Coverage for tests/test_deblend.py: 13%

160 statements  

« prev     ^ index     » next       coverage.py v7.3.3, created at 2023-12-15 12:27 +0000

1# This file is part of meas_extensions_scarlet. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import unittest 

23 

24import numpy as np 

25from numpy.testing import assert_almost_equal 

26 

27from lsst.geom import Point2I, Point2D 

28import lsst.utils.tests 

29import lsst.afw.image as afwImage 

30from lsst.meas.algorithms import SourceDetectionTask 

31from lsst.meas.extensions.scarlet.scarletDeblendTask import ScarletDeblendTask 

32from lsst.meas.extensions.scarlet.utils import bboxToScarletBox, scarletBoxToBBox 

33from lsst.meas.extensions.scarlet.io import monochromaticDataToScarlet, updateCatalogFootprints 

34import lsst.scarlet.lite as scl 

35from lsst.afw.table import SourceCatalog 

36from lsst.afw.detection import Footprint 

37from lsst.afw.geom import SpanSet, Stencil 

38 

39from utils import initData 

40 

41 

42class TestDeblend(lsst.utils.tests.TestCase): 

43 def setUp(self): 

44 # Set the random seed so that the noise field is unaffected 

45 np.random.seed(0) 

46 shape = (5, 100, 115) 

47 coords = [ 

48 # blend 

49 (15, 25), (10, 30), (17, 38), 

50 # isolated source 

51 (85, 90), 

52 ] 

53 amplitudes = [ 

54 # blend 

55 80, 60, 90, 

56 # isolated source 

57 20, 

58 ] 

59 result = initData(shape, coords, amplitudes) 

60 targetPsfImage, psfImages, images, channels, seds, morphs, psfs = result 

61 B, Ny, Nx = shape 

62 

63 # Add some noise, otherwise the task will blow up due to 

64 # zero variance 

65 noise = 10*(np.random.rand(*images.shape).astype(np.float32)-.5) 

66 images += noise 

67 

68 self.bands = "grizy" 

69 _images = afwImage.MultibandMaskedImage.fromArrays( 

70 self.bands, 

71 images.astype(np.float32), 

72 None, 

73 noise**2 

74 ) 

75 coadds = [afwImage.Exposure(img, dtype=img.image.array.dtype) for img in _images] 

76 self.coadds = afwImage.MultibandExposure.fromExposures(self.bands, coadds) 

77 for b, coadd in enumerate(self.coadds): 

78 coadd.setPsf(psfs[b]) 

79 

80 def _deblend(self, version): 

81 schema = SourceCatalog.Table.makeMinimalSchema() 

82 # Adjust config options to test skipping parents 

83 config = ScarletDeblendTask.ConfigClass() 

84 config.maxIter = 100 

85 config.maxFootprintArea = 1000 

86 config.maxNumberOfPeaks = 4 

87 config.catchFailures = False 

88 config.version = version 

89 

90 # Detect sources 

91 detectionTask = SourceDetectionTask(schema=schema) 

92 deblendTask = ScarletDeblendTask(schema=schema, config=config) 

93 table = SourceCatalog.Table.make(schema) 

94 detectionResult = detectionTask.run(table, self.coadds["r"]) 

95 catalog = detectionResult.sources 

96 

97 # Add a footprint that is too large 

98 src = catalog.addNew() 

99 halfLength = int(np.ceil(np.sqrt(config.maxFootprintArea) + 1)) 

100 ss = SpanSet.fromShape(halfLength, Stencil.BOX, offset=(50, 50)) 

101 bigfoot = Footprint(ss) 

102 bigfoot.addPeak(50, 50, 100) 

103 src.setFootprint(bigfoot) 

104 

105 # Add a footprint with too many peaks 

106 src = catalog.addNew() 

107 ss = SpanSet.fromShape(10, Stencil.BOX, offset=(75, 20)) 

108 denseFoot = Footprint(ss) 

109 for n in range(config.maxNumberOfPeaks+1): 

110 denseFoot.addPeak(70+2*n, 15+2*n, 10*n) 

111 src.setFootprint(denseFoot) 

112 

113 # Run the deblender 

114 catalog, modelData = deblendTask.run(self.coadds, catalog) 

115 return catalog, modelData, config 

116 

117 def test_deblend_task(self): 

118 catalog, modelData, config = self._deblend("lite") 

119 

120 # Attach the footprints in each band and compare to the full 

121 # data model. This is done in each band, both with and without 

122 # flux re-distribution to test all of the different possible 

123 # options of loading catalog footprints. 

