Coverage for python / lsst / meas / extensions / shapeHSM / _hsm_higher_moments.py: 19%

140 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-26 09:15 +0000

1# This file is part of meas_extensions_shapeHSM. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = ( 

23 "HigherOrderMomentsConfig", 

24 "HigherOrderMomentsPlugin", 

25 "HigherOrderMomentsPSFConfig", 

26 "HigherOrderMomentsPSFPlugin", 

27 "HigherOrderMomentsSourceConfig", 

28 "HigherOrderMomentsSourcePlugin", 

29) 

30 

31import lsst.geom as geom 

32import lsst.meas.base as measBase 

33import numpy as np 

34from lsst.pex.config import Field, FieldValidationError, ListField 

35 

36 

37class HigherOrderMomentsConfig(measBase.SingleFramePluginConfig): 

38 min_order = Field[int]( 

39 doc="Minimum order of the higher order moments to compute", 

40 default=3, 

41 ) 

42 

43 max_order = Field[int]( 

44 doc="Maximum order of the higher order moments to compute", 

45 default=4, 

46 ) 

47 

48 def validate(self): 

49 if self.min_order > self.max_order: 

50 raise FieldValidationError( 

51 self.__class__.min_order, self, "min_order must be less than or equal to max_order" 

52 ) 

53 super().validate() 

54 

55 

56class HigherOrderMomentsPlugin(measBase.SingleFramePlugin): 

57 """Base plugin for higher moments measurement""" 

58 

59 ConfigClass = HigherOrderMomentsConfig 

60 

61 def __init__(self, config, name, schema, metadata, logName=None): 

62 super().__init__(config, name, schema, metadata, logName=logName) 

63 

64 # Define flags for possible issues that might arise during measurement. 

65 flagDefs = measBase.FlagDefinitionList() 

66 self.FAILURE = flagDefs.addFailureFlag("General failure flag, set if anything went wrong") 

67 

68 # Embed the flag definitions in the schema using a flag handler. 

69 self.flagHandler = measBase.FlagHandler.addFields(schema, name, flagDefs) 

70 

71 self.pqlist = self._get_pq_full() 

72 

73 @classmethod 

74 def getExecutionOrder(cls): 

75 return cls.FLUX_ORDER 

76 

77 def fail(self, record, error=None): 

78 # Docstring inherited. 

79 self.flagHandler.handleFailure(record) 

80 

81 def _get_pq_full(self): 

82 """Get a list of the orders to measure as a tuple. 

83 

84 Returns 

85 ------- 

86 pqlist: `list` [`tuples`] 

87 A list of tuples of the form (p, q) where p and q denote the order 

88 in x and y direction. 

89 """ 

90 pq_list = [] 

91 

92 for n in range(self.config.min_order, self.config.max_order + 1): 

93 p = 0 

94 q = n 

95 

96 pq_list.append((p, q)) 

97 

98 while p < n: 

99 p += 1 

100 q -= 1 

101 pq_list.append((p, q)) 

102 

103 return pq_list 

104 

105 def _generate_suffixes(self): 

106 """Generator of suffixes 'pq'.""" 

107 for p, q in self.pqlist: 

108 yield f"{p}{q}" 

109 

110 def _generate_powers_of_standard_positions(self, std_x, std_y): 

111 std_x_powers, std_y_powers = {0: 1.0, 1: std_x}, {0: 1.0, 1: std_y} 

112 

113 for p in range(2, self.config.max_order + 1): 

114 std_x_powers[p] = std_x_powers[p - 1] * std_x 

115 

116 for q in range(2, self.config.max_order + 1): 

117 std_y_powers[q] = std_y_powers[q - 1] * std_y 

118 

119 return std_x_powers, std_y_powers 

120 

121 def _calculate_higher_order_moments( 

122 self, 

123 image, 

124 center, 

125 M, 

126 badpix=None, 

127 set_masked_pixels_to_zero=False, 

128 use_linear_algebra=False, 

129 ): 

