Coverage for tests/test_deblend.py: 16%

95 statements  

« prev     ^ index     » next       coverage.py v7.2.1, created at 2023-03-12 03:18 -0700

1# This file is part of meas_extensions_scarlet. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import unittest 

23 

24import numpy as np 

25 

26from lsst.geom import Point2I 

27import lsst.utils.tests 

28import lsst.afw.image as afwImage 

29from lsst.meas.algorithms import SourceDetectionTask 

30from lsst.meas.extensions.scarlet import ScarletDeblendTask 

31from lsst.afw.table import SourceCatalog 

32from lsst.afw.detection import Footprint 

33from lsst.afw.detection.multiband import heavyFootprintToImage 

34from lsst.afw.geom import SpanSet, Stencil 

35 

36from utils import initData 

37 

38 

39class TestDeblend(lsst.utils.tests.TestCase): 

40 def test_deblend_task(self): 

41 # Set the random seed so that the noise field is unaffected 

42 np.random.seed(0) 

43 shape = (5, 100, 115) 

44 coords = [ 

45 # blend 

46 (15, 25), (10, 30), (17, 38), 

47 # isolated source 

48 (85, 90), 

49 ] 

50 amplitudes = [ 

51 # blend 

52 80, 60, 90, 

53 # isolated source 

54 20, 

55 ] 

56 result = initData(shape, coords, amplitudes) 

57 targetPsfImage, psfImages, images, channels, seds, morphs, targetPsf, psfs = result 

58 B, Ny, Nx = shape 

59 

60 # Add some noise, otherwise the task will blow up due to 

61 # zero variance 

62 noise = 10*(np.random.rand(*images.shape).astype(np.float32)-.5) 

63 images += noise 

64 

65 filters = "grizy" 

66 _images = afwImage.MultibandMaskedImage.fromArrays(filters, images.astype(np.float32), None, noise) 

67 coadds = [afwImage.Exposure(img, dtype=img.image.array.dtype) for img in _images] 

68 coadds = afwImage.MultibandExposure.fromExposures(filters, coadds) 

69 for b, coadd in enumerate(coadds): 

70 coadd.setPsf(psfs[b]) 

71 

72 schema = SourceCatalog.Table.makeMinimalSchema() 

73 

74 detectionTask = SourceDetectionTask(schema=schema) 

75 

76 # Adjust config options to test skipping parents 

77 config = ScarletDeblendTask.ConfigClass() 

78 config.maxIter = 100 

79 config.maxFootprintArea = 1000 

80 config.maxNumberOfPeaks = 4 

81 deblendTask = ScarletDeblendTask(schema=schema, config=config) 

82 

83 table = SourceCatalog.Table.make(schema) 

84 detectionResult = detectionTask.run(table, coadds["r"]) 

85 catalog = detectionResult.sources 

86 

87 # Add a footprint that is too large 

88 src = catalog.addNew() 

89 halfLength = int(np.ceil(np.sqrt(config.maxFootprintArea) + 1)) 

90 ss = SpanSet.fromShape(halfLength, Stencil.BOX, offset=(50, 50)) 

91 bigfoot = Footprint(ss) 

92 bigfoot.addPeak(50, 50, 100) 

93 src.setFootprint(bigfoot) 

94 

95 # Add a footprint with too many peaks 

96 src = catalog.addNew() 

97 ss = SpanSet.fromShape(10, Stencil.BOX, offset=(75, 20)) 

98 denseFoot = Footprint(ss) 

99 for n in range(config.maxNumberOfPeaks+1): 

100 denseFoot.addPeak(70+2*n, 15+2*n, 10*n) 

101 src.setFootprint(denseFoot) 

102 

103 # Run the deblender 

104 result = deblendTask.run(coadds, catalog) 

105 

106 # Make sure that the catalogs have the same sources in all bands, 

107 # and check that band-independent columns are equal 

108 bandIndependentColumns = [ 

109 "id", 

110 "parent", 

111 "deblend_nPeaks", 

112 "deblend_nChild", 

113 "deblend_peak_center_x", 

114 "deblend_peak_center_y", 

115 "deblend_runtime", 

116 "deblend_iterations", 

117 "deblend_logL", 

118 "deblend_spectrumInitFlag", 

119 "deblend_blendConvergenceFailedFlag", 

120 ] 

121 self.assertEqual(len(filters), len(result)) 

122 ref = result[filters[0]] 

123 for f in filters[1:]: 

124 for col in bandIndependentColumns: 

125 np.testing.assert_array_equal(result[f][col], ref[col]) 

126 

127 # Check that other columns are consistent 

128 for f, _catalog in result.items(): 

129 parents = _catalog[_catalog["parent"] == 0] 

130 # Check that the number of deblended children is consistent 

131 self.assertEqual(np.sum(_catalog["deblend_nChild"]), len(_catalog)-len(parents)) 

132 

133 for parent in parents: 

134 children = _catalog[_catalog["parent"] == parent.get("id")] 

135 # Check that nChild is set correctly 

136 self.assertEqual(len(children), parent.get("deblend_nChild")) 

137 # Check that parent columns are propagated to their children 

138 for parentCol, childCol in config.columnInheritance.items(): 

139 np.testing.assert_array_equal(parent.get(parentCol), children[childCol]) 

140 

141 children = _catalog[_catalog["parent"] != 0] 

142 for child in children: 

143 fp = child.getFootprint() 

144 img = heavyFootprintToImage(fp) 

145 # Check that the flux at the center is correct. 

146 # Note: this only works in this test image because the 

147 # detected peak is in the same location as the scarlet peak. 

148 # If the peak is shifted, the flux value will be correct 

149 # but deblend_peak_center is not the correct location. 

150 px = child.get("deblend_peak_center_x") 

151 py = child.get("deblend_peak_center_y") 

152 flux = img.image[Point2I(px, py)] 

153 self.assertEqual(flux, child.get("deblend_peak_instFlux")) 

154 

155 # Check that the peak positions match the catalog entry 

156 peaks = fp.getPeaks() 

157 self.assertEqual(px, peaks[0].getIx()) 

158 self.assertEqual(py, peaks[0].getIy()) 

159 

160 # Check that all sources have the correct number of peaks 

161 for src in _catalog: 

162 fp = src.getFootprint() 

163 self.assertEqual(len(fp.peaks), src.get("deblend_nPeaks")) 

164 

165 # Check that only the large foorprint was flagged as too big 

166 largeFootprint = np.zeros(len(_catalog), dtype=bool) 

167 largeFootprint[2] = True 

168 np.testing.assert_array_equal(largeFootprint, _catalog["deblend_parentTooBig"]) 

169 

170 # Check that only the dense foorprint was flagged as too dense 

171 denseFootprint = np.zeros(len(_catalog), dtype=bool) 

172 denseFootprint[3] = True 

173 np.testing.assert_array_equal(denseFootprint, _catalog["deblend_tooManyPeaks"]) 

174 

175 # Check that only the appropriate parents were skipped 

176 skipped = largeFootprint | denseFootprint 

177 np.testing.assert_array_equal(skipped, _catalog["deblend_skipped"]) 

178 

179 

180class MemoryTester(lsst.utils.tests.MemoryTestCase): 

181 pass 

182 

183 

184def setup_module(module): 

185 lsst.utils.tests.init() 

186 

187 

188if __name__ == "__main__": 188 ↛ 189line 188 didn't jump to line 189, because the condition on line 188 was never true

189 lsst.utils.tests.init() 

190 unittest.main()