Coverage for tests/test_FlagHandler.py: 19%

201 statements  

« prev     ^ index     » next       coverage.py v7.2.3, created at 2023-04-19 04:16 -0700

1# This file is part of meas_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import unittest 

23 

24import numpy as np 

25 

26import lsst.utils.tests 

27import lsst.geom 

28import lsst.meas.base 

29import lsst.meas.base.tests 

30import lsst.afw.table 

31from lsst.meas.base import FlagDefinitionList, FlagHandler, MeasurementError 

32from lsst.meas.base.tests import AlgorithmTestCase 

33 

34import lsst.pex.exceptions 

35from lsst.meas.base.pluginRegistry import register 

36from lsst.meas.base.sfm import SingleFramePluginConfig, SingleFramePlugin 

37 

38 

39class PythonPluginConfig(SingleFramePluginConfig): 

40 """Configuration for a sample plugin with a `FlagHandler`. 

41 """ 

42 

43 edgeLimit = lsst.pex.config.Field(dtype=int, default=0, 

44 doc="How close to the edge can the object be?") 

45 size = lsst.pex.config.Field(dtype=int, default=1, 

46 doc="size of aperture to measure around the center?") 

47 flux0 = lsst.pex.config.Field(dtype=float, default=None, optional=True, 

48 doc="Flux for zero mag, used to set mag if defined") 

49 

50 

51@register("test_PythonPlugin") 

52class PythonPlugin(SingleFramePlugin): 

53 """Example Python measurement plugin using a `FlagHandler`. 

54 

55 This is a sample Python plugin which shows how to create and use a 

56 `FlagHandler`. The `FlagHandler` defines the known failures which can 

57 occur when the plugin is called, and should be tested after `measure` to 

58 detect any potential problems. 

59 

60 This plugin is a very simple flux measurement algorithm which sums the 

61 pixel values in a square box of dimension `PythonPluginConfig.size` around 

62 the center point. 

63 

64 Note that to properly set the error flags when a `MeasurementError` occurs, 

65 the plugin must implement the `fail` method as shown below. The `fail` 

66 method should set both the general error flag, and any specific flag as 

67 designated in the `MeasurementError`. 

68 

69 This example also demonstrates the use of the `SafeCentroidExtractor`. The 

70 `SafeCentroidEextractor` and `SafeShapeExtractor` can be used to get some 

71 reasonable estimate of the centroid or shape in cases where the centroid 

72 or shape slot has failed on a particular source. 

73 """ 

74 

75 ConfigClass = PythonPluginConfig 

76 

77 @classmethod 

78 def getExecutionOrder(cls): 

79 return cls.FLUX_ORDER 

80 

81 def __init__(self, config, name, schema, metadata): 

82 SingleFramePlugin.__init__(self, config, name, schema, metadata) 

83 flagDefs = FlagDefinitionList() 

84 self.FAILURE = flagDefs.addFailureFlag() 

85 self.CONTAINS_NAN = flagDefs.add("flag_containsNan", "Measurement area contains a nan") 

86 self.EDGE = flagDefs.add("flag_edge", "Measurement area over edge") 

87 self.flagHandler = FlagHandler.addFields(schema, name, flagDefs) 

88 self.centroidExtractor = lsst.meas.base.SafeCentroidExtractor(schema, name) 

89 self.instFluxKey = schema.addField(name + "_instFlux", "F", doc="flux") 

90 self.magKey = schema.addField(name + "_mag", "F", doc="mag") 

91 

92 def measure(self, measRecord, exposure): 

93 """Perform measurement. 

94 

95 The measure method is called by the measurement framework when 

96 `run` is called. If a `MeasurementError` is raised during this method, 

97 the `fail` method will be called to set the error flags. 

98 """ 

99 # Call the SafeCentroidExtractor to get a centroid, even if one has 

100 # not been supplied by the centroid slot. Normally, the centroid is 

101 # supplied by the centroid slot, but if that fails, the footprint is 

102 # used as a fallback. If the fallback is needed, the fail flag will 

103 # be set on this record. 

104 center = self.centroidExtractor(measRecord, self.flagHandler) 

105 

106 # create a square bounding box of size = config.size around the center 

107 centerPoint = lsst.geom.Point2I(int(center.getX()), int(center.getY())) 

