Coverage for tests/helper/skyMapTestCase.py: 9%
230 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-05 17:57 -0800
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-05 17:57 -0800
1#
2# LSST Data Management System
3# Copyright 2008-2017 LSST Corporation.
4#
5# This product includes software developed by the
6# LSST Project (http://www.lsst.org/).
7#
8# This program is free software: you can redistribute it and/or modify
9# it under the terms of the GNU General Public License as published by
10# the Free Software Foundation, either version 3 of the License, or
11# (at your option) any later version.
12#
13# This program is distributed in the hope that it will be useful,
14# but WITHOUT ANY WARRANTY; without even the implied warranty of
15# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16# GNU General Public License for more details.
17#
18# You should have received a copy of the LSST License Statement and
19# the GNU General Public License along with this program. If not,
20# see <http://www.lsstcorp.org/LegalNotices/>.
21#
22import itertools
23import pickle
25import numpy as np
27import lsst.geom as geom
28import lsst.utils.tests
30from lsst.skymap import skyMapRegistry
33def checkDm14809(testcase, skymap):
34 """Test that DM-14809 has been fixed
36 The observed behaviour was:
38 skyMap.findTract(skyMap[9712].getCtrCoord()).getId() != 9712
40 and
42 skyMap[1].getCtrCoord() == skyMap[11].getCtrCoord()
44 In order to be thorough, we generalise these over the entire skymap.
45 """
46 # Check that the tract found for central coordinate of a tract is that tract
47 expect = [tract.getId() for tract in skymap]
48 got = [skymap.findTract(tract.getCtrCoord()).getId() for tract in skymap]
49 testcase.assertListEqual(got, expect)
51 # Check that the tract central coordinates are unique
52 # Round to integer arcminutes so differences are relatively immune to small numerical inaccuracies
53 centers = set([(int(coord.getRa().asArcminutes()), int(coord.getDec().asArcminutes())) for
54 coord in (tract.getCtrCoord() for tract in skymap)])
55 testcase.assertEqual(len(centers), len(skymap))
58class SkyMapTestCase(lsst.utils.tests.TestCase):
59 """An abstract base class for testing a SkyMap.
61 To use, subclass and call `setAttributes` from `setUp`
62 """
63 def setAttributes(self, *,
64 SkyMapClass,
65 name,
66 numTracts,
67 config=None,
68 neighborAngularSeparation=None,
69 numNeighbors=None):
70 """Initialize the test (call from setUp in the subclass)
72 Parameters
73 ----------
74 SkyMapClass : subclass of `lsst.skymap.BaseSkyMap`
75 Class of sky map to test
76 name : `str`
77 Name of sky map in sky map registry
78 numTracts : `int`
79 Number of tracts in the default configuration
80 config : subclass of `lsst.skymap.SkyMapConfig`, optional
81 Default configuration used by `getSkyMap`;
82 if None use SkyMapClass.ConfigClass()
83 neighborAngularSeparation : `lsst.geom.Angle`, optional
84 Expected angular separation between tracts;
85 if None then angular separation is not tested unless your
86 subclass of SkyMapTestCaseoverrides `testTractSeparation`.
87 numNeighbors : `int` or `None`
88 Number of neighbors that should be within
89 ``neighborAngularSeparation``;
90 Required if ``neighborAngularSeparation`` is not None;
91 ignored otherwise.
92 """
93 self.SkyMapClass = SkyMapClass
94 self.config = config
95 self.name = name
96 self.numTracts = numTracts
97 self.neighborAngularSeparation = neighborAngularSeparation
98 self.numNeighbors = numNeighbors
99 np.random.seed(47)
101 def getSkyMap(self, config=None):
102 """Provide an instance of the skymap"""
103 if config is None:
104 config = self.getConfig()
105 return self.SkyMapClass(config=config)
107 def getConfig(self):
108 """Provide an instance of the configuration class"""
109 if self.config is None:
110 return self.SkyMapClass.ConfigClass()
111 # Want to return a copy of self.config, so it can be modified.
112 # However, there is no Config.copy() method, so this is more complicated than desirable.
113 return pickle.loads(pickle.dumps(self.config))
115 def testRegistry(self):
116 """Confirm that the skymap can be retrieved from the registry"""
117 self.assertEqual(skyMapRegistry[self.name], self.SkyMapClass)
119 def testBasicAttributes(self):
120 """Confirm that constructor attributes are available
121 """
122 defaultSkyMap = self.getSkyMap()
123 for tractOverlap in (0.0, 0.01, 0.1): # degrees
124 config = self.getConfig()
125 config.tractOverlap = tractOverlap
126 skyMap = self.getSkyMap(config)
127 for tractInfo in skyMap:
128 self.assertAlmostEqual(tractInfo.getTractOverlap().asDegrees(), tractOverlap)
129 self.assertEqual(len(skyMap), self.numTracts)
130 self.assertNotEqual(skyMap, defaultSkyMap)
132 for patchBorder in (0, 101):
133 config = self.getConfig()
134 config.patchBorder = patchBorder
135 skyMap = self.getSkyMap(config)
136 for tractInfo in skyMap:
137 self.assertEqual(tractInfo.getPatchBorder(), patchBorder)
138 self.assertEqual(len(skyMap), self.numTracts)
139 self.assertNotEqual(skyMap, defaultSkyMap)
141 for xInnerDim in (1005, 5062):
142 for yInnerDim in (2032, 5431):
143 config = self.getConfig()
144 config.patchInnerDimensions = (xInnerDim, yInnerDim)
145 skyMap = self.getSkyMap(config)
146 for tractInfo in skyMap:
147 self.assertEqual(tuple(tractInfo.getPatchInnerDimensions()), (xInnerDim, yInnerDim))
148 self.assertEqual(len(skyMap), self.numTracts)
149 self.assertNotEqual(skyMap, defaultSkyMap)
151 def assertUnpickledTractInfo(self, unpickled, original, patchBorder):
152 """Assert that an unpickled TractInfo is functionally identical to the original
154 @param unpickled The unpickled TractInfo
155 @param original The original TractInfo
156 @param patchBorder Border around each patch, from SkyMap.config.patchBorder
157 """
158 for getterName in ("getBBox",
159 "getCtrCoord",
160 "getId",
161 "getNumPatches",
162 "getPatchBorder",
163 "getPatchInnerDimensions",
164 "getTractOverlap",
165 "getVertexList",
166 "getWcs",
167 ):
168 self.assertEqual(getattr(original, getterName)(), getattr(unpickled, getterName)())
170 # test WCS at a few locations
171 wcs = original.getWcs()
172 unpickledWcs = unpickled.getWcs()
173 for x in (-1000.0, 0.0, 1000.0):
174 for y in (-532.5, 0.5, 532.5):
175 pixelPos = geom.Point2D(x, y)
176 skyPos = wcs.pixelToSky(pixelPos)
177 unpickledSkyPos = unpickledWcs.pixelToSky(pixelPos)
178 self.assertEqual(skyPos, unpickledSkyPos)
180 # compare a few patches
181 numPatches = original.getNumPatches()
182 for xInd in (0, 1, numPatches[0]//2, numPatches[0]-2, numPatches[0]-1):
183 for yInd in (0, 1, numPatches[1]//2, numPatches[1]-2, numPatches[1]-1):
184 patchInfo = original.getPatchInfo((xInd, yInd))
185 unpickledPatchInfo = unpickled.getPatchInfo((xInd, yInd))
186 self.assertEqual(patchInfo, unpickledPatchInfo)
188 # check inner and outer bbox (nothing to do with pickle,
189 # but a convenient place for the test)
190 innerBBox = patchInfo.getInnerBBox()
191 outerBBox = patchInfo.getOuterBBox()
193 if xInd == 0:
194 self.assertEqual(innerBBox.getMinX(), outerBBox.getMinX())
195 else:
196 self.assertEqual(innerBBox.getMinX() - patchBorder, outerBBox.getMinX())
197 if yInd == 0:
198 self.assertEqual(innerBBox.getMinY(), outerBBox.getMinY())
199 else:
200 self.assertEqual(innerBBox.getMinY() - patchBorder, outerBBox.getMinY())
202 if xInd == numPatches[0] - 1:
203 self.assertEqual(innerBBox.getMaxX(), outerBBox.getMaxX())
204 else:
205 self.assertEqual(innerBBox.getMaxX() + patchBorder, outerBBox.getMaxX())
206 if yInd == numPatches[1] - 1:
207 self.assertEqual(innerBBox.getMaxY(), outerBBox.getMaxY())
208 else:
209 self.assertEqual(innerBBox.getMaxY() + patchBorder, outerBBox.getMaxY())
211 def testPickle(self):
212 """Test that pickling and unpickling restores the original exactly
213 """
214 skyMap = self.getSkyMap()
215 pickleStr = pickle.dumps(skyMap)
216 unpickledSkyMap = pickle.loads(pickleStr)
217 self.assertEqual(len(skyMap), len(unpickledSkyMap))
218 self.assertEqual(unpickledSkyMap.config, skyMap.config)
219 self.assertEqual(skyMap, unpickledSkyMap)
220 for tractInfo, unpickledTractInfo in zip(skyMap, unpickledSkyMap):
221 self.assertUnpickledTractInfo(unpickledTractInfo, tractInfo, skyMap.config.patchBorder)
223 def testTractSeparation(self):
224 """Confirm that each sky tract has the proper distance to other tracts
225 """
226 if self.neighborAngularSeparation is None:
227 self.skipTest("Not testing angular separation for %s: neighborAngularSeparation is None" %
228 (self.SkyMapClass.__name__,))
229 skyMap = self.getSkyMap()
230 for tractId, tractInfo in enumerate(skyMap):
231 self.assertEqual(tractInfo.getId(), tractId)
233 ctrCoord = tractInfo.getCtrCoord()
234 distList = []
235 for tractInfo1 in skyMap:
236 otherCtrCoord = tractInfo1.getCtrCoord()
237 distList.append(ctrCoord.separation(otherCtrCoord))
238 distList.sort()
239 self.assertEqual(distList[0], 0.0)
240 for dist in distList[1:self.numNeighbors]:
241 self.assertAnglesAlmostEqual(dist, self.neighborAngularSeparation)
243 def testFindPatchList(self):
244 """Test TractInfo.findPatchList
245 """
246 skyMap = self.getSkyMap()
247 # pick two arbitrary tracts
248 for tractId in np.random.choice(len(skyMap), 2):
249 tractInfo = skyMap[tractId]
250 wcs = tractInfo.getWcs()
251 numPatches = tractInfo.getNumPatches()
252 border = tractInfo.getPatchBorder()
253 for patchInd in ((0, 0),
254 (0, 1),
255 (5, 0),
256 (5, 6),
257 (numPatches[0] - 2, numPatches[1] - 1),
258 (numPatches[0] - 1, numPatches[1] - 2),
259 (numPatches[0] - 1, numPatches[1] - 1),
260 ):
261 patchInfo = tractInfo.getPatchInfo(patchInd)
262 patchIndex = patchInfo.getIndex()
263 bbox = patchInfo.getInnerBBox()
264 bbox.grow(-(border+1))
265 coordList = getCornerCoords(wcs=wcs, bbox=bbox)
266 patchInfoList = tractInfo.findPatchList(coordList)
267 self.assertEqual(len(patchInfoList), 1)
268 self.assertEqual(patchInfoList[0].getIndex(), patchIndex)
270 # grow to include neighbors and test again
271 bbox.grow(2)
272 predFoundIndexSet = set()
273 for dx in (-1, 0, 1):
274 nbrX = patchIndex[0] + dx
275 if not 0 <= nbrX < numPatches[0]:
276 continue
277 for dy in (-1, 0, 1):
278 nbrY = patchIndex[1] + dy
279 if not 0 <= nbrY < numPatches[1]:
280 continue
281 nbrInd = (nbrX, nbrY)
282 predFoundIndexSet.add(nbrInd)
283 coordList = getCornerCoords(wcs=wcs, bbox=bbox)
284 patchInfoList = tractInfo.findPatchList(coordList)
285 self.assertEqual(len(patchInfoList), len(predFoundIndexSet))
286 foundIndexSet = set(patchInfo.getIndex() for patchInfo in patchInfoList)
287 self.assertEqual(foundIndexSet, predFoundIndexSet)
289 def testFindTractPatchList(self):
290 """Test findTractPatchList
292 Note: this test uses single points for speed and to avoid really large regions.
293 Note that findPatchList is being tested elsewhere.
294 """
295 skyMap = self.getSkyMap()
296 # pick 3 arbitrary tracts
297 for tractId in np.random.choice(len(skyMap), 3):
298 tractInfo = skyMap[tractId]
299 self.assertTractPatchListOk(
300 skyMap=skyMap,
301 coordList=[tractInfo.getCtrCoord()],
302 knownTractId=tractId,
303 )
304 self.assertClosestTractPatchList(skyMap, [tractInfo.getCtrCoord()], tractId)
306 vertices = tractInfo.getVertexList()
307 if len(vertices) > 0:
308 self.assertTractPatchListOk(
309 skyMap=skyMap,
310 coordList=[tractInfo.getVertexList()[0]],
311 knownTractId=tractId,
312 )
314 if len(vertices) > 2:
315 self.assertTractPatchListOk(
316 skyMap=skyMap,
317 coordList=[tractInfo.getVertexList()[2]],
318 knownTractId=tractId,
319 )
321 def testTractContains(self):
322 """Test that TractInfo.contains works"""
323 skyMap = self.getSkyMap()
324 for tract in skyMap:
325 coord = tract.getCtrCoord()
326 self.assertTrue(tract.contains(coord))
327 opposite = geom.SpherePoint(coord.getLongitude() + 12*geom.hours, -1*coord.getLatitude())
328 self.assertFalse(tract.contains(opposite))
330 def testTractInfoGetPolygon(self):
331 skyMap = self.getSkyMap()
332 for tractInfo in skyMap:
333 centerCoord = tractInfo.getCtrCoord()
334 self.assertPolygonOk(polygon=tractInfo.getInnerSkyPolygon(),
335 vertexList=tractInfo.getVertexList(),
336 centerCoord=centerCoord)
337 self.assertBBoxPolygonOk(polygon=tractInfo.getOuterSkyPolygon(),
338 bbox=tractInfo.getBBox(), wcs=tractInfo.getWcs())
340 def testPatchInfoGetPolygon(self):
341 skyMap = self.getSkyMap()
342 numPatches = skyMap[0].getNumPatches()
344 def getIndices(numItems):
345 """Return up to 3 indices for testing"""
346 if numItems > 2:
347 return (0, 1, numItems-1)
348 elif numItems > 1:
349 return (0, 1)
350 return (0,)
352 for tractInfo in skyMap:
353 wcs = tractInfo.getWcs()
354 for patchInd in itertools.product(getIndices(numPatches[0]), getIndices(numPatches[1])):
355 with self.subTest(patchInd=patchInd):
356 patchInfo = tractInfo.getPatchInfo(patchInd)
357 self.assertBBoxPolygonOk(polygon=patchInfo.getInnerSkyPolygon(tractWcs=wcs),
358 bbox=patchInfo.getInnerBBox(), wcs=wcs)
359 self.assertBBoxPolygonOk(polygon=patchInfo.getOuterSkyPolygon(tractWcs=wcs),
360 bbox=patchInfo.getOuterBBox(), wcs=wcs)
362 def testDm14809(self):
363 """Generic version of test that DM-14809 has been fixed"""
364 checkDm14809(self, self.getSkyMap())
366 def testNumbering(self):
367 """Check the numbering of tracts matches the indexing"""
368 skymap = self.getSkyMap()
369 expect = list(range(len(skymap)))
370 got = [tt.getId() for tt in skymap]
371 self.assertEqual(got, expect)
373 def assertTractPatchListOk(self, skyMap, coordList, knownTractId):
374 """Assert that findTractPatchList produces the correct results
376 @param[in] skyMap: sky map to test
377 @param[in] coordList: coordList of region to search for
378 @param[in] knownTractId: this tractId must appear in the found list
379 """
380 tractPatchList = skyMap.findTractPatchList(coordList)
381 tractPatchDict = dict((tp[0].getId(), tp[1]) for tp in tractPatchList)
382 self.assertTrue(knownTractId in tractPatchDict)
383 for tractInfo in skyMap:
384 tractId = tractInfo.getId()
385 patchList = tractInfo.findPatchList(coordList)
386 if patchList:
387 self.assertTrue(tractId in tractPatchDict)
388 self.assertTrue(len(patchList) == len(tractPatchDict[tractId]))
389 else:
390 self.assertTrue(tractId not in tractPatchDict)
392 def assertClosestTractPatchList(self, skyMap, coordList, knownTractId):
393 if not hasattr(skyMap, "findClosestTractPatchList"):
394 self.skipTest("This skymap doesn't implement findClosestTractPatchList")
395 tractPatchList = skyMap.findClosestTractPatchList(coordList)
396 self.assertEqual(len(coordList), len(tractPatchList)) # One tract+patchList per coordinate
397 for coord, (tract, patchList) in zip(coordList, tractPatchList):
398 self.assertEqual(tract.getId(), knownTractId)
399 self.assertEqual(patchList, tract.findPatchList([coord]))
401 def assertBBoxPolygonOk(self, polygon, bbox, wcs):
402 """Assert that an on-sky polygon from a pixel bbox
403 covers the expected region.
405 Parameters
406 ----------
407 polygon : `lsst.sphgeom.ConvexPolygon`
408 On-sky polygon
409 vertexList : `iterable` of `lsst.geom.SpherePoint`
410 Vertices of polygon
411 centerCoord : `lsst.geom.SpherePoint`
412 A coord approximately in the center of the region
413 """
414 bboxd = geom.Box2D(bbox)
415 centerPixel = bboxd.getCenter()
416 centerCoord = wcs.pixelToSky(centerPixel)
417 skyCorners = getCornerCoords(wcs=wcs, bbox=bbox)
418 self.assertPolygonOk(polygon=polygon, vertexList=skyCorners, centerCoord=centerCoord)
420 def assertPolygonOk(self, polygon, vertexList, centerCoord):
421 """Assert that an on-sky polygon from covers the expected region.
423 Parameters
424 ----------
425 polygon : `lsst.sphgeom.ConvexPolygon`
426 On-sky polygon
427 vertexList : `iterable` of `lsst.geom.SpherePoint`
428 Vertices of polygon
429 centerCoord : `lsst.geom.SpherePoint`
430 A coord approximately in the center of the region
431 """
432 shiftAngle = 0.01*geom.arcseconds
433 self.assertTrue(polygon.contains(centerCoord.getVector()))
434 for vertex in vertexList:
435 bearingToCenter = vertex.bearingTo(centerCoord)
436 cornerShiftedIn = vertex.offset(bearing=bearingToCenter, amount=shiftAngle)
437 cornerShiftedOut = vertex.offset(bearing=bearingToCenter, amount=-shiftAngle)
438 self.assertTrue(polygon.contains(cornerShiftedIn.getVector()))
439 self.assertFalse(polygon.contains(cornerShiftedOut.getVector()))
442def getCornerCoords(wcs, bbox):
443 """Return the coords of the four corners of a bounding box
444 """
445 cornerPosList = geom.Box2D(bbox).getCorners()
446 return wcs.pixelToSky(cornerPosList)