Coverage for tests/helper/skyMapTestCase.py: 8%
241 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-01-10 02:59 -0800
« prev ^ index » next coverage.py v6.5.0, created at 2023-01-10 02:59 -0800
1#
2# LSST Data Management System
3# Copyright 2008-2017 LSST Corporation.
4#
5# This product includes software developed by the
6# LSST Project (http://www.lsst.org/).
7#
8# This program is free software: you can redistribute it and/or modify
9# it under the terms of the GNU General Public License as published by
10# the Free Software Foundation, either version 3 of the License, or
11# (at your option) any later version.
12#
13# This program is distributed in the hope that it will be useful,
14# but WITHOUT ANY WARRANTY; without even the implied warranty of
15# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16# GNU General Public License for more details.
17#
18# You should have received a copy of the LSST License Statement and
19# the GNU General Public License along with this program. If not,
20# see <http://www.lsstcorp.org/LegalNotices/>.
21#
22import itertools
23import pickle
24import copy
26import numpy as np
28import lsst.geom as geom
29import lsst.utils.tests
31from lsst.skymap import skyMapRegistry
34def checkDm14809(testcase, skymap):
35 """Test that DM-14809 has been fixed
37 The observed behaviour was:
39 skyMap.findTract(skyMap[9712].getCtrCoord()).getId() != 9712
41 and
43 skyMap[1].getCtrCoord() == skyMap[11].getCtrCoord()
45 In order to be thorough, we generalise these over the entire skymap.
46 """
47 # Check that the tract found for central coordinate of a tract is that tract
48 expect = [tract.getId() for tract in skymap]
49 got = [skymap.findTract(tract.getCtrCoord()).getId() for tract in skymap]
50 testcase.assertListEqual(got, expect)
52 # Check that the tract central coordinates are unique
53 # Round to integer arcminutes so differences are relatively immune to small numerical inaccuracies
54 centers = set([(int(coord.getRa().asArcminutes()), int(coord.getDec().asArcminutes())) for
55 coord in (tract.getCtrCoord() for tract in skymap)])
56 testcase.assertEqual(len(centers), len(skymap))
59class SkyMapTestCase(lsst.utils.tests.TestCase):
60 """An abstract base class for testing a SkyMap.
62 To use, subclass and call `setAttributes` from `setUp`
63 """
64 def setAttributes(self, *,
65 SkyMapClass,
66 name,
67 numTracts,
68 config=None,
69 neighborAngularSeparation=None,
70 numNeighbors=None):
71 """Initialize the test (call from setUp in the subclass)
73 Parameters
74 ----------
75 SkyMapClass : subclass of `lsst.skymap.BaseSkyMap`
76 Class of sky map to test
77 name : `str`
78 Name of sky map in sky map registry
79 numTracts : `int`
80 Number of tracts in the default configuration
81 config : subclass of `lsst.skymap.SkyMapConfig`, optional
82 Default configuration used by `getSkyMap`;
83 if None use SkyMapClass.ConfigClass()
84 neighborAngularSeparation : `lsst.geom.Angle`, optional
85 Expected angular separation between tracts;
86 if None then angular separation is not tested unless your
87 subclass of SkyMapTestCaseoverrides `testTractSeparation`.
88 numNeighbors : `int` or `None`
89 Number of neighbors that should be within
90 ``neighborAngularSeparation``;
91 Required if ``neighborAngularSeparation`` is not None;
92 ignored otherwise.
93 """
94 self.SkyMapClass = SkyMapClass
95 self.config = config
96 self.name = name
97 self.numTracts = numTracts
98 self.neighborAngularSeparation = neighborAngularSeparation
99 self.numNeighbors = numNeighbors
100 np.random.seed(47)
102 def getSkyMap(self, config=None):
103 """Provide an instance of the skymap"""
104 if config is None:
105 config = self.getConfig()
106 config.validate()
107 return self.SkyMapClass(config=config)
109 def getConfig(self):
110 """Provide an instance of the configuration class"""
111 if self.config is None:
112 return self.SkyMapClass.ConfigClass()
113 # Want to return a copy of self.config, so it can be modified.
114 return copy.copy(self.config)
116 def testRegistry(self):
117 """Confirm that the skymap can be retrieved from the registry"""
118 self.assertEqual(skyMapRegistry[self.name], self.SkyMapClass)
120 def testBasicAttributes(self):
121 """Confirm that constructor attributes are available
122 """
123 defaultSkyMap = self.getSkyMap()
124 for tractOverlap in (0.0, 0.01, 0.1): # degrees
125 config = self.getConfig()
126 config.tractOverlap = tractOverlap
127 skyMap = self.getSkyMap(config)
128 for tractInfo in skyMap:
129 self.assertAlmostEqual(tractInfo.getTractOverlap().asDegrees(), tractOverlap)
130 self.assertEqual(len(skyMap), self.numTracts)
131 self.assertNotEqual(skyMap, defaultSkyMap)
133 if defaultSkyMap.config.tractBuilder.name == 'cells':
134 # The following tests are not appropriate for cells
135 # see test_ringsSkyMapCells.py for "cell" tract testing.
136 return
138 for patchBorder in (0, 101):
139 config = self.getConfig()
140 config.patchBorder = patchBorder
141 skyMap = self.getSkyMap(config)
142 for tractInfo in skyMap:
143 self.assertEqual(tractInfo.getPatchBorder(), patchBorder)
144 self.assertEqual(len(skyMap), self.numTracts)
145 self.assertNotEqual(skyMap, defaultSkyMap)
147 for xInnerDim in (1005, 5062):
148 for yInnerDim in (2032, 5431):
149 config = self.getConfig()
150 config.patchInnerDimensions = (xInnerDim, yInnerDim)
151 skyMap = self.getSkyMap(config)
152 for tractInfo in skyMap:
153 self.assertEqual(tuple(tractInfo.getPatchInnerDimensions()), (xInnerDim, yInnerDim))
154 self.assertEqual(len(skyMap), self.numTracts)
155 self.assertNotEqual(skyMap, defaultSkyMap)
157 # Compare a few patches
158 defaultSkyMap = self.getSkyMap()
159 tractInfo = defaultSkyMap[0]
160 numPatches = tractInfo.getNumPatches()
161 patchBorder = defaultSkyMap.config.patchBorder
162 for xInd in (0, 1, numPatches[0]//2, numPatches[0]-2, numPatches[0]-1):
163 for yInd in (0, 1, numPatches[1]//2, numPatches[1]-2, numPatches[1]-1):
164 patchInfo = tractInfo.getPatchInfo((xInd, yInd))
166 # check inner and outer bbox
167 innerBBox = patchInfo.getInnerBBox()
168 outerBBox = patchInfo.getOuterBBox()
170 if xInd == 0:
171 self.assertEqual(innerBBox.getMinX(), outerBBox.getMinX())
172 else:
173 self.assertEqual(innerBBox.getMinX() - patchBorder, outerBBox.getMinX())
174 if yInd == 0:
175 self.assertEqual(innerBBox.getMinY(), outerBBox.getMinY())
176 else:
177 self.assertEqual(innerBBox.getMinY() - patchBorder, outerBBox.getMinY())
179 if xInd == numPatches[0] - 1:
180 self.assertEqual(innerBBox.getMaxX(), outerBBox.getMaxX())
181 else:
182 self.assertEqual(innerBBox.getMaxX() + patchBorder, outerBBox.getMaxX())
183 if yInd == numPatches[1] - 1:
184 self.assertEqual(innerBBox.getMaxY(), outerBBox.getMaxY())
185 else:
186 self.assertEqual(innerBBox.getMaxY() + patchBorder, outerBBox.getMaxY())
188 def assertUnpickledTractInfo(self, unpickled, original, patchBorder):
189 """Assert that an unpickled TractInfo is functionally identical to the original
191 @param unpickled The unpickled TractInfo
192 @param original The original TractInfo
193 @param patchBorder Border around each patch, from SkyMap.config.patchBorder
194 """
195 for getterName in ("getBBox",
196 "getCtrCoord",
197 "getId",
198 "getNumPatches",
199 "getPatchBorder",
200 "getPatchInnerDimensions",
201 "getTractOverlap",
202 "getVertexList",
203 "getWcs",
204 ):
205 self.assertEqual(getattr(original, getterName)(), getattr(unpickled, getterName)())
207 # test WCS at a few locations
208 wcs = original.getWcs()
209 unpickledWcs = unpickled.getWcs()
210 for x in (-1000.0, 0.0, 1000.0):
211 for y in (-532.5, 0.5, 532.5):
212 pixelPos = geom.Point2D(x, y)
213 skyPos = wcs.pixelToSky(pixelPos)
214 unpickledSkyPos = unpickledWcs.pixelToSky(pixelPos)
215 self.assertEqual(skyPos, unpickledSkyPos)
217 # compare a few patches
218 numPatches = original.getNumPatches()
219 for xInd in (0, 1, numPatches[0]//2, numPatches[0]-2, numPatches[0]-1):
220 for yInd in (0, 1, numPatches[1]//2, numPatches[1]-2, numPatches[1]-1):
221 patchInfo = original.getPatchInfo((xInd, yInd))
222 unpickledPatchInfo = unpickled.getPatchInfo((xInd, yInd))
223 self.assertEqual(patchInfo, unpickledPatchInfo)
225 def testPickle(self):
226 """Test that pickling and unpickling restores the original exactly
227 """
228 skyMap = self.getSkyMap()
229 pickleStr = pickle.dumps(skyMap)
230 unpickledSkyMap = pickle.loads(pickleStr)
231 self.assertEqual(len(skyMap), len(unpickledSkyMap))
232 self.assertEqual(unpickledSkyMap.config, skyMap.config)
233 self.assertEqual(skyMap, unpickledSkyMap)
234 for tractInfo, unpickledTractInfo in zip(skyMap, unpickledSkyMap):
235 self.assertUnpickledTractInfo(unpickledTractInfo, tractInfo, skyMap.config.patchBorder)
237 def testTractSeparation(self):
238 """Confirm that each sky tract has the proper distance to other tracts
239 """
240 if self.neighborAngularSeparation is None:
241 self.skipTest("Not testing angular separation for %s: neighborAngularSeparation is None" %
242 (self.SkyMapClass.__name__,))
243 skyMap = self.getSkyMap()
244 for tractId, tractInfo in enumerate(skyMap):
245 self.assertEqual(tractInfo.getId(), tractId)
247 ctrCoord = tractInfo.getCtrCoord()
248 distList = []
249 for tractInfo1 in skyMap:
250 otherCtrCoord = tractInfo1.getCtrCoord()
251 distList.append(ctrCoord.separation(otherCtrCoord))
252 distList.sort()
253 self.assertEqual(distList[0], 0.0)
254 for dist in distList[1:self.numNeighbors]:
255 self.assertAnglesAlmostEqual(dist, self.neighborAngularSeparation)
257 def testFindPatchList(self):
258 """Test TractInfo.findPatchList
259 """
260 skyMap = self.getSkyMap()
261 # pick two arbitrary tracts
262 for tractId in np.random.choice(len(skyMap), 2):
263 tractInfo = skyMap[tractId]
264 wcs = tractInfo.getWcs()
265 numPatches = tractInfo.getNumPatches()
266 border = tractInfo.getPatchBorder()
267 for patchInd in ((0, 0),
268 (0, 1),
269 (5, 0),
270 (5, 6),
271 (numPatches[0] - 2, numPatches[1] - 1),
272 (numPatches[0] - 1, numPatches[1] - 2),
273 (numPatches[0] - 1, numPatches[1] - 1),
274 ):
275 patchInfo = tractInfo.getPatchInfo(patchInd)
276 patchIndex = patchInfo.getIndex()
277 bbox = patchInfo.getInnerBBox()
278 bbox.grow(-(border+1))
279 coordList = getCornerCoords(wcs=wcs, bbox=bbox)
280 patchInfoList = tractInfo.findPatchList(coordList)
281 self.assertEqual(len(patchInfoList), 1)
282 self.assertEqual(patchInfoList[0].getIndex(), patchIndex)
284 # grow to include neighbors and test again
285 bbox.grow(2)
286 predFoundIndexSet = set()
287 for dx in (-1, 0, 1):
288 nbrX = patchIndex[0] + dx
289 if not 0 <= nbrX < numPatches[0]:
290 continue
291 for dy in (-1, 0, 1):
292 nbrY = patchIndex[1] + dy
293 if not 0 <= nbrY < numPatches[1]:
294 continue
295 nbrInd = (nbrX, nbrY)
296 predFoundIndexSet.add(nbrInd)
297 coordList = getCornerCoords(wcs=wcs, bbox=bbox)
298 patchInfoList = tractInfo.findPatchList(coordList)
299 self.assertEqual(len(patchInfoList), len(predFoundIndexSet))
300 foundIndexSet = set(patchInfo.getIndex() for patchInfo in patchInfoList)
301 self.assertEqual(foundIndexSet, predFoundIndexSet)
303 def testFindTractPatchList(self):
304 """Test findTractPatchList
306 Note: this test uses single points for speed and to avoid really large regions.
307 Note that findPatchList is being tested elsewhere.
308 """
309 skyMap = self.getSkyMap()
310 # pick 3 arbitrary tracts
311 for tractId in np.random.choice(len(skyMap), 3):
312 tractInfo = skyMap[tractId]
313 self.assertTractPatchListOk(
314 skyMap=skyMap,
315 coordList=[tractInfo.getCtrCoord()],
316 knownTractId=tractId,
317 )
318 self.assertClosestTractPatchList(skyMap, [tractInfo.getCtrCoord()], tractId)
320 vertices = tractInfo.getVertexList()
321 if len(vertices) > 0:
322 self.assertTractPatchListOk(
323 skyMap=skyMap,
324 coordList=[tractInfo.getVertexList()[0]],
325 knownTractId=tractId,
326 )
328 if len(vertices) > 2:
329 self.assertTractPatchListOk(
330 skyMap=skyMap,
331 coordList=[tractInfo.getVertexList()[2]],
332 knownTractId=tractId,
333 )
335 def testTractContains(self):
336 """Test that TractInfo.contains works"""
337 skyMap = self.getSkyMap()
338 for tract in skyMap:
339 coord = tract.getCtrCoord()
340 self.assertTrue(tract.contains(coord))
341 opposite = geom.SpherePoint(coord.getLongitude() + 12*geom.hours, -1*coord.getLatitude())
342 self.assertFalse(tract.contains(opposite))
344 def testTractInfoGetPolygon(self):
345 skyMap = self.getSkyMap()
346 for tractInfo in skyMap:
347 centerCoord = tractInfo.getCtrCoord()
348 self.assertPolygonOk(polygon=tractInfo.getInnerSkyPolygon(),
349 vertexList=tractInfo.getVertexList(),
350 centerCoord=centerCoord)
351 self.assertBBoxPolygonOk(polygon=tractInfo.getOuterSkyPolygon(),
352 bbox=tractInfo.getBBox(), wcs=tractInfo.getWcs())
354 def testPatchInfoGetPolygon(self):
355 skyMap = self.getSkyMap()
356 numPatches = skyMap[0].getNumPatches()
358 def getIndices(numItems):
359 """Return up to 3 indices for testing"""
360 if numItems > 2:
361 return (0, 1, numItems-1)
362 elif numItems > 1:
363 return (0, 1)
364 return (0,)
366 for tractInfo in skyMap:
367 wcs = tractInfo.getWcs()
368 for patchInd in itertools.product(getIndices(numPatches[0]), getIndices(numPatches[1])):
369 with self.subTest(patchInd=patchInd):
370 patchInfo = tractInfo.getPatchInfo(patchInd)
371 self.assertBBoxPolygonOk(polygon=patchInfo.getInnerSkyPolygon(tractWcs=wcs),
372 bbox=patchInfo.getInnerBBox(), wcs=wcs)
373 self.assertBBoxPolygonOk(polygon=patchInfo.getOuterSkyPolygon(tractWcs=wcs),
374 bbox=patchInfo.getOuterBBox(), wcs=wcs)
376 def testDm14809(self):
377 """Generic version of test that DM-14809 has been fixed"""
378 checkDm14809(self, self.getSkyMap())
380 def testNumbering(self):
381 """Check the numbering of tracts matches the indexing"""
382 skymap = self.getSkyMap()
383 expect = list(range(len(skymap)))
384 got = [tt.getId() for tt in skymap]
385 self.assertEqual(got, expect)
387 def assertTractPatchListOk(self, skyMap, coordList, knownTractId):
388 """Assert that findTractPatchList produces the correct results
390 @param[in] skyMap: sky map to test
391 @param[in] coordList: coordList of region to search for
392 @param[in] knownTractId: this tractId must appear in the found list
393 """
394 tractPatchList = skyMap.findTractPatchList(coordList)
395 tractPatchDict = dict((tp[0].getId(), tp[1]) for tp in tractPatchList)
396 self.assertTrue(knownTractId in tractPatchDict)
397 for tractInfo in skyMap:
398 tractId = tractInfo.getId()
399 patchList = tractInfo.findPatchList(coordList)
400 if patchList:
401 self.assertTrue(tractId in tractPatchDict)
402 self.assertTrue(len(patchList) == len(tractPatchDict[tractId]))
403 else:
404 self.assertTrue(tractId not in tractPatchDict)
406 def assertClosestTractPatchList(self, skyMap, coordList, knownTractId):
407 if not hasattr(skyMap, "findClosestTractPatchList"):
408 self.skipTest("This skymap doesn't implement findClosestTractPatchList")
409 tractPatchList = skyMap.findClosestTractPatchList(coordList)
410 self.assertEqual(len(coordList), len(tractPatchList)) # One tract+patchList per coordinate
411 for coord, (tract, patchList) in zip(coordList, tractPatchList):
412 self.assertEqual(tract.getId(), knownTractId)
413 self.assertEqual(patchList, tract.findPatchList([coord]))
415 def assertBBoxPolygonOk(self, polygon, bbox, wcs):
416 """Assert that an on-sky polygon from a pixel bbox
417 covers the expected region.
419 Parameters
420 ----------
421 polygon : `lsst.sphgeom.ConvexPolygon`
422 On-sky polygon
423 vertexList : `iterable` of `lsst.geom.SpherePoint`
424 Vertices of polygon
425 centerCoord : `lsst.geom.SpherePoint`
426 A coord approximately in the center of the region
427 """
428 bboxd = geom.Box2D(bbox)
429 centerPixel = bboxd.getCenter()
430 centerCoord = wcs.pixelToSky(centerPixel)
431 skyCorners = getCornerCoords(wcs=wcs, bbox=bbox)
432 self.assertPolygonOk(polygon=polygon, vertexList=skyCorners, centerCoord=centerCoord)
434 def assertPolygonOk(self, polygon, vertexList, centerCoord):
435 """Assert that an on-sky polygon from covers the expected region.
437 Parameters
438 ----------
439 polygon : `lsst.sphgeom.ConvexPolygon`
440 On-sky polygon
441 vertexList : `iterable` of `lsst.geom.SpherePoint`
442 Vertices of polygon
443 centerCoord : `lsst.geom.SpherePoint`
444 A coord approximately in the center of the region
445 """
446 shiftAngle = 0.01*geom.arcseconds
447 self.assertTrue(polygon.contains(centerCoord.getVector()))
448 for vertex in vertexList:
449 bearingToCenter = vertex.bearingTo(centerCoord)
450 cornerShiftedIn = vertex.offset(bearing=bearingToCenter, amount=shiftAngle)
451 cornerShiftedOut = vertex.offset(bearing=bearingToCenter, amount=-shiftAngle)
452 self.assertTrue(polygon.contains(cornerShiftedIn.getVector()))
453 self.assertFalse(polygon.contains(cornerShiftedOut.getVector()))
456def getCornerCoords(wcs, bbox):
457 """Return the coords of the four corners of a bounding box
458 """
459 cornerPosList = geom.Box2D(bbox).getCorners()
460 return wcs.pixelToSky(cornerPosList)