Coverage for tests/test_simpleAssociation.py: 18%

91 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-02 10:25 +0000

1# This file is part of pipe_tasks. 

2 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21# 

22 

23import numpy as np 

24import pandas as pd 

25import unittest 

26 

27import lsst.afw.table as afwTable 

28import lsst.geom as geom 

29import lsst.utils.tests 

30from lsst.pipe.tasks.associationUtils import toIndex 

31from lsst.pipe.tasks.simpleAssociation import SimpleAssociationTask 

32 

33 

34class TestSimpleAssociation(lsst.utils.tests.TestCase): 

35 

36 def setUp(self): 

37 simpleAssoc = SimpleAssociationTask() 

38 

39 self.nDiaObjects = 10 

40 self.diaObjRas = np.linspace(45, 46, self.nDiaObjects) 

41 self.diaObjDecs = np.linspace(45, 46, self.nDiaObjects) 

42 # Copy a coord to get multiple matches. 

43 self.diaObjRas[3] = self.diaObjRas[2] + 0.1/3600 

44 self.diaObjDecs[3] = self.diaObjDecs[2] + 0.1/3600 

45 self.diaObjects = [ 

46 simpleAssoc.createDiaObject(objId, ra, decl) 

47 for objId, ra, decl in zip( 

48 np.arange(self.nDiaObjects, dtype=int), 

49 self.diaObjRas, 

50 self.diaObjDecs)] 

51 

52 self.hpIndices = [toIndex(simpleAssoc.config.nside, 

53 diaObj["ra"], 

54 diaObj["decl"]) 

55 for diaObj in self.diaObjects] 

56 

57 self.newDiaObjectVisit = 1236 

58 # Drop in two copies of the DiaObject locations to make DiaSources. 

59 diaSourceList = [ 

60 {"ccdVisitId": 1234, 

61 "diaSourceId": idx, 

62 "diaObjectId": 0, 

63 "ra": ra, 

64 "decl": decl} 

65 for idx, (ra, decl) in enumerate(zip(self.diaObjRas, 

66 self.diaObjDecs))] 

67 self.coordList = [ 

68 [geom.SpherePoint(diaSrc["ra"], diaSrc["decl"], geom.degrees)] 

69 for diaSrc in diaSourceList] 

70 moreDiaSources = [ 

71 {"ccdVisitId": 1235, 

72 "diaSourceId": idx + self.nDiaObjects, 

73 "diaObjectId": 0, 

74 "ra": ra, 

75 "decl": decl} 

76 for idx, (ra, decl) in enumerate(zip(self.diaObjRas, 

77 self.diaObjDecs))] 

78 for idx in range(self.nDiaObjects): 

79 self.coordList[idx].append( 

80 geom.SpherePoint(moreDiaSources[idx]["ra"], 

81 moreDiaSources[idx]["decl"], 

82 geom.degrees)) 

83 diaSourceList.extend(moreDiaSources) 

84 

85 self.nNewDiaSources = 2 

86 # Drop in two more DiaSources that are unassociated. 

87 diaSourceList.append({"ccdVisitId": 1236, 

88 "diaSourceId": len(diaSourceList), 

89 "diaObjectId": 0, 

90 "ra": 0.0, 

91 "decl": 0.0}) 

92 diaSourceList.append({"ccdVisitId": 1236, 

93 "diaSourceId": len(diaSourceList), 

94 "diaObjectId": 0, 

95 "ra": 1.0, 

96 "decl": 89.0}) 

97 self.diaSources = pd.DataFrame(data=diaSourceList) 

98 

99 def tearDown(self): 

100 del self.diaObjects 

101 del self.hpIndices 

102 del self.diaSources 

103 del self.coordList 

104 

105 def testRun(self): 

106 """Test the full run method of the simple associator. 

107 """ 

108 simpleAssoc = SimpleAssociationTask() 

109 result = simpleAssoc.run(self.diaSources) 

110 

111 # Test the number of expected DiaObjects are created. 

112 self.assertEqual(len(result.diaObjects), 

113 self.nDiaObjects + self.nNewDiaSources) 

114 

115 # Test that DiaSources are assigned the correct ``diaObjectId`` 

116 assocDiaObjects = result.diaObjects 

117 assocDiaSources = result.assocDiaSources.reset_index().set_index(["diaObjectId", "diaSourceId"]) 

118 for idx, (diaObjId, diaObj) in enumerate(assocDiaObjects.iterrows()): 

119 if idx < 10: 

120 self.assertEqual(len(assocDiaSources.loc[diaObjId]), 2) 

121 else: 

122 self.assertEqual(len(assocDiaSources.loc[diaObjId]), 1) 

