Coverage for tests/test_component.py: 16%
121 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-03-20 03:40 -0700
« prev ^ index » next coverage.py v7.4.4, created at 2024-03-20 03:40 -0700
1# This file is part of scarlet_lite.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from typing import Callable
24import numpy as np
25from lsst.scarlet.lite import Box, Image, Parameter
26from lsst.scarlet.lite.component import (
27 Component,
28 FactorizedComponent,
29 default_adaprox_parameterization,
30 default_fista_parameterization,
31)
32from lsst.scarlet.lite.operators import Monotonicity
33from numpy.testing import assert_almost_equal, assert_array_equal
34from utils import ScarletTestCase
37class DummyComponent(Component):
38 def resize(self) -> bool:
39 pass
41 def update(self, it: int, input_grad: np.ndarray):
42 pass
44 def get_model(self) -> Image:
45 pass
47 def parameterize(self, parameterization: Callable) -> None:
48 parameterization(self)
51class TestFactorizedComponent(ScarletTestCase):
52 def setUp(self) -> None:
53 spectrum = np.arange(3).astype(np.float32)
54 morph = np.arange(20).reshape(4, 5).astype(np.float32)
55 bands = ("g", "r", "i")
56 bbox = Box((4, 5), (22, 31))
57 self.model_box = Box((100, 100))
58 center = (24, 33)
60 self.component = FactorizedComponent(
61 bands,
62 spectrum,
63 morph,
64 bbox,
65 center,
66 )
68 self.bands = bands
69 self.spectrum = spectrum
70 self.morph = morph
71 self.full_shape = (3, 100, 100)
73 def test_constructor(self):
74 # Test with only required parameters
75 component = FactorizedComponent(
76 self.bands,
77 self.spectrum,
78 self.morph,
79 self.component.bbox,
80 )
82 self.assertIsInstance(component._spectrum, Parameter)
83 assert_array_equal(component.spectrum, self.spectrum)
84 self.assertIsInstance(component._morph, Parameter)
85 assert_array_equal(component.morph, self.morph)
86 self.assertBoxEqual(component.bbox, self.component.bbox)
87 self.assertIsNone(component.peak)
88 self.assertIsNone(component.bg_rms)
89 self.assertEqual(component.bg_thresh, 0.25)
90 self.assertEqual(component.floor, 1e-20)
91 self.assertTupleEqual(component.shape, (3, 4, 5))
93 # Test that parameters are passed through
94 center = self.component.peak
95 bg_rms = np.arange(5) / 10
96 bg_thresh = 0.9
97 floor = 1e-10
99 component = FactorizedComponent(
100 self.bands,
101 self.spectrum,
102 self.morph,
103 self.component.bbox,
104 center,
105 bg_rms,
106 bg_thresh,
107 floor,
108 )
110 self.assertTupleEqual(component.peak, center)
111 assert_array_equal(component.bg_rms, bg_rms) # type: ignore
112 self.assertEqual(component.bg_thresh, bg_thresh)
113 self.assertEqual(component.floor, floor)
114 self.assertEqual(component.get_model().dtype, np.float32)
116 def test_get_model(self):
117 component = self.component
118 assert_array_equal(component.get_model(), self.spectrum[:, None, None] * self.morph[None, :, :])
120 # Insert component into a larger model
121 full_model = np.zeros(self.full_shape)
122 full_model[:, 22:26, 31:36] = self.spectrum[:, None, None] * self.morph[None, :, :]
124 test_model = Image(np.zeros(self.full_shape), bands=self.bands)
125 test_model += component.get_model()
127 assert_array_equal(test_model.data, full_model)
129 def test_gradients(self):
130 component = self.component
131 morph = self.morph
132 spectrum = self.spectrum
134 input_grad = np.array([morph, 2 * morph, 3 * morph])
135 true_spectrum_grad = np.array(
136 [
137 np.sum(morph**2),
138 np.sum(2 * morph**2),
139 np.sum(3 * morph**2),
140 ]
141 )
142 assert_almost_equal(component.grad_spectrum(input_grad, spectrum, morph), true_spectrum_grad)
144 true_morph_grad = np.sum(input_grad * spectrum[:, None, None], axis=0)
145 assert_almost_equal(component.grad_morph(input_grad, morph, spectrum), true_morph_grad)
147 def test_proximal_operators(self):
148 # Test spectrum positivity, morph threshold, and monotonicity
149 spectrum = np.array([-1, 2, 3], dtype=float)
150 morph = np.array([[10, 2, 1], [1, 5, 3], [0.1, 4, -1]], dtype=float)
151 bbox = Box((3, 3), (10, 10))
152 morph_bbox = Box((100, 100))
153 center = (11, 11)
154 monotonicity = Monotonicity((101, 101), fit_radius=0)
156 component = FactorizedComponent(
157 self.bands,
158 spectrum.copy(),
159 morph.copy(),
160 bbox,
161 center,
162 bg_rms=np.array([1, 1, 1]),
163 bg_thresh=0.5,
164 monotonicity=monotonicity,
165 )
167 proxed_spectrum = np.array([1e-20, 2, 3])
168 proxed_morph = np.array([[2.6666666666666667, 2, 1], [1, 5, 3], [0, 4, 0]])
169 proxed_morph = proxed_morph / 5
171 component.prox_spectrum(component.spectrum)
172 component.prox_morph(component.morph)
174 assert_array_equal(component.spectrum, proxed_spectrum)
175 assert_array_equal(component.morph, proxed_morph)
177 component = FactorizedComponent(
178 self.bands,
179 spectrum.copy(),
180 morph.copy(),
181 bbox,
182 None,
183 )
185 proxed_spectrum = np.array([1e-20, 2, 3])
186 proxed_morph = np.array([[10, 2, 1], [1, 5, 3], [0.1, 4, 0]])
187 proxed_morph = proxed_morph / 10
189 component.prox_spectrum(component.spectrum)
190 component.prox_morph(component.morph)
192 assert_array_equal(component.spectrum, proxed_spectrum)
193 assert_array_equal(component.morph, proxed_morph)
195 self.assertFalse(component.resize(morph_bbox))
197 def test_resize(self):
198 spectrum = np.array([1, 2, 3], dtype=float)
199 morph = np.zeros((10, 10), dtype=float)
200 morph[3:6, 5:8] = np.arange(9).reshape(3, 3)
201 bbox = Box((10, 10), (3, 5))
203 morph_bbox = Box((100, 100))
204 monotonicity = Monotonicity((101, 101), fit_radius=0)
206 component = FactorizedComponent(
207 self.bands,
208 spectrum.copy(),
209 morph.copy(),
210 bbox,
211 None,
212 bg_rms=np.array([1, 1, 1]),
213 bg_thresh=0.5,
214 monotonicity=monotonicity,
215 padding=1,
216 )
218 self.assertTupleEqual(component.morph.shape, (10, 10))
219 self.assertIsNone(component.component_center)
221 component.resize(morph_bbox)
222 self.assertTupleEqual(component.morph.shape, (5, 5))
223 self.assertTupleEqual(component.bbox.origin, (5, 9))
224 self.assertTupleEqual(component.bbox.shape, (5, 5))
225 self.assertIsNone(component.component_center)
227 def test_parameterization(self):
228 component = self.component
229 assert_array_equal(component.get_model(), self.spectrum[:, None, None] * self.morph[None, :, :])
231 component.parameterize(default_fista_parameterization)
232 helpers = set(component._morph.helpers.keys())
233 self.assertSetEqual(helpers, {"z"})
234 component.parameterize(default_adaprox_parameterization)
235 helpers = set(component._morph.helpers.keys())
236 self.assertSetEqual(helpers, {"m", "v", "vhat"})
238 params = (tuple("grizy"), Box((5, 5)))
239 with self.assertRaises(NotImplementedError):
240 default_fista_parameterization(DummyComponent(*params))
242 with self.assertRaises(NotImplementedError):
243 default_adaprox_parameterization(DummyComponent(*params))