Coverage for python/lsst/meas/extensions/shapeHSM/_hsm_shape.py: 36%

154 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-17 08:48 +0000

1# This file is part of meas_extensions_shapeHSM. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import logging 

23 

24import galsim 

25import lsst.afw.image as afwImage 

26import lsst.afw.math as afwMath 

27import lsst.meas.base as measBase 

28import lsst.pex.config as pexConfig 

29from lsst.geom import Point2I 

30 

31__all__ = [ 

32 "HsmShapeBjConfig", 

33 "HsmShapeBjPlugin", 

34 "HsmShapeLinearConfig", 

35 "HsmShapeLinearPlugin", 

36 "HsmShapeKsbConfig", 

37 "HsmShapeKsbPlugin", 

38 "HsmShapeRegaussConfig", 

39 "HsmShapeRegaussPlugin", 

40] 

41 

42 

43class HsmShapeConfig(measBase.SingleFramePluginConfig): 

44 """Base configuration for HSM shape measurement.""" 

45 

46 shearType = pexConfig.ChoiceField[str]( 

47 doc="The desired method of PSF correction using GalSim. The first three options return an e-type " 

48 "distortion, whereas the last option returns a g-type shear.", 

49 allowed={ 

50 "REGAUSS": "Regaussianization method from Hirata & Seljak (2003)", 

51 "LINEAR": "A modification by Hirata & Seljak (2003) of methods in Bernstein & Jarvis (2002)", 

52 "BJ": "From Bernstein & Jarvis (2002)", 

53 "KSB": "From Kaiser, Squires, & Broadhurst (1995)", 

54 }, 

55 default="REGAUSS", 

56 ) 

57 

58 deblendNChild = pexConfig.Field[str]( 

59 doc="Field name for number of deblend children.", 

60 default="", 

61 ) 

62 

63 badMaskPlanes = pexConfig.ListField[str]( 

64 doc="Mask planes that indicate pixels that should be excluded from the fit.", 

65 default=["BAD", "SAT"], 

66 ) 

67 

68 

69class HsmShapePlugin(measBase.SingleFramePlugin): 

70 """Base plugin for HSM shape measurement.""" 

71 

72 ConfigClass = HsmShapeConfig 

73 doc = "" 

74 

75 def __init__(self, config, name, schema, metadata, logName=None): 

76 if logName is None: 

77 logName = __name__ 

78 super().__init__(config, name, schema, metadata, logName=logName) 

79 

80 # Define flags for possible issues that might arise during measurement. 

81 flagDefs = measBase.FlagDefinitionList() 

82 self.FAILURE = flagDefs.addFailureFlag("General failure flag, set if anything went wrong") 

83 self.NO_PIXELS = flagDefs.add("flag_no_pixels", "No pixels to measure") 

84 self.NOT_CONTAINED = flagDefs.add( 

85 "flag_not_contained", "Center not contained in footprint bounding box" 

86 ) 

87 self.PARENT_SOURCE = flagDefs.add("flag_parent_source", "Parent source, ignored") 

88 self.GALSIM = flagDefs.add("flag_galsim", "GalSim failure") 

89 

90 # Embed the flag definitions in the schema using a flag handler. 

91 self.flagHandler = measBase.FlagHandler.addFields(schema, name, flagDefs) 

92 

93 # Utilize a safe centroid extractor that uses the detection footprint 

94 # as a fallback if necessary. 

95 self.centroidExtractor = measBase.SafeCentroidExtractor(schema, name) 

96 

97 self.e1Key = self._addEllipticityField(name, 1, schema, self.doc) 

98 self.e2Key = self._addEllipticityField(name, 2, schema, self.doc) 

99 self.sigmaKey = schema.addField( 

100 schema.join(name, "sigma"), 

101 type=float, 

102 doc=f"{self.doc} (shape measurement uncertainty per component)", 

103 ) 

104 self.resolutionKey = schema.addField( 

105 schema.join(name, "resolution"), type=float, doc="Resolution factor (0=unresolved, 1=resolved)" 

106 ) 

107 self.hasDeblendKey = len(config.deblendNChild) > 0 

108 

