Coverage for python/lsst/meas/extensions/trailedSources/NaivePlugin.py: 18%

178 statements  

« prev     ^ index     » next       coverage.py v6.4.4, created at 2022-08-18 12:14 -0700

1# 

2# This file is part of meas_extensions_trailedSources. 

3# 

4# Developed for the LSST Data Management System. 

5# This product includes software developed by the LSST Project 

6# (http://www.lsst.org). 

7# See the COPYRIGHT file at the top-level directory of this distribution 

8# for details of code ownership. 

9# 

10# This program is free software: you can redistribute it and/or modify 

11# it under the terms of the GNU General Public License as published by 

12# the Free Software Foundation, either version 3 of the License, or 

13# (at your option) any later version. 

14# 

15# This program is distributed in the hope that it will be useful, 

16# but WITHOUT ANY WARRANTY; without even the implied warranty of 

17# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

18# GNU General Public License for more details. 

19# 

20# You should have received a copy of the GNU General Public License 

21# along with this program. If not, see <http://www.gnu.org/licenses/>. 

22# 

23 

24import logging 

25import numpy as np 

26import scipy.optimize as sciOpt 

27from scipy.special import erf 

28 

29from lsst.geom import Point2D 

30from lsst.meas.base.pluginRegistry import register 

31from lsst.meas.base import SingleFramePlugin, SingleFramePluginConfig 

32from lsst.meas.base import FlagHandler, FlagDefinitionList, SafeCentroidExtractor 

33 

34from ._trailedSources import VeresModel 

35from .utils import getMeasurementCutout 

36 

37__all__ = ("SingleFrameNaiveTrailConfig", "SingleFrameNaiveTrailPlugin") 

38 

39 

40class SingleFrameNaiveTrailConfig(SingleFramePluginConfig): 

41 """Config class for SingleFrameNaiveTrailPlugin. 

42 """ 

43 pass 

44 

45 

46@register("ext_trailedSources_Naive") 

47class SingleFrameNaiveTrailPlugin(SingleFramePlugin): 

48 """Naive trailed source measurement plugin 

49 

50 Measures the length, angle from +x-axis, and end points of an extended 

51 source using the second moments. 

52 

53 Parameters 

54 ---------- 

55 config: `SingleFrameNaiveTrailConfig` 

56 Plugin configuration. 

57 name: `str` 

58 Plugin name. 

59 schema: `lsst.afw.table.Schema` 

60 Schema for the output catalog. 

61 metadata: `lsst.daf.base.PropertySet` 

62 Metadata to be attached to output catalog. 

63 

64 Notes 

65 ----- 

66 This measurement plugin aims to utilize the already measured adaptive 

67 second moments to naively estimate the length and angle, and thus 

68 end-points, of a fast-moving, trailed source. The length is solved for via 

69 finding the root of the difference between the numerical (stack computed) 

70 and the analytic adaptive second moments. The angle, theta, from the x-axis 

71 is also computed via adaptive moments: theta = arctan(2*Ixy/(Ixx - Iyy))/2. 

72 The end points of the trail are then given by (xc +/- (length/2)*cos(theta) 

73 and yc +/- (length/2)*sin(theta)), with xc and yc being the centroid 

74 coordinates. 

75 

76 See also 

77 -------- 

78 lsst.meas.base.SingleFramePlugin 

79 """ 

80 

81 ConfigClass = SingleFrameNaiveTrailConfig 

82 

83 @classmethod 

84 def getExecutionOrder(cls): 

85 # Needs centroids, shape, and flux measurements. 

86 # VeresPlugin is run after, which requires image data. 

87 return cls.APCORR_ORDER + 0.1 

88 

89 def __init__(self, config, name, schema, metadata, logName=None): 

90 if logName is None: 

91 logName = __name__ 

92 super().__init__(config, name, schema, metadata, logName=logName) 

93 

94 # Measurement Keys 

