Coverage for tests/test_deblend.py: 15%

127 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-11-08 22:56 -0800

1# This file is part of meas_extensions_scarlet. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import unittest 

23 

24import numpy as np 

25from scarlet.bbox import Box 

26from scarlet.lite.measure import weight_sources 

27 

28from lsst.geom import Point2I, Point2D 

29import lsst.utils.tests 

30import lsst.afw.image as afwImage 

31from lsst.meas.algorithms import SourceDetectionTask 

32from lsst.meas.extensions.scarlet.scarletDeblendTask import ScarletDeblendTask, getFootprintMask 

33from lsst.meas.extensions.scarlet.source import bboxToScarletBox, scarletBoxToBBox 

34from lsst.meas.extensions.scarlet.io import dataToScarlet, DummyObservation 

35from lsst.afw.table import SourceCatalog 

36from lsst.afw.detection import Footprint 

37from lsst.afw.detection.multiband import heavyFootprintToImage 

38from lsst.afw.geom import SpanSet, Stencil 

39 

40from utils import initData 

41 

42 

43class TestDeblend(lsst.utils.tests.TestCase): 

44 def test_deblend_task(self): 

45 # Set the random seed so that the noise field is unaffected 

46 np.random.seed(0) 

47 shape = (5, 100, 115) 

48 coords = [ 

49 # blend 

50 (15, 25), (10, 30), (17, 38), 

51 # isolated source 

52 (85, 90), 

53 ] 

54 amplitudes = [ 

55 # blend 

56 80, 60, 90, 

57 # isolated source 

58 20, 

59 ] 

60 result = initData(shape, coords, amplitudes) 

61 targetPsfImage, psfImages, images, channels, seds, morphs, targetPsf, psfs = result 

62 B, Ny, Nx = shape 

63 

64 # Add some noise, otherwise the task will blow up due to 

65 # zero variance 

66 noise = 10*(np.random.rand(*images.shape).astype(np.float32)-.5) 

67 images += noise 

68 

69 filters = "grizy" 

70 _images = afwImage.MultibandMaskedImage.fromArrays(filters, images.astype(np.float32), None, noise**2) 

71 coadds = [afwImage.Exposure(img, dtype=img.image.array.dtype) for img in _images] 

72 coadds = afwImage.MultibandExposure.fromExposures(filters, coadds) 

73 for b, coadd in enumerate(coadds): 

74 coadd.setPsf(psfs[b]) 

75 

76 schema = SourceCatalog.Table.makeMinimalSchema() 

77 

78 detectionTask = SourceDetectionTask(schema=schema) 

79 

80 # Adjust config options to test skipping parents 

81 config = ScarletDeblendTask.ConfigClass() 

82 config.maxIter = 100 

83 config.maxFootprintArea = 1000 

84 config.maxNumberOfPeaks = 4 

85 deblendTask = ScarletDeblendTask(schema=schema, config=config) 

86 

87 table = SourceCatalog.Table.make(schema) 

88 detectionResult = detectionTask.run(table, coadds["r"]) 

89 catalog = detectionResult.sources 

90 

91 # Add a footprint that is too large 

92 src = catalog.addNew() 

93 halfLength = int(np.ceil(np.sqrt(config.maxFootprintArea) + 1)) 

94 ss = SpanSet.fromShape(halfLength, Stencil.BOX, offset=(50, 50)) 

95 bigfoot = Footprint(ss) 

96 bigfoot.addPeak(50, 50, 100) 

97 src.setFootprint(bigfoot) 

98 

99 # Add a footprint with too many peaks 

100 src = catalog.addNew() 

101 ss = SpanSet.fromShape(10, Stencil.BOX, offset=(75, 20)) 

102 denseFoot = Footprint(ss) 

103 for n in range(config.maxNumberOfPeaks+1): 

104 denseFoot.addPeak(70+2*n, 15+2*n, 10*n) 

105 src.setFootprint(denseFoot) 

106 

107 # Run the deblender 

108 catalog, modelData = deblendTask.run(coadds, catalog) 

109 

110 # Attach the footprints in each band and compare to the full 

111 # data model. This is done in each band, both with and without 

112 # flux re-distribution to test all of the different possible 

113 # options of loading catalog footprints. 

114 for useFlux in [False, True]: 

115 for band in filters: 

116 bandIndex = filters.index(band) 

117 coadd = coadds[band] 

118 psfModel = coadd.getPsf() 

119 

120 if useFlux: 

121 redistributeImage = coadd.image 

122 else: 

123 redistributeImage = None 

124 

125 modelData.updateCatalogFootprints( 

126 catalog, 

127 band=band, 

128 psfModel=psfModel, 

129 redistributeImage=redistributeImage, 

130 removeScarletData=False, 

131 ) 

132 

133 # Check that the number of deblended children is consistent 

134 parents = catalog[catalog["parent"] == 0] 

135 self.assertEqual(np.sum(catalog["deblend_nChild"]), len(catalog)-len(parents)) 

136 

137 # Check that the models have not been cleared 

138 # from the modelData 

139 self.assertEqual(len(modelData.blends), np.sum(~parents["deblend_skipped"])) 

140 

141 for parent in parents: 

