Coverage for python/lsst/meas/extensions/trailedSources/NaivePlugin.py: 19%

170 statements  

« prev     ^ index     » next       coverage.py v6.4.1, created at 2022-06-09 11:18 +0000

1# 

2# This file is part of meas_extensions_trailedSources. 

3# 

4# Developed for the LSST Data Management System. 

5# This product includes software developed by the LSST Project 

6# (http://www.lsst.org). 

7# See the COPYRIGHT file at the top-level directory of this distribution 

8# for details of code ownership. 

9# 

10# This program is free software: you can redistribute it and/or modify 

11# it under the terms of the GNU General Public License as published by 

12# the Free Software Foundation, either version 3 of the License, or 

13# (at your option) any later version. 

14# 

15# This program is distributed in the hope that it will be useful, 

16# but WITHOUT ANY WARRANTY; without even the implied warranty of 

17# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

18# GNU General Public License for more details. 

19# 

20# You should have received a copy of the GNU General Public License 

21# along with this program. If not, see <http://www.gnu.org/licenses/>. 

22# 

23 

24import numpy as np 

25import scipy.optimize as sciOpt 

26from scipy.special import erf 

27 

28import lsst.log 

29from lsst.geom import Point2D 

30from lsst.meas.base.pluginRegistry import register 

31from lsst.meas.base import SingleFramePlugin, SingleFramePluginConfig 

32from lsst.meas.base import FlagHandler, FlagDefinitionList, SafeCentroidExtractor 

33from lsst.meas.base import MeasurementError 

34 

35from ._trailedSources import VeresModel 

36from .utils import getMeasurementCutout 

37 

38__all__ = ("SingleFrameNaiveTrailConfig", "SingleFrameNaiveTrailPlugin") 

39 

40 

41class SingleFrameNaiveTrailConfig(SingleFramePluginConfig): 

42 """Config class for SingleFrameNaiveTrailPlugin. 

43 """ 

44 pass 

45 

46 

47@register("ext_trailedSources_Naive") 

48class SingleFrameNaiveTrailPlugin(SingleFramePlugin): 

49 """Naive trailed source measurement plugin 

50 

51 Measures the length, angle from +x-axis, and end points of an extended 

52 source using the second moments. 

53 

54 Parameters 

55 ---------- 

56 config: `SingleFrameNaiveTrailConfig` 

57 Plugin configuration. 

58 name: `str` 

59 Plugin name. 

60 schema: `lsst.afw.table.Schema` 

61 Schema for the output catalog. 

62 metadata: `lsst.daf.base.PropertySet` 

63 Metadata to be attached to output catalog. 

64 

65 Notes 

66 ----- 

67 This measurement plugin aims to utilize the already measured adaptive 

68 second moments to naively estimate the length and angle, and thus 

69 end-points, of a fast-moving, trailed source. The length is solved for via 

70 finding the root of the difference between the numerical (stack computed) 

71 and the analytic adaptive second moments. The angle, theta, from the x-axis 

72 is also computed via adaptive moments: theta = arctan(2*Ixy/(Ixx - Iyy))/2. 

73 The end points of the trail are then given by (xc +/- (length/2)*cos(theta) 

74 and yc +/- (length/2)*sin(theta)), with xc and yc being the centroid 

75 coordinates. 

76 

77 See also 

78 -------- 

79 lsst.meas.base.SingleFramePlugin 

80 """ 

81 

82 ConfigClass = SingleFrameNaiveTrailConfig 

83 

84 @classmethod 

85 def getExecutionOrder(cls): 

86 # Needs centroids, shape, and flux measurements. 

87 # VeresPlugin is run after, which requires image data. 

88 return cls.APCORR_ORDER + 0.1 

89 

90 def __init__(self, config, name, schema, metadata): 

91 super().__init__(config, name, schema, metadata) 

92 

93 # Measurement Keys 

94 self.keyRa = schema.addField(name + "_ra", type="D", doc="Trail centroid right ascension.") 

95 self.keyDec = schema.addField(name + "_dec", type="D", doc="Trail centroid declination.") 

96 self.keyX0 = schema.addField(name + "_x0", type="D", doc="Trail head X coordinate.", units="pixel") 

97 self.keyY0 = schema.addField(name + "_y0", type="D", doc="Trail head Y coordinate.", units="pixel") 

98 self.keyX1 = schema.addField(name + "_x1", type="D", doc="Trail tail X coordinate.", units="pixel") 

99 self.keyY1 = schema.addField(name + "_y1", type="D", doc="Trail tail Y coordinate.", units="pixel") 

100 self.keyFlux = schema.addField(name + "_flux", type="D", doc="Trailed source flux.", units="count") 

101 self.keyLength = schema.addField(name + "_length", type="D", doc="Trail length.", units="pixel") 

