Coverage for tests / helper / skyMapTestCase.py: 9%

260 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-01 08:22 +0000

1# 

2# LSST Data Management System 

3# Copyright 2008-2017 LSST Corporation. 

4# 

5# This product includes software developed by the 

6# LSST Project (http://www.lsst.org/). 

7# 

8# This program is free software: you can redistribute it and/or modify 

9# it under the terms of the GNU General Public License as published by 

10# the Free Software Foundation, either version 3 of the License, or 

11# (at your option) any later version. 

12# 

13# This program is distributed in the hope that it will be useful, 

14# but WITHOUT ANY WARRANTY; without even the implied warranty of 

15# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

16# GNU General Public License for more details. 

17# 

18# You should have received a copy of the LSST License Statement and 

19# the GNU General Public License along with this program. If not, 

20# see <http://www.lsstcorp.org/LegalNotices/>. 

21# 

22import itertools 

23import pickle 

24import copy 

25 

26import numpy as np 

27 

28import lsst.geom as geom 

29import lsst.sphgeom as sphgeom 

30import lsst.utils.tests 

31 

32from lsst.skymap import skyMapRegistry 

33 

34 

35def checkDm14809(testcase, skymap): 

36 """Test that DM-14809 has been fixed 

37 

38 The observed behaviour was: 

39 

40 skyMap.findTract(skyMap[9712].getCtrCoord()).getId() != 9712 

41 

42 and 

43 

44 skyMap[1].getCtrCoord() == skyMap[11].getCtrCoord() 

45 

46 In order to be thorough, we generalise these over the entire skymap. 

47 """ 

48 # Check that the tract found for central coordinate of a tract is that 

49 # tract. 

50 expect = [tract.getId() for tract in skymap] 

51 got = [skymap.findTract(tract.getCtrCoord()).getId() for tract in skymap] 

52 testcase.assertListEqual(got, expect) 

53 

54 # Check that the tract central coordinates are unique 

55 # Round to integer arcminutes so differences are relatively immune to 

56 # small numerical inaccuracies. 

57 centers = set([(int(coord.getRa().asArcminutes()), int(coord.getDec().asArcminutes())) for 

58 coord in (tract.getCtrCoord() for tract in skymap)]) 

59 testcase.assertEqual(len(centers), len(skymap)) 

60 

61 

62class SkyMapTestCase(lsst.utils.tests.TestCase): 

63 """An abstract base class for testing a SkyMap. 

64 

65 To use, subclass and call `setAttributes` from `setUp` 

66 """ 

67 def setAttributes(self, *, 

68 SkyMapClass, 

69 name, 

70 numTracts, 

71 config=None, 

72 neighborAngularSeparation=None, 

73 numNeighbors=None): 

74 """Initialize the test (call from setUp in the subclass) 

75 

76 Parameters 

77 ---------- 

78 SkyMapClass : subclass of `lsst.skymap.BaseSkyMap` 

79 Class of sky map to test 

80 name : `str` 

81 Name of sky map in sky map registry 

82 numTracts : `int` 

83 Number of tracts in the default configuration 

84 config : subclass of `lsst.skymap.SkyMapConfig`, optional 

85 Default configuration used by `getSkyMap`; 

86 if None use SkyMapClass.ConfigClass() 

87 neighborAngularSeparation : `lsst.geom.Angle`, optional 

88 Expected angular separation between tracts; 

89 if None then angular separation is not tested unless your 

90 subclass of SkyMapTestCaseoverrides `testTractSeparation`. 

91 numNeighbors : `int` or `None` 

92 Number of neighbors that should be within 

93 ``neighborAngularSeparation``; 

94 Required if ``neighborAngularSeparation`` is not None; 

95 ignored otherwise. 

96 """ 

97 self.SkyMapClass = SkyMapClass 

98 self.config = config 

99 self.name = name 

100 self.numTracts = numTracts 

101 self.neighborAngularSeparation = neighborAngularSeparation 

102 self.numNeighbors = numNeighbors 

103 np.random.seed(47) 

