Coverage for tests/test_association_task.py: 22%

74 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-05-02 04:31 -0700

1# This file is part of ap_association. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import numpy as np 

23import pandas as pd 

24import unittest 

25 

26import lsst.geom as geom 

27import lsst.utils.tests 

28from lsst.ap.association import AssociationTask 

29 

30 

31class TestAssociationTask(unittest.TestCase): 

32 

33 def setUp(self): 

34 """Create sets of diaSources and diaObjects. 

35 """ 

36 rng = np.random.default_rng(1234) 

37 self.nObjects = 5 

38 scatter = 0.1/3600 

39 self.diaObjects = pd.DataFrame(data=[ 

40 {"ra": 0.04*(idx + 1), "dec": 0.04*(idx + 1), 

41 "diaObjectId": idx + 1} 

42 for idx in range(self.nObjects)]) 

43 self.diaObjects.set_index("diaObjectId", drop=False, inplace=True) 

44 self.nSources = 5 

45 self.diaSources = pd.DataFrame(data=[ 

46 {"ra": 0.04*idx + scatter*rng.uniform(-1, 1), 

47 "dec": 0.04*idx + scatter*rng.uniform(-1, 1), 

48 "diaSourceId": idx + 1 + self.nObjects, "diaObjectId": 0, "trailLength": 5.5*idx, 

49 "flags": 0} 

50 for idx in range(self.nSources)]) 

51 self.diaSourceZeroScatter = pd.DataFrame(data=[ 

52 {"ra": 0.04*idx, 

53 "dec": 0.04*idx, 

54 "diaSourceId": idx + 1 + self.nObjects, "diaObjectId": 0, "trailLength": 5.5*idx, 

55 "flags": 0} 

56 for idx in range(self.nSources)]) 

57 self.exposure_time = 30.0 

58 

59 def test_run(self): 

60 """Test the full task by associating a set of diaSources to 

61 existing diaObjects. 

62 """ 

63 config = AssociationTask.ConfigClass() 

64 config.doTrailedSourceFilter = False 

65 assocTask = AssociationTask(config=config) 

66 results = assocTask.run(self.diaSources, self.diaObjects, exposure_time=self.exposure_time) 

67 

68 self.assertEqual(results.nUpdatedDiaObjects, len(self.diaObjects) - 1) 

69 self.assertEqual(results.nUnassociatedDiaObjects, 1) 

70 self.assertEqual(len(results.matchedDiaSources), 

71 len(self.diaObjects) - 1) 

72 self.assertEqual(len(results.unAssocDiaSources), 1) 

73 np.testing.assert_array_equal(results.matchedDiaSources["diaObjectId"].values, [1, 2, 3, 4]) 

74 np.testing.assert_array_equal(results.unAssocDiaSources["diaObjectId"].values, [0]) 

75 

76 def test_run_trailed_sources(self): 

77 """Test the full task by associating a set of diaSources to 

78 existing diaObjects when trailed sources are filtered. 

79 

80 This should filter out two of the five sources based on trail length, 

81 leaving one unassociated diaSource and two associated diaSources. 

82 """ 

83 assocTask = AssociationTask() 

84 results = assocTask.run(self.diaSources, self.diaObjects, exposure_time=self.exposure_time) 

85 

86 self.assertEqual(results.nUpdatedDiaObjects, len(self.diaObjects) - 3) 

87 self.assertEqual(results.nUnassociatedDiaObjects, 3) 

88 self.assertEqual(len(results.matchedDiaSources), len(self.diaObjects) - 3) 

89 self.assertEqual(len(results.unAssocDiaSources), 1) 

90 np.testing.assert_array_equal(results.matchedDiaSources["diaObjectId"].values, [1, 2]) 

91 np.testing.assert_array_equal(results.unAssocDiaSources["diaObjectId"].values, [0]) 

92 

93 def test_run_no_existing_objects(self): 

94 """Test the run method with a completely empty database. 

95 """ 

96 assocTask = AssociationTask() 

97 results = assocTask.run( 

98 self.diaSources, 

99 pd.DataFrame(columns=["ra", "dec", "diaObjectId", "trailLength"]), 

100 exposure_time=self.exposure_time) 

101 self.assertEqual(results.nUpdatedDiaObjects, 0) 

102 self.assertEqual(results.nUnassociatedDiaObjects, 0) 

103 self.assertEqual(len(results.matchedDiaSources), 0) 

104 self.assertTrue(np.all(results.unAssocDiaSources["diaObjectId"] == 0)) 

105 

106 def test_associate_sources(self): 

107 """Test the performance of the associate_sources method in 

108 AssociationTask. 

109 """ 

110 assoc_task = AssociationTask() 

111 assoc_result = assoc_task.associate_sources( 

112 self.diaObjects, self.diaSources) 

113 

114 for test_obj_id, expected_obj_id in zip( 

115 assoc_result.diaSources["diaObjectId"].to_numpy(), 

116 [0, 1, 2, 3, 4]): 

117 self.assertEqual(test_obj_id, expected_obj_id) 

118 np.testing.assert_array_equal(assoc_result.diaSources["diaObjectId"].values, [0, 1, 2, 3, 4]) 

119 

120 def test_score_and_match(self): 

121 """Test association between a set of sources and an existing 

122 DIAObjectCollection. 

123 """ 

124 

125 assoc_task = AssociationTask() 

126 score_struct = assoc_task.score(self.diaObjects, 

127 self.diaSourceZeroScatter, 

128 1.0 * geom.arcseconds) 

129 self.assertFalse(np.isfinite(score_struct.scores[0])) 

130 for src_idx in range(1, len(self.diaSources)): 

131 # Our scores should be extremely close to 0 but not exactly so due 

132 # to machine noise. 

133 self.assertAlmostEqual(score_struct.scores[src_idx], 0.0, 

134 places=16) 

135 

136 # After matching each DIAObject should now contain 2 DIASources 

137 # except the last DIAObject in this collection which should be 

138 # newly created during the matching step and contain only one 

139 # DIASource. 

140 match_result = assoc_task.match( 

141 self.diaObjects, self.diaSources, score_struct) 

142 self.assertEqual(match_result.nUpdatedDiaObjects, 4) 

143 self.assertEqual(match_result.nUnassociatedDiaObjects, 1) 

144 

145 def test_remove_nan_dia_sources(self): 

146 """Test removing DiaSources with NaN locations. 

147 """ 

148 self.diaSources.loc[2, "ra"] = np.nan 

149 self.diaSources.loc[3, "dec"] = np.nan 

150 self.diaSources.loc[4, "ra"] = np.nan 

151 self.diaSources.loc[4, "dec"] = np.nan 

152 assoc_task = AssociationTask() 

153 out_dia_sources = assoc_task.check_dia_source_radec(self.diaSources) 

154 self.assertEqual(len(out_dia_sources), len(self.diaSources) - 3) 

155 

156 

157class MemoryTester(lsst.utils.tests.MemoryTestCase): 

158 pass 

159 

160 

161def setup_module(module): 

162 lsst.utils.tests.init() 

163 

164 

165if __name__ == "__main__": 165 ↛ 166line 165 didn't jump to line 166, because the condition on line 165 was never true

166 lsst.utils.tests.init() 

167 unittest.main()