Coverage for tests / test_component.py: 16%
242 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 23:28 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 23:28 +0000
1# This file is part of scarlet_lite.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24from abc import ABC
25from typing import Any, Callable
27import numpy as np
28from lsst.scarlet.lite import Box, Image, Parameter
29from lsst.scarlet.lite.component import (
30 Component,
31 CubeComponent,
32 FactorizedComponent,
33 default_adaprox_parameterization,
34 default_fista_parameterization,
35)
36from lsst.scarlet.lite.operators import Monotonicity
37from lsst.scarlet.lite.utils import integrated_circular_gaussian
38from numpy.testing import assert_almost_equal, assert_array_equal
39from utils import ScarletTestCase
42class DummyComponent(Component):
43 def resize(self) -> bool:
44 pass
46 def update(self, it: int, input_grad: np.ndarray):
47 pass
49 def get_model(self) -> Image:
50 pass
52 def parameterize(self, parameterization: Callable) -> None:
53 parameterization(self)
55 def to_data(self) -> DummyComponent:
56 pass
58 def __getitem__(self, indices: Any) -> DummyComponent:
59 pass
61 def __copy__(self) -> DummyComponent:
62 pass
64 def __deepcopy__(self, memo: dict[int, Any]) -> DummyComponent:
65 pass
68class _ComponentTestBase(ABC):
69 def test_slice(self):
70 component = self.component
71 component_sliced = component["g":"r"]
72 self.assertTupleEqual(component_sliced.bands, ("g", "r"))
73 np.testing.assert_array_equal(component_sliced.get_model(), component.get_model().data[0:2])
75 def test_reorder(self):
76 component = self.component
77 indices = ("i", "g", "r")
78 component_reordered = component["i", "g", "r"]
79 self.assertTupleEqual(component_reordered.bands, indices)
80 np.testing.assert_array_equal(
81 component_reordered.get_model(),
82 component.get_model().data[(2, 0, 1),],
83 )
85 component_reordered = component["igr"]
86 self.assertTupleEqual(component_reordered.bands, indices)
87 np.testing.assert_array_equal(
88 component_reordered.get_model(),
89 component.get_model().data[(2, 0, 1),],
90 )
92 def test_subset(self):
93 component = self.component
94 indices = ("r",)
95 component_subset = component["r"]
96 self.assertTupleEqual(component_subset.bands, indices)
97 np.testing.assert_array_equal(
98 component_subset.get_model(),
99 component.get_model().data[1:2,],
100 )
102 component = self.component.copy(deep=True)
103 component._bands = ("ab", "cd", "ef")
104 indices = "ab"
105 component_reordered = component["ab"]
106 self.assertTupleEqual(component_reordered.bands, (indices,))
107 np.testing.assert_array_equal(
108 component_reordered.get_model(),
109 component.get_model().data[0:1,],
110 )
112 def test_indexing_errors(self):
113 component = self.component
114 print("bands", component.bands)
115 with self.assertRaises(IndexError):
116 component["z"]
118 with self.assertRaises(IndexError):
119 component["r":"z"]
121 with self.assertRaises(IndexError):
122 component["z":"i"]
124 with self.assertRaises(IndexError):
125 component["g", "z", "i"]
127 with self.assertRaises(IndexError):
128 component[Box((0, 0), (10, 10))]
130 with self.assertRaises(IndexError):
131 component[:, 10:20, 10:20]
133 with self.assertRaises(IndexError):
134 component[1:]
136 with self.assertRaises(IndexError):
137 component[1]
139 with self.assertRaises(IndexError):
140 component[0, 1]
143class TestFactorizedComponent(_ComponentTestBase, ScarletTestCase):
144 def setUp(self) -> None:
145 spectrum = np.arange(3).astype(np.float32)
146 morph = np.arange(20).reshape(4, 5).astype(np.float32)
147 bands = ("g", "r", "i")
148 bbox = Box((4, 5), (22, 31))
149 self.model_box = Box((100, 100))
150 center = (24, 33)
152 self.component = FactorizedComponent(
153 bands,
154 spectrum,
155 morph,
156 bbox,
157 center,
158 )
160 self.bands = bands
161 self.spectrum = spectrum
162 self.morph = morph
163 self.full_shape = (3, 100, 100)
165 def test_constructor(self):
166 # Test with only required parameters
167 component = FactorizedComponent(
168 self.bands,
169 self.spectrum,
170 self.morph,
171 self.component.bbox,
172 )
174 self.assertIsInstance(component._spectrum, Parameter)
175 assert_array_equal(component.spectrum, self.spectrum)
176 self.assertIsInstance(component._morph, Parameter)
177 assert_array_equal(component.morph, self.morph)
178 self.assertBoxEqual(component.bbox, self.component.bbox)
179 self.assertIsNone(component.peak)
180 self.assertIsNone(component.bg_rms)
181 self.assertEqual(component.bg_thresh, 0.25)
182 self.assertEqual(component.floor, 1e-20)
183 self.assertTupleEqual(component.shape, (3, 4, 5))
185 # Test that parameters are passed through
186 center = self.component.peak
187 bg_rms = np.arange(5) / 10
188 bg_thresh = 0.9
189 floor = 1e-10
191 component = FactorizedComponent(
192 self.bands,
193 self.spectrum,
194 self.morph,
195 self.component.bbox,
196 center,
197 bg_rms,
198 bg_thresh,
199 floor,
200 )
202 self.assertTupleEqual(component.peak, center)
203 assert_array_equal(component.bg_rms, bg_rms) # type: ignore
204 self.assertEqual(component.bg_thresh, bg_thresh)
205 self.assertEqual(component.floor, floor)
206 self.assertEqual(component.get_model().dtype, np.float32)
208 def test_get_model(self):
209 component = self.component
210 assert_array_equal(component.get_model(), self.spectrum[:, None, None] * self.morph[None, :, :])
212 # Insert component into a larger model
213 full_model = np.zeros(self.full_shape)
214 full_model[:, 22:26, 31:36] = self.spectrum[:, None, None] * self.morph[None, :, :]
216 test_model = Image(np.zeros(self.full_shape), bands=self.bands)
217 test_model += component.get_model()
219 assert_array_equal(test_model.data, full_model)
221 def test_gradients(self):
222 component = self.component
223 morph = self.morph
224 spectrum = self.spectrum
226 input_grad = np.array([morph, 2 * morph, 3 * morph])
227 true_spectrum_grad = np.array(
228 [
229 np.sum(morph**2),
230 np.sum(2 * morph**2),
231 np.sum(3 * morph**2),
232 ]
233 )
234 assert_almost_equal(component.grad_spectrum(input_grad, spectrum, morph), true_spectrum_grad)
236 true_morph_grad = np.sum(input_grad * spectrum[:, None, None], axis=0)
237 assert_almost_equal(component.grad_morph(input_grad, morph, spectrum), true_morph_grad)
239 def test_proximal_operators(self):
240 # Test spectrum positivity, morph threshold, and monotonicity
241 spectrum = np.array([-1, 2, 3], dtype=float)
242 morph = np.array([[10, 2, 1], [1, 5, 3], [0.1, 4, -1]], dtype=float)
243 bbox = Box((3, 3), (10, 10))
244 morph_bbox = Box((100, 100))
245 center = (11, 11)
246 monotonicity = Monotonicity((101, 101), fit_radius=0)
248 component = FactorizedComponent(
249 self.bands,
250 spectrum.copy(),
251 morph.copy(),
252 bbox,
253 center,
254 bg_rms=np.array([1, 1, 1]),
255 bg_thresh=0.5,
256 monotonicity=monotonicity,
257 )
259 proxed_spectrum = np.array([1e-20, 2, 3])
260 proxed_morph = np.array([[2.6666666666666667, 2, 1], [1, 5, 3], [0, 4, 0]])
261 proxed_morph = proxed_morph / 5
263 component.prox_spectrum(component.spectrum)
264 component.prox_morph(component.morph)
266 assert_array_equal(component.spectrum, proxed_spectrum)
267 assert_array_equal(component.morph, proxed_morph)
269 component = FactorizedComponent(
270 self.bands,
271 spectrum.copy(),
272 morph.copy(),
273 bbox,
274 None,
275 )
277 proxed_spectrum = np.array([1e-20, 2, 3])
278 proxed_morph = np.array([[10, 2, 1], [1, 5, 3], [0.1, 4, 0]])
279 proxed_morph = proxed_morph / 10
281 component.prox_spectrum(component.spectrum)
282 component.prox_morph(component.morph)
284 assert_array_equal(component.spectrum, proxed_spectrum)
285 assert_array_equal(component.morph, proxed_morph)
287 self.assertFalse(component.resize(morph_bbox))
289 def test_resize(self):
290 spectrum = np.array([1, 2, 3], dtype=float)
291 morph = np.zeros((10, 10), dtype=float)
292 morph[3:6, 5:8] = np.arange(9).reshape(3, 3)
293 bbox = Box((10, 10), (3, 5))
295 morph_bbox = Box((100, 100))
296 monotonicity = Monotonicity((101, 101), fit_radius=0)
298 component = FactorizedComponent(
299 self.bands,
300 spectrum.copy(),
301 morph.copy(),
302 bbox,
303 None,
304 bg_rms=np.array([1, 1, 1]),
305 bg_thresh=0.5,
306 monotonicity=monotonicity,
307 padding=1,
308 )
310 self.assertTupleEqual(component.morph.shape, (10, 10))
311 self.assertIsNone(component.component_center)
313 component.resize(morph_bbox)
314 self.assertTupleEqual(component.morph.shape, (5, 5))
315 self.assertTupleEqual(component.bbox.origin, (5, 9))
316 self.assertTupleEqual(component.bbox.shape, (5, 5))
317 self.assertIsNone(component.component_center)
319 def test_parameterization(self):
320 component = self.component
321 assert_array_equal(component.get_model(), self.spectrum[:, None, None] * self.morph[None, :, :])
323 component.parameterize(default_fista_parameterization)
324 helpers = set(component._morph.helpers.keys())
325 self.assertSetEqual(helpers, {"z"})
326 component.parameterize(default_adaprox_parameterization)
327 helpers = set(component._morph.helpers.keys())
328 self.assertSetEqual(helpers, {"m", "v", "vhat"})
330 params = (tuple("grizy"), Box((5, 5)))
331 with self.assertRaises(NotImplementedError):
332 default_fista_parameterization(DummyComponent(*params))
334 with self.assertRaises(NotImplementedError):
335 default_adaprox_parameterization(DummyComponent(*params))
337 def test_shallow_copy(self):
338 component = self.component
339 component.monotonicity = Monotonicity((11, 11), fit_radius=0)
341 component_copy = component.copy()
343 self.assertIsNot(component, component_copy)
344 np.testing.assert_array_equal(component._spectrum.x, component_copy._spectrum.x)
345 np.testing.assert_array_equal(component._morph.x, component_copy._morph.x)
346 self.assertIs(component.bbox, component_copy.bbox)
347 self.assertIs(component.peak, component_copy.peak)
348 self.assertIs(component.bg_thresh, component_copy.bg_thresh)
349 self.assertIs(component.monotonicity, component_copy.monotonicity)
351 def test_deep_copy(self):
352 component = self.component
353 component.monotonicity = Monotonicity((11, 11), fit_radius=0)
354 component_deepcopy = component.copy(deep=True)
356 self.assertIsNot(component, component_deepcopy)
358 np.testing.assert_array_equal(component._spectrum.x, component_deepcopy._spectrum.x)
359 component_deepcopy._spectrum.x += 1
360 with self.assertRaises(AssertionError):
361 np.testing.assert_array_equal(component._spectrum.x, component_deepcopy._spectrum.x)
363 np.testing.assert_array_equal(component._morph.x, component_deepcopy._morph.x)
364 component_deepcopy._morph.x += 1
365 with self.assertRaises(AssertionError):
366 np.testing.assert_array_equal(component._morph.x, component_deepcopy._morph.x)
368 self.assertIsNot(component.bbox, component_deepcopy.bbox)
369 self.assertBoxEqual(component.bbox, component_deepcopy.bbox)
371 self.assertTupleEqual(component.peak, component_deepcopy.peak)
372 self.assertEqual(component.bg_thresh, component_deepcopy.bg_thresh)
373 self.assertIsNot(component.monotonicity, component_deepcopy.monotonicity)
376class TestCubeComponent(_ComponentTestBase, ScarletTestCase):
377 def setUp(self) -> None:
378 super().setUp()
379 self.bands = tuple("gri")
380 peak = (27, 32)
381 bbox = Box((15, 15), (20, 25))
382 morph = integrated_circular_gaussian(sigma=0.8).astype(np.float32)
383 spectrum = np.arange(3, dtype=np.float32)
384 model = morph[None, :, :] * spectrum[:, None, None]
385 model_image = Image(model, yx0=bbox.origin, bands=self.bands)
386 self.component = CubeComponent(model=model_image, peak=peak)
388 def test_constructor(self):
389 component = self.component
390 self.assertIsInstance(component._model, Image)
391 np.testing.assert_array_equal(component._model.data, self.component._model.data)
392 self.assertTupleEqual(component.bands, self.bands)
393 self.assertBoxEqual(component.bbox, Box((15, 15), (20, 25)))
394 self.assertTupleEqual(component.peak, (27, 32))
396 def test_shallow_copy(self):
397 component = self.component
398 component_copy = component.copy()
400 self.assertIsNot(component_copy, component)
401 self.assertTupleEqual(component_copy.peak, component.peak)
402 self.assertImageEqual(component_copy._model, component._model)
404 def test_deep_copy(self):
405 component = self.component
406 component_copy = component.copy(deep=True)
408 self.assertIsNot(component, component_copy)
410 self.assertTupleEqual(component_copy.peak, component.peak)
411 self.assertImageEqual(component_copy._model, component._model)
412 with self.assertRaises(AssertionError):
413 component_copy._model._data -= 1
414 self.assertImageEqual(component_copy._model, component._model)