Coverage for tests/test_metrics.py: 26%
125 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-05 18:32 -0800
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-05 18:32 -0800
1# This file is part of ip_diffim.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import unittest
24import astropy.units as u
25from astropy.tests.helper import assert_quantity_allclose
27from lsst.afw.table import SourceCatalog
28import lsst.utils.tests
29import lsst.pipe.base.testUtils
30from lsst.verify import Name
31from lsst.verify.gen2tasks.testUtils import MetricTaskTestCase
32from lsst.verify.tasks import MetricComputationError
34from lsst.ip.diffim.metrics import \
35 NumberSciSourcesMetricTask, \
36 FractionDiaSourcesToSciSourcesMetricTask
39def _makeDummyCatalog(size, skyFlag=False, priFlag=False):
40 """Create a trivial catalog for testing source counts.
42 Parameters
43 ----------
44 size : `int`
45 The number of entries in the catalog.
46 skyFlag : `bool`
47 If set, the schema is guaranteed to have the ``sky_source`` flag, and
48 one row has it set to `True`. If not set, the ``sky_source`` flag is
49 not present.
50 priFlag : `bool`
51 As ``skyFlag``, but for a ``detect_isPrimary`` flag.
53 Returns
54 -------
55 catalog : `lsst.afw.table.SourceCatalog`
56 A new catalog with ``size`` rows.
57 """
58 schema = SourceCatalog.Table.makeMinimalSchema()
59 if skyFlag:
60 schema.addField("sky_source", type="Flag", doc="Sky source.")
61 if priFlag:
62 schema.addField("detect_isPrimary", type="Flag", doc="Primary source.")
63 catalog = SourceCatalog(schema)
64 for i in range(size):
65 record = catalog.addNew()
66 if priFlag and size > 0:
67 record["detect_isPrimary"] = True
68 if skyFlag and size > 0:
69 record["sky_source"] = True
70 return catalog
73class TestNumSciSources(MetricTaskTestCase):
75 @classmethod
76 def makeTask(cls):
77 return NumberSciSourcesMetricTask()
79 def testValid(self):
80 catalog = _makeDummyCatalog(3)
81 result = self.task.run(catalog)
82 lsst.pipe.base.testUtils.assertValidOutput(self.task, result)
83 meas = result.measurement
85 self.assertEqual(meas.metric_name, Name(metric="ip_diffim.numSciSources"))
86 assert_quantity_allclose(meas.quantity, len(catalog) * u.count)
88 def testEmptyCatalog(self):
89 catalog = _makeDummyCatalog(0)
90 result = self.task.run(catalog)
91 lsst.pipe.base.testUtils.assertValidOutput(self.task, result)
92 meas = result.measurement
94 self.assertEqual(meas.metric_name, Name(metric="ip_diffim.numSciSources"))
95 assert_quantity_allclose(meas.quantity, 0 * u.count)
97 def testSkySources(self):
98 catalog = _makeDummyCatalog(3, skyFlag=True)
99 result = self.task.run(catalog)
100 lsst.pipe.base.testUtils.assertValidOutput(self.task, result)
101 meas = result.measurement
103 self.assertEqual(meas.metric_name, Name(metric="ip_diffim.numSciSources"))
104 assert_quantity_allclose(meas.quantity, (len(catalog) - 1) * u.count)
106 def testPrimarySources(self):
107 catalog = _makeDummyCatalog(3, priFlag=True)
108 result = self.task.run(catalog)
109 lsst.pipe.base.testUtils.assertValidOutput(self.task, result)
110 meas = result.measurement
112 self.assertEqual(meas.metric_name, Name(metric="ip_diffim.numSciSources"))
113 assert_quantity_allclose(meas.quantity, 1 * u.count)
115 def testMissingData(self):
116 result = self.task.run(None)
117 lsst.pipe.base.testUtils.assertValidOutput(self.task, result)
118 meas = result.measurement
119 self.assertIsNone(meas)
122class TestFractionDiaSources(MetricTaskTestCase):
124 @classmethod
125 def makeTask(cls):
126 return FractionDiaSourcesToSciSourcesMetricTask()
128 def testValid(self):
129 sciCatalog = _makeDummyCatalog(5)
130 diaCatalog = _makeDummyCatalog(3)
131 result = self.task.run(sciCatalog, diaCatalog)
132 lsst.pipe.base.testUtils.assertValidOutput(self.task, result)
133 meas = result.measurement
135 self.assertEqual(meas.metric_name, Name(metric="ip_diffim.fracDiaSourcesToSciSources"))
136 assert_quantity_allclose(meas.quantity, len(diaCatalog) / len(sciCatalog) * u.dimensionless_unscaled)
138 def testEmptyDiaCatalog(self):
139 sciCatalog = _makeDummyCatalog(5)
140 diaCatalog = _makeDummyCatalog(0)
141 result = self.task.run(sciCatalog, diaCatalog)
142 lsst.pipe.base.testUtils.assertValidOutput(self.task, result)
143 meas = result.measurement
145 self.assertEqual(meas.metric_name, Name(metric="ip_diffim.fracDiaSourcesToSciSources"))
146 assert_quantity_allclose(meas.quantity, 0.0 * u.dimensionless_unscaled)
148 def testEmptySciCatalog(self):
149 sciCatalog = _makeDummyCatalog(0)
150 diaCatalog = _makeDummyCatalog(3)
151 with self.assertRaises(MetricComputationError):
152 self.task.run(sciCatalog, diaCatalog)
154 def testEmptyCatalogs(self):
155 sciCatalog = _makeDummyCatalog(0)
156 diaCatalog = _makeDummyCatalog(0)
157 with self.assertRaises(MetricComputationError):
158 self.task.run(sciCatalog, diaCatalog)
160 def testMissingData(self):
161 result = self.task.run(None, None)
162 lsst.pipe.base.testUtils.assertValidOutput(self.task, result)
163 meas = result.measurement
164 self.assertIsNone(meas)
166 def testSemiMissingData(self):
167 result = self.task.run(sciSources=_makeDummyCatalog(3), diaSources=None)
168 lsst.pipe.base.testUtils.assertValidOutput(self.task, result)
169 meas = result.measurement
170 self.assertIsNone(meas)
172 def testSkySources(self):
173 sciCatalog = _makeDummyCatalog(5, skyFlag=True)
174 diaCatalog = _makeDummyCatalog(3)
175 result = self.task.run(sciCatalog, diaCatalog)
176 lsst.pipe.base.testUtils.assertValidOutput(self.task, result)
177 meas = result.measurement
179 self.assertEqual(meas.metric_name, Name(metric="ip_diffim.fracDiaSourcesToSciSources"))
180 assert_quantity_allclose(meas.quantity,
181 len(diaCatalog) / (len(sciCatalog) - 1) * u.dimensionless_unscaled)
183 def testPrimarySources(self):
184 sciCatalog = _makeDummyCatalog(5, skyFlag=True, priFlag=True)
185 diaCatalog = _makeDummyCatalog(3)
186 result = self.task.run(sciCatalog, diaCatalog)
187 lsst.pipe.base.testUtils.assertValidOutput(self.task, result)
188 meas = result.measurement
190 self.assertEqual(meas.metric_name, Name(metric="ip_diffim.fracDiaSourcesToSciSources"))
191 assert_quantity_allclose(meas.quantity, len(diaCatalog) * u.dimensionless_unscaled)
194# Hack around unittest's hacky test setup system
195del MetricTaskTestCase
198class MemoryTester(lsst.utils.tests.MemoryTestCase):
199 pass
202def setup_module(module):
203 lsst.utils.tests.init()
206if __name__ == "__main__": 206 ↛ 207line 206 didn't jump to line 207, because the condition on line 206 was never true
207 lsst.utils.tests.init()
208 unittest.main()