Coverage for tests/test_metrics.py: 23%

153 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-10-12 11:10 +0000

1# This file is part of pipe_tasks. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import unittest 

23 

24import numpy as np 

25import astropy.units as u 

26from astropy.tests.helper import assert_quantity_allclose 

27 

28from lsst.afw.table import SourceCatalog 

29import lsst.utils.tests 

30import lsst.pipe.base.testUtils 

31from lsst.verify import Name 

32from lsst.verify.tasks.testUtils import MetricTaskTestCase 

33 

34from lsst.pipe.tasks.metrics import \ 

35 NumberDeblendedSourcesMetricTask, NumberDeblendChildSourcesMetricTask 

36 

37 

38def _makeDummyCatalog(nParents, *, skyFlags=False, deblendFlags=False, nChildren=0, nGrandchildren=0): 

39 """Create a trivial catalog for testing deblending counts. 

40 

41 Parameters 

42 ---------- 

43 nParents : `int` 

44 The number of entries in the catalog prior to deblending. 

45 skyFlags : `bool` 

46 If set, the schema includes flags associated with sky sources, 

47 and one top-level source (the deblended one, if it exists) and any 

48 descendents are sky sources. 

49 deblendFlags : `bool` 

50 If set, the schema includes flags associated with the deblender. 

51 nChildren : `int` 

52 If positive, one source is deblended into ``nChildren`` children. This 

53 parameter is ignored if ``deblendFlags`` is `False`. 

54 nGrandchildren : `int` 

55 If positive, one source produced by ``nChildren`` is deblended into 

56 ``nGrandchildren`` children. This parameter is ignored if ``nChildren`` 

57 is 0 or not applicable. 

58 

59 Returns 

60 ------- 

61 catalog : `lsst.afw.table.SourceCatalog` 

62 A new catalog with ``nParents + nChildren + nGrandchildren`` rows. 

63 """ 

64 schema = SourceCatalog.Table.makeMinimalSchema() 

65 if skyFlags: 

66 schema.addField("sky_source", type="Flag", doc="Sky source.") 

67 if deblendFlags: 

68 # See https://community.lsst.org/t/4957 for flag definitions. 

69 # Do not use detect_ flags, as they are defined by a postprocessing 

70 # task and some post-deblend catalogs may not have them. 

71 schema.addField('deblend_nChild', type=np.int32, doc='') 

72 schema.addField('deblend_nPeaks', type=np.int32, doc='') 

73 schema.addField('deblend_parentNPeaks', type=np.int32, doc='') 

74 schema.addField('deblend_parentNChild', type=np.int32, doc='') 

75 catalog = SourceCatalog(schema) 

76 if nParents > 0: # normally anti-pattern, but simplifies nested ifs 

77 for i in range(nParents): 

78 record = catalog.addNew() 

79 if deblendFlags: 

80 record["deblend_nPeaks"] = 1 

81 if skyFlags: 

82 record["sky_source"] = True 

83 if deblendFlags and nChildren > 0: 

84 children = _addChildren(catalog, record, nChildren) 

85 if nGrandchildren > 0: 

86 _addChildren(catalog, children[0], nGrandchildren) 

87 return catalog 

88 

89 

90def _addChildren(catalog, parent, nChildren): 

91 """Add children to a catalog source. 

92 

93 Parameters 

94 ---------- 

95 catalog : `lsst.afw.table.SourceCatalog` 

96 The catalog to update. Its schema must contain all supported 

97 deblender flags. 

98 parent : `lsst.afw.table.SourceRecord` 

99 The source record to serve as the parent for any new children. Must be 

100 an element of ``catalog`` (not validated). 

101 nChildren : `int` 

102 The number of children of ``parent`` to add to ``catalog``. 

103 

104 Returns 

105 ------- 

106 children : `list` [`lsst.afw.table.SourceRecord`] 

107 A list of the ``nChildren`` new children. 

108 """ 

109 newRecords = [] 

110 if nChildren > 0: 

111 parent["deblend_nChild"] = nChildren 

112 parent["deblend_nPeaks"] = nChildren 

113 parentId = parent.getId() 

114 for i in range(nChildren): 

115 child = catalog.addNew() 

116 child.setParent(parentId) 

117 child["deblend_parentNPeaks"] = nChildren 

118 child["deblend_parentNChild"] = nChildren 

119 if "sky_source" in parent.schema: 

120 child["sky_source"] = parent["sky_source"] 

121 newRecords.append(child) 

122 return newRecords 

123 

124 

125class TestNumDeblended(MetricTaskTestCase): 

126 

127 METRIC_NAME = Name(metric="pipe_tasks.numDeblendedSciSources") 

128 

129 @classmethod 

130 def makeTask(cls): 

131 return NumberDeblendedSourcesMetricTask() 

132 

133 def testValid(self): 

134 catalog = _makeDummyCatalog(3, deblendFlags=True, nChildren=2) 

135 result = self.task.run(catalog) 

136 lsst.pipe.base.testUtils.assertValidOutput(self.task, result) 

137 meas = result.measurement 

138 

139 self.assertEqual(meas.metric_name, self.METRIC_NAME) 

140 assert_quantity_allclose(meas.quantity, u.Quantity(1)) 

141 

142 def testEmptyCatalog(self): 

143 catalog = _makeDummyCatalog(0, deblendFlags=True) 

144 result = self.task.run(catalog) 

145 lsst.pipe.base.testUtils.assertValidOutput(self.task, result) 

146 meas = result.measurement 

147 

148 self.assertEqual(meas.metric_name, self.METRIC_NAME) 

