Coverage for tests/test_initialization.py: 13%
120 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-16 02:46 -0700
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-16 02:46 -0700
1# This file is part of lsst.scarlet.lite.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import os
24import numpy as np
25from lsst.scarlet.lite import Box, Image, Observation
26from lsst.scarlet.lite.initialization import (
27 FactorizedChi2Initialization,
28 FactorizedWaveletInitialization,
29 init_monotonic_morph,
30 multifit_spectra,
31 trim_morphology,
32)
33from lsst.scarlet.lite.operators import Monotonicity, prox_monotonic_mask
34from lsst.scarlet.lite.utils import integrated_circular_gaussian
35from numpy.testing import assert_almost_equal, assert_array_equal
36from scipy.signal import convolve as scipy_convolve
37from utils import ObservationData, ScarletTestCase
40class TestInitialization(ScarletTestCase):
41 def setUp(self) -> None:
42 yx0 = (1000, 2000)
43 filename = os.path.join(__file__, "..", "..", "data", "hsc_cosmos_35.npz")
44 filename = os.path.abspath(filename)
45 data = np.load(filename)
46 model_psf = integrated_circular_gaussian(sigma=0.8)
47 self.detect = np.sum(data["images"], axis=0)
48 self.centers = np.array([data["catalog"]["y"], data["catalog"]["x"]]).T + np.array(yx0)
49 bands = data["filters"]
50 self.observation = Observation(
51 Image(data["images"], bands=bands, yx0=yx0),
52 Image(data["variance"], bands=bands, yx0=yx0),
53 Image(1 / data["variance"], bands=bands, yx0=yx0),
54 data["psfs"],
55 model_psf[None],
56 bands=bands,
57 )
59 def test_trim_morphology(self):
60 # Test default parameters
61 morph = np.zeros((50, 50)).astype(np.float32)
62 morph[10:15, 12:27] = 1
63 trimmed, trimmed_box = trim_morphology(morph)
64 assert_array_equal(trimmed, morph)
65 self.assertTupleEqual(trimmed_box.origin, (5, 7))
66 self.assertTupleEqual(trimmed_box.shape, (15, 25))
67 self.assertEqual(trimmed.dtype, np.float32)
69 # Test with parameters specified
70 morph = np.full((50, 50), 0.1).astype(np.float32)
71 morph[10:15, 12:27] = 1
72 truth = np.zeros(morph.shape)
73 truth[10:15, 12:27] = 1
74 trimmed, trimmed_box = trim_morphology(morph, 0.5, 1)
75 assert_array_equal(trimmed, truth)
76 self.assertTupleEqual(trimmed_box.origin, (9, 11))
77 self.assertTupleEqual(trimmed_box.shape, (7, 17))
78 self.assertEqual(trimmed.dtype, np.float32)
80 def test_init_monotonic_mask(self):
81 full_box = self.observation.bbox
82 center = self.centers[0]
83 local_center = (center[0] - full_box.origin[0], center[1] - full_box.origin[1])
85 # Default parameters
86 bbox, morph = init_monotonic_morph(self.detect.copy(), center, full_box)
87 self.assertBoxEqual(bbox, Box((38, 29), (1014, 2000)))
88 _, masked_morph, _ = prox_monotonic_mask(self.detect.copy(), local_center, max_iter=0)
89 assert_array_equal(morph, masked_morph / np.max(masked_morph))
90 self.assertEqual(morph.dtype, np.float32)
92 # Specifying parameters
93 bbox, morph = init_monotonic_morph(
94 self.detect.copy(),
95 center,
96 full_box,
97 0, # padding
98 False, # normalizae
99 None, # monotonicity
100 0.2, # threshold
101 )
102 self.assertBoxEqual(bbox, Box((26, 21), (1021, 2003)))
103 # Remove pixels below the threshold
104 truth = masked_morph.copy()
105 truth[truth < 0.2] = 0
106 assert_array_equal(morph, truth)
107 self.assertEqual(morph.dtype, np.float32)
109 # Test an empty morphology
110 bbox, morph = init_monotonic_morph(np.zeros(self.detect.shape), center, full_box)
111 self.assertBoxEqual(bbox, Box((0, 0)))
112 self.assertIsNone(morph)
114 def test_init_monotonic_weighted(self):
115 full_box = self.observation.bbox
116 center = self.centers[0]
117 local_center = (center[0] - full_box.origin[0], center[1] - full_box.origin[1])
118 monotonicity = Monotonicity((101, 101))
120 # Default parameters
121 bbox, morph = init_monotonic_morph(self.detect.copy(), center, full_box, monotonicity=monotonicity)
122 truth = monotonicity(self.detect.copy(), local_center)
123 truth[truth < 0] = 0
124 truth = truth / np.max(truth)
125 self.assertBoxEqual(bbox, Box((58, 48), origin=(1000, 2000)))
126 assert_array_equal(morph, truth)
127 self.assertEqual(morph.dtype, np.float32)
129 # Specify parameters
130 bbox, morph = init_monotonic_morph(
131 self.detect.copy(),
132 center,
133 full_box,
134 0, # padding
135 False, # normalize
136 monotonicity, # monotonicity
137 0.2, # threshold
138 )
139 truth = monotonicity(self.detect.copy(), local_center)
140 truth[truth < 0.2] = 0
141 self.assertBoxEqual(bbox, Box((45, 44), origin=(1010, 2003)))
142 assert_array_equal(morph, truth)
143 self.assertEqual(morph.dtype, np.float32)
145 # Test zero morphology
146 zeros = np.zeros(self.detect.shape)
147 bbox, morph = init_monotonic_morph(zeros, center, full_box, monotonicity=monotonicity)
148 self.assertBoxEqual(bbox, Box((0, 0), (1000, 2000)))
149 self.assertIsNone(morph)
151 def test_multifit_spectra(self):
152 bands = ("g", "r", "i")
153 variance = np.ones((3, 35, 35), dtype=np.float32)
154 weights = 1 / variance
155 psfs = np.array([integrated_circular_gaussian(sigma=sigma) for sigma in [1.05, 0.9, 1.2]])
156 psfs = psfs.astype(np.float32)
157 model_psf = integrated_circular_gaussian(sigma=0.8).astype(np.float32)
159 # The spectrum of each source
160 spectra = np.array(
161 [
162 [31, 10, 0],
163 [0, 5, 20],
164 [15, 8, 3],
165 [20, 3, 4],
166 [0, 30, 60],
167 ],
168 dtype=np.float32,
169 )
171 # Use a point source for all of the sources
172 morphs = [
173 integrated_circular_gaussian(sigma=sigma).astype(np.float32)
174 for sigma in [0.8, 3.1, 1.1, 2.1, 1.5]
175 ]
176 # Make the second component a disk component
177 morphs[1] = scipy_convolve(morphs[1], model_psf, mode="same")
179 # Give the first two components the same center, and unique centers
180 # for the remaining sources
181 centers = [
182 (10, 12),
183 (10, 12),
184 (20, 23),
185 (20, 10),
186 (25, 20),
187 ]
189 # Create the Observation
190 test_data = ObservationData(bands, psfs, spectra, morphs, centers, model_psf, dtype=np.float32)
191 observation = Observation(
192 test_data.convolved,
193 variance,
194 weights,
195 psfs,
196 model_psf[None],
197 bands=bands,
198 )
200 fit_spectra = multifit_spectra(observation, test_data.morphs)
201 self.assertEqual(fit_spectra.dtype, spectra.dtype)
202 assert_almost_equal(fit_spectra, spectra, decimal=5)
204 def test_factorized_chi2_init(self):
205 # Test default parameters
206 init = FactorizedChi2Initialization(self.observation, self.centers)
207 self.assertEqual(init.observation, self.observation)
208 self.assertEqual(init.min_snr, 50)
209 self.assertIsNone(init.monotonicity)
210 self.assertEqual(init.disk_percentile, 25)
211 self.assertEqual(init.thresh, 0.5)
212 self.assertTupleEqual((init.py, init.px), (7, 7))
213 self.assertEqual(len(init.sources), 7)
214 for src in init.sources:
215 self.assertEqual(src.get_model().dtype, np.float32)
217 centers = tuple(tuple(center.astype(int)) for center in self.centers) + ((1000, 2004),)
218 init = FactorizedChi2Initialization(self.observation, centers)
219 self.assertEqual(len(init.sources), 8)
220 for src in init.sources:
221 self.assertEqual(src.get_model().dtype, np.float32)
223 def test_factorized_wavelet_init(self):
224 # Test default parameters
225 init = FactorizedWaveletInitialization(self.observation, self.centers)
226 self.assertEqual(init.observation, self.observation)
227 self.assertEqual(init.min_snr, 50)
228 self.assertIsNone(init.monotonicity)
229 self.assertTupleEqual((init.py, init.px), (7, 7))
230 self.assertEqual(len(init.sources), 7)
231 components = np.sum([len(src.components) for src in init.sources])
232 self.assertEqual(components, 8)
233 for src in init.sources:
234 self.assertEqual(src.get_model().dtype, np.float32)