Coverage for python/lsst/meas/extensions/trailedSources/NaivePlugin.py: 20%

191 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-08 03:17 -0700

1# 

2# This file is part of meas_extensions_trailedSources. 

3# 

4# Developed for the LSST Data Management System. 

5# This product includes software developed by the LSST Project 

6# (http://www.lsst.org). 

7# See the COPYRIGHT file at the top-level directory of this distribution 

8# for details of code ownership. 

9# 

10# This program is free software: you can redistribute it and/or modify 

11# it under the terms of the GNU General Public License as published by 

12# the Free Software Foundation, either version 3 of the License, or 

13# (at your option) any later version. 

14# 

15# This program is distributed in the hope that it will be useful, 

16# but WITHOUT ANY WARRANTY; without even the implied warranty of 

17# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

18# GNU General Public License for more details. 

19# 

20# You should have received a copy of the GNU General Public License 

21# along with this program. If not, see <http://www.gnu.org/licenses/>. 

22# 

23 

24import logging 

25import numpy as np 

26import scipy.optimize as sciOpt 

27from scipy.special import erf 

28from math import sqrt 

29 

30from lsst.geom import Point2D, Point2I 

31from lsst.meas.base.pluginRegistry import register 

32from lsst.meas.base import SingleFramePlugin, SingleFramePluginConfig 

33from lsst.meas.base import FlagHandler, FlagDefinitionList, SafeCentroidExtractor 

34 

35from ._trailedSources import VeresModel 

36from .utils import getMeasurementCutout 

37 

38__all__ = ("SingleFrameNaiveTrailConfig", "SingleFrameNaiveTrailPlugin") 

39 

40 

41class SingleFrameNaiveTrailConfig(SingleFramePluginConfig): 

42 """Config class for SingleFrameNaiveTrailPlugin. 

43 """ 

44 pass 

45 

46 

47@register("ext_trailedSources_Naive") 

48class SingleFrameNaiveTrailPlugin(SingleFramePlugin): 

49 """Naive trailed source measurement plugin 

50 

51 Measures the length, angle from +x-axis, and end points of an extended 

52 source using the second moments. 

53 

54 Parameters 

55 ---------- 

56 config: `SingleFrameNaiveTrailConfig` 

57 Plugin configuration. 

58 name: `str` 

59 Plugin name. 

60 schema: `lsst.afw.table.Schema` 

61 Schema for the output catalog. 

62 metadata: `lsst.daf.base.PropertySet` 

63 Metadata to be attached to output catalog. 

64 

65 Notes 

66 ----- 

67 This measurement plugin aims to utilize the already measured adaptive 

68 second moments to naively estimate the length and angle, and thus 

69 end-points, of a fast-moving, trailed source. The length is solved for via 

70 finding the root of the difference between the numerical (stack computed) 

71 and the analytic adaptive second moments. The angle, theta, from the x-axis 

72 is also computed via adaptive moments: theta = arctan(2*Ixy/(Ixx - Iyy))/2. 

73 The end points of the trail are then given by (xc +/- (length/2)*cos(theta) 

74 and yc +/- (length/2)*sin(theta)), with xc and yc being the centroid 

75 coordinates. 

76 

77 See also 

78 -------- 

79 lsst.meas.base.SingleFramePlugin 

80 """ 

81 

82 ConfigClass = SingleFrameNaiveTrailConfig 

83 

84 @classmethod 

85 def getExecutionOrder(cls): 

86 # Needs centroids, shape, and flux measurements. 

87 # VeresPlugin is run after, which requires image data. 

88 return cls.APCORR_ORDER + 0.1 

89 

90 def __init__(self, config, name, schema, metadata, logName=None): 

91 if logName is None: 

92 logName = __name__ 

93 super().__init__(config, name, schema, metadata, logName=logName) 

94 

95 # Measurement Keys 

96 self.keyRa = schema.addField(name + "_ra", type="D", doc="Trail centroid right ascension.") 

97 self.keyDec = schema.addField(name + "_dec", type="D", doc="Trail centroid declination.") 

98 self.keyX0 = schema.addField(name + "_x0", type="D", doc="Trail head X coordinate.", units="pixel") 

99 self.keyY0 = schema.addField(name + "_y0", type="D", doc="Trail head Y coordinate.", units="pixel") 

