Coverage for tests/test_metrics.py: 26%

125 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-05 18:32 -0800

1# This file is part of ip_diffim. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import unittest 

23 

24import astropy.units as u 

25from astropy.tests.helper import assert_quantity_allclose 

26 

27from lsst.afw.table import SourceCatalog 

28import lsst.utils.tests 

29import lsst.pipe.base.testUtils 

30from lsst.verify import Name 

31from lsst.verify.gen2tasks.testUtils import MetricTaskTestCase 

32from lsst.verify.tasks import MetricComputationError 

33 

34from lsst.ip.diffim.metrics import \ 

35 NumberSciSourcesMetricTask, \ 

36 FractionDiaSourcesToSciSourcesMetricTask 

37 

38 

39def _makeDummyCatalog(size, skyFlag=False, priFlag=False): 

40 """Create a trivial catalog for testing source counts. 

41 

42 Parameters 

43 ---------- 

44 size : `int` 

45 The number of entries in the catalog. 

46 skyFlag : `bool` 

47 If set, the schema is guaranteed to have the ``sky_source`` flag, and 

48 one row has it set to `True`. If not set, the ``sky_source`` flag is 

49 not present. 

50 priFlag : `bool` 

51 As ``skyFlag``, but for a ``detect_isPrimary`` flag. 

52 

53 Returns 

54 ------- 

55 catalog : `lsst.afw.table.SourceCatalog` 

56 A new catalog with ``size`` rows. 

57 """ 

58 schema = SourceCatalog.Table.makeMinimalSchema() 

59 if skyFlag: 

60 schema.addField("sky_source", type="Flag", doc="Sky source.") 

61 if priFlag: 

62 schema.addField("detect_isPrimary", type="Flag", doc="Primary source.") 

63 catalog = SourceCatalog(schema) 

64 for i in range(size): 

65 record = catalog.addNew() 

66 if priFlag and size > 0: 

67 record["detect_isPrimary"] = True 

68 if skyFlag and size > 0: 

69 record["sky_source"] = True 

70 return catalog 

71 

72 

73class TestNumSciSources(MetricTaskTestCase): 

74 

75 @classmethod 

76 def makeTask(cls): 

77 return NumberSciSourcesMetricTask() 

78 

79 def testValid(self): 

80 catalog = _makeDummyCatalog(3) 

81 result = self.task.run(catalog) 

82 lsst.pipe.base.testUtils.assertValidOutput(self.task, result) 

83 meas = result.measurement 

84 

85 self.assertEqual(meas.metric_name, Name(metric="ip_diffim.numSciSources")) 

86 assert_quantity_allclose(meas.quantity, len(catalog) * u.count) 

87 

88 def testEmptyCatalog(self): 

89 catalog = _makeDummyCatalog(0) 

90 result = self.task.run(catalog) 

91 lsst.pipe.base.testUtils.assertValidOutput(self.task, result) 

92 meas = result.measurement 

93 

94 self.assertEqual(meas.metric_name, Name(metric="ip_diffim.numSciSources")) 

95 assert_quantity_allclose(meas.quantity, 0 * u.count) 

96 

97 def testSkySources(self): 

98 catalog = _makeDummyCatalog(3, skyFlag=True) 

99 result = self.task.run(catalog) 

100 lsst.pipe.base.testUtils.assertValidOutput(self.task, result) 

101 meas = result.measurement 

102 

103 self.assertEqual(meas.metric_name, Name(metric="ip_diffim.numSciSources")) 

104 assert_quantity_allclose(meas.quantity, (len(catalog) - 1) * u.count) 

105 

106 def testPrimarySources(self): 

107 catalog = _makeDummyCatalog(3, priFlag=True) 

108 result = self.task.run(catalog) 

109 lsst.pipe.base.testUtils.assertValidOutput(self.task, result) 

110 meas = result.measurement 

111 

112 self.assertEqual(meas.metric_name, Name(metric="ip_diffim.numSciSources")) 

113 assert_quantity_allclose(meas.quantity, 1 * u.count) 

114 

115 def testMissingData(self): 

116 result = self.task.run(None) 

117 lsst.pipe.base.testUtils.assertValidOutput(self.task, result) 

118 meas = result.measurement 

119 self.assertIsNone(meas) 

120 

