Coverage for python / lsst / scarlet / lite / models / fit_psf.py: 28%

71 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-26 08:40 +0000

1# This file is part of scarlet_lite. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = ["FittedPsfObservation", "FittedPsfBlend"] 

23 

24from typing import Callable, cast 

25 

26import numpy as np 

27 

28from ..bbox import Box 

29from ..blend import Blend 

30from ..fft import Fourier, centered 

31from ..fft import convolve as fft_convolve 

32from ..image import Image 

33from ..observation import Observation 

34from ..parameters import parameter 

35 

36 

37class FittedPsfObservation(Observation): 

38 """An observation that fits the PSF used to convolve the model.""" 

39 

40 def __init__( 

41 self, 

42 images: np.ndarray | Image, 

43 variance: np.ndarray | Image, 

44 weights: np.ndarray | Image, 

45 psfs: np.ndarray, 

46 model_psf: np.ndarray | None = None, 

47 noise_rms: np.ndarray | None = None, 

48 bbox: Box | None = None, 

49 bands: tuple | None = None, 

50 padding: int = 3, 

51 convolution_mode: str = "fft", 

52 shape: tuple[int, int] | None = None, 

53 ): 

54 """Initialize a `FitPsfObservation` 

55 

56 See `Observation` for a description of the parameters. 

57 """ 

58 super().__init__( 

59 images, 

60 variance, 

61 weights, 

62 psfs, 

63 model_psf, 

64 noise_rms, 

65 bbox, 

66 bands, 

67 padding, 

68 convolution_mode, 

69 ) 

70 

71 self.axes = (-2, -1) 

72 

73 if shape is None: 

74 shape = (41, 41) 

75 

76 # Make the DFT of the psf a fittable parameter 

77 self._fitted_kernel = parameter(cast(Fourier, self.diff_kernel).image) 

78 

79 def grad_fit_kernel(self, input_grad: np.ndarray, psf: np.ndarray, model: np.ndarray) -> np.ndarray: 

80 """Gradient of the loss wrt the PSF 

81 

82 This is just the cross correlation of the input gradient 

83 with the model. 

84 

85 Parameters 

86 ---------- 

87 input_grad: 

88 The gradient of the loss wrt the model 

89 psf: 

90 The PSF of the model. 

91 model: 

92 The deconvolved model. 

93 """ 

94 grad = cast( 

95 np.ndarray, 

96 fft_convolve( 

97 Fourier(model), 

98 Fourier(input_grad[:, ::-1, ::-1]), 

99 axes=(1, 2), 

100 return_fourier=False, 

101 ), 

102 ) 

103 

104 return centered(grad, psf.shape) 

105 

106 def prox_kernel(self, kernel: np.ndarray) -> np.ndarray: 

107 # No prox for now 

108 return kernel 

109 

110 @property 

111 def fitted_kernel(self) -> np.ndarray: 

112 return self._fitted_kernel.x 

113 

114 @property 

115 def cached_kernel(self): 

116 return self.fitted_kernel[:, ::-1, ::-1] 

117 

118 def convolve(self, image: Image, mode: str | None = None, grad: bool = False) -> Image: 

119 """Convolve the model into the observed seeing in each band. 

120 

121 Parameters 

122 ---------- 

123 image: 

124 The image to convolve 

125 mode: 

126 The convolution mode to use. 

127 This should be "real" or "fft" or `None`, 

128 where `None` will use the default `convolution_mode` 

129 specified during init. 

130 grad: 

131 Whether this is a backward gradient convolution 

132 (`grad==True`) or a pure convolution with the PSF. 

133 """ 

134 if grad: 

135 kernel = self.cached_kernel 

136 else: 

137 kernel = self.fitted_kernel 

138 

139 if mode != "fft" and mode is not None: 

140 return super().convolve(image, mode, grad) 

141 

142 result = fft_convolve( 

143 Fourier(image.data), 

144 Fourier(kernel), 

145 axes=(1, 2), 

146 return_fourier=False, 

147 ) 

148 return Image(cast(np.ndarray, result), bands=image.bands, yx0=image.yx0) 

