Coverage for python / lsst / images / _transforms / _transform.py: 29%

174 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-26 09:16 +0000

1# This file is part of lsst-images. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# Use of this source code is governed by a 3-clause BSD-style 

10# license that can be found in the LICENSE file. 

11 

12from __future__ import annotations 

13 

14__all__ = ( 

15 "Transform", 

16 "TransformCompositionError", 

17 "TransformSerializationModel", 

18) 

19 

20import textwrap 

21from collections.abc import Iterable 

22from typing import TYPE_CHECKING, Any, TypeVar, final 

23 

24import astropy.io.fits.header 

25import astropy.units as u 

26import numpy as np 

27import pydantic 

28 

29from .._concrete_bounds import SerializableBounds 

30from .._geom import XY, Bounds, Box 

31from ..serialization import ArchiveReadError, ArchiveTree, InputArchive, OutputArchive 

32from . import _ast as astshim 

33from ._frames import Frame, SerializableFrame, SkyFrame 

34 

35if TYPE_CHECKING: 

36 from ._projection import Projection 

37 

38 try: 

39 from lsst.afw.geom import TransformPoint2ToPoint2 as LegacyTransform 

40 except ImportError: 

41 type LegacyTransform = Any # type: ignore[no-redef] 

42 

43# These pre-python-3.12 declaration are needed by Sphinx (probably the 

44# autodoc-typehints plugin. 

45I = TypeVar("I", bound=Frame) # noqa: E741 

46O = TypeVar("O", bound=Frame) # noqa: E741 

47 

48 

49class TransformCompositionError(RuntimeError): 

50 """Exception raised when two transforms cannot be composed.""" 

51 

52 

53@final 

54class Transform[I: Frame, O: Frame]: 

55 """A transform that maps two coordinate frames. 

56 

57 Notes 

58 ----- 

59 The `Transform` class constructor is considered a private implementation 

60 detail. Instead of using this, various factory methods are available: 

61 

62 - `from_fits_wcs` constructs a transform from a FITS WCS, as represented 

63 `astropy.wcs.WCS`; 

64 - `then` composes two transforms; 

65 - `identity` constructs a trivial transform that does nothing; 

66 - `inverted` returns the inverse of a transform; 

67 - `from_legacy` converts an `lsst.afw.geom.Transform` instance. 

68 

69 When applied to celestial coordinate systems, ``x=ra`` and ``y=dec``. 

70 `Projection` provides a more natural interface for pixel-to-sky transforms. 

71 

72 `Transform` is conceptually immutable (the internal AST Mapping should 

73 never be modified in-place after construction), and hence does not need to 

74 be copied when any object that holds it is copied. 

75 """ 

76 

77 def __init__( 

78 self, 

79 in_frame: I, 

80 out_frame: O, 

81 ast_mapping: astshim.Mapping, 

82 in_bounds: Bounds | None = None, 

83 out_bounds: Bounds | None = None, 

84 components: Iterable[Transform[Any, Any]] = (), 

85 ): 

86 self._in_frame = in_frame 

87 self._out_frame = out_frame 

88 self._ast_mapping = ast_mapping 

89 self._in_bounds = in_bounds or getattr(in_frame, "bbox", None) 

90 self._out_bounds = out_bounds or getattr(out_frame, "bbox", None) 

91 self._components = list(components) 

92 

93 @staticmethod 

94 def from_fits_wcs( 

95 fits_wcs: astropy.wcs.WCS, 

96 in_frame: I, 

97 out_frame: O, 

98 in_bounds: Bounds | None = None, 

99 out_bounds: Bounds | None = None, 

100 x0: int = 0, 

101 y0: int = 0, 

102 ) -> Transform[I, O]: 

