Coverage for tests/test_simpleAssociation.py: 18%

94 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-11-16 01:25 -0800

1# This file is part of pipe_tasks. 

2 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21# 

22 

23import numpy as np 

24import pandas as pd 

25import unittest 

26 

27import lsst.afw.table as afwTable 

28import lsst.geom as geom 

29import lsst.utils.tests 

30from lsst.pipe.tasks.associationUtils import toIndex 

31from lsst.pipe.tasks.simpleAssociation import SimpleAssociationTask 

32 

33 

34class TestSimpleAssociation(lsst.utils.tests.TestCase): 

35 

36 def setUp(self): 

37 simpleAssoc = SimpleAssociationTask() 

38 self.tractPatchId = 1234 

39 self.skymapBits = 16 

40 

41 self.nDiaObjects = 10 

42 self.diaObjRas = np.linspace(45, 46, self.nDiaObjects) 

43 self.diaObjDecs = np.linspace(45, 46, self.nDiaObjects) 

44 # Copy a coord to get multiple matches. 

45 self.diaObjRas[3] = self.diaObjRas[2] + 0.1/3600 

46 self.diaObjDecs[3] = self.diaObjDecs[2] + 0.1/3600 

47 self.diaObjects = [ 

48 simpleAssoc.createDiaObject(objId, ra, decl) 

49 for objId, ra, decl in zip( 

50 np.arange(self.nDiaObjects, dtype=int), 

51 self.diaObjRas, 

52 self.diaObjDecs)] 

53 

54 self.hpIndices = [toIndex(simpleAssoc.config.nside, 

55 diaObj["ra"], 

56 diaObj["decl"]) 

57 for diaObj in self.diaObjects] 

58 

59 self.newDiaObjectVisit = 1236 

60 # Drop in two copies of the DiaObject locations to make DiaSources. 

61 diaSourceList = [ 

62 {"ccdVisitId": 1234, 

63 "diaSourceId": idx, 

64 "diaObjectId": 0, 

65 "ra": ra, 

66 "decl": decl} 

67 for idx, (ra, decl) in enumerate(zip(self.diaObjRas, 

68 self.diaObjDecs))] 

69 self.coordList = [ 

70 [geom.SpherePoint(diaSrc["ra"], diaSrc["decl"], geom.degrees)] 

71 for diaSrc in diaSourceList] 

72 moreDiaSources = [ 

73 {"ccdVisitId": 1235, 

74 "diaSourceId": idx + self.nDiaObjects, 

75 "diaObjectId": 0, 

76 "ra": ra, 

77 "decl": decl} 

78 for idx, (ra, decl) in enumerate(zip(self.diaObjRas, 

79 self.diaObjDecs))] 

80 for idx in range(self.nDiaObjects): 

81 self.coordList[idx].append( 

82 geom.SpherePoint(moreDiaSources[idx]["ra"], 

83 moreDiaSources[idx]["decl"], 

84 geom.degrees)) 

85 diaSourceList.extend(moreDiaSources) 

86 

87 self.nNewDiaSources = 2 

88 # Drop in two more DiaSources that are unassociated. 

89 diaSourceList.append({"ccdVisitId": 1236, 

90 "diaSourceId": len(diaSourceList), 

91 "diaObjectId": 0, 

92 "ra": 0.0, 

93 "decl": 0.0}) 

94 diaSourceList.append({"ccdVisitId": 1236, 

95 "diaSourceId": len(diaSourceList), 

96 "diaObjectId": 0, 

97 "ra": 1.0, 

98 "decl": 89.0}) 

99 self.diaSources = pd.DataFrame(data=diaSourceList) 

100 

101 def tearDown(self): 

102 del self.diaObjects 

103 del self.hpIndices 

104 del self.diaSources 

105 del self.coordList 

106 

107 def testRun(self): 

108 """Test the full run method of the simple associator. 

109 """ 

110 simpleAssoc = SimpleAssociationTask() 

111 result = simpleAssoc.run(self.diaSources, 

112 self.tractPatchId, 

113 self.skymapBits) 

114 

115 # Test the number of expected DiaObjects are created. 

116 self.assertEqual(len(result.diaObjects), 

117 self.nDiaObjects + self.nNewDiaSources) 

118 

119 # Test that DiaSources are assigned the correct ``diaObjectId`` 

120 assocDiaObjects = result.diaObjects 

121 assocDiaSources = result.assocDiaSources.reset_index().set_index(["diaObjectId", "diaSourceId"]) 

122 for idx, (diaObjId, diaObj) in enumerate(assocDiaObjects.iterrows()): 

123 if idx < 10: 

124 self.assertEqual(len(assocDiaSources.loc[diaObjId]), 2) 

125 else: 

126 self.assertEqual(len(assocDiaSources.loc[diaObjId]), 1) 