130 """ 

131 Calculate the higher order moments of an image. 

132 

133 Parameters 

134 ---------- 

135 image : `~lsst.afw.image.Image` 

136 Image from which the moments need to be measured (source or PSF). 

137 center: `~lsst.geom.Point2D` 

138 First order moments of ``image``. This is used as the peak of the 

139 Gaussian weight image. 

140 M : `~numpy.ndarray` 

141 A 2x2 numpy array representing the second order moments of 

142 ``image``. This is used to generate the Gaussian weight image. 

143 badpix : `~numpy.ndarray` or None 

144 A 2D array having the same shape and orientation as ``image.array`` 

145 that denotes which pixels are bad and should not be accounted for 

146 when computing the moments. 

147 set_masked_pixels_to_zero: `bool` 

148 Whether to treat pixels corresponding to ``badpix`` should be set 

149 to zero, or replaced by a scaled version of the weight image. 

150 This is ignored if ``badpix`` is None. 

151 use_linear_algebra: `bool` 

152 Use linear algebra operations (eigen decomposition and inverse) to 

153 calculate the moments? If False, use the specialized formulae for 

154 2x2 matrix. 

155 

156 Returns 

157 ------- 

158 results : `dict` 

159 A dictionary mapping the order of the moments expressed as tuples 

160 to the corresponding higher order moments. 

161 """ 

162 

163 bbox = image.getBBox() 

164 image_array = image.array 

165 

166 y, x = np.mgrid[: image_array.shape[0], : image_array.shape[1]] 

167 

168 if use_linear_algebra: 

169 inv_M = np.linalg.inv(M) 

170 

171 evalues, evectors = np.linalg.eig(inv_M) 

172 

173 sqrt_inv_M = evectors * np.sqrt(evalues) @ np.linalg.inv(evectors) 

174 else: 

175 # This is the implementation of Eq. 6 in Hirata & Seljak (2003): 

176 # https://arxiv.org/pdf/astro-ph/0301054.pdf 

177 D = M[0, 0] * M[1, 1] - M[0, 1] * M[1, 0] 

178 sqrt_D = D**0.5 

179 sqrt_eta = (D * (M[0, 0] + M[1, 1] + 2 * sqrt_D)) ** 0.5 

180 sqrt_inv_M = (1 / sqrt_eta) * np.array( 

181 [[M[1, 1] + sqrt_D, -M[0, 1]], [-M[1, 0], M[0, 0] + sqrt_D]] 

182 ) 

183 

184 pos = np.array([x - (center.getX() - bbox.getMinX()), y - (center.getY() - bbox.getMinY())]) 

185 

186 std_pos = np.einsum("ij,jqp->iqp", sqrt_inv_M, pos) 

187 weight = np.exp(-0.5 * np.einsum("ijk,ijk->jk", std_pos, std_pos)) 

188 

189 image_weight = weight * image_array 

190 

191 # Modify only the weight, not the image_array, since it will change the 

192 # pixel values forever!!! 

193 if badpix is not None and badpix.any(): 

194 if set_masked_pixels_to_zero: 

195 # This is how HSM treats bad pixels to compute the quadrupole 

196 # moments. 

197 image_weight[badpix] = 0.0 

198 else: 

199 # This is how Piff treats bad pixels to compute the 

200 # higher-order moments. 

201 scale = image_array[~badpix].sum() / weight[~badpix].sum() 

202 image_weight[badpix] = (weight[badpix] ** 2) * scale 

203 

204 normalization = np.sum(image_weight) 

205 

206 std_x, std_y = std_pos 

207 std_x_powers, std_y_powers = self._generate_powers_of_standard_positions(std_x, std_y) 

208 

209 results = {} 

210 for p, q in self.pqlist: 

211 results[(p, q)] = np.sum(std_x_powers[p] * std_y_powers[q] * image_weight) / normalization 

212 

213 return results 

214 

215 

216class HigherOrderMomentsSourceConfig(HigherOrderMomentsConfig): 

217 """Configuration for the measurement of higher order moments of objects.""" 

218 

219 badMaskPlanes = ListField[str]( 

220 doc="Mask planes used to reject bad pixels.", 

221 default=["BAD", "SAT"], 

222 ) 

223 

224 setMaskedPixelsToZero = Field[bool]( 

225 doc="Set masked pixels to zero? If False, they are replaced by the " 

226 "scaled version of the adaptive weights.", 

227 default=False, 

228 ) 

229 

230 

231@measBase.register("ext_shapeHSM_HigherOrderMomentsSource") 

232class HigherOrderMomentsSourcePlugin(HigherOrderMomentsPlugin): 