109 if self.hasDeblendKey: 

110 self.deblendKey = schema[config.deblendNChild] 

111 

112 self.log = logging.getLogger(self.logName) 

113 

114 @classmethod 

115 def getExecutionOrder(cls): 

116 return cls.SHAPE_ORDER 

117 

118 @staticmethod 

119 def bboxToGalSimBounds(bbox): 

120 xmin, xmax = bbox.getMinX(), bbox.getMaxX() 

121 ymin, ymax = bbox.getMinY(), bbox.getMaxY() 

122 return galsim._BoundsI(xmin, xmax, ymin, ymax) 

123 

124 def _addEllipticityField(self, name, n, schema, doc): 

125 """ 

126 Helper function to add an ellipticity field to a measurement schema. 

127 

128 Parameters 

129 ---------- 

130 name : `str` 

131 Base name of the field. 

132 n : `int` 

133 Specifies whether the field is for the first (1) or second (2) 

134 component. 

135 schema : `~lsst.afw.table.Schema` 

136 The schema to which the field is added. 

137 doc : `str` 

138 The documentation string that needs to be updated to reflect the 

139 type and component of the measurement. 

140 

141 Returns 

142 ------- 

143 `~lsst.afw.table.KeyD` 

144 The key associated with the added field in the schema. 

145 """ 

146 componentLookup = {1: "+ component", 2: "x component"} 

147 typeLookup = {"e": " of ellipticity", "g": " of estimated shear"} 

148 name = f"{name}_{self.measTypeSymbol}{n}" 

149 updatedDoc = f"{doc} ({componentLookup[n]}{typeLookup[self.measTypeSymbol]})" 

150 return schema.addField(name, type=float, doc=updatedDoc) 

151 

152 def measure(self, record, exposure): 

153 """ 

154 Measure the shape of sources given an exposure and set the results in 

155 the record in place. 

156 

157 Parameters 

158 ---------- 

159 record : `~lsst.afw.table.SourceRecord` 

160 The record where measurement outputs will be stored. 

161 exposure : `~lsst.afw.image.Exposure` 

162 The exposure containing the source which needs measurement. 

163 

164 Raises 

165 ------ 

166 MeasurementError 

167 Raised for errors in measurement. 

168 """ 

169 # Extract the centroid from the record. 

170 center = self.centroidExtractor(record, self.flagHandler) 

171 

172 if self.hasDeblendKey and record.get(self.deblendKey) > 0: 

173 raise measBase.MeasurementError(self.PARENT_SOURCE.doc, self.PARENT_SOURCE.number) 

174 

175 # Get the bounding box of the source's footprint. 

176 bbox = record.getFootprint().getBBox() 

177 

178 # Check that the bounding box has non-zero area. 

179 if bbox.getArea() == 0: 

180 raise measBase.MeasurementError(self.NO_PIXELS.doc, self.NO_PIXELS.number) 

181 

182 # Ensure that the centroid is within the bounding box. 

183 if not bbox.contains(Point2I(center)): 

184 raise measBase.MeasurementError(self.NOT_CONTAINED.doc, self.NOT_CONTAINED.number) 

185 

186 # Get the PSF image evaluated at the source centroid. 

187 psfImage = exposure.getPsf().computeImage(center) 

188 psfImage.setXY0(0, 0) 

189 

190 # Get the trace radius of the PSF. 

191 psfSigma = exposure.getPsf().computeShape(center).getTraceRadius() 

192 

193 # Turn bounding box corners into GalSim bounds. 

194 bounds = self.bboxToGalSimBounds(bbox) 

195 

196 # Get the bounding box of the PSF in the parent coordinate system. 

197 psfBBox = psfImage.getBBox(afwImage.PARENT) 

198 

199 # Turn the PSF bounding box corners into GalSim bounds. 

200 psfBounds = self.bboxToGalSimBounds(psfBBox) 

201 

202 # Each GalSim image below will match whatever dtype the input array is. 

203 # NOTE: PSF is already restricted to a small image, so no bbox for the 

204 # PSF is expected. 

205 image = galsim._Image(exposure.image[bbox].array, bounds, wcs=None) 

