Coverage for tests / test_detect.py: 14%
137 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-17 08:40 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-17 08:40 +0000
1# This file is part of lsst.scarlet.lite.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import os
24import numpy as np
25from lsst.scarlet.lite import Box, Image
26from lsst.scarlet.lite.detect import (
27 bbox_to_bounds,
28 bounds_to_bbox,
29 detect_footprints,
30 footprints_to_image,
31 get_detect_wavelets,
32 get_wavelets,
33)
34from lsst.scarlet.lite.detect_pybind11 import (
35 Footprint,
36 Peak,
37 get_connected_multipeak,
38 get_connected_pixels,
39 get_footprints,
40)
41from lsst.scarlet.lite.utils import integrated_circular_gaussian
42from numpy.testing import assert_array_equal
43from utils import ScarletTestCase
46class TestDetect(ScarletTestCase):
47 def setUp(self):
48 centers = (
49 (17, 9),
50 (27, 14),
51 (41, 25),
52 (10, 42),
53 )
54 sigmas = (1.0, 0.95, 0.9, 1.5)
56 sources = []
57 for sigma, center in zip(sigmas, centers):
58 yx0 = center[0] - 7, center[1] - 7
59 source = Image(integrated_circular_gaussian(sigma=sigma).astype(np.float32), yx0=yx0)
60 sources.append(source)
62 image = Image.from_box(Box((51, 51)))
63 for source in sources:
64 image += source
65 image.data[30:32, 40] = 0.5
67 self.image = image
68 self.centers = centers
69 self.sources = sources
71 filename = os.path.join(__file__, "..", "..", "data", "hsc_cosmos_35.npz")
72 filename = os.path.abspath(filename)
73 self.hsc_data = np.load(filename)
75 def tearDown(self):
76 del self.hsc_data
78 def test_connected(self):
79 image = self.image.copy()
81 # Check that the first 3 footprints are all connected
82 # with thresholding at zero
83 truth = self.sources[0] + self.sources[1] + self.sources[2]
84 bbox = truth.bbox
85 truth = truth.data > 0
87 unchecked = np.ones(self.image.shape, dtype=bool)
88 footprint = np.zeros(self.image.shape, dtype=bool)
89 y, x = self.centers[0]
90 get_connected_pixels(
91 y,
92 x,
93 image.data,
94 unchecked,
95 footprint,
96 np.array([y, y, x, x]).astype(np.int32),
97 0,
98 )
99 assert_array_equal(footprint[bbox.slices], truth)
101 # Check that only the first 2 footprints are all connected
102 # with thresholding at 1e-15
103 truth = self.sources[0] + self.sources[1]
104 bbox = truth.bbox
105 truth = truth.data > 1e-15
107 unchecked = np.ones(self.image.shape, dtype=bool)
108 footprint = np.zeros(self.image.shape, dtype=bool)
109 y, x = self.centers[0]
110 get_connected_pixels(
111 y,
112 x,
113 image.data,
114 unchecked,
115 footprint,
116 np.array([y, y, x, x]).astype(np.int32),
117 1e-15,
118 )
119 assert_array_equal(footprint[bbox.slices], truth)
121 # Test finding all peaks
122 footprint = get_connected_multipeak(self.image.data, self.centers, 1e-15)
123 truth = self.image.data > 1e-15
124 truth[30:32, 40] = False
125 assert_array_equal(footprint, truth)
127 def _check_footprints(self, footprints):
128 self.assertEqual(len(footprints), 3)
130 # The first footprint has a single peak
131 assert_array_equal(footprints[0].data, self.sources[3].data > 1e-15)
132 self.assertEqual(len(footprints[0].peaks), 1)
133 self.assertBoxEqual(footprints[0].bbox, self.sources[3].bbox)
134 self.assertEqual(footprints[0].peaks[0].y, self.centers[3][0])
135 self.assertEqual(footprints[0].peaks[0].x, self.centers[3][1])
137 # The second footprint has two peaks
138 truth = self.sources[0] + self.sources[1]
139 assert_array_equal(footprints[1].data, truth.data > 1e-15)
140 self.assertEqual(len(footprints[1].peaks), 2)
141 self.assertBoxEqual(footprints[1].bbox, truth.bbox)
142 self.assertEqual(footprints[1].peaks[0].y, self.centers[1][0])
143 self.assertEqual(footprints[1].peaks[0].x, self.centers[1][1])
144 self.assertEqual(footprints[1].peaks[1].y, self.centers[0][0])
145 self.assertEqual(footprints[1].peaks[1].x, self.centers[0][1])
147 # The third footprint has a single peak
148 assert_array_equal(footprints[2].data, self.sources[2].data > 1e-15)
149 self.assertEqual(len(footprints[2].peaks), 1)
150 self.assertBoxEqual(footprints[2].bbox, self.sources[2].bbox)
151 self.assertEqual(footprints[2].peaks[0].y, self.centers[2][0])
152 self.assertEqual(footprints[2].peaks[0].x, self.centers[2][1])
154 truth = 1 * self.sources[3] + 2 * (self.sources[0] + self.sources[1]) + 3 * self.sources[2]
155 truth.data[truth.data < 1e-15] = 0
156 fp_image = footprints_to_image(footprints, truth.bbox)
157 assert_array_equal(fp_image, truth.data)
159 def test_get_footprints(self):
160 footprints = get_footprints(self.image.data, 1, 4, 1e-15, 1e-15, True)
161 self._check_footprints(footprints)
163 def _check_peaks(self, peaks):
164 matched_peaks = []
165 for center in self.centers:
166 for peak in peaks:
167 if peak.y == center[0] and peak.x == center[1]:
168 matched_peaks.append(peak)
169 break
170 self.assertEqual(len(matched_peaks), len(self.centers))
172 def test_detect_footprints(self):
173 # This method doesn't test for accurracy, since
174 # there is no variance, so we set it to ones.
175 variance = np.ones(self.image.shape, dtype=self.image.dtype)
177 footprints = detect_footprints(
178 self.image.data[None, :, :],
179 variance[None, :, :],
180 scales=1,
181 generation=2,
182 origin=(0, 0),
183 min_separation=1,
184 min_area=4,
185 peak_thresh=1e-15,
186 footprint_thresh=1e-15,
187 find_peaks=True,
188 remove_high_freq=False,
189 min_pixel_detect=1,
190 )
192 self.assertEqual(len(footprints), 3)
193 peaks = [peak for footprint in footprints for peak in footprint.peaks]
194 self._check_peaks(peaks)
196 footprints = detect_footprints(
197 self.image.data[None, :, :],
198 variance[None, :, :],
199 scales=1,
200 generation=1,
201 min_separation=1,
202 min_area=4,
203 peak_thresh=1e-15,
204 footprint_thresh=1e-15,
205 find_peaks=True,
206 remove_high_freq=True,
207 min_pixel_detect=1,
208 )
210 self.assertEqual(len(footprints), 2)
211 peaks = [peak for footprint in footprints for peak in footprint.peaks]
212 self._check_peaks(peaks)
214 def test_bounds_to_bbox(self):
215 bounds = (3, 27, 11, 52)
216 truth = Box((25, 42), (3, 11))
217 bbox = bounds_to_bbox(bounds)
218 self.assertBoxEqual(bbox, truth)
220 # Check that the reverse operation also works
221 new_bounds = bbox_to_bounds(bbox)
222 self.assertTupleEqual(new_bounds, bounds)
224 def test_footprint(self):
225 footprint = self.sources[0].data
226 footprint[footprint < 1e-15] = 0
227 bounds = [
228 self.sources[0].bbox.start[0],
229 self.sources[0].bbox.stop[0] - 1,
230 self.sources[0].bbox.start[1],
231 self.sources[0].bbox.stop[1] - 1,
232 ]
233 print(bounds)
234 peaks = [Peak(self.centers[0][0], self.centers[0][1], self.image.data[self.centers[0]])]
235 footprint1 = Footprint(footprint, peaks, bounds)
236 footprint = self.sources[1].data
237 footprint[footprint < 1e-15] = 0
238 bounds = [
239 self.sources[1].bbox.start[0],
240 self.sources[1].bbox.stop[0] - 1,
241 self.sources[1].bbox.start[1],
242 self.sources[1].bbox.stop[1] - 1,
243 ]
244 print(bounds)
245 peaks = [Peak(self.centers[1][0], self.centers[1][1], self.image.data[self.centers[1]])]
246 footprint2 = Footprint(footprint, peaks, bounds)
248 truth = self.sources[0] + self.sources[1]
249 truth.data[truth.data < 1e-15] = 0
250 image = footprints_to_image([footprint1, footprint2], truth.bbox)
251 assert_array_equal(image, truth.data)
253 # Test intersection
254 truth = (self.sources[0] > 1e-15) & (self.sources[1] > 1e-15)
255 intersection = footprint1.intersection(footprint2)
256 self.assertImageEqual(intersection, truth)
258 # Test union
259 truth = (self.sources[0] > 1e-15) | (self.sources[1] > 1e-15)
260 union = footprint1.union(footprint2)
261 self.assertImageEqual(union, truth)
263 def test_get_wavelets(self):
264 images = self.hsc_data["images"]
265 variance = self.hsc_data["variance"]
266 wavelets = get_wavelets(images, variance)
268 self.assertTupleEqual(wavelets.shape, (5, 5, 58, 48))
269 self.assertEqual(wavelets.dtype, np.float32)
271 def test_get_detect_wavelets(self):
272 images = self.hsc_data["images"]
273 variance = self.hsc_data["variance"]
274 wavelets = get_detect_wavelets(images, variance)
276 self.assertTupleEqual(wavelets.shape, (4, 58, 48))