121 

122class TestFractionDiaSources(MetricTaskTestCase): 

123 

124 @classmethod 

125 def makeTask(cls): 

126 return FractionDiaSourcesToSciSourcesMetricTask() 

127 

128 def testValid(self): 

129 sciCatalog = _makeDummyCatalog(5) 

130 diaCatalog = _makeDummyCatalog(3) 

131 result = self.task.run(sciCatalog, diaCatalog) 

132 lsst.pipe.base.testUtils.assertValidOutput(self.task, result) 

133 meas = result.measurement 

134 

135 self.assertEqual(meas.metric_name, Name(metric="ip_diffim.fracDiaSourcesToSciSources")) 

136 assert_quantity_allclose(meas.quantity, len(diaCatalog) / len(sciCatalog) * u.dimensionless_unscaled) 

137 

138 def testEmptyDiaCatalog(self): 

139 sciCatalog = _makeDummyCatalog(5) 

140 diaCatalog = _makeDummyCatalog(0) 

141 result = self.task.run(sciCatalog, diaCatalog) 

142 lsst.pipe.base.testUtils.assertValidOutput(self.task, result) 

143 meas = result.measurement 

144 

145 self.assertEqual(meas.metric_name, Name(metric="ip_diffim.fracDiaSourcesToSciSources")) 

146 assert_quantity_allclose(meas.quantity, 0.0 * u.dimensionless_unscaled) 

147 

148 def testEmptySciCatalog(self): 

149 sciCatalog = _makeDummyCatalog(0) 

150 diaCatalog = _makeDummyCatalog(3) 

151 with self.assertRaises(MetricComputationError): 

152 self.task.run(sciCatalog, diaCatalog) 

153 

154 def testEmptyCatalogs(self): 

155 sciCatalog = _makeDummyCatalog(0) 

156 diaCatalog = _makeDummyCatalog(0) 

157 with self.assertRaises(MetricComputationError): 

158 self.task.run(sciCatalog, diaCatalog) 

159 

160 def testMissingData(self): 

161 result = self.task.run(None, None) 

162 lsst.pipe.base.testUtils.assertValidOutput(self.task, result) 

163 meas = result.measurement 

164 self.assertIsNone(meas) 

165 

166 def testSemiMissingData(self): 

167 result = self.task.run(sciSources=_makeDummyCatalog(3), diaSources=None) 

168 lsst.pipe.base.testUtils.assertValidOutput(self.task, result) 

169 meas = result.measurement 

170 self.assertIsNone(meas) 

171 

172 def testSkySources(self): 

173 sciCatalog = _makeDummyCatalog(5, skyFlag=True) 

174 diaCatalog = _makeDummyCatalog(3) 

175 result = self.task.run(sciCatalog, diaCatalog) 

176 lsst.pipe.base.testUtils.assertValidOutput(self.task, result) 

177 meas = result.measurement 

178 

179 self.assertEqual(meas.metric_name, Name(metric="ip_diffim.fracDiaSourcesToSciSources")) 

180 assert_quantity_allclose(meas.quantity, 

181 len(diaCatalog) / (len(sciCatalog) - 1) * u.dimensionless_unscaled) 

182 

183 def testPrimarySources(self): 

184 sciCatalog = _makeDummyCatalog(5, skyFlag=True, priFlag=True) 

185 diaCatalog = _makeDummyCatalog(3) 

186 result = self.task.run(sciCatalog, diaCatalog) 

187 lsst.pipe.base.testUtils.assertValidOutput(self.task, result) 

188 meas = result.measurement 

189 

190 self.assertEqual(meas.metric_name, Name(metric="ip_diffim.fracDiaSourcesToSciSources")) 

191 assert_quantity_allclose(meas.quantity, len(diaCatalog) * u.dimensionless_unscaled) 

192 

193 

194# Hack around unittest's hacky test setup system 

195del MetricTaskTestCase 

196 

197 

198class MemoryTester(lsst.utils.tests.MemoryTestCase): 

199 pass 

200 

201 

202def setup_module(module): 

203 lsst.utils.tests.init() 

204 

205 

206if __name__ == "__main__": 206 ↛ 207line 206 didn't jump to line 207, because the condition on line 206 was never true

207 lsst.utils.tests.init() 

208 unittest.main()