Coverage for python/lsst/meas/algorithms/measureApCorr.py: 16%

164 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-11 03:52 -0700

1# 

2# LSST Data Management System 

3# 

4# Copyright 2008-2017 AURA/LSST. 

5# 

6# This product includes software developed by the 

7# LSST Project (http://www.lsst.org/). 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the LSST License Statement and 

20# the GNU General Public License along with this program. If not, 

21# see <https://www.lsstcorp.org/LegalNotices/>. 

22# 

23 

24__all__ = ("MeasureApCorrConfig", "MeasureApCorrTask", "MeasureApCorrError") 

25 

26import numpy as np 

27from scipy.stats import median_abs_deviation 

28 

29import lsst.pex.config 

30from lsst.afw.image import ApCorrMap 

31from lsst.afw.math import ChebyshevBoundedField, ChebyshevBoundedFieldConfig 

32from lsst.pipe.base import Task, Struct, AlgorithmError 

33from lsst.meas.base.apCorrRegistry import getApCorrNameSet 

34 

35from .sourceSelector import sourceSelectorRegistry 

36 

37 

38class MeasureApCorrError(AlgorithmError): 

39 """Raised if Aperture Correction fails in a non-recoverable way. 

40 

41 Parameters 

42 ---------- 

43 name : `str` 

44 Name of the kind of aperture correction that failed; typically an 

45 instFlux catalog field. 

46 nSources : `int` 

47 Number of sources available to the fitter at the point of failure. 

48 ndof : `int` 

49 Number of degrees of freedom required at the point of failure. 

50 iteration : `int`, optional 

51 Which fit iteration the failure occurred in. 

52 """ 

53 def __init__(self, *, name, nSources, ndof, iteration=None): 

54 msg = f"Unable to measure aperture correction for '{name}'" 

55 if iteration is not None: 

56 msg += f" after {iteration} steps:" 

57 else: 

58 msg += ":" 

59 msg += f" only {nSources} sources, but require at least {ndof}." 

60 super().__init__(msg) 

61 self.name = name 

62 self.nSources = nSources 

63 self.ndof = ndof 

64 self.iteration = iteration 

65 

66 @property 

67 def metadata(self): 

68 metadata = {"name": self.name, 

69 "nSources": self.nSources, 

70 "ndof": self.ndof, 

71 } 

72 # NOTE: have to do this because task metadata doesn't allow None. 

73 if self.iteration is not None: 

74 metadata["iteration"] = self.iteration 

75 return metadata 

76 

77 

78class _FluxNames: 

79 """A collection of flux-related names for a given flux measurement algorithm. 

80 

81 Parameters 

82 ---------- 

83 name : `str` 

84 Name of flux measurement algorithm, e.g. ``base_PsfFlux``. 

85 schema : `lsst.afw.table.Schema` 

86 Catalog schema containing the flux field. The ``{name}_instFlux``, 

87 ``{name}_instFluxErr``, ``{name}_flag`` fields are checked for 

88 existence, and the ``apcorr_{name}_used`` field is added. 

89 

90 Raises 

91 ------ 

92 KeyError if any of instFlux, instFluxErr, or flag fields is missing. 

93 """ 

94 def __init__(self, name, schema): 

95 self.fluxName = name + "_instFlux" 

96 if self.fluxName not in schema: 

97 raise KeyError("Could not find " + self.fluxName) 

98 self.errName = name + "_instFluxErr" 

99 if self.errName not in schema: 

100 raise KeyError("Could not find " + self.errName) 

101 self.flagName = name + "_flag" 

102 if self.flagName not in schema: 

103 raise KeyError("Cound not find " + self.flagName) 

104 self.usedName = "apcorr_" + name + "_used" 

105 schema.addField(self.usedName, type="Flag", 

106 doc="Set if source was used in measuring aperture correction.") 

107 

108 

109class MeasureApCorrConfig(lsst.pex.config.Config): 

110 """Configuration for MeasureApCorrTask. 

111 """ 

112 refFluxName = lsst.pex.config.Field( 

113 doc="Field name prefix for the flux other measurements should be aperture corrected to match", 

114 dtype=str, 

115 default="slot_CalibFlux", 

116 ) 

