Coverage for python/lsst/scarlet/lite/models/fit_psf.py: 31%

70 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-16 02:46 -0700

1# This file is part of scarlet_lite. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = ["FittedPsfObservation", "FittedPsfBlend"] 

23 

24from typing import Callable, cast 

25 

26import numpy as np 

27 

28from ..bbox import Box 

29from ..blend import Blend 

30from ..fft import Fourier, get_fft_shape 

31from ..image import Image 

32from ..observation import Observation 

33from ..parameters import parameter 

34 

35 

36class FittedPsfObservation(Observation): 

37 """An observation that fits the PSF used to convolve the model.""" 

38 

39 def __init__( 

40 self, 

41 images: np.ndarray | Image, 

42 variance: np.ndarray | Image, 

43 weights: np.ndarray | Image, 

44 psfs: np.ndarray, 

45 model_psf: np.ndarray | None = None, 

46 noise_rms: np.ndarray | None = None, 

47 bbox: Box | None = None, 

48 bands: tuple | None = None, 

49 padding: int = 3, 

50 convolution_mode: str = "fft", 

51 ): 

52 """Initialize a `FitPsfObservation` 

53 

54 See `Observation` for a description of the parameters. 

55 """ 

56 super().__init__( 

57 images, 

58 variance, 

59 weights, 

60 psfs, 

61 model_psf, 

62 noise_rms, 

63 bbox, 

64 bands, 

65 padding, 

66 convolution_mode, 

67 ) 

68 

69 self.axes = (-2, -1) 

70 

71 self.fft_shape = get_fft_shape(self.images.data[0], psfs[0], padding, self.axes) 

72 

73 # Make the DFT of the psf a fittable parameter 

74 self._fitted_kernel = parameter(cast(Fourier, self.diff_kernel).fft(self.fft_shape, self.axes)) 

75 

76 def grad_fit_kernel( 

77 self, input_grad: np.ndarray, kernel: np.ndarray, model_fft: np.ndarray 

78 ) -> np.ndarray: 

79 # Transform the upstream gradient into k-space 

80 grad_fft = Fourier(input_grad) 

81 _grad_fft = grad_fft.fft(self.fft_shape, self.axes) 

82 return _grad_fft * model_fft 

83 

84 def prox_kernel(self, kernel: np.ndarray) -> np.ndarray: 

85 # No prox for now 

86 return kernel 

87 

88 @property 

89 def fitted_kernel(self) -> np.ndarray: 

90 return self._fitted_kernel.x 

91 

92 @property 

93 def cached_kernel(self): 

94 return self.fitted_kernel.real - self.fitted_kernel.imag * 1j 

95 

96 def convolve(self, image: Image, mode: str | None = None, grad: bool = False) -> Image: 

97 """Convolve the model into the observed seeing in each band. 

98 

99 Parameters 

100 ---------- 

101 image: 

102 The image to convolve 

103 mode: 

104 The convolution mode to use. 

105 This should be "real" or "fft" or `None`, 

106 where `None` will use the default `convolution_mode` 

107 specified during init. 

108 grad: 

109 Whether this is a backward gradient convolution 

110 (`grad==True`) or a pure convolution with the PSF. 

111 """ 

112 if grad: 

113 kernel = self.cached_kernel 

114 else: 

115 kernel = self.fitted_kernel 

116 

117 if mode != "fft" and mode is not None: 

118 return super().convolve(image, mode, grad) 

119 

120 fft_image = Fourier(image.data) 

121 fft = fft_image.fft(self.fft_shape, self.axes) 

122 

123 result = Fourier.from_fft(fft * kernel, self.fft_shape, image.shape, self.axes) 

124 return Image(result.image, bands=image.bands, yx0=image.yx0) 

125 

126 def update(self, it: int, input_grad: np.ndarray, model: Image): 

127 _model = Fourier(model.data[:, ::-1, ::-1]) 

128 model_fft = _model.fft(self.fft_shape, self.axes) 

129 self._fitted_kernel.update(it, input_grad, model_fft) 

130 

131 def parameterize(self, parameterization: Callable) -> None: 

132 """Convert the component parameter arrays into Parameter instances 

133 

134 Parameters 

135 ---------- 

136 parameterization: Callable 

137 A function to use to convert parameters of a given type into 

138 a `Parameter` in place. It should take a single argument that 

139 is the `Component` or `Source` that is to be parameterized. 

140 """ 

141 # Update the spectrum and morph in place 

142 parameterization(self) 

143 # update the parameters 

144 self._fitted_kernel.grad = self.grad_fit_kernel 

145 self._fitted_kernel.prox = self.prox_kernel 

146 

147 

148class FittedPsfBlend(Blend): 

149 """A blend that attempts to fit the PSF along with the source models.""" 

150 

151 def _grad_log_likelihood(self) -> Image: 

152 """Gradient of the likelihood wrt the unconvolved model""" 

153 model = self.get_model(convolve=True) 

154 # Update the loss 

155 self.loss.append(self.observation.log_likelihood(model)) 

156 # Calculate the gradient wrt the model d(logL)/d(model) 

157 result = self.observation.weights * (model - self.observation.images) 

158 return result 

159 

160 def fit( 

161 self, 

162 max_iter: int, 

163 e_rel: float = 1e-4, 

164 min_iter: int = 1, 

165 resize: int = 10, 

166 ) -> tuple[int, float]: 

167 """Fit all of the parameters 

168 

169 Parameters 

170 ---------- 

171 max_iter: int 

172 The maximum number of iterations 

173 e_rel: float 

174 The relative error to use for determining convergence. 

175 min_iter: int 

176 The minimum number of iterations. 

177 resize: int 

178 Number of iterations before attempting to resize the 

179 resizable components. If `resize` is `None` then 

180 no resizing is ever attempted. 

181 """ 

182 it = self.it 

183 while it < max_iter: 

184 # Calculate the gradient wrt the on-convolved model 

185 grad_log_likelihood = self._grad_log_likelihood() 

186 _grad_log_likelihood = self.observation.convolve(grad_log_likelihood, grad=True) 

187 # Check if resizing needs to be performed in this iteration 

188 if resize is not None and self.it > 0 and self.it % resize == 0: 

189 do_resize = True 

190 else: 

191 do_resize = False 

192 # Update each component given the current gradient 

193 for component in self.components: 

194 overlap = component.bbox & self.bbox 

195 component.update(it, _grad_log_likelihood[overlap].data) 

196 # Check to see if any components need to be resized 

197 if do_resize: 

198 component.resize(self.bbox) 

199 

200 # Update the PSF 

201 cast(FittedPsfObservation, self.observation).update( 

202 it, grad_log_likelihood.data, self.get_model() 

203 ) 

204 # Stopping criteria 

205 it += 1 

206 if it > min_iter and np.abs(self.loss[-1] - self.loss[-2]) < e_rel * np.abs(self.loss[-1]): 

207 break 

208 self.it = it 

209 return it, self.loss[-1]