Coverage for python / lsst / meas / algorithms / gp_interpolation.py: 14%

113 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-14 23:55 +0000

1# This file is part of meas_algorithms. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import numpy as np 

23from lsst.meas.algorithms import CloughTocher2DInterpolatorUtils as ctUtils 

24from lsst.geom import Box2I, Point2I 

25from lsst.afw.geom import SpanSet 

26import copy 

27import treegp 

28 

29import logging 

30 

31__all__ = [ 

32 "InterpolateOverDefectGaussianProcess", 

33 "GaussianProcessTreegp", 

34] 

35 

36 

37def updateMaskFromArray(mask, bad_pixel, interpBit): 

38 """ 

39 Update the mask array with the given bad pixels. 

40 

41 Parameters 

42 ---------- 

43 mask : `lsst.afw.image.MaskedImage` 

44 The mask image to update. 

45 bad_pixel : `np.array` 

46 An array-like object containing the coordinates of the bad pixels. 

47 Each row should contain the x and y coordinates of a bad pixel. 

48 interpBit : `int` 

49 The bit value to set for the bad pixels in the mask. 

50 """ 

51 x0 = mask.getX0() 

52 y0 = mask.getY0() 

53 for row in bad_pixel: 

54 x = int(row[0] - x0) 

55 y = int(row[1] - y0) 

56 mask.array[y, x] |= interpBit 

57 # TO DO --> might be better: mask.array[int(bad_pixel[:,1]-y0), int(bad_pixel[:,0]-x)] |= interpBit 

58 

59 

60def median_with_mad_clipping(data, mad_multiplier=2.0): 

61 """ 

62 Calculate the median of the input data after applying Median Absolute Deviation (MAD) clipping. 

63 

64 The MAD clipping method is used to remove outliers from the data. The median of the data is calculated, 

65 and then the MAD is calculated as the median absolute deviation from the median. The data is then clipped 

66 by removing values that are outside the range of median +/- mad_multiplier * MAD. Finally, the median of 

67 the clipped data is returned. 

68 

69 Parameters: 

70 ----------- 

71 data : `np.array` 

72 Input data array. 

73 mad_multiplier : `float`, optional 

74 Multiplier for the MAD value used for clipping. Default is 2.0. 

75 

76 Returns: 

77 -------- 

78 median_clipped : `float` 

79 Median value of the clipped data. 

80 

81 Examples: 

82 --------- 

83 >>> data = [1, 2, 3, 4, 5, 100] 

84 >>> median_with_mad_clipping(data) 

85 3.5 

86 """ 

87 median = np.median(data) 

88 mad = np.median(np.abs(data - median)) 

89 clipping_range = mad_multiplier * mad 

90 clipped_data = np.clip(data, median - clipping_range, median + clipping_range) 

91 median_clipped = np.median(clipped_data) 

92 return median_clipped 

93 

94 

95class GaussianProcessTreegp: 

96 """ 

97 Gaussian Process Treegp class for Gaussian Process interpolation. 

98 

99 The basic GP regression, which uses Cholesky decomposition. 

100 

101 Parameters: 

102 ----------- 

103 std : `float`, optional 

104 Standard deviation of the Gaussian Process kernel. Default is 1.0. 

105 correlation_length : `float`, optional 

106 Correlation length of the Gaussian Process kernel. Default is 1.0. 

107 white_noise : `float`, optional 

108 White noise level of the Gaussian Process. Default is 0.0. 

109 mean : `float`, optional 

110 Mean value of the Gaussian Process. Default is 0.0. 

111 """ 

112 

113 def __init__(self, std=1.0, correlation_length=1.0, white_noise=0.0, mean=0.0): 

114 self.std = std 

115 self.correlation_length = correlation_length 

116 self.white_noise = white_noise 

117 self.mean = mean 

118 

119 # Looks like weird to do that, but this is justified. 

120 # in GP if no noise is provided, even if matrix 

121 # can be inverted, it wont invert because of numerical 

122 # issue (det(K)~0). Add a little bit of noise allow 

123 # to compute a numerical solution in the case of no 

124 # external noise is added. Wont happened on real 

125 # image but help for unit test. 

126 if self.white_noise == 0.0: 

127 self.white_noise = 1e-5 

128 

129 def fit(self, x_train, y_train): 

130 """ 

131 Fit the Gaussian Process to the given training data. 

132 

133 Parameters: 

134 ----------- 

135 x_train : `np.array` 

136 Input features for the training data. 

137 y_train : `np.array` 

138 Target values for the training data. 

139 """ 

140 kernel = f"{self.std}**2 * RBF({self.correlation_length})" 

