Coverage for tests / test_detect.py: 14%

137 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-17 08:40 +0000

1# This file is part of lsst.scarlet.lite. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import os 

23 

24import numpy as np 

25from lsst.scarlet.lite import Box, Image 

26from lsst.scarlet.lite.detect import ( 

27 bbox_to_bounds, 

28 bounds_to_bbox, 

29 detect_footprints, 

30 footprints_to_image, 

31 get_detect_wavelets, 

32 get_wavelets, 

33) 

34from lsst.scarlet.lite.detect_pybind11 import ( 

35 Footprint, 

36 Peak, 

37 get_connected_multipeak, 

38 get_connected_pixels, 

39 get_footprints, 

40) 

41from lsst.scarlet.lite.utils import integrated_circular_gaussian 

42from numpy.testing import assert_array_equal 

43from utils import ScarletTestCase 

44 

45 

46class TestDetect(ScarletTestCase): 

47 def setUp(self): 

48 centers = ( 

49 (17, 9), 

50 (27, 14), 

51 (41, 25), 

52 (10, 42), 

53 ) 

54 sigmas = (1.0, 0.95, 0.9, 1.5) 

55 

56 sources = [] 

57 for sigma, center in zip(sigmas, centers): 

58 yx0 = center[0] - 7, center[1] - 7 

59 source = Image(integrated_circular_gaussian(sigma=sigma).astype(np.float32), yx0=yx0) 

60 sources.append(source) 

61 

62 image = Image.from_box(Box((51, 51))) 

63 for source in sources: 

64 image += source 

65 image.data[30:32, 40] = 0.5 

66 

67 self.image = image 

68 self.centers = centers 

69 self.sources = sources 

70 

71 filename = os.path.join(__file__, "..", "..", "data", "hsc_cosmos_35.npz") 

72 filename = os.path.abspath(filename) 

73 self.hsc_data = np.load(filename) 

74 

75 def tearDown(self): 

76 del self.hsc_data 

77 

78 def test_connected(self): 

79 image = self.image.copy() 

80 

81 # Check that the first 3 footprints are all connected 

82 # with thresholding at zero 

83 truth = self.sources[0] + self.sources[1] + self.sources[2] 

84 bbox = truth.bbox 

85 truth = truth.data > 0 

86 

87 unchecked = np.ones(self.image.shape, dtype=bool) 

88 footprint = np.zeros(self.image.shape, dtype=bool) 

89 y, x = self.centers[0] 

90 get_connected_pixels( 

91 y, 

92 x, 

93 image.data, 

94 unchecked, 

95 footprint, 

96 np.array([y, y, x, x]).astype(np.int32), 

97 0, 

98 ) 

99 assert_array_equal(footprint[bbox.slices], truth) 

100 

101 # Check that only the first 2 footprints are all connected 

102 # with thresholding at 1e-15 

103 truth = self.sources[0] + self.sources[1] 

104 bbox = truth.bbox 

105 truth = truth.data > 1e-15 

106 

107 unchecked = np.ones(self.image.shape, dtype=bool) 

108 footprint = np.zeros(self.image.shape, dtype=bool) 

109 y, x = self.centers[0] 

110 get_connected_pixels( 

111 y, 

112 x, 

113 image.data, 

114 unchecked, 

115 footprint, 

116 np.array([y, y, x, x]).astype(np.int32), 

117 1e-15, 

118 ) 

119 assert_array_equal(footprint[bbox.slices], truth) 

120 

121 # Test finding all peaks 

122 footprint = get_connected_multipeak(self.image.data, self.centers, 1e-15) 

123 truth = self.image.data > 1e-15 

124 truth[30:32, 40] = False 

125 assert_array_equal(footprint, truth) 

126 

127 def _check_footprints(self, footprints): 

128 self.assertEqual(len(footprints), 3) 

129 

130 # The first footprint has a single peak 

131 assert_array_equal(footprints[0].data, self.sources[3].data > 1e-15) 

132 self.assertEqual(len(footprints[0].peaks), 1) 

133 self.assertBoxEqual(footprints[0].bbox, self.sources[3].bbox) 

134 self.assertEqual(footprints[0].peaks[0].y, self.centers[3][0]) 

135 self.assertEqual(footprints[0].peaks[0].x, self.centers[3][1]) 

136 

137 # The second footprint has two peaks 

138 truth = self.sources[0] + self.sources[1] 

139 assert_array_equal(footprints[1].data, truth.data > 1e-15) 

140 self.assertEqual(len(footprints[1].peaks), 2) 

141 self.assertBoxEqual(footprints[1].bbox, truth.bbox) 

142 self.assertEqual(footprints[1].peaks[0].y, self.centers[1][0]) 

143 self.assertEqual(footprints[1].peaks[0].x, self.centers[1][1]) 

144 self.assertEqual(footprints[1].peaks[1].y, self.centers[0][0]) 

145 self.assertEqual(footprints[1].peaks[1].x, self.centers[0][1]) 

146 

147 # The third footprint has a single peak 

148 assert_array_equal(footprints[2].data, self.sources[2].data > 1e-15) 

149 self.assertEqual(len(footprints[2].peaks), 1) 

150 self.assertBoxEqual(footprints[2].bbox, self.sources[2].bbox) 

151 self.assertEqual(footprints[2].peaks[0].y, self.centers[2][0]) 

