Coverage for python/lsst/meas/extensions/scarlet/scarletDeblendTask.py: 16%
525 statements
« prev ^ index » next coverage.py v6.4, created at 2022-05-26 10:56 +0000
« prev ^ index » next coverage.py v6.4, created at 2022-05-26 10:56 +0000
1# This file is part of meas_extensions_scarlet.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from dataclasses import dataclass
23from functools import partial
24import logging
25import numpy as np
26import scarlet
27from scarlet.psf import ImagePSF, GaussianPSF
28from scarlet import Blend, Frame, Observation
29from scarlet.renderer import ConvolutionRenderer
30from scarlet.detect import get_detect_wavelets
31from scarlet.initialization import init_all_sources
32from scarlet import lite
34import lsst.pex.config as pexConfig
35from lsst.pex.exceptions import InvalidParameterError
36import lsst.pipe.base as pipeBase
37from lsst.geom import Point2I, Box2I, Point2D
38import lsst.afw.geom.ellipses as afwEll
39import lsst.afw.image as afwImage
40import lsst.afw.detection as afwDet
41import lsst.afw.table as afwTable
42from lsst.utils.logging import PeriodicLogger
43from lsst.utils.timer import timeMethod
45from .source import bboxToScarletBox, modelToHeavy, liteModelToHeavy
47# Scarlet and proxmin have a different definition of log levels than the stack,
48# so even "warnings" occur far more often than we would like.
49# So for now we only display scarlet and proxmin errors, as all other
50# scarlet outputs would be considered "TRACE" by our standards.
51scarletLogger = logging.getLogger("scarlet")
52scarletLogger.setLevel(logging.ERROR)
53proxminLogger = logging.getLogger("proxmin")
54proxminLogger.setLevel(logging.ERROR)
56__all__ = ["deblend", "ScarletDeblendConfig", "ScarletDeblendTask"]
58logger = logging.getLogger(__name__)
61class IncompleteDataError(Exception):
62 """The PSF could not be computed due to incomplete data
63 """
64 pass
67class ScarletGradientError(Exception):
68 """An error occurred during optimization
70 This error occurs when the optimizer encounters
71 a NaN value while calculating the gradient.
72 """
73 def __init__(self, iterations, sources):
74 self.iterations = iterations
75 self.sources = sources
76 msg = ("ScalarGradientError in iteration {0}. "
77 "NaN values introduced in sources {1}")
78 self.message = msg.format(iterations, sources)
80 def __str__(self):
81 return self.message
84def _checkBlendConvergence(blend, f_rel):
85 """Check whether or not a blend has converged
86 """
87 deltaLoss = np.abs(blend.loss[-2] - blend.loss[-1])
88 convergence = f_rel * np.abs(blend.loss[-1])
89 return deltaLoss < convergence
92def _computePsfImage(self, position):
93 """Get a multiband PSF image
94 The PSF Kernel Image is computed for each band
95 and combined into a (filter, y, x) array and stored
96 as `self._psfImage`.
97 The result is not cached, so if the same PSF is expected
98 to be used multiple times it is a good idea to store the
99 result in another variable.
100 Note: this is a temporary fix during the deblender sprint.
101 In the future this function will replace the current method
102 in `afw.MultibandExposure.computePsfImage` (DM-19789).
103 Parameters
104 ----------
105 position : `Point2D` or `tuple`
106 Coordinates to evaluate the PSF.
107 Returns
108 -------
109 self._psfImage: array
110 The multiband PSF image.
111 """
112 psfs = []
113 # Make the coordinates into a Point2D (if necessary)
114 if not isinstance(position, Point2D):
115 position = Point2D(position[0], position[1])
117 for bidx, single in enumerate(self.singles):
118 try:
119 psf = single.getPsf().computeKernelImage(position)
120 psfs.append(psf)
121 except InvalidParameterError:
122 # This band failed to compute the PSF due to incomplete data
123 # at that location. This is unlikely to be a problem for Rubin,
124 # however the edges of some HSC COSMOS fields contain incomplete
125 # data in some bands, so we track this error to distinguish it
126 # from unknown errors.
127 msg = "Failed to compute PSF at {} in band {}"
128 raise IncompleteDataError(msg.format(position, self.filters[bidx]))
130 left = np.min([psf.getBBox().getMinX() for psf in psfs])
131 bottom = np.min([psf.getBBox().getMinY() for psf in psfs])
132 right = np.max([psf.getBBox().getMaxX() for psf in psfs])
133 top = np.max([psf.getBBox().getMaxY() for psf in psfs])
134 bbox = Box2I(Point2I(left, bottom), Point2I(right, top))
135 psfs = [afwImage.utils.projectImage(psf, bbox) for psf in psfs]
136 psfImage = afwImage.MultibandImage.fromImages(self.filters, psfs)
137 return psfImage
140def getFootprintMask(footprint, mExposure):
141 """Mask pixels outside the footprint
143 Parameters
144 ----------
145 mExposure : `lsst.image.MultibandExposure`
146 - The multiband exposure containing the image,
147 mask, and variance data
148 footprint : `lsst.detection.Footprint`
149 - The footprint of the parent to deblend
151 Returns
152 -------
153 footprintMask : array
154 Boolean array with pixels not in the footprint set to one.
155 """
156 bbox = footprint.getBBox()
157 fpMask = afwImage.Mask(bbox)
158 footprint.spans.setMask(fpMask, 1)
159 fpMask = ~fpMask.getArray().astype(bool)
160 return fpMask
163def isPseudoSource(source, pseudoColumns):
164 """Check if a source is a pseudo source.
166 This is mostly for skipping sky objects,
167 but any other column can also be added to disable
168 deblending on a parent or individual source when
169 set to `True`.
171 Parameters
172 ----------
173 source : `lsst.afw.table.source.source.SourceRecord`
174 The source to check for the pseudo bit.
175 pseudoColumns : `list` of `str`
176 A list of columns to check for pseudo sources.
177 """
178 isPseudo = False
179 for col in pseudoColumns:
180 try:
181 isPseudo |= source[col]
182 except KeyError:
183 pass
184 return isPseudo
187def deblend(mExposure, footprint, config):
188 """Deblend a parent footprint
190 Parameters
191 ----------
192 mExposure : `lsst.image.MultibandExposure`
193 - The multiband exposure containing the image,
194 mask, and variance data
195 footprint : `lsst.detection.Footprint`
196 - The footprint of the parent to deblend
197 config : `ScarletDeblendConfig`
198 - Configuration of the deblending task
200 Returns
201 -------
202 blend : `scarlet.Blend`
203 The scarlet blend class that contains all of the information
204 about the parameters and results from scarlet
205 skipped : `list` of `int`
206 The indices of any children that failed to initialize
207 and were skipped.
208 spectrumInit : `bool`
209 Whether or not all of the sources were initialized by jointly
210 fitting their SED's. This provides a better initialization
211 but created memory issues when a blend is too large or
212 contains too many sources.
213 """
214 # Extract coordinates from each MultiColorPeak
215 bbox = footprint.getBBox()
217 # Create the data array from the masked images
218 images = mExposure.image[:, bbox].array
220 # Use the inverse variance as the weights
221 if config.useWeights:
222 weights = 1/mExposure.variance[:, bbox].array
223 else:
224 weights = np.ones_like(images)
225 badPixels = mExposure.mask.getPlaneBitMask(config.badMask)
226 mask = mExposure.mask[:, bbox].array & badPixels
227 weights[mask > 0] = 0
229 # Mask out the pixels outside the footprint
230 mask = getFootprintMask(footprint, mExposure)
231 weights *= ~mask
233 psfs = _computePsfImage(mExposure, footprint.getCentroid()).array.astype(np.float32)
234 psfs = ImagePSF(psfs)
235 model_psf = GaussianPSF(sigma=(config.modelPsfSigma,)*len(mExposure.filters))
237 frame = Frame(images.shape, psf=model_psf, channels=mExposure.filters)
238 observation = Observation(images, psf=psfs, weights=weights, channels=mExposure.filters)
239 if config.convolutionType == "fft":
240 observation.match(frame)
241 elif config.convolutionType == "real":
242 renderer = ConvolutionRenderer(observation, frame, convolution_type="real")
243 observation.match(frame, renderer=renderer)
244 else:
245 raise ValueError("Unrecognized convolution type {}".format(config.convolutionType))
247 assert(config.sourceModel in ["single", "double", "compact", "fit"])
249 # Set the appropriate number of components
250 if config.sourceModel == "single":
251 maxComponents = 1
252 elif config.sourceModel == "double":
253 maxComponents = 2
254 elif config.sourceModel == "compact":
255 maxComponents = 0
256 elif config.sourceModel == "point":
257 raise NotImplementedError("Point source photometry is currently not implemented")
258 elif config.sourceModel == "fit":
259 # It is likely in the future that there will be some heuristic
260 # used to determine what type of model to use for each source,
261 # but that has not yet been implemented (see DM-22551)
262 raise NotImplementedError("sourceModel 'fit' has not been implemented yet")
264 # Convert the centers to pixel coordinates
265 xmin = bbox.getMinX()
266 ymin = bbox.getMinY()
267 centers = [
268 np.array([peak.getIy() - ymin, peak.getIx() - xmin], dtype=int)
269 for peak in footprint.peaks
270 if not isPseudoSource(peak, config.pseudoColumns)
271 ]
273 # Choose whether or not to use the improved spectral initialization
274 if config.setSpectra:
275 if config.maxSpectrumCutoff <= 0:
276 spectrumInit = True
277 else:
278 spectrumInit = len(centers) * bbox.getArea() < config.maxSpectrumCutoff
279 else:
280 spectrumInit = False
282 # Only deblend sources that can be initialized
283 sources, skipped = init_all_sources(
284 frame=frame,
285 centers=centers,
286 observations=observation,
287 thresh=config.morphThresh,
288 max_components=maxComponents,
289 min_snr=config.minSNR,
290 shifting=False,
291 fallback=config.fallback,
292 silent=config.catchFailures,
293 set_spectra=spectrumInit,
294 )
296 # Attach the peak to all of the initialized sources
297 srcIndex = 0
298 for k, center in enumerate(centers):
299 if k not in skipped:
300 # This is just to make sure that there isn't a coding bug
301 assert np.all(sources[srcIndex].center == center)
302 # Store the record for the peak with the appropriate source
303 sources[srcIndex].detectedPeak = footprint.peaks[k]
304 srcIndex += 1
306 # Create the blend and attempt to optimize it
307 blend = Blend(sources, observation)
308 try:
309 blend.fit(max_iter=config.maxIter, e_rel=config.relativeError)
310 except ArithmeticError:
311 # This occurs when a gradient update produces a NaN value
312 # This is usually due to a source initialized with a
313 # negative SED or no flux, often because the peak
314 # is a noise fluctuation in one band and not a real source.
315 iterations = len(blend.loss)
316 failedSources = []
317 for k, src in enumerate(sources):
318 if np.any(~np.isfinite(src.get_model())):
319 failedSources.append(k)
320 raise ScarletGradientError(iterations, failedSources)
322 return blend, skipped, spectrumInit
325def deblend_lite(mExposure, footprint, config, wavelets=None):
326 """Deblend a parent footprint
328 Parameters
329 ----------
330 mExposure : `lsst.image.MultibandExposure`
331 - The multiband exposure containing the image,
332 mask, and variance data
333 footprint : `lsst.detection.Footprint`
334 - The footprint of the parent to deblend
335 config : `ScarletDeblendConfig`
336 - Configuration of the deblending task
337 """
338 # Extract coordinates from each MultiColorPeak
339 bbox = footprint.getBBox()
341 # Create the data array from the masked images
342 images = mExposure.image[:, bbox].array
343 variance = mExposure.variance[:, bbox].array
345 # Use the inverse variance as the weights
346 if config.useWeights:
347 weights = 1/mExposure.variance[:, bbox].array
348 else:
349 weights = np.ones_like(images)
350 badPixels = mExposure.mask.getPlaneBitMask(config.badMask)
351 mask = mExposure.mask[:, bbox].array & badPixels
352 weights[mask > 0] = 0
354 # Mask out the pixels outside the footprint
355 mask = getFootprintMask(footprint, mExposure)
356 weights *= ~mask
358 psfs = _computePsfImage(mExposure, footprint.getCentroid()).array.astype(np.float32)
359 modelPsf = lite.integrated_circular_gaussian(sigma=config.modelPsfSigma)
361 observation = lite.LiteObservation(
362 images=images,
363 variance=variance,
364 weights=weights,
365 psfs=psfs,
366 model_psf=modelPsf[None, :, :],
367 convolution_mode=config.convolutionType,
368 )
370 # Convert the centers to pixel coordinates
371 xmin = bbox.getMinX()
372 ymin = bbox.getMinY()
373 centers = [
374 np.array([peak.getIy() - ymin, peak.getIx() - xmin], dtype=int)
375 for peak in footprint.peaks
376 if not isPseudoSource(peak, config.pseudoColumns)
377 ]
379 # Initialize the sources
380 if config.morphImage == "chi2":
381 sources = lite.init_all_sources_main(
382 observation,
383 centers,
384 min_snr=config.minSNR,
385 thresh=config.morphThresh,
386 )
387 elif config.morphImage == "wavelet":
388 _bbox = bboxToScarletBox(len(mExposure.filters), bbox, bbox.getMin())
389 _wavelets = wavelets[(slice(None), *_bbox[1:].slices)]
390 sources = lite.init_all_sources_wavelets(
391 observation,
392 centers,
393 use_psf=False,
394 wavelets=_wavelets,
395 min_snr=config.minSNR,
396 )
397 else:
398 raise ValueError("morphImage must be either 'chi2' or 'wavelet'.")
400 # Set the optimizer
401 if config.optimizer == "adaprox":
402 parameterization = partial(
403 lite.init_adaprox_component,
404 bg_thresh=config.backgroundThresh,
405 max_prox_iter=config.maxProxIter,
406 )
407 elif config.optimizer == "fista":
408 parameterization = partial(
409 lite.init_fista_component,
410 bg_thresh=config.backgroundThresh,
411 )
412 else:
413 raise ValueError("Unrecognized optimizer. Must be either 'adaprox' or 'fista'.")
414 sources = lite.parameterize_sources(sources, observation, parameterization)
416 # Attach the peak to all of the initialized sources
417 for k, center in enumerate(centers):
418 # This is just to make sure that there isn't a coding bug
419 if len(sources[k].components) > 0 and np.any(sources[k].center != center):
420 raise ValueError("Misaligned center, expected {center} but got {sources[k].center}")
421 # Store the record for the peak with the appropriate source
422 sources[k].detectedPeak = footprint.peaks[k]
424 blend = lite.LiteBlend(sources, observation)
426 # Initialize each source with its best fit spectrum
427 # This significantly cuts down on the number of iterations
428 # that the optimizer needs and usually results in a better
429 # fit, but using least squares on a very large blend causes memory issues.
430 # This is typically the most expensive operation in deblending, memorywise.
431 spectrumInit = False
432 if config.setSpectra:
433 if config.maxSpectrumCutoff <= 0 or len(centers) * bbox.getArea() < config.maxSpectrumCutoff:
434 spectrumInit = True
435 blend.fit_spectra()
437 # Set the sources that could not be initialized and were skipped
438 skipped = [src for src in sources if src.is_null]
440 blend.fit(max_iter=config.maxIter, e_rel=config.relativeError, min_iter=config.minIter)
442 return blend, skipped, spectrumInit
445@dataclass
446class DeblenderMetrics:
447 """Metrics and measurements made on single sources.
449 Store deblender metrics to be added as attributes to a scarlet source
450 before it is converted into a `SourceRecord`.
451 When DM-34414 is finished this class will be eliminated and the metrics
452 will be added to the schema using a pipeline task that calculates them
453 from the stored deconvolved models.
455 All of the parameters are one dimensional numpy arrays,
456 with an element for each band in the observed images.
458 `maxOverlap` is useful as a metric for determining how blended a source
459 is because if it only overlaps with other sources at or below
460 the noise level, it is likely to be a mostly isolated source
461 in the deconvolved model frame.
463 `fluxOverlapFraction` is potentially more useful than the canonical
464 "blendedness" (or purity) metric because it accounts for potential
465 biases created during deblending by not weighting the overlapping
466 flux with the flux of this sources model.
468 Attributes
469 ----------
470 maxOverlap : `numpy.ndarray`
471 The maximum overlap that the source has with its neighbors in
472 a single pixel.
473 fluxOverlap : `numpy.ndarray`
474 The total flux from neighbors overlapping with the current source.
475 fluxOverlapFraction : `numpy.ndarray`
476 The fraction of `flux from neighbors/source flux` for a
477 given source within the source's footprint.
478 blendedness : `numpy.ndarray`
479 The metric for determining how blended a source is using the
480 Bosch et al. 2018 metric for "blendedness." Note that some
481 surveys use the term "purity," which is `1-blendedness`.
482 """
483 maxOverlap: np.array
484 fluxOverlap: np.array
485 fluxOverlapFraction: np.array
486 blendedness: np.array
489def setDeblenderMetrics(blend):
490 """Set metrics that can be used to evalute the deblender accuracy
492 This function calculates the `DeblenderMetrics` for each source in the
493 blend, and assigns it to that sources `metrics` property in place.
495 Parameters
496 ----------
497 blend : `scarlet.lite.Blend`
498 The blend containing the sources to measure.
499 """
500 # Store the full model of the scene for comparison
501 blendModel = blend.get_model()
502 for k, src in enumerate(blend.sources):
503 # Extract the source model in the full bounding box
504 model = src.get_model(bbox=blend.bbox)
505 # The footprint is the 2D array of non-zero pixels in each band
506 footprint = np.bitwise_or.reduce(model > 0, axis=0)
507 # Calculate the metrics.
508 # See `DeblenderMetrics` for a description of each metric.
509 neighborOverlap = (blendModel-model) * footprint[None, :, :]
510 maxOverlap = np.max(neighborOverlap, axis=(1, 2))
511 fluxOverlap = np.sum(neighborOverlap, axis=(1, 2))
512 fluxModel = np.sum(model, axis=(1, 2))
513 fluxOverlapFraction = np.zeros((len(model), ), dtype=float)
514 isFinite = fluxModel > 0
515 fluxOverlapFraction[isFinite] = fluxOverlap[isFinite]/fluxModel[isFinite]
516 blendedness = 1 - np.sum(model*model, axis=(1, 2))/np.sum(blendModel*model, axis=(1, 2))
517 src.metrics = DeblenderMetrics(maxOverlap, fluxOverlap, fluxOverlapFraction, blendedness)
520class ScarletDeblendConfig(pexConfig.Config):
521 """MultibandDeblendConfig
523 Configuration for the multiband deblender.
524 The parameters are organized by the parameter types, which are
525 - Stopping Criteria: Used to determine if the fit has converged
526 - Position Fitting Criteria: Used to fit the positions of the peaks
527 - Constraints: Used to apply constraints to the peaks and their components
528 - Other: Parameters that don't fit into the above categories
529 """
530 # Stopping Criteria
531 minIter = pexConfig.Field(dtype=int, default=1,
532 doc="Minimum number of iterations before the optimizer is allowed to stop.")
533 maxIter = pexConfig.Field(dtype=int, default=300,
534 doc=("Maximum number of iterations to deblend a single parent"))
535 relativeError = pexConfig.Field(dtype=float, default=1e-2,
536 doc=("Change in the loss function between iterations to exit fitter. "
537 "Typically this is `1e-2` if measurements will be made on the "
538 "flux re-distributed models and `1e-4` when making measurements "
539 "on the models themselves."))
541 # Constraints
542 morphThresh = pexConfig.Field(dtype=float, default=1,
543 doc="Fraction of background RMS a pixel must have"
544 "to be included in the initial morphology")
545 # Lite Parameters
546 # All of these parameters (except version) are only valid if version='lite'
547 version = pexConfig.ChoiceField(
548 dtype=str,
549 default="lite",
550 allowed={
551 "scarlet": "main scarlet version (likely to be deprecated soon)",
552 "lite": "Optimized version of scarlet for survey data from a single instrument",
553 },
554 doc="The version of scarlet to use.",
555 )
556 optimizer = pexConfig.ChoiceField(
557 dtype=str,
558 default="adaprox",
559 allowed={
560 "adaprox": "Proximal ADAM optimization",
561 "fista": "Accelerated proximal gradient method",
562 },
563 doc="The optimizer to use for fitting parameters and is only used when version='lite'",
564 )
565 morphImage = pexConfig.ChoiceField(
566 dtype=str,
567 default="chi2",
568 allowed={
569 "chi2": "Initialize sources on a chi^2 image made from all available bands",
570 "wavelet": "Initialize sources using a wavelet decomposition of the chi^2 image",
571 },
572 doc="The type of image to use for initializing the morphology. "
573 "Must be either 'chi2' or 'wavelet'. "
574 )
575 backgroundThresh = pexConfig.Field(
576 dtype=float,
577 default=0.25,
578 doc="Fraction of background to use for a sparsity threshold. "
579 "This prevents sources from growing unrealistically outside "
580 "the parent footprint while still modeling flux correctly "
581 "for bright sources."
582 )
583 maxProxIter = pexConfig.Field(
584 dtype=int,
585 default=1,
586 doc="Maximum number of proximal operator iterations inside of each "
587 "iteration of the optimizer. "
588 "This config field is only used if version='lite' and optimizer='adaprox'."
589 )
590 waveletScales = pexConfig.Field(
591 dtype=int,
592 default=5,
593 doc="Number of wavelet scales to use for wavelet initialization. "
594 "This field is only used when `version`='lite' and `morphImage`='wavelet'."
595 )
597 # Other scarlet paremeters
598 useWeights = pexConfig.Field(
599 dtype=bool, default=True,
600 doc=("Whether or not use use inverse variance weighting."
601 "If `useWeights` is `False` then flat weights are used"))
602 modelPsfSize = pexConfig.Field(
603 dtype=int, default=11,
604 doc="Model PSF side length in pixels")
605 modelPsfSigma = pexConfig.Field(
606 dtype=float, default=0.8,
607 doc="Define sigma for the model frame PSF")
608 minSNR = pexConfig.Field(
609 dtype=float, default=50,
610 doc="Minimum Signal to noise to accept the source."
611 "Sources with lower flux will be initialized with the PSF but updated "
612 "like an ordinary ExtendedSource (known in scarlet as a `CompactSource`).")
613 saveTemplates = pexConfig.Field(
614 dtype=bool, default=True,
615 doc="Whether or not to save the SEDs and templates")
616 processSingles = pexConfig.Field(
617 dtype=bool, default=True,
618 doc="Whether or not to process isolated sources in the deblender")
619 convolutionType = pexConfig.Field(
620 dtype=str, default="fft",
621 doc="Type of convolution to render the model to the observations.\n"
622 "- 'fft': perform convolutions in Fourier space\n"
623 "- 'real': peform convolutions in real space.")
624 sourceModel = pexConfig.Field(
625 dtype=str, default="double",
626 doc=("How to determine which model to use for sources, from\n"
627 "- 'single': use a single component for all sources\n"
628 "- 'double': use a bulge disk model for all sources\n"
629 "- 'compact': use a single component model, initialzed with a point source morphology, "
630 " for all sources\n"
631 "- 'point': use a point-source model for all sources\n"
632 "- 'fit: use a PSF fitting model to determine the number of components (not yet implemented)"),
633 deprecated="This field will be deprecated when the default for `version` is changed to `lite`.",
634 )
635 setSpectra = pexConfig.Field(
636 dtype=bool, default=True,
637 doc="Whether or not to solve for the best-fit spectra during initialization. "
638 "This makes initialization slightly longer, as it requires a convolution "
639 "to set the optimal spectra, but results in a much better initial log-likelihood "
640 "and reduced total runtime, with convergence in fewer iterations."
641 "This option is only used when "
642 "peaks*area < `maxSpectrumCutoff` will use the improved initialization.")
644 # Mask-plane restrictions
645 badMask = pexConfig.ListField(
646 dtype=str, default=["BAD", "CR", "NO_DATA", "SAT", "SUSPECT", "EDGE"],
647 doc="Whether or not to process isolated sources in the deblender")
648 statsMask = pexConfig.ListField(dtype=str, default=["SAT", "INTRP", "NO_DATA"],
649 doc="Mask planes to ignore when performing statistics")
650 maskLimits = pexConfig.DictField(
651 keytype=str,
652 itemtype=float,
653 default={},
654 doc=("Mask planes with the corresponding limit on the fraction of masked pixels. "
655 "Sources violating this limit will not be deblended. "
656 "If the fraction is `0` then the limit is a single pixel."),
657 )
659 # Size restrictions
660 maxNumberOfPeaks = pexConfig.Field(
661 dtype=int, default=200,
662 doc=("Only deblend the brightest maxNumberOfPeaks peaks in the parent"
663 " (<= 0: unlimited)"))
664 maxFootprintArea = pexConfig.Field(
665 dtype=int, default=100_000,
666 doc=("Maximum area for footprints before they are ignored as large; "
667 "non-positive means no threshold applied"))
668 maxAreaTimesPeaks = pexConfig.Field(
669 dtype=int, default=10_000_000,
670 doc=("Maximum rectangular footprint area * nPeaks in the footprint. "
671 "This was introduced in DM-33690 to prevent fields that are crowded or have a "
672 "LSB galaxy that causes memory intensive initialization in scarlet from dominating "
673 "the overall runtime and/or causing the task to run out of memory. "
674 "(<= 0: unlimited)")
675 )
676 maxFootprintSize = pexConfig.Field(
677 dtype=int, default=0,
678 doc=("Maximum linear dimension for footprints before they are ignored "
679 "as large; non-positive means no threshold applied"))
680 minFootprintAxisRatio = pexConfig.Field(
681 dtype=float, default=0.0,
682 doc=("Minimum axis ratio for footprints before they are ignored "
683 "as large; non-positive means no threshold applied"))
684 maxSpectrumCutoff = pexConfig.Field(
685 dtype=int, default=1_000_000,
686 doc=("Maximum number of pixels * number of sources in a blend. "
687 "This is different than `maxFootprintArea` because this isn't "
688 "the footprint area but the area of the bounding box that "
689 "contains the footprint, and is also multiplied by the number of"
690 "sources in the footprint. This prevents large skinny blends with "
691 "a high density of sources from running out of memory. "
692 "If `maxSpectrumCutoff == -1` then there is no cutoff.")
693 )
694 # Failure modes
695 fallback = pexConfig.Field(
696 dtype=bool, default=True,
697 doc="Whether or not to fallback to a smaller number of components if a source does not initialize"
698 )
699 notDeblendedMask = pexConfig.Field(
700 dtype=str, default="NOT_DEBLENDED", optional=True,
701 doc="Mask name for footprints not deblended, or None")
702 catchFailures = pexConfig.Field(
703 dtype=bool, default=True,
704 doc=("If True, catch exceptions thrown by the deblender, log them, "
705 "and set a flag on the parent, instead of letting them propagate up"))
707 # Other options
708 columnInheritance = pexConfig.DictField(
709 keytype=str, itemtype=str, default={
710 "deblend_nChild": "deblend_parentNChild",
711 "deblend_nPeaks": "deblend_parentNPeaks",
712 "deblend_spectrumInitFlag": "deblend_spectrumInitFlag",
713 "deblend_blendConvergenceFailedFlag": "deblend_blendConvergenceFailedFlag",
714 },
715 doc="Columns to pass from the parent to the child. "
716 "The key is the name of the column for the parent record, "
717 "the value is the name of the column to use for the child."
718 )
719 pseudoColumns = pexConfig.ListField(
720 dtype=str, default=['merge_peak_sky', 'sky_source'],
721 doc="Names of flags which should never be deblended."
722 )
724 # Logging option(s)
725 loggingInterval = pexConfig.Field(
726 dtype=int, default=600,
727 doc="Interval (in seconds) to log messages (at VERBOSE level) while deblending sources.",
728 deprecated="This field is no longer used and will be removed in v25.",
729 )
730 # Testing options
731 # Some obs packages and ci packages run the full pipeline on a small
732 # subset of data to test that the pipeline is functioning properly.
733 # This is not meant as scientific validation, so it can be useful
734 # to only run on a small subset of the data that is large enough to
735 # test the desired pipeline features but not so long that the deblender
736 # is the tall pole in terms of execution times.
737 useCiLimits = pexConfig.Field(
738 dtype=bool, default=False,
739 doc="Limit the number of sources deblended for CI to prevent long build times")
740 ciDeblendChildRange = pexConfig.ListField(
741 dtype=int, default=[5, 10],
742 doc="Only deblend parent Footprints with a number of peaks in the (inclusive) range indicated."
743 "If `useCiLimits==False` then this parameter is ignored.")
744 ciNumParentsToDeblend = pexConfig.Field(
745 dtype=int, default=10,
746 doc="Only use the first `ciNumParentsToDeblend` parent footprints with a total peak count "
747 "within `ciDebledChildRange`. "
748 "If `useCiLimits==False` then this parameter is ignored.")
751class ScarletDeblendTask(pipeBase.Task):
752 """ScarletDeblendTask
754 Split blended sources into individual sources.
756 This task has no return value; it only modifies the SourceCatalog in-place.
757 """
758 ConfigClass = ScarletDeblendConfig
759 _DefaultName = "scarletDeblend"
761 def __init__(self, schema, peakSchema=None, **kwargs):
762 """Create the task, adding necessary fields to the given schema.
764 Parameters
765 ----------
766 schema : `lsst.afw.table.schema.schema.Schema`
767 Schema object for measurement fields; will be modified in-place.
768 peakSchema : `lsst.afw.table.schema.schema.Schema`
769 Schema of Footprint Peaks that will be passed to the deblender.
770 Any fields beyond the PeakTable minimal schema will be transferred
771 to the main source Schema. If None, no fields will be transferred
772 from the Peaks.
773 filters : list of str
774 Names of the filters used for the eposures. This is needed to store
775 the SED as a field
776 **kwargs
777 Passed to Task.__init__.
778 """
779 pipeBase.Task.__init__(self, **kwargs)
781 peakMinimalSchema = afwDet.PeakTable.makeMinimalSchema()
782 if peakSchema is None:
783 # In this case, the peakSchemaMapper will transfer nothing, but
784 # we'll still have one
785 # to simplify downstream code
786 self.peakSchemaMapper = afwTable.SchemaMapper(peakMinimalSchema, schema)
787 else:
788 self.peakSchemaMapper = afwTable.SchemaMapper(peakSchema, schema)
789 for item in peakSchema:
790 if item.key not in peakMinimalSchema:
791 self.peakSchemaMapper.addMapping(item.key, item.field)
792 # Because SchemaMapper makes a copy of the output schema
793 # you give its ctor, it isn't updating this Schema in
794 # place. That's probably a design flaw, but in the
795 # meantime, we'll keep that schema in sync with the
796 # peakSchemaMapper.getOutputSchema() manually, by adding
797 # the same fields to both.
798 schema.addField(item.field)
799 assert schema == self.peakSchemaMapper.getOutputSchema(), "Logic bug mapping schemas"
800 self._addSchemaKeys(schema)
801 self.schema = schema
802 self.toCopyFromParent = [item.key for item in self.schema
803 if item.field.getName().startswith("merge_footprint")]
805 def _addSchemaKeys(self, schema):
806 """Add deblender specific keys to the schema
807 """
808 # Parent (blend) fields
809 self.runtimeKey = schema.addField('deblend_runtime', type=np.float32, doc='runtime in ms')
810 self.iterKey = schema.addField('deblend_iterations', type=np.int32, doc='iterations to converge')
811 self.nChildKey = schema.addField('deblend_nChild', type=np.int32,
812 doc='Number of children this object has (defaults to 0)')
813 self.nPeaksKey = schema.addField("deblend_nPeaks", type=np.int32,
814 doc="Number of initial peaks in the blend. "
815 "This includes peaks that may have been culled "
816 "during deblending or failed to deblend")
817 # Skipped flags
818 self.deblendSkippedKey = schema.addField('deblend_skipped', type='Flag',
819 doc="Deblender skipped this source")
820 self.isolatedParentKey = schema.addField('deblend_isolatedParent', type='Flag',
821 doc='The source has only a single peak '
822 'and was not deblended')
823 self.pseudoKey = schema.addField('deblend_isPseudo', type='Flag',
824 doc='The source is identified as a "pseudo" source and '
825 'was not deblended')
826 self.tooManyPeaksKey = schema.addField('deblend_tooManyPeaks', type='Flag',
827 doc='Source had too many peaks; '
828 'only the brightest were included')
829 self.tooBigKey = schema.addField('deblend_parentTooBig', type='Flag',
830 doc='Parent footprint covered too many pixels')
831 self.maskedKey = schema.addField('deblend_masked', type='Flag',
832 doc='Parent footprint had too many masked pixels')
833 # Convergence flags
834 self.sedNotConvergedKey = schema.addField('deblend_sedConvergenceFailed', type='Flag',
835 doc='scarlet sed optimization did not converge before'
836 'config.maxIter')
837 self.morphNotConvergedKey = schema.addField('deblend_morphConvergenceFailed', type='Flag',
838 doc='scarlet morph optimization did not converge before'
839 'config.maxIter')
840 self.blendConvergenceFailedFlagKey = schema.addField('deblend_blendConvergenceFailedFlag',
841 type='Flag',
842 doc='at least one source in the blend'
843 'failed to converge')
844 # Error flags
845 self.deblendFailedKey = schema.addField('deblend_failed', type='Flag',
846 doc="Deblending failed on source")
847 self.deblendErrorKey = schema.addField('deblend_error', type="String", size=25,
848 doc='Name of error if the blend failed')
849 # Deblended source fields
850 self.peakCenter = afwTable.Point2IKey.addFields(schema, name="deblend_peak_center",
851 doc="Center used to apply constraints in scarlet",
852 unit="pixel")
853 self.peakIdKey = schema.addField("deblend_peakId", type=np.int32,
854 doc="ID of the peak in the parent footprint. "
855 "This is not unique, but the combination of 'parent'"
856 "and 'peakId' should be for all child sources. "
857 "Top level blends with no parents have 'peakId=0'")
858 self.modelCenterFlux = schema.addField('deblend_peak_instFlux', type=float, units='count',
859 doc="The instFlux at the peak position of deblended mode")
860 self.modelTypeKey = schema.addField("deblend_modelType", type="String", size=25,
861 doc="The type of model used, for example "
862 "MultiExtendedSource, SingleExtendedSource, PointSource")
863 self.parentNPeaksKey = schema.addField("deblend_parentNPeaks", type=np.int32,
864 doc="deblend_nPeaks from this records parent.")
865 self.parentNChildKey = schema.addField("deblend_parentNChild", type=np.int32,
866 doc="deblend_nChild from this records parent.")
867 self.scarletFluxKey = schema.addField("deblend_scarletFlux", type=np.float32,
868 doc="Flux measurement from scarlet")
869 self.scarletLogLKey = schema.addField("deblend_logL", type=np.float32,
870 doc="Final logL, used to identify regressions in scarlet.")
871 self.edgePixelsKey = schema.addField('deblend_edgePixels', type='Flag',
872 doc='Source had flux on the edge of the parent footprint')
873 self.scarletSpectrumInitKey = schema.addField("deblend_spectrumInitFlag", type='Flag',
874 doc="True when scarlet initializes sources "
875 "in the blend with a more accurate spectrum. "
876 "The algorithm uses a lot of memory, "
877 "so large dense blends will use "
878 "a less accurate initialization.")
879 self.nComponentsKey = schema.addField("deblend_nComponents", type=np.int32,
880 doc="Number of components in a ScarletLiteSource. "
881 "If `config.version != 'lite'`then "
882 "this column is set to zero.")
883 self.psfKey = schema.addField('deblend_deblendedAsPsf', type='Flag',
884 doc='Deblender thought this source looked like a PSF')
885 # Blendedness/classification metrics
886 self.maxOverlapKey = schema.addField("deblend_maxOverlap", type=np.float32,
887 doc="Maximum overlap with all of the other neighbors flux "
888 "combined."
889 "This is useful as a metric for determining how blended a "
890 "source is because if it only overlaps with other sources "
891 "at or below the noise level, it is likely to be a mostly "
892 "isolated source in the deconvolved model frame.")
893 self.fluxOverlapKey = schema.addField("deblend_fluxOverlap", type=np.float32,
894 doc="This is the total flux from neighboring objects that "
895 "overlaps with this source.")
896 self.fluxOverlapFractionKey = schema.addField("deblend_fluxOverlapFraction", type=np.float32,
897 doc="This is the fraction of "
898 "`flux from neighbors/source flux` "
899 "for a given source within the source's"
900 "footprint.")
901 self.blendednessKey = schema.addField("deblend_blendedness", type=np.float32,
902 doc="The Bosch et al. 2018 metric for 'blendedness.' ")
904 @timeMethod
905 def run(self, mExposure, mergedSources):
906 """Get the psf from each exposure and then run deblend().
908 Parameters
909 ----------
910 mExposure : `MultibandExposure`
911 The exposures should be co-added images of the same
912 shape and region of the sky.
913 mergedSources : `SourceCatalog`
914 The merged `SourceCatalog` that contains parent footprints
915 to (potentially) deblend.
917 Returns
918 -------
919 templateCatalogs: dict
920 Keys are the names of the filters and the values are
921 `lsst.afw.table.source.source.SourceCatalog`'s.
922 These are catalogs with heavy footprints that are the templates
923 created by the multiband templates.
924 """
925 return self.deblend(mExposure, mergedSources)
927 @timeMethod
928 def deblend(self, mExposure, catalog):
929 """Deblend a data cube of multiband images
931 Parameters
932 ----------
933 mExposure : `MultibandExposure`
934 The exposures should be co-added images of the same
935 shape and region of the sky.
936 catalog : `SourceCatalog`
937 The merged `SourceCatalog` that contains parent footprints
938 to (potentially) deblend. The new deblended sources are
939 appended to this catalog in place.
941 Returns
942 -------
943 catalogs : `dict` or `None`
944 Keys are the names of the filters and the values are
945 `lsst.afw.table.source.source.SourceCatalog`'s.
946 These are catalogs with heavy footprints that are the templates
947 created by the multiband templates.
948 """
949 import time
951 # Cull footprints if required by ci
952 if self.config.useCiLimits:
953 self.log.info("Using CI catalog limits, the original number of sources to deblend was %d.",
954 len(catalog))
955 # Select parents with a number of children in the range
956 # config.ciDeblendChildRange
957 minChildren, maxChildren = self.config.ciDeblendChildRange
958 nPeaks = np.array([len(src.getFootprint().peaks) for src in catalog])
959 childrenInRange = np.where((nPeaks >= minChildren) & (nPeaks <= maxChildren))[0]
960 if len(childrenInRange) < self.config.ciNumParentsToDeblend:
961 raise ValueError("Fewer than ciNumParentsToDeblend children were contained in the range "
962 "indicated by ciDeblendChildRange. Adjust this range to include more "
963 "parents.")
964 # Keep all of the isolated parents and the first
965 # `ciNumParentsToDeblend` children
966 parents = nPeaks == 1
967 children = np.zeros((len(catalog),), dtype=bool)
968 children[childrenInRange[:self.config.ciNumParentsToDeblend]] = True
969 catalog = catalog[parents | children]
970 # We need to update the IdFactory, otherwise the the source ids
971 # will not be sequential
972 idFactory = catalog.getIdFactory()
973 maxId = np.max(catalog["id"])
974 idFactory.notify(maxId)
976 filters = mExposure.filters
977 self.log.info("Deblending %d sources in %d exposure bands", len(catalog), len(mExposure))
978 periodicLog = PeriodicLogger(self.log)
980 # Create a set of wavelet coefficients if using wavelet initialization
981 if self.config.version == "lite" and self.config.morphImage == "wavelet":
982 images = mExposure.image.array
983 variance = mExposure.variance.array
984 wavelets = get_detect_wavelets(images, variance, scales=self.config.waveletScales)
985 else:
986 wavelets = None
988 # Add the NOT_DEBLENDED mask to the mask plane in each band
989 if self.config.notDeblendedMask:
990 for mask in mExposure.mask:
991 mask.addMaskPlane(self.config.notDeblendedMask)
993 nParents = len(catalog)
994 nDeblendedParents = 0
995 skippedParents = []
996 multibandColumns = {
997 "heavies": [],
998 "fluxes": [],
999 "centerFluxes": [],
1000 "metrics": [],
1001 }
1002 weightedColumns = {
1003 "heavies": [],
1004 "fluxes": [],
1005 "centerFluxes": [],
1006 }
1007 for parentIndex in range(nParents):
1008 parent = catalog[parentIndex]
1009 foot = parent.getFootprint()
1010 bbox = foot.getBBox()
1011 peaks = foot.getPeaks()
1013 # Since we use the first peak for the parent object, we should
1014 # propagate its flags to the parent source.
1015 parent.assign(peaks[0], self.peakSchemaMapper)
1017 # Block of conditions for skipping a parent with multiple children
1018 if (skipArgs := self._checkSkipped(parent, mExposure)) is not None:
1019 self._skipParent(parent, *skipArgs)
1020 skippedParents.append(parentIndex)
1021 continue
1023 nDeblendedParents += 1
1024 self.log.trace("Parent %d: deblending %d peaks", parent.getId(), len(peaks))
1025 # Run the deblender
1026 blendError = None
1027 try:
1028 t0 = time.monotonic()
1029 # Build the parameter lists with the same ordering
1030 if self.config.version == "scarlet":
1031 blend, skipped, spectrumInit = deblend(mExposure, foot, self.config)
1032 elif self.config.version == "lite":
1033 blend, skipped, spectrumInit = deblend_lite(mExposure, foot, self.config, wavelets)
1034 tf = time.monotonic()
1035 runtime = (tf-t0)*1000
1036 converged = _checkBlendConvergence(blend, self.config.relativeError)
1037 # Store the number of components in the blend
1038 if self.config.version == "lite":
1039 nComponents = len(blend.components)
1040 else:
1041 nComponents = 0
1042 nChild = len(blend.sources)
1043 # Catch all errors and filter out the ones that we know about
1044 except Exception as e:
1045 blendError = type(e).__name__
1046 if isinstance(e, ScarletGradientError):
1047 parent.set(self.iterKey, e.iterations)
1048 elif not isinstance(e, IncompleteDataError):
1049 blendError = "UnknownError"
1050 if self.config.catchFailures:
1051 # Make it easy to find UnknownErrors in the log file
1052 self.log.warn("UnknownError")
1053 import traceback
1054 traceback.print_exc()
1055 else:
1056 raise
1058 self._skipParent(
1059 parent=parent,
1060 skipKey=self.deblendFailedKey,
1061 logMessage=f"Unable to deblend source {parent.getId}: {blendError}",
1062 )
1063 parent.set(self.deblendErrorKey, blendError)
1064 skippedParents.append(parentIndex)
1065 continue
1067 # Update the parent record with the deblending results
1068 if self.config.version == "scarlet":
1069 logL = -blend.loss[-1] + blend.observations[0].log_norm
1070 elif self.config.version == "lite":
1071 logL = blend.loss[-1]
1072 setDeblenderMetrics(blend)
1073 self._updateParentRecord(
1074 parent=parent,
1075 nPeaks=len(peaks),
1076 nChild=nChild,
1077 nComponents=nComponents,
1078 runtime=runtime,
1079 iterations=len(blend.loss),
1080 logL=logL,
1081 spectrumInit=spectrumInit,
1082 converged=converged,
1083 )
1085 # Add each deblended source to the catalog
1086 for k, scarletSource in enumerate(blend.sources):
1087 # Skip any sources with no flux or that scarlet skipped because
1088 # it could not initialize
1089 if k in skipped or (self.config.version == "lite" and scarletSource.is_null):
1090 # No need to propagate anything
1091 continue
1092 parent.set(self.deblendSkippedKey, False)
1093 if self.config.version == "lite":
1094 mHeavy = liteModelToHeavy(scarletSource, mExposure, blend, xy0=bbox.getMin())
1095 weightedHeavy = liteModelToHeavy(
1096 scarletSource, mExposure, blend, xy0=bbox.getMin(), useFlux=True)
1097 weightedColumns["heavies"].append(weightedHeavy)
1098 flux = scarletSource.get_model(use_flux=True).sum(axis=(1, 2))
1099 weightedColumns["fluxes"].append({
1100 filters[fidx]: _flux
1101 for fidx, _flux in enumerate(flux)
1102 })
1103 centerFlux = self._getCenterFlux(weightedHeavy, scarletSource, xy0=bbox.getMin())
1104 weightedColumns["centerFluxes"].append(centerFlux)
1105 multibandColumns["metrics"].append(scarletSource.metrics)
1106 else:
1107 mHeavy = modelToHeavy(scarletSource, mExposure, blend, xy0=bbox.getMin())
1108 multibandColumns["heavies"].append(mHeavy)
1109 flux = scarlet.measure.flux(scarletSource)
1110 multibandColumns["fluxes"].append({
1111 filters[fidx]: _flux
1112 for fidx, _flux in enumerate(flux)
1113 })
1114 centerFlux = self._getCenterFlux(mHeavy, scarletSource, xy0=bbox.getMin())
1115 multibandColumns["centerFluxes"].append(centerFlux)
1117 # Add all fields except the HeavyFootprint to the
1118 # source record
1119 self._addChild(
1120 parent=parent,
1121 mHeavy=mHeavy,
1122 catalog=catalog,
1123 scarletSource=scarletSource,
1124 )
1126 # Log a message if it has been a while since the last log.
1127 periodicLog.log("Deblended %d parent sources out of %d", parentIndex + 1, nParents)
1129 # Clear the cached values in scarlet to clear out memory
1130 scarlet.cache.Cache._cache = {}
1132 # Make sure that the number of new sources matches the number of
1133 # entries in each of the band dependent columns.
1134 # This should never trigger and is just a sanity check.
1135 nChildren = len(catalog) - nParents
1136 if np.any([len(meas) != nChildren for meas in multibandColumns.values()]):
1137 msg = f"Added {len(catalog)-nParents} new sources, but have "
1138 msg += ", ".join([
1139 f"{len(value)} {key}"
1140 for key, value in multibandColumns.items()
1141 ])
1142 raise RuntimeError(msg)
1143 # Make a copy of the catlog in each band and update the footprints
1144 catalogs = {}
1145 for fidx, f in enumerate(filters):
1146 _catalog = afwTable.SourceCatalog(catalog.table.clone())
1147 _catalog.extend(catalog, deep=True)
1149 # Update the footprints and columns that are different
1150 # for each filter
1151 for sourceIndex, source in enumerate(_catalog[nParents:]):
1152 source.setFootprint(multibandColumns["heavies"][sourceIndex][f])
1153 source.set(self.scarletFluxKey, multibandColumns["fluxes"][sourceIndex][f])
1154 source.set(self.modelCenterFlux, multibandColumns["centerFluxes"][sourceIndex][f])
1155 if self.config.version == "lite":
1156 self._setMetrics(source, multibandColumns["metrics"][sourceIndex], fidx)
1157 catalogs[f] = _catalog
1159 weightedCatalogs = {}
1160 if self.config.version == "lite":
1161 # Also create a catalog by reweighting the flux
1162 weightedCatalogs = {}
1163 for fidx, f in enumerate(filters):
1164 _catalog = afwTable.SourceCatalog(catalog.table.clone())
1165 _catalog.extend(catalog, deep=True)
1167 # Update the footprints and columns that are different
1168 # for each filter
1169 for sourceIndex, source in enumerate(_catalog[nParents:]):
1170 source.setFootprint(weightedColumns["heavies"][sourceIndex][f])
1171 source.set(self.scarletFluxKey, weightedColumns["fluxes"][sourceIndex][f])
1172 source.set(self.modelCenterFlux, weightedColumns["centerFluxes"][sourceIndex][f])
1173 self._setMetrics(source, multibandColumns["metrics"][sourceIndex], fidx)
1174 weightedCatalogs[f] = _catalog
1176 # Update the mExposure mask with the footprint of skipped parents
1177 if self.config.notDeblendedMask:
1178 for mask in mExposure.mask:
1179 for parentIndex in skippedParents:
1180 fp = _catalog[parentIndex].getFootprint()
1181 fp.spans.setMask(mask, mask.getPlaneBitMask(self.config.notDeblendedMask))
1183 self.log.info("Deblender results: of %d parent sources, %d were deblended, "
1184 "creating %d children, for a total of %d sources",
1185 nParents, nDeblendedParents, nChildren, len(catalog))
1186 return catalogs, weightedCatalogs
1188 def _setMetrics(self, source, metrics, filterIndex):
1189 """Set the metrics for a source in either a template or weighted
1190 catalog.
1192 Parameters
1193 ----------
1194 source : `lsst.afw.SourceRecord`
1195 The source that is being updated.
1196 metrics : `DeblenderMetric`
1197 The deblender metrics calculated in the deconvolved seeing.
1198 filterIndex : `int`
1199 The index of the filter to select a given element of the metric.
1200 """
1201 source.set(self.maxOverlapKey, metrics.maxOverlap[filterIndex])
1202 source.set(self.fluxOverlapKey, metrics.fluxOverlap[filterIndex])
1203 source.set(self.fluxOverlapFractionKey, metrics.fluxOverlapFraction[filterIndex])
1204 source.set(self.blendednessKey, metrics.blendedness[filterIndex])
1206 def _isLargeFootprint(self, footprint):
1207 """Returns whether a Footprint is large
1209 'Large' is defined by thresholds on the area, size and axis ratio,
1210 and total area of the bounding box multiplied by
1211 the number of children.
1212 These may be disabled independently by configuring them to be
1213 non-positive.
1214 """
1215 if self.config.maxFootprintArea > 0 and footprint.getArea() > self.config.maxFootprintArea:
1216 return True
1217 if self.config.maxFootprintSize > 0:
1218 bbox = footprint.getBBox()
1219 if max(bbox.getWidth(), bbox.getHeight()) > self.config.maxFootprintSize:
1220 return True
1221 if self.config.minFootprintAxisRatio > 0:
1222 axes = afwEll.Axes(footprint.getShape())
1223 if axes.getB() < self.config.minFootprintAxisRatio*axes.getA():
1224 return True
1225 if self.config.maxAreaTimesPeaks > 0:
1226 if footprint.getBBox().getArea() * len(footprint.peaks) > self.config.maxAreaTimesPeaks:
1227 return True
1228 return False
1230 def _isMasked(self, footprint, mExposure):
1231 """Returns whether the footprint violates the mask limits
1233 Parameters
1234 ----------
1235 footprint : `lsst.afw.detection.Footprint`
1236 The footprint to check for masked pixels
1237 mMask : `lsst.afw.image.MaskX`
1238 The mask plane to check for masked pixels in the `footprint`.
1240 Returns
1241 -------
1242 isMasked : `bool`
1243 `True` if `self.config.maskPlaneLimits` is less than the
1244 fraction of pixels for a given mask in
1245 `self.config.maskLimits`.
1246 """
1247 bbox = footprint.getBBox()
1248 mask = np.bitwise_or.reduce(mExposure.mask[:, bbox].array, axis=0)
1249 size = float(footprint.getArea())
1250 for maskName, limit in self.config.maskLimits.items():
1251 maskVal = mExposure.mask.getPlaneBitMask(maskName)
1252 _mask = afwImage.MaskX(mask & maskVal, xy0=bbox.getMin())
1253 # spanset of masked pixels
1254 maskedSpan = footprint.spans.intersect(_mask, maskVal)
1255 if (maskedSpan.getArea())/size > limit:
1256 return True
1257 return False
1259 def _skipParent(self, parent, skipKey, logMessage):
1260 """Update a parent record that is not being deblended.
1262 This is a fairly trivial function but is implemented to ensure
1263 that a skipped parent updates the appropriate columns
1264 consistently, and always has a flag to mark the reason that
1265 it is being skipped.
1267 Parameters
1268 ----------
1269 parent : `lsst.afw.table.source.source.SourceRecord`
1270 The parent record to flag as skipped.
1271 skipKey : `bool`
1272 The name of the flag to mark the reason for skipping.
1273 logMessage : `str`
1274 The message to display in a log.trace when a source
1275 is skipped.
1276 """
1277 if logMessage is not None:
1278 self.log.trace(logMessage)
1279 self._updateParentRecord(
1280 parent=parent,
1281 nPeaks=len(parent.getFootprint().peaks),
1282 nChild=0,
1283 nComponents=0,
1284 runtime=np.nan,
1285 iterations=0,
1286 logL=np.nan,
1287 spectrumInit=False,
1288 converged=False,
1289 )
1291 # Mark the source as skipped by the deblender and
1292 # flag the reason why.
1293 parent.set(self.deblendSkippedKey, True)
1294 parent.set(skipKey, True)
1296 def _checkSkipped(self, parent, mExposure):
1297 """Update a parent record that is not being deblended.
1299 This is a fairly trivial function but is implemented to ensure
1300 that a skipped parent updates the appropriate columns
1301 consistently, and always has a flag to mark the reason that
1302 it is being skipped.
1304 Parameters
1305 ----------
1306 parent : `lsst.afw.table.source.source.SourceRecord`
1307 The parent record to flag as skipped.
1308 mExposure : `MultibandExposure`
1309 The exposures should be co-added images of the same
1310 shape and region of the sky.
1311 Returns
1312 -------
1313 skip: `bool`
1314 `True` if the deblender will skip the parent
1315 """
1316 skipKey = None
1317 skipMessage = None
1318 footprint = parent.getFootprint()
1319 if len(footprint.peaks) < 2 and not self.config.processSingles:
1320 # Skip isolated sources unless processSingles is turned on.
1321 # Note: this does not flag isolated sources as skipped or
1322 # set the NOT_DEBLENDED mask in the exposure,
1323 # since these aren't really any skipped blends.
1324 skipKey = self.isolatedParentKey
1325 elif isPseudoSource(parent, self.config.pseudoColumns):
1326 # We also skip pseudo sources, like sky objects, which
1327 # are intended to be skipped.
1328 skipKey = self.pseudoKey
1329 if self._isLargeFootprint(footprint):
1330 # The footprint is above the maximum footprint size limit
1331 skipKey = self.tooBigKey
1332 skipMessage = f"Parent {parent.getId()}: skipping large footprint"
1333 elif self._isMasked(footprint, mExposure):
1334 # The footprint exceeds the maximum number of masked pixels
1335 skipKey = self.maskedKey
1336 skipMessage = f"Parent {parent.getId()}: skipping masked footprint"
1337 elif self.config.maxNumberOfPeaks > 0 and len(footprint.peaks) > self.config.maxNumberOfPeaks:
1338 # Unlike meas_deblender, in scarlet we skip the entire blend
1339 # if the number of peaks exceeds max peaks, since neglecting
1340 # to model any peaks often results in catastrophic failure
1341 # of scarlet to generate models for the brighter sources.
1342 skipKey = self.tooManyPeaksKey
1343 skipMessage = f"Parent {parent.getId()}: skipping blend with too many peaks"
1344 if skipKey is not None:
1345 return (skipKey, skipMessage)
1346 return None
1348 def setSkipFlags(self, mExposure, catalog):
1349 """Set the skip flags for all of the parent sources
1351 This is mostly used for testing which parent sources will be deblended
1352 and which will be skipped based on the current configuration options.
1353 Skipped sources will have the appropriate flags set in place in the
1354 catalog.
1356 Parameters
1357 ----------
1358 mExposure : `MultibandExposure`
1359 The exposures should be co-added images of the same
1360 shape and region of the sky.
1361 catalog : `SourceCatalog`
1362 The merged `SourceCatalog` that contains parent footprints
1363 to (potentially) deblend. The new deblended sources are
1364 appended to this catalog in place.
1365 """
1366 for src in catalog:
1367 if skipArgs := self._checkSkipped(src, mExposure) is not None:
1368 self._skipParent(src, *skipArgs)
1370 def _updateParentRecord(self, parent, nPeaks, nChild, nComponents,
1371 runtime, iterations, logL, spectrumInit, converged):
1372 """Update a parent record in all of the single band catalogs.
1374 Ensure that all locations that update a parent record,
1375 whether it is skipped or updated after deblending,
1376 update all of the appropriate columns.
1378 Parameters
1379 ----------
1380 parent : `lsst.afw.table.source.source.SourceRecord`
1381 The parent record to update.
1382 nPeaks : `int`
1383 Number of peaks in the parent footprint.
1384 nChild : `int`
1385 Number of children deblended from the parent.
1386 This may differ from `nPeaks` if some of the peaks
1387 were culled and have no deblended model.
1388 nComponents : `int`
1389 Total number of components in the parent.
1390 This is usually different than the number of children,
1391 since it is common for a single source to have multiple
1392 components.
1393 runtime : `float`
1394 Total runtime for deblending.
1395 iterations : `int`
1396 Total number of iterations in scarlet before convergence.
1397 logL : `float`
1398 Final log likelihood of the blend.
1399 spectrumInit : `bool`
1400 True when scarlet used `set_spectra` to initialize all
1401 sources with better initial intensities.
1402 converged : `bool`
1403 True when the optimizer reached convergence before
1404 reaching the maximum number of iterations.
1405 """
1406 parent.set(self.nPeaksKey, nPeaks)
1407 parent.set(self.nChildKey, nChild)
1408 parent.set(self.nComponentsKey, nComponents)
1409 parent.set(self.runtimeKey, runtime)
1410 parent.set(self.iterKey, iterations)
1411 parent.set(self.scarletLogLKey, logL)
1412 parent.set(self.scarletSpectrumInitKey, spectrumInit)
1413 parent.set(self.blendConvergenceFailedFlagKey, converged)
1415 def _addChild(self, parent, mHeavy, catalog, scarletSource):
1416 """Add a child to a catalog.
1418 This creates a new child in the source catalog,
1419 assigning it a parent id, and adding all columns
1420 that are independent across all filter bands.
1422 Parameters
1423 ----------
1424 parent : `lsst.afw.table.source.source.SourceRecord`
1425 The parent of the new child record.
1426 mHeavy : `lsst.detection.MultibandFootprint`
1427 The multi-band footprint containing the model and
1428 peak catalog for the new child record.
1429 catalog : `lsst.afw.table.source.source.SourceCatalog`
1430 The merged `SourceCatalog` that contains parent footprints
1431 to (potentially) deblend.
1432 scarletSource : `scarlet.Component`
1433 The scarlet model for the new source record.
1434 """
1435 src = catalog.addNew()
1436 for key in self.toCopyFromParent:
1437 src.set(key, parent.get(key))
1438 # The peak catalog is the same for all bands,
1439 # so we just use the first peak catalog
1440 peaks = mHeavy[mHeavy.filters[0]].peaks
1441 src.assign(peaks[0], self.peakSchemaMapper)
1442 src.setParent(parent.getId())
1443 src.set(self.nPeaksKey, len(peaks))
1444 # Set the psf key based on whether or not the source was
1445 # deblended using the PointSource model.
1446 # This key is not that useful anymore since we now keep track of
1447 # `modelType`, but we continue to propagate it in case code downstream
1448 # is expecting it.
1449 src.set(self.psfKey, scarletSource.__class__.__name__ == "PointSource")
1450 src.set(self.modelTypeKey, scarletSource.__class__.__name__)
1451 # We set the runtime to zero so that summing up the
1452 # runtime column will give the total time spent
1453 # running the deblender for the catalog.
1454 src.set(self.runtimeKey, 0)
1456 # Set the position of the peak from the parent footprint
1457 # This will make it easier to match the same source across
1458 # deblenders and across observations, where the peak
1459 # position is unlikely to change unless enough time passes
1460 # for a source to move on the sky.
1461 peak = peaks[0]
1462 src.set(self.peakCenter, Point2I(peak["i_x"], peak["i_y"]))
1463 src.set(self.peakIdKey, peak["id"])
1465 # Store the number of components for the source
1466 src.set(self.nComponentsKey, len(scarletSource.components))
1468 # Propagate columns from the parent to the child
1469 for parentColumn, childColumn in self.config.columnInheritance.items():
1470 src.set(childColumn, parent.get(parentColumn))
1472 def _getCenterFlux(self, mHeavy, scarletSource, xy0):
1473 """Get the flux at the center of a HeavyFootprint
1475 Parameters
1476 ----------
1477 mHeavy : `lsst.detection.MultibandFootprint`
1478 The multi-band footprint containing the model for the source.
1479 scarletSource : `scarlet.Component`
1480 The scarlet model for the heavy footprint
1481 """
1482 # Store the flux at the center of the model and the total
1483 # scarlet flux measurement.
1484 mImage = mHeavy.getImage(fill=0.0).image
1486 # Set the flux at the center of the model (for SNR)
1487 try:
1488 cy, cx = scarletSource.center
1489 cy += xy0.y
1490 cx += xy0.x
1491 return mImage[:, cx, cy]
1492 except AttributeError:
1493 msg = "Did not recognize coordinates for source type of `{0}`, "
1494 msg += "could not write coordinates or center flux. "
1495 msg += "Add `{0}` to meas_extensions_scarlet to properly persist this information."
1496 logger.warning(msg.format(type(scarletSource)))
1497 return {f: np.nan for f in mImage.filters}