Coverage for python/lsst/meas/extensions/scarlet/scarletDeblendTask.py : 15%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of meas_extensions_scarlet.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import logging
23import numpy as np
24import scarlet
25from scarlet.psf import ImagePSF, GaussianPSF
26from scarlet import Blend, Frame, Observation
27from scarlet.renderer import ConvolutionRenderer
28from scarlet.initialization import init_all_sources
30import lsst.log
31import lsst.pex.config as pexConfig
32from lsst.pex.exceptions import InvalidParameterError
33import lsst.pipe.base as pipeBase
34from lsst.geom import Point2I, Box2I, Point2D
35import lsst.afw.geom.ellipses as afwEll
36import lsst.afw.image.utils
37import lsst.afw.image as afwImage
38import lsst.afw.detection as afwDet
39import lsst.afw.table as afwTable
41from .source import modelToHeavy
43# Scarlet and proxmin have a different definition of log levels than the stack,
44# so even "warnings" occur far more often than we would like.
45# So for now we only display scarlet and proxmin errors, as all other
46# scarlet outputs would be considered "TRACE" by our standards.
47scarletLogger = logging.getLogger("scarlet")
48scarletLogger.setLevel(logging.ERROR)
49proxminLogger = logging.getLogger("proxmin")
50proxminLogger.setLevel(logging.ERROR)
52__all__ = ["deblend", "ScarletDeblendConfig", "ScarletDeblendTask"]
54logger = lsst.log.Log.getLogger("meas.deblender.deblend")
57class IncompleteDataError(Exception):
58 """The PSF could not be computed due to incomplete data
59 """
60 pass
63class ScarletGradientError(Exception):
64 """An error occurred during optimization
66 This error occurs when the optimizer encounters
67 a NaN value while calculating the gradient.
68 """
69 def __init__(self, iterations, sources):
70 self.iterations = iterations
71 self.sources = sources
72 msg = ("ScalarGradientError in iteration {0}. "
73 "NaN values introduced in sources {1}")
74 self.message = msg.format(iterations, sources)
76 def __str__(self):
77 return self.message
80def _checkBlendConvergence(blend, f_rel):
81 """Check whether or not a blend has converged
82 """
83 deltaLoss = np.abs(blend.loss[-2] - blend.loss[-1])
84 convergence = f_rel * np.abs(blend.loss[-1])
85 return deltaLoss < convergence
88def _getPsfFwhm(psf):
89 """Calculate the FWHM of the `psf`
90 """
91 return psf.computeShape().getDeterminantRadius() * 2.35
94def _computePsfImage(self, position=None):
95 """Get a multiband PSF image
96 The PSF Kernel Image is computed for each band
97 and combined into a (filter, y, x) array and stored
98 as `self._psfImage`.
99 The result is not cached, so if the same PSF is expected
100 to be used multiple times it is a good idea to store the
101 result in another variable.
102 Note: this is a temporary fix during the deblender sprint.
103 In the future this function will replace the current method
104 in `afw.MultibandExposure.computePsfImage` (DM-19789).
105 Parameters
106 ----------
107 position : `Point2D` or `tuple`
108 Coordinates to evaluate the PSF. If `position` is `None`
109 then `Psf.getAveragePosition()` is used.
110 Returns
111 -------
112 self._psfImage: array
113 The multiband PSF image.
114 """
115 psfs = []
116 # Make the coordinates into a Point2D (if necessary)
117 if not isinstance(position, Point2D) and position is not None:
118 position = Point2D(position[0], position[1])
120 for bidx, single in enumerate(self.singles):
121 try:
122 if position is None:
123 psf = single.getPsf().computeImage()
124 psfs.append(psf)
125 else:
126 psf = single.getPsf().computeKernelImage(position)
127 psfs.append(psf)
128 except InvalidParameterError:
129 # This band failed to compute the PSF due to incomplete data
130 # at that location. This is unlikely to be a problem for Rubin,
131 # however the edges of some HSC COSMOS fields contain incomplete
132 # data in some bands, so we track this error to distinguish it
133 # from unknown errors.
134 msg = "Failed to compute PSF at {} in band {}"
135 raise IncompleteDataError(msg.format(position, self.filters[bidx]))
137 left = np.min([psf.getBBox().getMinX() for psf in psfs])
138 bottom = np.min([psf.getBBox().getMinY() for psf in psfs])
139 right = np.max([psf.getBBox().getMaxX() for psf in psfs])
140 top = np.max([psf.getBBox().getMaxY() for psf in psfs])
141 bbox = Box2I(Point2I(left, bottom), Point2I(right, top))
142 psfs = [afwImage.utils.projectImage(psf, bbox) for psf in psfs]
143 psfImage = afwImage.MultibandImage.fromImages(self.filters, psfs)
144 return psfImage
147def getFootprintMask(footprint, mExposure):
148 """Mask pixels outside the footprint
150 Parameters
151 ----------
152 mExposure : `lsst.image.MultibandExposure`
153 - The multiband exposure containing the image,
154 mask, and variance data
155 footprint : `lsst.detection.Footprint`
156 - The footprint of the parent to deblend
158 Returns
159 -------
160 footprintMask : array
161 Boolean array with pixels not in the footprint set to one.
162 """
163 bbox = footprint.getBBox()
164 fpMask = afwImage.Mask(bbox)
165 footprint.spans.setMask(fpMask, 1)
166 fpMask = ~fpMask.getArray().astype(bool)
167 return fpMask
170def isPseudoSource(source, pseudoColumns):
171 """Check if a source is a pseudo source.
173 This is mostly for skipping sky objects,
174 but any other column can also be added to disable
175 deblending on a parent or individual source when
176 set to `True`.
178 Parameters
179 ----------
180 source : `lsst.afw.table.source.source.SourceRecord`
181 The source to check for the pseudo bit.
182 pseudoColumns : `list` of `str`
183 A list of columns to check for pseudo sources.
184 """
185 isPseudo = False
186 for col in pseudoColumns:
187 try:
188 isPseudo |= source[col]
189 except KeyError:
190 pass
191 return isPseudo
194def deblend(mExposure, footprint, config):
195 """Deblend a parent footprint
197 Parameters
198 ----------
199 mExposure : `lsst.image.MultibandExposure`
200 - The multiband exposure containing the image,
201 mask, and variance data
202 footprint : `lsst.detection.Footprint`
203 - The footprint of the parent to deblend
204 config : `ScarletDeblendConfig`
205 - Configuration of the deblending task
206 """
207 # Extract coordinates from each MultiColorPeak
208 bbox = footprint.getBBox()
210 # Create the data array from the masked images
211 images = mExposure.image[:, bbox].array
213 # Use the inverse variance as the weights
214 if config.useWeights:
215 weights = 1/mExposure.variance[:, bbox].array
216 else:
217 weights = np.ones_like(images)
218 badPixels = mExposure.mask.getPlaneBitMask(config.badMask)
219 mask = mExposure.mask[:, bbox].array & badPixels
220 weights[mask > 0] = 0
222 # Mask out the pixels outside the footprint
223 mask = getFootprintMask(footprint, mExposure)
224 weights *= ~mask
226 psfs = _computePsfImage(mExposure, footprint.getCentroid()).array.astype(np.float32)
227 psfs = ImagePSF(psfs)
228 model_psf = GaussianPSF(sigma=(config.modelPsfSigma,)*len(mExposure.filters))
230 frame = Frame(images.shape, psf=model_psf, channels=mExposure.filters)
231 observation = Observation(images, psf=psfs, weights=weights, channels=mExposure.filters)
232 if config.convolutionType == "fft":
233 observation.match(frame)
234 elif config.convolutionType == "real":
235 renderer = ConvolutionRenderer(observation, frame, convolution_type="real")
236 observation.match(frame, renderer=renderer)
237 else:
238 raise ValueError("Unrecognized convolution type {}".format(config.convolutionType))
240 assert(config.sourceModel in ["single", "double", "compact", "fit"])
242 # Set the appropriate number of components
243 if config.sourceModel == "single":
244 maxComponents = 1
245 elif config.sourceModel == "double":
246 maxComponents = 2
247 elif config.sourceModel == "compact":
248 maxComponents = 0
249 elif config.sourceModel == "point":
250 raise NotImplementedError("Point source photometry is currently not implemented")
251 elif config.sourceModel == "fit":
252 # It is likely in the future that there will be some heuristic
253 # used to determine what type of model to use for each source,
254 # but that has not yet been implemented (see DM-22551)
255 raise NotImplementedError("sourceModel 'fit' has not been implemented yet")
257 # Convert the centers to pixel coordinates
258 xmin = bbox.getMinX()
259 ymin = bbox.getMinY()
260 centers = [
261 np.array([peak.getIy() - ymin, peak.getIx() - xmin], dtype=int)
262 for peak in footprint.peaks
263 if not isPseudoSource(peak, config.pseudoColumns)
264 ]
266 # Choose whether or not to use the improved spectral initialization
267 if config.setSpectra:
268 if config.maxSpectrumCutoff <= 0:
269 spectrumInit = True
270 else:
271 spectrumInit = len(centers) * bbox.getArea() < config.maxSpectrumCutoff
272 else:
273 spectrumInit = False
275 # Only deblend sources that can be initialized
276 sources, skipped = init_all_sources(
277 frame=frame,
278 centers=centers,
279 observations=observation,
280 thresh=config.morphThresh,
281 max_components=maxComponents,
282 min_snr=config.minSNR,
283 shifting=False,
284 fallback=config.fallback,
285 silent=config.catchFailures,
286 set_spectra=spectrumInit,
287 )
289 # Attach the peak to all of the initialized sources
290 srcIndex = 0
291 for k, center in enumerate(centers):
292 if k not in skipped:
293 # This is just to make sure that there isn't a coding bug
294 assert np.all(sources[srcIndex].center == center)
295 # Store the record for the peak with the appropriate source
296 sources[srcIndex].detectedPeak = footprint.peaks[k]
297 srcIndex += 1
299 # Create the blend and attempt to optimize it
300 blend = Blend(sources, observation)
301 try:
302 blend.fit(max_iter=config.maxIter, e_rel=config.relativeError)
303 except ArithmeticError:
304 # This occurs when a gradient update produces a NaN value
305 # This is usually due to a source initialized with a
306 # negative SED or no flux, often because the peak
307 # is a noise fluctuation in one band and not a real source.
308 iterations = len(blend.loss)
309 failedSources = []
310 for k, src in enumerate(sources):
311 if np.any(~np.isfinite(src.get_model())):
312 failedSources.append(k)
313 raise ScarletGradientError(iterations, failedSources)
315 return blend, skipped, spectrumInit
318class ScarletDeblendConfig(pexConfig.Config):
319 """MultibandDeblendConfig
321 Configuration for the multiband deblender.
322 The parameters are organized by the parameter types, which are
323 - Stopping Criteria: Used to determine if the fit has converged
324 - Position Fitting Criteria: Used to fit the positions of the peaks
325 - Constraints: Used to apply constraints to the peaks and their components
326 - Other: Parameters that don't fit into the above categories
327 """
328 # Stopping Criteria
329 maxIter = pexConfig.Field(dtype=int, default=300,
330 doc=("Maximum number of iterations to deblend a single parent"))
331 relativeError = pexConfig.Field(dtype=float, default=1e-4,
332 doc=("Change in the loss function between"
333 "iterations to exit fitter"))
335 # Constraints
336 morphThresh = pexConfig.Field(dtype=float, default=1,
337 doc="Fraction of background RMS a pixel must have"
338 "to be included in the initial morphology")
339 # Other scarlet paremeters
340 useWeights = pexConfig.Field(
341 dtype=bool, default=True,
342 doc=("Whether or not use use inverse variance weighting."
343 "If `useWeights` is `False` then flat weights are used"))
344 modelPsfSize = pexConfig.Field(
345 dtype=int, default=11,
346 doc="Model PSF side length in pixels")
347 modelPsfSigma = pexConfig.Field(
348 dtype=float, default=0.8,
349 doc="Define sigma for the model frame PSF")
350 minSNR = pexConfig.Field(
351 dtype=float, default=50,
352 doc="Minimum Signal to noise to accept the source."
353 "Sources with lower flux will be initialized with the PSF but updated "
354 "like an ordinary ExtendedSource (known in scarlet as a `CompactSource`).")
355 saveTemplates = pexConfig.Field(
356 dtype=bool, default=True,
357 doc="Whether or not to save the SEDs and templates")
358 processSingles = pexConfig.Field(
359 dtype=bool, default=True,
360 doc="Whether or not to process isolated sources in the deblender")
361 convolutionType = pexConfig.Field(
362 dtype=str, default="fft",
363 doc="Type of convolution to render the model to the observations.\n"
364 "- 'fft': perform convolutions in Fourier space\n"
365 "- 'real': peform convolutions in real space.")
366 sourceModel = pexConfig.Field(
367 dtype=str, default="double",
368 doc=("How to determine which model to use for sources, from\n"
369 "- 'single': use a single component for all sources\n"
370 "- 'double': use a bulge disk model for all sources\n"
371 "- 'compact': use a single component model, initialzed with a point source morphology, "
372 " for all sources\n"
373 "- 'point': use a point-source model for all sources\n"
374 "- 'fit: use a PSF fitting model to determine the number of components (not yet implemented)")
375 )
376 setSpectra = pexConfig.Field(
377 dtype=bool, default=True,
378 doc="Whether or not to solve for the best-fit spectra during initialization. "
379 "This makes initialization slightly longer, as it requires a convolution "
380 "to set the optimal spectra, but results in a much better initial log-likelihood "
381 "and reduced total runtime, with convergence in fewer iterations."
382 "This option is only used when "
383 "peaks*area < `maxSpectrumCutoff` will use the improved initialization.")
385 # Mask-plane restrictions
386 badMask = pexConfig.ListField(
387 dtype=str, default=["BAD", "CR", "NO_DATA", "SAT", "SUSPECT", "EDGE"],
388 doc="Whether or not to process isolated sources in the deblender")
389 statsMask = pexConfig.ListField(dtype=str, default=["SAT", "INTRP", "NO_DATA"],
390 doc="Mask planes to ignore when performing statistics")
391 maskLimits = pexConfig.DictField(
392 keytype=str,
393 itemtype=float,
394 default={},
395 doc=("Mask planes with the corresponding limit on the fraction of masked pixels. "
396 "Sources violating this limit will not be deblended."),
397 )
399 # Size restrictions
400 maxNumberOfPeaks = pexConfig.Field(
401 dtype=int, default=0,
402 doc=("Only deblend the brightest maxNumberOfPeaks peaks in the parent"
403 " (<= 0: unlimited)"))
404 maxFootprintArea = pexConfig.Field(
405 dtype=int, default=1000000,
406 doc=("Maximum area for footprints before they are ignored as large; "
407 "non-positive means no threshold applied"))
408 maxFootprintSize = pexConfig.Field(
409 dtype=int, default=0,
410 doc=("Maximum linear dimension for footprints before they are ignored "
411 "as large; non-positive means no threshold applied"))
412 minFootprintAxisRatio = pexConfig.Field(
413 dtype=float, default=0.0,
414 doc=("Minimum axis ratio for footprints before they are ignored "
415 "as large; non-positive means no threshold applied"))
416 maxSpectrumCutoff = pexConfig.Field(
417 dtype=int, default=1000000,
418 doc=("Maximum number of pixels * number of sources in a blend. "
419 "This is different than `maxFootprintArea` because this isn't "
420 "the footprint area but the area of the bounding box that "
421 "contains the footprint, and is also multiplied by the number of"
422 "sources in the footprint. This prevents large skinny blends with "
423 "a high density of sources from running out of memory. "
424 "If `maxSpectrumCutoff == -1` then there is no cutoff.")
425 )
427 # Failure modes
428 fallback = pexConfig.Field(
429 dtype=bool, default=True,
430 doc="Whether or not to fallback to a smaller number of components if a source does not initialize"
431 )
432 notDeblendedMask = pexConfig.Field(
433 dtype=str, default="NOT_DEBLENDED", optional=True,
434 doc="Mask name for footprints not deblended, or None")
435 catchFailures = pexConfig.Field(
436 dtype=bool, default=True,
437 doc=("If True, catch exceptions thrown by the deblender, log them, "
438 "and set a flag on the parent, instead of letting them propagate up"))
440 # Other options
441 columnInheritance = pexConfig.DictField(
442 keytype=str, itemtype=str, default={
443 "deblend_nChild": "deblend_parentNChild",
444 "deblend_nPeaks": "deblend_parentNPeaks",
445 "deblend_spectrumInitFlag": "deblend_spectrumInitFlag",
446 "deblend_blendConvergenceFailedFlag": "deblend_blendConvergenceFailedFlag",
447 },
448 doc="Columns to pass from the parent to the child. "
449 "The key is the name of the column for the parent record, "
450 "the value is the name of the column to use for the child."
451 )
452 pseudoColumns = pexConfig.ListField(
453 dtype=str, default=['merge_peak_sky', 'sky_source'],
454 doc="Names of flags which should never be deblended."
455 )
457 # Testing options
458 # Some obs packages and ci packages run the full pipeline on a small
459 # subset of data to test that the pipeline is functioning properly.
460 # This is not meant as scientific validation, so it can be useful
461 # to only run on a small subset of the data that is large enough to
462 # test the desired pipeline features but not so long that the deblender
463 # is the tall pole in terms of execution times.
464 useCiLimits = pexConfig.Field(
465 dtype=bool, default=False,
466 doc="Limit the number of sources deblended for CI to prevent long build times")
467 ciDeblendChildRange = pexConfig.ListField(
468 dtype=int, default=[5, 10],
469 doc="Only deblend parent Footprints with a number of peaks in the (inclusive) range indicated."
470 "If `useCiLimits==False` then this parameter is ignored.")
471 ciNumParentsToDeblend = pexConfig.Field(
472 dtype=int, default=10,
473 doc="Only use the first `ciNumParentsToDeblend` parent footprints with a total peak count "
474 "within `ciDebledChildRange`. "
475 "If `useCiLimits==False` then this parameter is ignored.")
478class ScarletDeblendTask(pipeBase.Task):
479 """ScarletDeblendTask
481 Split blended sources into individual sources.
483 This task has no return value; it only modifies the SourceCatalog in-place.
484 """
485 ConfigClass = ScarletDeblendConfig
486 _DefaultName = "scarletDeblend"
488 def __init__(self, schema, peakSchema=None, **kwargs):
489 """Create the task, adding necessary fields to the given schema.
491 Parameters
492 ----------
493 schema : `lsst.afw.table.schema.schema.Schema`
494 Schema object for measurement fields; will be modified in-place.
495 peakSchema : `lsst.afw.table.schema.schema.Schema`
496 Schema of Footprint Peaks that will be passed to the deblender.
497 Any fields beyond the PeakTable minimal schema will be transferred
498 to the main source Schema. If None, no fields will be transferred
499 from the Peaks.
500 filters : list of str
501 Names of the filters used for the eposures. This is needed to store
502 the SED as a field
503 **kwargs
504 Passed to Task.__init__.
505 """
506 pipeBase.Task.__init__(self, **kwargs)
508 peakMinimalSchema = afwDet.PeakTable.makeMinimalSchema()
509 if peakSchema is None:
510 # In this case, the peakSchemaMapper will transfer nothing, but
511 # we'll still have one
512 # to simplify downstream code
513 self.peakSchemaMapper = afwTable.SchemaMapper(peakMinimalSchema, schema)
514 else:
515 self.peakSchemaMapper = afwTable.SchemaMapper(peakSchema, schema)
516 for item in peakSchema:
517 if item.key not in peakMinimalSchema:
518 self.peakSchemaMapper.addMapping(item.key, item.field)
519 # Because SchemaMapper makes a copy of the output schema
520 # you give its ctor, it isn't updating this Schema in
521 # place. That's probably a design flaw, but in the
522 # meantime, we'll keep that schema in sync with the
523 # peakSchemaMapper.getOutputSchema() manually, by adding
524 # the same fields to both.
525 schema.addField(item.field)
526 assert schema == self.peakSchemaMapper.getOutputSchema(), "Logic bug mapping schemas"
527 self._addSchemaKeys(schema)
528 self.schema = schema
529 self.toCopyFromParent = [item.key for item in self.schema
530 if item.field.getName().startswith("merge_footprint")]
532 def _addSchemaKeys(self, schema):
533 """Add deblender specific keys to the schema
534 """
535 self.runtimeKey = schema.addField('deblend_runtime', type=np.float32, doc='runtime in ms')
537 self.iterKey = schema.addField('deblend_iterations', type=np.int32, doc='iterations to converge')
539 self.nChildKey = schema.addField('deblend_nChild', type=np.int32,
540 doc='Number of children this object has (defaults to 0)')
541 self.psfKey = schema.addField('deblend_deblendedAsPsf', type='Flag',
542 doc='Deblender thought this source looked like a PSF')
543 self.tooManyPeaksKey = schema.addField('deblend_tooManyPeaks', type='Flag',
544 doc='Source had too many peaks; '
545 'only the brightest were included')
546 self.tooBigKey = schema.addField('deblend_parentTooBig', type='Flag',
547 doc='Parent footprint covered too many pixels')
548 self.maskedKey = schema.addField('deblend_masked', type='Flag',
549 doc='Parent footprint was predominantly masked')
550 self.sedNotConvergedKey = schema.addField('deblend_sedConvergenceFailed', type='Flag',
551 doc='scarlet sed optimization did not converge before'
552 'config.maxIter')
553 self.morphNotConvergedKey = schema.addField('deblend_morphConvergenceFailed', type='Flag',
554 doc='scarlet morph optimization did not converge before'
555 'config.maxIter')
556 self.blendConvergenceFailedFlagKey = schema.addField('deblend_blendConvergenceFailedFlag',
557 type='Flag',
558 doc='at least one source in the blend'
559 'failed to converge')
560 self.edgePixelsKey = schema.addField('deblend_edgePixels', type='Flag',
561 doc='Source had flux on the edge of the parent footprint')
562 self.deblendFailedKey = schema.addField('deblend_failed', type='Flag',
563 doc="Deblending failed on source")
564 self.deblendErrorKey = schema.addField('deblend_error', type="String", size=25,
565 doc='Name of error if the blend failed')
566 self.deblendSkippedKey = schema.addField('deblend_skipped', type='Flag',
567 doc="Deblender skipped this source")
568 self.peakCenter = afwTable.Point2IKey.addFields(schema, name="deblend_peak_center",
569 doc="Center used to apply constraints in scarlet",
570 unit="pixel")
571 self.peakIdKey = schema.addField("deblend_peakId", type=np.int32,
572 doc="ID of the peak in the parent footprint. "
573 "This is not unique, but the combination of 'parent'"
574 "and 'peakId' should be for all child sources. "
575 "Top level blends with no parents have 'peakId=0'")
576 self.modelCenterFlux = schema.addField('deblend_peak_instFlux', type=float, units='count',
577 doc="The instFlux at the peak position of deblended mode")
578 self.modelTypeKey = schema.addField("deblend_modelType", type="String", size=25,
579 doc="The type of model used, for example "
580 "MultiExtendedSource, SingleExtendedSource, PointSource")
581 self.nPeaksKey = schema.addField("deblend_nPeaks", type=np.int32,
582 doc="Number of initial peaks in the blend. "
583 "This includes peaks that may have been culled "
584 "during deblending or failed to deblend")
585 self.parentNPeaksKey = schema.addField("deblend_parentNPeaks", type=np.int32,
586 doc="deblend_nPeaks from this records parent.")
587 self.parentNChildKey = schema.addField("deblend_parentNChild", type=np.int32,
588 doc="deblend_nChild from this records parent.")
589 self.scarletFluxKey = schema.addField("deblend_scarletFlux", type=np.float32,
590 doc="Flux measurement from scarlet")
591 self.scarletLogLKey = schema.addField("deblend_logL", type=np.float32,
592 doc="Final logL, used to identify regressions in scarlet.")
593 self.scarletSpectrumInitKey = schema.addField("deblend_spectrumInitFlag", type='Flag',
594 doc="True when scarlet initializes sources "
595 "in the blend with a more accurate spectrum. "
596 "The algorithm uses a lot of memory, "
597 "so large dense blends will use "
598 "a less accurate initialization.")
600 # self.log.trace('Added keys to schema: %s', ", ".join(str(x) for x in
601 # (self.nChildKey, self.tooManyPeaksKey, self.tooBigKey))
602 # )
604 @pipeBase.timeMethod
605 def run(self, mExposure, mergedSources):
606 """Get the psf from each exposure and then run deblend().
608 Parameters
609 ----------
610 mExposure : `MultibandExposure`
611 The exposures should be co-added images of the same
612 shape and region of the sky.
613 mergedSources : `SourceCatalog`
614 The merged `SourceCatalog` that contains parent footprints
615 to (potentially) deblend.
617 Returns
618 -------
619 templateCatalogs: dict
620 Keys are the names of the filters and the values are
621 `lsst.afw.table.source.source.SourceCatalog`'s.
622 These are catalogs with heavy footprints that are the templates
623 created by the multiband templates.
624 """
625 return self.deblend(mExposure, mergedSources)
627 @pipeBase.timeMethod
628 def deblend(self, mExposure, catalog):
629 """Deblend a data cube of multiband images
631 Parameters
632 ----------
633 mExposure : `MultibandExposure`
634 The exposures should be co-added images of the same
635 shape and region of the sky.
636 catalog : `SourceCatalog`
637 The merged `SourceCatalog` that contains parent footprints
638 to (potentially) deblend. The new deblended sources are
639 appended to this catalog in place.
641 Returns
642 -------
643 catalogs : `dict` or `None`
644 Keys are the names of the filters and the values are
645 `lsst.afw.table.source.source.SourceCatalog`'s.
646 These are catalogs with heavy footprints that are the templates
647 created by the multiband templates.
648 """
649 import time
651 # Cull footprints if required by ci
652 if self.config.useCiLimits:
653 self.log.info(f"Using CI catalog limits, "
654 f"the original number of sources to deblend was {len(catalog)}.")
655 # Select parents with a number of children in the range
656 # config.ciDeblendChildRange
657 minChildren, maxChildren = self.config.ciDeblendChildRange
658 nPeaks = np.array([len(src.getFootprint().peaks) for src in catalog])
659 childrenInRange = np.where((nPeaks >= minChildren) & (nPeaks <= maxChildren))[0]
660 if len(childrenInRange) < self.config.ciNumParentsToDeblend:
661 raise ValueError("Fewer than ciNumParentsToDeblend children were contained in the range "
662 "indicated by ciDeblendChildRange. Adjust this range to include more "
663 "parents.")
664 # Keep all of the isolated parents and the first
665 # `ciNumParentsToDeblend` children
666 parents = nPeaks == 1
667 children = np.zeros((len(catalog),), dtype=bool)
668 children[childrenInRange[:self.config.ciNumParentsToDeblend]] = True
669 catalog = catalog[parents | children]
670 # We need to update the IdFactory, otherwise the the source ids
671 # will not be sequential
672 idFactory = catalog.getIdFactory()
673 maxId = np.max(catalog["id"])
674 idFactory.notify(maxId)
676 filters = mExposure.filters
677 self.log.info(f"Deblending {len(catalog)} sources in {len(mExposure)} exposure bands")
679 # Add the NOT_DEBLENDED mask to the mask plane in each band
680 if self.config.notDeblendedMask:
681 for mask in mExposure.mask:
682 mask.addMaskPlane(self.config.notDeblendedMask)
684 nParents = len(catalog)
685 nDeblendedParents = 0
686 skippedParents = []
687 multibandColumns = {
688 "heavies": [],
689 "fluxes": [],
690 "centerFluxes": [],
691 }
692 for parentIndex in range(nParents):
693 parent = catalog[parentIndex]
694 foot = parent.getFootprint()
695 bbox = foot.getBBox()
696 peaks = foot.getPeaks()
698 # Since we use the first peak for the parent object, we should
699 # propagate its flags to the parent source.
700 parent.assign(peaks[0], self.peakSchemaMapper)
702 # Skip isolated sources unless processSingles is turned on.
703 # Note: this does not flag isolated sources as skipped or
704 # set the NOT_DEBLENDED mask in the exposure,
705 # since these aren't really a skipped blends.
706 # We also skip pseudo sources, like sky objects, which
707 # are intended to be skipped
708 if ((len(peaks) < 2 and not self.config.processSingles)
709 or isPseudoSource(parent, self.config.pseudoColumns)):
710 self._updateParentRecord(
711 parent=parent,
712 nPeaks=len(peaks),
713 nChild=0,
714 runtime=np.nan,
715 iterations=0,
716 logL=np.nan,
717 spectrumInit=False,
718 converged=False,
719 )
720 continue
722 # Block of conditions for skipping a parent with multiple children
723 skipKey = None
724 if self._isLargeFootprint(foot):
725 # The footprint is above the maximum footprint size limit
726 skipKey = self.tooBigKey
727 skipMessage = f"Parent {parent.getId()}: skipping large footprint"
728 elif self._isMasked(foot, mExposure):
729 # The footprint exceeds the maximum number of masked pixels
730 skipKey = self.maskedKey
731 skipMessage = f"Parent {parent.getId()}: skipping masked footprint"
732 elif self.config.maxNumberOfPeaks > 0 and len(peaks) > self.config.maxNumberOfPeaks:
733 # Unlike meas_deblender, in scarlet we skip the entire blend
734 # if the number of peaks exceeds max peaks, since neglecting
735 # to model any peaks often results in catastrophic failure
736 # of scarlet to generate models for the brighter sources.
737 skipKey = self.tooManyPeaksKey
738 skipMessage = f"Parent {parent.getId()}: Too many peaks, skipping blend"
739 if skipKey is not None:
740 self._skipParent(
741 parent=parent,
742 skipKey=skipKey,
743 logMessage=skipMessage,
744 )
745 skippedParents.append(parentIndex)
746 continue
748 nDeblendedParents += 1
749 self.log.trace(f"Parent {parent.getId()}: deblending {len(peaks)} peaks")
750 # Run the deblender
751 blendError = None
752 try:
753 t0 = time.time()
754 # Build the parameter lists with the same ordering
755 blend, skipped, spectrumInit = deblend(mExposure, foot, self.config)
756 tf = time.time()
757 runtime = (tf-t0)*1000
758 converged = _checkBlendConvergence(blend, self.config.relativeError)
759 scarletSources = [src for src in blend.sources]
760 nChild = len(scarletSources)
761 # Re-insert place holders for skipped sources
762 # to propagate them in the catalog so
763 # that the peaks stay consistent
764 for k in skipped:
765 scarletSources.insert(k, None)
766 # Catch all errors and filter out the ones that we know about
767 except Exception as e:
768 blendError = type(e).__name__
769 if isinstance(e, ScarletGradientError):
770 parent.set(self.iterKey, e.iterations)
771 elif not isinstance(e, IncompleteDataError):
772 blendError = "UnknownError"
773 if self.config.catchFailures:
774 # Make it easy to find UnknownErrors in the log file
775 self.log.warn("UnknownError")
776 import traceback
777 traceback.print_exc()
778 else:
779 raise
781 self._skipParent(
782 parent=parent,
783 skipKey=self.deblendFailedKey,
784 logMessage=f"Unable to deblend source {parent.getId}: {blendError}",
785 )
786 parent.set(self.deblendErrorKey, blendError)
787 skippedParents.append(parentIndex)
788 continue
790 # Update the parent record with the deblending results
791 logL = blend.loss[-1]-blend.observations[0].log_norm
792 self._updateParentRecord(
793 parent=parent,
794 nPeaks=len(peaks),
795 nChild=nChild,
796 runtime=runtime,
797 iterations=len(blend.loss),
798 logL=logL,
799 spectrumInit=spectrumInit,
800 converged=converged,
801 )
803 # Add each deblended source to the catalog
804 for k, scarletSource in enumerate(scarletSources):
805 # Skip any sources with no flux or that scarlet skipped because
806 # it could not initialize
807 if k in skipped:
808 # No need to propagate anything
809 continue
810 parent.set(self.deblendSkippedKey, False)
811 mHeavy = modelToHeavy(scarletSource, filters, xy0=bbox.getMin(),
812 observation=blend.observations[0])
813 multibandColumns["heavies"].append(mHeavy)
814 flux = scarlet.measure.flux(scarletSource)
815 multibandColumns["fluxes"].append({
816 filters[fidx]: _flux
817 for fidx, _flux in enumerate(flux)
818 })
819 centerFlux = self._getCenterFlux(mHeavy, scarletSource, xy0=bbox.getMin())
820 multibandColumns["centerFluxes"].append(centerFlux)
822 # Add all fields except the HeavyFootprint to the
823 # source record
824 self._addChild(
825 parent=parent,
826 mHeavy=mHeavy,
827 catalog=catalog,
828 scarletSource=scarletSource,
829 )
831 # Make sure that the number of new sources matches the number of
832 # entries in each of the band dependent columns.
833 # This should never trigger and is just a sanity check.
834 nChildren = len(catalog) - nParents
835 if np.any([len(meas) != nChildren for meas in multibandColumns.values()]):
836 msg = f"Added {len(catalog)-nParents} new sources, but have "
837 msg += ", ".join([
838 f"{len(value)} {key}"
839 for key, value in multibandColumns
840 ])
841 raise RuntimeError(msg)
842 # Make a copy of the catlog in each band and update the footprints
843 catalogs = {}
844 for f in filters:
845 _catalog = afwTable.SourceCatalog(catalog.table.clone())
846 _catalog.extend(catalog, deep=True)
847 # Update the footprints and columns that are different
848 # for each filter
849 for sourceIndex, source in enumerate(_catalog[nParents:]):
850 source.setFootprint(multibandColumns["heavies"][sourceIndex][f])
851 source.set(self.scarletFluxKey, multibandColumns["fluxes"][sourceIndex][f])
852 source.set(self.modelCenterFlux, multibandColumns["centerFluxes"][sourceIndex][f])
853 catalogs[f] = _catalog
855 # Update the mExposure mask with the footprint of skipped parents
856 if self.config.notDeblendedMask:
857 for mask in mExposure.mask:
858 for parentIndex in skippedParents:
859 fp = _catalog[parentIndex].getFootprint()
860 fp.spans.setMask(mask, mask.getPlaneBitMask(self.config.notDeblendedMask))
862 self.log.info(f"Deblender results: of {nParents} parent sources, {nDeblendedParents} "
863 f"were deblended, creating {nChildren} children, "
864 f"for a total of {len(catalog)} sources")
865 return catalogs
867 def _isLargeFootprint(self, footprint):
868 """Returns whether a Footprint is large
870 'Large' is defined by thresholds on the area, size and axis ratio.
871 These may be disabled independently by configuring them to be
872 non-positive.
874 This is principally intended to get rid of satellite streaks, which the
875 deblender or other downstream processing can have trouble dealing with
876 (e.g., multiple large HeavyFootprints can chew up memory).
877 """
878 if self.config.maxFootprintArea > 0 and footprint.getArea() > self.config.maxFootprintArea:
879 return True
880 if self.config.maxFootprintSize > 0:
881 bbox = footprint.getBBox()
882 if max(bbox.getWidth(), bbox.getHeight()) > self.config.maxFootprintSize:
883 return True
884 if self.config.minFootprintAxisRatio > 0:
885 axes = afwEll.Axes(footprint.getShape())
886 if axes.getB() < self.config.minFootprintAxisRatio*axes.getA():
887 return True
888 return False
890 def _isMasked(self, footprint, mExposure):
891 """Returns whether the footprint violates the mask limits"""
892 bbox = footprint.getBBox()
893 mask = np.bitwise_or.reduce(mExposure.mask[:, bbox].array, axis=0)
894 size = float(footprint.getArea())
895 for maskName, limit in self.config.maskLimits.items():
896 maskVal = mExposure.mask.getPlaneBitMask(maskName)
897 _mask = afwImage.MaskX(mask & maskVal, xy0=bbox.getMin())
898 unmaskedSpan = footprint.spans.intersectNot(_mask) # spanset of unmasked pixels
899 if (size - unmaskedSpan.getArea())/size > limit:
900 return True
901 return False
903 def _skipParent(self, parent, skipKey, logMessage):
904 """Update a parent record that is not being deblended.
906 This is a fairly trivial function but is implemented to ensure
907 that a skipped parent updates the appropriate columns
908 consistently, and always has a flag to mark the reason that
909 it is being skipped.
911 Parameters
912 ----------
913 parent : `lsst.afw.table.source.source.SourceRecord`
914 The parent record to flag as skipped.
915 skipKey : `bool`
916 The name of the flag to mark the reason for skipping.
917 logMessage : `str`
918 The message to display in a log.trace when a source
919 is skipped.
920 """
921 if logMessage is not None:
922 self.log.trace(logMessage)
923 self._updateParentRecord(
924 parent=parent,
925 nPeaks=len(parent.getFootprint().peaks),
926 nChild=0,
927 runtime=np.nan,
928 iterations=0,
929 logL=np.nan,
930 spectrumInit=False,
931 converged=False,
932 )
934 # Mark the source as skipped by the deblender and
935 # flag the reason why.
936 parent.set(self.deblendSkippedKey, True)
937 parent.set(skipKey, True)
939 def _updateParentRecord(self, parent, nPeaks, nChild,
940 runtime, iterations, logL, spectrumInit, converged):
941 """Update a parent record in all of the single band catalogs.
943 Ensure that all locations that update a parent record,
944 whether it is skipped or updated after deblending,
945 update all of the appropriate columns.
947 Parameters
948 ----------
949 parent : `lsst.afw.table.source.source.SourceRecord`
950 The parent record to update.
951 nPeaks : `int`
952 Number of peaks in the parent footprint.
953 nChild : `int`
954 Number of children deblended from the parent.
955 This may differ from `nPeaks` if some of the peaks
956 were culled and have no deblended model.
957 runtime : `float`
958 Total runtime for deblending.
959 iterations : `int`
960 Total number of iterations in scarlet before convergence.
961 logL : `float`
962 Final log likelihood of the blend.
963 spectrumInit : `bool`
964 True when scarlet used `set_spectra` to initialize all
965 sources with better initial intensities.
966 converged : `bool`
967 True when the optimizer reached convergence before
968 reaching the maximum number of iterations.
969 """
970 parent.set(self.nPeaksKey, nPeaks)
971 parent.set(self.nChildKey, nChild)
972 parent.set(self.runtimeKey, runtime)
973 parent.set(self.iterKey, iterations)
974 parent.set(self.scarletLogLKey, logL)
975 parent.set(self.scarletSpectrumInitKey, spectrumInit)
976 parent.set(self.blendConvergenceFailedFlagKey, converged)
978 def _addChild(self, parent, mHeavy, catalog, scarletSource):
979 """Add a child to a catalog.
981 This creates a new child in the source catalog,
982 assigning it a parent id, and adding all columns
983 that are independent across all filter bands.
985 Parameters
986 ----------
987 parent : `lsst.afw.table.source.source.SourceRecord`
988 The parent of the new child record.
989 mHeavy : `lsst.detection.MultibandFootprint`
990 The multi-band footprint containing the model and
991 peak catalog for the new child record.
992 catalog : `lsst.afw.table.source.source.SourceCatalog`
993 The merged `SourceCatalog` that contains parent footprints
994 to (potentially) deblend.
995 scarletSource : `scarlet.Component`
996 The scarlet model for the new source record.
997 """
998 src = catalog.addNew()
999 for key in self.toCopyFromParent:
1000 src.set(key, parent.get(key))
1001 # The peak catalog is the same for all bands,
1002 # so we just use the first peak catalog
1003 peaks = mHeavy[mHeavy.filters[0]].peaks
1004 src.assign(peaks[0], self.peakSchemaMapper)
1005 src.setParent(parent.getId())
1006 # Currently all children only have a single peak,
1007 # but it's possible in the future that there will be hierarchical
1008 # deblending, so we use the footprint to set the number of peaks
1009 # for each child.
1010 src.set(self.nPeaksKey, len(peaks))
1011 # Set the psf key based on whether or not the source was
1012 # deblended using the PointSource model.
1013 # This key is not that useful anymore since we now keep track of
1014 # `modelType`, but we continue to propagate it in case code downstream
1015 # is expecting it.
1016 src.set(self.psfKey, scarletSource.__class__.__name__ == "PointSource")
1017 src.set(self.modelTypeKey, scarletSource.__class__.__name__)
1018 # We set the runtime to zero so that summing up the
1019 # runtime column will give the total time spent
1020 # running the deblender for the catalog.
1021 src.set(self.runtimeKey, 0)
1023 # Set the position of the peak from the parent footprint
1024 # This will make it easier to match the same source across
1025 # deblenders and across observations, where the peak
1026 # position is unlikely to change unless enough time passes
1027 # for a source to move on the sky.
1028 peak = scarletSource.detectedPeak
1029 src.set(self.peakCenter, Point2I(peak["i_x"], peak["i_y"]))
1030 src.set(self.peakIdKey, peak["id"])
1032 # Propagate columns from the parent to the child
1033 for parentColumn, childColumn in self.config.columnInheritance.items():
1034 src.set(childColumn, parent.get(parentColumn))
1036 def _getCenterFlux(self, mHeavy, scarletSource, xy0):
1037 """Get the flux at the center of a HeavyFootprint
1039 Parameters
1040 ----------
1041 mHeavy : `lsst.detection.MultibandFootprint`
1042 The multi-band footprint containing the model for the source.
1043 scarletSource : `scarlet.Component`
1044 The scarlet model for the heavy footprint
1045 """
1046 # Store the flux at the center of the model and the total
1047 # scarlet flux measurement.
1048 mImage = mHeavy.getImage(fill=0.0).image
1050 # Set the flux at the center of the model (for SNR)
1051 try:
1052 cy, cx = scarletSource.center
1053 cy += xy0.y
1054 cx += xy0.x
1055 return mImage[:, cx, cy]
1056 except AttributeError:
1057 msg = "Did not recognize coordinates for source type of `{0}`, "
1058 msg += "could not write coordinates or center flux. "
1059 msg += "Add `{0}` to meas_extensions_scarlet to properly persist this information."
1060 logger.warning(msg.format(type(scarletSource)))
1061 return {f: np.nan for f in mImage.filters}