Coverage for tests/test_association_task.py: 22%
74 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-11 11:00 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-11 11:00 +0000
1# This file is part of ap_association.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import numpy as np
23import pandas as pd
24import unittest
25import lsst.geom as geom
26import lsst.utils.tests
28from lsst.ap.association import AssociationTask
31class TestAssociationTask(unittest.TestCase):
33 def setUp(self):
34 """Create sets of diaSources and diaObjects.
35 """
36 rng = np.random.default_rng(1234)
37 self.nObjects = 5
38 scatter = 0.1/3600
39 self.diaObjects = pd.DataFrame(data=[
40 {"ra": 0.04*(idx + 1), "dec": 0.04*(idx + 1),
41 "diaObjectId": idx + 1}
42 for idx in range(self.nObjects)])
43 self.diaObjects.set_index("diaObjectId", drop=False, inplace=True)
44 self.nSources = 5
45 self.diaSources = pd.DataFrame(data=[
46 {"ra": 0.04*idx + scatter*rng.uniform(-1, 1),
47 "dec": 0.04*idx + scatter*rng.uniform(-1, 1),
48 "diaSourceId": idx + 1 + self.nObjects, "diaObjectId": 0, "trailLength": 5.5*idx}
49 for idx in range(self.nSources)])
50 self.diaSourceZeroScatter = pd.DataFrame(data=[
51 {"ra": 0.04*idx,
52 "dec": 0.04*idx,
53 "diaSourceId": idx + 1 + self.nObjects, "diaObjectId": 0, "trailLength": 5.5*idx}
54 for idx in range(self.nSources)])
55 self.exposure_time = 30.0
57 def test_run(self):
58 """Test the full task by associating a set of diaSources to
59 existing diaObjects.
60 """
61 config = AssociationTask.ConfigClass()
62 config.doTrailedSourceFilter = False
63 assocTask = AssociationTask(config=config)
64 results = assocTask.run(self.diaSources, self.diaObjects, exposure_time=self.exposure_time)
66 self.assertEqual(results.nUpdatedDiaObjects, len(self.diaObjects) - 1)
67 self.assertEqual(results.nUnassociatedDiaObjects, 1)
68 self.assertEqual(len(results.matchedDiaSources),
69 len(self.diaObjects) - 1)
70 self.assertEqual(len(results.unAssocDiaSources), 1)
71 np.testing.assert_array_equal(results.matchedDiaSources["diaObjectId"].values, [1, 2, 3, 4])
72 np.testing.assert_array_equal(results.unAssocDiaSources["diaObjectId"].values, [0])
74 def test_run_trailed_sources(self):
75 """Test the full task by associating a set of diaSources to
76 existing diaObjects when trailed sources are filtered.
78 This should filter out two of the five sources based on trail length,
79 leaving one unassociated diaSource and two associated diaSources.
80 """
81 assocTask = AssociationTask()
82 results = assocTask.run(self.diaSources, self.diaObjects, exposure_time=self.exposure_time)
84 self.assertEqual(results.nUpdatedDiaObjects, len(self.diaObjects) - 3)
85 self.assertEqual(results.nUnassociatedDiaObjects, 3)
86 self.assertEqual(len(results.matchedDiaSources), len(self.diaObjects) - 3)
87 self.assertEqual(len(results.unAssocDiaSources), 1)
88 np.testing.assert_array_equal(results.matchedDiaSources["diaObjectId"].values, [1, 2])
89 np.testing.assert_array_equal(results.unAssocDiaSources["diaObjectId"].values, [0])
91 def test_run_no_existing_objects(self):
92 """Test the run method with a completely empty database.
93 """
94 assocTask = AssociationTask()
95 results = assocTask.run(
96 self.diaSources,
97 pd.DataFrame(columns=["ra", "dec", "diaObjectId", "trailLength"]),
98 exposure_time=self.exposure_time)
99 self.assertEqual(results.nUpdatedDiaObjects, 0)
100 self.assertEqual(results.nUnassociatedDiaObjects, 0)
101 self.assertEqual(len(results.matchedDiaSources), 0)
102 self.assertTrue(np.all(results.unAssocDiaSources["diaObjectId"] == 0))
104 def test_associate_sources(self):
105 """Test the performance of the associate_sources method in
106 AssociationTask.
107 """
108 assoc_task = AssociationTask()
109 assoc_result = assoc_task.associate_sources(
110 self.diaObjects, self.diaSources)
112 for test_obj_id, expected_obj_id in zip(
113 assoc_result.diaSources["diaObjectId"].to_numpy(),
114 [0, 1, 2, 3, 4]):
115 self.assertEqual(test_obj_id, expected_obj_id)
116 np.testing.assert_array_equal(assoc_result.diaSources["diaObjectId"].values, [0, 1, 2, 3, 4])
118 def test_score_and_match(self):
119 """Test association between a set of sources and an existing
120 DIAObjectCollection.
121 """
123 assoc_task = AssociationTask()
124 score_struct = assoc_task.score(self.diaObjects,
125 self.diaSourceZeroScatter,
126 1.0 * geom.arcseconds)
127 self.assertFalse(np.isfinite(score_struct.scores[0]))
128 for src_idx in range(1, len(self.diaSources)):
129 # Our scores should be extremely close to 0 but not exactly so due
130 # to machine noise.
131 self.assertAlmostEqual(score_struct.scores[src_idx], 0.0,
132 places=16)
134 # After matching each DIAObject should now contain 2 DIASources
135 # except the last DIAObject in this collection which should be
136 # newly created during the matching step and contain only one
137 # DIASource.
138 match_result = assoc_task.match(
139 self.diaObjects, self.diaSources, score_struct)
140 self.assertEqual(match_result.nUpdatedDiaObjects, 4)
141 self.assertEqual(match_result.nUnassociatedDiaObjects, 1)
143 def test_remove_nan_dia_sources(self):
144 """Test removing DiaSources with NaN locations.
145 """
146 self.diaSources.loc[2, "ra"] = np.nan
147 self.diaSources.loc[3, "dec"] = np.nan
148 self.diaSources.loc[4, "ra"] = np.nan
149 self.diaSources.loc[4, "dec"] = np.nan
150 assoc_task = AssociationTask()
151 out_dia_sources = assoc_task.check_dia_source_radec(self.diaSources)
152 self.assertEqual(len(out_dia_sources), len(self.diaSources) - 3)
155class MemoryTester(lsst.utils.tests.MemoryTestCase):
156 pass
159def setup_module(module):
160 lsst.utils.tests.init()
163if __name__ == "__main__": 163 ↛ 164line 163 didn't jump to line 164, because the condition on line 163 was never true
164 lsst.utils.tests.init()
165 unittest.main()