Coverage for python / lsst / scarlet / lite / models / free_form.py: 27%

112 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-14 23:28 +0000

1# This file is part of scarlet_lite. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ["FactorizedFreeFormComponent"] 

24 

25from copy import deepcopy 

26from typing import TYPE_CHECKING, Any, Callable, cast 

27 

28import numpy as np 

29 

30from ..bbox import Box 

31from ..component import Component, FactorizedComponent 

32from ..detect import footprints_to_image 

33from ..detect_pybind11 import get_connected_multipeak, get_footprints # type: ignore 

34from ..image import Image 

35from ..parameters import Parameter, parameter 

36 

37if TYPE_CHECKING: 

38 from ..io.component import ScarletComponentBaseData 

39 

40 

41class FactorizedFreeFormComponent(FactorizedComponent): 

42 """Implements a free-form component 

43 

44 With no constraints this component is typically either a garbage collector, 

45 or part of a set of components to deconvolve an image by separating out 

46 the different spectral components. 

47 

48 See `FactorizedComponent` for a list of parameters not shown here. 

49 

50 Parameters 

51 ---------- 

52 peaks: `list` of `tuple` 

53 A set of ``(cy, cx)`` peaks for detected sources. 

54 If peak is not ``None`` then only pixels in the same "footprint" 

55 as one of the peaks are included in the morphology. 

56 If `peaks` is ``None`` then there is no constraint applied. 

57 min_area: float 

58 The minimum area for a peak. 

59 If `min_area` is not `None` then all regions of the morphology 

60 with fewer than `min_area` connected pixels are removed. 

61 """ 

62 

63 def __init__( 

64 self, 

65 bands: tuple, 

66 spectrum: np.ndarray | Parameter, 

67 morph: np.ndarray | Parameter, 

68 model_bbox: Box, 

69 bg_thresh: float | None = None, 

70 bg_rms: np.ndarray | None = None, 

71 floor: float = 1e-20, 

72 peaks: list[tuple[int, int]] | None = None, 

73 min_area: float = 0, 

74 ): 

75 super().__init__( 

76 bands=bands, 

77 spectrum=spectrum, 

78 morph=morph, 

79 bbox=model_bbox, 

80 peak=None, 

81 bg_rms=bg_rms, 

82 bg_thresh=bg_thresh, 

83 floor=floor, 

84 ) 

85 

86 self.peaks = peaks 

87 self.min_area = min_area 

88 

89 def prox_spectrum(self, spectrum: np.ndarray) -> np.ndarray: 

90 """Apply a prox-like update to the spectrum 

91 

92 This differs from `FactorizedComponent` because an 

93 `SedComponent` has the spectrum normalized to unity. 

94 """ 

95 # prevent divergent spectrum 

96 spectrum[spectrum < self.floor] = self.floor 

97 # Normalize the spectrum 

98 spectrum = spectrum / np.sum(spectrum) 

99 return spectrum 

100 

101 def prox_morph(self, morph: np.ndarray) -> np.ndarray: 

102 """Apply a prox-like update to the morphology 

103 

104 This is the main difference between an `SedComponent` and a 

105 `FactorizedComponent`, since this component has fewer constraints. 

106 """ 

107 from lsst.scarlet.lite.detect_pybind11 import get_connected_multipeak, get_footprints # type: ignore 

108 

109 if self.bg_thresh is not None and isinstance(self.bg_rms, np.ndarray): 

110 bg_thresh = self.bg_rms * self.bg_thresh 

111 # Enforce background thresholding 

112 model = self.spectrum[:, None, None] * morph[None, :, :] 

113 morph[np.all(model < bg_thresh[:, None, None], axis=0)] = 0 

114 else: 

115 # enforce positivity 

116 morph[morph < 0] = 0 

117 

118 if self.peaks is not None: 

119 footprint = get_connected_multipeak(morph, self.peaks, 0) 

120 morph = morph * footprint 

121 

122 if self.min_area > 0: 

123 footprints = get_footprints(morph, 4.0, self.min_area, 0, 0, False) 

124 bbox = self.bbox.copy() 

125 bbox.origin = (0, 0) 

126 footprint_image = footprints_to_image(footprints, bbox) 

127 morph = morph * (footprint_image > 0).data 

128 

129 if np.all(morph == 0): 

130 morph[0, 0] = self.floor 

131 

132 return morph 

133 

134 def resize(self, model_box: Box) -> bool: 

135 return False 

136 

137 def __str__(self): 

138 return ( 

139 f"FactorizedFreeFormComponent(\n bands={self.bands}\n " 

140 f"spectrum={self.spectrum})\n center={self.peak}\n " 

141 f"morph_shape={self.morph.shape}" 

142 ) 

143 

144 def __repr__(self): 

145 return self.__str__() 

146 

147 

148class FreeFormComponent(Component): 

149 """Implements a component with no spectral or monotonicty constraints 

150 

151 This is a FreeFormComponent that is not factorized into a 

152 spectrum and morphology with no monotonicity constraint. 

153 """ 

154 

155 def __init__( 

156 self, 

157 bands: tuple, 

158 model: np.ndarray | Parameter, 

159 model_bbox: Box, 

160 bg_thresh: float | None = None, 

161 bg_rms: np.ndarray | None = None, 

162 floor: float = 1e-20, 

163 peaks: list[tuple[int, int]] | None = None, 

164 min_area: float = 0, 

165 ): 

166 super().__init__(bands=bands, bbox=model_bbox) 

167 self._model = parameter(model) 

168 self.bg_rms = bg_rms 

169 self.bg_thresh = bg_thresh 