206 psf = galsim._Image(psfImage.array, psfBounds, wcs=None) 

207 

208 # Get the `lsst.meas.base` mask for bad pixels. 

209 subMask = exposure.mask[bbox] 

210 badpix = subMask.array.copy() # Copy it since badpix gets modified. 

211 bitValue = exposure.mask.getPlaneBitMask(self.config.badMaskPlanes) 

212 badpix &= bitValue 

213 

214 # Turn badpix to weight where elements set to 1 indicate 'use pixel' 

215 # and those set to 0 mean 'do not use pixel'. Now, weight will assume 

216 # the role of badpix, and we will no longer use badpix in our call to 

217 # EstimateShear(). 

218 gd = badpix == 0 

219 badpix[gd] = 1 

220 badpix[~gd] = 0 

221 weight = galsim._Image(badpix, bounds, wcs=None) 

222 

223 # Get the statistics control object for sky variance estimation. 

224 sctrl = afwMath.StatisticsControl() 

225 sctrl.setAndMask(bitValue) 

226 

227 # Create a variance image from the exposure. 

228 # NOTE: Origin defaults to PARENT in all cases accessible from Python. 

229 variance = afwImage.Image( 

230 exposure.variance[bbox], 

231 dtype=exposure.variance.dtype, 

232 deep=False, 

233 ) 

234 

235 # Calculate median sky variance for use in shear estimation. 

236 stat = afwMath.makeStatistics(variance, subMask, afwMath.MEDIAN, sctrl) 

237 skyvar = stat.getValue(afwMath.MEDIAN) 

238 

239 # Directly use GalSim's C++/Python interface for shear estimation. 

240 try: 

241 # Initialize an instance of ShapeData to store the results. 

242 shape = galsim.hsm.ShapeData( 

243 image_bounds=galsim._BoundsI(0, 0, 1, 1), 

244 observed_shape=galsim._Shear(0j), 

245 psf_shape=galsim._Shear(0j), 

246 moments_centroid=galsim._PositionD(0, 0), 

247 ) 

248 

249 # Prepare various values for GalSim's EstimateShearView. 

250 recomputeFlux = "FIT" 

251 precision = 1.0e-6 

252 guessCentroid = galsim._PositionD(center.getX(), center.getY()) 

253 hsmparams = galsim.hsm.HSMParams.default 

254 

255 # Estimate shear using GalSim. Arguments are passed positionally 

256 # to the C++ function. Inline comments specify the Python layer 

257 # equivalent of each argument for clarity. 

258 # TODO: [DM-42047] Change to public API when an optimized version 

259 # is available. 

260 galsim._galsim.EstimateShearView( 

261 shape._data, # shape data buffer (not passed in pure Python) 

262 image._image, # gal_image 

263 psf._image, # PSF_image 

264 weight._image, # weight 

265 float(skyvar), # sky_var 

266 self.config.shearType.upper(), # shear_est 

267 recomputeFlux.upper(), # recompute_flux 

268 float(2.5 * psfSigma), # guess_sig_gal 

269 float(psfSigma), # guess_sig_PSF 

270 float(precision), # precision 

271 guessCentroid._p, # guess_centroid 

272 hsmparams._hsmp, # hsmparams 

273 ) 

274 except galsim.hsm.GalSimHSMError as error: 

275 raise measBase.MeasurementError(str(error), self.GALSIM.number) 

276 

277 # Set ellipticity and error values based on measurement type. 

278 if shape.meas_type == "e": 

279 record.set(self.e1Key, shape.corrected_e1) 

280 record.set(self.e2Key, shape.corrected_e2) 

281 record.set(self.sigmaKey, 2.0 * shape.corrected_shape_err) 

282 else: 

283 record.set(self.e1Key, shape.corrected_g1) 

284 record.set(self.e2Key, shape.corrected_g2) 

285 record.set(self.sigmaKey, shape.corrected_shape_err) 

286 

287 record.set(self.resolutionKey, shape.resolution_factor) 

288 self.flagHandler.setValue(record, self.FAILURE.number, shape.correction_status != 0) 

289 

290 def fail(self, record, error=None): 

