Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# 

2# LSST Data Management System 

3# Copyright 2016 LSST Corporation. 

4# 

5# This product includes software developed by the 

6# LSST Project (http://www.lsst.org/). 

7# 

8# This program is free software: you can redistribute it and/or modify 

9# it under the terms of the GNU General Public License as published by 

10# the Free Software Foundation, either version 3 of the License, or 

11# (at your option) any later version. 

12# 

13# This program is distributed in the hope that it will be useful, 

14# but WITHOUT ANY WARRANTY; without even the implied warranty of 

15# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

16# GNU General Public License for more details. 

17# 

18# You should have received a copy of the LSST License Statement and 

19# the GNU General Public License along with this program. If not, 

20# see <http://www.lsstcorp.org/LegalNotices/>. 

21# 

22 

23__all__ = ["BoxGrid", "makeSipIwcToPixel", "makeSipPixelToIwc"] 

24 

25import itertools 

26import math 

27import os 

28import pickle 

29 

30import astshim as ast 

31import numpy as np 

32from numpy.testing import assert_allclose, assert_array_equal 

33from astshim.test import makeForwardPolyMap, makeTwoWayPolyMap 

34from lsst.afw.geom.wcsUtils import getCdMatrixFromMetadata 

35 

36import lsst.geom 

37import lsst.afw.geom as afwGeom 

38from lsst.pex.exceptions import InvalidParameterError 

39import lsst.utils 

40import lsst.utils.tests 

41 

42 

43class BoxGrid: 

44 """Divide a box into nx by ny sub-boxes that tile the region 

45 

46 The sub-boxes will be of the same type as `box` and will exactly tile `box`; 

47 they will also all be the same size, to the extent possible (some variation 

48 is inevitable for integer boxes that cannot be evenly divided. 

49 

50 Parameters 

51 ---------- 

52 box : `lsst.geom.Box2I` or `lsst.geom.Box2D` 

53 the box to subdivide; the boxes in the grid will be of the same type 

54 numColRow : pair of `int` 

55 number of columns and rows 

56 """ 

57 

58 def __init__(self, box, numColRow): 

59 if len(numColRow) != 2: 

60 raise RuntimeError(f"numColRow={numColRow!r}; must be a sequence of two integers") 

61 self._numColRow = tuple(int(val) for val in numColRow) 

62 

63 if isinstance(box, lsst.geom.Box2I): 

64 stopDelta = 1 

65 elif isinstance(box, lsst.geom.Box2D): 

66 stopDelta = 0 

67 else: 

68 raise RuntimeError(f"Unknown class {type(box)} of box {box}") 

69 self.boxClass = type(box) 

70 self.stopDelta = stopDelta 

71 

72 minPoint = box.getMin() 

73 self.pointClass = type(minPoint) 

74 dtype = np.array(minPoint).dtype 

75 

76 self._divList = [np.linspace(start=box.getMin()[i], 

77 stop=box.getMax()[i] + self.stopDelta, 

78 num=self._numColRow[i] + 1, 

79 endpoint=True, 

80 dtype=dtype) for i in range(2)] 

81 

82 @property 

83 def numColRow(self): 

84 return self._numColRow 

85 

86 def __getitem__(self, indXY): 

87 """Return the box at the specified x,y index 

88 

89 Parameters 

90 ---------- 

91 indXY : pair of `ints` 

92 the x,y index to return 

93 

94 Returns 

95 ------- 

96 subBox : `lsst.geom.Box2I` or `lsst.geom.Box2D` 

97 """ 

98 beg = self.pointClass(*[self._divList[i][indXY[i]] for i in range(2)]) 

99 end = self.pointClass( 

100 *[self._divList[i][indXY[i] + 1] - self.stopDelta for i in range(2)]) 

101 return self.boxClass(beg, end) 

102 

103 def __len__(self): 

104 return self.shape[0]*self.shape[1] 

105 

106 def __iter__(self): 

107 """Return an iterator over all boxes, where column varies most quickly 

108 """ 

109 for row in range(self.numColRow[1]): 

110 for col in range(self.numColRow[0]): 

111 yield self[col, row] 

112 

113 

114class FrameSetInfo: 

115 """Information about a FrameSet 

116 

117 Parameters 

118 ---------- 

119 frameSet : `ast.FrameSet` 

120 The FrameSet about which you want information 

121 

122 Notes 

123 ----- 

124 **Fields** 

125 

126 baseInd : `int` 

127 Index of base frame 

128 currInd : `int` 

129 Index of current frame 

130 isBaseSkyFrame : `bool` 

131 Is the base frame an `ast.SkyFrame`? 

132 isCurrSkyFrame : `bool` 

133 Is the current frame an `ast.SkyFrame`? 

134 """ 

135 def __init__(self, frameSet): 

136 self.baseInd = frameSet.base 

137 self.currInd = frameSet.current 

138 self.isBaseSkyFrame = frameSet.getFrame(self.baseInd).className == "SkyFrame" 

139 self.isCurrSkyFrame = frameSet.getFrame(self.currInd).className == "SkyFrame" 

140 

141 

142def makeSipPolyMapCoeffs(metadata, name): 

143 """Return a list of ast.PolyMap coefficients for the specified SIP matrix 

144 

145 The returned list of coefficients for an ast.PolyMap 

146 that computes the following function: 

147 

148 f(dxy) = dxy + sipPolynomial(dxy)) 

149 where dxy = pixelPosition - pixelOrigin 

150 and sipPolynomial is a polynomial with terms `<name>n_m for x^n y^m` 

151 (e.g. A2_0 is the coefficient for x^2 y^0) 

152 

153 Parameters 

154 ---------- 

155 metadata : lsst.daf.base.PropertySet 

156 FITS metadata describing a WCS with the specified SIP coefficients 

157 name : str 

158 The desired SIP terms: one of A, B, AP, BP 

159 

160 Returns 

161 ------- 

162 list 

163 A list of coefficients for an ast.PolyMap that computes 

164 the specified SIP polynomial, including a term for out = in 

165 

166 Note 

167 ---- 

168 This is an internal function for use by makeSipIwcToPixel 

169 and makeSipPixelToIwc 

170 """ 

171 outAxisDict = dict(A=1, B=2, AP=1, BP=2) 

172 outAxis = outAxisDict.get(name) 

173 if outAxis is None: 

