Coverage for tests/helper/skyMapTestCase.py: 9%

241 statements  

« prev     ^ index     » next       coverage.py v6.4.1, created at 2022-06-09 02:58 -0700

1# 

2# LSST Data Management System 

3# Copyright 2008-2017 LSST Corporation. 

4# 

5# This product includes software developed by the 

6# LSST Project (http://www.lsst.org/). 

7# 

8# This program is free software: you can redistribute it and/or modify 

9# it under the terms of the GNU General Public License as published by 

10# the Free Software Foundation, either version 3 of the License, or 

11# (at your option) any later version. 

12# 

13# This program is distributed in the hope that it will be useful, 

14# but WITHOUT ANY WARRANTY; without even the implied warranty of 

15# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

16# GNU General Public License for more details. 

17# 

18# You should have received a copy of the LSST License Statement and 

19# the GNU General Public License along with this program. If not, 

20# see <http://www.lsstcorp.org/LegalNotices/>. 

21# 

22import itertools 

23import pickle 

24import copy 

25 

26import numpy as np 

27 

28import lsst.geom as geom 

29import lsst.utils.tests 

30 

31from lsst.skymap import skyMapRegistry 

32 

33 

34def checkDm14809(testcase, skymap): 

35 """Test that DM-14809 has been fixed 

36 

37 The observed behaviour was: 

38 

39 skyMap.findTract(skyMap[9712].getCtrCoord()).getId() != 9712 

40 

41 and 

42 

43 skyMap[1].getCtrCoord() == skyMap[11].getCtrCoord() 

44 

45 In order to be thorough, we generalise these over the entire skymap. 

46 """ 

47 # Check that the tract found for central coordinate of a tract is that tract 

48 expect = [tract.getId() for tract in skymap] 

49 got = [skymap.findTract(tract.getCtrCoord()).getId() for tract in skymap] 

50 testcase.assertListEqual(got, expect) 

51 

52 # Check that the tract central coordinates are unique 

53 # Round to integer arcminutes so differences are relatively immune to small numerical inaccuracies 

54 centers = set([(int(coord.getRa().asArcminutes()), int(coord.getDec().asArcminutes())) for 

55 coord in (tract.getCtrCoord() for tract in skymap)]) 

56 testcase.assertEqual(len(centers), len(skymap)) 

57 

58 

59class SkyMapTestCase(lsst.utils.tests.TestCase): 

60 """An abstract base class for testing a SkyMap. 

61 

62 To use, subclass and call `setAttributes` from `setUp` 

63 """ 

64 def setAttributes(self, *, 

65 SkyMapClass, 

66 name, 

67 numTracts, 

68 config=None, 

69 neighborAngularSeparation=None, 

70 numNeighbors=None): 

71 """Initialize the test (call from setUp in the subclass) 

72 

73 Parameters 

74 ---------- 

75 SkyMapClass : subclass of `lsst.skymap.BaseSkyMap` 

76 Class of sky map to test 

77 name : `str` 

78 Name of sky map in sky map registry 

79 numTracts : `int` 

80 Number of tracts in the default configuration 

81 config : subclass of `lsst.skymap.SkyMapConfig`, optional 

82 Default configuration used by `getSkyMap`; 

83 if None use SkyMapClass.ConfigClass() 

84 neighborAngularSeparation : `lsst.geom.Angle`, optional 

85 Expected angular separation between tracts; 

86 if None then angular separation is not tested unless your 

87 subclass of SkyMapTestCaseoverrides `testTractSeparation`. 

88 numNeighbors : `int` or `None` 

89 Number of neighbors that should be within 

90 ``neighborAngularSeparation``; 

91 Required if ``neighborAngularSeparation`` is not None; 

92 ignored otherwise. 

93 """ 

94 self.SkyMapClass = SkyMapClass 

95 self.config = config 

96 self.name = name 

97 self.numTracts = numTracts 

98 self.neighborAngularSeparation = neighborAngularSeparation 

99 self.numNeighbors = numNeighbors 

100 np.random.seed(47) 

101 

102 def getSkyMap(self, config=None): 

103 """Provide an instance of the skymap""" 

104 if config is None: 

105 config = self.getConfig() 

106 config.validate() 

107 return self.SkyMapClass(config=config) 

108 

109 def getConfig(self): 

110 """Provide an instance of the configuration class""" 

111 if self.config is None: 

112 return self.SkyMapClass.ConfigClass() 

