Coverage for tests/test_deblend.py: 11%

188 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-27 11:16 +0000

1# This file is part of meas_extensions_scarlet. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import unittest 

23 

24import numpy as np 

25from numpy.testing import assert_almost_equal 

26 

27from lsst.geom import Point2I, Point2D 

28import lsst.utils.tests 

29import lsst.afw.image as afwImage 

30from lsst.meas.algorithms import SourceDetectionTask 

31from lsst.meas.extensions.scarlet.scarletDeblendTask import ScarletDeblendTask 

32from lsst.meas.extensions.scarlet.utils import bboxToScarletBox, scarletBoxToBBox 

33from lsst.meas.extensions.scarlet.io import monochromaticDataToScarlet, updateCatalogFootprints 

34import lsst.scarlet.lite as scl 

35from lsst.afw.table import SourceCatalog 

36from lsst.afw.detection import Footprint 

37from lsst.afw.geom import SpanSet, Stencil 

38 

39from utils import initData 

40 

41 

42class TestDeblend(lsst.utils.tests.TestCase): 

43 def setUp(self): 

44 # Set the random seed so that the noise field is unaffected 

45 np.random.seed(0) 

46 shape = (5, 100, 115) 

47 coords = [ 

48 # blend 

49 (15, 25), (10, 30), (17, 38), 

50 # isolated source 

51 (85, 90), 

52 ] 

53 amplitudes = [ 

54 # blend 

55 80, 60, 90, 

56 # isolated source 

57 20, 

58 ] 

59 result = initData(shape, coords, amplitudes) 

60 targetPsfImage, psfImages, images, channels, seds, morphs, psfs = result 

61 B, Ny, Nx = shape 

62 

63 # Add some noise, otherwise the task will blow up due to 

64 # zero variance 

65 noise = 10*(np.random.rand(*images.shape).astype(np.float32)-.5) 

66 images += noise 

67 

68 self.bands = "grizy" 

69 _images = afwImage.MultibandMaskedImage.fromArrays( 

70 self.bands, 

71 images.astype(np.float32), 

72 None, 

73 noise**2 

74 ) 

75 coadds = [afwImage.Exposure(img, dtype=img.image.array.dtype) for img in _images] 

76 self.coadds = afwImage.MultibandExposure.fromExposures(self.bands, coadds) 

77 for b, coadd in enumerate(self.coadds): 

78 coadd.setPsf(psfs[b]) 

79 

80 def _insert_blank_source(self, modelData, catalog): 

81 # Add parent 

82 parent = catalog.addNew() 

83 parent.setParent(0) 

84 parent["deblend_nChild"] = 1 

85 parent["deblend_nPeaks"] = 1 

86 ss = SpanSet.fromShape(5, Stencil.CIRCLE, offset=(30, 70)) 

87 footprint = Footprint(ss) 

88 peak = footprint.addPeak(30, 70, 0) 

89 parent.setFootprint(footprint) 

90 

91 # Add the zero flux source 

92 dtype = np.float32 

93 center = (70, 30) 

94 origin = (center[0]-5, center[1]-5) 

95 psf = list(modelData.blends.values())[0].psf 

96 src = catalog.addNew() 

97 src.setParent(parent.getId()) 

98 src["deblend_peak_center_x"] = center[1] 

99 src["deblend_peak_center_y"] = center[0] 

100 src["deblend_nPeaks"] = 1 

101 

102 sources = { 

103 src.getId(): { 

104 "components": [], 

105 "factorized": [{ 

106 "origin": origin, 

107 "peak": center, 

108 "spectrum": np.zeros((len(self.bands),), dtype=dtype), 

109 "morph": np.zeros((11, 11), dtype=dtype), 

110 "shape": (11, 11), 

111 }], 

112 "peak_id": peak.getId(), 

113 } 

114 } 

115 

