Coverage for tests/test_initialization.py: 13%

120 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-16 02:46 -0700

1# This file is part of lsst.scarlet.lite. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import os 

23 

24import numpy as np 

25from lsst.scarlet.lite import Box, Image, Observation 

26from lsst.scarlet.lite.initialization import ( 

27 FactorizedChi2Initialization, 

28 FactorizedWaveletInitialization, 

29 init_monotonic_morph, 

30 multifit_spectra, 

31 trim_morphology, 

32) 

33from lsst.scarlet.lite.operators import Monotonicity, prox_monotonic_mask 

34from lsst.scarlet.lite.utils import integrated_circular_gaussian 

35from numpy.testing import assert_almost_equal, assert_array_equal 

36from scipy.signal import convolve as scipy_convolve 

37from utils import ObservationData, ScarletTestCase 

38 

39 

40class TestInitialization(ScarletTestCase): 

41 def setUp(self) -> None: 

42 yx0 = (1000, 2000) 

43 filename = os.path.join(__file__, "..", "..", "data", "hsc_cosmos_35.npz") 

44 filename = os.path.abspath(filename) 

45 data = np.load(filename) 

46 model_psf = integrated_circular_gaussian(sigma=0.8) 

47 self.detect = np.sum(data["images"], axis=0) 

48 self.centers = np.array([data["catalog"]["y"], data["catalog"]["x"]]).T + np.array(yx0) 

49 bands = data["filters"] 

50 self.observation = Observation( 

51 Image(data["images"], bands=bands, yx0=yx0), 

52 Image(data["variance"], bands=bands, yx0=yx0), 

53 Image(1 / data["variance"], bands=bands, yx0=yx0), 

54 data["psfs"], 

55 model_psf[None], 

56 bands=bands, 

57 ) 

58 

59 def test_trim_morphology(self): 

60 # Test default parameters 

61 morph = np.zeros((50, 50)).astype(np.float32) 

62 morph[10:15, 12:27] = 1 

63 trimmed, trimmed_box = trim_morphology(morph) 

64 assert_array_equal(trimmed, morph) 

65 self.assertTupleEqual(trimmed_box.origin, (5, 7)) 

66 self.assertTupleEqual(trimmed_box.shape, (15, 25)) 

67 self.assertEqual(trimmed.dtype, np.float32) 

68 

69 # Test with parameters specified 

70 morph = np.full((50, 50), 0.1).astype(np.float32) 

71 morph[10:15, 12:27] = 1 

72 truth = np.zeros(morph.shape) 

73 truth[10:15, 12:27] = 1 

74 trimmed, trimmed_box = trim_morphology(morph, 0.5, 1) 

75 assert_array_equal(trimmed, truth) 

76 self.assertTupleEqual(trimmed_box.origin, (9, 11)) 

77 self.assertTupleEqual(trimmed_box.shape, (7, 17)) 

78 self.assertEqual(trimmed.dtype, np.float32) 

79 

80 def test_init_monotonic_mask(self): 

81 full_box = self.observation.bbox 

82 center = self.centers[0] 

83 local_center = (center[0] - full_box.origin[0], center[1] - full_box.origin[1]) 

84 

85 # Default parameters 

86 bbox, morph = init_monotonic_morph(self.detect.copy(), center, full_box) 

87 self.assertBoxEqual(bbox, Box((38, 29), (1014, 2000))) 

88 _, masked_morph, _ = prox_monotonic_mask(self.detect.copy(), local_center, max_iter=0) 

89 assert_array_equal(morph, masked_morph / np.max(masked_morph)) 

90 self.assertEqual(morph.dtype, np.float32) 

91 

92 # Specifying parameters 

93 bbox, morph = init_monotonic_morph( 

94 self.detect.copy(), 

95 center, 

96 full_box, 

97 0, # padding 

98 False, # normalizae 

99 None, # monotonicity 

100 0.2, # threshold 

101 ) 

102 self.assertBoxEqual(bbox, Box((26, 21), (1021, 2003))) 

103 # Remove pixels below the threshold 

104 truth = masked_morph.copy() 

105 truth[truth < 0.2] = 0 

106 assert_array_equal(morph, truth) 

107 self.assertEqual(morph.dtype, np.float32) 

108 

109 # Test an empty morphology 

110 bbox, morph = init_monotonic_morph(np.zeros(self.detect.shape), center, full_box) 

111 self.assertBoxEqual(bbox, Box((0, 0))) 

112 self.assertIsNone(morph) 

113 

114 def test_init_monotonic_weighted(self): 

115 full_box = self.observation.bbox 

116 center = self.centers[0] 

117 local_center = (center[0] - full_box.origin[0], center[1] - full_box.origin[1]) 

118 monotonicity = Monotonicity((101, 101)) 

119 

120 # Default parameters 

121 bbox, morph = init_monotonic_morph(self.detect.copy(), center, full_box, monotonicity=monotonicity) 

