Coverage for tests / test_initialization.py: 15%

122 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-17 08:40 +0000

1# This file is part of lsst.scarlet.lite. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import os 

23 

24import numpy as np 

25from deprecated.sphinx import deprecated 

26from lsst.scarlet.lite import Box, Image, Observation 

27from lsst.scarlet.lite.initialization import ( 

28 FactorizedInitialization, 

29 FactorizedWaveletInitialization, 

30 init_monotonic_morph, 

31 multifit_spectra, 

32 trim_morphology, 

33) 

34from lsst.scarlet.lite.operators import Monotonicity, prox_monotonic_mask 

35from lsst.scarlet.lite.utils import integrated_circular_gaussian 

36from numpy.testing import assert_almost_equal, assert_array_equal 

37from scipy.signal import convolve as scipy_convolve 

38from utils import ObservationData, ScarletTestCase 

39 

40 

41class TestInitialization(ScarletTestCase): 

42 def setUp(self) -> None: 

43 yx0 = (1000, 2000) 

44 filename = os.path.join(__file__, "..", "..", "data", "hsc_cosmos_35.npz") 

45 filename = os.path.abspath(filename) 

46 data = np.load(filename) 

47 model_psf = integrated_circular_gaussian(sigma=0.8) 

48 self.detect = np.sum(data["images"], axis=0) 

49 self.centers = np.array([data["catalog"]["y"], data["catalog"]["x"]]).T + np.array(yx0) 

50 bands = data["filters"] 

51 self.observation = Observation( 

52 Image(data["images"], bands=bands, yx0=yx0), 

53 Image(data["variance"], bands=bands, yx0=yx0), 

54 Image(1 / data["variance"], bands=bands, yx0=yx0), 

55 data["psfs"], 

56 model_psf[None], 

57 bands=bands, 

58 ) 

59 

60 def test_trim_morphology(self): 

61 # Test default parameters 

62 morph = np.zeros((50, 50)).astype(np.float32) 

63 morph[10:15, 12:27] = 1 

64 trimmed, trimmed_box = trim_morphology(morph) 

65 assert_array_equal(trimmed, morph) 

66 self.assertTupleEqual(trimmed_box.origin, (5, 7)) 

67 self.assertTupleEqual(trimmed_box.shape, (15, 25)) 

68 self.assertEqual(trimmed.dtype, np.float32) 

69 

70 # Test with parameters specified 

71 morph = np.full((50, 50), 0.1).astype(np.float32) 

72 morph[10:15, 12:27] = 1 

73 truth = np.zeros(morph.shape) 

74 truth[10:15, 12:27] = 1 

75 trimmed, trimmed_box = trim_morphology(morph, 0.5, 1) 

76 assert_array_equal(trimmed, truth) 

77 self.assertTupleEqual(trimmed_box.origin, (9, 11)) 

78 self.assertTupleEqual(trimmed_box.shape, (7, 17)) 

79 self.assertEqual(trimmed.dtype, np.float32) 

80 

81 def test_init_monotonic_mask(self): 

82 full_box = self.observation.bbox 

83 center = self.centers[0] 

84 local_center = (center[0] - full_box.origin[0], center[1] - full_box.origin[1]) 

85 

86 # Default parameters 

87 bbox, morph = init_monotonic_morph(self.detect.copy(), center, full_box) 

88 self.assertBoxEqual(bbox, Box((38, 29), (1014, 2000))) 

89 _, masked_morph, _ = prox_monotonic_mask(self.detect.copy(), local_center, max_iter=0) 

90 assert_array_equal(morph, masked_morph / np.max(masked_morph)) 

91 self.assertEqual(morph.dtype, np.float32) 

92 

93 # Specifying parameters 

94 bbox, morph = init_monotonic_morph( 

95 self.detect.copy(), 

96 center, 

97 full_box, 

98 0, # padding 

99 False, # normalizae 

100 None, # monotonicity 

101 0.2, # threshold 

102 ) 

103 self.assertBoxEqual(bbox, Box((26, 21), (1021, 2003))) 

104 # Remove pixels below the threshold 

105 truth = masked_morph.copy() 

106 truth[truth < 0.2] = 0 

107 assert_array_equal(morph, truth) 

108 self.assertEqual(morph.dtype, np.float32) 

109 

110 # Test an empty morphology 

111 bbox, morph = init_monotonic_morph(np.zeros(self.detect.shape), center, full_box) 

112 self.assertBoxEqual(bbox, Box((0, 0))) 

113 self.assertIsNone(morph) 

114 

115 def test_init_monotonic_weighted(self): 

116 full_box = self.observation.bbox 

117 center = self.centers[0] 

118 local_center = (center[0] - full_box.origin[0], center[1] - full_box.origin[1]) 

119 monotonicity = Monotonicity((101, 101)) 

120 

121 # Default parameters 

122 bbox, morph = init_monotonic_morph(self.detect.copy(), center, full_box, monotonicity=monotonicity) 

123 truth = monotonicity(self.detect.copy(), local_center) 

