Coverage for tests / test_blend.py: 12%

186 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-26 08:40 +0000

1# This file is part of scarlet_lite. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24from typing import cast 

25 

26import numpy as np 

27from lsst.scarlet.lite import Blend, Box, Image, Observation, Source 

28from lsst.scarlet.lite.component import CubeComponent, FactorizedComponent, default_adaprox_parameterization 

29from lsst.scarlet.lite.initialization import FactorizedInitialization 

30from lsst.scarlet.lite.operators import Monotonicity 

31from lsst.scarlet.lite.utils import integrated_circular_gaussian 

32from numpy.testing import assert_almost_equal, assert_raises 

33from scipy.signal import convolve as scipy_convolve 

34from utils import ObservationData, ScarletTestCase 

35 

36 

37class TestBlend(ScarletTestCase): 

38 def setUp(self): 

39 bands = ("g", "r", "i") 

40 yx0 = (1000, 2000) 

41 # The PSF in each band of the "observation" 

42 psfs = np.array([integrated_circular_gaussian(sigma=sigma) for sigma in [1.05, 0.9, 1.2]]) 

43 # The PSF of the model 

44 model_psf = integrated_circular_gaussian(sigma=0.8) 

45 

46 # The spectrum of each source 

47 spectra = np.array( 

48 [ 

49 [40, 10, 0], 

50 [0, 25, 40], 

51 [15, 8, 3], 

52 [20, 3, 4], 

53 [0, 30, 60], 

54 ], 

55 dtype=float, 

56 ) 

57 

58 # Use a point source for all of the sources 

59 morphs = [integrated_circular_gaussian(sigma=sigma) for sigma in [0.8, 2.5, 1.1, 2.1, 1.5]] 

60 # Make the second component a disk component 

61 morphs[1] = scipy_convolve(morphs[1], model_psf, mode="same") 

62 

63 # Give the first two components the same center, and unique centers 

64 # for the remaining sources 

65 centers = [ 

66 (1010, 2012), 

67 (1010, 2012), 

68 (1020, 2023), 

69 (1020, 2010), 

70 (1025, 2020), 

71 ] 

72 

73 # Create the simulated image and associated data products 

74 test_data = ObservationData(bands, psfs, spectra, morphs, centers, model_psf, yx0=yx0) 

75 

76 # Create the Observation 

77 variance = np.ones((3, 35, 35), dtype=float) * 1e-2 

78 weights = 1 / variance 

79 weights = weights / np.max(weights) 

80 self.observation = Observation( 

81 test_data.convolved, 

82 variance, 

83 weights, 

84 psfs, 

85 model_psf[None], 

86 bands=bands, 

87 bbox=Box(variance.shape[-2:], origin=yx0), 

88 ) 

89 self.data = test_data 

90 self.spectra = spectra 

91 self.centers = centers 

92 self.morphs = morphs 

93 

94 components = [] 

95 for spectrum, center, morph, data_morph in zip( 

96 self.spectra, self.centers, self.morphs, self.data.morphs 

97 ): 

98 components.append( 

99 FactorizedComponent( 

100 bands=bands, 

101 spectrum=spectrum, 

102 morph=morph, 

103 bbox=data_morph.bbox, 

104 peak=center, 

105 ) 

106 ) 

107 

108 sources = [Source(components[:2])] 

109 sources += [Source([component]) for component in components[2:]] 

110 

111 self.blend = Blend(sources, self.observation) 

112 

113 def test_exact(self): 

114 """Test that a blend model initialized with the exact solution 

115 builds the model correctly 

116 """ 

117 blend = self.blend 

118 self.assertEqual(len(blend.components), 5) 

119 self.assertEqual(len(blend.sources), 4) 

120 self.assertBoxEqual(blend.bbox, Box(self.data.images.shape[1:], self.observation.bbox.origin)) 

121 self.assertImageAlmostEqual(blend.get_model(), self.data.images) 

122 self.assertImageAlmostEqual(blend.get_model(convolve=True), self.observation.images) 

123 self.assertImageAlmostEqual( 

124 self.observation.convolve(blend.get_model(), mode="real"), 

125 self.observation.images, 

126 ) 

127 

128 # Test that the log likelihood is very small 

129 assert_almost_equal([blend.log_likelihood], [0]) 

130 

131 # Test that grad_log_likelihood updates the loss 

132 self.assertListEqual(blend.loss, []) 

133 blend._grad_log_likelihood() 

