Coverage for python/lsst/scarlet/lite/component.py: 40%

129 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-03-20 03:40 -0700

1# This file is part of scarlet_lite. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = [ 

23 "Component", 

24 "FactorizedComponent", 

25 "default_fista_parameterization", 

26 "default_adaprox_parameterization", 

27] 

28 

29from abc import ABC, abstractmethod 

30from functools import partial 

31from typing import Callable, cast 

32 

33import numpy as np 

34 

35from .bbox import Box 

36from .image import Image 

37from .operators import Monotonicity 

38from .parameters import AdaproxParameter, FistaParameter, Parameter, parameter, relative_step 

39 

40 

41class Component(ABC): 

42 """A base component in scarlet lite. 

43 

44 Parameters 

45 ---------- 

46 bands: 

47 The bands used when the component model is created. 

48 bbox: Box 

49 The bounding box for this component. 

50 """ 

51 

52 def __init__( 

53 self, 

54 bands: tuple, 

55 bbox: Box, 

56 ): 

57 self._bands = bands 

58 self._bbox = bbox 

59 

60 @property 

61 def bbox(self) -> Box: 

62 """The bounding box that contains the component in the full image""" 

63 return self._bbox 

64 

65 @property 

66 def bands(self) -> tuple: 

67 """The bands in the component model""" 

68 return self._bands 

69 

70 @abstractmethod 

71 def resize(self, model_box: Box) -> bool: 

72 """Test whether or not the component needs to be resized 

73 

74 This should be overriden in inherited classes and return `True` 

75 if the component needs to be resized. 

76 """ 

77 

78 @abstractmethod 

79 def update(self, it: int, input_grad: np.ndarray) -> None: 

80 """Update the component parameters from an input gradient 

81 

82 Parameters 

83 ---------- 

84 it: 

85 The current iteration of the optimizer. 

86 input_grad: 

87 Gradient of the likelihood wrt the component model 

88 """ 

89 

90 @abstractmethod 

91 def get_model(self) -> Image: 

92 """Generate a model for the component 

93 

94 This must be implemented in inherited classes. 

95 

96 Returns 

97 ------- 

98 model: Image 

99 The image of the component model. 

100 """ 

101 

102 @abstractmethod 

103 def parameterize(self, parameterization: Callable) -> None: 

104 """Convert the component parameter arrays into Parameter instances 

105 

106 Parameters 

107 ---------- 

108 parameterization: Callable 

109 A function to use to convert parameters of a given type into 

110 a `Parameter` in place. It should take a single argument that 

111 is the `Component` or `Source` that is to be parameterized. 

112 """ 

113 

114 

115class FactorizedComponent(Component): 

116 """A component that can be factorized into spectrum and morphology 

117 parameters. 

118 

119 Parameters 

120 ---------- 

121 bands: 

122 The bands of the spectral dimension, in order. 

123 spectrum: 

124 The parameter to store and update the spectrum. 

125 morph: 

126 The parameter to store and update the morphology. 

127 peak: 

128 Location of the peak for the source. 

129 bbox: 

130 The `Box` in the `model_bbox` that contains the source. 

131 bg_rms: 

132 The RMS of the background used to threshold, grow, 

133 and shrink the component. 

134 floor: 

135 Minimum value of the spectrum or center morphology pixel 

136 (depending on which is normalized). 

137 monotonicity: 

138 The monotonicity operator to use for making the source monotonic. 

139 If this parameter is `None`, the source will not be made monotonic. 

140 """ 

141 

142 def __init__( 

143 self, 

144 bands: tuple, 

145 spectrum: Parameter | np.ndarray, 

146 morph: Parameter | np.ndarray, 

147 bbox: Box, 

148 peak: tuple[int, int] | None = None, 

149 bg_rms: np.ndarray | None = None, 

150 bg_thresh: float | None = 0.25, 

151 floor: float = 1e-20, 

152 monotonicity: Monotonicity | None = None, 

153 padding: int = 5, 

154 ): 

