Coverage for tests/test_FlagHandler.py : 18%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of meas_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import unittest
24import numpy as np
26import lsst.utils.tests
27import lsst.geom
28import lsst.meas.base
29import lsst.meas.base.tests
30import lsst.afw.table
31from lsst.meas.base import FlagDefinitionList, FlagHandler, MeasurementError
32from lsst.meas.base.tests import AlgorithmTestCase
34import lsst.pex.exceptions
35from lsst.meas.base.pluginRegistry import register
36from lsst.meas.base.sfm import SingleFramePluginConfig, SingleFramePlugin
39class PythonPluginConfig(SingleFramePluginConfig):
40 """Configuration for a sample plugin with a `FlagHandler`.
41 """
43 edgeLimit = lsst.pex.config.Field(dtype=int, default=0, optional=False,
44 doc="How close to the edge can the object be?")
45 size = lsst.pex.config.Field(dtype=int, default=1, optional=False,
46 doc="size of aperture to measure around the center?")
47 flux0 = lsst.pex.config.Field(dtype=float, default=None, optional=False,
48 doc="Flux for zero mag, used to set mag if defined")
51@register("test_PythonPlugin")
52class PythonPlugin(SingleFramePlugin):
53 """Example Python measurement plugin using a `FlagHandler`.
55 This is a sample Python plugin which shows how to create and use a
56 `FlagHandler`. The `FlagHandler` defines the known failures which can
57 occur when the plugin is called, and should be tested after `measure` to
58 detect any potential problems.
60 This plugin is a very simple flux measurement algorithm which sums the
61 pixel values in a square box of dimension `PythonPluginConfig.size` around
62 the center point.
64 Note that to properly set the error flags when a `MeasurementError` occurs,
65 the plugin must implement the `fail` method as shown below. The `fail`
66 method should set both the general error flag, and any specific flag as
67 designated in the `MeasurementError`.
69 This example also demonstrates the use of the `SafeCentroidExtractor`. The
70 `SafeCentroidEextractor` and `SafeShapeExtractor` can be used to get some
71 reasonable estimate of the centroid or shape in cases where the centroid
72 or shape slot has failed on a particular source.
73 """
75 ConfigClass = PythonPluginConfig
77 @classmethod
78 def getExecutionOrder(cls):
79 return cls.FLUX_ORDER
81 def __init__(self, config, name, schema, metadata):
82 SingleFramePlugin.__init__(self, config, name, schema, metadata)
83 flagDefs = FlagDefinitionList()
84 self.FAILURE = flagDefs.addFailureFlag()
85 self.CONTAINS_NAN = flagDefs.add("flag_containsNan", "Measurement area contains a nan")
86 self.EDGE = flagDefs.add("flag_edge", "Measurement area over edge")
87 self.flagHandler = FlagHandler.addFields(schema, name, flagDefs)
88 self.centroidExtractor = lsst.meas.base.SafeCentroidExtractor(schema, name)
89 self.instFluxKey = schema.addField(name + "_instFlux", "F", doc="flux")
90 self.magKey = schema.addField(name + "_mag", "F", doc="mag")
92 def measure(self, measRecord, exposure):
93 """Perform measurement.
95 The measure method is called by the measurement framework when
96 `run` is called. If a `MeasurementError` is raised during this method,
97 the `fail` method will be called to set the error flags.
98 """
99 # Call the SafeCentroidExtractor to get a centroid, even if one has
100 # not been supplied by the centroid slot. Normally, the centroid is
101 # supplied by the centroid slot, but if that fails, the footprint is
102 # used as a fallback. If the fallback is needed, the fail flag will
103 # be set on this record.
104 center = self.centroidExtractor(measRecord, self.flagHandler)
106 # create a square bounding box of size = config.size around the center
107 centerPoint = lsst.geom.Point2I(int(center.getX()), int(center.getY()))
108 bbox = lsst.geom.Box2I(centerPoint, lsst.geom.Extent2I(1, 1))
109 bbox.grow(self.config.size)
111 # If the measurement box falls outside the exposure, raise the edge
112 # MeasurementError
113 if not exposure.getBBox().contains(bbox):
114 raise MeasurementError(self.EDGE.doc, self.EDGE.number)
116 # Sum the pixels inside the bounding box
117 instFlux = lsst.afw.image.ImageF(exposure.getMaskedImage().getImage(), bbox).getArray().sum()
118 measRecord.set(self.instFluxKey, instFlux)
120 # If there was a NaN inside the bounding box, the instFlux will still
121 # be NaN
122 if np.isnan(instFlux):
123 raise MeasurementError(self.CONTAINS_NAN.doc, self.CONTAINS_NAN.number)
125 if self.config.flux0 is not None:
126 if self.config.flux0 == 0:
127 raise ZeroDivisionError("self.config.flux0 is zero in divisor")
128 mag = -2.5 * np.log10(instFlux/self.config.flux0)
129 measRecord.set(self.magKey, mag)
131 def fail(self, measRecord, error=None):
132 """Handle measurement failures.
134 If the exception is a `MeasurementError`, the error will be passed to
135 the fail method by the measurement Framework. If ``error`` is not
136 `None`, ``error.cpp`` should correspond to a specific error and the
137 appropriate error flag will be set.
138 """
139 if error is None:
140 self.flagHandler.handleFailure(measRecord)
141 else:
142 self.flagHandler.handleFailure(measRecord, error.cpp)
145class FlagHandlerTestCase(AlgorithmTestCase, lsst.utils.tests.TestCase):
146 # Setup a configuration and datasource to be used by the plugin tests
148 def setUp(self):
149 self.algName = "test_PythonPlugin"
150 bbox = lsst.geom.Box2I(lsst.geom.Point2I(0, 0), lsst.geom.Point2I(100, 100))
151 self.dataset = lsst.meas.base.tests.TestDataset(bbox)
152 self.dataset.addSource(instFlux=1E5, centroid=lsst.geom.Point2D(25, 26))
153 config = lsst.meas.base.SingleFrameMeasurementConfig()
154 config.plugins = [self.algName]
155 config.slots.centroid = None
156 config.slots.apFlux = None
157 config.slots.calibFlux = None
158 config.slots.gaussianFlux = None
159 config.slots.modelFlux = None
160 config.slots.psfFlux = None
161 config.slots.shape = None
162 config.slots.psfShape = None
163 self.config = config
165 def tearDown(self):
166 del self.config
167 del self.dataset
169 def testFlagHandler(self):
170 """Test creation and invocation of `FlagHander`.
171 """
172 schema = lsst.afw.table.SourceTable.makeMinimalSchema()
174 # This is a FlagDefinition structure like a plugin might have
175 flagDefs = FlagDefinitionList()
176 FAILURE = flagDefs.addFailureFlag()
177 FIRST = flagDefs.add("1st error", "this is the first failure type")
178 SECOND = flagDefs.add("2nd error", "this is the second failure type")
179 fh = FlagHandler.addFields(schema, "test", flagDefs)
180 # Check to be sure that the FlagHandler was correctly initialized
181 for index in range(len(flagDefs)):
182 self.assertEqual(flagDefs.getDefinition(index).name, fh.getFlagName(index))
184 catalog = lsst.afw.table.SourceCatalog(schema)
186 # Now check to be sure that all of the known failures set the bits
187 # correctly
188 record = catalog.addNew()
189 fh.handleFailure(record)
190 self.assertTrue(fh.getValue(record, FAILURE.number))
191 self.assertFalse(fh.getValue(record, FIRST.number))
192 self.assertFalse(fh.getValue(record, SECOND.number))
193 record = catalog.addNew()
195 error = MeasurementError(FAILURE.doc, FAILURE.number)
196 fh.handleFailure(record, error.cpp)
197 self.assertTrue(fh.getValue(record, FAILURE.number))
198 self.assertFalse(fh.getValue(record, FIRST.number))
199 self.assertFalse(fh.getValue(record, SECOND.number))
201 record = catalog.addNew()
202 error = MeasurementError(FIRST.doc, FIRST.number)
203 fh.handleFailure(record, error.cpp)
204 self.assertTrue(fh.getValue(record, FAILURE.number))
205 self.assertTrue(fh.getValue(record, FIRST.number))
206 self.assertFalse(fh.getValue(record, SECOND.number))
208 record = catalog.addNew()
209 error = MeasurementError(SECOND.doc, SECOND.number)
210 fh.handleFailure(record, error.cpp)
211 self.assertTrue(fh.getValue(record, FAILURE.number))
212 self.assertFalse(fh.getValue(record, FIRST.number))
213 self.assertTrue(fh.getValue(record, SECOND.number))
215 def testNoFailureFlag(self):
216 """Test with no failure flag.
217 """
218 schema = lsst.afw.table.SourceTable.makeMinimalSchema()
220 # This is a FlagDefinition structure like a plugin might have
221 flagDefs = FlagDefinitionList()
222 FIRST = flagDefs.add("1st error", "this is the first failure type")
223 SECOND = flagDefs.add("2nd error", "this is the second failure type")
224 fh = FlagHandler.addFields(schema, "test", flagDefs)
225 # Check to be sure that the FlagHandler was correctly initialized
226 for index in range(len(flagDefs)):
227 self.assertEqual(flagDefs.getDefinition(index).name, fh.getFlagName(index))
229 catalog = lsst.afw.table.SourceCatalog(schema)
231 # Now check to be sure that all of the known failures set the bits
232 # correctly
233 record = catalog.addNew()
234 fh.handleFailure(record)
235 self.assertFalse(fh.getValue(record, FIRST.number))
236 self.assertFalse(fh.getValue(record, SECOND.number))
237 record = catalog.addNew()
239 record = catalog.addNew()
240 error = MeasurementError(FIRST.doc, FIRST.number)
241 fh.handleFailure(record, error.cpp)
242 self.assertTrue(fh.getValue(record, FIRST.number))
243 self.assertFalse(fh.getValue(record, SECOND.number))
245 record = catalog.addNew()
246 error = MeasurementError(SECOND.doc, SECOND.number)
247 fh.handleFailure(record, error.cpp)
248 self.assertFalse(fh.getValue(record, FIRST.number))
249 self.assertTrue(fh.getValue(record, SECOND.number))
251 # This and the following tests using the toy plugin, and demonstrate how
252 # the FlagHandler is used.
254 def testPluginNoError(self):
255 """Test that the sample plugin can be run without errors.
256 """
257 schema = self.dataset.makeMinimalSchema()
258 task = lsst.meas.base.SingleFrameMeasurementTask(schema=schema, config=self.config)
259 exposure, cat = self.dataset.realize(noise=100.0, schema=schema, randomSeed=0)
260 task.run(cat, exposure)
261 source = cat[0]
262 self.assertFalse(source.get(self.algName + "_flag"))
263 self.assertFalse(source.get(self.algName + "_flag_containsNan"))
264 self.assertFalse(source.get(self.algName + "_flag_edge"))
266 def testPluginUnexpectedError(self):
267 """Test that unexpected non-fatal errors set the failure flag.
269 An unexpected error is a non-fatal error which is not caught by the
270 algorithm itself. However, such errors are caught by the measurement
271 framework in task.run, and result in the failure flag being set, but
272 no other specific flags
273 """
274 self.config.plugins[self.algName].flux0 = 0.0 # divide by zero
275 schema = self.dataset.makeMinimalSchema()
276 task = lsst.meas.base.SingleFrameMeasurementTask(schema=schema, config=self.config)
277 exposure, cat = self.dataset.realize(noise=100.0, schema=schema, randomSeed=1)
278 task.log.setLevel(task.log.FATAL)
279 task.run(cat, exposure)
280 source = cat[0]
281 self.assertTrue(source.get(self.algName + "_flag"))
282 self.assertFalse(source.get(self.algName + "_flag_containsNan"))
283 self.assertFalse(source.get(self.algName + "_flag_edge"))
285 def testPluginContainsNan(self):
286 """Test that the ``containsNan`` error can be triggered.
287 """
288 schema = self.dataset.makeMinimalSchema()
289 task = lsst.meas.base.SingleFrameMeasurementTask(schema=schema, config=self.config)
290 exposure, cat = self.dataset.realize(noise=100.0, schema=schema, randomSeed=2)
291 source = cat[0]
292 exposure.getMaskedImage().getImage().getArray()[int(source.getY()), int(source.getX())] = np.nan
293 task.run(cat, exposure)
294 self.assertTrue(source.get(self.algName + "_flag"))
295 self.assertTrue(source.get(self.algName + "_flag_containsNan"))
296 self.assertFalse(source.get(self.algName + "_flag_edge"))
298 def testPluginEdgeError(self):
299 """Test that the ``edge`` error can be triggered.
300 """
301 schema = self.dataset.makeMinimalSchema()
302 task = lsst.meas.base.SingleFrameMeasurementTask(schema=schema, config=self.config)
303 exposure, cat = self.dataset.realize(noise=100.0, schema=schema, randomSeed=3)
304 # Set the size large enough to trigger the edge error
305 self.config.plugins[self.algName].size = exposure.getDimensions()[1]//2
306 task.log.setLevel(task.log.FATAL)
307 task.run(cat, exposure)
308 source = cat[0]
309 self.assertTrue(source.get(self.algName + "_flag"))
310 self.assertFalse(source.get(self.algName + "_flag_containsNan"))
311 self.assertTrue(source.get(self.algName + "_flag_edge"))
313 def testSafeCentroider(self):
314 """Test `SafeCentroidExtractor` correctly runs and sets flags.
315 """
316 # Normal case should use the centroid slot to get the center, which
317 # should succeed
318 schema = self.dataset.makeMinimalSchema()
319 task = lsst.meas.base.SingleFrameMeasurementTask(schema=schema, config=self.config)
320 task.log.setLevel(task.log.FATAL)
321 exposure, cat = self.dataset.realize(noise=0.0, schema=schema, randomSeed=4)
322 source = cat[0]
323 task.run(cat, exposure)
324 self.assertFalse(source.get(self.algName + "_flag"))
325 instFlux = source.get("test_PythonPlugin_instFlux")
326 self.assertFalse(np.isnan(instFlux))
328 # If one of the center coordinates is nan and the centroid slot error
329 # flag has not been set, the SafeCentroidExtractor will fail.
330 source.set('truth_x', np.nan)
331 source.set('truth_flag', False)
332 source.set("test_PythonPlugin_instFlux", np.nan)
333 source.set(self.algName + "_flag", False)
334 task.run(cat, exposure)
335 self.assertTrue(source.get(self.algName + "_flag"))
336 self.assertTrue(np.isnan(source.get("test_PythonPlugin_instFlux")))
338 # But if the same conditions occur and the centroid slot error flag is
339 # set to true, the SafeCentroidExtractor will succeed and the
340 # algorithm will complete. However, the failure flag will also be
341 # set.
342 source.set('truth_x', np.nan)
343 source.set('truth_flag', True)
344 source.set("test_PythonPlugin_instFlux", np.nan)
345 source.set(self.algName + "_flag", False)
346 task.run(cat, exposure)
347 self.assertTrue(source.get(self.algName + "_flag"))
348 self.assertEqual(source.get("test_PythonPlugin_instFlux"), instFlux)
351class TestMemory(lsst.utils.tests.MemoryTestCase):
352 pass
355def setup_module(module):
356 lsst.utils.tests.init()
359if __name__ == "__main__": 359 ↛ 360line 359 didn't jump to line 360, because the condition on line 359 was never true
360 lsst.utils.tests.init()
361 unittest.main()