Coverage for python/lsst/meas/extensions/scarlet/io.py: 9%
140 statements
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-01 13:52 +0000
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-01 13:52 +0000
1# This file is part of meas_extensions_scarlet.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
23import logging
25import numpy as np
27from lsst.afw.table import SourceCatalog
28from lsst.afw.image import MaskedImage, Exposure
29from lsst.afw.detection import Footprint as afwFootprint
30from lsst.geom import Box2I, Extent2I, Point2I
31import lsst.scarlet.lite as scl
32from lsst.scarlet.lite import Blend, Source, Box, Component, FixedParameter, FactorizedComponent, Image
34from .metrics import setDeblenderMetrics
35from .utils import scarletModelToHeavy
38logger = logging.getLogger(__name__)
41def monochromaticDataToScarlet(
42 blendData: scl.io.ScarletBlendData,
43 bandIndex: int,
44 observation: scl.Observation,
45):
46 """Convert the storage data model into a scarlet lite blend
48 Parameters
49 ----------
50 blendData:
51 Persistable data for the entire blend.
52 bandIndex:
53 Index of model to extract.
54 observation:
55 Observation of region inside the bounding box.
57 Returns
58 -------
59 blend : `scarlet.lite.LiteBlend`
60 A scarlet blend model extracted from persisted data.
61 """
62 sources = []
63 # Use a dummy band, since we are only extracting a monochromatic model
64 # that will be turned into a HeavyFootprint.
65 bands = ("dummy", )
66 for sourceId, sourceData in blendData.sources.items():
67 components: list[Component] = []
68 # There is no need to distinguish factorized components from regular
69 # components, since there is only one band being used.
70 for componentData in sourceData.components:
71 bbox = Box(componentData.shape, origin=componentData.origin)
72 model = scl.io.Image(componentData.model[bandIndex][None, :, :], yx0=bbox.origin, bands=bands)
73 component = scl.io.ComponentCube(
74 bands=bands,
75 model=model,
76 peak=tuple(componentData.peak[::-1]),
77 bbox=bbox,
78 )
79 components.append(component)
81 for factorizedData in sourceData.factorized_components:
82 bbox = Box(factorizedData.shape, origin=factorizedData.origin)
83 # Add dummy values for properties only needed for
84 # model fitting.
85 spectrum = FixedParameter(factorizedData.spectrum)
86 totalBands = len(spectrum.x)
87 morph = FixedParameter(factorizedData.morph)
88 factorized = FactorizedComponent(
89 bands=("dummy", ) * totalBands,
90 spectrum=spectrum,
91 morph=morph,
92 peak=tuple(int(np.round(p)) for p in factorizedData.peak), # type: ignore
93 bbox=bbox,
94 bg_rms=np.full((totalBands,), np.nan),
95 )
96 model = factorized.get_model().data[bandIndex][None, :, :]
97 model = scl.io.Image(model, yx0=bbox.origin, bands=bands)
98 component = scl.io.ComponentCube(
99 bands=bands,
100 model=model,
101 peak=factorized.peak,
102 bbox=factorized.bbox,
103 )
104 components.append(component)
106 source = Source(components=components)
107 source.record_id = sourceId
108 source.peak_id = sourceData.peak_id
109 sources.append(source)
111 return Blend(sources=sources, observation=observation)
114def updateCatalogFootprints(
115 modelData: scl.io.ScarletModelData,
116 catalog: SourceCatalog,
117 band: str,
118 imageForRedistribution: MaskedImage | Exposure | None = None,
119 removeScarletData: bool = True,
120 updateFluxColumns: bool = True
121):
122 """Use the scarlet models to set HeavyFootprints for modeled sources
124 Parameters
125 ----------
126 catalog:
127 The catalog missing heavy footprints for deblended sources.
128 band:
129 The name of the band that the catalog data describes.
130 imageForRedistribution:
131 The image that is the source for flux re-distribution.
132 If `imageForRedistribution` is `None` then flux re-distribution is
133 not performed.
134 removeScarletData:
135 Whether or not to remove `ScarletBlendData` for each blend
136 in order to save memory.
137 updateFluxColumns:
138 Whether or not to update the `deblend_*` columns in the catalog.
139 This should only be true when the input catalog schema already
140 contains those columns.
141 """
142 # Iterate over the blends, since flux re-distribution must be done on
143 # all of the children with the same parent
144 parents = catalog[catalog["parent"] == 0]
146 for parentRecord in parents:
147 parentId = parentRecord.getId()
149 try:
150 blendModel = modelData.blends[parentId]
151 except KeyError:
152 # The parent was skipped in the deblender, so there are
153 # no models for its sources.
154 continue
156 parent = catalog.find(parentId)
157 if updateFluxColumns and imageForRedistribution is not None:
158 # Update the data coverage (1 - # of NO_DATA pixels/# of pixels)
159 parentRecord["deblend_dataCoverage"] = calculateFootprintCoverage(
160 parent.getFootprint(),
161 imageForRedistribution.mask
162 )
164 if band not in blendModel.bands:
165 peaks = parent.getFootprint().peaks
166 # Set the footprint and coverage of the sources in this blend
167 # to zero
168 for sourceId, sourceData in blendModel.sources.items():
169 sourceRecord = catalog.find(sourceId)
170 footprint = afwFootprint()
171 peakIdx = np.where(peaks["id"] == sourceData.peak_id)[0][0]
172 peak = peaks[peakIdx]
173 footprint.addPeak(peak.getIx(), peak.getIy(), peak.getPeakValue())
174 sourceRecord.setFootprint(footprint)
175 if updateFluxColumns:
176 sourceRecord["deblend_dataCoverage"] = 0
177 continue
179 # Get the index of the model for the given band
180 bandIndex = blendModel.bands.index(band)
182 updateBlendRecords(
183 blendData=blendModel,
184 catalog=catalog,
185 modelPsf=modelData.psf,
186 imageForRedistribution=imageForRedistribution,
187 bandIndex=bandIndex,
188 parentFootprint=parentRecord.getFootprint(),
189 updateFluxColumns=updateFluxColumns,
190 )
192 # Save memory by removing the data for the blend
193 if removeScarletData:
194 del modelData.blends[parentId]
197def calculateFootprintCoverage(footprint, maskImage):
198 """Calculate the fraction of pixels with no data in a Footprint
199 Parameters
200 ----------
201 footprint : `lsst.afw.detection.Footprint`
202 The footprint to check for missing data.
203 maskImage : `lsst.afw.image.MaskX`
204 The mask image with the ``NO_DATA`` bit set.
205 Returns
206 -------
207 coverage : `float`
208 The fraction of pixels in `footprint` where the ``NO_DATA`` bit is set.
209 """
210 # Store the value of "NO_DATA" from the mask plane.
211 noDataInt = 2**maskImage.getMaskPlaneDict()["NO_DATA"]
213 # Calculate the coverage in the footprint
214 bbox = footprint.getBBox()
215 if bbox.area == 0:
216 # The source has no footprint, so it has no coverage
217 return 0
218 spans = footprint.spans.asArray()
219 totalArea = footprint.getArea()
220 mask = maskImage[bbox].array & noDataInt
221 noData = (mask * spans) > 0
222 coverage = 1 - np.sum(noData)/totalArea
223 return coverage
226def updateBlendRecords(
227 blendData: scl.io.ScarletBlendData,
228 catalog: SourceCatalog,
229 modelPsf: np.ndarray,
230 imageForRedistribution: MaskedImage | Exposure | None,
231 bandIndex: int,
232 parentFootprint: afwFootprint,
233 updateFluxColumns: bool,
234):
235 """Create footprints and update band-dependent columns in the catalog
237 Parameters
238 ----------
239 blendData:
240 Persistable data for the entire blend.
241 catalog:
242 The catalog that is being updated.
243 modelPsf:
244 The 2D model of the PSF.
245 observedPsf:
246 The observed PSF model for the catalog.
247 imageForRedistribution:
248 The image that is the source for flux re-distribution.
249 If `imageForRedistribution` is `None` then flux re-distribution is
250 not performed.
251 bandIndex:
252 The number of the band to extract.
253 parentFootprint:
254 The footprint of the parent, used for masking out the model
255 when re-distributing flux.
256 updateFluxColumns:
257 Whether or not to update the `deblend_*` columns in the catalog.
258 This should only be true when the input catalog schema already
259 contains those columns.
260 """
261 useFlux = imageForRedistribution is not None
262 bands = ("dummy",)
263 # Only use the PSF for the current image
264 psfs = np.array([blendData.psf[bandIndex]])
266 if useFlux:
267 # Extract the image array to re-distribute its flux
268 xy0 = Point2I(*blendData.origin[::-1])
269 extent = Extent2I(*blendData.shape[::-1])
270 bbox = Box2I(xy0, extent)
272 images = Image(
273 imageForRedistribution[bbox].image.array[None, :, :],
274 yx0=blendData.origin,
275 bands=bands,
276 )
278 variance = Image(
279 imageForRedistribution[bbox].variance.array[None, :, :],
280 yx0=blendData.origin,
281 bands=bands,
282 )
284 weights = Image(
285 parentFootprint.spans.asArray()[None, :, :],
286 yx0=blendData.origin,
287 bands=bands,
288 )
290 observation = scl.io.Observation(
291 images=images,
292 variance=variance,
293 weights=weights,
294 psfs=psfs,
295 model_psf=modelPsf[None, :, :],
296 )
297 else:
298 observation = scl.io.Observation.empty(
299 bands=bands,
300 psfs=psfs,
301 model_psf=modelPsf[None, :, :],
302 bbox=Box(blendData.shape, blendData.origin),
303 dtype=np.float32,
304 )
306 blend = monochromaticDataToScarlet(
307 blendData=blendData,
308 bandIndex=bandIndex,
309 observation=observation,
310 )
312 if useFlux:
313 # Re-distribute the flux in the images
314 blend.conserve_flux()
316 # Set the metrics for the blend.
317 # TODO: remove this once DM-34558 runs all deblender metrics
318 # in a separate task.
319 if updateFluxColumns:
320 setDeblenderMetrics(blend)
322 # Update the HeavyFootprints for deblended sources
323 # and update the band-dependent catalog columns.
324 for source in blend.sources:
325 sourceRecord = catalog.find(source.record_id)
326 parent = catalog.find(sourceRecord["parent"])
327 peaks = parent.getFootprint().peaks
328 peakIdx = np.where(peaks["id"] == source.peak_id)[0][0]
329 source.detectedPeak = peaks[peakIdx]
330 # Set the Footprint
331 heavy = scarletModelToHeavy(
332 source=source,
333 blend=blend,
334 useFlux=useFlux,
335 )
336 sourceRecord.setFootprint(heavy)
338 if updateFluxColumns:
339 if useFlux:
340 # Set the fraction of pixels with valid data.
341 coverage = calculateFootprintCoverage(heavy, imageForRedistribution.mask)
342 sourceRecord.set("deblend_dataCoverage", coverage)
344 # Set the flux of the scarlet model
345 # TODO: this field should probably be deprecated,
346 # since DM-33710 gives users access to the scarlet models.
347 model = source.get_model().data[0]
348 sourceRecord.set("deblend_scarletFlux", np.sum(model))
350 # Set the flux at the center of the model
351 peak = heavy.peaks[0]
353 img = heavy.extractImage(fill=0.0)
354 try:
355 sourceRecord.set("deblend_peak_instFlux", img[Point2I(peak["i_x"], peak["i_y"])])
356 except Exception:
357 srcId = sourceRecord.getId()
358 x = peak["i_x"]
359 y = peak["i_y"]
360 logger.warning(
361 f"Source {srcId} at {x},{y} could not set the peak flux with error:",
362 exc_info=1
363 )
364 sourceRecord.set("deblend_peak_instFlux", np.nan)
366 # Set the metrics columns.
367 # TODO: remove this once DM-34558 runs all deblender metrics
368 # in a separate task.
369 sourceRecord.set("deblend_maxOverlap", source.metrics.maxOverlap[0])
370 sourceRecord.set("deblend_fluxOverlap", source.metrics.fluxOverlap[0])
371 sourceRecord.set("deblend_fluxOverlapFraction", source.metrics.fluxOverlapFraction[0])
372 sourceRecord.set("deblend_blendedness", source.metrics.blendedness[0])
375def oldScarletToData(blend: Blend, psfCenter: tuple[int, int], xy0: Point2I):
376 """Convert a scarlet.lite blend into a persistable data object
378 Note: This converts a blend from the old version of scarlet.lite,
379 which is deprecated, to the persistable data format used in the
380 new scarlet lite package.
381 This is kept to compare the two scarlet versions,
382 and can be removed once the new lsst.scarlet.lite package is
383 used in production.
385 Parameters
386 ----------
387 blend:
388 The blend that is being persisted.
389 psfCenter:
390 The center of the PSF.
391 xy0:
392 The lower coordinate of the entire blend.
393 Returns
394 -------
395 blendData : `ScarletBlendDataModel`
396 The data model for a single blend.
397 """
398 from scarlet import lite
399 yx0 = (xy0.y, xy0.x)
401 sources = {}
402 for source in blend.sources:
403 components = []
404 factorizedComponents = []
405 for component in source.components:
406 origin = tuple(component.bbox.origin[i+1] + yx0[i] for i in range(2))
407 peak = tuple(component.center[i] + yx0[i] for i in range(2))
409 if isinstance(component, lite.LiteFactorizedComponent):
410 componentData = scl.io.ScarletFactorizedComponentData(
411 origin=origin,
412 peak=peak,
413 spectrum=component.sed,
414 morph=component.morph,
415 )
416 factorizedComponents.append(componentData)
417 else:
418 componentData = scl.io.ScarletComponentData(
419 origin=origin,
420 peak=peak,
421 model=component.get_model(),
422 )
423 components.append(componentData)
424 sourceData = scl.io.ScarletSourceData(
425 components=components,
426 factorized_components=factorizedComponents,
427 peak_id=source.peak_id,
428 )
429 sources[source.record_id] = sourceData
431 blendData = scl.io.ScarletBlendData(
432 origin=(xy0.y, xy0.x),
433 shape=blend.observation.bbox.shape[-2:],
434 sources=sources,
435 psf_center=psfCenter,
436 psf=blend.observation.psfs,
437 bands=blend.observation.bands,
438 )
440 return blendData