Coverage for python / lsst / images / fields / _spline.py: 25%

139 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-30 09:07 +0000

1# This file is part of lsst-images. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# Use of this source code is governed by a 3-clause BSD-style 

10# license that can be found in the LICENSE file. 

11 

12from __future__ import annotations 

13 

14__all__ = ("SplineField", "SplineFieldSerializationModel") 

15 

16from typing import TYPE_CHECKING, Any, Literal, final 

17 

18import astropy.units 

19import numpy as np 

20import pydantic 

21from scipy.interpolate import Akima1DInterpolator 

22 

23from .._concrete_bounds import SerializableBounds, deserialize_bounds 

24from .._geom import Bounds, Box 

25from .._image import Image 

26from ..serialization import ArchiveTree, ArrayReferenceModel, InlineArray, InputArchive, OutputArchive, Unit 

27from ._base import BaseField 

28 

29if TYPE_CHECKING: 

30 try: 

31 from lsst.afw.math import BackgroundMI as LegacyBackground 

32 except ImportError: 

33 type LegacyBackground = Any # type: ignore[no-redef] 

34 

35 

36@final 

37class SplineField(BaseField): 

38 """A 2-d Akima spline interpolation of data on a regular grid. 

39 

40 Parameters 

41 ---------- 

42 bounds 

43 The region where this field can be evaluated. 

44 data 

45 The data points to be interpolated. Missing values (indicated by NaN) 

46 are allowed. Will be set to read-only in place. 

47 y 

48 Coordinates for the first dimension of ``data``. Will be set to 

49 read-only in place. 

50 x 

51 Coordinates for the second dimension of ``data``. Will be set to 

52 read-only in place. 

53 unit 

54 Units of the field. 

55 

56 Notes 

57 ----- 

58 This field is much faster to evaluate on a grid via `render` than at 

59 arbitrary points via the function-call operator. 

60 """ 

61 

62 def __init__( 

63 self, 

64 bounds: Bounds, 

65 data: np.ndarray, 

66 *, 

67 y: np.ndarray, 

68 x: np.ndarray, 

69 unit: astropy.units.UnitBase | None = None, 

70 ): 

71 if isinstance(data, astropy.units.Quantity): 

72 if unit is not None: 

73 raise TypeError("If 'data' is a Quantity, 'unit' cannot be provided separately.") 

74 unit = data.unit 

75 data = data.to_value() 

76 if data.ndim != 2: 

77 raise ValueError("'data' must be 2-d.") 

78 if y.ndim != 1: 

79 raise ValueError("'y' must be 1-d.") 

80 if x.ndim != 1: 

81 raise ValueError("'x' must be 1-d.") 

82 if data.shape != y.shape + x.shape: 

83 raise ValueError( 

84 f"Shape of 2-d 'data' {data.shape} does not match " 

85 f"expected 1-d 'y' {y.shape} and/or 'x' {x.shape}." 

86 ) 

87 self._bounds = bounds 

88 self._data = data 

89 self._data.flags.writeable = False 

90 self._x = x 

91 self._x.flags.writeable = False 

92 self._y = y 

93 self._y.flags.writeable = False 

94 self._unit = unit 

95 

96 @property 

97 def bounds(self) -> Bounds: 

98 return self._bounds 

99 

100 @property 

101 def unit(self) -> astropy.units.UnitBase | None: 

102 return self._unit 

103 

104 @property 

105 def data(self) -> np.ndarray: 

106 """The data points to be interpolated (`numpy.ndarray`). 

107 

108 May have missing values indicated by NaNs. 

109 """ 

110 return self._data 

111 

112 @property 

113 def x(self) -> np.ndarray: 

114 """Coordinates for the second dimension of `data` (`numpy.ndarray`).""" 

115 return self._x 

116 

117 @property 

118 def y(self) -> np.ndarray: 

119 """Coordinates for the first dimension of `data` (`numpy.ndarray`).""" 

120 return self._y 

121 

122 def evaluate( 

123 self, *, x: np.ndarray, y: np.ndarray, quantity: bool = False 

124 ) -> np.ndarray | astropy.units.Quantity: 

125 y, x = np.broadcast_arrays(y, x) 

126 xg = self._x 

127 y_render = np.zeros(xg.shape + y.shape, dtype=np.float64) 

128 mask = np.zeros(xg.size, dtype=bool) 

129 for j in range(xg.size): 

130 if (y_interpolator := self._make_y_interpolator(j)) is not None: 

131 y_render[j, ...] = y_interpolator(y) 

132 mask[j] = True 

133 if not np.all(mask): 

134 y_render = y_render[mask, ...] 

135 xg = xg[mask] 

136 result = np.zeros(y.shape, dtype=np.float64) 

137 # There doesn't seem to be a way to avoid looping in Python here; 

138 # maybe someday we'll push this down to a compiled language. 

139 for i, xv in np.ndenumerate(x): 

140 if (x_interpolator := self._make_1d_interpolator(xg, y_render[:, *i])) is None: 

141 raise ValueError("No valid data points.") 

142 v = x_interpolator(xv) 

143 result[*i] = v 

144 if quantity: 

145 return astropy.units.Quantity(result, self._unit) 

146 return result 

147 

148 def render(self, bbox: Box | None = None, *, dtype: np.typing.DTypeLike | None = None) -> Image: 

149 if bbox is None: 

150 bbox = self.bounds.bbox 

151 xg = self._x 

152 y_render = np.zeros((xg.size, bbox.y.size), dtype=dtype) 

