Coverage for python/lsst/meas/extensions/scarlet/scarletDeblendTask.py: 15%
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of meas_extensions_scarlet.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import logging
23import numpy as np
24import scarlet
25from scarlet.psf import ImagePSF, GaussianPSF
26from scarlet import Blend, Frame, Observation
27from scarlet.renderer import ConvolutionRenderer
28from scarlet.initialization import init_all_sources
30import lsst.log
31import lsst.pex.config as pexConfig
32from lsst.pex.exceptions import InvalidParameterError
33import lsst.pipe.base as pipeBase
34from lsst.geom import Point2I, Box2I, Point2D
35import lsst.afw.geom.ellipses as afwEll
36import lsst.afw.image.utils
37import lsst.afw.image as afwImage
38import lsst.afw.detection as afwDet
39import lsst.afw.table as afwTable
40from lsst.utils.timer import timeMethod
42from .source import modelToHeavy
44# Scarlet and proxmin have a different definition of log levels than the stack,
45# so even "warnings" occur far more often than we would like.
46# So for now we only display scarlet and proxmin errors, as all other
47# scarlet outputs would be considered "TRACE" by our standards.
48scarletLogger = logging.getLogger("scarlet")
49scarletLogger.setLevel(logging.ERROR)
50proxminLogger = logging.getLogger("proxmin")
51proxminLogger.setLevel(logging.ERROR)
53__all__ = ["deblend", "ScarletDeblendConfig", "ScarletDeblendTask"]
55logger = lsst.log.Log.getLogger("meas.deblender.deblend")
58class IncompleteDataError(Exception):
59 """The PSF could not be computed due to incomplete data
60 """
61 pass
64class ScarletGradientError(Exception):
65 """An error occurred during optimization
67 This error occurs when the optimizer encounters
68 a NaN value while calculating the gradient.
69 """
70 def __init__(self, iterations, sources):
71 self.iterations = iterations
72 self.sources = sources
73 msg = ("ScalarGradientError in iteration {0}. "
74 "NaN values introduced in sources {1}")
75 self.message = msg.format(iterations, sources)
77 def __str__(self):
78 return self.message
81def _checkBlendConvergence(blend, f_rel):
82 """Check whether or not a blend has converged
83 """
84 deltaLoss = np.abs(blend.loss[-2] - blend.loss[-1])
85 convergence = f_rel * np.abs(blend.loss[-1])
86 return deltaLoss < convergence
89def _getPsfFwhm(psf):
90 """Calculate the FWHM of the `psf`
91 """
92 return psf.computeShape().getDeterminantRadius() * 2.35
95def _computePsfImage(self, position=None):
96 """Get a multiband PSF image
97 The PSF Kernel Image is computed for each band
98 and combined into a (filter, y, x) array and stored
99 as `self._psfImage`.
100 The result is not cached, so if the same PSF is expected
101 to be used multiple times it is a good idea to store the
102 result in another variable.
103 Note: this is a temporary fix during the deblender sprint.
104 In the future this function will replace the current method
105 in `afw.MultibandExposure.computePsfImage` (DM-19789).
106 Parameters
107 ----------
108 position : `Point2D` or `tuple`
109 Coordinates to evaluate the PSF. If `position` is `None`
110 then `Psf.getAveragePosition()` is used.
111 Returns
112 -------
113 self._psfImage: array
114 The multiband PSF image.
115 """
116 psfs = []
117 # Make the coordinates into a Point2D (if necessary)
118 if not isinstance(position, Point2D) and position is not None:
119 position = Point2D(position[0], position[1])
121 for bidx, single in enumerate(self.singles):
122 try:
123 if position is None:
124 psf = single.getPsf().computeImage()
125 psfs.append(psf)
126 else:
127 psf = single.getPsf().computeKernelImage(position)
128 psfs.append(psf)
129 except InvalidParameterError:
130 # This band failed to compute the PSF due to incomplete data
131 # at that location. This is unlikely to be a problem for Rubin,
132 # however the edges of some HSC COSMOS fields contain incomplete
133 # data in some bands, so we track this error to distinguish it
134 # from unknown errors.
135 msg = "Failed to compute PSF at {} in band {}"
136 raise IncompleteDataError(msg.format(position, self.filters[bidx]))
138 left = np.min([psf.getBBox().getMinX() for psf in psfs])
139 bottom = np.min([psf.getBBox().getMinY() for psf in psfs])
140 right = np.max([psf.getBBox().getMaxX() for psf in psfs])
141 top = np.max([psf.getBBox().getMaxY() for psf in psfs])
142 bbox = Box2I(Point2I(left, bottom), Point2I(right, top))
143 psfs = [afwImage.utils.projectImage(psf, bbox) for psf in psfs]
144 psfImage = afwImage.MultibandImage.fromImages(self.filters, psfs)
145 return psfImage
148def getFootprintMask(footprint, mExposure):
149 """Mask pixels outside the footprint
151 Parameters
152 ----------
153 mExposure : `lsst.image.MultibandExposure`
154 - The multiband exposure containing the image,
155 mask, and variance data
156 footprint : `lsst.detection.Footprint`
157 - The footprint of the parent to deblend
159 Returns
160 -------
161 footprintMask : array
162 Boolean array with pixels not in the footprint set to one.
163 """
164 bbox = footprint.getBBox()
165 fpMask = afwImage.Mask(bbox)
166 footprint.spans.setMask(fpMask, 1)
167 fpMask = ~fpMask.getArray().astype(bool)
168 return fpMask
171def isPseudoSource(source, pseudoColumns):
172 """Check if a source is a pseudo source.
174 This is mostly for skipping sky objects,
175 but any other column can also be added to disable
176 deblending on a parent or individual source when
177 set to `True`.
179 Parameters
180 ----------
181 source : `lsst.afw.table.source.source.SourceRecord`
182 The source to check for the pseudo bit.
183 pseudoColumns : `list` of `str`
184 A list of columns to check for pseudo sources.
185 """
186 isPseudo = False
187 for col in pseudoColumns:
188 try:
189 isPseudo |= source[col]
190 except KeyError:
191 pass
192 return isPseudo
195def deblend(mExposure, footprint, config):
196 """Deblend a parent footprint
198 Parameters
199 ----------
200 mExposure : `lsst.image.MultibandExposure`
201 - The multiband exposure containing the image,
202 mask, and variance data
203 footprint : `lsst.detection.Footprint`
204 - The footprint of the parent to deblend
205 config : `ScarletDeblendConfig`
206 - Configuration of the deblending task
207 """
208 # Extract coordinates from each MultiColorPeak
209 bbox = footprint.getBBox()
211 # Create the data array from the masked images
212 images = mExposure.image[:, bbox].array
214 # Use the inverse variance as the weights
215 if config.useWeights:
216 weights = 1/mExposure.variance[:, bbox].array
217 else:
218 weights = np.ones_like(images)
219 badPixels = mExposure.mask.getPlaneBitMask(config.badMask)
220 mask = mExposure.mask[:, bbox].array & badPixels
221 weights[mask > 0] = 0
223 # Mask out the pixels outside the footprint
224 mask = getFootprintMask(footprint, mExposure)
225 weights *= ~mask
227 psfs = _computePsfImage(mExposure, footprint.getCentroid()).array.astype(np.float32)
228 psfs = ImagePSF(psfs)
229 model_psf = GaussianPSF(sigma=(config.modelPsfSigma,)*len(mExposure.filters))
231 frame = Frame(images.shape, psf=model_psf, channels=mExposure.filters)
232 observation = Observation(images, psf=psfs, weights=weights, channels=mExposure.filters)
233 if config.convolutionType == "fft":
234 observation.match(frame)
235 elif config.convolutionType == "real":
236 renderer = ConvolutionRenderer(observation, frame, convolution_type="real")
237 observation.match(frame, renderer=renderer)
238 else:
239 raise ValueError("Unrecognized convolution type {}".format(config.convolutionType))
241 assert(config.sourceModel in ["single", "double", "compact", "fit"])
243 # Set the appropriate number of components
244 if config.sourceModel == "single":
245 maxComponents = 1
246 elif config.sourceModel == "double":
247 maxComponents = 2
248 elif config.sourceModel == "compact":
249 maxComponents = 0
250 elif config.sourceModel == "point":
251 raise NotImplementedError("Point source photometry is currently not implemented")
252 elif config.sourceModel == "fit":
253 # It is likely in the future that there will be some heuristic
254 # used to determine what type of model to use for each source,
255 # but that has not yet been implemented (see DM-22551)
256 raise NotImplementedError("sourceModel 'fit' has not been implemented yet")
258 # Convert the centers to pixel coordinates
259 xmin = bbox.getMinX()
260 ymin = bbox.getMinY()
261 centers = [
262 np.array([peak.getIy() - ymin, peak.getIx() - xmin], dtype=int)
263 for peak in footprint.peaks
264 if not isPseudoSource(peak, config.pseudoColumns)
265 ]
267 # Choose whether or not to use the improved spectral initialization
268 if config.setSpectra:
269 if config.maxSpectrumCutoff <= 0:
270 spectrumInit = True
271 else:
272 spectrumInit = len(centers) * bbox.getArea() < config.maxSpectrumCutoff
273 else:
274 spectrumInit = False
276 # Only deblend sources that can be initialized
277 sources, skipped = init_all_sources(
278 frame=frame,
279 centers=centers,
280 observations=observation,
281 thresh=config.morphThresh,
282 max_components=maxComponents,
283 min_snr=config.minSNR,
284 shifting=False,
285 fallback=config.fallback,
286 silent=config.catchFailures,
287 set_spectra=spectrumInit,
288 )
290 # Attach the peak to all of the initialized sources
291 srcIndex = 0
292 for k, center in enumerate(centers):
293 if k not in skipped:
294 # This is just to make sure that there isn't a coding bug
295 assert np.all(sources[srcIndex].center == center)
296 # Store the record for the peak with the appropriate source
297 sources[srcIndex].detectedPeak = footprint.peaks[k]
298 srcIndex += 1
300 # Create the blend and attempt to optimize it
301 blend = Blend(sources, observation)
302 try:
303 blend.fit(max_iter=config.maxIter, e_rel=config.relativeError)
304 except ArithmeticError:
305 # This occurs when a gradient update produces a NaN value
306 # This is usually due to a source initialized with a
307 # negative SED or no flux, often because the peak
308 # is a noise fluctuation in one band and not a real source.
309 iterations = len(blend.loss)
310 failedSources = []
311 for k, src in enumerate(sources):
312 if np.any(~np.isfinite(src.get_model())):
313 failedSources.append(k)
314 raise ScarletGradientError(iterations, failedSources)
316 return blend, skipped, spectrumInit
319class ScarletDeblendConfig(pexConfig.Config):
320 """MultibandDeblendConfig
322 Configuration for the multiband deblender.
323 The parameters are organized by the parameter types, which are
324 - Stopping Criteria: Used to determine if the fit has converged
325 - Position Fitting Criteria: Used to fit the positions of the peaks
326 - Constraints: Used to apply constraints to the peaks and their components
327 - Other: Parameters that don't fit into the above categories
328 """
329 # Stopping Criteria
330 maxIter = pexConfig.Field(dtype=int, default=300,
331 doc=("Maximum number of iterations to deblend a single parent"))
332 relativeError = pexConfig.Field(dtype=float, default=1e-4,
333 doc=("Change in the loss function between"
334 "iterations to exit fitter"))
336 # Constraints
337 morphThresh = pexConfig.Field(dtype=float, default=1,
338 doc="Fraction of background RMS a pixel must have"
339 "to be included in the initial morphology")
340 # Other scarlet paremeters
341 useWeights = pexConfig.Field(
342 dtype=bool, default=True,
343 doc=("Whether or not use use inverse variance weighting."
344 "If `useWeights` is `False` then flat weights are used"))
345 modelPsfSize = pexConfig.Field(
346 dtype=int, default=11,
347 doc="Model PSF side length in pixels")
348 modelPsfSigma = pexConfig.Field(
349 dtype=float, default=0.8,
350 doc="Define sigma for the model frame PSF")
351 minSNR = pexConfig.Field(
352 dtype=float, default=50,
353 doc="Minimum Signal to noise to accept the source."
354 "Sources with lower flux will be initialized with the PSF but updated "
355 "like an ordinary ExtendedSource (known in scarlet as a `CompactSource`).")
356 saveTemplates = pexConfig.Field(
357 dtype=bool, default=True,
358 doc="Whether or not to save the SEDs and templates")
359 processSingles = pexConfig.Field(
360 dtype=bool, default=True,
361 doc="Whether or not to process isolated sources in the deblender")
362 convolutionType = pexConfig.Field(
363 dtype=str, default="fft",
364 doc="Type of convolution to render the model to the observations.\n"
365 "- 'fft': perform convolutions in Fourier space\n"
366 "- 'real': peform convolutions in real space.")
367 sourceModel = pexConfig.Field(
368 dtype=str, default="double",
369 doc=("How to determine which model to use for sources, from\n"
370 "- 'single': use a single component for all sources\n"
371 "- 'double': use a bulge disk model for all sources\n"
372 "- 'compact': use a single component model, initialzed with a point source morphology, "
373 " for all sources\n"
374 "- 'point': use a point-source model for all sources\n"
375 "- 'fit: use a PSF fitting model to determine the number of components (not yet implemented)")
376 )
377 setSpectra = pexConfig.Field(
378 dtype=bool, default=True,
379 doc="Whether or not to solve for the best-fit spectra during initialization. "
380 "This makes initialization slightly longer, as it requires a convolution "
381 "to set the optimal spectra, but results in a much better initial log-likelihood "
382 "and reduced total runtime, with convergence in fewer iterations."
383 "This option is only used when "
384 "peaks*area < `maxSpectrumCutoff` will use the improved initialization.")
386 # Mask-plane restrictions
387 badMask = pexConfig.ListField(
388 dtype=str, default=["BAD", "CR", "NO_DATA", "SAT", "SUSPECT", "EDGE"],
389 doc="Whether or not to process isolated sources in the deblender")
390 statsMask = pexConfig.ListField(dtype=str, default=["SAT", "INTRP", "NO_DATA"],
391 doc="Mask planes to ignore when performing statistics")
392 maskLimits = pexConfig.DictField(
393 keytype=str,
394 itemtype=float,
395 default={},
396 doc=("Mask planes with the corresponding limit on the fraction of masked pixels. "
397 "Sources violating this limit will not be deblended."),
398 )
400 # Size restrictions
401 maxNumberOfPeaks = pexConfig.Field(
402 dtype=int, default=0,
403 doc=("Only deblend the brightest maxNumberOfPeaks peaks in the parent"
404 " (<= 0: unlimited)"))
405 maxFootprintArea = pexConfig.Field(
406 dtype=int, default=1000000,
407 doc=("Maximum area for footprints before they are ignored as large; "
408 "non-positive means no threshold applied"))
409 maxFootprintSize = pexConfig.Field(
410 dtype=int, default=0,
411 doc=("Maximum linear dimension for footprints before they are ignored "
412 "as large; non-positive means no threshold applied"))
413 minFootprintAxisRatio = pexConfig.Field(
414 dtype=float, default=0.0,
415 doc=("Minimum axis ratio for footprints before they are ignored "
416 "as large; non-positive means no threshold applied"))
417 maxSpectrumCutoff = pexConfig.Field(
418 dtype=int, default=1000000,
419 doc=("Maximum number of pixels * number of sources in a blend. "
420 "This is different than `maxFootprintArea` because this isn't "
421 "the footprint area but the area of the bounding box that "
422 "contains the footprint, and is also multiplied by the number of"
423 "sources in the footprint. This prevents large skinny blends with "
424 "a high density of sources from running out of memory. "
425 "If `maxSpectrumCutoff == -1` then there is no cutoff.")
426 )
428 # Failure modes
429 fallback = pexConfig.Field(
430 dtype=bool, default=True,
431 doc="Whether or not to fallback to a smaller number of components if a source does not initialize"
432 )
433 notDeblendedMask = pexConfig.Field(
434 dtype=str, default="NOT_DEBLENDED", optional=True,
435 doc="Mask name for footprints not deblended, or None")
436 catchFailures = pexConfig.Field(
437 dtype=bool, default=True,
438 doc=("If True, catch exceptions thrown by the deblender, log them, "
439 "and set a flag on the parent, instead of letting them propagate up"))
441 # Other options
442 columnInheritance = pexConfig.DictField(
443 keytype=str, itemtype=str, default={
444 "deblend_nChild": "deblend_parentNChild",
445 "deblend_nPeaks": "deblend_parentNPeaks",
446 "deblend_spectrumInitFlag": "deblend_spectrumInitFlag",
447 "deblend_blendConvergenceFailedFlag": "deblend_blendConvergenceFailedFlag",
448 },
449 doc="Columns to pass from the parent to the child. "
450 "The key is the name of the column for the parent record, "
451 "the value is the name of the column to use for the child."
452 )
453 pseudoColumns = pexConfig.ListField(
454 dtype=str, default=['merge_peak_sky', 'sky_source'],
455 doc="Names of flags which should never be deblended."
456 )
458 # Logging option(s)
459 loggingInterval = pexConfig.Field(
460 dtype=int, default=600,
461 doc="Interval (in seconds) to log messages (at VERBOSE level) while deblending sources."
462 )
463 # Testing options
464 # Some obs packages and ci packages run the full pipeline on a small
465 # subset of data to test that the pipeline is functioning properly.
466 # This is not meant as scientific validation, so it can be useful
467 # to only run on a small subset of the data that is large enough to
468 # test the desired pipeline features but not so long that the deblender
469 # is the tall pole in terms of execution times.
470 useCiLimits = pexConfig.Field(
471 dtype=bool, default=False,
472 doc="Limit the number of sources deblended for CI to prevent long build times")
473 ciDeblendChildRange = pexConfig.ListField(
474 dtype=int, default=[5, 10],
475 doc="Only deblend parent Footprints with a number of peaks in the (inclusive) range indicated."
476 "If `useCiLimits==False` then this parameter is ignored.")
477 ciNumParentsToDeblend = pexConfig.Field(
478 dtype=int, default=10,
479 doc="Only use the first `ciNumParentsToDeblend` parent footprints with a total peak count "
480 "within `ciDebledChildRange`. "
481 "If `useCiLimits==False` then this parameter is ignored.")
484class ScarletDeblendTask(pipeBase.Task):
485 """ScarletDeblendTask
487 Split blended sources into individual sources.
489 This task has no return value; it only modifies the SourceCatalog in-place.
490 """
491 ConfigClass = ScarletDeblendConfig
492 _DefaultName = "scarletDeblend"
494 def __init__(self, schema, peakSchema=None, **kwargs):
495 """Create the task, adding necessary fields to the given schema.
497 Parameters
498 ----------
499 schema : `lsst.afw.table.schema.schema.Schema`
500 Schema object for measurement fields; will be modified in-place.
501 peakSchema : `lsst.afw.table.schema.schema.Schema`
502 Schema of Footprint Peaks that will be passed to the deblender.
503 Any fields beyond the PeakTable minimal schema will be transferred
504 to the main source Schema. If None, no fields will be transferred
505 from the Peaks.
506 filters : list of str
507 Names of the filters used for the eposures. This is needed to store
508 the SED as a field
509 **kwargs
510 Passed to Task.__init__.
511 """
512 pipeBase.Task.__init__(self, **kwargs)
514 peakMinimalSchema = afwDet.PeakTable.makeMinimalSchema()
515 if peakSchema is None:
516 # In this case, the peakSchemaMapper will transfer nothing, but
517 # we'll still have one
518 # to simplify downstream code
519 self.peakSchemaMapper = afwTable.SchemaMapper(peakMinimalSchema, schema)
520 else:
521 self.peakSchemaMapper = afwTable.SchemaMapper(peakSchema, schema)
522 for item in peakSchema:
523 if item.key not in peakMinimalSchema:
524 self.peakSchemaMapper.addMapping(item.key, item.field)
525 # Because SchemaMapper makes a copy of the output schema
526 # you give its ctor, it isn't updating this Schema in
527 # place. That's probably a design flaw, but in the
528 # meantime, we'll keep that schema in sync with the
529 # peakSchemaMapper.getOutputSchema() manually, by adding
530 # the same fields to both.
531 schema.addField(item.field)
532 assert schema == self.peakSchemaMapper.getOutputSchema(), "Logic bug mapping schemas"
533 self._addSchemaKeys(schema)
534 self.schema = schema
535 self.toCopyFromParent = [item.key for item in self.schema
536 if item.field.getName().startswith("merge_footprint")]
538 def _addSchemaKeys(self, schema):
539 """Add deblender specific keys to the schema
540 """
541 self.runtimeKey = schema.addField('deblend_runtime', type=np.float32, doc='runtime in ms')
543 self.iterKey = schema.addField('deblend_iterations', type=np.int32, doc='iterations to converge')
545 self.nChildKey = schema.addField('deblend_nChild', type=np.int32,
546 doc='Number of children this object has (defaults to 0)')
547 self.psfKey = schema.addField('deblend_deblendedAsPsf', type='Flag',
548 doc='Deblender thought this source looked like a PSF')
549 self.tooManyPeaksKey = schema.addField('deblend_tooManyPeaks', type='Flag',
550 doc='Source had too many peaks; '
551 'only the brightest were included')
552 self.tooBigKey = schema.addField('deblend_parentTooBig', type='Flag',
553 doc='Parent footprint covered too many pixels')
554 self.maskedKey = schema.addField('deblend_masked', type='Flag',
555 doc='Parent footprint was predominantly masked')
556 self.sedNotConvergedKey = schema.addField('deblend_sedConvergenceFailed', type='Flag',
557 doc='scarlet sed optimization did not converge before'
558 'config.maxIter')
559 self.morphNotConvergedKey = schema.addField('deblend_morphConvergenceFailed', type='Flag',
560 doc='scarlet morph optimization did not converge before'
561 'config.maxIter')
562 self.blendConvergenceFailedFlagKey = schema.addField('deblend_blendConvergenceFailedFlag',
563 type='Flag',
564 doc='at least one source in the blend'
565 'failed to converge')
566 self.edgePixelsKey = schema.addField('deblend_edgePixels', type='Flag',
567 doc='Source had flux on the edge of the parent footprint')
568 self.deblendFailedKey = schema.addField('deblend_failed', type='Flag',
569 doc="Deblending failed on source")
570 self.deblendErrorKey = schema.addField('deblend_error', type="String", size=25,
571 doc='Name of error if the blend failed')
572 self.deblendSkippedKey = schema.addField('deblend_skipped', type='Flag',
573 doc="Deblender skipped this source")
574 self.peakCenter = afwTable.Point2IKey.addFields(schema, name="deblend_peak_center",
575 doc="Center used to apply constraints in scarlet",
576 unit="pixel")
577 self.peakIdKey = schema.addField("deblend_peakId", type=np.int32,
578 doc="ID of the peak in the parent footprint. "
579 "This is not unique, but the combination of 'parent'"
580 "and 'peakId' should be for all child sources. "
581 "Top level blends with no parents have 'peakId=0'")
582 self.modelCenterFlux = schema.addField('deblend_peak_instFlux', type=float, units='count',
583 doc="The instFlux at the peak position of deblended mode")
584 self.modelTypeKey = schema.addField("deblend_modelType", type="String", size=25,
585 doc="The type of model used, for example "
586 "MultiExtendedSource, SingleExtendedSource, PointSource")
587 self.nPeaksKey = schema.addField("deblend_nPeaks", type=np.int32,
588 doc="Number of initial peaks in the blend. "
589 "This includes peaks that may have been culled "
590 "during deblending or failed to deblend")
591 self.parentNPeaksKey = schema.addField("deblend_parentNPeaks", type=np.int32,
592 doc="deblend_nPeaks from this records parent.")
593 self.parentNChildKey = schema.addField("deblend_parentNChild", type=np.int32,
594 doc="deblend_nChild from this records parent.")
595 self.scarletFluxKey = schema.addField("deblend_scarletFlux", type=np.float32,
596 doc="Flux measurement from scarlet")
597 self.scarletLogLKey = schema.addField("deblend_logL", type=np.float32,
598 doc="Final logL, used to identify regressions in scarlet.")
599 self.scarletSpectrumInitKey = schema.addField("deblend_spectrumInitFlag", type='Flag',
600 doc="True when scarlet initializes sources "
601 "in the blend with a more accurate spectrum. "
602 "The algorithm uses a lot of memory, "
603 "so large dense blends will use "
604 "a less accurate initialization.")
606 # self.log.trace('Added keys to schema: %s', ", ".join(str(x) for x in
607 # (self.nChildKey, self.tooManyPeaksKey, self.tooBigKey))
608 # )
610 @timeMethod
611 def run(self, mExposure, mergedSources):
612 """Get the psf from each exposure and then run deblend().
614 Parameters
615 ----------
616 mExposure : `MultibandExposure`
617 The exposures should be co-added images of the same
618 shape and region of the sky.
619 mergedSources : `SourceCatalog`
620 The merged `SourceCatalog` that contains parent footprints
621 to (potentially) deblend.
623 Returns
624 -------
625 templateCatalogs: dict
626 Keys are the names of the filters and the values are
627 `lsst.afw.table.source.source.SourceCatalog`'s.
628 These are catalogs with heavy footprints that are the templates
629 created by the multiband templates.
630 """
631 return self.deblend(mExposure, mergedSources)
633 @timeMethod
634 def deblend(self, mExposure, catalog):
635 """Deblend a data cube of multiband images
637 Parameters
638 ----------
639 mExposure : `MultibandExposure`
640 The exposures should be co-added images of the same
641 shape and region of the sky.
642 catalog : `SourceCatalog`
643 The merged `SourceCatalog` that contains parent footprints
644 to (potentially) deblend. The new deblended sources are
645 appended to this catalog in place.
647 Returns
648 -------
649 catalogs : `dict` or `None`
650 Keys are the names of the filters and the values are
651 `lsst.afw.table.source.source.SourceCatalog`'s.
652 These are catalogs with heavy footprints that are the templates
653 created by the multiband templates.
654 """
655 import time
657 # Cull footprints if required by ci
658 if self.config.useCiLimits:
659 self.log.info("Using CI catalog limits, the original number of sources to deblend was %d.",
660 len(catalog))
661 # Select parents with a number of children in the range
662 # config.ciDeblendChildRange
663 minChildren, maxChildren = self.config.ciDeblendChildRange
664 nPeaks = np.array([len(src.getFootprint().peaks) for src in catalog])
665 childrenInRange = np.where((nPeaks >= minChildren) & (nPeaks <= maxChildren))[0]
666 if len(childrenInRange) < self.config.ciNumParentsToDeblend:
667 raise ValueError("Fewer than ciNumParentsToDeblend children were contained in the range "
668 "indicated by ciDeblendChildRange. Adjust this range to include more "
669 "parents.")
670 # Keep all of the isolated parents and the first
671 # `ciNumParentsToDeblend` children
672 parents = nPeaks == 1
673 children = np.zeros((len(catalog),), dtype=bool)
674 children[childrenInRange[:self.config.ciNumParentsToDeblend]] = True
675 catalog = catalog[parents | children]
676 # We need to update the IdFactory, otherwise the the source ids
677 # will not be sequential
678 idFactory = catalog.getIdFactory()
679 maxId = np.max(catalog["id"])
680 idFactory.notify(maxId)
682 filters = mExposure.filters
683 self.log.info("Deblending %d sources in %d exposure bands", len(catalog), len(mExposure))
684 nextLogTime = time.time() + self.config.loggingInterval
686 # Add the NOT_DEBLENDED mask to the mask plane in each band
687 if self.config.notDeblendedMask:
688 for mask in mExposure.mask:
689 mask.addMaskPlane(self.config.notDeblendedMask)
691 nParents = len(catalog)
692 nDeblendedParents = 0
693 skippedParents = []
694 multibandColumns = {
695 "heavies": [],
696 "fluxes": [],
697 "centerFluxes": [],
698 }
699 for parentIndex in range(nParents):
700 parent = catalog[parentIndex]
701 foot = parent.getFootprint()
702 bbox = foot.getBBox()
703 peaks = foot.getPeaks()
705 # Since we use the first peak for the parent object, we should
706 # propagate its flags to the parent source.
707 parent.assign(peaks[0], self.peakSchemaMapper)
709 # Skip isolated sources unless processSingles is turned on.
710 # Note: this does not flag isolated sources as skipped or
711 # set the NOT_DEBLENDED mask in the exposure,
712 # since these aren't really a skipped blends.
713 # We also skip pseudo sources, like sky objects, which
714 # are intended to be skipped
715 if ((len(peaks) < 2 and not self.config.processSingles)
716 or isPseudoSource(parent, self.config.pseudoColumns)):
717 self._updateParentRecord(
718 parent=parent,
719 nPeaks=len(peaks),
720 nChild=0,
721 runtime=np.nan,
722 iterations=0,
723 logL=np.nan,
724 spectrumInit=False,
725 converged=False,
726 )
727 continue
729 # Block of conditions for skipping a parent with multiple children
730 skipKey = None
731 if self._isLargeFootprint(foot):
732 # The footprint is above the maximum footprint size limit
733 skipKey = self.tooBigKey
734 skipMessage = f"Parent {parent.getId()}: skipping large footprint"
735 elif self._isMasked(foot, mExposure):
736 # The footprint exceeds the maximum number of masked pixels
737 skipKey = self.maskedKey
738 skipMessage = f"Parent {parent.getId()}: skipping masked footprint"
739 elif self.config.maxNumberOfPeaks > 0 and len(peaks) > self.config.maxNumberOfPeaks:
740 # Unlike meas_deblender, in scarlet we skip the entire blend
741 # if the number of peaks exceeds max peaks, since neglecting
742 # to model any peaks often results in catastrophic failure
743 # of scarlet to generate models for the brighter sources.
744 skipKey = self.tooManyPeaksKey
745 skipMessage = f"Parent {parent.getId()}: Too many peaks, skipping blend"
746 if skipKey is not None:
747 self._skipParent(
748 parent=parent,
749 skipKey=skipKey,
750 logMessage=skipMessage,
751 )
752 skippedParents.append(parentIndex)
753 continue
755 nDeblendedParents += 1
756 self.log.trace("Parent %d: deblending %d peaks", parent.getId(), len(peaks))
757 # Run the deblender
758 blendError = None
759 try:
760 t0 = time.time()
761 # Build the parameter lists with the same ordering
762 blend, skipped, spectrumInit = deblend(mExposure, foot, self.config)
763 tf = time.time()
764 runtime = (tf-t0)*1000
765 converged = _checkBlendConvergence(blend, self.config.relativeError)
766 scarletSources = [src for src in blend.sources]
767 nChild = len(scarletSources)
768 # Re-insert place holders for skipped sources
769 # to propagate them in the catalog so
770 # that the peaks stay consistent
771 for k in skipped:
772 scarletSources.insert(k, None)
773 # Catch all errors and filter out the ones that we know about
774 except Exception as e:
775 blendError = type(e).__name__
776 if isinstance(e, ScarletGradientError):
777 parent.set(self.iterKey, e.iterations)
778 elif not isinstance(e, IncompleteDataError):
779 blendError = "UnknownError"
780 if self.config.catchFailures:
781 # Make it easy to find UnknownErrors in the log file
782 self.log.warn("UnknownError")
783 import traceback
784 traceback.print_exc()
785 else:
786 raise
788 self._skipParent(
789 parent=parent,
790 skipKey=self.deblendFailedKey,
791 logMessage=f"Unable to deblend source {parent.getId}: {blendError}",
792 )
793 parent.set(self.deblendErrorKey, blendError)
794 skippedParents.append(parentIndex)
795 continue
797 # Update the parent record with the deblending results
798 logL = blend.loss[-1]-blend.observations[0].log_norm
799 self._updateParentRecord(
800 parent=parent,
801 nPeaks=len(peaks),
802 nChild=nChild,
803 runtime=runtime,
804 iterations=len(blend.loss),
805 logL=logL,
806 spectrumInit=spectrumInit,
807 converged=converged,
808 )
810 # Add each deblended source to the catalog
811 for k, scarletSource in enumerate(scarletSources):
812 # Skip any sources with no flux or that scarlet skipped because
813 # it could not initialize
814 if k in skipped:
815 # No need to propagate anything
816 continue
817 parent.set(self.deblendSkippedKey, False)
818 mHeavy = modelToHeavy(scarletSource, filters, xy0=bbox.getMin(),
819 observation=blend.observations[0])
820 multibandColumns["heavies"].append(mHeavy)
821 flux = scarlet.measure.flux(scarletSource)
822 multibandColumns["fluxes"].append({
823 filters[fidx]: _flux
824 for fidx, _flux in enumerate(flux)
825 })
826 centerFlux = self._getCenterFlux(mHeavy, scarletSource, xy0=bbox.getMin())
827 multibandColumns["centerFluxes"].append(centerFlux)
829 # Add all fields except the HeavyFootprint to the
830 # source record
831 self._addChild(
832 parent=parent,
833 mHeavy=mHeavy,
834 catalog=catalog,
835 scarletSource=scarletSource,
836 )
837 # Log a message if it has been a while since the last log.
838 if (currentTime := time.time()) > nextLogTime:
839 nextLogTime = currentTime + self.config.loggingInterval
840 self.log.verbose("Deblended %d parent sources out of %d", parentIndex + 1, nParents)
842 # Make sure that the number of new sources matches the number of
843 # entries in each of the band dependent columns.
844 # This should never trigger and is just a sanity check.
845 nChildren = len(catalog) - nParents
846 if np.any([len(meas) != nChildren for meas in multibandColumns.values()]):
847 msg = f"Added {len(catalog)-nParents} new sources, but have "
848 msg += ", ".join([
849 f"{len(value)} {key}"
850 for key, value in multibandColumns
851 ])
852 raise RuntimeError(msg)
853 # Make a copy of the catlog in each band and update the footprints
854 catalogs = {}
855 for f in filters:
856 _catalog = afwTable.SourceCatalog(catalog.table.clone())
857 _catalog.extend(catalog, deep=True)
858 # Update the footprints and columns that are different
859 # for each filter
860 for sourceIndex, source in enumerate(_catalog[nParents:]):
861 source.setFootprint(multibandColumns["heavies"][sourceIndex][f])
862 source.set(self.scarletFluxKey, multibandColumns["fluxes"][sourceIndex][f])
863 source.set(self.modelCenterFlux, multibandColumns["centerFluxes"][sourceIndex][f])
864 catalogs[f] = _catalog
866 # Update the mExposure mask with the footprint of skipped parents
867 if self.config.notDeblendedMask:
868 for mask in mExposure.mask:
869 for parentIndex in skippedParents:
870 fp = _catalog[parentIndex].getFootprint()
871 fp.spans.setMask(mask, mask.getPlaneBitMask(self.config.notDeblendedMask))
873 self.log.info("Deblender results: of %d parent sources, %d were deblended, "
874 "creating %d children, for a total of %d sources",
875 nParents, nDeblendedParents, nChildren, len(catalog))
876 return catalogs
878 def _isLargeFootprint(self, footprint):
879 """Returns whether a Footprint is large
881 'Large' is defined by thresholds on the area, size and axis ratio.
882 These may be disabled independently by configuring them to be
883 non-positive.
885 This is principally intended to get rid of satellite streaks, which the
886 deblender or other downstream processing can have trouble dealing with
887 (e.g., multiple large HeavyFootprints can chew up memory).
888 """
889 if self.config.maxFootprintArea > 0 and footprint.getArea() > self.config.maxFootprintArea:
890 return True
891 if self.config.maxFootprintSize > 0:
892 bbox = footprint.getBBox()
893 if max(bbox.getWidth(), bbox.getHeight()) > self.config.maxFootprintSize:
894 return True
895 if self.config.minFootprintAxisRatio > 0:
896 axes = afwEll.Axes(footprint.getShape())
897 if axes.getB() < self.config.minFootprintAxisRatio*axes.getA():
898 return True
899 return False
901 def _isMasked(self, footprint, mExposure):
902 """Returns whether the footprint violates the mask limits"""
903 bbox = footprint.getBBox()
904 mask = np.bitwise_or.reduce(mExposure.mask[:, bbox].array, axis=0)
905 size = float(footprint.getArea())
906 for maskName, limit in self.config.maskLimits.items():
907 maskVal = mExposure.mask.getPlaneBitMask(maskName)
908 _mask = afwImage.MaskX(mask & maskVal, xy0=bbox.getMin())
909 unmaskedSpan = footprint.spans.intersectNot(_mask) # spanset of unmasked pixels
910 if (size - unmaskedSpan.getArea())/size > limit:
911 return True
912 return False
914 def _skipParent(self, parent, skipKey, logMessage):
915 """Update a parent record that is not being deblended.
917 This is a fairly trivial function but is implemented to ensure
918 that a skipped parent updates the appropriate columns
919 consistently, and always has a flag to mark the reason that
920 it is being skipped.
922 Parameters
923 ----------
924 parent : `lsst.afw.table.source.source.SourceRecord`
925 The parent record to flag as skipped.
926 skipKey : `bool`
927 The name of the flag to mark the reason for skipping.
928 logMessage : `str`
929 The message to display in a log.trace when a source
930 is skipped.
931 """
932 if logMessage is not None:
933 self.log.trace(logMessage)
934 self._updateParentRecord(
935 parent=parent,
936 nPeaks=len(parent.getFootprint().peaks),
937 nChild=0,
938 runtime=np.nan,
939 iterations=0,
940 logL=np.nan,
941 spectrumInit=False,
942 converged=False,
943 )
945 # Mark the source as skipped by the deblender and
946 # flag the reason why.
947 parent.set(self.deblendSkippedKey, True)
948 parent.set(skipKey, True)
950 def _updateParentRecord(self, parent, nPeaks, nChild,
951 runtime, iterations, logL, spectrumInit, converged):
952 """Update a parent record in all of the single band catalogs.
954 Ensure that all locations that update a parent record,
955 whether it is skipped or updated after deblending,
956 update all of the appropriate columns.
958 Parameters
959 ----------
960 parent : `lsst.afw.table.source.source.SourceRecord`
961 The parent record to update.
962 nPeaks : `int`
963 Number of peaks in the parent footprint.
964 nChild : `int`
965 Number of children deblended from the parent.
966 This may differ from `nPeaks` if some of the peaks
967 were culled and have no deblended model.
968 runtime : `float`
969 Total runtime for deblending.
970 iterations : `int`
971 Total number of iterations in scarlet before convergence.
972 logL : `float`
973 Final log likelihood of the blend.
974 spectrumInit : `bool`
975 True when scarlet used `set_spectra` to initialize all
976 sources with better initial intensities.
977 converged : `bool`
978 True when the optimizer reached convergence before
979 reaching the maximum number of iterations.
980 """
981 parent.set(self.nPeaksKey, nPeaks)
982 parent.set(self.nChildKey, nChild)
983 parent.set(self.runtimeKey, runtime)
984 parent.set(self.iterKey, iterations)
985 parent.set(self.scarletLogLKey, logL)
986 parent.set(self.scarletSpectrumInitKey, spectrumInit)
987 parent.set(self.blendConvergenceFailedFlagKey, converged)
989 def _addChild(self, parent, mHeavy, catalog, scarletSource):
990 """Add a child to a catalog.
992 This creates a new child in the source catalog,
993 assigning it a parent id, and adding all columns
994 that are independent across all filter bands.
996 Parameters
997 ----------
998 parent : `lsst.afw.table.source.source.SourceRecord`
999 The parent of the new child record.
1000 mHeavy : `lsst.detection.MultibandFootprint`
1001 The multi-band footprint containing the model and
1002 peak catalog for the new child record.
1003 catalog : `lsst.afw.table.source.source.SourceCatalog`
1004 The merged `SourceCatalog` that contains parent footprints
1005 to (potentially) deblend.
1006 scarletSource : `scarlet.Component`
1007 The scarlet model for the new source record.
1008 """
1009 src = catalog.addNew()
1010 for key in self.toCopyFromParent:
1011 src.set(key, parent.get(key))
1012 # The peak catalog is the same for all bands,
1013 # so we just use the first peak catalog
1014 peaks = mHeavy[mHeavy.filters[0]].peaks
1015 src.assign(peaks[0], self.peakSchemaMapper)
1016 src.setParent(parent.getId())
1017 # Currently all children only have a single peak,
1018 # but it's possible in the future that there will be hierarchical
1019 # deblending, so we use the footprint to set the number of peaks
1020 # for each child.
1021 src.set(self.nPeaksKey, len(peaks))
1022 # Set the psf key based on whether or not the source was
1023 # deblended using the PointSource model.
1024 # This key is not that useful anymore since we now keep track of
1025 # `modelType`, but we continue to propagate it in case code downstream
1026 # is expecting it.
1027 src.set(self.psfKey, scarletSource.__class__.__name__ == "PointSource")
1028 src.set(self.modelTypeKey, scarletSource.__class__.__name__)
1029 # We set the runtime to zero so that summing up the
1030 # runtime column will give the total time spent
1031 # running the deblender for the catalog.
1032 src.set(self.runtimeKey, 0)
1034 # Set the position of the peak from the parent footprint
1035 # This will make it easier to match the same source across
1036 # deblenders and across observations, where the peak
1037 # position is unlikely to change unless enough time passes
1038 # for a source to move on the sky.
1039 peak = scarletSource.detectedPeak
1040 src.set(self.peakCenter, Point2I(peak["i_x"], peak["i_y"]))
1041 src.set(self.peakIdKey, peak["id"])
1043 # Propagate columns from the parent to the child
1044 for parentColumn, childColumn in self.config.columnInheritance.items():
1045 src.set(childColumn, parent.get(parentColumn))
1047 def _getCenterFlux(self, mHeavy, scarletSource, xy0):
1048 """Get the flux at the center of a HeavyFootprint
1050 Parameters
1051 ----------
1052 mHeavy : `lsst.detection.MultibandFootprint`
1053 The multi-band footprint containing the model for the source.
1054 scarletSource : `scarlet.Component`
1055 The scarlet model for the heavy footprint
1056 """
1057 # Store the flux at the center of the model and the total
1058 # scarlet flux measurement.
1059 mImage = mHeavy.getImage(fill=0.0).image
1061 # Set the flux at the center of the model (for SNR)
1062 try:
1063 cy, cx = scarletSource.center
1064 cy += xy0.y
1065 cx += xy0.x
1066 return mImage[:, cx, cy]
1067 except AttributeError:
1068 msg = "Did not recognize coordinates for source type of `{0}`, "
1069 msg += "could not write coordinates or center flux. "
1070 msg += "Add `{0}` to meas_extensions_scarlet to properly persist this information."
1071 logger.warning(msg.format(type(scarletSource)))
1072 return {f: np.nan for f in mImage.filters}