Coverage for tests/helper/skyMapTestCase.py: 9%

230 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-12-01 20:23 +0000

1# 

2# LSST Data Management System 

3# Copyright 2008-2017 LSST Corporation. 

4# 

5# This product includes software developed by the 

6# LSST Project (http://www.lsst.org/). 

7# 

8# This program is free software: you can redistribute it and/or modify 

9# it under the terms of the GNU General Public License as published by 

10# the Free Software Foundation, either version 3 of the License, or 

11# (at your option) any later version. 

12# 

13# This program is distributed in the hope that it will be useful, 

14# but WITHOUT ANY WARRANTY; without even the implied warranty of 

15# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

16# GNU General Public License for more details. 

17# 

18# You should have received a copy of the LSST License Statement and 

19# the GNU General Public License along with this program. If not, 

20# see <http://www.lsstcorp.org/LegalNotices/>. 

21# 

22import itertools 

23import pickle 

24 

25import numpy as np 

26 

27import lsst.geom as geom 

28import lsst.utils.tests 

29 

30from lsst.skymap import skyMapRegistry 

31 

32 

33def checkDm14809(testcase, skymap): 

34 """Test that DM-14809 has been fixed 

35 

36 The observed behaviour was: 

37 

38 skyMap.findTract(skyMap[9712].getCtrCoord()).getId() != 9712 

39 

40 and 

41 

42 skyMap[1].getCtrCoord() == skyMap[11].getCtrCoord() 

43 

44 In order to be thorough, we generalise these over the entire skymap. 

45 """ 

46 # Check that the tract found for central coordinate of a tract is that tract 

47 expect = [tract.getId() for tract in skymap] 

48 got = [skymap.findTract(tract.getCtrCoord()).getId() for tract in skymap] 

49 testcase.assertListEqual(got, expect) 

50 

51 # Check that the tract central coordinates are unique 

52 # Round to integer arcminutes so differences are relatively immune to small numerical inaccuracies 

53 centers = set([(int(coord.getRa().asArcminutes()), int(coord.getDec().asArcminutes())) for 

54 coord in (tract.getCtrCoord() for tract in skymap)]) 

55 testcase.assertEqual(len(centers), len(skymap)) 

56 

57 

58class SkyMapTestCase(lsst.utils.tests.TestCase): 

59 """An abstract base class for testing a SkyMap. 

60 

61 To use, subclass and call `setAttributes` from `setUp` 

62 """ 

63 def setAttributes(self, *, 

64 SkyMapClass, 

65 name, 

66 numTracts, 

67 config=None, 

68 neighborAngularSeparation=None, 

69 numNeighbors=None): 

70 """Initialize the test (call from setUp in the subclass) 

71 

72 Parameters 

73 ---------- 

74 SkyMapClass : subclass of `lsst.skymap.BaseSkyMap` 

75 Class of sky map to test 

76 name : `str` 

77 Name of sky map in sky map registry 

78 numTracts : `int` 

79 Number of tracts in the default configuration 

80 config : subclass of `lsst.skymap.SkyMapConfig`, optional 

81 Default configuration used by `getSkyMap`; 

82 if None use SkyMapClass.ConfigClass() 

83 neighborAngularSeparation : `lsst.geom.Angle`, optional 

84 Expected angular separation between tracts; 

85 if None then angular separation is not tested unless your 

86 subclass of SkyMapTestCaseoverrides `testTractSeparation`. 

87 numNeighbors : `int` or `None` 

88 Number of neighbors that should be within 

89 ``neighborAngularSeparation``; 

90 Required if ``neighborAngularSeparation`` is not None; 

91 ignored otherwise. 

92 """ 

93 self.SkyMapClass = SkyMapClass 

94 self.config = config 

95 self.name = name 

96 self.numTracts = numTracts 

97 self.neighborAngularSeparation = neighborAngularSeparation 

98 self.numNeighbors = numNeighbors 

99 np.random.seed(47) 

100 

101 def getSkyMap(self, config=None): 

102 """Provide an instance of the skymap""" 

103 if config is None: 

104 config = self.getConfig() 

105 return self.SkyMapClass(config=config) 

106 

107 def getConfig(self): 

108 """Provide an instance of the configuration class""" 

109 if self.config is None: 

