Coverage for python/lsst/afw/geom/testUtils.py : 10%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1#
2# LSST Data Management System
3# Copyright 2016 LSST Corporation.
4#
5# This product includes software developed by the
6# LSST Project (http://www.lsst.org/).
7#
8# This program is free software: you can redistribute it and/or modify
9# it under the terms of the GNU General Public License as published by
10# the Free Software Foundation, either version 3 of the License, or
11# (at your option) any later version.
12#
13# This program is distributed in the hope that it will be useful,
14# but WITHOUT ANY WARRANTY; without even the implied warranty of
15# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16# GNU General Public License for more details.
17#
18# You should have received a copy of the LSST License Statement and
19# the GNU General Public License along with this program. If not,
20# see <http://www.lsstcorp.org/LegalNotices/>.
21#
23__all__ = ["BoxGrid", "makeSipIwcToPixel", "makeSipPixelToIwc"]
25import itertools
26import math
27import os
28import pickle
30import astshim as ast
31import numpy as np
32from numpy.testing import assert_allclose, assert_array_equal
33from astshim.test import makeForwardPolyMap, makeTwoWayPolyMap
34from lsst.afw.geom.wcsUtils import getCdMatrixFromMetadata
36import lsst.geom
37import lsst.afw.geom as afwGeom
38from lsst.pex.exceptions import InvalidParameterError
39import lsst.utils
40import lsst.utils.tests
43class BoxGrid:
44 """Divide a box into nx by ny sub-boxes that tile the region
46 The sub-boxes will be of the same type as `box` and will exactly tile `box`;
47 they will also all be the same size, to the extent possible (some variation
48 is inevitable for integer boxes that cannot be evenly divided.
50 Parameters
51 ----------
52 box : `lsst.geom.Box2I` or `lsst.geom.Box2D`
53 the box to subdivide; the boxes in the grid will be of the same type
54 numColRow : pair of `int`
55 number of columns and rows
56 """
58 def __init__(self, box, numColRow):
59 if len(numColRow) != 2:
60 raise RuntimeError(f"numColRow={numColRow!r}; must be a sequence of two integers")
61 self._numColRow = tuple(int(val) for val in numColRow)
63 if isinstance(box, lsst.geom.Box2I):
64 stopDelta = 1
65 elif isinstance(box, lsst.geom.Box2D):
66 stopDelta = 0
67 else:
68 raise RuntimeError(f"Unknown class {type(box)} of box {box}")
69 self.boxClass = type(box)
70 self.stopDelta = stopDelta
72 minPoint = box.getMin()
73 self.pointClass = type(minPoint)
74 dtype = np.array(minPoint).dtype
76 self._divList = [np.linspace(start=box.getMin()[i],
77 stop=box.getMax()[i] + self.stopDelta,
78 num=self._numColRow[i] + 1,
79 endpoint=True,
80 dtype=dtype) for i in range(2)]
82 @property
83 def numColRow(self):
84 return self._numColRow
86 def __getitem__(self, indXY):
87 """Return the box at the specified x,y index
89 Parameters
90 ----------
91 indXY : pair of `ints`
92 the x,y index to return
94 Returns
95 -------
96 subBox : `lsst.geom.Box2I` or `lsst.geom.Box2D`
97 """
98 beg = self.pointClass(*[self._divList[i][indXY[i]] for i in range(2)])
99 end = self.pointClass(
100 *[self._divList[i][indXY[i] + 1] - self.stopDelta for i in range(2)])
101 return self.boxClass(beg, end)
103 def __len__(self):
104 return self.shape[0]*self.shape[1]
106 def __iter__(self):
107 """Return an iterator over all boxes, where column varies most quickly
108 """
109 for row in range(self.numColRow[1]):
110 for col in range(self.numColRow[0]):
111 yield self[col, row]
114class FrameSetInfo:
115 """Information about a FrameSet
117 Parameters
118 ----------
119 frameSet : `ast.FrameSet`
120 The FrameSet about which you want information
122 Notes
123 -----
124 **Fields**
126 baseInd : `int`
127 Index of base frame
128 currInd : `int`
129 Index of current frame
130 isBaseSkyFrame : `bool`
131 Is the base frame an `ast.SkyFrame`?
132 isCurrSkyFrame : `bool`
133 Is the current frame an `ast.SkyFrame`?
134 """
135 def __init__(self, frameSet):
136 self.baseInd = frameSet.base
137 self.currInd = frameSet.current
138 self.isBaseSkyFrame = frameSet.getFrame(self.baseInd).className == "SkyFrame"
139 self.isCurrSkyFrame = frameSet.getFrame(self.currInd).className == "SkyFrame"
142def makeSipPolyMapCoeffs(metadata, name):
143 """Return a list of ast.PolyMap coefficients for the specified SIP matrix
145 The returned list of coefficients for an ast.PolyMap
146 that computes the following function:
148 f(dxy) = dxy + sipPolynomial(dxy))
149 where dxy = pixelPosition - pixelOrigin
150 and sipPolynomial is a polynomial with terms `<name>n_m for x^n y^m`
151 (e.g. A2_0 is the coefficient for x^2 y^0)
153 Parameters
154 ----------
155 metadata : lsst.daf.base.PropertySet
156 FITS metadata describing a WCS with the specified SIP coefficients
157 name : str
158 The desired SIP terms: one of A, B, AP, BP
160 Returns
161 -------
162 list
163 A list of coefficients for an ast.PolyMap that computes
164 the specified SIP polynomial, including a term for out = in
166 Note
167 ----
168 This is an internal function for use by makeSipIwcToPixel
169 and makeSipPixelToIwc
170 """
171 outAxisDict = dict(A=1, B=2, AP=1, BP=2)
172 outAxis = outAxisDict.get(name)
173 if outAxis is None:
174 raise RuntimeError(f"{name} not a supported SIP name")
175 width = metadata.getAsInt(f"{name}_ORDER") + 1
176 found = False
177 # start with a term for out = in
178 coeffs = []
179 if outAxis == 1:
180 coeffs.append([1.0, outAxis, 1, 0])
181 else:
182 coeffs.append([1.0, outAxis, 0, 1])
183 # add SIP distortion terms
184 for xPower in range(width):
185 for yPower in range(width):
186 coeffName = f"{name}_{xPower}_{yPower}"
187 if not metadata.exists(coeffName):
188 continue
189 found = True
190 coeff = metadata.getAsDouble(coeffName)
191 coeffs.append([coeff, outAxis, xPower, yPower])
192 if not found:
193 raise RuntimeError(f"No {name} coefficients found")
194 return coeffs
197def makeSipIwcToPixel(metadata):
198 """Make an IWC to pixel transform with SIP distortion from FITS-WCS metadata
200 This function is primarily intended for unit tests.
201 IWC is intermediate world coordinates, as described in the FITS papers.
203 Parameters
204 ----------
205 metadata : lsst.daf.base.PropertySet
206 FITS metadata describing a WCS with inverse SIP coefficients
208 Returns
209 -------
210 lsst.afw.geom.TransformPoint2ToPoint2
211 Transform from IWC position to pixel position (zero-based)
212 in the forward direction. The inverse direction is not defined.
214 Notes
215 -----
217 The inverse SIP terms APn_m, BPn_m are polynomial coefficients x^n y^m
218 for computing transformed x, y respectively. If we call the resulting
219 polynomial inverseSipPolynomial, the returned transformation is:
221 pixelPosition = pixel origin + uv + inverseSipPolynomial(uv)
222 where uv = inverseCdMatrix * iwcPosition
223 """
224 crpix = (metadata.getScalar("CRPIX1") - 1, metadata.getScalar("CRPIX2") - 1)
225 pixelRelativeToAbsoluteMap = ast.ShiftMap(crpix)
226 cdMatrix = getCdMatrixFromMetadata(metadata)
227 cdMatrixMap = ast.MatrixMap(cdMatrix.copy())
228 coeffList = makeSipPolyMapCoeffs(metadata, "AP") + makeSipPolyMapCoeffs(metadata, "BP")
229 coeffArr = np.array(coeffList, dtype=float)
230 sipPolyMap = ast.PolyMap(coeffArr, 2, "IterInverse=0")
232 iwcToPixelMap = cdMatrixMap.inverted().then(sipPolyMap).then(pixelRelativeToAbsoluteMap)
233 return afwGeom.TransformPoint2ToPoint2(iwcToPixelMap)
236def makeSipPixelToIwc(metadata):
237 """Make a pixel to IWC transform with SIP distortion from FITS-WCS metadata
239 This function is primarily intended for unit tests.
240 IWC is intermediate world coordinates, as described in the FITS papers.
242 Parameters
243 ----------
244 metadata : lsst.daf.base.PropertySet
245 FITS metadata describing a WCS with forward SIP coefficients
247 Returns
248 -------
249 lsst.afw.geom.TransformPoint2ToPoint2
250 Transform from pixel position (zero-based) to IWC position
251 in the forward direction. The inverse direction is not defined.
253 Notes
254 -----
256 The forward SIP terms An_m, Bn_m are polynomial coefficients x^n y^m
257 for computing transformed x, y respectively. If we call the resulting
258 polynomial sipPolynomial, the returned transformation is:
260 iwcPosition = cdMatrix * (dxy + sipPolynomial(dxy))
261 where dxy = pixelPosition - pixelOrigin
262 """
263 crpix = (metadata.getScalar("CRPIX1") - 1, metadata.getScalar("CRPIX2") - 1)
264 pixelAbsoluteToRelativeMap = ast.ShiftMap(crpix).inverted()
265 cdMatrix = getCdMatrixFromMetadata(metadata)
266 cdMatrixMap = ast.MatrixMap(cdMatrix.copy())
267 coeffList = makeSipPolyMapCoeffs(metadata, "A") + makeSipPolyMapCoeffs(metadata, "B")
268 coeffArr = np.array(coeffList, dtype=float)
269 sipPolyMap = ast.PolyMap(coeffArr, 2, "IterInverse=0")
270 pixelToIwcMap = pixelAbsoluteToRelativeMap.then(sipPolyMap).then(cdMatrixMap)
271 return afwGeom.TransformPoint2ToPoint2(pixelToIwcMap)
274class PermutedFrameSet:
275 """A FrameSet with base or current frame possibly permuted, with associated
276 information
278 Only two-axis frames will be permuted.
280 Parameters
281 ----------
282 frameSet : `ast.FrameSet`
283 The FrameSet you wish to permute. A deep copy is made.
284 permuteBase : `bool`
285 Permute the base frame's axes?
286 permuteCurr : `bool`
287 Permute the current frame's axes?
289 Raises
290 ------
291 RuntimeError
292 If you try to permute a frame that does not have 2 axes
294 Notes
295 -----
296 **Fields**
298 frameSet : `ast.FrameSet`
299 The FrameSet that may be permuted. A local copy is made.
300 isBaseSkyFrame : `bool`
301 Is the base frame an `ast.SkyFrame`?
302 isCurrSkyFrame : `bool`
303 Is the current frame an `ast.SkyFrame`?
304 isBasePermuted : `bool`
305 Are the base frame axes permuted?
306 isCurrPermuted : `bool`
307 Are the current frame axes permuted?
308 """
309 def __init__(self, frameSet, permuteBase, permuteCurr):
310 self.frameSet = frameSet.copy()
311 fsInfo = FrameSetInfo(self.frameSet)
312 self.isBaseSkyFrame = fsInfo.isBaseSkyFrame
313 self.isCurrSkyFrame = fsInfo.isCurrSkyFrame
314 if permuteBase:
315 baseNAxes = self.frameSet.getFrame(fsInfo.baseInd).nAxes
316 if baseNAxes != 2:
317 raise RuntimeError("Base frame has {} axes; 2 required to permute".format(baseNAxes))
318 self.frameSet.current = fsInfo.baseInd
319 self.frameSet.permAxes([2, 1])
320 self.frameSet.current = fsInfo.currInd
321 if permuteCurr:
322 currNAxes = self.frameSet.getFrame(fsInfo.currInd).nAxes
323 if currNAxes != 2:
324 raise RuntimeError("Current frame has {} axes; 2 required to permute".format(currNAxes))
325 assert self.frameSet.getFrame(fsInfo.currInd).nAxes == 2
326 self.frameSet.permAxes([2, 1])
327 self.isBasePermuted = permuteBase
328 self.isCurrPermuted = permuteCurr
331class TransformTestBaseClass(lsst.utils.tests.TestCase):
332 """Base class for unit tests of Transform<X>To<Y>
334 Subclasses must call `TransformTestBaseClass.setUp(self)`
335 if they provide their own version.
337 If a package other than afw uses this class then it must
338 override the `getTestDir` method to avoid writing into
339 afw's test directory.
340 """
342 def getTestDir(self):
343 """Return a directory where temporary test files can be written
345 The default implementation returns the test directory of the `afw`
346 package.
348 If this class is used by a test in a package other than `afw`
349 then the subclass must override this method.
350 """
351 return os.path.join(lsst.utils.getPackageDir("afw"), "tests")
353 def setUp(self):
354 """Set up a test
356 Subclasses should call this method if they override setUp.
357 """
358 # tell unittest to use the msg argument of asserts as a supplement
359 # to the error message, rather than as the whole error message
360 self.longMessage = True
362 # list of endpoint class name prefixes; the full name is prefix + "Endpoint"
363 self.endpointPrefixes = ("Generic", "Point2", "SpherePoint")
365 # GoodNAxes is dict of endpoint class name prefix:
366 # tuple containing 0 or more valid numbers of axes
367 self.goodNAxes = {
368 "Generic": (1, 2, 3, 4), # all numbers of axes are valid for GenericEndpoint
369 "Point2": (2,),
370 "SpherePoint": (2,),
371 }
373 # BadAxes is dict of endpoint class name prefix:
374 # tuple containing 0 or more invalid numbers of axes
375 self.badNAxes = {
376 "Generic": (), # all numbers of axes are valid for GenericEndpoint
377 "Point2": (1, 3, 4),
378 "SpherePoint": (1, 3, 4),
379 }
381 # Dict of frame index: identity name for frames created by makeFrameSet
382 self.frameIdentDict = {
383 1: "baseFrame",
384 2: "frame2",
385 3: "frame3",
386 4: "currFrame",
387 }
389 @staticmethod
390 def makeRawArrayData(nPoints, nAxes, delta=0.123):
391 """Make an array of generic point data
393 The data will be suitable for spherical points
395 Parameters
396 ----------
397 nPoints : `int`
398 Number of points in the array
399 nAxes : `int`
400 Number of axes in the point
402 Returns
403 -------
404 np.array of floats with shape (nAxes, nPoints)
405 The values are as follows; if nAxes != 2:
406 The first point has values `[0, delta, 2*delta, ..., (nAxes-1)*delta]`
407 The Nth point has those values + N
408 if nAxes == 2 then the data is scaled so that the max value of axis 1
409 is a bit less than pi/2
410 """
411 delta = 0.123
412 # oneAxis = [0, 1, 2, ...nPoints-1]
413 oneAxis = np.arange(nPoints, dtype=float) # [0, 1, 2...]
414 # rawData = [oneAxis, oneAxis + delta, oneAxis + 2 delta, ...]
415 rawData = np.array([j * delta + oneAxis for j in range(nAxes)], dtype=float)
416 if nAxes == 2:
417 # scale rawData so that max value of 2nd axis is a bit less than pi/2,
418 # thus making the data safe for SpherePoint
419 maxLatitude = np.max(rawData[1])
420 rawData *= math.pi * 0.4999 / maxLatitude
421 return rawData
423 @staticmethod
424 def makeRawPointData(nAxes, delta=0.123):
425 """Make one generic point
427 Parameters
428 ----------
429 nAxes : `int`
430 Number of axes in the point
431 delta : `float`
432 Increment between axis values
434 Returns
435 -------
436 A list of `nAxes` floats with values `[0, delta, ..., (nAxes-1)*delta]
437 """
438 return [i*delta for i in range(nAxes)]
440 @staticmethod
441 def makeEndpoint(name, nAxes=None):
442 """Make an endpoint
444 Parameters
445 ----------
446 name : `str`
447 Endpoint class name prefix; the full class name is name + "Endpoint"
448 nAxes : `int` or `None`, optional
449 number of axes; an int is required if `name` == "Generic";
450 otherwise ignored
452 Returns
453 -------
454 subclass of `lsst.afw.geom.BaseEndpoint`
455 The constructed endpoint
457 Raises
458 ------
459 TypeError
460 If `name` == "Generic" and `nAxes` is None or <= 0
461 """
462 EndpointClassName = name + "Endpoint"
463 EndpointClass = getattr(afwGeom, EndpointClassName)
464 if name == "Generic":
465 if nAxes is None:
466 raise TypeError("nAxes must be an integer for GenericEndpoint")
467 return EndpointClass(nAxes)
468 return EndpointClass()
470 @classmethod
471 def makeGoodFrame(cls, name, nAxes=None):
472 """Return the appropriate frame for the given name and nAxes
474 Parameters
475 ----------
476 name : `str`
477 Endpoint class name prefix; the full class name is name + "Endpoint"
478 nAxes : `int` or `None`, optional
479 number of axes; an int is required if `name` == "Generic";
480 otherwise ignored
482 Returns
483 -------
484 `ast.Frame`
485 The constructed frame
487 Raises
488 ------
489 TypeError
490 If `name` == "Generic" and `nAxes` is `None` or <= 0
491 """
492 return cls.makeEndpoint(name, nAxes).makeFrame()
494 @staticmethod
495 def makeBadFrames(name):
496 """Return a list of 0 or more frames that are not a valid match for the
497 named endpoint
499 Parameters
500 ----------
501 name : `str`
502 Endpoint class name prefix; the full class name is name + "Endpoint"
504 Returns
505 -------
506 Collection of `ast.Frame`
507 A collection of 0 or more frames
508 """
509 return {
510 "Generic": [],
511 "Point2": [
512 ast.SkyFrame(),
513 ast.Frame(1),
514 ast.Frame(3),
515 ],
516 "SpherePoint": [
517 ast.Frame(1),
518 ast.Frame(2),
519 ast.Frame(3),
520 ],
521 }[name]
523 def makeFrameSet(self, baseFrame, currFrame):
524 """Make a FrameSet
526 The FrameSet will contain 4 frames and three transforms connecting them.
527 The idenity of each frame is provided by self.frameIdentDict
529 Frame Index Mapping from this frame to the next
530 `baseFrame` 1 `ast.UnitMap(nIn)`
531 Frame(nIn) 2 `polyMap`
532 Frame(nOut) 3 `ast.UnitMap(nOut)`
533 `currFrame` 4
535 where:
536 - `nIn` = `baseFrame.nAxes`
537 - `nOut` = `currFrame.nAxes`
538 - `polyMap` = `makeTwoWayPolyMap(nIn, nOut)`
540 Returns
541 ------
542 `ast.FrameSet`
543 The FrameSet as described above
545 Parameters
546 ----------
547 baseFrame : `ast.Frame`
548 base frame
549 currFrame : `ast.Frame`
550 current frame
551 """
552 nIn = baseFrame.nAxes
553 nOut = currFrame.nAxes
554 polyMap = makeTwoWayPolyMap(nIn, nOut)
556 # The only way to set the Ident of a frame in a FrameSet is to set it in advance,
557 # and I don't want to modify the inputs, so replace the input frames with copies
558 baseFrame = baseFrame.copy()
559 baseFrame.ident = self.frameIdentDict[1]
560 currFrame = currFrame.copy()
561 currFrame.ident = self.frameIdentDict[4]
563 frameSet = ast.FrameSet(baseFrame)
564 frame2 = ast.Frame(nIn)
565 frame2.ident = self.frameIdentDict[2]
566 frameSet.addFrame(ast.FrameSet.CURRENT, ast.UnitMap(nIn), frame2)
567 frame3 = ast.Frame(nOut)
568 frame3.ident = self.frameIdentDict[3]
569 frameSet.addFrame(ast.FrameSet.CURRENT, polyMap, frame3)
570 frameSet.addFrame(ast.FrameSet.CURRENT, ast.UnitMap(nOut), currFrame)
571 return frameSet
573 @staticmethod
574 def permuteFrameSetIter(frameSet):
575 """Iterator over 0 or more frameSets with SkyFrames axes permuted
577 Only base and current SkyFrames are permuted. If neither the base nor
578 the current frame is a SkyFrame then no frames are returned.
580 Returns
581 -------
582 iterator over `PermutedFrameSet`
583 """
585 fsInfo = FrameSetInfo(frameSet)
586 if not (fsInfo.isBaseSkyFrame or fsInfo.isCurrSkyFrame):
587 return
589 permuteBaseList = [False, True] if fsInfo.isBaseSkyFrame else [False]
590 permuteCurrList = [False, True] if fsInfo.isCurrSkyFrame else [False]
591 for permuteBase in permuteBaseList:
592 for permuteCurr in permuteCurrList:
593 yield PermutedFrameSet(frameSet, permuteBase, permuteCurr)
595 @staticmethod
596 def makeJacobian(nIn, nOut, inPoint):
597 """Make a Jacobian matrix for the equation described by
598 `makeTwoWayPolyMap`.
600 Parameters
601 ----------
602 nIn, nOut : `int`
603 the dimensions of the input and output data; see makeTwoWayPolyMap
604 inPoint : `numpy.ndarray`
605 an array of size `nIn` representing the point at which the Jacobian
606 is measured
608 Returns
609 -------
610 J : `numpy.ndarray`
611 an `nOut` x `nIn` array of first derivatives
612 """
613 basePolyMapCoeff = 0.001 # see makeTwoWayPolyMap
614 baseCoeff = 2.0 * basePolyMapCoeff
615 coeffs = np.empty((nOut, nIn))
616 for iOut in range(nOut):
617 coeffOffset = baseCoeff * iOut
618 for iIn in range(nIn):
619 coeffs[iOut, iIn] = baseCoeff * (iIn + 1) + coeffOffset
620 coeffs[iOut, iIn] *= inPoint[iIn]
621 assert coeffs.ndim == 2
622 # Avoid spurious errors when comparing to a simplified array
623 assert coeffs.shape == (nOut, nIn)
624 return coeffs
626 def checkTransformation(self, transform, mapping, msg=""):
627 """Check applyForward and applyInverse for a transform
629 Parameters
630 ----------
631 transform : `lsst.afw.geom.Transform`
632 The transform to check
633 mapping : `ast.Mapping`
634 The mapping the transform should use. This mapping
635 must contain valid forward or inverse transformations,
636 but they need not match if both present. Hence the
637 mappings returned by make*PolyMap are acceptable.
638 msg : `str`
639 Error message suffix describing test parameters
640 """
641 fromEndpoint = transform.fromEndpoint
642 toEndpoint = transform.toEndpoint
643 mappingFromTransform = transform.getMapping()
645 nIn = mapping.nIn
646 nOut = mapping.nOut
647 self.assertEqual(nIn, fromEndpoint.nAxes, msg=msg)
648 self.assertEqual(nOut, toEndpoint.nAxes, msg=msg)
650 # forward transformation of one point
651 rawInPoint = self.makeRawPointData(nIn)
652 inPoint = fromEndpoint.pointFromData(rawInPoint)
654 # forward transformation of an array of points
655 nPoints = 7 # arbitrary
656 rawInArray = self.makeRawArrayData(nPoints, nIn)
657 inArray = fromEndpoint.arrayFromData(rawInArray)
659 if mapping.hasForward:
660 self.assertTrue(transform.hasForward)
661 outPoint = transform.applyForward(inPoint)
662 rawOutPoint = toEndpoint.dataFromPoint(outPoint)
663 assert_allclose(rawOutPoint, mapping.applyForward(rawInPoint), err_msg=msg)
664 assert_allclose(rawOutPoint, mappingFromTransform.applyForward(rawInPoint), err_msg=msg)
666 outArray = transform.applyForward(inArray)
667 rawOutArray = toEndpoint.dataFromArray(outArray)
668 self.assertFloatsAlmostEqual(rawOutArray, mapping.applyForward(rawInArray), msg=msg)
669 self.assertFloatsAlmostEqual(rawOutArray, mappingFromTransform.applyForward(rawInArray), msg=msg)
670 else:
671 # Need outPoint, but don't need it to be consistent with inPoint
672 rawOutPoint = self.makeRawPointData(nOut)
673 outPoint = toEndpoint.pointFromData(rawOutPoint)
674 rawOutArray = self.makeRawArrayData(nPoints, nOut)
675 outArray = toEndpoint.arrayFromData(rawOutArray)
677 self.assertFalse(transform.hasForward)
679 if mapping.hasInverse:
680 self.assertTrue(transform.hasInverse)
681 # inverse transformation of one point;
682 # remember that the inverse need not give the original values
683 # (see the description of the `mapping` parameter)
684 inversePoint = transform.applyInverse(outPoint)
685 rawInversePoint = fromEndpoint.dataFromPoint(inversePoint)
686 assert_allclose(rawInversePoint, mapping.applyInverse(rawOutPoint), err_msg=msg)
687 assert_allclose(rawInversePoint, mappingFromTransform.applyInverse(rawOutPoint), err_msg=msg)
689 # inverse transformation of an array of points;
690 # remember that the inverse will not give the original values
691 # (see the description of the `mapping` parameter)
692 inverseArray = transform.applyInverse(outArray)
693 rawInverseArray = fromEndpoint.dataFromArray(inverseArray)
694 self.assertFloatsAlmostEqual(rawInverseArray, mapping.applyInverse(rawOutArray), msg=msg)
695 self.assertFloatsAlmostEqual(rawInverseArray, mappingFromTransform.applyInverse(rawOutArray),
696 msg=msg)
697 else:
698 self.assertFalse(transform.hasInverse)
700 def checkInverseTransformation(self, forward, inverse, msg=""):
701 """Check that two Transforms are each others' inverses.
703 Parameters
704 ----------
705 forward : `lsst.afw.geom.Transform`
706 the reference Transform to test
707 inverse : `lsst.afw.geom.Transform`
708 the transform that should be the inverse of `forward`
709 msg : `str`
710 error message suffix describing test parameters
711 """
712 fromEndpoint = forward.fromEndpoint
713 toEndpoint = forward.toEndpoint
714 forwardMapping = forward.getMapping()
715 inverseMapping = inverse.getMapping()
717 # properties
718 self.assertEqual(forward.fromEndpoint,
719 inverse.toEndpoint, msg=msg)
720 self.assertEqual(forward.toEndpoint,
721 inverse.fromEndpoint, msg=msg)
722 self.assertEqual(forward.hasForward, inverse.hasInverse, msg=msg)
723 self.assertEqual(forward.hasInverse, inverse.hasForward, msg=msg)
725 # transformations of one point
726 # we don't care about whether the transformation itself is correct
727 # (see checkTransformation), so inPoint/outPoint need not be related
728 rawInPoint = self.makeRawPointData(fromEndpoint.nAxes)
729 inPoint = fromEndpoint.pointFromData(rawInPoint)
730 rawOutPoint = self.makeRawPointData(toEndpoint.nAxes)
731 outPoint = toEndpoint.pointFromData(rawOutPoint)
733 # transformations of arrays of points
734 nPoints = 7 # arbitrary
735 rawInArray = self.makeRawArrayData(nPoints, fromEndpoint.nAxes)
736 inArray = fromEndpoint.arrayFromData(rawInArray)
737 rawOutArray = self.makeRawArrayData(nPoints, toEndpoint.nAxes)
738 outArray = toEndpoint.arrayFromData(rawOutArray)
740 if forward.hasForward:
741 self.assertEqual(forward.applyForward(inPoint),
742 inverse.applyInverse(inPoint), msg=msg)
743 self.assertEqual(forwardMapping.applyForward(rawInPoint),
744 inverseMapping.applyInverse(rawInPoint), msg=msg)
745 # Assertions must work with both lists and numpy arrays
746 assert_array_equal(forward.applyForward(inArray),
747 inverse.applyInverse(inArray),
748 err_msg=msg)
749 assert_array_equal(forwardMapping.applyForward(rawInArray),
750 inverseMapping.applyInverse(rawInArray),
751 err_msg=msg)
753 if forward.hasInverse:
754 self.assertEqual(forward.applyInverse(outPoint),
755 inverse.applyForward(outPoint), msg=msg)
756 self.assertEqual(forwardMapping.applyInverse(rawOutPoint),
757 inverseMapping.applyForward(rawOutPoint), msg=msg)
758 assert_array_equal(forward.applyInverse(outArray),
759 inverse.applyForward(outArray),
760 err_msg=msg)
761 assert_array_equal(forwardMapping.applyInverse(rawOutArray),
762 inverseMapping.applyForward(rawOutArray),
763 err_msg=msg)
765 def checkTransformFromMapping(self, fromName, toName):
766 """Check Transform_<fromName>_<toName> using the Mapping constructor
768 Parameters
769 ----------
770 fromName, toName : `str`
771 Endpoint name prefix for "from" and "to" endpoints, respectively,
772 e.g. "Point2" for `lsst.afw.geom.Point2Endpoint`
773 fromAxes, toAxes : `int`
774 number of axes in fromFrame and toFrame, respectively
775 """
776 transformClassName = "Transform{}To{}".format(fromName, toName)
777 TransformClass = getattr(afwGeom, transformClassName)
778 baseMsg = "TransformClass={}".format(TransformClass.__name__)
780 # check valid numbers of inputs and outputs
781 for nIn, nOut in itertools.product(self.goodNAxes[fromName],
782 self.goodNAxes[toName]):
783 msg = "{}, nIn={}, nOut={}".format(baseMsg, nIn, nOut)
784 polyMap = makeTwoWayPolyMap(nIn, nOut)
785 transform = TransformClass(polyMap)
787 # desired output from `str(transform)`
788 desStr = "{}[{}->{}]".format(transformClassName, nIn, nOut)
789 self.assertEqual("{}".format(transform), desStr)
790 self.assertEqual(repr(transform), "lsst.afw.geom." + desStr)
792 self.checkTransformation(transform, polyMap, msg=msg)
794 # Forward transform but no inverse
795 polyMap = makeForwardPolyMap(nIn, nOut)
796 transform = TransformClass(polyMap)
797 self.checkTransformation(transform, polyMap, msg=msg)
799 # Inverse transform but no forward
800 polyMap = makeForwardPolyMap(nOut, nIn).inverted()
801 transform = TransformClass(polyMap)
802 self.checkTransformation(transform, polyMap, msg=msg)
804 # check invalid # of output against valid # of inputs
805 for nIn, badNOut in itertools.product(self.goodNAxes[fromName],
806 self.badNAxes[toName]):
807 badPolyMap = makeTwoWayPolyMap(nIn, badNOut)
808 msg = "{}, nIn={}, badNOut={}".format(baseMsg, nIn, badNOut)
809 with self.assertRaises(InvalidParameterError, msg=msg):
810 TransformClass(badPolyMap)
812 # check invalid # of inputs against valid and invalid # of outputs
813 for badNIn, nOut in itertools.product(self.badNAxes[fromName],
814 self.goodNAxes[toName] + self.badNAxes[toName]):
815 badPolyMap = makeTwoWayPolyMap(badNIn, nOut)
816 msg = "{}, badNIn={}, nOut={}".format(baseMsg, nIn, nOut)
817 with self.assertRaises(InvalidParameterError, msg=msg):
818 TransformClass(badPolyMap)
820 def checkTransformFromFrameSet(self, fromName, toName):
821 """Check Transform_<fromName>_<toName> using the FrameSet constructor
823 Parameters
824 ----------
825 fromName, toName : `str`
826 Endpoint name prefix for "from" and "to" endpoints, respectively,
827 e.g. "Point2" for `lsst.afw.geom.Point2Endpoint`
828 """
829 transformClassName = "Transform{}To{}".format(fromName, toName)
830 TransformClass = getattr(afwGeom, transformClassName)
831 baseMsg = "TransformClass={}".format(TransformClass.__name__)
832 for nIn, nOut in itertools.product(self.goodNAxes[fromName],
833 self.goodNAxes[toName]):
834 msg = "{}, nIn={}, nOut={}".format(baseMsg, nIn, nOut)
836 baseFrame = self.makeGoodFrame(fromName, nIn)
837 currFrame = self.makeGoodFrame(toName, nOut)
838 frameSet = self.makeFrameSet(baseFrame, currFrame)
839 self.assertEqual(frameSet.nFrame, 4)
841 # construct 0 or more frame sets that are invalid for this transform class
842 for badBaseFrame in self.makeBadFrames(fromName):
843 badFrameSet = self.makeFrameSet(badBaseFrame, currFrame)
844 with self.assertRaises(InvalidParameterError):
845 TransformClass(badFrameSet)
846 for badCurrFrame in self.makeBadFrames(toName):
847 reallyBadFrameSet = self.makeFrameSet(badBaseFrame, badCurrFrame)
848 with self.assertRaises(InvalidParameterError):
849 TransformClass(reallyBadFrameSet)
850 for badCurrFrame in self.makeBadFrames(toName):
851 badFrameSet = self.makeFrameSet(baseFrame, badCurrFrame)
852 with self.assertRaises(InvalidParameterError):
853 TransformClass(badFrameSet)
855 transform = TransformClass(frameSet)
857 desStr = "{}[{}->{}]".format(transformClassName, nIn, nOut)
858 self.assertEqual("{}".format(transform), desStr)
859 self.assertEqual(repr(transform), "lsst.afw.geom." + desStr)
861 self.checkPersistence(transform)
863 mappingFromTransform = transform.getMapping()
864 transformCopy = TransformClass(mappingFromTransform)
865 self.assertEqual(type(transform), type(transformCopy))
866 self.assertEqual(transform.getMapping(), mappingFromTransform)
868 polyMap = makeTwoWayPolyMap(nIn, nOut)
870 self.checkTransformation(transform, mapping=polyMap, msg=msg)
872 # If the base and/or current frame of frameSet is a SkyFrame,
873 # try permuting that frame (in place, so the connected mappings are
874 # correctly updated). The Transform constructor should undo the permutation,
875 # (via SpherePointEndpoint.normalizeFrame) in its internal copy of frameSet,
876 # forcing the axes of the SkyFrame into standard (longitude, latitude) order
877 for permutedFS in self.permuteFrameSetIter(frameSet):
878 if permutedFS.isBaseSkyFrame:
879 baseFrame = permutedFS.frameSet.getFrame(ast.FrameSet.BASE)
880 # desired base longitude axis
881 desBaseLonAxis = 2 if permutedFS.isBasePermuted else 1
882 self.assertEqual(baseFrame.lonAxis, desBaseLonAxis)
883 if permutedFS.isCurrSkyFrame:
884 currFrame = permutedFS.frameSet.getFrame(ast.FrameSet.CURRENT)
885 # desired current base longitude axis
886 desCurrLonAxis = 2 if permutedFS.isCurrPermuted else 1
887 self.assertEqual(currFrame.lonAxis, desCurrLonAxis)
889 permTransform = TransformClass(permutedFS.frameSet)
890 self.checkTransformation(permTransform, mapping=polyMap, msg=msg)
892 def checkInverted(self, fromName, toName):
893 """Test Transform<fromName>To<toName>.inverted
895 Parameters
896 ----------
897 fromName, toName : `str`
898 Endpoint name prefix for "from" and "to" endpoints, respectively,
899 e.g. "Point2" for `lsst.afw.geom.Point2Endpoint`
900 """
901 transformClassName = "Transform{}To{}".format(fromName, toName)
902 TransformClass = getattr(afwGeom, transformClassName)
903 baseMsg = "TransformClass={}".format(TransformClass.__name__)
904 for nIn, nOut in itertools.product(self.goodNAxes[fromName],
905 self.goodNAxes[toName]):
906 msg = "{}, nIn={}, nOut={}".format(baseMsg, nIn, nOut)
907 self.checkInverseMapping(
908 TransformClass,
909 makeTwoWayPolyMap(nIn, nOut),
910 "{}, Map={}".format(msg, "TwoWay"))
911 self.checkInverseMapping(
912 TransformClass,
913 makeForwardPolyMap(nIn, nOut),
914 "{}, Map={}".format(msg, "Forward"))
915 self.checkInverseMapping(
916 TransformClass,
917 makeForwardPolyMap(nOut, nIn).inverted(),
918 "{}, Map={}".format(msg, "Inverse"))
920 def checkInverseMapping(self, TransformClass, mapping, msg):
921 """Test Transform<fromName>To<toName>.inverted for a specific
922 mapping.
924 Also check that inverted() and getInverted() return the same
925 transform.
927 Parameters
928 ----------
929 TransformClass : `type`
930 The class of transform to test, such as TransformPoint2ToPoint2
931 mapping : `ast.Mapping`
932 The mapping to use for the transform
933 msg : `str`
934 Error message suffix
935 """
936 transform = TransformClass(mapping)
937 inverse = transform.inverted()
938 inverseInverse = inverse.inverted()
940 self.checkInverseTransformation(transform, inverse, msg=msg)
941 self.checkInverseTransformation(inverse, inverseInverse, msg=msg)
942 self.checkTransformation(inverseInverse, mapping, msg=msg)
944 def checkGetJacobian(self, fromName, toName):
945 """Test Transform<fromName>To<toName>.getJacobian
947 Parameters
948 ----------
949 fromName, toName : `str`
950 Endpoint name prefix for "from" and "to" endpoints, respectively,
951 e.g. "Point2" for `lsst.afw.geom.Point2Endpoint`
952 """
953 transformClassName = "Transform{}To{}".format(fromName, toName)
954 TransformClass = getattr(afwGeom, transformClassName)
955 baseMsg = "TransformClass={}".format(TransformClass.__name__)
956 for nIn, nOut in itertools.product(self.goodNAxes[fromName],
957 self.goodNAxes[toName]):
958 msg = "{}, nIn={}, nOut={}".format(baseMsg, nIn, nOut)
959 polyMap = makeForwardPolyMap(nIn, nOut)
960 transform = TransformClass(polyMap)
961 fromEndpoint = transform.fromEndpoint
963 # Test multiple points to ensure correct functional form
964 rawInPoint = self.makeRawPointData(nIn)
965 inPoint = fromEndpoint.pointFromData(rawInPoint)
966 jacobian = transform.getJacobian(inPoint)
967 assert_allclose(jacobian, self.makeJacobian(nIn, nOut, rawInPoint),
968 err_msg=msg)
970 rawInPoint = self.makeRawPointData(nIn, 0.111)
971 inPoint = fromEndpoint.pointFromData(rawInPoint)
972 jacobian = transform.getJacobian(inPoint)
973 assert_allclose(jacobian, self.makeJacobian(nIn, nOut, rawInPoint),
974 err_msg=msg)
976 def checkThen(self, fromName, midName, toName):
977 """Test Transform<fromName>To<midName>.then(Transform<midName>To<toName>)
979 Parameters
980 ----------
981 fromName : `str`
982 the prefix of the starting endpoint (e.g., "Point2" for a
983 Point2Endpoint) for the final, concatenated Transform
984 midName : `str`
985 the prefix for the shared endpoint where two Transforms will be
986 concatenated
987 toName : `str`
988 the prefix of the ending endpoint for the final, concatenated
989 Transform
990 """
991 TransformClass1 = getattr(afwGeom,
992 "Transform{}To{}".format(fromName, midName))
993 TransformClass2 = getattr(afwGeom,
994 "Transform{}To{}".format(midName, toName))
995 baseMsg = "{}.then({})".format(TransformClass1.__name__,
996 TransformClass2.__name__)
997 for nIn, nMid, nOut in itertools.product(self.goodNAxes[fromName],
998 self.goodNAxes[midName],
999 self.goodNAxes[toName]):
1000 msg = "{}, nIn={}, nMid={}, nOut={}".format(
1001 baseMsg, nIn, nMid, nOut)
1002 polyMap1 = makeTwoWayPolyMap(nIn, nMid)
1003 transform1 = TransformClass1(polyMap1)
1004 polyMap2 = makeTwoWayPolyMap(nMid, nOut)
1005 transform2 = TransformClass2(polyMap2)
1006 transform = transform1.then(transform2)
1008 fromEndpoint = transform1.fromEndpoint
1009 toEndpoint = transform2.toEndpoint
1011 inPoint = fromEndpoint.pointFromData(self.makeRawPointData(nIn))
1012 outPointMerged = transform.applyForward(inPoint)
1013 outPointSeparate = transform2.applyForward(
1014 transform1.applyForward(inPoint))
1015 assert_allclose(toEndpoint.dataFromPoint(outPointMerged),
1016 toEndpoint.dataFromPoint(outPointSeparate),
1017 err_msg=msg)
1019 outPoint = toEndpoint.pointFromData(self.makeRawPointData(nOut))
1020 inPointMerged = transform.applyInverse(outPoint)
1021 inPointSeparate = transform1.applyInverse(
1022 transform2.applyInverse(outPoint))
1023 assert_allclose(
1024 fromEndpoint.dataFromPoint(inPointMerged),
1025 fromEndpoint.dataFromPoint(inPointSeparate),
1026 err_msg=msg)
1028 # Mismatched number of axes should fail
1029 if midName == "Generic":
1030 nIn = self.goodNAxes[fromName][0]
1031 nOut = self.goodNAxes[toName][0]
1032 polyMap = makeTwoWayPolyMap(nIn, 3)
1033 transform1 = TransformClass1(polyMap)
1034 polyMap = makeTwoWayPolyMap(2, nOut)
1035 transform2 = TransformClass2(polyMap)
1036 with self.assertRaises(InvalidParameterError):
1037 transform = transform1.then(transform2)
1039 # Mismatched types of endpoints should fail
1040 if fromName != midName:
1041 # Use TransformClass1 for both args to keep test logic simple
1042 outName = midName
1043 joinNAxes = set(self.goodNAxes[fromName]).intersection(
1044 self.goodNAxes[outName])
1046 for nIn, nMid, nOut in itertools.product(self.goodNAxes[fromName],
1047 joinNAxes,
1048 self.goodNAxes[outName]):
1049 polyMap = makeTwoWayPolyMap(nIn, nMid)
1050 transform1 = TransformClass1(polyMap)
1051 polyMap = makeTwoWayPolyMap(nMid, nOut)
1052 transform2 = TransformClass1(polyMap)
1053 with self.assertRaises(InvalidParameterError):
1054 transform = transform1.then(transform2)
1056 def assertTransformsEqual(self, transform1, transform2):
1057 """Assert that two transforms are equal"""
1058 self.assertEqual(type(transform1), type(transform2))
1059 self.assertEqual(transform1.fromEndpoint, transform2.fromEndpoint)
1060 self.assertEqual(transform1.toEndpoint, transform2.toEndpoint)
1061 self.assertEqual(transform1.getMapping(), transform2.getMapping())
1063 fromEndpoint = transform1.fromEndpoint
1064 toEndpoint = transform1.toEndpoint
1065 mapping = transform1.getMapping()
1066 nIn = mapping.nIn
1067 nOut = mapping.nOut
1069 if mapping.hasForward:
1070 nPoints = 7 # arbitrary
1071 rawInArray = self.makeRawArrayData(nPoints, nIn)
1072 inArray = fromEndpoint.arrayFromData(rawInArray)
1073 outArray = transform1.applyForward(inArray)
1074 outData = toEndpoint.dataFromArray(outArray)
1075 outArrayRoundTrip = transform2.applyForward(inArray)
1076 outDataRoundTrip = toEndpoint.dataFromArray(outArrayRoundTrip)
1077 assert_allclose(outData, outDataRoundTrip)
1079 if mapping.hasInverse:
1080 nPoints = 7 # arbitrary
1081 rawOutArray = self.makeRawArrayData(nPoints, nOut)
1082 outArray = toEndpoint.arrayFromData(rawOutArray)
1083 inArray = transform1.applyInverse(outArray)
1084 inData = fromEndpoint.dataFromArray(inArray)
1085 inArrayRoundTrip = transform2.applyInverse(outArray)
1086 inDataRoundTrip = fromEndpoint.dataFromArray(inArrayRoundTrip)
1087 assert_allclose(inData, inDataRoundTrip)
1089 def checkPersistence(self, transform):
1090 """Check persistence of a transform
1091 """
1092 className = type(transform).__name__
1094 # check writeString and readString
1095 transformStr = transform.writeString()
1096 serialVersion, serialClassName, serialRest = transformStr.split(" ", 2)
1097 self.assertEqual(int(serialVersion), 1)
1098 self.assertEqual(serialClassName, className)
1099 badStr1 = " ".join(["2", serialClassName, serialRest])
1100 with self.assertRaises(lsst.pex.exceptions.InvalidParameterError):
1101 transform.readString(badStr1)
1102 badClassName = "x" + serialClassName
1103 badStr2 = " ".join(["1", badClassName, serialRest])
1104 with self.assertRaises(lsst.pex.exceptions.InvalidParameterError):
1105 transform.readString(badStr2)
1106 transformFromStr1 = transform.readString(transformStr)
1107 self.assertTransformsEqual(transform, transformFromStr1)
1109 # check transformFromString
1110 transformFromStr2 = afwGeom.transformFromString(transformStr)
1111 self.assertTransformsEqual(transform, transformFromStr2)
1113 # Check pickling
1114 self.assertTransformsEqual(transform, pickle.loads(pickle.dumps(transform)))
1116 # Check afw::table::io persistence round-trip
1117 with lsst.utils.tests.getTempFilePath(".fits") as filename:
1118 transform.writeFits(filename)
1119 self.assertTransformsEqual(transform, type(transform).readFits(filename))