Coverage for python / lsst / scarlet / lite / blend.py: 20%
172 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-26 08:40 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-26 08:40 +0000
1# This file is part of scarlet_lite.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ["Blend"]
26from abc import ABC, abstractmethod
27from copy import deepcopy
28from typing import TYPE_CHECKING, Any, Callable, Self, Sequence, cast
30import numpy as np
32from .bbox import Box
33from .component import Component, FactorizedComponent
34from .image import Image
35from .observation import Observation
36from .source import Source, SourceBase
38if TYPE_CHECKING:
39 from .io import ScarletBlendData, ScarletSourceBaseData
42class BlendBase(ABC):
43 """A base class for blends that can be extended to add additional
44 functionality.
46 This class holds all of the sources and observation that are to be fit,
47 as well as performing fitting and joint initialization of the
48 spectral components (when applicable).
50 Parameters
51 ----------
52 sources:
53 The sources to fit.
54 observation:
55 The observation that contains the images,
56 PSF, etc. that are being fit.
57 metadata:
58 Additional metadata to store with the blend.
59 """
61 sources: Sequence[SourceBase]
62 observation: Observation
63 metadata: dict | None
65 @property
66 def shape(self) -> tuple[int, int, int]:
67 """Shape of the model for the entire `Blend`."""
68 return self.observation.shape
70 @property
71 def bbox(self) -> Box:
72 """The bounding box of the entire blend."""
73 return self.observation.bbox
75 @property
76 def components(self) -> list[Component]:
77 """The list of all components in the blend.
79 Since the list of sources might change,
80 this is always built on the fly.
81 """
82 return [c for src in self.sources for c in src.components]
84 @abstractmethod
85 def __getitem__(self, indices: Any) -> Self:
86 """Get a sub-blend corresponding to the given indices.
88 Parameters
89 ----------
90 indices :
91 The indices to use to slice the blend.
93 Returns
94 -------
95 sub_blend :
96 A new `BlendBase` instance containing only data from the
97 specified bands in the specified order.
99 Raises
100 ------
101 IndexError :
102 If the indices contain bands not included in the original
103 blend or any spatial indices are given.
104 """
106 @abstractmethod
107 def __copy__(self) -> Self:
108 """Create a copy of this blend.
110 Returns
111 -------
112 blend : BlendBase
113 A new blend that is a copy of this one.
114 """
116 @abstractmethod
117 def __deepcopy__(self, memo: dict[int, Any]) -> Self:
118 """Create a deep copy of this blend.
120 Parameters
121 ----------
122 memo : dict[int, Any]
123 A memoization dictionary used by `copy.deepcopy`.
125 Returns
126 -------
127 blend : BlendBase
128 A new blend that is a deep copy of this one.
129 """
131 def copy(self, deep: bool = False) -> Self:
132 """Create a copy of this blend.
134 Parameters
135 ----------
136 deep :
137 If `True`, a deep copy is made. If `False`, a shallow copy is made.
138 Default is `False`.
140 Returns
141 -------
142 blend : Self
143 A new blend that is a copy of this one.
144 """
145 if deep:
146 return self.__deepcopy__({})
147 else:
148 return self.__copy__()
150 @abstractmethod
151 def get_model(self, convolve: bool = False, use_flux: bool = False) -> Image:
152 """Generate a model of the entire blend.
154 Parameters
155 ----------
156 convolve:
157 Whether to convolve the model with the observed PSF in each band.
158 use_flux:
159 Whether to use the re-distributed flux associated with the sources
160 instead of the component models.
162 Returns
163 -------
164 model:
165 The model created by combining all of the source models.
166 """
168 @abstractmethod
169 def to_data(self) -> ScarletBlendData:
170 """Convert the blend into a serializable dictionary format.
172 Returns
173 -------
174 data:
175 A dictionary containing all of the information needed to
176 reconstruct the blend.
177 """
180class Blend(BlendBase):
181 """A single blend.
183 This class holds all of the sources and observation that are to be fit,
184 as well as performing fitting and joint initialization of the
185 spectral components (when applicable).
187 Parameters
188 ----------
189 sources:
190 The sources to fit.
191 observation:
192 The observation that contains the images,
193 PSF, etc. that are being fit.
194 metadata:
195 Additional metadata to store with the blend.
196 """
198 sources: list[Source]
200 def __init__(self, sources: Sequence[Source], observation: Observation, metadata: dict | None = None):
201 self.sources = list(sources)
202 self.observation = observation
203 if metadata is not None and len(metadata) == 0:
204 metadata = None
205 self.metadata = metadata
207 # Initialize the iteration count and loss function
208 self.it = 0
209 self.loss: list[float] = []
211 def get_model(self, convolve: bool = False, use_flux: bool = False) -> Image:
212 """Generate a model of the entire blend.
214 Parameters
215 ----------
216 convolve:
217 Whether to convolve the model with the observed PSF in each band.
218 use_flux:
219 Whether to use the re-distributed flux associated with the sources
220 instead of the component models.
222 Returns
223 -------
224 model:
225 The model created by combining all of the source models.
226 """
227 model = Image(
228 np.zeros(self.shape, dtype=self.observation.images.dtype),
229 bands=self.observation.bands,
230 yx0=cast(tuple[int, int], self.observation.bbox.origin[-2:]),
231 )
233 if use_flux:
234 for src in self.sources:
235 if src.flux_weighted_image is None:
236 raise ValueError(
237 "Some sources do not have 'flux' attribute set. Run measure.conserve_flux"
238 )
239 src.flux_weighted_image.insert_into(model)
240 else:
241 for component in self.components:
242 component.get_model().insert_into(model)
243 if convolve:
244 return self.observation.convolve(model)
245 return model
247 def _grad_log_likelihood(self) -> tuple[Image, np.ndarray]:
248 """Gradient of the likelihood wrt the unconvolved model
250 Returns
251 -------
252 result:
253 The gradient of the likelihood wrt the model
254 model_data:
255 The convol model data used to calculate the gradient.
256 This can be useful for debugging but is not used in
257 production.
258 """
259 model = self.get_model(convolve=True)
260 # Update the loss
261 self.loss.append(self.observation.log_likelihood(model))
262 # Calculate the gradient wrt the model d(logL)/d(model)
263 result = self.observation.weights * (model - self.observation.images)
264 result = self.observation.convolve(result, grad=True)
265 return result, model.data
267 @property
268 def log_likelihood(self) -> float:
269 """The current log-likelihood
271 This is calculated on the fly to ensure that it is always up to date
272 with the current model parameters.
273 """
274 return self.observation.log_likelihood(self.get_model(convolve=True))
276 def fit_spectra(self, clip: bool = False) -> Blend:
277 """Fit all of the spectra given their current morphologies with a
278 linear least squares algorithm.
280 Parameters
281 ----------
282 clip:
283 Whether or not to clip components that were not
284 assigned any flux during the fit.
286 Returns
287 -------
288 blend:
289 The blend with updated components is returned.
290 """
291 from .initialization import multifit_spectra
293 morphs = []
294 spectra = []
295 factorized_indices = []
296 model = Image.from_box(
297 self.observation.bbox,
298 bands=self.observation.bands,
299 dtype=self.observation.dtype,
300 )
301 components = self.components
302 for idx, component in enumerate(components):
303 if hasattr(component, "morph") and hasattr(component, "spectrum"):
304 component = cast(FactorizedComponent, component)
305 morphs.append(component.morph)
306 spectra.append(component.spectrum)
307 factorized_indices.append(idx)
308 else:
309 model.insert(component.get_model())
310 model = self.observation.convolve(model, mode="real")
312 boxes = [c.bbox for c in components]
313 fit_spectra = multifit_spectra(
314 self.observation,
315 [Image(morph, yx0=cast(tuple[int, int], bbox.origin)) for morph, bbox in zip(morphs, boxes)],
316 model,
317 )
318 for idx in range(len(morphs)):
319 component = cast(FactorizedComponent, components[factorized_indices[idx]])
320 component.spectrum[:] = fit_spectra[idx]
321 component.spectrum[component.spectrum < 0] = 0
323 # Run the proxes for all of the components to make sure that the
324 # spectra are consistent with the constraints.
325 # In practice this usually means making sure that they are
326 # non-negative.
327 for src in self.sources:
328 for component in src.components:
329 if (
330 hasattr(component, "spectrum")
331 and hasattr(component, "prox_spectrum")
332 and component.prox_spectrum is not None # type: ignore
333 ):
334 component.prox_spectrum(component.spectrum) # type: ignore
336 if clip:
337 # Remove components with no positive flux
338 for src in self.sources:
339 _components = []
340 for component in src.components:
341 component_model = component.get_model()
342 component_model.data[component_model.data < 0] = 0
343 if np.sum(component_model.data) > 0:
344 _components.append(component)
345 src.components = _components
347 return self
349 def fit(
350 self,
351 max_iter: int,
352 e_rel: float = 1e-4,
353 min_iter: int = 15,
354 resize: int = 10,
355 ) -> tuple[int, float]:
356 """Fit all of the parameters
358 Parameters
359 ----------
360 max_iter:
361 The maximum number of iterations
362 e_rel:
363 The relative error to use for determining convergence.
364 min_iter:
365 The minimum number of iterations.
366 resize:
367 Number of iterations before attempting to resize the
368 resizable components. If `resize` is `None` then
369 no resizing is ever attempted.
371 Returns
372 -------
373 it:
374 Number of iterations.
375 loss:
376 Loss for the last solution
377 """
378 while self.it < max_iter:
379 # Calculate the gradient wrt the on-convolved model
380 grad_log_likelihood = self._grad_log_likelihood()
381 if resize is not None and self.it > 0 and self.it % resize == 0:
382 do_resize = True
383 else:
384 do_resize = False
385 # Update each component given the current gradient
386 for component in self.components:
387 overlap = component.bbox & self.bbox
388 component.update(self.it, grad_log_likelihood[0][overlap].data)
389 # Check to see if any components need to be resized
390 if do_resize:
391 component.resize(self.bbox)
392 # Stopping criteria
393 self.it += 1
394 if self.it > min_iter and np.abs(self.loss[-1] - self.loss[-2]) < e_rel * np.abs(self.loss[-1]):
395 break
396 return self.it, self.loss[-1]
398 def parameterize(self, parameterization: Callable):
399 """Convert the component parameter arrays into Parameter instances
401 Parameters
402 ----------
403 parameterization:
404 A function to use to convert parameters of a given type into
405 a `Parameter` in place. It should take a single argument that
406 is the `Component` or `Source` that is to be parameterized.
407 """
408 for source in self.sources:
409 source.parameterize(parameterization)
411 def conserve_flux(self, mask_footprint: bool = True, weight_image: Image | None = None) -> None:
412 """Use the source models as templates to re-distribute flux
413 from the data
415 The source models are used as approximations to the data,
416 which redistribute the flux in the data according to the
417 ratio of the models for each source.
418 There is no return value for this function,
419 instead it adds (or modifies) a ``flux_weighted_image``
420 attribute to each the sources with the flux attributed to
421 that source.
423 Parameters
424 ----------
425 blend:
426 The blend that is being fit
427 mask_footprint:
428 Whether or not to apply a mask for pixels with zero weight.
429 weight_image:
430 The weight image to use for the redistribution.
431 If `None` then the observation image is used.
432 """
433 observation = self.observation
434 py = observation.psfs.shape[-2] // 2
435 px = observation.psfs.shape[-1] // 2
437 images = observation.images.copy()
438 if mask_footprint:
439 images.data[observation.weights.data == 0] = 0
441 if weight_image is None:
442 weight_image = self.get_model()
443 # Always convolve in real space to avoid FFT artifacts
444 weight_image = observation.convolve(weight_image, mode="real")
446 # Due to ringing in the PSF, the convolved model can have
447 # negative values. We take the absolute value to avoid
448 # negative fluxes in the flux weighted images.
449 weight_image.data[:] = np.abs(weight_image.data)
451 for src in self.sources:
452 if src.is_null:
453 src.flux_weighted_image = Image.from_box(Box((0, 0)), bands=observation.bands) # type: ignore
454 continue
455 src_model = src.get_model()
457 # Grow the model to include the wings of the PSF
458 src_box = src.bbox.grow((py, px))
459 overlap = observation.bbox & src_box
460 src_model = src_model.project(bbox=overlap)
461 src_model = observation.convolve(src_model, mode="real")
462 src_model.data[:] = np.abs(src_model.data)
463 numerator = src_model.data
464 denominator = weight_image[overlap].data
465 cuts = denominator != 0
466 ratio = np.zeros(numerator.shape, dtype=numerator.dtype)
467 ratio[cuts] = numerator[cuts] / denominator[cuts]
468 ratio[denominator == 0] = 0
469 # sometimes numerical errors can cause a hot pixel to have a
470 # slightly higher ratio than 1
471 ratio[ratio > 1] = 1
472 src.flux_weighted_image = src_model.copy_with(data=ratio) * images[overlap]
474 def to_data(self) -> ScarletBlendData:
475 """Convert the Blend into a persistable data object
477 Parameters
478 ----------
479 blend :
480 The blend that is being persisted.
482 Returns
483 -------
484 blend_data :
485 The data model for a single blend.
486 """
487 from .io import ScarletBlendData
489 sources: dict[Any, ScarletSourceBaseData] = {}
490 for sidx, source in enumerate(self.sources):
491 metadata = source.metadata or {}
492 if "id" in metadata:
493 sources[metadata["id"]] = source.to_data()
494 else:
495 sources[sidx] = source.to_data()
497 blend_data = ScarletBlendData(
498 origin=self.bbox.origin, # type: ignore
499 shape=self.bbox.shape, # type: ignore
500 sources=sources,
501 metadata=self.metadata,
502 )
504 return blend_data
506 def __getitem__(self, indices: Any) -> Blend:
507 """Get a sub-blend corresponding to the given indices.
509 Parameters
510 ----------
511 indices :
512 The indices to use to slice the blend.
514 Returns
515 -------
516 blend :
517 A new `Blend` instance containing only data from the
518 specified bands in the specified order.
520 Raises
521 ------
522 IndexError :
523 If the indices contain bands not included in the original
524 blend or a bounding box is given.
525 """
526 return Blend(
527 sources=[src[indices] for src in self.sources],
528 observation=self.observation[indices],
529 metadata=self.metadata,
530 )
532 def __copy__(self) -> Blend:
533 """Create a copy of this blend.
535 Returns
536 -------
537 blend : Blend
538 A new blend that is a copy of this one.
539 """
540 return Blend(sources=self.sources, observation=self.observation, metadata=self.metadata)
542 def __deepcopy__(self, memo: dict[int, Any]) -> Blend:
543 """Create a deep copy of this blend.
545 Parameters
546 ----------
547 memo : dict[int, Any]
548 A memoization dictionary used by `copy.deepcopy`.
550 Returns
551 -------
552 blend : Blend
553 A new blend that is a deep copy of this one.
554 """
555 # Check if already copied
556 if id(self) in memo:
557 return memo[id(self)]
559 # Create placeholder and add to memo FIRST
560 blend = Blend.__new__(Blend)
561 memo[id(self)] = blend
563 # Now safely initialize the placeholder with deepcopied arguments
564 blend.__init__( # type: ignore[misc]
565 sources=[deepcopy(src, memo) for src in self.sources],
566 observation=deepcopy(self.observation, memo),
567 metadata=deepcopy(self.metadata, memo),
568 )
570 return blend