Coverage for python / lsst / meas / extensions / shapeHSM / _hsm_shape.py: 40%

147 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-24 08:32 +0000

1# This file is part of meas_extensions_shapeHSM. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import logging 

23 

24import galsim 

25import lsst.afw.image as afwImage 

26import lsst.afw.math as afwMath 

27import lsst.meas.base as measBase 

28import lsst.pex.config as pexConfig 

29from lsst.geom import Point2I 

30 

31__all__ = [ 

32 "HsmShapeBjConfig", 

33 "HsmShapeBjPlugin", 

34 "HsmShapeLinearConfig", 

35 "HsmShapeLinearPlugin", 

36 "HsmShapeKsbConfig", 

37 "HsmShapeKsbPlugin", 

38 "HsmShapeRegaussConfig", 

39 "HsmShapeRegaussPlugin", 

40] 

41 

42 

43def inherit_doc(ref_class): 

44 def decorator(func): 

45 func.__doc__ = getattr(ref_class, func.__name__).__doc__ 

46 return func 

47 

48 return decorator 

49 

50 

51class HsmShapeConfig(measBase.SingleFramePluginConfig): 

52 """Base configuration for HSM shape measurement.""" 

53 

54 deblendNChild = pexConfig.Field[str]( 

55 doc="Field name for number of deblend children.", 

56 default="", 

57 ) 

58 

59 badMaskPlanes = pexConfig.ListField[str]( 

60 doc="Mask planes that indicate pixels that should be excluded from the fit.", 

61 default=["BAD", "SAT"], 

62 ) 

63 

64 @property 

65 def shearType(self): 

66 """Base class property for the desired method of PSF correction. 

67 

68 The following options are available through GalSim. The first three 

69 options return an e-type distortion, whereas the last option returns a 

70 g-type shear: 

71 

72 - "REGAUSS": Regaussianization method from Hirata & Seljak (2003). 

73 - "LINEAR": A modification by Hirata & Seljak (2003) of methods in 

74 Bernstein & Jarvis (2002). 

75 - "BJ": The method developed by Bernstein & Jarvis (2002). 

76 - "KSB": The method from Kaiser, Squires, & Broadhurst (1995). 

77 

78 Subclasses can override this property, but it cannot be set externally, 

79 making it effectively read-only. 

80 """ 

81 raise NotImplementedError("The shearType property must be implemented in subclasses.") 

82 

83 

84class HsmShapePlugin(measBase.SingleFramePlugin): 

85 """Base plugin for HSM shape measurement.""" 

86 

87 ConfigClass = HsmShapeConfig 

88 doc = "" 

89 

90 def __init__(self, config, name, schema, metadata, logName=None): 

91 if logName is None: 

92 logName = __name__ 

93 super().__init__(config, name, schema, metadata, logName=logName) 

94 

95 # Define flags for possible issues that might arise during measurement. 

96 flagDefs = measBase.FlagDefinitionList() 

97 self.FAILURE = flagDefs.addFailureFlag("General failure flag, set if anything went wrong") 

98 self.NO_PIXELS = flagDefs.add("flag_no_pixels", "No pixels to measure") 

99 self.NOT_CONTAINED = flagDefs.add( 

100 "flag_not_contained", "Center not contained in footprint bounding box" 

101 ) 

102 self.PARENT_SOURCE = flagDefs.add("flag_parent_source", "Parent source, ignored") 

103 self.GALSIM = flagDefs.add("flag_galsim", "GalSim failure") 

104 

105 # Embed the flag definitions in the schema using a flag handler. 

106 self.flagHandler = measBase.FlagHandler.addFields(schema, name, flagDefs) 

107 

108 # Utilize a safe centroid extractor that uses the detection footprint 

109 # as a fallback if necessary. 

110 self.centroidExtractor = measBase.SafeCentroidExtractor(schema, name) 

111 

112 self.e1Key = self._addEllipticityField(name, 1, schema, self.doc) 