152 self.assertEqual(footprints[2].peaks[0].x, self.centers[2][1]) 

153 

154 truth = 1 * self.sources[3] + 2 * (self.sources[0] + self.sources[1]) + 3 * self.sources[2] 

155 truth.data[truth.data < 1e-15] = 0 

156 fp_image = footprints_to_image(footprints, truth.bbox) 

157 assert_array_equal(fp_image, truth.data) 

158 

159 def test_get_footprints(self): 

160 footprints = get_footprints(self.image.data, 1, 4, 1e-15, 1e-15, True) 

161 self._check_footprints(footprints) 

162 

163 def _check_peaks(self, peaks): 

164 matched_peaks = [] 

165 for center in self.centers: 

166 for peak in peaks: 

167 if peak.y == center[0] and peak.x == center[1]: 

168 matched_peaks.append(peak) 

169 break 

170 self.assertEqual(len(matched_peaks), len(self.centers)) 

171 

172 def test_detect_footprints(self): 

173 # This method doesn't test for accurracy, since 

174 # there is no variance, so we set it to ones. 

175 variance = np.ones(self.image.shape, dtype=self.image.dtype) 

176 

177 footprints = detect_footprints( 

178 self.image.data[None, :, :], 

179 variance[None, :, :], 

180 scales=1, 

181 generation=2, 

182 origin=(0, 0), 

183 min_separation=1, 

184 min_area=4, 

185 peak_thresh=1e-15, 

186 footprint_thresh=1e-15, 

187 find_peaks=True, 

188 remove_high_freq=False, 

189 min_pixel_detect=1, 

190 ) 

191 

192 self.assertEqual(len(footprints), 3) 

193 peaks = [peak for footprint in footprints for peak in footprint.peaks] 

194 self._check_peaks(peaks) 

195 

196 footprints = detect_footprints( 

197 self.image.data[None, :, :], 

198 variance[None, :, :], 

199 scales=1, 

200 generation=1, 

201 min_separation=1, 

202 min_area=4, 

203 peak_thresh=1e-15, 

204 footprint_thresh=1e-15, 

205 find_peaks=True, 

206 remove_high_freq=True, 

207 min_pixel_detect=1, 

208 ) 

209 

210 self.assertEqual(len(footprints), 2) 

211 peaks = [peak for footprint in footprints for peak in footprint.peaks] 

212 self._check_peaks(peaks) 

213 

214 def test_bounds_to_bbox(self): 

215 bounds = (3, 27, 11, 52) 

216 truth = Box((25, 42), (3, 11)) 

217 bbox = bounds_to_bbox(bounds) 

218 self.assertBoxEqual(bbox, truth) 

219 

220 # Check that the reverse operation also works 

221 new_bounds = bbox_to_bounds(bbox) 

222 self.assertTupleEqual(new_bounds, bounds) 

223 

224 def test_footprint(self): 

225 footprint = self.sources[0].data 

226 footprint[footprint < 1e-15] = 0 

227 bounds = [ 

228 self.sources[0].bbox.start[0], 

229 self.sources[0].bbox.stop[0] - 1, 

230 self.sources[0].bbox.start[1], 

231 self.sources[0].bbox.stop[1] - 1, 

232 ] 

233 print(bounds) 

234 peaks = [Peak(self.centers[0][0], self.centers[0][1], self.image.data[self.centers[0]])] 

235 footprint1 = Footprint(footprint, peaks, bounds) 

236 footprint = self.sources[1].data 

237 footprint[footprint < 1e-15] = 0 

238 bounds = [ 

239 self.sources[1].bbox.start[0], 

240 self.sources[1].bbox.stop[0] - 1, 

241 self.sources[1].bbox.start[1], 

242 self.sources[1].bbox.stop[1] - 1, 

243 ] 

244 print(bounds) 

245 peaks = [Peak(self.centers[1][0], self.centers[1][1], self.image.data[self.centers[1]])] 

246 footprint2 = Footprint(footprint, peaks, bounds) 

247 

248 truth = self.sources[0] + self.sources[1] 

249 truth.data[truth.data < 1e-15] = 0 

250 image = footprints_to_image([footprint1, footprint2], truth.bbox) 

251 assert_array_equal(image, truth.data) 

252 

253 # Test intersection 

254 truth = (self.sources[0] > 1e-15) & (self.sources[1] > 1e-15) 

255 intersection = footprint1.intersection(footprint2) 

256 self.assertImageEqual(intersection, truth) 

257 

258 # Test union 

259 truth = (self.sources[0] > 1e-15) | (self.sources[1] > 1e-15) 

260 union = footprint1.union(footprint2) 

261 self.assertImageEqual(union, truth) 

262 

263 def test_get_wavelets(self): 

264 images = self.hsc_data["images"] 

265 variance = self.hsc_data["variance"] 

266 wavelets = get_wavelets(images, variance) 

267 

268 self.assertTupleEqual(wavelets.shape, (5, 5, 58, 48)) 

269 self.assertEqual(wavelets.dtype, np.float32) 

270 

271 def test_get_detect_wavelets(self): 

272 images = self.hsc_data["images"] 

273 variance = self.hsc_data["variance"] 

274 wavelets = get_detect_wavelets(images, variance) 

275 

276 self.assertTupleEqual(wavelets.shape, (4, 58, 48))