Coverage for tests/test_metrics.py: 26%

110 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-01-06 02:46 -0800

1# This file is part of ip_diffim. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import unittest 

23 

24import astropy.units as u 

25from astropy.tests.helper import assert_quantity_allclose 

26 

27from lsst.afw.table import SourceCatalog 

28import lsst.utils.tests 

29import lsst.pipe.base.testUtils 

30from lsst.verify import Name 

31from lsst.verify.tasks.testUtils import MetricTaskTestCase 

32from lsst.verify.tasks import MetricComputationError 

33 

34from lsst.ip.diffim.metrics import \ 

35 NumberSciSourcesMetricTask, \ 

36 FractionDiaSourcesToSciSourcesMetricTask 

37 

38 

39def _makeDummyCatalog(size, skyFlag=False, priFlag=False): 

40 """Create a trivial catalog for testing source counts. 

41 

42 Parameters 

43 ---------- 

44 size : `int` 

45 The number of entries in the catalog. 

46 skyFlag : `bool` 

47 If set, the schema is guaranteed to have the ``sky_source`` flag, and 

48 one row has it set to `True`. If not set, the ``sky_source`` flag is 

49 not present. 

50 priFlag : `bool` 

51 As ``skyFlag``, but for a ``detect_isPrimary`` flag. 

52 

53 Returns 

54 ------- 

55 catalog : `lsst.afw.table.SourceCatalog` 

56 A new catalog with ``size`` rows. 

57 """ 

58 schema = SourceCatalog.Table.makeMinimalSchema() 

59 if skyFlag: 

60 schema.addField("sky_source", type="Flag", doc="Sky source.") 

61 if priFlag: 

62 schema.addField("detect_isPrimary", type="Flag", doc="Primary source.") 

63 catalog = SourceCatalog(schema) 

64 for i in range(size): 

65 record = catalog.addNew() 

66 if priFlag and size > 0: 

67 record["detect_isPrimary"] = True 

68 if skyFlag and size > 0: 

69 record["sky_source"] = True 

70 return catalog 

71 

72 

73class TestNumSciSources(MetricTaskTestCase): 

74 

75 @classmethod 

76 def makeTask(cls): 

77 return NumberSciSourcesMetricTask() 

78 

79 def testValid(self): 

80 catalog = _makeDummyCatalog(3) 

81 result = self.task.run(catalog) 

82 lsst.pipe.base.testUtils.assertValidOutput(self.task, result) 

83 meas = result.measurement 

84 

85 self.assertEqual(meas.metric_name, Name(metric="ip_diffim.numSciSources")) 

86 assert_quantity_allclose(meas.quantity, len(catalog) * u.count) 

87 

88 def testEmptyCatalog(self): 

89 catalog = _makeDummyCatalog(0) 

90 result = self.task.run(catalog) 

91 lsst.pipe.base.testUtils.assertValidOutput(self.task, result) 

92 meas = result.measurement 

93 

94 self.assertEqual(meas.metric_name, Name(metric="ip_diffim.numSciSources")) 

95 assert_quantity_allclose(meas.quantity, 0 * u.count) 

96 

97 def testSkySources(self): 

98 catalog = _makeDummyCatalog(3, skyFlag=True) 

99 result = self.task.run(catalog) 

100 lsst.pipe.base.testUtils.assertValidOutput(self.task, result) 

101 meas = result.measurement 

102 

103 self.assertEqual(meas.metric_name, Name(metric="ip_diffim.numSciSources")) 

104 assert_quantity_allclose(meas.quantity, (len(catalog) - 1) * u.count) 

105 

106 def testPrimarySources(self): 

107 catalog = _makeDummyCatalog(3, priFlag=True) 

108 result = self.task.run(catalog) 

109 lsst.pipe.base.testUtils.assertValidOutput(self.task, result) 

110 meas = result.measurement 

111 

