Coverage for tests/test_deblend.py: 18%

124 statements  

« prev     ^ index     » next       coverage.py v6.4.1, created at 2022-06-23 11:52 +0000

1# This file is part of meas_extensions_scarlet. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import unittest 

23 

24import numpy as np 

25from scarlet.bbox import Box 

26from scarlet.lite.measure import weight_sources 

27 

28from lsst.geom import Point2I, Point2D 

29import lsst.utils.tests 

30import lsst.afw.image as afwImage 

31from lsst.meas.algorithms import SourceDetectionTask 

32from lsst.meas.extensions.scarlet.scarletDeblendTask import ScarletDeblendTask, getFootprintMask 

33from lsst.meas.extensions.scarlet.source import bboxToScarletBox, scarletBoxToBBox 

34from lsst.meas.extensions.scarlet.io import dataToScarlet, DummyObservation 

35from lsst.afw.table import SourceCatalog 

36from lsst.afw.detection import Footprint 

37from lsst.afw.detection.multiband import heavyFootprintToImage 

38from lsst.afw.geom import SpanSet, Stencil 

39 

40from utils import initData 

41 

42 

43class TestDeblend(lsst.utils.tests.TestCase): 

44 def test_deblend_task(self): 

45 # Set the random seed so that the noise field is unaffected 

46 np.random.seed(0) 

47 shape = (5, 100, 115) 

48 coords = [ 

49 # blend 

50 (15, 25), (10, 30), (17, 38), 

51 # isolated source 

52 (85, 90), 

53 ] 

54 amplitudes = [ 

55 # blend 

56 80, 60, 90, 

57 # isolated source 

58 20, 

59 ] 

60 result = initData(shape, coords, amplitudes) 

61 targetPsfImage, psfImages, images, channels, seds, morphs, targetPsf, psfs = result 

62 B, Ny, Nx = shape 

63 

64 # Add some noise, otherwise the task will blow up due to 

65 # zero variance 

66 noise = 10*(np.random.rand(*images.shape).astype(np.float32)-.5) 

67 images += noise 

68 

69 filters = "grizy" 

70 _images = afwImage.MultibandMaskedImage.fromArrays(filters, images.astype(np.float32), None, noise**2) 

71 coadds = [afwImage.Exposure(img, dtype=img.image.array.dtype) for img in _images] 

72 coadds = afwImage.MultibandExposure.fromExposures(filters, coadds) 

73 for b, coadd in enumerate(coadds): 

74 coadd.setPsf(psfs[b]) 

75 

76 schema = SourceCatalog.Table.makeMinimalSchema() 

77 

78 detectionTask = SourceDetectionTask(schema=schema) 

79 

80 # Adjust config options to test skipping parents 

81 config = ScarletDeblendTask.ConfigClass() 

82 config.maxIter = 100 

83 config.maxFootprintArea = 1000 

84 config.maxNumberOfPeaks = 4 

85 deblendTask = ScarletDeblendTask(schema=schema, config=config) 

86 

87 table = SourceCatalog.Table.make(schema) 

88 detectionResult = detectionTask.run(table, coadds["r"]) 

89 catalog = detectionResult.sources 

90 

91 # Add a footprint that is too large 

92 src = catalog.addNew() 

93 halfLength = int(np.ceil(np.sqrt(config.maxFootprintArea) + 1)) 

94 ss = SpanSet.fromShape(halfLength, Stencil.BOX, offset=(50, 50)) 

95 bigfoot = Footprint(ss) 

96 bigfoot.addPeak(50, 50, 100) 

97 src.setFootprint(bigfoot) 

98 

99 # Add a footprint with too many peaks 

100 src = catalog.addNew() 

101 ss = SpanSet.fromShape(10, Stencil.BOX, offset=(75, 20)) 

102 denseFoot = Footprint(ss) 

103 for n in range(config.maxNumberOfPeaks+1): 

104 denseFoot.addPeak(70+2*n, 15+2*n, 10*n) 

105 src.setFootprint(denseFoot) 

106 

107 # Run the deblender 

108 catalog, modelData = deblendTask.run(coadds, catalog) 

109 

110 # Attach the footprints in each band and compare to the full 

111 # data model. This is done in each band, both with and without 

112 # flux re-distribution to test all of the different possible 

113 # options of loading catalog footprints. 

114 for useFlux in [False, True]: 

115 for band in filters: 

116 bandIndex = filters.index(band) 

117 coadd = coadds[band] 

118 psfModel = coadd.getPsf() 

119 

120 if useFlux: 

121 redistributeImage = coadd.image 

122 else: 

123 redistributeImage = None 

124 

125 modelData.updateCatalogFootprints( 

126 catalog, 

127 band=band, 

128 psfModel=psfModel, 

129 redistributeImage=redistributeImage, 

130 removeScarletData=False, 

131 ) 

132 

133 # Check that the number of deblended children is consistent 

134 parents = catalog[catalog["parent"] == 0] 

135 self.assertEqual(np.sum(catalog["deblend_nChild"]), len(catalog)-len(parents)) 

136 

137 # Check that the models have not been cleared 