95 self.keyRa = schema.addField(name + "_ra", type="D", doc="Trail centroid right ascension.") 

96 self.keyDec = schema.addField(name + "_dec", type="D", doc="Trail centroid declination.") 

97 self.keyX0 = schema.addField(name + "_x0", type="D", doc="Trail head X coordinate.", units="pixel") 

98 self.keyY0 = schema.addField(name + "_y0", type="D", doc="Trail head Y coordinate.", units="pixel") 

99 self.keyX1 = schema.addField(name + "_x1", type="D", doc="Trail tail X coordinate.", units="pixel") 

100 self.keyY1 = schema.addField(name + "_y1", type="D", doc="Trail tail Y coordinate.", units="pixel") 

101 self.keyFlux = schema.addField(name + "_flux", type="D", doc="Trailed source flux.", units="count") 

102 self.keyLength = schema.addField(name + "_length", type="D", doc="Trail length.", units="pixel") 

103 self.keyAngle = schema.addField(name + "_angle", type="D", doc="Angle measured from +x-axis.") 

104 

105 # Measurement Error Keys 

106 self.keyX0Err = schema.addField(name + "_x0Err", type="D", 

107 doc="Trail head X coordinate error.", units="pixel") 

108 self.keyY0Err = schema.addField(name + "_y0Err", type="D", 

109 doc="Trail head Y coordinate error.", units="pixel") 

110 self.keyX1Err = schema.addField(name + "_x1Err", type="D", 

111 doc="Trail tail X coordinate error.", units="pixel") 

112 self.keyY1Err = schema.addField(name + "_y1Err", type="D", 

113 doc="Trail tail Y coordinate error.", units="pixel") 

114 self.keyFluxErr = schema.addField(name + "_fluxErr", type="D", 

115 doc="Trail flux error.", units="count") 

116 self.keyLengthErr = schema.addField(name + "_lengthErr", type="D", 

117 doc="Trail length error.", units="pixel") 

118 self.keyAngleErr = schema.addField(name + "_angleErr", type="D", doc="Trail angle error.") 

119 

120 flagDefs = FlagDefinitionList() 

121 self.FAILURE = flagDefs.addFailureFlag("No trailed-source measured") 

122 self.NO_FLUX = flagDefs.add("flag_noFlux", "No suitable prior flux measurement") 

123 self.NO_CONVERGE = flagDefs.add("flag_noConverge", "The root finder did not converge") 

124 self.NO_SIGMA = flagDefs.add("flag_noSigma", "No PSF width (sigma)") 

125 self.SAFE_CENTROID = flagDefs.add("flag_safeCentroid", "Fell back to safe centroid extractor") 

126 self.flagHandler = FlagHandler.addFields(schema, name, flagDefs) 

127 

128 self.centriodExtractor = SafeCentroidExtractor(schema, name) 

129 self.log = logging.getLogger(self.logName) 

130 

131 def measure(self, measRecord, exposure): 

132 """Run the Naive trailed source measurement algorithm. 

133 

134 Parameters 

135 ---------- 

136 measRecord : `lsst.afw.table.SourceRecord` 

137 Record describing the object being measured. 

138 exposure : `lsst.afw.image.Exposure` 

139 Pixel data to be measured. 

140 

141 See also 

142 -------- 

143 lsst.meas.base.SingleFramePlugin.measure 

144 """ 

145 

146 # Get the SdssShape centroid or fall back to slot 

147 # There are currently no centroid errors for SdssShape 

148 xc = measRecord.get("base_SdssShape_x") 

149 yc = measRecord.get("base_SdssShape_y") 

150 if not np.isfinite(xc) or not np.isfinite(yc): 

151 xc, yc = self.centroidExtractor(measRecord, self.flagHandler) 

152 self.flagHandler.setValue(measRecord, self.SAFE_CENTROID.number) 

153 self.flagHandler.setValue(measRecord, self.FAILURE.number) 

154 return 

155 

