Coverage for python / lsst / meas / extensions / scarlet / deconvolveExposureTask.py: 25%
120 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-01 08:34 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-01 08:34 +0000
1# This file is part of meas_extensions_scarlet.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import logging
24import lsst.afw.detection as afwDet
25import lsst.afw.image as afwImage
26import lsst.afw.table as afwTable
27import lsst.pex.config as pexConfig
28import lsst.pipe.base as pipeBase
29import lsst.pipe.base.connectionTypes as cT
30import lsst.scarlet.lite as scl
31import numpy as np
33from . import utils
35log = logging.getLogger(__name__)
37__all__ = [
38 "DeconvolveExposureTask",
39 "DeconvolveExposureConfig",
40 "DeconvolveExposureConnections",
41]
44def calculate_update_step(
45 observation: scl.Observation,
46 min_scale: float = 0.01,
47 default_scale: float = 0.1,
48) -> float:
49 """Calculate the scale factor for the update step in deconvolution.
51 For most images this will be 1.0 but for images with low SNR
52 and/or high sparsity (for example LSST u-band images) the scale
53 factor will be less than 1.0.
55 Parameters
56 ----------
57 observation :
58 Scarlet lite Observation.
60 min_scale :
61 Minimum allowed scale factor.
63 default_scale :
64 Default scale factor to return if noise level is non-finite.
66 Returns
67 -------
68 scale : float
69 Scale factor for the update step.
70 """
71 # Calculate sparsity as fraction of pixels significantly above noise
72 noise_level = observation.noise_rms[0]
73 # Guard against non-finite or non-positive noise levels
74 if noise_level <= 0 or not np.isfinite(noise_level):
75 return default_scale
76 signal_mask = observation.images.data > 3*noise_level
77 signal_pixels = np.sum(signal_mask)
78 sparsity = signal_pixels / observation.images.data.size
80 if np.any(signal_mask):
81 median_signal = np.median(observation.images.data[signal_mask])
82 snr = median_signal / noise_level
83 else:
84 snr = 1.0
86 # Scale factor that decreases with sparsity and increases with SNR
87 scale = min(1.0, (sparsity * np.sqrt(snr)) / 0.1)
89 return max(min_scale, scale)
92class DeconvolveExposureConnections(
93 pipeBase.PipelineTaskConnections,
94 dimensions=("tract", "patch", "skymap", "band"),
95 defaultTemplates={"inputCoaddName": "deep"},
96):
97 """Connections for DeconvolveExposureTask"""
99 coadd = cT.Input(
100 doc="Exposure to deconvolve",
101 name="{inputCoaddName}Coadd_calexp",
102 storageClass="ExposureF",
103 dimensions=("tract", "patch", "band", "skymap"),
104 )
106 coadd_cell = cT.Input(
107 doc="Exposure on which to run deblending",
108 name="{inputCoaddName}CoaddCell",
109 storageClass="MultipleCellCoadd",
110 dimensions=("tract", "patch", "band", "skymap")
111 )
113 background = cT.Input(
114 doc="Background model to subtract from the cell-based coadd",
115 name="{inputCoaddName}Coadd_calexp_background",
116 storageClass="Background",
117 dimensions=("tract", "patch", "band", "skymap")
118 )
120 catalog = cT.Input(
121 doc="Catalog of sources detected in the deconvolved image",
122 name="{inputCoaddName}Coadd_mergeDet",
123 storageClass="SourceCatalog",
124 dimensions=("tract", "patch", "skymap"),
125 )
127 deconvolved = cT.Output(
128 doc="Deconvolved exposure",
129 name="deconvolved_{inputCoaddName}_coadd",
130 storageClass="ExposureF",
131 dimensions=("tract", "patch", "band", "skymap"),
132 )
134 def __init__(self, *, config=None):
135 if not config.useFootprints:
136 # Deconvolution will not use input catalog if
137 # footprints are not used
138 self.inputs.remove("catalog")
140 if config.useCellCoadds:
141 del self.coadd
142 else:
143 del self.coadd_cell
144 del self.background
147class DeconvolveExposureConfig(
148 pipeBase.PipelineTaskConfig,
149 pipelineConnections=DeconvolveExposureConnections,
150):
151 """Configuration for DeconvolveExposureTask"""
153 maxIter = pexConfig.Field[int](
154 doc="Maximum number of iterations",
155 default=100,
156 )
157 minIter = pexConfig.Field[int](
158 doc="Minimum number of iterations",
159 default=10,
160 )
161 eRel = pexConfig.Field[float](
162 doc="Relative error threshold",
163 default=1e-3,
164 )
165 backgroundThreshold = pexConfig.Field[float](
166 default=0,
167 doc="Threshold for background subtraction. "
168 "Pixels in the fit below this threshold will be set to zero",
169 )
170 useFootprints = pexConfig.Field[bool](
171 default=True,
172 doc="Use footprints to constrain the deconvolved model",
173 )
174 useCellCoadds = pexConfig.Field[bool](
175 doc="Use cell-based coadd instead of regular coadd?",
176 default=False,
177 )
180class DeconvolveExposureTask(pipeBase.PipelineTask):
181 """Deconvolve an Exposure using scarlet lite."""
183 ConfigClass = DeconvolveExposureConfig
184 _DefaultName = "deconvolveExposure"
186 def __init__(self, initInputs=None, **kwargs):
187 if initInputs is None:
188 initInputs = {}
189 super().__init__(initInputs=initInputs, **kwargs)
191 def runQuantum(self, butlerQC, inputRefs, outputRefs):
192 inputs = butlerQC.get(inputRefs)
194 # Stitch together cell-based coadds (if necessary)
195 if self.config.useCellCoadds:
196 band = inputRefs.coadd_cell.dataId['band']
197 cellCoadd = inputs.pop('coadd_cell')
198 background = inputs.pop('background')
199 coadd = cellCoadd.stitch().asExposure()
200 coadd.image -= background.getImage()
201 else:
202 coadd = inputs.pop("coadd")
203 band = inputRefs.coadd.dataId['band']
205 catalog = inputs.pop('catalog', None)
207 assert not inputs, "runQuantum got more inputs than expected."
208 outputs = self.run(
209 coadd=coadd,
210 catalog=catalog,
211 band=band,
212 )
213 butlerQC.put(outputs, outputRefs)
215 def run(
216 self,
217 coadd: afwImage.Exposure,
218 catalog: afwTable.SourceCatalog | None = None,
219 band: str = 'dummy'
220 ) -> pipeBase.Struct:
221 """Deconvolve an Exposure
223 Parameters
224 ----------
225 coadd :
226 Coadd image to deconvolve
228 catalog :
229 Catalog of sources detected in the merged catalog.
230 This is used to supress noise in regions with no
231 significant flux about the noise in the coadds.
233 band :
234 Band of the coadd image.
235 Since this is a single band task the band isn't really necessary
236 but can be useful for debugging so we keep it as a parameter.
238 Returns
239 -------
240 deconvolved : `pipeBase.Struct`
241 Deconvolved exposure
242 """
243 observation = self._buildObservation(coadd, catalog, band)
244 self.bbox = coadd.getBBox()
246 # Deconvolve.
247 # Store the loss history for debugging purposes.
248 model, self.loss = self._deconvolve(observation, catalog)
250 # Store the model in an Exposure
251 exposure = self._modelToExposure(model.data[0], coadd)
252 return pipeBase.Struct(deconvolved=exposure)
254 def _buildObservation(
255 self,
256 coadd: afwImage.Exposure,
257 catalog: afwTable.SourceCatalog | None = None,
258 band: str = 'dummy'
259 ) -> scl.Observation:
260 """Build a scarlet lite Observation from an Exposure.
262 We don't actually use scarlet, but the optimized convolutions
263 using scarlet data products are still useful.
265 Parameters
266 ----------
267 coadd :
268 Coadd image to deconvolve.
269 catalog :
270 Catalog of sources.
271 This is used to find a location for the PSF if it cannot be
272 generated at the center of the coadd.
274 band :
275 Band of the coadd image.
277 """
278 bands = (band,)
279 model_psf = scl.utils.integrated_circular_gaussian(sigma=0.8)
281 # Give zero weight to non-finite pixels
282 weights = np.ones_like(coadd.image.array)
283 weights[~np.isfinite(coadd.image.array)] = 0
285 image = coadd.image.array.copy()
286 # Set non-finite pixels to zero
287 image[~np.isfinite(image)] = 0.0
288 psfCenter = coadd.getBBox().getCenter()
289 if catalog is not None:
290 psf, _, _ = utils.computeNearestPsf(coadd, catalog, band, psfCenter)
291 if psf is None:
292 # There were no valid locations from
293 # which a PSF could be obtained
294 raise pipeBase.NoWorkFound("No valid PSF could be obtained for deconvolution")
295 psf = psf.array
296 else:
297 psf = coadd.getPsf().computeKernelImage(psfCenter).array
299 badPixelMasks = utils.defaultBadPixelMasks
300 badPixels = coadd.mask.getPlaneBitMask(badPixelMasks)
301 mask = coadd.mask.array & badPixels
302 weights[mask > 0] = 0
304 observation = scl.Observation(
305 images=image[None],
306 variance=coadd.variance.array.copy()[None],
307 weights=weights[None],
308 psfs=psf[None],
309 model_psf=model_psf[None],
310 convolution_mode="fft",
311 bands=bands,
312 bbox=utils.bboxToScarletBox(coadd.getBBox()),
313 )
314 return observation
316 def _deconvolve(
317 self,
318 observation: scl.Observation,
319 catalog: afwTable.SourceCatalog | None = None,
320 ) -> tuple[scl.Image, list[float]]:
321 """Deconvolve the observed image.
323 Parameters
324 ----------
325 observation :
326 Scarlet lite Observation.
327 catalog :
328 Catalog of sources detected in the deconvolved image.
329 This is used to mask the deconvolved image so that
330 the deconvolved footprints detected downstream will always
331 fit inside of the original footprints.
332 """
333 model = observation.images.copy()
334 loss = []
335 step = calculate_update_step(observation)
336 if catalog is not None:
337 width, height = self.bbox.getDimensions()
338 x0, y0 = self.bbox.getMin()
339 footprintImage = afwDet.footprintsToNumpy(catalog, shape=(height, width), xy0=(x0, y0))
340 for n in range(self.config.maxIter):
341 residual = observation.images - observation.convolve(model)
342 loss.append(-0.5 * np.sum(residual.data**2))
343 update = observation.convolve(residual, grad=True)
344 update.data[:] *= step
345 model += update
346 model.data[(model.data < 0) | ~np.isfinite(model.data)] = 0
347 if catalog is not None:
348 # Ensure that the deconvolved model footprints fit
349 # inside of the original footprints by setting regions
350 # outside of the original footprints to zero.
351 model.data[:] *= footprintImage
353 # Check for a diverging model
354 if len(loss) > 1 and loss[-1] < loss[-2]:
355 step = step / 2
356 self.log.warning(f"Loss increased at iteration {n}, decreasing scale to {step}")
358 # Check for convergence
359 if n > self.config.minIter and np.abs(loss[-1] - loss[-2]) < self.config.eRel * np.abs(loss[-1]):
360 break
362 return model, loss
364 def _modelToExposure(self, model: np.ndarray, coadd: afwImage.Exposure) -> afwImage.Exposure:
365 """Convert a scarlet lite Image to an Exposure.
367 Parameters
368 ----------
369 image :
370 Scarlet lite Image.
371 """
372 image = afwImage.Image(
373 array=model,
374 xy0=coadd.getBBox().getMin(),
375 deep=False,
376 dtype=coadd.image.array.dtype,
377 )
378 maskedImage = afwImage.MaskedImage(
379 image=image,
380 mask=coadd.mask,
381 variance=coadd.variance,
382 dtype=coadd.image.array.dtype,
383 )
384 exposure = afwImage.Exposure(
385 maskedImage=maskedImage,
386 exposureInfo=coadd.getInfo(),
387 dtype=coadd.image.array.dtype,
388 )
389 return exposure