Coverage for python/lsst/meas/extensions/scarlet/io.py: 9%
151 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-03-26 11:02 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-03-26 11:02 +0000
1# This file is part of meas_extensions_scarlet.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
23import logging
25import numpy as np
27from lsst.afw.table import SourceCatalog
28from lsst.afw.image import MaskedImage, Exposure
29from lsst.afw.detection import Footprint as afwFootprint, HeavyFootprintF
30from lsst.afw.geom import SpanSet, Span
31from lsst.geom import Box2I, Extent2I, Point2I
32import lsst.scarlet.lite as scl
33from lsst.scarlet.lite import Blend, Source, Box, Component, FixedParameter, FactorizedComponent, Image
35from .metrics import setDeblenderMetrics
36from .utils import scarletModelToHeavy
39logger = logging.getLogger(__name__)
42def monochromaticDataToScarlet(
43 blendData: scl.io.ScarletBlendData,
44 bandIndex: int,
45 observation: scl.Observation,
46):
47 """Convert the storage data model into a scarlet lite blend
49 Parameters
50 ----------
51 blendData:
52 Persistable data for the entire blend.
53 bandIndex:
54 Index of model to extract.
55 observation:
56 Observation of region inside the bounding box.
58 Returns
59 -------
60 blend : `scarlet.lite.LiteBlend`
61 A scarlet blend model extracted from persisted data.
62 """
63 sources = []
64 # Use a dummy band, since we are only extracting a monochromatic model
65 # that will be turned into a HeavyFootprint.
66 bands = ("dummy", )
67 for sourceId, sourceData in blendData.sources.items():
68 components: list[Component] = []
69 # There is no need to distinguish factorized components from regular
70 # components, since there is only one band being used.
71 for componentData in sourceData.components:
72 bbox = Box(componentData.shape, origin=componentData.origin)
73 model = scl.io.Image(componentData.model[bandIndex][None, :, :], yx0=bbox.origin, bands=bands)
74 component = scl.io.ComponentCube(
75 bands=bands,
76 model=model,
77 peak=tuple(componentData.peak[::-1]),
78 bbox=bbox,
79 )
80 components.append(component)
82 for factorizedData in sourceData.factorized_components:
83 bbox = Box(factorizedData.shape, origin=factorizedData.origin)
84 # Add dummy values for properties only needed for
85 # model fitting.
86 spectrum = FixedParameter(factorizedData.spectrum)
87 totalBands = len(spectrum.x)
88 morph = FixedParameter(factorizedData.morph)
89 factorized = FactorizedComponent(
90 bands=("dummy", ) * totalBands,
91 spectrum=spectrum,
92 morph=morph,
93 peak=tuple(int(np.round(p)) for p in factorizedData.peak), # type: ignore
94 bbox=bbox,
95 bg_rms=np.full((totalBands,), np.nan),
96 )
97 model = factorized.get_model().data[bandIndex][None, :, :]
98 model = scl.io.Image(model, yx0=bbox.origin, bands=bands)
99 component = scl.io.ComponentCube(
100 bands=bands,
101 model=model,
102 peak=factorized.peak,
103 bbox=factorized.bbox,
104 )
105 components.append(component)
107 source = Source(components=components)
108 source.record_id = sourceId
109 source.peak_id = sourceData.peak_id
110 sources.append(source)
112 return Blend(sources=sources, observation=observation)
115def updateCatalogFootprints(
116 modelData: scl.io.ScarletModelData,
117 catalog: SourceCatalog,
118 band: str,
119 imageForRedistribution: MaskedImage | Exposure | None = None,
120 removeScarletData: bool = True,
121 updateFluxColumns: bool = True
122):
123 """Use the scarlet models to set HeavyFootprints for modeled sources
125 Parameters
126 ----------
127 catalog:
128 The catalog missing heavy footprints for deblended sources.
129 band:
130 The name of the band that the catalog data describes.
131 imageForRedistribution:
132 The image that is the source for flux re-distribution.
133 If `imageForRedistribution` is `None` then flux re-distribution is
134 not performed.
135 removeScarletData:
136 Whether or not to remove `ScarletBlendData` for each blend
137 in order to save memory.
138 updateFluxColumns:
139 Whether or not to update the `deblend_*` columns in the catalog.
140 This should only be true when the input catalog schema already
141 contains those columns.
142 """
143 # Iterate over the blends, since flux re-distribution must be done on
144 # all of the children with the same parent
145 parents = catalog[catalog["parent"] == 0]
147 for parentRecord in parents:
148 parentId = parentRecord.getId()
150 try:
151 blendModel = modelData.blends[parentId]
152 except KeyError:
153 # The parent was skipped in the deblender, so there are
154 # no models for its sources.
155 continue
157 parent = catalog.find(parentId)
158 if updateFluxColumns and imageForRedistribution is not None:
159 # Update the data coverage (1 - # of NO_DATA pixels/# of pixels)
160 parentRecord["deblend_dataCoverage"] = calculateFootprintCoverage(
161 parent.getFootprint(),
162 imageForRedistribution.mask
163 )
165 if band not in blendModel.bands:
166 peaks = parent.getFootprint().peaks
167 # Set the footprint and coverage of the sources in this blend
168 # to zero
169 for sourceId, sourceData in blendModel.sources.items():
170 sourceRecord = catalog.find(sourceId)
171 footprint = afwFootprint()
172 peakIdx = np.where(peaks["id"] == sourceData.peak_id)[0][0]
173 peak = peaks[peakIdx]
174 footprint.addPeak(peak.getIx(), peak.getIy(), peak.getPeakValue())
175 sourceRecord.setFootprint(footprint)
176 if updateFluxColumns:
177 sourceRecord["deblend_dataCoverage"] = 0
178 continue
180 # Get the index of the model for the given band
181 bandIndex = blendModel.bands.index(band)
183 updateBlendRecords(
184 blendData=blendModel,
185 catalog=catalog,
186 modelPsf=modelData.psf,
187 imageForRedistribution=imageForRedistribution,
188 bandIndex=bandIndex,
189 parentFootprint=parentRecord.getFootprint(),
190 updateFluxColumns=updateFluxColumns,
191 )
193 # Save memory by removing the data for the blend
194 if removeScarletData:
195 del modelData.blends[parentId]
198def calculateFootprintCoverage(footprint, maskImage):
199 """Calculate the fraction of pixels with no data in a Footprint
200 Parameters
201 ----------
202 footprint : `lsst.afw.detection.Footprint`
203 The footprint to check for missing data.
204 maskImage : `lsst.afw.image.MaskX`
205 The mask image with the ``NO_DATA`` bit set.
206 Returns
207 -------
208 coverage : `float`
209 The fraction of pixels in `footprint` where the ``NO_DATA`` bit is set.
210 """
211 # Store the value of "NO_DATA" from the mask plane.
212 noDataInt = 2**maskImage.getMaskPlaneDict()["NO_DATA"]
214 # Calculate the coverage in the footprint
215 bbox = footprint.getBBox()
216 if bbox.area == 0:
217 # The source has no footprint, so it has no coverage
218 return 0
219 spans = footprint.spans.asArray()
220 totalArea = footprint.getArea()
221 mask = maskImage[bbox].array & noDataInt
222 noData = (mask * spans) > 0
223 coverage = 1 - np.sum(noData)/totalArea
224 return coverage
227def updateBlendRecords(
228 blendData: scl.io.ScarletBlendData,
229 catalog: SourceCatalog,
230 modelPsf: np.ndarray,
231 imageForRedistribution: MaskedImage | Exposure | None,
232 bandIndex: int,
233 parentFootprint: afwFootprint,
234 updateFluxColumns: bool,
235):
236 """Create footprints and update band-dependent columns in the catalog
238 Parameters
239 ----------
240 blendData:
241 Persistable data for the entire blend.
242 catalog:
243 The catalog that is being updated.
244 modelPsf:
245 The 2D model of the PSF.
246 observedPsf:
247 The observed PSF model for the catalog.
248 imageForRedistribution:
249 The image that is the source for flux re-distribution.
250 If `imageForRedistribution` is `None` then flux re-distribution is
251 not performed.
252 bandIndex:
253 The number of the band to extract.
254 parentFootprint:
255 The footprint of the parent, used for masking out the model
256 when re-distributing flux.
257 updateFluxColumns:
258 Whether or not to update the `deblend_*` columns in the catalog.
259 This should only be true when the input catalog schema already
260 contains those columns.
261 """
262 useFlux = imageForRedistribution is not None
263 bands = ("dummy",)
264 # Only use the PSF for the current image
265 psfs = np.array([blendData.psf[bandIndex]])
267 if useFlux:
268 # Extract the image array to re-distribute its flux
269 xy0 = Point2I(*blendData.origin[::-1])
270 extent = Extent2I(*blendData.shape[::-1])
271 bbox = Box2I(xy0, extent)
273 images = Image(
274 imageForRedistribution[bbox].image.array[None, :, :],
275 yx0=blendData.origin,
276 bands=bands,
277 )
279 variance = Image(
280 imageForRedistribution[bbox].variance.array[None, :, :],
281 yx0=blendData.origin,
282 bands=bands,
283 )
285 weights = Image(
286 parentFootprint.spans.asArray()[None, :, :],
287 yx0=blendData.origin,
288 bands=bands,
289 )
291 observation = scl.io.Observation(
292 images=images,
293 variance=variance,
294 weights=weights,
295 psfs=psfs,
296 model_psf=modelPsf[None, :, :],
297 )
298 else:
299 observation = scl.io.Observation.empty(
300 bands=bands,
301 psfs=psfs,
302 model_psf=modelPsf[None, :, :],
303 bbox=Box(blendData.shape, blendData.origin),
304 dtype=np.float32,
305 )
307 blend = monochromaticDataToScarlet(
308 blendData=blendData,
309 bandIndex=bandIndex,
310 observation=observation,
311 )
313 if useFlux:
314 # Re-distribute the flux in the images
315 blend.conserve_flux()
317 # Set the metrics for the blend.
318 # TODO: remove this once DM-34558 runs all deblender metrics
319 # in a separate task.
320 if updateFluxColumns:
321 setDeblenderMetrics(blend)
323 # Update the HeavyFootprints for deblended sources
324 # and update the band-dependent catalog columns.
325 for source in blend.sources:
326 sourceRecord = catalog.find(source.record_id)
327 parent = catalog.find(sourceRecord["parent"])
328 peaks = parent.getFootprint().peaks
329 peakIdx = np.where(peaks["id"] == source.peak_id)[0][0]
330 source.detectedPeak = peaks[peakIdx]
331 # Set the Footprint
332 heavy = scarletModelToHeavy(
333 source=source,
334 blend=blend,
335 useFlux=useFlux,
336 )
338 if updateFluxColumns:
339 if heavy.getArea() == 0:
340 # The source has no flux after being weighted with the PSF
341 # in this particular band (it might have flux in others).
342 sourceRecord.set("deblend_zeroFlux", True)
343 # Create a Footprint with a single pixel, set to zero,
344 # to avoid breakage in measurement algorithms.
345 center = Point2I(heavy.peaks[0]["i_x"], heavy.peaks[0]["i_y"])
346 spanList = [Span(center.y, center.x, center.x)]
347 footprint = afwFootprint(SpanSet(spanList))
348 footprint.setPeakCatalog(heavy.peaks)
349 heavy = HeavyFootprintF(footprint)
350 heavy.getImageArray()[0] = 0.0
351 else:
352 sourceRecord.set("deblend_zeroFlux", False)
353 sourceRecord.setFootprint(heavy)
355 if useFlux:
356 # Set the fraction of pixels with valid data.
357 coverage = calculateFootprintCoverage(heavy, imageForRedistribution.mask)
358 sourceRecord.set("deblend_dataCoverage", coverage)
360 # Set the flux of the scarlet model
361 # TODO: this field should probably be deprecated,
362 # since DM-33710 gives users access to the scarlet models.
363 model = source.get_model().data[0]
364 sourceRecord.set("deblend_scarletFlux", np.sum(model))
366 # Set the flux at the center of the model
367 peak = heavy.peaks[0]
369 img = heavy.extractImage(fill=0.0)
370 try:
371 sourceRecord.set("deblend_peak_instFlux", img[Point2I(peak["i_x"], peak["i_y"])])
372 except Exception:
373 srcId = sourceRecord.getId()
374 x = peak["i_x"]
375 y = peak["i_y"]
376 logger.warning(
377 f"Source {srcId} at {x},{y} could not set the peak flux with error:",
378 exc_info=1
379 )
380 sourceRecord.set("deblend_peak_instFlux", np.nan)
382 # Set the metrics columns.
383 # TODO: remove this once DM-34558 runs all deblender metrics
384 # in a separate task.
385 sourceRecord.set("deblend_maxOverlap", source.metrics.maxOverlap[0])
386 sourceRecord.set("deblend_fluxOverlap", source.metrics.fluxOverlap[0])
387 sourceRecord.set("deblend_fluxOverlapFraction", source.metrics.fluxOverlapFraction[0])
388 sourceRecord.set("deblend_blendedness", source.metrics.blendedness[0])
389 else:
390 sourceRecord.setFootprint(heavy)
393def oldScarletToData(blend: Blend, psfCenter: tuple[int, int], xy0: Point2I):
394 """Convert a scarlet.lite blend into a persistable data object
396 Note: This converts a blend from the old version of scarlet.lite,
397 which is deprecated, to the persistable data format used in the
398 new scarlet lite package.
399 This is kept to compare the two scarlet versions,
400 and can be removed once the new lsst.scarlet.lite package is
401 used in production.
403 Parameters
404 ----------
405 blend:
406 The blend that is being persisted.
407 psfCenter:
408 The center of the PSF.
409 xy0:
410 The lower coordinate of the entire blend.
411 Returns
412 -------
413 blendData : `ScarletBlendDataModel`
414 The data model for a single blend.
415 """
416 from scarlet import lite
417 yx0 = (xy0.y, xy0.x)
419 sources = {}
420 for source in blend.sources:
421 components = []
422 factorizedComponents = []
423 for component in source.components:
424 origin = tuple(component.bbox.origin[i+1] + yx0[i] for i in range(2))
425 peak = tuple(component.center[i] + yx0[i] for i in range(2))
427 if isinstance(component, lite.LiteFactorizedComponent):
428 componentData = scl.io.ScarletFactorizedComponentData(
429 origin=origin,
430 peak=peak,
431 spectrum=component.sed,
432 morph=component.morph,
433 )
434 factorizedComponents.append(componentData)
435 else:
436 componentData = scl.io.ScarletComponentData(
437 origin=origin,
438 peak=peak,
439 model=component.get_model(),
440 )
441 components.append(componentData)
442 sourceData = scl.io.ScarletSourceData(
443 components=components,
444 factorized_components=factorizedComponents,
445 peak_id=source.peak_id,
446 )
447 sources[source.record_id] = sourceData
449 blendData = scl.io.ScarletBlendData(
450 origin=(xy0.y, xy0.x),
451 shape=blend.observation.bbox.shape[-2:],
452 sources=sources,
453 psf_center=psfCenter,
454 psf=blend.observation.psfs,
455 bands=blend.observation.bands,
456 )
458 return blendData