103 """Construct a transform from a FITS WCS. 

104 

105 Parameters 

106 ---------- 

107 fits_wcs 

108 FITS WCS to convert. 

109 in_frame 

110 Coordinate frame for input points to the forward transform. 

111 out_frame 

112 Coordinate frame for output points from the forward transform. 

113 in_bounds 

114 The region that bounds valid input points. 

115 out_bounds 

116 The region that bounds valid output points. 

117 x0 

118 Logical coordinate of the first column in the array this WCS 

119 relates to world coordinates. 

120 y0 

121 Logical coordinate of the first column in the array this WCS 

122 relates to world coordinates. 

123 

124 Notes 

125 ----- 

126 The ``x0`` and ``y0`` parameters reflect the fact that for FITS, the 

127 first row and column are always labeled ``(1, 1)``, while in Astropy 

128 and most other Python libraries, they are ``(0, 0)``. The `types` in 

129 this package (e.g. `Image`, `Mask`) allow them to be any pair of 

130 integers. 

131 

132 See Also 

133 -------- 

134 Projection.from_fits_wcs 

135 """ 

136 ast_stream = astshim.StringStream(fits_wcs.to_header_string(relax=True)) 

137 ast_fits_chan = astshim.FitsChan(ast_stream, "Encoding=FITS-WCS, SipReplace=0, IWC=1") 

138 ast_frame_set = ast_fits_chan.read() 

139 _prepend_ast_shift(ast_frame_set, x=x0 - 1.0, y=y0 - 1.0, ast_domain="PIXEL") 

140 return Transform( 

141 in_frame, 

142 out_frame, 

143 ast_frame_set, 

144 in_bounds=in_bounds, 

145 out_bounds=out_bounds, 

146 ) 

147 

148 @staticmethod 

149 def identity(frame: I) -> Transform[I, I]: 

150 """Construct a trivial transform that maps a frame to itelf. 

151 

152 Parameters 

153 ---------- 

154 frame 

155 Frame used for both input and output points. 

156 """ 

157 return Transform(frame, frame, astshim.UnitMap(2)) 

158 

159 @property 

160 def in_frame(self) -> I: 

161 """Coordinate frame for input points.""" 

162 return self._in_frame 

163 

164 @property 

165 def out_frame(self) -> O: 

166 """Coordinate frame for output points.""" 

167 return self._out_frame 

168 

169 @property 

170 def in_bounds(self) -> Bounds | None: 

171 """The region that bounds valid input points (`Bounds` | `None`).""" 

172 return self._in_bounds 

173 

174 @property 

175 def out_bounds(self) -> Bounds | None: 

176 """The region that bounds valid output points (`Bounds` | `None`).""" 

177 return self._out_bounds 

178 

179 def show(self, simplified: bool = False, comments: bool = False) -> str: 

180 """Return the AST native representation of the transform. 

181 

182 Parameters 

183 ---------- 

184 simplified 

185 Whether to ask AST to simplify the mapping before showing it. 

186 This will make it much more likely that two equivalent transforms 

187 have the same `show` result. If the internal mapping is actually 

188 a frame set (as needed to round-trip legacy 

189 `lsst.afw.geom.SkyWcs` objects), this will also just show the 

190 mapping with no frame set information. 

191 comments 

192 Whether to include descriptive comments. 

193 """ 

194 ast_mapping = self._ast_mapping 

195 if simplified: 

196 if isinstance(ast_mapping, astshim.FrameSet): 

197 ast_mapping = ast_mapping.getMapping() 

198 ast_mapping = ast_mapping.simplified() 

199 return ast_mapping.show(comments) 

200 

201 def apply_forward[T: np.ndarray | float](self, *, x: T, y: T) -> XY[T]: 

202 """Apply the forward transform to one or more points. 

203 

204 Parameters 

205 ---------- 

206 x : `numpy.ndarray` | `float` 

207 ``x`` values of the points to transform. 

208 y : `numpy.ndarray` | `float` 

209 ``y`` values of the points to transform. 

210 

211 Returns 

212 ------- 

213 `XY` [`numpy.ndarray` | `float`] 

214 The transformed point or points. 

215 """ 

216 return _standardize_xy( 

217 _ast_apply( 

218 self._ast_mapping.applyForward, 

219 x=self._in_frame.standardize_x(x), 

220 y=self._in_frame.standardize_y(y), 

221 ), 

222 self._out_frame, 

223 ) 