149 

150 def update(self, it: int, input_grad: np.ndarray, model: np.ndarray): 

151 """Update the PSF given the gradient of the loss 

152 

153 Parameters 

154 ---------- 

155 it: int 

156 The current iteration 

157 input_grad: np.ndarray 

158 The gradient of the loss wrt the model 

159 model: np.ndarray 

160 The deconvolved model. 

161 """ 

162 self._fitted_kernel.update(it, input_grad, model) 

163 

164 def parameterize(self, parameterization: Callable) -> None: 

165 """Convert the component parameter arrays into Parameter instances 

166 

167 Parameters 

168 ---------- 

169 parameterization: Callable 

170 A function to use to convert parameters of a given type into 

171 a `Parameter` in place. It should take a single argument that 

172 is the `Component` or `Source` that is to be parameterized. 

173 """ 

174 # Update the fitted kernel in place 

175 parameterization(self) 

176 # update the parameters 

177 self._fitted_kernel.grad = self.grad_fit_kernel 

178 self._fitted_kernel.prox = self.prox_kernel 

179 

180 

181class FittedPsfBlend(Blend): 

182 """A blend that attempts to fit the PSF along with the source models.""" 

183 

184 def _grad_log_likelihood(self) -> tuple[Image, np.ndarray]: 

185 """Gradient of the likelihood wrt the unconvolved model""" 

186 model = self.get_model(convolve=True) 

187 # Update the loss 

188 self.loss.append(self.observation.log_likelihood(model)) 

189 # Calculate the gradient wrt the model d(logL)/d(model) 

190 residual = self.observation.weights * (model - self.observation.images) 

191 

192 return residual, model.data 

193 

194 def fit( 

195 self, 

196 max_iter: int, 

197 e_rel: float = 1e-4, 

198 min_iter: int = 1, 

199 resize: int = 10, 

200 ) -> tuple[int, float]: 

201 """Fit all of the parameters 

202 

203 Parameters 

204 ---------- 

205 max_iter: int 

206 The maximum number of iterations 

207 e_rel: float 

208 The relative error to use for determining convergence. 

209 min_iter: int 

210 The minimum number of iterations. 

211 resize: int 

212 Number of iterations before attempting to resize the 

213 resizable components. If `resize` is `None` then 

214 no resizing is ever attempted. 

215 """ 

216 it = self.it 

217 while it < max_iter: 

218 # Calculate the gradient wrt the on-convolved model 

219 grad_log_likelihood, model = self._grad_log_likelihood() 

220 _grad_log_likelihood = self.observation.convolve(grad_log_likelihood, grad=True) 

221 # Check if resizing needs to be performed in this iteration 

222 if resize is not None and self.it > 0 and self.it % resize == 0: 

223 do_resize = True 

224 else: 

225 do_resize = False 

226 # Update each component given the current gradient 

227 for component in self.components: 

228 overlap = component.bbox & self.bbox 

229 component.update(it, _grad_log_likelihood[overlap].data) 

230 # Check to see if any components need to be resized 

231 if do_resize: 

232 component.resize(self.bbox) 

233 

234 # Update the PSF 

235 cast(FittedPsfObservation, self.observation).update( 

236 self.it, 

237 grad_log_likelihood.data, 

238 model, 

239 ) 

240 # Stopping criteria 

241 it += 1 

242 if it > min_iter and np.abs(self.loss[-1] - self.loss[-2]) < e_rel * np.abs(self.loss[-1]): 

243 break 

244 self.it = it 

245 return it, self.loss[-1] 

246 

247 def parameterize(self, parameterization: Callable): 

248 """Convert the component parameter arrays into Parameter instances 

249 

250 Parameters 

251 ---------- 

252 parameterization: 

253 A function to use to convert parameters of a given type into 

254 a `Parameter` in place. It should take a single argument that 

255 is the `Component` or `Source` that is to be parameterized. 

256 """ 

257 for source in self.sources: 

258 source.parameterize(parameterization) 

259 cast(FittedPsfObservation, self.observation).parameterize(parameterization)