141 self.gp = treegp.GPInterpolation( 

142 kernel=kernel, 

143 optimizer="none", 

144 normalize=False, 

145 white_noise=self.white_noise, 

146 ) 

147 self.gp.initialize(x_train, y_train - self.mean) 

148 self.gp.solve() 

149 

150 def predict(self, x_predict): 

151 """ 

152 Predict the target values for the given input features. 

153 

154 Parameters: 

155 ----------- 

156 x_predict : `np.array` 

157 Input features for the prediction. 

158 

159 Returns: 

160 -------- 

161 y_pred : `np.array` 

162 Predicted target values. 

163 """ 

164 y_pred = self.gp.predict(x_predict) 

165 return y_pred + self.mean 

166 

167 

168class InterpolateOverDefectGaussianProcess: 

169 """ 

170 InterpolateOverDefectGaussianProcess class performs Gaussian Process 

171 (GP) interpolation over defects in an image. 

172 

173 Parameters: 

174 ----------- 

175 masked_image : `lsst.afw.image.MaskedImage` 

176 The masked image containing the defects to be interpolated. 

177 defects : `list`[`str`], optional 

178 The types of defects to be interpolated. Default is ["SAT"]. 

179 fwhm : `float`, optional 

180 The full width at half maximum (FWHM) of the PSF. Default is 5. 

181 bin_spacing : `int`, optional 

182 The spacing between bins for good pixel binning. Default is 10. 

183 threshold_dynamic_binning : `int`, optional 

184 The threshold for dynamic binning. Default is 1000. 

185 threshold_subdivide : `int`, optional 

186 The threshold for sub-dividing the bad pixel array to avoid memory error. Default is 20000. 

187 correlation_length_cut : `int`, optional 

188 The factor by which to dilate the bounding box around defects. Default is 5. 

189 log : `lsst.log.Log`, `logging.Logger` or `None`, optional 

190 Logger object used to write out messages. If `None` a default 

191 logger will be used. 

192 """ 

193 

194 def __init__( 

195 self, 

196 masked_image, 

197 defects=["SAT"], 

198 fwhm=5, 

199 bin_image=True, 

200 bin_spacing=10, 

201 threshold_dynamic_binning=1000, 

202 threshold_subdivide=20000, 

203 correlation_length_cut=5, 

204 log=None, 

205 ): 

206 

207 self.log = log or logging.getLogger(__name__) 

208 

209 self.bin_image = bin_image 

210 self.bin_spacing = bin_spacing 

211 self.threshold_subdivide = threshold_subdivide 

212 self.threshold_dynamic_binning = threshold_dynamic_binning 

213 

214 self.masked_image = masked_image 

215 self.defects = defects 

216 self.correlation_length = fwhm 

217 self.correlation_length_cut = correlation_length_cut 

218 

219 self.interpBit = self.masked_image.mask.getPlaneBitMask("INTRP") 

220 

221 def run(self): 

222 """ 

223 Interpolate over the defects in the image. 

224 

225 Change self.masked_image . 

226 """ 

227 if self.defects == [] or self.defects is None: 

228 self.log.info("No defects found. No interpolation performed.") 

229 else: 

230 mask = self.masked_image.getMask() 

231 bad_pixel_mask = mask.getPlaneBitMask(self.defects) 

232 bad_mask_span_set = SpanSet.fromMask(mask, bad_pixel_mask).split() 

233 

234 bbox = self.masked_image.getBBox() 

235 global_xmin, global_xmax = bbox.minX, bbox.maxX 

236 global_ymin, global_ymax = bbox.minY, bbox.maxY 

237 

238 for spanset in bad_mask_span_set: 

239 bbox = spanset.getBBox() 

240 # Dilate the bbox to make sure we have enough good pixels around the defect 

241 # For now, we dilate by 5 times the correlation length 

242 # For GP with the isotropic kernel, points at the default value of 

243 # correlation_length_cut=5 have negligible effect on the prediction. 

244 bbox = bbox.dilatedBy( 

245 int(self.correlation_length * self.correlation_length_cut) 

246 ) # need integer as input. 

247 xmin, xmax = max([global_xmin, bbox.minX]), min(global_xmax, bbox.maxX) 

248 ymin, ymax = max([global_ymin, bbox.minY]), min(global_ymax, bbox.maxY) 

249 localBox = Box2I(Point2I(xmin, ymin), Point2I(xmax - xmin, ymax - ymin)) 

250 masked_sub_image = self.masked_image[localBox] 

251 

252 masked_sub_image = self.interpolate_masked_sub_image(masked_sub_image) 

253 self.masked_image[localBox] = masked_sub_image 

254 

255 def _good_pixel_binning(self, pixels): 

