Coverage for python/lsst/meas/extensions/shapeHSM/_hsm_higher_moments.py: 23%

141 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-10 13:01 +0000

1# This file is part of meas_extensions_shapeHSM. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = ( 

23 "HigherOrderMomentsConfig", 

24 "HigherOrderMomentsPlugin", 

25 "HigherOrderMomentsPSFConfig", 

26 "HigherOrderMomentsPSFPlugin", 

27 "HigherOrderMomentsSourceConfig", 

28 "HigherOrderMomentsSourcePlugin", 

29) 

30 

31import lsst.geom as geom 

32import lsst.meas.base as measBase 

33import numpy as np 

34from lsst.pex.config import Field, FieldValidationError, ListField 

35 

36 

37class HigherOrderMomentsConfig(measBase.SingleFramePluginConfig): 

38 min_order = Field[int]( 

39 doc="Minimum order of the higher order moments to compute", 

40 default=3, 

41 ) 

42 

43 max_order = Field[int]( 

44 doc="Maximum order of the higher order moments to compute", 

45 default=4, 

46 ) 

47 

48 def validate(self): 

49 if self.min_order > self.max_order: 

50 raise FieldValidationError( 

51 self.__class__.min_order, self, "min_order must be less than or equal to max_order" 

52 ) 

53 super().validate() 

54 

55 

56class HigherOrderMomentsPlugin(measBase.SingleFramePlugin): 

57 """Base plugin for higher moments measurement""" 

58 

59 ConfigClass = HigherOrderMomentsConfig 

60 

61 def __init__(self, config, name, schema, metadata, logName=None): 

62 super().__init__(config, name, schema, metadata, logName=logName) 

63 

64 # Define flags for possible issues that might arise during measurement. 

65 flagDefs = measBase.FlagDefinitionList() 

66 self.FAILURE = flagDefs.addFailureFlag("General failure flag, set if anything went wrong") 

67 

68 # Embed the flag definitions in the schema using a flag handler. 

69 self.flagHandler = measBase.FlagHandler.addFields(schema, name, flagDefs) 

70 

71 self.pqlist = self._get_pq_full() 

72 

73 @classmethod 

74 def getExecutionOrder(cls): 

75 return cls.FLUX_ORDER 

76 

77 def fail(self, record, error=None): 

78 # Docstring inherited. 

79 self.flagHandler.handleFailure(record) 

80 

81 def _get_pq_full(self): 

82 """Get a list of the orders to measure as a tuple. 

83 

84 Returns 

85 ------- 

86 pqlist: `list` [`tuples`] 

87 A list of tuples of the form (p, q) where p and q denote the order 

88 in x and y direction. 

89 """ 

90 pq_list = [] 

91 

92 for n in range(self.config.min_order, self.config.max_order + 1): 

93 p = 0 

94 q = n 

95 

96 pq_list.append((p, q)) 

97 

98 while p < n: 

99 p += 1 

100 q -= 1 

101 pq_list.append((p, q)) 

102 

103 return pq_list 

104 

105 def _generate_suffixes(self): 

106 """Generator of suffixes 'pq'.""" 

107 for p, q in self.pqlist: 

108 yield f"{p}{q}" 

109 

110 def _generate_powers_of_standard_positions(self, std_x, std_y): 

111 std_x_powers, std_y_powers = {0: 1.0, 1: std_x}, {0: 1.0, 1: std_y} 

112 

113 for p in range(2, self.config.max_order + 1): 

114 std_x_powers[p] = std_x_powers[p - 1] * std_x 

115 

116 for q in range(2, self.config.max_order + 1): 

117 std_y_powers[q] = std_y_powers[q - 1] * std_y 

118 

119 return std_x_powers, std_y_powers 

120 

121 def _calculate_higher_order_moments( 

122 self, 

123 image, 

124 center, 

125 M, 

126 badpix=None, 

127 set_masked_pixels_to_zero=False, 

128 use_linear_algebra=False, 

129 ): 

