Coverage for python/lsst/afw/geom/testUtils.py: 11%

450 statements  

« prev     ^ index     » next       coverage.py v6.4.1, created at 2022-06-14 02:44 -0700

1# This file is part of afw. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = ["BoxGrid", "makeSipIwcToPixel", "makeSipPixelToIwc"] 

23 

24import itertools 

25import math 

26import os 

27import pickle 

28 

29import astshim as ast 

30import numpy as np 

31from numpy.testing import assert_allclose, assert_array_equal 

32from astshim.test import makeForwardPolyMap, makeTwoWayPolyMap 

33from ._geom import getCdMatrixFromMetadata 

34 

35import lsst.geom 

36import lsst.afw.geom as afwGeom 

37from lsst.pex.exceptions import InvalidParameterError 

38import lsst.utils 

39import lsst.utils.tests 

40 

41 

42class BoxGrid: 

43 """Divide a box into nx by ny sub-boxes that tile the region 

44 

45 The sub-boxes will be of the same type as `box` and will exactly tile `box`; 

46 they will also all be the same size, to the extent possible (some variation 

47 is inevitable for integer boxes that cannot be evenly divided. 

48 

49 Parameters 

50 ---------- 

51 box : `lsst.geom.Box2I` or `lsst.geom.Box2D` 

52 the box to subdivide; the boxes in the grid will be of the same type 

53 numColRow : pair of `int` 

54 number of columns and rows 

55 """ 

56 

57 def __init__(self, box, numColRow): 

58 if len(numColRow) != 2: 

59 raise RuntimeError(f"numColRow={numColRow!r}; must be a sequence of two integers") 

60 self._numColRow = tuple(int(val) for val in numColRow) 

61 

62 if isinstance(box, lsst.geom.Box2I): 

63 stopDelta = 1 

64 elif isinstance(box, lsst.geom.Box2D): 

65 stopDelta = 0 

66 else: 

67 raise RuntimeError(f"Unknown class {type(box)} of box {box}") 

68 self.boxClass = type(box) 

69 self.stopDelta = stopDelta 

70 

71 minPoint = box.getMin() 

72 self.pointClass = type(minPoint) 

73 dtype = np.array(minPoint).dtype 

74 

75 self._divList = [np.linspace(start=box.getMin()[i], 

76 stop=box.getMax()[i] + self.stopDelta, 

77 num=self._numColRow[i] + 1, 

78 endpoint=True, 

79 dtype=dtype) for i in range(2)] 

80 

81 @property 

82 def numColRow(self): 

83 return self._numColRow 

84 

85 def __getitem__(self, indXY): 

86 """Return the box at the specified x,y index 

87 

88 Parameters 

89 ---------- 

90 indXY : pair of `ints` 

91 the x,y index to return 

92 

93 Returns 

94 ------- 

95 subBox : `lsst.geom.Box2I` or `lsst.geom.Box2D` 

96 """ 

97 beg = self.pointClass(*[self._divList[i][indXY[i]] for i in range(2)]) 

98 end = self.pointClass( 

99 *[self._divList[i][indXY[i] + 1] - self.stopDelta for i in range(2)]) 

100 return self.boxClass(beg, end) 

101 

102 def __len__(self): 

103 return self.shape[0]*self.shape[1] 

104 

105 def __iter__(self): 

106 """Return an iterator over all boxes, where column varies most quickly 

107 """ 

108 for row in range(self.numColRow[1]): 

109 for col in range(self.numColRow[0]): 

110 yield self[col, row] 

111 

112 

113class FrameSetInfo: 

114 """Information about a FrameSet 

115 

116 Parameters 

117 ---------- 

118 frameSet : `ast.FrameSet` 

119 The FrameSet about which you want information 

120 

121 Notes 

122 ----- 

123 **Fields** 

124 

125 baseInd : `int` 

126 Index of base frame 

127 currInd : `int` 

128 Index of current frame 

129 isBaseSkyFrame : `bool` 

130 Is the base frame an `ast.SkyFrame`? 

131 isCurrSkyFrame : `bool` 

132 Is the current frame an `ast.SkyFrame`? 

133 """ 

134 def __init__(self, frameSet): 

135 self.baseInd = frameSet.base 

136 self.currInd = frameSet.current 

137 self.isBaseSkyFrame = frameSet.getFrame(self.baseInd).className == "SkyFrame" 

138 self.isCurrSkyFrame = frameSet.getFrame(self.currInd).className == "SkyFrame" 

139 

140 

141def makeSipPolyMapCoeffs(metadata, name): 

142 """Return a list of ast.PolyMap coefficients for the specified SIP matrix 

143 

144 The returned list of coefficients for an ast.PolyMap 

145 that computes the following function: 

146 

147 f(dxy) = dxy + sipPolynomial(dxy)) 

148 where dxy = pixelPosition - pixelOrigin 

149 and sipPolynomial is a polynomial with terms `<name>n_m for x^n y^m` 

150 (e.g. A2_0 is the coefficient for x^2 y^0) 

151 

152 Parameters 

153 ---------- 

154 metadata : lsst.daf.base.PropertySet 

155 FITS metadata describing a WCS with the specified SIP coefficients 

156 name : str 

157 The desired SIP terms: one of A, B, AP, BP 

158 

159 Returns 

160 ------- 

161 list 

162 A list of coefficients for an ast.PolyMap that computes 

163 the specified SIP polynomial, including a term for out = in 

164 

165 Note 

166 ---- 

167 This is an internal function for use by makeSipIwcToPixel 

168 and makeSipPixelToIwc 

169 """ 

170 outAxisDict = dict(A=1, B=2, AP=1, BP=2) 

171 outAxis = outAxisDict.get(name) 

172 if outAxis is None: 

173 raise RuntimeError(f"{name} not a supported SIP name") 

174 width = metadata.getAsInt(f"{name}_ORDER") + 1 

175 found = False 

176 # start with a term for out = in 

177 coeffs = [] 

178 if outAxis == 1: 

179 coeffs.append([1.0, outAxis, 1, 0]) 

180 else: 

181 coeffs.append([1.0, outAxis, 0, 1]) 

182 # add SIP distortion terms 

183 for xPower in range(width): 

184 for yPower in range(width): 