170 self.floor = floor 

171 self.peaks = peaks 

172 self.min_area = min_area 

173 

174 @property 

175 def model(self) -> np.ndarray: 

176 return self._model.x 

177 

178 def get_model(self) -> Image: 

179 return Image(self.model, bands=self.bands, yx0=cast(tuple[int, int], self.bbox.origin)) 

180 

181 @property 

182 def shape(self) -> tuple: 

183 return self.model.shape 

184 

185 def grad_model(self, input_grad: np.ndarray, model: np.ndarray) -> np.ndarray: 

186 return input_grad 

187 

188 def prox_model(self, model: np.ndarray) -> np.ndarray: 

189 if self.bg_thresh is not None and isinstance(self.bg_rms, np.ndarray): 

190 bg_thresh = self.bg_rms * self.bg_thresh 

191 # Enforce background thresholding 

192 model[model < bg_thresh[:, None, None]] = 0 

193 else: 

194 # enforce positivity 

195 model[model < 0] = 0 

196 

197 if self.peaks is not None: 

198 # Remove pixels not connected to one of the peaks 

199 model2d = np.sum(model, axis=0) 

200 footprint = get_connected_multipeak(model2d, self.peaks, 0) 

201 model = model * footprint[None, :, :] 

202 

203 if self.min_area > 0: 

204 # Remove regions with fewer than min_area connected pixels 

205 model2d = np.sum(model, axis=0) 

206 footprints = get_footprints(model2d, 4.0, self.min_area, 0, 0, False) 

207 bbox = self.bbox.copy() 

208 bbox.origin = (0, 0) 

209 footprint_image = footprints_to_image(footprints, bbox) 

210 model = model * (footprint_image > 0).data[None, :, :] 

211 

212 if np.all(model == 0): 

213 # If the model is all zeros, set a single pixel to the floor 

214 model[0, 0] = self.floor 

215 

216 return model 

217 

218 def resize(self, model_box: Box) -> bool: 

219 return False 

220 

221 def update(self, it: int, grad_log_likelihood: np.ndarray): 

222 self._model.update(it, grad_log_likelihood) 

223 

224 def parameterize(self, parameterization: Callable) -> None: 

225 """Convert the component parameter arrays into Parameter instances 

226 

227 Parameters 

228 ---------- 

229 parameterization: Callable 

230 A function to use to convert parameters of a given type into 

231 a `Parameter` in place. It should take a single argument that 

232 is the `Component` or `Source` that is to be parameterized. 

233 """ 

234 # Update the spectrum and morph in place 

235 parameterization(self) 

236 # update the parameters 

237 self._model.grad = self.grad_model 

238 self._model.prox = self.prox_model 

239 

240 def __str__(self): 

241 result = f"FreeFormComponent<bands={self.bands}, shape={self.shape}>" 

242 return result 

243 

244 def __repr__(self): 

245 return self.__str__() 

246 

247 def to_data(self) -> ScarletComponentBaseData: 

248 raise NotImplementedError("Serialization not implemented for FreeFormComponent") 

249 

250 def __getitem__(self, indices: Any) -> FreeFormComponent: 

251 """Get a sub-component corresponding to the given indices. 

252 

253 Parameters 

254 ---------- 

255 indices: Any 

256 The indices to use to slice the component model. 

257 

258 Returns 

259 ------- 

260 component: FreeFormComponent 

261 A new component that is a sub-component of this one. 

262 

263 Raises 

264 ------ 

265 IndexError : 

266 If the index includes a ``Box`` or spatial indices. 

267 """ 

268 if indices in self.bands: 

269 bands = (indices,) 

270 else: 

271 bands = tuple(indices) 

272 

273 return FreeFormComponent( 

274 bands=bands, 

275 model=self.model[indices], 

276 model_bbox=self.bbox, 

277 bg_thresh=self.bg_thresh, 

278 bg_rms=self.bg_rms, 

279 floor=self.floor, 

280 peaks=self.peaks, 

281 min_area=self.min_area, 

282 ) 

283 

284 def __deepcopy__(self, memo: dict[int, Any]) -> FreeFormComponent: 

285 """Create a deep copy of this component. 

286 

287 Parameters 

288 ---------- 

289 memo: dict[int, Any] 

290 A dictionary to keep track of already copied objects. 

291 

292 Returns 

293 ------- 

294 component : FreeFormComponent 

295 A new component that is a deep copy of this one. 

296 """ 

297 if id(self) in memo: 

298 return memo[id(self)] 

299 

300 component = FreeFormComponent.__new__(FreeFormComponent) 

301 memo[id(self)] = component 

302 

303 component.__init__( # type: ignore[misc] 

304 bands=deepcopy(self.bands), 

305 model=deepcopy(self.model), 

306 model_bbox=deepcopy(self.bbox), 

307 bg_thresh=self.bg_thresh, 

308 bg_rms=deepcopy(self.bg_rms), 

309 floor=self.floor, 

310 peaks=deepcopy(self.peaks), 

311 min_area=self.min_area, 

312 ) 

313 return component 

314 

315 def __copy__(self) -> FreeFormComponent: 

316 """Create a copy of this component. 

317 

318 Returns 

319 ------- 

320 component : FreeFormComponent 

321 A new component that is a copy of this one. 

322 """ 

323 return FreeFormComponent( 

324 bands=self.bands, 

325 model=self.model, 

326 model_bbox=self.bbox, 

327 bg_thresh=self.bg_thresh, 

328 bg_rms=self.bg_rms, 

329 floor=self.floor, 

330 peaks=self.peaks, 

331 min_area=self.min_area, 

332 )