Coverage for python/lsst/scarlet/lite/io.py: 47%
146 statements
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-01 11:54 +0000
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-01 11:54 +0000
1from __future__ import annotations
3import json
4import logging
5from dataclasses import dataclass
6from typing import Any, Callable
8import numpy as np
9from numpy.typing import DTypeLike
11from .bbox import Box
12from .blend import Blend
13from .component import Component, FactorizedComponent
14from .image import Image
15from .observation import Observation
16from .parameters import FixedParameter
17from .source import Source
19__all__ = [
20 "ScarletComponentData",
21 "ScarletFactorizedComponentData",
22 "ScarletSourceData",
23 "ScarletBlendData",
24 "ScarletModelData",
25 "ComponentCube",
26]
28logger = logging.getLogger(__name__)
31@dataclass(kw_only=True)
32class ScarletComponentData:
33 """Data for a component expressed as a 3D data cube
35 This is used for scarlet component models that are not factorized,
36 storing their entire model as a 3D data cube (bands, y, x).
38 Attributes
39 ----------
40 origin:
41 The lower bound of the components bounding box.
42 peak:
43 The peak of the component.
44 model:
45 The model for the component.
46 """
48 origin: tuple[int, int]
49 peak: tuple[float, float]
50 model: np.ndarray
52 @property
53 def shape(self):
54 return self.model.shape[-2:]
56 def as_dict(self) -> dict:
57 """Return the object encoded into a dict for JSON serialization
59 Returns
60 -------
61 result:
62 The object encoded as a JSON compatible dict
63 """
64 return {
65 "origin": self.origin,
66 "shape": self.model.shape,
67 "peak": self.peak,
68 "model": tuple(self.model.flatten().astype(float)),
69 }
71 @classmethod
72 def from_dict(cls, data: dict, dtype: DTypeLike = np.float32) -> ScarletComponentData:
73 """Reconstruct `ScarletComponentData` from JSON compatible dict
75 Parameters
76 ----------
77 data:
78 Dictionary representation of the object
79 dtype:
80 Datatype of the resulting model.
82 Returns
83 -------
84 result:
85 The reconstructed object
86 """
87 shape = tuple(data["shape"])
89 return cls(
90 origin=tuple(data["origin"]), # type: ignore
91 peak=data["peak"],
92 model=np.array(data["model"]).reshape(shape).astype(dtype),
93 )
96@dataclass(kw_only=True)
97class ScarletFactorizedComponentData:
98 """Data for a factorized component
100 Attributes
101 ----------
102 origin:
103 The lower bound of the component's bounding box.
104 peak:
105 The ``(y, x)`` peak of the component.
106 spectrum:
107 The SED of the component.
108 morph:
109 The 2D morphology of the component.
110 """
112 origin: tuple[int, int]
113 peak: tuple[float, float]
114 spectrum: np.ndarray
115 morph: np.ndarray
117 @property
118 def shape(self):
119 return self.morph.shape
121 def as_dict(self) -> dict:
122 """Return the object encoded into a dict for JSON serialization
124 Returns
125 -------
126 result:
127 The object encoded as a JSON compatible dict
128 """
129 return {
130 "origin": tuple(int(o) for o in self.origin),
131 "shape": tuple(int(s) for s in self.morph.shape),
132 "peak": tuple(int(p) for p in self.peak),
133 "spectrum": tuple(self.spectrum.astype(float)),
134 "morph": tuple(self.morph.flatten().astype(float)),
135 }
137 @classmethod
138 def from_dict(cls, data: dict, dtype: DTypeLike = np.float32) -> ScarletFactorizedComponentData:
139 """Reconstruct `ScarletFactorizedComponentData` from JSON compatible
140 dict.
142 Parameters
143 ----------
144 data:
145 Dictionary representation of the object
146 dtype:
147 Datatype of the resulting model.
149 Returns
150 -------
151 result:
152 The reconstructed object
153 """
154 shape = tuple(data["shape"])
156 return cls(
157 origin=tuple(data["origin"]), # type: ignore
158 peak=data["peak"],
159 spectrum=np.array(data["spectrum"]).astype(dtype),
160 morph=np.array(data["morph"]).reshape(shape).astype(dtype),
161 )
164@dataclass(kw_only=True)
165class ScarletSourceData:
166 """Data for a scarlet source
168 Attributes
169 ----------
170 components:
171 The components contained in the source that are not factorized.
172 factorized_components:
173 The components contained in the source that are factorized.
174 peak_id:
175 The peak ID of the source in it's parent's footprint peak catalog.
176 """
178 components: list[ScarletComponentData]
179 factorized_components: list[ScarletFactorizedComponentData]
180 peak_id: int
182 def as_dict(self) -> dict:
183 """Return the object encoded into a dict for JSON serialization
185 Returns
186 -------
187 result:
188 The object encoded as a JSON compatible dict
189 """
190 result = {
191 "components": [component.as_dict() for component in self.components],
192 "factorized": [component.as_dict() for component in self.factorized_components],
193 "peak_id": self.peak_id,
194 }
195 return result
197 @classmethod
198 def from_dict(cls, data: dict, dtype: DTypeLike = np.float32) -> ScarletSourceData:
199 """Reconstruct `ScarletSourceData` from JSON compatible
200 dict.
202 Parameters
203 ----------
204 data:
205 Dictionary representation of the object
206 dtype:
207 Datatype of the resulting model.
209 Returns
210 -------
211 result:
212 The reconstructed object
213 """
214 components = []
215 for component in data["components"]:
216 component = ScarletComponentData.from_dict(component, dtype=dtype)
217 components.append(component)
219 factorized = []
220 for component in data["factorized"]:
221 component = ScarletFactorizedComponentData.from_dict(component, dtype=dtype)
222 factorized.append(component)
224 return cls(components=components, factorized_components=factorized, peak_id=int(data["peak_id"]))
227@dataclass(kw_only=True)
228class ScarletBlendData:
229 """Data for an entire blend.
231 Attributes
232 ----------
233 origin:
234 The lower bound of the blend's bounding box.
235 shape:
236 The shape of the blend's bounding box.
237 sources:
238 Data for the sources contained in the blend,
239 indexed by the source id.
240 psf_center:
241 The location used for the center of the PSF for
242 the blend.
243 psf:
244 The PSF of the observation.
245 bands : `list` of `str`
246 The names of the bands.
247 The order of the bands must be the same as the order of
248 the multiband model arrays, and SEDs.
249 """
251 origin: tuple[int, int]
252 shape: tuple[int, int]
253 sources: dict[int, ScarletSourceData]
254 psf_center: tuple[float, float]
255 psf: np.ndarray
256 bands: tuple[str]
258 def as_dict(self) -> dict:
259 """Return the object encoded into a dict for JSON serialization
261 Returns
262 -------
263 result:
264 The object encoded as a JSON compatible dict
265 """
266 result = {
267 "origin": self.origin,
268 "shape": self.shape,
269 "psf_center": self.psf_center,
270 "psf_shape": self.psf.shape,
271 "psf": tuple(self.psf.flatten().astype(float)),
272 "sources": {bid: source.as_dict() for bid, source in self.sources.items()},
273 "bands": self.bands,
274 }
275 return result
277 @classmethod
278 def from_dict(cls, data: dict, dtype: DTypeLike = np.float32) -> ScarletBlendData:
279 """Reconstruct `ScarletBlendData` from JSON compatible
280 dict.
282 Parameters
283 ----------
284 data:
285 Dictionary representation of the object
286 dtype:
287 Datatype of the resulting model.
289 Returns
290 -------
291 result:
292 The reconstructed object
293 """
294 psf_shape = data["psf_shape"]
295 return cls(
296 origin=tuple(data["origin"]), # type: ignore
297 shape=tuple(data["shape"]), # type: ignore
298 psf_center=tuple(data["psf_center"]), # type: ignore
299 psf=np.array(data["psf"]).reshape(psf_shape).astype(dtype),
300 sources={
301 int(bid): ScarletSourceData.from_dict(source, dtype=dtype)
302 for bid, source in data["sources"].items()
303 },
304 bands=tuple(data["bands"]), # type: ignore
305 )
307 def minimal_data_to_blend(self, model_psf: np.ndarray, dtype: DTypeLike) -> Blend:
308 """Convert the storage data model into a scarlet lite blend
310 Parameters
311 ----------
312 model_psf:
313 PSF in model space (usually a nyquist sampled circular Gaussian).
314 dtype:
315 The data type of the model that is generated.
317 Returns
318 -------
319 blend:
320 A scarlet blend model extracted from persisted data.
321 """
322 model_box = Box(self.shape, origin=(0, 0))
323 observation = Observation.empty(
324 bands=self.bands,
325 psfs=self.psf,
326 model_psf=model_psf,
327 bbox=model_box,
328 dtype=dtype,
329 )
330 return self.to_blend(observation)
332 def to_blend(self, observation: Observation) -> Blend:
333 """Convert the storage data model into a scarlet lite blend
335 Parameters
336 ----------
337 observation:
338 The observation that contains the blend.
339 If `observation` is ``None`` then an `Observation` containing
340 no image data is initialized.
342 Returns
343 -------
344 blend:
345 A scarlet blend model extracted from persisted data.
346 """
347 sources = []
348 for source_id, source_data in self.sources.items():
349 components: list[Component] = []
350 for component_data in source_data.components:
351 bbox = Box(component_data.shape, origin=component_data.origin)
352 model = component_data.model
353 if component_data.peak is None:
354 peak = None
355 else:
356 peak = (int(np.round(component_data.peak[0])), int(np.round(component_data.peak[0])))
357 component = ComponentCube(
358 bands=observation.bands,
359 bbox=bbox,
360 model=Image(model, yx0=bbox.origin, bands=observation.bands), # type: ignore
361 peak=peak,
362 )
363 components.append(component)
364 for factorized_data in source_data.factorized_components:
365 bbox = Box(factorized_data.shape, origin=factorized_data.origin)
366 # Add dummy values for properties only needed for
367 # model fitting.
368 spectrum = FixedParameter(factorized_data.spectrum)
369 morph = FixedParameter(factorized_data.morph)
370 # Note: since we aren't fitting a model, we don't need to
371 # set the RMS of the background.
372 # We set it to NaN just to be safe.
373 factorized = FactorizedComponent(
374 bands=observation.bands,
375 spectrum=spectrum,
376 morph=morph,
377 peak=tuple(int(np.round(p)) for p in factorized_data.peak), # type: ignore
378 bbox=bbox,
379 bg_rms=np.full((len(observation.bands),), np.nan),
380 )
381 components.append(factorized)
383 source = Source(components=components)
384 # Store identifiers for the source
385 source.record_id = source_id # type: ignore
386 source.peak_id = source_data.peak_id # type: ignore
387 sources.append(source)
389 return Blend(sources=sources, observation=observation)
391 @staticmethod
392 def from_blend(blend: Blend, psf_center: tuple[int, int]) -> ScarletBlendData:
393 """Convert a scarlet lite blend into a persistable data object
395 Parameters
396 ----------
397 blend:
398 The blend that is being persisted.
399 psf_center:
400 The center of the PSF.
402 Returns
403 -------
404 blend_data:
405 The data model for a single blend.
406 """
407 sources = {}
408 for source in blend.sources:
409 components = []
410 factorized = []
411 for component in source.components:
412 if type(component) is FactorizedComponent:
413 factorized_data = ScarletFactorizedComponentData(
414 origin=component.bbox.origin, # type: ignore
415 peak=component.peak, # type: ignore
416 spectrum=component.spectrum,
417 morph=component.morph,
418 )
419 factorized.append(factorized_data)
420 else:
421 component_data = ScarletComponentData(
422 origin=component.bbox.origin, # type: ignore
423 peak=component.peak, # type: ignore
424 model=component.get_model().data,
425 )
426 components.append(component_data)
427 source_data = ScarletSourceData(
428 components=components,
429 factorized_components=factorized,
430 peak_id=source.peak_id, # type: ignore
431 )
432 sources[source.record_id] = source_data # type: ignore
434 blend_data = ScarletBlendData(
435 origin=blend.bbox.origin, # type: ignore
436 shape=blend.bbox.shape, # type: ignore
437 sources=sources,
438 psf_center=psf_center,
439 psf=blend.observation.psfs,
440 bands=blend.observation.bands, # type: ignore
441 )
443 return blend_data
446class ScarletModelData:
447 """A container that propagates scarlet models for an entire catalog."""
449 def __init__(self, psf: np.ndarray, blends: dict[int, ScarletBlendData] | None = None):
450 """Initialize an instance
452 Parameters
453 ----------
454 bands:
455 The names of the bands.
456 The order of the bands must be the same as the order of
457 the multiband model arrays, and SEDs.
458 psf:
459 The 2D array of the PSF in scarlet model space.
460 This is typically a narrow Gaussian integrated over the
461 pixels in the exposure.
462 blends:
463 Map from parent IDs in the source catalog
464 to scarlet model data for each parent ID (blend).
465 """
466 self.psf = psf
467 if blends is None:
468 blends = {}
469 self.blends = blends
471 def json(self) -> str:
472 """Serialize the data model to a JSON formatted string
474 Returns
475 -------
476 result : `str`
477 The result of the object converted into a JSON format
478 """
479 result = {
480 "psfShape": self.psf.shape,
481 "psf": list(self.psf.flatten().astype(float)),
482 "blends": {bid: blend.as_dict() for bid, blend in self.blends.items()},
483 }
484 return json.dumps(result)
486 @classmethod
487 def parse_obj(cls, data: dict) -> ScarletModelData:
488 """Construct a ScarletModelData from python decoded JSON object.
490 Parameters
491 ----------
492 data:
493 The result of json.load(s) on a JSON persisted ScarletModelData
495 Returns
496 -------
497 result:
498 The `ScarletModelData` that was loaded the from the input object
499 """
500 model_psf = np.array(data["psf"]).reshape(data["psfShape"]).astype(np.float32)
501 return cls(
502 psf=model_psf,
503 blends={int(bid): ScarletBlendData.from_dict(blend) for bid, blend in data["blends"].items()},
504 )
507class ComponentCube(Component):
508 """Dummy component for a component cube.
510 This is duck-typed to a `lsst.scarlet.lite.Component` in order to
511 generate a model from the component.
513 If scarlet lite ever implements a component as a data cube,
514 this class can be removed.
515 """
517 def __init__(self, bands: tuple[Any, ...], bbox: Box, model: Image, peak: tuple[int, int]):
518 """Initialization
520 Parameters
521 ----------
522 bands:
523 model:
524 The 3D (bands, y, x) model of the component.
525 peak:
526 The `(y, x)` peak of the component.
527 bbox:
528 The bounding box of the component.
529 """
530 super().__init__(bands, bbox)
531 self._model = model
532 self.peak = peak
534 def get_model(self) -> Image:
535 """Generate the model for the source
537 Returns
538 -------
539 model:
540 The model as a 3D `(band, y, x)` array.
541 """
542 return self._model
544 def resize(self, model_box: Box) -> bool:
545 """Test whether or not the component needs to be resized"""
546 return False
548 def update(self, it: int, input_grad: np.ndarray) -> None:
549 """Implementation of unused abstract method"""
551 def parameterize(self, parameterization: Callable) -> None:
552 """Implementation of unused abstract method"""