134 assert_almost_equal(blend.loss, [0]) 

135 

136 # Remove one of the sources and calculate the non-zero log_likelihood 

137 del blend.sources[-1] 

138 # Update the loss function and check that the loss changed 

139 blend._grad_log_likelihood() 

140 assert_almost_equal(blend.log_likelihood, -60.011720889007485) 

141 assert_almost_equal(blend.loss, [0, -60.011720889007485]) 

142 

143 def test_fit_spectra(self): 

144 """Test that fitting the spectra with exact morphologies is 

145 identical to the multiband image 

146 """ 

147 np.random.seed(0) 

148 blend = self.blend 

149 

150 # Change the initial spectra so that they can be fit later 

151 for component in blend.components: 

152 c = cast(FactorizedComponent, component) 

153 c.spectrum[:] = np.random.rand(3) * 10 

154 

155 with assert_raises(AssertionError): 

156 # Since the spectra have not yet been fit, 

157 # the model and images should not be equal 

158 self.assertImageEqual(blend.get_model(), self.data.images) 

159 

160 # We initialized all of the morphologies exactly, 

161 # so fitting the spectra should give a nearly exact solution 

162 blend.fit_spectra() 

163 

164 self.assertEqual(len(blend.components), 5) 

165 self.assertEqual(len(blend.sources), 4) 

166 self.assertBoxEqual(blend.bbox, self.observation.bbox) 

167 self.assertImageAlmostEqual(blend.get_model(), self.data.images) 

168 self.assertImageAlmostEqual(blend.get_model(convolve=True), self.observation.images) 

169 

170 def test_fit(self): 

171 observation = self.observation 

172 np.random.seed(0) 

173 images = observation.images.copy() 

174 noise = np.random.normal(size=observation.images.shape) * 1e-2 

175 observation.images._data += noise 

176 

177 monotonicity = Monotonicity((101, 101)) 

178 init = FactorizedInitialization(observation, self.centers, monotonicity=monotonicity) 

179 

180 blend = Blend(init.sources, self.observation).fit_spectra() 

181 blend.parameterize(default_adaprox_parameterization) 

182 blend.fit(100) 

183 

184 self.assertImageAlmostEqual(blend.get_model(convolve=True), images, decimal=1) 

185 

186 def test_non_factorized(self): 

187 np.random.seed(1) 

188 blend = self.blend 

189 # Remove the disk component from the first source 

190 model = self.spectra[1][:, None, None] * self.morphs[1][None, :, :] 

191 yx0 = blend.sources[0].components[1].bbox.origin 

192 blend.sources[0].components = blend.sources[0].components[:1] 

193 

194 # Change the initial spectra so that they can be fit later 

195 for component in blend.components: 

196 c = cast(FactorizedComponent, component) 

197 c.spectrum[:] = np.random.rand(3) * 10 

198 

199 with assert_raises(AssertionError): 

200 # Since the spectra have not yet been fit, 

201 # the model and images should not be equal 

202 self.assertImageEqual(blend.get_model(), self.data.images) 

203 

204 # Remove the disk component from the first source 

205 blend.sources[0].components = blend.sources[0].components[:1] 

206 # Create a new source for the disk with a non-factorized component 

207 component = CubeComponent(Image(model, bands=self.blend.observation.bands, yx0=yx0), (0, 0)) 

208 blend.sources.append(Source([component])) 

209 

210 blend.fit_spectra() 

211 

212 self.assertEqual(len(blend.components), 5) 

213 self.assertEqual(len(blend.sources), 5) 

214 self.assertImageAlmostEqual(blend.get_model(), self.data.images) 

215 

216 def test_clipping(self): 

217 blend = self.blend 

218 

219 # Change the initial spectra so that they can be fit later 

220 for component in blend.components: 

221 c = cast(FactorizedComponent, component) 

222 c.spectrum[:] = np.random.rand(3) * 10 

223 

224 with assert_raises(AssertionError): 

225 # Since the spectra have not yet been fit, 

226 # the model and images should not be equal 

227 self.assertImageEqual(blend.get_model(), self.data.images) 

228 

229 # Add an empty source 

230 zero_model = Image.from_box(Box((5, 5), (30, 0)), bands=blend.observation.bands) 

231 component = CubeComponent(zero_model, (0, 0)) 

232 blend.sources.append(Source([component])) 

233 

234 blend.fit_spectra(clip=True) 

235 

236 self.assertEqual(len(blend.components), 5) 