102 self.keyAngle = schema.addField(name + "_angle", type="D", doc="Angle measured from +x-axis.") 

103 

104 # Measurement Error Keys 

105 self.keyX0Err = schema.addField(name + "_x0Err", type="D", 

106 doc="Trail head X coordinate error.", units="pixel") 

107 self.keyY0Err = schema.addField(name + "_y0Err", type="D", 

108 doc="Trail head Y coordinate error.", units="pixel") 

109 self.keyX1Err = schema.addField(name + "_x1Err", type="D", 

110 doc="Trail tail X coordinate error.", units="pixel") 

111 self.keyY1Err = schema.addField(name + "_y1Err", type="D", 

112 doc="Trail tail Y coordinate error.", units="pixel") 

113 self.keyFluxErr = schema.addField(name + "_fluxErr", type="D", 

114 doc="Trail flux error.", units="count") 

115 self.keyLengthErr = schema.addField(name + "_lengthErr", type="D", 

116 doc="Trail length error.", units="pixel") 

117 self.keyAngleErr = schema.addField(name + "_angleErr", type="D", doc="Trail angle error.") 

118 

119 flagDefs = FlagDefinitionList() 

120 flagDefs.addFailureFlag("No trailed-source measured") 

121 self.NO_FLUX = flagDefs.add("flag_noFlux", "No suitable prior flux measurement") 

122 self.NO_CONVERGE = flagDefs.add("flag_noConverge", "The root finder did not converge") 

123 self.NO_SIGMA = flagDefs.add("flag_noSigma", "No PSF width (sigma)") 

124 self.SAFE_CENTROID = flagDefs.add("flag_safeCentroid", "Fell back to safe centroid extractor") 

125 self.flagHandler = FlagHandler.addFields(schema, name, flagDefs) 

126 

127 self.centriodExtractor = SafeCentroidExtractor(schema, name) 

128 

129 def measure(self, measRecord, exposure): 

130 """Run the Naive trailed source measurement algorithm. 

131 

132 Parameters 

133 ---------- 

134 measRecord : `lsst.afw.table.SourceRecord` 

135 Record describing the object being measured. 

136 exposure : `lsst.afw.image.Exposure` 

137 Pixel data to be measured. 

138 

139 See also 

140 -------- 

141 lsst.meas.base.SingleFramePlugin.measure 

142 """ 

143 

144 # Get the SdssShape centroid or fall back to slot 

145 # There are currently no centroid errors for SdssShape 

146 xc = measRecord.get("base_SdssShape_x") 

147 yc = measRecord.get("base_SdssShape_y") 

148 if not np.isfinite(xc) or not np.isfinite(yc): 

149 xc, yc = self.centroidExtractor(measRecord, self.flagHandler) 

150 raise MeasurementError(self.SAFE_CENTROID.doc, self.SAFE_CENTROID.number) 

151 

152 ra, dec = self.computeRaDec(exposure, xc, yc) 

153 

154 # Transform the second-moments to semi-major and minor axes 

155 Ixx, Iyy, Ixy = measRecord.getShape().getParameterVector() 

156 xmy = Ixx - Iyy 

157 xpy = Ixx + Iyy 

158 xmy2 = xmy*xmy 

159 xy2 = Ixy*Ixy 

160 a2 = 0.5 * (xpy + np.sqrt(xmy2 + 4.0*xy2)) 

161 b2 = 0.5 * (xpy - np.sqrt(xmy2 + 4.0*xy2)) 

162 

163 # Measure the trail length 

164 # Check if the second-moments are weighted 

165 if measRecord.get("base_SdssShape_flag_unweighted"): 

166 lsst.log.debug("Unweighed") 

167 length, gradLength = self.computeLength(a2, b2) 

168 else: 

169 lsst.log.debug("Weighted") 

170 length, gradLength, results = self.findLength(a2, b2) 

171 if not results.converged: 

172 lsst.log.info(results.flag) 

173 raise MeasurementError(self.NO_CONVERGE.doc, self.NO_CONVERGE.number) 

174 

175 # Compute the angle of the trail from the x-axis 

176 theta = 0.5 * np.arctan2(2.0 * Ixy, xmy) 

177 

178 # Get end-points of the trail (there is a degeneracy here) 

179 radius = length/2.0 # Trail 'radius' 

180 dydtheta = radius*np.cos(theta) 

181 dxdtheta = radius*np.sin(theta) 

182 x0 = xc - dydtheta 

183 y0 = yc - dxdtheta 

184 x1 = xc + dydtheta 

185 y1 = yc + dxdtheta 

186 

187 # Get a cutout of the object from the exposure 

188 cutout = getMeasurementCutout(measRecord, exposure) 

189 

190 # Compute flux assuming fixed parameters for VeresModel 