113 # Want to return a copy of self.config, so it can be modified. 

114 return copy.copy(self.config) 

115 

116 def testRegistry(self): 

117 """Confirm that the skymap can be retrieved from the registry""" 

118 self.assertEqual(skyMapRegistry[self.name], self.SkyMapClass) 

119 

120 def testBasicAttributes(self): 

121 """Confirm that constructor attributes are available 

122 """ 

123 defaultSkyMap = self.getSkyMap() 

124 for tractOverlap in (0.0, 0.01, 0.1): # degrees 

125 config = self.getConfig() 

126 config.tractOverlap = tractOverlap 

127 skyMap = self.getSkyMap(config) 

128 for tractInfo in skyMap: 

129 self.assertAlmostEqual(tractInfo.getTractOverlap().asDegrees(), tractOverlap) 

130 self.assertEqual(len(skyMap), self.numTracts) 

131 self.assertNotEqual(skyMap, defaultSkyMap) 

132 

133 if defaultSkyMap.config.tractBuilder.name == 'cells': 

134 # The following tests are not appropriate for cells 

135 # see test_ringsSkyMapCells.py for "cell" tract testing. 

136 return 

137 

138 for patchBorder in (0, 101): 

139 config = self.getConfig() 

140 config.patchBorder = patchBorder 

141 skyMap = self.getSkyMap(config) 

142 for tractInfo in skyMap: 

143 self.assertEqual(tractInfo.getPatchBorder(), patchBorder) 

144 self.assertEqual(len(skyMap), self.numTracts) 

145 self.assertNotEqual(skyMap, defaultSkyMap) 

146 

147 for xInnerDim in (1005, 5062): 

148 for yInnerDim in (2032, 5431): 

149 config = self.getConfig() 

150 config.patchInnerDimensions = (xInnerDim, yInnerDim) 

151 skyMap = self.getSkyMap(config) 

152 for tractInfo in skyMap: 

153 self.assertEqual(tuple(tractInfo.getPatchInnerDimensions()), (xInnerDim, yInnerDim)) 

154 self.assertEqual(len(skyMap), self.numTracts) 

155 self.assertNotEqual(skyMap, defaultSkyMap) 

156 

157 # Compare a few patches 

158 defaultSkyMap = self.getSkyMap() 

159 tractInfo = defaultSkyMap[0] 

160 numPatches = tractInfo.getNumPatches() 

161 patchBorder = defaultSkyMap.config.patchBorder 