116 blendData = scl.io.ScarletBlendData.from_dict({ 

117 "origin": origin, 

118 "shape": (11, 11), 

119 "psf_center": center, 

120 "psf_shape": psf.shape, 

121 "psf": psf.flatten(), 

122 "sources": sources, 

123 "bands": self.bands, 

124 }) 

125 pid = parent.getId() 

126 modelData.blends[pid] = blendData 

127 return pid, src.getId() 

128 

129 def _deblend(self, version): 

130 schema = SourceCatalog.Table.makeMinimalSchema() 

131 # Adjust config options to test skipping parents 

132 config = ScarletDeblendTask.ConfigClass() 

133 config.maxIter = 100 

134 config.maxFootprintArea = 1000 

135 config.maxNumberOfPeaks = 4 

136 config.catchFailures = False 

137 config.version = version 

138 

139 # Detect sources 

140 detectionTask = SourceDetectionTask(schema=schema) 

141 deblendTask = ScarletDeblendTask(schema=schema, config=config) 

142 table = SourceCatalog.Table.make(schema) 

143 detectionResult = detectionTask.run(table, self.coadds["r"]) 

144 catalog = detectionResult.sources 

145 

146 # Add a footprint that is too large 

147 src = catalog.addNew() 

148 halfLength = int(np.ceil(np.sqrt(config.maxFootprintArea) + 1)) 

149 ss = SpanSet.fromShape(halfLength, Stencil.BOX, offset=(50, 50)) 

150 bigfoot = Footprint(ss) 

151 bigfoot.addPeak(50, 50, 100) 

152 src.setFootprint(bigfoot) 

153 

154 # Add a footprint with too many peaks 

155 src = catalog.addNew() 

156 ss = SpanSet.fromShape(10, Stencil.BOX, offset=(75, 20)) 

157 denseFoot = Footprint(ss) 

158 for n in range(config.maxNumberOfPeaks+1): 

159 denseFoot.addPeak(70+2*n, 15+2*n, 10*n) 

160 src.setFootprint(denseFoot) 

161 

162 # Run the deblender 

163 catalog, modelData = deblendTask.run(self.coadds, catalog) 

164 return catalog, modelData, config 

165 

166 def test_deblend_task(self): 

167 catalog, modelData, config = self._deblend("lite") 

168 

169 bad_blend_id, bad_src_id = self._insert_blank_source(modelData, catalog) 

170 

171 # Attach the footprints in each band and compare to the full 

172 # data model. This is done in each band, both with and without 

173 # flux re-distribution to test all of the different possible 

174 # options of loading catalog footprints. 

175 for useFlux in [False, True]: 

176 for band in self.bands: 

177 bandIndex = self.bands.index(band) 

178 coadd = self.coadds[band] 

179 

180 if useFlux: 

181 imageForRedistribution = coadd 

182 else: 

183 imageForRedistribution = None 

184 

185 updateCatalogFootprints( 

186 modelData, 

187 catalog, 

188 band=band, 

189 imageForRedistribution=imageForRedistribution, 

190 removeScarletData=False, 

191 ) 

192 

193 # Check that the number of deblended children is consistent 

194 parents = catalog[catalog["parent"] == 0] 

195 self.assertEqual(np.sum(catalog["deblend_nChild"]), len(catalog)-len(parents)) 

196 

197 # Check that the models have not been cleared 

198 # from the modelData 

199 self.assertEqual(len(modelData.blends), np.sum(~parents["deblend_skipped"])) 

200 

201 for parent in parents: 

202 children = catalog[catalog["parent"] == parent.get("id")] 

203 # Check that nChild is set correctly 

204 self.assertEqual(len(children), parent.get("deblend_nChild")) 

205 # Check that parent columns are propagated 

206 # to their children 

207 if parent.getId() == bad_blend_id: 

208 continue 

209 for parentCol, childCol in config.columnInheritance.items(): 

210 np.testing.assert_array_equal(parent.get(parentCol), children[childCol]) 

211 

212 children = catalog[catalog["parent"] != 0] 