174 raise RuntimeError(f"{name} not a supported SIP name") 

175 width = metadata.getAsInt(f"{name}_ORDER") + 1 

176 found = False 

177 # start with a term for out = in 

178 coeffs = [] 

179 if outAxis == 1: 

180 coeffs.append([1.0, outAxis, 1, 0]) 

181 else: 

182 coeffs.append([1.0, outAxis, 0, 1]) 

183 # add SIP distortion terms 

184 for xPower in range(width): 

185 for yPower in range(width): 

186 coeffName = f"{name}_{xPower}_{yPower}" 

187 if not metadata.exists(coeffName): 

188 continue 

189 found = True 

190 coeff = metadata.getAsDouble(coeffName) 

191 coeffs.append([coeff, outAxis, xPower, yPower]) 

192 if not found: 

193 raise RuntimeError(f"No {name} coefficients found") 

194 return coeffs 

195 

196 

197def makeSipIwcToPixel(metadata): 

198 """Make an IWC to pixel transform with SIP distortion from FITS-WCS metadata 

199 

200 This function is primarily intended for unit tests. 

201 IWC is intermediate world coordinates, as described in the FITS papers. 

202 

203 Parameters 

204 ---------- 

205 metadata : lsst.daf.base.PropertySet 

206 FITS metadata describing a WCS with inverse SIP coefficients 

207 

208 Returns 

209 ------- 

210 lsst.afw.geom.TransformPoint2ToPoint2 

211 Transform from IWC position to pixel position (zero-based) 

212 in the forward direction. The inverse direction is not defined. 

213 

214 Notes 

215 ----- 

216 

217 The inverse SIP terms APn_m, BPn_m are polynomial coefficients x^n y^m 

218 for computing transformed x, y respectively. If we call the resulting 

219 polynomial inverseSipPolynomial, the returned transformation is: 

220 

221 pixelPosition = pixel origin + uv + inverseSipPolynomial(uv) 

222 where uv = inverseCdMatrix * iwcPosition 

223 """ 

224 crpix = (metadata.getScalar("CRPIX1") - 1, metadata.getScalar("CRPIX2") - 1) 

225 pixelRelativeToAbsoluteMap = ast.ShiftMap(crpix) 

226 cdMatrix = getCdMatrixFromMetadata(metadata) 

227 cdMatrixMap = ast.MatrixMap(cdMatrix.copy()) 

228 coeffList = makeSipPolyMapCoeffs(metadata, "AP") + makeSipPolyMapCoeffs(metadata, "BP") 

229 coeffArr = np.array(coeffList, dtype=float) 

230 sipPolyMap = ast.PolyMap(coeffArr, 2, "IterInverse=0") 

231 

232 iwcToPixelMap = cdMatrixMap.inverted().then(sipPolyMap).then(pixelRelativeToAbsoluteMap) 

233 return afwGeom.TransformPoint2ToPoint2(iwcToPixelMap) 

234 

235 

236def makeSipPixelToIwc(metadata): 

237 """Make a pixel to IWC transform with SIP distortion from FITS-WCS metadata 

238 

239 This function is primarily intended for unit tests. 

240 IWC is intermediate world coordinates, as described in the FITS papers. 

241 

242 Parameters 

243 ---------- 

244 metadata : lsst.daf.base.PropertySet 

245 FITS metadata describing a WCS with forward SIP coefficients 

246 

247 Returns 

248 ------- 

249 lsst.afw.geom.TransformPoint2ToPoint2 

250 Transform from pixel position (zero-based) to IWC position 

251 in the forward direction. The inverse direction is not defined. 

252 

253 Notes 

254 ----- 

255 

256 The forward SIP terms An_m, Bn_m are polynomial coefficients x^n y^m 

257 for computing transformed x, y respectively. If we call the resulting 

258 polynomial sipPolynomial, the returned transformation is: 

259 

260 iwcPosition = cdMatrix * (dxy + sipPolynomial(dxy)) 

261 where dxy = pixelPosition - pixelOrigin 

262 """ 

263 crpix = (metadata.getScalar("CRPIX1") - 1, metadata.getScalar("CRPIX2") - 1) 

264 pixelAbsoluteToRelativeMap = ast.ShiftMap(crpix).inverted() 

265 cdMatrix = getCdMatrixFromMetadata(metadata) 

266 cdMatrixMap = ast.MatrixMap(cdMatrix.copy()) 

267 coeffList = makeSipPolyMapCoeffs(metadata, "A") + makeSipPolyMapCoeffs(metadata, "B") 

268 coeffArr = np.array(coeffList, dtype=float) 

269 sipPolyMap = ast.PolyMap(coeffArr, 2, "IterInverse=0") 

270 pixelToIwcMap = pixelAbsoluteToRelativeMap.then(sipPolyMap).then(cdMatrixMap) 

271 return afwGeom.TransformPoint2ToPoint2(pixelToIwcMap) 

272 

273 

274class PermutedFrameSet: 

275 """A FrameSet with base or current frame possibly permuted, with associated 

276 information 

277 

278 Only two-axis frames will be permuted. 

279 

280 Parameters 

281 ---------- 

282 frameSet : `ast.FrameSet` 

283 The FrameSet you wish to permute. A deep copy is made. 

284 permuteBase : `bool` 

285 Permute the base frame's axes? 

286 permuteCurr : `bool` 

287 Permute the current frame's axes? 

288 

289 Raises 

290 ------ 

291 RuntimeError 

292 If you try to permute a frame that does not have 2 axes 

293 

294 Notes 

295 ----- 

296 **Fields** 

297 

298 frameSet : `ast.FrameSet` 

299 The FrameSet that may be permuted. A local copy is made. 

300 isBaseSkyFrame : `bool` 

301 Is the base frame an `ast.SkyFrame`? 

302 isCurrSkyFrame : `bool` 

303 Is the current frame an `ast.SkyFrame`? 

304 isBasePermuted : `bool` 

305 Are the base frame axes permuted? 

306 isCurrPermuted : `bool` 

307 Are the current frame axes permuted? 

308 """ 

309 def __init__(self, frameSet, permuteBase, permuteCurr): 

310 self.frameSet = frameSet.copy() 

311 fsInfo = FrameSetInfo(self.frameSet) 

312 self.isBaseSkyFrame = fsInfo.isBaseSkyFrame 

