Coverage for tests / test_initialization.py: 15%
122 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-17 08:40 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-17 08:40 +0000
1# This file is part of lsst.scarlet.lite.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import os
24import numpy as np
25from deprecated.sphinx import deprecated
26from lsst.scarlet.lite import Box, Image, Observation
27from lsst.scarlet.lite.initialization import (
28 FactorizedInitialization,
29 FactorizedWaveletInitialization,
30 init_monotonic_morph,
31 multifit_spectra,
32 trim_morphology,
33)
34from lsst.scarlet.lite.operators import Monotonicity, prox_monotonic_mask
35from lsst.scarlet.lite.utils import integrated_circular_gaussian
36from numpy.testing import assert_almost_equal, assert_array_equal
37from scipy.signal import convolve as scipy_convolve
38from utils import ObservationData, ScarletTestCase
41class TestInitialization(ScarletTestCase):
42 def setUp(self) -> None:
43 yx0 = (1000, 2000)
44 filename = os.path.join(__file__, "..", "..", "data", "hsc_cosmos_35.npz")
45 filename = os.path.abspath(filename)
46 data = np.load(filename)
47 model_psf = integrated_circular_gaussian(sigma=0.8)
48 self.detect = np.sum(data["images"], axis=0)
49 self.centers = np.array([data["catalog"]["y"], data["catalog"]["x"]]).T + np.array(yx0)
50 bands = data["filters"]
51 self.observation = Observation(
52 Image(data["images"], bands=bands, yx0=yx0),
53 Image(data["variance"], bands=bands, yx0=yx0),
54 Image(1 / data["variance"], bands=bands, yx0=yx0),
55 data["psfs"],
56 model_psf[None],
57 bands=bands,
58 )
60 def test_trim_morphology(self):
61 # Test default parameters
62 morph = np.zeros((50, 50)).astype(np.float32)
63 morph[10:15, 12:27] = 1
64 trimmed, trimmed_box = trim_morphology(morph)
65 assert_array_equal(trimmed, morph)
66 self.assertTupleEqual(trimmed_box.origin, (5, 7))
67 self.assertTupleEqual(trimmed_box.shape, (15, 25))
68 self.assertEqual(trimmed.dtype, np.float32)
70 # Test with parameters specified
71 morph = np.full((50, 50), 0.1).astype(np.float32)
72 morph[10:15, 12:27] = 1
73 truth = np.zeros(morph.shape)
74 truth[10:15, 12:27] = 1
75 trimmed, trimmed_box = trim_morphology(morph, 0.5, 1)
76 assert_array_equal(trimmed, truth)
77 self.assertTupleEqual(trimmed_box.origin, (9, 11))
78 self.assertTupleEqual(trimmed_box.shape, (7, 17))
79 self.assertEqual(trimmed.dtype, np.float32)
81 def test_init_monotonic_mask(self):
82 full_box = self.observation.bbox
83 center = self.centers[0]
84 local_center = (center[0] - full_box.origin[0], center[1] - full_box.origin[1])
86 # Default parameters
87 bbox, morph = init_monotonic_morph(self.detect.copy(), center, full_box)
88 self.assertBoxEqual(bbox, Box((38, 29), (1014, 2000)))
89 _, masked_morph, _ = prox_monotonic_mask(self.detect.copy(), local_center, max_iter=0)
90 assert_array_equal(morph, masked_morph / np.max(masked_morph))
91 self.assertEqual(morph.dtype, np.float32)
93 # Specifying parameters
94 bbox, morph = init_monotonic_morph(
95 self.detect.copy(),
96 center,
97 full_box,
98 0, # padding
99 False, # normalizae
100 None, # monotonicity
101 0.2, # threshold
102 )
103 self.assertBoxEqual(bbox, Box((26, 21), (1021, 2003)))
104 # Remove pixels below the threshold
105 truth = masked_morph.copy()
106 truth[truth < 0.2] = 0
107 assert_array_equal(morph, truth)
108 self.assertEqual(morph.dtype, np.float32)
110 # Test an empty morphology
111 bbox, morph = init_monotonic_morph(np.zeros(self.detect.shape), center, full_box)
112 self.assertBoxEqual(bbox, Box((0, 0)))
113 self.assertIsNone(morph)
115 def test_init_monotonic_weighted(self):
116 full_box = self.observation.bbox
117 center = self.centers[0]
118 local_center = (center[0] - full_box.origin[0], center[1] - full_box.origin[1])
119 monotonicity = Monotonicity((101, 101))
121 # Default parameters
122 bbox, morph = init_monotonic_morph(self.detect.copy(), center, full_box, monotonicity=monotonicity)
123 truth = monotonicity(self.detect.copy(), local_center)
124 truth[truth < 0] = 0
125 truth = truth / np.max(truth)
126 self.assertBoxEqual(bbox, Box((58, 48), origin=(1000, 2000)))
127 assert_array_equal(morph, truth)
128 self.assertEqual(morph.dtype, np.float32)
130 # Specify parameters
131 bbox, morph = init_monotonic_morph(
132 self.detect.copy(),
133 center,
134 full_box,
135 0, # padding
136 False, # normalize
137 monotonicity, # monotonicity
138 0.2, # threshold
139 )
140 truth = monotonicity(self.detect.copy(), local_center)
141 truth[truth < 0.2] = 0
142 self.assertBoxEqual(bbox, Box((45, 44), origin=(1010, 2003)))
143 assert_array_equal(morph, truth)
144 self.assertEqual(morph.dtype, np.float32)
146 # Test zero morphology
147 zeros = np.zeros(self.detect.shape)
148 bbox, morph = init_monotonic_morph(zeros, center, full_box, monotonicity=monotonicity)
149 self.assertBoxEqual(bbox, Box((0, 0), (1000, 2000)))
150 self.assertIsNone(morph)
152 def test_multifit_spectra(self):
153 bands = ("g", "r", "i")
154 variance = np.ones((3, 35, 35), dtype=np.float32)
155 weights = 1 / variance
156 psfs = np.array([integrated_circular_gaussian(sigma=sigma) for sigma in [1.05, 0.9, 1.2]])
157 psfs = psfs.astype(np.float32)
158 model_psf = integrated_circular_gaussian(sigma=0.8).astype(np.float32)
160 # The spectrum of each source
161 spectra = np.array(
162 [
163 [31, 10, 0],
164 [0, 5, 20],
165 [15, 8, 3],
166 [20, 3, 4],
167 [0, 30, 60],
168 ],
169 dtype=np.float32,
170 )
172 # Use a point source for all of the sources
173 morphs = [
174 integrated_circular_gaussian(sigma=sigma).astype(np.float32)
175 for sigma in [0.8, 3.1, 1.1, 2.1, 1.5]
176 ]
177 # Make the second component a disk component
178 morphs[1] = scipy_convolve(morphs[1], model_psf, mode="same")
180 # Give the first two components the same center, and unique centers
181 # for the remaining sources
182 centers = [
183 (10, 12),
184 (10, 12),
185 (20, 23),
186 (20, 10),
187 (25, 20),
188 ]
190 # Create the Observation
191 test_data = ObservationData(bands, psfs, spectra, morphs, centers, model_psf, dtype=np.float32)
192 observation = Observation(
193 test_data.convolved,
194 variance,
195 weights,
196 psfs,
197 model_psf[None],
198 bands=bands,
199 )
201 fit_spectra = multifit_spectra(observation, test_data.morphs)
202 self.assertEqual(fit_spectra.dtype, spectra.dtype)
203 assert_almost_equal(fit_spectra, spectra, decimal=5)
205 def test_factorized_chi2_init(self):
206 # Test default parameters
207 init = FactorizedInitialization(self.observation, self.centers)
208 self.assertEqual(init.observation, self.observation)
209 self.assertEqual(init.min_snr, 50)
210 self.assertIsNone(init.monotonicity)
211 self.assertEqual(init.disk_percentile, 25)
212 self.assertEqual(init.thresh, 0.5)
213 self.assertTupleEqual((init.py, init.px), (7, 7))
214 self.assertEqual(len(init.sources), 7)
215 for src in init.sources:
216 self.assertEqual(src.get_model().dtype, np.float32)
218 centers = tuple(tuple(center.astype(int)) for center in self.centers) + ((1000, 2004),)
219 init = FactorizedInitialization(self.observation, centers)
220 self.assertEqual(len(init.sources), 8)
221 for src in init.sources:
222 self.assertEqual(src.get_model().dtype, np.float32)
224 @deprecated(
225 version="v29.0",
226 reason="FactorizedWaveletInitialization is deprecated and will be removed after v29.0",
227 )
228 def test_factorized_wavelet_init(self):
229 # Test default parameters
230 init = FactorizedWaveletInitialization(self.observation, self.centers)
231 self.assertEqual(init.observation, self.observation)
232 self.assertEqual(init.min_snr, 50)
233 self.assertIsNone(init.monotonicity)
234 self.assertTupleEqual((init.py, init.px), (7, 7))
235 self.assertEqual(len(init.sources), 7)
236 components = np.sum([len(src.components) for src in init.sources])
237 self.assertEqual(components, 8)
238 for src in init.sources:
239 self.assertEqual(src.get_model().dtype, np.float32)