Coverage for tests/test_simpleAssociation.py: 18%

95 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-16 11:06 +0000

1# This file is part of pipe_tasks. 

2 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21# 

22 

23import numpy as np 

24import pandas as pd 

25import unittest 

26 

27import lsst.afw.table as afwTable 

28import lsst.geom as geom 

29import lsst.utils.tests 

30from lsst.pipe.tasks.associationUtils import toIndex 

31from lsst.pipe.tasks.simpleAssociation import SimpleAssociationTask 

32 

33 

34class TestSimpleAssociation(lsst.utils.tests.TestCase): 

35 

36 def setUp(self): 

37 simpleAssoc = SimpleAssociationTask() 

38 

39 self.nDiaObjects = 10 

40 self.diaObjRas = np.linspace(45, 46, self.nDiaObjects) 

41 self.diaObjDecs = np.linspace(45, 46, self.nDiaObjects) 

42 # Copy a coord to get multiple matches. 

43 self.diaObjRas[3] = self.diaObjRas[2] + 0.1/3600 

44 self.diaObjDecs[3] = self.diaObjDecs[2] + 0.1/3600 

45 self.diaObjects = [ 

46 simpleAssoc.createDiaObject(objId, ra, dec) 

47 for objId, ra, dec in zip( 

48 np.arange(self.nDiaObjects, dtype=int), 

49 self.diaObjRas, 

50 self.diaObjDecs)] 

51 

52 self.hpIndices = [toIndex(simpleAssoc.config.nside, 

53 diaObj["ra"], 

54 diaObj["dec"]) 

55 for diaObj in self.diaObjects] 

56 

57 self.newDiaObjectVisit = 1236 

58 # Drop in two copies of the DiaObject locations to make DiaSources. 

59 diaSourceList = [ 

60 {"visit": 1234, 

61 "detector": 42, 

62 "diaSourceId": idx, 

63 "diaObjectId": 0, 

64 "ra": ra, 

65 "dec": dec} 

66 for idx, (ra, dec) in enumerate(zip(self.diaObjRas, 

67 self.diaObjDecs))] 

68 self.coordList = [ 

69 [geom.SpherePoint(diaSrc["ra"], diaSrc["dec"], geom.degrees)] 

70 for diaSrc in diaSourceList] 

71 moreDiaSources = [ 

72 {"visit": 1235, 

73 "detector": 43, 

74 "diaSourceId": idx + self.nDiaObjects, 

75 "diaObjectId": 0, 

76 "ra": ra, 

77 "dec": dec} 

78 for idx, (ra, dec) in enumerate(zip(self.diaObjRas, 

79 self.diaObjDecs))] 

80 for idx in range(self.nDiaObjects): 

81 self.coordList[idx].append( 

82 geom.SpherePoint(moreDiaSources[idx]["ra"], 

83 moreDiaSources[idx]["dec"], 

84 geom.degrees)) 

85 diaSourceList.extend(moreDiaSources) 

86 

87 self.nNewDiaSources = 2 

88 # Drop in two more DiaSources that are unassociated. 

89 diaSourceList.append({"visit": 1236, 

90 "detector": 44, 

91 "diaSourceId": len(diaSourceList), 

92 "diaObjectId": 0, 

93 "ra": 0.0, 

94 "dec": 0.0}) 

95 diaSourceList.append({"visit": 1236, 

96 "detector": 45, 

97 "diaSourceId": len(diaSourceList), 

98 "diaObjectId": 0, 

99 "ra": 1.0, 

100 "dec": 89.0}) 

101 self.diaSources = pd.DataFrame(data=diaSourceList) 

102 

103 def tearDown(self): 

104 del self.diaObjects 

105 del self.hpIndices 

106 del self.diaSources 

107 del self.coordList 

108 

109 def testRun(self): 

110 """Test the full run method of the simple associator. 

111 """ 

112 simpleAssoc = SimpleAssociationTask() 

113 result = simpleAssoc.run(self.diaSources) 

114 

115 # Test the number of expected DiaObjects are created. 

116 self.assertEqual(len(result.diaObjects), 

117 self.nDiaObjects + self.nNewDiaSources) 

118 

119 # Test that DiaSources are assigned the correct ``diaObjectId`` 

120 assocDiaObjects = result.diaObjects 

121 assocDiaSources = result.assocDiaSources.reset_index().set_index(["diaObjectId", "diaSourceId"]) 

122 for idx, (diaObjId, diaObj) in enumerate(assocDiaObjects.iterrows()): 

123 if idx < 10: 

124 self.assertEqual(len(assocDiaSources.loc[diaObjId]), 2) 

125 else: 

126 self.assertEqual(len(assocDiaSources.loc[diaObjId]), 1) 

127 

128 def testUpdateCatalogs(self): 

