Coverage for python / lsst / meas / extensions / scarlet / io / utils.py: 18%
236 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-07 08:32 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-07 08:32 +0000
1# This file is part of meas_extensions_scarlet.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24from collections.abc import Mapping
25from io import BytesIO
26import logging
27import json
28from typing import Any, BinaryIO, cast
29import zipfile
31import numpy as np
32from pydantic_core import from_json
34import lsst.scarlet.lite as scl
35from lsst.afw.detection import Footprint as afwFootprint
36from lsst.afw.detection import HeavyFootprintF
37from lsst.afw.geom import Span, SpanSet
38from lsst.afw.image import Exposure, MaskedImage, MaskX, MultibandExposure
39from lsst.afw.table import SourceCatalog
40import lsst.utils as lsst_utils
41from lsst.daf.butler import StorageClassDelegate
42from lsst.daf.butler import FormatterV2
43from lsst.geom import Point2I, Extent2I, Box2I
44from lsst.pipe.base import NoWorkFound
45from lsst.resources import ResourceHandleProtocol
46from lsst.scarlet.lite import (
47 Box,
48 FactorizedComponent,
49 FixedParameter,
50)
52from ..metrics import setDeblenderMetrics
53from .. import utils
54from ..footprint import scarletModelToHeavy
55from .model_data import LsstScarletModelData
57logger = logging.getLogger(__name__)
59__all__ = [
60 "monochromaticDataToScarlet",
61 "updateCatalogFootprints",
62 "buildMonochromaticObservation",
63 "calculateFootprintCoverage",
64 "updateBlendRecords",
65 "ScarletModelFormatter",
66 "ScarletModelDelegate",
67 "loadBlend",
68]
70# The name of the band in an monochome blend.
71# This is used as a placeholder since the band is not used in the
72# monochromatic model.
73monochromaticBand = "dummy"
74monochromaticBands = (monochromaticBand,)
77def monochromaticDataToScarlet(
78 blendData: scl.io.ScarletBlendData,
79 bandIndex: int,
80 observation: scl.Observation,
81):
82 """Convert the storage data model into a scarlet lite blend
84 Parameters
85 ----------
86 blendData:
87 Persistable data for the entire blend.
88 bandIndex:
89 Index of model to extract.
90 observation:
91 Observation of region inside the bounding box.
93 Returns
94 -------
95 blend : `scarlet.lite.LiteBlend`
96 A scarlet blend model extracted from persisted data.
97 """
98 sources = []
99 # Use a dummy band, since we are only extracting a monochromatic model
100 # that will be turned into a HeavyFootprint.
101 bands = monochromaticBands
102 for sourceId, sourceData in blendData.sources.items():
103 components: list[scl.Component] = []
104 # There is no need to distinguish factorized components from regular
105 # components, since there is only one band being used.
106 for componentData in sourceData.components:
107 if componentData.component_type == "component":
108 bbox = Box(componentData.shape, origin=componentData.origin)
109 model = scl.Image(
110 componentData.model[bandIndex][None, :, :], yx0=bbox.origin, bands=bands
111 )
112 component = scl.ComponentCube(
113 model=model,
114 peak=tuple(componentData.peak[::-1]),
115 )
116 components.append(component)
117 else:
118 bbox = Box(componentData.shape, origin=componentData.origin)
119 # Add dummy values for properties only needed for
120 # model fitting.
121 spectrum = FixedParameter(componentData.spectrum)
122 totalBands = len(spectrum.x)
123 morph = FixedParameter(componentData.morph)
124 factorized = FactorizedComponent(
125 bands=("dummy",) * totalBands,
126 spectrum=spectrum,
127 morph=morph,
128 peak=tuple(int(np.round(p)) for p in componentData.peak), # type: ignore
129 bbox=bbox,
130 bg_rms=np.full((totalBands,), np.nan),
131 )
132 model = factorized.get_model().data[bandIndex][None, :, :]
133 model = scl.Image(model, yx0=bbox.origin, bands=bands)
134 component = scl.component.CubeComponent(
135 model=model,
136 peak=factorized.peak,
137 )
138 components.append(component)
140 source = scl.Source(components=components, metadata=sourceData.metadata)
141 sources.append(source)
143 bbox = scl.Box(blendData.shape, origin=blendData.origin)
144 blend = scl.Blend(sources=sources, observation=observation[:, bbox])
145 return blend
148def updateCatalogFootprints(
149 modelData: LsstScarletModelData,
150 catalog: SourceCatalog,
151 band: str,
152 imageForRedistribution: MaskedImage | Exposure | None = None,
153 removeScarletData: bool = True,
154 updateFluxColumns: bool = True,
155) -> None:
156 """Use the scarlet models to set HeavyFootprints for modeled sources
158 Parameters
159 ----------
160 modelData :
161 Persistable data for the entire catalog.
162 catalog:
163 The catalog missing heavy footprints for deblended sources.
164 band:
165 The name of the band that the catalog data describes.
166 imageForRedistribution:
167 The image that is the source for flux re-distribution.
168 If `imageForRedistribution` is `None` then flux re-distribution is
169 not performed.
170 removeScarletData:
171 Whether or not to remove `ScarletBlendData` for each blend
172 in order to save memory.
173 updateFluxColumns:
174 Whether or not to update the `deblend_*` columns in the catalog.
175 This should only be true when the input catalog schema already
176 contains those columns.
177 """
178 # All of the blends should have the same PSF,
179 # so we extract it from the first blend data.
180 if len(modelData.blends) == 0:
181 if len(modelData.isolated) == 0:
182 return NoWorkFound("Scarlet model data is empty")
183 # All of the sources must have been isolated so there is nothing
184 # to do in this function. This is rare but it does occasionally
185 # happen in fields that only have u-band images.
186 return
187 if modelData.metadata is None:
188 raise ValueError("Scarlet model data does not contain metadata")
189 bands = modelData.metadata["bands"]
190 try:
191 bandIndex = bands.index(band)
192 except ValueError:
193 raise NoWorkFound(f"Band '{band}' not found in scarlet model data")
194 modelPsf = modelData.metadata["model_psf"]
195 observedPsf = modelData.metadata["psf"][bandIndex][None, :, :]
197 # Flux re-distribution may mix depth=1 blends, so we iterate over the
198 # completely flux separated parents to ensure that the full models
199 # are used for each source.
200 blend_items = list(modelData.blends.items())
201 for parentId, blendData in blend_items:
202 spans = blendData.metadata["spans"]
203 bbox = scl.Box(spans.shape, blendData.metadata["origin"])
205 observation = buildMonochromaticObservation(
206 modelPsf=modelPsf,
207 observedPsf=observedPsf,
208 scarletBox=bbox,
209 footprint=spans,
210 imageForRedistribution=imageForRedistribution,
211 )
213 updateBlendRecords(
214 blendData=blendData,
215 bandIndex=bandIndex,
216 catalog=catalog,
217 observation=observation,
218 updateFluxColumns=updateFluxColumns,
219 imageForRedistribution=imageForRedistribution,
220 )
222 if removeScarletData:
223 modelData.blends.pop(parentId, None)
226def buildMonochromaticObservation(
227 modelPsf: np.ndarray,
228 observedPsf: np.ndarray,
229 scarletBox: Box,
230 footprint: np.ndarray | None,
231 imageForRedistribution: MaskedImage | Exposure | None = None,
232) -> scl.Observation:
233 """Create a single-band observation for the entire image
235 Parameters
236 ----------
237 modelPsf :
238 The 2D model of the PSF.
239 observedPsf :
240 The observed PSF model for the catalog.
241 scarletBox :
242 The bounding box for the scarlet observation.
243 footprint :
244 The footprint of the source, used for masking out the model.
245 imageForRedistribution:
246 The image that is the source for flux re-distribution.
247 If `imageForRedistribution` is `None` then flux re-distribution is
248 not performed.
250 Returns
251 -------
252 observation : `scarlet.lite.Observation`
253 The observation for the entire image
254 """
255 bbox = utils.scarletBoxToBBox(scarletBox)
257 if imageForRedistribution is not None:
258 cutout = imageForRedistribution[bbox]
260 # Mask the footprint
261 weights = np.ones(cutout.image.array.shape, dtype=cutout.image.array.dtype)
263 if footprint is not None:
264 weights *= footprint
266 observation = scl.Observation(
267 images=cutout.image.array[None, :, :],
268 variance=cutout.variance.array[None, :, :],
269 weights=weights[None, :, :],
270 psfs=observedPsf,
271 model_psf=modelPsf[None, :, :],
272 convolution_mode="real",
273 bands=monochromaticBands,
274 bbox=scarletBox,
275 )
276 else:
277 observation = scl.Observation.empty(
278 bands=monochromaticBands,
279 psfs=observedPsf,
280 model_psf=modelPsf[None, :, :],
281 bbox=scarletBox,
282 dtype=modelPsf.dtype,
283 )
284 return observation
287def calculateFootprintCoverage(footprint: afwFootprint, maskImage: MaskX) -> np.floating:
288 """Calculate the fraction of pixels with no data in a Footprint
290 Parameters
291 ----------
292 footprint : `lsst.afw.detection.Footprint`
293 The footprint to check for missing data.
294 maskImage : `lsst.afw.image.MaskX`
295 The mask image with the ``NO_DATA`` bit set.
296 Returns
297 -------
298 coverage : `float`
299 The fraction of pixels in `footprint` where the ``NO_DATA`` bit is set.
300 """
301 # Store the value of "NO_DATA" from the mask plane.
302 noDataInt = 2 ** maskImage.getMaskPlaneDict()["NO_DATA"]
304 # Calculate the coverage in the footprint
305 bbox = footprint.getBBox()
306 if bbox.area == 0:
307 # The source has no footprint, so it has no coverage
308 return 0
309 spans = footprint.spans.asArray()
310 totalArea = footprint.getArea()
311 mask = maskImage[bbox].array & noDataInt
312 noData = (mask * spans) > 0
313 coverage = 1 - np.sum(noData) / totalArea
314 return coverage
317def updateBlendRecords(
318 blendData: scl.io.ScarletBlendData | scl.io.HierarchicalBlendData,
319 bandIndex: int,
320 catalog: SourceCatalog,
321 observation: scl.Observation,
322 updateFluxColumns: bool,
323 imageForRedistribution: MaskedImage | Exposure | None = None,
324):
325 """Create footprints and update band-dependent columns in the catalog
327 Parameters
328 ----------
329 blendData :
330 Persistable data for a single blend or hierarchical blend.
331 bandIndex :
332 The number of the band to extract.
333 catalog :
334 The catalog that is being updated.
335 observation :
336 The observation of the blend.
337 updateFluxColumns :
338 Whether or not to update the `deblend_*` columns in the catalog.
339 This should only be true when the input catalog schema already
340 contains those columns.
341 imageForRedistribution :
342 The image that is the source for flux re-distribution.
343 If `imageForRedistribution` is `None` then flux re-distribution is
344 not performed.
345 """
346 useFlux = imageForRedistribution is not None
348 # Create a blend with the parent and all of its children.
349 sources = []
350 if isinstance(blendData, scl.io.HierarchicalBlendData):
351 for blendId in blendData.children:
352 _blendData = cast(scl.io.ScarletBlendData, blendData.children[blendId])
353 blend = monochromaticDataToScarlet(_blendData, bandIndex, observation)
354 sources.extend(blend.sources)
356 if len(sources) == 0:
357 # No sources to update, so we can skip the rest of the function.
358 return
360 blend = scl.Blend(
361 sources=sources,
362 observation=observation,
363 )
365 if useFlux:
366 blend.conserve_flux()
368 # Set the metrics for the blend.
369 # TODO: remove this once DM-34558 runs all deblender metrics
370 # in a separate task.
371 if updateFluxColumns:
372 setDeblenderMetrics(blend)
374 # Update the HeavyFootprints for deblended sources
375 # and update the band-dependent catalog columns.
376 for source in blend.sources:
377 sourceRecord = catalog.find(source.metadata["id"])
378 # Set the Footprint
379 heavy = scarletModelToHeavy(
380 source=source,
381 blend=blend,
382 useFlux=useFlux,
383 )
385 if updateFluxColumns:
386 if heavy.getArea() == 0:
387 # The source has no flux after being weighted with the PSF
388 # in this particular band (it might have flux in others).
389 sourceRecord.set("deblend_zeroFlux", True)
390 # Create a Footprint with a single pixel, set to zero,
391 # to avoid breakage in measurement algorithms.
392 center = Point2I(heavy.peaks[0]["i_x"], heavy.peaks[0]["i_y"])
393 spanList = [Span(center.y, center.x, center.x)]
394 footprint = afwFootprint(SpanSet(spanList))
395 footprint.setPeakCatalog(heavy.peaks)
396 heavy = HeavyFootprintF(footprint)
397 heavy.getImageArray()[0] = 0.0
398 else:
399 sourceRecord.set("deblend_zeroFlux", False)
400 sourceRecord.setFootprint(heavy)
402 if useFlux:
403 # Set the fraction of pixels with valid data.
404 coverage = calculateFootprintCoverage(
405 heavy, imageForRedistribution.mask
406 )
407 sourceRecord.set("deblend_dataCoverage", coverage)
409 # Set the flux of the scarlet model
410 # TODO: this field should probably be deprecated,
411 # since DM-33710 gives users access to the scarlet models.
412 model = source.get_model().data[0]
413 sourceRecord.set("deblend_scarletFlux", np.sum(model))
415 # Set the flux at the center of the model
416 peak = heavy.peaks[0]
418 img = heavy.extractImage(fill=0.0)
419 try:
420 sourceRecord.set(
421 "deblend_peak_instFlux", img[Point2I(peak["i_x"], peak["i_y"])]
422 )
423 except Exception:
424 srcId = sourceRecord.getId()
425 x = peak["i_x"]
426 y = peak["i_y"]
427 logger.warning(
428 f"Source {srcId} at {x},{y} could not set the peak flux with error:",
429 exc_info=1,
430 )
431 sourceRecord.set("deblend_peak_instFlux", np.nan)
432 else:
433 sourceRecord.setFootprint(heavy)
436def build_scarlet_model(zip_dict: dict[str, Any]) -> LsstScarletModelData:
437 """Build a LsstScarletModelData instance from a dictionary of files.
439 Parameters
440 ----------
441 zip_dict : dict[str, Any]
442 Dictionary mapping filenames to the desired file type.
444 Returns
445 -------
446 model :
447 LsstScarletModelData instance.
448 """
449 metadata = zip_dict.pop('metadata', None)
450 version = zip_dict.pop('version', scl.io.migration.PRE_SCHEMA)
451 if metadata is None:
452 model_psf = zip_dict.pop('psf')
453 psf_shape = zip_dict.pop('psf_shape')
454 metadata = {
455 'psf': model_psf,
456 'psfShape': psf_shape,
457 }
458 blends = {}
459 isolated = {}
460 for key, value in zip_dict.items():
461 if "blend_type" in value:
462 blends[int(key)] = value
463 elif "source_type" in value:
464 if value["source_type"] != "isolated":
465 raise ValueError("Found unknown source type in scarlet model data isolated sources")
466 isolated[int(key)] = value
467 else:
468 raise ValueError(f"Found unknown file '{value}' in scarlet model data")
470 return LsstScarletModelData.parse_obj({
471 'version': version,
472 'isolated': isolated,
473 'blends': blends,
474 'metadata': metadata,
475 })
478def read_scarlet_model(path_or_stream: str, blend_ids: list[int] | None = None) -> LsstScarletModelData:
479 """Read a zip file and return a LsstScarletModelData instance.
481 Parameters
482 ----------
483 path : `str`
484 Path to the zip file.
485 blend_ids : `list[int]`, optional
486 List of blend IDs to extract from the zip file. If None,
487 all blends in the dataset will be extracted.
489 Returns
490 -------
491 model :
492 LsstScarletModelData instance.
493 """
495 if blend_ids is not None:
496 filenames = [str(f) for f in blend_ids]
497 else:
498 filenames = None
500 with zipfile.ZipFile(path_or_stream, 'r') as zip_file:
501 unzipped_files = {}
502 if filenames is None:
503 filenames = zip_file.namelist()
504 # Attempt to read the metadata file first, if it exists.
505 try:
506 with zip_file.open('metadata') as f:
507 metadata = from_json(f.read())
508 unzipped_files['metadata'] = metadata
509 except ValueError:
510 # The metadata file is not present, so we will
511 # assume that the model is in the legacy format.
512 filenames += ['psf', 'psf_shape']
513 try:
514 with zip_file.open('version') as f:
515 version = from_json(f.read())
516 unzipped_files['version'] = version
517 except KeyError:
518 # The version file is not present.
519 pass
521 for filename in filenames:
522 with zip_file.open(filename) as f:
523 unzipped_files[filename] = from_json(f.read())
525 return build_scarlet_model(unzipped_files)
528def scarlet_model_to_zip_json(model_data: LsstScarletModelData) -> dict[str, Any]:
529 """Convert a LsstScarletModelData instance to a dictionary of files.
531 This is required to convert the model data into a format that
532 can be insterted into a zip archive.
534 Parameters
535 ----------
536 model_data : `lsst.scarelt.lite.io.LsstScarletModelData`
537 LsstScarletModelData instance.
539 Returns
540 -------
541 data : dict[str, Any]
542 Dictionary mapping filenames to the desired file type.
543 """
544 json_model = model_data.as_dict()
546 data = {
547 str(blend_id): json.dumps(blend_data)
548 for blend_id, blend_data in json_model['blends'].items()
549 }
551 data.update({
552 str(source_id): json.dumps(source_data)
553 for source_id, source_data in json_model['isolated'].items()
554 })
555 # Support for legacy models
556 if 'psf' in json_model:
557 data.update({
558 'psf_shape': json.dumps(json_model['psfShape']),
559 'psf': json.dumps(json_model['psf']),
560 })
561 else:
562 data.update({
563 'metadata': json.dumps(json_model['metadata']),
564 'version': json.dumps(json_model['version']),
565 })
566 return data
569def write_scarlet_model(path_or_stream: str | BinaryIO, model_data: LsstScarletModelData):
570 """Write a LsstScarletModelData instance to a zip file.
572 Parameters
573 ----------
574 model_data : `lsst.scarlet.lite.io.LsstScarletModelData`
575 LsstScarletModelData instance.
577 Returns
578 -------
579 zip_dict :
580 Dictionary mapping filenames to the desired file type.
581 """
582 with zipfile.ZipFile(path_or_stream, 'w') as zf:
583 zip_archive = scarlet_model_to_zip_json(model_data)
584 for filename, data in zip_archive.items():
585 zf.writestr(filename, data)
588def scarlet_model_to_lsst_scarlet_model(model_data: scl.io.ScarletModelData) -> LsstScarletModelData:
589 """Convert a scarlet ModelData instance to a LsstScarletModelData instance.
591 Parameters
592 ----------
593 model_data : `scarlet.lite.io.ScarletModelData`
594 Scarlet ModelData instance.
596 Returns
597 -------
598 result : `lsst.scarlet.lite.io.LsstScarletModelData`
599 LsstScarletModelData instance.
600 """
601 return LsstScarletModelData(
602 blends=model_data.blends,
603 metadata=None,
604 )
607class ScarletModelFormatter(FormatterV2):
608 """Read and write scarlet models.
609 """
611 default_extension = ".scarlet"
612 unsupported_parameters = frozenset()
613 can_read_from_stream = True
614 can_read_from_local_file = True
616 def read_from_local_file(self, path: str, component: str | None = None, expected_size: int = -1) -> Any:
617 # Override of `FormatterV2.read_from_local_file`.
618 return read_scarlet_model(path)
620 def read_from_stream(
621 self, stream: BinaryIO | ResourceHandleProtocol, component: str | None = None, expected_size: int = -1
622 ) -> Any:
623 # Override of `FormatterV2.read_from_stream`.
624 if self.file_descriptor.parameters is not None and "blend_id" in self.file_descriptor.parameters:
625 blend_ids = lsst_utils.iteration.ensure_iterable(self.file_descriptor.parameters["blend_id"])
626 else:
627 return NotImplemented
629 return read_scarlet_model(stream, blend_ids=blend_ids)
631 def to_bytes(self, in_memory_dataset: Any) -> bytes:
632 # Override of `FormatterV2.to_bytes`.
633 in_memory_zip = BytesIO()
634 write_scarlet_model(in_memory_zip, in_memory_dataset)
635 return in_memory_zip.getvalue()
638class ScarletModelDelegate(StorageClassDelegate):
639 """Delegate to extract a blend from an in-memory
640 LsstScarletModelData object.
641 """
642 def can_accept(self, inMemoryDataset: Any) -> bool:
643 return isinstance(inMemoryDataset, LsstScarletModelData)
645 def getComponent(self, composite: Any, componentName: str) -> Any:
646 raise AttributeError(f"Unsupported component: {componentName}")
648 def handleParameters(self, inMemoryDataset: Any, parameters: Mapping[str, Any] | None = None) -> Any:
649 if "blend_id" in parameters:
650 blend_ids = lsst_utils.iteration.ensure_iterable(parameters["blend_id"])
651 blends = {blend_id: inMemoryDataset.blends[blend_id] for blend_id in blend_ids}
652 inMemoryDataset.blends = blends
653 elif parameters is not None:
654 raise ValueError(f"Unsupported parameters: {parameters}")
655 return inMemoryDataset
658def loadBlend(blendData: scl.io.ScarletBlendData, model_psf: np.ndarray, mCoadd: MultibandExposure):
659 """Load a blend from the persisted data
661 Parameters
662 ----------
663 blendData:
664 The persisted scarlet BlendData to load into the blend.
665 model_psf:
666 The psf of the model in each band. This should be 2D, as scarlet
667 lite assumes that the PSF is the same for all bands.
668 mCoadd:
669 The coadd image to use for the observation attached to the blend.
670 This is required in order to create a difference kernel to convolve
671 the model into an observed seeing.
673 Returns
674 -------
675 blend : `scarlet.lite.Blend`
676 The blend object loaded from the persisted data.
677 """
678 psf, _ = utils.computePsfKernelImage(mCoadd, blendData.psf_center)
679 bbox = Box(blendData.shape, origin=blendData.origin)
680 afw_box = Box2I(Point2I(bbox.origin[::-1]), Extent2I(bbox.shape[::-1]))
681 coadd = mCoadd[blendData.bands, afw_box]
682 observation = scl.Observation(
683 images=coadd.image.array,
684 variance=coadd.variance.array,
685 weights=np.ones(coadd.image.array.shape, dtype=np.float32),
686 psfs=psf,
687 model_psf=model_psf[None, :, :],
688 convolution_mode='real',
689 bands=mCoadd.bands,
690 bbox=bbox,
691 )
692 return blendData.to_blend(observation), afw_box