Coverage for python/lsst/scarlet/lite/component.py: 39%

133 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-16 02:46 -0700

1# This file is part of scarlet_lite. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = [ 

23 "Component", 

24 "FactorizedComponent", 

25 "default_fista_parameterization", 

26 "default_adaprox_parameterization", 

27] 

28 

29from abc import ABC, abstractmethod 

30from functools import partial 

31from typing import Callable, cast 

32 

33import numpy as np 

34 

35from .bbox import Box 

36from .image import Image 

37from .operators import Monotonicity 

38from .parameters import AdaproxParameter, FistaParameter, Parameter, parameter, relative_step 

39 

40 

41class Component(ABC): 

42 """A base component in scarlet lite. 

43 

44 Parameters 

45 ---------- 

46 bands: 

47 The bands used when the component model is created. 

48 bbox: Box 

49 The bounding box for this component. 

50 """ 

51 

52 def __init__( 

53 self, 

54 bands: tuple, 

55 bbox: Box, 

56 ): 

57 self._bands = bands 

58 self._bbox = bbox 

59 

60 @property 

61 def bbox(self) -> Box: 

62 """The bounding box that contains the component in the full image""" 

63 return self._bbox 

64 

65 @property 

66 def bands(self) -> tuple: 

67 """The bands in the component model""" 

68 return self._bands 

69 

70 @abstractmethod 

71 def resize(self, model_box: Box) -> bool: 

72 """Test whether or not the component needs to be resized 

73 

74 This should be overriden in inherited classes and return `True` 

75 if the component needs to be resized. 

76 """ 

77 

78 @abstractmethod 

79 def update(self, it: int, input_grad: np.ndarray) -> None: 

80 """Update the component parameters from an input gradient 

81 

82 Parameters 

83 ---------- 

84 it: 

85 The current iteration of the optimizer. 

86 input_grad: 

87 Gradient of the likelihood wrt the component model 

88 """ 

89 

90 @abstractmethod 

91 def get_model(self) -> Image: 

92 """Generate a model for the component 

93 

94 This must be implemented in inherited classes. 

95 

96 Returns 

97 ------- 

98 model: Image 

99 The image of the component model. 

100 """ 

101 

102 @abstractmethod 

103 def parameterize(self, parameterization: Callable) -> None: 

104 """Convert the component parameter arrays into Parameter instances 

105 

106 Parameters 

107 ---------- 

108 parameterization: Callable 

109 A function to use to convert parameters of a given type into 

110 a `Parameter` in place. It should take a single argument that 

111 is the `Component` or `Source` that is to be parameterized. 

112 """ 

113 

114 

115class FactorizedComponent(Component): 

116 """A component that can be factorized into spectrum and morphology 

117 parameters. 

118 

119 Parameters 

120 ---------- 

121 bands: 

122 The bands of the spectral dimension, in order. 

123 spectrum: 

124 The parameter to store and update the spectrum. 

125 morph: 

126 The parameter to store and update the morphology. 

127 peak: 

128 Location of the peak for the source. 

129 bbox: 

130 The `Box` in the `model_bbox` that contains the source. 

131 bg_rms: 

132 The RMS of the background used to threshold, grow, 

133 and shrink the component. 

134 floor: 

135 Minimum value of the spectrum or center morphology pixel 

136 (depending on which is normalized). 

137 monotonicity: 

138 The monotonicity operator to use for making the source monotonic. 

139 If this parameter is `None`, the source will not be made monotonic. 

140 """ 

141 

142 def __init__( 

143 self, 

144 bands: tuple, 

145 spectrum: Parameter | np.ndarray, 

146 morph: Parameter | np.ndarray, 

147 bbox: Box, 

148 peak: tuple[int, int] | None = None, 

149 bg_rms: np.ndarray | None = None, 

150 bg_thresh: float | None = 0.25, 

151 floor: float = 1e-20, 

152 monotonicity: Monotonicity | None = None, 

153 padding: int = 5, 

154 ): 

155 # Initialize all of the base attributes 

