Coverage for python/lsst/meas/extensions/shapeHSM/_hsm_shape.py: 36%

154 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-18 10:50 +0000

1# This file is part of meas_extensions_shapeHSM. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import logging 

23 

24import galsim 

25import lsst.afw.image as afwImage 

26import lsst.afw.math as afwMath 

27import lsst.meas.base as measBase 

28import lsst.pex.config as pexConfig 

29from lsst.geom import Point2I 

30 

31__all__ = [ 

32 "HsmShapeBjConfig", 

33 "HsmShapeBjPlugin", 

34 "HsmShapeLinearConfig", 

35 "HsmShapeLinearPlugin", 

36 "HsmShapeKsbConfig", 

37 "HsmShapeKsbPlugin", 

38 "HsmShapeRegaussConfig", 

39 "HsmShapeRegaussPlugin", 

40] 

41 

42 

43class HsmShapeConfig(measBase.SingleFramePluginConfig): 

44 """Base configuration for HSM shape measurement.""" 

45 

46 shearType = pexConfig.ChoiceField[str]( 

47 doc=( 

48 "The desired method of PSF correction using GalSim. The first three options return an e-type " 

49 "distortion, whereas the last option returns a g-type shear." 

50 ), 

51 allowed={ 

52 "REGAUSS": "Regaussianization method from Hirata & Seljak (2003)", 

53 "LINEAR": "A modification by Hirata & Seljak (2003) of methods in Bernstein & Jarvis (2002)", 

54 "BJ": "From Bernstein & Jarvis (2002)", 

55 "KSB": "From Kaiser, Squires, & Broadhurst (1995)", 

56 }, 

57 default="REGAUSS", 

58 ) 

59 

60 deblendNChild = pexConfig.Field[str]( 

61 doc="Field name for number of deblend children.", 

62 default="", 

63 ) 

64 

65 badMaskPlanes = pexConfig.ListField[str]( 

66 doc="Mask planes that indicate pixels that should be excluded from the fit.", 

67 default=["BAD", "SAT"], 

68 ) 

69 

70 

71class HsmShapePlugin(measBase.SingleFramePlugin): 

72 """Base plugin for HSM shape measurement.""" 

73 

74 ConfigClass = HsmShapeConfig 

75 doc = "" 

76 

77 def __init__(self, config, name, schema, metadata, logName=None): 

78 if logName is None: 

79 logName = __name__ 

80 super().__init__(config, name, schema, metadata, logName=logName) 

81 

82 # Define flags for possible issues that might arise during measurement. 

83 flagDefs = measBase.FlagDefinitionList() 

84 self.FAILURE = flagDefs.addFailureFlag("General failure flag, set if anything went wrong") 

85 self.NO_PIXELS = flagDefs.add("flag_no_pixels", "No pixels to measure") 

86 self.NOT_CONTAINED = flagDefs.add( 

87 "flag_not_contained", "Center not contained in footprint bounding box" 

88 ) 

89 self.PARENT_SOURCE = flagDefs.add("flag_parent_source", "Parent source, ignored") 

90 self.GALSIM = flagDefs.add("flag_galsim", "GalSim failure") 

91 

92 # Embed the flag definitions in the schema using a flag handler. 

93 self.flagHandler = measBase.FlagHandler.addFields(schema, name, flagDefs) 

94 

95 # Utilize a safe centroid extractor that uses the detection footprint 

96 # as a fallback if necessary. 

97 self.centroidExtractor = measBase.SafeCentroidExtractor(schema, name) 

98 

99 self.e1Key = self._addEllipticityField(name, 1, schema, self.doc) 

100 self.e2Key = self._addEllipticityField(name, 2, schema, self.doc) 

101 self.sigmaKey = schema.addField( 

102 schema.join(name, "sigma"), 

103 type=float, 

104 doc=f"{self.doc} (shape measurement uncertainty per component)", 

105 ) 

106 self.resolutionKey = schema.addField( 

107 schema.join(name, "resolution"), type=float, doc="Resolution factor (0=unresolved, 1=resolved)" 

108 ) 

109 self.hasDeblendKey = len(config.deblendNChild) > 0 

110 

111 if self.hasDeblendKey: 