233 """Plugin for Higher Order Moments measurement of objects. 

234 

235 The moments are measured in normalized coordinates, where the normalized x 

236 axis is along the major axis and the normalized y axis along the minor. 

237 The moments are dependent only on the light profile, and does not scale 

238 with the size or orientation of the object. 

239 

240 For any well-sampled image, the zeroth order moment is 1, 

241 the first order moments are 0, and the second order moments are 0.5 for xx 

242 and yy and 0 for xy. For a symmetric profile, the moments are zeros if 

243 either of the indices is odd. 

244 

245 Notes 

246 ----- 

247 This plugin requires the `ext_shapeHSM_HsmSourceMoments` plugin to be 

248 enabled in order to measure the higher order moments, and raises a 

249 FatalAlgorithmError otherwise. For accurate results, the weight function 

250 used must match those used for first and second order moments. Hence, this 

251 plugin does not use slots for centroids and shapes, but instead uses those 

252 measured by the `ext_shapeHSM_HsmSourceMoments` explicitly. 

253 

254 The only known failure mode of this plugin is if 

255 `ext_shapeHSM_HsmSourceMoments` measurement failed. The flags of that 

256 plugin are informative here as well and should be used to filter out 

257 unreliable measurements. 

258 """ 

259 

260 ConfigClass = HigherOrderMomentsSourceConfig 

261 

262 def __init__(self, config, name, schema, metadata, logName=None): 

263 super().__init__(config, name, schema, metadata, logName=logName) 

264 

265 for suffix in self._generate_suffixes(): 

266 schema.addField( 

267 schema.join(name, suffix), 

268 type=float, 

269 doc=f"Higher order moments M_{suffix} for source", 

270 ) 

271 

272 def measure(self, record, exposure): 

273 # Docstring inherited. 

274 

275 # If HsmSourceMoments algorithm failed, then the higher-order moments measurements 

276 # would surely fail. Raise MeasurementError and move on. 

277 # If HsmSourceMoments algorithm was not run, then higher-order moments measurements 

278 # would fail for all entries in the catalog and is fatal. 

279 try: 

280 if record["ext_shapeHSM_HsmSourceMoments_flag"]: 

281 raise measBase.MeasurementError(self.FAILURE.doc, self.FAILURE.number) 

282 except KeyError: 

283 raise measBase.FatalAlgorithmError("'ext_shapeHSM_HsmSourceMoments' plugin must be enabled.") 

284 

285 center = geom.Point2D( 

286 record["ext_shapeHSM_HsmSourceMoments_x"], 

287 record["ext_shapeHSM_HsmSourceMoments_y"], 

288 ) 

289 M = np.zeros((2, 2)) 

290 M[0, 0] = record["ext_shapeHSM_HsmSourceMoments_xx"] 

291 M[1, 1] = record["ext_shapeHSM_HsmSourceMoments_yy"] 

292 M[0, 1] = M[1, 0] = record["ext_shapeHSM_HsmSourceMoments_xy"] 

293 

294 # Obtain the bounding box of the source footprint 

295 bbox = record.getFootprint().getBBox() 

296 

297 bitValue = exposure.mask.getPlaneBitMask(self.config.badMaskPlanes) 

298 badpix = (exposure.mask[bbox].array & bitValue) != 0 

299 

300 # Measure all the moments together to save time 

301 hm_measurement = self._calculate_higher_order_moments( 

302 exposure.image[bbox], 

303 center, 

304 M, 

305 badpix, 

306 set_masked_pixels_to_zero=self.config.setMaskedPixelsToZero, 

307 ) 

308 

309 # Record the moments 

310 for (p, q), M_pq in hm_measurement.items(): 

311 column_key = self.name + f"_{p}{q}" 

312 record.set(column_key, M_pq) 

313 

314 

315class HigherOrderMomentsPSFConfig(HigherOrderMomentsConfig): 

316 """Configuration for the higher order moments of the PSF.""" 

317 

318 useSourceCentroidOffset = Field[bool]( 

319 doc="Use source centroid offset?", 

320 default=False, 

321 ) 

322 

323 

324@measBase.register("ext_shapeHSM_HigherOrderMomentsPSF") 

325class HigherOrderMomentsPSFPlugin(HigherOrderMomentsPlugin): 

