Coverage for tests/test_simpleAssociation.py: 18%
91 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-11-29 10:48 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-11-29 10:48 +0000
1# This file is part of pipe_tasks.
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21#
23import numpy as np
24import pandas as pd
25import unittest
27import lsst.afw.table as afwTable
28import lsst.geom as geom
29import lsst.utils.tests
30from lsst.pipe.tasks.associationUtils import toIndex
31from lsst.pipe.tasks.simpleAssociation import SimpleAssociationTask
34class TestSimpleAssociation(lsst.utils.tests.TestCase):
36 def setUp(self):
37 simpleAssoc = SimpleAssociationTask()
39 self.nDiaObjects = 10
40 self.diaObjRas = np.linspace(45, 46, self.nDiaObjects)
41 self.diaObjDecs = np.linspace(45, 46, self.nDiaObjects)
42 # Copy a coord to get multiple matches.
43 self.diaObjRas[3] = self.diaObjRas[2] + 0.1/3600
44 self.diaObjDecs[3] = self.diaObjDecs[2] + 0.1/3600
45 self.diaObjects = [
46 simpleAssoc.createDiaObject(objId, ra, dec)
47 for objId, ra, dec in zip(
48 np.arange(self.nDiaObjects, dtype=int),
49 self.diaObjRas,
50 self.diaObjDecs)]
52 self.hpIndices = [toIndex(simpleAssoc.config.nside,
53 diaObj["ra"],
54 diaObj["dec"])
55 for diaObj in self.diaObjects]
57 self.newDiaObjectVisit = 1236
58 # Drop in two copies of the DiaObject locations to make DiaSources.
59 diaSourceList = [
60 {"ccdVisitId": 1234,
61 "diaSourceId": idx,
62 "diaObjectId": 0,
63 "ra": ra,
64 "dec": dec}
65 for idx, (ra, dec) in enumerate(zip(self.diaObjRas,
66 self.diaObjDecs))]
67 self.coordList = [
68 [geom.SpherePoint(diaSrc["ra"], diaSrc["dec"], geom.degrees)]
69 for diaSrc in diaSourceList]
70 moreDiaSources = [
71 {"ccdVisitId": 1235,
72 "diaSourceId": idx + self.nDiaObjects,
73 "diaObjectId": 0,
74 "ra": ra,
75 "dec": dec}
76 for idx, (ra, dec) in enumerate(zip(self.diaObjRas,
77 self.diaObjDecs))]
78 for idx in range(self.nDiaObjects):
79 self.coordList[idx].append(
80 geom.SpherePoint(moreDiaSources[idx]["ra"],
81 moreDiaSources[idx]["dec"],
82 geom.degrees))
83 diaSourceList.extend(moreDiaSources)
85 self.nNewDiaSources = 2
86 # Drop in two more DiaSources that are unassociated.
87 diaSourceList.append({"ccdVisitId": 1236,
88 "diaSourceId": len(diaSourceList),
89 "diaObjectId": 0,
90 "ra": 0.0,
91 "dec": 0.0})
92 diaSourceList.append({"ccdVisitId": 1236,
93 "diaSourceId": len(diaSourceList),
94 "diaObjectId": 0,
95 "ra": 1.0,
96 "dec": 89.0})
97 self.diaSources = pd.DataFrame(data=diaSourceList)
99 def tearDown(self):
100 del self.diaObjects
101 del self.hpIndices
102 del self.diaSources
103 del self.coordList
105 def testRun(self):
106 """Test the full run method of the simple associator.
107 """
108 simpleAssoc = SimpleAssociationTask()
109 result = simpleAssoc.run(self.diaSources)
111 # Test the number of expected DiaObjects are created.
112 self.assertEqual(len(result.diaObjects),
113 self.nDiaObjects + self.nNewDiaSources)
115 # Test that DiaSources are assigned the correct ``diaObjectId``
116 assocDiaObjects = result.diaObjects
117 assocDiaSources = result.assocDiaSources.reset_index().set_index(["diaObjectId", "diaSourceId"])
118 for idx, (diaObjId, diaObj) in enumerate(assocDiaObjects.iterrows()):
119 if idx < 10:
120 self.assertEqual(len(assocDiaSources.loc[diaObjId]), 2)
121 else:
122 self.assertEqual(len(assocDiaSources.loc[diaObjId]), 1)
124 def testUpdateCatalogs(self):
125 """Test adding data to existing DiaObject/Source catalogs.
126 """
127 matchIndex = 4
128 diaSrc = self.diaSources.iloc[matchIndex]
129 self.diaObjects[matchIndex]["diaObjectId"] = 1234
130 ccdVisit = diaSrc["ccdVisitId"]
131 diaSourceId = diaSrc["diaSourceId"]
132 self.diaSources.set_index(["ccdVisitId", "diaSourceId"], inplace=True)
134 simpleAssoc = SimpleAssociationTask()
135 simpleAssoc.updateCatalogs(matchIndex,
136 diaSrc,
137 self.diaSources,
138 ccdVisit,
139 diaSourceId,
140 self.diaObjects,
141 self.coordList,
142 self.hpIndices)
143 self.assertEqual(len(self.hpIndices), self.nDiaObjects)
144 self.assertEqual(len(self.coordList), self.nDiaObjects)
145 # Should be 3 source coordinates.
146 self.assertEqual(len(self.coordList[matchIndex]), 3)
147 self.assertEqual(len(self.diaObjects), self.nDiaObjects)
148 self.assertEqual(self.diaSources.loc[(ccdVisit, diaSourceId),
149 "diaObjectId"],
150 self.diaObjects[matchIndex]["diaObjectId"])
152 def testAddDiaObject(self):
153 """Test adding data to existing DiaObjects/Sources.
154 """
155 diaSrc = self.diaSources.iloc[-1]
156 ccdVisit = diaSrc["ccdVisitId"]
157 diaSourceId = diaSrc["diaSourceId"]
158 self.diaSources.set_index(["ccdVisitId", "diaSourceId"], inplace=True)
159 idCat = afwTable.SourceCatalog(
160 afwTable.SourceTable.make(afwTable.SourceTable.makeMinimalSchema()))
162 simpleAssoc = SimpleAssociationTask()
163 simpleAssoc.addNewDiaObject(diaSrc,
164 self.diaSources,
165 ccdVisit,
166 diaSourceId,
167 self.diaObjects,
168 idCat,
169 self.coordList,
170 self.hpIndices)
171 self.assertEqual(len(self.hpIndices), self.nDiaObjects + 1)
172 self.assertEqual(len(self.coordList), self.nDiaObjects + 1)
173 self.assertEqual(len(self.diaObjects), self.nDiaObjects + 1)
174 self.assertEqual(self.diaSources.loc[(ccdVisit, diaSourceId),
175 "diaObjectId"],
176 idCat[0].get("id"))
178 def testFindMatches(self):
179 """Test the simple brute force matching algorithm.
180 """
181 simpleAssoc = SimpleAssociationTask()
182 # No match
183 matchResult = simpleAssoc.findMatches(
184 0.0,
185 0.0,
186 2*simpleAssoc.config.tolerance,
187 self.hpIndices,
188 self.diaObjects)
189 self.assertIsNone(matchResult.dists)
190 self.assertIsNone(matchResult.matches)
192 # One match
193 matchResult = simpleAssoc.findMatches(
194 self.diaObjRas[4],
195 self.diaObjDecs[4],
196 2*simpleAssoc.config.tolerance,
197 self.hpIndices,
198 self.diaObjects)
199 self.assertEqual(len(matchResult.dists), 1)
200 self.assertEqual(len(matchResult.matches), 1)
201 self.assertEqual(matchResult.matches[0], 4)
203 # 2 match
204 matchResult = simpleAssoc.findMatches(
205 self.diaObjRas[2],
206 self.diaObjDecs[2],
207 2*simpleAssoc.config.tolerance,
208 self.hpIndices,
209 self.diaObjects)
210 self.assertEqual(len(matchResult.dists), 2)
211 self.assertEqual(len(matchResult.matches), 2)
212 self.assertEqual(matchResult.matches[0], 2)
213 self.assertEqual(matchResult.matches[1], 3)
216def setup_module(module):
217 lsst.utils.tests.init()
220class MemoryTestCase(lsst.utils.tests.MemoryTestCase):
221 pass
224if __name__ == "__main__": 224 ↛ 225line 224 didn't jump to line 225, because the condition on line 224 was never true
225 lsst.utils.tests.init()
226 unittest.main()