Coverage for tests/helper/skyMapTestCase.py: 8%
241 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-11 03:42 -0700
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-11 03:42 -0700
1#
2# LSST Data Management System
3# Copyright 2008-2017 LSST Corporation.
4#
5# This product includes software developed by the
6# LSST Project (http://www.lsst.org/).
7#
8# This program is free software: you can redistribute it and/or modify
9# it under the terms of the GNU General Public License as published by
10# the Free Software Foundation, either version 3 of the License, or
11# (at your option) any later version.
12#
13# This program is distributed in the hope that it will be useful,
14# but WITHOUT ANY WARRANTY; without even the implied warranty of
15# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16# GNU General Public License for more details.
17#
18# You should have received a copy of the LSST License Statement and
19# the GNU General Public License along with this program. If not,
20# see <http://www.lsstcorp.org/LegalNotices/>.
21#
22import itertools
23import pickle
24import copy
26import numpy as np
28import lsst.geom as geom
29import lsst.utils.tests
31from lsst.skymap import skyMapRegistry
34def checkDm14809(testcase, skymap):
35 """Test that DM-14809 has been fixed
37 The observed behaviour was:
39 skyMap.findTract(skyMap[9712].getCtrCoord()).getId() != 9712
41 and
43 skyMap[1].getCtrCoord() == skyMap[11].getCtrCoord()
45 In order to be thorough, we generalise these over the entire skymap.
46 """
47 # Check that the tract found for central coordinate of a tract is that
48 # tract.
49 expect = [tract.getId() for tract in skymap]
50 got = [skymap.findTract(tract.getCtrCoord()).getId() for tract in skymap]
51 testcase.assertListEqual(got, expect)
53 # Check that the tract central coordinates are unique
54 # Round to integer arcminutes so differences are relatively immune to
55 # small numerical inaccuracies.
56 centers = set([(int(coord.getRa().asArcminutes()), int(coord.getDec().asArcminutes())) for
57 coord in (tract.getCtrCoord() for tract in skymap)])
58 testcase.assertEqual(len(centers), len(skymap))
61class SkyMapTestCase(lsst.utils.tests.TestCase):
62 """An abstract base class for testing a SkyMap.
64 To use, subclass and call `setAttributes` from `setUp`
65 """
66 def setAttributes(self, *,
67 SkyMapClass,
68 name,
69 numTracts,
70 config=None,
71 neighborAngularSeparation=None,
72 numNeighbors=None):
73 """Initialize the test (call from setUp in the subclass)
75 Parameters
76 ----------
77 SkyMapClass : subclass of `lsst.skymap.BaseSkyMap`
78 Class of sky map to test
79 name : `str`
80 Name of sky map in sky map registry
81 numTracts : `int`
82 Number of tracts in the default configuration
83 config : subclass of `lsst.skymap.SkyMapConfig`, optional
84 Default configuration used by `getSkyMap`;
85 if None use SkyMapClass.ConfigClass()
86 neighborAngularSeparation : `lsst.geom.Angle`, optional
87 Expected angular separation between tracts;
88 if None then angular separation is not tested unless your
89 subclass of SkyMapTestCaseoverrides `testTractSeparation`.
90 numNeighbors : `int` or `None`
91 Number of neighbors that should be within
92 ``neighborAngularSeparation``;
93 Required if ``neighborAngularSeparation`` is not None;
94 ignored otherwise.
95 """
96 self.SkyMapClass = SkyMapClass
97 self.config = config
98 self.name = name
99 self.numTracts = numTracts
100 self.neighborAngularSeparation = neighborAngularSeparation
101 self.numNeighbors = numNeighbors
102 np.random.seed(47)
104 def getSkyMap(self, config=None):
105 """Provide an instance of the skymap"""
106 if config is None:
107 config = self.getConfig()
108 config.validate()
109 return self.SkyMapClass(config=config)
111 def getConfig(self):
112 """Provide an instance of the configuration class"""
113 if self.config is None:
114 return self.SkyMapClass.ConfigClass()
115 # Want to return a copy of self.config, so it can be modified.
116 return copy.copy(self.config)
118 def testRegistry(self):
119 """Confirm that the skymap can be retrieved from the registry"""
120 self.assertEqual(skyMapRegistry[self.name], self.SkyMapClass)
122 def testBasicAttributes(self):
123 """Confirm that constructor attributes are available
124 """
125 defaultSkyMap = self.getSkyMap()
126 for tractOverlap in (0.0, 0.01, 0.1): # degrees
127 config = self.getConfig()
128 config.tractOverlap = tractOverlap
129 skyMap = self.getSkyMap(config)
130 for tractInfo in skyMap:
131 self.assertAlmostEqual(tractInfo.getTractOverlap().asDegrees(), tractOverlap)
132 self.assertEqual(len(skyMap), self.numTracts)
133 self.assertNotEqual(skyMap, defaultSkyMap)
135 if defaultSkyMap.config.tractBuilder.name == 'cells':
136 # The following tests are not appropriate for cells
137 # see test_ringsSkyMapCells.py for "cell" tract testing.
138 return
140 for patchBorder in (0, 101):
141 config = self.getConfig()
142 config.patchBorder = patchBorder
143 skyMap = self.getSkyMap(config)
144 for tractInfo in skyMap:
145 self.assertEqual(tractInfo.getPatchBorder(), patchBorder)
146 self.assertEqual(len(skyMap), self.numTracts)
147 self.assertNotEqual(skyMap, defaultSkyMap)
149 for xInnerDim in (1005, 5062):
150 for yInnerDim in (2032, 5431):
151 config = self.getConfig()
152 config.patchInnerDimensions = (xInnerDim, yInnerDim)
153 skyMap = self.getSkyMap(config)
154 for tractInfo in skyMap:
155 self.assertEqual(tuple(tractInfo.getPatchInnerDimensions()), (xInnerDim, yInnerDim))
156 self.assertEqual(len(skyMap), self.numTracts)
157 self.assertNotEqual(skyMap, defaultSkyMap)
159 # Compare a few patches
160 defaultSkyMap = self.getSkyMap()
161 tractInfo = defaultSkyMap[0]
162 numPatches = tractInfo.getNumPatches()
163 patchBorder = defaultSkyMap.config.patchBorder
164 for xInd in (0, 1, numPatches[0]//2, numPatches[0]-2, numPatches[0]-1):
165 for yInd in (0, 1, numPatches[1]//2, numPatches[1]-2, numPatches[1]-1):
166 patchInfo = tractInfo.getPatchInfo((xInd, yInd))
168 # check inner and outer bbox
169 innerBBox = patchInfo.getInnerBBox()
170 outerBBox = patchInfo.getOuterBBox()
172 if xInd == 0:
173 self.assertEqual(innerBBox.getMinX(), outerBBox.getMinX())
174 else:
175 self.assertEqual(innerBBox.getMinX() - patchBorder, outerBBox.getMinX())
176 if yInd == 0:
177 self.assertEqual(innerBBox.getMinY(), outerBBox.getMinY())
178 else:
179 self.assertEqual(innerBBox.getMinY() - patchBorder, outerBBox.getMinY())
181 if xInd == numPatches[0] - 1:
182 self.assertEqual(innerBBox.getMaxX(), outerBBox.getMaxX())
183 else:
184 self.assertEqual(innerBBox.getMaxX() + patchBorder, outerBBox.getMaxX())
185 if yInd == numPatches[1] - 1:
186 self.assertEqual(innerBBox.getMaxY(), outerBBox.getMaxY())
187 else:
188 self.assertEqual(innerBBox.getMaxY() + patchBorder, outerBBox.getMaxY())
190 def assertUnpickledTractInfo(self, unpickled, original, patchBorder):
191 """Assert that an unpickled TractInfo is functionally identical to the
192 original.
194 Parameters
195 ----------
196 unpickled : `TractInfo`
197 The unpickled `TractInfo`.
198 original : `TractInfo`
199 The original `TractInfo`.
200 patchBorder : `int`
201 Border around each patch, from ``SkyMap.config.patchBorder``.
202 """
203 for getterName in ("getBBox",
204 "getCtrCoord",
205 "getId",
206 "getNumPatches",
207 "getPatchBorder",
208 "getPatchInnerDimensions",
209 "getTractOverlap",
210 "getVertexList",
211 "getWcs",
212 ):
213 self.assertEqual(getattr(original, getterName)(), getattr(unpickled, getterName)())
215 # test WCS at a few locations
216 wcs = original.getWcs()
217 unpickledWcs = unpickled.getWcs()
218 for x in (-1000.0, 0.0, 1000.0):
219 for y in (-532.5, 0.5, 532.5):
220 pixelPos = geom.Point2D(x, y)
221 skyPos = wcs.pixelToSky(pixelPos)
222 unpickledSkyPos = unpickledWcs.pixelToSky(pixelPos)
223 self.assertEqual(skyPos, unpickledSkyPos)
225 # compare a few patches
226 numPatches = original.getNumPatches()
227 for xInd in (0, 1, numPatches[0]//2, numPatches[0]-2, numPatches[0]-1):
228 for yInd in (0, 1, numPatches[1]//2, numPatches[1]-2, numPatches[1]-1):
229 patchInfo = original.getPatchInfo((xInd, yInd))
230 unpickledPatchInfo = unpickled.getPatchInfo((xInd, yInd))
231 self.assertEqual(patchInfo, unpickledPatchInfo)
233 def testPickle(self):
234 """Test that pickling and unpickling restores the original exactly
235 """
236 skyMap = self.getSkyMap()
237 pickleStr = pickle.dumps(skyMap)
238 unpickledSkyMap = pickle.loads(pickleStr)
239 self.assertEqual(len(skyMap), len(unpickledSkyMap))
240 self.assertEqual(unpickledSkyMap.config, skyMap.config)
241 self.assertEqual(skyMap, unpickledSkyMap)
242 for tractInfo, unpickledTractInfo in zip(skyMap, unpickledSkyMap):
243 self.assertUnpickledTractInfo(unpickledTractInfo, tractInfo, skyMap.config.patchBorder)
245 def testTractSeparation(self):
246 """Confirm that each sky tract has the proper distance to other tracts
247 """
248 if self.neighborAngularSeparation is None:
249 self.skipTest("Not testing angular separation for %s: neighborAngularSeparation is None" %
250 (self.SkyMapClass.__name__,))
251 skyMap = self.getSkyMap()
252 for tractId, tractInfo in enumerate(skyMap):
253 self.assertEqual(tractInfo.getId(), tractId)
255 ctrCoord = tractInfo.getCtrCoord()
256 distList = []
257 for tractInfo1 in skyMap:
258 otherCtrCoord = tractInfo1.getCtrCoord()
259 distList.append(ctrCoord.separation(otherCtrCoord))
260 distList.sort()
261 self.assertEqual(distList[0], 0.0)
262 for dist in distList[1:self.numNeighbors]:
263 self.assertAnglesAlmostEqual(dist, self.neighborAngularSeparation)
265 def testFindPatchList(self):
266 """Test TractInfo.findPatchList
267 """
268 skyMap = self.getSkyMap()
269 # pick two arbitrary tracts
270 for tractId in np.random.choice(len(skyMap), 2):
271 tractInfo = skyMap[tractId]
272 wcs = tractInfo.getWcs()
273 numPatches = tractInfo.getNumPatches()
274 border = tractInfo.getPatchBorder()
275 for patchInd in ((0, 0),
276 (0, 1),
277 (5, 0),
278 (5, 6),
279 (numPatches[0] - 2, numPatches[1] - 1),
280 (numPatches[0] - 1, numPatches[1] - 2),
281 (numPatches[0] - 1, numPatches[1] - 1),
282 ):
283 patchInfo = tractInfo.getPatchInfo(patchInd)
284 patchIndex = patchInfo.getIndex()
285 bbox = patchInfo.getInnerBBox()
286 bbox.grow(-(border+1))
287 coordList = getCornerCoords(wcs=wcs, bbox=bbox)
288 patchInfoList = tractInfo.findPatchList(coordList)
289 self.assertEqual(len(patchInfoList), 1)
290 self.assertEqual(patchInfoList[0].getIndex(), patchIndex)
292 # grow to include neighbors and test again
293 bbox.grow(2)
294 predFoundIndexSet = set()
295 for dx in (-1, 0, 1):
296 nbrX = patchIndex[0] + dx
297 if not 0 <= nbrX < numPatches[0]:
298 continue
299 for dy in (-1, 0, 1):
300 nbrY = patchIndex[1] + dy
301 if not 0 <= nbrY < numPatches[1]:
302 continue
303 nbrInd = (nbrX, nbrY)
304 predFoundIndexSet.add(nbrInd)
305 coordList = getCornerCoords(wcs=wcs, bbox=bbox)
306 patchInfoList = tractInfo.findPatchList(coordList)
307 self.assertEqual(len(patchInfoList), len(predFoundIndexSet))
308 foundIndexSet = set(patchInfo.getIndex() for patchInfo in patchInfoList)
309 self.assertEqual(foundIndexSet, predFoundIndexSet)
311 def testFindTractPatchList(self):
312 """Test findTractPatchList
314 Notes
315 -----
316 This test uses single points for speed and to avoid really large
317 regions.
318 Note that `findPatchList` is being tested elsewhere.
319 """
320 skyMap = self.getSkyMap()
321 # pick 3 arbitrary tracts
322 for tractId in np.random.choice(len(skyMap), 3):
323 tractInfo = skyMap[tractId]
324 self.assertTractPatchListOk(
325 skyMap=skyMap,
326 coordList=[tractInfo.getCtrCoord()],
327 knownTractId=tractId,
328 )
329 self.assertClosestTractPatchList(skyMap, [tractInfo.getCtrCoord()], tractId)
331 vertices = tractInfo.getVertexList()
332 if len(vertices) > 0:
333 self.assertTractPatchListOk(
334 skyMap=skyMap,
335 coordList=[tractInfo.getVertexList()[0]],
336 knownTractId=tractId,
337 )
339 if len(vertices) > 2:
340 self.assertTractPatchListOk(
341 skyMap=skyMap,
342 coordList=[tractInfo.getVertexList()[2]],
343 knownTractId=tractId,
344 )
346 def testTractContains(self):
347 """Test that TractInfo.contains works"""
348 skyMap = self.getSkyMap()
349 for tract in skyMap:
350 coord = tract.getCtrCoord()
351 self.assertTrue(tract.contains(coord))
352 opposite = geom.SpherePoint(coord.getLongitude() + 12*geom.hours, -1*coord.getLatitude())
353 self.assertFalse(tract.contains(opposite))
355 def testTractInfoGetPolygon(self):
356 skyMap = self.getSkyMap()
357 for tractInfo in skyMap:
358 centerCoord = tractInfo.getCtrCoord()
359 self.assertPolygonOk(polygon=tractInfo.getInnerSkyPolygon(),
360 vertexList=tractInfo.getVertexList(),
361 centerCoord=centerCoord)
362 self.assertBBoxPolygonOk(polygon=tractInfo.getOuterSkyPolygon(),
363 bbox=tractInfo.getBBox(), wcs=tractInfo.getWcs())
365 def testPatchInfoGetPolygon(self):
366 skyMap = self.getSkyMap()
367 numPatches = skyMap[0].getNumPatches()
369 def getIndices(numItems):
370 """Return up to 3 indices for testing"""
371 if numItems > 2:
372 return (0, 1, numItems-1)
373 elif numItems > 1:
374 return (0, 1)
375 return (0,)
377 for tractInfo in skyMap:
378 wcs = tractInfo.getWcs()
379 for patchInd in itertools.product(getIndices(numPatches[0]), getIndices(numPatches[1])):
380 with self.subTest(patchInd=patchInd):
381 patchInfo = tractInfo.getPatchInfo(patchInd)
382 self.assertBBoxPolygonOk(polygon=patchInfo.getInnerSkyPolygon(tractWcs=wcs),
383 bbox=patchInfo.getInnerBBox(), wcs=wcs)
384 self.assertBBoxPolygonOk(polygon=patchInfo.getOuterSkyPolygon(tractWcs=wcs),
385 bbox=patchInfo.getOuterBBox(), wcs=wcs)
387 def testDm14809(self):
388 """Generic version of test that DM-14809 has been fixed"""
389 checkDm14809(self, self.getSkyMap())
391 def testNumbering(self):
392 """Check the numbering of tracts matches the indexing"""
393 skymap = self.getSkyMap()
394 expect = list(range(len(skymap)))
395 got = [tt.getId() for tt in skymap]
396 self.assertEqual(got, expect)
398 def assertTractPatchListOk(self, skyMap, coordList, knownTractId):
399 """Assert that findTractPatchList produces the correct results.
401 Parameters
402 ----------
403 skyMap : `BaseSkyMap`
404 Sky map to test.
405 coordList : `list` of `lsst.geom.SpherePoint`
406 Region to search for.
407 knownTractId : `int`
408 This tractId must appear in the found list.
409 """
410 tractPatchList = skyMap.findTractPatchList(coordList)
411 tractPatchDict = dict((tp[0].getId(), tp[1]) for tp in tractPatchList)
412 self.assertTrue(knownTractId in tractPatchDict)
413 for tractInfo in skyMap:
414 tractId = tractInfo.getId()
415 patchList = tractInfo.findPatchList(coordList)
416 if patchList:
417 self.assertTrue(tractId in tractPatchDict)
418 self.assertTrue(len(patchList) == len(tractPatchDict[tractId]))
419 else:
420 self.assertTrue(tractId not in tractPatchDict)
422 def assertClosestTractPatchList(self, skyMap, coordList, knownTractId):
423 if not hasattr(skyMap, "findClosestTractPatchList"):
424 self.skipTest("This skymap doesn't implement findClosestTractPatchList")
425 tractPatchList = skyMap.findClosestTractPatchList(coordList)
426 self.assertEqual(len(coordList), len(tractPatchList)) # One tract+patchList per coordinate
427 for coord, (tract, patchList) in zip(coordList, tractPatchList):
428 self.assertEqual(tract.getId(), knownTractId)
429 self.assertEqual(patchList, tract.findPatchList([coord]))
431 def assertBBoxPolygonOk(self, polygon, bbox, wcs):
432 """Assert that an on-sky polygon from a pixel bbox
433 covers the expected region.
435 Parameters
436 ----------
437 polygon : `lsst.sphgeom.ConvexPolygon`
438 On-sky polygon
439 vertexList : `iterable` of `lsst.geom.SpherePoint`
440 Vertices of polygon
441 centerCoord : `lsst.geom.SpherePoint`
442 A coord approximately in the center of the region
443 """
444 bboxd = geom.Box2D(bbox)
445 centerPixel = bboxd.getCenter()
446 centerCoord = wcs.pixelToSky(centerPixel)
447 skyCorners = getCornerCoords(wcs=wcs, bbox=bbox)
448 self.assertPolygonOk(polygon=polygon, vertexList=skyCorners, centerCoord=centerCoord)
450 def assertPolygonOk(self, polygon, vertexList, centerCoord):
451 """Assert that an on-sky polygon from covers the expected region.
453 Parameters
454 ----------
455 polygon : `lsst.sphgeom.ConvexPolygon`
456 On-sky polygon
457 vertexList : `iterable` of `lsst.geom.SpherePoint`
458 Vertices of polygon
459 centerCoord : `lsst.geom.SpherePoint`
460 A coord approximately in the center of the region
461 """
462 shiftAngle = 0.01*geom.arcseconds
463 self.assertTrue(polygon.contains(centerCoord.getVector()))
464 for vertex in vertexList:
465 bearingToCenter = vertex.bearingTo(centerCoord)
466 cornerShiftedIn = vertex.offset(bearing=bearingToCenter, amount=shiftAngle)
467 cornerShiftedOut = vertex.offset(bearing=bearingToCenter, amount=-shiftAngle)
468 self.assertTrue(polygon.contains(cornerShiftedIn.getVector()))
469 self.assertFalse(polygon.contains(cornerShiftedOut.getVector()))
472def getCornerCoords(wcs, bbox):
473 """Return the coords of the four corners of a bounding box
474 """
475 cornerPosList = geom.Box2D(bbox).getCorners()
476 return wcs.pixelToSky(cornerPosList)