Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of ip_diffim. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import unittest 

23 

24import astropy.units as u 

25 

26from lsst.afw.table import SourceCatalog 

27import lsst.utils.tests 

28import lsst.pipe.base.testUtils 

29from lsst.verify import Name 

30from lsst.verify.gen2tasks.testUtils import MetricTaskTestCase 

31from lsst.verify.tasks import MetricComputationError 

32 

33from lsst.ip.diffim.metrics import \ 

34 NumberSciSourcesMetricTask, \ 

35 FractionDiaSourcesToSciSourcesMetricTask 

36 

37 

38def _makeDummyCatalog(size, skyFlag=False): 

39 """Create a trivial catalog for testing source counts. 

40 

41 Parameters 

42 ---------- 

43 size : `int` 

44 The number of entries in the catalog. 

45 skyFlag : `bool` 

46 If set, the schema is guaranteed to have the ``sky_source`` flag, and 

47 one row has it set to `True`. If not set, the ``sky_source`` flag is 

48 not present. 

49 

50 Returns 

51 ------- 

52 catalog : `lsst.afw.table.SourceCatalog` 

53 A new catalog with ``size`` rows. 

54 """ 

55 schema = SourceCatalog.Table.makeMinimalSchema() 

56 if skyFlag: 

57 schema.addField("sky_source", type="Flag", doc="Sky objects.") 

58 catalog = SourceCatalog(schema) 

59 for i in range(size): 

60 record = catalog.addNew() 

61 if skyFlag and size > 0: 

62 record["sky_source"] = True 

63 return catalog 

64 

65 

66class TestNumSciSources(MetricTaskTestCase): 

67 

68 @classmethod 

69 def makeTask(cls): 

70 return NumberSciSourcesMetricTask() 

71 

72 def testValid(self): 

73 catalog = _makeDummyCatalog(3) 

74 result = self.task.run(catalog) 

75 lsst.pipe.base.testUtils.assertValidOutput(self.task, result) 

76 meas = result.measurement 

77 

78 self.assertEqual(meas.metric_name, Name(metric="ip_diffim.numSciSources")) 

79 self.assertEqual(meas.quantity, len(catalog) * u.count) 

80 

81 def testEmptyCatalog(self): 

82 catalog = _makeDummyCatalog(0) 

83 result = self.task.run(catalog) 

84 lsst.pipe.base.testUtils.assertValidOutput(self.task, result) 

85 meas = result.measurement 

86 

87 self.assertEqual(meas.metric_name, Name(metric="ip_diffim.numSciSources")) 

88 self.assertEqual(meas.quantity, 0 * u.count) 

89 

90 def testSkySources(self): 

91 catalog = _makeDummyCatalog(3, skyFlag=True) 

92 result = self.task.run(catalog) 

93 lsst.pipe.base.testUtils.assertValidOutput(self.task, result) 

94 meas = result.measurement 

95 

96 self.assertEqual(meas.metric_name, Name(metric="ip_diffim.numSciSources")) 

97 self.assertEqual(meas.quantity, (len(catalog) - 1) * u.count) 

98 

99 def testMissingData(self): 

100 result = self.task.run(None) 

101 lsst.pipe.base.testUtils.assertValidOutput(self.task, result) 

102 meas = result.measurement 

103 self.assertIsNone(meas) 

104 

105 

106class TestFractionDiaSources(MetricTaskTestCase): 

107 

108 @classmethod 

109 def makeTask(cls): 

110 return FractionDiaSourcesToSciSourcesMetricTask() 

111 

112 def testValid(self): 

113 sciCatalog = _makeDummyCatalog(5) 

114 diaCatalog = _makeDummyCatalog(3) 

115 result = self.task.run(sciCatalog, diaCatalog) 

116 lsst.pipe.base.testUtils.assertValidOutput(self.task, result) 

117 meas = result.measurement 

118 

119 self.assertEqual(meas.metric_name, Name(metric="ip_diffim.fracDiaSourcesToSciSources")) 

120 self.assertEqual(meas.quantity, len(diaCatalog) / len(sciCatalog) * u.dimensionless_unscaled) 

121 

122 def testEmptyDiaCatalog(self): 

123 sciCatalog = _makeDummyCatalog(5) 

124 diaCatalog = _makeDummyCatalog(0) 

125 result = self.task.run(sciCatalog, diaCatalog) 

126 lsst.pipe.base.testUtils.assertValidOutput(self.task, result) 

127 meas = result.measurement 

128 

129 self.assertEqual(meas.metric_name, Name(metric="ip_diffim.fracDiaSourcesToSciSources")) 

130 self.assertEqual(meas.quantity, 0.0 * u.dimensionless_unscaled) 

131 

132 def testEmptySciCatalog(self): 

133 sciCatalog = _makeDummyCatalog(0) 

134 diaCatalog = _makeDummyCatalog(3) 

135 with self.assertRaises(MetricComputationError): 

136 self.task.run(sciCatalog, diaCatalog) 

137 

138 def testEmptyCatalogs(self): 

139 sciCatalog = _makeDummyCatalog(0) 

140 diaCatalog = _makeDummyCatalog(0) 

141 with self.assertRaises(MetricComputationError): 

142 self.task.run(sciCatalog, diaCatalog) 

143 

144 def testMissingData(self): 

145 result = self.task.run(None, None) 

146 lsst.pipe.base.testUtils.assertValidOutput(self.task, result) 

147 meas = result.measurement 

148 self.assertIsNone(meas) 

149 

150 def testSemiMissingData(self): 

151 result = self.task.run(sciSources=_makeDummyCatalog(3), diaSources=None) 

152 lsst.pipe.base.testUtils.assertValidOutput(self.task, result) 

153 meas = result.measurement 

154 self.assertIsNone(meas) 

155 

156 def testSkySources(self): 

157 sciCatalog = _makeDummyCatalog(5, skyFlag=True) 

158 diaCatalog = _makeDummyCatalog(3) 

159 result = self.task.run(sciCatalog, diaCatalog) 

160 lsst.pipe.base.testUtils.assertValidOutput(self.task, result) 

161 meas = result.measurement 

162 

163 self.assertEqual(meas.metric_name, Name(metric="ip_diffim.fracDiaSourcesToSciSources")) 

164 self.assertEqual(meas.quantity, len(diaCatalog) / (len(sciCatalog) - 1) * u.dimensionless_unscaled) 

165 

166 

167# Hack around unittest's hacky test setup system 

168del MetricTaskTestCase 

169 

170 

171class MemoryTester(lsst.utils.tests.MemoryTestCase): 

172 pass 

173 

174 

175def setup_module(module): 

176 lsst.utils.tests.init() 

177 

178 

179if __name__ == "__main__": 179 ↛ 180line 179 didn't jump to line 180, because the condition on line 179 was never true

180 lsst.utils.tests.init() 

181 unittest.main()