326 """Plugin for Higher Order Moments measurement of PSF models. 

327 

328 The moments are measured in normalized coordinates, where the normalized x 

329 axis is along the major axis and the normalized y axis along the minor. 

330 The moments are dependent only on the light profile, and does not scale 

331 with the size or orientation of the object. 

332 

333 For any well-sampled image, the zeroth order moment is 1, 

334 the first order moments are 0, and the second order moments are 0.5 for xx 

335 and yy and 0 for xy. For a symmetric profile, the moments are zeros if 

336 either of the indices is odd. 

337 

338 Notes 

339 ----- 

340 This plugin requires the `ext_shapeHSM_HsmPsfMoments` plugin to be 

341 enabled in order to measure the higher order moments, and raises a 

342 FatalAlgorithmError otherwise. The weight function is parametrized by the 

343 shape measured from `ext_shapeHSM_HsmPsfMoments` but for efficiency 

344 reasons, uses the slot centroid to evaluate the PSF model. 

345 

346 The only known failure mode of this plugin is if 

347 `ext_shapeHSM_HsmPsfMoments` measurement failed. The flags of that 

348 plugin are informative here as well and should be used to filter out 

349 unreliable measurements. 

350 """ 

351 

352 ConfigClass = HigherOrderMomentsPSFConfig 

353 

354 def __init__(self, config, name, schema, metadata, logName=None): 

355 super().__init__(config, name, schema, metadata, logName=logName) 

356 # Use the standard slot centroid to use the shared PSF model with 

357 # other plugins. 

358 self.centroidExtractor = measBase.SafeCentroidExtractor(schema, name) 

359 

360 for suffix in self._generate_suffixes(): 

361 schema.addField( 

362 schema.join(name, suffix), 

363 type=float, 

364 doc=f"Higher order moments M_{suffix} for PSF", 

365 ) 

366 

367 def measure(self, record, exposure): 

368 # Docstring inherited. 

369 

370 # If the PSF model is not available, this plugin would fail for all entries. 

371 if (psf := exposure.getPsf()) is None: 

372 raise measBase.FatalAlgorithmError("No PSF attached to the exposure.") 

373 

374 # If HsmPsfMoments algorithm failed, then the higher-order PSF moments measurements 

375 # would surely fail. Raise MeasurementError and move on. 

376 # If HsmPsfMoments algorithm was not run, then higher-order PSF moments measurements 

377 # would fail for all entries in the catalog and is fatal. 

378 try: 

379 if record["ext_shapeHSM_HsmPsfMoments_flag"]: 

380 raise measBase.MeasurementError(self.FAILURE.doc, self.FAILURE.number) 

381 except KeyError: 

382 raise measBase.FatalAlgorithmError("'ext_shapeHSM_HsmPsfMoments' plugin must be enabled.") 

383 

384 M = np.zeros((2, 2)) 

385 M[0, 0] = record["ext_shapeHSM_HsmPsfMoments_xx"] 

386 M[1, 1] = record["ext_shapeHSM_HsmPsfMoments_yy"] 

387 M[0, 1] = M[1, 0] = record["ext_shapeHSM_HsmPsfMoments_xy"] 

388 

389 centroid = self.centroidExtractor(record, self.flagHandler) 

390 if self.config.useSourceCentroidOffset: 

391 psfImage = psf.computeImage(centroid) 

392 psfCenter = centroid 

393 # Undo what subtractCenter config did. 

394 # This operation assumes subtractCenter was set to True (default) 

395 # in the ext_shapeHSM_HsmPsfMomentsConfig and does not have 

396 # access to it. 

397 psfCenter.x += record["ext_shapeHSM_HsmPsfMoments_x"] 

398 psfCenter.y += record["ext_shapeHSM_HsmPsfMoments_y"] 

399 else: 

400 psfImage = psf.computeKernelImage(centroid) 

401 center0 = geom.Point2I(centroid) 

402 xy0 = geom.Point2I(center0.x + psfImage.getX0(), center0.y + psfImage.getY0()) 

403 psfImage.setXY0(xy0) 

404 psfBBox = psfImage.getBBox() 

405 psfCenter = geom.Point2D(psfBBox.getMin() + psfBBox.getDimensions() // 2) 

406 

407 # Measure all the moments together to save time 

408 hm_measurement = self._calculate_higher_order_moments(psfImage, psfCenter, M) 

409 

410 # Record the moments 

411 for (p, q), M_pq in hm_measurement.items(): 

412 column_key = self.name + f"_{p}{q}" 

413 record.set(column_key, M_pq)