Coverage for python/lsst/meas/extensions/scarlet/scarletDeblendTask.py: 15%
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of meas_extensions_scarlet.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from functools import partial
23import logging
24import numpy as np
25import scarlet
26from scarlet.psf import ImagePSF, GaussianPSF
27from scarlet import Blend, Frame, Observation
28from scarlet.renderer import ConvolutionRenderer
29from scarlet.detect import get_detect_wavelets
30from scarlet.initialization import init_all_sources
31from scarlet import lite
33import lsst.pex.config as pexConfig
34from lsst.pex.exceptions import InvalidParameterError
35import lsst.pipe.base as pipeBase
36from lsst.geom import Point2I, Box2I, Point2D
37import lsst.afw.geom.ellipses as afwEll
38import lsst.afw.image as afwImage
39import lsst.afw.detection as afwDet
40import lsst.afw.table as afwTable
41from lsst.utils.logging import PeriodicLogger
42from lsst.utils.timer import timeMethod
44from .source import bboxToScarletBox, modelToHeavy, liteModelToHeavy
46# Scarlet and proxmin have a different definition of log levels than the stack,
47# so even "warnings" occur far more often than we would like.
48# So for now we only display scarlet and proxmin errors, as all other
49# scarlet outputs would be considered "TRACE" by our standards.
50scarletLogger = logging.getLogger("scarlet")
51scarletLogger.setLevel(logging.ERROR)
52proxminLogger = logging.getLogger("proxmin")
53proxminLogger.setLevel(logging.ERROR)
55__all__ = ["deblend", "ScarletDeblendConfig", "ScarletDeblendTask"]
57logger = logging.getLogger(__name__)
60class IncompleteDataError(Exception):
61 """The PSF could not be computed due to incomplete data
62 """
63 pass
66class ScarletGradientError(Exception):
67 """An error occurred during optimization
69 This error occurs when the optimizer encounters
70 a NaN value while calculating the gradient.
71 """
72 def __init__(self, iterations, sources):
73 self.iterations = iterations
74 self.sources = sources
75 msg = ("ScalarGradientError in iteration {0}. "
76 "NaN values introduced in sources {1}")
77 self.message = msg.format(iterations, sources)
79 def __str__(self):
80 return self.message
83def _checkBlendConvergence(blend, f_rel):
84 """Check whether or not a blend has converged
85 """
86 deltaLoss = np.abs(blend.loss[-2] - blend.loss[-1])
87 convergence = f_rel * np.abs(blend.loss[-1])
88 return deltaLoss < convergence
91def _getPsfFwhm(psf):
92 """Calculate the FWHM of the `psf`
93 """
94 return psf.computeShape().getDeterminantRadius() * 2.35
97def _computePsfImage(self, position=None):
98 """Get a multiband PSF image
99 The PSF Kernel Image is computed for each band
100 and combined into a (filter, y, x) array and stored
101 as `self._psfImage`.
102 The result is not cached, so if the same PSF is expected
103 to be used multiple times it is a good idea to store the
104 result in another variable.
105 Note: this is a temporary fix during the deblender sprint.
106 In the future this function will replace the current method
107 in `afw.MultibandExposure.computePsfImage` (DM-19789).
108 Parameters
109 ----------
110 position : `Point2D` or `tuple`
111 Coordinates to evaluate the PSF. If `position` is `None`
112 then `Psf.getAveragePosition()` is used.
113 Returns
114 -------
115 self._psfImage: array
116 The multiband PSF image.
117 """
118 psfs = []
119 # Make the coordinates into a Point2D (if necessary)
120 if not isinstance(position, Point2D) and position is not None:
121 position = Point2D(position[0], position[1])
123 for bidx, single in enumerate(self.singles):
124 try:
125 if position is None:
126 psf = single.getPsf().computeImage()
127 psfs.append(psf)
128 else:
129 psf = single.getPsf().computeKernelImage(position)
130 psfs.append(psf)
131 except InvalidParameterError:
132 # This band failed to compute the PSF due to incomplete data
133 # at that location. This is unlikely to be a problem for Rubin,
134 # however the edges of some HSC COSMOS fields contain incomplete
135 # data in some bands, so we track this error to distinguish it
136 # from unknown errors.
137 msg = "Failed to compute PSF at {} in band {}"
138 raise IncompleteDataError(msg.format(position, self.filters[bidx]))
140 left = np.min([psf.getBBox().getMinX() for psf in psfs])
141 bottom = np.min([psf.getBBox().getMinY() for psf in psfs])
142 right = np.max([psf.getBBox().getMaxX() for psf in psfs])
143 top = np.max([psf.getBBox().getMaxY() for psf in psfs])
144 bbox = Box2I(Point2I(left, bottom), Point2I(right, top))
145 psfs = [afwImage.utils.projectImage(psf, bbox) for psf in psfs]
146 psfImage = afwImage.MultibandImage.fromImages(self.filters, psfs)
147 return psfImage
150def getFootprintMask(footprint, mExposure):
151 """Mask pixels outside the footprint
153 Parameters
154 ----------
155 mExposure : `lsst.image.MultibandExposure`
156 - The multiband exposure containing the image,
157 mask, and variance data
158 footprint : `lsst.detection.Footprint`
159 - The footprint of the parent to deblend
161 Returns
162 -------
163 footprintMask : array
164 Boolean array with pixels not in the footprint set to one.
165 """
166 bbox = footprint.getBBox()
167 fpMask = afwImage.Mask(bbox)
168 footprint.spans.setMask(fpMask, 1)
169 fpMask = ~fpMask.getArray().astype(bool)
170 return fpMask
173def isPseudoSource(source, pseudoColumns):
174 """Check if a source is a pseudo source.
176 This is mostly for skipping sky objects,
177 but any other column can also be added to disable
178 deblending on a parent or individual source when
179 set to `True`.
181 Parameters
182 ----------
183 source : `lsst.afw.table.source.source.SourceRecord`
184 The source to check for the pseudo bit.
185 pseudoColumns : `list` of `str`
186 A list of columns to check for pseudo sources.
187 """
188 isPseudo = False
189 for col in pseudoColumns:
190 try:
191 isPseudo |= source[col]
192 except KeyError:
193 pass
194 return isPseudo
197def deblend(mExposure, footprint, config):
198 """Deblend a parent footprint
200 Parameters
201 ----------
202 mExposure : `lsst.image.MultibandExposure`
203 - The multiband exposure containing the image,
204 mask, and variance data
205 footprint : `lsst.detection.Footprint`
206 - The footprint of the parent to deblend
207 config : `ScarletDeblendConfig`
208 - Configuration of the deblending task
210 Returns
211 -------
212 blend : `scarlet.Blend`
213 The scarlet blend class that contains all of the information
214 about the parameters and results from scarlet
215 skipped : `list` of `int`
216 The indices of any children that failed to initialize
217 and were skipped.
218 spectrumInit : `bool`
219 Whether or not all of the sources were initialized by jointly
220 fitting their SED's. This provides a better initialization
221 but created memory issues when a blend is too large or
222 contains too many sources.
223 """
224 # Extract coordinates from each MultiColorPeak
225 bbox = footprint.getBBox()
227 # Create the data array from the masked images
228 images = mExposure.image[:, bbox].array
230 # Use the inverse variance as the weights
231 if config.useWeights:
232 weights = 1/mExposure.variance[:, bbox].array
233 else:
234 weights = np.ones_like(images)
235 badPixels = mExposure.mask.getPlaneBitMask(config.badMask)
236 mask = mExposure.mask[:, bbox].array & badPixels
237 weights[mask > 0] = 0
239 # Mask out the pixels outside the footprint
240 mask = getFootprintMask(footprint, mExposure)
241 weights *= ~mask
243 psfs = _computePsfImage(mExposure, footprint.getCentroid()).array.astype(np.float32)
244 psfs = ImagePSF(psfs)
245 model_psf = GaussianPSF(sigma=(config.modelPsfSigma,)*len(mExposure.filters))
247 frame = Frame(images.shape, psf=model_psf, channels=mExposure.filters)
248 observation = Observation(images, psf=psfs, weights=weights, channels=mExposure.filters)
249 if config.convolutionType == "fft":
250 observation.match(frame)
251 elif config.convolutionType == "real":
252 renderer = ConvolutionRenderer(observation, frame, convolution_type="real")
253 observation.match(frame, renderer=renderer)
254 else:
255 raise ValueError("Unrecognized convolution type {}".format(config.convolutionType))
257 assert(config.sourceModel in ["single", "double", "compact", "fit"])
259 # Set the appropriate number of components
260 if config.sourceModel == "single":
261 maxComponents = 1
262 elif config.sourceModel == "double":
263 maxComponents = 2
264 elif config.sourceModel == "compact":
265 maxComponents = 0
266 elif config.sourceModel == "point":
267 raise NotImplementedError("Point source photometry is currently not implemented")
268 elif config.sourceModel == "fit":
269 # It is likely in the future that there will be some heuristic
270 # used to determine what type of model to use for each source,
271 # but that has not yet been implemented (see DM-22551)
272 raise NotImplementedError("sourceModel 'fit' has not been implemented yet")
274 # Convert the centers to pixel coordinates
275 xmin = bbox.getMinX()
276 ymin = bbox.getMinY()
277 centers = [
278 np.array([peak.getIy() - ymin, peak.getIx() - xmin], dtype=int)
279 for peak in footprint.peaks
280 if not isPseudoSource(peak, config.pseudoColumns)
281 ]
283 # Choose whether or not to use the improved spectral initialization
284 if config.setSpectra:
285 if config.maxSpectrumCutoff <= 0:
286 spectrumInit = True
287 else:
288 spectrumInit = len(centers) * bbox.getArea() < config.maxSpectrumCutoff
289 else:
290 spectrumInit = False
292 # Only deblend sources that can be initialized
293 sources, skipped = init_all_sources(
294 frame=frame,
295 centers=centers,
296 observations=observation,
297 thresh=config.morphThresh,
298 max_components=maxComponents,
299 min_snr=config.minSNR,
300 shifting=False,
301 fallback=config.fallback,
302 silent=config.catchFailures,
303 set_spectra=spectrumInit,
304 )
306 # Attach the peak to all of the initialized sources
307 srcIndex = 0
308 for k, center in enumerate(centers):
309 if k not in skipped:
310 # This is just to make sure that there isn't a coding bug
311 assert np.all(sources[srcIndex].center == center)
312 # Store the record for the peak with the appropriate source
313 sources[srcIndex].detectedPeak = footprint.peaks[k]
314 srcIndex += 1
316 # Create the blend and attempt to optimize it
317 blend = Blend(sources, observation)
318 try:
319 blend.fit(max_iter=config.maxIter, e_rel=config.relativeError)
320 except ArithmeticError:
321 # This occurs when a gradient update produces a NaN value
322 # This is usually due to a source initialized with a
323 # negative SED or no flux, often because the peak
324 # is a noise fluctuation in one band and not a real source.
325 iterations = len(blend.loss)
326 failedSources = []
327 for k, src in enumerate(sources):
328 if np.any(~np.isfinite(src.get_model())):
329 failedSources.append(k)
330 raise ScarletGradientError(iterations, failedSources)
332 return blend, skipped, spectrumInit
335def deblend_lite(mExposure, footprint, config, wavelets=None):
336 """Deblend a parent footprint
338 Parameters
339 ----------
340 mExposure : `lsst.image.MultibandExposure`
341 - The multiband exposure containing the image,
342 mask, and variance data
343 footprint : `lsst.detection.Footprint`
344 - The footprint of the parent to deblend
345 config : `ScarletDeblendConfig`
346 - Configuration of the deblending task
347 """
348 # Extract coordinates from each MultiColorPeak
349 bbox = footprint.getBBox()
351 # Create the data array from the masked images
352 images = mExposure.image[:, bbox].array
353 variance = mExposure.variance[:, bbox].array
355 # Use the inverse variance as the weights
356 if config.useWeights:
357 weights = 1/mExposure.variance[:, bbox].array
358 else:
359 weights = np.ones_like(images)
360 badPixels = mExposure.mask.getPlaneBitMask(config.badMask)
361 mask = mExposure.mask[:, bbox].array & badPixels
362 weights[mask > 0] = 0
364 # Mask out the pixels outside the footprint
365 mask = getFootprintMask(footprint, mExposure)
366 weights *= ~mask
368 psfs = _computePsfImage(mExposure, footprint.getCentroid()).array.astype(np.float32)
369 modelPsf = lite.integrated_circular_gaussian(sigma=config.modelPsfSigma)
371 observation = lite.LiteObservation(
372 images=images,
373 variance=variance,
374 weights=weights,
375 psfs=psfs,
376 model_psf=modelPsf[None, :, :],
377 convolution_mode=config.convolutionType,
378 )
380 # Convert the centers to pixel coordinates
381 xmin = bbox.getMinX()
382 ymin = bbox.getMinY()
383 centers = [
384 np.array([peak.getIy() - ymin, peak.getIx() - xmin], dtype=int)
385 for peak in footprint.peaks
386 if not isPseudoSource(peak, config.pseudoColumns)
387 ]
389 # Initialize the sources
390 if config.morphImage == "chi2":
391 sources = lite.init_all_sources_main(
392 observation,
393 centers,
394 min_snr=config.minSNR,
395 thresh=config.morphThresh,
396 )
397 elif config.morphImage == "wavelet":
398 _bbox = bboxToScarletBox(len(mExposure.filters), bbox, bbox.getMin())
399 _wavelets = wavelets[(slice(None), *_bbox[1:].slices)]
400 sources = lite.init_all_sources_wavelets(
401 observation,
402 centers,
403 use_psf=False,
404 wavelets=_wavelets,
405 min_snr=config.minSNR,
406 )
407 else:
408 raise ValueError("morphImage must be either 'chi2' or 'wavelet'.")
410 # Set the optimizer
411 if config.optimizer == "adaprox":
412 parameterization = partial(
413 lite.init_adaprox_component,
414 bg_thresh=config.backgroundThresh,
415 max_prox_iter=config.maxProxIter,
416 )
417 elif config.optimizer == "fista":
418 parameterization = partial(
419 lite.init_fista_component,
420 bg_thresh=config.backgroundThresh,
421 )
422 else:
423 raise ValueError("Unrecognized optimizer. Must be either 'adaprox' or 'fista'.")
424 sources = lite.parameterize_sources(sources, observation, parameterization)
426 # Attach the peak to all of the initialized sources
427 for k, center in enumerate(centers):
428 # This is just to make sure that there isn't a coding bug
429 if len(sources[k].components) > 0 and np.any(sources[k].center != center):
430 raise ValueError("Misaligned center, expected {center} but got {sources[k].center}")
431 # Store the record for the peak with the appropriate source
432 sources[k].detectedPeak = footprint.peaks[k]
434 blend = lite.LiteBlend(sources, observation)
436 # Initialize each source with its best fit spectrum
437 # This significantly cuts down on the number of iterations
438 # that the optimizer needs and usually results in a better
439 # fit, but using least squares on a very large blend causes memory issues.
440 # This is typically the most expensive operation in deblending, memorywise.
441 spectrumInit = False
442 if config.setSpectra:
443 if config.maxSpectrumCutoff <= 0 or len(centers) * bbox.getArea() < config.maxSpectrumCutoff:
444 spectrumInit = True
445 blend.fit_spectra()
447 # Set the sources that could not be initialized and were skipped
448 skipped = [src for src in sources if src.is_null]
450 blend.fit(max_iter=config.maxIter, e_rel=config.relativeError, min_iter=config.minIter)
452 return blend, skipped, spectrumInit
455class ScarletDeblendConfig(pexConfig.Config):
456 """MultibandDeblendConfig
458 Configuration for the multiband deblender.
459 The parameters are organized by the parameter types, which are
460 - Stopping Criteria: Used to determine if the fit has converged
461 - Position Fitting Criteria: Used to fit the positions of the peaks
462 - Constraints: Used to apply constraints to the peaks and their components
463 - Other: Parameters that don't fit into the above categories
464 """
465 # Stopping Criteria
466 minIter = pexConfig.Field(dtype=int, default=1,
467 doc="Minimum number of iterations before the optimizer is allowed to stop.")
468 maxIter = pexConfig.Field(dtype=int, default=300,
469 doc=("Maximum number of iterations to deblend a single parent"))
470 relativeError = pexConfig.Field(dtype=float, default=1e-2,
471 doc=("Change in the loss function between iterations to exit fitter. "
472 "Typically this is `1e-2` if measurements will be made on the "
473 "flux re-distributed models and `1e-4` when making measurements "
474 "on the models themselves."))
476 # Constraints
477 morphThresh = pexConfig.Field(dtype=float, default=1,
478 doc="Fraction of background RMS a pixel must have"
479 "to be included in the initial morphology")
480 # Lite Parameters
481 # All of these parameters (except version) are only valid if version='lite'
482 version = pexConfig.ChoiceField(
483 dtype=str,
484 default="lite",
485 allowed={
486 "scarlet": "main scarlet version (likely to be deprecated soon)",
487 "lite": "Optimized version of scarlet for survey data from a single instrument",
488 },
489 doc="The version of scarlet to use.",
490 )
491 optimizer = pexConfig.ChoiceField(
492 dtype=str,
493 default="adaprox",
494 allowed={
495 "adaprox": "Proximal ADAM optimization",
496 "fista": "Accelerated proximal gradient method",
497 },
498 doc="The optimizer to use for fitting parameters and is only used when version='lite'",
499 )
500 morphImage = pexConfig.ChoiceField(
501 dtype=str,
502 default="chi2",
503 allowed={
504 "chi2": "Initialize sources on a chi^2 image made from all available bands",
505 "wavelet": "Initialize sources using a wavelet decomposition of the chi^2 image",
506 },
507 doc="The type of image to use for initializing the morphology. "
508 "Must be either 'chi2' or 'wavelet'. "
509 )
510 backgroundThresh = pexConfig.Field(
511 dtype=float,
512 default=0.25,
513 doc="Fraction of background to use for a sparsity threshold. "
514 "This prevents sources from growing unrealistically outside "
515 "the parent footprint while still modeling flux correctly "
516 "for bright sources."
517 )
518 maxProxIter = pexConfig.Field(
519 dtype=int,
520 default=1,
521 doc="Maximum number of proximal operator iterations inside of each "
522 "iteration of the optimizer. "
523 "This config field is only used if version='lite' and optimizer='adaprox'."
524 )
525 waveletScales = pexConfig.Field(
526 dtype=int,
527 default=5,
528 doc="Number of wavelet scales to use for wavelet initialization. "
529 "This field is only used when `version`='lite' and `morphImage`='wavelet'."
530 )
532 # Other scarlet paremeters
533 useWeights = pexConfig.Field(
534 dtype=bool, default=True,
535 doc=("Whether or not use use inverse variance weighting."
536 "If `useWeights` is `False` then flat weights are used"))
537 modelPsfSize = pexConfig.Field(
538 dtype=int, default=11,
539 doc="Model PSF side length in pixels")
540 modelPsfSigma = pexConfig.Field(
541 dtype=float, default=0.8,
542 doc="Define sigma for the model frame PSF")
543 minSNR = pexConfig.Field(
544 dtype=float, default=50,
545 doc="Minimum Signal to noise to accept the source."
546 "Sources with lower flux will be initialized with the PSF but updated "
547 "like an ordinary ExtendedSource (known in scarlet as a `CompactSource`).")
548 saveTemplates = pexConfig.Field(
549 dtype=bool, default=True,
550 doc="Whether or not to save the SEDs and templates")
551 processSingles = pexConfig.Field(
552 dtype=bool, default=True,
553 doc="Whether or not to process isolated sources in the deblender")
554 convolutionType = pexConfig.Field(
555 dtype=str, default="fft",
556 doc="Type of convolution to render the model to the observations.\n"
557 "- 'fft': perform convolutions in Fourier space\n"
558 "- 'real': peform convolutions in real space.")
559 sourceModel = pexConfig.Field(
560 dtype=str, default="double",
561 doc=("How to determine which model to use for sources, from\n"
562 "- 'single': use a single component for all sources\n"
563 "- 'double': use a bulge disk model for all sources\n"
564 "- 'compact': use a single component model, initialzed with a point source morphology, "
565 " for all sources\n"
566 "- 'point': use a point-source model for all sources\n"
567 "- 'fit: use a PSF fitting model to determine the number of components (not yet implemented)"),
568 deprecated="This field will be deprecated when the default for `version` is changed to `lite`.",
569 )
570 setSpectra = pexConfig.Field(
571 dtype=bool, default=True,
572 doc="Whether or not to solve for the best-fit spectra during initialization. "
573 "This makes initialization slightly longer, as it requires a convolution "
574 "to set the optimal spectra, but results in a much better initial log-likelihood "
575 "and reduced total runtime, with convergence in fewer iterations."
576 "This option is only used when "
577 "peaks*area < `maxSpectrumCutoff` will use the improved initialization.")
579 # Mask-plane restrictions
580 badMask = pexConfig.ListField(
581 dtype=str, default=["BAD", "CR", "NO_DATA", "SAT", "SUSPECT", "EDGE"],
582 doc="Whether or not to process isolated sources in the deblender")
583 statsMask = pexConfig.ListField(dtype=str, default=["SAT", "INTRP", "NO_DATA"],
584 doc="Mask planes to ignore when performing statistics")
585 maskLimits = pexConfig.DictField(
586 keytype=str,
587 itemtype=float,
588 default={},
589 doc=("Mask planes with the corresponding limit on the fraction of masked pixels. "
590 "Sources violating this limit will not be deblended."),
591 )
593 # Size restrictions
594 maxNumberOfPeaks = pexConfig.Field(
595 dtype=int, default=200,
596 doc=("Only deblend the brightest maxNumberOfPeaks peaks in the parent"
597 " (<= 0: unlimited)"))
598 maxFootprintArea = pexConfig.Field(
599 dtype=int, default=100_000,
600 doc=("Maximum area for footprints before they are ignored as large; "
601 "non-positive means no threshold applied"))
602 maxAreaTimesPeaks = pexConfig.Field(
603 dtype=int, default=10_000_000,
604 doc=("Maximum rectangular footprint area * nPeaks in the footprint. "
605 "This was introduced in DM-33690 to prevent fields that are crowded or have a "
606 "LSB galaxy that causes memory intensive initialization in scarlet from dominating "
607 "the overall runtime and/or causing the task to run out of memory. "
608 "(<= 0: unlimited)")
609 )
610 maxFootprintSize = pexConfig.Field(
611 dtype=int, default=0,
612 doc=("Maximum linear dimension for footprints before they are ignored "
613 "as large; non-positive means no threshold applied"))
614 minFootprintAxisRatio = pexConfig.Field(
615 dtype=float, default=0.0,
616 doc=("Minimum axis ratio for footprints before they are ignored "
617 "as large; non-positive means no threshold applied"))
618 maxSpectrumCutoff = pexConfig.Field(
619 dtype=int, default=1_000_000,
620 doc=("Maximum number of pixels * number of sources in a blend. "
621 "This is different than `maxFootprintArea` because this isn't "
622 "the footprint area but the area of the bounding box that "
623 "contains the footprint, and is also multiplied by the number of"
624 "sources in the footprint. This prevents large skinny blends with "
625 "a high density of sources from running out of memory. "
626 "If `maxSpectrumCutoff == -1` then there is no cutoff.")
627 )
629 # Failure modes
630 fallback = pexConfig.Field(
631 dtype=bool, default=True,
632 doc="Whether or not to fallback to a smaller number of components if a source does not initialize"
633 )
634 notDeblendedMask = pexConfig.Field(
635 dtype=str, default="NOT_DEBLENDED", optional=True,
636 doc="Mask name for footprints not deblended, or None")
637 catchFailures = pexConfig.Field(
638 dtype=bool, default=True,
639 doc=("If True, catch exceptions thrown by the deblender, log them, "
640 "and set a flag on the parent, instead of letting them propagate up"))
642 # Other options
643 columnInheritance = pexConfig.DictField(
644 keytype=str, itemtype=str, default={
645 "deblend_nChild": "deblend_parentNChild",
646 "deblend_nPeaks": "deblend_parentNPeaks",
647 "deblend_spectrumInitFlag": "deblend_spectrumInitFlag",
648 "deblend_blendConvergenceFailedFlag": "deblend_blendConvergenceFailedFlag",
649 },
650 doc="Columns to pass from the parent to the child. "
651 "The key is the name of the column for the parent record, "
652 "the value is the name of the column to use for the child."
653 )
654 pseudoColumns = pexConfig.ListField(
655 dtype=str, default=['merge_peak_sky', 'sky_source'],
656 doc="Names of flags which should never be deblended."
657 )
659 # Logging option(s)
660 loggingInterval = pexConfig.Field(
661 dtype=int, default=600,
662 doc="Interval (in seconds) to log messages (at VERBOSE level) while deblending sources.",
663 deprecated="This field is no longer used and will be removed in v25.",
664 )
665 # Testing options
666 # Some obs packages and ci packages run the full pipeline on a small
667 # subset of data to test that the pipeline is functioning properly.
668 # This is not meant as scientific validation, so it can be useful
669 # to only run on a small subset of the data that is large enough to
670 # test the desired pipeline features but not so long that the deblender
671 # is the tall pole in terms of execution times.
672 useCiLimits = pexConfig.Field(
673 dtype=bool, default=False,
674 doc="Limit the number of sources deblended for CI to prevent long build times")
675 ciDeblendChildRange = pexConfig.ListField(
676 dtype=int, default=[5, 10],
677 doc="Only deblend parent Footprints with a number of peaks in the (inclusive) range indicated."
678 "If `useCiLimits==False` then this parameter is ignored.")
679 ciNumParentsToDeblend = pexConfig.Field(
680 dtype=int, default=10,
681 doc="Only use the first `ciNumParentsToDeblend` parent footprints with a total peak count "
682 "within `ciDebledChildRange`. "
683 "If `useCiLimits==False` then this parameter is ignored.")
686class ScarletDeblendTask(pipeBase.Task):
687 """ScarletDeblendTask
689 Split blended sources into individual sources.
691 This task has no return value; it only modifies the SourceCatalog in-place.
692 """
693 ConfigClass = ScarletDeblendConfig
694 _DefaultName = "scarletDeblend"
696 def __init__(self, schema, peakSchema=None, **kwargs):
697 """Create the task, adding necessary fields to the given schema.
699 Parameters
700 ----------
701 schema : `lsst.afw.table.schema.schema.Schema`
702 Schema object for measurement fields; will be modified in-place.
703 peakSchema : `lsst.afw.table.schema.schema.Schema`
704 Schema of Footprint Peaks that will be passed to the deblender.
705 Any fields beyond the PeakTable minimal schema will be transferred
706 to the main source Schema. If None, no fields will be transferred
707 from the Peaks.
708 filters : list of str
709 Names of the filters used for the eposures. This is needed to store
710 the SED as a field
711 **kwargs
712 Passed to Task.__init__.
713 """
714 pipeBase.Task.__init__(self, **kwargs)
716 peakMinimalSchema = afwDet.PeakTable.makeMinimalSchema()
717 if peakSchema is None:
718 # In this case, the peakSchemaMapper will transfer nothing, but
719 # we'll still have one
720 # to simplify downstream code
721 self.peakSchemaMapper = afwTable.SchemaMapper(peakMinimalSchema, schema)
722 else:
723 self.peakSchemaMapper = afwTable.SchemaMapper(peakSchema, schema)
724 for item in peakSchema:
725 if item.key not in peakMinimalSchema:
726 self.peakSchemaMapper.addMapping(item.key, item.field)
727 # Because SchemaMapper makes a copy of the output schema
728 # you give its ctor, it isn't updating this Schema in
729 # place. That's probably a design flaw, but in the
730 # meantime, we'll keep that schema in sync with the
731 # peakSchemaMapper.getOutputSchema() manually, by adding
732 # the same fields to both.
733 schema.addField(item.field)
734 assert schema == self.peakSchemaMapper.getOutputSchema(), "Logic bug mapping schemas"
735 self._addSchemaKeys(schema)
736 self.schema = schema
737 self.toCopyFromParent = [item.key for item in self.schema
738 if item.field.getName().startswith("merge_footprint")]
740 def _addSchemaKeys(self, schema):
741 """Add deblender specific keys to the schema
742 """
743 self.runtimeKey = schema.addField('deblend_runtime', type=np.float32, doc='runtime in ms')
745 self.iterKey = schema.addField('deblend_iterations', type=np.int32, doc='iterations to converge')
747 self.nChildKey = schema.addField('deblend_nChild', type=np.int32,
748 doc='Number of children this object has (defaults to 0)')
749 self.psfKey = schema.addField('deblend_deblendedAsPsf', type='Flag',
750 doc='Deblender thought this source looked like a PSF')
751 self.tooManyPeaksKey = schema.addField('deblend_tooManyPeaks', type='Flag',
752 doc='Source had too many peaks; '
753 'only the brightest were included')
754 self.tooBigKey = schema.addField('deblend_parentTooBig', type='Flag',
755 doc='Parent footprint covered too many pixels')
756 self.maskedKey = schema.addField('deblend_masked', type='Flag',
757 doc='Parent footprint was predominantly masked')
758 self.sedNotConvergedKey = schema.addField('deblend_sedConvergenceFailed', type='Flag',
759 doc='scarlet sed optimization did not converge before'
760 'config.maxIter')
761 self.morphNotConvergedKey = schema.addField('deblend_morphConvergenceFailed', type='Flag',
762 doc='scarlet morph optimization did not converge before'
763 'config.maxIter')
764 self.blendConvergenceFailedFlagKey = schema.addField('deblend_blendConvergenceFailedFlag',
765 type='Flag',
766 doc='at least one source in the blend'
767 'failed to converge')
768 self.edgePixelsKey = schema.addField('deblend_edgePixels', type='Flag',
769 doc='Source had flux on the edge of the parent footprint')
770 self.deblendFailedKey = schema.addField('deblend_failed', type='Flag',
771 doc="Deblending failed on source")
772 self.deblendErrorKey = schema.addField('deblend_error', type="String", size=25,
773 doc='Name of error if the blend failed')
774 self.deblendSkippedKey = schema.addField('deblend_skipped', type='Flag',
775 doc="Deblender skipped this source")
776 self.peakCenter = afwTable.Point2IKey.addFields(schema, name="deblend_peak_center",
777 doc="Center used to apply constraints in scarlet",
778 unit="pixel")
779 self.peakIdKey = schema.addField("deblend_peakId", type=np.int32,
780 doc="ID of the peak in the parent footprint. "
781 "This is not unique, but the combination of 'parent'"
782 "and 'peakId' should be for all child sources. "
783 "Top level blends with no parents have 'peakId=0'")
784 self.modelCenterFlux = schema.addField('deblend_peak_instFlux', type=float, units='count',
785 doc="The instFlux at the peak position of deblended mode")
786 self.modelTypeKey = schema.addField("deblend_modelType", type="String", size=25,
787 doc="The type of model used, for example "
788 "MultiExtendedSource, SingleExtendedSource, PointSource")
789 self.nPeaksKey = schema.addField("deblend_nPeaks", type=np.int32,
790 doc="Number of initial peaks in the blend. "
791 "This includes peaks that may have been culled "
792 "during deblending or failed to deblend")
793 self.parentNPeaksKey = schema.addField("deblend_parentNPeaks", type=np.int32,
794 doc="deblend_nPeaks from this records parent.")
795 self.parentNChildKey = schema.addField("deblend_parentNChild", type=np.int32,
796 doc="deblend_nChild from this records parent.")
797 self.scarletFluxKey = schema.addField("deblend_scarletFlux", type=np.float32,
798 doc="Flux measurement from scarlet")
799 self.scarletLogLKey = schema.addField("deblend_logL", type=np.float32,
800 doc="Final logL, used to identify regressions in scarlet.")
801 self.scarletSpectrumInitKey = schema.addField("deblend_spectrumInitFlag", type='Flag',
802 doc="True when scarlet initializes sources "
803 "in the blend with a more accurate spectrum. "
804 "The algorithm uses a lot of memory, "
805 "so large dense blends will use "
806 "a less accurate initialization.")
808 # self.log.trace('Added keys to schema: %s', ", ".join(str(x) for x in
809 # (self.nChildKey, self.tooManyPeaksKey, self.tooBigKey))
810 # )
812 @timeMethod
813 def run(self, mExposure, mergedSources):
814 """Get the psf from each exposure and then run deblend().
816 Parameters
817 ----------
818 mExposure : `MultibandExposure`
819 The exposures should be co-added images of the same
820 shape and region of the sky.
821 mergedSources : `SourceCatalog`
822 The merged `SourceCatalog` that contains parent footprints
823 to (potentially) deblend.
825 Returns
826 -------
827 templateCatalogs: dict
828 Keys are the names of the filters and the values are
829 `lsst.afw.table.source.source.SourceCatalog`'s.
830 These are catalogs with heavy footprints that are the templates
831 created by the multiband templates.
832 """
833 return self.deblend(mExposure, mergedSources)
835 @timeMethod
836 def deblend(self, mExposure, catalog):
837 """Deblend a data cube of multiband images
839 Parameters
840 ----------
841 mExposure : `MultibandExposure`
842 The exposures should be co-added images of the same
843 shape and region of the sky.
844 catalog : `SourceCatalog`
845 The merged `SourceCatalog` that contains parent footprints
846 to (potentially) deblend. The new deblended sources are
847 appended to this catalog in place.
849 Returns
850 -------
851 catalogs : `dict` or `None`
852 Keys are the names of the filters and the values are
853 `lsst.afw.table.source.source.SourceCatalog`'s.
854 These are catalogs with heavy footprints that are the templates
855 created by the multiband templates.
856 """
857 import time
859 # Cull footprints if required by ci
860 if self.config.useCiLimits:
861 self.log.info("Using CI catalog limits, the original number of sources to deblend was %d.",
862 len(catalog))
863 # Select parents with a number of children in the range
864 # config.ciDeblendChildRange
865 minChildren, maxChildren = self.config.ciDeblendChildRange
866 nPeaks = np.array([len(src.getFootprint().peaks) for src in catalog])
867 childrenInRange = np.where((nPeaks >= minChildren) & (nPeaks <= maxChildren))[0]
868 if len(childrenInRange) < self.config.ciNumParentsToDeblend:
869 raise ValueError("Fewer than ciNumParentsToDeblend children were contained in the range "
870 "indicated by ciDeblendChildRange. Adjust this range to include more "
871 "parents.")
872 # Keep all of the isolated parents and the first
873 # `ciNumParentsToDeblend` children
874 parents = nPeaks == 1
875 children = np.zeros((len(catalog),), dtype=bool)
876 children[childrenInRange[:self.config.ciNumParentsToDeblend]] = True
877 catalog = catalog[parents | children]
878 # We need to update the IdFactory, otherwise the the source ids
879 # will not be sequential
880 idFactory = catalog.getIdFactory()
881 maxId = np.max(catalog["id"])
882 idFactory.notify(maxId)
884 filters = mExposure.filters
885 self.log.info("Deblending %d sources in %d exposure bands", len(catalog), len(mExposure))
886 periodicLog = PeriodicLogger(self.log)
888 # Create a set of wavelet coefficients if using wavelet initialization
889 if self.config.version == "lite" and self.config.morphImage == "wavelet":
890 images = mExposure.image.array
891 variance = mExposure.variance.array
892 wavelets = get_detect_wavelets(images, variance, scales=self.config.waveletScales)
893 else:
894 wavelets = None
896 # Add the NOT_DEBLENDED mask to the mask plane in each band
897 if self.config.notDeblendedMask:
898 for mask in mExposure.mask:
899 mask.addMaskPlane(self.config.notDeblendedMask)
901 nParents = len(catalog)
902 nDeblendedParents = 0
903 skippedParents = []
904 multibandColumns = {
905 "heavies": [],
906 "fluxes": [],
907 "centerFluxes": [],
908 }
909 weightedColumns = {
910 "heavies": [],
911 "fluxes": [],
912 "centerFluxes": [],
913 }
914 for parentIndex in range(nParents):
915 parent = catalog[parentIndex]
916 foot = parent.getFootprint()
917 bbox = foot.getBBox()
918 peaks = foot.getPeaks()
920 # Since we use the first peak for the parent object, we should
921 # propagate its flags to the parent source.
922 parent.assign(peaks[0], self.peakSchemaMapper)
924 # Skip isolated sources unless processSingles is turned on.
925 # Note: this does not flag isolated sources as skipped or
926 # set the NOT_DEBLENDED mask in the exposure,
927 # since these aren't really a skipped blends.
928 # We also skip pseudo sources, like sky objects, which
929 # are intended to be skipped
930 if ((len(peaks) < 2 and not self.config.processSingles)
931 or isPseudoSource(parent, self.config.pseudoColumns)):
932 self._updateParentRecord(
933 parent=parent,
934 nPeaks=len(peaks),
935 nChild=0,
936 runtime=np.nan,
937 iterations=0,
938 logL=np.nan,
939 spectrumInit=False,
940 converged=False,
941 )
942 continue
944 # Block of conditions for skipping a parent with multiple children
945 skipKey = None
946 if self._isLargeFootprint(foot):
947 # The footprint is above the maximum footprint size limit
948 skipKey = self.tooBigKey
949 skipMessage = f"Parent {parent.getId()}: skipping large footprint"
950 elif self._isMasked(foot, mExposure):
951 # The footprint exceeds the maximum number of masked pixels
952 skipKey = self.maskedKey
953 skipMessage = f"Parent {parent.getId()}: skipping masked footprint"
954 elif self.config.maxNumberOfPeaks > 0 and len(peaks) > self.config.maxNumberOfPeaks:
955 # Unlike meas_deblender, in scarlet we skip the entire blend
956 # if the number of peaks exceeds max peaks, since neglecting
957 # to model any peaks often results in catastrophic failure
958 # of scarlet to generate models for the brighter sources.
959 skipKey = self.tooManyPeaksKey
960 skipMessage = f"Parent {parent.getId()}: Too many peaks, skipping blend"
961 if skipKey is not None:
962 self._skipParent(
963 parent=parent,
964 skipKey=skipKey,
965 logMessage=skipMessage,
966 )
967 skippedParents.append(parentIndex)
968 continue
970 nDeblendedParents += 1
971 self.log.trace("Parent %d: deblending %d peaks", parent.getId(), len(peaks))
972 # Run the deblender
973 blendError = None
974 try:
975 t0 = time.monotonic()
976 # Build the parameter lists with the same ordering
977 if self.config.version == "scarlet":
978 blend, skipped, spectrumInit = deblend(mExposure, foot, self.config)
979 elif self.config.version == "lite":
980 blend, skipped, spectrumInit = deblend_lite(mExposure, foot, self.config, wavelets)
981 tf = time.monotonic()
982 runtime = (tf-t0)*1000
983 converged = _checkBlendConvergence(blend, self.config.relativeError)
985 scarletSources = [src for src in blend.sources]
986 nChild = len(scarletSources)
987 # Catch all errors and filter out the ones that we know about
988 except Exception as e:
989 blendError = type(e).__name__
990 if isinstance(e, ScarletGradientError):
991 parent.set(self.iterKey, e.iterations)
992 elif not isinstance(e, IncompleteDataError):
993 blendError = "UnknownError"
994 if self.config.catchFailures:
995 # Make it easy to find UnknownErrors in the log file
996 self.log.warn("UnknownError")
997 import traceback
998 traceback.print_exc()
999 else:
1000 raise
1002 self._skipParent(
1003 parent=parent,
1004 skipKey=self.deblendFailedKey,
1005 logMessage=f"Unable to deblend source {parent.getId}: {blendError}",
1006 )
1007 parent.set(self.deblendErrorKey, blendError)
1008 skippedParents.append(parentIndex)
1009 continue
1011 # Update the parent record with the deblending results
1012 if self.config.version == "scarlet":
1013 logL = -blend.loss[-1] + blend.observations[0].log_norm
1014 elif self.config.version == "lite":
1015 logL = blend.loss[-1]
1016 self._updateParentRecord(
1017 parent=parent,
1018 nPeaks=len(peaks),
1019 nChild=nChild,
1020 runtime=runtime,
1021 iterations=len(blend.loss),
1022 logL=logL,
1023 spectrumInit=spectrumInit,
1024 converged=converged,
1025 )
1027 # Add each deblended source to the catalog
1028 for k, scarletSource in enumerate(scarletSources):
1029 # Skip any sources with no flux or that scarlet skipped because
1030 # it could not initialize
1031 if k in skipped or (self.config.version == "lite" and scarletSource.is_null):
1032 # No need to propagate anything
1033 continue
1034 parent.set(self.deblendSkippedKey, False)
1035 if self.config.version == "lite":
1036 mHeavy = liteModelToHeavy(scarletSource, mExposure, blend, xy0=bbox.getMin())
1037 weightedHeavy = liteModelToHeavy(
1038 scarletSource, mExposure, blend, xy0=bbox.getMin(), useFlux=True)
1039 weightedColumns["heavies"].append(weightedHeavy)
1040 flux = scarletSource.get_model(use_flux=True).sum(axis=(1, 2))
1041 weightedColumns["fluxes"].append({
1042 filters[fidx]: _flux
1043 for fidx, _flux in enumerate(flux)
1044 })
1045 centerFlux = self._getCenterFlux(weightedHeavy, scarletSource, xy0=bbox.getMin())
1046 weightedColumns["centerFluxes"].append(centerFlux)
1047 else:
1048 mHeavy = modelToHeavy(scarletSource, mExposure, blend, xy0=bbox.getMin())
1049 multibandColumns["heavies"].append(mHeavy)
1050 flux = scarlet.measure.flux(scarletSource)
1051 multibandColumns["fluxes"].append({
1052 filters[fidx]: _flux
1053 for fidx, _flux in enumerate(flux)
1054 })
1055 centerFlux = self._getCenterFlux(mHeavy, scarletSource, xy0=bbox.getMin())
1056 multibandColumns["centerFluxes"].append(centerFlux)
1058 # Add all fields except the HeavyFootprint to the
1059 # source record
1060 self._addChild(
1061 parent=parent,
1062 mHeavy=mHeavy,
1063 catalog=catalog,
1064 scarletSource=scarletSource,
1065 )
1067 # Log a message if it has been a while since the last log.
1068 periodicLog.log("Deblended %d parent sources out of %d", parentIndex + 1, nParents)
1070 # Clear the cached values in scarlet to clear out memory
1071 scarlet.cache.Cache._cache = {}
1073 # Make sure that the number of new sources matches the number of
1074 # entries in each of the band dependent columns.
1075 # This should never trigger and is just a sanity check.
1076 nChildren = len(catalog) - nParents
1077 if np.any([len(meas) != nChildren for meas in multibandColumns.values()]):
1078 msg = f"Added {len(catalog)-nParents} new sources, but have "
1079 msg += ", ".join([
1080 f"{len(value)} {key}"
1081 for key, value in multibandColumns.items()
1082 ])
1083 raise RuntimeError(msg)
1084 # Make a copy of the catlog in each band and update the footprints
1085 catalogs = {}
1086 for f in filters:
1087 _catalog = afwTable.SourceCatalog(catalog.table.clone())
1088 _catalog.extend(catalog, deep=True)
1090 # Update the footprints and columns that are different
1091 # for each filter
1092 for sourceIndex, source in enumerate(_catalog[nParents:]):
1093 source.setFootprint(multibandColumns["heavies"][sourceIndex][f])
1094 source.set(self.scarletFluxKey, multibandColumns["fluxes"][sourceIndex][f])
1095 source.set(self.modelCenterFlux, multibandColumns["centerFluxes"][sourceIndex][f])
1096 catalogs[f] = _catalog
1098 weightedCatalogs = {}
1099 if self.config.version == "lite":
1100 # Also create a catalog by reweighting the flux
1101 weightedCatalogs = {}
1102 for f in filters:
1103 _catalog = afwTable.SourceCatalog(catalog.table.clone())
1104 _catalog.extend(catalog, deep=True)
1106 # Update the footprints and columns that are different
1107 # for each filter
1108 for sourceIndex, source in enumerate(_catalog[nParents:]):
1109 source.setFootprint(weightedColumns["heavies"][sourceIndex][f])
1110 source.set(self.scarletFluxKey, weightedColumns["fluxes"][sourceIndex][f])
1111 source.set(self.modelCenterFlux, weightedColumns["centerFluxes"][sourceIndex][f])
1112 weightedCatalogs[f] = _catalog
1114 # Update the mExposure mask with the footprint of skipped parents
1115 if self.config.notDeblendedMask:
1116 for mask in mExposure.mask:
1117 for parentIndex in skippedParents:
1118 fp = _catalog[parentIndex].getFootprint()
1119 fp.spans.setMask(mask, mask.getPlaneBitMask(self.config.notDeblendedMask))
1121 self.log.info("Deblender results: of %d parent sources, %d were deblended, "
1122 "creating %d children, for a total of %d sources",
1123 nParents, nDeblendedParents, nChildren, len(catalog))
1124 return catalogs, weightedCatalogs
1126 def _isLargeFootprint(self, footprint):
1127 """Returns whether a Footprint is large
1129 'Large' is defined by thresholds on the area, size and axis ratio,
1130 and total area of the bounding box multiplied by
1131 the number of children.
1132 These may be disabled independently by configuring them to be
1133 non-positive.
1134 """
1135 if self.config.maxFootprintArea > 0 and footprint.getArea() > self.config.maxFootprintArea:
1136 return True
1137 if self.config.maxFootprintSize > 0:
1138 bbox = footprint.getBBox()
1139 if max(bbox.getWidth(), bbox.getHeight()) > self.config.maxFootprintSize:
1140 return True
1141 if self.config.minFootprintAxisRatio > 0:
1142 axes = afwEll.Axes(footprint.getShape())
1143 if axes.getB() < self.config.minFootprintAxisRatio*axes.getA():
1144 return True
1145 if self.config.maxAreaTimesPeaks > 0:
1146 if footprint.getBBox().getArea() * len(footprint.peaks) > self.config.maxAreaTimesPeaks:
1147 return True
1148 return False
1150 def _isMasked(self, footprint, mExposure):
1151 """Returns whether the footprint violates the mask limits"""
1152 bbox = footprint.getBBox()
1153 mask = np.bitwise_or.reduce(mExposure.mask[:, bbox].array, axis=0)
1154 size = float(footprint.getArea())
1155 for maskName, limit in self.config.maskLimits.items():
1156 maskVal = mExposure.mask.getPlaneBitMask(maskName)
1157 _mask = afwImage.MaskX(mask & maskVal, xy0=bbox.getMin())
1158 unmaskedSpan = footprint.spans.intersectNot(_mask) # spanset of unmasked pixels
1159 if (size - unmaskedSpan.getArea())/size > limit:
1160 return True
1161 return False
1163 def _skipParent(self, parent, skipKey, logMessage):
1164 """Update a parent record that is not being deblended.
1166 This is a fairly trivial function but is implemented to ensure
1167 that a skipped parent updates the appropriate columns
1168 consistently, and always has a flag to mark the reason that
1169 it is being skipped.
1171 Parameters
1172 ----------
1173 parent : `lsst.afw.table.source.source.SourceRecord`
1174 The parent record to flag as skipped.
1175 skipKey : `bool`
1176 The name of the flag to mark the reason for skipping.
1177 logMessage : `str`
1178 The message to display in a log.trace when a source
1179 is skipped.
1180 """
1181 if logMessage is not None:
1182 self.log.trace(logMessage)
1183 self._updateParentRecord(
1184 parent=parent,
1185 nPeaks=len(parent.getFootprint().peaks),
1186 nChild=0,
1187 runtime=np.nan,
1188 iterations=0,
1189 logL=np.nan,
1190 spectrumInit=False,
1191 converged=False,
1192 )
1194 # Mark the source as skipped by the deblender and
1195 # flag the reason why.
1196 parent.set(self.deblendSkippedKey, True)
1197 parent.set(skipKey, True)
1199 def _updateParentRecord(self, parent, nPeaks, nChild,
1200 runtime, iterations, logL, spectrumInit, converged):
1201 """Update a parent record in all of the single band catalogs.
1203 Ensure that all locations that update a parent record,
1204 whether it is skipped or updated after deblending,
1205 update all of the appropriate columns.
1207 Parameters
1208 ----------
1209 parent : `lsst.afw.table.source.source.SourceRecord`
1210 The parent record to update.
1211 nPeaks : `int`
1212 Number of peaks in the parent footprint.
1213 nChild : `int`
1214 Number of children deblended from the parent.
1215 This may differ from `nPeaks` if some of the peaks
1216 were culled and have no deblended model.
1217 runtime : `float`
1218 Total runtime for deblending.
1219 iterations : `int`
1220 Total number of iterations in scarlet before convergence.
1221 logL : `float`
1222 Final log likelihood of the blend.
1223 spectrumInit : `bool`
1224 True when scarlet used `set_spectra` to initialize all
1225 sources with better initial intensities.
1226 converged : `bool`
1227 True when the optimizer reached convergence before
1228 reaching the maximum number of iterations.
1229 """
1230 parent.set(self.nPeaksKey, nPeaks)
1231 parent.set(self.nChildKey, nChild)
1232 parent.set(self.runtimeKey, runtime)
1233 parent.set(self.iterKey, iterations)
1234 parent.set(self.scarletLogLKey, logL)
1235 parent.set(self.scarletSpectrumInitKey, spectrumInit)
1236 parent.set(self.blendConvergenceFailedFlagKey, converged)
1238 def _addChild(self, parent, mHeavy, catalog, scarletSource):
1239 """Add a child to a catalog.
1241 This creates a new child in the source catalog,
1242 assigning it a parent id, and adding all columns
1243 that are independent across all filter bands.
1245 Parameters
1246 ----------
1247 parent : `lsst.afw.table.source.source.SourceRecord`
1248 The parent of the new child record.
1249 mHeavy : `lsst.detection.MultibandFootprint`
1250 The multi-band footprint containing the model and
1251 peak catalog for the new child record.
1252 catalog : `lsst.afw.table.source.source.SourceCatalog`
1253 The merged `SourceCatalog` that contains parent footprints
1254 to (potentially) deblend.
1255 scarletSource : `scarlet.Component`
1256 The scarlet model for the new source record.
1257 """
1258 src = catalog.addNew()
1259 for key in self.toCopyFromParent:
1260 src.set(key, parent.get(key))
1261 # The peak catalog is the same for all bands,
1262 # so we just use the first peak catalog
1263 peaks = mHeavy[mHeavy.filters[0]].peaks
1264 src.assign(peaks[0], self.peakSchemaMapper)
1265 src.setParent(parent.getId())
1266 # Currently all children only have a single peak,
1267 # but it's possible in the future that there will be hierarchical
1268 # deblending, so we use the footprint to set the number of peaks
1269 # for each child.
1270 src.set(self.nPeaksKey, len(peaks))
1271 # Set the psf key based on whether or not the source was
1272 # deblended using the PointSource model.
1273 # This key is not that useful anymore since we now keep track of
1274 # `modelType`, but we continue to propagate it in case code downstream
1275 # is expecting it.
1276 src.set(self.psfKey, scarletSource.__class__.__name__ == "PointSource")
1277 src.set(self.modelTypeKey, scarletSource.__class__.__name__)
1278 # We set the runtime to zero so that summing up the
1279 # runtime column will give the total time spent
1280 # running the deblender for the catalog.
1281 src.set(self.runtimeKey, 0)
1283 # Set the position of the peak from the parent footprint
1284 # This will make it easier to match the same source across
1285 # deblenders and across observations, where the peak
1286 # position is unlikely to change unless enough time passes
1287 # for a source to move on the sky.
1288 peak = scarletSource.detectedPeak
1289 src.set(self.peakCenter, Point2I(peak["i_x"], peak["i_y"]))
1290 src.set(self.peakIdKey, peak["id"])
1292 # Propagate columns from the parent to the child
1293 for parentColumn, childColumn in self.config.columnInheritance.items():
1294 src.set(childColumn, parent.get(parentColumn))
1296 def _getCenterFlux(self, mHeavy, scarletSource, xy0):
1297 """Get the flux at the center of a HeavyFootprint
1299 Parameters
1300 ----------
1301 mHeavy : `lsst.detection.MultibandFootprint`
1302 The multi-band footprint containing the model for the source.
1303 scarletSource : `scarlet.Component`
1304 The scarlet model for the heavy footprint
1305 """
1306 # Store the flux at the center of the model and the total
1307 # scarlet flux measurement.
1308 mImage = mHeavy.getImage(fill=0.0).image
1310 # Set the flux at the center of the model (for SNR)
1311 try:
1312 cy, cx = scarletSource.center
1313 cy += xy0.y
1314 cx += xy0.x
1315 return mImage[:, cx, cy]
1316 except AttributeError:
1317 msg = "Did not recognize coordinates for source type of `{0}`, "
1318 msg += "could not write coordinates or center flux. "
1319 msg += "Add `{0}` to meas_extensions_scarlet to properly persist this information."
1320 logger.warning(msg.format(type(scarletSource)))
1321 return {f: np.nan for f in mImage.filters}