110 return self.SkyMapClass.ConfigClass() 

111 # Want to return a copy of self.config, so it can be modified. 

112 # However, there is no Config.copy() method, so this is more complicated than desirable. 

113 return pickle.loads(pickle.dumps(self.config)) 

114 

115 def testRegistry(self): 

116 """Confirm that the skymap can be retrieved from the registry""" 

117 self.assertEqual(skyMapRegistry[self.name], self.SkyMapClass) 

118 

119 def testBasicAttributes(self): 

120 """Confirm that constructor attributes are available 

121 """ 

122 defaultSkyMap = self.getSkyMap() 

123 for tractOverlap in (0.0, 0.01, 0.1): # degrees 

124 config = self.getConfig() 

125 config.tractOverlap = tractOverlap 

126 skyMap = self.getSkyMap(config) 

127 for tractInfo in skyMap: 

128 self.assertAlmostEqual(tractInfo.getTractOverlap().asDegrees(), tractOverlap) 

129 self.assertEqual(len(skyMap), self.numTracts) 

130 self.assertNotEqual(skyMap, defaultSkyMap) 

131 

132 for patchBorder in (0, 101): 

133 config = self.getConfig() 

134 config.patchBorder = patchBorder 

135 skyMap = self.getSkyMap(config) 

136 for tractInfo in skyMap: 

137 self.assertEqual(tractInfo.getPatchBorder(), patchBorder) 

138 self.assertEqual(len(skyMap), self.numTracts) 

139 self.assertNotEqual(skyMap, defaultSkyMap) 

140 

141 for xInnerDim in (1005, 5062): 

142 for yInnerDim in (2032, 5431): 

143 config = self.getConfig() 

144 config.patchInnerDimensions = (xInnerDim, yInnerDim) 

145 skyMap = self.getSkyMap(config) 

146 for tractInfo in skyMap: 

147 self.assertEqual(tuple(tractInfo.getPatchInnerDimensions()), (xInnerDim, yInnerDim)) 

148 self.assertEqual(len(skyMap), self.numTracts) 

149 self.assertNotEqual(skyMap, defaultSkyMap) 

150 

151 def assertUnpickledTractInfo(self, unpickled, original, patchBorder): 

152 """Assert that an unpickled TractInfo is functionally identical to the original 

153 

154 @param unpickled The unpickled TractInfo 

155 @param original The original TractInfo 

156 @param patchBorder Border around each patch, from SkyMap.config.patchBorder 

157 """ 

158 for getterName in ("getBBox", 

159 "getCtrCoord", 

160 "getId", 

161 "getNumPatches", 

162 "getPatchBorder", 

163 "getPatchInnerDimensions", 

164 "getTractOverlap", 

165 "getVertexList", 

166 "getWcs", 

167 ): 

168 self.assertEqual(getattr(original, getterName)(), getattr(unpickled, getterName)()) 

169 

170 # test WCS at a few locations 

171 wcs = original.getWcs() 

172 unpickledWcs = unpickled.getWcs() 

173 for x in (-1000.0, 0.0, 1000.0): 

174 for y in (-532.5, 0.5, 532.5): 

175 pixelPos = geom.Point2D(x, y) 

176 skyPos = wcs.pixelToSky(pixelPos) 

177 unpickledSkyPos = unpickledWcs.pixelToSky(pixelPos) 

178 self.assertEqual(skyPos, unpickledSkyPos) 

179 

180 # compare a few patches 

181 numPatches = original.getNumPatches() 

