Coverage for python / lsst / scarlet / lite / component.py: 31%
206 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-17 08:40 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-17 08:40 +0000
1# This file is part of scarlet_lite.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21from __future__ import annotations
23from copy import deepcopy
25__all__ = [
26 "Component",
27 "CubeComponent",
28 "FactorizedComponent",
29 "default_fista_parameterization",
30 "default_adaprox_parameterization",
31]
33from abc import ABC, abstractmethod
34from functools import partial
35from typing import TYPE_CHECKING, Any, Callable, cast
37import numpy as np
39from .bbox import Box
40from .image import Image
41from .operators import Monotonicity, prox_uncentered_symmetry
42from .parameters import AdaproxParameter, FistaParameter, Parameter, parameter, relative_step
43from .utils import convert_indices
45if TYPE_CHECKING:
46 from .io import ScarletComponentBaseData, ScarletCubeComponentData
48import logging
50Logger = logging.getLogger(__name__)
53class Component(ABC):
54 """A base component in scarlet lite.
56 Parameters
57 ----------
58 bands:
59 The bands used when the component model is created.
60 bbox: Box
61 The bounding box for this component.
62 """
64 def __init__(
65 self,
66 bands: tuple,
67 bbox: Box,
68 ):
69 self._bands = bands
70 self._bbox = bbox
72 @property
73 def bbox(self) -> Box:
74 """The bounding box that contains the component in the full image"""
75 return self._bbox
77 @property
78 def bands(self) -> tuple:
79 """The bands in the component model"""
80 return self._bands
82 @abstractmethod
83 def resize(self, model_box: Box) -> bool:
84 """Test whether or not the component needs to be resized
86 This should be overriden in inherited classes and return `True`
87 if the component needs to be resized.
88 """
90 @abstractmethod
91 def update(self, it: int, input_grad: np.ndarray) -> None:
92 """Update the component parameters from an input gradient
94 Parameters
95 ----------
96 it:
97 The current iteration of the optimizer.
98 input_grad:
99 Gradient of the likelihood wrt the component model
100 """
102 @abstractmethod
103 def get_model(self) -> Image:
104 """Generate a model for the component
106 This must be implemented in inherited classes.
108 Returns
109 -------
110 model: Image
111 The image of the component model.
112 """
114 @abstractmethod
115 def parameterize(self, parameterization: Callable) -> None:
116 """Convert the component parameter arrays into Parameter instances
118 Parameters
119 ----------
120 parameterization: Callable
121 A function to use to convert parameters of a given type into
122 a `Parameter` in place. It should take a single argument that
123 is the `Component` or `Source` that is to be parameterized.
124 """
126 @abstractmethod
127 def to_data(self) -> ScarletComponentBaseData:
128 """Convert the component to persistable ScarletComponentBaseData
130 Returns
131 -------
132 component_data: ScarletComponentBaseData
133 The data object containing the component information
134 """
136 @abstractmethod
137 def __getitem__(self, indices: Any) -> Component:
138 """Get a sub-component corresponding to the given indices.
140 Parameters
141 ----------
142 indices: Any
143 The indices to use to slice the component model.
145 Returns
146 -------
147 sub_component: Component
148 A new component that is a sub-component of this one.
150 Raises
151 ------
152 IndexError :
153 If the index includes a ``Box`` or spatial indices.
154 """
156 @abstractmethod
157 def __copy__(self) -> Component:
158 """Create a copy of this component.
160 Returns
161 -------
162 component : Component
163 A new component that is a copy of this one.
164 """
166 @abstractmethod
167 def __deepcopy__(self, memo: dict[int, Any]) -> Component:
168 """Create a deep copy of this component.
170 Returns
171 -------
172 component : Component
173 A new component that is a deep copy of this one.
174 """
176 def copy(self, deep: bool = False) -> Component:
177 """Create a copy of this component.
179 Parameters
180 ----------
181 deep : bool, optional
182 If `True`, a deep copy is made. If `False`, a shallow copy is made.
183 Default is `False`.
185 Returns
186 -------
187 component : Component
188 A new component that is a copy of this one.
189 """
190 if deep:
191 return self.__deepcopy__({})
192 return self.__copy__()
195class FactorizedComponent(Component):
196 """A component that can be factorized into spectrum and morphology
197 parameters.
199 Parameters
200 ----------
201 bands:
202 The bands of the spectral dimension, in order.
203 spectrum:
204 The parameter to store and update the spectrum.
205 morph:
206 The parameter to store and update the morphology.
207 peak:
208 Location of the peak for the source.
209 bbox:
210 The `Box` in the `model_bbox` that contains the source.
211 bg_rms:
212 The RMS of the background used to threshold, grow,
213 and shrink the component.
214 bg_thresh:
215 The threshold to use for the background RMS.
216 If `None`, no background thresholding is applied, otherwise
217 a sparsity constraint is applied to the morpholigy that
218 requires flux in at least one band to be bg_thresh multiplied by
219 `bg_rms` in that band.
220 floor:
221 Minimum value of the spectrum or center morphology pixel
222 (depending on which is normalized).
223 monotonicity:
224 The monotonicity operator to use for making the source monotonic.
225 If this parameter is `None`, the source will not be made monotonic.
226 padding:
227 The amount of padding to add to the component bounding box
228 when resizing the component.
229 is_symmetric:
230 Whether the component is symmetric or not.
231 If `True`, the morphology will be symmetrized using
232 `prox_uncentered_symmetry`.
233 If `False`, the morphology will not be symmetrized.
234 """
236 def __init__(
237 self,
238 bands: tuple,
239 spectrum: Parameter | np.ndarray,
240 morph: Parameter | np.ndarray,
241 bbox: Box,
242 peak: tuple[int, int] | None = None,
243 bg_rms: np.ndarray | None = None,
244 bg_thresh: float | None = 0.25,
245 floor: float = 1e-20,
246 monotonicity: Monotonicity | None = None,
247 padding: int = 5,
248 is_symmetric: bool = False,
249 ):
250 # Initialize all of the base attributes
251 super().__init__(
252 bands=bands,
253 bbox=bbox,
254 )
255 self._spectrum = parameter(spectrum)
256 self._morph = parameter(morph)
257 self._peak = peak
258 self.bg_rms = bg_rms
259 self.bg_thresh = bg_thresh
261 self.floor = floor
262 self.monotonicity = monotonicity
263 self.padding = padding
264 self.is_symmetric = is_symmetric
266 @property
267 def peak(self) -> tuple[int, int] | None:
268 """The peak of the component
270 Returns
271 -------
272 peak:
273 The peak of the component
274 """
275 return self._peak
277 @property
278 def component_center(self) -> tuple[int, int] | None:
279 """The center of the component in its bounding box
281 This is likely to be different than `Component.center`,
282 since `Component.center` is the center of the component in the
283 full model, whereas `component_center` is the center of the component
284 inside its bounding box.
286 Returns
287 -------
288 center:
289 The center of the component in its bounding box
290 """
291 _center = self.peak
292 if _center is None:
293 return None
294 center = (
295 _center[0] - self.bbox.origin[-2],
296 _center[1] - self.bbox.origin[-1],
297 )
298 return center
300 @property
301 def spectrum(self) -> np.ndarray:
302 """The array of spectrum values"""
303 return self._spectrum.x
305 @property
306 def morph(self) -> np.ndarray:
307 """The array of morphology values"""
308 return self._morph.x
310 @property
311 def shape(self) -> tuple:
312 """Shape of the resulting model image"""
313 return self.spectrum.shape + self.morph.shape
315 def get_model(self) -> Image:
316 """Build the model from the spectrum and morphology"""
317 # The spectrum and morph might be Parameters,
318 # so cast them as arrays in the model.
319 spectrum = self.spectrum
320 morph = self.morph
321 model = spectrum[:, None, None] * morph[None, :, :]
322 return Image(model, bands=self.bands, yx0=cast(tuple[int, int], self.bbox.origin))
324 def grad_spectrum(self, input_grad: np.ndarray, spectrum: np.ndarray, morph: np.ndarray):
325 """Gradient of the spectrum wrt. the component model"""
326 return np.einsum("...jk,jk", input_grad, morph)
328 def grad_morph(self, input_grad: np.ndarray, morph: np.ndarray, spectrum: np.ndarray):
329 """Gradient of the morph wrt. the component model"""
330 return np.einsum("i,i...", spectrum, input_grad)
332 def prox_spectrum(self, spectrum: np.ndarray) -> np.ndarray:
333 """Apply a prox-like update to the spectrum"""
334 # prevent divergent spectrum
335 spectrum[spectrum < self.floor] = self.floor
336 spectrum[~np.isfinite(spectrum)] = self.floor
337 return spectrum
339 def prox_morph(self, morph: np.ndarray) -> np.ndarray:
340 """Apply a prox-like update to the morphology"""
341 # Get the peak position in the current bbox
342 shape = morph.shape
343 if self.peak is None:
344 peak = (shape[0] // 2, shape[1] // 2)
345 else:
346 peak = (
347 self.peak[0] - self.bbox.origin[-2],
348 self.peak[1] - self.bbox.origin[-1],
349 )
351 # monotonicity
352 if self.monotonicity is not None:
353 morph = self.monotonicity(morph, cast(tuple[int, int], self.component_center))
355 # symmetry
356 if self.is_symmetric:
357 # Apply the symmetry operator
358 morph = prox_uncentered_symmetry(morph, peak, fill=0.0)
360 if self.bg_thresh is not None and self.bg_rms is not None:
361 bg_thresh = self.bg_rms * self.bg_thresh
362 # Enforce background thresholding
363 model = self.spectrum[:, None, None] * morph[None, :, :]
364 morph[np.all(model < bg_thresh[:, None, None], axis=0)] = 0
365 else:
366 # enforce positivity
367 morph[morph < 0] = 0
369 # prevent divergent morphology
370 morph[peak] = np.max([morph[peak], self.floor])
372 # Ensure that the morphology is finite
373 morph[~np.isfinite(morph)] = 0
375 # Normalize the morphology
376 max_value = np.max(morph)
377 if max_value > 0:
378 morph[:] = morph / max_value
379 return morph
381 def resize(self, model_box: Box) -> bool:
382 """Test whether or not the component needs to be resized"""
383 # No need to resize if there is no size threshold.
384 # To allow box sizing but no thresholding use `bg_thresh=0`.
385 if self.bg_thresh is None or self.bg_rms is None:
386 return False
388 model = self.spectrum[:, None, None] * self.morph[None, :, :]
389 bg_thresh = self.bg_rms * self.bg_thresh
390 significant = np.any(model >= bg_thresh[:, None, None], axis=0)
391 if np.sum(significant) == 0:
392 # There are no significant pixels,
393 # so make a small box around the center
394 center = self.peak
395 if center is None:
396 center = (0, 0)
397 new_box = Box((1, 1), center).grow(self.padding) & model_box
398 else:
399 new_box = (
400 Box.from_data(significant, threshold=0).grow(self.padding) + self.bbox.origin # type: ignore
401 ) & model_box
402 if new_box == self.bbox:
403 return False
405 old_box = self.bbox
406 self._bbox = new_box
407 self._morph.resize(old_box, new_box)
408 return True
410 def update(self, it: int, input_grad: np.ndarray):
411 """Update the spectrum and morphology parameters"""
412 # Store the input spectrum so that the morphology can
413 # have a consistent update
414 spectrum = self.spectrum.copy()
415 self._spectrum.update(it, input_grad, self.morph)
416 self._morph.update(it, input_grad, spectrum)
418 def parameterize(self, parameterization: Callable) -> None:
419 """Convert the component parameter arrays into Parameter instances
421 Parameters
422 ----------
423 parameterization: Callable
424 A function to use to convert parameters of a given type into
425 a `Parameter` in place. It should take a single argument that
426 is the `Component` or `Source` that is to be parameterized.
427 """
428 # Update the spectrum and morph in place
429 parameterization(self)
430 # update the parameters
431 self._spectrum.grad = self.grad_spectrum
432 self._spectrum.prox = self.prox_spectrum
433 self._morph.grad = self.grad_morph
434 self._morph.prox = self.prox_morph
436 def to_data(self) -> ScarletComponentBaseData:
437 """Convert the component to persistable ScarletComponentBaseData
439 Returns
440 -------
441 component_data: ScarletComponentBaseData
442 The data object containing the component information
443 """
444 from .io import ScarletFactorizedComponentData
446 return ScarletFactorizedComponentData(
447 origin=self.bbox.origin, # type: ignore
448 peak=self.peak, # type: ignore
449 spectrum=self.spectrum,
450 morph=self.morph,
451 )
453 def __str__(self):
454 result = (
455 f"FactorizedComponent<\n bands={self.bands},\n center={self.peak},\n "
456 f"spectrum={self.spectrum},\n morph_shape={self.morph.shape}\n>"
457 )
458 return result
460 def __repr__(self):
461 return self.__str__()
463 def __getitem__(self, indices: Any) -> FactorizedComponent:
464 """Get a sub-component corresponding to the given indices.
466 Parameters
467 ----------
468 indices: Any
469 The indices to use to slice the component model.
471 Returns
472 -------
473 component: FactorizedComponent
474 A new component that is a sub-component of this one.
476 Raises
477 ------
478 IndexError :
479 If the index includes a ``Box`` or spatial indices.
480 """
481 # Convert the band indices into numerical indices
482 band_indices = convert_indices(self.bands, indices)
483 if isinstance(band_indices, slice):
484 bands = self.bands[band_indices]
485 else:
486 bands = tuple(self.bands[i] for i in band_indices)
488 # Slice the spectrum
489 spectrum = self._spectrum.x[band_indices,]
491 return FactorizedComponent(
492 bands=bands,
493 spectrum=spectrum,
494 morph=self.morph,
495 bbox=self.bbox,
496 peak=self.peak,
497 bg_rms=self.bg_rms,
498 bg_thresh=self.bg_thresh,
499 floor=self.floor,
500 monotonicity=self.monotonicity,
501 padding=self.padding,
502 is_symmetric=self.is_symmetric,
503 )
505 def __deepcopy__(self, memo: dict[int, Any]) -> FactorizedComponent:
506 """Create a deep copy of this component.
508 Parameters
509 ----------
510 memo: dict[int, Any]
511 The memoization dictionary used by `copy.deepcopy`.
513 Returns
514 -------
515 component : FactorizedComponent
516 A new component that is a deep copy of this one.
517 """
518 # Check if already copied
519 if id(self) in memo:
520 return memo[id(self)]
522 # Create placeholder and add to memo FIRST
523 component = FactorizedComponent.__new__(FactorizedComponent)
524 memo[id(self)] = component
526 # Now safely initialize the placeholder with deepcopied arguments
527 component.__init__( # type: ignore[misc]
528 bands=deepcopy(self.bands, memo),
529 spectrum=deepcopy(self.spectrum, memo),
530 morph=deepcopy(self.morph, memo),
531 bbox=deepcopy(self.bbox, memo),
532 peak=deepcopy(self.peak, memo),
533 bg_rms=deepcopy(self.bg_rms, memo),
534 bg_thresh=self.bg_thresh,
535 floor=self.floor,
536 monotonicity=deepcopy(self.monotonicity, memo),
537 padding=self.padding,
538 is_symmetric=self.is_symmetric,
539 )
540 return component
542 def __copy__(self) -> FactorizedComponent:
543 """Create a copy of this component.
545 Returns
546 -------
547 component : FactorizedComponent
548 A new component that is a shallow copy of this one.
549 """
550 return FactorizedComponent(
551 bands=self.bands,
552 spectrum=self.spectrum,
553 morph=self.morph,
554 bbox=self.bbox,
555 peak=self.peak,
556 bg_rms=self.bg_rms,
557 bg_thresh=self.bg_thresh,
558 floor=self.floor,
559 monotonicity=self.monotonicity,
560 padding=self.padding,
561 is_symmetric=self.is_symmetric,
562 )
565class CubeComponent(Component):
566 """Dummy component for a component cube.
568 This is duck-typed to a `lsst.scarlet.lite.Component` in order to
569 generate a model from the component but it is currently not functional
570 in that it cannot be optimized, only persisted and loaded.
572 If scarlet lite ever implements a component as a data cube,
573 this class can be removed.
574 """
576 def __init__(self, model: Image, peak: tuple[int, int]):
577 """Initialization
579 Parameters
580 ----------
581 bands :
582 model :
583 The 3D (bands, y, x) model of the component.
584 peak :
585 The `(y, x)` peak of the component.
586 bbox :
587 The bounding box of the component.
588 """
589 super().__init__(model.bands, model.bbox)
590 self._model = model
591 self.peak = peak
593 def get_model(self) -> Image:
594 """Generate the model for the source
596 Returns
597 -------
598 model :
599 The model as a 3D `(band, y, x)` array.
600 """
601 return self._model
603 def resize(self, model_box: Box) -> bool:
604 """Resize the component if needed and return whether it was resized"""
605 Logger.warning("CubeComponent does not support resizing")
606 return False
608 def update(self, it: int, input_grad: np.ndarray) -> None:
609 """Implementation of unused abstract method"""
610 Logger.warning("CubeComponent does not support updates")
612 def parameterize(self, parameterization: Callable) -> None:
613 """Implementation of unused abstract method"""
614 Logger.warning("CubeComponent does not support parameterization")
616 def to_data(self) -> ScarletCubeComponentData:
617 """Convert the component to persistable ScarletComponentData
619 Returns
620 -------
621 component_data: ScarletComponentData
622 The data object containing the component information
623 """
624 from .io import ScarletCubeComponentData
626 return ScarletCubeComponentData(
627 origin=self.bbox.origin, # type: ignore
628 peak=self.peak, # type: ignore
629 model=self.get_model().data,
630 )
632 def __getitem__(self, indices: Any) -> CubeComponent:
633 """Get a sub-component corresponding to the given indices.
635 Parameters
636 ----------
637 indices :
638 The indices to select.
639 Returns
640 -------
641 sub_component :
642 A new component that is a sub-component of this one.
643 """
644 band_indices = convert_indices(self.bands, indices)
645 if isinstance(band_indices, slice):
646 bands = self.bands[band_indices]
647 else:
648 bands = tuple(self.bands[i] for i in band_indices)
650 data = self.get_model()._data[band_indices,]
651 model = Image(data=data, bands=bands, yx0=cast(tuple[int, int], self.bbox.origin))
652 return CubeComponent(model=model, peak=self.peak)
654 def __copy__(self) -> CubeComponent:
655 """Create a copy of this component.
657 Returns
658 -------
659 component : ComponentCube
660 A new component that is a shallow copy of this one.
661 """
662 return CubeComponent(model=self._model, peak=self.peak)
664 def __deepcopy__(self, memo: dict[int, Any]) -> CubeComponent:
665 """Create a deep copy of this component.
667 Parameters
668 ----------
669 memo: dict[int, Any]
670 The memoization dictionary used by `copy.deepcopy`.
672 Returns
673 -------
674 component : ComponentCube
675 A new component that is a deep copy of this one.
676 """
677 if id(self) in memo:
678 return memo[id(self)]
680 # Create placeholder and add to memo FIRST
681 component = CubeComponent.__new__(CubeComponent)
682 memo[id(self)] = component
684 # Now safely initialize the placeholder with deepcopied arguments
685 component.__init__( # type: ignore[misc]
686 model=self._model.copy(),
687 peak=self.peak,
688 )
689 return component
692def default_fista_parameterization(component: Component):
693 """Initialize a factorized component to use FISTA PGM for optimization"""
694 if isinstance(component, FactorizedComponent):
695 component._spectrum = FistaParameter(component.spectrum, step=0.5)
696 component._morph = FistaParameter(component.morph, step=0.5)
697 else:
698 raise NotImplementedError(f"Unrecognized component type {component}")
701def default_adaprox_parameterization(component: Component, noise_rms: float | None = None):
702 """Initialize a factorized component to use Proximal ADAM
703 for optimization
704 """
705 if noise_rms is None:
706 noise_rms = 1e-16
707 if isinstance(component, FactorizedComponent):
708 component._spectrum = AdaproxParameter(
709 component.spectrum,
710 step=partial(relative_step, factor=1e-2, minimum=noise_rms),
711 )
712 component._morph = AdaproxParameter(
713 component.morph,
714 step=1e-2,
715 )
716 else:
717 raise NotImplementedError(f"Unrecognized component type {component}")