256 """ 

257 Performs pixel binning using treegp.meanify 

258 

259 Parameters: 

260 ----------- 

261 pixels : `np.array` 

262 The array of pixels. 

263 

264 Returns: 

265 -------- 

266 `np.array` 

267 The binned array of pixels. 

268 """ 

269 

270 n_pixels = len(pixels[:, 0]) 

271 dynamic_binning = int(np.sqrt(n_pixels / self.threshold_dynamic_binning)) 

272 if n_pixels / self.bin_spacing**2 < n_pixels / dynamic_binning**2: 

273 bin_spacing = self.bin_spacing 

274 else: 

275 bin_spacing = dynamic_binning 

276 binning = treegp.meanify(bin_spacing=bin_spacing, statistics="mean") 

277 binning.add_field( 

278 pixels[:, :2], 

279 pixels[:, 2:].T, 

280 ) 

281 binning.meanify() 

282 return np.array( 

283 [binning.coords0[:, 0], binning.coords0[:, 1], binning.params0] 

284 ).T 

285 

286 def interpolate_masked_sub_image(self, masked_sub_image): 

287 """ 

288 Interpolate the masked sub-image. 

289 

290 Parameters: 

291 ----------- 

292 masked_sub_image : `lsst.afw.image.MaskedImage` 

293 The sub-masked image to be interpolated. 

294 

295 Returns: 

296 -------- 

297 `lsst.afw.image.MaskedImage` 

298 The interpolated sub-masked image. 

299 """ 

300 

301 cut = int( 

302 self.correlation_length * self.correlation_length_cut 

303 ) # need integer as input. 

304 bad_pixel, good_pixel = ctUtils.findGoodPixelsAroundBadPixels( 

305 masked_sub_image, self.defects, buffer=cut 

306 ) 

307 # Do nothing if bad pixel is None. 

308 if bad_pixel.size == 0 or good_pixel.size == 0: 

309 self.log.info("No bad or good pixels found. No interpolation performed.") 

310 return masked_sub_image 

311 # Do GP interpolation if bad pixel found. 

312 else: 

313 # gp interpolation 

314 sub_image_array = masked_sub_image.getVariance().array 

315 white_noise = np.sqrt( 

316 np.mean(sub_image_array[np.isfinite(sub_image_array)]) 

317 ) 

318 kernel_amplitude = np.max(good_pixel[:, 2:]) 

319 if not np.isfinite(kernel_amplitude): 

320 filter_finite = np.isfinite(good_pixel[:, 2:]).T[0] 

321 good_pixel = good_pixel[filter_finite] 

322 if good_pixel.size == 0: 

323 self.log.info( 

324 "No bad or good pixels found. No interpolation performed." 

325 ) 

326 return masked_sub_image 

327 # kernel amplitude might be better described by maximum value of good pixel given 

328 # the data and not really a random gaussian field. 

329 kernel_amplitude = np.max(good_pixel[:, 2:]) 

330 

331 if self.bin_image: 

332 try: 

333 good_pixel = self._good_pixel_binning(copy.deepcopy(good_pixel)) 

334 except Exception: 

335 self.log.info( 

336 "Binning failed, use original good pixel array in interpolation." 

337 ) 

338 

339 # put this after binning as computing median is O(n*log(n)) 

340 clipped_median = median_with_mad_clipping(good_pixel[:, 2:]) 

341 

342 gp = GaussianProcessTreegp( 

343 std=np.sqrt(kernel_amplitude), 

344 correlation_length=self.correlation_length, 

345 white_noise=white_noise, 

346 mean=clipped_median, 

347 ) 

348 gp.fit(good_pixel[:, :2], np.squeeze(good_pixel[:, 2:])) 

349 if bad_pixel.size < self.threshold_subdivide: 

350 gp_predict = gp.predict(bad_pixel[:, :2]) 

351 bad_pixel[:, 2:] = gp_predict.reshape(np.shape(bad_pixel[:, 2:])) 

352 else: 

353 self.log.info("sub-divide bad pixel array to avoid memory error.") 

354 for i in range(0, len(bad_pixel), self.threshold_subdivide): 

355 end = min(i + self.threshold_subdivide, len(bad_pixel)) 

356 gp_predict = gp.predict(bad_pixel[i:end, :2]) 

357 bad_pixel[i:end, 2:] = gp_predict.reshape( 

358 np.shape(bad_pixel[i:end, 2:]) 

359 ) 

360 

361 # Update values 

362 ctUtils.updateImageFromArray(masked_sub_image.image, bad_pixel) 

363 updateMaskFromArray(masked_sub_image.mask, bad_pixel, self.interpBit) 

364 return masked_sub_image