185 coeffName = f"{name}_{xPower}_{yPower}" 

186 if not metadata.exists(coeffName): 

187 continue 

188 found = True 

189 coeff = metadata.getAsDouble(coeffName) 

190 coeffs.append([coeff, outAxis, xPower, yPower]) 

191 if not found: 

192 raise RuntimeError(f"No {name} coefficients found") 

193 return coeffs 

194 

195 

196def makeSipIwcToPixel(metadata): 

197 """Make an IWC to pixel transform with SIP distortion from FITS-WCS metadata 

198 

199 This function is primarily intended for unit tests. 

200 IWC is intermediate world coordinates, as described in the FITS papers. 

201 

202 Parameters 

203 ---------- 

204 metadata : lsst.daf.base.PropertySet 

205 FITS metadata describing a WCS with inverse SIP coefficients 

206 

207 Returns 

208 ------- 

209 lsst.afw.geom.TransformPoint2ToPoint2 

210 Transform from IWC position to pixel position (zero-based) 

211 in the forward direction. The inverse direction is not defined. 

212 

213 Notes 

214 ----- 

215 

216 The inverse SIP terms APn_m, BPn_m are polynomial coefficients x^n y^m 

217 for computing transformed x, y respectively. If we call the resulting 

218 polynomial inverseSipPolynomial, the returned transformation is: 

219 

220 pixelPosition = pixel origin + uv + inverseSipPolynomial(uv) 

221 where uv = inverseCdMatrix * iwcPosition 

222 """ 

223 crpix = (metadata.getScalar("CRPIX1") - 1, metadata.getScalar("CRPIX2") - 1) 

224 pixelRelativeToAbsoluteMap = ast.ShiftMap(crpix) 

225 cdMatrix = getCdMatrixFromMetadata(metadata) 

226 cdMatrixMap = ast.MatrixMap(cdMatrix.copy()) 

227 coeffList = makeSipPolyMapCoeffs(metadata, "AP") + makeSipPolyMapCoeffs(metadata, "BP") 

228 coeffArr = np.array(coeffList, dtype=float) 

229 sipPolyMap = ast.PolyMap(coeffArr, 2, "IterInverse=0") 

230 

231 iwcToPixelMap = cdMatrixMap.inverted().then(sipPolyMap).then(pixelRelativeToAbsoluteMap) 

232 return afwGeom.TransformPoint2ToPoint2(iwcToPixelMap) 

233 

234 

235def makeSipPixelToIwc(metadata): 

236 """Make a pixel to IWC transform with SIP distortion from FITS-WCS metadata 

237 

238 This function is primarily intended for unit tests. 

239 IWC is intermediate world coordinates, as described in the FITS papers. 

240 

241 Parameters 

242 ---------- 

243 metadata : lsst.daf.base.PropertySet 

244 FITS metadata describing a WCS with forward SIP coefficients 

245 

246 Returns 

247 ------- 

248 lsst.afw.geom.TransformPoint2ToPoint2 

249 Transform from pixel position (zero-based) to IWC position 

250 in the forward direction. The inverse direction is not defined. 

251 

252 Notes 

253 ----- 

254 

255 The forward SIP terms An_m, Bn_m are polynomial coefficients x^n y^m 

256 for computing transformed x, y respectively. If we call the resulting 

257 polynomial sipPolynomial, the returned transformation is: 

258 

259 iwcPosition = cdMatrix * (dxy + sipPolynomial(dxy)) 

260 where dxy = pixelPosition - pixelOrigin 

261 """ 

262 crpix = (metadata.getScalar("CRPIX1") - 1, metadata.getScalar("CRPIX2") - 1) 

263 pixelAbsoluteToRelativeMap = ast.ShiftMap(crpix).inverted() 

264 cdMatrix = getCdMatrixFromMetadata(metadata) 

265 cdMatrixMap = ast.MatrixMap(cdMatrix.copy()) 

266 coeffList = makeSipPolyMapCoeffs(metadata, "A") + makeSipPolyMapCoeffs(metadata, "B") 

267 coeffArr = np.array(coeffList, dtype=float) 

268 sipPolyMap = ast.PolyMap(coeffArr, 2, "IterInverse=0") 

269 pixelToIwcMap = pixelAbsoluteToRelativeMap.then(sipPolyMap).then(cdMatrixMap) 

270 return afwGeom.TransformPoint2ToPoint2(pixelToIwcMap) 

271 

272 

273class PermutedFrameSet: 

274 """A FrameSet with base or current frame possibly permuted, with associated 

275 information 

276 

277 Only two-axis frames will be permuted. 

278 

279 Parameters 

280 ---------- 

281 frameSet : `ast.FrameSet` 

282 The FrameSet you wish to permute. A deep copy is made. 

283 permuteBase : `bool` 

284 Permute the base frame's axes? 

285 permuteCurr : `bool` 

286 Permute the current frame's axes? 

287 

288 Raises 

289 ------ 

290 RuntimeError 

291 If you try to permute a frame that does not have 2 axes 

292 

293 Notes 

294 ----- 

295 **Fields** 

296 

297 frameSet : `ast.FrameSet` 

298 The FrameSet that may be permuted. A local copy is made. 

299 isBaseSkyFrame : `bool` 

300 Is the base frame an `ast.SkyFrame`? 

301 isCurrSkyFrame : `bool` 

302 Is the current frame an `ast.SkyFrame`? 

303 isBasePermuted : `bool` 

304 Are the base frame axes permuted? 

305 isCurrPermuted : `bool` 

306 Are the current frame axes permuted? 

307 """ 

308 def __init__(self, frameSet, permuteBase, permuteCurr): 

309 self.frameSet = frameSet.copy() 

310 fsInfo = FrameSetInfo(self.frameSet) 

311 self.isBaseSkyFrame = fsInfo.isBaseSkyFrame 

312 self.isCurrSkyFrame = fsInfo.isCurrSkyFrame 

313 if permuteBase: 

314 baseNAxes = self.frameSet.getFrame(fsInfo.baseInd).nAxes 

315 if baseNAxes != 2: 

316 raise RuntimeError("Base frame has {} axes; 2 required to permute".format(baseNAxes)) 

317 self.frameSet.current = fsInfo.baseInd 

318 self.frameSet.permAxes([2, 1]) 