182 for xInd in (0, 1, numPatches[0]//2, numPatches[0]-2, numPatches[0]-1): 

183 for yInd in (0, 1, numPatches[1]//2, numPatches[1]-2, numPatches[1]-1): 

184 patchInfo = original.getPatchInfo((xInd, yInd)) 

185 unpickledPatchInfo = unpickled.getPatchInfo((xInd, yInd)) 

186 self.assertEqual(patchInfo, unpickledPatchInfo) 

187 

188 # check inner and outer bbox (nothing to do with pickle, 

189 # but a convenient place for the test) 

190 innerBBox = patchInfo.getInnerBBox() 

191 outerBBox = patchInfo.getOuterBBox() 

192 

193 if xInd == 0: 

194 self.assertEqual(innerBBox.getMinX(), outerBBox.getMinX()) 

195 else: 

196 self.assertEqual(innerBBox.getMinX() - patchBorder, outerBBox.getMinX()) 

197 if yInd == 0: 

198 self.assertEqual(innerBBox.getMinY(), outerBBox.getMinY()) 

199 else: 

200 self.assertEqual(innerBBox.getMinY() - patchBorder, outerBBox.getMinY()) 

201 

202 if xInd == numPatches[0] - 1: 

203 self.assertEqual(innerBBox.getMaxX(), outerBBox.getMaxX()) 

204 else: 

205 self.assertEqual(innerBBox.getMaxX() + patchBorder, outerBBox.getMaxX()) 

206 if yInd == numPatches[1] - 1: 

207 self.assertEqual(innerBBox.getMaxY(), outerBBox.getMaxY()) 

208 else: 

209 self.assertEqual(innerBBox.getMaxY() + patchBorder, outerBBox.getMaxY()) 

210 

211 def testPickle(self): 

212 """Test that pickling and unpickling restores the original exactly 

213 """ 

214 skyMap = self.getSkyMap() 

215 pickleStr = pickle.dumps(skyMap) 

216 unpickledSkyMap = pickle.loads(pickleStr) 

217 self.assertEqual(len(skyMap), len(unpickledSkyMap)) 

218 self.assertEqual(unpickledSkyMap.config, skyMap.config) 

219 self.assertEqual(skyMap, unpickledSkyMap) 

220 for tractInfo, unpickledTractInfo in zip(skyMap, unpickledSkyMap): 

221 self.assertUnpickledTractInfo(unpickledTractInfo, tractInfo, skyMap.config.patchBorder) 

222 

223 def testTractSeparation(self): 

224 """Confirm that each sky tract has the proper distance to other tracts 

225 """ 

226 if self.neighborAngularSeparation is None: 

227 self.skipTest("Not testing angular separation for %s: neighborAngularSeparation is None" % 

228 (self.SkyMapClass.__name__,)) 

229 skyMap = self.getSkyMap() 

230 for tractId, tractInfo in enumerate(skyMap): 

231 self.assertEqual(tractInfo.getId(), tractId) 

232 

233 ctrCoord = tractInfo.getCtrCoord() 

234 distList = [] 

235 for tractInfo1 in skyMap: 

236 otherCtrCoord = tractInfo1.getCtrCoord() 

237 distList.append(ctrCoord.separation(otherCtrCoord)) 

238 distList.sort() 

239 self.assertEqual(distList[0], 0.0) 

240 for dist in distList[1:self.numNeighbors]: 

241 self.assertAnglesAlmostEqual(dist, self.neighborAngularSeparation) 

242 

243 def testFindPatchList(self): 

244 """Test TractInfo.findPatchList 

245 """ 

246 skyMap = self.getSkyMap() 

247 # pick two arbitrary tracts 

248 for tractId in np.random.choice(len(skyMap), 2): 

249 tractInfo = skyMap[tractId] 

250 wcs = tractInfo.getWcs() 

251 numPatches = tractInfo.getNumPatches() 

252 border = tractInfo.getPatchBorder() 

253 for patchInd in ((0, 0), 

254 (0, 1), 

255 (5, 0), 

256 (5, 6), 

257 (numPatches[0] - 2, numPatches[1] - 1), 

258 (numPatches[0] - 1, numPatches[1] - 2), 

259 (numPatches[0] - 1, numPatches[1] - 1), 

260 ): 

261 patchInfo = tractInfo.getPatchInfo(patchInd) 

262 patchIndex = patchInfo.getIndex() 

263 bbox = patchInfo.getInnerBBox() 

264 bbox.grow(-(border+1)) 

265 coordList = getCornerCoords(wcs=wcs, bbox=bbox) 

266 patchInfoList = tractInfo.findPatchList(coordList) 

267 self.assertEqual(len(patchInfoList), 1) 

268 self.assertEqual(patchInfoList[0].getIndex(), patchIndex) 

269 

270 # grow to include neighbors and test again 

271 bbox.grow(2) 

272 predFoundIndexSet = set() 

273 for dx in (-1, 0, 1): 

274 nbrX = patchIndex[0] + dx 

275 if not 0 <= nbrX < numPatches[0]: 

276 continue 

277 for dy in (-1, 0, 1): 

278 nbrY = patchIndex[1] + dy 

279 if not 0 <= nbrY < numPatches[1]: 

280 continue 

281 nbrInd = (nbrX, nbrY) 

282 predFoundIndexSet.add(nbrInd) 

283 coordList = getCornerCoords(wcs=wcs, bbox=bbox) 

284 patchInfoList = tractInfo.findPatchList(coordList) 

285 self.assertEqual(len(patchInfoList), len(predFoundIndexSet)) 

286 foundIndexSet = set(patchInfo.getIndex() for patchInfo in patchInfoList) 

287 self.assertEqual(foundIndexSet, predFoundIndexSet) 

288 

289 def testFindTractPatchList(self): 

290 """Test findTractPatchList 

291 

292 Note: this test uses single points for speed and to avoid really large regions. 

293 Note that findPatchList is being tested elsewhere. 

294 """ 

295 skyMap = self.getSkyMap() 

296 # pick 3 arbitrary tracts 

297 for tractId in np.random.choice(len(skyMap), 3): 

298 tractInfo = skyMap[tractId] 

299 self.assertTractPatchListOk( 

300 skyMap=skyMap, 

301 coordList=[tractInfo.getCtrCoord()], 

302 knownTractId=tractId, 

303 ) 

304 self.assertClosestTractPatchList(skyMap, [tractInfo.getCtrCoord()], tractId) 

305 

306 vertices = tractInfo.getVertexList() 

307 if len(vertices) > 0: 

308 self.assertTractPatchListOk( 

309 skyMap=skyMap, 

310 coordList=[tractInfo.getVertexList()[0]], 

311 knownTractId=tractId, 

312 ) 

313 

314 if len(vertices) > 2: 

315 self.assertTractPatchListOk( 

316 skyMap=skyMap, 

317 coordList=[tractInfo.getVertexList()[2]], 

318 knownTractId=tractId, 

319 ) 

320 

321 def testTractContains(self): 

322 """Test that TractInfo.contains works""" 

323 skyMap = self.getSkyMap() 

324 for tract in skyMap: 

325 coord = tract.getCtrCoord() 

326 self.assertTrue(tract.contains(coord)) 

327 opposite = geom.SpherePoint(coord.getLongitude() + 12*geom.hours, -1*coord.getLatitude()) 

328 self.assertFalse(tract.contains(opposite)) 

329 

330 def testTractInfoGetPolygon(self): 

331 skyMap = self.getSkyMap() 

332 for tractInfo in skyMap: 

333 centerCoord = tractInfo.getCtrCoord() 

334 self.assertPolygonOk(polygon=tractInfo.getInnerSkyPolygon(), 

335 vertexList=tractInfo.getVertexList(), 

336 centerCoord=centerCoord) 

337 self.assertBBoxPolygonOk(polygon=tractInfo.getOuterSkyPolygon(), 

338 bbox=tractInfo.getBBox(), wcs=tractInfo.getWcs()) 

339 

340 def testPatchInfoGetPolygon(self): 

341 skyMap = self.getSkyMap() 

342 numPatches = skyMap[0].getNumPatches() 

343 

344 def getIndices(numItems): 

345 """Return up to 3 indices for testing""" 

346 if numItems > 2: 

347 return (0, 1, numItems-1) 

348 elif numItems > 1: 

349 return (0, 1) 

350 return (0,) 

351 

352 for tractInfo in skyMap: 

353 wcs = tractInfo.getWcs() 

354 for patchInd in itertools.product(getIndices(numPatches[0]), getIndices(numPatches[1])): 

355 with self.subTest(patchInd=patchInd): 

356 patchInfo = tractInfo.getPatchInfo(patchInd) 

357 self.assertBBoxPolygonOk(polygon=patchInfo.getInnerSkyPolygon(tractWcs=wcs), 

358 bbox=patchInfo.getInnerBBox(), wcs=wcs) 

359 self.assertBBoxPolygonOk(polygon=patchInfo.getOuterSkyPolygon(tractWcs=wcs), 

360 bbox=patchInfo.getOuterBBox(), wcs=wcs) 

361 

362 def testDm14809(self): 

363 """Generic version of test that DM-14809 has been fixed""" 

364 checkDm14809(self, self.getSkyMap()) 

365 

366 def testNumbering(self): 

367 """Check the numbering of tracts matches the indexing""" 

368 skymap = self.getSkyMap() 

369 expect = list(range(len(skymap))) 

370 got = [tt.getId() for tt in skymap] 

371 self.assertEqual(got, expect) 

372 

373 def assertTractPatchListOk(self, skyMap, coordList, knownTractId): 

374 """Assert that findTractPatchList produces the correct results 

375 

376 @param[in] skyMap: sky map to test 

377 @param[in] coordList: coordList of region to search for 

378 @param[in] knownTractId: this tractId must appear in the found list 

379 """ 

380 tractPatchList = skyMap.findTractPatchList(coordList) 

381 tractPatchDict = dict((tp[0].getId(), tp[1]) for tp in tractPatchList) 

382 self.assertTrue(knownTractId in tractPatchDict) 

383 for tractInfo in skyMap: 

384 tractId = tractInfo.getId() 

385 patchList = tractInfo.findPatchList(coordList) 

386 if patchList: 

387 self.assertTrue(tractId in tractPatchDict) 

388 self.assertTrue(len(patchList) == len(tractPatchDict[tractId])) 

389 else: 

390 self.assertTrue(tractId not in tractPatchDict) 

391 

392 def assertClosestTractPatchList(self, skyMap, coordList, knownTractId): 

393 if not hasattr(skyMap, "findClosestTractPatchList"): 

394 self.skipTest("This skymap doesn't implement findClosestTractPatchList") 

395 tractPatchList = skyMap.findClosestTractPatchList(coordList) 

396 self.assertEqual(len(coordList), len(tractPatchList)) # One tract+patchList per coordinate 

397 for coord, (tract, patchList) in zip(coordList, tractPatchList): 

398 self.assertEqual(tract.getId(), knownTractId) 

399 self.assertEqual(patchList, tract.findPatchList([coord])) 

400 

401 def assertBBoxPolygonOk(self, polygon, bbox, wcs): 

402 """Assert that an on-sky polygon from a pixel bbox 

403 covers the expected region. 

404 

405 Parameters 

406 ---------- 

407 polygon : `lsst.sphgeom.ConvexPolygon` 

408 On-sky polygon 

409 vertexList : `iterable` of `lsst.geom.SpherePoint` 

410 Vertices of polygon 

411 centerCoord : `lsst.geom.SpherePoint` 

412 A coord approximately in the center of the region 

413 """ 

414 bboxd = geom.Box2D(bbox) 

415 centerPixel = bboxd.getCenter() 

416 centerCoord = wcs.pixelToSky(centerPixel) 

417 skyCorners = getCornerCoords(wcs=wcs, bbox=bbox) 

418 self.assertPolygonOk(polygon=polygon, vertexList=skyCorners, centerCoord=centerCoord) 

419 

420 def assertPolygonOk(self, polygon, vertexList, centerCoord): 

421 """Assert that an on-sky polygon from covers the expected region. 

422 

423 Parameters 

424 ---------- 

425 polygon : `lsst.sphgeom.ConvexPolygon` 

426 On-sky polygon 

427 vertexList : `iterable` of `lsst.geom.SpherePoint` 

428 Vertices of polygon 

429 centerCoord : `lsst.geom.SpherePoint` 

430 A coord approximately in the center of the region 

431 """ 

432 shiftAngle = 0.01*geom.arcseconds 

433 self.assertTrue(polygon.contains(centerCoord.getVector())) 

434 for vertex in vertexList: 

435 bearingToCenter = vertex.bearingTo(centerCoord) 

436 cornerShiftedIn = vertex.offset(bearing=bearingToCenter, amount=shiftAngle) 

437 cornerShiftedOut = vertex.offset(bearing=bearingToCenter, amount=-shiftAngle) 

438 self.assertTrue(polygon.contains(cornerShiftedIn.getVector())) 

439 self.assertFalse(polygon.contains(cornerShiftedOut.getVector())) 

440 

441 

442def getCornerCoords(wcs, bbox): 

443 """Return the coords of the four corners of a bounding box 

444 """ 

445 cornerPosList = geom.Box2D(bbox).getCorners() 

446 return wcs.pixelToSky(cornerPosList)