100 self.keyX1 = schema.addField(name + "_x1", type="D", doc="Trail tail X coordinate.", units="pixel") 

101 self.keyY1 = schema.addField(name + "_y1", type="D", doc="Trail tail Y coordinate.", units="pixel") 

102 self.keyFlux = schema.addField(name + "_flux", type="D", doc="Trailed source flux.", units="count") 

103 self.keyLength = schema.addField(name + "_length", type="D", doc="Trail length.", units="pixel") 

104 self.keyAngle = schema.addField(name + "_angle", type="D", doc="Angle measured from +x-axis.") 

105 

106 # Measurement Error Keys 

107 self.keyX0Err = schema.addField(name + "_x0Err", type="D", 

108 doc="Trail head X coordinate error.", units="pixel") 

109 self.keyY0Err = schema.addField(name + "_y0Err", type="D", 

110 doc="Trail head Y coordinate error.", units="pixel") 

111 self.keyX1Err = schema.addField(name + "_x1Err", type="D", 

112 doc="Trail tail X coordinate error.", units="pixel") 

113 self.keyY1Err = schema.addField(name + "_y1Err", type="D", 

114 doc="Trail tail Y coordinate error.", units="pixel") 

115 self.keyFluxErr = schema.addField(name + "_fluxErr", type="D", 

116 doc="Trail flux error.", units="count") 

117 self.keyLengthErr = schema.addField(name + "_lengthErr", type="D", 

118 doc="Trail length error.", units="pixel") 

119 self.keyAngleErr = schema.addField(name + "_angleErr", type="D", doc="Trail angle error.") 

120 

121 flagDefs = FlagDefinitionList() 

122 self.FAILURE = flagDefs.addFailureFlag("No trailed-source measured") 

123 self.NO_FLUX = flagDefs.add("flag_noFlux", "No suitable prior flux measurement") 

124 self.NO_CONVERGE = flagDefs.add("flag_noConverge", "The root finder did not converge") 

125 self.NO_SIGMA = flagDefs.add("flag_noSigma", "No PSF width (sigma)") 

126 self.SAFE_CENTROID = flagDefs.add("flag_safeCentroid", "Fell back to safe centroid extractor") 

127 self.EDGE = flagDefs.add("flag_edge", "Trail contains edge pixels or extends off chip") 

128 self.SHAPE = flagDefs.add("flag_shape", "Shape flag is set, trail not calculated") 

129 self.flagHandler = FlagHandler.addFields(schema, name, flagDefs) 

130 

131 self.centroidExtractor = SafeCentroidExtractor(schema, name) 

132 self.log = logging.getLogger(self.logName) 

133 

134 def measure(self, measRecord, exposure): 

135 """Run the Naive trailed source measurement algorithm. 

136 

137 Parameters 

138 ---------- 

139 measRecord : `lsst.afw.table.SourceRecord` 

140 Record describing the object being measured. 

141 exposure : `lsst.afw.image.Exposure` 

142 Pixel data to be measured. 

143 

144 See also 

145 -------- 

146 lsst.meas.base.SingleFramePlugin.measure 

147 """ 

148 

149 # Get the SdssShape centroid or fall back to slot 

150 # There are currently no centroid errors for SdssShape 

151 xc = measRecord.get("base_SdssShape_x") 

152 yc = measRecord.get("base_SdssShape_y") 

153 if not np.isfinite(xc) or not np.isfinite(yc): 

154 xc, yc = self.centroidExtractor(measRecord, self.flagHandler) 

155 self.flagHandler.setValue(measRecord, self.SAFE_CENTROID.number, True) 

156 self.flagHandler.setValue(measRecord, self.FAILURE.number, True) 

157 return 

158 

159 ra, dec = self.computeRaDec(exposure, xc, yc) 

160 

161 if measRecord.getShapeFlag(): 

162 self.log.warning("Shape flag is set for measRecord: %s. Trail measurement " 

163 "will not be made.", measRecord.getId()) 

164 self.flagHandler.setValue(measRecord, self.FAILURE.number, True) 

165 self.flagHandler.setValue(measRecord, self.SHAPE.number, True) 

166 return 

167 

168 # Transform the second-moments to semi-major and minor axes 