319 self.frameSet.current = fsInfo.currInd 

320 if permuteCurr: 

321 currNAxes = self.frameSet.getFrame(fsInfo.currInd).nAxes 

322 if currNAxes != 2: 

323 raise RuntimeError("Current frame has {} axes; 2 required to permute".format(currNAxes)) 

324 assert self.frameSet.getFrame(fsInfo.currInd).nAxes == 2 

325 self.frameSet.permAxes([2, 1]) 

326 self.isBasePermuted = permuteBase 

327 self.isCurrPermuted = permuteCurr 

328 

329 

330class TransformTestBaseClass(lsst.utils.tests.TestCase): 

331 """Base class for unit tests of Transform<X>To<Y> 

332 

333 Subclasses must call `TransformTestBaseClass.setUp(self)` 

334 if they provide their own version. 

335 

336 If a package other than afw uses this class then it must 

337 override the `getTestDir` method to avoid writing into 

338 afw's test directory. 

339 """ 

340 

341 def getTestDir(self): 

342 """Return a directory where temporary test files can be written 

343 

344 The default implementation returns the test directory of the `afw` 

345 package. 

346 

347 If this class is used by a test in a package other than `afw` 

348 then the subclass must override this method. 

349 """ 

350 return os.path.join(lsst.utils.getPackageDir("afw"), "tests") 

351 

352 def setUp(self): 

353 """Set up a test 

354 

355 Subclasses should call this method if they override setUp. 

356 """ 

357 # tell unittest to use the msg argument of asserts as a supplement 

358 # to the error message, rather than as the whole error message 

359 self.longMessage = True 

360 

361 # list of endpoint class name prefixes; the full name is prefix + "Endpoint" 

362 self.endpointPrefixes = ("Generic", "Point2", "SpherePoint") 

363 

364 # GoodNAxes is dict of endpoint class name prefix: 

365 # tuple containing 0 or more valid numbers of axes 

366 self.goodNAxes = { 

367 "Generic": (1, 2, 3, 4), # all numbers of axes are valid for GenericEndpoint 

368 "Point2": (2,), 

369 "SpherePoint": (2,), 

370 } 

371 

372 # BadAxes is dict of endpoint class name prefix: 

373 # tuple containing 0 or more invalid numbers of axes 

374 self.badNAxes = { 

375 "Generic": (), # all numbers of axes are valid for GenericEndpoint 

376 "Point2": (1, 3, 4), 

377 "SpherePoint": (1, 3, 4), 

378 } 

379 

380 # Dict of frame index: identity name for frames created by makeFrameSet 

381 self.frameIdentDict = { 

382 1: "baseFrame", 

383 2: "frame2", 

384 3: "frame3", 

385 4: "currFrame", 

386 } 

387 

388 @staticmethod 

389 def makeRawArrayData(nPoints, nAxes, delta=0.123): 

390 """Make an array of generic point data 

391 

392 The data will be suitable for spherical points 

393 

394 Parameters 

395 ---------- 

396 nPoints : `int` 

397 Number of points in the array 

398 nAxes : `int` 

399 Number of axes in the point 

400 

401 Returns 

402 ------- 

403 np.array of floats with shape (nAxes, nPoints) 

404 The values are as follows; if nAxes != 2: 

405 The first point has values `[0, delta, 2*delta, ..., (nAxes-1)*delta]` 

406 The Nth point has those values + N 

407 if nAxes == 2 then the data is scaled so that the max value of axis 1 

408 is a bit less than pi/2 

409 """ 

410 delta = 0.123 

411 # oneAxis = [0, 1, 2, ...nPoints-1] 

412 oneAxis = np.arange(nPoints, dtype=float) # [0, 1, 2...] 

413 # rawData = [oneAxis, oneAxis + delta, oneAxis + 2 delta, ...] 

414 rawData = np.array([j * delta + oneAxis for j in range(nAxes)], dtype=float) 

415 if nAxes == 2: 

416 # scale rawData so that max value of 2nd axis is a bit less than pi/2, 

417 # thus making the data safe for SpherePoint 

418 maxLatitude = np.max(rawData[1]) 

419 rawData *= math.pi * 0.4999 / maxLatitude 

420 return rawData 

421 

422 @staticmethod 

423 def makeRawPointData(nAxes, delta=0.123): 

424 """Make one generic point 

425 

426 Parameters 

427 ---------- 

428 nAxes : `int` 

429 Number of axes in the point 

430 delta : `float` 

431 Increment between axis values 

432 

433 Returns 

434 ------- 

435 A list of `nAxes` floats with values `[0, delta, ..., (nAxes-1)*delta] 

436 """ 

437 return [i*delta for i in range(nAxes)] 

438 

439 @staticmethod 

440 def makeEndpoint(name, nAxes=None): 

441 """Make an endpoint 

442 

443 Parameters 

444 ---------- 

445 name : `str` 

446 Endpoint class name prefix; the full class name is name + "Endpoint" 

447 nAxes : `int` or `None`, optional 

448 number of axes; an int is required if `name` == "Generic"; 

449 otherwise ignored 

450 

451 Returns 

452 ------- 

453 subclass of `lsst.afw.geom.BaseEndpoint` 

454 The constructed endpoint 

455 

456 Raises 

457 ------ 

458 TypeError 

459 If `name` == "Generic" and `nAxes` is None or <= 0 

460 """ 

461 EndpointClassName = name + "Endpoint" 

462 EndpointClass = getattr(afwGeom, EndpointClassName) 

463 if name == "Generic": 

464 if nAxes is None: 

465 raise TypeError("nAxes must be an integer for GenericEndpoint") 

466 return EndpointClass(nAxes) 

467 return EndpointClass() 

468 

469 @classmethod 

470 def makeGoodFrame(cls, name, nAxes=None): 

471 """Return the appropriate frame for the given name and nAxes 

472 

473 Parameters 

474 ---------- 

475 name : `str` 

476 Endpoint class name prefix; the full class name is name + "Endpoint" 

477 nAxes : `int` or `None`, optional 

478 number of axes; an int is required if `name` == "Generic"; 

479 otherwise ignored 

480 

481 Returns 

482 ------- 

483 `ast.Frame` 

484 The constructed frame 

485 

486 Raises 

487 ------ 

488 TypeError 

489 If `name` == "Generic" and `nAxes` is `None` or <= 0 

490 """ 