313 self.isCurrSkyFrame = fsInfo.isCurrSkyFrame 

314 if permuteBase: 

315 baseNAxes = self.frameSet.getFrame(fsInfo.baseInd).nAxes 

316 if baseNAxes != 2: 

317 raise RuntimeError("Base frame has {} axes; 2 required to permute".format(baseNAxes)) 

318 self.frameSet.current = fsInfo.baseInd 

319 self.frameSet.permAxes([2, 1]) 

320 self.frameSet.current = fsInfo.currInd 

321 if permuteCurr: 

322 currNAxes = self.frameSet.getFrame(fsInfo.currInd).nAxes 

323 if currNAxes != 2: 

324 raise RuntimeError("Current frame has {} axes; 2 required to permute".format(currNAxes)) 

325 assert self.frameSet.getFrame(fsInfo.currInd).nAxes == 2 

326 self.frameSet.permAxes([2, 1]) 

327 self.isBasePermuted = permuteBase 

328 self.isCurrPermuted = permuteCurr 

329 

330 

331class TransformTestBaseClass(lsst.utils.tests.TestCase): 

332 """Base class for unit tests of Transform<X>To<Y> 

333 

334 Subclasses must call `TransformTestBaseClass.setUp(self)` 

335 if they provide their own version. 

336 

337 If a package other than afw uses this class then it must 

338 override the `getTestDir` method to avoid writing into 

339 afw's test directory. 

340 """ 

341 

342 def getTestDir(self): 

343 """Return a directory where temporary test files can be written 

344 

345 The default implementation returns the test directory of the `afw` 

346 package. 

347 

348 If this class is used by a test in a package other than `afw` 

349 then the subclass must override this method. 

350 """ 

351 return os.path.join(lsst.utils.getPackageDir("afw"), "tests") 

352 

353 def setUp(self): 

354 """Set up a test 

355 

356 Subclasses should call this method if they override setUp. 

357 """ 

358 # tell unittest to use the msg argument of asserts as a supplement 

359 # to the error message, rather than as the whole error message 

360 self.longMessage = True 

361 

362 # list of endpoint class name prefixes; the full name is prefix + "Endpoint" 

363 self.endpointPrefixes = ("Generic", "Point2", "SpherePoint") 

364 

365 # GoodNAxes is dict of endpoint class name prefix: 

366 # tuple containing 0 or more valid numbers of axes 

367 self.goodNAxes = { 

368 "Generic": (1, 2, 3, 4), # all numbers of axes are valid for GenericEndpoint 

369 "Point2": (2,), 

370 "SpherePoint": (2,), 

371 } 

372 

373 # BadAxes is dict of endpoint class name prefix: 

374 # tuple containing 0 or more invalid numbers of axes 

375 self.badNAxes = { 

376 "Generic": (), # all numbers of axes are valid for GenericEndpoint 

377 "Point2": (1, 3, 4), 

378 "SpherePoint": (1, 3, 4), 

379 } 

380 

381 # Dict of frame index: identity name for frames created by makeFrameSet 

382 self.frameIdentDict = { 

383 1: "baseFrame", 

384 2: "frame2", 

385 3: "frame3", 

386 4: "currFrame", 

387 } 

388 

389 @staticmethod 

390 def makeRawArrayData(nPoints, nAxes, delta=0.123): 

391 """Make an array of generic point data 

392 

393 The data will be suitable for spherical points 

394 

395 Parameters 

396 ---------- 

397 nPoints : `int` 

398 Number of points in the array 

399 nAxes : `int` 

400 Number of axes in the point 

401 

402 Returns 

403 ------- 

404 np.array of floats with shape (nAxes, nPoints) 

405 The values are as follows; if nAxes != 2: 

406 The first point has values `[0, delta, 2*delta, ..., (nAxes-1)*delta]` 

407 The Nth point has those values + N 

408 if nAxes == 2 then the data is scaled so that the max value of axis 1 

409 is a bit less than pi/2 

410 """ 

411 delta = 0.123 

412 # oneAxis = [0, 1, 2, ...nPoints-1] 

413 oneAxis = np.arange(nPoints, dtype=float) # [0, 1, 2...] 

414 # rawData = [oneAxis, oneAxis + delta, oneAxis + 2 delta, ...] 

415 rawData = np.array([j * delta + oneAxis for j in range(nAxes)], dtype=float) 

416 if nAxes == 2: 

417 # scale rawData so that max value of 2nd axis is a bit less than pi/2, 

418 # thus making the data safe for SpherePoint 

419 maxLatitude = np.max(rawData[1]) 

420 rawData *= math.pi * 0.4999 / maxLatitude 

421 return rawData 

422 

423 @staticmethod 

424 def makeRawPointData(nAxes, delta=0.123): 

425 """Make one generic point 

426 

427 Parameters 

428 ---------- 

429 nAxes : `int` 

430 Number of axes in the point 

431 delta : `float` 

432 Increment between axis values 

433 

434 Returns 

435 ------- 

436 A list of `nAxes` floats with values `[0, delta, ..., (nAxes-1)*delta] 

437 """ 

438 return [i*delta for i in range(nAxes)] 

439 

440 @staticmethod 

441 def makeEndpoint(name, nAxes=None): 

442 """Make an endpoint 

443 

444 Parameters 

445 ---------- 

446 name : `str` 

447 Endpoint class name prefix; the full class name is name + "Endpoint" 

448 nAxes : `int` or `None`, optional 

449 number of axes; an int is required if `name` == "Generic"; 

450 otherwise ignored 

451 

452 Returns 

453 ------- 

454 subclass of `lsst.afw.geom.BaseEndpoint` 

455 The constructed endpoint 

456 

457 Raises 

458 ------ 

459 TypeError 

460 If `name` == "Generic" and `nAxes` is None or <= 0 

461 """ 

462 EndpointClassName = name + "Endpoint" 

463 EndpointClass = getattr(afwGeom, EndpointClassName) 

464 if name == "Generic": 

465 if nAxes is None: 

466 raise TypeError("nAxes must be an integer for GenericEndpoint") 

467 return EndpointClass(nAxes) 

468 return EndpointClass() 

469 

470 @classmethod 

471 def makeGoodFrame(cls, name, nAxes=None): 