156 ra, dec = self.computeRaDec(exposure, xc, yc) 

157 

158 # Transform the second-moments to semi-major and minor axes 

159 Ixx, Iyy, Ixy = measRecord.getShape().getParameterVector() 

160 xmy = Ixx - Iyy 

161 xpy = Ixx + Iyy 

162 xmy2 = xmy*xmy 

163 xy2 = Ixy*Ixy 

164 a2 = 0.5 * (xpy + np.sqrt(xmy2 + 4.0*xy2)) 

165 b2 = 0.5 * (xpy - np.sqrt(xmy2 + 4.0*xy2)) 

166 

167 # Measure the trail length 

168 # Check if the second-moments are weighted 

169 if measRecord.get("base_SdssShape_flag_unweighted"): 

170 self.log.debug("Unweighted") 

171 length, gradLength = self.computeLength(a2, b2) 

172 else: 

173 self.log.debug("Weighted") 

174 length, gradLength, results = self.findLength(a2, b2) 

175 if not results.converged: 

176 self.log.info("Results not converged: %s", results.flag) 

177 self.flagHandler.setValue(measRecord, self.NO_CONVERGE.number) 

178 self.flagHandler.setValue(measRecord, self.FAILURE.number) 

179 return 

180 

181 # Compute the angle of the trail from the x-axis 

182 theta = 0.5 * np.arctan2(2.0 * Ixy, xmy) 

183 

184 # Get end-points of the trail (there is a degeneracy here) 

185 radius = length/2.0 # Trail 'radius' 

186 dydtheta = radius*np.cos(theta) 

187 dxdtheta = radius*np.sin(theta) 

188 x0 = xc - dydtheta 

189 y0 = yc - dxdtheta 

190 x1 = xc + dydtheta 

191 y1 = yc + dxdtheta 

192 

193 # Get a cutout of the object from the exposure 

194 cutout = getMeasurementCutout(measRecord, exposure) 

195 

196 # Compute flux assuming fixed parameters for VeresModel 

197 params = np.array([xc, yc, 1.0, length, theta]) # Flux = 1.0 

198 model = VeresModel(cutout) 

199 flux, gradFlux = model.computeFluxWithGradient(params) 

200 

201 # Fall back to aperture flux 

202 if not np.isfinite(flux): 

203 if np.isfinite(measRecord.getApInstFlux()): 

204 flux = measRecord.getApInstFlux() 

205 else: 

206 self.flagHandler.setValue(measRecord, self.NO_FLUX.number) 

207 self.flagHandler.setValue(measRecord, self.FAILURE.number) 

208 return 

209 

210 # Propogate errors from second moments and centroid 

211 IxxErr2, IyyErr2, IxyErr2 = np.diag(measRecord.getShapeErr()) 

212 

213 # SdssShape does not produce centroid errors. The 

214 # Slot centroid errors will suffice for now. 

215 xcErr2, ycErr2 = np.diag(measRecord.getCentroidErr()) 

216 

217 # Error in length 

218 desc = np.sqrt(xmy2 + 4.0*xy2) # Descriminant^1/2 of EV equation 

219 da2dIxx = 0.5*(1.0 + (xmy/desc)) 

220 da2dIyy = 0.5*(1.0 - (xmy/desc)) 

221 da2dIxy = 2.0*Ixy / desc 

222 a2Err2 = IxxErr2*da2dIxx*da2dIxx + IyyErr2*da2dIyy*da2dIyy + IxyErr2*da2dIxy*da2dIxy 

223 b2Err2 = IxxErr2*da2dIyy*da2dIyy + IyyErr2*da2dIxx*da2dIxx + IxyErr2*da2dIxy*da2dIxy 

224 dLda2, dLdb2 = gradLength 

225 lengthErr = np.sqrt(dLda2*dLda2*a2Err2 + dLdb2*dLdb2*b2Err2) 

226 

227 # Error in theta 