491 return cls.makeEndpoint(name, nAxes).makeFrame() 

492 

493 @staticmethod 

494 def makeBadFrames(name): 

495 """Return a list of 0 or more frames that are not a valid match for the 

496 named endpoint 

497 

498 Parameters 

499 ---------- 

500 name : `str` 

501 Endpoint class name prefix; the full class name is name + "Endpoint" 

502 

503 Returns 

504 ------- 

505 Collection of `ast.Frame` 

506 A collection of 0 or more frames 

507 """ 

508 return { 

509 "Generic": [], 

510 "Point2": [ 

511 ast.SkyFrame(), 

512 ast.Frame(1), 

513 ast.Frame(3), 

514 ], 

515 "SpherePoint": [ 

516 ast.Frame(1), 

517 ast.Frame(2), 

518 ast.Frame(3), 

519 ], 

520 }[name] 

521 

522 def makeFrameSet(self, baseFrame, currFrame): 

523 """Make a FrameSet 

524 

525 The FrameSet will contain 4 frames and three transforms connecting them. 

526 The idenity of each frame is provided by self.frameIdentDict 

527 

528 Frame Index Mapping from this frame to the next 

529 `baseFrame` 1 `ast.UnitMap(nIn)` 

530 Frame(nIn) 2 `polyMap` 

531 Frame(nOut) 3 `ast.UnitMap(nOut)` 

532 `currFrame` 4 

533 

534 where: 

535 - `nIn` = `baseFrame.nAxes` 

536 - `nOut` = `currFrame.nAxes` 

537 - `polyMap` = `makeTwoWayPolyMap(nIn, nOut)` 

538 

539 Returns 

540 ------ 

541 `ast.FrameSet` 

542 The FrameSet as described above 

543 

544 Parameters 

545 ---------- 

546 baseFrame : `ast.Frame` 

547 base frame 

548 currFrame : `ast.Frame` 

549 current frame 

550 """ 

551 nIn = baseFrame.nAxes 

552 nOut = currFrame.nAxes 

553 polyMap = makeTwoWayPolyMap(nIn, nOut) 

554 

555 # The only way to set the Ident of a frame in a FrameSet is to set it in advance, 

556 # and I don't want to modify the inputs, so replace the input frames with copies 

557 baseFrame = baseFrame.copy() 

558 baseFrame.ident = self.frameIdentDict[1] 

559 currFrame = currFrame.copy() 

560 currFrame.ident = self.frameIdentDict[4] 

561 

562 frameSet = ast.FrameSet(baseFrame) 

563 frame2 = ast.Frame(nIn) 

564 frame2.ident = self.frameIdentDict[2] 

565 frameSet.addFrame(ast.FrameSet.CURRENT, ast.UnitMap(nIn), frame2) 

566 frame3 = ast.Frame(nOut) 

567 frame3.ident = self.frameIdentDict[3] 

568 frameSet.addFrame(ast.FrameSet.CURRENT, polyMap, frame3) 

569 frameSet.addFrame(ast.FrameSet.CURRENT, ast.UnitMap(nOut), currFrame) 

570 return frameSet 

571 

572 @staticmethod 

573 def permuteFrameSetIter(frameSet): 

574 """Iterator over 0 or more frameSets with SkyFrames axes permuted 

575 

576 Only base and current SkyFrames are permuted. If neither the base nor 

577 the current frame is a SkyFrame then no frames are returned. 

578 

579 Returns 

580 ------- 

581 iterator over `PermutedFrameSet` 

582 """ 

583 

584 fsInfo = FrameSetInfo(frameSet) 

585 if not (fsInfo.isBaseSkyFrame or fsInfo.isCurrSkyFrame): 

586 return 

587 

588 permuteBaseList = [False, True] if fsInfo.isBaseSkyFrame else [False] 

589 permuteCurrList = [False, True] if fsInfo.isCurrSkyFrame else [False] 

590 for permuteBase in permuteBaseList: 

591 for permuteCurr in permuteCurrList: 

592 yield PermutedFrameSet(frameSet, permuteBase, permuteCurr) 

593 

594 @staticmethod 

595 def makeJacobian(nIn, nOut, inPoint): 

596 """Make a Jacobian matrix for the equation described by 

597 `makeTwoWayPolyMap`. 

598 

599 Parameters 

600 ---------- 

601 nIn, nOut : `int` 

602 the dimensions of the input and output data; see makeTwoWayPolyMap 

603 inPoint : `numpy.ndarray` 

604 an array of size `nIn` representing the point at which the Jacobian 

605 is measured 

606 

607 Returns 

608 ------- 

609 J : `numpy.ndarray` 

610 an `nOut` x `nIn` array of first derivatives 

611 """ 

612 basePolyMapCoeff = 0.001 # see makeTwoWayPolyMap 

613 baseCoeff = 2.0 * basePolyMapCoeff 

614 coeffs = np.empty((nOut, nIn)) 

615 for iOut in range(nOut): 

616 coeffOffset = baseCoeff * iOut 

617 for iIn in range(nIn): 

618 coeffs[iOut, iIn] = baseCoeff * (iIn + 1) + coeffOffset 

619 coeffs[iOut, iIn] *= inPoint[iIn] 

620 assert coeffs.ndim == 2 

621 # Avoid spurious errors when comparing to a simplified array 

622 assert coeffs.shape == (nOut, nIn) 

623 return coeffs 

624 

625 def checkTransformation(self, transform, mapping, msg=""): 

626 """Check applyForward and applyInverse for a transform 

627 

628 Parameters 

629 ---------- 

630 transform : `lsst.afw.geom.Transform` 

631 The transform to check 

632 mapping : `ast.Mapping` 

633 The mapping the transform should use. This mapping 

634 must contain valid forward or inverse transformations, 

635 but they need not match if both present. Hence the 

636 mappings returned by make*PolyMap are acceptable. 

637 msg : `str` 

638 Error message suffix describing test parameters 

639 """ 

640 fromEndpoint = transform.fromEndpoint 

641 toEndpoint = transform.toEndpoint 

642 mappingFromTransform = transform.getMapping() 

643 

644 nIn = mapping.nIn 