224 

225 def apply_inverse[T: np.ndarray | float](self, *, x: T, y: T) -> XY[T]: 

226 """Apply the inverse transform to one or more points. 

227 

228 Parameters 

229 ---------- 

230 x : `numpy.ndarray` | `float` 

231 ``x`` values of the points to transform. 

232 y : `numpy.ndarray` | `float` 

233 ``y`` values of the points to transform. 

234 

235 Returns 

236 ------- 

237 `XY` [`numpy.ndarray` | `float`] 

238 The transformed point or points. 

239 """ 

240 return _standardize_xy( 

241 _ast_apply( 

242 self._ast_mapping.applyInverse, 

243 x=self._out_frame.standardize_x(x), 

244 y=self._out_frame.standardize_y(y), 

245 ), 

246 self._in_frame, 

247 ) 

248 

249 def apply_forward_q(self, *, x: u.Quantity, y: u.Quantity) -> XY[u.Quantity]: 

250 """Apply the forward transform to one or more unit-aware points. 

251 

252 Parameters 

253 ---------- 

254 x 

255 ``x`` values of the points to transform. 

256 y 

257 ``y`` values of the points to transform. 

258 

259 Returns 

260 ------- 

261 `XY` [`astropy.units.Quantity`] 

262 The transformed point or points. 

263 """ 

264 xy = self.apply_forward(x=x.to_value(self._in_frame.unit), y=y.to_value(self._in_frame.unit)) 

265 return XY(xy.x * self._out_frame.unit, xy.y * self._out_frame.unit) 

266 

267 def apply_inverse_q(self, *, x: u.Quantity, y: u.Quantity) -> XY[u.Quantity]: 

268 """Apply the inverse transform to one or more unit-aware points. 

269 

270 Parameters 

271 ---------- 

272 x 

273 ``x`` values of the points to transform. 

274 y 

275 ``y`` values of the points to transform. 

276 

277 Returns 

278 ------- 

279 `XY` [`astropy.units.Quantity`] 

280 The transformed point or points. 

281 """ 

282 xy = self.apply_inverse(x=x.to_value(self._out_frame.unit), y=y.to_value(self._out_frame.unit)) 

283 return XY(xy.x * self._in_frame.unit, xy.y * self._in_frame.unit) 

284 

285 def decompose(self) -> list[Transform[Any, Any]]: 

286 """Deconstruct a composed transform into its constituent parts. 

287 

288 Notes 

289 ----- 

290 Most transforms will just return a single-element list holding 

291 ``self``. Identity transform will return an empty list, and 

292 transforms composed with `then` will return the original transforms. 

293 Transforms constructed by `FrameSet` may or may not be decomposable. 

294 """ 

295 if not self._components: 

296 if self.in_frame == self._out_frame: 

297 return [] 

298 else: 

299 return [self] 

300 else: 

301 return list(self._components) 

302 

303 def inverted(self) -> Transform[O, I]: 

304 """Return the inverse of this transform.""" 

305 return Transform[O, I]( 

306 self._out_frame, 

307 self._in_frame, 

308 self._ast_mapping.inverted(), 

309 in_bounds=self.out_bounds, 

310 out_bounds=self.in_bounds, 

311 components=[t.inverted() for t in reversed(self._components)], 

312 ) 

313 

314 def then[F: Frame](self, next: Transform[O, F], remember_components: bool = True) -> Transform[I, F]: 

315 """Compose two transforms into another. 

316 

317 Parameters 

318 ---------- 

319 next 

320 Another transform to apply after ``self``. 

321 remember_components 

322 If `True`, the returned composed transform will remember ``self`` 

323 and ``other`` so they can be returned by `decompose`. 

324 """ 

325 if self._out_frame != next._in_frame: 

326 raise TransformCompositionError( 

327 "Cannot compose transforms that do not share a common intermediate frame: " 

328 f"{self._out_frame} != {next._in_frame}." 

329 ) 