237 self.assertEqual(len(blend.sources), 5) 

238 self.assertImageAlmostEqual(blend.get_model(), self.data.images) 

239 

240 def test_shallow_copy(self): 

241 blend = self.blend 

242 blend.metadata = {"test": "value"} 

243 blend_copy = blend.copy() 

244 

245 self.assertIsNot(blend_copy, blend) 

246 self.assertEqual(len(blend_copy.sources), len(blend.sources)) 

247 for source_copy, source in zip(blend_copy.sources, blend.sources): 

248 self.assertSourceEqual(source_copy, source) 

249 

250 self.assertObservationEqual(blend_copy.observation, blend.observation) 

251 

252 self.assertDictEqual(blend_copy.metadata, blend.metadata) 

253 

254 def test_deepcopy(self): 

255 blend = self.blend 

256 blend.metadata = {"test": "value"} 

257 blend_copy = blend.copy(deep=True) 

258 

259 self.assertIsNot(blend_copy, blend) 

260 self.assertEqual(len(blend_copy.sources), len(blend.sources)) 

261 for source_copy, source in zip(blend_copy.sources, blend.sources): 

262 self.assertSourceEqual(source_copy, source) 

263 

264 with self.assertRaises(AssertionError): 

265 source_copy.components[0]._spectrum.x += 1 

266 self.assertSourceEqual(source_copy, source) 

267 

268 self.assertObservationEqual(blend_copy.observation, blend.observation) 

269 self.assertDictEqual(blend_copy.metadata, blend.metadata) 

270 blend_copy.metadata["test"] = "new_value" 

271 with self.assertRaises(AssertionError): 

272 self.assertDictEqual(blend_copy.metadata, blend.metadata) 

273 

274 def test_slice(self): 

275 blend = self.blend 

276 blend.metadata = {"test": "value"} 

277 blend_sliced = blend["g":"r"] 

278 self.assertEqual(len(blend.sources), len(blend_sliced.sources)) 

279 

280 for source_sliced, source in zip(blend_sliced.sources, blend.sources): 

281 self.assertSourceEqual(source_sliced, source["g":"r"]) 

282 

283 self.assertObservationEqual(blend_sliced.observation, blend.observation["g":"r"]) 

284 self.assertDictEqual(blend_sliced.metadata, blend.metadata) 

285 

286 def test_reorder(self): 

287 blend = self.blend 

288 blend.metadata = {"test": "value"} 

289 indices = ("i", "g", "r") 

290 blend_reordered = blend[indices] 

291 self.assertEqual(len(blend.sources), len(blend_reordered.sources)) 

292 

293 for source_reordered, source in zip(blend_reordered.sources, blend.sources): 

294 self.assertSourceEqual(source_reordered, source[indices]) 

295 

296 self.assertObservationEqual(blend_reordered.observation, blend.observation[indices]) 

297 self.assertDictEqual(blend_reordered.metadata, blend.metadata) 

298 

299 def test_subset(self): 

300 blend = self.blend 

301 blend.metadata = {"test": "value"} 

302 blend_subset = blend[("r",)] 

303 self.assertEqual(len(blend.sources), len(blend_subset.sources)) 

304 

305 for source_subset, source in zip(blend_subset.sources, blend.sources): 

306 self.assertSourceEqual(source_subset, source["r"]) 

307 

308 self.assertObservationEqual(blend_subset.observation, blend.observation["r"]) 

309 self.assertDictEqual(blend_subset.metadata, blend.metadata) 

310 

311 def test_indexing_errors(self): 

312 blend = self.blend 

313 

314 with self.assertRaises(IndexError): 

315 blend["x"] 

316 

317 with self.assertRaises(IndexError): 

318 blend[("r", "x")] 

319 

320 with self.assertRaises(IndexError): 

321 blend["r":"x"] 

322 

323 with self.assertRaises(IndexError): 

324 blend["x":"i"] 

325 

326 with self.assertRaises(IndexError): 

327 blend["g", "x", "i"] 

328 

329 with self.assertRaises(IndexError): 

330 blend[Box((0, 0), (10, 10))] 

331 

332 with self.assertRaises(IndexError): 

333 blend[:, 10:20, 10:20] 

334 

335 with self.assertRaises(IndexError): 

336 blend[1:] 

337 

338 with self.assertRaises(IndexError): 

339 blend[1] 

340 

341 with self.assertRaises(IndexError): 

342 blend[0, 1]