645 nOut = mapping.nOut 

646 self.assertEqual(nIn, fromEndpoint.nAxes, msg=msg) 

647 self.assertEqual(nOut, toEndpoint.nAxes, msg=msg) 

648 

649 # forward transformation of one point 

650 rawInPoint = self.makeRawPointData(nIn) 

651 inPoint = fromEndpoint.pointFromData(rawInPoint) 

652 

653 # forward transformation of an array of points 

654 nPoints = 7 # arbitrary 

655 rawInArray = self.makeRawArrayData(nPoints, nIn) 

656 inArray = fromEndpoint.arrayFromData(rawInArray) 

657 

658 if mapping.hasForward: 

659 self.assertTrue(transform.hasForward) 

660 outPoint = transform.applyForward(inPoint) 

661 rawOutPoint = toEndpoint.dataFromPoint(outPoint) 

662 assert_allclose(rawOutPoint, mapping.applyForward(rawInPoint), err_msg=msg) 

663 assert_allclose(rawOutPoint, mappingFromTransform.applyForward(rawInPoint), err_msg=msg) 

664 

665 outArray = transform.applyForward(inArray) 

666 rawOutArray = toEndpoint.dataFromArray(outArray) 

667 self.assertFloatsAlmostEqual(rawOutArray, mapping.applyForward(rawInArray), msg=msg) 

668 self.assertFloatsAlmostEqual(rawOutArray, mappingFromTransform.applyForward(rawInArray), msg=msg) 

669 else: 

670 # Need outPoint, but don't need it to be consistent with inPoint 

671 rawOutPoint = self.makeRawPointData(nOut) 

672 outPoint = toEndpoint.pointFromData(rawOutPoint) 

673 rawOutArray = self.makeRawArrayData(nPoints, nOut) 

674 outArray = toEndpoint.arrayFromData(rawOutArray) 

675 

676 self.assertFalse(transform.hasForward) 

677 

678 if mapping.hasInverse: 

679 self.assertTrue(transform.hasInverse) 

680 # inverse transformation of one point; 

681 # remember that the inverse need not give the original values 

682 # (see the description of the `mapping` parameter) 

683 inversePoint = transform.applyInverse(outPoint) 

684 rawInversePoint = fromEndpoint.dataFromPoint(inversePoint) 

685 assert_allclose(rawInversePoint, mapping.applyInverse(rawOutPoint), err_msg=msg) 

686 assert_allclose(rawInversePoint, mappingFromTransform.applyInverse(rawOutPoint), err_msg=msg) 

687 

688 # inverse transformation of an array of points; 

689 # remember that the inverse will not give the original values 

690 # (see the description of the `mapping` parameter) 

691 inverseArray = transform.applyInverse(outArray) 

692 rawInverseArray = fromEndpoint.dataFromArray(inverseArray) 

693 self.assertFloatsAlmostEqual(rawInverseArray, mapping.applyInverse(rawOutArray), msg=msg) 

694 self.assertFloatsAlmostEqual(rawInverseArray, mappingFromTransform.applyInverse(rawOutArray), 

695 msg=msg) 

696 else: 

697 self.assertFalse(transform.hasInverse) 

698 

699 def checkInverseTransformation(self, forward, inverse, msg=""): 

700 """Check that two Transforms are each others' inverses. 

701 

702 Parameters 

703 ---------- 

704 forward : `lsst.afw.geom.Transform` 

705 the reference Transform to test 

706 inverse : `lsst.afw.geom.Transform` 

707 the transform that should be the inverse of `forward` 

708 msg : `str` 

709 error message suffix describing test parameters 

710 """ 

711 fromEndpoint = forward.fromEndpoint 

712 toEndpoint = forward.toEndpoint 

713 forwardMapping = forward.getMapping() 

714 inverseMapping = inverse.getMapping() 

715 

716 # properties 

717 self.assertEqual(forward.fromEndpoint, 

718 inverse.toEndpoint, msg=msg) 

719 self.assertEqual(forward.toEndpoint, 

720 inverse.fromEndpoint, msg=msg) 

721 self.assertEqual(forward.hasForward, inverse.hasInverse, msg=msg) 

722 self.assertEqual(forward.hasInverse, inverse.hasForward, msg=msg) 

723 

724 # transformations of one point 

725 # we don't care about whether the transformation itself is correct 

726 # (see checkTransformation), so inPoint/outPoint need not be related 

727 rawInPoint = self.makeRawPointData(fromEndpoint.nAxes) 

728 inPoint = fromEndpoint.pointFromData(rawInPoint) 

729 rawOutPoint = self.makeRawPointData(toEndpoint.nAxes) 

730 outPoint = toEndpoint.pointFromData(rawOutPoint) 

731 

732 # transformations of arrays of points 

733 nPoints = 7 # arbitrary 

734 rawInArray = self.makeRawArrayData(nPoints, fromEndpoint.nAxes) 

735 inArray = fromEndpoint.arrayFromData(rawInArray) 

736 rawOutArray = self.makeRawArrayData(nPoints, toEndpoint.nAxes) 

737 outArray = toEndpoint.arrayFromData(rawOutArray) 

738 

739 if forward.hasForward: 

740 self.assertEqual(forward.applyForward(inPoint), 

741 inverse.applyInverse(inPoint), msg=msg) 

742 self.assertEqual(forwardMapping.applyForward(rawInPoint), 

743 inverseMapping.applyInverse(rawInPoint), msg=msg) 

744 # Assertions must work with both lists and numpy arrays 

745 assert_array_equal(forward.applyForward(inArray), 

746 inverse.applyInverse(inArray), 

747 err_msg=msg) 

748 assert_array_equal(forwardMapping.applyForward(rawInArray), 

749 inverseMapping.applyInverse(rawInArray), 

750 err_msg=msg) 

751 

752 if forward.hasInverse: 

753 self.assertEqual(forward.applyInverse(outPoint), 

754 inverse.applyForward(outPoint), msg=msg) 

755 self.assertEqual(forwardMapping.applyInverse(rawOutPoint), 

756 inverseMapping.applyForward(rawOutPoint), msg=msg) 

757 assert_array_equal(forward.applyInverse(outArray), 

758 inverse.applyForward(outArray), 

759 err_msg=msg) 