108 bbox = lsst.geom.Box2I(centerPoint, lsst.geom.Extent2I(1, 1)) 

109 bbox.grow(self.config.size) 

110 

111 # If the measurement box falls outside the exposure, raise the edge 

112 # MeasurementError 

113 if not exposure.getBBox().contains(bbox): 

114 raise MeasurementError(self.EDGE.doc, self.EDGE.number) 

115 

116 # Sum the pixels inside the bounding box 

117 instFlux = lsst.afw.image.ImageF(exposure.getMaskedImage().getImage(), bbox).getArray().sum() 

118 measRecord.set(self.instFluxKey, instFlux) 

119 

120 # If there was a NaN inside the bounding box, the instFlux will still 

121 # be NaN 

122 if np.isnan(instFlux): 

123 raise MeasurementError(self.CONTAINS_NAN.doc, self.CONTAINS_NAN.number) 

124 

125 if self.config.flux0 is not None: 

126 if self.config.flux0 == 0: 

127 raise ZeroDivisionError("self.config.flux0 is zero in divisor") 

128 mag = -2.5 * np.log10(instFlux/self.config.flux0) 

129 measRecord.set(self.magKey, mag) 

130 

131 def fail(self, measRecord, error=None): 

132 """Handle measurement failures. 

133 

134 If the exception is a `MeasurementError`, the error will be passed to 

135 the fail method by the measurement Framework. If ``error`` is not 

136 `None`, ``error.cpp`` should correspond to a specific error and the 

137 appropriate error flag will be set. 

138 """ 

139 if error is None: 

140 self.flagHandler.handleFailure(measRecord) 

141 else: 

142 self.flagHandler.handleFailure(measRecord, error.cpp) 

143 

144 

145class FlagHandlerTestCase(AlgorithmTestCase, lsst.utils.tests.TestCase): 

146 # Setup a configuration and datasource to be used by the plugin tests 

147 

148 def setUp(self): 

149 self.algName = "test_PythonPlugin" 

150 bbox = lsst.geom.Box2I(lsst.geom.Point2I(0, 0), lsst.geom.Point2I(100, 100)) 

151 self.dataset = lsst.meas.base.tests.TestDataset(bbox) 

152 self.dataset.addSource(instFlux=1E5, centroid=lsst.geom.Point2D(25, 26)) 

153 config = lsst.meas.base.SingleFrameMeasurementConfig() 

154 config.plugins = [self.algName] 

155 config.slots.centroid = None 

156 config.slots.apFlux = None 

157 config.slots.calibFlux = None 

158 config.slots.gaussianFlux = None 

159 config.slots.modelFlux = None 

160 config.slots.psfFlux = None 

161 config.slots.shape = None 

162 config.slots.psfShape = None 

163 self.config = config 

164 

165 def tearDown(self): 

166 del self.config 

167 del self.dataset 

168 

169 def testFlagHandler(self): 

170 """Test creation and invocation of `FlagHander`. 

171 """ 

172 schema = lsst.afw.table.SourceTable.makeMinimalSchema() 

173 

174 # This is a FlagDefinition structure like a plugin might have 

175 flagDefs = FlagDefinitionList() 

176 FAILURE = flagDefs.addFailureFlag() 

177 FIRST = flagDefs.add("1st error", "this is the first failure type") 

178 SECOND = flagDefs.add("2nd error", "this is the second failure type") 

179 fh = FlagHandler.addFields(schema, "test", flagDefs) 

180 # Check to be sure that the FlagHandler was correctly initialized 

181 for index in range(len(flagDefs)): 

182 self.assertEqual(flagDefs.getDefinition(index).name, fh.getFlagName(index)) 

183 

184 catalog = lsst.afw.table.SourceCatalog(schema) 

185 

186 # Now check to be sure that all of the known failures set the bits 

187 # correctly 

188 record = catalog.addNew() 

189 fh.handleFailure(record) 

190 self.assertTrue(fh.getValue(record, FAILURE.number)) 

191 self.assertFalse(fh.getValue(record, FIRST.number)) 

192 self.assertFalse(fh.getValue(record, SECOND.number)) 

193 record = catalog.addNew() 

194 

195 error = MeasurementError(FAILURE.doc, FAILURE.number) 

196 fh.handleFailure(record, error.cpp) 

