Coverage for tests/helper/skyMapTestCase.py: 8%

241 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-05 02:59 -0700

1# 

2# LSST Data Management System 

3# Copyright 2008-2017 LSST Corporation. 

4# 

5# This product includes software developed by the 

6# LSST Project (http://www.lsst.org/). 

7# 

8# This program is free software: you can redistribute it and/or modify 

9# it under the terms of the GNU General Public License as published by 

10# the Free Software Foundation, either version 3 of the License, or 

11# (at your option) any later version. 

12# 

13# This program is distributed in the hope that it will be useful, 

14# but WITHOUT ANY WARRANTY; without even the implied warranty of 

15# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

16# GNU General Public License for more details. 

17# 

18# You should have received a copy of the LSST License Statement and 

19# the GNU General Public License along with this program. If not, 

20# see <http://www.lsstcorp.org/LegalNotices/>. 

21# 

22import itertools 

23import pickle 

24import copy 

25 

26import numpy as np 

27 

28import lsst.geom as geom 

29import lsst.utils.tests 

30 

31from lsst.skymap import skyMapRegistry 

32 

33 

34def checkDm14809(testcase, skymap): 

35 """Test that DM-14809 has been fixed 

36 

37 The observed behaviour was: 

38 

39 skyMap.findTract(skyMap[9712].getCtrCoord()).getId() != 9712 

40 

41 and 

42 

43 skyMap[1].getCtrCoord() == skyMap[11].getCtrCoord() 

44 

45 In order to be thorough, we generalise these over the entire skymap. 

46 """ 

47 # Check that the tract found for central coordinate of a tract is that 

48 # tract. 

49 expect = [tract.getId() for tract in skymap] 

50 got = [skymap.findTract(tract.getCtrCoord()).getId() for tract in skymap] 

51 testcase.assertListEqual(got, expect) 

52 

53 # Check that the tract central coordinates are unique 

54 # Round to integer arcminutes so differences are relatively immune to 

55 # small numerical inaccuracies. 

56 centers = set([(int(coord.getRa().asArcminutes()), int(coord.getDec().asArcminutes())) for 

57 coord in (tract.getCtrCoord() for tract in skymap)]) 

58 testcase.assertEqual(len(centers), len(skymap)) 

59 

60 

61class SkyMapTestCase(lsst.utils.tests.TestCase): 

62 """An abstract base class for testing a SkyMap. 

63 

64 To use, subclass and call `setAttributes` from `setUp` 

65 """ 

66 def setAttributes(self, *, 

67 SkyMapClass, 

68 name, 

69 numTracts, 

70 config=None, 

71 neighborAngularSeparation=None, 

72 numNeighbors=None): 

73 """Initialize the test (call from setUp in the subclass) 

74 

75 Parameters 

76 ---------- 

77 SkyMapClass : subclass of `lsst.skymap.BaseSkyMap` 

78 Class of sky map to test 

79 name : `str` 

80 Name of sky map in sky map registry 

81 numTracts : `int` 

82 Number of tracts in the default configuration 

83 config : subclass of `lsst.skymap.SkyMapConfig`, optional 

84 Default configuration used by `getSkyMap`; 

85 if None use SkyMapClass.ConfigClass() 

86 neighborAngularSeparation : `lsst.geom.Angle`, optional 

87 Expected angular separation between tracts; 

88 if None then angular separation is not tested unless your 

89 subclass of SkyMapTestCaseoverrides `testTractSeparation`. 

90 numNeighbors : `int` or `None` 

91 Number of neighbors that should be within 

92 ``neighborAngularSeparation``; 

93 Required if ``neighborAngularSeparation`` is not None; 

94 ignored otherwise. 

95 """ 

96 self.SkyMapClass = SkyMapClass 

97 self.config = config 

98 self.name = name 

99 self.numTracts = numTracts 

100 self.neighborAngularSeparation = neighborAngularSeparation 

101 self.numNeighbors = numNeighbors 

102 np.random.seed(47) 

103 

104 def getSkyMap(self, config=None): 

105 """Provide an instance of the skymap""" 

106 if config is None: 

107 config = self.getConfig() 

108 config.validate() 

109 return self.SkyMapClass(config=config) 

110 

111 def getConfig(self): 

112 """Provide an instance of the configuration class""" 

113 if self.config is None: 

114 return self.SkyMapClass.ConfigClass() 

115 # Want to return a copy of self.config, so it can be modified. 

116 return copy.copy(self.config) 