760 assert_array_equal(forwardMapping.applyInverse(rawOutArray), 

761 inverseMapping.applyForward(rawOutArray), 

762 err_msg=msg) 

763 

764 def checkTransformFromMapping(self, fromName, toName): 

765 """Check Transform_<fromName>_<toName> using the Mapping constructor 

766 

767 Parameters 

768 ---------- 

769 fromName, toName : `str` 

770 Endpoint name prefix for "from" and "to" endpoints, respectively, 

771 e.g. "Point2" for `lsst.afw.geom.Point2Endpoint` 

772 fromAxes, toAxes : `int` 

773 number of axes in fromFrame and toFrame, respectively 

774 """ 

775 transformClassName = "Transform{}To{}".format(fromName, toName) 

776 TransformClass = getattr(afwGeom, transformClassName) 

777 baseMsg = "TransformClass={}".format(TransformClass.__name__) 

778 

779 # check valid numbers of inputs and outputs 

780 for nIn, nOut in itertools.product(self.goodNAxes[fromName], 

781 self.goodNAxes[toName]): 

782 msg = "{}, nIn={}, nOut={}".format(baseMsg, nIn, nOut) 

783 polyMap = makeTwoWayPolyMap(nIn, nOut) 

784 transform = TransformClass(polyMap) 

785 

786 # desired output from `str(transform)` 

787 desStr = "{}[{}->{}]".format(transformClassName, nIn, nOut) 

788 self.assertEqual("{}".format(transform), desStr) 

789 self.assertEqual(repr(transform), "lsst.afw.geom." + desStr) 

790 

791 self.checkTransformation(transform, polyMap, msg=msg) 

792 

793 # Forward transform but no inverse 

794 polyMap = makeForwardPolyMap(nIn, nOut) 

795 transform = TransformClass(polyMap) 

796 self.checkTransformation(transform, polyMap, msg=msg) 

797 

798 # Inverse transform but no forward 

799 polyMap = makeForwardPolyMap(nOut, nIn).inverted() 

800 transform = TransformClass(polyMap) 

801 self.checkTransformation(transform, polyMap, msg=msg) 

802 

803 # check invalid # of output against valid # of inputs 

804 for nIn, badNOut in itertools.product(self.goodNAxes[fromName], 

805 self.badNAxes[toName]): 

806 badPolyMap = makeTwoWayPolyMap(nIn, badNOut) 

807 msg = "{}, nIn={}, badNOut={}".format(baseMsg, nIn, badNOut) 

808 with self.assertRaises(InvalidParameterError, msg=msg): 

809 TransformClass(badPolyMap) 

810 

811 # check invalid # of inputs against valid and invalid # of outputs 

812 for badNIn, nOut in itertools.product(self.badNAxes[fromName], 

813 self.goodNAxes[toName] + self.badNAxes[toName]): 

814 badPolyMap = makeTwoWayPolyMap(badNIn, nOut) 

815 msg = "{}, badNIn={}, nOut={}".format(baseMsg, nIn, nOut) 

816 with self.assertRaises(InvalidParameterError, msg=msg): 

817 TransformClass(badPolyMap) 

818 

819 def checkTransformFromFrameSet(self, fromName, toName): 

820 """Check Transform_<fromName>_<toName> using the FrameSet constructor 

821 

822 Parameters 

823 ---------- 

824 fromName, toName : `str` 

825 Endpoint name prefix for "from" and "to" endpoints, respectively, 

826 e.g. "Point2" for `lsst.afw.geom.Point2Endpoint` 

827 """ 

828 transformClassName = "Transform{}To{}".format(fromName, toName) 

829 TransformClass = getattr(afwGeom, transformClassName) 

830 baseMsg = "TransformClass={}".format(TransformClass.__name__) 

831 for nIn, nOut in itertools.product(self.goodNAxes[fromName], 

832 self.goodNAxes[toName]): 

833 msg = "{}, nIn={}, nOut={}".format(baseMsg, nIn, nOut) 

834 

835 baseFrame = self.makeGoodFrame(fromName, nIn) 

836 currFrame = self.makeGoodFrame(toName, nOut) 

837 frameSet = self.makeFrameSet(baseFrame, currFrame) 

838 self.assertEqual(frameSet.nFrame, 4) 

839 

840 # construct 0 or more frame sets that are invalid for this transform class 

841 for badBaseFrame in self.makeBadFrames(fromName): 

842 badFrameSet = self.makeFrameSet(badBaseFrame, currFrame) 

843 with self.assertRaises(InvalidParameterError): 

844 TransformClass(badFrameSet) 

845 for badCurrFrame in self.makeBadFrames(toName): 

846 reallyBadFrameSet = self.makeFrameSet(badBaseFrame, badCurrFrame) 

847 with self.assertRaises(InvalidParameterError): 

848 TransformClass(reallyBadFrameSet) 

849 for badCurrFrame in self.makeBadFrames(toName): 

850 badFrameSet = self.makeFrameSet(baseFrame, badCurrFrame) 

851 with self.assertRaises(InvalidParameterError): 

852 TransformClass(badFrameSet) 

853 

854 transform = TransformClass(frameSet) 

855 

856 desStr = "{}[{}->{}]".format(transformClassName, nIn, nOut) 

857 self.assertEqual("{}".format(transform), desStr) 

858 self.assertEqual(repr(transform), "lsst.afw.geom." + desStr) 

859 

860 self.checkPersistence(transform) 

861 

862 mappingFromTransform = transform.getMapping() 

863 transformCopy = TransformClass(mappingFromTransform) 

864 self.assertEqual(type(transform), type(transformCopy)) 

865 self.assertEqual(transform.getMapping(), mappingFromTransform) 

866 

867 polyMap = makeTwoWayPolyMap(nIn, nOut) 

868 

869 self.checkTransformation(transform, mapping=polyMap, msg=msg) 

870 

871 # If the base and/or current frame of frameSet is a SkyFrame, 

872 # try permuting that frame (in place, so the connected mappings are 

873 # correctly updated). The Transform constructor should undo the permutation, 

874 # (via SpherePointEndpoint.normalizeFrame) in its internal copy of frameSet, 

875 # forcing the axes of the SkyFrame into standard (longitude, latitude) order 