472 """Return the appropriate frame for the given name and nAxes 

473 

474 Parameters 

475 ---------- 

476 name : `str` 

477 Endpoint class name prefix; the full class name is name + "Endpoint" 

478 nAxes : `int` or `None`, optional 

479 number of axes; an int is required if `name` == "Generic"; 

480 otherwise ignored 

481 

482 Returns 

483 ------- 

484 `ast.Frame` 

485 The constructed frame 

486 

487 Raises 

488 ------ 

489 TypeError 

490 If `name` == "Generic" and `nAxes` is `None` or <= 0 

491 """ 

492 return cls.makeEndpoint(name, nAxes).makeFrame() 

493 

494 @staticmethod 

495 def makeBadFrames(name): 

496 """Return a list of 0 or more frames that are not a valid match for the 

497 named endpoint 

498 

499 Parameters 

500 ---------- 

501 name : `str` 

502 Endpoint class name prefix; the full class name is name + "Endpoint" 

503 

504 Returns 

505 ------- 

506 Collection of `ast.Frame` 

507 A collection of 0 or more frames 

508 """ 

509 return { 

510 "Generic": [], 

511 "Point2": [ 

512 ast.SkyFrame(), 

513 ast.Frame(1), 

514 ast.Frame(3), 

515 ], 

516 "SpherePoint": [ 

517 ast.Frame(1), 

518 ast.Frame(2), 

519 ast.Frame(3), 

520 ], 

521 }[name] 

522 

523 def makeFrameSet(self, baseFrame, currFrame): 

524 """Make a FrameSet 

525 

526 The FrameSet will contain 4 frames and three transforms connecting them. 

527 The idenity of each frame is provided by self.frameIdentDict 

528 

529 Frame Index Mapping from this frame to the next 

530 `baseFrame` 1 `ast.UnitMap(nIn)` 

531 Frame(nIn) 2 `polyMap` 

532 Frame(nOut) 3 `ast.UnitMap(nOut)` 

533 `currFrame` 4 

534 

535 where: 

536 - `nIn` = `baseFrame.nAxes` 

537 - `nOut` = `currFrame.nAxes` 

538 - `polyMap` = `makeTwoWayPolyMap(nIn, nOut)` 

539 

540 Returns 

541 ------ 

542 `ast.FrameSet` 

543 The FrameSet as described above 

544 

545 Parameters 

546 ---------- 

547 baseFrame : `ast.Frame` 

548 base frame 

549 currFrame : `ast.Frame` 

550 current frame 

551 """ 

552 nIn = baseFrame.nAxes 

553 nOut = currFrame.nAxes 

554 polyMap = makeTwoWayPolyMap(nIn, nOut) 

555 

556 # The only way to set the Ident of a frame in a FrameSet is to set it in advance, 

557 # and I don't want to modify the inputs, so replace the input frames with copies 

558 baseFrame = baseFrame.copy() 

559 baseFrame.ident = self.frameIdentDict[1] 

560 currFrame = currFrame.copy() 

561 currFrame.ident = self.frameIdentDict[4] 

562 

563 frameSet = ast.FrameSet(baseFrame) 

564 frame2 = ast.Frame(nIn) 

565 frame2.ident = self.frameIdentDict[2] 

566 frameSet.addFrame(ast.FrameSet.CURRENT, ast.UnitMap(nIn), frame2) 

567 frame3 = ast.Frame(nOut) 

568 frame3.ident = self.frameIdentDict[3] 

569 frameSet.addFrame(ast.FrameSet.CURRENT, polyMap, frame3) 

570 frameSet.addFrame(ast.FrameSet.CURRENT, ast.UnitMap(nOut), currFrame) 

571 return frameSet 

572 

573 @staticmethod 

574 def permuteFrameSetIter(frameSet): 

575 """Iterator over 0 or more frameSets with SkyFrames axes permuted 

576 

577 Only base and current SkyFrames are permuted. If neither the base nor 

578 the current frame is a SkyFrame then no frames are returned. 

579 

580 Returns 

581 ------- 

582 iterator over `PermutedFrameSet` 

583 """ 

584 

585 fsInfo = FrameSetInfo(frameSet) 

586 if not (fsInfo.isBaseSkyFrame or fsInfo.isCurrSkyFrame): 

587 return 

588 

589 permuteBaseList = [False, True] if fsInfo.isBaseSkyFrame else [False] 

590 permuteCurrList = [False, True] if fsInfo.isCurrSkyFrame else [False] 

591 for permuteBase in permuteBaseList: 

592 for permuteCurr in permuteCurrList: 

593 yield PermutedFrameSet(frameSet, permuteBase, permuteCurr) 

594 

595 @staticmethod 

596 def makeJacobian(nIn, nOut, inPoint): 

597 """Make a Jacobian matrix for the equation described by 

598 `makeTwoWayPolyMap`. 

599 

600 Parameters 

601 ---------- 

602 nIn, nOut : `int` 

603 the dimensions of the input and output data; see makeTwoWayPolyMap 

604 inPoint : `numpy.ndarray` 

605 an array of size `nIn` representing the point at which the Jacobian 

606 is measured 

607 

608 Returns 

609 ------- 

610 J : `numpy.ndarray` 

611 an `nOut` x `nIn` array of first derivatives 

612 """ 

613 basePolyMapCoeff = 0.001 # see makeTwoWayPolyMap 

614 baseCoeff = 2.0 * basePolyMapCoeff 

615 coeffs = np.empty((nOut, nIn)) 

616 for iOut in range(nOut): 

617 coeffOffset = baseCoeff * iOut 

618 for iIn in range(nIn): 

619 coeffs[iOut, iIn] = baseCoeff * (iIn + 1) + coeffOffset 

620 coeffs[iOut, iIn] *= inPoint[iIn] 

621 assert coeffs.ndim == 2 

622 # Avoid spurious errors when comparing to a simplified array 

623 assert coeffs.shape == (nOut, nIn) 

624 return coeffs 

625 

626 def checkTransformation(self, transform, mapping, msg=""): 

627 """Check applyForward and applyInverse for a transform 

628 

629 Parameters 

630 ---------- 

631 transform : `lsst.afw.geom.Transform` 

632 The transform to check 

633 mapping : `ast.Mapping` 

634 The mapping the transform should use. This mapping 

635 must contain valid forward or inverse transformations, 

636 but they need not match if both present. Hence the 

637 mappings returned by make*PolyMap are acceptable. 

638 msg : `str` 

639 Error message suffix describing test parameters 

640 """ 