112 self.deblendKey = schema[config.deblendNChild].asKey() 

113 

114 self.log = logging.getLogger(self.logName) 

115 

116 @classmethod 

117 def getExecutionOrder(cls): 

118 return cls.SHAPE_ORDER 

119 

120 @staticmethod 

121 def bboxToGalSimBounds(bbox): 

122 xmin, xmax = bbox.getMinX(), bbox.getMaxX() 

123 ymin, ymax = bbox.getMinY(), bbox.getMaxY() 

124 return galsim._BoundsI(xmin, xmax, ymin, ymax) 

125 

126 def _addEllipticityField(self, name, n, schema, doc): 

127 """ 

128 Helper function to add an ellipticity field to a measurement schema. 

129 

130 Parameters 

131 ---------- 

132 name : `str` 

133 Base name of the field. 

134 n : `int` 

135 Specifies whether the field is for the first (1) or second (2) 

136 component. 

137 schema : `~lsst.afw.table.Schema` 

138 The schema to which the field is added. 

139 doc : `str` 

140 The documentation string that needs to be updated to reflect the 

141 type and component of the measurement. 

142 

143 Returns 

144 ------- 

145 `~lsst.afw.table.KeyD` 

146 The key associated with the added field in the schema. 

147 """ 

148 componentLookup = {1: "+ component", 2: "x component"} 

149 typeLookup = {"e": " of ellipticity", "g": " of estimated shear"} 

150 name = f"{name}_{self.measTypeSymbol}{n}" 

151 updatedDoc = f"{doc} ({componentLookup[n]}{typeLookup[self.measTypeSymbol]})" 

152 return schema.addField(name, type=float, doc=updatedDoc) 

153 

154 def measure(self, record, exposure): 

155 """ 

156 Measure the shape of sources given an exposure and set the results in 

157 the record in place. 

158 

159 Parameters 

160 ---------- 

161 record : `~lsst.afw.table.SourceRecord` 

162 The record where measurement outputs will be stored. 

163 exposure : `~lsst.afw.image.Exposure` 

164 The exposure containing the source which needs measurement. 

165 

166 Raises 

167 ------ 

168 MeasurementError 

169 Raised for errors in measurement. 

170 """ 

171 # Extract the centroid from the record. 

172 center = self.centroidExtractor(record, self.flagHandler) 

173 

174 if self.hasDeblendKey and record.get(self.deblendKey) > 0: 

175 raise measBase.MeasurementError(self.PARENT_SOURCE.doc, self.PARENT_SOURCE.number) 

176 

177 # Get the bounding box of the source's footprint. 

178 bbox = record.getFootprint().getBBox() 

179 

180 # Check that the bounding box has non-zero area. 

181 if bbox.getArea() == 0: 

182 raise measBase.MeasurementError(self.NO_PIXELS.doc, self.NO_PIXELS.number) 

183 

184 # Ensure that the centroid is within the bounding box. 

185 if not bbox.contains(Point2I(center)): 

186 raise measBase.MeasurementError(self.NOT_CONTAINED.doc, self.NOT_CONTAINED.number) 

187 

188 # Get the PSF image evaluated at the source centroid. 

189 psfImage = exposure.getPsf().computeImage(center) 

190 psfImage.setXY0(0, 0) 

191 

192 # Get the trace radius of the PSF. 

193 psfSigma = exposure.getPsf().computeShape(center).getTraceRadius() 

194 

195 # Turn bounding box corners into GalSim bounds. 

196 bounds = self.bboxToGalSimBounds(bbox) 

197 

198 # Get the bounding box of the PSF in the parent coordinate system. 

199 psfBBox = psfImage.getBBox(afwImage.PARENT) 

200 

201 # Turn the PSF bounding box corners into GalSim bounds. 

202 psfBounds = self.bboxToGalSimBounds(psfBBox) 

203 

204 # Each GalSim image below will match whatever dtype the input array is. 

205 # NOTE: PSF is already restricted to a small image, so no bbox for the 

206 # PSF is expected. 

207 image = galsim._Image(exposure.image[bbox].array, bounds, wcs=None) 

208 psf = galsim._Image(psfImage.array, psfBounds, wcs=None) 