113 self.e2Key = self._addEllipticityField(name, 2, schema, self.doc) 

114 self.sigmaKey = schema.addField( 

115 schema.join(name, "sigma"), 

116 type=float, 

117 doc=f"{self.doc} (shape measurement uncertainty per component)", 

118 ) 

119 self.resolutionKey = schema.addField( 

120 schema.join(name, "resolution"), type=float, doc="Resolution factor (0=unresolved, 1=resolved)" 

121 ) 

122 self.hasDeblendKey = len(config.deblendNChild) > 0 

123 

124 if self.hasDeblendKey: 

125 self.deblendKey = schema[config.deblendNChild].asKey() 

126 

127 self.log = logging.getLogger(self.logName) 

128 

129 @classmethod 

130 def getExecutionOrder(cls): 

131 return cls.SHAPE_ORDER 

132 

133 @staticmethod 

134 def bboxToGalSimBounds(bbox): 

135 xmin, xmax = bbox.getMinX(), bbox.getMaxX() 

136 ymin, ymax = bbox.getMinY(), bbox.getMaxY() 

137 return galsim._BoundsI(xmin, xmax, ymin, ymax) 

138 

139 def _addEllipticityField(self, name, n, schema, doc): 

140 """ 

141 Helper function to add an ellipticity field to a measurement schema. 

142 

143 Parameters 

144 ---------- 

145 name : `str` 

146 Base name of the field. 

147 n : `int` 

148 Specifies whether the field is for the first (1) or second (2) 

149 component. 

150 schema : `~lsst.afw.table.Schema` 

151 The schema to which the field is added. 

152 doc : `str` 

153 The documentation string that needs to be updated to reflect the 

154 type and component of the measurement. 

155 

156 Returns 

157 ------- 

158 `~lsst.afw.table.KeyD` 

159 The key associated with the added field in the schema. 

160 """ 

161 componentLookup = {1: "+ component", 2: "x component"} 

162 typeLookup = {"e": " of ellipticity", "g": " of estimated shear"} 

163 name = f"{name}_{self.measTypeSymbol}{n}" 

164 updatedDoc = f"{doc} ({componentLookup[n]}{typeLookup[self.measTypeSymbol]})" 

165 return schema.addField(name, type=float, doc=updatedDoc) 

166 

167 def measure(self, record, exposure): 

168 """ 

169 Measure the shape of sources given an exposure and set the results in 

170 the record in place. 

171 

172 Parameters 

173 ---------- 

174 record : `~lsst.afw.table.SourceRecord` 

175 The record where measurement outputs will be stored. 

176 exposure : `~lsst.afw.image.Exposure` 

177 The exposure containing the source which needs measurement. 

178 

179 Raises 

180 ------ 

181 MeasurementError 

182 Raised for errors in measurement. 

183 """ 

184 # Extract the centroid from the record. 

185 center = self.centroidExtractor(record, self.flagHandler) 

186 

187 if self.hasDeblendKey and record.get(self.deblendKey) > 0: 

188 raise measBase.MeasurementError(self.PARENT_SOURCE.doc, self.PARENT_SOURCE.number) 

189 

190 # Get the bounding box of the source's footprint. 

191 bbox = record.getFootprint().getBBox() 

192 

193 # Check that the bounding box has non-zero area. 

194 if bbox.getArea() == 0: 

195 raise measBase.MeasurementError(self.NO_PIXELS.doc, self.NO_PIXELS.number) 

196 

197 # Ensure that the centroid is within the bounding box. 

198 if not bbox.contains(Point2I(center)): 

199 raise measBase.MeasurementError(self.NOT_CONTAINED.doc, self.NOT_CONTAINED.number) 

200 

201 # Get the PSF image evaluated at the source centroid. 

202 psfImage = exposure.getPsf().computeImage(center) 

203 psfImage.setXY0(0, 0) 

204 

205 # Get the trace radius of the PSF. 