117 sourceSelector = sourceSelectorRegistry.makeField( 

118 doc="Selector that sets the stars that aperture corrections will be measured from.", 

119 default="science", 

120 ) 

121 minDegreesOfFreedom = lsst.pex.config.RangeField( 

122 doc="Minimum number of degrees of freedom (# of valid data points - # of parameters);" 

123 " if this is exceeded, the order of the fit is decreased (in both dimensions), and" 

124 " if we can't decrease it enough, we'll raise ValueError.", 

125 dtype=int, 

126 default=1, 

127 min=1, 

128 ) 

129 fitConfig = lsst.pex.config.ConfigField( 

130 doc="Configuration used in fitting the aperture correction fields.", 

131 dtype=ChebyshevBoundedFieldConfig, 

132 ) 

133 numIter = lsst.pex.config.Field( 

134 doc="Number of iterations for robust MAD sigma clipping.", 

135 dtype=int, 

136 default=4, 

137 ) 

138 numSigmaClip = lsst.pex.config.Field( 

139 doc="Number of robust MAD sigma to do clipping.", 

140 dtype=float, 

141 default=4.0, 

142 ) 

143 allowFailure = lsst.pex.config.ListField( 

144 doc="Allow these measurement algorithms to fail without an exception.", 

145 dtype=str, 

146 default=[], 

147 ) 

148 

149 def setDefaults(self): 

150 selector = self.sourceSelector["science"] 

151 

152 selector.doFlags = True 

153 selector.doUnresolved = True 

154 selector.doSignalToNoise = True 

155 selector.doIsolated = False 

156 selector.flags.good = [] 

157 selector.flags.bad = [ 

158 "base_PixelFlags_flag_edge", 

159 "base_PixelFlags_flag_interpolatedCenter", 

160 "base_PixelFlags_flag_saturatedCenter", 

161 "base_PixelFlags_flag_crCenter", 

162 "base_PixelFlags_flag_bad", 

163 "base_PixelFlags_flag_interpolated", 

164 "base_PixelFlags_flag_saturated", 

165 ] 

166 selector.signalToNoise.minimum = 200.0 

167 selector.signalToNoise.maximum = None 

168 selector.signalToNoise.fluxField = "base_PsfFlux_instFlux" 

169 selector.signalToNoise.errField = "base_PsfFlux_instFluxErr" 

170 

171 def validate(self): 

172 lsst.pex.config.Config.validate(self) 

173 if self.sourceSelector.target.usesMatches: 

174 raise lsst.pex.config.FieldValidationError( 

175 MeasureApCorrConfig.sourceSelector, 

176 self, 

177 "Star selectors that require matches are not permitted.") 

178 

179 

180class MeasureApCorrTask(Task): 

181 """Task to measure aperture correction. 

182 """ 

183 ConfigClass = MeasureApCorrConfig 

184 _DefaultName = "measureApCorr" 

185 

186 def __init__(self, schema, **kwargs): 

187 """Construct a MeasureApCorrTask 

188 

189 For every name in lsst.meas.base.getApCorrNameSet(): 

190 - If the corresponding flux fields exist in the schema: 

191 - Add a new field apcorr_{name}_used 

192 - Add an entry to the self.toCorrect dict 

193 - Otherwise silently skip the name 

194 """ 

195 Task.__init__(self, **kwargs) 

196 self.refFluxNames = _FluxNames(self.config.refFluxName, schema) 

197 self.toCorrect = {} # dict of flux field name prefix: FluxKeys instance 

198 for name in sorted(getApCorrNameSet()): 

199 try: 

200 self.toCorrect[name] = _FluxNames(name, schema) 

201 except KeyError: 

202 # if a field in the registry is missing, just ignore it. 

203 pass 

204 self.makeSubtask("sourceSelector") 

205 

206 def run(self, exposure, catalog): 

