Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of pipe_tasks. 

2 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21# 

22 

23import numpy as np 

24import pandas as pd 

25import unittest 

26 

27import lsst.afw.table as afwTable 

28import lsst.geom as geom 

29import lsst.utils.tests 

30from lsst.pipe.tasks.associationUtils import toIndex 

31from lsst.pipe.tasks.simpleAssociation import SimpleAssociationTask 

32 

33 

34class TestSimpleAssociation(lsst.utils.tests.TestCase): 

35 

36 def setUp(self): 

37 simpleAssoc = SimpleAssociationTask() 

38 self.tractPatchId = 1234 

39 self.skymapBits = 16 

40 

41 self.nDiaObjects = 10 

42 self.diaObjRas = np.linspace(45, 46, self.nDiaObjects) 

43 self.diaObjDecs = np.linspace(45, 46, self.nDiaObjects) 

44 # Copy a coord to get multiple matches. 

45 self.diaObjRas[3] = self.diaObjRas[2] + 0.1/3600 

46 self.diaObjDecs[3] = self.diaObjDecs[2] + 0.1/3600 

47 self.diaObjects = [ 

48 simpleAssoc.createDiaObject(objId, ra, decl) 

49 for objId, ra, decl in zip( 

50 np.arange(self.nDiaObjects, dtype=int), 

51 self.diaObjRas, 

52 self.diaObjDecs)] 

53 

54 self.hpIndices = [toIndex(simpleAssoc.config.nside, 

55 diaObj["ra"], 

56 diaObj["decl"]) 

57 for diaObj in self.diaObjects] 

58 

59 self.newDiaObjectVisit = 1236 

60 # Drop in two copies of the DiaObject locations to make DiaSources. 

61 diaSourceList = [ 

62 {"ccdVisitId": 1234, 

63 "diaSourceId": idx, 

64 "diaObjectId": 0, 

65 "ra": ra, 

66 "decl": decl} 

67 for idx, (ra, decl) in enumerate(zip(self.diaObjRas, 

68 self.diaObjDecs))] 

69 self.coordList = [ 

70 [geom.SpherePoint(diaSrc["ra"], diaSrc["decl"], geom.degrees)] 

71 for diaSrc in diaSourceList] 

72 moreDiaSources = [ 

73 {"ccdVisitId": 1235, 

74 "diaSourceId": idx + self.nDiaObjects, 

75 "diaObjectId": 0, 

76 "ra": ra, 

77 "decl": decl} 

78 for idx, (ra, decl) in enumerate(zip(self.diaObjRas, 

79 self.diaObjDecs))] 

80 for idx in range(self.nDiaObjects): 

81 self.coordList[idx].append( 

82 geom.SpherePoint(moreDiaSources[idx]["ra"], 

83 moreDiaSources[idx]["decl"], 

84 geom.degrees)) 

85 diaSourceList.extend(moreDiaSources) 

86 

87 self.nNewDiaSources = 2 

88 # Drop in two more DiaSources that are unassociated. 

89 diaSourceList.append({"ccdVisitId": 1236, 

90 "diaSourceId": len(diaSourceList), 

91 "diaObjectId": 0, 

92 "ra": 0.0, 

93 "decl": 0.0}) 

94 diaSourceList.append({"ccdVisitId": 1236, 

95 "diaSourceId": len(diaSourceList), 

96 "diaObjectId": 0, 

97 "ra": 1.0, 

98 "decl": 89.0}) 

99 self.diaSources = pd.DataFrame(data=diaSourceList) 

100 

101 def tearDown(self): 

102 del self.diaObjects 

103 del self.hpIndices 

104 del self.diaSources 

105 del self.coordList 

106 

107 def testRun(self): 

108 """Test the full run method of the simple associator. 

109 """ 

110 simpleAssoc = SimpleAssociationTask() 

111 result = simpleAssoc.run(self.diaSources, 

112 self.tractPatchId, 

113 self.skymapBits) 

114 

115 # Test the number of expected DiaObjects are created. 

116 self.assertEqual(len(result.diaObjects), 

117 self.nDiaObjects + self.nNewDiaSources) 

118 # Test that DiaSources are assigned the correct ``diaObjectId`` 

119 assocDiaObjects = result.diaObjects.set_index(["diaObjectId"]) 

120 assocDiaSources = result.assocDiaSources.set_index(["diaObjectId", "diaSourceId"]) 

121 for idx, (diaObjId, diaObj) in enumerate(assocDiaObjects.iterrows()): 

122 if idx < 10: 

123 self.assertEqual(len(assocDiaSources.loc[diaObjId]), 2) 

124 else: 

125 self.assertEqual(len(assocDiaSources.loc[diaObjId]), 1) 

