Coverage for tests / test_component.py: 16%

242 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-26 08:40 +0000

1# This file is part of scarlet_lite. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24from abc import ABC 

25from typing import Any, Callable 

26 

27import numpy as np 

28from lsst.scarlet.lite import Box, Image, Parameter 

29from lsst.scarlet.lite.component import ( 

30 Component, 

31 CubeComponent, 

32 FactorizedComponent, 

33 default_adaprox_parameterization, 

34 default_fista_parameterization, 

35) 

36from lsst.scarlet.lite.operators import Monotonicity 

37from lsst.scarlet.lite.utils import integrated_circular_gaussian 

38from numpy.testing import assert_almost_equal, assert_array_equal 

39from utils import ScarletTestCase 

40 

41 

42class DummyComponent(Component): 

43 def resize(self) -> bool: 

44 pass 

45 

46 def update(self, it: int, input_grad: np.ndarray): 

47 pass 

48 

49 def get_model(self) -> Image: 

50 pass 

51 

52 def parameterize(self, parameterization: Callable) -> None: 

53 parameterization(self) 

54 

55 def to_data(self) -> DummyComponent: 

56 pass 

57 

58 def __getitem__(self, indices: Any) -> DummyComponent: 

59 pass 

60 

61 def __copy__(self) -> DummyComponent: 

62 pass 

63 

64 def __deepcopy__(self, memo: dict[int, Any]) -> DummyComponent: 

65 pass 

66 

67 

68class _ComponentTestBase(ABC): 

69 def test_slice(self): 

70 component = self.component 

71 component_sliced = component["g":"r"] 

72 self.assertTupleEqual(component_sliced.bands, ("g", "r")) 

73 np.testing.assert_array_equal(component_sliced.get_model(), component.get_model().data[0:2]) 

74 

75 def test_reorder(self): 

76 component = self.component 

77 indices = ("i", "g", "r") 

78 component_reordered = component["i", "g", "r"] 

79 self.assertTupleEqual(component_reordered.bands, indices) 

80 np.testing.assert_array_equal( 

81 component_reordered.get_model(), 

82 component.get_model().data[(2, 0, 1),], 

83 ) 

84 

85 component_reordered = component["igr"] 

86 self.assertTupleEqual(component_reordered.bands, indices) 

87 np.testing.assert_array_equal( 

88 component_reordered.get_model(), 

89 component.get_model().data[(2, 0, 1),], 

90 ) 

91 

92 def test_subset(self): 

93 component = self.component 

94 indices = ("r",) 

95 component_subset = component["r"] 

96 self.assertTupleEqual(component_subset.bands, indices) 

97 np.testing.assert_array_equal( 

98 component_subset.get_model(), 

99 component.get_model().data[1:2,], 

100 ) 

101 

102 component = self.component.copy(deep=True) 

103 component._bands = ("ab", "cd", "ef") 

104 indices = "ab" 

105 component_reordered = component["ab"] 

106 self.assertTupleEqual(component_reordered.bands, (indices,)) 

107 np.testing.assert_array_equal( 

108 component_reordered.get_model(), 

109 component.get_model().data[0:1,], 

110 ) 

111 

112 def test_indexing_errors(self): 

113 component = self.component 

114 print("bands", component.bands) 

115 with self.assertRaises(IndexError): 

116 component["z"] 

117 

118 with self.assertRaises(IndexError): 

119 component["r":"z"] 

120 

121 with self.assertRaises(IndexError): 

122 component["z":"i"] 

123 

124 with self.assertRaises(IndexError): 

125 component["g", "z", "i"] 

126 

127 with self.assertRaises(IndexError): 

128 component[Box((0, 0), (10, 10))] 

129 

130 with self.assertRaises(IndexError): 

131 component[:, 10:20, 10:20] 

132 

133 with self.assertRaises(IndexError): 

134 component[1:] 

135 

136 with self.assertRaises(IndexError): 

137 component[1] 

138 

