Coverage for python / lsst / images / _transforms / _transform.py: 29%
174 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-22 09:13 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-22 09:13 +0000
1# This file is part of lsst-images.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# Use of this source code is governed by a 3-clause BSD-style
10# license that can be found in the LICENSE file.
12from __future__ import annotations
14__all__ = (
15 "Transform",
16 "TransformCompositionError",
17 "TransformSerializationModel",
18)
20import textwrap
21from collections.abc import Iterable
22from typing import TYPE_CHECKING, Any, TypeVar, final
24import astropy.io.fits.header
25import astropy.units as u
26import numpy as np
27import pydantic
29from .._concrete_bounds import SerializableBounds
30from .._geom import XY, Bounds, Box
31from ..serialization import ArchiveReadError, ArchiveTree, InputArchive, OutputArchive
32from . import _ast as astshim
33from ._frames import Frame, SerializableFrame, SkyFrame
35if TYPE_CHECKING:
36 from ._projection import Projection
38 try:
39 from lsst.afw.geom import TransformPoint2ToPoint2 as LegacyTransform
40 except ImportError:
41 type LegacyTransform = Any # type: ignore[no-redef]
43# These pre-python-3.12 declaration are needed by Sphinx (probably the
44# autodoc-typehints plugin.
45I = TypeVar("I", bound=Frame) # noqa: E741
46O = TypeVar("O", bound=Frame) # noqa: E741
49class TransformCompositionError(RuntimeError):
50 """Exception raised when two transforms cannot be composed."""
53@final
54class Transform[I: Frame, O: Frame]:
55 """A transform that maps two coordinate frames.
57 Notes
58 -----
59 The `Transform` class constructor is considered a private implementation
60 detail. Instead of using this, various factory methods are available:
62 - `from_fits_wcs` constructs a transform from a FITS WCS, as represented
63 `astropy.wcs.WCS`;
64 - `then` composes two transforms;
65 - `identity` constructs a trivial transform that does nothing;
66 - `inverted` returns the inverse of a transform;
67 - `from_legacy` converts an `lsst.afw.geom.Transform` instance.
69 When applied to celestial coordinate systems, ``x=ra`` and ``y=dec``.
70 `Projection` provides a more natural interface for pixel-to-sky transforms.
72 `Transform` is conceptually immutable (the internal AST Mapping should
73 never be modified in-place after construction), and hence does not need to
74 be copied when any object that holds it is copied.
75 """
77 def __init__(
78 self,
79 in_frame: I,
80 out_frame: O,
81 ast_mapping: astshim.Mapping,
82 in_bounds: Bounds | None = None,
83 out_bounds: Bounds | None = None,
84 components: Iterable[Transform[Any, Any]] = (),
85 ):
86 self._in_frame = in_frame
87 self._out_frame = out_frame
88 self._ast_mapping = ast_mapping
89 self._in_bounds = in_bounds or getattr(in_frame, "bbox", None)
90 self._out_bounds = out_bounds or getattr(out_frame, "bbox", None)
91 self._components = list(components)
93 @staticmethod
94 def from_fits_wcs(
95 fits_wcs: astropy.wcs.WCS,
96 in_frame: I,
97 out_frame: O,
98 in_bounds: Bounds | None = None,
99 out_bounds: Bounds | None = None,
100 x0: int = 0,
101 y0: int = 0,
102 ) -> Transform[I, O]:
103 """Construct a transform from a FITS WCS.
105 Parameters
106 ----------
107 fits_wcs
108 FITS WCS to convert.
109 in_frame
110 Coordinate frame for input points to the forward transform.
111 out_frame
112 Coordinate frame for output points from the forward transform.
113 in_bounds
114 The region that bounds valid input points.
115 out_bounds
116 The region that bounds valid output points.
117 x0
118 Logical coordinate of the first column in the array this WCS
119 relates to world coordinates.
120 y0
121 Logical coordinate of the first column in the array this WCS
122 relates to world coordinates.
124 Notes
125 -----
126 The ``x0`` and ``y0`` parameters reflect the fact that for FITS, the
127 first row and column are always labeled ``(1, 1)``, while in Astropy
128 and most other Python libraries, they are ``(0, 0)``. The `types` in
129 this package (e.g. `Image`, `Mask`) allow them to be any pair of
130 integers.
132 See Also
133 --------
134 Projection.from_fits_wcs
135 """
136 ast_stream = astshim.StringStream(fits_wcs.to_header_string(relax=True))
137 ast_fits_chan = astshim.FitsChan(ast_stream, "Encoding=FITS-WCS, SipReplace=0, IWC=1")
138 ast_frame_set = ast_fits_chan.read()
139 _prepend_ast_shift(ast_frame_set, x=x0 - 1.0, y=y0 - 1.0, ast_domain="PIXEL")
140 return Transform(
141 in_frame,
142 out_frame,
143 ast_frame_set,
144 in_bounds=in_bounds,
145 out_bounds=out_bounds,
146 )
148 @staticmethod
149 def identity(frame: I) -> Transform[I, I]:
150 """Construct a trivial transform that maps a frame to itelf.
152 Parameters
153 ----------
154 frame
155 Frame used for both input and output points.
156 """
157 return Transform(frame, frame, astshim.UnitMap(2))
159 @property
160 def in_frame(self) -> I:
161 """Coordinate frame for input points."""
162 return self._in_frame
164 @property
165 def out_frame(self) -> O:
166 """Coordinate frame for output points."""
167 return self._out_frame
169 @property
170 def in_bounds(self) -> Bounds | None:
171 """The region that bounds valid input points (`Bounds` | `None`)."""
172 return self._in_bounds
174 @property
175 def out_bounds(self) -> Bounds | None:
176 """The region that bounds valid output points (`Bounds` | `None`)."""
177 return self._out_bounds
179 def show(self, simplified: bool = False, comments: bool = False) -> str:
180 """Return the AST native representation of the transform.
182 Parameters
183 ----------
184 simplified
185 Whether to ask AST to simplify the mapping before showing it.
186 This will make it much more likely that two equivalent transforms
187 have the same `show` result. If the internal mapping is actually
188 a frame set (as needed to round-trip legacy
189 `lsst.afw.geom.SkyWcs` objects), this will also just show the
190 mapping with no frame set information.
191 comments
192 Whether to include descriptive comments.
193 """
194 ast_mapping = self._ast_mapping
195 if simplified:
196 if isinstance(ast_mapping, astshim.FrameSet):
197 ast_mapping = ast_mapping.getMapping()
198 ast_mapping = ast_mapping.simplified()
199 return ast_mapping.show(comments)
201 def apply_forward[T: np.ndarray | float](self, *, x: T, y: T) -> XY[T]:
202 """Apply the forward transform to one or more points.
204 Parameters
205 ----------
206 x : `numpy.ndarray` | `float`
207 ``x`` values of the points to transform.
208 y : `numpy.ndarray` | `float`
209 ``y`` values of the points to transform.
211 Returns
212 -------
213 `XY` [`numpy.ndarray` | `float`]
214 The transformed point or points.
215 """
216 return _standardize_xy(
217 _ast_apply(
218 self._ast_mapping.applyForward,
219 x=self._in_frame.standardize_x(x),
220 y=self._in_frame.standardize_y(y),
221 ),
222 self._out_frame,
223 )
225 def apply_inverse[T: np.ndarray | float](self, *, x: T, y: T) -> XY[T]:
226 """Apply the inverse transform to one or more points.
228 Parameters
229 ----------
230 x : `numpy.ndarray` | `float`
231 ``x`` values of the points to transform.
232 y : `numpy.ndarray` | `float`
233 ``y`` values of the points to transform.
235 Returns
236 -------
237 `XY` [`numpy.ndarray` | `float`]
238 The transformed point or points.
239 """
240 return _standardize_xy(
241 _ast_apply(
242 self._ast_mapping.applyInverse,
243 x=self._out_frame.standardize_x(x),
244 y=self._out_frame.standardize_y(y),
245 ),
246 self._in_frame,
247 )
249 def apply_forward_q(self, *, x: u.Quantity, y: u.Quantity) -> XY[u.Quantity]:
250 """Apply the forward transform to one or more unit-aware points.
252 Parameters
253 ----------
254 x
255 ``x`` values of the points to transform.
256 y
257 ``y`` values of the points to transform.
259 Returns
260 -------
261 `XY` [`astropy.units.Quantity`]
262 The transformed point or points.
263 """
264 xy = self.apply_forward(x=x.to_value(self._in_frame.unit), y=y.to_value(self._in_frame.unit))
265 return XY(xy.x * self._out_frame.unit, xy.y * self._out_frame.unit)
267 def apply_inverse_q(self, *, x: u.Quantity, y: u.Quantity) -> XY[u.Quantity]:
268 """Apply the inverse transform to one or more unit-aware points.
270 Parameters
271 ----------
272 x
273 ``x`` values of the points to transform.
274 y
275 ``y`` values of the points to transform.
277 Returns
278 -------
279 `XY` [`astropy.units.Quantity`]
280 The transformed point or points.
281 """
282 xy = self.apply_inverse(x=x.to_value(self._out_frame.unit), y=y.to_value(self._out_frame.unit))
283 return XY(xy.x * self._in_frame.unit, xy.y * self._in_frame.unit)
285 def decompose(self) -> list[Transform[Any, Any]]:
286 """Deconstruct a composed transform into its constituent parts.
288 Notes
289 -----
290 Most transforms will just return a single-element list holding
291 ``self``. Identity transform will return an empty list, and
292 transforms composed with `then` will return the original transforms.
293 Transforms constructed by `FrameSet` may or may not be decomposable.
294 """
295 if not self._components:
296 if self.in_frame == self._out_frame:
297 return []
298 else:
299 return [self]
300 else:
301 return list(self._components)
303 def inverted(self) -> Transform[O, I]:
304 """Return the inverse of this transform."""
305 return Transform[O, I](
306 self._out_frame,
307 self._in_frame,
308 self._ast_mapping.inverted(),
309 in_bounds=self.out_bounds,
310 out_bounds=self.in_bounds,
311 components=[t.inverted() for t in reversed(self._components)],
312 )
314 def then[F: Frame](self, next: Transform[O, F], remember_components: bool = True) -> Transform[I, F]:
315 """Compose two transforms into another.
317 Parameters
318 ----------
319 next
320 Another transform to apply after ``self``.
321 remember_components
322 If `True`, the returned composed transform will remember ``self``
323 and ``other`` so they can be returned by `decompose`.
324 """
325 if self._out_frame != next._in_frame:
326 raise TransformCompositionError(
327 "Cannot compose transforms that do not share a common intermediate frame: "
328 f"{self._out_frame} != {next._in_frame}."
329 )
330 components = self.decompose() + next.decompose() if remember_components else ()
331 return Transform(
332 self._in_frame,
333 next._out_frame,
334 self._ast_mapping.then(next._ast_mapping),
335 in_bounds=self.in_bounds,
336 out_bounds=next.out_bounds,
337 components=components,
338 )
340 def as_projection(self: Transform[I, SkyFrame]) -> Projection[I]:
341 """Return a `Projection` view of this transform.
343 This is only valid when `out_frame` is `~SkyFrame.ICRS`.
344 """
345 from ._projection import Projection
347 return Projection(self)
349 def as_fits_wcs(self, bbox: Box) -> astropy.wcs.WCS | None:
350 """Return a FITS WCS representation of this transform, if possible.
352 Parameters
353 ----------
354 bbox
355 Bounding box of the array the FITS WCS will describe. This
356 transform object is assumed to work on the same coordinate system
357 in which ``bbox`` is defined, while the FITS WCS will consider the
358 first row and column in that box to be ``(0, 0)`` (in Astropy
359 interfaces) or ``(1, 1)`` (in the FITS representation itself).
361 Notes
362 -----
363 This method assumes the transform maps pixel coordinates to world
364 coordinates.
366 Not all transforms can be represented exactly; when a FITS
367 represention is not possible, `None` is returned. When the returned
368 WCS is not `None`, it will have the same functional form, but it may
369 not evaluate identically due to small implementation differences in
370 the order of floating-point operations.
371 """
372 ast_frame_set = self._get_ast_frame_set()
373 _prepend_ast_shift(ast_frame_set, x=1.0 - bbox.x.start, y=1.0 - bbox.y.start, ast_domain="GRID")
374 ast_stream = astshim.StringStream()
375 ast_fits_chan = astshim.FitsChan(
376 ast_stream, "Encoding=FITS-WCS, CDMatrix=1, FitsAxisOrder=<copy>, FitsTol=0.0001"
377 )
378 ast_fits_chan.setFitsI("NAXIS1", bbox.x.size)
379 ast_fits_chan.setFitsI("NAXIS2", bbox.y.size)
380 n_writes = ast_fits_chan.write(ast_frame_set)
381 if not n_writes:
382 return None
383 header = astropy.io.fits.Header(astropy.io.fits.Card.fromstring(c) for c in ast_fits_chan)
384 return astropy.wcs.WCS(header)
386 def serialize[P: pydantic.BaseModel](
387 self, archive: OutputArchive[P], *, use_frame_sets: bool = False
388 ) -> TransformSerializationModel[P]:
389 """Serialize a transform to an archive.
391 Parameters
392 ----------
393 archive
394 Archive to serialize to.
395 use_frame_sets
396 If `True`, decompose the transform and try to reference component
397 mappings that were already serialized into a `FrameSet` in the
398 archive. Note that if multiple transforms exist between a pair of
399 frames (e.g. a `Projection` and its FITS approximation), this may
400 cause the wrong one to be saved. When this option is used, the
401 frame set must be saved before the transform, and it must be
402 deserialized before the transform as well.
404 Returns
405 -------
406 `TransformSerializationModel`
407 Serialized form of the transform.
408 """
409 model = TransformSerializationModel[P]()
410 if use_frame_sets:
411 for link in self.decompose():
412 model.frames.append(link.in_frame.serialize())
413 model.bounds.append(link.in_bounds.serialize() if link.in_bounds is not None else None)
414 for frame_set, pointer in archive.iter_frame_sets():
415 if link.in_frame in frame_set and link.out_frame in frame_set:
416 model.mappings.append(pointer)
417 break
418 else:
419 model.mappings.append(MappingSerializationModel(ast=link._ast_mapping.show()))
420 else:
421 model.frames.append(self.in_frame.serialize())
422 model.bounds.append(self.in_bounds.serialize() if self.in_bounds is not None else None)
423 model.mappings.append(MappingSerializationModel(ast=self._ast_mapping.show()))
424 model.frames.append(self.out_frame.serialize())
425 model.bounds.append(self.out_bounds.serialize() if self.out_bounds is not None else None)
426 return model
428 @staticmethod
429 def deserialize[P: pydantic.BaseModel](
430 model: TransformSerializationModel[P], archive: InputArchive[P]
431 ) -> Transform[Any, Any]:
432 """Deserialize a transform from an archive.
434 Parameters
435 ----------
436 model
437 Seralized form of the transform.
438 archive
439 Archive to read from.
440 """
441 if len(model.frames) != len(model.bounds):
442 raise ArchiveReadError(
443 f"Inconsistent lengths for 'frames' ({len(model.frames)}) and 'bounds' ({len(model.bounds)})."
444 )
445 if len(model.frames) != len(model.mappings) + 1:
446 raise ArchiveReadError(
447 f"Inconsistent lengths for 'frames' ({len(model.frames)}) and "
448 f"'mappings' ({len(model.mappings)}; should be one less)."
449 )
450 # We can't just compose onto an identity Transform if we want to
451 # preserve the FrameSet-ness of any of these mappings.
452 transform: Transform | None = None
453 for n, mapping in enumerate(model.mappings):
454 match mapping:
455 case MappingSerializationModel(ast=serialized_mapping):
456 ast_mapping = astshim.Mapping.fromString(serialized_mapping)
457 in_bounds = model.bounds[n]
458 out_bounds = model.bounds[n + 1]
459 new_transform = Transform(
460 Frame.deserialize(model.frames[n]),
461 Frame.deserialize(model.frames[n + 1]),
462 ast_mapping,
463 Bounds.deserialize(in_bounds) if in_bounds is not None else None,
464 Bounds.deserialize(out_bounds) if out_bounds is not None else None,
465 )
466 case reference:
467 frame_set = archive.get_frame_set(reference)
468 new_transform = frame_set[
469 Frame.deserialize(model.frames[n]), Frame.deserialize(model.frames[n + 1])
470 ]
471 if transform is None:
472 transform = new_transform
473 else:
474 transform = transform.then(new_transform)
475 if transform is None:
476 transform = Transform.identity(Frame.deserialize(model.frames[0]))
477 return transform
479 @staticmethod
480 def _get_archive_tree_type[P: pydantic.BaseModel](
481 pointer_type: type[P],
482 ) -> type[TransformSerializationModel[P]]:
483 """Return the serialization model type for this object for an archive
484 type that uses the given pointer type.
485 """
486 return TransformSerializationModel[pointer_type] # type: ignore
488 @staticmethod
489 def from_legacy(
490 legacy: LegacyTransform,
491 in_frame: I,
492 out_frame: O,
493 in_bounds: Bounds | None = None,
494 out_bounds: Bounds | None = None,
495 ) -> Transform[I, O]:
496 """Construct a transform from a legacy `lsst.afw.geom.Transform`.
498 Parameters
499 ----------
500 legacy : `lsst.afw.geom.Transform`
501 Legacy transform object.
502 in_frame
503 Coordinate frame for input points to the forward transform.
504 out_frame
505 Coordinate frame for output points from the forward transform.
506 in_bounds
507 The region that bounds valid input points.
508 out_bounds
509 The region that bounds valid output points.
510 """
511 return Transform(
512 in_frame,
513 out_frame,
514 legacy.getMapping(),
515 in_bounds=in_bounds,
516 out_bounds=out_bounds,
517 )
519 def to_legacy(self) -> LegacyTransform:
520 """Convert to a legacy `lsst.afw.geom.TransformPoint2ToPoint2`
521 instance.
522 """
523 from lsst.afw.geom import TransformPoint2ToPoint2 as LegacyTransform
525 return LegacyTransform(self._ast_mapping, False)
527 def _get_ast_frame_set(self) -> Any:
528 ast_frame_set = astshim.FrameSet(_make_ast_frame(self._in_frame))
529 ast_frame_set.addFrame(astshim.FrameSet.BASE, self._ast_mapping, _make_ast_frame(self._out_frame))
530 return ast_frame_set
533def _ast_apply[T: np.ndarray | float](method: Any, *, x: T, y: T) -> XY[T]:
534 # TODO: add bounds argument and check inputs
535 # TODO: broadcast arrays with different shapes.
536 xy_in = np.vstack([x, y]).astype(np.float64)
537 xy_out = method(xy_in)
538 return XY(xy_out[0, :], xy_out[1, :])
541def _prepend_ast_shift(ast_frame_set: Any, x: float, y: float, ast_domain: str) -> None:
542 ast_output_frame_id = ast_frame_set.current
543 ast_frame_set.addFrame(
544 astshim.FrameSet.BASE,
545 astshim.ShiftMap([x, y]),
546 astshim.Frame(2, f"Domain={ast_domain}"),
547 )
548 ast_frame_set.base = ast_frame_set.current
549 ast_frame_set.current = ast_output_frame_id
552def _make_ast_frame(frame: Frame) -> Any:
553 if frame is SkyFrame.ICRS:
554 return astshim.SkyFrame("")
555 ast_frame = astshim.Frame(2, f"Ident={frame._ast_ident}")
556 if frame.unit is not None:
557 fits_unit = frame.unit.to_string(format="fits")
558 ast_frame.setUnit(1, fits_unit)
559 ast_frame.setUnit(2, fits_unit)
560 ast_frame.setLabel(1, "x")
561 ast_frame.setLabel(2, "y")
562 return ast_frame
565def _standardize_xy[T: np.ndarray | float](xy: XY[T], frame: Frame) -> XY[T]:
566 return XY(x=frame.standardize_x(xy.x), y=frame.standardize_y(xy.y))
569class MappingSerializationModel(ArchiveTree):
570 """Serialization model for an AST Mapping."""
572 ast: str = pydantic.Field(description="A serialized Starlink AST Mapping, using the AST native encoding.")
575class TransformSerializationModel[P: pydantic.BaseModel](ArchiveTree):
576 """Serialization model for coordinate transforms."""
578 frames: list[SerializableFrame] = pydantic.Field(
579 default_factory=list,
580 description=textwrap.dedent(
581 """
582 List of frames that this transform passes through.
584 All transforms include at least two frames (the endpoints). Others
585 intermediate frames may be included to facilitate data-sharing
586 between transforms.
587 """
588 ),
589 )
591 bounds: list[SerializableBounds | None] = pydantic.Field(
592 default_factory=list,
593 description=textwrap.dedent(
594 """
595 List of the bounds of the ``frames`` for this transform.
597 This always has the same number of elements as ``frames``.
598 """
599 ),
600 )
602 mappings: list[P | MappingSerializationModel] = pydantic.Field(
603 default_factory=list,
604 description=textwrap.dedent(
605 """
606 The actual mappings between frames, or archive pointers to
607 serialized FrameSet objects from which they can be obtained.
609 This always has one fewer element than ``frames``.
610 """
611 ),
612 )