Coverage for python/lsst/scarlet/lite/models/fit_psf.py: 31%
70 statements
« prev ^ index » next coverage.py v7.4.1, created at 2024-02-07 11:26 +0000
« prev ^ index » next coverage.py v7.4.1, created at 2024-02-07 11:26 +0000
1# This file is part of scarlet_lite.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22__all__ = ["FittedPsfObservation", "FittedPsfBlend"]
24from typing import Callable, cast
26import numpy as np
28from ..bbox import Box
29from ..blend import Blend
30from ..fft import Fourier, get_fft_shape
31from ..image import Image
32from ..observation import Observation
33from ..parameters import parameter
36class FittedPsfObservation(Observation):
37 """An observation that fits the PSF used to convolve the model."""
39 def __init__(
40 self,
41 images: np.ndarray | Image,
42 variance: np.ndarray | Image,
43 weights: np.ndarray | Image,
44 psfs: np.ndarray,
45 model_psf: np.ndarray | None = None,
46 noise_rms: np.ndarray | None = None,
47 bbox: Box | None = None,
48 bands: tuple | None = None,
49 padding: int = 3,
50 convolution_mode: str = "fft",
51 ):
52 """Initialize a `FitPsfObservation`
54 See `Observation` for a description of the parameters.
55 """
56 super().__init__(
57 images,
58 variance,
59 weights,
60 psfs,
61 model_psf,
62 noise_rms,
63 bbox,
64 bands,
65 padding,
66 convolution_mode,
67 )
69 self.axes = (-2, -1)
71 self.fft_shape = get_fft_shape(self.images.data[0], psfs[0], padding, self.axes)
73 # Make the DFT of the psf a fittable parameter
74 self._fitted_kernel = parameter(cast(Fourier, self.diff_kernel).fft(self.fft_shape, self.axes))
76 def grad_fit_kernel(
77 self, input_grad: np.ndarray, kernel: np.ndarray, model_fft: np.ndarray
78 ) -> np.ndarray:
79 # Transform the upstream gradient into k-space
80 grad_fft = Fourier(input_grad)
81 _grad_fft = grad_fft.fft(self.fft_shape, self.axes)
82 return _grad_fft * model_fft
84 def prox_kernel(self, kernel: np.ndarray) -> np.ndarray:
85 # No prox for now
86 return kernel
88 @property
89 def fitted_kernel(self) -> np.ndarray:
90 return self._fitted_kernel.x
92 @property
93 def cached_kernel(self):
94 return self.fitted_kernel.real - self.fitted_kernel.imag * 1j
96 def convolve(self, image: Image, mode: str | None = None, grad: bool = False) -> Image:
97 """Convolve the model into the observed seeing in each band.
99 Parameters
100 ----------
101 image:
102 The image to convolve
103 mode:
104 The convolution mode to use.
105 This should be "real" or "fft" or `None`,
106 where `None` will use the default `convolution_mode`
107 specified during init.
108 grad:
109 Whether this is a backward gradient convolution
110 (`grad==True`) or a pure convolution with the PSF.
111 """
112 if grad:
113 kernel = self.cached_kernel
114 else:
115 kernel = self.fitted_kernel
117 if mode != "fft" and mode is not None:
118 return super().convolve(image, mode, grad)
120 fft_image = Fourier(image.data)
121 fft = fft_image.fft(self.fft_shape, self.axes)
123 result = Fourier.from_fft(fft * kernel, self.fft_shape, image.shape, self.axes)
124 return Image(result.image, bands=image.bands, yx0=image.yx0)
126 def update(self, it: int, input_grad: np.ndarray, model: Image):
127 _model = Fourier(model.data[:, ::-1, ::-1])
128 model_fft = _model.fft(self.fft_shape, self.axes)
129 self._fitted_kernel.update(it, input_grad, model_fft)
131 def parameterize(self, parameterization: Callable) -> None:
132 """Convert the component parameter arrays into Parameter instances
134 Parameters
135 ----------
136 parameterization: Callable
137 A function to use to convert parameters of a given type into
138 a `Parameter` in place. It should take a single argument that
139 is the `Component` or `Source` that is to be parameterized.
140 """
141 # Update the spectrum and morph in place
142 parameterization(self)
143 # update the parameters
144 self._fitted_kernel.grad = self.grad_fit_kernel
145 self._fitted_kernel.prox = self.prox_kernel
148class FittedPsfBlend(Blend):
149 """A blend that attempts to fit the PSF along with the source models."""
151 def _grad_log_likelihood(self) -> Image:
152 """Gradient of the likelihood wrt the unconvolved model"""
153 model = self.get_model(convolve=True)
154 # Update the loss
155 self.loss.append(self.observation.log_likelihood(model))
156 # Calculate the gradient wrt the model d(logL)/d(model)
157 result = self.observation.weights * (model - self.observation.images)
158 return result
160 def fit(
161 self,
162 max_iter: int,
163 e_rel: float = 1e-4,
164 min_iter: int = 1,
165 resize: int = 10,
166 ) -> tuple[int, float]:
167 """Fit all of the parameters
169 Parameters
170 ----------
171 max_iter: int
172 The maximum number of iterations
173 e_rel: float
174 The relative error to use for determining convergence.
175 min_iter: int
176 The minimum number of iterations.
177 resize: int
178 Number of iterations before attempting to resize the
179 resizable components. If `resize` is `None` then
180 no resizing is ever attempted.
181 """
182 it = self.it
183 while it < max_iter:
184 # Calculate the gradient wrt the on-convolved model
185 grad_log_likelihood = self._grad_log_likelihood()
186 _grad_log_likelihood = self.observation.convolve(grad_log_likelihood, grad=True)
187 # Check if resizing needs to be performed in this iteration
188 if resize is not None and self.it > 0 and self.it % resize == 0:
189 do_resize = True
190 else:
191 do_resize = False
192 # Update each component given the current gradient
193 for component in self.components:
194 overlap = component.bbox & self.bbox
195 component.update(it, _grad_log_likelihood[overlap].data)
196 # Check to see if any components need to be resized
197 if do_resize:
198 component.resize(self.bbox)
200 # Update the PSF
201 cast(FittedPsfObservation, self.observation).update(
202 it, grad_log_likelihood.data, self.get_model()
203 )
204 # Stopping criteria
205 it += 1
206 if it > min_iter and np.abs(self.loss[-1] - self.loss[-2]) < e_rel * np.abs(self.loss[-1]):
207 break
208 self.it = it
209 return it, self.loss[-1]