213 for child in children: 

214 fp = child.getFootprint() 

215 img = fp.extractImage(fill=0.0) 

216 # Check that the flux at the center is correct. 

217 # Note: this only works in this test image because the 

218 # detected peak is in the same location as the 

219 # scarlet peak. 

220 # If the peak is shifted, the flux value will be correct 

221 # but deblend_peak_center is not the correct location. 

222 px = child.get("deblend_peak_center_x") 

223 py = child.get("deblend_peak_center_y") 

224 flux = img[Point2I(px, py)] 

225 self.assertEqual(flux, child.get("deblend_peak_instFlux")) 

226 

227 # Check that the peak positions match the catalog entry 

228 peaks = fp.getPeaks() 

229 self.assertEqual(px, peaks[0].getIx()) 

230 self.assertEqual(py, peaks[0].getIy()) 

231 

232 # Load the data to check against the HeavyFootprint 

233 blendData = modelData.blends[child["parent"]] 

234 # We need to set an observation in order to convolve 

235 # the model. 

236 position = Point2D(*blendData.psf_center[::-1]) 

237 _psfs = self.coadds[band].getPsf().computeKernelImage(position).array[None, :, :] 

238 modelBox = scl.Box(blendData.shape, origin=blendData.origin) 

239 observation = scl.Observation.empty( 

240 bands=("dummy", ), 

241 psfs=_psfs, 

242 model_psf=modelData.psf[None, :, :], 

243 bbox=modelBox, 

244 dtype=np.float32, 

245 ) 

246 blend = monochromaticDataToScarlet( 

247 blendData=blendData, 

248 bandIndex=bandIndex, 

249 observation=observation, 

250 ) 

251 # The stored PSF should be the same as the calculated one 

252 assert_almost_equal(blendData.psf[bandIndex:bandIndex+1], _psfs) 

253 

254 # Get the scarlet model for the source 

255 source = [src for src in blend.sources if src.record_id == child.getId()][0] 

256 self.assertEqual(source.center[1], px) 

257 self.assertEqual(source.center[0], py) 

258 

259 if useFlux: 

260 # Get the flux re-weighted model and test against 

261 # the HeavyFootprint. 

262 # The HeavyFootprint needs to be projected onto 

263 # the image of the flux-redistributed model, 

264 # since the HeavyFootprint may trim rows or columns. 

265 parentFootprint = catalog[catalog["id"] == child["parent"]][0].getFootprint() 

266 _images = imageForRedistribution[parentFootprint.getBBox()].image.array 

267 blend.observation.images = scl.Image( 

268 _images[None, :, :], 

269 yx0=blendData.origin, 

270 bands=("dummy", ), 

271 ) 

272 blend.observation.weights = scl.Image( 

273 parentFootprint.spans.asArray()[None, :, :], 

274 yx0=blendData.origin, 

275 bands=("dummy", ), 

276 ) 

277 blend.conserve_flux() 

278 model = source.flux_weighted_image.data[0] 

279 bbox = scarletBoxToBBox(source.flux_weighted_image.bbox) 

280 image = afwImage.ImageF(model, xy0=bbox.getMin()) 

281 fp.insert(image) 

282 np.testing.assert_almost_equal(image.array, model) 

283 else: 

284 # Get the model for the source and test 

285 # against the HeavyFootprint 

286 bbox = fp.getBBox() 

287 bbox = bboxToScarletBox(bbox) 

288 model = blend.observation.convolve( 

289 source.get_model().project(bbox=bbox), mode="real" 

290 ).data[0] 

291 np.testing.assert_almost_equal(img.array, model) 

292 

293 # Check that all sources have the correct number of peaks 

294 for src in catalog: 

295 fp = src.getFootprint() 

296 self.assertEqual(len(fp.peaks), src.get("deblend_nPeaks")) 

297 

298 # Check that only the large footprint was flagged as too big 

299 largeFootprint = np.zeros(len(catalog), dtype=bool) 

