Coverage for python/lsst/meas/extensions/shapeHSM/_hsm_shape.py: 36%
154 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-17 08:48 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-17 08:48 +0000
1# This file is part of meas_extensions_shapeHSM.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import logging
24import galsim
25import lsst.afw.image as afwImage
26import lsst.afw.math as afwMath
27import lsst.meas.base as measBase
28import lsst.pex.config as pexConfig
29from lsst.geom import Point2I
31__all__ = [
32 "HsmShapeBjConfig",
33 "HsmShapeBjPlugin",
34 "HsmShapeLinearConfig",
35 "HsmShapeLinearPlugin",
36 "HsmShapeKsbConfig",
37 "HsmShapeKsbPlugin",
38 "HsmShapeRegaussConfig",
39 "HsmShapeRegaussPlugin",
40]
43class HsmShapeConfig(measBase.SingleFramePluginConfig):
44 """Base configuration for HSM shape measurement."""
46 shearType = pexConfig.ChoiceField[str](
47 doc="The desired method of PSF correction using GalSim. The first three options return an e-type "
48 "distortion, whereas the last option returns a g-type shear.",
49 allowed={
50 "REGAUSS": "Regaussianization method from Hirata & Seljak (2003)",
51 "LINEAR": "A modification by Hirata & Seljak (2003) of methods in Bernstein & Jarvis (2002)",
52 "BJ": "From Bernstein & Jarvis (2002)",
53 "KSB": "From Kaiser, Squires, & Broadhurst (1995)",
54 },
55 default="REGAUSS",
56 )
58 deblendNChild = pexConfig.Field[str](
59 doc="Field name for number of deblend children.",
60 default="",
61 )
63 badMaskPlanes = pexConfig.ListField[str](
64 doc="Mask planes that indicate pixels that should be excluded from the fit.",
65 default=["BAD", "SAT"],
66 )
69class HsmShapePlugin(measBase.SingleFramePlugin):
70 """Base plugin for HSM shape measurement."""
72 ConfigClass = HsmShapeConfig
73 doc = ""
75 def __init__(self, config, name, schema, metadata, logName=None):
76 if logName is None:
77 logName = __name__
78 super().__init__(config, name, schema, metadata, logName=logName)
80 # Define flags for possible issues that might arise during measurement.
81 flagDefs = measBase.FlagDefinitionList()
82 self.FAILURE = flagDefs.addFailureFlag("General failure flag, set if anything went wrong")
83 self.NO_PIXELS = flagDefs.add("flag_no_pixels", "No pixels to measure")
84 self.NOT_CONTAINED = flagDefs.add(
85 "flag_not_contained", "Center not contained in footprint bounding box"
86 )
87 self.PARENT_SOURCE = flagDefs.add("flag_parent_source", "Parent source, ignored")
88 self.GALSIM = flagDefs.add("flag_galsim", "GalSim failure")
90 # Embed the flag definitions in the schema using a flag handler.
91 self.flagHandler = measBase.FlagHandler.addFields(schema, name, flagDefs)
93 # Utilize a safe centroid extractor that uses the detection footprint
94 # as a fallback if necessary.
95 self.centroidExtractor = measBase.SafeCentroidExtractor(schema, name)
97 self.e1Key = self._addEllipticityField(name, 1, schema, self.doc)
98 self.e2Key = self._addEllipticityField(name, 2, schema, self.doc)
99 self.sigmaKey = schema.addField(
100 schema.join(name, "sigma"),
101 type=float,
102 doc=f"{self.doc} (shape measurement uncertainty per component)",
103 )
104 self.resolutionKey = schema.addField(
105 schema.join(name, "resolution"), type=float, doc="Resolution factor (0=unresolved, 1=resolved)"
106 )
107 self.hasDeblendKey = len(config.deblendNChild) > 0
109 if self.hasDeblendKey:
110 self.deblendKey = schema[config.deblendNChild]
112 self.log = logging.getLogger(self.logName)
114 @classmethod
115 def getExecutionOrder(cls):
116 return cls.SHAPE_ORDER
118 @staticmethod
119 def bboxToGalSimBounds(bbox):
120 xmin, xmax = bbox.getMinX(), bbox.getMaxX()
121 ymin, ymax = bbox.getMinY(), bbox.getMaxY()
122 return galsim._BoundsI(xmin, xmax, ymin, ymax)
124 def _addEllipticityField(self, name, n, schema, doc):
125 """
126 Helper function to add an ellipticity field to a measurement schema.
128 Parameters
129 ----------
130 name : `str`
131 Base name of the field.
132 n : `int`
133 Specifies whether the field is for the first (1) or second (2)
134 component.
135 schema : `~lsst.afw.table.Schema`
136 The schema to which the field is added.
137 doc : `str`
138 The documentation string that needs to be updated to reflect the
139 type and component of the measurement.
141 Returns
142 -------
143 `~lsst.afw.table.KeyD`
144 The key associated with the added field in the schema.
145 """
146 componentLookup = {1: "+ component", 2: "x component"}
147 typeLookup = {"e": " of ellipticity", "g": " of estimated shear"}
148 name = f"{name}_{self.measTypeSymbol}{n}"
149 updatedDoc = f"{doc} ({componentLookup[n]}{typeLookup[self.measTypeSymbol]})"
150 return schema.addField(name, type=float, doc=updatedDoc)
152 def measure(self, record, exposure):
153 """
154 Measure the shape of sources given an exposure and set the results in
155 the record in place.
157 Parameters
158 ----------
159 record : `~lsst.afw.table.SourceRecord`
160 The record where measurement outputs will be stored.
161 exposure : `~lsst.afw.image.Exposure`
162 The exposure containing the source which needs measurement.
164 Raises
165 ------
166 MeasurementError
167 Raised for errors in measurement.
168 """
169 # Extract the centroid from the record.
170 center = self.centroidExtractor(record, self.flagHandler)
172 if self.hasDeblendKey and record.get(self.deblendKey) > 0:
173 raise measBase.MeasurementError(self.PARENT_SOURCE.doc, self.PARENT_SOURCE.number)
175 # Get the bounding box of the source's footprint.
176 bbox = record.getFootprint().getBBox()
178 # Check that the bounding box has non-zero area.
179 if bbox.getArea() == 0:
180 raise measBase.MeasurementError(self.NO_PIXELS.doc, self.NO_PIXELS.number)
182 # Ensure that the centroid is within the bounding box.
183 if not bbox.contains(Point2I(center)):
184 raise measBase.MeasurementError(self.NOT_CONTAINED.doc, self.NOT_CONTAINED.number)
186 # Get the PSF image evaluated at the source centroid.
187 psfImage = exposure.getPsf().computeImage(center)
188 psfImage.setXY0(0, 0)
190 # Get the trace radius of the PSF.
191 psfSigma = exposure.getPsf().computeShape(center).getTraceRadius()
193 # Turn bounding box corners into GalSim bounds.
194 bounds = self.bboxToGalSimBounds(bbox)
196 # Get the bounding box of the PSF in the parent coordinate system.
197 psfBBox = psfImage.getBBox(afwImage.PARENT)
199 # Turn the PSF bounding box corners into GalSim bounds.
200 psfBounds = self.bboxToGalSimBounds(psfBBox)
202 # Each GalSim image below will match whatever dtype the input array is.
203 # NOTE: PSF is already restricted to a small image, so no bbox for the
204 # PSF is expected.
205 image = galsim._Image(exposure.image[bbox].array, bounds, wcs=None)
206 psf = galsim._Image(psfImage.array, psfBounds, wcs=None)
208 # Get the `lsst.meas.base` mask for bad pixels.
209 subMask = exposure.mask[bbox]
210 badpix = subMask.array.copy() # Copy it since badpix gets modified.
211 bitValue = exposure.mask.getPlaneBitMask(self.config.badMaskPlanes)
212 badpix &= bitValue
214 # Turn badpix to weight where elements set to 1 indicate 'use pixel'
215 # and those set to 0 mean 'do not use pixel'. Now, weight will assume
216 # the role of badpix, and we will no longer use badpix in our call to
217 # EstimateShear().
218 gd = badpix == 0
219 badpix[gd] = 1
220 badpix[~gd] = 0
221 weight = galsim._Image(badpix, bounds, wcs=None)
223 # Get the statistics control object for sky variance estimation.
224 sctrl = afwMath.StatisticsControl()
225 sctrl.setAndMask(bitValue)
227 # Create a variance image from the exposure.
228 # NOTE: Origin defaults to PARENT in all cases accessible from Python.
229 variance = afwImage.Image(
230 exposure.variance[bbox],
231 dtype=exposure.variance.dtype,
232 deep=False,
233 )
235 # Calculate median sky variance for use in shear estimation.
236 stat = afwMath.makeStatistics(variance, subMask, afwMath.MEDIAN, sctrl)
237 skyvar = stat.getValue(afwMath.MEDIAN)
239 # Directly use GalSim's C++/Python interface for shear estimation.
240 try:
241 # Initialize an instance of ShapeData to store the results.
242 shape = galsim.hsm.ShapeData(
243 image_bounds=galsim._BoundsI(0, 0, 1, 1),
244 observed_shape=galsim._Shear(0j),
245 psf_shape=galsim._Shear(0j),
246 moments_centroid=galsim._PositionD(0, 0),
247 )
249 # Prepare various values for GalSim's EstimateShearView.
250 recomputeFlux = "FIT"
251 precision = 1.0e-6
252 guessCentroid = galsim._PositionD(center.getX(), center.getY())
253 hsmparams = galsim.hsm.HSMParams.default
255 # Estimate shear using GalSim. Arguments are passed positionally
256 # to the C++ function. Inline comments specify the Python layer
257 # equivalent of each argument for clarity.
258 # TODO: [DM-42047] Change to public API when an optimized version
259 # is available.
260 galsim._galsim.EstimateShearView(
261 shape._data, # shape data buffer (not passed in pure Python)
262 image._image, # gal_image
263 psf._image, # PSF_image
264 weight._image, # weight
265 float(skyvar), # sky_var
266 self.config.shearType.upper(), # shear_est
267 recomputeFlux.upper(), # recompute_flux
268 float(2.5 * psfSigma), # guess_sig_gal
269 float(psfSigma), # guess_sig_PSF
270 float(precision), # precision
271 guessCentroid._p, # guess_centroid
272 hsmparams._hsmp, # hsmparams
273 )
274 except galsim.hsm.GalSimHSMError as error:
275 raise measBase.MeasurementError(str(error), self.GALSIM.number)
277 # Set ellipticity and error values based on measurement type.
278 if shape.meas_type == "e":
279 record.set(self.e1Key, shape.corrected_e1)
280 record.set(self.e2Key, shape.corrected_e2)
281 record.set(self.sigmaKey, 2.0 * shape.corrected_shape_err)
282 else:
283 record.set(self.e1Key, shape.corrected_g1)
284 record.set(self.e2Key, shape.corrected_g2)
285 record.set(self.sigmaKey, shape.corrected_shape_err)
287 record.set(self.resolutionKey, shape.resolution_factor)
288 self.flagHandler.setValue(record, self.FAILURE.number, shape.correction_status != 0)
290 def fail(self, record, error=None):
291 # Docstring inherited.
292 self.flagHandler.handleFailure(record)
293 if error:
294 centroid = self.centroidExtractor(record, self.flagHandler)
295 self.log.debug(
296 "Failed to measure shape for %d at (%f, %f): %s",
297 record.getId(),
298 centroid.getX(),
299 centroid.getY(),
300 error,
301 )
304class HsmShapeBjConfig(HsmShapeConfig):
305 """Configuration for HSM shape measurement for the BJ estimator."""
307 def setDefaults(self):
308 super().setDefaults()
309 self.shearType = "BJ"
311 def validate(self):
312 if self.shearType != "BJ":
313 raise pexConfig.FieldValidationError(self.shearType, self, "shearType should be set to 'BJ'.")
314 super().validate()
317@measBase.register("ext_shapeHSM_HsmShapeBj")
318class HsmShapeBjPlugin(HsmShapePlugin):
319 """Plugin for HSM shape measurement for the BJ estimator."""
321 ConfigClass = HsmShapeBjConfig
322 measTypeSymbol = "e"
323 doc = "PSF-corrected shear using Bernstein & Jarvis (2002) method"
326class HsmShapeLinearConfig(HsmShapeConfig):
327 """Configuration for HSM shape measurement for the LINEAR estimator."""
329 def setDefaults(self):
330 super().setDefaults()
331 self.shearType = "LINEAR"
333 def validate(self):
334 if self.shearType != "LINEAR":
335 raise pexConfig.FieldValidationError(self.shearType, self, "shearType should be set to 'LINEAR'.")
336 super().validate()
339@measBase.register("ext_shapeHSM_HsmShapeLinear")
340class HsmShapeLinearPlugin(HsmShapePlugin):
341 """Plugin for HSM shape measurement for the LINEAR estimator."""
343 ConfigClass = HsmShapeLinearConfig
344 measTypeSymbol = "e"
345 doc = "PSF-corrected shear using Hirata & Seljak (2003) 'linear' method"
348class HsmShapeKsbConfig(HsmShapeConfig):
349 """Configuration for HSM shape measurement for the KSB estimator."""
351 def setDefaults(self):
352 super().setDefaults()
353 self.shearType = "KSB"
355 def validate(self):
356 if self.shearType != "KSB":
357 raise pexConfig.FieldValidationError(self.shearType, self, "shearType should be set to 'KSB'.")
358 super().validate()
361@measBase.register("ext_shapeHSM_HsmShapeKsb")
362class HsmShapeKsbPlugin(HsmShapePlugin):
363 """Plugin for HSM shape measurement for the KSB estimator."""
365 ConfigClass = HsmShapeKsbConfig
366 measTypeSymbol = "g"
367 doc = "PSF-corrected shear using Kaiser, Squires, & Broadhurst (1995) method"
370class HsmShapeRegaussConfig(HsmShapeConfig):
371 """Configuration for HSM shape measurement for the REGAUSS estimator."""
373 def setDefaults(self):
374 super().setDefaults()
375 self.shearType = "REGAUSS"
377 def validate(self):
378 if self.shearType != "REGAUSS":
379 raise pexConfig.FieldValidationError(
380 self.shearType, self, "shearType should be set to 'REGAUSS'."
381 )
382 super().validate()
385@measBase.register("ext_shapeHSM_HsmShapeRegauss")
386class HsmShapeRegaussPlugin(HsmShapePlugin):
387 """Plugin for HSM shape measurement for the REGAUSS estimator."""
389 ConfigClass = HsmShapeRegaussConfig
390 measTypeSymbol = "e"
391 doc = "PSF-corrected shear using Hirata & Seljak (2003) 'regaussianization' method"