Coverage for python/lsst/scarlet/lite/blend.py: 18%

126 statements  

« prev     ^ index     » next       coverage.py v7.4.3, created at 2024-03-01 11:54 +0000

1# This file is part of scarlet_lite. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ["Blend"] 

25 

26from typing import Callable, Sequence, cast 

27 

28import numpy as np 

29 

30from .bbox import Box 

31from .component import Component, FactorizedComponent 

32from .image import Image 

33from .observation import Observation 

34from .source import Source 

35 

36 

37class Blend: 

38 """A single blend. 

39 

40 This class holds all of the sources and observation that are to be fit, 

41 as well as performing fitting and joint initialization of the 

42 spectral components (when applicable). 

43 

44 Parameters 

45 ---------- 

46 sources: 

47 The sources to fit. 

48 observation: 

49 The observation that contains the images, 

50 PSF, etc. that are being fit. 

51 """ 

52 

53 def __init__(self, sources: Sequence[Source], observation: Observation): 

54 self.sources = list(sources) 

55 self.observation = observation 

56 

57 # Initialize the iteration count and loss function 

58 self.it = 0 

59 self.loss: list[float] = [] 

60 

61 @property 

62 def shape(self) -> tuple[int, int, int]: 

63 """Shape of the model for the entire `Blend`.""" 

64 return self.observation.shape 

65 

66 @property 

67 def bbox(self) -> Box: 

68 """The bounding box of the entire blend.""" 

69 return self.observation.bbox 

70 

71 @property 

72 def components(self) -> list[Component]: 

73 """The list of all components in the blend. 

74 

75 Since the list of sources might change, 

76 this is always built on the fly. 

77 """ 

78 return [c for src in self.sources for c in src.components] 

79 

80 def get_model(self, convolve: bool = False, use_flux: bool = False) -> Image: 

81 """Generate a model of the entire blend. 

82 

83 Parameters 

84 ---------- 

85 convolve: 

86 Whether to convolve the model with the observed PSF in each band. 

87 use_flux: 

88 Whether to use the re-distributed flux associated with the sources 

89 instead of the component models. 

90 

91 Returns 

92 ------- 

93 model: 

94 The model created by combining all of the source models. 

95 """ 

96 model = Image( 

97 np.zeros(self.shape, dtype=self.observation.images.dtype), 

98 bands=self.observation.bands, 

99 yx0=cast(tuple[int, int], self.observation.bbox.origin[-2:]), 

100 ) 

101 

102 if use_flux: 

103 for src in self.sources: 

104 if src.flux_weighted_image is None: 

105 raise ValueError( 

106 "Some sources do not have 'flux' attribute set. Run measure.conserve_flux" 

107 ) 

108 src.flux_weighted_image.insert_into(model) 

109 else: 

110 for component in self.components: 

111 component.get_model().insert_into(model) 

112 if convolve: 

113 return self.observation.convolve(model) 

114 return model 

115 

116 def _grad_log_likelihood(self) -> Image: 

117 """Gradient of the likelihood wrt the unconvolved model""" 

118 model = self.get_model(convolve=True) 

119 # Update the loss 

120 self.loss.append(self.observation.log_likelihood(model)) 

121 # Calculate the gradient wrt the model d(logL)/d(model) 

122 result = self.observation.weights * (model - self.observation.images) 

123 result = self.observation.convolve(result, grad=True) 

124 return result 

125 

126 @property 

127 def log_likelihood(self) -> float: 

128 """The current log-likelihood 

129 

130 This is calculated on the fly to ensure that it is always up to date 

131 with the current model parameters. 

132 """ 

133 return self.observation.log_likelihood(self.get_model(convolve=True)) 

134 

135 def fit_spectra(self, clip: bool = False) -> Blend: 

136 """Fit all of the spectra given their current morphologies with a 

137 linear least squares algorithm. 

138 

139 Parameters 

140 ---------- 

141 clip: 

142 Whether or not to clip components that were not 

143 assigned any flux during the fit. 

144 

145 Returns 

146 ------- 

147 blend: 

148 The blend with updated components is returned. 

149 """ 

150 from .initialization import multifit_spectra 

151 

152 morphs = [] 

153 spectra = [] 

154 factorized_indices = [] 

155 model = Image.from_box( 

156 self.observation.bbox, 

157 bands=self.observation.bands, 

158 dtype=self.observation.dtype, 

159 ) 

160 components = self.components 

161 for idx, component in enumerate(components): 

162 if hasattr(component, "morph") and hasattr(component, "spectrum"): 

163 component = cast(FactorizedComponent, component) 

164 morphs.append(component.morph) 

165 spectra.append(component.spectrum) 

166 factorized_indices.append(idx) 

167 else: 

168 model.insert(component.get_model()) 

169 model = self.observation.convolve(model, mode="real") 

170 

171 boxes = [c.bbox for c in components] 

172 fit_spectra = multifit_spectra( 

173 self.observation, 

174 [Image(morph, yx0=cast(tuple[int, int], bbox.origin)) for morph, bbox in zip(morphs, boxes)], 

175 model, 

176 ) 

177 for idx in range(len(morphs)): 

178 component = cast(FactorizedComponent, components[factorized_indices[idx]]) 