130 """ 

131 Calculate the higher order moments of an image. 

132 

133 Parameters 

134 ---------- 

135 image : `~lsst.afw.image.Image` 

136 Image from which the moments need to be measured (source or PSF). 

137 center: `~lsst.geom.Point2D` 

138 First order moments of ``image``. This is used as the peak of the 

139 Gaussian weight image. 

140 M : `~numpy.ndarray` 

141 A 2x2 numpy array representing the second order moments of 

142 ``image``. This is used to generate the Gaussian weight image. 

143 badpix : `~numpy.ndarray` or None 

144 A 2D array having the same shape and orientation as ``image.array`` 

145 that denotes which pixels are bad and should not be accounted for 

146 when computing the moments. 

147 set_masked_pixels_to_zero: `bool` 

148 Whether to treat pixels corresponding to ``badpix`` should be set 

149 to zero, or replaced by a scaled version of the weight image. 

150 This is ignored if ``badpix`` is None. 

151 use_linear_algebra: `bool` 

152 Use linear algebra operations (eigen decomposition and inverse) to 

153 calculate the moments? If False, use the specialized formulae for 

154 2x2 matrix. 

155 

156 Returns 

157 ------- 

158 results : `dict` 

159 A dictionary mapping the order of the moments expressed as tuples 

160 to the corresponding higher order moments. 

161 """ 

162 

163 bbox = image.getBBox() 

164 image_array = image.array 

165 

166 y, x = np.mgrid[: image_array.shape[0], : image_array.shape[1]] 

167 

168 if use_linear_algebra: 

169 inv_M = np.linalg.inv(M) 

170 

171 evalues, evectors = np.linalg.eig(inv_M) 

172 

173 sqrt_inv_M = evectors * np.sqrt(evalues) @ np.linalg.inv(evectors) 

174 else: 

175 # This is the implementation of Eq. 6 in Hirata & Seljak (2003): 

176 # https://arxiv.org/pdf/astro-ph/0301054.pdf 

177 D = M[0, 0] * M[1, 1] - M[0, 1] * M[1, 0] 

178 sqrt_D = D**0.5 

179 sqrt_eta = (D * (M[0, 0] + M[1, 1] + 2 * sqrt_D)) ** 0.5 

180 sqrt_inv_M = (1 / sqrt_eta) * np.array( 

181 [[M[1, 1] + sqrt_D, -M[0, 1]], [-M[1, 0], M[0, 0] + sqrt_D]] 

182 ) 

183 

184 pos = np.array([x - (center.getX() - bbox.getMinX()), y - (center.getY() - bbox.getMinY())]) 

185 

186 std_pos = np.einsum("ij,jqp->iqp", sqrt_inv_M, pos) 

187 weight = np.exp(-0.5 * np.einsum("ijk,ijk->jk", std_pos, std_pos)) 

188 

189 image_weight = weight * image_array 

190 

191 # Modify only the weight, not the image_array, since it will change the 

192 # pixel values forever!!! 

193 if badpix is not None and badpix.any(): 

194 if set_masked_pixels_to_zero: 

195 # This is how HSM treats bad pixels to compute the quadrupole 

196 # moments. 

197 image_weight[badpix] = 0.0 

198 else: 

199 # This is how Piff treats bad pixels to compute the 

200 # higher-order moments. 

201 scale = image_array[~badpix].sum() / weight[~badpix].sum() 

202 image_weight[badpix] = (weight[badpix] ** 2) * scale 

203 

204 normalization = np.sum(image_weight) 

205 

206 std_x, std_y = std_pos 

207 std_x_powers, std_y_powers = self._generate_powers_of_standard_positions(std_x, std_y) 

208 

209 results = {} 

210 for p, q in self.pqlist: 

211 results[(p, q)] = np.sum(std_x_powers[p] * std_y_powers[q] * image_weight) / normalization 

212 

213 return results 

214 

215 

216class HigherOrderMomentsSourceConfig(HigherOrderMomentsConfig): 

217 """Configuration for the measurement of higher order moments of objects.""" 

218 

219 badMaskPlanes = ListField[str]( 

220 doc="Mask planes used to reject bad pixels.", 

221 default=["BAD", "SAT"], 

222 ) 

223 