139 with self.assertRaises(IndexError): 

140 component[0, 1] 

141 

142 

143class TestFactorizedComponent(_ComponentTestBase, ScarletTestCase): 

144 def setUp(self) -> None: 

145 spectrum = np.arange(3).astype(np.float32) 

146 morph = np.arange(20).reshape(4, 5).astype(np.float32) 

147 bands = ("g", "r", "i") 

148 bbox = Box((4, 5), (22, 31)) 

149 self.model_box = Box((100, 100)) 

150 center = (24, 33) 

151 

152 self.component = FactorizedComponent( 

153 bands, 

154 spectrum, 

155 morph, 

156 bbox, 

157 center, 

158 ) 

159 

160 self.bands = bands 

161 self.spectrum = spectrum 

162 self.morph = morph 

163 self.full_shape = (3, 100, 100) 

164 

165 def test_constructor(self): 

166 # Test with only required parameters 

167 component = FactorizedComponent( 

168 self.bands, 

169 self.spectrum, 

170 self.morph, 

171 self.component.bbox, 

172 ) 

173 

174 self.assertIsInstance(component._spectrum, Parameter) 

175 assert_array_equal(component.spectrum, self.spectrum) 

176 self.assertIsInstance(component._morph, Parameter) 

177 assert_array_equal(component.morph, self.morph) 

178 self.assertBoxEqual(component.bbox, self.component.bbox) 

179 self.assertIsNone(component.peak) 

180 self.assertIsNone(component.bg_rms) 

181 self.assertEqual(component.bg_thresh, 0.25) 

182 self.assertEqual(component.floor, 1e-20) 

183 self.assertTupleEqual(component.shape, (3, 4, 5)) 

184 

185 # Test that parameters are passed through 

186 center = self.component.peak 

187 bg_rms = np.arange(5) / 10 

188 bg_thresh = 0.9 

189 floor = 1e-10 

190 

191 component = FactorizedComponent( 

192 self.bands, 

193 self.spectrum, 

194 self.morph, 

195 self.component.bbox, 

196 center, 

197 bg_rms, 

198 bg_thresh, 

199 floor, 

200 ) 

201 

202 self.assertTupleEqual(component.peak, center) 

203 assert_array_equal(component.bg_rms, bg_rms) # type: ignore 

204 self.assertEqual(component.bg_thresh, bg_thresh) 

205 self.assertEqual(component.floor, floor) 

206 self.assertEqual(component.get_model().dtype, np.float32) 

207 

208 def test_get_model(self): 

209 component = self.component 

210 assert_array_equal(component.get_model(), self.spectrum[:, None, None] * self.morph[None, :, :]) 

211 

212 # Insert component into a larger model 

213 full_model = np.zeros(self.full_shape) 

214 full_model[:, 22:26, 31:36] = self.spectrum[:, None, None] * self.morph[None, :, :] 

215 

216 test_model = Image(np.zeros(self.full_shape), bands=self.bands) 

217 test_model += component.get_model() 

218 

219 assert_array_equal(test_model.data, full_model) 

220 

221 def test_gradients(self): 

222 component = self.component 

223 morph = self.morph 

224 spectrum = self.spectrum 

225 

226 input_grad = np.array([morph, 2 * morph, 3 * morph]) 

227 true_spectrum_grad = np.array( 

228 [ 

229 np.sum(morph**2), 

230 np.sum(2 * morph**2), 

231 np.sum(3 * morph**2), 

232 ] 

233 ) 

234 assert_almost_equal(component.grad_spectrum(input_grad, spectrum, morph), true_spectrum_grad) 

235 

236 true_morph_grad = np.sum(input_grad * spectrum[:, None, None], axis=0) 

237 assert_almost_equal(component.grad_morph(input_grad, morph, spectrum), true_morph_grad) 

238 

239 def test_proximal_operators(self): 

240 # Test spectrum positivity, morph threshold, and monotonicity 

241 spectrum = np.array([-1, 2, 3], dtype=float) 