112 self.assertEqual(meas.metric_name, Name(metric="ip_diffim.numSciSources")) 

113 assert_quantity_allclose(meas.quantity, 1 * u.count) 

114 

115 

116class TestFractionDiaSources(MetricTaskTestCase): 

117 

118 @classmethod 

119 def makeTask(cls): 

120 return FractionDiaSourcesToSciSourcesMetricTask() 

121 

122 def testValid(self): 

123 sciCatalog = _makeDummyCatalog(5) 

124 diaCatalog = _makeDummyCatalog(3) 

125 result = self.task.run(sciCatalog, diaCatalog) 

126 lsst.pipe.base.testUtils.assertValidOutput(self.task, result) 

127 meas = result.measurement 

128 

129 self.assertEqual(meas.metric_name, Name(metric="ip_diffim.fracDiaSourcesToSciSources")) 

130 assert_quantity_allclose(meas.quantity, len(diaCatalog) / len(sciCatalog) * u.dimensionless_unscaled) 

131 

132 def testEmptyDiaCatalog(self): 

133 sciCatalog = _makeDummyCatalog(5) 

134 diaCatalog = _makeDummyCatalog(0) 

135 result = self.task.run(sciCatalog, diaCatalog) 

136 lsst.pipe.base.testUtils.assertValidOutput(self.task, result) 

137 meas = result.measurement 

138 

139 self.assertEqual(meas.metric_name, Name(metric="ip_diffim.fracDiaSourcesToSciSources")) 

140 assert_quantity_allclose(meas.quantity, 0.0 * u.dimensionless_unscaled) 

141 

142 def testEmptySciCatalog(self): 

143 sciCatalog = _makeDummyCatalog(0) 

144 diaCatalog = _makeDummyCatalog(3) 

145 with self.assertRaises(MetricComputationError): 

146 self.task.run(sciCatalog, diaCatalog) 

147 

148 def testEmptyCatalogs(self): 

149 sciCatalog = _makeDummyCatalog(0) 

150 diaCatalog = _makeDummyCatalog(0) 

151 with self.assertRaises(MetricComputationError): 

152 self.task.run(sciCatalog, diaCatalog) 

153 

154 def testSkySources(self): 

155 sciCatalog = _makeDummyCatalog(5, skyFlag=True) 

156 diaCatalog = _makeDummyCatalog(3) 

157 result = self.task.run(sciCatalog, diaCatalog) 

158 lsst.pipe.base.testUtils.assertValidOutput(self.task, result) 

159 meas = result.measurement 

160 

161 self.assertEqual(meas.metric_name, Name(metric="ip_diffim.fracDiaSourcesToSciSources")) 

162 assert_quantity_allclose(meas.quantity, 

163 len(diaCatalog) / (len(sciCatalog) - 1) * u.dimensionless_unscaled) 

164 

165 def testPrimarySources(self): 

166 sciCatalog = _makeDummyCatalog(5, skyFlag=True, priFlag=True) 

167 diaCatalog = _makeDummyCatalog(3) 

168 result = self.task.run(sciCatalog, diaCatalog) 

169 lsst.pipe.base.testUtils.assertValidOutput(self.task, result) 

170 meas = result.measurement 

171 

172 self.assertEqual(meas.metric_name, Name(metric="ip_diffim.fracDiaSourcesToSciSources")) 

173 assert_quantity_allclose(meas.quantity, len(diaCatalog) * u.dimensionless_unscaled) 

174 

175 

176# Hack around unittest's hacky test setup system 

177del MetricTaskTestCase 

178 

179 

180class MemoryTester(lsst.utils.tests.MemoryTestCase): 

181 pass 

182 

183 

184def setup_module(module): 

185 lsst.utils.tests.init() 

186 

187 

188if __name__ == "__main__": 188 ↛ 189line 188 didn't jump to line 189, because the condition on line 188 was never true

189 lsst.utils.tests.init() 

190 unittest.main()