122 truth = monotonicity(self.detect.copy(), local_center) 

123 truth[truth < 0] = 0 

124 truth = truth / np.max(truth) 

125 self.assertBoxEqual(bbox, Box((58, 48), origin=(1000, 2000))) 

126 assert_array_equal(morph, truth) 

127 self.assertEqual(morph.dtype, np.float32) 

128 

129 # Specify parameters 

130 bbox, morph = init_monotonic_morph( 

131 self.detect.copy(), 

132 center, 

133 full_box, 

134 0, # padding 

135 False, # normalize 

136 monotonicity, # monotonicity 

137 0.2, # threshold 

138 ) 

139 truth = monotonicity(self.detect.copy(), local_center) 

140 truth[truth < 0.2] = 0 

141 self.assertBoxEqual(bbox, Box((45, 44), origin=(1010, 2003))) 

142 assert_array_equal(morph, truth) 

143 self.assertEqual(morph.dtype, np.float32) 

144 

145 # Test zero morphology 

146 zeros = np.zeros(self.detect.shape) 

147 bbox, morph = init_monotonic_morph(zeros, center, full_box, monotonicity=monotonicity) 

148 self.assertBoxEqual(bbox, Box((0, 0), (1000, 2000))) 

149 self.assertIsNone(morph) 

150 

151 def test_multifit_spectra(self): 

152 bands = ("g", "r", "i") 

153 variance = np.ones((3, 35, 35), dtype=np.float32) 

154 weights = 1 / variance 

155 psfs = np.array([integrated_circular_gaussian(sigma=sigma) for sigma in [1.05, 0.9, 1.2]]) 

156 psfs = psfs.astype(np.float32) 

157 model_psf = integrated_circular_gaussian(sigma=0.8).astype(np.float32) 

158 

159 # The spectrum of each source 

160 spectra = np.array( 

161 [ 

162 [31, 10, 0], 

163 [0, 5, 20], 

164 [15, 8, 3], 

165 [20, 3, 4], 

166 [0, 30, 60], 

167 ], 

168 dtype=np.float32, 

169 ) 

170 

171 # Use a point source for all of the sources 

172 morphs = [ 

173 integrated_circular_gaussian(sigma=sigma).astype(np.float32) 

174 for sigma in [0.8, 3.1, 1.1, 2.1, 1.5] 

175 ] 

176 # Make the second component a disk component 

177 morphs[1] = scipy_convolve(morphs[1], model_psf, mode="same") 

178 

179 # Give the first two components the same center, and unique centers 

180 # for the remaining sources 

181 centers = [ 

182 (10, 12), 

183 (10, 12), 

184 (20, 23), 

185 (20, 10), 

186 (25, 20), 

187 ] 

188 

189 # Create the Observation 

190 test_data = ObservationData(bands, psfs, spectra, morphs, centers, model_psf, dtype=np.float32) 

191 observation = Observation( 

192 test_data.convolved, 

193 variance, 

194 weights, 

195 psfs, 

196 model_psf[None], 

197 bands=bands, 

198 ) 

199 

200 fit_spectra = multifit_spectra(observation, test_data.morphs) 

201 self.assertEqual(fit_spectra.dtype, spectra.dtype) 

202 assert_almost_equal(fit_spectra, spectra, decimal=5) 

203 

204 def test_factorized_chi2_init(self): 

205 # Test default parameters 

206 init = FactorizedChi2Initialization(self.observation, self.centers) 

207 self.assertEqual(init.observation, self.observation) 

208 self.assertEqual(init.min_snr, 50) 

209 self.assertIsNone(init.monotonicity) 

210 self.assertEqual(init.disk_percentile, 25) 

211 self.assertEqual(init.thresh, 0.5) 

212 self.assertTupleEqual((init.py, init.px), (7, 7)) 

213 self.assertEqual(len(init.sources), 7) 

214 for src in init.sources: 

215 self.assertEqual(src.get_model().dtype, np.float32) 

216 

217 centers = tuple(tuple(center.astype(int)) for center in self.centers) + ((1000, 2004),) 

218 init = FactorizedChi2Initialization(self.observation, centers) 

219 self.assertEqual(len(init.sources), 8) 

220 for src in init.sources: 

221 self.assertEqual(src.get_model().dtype, np.float32) 

222 

223 def test_factorized_wavelet_init(self): 

224 # Test default parameters 

225 init = FactorizedWaveletInitialization(self.observation, self.centers) 

226 self.assertEqual(init.observation, self.observation) 

227 self.assertEqual(init.min_snr, 50) 

228 self.assertIsNone(init.monotonicity) 

229 self.assertTupleEqual((init.py, init.px), (7, 7)) 

230 self.assertEqual(len(init.sources), 7) 

231 components = np.sum([len(src.components) for src in init.sources]) 

232 self.assertEqual(components, 8) 

233 for src in init.sources: 

234 self.assertEqual(src.get_model().dtype, np.float32)