123 

124 def testUpdateCatalogs(self): 

125 """Test adding data to existing DiaObject/Source catalogs. 

126 """ 

127 matchIndex = 4 

128 diaSrc = self.diaSources.iloc[matchIndex] 

129 self.diaObjects[matchIndex]["diaObjectId"] = 1234 

130 ccdVisit = diaSrc["ccdVisitId"] 

131 diaSourceId = diaSrc["diaSourceId"] 

132 self.diaSources.set_index(["ccdVisitId", "diaSourceId"], inplace=True) 

133 

134 simpleAssoc = SimpleAssociationTask() 

135 simpleAssoc.updateCatalogs(matchIndex, 

136 diaSrc, 

137 self.diaSources, 

138 ccdVisit, 

139 diaSourceId, 

140 self.diaObjects, 

141 self.coordList, 

142 self.hpIndices) 

143 self.assertEqual(len(self.hpIndices), self.nDiaObjects) 

144 self.assertEqual(len(self.coordList), self.nDiaObjects) 

145 # Should be 3 source coordinates. 

146 self.assertEqual(len(self.coordList[matchIndex]), 3) 

147 self.assertEqual(len(self.diaObjects), self.nDiaObjects) 

148 self.assertEqual(self.diaSources.loc[(ccdVisit, diaSourceId), 

149 "diaObjectId"], 

150 self.diaObjects[matchIndex]["diaObjectId"]) 

151 

152 def testAddDiaObject(self): 

153 """Test adding data to existing DiaObjects/Sources. 

154 """ 

155 diaSrc = self.diaSources.iloc[-1] 

156 ccdVisit = diaSrc["ccdVisitId"] 

157 diaSourceId = diaSrc["diaSourceId"] 

158 self.diaSources.set_index(["ccdVisitId", "diaSourceId"], inplace=True) 

159 idCat = afwTable.SourceCatalog( 

160 afwTable.SourceTable.make(afwTable.SourceTable.makeMinimalSchema())) 

161 

162 simpleAssoc = SimpleAssociationTask() 

163 simpleAssoc.addNewDiaObject(diaSrc, 

164 self.diaSources, 

165 ccdVisit, 

166 diaSourceId, 

167 self.diaObjects, 

168 idCat, 

169 self.coordList, 

170 self.hpIndices) 

171 self.assertEqual(len(self.hpIndices), self.nDiaObjects + 1) 

172 self.assertEqual(len(self.coordList), self.nDiaObjects + 1) 

173 self.assertEqual(len(self.diaObjects), self.nDiaObjects + 1) 

174 self.assertEqual(self.diaSources.loc[(ccdVisit, diaSourceId), 

175 "diaObjectId"], 

176 idCat[0].get("id")) 

177 

178 def testFindMatches(self): 

179 """Test the simple brute force matching algorithm. 

180 """ 

181 simpleAssoc = SimpleAssociationTask() 

182 # No match 

183 matchResult = simpleAssoc.findMatches( 

184 0.0, 

185 0.0, 

186 2*simpleAssoc.config.tolerance, 

187 self.hpIndices, 

188 self.diaObjects) 

189 self.assertIsNone(matchResult.dists) 

190 self.assertIsNone(matchResult.matches) 

191 

192 # One match 

193 matchResult = simpleAssoc.findMatches( 

194 self.diaObjRas[4], 

195 self.diaObjDecs[4], 

196 2*simpleAssoc.config.tolerance, 

197 self.hpIndices, 

198 self.diaObjects) 

199 self.assertEqual(len(matchResult.dists), 1) 

200 self.assertEqual(len(matchResult.matches), 1) 

201 self.assertEqual(matchResult.matches[0], 4) 

202 

203 # 2 match 

204 matchResult = simpleAssoc.findMatches( 

205 self.diaObjRas[2], 

206 self.diaObjDecs[2], 

207 2*simpleAssoc.config.tolerance, 

208 self.hpIndices, 

209 self.diaObjects) 

210 self.assertEqual(len(matchResult.dists), 2) 

211 self.assertEqual(len(matchResult.matches), 2) 

212 self.assertEqual(matchResult.matches[0], 2) 

213 self.assertEqual(matchResult.matches[1], 3) 

214 

215 

216def setup_module(module): 

217 lsst.utils.tests.init() 

218 

219 

220class MemoryTestCase(lsst.utils.tests.MemoryTestCase): 

221 pass 

222 

223 

224if __name__ == "__main__": 224 ↛ 225line 224 didn't jump to line 225, because the condition on line 224 was never true

225 lsst.utils.tests.init() 

226 unittest.main()