224 setMaskedPixelsToZero = Field[bool]( 

225 doc="Set masked pixels to zero? If False, they are replaced by the " 

226 "scaled version of the adaptive weights.", 

227 default=False, 

228 ) 

229 

230 

231@measBase.register("ext_shapeHSM_HigherOrderMomentsSource") 

232class HigherOrderMomentsSourcePlugin(HigherOrderMomentsPlugin): 

233 """Plugin for Higher Order Moments measurement of objects. 

234 

235 The moments are measured in normalized coordinates, where the normalized x 

236 axis is along the major axis and the normalized y axis along the minor. 

237 The moments are dependent only on the light profile, and does not scale 

238 with the size or orientation of the object. 

239 

240 For any well-sampled image, the zeroth order moment is 1, 

241 the first order moments are 0, and the second order moments are 0.5 for xx 

242 and yy and 0 for xy. For a symmetric profile, the moments are zeros if 

243 either of the indices is odd. 

244 

245 Notes 

246 ----- 

247 This plugin requires the `ext_shapeHSM_HsmSourceMoments` plugin to be 

248 enabled in order to measure the higher order moments, and raises a 

249 FatalAlgorithmError otherwise. For accurate results, the weight function 

250 used must match those used for first and second order moments. Hence, this 

251 plugin does not use slots for centroids and shapes, but instead uses those 

252 measured by the `ext_shapeHSM_HsmSourceMoments` explicitly. 

253 

254 The only known failure mode of this plugin is if 

255 `ext_shapeHSM_HsmSourceMoments` measurement failed. The flags of that 

256 plugin are informative here as well and should be used to filter out 

257 unreliable measurements. 

258 """ 

259 

260 ConfigClass = HigherOrderMomentsSourceConfig 

261 

262 def __init__(self, config, name, schema, metadata, logName=None): 

263 super().__init__(config, name, schema, metadata, logName=logName) 

264 

265 for suffix in self._generate_suffixes(): 

266 schema.addField( 

267 schema.join(name, suffix), 

268 type=float, 

269 doc=f"Higher order moments M_{suffix} for source", 

270 ) 

271 

272 def measure(self, record, exposure): 

273 # Docstring inherited. 

274 M = np.zeros((2, 2)) 

275 try: 

276 center = geom.Point2D( 

277 record["ext_shapeHSM_HsmSourceMoments_x"], 

278 record["ext_shapeHSM_HsmSourceMoments_y"], 

279 ) 

280 M[0, 0] = record["ext_shapeHSM_HsmSourceMoments_xx"] 

281 M[1, 1] = record["ext_shapeHSM_HsmSourceMoments_yy"] 

282 M[0, 1] = M[1, 0] = record["ext_shapeHSM_HsmSourceMoments_xy"] 

283 except KeyError: 

284 raise measBase.FatalAlgorithmError("'ext_shapeHSM_HsmSourceMoments' plugin must be enabled.") 

285 

286 # Obtain the bounding box of the source footprint 

287 bbox = record.getFootprint().getBBox() 

288 

289 bitValue = exposure.mask.getPlaneBitMask(self.config.badMaskPlanes) 

290 badpix = (exposure.mask[bbox].array & bitValue) != 0 

291 

292 # Measure all the moments together to save time 

293 try: 

294 hm_measurement = self._calculate_higher_order_moments( 

295 exposure.image[bbox], 

296 center, 

297 M, 

298 badpix, 

299 set_masked_pixels_to_zero=self.config.setMaskedPixelsToZero, 

300 ) 

301 except Exception as e: 

302 raise measBase.MeasurementError(e) 

303 

304 # Record the moments 

305 for (p, q), M_pq in hm_measurement.items(): 

306 column_key = self.name + f"_{p}{q}" 

307 record.set(column_key, M_pq) 

308 

309 

310class HigherOrderMomentsPSFConfig(HigherOrderMomentsConfig): 

311 """Configuration for the higher order moments of the PSF.""" 

312 

313 useSourceCentroidOffset = Field[bool]( 

314 doc="Use source centroid offset?", 

315 default=False, 

316 ) 

317 

318 

319@measBase.register("ext_shapeHSM_HigherOrderMomentsPSF") 