330 components = self.decompose() + next.decompose() if remember_components else () 

331 return Transform( 

332 self._in_frame, 

333 next._out_frame, 

334 self._ast_mapping.then(next._ast_mapping), 

335 in_bounds=self.in_bounds, 

336 out_bounds=next.out_bounds, 

337 components=components, 

338 ) 

339 

340 def as_projection(self: Transform[I, SkyFrame]) -> Projection[I]: 

341 """Return a `Projection` view of this transform. 

342 

343 This is only valid when `out_frame` is `~SkyFrame.ICRS`. 

344 """ 

345 from ._projection import Projection 

346 

347 return Projection(self) 

348 

349 def as_fits_wcs(self, bbox: Box) -> astropy.wcs.WCS | None: 

350 """Return a FITS WCS representation of this transform, if possible. 

351 

352 Parameters 

353 ---------- 

354 bbox 

355 Bounding box of the array the FITS WCS will describe. This 

356 transform object is assumed to work on the same coordinate system 

357 in which ``bbox`` is defined, while the FITS WCS will consider the 

358 first row and column in that box to be ``(0, 0)`` (in Astropy 

359 interfaces) or ``(1, 1)`` (in the FITS representation itself). 

360 

361 Notes 

362 ----- 

363 This method assumes the transform maps pixel coordinates to world 

364 coordinates. 

365 

366 Not all transforms can be represented exactly; when a FITS 

367 represention is not possible, `None` is returned. When the returned 

368 WCS is not `None`, it will have the same functional form, but it may 

369 not evaluate identically due to small implementation differences in 

370 the order of floating-point operations. 

371 """ 

372 ast_frame_set = self._get_ast_frame_set() 

373 _prepend_ast_shift(ast_frame_set, x=1.0 - bbox.x.start, y=1.0 - bbox.y.start, ast_domain="GRID") 

374 ast_stream = astshim.StringStream() 

375 ast_fits_chan = astshim.FitsChan( 

376 ast_stream, "Encoding=FITS-WCS, CDMatrix=1, FitsAxisOrder=<copy>, FitsTol=0.0001" 

377 ) 

378 ast_fits_chan.setFitsI("NAXIS1", bbox.x.size) 

379 ast_fits_chan.setFitsI("NAXIS2", bbox.y.size) 

380 n_writes = ast_fits_chan.write(ast_frame_set) 

381 if not n_writes: 

382 return None 

383 header = astropy.io.fits.Header(astropy.io.fits.Card.fromstring(c) for c in ast_fits_chan) 

384 return astropy.wcs.WCS(header) 

385 

386 def serialize[P: pydantic.BaseModel]( 

387 self, archive: OutputArchive[P], *, use_frame_sets: bool = False 

388 ) -> TransformSerializationModel[P]: 

389 """Serialize a transform to an archive. 

390 

391 Parameters 

392 ---------- 

393 archive 

394 Archive to serialize to. 

395 use_frame_sets 

396 If `True`, decompose the transform and try to reference component 

397 mappings that were already serialized into a `FrameSet` in the 

398 archive. Note that if multiple transforms exist between a pair of 

399 frames (e.g. a `Projection` and its FITS approximation), this may 

400 cause the wrong one to be saved. When this option is used, the 

401 frame set must be saved before the transform, and it must be 

402 deserialized before the transform as well. 

403 

404 Returns 

405 ------- 

406 `TransformSerializationModel` 

407 Serialized form of the transform. 

408 """ 

409 model = TransformSerializationModel[P]() 

410 if use_frame_sets: 

411 for link in self.decompose(): 

412 model.frames.append(link.in_frame.serialize()) 

413 model.bounds.append(link.in_bounds.serialize() if link.in_bounds is not None else None) 

414 for frame_set, pointer in archive.iter_frame_sets(): 

415 if link.in_frame in frame_set and link.out_frame in frame_set: 

416 model.mappings.append(pointer) 

417 break 

418 else: 

