Coverage for python/lsst/meas/extensions/scarlet/scarletDeblendTask.py: 15%
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of meas_extensions_scarlet.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import logging
23import numpy as np
24import scarlet
25from scarlet.psf import ImagePSF, GaussianPSF
26from scarlet import Blend, Frame, Observation
27from scarlet.renderer import ConvolutionRenderer
28from scarlet.initialization import init_all_sources
30import lsst.pex.config as pexConfig
31from lsst.pex.exceptions import InvalidParameterError
32import lsst.pipe.base as pipeBase
33from lsst.geom import Point2I, Box2I, Point2D
34import lsst.afw.geom.ellipses as afwEll
35import lsst.afw.image as afwImage
36import lsst.afw.detection as afwDet
37import lsst.afw.table as afwTable
38from lsst.utils.timer import timeMethod
40from .source import modelToHeavy
42# Scarlet and proxmin have a different definition of log levels than the stack,
43# so even "warnings" occur far more often than we would like.
44# So for now we only display scarlet and proxmin errors, as all other
45# scarlet outputs would be considered "TRACE" by our standards.
46scarletLogger = logging.getLogger("scarlet")
47scarletLogger.setLevel(logging.ERROR)
48proxminLogger = logging.getLogger("proxmin")
49proxminLogger.setLevel(logging.ERROR)
51__all__ = ["deblend", "ScarletDeblendConfig", "ScarletDeblendTask"]
53logger = logging.getLogger(__name__)
56class IncompleteDataError(Exception):
57 """The PSF could not be computed due to incomplete data
58 """
59 pass
62class ScarletGradientError(Exception):
63 """An error occurred during optimization
65 This error occurs when the optimizer encounters
66 a NaN value while calculating the gradient.
67 """
68 def __init__(self, iterations, sources):
69 self.iterations = iterations
70 self.sources = sources
71 msg = ("ScalarGradientError in iteration {0}. "
72 "NaN values introduced in sources {1}")
73 self.message = msg.format(iterations, sources)
75 def __str__(self):
76 return self.message
79def _checkBlendConvergence(blend, f_rel):
80 """Check whether or not a blend has converged
81 """
82 deltaLoss = np.abs(blend.loss[-2] - blend.loss[-1])
83 convergence = f_rel * np.abs(blend.loss[-1])
84 return deltaLoss < convergence
87def _getPsfFwhm(psf):
88 """Calculate the FWHM of the `psf`
89 """
90 return psf.computeShape().getDeterminantRadius() * 2.35
93def _computePsfImage(self, position=None):
94 """Get a multiband PSF image
95 The PSF Kernel Image is computed for each band
96 and combined into a (filter, y, x) array and stored
97 as `self._psfImage`.
98 The result is not cached, so if the same PSF is expected
99 to be used multiple times it is a good idea to store the
100 result in another variable.
101 Note: this is a temporary fix during the deblender sprint.
102 In the future this function will replace the current method
103 in `afw.MultibandExposure.computePsfImage` (DM-19789).
104 Parameters
105 ----------
106 position : `Point2D` or `tuple`
107 Coordinates to evaluate the PSF. If `position` is `None`
108 then `Psf.getAveragePosition()` is used.
109 Returns
110 -------
111 self._psfImage: array
112 The multiband PSF image.
113 """
114 psfs = []
115 # Make the coordinates into a Point2D (if necessary)
116 if not isinstance(position, Point2D) and position is not None:
117 position = Point2D(position[0], position[1])
119 for bidx, single in enumerate(self.singles):
120 try:
121 if position is None:
122 psf = single.getPsf().computeImage()
123 psfs.append(psf)
124 else:
125 psf = single.getPsf().computeKernelImage(position)
126 psfs.append(psf)
127 except InvalidParameterError:
128 # This band failed to compute the PSF due to incomplete data
129 # at that location. This is unlikely to be a problem for Rubin,
130 # however the edges of some HSC COSMOS fields contain incomplete
131 # data in some bands, so we track this error to distinguish it
132 # from unknown errors.
133 msg = "Failed to compute PSF at {} in band {}"
134 raise IncompleteDataError(msg.format(position, self.filters[bidx]))
136 left = np.min([psf.getBBox().getMinX() for psf in psfs])
137 bottom = np.min([psf.getBBox().getMinY() for psf in psfs])
138 right = np.max([psf.getBBox().getMaxX() for psf in psfs])
139 top = np.max([psf.getBBox().getMaxY() for psf in psfs])
140 bbox = Box2I(Point2I(left, bottom), Point2I(right, top))
141 psfs = [afwImage.utils.projectImage(psf, bbox) for psf in psfs]
142 psfImage = afwImage.MultibandImage.fromImages(self.filters, psfs)
143 return psfImage
146def getFootprintMask(footprint, mExposure):
147 """Mask pixels outside the footprint
149 Parameters
150 ----------
151 mExposure : `lsst.image.MultibandExposure`
152 - The multiband exposure containing the image,
153 mask, and variance data
154 footprint : `lsst.detection.Footprint`
155 - The footprint of the parent to deblend
157 Returns
158 -------
159 footprintMask : array
160 Boolean array with pixels not in the footprint set to one.
161 """
162 bbox = footprint.getBBox()
163 fpMask = afwImage.Mask(bbox)
164 footprint.spans.setMask(fpMask, 1)
165 fpMask = ~fpMask.getArray().astype(bool)
166 return fpMask
169def isPseudoSource(source, pseudoColumns):
170 """Check if a source is a pseudo source.
172 This is mostly for skipping sky objects,
173 but any other column can also be added to disable
174 deblending on a parent or individual source when
175 set to `True`.
177 Parameters
178 ----------
179 source : `lsst.afw.table.source.source.SourceRecord`
180 The source to check for the pseudo bit.
181 pseudoColumns : `list` of `str`
182 A list of columns to check for pseudo sources.
183 """
184 isPseudo = False
185 for col in pseudoColumns:
186 try:
187 isPseudo |= source[col]
188 except KeyError:
189 pass
190 return isPseudo
193def deblend(mExposure, footprint, config):
194 """Deblend a parent footprint
196 Parameters
197 ----------
198 mExposure : `lsst.image.MultibandExposure`
199 - The multiband exposure containing the image,
200 mask, and variance data
201 footprint : `lsst.detection.Footprint`
202 - The footprint of the parent to deblend
203 config : `ScarletDeblendConfig`
204 - Configuration of the deblending task
206 Returns
207 -------
208 blend : `scarlet.Blend`
209 The scarlet blend class that contains all of the information
210 about the parameters and results from scarlet
211 skipped : `list` of `int`
212 The indices of any children that failed to initialize
213 and were skipped.
214 spectrumInit : `bool`
215 Whether or not all of the sources were initialized by jointly
216 fitting their SED's. This provides a better initialization
217 but created memory issues when a blend is too large or
218 contains too many sources.
219 """
220 # Extract coordinates from each MultiColorPeak
221 bbox = footprint.getBBox()
223 # Create the data array from the masked images
224 images = mExposure.image[:, bbox].array
226 # Use the inverse variance as the weights
227 if config.useWeights:
228 weights = 1/mExposure.variance[:, bbox].array
229 else:
230 weights = np.ones_like(images)
231 badPixels = mExposure.mask.getPlaneBitMask(config.badMask)
232 mask = mExposure.mask[:, bbox].array & badPixels
233 weights[mask > 0] = 0
235 # Mask out the pixels outside the footprint
236 mask = getFootprintMask(footprint, mExposure)
237 weights *= ~mask
239 psfs = _computePsfImage(mExposure, footprint.getCentroid()).array.astype(np.float32)
240 psfs = ImagePSF(psfs)
241 model_psf = GaussianPSF(sigma=(config.modelPsfSigma,)*len(mExposure.filters))
243 frame = Frame(images.shape, psf=model_psf, channels=mExposure.filters)
244 observation = Observation(images, psf=psfs, weights=weights, channels=mExposure.filters)
245 if config.convolutionType == "fft":
246 observation.match(frame)
247 elif config.convolutionType == "real":
248 renderer = ConvolutionRenderer(observation, frame, convolution_type="real")
249 observation.match(frame, renderer=renderer)
250 else:
251 raise ValueError("Unrecognized convolution type {}".format(config.convolutionType))
253 assert(config.sourceModel in ["single", "double", "compact", "fit"])
255 # Set the appropriate number of components
256 if config.sourceModel == "single":
257 maxComponents = 1
258 elif config.sourceModel == "double":
259 maxComponents = 2
260 elif config.sourceModel == "compact":
261 maxComponents = 0
262 elif config.sourceModel == "point":
263 raise NotImplementedError("Point source photometry is currently not implemented")
264 elif config.sourceModel == "fit":
265 # It is likely in the future that there will be some heuristic
266 # used to determine what type of model to use for each source,
267 # but that has not yet been implemented (see DM-22551)
268 raise NotImplementedError("sourceModel 'fit' has not been implemented yet")
270 # Convert the centers to pixel coordinates
271 xmin = bbox.getMinX()
272 ymin = bbox.getMinY()
273 centers = [
274 np.array([peak.getIy() - ymin, peak.getIx() - xmin], dtype=int)
275 for peak in footprint.peaks
276 if not isPseudoSource(peak, config.pseudoColumns)
277 ]
279 # Choose whether or not to use the improved spectral initialization
280 if config.setSpectra:
281 if config.maxSpectrumCutoff <= 0:
282 spectrumInit = True
283 else:
284 spectrumInit = len(centers) * bbox.getArea() < config.maxSpectrumCutoff
285 else:
286 spectrumInit = False
288 # Only deblend sources that can be initialized
289 sources, skipped = init_all_sources(
290 frame=frame,
291 centers=centers,
292 observations=observation,
293 thresh=config.morphThresh,
294 max_components=maxComponents,
295 min_snr=config.minSNR,
296 shifting=False,
297 fallback=config.fallback,
298 silent=config.catchFailures,
299 set_spectra=spectrumInit,
300 )
302 # Attach the peak to all of the initialized sources
303 srcIndex = 0
304 for k, center in enumerate(centers):
305 if k not in skipped:
306 # This is just to make sure that there isn't a coding bug
307 assert np.all(sources[srcIndex].center == center)
308 # Store the record for the peak with the appropriate source
309 sources[srcIndex].detectedPeak = footprint.peaks[k]
310 srcIndex += 1
312 # Create the blend and attempt to optimize it
313 blend = Blend(sources, observation)
314 try:
315 blend.fit(max_iter=config.maxIter, e_rel=config.relativeError)
316 except ArithmeticError:
317 # This occurs when a gradient update produces a NaN value
318 # This is usually due to a source initialized with a
319 # negative SED or no flux, often because the peak
320 # is a noise fluctuation in one band and not a real source.
321 iterations = len(blend.loss)
322 failedSources = []
323 for k, src in enumerate(sources):
324 if np.any(~np.isfinite(src.get_model())):
325 failedSources.append(k)
326 raise ScarletGradientError(iterations, failedSources)
328 return blend, skipped, spectrumInit
331class ScarletDeblendConfig(pexConfig.Config):
332 """MultibandDeblendConfig
334 Configuration for the multiband deblender.
335 The parameters are organized by the parameter types, which are
336 - Stopping Criteria: Used to determine if the fit has converged
337 - Position Fitting Criteria: Used to fit the positions of the peaks
338 - Constraints: Used to apply constraints to the peaks and their components
339 - Other: Parameters that don't fit into the above categories
340 """
341 # Stopping Criteria
342 maxIter = pexConfig.Field(dtype=int, default=300,
343 doc=("Maximum number of iterations to deblend a single parent"))
344 relativeError = pexConfig.Field(dtype=float, default=1e-4,
345 doc=("Change in the loss function between"
346 "iterations to exit fitter"))
348 # Constraints
349 morphThresh = pexConfig.Field(dtype=float, default=1,
350 doc="Fraction of background RMS a pixel must have"
351 "to be included in the initial morphology")
352 # Other scarlet paremeters
353 useWeights = pexConfig.Field(
354 dtype=bool, default=True,
355 doc=("Whether or not use use inverse variance weighting."
356 "If `useWeights` is `False` then flat weights are used"))
357 modelPsfSize = pexConfig.Field(
358 dtype=int, default=11,
359 doc="Model PSF side length in pixels")
360 modelPsfSigma = pexConfig.Field(
361 dtype=float, default=0.8,
362 doc="Define sigma for the model frame PSF")
363 minSNR = pexConfig.Field(
364 dtype=float, default=50,
365 doc="Minimum Signal to noise to accept the source."
366 "Sources with lower flux will be initialized with the PSF but updated "
367 "like an ordinary ExtendedSource (known in scarlet as a `CompactSource`).")
368 saveTemplates = pexConfig.Field(
369 dtype=bool, default=True,
370 doc="Whether or not to save the SEDs and templates")
371 processSingles = pexConfig.Field(
372 dtype=bool, default=True,
373 doc="Whether or not to process isolated sources in the deblender")
374 convolutionType = pexConfig.Field(
375 dtype=str, default="fft",
376 doc="Type of convolution to render the model to the observations.\n"
377 "- 'fft': perform convolutions in Fourier space\n"
378 "- 'real': peform convolutions in real space.")
379 sourceModel = pexConfig.Field(
380 dtype=str, default="double",
381 doc=("How to determine which model to use for sources, from\n"
382 "- 'single': use a single component for all sources\n"
383 "- 'double': use a bulge disk model for all sources\n"
384 "- 'compact': use a single component model, initialzed with a point source morphology, "
385 " for all sources\n"
386 "- 'point': use a point-source model for all sources\n"
387 "- 'fit: use a PSF fitting model to determine the number of components (not yet implemented)")
388 )
389 setSpectra = pexConfig.Field(
390 dtype=bool, default=True,
391 doc="Whether or not to solve for the best-fit spectra during initialization. "
392 "This makes initialization slightly longer, as it requires a convolution "
393 "to set the optimal spectra, but results in a much better initial log-likelihood "
394 "and reduced total runtime, with convergence in fewer iterations."
395 "This option is only used when "
396 "peaks*area < `maxSpectrumCutoff` will use the improved initialization.")
398 # Mask-plane restrictions
399 badMask = pexConfig.ListField(
400 dtype=str, default=["BAD", "CR", "NO_DATA", "SAT", "SUSPECT", "EDGE"],
401 doc="Whether or not to process isolated sources in the deblender")
402 statsMask = pexConfig.ListField(dtype=str, default=["SAT", "INTRP", "NO_DATA"],
403 doc="Mask planes to ignore when performing statistics")
404 maskLimits = pexConfig.DictField(
405 keytype=str,
406 itemtype=float,
407 default={},
408 doc=("Mask planes with the corresponding limit on the fraction of masked pixels. "
409 "Sources violating this limit will not be deblended."),
410 )
412 # Size restrictions
413 maxNumberOfPeaks = pexConfig.Field(
414 dtype=int, default=0,
415 doc=("Only deblend the brightest maxNumberOfPeaks peaks in the parent"
416 " (<= 0: unlimited)"))
417 maxFootprintArea = pexConfig.Field(
418 dtype=int, default=1000000,
419 doc=("Maximum area for footprints before they are ignored as large; "
420 "non-positive means no threshold applied"))
421 maxFootprintSize = pexConfig.Field(
422 dtype=int, default=0,
423 doc=("Maximum linear dimension for footprints before they are ignored "
424 "as large; non-positive means no threshold applied"))
425 minFootprintAxisRatio = pexConfig.Field(
426 dtype=float, default=0.0,
427 doc=("Minimum axis ratio for footprints before they are ignored "
428 "as large; non-positive means no threshold applied"))
429 maxSpectrumCutoff = pexConfig.Field(
430 dtype=int, default=1000000,
431 doc=("Maximum number of pixels * number of sources in a blend. "
432 "This is different than `maxFootprintArea` because this isn't "
433 "the footprint area but the area of the bounding box that "
434 "contains the footprint, and is also multiplied by the number of"
435 "sources in the footprint. This prevents large skinny blends with "
436 "a high density of sources from running out of memory. "
437 "If `maxSpectrumCutoff == -1` then there is no cutoff.")
438 )
440 # Failure modes
441 fallback = pexConfig.Field(
442 dtype=bool, default=True,
443 doc="Whether or not to fallback to a smaller number of components if a source does not initialize"
444 )
445 notDeblendedMask = pexConfig.Field(
446 dtype=str, default="NOT_DEBLENDED", optional=True,
447 doc="Mask name for footprints not deblended, or None")
448 catchFailures = pexConfig.Field(
449 dtype=bool, default=True,
450 doc=("If True, catch exceptions thrown by the deblender, log them, "
451 "and set a flag on the parent, instead of letting them propagate up"))
453 # Other options
454 columnInheritance = pexConfig.DictField(
455 keytype=str, itemtype=str, default={
456 "deblend_nChild": "deblend_parentNChild",
457 "deblend_nPeaks": "deblend_parentNPeaks",
458 "deblend_spectrumInitFlag": "deblend_spectrumInitFlag",
459 "deblend_blendConvergenceFailedFlag": "deblend_blendConvergenceFailedFlag",
460 },
461 doc="Columns to pass from the parent to the child. "
462 "The key is the name of the column for the parent record, "
463 "the value is the name of the column to use for the child."
464 )
465 pseudoColumns = pexConfig.ListField(
466 dtype=str, default=['merge_peak_sky', 'sky_source'],
467 doc="Names of flags which should never be deblended."
468 )
470 # Logging option(s)
471 loggingInterval = pexConfig.Field(
472 dtype=int, default=600,
473 doc="Interval (in seconds) to log messages (at VERBOSE level) while deblending sources."
474 )
475 # Testing options
476 # Some obs packages and ci packages run the full pipeline on a small
477 # subset of data to test that the pipeline is functioning properly.
478 # This is not meant as scientific validation, so it can be useful
479 # to only run on a small subset of the data that is large enough to
480 # test the desired pipeline features but not so long that the deblender
481 # is the tall pole in terms of execution times.
482 useCiLimits = pexConfig.Field(
483 dtype=bool, default=False,
484 doc="Limit the number of sources deblended for CI to prevent long build times")
485 ciDeblendChildRange = pexConfig.ListField(
486 dtype=int, default=[5, 10],
487 doc="Only deblend parent Footprints with a number of peaks in the (inclusive) range indicated."
488 "If `useCiLimits==False` then this parameter is ignored.")
489 ciNumParentsToDeblend = pexConfig.Field(
490 dtype=int, default=10,
491 doc="Only use the first `ciNumParentsToDeblend` parent footprints with a total peak count "
492 "within `ciDebledChildRange`. "
493 "If `useCiLimits==False` then this parameter is ignored.")
496class ScarletDeblendTask(pipeBase.Task):
497 """ScarletDeblendTask
499 Split blended sources into individual sources.
501 This task has no return value; it only modifies the SourceCatalog in-place.
502 """
503 ConfigClass = ScarletDeblendConfig
504 _DefaultName = "scarletDeblend"
506 def __init__(self, schema, peakSchema=None, **kwargs):
507 """Create the task, adding necessary fields to the given schema.
509 Parameters
510 ----------
511 schema : `lsst.afw.table.schema.schema.Schema`
512 Schema object for measurement fields; will be modified in-place.
513 peakSchema : `lsst.afw.table.schema.schema.Schema`
514 Schema of Footprint Peaks that will be passed to the deblender.
515 Any fields beyond the PeakTable minimal schema will be transferred
516 to the main source Schema. If None, no fields will be transferred
517 from the Peaks.
518 filters : list of str
519 Names of the filters used for the eposures. This is needed to store
520 the SED as a field
521 **kwargs
522 Passed to Task.__init__.
523 """
524 pipeBase.Task.__init__(self, **kwargs)
526 peakMinimalSchema = afwDet.PeakTable.makeMinimalSchema()
527 if peakSchema is None:
528 # In this case, the peakSchemaMapper will transfer nothing, but
529 # we'll still have one
530 # to simplify downstream code
531 self.peakSchemaMapper = afwTable.SchemaMapper(peakMinimalSchema, schema)
532 else:
533 self.peakSchemaMapper = afwTable.SchemaMapper(peakSchema, schema)
534 for item in peakSchema:
535 if item.key not in peakMinimalSchema:
536 self.peakSchemaMapper.addMapping(item.key, item.field)
537 # Because SchemaMapper makes a copy of the output schema
538 # you give its ctor, it isn't updating this Schema in
539 # place. That's probably a design flaw, but in the
540 # meantime, we'll keep that schema in sync with the
541 # peakSchemaMapper.getOutputSchema() manually, by adding
542 # the same fields to both.
543 schema.addField(item.field)
544 assert schema == self.peakSchemaMapper.getOutputSchema(), "Logic bug mapping schemas"
545 self._addSchemaKeys(schema)
546 self.schema = schema
547 self.toCopyFromParent = [item.key for item in self.schema
548 if item.field.getName().startswith("merge_footprint")]
550 def _addSchemaKeys(self, schema):
551 """Add deblender specific keys to the schema
552 """
553 self.runtimeKey = schema.addField('deblend_runtime', type=np.float32, doc='runtime in ms')
555 self.iterKey = schema.addField('deblend_iterations', type=np.int32, doc='iterations to converge')
557 self.nChildKey = schema.addField('deblend_nChild', type=np.int32,
558 doc='Number of children this object has (defaults to 0)')
559 self.psfKey = schema.addField('deblend_deblendedAsPsf', type='Flag',
560 doc='Deblender thought this source looked like a PSF')
561 self.tooManyPeaksKey = schema.addField('deblend_tooManyPeaks', type='Flag',
562 doc='Source had too many peaks; '
563 'only the brightest were included')
564 self.tooBigKey = schema.addField('deblend_parentTooBig', type='Flag',
565 doc='Parent footprint covered too many pixels')
566 self.maskedKey = schema.addField('deblend_masked', type='Flag',
567 doc='Parent footprint was predominantly masked')
568 self.sedNotConvergedKey = schema.addField('deblend_sedConvergenceFailed', type='Flag',
569 doc='scarlet sed optimization did not converge before'
570 'config.maxIter')
571 self.morphNotConvergedKey = schema.addField('deblend_morphConvergenceFailed', type='Flag',
572 doc='scarlet morph optimization did not converge before'
573 'config.maxIter')
574 self.blendConvergenceFailedFlagKey = schema.addField('deblend_blendConvergenceFailedFlag',
575 type='Flag',
576 doc='at least one source in the blend'
577 'failed to converge')
578 self.edgePixelsKey = schema.addField('deblend_edgePixels', type='Flag',
579 doc='Source had flux on the edge of the parent footprint')
580 self.deblendFailedKey = schema.addField('deblend_failed', type='Flag',
581 doc="Deblending failed on source")
582 self.deblendErrorKey = schema.addField('deblend_error', type="String", size=25,
583 doc='Name of error if the blend failed')
584 self.deblendSkippedKey = schema.addField('deblend_skipped', type='Flag',
585 doc="Deblender skipped this source")
586 self.peakCenter = afwTable.Point2IKey.addFields(schema, name="deblend_peak_center",
587 doc="Center used to apply constraints in scarlet",
588 unit="pixel")
589 self.peakIdKey = schema.addField("deblend_peakId", type=np.int32,
590 doc="ID of the peak in the parent footprint. "
591 "This is not unique, but the combination of 'parent'"
592 "and 'peakId' should be for all child sources. "
593 "Top level blends with no parents have 'peakId=0'")
594 self.modelCenterFlux = schema.addField('deblend_peak_instFlux', type=float, units='count',
595 doc="The instFlux at the peak position of deblended mode")
596 self.modelTypeKey = schema.addField("deblend_modelType", type="String", size=25,
597 doc="The type of model used, for example "
598 "MultiExtendedSource, SingleExtendedSource, PointSource")
599 self.nPeaksKey = schema.addField("deblend_nPeaks", type=np.int32,
600 doc="Number of initial peaks in the blend. "
601 "This includes peaks that may have been culled "
602 "during deblending or failed to deblend")
603 self.parentNPeaksKey = schema.addField("deblend_parentNPeaks", type=np.int32,
604 doc="deblend_nPeaks from this records parent.")
605 self.parentNChildKey = schema.addField("deblend_parentNChild", type=np.int32,
606 doc="deblend_nChild from this records parent.")
607 self.scarletFluxKey = schema.addField("deblend_scarletFlux", type=np.float32,
608 doc="Flux measurement from scarlet")
609 self.scarletLogLKey = schema.addField("deblend_logL", type=np.float32,
610 doc="Final logL, used to identify regressions in scarlet.")
611 self.scarletSpectrumInitKey = schema.addField("deblend_spectrumInitFlag", type='Flag',
612 doc="True when scarlet initializes sources "
613 "in the blend with a more accurate spectrum. "
614 "The algorithm uses a lot of memory, "
615 "so large dense blends will use "
616 "a less accurate initialization.")
618 # self.log.trace('Added keys to schema: %s', ", ".join(str(x) for x in
619 # (self.nChildKey, self.tooManyPeaksKey, self.tooBigKey))
620 # )
622 @timeMethod
623 def run(self, mExposure, mergedSources):
624 """Get the psf from each exposure and then run deblend().
626 Parameters
627 ----------
628 mExposure : `MultibandExposure`
629 The exposures should be co-added images of the same
630 shape and region of the sky.
631 mergedSources : `SourceCatalog`
632 The merged `SourceCatalog` that contains parent footprints
633 to (potentially) deblend.
635 Returns
636 -------
637 templateCatalogs: dict
638 Keys are the names of the filters and the values are
639 `lsst.afw.table.source.source.SourceCatalog`'s.
640 These are catalogs with heavy footprints that are the templates
641 created by the multiband templates.
642 """
643 return self.deblend(mExposure, mergedSources)
645 @timeMethod
646 def deblend(self, mExposure, catalog):
647 """Deblend a data cube of multiband images
649 Parameters
650 ----------
651 mExposure : `MultibandExposure`
652 The exposures should be co-added images of the same
653 shape and region of the sky.
654 catalog : `SourceCatalog`
655 The merged `SourceCatalog` that contains parent footprints
656 to (potentially) deblend. The new deblended sources are
657 appended to this catalog in place.
659 Returns
660 -------
661 catalogs : `dict` or `None`
662 Keys are the names of the filters and the values are
663 `lsst.afw.table.source.source.SourceCatalog`'s.
664 These are catalogs with heavy footprints that are the templates
665 created by the multiband templates.
666 """
667 import time
669 # Cull footprints if required by ci
670 if self.config.useCiLimits:
671 self.log.info("Using CI catalog limits, the original number of sources to deblend was %d.",
672 len(catalog))
673 # Select parents with a number of children in the range
674 # config.ciDeblendChildRange
675 minChildren, maxChildren = self.config.ciDeblendChildRange
676 nPeaks = np.array([len(src.getFootprint().peaks) for src in catalog])
677 childrenInRange = np.where((nPeaks >= minChildren) & (nPeaks <= maxChildren))[0]
678 if len(childrenInRange) < self.config.ciNumParentsToDeblend:
679 raise ValueError("Fewer than ciNumParentsToDeblend children were contained in the range "
680 "indicated by ciDeblendChildRange. Adjust this range to include more "
681 "parents.")
682 # Keep all of the isolated parents and the first
683 # `ciNumParentsToDeblend` children
684 parents = nPeaks == 1
685 children = np.zeros((len(catalog),), dtype=bool)
686 children[childrenInRange[:self.config.ciNumParentsToDeblend]] = True
687 catalog = catalog[parents | children]
688 # We need to update the IdFactory, otherwise the the source ids
689 # will not be sequential
690 idFactory = catalog.getIdFactory()
691 maxId = np.max(catalog["id"])
692 idFactory.notify(maxId)
694 filters = mExposure.filters
695 self.log.info("Deblending %d sources in %d exposure bands", len(catalog), len(mExposure))
696 nextLogTime = time.time() + self.config.loggingInterval
698 # Add the NOT_DEBLENDED mask to the mask plane in each band
699 if self.config.notDeblendedMask:
700 for mask in mExposure.mask:
701 mask.addMaskPlane(self.config.notDeblendedMask)
703 nParents = len(catalog)
704 nDeblendedParents = 0
705 skippedParents = []
706 multibandColumns = {
707 "heavies": [],
708 "fluxes": [],
709 "centerFluxes": [],
710 }
711 for parentIndex in range(nParents):
712 parent = catalog[parentIndex]
713 foot = parent.getFootprint()
714 bbox = foot.getBBox()
715 peaks = foot.getPeaks()
717 # Since we use the first peak for the parent object, we should
718 # propagate its flags to the parent source.
719 parent.assign(peaks[0], self.peakSchemaMapper)
721 # Skip isolated sources unless processSingles is turned on.
722 # Note: this does not flag isolated sources as skipped or
723 # set the NOT_DEBLENDED mask in the exposure,
724 # since these aren't really a skipped blends.
725 # We also skip pseudo sources, like sky objects, which
726 # are intended to be skipped
727 if ((len(peaks) < 2 and not self.config.processSingles)
728 or isPseudoSource(parent, self.config.pseudoColumns)):
729 self._updateParentRecord(
730 parent=parent,
731 nPeaks=len(peaks),
732 nChild=0,
733 runtime=np.nan,
734 iterations=0,
735 logL=np.nan,
736 spectrumInit=False,
737 converged=False,
738 )
739 continue
741 # Block of conditions for skipping a parent with multiple children
742 skipKey = None
743 if self._isLargeFootprint(foot):
744 # The footprint is above the maximum footprint size limit
745 skipKey = self.tooBigKey
746 skipMessage = f"Parent {parent.getId()}: skipping large footprint"
747 elif self._isMasked(foot, mExposure):
748 # The footprint exceeds the maximum number of masked pixels
749 skipKey = self.maskedKey
750 skipMessage = f"Parent {parent.getId()}: skipping masked footprint"
751 elif self.config.maxNumberOfPeaks > 0 and len(peaks) > self.config.maxNumberOfPeaks:
752 # Unlike meas_deblender, in scarlet we skip the entire blend
753 # if the number of peaks exceeds max peaks, since neglecting
754 # to model any peaks often results in catastrophic failure
755 # of scarlet to generate models for the brighter sources.
756 skipKey = self.tooManyPeaksKey
757 skipMessage = f"Parent {parent.getId()}: Too many peaks, skipping blend"
758 if skipKey is not None:
759 self._skipParent(
760 parent=parent,
761 skipKey=skipKey,
762 logMessage=skipMessage,
763 )
764 skippedParents.append(parentIndex)
765 continue
767 nDeblendedParents += 1
768 self.log.trace("Parent %d: deblending %d peaks", parent.getId(), len(peaks))
769 # Run the deblender
770 blendError = None
771 try:
772 t0 = time.time()
773 # Build the parameter lists with the same ordering
774 blend, skipped, spectrumInit = deblend(mExposure, foot, self.config)
775 tf = time.time()
776 runtime = (tf-t0)*1000
777 converged = _checkBlendConvergence(blend, self.config.relativeError)
778 scarletSources = [src for src in blend.sources]
779 nChild = len(scarletSources)
780 # Re-insert place holders for skipped sources
781 # to propagate them in the catalog so
782 # that the peaks stay consistent
783 for k in skipped:
784 scarletSources.insert(k, None)
785 # Catch all errors and filter out the ones that we know about
786 except Exception as e:
787 blendError = type(e).__name__
788 if isinstance(e, ScarletGradientError):
789 parent.set(self.iterKey, e.iterations)
790 elif not isinstance(e, IncompleteDataError):
791 blendError = "UnknownError"
792 if self.config.catchFailures:
793 # Make it easy to find UnknownErrors in the log file
794 self.log.warn("UnknownError")
795 import traceback
796 traceback.print_exc()
797 else:
798 raise
800 self._skipParent(
801 parent=parent,
802 skipKey=self.deblendFailedKey,
803 logMessage=f"Unable to deblend source {parent.getId}: {blendError}",
804 )
805 parent.set(self.deblendErrorKey, blendError)
806 skippedParents.append(parentIndex)
807 continue
809 # Update the parent record with the deblending results
810 logL = blend.loss[-1]-blend.observations[0].log_norm
811 self._updateParentRecord(
812 parent=parent,
813 nPeaks=len(peaks),
814 nChild=nChild,
815 runtime=runtime,
816 iterations=len(blend.loss),
817 logL=logL,
818 spectrumInit=spectrumInit,
819 converged=converged,
820 )
822 # Add each deblended source to the catalog
823 for k, scarletSource in enumerate(scarletSources):
824 # Skip any sources with no flux or that scarlet skipped because
825 # it could not initialize
826 if k in skipped:
827 # No need to propagate anything
828 continue
829 parent.set(self.deblendSkippedKey, False)
830 mHeavy = modelToHeavy(scarletSource, mExposure, blend, xy0=bbox.getMin())
831 multibandColumns["heavies"].append(mHeavy)
832 flux = scarlet.measure.flux(scarletSource)
833 multibandColumns["fluxes"].append({
834 filters[fidx]: _flux
835 for fidx, _flux in enumerate(flux)
836 })
837 centerFlux = self._getCenterFlux(mHeavy, scarletSource, xy0=bbox.getMin())
838 multibandColumns["centerFluxes"].append(centerFlux)
840 # Add all fields except the HeavyFootprint to the
841 # source record
842 self._addChild(
843 parent=parent,
844 mHeavy=mHeavy,
845 catalog=catalog,
846 scarletSource=scarletSource,
847 )
849 # Log a message if it has been a while since the last log.
850 if (currentTime := time.time()) > nextLogTime:
851 nextLogTime = currentTime + self.config.loggingInterval
852 self.log.verbose("Deblended %d parent sources out of %d", parentIndex + 1, nParents)
854 # Clear the cached values in scarlet to clear out memory
855 scarlet.cache.Cache._cache = {}
857 # Make sure that the number of new sources matches the number of
858 # entries in each of the band dependent columns.
859 # This should never trigger and is just a sanity check.
860 nChildren = len(catalog) - nParents
861 if np.any([len(meas) != nChildren for meas in multibandColumns.values()]):
862 msg = f"Added {len(catalog)-nParents} new sources, but have "
863 msg += ", ".join([
864 f"{len(value)} {key}"
865 for key, value in multibandColumns
866 ])
867 raise RuntimeError(msg)
868 # Make a copy of the catlog in each band and update the footprints
869 catalogs = {}
870 for f in filters:
871 _catalog = afwTable.SourceCatalog(catalog.table.clone())
872 _catalog.extend(catalog, deep=True)
873 # Update the footprints and columns that are different
874 # for each filter
875 for sourceIndex, source in enumerate(_catalog[nParents:]):
876 source.setFootprint(multibandColumns["heavies"][sourceIndex][f])
877 source.set(self.scarletFluxKey, multibandColumns["fluxes"][sourceIndex][f])
878 source.set(self.modelCenterFlux, multibandColumns["centerFluxes"][sourceIndex][f])
879 catalogs[f] = _catalog
881 # Update the mExposure mask with the footprint of skipped parents
882 if self.config.notDeblendedMask:
883 for mask in mExposure.mask:
884 for parentIndex in skippedParents:
885 fp = _catalog[parentIndex].getFootprint()
886 fp.spans.setMask(mask, mask.getPlaneBitMask(self.config.notDeblendedMask))
888 self.log.info("Deblender results: of %d parent sources, %d were deblended, "
889 "creating %d children, for a total of %d sources",
890 nParents, nDeblendedParents, nChildren, len(catalog))
891 return catalogs
893 def _isLargeFootprint(self, footprint):
894 """Returns whether a Footprint is large
896 'Large' is defined by thresholds on the area, size and axis ratio.
897 These may be disabled independently by configuring them to be
898 non-positive.
900 This is principally intended to get rid of satellite streaks, which the
901 deblender or other downstream processing can have trouble dealing with
902 (e.g., multiple large HeavyFootprints can chew up memory).
903 """
904 if self.config.maxFootprintArea > 0 and footprint.getArea() > self.config.maxFootprintArea:
905 return True
906 if self.config.maxFootprintSize > 0:
907 bbox = footprint.getBBox()
908 if max(bbox.getWidth(), bbox.getHeight()) > self.config.maxFootprintSize:
909 return True
910 if self.config.minFootprintAxisRatio > 0:
911 axes = afwEll.Axes(footprint.getShape())
912 if axes.getB() < self.config.minFootprintAxisRatio*axes.getA():
913 return True
914 return False
916 def _isMasked(self, footprint, mExposure):
917 """Returns whether the footprint violates the mask limits"""
918 bbox = footprint.getBBox()
919 mask = np.bitwise_or.reduce(mExposure.mask[:, bbox].array, axis=0)
920 size = float(footprint.getArea())
921 for maskName, limit in self.config.maskLimits.items():
922 maskVal = mExposure.mask.getPlaneBitMask(maskName)
923 _mask = afwImage.MaskX(mask & maskVal, xy0=bbox.getMin())
924 unmaskedSpan = footprint.spans.intersectNot(_mask) # spanset of unmasked pixels
925 if (size - unmaskedSpan.getArea())/size > limit:
926 return True
927 return False
929 def _skipParent(self, parent, skipKey, logMessage):
930 """Update a parent record that is not being deblended.
932 This is a fairly trivial function but is implemented to ensure
933 that a skipped parent updates the appropriate columns
934 consistently, and always has a flag to mark the reason that
935 it is being skipped.
937 Parameters
938 ----------
939 parent : `lsst.afw.table.source.source.SourceRecord`
940 The parent record to flag as skipped.
941 skipKey : `bool`
942 The name of the flag to mark the reason for skipping.
943 logMessage : `str`
944 The message to display in a log.trace when a source
945 is skipped.
946 """
947 if logMessage is not None:
948 self.log.trace(logMessage)
949 self._updateParentRecord(
950 parent=parent,
951 nPeaks=len(parent.getFootprint().peaks),
952 nChild=0,
953 runtime=np.nan,
954 iterations=0,
955 logL=np.nan,
956 spectrumInit=False,
957 converged=False,
958 )
960 # Mark the source as skipped by the deblender and
961 # flag the reason why.
962 parent.set(self.deblendSkippedKey, True)
963 parent.set(skipKey, True)
965 def _updateParentRecord(self, parent, nPeaks, nChild,
966 runtime, iterations, logL, spectrumInit, converged):
967 """Update a parent record in all of the single band catalogs.
969 Ensure that all locations that update a parent record,
970 whether it is skipped or updated after deblending,
971 update all of the appropriate columns.
973 Parameters
974 ----------
975 parent : `lsst.afw.table.source.source.SourceRecord`
976 The parent record to update.
977 nPeaks : `int`
978 Number of peaks in the parent footprint.
979 nChild : `int`
980 Number of children deblended from the parent.
981 This may differ from `nPeaks` if some of the peaks
982 were culled and have no deblended model.
983 runtime : `float`
984 Total runtime for deblending.
985 iterations : `int`
986 Total number of iterations in scarlet before convergence.
987 logL : `float`
988 Final log likelihood of the blend.
989 spectrumInit : `bool`
990 True when scarlet used `set_spectra` to initialize all
991 sources with better initial intensities.
992 converged : `bool`
993 True when the optimizer reached convergence before
994 reaching the maximum number of iterations.
995 """
996 parent.set(self.nPeaksKey, nPeaks)
997 parent.set(self.nChildKey, nChild)
998 parent.set(self.runtimeKey, runtime)
999 parent.set(self.iterKey, iterations)
1000 parent.set(self.scarletLogLKey, logL)
1001 parent.set(self.scarletSpectrumInitKey, spectrumInit)
1002 parent.set(self.blendConvergenceFailedFlagKey, converged)
1004 def _addChild(self, parent, mHeavy, catalog, scarletSource):
1005 """Add a child to a catalog.
1007 This creates a new child in the source catalog,
1008 assigning it a parent id, and adding all columns
1009 that are independent across all filter bands.
1011 Parameters
1012 ----------
1013 parent : `lsst.afw.table.source.source.SourceRecord`
1014 The parent of the new child record.
1015 mHeavy : `lsst.detection.MultibandFootprint`
1016 The multi-band footprint containing the model and
1017 peak catalog for the new child record.
1018 catalog : `lsst.afw.table.source.source.SourceCatalog`
1019 The merged `SourceCatalog` that contains parent footprints
1020 to (potentially) deblend.
1021 scarletSource : `scarlet.Component`
1022 The scarlet model for the new source record.
1023 """
1024 src = catalog.addNew()
1025 for key in self.toCopyFromParent:
1026 src.set(key, parent.get(key))
1027 # The peak catalog is the same for all bands,
1028 # so we just use the first peak catalog
1029 peaks = mHeavy[mHeavy.filters[0]].peaks
1030 src.assign(peaks[0], self.peakSchemaMapper)
1031 src.setParent(parent.getId())
1032 # Currently all children only have a single peak,
1033 # but it's possible in the future that there will be hierarchical
1034 # deblending, so we use the footprint to set the number of peaks
1035 # for each child.
1036 src.set(self.nPeaksKey, len(peaks))
1037 # Set the psf key based on whether or not the source was
1038 # deblended using the PointSource model.
1039 # This key is not that useful anymore since we now keep track of
1040 # `modelType`, but we continue to propagate it in case code downstream
1041 # is expecting it.
1042 src.set(self.psfKey, scarletSource.__class__.__name__ == "PointSource")
1043 src.set(self.modelTypeKey, scarletSource.__class__.__name__)
1044 # We set the runtime to zero so that summing up the
1045 # runtime column will give the total time spent
1046 # running the deblender for the catalog.
1047 src.set(self.runtimeKey, 0)
1049 # Set the position of the peak from the parent footprint
1050 # This will make it easier to match the same source across
1051 # deblenders and across observations, where the peak
1052 # position is unlikely to change unless enough time passes
1053 # for a source to move on the sky.
1054 peak = scarletSource.detectedPeak
1055 src.set(self.peakCenter, Point2I(peak["i_x"], peak["i_y"]))
1056 src.set(self.peakIdKey, peak["id"])
1058 # Propagate columns from the parent to the child
1059 for parentColumn, childColumn in self.config.columnInheritance.items():
1060 src.set(childColumn, parent.get(parentColumn))
1062 def _getCenterFlux(self, mHeavy, scarletSource, xy0):
1063 """Get the flux at the center of a HeavyFootprint
1065 Parameters
1066 ----------
1067 mHeavy : `lsst.detection.MultibandFootprint`
1068 The multi-band footprint containing the model for the source.
1069 scarletSource : `scarlet.Component`
1070 The scarlet model for the heavy footprint
1071 """
1072 # Store the flux at the center of the model and the total
1073 # scarlet flux measurement.
1074 mImage = mHeavy.getImage(fill=0.0).image
1076 # Set the flux at the center of the model (for SNR)
1077 try:
1078 cy, cx = scarletSource.center
1079 cy += xy0.y
1080 cx += xy0.x
1081 return mImage[:, cx, cy]
1082 except AttributeError:
1083 msg = "Did not recognize coordinates for source type of `{0}`, "
1084 msg += "could not write coordinates or center flux. "
1085 msg += "Add `{0}` to meas_extensions_scarlet to properly persist this information."
1086 logger.warning(msg.format(type(scarletSource)))
1087 return {f: np.nan for f in mImage.filters}