Coverage for tests/test_simpleAssociation.py: 18%
95 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-07 03:40 -0700
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-07 03:40 -0700
1# This file is part of pipe_tasks.
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21#
23import numpy as np
24import pandas as pd
25import unittest
27import lsst.afw.table as afwTable
28import lsst.geom as geom
29import lsst.utils.tests
30from lsst.pipe.tasks.associationUtils import toIndex
31from lsst.pipe.tasks.simpleAssociation import SimpleAssociationTask
34class TestSimpleAssociation(lsst.utils.tests.TestCase):
36 def setUp(self):
37 simpleAssoc = SimpleAssociationTask()
39 self.nDiaObjects = 10
40 self.diaObjRas = np.linspace(45, 46, self.nDiaObjects)
41 self.diaObjDecs = np.linspace(45, 46, self.nDiaObjects)
42 # Copy a coord to get multiple matches.
43 self.diaObjRas[3] = self.diaObjRas[2] + 0.1/3600
44 self.diaObjDecs[3] = self.diaObjDecs[2] + 0.1/3600
45 self.diaObjects = [
46 simpleAssoc.createDiaObject(objId, ra, dec)
47 for objId, ra, dec in zip(
48 np.arange(self.nDiaObjects, dtype=int),
49 self.diaObjRas,
50 self.diaObjDecs)]
52 self.hpIndices = [toIndex(simpleAssoc.config.nside,
53 diaObj["ra"],
54 diaObj["dec"])
55 for diaObj in self.diaObjects]
57 self.newDiaObjectVisit = 1236
58 # Drop in two copies of the DiaObject locations to make DiaSources.
59 diaSourceList = [
60 {"visit": 1234,
61 "detector": 42,
62 "diaSourceId": idx,
63 "diaObjectId": 0,
64 "ra": ra,
65 "dec": dec}
66 for idx, (ra, dec) in enumerate(zip(self.diaObjRas,
67 self.diaObjDecs))]
68 self.coordList = [
69 [geom.SpherePoint(diaSrc["ra"], diaSrc["dec"], geom.degrees)]
70 for diaSrc in diaSourceList]
71 moreDiaSources = [
72 {"visit": 1235,
73 "detector": 43,
74 "diaSourceId": idx + self.nDiaObjects,
75 "diaObjectId": 0,
76 "ra": ra,
77 "dec": dec}
78 for idx, (ra, dec) in enumerate(zip(self.diaObjRas,
79 self.diaObjDecs))]
80 for idx in range(self.nDiaObjects):
81 self.coordList[idx].append(
82 geom.SpherePoint(moreDiaSources[idx]["ra"],
83 moreDiaSources[idx]["dec"],
84 geom.degrees))
85 diaSourceList.extend(moreDiaSources)
87 self.nNewDiaSources = 2
88 # Drop in two more DiaSources that are unassociated.
89 diaSourceList.append({"visit": 1236,
90 "detector": 44,
91 "diaSourceId": len(diaSourceList),
92 "diaObjectId": 0,
93 "ra": 0.0,
94 "dec": 0.0})
95 diaSourceList.append({"visit": 1236,
96 "detector": 45,
97 "diaSourceId": len(diaSourceList),
98 "diaObjectId": 0,
99 "ra": 1.0,
100 "dec": 89.0})
101 self.diaSources = pd.DataFrame(data=diaSourceList)
103 def tearDown(self):
104 del self.diaObjects
105 del self.hpIndices
106 del self.diaSources
107 del self.coordList
109 def testRun(self):
110 """Test the full run method of the simple associator.
111 """
112 simpleAssoc = SimpleAssociationTask()
113 result = simpleAssoc.run(self.diaSources)
115 # Test the number of expected DiaObjects are created.
116 self.assertEqual(len(result.diaObjects),
117 self.nDiaObjects + self.nNewDiaSources)
119 # Test that DiaSources are assigned the correct ``diaObjectId``
120 assocDiaObjects = result.diaObjects
121 assocDiaSources = result.assocDiaSources.reset_index().set_index(["diaObjectId", "diaSourceId"])
122 for idx, (diaObjId, diaObj) in enumerate(assocDiaObjects.iterrows()):
123 if idx < 10:
124 self.assertEqual(len(assocDiaSources.loc[diaObjId]), 2)
125 else:
126 self.assertEqual(len(assocDiaSources.loc[diaObjId]), 1)
128 def testUpdateCatalogs(self):
129 """Test adding data to existing DiaObject/Source catalogs.
130 """
131 matchIndex = 4
132 diaSrc = self.diaSources.iloc[matchIndex]
133 self.diaObjects[matchIndex]["diaObjectId"] = 1234
134 visit = diaSrc["visit"]
135 detector = diaSrc["detector"]
136 diaSourceId = diaSrc["diaSourceId"]
137 self.diaSources[("visit,detector")] = list(zip(self.diaSources["visit"],
138 self.diaSources["detector"]))
139 self.diaSources.set_index(["visit,detector", "diaSourceId"], inplace=True)
141 simpleAssoc = SimpleAssociationTask()
142 simpleAssoc.updateCatalogs(matchIndex,
143 diaSrc,
144 self.diaSources,
145 visit,
146 detector,
147 diaSourceId,
148 self.diaObjects,
149 self.coordList,
150 self.hpIndices)
151 self.assertEqual(len(self.hpIndices), self.nDiaObjects)
152 self.assertEqual(len(self.coordList), self.nDiaObjects)
153 # Should be 3 source coordinates.
154 self.assertEqual(len(self.coordList[matchIndex]), 3)
155 self.assertEqual(len(self.diaObjects), self.nDiaObjects)
156 self.assertEqual(self.diaSources.loc[((visit, detector), diaSourceId), "diaObjectId"].iloc[0],
157 self.diaObjects[matchIndex]["diaObjectId"])
159 def testAddDiaObject(self):
160 """Test adding data to existing DiaObjects/Sources.
161 """
162 diaSrc = self.diaSources.iloc[-1]
163 visit = diaSrc["visit"]
164 detector = diaSrc["detector"]
165 diaSourceId = diaSrc["diaSourceId"]
166 self.diaSources[("visit,detector")] = list(zip(self.diaSources["visit"], self.diaSources["detector"]))
167 self.diaSources.set_index(["visit,detector", "diaSourceId"], inplace=True)
168 idCat = afwTable.SourceCatalog(
169 afwTable.SourceTable.make(afwTable.SourceTable.makeMinimalSchema()))
171 simpleAssoc = SimpleAssociationTask()
172 simpleAssoc.addNewDiaObject(diaSrc,
173 self.diaSources,
174 visit,
175 detector,
176 diaSourceId,
177 self.diaObjects,
178 idCat,
179 self.coordList,
180 self.hpIndices)
181 self.assertEqual(len(self.hpIndices), self.nDiaObjects + 1)
182 self.assertEqual(len(self.coordList), self.nDiaObjects + 1)
183 self.assertEqual(len(self.diaObjects), self.nDiaObjects + 1)
184 self.assertEqual(self.diaSources.loc[((visit, detector), diaSourceId), "diaObjectId"].iloc[0],
185 idCat[0].get("id"))
187 def testFindMatches(self):
188 """Test the simple brute force matching algorithm.
189 """
190 simpleAssoc = SimpleAssociationTask()
191 # No match
192 matchResult = simpleAssoc.findMatches(
193 0.0,
194 0.0,
195 2*simpleAssoc.config.tolerance,
196 self.hpIndices,
197 self.diaObjects)
198 self.assertIsNone(matchResult.dists)
199 self.assertIsNone(matchResult.matches)
201 # One match
202 matchResult = simpleAssoc.findMatches(
203 self.diaObjRas[4],
204 self.diaObjDecs[4],
205 2*simpleAssoc.config.tolerance,
206 self.hpIndices,
207 self.diaObjects)
208 self.assertEqual(len(matchResult.dists), 1)
209 self.assertEqual(len(matchResult.matches), 1)
210 self.assertEqual(matchResult.matches[0], 4)
212 # 2 match
213 matchResult = simpleAssoc.findMatches(
214 self.diaObjRas[2],
215 self.diaObjDecs[2],
216 2*simpleAssoc.config.tolerance,
217 self.hpIndices,
218 self.diaObjects)
219 self.assertEqual(len(matchResult.dists), 2)
220 self.assertEqual(len(matchResult.matches), 2)
221 self.assertEqual(matchResult.matches[0], 2)
222 self.assertEqual(matchResult.matches[1], 3)
225def setup_module(module):
226 lsst.utils.tests.init()
229class MemoryTestCase(lsst.utils.tests.MemoryTestCase):
230 pass
233if __name__ == "__main__": 233 ↛ 234line 233 didn't jump to line 234, because the condition on line 233 was never true
234 lsst.utils.tests.init()
235 unittest.main()