Coverage for tests/test_detect.py: 15%
113 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-03-20 03:40 -0700
« prev ^ index » next coverage.py v7.4.4, created at 2024-03-20 03:40 -0700
1# This file is part of lsst.scarlet.lite.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import os
24import numpy as np
25from lsst.scarlet.lite import Box, Image
26from lsst.scarlet.lite.detect import bounds_to_bbox, footprints_to_image, get_detect_wavelets, get_wavelets
27from lsst.scarlet.lite.detect_pybind11 import (
28 Footprint,
29 Peak,
30 get_connected_multipeak,
31 get_connected_pixels,
32 get_footprints,
33)
34from lsst.scarlet.lite.utils import integrated_circular_gaussian
35from numpy.testing import assert_array_equal
36from utils import ScarletTestCase
39class TestDetect(ScarletTestCase):
40 def setUp(self):
41 centers = (
42 (17, 9),
43 (27, 14),
44 (41, 25),
45 (10, 42),
46 )
47 sigmas = (1.0, 0.95, 0.9, 1.5)
49 sources = []
50 for sigma, center in zip(sigmas, centers):
51 yx0 = center[0] - 7, center[1] - 7
52 source = Image(integrated_circular_gaussian(sigma=sigma).astype(np.float32), yx0=yx0)
53 sources.append(source)
55 image = Image.from_box(Box((51, 51)))
56 for source in sources:
57 image += source
58 image.data[30:32, 40] = 0.5
60 self.image = image
61 self.centers = centers
62 self.sources = sources
64 filename = os.path.join(__file__, "..", "..", "data", "hsc_cosmos_35.npz")
65 filename = os.path.abspath(filename)
66 self.hsc_data = np.load(filename)
68 def tearDown(self):
69 del self.hsc_data
71 def test_connected(self):
72 image = self.image.copy()
74 # Check that the first 3 footprints are all connected
75 # with thresholding at zero
76 truth = self.sources[0] + self.sources[1] + self.sources[2]
77 bbox = truth.bbox
78 truth = truth.data > 0
80 unchecked = np.ones(self.image.shape, dtype=bool)
81 footprint = np.zeros(self.image.shape, dtype=bool)
82 y, x = self.centers[0]
83 get_connected_pixels(
84 y,
85 x,
86 image.data,
87 unchecked,
88 footprint,
89 np.array([y, y, x, x]).astype(np.int32),
90 0,
91 )
92 assert_array_equal(footprint[bbox.slices], truth)
94 # Check that only the first 2 footprints are all connected
95 # with thresholding at 1e-15
96 truth = self.sources[0] + self.sources[1]
97 bbox = truth.bbox
98 truth = truth.data > 1e-15
100 unchecked = np.ones(self.image.shape, dtype=bool)
101 footprint = np.zeros(self.image.shape, dtype=bool)
102 y, x = self.centers[0]
103 get_connected_pixels(
104 y,
105 x,
106 image.data,
107 unchecked,
108 footprint,
109 np.array([y, y, x, x]).astype(np.int32),
110 1e-15,
111 )
112 assert_array_equal(footprint[bbox.slices], truth)
114 # Test finding all peaks
115 footprint = get_connected_multipeak(self.image.data, self.centers, 1e-15)
116 truth = self.image.data > 1e-15
117 truth[30:32, 40] = False
118 assert_array_equal(footprint, truth)
120 def test_get_footprints(self):
121 footprints = get_footprints(self.image.data, 1, 4, 1e-15, True)
122 self.assertEqual(len(footprints), 3)
124 # The first footprint has a single peak
125 assert_array_equal(footprints[0].data, self.sources[3].data > 1e-15)
126 self.assertEqual(len(footprints[0].peaks), 1)
127 self.assertBoxEqual(footprints[0].bbox, self.sources[3].bbox)
128 self.assertEqual(footprints[0].peaks[0].y, self.centers[3][0])
129 self.assertEqual(footprints[0].peaks[0].x, self.centers[3][1])
131 # The second footprint has two peaks
132 truth = self.sources[0] + self.sources[1]
133 assert_array_equal(footprints[1].data, truth.data > 1e-15)
134 self.assertEqual(len(footprints[1].peaks), 2)
135 self.assertBoxEqual(footprints[1].bbox, truth.bbox)
136 self.assertEqual(footprints[1].peaks[0].y, self.centers[1][0])
137 self.assertEqual(footprints[1].peaks[0].x, self.centers[1][1])
138 self.assertEqual(footprints[1].peaks[1].y, self.centers[0][0])
139 self.assertEqual(footprints[1].peaks[1].x, self.centers[0][1])
141 # The third footprint has a single peak
142 assert_array_equal(footprints[2].data, self.sources[2].data > 1e-15)
143 self.assertEqual(len(footprints[2].peaks), 1)
144 self.assertBoxEqual(footprints[2].bbox, self.sources[2].bbox)
145 self.assertEqual(footprints[2].peaks[0].y, self.centers[2][0])
146 self.assertEqual(footprints[2].peaks[0].x, self.centers[2][1])
148 truth = 1 * self.sources[3] + 2 * (self.sources[0] + self.sources[1]) + 3 * self.sources[2]
149 truth.data[truth.data < 1e-15] = 0
150 fp_image = footprints_to_image(footprints, truth.shape)
151 assert_array_equal(fp_image, truth.data)
153 def test_bounds_to_bbox(self):
154 bounds = (3, 27, 11, 52)
155 truth = Box((25, 42), (3, 11))
156 bbox = bounds_to_bbox(bounds)
157 self.assertBoxEqual(bbox, truth)
159 def test_footprint(self):
160 footprint = self.sources[0].data
161 footprint[footprint < 1e-15] = 0
162 bounds = [
163 self.sources[0].bbox.start[0],
164 self.sources[0].bbox.stop[0],
165 self.sources[0].bbox.start[1],
166 self.sources[0].bbox.stop[1],
167 ]
168 peaks = [Peak(self.centers[0][0], self.centers[0][1], self.image.data[self.centers[0]])]
169 footprint1 = Footprint(footprint, peaks, bounds)
170 footprint = self.sources[1].data
171 footprint[footprint < 1e-15] = 0
172 bounds = [
173 self.sources[1].bbox.start[0],
174 self.sources[1].bbox.stop[0],
175 self.sources[1].bbox.start[1],
176 self.sources[1].bbox.stop[1],
177 ]
178 peaks = [Peak(self.centers[1][0], self.centers[1][1], self.image.data[self.centers[1]])]
179 footprint2 = Footprint(footprint, peaks, bounds)
181 truth = self.sources[0] + self.sources[1]
182 truth.data[truth.data < 1e-15] = 0
183 image = footprints_to_image([footprint1, footprint2], truth.shape)
184 assert_array_equal(image, truth.data)
186 # Test intersection
187 truth = (self.sources[0] > 1e-15) & (self.sources[1] > 1e-15)
188 intersection = footprint1.intersection(footprint2)
189 self.assertImageEqual(intersection, truth)
191 # Test union
192 truth = (self.sources[0] > 1e-15) | (self.sources[1] > 1e-15)
193 union = footprint1.union(footprint2)
194 self.assertImageEqual(union, truth)
196 def test_get_wavelets(self):
197 images = self.hsc_data["images"]
198 variance = self.hsc_data["variance"]
199 wavelets = get_wavelets(images, variance)
201 self.assertTupleEqual(wavelets.shape, (5, 5, 58, 48))
202 self.assertEqual(wavelets.dtype, np.float32)
204 def test_get_detect_wavelets(self):
205 images = self.hsc_data["images"]
206 variance = self.hsc_data["variance"]
207 wavelets = get_detect_wavelets(images, variance)
209 self.assertTupleEqual(wavelets.shape, (4, 58, 48))