209 

210 # Get the `lsst.meas.base` mask for bad pixels. 

211 subMask = exposure.mask[bbox] 

212 badpix = subMask.array.copy() # Copy it since badpix gets modified. 

213 bitValue = exposure.mask.getPlaneBitMask(self.config.badMaskPlanes) 

214 badpix &= bitValue 

215 

216 # Turn badpix to weight where elements set to 1 indicate 'use pixel' 

217 # and those set to 0 mean 'do not use pixel'. Now, weight will assume 

218 # the role of badpix, and we will no longer use badpix in our call to 

219 # EstimateShear(). 

220 gd = badpix == 0 

221 badpix[gd] = 1 

222 badpix[~gd] = 0 

223 weight = galsim._Image(badpix, bounds, wcs=None) 

224 

225 # Get the statistics control object for sky variance estimation. 

226 sctrl = afwMath.StatisticsControl() 

227 sctrl.setAndMask(bitValue) 

228 

229 # Create a variance image from the exposure. 

230 # NOTE: Origin defaults to PARENT in all cases accessible from Python. 

231 variance = afwImage.Image( 

232 exposure.variance[bbox], 

233 dtype=exposure.variance.dtype, 

234 deep=False, 

235 ) 

236 

237 # Calculate median sky variance for use in shear estimation. 

238 stat = afwMath.makeStatistics(variance, subMask, afwMath.MEDIAN, sctrl) 

239 skyvar = stat.getValue(afwMath.MEDIAN) 

240 

241 # Initialize an instance of ShapeData to store the results. 

242 shape = galsim.hsm.ShapeData( 

243 image_bounds=galsim._BoundsI(0, 0, 1, 1), 

244 observed_shape=galsim._Shear(0j), 

245 psf_shape=galsim._Shear(0j), 

246 moments_centroid=galsim._PositionD(0, 0), 

247 ) 

248 

249 # Prepare various values for the GalSim's EstimateShearView call. 

250 recomputeFlux = "FIT" 

251 precision = 1.0e-6 

252 guessCentroid = galsim._PositionD(center.getX(), center.getY()) 

253 hsmparams = galsim.hsm.HSMParams.default 

254 

255 # Directly use GalSim's C++/Python interface for shear estimation. 

256 try: 

257 # Estimate shear using GalSim. Arguments are passed positionally 

258 # to the C++ function. Inline comments specify the Python layer 

259 # equivalent of each argument for clarity. 

260 # TODO: [DM-42047] Change to public API when an optimized version 

261 # is available. 

262 galsim._galsim.EstimateShearView( 

263 shape._data, # shape data buffer (not passed in pure Python) 

264 image._image, # gal_image 

265 psf._image, # PSF_image 

266 weight._image, # weight 

267 float(skyvar), # sky_var 

268 self.config.shearType.upper(), # shear_est 

269 recomputeFlux.upper(), # recompute_flux 

270 float(2.5 * psfSigma), # guess_sig_gal 

271 float(psfSigma), # guess_sig_PSF 

272 float(precision), # precision 

273 guessCentroid._p, # guess_centroid 

274 hsmparams._hsmp, # hsmparams 

275 ) 

276 # GalSim does not raise custom pybind errors as of v2.5, resulting in 

277 # all GalSim C++ errors being RuntimeErrors. 

278 except (galsim.hsm.GalSimHSMError, RuntimeError) as error: 

279 raise measBase.MeasurementError(str(error), self.GALSIM.number) 

280 

281 # Set ellipticity and error values based on measurement type. 

282 if shape.meas_type == "e": 

283 record.set(self.e1Key, shape.corrected_e1) 

284 record.set(self.e2Key, shape.corrected_e2) 

285 record.set(self.sigmaKey, 2.0 * shape.corrected_shape_err) 

286 else: 

287 record.set(self.e1Key, shape.corrected_g1) 

288 record.set(self.e2Key, shape.corrected_g2) 

289 record.set(self.sigmaKey, shape.corrected_shape_err) 

290 

291 record.set(self.resolutionKey, shape.resolution_factor) 

292 self.flagHandler.setValue(record, self.FAILURE.number, shape.correction_status != 0) 