206 psfSigma = exposure.getPsf().computeShape(center).getTraceRadius() 

207 

208 # Turn bounding box corners into GalSim bounds. 

209 bounds = self.bboxToGalSimBounds(bbox) 

210 

211 # Get the bounding box of the PSF in the parent coordinate system. 

212 psfBBox = psfImage.getBBox(afwImage.PARENT) 

213 

214 # Turn the PSF bounding box corners into GalSim bounds. 

215 psfBounds = self.bboxToGalSimBounds(psfBBox) 

216 

217 # Each GalSim image below will match whatever dtype the input array is. 

218 # NOTE: PSF is already restricted to a small image, so no bbox for the 

219 # PSF is expected. 

220 image = galsim._Image(exposure.image[bbox].array, bounds, wcs=None) 

221 psf = galsim._Image(psfImage.array, psfBounds, wcs=None) 

222 

223 # Get the `lsst.meas.base` mask for bad pixels. 

224 subMask = exposure.mask[bbox] 

225 badpix = subMask.array.copy() # Copy it since badpix gets modified. 

226 bitValue = exposure.mask.getPlaneBitMask(self.config.badMaskPlanes) 

227 badpix &= bitValue 

228 

229 # Turn badpix to weight where elements set to 1 indicate 'use pixel' 

230 # and those set to 0 mean 'do not use pixel'. Now, weight will assume 

231 # the role of badpix, and we will no longer use badpix in our call to 

232 # EstimateShear(). 

233 gd = badpix == 0 

234 badpix[gd] = 1 

235 badpix[~gd] = 0 

236 weight = galsim._Image(badpix, bounds, wcs=None) 

237 

238 # Get the statistics control object for sky variance estimation. 

239 sctrl = afwMath.StatisticsControl() 

240 sctrl.setAndMask(bitValue) 

241 

242 # Create a variance image from the exposure. 

243 # NOTE: Origin defaults to PARENT in all cases accessible from Python. 

244 variance = afwImage.Image( 

245 exposure.variance[bbox], 

246 dtype=exposure.variance.dtype, 

247 deep=False, 

248 ) 

249 

250 # Calculate median sky variance for use in shear estimation. 

251 stat = afwMath.makeStatistics(variance, subMask, afwMath.MEDIAN, sctrl) 

252 skyvar = stat.getValue(afwMath.MEDIAN) 

253 

254 # Prepare various values for the GalSim's EstimateShear call. 

255 recomputeFlux = "FIT" 

256 precision = 1.0e-6 

257 guessCentroid = galsim._PositionD(center.getX(), center.getY()) 

258 

259 try: 

260 # Estimate shear using GalSim. 

261 shape = galsim.hsm.EstimateShear( 

262 image, 

263 psf, 

264 weight=weight, 

265 badpix=None, # Already incorporated into `weight_image`. 

266 sky_var=skyvar, 

267 shear_est=self.config.shearType.upper(), 

268 recompute_flux=recomputeFlux.upper(), 

269 guess_sig_gal=2.5 * psfSigma, 

270 guess_sig_PSF=psfSigma, 

271 precision=precision, 

272 guess_centroid=guessCentroid, 

273 strict=True, # Raises GalSimHSMError if estimation fails. 

274 check=False, # This speeds up the code! 

275 hsmparams=None, 

276 ) 

277 except galsim.hsm.GalSimHSMError as error: 

278 raise measBase.MeasurementError(str(error), self.GALSIM.number) 

279 

280 # Set ellipticity and error values based on measurement type. 

281 if shape.meas_type == "e": 

282 record.set(self.e1Key, shape.corrected_e1) 

283 record.set(self.e2Key, shape.corrected_e2) 

284 record.set(self.sigmaKey, 2.0 * shape.corrected_shape_err) 

285 else: 

286 record.set(self.e1Key, shape.corrected_g1) 

287 record.set(self.e2Key, shape.corrected_g2) 