169 Ixx, Iyy, Ixy = measRecord.getShape().getParameterVector() 

170 xmy = Ixx - Iyy 

171 xpy = Ixx + Iyy 

172 xmy2 = xmy*xmy 

173 xy2 = Ixy*Ixy 

174 a2 = 0.5 * (xpy + sqrt(xmy2 + 4.0*xy2)) 

175 b2 = 0.5 * (xpy - sqrt(xmy2 + 4.0*xy2)) 

176 

177 # Measure the trail length 

178 # Check if the second-moments are weighted 

179 if measRecord.get("base_SdssShape_flag_unweighted"): 

180 self.log.debug("Unweighted") 

181 length, gradLength = self.computeLength(a2, b2) 

182 else: 

183 self.log.debug("Weighted") 

184 length, gradLength, results = self.findLength(a2, b2) 

185 if not results.converged: 

186 self.log.info("Results not converged: %s", results.flag) 

187 self.flagHandler.setValue(measRecord, self.NO_CONVERGE.number, True) 

188 self.flagHandler.setValue(measRecord, self.FAILURE.number, True) 

189 return 

190 

191 # Compute the angle of the trail from the x-axis 

192 theta = 0.5 * np.arctan2(2.0 * Ixy, xmy) 

193 

194 # Get end-points of the trail (there is a degeneracy here) 

195 radius = length/2.0 # Trail 'radius' 

196 dydtheta = radius*np.cos(theta) 

197 dxdtheta = radius*np.sin(theta) 

198 x0 = xc - dydtheta 

199 y0 = yc - dxdtheta 

200 x1 = xc + dydtheta 

201 y1 = yc + dxdtheta 

202 

203 # Check whether trail extends off the edge of the exposure 

204 if not (exposure.getBBox().beginX <= x0 <= exposure.getBBox().endX 

205 and exposure.getBBox().beginX <= x1 <= exposure.getBBox().endX 

206 and exposure.getBBox().beginY <= y0 <= exposure.getBBox().endY 

207 and exposure.getBBox().beginY <= y1 <= exposure.getBBox().endY): 

208 

209 self.flagHandler.setValue(measRecord, self.EDGE.number, True) 

210 

211 else: 

212 # Check whether the beginning or end point of the trail has the 

213 # edge flag set. The end points are not whole pixel values, so 

214 # the pixel value must be rounded. 

215 if exposure.mask[Point2I(int(x0), int(y0))] and exposure.mask[Point2I(int(x1), int(y1))]: 

216 if ((exposure.mask[Point2I(int(x0), int(y0))] & exposure.mask.getPlaneBitMask('EDGE') != 0) 

217 or (exposure.mask[Point2I(int(x1), int(y1))] 

218 & exposure.mask.getPlaneBitMask('EDGE') != 0)): 

219 

220 self.flagHandler.setValue(measRecord, self.EDGE.number, True) 

221 

222 # Get a cutout of the object from the exposure 

223 cutout = getMeasurementCutout(measRecord, exposure) 

224 

225 # Compute flux assuming fixed parameters for VeresModel 

226 params = np.array([xc, yc, 1.0, length, theta]) # Flux = 1.0 

227 model = VeresModel(cutout) 

228 flux, gradFlux = model.computeFluxWithGradient(params) 

229 

230 # Fall back to aperture flux 

231 if not np.isfinite(flux): 

232 if np.isfinite(measRecord.getApInstFlux()): 

233 flux = measRecord.getApInstFlux() 

234 else: 

235 self.flagHandler.setValue(measRecord, self.NO_FLUX.number, True) 

236 self.flagHandler.setValue(measRecord, self.FAILURE.number, True) 

237 return 

238 

239 # Propogate errors from second moments and centroid 

240 IxxErr2, IyyErr2, IxyErr2 = np.diag(measRecord.getShapeErr()) 

241 

242 # SdssShape does not produce centroid errors. The 

243 # Slot centroid errors will suffice for now. 

244 xcErr2, ycErr2 = np.diag(measRecord.getCentroidErr()) 

245 

246 # Error in length 

247 desc = sqrt(xmy2 + 4.0*xy2) # Descriminant^1/2 of EV equation 

248 da2dIxx = 0.5*(1.0 + (xmy/desc)) 