155 # Initialize all of the base attributes 

156 super().__init__( 

157 bands=bands, 

158 bbox=bbox, 

159 ) 

160 self._spectrum = parameter(spectrum) 

161 self._morph = parameter(morph) 

162 self._peak = peak 

163 self.bg_rms = bg_rms 

164 self.bg_thresh = bg_thresh 

165 

166 self.floor = floor 

167 self.monotonicity = monotonicity 

168 self.padding = padding 

169 

170 @property 

171 def peak(self) -> tuple[int, int] | None: 

172 """The peak of the component 

173 

174 Returns 

175 ------- 

176 peak: 

177 The peak of the component 

178 """ 

179 return self._peak 

180 

181 @property 

182 def component_center(self) -> tuple[int, int] | None: 

183 """The center of the component in its bounding box 

184 

185 This is likely to be different than `Component.center`, 

186 since `Component.center` is the center of the component in the 

187 full model, whereas `component_center` is the center of the component 

188 inside its bounding box. 

189 

190 Returns 

191 ------- 

192 center: 

193 The center of the component in its bounding box 

194 """ 

195 _center = self.peak 

196 if _center is None: 

197 return None 

198 center = ( 

199 _center[0] - self.bbox.origin[-2], 

200 _center[1] - self.bbox.origin[-1], 

201 ) 

202 return center 

203 

204 @property 

205 def spectrum(self) -> np.ndarray: 

206 """The array of spectrum values""" 

207 return self._spectrum.x 

208 

209 @property 

210 def morph(self) -> np.ndarray: 

211 """The array of morphology values""" 

212 return self._morph.x 

213 

214 @property 

215 def shape(self) -> tuple: 

216 """Shape of the resulting model image""" 

217 return self.spectrum.shape + self.morph.shape 

218 

219 def get_model(self) -> Image: 

220 """Build the model from the spectrum and morphology""" 

221 # The spectrum and morph might be Parameters, 

222 # so cast them as arrays in the model. 

223 spectrum = self.spectrum 

224 morph = self.morph 

225 model = spectrum[:, None, None] * morph[None, :, :] 

226 return Image(model, bands=self.bands, yx0=cast(tuple[int, int], self.bbox.origin)) 

227 

228 def grad_spectrum(self, input_grad: np.ndarray, spectrum: np.ndarray, morph: np.ndarray): 

229 """Gradient of the spectrum wrt. the component model""" 

230 return np.einsum("...jk,jk", input_grad, morph) 

231 

232 def grad_morph(self, input_grad: np.ndarray, morph: np.ndarray, spectrum: np.ndarray): 

233 """Gradient of the morph wrt. the component model""" 

234 return np.einsum("i,i...", spectrum, input_grad) 

235 

236 def prox_spectrum(self, spectrum: np.ndarray) -> np.ndarray: 

237 """Apply a prox-like update to the spectrum""" 

238 # prevent divergent spectrum 

239 spectrum[spectrum < self.floor] = self.floor 

240 return spectrum 

241 

242 def prox_morph(self, morph: np.ndarray) -> np.ndarray: 

243 """Apply a prox-like update to the morphology""" 

244 # monotonicity 

245 if self.monotonicity is not None: 

246 morph = self.monotonicity(morph, cast(tuple[int, int], self.component_center)) 

247 

248 if self.bg_thresh is not None and self.bg_rms is not None: 

249 bg_thresh = self.bg_rms * self.bg_thresh 

250 # Enforce background thresholding 

251 model = self.spectrum[:, None, None] * morph[None, :, :] 

252 morph[np.all(model < bg_thresh[:, None, None], axis=0)] = 0 

253 else: 

254 # enforce positivity 

255 morph[morph < 0] = 0 

256 

257 # prevent divergent morphology 

258 shape = morph.shape 

259 if self.peak is None: 