124 for useFlux in [False, True]: 

125 for band in self.bands: 

126 bandIndex = self.bands.index(band) 

127 coadd = self.coadds[band] 

128 

129 if useFlux: 

130 imageForRedistribution = coadd 

131 else: 

132 imageForRedistribution = None 

133 

134 updateCatalogFootprints( 

135 modelData, 

136 catalog, 

137 band=band, 

138 imageForRedistribution=imageForRedistribution, 

139 removeScarletData=False, 

140 ) 

141 

142 # Check that the number of deblended children is consistent 

143 parents = catalog[catalog["parent"] == 0] 

144 self.assertEqual(np.sum(catalog["deblend_nChild"]), len(catalog)-len(parents)) 

145 

146 # Check that the models have not been cleared 

147 # from the modelData 

148 self.assertEqual(len(modelData.blends), np.sum(~parents["deblend_skipped"])) 

149 

150 for parent in parents: 

151 children = catalog[catalog["parent"] == parent.get("id")] 

152 # Check that nChild is set correctly 

153 self.assertEqual(len(children), parent.get("deblend_nChild")) 

154 # Check that parent columns are propagated 

155 # to their children 

156 for parentCol, childCol in config.columnInheritance.items(): 

157 np.testing.assert_array_equal(parent.get(parentCol), children[childCol]) 

158 

159 children = catalog[catalog["parent"] != 0] 

160 for child in children: 

161 fp = child.getFootprint() 

162 img = fp.extractImage(fill=0.0) 

163 # Check that the flux at the center is correct. 

164 # Note: this only works in this test image because the 

165 # detected peak is in the same location as the 

166 # scarlet peak. 

167 # If the peak is shifted, the flux value will be correct 

168 # but deblend_peak_center is not the correct location. 

169 px = child.get("deblend_peak_center_x") 

170 py = child.get("deblend_peak_center_y") 

171 flux = img[Point2I(px, py)] 

172 self.assertEqual(flux, child.get("deblend_peak_instFlux")) 

173 

174 # Check that the peak positions match the catalog entry 

175 peaks = fp.getPeaks() 

176 self.assertEqual(px, peaks[0].getIx()) 

177 self.assertEqual(py, peaks[0].getIy()) 

178 

179 # Load the data to check against the HeavyFootprint 

180 blendData = modelData.blends[child["parent"]] 

181 # We need to set an observation in order to convolve 

182 # the model. 

183 position = Point2D(*blendData.psf_center[::-1]) 

184 _psfs = self.coadds[band].getPsf().computeKernelImage(position).array[None, :, :] 

185 modelBox = scl.Box(blendData.shape, origin=blendData.origin) 

186 observation = scl.Observation.empty( 

187 bands=("dummy", ), 

188 psfs=_psfs, 

189 model_psf=modelData.psf[None, :, :], 

190 bbox=modelBox, 

191 dtype=np.float32, 

192 ) 

193 blend = monochromaticDataToScarlet( 

194 blendData=blendData, 

195 bandIndex=bandIndex, 

196 observation=observation, 

197 ) 

198 # The stored PSF should be the same as the calculated one 

199 assert_almost_equal(blendData.psf[bandIndex:bandIndex+1], _psfs) 

200 

201 # Get the scarlet model for the source 

202 source = [src for src in blend.sources if src.record_id == child.getId()][0] 

203 self.assertEqual(source.center[1], px) 

204 self.assertEqual(source.center[0], py) 

205 

206 if useFlux: 

207 # Get the flux re-weighted model and test against 

208 # the HeavyFootprint. 

209 # The HeavyFootprint needs to be projected onto 

210 # the image of the flux-redistributed model, 

211 # since the HeavyFootprint may trim rows or columns. 

212 parentFootprint = catalog[catalog["id"] == child["parent"]][0].getFootprint() 

213 _images = imageForRedistribution[parentFootprint.getBBox()].image.array 

214 blend.observation.images = scl.Image( 

215 _images[None, :, :], 

216 yx0=blendData.origin, 

217 bands=("dummy", ), 

218 ) 

219 blend.observation.weights = scl.Image( 

220 parentFootprint.spans.asArray()[None, :, :], 

221 yx0=blendData.origin, 

222 bands=("dummy", ), 

223 ) 

224 blend.conserve_flux() 

225 model = source.flux_weighted_image.data[0] 

226 bbox = scarletBoxToBBox(source.flux_weighted_image.bbox) 

227 image = afwImage.ImageF(model, xy0=bbox.getMin()) 

228 fp.insert(image) 