191 params = np.array([xc, yc, 1.0, length, theta]) # Flux = 1.0 

192 model = VeresModel(cutout) 

193 flux, gradFlux = model.computeFluxWithGradient(params) 

194 

195 # Fall back to aperture flux 

196 if not np.isfinite(flux): 

197 if np.isfinite(measRecord.getApInstFlux()): 

198 flux = measRecord.getApInstFlux() 

199 else: 

200 raise MeasurementError(self.NO_FLUX.doc, self.NO_FLUX.number) 

201 

202 # Propogate errors from second moments and centroid 

203 IxxErr2, IyyErr2, IxyErr2 = np.diag(measRecord.getShapeErr()) 

204 

205 # SdssShape does not produce centroid errors. The 

206 # Slot centroid errors will suffice for now. 

207 xcErr2, ycErr2 = np.diag(measRecord.getCentroidErr()) 

208 

209 # Error in length 

210 desc = np.sqrt(xmy2 + 4.0*xy2) # Descriminant^1/2 of EV equation 

211 da2dIxx = 0.5*(1.0 + (xmy/desc)) 

212 da2dIyy = 0.5*(1.0 - (xmy/desc)) 

213 da2dIxy = 2.0*Ixy / desc 

214 a2Err2 = IxxErr2*da2dIxx*da2dIxx + IyyErr2*da2dIyy*da2dIyy + IxyErr2*da2dIxy*da2dIxy 

215 b2Err2 = IxxErr2*da2dIyy*da2dIyy + IyyErr2*da2dIxx*da2dIxx + IxyErr2*da2dIxy*da2dIxy 

216 dLda2, dLdb2 = gradLength 

217 lengthErr = np.sqrt(dLda2*dLda2*a2Err2 + dLdb2*dLdb2*b2Err2) 

218 

219 # Error in theta 

220 dThetadIxx = -Ixy / (xmy2 + 4.0*xy2) # dThetadIxx = -dThetadIyy 

221 dThetadIxy = xmy / (xmy2 + 4.0*xy2) 

222 thetaErr = np.sqrt(dThetadIxx*dThetadIxx*(IxxErr2 + IyyErr2) + dThetadIxy*dThetadIxy*IxyErr2) 

223 

224 # Error in flux 

225 dFdxc, dFdyc, _, dFdL, dFdTheta = gradFlux 

226 fluxErr = np.sqrt(dFdL*dFdL*lengthErr*lengthErr + dFdTheta*dFdTheta*thetaErr*thetaErr 

227 + dFdxc*dFdxc*xcErr2 + dFdyc*dFdyc*ycErr2) 

228 

229 # Errors in end-points 

230 dxdradius = np.cos(theta) 

231 dydradius = np.sin(theta) 

232 radiusErr2 = lengthErr*lengthErr/4.0 

233 xErr2 = np.sqrt(xcErr2 + radiusErr2*dxdradius*dxdradius + thetaErr*thetaErr*dxdtheta*dxdtheta) 

234 yErr2 = np.sqrt(ycErr2 + radiusErr2*dydradius*dydradius + thetaErr*thetaErr*dydtheta*dydtheta) 

235 x0Err = np.sqrt(xErr2) # Same for x1 

236 y0Err = np.sqrt(yErr2) # Same for y1 

237 

238 # Set flags 

239 measRecord.set(self.keyRa, ra) 

240 measRecord.set(self.keyDec, dec) 

241 measRecord.set(self.keyX0, x0) 

242 measRecord.set(self.keyY0, y0) 

243 measRecord.set(self.keyX1, x1) 

244 measRecord.set(self.keyY1, y1) 

245 measRecord.set(self.keyFlux, flux) 

246 measRecord.set(self.keyLength, length) 

247 measRecord.set(self.keyAngle, theta) 

248 measRecord.set(self.keyX0Err, x0Err) 

249 measRecord.set(self.keyY0Err, y0Err) 

250 measRecord.set(self.keyX1Err, x0Err) 

251 measRecord.set(self.keyY1Err, y0Err) 

252 measRecord.set(self.keyFluxErr, fluxErr) 

253 measRecord.set(self.keyLengthErr, lengthErr) 

254 measRecord.set(self.keyAngleErr, thetaErr) 

255 

256 def fail(self, measRecord, error=None): 

257 """Record failure 

258 

259 See also 

260 -------- 

261 lsst.meas.base.SingleFramePlugin.fail 

262 """ 

263 if error is None: 

264 self.flagHandler.handleFailure(measRecord) 

265 else: 

266 self.flagHandler.handleFailure(measRecord, error.cpp) 

267 

268 @staticmethod 

269 def _computeSecondMomentDiff(z, c): 