260 peak = (shape[0] // 2, shape[1] // 2) 

261 else: 

262 peak = ( 

263 self.peak[0] - self.bbox.origin[-2], 

264 self.peak[1] - self.bbox.origin[-1], 

265 ) 

266 morph[peak] = np.max([morph[peak], self.floor]) 

267 # Normalize the morphology 

268 morph[:] = morph / np.max(morph) 

269 return morph 

270 

271 def resize(self, model_box: Box) -> bool: 

272 """Test whether or not the component needs to be resized""" 

273 # No need to resize if there is no size threshold. 

274 # To allow box sizing but no thresholding use `bg_thresh=0`. 

275 if self.bg_thresh is None or self.bg_rms is None: 

276 return False 

277 

278 model = self.spectrum[:, None, None] * self.morph[None, :, :] 

279 bg_thresh = self.bg_rms * self.bg_thresh 

280 significant = np.any(model >= bg_thresh[:, None, None], axis=0) 

281 if np.sum(significant) == 0: 

282 # There are no significant pixels, 

283 # so make a small box around the center 

284 center = self.peak 

285 if center is None: 

286 center = (0, 0) 

287 new_box = Box((1, 1), center).grow(self.padding) & model_box 

288 else: 

289 new_box = ( 

290 Box.from_data(significant, threshold=0).grow(self.padding) + self.bbox.origin 

291 ) & model_box 

292 if new_box == self.bbox: 

293 return False 

294 

295 old_box = self.bbox 

296 self._bbox = new_box 

297 self._morph.resize(old_box, new_box) 

298 return True 

299 

300 def update(self, it: int, input_grad: np.ndarray): 

301 """Update the spectrum and morphology parameters""" 

302 # Store the input spectrum so that the morphology can 

303 # have a consistent update 

304 spectrum = self.spectrum.copy() 

305 self._spectrum.update(it, input_grad, self.morph) 

306 self._morph.update(it, input_grad, spectrum) 

307 

308 def parameterize(self, parameterization: Callable) -> None: 

309 """Convert the component parameter arrays into Parameter instances 

310 

311 Parameters 

312 ---------- 

313 parameterization: Callable 

314 A function to use to convert parameters of a given type into 

315 a `Parameter` in place. It should take a single argument that 

316 is the `Component` or `Source` that is to be parameterized. 

317 """ 

318 # Update the spectrum and morph in place 

319 parameterization(self) 

320 # update the parameters 

321 self._spectrum.grad = self.grad_spectrum 

322 self._spectrum.prox = self.prox_spectrum 

323 self._morph.grad = self.grad_morph 

324 self._morph.prox = self.prox_morph 

325 

326 def __str__(self): 

327 result = ( 

328 f"FactorizedComponent<\n bands={self.bands},\n center={self.peak},\n " 

329 f"spectrum={self.spectrum},\n morph_shape={self.morph.shape}\n>" 

330 ) 

331 return result 

332 

333 def __repr__(self): 

334 return self.__str__() 

335 

336 

337def default_fista_parameterization(component: Component): 

338 """Initialize a factorized component to use FISTA PGM for optimization""" 

339 if isinstance(component, FactorizedComponent): 

340 component._spectrum = FistaParameter(component.spectrum, step=0.5) 

341 component._morph = FistaParameter(component.morph, step=0.5) 

342 else: 

343 raise NotImplementedError(f"Unrecognized component type {component}") 

344 

345 

346def default_adaprox_parameterization(component: Component, noise_rms: float | None = None): 

347 """Initialize a factorized component to use Proximal ADAM 

348 for optimization 

349 """ 

350 if noise_rms is None: 

351 noise_rms = 1e-16 

352 if isinstance(component, FactorizedComponent): 

353 component._spectrum = AdaproxParameter( 

354 component.spectrum, 

355 step=partial(relative_step, factor=1e-2, minimum=noise_rms), 

356 ) 

357 component._morph = AdaproxParameter( 

358 component.morph, 

359 step=1e-2, 

360 ) 

361 else: 

362 raise NotImplementedError(f"Unrecognized component type {component}")