419 model.mappings.append(MappingSerializationModel(ast=link._ast_mapping.show())) 

420 else: 

421 model.frames.append(self.in_frame.serialize()) 

422 model.bounds.append(self.in_bounds.serialize() if self.in_bounds is not None else None) 

423 model.mappings.append(MappingSerializationModel(ast=self._ast_mapping.show())) 

424 model.frames.append(self.out_frame.serialize()) 

425 model.bounds.append(self.out_bounds.serialize() if self.out_bounds is not None else None) 

426 return model 

427 

428 @staticmethod 

429 def deserialize[P: pydantic.BaseModel]( 

430 model: TransformSerializationModel[P], archive: InputArchive[P] 

431 ) -> Transform[Any, Any]: 

432 """Deserialize a transform from an archive. 

433 

434 Parameters 

435 ---------- 

436 model 

437 Seralized form of the transform. 

438 archive 

439 Archive to read from. 

440 """ 

441 if len(model.frames) != len(model.bounds): 

442 raise ArchiveReadError( 

443 f"Inconsistent lengths for 'frames' ({len(model.frames)}) and 'bounds' ({len(model.bounds)})." 

444 ) 

445 if len(model.frames) != len(model.mappings) + 1: 

446 raise ArchiveReadError( 

447 f"Inconsistent lengths for 'frames' ({len(model.frames)}) and " 

448 f"'mappings' ({len(model.mappings)}; should be one less)." 

449 ) 

450 # We can't just compose onto an identity Transform if we want to 

451 # preserve the FrameSet-ness of any of these mappings. 

452 transform: Transform | None = None 

453 for n, mapping in enumerate(model.mappings): 

454 match mapping: 

455 case MappingSerializationModel(ast=serialized_mapping): 

456 ast_mapping = astshim.Mapping.fromString(serialized_mapping) 

457 in_bounds = model.bounds[n] 

458 out_bounds = model.bounds[n + 1] 

459 new_transform = Transform( 

460 Frame.deserialize(model.frames[n]), 

461 Frame.deserialize(model.frames[n + 1]), 

462 ast_mapping, 

463 Bounds.deserialize(in_bounds) if in_bounds is not None else None, 

464 Bounds.deserialize(out_bounds) if out_bounds is not None else None, 

465 ) 

466 case reference: 

467 frame_set = archive.get_frame_set(reference) 

468 new_transform = frame_set[ 

469 Frame.deserialize(model.frames[n]), Frame.deserialize(model.frames[n + 1]) 

470 ] 

471 if transform is None: 

472 transform = new_transform 

473 else: 

474 transform = transform.then(new_transform) 

475 if transform is None: 

476 transform = Transform.identity(Frame.deserialize(model.frames[0])) 

477 return transform 

478 

479 @staticmethod 

480 def _get_archive_tree_type[P: pydantic.BaseModel]( 

481 pointer_type: type[P], 

482 ) -> type[TransformSerializationModel[P]]: 

483 """Return the serialization model type for this object for an archive 

484 type that uses the given pointer type. 

485 """ 

486 return TransformSerializationModel[pointer_type] # type: ignore 

487 

488 @staticmethod 

489 def from_legacy( 

490 legacy: LegacyTransform, 

491 in_frame: I, 

492 out_frame: O, 

493 in_bounds: Bounds | None = None, 

494 out_bounds: Bounds | None = None, 

495 ) -> Transform[I, O]: 

496 """Construct a transform from a legacy `lsst.afw.geom.Transform`. 

497 

498 Parameters 

499 ---------- 

500 legacy : `lsst.afw.geom.Transform` 

501 Legacy transform object. 

502 in_frame 

503 Coordinate frame for input points to the forward transform. 

504 out_frame 

505 Coordinate frame for output points from the forward transform. 

506 in_bounds 

507 The region that bounds valid input points. 

508 out_bounds 

509 The region that bounds valid output points. 

510 """ 

511 return Transform( 

512 in_frame, 

513 out_frame, 

514 legacy.getMapping(), 

515 in_bounds=in_bounds, 

516 out_bounds=out_bounds, 

517 ) 

