Coverage for tests/test_diaCalculation.py: 37%
99 statements
« prev ^ index » next coverage.py v6.5.0, created at 2022-12-09 02:11 -0800
« prev ^ index » next coverage.py v6.5.0, created at 2022-12-09 02:11 -0800
1# This file is part of ap_association.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import numpy as np
23import pandas as pd
24import unittest
26from lsst.meas.base import (
27 DiaObjectCalculationTask,
28 DiaObjectCalculationConfig,
29 DiaObjectCalculationPlugin)
30from lsst.meas.base.pluginRegistry import register
31import lsst.utils.tests
34@register("testCount")
35class CountDiaPlugin(DiaObjectCalculationPlugin):
36 """Simple mean function.
37 """
38 outputCols = ["count"]
40 @classmethod
41 def getExecutionOrder(cls):
42 return cls.DEFAULT_CATALOGCALCULATION
44 def calculate(self,
45 diaObjects,
46 diaObjectId,
47 diaSources,
48 filterDiaSources,
49 filterName,
50 **kwargs):
51 """
52 """
53 diaObjects.at[diaObjectId, "count"] = len(diaSources["psFlux"])
56@register("testDiaPlugin")
57class DiaPlugin(DiaObjectCalculationPlugin):
58 """Simple mean function.
59 """
60 outputCols = ["MeanFlux", "StdFlux"]
62 plugType = "multi"
64 @classmethod
65 def getExecutionOrder(cls):
66 return cls.DEFAULT_CATALOGCALCULATION
68 def calculate(self,
69 diaObjects,
70 diaSources,
71 filterDiaSources,
72 filterName,
73 **kwargs):
74 """
75 """
76 diaObjects.loc[:, "%sMeanFlux" % filterName] = \
77 filterDiaSources.psFlux.agg(np.nanmean)
78 diaObjects.loc[:, "%sStdFlux" % filterName] = \
79 filterDiaSources.psFlux.agg(np.nanstd)
82@register("testDependentDiaPlugin")
83class DependentDiaPlugin(DiaObjectCalculationPlugin):
84 """Simple calculation using the previously calculated mean.
85 """
86 inputCols = ["MeanFlux"]
87 outputCols = ["ChiFlux"]
89 @classmethod
90 def getExecutionOrder(cls):
91 return cls.FLUX_MOMENTS_CALCULATED
93 def calculate(self,
94 diaObjects,
95 diaObjectId,
96 diaSources,
97 filterDiaSources,
98 filterName,
99 **kwargs):
100 diaObjects.at[diaObjectId, "%sChiFlux" % filterName] = np.sum(
101 ((filterDiaSources["psFlux"]
102 - diaObjects.at[diaObjectId, "%sMeanFlux" % filterName])
103 / filterDiaSources["psFluxErr"]) ** 2)
106@register("testCollidingDiaPlugin")
107class CollidingDiaPlugin(DiaObjectCalculationPlugin):
108 """Simple calculation using the previously calculated mean.
109 """
110 outputCols = ["MeanFlux"]
112 @classmethod
113 def getExecutionOrder(cls):
114 return cls.FLUX_MOMENTS_CALCULATED
116 def calculate(self,
117 diaObjects,
118 diaObjectId,
119 diaSources,
120 filterDiaSources,
121 filterName,
122 **kwargs):
123 diaObjects.at[diaObjectId, "%sMeanFlux" % filterName] = 0.0
126class TestDiaCalcluation(unittest.TestCase):
128 def setUp(self):
129 # Create diaObjects
130 self.newDiaObjectId = 13
131 self.diaObjects = pd.DataFrame(
132 data=[{"diaObjectId": objId}
133 for objId in [0, 1, 2, 3, 4, 5, self.newDiaObjectId]])
135 # Create diaSources from "previous runs" and newly created ones.
136 diaSources = [{"diaSourceId": objId, "diaObjectId": objId,
137 "psFlux": 0., "psFluxErr": 1.,
138 "totFlux": 0., "totFluxErr": 1.,
139 "midPointTai": 0, "filterName": "g"}
140 for objId in range(5)]
141 diaSources.extend([{"diaSourceId": 5 + objId, "diaObjectId": objId,
142 "psFlux": 0., "psFluxErr": 1.,
143 "totFlux": 0., "totFluxErr": 1.,
144 "midPointTai": 0, "filterName": "r"}
145 for objId in range(5)])
146 diaSources.extend([{"diaSourceId": 10, "diaObjectId": 0,
147 "psFlux": 1., "psFluxErr": 1.,
148 "totFlux": 0., "totFluxErr": 0.,
149 "midPointTai": 0, "filterName": "g"},
150 {"diaSourceId": 11, "diaObjectId": 1,
151 "psFlux": 1., "psFluxErr": 1.,
152 "totFlux": 0., "totFluxErr": 0.,
153 "midPointTai": 0, "filterName": "g"},
154 {"diaSourceId": 12, "diaObjectId": 2,
155 "psFlux": np.nan, "psFluxErr": 1.,
156 "totFlux": 0., "totFluxErr": 0.,
157 "midPointTai": 0, "filterName": "g"},
158 {"diaSourceId": self.newDiaObjectId,
159 "diaObjectId": self.newDiaObjectId,
160 "psFlux": 1., "psFluxErr": 1.,
161 "totFlux": 0., "totFluxErr": 0.,
162 "midPointTai": 0, "filterName": "g"}])
163 self.diaSources = pd.DataFrame(data=diaSources)
165 self.updatedDiaObjectIds = np.array([0, 1, 2, self.newDiaObjectId],
166 dtype=np.int64)
168 conf = DiaObjectCalculationConfig()
169 conf.plugins = ["testDiaPlugin",
170 "testDependentDiaPlugin"]
171 self.diaObjCalTask = DiaObjectCalculationTask(config=conf)
173 def testRun(self):
174 """Test the run method and that diaObjects are updated correctly.
175 """
176 results = self.diaObjCalTask.run(self.diaObjects,
177 self.diaSources,
178 self.updatedDiaObjectIds,
179 ["g"])
180 diaObjectCat = results.diaObjectCat
181 updatedDiaObjects = results.updatedDiaObjects
182 updatedDiaObjects.set_index("diaObjectId", inplace=True)
183 # Test the lengths of the output dataframes.
184 self.assertEqual(len(diaObjectCat), len(self.diaObjects))
185 self.assertEqual(len(updatedDiaObjects),
186 len(self.updatedDiaObjectIds))
188 # Test values stored computed in the task.
189 for objId, diaObject in updatedDiaObjects.iterrows():
190 if objId == self.newDiaObjectId:
191 self.assertEqual(diaObject["gMeanFlux"], 1.)
192 self.assertTrue(np.isnan(diaObject["gStdFlux"]))
193 self.assertAlmostEqual(diaObject["gChiFlux"], 0.0)
194 elif objId == 2:
195 self.assertAlmostEqual(diaObject["gMeanFlux"], 0.0)
196 self.assertTrue(np.isnan(diaObject["gStdFlux"]))
197 self.assertAlmostEqual(diaObject["gChiFlux"], 0.0)
198 else:
199 self.assertAlmostEqual(diaObject["gMeanFlux"], 0.5)
200 self.assertAlmostEqual(diaObject["gStdFlux"],
201 0.7071067811865476)
202 self.assertAlmostEqual(diaObject["gChiFlux"], 0.5)
204 def testRunUnindexed(self):
205 """Test inputing un-indexed catalogs.
206 """
207 unindexedDiaSources = pd.DataFrame(data=[
208 {"diaSourceId": objId, "diaObjectId": 0,
209 "psFlux": 0., "psFluxErr": 1.,
210 "totFlux": 0., "totFluxErr": 1.,
211 "midPointTai": 0, "filterName": "g"}
212 for objId in range(1000)])
213 unindexedDiaSources = unindexedDiaSources.append(
214 pd.DataFrame(data=[{"diaSourceId": objId + 1000,
215 "diaObjectId": 0,
216 "psFlux": 0., "psFluxErr": 1.,
217 "totFlux": 0., "totFluxErr": 1.,
218 "midPointTai": 0, "filterName": "g"}
219 for objId in range(10)]))
221 conf = DiaObjectCalculationConfig()
222 conf.plugins = ["testCount"]
223 diaObjectCalTask = DiaObjectCalculationTask(config=conf)
224 self.diaObjects.reset_index()
225 results = diaObjectCalTask.run(self.diaObjects,
226 unindexedDiaSources,
227 np.array([0], dtype=np.int64),
228 ["g"])
229 updatedDiaObjects = results.updatedDiaObjects
230 self.assertEqual(updatedDiaObjects.at[0, "count"],
231 len(unindexedDiaSources))
233 def testConflictingPlugins(self):
234 """Test that code properly exits upon plugin collision.
235 """
236 with self.assertRaises(ValueError):
237 conf = DiaObjectCalculationConfig()
238 conf.plugins = ["testDependentDiaPlugin"]
239 DiaObjectCalculationTask(config=conf)
241 with self.assertRaises(ValueError):
242 conf = DiaObjectCalculationConfig()
243 conf.plugins = ["testDiaPlugin",
244 "testCollidingDiaPlugin",
245 "testDependentDiaPlugin"]
246 DiaObjectCalculationTask(config=conf)
248 # Test that ordering in the config does not matter and dependent
249 # plugin is instantiated after independent plugin. Would raise
250 # ValueError on failure.
251 conf = DiaObjectCalculationConfig()
252 conf.plugins = ["testDependentDiaPlugin",
253 "testDiaPlugin"]
254 DiaObjectCalculationTask(config=conf)
257def setup_module(module):
258 lsst.utils.tests.init()
261if __name__ == "__main__": 261 ↛ 262line 261 didn't jump to line 262, because the condition on line 261 was never true
262 lsst.utils.tests.init()
263 unittest.main()