Coverage for tests / helper / skyMapTestCase.py: 9%
260 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 23:51 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 23:51 +0000
1#
2# LSST Data Management System
3# Copyright 2008-2017 LSST Corporation.
4#
5# This product includes software developed by the
6# LSST Project (http://www.lsst.org/).
7#
8# This program is free software: you can redistribute it and/or modify
9# it under the terms of the GNU General Public License as published by
10# the Free Software Foundation, either version 3 of the License, or
11# (at your option) any later version.
12#
13# This program is distributed in the hope that it will be useful,
14# but WITHOUT ANY WARRANTY; without even the implied warranty of
15# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16# GNU General Public License for more details.
17#
18# You should have received a copy of the LSST License Statement and
19# the GNU General Public License along with this program. If not,
20# see <http://www.lsstcorp.org/LegalNotices/>.
21#
22import itertools
23import pickle
24import copy
26import numpy as np
28import lsst.geom as geom
29import lsst.sphgeom as sphgeom
30import lsst.utils.tests
32from lsst.skymap import skyMapRegistry
35def checkDm14809(testcase, skymap):
36 """Test that DM-14809 has been fixed
38 The observed behaviour was:
40 skyMap.findTract(skyMap[9712].getCtrCoord()).getId() != 9712
42 and
44 skyMap[1].getCtrCoord() == skyMap[11].getCtrCoord()
46 In order to be thorough, we generalise these over the entire skymap.
47 """
48 # Check that the tract found for central coordinate of a tract is that
49 # tract.
50 expect = [tract.getId() for tract in skymap]
51 got = [skymap.findTract(tract.getCtrCoord()).getId() for tract in skymap]
52 testcase.assertListEqual(got, expect)
54 # Check that the tract central coordinates are unique
55 # Round to integer arcminutes so differences are relatively immune to
56 # small numerical inaccuracies.
57 centers = set([(int(coord.getRa().asArcminutes()), int(coord.getDec().asArcminutes())) for
58 coord in (tract.getCtrCoord() for tract in skymap)])
59 testcase.assertEqual(len(centers), len(skymap))
62class SkyMapTestCase(lsst.utils.tests.TestCase):
63 """An abstract base class for testing a SkyMap.
65 To use, subclass and call `setAttributes` from `setUp`
66 """
67 def setAttributes(self, *,
68 SkyMapClass,
69 name,
70 numTracts,
71 config=None,
72 neighborAngularSeparation=None,
73 numNeighbors=None):
74 """Initialize the test (call from setUp in the subclass)
76 Parameters
77 ----------
78 SkyMapClass : subclass of `lsst.skymap.BaseSkyMap`
79 Class of sky map to test
80 name : `str`
81 Name of sky map in sky map registry
82 numTracts : `int`
83 Number of tracts in the default configuration
84 config : subclass of `lsst.skymap.SkyMapConfig`, optional
85 Default configuration used by `getSkyMap`;
86 if None use SkyMapClass.ConfigClass()
87 neighborAngularSeparation : `lsst.geom.Angle`, optional
88 Expected angular separation between tracts;
89 if None then angular separation is not tested unless your
90 subclass of SkyMapTestCaseoverrides `testTractSeparation`.
91 numNeighbors : `int` or `None`
92 Number of neighbors that should be within
93 ``neighborAngularSeparation``;
94 Required if ``neighborAngularSeparation`` is not None;
95 ignored otherwise.
96 """
97 self.SkyMapClass = SkyMapClass
98 self.config = config
99 self.name = name
100 self.numTracts = numTracts
101 self.neighborAngularSeparation = neighborAngularSeparation
102 self.numNeighbors = numNeighbors
103 np.random.seed(47)
105 def getSkyMap(self, config=None):
106 """Provide an instance of the skymap"""
107 if config is None:
108 config = self.getConfig()
109 config.validate()
110 return self.SkyMapClass(config=config)
112 def getConfig(self):
113 """Provide an instance of the configuration class"""
114 if self.config is None:
115 return self.SkyMapClass.ConfigClass()
116 # Want to return a copy of self.config, so it can be modified.
117 return copy.copy(self.config)
119 def testRegistry(self):
120 """Confirm that the skymap can be retrieved from the registry"""
121 self.assertEqual(skyMapRegistry[self.name], self.SkyMapClass)
123 def testBasicAttributes(self):
124 """Confirm that constructor attributes are available
125 """
126 defaultSkyMap = self.getSkyMap()
127 for tractOverlap in (0.0, 0.01, 0.1): # degrees
128 config = self.getConfig()
129 config.tractOverlap = tractOverlap
130 skyMap = self.getSkyMap(config)
131 for tractInfo in skyMap:
132 self.assertAlmostEqual(tractInfo.getTractOverlap().asDegrees(), tractOverlap)
133 self.assertEqual(len(skyMap), self.numTracts)
134 self.assertNotEqual(skyMap, defaultSkyMap)
136 if defaultSkyMap.config.tractBuilder.name == 'cells':
137 # The following tests are not appropriate for cells
138 # see test_ringsSkyMapCells.py for "cell" tract testing.
139 return
141 for patchBorder in (0, 101):
142 config = self.getConfig()
143 config.patchBorder = patchBorder
144 skyMap = self.getSkyMap(config)
145 for tractInfo in skyMap:
146 self.assertEqual(tractInfo.getPatchBorder(), patchBorder)
147 self.assertEqual(len(skyMap), self.numTracts)
148 self.assertNotEqual(skyMap, defaultSkyMap)
150 for xInnerDim in (1005, 5062):
151 for yInnerDim in (2032, 5431):
152 config = self.getConfig()
153 config.patchInnerDimensions = (xInnerDim, yInnerDim)
154 skyMap = self.getSkyMap(config)
155 for tractInfo in skyMap:
156 self.assertEqual(tuple(tractInfo.getPatchInnerDimensions()), (xInnerDim, yInnerDim))
157 self.assertEqual(len(skyMap), self.numTracts)
158 self.assertNotEqual(skyMap, defaultSkyMap)
160 # Compare a few patches
161 defaultSkyMap = self.getSkyMap()
162 tractInfo = defaultSkyMap[0]
163 numPatches = tractInfo.getNumPatches()
164 patchBorder = defaultSkyMap.config.patchBorder
165 for xInd in (0, 1, numPatches[0]//2, numPatches[0]-2, numPatches[0]-1):
166 for yInd in (0, 1, numPatches[1]//2, numPatches[1]-2, numPatches[1]-1):
167 patchInfo = tractInfo.getPatchInfo((xInd, yInd))
169 # check inner and outer bbox
170 innerBBox = patchInfo.getInnerBBox()
171 outerBBox = patchInfo.getOuterBBox()
173 if xInd == 0:
174 self.assertEqual(innerBBox.getMinX(), outerBBox.getMinX())
175 else:
176 self.assertEqual(innerBBox.getMinX() - patchBorder, outerBBox.getMinX())
177 if yInd == 0:
178 self.assertEqual(innerBBox.getMinY(), outerBBox.getMinY())
179 else:
180 self.assertEqual(innerBBox.getMinY() - patchBorder, outerBBox.getMinY())
182 if xInd == numPatches[0] - 1:
183 self.assertEqual(innerBBox.getMaxX(), outerBBox.getMaxX())
184 else:
185 self.assertEqual(innerBBox.getMaxX() + patchBorder, outerBBox.getMaxX())
186 if yInd == numPatches[1] - 1:
187 self.assertEqual(innerBBox.getMaxY(), outerBBox.getMaxY())
188 else:
189 self.assertEqual(innerBBox.getMaxY() + patchBorder, outerBBox.getMaxY())
191 def assertUnpickledTractInfo(self, unpickled, original, patchBorder):
192 """Assert that an unpickled TractInfo is functionally identical to the
193 original.
195 Parameters
196 ----------
197 unpickled : `TractInfo`
198 The unpickled `TractInfo`.
199 original : `TractInfo`
200 The original `TractInfo`.
201 patchBorder : `int`
202 Border around each patch, from ``SkyMap.config.patchBorder``.
203 """
204 for getterName in ("getBBox",
205 "getCtrCoord",
206 "getId",
207 "getNumPatches",
208 "getPatchBorder",
209 "getPatchInnerDimensions",
210 "getTractOverlap",
211 "getVertexList",
212 "getWcs",
213 ):
214 self.assertEqual(getattr(original, getterName)(), getattr(unpickled, getterName)())
216 # test WCS at a few locations
217 wcs = original.getWcs()
218 unpickledWcs = unpickled.getWcs()
219 for x in (-1000.0, 0.0, 1000.0):
220 for y in (-532.5, 0.5, 532.5):
221 pixelPos = geom.Point2D(x, y)
222 skyPos = wcs.pixelToSky(pixelPos)
223 unpickledSkyPos = unpickledWcs.pixelToSky(pixelPos)
224 self.assertEqual(skyPos, unpickledSkyPos)
226 # compare a few patches
227 numPatches = original.getNumPatches()
228 for xInd in (0, 1, numPatches[0]//2, numPatches[0]-2, numPatches[0]-1):
229 for yInd in (0, 1, numPatches[1]//2, numPatches[1]-2, numPatches[1]-1):
230 patchInfo = original.getPatchInfo((xInd, yInd))
231 unpickledPatchInfo = unpickled.getPatchInfo((xInd, yInd))
232 self.assertEqual(patchInfo, unpickledPatchInfo)
234 def testPickle(self):
235 """Test that pickling and unpickling restores the original exactly
236 """
237 skyMap = self.getSkyMap()
238 pickleStr = pickle.dumps(skyMap)
239 unpickledSkyMap = pickle.loads(pickleStr)
240 self.assertEqual(len(skyMap), len(unpickledSkyMap))
241 self.assertEqual(unpickledSkyMap.config, skyMap.config)
242 self.assertEqual(skyMap, unpickledSkyMap)
243 for tractInfo, unpickledTractInfo in zip(skyMap, unpickledSkyMap):
244 self.assertUnpickledTractInfo(unpickledTractInfo, tractInfo, skyMap.config.patchBorder)
246 def testTractSeparation(self):
247 """Confirm that each sky tract has the proper distance to other tracts
248 """
249 if self.neighborAngularSeparation is None:
250 self.skipTest("Not testing angular separation for %s: neighborAngularSeparation is None" %
251 (self.SkyMapClass.__name__,))
252 skyMap = self.getSkyMap()
253 for tractId, tractInfo in enumerate(skyMap):
254 self.assertEqual(tractInfo.getId(), tractId)
256 ctrCoord = tractInfo.getCtrCoord()
257 distList = []
258 for tractInfo1 in skyMap:
259 otherCtrCoord = tractInfo1.getCtrCoord()
260 distList.append(ctrCoord.separation(otherCtrCoord))
261 distList.sort()
262 self.assertEqual(distList[0], 0.0)
263 for dist in distList[1:self.numNeighbors]:
264 self.assertAnglesAlmostEqual(dist, self.neighborAngularSeparation)
266 def testFindPatchList(self):
267 """Test TractInfo.findPatchList
268 """
269 skyMap = self.getSkyMap()
270 # pick two arbitrary tracts
271 for tractId in np.random.choice(len(skyMap), 2):
272 tractInfo = skyMap[tractId]
273 wcs = tractInfo.getWcs()
274 numPatches = tractInfo.getNumPatches()
275 border = tractInfo.getPatchBorder()
276 for patchInd in ((0, 0),
277 (0, 1),
278 (5, 0),
279 (5, 6),
280 (numPatches[0] - 2, numPatches[1] - 1),
281 (numPatches[0] - 1, numPatches[1] - 2),
282 (numPatches[0] - 1, numPatches[1] - 1),
283 ):
284 patchInfo = tractInfo.getPatchInfo(patchInd)
285 patchIndex = patchInfo.getIndex()
286 bbox = patchInfo.getInnerBBox()
287 bbox.grow(-(border+1))
288 coordList = getCornerCoords(wcs=wcs, bbox=bbox)
289 patchInfoList = tractInfo.findPatchList(coordList)
290 self.assertEqual(len(patchInfoList), 1)
291 self.assertEqual(patchInfoList[0].getIndex(), patchIndex)
293 # grow to include neighbors and test again
294 bbox.grow(2)
295 predFoundIndexSet = set()
296 for dx in (-1, 0, 1):
297 nbrX = patchIndex[0] + dx
298 if not 0 <= nbrX < numPatches[0]:
299 continue
300 for dy in (-1, 0, 1):
301 nbrY = patchIndex[1] + dy
302 if not 0 <= nbrY < numPatches[1]:
303 continue
304 nbrInd = (nbrX, nbrY)
305 predFoundIndexSet.add(nbrInd)
306 coordList = getCornerCoords(wcs=wcs, bbox=bbox)
307 patchInfoList = tractInfo.findPatchList(coordList)
308 self.assertEqual(len(patchInfoList), len(predFoundIndexSet))
309 foundIndexSet = set(patchInfo.getIndex() for patchInfo in patchInfoList)
310 self.assertEqual(foundIndexSet, predFoundIndexSet)
312 def testFindTractPatchList(self):
313 """Test findTractPatchList
315 Notes
316 -----
317 This test uses single points for speed and to avoid really large
318 regions.
319 Note that `findPatchList` is being tested elsewhere.
320 """
321 skyMap = self.getSkyMap()
322 # pick 3 arbitrary tracts
323 for tractId in np.random.choice(len(skyMap), 3):
324 tractInfo = skyMap[tractId]
325 self.assertTractPatchListOk(
326 skyMap=skyMap,
327 coordList=[tractInfo.getCtrCoord()],
328 knownTractId=tractId,
329 )
330 self.assertClosestTractPatchList(skyMap, [tractInfo.getCtrCoord()], tractId)
332 vertices = tractInfo.getVertexList()
333 if len(vertices) > 0:
334 self.assertTractPatchListOk(
335 skyMap=skyMap,
336 coordList=[tractInfo.getVertexList()[0]],
337 knownTractId=tractId,
338 )
340 if len(vertices) > 2:
341 self.assertTractPatchListOk(
342 skyMap=skyMap,
343 coordList=[tractInfo.getVertexList()[2]],
344 knownTractId=tractId,
345 )
347 def testTractContains(self):
348 """Test that TractInfo.contains works"""
349 skyMap = self.getSkyMap()
350 for tract in skyMap:
351 coord = tract.getCtrCoord()
352 self.assertTrue(tract.contains(coord))
353 opposite = geom.SpherePoint(coord.getLongitude() + 12*geom.hours, -1*coord.getLatitude())
354 self.assertFalse(tract.contains(opposite))
356 def testTractInfoGetPolygon(self):
357 skyMap = self.getSkyMap()
358 for tractInfo in skyMap:
359 centerCoord = tractInfo.getCtrCoord()
360 # TODO: Remove with DM-44799
361 with self.assertWarns(FutureWarning):
362 self.assertPolygonOk(polygon=tractInfo.getInnerSkyPolygon(),
363 vertexList=tractInfo.getVertexList(),
364 centerCoord=centerCoord)
365 self.assertBBoxPolygonOk(polygon=tractInfo.getOuterSkyPolygon(),
366 bbox=tractInfo.getBBox(), wcs=tractInfo.getWcs())
368 def testTractInfoGetRegion(self):
369 skyMap = self.getSkyMap()
370 for tractInfo in skyMap:
371 centerCoord = tractInfo.getCtrCoord()
372 region = tractInfo.getInnerSkyRegion()
373 if isinstance(region, sphgeom.Box):
374 self.assertRegionOk(region=region,
375 centerCoord=centerCoord)
376 else:
377 self.assertRegionOk(region=region,
378 centerCoord=centerCoord,
379 vertexList=tractInfo.getVertexList())
381 def testPatchInfoGetPolygon(self):
382 skyMap = self.getSkyMap()
383 numPatches = skyMap[0].getNumPatches()
385 def getIndices(numItems):
386 """Return up to 3 indices for testing"""
387 if numItems > 2:
388 return (0, 1, numItems-1)
389 elif numItems > 1:
390 return (0, 1)
391 return (0,)
393 for tractInfo in skyMap:
394 wcs = tractInfo.getWcs()
395 for patchInd in itertools.product(getIndices(numPatches[0]), getIndices(numPatches[1])):
396 with self.subTest(patchInd=patchInd):
397 patchInfo = tractInfo.getPatchInfo(patchInd)
398 self.assertBBoxPolygonOk(polygon=patchInfo.getInnerSkyPolygon(tractWcs=wcs),
399 bbox=patchInfo.getInnerBBox(), wcs=wcs)
400 self.assertBBoxPolygonOk(polygon=patchInfo.getOuterSkyPolygon(tractWcs=wcs),
401 bbox=patchInfo.getOuterBBox(), wcs=wcs)
403 def testDm14809(self):
404 """Generic version of test that DM-14809 has been fixed"""
405 checkDm14809(self, self.getSkyMap())
407 def testNumbering(self):
408 """Check the numbering of tracts matches the indexing"""
409 skymap = self.getSkyMap()
410 expect = list(range(len(skymap)))
411 got = [tt.getId() for tt in skymap]
412 self.assertEqual(got, expect)
414 def assertTractPatchListOk(self, skyMap, coordList, knownTractId):
415 """Assert that findTractPatchList produces the correct results.
417 Parameters
418 ----------
419 skyMap : `BaseSkyMap`
420 Sky map to test.
421 coordList : `list` of `lsst.geom.SpherePoint`
422 Region to search for.
423 knownTractId : `int`
424 This tractId must appear in the found list.
425 """
426 tractPatchList = skyMap.findTractPatchList(coordList)
427 tractPatchDict = dict((tp[0].getId(), tp[1]) for tp in tractPatchList)
428 self.assertTrue(knownTractId in tractPatchDict)
429 for tractInfo in skyMap:
430 tractId = tractInfo.getId()
431 patchList = tractInfo.findPatchList(coordList)
432 if patchList:
433 self.assertTrue(tractId in tractPatchDict)
434 self.assertTrue(len(patchList) == len(tractPatchDict[tractId]))
435 else:
436 self.assertTrue(tractId not in tractPatchDict)
438 def assertClosestTractPatchList(self, skyMap, coordList, knownTractId):
439 if not hasattr(skyMap, "findClosestTractPatchList"):
440 self.skipTest("This skymap doesn't implement findClosestTractPatchList")
441 tractPatchList = skyMap.findClosestTractPatchList(coordList)
442 self.assertEqual(len(coordList), len(tractPatchList)) # One tract+patchList per coordinate
443 for coord, (tract, patchList) in zip(coordList, tractPatchList):
444 self.assertEqual(tract.getId(), knownTractId)
445 self.assertEqual(patchList, tract.findPatchList([coord]))
447 def assertBBoxPolygonOk(self, polygon, bbox, wcs):
448 """Assert that an on-sky polygon from a pixel bbox
449 covers the expected region.
451 Parameters
452 ----------
453 polygon : `lsst.sphgeom.ConvexPolygon`
454 On-sky polygon
455 vertexList : `iterable` of `lsst.geom.SpherePoint`
456 Vertices of polygon
457 centerCoord : `lsst.geom.SpherePoint`
458 A coord approximately in the center of the region
459 """
460 bboxd = geom.Box2D(bbox)
461 centerPixel = bboxd.getCenter()
462 centerCoord = wcs.pixelToSky(centerPixel)
463 skyCorners = getCornerCoords(wcs=wcs, bbox=bbox)
464 self.assertPolygonOk(polygon=polygon, vertexList=skyCorners, centerCoord=centerCoord)
466 def assertPolygonOk(self, polygon, vertexList, centerCoord):
467 """Assert that an on-sky polygon from covers the expected region.
469 Parameters
470 ----------
471 polygon : `lsst.sphgeom.ConvexPolygon`
472 On-sky polygon
473 vertexList : `iterable` of `lsst.geom.SpherePoint`
474 Vertices of polygon
475 centerCoord : `lsst.geom.SpherePoint`
476 A coord approximately in the center of the region
477 """
478 shiftAngle = 0.01*geom.arcseconds
479 self.assertTrue(polygon.contains(centerCoord.getVector()))
480 for vertex in vertexList:
481 bearingToCenter = vertex.bearingTo(centerCoord)
482 cornerShiftedIn = vertex.offset(bearing=bearingToCenter, amount=shiftAngle)
483 cornerShiftedOut = vertex.offset(bearing=bearingToCenter, amount=-shiftAngle)
484 self.assertTrue(polygon.contains(cornerShiftedIn.getVector()))
485 self.assertFalse(polygon.contains(cornerShiftedOut.getVector()))
487 def assertRegionOk(self, region, centerCoord, vertexList=[]):
488 """Assert that an on-sky region is appropriate.
490 region : `lsst.sphgeom.Region`
491 On-sky region.
492 centerCoord : `lsst.geom.SpherePoint`
493 A coord approximately in the center of the region
494 vertexList : `iterable` of `lsst.geom.SpherePoint`, optional
495 Vertices to test.
496 """
497 shiftAngle = 0.01*geom.arcseconds
498 self.assertTrue(region.contains(centerCoord.getVector()))
499 for vertex in vertexList:
500 bearingToCenter = vertex.bearingTo(centerCoord)
501 cornerShiftedIn = vertex.offset(bearing=bearingToCenter, amount=shiftAngle)
502 cornerShiftedOut = vertex.offset(bearing=bearingToCenter, amount=-shiftAngle)
503 self.assertTrue(region.contains(cornerShiftedIn.getVector()))
504 self.assertFalse(region.contains(cornerShiftedOut.getVector()))
507def getCornerCoords(wcs, bbox):
508 """Return the coords of the four corners of a bounding box
509 """
510 cornerPosList = geom.Box2D(bbox).getCorners()
511 return wcs.pixelToSky(cornerPosList)