127 

128 def testUpdateCatalogs(self): 

129 """Test adding data to existing DiaObject/Source catalogs. 

130 """ 

131 matchIndex = 4 

132 diaSrc = self.diaSources.iloc[matchIndex] 

133 self.diaObjects[matchIndex]["diaObjectId"] = 1234 

134 ccdVisit = diaSrc["ccdVisitId"] 

135 diaSourceId = diaSrc["diaSourceId"] 

136 self.diaSources.set_index(["ccdVisitId", "diaSourceId"], inplace=True) 

137 

138 simpleAssoc = SimpleAssociationTask() 

139 simpleAssoc.updateCatalogs(matchIndex, 

140 diaSrc, 

141 self.diaSources, 

142 ccdVisit, 

143 diaSourceId, 

144 self.diaObjects, 

145 self.coordList, 

146 self.hpIndices) 

147 self.assertEqual(len(self.hpIndices), self.nDiaObjects) 

148 self.assertEqual(len(self.coordList), self.nDiaObjects) 

149 # Should be 3 source coordinates. 

150 self.assertEqual(len(self.coordList[matchIndex]), 3) 

151 self.assertEqual(len(self.diaObjects), self.nDiaObjects) 

152 self.assertEqual(self.diaSources.loc[(ccdVisit, diaSourceId), 

153 "diaObjectId"], 

154 self.diaObjects[matchIndex]["diaObjectId"]) 

155 

156 def testAddDiaObject(self): 

157 """Test adding data to existing DiaObjects/Sources. 

158 """ 

159 diaSrc = self.diaSources.iloc[-1] 

160 ccdVisit = diaSrc["ccdVisitId"] 

161 diaSourceId = diaSrc["diaSourceId"] 

162 self.diaSources.set_index(["ccdVisitId", "diaSourceId"], inplace=True) 

163 idFactory = afwTable.IdFactory.makeSource(1234, afwTable.IdFactory.computeReservedFromMaxBits(16)) 

164 idCat = afwTable.SourceCatalog( 

165 afwTable.SourceTable.make(afwTable.SourceTable.makeMinimalSchema(), 

166 idFactory)) 

167 

168 simpleAssoc = SimpleAssociationTask() 

169 simpleAssoc.addNewDiaObject(diaSrc, 

170 self.diaSources, 

171 ccdVisit, 

172 diaSourceId, 

173 self.diaObjects, 

174 idCat, 

175 self.coordList, 

176 self.hpIndices) 

177 self.assertEqual(len(self.hpIndices), self.nDiaObjects + 1) 

178 self.assertEqual(len(self.coordList), self.nDiaObjects + 1) 

179 self.assertEqual(len(self.diaObjects), self.nDiaObjects + 1) 

180 self.assertEqual(self.diaSources.loc[(ccdVisit, diaSourceId), 

181 "diaObjectId"], 

182 idCat[0].get("id")) 

183 

184 def testFindMatches(self): 

185 """Test the simple brute force matching algorithm. 

186 """ 

187 simpleAssoc = SimpleAssociationTask() 

188 # No match 

189 matchResult = simpleAssoc.findMatches( 

190 0.0, 

191 0.0, 

192 2*simpleAssoc.config.tolerance, 

193 self.hpIndices, 

194 self.diaObjects) 

195 self.assertIsNone(matchResult.dists) 

196 self.assertIsNone(matchResult.matches) 

197 

198 # One match 

199 matchResult = simpleAssoc.findMatches( 

200 self.diaObjRas[4], 

201 self.diaObjDecs[4], 

202 2*simpleAssoc.config.tolerance, 

203 self.hpIndices, 

204 self.diaObjects) 

205 self.assertEqual(len(matchResult.dists), 1) 

206 self.assertEqual(len(matchResult.matches), 1) 

207 self.assertEqual(matchResult.matches[0], 4) 

208 

209 # 2 match 

210 matchResult = simpleAssoc.findMatches( 

211 self.diaObjRas[2], 

212 self.diaObjDecs[2], 

213 2*simpleAssoc.config.tolerance, 

214 self.hpIndices, 

215 self.diaObjects) 

216 self.assertEqual(len(matchResult.dists), 2) 

217 self.assertEqual(len(matchResult.matches), 2) 

218 self.assertEqual(matchResult.matches[0], 2) 

219 self.assertEqual(matchResult.matches[1], 3) 

220 

221 

222def setup_module(module): 

223 lsst.utils.tests.init() 

224 

225 

226class MemoryTestCase(lsst.utils.tests.MemoryTestCase): 

227 pass 

228 

229 

230if __name__ == "__main__": 230 ↛ 231line 230 didn't jump to line 231, because the condition on line 230 was never true

231 lsst.utils.tests.init() 

232 unittest.main()