156 super().__init__( 

157 bands=bands, 

158 bbox=bbox, 

159 ) 

160 self._spectrum = parameter(spectrum) 

161 self._morph = parameter(morph) 

162 self._peak = peak 

163 self.bg_rms = bg_rms 

164 self.bg_thresh = bg_thresh 

165 

166 self.floor = floor 

167 self.monotonicity = monotonicity 

168 self.padding = padding 

169 

170 @property 

171 def peak(self) -> tuple[int, int] | None: 

172 """The peak of the component 

173 

174 Returns 

175 ------- 

176 peak: 

177 The peak of the component 

178 """ 

179 return self._peak 

180 

181 @property 

182 def component_center(self) -> tuple[int, int] | None: 

183 """The center of the component in its bounding box 

184 

185 This is likely to be different than `Component.center`, 

186 since `Component.center` is the center of the component in the 

187 full model, whereas `component_center` is the center of the component 

188 inside its bounding box. 

189 

190 Returns 

191 ------- 

192 center: 

193 The center of the component in its bounding box 

194 """ 

195 _center = self.peak 

196 if _center is None: 

197 return None 

198 center = ( 

199 _center[0] - self.bbox.origin[-2], 

200 _center[1] - self.bbox.origin[-1], 

201 ) 

202 return center 

203 

204 @property 

205 def spectrum(self) -> np.ndarray: 

206 """The array of spectrum values""" 

207 return self._spectrum.x 

208 

209 @property 

210 def morph(self) -> np.ndarray: 

211 """The array of morphology values""" 

212 return self._morph.x 

213 

214 @property 

215 def shape(self) -> tuple: 

216 """Shape of the resulting model image""" 

217 return self.spectrum.shape + self.morph.shape 

218 

219 def get_model(self) -> Image: 

220 """Build the model from the spectrum and morphology""" 

221 # The spectrum and morph might be Parameters, 

222 # so cast them as arrays in the model. 

223 spectrum = self.spectrum 

224 morph = self.morph 

225 model = spectrum[:, None, None] * morph[None, :, :] 

226 return Image(model, bands=self.bands, yx0=cast(tuple[int, int], self.bbox.origin)) 

227 

228 def grad_spectrum(self, input_grad: np.ndarray, spectrum: np.ndarray, morph: np.ndarray): 

229 """Gradient of the spectrum wrt. the component model""" 

230 return np.einsum("...jk,jk", input_grad, morph) 

231 

232 def grad_morph(self, input_grad: np.ndarray, morph: np.ndarray, spectrum: np.ndarray): 

233 """Gradient of the morph wrt. the component model""" 

234 return np.einsum("i,i...", spectrum, input_grad) 

235 

236 def prox_spectrum(self, spectrum: np.ndarray) -> np.ndarray: 

237 """Apply a prox-like update to the spectrum""" 

238 # prevent divergent spectrum 

239 spectrum[spectrum < self.floor] = self.floor 

240 spectrum[~np.isfinite(spectrum)] = self.floor 

241 return spectrum 

242 

243 def prox_morph(self, morph: np.ndarray) -> np.ndarray: 

244 """Apply a prox-like update to the morphology""" 

245 # monotonicity 

246 if self.monotonicity is not None: 

247 morph = self.monotonicity(morph, cast(tuple[int, int], self.component_center)) 

248 

249 if self.bg_thresh is not None and self.bg_rms is not None: 

250 bg_thresh = self.bg_rms * self.bg_thresh 

251 # Enforce background thresholding 

252 model = self.spectrum[:, None, None] * morph[None, :, :] 

253 morph[np.all(model < bg_thresh[:, None, None], axis=0)] = 0 

254 else: 

255 # enforce positivity 

256 morph[morph < 0] = 0 

257 

258 # prevent divergent morphology 

259 shape = morph.shape 

260 if self.peak is None: 

