Coverage for tests/test_deblend.py: 14%

126 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-25 00:40 -0700

1# This file is part of meas_extensions_scarlet. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import unittest 

23 

24import numpy as np 

25from scarlet.bbox import Box 

26from scarlet.lite.measure import weight_sources 

27 

28from lsst.geom import Point2I, Point2D 

29import lsst.utils.tests 

30import lsst.afw.image as afwImage 

31from lsst.meas.algorithms import SourceDetectionTask 

32from lsst.meas.extensions.scarlet.scarletDeblendTask import ScarletDeblendTask 

33from lsst.meas.extensions.scarlet.source import bboxToScarletBox, scarletBoxToBBox 

34from lsst.meas.extensions.scarlet.io import dataToScarlet, DummyObservation 

35from lsst.afw.table import SourceCatalog 

36from lsst.afw.detection import Footprint 

37from lsst.afw.geom import SpanSet, Stencil 

38 

39from utils import initData 

40 

41 

42class TestDeblend(lsst.utils.tests.TestCase): 

43 def test_deblend_task(self): 

44 # Set the random seed so that the noise field is unaffected 

45 np.random.seed(0) 

46 shape = (5, 100, 115) 

47 coords = [ 

48 # blend 

49 (15, 25), (10, 30), (17, 38), 

50 # isolated source 

51 (85, 90), 

52 ] 

53 amplitudes = [ 

54 # blend 

55 80, 60, 90, 

56 # isolated source 

57 20, 

58 ] 

59 result = initData(shape, coords, amplitudes) 

60 targetPsfImage, psfImages, images, channels, seds, morphs, targetPsf, psfs = result 

61 B, Ny, Nx = shape 

62 

63 # Add some noise, otherwise the task will blow up due to 

64 # zero variance 

65 noise = 10*(np.random.rand(*images.shape).astype(np.float32)-.5) 

66 images += noise 

67 

68 filters = "grizy" 

69 _images = afwImage.MultibandMaskedImage.fromArrays(filters, images.astype(np.float32), None, noise**2) 

70 coadds = [afwImage.Exposure(img, dtype=img.image.array.dtype) for img in _images] 

71 coadds = afwImage.MultibandExposure.fromExposures(filters, coadds) 

72 for b, coadd in enumerate(coadds): 

73 coadd.setPsf(psfs[b]) 

74 

75 schema = SourceCatalog.Table.makeMinimalSchema() 

76 

77 detectionTask = SourceDetectionTask(schema=schema) 

78 

79 # Adjust config options to test skipping parents 

80 config = ScarletDeblendTask.ConfigClass() 

81 config.maxIter = 100 

82 config.maxFootprintArea = 1000 

83 config.maxNumberOfPeaks = 4 

84 deblendTask = ScarletDeblendTask(schema=schema, config=config) 

85 

86 table = SourceCatalog.Table.make(schema) 

87 detectionResult = detectionTask.run(table, coadds["r"]) 

88 catalog = detectionResult.sources 

89 

90 # Add a footprint that is too large 

91 src = catalog.addNew() 

92 halfLength = int(np.ceil(np.sqrt(config.maxFootprintArea) + 1)) 

93 ss = SpanSet.fromShape(halfLength, Stencil.BOX, offset=(50, 50)) 

94 bigfoot = Footprint(ss) 

95 bigfoot.addPeak(50, 50, 100) 

96 src.setFootprint(bigfoot) 

97 

98 # Add a footprint with too many peaks 

99 src = catalog.addNew() 

100 ss = SpanSet.fromShape(10, Stencil.BOX, offset=(75, 20)) 

101 denseFoot = Footprint(ss) 

102 for n in range(config.maxNumberOfPeaks+1): 

103 denseFoot.addPeak(70+2*n, 15+2*n, 10*n) 

104 src.setFootprint(denseFoot) 

105 

106 # Run the deblender 

107 catalog, modelData = deblendTask.run(coadds, catalog) 

108 

109 # Attach the footprints in each band and compare to the full 

110 # data model. This is done in each band, both with and without 

111 # flux re-distribution to test all of the different possible 

112 # options of loading catalog footprints. 

113 for useFlux in [False, True]: 

114 for band in filters: 

115 bandIndex = filters.index(band) 

116 coadd = coadds[band] 

117 psfModel = coadd.getPsf() 

118 

119 if useFlux: 

120 redistributeImage = coadd.image 

121 else: 

122 redistributeImage = None 

123 

124 modelData.updateCatalogFootprints( 

125 catalog, 

126 band=band, 

127 psfModel=psfModel, 

128 maskImage=coadd.mask, 

129 redistributeImage=redistributeImage, 

130 removeScarletData=False, 

131 ) 

132 

133 # Check that the number of deblended children is consistent 

134 parents = catalog[catalog["parent"] == 0] 

135 self.assertEqual(np.sum(catalog["deblend_nChild"]), len(catalog)-len(parents)) 

136 

137 # Check that the models have not been cleared 

138 # from the modelData 

139 self.assertEqual(len(modelData.blends), np.sum(~parents["deblend_skipped"])) 

140 

141 for parent in parents: 