641 fromEndpoint = transform.fromEndpoint 

642 toEndpoint = transform.toEndpoint 

643 mappingFromTransform = transform.getMapping() 

644 

645 nIn = mapping.nIn 

646 nOut = mapping.nOut 

647 self.assertEqual(nIn, fromEndpoint.nAxes, msg=msg) 

648 self.assertEqual(nOut, toEndpoint.nAxes, msg=msg) 

649 

650 # forward transformation of one point 

651 rawInPoint = self.makeRawPointData(nIn) 

652 inPoint = fromEndpoint.pointFromData(rawInPoint) 

653 

654 # forward transformation of an array of points 

655 nPoints = 7 # arbitrary 

656 rawInArray = self.makeRawArrayData(nPoints, nIn) 

657 inArray = fromEndpoint.arrayFromData(rawInArray) 

658 

659 if mapping.hasForward: 

660 self.assertTrue(transform.hasForward) 

661 outPoint = transform.applyForward(inPoint) 

662 rawOutPoint = toEndpoint.dataFromPoint(outPoint) 

663 assert_allclose(rawOutPoint, mapping.applyForward(rawInPoint), err_msg=msg) 

664 assert_allclose(rawOutPoint, mappingFromTransform.applyForward(rawInPoint), err_msg=msg) 

665 

666 outArray = transform.applyForward(inArray) 

667 rawOutArray = toEndpoint.dataFromArray(outArray) 

668 self.assertFloatsAlmostEqual(rawOutArray, mapping.applyForward(rawInArray), msg=msg) 

669 self.assertFloatsAlmostEqual(rawOutArray, mappingFromTransform.applyForward(rawInArray), msg=msg) 

670 else: 

671 # Need outPoint, but don't need it to be consistent with inPoint 

672 rawOutPoint = self.makeRawPointData(nOut) 

673 outPoint = toEndpoint.pointFromData(rawOutPoint) 

674 rawOutArray = self.makeRawArrayData(nPoints, nOut) 

675 outArray = toEndpoint.arrayFromData(rawOutArray) 

676 

677 self.assertFalse(transform.hasForward) 

678 

679 if mapping.hasInverse: 

680 self.assertTrue(transform.hasInverse) 

681 # inverse transformation of one point; 

682 # remember that the inverse need not give the original values 

683 # (see the description of the `mapping` parameter) 

684 inversePoint = transform.applyInverse(outPoint) 

685 rawInversePoint = fromEndpoint.dataFromPoint(inversePoint) 

686 assert_allclose(rawInversePoint, mapping.applyInverse(rawOutPoint), err_msg=msg) 

687 assert_allclose(rawInversePoint, mappingFromTransform.applyInverse(rawOutPoint), err_msg=msg) 

688 

689 # inverse transformation of an array of points; 

690 # remember that the inverse will not give the original values 

691 # (see the description of the `mapping` parameter) 

692 inverseArray = transform.applyInverse(outArray) 

693 rawInverseArray = fromEndpoint.dataFromArray(inverseArray) 

694 self.assertFloatsAlmostEqual(rawInverseArray, mapping.applyInverse(rawOutArray), msg=msg) 

695 self.assertFloatsAlmostEqual(rawInverseArray, mappingFromTransform.applyInverse(rawOutArray), 

696 msg=msg) 

697 else: 

698 self.assertFalse(transform.hasInverse) 

699 

700 def checkInverseTransformation(self, forward, inverse, msg=""): 

701 """Check that two Transforms are each others' inverses. 

702 

703 Parameters 

704 ---------- 

705 forward : `lsst.afw.geom.Transform` 

706 the reference Transform to test 

707 inverse : `lsst.afw.geom.Transform` 

708 the transform that should be the inverse of `forward` 

709 msg : `str` 

710 error message suffix describing test parameters 

711 """ 

712 fromEndpoint = forward.fromEndpoint 

713 toEndpoint = forward.toEndpoint 

714 forwardMapping = forward.getMapping() 

715 inverseMapping = inverse.getMapping() 

716 

717 # properties 

718 self.assertEqual(forward.fromEndpoint, 

719 inverse.toEndpoint, msg=msg) 

720 self.assertEqual(forward.toEndpoint, 

721 inverse.fromEndpoint, msg=msg) 

722 self.assertEqual(forward.hasForward, inverse.hasInverse, msg=msg) 

723 self.assertEqual(forward.hasInverse, inverse.hasForward, msg=msg) 

724 

725 # transformations of one point 

726 # we don't care about whether the transformation itself is correct 

727 # (see checkTransformation), so inPoint/outPoint need not be related 

728 rawInPoint = self.makeRawPointData(fromEndpoint.nAxes) 

729 inPoint = fromEndpoint.pointFromData(rawInPoint) 

730 rawOutPoint = self.makeRawPointData(toEndpoint.nAxes) 

731 outPoint = toEndpoint.pointFromData(rawOutPoint) 

732 

733 # transformations of arrays of points 

734 nPoints = 7 # arbitrary 

735 rawInArray = self.makeRawArrayData(nPoints, fromEndpoint.nAxes) 

736 inArray = fromEndpoint.arrayFromData(rawInArray) 

737 rawOutArray = self.makeRawArrayData(nPoints, toEndpoint.nAxes) 

738 outArray = toEndpoint.arrayFromData(rawOutArray) 

739 

740 if forward.hasForward: 

741 self.assertEqual(forward.applyForward(inPoint), 

742 inverse.applyInverse(inPoint), msg=msg) 

743 self.assertEqual(forwardMapping.applyForward(rawInPoint), 

744 inverseMapping.applyInverse(rawInPoint), msg=msg) 

745 # Assertions must work with both lists and numpy arrays 

746 assert_array_equal(forward.applyForward(inArray), 

747 inverse.applyInverse(inArray), 

748 err_msg=msg) 

749 assert_array_equal(forwardMapping.applyForward(rawInArray), 

750 inverseMapping.applyInverse(rawInArray), 

751 err_msg=msg) 

752 

753 if forward.hasInverse: 

754 self.assertEqual(forward.applyInverse(outPoint), 

755 inverse.applyForward(outPoint), msg=msg) 

