Coverage for python / lsst / images / fields / _spline.py: 25%
139 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-28 09:01 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-28 09:01 +0000
1# This file is part of lsst-images.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# Use of this source code is governed by a 3-clause BSD-style
10# license that can be found in the LICENSE file.
12from __future__ import annotations
14__all__ = ("SplineField", "SplineFieldSerializationModel")
16from typing import TYPE_CHECKING, Any, Literal, final
18import astropy.units
19import numpy as np
20import pydantic
21from scipy.interpolate import Akima1DInterpolator
23from .._concrete_bounds import SerializableBounds, deserialize_bounds
24from .._geom import Bounds, Box
25from .._image import Image
26from ..serialization import ArchiveTree, ArrayReferenceModel, InlineArray, InputArchive, OutputArchive, Unit
27from ._base import BaseField
29if TYPE_CHECKING:
30 try:
31 from lsst.afw.math import BackgroundMI as LegacyBackground
32 except ImportError:
33 type LegacyBackground = Any # type: ignore[no-redef]
36@final
37class SplineField(BaseField):
38 """A 2-d Akima spline interpolation of data on a regular grid.
40 Parameters
41 ----------
42 bounds
43 The region where this field can be evaluated.
44 data
45 The data points to be interpolated. Missing values (indicated by NaN)
46 are allowed. Will be set to read-only in place.
47 y
48 Coordinates for the first dimension of ``data``. Will be set to
49 read-only in place.
50 x
51 Coordinates for the second dimension of ``data``. Will be set to
52 read-only in place.
53 unit
54 Units of the field.
56 Notes
57 -----
58 This field is much faster to evaluate on a grid via `render` than at
59 arbitrary points via the function-call operator.
60 """
62 def __init__(
63 self,
64 bounds: Bounds,
65 data: np.ndarray,
66 *,
67 y: np.ndarray,
68 x: np.ndarray,
69 unit: astropy.units.UnitBase | None = None,
70 ):
71 if isinstance(data, astropy.units.Quantity):
72 if unit is not None:
73 raise TypeError("If 'data' is a Quantity, 'unit' cannot be provided separately.")
74 unit = data.unit
75 data = data.to_value()
76 if data.ndim != 2:
77 raise ValueError("'data' must be 2-d.")
78 if y.ndim != 1:
79 raise ValueError("'y' must be 1-d.")
80 if x.ndim != 1:
81 raise ValueError("'x' must be 1-d.")
82 if data.shape != y.shape + x.shape:
83 raise ValueError(
84 f"Shape of 2-d 'data' {data.shape} does not match "
85 f"expected 1-d 'y' {y.shape} and/or 'x' {x.shape}."
86 )
87 self._bounds = bounds
88 self._data = data
89 self._data.flags.writeable = False
90 self._x = x
91 self._x.flags.writeable = False
92 self._y = y
93 self._y.flags.writeable = False
94 self._unit = unit
96 @property
97 def bounds(self) -> Bounds:
98 return self._bounds
100 @property
101 def unit(self) -> astropy.units.UnitBase | None:
102 return self._unit
104 @property
105 def data(self) -> np.ndarray:
106 """The data points to be interpolated (`numpy.ndarray`).
108 May have missing values indicated by NaNs.
109 """
110 return self._data
112 @property
113 def x(self) -> np.ndarray:
114 """Coordinates for the second dimension of `data` (`numpy.ndarray`)."""
115 return self._x
117 @property
118 def y(self) -> np.ndarray:
119 """Coordinates for the first dimension of `data` (`numpy.ndarray`)."""
120 return self._y
122 def evaluate(
123 self, *, x: np.ndarray, y: np.ndarray, quantity: bool = False
124 ) -> np.ndarray | astropy.units.Quantity:
125 y, x = np.broadcast_arrays(y, x)
126 xg = self._x
127 y_render = np.zeros(xg.shape + y.shape, dtype=np.float64)
128 mask = np.zeros(xg.size, dtype=bool)
129 for j in range(xg.size):
130 if (y_interpolator := self._make_y_interpolator(j)) is not None:
131 y_render[j, ...] = y_interpolator(y)
132 mask[j] = True
133 if not np.all(mask):
134 y_render = y_render[mask, ...]
135 xg = xg[mask]
136 result = np.zeros(y.shape, dtype=np.float64)
137 # There doesn't seem to be a way to avoid looping in Python here;
138 # maybe someday we'll push this down to a compiled language.
139 for i, xv in np.ndenumerate(x):
140 if (x_interpolator := self._make_1d_interpolator(xg, y_render[:, *i])) is None:
141 raise ValueError("No valid data points.")
142 v = x_interpolator(xv)
143 result[*i] = v
144 if quantity:
145 return astropy.units.Quantity(result, self._unit)
146 return result
148 def render(self, bbox: Box | None = None, *, dtype: np.typing.DTypeLike | None = None) -> Image:
149 if bbox is None:
150 bbox = self.bounds.bbox
151 xg = self._x
152 y_render = np.zeros((xg.size, bbox.y.size), dtype=dtype)
153 mask = np.zeros(xg.size, dtype=bool)
154 for j in range(xg.size): # we have to loop, but only over bins, not evaluation points.
155 if (y_interpolator := self._make_y_interpolator(j)) is not None:
156 y_render[j, :] = y_interpolator(bbox.y.arange)
157 mask[j] = True
158 if not np.all(mask):
159 y_render = y_render[mask, :]
160 xg = xg[mask]
161 if (x_interpolator := self._make_1d_interpolator(xg, y_render)) is None:
162 raise ValueError("No valid data points.")
163 rendered_array = x_interpolator(bbox.x.arange)
164 return Image(rendered_array.transpose().copy(), bbox=bbox, unit=self._unit, dtype=dtype)
166 def multiply_constant(
167 self, factor: float | astropy.units.Quantity | astropy.units.UnitBase
168 ) -> SplineField:
169 factor, unit = self._handle_factor_units(factor)
170 return SplineField(self._bounds, self._data * factor, y=self._y, x=self._x, unit=unit)
172 def serialize(self, archive: OutputArchive[Any]) -> SplineFieldSerializationModel:
173 """Serialize the spline field to an output archive."""
174 return SplineFieldSerializationModel(
175 bounds=self.bounds.serialize(),
176 data=archive.add_array(self._data, name="data"),
177 y=self._y,
178 x=self._x,
179 unit=self._unit,
180 )
182 @staticmethod
183 def deserialize(model: SplineFieldSerializationModel, archive: InputArchive[Any]) -> SplineField:
184 """Deserialize the spline field from an input archive."""
185 return SplineField(
186 deserialize_bounds(model.bounds),
187 archive.get_array(model.data),
188 y=model.y,
189 x=model.x,
190 unit=model.unit,
191 )
193 @staticmethod
194 def _get_archive_tree_type(
195 pointer_type: type[Any],
196 ) -> type[SplineFieldSerializationModel]:
197 """Return the serialization model type for this object for an archive
198 type that uses the given pointer type.
199 """
200 return SplineFieldSerializationModel
202 @staticmethod
203 def from_legacy_background(
204 legacy_background: LegacyBackground,
205 unit: astropy.units.UnitBase | None = None,
206 ) -> SplineField:
207 """Convert from a legacy `lsst.afw.math.BackgroundMI` instance.
209 Notes
210 -----
211 `SplineField.render` and the `lsst.afw` background interpolator both
212 use Akima splines, but with slightly different boundary conditions.
213 They should produce equivalent to single-precision round-off error
214 when evaluated within the region enclosed by bin centers (i.e. where
215 no extrapolation is necessary) and when there are five or more
216 points to be interpolated in each row and column.
217 """
218 from lsst.afw.math import ApproximateControl, Interpolate
220 bg_control = legacy_background.getBackgroundControl()
221 approx_control = bg_control.getApproximateControl()
222 stats_image = legacy_background.getStatsImage()
223 if approx_control.getStyle() != ApproximateControl.UNKNOWN:
224 raise TypeError("Legacy background uses Chebyshev approximation, not splines.")
225 if bg_control.getInterpStyle() != Interpolate.AKIMA_SPLINE:
226 raise TypeError("Legacy background does not use Akima spline interpolation.")
227 x = legacy_background.getBinCentersX()
228 y = legacy_background.getBinCentersY()
229 return SplineField(
230 Box.from_legacy(legacy_background.getImageBBox()), stats_image.image.array, x=x, y=y, unit=unit
231 )
233 def _make_1d_interpolator(self, loc: np.ndarray, val: np.ndarray) -> Akima1DInterpolator | None:
234 match len(loc):
235 case 0:
236 return None
237 case 1:
238 # SciPy can handle only two points by downgrading to linear
239 # interpolation, but it raises if given only one. Mock up
240 # two for the nearest-neighbor fallback.
241 return Akima1DInterpolator(np.array([loc[0], loc[0]]), np.array([val[0], val[0]]))
242 case _:
243 return Akima1DInterpolator(loc, val, extrapolate=True)
245 def _make_y_interpolator(self, j: int) -> Akima1DInterpolator | None:
246 y = self._y
247 z = self._data[:, j]
248 mask = np.isfinite(z)
249 if not np.all(mask):
250 y = y[mask]
251 z = z[mask]
252 del mask
253 return self._make_1d_interpolator(y, z)
256class SplineFieldSerializationModel(ArchiveTree):
257 """Serialization model for `SplineField`."""
259 bounds: SerializableBounds = pydantic.Field(description=("The region where this field can be evaluated."))
261 data: ArrayReferenceModel = pydantic.Field(
262 description="2-d data to interpolate. NaNs indicate missing values."
263 )
265 y: InlineArray = pydantic.Field(description="Row positions of the data points.")
267 x: InlineArray = pydantic.Field(description="Column positions of the data points.")
269 unit: Unit | None = pydantic.Field(default=None, description="Units of the field.")
271 field_type: Literal["SPLINE"] = "SPLINE"
273 def finish_deserialize(self, archive: InputArchive) -> SplineField:
274 return SplineField.deserialize(self, archive)