197 self.assertTrue(fh.getValue(record, FAILURE.number)) 

198 self.assertFalse(fh.getValue(record, FIRST.number)) 

199 self.assertFalse(fh.getValue(record, SECOND.number)) 

200 

201 record = catalog.addNew() 

202 error = MeasurementError(FIRST.doc, FIRST.number) 

203 fh.handleFailure(record, error.cpp) 

204 self.assertTrue(fh.getValue(record, FAILURE.number)) 

205 self.assertTrue(fh.getValue(record, FIRST.number)) 

206 self.assertFalse(fh.getValue(record, SECOND.number)) 

207 

208 record = catalog.addNew() 

209 error = MeasurementError(SECOND.doc, SECOND.number) 

210 fh.handleFailure(record, error.cpp) 

211 self.assertTrue(fh.getValue(record, FAILURE.number)) 

212 self.assertFalse(fh.getValue(record, FIRST.number)) 

213 self.assertTrue(fh.getValue(record, SECOND.number)) 

214 

215 def testNoFailureFlag(self): 

216 """Test with no failure flag. 

217 """ 

218 schema = lsst.afw.table.SourceTable.makeMinimalSchema() 

219 

220 # This is a FlagDefinition structure like a plugin might have 

221 flagDefs = FlagDefinitionList() 

222 FIRST = flagDefs.add("1st error", "this is the first failure type") 

223 SECOND = flagDefs.add("2nd error", "this is the second failure type") 

224 fh = FlagHandler.addFields(schema, "test", flagDefs) 

225 # Check to be sure that the FlagHandler was correctly initialized 

226 for index in range(len(flagDefs)): 

227 self.assertEqual(flagDefs.getDefinition(index).name, fh.getFlagName(index)) 

228 

229 catalog = lsst.afw.table.SourceCatalog(schema) 

230 

231 # Now check to be sure that all of the known failures set the bits 

232 # correctly 

233 record = catalog.addNew() 

234 fh.handleFailure(record) 

235 self.assertFalse(fh.getValue(record, FIRST.number)) 

236 self.assertFalse(fh.getValue(record, SECOND.number)) 

237 record = catalog.addNew() 

238 

239 record = catalog.addNew() 

240 error = MeasurementError(FIRST.doc, FIRST.number) 

241 fh.handleFailure(record, error.cpp) 

242 self.assertTrue(fh.getValue(record, FIRST.number)) 

243 self.assertFalse(fh.getValue(record, SECOND.number)) 

244 

245 record = catalog.addNew() 

246 error = MeasurementError(SECOND.doc, SECOND.number) 

247 fh.handleFailure(record, error.cpp) 

248 self.assertFalse(fh.getValue(record, FIRST.number)) 

249 self.assertTrue(fh.getValue(record, SECOND.number)) 

250 

251 # This and the following tests using the toy plugin, and demonstrate how 

252 # the FlagHandler is used. 

253 

254 def testPluginNoError(self): 

255 """Test that the sample plugin can be run without errors. 

256 """ 

257 schema = self.dataset.makeMinimalSchema() 

258 task = lsst.meas.base.SingleFrameMeasurementTask(schema=schema, config=self.config) 

259 exposure, cat = self.dataset.realize(noise=100.0, schema=schema, randomSeed=0) 

260 task.run(cat, exposure) 

261 source = cat[0] 

262 self.assertFalse(source.get(self.algName + "_flag")) 

263 self.assertFalse(source.get(self.algName + "_flag_containsNan")) 

264 self.assertFalse(source.get(self.algName + "_flag_edge")) 

265 

266 def testPluginUnexpectedError(self): 

267 """Test that unexpected non-fatal errors set the failure flag. 

268 

269 An unexpected error is a non-fatal error which is not caught by the 

270 algorithm itself. However, such errors are caught by the measurement 

271 framework in task.run, and result in the failure flag being set, but 

272 no other specific flags 

273 """ 

274 self.config.plugins[self.algName].flux0 = 0.0 # divide by zero 

275 schema = self.dataset.makeMinimalSchema() 

276 task = lsst.meas.base.SingleFrameMeasurementTask(schema=schema, config=self.config) 

277 exposure, cat = self.dataset.realize(noise=100.0, schema=schema, randomSeed=1) 

278 task.log.setLevel(task.log.FATAL) 

