Coverage for tests / test_blend.py: 12%
186 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-17 08:40 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-17 08:40 +0000
1# This file is part of scarlet_lite.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24from typing import cast
26import numpy as np
27from lsst.scarlet.lite import Blend, Box, Image, Observation, Source
28from lsst.scarlet.lite.component import CubeComponent, FactorizedComponent, default_adaprox_parameterization
29from lsst.scarlet.lite.initialization import FactorizedInitialization
30from lsst.scarlet.lite.operators import Monotonicity
31from lsst.scarlet.lite.utils import integrated_circular_gaussian
32from numpy.testing import assert_almost_equal, assert_raises
33from scipy.signal import convolve as scipy_convolve
34from utils import ObservationData, ScarletTestCase
37class TestBlend(ScarletTestCase):
38 def setUp(self):
39 bands = ("g", "r", "i")
40 yx0 = (1000, 2000)
41 # The PSF in each band of the "observation"
42 psfs = np.array([integrated_circular_gaussian(sigma=sigma) for sigma in [1.05, 0.9, 1.2]])
43 # The PSF of the model
44 model_psf = integrated_circular_gaussian(sigma=0.8)
46 # The spectrum of each source
47 spectra = np.array(
48 [
49 [40, 10, 0],
50 [0, 25, 40],
51 [15, 8, 3],
52 [20, 3, 4],
53 [0, 30, 60],
54 ],
55 dtype=float,
56 )
58 # Use a point source for all of the sources
59 morphs = [integrated_circular_gaussian(sigma=sigma) for sigma in [0.8, 2.5, 1.1, 2.1, 1.5]]
60 # Make the second component a disk component
61 morphs[1] = scipy_convolve(morphs[1], model_psf, mode="same")
63 # Give the first two components the same center, and unique centers
64 # for the remaining sources
65 centers = [
66 (1010, 2012),
67 (1010, 2012),
68 (1020, 2023),
69 (1020, 2010),
70 (1025, 2020),
71 ]
73 # Create the simulated image and associated data products
74 test_data = ObservationData(bands, psfs, spectra, morphs, centers, model_psf, yx0=yx0)
76 # Create the Observation
77 variance = np.ones((3, 35, 35), dtype=float) * 1e-2
78 weights = 1 / variance
79 weights = weights / np.max(weights)
80 self.observation = Observation(
81 test_data.convolved,
82 variance,
83 weights,
84 psfs,
85 model_psf[None],
86 bands=bands,
87 bbox=Box(variance.shape[-2:], origin=yx0),
88 )
89 self.data = test_data
90 self.spectra = spectra
91 self.centers = centers
92 self.morphs = morphs
94 components = []
95 for spectrum, center, morph, data_morph in zip(
96 self.spectra, self.centers, self.morphs, self.data.morphs
97 ):
98 components.append(
99 FactorizedComponent(
100 bands=bands,
101 spectrum=spectrum,
102 morph=morph,
103 bbox=data_morph.bbox,
104 peak=center,
105 )
106 )
108 sources = [Source(components[:2])]
109 sources += [Source([component]) for component in components[2:]]
111 self.blend = Blend(sources, self.observation)
113 def test_exact(self):
114 """Test that a blend model initialized with the exact solution
115 builds the model correctly
116 """
117 blend = self.blend
118 self.assertEqual(len(blend.components), 5)
119 self.assertEqual(len(blend.sources), 4)
120 self.assertBoxEqual(blend.bbox, Box(self.data.images.shape[1:], self.observation.bbox.origin))
121 self.assertImageAlmostEqual(blend.get_model(), self.data.images)
122 self.assertImageAlmostEqual(blend.get_model(convolve=True), self.observation.images)
123 self.assertImageAlmostEqual(
124 self.observation.convolve(blend.get_model(), mode="real"),
125 self.observation.images,
126 )
128 # Test that the log likelihood is very small
129 assert_almost_equal([blend.log_likelihood], [0])
131 # Test that grad_log_likelihood updates the loss
132 self.assertListEqual(blend.loss, [])
133 blend._grad_log_likelihood()
134 assert_almost_equal(blend.loss, [0])
136 # Remove one of the sources and calculate the non-zero log_likelihood
137 del blend.sources[-1]
138 # Update the loss function and check that the loss changed
139 blend._grad_log_likelihood()
140 assert_almost_equal(blend.log_likelihood, -60.011720889007485)
141 assert_almost_equal(blend.loss, [0, -60.011720889007485])
143 def test_fit_spectra(self):
144 """Test that fitting the spectra with exact morphologies is
145 identical to the multiband image
146 """
147 np.random.seed(0)
148 blend = self.blend
150 # Change the initial spectra so that they can be fit later
151 for component in blend.components:
152 c = cast(FactorizedComponent, component)
153 c.spectrum[:] = np.random.rand(3) * 10
155 with assert_raises(AssertionError):
156 # Since the spectra have not yet been fit,
157 # the model and images should not be equal
158 self.assertImageEqual(blend.get_model(), self.data.images)
160 # We initialized all of the morphologies exactly,
161 # so fitting the spectra should give a nearly exact solution
162 blend.fit_spectra()
164 self.assertEqual(len(blend.components), 5)
165 self.assertEqual(len(blend.sources), 4)
166 self.assertBoxEqual(blend.bbox, self.observation.bbox)
167 self.assertImageAlmostEqual(blend.get_model(), self.data.images)
168 self.assertImageAlmostEqual(blend.get_model(convolve=True), self.observation.images)
170 def test_fit(self):
171 observation = self.observation
172 np.random.seed(0)
173 images = observation.images.copy()
174 noise = np.random.normal(size=observation.images.shape) * 1e-2
175 observation.images._data += noise
177 monotonicity = Monotonicity((101, 101))
178 init = FactorizedInitialization(observation, self.centers, monotonicity=monotonicity)
180 blend = Blend(init.sources, self.observation).fit_spectra()
181 blend.parameterize(default_adaprox_parameterization)
182 blend.fit(100)
184 self.assertImageAlmostEqual(blend.get_model(convolve=True), images, decimal=1)
186 def test_non_factorized(self):
187 np.random.seed(1)
188 blend = self.blend
189 # Remove the disk component from the first source
190 model = self.spectra[1][:, None, None] * self.morphs[1][None, :, :]
191 yx0 = blend.sources[0].components[1].bbox.origin
192 blend.sources[0].components = blend.sources[0].components[:1]
194 # Change the initial spectra so that they can be fit later
195 for component in blend.components:
196 c = cast(FactorizedComponent, component)
197 c.spectrum[:] = np.random.rand(3) * 10
199 with assert_raises(AssertionError):
200 # Since the spectra have not yet been fit,
201 # the model and images should not be equal
202 self.assertImageEqual(blend.get_model(), self.data.images)
204 # Remove the disk component from the first source
205 blend.sources[0].components = blend.sources[0].components[:1]
206 # Create a new source for the disk with a non-factorized component
207 component = CubeComponent(Image(model, bands=self.blend.observation.bands, yx0=yx0), (0, 0))
208 blend.sources.append(Source([component]))
210 blend.fit_spectra()
212 self.assertEqual(len(blend.components), 5)
213 self.assertEqual(len(blend.sources), 5)
214 self.assertImageAlmostEqual(blend.get_model(), self.data.images)
216 def test_clipping(self):
217 blend = self.blend
219 # Change the initial spectra so that they can be fit later
220 for component in blend.components:
221 c = cast(FactorizedComponent, component)
222 c.spectrum[:] = np.random.rand(3) * 10
224 with assert_raises(AssertionError):
225 # Since the spectra have not yet been fit,
226 # the model and images should not be equal
227 self.assertImageEqual(blend.get_model(), self.data.images)
229 # Add an empty source
230 zero_model = Image.from_box(Box((5, 5), (30, 0)), bands=blend.observation.bands)
231 component = CubeComponent(zero_model, (0, 0))
232 blend.sources.append(Source([component]))
234 blend.fit_spectra(clip=True)
236 self.assertEqual(len(blend.components), 5)
237 self.assertEqual(len(blend.sources), 5)
238 self.assertImageAlmostEqual(blend.get_model(), self.data.images)
240 def test_shallow_copy(self):
241 blend = self.blend
242 blend.metadata = {"test": "value"}
243 blend_copy = blend.copy()
245 self.assertIsNot(blend_copy, blend)
246 self.assertEqual(len(blend_copy.sources), len(blend.sources))
247 for source_copy, source in zip(blend_copy.sources, blend.sources):
248 self.assertSourceEqual(source_copy, source)
250 self.assertObservationEqual(blend_copy.observation, blend.observation)
252 self.assertDictEqual(blend_copy.metadata, blend.metadata)
254 def test_deepcopy(self):
255 blend = self.blend
256 blend.metadata = {"test": "value"}
257 blend_copy = blend.copy(deep=True)
259 self.assertIsNot(blend_copy, blend)
260 self.assertEqual(len(blend_copy.sources), len(blend.sources))
261 for source_copy, source in zip(blend_copy.sources, blend.sources):
262 self.assertSourceEqual(source_copy, source)
264 with self.assertRaises(AssertionError):
265 source_copy.components[0]._spectrum.x += 1
266 self.assertSourceEqual(source_copy, source)
268 self.assertObservationEqual(blend_copy.observation, blend.observation)
269 self.assertDictEqual(blend_copy.metadata, blend.metadata)
270 blend_copy.metadata["test"] = "new_value"
271 with self.assertRaises(AssertionError):
272 self.assertDictEqual(blend_copy.metadata, blend.metadata)
274 def test_slice(self):
275 blend = self.blend
276 blend.metadata = {"test": "value"}
277 blend_sliced = blend["g":"r"]
278 self.assertEqual(len(blend.sources), len(blend_sliced.sources))
280 for source_sliced, source in zip(blend_sliced.sources, blend.sources):
281 self.assertSourceEqual(source_sliced, source["g":"r"])
283 self.assertObservationEqual(blend_sliced.observation, blend.observation["g":"r"])
284 self.assertDictEqual(blend_sliced.metadata, blend.metadata)
286 def test_reorder(self):
287 blend = self.blend
288 blend.metadata = {"test": "value"}
289 indices = ("i", "g", "r")
290 blend_reordered = blend[indices]
291 self.assertEqual(len(blend.sources), len(blend_reordered.sources))
293 for source_reordered, source in zip(blend_reordered.sources, blend.sources):
294 self.assertSourceEqual(source_reordered, source[indices])
296 self.assertObservationEqual(blend_reordered.observation, blend.observation[indices])
297 self.assertDictEqual(blend_reordered.metadata, blend.metadata)
299 def test_subset(self):
300 blend = self.blend
301 blend.metadata = {"test": "value"}
302 blend_subset = blend[("r",)]
303 self.assertEqual(len(blend.sources), len(blend_subset.sources))
305 for source_subset, source in zip(blend_subset.sources, blend.sources):
306 self.assertSourceEqual(source_subset, source["r"])
308 self.assertObservationEqual(blend_subset.observation, blend.observation["r"])
309 self.assertDictEqual(blend_subset.metadata, blend.metadata)
311 def test_indexing_errors(self):
312 blend = self.blend
314 with self.assertRaises(IndexError):
315 blend["x"]
317 with self.assertRaises(IndexError):
318 blend[("r", "x")]
320 with self.assertRaises(IndexError):
321 blend["r":"x"]
323 with self.assertRaises(IndexError):
324 blend["x":"i"]
326 with self.assertRaises(IndexError):
327 blend["g", "x", "i"]
329 with self.assertRaises(IndexError):
330 blend[Box((0, 0), (10, 10))]
332 with self.assertRaises(IndexError):
333 blend[:, 10:20, 10:20]
335 with self.assertRaises(IndexError):
336 blend[1:]
338 with self.assertRaises(IndexError):
339 blend[1]
341 with self.assertRaises(IndexError):
342 blend[0, 1]