104 

105 def getSkyMap(self, config=None): 

106 """Provide an instance of the skymap""" 

107 if config is None: 

108 config = self.getConfig() 

109 config.validate() 

110 return self.SkyMapClass(config=config) 

111 

112 def getConfig(self): 

113 """Provide an instance of the configuration class""" 

114 if self.config is None: 

115 return self.SkyMapClass.ConfigClass() 

116 # Want to return a copy of self.config, so it can be modified. 

117 return copy.copy(self.config) 

118 

119 def testRegistry(self): 

120 """Confirm that the skymap can be retrieved from the registry""" 

121 self.assertEqual(skyMapRegistry[self.name], self.SkyMapClass) 

122 

123 def testBasicAttributes(self): 

124 """Confirm that constructor attributes are available 

125 """ 

126 defaultSkyMap = self.getSkyMap() 

127 for tractOverlap in (0.0, 0.01, 0.1): # degrees 

128 config = self.getConfig() 

129 config.tractOverlap = tractOverlap 

130 skyMap = self.getSkyMap(config) 

131 for tractInfo in skyMap: 

132 self.assertAlmostEqual(tractInfo.getTractOverlap().asDegrees(), tractOverlap) 

133 self.assertEqual(len(skyMap), self.numTracts) 

134 self.assertNotEqual(skyMap, defaultSkyMap) 

135 

136 if defaultSkyMap.config.tractBuilder.name == 'cells': 

137 # The following tests are not appropriate for cells 

138 # see test_ringsSkyMapCells.py for "cell" tract testing. 

139 return 

140 

141 for patchBorder in (0, 101): 

142 config = self.getConfig() 

143 config.patchBorder = patchBorder 

144 skyMap = self.getSkyMap(config) 

145 for tractInfo in skyMap: 

146 self.assertEqual(tractInfo.getPatchBorder(), patchBorder) 

147 self.assertEqual(len(skyMap), self.numTracts) 

148 self.assertNotEqual(skyMap, defaultSkyMap) 

149 

150 for xInnerDim in (1005, 5062): 

151 for yInnerDim in (2032, 5431): 

152 config = self.getConfig() 

153 config.patchInnerDimensions = (xInnerDim, yInnerDim) 

154 skyMap = self.getSkyMap(config) 

155 for tractInfo in skyMap: 

156 self.assertEqual(tuple(tractInfo.getPatchInnerDimensions()), (xInnerDim, yInnerDim)) 

157 self.assertEqual(len(skyMap), self.numTracts) 

158 self.assertNotEqual(skyMap, defaultSkyMap) 

159 

160 # Compare a few patches 

161 defaultSkyMap = self.getSkyMap() 

162 tractInfo = defaultSkyMap[0] 

163 numPatches = tractInfo.getNumPatches() 

164 patchBorder = defaultSkyMap.config.patchBorder 