153 mask = np.zeros(xg.size, dtype=bool) 

154 for j in range(xg.size): # we have to loop, but only over bins, not evaluation points. 

155 if (y_interpolator := self._make_y_interpolator(j)) is not None: 

156 y_render[j, :] = y_interpolator(bbox.y.arange) 

157 mask[j] = True 

158 if not np.all(mask): 

159 y_render = y_render[mask, :] 

160 xg = xg[mask] 

161 if (x_interpolator := self._make_1d_interpolator(xg, y_render)) is None: 

162 raise ValueError("No valid data points.") 

163 rendered_array = x_interpolator(bbox.x.arange) 

164 return Image(rendered_array.transpose().copy(), bbox=bbox, unit=self._unit, dtype=dtype) 

165 

166 def multiply_constant( 

167 self, factor: float | astropy.units.Quantity | astropy.units.UnitBase 

168 ) -> SplineField: 

169 factor, unit = self._handle_factor_units(factor) 

170 return SplineField(self._bounds, self._data * factor, y=self._y, x=self._x, unit=unit) 

171 

172 def serialize(self, archive: OutputArchive[Any]) -> SplineFieldSerializationModel: 

173 """Serialize the spline field to an output archive.""" 

174 return SplineFieldSerializationModel( 

175 bounds=self.bounds.serialize(), 

176 data=archive.add_array(self._data, name="data"), 

177 y=self._y, 

178 x=self._x, 

179 unit=self._unit, 

180 ) 

181 

182 @staticmethod 

183 def deserialize(model: SplineFieldSerializationModel, archive: InputArchive[Any]) -> SplineField: 

184 """Deserialize the spline field from an input archive.""" 

185 return SplineField( 

186 deserialize_bounds(model.bounds), 

187 archive.get_array(model.data), 

188 y=model.y, 

189 x=model.x, 

190 unit=model.unit, 

191 ) 

192 

193 @staticmethod 

194 def _get_archive_tree_type( 

195 pointer_type: type[Any], 

196 ) -> type[SplineFieldSerializationModel]: 

197 """Return the serialization model type for this object for an archive 

198 type that uses the given pointer type. 

199 """ 

200 return SplineFieldSerializationModel 

201 

202 @staticmethod 

203 def from_legacy_background( 

204 legacy_background: LegacyBackground, 

205 unit: astropy.units.UnitBase | None = None, 

206 ) -> SplineField: 

207 """Convert from a legacy `lsst.afw.math.BackgroundMI` instance. 

208 

209 Notes 

210 ----- 

211 `SplineField.render` and the `lsst.afw` background interpolator both 

212 use Akima splines, but with slightly different boundary conditions. 

213 They should produce equivalent to single-precision round-off error 

214 when evaluated within the region enclosed by bin centers (i.e. where 

215 no extrapolation is necessary) and when there are five or more 

216 points to be interpolated in each row and column. 

217 """ 

218 from lsst.afw.math import ApproximateControl, Interpolate 

219 

220 bg_control = legacy_background.getBackgroundControl() 

221 approx_control = bg_control.getApproximateControl() 

222 stats_image = legacy_background.getStatsImage() 

223 if approx_control.getStyle() != ApproximateControl.UNKNOWN: 

224 raise TypeError("Legacy background uses Chebyshev approximation, not splines.") 

225 if bg_control.getInterpStyle() != Interpolate.AKIMA_SPLINE: 

226 raise TypeError("Legacy background does not use Akima spline interpolation.") 

227 x = legacy_background.getBinCentersX() 

228 y = legacy_background.getBinCentersY() 

229 return SplineField( 

230 Box.from_legacy(legacy_background.getImageBBox()), stats_image.image.array, x=x, y=y, unit=unit 

231 ) 

232 

233 def _make_1d_interpolator(self, loc: np.ndarray, val: np.ndarray) -> Akima1DInterpolator | None: 

234 match len(loc): 

235 case 0: 

236 return None 

237 case 1: 

238 # SciPy can handle only two points by downgrading to linear 

239 # interpolation, but it raises if given only one. Mock up 

240 # two for the nearest-neighbor fallback. 

241 return Akima1DInterpolator(np.array([loc[0], loc[0]]), np.array([val[0], val[0]])) 

242 case _: 

243 return Akima1DInterpolator(loc, val, extrapolate=True) 

244 

245 def _make_y_interpolator(self, j: int) -> Akima1DInterpolator | None: 

246 y = self._y 

247 z = self._data[:, j] 

248 mask = np.isfinite(z) 

249 if not np.all(mask): 

250 y = y[mask] 

251 z = z[mask] 

252 del mask 

253 return self._make_1d_interpolator(y, z) 

254 

255 

256class SplineFieldSerializationModel(ArchiveTree): 

257 """Serialization model for `SplineField`.""" 

258 

259 bounds: SerializableBounds = pydantic.Field(description=("The region where this field can be evaluated.")) 

260 

261 data: ArrayReferenceModel = pydantic.Field( 

262 description="2-d data to interpolate. NaNs indicate missing values." 

263 ) 

264 

265 y: InlineArray = pydantic.Field(description="Row positions of the data points.") 

266 

267 x: InlineArray = pydantic.Field(description="Column positions of the data points.") 

268 

269 unit: Unit | None = pydantic.Field(default=None, description="Units of the field.") 

270 

271 field_type: Literal["SPLINE"] = "SPLINE" 

272 

273 def finish_deserialize(self, archive: InputArchive) -> SplineField: 

274 return SplineField.deserialize(self, archive)