162 for xInd in (0, 1, numPatches[0]//2, numPatches[0]-2, numPatches[0]-1): 

163 for yInd in (0, 1, numPatches[1]//2, numPatches[1]-2, numPatches[1]-1): 

164 patchInfo = tractInfo.getPatchInfo((xInd, yInd)) 

165 

166 # check inner and outer bbox 

167 innerBBox = patchInfo.getInnerBBox() 

168 outerBBox = patchInfo.getOuterBBox() 

169 

170 if xInd == 0: 

171 self.assertEqual(innerBBox.getMinX(), outerBBox.getMinX()) 

172 else: 

173 self.assertEqual(innerBBox.getMinX() - patchBorder, outerBBox.getMinX()) 

174 if yInd == 0: 

175 self.assertEqual(innerBBox.getMinY(), outerBBox.getMinY()) 

176 else: 

177 self.assertEqual(innerBBox.getMinY() - patchBorder, outerBBox.getMinY()) 

178 

179 if xInd == numPatches[0] - 1: 

180 self.assertEqual(innerBBox.getMaxX(), outerBBox.getMaxX()) 

181 else: 

182 self.assertEqual(innerBBox.getMaxX() + patchBorder, outerBBox.getMaxX()) 

183 if yInd == numPatches[1] - 1: 

184 self.assertEqual(innerBBox.getMaxY(), outerBBox.getMaxY()) 

185 else: 

186 self.assertEqual(innerBBox.getMaxY() + patchBorder, outerBBox.getMaxY()) 

187 

188 def assertUnpickledTractInfo(self, unpickled, original, patchBorder): 

189 """Assert that an unpickled TractInfo is functionally identical to the original 

190 

191 @param unpickled The unpickled TractInfo 

192 @param original The original TractInfo 

193 @param patchBorder Border around each patch, from SkyMap.config.patchBorder 

194 """ 

195 for getterName in ("getBBox", 

196 "getCtrCoord", 

197 "getId", 

198 "getNumPatches", 

199 "getPatchBorder", 

200 "getPatchInnerDimensions", 

201 "getTractOverlap", 

202 "getVertexList", 

203 "getWcs", 

204 ): 

205 self.assertEqual(getattr(original, getterName)(), getattr(unpickled, getterName)()) 

206 

207 # test WCS at a few locations 

208 wcs = original.getWcs() 

209 unpickledWcs = unpickled.getWcs() 

210 for x in (-1000.0, 0.0, 1000.0): 

211 for y in (-532.5, 0.5, 532.5): 

212 pixelPos = geom.Point2D(x, y) 

213 skyPos = wcs.pixelToSky(pixelPos) 

214 unpickledSkyPos = unpickledWcs.pixelToSky(pixelPos) 

215 self.assertEqual(skyPos, unpickledSkyPos) 

216 

217 # compare a few patches 

218 numPatches = original.getNumPatches() 

219 for xInd in (0, 1, numPatches[0]//2, numPatches[0]-2, numPatches[0]-1): 

220 for yInd in (0, 1, numPatches[1]//2, numPatches[1]-2, numPatches[1]-1): 

221 patchInfo = original.getPatchInfo((xInd, yInd)) 

222 unpickledPatchInfo = unpickled.getPatchInfo((xInd, yInd)) 

223 self.assertEqual(patchInfo, unpickledPatchInfo) 

224 

225 def testPickle(self): 

226 """Test that pickling and unpickling restores the original exactly 

227 """ 

228 skyMap = self.getSkyMap() 

229 pickleStr = pickle.dumps(skyMap) 

230 unpickledSkyMap = pickle.loads(pickleStr) 

231 self.assertEqual(len(skyMap), len(unpickledSkyMap)) 

232 self.assertEqual(unpickledSkyMap.config, skyMap.config) 

233 self.assertEqual(skyMap, unpickledSkyMap) 

234 for tractInfo, unpickledTractInfo in zip(skyMap, unpickledSkyMap): 

235 self.assertUnpickledTractInfo(unpickledTractInfo, tractInfo, skyMap.config.patchBorder) 

236 

237 def testTractSeparation(self): 

238 """Confirm that each sky tract has the proper distance to other tracts 

239 """ 

240 if self.neighborAngularSeparation is None: 

241 self.skipTest("Not testing angular separation for %s: neighborAngularSeparation is None" % 

242 (self.SkyMapClass.__name__,)) 

243 skyMap = self.getSkyMap() 

244 for tractId, tractInfo in enumerate(skyMap): 

245 self.assertEqual(tractInfo.getId(), tractId) 

246 

247 ctrCoord = tractInfo.getCtrCoord() 

248 distList = [] 

249 for tractInfo1 in skyMap: 

250 otherCtrCoord = tractInfo1.getCtrCoord() 

251 distList.append(ctrCoord.separation(otherCtrCoord)) 

252 distList.sort() 

253 self.assertEqual(distList[0], 0.0) 

254 for dist in distList[1:self.numNeighbors]: 

255 self.assertAnglesAlmostEqual(dist, self.neighborAngularSeparation) 

256 

257 def testFindPatchList(self): 

258 """Test TractInfo.findPatchList 

259 """ 

260 skyMap = self.getSkyMap() 

261 # pick two arbitrary tracts 

262 for tractId in np.random.choice(len(skyMap), 2): 

263 tractInfo = skyMap[tractId] 

264 wcs = tractInfo.getWcs() 

265 numPatches = tractInfo.getNumPatches() 

266 border = tractInfo.getPatchBorder() 

267 for patchInd in ((0, 0), 

268 (0, 1), 

269 (5, 0), 

270 (5, 6), 

271 (numPatches[0] - 2, numPatches[1] - 1), 

272 (numPatches[0] - 1, numPatches[1] - 2), 

273 (numPatches[0] - 1, numPatches[1] - 1), 

274 ): 

275 patchInfo = tractInfo.getPatchInfo(patchInd) 

276 patchIndex = patchInfo.getIndex() 

277 bbox = patchInfo.getInnerBBox() 

278 bbox.grow(-(border+1)) 

279 coordList = getCornerCoords(wcs=wcs, bbox=bbox) 

280 patchInfoList = tractInfo.findPatchList(coordList) 

281 self.assertEqual(len(patchInfoList), 1) 

282 self.assertEqual(patchInfoList[0].getIndex(), patchIndex) 

283 

284 # grow to include neighbors and test again 

285 bbox.grow(2) 

286 predFoundIndexSet = set() 

287 for dx in (-1, 0, 1): 

288 nbrX = patchIndex[0] + dx 

289 if not 0 <= nbrX < numPatches[0]: 

290 continue 

291 for dy in (-1, 0, 1): 

292 nbrY = patchIndex[1] + dy 

293 if not 0 <= nbrY < numPatches[1]: 

294 continue 

295 nbrInd = (nbrX, nbrY) 

296 predFoundIndexSet.add(nbrInd) 

297 coordList = getCornerCoords(wcs=wcs, bbox=bbox) 

298 patchInfoList = tractInfo.findPatchList(coordList) 

299 self.assertEqual(len(patchInfoList), len(predFoundIndexSet)) 

300 foundIndexSet = set(patchInfo.getIndex() for patchInfo in patchInfoList) 

301 self.assertEqual(foundIndexSet, predFoundIndexSet) 

302 

303 def testFindTractPatchList(self): 

304 """Test findTractPatchList 

305 

306 Note: this test uses single points for speed and to avoid really large regions. 

307 Note that findPatchList is being tested elsewhere. 

308 """ 

309 skyMap = self.getSkyMap() 

310 # pick 3 arbitrary tracts 

311 for tractId in np.random.choice(len(skyMap), 3): 

312 tractInfo = skyMap[tractId] 

313 self.assertTractPatchListOk( 

314 skyMap=skyMap, 

315 coordList=[tractInfo.getCtrCoord()], 

316 knownTractId=tractId, 

317 ) 

318 self.assertClosestTractPatchList(skyMap, [tractInfo.getCtrCoord()], tractId) 

319 

320 vertices = tractInfo.getVertexList() 

321 if len(vertices) > 0: 

322 self.assertTractPatchListOk( 

323 skyMap=skyMap, 

324 coordList=[tractInfo.getVertexList()[0]], 

325 knownTractId=tractId, 

326 ) 

327 

328 if len(vertices) > 2: 

329 self.assertTractPatchListOk( 

330 skyMap=skyMap, 

331 coordList=[tractInfo.getVertexList()[2]], 

332 knownTractId=tractId, 

333 ) 

334 

335 def testTractContains(self): 

336 """Test that TractInfo.contains works""" 

337 skyMap = self.getSkyMap() 

338 for tract in skyMap: 

339 coord = tract.getCtrCoord() 

340 self.assertTrue(tract.contains(coord)) 

341 opposite = geom.SpherePoint(coord.getLongitude() + 12*geom.hours, -1*coord.getLatitude()) 

342 self.assertFalse(tract.contains(opposite)) 

343 

344 def testTractInfoGetPolygon(self): 

345 skyMap = self.getSkyMap() 

346 for tractInfo in skyMap: 

347 centerCoord = tractInfo.getCtrCoord() 

348 self.assertPolygonOk(polygon=tractInfo.getInnerSkyPolygon(), 

349 vertexList=tractInfo.getVertexList(), 

350 centerCoord=centerCoord) 

351 self.assertBBoxPolygonOk(polygon=tractInfo.getOuterSkyPolygon(), 

352 bbox=tractInfo.getBBox(), wcs=tractInfo.getWcs()) 

353 

354 def testPatchInfoGetPolygon(self): 

355 skyMap = self.getSkyMap() 

356 numPatches = skyMap[0].getNumPatches() 

357 

358 def getIndices(numItems): 

359 """Return up to 3 indices for testing""" 

360 if numItems > 2: 

361 return (0, 1, numItems-1) 

362 elif numItems > 1: 

363 return (0, 1) 

364 return (0,) 

365 

366 for tractInfo in skyMap: 

367 wcs = tractInfo.getWcs() 

368 for patchInd in itertools.product(getIndices(numPatches[0]), getIndices(numPatches[1])): 

369 with self.subTest(patchInd=patchInd): 

370 patchInfo = tractInfo.getPatchInfo(patchInd) 

371 self.assertBBoxPolygonOk(polygon=patchInfo.getInnerSkyPolygon(tractWcs=wcs), 

372 bbox=patchInfo.getInnerBBox(), wcs=wcs) 

373 self.assertBBoxPolygonOk(polygon=patchInfo.getOuterSkyPolygon(tractWcs=wcs), 

374 bbox=patchInfo.getOuterBBox(), wcs=wcs) 

375 

376 def testDm14809(self): 

377 """Generic version of test that DM-14809 has been fixed""" 

378 checkDm14809(self, self.getSkyMap()) 

379 

380 def testNumbering(self): 

381 """Check the numbering of tracts matches the indexing""" 

382 skymap = self.getSkyMap() 

383 expect = list(range(len(skymap))) 

384 got = [tt.getId() for tt in skymap] 

385 self.assertEqual(got, expect) 

386 

387 def assertTractPatchListOk(self, skyMap, coordList, knownTractId): 

388 """Assert that findTractPatchList produces the correct results 

389 

390 @param[in] skyMap: sky map to test 

391 @param[in] coordList: coordList of region to search for 

392 @param[in] knownTractId: this tractId must appear in the found list 

393 """ 

394 tractPatchList = skyMap.findTractPatchList(coordList) 

395 tractPatchDict = dict((tp[0].getId(), tp[1]) for tp in tractPatchList) 

396 self.assertTrue(knownTractId in tractPatchDict) 

397 for tractInfo in skyMap: 

398 tractId = tractInfo.getId() 

399 patchList = tractInfo.findPatchList(coordList) 

400 if patchList: 

401 self.assertTrue(tractId in tractPatchDict) 

402 self.assertTrue(len(patchList) == len(tractPatchDict[tractId])) 

403 else: 

404 self.assertTrue(tractId not in tractPatchDict) 

405 

406 def assertClosestTractPatchList(self, skyMap, coordList, knownTractId): 

407 if not hasattr(skyMap, "findClosestTractPatchList"): 

408 self.skipTest("This skymap doesn't implement findClosestTractPatchList") 

409 tractPatchList = skyMap.findClosestTractPatchList(coordList) 

410 self.assertEqual(len(coordList), len(tractPatchList)) # One tract+patchList per coordinate 

411 for coord, (tract, patchList) in zip(coordList, tractPatchList): 

412 self.assertEqual(tract.getId(), knownTractId) 

413 self.assertEqual(patchList, tract.findPatchList([coord])) 

414 

415 def assertBBoxPolygonOk(self, polygon, bbox, wcs): 

416 """Assert that an on-sky polygon from a pixel bbox 

417 covers the expected region. 

418 

419 Parameters 

420 ---------- 

421 polygon : `lsst.sphgeom.ConvexPolygon` 

422 On-sky polygon 

423 vertexList : `iterable` of `lsst.geom.SpherePoint` 

424 Vertices of polygon 

425 centerCoord : `lsst.geom.SpherePoint` 

426 A coord approximately in the center of the region 

427 """ 

428 bboxd = geom.Box2D(bbox) 

429 centerPixel = bboxd.getCenter() 

430 centerCoord = wcs.pixelToSky(centerPixel) 

431 skyCorners = getCornerCoords(wcs=wcs, bbox=bbox) 

432 self.assertPolygonOk(polygon=polygon, vertexList=skyCorners, centerCoord=centerCoord) 

433 

434 def assertPolygonOk(self, polygon, vertexList, centerCoord): 

435 """Assert that an on-sky polygon from covers the expected region. 

436 

437 Parameters 

438 ---------- 

439 polygon : `lsst.sphgeom.ConvexPolygon` 

440 On-sky polygon 

441 vertexList : `iterable` of `lsst.geom.SpherePoint` 

442 Vertices of polygon 

443 centerCoord : `lsst.geom.SpherePoint` 

444 A coord approximately in the center of the region 

445 """ 

446 shiftAngle = 0.01*geom.arcseconds 

447 self.assertTrue(polygon.contains(centerCoord.getVector())) 

448 for vertex in vertexList: 

449 bearingToCenter = vertex.bearingTo(centerCoord) 

450 cornerShiftedIn = vertex.offset(bearing=bearingToCenter, amount=shiftAngle) 

451 cornerShiftedOut = vertex.offset(bearing=bearingToCenter, amount=-shiftAngle) 

452 self.assertTrue(polygon.contains(cornerShiftedIn.getVector())) 

453 self.assertFalse(polygon.contains(cornerShiftedOut.getVector())) 

454 

455 

456def getCornerCoords(wcs, bbox): 

457 """Return the coords of the four corners of a bounding box 

458 """ 

459 cornerPosList = geom.Box2D(bbox).getCorners() 

460 return wcs.pixelToSky(cornerPosList)