288 record.set(self.sigmaKey, shape.corrected_shape_err) 

289 

290 record.set(self.resolutionKey, shape.resolution_factor) 

291 self.flagHandler.setValue(record, self.FAILURE.number, shape.correction_status != 0) 

292 

293 def fail(self, record, error=None): 

294 # Docstring inherited. 

295 self.flagHandler.handleFailure(record) 

296 if error: 

297 centroid = self.centroidExtractor(record, self.flagHandler) 

298 self.log.debug( 

299 "Failed to measure shape for %d at (%f, %f): %s", 

300 record.getId(), 

301 centroid.getX(), 

302 centroid.getY(), 

303 error, 

304 ) 

305 

306 

307class HsmShapeBjConfig(HsmShapeConfig): 

308 """Configuration for HSM shape measurement for the BJ estimator.""" 

309 

310 @HsmShapeConfig.shearType.getter 

311 @inherit_doc(HsmShapeConfig) 

312 def shearType(self): 

313 # Docstring inherited. 

314 return "BJ" 

315 

316 

317@measBase.register("ext_shapeHSM_HsmShapeBj") 

318class HsmShapeBjPlugin(HsmShapePlugin): 

319 """Plugin for HSM shape measurement for the BJ estimator.""" 

320 

321 ConfigClass = HsmShapeBjConfig 

322 measTypeSymbol = "e" 

323 doc = "PSF-corrected shear using Bernstein & Jarvis (2002) method" 

324 

325 

326class HsmShapeLinearConfig(HsmShapeConfig): 

327 """Configuration for HSM shape measurement for the LINEAR estimator.""" 

328 

329 @HsmShapeConfig.shearType.getter 

330 @inherit_doc(HsmShapeConfig) 

331 def shearType(self): 

332 # Docstring inherited. 

333 return "LINEAR" 

334 

335 

336@measBase.register("ext_shapeHSM_HsmShapeLinear") 

337class HsmShapeLinearPlugin(HsmShapePlugin): 

338 """Plugin for HSM shape measurement for the LINEAR estimator.""" 

339 

340 ConfigClass = HsmShapeLinearConfig 

341 measTypeSymbol = "e" 

342 doc = "PSF-corrected shear using Hirata & Seljak (2003) 'linear' method" 

343 

344 

345class HsmShapeKsbConfig(HsmShapeConfig): 

346 """Configuration for HSM shape measurement for the KSB estimator.""" 

347 

348 @HsmShapeConfig.shearType.getter 

349 @inherit_doc(HsmShapeConfig) 

350 def shearType(self): 

351 # Docstring inherited. 

352 return "KSB" 

353 

354 

355@measBase.register("ext_shapeHSM_HsmShapeKsb") 

356class HsmShapeKsbPlugin(HsmShapePlugin): 

357 """Plugin for HSM shape measurement for the KSB estimator.""" 

358 

359 ConfigClass = HsmShapeKsbConfig 

360 measTypeSymbol = "g" 

361 doc = "PSF-corrected shear using Kaiser, Squires, & Broadhurst (1995) method" 

362 

363 

364class HsmShapeRegaussConfig(HsmShapeConfig): 

365 """Configuration for HSM shape measurement for the REGAUSS estimator.""" 

366 

367 @HsmShapeConfig.shearType.getter 

368 @inherit_doc(HsmShapeConfig) 

369 def shearType(self): 

370 # Docstring inherited. 

371 return "REGAUSS" 

372 

373 

374@measBase.register("ext_shapeHSM_HsmShapeRegauss") 

375class HsmShapeRegaussPlugin(HsmShapePlugin): 

376 """Plugin for HSM shape measurement for the REGAUSS estimator.""" 

377 

378 ConfigClass = HsmShapeRegaussConfig 

379 measTypeSymbol = "e" 

380 doc = "PSF-corrected shear using Hirata & Seljak (2003) 'regaussianization' method"