876 for permutedFS in self.permuteFrameSetIter(frameSet): 

877 if permutedFS.isBaseSkyFrame: 

878 baseFrame = permutedFS.frameSet.getFrame(ast.FrameSet.BASE) 

879 # desired base longitude axis 

880 desBaseLonAxis = 2 if permutedFS.isBasePermuted else 1 

881 self.assertEqual(baseFrame.lonAxis, desBaseLonAxis) 

882 if permutedFS.isCurrSkyFrame: 

883 currFrame = permutedFS.frameSet.getFrame(ast.FrameSet.CURRENT) 

884 # desired current base longitude axis 

885 desCurrLonAxis = 2 if permutedFS.isCurrPermuted else 1 

886 self.assertEqual(currFrame.lonAxis, desCurrLonAxis) 

887 

888 permTransform = TransformClass(permutedFS.frameSet) 

889 self.checkTransformation(permTransform, mapping=polyMap, msg=msg) 

890 

891 def checkInverted(self, fromName, toName): 

892 """Test Transform<fromName>To<toName>.inverted 

893 

894 Parameters 

895 ---------- 

896 fromName, toName : `str` 

897 Endpoint name prefix for "from" and "to" endpoints, respectively, 

898 e.g. "Point2" for `lsst.afw.geom.Point2Endpoint` 

899 """ 

900 transformClassName = "Transform{}To{}".format(fromName, toName) 

901 TransformClass = getattr(afwGeom, transformClassName) 

902 baseMsg = "TransformClass={}".format(TransformClass.__name__) 

903 for nIn, nOut in itertools.product(self.goodNAxes[fromName], 

904 self.goodNAxes[toName]): 

905 msg = "{}, nIn={}, nOut={}".format(baseMsg, nIn, nOut) 

906 self.checkInverseMapping( 

907 TransformClass, 

908 makeTwoWayPolyMap(nIn, nOut), 

909 "{}, Map={}".format(msg, "TwoWay")) 

910 self.checkInverseMapping( 

911 TransformClass, 

912 makeForwardPolyMap(nIn, nOut), 

913 "{}, Map={}".format(msg, "Forward")) 

914 self.checkInverseMapping( 

915 TransformClass, 

916 makeForwardPolyMap(nOut, nIn).inverted(), 

917 "{}, Map={}".format(msg, "Inverse")) 

918 

919 def checkInverseMapping(self, TransformClass, mapping, msg): 

920 """Test Transform<fromName>To<toName>.inverted for a specific 

921 mapping. 

922 

923 Also check that inverted() and getInverted() return the same 

924 transform. 

925 

926 Parameters 

927 ---------- 

928 TransformClass : `type` 

929 The class of transform to test, such as TransformPoint2ToPoint2 

930 mapping : `ast.Mapping` 

931 The mapping to use for the transform 

932 msg : `str` 

933 Error message suffix 

934 """ 

935 transform = TransformClass(mapping) 

936 inverse = transform.inverted() 

937 inverseInverse = inverse.inverted() 

938 

939 self.checkInverseTransformation(transform, inverse, msg=msg) 

940 self.checkInverseTransformation(inverse, inverseInverse, msg=msg) 

941 self.checkTransformation(inverseInverse, mapping, msg=msg) 

942 

943 def checkGetJacobian(self, fromName, toName): 

944 """Test Transform<fromName>To<toName>.getJacobian 

945 

946 Parameters 

947 ---------- 

948 fromName, toName : `str` 

949 Endpoint name prefix for "from" and "to" endpoints, respectively, 

950 e.g. "Point2" for `lsst.afw.geom.Point2Endpoint` 

951 """ 

952 transformClassName = "Transform{}To{}".format(fromName, toName) 

953 TransformClass = getattr(afwGeom, transformClassName) 

954 baseMsg = "TransformClass={}".format(TransformClass.__name__) 

955 for nIn, nOut in itertools.product(self.goodNAxes[fromName], 

956 self.goodNAxes[toName]): 

957 msg = "{}, nIn={}, nOut={}".format(baseMsg, nIn, nOut) 

958 polyMap = makeForwardPolyMap(nIn, nOut) 

959 transform = TransformClass(polyMap) 

960 fromEndpoint = transform.fromEndpoint 

961 

962 # Test multiple points to ensure correct functional form 

963 rawInPoint = self.makeRawPointData(nIn) 

964 inPoint = fromEndpoint.pointFromData(rawInPoint) 

965 jacobian = transform.getJacobian(inPoint) 

966 assert_allclose(jacobian, self.makeJacobian(nIn, nOut, rawInPoint), 

967 err_msg=msg) 

968 

969 rawInPoint = self.makeRawPointData(nIn, 0.111) 

970 inPoint = fromEndpoint.pointFromData(rawInPoint) 

971 jacobian = transform.getJacobian(inPoint) 

972 assert_allclose(jacobian, self.makeJacobian(nIn, nOut, rawInPoint), 

973 err_msg=msg) 

974 

975 def checkThen(self, fromName, midName, toName): 

976 """Test Transform<fromName>To<midName>.then(Transform<midName>To<toName>) 

977 

978 Parameters 

979 ---------- 

980 fromName : `str` 

981 the prefix of the starting endpoint (e.g., "Point2" for a 

982 Point2Endpoint) for the final, concatenated Transform 

983 midName : `str` 

984 the prefix for the shared endpoint where two Transforms will be 

985 concatenated 

986 toName : `str` 

987 the prefix of the ending endpoint for the final, concatenated 

988 Transform 

989 """ 

990 TransformClass1 = getattr(afwGeom, 

991 "Transform{}To{}".format(fromName, midName)) 

992 TransformClass2 = getattr(afwGeom, 

993 "Transform{}To{}".format(midName, toName)) 

994 baseMsg = "{}.then({})".format(TransformClass1.__name__, 

995 TransformClass2.__name__) 

996 for nIn, nMid, nOut in itertools.product(self.goodNAxes[fromName], 

997 self.goodNAxes[midName], 

998 self.goodNAxes[toName]): 

999 msg = "{}, nIn={}, nMid={}, nOut={}".format( 

1000 baseMsg, nIn, nMid, nOut) 

1001 polyMap1 = makeTwoWayPolyMap(nIn, nMid) 