124 truth[truth < 0] = 0 

125 truth = truth / np.max(truth) 

126 self.assertBoxEqual(bbox, Box((58, 48), origin=(1000, 2000))) 

127 assert_array_equal(morph, truth) 

128 self.assertEqual(morph.dtype, np.float32) 

129 

130 # Specify parameters 

131 bbox, morph = init_monotonic_morph( 

132 self.detect.copy(), 

133 center, 

134 full_box, 

135 0, # padding 

136 False, # normalize 

137 monotonicity, # monotonicity 

138 0.2, # threshold 

139 ) 

140 truth = monotonicity(self.detect.copy(), local_center) 

141 truth[truth < 0.2] = 0 

142 self.assertBoxEqual(bbox, Box((45, 44), origin=(1010, 2003))) 

143 assert_array_equal(morph, truth) 

144 self.assertEqual(morph.dtype, np.float32) 

145 

146 # Test zero morphology 

147 zeros = np.zeros(self.detect.shape) 

148 bbox, morph = init_monotonic_morph(zeros, center, full_box, monotonicity=monotonicity) 

149 self.assertBoxEqual(bbox, Box((0, 0), (1000, 2000))) 

150 self.assertIsNone(morph) 

151 

152 def test_multifit_spectra(self): 

153 bands = ("g", "r", "i") 

154 variance = np.ones((3, 35, 35), dtype=np.float32) 

155 weights = 1 / variance 

156 psfs = np.array([integrated_circular_gaussian(sigma=sigma) for sigma in [1.05, 0.9, 1.2]]) 

157 psfs = psfs.astype(np.float32) 

158 model_psf = integrated_circular_gaussian(sigma=0.8).astype(np.float32) 

159 

160 # The spectrum of each source 

161 spectra = np.array( 

162 [ 

163 [31, 10, 0], 

164 [0, 5, 20], 

165 [15, 8, 3], 

166 [20, 3, 4], 

167 [0, 30, 60], 

168 ], 

169 dtype=np.float32, 

170 ) 

171 

172 # Use a point source for all of the sources 

173 morphs = [ 

174 integrated_circular_gaussian(sigma=sigma).astype(np.float32) 

175 for sigma in [0.8, 3.1, 1.1, 2.1, 1.5] 

176 ] 

177 # Make the second component a disk component 

178 morphs[1] = scipy_convolve(morphs[1], model_psf, mode="same") 

179 

180 # Give the first two components the same center, and unique centers 

181 # for the remaining sources 

182 centers = [ 

183 (10, 12), 

184 (10, 12), 

185 (20, 23), 

186 (20, 10), 

187 (25, 20), 

188 ] 

189 

190 # Create the Observation 

191 test_data = ObservationData(bands, psfs, spectra, morphs, centers, model_psf, dtype=np.float32) 

192 observation = Observation( 

193 test_data.convolved, 

194 variance, 

195 weights, 

196 psfs, 

197 model_psf[None], 

198 bands=bands, 

199 ) 

200 

201 fit_spectra = multifit_spectra(observation, test_data.morphs) 

202 self.assertEqual(fit_spectra.dtype, spectra.dtype) 

203 assert_almost_equal(fit_spectra, spectra, decimal=5) 

204 

205 def test_factorized_chi2_init(self): 

206 # Test default parameters 

207 init = FactorizedInitialization(self.observation, self.centers) 

208 self.assertEqual(init.observation, self.observation) 

209 self.assertEqual(init.min_snr, 50) 

210 self.assertIsNone(init.monotonicity) 

211 self.assertEqual(init.disk_percentile, 25) 

212 self.assertEqual(init.thresh, 0.5) 

213 self.assertTupleEqual((init.py, init.px), (7, 7)) 

214 self.assertEqual(len(init.sources), 7) 

215 for src in init.sources: 

216 self.assertEqual(src.get_model().dtype, np.float32) 

217 

218 centers = tuple(tuple(center.astype(int)) for center in self.centers) + ((1000, 2004),) 

219 init = FactorizedInitialization(self.observation, centers) 

220 self.assertEqual(len(init.sources), 8) 

221 for src in init.sources: 

222 self.assertEqual(src.get_model().dtype, np.float32) 

223 

224 @deprecated( 

225 version="v29.0", 

226 reason="FactorizedWaveletInitialization is deprecated and will be removed after v29.0", 

227 ) 

228 def test_factorized_wavelet_init(self): 

229 # Test default parameters 

230 init = FactorizedWaveletInitialization(self.observation, self.centers) 

231 self.assertEqual(init.observation, self.observation) 

232 self.assertEqual(init.min_snr, 50) 

233 self.assertIsNone(init.monotonicity) 

234 self.assertTupleEqual((init.py, init.px), (7, 7)) 

235 self.assertEqual(len(init.sources), 7) 

236 components = np.sum([len(src.components) for src in init.sources]) 

237 self.assertEqual(components, 8) 

238 for src in init.sources: 

239 self.assertEqual(src.get_model().dtype, np.float32)