270 """Compute difference of the numerical and analytic second moments. 

271 

272 Parameters 

273 ---------- 

274 z : `float` 

275 Proportional to the length of the trail. (see notes) 

276 c : `float` 

277 Constant (see notes) 

278 

279 Returns 

280 ------- 

281 diff : `float` 

282 Difference in numerical and analytic second moments. 

283 

284 Notes 

285 ----- 

286 This is a simplified expression for the difference between the stack 

287 computed adaptive second-moment and the analytic solution. The variable 

288 z is proportional to the length such that length=2*z*sqrt(2*(Ixx+Iyy)), 

289 and c is a constant (c = 4*Ixx/((Ixx+Iyy)*sqrt(pi))). Both have been 

290 defined to avoid unnecessary floating-point operations in the root 

291 finder. 

292 """ 

293 

294 diff = erf(z) - c*z*np.exp(-z*z) 

295 return diff 

296 

297 @classmethod 

298 def findLength(cls, Ixx, Iyy): 

299 """Find the length of a trail, given adaptive second-moments. 

300 

301 Uses a root finder to compute the length of a trail corresponding to 

302 the adaptive second-moments computed by previous measurements 

303 (ie. SdssShape). 

304 

305 Parameters 

306 ---------- 

307 Ixx : `float` 

308 Adaptive second-moment along x-axis. 

309 Iyy : `float` 

310 Adaptive second-moment along y-axis. 

311 

312 Returns 

313 ------- 

314 length : `float` 

315 Length of the trail. 

316 results : `scipy.optimize.RootResults` 

317 Contains messages about convergence from the root finder. 

318 """ 

319 

320 xpy = Ixx + Iyy 

321 c = 4.0*Ixx/(xpy*np.sqrt(np.pi)) 

322 

323 # Given a 'c' in (c_min, c_max], the root is contained in (0,1]. 

324 # c_min is given by the case: Ixx == Iyy, ie. a point source. 

325 # c_max is given by the limit Ixx >> Iyy. 

326 # Emperically, 0.001 is a suitable lower bound, assuming Ixx > Iyy. 

327 z, results = sciOpt.brentq(lambda z: cls._computeSecondMomentDiff(z, c), 

328 0.001, 1.0, full_output=True) 

329 

330 length = 2.0*z*np.sqrt(2.0*xpy) 

331 gradLength = cls._gradFindLength(Ixx, Iyy, z, c) 

332 return length, gradLength, results 

333 

334 @staticmethod 

335 def _gradFindLength(Ixx, Iyy, z, c): 

336 """Compute the gradient of the findLength function. 

337 """ 

338 spi = np.sqrt(np.pi) 

339 xpy = Ixx+Iyy 

340 xpy2 = xpy*xpy 

341 enz2 = np.exp(-z*z) 

342 sxpy = np.sqrt(xpy) 

343 

344 fac = 4.0 / (spi*xpy2) 

345 dcdIxx = Iyy*fac 

346 dcdIyy = -Ixx*fac 

347 

348 # Derivatives of the _computeMomentsDiff function 

349 dfdc = z*enz2 

350 dzdf = spi / (enz2*(spi*c*(2.0*z*z - 1.0) + 2.0)) # inverse of dfdz 

351 

352 dLdz = 2.0*np.sqrt(2.0)*sxpy 

353 pLpIxx = np.sqrt(2.0)*z / sxpy # Same as pLpIyy 

354 

355 dLdc = dLdz*dzdf*dfdc 

356 dLdIxx = dLdc*dcdIxx + pLpIxx 

357 dLdIyy = dLdc*dcdIyy + pLpIxx 

358 return dLdIxx, dLdIyy 

359 

360 @staticmethod 

361 def computeLength(Ixx, Iyy): 

362 """Compute the length of a trail, given unweighted second-moments. 

363 """ 

364 denom = np.sqrt(Ixx - 2.0*Iyy) 

365 

366 length = np.sqrt(6.0)*denom 

367 

368 dLdIxx = np.sqrt(1.5) / denom 

369 dLdIyy = -np.sqrt(6.0) / denom 

370 return length, (dLdIxx, dLdIyy) 

371 

372 @staticmethod 

373 def computeRaDec(exposure, x, y): 

374 """Convert pixel coordinates to RA and Dec. 

375 

376 Parameters 

377 ---------- 

378 exposure : `lsst.afw.image.ExposureF` 

379 Exposure object containing the WCS. 

380 x : `float` 

381 x coordinate of the trail centroid 

382 y : `float` 

383 y coodinate of the trail centroid 

384 

385 Returns 

386 ------- 

387 ra : `float` 

388 Right ascension. 

389 dec : `float` 

390 Declination. 

391 """ 

392 

393 wcs = exposure.getWcs() 

394 center = wcs.pixelToSky(Point2D(x, y)) 

395 ra = center.getRa().asDegrees() 

396 dec = center.getDec().asDegrees() 

397 return ra, dec