Coverage for python/lsst/scarlet/lite/component.py: 39%
133 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-16 02:46 -0700
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-16 02:46 -0700
1# This file is part of scarlet_lite.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22__all__ = [
23 "Component",
24 "FactorizedComponent",
25 "default_fista_parameterization",
26 "default_adaprox_parameterization",
27]
29from abc import ABC, abstractmethod
30from functools import partial
31from typing import Callable, cast
33import numpy as np
35from .bbox import Box
36from .image import Image
37from .operators import Monotonicity
38from .parameters import AdaproxParameter, FistaParameter, Parameter, parameter, relative_step
41class Component(ABC):
42 """A base component in scarlet lite.
44 Parameters
45 ----------
46 bands:
47 The bands used when the component model is created.
48 bbox: Box
49 The bounding box for this component.
50 """
52 def __init__(
53 self,
54 bands: tuple,
55 bbox: Box,
56 ):
57 self._bands = bands
58 self._bbox = bbox
60 @property
61 def bbox(self) -> Box:
62 """The bounding box that contains the component in the full image"""
63 return self._bbox
65 @property
66 def bands(self) -> tuple:
67 """The bands in the component model"""
68 return self._bands
70 @abstractmethod
71 def resize(self, model_box: Box) -> bool:
72 """Test whether or not the component needs to be resized
74 This should be overriden in inherited classes and return `True`
75 if the component needs to be resized.
76 """
78 @abstractmethod
79 def update(self, it: int, input_grad: np.ndarray) -> None:
80 """Update the component parameters from an input gradient
82 Parameters
83 ----------
84 it:
85 The current iteration of the optimizer.
86 input_grad:
87 Gradient of the likelihood wrt the component model
88 """
90 @abstractmethod
91 def get_model(self) -> Image:
92 """Generate a model for the component
94 This must be implemented in inherited classes.
96 Returns
97 -------
98 model: Image
99 The image of the component model.
100 """
102 @abstractmethod
103 def parameterize(self, parameterization: Callable) -> None:
104 """Convert the component parameter arrays into Parameter instances
106 Parameters
107 ----------
108 parameterization: Callable
109 A function to use to convert parameters of a given type into
110 a `Parameter` in place. It should take a single argument that
111 is the `Component` or `Source` that is to be parameterized.
112 """
115class FactorizedComponent(Component):
116 """A component that can be factorized into spectrum and morphology
117 parameters.
119 Parameters
120 ----------
121 bands:
122 The bands of the spectral dimension, in order.
123 spectrum:
124 The parameter to store and update the spectrum.
125 morph:
126 The parameter to store and update the morphology.
127 peak:
128 Location of the peak for the source.
129 bbox:
130 The `Box` in the `model_bbox` that contains the source.
131 bg_rms:
132 The RMS of the background used to threshold, grow,
133 and shrink the component.
134 floor:
135 Minimum value of the spectrum or center morphology pixel
136 (depending on which is normalized).
137 monotonicity:
138 The monotonicity operator to use for making the source monotonic.
139 If this parameter is `None`, the source will not be made monotonic.
140 """
142 def __init__(
143 self,
144 bands: tuple,
145 spectrum: Parameter | np.ndarray,
146 morph: Parameter | np.ndarray,
147 bbox: Box,
148 peak: tuple[int, int] | None = None,
149 bg_rms: np.ndarray | None = None,
150 bg_thresh: float | None = 0.25,
151 floor: float = 1e-20,
152 monotonicity: Monotonicity | None = None,
153 padding: int = 5,
154 ):
155 # Initialize all of the base attributes
156 super().__init__(
157 bands=bands,
158 bbox=bbox,
159 )
160 self._spectrum = parameter(spectrum)
161 self._morph = parameter(morph)
162 self._peak = peak
163 self.bg_rms = bg_rms
164 self.bg_thresh = bg_thresh
166 self.floor = floor
167 self.monotonicity = monotonicity
168 self.padding = padding
170 @property
171 def peak(self) -> tuple[int, int] | None:
172 """The peak of the component
174 Returns
175 -------
176 peak:
177 The peak of the component
178 """
179 return self._peak
181 @property
182 def component_center(self) -> tuple[int, int] | None:
183 """The center of the component in its bounding box
185 This is likely to be different than `Component.center`,
186 since `Component.center` is the center of the component in the
187 full model, whereas `component_center` is the center of the component
188 inside its bounding box.
190 Returns
191 -------
192 center:
193 The center of the component in its bounding box
194 """
195 _center = self.peak
196 if _center is None:
197 return None
198 center = (
199 _center[0] - self.bbox.origin[-2],
200 _center[1] - self.bbox.origin[-1],
201 )
202 return center
204 @property
205 def spectrum(self) -> np.ndarray:
206 """The array of spectrum values"""
207 return self._spectrum.x
209 @property
210 def morph(self) -> np.ndarray:
211 """The array of morphology values"""
212 return self._morph.x
214 @property
215 def shape(self) -> tuple:
216 """Shape of the resulting model image"""
217 return self.spectrum.shape + self.morph.shape
219 def get_model(self) -> Image:
220 """Build the model from the spectrum and morphology"""
221 # The spectrum and morph might be Parameters,
222 # so cast them as arrays in the model.
223 spectrum = self.spectrum
224 morph = self.morph
225 model = spectrum[:, None, None] * morph[None, :, :]
226 return Image(model, bands=self.bands, yx0=cast(tuple[int, int], self.bbox.origin))
228 def grad_spectrum(self, input_grad: np.ndarray, spectrum: np.ndarray, morph: np.ndarray):
229 """Gradient of the spectrum wrt. the component model"""
230 return np.einsum("...jk,jk", input_grad, morph)
232 def grad_morph(self, input_grad: np.ndarray, morph: np.ndarray, spectrum: np.ndarray):
233 """Gradient of the morph wrt. the component model"""
234 return np.einsum("i,i...", spectrum, input_grad)
236 def prox_spectrum(self, spectrum: np.ndarray) -> np.ndarray:
237 """Apply a prox-like update to the spectrum"""
238 # prevent divergent spectrum
239 spectrum[spectrum < self.floor] = self.floor
240 spectrum[~np.isfinite(spectrum)] = self.floor
241 return spectrum
243 def prox_morph(self, morph: np.ndarray) -> np.ndarray:
244 """Apply a prox-like update to the morphology"""
245 # monotonicity
246 if self.monotonicity is not None:
247 morph = self.monotonicity(morph, cast(tuple[int, int], self.component_center))
249 if self.bg_thresh is not None and self.bg_rms is not None:
250 bg_thresh = self.bg_rms * self.bg_thresh
251 # Enforce background thresholding
252 model = self.spectrum[:, None, None] * morph[None, :, :]
253 morph[np.all(model < bg_thresh[:, None, None], axis=0)] = 0
254 else:
255 # enforce positivity
256 morph[morph < 0] = 0
258 # prevent divergent morphology
259 shape = morph.shape
260 if self.peak is None:
261 peak = (shape[0] // 2, shape[1] // 2)
262 else:
263 peak = (
264 self.peak[0] - self.bbox.origin[-2],
265 self.peak[1] - self.bbox.origin[-1],
266 )
267 morph[peak] = np.max([morph[peak], self.floor])
269 # Ensure that the morphology is finite
270 morph[~np.isfinite(morph)] = 0
272 # Normalize the morphology
273 max_value = np.max(morph)
274 if max_value > 0:
275 morph[:] = morph / max_value
276 return morph
278 def resize(self, model_box: Box) -> bool:
279 """Test whether or not the component needs to be resized"""
280 # No need to resize if there is no size threshold.
281 # To allow box sizing but no thresholding use `bg_thresh=0`.
282 if self.bg_thresh is None or self.bg_rms is None:
283 return False
285 model = self.spectrum[:, None, None] * self.morph[None, :, :]
286 bg_thresh = self.bg_rms * self.bg_thresh
287 significant = np.any(model >= bg_thresh[:, None, None], axis=0)
288 if np.sum(significant) == 0:
289 # There are no significant pixels,
290 # so make a small box around the center
291 center = self.peak
292 if center is None:
293 center = (0, 0)
294 new_box = Box((1, 1), center).grow(self.padding) & model_box
295 else:
296 new_box = (
297 Box.from_data(significant, threshold=0).grow(self.padding) + self.bbox.origin
298 ) & model_box
299 if new_box == self.bbox:
300 return False
302 old_box = self.bbox
303 self._bbox = new_box
304 self._morph.resize(old_box, new_box)
305 return True
307 def update(self, it: int, input_grad: np.ndarray):
308 """Update the spectrum and morphology parameters"""
309 # Store the input spectrum so that the morphology can
310 # have a consistent update
311 spectrum = self.spectrum.copy()
312 self._spectrum.update(it, input_grad, self.morph)
313 self._morph.update(it, input_grad, spectrum)
315 def parameterize(self, parameterization: Callable) -> None:
316 """Convert the component parameter arrays into Parameter instances
318 Parameters
319 ----------
320 parameterization: Callable
321 A function to use to convert parameters of a given type into
322 a `Parameter` in place. It should take a single argument that
323 is the `Component` or `Source` that is to be parameterized.
324 """
325 # Update the spectrum and morph in place
326 parameterization(self)
327 # update the parameters
328 self._spectrum.grad = self.grad_spectrum
329 self._spectrum.prox = self.prox_spectrum
330 self._morph.grad = self.grad_morph
331 self._morph.prox = self.prox_morph
333 def __str__(self):
334 result = (
335 f"FactorizedComponent<\n bands={self.bands},\n center={self.peak},\n "
336 f"spectrum={self.spectrum},\n morph_shape={self.morph.shape}\n>"
337 )
338 return result
340 def __repr__(self):
341 return self.__str__()
344def default_fista_parameterization(component: Component):
345 """Initialize a factorized component to use FISTA PGM for optimization"""
346 if isinstance(component, FactorizedComponent):
347 component._spectrum = FistaParameter(component.spectrum, step=0.5)
348 component._morph = FistaParameter(component.morph, step=0.5)
349 else:
350 raise NotImplementedError(f"Unrecognized component type {component}")
353def default_adaprox_parameterization(component: Component, noise_rms: float | None = None):
354 """Initialize a factorized component to use Proximal ADAM
355 for optimization
356 """
357 if noise_rms is None:
358 noise_rms = 1e-16
359 if isinstance(component, FactorizedComponent):
360 component._spectrum = AdaproxParameter(
361 component.spectrum,
362 step=partial(relative_step, factor=1e-2, minimum=noise_rms),
363 )
364 component._morph = AdaproxParameter(
365 component.morph,
366 step=1e-2,
367 )
368 else:
369 raise NotImplementedError(f"Unrecognized component type {component}")