1002 transform1 = TransformClass1(polyMap1) 

1003 polyMap2 = makeTwoWayPolyMap(nMid, nOut) 

1004 transform2 = TransformClass2(polyMap2) 

1005 transform = transform1.then(transform2) 

1006 

1007 fromEndpoint = transform1.fromEndpoint 

1008 toEndpoint = transform2.toEndpoint 

1009 

1010 inPoint = fromEndpoint.pointFromData(self.makeRawPointData(nIn)) 

1011 outPointMerged = transform.applyForward(inPoint) 

1012 outPointSeparate = transform2.applyForward( 

1013 transform1.applyForward(inPoint)) 

1014 assert_allclose(toEndpoint.dataFromPoint(outPointMerged), 

1015 toEndpoint.dataFromPoint(outPointSeparate), 

1016 err_msg=msg) 

1017 

1018 outPoint = toEndpoint.pointFromData(self.makeRawPointData(nOut)) 

1019 inPointMerged = transform.applyInverse(outPoint) 

1020 inPointSeparate = transform1.applyInverse( 

1021 transform2.applyInverse(outPoint)) 

1022 assert_allclose( 

1023 fromEndpoint.dataFromPoint(inPointMerged), 

1024 fromEndpoint.dataFromPoint(inPointSeparate), 

1025 err_msg=msg) 

1026 

1027 # Mismatched number of axes should fail 

1028 if midName == "Generic": 

1029 nIn = self.goodNAxes[fromName][0] 

1030 nOut = self.goodNAxes[toName][0] 

1031 polyMap = makeTwoWayPolyMap(nIn, 3) 

1032 transform1 = TransformClass1(polyMap) 

1033 polyMap = makeTwoWayPolyMap(2, nOut) 

1034 transform2 = TransformClass2(polyMap) 

1035 with self.assertRaises(InvalidParameterError): 

1036 transform = transform1.then(transform2) 

1037 

1038 # Mismatched types of endpoints should fail 

1039 if fromName != midName: 

1040 # Use TransformClass1 for both args to keep test logic simple 

1041 outName = midName 

1042 joinNAxes = set(self.goodNAxes[fromName]).intersection( 

1043 self.goodNAxes[outName]) 

1044 

1045 for nIn, nMid, nOut in itertools.product(self.goodNAxes[fromName], 

1046 joinNAxes, 

1047 self.goodNAxes[outName]): 

1048 polyMap = makeTwoWayPolyMap(nIn, nMid) 

1049 transform1 = TransformClass1(polyMap) 

1050 polyMap = makeTwoWayPolyMap(nMid, nOut) 

1051 transform2 = TransformClass1(polyMap) 

1052 with self.assertRaises(InvalidParameterError): 

1053 transform = transform1.then(transform2) 

1054 

1055 def assertTransformsEqual(self, transform1, transform2): 

1056 """Assert that two transforms are equal""" 

1057 self.assertEqual(type(transform1), type(transform2)) 

1058 self.assertEqual(transform1.fromEndpoint, transform2.fromEndpoint) 

1059 self.assertEqual(transform1.toEndpoint, transform2.toEndpoint) 

1060 self.assertEqual(transform1.getMapping(), transform2.getMapping()) 

1061 

1062 fromEndpoint = transform1.fromEndpoint 

1063 toEndpoint = transform1.toEndpoint 

1064 mapping = transform1.getMapping() 

1065 nIn = mapping.nIn 

1066 nOut = mapping.nOut 

1067 

1068 if mapping.hasForward: 

1069 nPoints = 7 # arbitrary 

1070 rawInArray = self.makeRawArrayData(nPoints, nIn) 

1071 inArray = fromEndpoint.arrayFromData(rawInArray) 

1072 outArray = transform1.applyForward(inArray) 

1073 outData = toEndpoint.dataFromArray(outArray) 

1074 outArrayRoundTrip = transform2.applyForward(inArray) 

1075 outDataRoundTrip = toEndpoint.dataFromArray(outArrayRoundTrip) 

1076 assert_allclose(outData, outDataRoundTrip) 

1077 

1078 if mapping.hasInverse: 

1079 nPoints = 7 # arbitrary 

1080 rawOutArray = self.makeRawArrayData(nPoints, nOut) 

1081 outArray = toEndpoint.arrayFromData(rawOutArray) 

1082 inArray = transform1.applyInverse(outArray) 

1083 inData = fromEndpoint.dataFromArray(inArray) 

1084 inArrayRoundTrip = transform2.applyInverse(outArray) 

1085 inDataRoundTrip = fromEndpoint.dataFromArray(inArrayRoundTrip) 

1086 assert_allclose(inData, inDataRoundTrip) 

1087 

1088 def checkPersistence(self, transform): 

1089 """Check persistence of a transform 

1090 """ 

1091 className = type(transform).__name__ 

1092 

1093 # check writeString and readString 

1094 transformStr = transform.writeString() 

1095 serialVersion, serialClassName, serialRest = transformStr.split(" ", 2) 

1096 self.assertEqual(int(serialVersion), 1) 

1097 self.assertEqual(serialClassName, className) 

1098 badStr1 = " ".join(["2", serialClassName, serialRest]) 

1099 with self.assertRaises(lsst.pex.exceptions.InvalidParameterError): 

1100 transform.readString(badStr1) 

1101 badClassName = "x" + serialClassName 

1102 badStr2 = " ".join(["1", badClassName, serialRest]) 

1103 with self.assertRaises(lsst.pex.exceptions.InvalidParameterError): 

1104 transform.readString(badStr2) 

1105 transformFromStr1 = transform.readString(transformStr) 

1106 self.assertTransformsEqual(transform, transformFromStr1) 

1107 

1108 # check transformFromString 

1109 transformFromStr2 = afwGeom.transformFromString(transformStr) 

1110 self.assertTransformsEqual(transform, transformFromStr2) 

1111 

1112 # Check pickling 

1113 self.assertTransformsEqual(transform, pickle.loads(pickle.dumps(transform))) 

1114 

1115 # Check afw::table::io persistence round-trip 

1116 with lsst.utils.tests.getTempFilePath(".fits") as filename: 

1117 transform.writeFits(filename) 

1118 self.assertTransformsEqual(transform, type(transform).readFits(filename))