Coverage for tests/test_simpleAssociation.py: 18%
94 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-03-01 03:26 -0800
« prev ^ index » next coverage.py v6.5.0, created at 2023-03-01 03:26 -0800
1# This file is part of pipe_tasks.
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21#
23import numpy as np
24import pandas as pd
25import unittest
27import lsst.afw.table as afwTable
28import lsst.geom as geom
29import lsst.utils.tests
30from lsst.pipe.tasks.associationUtils import toIndex
31from lsst.pipe.tasks.simpleAssociation import SimpleAssociationTask
34class TestSimpleAssociation(lsst.utils.tests.TestCase):
36 def setUp(self):
37 simpleAssoc = SimpleAssociationTask()
38 self.tractPatchId = 1234
39 self.skymapBits = 16
41 self.nDiaObjects = 10
42 self.diaObjRas = np.linspace(45, 46, self.nDiaObjects)
43 self.diaObjDecs = np.linspace(45, 46, self.nDiaObjects)
44 # Copy a coord to get multiple matches.
45 self.diaObjRas[3] = self.diaObjRas[2] + 0.1/3600
46 self.diaObjDecs[3] = self.diaObjDecs[2] + 0.1/3600
47 self.diaObjects = [
48 simpleAssoc.createDiaObject(objId, ra, decl)
49 for objId, ra, decl in zip(
50 np.arange(self.nDiaObjects, dtype=int),
51 self.diaObjRas,
52 self.diaObjDecs)]
54 self.hpIndices = [toIndex(simpleAssoc.config.nside,
55 diaObj["ra"],
56 diaObj["decl"])
57 for diaObj in self.diaObjects]
59 self.newDiaObjectVisit = 1236
60 # Drop in two copies of the DiaObject locations to make DiaSources.
61 diaSourceList = [
62 {"ccdVisitId": 1234,
63 "diaSourceId": idx,
64 "diaObjectId": 0,
65 "ra": ra,
66 "decl": decl}
67 for idx, (ra, decl) in enumerate(zip(self.diaObjRas,
68 self.diaObjDecs))]
69 self.coordList = [
70 [geom.SpherePoint(diaSrc["ra"], diaSrc["decl"], geom.degrees)]
71 for diaSrc in diaSourceList]
72 moreDiaSources = [
73 {"ccdVisitId": 1235,
74 "diaSourceId": idx + self.nDiaObjects,
75 "diaObjectId": 0,
76 "ra": ra,
77 "decl": decl}
78 for idx, (ra, decl) in enumerate(zip(self.diaObjRas,
79 self.diaObjDecs))]
80 for idx in range(self.nDiaObjects):
81 self.coordList[idx].append(
82 geom.SpherePoint(moreDiaSources[idx]["ra"],
83 moreDiaSources[idx]["decl"],
84 geom.degrees))
85 diaSourceList.extend(moreDiaSources)
87 self.nNewDiaSources = 2
88 # Drop in two more DiaSources that are unassociated.
89 diaSourceList.append({"ccdVisitId": 1236,
90 "diaSourceId": len(diaSourceList),
91 "diaObjectId": 0,
92 "ra": 0.0,
93 "decl": 0.0})
94 diaSourceList.append({"ccdVisitId": 1236,
95 "diaSourceId": len(diaSourceList),
96 "diaObjectId": 0,
97 "ra": 1.0,
98 "decl": 89.0})
99 self.diaSources = pd.DataFrame(data=diaSourceList)
101 def tearDown(self):
102 del self.diaObjects
103 del self.hpIndices
104 del self.diaSources
105 del self.coordList
107 def testRun(self):
108 """Test the full run method of the simple associator.
109 """
110 simpleAssoc = SimpleAssociationTask()
111 result = simpleAssoc.run(self.diaSources,
112 self.tractPatchId,
113 self.skymapBits)
115 # Test the number of expected DiaObjects are created.
116 self.assertEqual(len(result.diaObjects),
117 self.nDiaObjects + self.nNewDiaSources)
119 # Test that DiaSources are assigned the correct ``diaObjectId``
120 assocDiaObjects = result.diaObjects
121 assocDiaSources = result.assocDiaSources.reset_index().set_index(["diaObjectId", "diaSourceId"])
122 for idx, (diaObjId, diaObj) in enumerate(assocDiaObjects.iterrows()):
123 if idx < 10:
124 self.assertEqual(len(assocDiaSources.loc[diaObjId]), 2)
125 else:
126 self.assertEqual(len(assocDiaSources.loc[diaObjId]), 1)
128 def testUpdateCatalogs(self):
129 """Test adding data to existing DiaObject/Source catalogs.
130 """
131 matchIndex = 4
132 diaSrc = self.diaSources.iloc[matchIndex]
133 self.diaObjects[matchIndex]["diaObjectId"] = 1234
134 ccdVisit = diaSrc["ccdVisitId"]
135 diaSourceId = diaSrc["diaSourceId"]
136 self.diaSources.set_index(["ccdVisitId", "diaSourceId"], inplace=True)
138 simpleAssoc = SimpleAssociationTask()
139 simpleAssoc.updateCatalogs(matchIndex,
140 diaSrc,
141 self.diaSources,
142 ccdVisit,
143 diaSourceId,
144 self.diaObjects,
145 self.coordList,
146 self.hpIndices)
147 self.assertEqual(len(self.hpIndices), self.nDiaObjects)
148 self.assertEqual(len(self.coordList), self.nDiaObjects)
149 # Should be 3 source coordinates.
150 self.assertEqual(len(self.coordList[matchIndex]), 3)
151 self.assertEqual(len(self.diaObjects), self.nDiaObjects)
152 self.assertEqual(self.diaSources.loc[(ccdVisit, diaSourceId),
153 "diaObjectId"],
154 self.diaObjects[matchIndex]["diaObjectId"])
156 def testAddDiaObject(self):
157 """Test adding data to existing DiaObjects/Sources.
158 """
159 diaSrc = self.diaSources.iloc[-1]
160 ccdVisit = diaSrc["ccdVisitId"]
161 diaSourceId = diaSrc["diaSourceId"]
162 self.diaSources.set_index(["ccdVisitId", "diaSourceId"], inplace=True)
163 idFactory = afwTable.IdFactory.makeSource(1234, afwTable.IdFactory.computeReservedFromMaxBits(16))
164 idCat = afwTable.SourceCatalog(
165 afwTable.SourceTable.make(afwTable.SourceTable.makeMinimalSchema(),
166 idFactory))
168 simpleAssoc = SimpleAssociationTask()
169 simpleAssoc.addNewDiaObject(diaSrc,
170 self.diaSources,
171 ccdVisit,
172 diaSourceId,
173 self.diaObjects,
174 idCat,
175 self.coordList,
176 self.hpIndices)
177 self.assertEqual(len(self.hpIndices), self.nDiaObjects + 1)
178 self.assertEqual(len(self.coordList), self.nDiaObjects + 1)
179 self.assertEqual(len(self.diaObjects), self.nDiaObjects + 1)
180 self.assertEqual(self.diaSources.loc[(ccdVisit, diaSourceId),
181 "diaObjectId"],
182 idCat[0].get("id"))
184 def testFindMatches(self):
185 """Test the simple brute force matching algorithm.
186 """
187 simpleAssoc = SimpleAssociationTask()
188 # No match
189 matchResult = simpleAssoc.findMatches(
190 0.0,
191 0.0,
192 2*simpleAssoc.config.tolerance,
193 self.hpIndices,
194 self.diaObjects)
195 self.assertIsNone(matchResult.dists)
196 self.assertIsNone(matchResult.matches)
198 # One match
199 matchResult = simpleAssoc.findMatches(
200 self.diaObjRas[4],
201 self.diaObjDecs[4],
202 2*simpleAssoc.config.tolerance,
203 self.hpIndices,
204 self.diaObjects)
205 self.assertEqual(len(matchResult.dists), 1)
206 self.assertEqual(len(matchResult.matches), 1)
207 self.assertEqual(matchResult.matches[0], 4)
209 # 2 match
210 matchResult = simpleAssoc.findMatches(
211 self.diaObjRas[2],
212 self.diaObjDecs[2],
213 2*simpleAssoc.config.tolerance,
214 self.hpIndices,
215 self.diaObjects)
216 self.assertEqual(len(matchResult.dists), 2)
217 self.assertEqual(len(matchResult.matches), 2)
218 self.assertEqual(matchResult.matches[0], 2)
219 self.assertEqual(matchResult.matches[1], 3)
222def setup_module(module):
223 lsst.utils.tests.init()
226class MemoryTestCase(lsst.utils.tests.MemoryTestCase):
227 pass
230if __name__ == "__main__": 230 ↛ 231line 230 didn't jump to line 231, because the condition on line 230 was never true
231 lsst.utils.tests.init()
232 unittest.main()