242 morph = np.array([[10, 2, 1], [1, 5, 3], [0.1, 4, -1]], dtype=float) 

243 bbox = Box((3, 3), (10, 10)) 

244 morph_bbox = Box((100, 100)) 

245 center = (11, 11) 

246 monotonicity = Monotonicity((101, 101), fit_radius=0) 

247 

248 component = FactorizedComponent( 

249 self.bands, 

250 spectrum.copy(), 

251 morph.copy(), 

252 bbox, 

253 center, 

254 bg_rms=np.array([1, 1, 1]), 

255 bg_thresh=0.5, 

256 monotonicity=monotonicity, 

257 ) 

258 

259 proxed_spectrum = np.array([1e-20, 2, 3]) 

260 proxed_morph = np.array([[2.6666666666666667, 2, 1], [1, 5, 3], [0, 4, 0]]) 

261 proxed_morph = proxed_morph / 5 

262 

263 component.prox_spectrum(component.spectrum) 

264 component.prox_morph(component.morph) 

265 

266 assert_array_equal(component.spectrum, proxed_spectrum) 

267 assert_array_equal(component.morph, proxed_morph) 

268 

269 component = FactorizedComponent( 

270 self.bands, 

271 spectrum.copy(), 

272 morph.copy(), 

273 bbox, 

274 None, 

275 ) 

276 

277 proxed_spectrum = np.array([1e-20, 2, 3]) 

278 proxed_morph = np.array([[10, 2, 1], [1, 5, 3], [0.1, 4, 0]]) 

279 proxed_morph = proxed_morph / 10 

280 

281 component.prox_spectrum(component.spectrum) 

282 component.prox_morph(component.morph) 

283 

284 assert_array_equal(component.spectrum, proxed_spectrum) 

285 assert_array_equal(component.morph, proxed_morph) 

286 

287 self.assertFalse(component.resize(morph_bbox)) 

288 

289 def test_resize(self): 

290 spectrum = np.array([1, 2, 3], dtype=float) 

291 morph = np.zeros((10, 10), dtype=float) 

292 morph[3:6, 5:8] = np.arange(9).reshape(3, 3) 

293 bbox = Box((10, 10), (3, 5)) 

294 

295 morph_bbox = Box((100, 100)) 

296 monotonicity = Monotonicity((101, 101), fit_radius=0) 

297 

298 component = FactorizedComponent( 

299 self.bands, 

300 spectrum.copy(), 

301 morph.copy(), 

302 bbox, 

303 None, 

304 bg_rms=np.array([1, 1, 1]), 

305 bg_thresh=0.5, 

306 monotonicity=monotonicity, 

307 padding=1, 

308 ) 

309 

310 self.assertTupleEqual(component.morph.shape, (10, 10)) 

311 self.assertIsNone(component.component_center) 

312 

313 component.resize(morph_bbox) 

314 self.assertTupleEqual(component.morph.shape, (5, 5)) 

315 self.assertTupleEqual(component.bbox.origin, (5, 9)) 

316 self.assertTupleEqual(component.bbox.shape, (5, 5)) 

317 self.assertIsNone(component.component_center) 

318 

319 def test_parameterization(self): 

320 component = self.component 

321 assert_array_equal(component.get_model(), self.spectrum[:, None, None] * self.morph[None, :, :]) 

322 

323 component.parameterize(default_fista_parameterization) 

324 helpers = set(component._morph.helpers.keys()) 

325 self.assertSetEqual(helpers, {"z"}) 

326 component.parameterize(default_adaprox_parameterization) 

327 helpers = set(component._morph.helpers.keys()) 

328 self.assertSetEqual(helpers, {"m", "v", "vhat"}) 

329 

330 params = (tuple("grizy"), Box((5, 5))) 

331 with self.assertRaises(NotImplementedError): 

332 default_fista_parameterization(DummyComponent(*params)) 

333 

334 with self.assertRaises(NotImplementedError): 