261 peak = (shape[0] // 2, shape[1] // 2) 

262 else: 

263 peak = ( 

264 self.peak[0] - self.bbox.origin[-2], 

265 self.peak[1] - self.bbox.origin[-1], 

266 ) 

267 morph[peak] = np.max([morph[peak], self.floor]) 

268 

269 # Ensure that the morphology is finite 

270 morph[~np.isfinite(morph)] = 0 

271 

272 # Normalize the morphology 

273 max_value = np.max(morph) 

274 if max_value > 0: 

275 morph[:] = morph / max_value 

276 return morph 

277 

278 def resize(self, model_box: Box) -> bool: 

279 """Test whether or not the component needs to be resized""" 

280 # No need to resize if there is no size threshold. 

281 # To allow box sizing but no thresholding use `bg_thresh=0`. 

282 if self.bg_thresh is None or self.bg_rms is None: 

283 return False 

284 

285 model = self.spectrum[:, None, None] * self.morph[None, :, :] 

286 bg_thresh = self.bg_rms * self.bg_thresh 

287 significant = np.any(model >= bg_thresh[:, None, None], axis=0) 

288 if np.sum(significant) == 0: 

289 # There are no significant pixels, 

290 # so make a small box around the center 

291 center = self.peak 

292 if center is None: 

293 center = (0, 0) 

294 new_box = Box((1, 1), center).grow(self.padding) & model_box 

295 else: 

296 new_box = ( 

297 Box.from_data(significant, threshold=0).grow(self.padding) + self.bbox.origin 

298 ) & model_box 

299 if new_box == self.bbox: 

300 return False 

301 

302 old_box = self.bbox 

303 self._bbox = new_box 

304 self._morph.resize(old_box, new_box) 

305 return True 

306 

307 def update(self, it: int, input_grad: np.ndarray): 

308 """Update the spectrum and morphology parameters""" 

309 # Store the input spectrum so that the morphology can 

310 # have a consistent update 

311 spectrum = self.spectrum.copy() 

312 self._spectrum.update(it, input_grad, self.morph) 

313 self._morph.update(it, input_grad, spectrum) 

314 

315 def parameterize(self, parameterization: Callable) -> None: 

316 """Convert the component parameter arrays into Parameter instances 

317 

318 Parameters 

319 ---------- 

320 parameterization: Callable 

321 A function to use to convert parameters of a given type into 

322 a `Parameter` in place. It should take a single argument that 

323 is the `Component` or `Source` that is to be parameterized. 

324 """ 

325 # Update the spectrum and morph in place 

326 parameterization(self) 

327 # update the parameters 

328 self._spectrum.grad = self.grad_spectrum 

329 self._spectrum.prox = self.prox_spectrum 

330 self._morph.grad = self.grad_morph 

331 self._morph.prox = self.prox_morph 

332 

333 def __str__(self): 

334 result = ( 

335 f"FactorizedComponent<\n bands={self.bands},\n center={self.peak},\n " 

336 f"spectrum={self.spectrum},\n morph_shape={self.morph.shape}\n>" 

337 ) 

338 return result 

339 

340 def __repr__(self): 

341 return self.__str__() 

342 

343 

344def default_fista_parameterization(component: Component): 

345 """Initialize a factorized component to use FISTA PGM for optimization""" 

346 if isinstance(component, FactorizedComponent): 

347 component._spectrum = FistaParameter(component.spectrum, step=0.5) 

348 component._morph = FistaParameter(component.morph, step=0.5) 

349 else: 

350 raise NotImplementedError(f"Unrecognized component type {component}") 

351 

352 

353def default_adaprox_parameterization(component: Component, noise_rms: float | None = None): 

354 """Initialize a factorized component to use Proximal ADAM 

355 for optimization 

356 """ 

357 if noise_rms is None: 

358 noise_rms = 1e-16 

359 if isinstance(component, FactorizedComponent): 

360 component._spectrum = AdaproxParameter( 

361 component.spectrum, 

362 step=partial(relative_step, factor=1e-2, minimum=noise_rms), 

363 ) 

364 component._morph = AdaproxParameter( 

365 component.morph, 

366 step=1e-2, 

367 ) 

368 else: 

369 raise NotImplementedError(f"Unrecognized component type {component}")