142 children = catalog[catalog["parent"] == parent.get("id")] 

143 # Check that nChild is set correctly 

144 self.assertEqual(len(children), parent.get("deblend_nChild")) 

145 # Check that parent columns are propagated 

146 # to their children 

147 for parentCol, childCol in config.columnInheritance.items(): 

148 np.testing.assert_array_equal(parent.get(parentCol), children[childCol]) 

149 

150 children = catalog[catalog["parent"] != 0] 

151 for child in children: 

152 fp = child.getFootprint() 

153 img = fp.extractImage(fill=0.0) 

154 # Check that the flux at the center is correct. 

155 # Note: this only works in this test image because the 

156 # detected peak is in the same location as the 

157 # scarlet peak. 

158 # If the peak is shifted, the flux value will be correct 

159 # but deblend_peak_center is not the correct location. 

160 px = child.get("deblend_peak_center_x") 

161 py = child.get("deblend_peak_center_y") 

162 flux = img[Point2I(px, py)] 

163 self.assertEqual(flux, child.get("deblend_peak_instFlux")) 

164 

165 # Check that the peak positions match the catalog entry 

166 peaks = fp.getPeaks() 

167 self.assertEqual(px, peaks[0].getIx()) 

168 self.assertEqual(py, peaks[0].getIy()) 

169 

170 # Load the data to check against the HeavyFootprint 

171 blendData = modelData.blends[child["parent"]] 

172 blend = dataToScarlet( 

173 blendData=blendData, 

174 bandIndex=bandIndex, 

175 ) 

176 # We need to set an observation in order to convolve 

177 # the model. 

178 position = Point2D(*blendData.psfCenter) 

179 _psfs = coadds[band].getPsf().computeKernelImage(position).array[None, :, :] 

180 modelBox = Box((1,) + tuple(blendData.extent[::-1]), origin=(0, 0, 0)) 

181 blend.observation = DummyObservation( 

182 psfs=_psfs, 

183 model_psf=modelData.psf[None, :, :], 

184 bbox=modelBox, 

185 dtype=np.float32, 

186 ) 

187 

188 # Get the scarlet model for the source 

189 source = [src for src in blend.sources if src.recordId == child.getId()][0] 

190 parentBox = catalog.find(child["parent"]).getFootprint().getBBox() 

191 self.assertEqual(source.center[1], px - parentBox.getMinX()) 

192 self.assertEqual(source.center[0], py - parentBox.getMinY()) 

193 

194 if useFlux: 

195 # Get the flux re-weighted model and test against 

196 # the HeavyFootprint. 

197 # The HeavyFootprint needs to be projected onto 

198 # the image of the flux-redistributed model, 

199 # since the HeavyFootprint may trim rows or columns. 

200 parentFootprint = catalog[catalog["id"] == child["parent"]][0].getFootprint() 

201 blend.observation.images = redistributeImage[parentFootprint.getBBox()].array 

202 blend.observation.images = blend.observation.images[None, :, :] 

203 blend.observation.weights = parentFootprint.spans.asArray()[None, :, :] 

204 weight_sources(blend) 

205 model = source.flux[0] 

206 bbox = scarletBoxToBBox(source.flux_box, Point2I(*blendData.xy0)) 

207 image = afwImage.ImageF(model, xy0=bbox.getMin()) 

208 fp.insert(image) 

209 np.testing.assert_almost_equal(image.array, model) 

210 else: 

211 # Get the model for the source and test 

212 # against the HeavyFootprint 

213 bbox = fp.getBBox() 

214 bbox = bboxToScarletBox(1, bbox, Point2I(*blendData.xy0)) 

215 model = blend.observation.convolve(source.get_model(bbox=bbox))[0] 

216 np.testing.assert_almost_equal(img.array, model) 

217 

218 # Check that all sources have the correct number of peaks 

219 for src in catalog: 

220 fp = src.getFootprint() 

221 self.assertEqual(len(fp.peaks), src.get("deblend_nPeaks")) 

222 

223 # Check that only the large footprint was flagged as too big 

224 largeFootprint = np.zeros(len(catalog), dtype=bool) 

225 largeFootprint[2] = True 

226 np.testing.assert_array_equal(largeFootprint, catalog["deblend_parentTooBig"]) 

227 

228 # Check that only the dense footprint was flagged as too dense 

229 denseFootprint = np.zeros(len(catalog), dtype=bool) 

230 denseFootprint[3] = True 

231 np.testing.assert_array_equal(denseFootprint, catalog["deblend_tooManyPeaks"]) 

232 

233 # Check that only the appropriate parents were skipped 

234 skipped = largeFootprint | denseFootprint 

235 np.testing.assert_array_equal(skipped, catalog["deblend_skipped"]) 

236 

237 

238class MemoryTester(lsst.utils.tests.MemoryTestCase): 

239 pass 

240 

241 

242def setup_module(module): 

243 lsst.utils.tests.init() 

244 

245 

246if __name__ == "__main__": 246 ↛ 247line 246 didn't jump to line 247, because the condition on line 246 was never true

247 lsst.utils.tests.init() 

248 unittest.main()