249 da2dIyy = 0.5*(1.0 - (xmy/desc)) 

250 da2dIxy = 2.0*Ixy / desc 

251 a2Err2 = IxxErr2*da2dIxx*da2dIxx + IyyErr2*da2dIyy*da2dIyy + IxyErr2*da2dIxy*da2dIxy 

252 b2Err2 = IxxErr2*da2dIyy*da2dIyy + IyyErr2*da2dIxx*da2dIxx + IxyErr2*da2dIxy*da2dIxy 

253 dLda2, dLdb2 = gradLength 

254 lengthErr = np.sqrt(dLda2*dLda2*a2Err2 + dLdb2*dLdb2*b2Err2) 

255 

256 # Error in theta 

257 dThetadIxx = -Ixy / (xmy2 + 4.0*xy2) # dThetadIxx = -dThetadIyy 

258 dThetadIxy = xmy / (xmy2 + 4.0*xy2) 

259 thetaErr = sqrt(dThetadIxx*dThetadIxx*(IxxErr2 + IyyErr2) + dThetadIxy*dThetadIxy*IxyErr2) 

260 

261 # Error in flux 

262 dFdxc, dFdyc, _, dFdL, dFdTheta = gradFlux 

263 fluxErr = sqrt(dFdL*dFdL*lengthErr*lengthErr + dFdTheta*dFdTheta*thetaErr*thetaErr 

264 + dFdxc*dFdxc*xcErr2 + dFdyc*dFdyc*ycErr2) 

265 

266 # Errors in end-points 

267 dxdradius = np.cos(theta) 

268 dydradius = np.sin(theta) 

269 radiusErr2 = lengthErr*lengthErr/4.0 

270 xErr2 = sqrt(xcErr2 + radiusErr2*dxdradius*dxdradius + thetaErr*thetaErr*dxdtheta*dxdtheta) 

271 yErr2 = sqrt(ycErr2 + radiusErr2*dydradius*dydradius + thetaErr*thetaErr*dydtheta*dydtheta) 

272 x0Err = sqrt(xErr2) # Same for x1 

273 y0Err = sqrt(yErr2) # Same for y1 

274 

275 # Set flags 

276 measRecord.set(self.keyRa, ra) 

277 measRecord.set(self.keyDec, dec) 

278 measRecord.set(self.keyX0, x0) 

279 measRecord.set(self.keyY0, y0) 

280 measRecord.set(self.keyX1, x1) 

281 measRecord.set(self.keyY1, y1) 

282 measRecord.set(self.keyFlux, flux) 

283 measRecord.set(self.keyLength, length) 

284 measRecord.set(self.keyAngle, theta) 

285 measRecord.set(self.keyX0Err, x0Err) 

286 measRecord.set(self.keyY0Err, y0Err) 

287 measRecord.set(self.keyX1Err, x0Err) 

288 measRecord.set(self.keyY1Err, y0Err) 

289 measRecord.set(self.keyFluxErr, fluxErr) 

290 measRecord.set(self.keyLengthErr, lengthErr) 

291 measRecord.set(self.keyAngleErr, thetaErr) 

292 

293 def fail(self, measRecord, error=None): 

294 """Record failure 

295 

296 See also 

297 -------- 

298 lsst.meas.base.SingleFramePlugin.fail 

299 """ 

300 if error is None: 

301 self.flagHandler.handleFailure(measRecord) 

302 else: 

303 self.flagHandler.handleFailure(measRecord, error.cpp) 

304 

305 @staticmethod 

306 def _computeSecondMomentDiff(z, c): 

307 """Compute difference of the numerical and analytic second moments. 

308 

309 Parameters 

310 ---------- 

311 z : `float` 

312 Proportional to the length of the trail. (see notes) 

313 c : `float` 

314 Constant (see notes) 

315 

316 Returns 

317 ------- 

318 diff : `float` 

319 Difference in numerical and analytic second moments. 

320 

321 Notes 

322 ----- 

323 This is a simplified expression for the difference between the stack 

324 computed adaptive second-moment and the analytic solution. The variable 

325 z is proportional to the length such that length=2*z*sqrt(2*(Ixx+Iyy)), 

326 and c is a constant (c = 4*Ixx/((Ixx+Iyy)*sqrt(pi))). Both have been 

327 defined to avoid unnecessary floating-point operations in the root 

328 finder. 

329 """ 