300 largeFootprint[2] = True 

301 np.testing.assert_array_equal(largeFootprint, catalog["deblend_parentTooBig"]) 

302 

303 # Check that only the dense footprint was flagged as too dense 

304 denseFootprint = np.zeros(len(catalog), dtype=bool) 

305 denseFootprint[3] = True 

306 np.testing.assert_array_equal(denseFootprint, catalog["deblend_tooManyPeaks"]) 

307 

308 # Check that only the appropriate parents were skipped 

309 skipped = largeFootprint | denseFootprint 

310 np.testing.assert_array_equal(skipped, catalog["deblend_skipped"]) 

311 

312 # Check that the zero flux source was flagged 

313 for src in catalog: 

314 np.testing.assert_equal(src["deblend_zeroFlux"], src.getId() == bad_src_id) 

315 

316 def test_continuity(self): 

317 """This test ensures that lsst.scarlet.lite gives roughly the same 

318 result as scarlet.lite 

319 

320 TODO: This test can be removed once the deprecated scarlet.lite 

321 module is removed from the science pipelines. 

322 """ 

323 oldCatalog, oldModelData, oldConfig = self._deblend("old_lite") 

324 catalog, modelData, config = self._deblend("lite") 

325 

326 # Ensure that the deblender used different versions 

327 self.assertEqual(oldConfig.version, "old_lite") 

328 self.assertEqual(config.version, "lite") 

329 

330 # Check that the PSF and other properties are the same 

331 assert_almost_equal(oldModelData.psf, modelData.psf) 

332 self.assertTupleEqual(tuple(oldModelData.blends.keys()), tuple(modelData.blends.keys())) 

333 

334 # Make sure that the sources have the same IDs 

335 for i in range(len(catalog)): 

336 self.assertEqual(catalog[i]["id"], oldCatalog[i]["id"]) 

337 

338 for blendId in modelData.blends.keys(): 

339 oldBlendData = oldModelData.blends[blendId] 

340 blendData = modelData.blends[blendId] 

341 

342 # Check that blend properties are the same 

343 self.assertTupleEqual(oldBlendData.origin, blendData.origin) 

344 self.assertTupleEqual(oldBlendData.shape, blendData.shape) 

345 self.assertTupleEqual(oldBlendData.bands, blendData.bands) 

346 self.assertTupleEqual(oldBlendData.psf_center, blendData.psf_center) 

347 self.assertTupleEqual(tuple(oldBlendData.sources.keys()), tuple(blendData.sources.keys())) 

348 assert_almost_equal(oldBlendData.psf, blendData.psf) 

349 

350 for sourceId in blendData.sources.keys(): 

351 oldSourceData = oldBlendData.sources[sourceId] 

352 sourceData = blendData.sources[sourceId] 

353 # Check that source properties are the same 

354 self.assertEqual(len(oldSourceData.components), 0) 

355 self.assertEqual(len(sourceData.components), 0) 

356 self.assertEqual( 

357 len(oldSourceData.factorized_components), 

358 len(sourceData.factorized_components) 

359 ) 

360 

361 for c in range(len(sourceData.factorized_components)): 

362 oldComponentData = oldSourceData.factorized_components[c] 

363 componentData = sourceData.factorized_components[c] 

364 # Check that component properties are the same 

365 self.assertTupleEqual(oldComponentData.peak, componentData.peak) 

366 self.assertTupleEqual( 

367 tuple(oldComponentData.peak[i]-oldComponentData.shape[i]//2 for i in range(2)), 

368 oldComponentData.origin, 

369 ) 

370 

371 

372class MemoryTester(lsst.utils.tests.MemoryTestCase): 

373 pass 

374 

375 

376def setup_module(module): 

377 lsst.utils.tests.init() 

378 

379 

380if __name__ == "__main__": 380 ↛ 381line 380 didn't jump to line 381, because the condition on line 380 was never true

381 lsst.utils.tests.init() 

382 unittest.main()