756 self.assertEqual(forwardMapping.applyInverse(rawOutPoint), 

757 inverseMapping.applyForward(rawOutPoint), msg=msg) 

758 assert_array_equal(forward.applyInverse(outArray), 

759 inverse.applyForward(outArray), 

760 err_msg=msg) 

761 assert_array_equal(forwardMapping.applyInverse(rawOutArray), 

762 inverseMapping.applyForward(rawOutArray), 

763 err_msg=msg) 

764 

765 def checkTransformFromMapping(self, fromName, toName): 

766 """Check Transform_<fromName>_<toName> using the Mapping constructor 

767 

768 Parameters 

769 ---------- 

770 fromName, toName : `str` 

771 Endpoint name prefix for "from" and "to" endpoints, respectively, 

772 e.g. "Point2" for `lsst.afw.geom.Point2Endpoint` 

773 fromAxes, toAxes : `int` 

774 number of axes in fromFrame and toFrame, respectively 

775 """ 

776 transformClassName = "Transform{}To{}".format(fromName, toName) 

777 TransformClass = getattr(afwGeom, transformClassName) 

778 baseMsg = "TransformClass={}".format(TransformClass.__name__) 

779 

780 # check valid numbers of inputs and outputs 

781 for nIn, nOut in itertools.product(self.goodNAxes[fromName], 

782 self.goodNAxes[toName]): 

783 msg = "{}, nIn={}, nOut={}".format(baseMsg, nIn, nOut) 

784 polyMap = makeTwoWayPolyMap(nIn, nOut) 

785 transform = TransformClass(polyMap) 

786 

787 # desired output from `str(transform)` 

788 desStr = "{}[{}->{}]".format(transformClassName, nIn, nOut) 

789 self.assertEqual("{}".format(transform), desStr) 

790 self.assertEqual(repr(transform), "lsst.afw.geom." + desStr) 

791 

792 self.checkTransformation(transform, polyMap, msg=msg) 

793 

794 # Forward transform but no inverse 

795 polyMap = makeForwardPolyMap(nIn, nOut) 

796 transform = TransformClass(polyMap) 

797 self.checkTransformation(transform, polyMap, msg=msg) 

798 

799 # Inverse transform but no forward 

800 polyMap = makeForwardPolyMap(nOut, nIn).inverted() 

801 transform = TransformClass(polyMap) 

802 self.checkTransformation(transform, polyMap, msg=msg) 

803 

804 # check invalid # of output against valid # of inputs 

805 for nIn, badNOut in itertools.product(self.goodNAxes[fromName], 

806 self.badNAxes[toName]): 

807 badPolyMap = makeTwoWayPolyMap(nIn, badNOut) 

808 msg = "{}, nIn={}, badNOut={}".format(baseMsg, nIn, badNOut) 

809 with self.assertRaises(InvalidParameterError, msg=msg): 

810 TransformClass(badPolyMap) 

811 

812 # check invalid # of inputs against valid and invalid # of outputs 

813 for badNIn, nOut in itertools.product(self.badNAxes[fromName], 

814 self.goodNAxes[toName] + self.badNAxes[toName]): 

815 badPolyMap = makeTwoWayPolyMap(badNIn, nOut) 

816 msg = "{}, badNIn={}, nOut={}".format(baseMsg, nIn, nOut) 

817 with self.assertRaises(InvalidParameterError, msg=msg): 

818 TransformClass(badPolyMap) 

819 

820 def checkTransformFromFrameSet(self, fromName, toName): 

821 """Check Transform_<fromName>_<toName> using the FrameSet constructor 

822 

823 Parameters 

824 ---------- 

825 fromName, toName : `str` 

826 Endpoint name prefix for "from" and "to" endpoints, respectively, 

827 e.g. "Point2" for `lsst.afw.geom.Point2Endpoint` 

828 """ 

829 transformClassName = "Transform{}To{}".format(fromName, toName) 

830 TransformClass = getattr(afwGeom, transformClassName) 

831 baseMsg = "TransformClass={}".format(TransformClass.__name__) 

832 for nIn, nOut in itertools.product(self.goodNAxes[fromName], 

833 self.goodNAxes[toName]): 

834 msg = "{}, nIn={}, nOut={}".format(baseMsg, nIn, nOut) 

835 

836 baseFrame = self.makeGoodFrame(fromName, nIn) 

837 currFrame = self.makeGoodFrame(toName, nOut) 

838 frameSet = self.makeFrameSet(baseFrame, currFrame) 

839 self.assertEqual(frameSet.nFrame, 4) 

840 

841 # construct 0 or more frame sets that are invalid for this transform class 

842 for badBaseFrame in self.makeBadFrames(fromName): 

843 badFrameSet = self.makeFrameSet(badBaseFrame, currFrame) 

844 with self.assertRaises(InvalidParameterError): 

845 TransformClass(badFrameSet) 

846 for badCurrFrame in self.makeBadFrames(toName): 

847 reallyBadFrameSet = self.makeFrameSet(badBaseFrame, badCurrFrame) 

848 with self.assertRaises(InvalidParameterError): 

849 TransformClass(reallyBadFrameSet) 

850 for badCurrFrame in self.makeBadFrames(toName): 

851 badFrameSet = self.makeFrameSet(baseFrame, badCurrFrame) 

852 with self.assertRaises(InvalidParameterError): 

853 TransformClass(badFrameSet) 

854 

855 transform = TransformClass(frameSet) 

856 

857 desStr = "{}[{}->{}]".format(transformClassName, nIn, nOut) 

858 self.assertEqual("{}".format(transform), desStr) 

859 self.assertEqual(repr(transform), "lsst.afw.geom." + desStr) 

860 

861 self.checkPersistence(transform) 

862 

863 mappingFromTransform = transform.getMapping() 

864 transformCopy = TransformClass(mappingFromTransform) 

865 self.assertEqual(type(transform), type(transformCopy)) 

866 self.assertEqual(transform.getMapping(), mappingFromTransform) 

867 

868 polyMap = makeTwoWayPolyMap(nIn, nOut) 

869 

870 self.checkTransformation(transform, mapping=polyMap, msg=msg) 

871 

872 # If the base and/or current frame of frameSet is a SkyFrame, 

873 # try permuting that frame (in place, so the connected mappings are 

874 # correctly updated). The Transform constructor should undo the permutation, 

