Coverage for tests/test_detect.py: 15%

113 statements  

« prev     ^ index     » next       coverage.py v7.4.3, created at 2024-03-01 11:54 +0000

1# This file is part of lsst.scarlet.lite. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import os 

23 

24import numpy as np 

25from lsst.scarlet.lite import Box, Image 

26from lsst.scarlet.lite.detect import bounds_to_bbox, footprints_to_image, get_detect_wavelets, get_wavelets 

27from lsst.scarlet.lite.detect_pybind11 import ( 

28 Footprint, 

29 Peak, 

30 get_connected_multipeak, 

31 get_connected_pixels, 

32 get_footprints, 

33) 

34from lsst.scarlet.lite.utils import integrated_circular_gaussian 

35from numpy.testing import assert_array_equal 

36from utils import ScarletTestCase 

37 

38 

39class TestDetect(ScarletTestCase): 

40 def setUp(self): 

41 centers = ( 

42 (17, 9), 

43 (27, 14), 

44 (41, 25), 

45 (10, 42), 

46 ) 

47 sigmas = (1.0, 0.95, 0.9, 1.5) 

48 

49 sources = [] 

50 for sigma, center in zip(sigmas, centers): 

51 yx0 = center[0] - 7, center[1] - 7 

52 source = Image(integrated_circular_gaussian(sigma=sigma).astype(np.float32), yx0=yx0) 

53 sources.append(source) 

54 

55 image = Image.from_box(Box((51, 51))) 

56 for source in sources: 

57 image += source 

58 image.data[30:32, 40] = 0.5 

59 

60 self.image = image 

61 self.centers = centers 

62 self.sources = sources 

63 

64 filename = os.path.join(__file__, "..", "..", "data", "hsc_cosmos_35.npz") 

65 filename = os.path.abspath(filename) 

66 self.hsc_data = np.load(filename) 

67 

68 def tearDown(self): 

69 del self.hsc_data 

70 

71 def test_connected(self): 

72 image = self.image.copy() 

73 

74 # Check that the first 3 footprints are all connected 

75 # with thresholding at zero 

76 truth = self.sources[0] + self.sources[1] + self.sources[2] 

77 bbox = truth.bbox 

78 truth = truth.data > 0 

79 

80 unchecked = np.ones(self.image.shape, dtype=bool) 

81 footprint = np.zeros(self.image.shape, dtype=bool) 

82 y, x = self.centers[0] 

83 get_connected_pixels( 

84 y, 

85 x, 

86 image.data, 

87 unchecked, 

88 footprint, 

89 np.array([y, y, x, x]).astype(np.int32), 

90 0, 

91 ) 

92 assert_array_equal(footprint[bbox.slices], truth) 

93 

94 # Check that only the first 2 footprints are all connected 

95 # with thresholding at 1e-15 

96 truth = self.sources[0] + self.sources[1] 

97 bbox = truth.bbox 

98 truth = truth.data > 1e-15 

99 

100 unchecked = np.ones(self.image.shape, dtype=bool) 

101 footprint = np.zeros(self.image.shape, dtype=bool) 

102 y, x = self.centers[0] 

103 get_connected_pixels( 

104 y, 

105 x, 

106 image.data, 

107 unchecked, 

108 footprint, 

109 np.array([y, y, x, x]).astype(np.int32), 

110 1e-15, 

111 ) 

112 assert_array_equal(footprint[bbox.slices], truth) 

113 

114 # Test finding all peaks 

115 footprint = get_connected_multipeak(self.image.data, self.centers, 1e-15) 

116 truth = self.image.data > 1e-15 

117 truth[30:32, 40] = False 

118 assert_array_equal(footprint, truth) 

119 

120 def test_get_footprints(self): 

121 footprints = get_footprints(self.image.data, 1, 4, 1e-15, True) 

122 self.assertEqual(len(footprints), 3) 

123 

124 # The first footprint has a single peak 

125 assert_array_equal(footprints[0].data, self.sources[3].data > 1e-15) 

126 self.assertEqual(len(footprints[0].peaks), 1) 

127 self.assertBoxEqual(footprints[0].bbox, self.sources[3].bbox) 

128 self.assertEqual(footprints[0].peaks[0].y, self.centers[3][0]) 