228 dThetadIxx = -Ixy / (xmy2 + 4.0*xy2) # dThetadIxx = -dThetadIyy 

229 dThetadIxy = xmy / (xmy2 + 4.0*xy2) 

230 thetaErr = np.sqrt(dThetadIxx*dThetadIxx*(IxxErr2 + IyyErr2) + dThetadIxy*dThetadIxy*IxyErr2) 

231 

232 # Error in flux 

233 dFdxc, dFdyc, _, dFdL, dFdTheta = gradFlux 

234 fluxErr = np.sqrt(dFdL*dFdL*lengthErr*lengthErr + dFdTheta*dFdTheta*thetaErr*thetaErr 

235 + dFdxc*dFdxc*xcErr2 + dFdyc*dFdyc*ycErr2) 

236 

237 # Errors in end-points 

238 dxdradius = np.cos(theta) 

239 dydradius = np.sin(theta) 

240 radiusErr2 = lengthErr*lengthErr/4.0 

241 xErr2 = np.sqrt(xcErr2 + radiusErr2*dxdradius*dxdradius + thetaErr*thetaErr*dxdtheta*dxdtheta) 

242 yErr2 = np.sqrt(ycErr2 + radiusErr2*dydradius*dydradius + thetaErr*thetaErr*dydtheta*dydtheta) 

243 x0Err = np.sqrt(xErr2) # Same for x1 

244 y0Err = np.sqrt(yErr2) # Same for y1 

245 

246 # Set flags 

247 measRecord.set(self.keyRa, ra) 

248 measRecord.set(self.keyDec, dec) 

249 measRecord.set(self.keyX0, x0) 

250 measRecord.set(self.keyY0, y0) 

251 measRecord.set(self.keyX1, x1) 

252 measRecord.set(self.keyY1, y1) 

253 measRecord.set(self.keyFlux, flux) 

254 measRecord.set(self.keyLength, length) 

255 measRecord.set(self.keyAngle, theta) 

256 measRecord.set(self.keyX0Err, x0Err) 

257 measRecord.set(self.keyY0Err, y0Err) 

258 measRecord.set(self.keyX1Err, x0Err) 

259 measRecord.set(self.keyY1Err, y0Err) 

260 measRecord.set(self.keyFluxErr, fluxErr) 

261 measRecord.set(self.keyLengthErr, lengthErr) 

262 measRecord.set(self.keyAngleErr, thetaErr) 

263 

264 def fail(self, measRecord, error=None): 

265 """Record failure 

266 

267 See also 

268 -------- 

269 lsst.meas.base.SingleFramePlugin.fail 

270 """ 

271 if error is None: 

272 self.flagHandler.handleFailure(measRecord) 

273 else: 

274 self.flagHandler.handleFailure(measRecord, error.cpp) 

275 

276 @staticmethod 

277 def _computeSecondMomentDiff(z, c): 

278 """Compute difference of the numerical and analytic second moments. 

279 

280 Parameters 

281 ---------- 

282 z : `float` 

283 Proportional to the length of the trail. (see notes) 

284 c : `float` 

285 Constant (see notes) 

286 

287 Returns 

288 ------- 

289 diff : `float` 

290 Difference in numerical and analytic second moments. 

291 

292 Notes 

293 ----- 

294 This is a simplified expression for the difference between the stack 

295 computed adaptive second-moment and the analytic solution. The variable 

296 z is proportional to the length such that length=2*z*sqrt(2*(Ixx+Iyy)), 

297 and c is a constant (c = 4*Ixx/((Ixx+Iyy)*sqrt(pi))). Both have been 

298 defined to avoid unnecessary floating-point operations in the root 

299 finder. 

300 """ 

301 

302 diff = erf(z) - c*z*np.exp(-z*z) 

303 return diff 

304 

305 @classmethod 

306 def findLength(cls, Ixx, Iyy): 

