Coverage for python / lsst / scarlet / lite / models / fit_psf.py: 28%
71 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 23:28 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 23:28 +0000
1# This file is part of scarlet_lite.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22__all__ = ["FittedPsfObservation", "FittedPsfBlend"]
24from typing import Callable, cast
26import numpy as np
28from ..bbox import Box
29from ..blend import Blend
30from ..fft import Fourier, centered
31from ..fft import convolve as fft_convolve
32from ..image import Image
33from ..observation import Observation
34from ..parameters import parameter
37class FittedPsfObservation(Observation):
38 """An observation that fits the PSF used to convolve the model."""
40 def __init__(
41 self,
42 images: np.ndarray | Image,
43 variance: np.ndarray | Image,
44 weights: np.ndarray | Image,
45 psfs: np.ndarray,
46 model_psf: np.ndarray | None = None,
47 noise_rms: np.ndarray | None = None,
48 bbox: Box | None = None,
49 bands: tuple | None = None,
50 padding: int = 3,
51 convolution_mode: str = "fft",
52 shape: tuple[int, int] | None = None,
53 ):
54 """Initialize a `FitPsfObservation`
56 See `Observation` for a description of the parameters.
57 """
58 super().__init__(
59 images,
60 variance,
61 weights,
62 psfs,
63 model_psf,
64 noise_rms,
65 bbox,
66 bands,
67 padding,
68 convolution_mode,
69 )
71 self.axes = (-2, -1)
73 if shape is None:
74 shape = (41, 41)
76 # Make the DFT of the psf a fittable parameter
77 self._fitted_kernel = parameter(cast(Fourier, self.diff_kernel).image)
79 def grad_fit_kernel(self, input_grad: np.ndarray, psf: np.ndarray, model: np.ndarray) -> np.ndarray:
80 """Gradient of the loss wrt the PSF
82 This is just the cross correlation of the input gradient
83 with the model.
85 Parameters
86 ----------
87 input_grad:
88 The gradient of the loss wrt the model
89 psf:
90 The PSF of the model.
91 model:
92 The deconvolved model.
93 """
94 grad = cast(
95 np.ndarray,
96 fft_convolve(
97 Fourier(model),
98 Fourier(input_grad[:, ::-1, ::-1]),
99 axes=(1, 2),
100 return_fourier=False,
101 ),
102 )
104 return centered(grad, psf.shape)
106 def prox_kernel(self, kernel: np.ndarray) -> np.ndarray:
107 # No prox for now
108 return kernel
110 @property
111 def fitted_kernel(self) -> np.ndarray:
112 return self._fitted_kernel.x
114 @property
115 def cached_kernel(self):
116 return self.fitted_kernel[:, ::-1, ::-1]
118 def convolve(self, image: Image, mode: str | None = None, grad: bool = False) -> Image:
119 """Convolve the model into the observed seeing in each band.
121 Parameters
122 ----------
123 image:
124 The image to convolve
125 mode:
126 The convolution mode to use.
127 This should be "real" or "fft" or `None`,
128 where `None` will use the default `convolution_mode`
129 specified during init.
130 grad:
131 Whether this is a backward gradient convolution
132 (`grad==True`) or a pure convolution with the PSF.
133 """
134 if grad:
135 kernel = self.cached_kernel
136 else:
137 kernel = self.fitted_kernel
139 if mode != "fft" and mode is not None:
140 return super().convolve(image, mode, grad)
142 result = fft_convolve(
143 Fourier(image.data),
144 Fourier(kernel),
145 axes=(1, 2),
146 return_fourier=False,
147 )
148 return Image(cast(np.ndarray, result), bands=image.bands, yx0=image.yx0)
150 def update(self, it: int, input_grad: np.ndarray, model: np.ndarray):
151 """Update the PSF given the gradient of the loss
153 Parameters
154 ----------
155 it: int
156 The current iteration
157 input_grad: np.ndarray
158 The gradient of the loss wrt the model
159 model: np.ndarray
160 The deconvolved model.
161 """
162 self._fitted_kernel.update(it, input_grad, model)
164 def parameterize(self, parameterization: Callable) -> None:
165 """Convert the component parameter arrays into Parameter instances
167 Parameters
168 ----------
169 parameterization: Callable
170 A function to use to convert parameters of a given type into
171 a `Parameter` in place. It should take a single argument that
172 is the `Component` or `Source` that is to be parameterized.
173 """
174 # Update the fitted kernel in place
175 parameterization(self)
176 # update the parameters
177 self._fitted_kernel.grad = self.grad_fit_kernel
178 self._fitted_kernel.prox = self.prox_kernel
181class FittedPsfBlend(Blend):
182 """A blend that attempts to fit the PSF along with the source models."""
184 def _grad_log_likelihood(self) -> tuple[Image, np.ndarray]:
185 """Gradient of the likelihood wrt the unconvolved model"""
186 model = self.get_model(convolve=True)
187 # Update the loss
188 self.loss.append(self.observation.log_likelihood(model))
189 # Calculate the gradient wrt the model d(logL)/d(model)
190 residual = self.observation.weights * (model - self.observation.images)
192 return residual, model.data
194 def fit(
195 self,
196 max_iter: int,
197 e_rel: float = 1e-4,
198 min_iter: int = 1,
199 resize: int = 10,
200 ) -> tuple[int, float]:
201 """Fit all of the parameters
203 Parameters
204 ----------
205 max_iter: int
206 The maximum number of iterations
207 e_rel: float
208 The relative error to use for determining convergence.
209 min_iter: int
210 The minimum number of iterations.
211 resize: int
212 Number of iterations before attempting to resize the
213 resizable components. If `resize` is `None` then
214 no resizing is ever attempted.
215 """
216 it = self.it
217 while it < max_iter:
218 # Calculate the gradient wrt the on-convolved model
219 grad_log_likelihood, model = self._grad_log_likelihood()
220 _grad_log_likelihood = self.observation.convolve(grad_log_likelihood, grad=True)
221 # Check if resizing needs to be performed in this iteration
222 if resize is not None and self.it > 0 and self.it % resize == 0:
223 do_resize = True
224 else:
225 do_resize = False
226 # Update each component given the current gradient
227 for component in self.components:
228 overlap = component.bbox & self.bbox
229 component.update(it, _grad_log_likelihood[overlap].data)
230 # Check to see if any components need to be resized
231 if do_resize:
232 component.resize(self.bbox)
234 # Update the PSF
235 cast(FittedPsfObservation, self.observation).update(
236 self.it,
237 grad_log_likelihood.data,
238 model,
239 )
240 # Stopping criteria
241 it += 1
242 if it > min_iter and np.abs(self.loss[-1] - self.loss[-2]) < e_rel * np.abs(self.loss[-1]):
243 break
244 self.it = it
245 return it, self.loss[-1]
247 def parameterize(self, parameterization: Callable):
248 """Convert the component parameter arrays into Parameter instances
250 Parameters
251 ----------
252 parameterization:
253 A function to use to convert parameters of a given type into
254 a `Parameter` in place. It should take a single argument that
255 is the `Component` or `Source` that is to be parameterized.
256 """
257 for source in self.sources:
258 source.parameterize(parameterization)
259 cast(FittedPsfObservation, self.observation).parameterize(parameterization)