129 """Test adding data to existing DiaObject/Source catalogs. 

130 """ 

131 matchIndex = 4 

132 diaSrc = self.diaSources.iloc[matchIndex] 

133 self.diaObjects[matchIndex]["diaObjectId"] = 1234 

134 visit = diaSrc["visit"] 

135 detector = diaSrc["detector"] 

136 diaSourceId = diaSrc["diaSourceId"] 

137 self.diaSources[("visit,detector")] = list(zip(self.diaSources["visit"], 

138 self.diaSources["detector"])) 

139 self.diaSources.set_index(["visit,detector", "diaSourceId"], inplace=True) 

140 

141 simpleAssoc = SimpleAssociationTask() 

142 simpleAssoc.updateCatalogs(matchIndex, 

143 diaSrc, 

144 self.diaSources, 

145 visit, 

146 detector, 

147 diaSourceId, 

148 self.diaObjects, 

149 self.coordList, 

150 self.hpIndices) 

151 self.assertEqual(len(self.hpIndices), self.nDiaObjects) 

152 self.assertEqual(len(self.coordList), self.nDiaObjects) 

153 # Should be 3 source coordinates. 

154 self.assertEqual(len(self.coordList[matchIndex]), 3) 

155 self.assertEqual(len(self.diaObjects), self.nDiaObjects) 

156 self.assertEqual(self.diaSources.loc[((visit, detector), diaSourceId), "diaObjectId"].iloc[0], 

157 self.diaObjects[matchIndex]["diaObjectId"]) 

158 

159 def testAddDiaObject(self): 

160 """Test adding data to existing DiaObjects/Sources. 

161 """ 

162 diaSrc = self.diaSources.iloc[-1] 

163 visit = diaSrc["visit"] 

164 detector = diaSrc["detector"] 

165 diaSourceId = diaSrc["diaSourceId"] 

166 self.diaSources[("visit,detector")] = list(zip(self.diaSources["visit"], self.diaSources["detector"])) 

167 self.diaSources.set_index(["visit,detector", "diaSourceId"], inplace=True) 

168 idCat = afwTable.SourceCatalog( 

169 afwTable.SourceTable.make(afwTable.SourceTable.makeMinimalSchema())) 

170 

171 simpleAssoc = SimpleAssociationTask() 

172 simpleAssoc.addNewDiaObject(diaSrc, 

173 self.diaSources, 

174 visit, 

175 detector, 

176 diaSourceId, 

177 self.diaObjects, 

178 idCat, 

179 self.coordList, 

180 self.hpIndices) 

181 self.assertEqual(len(self.hpIndices), self.nDiaObjects + 1) 

182 self.assertEqual(len(self.coordList), self.nDiaObjects + 1) 

183 self.assertEqual(len(self.diaObjects), self.nDiaObjects + 1) 

184 self.assertEqual(self.diaSources.loc[((visit, detector), diaSourceId), "diaObjectId"].iloc[0], 

185 idCat[0].get("id")) 

186 

187 def testFindMatches(self): 

188 """Test the simple brute force matching algorithm. 

189 """ 

190 simpleAssoc = SimpleAssociationTask() 

191 # No match 

192 matchResult = simpleAssoc.findMatches( 

193 0.0, 

194 0.0, 

195 2*simpleAssoc.config.tolerance, 

196 self.hpIndices, 

197 self.diaObjects) 

198 self.assertIsNone(matchResult.dists) 

199 self.assertIsNone(matchResult.matches) 

200 

201 # One match 

202 matchResult = simpleAssoc.findMatches( 

203 self.diaObjRas[4], 

204 self.diaObjDecs[4], 

205 2*simpleAssoc.config.tolerance, 

206 self.hpIndices, 

207 self.diaObjects) 

208 self.assertEqual(len(matchResult.dists), 1) 

209 self.assertEqual(len(matchResult.matches), 1) 

210 self.assertEqual(matchResult.matches[0], 4) 

211 

212 # 2 match 

213 matchResult = simpleAssoc.findMatches( 

214 self.diaObjRas[2], 

215 self.diaObjDecs[2], 

216 2*simpleAssoc.config.tolerance, 

217 self.hpIndices, 

218 self.diaObjects) 

219 self.assertEqual(len(matchResult.dists), 2) 

220 self.assertEqual(len(matchResult.matches), 2) 

221 self.assertEqual(matchResult.matches[0], 2) 

222 self.assertEqual(matchResult.matches[1], 3) 

223 

224 

225def setup_module(module): 

226 lsst.utils.tests.init() 

227 

228 

229class MemoryTestCase(lsst.utils.tests.MemoryTestCase): 

230 pass 

231 

232 

233if __name__ == "__main__": 233 ↛ 234line 233 didn't jump to line 234, because the condition on line 233 was never true

234 lsst.utils.tests.init() 

235 unittest.main()