117 

118 def testRegistry(self): 

119 """Confirm that the skymap can be retrieved from the registry""" 

120 self.assertEqual(skyMapRegistry[self.name], self.SkyMapClass) 

121 

122 def testBasicAttributes(self): 

123 """Confirm that constructor attributes are available 

124 """ 

125 defaultSkyMap = self.getSkyMap() 

126 for tractOverlap in (0.0, 0.01, 0.1): # degrees 

127 config = self.getConfig() 

128 config.tractOverlap = tractOverlap 

129 skyMap = self.getSkyMap(config) 

130 for tractInfo in skyMap: 

131 self.assertAlmostEqual(tractInfo.getTractOverlap().asDegrees(), tractOverlap) 

132 self.assertEqual(len(skyMap), self.numTracts) 

133 self.assertNotEqual(skyMap, defaultSkyMap) 

134 

135 if defaultSkyMap.config.tractBuilder.name == 'cells': 

136 # The following tests are not appropriate for cells 

137 # see test_ringsSkyMapCells.py for "cell" tract testing. 

138 return 

139 

140 for patchBorder in (0, 101): 

141 config = self.getConfig() 

142 config.patchBorder = patchBorder 

143 skyMap = self.getSkyMap(config) 

144 for tractInfo in skyMap: 

145 self.assertEqual(tractInfo.getPatchBorder(), patchBorder) 

146 self.assertEqual(len(skyMap), self.numTracts) 

147 self.assertNotEqual(skyMap, defaultSkyMap) 

148 

149 for xInnerDim in (1005, 5062): 

150 for yInnerDim in (2032, 5431): 

151 config = self.getConfig() 

152 config.patchInnerDimensions = (xInnerDim, yInnerDim) 

153 skyMap = self.getSkyMap(config) 

154 for tractInfo in skyMap: 

155 self.assertEqual(tuple(tractInfo.getPatchInnerDimensions()), (xInnerDim, yInnerDim)) 

156 self.assertEqual(len(skyMap), self.numTracts) 

157 self.assertNotEqual(skyMap, defaultSkyMap) 

158 

159 # Compare a few patches 

160 defaultSkyMap = self.getSkyMap() 

161 tractInfo = defaultSkyMap[0] 

162 numPatches = tractInfo.getNumPatches() 

163 patchBorder = defaultSkyMap.config.patchBorder 