165 for xInd in (0, 1, numPatches[0]//2, numPatches[0]-2, numPatches[0]-1): 

166 for yInd in (0, 1, numPatches[1]//2, numPatches[1]-2, numPatches[1]-1): 

167 patchInfo = tractInfo.getPatchInfo((xInd, yInd)) 

168 

169 # check inner and outer bbox 

170 innerBBox = patchInfo.getInnerBBox() 

171 outerBBox = patchInfo.getOuterBBox() 

172 

173 if xInd == 0: 

174 self.assertEqual(innerBBox.getMinX(), outerBBox.getMinX()) 

175 else: 

176 self.assertEqual(innerBBox.getMinX() - patchBorder, outerBBox.getMinX()) 

177 if yInd == 0: 

178 self.assertEqual(innerBBox.getMinY(), outerBBox.getMinY()) 

179 else: 

180 self.assertEqual(innerBBox.getMinY() - patchBorder, outerBBox.getMinY()) 

181 

182 if xInd == numPatches[0] - 1: 

183 self.assertEqual(innerBBox.getMaxX(), outerBBox.getMaxX()) 

184 else: 

185 self.assertEqual(innerBBox.getMaxX() + patchBorder, outerBBox.getMaxX()) 

186 if yInd == numPatches[1] - 1: 

187 self.assertEqual(innerBBox.getMaxY(), outerBBox.getMaxY()) 

188 else: 

189 self.assertEqual(innerBBox.getMaxY() + patchBorder, outerBBox.getMaxY()) 

190 

191 def assertUnpickledTractInfo(self, unpickled, original, patchBorder): 

192 """Assert that an unpickled TractInfo is functionally identical to the 

193 original. 

194 

195 Parameters 

196 ---------- 

197 unpickled : `TractInfo` 

198 The unpickled `TractInfo`. 

199 original : `TractInfo` 

200 The original `TractInfo`. 

201 patchBorder : `int` 

202 Border around each patch, from ``SkyMap.config.patchBorder``. 

203 """ 

204 for getterName in ("getBBox", 

205 "getCtrCoord", 

206 "getId", 

207 "getNumPatches", 

208 "getPatchBorder", 

209 "getPatchInnerDimensions", 

210 "getTractOverlap", 

211 "getVertexList", 

212 "getWcs", 

213 ): 

214 self.assertEqual(getattr(original, getterName)(), getattr(unpickled, getterName)()) 

215 

216 # test WCS at a few locations 

217 wcs = original.getWcs() 

218 unpickledWcs = unpickled.getWcs() 

219 for x in (-1000.0, 0.0, 1000.0): 

220 for y in (-532.5, 0.5, 532.5): 

221 pixelPos = geom.Point2D(x, y) 

222 skyPos = wcs.pixelToSky(pixelPos) 

223 unpickledSkyPos = unpickledWcs.pixelToSky(pixelPos) 

224 self.assertEqual(skyPos, unpickledSkyPos) 

225 

226 # compare a few patches 

227 numPatches = original.getNumPatches() 

228 for xInd in (0, 1, numPatches[0]//2, numPatches[0]-2, numPatches[0]-1): 

229 for yInd in (0, 1, numPatches[1]//2, numPatches[1]-2, numPatches[1]-1): 

230 patchInfo = original.getPatchInfo((xInd, yInd)) 

231 unpickledPatchInfo = unpickled.getPatchInfo((xInd, yInd)) 

232 self.assertEqual(patchInfo, unpickledPatchInfo) 

233 

234 def testPickle(self): 

235 """Test that pickling and unpickling restores the original exactly 

236 """ 

237 skyMap = self.getSkyMap() 

238 pickleStr = pickle.dumps(skyMap) 

239 unpickledSkyMap = pickle.loads(pickleStr) 

240 self.assertEqual(len(skyMap), len(unpickledSkyMap)) 

241 self.assertEqual(unpickledSkyMap.config, skyMap.config) 

242 self.assertEqual(skyMap, unpickledSkyMap) 

243 for tractInfo, unpickledTractInfo in zip(skyMap, unpickledSkyMap): 

244 self.assertUnpickledTractInfo(unpickledTractInfo, tractInfo, skyMap.config.patchBorder) 

245 

246 def testTractSeparation(self): 

247 """Confirm that each sky tract has the proper distance to other tracts 

248 """ 

249 if self.neighborAngularSeparation is None: 

250 self.skipTest("Not testing angular separation for %s: neighborAngularSeparation is None" % 

251 (self.SkyMapClass.__name__,)) 

252 skyMap = self.getSkyMap() 

253 for tractId, tractInfo in enumerate(skyMap): 

254 self.assertEqual(tractInfo.getId(), tractId) 

255 

256 ctrCoord = tractInfo.getCtrCoord() 

257 distList = [] 

258 for tractInfo1 in skyMap: 

259 otherCtrCoord = tractInfo1.getCtrCoord() 

260 distList.append(ctrCoord.separation(otherCtrCoord)) 

261 distList.sort() 

262 self.assertEqual(distList[0], 0.0) 

263 for dist in distList[1:self.numNeighbors]: 

264 self.assertAnglesAlmostEqual(dist, self.neighborAngularSeparation) 

265 

266 def testFindPatchList(self): 

267 """Test TractInfo.findPatchList 

268 """ 

269 skyMap = self.getSkyMap() 

270 # pick two arbitrary tracts 

271 for tractId in np.random.choice(len(skyMap), 2): 

272 tractInfo = skyMap[tractId] 

273 wcs = tractInfo.getWcs() 

274 numPatches = tractInfo.getNumPatches() 

275 border = tractInfo.getPatchBorder() 

276 for patchInd in ((0, 0), 

277 (0, 1), 

278 (5, 0), 

279 (5, 6), 

280 (numPatches[0] - 2, numPatches[1] - 1), 

281 (numPatches[0] - 1, numPatches[1] - 2), 

282 (numPatches[0] - 1, numPatches[1] - 1), 

283 ): 

284 patchInfo = tractInfo.getPatchInfo(patchInd) 

285 patchIndex = patchInfo.getIndex() 

286 bbox = patchInfo.getInnerBBox() 

287 bbox.grow(-(border+1)) 

288 coordList = getCornerCoords(wcs=wcs, bbox=bbox) 

289 patchInfoList = tractInfo.findPatchList(coordList) 

290 self.assertEqual(len(patchInfoList), 1) 

291 self.assertEqual(patchInfoList[0].getIndex(), patchIndex) 

292 

293 # grow to include neighbors and test again 

294 bbox.grow(2) 

295 predFoundIndexSet = set() 

296 for dx in (-1, 0, 1): 

297 nbrX = patchIndex[0] + dx 

298 if not 0 <= nbrX < numPatches[0]: 

299 continue 

300 for dy in (-1, 0, 1): 

301 nbrY = patchIndex[1] + dy 

302 if not 0 <= nbrY < numPatches[1]: 

303 continue 

304 nbrInd = (nbrX, nbrY) 

305 predFoundIndexSet.add(nbrInd) 

306 coordList = getCornerCoords(wcs=wcs, bbox=bbox) 

307 patchInfoList = tractInfo.findPatchList(coordList) 

308 self.assertEqual(len(patchInfoList), len(predFoundIndexSet)) 

309 foundIndexSet = set(patchInfo.getIndex() for patchInfo in patchInfoList) 

310 self.assertEqual(foundIndexSet, predFoundIndexSet) 

311 

312 def testFindTractPatchList(self): 

313 """Test findTractPatchList 

314 

315 Notes 

316 ----- 

317 This test uses single points for speed and to avoid really large 

318 regions. 

319 Note that `findPatchList` is being tested elsewhere. 

320 """ 

321 skyMap = self.getSkyMap() 

322 # pick 3 arbitrary tracts 

323 for tractId in np.random.choice(len(skyMap), 3): 

324 tractInfo = skyMap[tractId] 

325 self.assertTractPatchListOk( 

326 skyMap=skyMap, 

327 coordList=[tractInfo.getCtrCoord()], 

328 knownTractId=tractId, 

329 ) 

330 self.assertClosestTractPatchList(skyMap, [tractInfo.getCtrCoord()], tractId) 

331 

332 vertices = tractInfo.getVertexList() 

333 if len(vertices) > 0: 

334 self.assertTractPatchListOk( 

335 skyMap=skyMap, 

336 coordList=[tractInfo.getVertexList()[0]], 

337 knownTractId=tractId, 

338 ) 

339 

340 if len(vertices) > 2: 

341 self.assertTractPatchListOk( 

342 skyMap=skyMap, 

343 coordList=[tractInfo.getVertexList()[2]], 

344 knownTractId=tractId, 

345 ) 

346 

347 def testTractContains(self): 

348 """Test that TractInfo.contains works""" 

349 skyMap = self.getSkyMap() 

350 for tract in skyMap: 

351 coord = tract.getCtrCoord() 

352 self.assertTrue(tract.contains(coord)) 

353 opposite = geom.SpherePoint(coord.getLongitude() + 12*geom.hours, -1*coord.getLatitude()) 

354 self.assertFalse(tract.contains(opposite)) 

355 

356 def testTractInfoGetPolygon(self): 

357 skyMap = self.getSkyMap() 

358 for tractInfo in skyMap: 

359 centerCoord = tractInfo.getCtrCoord() 

360 # TODO: Remove with DM-44799 

361 with self.assertWarns(FutureWarning): 

362 self.assertPolygonOk(polygon=tractInfo.getInnerSkyPolygon(), 

363 vertexList=tractInfo.getVertexList(), 

364 centerCoord=centerCoord) 

365 self.assertBBoxPolygonOk(polygon=tractInfo.getOuterSkyPolygon(), 

366 bbox=tractInfo.getBBox(), wcs=tractInfo.getWcs()) 

367 

368 def testTractInfoGetRegion(self): 

369 skyMap = self.getSkyMap() 

370 for tractInfo in skyMap: 

371 centerCoord = tractInfo.getCtrCoord() 

372 region = tractInfo.getInnerSkyRegion() 

373 if isinstance(region, sphgeom.Box): 

374 self.assertRegionOk(region=region, 

375 centerCoord=centerCoord) 

376 else: 

377 self.assertRegionOk(region=region, 

378 centerCoord=centerCoord, 

379 vertexList=tractInfo.getVertexList()) 

380 

381 def testPatchInfoGetPolygon(self): 

382 skyMap = self.getSkyMap() 

383 numPatches = skyMap[0].getNumPatches() 

384 

385 def getIndices(numItems): 

386 """Return up to 3 indices for testing""" 

387 if numItems > 2: 

388 return (0, 1, numItems-1) 

389 elif numItems > 1: 

390 return (0, 1) 

391 return (0,) 

392 

393 for tractInfo in skyMap: 

394 wcs = tractInfo.getWcs() 

395 for patchInd in itertools.product(getIndices(numPatches[0]), getIndices(numPatches[1])): 

396 with self.subTest(patchInd=patchInd): 

397 patchInfo = tractInfo.getPatchInfo(patchInd) 

398 self.assertBBoxPolygonOk(polygon=patchInfo.getInnerSkyPolygon(tractWcs=wcs), 

399 bbox=patchInfo.getInnerBBox(), wcs=wcs) 

400 self.assertBBoxPolygonOk(polygon=patchInfo.getOuterSkyPolygon(tractWcs=wcs), 

401 bbox=patchInfo.getOuterBBox(), wcs=wcs) 

402 

403 def testDm14809(self): 

404 """Generic version of test that DM-14809 has been fixed""" 

405 checkDm14809(self, self.getSkyMap()) 

406 

407 def testNumbering(self): 

408 """Check the numbering of tracts matches the indexing""" 

409 skymap = self.getSkyMap() 

410 expect = list(range(len(skymap))) 

411 got = [tt.getId() for tt in skymap] 

412 self.assertEqual(got, expect) 

413 

414 def assertTractPatchListOk(self, skyMap, coordList, knownTractId): 

415 """Assert that findTractPatchList produces the correct results. 

416 

417 Parameters 

418 ---------- 

419 skyMap : `BaseSkyMap` 

420 Sky map to test. 

421 coordList : `list` of `lsst.geom.SpherePoint` 

422 Region to search for. 

423 knownTractId : `int` 

424 This tractId must appear in the found list. 

425 """ 

426 tractPatchList = skyMap.findTractPatchList(coordList) 

427 tractPatchDict = dict((tp[0].getId(), tp[1]) for tp in tractPatchList) 

428 self.assertTrue(knownTractId in tractPatchDict) 

429 for tractInfo in skyMap: 

430 tractId = tractInfo.getId() 

431 patchList = tractInfo.findPatchList(coordList) 

432 if patchList: 

433 self.assertTrue(tractId in tractPatchDict) 

434 self.assertTrue(len(patchList) == len(tractPatchDict[tractId])) 

435 else: 

436 self.assertTrue(tractId not in tractPatchDict) 

437 

438 def assertClosestTractPatchList(self, skyMap, coordList, knownTractId): 

439 if not hasattr(skyMap, "findClosestTractPatchList"): 

440 self.skipTest("This skymap doesn't implement findClosestTractPatchList") 

441 tractPatchList = skyMap.findClosestTractPatchList(coordList) 

442 self.assertEqual(len(coordList), len(tractPatchList)) # One tract+patchList per coordinate 

443 for coord, (tract, patchList) in zip(coordList, tractPatchList): 

444 self.assertEqual(tract.getId(), knownTractId) 

445 self.assertEqual(patchList, tract.findPatchList([coord])) 

446 

447 def assertBBoxPolygonOk(self, polygon, bbox, wcs): 

448 """Assert that an on-sky polygon from a pixel bbox 

449 covers the expected region. 

450 

451 Parameters 

452 ---------- 

453 polygon : `lsst.sphgeom.ConvexPolygon` 

454 On-sky polygon 

455 vertexList : `iterable` of `lsst.geom.SpherePoint` 

456 Vertices of polygon 

457 centerCoord : `lsst.geom.SpherePoint` 

458 A coord approximately in the center of the region 

459 """ 

460 bboxd = geom.Box2D(bbox) 

461 centerPixel = bboxd.getCenter() 

462 centerCoord = wcs.pixelToSky(centerPixel) 

463 skyCorners = getCornerCoords(wcs=wcs, bbox=bbox) 

464 self.assertPolygonOk(polygon=polygon, vertexList=skyCorners, centerCoord=centerCoord) 

465 

466 def assertPolygonOk(self, polygon, vertexList, centerCoord): 

467 """Assert that an on-sky polygon from covers the expected region. 

468 

469 Parameters 

470 ---------- 

471 polygon : `lsst.sphgeom.ConvexPolygon` 

472 On-sky polygon 

473 vertexList : `iterable` of `lsst.geom.SpherePoint` 

474 Vertices of polygon 

475 centerCoord : `lsst.geom.SpherePoint` 

476 A coord approximately in the center of the region 

477 """ 

478 shiftAngle = 0.01*geom.arcseconds 

479 self.assertTrue(polygon.contains(centerCoord.getVector())) 

480 for vertex in vertexList: 

481 bearingToCenter = vertex.bearingTo(centerCoord) 

482 cornerShiftedIn = vertex.offset(bearing=bearingToCenter, amount=shiftAngle) 

483 cornerShiftedOut = vertex.offset(bearing=bearingToCenter, amount=-shiftAngle) 

484 self.assertTrue(polygon.contains(cornerShiftedIn.getVector())) 

485 self.assertFalse(polygon.contains(cornerShiftedOut.getVector())) 

486 

487 def assertRegionOk(self, region, centerCoord, vertexList=[]): 

488 """Assert that an on-sky region is appropriate. 

489 

490 region : `lsst.sphgeom.Region` 

491 On-sky region. 

492 centerCoord : `lsst.geom.SpherePoint` 

493 A coord approximately in the center of the region 

494 vertexList : `iterable` of `lsst.geom.SpherePoint`, optional 

495 Vertices to test. 

496 """ 

497 shiftAngle = 0.01*geom.arcseconds 

498 self.assertTrue(region.contains(centerCoord.getVector())) 

499 for vertex in vertexList: 

500 bearingToCenter = vertex.bearingTo(centerCoord) 

501 cornerShiftedIn = vertex.offset(bearing=bearingToCenter, amount=shiftAngle) 

502 cornerShiftedOut = vertex.offset(bearing=bearingToCenter, amount=-shiftAngle) 

503 self.assertTrue(region.contains(cornerShiftedIn.getVector())) 

504 self.assertFalse(region.contains(cornerShiftedOut.getVector())) 

505 

506 

507def getCornerCoords(wcs, bbox): 

508 """Return the coords of the four corners of a bounding box 

509 """ 

510 cornerPosList = geom.Box2D(bbox).getCorners() 

511 return wcs.pixelToSky(cornerPosList)