179 component.spectrum[:] = fit_spectra[idx] 

180 component.spectrum[component.spectrum < 0] = 0 

181 

182 # Run the proxes for all of the components to make sure that the 

183 # spectra are consistent with the constraints. 

184 # In practice this usually means making sure that they are 

185 # non-negative. 

186 for src in self.sources: 

187 for component in src.components: 

188 if ( 

189 hasattr(component, "spectrum") 

190 and hasattr(component, "prox_spectrum") 

191 and component.prox_spectrum is not None # type: ignore 

192 ): 

193 component.prox_spectrum(component.spectrum) # type: ignore 

194 

195 if clip: 

196 # Remove components with no positive flux 

197 for src in self.sources: 

198 _components = [] 

199 for component in src.components: 

200 component_model = component.get_model() 

201 component_model.data[component_model.data < 0] = 0 

202 if np.sum(component_model.data) > 0: 

203 _components.append(component) 

204 src.components = _components 

205 

206 return self 

207 

208 def fit( 

209 self, 

210 max_iter: int, 

211 e_rel: float = 1e-4, 

212 min_iter: int = 15, 

213 resize: int = 10, 

214 ) -> tuple[int, float]: 

215 """Fit all of the parameters 

216 

217 Parameters 

218 ---------- 

219 max_iter: 

220 The maximum number of iterations 

221 e_rel: 

222 The relative error to use for determining convergence. 

223 min_iter: 

224 The minimum number of iterations. 

225 resize: 

226 Number of iterations before attempting to resize the 

227 resizable components. If `resize` is `None` then 

228 no resizing is ever attempted. 

229 

230 Returns 

231 ------- 

232 it: 

233 Number of iterations. 

234 loss: 

235 Loss for the last solution 

236 """ 

237 while self.it < max_iter: 

238 # Calculate the gradient wrt the on-convolved model 

239 grad_log_likelihood = self._grad_log_likelihood() 

240 if resize is not None and self.it > 0 and self.it % resize == 0: 

241 do_resize = True 

242 else: 

243 do_resize = False 

244 # Update each component given the current gradient 

245 for component in self.components: 

246 overlap = component.bbox & self.bbox 

247 component.update(self.it, grad_log_likelihood[overlap].data) 

248 # Check to see if any components need to be resized 

249 if do_resize: 

250 component.resize(self.bbox) 

251 # Stopping criteria 

252 self.it += 1 

253 if self.it > min_iter and np.abs(self.loss[-1] - self.loss[-2]) < e_rel * np.abs(self.loss[-1]): 

254 break 

255 return self.it, self.loss[-1] 

256 

257 def parameterize(self, parameterization: Callable): 

258 """Convert the component parameter arrays into Parameter instances 

259 

260 Parameters 

261 ---------- 

262 parameterization: 

263 A function to use to convert parameters of a given type into 

264 a `Parameter` in place. It should take a single argument that 

265 is the `Component` or `Source` that is to be parameterized. 

266 """ 

267 for source in self.sources: 

268 source.parameterize(parameterization) 

269 

270 def conserve_flux(self, mask_footprint: bool = True) -> None: 

271 """Use the source models as templates to re-distribute flux 

272 from the data 

273 

274 The source models are used as approximations to the data, 

275 which redistribute the flux in the data according to the 

276 ratio of the models for each source. 

277 There is no return value for this function, 

278 instead it adds (or modifies) a ``flux_weighted_image`` 

279 attribute to each the sources with the flux attributed to 

280 that source. 

281 

282 Parameters 

283 ---------- 

284 blend: 

285 The blend that is being fit 

286 mask_footprint: 

287 Whether or not to apply a mask for pixels with zero weight. 

288 """ 

289 observation = self.observation 

290 py = observation.psfs.shape[-2] // 2 

291 px = observation.psfs.shape[-1] // 2 

292 

293 images = observation.images.copy() 

294 if mask_footprint: 

295 images.data[observation.weights.data == 0] = 0 

296 model = self.get_model() 

297 # Always convolve in real space to avoid FFT artifacts 

298 model = observation.convolve(model, mode="real") 

299 model.data[model.data < 0] = 0 

300 

301 for src in self.sources: 

302 if src.is_null: 

303 src.flux_weighted_image = Image.from_box(Box((0, 0)), bands=observation.bands) # type: ignore 

304 continue 

305 src_model = src.get_model() 

306 

307 # Grow the model to include the wings of the PSF 

308 src_box = src.bbox.grow((py, px)) 

309 overlap = observation.bbox & src_box 

310 src_model = src_model.project(bbox=overlap) 

311 src_model = observation.convolve(src_model, mode="real") 

312 src_model.data[src_model.data < 0] = 0 

313 numerator = src_model.data 

314 denominator = model[overlap].data 

315 cuts = denominator != 0 

316 ratio = np.zeros(numerator.shape, dtype=numerator.dtype) 

317 ratio[cuts] = numerator[cuts] / denominator[cuts] 

318 ratio[denominator == 0] = 0 

319 # sometimes numerical errors can cause a hot pixel to have a 

320 # slightly higher ratio than 1 

321 ratio[ratio > 1] = 1 

322 src.flux_weighted_image = src_model.copy_with(data=ratio) * images[overlap]