129 self.assertEqual(footprints[0].peaks[0].x, self.centers[3][1]) 

130 

131 # The second footprint has two peaks 

132 truth = self.sources[0] + self.sources[1] 

133 assert_array_equal(footprints[1].data, truth.data > 1e-15) 

134 self.assertEqual(len(footprints[1].peaks), 2) 

135 self.assertBoxEqual(footprints[1].bbox, truth.bbox) 

136 self.assertEqual(footprints[1].peaks[0].y, self.centers[1][0]) 

137 self.assertEqual(footprints[1].peaks[0].x, self.centers[1][1]) 

138 self.assertEqual(footprints[1].peaks[1].y, self.centers[0][0]) 

139 self.assertEqual(footprints[1].peaks[1].x, self.centers[0][1]) 

140 

141 # The third footprint has a single peak 

142 assert_array_equal(footprints[2].data, self.sources[2].data > 1e-15) 

143 self.assertEqual(len(footprints[2].peaks), 1) 

144 self.assertBoxEqual(footprints[2].bbox, self.sources[2].bbox) 

145 self.assertEqual(footprints[2].peaks[0].y, self.centers[2][0]) 

146 self.assertEqual(footprints[2].peaks[0].x, self.centers[2][1]) 

147 

148 truth = 1 * self.sources[3] + 2 * (self.sources[0] + self.sources[1]) + 3 * self.sources[2] 

149 truth.data[truth.data < 1e-15] = 0 

150 fp_image = footprints_to_image(footprints, truth.shape) 

151 assert_array_equal(fp_image, truth.data) 

152 

153 def test_bounds_to_bbox(self): 

154 bounds = (3, 27, 11, 52) 

155 truth = Box((25, 42), (3, 11)) 

156 bbox = bounds_to_bbox(bounds) 

157 self.assertBoxEqual(bbox, truth) 

158 

159 def test_footprint(self): 

160 footprint = self.sources[0].data 

161 footprint[footprint < 1e-15] = 0 

162 bounds = [ 

163 self.sources[0].bbox.start[0], 

164 self.sources[0].bbox.stop[0], 

165 self.sources[0].bbox.start[1], 

166 self.sources[0].bbox.stop[1], 

167 ] 

168 peaks = [Peak(self.centers[0][0], self.centers[0][1], self.image.data[self.centers[0]])] 

169 footprint1 = Footprint(footprint, peaks, bounds) 

170 footprint = self.sources[1].data 

171 footprint[footprint < 1e-15] = 0 

172 bounds = [ 

173 self.sources[1].bbox.start[0], 

174 self.sources[1].bbox.stop[0], 

175 self.sources[1].bbox.start[1], 

176 self.sources[1].bbox.stop[1], 

177 ] 

178 peaks = [Peak(self.centers[1][0], self.centers[1][1], self.image.data[self.centers[1]])] 

179 footprint2 = Footprint(footprint, peaks, bounds) 

180 

181 truth = self.sources[0] + self.sources[1] 

182 truth.data[truth.data < 1e-15] = 0 

183 image = footprints_to_image([footprint1, footprint2], truth.shape) 

184 assert_array_equal(image, truth.data) 

185 

186 # Test intersection 

187 truth = (self.sources[0] > 1e-15) & (self.sources[1] > 1e-15) 

188 intersection = footprint1.intersection(footprint2) 

189 self.assertImageEqual(intersection, truth) 

190 

191 # Test union 

192 truth = (self.sources[0] > 1e-15) | (self.sources[1] > 1e-15) 

193 union = footprint1.union(footprint2) 

194 self.assertImageEqual(union, truth) 

195 

196 def test_get_wavelets(self): 

197 images = self.hsc_data["images"] 

198 variance = self.hsc_data["variance"] 

199 wavelets = get_wavelets(images, variance) 

200 

201 self.assertTupleEqual(wavelets.shape, (5, 5, 58, 48)) 

202 self.assertEqual(wavelets.dtype, np.float32) 

203 

204 def test_get_detect_wavelets(self): 

205 images = self.hsc_data["images"] 

206 variance = self.hsc_data["variance"] 

207 wavelets = get_detect_wavelets(images, variance) 

208 

209 self.assertTupleEqual(wavelets.shape, (4, 58, 48))