142 children = catalog[catalog["parent"] == parent.get("id")] 

143 # Check that nChild is set correctly 

144 self.assertEqual(len(children), parent.get("deblend_nChild")) 

145 # Check that parent columns are propagated 

146 # to their children 

147 for parentCol, childCol in config.columnInheritance.items(): 

148 np.testing.assert_array_equal(parent.get(parentCol), children[childCol]) 

149 

150 children = catalog[catalog["parent"] != 0] 

151 for child in children: 

152 fp = child.getFootprint() 

153 img = heavyFootprintToImage(fp, fill=0.0) 

154 # Check that the flux at the center is correct. 

155 # Note: this only works in this test image because the 

156 # detected peak is in the same location as the 

157 # scarlet peak. 

158 # If the peak is shifted, the flux value will be correct 

159 # but deblend_peak_center is not the correct location. 

160 px = child.get("deblend_peak_center_x") 

161 py = child.get("deblend_peak_center_y") 

162 flux = img.image[Point2I(px, py)] 

163 self.assertEqual(flux, child.get("deblend_peak_instFlux")) 

164 

165 # Check that the peak positions match the catalog entry 

166 peaks = fp.getPeaks() 

167 self.assertEqual(px, peaks[0].getIx()) 

168 self.assertEqual(py, peaks[0].getIy()) 

169 

170 # Load the data to check against the HeavyFootprint 

171 blendData = modelData.blends[child["parent"]] 

172 blend = dataToScarlet( 

173 blendData=blendData, 

174 nBands=1, 

175 bandIndex=bandIndex, 

176 ) 

177 # We need to set an observation in order to convolve 

178 # the model. 

179 position = Point2D(*blendData.psfCenter) 

180 _psfs = coadds[band].getPsf().computeKernelImage(position).array[None, :, :] 

181 modelBox = Box((1,) + tuple(blendData.extent[::-1]), origin=(0, 0, 0)) 

182 blend.observation = DummyObservation( 

183 psfs=_psfs, 

184 model_psf=modelData.psf[None, :, :], 

185 bbox=modelBox, 

186 dtype=np.float32, 

187 ) 

188 

189 # Get the scarlet model for the source 

190 source = [src for src in blend.sources if src.recordId == child.getId()][0] 

191 parentBox = catalog.find(child["parent"]).getFootprint().getBBox() 

192 self.assertEqual(source.center[1], px - parentBox.getMinX()) 

193 self.assertEqual(source.center[0], py - parentBox.getMinY()) 

194 

195 if useFlux: 

196 # Get the flux re-weighted model and test against 

197 # the HeavyFootprint. 

198 # The HeavyFootprint needs to be projected onto 

199 # the image of the flux-redistributed model, 

200 # since the HeavyFootprint may trim rows or columns. 

201 parentFootprint = catalog[catalog["id"] == child["parent"]][0].getFootprint() 

202 blend.observation.images = redistributeImage[parentFootprint.getBBox()].array 

203 blend.observation.images = blend.observation.images[None, :, :] 

204 blend.observation.weights = ~getFootprintMask(parentFootprint, None)[None, :, :] 

205 weight_sources(blend) 

206 model = source.flux[0] 

207 bbox = scarletBoxToBBox(source.flux_box, Point2I(*blendData.xy0)) 

208 image = afwImage.ImageF(model, xy0=bbox.getMin()) 

209 fp.insert(image) 

210 np.testing.assert_almost_equal(image.array, model) 

211 else: 

212 # Get the model for the source and test 

213 # against the HeavyFootprint 

214 bbox = fp.getBBox() 

215 bbox = bboxToScarletBox(1, bbox, Point2I(*blendData.xy0)) 

216 model = blend.observation.convolve(source.get_model(bbox=bbox))[0] 

217 np.testing.assert_almost_equal(img.image.array, model) 

218 

219 # Check that all sources have the correct number of peaks 

220 for src in catalog: 

221 fp = src.getFootprint() 

222 self.assertEqual(len(fp.peaks), src.get("deblend_nPeaks")) 

223 

224 # Check that only the large footprint was flagged as too big 

225 largeFootprint = np.zeros(len(catalog), dtype=bool) 

226 largeFootprint[2] = True 

227 np.testing.assert_array_equal(largeFootprint, catalog["deblend_parentTooBig"]) 

228 

229 # Check that only the dense footprint was flagged as too dense 

230 denseFootprint = np.zeros(len(catalog), dtype=bool) 

231 denseFootprint[3] = True 

232 np.testing.assert_array_equal(denseFootprint, catalog["deblend_tooManyPeaks"]) 

233 

234 # Check that only the appropriate parents were skipped 

235 skipped = largeFootprint | denseFootprint 

236 np.testing.assert_array_equal(skipped, catalog["deblend_skipped"]) 

237 

238 

239class MemoryTester(lsst.utils.tests.MemoryTestCase): 

240 pass 

241 

242 

243def setup_module(module): 

244 lsst.utils.tests.init() 

245 

246 

247if __name__ == "__main__": 247 ↛ 248line 247 didn't jump to line 248, because the condition on line 247 was never true

248 lsst.utils.tests.init() 

249 unittest.main()