320class HigherOrderMomentsPSFPlugin(HigherOrderMomentsPlugin): 

321 """Plugin for Higher Order Moments measurement of PSF models. 

322 

323 The moments are measured in normalized coordinates, where the normalized x 

324 axis is along the major axis and the normalized y axis along the minor. 

325 The moments are dependent only on the light profile, and does not scale 

326 with the size or orientation of the object. 

327 

328 For any well-sampled image, the zeroth order moment is 1, 

329 the first order moments are 0, and the second order moments are 0.5 for xx 

330 and yy and 0 for xy. For a symmetric profile, the moments are zeros if 

331 either of the indices is odd. 

332 

333 Notes 

334 ----- 

335 This plugin requires the `ext_shapeHSM_HsmPsfMoments` plugin to be 

336 enabled in order to measure the higher order moments, and raises a 

337 FatalAlgorithmError otherwise. The weight function is parametrized by the 

338 shape measured from `ext_shapeHSM_HsmPsfMoments` but for efficiency 

339 reasons, uses the slot centroid to evaluate the PSF model. 

340 

341 The only known failure mode of this plugin is if 

342 `ext_shapeHSM_HsmPsfMoments` measurement failed. The flags of that 

343 plugin are informative here as well and should be used to filter out 

344 unreliable measurements. 

345 """ 

346 

347 ConfigClass = HigherOrderMomentsPSFConfig 

348 

349 def __init__(self, config, name, schema, metadata, logName=None): 

350 super().__init__(config, name, schema, metadata, logName=logName) 

351 # Use the standard slot centroid to use the shared PSF model with 

352 # other plugins. 

353 self.centroidExtractor = measBase.SafeCentroidExtractor(schema, name) 

354 

355 for suffix in self._generate_suffixes(): 

356 schema.addField( 

357 schema.join(name, suffix), 

358 type=float, 

359 doc=f"Higher order moments M_{suffix} for PSF", 

360 ) 

361 

362 def measure(self, record, exposure): 

363 # Docstring inherited. 

364 M = np.zeros((2, 2)) 

365 try: 

366 M[0, 0] = record["ext_shapeHSM_HsmPsfMoments_xx"] 

367 M[1, 1] = record["ext_shapeHSM_HsmPsfMoments_yy"] 

368 M[0, 1] = M[1, 0] = record["ext_shapeHSM_HsmPsfMoments_xy"] 

369 except KeyError: 

370 raise measBase.FatalAlgorithmError("'ext_shapeHSM_HsmPsfMoments' plugin must be enabled.") 

371 

372 psf = exposure.getPsf() 

373 

374 centroid = self.centroidExtractor(record, self.flagHandler) 

375 if self.config.useSourceCentroidOffset: 

376 psfImage = psf.computeImage(centroid) 

377 psfCenter = centroid 

378 # Undo what subtractCenter config did. 

379 # This operation assumes subtractCenter was set to True (default) 

380 # in the ext_shapeHSM_HsmPsfMomentsConfig and does not have 

381 # access to it. 

382 psfCenter.x += record["ext_shapeHSM_HsmPsfMoments_x"] 

383 psfCenter.y += record["ext_shapeHSM_HsmPsfMoments_y"] 

384 else: 

385 psfImage = psf.computeKernelImage(centroid) 

386 center0 = geom.Point2I(centroid) 

387 xy0 = geom.Point2I(center0.x + psfImage.getX0(), center0.y + psfImage.getY0()) 

388 psfImage.setXY0(xy0) 

389 psfBBox = psfImage.getBBox() 

390 psfCenter = geom.Point2D(psfBBox.getMin() + psfBBox.getDimensions() // 2) 

391 

392 # Measure all the moments together to save time 

393 try: 

394 hm_measurement = self._calculate_higher_order_moments(psfImage, psfCenter, M) 

395 except Exception as e: 

396 raise measBase.MeasurementError(e) 

397 

398 # Record the moments 

399 for (p, q), M_pq in hm_measurement.items(): 

400 column_key = self.name + f"_{p}{q}" 

401 record.set(column_key, M_pq)