291 # Docstring inherited. 

292 self.flagHandler.handleFailure(record) 

293 if error: 

294 centroid = self.centroidExtractor(record, self.flagHandler) 

295 self.log.debug( 

296 "Failed to measure shape for %d at (%f, %f): %s", 

297 record.getId(), 

298 centroid.getX(), 

299 centroid.getY(), 

300 error, 

301 ) 

302 

303 

304class HsmShapeBjConfig(HsmShapeConfig): 

305 """Configuration for HSM shape measurement for the BJ estimator.""" 

306 

307 def setDefaults(self): 

308 super().setDefaults() 

309 self.shearType = "BJ" 

310 

311 def validate(self): 

312 if self.shearType != "BJ": 

313 raise pexConfig.FieldValidationError(self.shearType, self, "shearType should be set to 'BJ'.") 

314 super().validate() 

315 

316 

317@measBase.register("ext_shapeHSM_HsmShapeBj") 

318class HsmShapeBjPlugin(HsmShapePlugin): 

319 """Plugin for HSM shape measurement for the BJ estimator.""" 

320 

321 ConfigClass = HsmShapeBjConfig 

322 measTypeSymbol = "e" 

323 doc = "PSF-corrected shear using Bernstein & Jarvis (2002) method" 

324 

325 

326class HsmShapeLinearConfig(HsmShapeConfig): 

327 """Configuration for HSM shape measurement for the LINEAR estimator.""" 

328 

329 def setDefaults(self): 

330 super().setDefaults() 

331 self.shearType = "LINEAR" 

332 

333 def validate(self): 

334 if self.shearType != "LINEAR": 

335 raise pexConfig.FieldValidationError(self.shearType, self, "shearType should be set to 'LINEAR'.") 

336 super().validate() 

337 

338 

339@measBase.register("ext_shapeHSM_HsmShapeLinear") 

340class HsmShapeLinearPlugin(HsmShapePlugin): 

341 """Plugin for HSM shape measurement for the LINEAR estimator.""" 

342 

343 ConfigClass = HsmShapeLinearConfig 

344 measTypeSymbol = "e" 

345 doc = "PSF-corrected shear using Hirata & Seljak (2003) 'linear' method" 

346 

347 

348class HsmShapeKsbConfig(HsmShapeConfig): 

349 """Configuration for HSM shape measurement for the KSB estimator.""" 

350 

351 def setDefaults(self): 

352 super().setDefaults() 

353 self.shearType = "KSB" 

354 

355 def validate(self): 

356 if self.shearType != "KSB": 

357 raise pexConfig.FieldValidationError(self.shearType, self, "shearType should be set to 'KSB'.") 

358 super().validate() 

359 

360 

361@measBase.register("ext_shapeHSM_HsmShapeKsb") 

362class HsmShapeKsbPlugin(HsmShapePlugin): 

363 """Plugin for HSM shape measurement for the KSB estimator.""" 

364 

365 ConfigClass = HsmShapeKsbConfig 

366 measTypeSymbol = "g" 

367 doc = "PSF-corrected shear using Kaiser, Squires, & Broadhurst (1995) method" 

368 

369 

370class HsmShapeRegaussConfig(HsmShapeConfig): 

371 """Configuration for HSM shape measurement for the REGAUSS estimator.""" 

372 

373 def setDefaults(self): 

374 super().setDefaults() 

375 self.shearType = "REGAUSS" 

376 

377 def validate(self): 

378 if self.shearType != "REGAUSS": 

379 raise pexConfig.FieldValidationError( 

380 self.shearType, self, "shearType should be set to 'REGAUSS'." 

381 ) 

382 super().validate() 

383 

384 

385@measBase.register("ext_shapeHSM_HsmShapeRegauss") 

386class HsmShapeRegaussPlugin(HsmShapePlugin): 

387 """Plugin for HSM shape measurement for the REGAUSS estimator.""" 

388 

389 ConfigClass = HsmShapeRegaussConfig 

390 measTypeSymbol = "e" 

391 doc = "PSF-corrected shear using Hirata & Seljak (2003) 'regaussianization' method"