Coverage for python/lsst/meas/extensions/shapeHSM/_hsm_shape.py: 36%
154 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-03-29 03:39 -0700
« prev ^ index » next coverage.py v7.4.4, created at 2024-03-29 03:39 -0700
1# This file is part of meas_extensions_shapeHSM.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import logging
24import galsim
25import lsst.afw.image as afwImage
26import lsst.afw.math as afwMath
27import lsst.meas.base as measBase
28import lsst.pex.config as pexConfig
29from lsst.geom import Point2I
31__all__ = [
32 "HsmShapeBjConfig",
33 "HsmShapeBjPlugin",
34 "HsmShapeLinearConfig",
35 "HsmShapeLinearPlugin",
36 "HsmShapeKsbConfig",
37 "HsmShapeKsbPlugin",
38 "HsmShapeRegaussConfig",
39 "HsmShapeRegaussPlugin",
40]
43class HsmShapeConfig(measBase.SingleFramePluginConfig):
44 """Base configuration for HSM shape measurement."""
46 shearType = pexConfig.ChoiceField[str](
47 doc=(
48 "The desired method of PSF correction using GalSim. The first three options return an e-type "
49 "distortion, whereas the last option returns a g-type shear."
50 ),
51 allowed={
52 "REGAUSS": "Regaussianization method from Hirata & Seljak (2003)",
53 "LINEAR": "A modification by Hirata & Seljak (2003) of methods in Bernstein & Jarvis (2002)",
54 "BJ": "From Bernstein & Jarvis (2002)",
55 "KSB": "From Kaiser, Squires, & Broadhurst (1995)",
56 },
57 default="REGAUSS",
58 )
60 deblendNChild = pexConfig.Field[str](
61 doc="Field name for number of deblend children.",
62 default="",
63 )
65 badMaskPlanes = pexConfig.ListField[str](
66 doc="Mask planes that indicate pixels that should be excluded from the fit.",
67 default=["BAD", "SAT"],
68 )
71class HsmShapePlugin(measBase.SingleFramePlugin):
72 """Base plugin for HSM shape measurement."""
74 ConfigClass = HsmShapeConfig
75 doc = ""
77 def __init__(self, config, name, schema, metadata, logName=None):
78 if logName is None:
79 logName = __name__
80 super().__init__(config, name, schema, metadata, logName=logName)
82 # Define flags for possible issues that might arise during measurement.
83 flagDefs = measBase.FlagDefinitionList()
84 self.FAILURE = flagDefs.addFailureFlag("General failure flag, set if anything went wrong")
85 self.NO_PIXELS = flagDefs.add("flag_no_pixels", "No pixels to measure")
86 self.NOT_CONTAINED = flagDefs.add(
87 "flag_not_contained", "Center not contained in footprint bounding box"
88 )
89 self.PARENT_SOURCE = flagDefs.add("flag_parent_source", "Parent source, ignored")
90 self.GALSIM = flagDefs.add("flag_galsim", "GalSim failure")
92 # Embed the flag definitions in the schema using a flag handler.
93 self.flagHandler = measBase.FlagHandler.addFields(schema, name, flagDefs)
95 # Utilize a safe centroid extractor that uses the detection footprint
96 # as a fallback if necessary.
97 self.centroidExtractor = measBase.SafeCentroidExtractor(schema, name)
99 self.e1Key = self._addEllipticityField(name, 1, schema, self.doc)
100 self.e2Key = self._addEllipticityField(name, 2, schema, self.doc)
101 self.sigmaKey = schema.addField(
102 schema.join(name, "sigma"),
103 type=float,
104 doc=f"{self.doc} (shape measurement uncertainty per component)",
105 )
106 self.resolutionKey = schema.addField(
107 schema.join(name, "resolution"), type=float, doc="Resolution factor (0=unresolved, 1=resolved)"
108 )
109 self.hasDeblendKey = len(config.deblendNChild) > 0
111 if self.hasDeblendKey:
112 self.deblendKey = schema[config.deblendNChild]
114 self.log = logging.getLogger(self.logName)
116 @classmethod
117 def getExecutionOrder(cls):
118 return cls.SHAPE_ORDER
120 @staticmethod
121 def bboxToGalSimBounds(bbox):
122 xmin, xmax = bbox.getMinX(), bbox.getMaxX()
123 ymin, ymax = bbox.getMinY(), bbox.getMaxY()
124 return galsim._BoundsI(xmin, xmax, ymin, ymax)
126 def _addEllipticityField(self, name, n, schema, doc):
127 """
128 Helper function to add an ellipticity field to a measurement schema.
130 Parameters
131 ----------
132 name : `str`
133 Base name of the field.
134 n : `int`
135 Specifies whether the field is for the first (1) or second (2)
136 component.
137 schema : `~lsst.afw.table.Schema`
138 The schema to which the field is added.
139 doc : `str`
140 The documentation string that needs to be updated to reflect the
141 type and component of the measurement.
143 Returns
144 -------
145 `~lsst.afw.table.KeyD`
146 The key associated with the added field in the schema.
147 """
148 componentLookup = {1: "+ component", 2: "x component"}
149 typeLookup = {"e": " of ellipticity", "g": " of estimated shear"}
150 name = f"{name}_{self.measTypeSymbol}{n}"
151 updatedDoc = f"{doc} ({componentLookup[n]}{typeLookup[self.measTypeSymbol]})"
152 return schema.addField(name, type=float, doc=updatedDoc)
154 def measure(self, record, exposure):
155 """
156 Measure the shape of sources given an exposure and set the results in
157 the record in place.
159 Parameters
160 ----------
161 record : `~lsst.afw.table.SourceRecord`
162 The record where measurement outputs will be stored.
163 exposure : `~lsst.afw.image.Exposure`
164 The exposure containing the source which needs measurement.
166 Raises
167 ------
168 MeasurementError
169 Raised for errors in measurement.
170 """
171 # Extract the centroid from the record.
172 center = self.centroidExtractor(record, self.flagHandler)
174 if self.hasDeblendKey and record.get(self.deblendKey) > 0:
175 raise measBase.MeasurementError(self.PARENT_SOURCE.doc, self.PARENT_SOURCE.number)
177 # Get the bounding box of the source's footprint.
178 bbox = record.getFootprint().getBBox()
180 # Check that the bounding box has non-zero area.
181 if bbox.getArea() == 0:
182 raise measBase.MeasurementError(self.NO_PIXELS.doc, self.NO_PIXELS.number)
184 # Ensure that the centroid is within the bounding box.
185 if not bbox.contains(Point2I(center)):
186 raise measBase.MeasurementError(self.NOT_CONTAINED.doc, self.NOT_CONTAINED.number)
188 # Get the PSF image evaluated at the source centroid.
189 psfImage = exposure.getPsf().computeImage(center)
190 psfImage.setXY0(0, 0)
192 # Get the trace radius of the PSF.
193 psfSigma = exposure.getPsf().computeShape(center).getTraceRadius()
195 # Turn bounding box corners into GalSim bounds.
196 bounds = self.bboxToGalSimBounds(bbox)
198 # Get the bounding box of the PSF in the parent coordinate system.
199 psfBBox = psfImage.getBBox(afwImage.PARENT)
201 # Turn the PSF bounding box corners into GalSim bounds.
202 psfBounds = self.bboxToGalSimBounds(psfBBox)
204 # Each GalSim image below will match whatever dtype the input array is.
205 # NOTE: PSF is already restricted to a small image, so no bbox for the
206 # PSF is expected.
207 image = galsim._Image(exposure.image[bbox].array, bounds, wcs=None)
208 psf = galsim._Image(psfImage.array, psfBounds, wcs=None)
210 # Get the `lsst.meas.base` mask for bad pixels.
211 subMask = exposure.mask[bbox]
212 badpix = subMask.array.copy() # Copy it since badpix gets modified.
213 bitValue = exposure.mask.getPlaneBitMask(self.config.badMaskPlanes)
214 badpix &= bitValue
216 # Turn badpix to weight where elements set to 1 indicate 'use pixel'
217 # and those set to 0 mean 'do not use pixel'. Now, weight will assume
218 # the role of badpix, and we will no longer use badpix in our call to
219 # EstimateShear().
220 gd = badpix == 0
221 badpix[gd] = 1
222 badpix[~gd] = 0
223 weight = galsim._Image(badpix, bounds, wcs=None)
225 # Get the statistics control object for sky variance estimation.
226 sctrl = afwMath.StatisticsControl()
227 sctrl.setAndMask(bitValue)
229 # Create a variance image from the exposure.
230 # NOTE: Origin defaults to PARENT in all cases accessible from Python.
231 variance = afwImage.Image(
232 exposure.variance[bbox],
233 dtype=exposure.variance.dtype,
234 deep=False,
235 )
237 # Calculate median sky variance for use in shear estimation.
238 stat = afwMath.makeStatistics(variance, subMask, afwMath.MEDIAN, sctrl)
239 skyvar = stat.getValue(afwMath.MEDIAN)
241 # Directly use GalSim's C++/Python interface for shear estimation.
242 try:
243 # Initialize an instance of ShapeData to store the results.
244 shape = galsim.hsm.ShapeData(
245 image_bounds=galsim._BoundsI(0, 0, 1, 1),
246 observed_shape=galsim._Shear(0j),
247 psf_shape=galsim._Shear(0j),
248 moments_centroid=galsim._PositionD(0, 0),
249 )
251 # Prepare various values for GalSim's EstimateShearView.
252 recomputeFlux = "FIT"
253 precision = 1.0e-6
254 guessCentroid = galsim._PositionD(center.getX(), center.getY())
255 hsmparams = galsim.hsm.HSMParams.default
257 # Estimate shear using GalSim. Arguments are passed positionally
258 # to the C++ function. Inline comments specify the Python layer
259 # equivalent of each argument for clarity.
260 # TODO: [DM-42047] Change to public API when an optimized version
261 # is available.
262 galsim._galsim.EstimateShearView(
263 shape._data, # shape data buffer (not passed in pure Python)
264 image._image, # gal_image
265 psf._image, # PSF_image
266 weight._image, # weight
267 float(skyvar), # sky_var
268 self.config.shearType.upper(), # shear_est
269 recomputeFlux.upper(), # recompute_flux
270 float(2.5 * psfSigma), # guess_sig_gal
271 float(psfSigma), # guess_sig_PSF
272 float(precision), # precision
273 guessCentroid._p, # guess_centroid
274 hsmparams._hsmp, # hsmparams
275 )
276 except galsim.hsm.GalSimHSMError as error:
277 raise measBase.MeasurementError(str(error), self.GALSIM.number)
279 # Set ellipticity and error values based on measurement type.
280 if shape.meas_type == "e":
281 record.set(self.e1Key, shape.corrected_e1)
282 record.set(self.e2Key, shape.corrected_e2)
283 record.set(self.sigmaKey, 2.0 * shape.corrected_shape_err)
284 else:
285 record.set(self.e1Key, shape.corrected_g1)
286 record.set(self.e2Key, shape.corrected_g2)
287 record.set(self.sigmaKey, shape.corrected_shape_err)
289 record.set(self.resolutionKey, shape.resolution_factor)
290 self.flagHandler.setValue(record, self.FAILURE.number, shape.correction_status != 0)
292 def fail(self, record, error=None):
293 # Docstring inherited.
294 self.flagHandler.handleFailure(record)
295 if error:
296 centroid = self.centroidExtractor(record, self.flagHandler)
297 self.log.debug(
298 "Failed to measure shape for %d at (%f, %f): %s",
299 record.getId(),
300 centroid.getX(),
301 centroid.getY(),
302 error,
303 )
306class HsmShapeBjConfig(HsmShapeConfig):
307 """Configuration for HSM shape measurement for the BJ estimator."""
309 def setDefaults(self):
310 super().setDefaults()
311 self.shearType = "BJ"
313 def validate(self):
314 if self.shearType != "BJ":
315 raise pexConfig.FieldValidationError(self.shearType, self, "shearType should be set to 'BJ'.")
316 super().validate()
319@measBase.register("ext_shapeHSM_HsmShapeBj")
320class HsmShapeBjPlugin(HsmShapePlugin):
321 """Plugin for HSM shape measurement for the BJ estimator."""
323 ConfigClass = HsmShapeBjConfig
324 measTypeSymbol = "e"
325 doc = "PSF-corrected shear using Bernstein & Jarvis (2002) method"
328class HsmShapeLinearConfig(HsmShapeConfig):
329 """Configuration for HSM shape measurement for the LINEAR estimator."""
331 def setDefaults(self):
332 super().setDefaults()
333 self.shearType = "LINEAR"
335 def validate(self):
336 if self.shearType != "LINEAR":
337 raise pexConfig.FieldValidationError(self.shearType, self, "shearType should be set to 'LINEAR'.")
338 super().validate()
341@measBase.register("ext_shapeHSM_HsmShapeLinear")
342class HsmShapeLinearPlugin(HsmShapePlugin):
343 """Plugin for HSM shape measurement for the LINEAR estimator."""
345 ConfigClass = HsmShapeLinearConfig
346 measTypeSymbol = "e"
347 doc = "PSF-corrected shear using Hirata & Seljak (2003) 'linear' method"
350class HsmShapeKsbConfig(HsmShapeConfig):
351 """Configuration for HSM shape measurement for the KSB estimator."""
353 def setDefaults(self):
354 super().setDefaults()
355 self.shearType = "KSB"
357 def validate(self):
358 if self.shearType != "KSB":
359 raise pexConfig.FieldValidationError(self.shearType, self, "shearType should be set to 'KSB'.")
360 super().validate()
363@measBase.register("ext_shapeHSM_HsmShapeKsb")
364class HsmShapeKsbPlugin(HsmShapePlugin):
365 """Plugin for HSM shape measurement for the KSB estimator."""
367 ConfigClass = HsmShapeKsbConfig
368 measTypeSymbol = "g"
369 doc = "PSF-corrected shear using Kaiser, Squires, & Broadhurst (1995) method"
372class HsmShapeRegaussConfig(HsmShapeConfig):
373 """Configuration for HSM shape measurement for the REGAUSS estimator."""
375 def setDefaults(self):
376 super().setDefaults()
377 self.shearType = "REGAUSS"
379 def validate(self):
380 if self.shearType != "REGAUSS":
381 raise pexConfig.FieldValidationError(
382 self.shearType, self, "shearType should be set to 'REGAUSS'."
383 )
384 super().validate()
387@measBase.register("ext_shapeHSM_HsmShapeRegauss")
388class HsmShapeRegaussPlugin(HsmShapePlugin):
389 """Plugin for HSM shape measurement for the REGAUSS estimator."""
391 ConfigClass = HsmShapeRegaussConfig
392 measTypeSymbol = "e"
393 doc = "PSF-corrected shear using Hirata & Seljak (2003) 'regaussianization' method"