518 

519 def to_legacy(self) -> LegacyTransform: 

520 """Convert to a legacy `lsst.afw.geom.TransformPoint2ToPoint2` 

521 instance. 

522 """ 

523 from lsst.afw.geom import TransformPoint2ToPoint2 as LegacyTransform 

524 

525 return LegacyTransform(self._ast_mapping, False) 

526 

527 def _get_ast_frame_set(self) -> Any: 

528 ast_frame_set = astshim.FrameSet(_make_ast_frame(self._in_frame)) 

529 ast_frame_set.addFrame(astshim.FrameSet.BASE, self._ast_mapping, _make_ast_frame(self._out_frame)) 

530 return ast_frame_set 

531 

532 

533def _ast_apply[T: np.ndarray | float](method: Any, *, x: T, y: T) -> XY[T]: 

534 # TODO: add bounds argument and check inputs 

535 # TODO: broadcast arrays with different shapes. 

536 xy_in = np.vstack([x, y]).astype(np.float64) 

537 xy_out = method(xy_in) 

538 return XY(xy_out[0, :], xy_out[1, :]) 

539 

540 

541def _prepend_ast_shift(ast_frame_set: Any, x: float, y: float, ast_domain: str) -> None: 

542 ast_output_frame_id = ast_frame_set.current 

543 ast_frame_set.addFrame( 

544 astshim.FrameSet.BASE, 

545 astshim.ShiftMap([x, y]), 

546 astshim.Frame(2, f"Domain={ast_domain}"), 

547 ) 

548 ast_frame_set.base = ast_frame_set.current 

549 ast_frame_set.current = ast_output_frame_id 

550 

551 

552def _make_ast_frame(frame: Frame) -> Any: 

553 if frame is SkyFrame.ICRS: 

554 return astshim.SkyFrame("") 

555 ast_frame = astshim.Frame(2, f"Ident={frame._ast_ident}") 

556 if frame.unit is not None: 

557 fits_unit = frame.unit.to_string(format="fits") 

558 ast_frame.setUnit(1, fits_unit) 

559 ast_frame.setUnit(2, fits_unit) 

560 ast_frame.setLabel(1, "x") 

561 ast_frame.setLabel(2, "y") 

562 return ast_frame 

563 

564 

565def _standardize_xy[T: np.ndarray | float](xy: XY[T], frame: Frame) -> XY[T]: 

566 return XY(x=frame.standardize_x(xy.x), y=frame.standardize_y(xy.y)) 

567 

568 

569class MappingSerializationModel(ArchiveTree): 

570 """Serialization model for an AST Mapping.""" 

571 

572 ast: str = pydantic.Field(description="A serialized Starlink AST Mapping, using the AST native encoding.") 

573 

574 

575class TransformSerializationModel[P: pydantic.BaseModel](ArchiveTree): 

576 """Serialization model for coordinate transforms.""" 

577 

578 frames: list[SerializableFrame] = pydantic.Field( 

579 default_factory=list, 

580 description=textwrap.dedent( 

581 """ 

582 List of frames that this transform passes through. 

583 

584 All transforms include at least two frames (the endpoints). Others 

585 intermediate frames may be included to facilitate data-sharing 

586 between transforms. 

587 """ 

588 ), 

589 ) 

590 

591 bounds: list[SerializableBounds | None] = pydantic.Field( 

592 default_factory=list, 

593 description=textwrap.dedent( 

594 """ 

595 List of the bounds of the ``frames`` for this transform. 

596 

597 This always has the same number of elements as ``frames``. 

598 """ 

599 ), 

600 ) 

601 

602 mappings: list[P | MappingSerializationModel] = pydantic.Field( 

603 default_factory=list, 

604 description=textwrap.dedent( 

605 """ 

606 The actual mappings between frames, or archive pointers to 

607 serialized FrameSet objects from which they can be obtained. 

608 

609 This always has one fewer element than ``frames``. 

610 """ 

611 ), 

612 )