Coverage for tests/test_association_task.py: 22%
74 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-05-01 17:01 -0700
« prev ^ index » next coverage.py v7.5.0, created at 2024-05-01 17:01 -0700
1# This file is part of ap_association.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import numpy as np
23import pandas as pd
24import unittest
26import lsst.geom as geom
27import lsst.utils.tests
28from lsst.ap.association import AssociationTask
31class TestAssociationTask(unittest.TestCase):
33 def setUp(self):
34 """Create sets of diaSources and diaObjects.
35 """
36 rng = np.random.default_rng(1234)
37 self.nObjects = 5
38 scatter = 0.1/3600
39 self.diaObjects = pd.DataFrame(data=[
40 {"ra": 0.04*(idx + 1), "dec": 0.04*(idx + 1),
41 "diaObjectId": idx + 1}
42 for idx in range(self.nObjects)])
43 self.diaObjects.set_index("diaObjectId", drop=False, inplace=True)
44 self.nSources = 5
45 self.diaSources = pd.DataFrame(data=[
46 {"ra": 0.04*idx + scatter*rng.uniform(-1, 1),
47 "dec": 0.04*idx + scatter*rng.uniform(-1, 1),
48 "diaSourceId": idx + 1 + self.nObjects, "diaObjectId": 0, "trailLength": 5.5*idx,
49 "flags": 0}
50 for idx in range(self.nSources)])
51 self.diaSourceZeroScatter = pd.DataFrame(data=[
52 {"ra": 0.04*idx,
53 "dec": 0.04*idx,
54 "diaSourceId": idx + 1 + self.nObjects, "diaObjectId": 0, "trailLength": 5.5*idx,
55 "flags": 0}
56 for idx in range(self.nSources)])
57 self.exposure_time = 30.0
59 def test_run(self):
60 """Test the full task by associating a set of diaSources to
61 existing diaObjects.
62 """
63 config = AssociationTask.ConfigClass()
64 config.doTrailedSourceFilter = False
65 assocTask = AssociationTask(config=config)
66 results = assocTask.run(self.diaSources, self.diaObjects, exposure_time=self.exposure_time)
68 self.assertEqual(results.nUpdatedDiaObjects, len(self.diaObjects) - 1)
69 self.assertEqual(results.nUnassociatedDiaObjects, 1)
70 self.assertEqual(len(results.matchedDiaSources),
71 len(self.diaObjects) - 1)
72 self.assertEqual(len(results.unAssocDiaSources), 1)
73 np.testing.assert_array_equal(results.matchedDiaSources["diaObjectId"].values, [1, 2, 3, 4])
74 np.testing.assert_array_equal(results.unAssocDiaSources["diaObjectId"].values, [0])
76 def test_run_trailed_sources(self):
77 """Test the full task by associating a set of diaSources to
78 existing diaObjects when trailed sources are filtered.
80 This should filter out two of the five sources based on trail length,
81 leaving one unassociated diaSource and two associated diaSources.
82 """
83 assocTask = AssociationTask()
84 results = assocTask.run(self.diaSources, self.diaObjects, exposure_time=self.exposure_time)
86 self.assertEqual(results.nUpdatedDiaObjects, len(self.diaObjects) - 3)
87 self.assertEqual(results.nUnassociatedDiaObjects, 3)
88 self.assertEqual(len(results.matchedDiaSources), len(self.diaObjects) - 3)
89 self.assertEqual(len(results.unAssocDiaSources), 1)
90 np.testing.assert_array_equal(results.matchedDiaSources["diaObjectId"].values, [1, 2])
91 np.testing.assert_array_equal(results.unAssocDiaSources["diaObjectId"].values, [0])
93 def test_run_no_existing_objects(self):
94 """Test the run method with a completely empty database.
95 """
96 assocTask = AssociationTask()
97 results = assocTask.run(
98 self.diaSources,
99 pd.DataFrame(columns=["ra", "dec", "diaObjectId", "trailLength"]),
100 exposure_time=self.exposure_time)
101 self.assertEqual(results.nUpdatedDiaObjects, 0)
102 self.assertEqual(results.nUnassociatedDiaObjects, 0)
103 self.assertEqual(len(results.matchedDiaSources), 0)
104 self.assertTrue(np.all(results.unAssocDiaSources["diaObjectId"] == 0))
106 def test_associate_sources(self):
107 """Test the performance of the associate_sources method in
108 AssociationTask.
109 """
110 assoc_task = AssociationTask()
111 assoc_result = assoc_task.associate_sources(
112 self.diaObjects, self.diaSources)
114 for test_obj_id, expected_obj_id in zip(
115 assoc_result.diaSources["diaObjectId"].to_numpy(),
116 [0, 1, 2, 3, 4]):
117 self.assertEqual(test_obj_id, expected_obj_id)
118 np.testing.assert_array_equal(assoc_result.diaSources["diaObjectId"].values, [0, 1, 2, 3, 4])
120 def test_score_and_match(self):
121 """Test association between a set of sources and an existing
122 DIAObjectCollection.
123 """
125 assoc_task = AssociationTask()
126 score_struct = assoc_task.score(self.diaObjects,
127 self.diaSourceZeroScatter,
128 1.0 * geom.arcseconds)
129 self.assertFalse(np.isfinite(score_struct.scores[0]))
130 for src_idx in range(1, len(self.diaSources)):
131 # Our scores should be extremely close to 0 but not exactly so due
132 # to machine noise.
133 self.assertAlmostEqual(score_struct.scores[src_idx], 0.0,
134 places=16)
136 # After matching each DIAObject should now contain 2 DIASources
137 # except the last DIAObject in this collection which should be
138 # newly created during the matching step and contain only one
139 # DIASource.
140 match_result = assoc_task.match(
141 self.diaObjects, self.diaSources, score_struct)
142 self.assertEqual(match_result.nUpdatedDiaObjects, 4)
143 self.assertEqual(match_result.nUnassociatedDiaObjects, 1)
145 def test_remove_nan_dia_sources(self):
146 """Test removing DiaSources with NaN locations.
147 """
148 self.diaSources.loc[2, "ra"] = np.nan
149 self.diaSources.loc[3, "dec"] = np.nan
150 self.diaSources.loc[4, "ra"] = np.nan
151 self.diaSources.loc[4, "dec"] = np.nan
152 assoc_task = AssociationTask()
153 out_dia_sources = assoc_task.check_dia_source_radec(self.diaSources)
154 self.assertEqual(len(out_dia_sources), len(self.diaSources) - 3)
157class MemoryTester(lsst.utils.tests.MemoryTestCase):
158 pass
161def setup_module(module):
162 lsst.utils.tests.init()
165if __name__ == "__main__": 165 ↛ 166line 165 didn't jump to line 166, because the condition on line 165 was never true
166 lsst.utils.tests.init()
167 unittest.main()