279 task.run(cat, exposure) 

280 source = cat[0] 

281 self.assertTrue(source.get(self.algName + "_flag")) 

282 self.assertFalse(source.get(self.algName + "_flag_containsNan")) 

283 self.assertFalse(source.get(self.algName + "_flag_edge")) 

284 

285 def testPluginContainsNan(self): 

286 """Test that the ``containsNan`` error can be triggered. 

287 """ 

288 schema = self.dataset.makeMinimalSchema() 

289 task = lsst.meas.base.SingleFrameMeasurementTask(schema=schema, config=self.config) 

290 exposure, cat = self.dataset.realize(noise=100.0, schema=schema, randomSeed=2) 

291 source = cat[0] 

292 exposure.getMaskedImage().getImage().getArray()[int(source.getY()), int(source.getX())] = np.nan 

293 task.run(cat, exposure) 

294 self.assertTrue(source.get(self.algName + "_flag")) 

295 self.assertTrue(source.get(self.algName + "_flag_containsNan")) 

296 self.assertFalse(source.get(self.algName + "_flag_edge")) 

297 

298 def testPluginEdgeError(self): 

299 """Test that the ``edge`` error can be triggered. 

300 """ 

301 schema = self.dataset.makeMinimalSchema() 

302 task = lsst.meas.base.SingleFrameMeasurementTask(schema=schema, config=self.config) 

303 exposure, cat = self.dataset.realize(noise=100.0, schema=schema, randomSeed=3) 

304 # Set the size large enough to trigger the edge error 

305 self.config.plugins[self.algName].size = exposure.getDimensions()[1]//2 

306 task.log.setLevel(task.log.FATAL) 

307 task.run(cat, exposure) 

308 source = cat[0] 

309 self.assertTrue(source.get(self.algName + "_flag")) 

310 self.assertFalse(source.get(self.algName + "_flag_containsNan")) 

311 self.assertTrue(source.get(self.algName + "_flag_edge")) 

312 

313 def testSafeCentroider(self): 

314 """Test `SafeCentroidExtractor` correctly runs and sets flags. 

315 """ 

316 # Normal case should use the centroid slot to get the center, which 

317 # should succeed 

318 schema = self.dataset.makeMinimalSchema() 

319 task = lsst.meas.base.SingleFrameMeasurementTask(schema=schema, config=self.config) 

320 task.log.setLevel(task.log.FATAL) 

321 exposure, cat = self.dataset.realize(noise=0.0, schema=schema, randomSeed=4) 

322 source = cat[0] 

323 task.run(cat, exposure) 

324 self.assertFalse(source.get(self.algName + "_flag")) 

325 instFlux = source.get("test_PythonPlugin_instFlux") 

326 self.assertFalse(np.isnan(instFlux)) 

327 

328 # If one of the center coordinates is nan and the centroid slot error 

329 # flag has not been set, the SafeCentroidExtractor will fail. 

330 source.set('truth_x', np.nan) 

331 source.set('truth_flag', False) 

332 source.set("test_PythonPlugin_instFlux", np.nan) 

333 source.set(self.algName + "_flag", False) 

334 task.run(cat, exposure) 

335 self.assertTrue(source.get(self.algName + "_flag")) 

336 self.assertTrue(np.isnan(source.get("test_PythonPlugin_instFlux"))) 

337 

338 # But if the same conditions occur and the centroid slot error flag is 

339 # set to true, the SafeCentroidExtractor will succeed and the 

340 # algorithm will complete. However, the failure flag will also be 

341 # set. 

342 source.set('truth_x', np.nan) 

343 source.set('truth_flag', True) 

344 source.set("test_PythonPlugin_instFlux", np.nan) 

345 source.set(self.algName + "_flag", False) 

346 task.run(cat, exposure) 

347 self.assertTrue(source.get(self.algName + "_flag")) 

348 self.assertEqual(source.get("test_PythonPlugin_instFlux"), instFlux) 

349 

350 

351class TestMemory(lsst.utils.tests.MemoryTestCase): 

352 pass 

353 

354 

355def setup_module(module): 

356 lsst.utils.tests.init() 

357 

358 

359if __name__ == "__main__": 359 ↛ 360line 359 didn't jump to line 360, because the condition on line 359 was never true

360 lsst.utils.tests.init() 

361 unittest.main()