875 # (via SpherePointEndpoint.normalizeFrame) in its internal copy of frameSet, 

876 # forcing the axes of the SkyFrame into standard (longitude, latitude) order 

877 for permutedFS in self.permuteFrameSetIter(frameSet): 

878 if permutedFS.isBaseSkyFrame: 

879 baseFrame = permutedFS.frameSet.getFrame(ast.FrameSet.BASE) 

880 # desired base longitude axis 

881 desBaseLonAxis = 2 if permutedFS.isBasePermuted else 1 

882 self.assertEqual(baseFrame.lonAxis, desBaseLonAxis) 

883 if permutedFS.isCurrSkyFrame: 

884 currFrame = permutedFS.frameSet.getFrame(ast.FrameSet.CURRENT) 

885 # desired current base longitude axis 

886 desCurrLonAxis = 2 if permutedFS.isCurrPermuted else 1 

887 self.assertEqual(currFrame.lonAxis, desCurrLonAxis) 

888 

889 permTransform = TransformClass(permutedFS.frameSet) 

890 self.checkTransformation(permTransform, mapping=polyMap, msg=msg) 

891 

892 def checkInverted(self, fromName, toName): 

893 """Test Transform<fromName>To<toName>.inverted 

894 

895 Parameters 

896 ---------- 

897 fromName, toName : `str` 

898 Endpoint name prefix for "from" and "to" endpoints, respectively, 

899 e.g. "Point2" for `lsst.afw.geom.Point2Endpoint` 

900 """ 

901 transformClassName = "Transform{}To{}".format(fromName, toName) 

902 TransformClass = getattr(afwGeom, transformClassName) 

903 baseMsg = "TransformClass={}".format(TransformClass.__name__) 

904 for nIn, nOut in itertools.product(self.goodNAxes[fromName], 

905 self.goodNAxes[toName]): 

906 msg = "{}, nIn={}, nOut={}".format(baseMsg, nIn, nOut) 

907 self.checkInverseMapping( 

908 TransformClass, 

909 makeTwoWayPolyMap(nIn, nOut), 

910 "{}, Map={}".format(msg, "TwoWay")) 

911 self.checkInverseMapping( 

912 TransformClass, 

913 makeForwardPolyMap(nIn, nOut), 

914 "{}, Map={}".format(msg, "Forward")) 

915 self.checkInverseMapping( 

916 TransformClass, 

917 makeForwardPolyMap(nOut, nIn).inverted(), 

918 "{}, Map={}".format(msg, "Inverse")) 

919 

920 def checkInverseMapping(self, TransformClass, mapping, msg): 

921 """Test Transform<fromName>To<toName>.inverted for a specific 

922 mapping. 

923 

924 Also check that inverted() and getInverted() return the same 

925 transform. 

926 

927 Parameters 

928 ---------- 

929 TransformClass : `type` 

930 The class of transform to test, such as TransformPoint2ToPoint2 

931 mapping : `ast.Mapping` 

932 The mapping to use for the transform 

933 msg : `str` 

934 Error message suffix 

935 """ 

936 transform = TransformClass(mapping) 

937 inverse = transform.inverted() 

938 inverseInverse = inverse.inverted() 

939 

940 self.checkInverseTransformation(transform, inverse, msg=msg) 

941 self.checkInverseTransformation(inverse, inverseInverse, msg=msg) 

942 self.checkTransformation(inverseInverse, mapping, msg=msg) 

943 

944 def checkGetJacobian(self, fromName, toName): 

945 """Test Transform<fromName>To<toName>.getJacobian 

946 

947 Parameters 

948 ---------- 

949 fromName, toName : `str` 

950 Endpoint name prefix for "from" and "to" endpoints, respectively, 

951 e.g. "Point2" for `lsst.afw.geom.Point2Endpoint` 

952 """ 

953 transformClassName = "Transform{}To{}".format(fromName, toName) 

954 TransformClass = getattr(afwGeom, transformClassName) 

955 baseMsg = "TransformClass={}".format(TransformClass.__name__) 

956 for nIn, nOut in itertools.product(self.goodNAxes[fromName], 

957 self.goodNAxes[toName]): 

958 msg = "{}, nIn={}, nOut={}".format(baseMsg, nIn, nOut) 

959 polyMap = makeForwardPolyMap(nIn, nOut) 

960 transform = TransformClass(polyMap) 

961 fromEndpoint = transform.fromEndpoint 

962 

963 # Test multiple points to ensure correct functional form 

964 rawInPoint = self.makeRawPointData(nIn) 

965 inPoint = fromEndpoint.pointFromData(rawInPoint) 

966 jacobian = transform.getJacobian(inPoint) 

967 assert_allclose(jacobian, self.makeJacobian(nIn, nOut, rawInPoint), 

968 err_msg=msg) 

969 

970 rawInPoint = self.makeRawPointData(nIn, 0.111) 

971 inPoint = fromEndpoint.pointFromData(rawInPoint) 

972 jacobian = transform.getJacobian(inPoint) 

973 assert_allclose(jacobian, self.makeJacobian(nIn, nOut, rawInPoint), 

974 err_msg=msg) 

975 

976 def checkThen(self, fromName, midName, toName): 

977 """Test Transform<fromName>To<midName>.then(Transform<midName>To<toName>) 

978 

979 Parameters 

980 ---------- 

981 fromName : `str` 

982 the prefix of the starting endpoint (e.g., "Point2" for a 

983 Point2Endpoint) for the final, concatenated Transform 

984 midName : `str` 

985 the prefix for the shared endpoint where two Transforms will be 

986 concatenated 

987 toName : `str` 

988 the prefix of the ending endpoint for the final, concatenated 

989 Transform 

990 """ 

991 TransformClass1 = getattr(afwGeom, 

992 "Transform{}To{}".format(fromName, midName)) 

993 TransformClass2 = getattr(afwGeom, 

994 "Transform{}To{}".format(midName, toName)) 

995 baseMsg = "{}.then({})".format(TransformClass1.__name__, 

996 TransformClass2.__name__) 

997 for nIn, nMid, nOut in itertools.product(self.goodNAxes[fromName], 

998 self.goodNAxes[midName], 

999 self.goodNAxes[toName]): 

1000 msg = "{}, nIn={}, nMid={}, nOut={}".format( 

1001 baseMsg, nIn, nMid, nOut) 