307 """Find the length of a trail, given adaptive second-moments. 

308 

309 Uses a root finder to compute the length of a trail corresponding to 

310 the adaptive second-moments computed by previous measurements 

311 (ie. SdssShape). 

312 

313 Parameters 

314 ---------- 

315 Ixx : `float` 

316 Adaptive second-moment along x-axis. 

317 Iyy : `float` 

318 Adaptive second-moment along y-axis. 

319 

320 Returns 

321 ------- 

322 length : `float` 

323 Length of the trail. 

324 results : `scipy.optimize.RootResults` 

325 Contains messages about convergence from the root finder. 

326 """ 

327 

328 xpy = Ixx + Iyy 

329 c = 4.0*Ixx/(xpy*np.sqrt(np.pi)) 

330 

331 # Given a 'c' in (c_min, c_max], the root is contained in (0,1]. 

332 # c_min is given by the case: Ixx == Iyy, ie. a point source. 

333 # c_max is given by the limit Ixx >> Iyy. 

334 # Emperically, 0.001 is a suitable lower bound, assuming Ixx > Iyy. 

335 z, results = sciOpt.brentq(lambda z: cls._computeSecondMomentDiff(z, c), 

336 0.001, 1.0, full_output=True) 

337 

338 length = 2.0*z*np.sqrt(2.0*xpy) 

339 gradLength = cls._gradFindLength(Ixx, Iyy, z, c) 

340 return length, gradLength, results 

341 

342 @staticmethod 

343 def _gradFindLength(Ixx, Iyy, z, c): 

344 """Compute the gradient of the findLength function. 

345 """ 

346 spi = np.sqrt(np.pi) 

347 xpy = Ixx+Iyy 

348 xpy2 = xpy*xpy 

349 enz2 = np.exp(-z*z) 

350 sxpy = np.sqrt(xpy) 

351 

352 fac = 4.0 / (spi*xpy2) 

353 dcdIxx = Iyy*fac 

354 dcdIyy = -Ixx*fac 

355 

356 # Derivatives of the _computeMomentsDiff function 

357 dfdc = z*enz2 

358 dzdf = spi / (enz2*(spi*c*(2.0*z*z - 1.0) + 2.0)) # inverse of dfdz 

359 

360 dLdz = 2.0*np.sqrt(2.0)*sxpy 

361 pLpIxx = np.sqrt(2.0)*z / sxpy # Same as pLpIyy 

362 

363 dLdc = dLdz*dzdf*dfdc 

364 dLdIxx = dLdc*dcdIxx + pLpIxx 

365 dLdIyy = dLdc*dcdIyy + pLpIxx 

366 return dLdIxx, dLdIyy 

367 

368 @staticmethod 

369 def computeLength(Ixx, Iyy): 

370 """Compute the length of a trail, given unweighted second-moments. 

371 """ 

372 denom = np.sqrt(Ixx - 2.0*Iyy) 

373 

374 length = np.sqrt(6.0)*denom 

375 

376 dLdIxx = np.sqrt(1.5) / denom 

377 dLdIyy = -np.sqrt(6.0) / denom 

378 return length, (dLdIxx, dLdIyy) 

379 

380 @staticmethod 

381 def computeRaDec(exposure, x, y): 

382 """Convert pixel coordinates to RA and Dec. 

383 

384 Parameters 

385 ---------- 

386 exposure : `lsst.afw.image.ExposureF` 

387 Exposure object containing the WCS. 

388 x : `float` 

389 x coordinate of the trail centroid 

390 y : `float` 

391 y coodinate of the trail centroid 

392 

393 Returns 

394 ------- 

395 ra : `float` 

396 Right ascension. 

397 dec : `float` 

398 Declination. 

399 """ 

400 

401 wcs = exposure.getWcs() 

402 center = wcs.pixelToSky(Point2D(x, y)) 

403 ra = center.getRa().asDegrees() 

404 dec = center.getDec().asDegrees() 

405 return ra, dec