207 """Measure aperture correction 

208 

209 Parameters 

210 ---------- 

211 exposure : `lsst.afw.image.Exposure` 

212 Exposure aperture corrections are being measured on. The 

213 bounding box is retrieved from it, and it is passed to the 

214 sourceSelector. The output aperture correction map is *not* 

215 added to the exposure; this is left to the caller. 

216 catalog : `lsst.afw.table.SourceCatalog` 

217 SourceCatalog containing measurements to be used to 

218 compute aperture corrections. 

219 

220 Returns 

221 ------- 

222 Struct : `lsst.pipe.base.Struct` 

223 Contains the following: 

224 

225 ``apCorrMap`` 

226 aperture correction map (`lsst.afw.image.ApCorrMap`) 

227 that contains two entries for each flux field: 

228 - flux field (e.g. base_PsfFlux_instFlux): 2d model 

229 - flux sigma field (e.g. base_PsfFlux_instFluxErr): 2d model of error 

230 """ 

231 bbox = exposure.getBBox() 

232 import lsstDebug 

233 display = lsstDebug.Info(__name__).display 

234 doPause = lsstDebug.Info(__name__).doPause 

235 

236 self.log.info("Measuring aperture corrections for %d flux fields", len(self.toCorrect)) 

237 

238 # First, create a subset of the catalog that contains only selected stars 

239 # with non-flagged reference fluxes. 

240 selected = self.sourceSelector.run(catalog, exposure=exposure) 

241 

242 use = ( 

243 ~selected.sourceCat[self.refFluxNames.flagName] 

244 & (np.isfinite(selected.sourceCat[self.refFluxNames.fluxName])) 

245 ) 

246 goodRefCat = selected.sourceCat[use].copy() 

247 

248 apCorrMap = ApCorrMap() 

249 

250 # Outer loop over the fields we want to correct 

251 for name, fluxNames in self.toCorrect.items(): 

252 # Create a more restricted subset with only the objects where the to-be-correct flux 

253 # is not flagged. 

254 fluxes = goodRefCat[fluxNames.fluxName] 

255 with np.errstate(invalid="ignore"): # suppress NaN warnings. 

256 isGood = ( 

257 (~goodRefCat[fluxNames.flagName]) 

258 & (np.isfinite(fluxes)) 

259 & (fluxes > 0.0) 

260 ) 

261 

262 # The 1 is the minimum number of ctrl.computeSize() when the order 

263 # drops to 0 in both x and y. 

264 if (isGood.sum() - 1) < self.config.minDegreesOfFreedom: 

265 if name in self.config.allowFailure: 

266 self.log.warning("Unable to measure aperture correction for '%s': " 

267 "only %d sources, but require at least %d.", 

268 name, isGood.sum(), self.config.minDegreesOfFreedom + 1) 

269 continue 

270 else: 

271 raise MeasureApCorrError(name=name, nSources=isGood.sum(), 

272 ndof=self.config.minDegreesOfFreedom + 1) 

273 

274 goodCat = goodRefCat[isGood].copy() 

275 

276 x = goodCat['slot_Centroid_x'] 

277 y = goodCat['slot_Centroid_y'] 

278 z = goodCat[self.refFluxNames.fluxName]/goodCat[fluxNames.fluxName] 

279 ids = goodCat['id'] 

280 

281 # We start with an initial fit that is the median offset; this 

282 # works well in practice. 

283 fitValues = np.median(z) 

284 

285 ctrl = self.config.fitConfig.makeControl() 

286 

287 allBad = False 

288 for iteration in range(self.config.numIter): 

289 resid = z - fitValues 

290 # We add a small (epsilon) amount of floating-point slop because 

291 # the median_abs_deviation may give a value that is just larger than 0 

292 # even if given a completely flat residual field (as in tests). 

293 apCorrErr = median_abs_deviation(resid, scale="normal") + 1e-7 

294 keep = np.abs(resid) <= self.config.numSigmaClip * apCorrErr 

295 

296 self.log.debug("Removing %d sources as outliers.", len(resid) - keep.sum()) 

297 

298 x = x[keep] 

299 y = y[keep] 

300 z = z[keep] 

301 ids = ids[keep] 

302 

303 while (len(x) - ctrl.computeSize()) < self.config.minDegreesOfFreedom: 

304 if ctrl.orderX > 0: 

305 ctrl.orderX -= 1 

306 else: 

307 allBad = True 

308 break 

309 if ctrl.orderY > 0: 