330 

331 diff = erf(z) - c*z*np.exp(-z*z) 

332 return diff 

333 

334 @classmethod 

335 def findLength(cls, Ixx, Iyy): 

336 """Find the length of a trail, given adaptive second-moments. 

337 

338 Uses a root finder to compute the length of a trail corresponding to 

339 the adaptive second-moments computed by previous measurements 

340 (ie. SdssShape). 

341 

342 Parameters 

343 ---------- 

344 Ixx : `float` 

345 Adaptive second-moment along x-axis. 

346 Iyy : `float` 

347 Adaptive second-moment along y-axis. 

348 

349 Returns 

350 ------- 

351 length : `float` 

352 Length of the trail. 

353 results : `scipy.optimize.RootResults` 

354 Contains messages about convergence from the root finder. 

355 """ 

356 

357 xpy = Ixx + Iyy 

358 c = 4.0*Ixx/(xpy*np.sqrt(np.pi)) 

359 

360 # Given a 'c' in (c_min, c_max], the root is contained in (0,1]. 

361 # c_min is given by the case: Ixx == Iyy, ie. a point source. 

362 # c_max is given by the limit Ixx >> Iyy. 

363 # Empirically, 0.001 is a suitable lower bound, assuming Ixx > Iyy. 

364 z, results = sciOpt.brentq(lambda z: cls._computeSecondMomentDiff(z, c), 

365 0.001, 1.0, full_output=True) 

366 

367 length = 2.0*z*np.sqrt(2.0*xpy) 

368 gradLength = cls._gradFindLength(Ixx, Iyy, z, c) 

369 return length, gradLength, results 

370 

371 @staticmethod 

372 def _gradFindLength(Ixx, Iyy, z, c): 

373 """Compute the gradient of the findLength function. 

374 """ 

375 spi = np.sqrt(np.pi) 

376 xpy = Ixx+Iyy 

377 xpy2 = xpy*xpy 

378 enz2 = np.exp(-z*z) 

379 sxpy = np.sqrt(xpy) 

380 

381 fac = 4.0 / (spi*xpy2) 

382 dcdIxx = Iyy*fac 

383 dcdIyy = -Ixx*fac 

384 

385 # Derivatives of the _computeMomentsDiff function 

386 dfdc = z*enz2 

387 dzdf = spi / (enz2*(spi*c*(2.0*z*z - 1.0) + 2.0)) # inverse of dfdz 

388 

389 dLdz = 2.0*np.sqrt(2.0)*sxpy 

390 pLpIxx = np.sqrt(2.0)*z / sxpy # Same as pLpIyy 

391 

392 dLdc = dLdz*dzdf*dfdc 

393 dLdIxx = dLdc*dcdIxx + pLpIxx 

394 dLdIyy = dLdc*dcdIyy + pLpIxx 

395 return dLdIxx, dLdIyy 

396 

397 @staticmethod 

398 def computeLength(Ixx, Iyy): 

399 """Compute the length of a trail, given unweighted second-moments. 

400 """ 

401 denom = np.sqrt(Ixx - 2.0*Iyy) 

402 

403 length = np.sqrt(6.0)*denom 

404 

405 dLdIxx = np.sqrt(1.5) / denom 

406 dLdIyy = -np.sqrt(6.0) / denom 

407 return length, (dLdIxx, dLdIyy) 

408 

409 @staticmethod 

410 def computeRaDec(exposure, x, y): 

411 """Convert pixel coordinates to RA and Dec. 

412 

413 Parameters 

414 ---------- 

415 exposure : `lsst.afw.image.ExposureF` 

416 Exposure object containing the WCS. 

417 x : `float` 

418 x coordinate of the trail centroid 

419 y : `float` 

420 y coodinate of the trail centroid 

421 

422 Returns 

423 ------- 

424 ra : `float` 

425 Right ascension. 

426 dec : `float` 

427 Declination. 

428 """ 

429 

430 wcs = exposure.getWcs() 

431 center = wcs.pixelToSky(Point2D(x, y)) 

432 ra = center.getRa().asDegrees() 

433 dec = center.getDec().asDegrees() 

434 return ra, dec