Coverage for tests/test_diaCalculation.py: 41%
99 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-03-27 02:49 -0700
« prev ^ index » next coverage.py v7.4.4, created at 2024-03-27 02:49 -0700
1# This file is part of ap_association.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import numpy as np
23import pandas as pd
24import unittest
26from lsst.meas.base import (
27 DiaObjectCalculationTask,
28 DiaObjectCalculationConfig,
29 DiaObjectCalculationPlugin)
30from lsst.meas.base.pluginRegistry import register
31import lsst.utils.tests
34@register("testCount")
35class CountDiaPlugin(DiaObjectCalculationPlugin):
36 """Simple mean function.
37 """
38 outputCols = ["count"]
40 @classmethod
41 def getExecutionOrder(cls):
42 return cls.DEFAULT_CATALOGCALCULATION
44 def calculate(self,
45 diaObjects,
46 diaObjectId,
47 diaSources,
48 filterDiaSources,
49 band,
50 **kwargs):
51 """
52 """
53 diaObjects.at[diaObjectId, "count"] = len(diaSources["psfFlux"])
56@register("testDiaPlugin")
57class DiaPlugin(DiaObjectCalculationPlugin):
58 """Simple mean function.
59 """
60 outputCols = ["MeanFlux", "StdFlux"]
62 plugType = "multi"
64 @classmethod
65 def getExecutionOrder(cls):
66 return cls.DEFAULT_CATALOGCALCULATION
68 def calculate(self,
69 diaObjects,
70 diaSources,
71 filterDiaSources,
72 band,
73 **kwargs):
74 """
75 """
76 diaObjects.loc[:, "%sMeanFlux" % band] = \
77 filterDiaSources.psfFlux.agg("mean")
78 diaObjects.loc[:, "%sStdFlux" % band] = \
79 filterDiaSources.psfFlux.agg("std")
82@register("testDependentDiaPlugin")
83class DependentDiaPlugin(DiaObjectCalculationPlugin):
84 """Simple calculation using the previously calculated mean.
85 """
86 inputCols = ["MeanFlux"]
87 outputCols = ["ChiFlux"]
89 @classmethod
90 def getExecutionOrder(cls):
91 return cls.FLUX_MOMENTS_CALCULATED
93 def calculate(self,
94 diaObjects,
95 diaObjectId,
96 diaSources,
97 filterDiaSources,
98 band,
99 **kwargs):
100 diaObjects.at[diaObjectId, "%sChiFlux" % band] = np.sum(
101 ((filterDiaSources["psfFlux"]
102 - diaObjects.at[diaObjectId, "%sMeanFlux" % band])
103 / filterDiaSources["psfFluxErr"]) ** 2)
106@register("testCollidingDiaPlugin")
107class CollidingDiaPlugin(DiaObjectCalculationPlugin):
108 """Simple calculation using the previously calculated mean.
109 """
110 outputCols = ["MeanFlux"]
112 @classmethod
113 def getExecutionOrder(cls):
114 return cls.FLUX_MOMENTS_CALCULATED
116 def calculate(self,
117 diaObjects,
118 diaObjectId,
119 diaSources,
120 filterDiaSources,
121 band,
122 **kwargs):
123 diaObjects.at[diaObjectId, "%sMeanFlux" % band] = 0.0
126class TestDiaCalcluation(unittest.TestCase):
128 def setUp(self):
129 # Create diaObjects
130 self.newDiaObjectId = 13
131 self.diaObjects = pd.DataFrame(
132 data=[{"diaObjectId": objId}
133 for objId in [0, 1, 2, 3, 4, 5, self.newDiaObjectId]])
135 # Create diaSources from "previous runs" and newly created ones.
136 diaSources = [{"diaSourceId": objId, "diaObjectId": objId,
137 "psfFlux": 0., "psfFluxErr": 1.,
138 "scienceFlux": 0., "scienceFluxErr": 1.,
139 "midpointMjdTai": 0, "band": "g"}
140 for objId in range(5)]
141 diaSources.extend([{"diaSourceId": 5 + objId, "diaObjectId": objId,
142 "psfFlux": 0., "psfFluxErr": 1.,
143 "scienceFlux": 0., "scienceFluxErr": 1.,
144 "midpointMjdTai": 0, "band": "r"}
145 for objId in range(5)])
146 diaSources.extend([{"diaSourceId": 10, "diaObjectId": 0,
147 "psfFlux": 1., "psfFluxErr": 1.,
148 "scienceFlux": 0., "scienceFluxErr": 0.,
149 "midpointMjdTai": 0, "band": "g"},
150 {"diaSourceId": 11, "diaObjectId": 1,
151 "psfFlux": 1., "psfFluxErr": 1.,
152 "scienceFlux": 0., "scienceFluxErr": 0.,
153 "midpointMjdTai": 0, "band": "g"},
154 {"diaSourceId": 12, "diaObjectId": 2,
155 "psfFlux": np.nan, "psfFluxErr": 1.,
156 "scienceFlux": 0., "scienceFluxErr": 0.,
157 "midpointMjdTai": 0, "band": "g"},
158 {"diaSourceId": self.newDiaObjectId,
159 "diaObjectId": self.newDiaObjectId,
160 "psfFlux": 1., "psfFluxErr": 1.,
161 "scienceFlux": 0., "scienceFluxErr": 0.,
162 "midpointMjdTai": 0, "band": "g"}])
163 self.diaSources = pd.DataFrame(data=diaSources)
165 self.updatedDiaObjectIds = np.array([0, 1, 2, self.newDiaObjectId],
166 dtype=np.int64)
168 conf = DiaObjectCalculationConfig()
169 conf.plugins = ["testDiaPlugin",
170 "testDependentDiaPlugin"]
171 self.diaObjCalTask = DiaObjectCalculationTask(config=conf)
173 def testRun(self):
174 """Test the run method and that diaObjects are updated correctly.
175 """
176 results = self.diaObjCalTask.run(self.diaObjects,
177 self.diaSources,
178 self.updatedDiaObjectIds,
179 ["g"])
180 diaObjectCat = results.diaObjectCat
181 updatedDiaObjects = results.updatedDiaObjects
182 updatedDiaObjects.set_index("diaObjectId", inplace=True)
183 # Test the lengths of the output dataframes.
184 self.assertEqual(len(diaObjectCat), len(self.diaObjects))
185 self.assertEqual(len(updatedDiaObjects),
186 len(self.updatedDiaObjectIds))
188 # Test values stored computed in the task.
189 for objId, diaObject in updatedDiaObjects.iterrows():
190 if objId == self.newDiaObjectId:
191 self.assertEqual(diaObject["gMeanFlux"], 1.)
192 self.assertTrue(np.isnan(diaObject["gStdFlux"]))
193 self.assertAlmostEqual(diaObject["gChiFlux"], 0.0)
194 elif objId == 2:
195 self.assertAlmostEqual(diaObject["gMeanFlux"], 0.0)
196 self.assertTrue(np.isnan(diaObject["gStdFlux"]))
197 self.assertAlmostEqual(diaObject["gChiFlux"], 0.0)
198 else:
199 self.assertAlmostEqual(diaObject["gMeanFlux"], 0.5)
200 self.assertAlmostEqual(diaObject["gStdFlux"],
201 0.7071067811865476)
202 self.assertAlmostEqual(diaObject["gChiFlux"], 0.5)
204 def testRunUnindexed(self):
205 """Test inputing un-indexed catalogs.
206 """
207 unindexedDiaSources = pd.DataFrame(data=[
208 {"diaSourceId": objId, "diaObjectId": 0,
209 "psfFlux": 0., "psfFluxErr": 1.,
210 "scienceFlux": 0., "scienceFluxErr": 1.,
211 "midpointMjdTai": 0, "band": "g"}
212 for objId in range(1000)])
213 unindexedDiaSources = pd.concat(
214 (
215 unindexedDiaSources,
216 pd.DataFrame(
217 data=[
218 {
219 "diaSourceId": objId + 1000,
220 "diaObjectId": 0,
221 "psfFlux": 0., "psfFluxErr": 1.,
222 "scienceFlux": 0., "scienceFluxErr": 1.,
223 "midpointMjdTai": 0, "band": "g",
224 }
225 for objId in range(10)
226 ]
227 )
228 )
229 )
231 conf = DiaObjectCalculationConfig()
232 conf.plugins = ["testCount"]
233 diaObjectCalTask = DiaObjectCalculationTask(config=conf)
234 self.diaObjects.reset_index()
235 results = diaObjectCalTask.run(self.diaObjects,
236 unindexedDiaSources,
237 np.array([0], dtype=np.int64),
238 ["g"])
239 updatedDiaObjects = results.updatedDiaObjects
240 self.assertEqual(updatedDiaObjects.at[0, "count"],
241 len(unindexedDiaSources))
243 def testConflictingPlugins(self):
244 """Test that code properly exits upon plugin collision.
245 """
246 with self.assertRaises(ValueError):
247 conf = DiaObjectCalculationConfig()
248 conf.plugins = ["testDependentDiaPlugin"]
249 DiaObjectCalculationTask(config=conf)
251 with self.assertRaises(ValueError):
252 conf = DiaObjectCalculationConfig()
253 conf.plugins = ["testDiaPlugin",
254 "testCollidingDiaPlugin",
255 "testDependentDiaPlugin"]
256 DiaObjectCalculationTask(config=conf)
258 # Test that ordering in the config does not matter and dependent
259 # plugin is instantiated after independent plugin. Would raise
260 # ValueError on failure.
261 conf = DiaObjectCalculationConfig()
262 conf.plugins = ["testDependentDiaPlugin",
263 "testDiaPlugin"]
264 DiaObjectCalculationTask(config=conf)
267def setup_module(module):
268 lsst.utils.tests.init()
271if __name__ == "__main__": 271 ↛ 272line 271 didn't jump to line 272, because the condition on line 271 was never true
272 lsst.utils.tests.init()
273 unittest.main()