229 np.testing.assert_almost_equal(image.array, model) 

230 else: 

231 # Get the model for the source and test 

232 # against the HeavyFootprint 

233 bbox = fp.getBBox() 

234 bbox = bboxToScarletBox(bbox) 

235 model = blend.observation.convolve( 

236 source.get_model().project(bbox=bbox), mode="real" 

237 ).data[0] 

238 np.testing.assert_almost_equal(img.array, model) 

239 

240 # Check that all sources have the correct number of peaks 

241 for src in catalog: 

242 fp = src.getFootprint() 

243 self.assertEqual(len(fp.peaks), src.get("deblend_nPeaks")) 

244 

245 # Check that only the large footprint was flagged as too big 

246 largeFootprint = np.zeros(len(catalog), dtype=bool) 

247 largeFootprint[2] = True 

248 np.testing.assert_array_equal(largeFootprint, catalog["deblend_parentTooBig"]) 

249 

250 # Check that only the dense footprint was flagged as too dense 

251 denseFootprint = np.zeros(len(catalog), dtype=bool) 

252 denseFootprint[3] = True 

253 np.testing.assert_array_equal(denseFootprint, catalog["deblend_tooManyPeaks"]) 

254 

255 # Check that only the appropriate parents were skipped 

256 skipped = largeFootprint | denseFootprint 

257 np.testing.assert_array_equal(skipped, catalog["deblend_skipped"]) 

258 

259 def test_continuity(self): 

260 """This test ensures that lsst.scarlet.lite gives roughly the same 

261 result as scarlet.lite 

262 

263 TODO: This test can be removed once the deprecated scarlet.lite 

264 module is removed from the science pipelines. 

265 """ 

266 oldCatalog, oldModelData, oldConfig = self._deblend("old_lite") 

267 catalog, modelData, config = self._deblend("lite") 

268 

269 # Ensure that the deblender used different versions 

270 self.assertEqual(oldConfig.version, "old_lite") 

271 self.assertEqual(config.version, "lite") 

272 

273 # Check that the PSF and other properties are the same 

274 assert_almost_equal(oldModelData.psf, modelData.psf) 

275 self.assertTupleEqual(tuple(oldModelData.blends.keys()), tuple(modelData.blends.keys())) 

276 

277 # Make sure that the sources have the same IDs 

278 for i in range(len(catalog)): 

279 self.assertEqual(catalog[i]["id"], oldCatalog[i]["id"]) 

280 

281 for blendId in modelData.blends.keys(): 

282 oldBlendData = oldModelData.blends[blendId] 

283 blendData = modelData.blends[blendId] 

284 

285 # Check that blend properties are the same 

286 self.assertTupleEqual(oldBlendData.origin, blendData.origin) 

287 self.assertTupleEqual(oldBlendData.shape, blendData.shape) 

288 self.assertTupleEqual(oldBlendData.bands, blendData.bands) 

289 self.assertTupleEqual(oldBlendData.psf_center, blendData.psf_center) 

290 self.assertTupleEqual(tuple(oldBlendData.sources.keys()), tuple(blendData.sources.keys())) 

291 assert_almost_equal(oldBlendData.psf, blendData.psf) 

292 

293 for sourceId in blendData.sources.keys(): 

294 oldSourceData = oldBlendData.sources[sourceId] 

295 sourceData = blendData.sources[sourceId] 

296 # Check that source properties are the same 

297 self.assertEqual(len(oldSourceData.components), 0) 

298 self.assertEqual(len(sourceData.components), 0) 

299 self.assertEqual( 

300 len(oldSourceData.factorized_components), 

301 len(sourceData.factorized_components) 

302 ) 

303 

304 for c in range(len(sourceData.factorized_components)): 

305 oldComponentData = oldSourceData.factorized_components[c] 

306 componentData = sourceData.factorized_components[c] 

307 # Check that component properties are the same 

308 self.assertTupleEqual(oldComponentData.peak, componentData.peak) 

309 self.assertTupleEqual( 

310 tuple(oldComponentData.peak[i]-oldComponentData.shape[i]//2 for i in range(2)), 

311 oldComponentData.origin, 

312 ) 

313 

314 

315class MemoryTester(lsst.utils.tests.MemoryTestCase): 

316 pass 

317 

318 

319def setup_module(module): 

320 lsst.utils.tests.init() 

321 

322 

323if __name__ == "__main__": 323 ↛ 324line 323 didn't jump to line 324, because the condition on line 323 was never true

324 lsst.utils.tests.init() 

325 unittest.main()