335 default_adaprox_parameterization(DummyComponent(*params)) 

336 

337 def test_shallow_copy(self): 

338 component = self.component 

339 component.monotonicity = Monotonicity((11, 11), fit_radius=0) 

340 

341 component_copy = component.copy() 

342 

343 self.assertIsNot(component, component_copy) 

344 np.testing.assert_array_equal(component._spectrum.x, component_copy._spectrum.x) 

345 np.testing.assert_array_equal(component._morph.x, component_copy._morph.x) 

346 self.assertIs(component.bbox, component_copy.bbox) 

347 self.assertIs(component.peak, component_copy.peak) 

348 self.assertIs(component.bg_thresh, component_copy.bg_thresh) 

349 self.assertIs(component.monotonicity, component_copy.monotonicity) 

350 

351 def test_deep_copy(self): 

352 component = self.component 

353 component.monotonicity = Monotonicity((11, 11), fit_radius=0) 

354 component_deepcopy = component.copy(deep=True) 

355 

356 self.assertIsNot(component, component_deepcopy) 

357 

358 np.testing.assert_array_equal(component._spectrum.x, component_deepcopy._spectrum.x) 

359 component_deepcopy._spectrum.x += 1 

360 with self.assertRaises(AssertionError): 

361 np.testing.assert_array_equal(component._spectrum.x, component_deepcopy._spectrum.x) 

362 

363 np.testing.assert_array_equal(component._morph.x, component_deepcopy._morph.x) 

364 component_deepcopy._morph.x += 1 

365 with self.assertRaises(AssertionError): 

366 np.testing.assert_array_equal(component._morph.x, component_deepcopy._morph.x) 

367 

368 self.assertIsNot(component.bbox, component_deepcopy.bbox) 

369 self.assertBoxEqual(component.bbox, component_deepcopy.bbox) 

370 

371 self.assertTupleEqual(component.peak, component_deepcopy.peak) 

372 self.assertEqual(component.bg_thresh, component_deepcopy.bg_thresh) 

373 self.assertIsNot(component.monotonicity, component_deepcopy.monotonicity) 

374 

375 

376class TestCubeComponent(_ComponentTestBase, ScarletTestCase): 

377 def setUp(self) -> None: 

378 super().setUp() 

379 self.bands = tuple("gri") 

380 peak = (27, 32) 

381 bbox = Box((15, 15), (20, 25)) 

382 morph = integrated_circular_gaussian(sigma=0.8).astype(np.float32) 

383 spectrum = np.arange(3, dtype=np.float32) 

384 model = morph[None, :, :] * spectrum[:, None, None] 

385 model_image = Image(model, yx0=bbox.origin, bands=self.bands) 

386 self.component = CubeComponent(model=model_image, peak=peak) 

387 

388 def test_constructor(self): 

389 component = self.component 

390 self.assertIsInstance(component._model, Image) 

391 np.testing.assert_array_equal(component._model.data, self.component._model.data) 

392 self.assertTupleEqual(component.bands, self.bands) 

393 self.assertBoxEqual(component.bbox, Box((15, 15), (20, 25))) 

394 self.assertTupleEqual(component.peak, (27, 32)) 

395 

396 def test_shallow_copy(self): 

397 component = self.component 

398 component_copy = component.copy() 

399 

400 self.assertIsNot(component_copy, component) 

401 self.assertTupleEqual(component_copy.peak, component.peak) 

402 self.assertImageEqual(component_copy._model, component._model) 

403 

404 def test_deep_copy(self): 

405 component = self.component 

406 component_copy = component.copy(deep=True) 

407 

408 self.assertIsNot(component, component_copy) 

409 

410 self.assertTupleEqual(component_copy.peak, component.peak) 

411 self.assertImageEqual(component_copy._model, component._model) 

412 with self.assertRaises(AssertionError): 

413 component_copy._model._data -= 1 

414 self.assertImageEqual(component_copy._model, component._model)