293 

294 def fail(self, record, error=None): 

295 # Docstring inherited. 

296 self.flagHandler.handleFailure(record) 

297 if error: 

298 centroid = self.centroidExtractor(record, self.flagHandler) 

299 self.log.debug( 

300 "Failed to measure shape for %d at (%f, %f): %s", 

301 record.getId(), 

302 centroid.getX(), 

303 centroid.getY(), 

304 error, 

305 ) 

306 

307 

308class HsmShapeBjConfig(HsmShapeConfig): 

309 """Configuration for HSM shape measurement for the BJ estimator.""" 

310 

311 def setDefaults(self): 

312 super().setDefaults() 

313 self.shearType = "BJ" 

314 

315 def validate(self): 

316 if self.shearType != "BJ": 

317 raise pexConfig.FieldValidationError(self.shearType, self, "shearType should be set to 'BJ'.") 

318 super().validate() 

319 

320 

321@measBase.register("ext_shapeHSM_HsmShapeBj") 

322class HsmShapeBjPlugin(HsmShapePlugin): 

323 """Plugin for HSM shape measurement for the BJ estimator.""" 

324 

325 ConfigClass = HsmShapeBjConfig 

326 measTypeSymbol = "e" 

327 doc = "PSF-corrected shear using Bernstein & Jarvis (2002) method" 

328 

329 

330class HsmShapeLinearConfig(HsmShapeConfig): 

331 """Configuration for HSM shape measurement for the LINEAR estimator.""" 

332 

333 def setDefaults(self): 

334 super().setDefaults() 

335 self.shearType = "LINEAR" 

336 

337 def validate(self): 

338 if self.shearType != "LINEAR": 

339 raise pexConfig.FieldValidationError(self.shearType, self, "shearType should be set to 'LINEAR'.") 

340 super().validate() 

341 

342 

343@measBase.register("ext_shapeHSM_HsmShapeLinear") 

344class HsmShapeLinearPlugin(HsmShapePlugin): 

345 """Plugin for HSM shape measurement for the LINEAR estimator.""" 

346 

347 ConfigClass = HsmShapeLinearConfig 

348 measTypeSymbol = "e" 

349 doc = "PSF-corrected shear using Hirata & Seljak (2003) 'linear' method" 

350 

351 

352class HsmShapeKsbConfig(HsmShapeConfig): 

353 """Configuration for HSM shape measurement for the KSB estimator.""" 

354 

355 def setDefaults(self): 

356 super().setDefaults() 

357 self.shearType = "KSB" 

358 

359 def validate(self): 

360 if self.shearType != "KSB": 

361 raise pexConfig.FieldValidationError(self.shearType, self, "shearType should be set to 'KSB'.") 

362 super().validate() 

363 

364 

365@measBase.register("ext_shapeHSM_HsmShapeKsb") 

366class HsmShapeKsbPlugin(HsmShapePlugin): 

367 """Plugin for HSM shape measurement for the KSB estimator.""" 

368 

369 ConfigClass = HsmShapeKsbConfig 

370 measTypeSymbol = "g" 

371 doc = "PSF-corrected shear using Kaiser, Squires, & Broadhurst (1995) method" 

372 

373 

374class HsmShapeRegaussConfig(HsmShapeConfig): 

375 """Configuration for HSM shape measurement for the REGAUSS estimator.""" 

376 

377 def setDefaults(self): 

378 super().setDefaults() 

379 self.shearType = "REGAUSS" 

380 

381 def validate(self): 

382 if self.shearType != "REGAUSS": 

383 raise pexConfig.FieldValidationError( 

384 self.shearType, self, "shearType should be set to 'REGAUSS'." 

385 ) 

386 super().validate() 

387 

388 

389@measBase.register("ext_shapeHSM_HsmShapeRegauss") 

390class HsmShapeRegaussPlugin(HsmShapePlugin): 

391 """Plugin for HSM shape measurement for the REGAUSS estimator.""" 

392 

393 ConfigClass = HsmShapeRegaussConfig 

394 measTypeSymbol = "e" 

395 doc = "PSF-corrected shear using Hirata & Seljak (2003) 'regaussianization' method"