138 # from the modelData 

139 self.assertEqual(len(modelData.blends), np.sum(~parents["deblend_skipped"])) 

140 

141 for parent in parents: 

142 children = catalog[catalog["parent"] == parent.get("id")] 

143 # Check that nChild is set correctly 

144 self.assertEqual(len(children), parent.get("deblend_nChild")) 

145 # Check that parent columns are propagated 

146 # to their children 

147 for parentCol, childCol in config.columnInheritance.items(): 

148 np.testing.assert_array_equal(parent.get(parentCol), children[childCol]) 

149 

150 children = catalog[catalog["parent"] != 0] 

151 for child in children: 

152 fp = child.getFootprint() 

153 img = heavyFootprintToImage(fp, fill=0.0) 

154 # Check that the flux at the center is correct. 

155 # Note: this only works in this test image because the 

156 # detected peak is in the same location as the 

157 # scarlet peak. 

158 # If the peak is shifted, the flux value will be correct 

159 # but deblend_peak_center is not the correct location. 

160 px = child.get("deblend_peak_center_x") 

161 py = child.get("deblend_peak_center_y") 

162 flux = img.image[Point2I(px, py)] 

163 self.assertEqual(flux, child.get("deblend_peak_instFlux")) 

164 

165 # Check that the peak positions match the catalog entry 

166 peaks = fp.getPeaks() 

167 self.assertEqual(px, peaks[0].getIx()) 

168 self.assertEqual(py, peaks[0].getIy()) 

169 

170 # Load the data to check against the HeavyFootprint 

171 blendData = modelData.blends[child["parent"]] 

172 blend = dataToScarlet( 

173 blendData=blendData, 

174 nBands=1, 

175 bandIndex=bandIndex, 

176 ) 

177 # We need to set an observation in order to convolve 

178 # the model. 

179 position = Point2D(*blendData.psfCenter) 

180 _psfs = coadds[band].getPsf().computeKernelImage(position).array[None, :, :] 

181 modelBox = Box((1,) + tuple(blendData.extent[::-1]), origin=(0, 0, 0)) 

182 blend.observation = DummyObservation( 

183 psfs=_psfs, 

184 model_psf=modelData.psf[None, :, :], 

185 bbox=modelBox, 

186 dtype=np.float32, 

187 ) 

188 

189 # Get the scarlet model for the source 

190 source = [src for src in blend.sources if src.recordId == child.getId()][0] 

191 

192 if useFlux: 

193 # Get the flux re-weighted model and test against 

194 # the HeavyFootprint. 

195 # The HeavyFootprint needs to be projected onto 

196 # the image of the flux-redistributed model, 

197 # since the HeavyFootprint may trim rows or columns. 

198 parentFootprint = catalog[catalog["id"] == child["parent"]][0].getFootprint() 

199 blend.observation.images = redistributeImage[parentFootprint.getBBox()].array 

200 blend.observation.images = blend.observation.images[None, :, :] 

201 blend.observation.weights = ~getFootprintMask(parentFootprint, None)[None, :, :] 

202 weight_sources(blend) 

203 model = source.flux[0] 

204 bbox = scarletBoxToBBox(source.flux_box, Point2I(*blendData.xy0)) 

205 image = afwImage.ImageF(model, xy0=bbox.getMin()) 

206 fp.insert(image) 

207 np.testing.assert_almost_equal(image.array, model) 

208 else: 

209 # Get the model for the source and test 

210 # against the HeavyFootprint 

211 bbox = fp.getBBox() 

212 bbox = bboxToScarletBox(1, bbox, Point2I(*blendData.xy0)) 

213 model = blend.observation.convolve(source.get_model(bbox=bbox))[0] 

214 np.testing.assert_almost_equal(img.image.array, model) 

215 

216 # Check that all sources have the correct number of peaks 

217 for src in catalog: 

218 fp = src.getFootprint() 

219 self.assertEqual(len(fp.peaks), src.get("deblend_nPeaks")) 

220 

221 # Check that only the large footprint was flagged as too big 

222 largeFootprint = np.zeros(len(catalog), dtype=bool) 

223 largeFootprint[2] = True 

224 np.testing.assert_array_equal(largeFootprint, catalog["deblend_parentTooBig"]) 

225 

226 # Check that only the dense footprint was flagged as too dense 

227 denseFootprint = np.zeros(len(catalog), dtype=bool) 

228 denseFootprint[3] = True 

229 np.testing.assert_array_equal(denseFootprint, catalog["deblend_tooManyPeaks"]) 

230 

231 # Check that only the appropriate parents were skipped 

232 skipped = largeFootprint | denseFootprint 

233 np.testing.assert_array_equal(skipped, catalog["deblend_skipped"]) 

234 

235 

236class MemoryTester(lsst.utils.tests.MemoryTestCase): 

237 pass 

238 

239 

240def setup_module(module): 

241 lsst.utils.tests.init() 

242 

243 

244if __name__ == "__main__": 244 ↛ 245line 244 didn't jump to line 245, because the condition on line 244 was never true

245 lsst.utils.tests.init() 

246 unittest.main()