Coverage for python/lsst/meas/extensions/shapeHSM/_hsm_shape.py: 36%
154 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-09 03:47 -0700
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-09 03:47 -0700
1# This file is part of meas_extensions_shapeHSM.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import logging
24import galsim
25import lsst.afw.image as afwImage
26import lsst.afw.math as afwMath
27import lsst.meas.base as measBase
28import lsst.pex.config as pexConfig
29from lsst.geom import Point2I
31__all__ = [
32 "HsmShapeBjConfig",
33 "HsmShapeBjPlugin",
34 "HsmShapeLinearConfig",
35 "HsmShapeLinearPlugin",
36 "HsmShapeKsbConfig",
37 "HsmShapeKsbPlugin",
38 "HsmShapeRegaussConfig",
39 "HsmShapeRegaussPlugin",
40]
43class HsmShapeConfig(measBase.SingleFramePluginConfig):
44 """Base configuration for HSM shape measurement."""
46 shearType = pexConfig.ChoiceField[str](
47 doc=(
48 "The desired method of PSF correction using GalSim. The first three options return an e-type "
49 "distortion, whereas the last option returns a g-type shear."
50 ),
51 allowed={
52 "REGAUSS": "Regaussianization method from Hirata & Seljak (2003)",
53 "LINEAR": "A modification by Hirata & Seljak (2003) of methods in Bernstein & Jarvis (2002)",
54 "BJ": "From Bernstein & Jarvis (2002)",
55 "KSB": "From Kaiser, Squires, & Broadhurst (1995)",
56 },
57 default="REGAUSS",
58 )
60 deblendNChild = pexConfig.Field[str](
61 doc="Field name for number of deblend children.",
62 default="",
63 )
65 badMaskPlanes = pexConfig.ListField[str](
66 doc="Mask planes that indicate pixels that should be excluded from the fit.",
67 default=["BAD", "SAT"],
68 )
71class HsmShapePlugin(measBase.SingleFramePlugin):
72 """Base plugin for HSM shape measurement."""
74 ConfigClass = HsmShapeConfig
75 doc = ""
77 def __init__(self, config, name, schema, metadata, logName=None):
78 if logName is None:
79 logName = __name__
80 super().__init__(config, name, schema, metadata, logName=logName)
82 # Define flags for possible issues that might arise during measurement.
83 flagDefs = measBase.FlagDefinitionList()
84 self.FAILURE = flagDefs.addFailureFlag("General failure flag, set if anything went wrong")
85 self.NO_PIXELS = flagDefs.add("flag_no_pixels", "No pixels to measure")
86 self.NOT_CONTAINED = flagDefs.add(
87 "flag_not_contained", "Center not contained in footprint bounding box"
88 )
89 self.PARENT_SOURCE = flagDefs.add("flag_parent_source", "Parent source, ignored")
90 self.GALSIM = flagDefs.add("flag_galsim", "GalSim failure")
92 # Embed the flag definitions in the schema using a flag handler.
93 self.flagHandler = measBase.FlagHandler.addFields(schema, name, flagDefs)
95 # Utilize a safe centroid extractor that uses the detection footprint
96 # as a fallback if necessary.
97 self.centroidExtractor = measBase.SafeCentroidExtractor(schema, name)
99 self.e1Key = self._addEllipticityField(name, 1, schema, self.doc)
100 self.e2Key = self._addEllipticityField(name, 2, schema, self.doc)
101 self.sigmaKey = schema.addField(
102 schema.join(name, "sigma"),
103 type=float,
104 doc=f"{self.doc} (shape measurement uncertainty per component)",
105 )
106 self.resolutionKey = schema.addField(
107 schema.join(name, "resolution"), type=float, doc="Resolution factor (0=unresolved, 1=resolved)"
108 )
109 self.hasDeblendKey = len(config.deblendNChild) > 0
111 if self.hasDeblendKey:
112 self.deblendKey = schema[config.deblendNChild].asKey()
114 self.log = logging.getLogger(self.logName)
116 @classmethod
117 def getExecutionOrder(cls):
118 return cls.SHAPE_ORDER
120 @staticmethod
121 def bboxToGalSimBounds(bbox):
122 xmin, xmax = bbox.getMinX(), bbox.getMaxX()
123 ymin, ymax = bbox.getMinY(), bbox.getMaxY()
124 return galsim._BoundsI(xmin, xmax, ymin, ymax)
126 def _addEllipticityField(self, name, n, schema, doc):
127 """
128 Helper function to add an ellipticity field to a measurement schema.
130 Parameters
131 ----------
132 name : `str`
133 Base name of the field.
134 n : `int`
135 Specifies whether the field is for the first (1) or second (2)
136 component.
137 schema : `~lsst.afw.table.Schema`
138 The schema to which the field is added.
139 doc : `str`
140 The documentation string that needs to be updated to reflect the
141 type and component of the measurement.
143 Returns
144 -------
145 `~lsst.afw.table.KeyD`
146 The key associated with the added field in the schema.
147 """
148 componentLookup = {1: "+ component", 2: "x component"}
149 typeLookup = {"e": " of ellipticity", "g": " of estimated shear"}
150 name = f"{name}_{self.measTypeSymbol}{n}"
151 updatedDoc = f"{doc} ({componentLookup[n]}{typeLookup[self.measTypeSymbol]})"
152 return schema.addField(name, type=float, doc=updatedDoc)
154 def measure(self, record, exposure):
155 """
156 Measure the shape of sources given an exposure and set the results in
157 the record in place.
159 Parameters
160 ----------
161 record : `~lsst.afw.table.SourceRecord`
162 The record where measurement outputs will be stored.
163 exposure : `~lsst.afw.image.Exposure`
164 The exposure containing the source which needs measurement.
166 Raises
167 ------
168 MeasurementError
169 Raised for errors in measurement.
170 """
171 # Extract the centroid from the record.
172 center = self.centroidExtractor(record, self.flagHandler)
174 if self.hasDeblendKey and record.get(self.deblendKey) > 0:
175 raise measBase.MeasurementError(self.PARENT_SOURCE.doc, self.PARENT_SOURCE.number)
177 # Get the bounding box of the source's footprint.
178 bbox = record.getFootprint().getBBox()
180 # Check that the bounding box has non-zero area.
181 if bbox.getArea() == 0:
182 raise measBase.MeasurementError(self.NO_PIXELS.doc, self.NO_PIXELS.number)
184 # Ensure that the centroid is within the bounding box.
185 if not bbox.contains(Point2I(center)):
186 raise measBase.MeasurementError(self.NOT_CONTAINED.doc, self.NOT_CONTAINED.number)
188 # Get the PSF image evaluated at the source centroid.
189 psfImage = exposure.getPsf().computeImage(center)
190 psfImage.setXY0(0, 0)
192 # Get the trace radius of the PSF.
193 psfSigma = exposure.getPsf().computeShape(center).getTraceRadius()
195 # Turn bounding box corners into GalSim bounds.
196 bounds = self.bboxToGalSimBounds(bbox)
198 # Get the bounding box of the PSF in the parent coordinate system.
199 psfBBox = psfImage.getBBox(afwImage.PARENT)
201 # Turn the PSF bounding box corners into GalSim bounds.
202 psfBounds = self.bboxToGalSimBounds(psfBBox)
204 # Each GalSim image below will match whatever dtype the input array is.
205 # NOTE: PSF is already restricted to a small image, so no bbox for the
206 # PSF is expected.
207 image = galsim._Image(exposure.image[bbox].array, bounds, wcs=None)
208 psf = galsim._Image(psfImage.array, psfBounds, wcs=None)
210 # Get the `lsst.meas.base` mask for bad pixels.
211 subMask = exposure.mask[bbox]
212 badpix = subMask.array.copy() # Copy it since badpix gets modified.
213 bitValue = exposure.mask.getPlaneBitMask(self.config.badMaskPlanes)
214 badpix &= bitValue
216 # Turn badpix to weight where elements set to 1 indicate 'use pixel'
217 # and those set to 0 mean 'do not use pixel'. Now, weight will assume
218 # the role of badpix, and we will no longer use badpix in our call to
219 # EstimateShear().
220 gd = badpix == 0
221 badpix[gd] = 1
222 badpix[~gd] = 0
223 weight = galsim._Image(badpix, bounds, wcs=None)
225 # Get the statistics control object for sky variance estimation.
226 sctrl = afwMath.StatisticsControl()
227 sctrl.setAndMask(bitValue)
229 # Create a variance image from the exposure.
230 # NOTE: Origin defaults to PARENT in all cases accessible from Python.
231 variance = afwImage.Image(
232 exposure.variance[bbox],
233 dtype=exposure.variance.dtype,
234 deep=False,
235 )
237 # Calculate median sky variance for use in shear estimation.
238 stat = afwMath.makeStatistics(variance, subMask, afwMath.MEDIAN, sctrl)
239 skyvar = stat.getValue(afwMath.MEDIAN)
241 # Initialize an instance of ShapeData to store the results.
242 shape = galsim.hsm.ShapeData(
243 image_bounds=galsim._BoundsI(0, 0, 1, 1),
244 observed_shape=galsim._Shear(0j),
245 psf_shape=galsim._Shear(0j),
246 moments_centroid=galsim._PositionD(0, 0),
247 )
249 # Prepare various values for the GalSim's EstimateShearView call.
250 recomputeFlux = "FIT"
251 precision = 1.0e-6
252 guessCentroid = galsim._PositionD(center.getX(), center.getY())
253 hsmparams = galsim.hsm.HSMParams.default
255 # Directly use GalSim's C++/Python interface for shear estimation.
256 try:
257 # Estimate shear using GalSim. Arguments are passed positionally
258 # to the C++ function. Inline comments specify the Python layer
259 # equivalent of each argument for clarity.
260 # TODO: [DM-42047] Change to public API when an optimized version
261 # is available.
262 galsim._galsim.EstimateShearView(
263 shape._data, # shape data buffer (not passed in pure Python)
264 image._image, # gal_image
265 psf._image, # PSF_image
266 weight._image, # weight
267 float(skyvar), # sky_var
268 self.config.shearType.upper(), # shear_est
269 recomputeFlux.upper(), # recompute_flux
270 float(2.5 * psfSigma), # guess_sig_gal
271 float(psfSigma), # guess_sig_PSF
272 float(precision), # precision
273 guessCentroid._p, # guess_centroid
274 hsmparams._hsmp, # hsmparams
275 )
276 # GalSim does not raise custom pybind errors as of v2.5, resulting in
277 # all GalSim C++ errors being RuntimeErrors.
278 except (galsim.hsm.GalSimHSMError, RuntimeError) as error:
279 raise measBase.MeasurementError(str(error), self.GALSIM.number)
281 # Set ellipticity and error values based on measurement type.
282 if shape.meas_type == "e":
283 record.set(self.e1Key, shape.corrected_e1)
284 record.set(self.e2Key, shape.corrected_e2)
285 record.set(self.sigmaKey, 2.0 * shape.corrected_shape_err)
286 else:
287 record.set(self.e1Key, shape.corrected_g1)
288 record.set(self.e2Key, shape.corrected_g2)
289 record.set(self.sigmaKey, shape.corrected_shape_err)
291 record.set(self.resolutionKey, shape.resolution_factor)
292 self.flagHandler.setValue(record, self.FAILURE.number, shape.correction_status != 0)
294 def fail(self, record, error=None):
295 # Docstring inherited.
296 self.flagHandler.handleFailure(record)
297 if error:
298 centroid = self.centroidExtractor(record, self.flagHandler)
299 self.log.debug(
300 "Failed to measure shape for %d at (%f, %f): %s",
301 record.getId(),
302 centroid.getX(),
303 centroid.getY(),
304 error,
305 )
308class HsmShapeBjConfig(HsmShapeConfig):
309 """Configuration for HSM shape measurement for the BJ estimator."""
311 def setDefaults(self):
312 super().setDefaults()
313 self.shearType = "BJ"
315 def validate(self):
316 if self.shearType != "BJ":
317 raise pexConfig.FieldValidationError(self.shearType, self, "shearType should be set to 'BJ'.")
318 super().validate()
321@measBase.register("ext_shapeHSM_HsmShapeBj")
322class HsmShapeBjPlugin(HsmShapePlugin):
323 """Plugin for HSM shape measurement for the BJ estimator."""
325 ConfigClass = HsmShapeBjConfig
326 measTypeSymbol = "e"
327 doc = "PSF-corrected shear using Bernstein & Jarvis (2002) method"
330class HsmShapeLinearConfig(HsmShapeConfig):
331 """Configuration for HSM shape measurement for the LINEAR estimator."""
333 def setDefaults(self):
334 super().setDefaults()
335 self.shearType = "LINEAR"
337 def validate(self):
338 if self.shearType != "LINEAR":
339 raise pexConfig.FieldValidationError(self.shearType, self, "shearType should be set to 'LINEAR'.")
340 super().validate()
343@measBase.register("ext_shapeHSM_HsmShapeLinear")
344class HsmShapeLinearPlugin(HsmShapePlugin):
345 """Plugin for HSM shape measurement for the LINEAR estimator."""
347 ConfigClass = HsmShapeLinearConfig
348 measTypeSymbol = "e"
349 doc = "PSF-corrected shear using Hirata & Seljak (2003) 'linear' method"
352class HsmShapeKsbConfig(HsmShapeConfig):
353 """Configuration for HSM shape measurement for the KSB estimator."""
355 def setDefaults(self):
356 super().setDefaults()
357 self.shearType = "KSB"
359 def validate(self):
360 if self.shearType != "KSB":
361 raise pexConfig.FieldValidationError(self.shearType, self, "shearType should be set to 'KSB'.")
362 super().validate()
365@measBase.register("ext_shapeHSM_HsmShapeKsb")
366class HsmShapeKsbPlugin(HsmShapePlugin):
367 """Plugin for HSM shape measurement for the KSB estimator."""
369 ConfigClass = HsmShapeKsbConfig
370 measTypeSymbol = "g"
371 doc = "PSF-corrected shear using Kaiser, Squires, & Broadhurst (1995) method"
374class HsmShapeRegaussConfig(HsmShapeConfig):
375 """Configuration for HSM shape measurement for the REGAUSS estimator."""
377 def setDefaults(self):
378 super().setDefaults()
379 self.shearType = "REGAUSS"
381 def validate(self):
382 if self.shearType != "REGAUSS":
383 raise pexConfig.FieldValidationError(
384 self.shearType, self, "shearType should be set to 'REGAUSS'."
385 )
386 super().validate()
389@measBase.register("ext_shapeHSM_HsmShapeRegauss")
390class HsmShapeRegaussPlugin(HsmShapePlugin):
391 """Plugin for HSM shape measurement for the REGAUSS estimator."""
393 ConfigClass = HsmShapeRegaussConfig
394 measTypeSymbol = "e"
395 doc = "PSF-corrected shear using Hirata & Seljak (2003) 'regaussianization' method"