310 ctrl.orderY -= 1 

311 else: 

312 allBad = True 

313 break 

314 

315 if allBad: 

316 if name in self.config.allowFailure: 

317 self.log.warning("Unable to measure aperture correction for '%s': " 

318 "only %d sources remain, but require at least %d." % 

319 (name, keep.sum(), self.config.minDegreesOfFreedom + 1)) 

320 break 

321 else: 

322 raise MeasureApCorrError(name=name, nSources=keep.sum(), 

323 ndof=self.config.minDegreesOfFreedom + 1, 

324 iteration=iteration+1) 

325 

326 apCorrField = ChebyshevBoundedField.fit(bbox, x, y, z, ctrl) 

327 fitValues = apCorrField.evaluate(x, y) 

328 

329 if allBad: 

330 continue 

331 

332 self.log.info( 

333 "Aperture correction for %s from %d stars: MAD %f, RMS %f", 

334 name, 

335 len(x), 

336 median_abs_deviation(fitValues - z, scale="normal"), 

337 np.mean((fitValues - z)**2.)**0.5, 

338 ) 

339 

340 if display: 

341 plotApCorr(bbox, x, y, z, apCorrField, "%s, final" % (name,), doPause) 

342 

343 # Record which sources were used. 

344 used = np.zeros(len(catalog), dtype=bool) 

345 used[np.searchsorted(catalog['id'], ids)] = True 

346 catalog[fluxNames.usedName] = used 

347 

348 # Save the result in the output map 

349 # The error is constant spatially (we could imagine being 

350 # more clever, but we're not yet sure if it's worth the effort). 

351 # We save the errors as a 0th-order ChebyshevBoundedField 

352 apCorrMap[fluxNames.fluxName] = apCorrField 

353 apCorrMap[fluxNames.errName] = ChebyshevBoundedField( 

354 bbox, 

355 np.array([[apCorrErr]]), 

356 ) 

357 

358 return Struct( 

359 apCorrMap=apCorrMap, 

360 ) 

361 

362 

363def plotApCorr(bbox, xx, yy, zzMeasure, field, title, doPause): 

364 """Plot aperture correction fit residuals 

365 

366 There are two subplots: residuals against x and y. 

367 

368 Intended for debugging. 

369 

370 Parameters 

371 ---------- 

372 bbox : `lsst.geom.Box2I` 

373 Bounding box (for bounds) 

374 xx, yy : `numpy.ndarray`, (N) 

375 x and y coordinates 

376 zzMeasure : `float` 

377 Measured value of the aperture correction 

378 field : `lsst.afw.math.ChebyshevBoundedField` 

379 Fit aperture correction field 

380 title : 'str' 

381 Title for plot 

382 doPause : `bool` 

383 Pause to inspect the residuals plot? If 

384 False, there will be a 4 second delay to 

385 allow for inspection of the plot before 

386 closing it and moving on. 

387 """ 

388 import matplotlib.pyplot as plt 

389 

390 zzFit = field.evaluate(xx, yy) 

391 residuals = zzMeasure - zzFit 

392 

393 fig, axes = plt.subplots(2, 1) 

394 

395 axes[0].scatter(xx, residuals, s=3, marker='o', lw=0, alpha=0.7) 

396 axes[1].scatter(yy, residuals, s=3, marker='o', lw=0, alpha=0.7) 

397 for ax in axes: 

398 ax.set_ylabel("ApCorr Fit Residual") 

399 ax.set_ylim(0.9*residuals.min(), 1.1*residuals.max()) 

400 axes[0].set_xlabel("x") 

401 axes[0].set_xlim(bbox.getMinX(), bbox.getMaxX()) 

402 axes[1].set_xlabel("y") 

403 axes[1].set_xlim(bbox.getMinY(), bbox.getMaxY()) 

404 plt.suptitle(title) 

405 

406 if not doPause: 

407 try: 

408 plt.pause(4) 

409 plt.close() 

410 except Exception: 

411 print("%s: plt.pause() failed. Please close plots when done." % __name__) 

412 plt.show() 

413 else: 

414 print("%s: Please close plots when done." % __name__) 

415 plt.show()