164 for xInd in (0, 1, numPatches[0]//2, numPatches[0]-2, numPatches[0]-1): 

165 for yInd in (0, 1, numPatches[1]//2, numPatches[1]-2, numPatches[1]-1): 

166 patchInfo = tractInfo.getPatchInfo((xInd, yInd)) 

167 

168 # check inner and outer bbox 

169 innerBBox = patchInfo.getInnerBBox() 

170 outerBBox = patchInfo.getOuterBBox() 

171 

172 if xInd == 0: 

173 self.assertEqual(innerBBox.getMinX(), outerBBox.getMinX()) 

174 else: 

175 self.assertEqual(innerBBox.getMinX() - patchBorder, outerBBox.getMinX()) 

176 if yInd == 0: 

177 self.assertEqual(innerBBox.getMinY(), outerBBox.getMinY()) 

178 else: 

179 self.assertEqual(innerBBox.getMinY() - patchBorder, outerBBox.getMinY()) 

180 

181 if xInd == numPatches[0] - 1: 

182 self.assertEqual(innerBBox.getMaxX(), outerBBox.getMaxX()) 

183 else: 

184 self.assertEqual(innerBBox.getMaxX() + patchBorder, outerBBox.getMaxX()) 

185 if yInd == numPatches[1] - 1: 

186 self.assertEqual(innerBBox.getMaxY(), outerBBox.getMaxY()) 

187 else: 

188 self.assertEqual(innerBBox.getMaxY() + patchBorder, outerBBox.getMaxY()) 

189 

190 def assertUnpickledTractInfo(self, unpickled, original, patchBorder): 

191 """Assert that an unpickled TractInfo is functionally identical to the 

192 original. 

193 

194 Parameters 

195 ---------- 

196 unpickled : `TractInfo` 

197 The unpickled `TractInfo`. 

198 original : `TractInfo` 

199 The original `TractInfo`. 

200 patchBorder : `int` 

201 Border around each patch, from ``SkyMap.config.patchBorder``. 

202 """ 

203 for getterName in ("getBBox", 

204 "getCtrCoord", 

205 "getId", 

206 "getNumPatches", 

207 "getPatchBorder", 

208 "getPatchInnerDimensions", 

209 "getTractOverlap", 

210 "getVertexList", 

211 "getWcs", 

212 ): 

213 self.assertEqual(getattr(original, getterName)(), getattr(unpickled, getterName)()) 

214 

215 # test WCS at a few locations 

216 wcs = original.getWcs() 

217 unpickledWcs = unpickled.getWcs() 

218 for x in (-1000.0, 0.0, 1000.0): 

219 for y in (-532.5, 0.5, 532.5): 

220 pixelPos = geom.Point2D(x, y) 

221 skyPos = wcs.pixelToSky(pixelPos) 

222 unpickledSkyPos = unpickledWcs.pixelToSky(pixelPos) 

223 self.assertEqual(skyPos, unpickledSkyPos) 

224 

225 # compare a few patches 

226 numPatches = original.getNumPatches() 

227 for xInd in (0, 1, numPatches[0]//2, numPatches[0]-2, numPatches[0]-1): 

228 for yInd in (0, 1, numPatches[1]//2, numPatches[1]-2, numPatches[1]-1): 

229 patchInfo = original.getPatchInfo((xInd, yInd)) 

230 unpickledPatchInfo = unpickled.getPatchInfo((xInd, yInd)) 

231 self.assertEqual(patchInfo, unpickledPatchInfo) 

232 

233 def testPickle(self): 

234 """Test that pickling and unpickling restores the original exactly 

235 """ 

236 skyMap = self.getSkyMap() 

237 pickleStr = pickle.dumps(skyMap) 

238 unpickledSkyMap = pickle.loads(pickleStr) 

239 self.assertEqual(len(skyMap), len(unpickledSkyMap)) 

240 self.assertEqual(unpickledSkyMap.config, skyMap.config) 

241 self.assertEqual(skyMap, unpickledSkyMap) 

242 for tractInfo, unpickledTractInfo in zip(skyMap, unpickledSkyMap): 

243 self.assertUnpickledTractInfo(unpickledTractInfo, tractInfo, skyMap.config.patchBorder) 

244 

245 def testTractSeparation(self): 

246 """Confirm that each sky tract has the proper distance to other tracts 

247 """ 

248 if self.neighborAngularSeparation is None: 

249 self.skipTest("Not testing angular separation for %s: neighborAngularSeparation is None" % 

250 (self.SkyMapClass.__name__,)) 

251 skyMap = self.getSkyMap() 

252 for tractId, tractInfo in enumerate(skyMap): 

253 self.assertEqual(tractInfo.getId(), tractId) 

254 

255 ctrCoord = tractInfo.getCtrCoord() 

256 distList = [] 

257 for tractInfo1 in skyMap: 

258 otherCtrCoord = tractInfo1.getCtrCoord() 

259 distList.append(ctrCoord.separation(otherCtrCoord)) 

260 distList.sort() 

261 self.assertEqual(distList[0], 0.0) 

262 for dist in distList[1:self.numNeighbors]: 

263 self.assertAnglesAlmostEqual(dist, self.neighborAngularSeparation) 

264 

265 def testFindPatchList(self): 

266 """Test TractInfo.findPatchList 

267 """ 

268 skyMap = self.getSkyMap() 

269 # pick two arbitrary tracts 

270 for tractId in np.random.choice(len(skyMap), 2): 

271 tractInfo = skyMap[tractId] 

272 wcs = tractInfo.getWcs() 

273 numPatches = tractInfo.getNumPatches() 

274 border = tractInfo.getPatchBorder() 

275 for patchInd in ((0, 0), 

276 (0, 1), 

277 (5, 0), 

278 (5, 6), 

279 (numPatches[0] - 2, numPatches[1] - 1), 

280 (numPatches[0] - 1, numPatches[1] - 2), 

281 (numPatches[0] - 1, numPatches[1] - 1), 

282 ): 

283 patchInfo = tractInfo.getPatchInfo(patchInd) 

284 patchIndex = patchInfo.getIndex() 

285 bbox = patchInfo.getInnerBBox() 

286 bbox.grow(-(border+1)) 

287 coordList = getCornerCoords(wcs=wcs, bbox=bbox) 

288 patchInfoList = tractInfo.findPatchList(coordList) 

289 self.assertEqual(len(patchInfoList), 1) 

290 self.assertEqual(patchInfoList[0].getIndex(), patchIndex) 

291 

292 # grow to include neighbors and test again 

293 bbox.grow(2) 

294 predFoundIndexSet = set() 

295 for dx in (-1, 0, 1): 

296 nbrX = patchIndex[0] + dx 

297 if not 0 <= nbrX < numPatches[0]: 

298 continue 

299 for dy in (-1, 0, 1): 

300 nbrY = patchIndex[1] + dy 

301 if not 0 <= nbrY < numPatches[1]: 

302 continue 

303 nbrInd = (nbrX, nbrY) 

304 predFoundIndexSet.add(nbrInd) 

305 coordList = getCornerCoords(wcs=wcs, bbox=bbox) 

306 patchInfoList = tractInfo.findPatchList(coordList) 

307 self.assertEqual(len(patchInfoList), len(predFoundIndexSet)) 

308 foundIndexSet = set(patchInfo.getIndex() for patchInfo in patchInfoList) 

309 self.assertEqual(foundIndexSet, predFoundIndexSet) 

310 

311 def testFindTractPatchList(self): 

312 """Test findTractPatchList 

313 

314 Notes 

315 ----- 

316 This test uses single points for speed and to avoid really large 

317 regions. 

318 Note that `findPatchList` is being tested elsewhere. 

319 """ 

320 skyMap = self.getSkyMap() 

321 # pick 3 arbitrary tracts 

322 for tractId in np.random.choice(len(skyMap), 3): 

323 tractInfo = skyMap[tractId] 

324 self.assertTractPatchListOk( 

325 skyMap=skyMap, 

326 coordList=[tractInfo.getCtrCoord()], 

327 knownTractId=tractId, 

328 ) 

329 self.assertClosestTractPatchList(skyMap, [tractInfo.getCtrCoord()], tractId) 

330 

331 vertices = tractInfo.getVertexList() 

332 if len(vertices) > 0: 

333 self.assertTractPatchListOk( 

334 skyMap=skyMap, 

335 coordList=[tractInfo.getVertexList()[0]], 

336 knownTractId=tractId, 

337 ) 

338 

339 if len(vertices) > 2: 

340 self.assertTractPatchListOk( 

341 skyMap=skyMap, 

342 coordList=[tractInfo.getVertexList()[2]], 

343 knownTractId=tractId, 

344 ) 

345 

346 def testTractContains(self): 

347 """Test that TractInfo.contains works""" 

348 skyMap = self.getSkyMap() 

349 for tract in skyMap: 

350 coord = tract.getCtrCoord() 

351 self.assertTrue(tract.contains(coord)) 

352 opposite = geom.SpherePoint(coord.getLongitude() + 12*geom.hours, -1*coord.getLatitude()) 

353 self.assertFalse(tract.contains(opposite)) 

354 

355 def testTractInfoGetPolygon(self): 

356 skyMap = self.getSkyMap() 

357 for tractInfo in skyMap: 

358 centerCoord = tractInfo.getCtrCoord() 

359 self.assertPolygonOk(polygon=tractInfo.getInnerSkyPolygon(), 

360 vertexList=tractInfo.getVertexList(), 

361 centerCoord=centerCoord) 

362 self.assertBBoxPolygonOk(polygon=tractInfo.getOuterSkyPolygon(), 

363 bbox=tractInfo.getBBox(), wcs=tractInfo.getWcs()) 

364 

365 def testPatchInfoGetPolygon(self): 

366 skyMap = self.getSkyMap() 

367 numPatches = skyMap[0].getNumPatches() 

368 

369 def getIndices(numItems): 

370 """Return up to 3 indices for testing""" 

371 if numItems > 2: 

372 return (0, 1, numItems-1) 

373 elif numItems > 1: 

374 return (0, 1) 

375 return (0,) 

376 

377 for tractInfo in skyMap: 

378 wcs = tractInfo.getWcs() 

379 for patchInd in itertools.product(getIndices(numPatches[0]), getIndices(numPatches[1])): 

380 with self.subTest(patchInd=patchInd): 

381 patchInfo = tractInfo.getPatchInfo(patchInd) 

382 self.assertBBoxPolygonOk(polygon=patchInfo.getInnerSkyPolygon(tractWcs=wcs), 

383 bbox=patchInfo.getInnerBBox(), wcs=wcs) 

384 self.assertBBoxPolygonOk(polygon=patchInfo.getOuterSkyPolygon(tractWcs=wcs), 

385 bbox=patchInfo.getOuterBBox(), wcs=wcs) 

386 

387 def testDm14809(self): 

388 """Generic version of test that DM-14809 has been fixed""" 

389 checkDm14809(self, self.getSkyMap()) 

390 

391 def testNumbering(self): 

392 """Check the numbering of tracts matches the indexing""" 

393 skymap = self.getSkyMap() 

394 expect = list(range(len(skymap))) 

395 got = [tt.getId() for tt in skymap] 

396 self.assertEqual(got, expect) 

397 

398 def assertTractPatchListOk(self, skyMap, coordList, knownTractId): 

399 """Assert that findTractPatchList produces the correct results. 

400 

401 Parameters 

402 ---------- 

403 skyMap : `BaseSkyMap` 

404 Sky map to test. 

405 coordList : `list` of `lsst.geom.SpherePoint` 

406 Region to search for. 

407 knownTractId : `int` 

408 This tractId must appear in the found list. 

409 """ 

410 tractPatchList = skyMap.findTractPatchList(coordList) 

411 tractPatchDict = dict((tp[0].getId(), tp[1]) for tp in tractPatchList) 

412 self.assertTrue(knownTractId in tractPatchDict) 

413 for tractInfo in skyMap: 

414 tractId = tractInfo.getId() 

415 patchList = tractInfo.findPatchList(coordList) 

416 if patchList: 

417 self.assertTrue(tractId in tractPatchDict) 

418 self.assertTrue(len(patchList) == len(tractPatchDict[tractId])) 

419 else: 

420 self.assertTrue(tractId not in tractPatchDict) 

421 

422 def assertClosestTractPatchList(self, skyMap, coordList, knownTractId): 

423 if not hasattr(skyMap, "findClosestTractPatchList"): 

424 self.skipTest("This skymap doesn't implement findClosestTractPatchList") 

425 tractPatchList = skyMap.findClosestTractPatchList(coordList) 

426 self.assertEqual(len(coordList), len(tractPatchList)) # One tract+patchList per coordinate 

427 for coord, (tract, patchList) in zip(coordList, tractPatchList): 

428 self.assertEqual(tract.getId(), knownTractId) 

429 self.assertEqual(patchList, tract.findPatchList([coord])) 

430 

431 def assertBBoxPolygonOk(self, polygon, bbox, wcs): 

432 """Assert that an on-sky polygon from a pixel bbox 

433 covers the expected region. 

434 

435 Parameters 

436 ---------- 

437 polygon : `lsst.sphgeom.ConvexPolygon` 

438 On-sky polygon 

439 vertexList : `iterable` of `lsst.geom.SpherePoint` 

440 Vertices of polygon 

441 centerCoord : `lsst.geom.SpherePoint` 

442 A coord approximately in the center of the region 

443 """ 

444 bboxd = geom.Box2D(bbox) 

445 centerPixel = bboxd.getCenter() 

446 centerCoord = wcs.pixelToSky(centerPixel) 

447 skyCorners = getCornerCoords(wcs=wcs, bbox=bbox) 

448 self.assertPolygonOk(polygon=polygon, vertexList=skyCorners, centerCoord=centerCoord) 

449 

450 def assertPolygonOk(self, polygon, vertexList, centerCoord): 

451 """Assert that an on-sky polygon from covers the expected region. 

452 

453 Parameters 

454 ---------- 

455 polygon : `lsst.sphgeom.ConvexPolygon` 

456 On-sky polygon 

457 vertexList : `iterable` of `lsst.geom.SpherePoint` 

458 Vertices of polygon 

459 centerCoord : `lsst.geom.SpherePoint` 

460 A coord approximately in the center of the region 

461 """ 

462 shiftAngle = 0.01*geom.arcseconds 

463 self.assertTrue(polygon.contains(centerCoord.getVector())) 

464 for vertex in vertexList: 

465 bearingToCenter = vertex.bearingTo(centerCoord) 

466 cornerShiftedIn = vertex.offset(bearing=bearingToCenter, amount=shiftAngle) 

467 cornerShiftedOut = vertex.offset(bearing=bearingToCenter, amount=-shiftAngle) 

468 self.assertTrue(polygon.contains(cornerShiftedIn.getVector())) 

469 self.assertFalse(polygon.contains(cornerShiftedOut.getVector())) 

470 

471 

472def getCornerCoords(wcs, bbox): 

473 """Return the coords of the four corners of a bounding box 

474 """ 

475 cornerPosList = geom.Box2D(bbox).getCorners() 

476 return wcs.pixelToSky(cornerPosList)