149 assert_quantity_allclose(meas.quantity, u.Quantity(0)) 

150 

151 def testNothingDeblended(self): 

152 catalog = _makeDummyCatalog(3, deblendFlags=True, nChildren=0) 

153 result = self.task.run(catalog) 

154 lsst.pipe.base.testUtils.assertValidOutput(self.task, result) 

155 meas = result.measurement 

156 

157 self.assertEqual(meas.metric_name, self.METRIC_NAME) 

158 assert_quantity_allclose(meas.quantity, u.Quantity(0)) 

159 

160 def testSkyIgnored(self): 

161 catalog = _makeDummyCatalog(3, skyFlags=True, deblendFlags=True, nChildren=2) 

162 result = self.task.run(catalog) 

163 lsst.pipe.base.testUtils.assertValidOutput(self.task, result) 

164 meas = result.measurement 

165 

166 self.assertEqual(meas.metric_name, self.METRIC_NAME) 

167 assert_quantity_allclose(meas.quantity, u.Quantity(0)) 

168 

169 def testMultiDeblending(self): 

170 catalog = _makeDummyCatalog(5, deblendFlags=True, nChildren=3, nGrandchildren=2) 

171 result = self.task.run(catalog) 

172 lsst.pipe.base.testUtils.assertValidOutput(self.task, result) 

173 meas = result.measurement 

174 

175 self.assertEqual(meas.metric_name, self.METRIC_NAME) 

176 assert_quantity_allclose(meas.quantity, u.Quantity(1)) 

177 

178 def testNoDeblending(self): 

179 catalog = _makeDummyCatalog(3, deblendFlags=False) 

180 try: 

181 result = self.task.run(catalog) 

182 except lsst.pipe.base.NoWorkFound: 

183 # Correct behavior 

184 pass 

185 else: 

186 # Alternative correct behavior 

187 lsst.pipe.base.testUtils.assertValidOutput(self.task, result) 

188 meas = result.measurement 

189 self.assertIsNone(meas) 

190 

191 

192class TestNumDeblendChild(MetricTaskTestCase): 

193 

194 METRIC_NAME = Name(metric="pipe_tasks.numDeblendChildSciSources") 

195 

196 @classmethod 

197 def makeTask(cls): 

198 return NumberDeblendChildSourcesMetricTask() 

199 

200 def testValid(self): 

201 catalog = _makeDummyCatalog(3, deblendFlags=True, nChildren=2) 

202 result = self.task.run(catalog) 

203 lsst.pipe.base.testUtils.assertValidOutput(self.task, result) 

204 meas = result.measurement 

205 

206 self.assertEqual(meas.metric_name, self.METRIC_NAME) 

207 assert_quantity_allclose(meas.quantity, u.Quantity(2)) 

208 

209 def testEmptyCatalog(self): 

210 catalog = _makeDummyCatalog(0, deblendFlags=True) 

211 result = self.task.run(catalog) 

212 lsst.pipe.base.testUtils.assertValidOutput(self.task, result) 

213 meas = result.measurement 

214 

215 self.assertEqual(meas.metric_name, self.METRIC_NAME) 

216 assert_quantity_allclose(meas.quantity, u.Quantity(0)) 

217 

218 def testNothingDeblended(self): 

219 catalog = _makeDummyCatalog(3, deblendFlags=True, nChildren=0) 

220 result = self.task.run(catalog) 

221 lsst.pipe.base.testUtils.assertValidOutput(self.task, result) 

222 meas = result.measurement 

223 

224 self.assertEqual(meas.metric_name, self.METRIC_NAME) 

225 assert_quantity_allclose(meas.quantity, u.Quantity(0)) 

226 

227 def testSkyIgnored(self): 

228 catalog = _makeDummyCatalog(3, skyFlags=True, deblendFlags=True, nChildren=2) 

229 result = self.task.run(catalog) 

230 lsst.pipe.base.testUtils.assertValidOutput(self.task, result) 

231 meas = result.measurement 

232 

233 self.assertEqual(meas.metric_name, self.METRIC_NAME) 

234 assert_quantity_allclose(meas.quantity, u.Quantity(0)) 

235 

236 def testMultiDeblending(self): 

237 catalog = _makeDummyCatalog(5, deblendFlags=True, nChildren=3, nGrandchildren=2) 

238 result = self.task.run(catalog) 

239 lsst.pipe.base.testUtils.assertValidOutput(self.task, result) 

240 meas = result.measurement 

241 

242 self.assertEqual(meas.metric_name, self.METRIC_NAME) 

243 # Expect 2 from first-level children and 2 from subchildren 

244 assert_quantity_allclose(meas.quantity, u.Quantity(4)) 

245 

246 def testNoDeblending(self): 

247 catalog = _makeDummyCatalog(3, deblendFlags=False) 

248 try: 

249 result = self.task.run(catalog) 

250 except lsst.pipe.base.NoWorkFound: 

251 # Correct behavior 

252 pass 

253 else: 

254 # Alternative correct behavior 

255 lsst.pipe.base.testUtils.assertValidOutput(self.task, result) 

256 meas = result.measurement 

257 self.assertIsNone(meas) 

258 

259 

260# Hack around unittest's hacky test setup system 

261del MetricTaskTestCase 

262 

263 

264class MemoryTester(lsst.utils.tests.MemoryTestCase): 

265 pass 

266 

267 

268def setup_module(module): 

269 lsst.utils.tests.init() 

270 

271 

272if __name__ == "__main__": 272 ↛ 273line 272 didn't jump to line 273, because the condition on line 272 was never true

273 lsst.utils.tests.init() 

274 unittest.main()