Coverage for python / lsst / meas / extensions / scarlet / utils.py: 10%
134 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-28 09:00 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-28 09:00 +0000
1import lsst.geom as geom
2import lsst.scarlet.lite as scl
3import numpy as np
4from scipy.signal import convolve
5from lsst.afw.detection import InvalidPsfError, Footprint as afwFootprint
6from lsst.afw.image import (
7 IncompleteDataError,
8 MultibandExposure,
9 MultibandImage,
10 Exposure,
11)
12from lsst.afw.image.utils import projectImage
13from lsst.afw.table import SourceCatalog
14from lsst.geom import Box2I, Point2D, Point2I
15from lsst.pipe.base import NoWorkFound
17defaultBadPixelMasks = ["BAD", "NO_DATA", "SAT", "SUSPECT", "EDGE"]
20def scarletBoxToBBox(box: scl.Box, xy0: geom.Point2I = geom.Point2I()) -> geom.Box2I:
21 """Convert a scarlet_lite Box into a Box2I.
23 Parameters
24 ----------
25 box:
26 The scarlet bounding box to convert.
27 xy0:
28 An additional offset to add to the scarlet box.
29 This is common since scarlet sources have an origin of
30 `(0,0)` at the lower left corner of the blend while
31 the blend itself is likely to have an offset in the
32 `Exposure`.
34 Returns
35 -------
36 bbox:
37 The converted bounding box.
38 """
39 xy0 = geom.Point2I(box.origin[-1] + xy0.x, box.origin[-2] + xy0.y)
40 extent = geom.Extent2I(box.shape[-1], box.shape[-2])
41 return geom.Box2I(xy0, extent)
44def bboxToScarletBox(bbox: geom.Box2I, xy0: geom.Point2I = geom.Point2I()) -> scl.Box:
45 """Convert a Box2I into a scarlet_lite Box.
47 Parameters
48 ----------
49 bbox:
50 The Box2I to convert into a scarlet `Box`.
51 xy0:
52 An overall offset to subtract from the `Box2I`.
53 This is common in blends, where `xy0` is the minimum pixel
54 location of the blend and `bbox` is the box containing
55 a source in the blend.
57 Returns
58 -------
59 box:
60 A scarlet `Box` that is more useful for slicing image data
61 as a numpy array.
62 """
63 origin = (bbox.getMinY() - xy0.y, bbox.getMinX() - xy0.x)
64 return scl.Box((bbox.getHeight(), bbox.getWidth()), origin)
67def multiband_convolve(images: np.ndarray, psfs: np.ndarray) -> np.ndarray:
68 """Convolve a multi-band image with the PSF in each band.
70 `images` and `psfs` should have dimensions `(bands, height, width)`.
72 Parameters
73 ----------
74 images :
75 The multi-band images to convolve.
76 psfs :
77 The PSF for each band.
79 Returns
80 -------
81 result :
82 The convolved images.
83 """
84 result = np.zeros(images.shape, dtype=images.dtype)
85 for bidx, (image, psf) in enumerate(zip(images, psfs, strict=True)):
86 result[bidx] = convolve(image, psf, mode="same")
87 return result
90def computePsfKernelImage(mExposure, psfCenter, catalog=None):
91 """Compute the PSF kernel image and update the multiband exposure
92 if not all of the PSF images could be computed.
94 Parameters
95 ----------
96 psfCenter : `tuple` or `Point2I` or `Point2D`
97 The location `(x, y)` used as the center of the PSF.
99 Returns
100 -------
101 psfModels : `np.ndarray`
102 The multiband PSF image
103 mExposure : `MultibandExposure`
104 The exposure, updated to only use bands that
105 successfully generated a PSF image.
106 """
107 if not isinstance(psfCenter, geom.Point2D):
108 psfCenter = geom.Point2D(*psfCenter)
110 try:
111 psfModels = mExposure.computePsfKernelImage(psfCenter)
112 except IncompleteDataError as e:
113 psfModels = e.partialPsf
114 if psfModels is None:
115 return None, None
116 # Use only the bands that successfully generated a PSF image.
117 bands = psfModels.bands
118 mExposure = mExposure[bands,]
119 if len(bands) == 1:
120 # Only a single band generated a PSF, so the MultibandExposure
121 # became a single band ExposureF.
122 # Convert the result back into a MultibandExposure.
123 mExposure = MultibandExposure.fromExposures(bands, [mExposure])
124 return psfModels.array, mExposure
127def computeNearestPsf(
128 calexp: Exposure,
129 catalog: SourceCatalog,
130 band: str | None = None,
131 psfCenter: Point2D | None = None,
132) -> tuple[np.ndarray, Point2I, float]:
133 """Create a PSF image at the nearest valid location
135 Sometimes not all locations in an image can generate a PSF image so the
136 source catalog is used to find the nearest valid location.
138 Parameters
139 ----------
140 calexp :
141 The exposure.
142 catalog :
143 The catalog.
144 band :
145 The band of the exposure used to filter the catalog by only
146 selecting sources that have a
147 If band is ``None`` then the full catalog is used.
148 psfCenter :
149 The location of the PSF image.
150 If no location is provided, the center of the exposure is used.
152 Returns
153 -------
154 psf :
155 The PSF image.
156 location :
157 The location of the PSF image.
158 diff :
159 The difference between the requested location and the
160 nearest valid location.
161 """
162 if psfCenter is None:
163 psfCenter = calexp.getBBox().getCenter()
165 if not isinstance(psfCenter, geom.Point2D):
166 psfCenter = geom.Point2D(*psfCenter)
168 try:
169 psf = calexp.getPsf().computeKernelImage(psfCenter)
170 return psf, psfCenter, 0
171 except InvalidPsfError:
172 pass
174 xc, yc = psfCenter
176 # Only select records that have detections in this band
177 if band is not None:
178 sources = catalog[catalog[f'merge_footprint_{band}']]
179 else:
180 sources = catalog
182 # Get the peaks of all of the sources
183 x = []
184 y = []
185 for src in sources:
186 for peak in src.getFootprint().peaks:
187 if band is None or peak[f'merge_peak_{band}']:
188 x.append(peak['i_x'])
189 y.append(peak['i_y'])
190 x = np.array(x)
191 y = np.array(y)
193 # Sort the peaks based on their distance to the location
194 diff_x = x - xc
195 diff_y = y - yc
196 sorted_indices = np.argsort(diff_x**2 + diff_y**2)
198 # Iterate over sources until a location is found that can generate a PSF
199 psf = None
200 for ref_index in sorted_indices:
201 try:
202 psf = calexp.getPsf().computeKernelImage(Point2D(x[ref_index], y[ref_index]))
203 break
204 except InvalidPsfError:
205 pass
206 if psf is None:
207 return None, None, None
208 newLocation = Point2I(x[ref_index], y[ref_index])
209 diff = np.sqrt(diff_x[ref_index]**2 + diff_y[ref_index]**2)
211 return psf, newLocation, diff
214def computeNearestPsfMultiBand(
215 mExposure: MultibandExposure,
216 psfCenter: tuple[int, int] | geom.Point2I | geom.Point2D,
217 catalog: SourceCatalog,
218) -> tuple[np.ndarray, MultibandExposure]:
219 """Compute the image in each band at the location nearest to the PSF Center
221 If the PSF cannot be generated in all bands then `mExposure` is updated
222 to use only the bands that successfully generated a PSF image.
224 Parameters
225 ----------
226 mExposure :
227 The multi-band exposure.
228 psfCenter :
229 The location `(x, y)` used as the center of the PSF.
230 catalog :
231 The source catalog.
232 """
233 psfs = {}
234 incomplete = False
235 for band in mExposure.bands:
236 psf, psfCenter, diff = computeNearestPsf(
237 mExposure[band,],
238 catalog,
239 band,
240 psfCenter,
241 )
242 if psf is None:
243 incomplete = True
244 else:
245 psfs[band] = psf
247 if len(psfs) == 0:
248 return None, None
250 left = np.min([psf.getBBox().getMinX() for psf in psfs.values()])
251 bottom = np.min([psf.getBBox().getMinY() for psf in psfs.values()])
252 right = np.max([psf.getBBox().getMaxX() for psf in psfs.values()])
253 top = np.max([psf.getBBox().getMaxY() for psf in psfs.values()])
254 bbox = Box2I(Point2I(left, bottom), Point2I(right, top))
256 psf_images = [projectImage(psf, bbox) for psf in psfs.values()]
258 mPsf = MultibandImage.fromImages(list(psfs.keys()), psf_images)
260 if incomplete:
261 bands = mPsf.bands
262 mExposure = mExposure[bands,]
264 if len(bands) == 1:
265 # Only a single band generated a PSF, so the MultibandExposure
266 # became a single band ExposureF.
267 # Convert the result back into a MultibandExposure.
268 mExposure = MultibandExposure.fromExposures(bands, [mExposure])
270 return mPsf.array, mExposure
273def buildObservation(
274 modelPsf: np.ndarray,
275 psfCenter: tuple[int, int] | geom.Point2I | geom.Point2D,
276 mExposure: MultibandExposure,
277 badPixelMasks: list[str] | None = None,
278 footprint: afwFootprint = None,
279 useWeights: bool = True,
280 convolutionType: str = "real",
281 catalog: SourceCatalog | None = None,
282) -> scl.Observation:
283 """Generate an Observation from a set of arguments.
285 Make the generation and reconstruction of a scarlet model consistent
286 by building an `Observation` from a set of arguments.
288 Parameters
289 ----------
290 modelPsf :
291 The 2D model of the PSF in the partially deconvolved space.
292 psfCenter :
293 The location `(x, y)` used as the center of the PSF.
294 mExposure :
295 The multi-band exposure that the model represents.
296 If `mExposure` is `None` then no image, variance, or weights are
297 attached to the observation.
298 footprint :
299 The footprint that is being fit.
300 If `footprint` is `None` then the weights are not updated to mask
301 out pixels not contained in the footprint.
302 badPixelMasks :
303 The keys from the bit mask plane used to mask out pixels
304 during the fit.
305 If `badPixelMasks` is `None` then the default values from
306 `ScarletDeblendConfig.badMask` are used.
307 useWeights :
308 Whether or not fitting should use inverse variance weights to
309 calculate the log-likelihood.
310 convolutionType :
311 The type of convolution to use (either "real" or "fft").
312 When reconstructing an image it is advised to use "real" to avoid
313 polluting the footprint with artifacts from the fft.
314 catalog :
315 A source catalog to use for PSFs that cannot be determined at
316 the center of the image.
318 Returns
319 -------
320 observation:
321 The observation constructed from the input parameters.
322 """
323 # Initialize the observed PSFs
324 if not isinstance(psfCenter, geom.Point2D):
325 psfCenter = geom.Point2D(*psfCenter)
326 if catalog is None:
327 psfModels, mExposure = computePsfKernelImage(mExposure, psfCenter)
328 else:
329 psfModels, mExposure = computeNearestPsfMultiBand(mExposure, psfCenter, catalog)
331 if psfModels is None:
332 raise NoWorkFound("No valid PSF could be obtained for building the observation")
334 # Use the inverse variance as the weights
335 if useWeights:
336 weights = 1 / mExposure.variance.array
337 weights[~np.isfinite(weights)] = 0
338 else:
339 weights = np.ones_like(mExposure.image.array)
341 # Mask out bad pixels
342 if badPixelMasks is None:
343 badPixelMasks = defaultBadPixelMasks
344 badPixels = mExposure.mask.getPlaneBitMask(badPixelMasks)
345 mask = mExposure.mask.array & badPixels
346 weights[mask > 0] = 0
348 if footprint is not None:
349 # Mask out the pixels outside the footprint
350 weights *= footprint.spans.asArray()
352 # Mask out non-finite pixels
353 image = mExposure.image.array.copy()
354 weights[~np.isfinite(image)] = 0
355 image[~np.isfinite(image)] = 0
357 return scl.Observation(
358 images=image,
359 variance=mExposure.variance.array,
360 weights=weights,
361 psfs=psfModels,
362 model_psf=modelPsf[None, :, :],
363 convolution_mode=convolutionType,
364 bands=mExposure.bands,
365 bbox=bboxToScarletBox(mExposure.getBBox()),
366 )
369def calcChi2(
370 model: scl.Image,
371 observation: scl.Observation,
372 footprint: np.ndarray | None = None,
373 doConvolve: bool = True,
374) -> scl.Image:
375 """Calculate the chi2 image for a model.
377 Parameters
378 ----------
379 model :
380 The model used to calculate the chi2.
381 observation :
382 The observation used to calculate the chi2.
383 footprint :
384 The footprint to use when calculating the chi2.
385 If `footprint` is `None` then the footprint is calculated
386 to be the pixels where the model is greater than 0.
387 doConvolve :
388 Whether or not to convolve the model with the PSF.
390 Returns
391 -------
392 chi2 :
393 The chi2/pixel image for the model.
394 """
395 if doConvolve:
396 model = observation.convolve(model)
397 if footprint is None:
398 footprint = model.data > 0
399 bbox = model.bbox
400 nBands = len(observation.images.bands)
401 residual = (observation.images[:, bbox].data - model.data) * footprint
402 cuts = observation.variance[:, bbox].data != 0
403 chi2Data = np.zeros(residual.shape, dtype=residual.dtype)
404 chi2Data[cuts] = residual[cuts]**2 / observation.variance[:, bbox].data[cuts] / nBands
405 chi2 = scl.Image(
406 chi2Data,
407 bands=model.bands,
408 yx0=model.yx0,
409 )
410 return chi2