1002 polyMap1 = makeTwoWayPolyMap(nIn, nMid) 

1003 transform1 = TransformClass1(polyMap1) 

1004 polyMap2 = makeTwoWayPolyMap(nMid, nOut) 

1005 transform2 = TransformClass2(polyMap2) 

1006 transform = transform1.then(transform2) 

1007 

1008 fromEndpoint = transform1.fromEndpoint 

1009 toEndpoint = transform2.toEndpoint 

1010 

1011 inPoint = fromEndpoint.pointFromData(self.makeRawPointData(nIn)) 

1012 outPointMerged = transform.applyForward(inPoint) 

1013 outPointSeparate = transform2.applyForward( 

1014 transform1.applyForward(inPoint)) 

1015 assert_allclose(toEndpoint.dataFromPoint(outPointMerged), 

1016 toEndpoint.dataFromPoint(outPointSeparate), 

1017 err_msg=msg) 

1018 

1019 outPoint = toEndpoint.pointFromData(self.makeRawPointData(nOut)) 

1020 inPointMerged = transform.applyInverse(outPoint) 

1021 inPointSeparate = transform1.applyInverse( 

1022 transform2.applyInverse(outPoint)) 

1023 assert_allclose( 

1024 fromEndpoint.dataFromPoint(inPointMerged), 

1025 fromEndpoint.dataFromPoint(inPointSeparate), 

1026 err_msg=msg) 

1027 

1028 # Mismatched number of axes should fail 

1029 if midName == "Generic": 

1030 nIn = self.goodNAxes[fromName][0] 

1031 nOut = self.goodNAxes[toName][0] 

1032 polyMap = makeTwoWayPolyMap(nIn, 3) 

1033 transform1 = TransformClass1(polyMap) 

1034 polyMap = makeTwoWayPolyMap(2, nOut) 

1035 transform2 = TransformClass2(polyMap) 

1036 with self.assertRaises(InvalidParameterError): 

1037 transform = transform1.then(transform2) 

1038 

1039 # Mismatched types of endpoints should fail 

1040 if fromName != midName: 

1041 # Use TransformClass1 for both args to keep test logic simple 

1042 outName = midName 

1043 joinNAxes = set(self.goodNAxes[fromName]).intersection( 

1044 self.goodNAxes[outName]) 

1045 

1046 for nIn, nMid, nOut in itertools.product(self.goodNAxes[fromName], 

1047 joinNAxes, 

1048 self.goodNAxes[outName]): 

1049 polyMap = makeTwoWayPolyMap(nIn, nMid) 

1050 transform1 = TransformClass1(polyMap) 

1051 polyMap = makeTwoWayPolyMap(nMid, nOut) 

1052 transform2 = TransformClass1(polyMap) 

1053 with self.assertRaises(InvalidParameterError): 

1054 transform = transform1.then(transform2) 

1055 

1056 def assertTransformsEqual(self, transform1, transform2): 

1057 """Assert that two transforms are equal""" 

1058 self.assertEqual(type(transform1), type(transform2)) 

1059 self.assertEqual(transform1.fromEndpoint, transform2.fromEndpoint) 

1060 self.assertEqual(transform1.toEndpoint, transform2.toEndpoint) 

1061 self.assertEqual(transform1.getMapping(), transform2.getMapping()) 

1062 

1063 fromEndpoint = transform1.fromEndpoint 

1064 toEndpoint = transform1.toEndpoint 

1065 mapping = transform1.getMapping() 

1066 nIn = mapping.nIn 

1067 nOut = mapping.nOut 

1068 

1069 if mapping.hasForward: 

1070 nPoints = 7 # arbitrary 

1071 rawInArray = self.makeRawArrayData(nPoints, nIn) 

1072 inArray = fromEndpoint.arrayFromData(rawInArray) 

1073 outArray = transform1.applyForward(inArray) 

1074 outData = toEndpoint.dataFromArray(outArray) 

1075 outArrayRoundTrip = transform2.applyForward(inArray) 

1076 outDataRoundTrip = toEndpoint.dataFromArray(outArrayRoundTrip) 

1077 assert_allclose(outData, outDataRoundTrip) 

1078 

1079 if mapping.hasInverse: 

1080 nPoints = 7 # arbitrary 

1081 rawOutArray = self.makeRawArrayData(nPoints, nOut) 

1082 outArray = toEndpoint.arrayFromData(rawOutArray) 

1083 inArray = transform1.applyInverse(outArray) 

1084 inData = fromEndpoint.dataFromArray(inArray) 

1085 inArrayRoundTrip = transform2.applyInverse(outArray) 

1086 inDataRoundTrip = fromEndpoint.dataFromArray(inArrayRoundTrip) 

1087 assert_allclose(inData, inDataRoundTrip) 

1088 

1089 def checkPersistence(self, transform): 

1090 """Check persistence of a transform 

1091 """ 

1092 className = type(transform).__name__ 

1093 

1094 # check writeString and readString 

1095 transformStr = transform.writeString() 

1096 serialVersion, serialClassName, serialRest = transformStr.split(" ", 2) 

1097 self.assertEqual(int(serialVersion), 1) 

1098 self.assertEqual(serialClassName, className) 

1099 badStr1 = " ".join(["2", serialClassName, serialRest]) 

1100 with self.assertRaises(lsst.pex.exceptions.InvalidParameterError): 

1101 transform.readString(badStr1) 

1102 badClassName = "x" + serialClassName 

1103 badStr2 = " ".join(["1", badClassName, serialRest]) 

1104 with self.assertRaises(lsst.pex.exceptions.InvalidParameterError): 

1105 transform.readString(badStr2) 

1106 transformFromStr1 = transform.readString(transformStr) 

1107 self.assertTransformsEqual(transform, transformFromStr1) 

1108 

1109 # check transformFromString 

1110 transformFromStr2 = afwGeom.transformFromString(transformStr) 

1111 self.assertTransformsEqual(transform, transformFromStr2) 

1112 

1113 # Check pickling 

1114 self.assertTransformsEqual(transform, pickle.loads(pickle.dumps(transform))) 

1115 

1116 # Check afw::table::io persistence round-trip 

1117 with lsst.utils.tests.getTempFilePath(".fits") as filename: 

1118 transform.writeFits(filename) 

1119 self.assertTransformsEqual(transform, type(transform).readFits(filename))