Coverage for python / lsst / meas / algorithms / gp_interpolation.py: 14%
113 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-06 08:38 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-06 08:38 +0000
1# This file is part of meas_algorithms.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import numpy as np
23from lsst.meas.algorithms import CloughTocher2DInterpolatorUtils as ctUtils
24from lsst.geom import Box2I, Point2I
25from lsst.afw.geom import SpanSet
26import copy
27import treegp
29import logging
31__all__ = [
32 "InterpolateOverDefectGaussianProcess",
33 "GaussianProcessTreegp",
34]
37def updateMaskFromArray(mask, bad_pixel, interpBit):
38 """
39 Update the mask array with the given bad pixels.
41 Parameters
42 ----------
43 mask : `lsst.afw.image.MaskedImage`
44 The mask image to update.
45 bad_pixel : `np.array`
46 An array-like object containing the coordinates of the bad pixels.
47 Each row should contain the x and y coordinates of a bad pixel.
48 interpBit : `int`
49 The bit value to set for the bad pixels in the mask.
50 """
51 x0 = mask.getX0()
52 y0 = mask.getY0()
53 for row in bad_pixel:
54 x = int(row[0] - x0)
55 y = int(row[1] - y0)
56 mask.array[y, x] |= interpBit
57 # TO DO --> might be better: mask.array[int(bad_pixel[:,1]-y0), int(bad_pixel[:,0]-x)] |= interpBit
60def median_with_mad_clipping(data, mad_multiplier=2.0):
61 """
62 Calculate the median of the input data after applying Median Absolute Deviation (MAD) clipping.
64 The MAD clipping method is used to remove outliers from the data. The median of the data is calculated,
65 and then the MAD is calculated as the median absolute deviation from the median. The data is then clipped
66 by removing values that are outside the range of median +/- mad_multiplier * MAD. Finally, the median of
67 the clipped data is returned.
69 Parameters:
70 -----------
71 data : `np.array`
72 Input data array.
73 mad_multiplier : `float`, optional
74 Multiplier for the MAD value used for clipping. Default is 2.0.
76 Returns:
77 --------
78 median_clipped : `float`
79 Median value of the clipped data.
81 Examples:
82 ---------
83 >>> data = [1, 2, 3, 4, 5, 100]
84 >>> median_with_mad_clipping(data)
85 3.5
86 """
87 median = np.median(data)
88 mad = np.median(np.abs(data - median))
89 clipping_range = mad_multiplier * mad
90 clipped_data = np.clip(data, median - clipping_range, median + clipping_range)
91 median_clipped = np.median(clipped_data)
92 return median_clipped
95class GaussianProcessTreegp:
96 """
97 Gaussian Process Treegp class for Gaussian Process interpolation.
99 The basic GP regression, which uses Cholesky decomposition.
101 Parameters:
102 -----------
103 std : `float`, optional
104 Standard deviation of the Gaussian Process kernel. Default is 1.0.
105 correlation_length : `float`, optional
106 Correlation length of the Gaussian Process kernel. Default is 1.0.
107 white_noise : `float`, optional
108 White noise level of the Gaussian Process. Default is 0.0.
109 mean : `float`, optional
110 Mean value of the Gaussian Process. Default is 0.0.
111 """
113 def __init__(self, std=1.0, correlation_length=1.0, white_noise=0.0, mean=0.0):
114 self.std = std
115 self.correlation_length = correlation_length
116 self.white_noise = white_noise
117 self.mean = mean
119 # Looks like weird to do that, but this is justified.
120 # in GP if no noise is provided, even if matrix
121 # can be inverted, it wont invert because of numerical
122 # issue (det(K)~0). Add a little bit of noise allow
123 # to compute a numerical solution in the case of no
124 # external noise is added. Wont happened on real
125 # image but help for unit test.
126 if self.white_noise == 0.0:
127 self.white_noise = 1e-5
129 def fit(self, x_train, y_train):
130 """
131 Fit the Gaussian Process to the given training data.
133 Parameters:
134 -----------
135 x_train : `np.array`
136 Input features for the training data.
137 y_train : `np.array`
138 Target values for the training data.
139 """
140 kernel = f"{self.std}**2 * RBF({self.correlation_length})"
141 self.gp = treegp.GPInterpolation(
142 kernel=kernel,
143 optimizer="none",
144 normalize=False,
145 white_noise=self.white_noise,
146 )
147 self.gp.initialize(x_train, y_train - self.mean)
148 self.gp.solve()
150 def predict(self, x_predict):
151 """
152 Predict the target values for the given input features.
154 Parameters:
155 -----------
156 x_predict : `np.array`
157 Input features for the prediction.
159 Returns:
160 --------
161 y_pred : `np.array`
162 Predicted target values.
163 """
164 y_pred = self.gp.predict(x_predict)
165 return y_pred + self.mean
168class InterpolateOverDefectGaussianProcess:
169 """
170 InterpolateOverDefectGaussianProcess class performs Gaussian Process
171 (GP) interpolation over defects in an image.
173 Parameters:
174 -----------
175 masked_image : `lsst.afw.image.MaskedImage`
176 The masked image containing the defects to be interpolated.
177 defects : `list`[`str`], optional
178 The types of defects to be interpolated. Default is ["SAT"].
179 fwhm : `float`, optional
180 The full width at half maximum (FWHM) of the PSF. Default is 5.
181 bin_spacing : `int`, optional
182 The spacing between bins for good pixel binning. Default is 10.
183 threshold_dynamic_binning : `int`, optional
184 The threshold for dynamic binning. Default is 1000.
185 threshold_subdivide : `int`, optional
186 The threshold for sub-dividing the bad pixel array to avoid memory error. Default is 20000.
187 correlation_length_cut : `int`, optional
188 The factor by which to dilate the bounding box around defects. Default is 5.
189 log : `lsst.log.Log`, `logging.Logger` or `None`, optional
190 Logger object used to write out messages. If `None` a default
191 logger will be used.
192 """
194 def __init__(
195 self,
196 masked_image,
197 defects=["SAT"],
198 fwhm=5,
199 bin_image=True,
200 bin_spacing=10,
201 threshold_dynamic_binning=1000,
202 threshold_subdivide=20000,
203 correlation_length_cut=5,
204 log=None,
205 ):
207 self.log = log or logging.getLogger(__name__)
209 self.bin_image = bin_image
210 self.bin_spacing = bin_spacing
211 self.threshold_subdivide = threshold_subdivide
212 self.threshold_dynamic_binning = threshold_dynamic_binning
214 self.masked_image = masked_image
215 self.defects = defects
216 self.correlation_length = fwhm
217 self.correlation_length_cut = correlation_length_cut
219 self.interpBit = self.masked_image.mask.getPlaneBitMask("INTRP")
221 def run(self):
222 """
223 Interpolate over the defects in the image.
225 Change self.masked_image .
226 """
227 if self.defects == [] or self.defects is None:
228 self.log.info("No defects found. No interpolation performed.")
229 else:
230 mask = self.masked_image.getMask()
231 bad_pixel_mask = mask.getPlaneBitMask(self.defects)
232 bad_mask_span_set = SpanSet.fromMask(mask, bad_pixel_mask).split()
234 bbox = self.masked_image.getBBox()
235 global_xmin, global_xmax = bbox.minX, bbox.maxX
236 global_ymin, global_ymax = bbox.minY, bbox.maxY
238 for spanset in bad_mask_span_set:
239 bbox = spanset.getBBox()
240 # Dilate the bbox to make sure we have enough good pixels around the defect
241 # For now, we dilate by 5 times the correlation length
242 # For GP with the isotropic kernel, points at the default value of
243 # correlation_length_cut=5 have negligible effect on the prediction.
244 bbox = bbox.dilatedBy(
245 int(self.correlation_length * self.correlation_length_cut)
246 ) # need integer as input.
247 xmin, xmax = max([global_xmin, bbox.minX]), min(global_xmax, bbox.maxX)
248 ymin, ymax = max([global_ymin, bbox.minY]), min(global_ymax, bbox.maxY)
249 localBox = Box2I(Point2I(xmin, ymin), Point2I(xmax - xmin, ymax - ymin))
250 masked_sub_image = self.masked_image[localBox]
252 masked_sub_image = self.interpolate_masked_sub_image(masked_sub_image)
253 self.masked_image[localBox] = masked_sub_image
255 def _good_pixel_binning(self, pixels):
256 """
257 Performs pixel binning using treegp.meanify
259 Parameters:
260 -----------
261 pixels : `np.array`
262 The array of pixels.
264 Returns:
265 --------
266 `np.array`
267 The binned array of pixels.
268 """
270 n_pixels = len(pixels[:, 0])
271 dynamic_binning = int(np.sqrt(n_pixels / self.threshold_dynamic_binning))
272 if n_pixels / self.bin_spacing**2 < n_pixels / dynamic_binning**2:
273 bin_spacing = self.bin_spacing
274 else:
275 bin_spacing = dynamic_binning
276 binning = treegp.meanify(bin_spacing=bin_spacing, statistics="mean")
277 binning.add_field(
278 pixels[:, :2],
279 pixels[:, 2:].T,
280 )
281 binning.meanify()
282 return np.array(
283 [binning.coords0[:, 0], binning.coords0[:, 1], binning.params0]
284 ).T
286 def interpolate_masked_sub_image(self, masked_sub_image):
287 """
288 Interpolate the masked sub-image.
290 Parameters:
291 -----------
292 masked_sub_image : `lsst.afw.image.MaskedImage`
293 The sub-masked image to be interpolated.
295 Returns:
296 --------
297 `lsst.afw.image.MaskedImage`
298 The interpolated sub-masked image.
299 """
301 cut = int(
302 self.correlation_length * self.correlation_length_cut
303 ) # need integer as input.
304 bad_pixel, good_pixel = ctUtils.findGoodPixelsAroundBadPixels(
305 masked_sub_image, self.defects, buffer=cut
306 )
307 # Do nothing if bad pixel is None.
308 if bad_pixel.size == 0 or good_pixel.size == 0:
309 self.log.info("No bad or good pixels found. No interpolation performed.")
310 return masked_sub_image
311 # Do GP interpolation if bad pixel found.
312 else:
313 # gp interpolation
314 sub_image_array = masked_sub_image.getVariance().array
315 white_noise = np.sqrt(
316 np.mean(sub_image_array[np.isfinite(sub_image_array)])
317 )
318 kernel_amplitude = np.max(good_pixel[:, 2:])
319 if not np.isfinite(kernel_amplitude):
320 filter_finite = np.isfinite(good_pixel[:, 2:]).T[0]
321 good_pixel = good_pixel[filter_finite]
322 if good_pixel.size == 0:
323 self.log.info(
324 "No bad or good pixels found. No interpolation performed."
325 )
326 return masked_sub_image
327 # kernel amplitude might be better described by maximum value of good pixel given
328 # the data and not really a random gaussian field.
329 kernel_amplitude = np.max(good_pixel[:, 2:])
331 if self.bin_image:
332 try:
333 good_pixel = self._good_pixel_binning(copy.deepcopy(good_pixel))
334 except Exception:
335 self.log.info(
336 "Binning failed, use original good pixel array in interpolation."
337 )
339 # put this after binning as computing median is O(n*log(n))
340 clipped_median = median_with_mad_clipping(good_pixel[:, 2:])
342 gp = GaussianProcessTreegp(
343 std=np.sqrt(kernel_amplitude),
344 correlation_length=self.correlation_length,
345 white_noise=white_noise,
346 mean=clipped_median,
347 )
348 gp.fit(good_pixel[:, :2], np.squeeze(good_pixel[:, 2:]))
349 if bad_pixel.size < self.threshold_subdivide:
350 gp_predict = gp.predict(bad_pixel[:, :2])
351 bad_pixel[:, 2:] = gp_predict.reshape(np.shape(bad_pixel[:, 2:]))
352 else:
353 self.log.info("sub-divide bad pixel array to avoid memory error.")
354 for i in range(0, len(bad_pixel), self.threshold_subdivide):
355 end = min(i + self.threshold_subdivide, len(bad_pixel))
356 gp_predict = gp.predict(bad_pixel[i:end, :2])
357 bad_pixel[i:end, 2:] = gp_predict.reshape(
358 np.shape(bad_pixel[i:end, 2:])
359 )
361 # Update values
362 ctUtils.updateImageFromArray(masked_sub_image.image, bad_pixel)
363 updateMaskFromArray(masked_sub_image.mask, bad_pixel, self.interpBit)
364 return masked_sub_image