Coverage for tests/test_association_task.py: 22%

74 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-10-19 11:23 +0000

1# This file is part of ap_association. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import numpy as np 

23import pandas as pd 

24import unittest 

25import lsst.geom as geom 

26import lsst.utils.tests 

27 

28from lsst.ap.association import AssociationTask 

29 

30 

31class TestAssociationTask(unittest.TestCase): 

32 

33 def setUp(self): 

34 """Create sets of diaSources and diaObjects. 

35 """ 

36 rng = np.random.default_rng(1234) 

37 self.nObjects = 5 

38 scatter = 0.1/3600 

39 self.diaObjects = pd.DataFrame(data=[ 

40 {"ra": 0.04*(idx + 1), "dec": 0.04*(idx + 1), 

41 "diaObjectId": idx + 1} 

42 for idx in range(self.nObjects)]) 

43 self.diaObjects.set_index("diaObjectId", drop=False, inplace=True) 

44 self.nSources = 5 

45 self.diaSources = pd.DataFrame(data=[ 

46 {"ra": 0.04*idx + scatter*rng.uniform(-1, 1), 

47 "dec": 0.04*idx + scatter*rng.uniform(-1, 1), 

48 "diaSourceId": idx + 1 + self.nObjects, "diaObjectId": 0, "trailLength": 5.5*idx} 

49 for idx in range(self.nSources)]) 

50 self.diaSourceZeroScatter = pd.DataFrame(data=[ 

51 {"ra": 0.04*idx, 

52 "dec": 0.04*idx, 

53 "diaSourceId": idx + 1 + self.nObjects, "diaObjectId": 0, "trailLength": 5.5*idx} 

54 for idx in range(self.nSources)]) 

55 self.exposure_time = 30.0 

56 

57 def test_run(self): 

58 """Test the full task by associating a set of diaSources to 

59 existing diaObjects. 

60 """ 

61 config = AssociationTask.ConfigClass() 

62 config.doTrailedSourceFilter = False 

63 assocTask = AssociationTask(config=config) 

64 results = assocTask.run(self.diaSources, self.diaObjects, exposure_time=self.exposure_time) 

65 

66 self.assertEqual(results.nUpdatedDiaObjects, len(self.diaObjects) - 1) 

67 self.assertEqual(results.nUnassociatedDiaObjects, 1) 

68 self.assertEqual(len(results.matchedDiaSources), 

69 len(self.diaObjects) - 1) 

70 self.assertEqual(len(results.unAssocDiaSources), 1) 

71 np.testing.assert_array_equal(results.matchedDiaSources["diaObjectId"].values, [1, 2, 3, 4]) 

72 np.testing.assert_array_equal(results.unAssocDiaSources["diaObjectId"].values, [0]) 

73 

74 def test_run_trailed_sources(self): 

75 """Test the full task by associating a set of diaSources to 

76 existing diaObjects when trailed sources are filtered. 

77 

78 This should filter out two of the five sources based on trail length, 

79 leaving one unassociated diaSource and two associated diaSources. 

80 """ 

81 assocTask = AssociationTask() 

82 results = assocTask.run(self.diaSources, self.diaObjects, exposure_time=self.exposure_time) 

83 

84 self.assertEqual(results.nUpdatedDiaObjects, len(self.diaObjects) - 3) 

85 self.assertEqual(results.nUnassociatedDiaObjects, 3) 

86 self.assertEqual(len(results.matchedDiaSources), len(self.diaObjects) - 3) 

87 self.assertEqual(len(results.unAssocDiaSources), 1) 

88 np.testing.assert_array_equal(results.matchedDiaSources["diaObjectId"].values, [1, 2]) 

89 np.testing.assert_array_equal(results.unAssocDiaSources["diaObjectId"].values, [0]) 

90 

91 def test_run_no_existing_objects(self): 

92 """Test the run method with a completely empty database. 

93 """ 

94 assocTask = AssociationTask() 

95 results = assocTask.run( 

96 self.diaSources, 

97 pd.DataFrame(columns=["ra", "dec", "diaObjectId", "trailLength"]), 

98 exposure_time=self.exposure_time) 

99 self.assertEqual(results.nUpdatedDiaObjects, 0) 

100 self.assertEqual(results.nUnassociatedDiaObjects, 0) 

101 self.assertEqual(len(results.matchedDiaSources), 0) 

102 self.assertTrue(np.all(results.unAssocDiaSources["diaObjectId"] == 0)) 

103 

104 def test_associate_sources(self): 

105 """Test the performance of the associate_sources method in 

106 AssociationTask. 

107 """ 

108 assoc_task = AssociationTask() 

109 assoc_result = assoc_task.associate_sources( 

110 self.diaObjects, self.diaSources) 

111 

112 for test_obj_id, expected_obj_id in zip( 

113 assoc_result.diaSources["diaObjectId"].to_numpy(), 

114 [0, 1, 2, 3, 4]): 

115 self.assertEqual(test_obj_id, expected_obj_id) 

116 np.testing.assert_array_equal(assoc_result.diaSources["diaObjectId"].values, [0, 1, 2, 3, 4]) 

117 

118 def test_score_and_match(self): 

119 """Test association between a set of sources and an existing 

120 DIAObjectCollection. 

121 """ 

122 

123 assoc_task = AssociationTask() 

124 score_struct = assoc_task.score(self.diaObjects, 

125 self.diaSourceZeroScatter, 

126 1.0 * geom.arcseconds) 

127 self.assertFalse(np.isfinite(score_struct.scores[0])) 

128 for src_idx in range(1, len(self.diaSources)): 

129 # Our scores should be extremely close to 0 but not exactly so due 

130 # to machine noise. 

131 self.assertAlmostEqual(score_struct.scores[src_idx], 0.0, 

132 places=16) 

133 

134 # After matching each DIAObject should now contain 2 DIASources 

135 # except the last DIAObject in this collection which should be 

136 # newly created during the matching step and contain only one 

137 # DIASource. 

138 match_result = assoc_task.match( 

139 self.diaObjects, self.diaSources, score_struct) 

140 self.assertEqual(match_result.nUpdatedDiaObjects, 4) 

141 self.assertEqual(match_result.nUnassociatedDiaObjects, 1) 

142 

143 def test_remove_nan_dia_sources(self): 

144 """Test removing DiaSources with NaN locations. 

145 """ 

146 self.diaSources.loc[2, "ra"] = np.nan 

147 self.diaSources.loc[3, "dec"] = np.nan 

148 self.diaSources.loc[4, "ra"] = np.nan 

149 self.diaSources.loc[4, "dec"] = np.nan 

150 assoc_task = AssociationTask() 

151 out_dia_sources = assoc_task.check_dia_source_radec(self.diaSources) 

152 self.assertEqual(len(out_dia_sources), len(self.diaSources) - 3) 

153 

154 

155class MemoryTester(lsst.utils.tests.MemoryTestCase): 

156 pass 

157 

158 

159def setup_module(module): 

160 lsst.utils.tests.init() 

161 

162 

163if __name__ == "__main__": 163 ↛ 164line 163 didn't jump to line 164, because the condition on line 163 was never true

164 lsst.utils.tests.init() 

165 unittest.main()