126 

127 def testUpdateCatalogs(self): 

128 """Test adding data to existing DiaObject/Source catalogs. 

129 """ 

130 matchIndex = 4 

131 diaSrc = self.diaSources.iloc[matchIndex] 

132 self.diaObjects[matchIndex]["diaObjectId"] = 1234 

133 ccdVisit = diaSrc["ccdVisitId"] 

134 diaSourceId = diaSrc["diaSourceId"] 

135 self.diaSources.set_index(["ccdVisitId", "diaSourceId"], inplace=True) 

136 

137 simpleAssoc = SimpleAssociationTask() 

138 simpleAssoc.updateCatalogs(matchIndex, 

139 diaSrc, 

140 self.diaSources, 

141 ccdVisit, 

142 diaSourceId, 

143 self.diaObjects, 

144 self.coordList, 

145 self.hpIndices) 

146 self.assertEqual(len(self.hpIndices), self.nDiaObjects) 

147 self.assertEqual(len(self.coordList), self.nDiaObjects) 

148 # Should be 3 source coordinates. 

149 self.assertEqual(len(self.coordList[matchIndex]), 3) 

150 self.assertEqual(len(self.diaObjects), self.nDiaObjects) 

151 self.assertEqual(self.diaSources.loc[(ccdVisit, diaSourceId), 

152 "diaObjectId"], 

153 self.diaObjects[matchIndex]["diaObjectId"]) 

154 

155 def testAddDiaObject(self): 

156 """Test adding data to existing DiaObjects/Sources. 

157 """ 

158 diaSrc = self.diaSources.iloc[-1] 

159 ccdVisit = diaSrc["ccdVisitId"] 

160 diaSourceId = diaSrc["diaSourceId"] 

161 self.diaSources.set_index(["ccdVisitId", "diaSourceId"], inplace=True) 

162 idFactory = afwTable.IdFactory.makeSource(1234, 

163 64 - 16) 

164 idCat = afwTable.SourceCatalog( 

165 afwTable.SourceTable.make(afwTable.SourceTable.makeMinimalSchema(), 

166 idFactory)) 

167 

168 simpleAssoc = SimpleAssociationTask() 

169 simpleAssoc.addNewDiaObject(diaSrc, 

170 self.diaSources, 

171 ccdVisit, 

172 diaSourceId, 

173 self.diaObjects, 

174 idCat, 

175 self.coordList, 

176 self.hpIndices) 

177 self.assertEqual(len(self.hpIndices), self.nDiaObjects + 1) 

178 self.assertEqual(len(self.coordList), self.nDiaObjects + 1) 

179 self.assertEqual(len(self.diaObjects), self.nDiaObjects + 1) 

180 self.assertEqual(self.diaSources.loc[(ccdVisit, diaSourceId), 

181 "diaObjectId"], 

182 idCat[0].get("id")) 

183 

184 def testFindMatches(self): 

185 """Test the simple brute force matching algorithm. 

186 """ 

187 simpleAssoc = SimpleAssociationTask() 

188 # No match 

189 matchResult = simpleAssoc.findMatches( 

190 0.0, 

191 0.0, 

192 2*simpleAssoc.config.tolerance, 

193 self.hpIndices, 

194 self.diaObjects) 

195 self.assertIsNone(matchResult.dists) 

196 self.assertIsNone(matchResult.matches) 

197 

198 # One match 

199 matchResult = simpleAssoc.findMatches( 

200 self.diaObjRas[4], 

201 self.diaObjDecs[4], 

202 2*simpleAssoc.config.tolerance, 

203 self.hpIndices, 

204 self.diaObjects) 

205 self.assertEqual(len(matchResult.dists), 1) 

206 self.assertEqual(len(matchResult.matches), 1) 

207 self.assertEqual(matchResult.matches[0], 4) 

208 

209 # 2 match 

210 matchResult = simpleAssoc.findMatches( 

211 self.diaObjRas[2], 

212 self.diaObjDecs[2], 

213 2*simpleAssoc.config.tolerance, 

214 self.hpIndices, 

215 self.diaObjects) 

216 self.assertEqual(len(matchResult.dists), 2) 

217 self.assertEqual(len(matchResult.matches), 2) 

218 self.assertEqual(matchResult.matches[0], 2) 

219 self.assertEqual(matchResult.matches[1], 3) 

220 

221 

222def setup_module(module): 

223 lsst.utils.tests.init() 

224 

225 

226class MemoryTestCase(lsst.utils.tests.MemoryTestCase): 

227 pass 

228 

229 

230if __name__ == "__main__": 230 ↛ 231line 230 didn't jump to line 231, because the condition on line 230 was never true

231 lsst.utils.tests.init() 

232 unittest.main()