Coverage for python/lsst/meas/extensions/trailedSources/NaivePlugin.py: 22%

179 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-09-13 10:21 +0000

1# 

2# This file is part of meas_extensions_trailedSources. 

3# 

4# Developed for the LSST Data Management System. 

5# This product includes software developed by the LSST Project 

6# (http://www.lsst.org). 

7# See the COPYRIGHT file at the top-level directory of this distribution 

8# for details of code ownership. 

9# 

10# This program is free software: you can redistribute it and/or modify 

11# it under the terms of the GNU General Public License as published by 

12# the Free Software Foundation, either version 3 of the License, or 

13# (at your option) any later version. 

14# 

15# This program is distributed in the hope that it will be useful, 

16# but WITHOUT ANY WARRANTY; without even the implied warranty of 

17# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

18# GNU General Public License for more details. 

19# 

20# You should have received a copy of the GNU General Public License 

21# along with this program. If not, see <http://www.gnu.org/licenses/>. 

22# 

23 

24import logging 

25import numpy as np 

26import scipy.optimize as sciOpt 

27from scipy.special import erf 

28from math import sqrt 

29 

30from lsst.geom import Point2D 

31from lsst.meas.base.pluginRegistry import register 

32from lsst.meas.base import SingleFramePlugin, SingleFramePluginConfig 

33from lsst.meas.base import FlagHandler, FlagDefinitionList, SafeCentroidExtractor 

34 

35from ._trailedSources import VeresModel 

36from .utils import getMeasurementCutout 

37 

38__all__ = ("SingleFrameNaiveTrailConfig", "SingleFrameNaiveTrailPlugin") 

39 

40 

41class SingleFrameNaiveTrailConfig(SingleFramePluginConfig): 

42 """Config class for SingleFrameNaiveTrailPlugin. 

43 """ 

44 pass 

45 

46 

47@register("ext_trailedSources_Naive") 

48class SingleFrameNaiveTrailPlugin(SingleFramePlugin): 

49 """Naive trailed source measurement plugin 

50 

51 Measures the length, angle from +x-axis, and end points of an extended 

52 source using the second moments. 

53 

54 Parameters 

55 ---------- 

56 config: `SingleFrameNaiveTrailConfig` 

57 Plugin configuration. 

58 name: `str` 

59 Plugin name. 

60 schema: `lsst.afw.table.Schema` 

61 Schema for the output catalog. 

62 metadata: `lsst.daf.base.PropertySet` 

63 Metadata to be attached to output catalog. 

64 

65 Notes 

66 ----- 

67 This measurement plugin aims to utilize the already measured adaptive 

68 second moments to naively estimate the length and angle, and thus 

69 end-points, of a fast-moving, trailed source. The length is solved for via 

70 finding the root of the difference between the numerical (stack computed) 

71 and the analytic adaptive second moments. The angle, theta, from the x-axis 

72 is also computed via adaptive moments: theta = arctan(2*Ixy/(Ixx - Iyy))/2. 

73 The end points of the trail are then given by (xc +/- (length/2)*cos(theta) 

74 and yc +/- (length/2)*sin(theta)), with xc and yc being the centroid 

75 coordinates. 

76 

77 See also 

78 -------- 

79 lsst.meas.base.SingleFramePlugin 

80 """ 

81 

82 ConfigClass = SingleFrameNaiveTrailConfig 

83 

84 @classmethod 

85 def getExecutionOrder(cls): 

86 # Needs centroids, shape, and flux measurements. 

87 # VeresPlugin is run after, which requires image data. 

88 return cls.APCORR_ORDER + 0.1 

89 

90 def __init__(self, config, name, schema, metadata, logName=None): 

91 if logName is None: 

92 logName = __name__ 

93 super().__init__(config, name, schema, metadata, logName=logName) 

94 

95 # Measurement Keys 

96 self.keyRa = schema.addField(name + "_ra", type="D", doc="Trail centroid right ascension.") 

97 self.keyDec = schema.addField(name + "_dec", type="D", doc="Trail centroid declination.") 

98 self.keyX0 = schema.addField(name + "_x0", type="D", doc="Trail head X coordinate.", units="pixel") 

99 self.keyY0 = schema.addField(name + "_y0", type="D", doc="Trail head Y coordinate.", units="pixel") 

100 self.keyX1 = schema.addField(name + "_x1", type="D", doc="Trail tail X coordinate.", units="pixel") 

101 self.keyY1 = schema.addField(name + "_y1", type="D", doc="Trail tail Y coordinate.", units="pixel") 

102 self.keyFlux = schema.addField(name + "_flux", type="D", doc="Trailed source flux.", units="count") 

103 self.keyLength = schema.addField(name + "_length", type="D", doc="Trail length.", units="pixel") 

104 self.keyAngle = schema.addField(name + "_angle", type="D", doc="Angle measured from +x-axis.") 

105 

106 # Measurement Error Keys 

107 self.keyX0Err = schema.addField(name + "_x0Err", type="D", 

108 doc="Trail head X coordinate error.", units="pixel") 

109 self.keyY0Err = schema.addField(name + "_y0Err", type="D", 

110 doc="Trail head Y coordinate error.", units="pixel") 

111 self.keyX1Err = schema.addField(name + "_x1Err", type="D", 

112 doc="Trail tail X coordinate error.", units="pixel") 

113 self.keyY1Err = schema.addField(name + "_y1Err", type="D", 

114 doc="Trail tail Y coordinate error.", units="pixel") 

115 self.keyFluxErr = schema.addField(name + "_fluxErr", type="D", 

116 doc="Trail flux error.", units="count") 

117 self.keyLengthErr = schema.addField(name + "_lengthErr", type="D", 

118 doc="Trail length error.", units="pixel") 

119 self.keyAngleErr = schema.addField(name + "_angleErr", type="D", doc="Trail angle error.") 

120 

121 flagDefs = FlagDefinitionList() 

122 self.FAILURE = flagDefs.addFailureFlag("No trailed-source measured") 

123 self.NO_FLUX = flagDefs.add("flag_noFlux", "No suitable prior flux measurement") 

124 self.NO_CONVERGE = flagDefs.add("flag_noConverge", "The root finder did not converge") 

125 self.NO_SIGMA = flagDefs.add("flag_noSigma", "No PSF width (sigma)") 

126 self.SAFE_CENTROID = flagDefs.add("flag_safeCentroid", "Fell back to safe centroid extractor") 

127 self.flagHandler = FlagHandler.addFields(schema, name, flagDefs) 

128 

129 self.centriodExtractor = SafeCentroidExtractor(schema, name) 

130 self.log = logging.getLogger(self.logName) 

131 

132 def measure(self, measRecord, exposure): 

133 """Run the Naive trailed source measurement algorithm. 

134 

135 Parameters 

136 ---------- 

137 measRecord : `lsst.afw.table.SourceRecord` 

138 Record describing the object being measured. 

139 exposure : `lsst.afw.image.Exposure` 

140 Pixel data to be measured. 

141 

142 See also 

143 -------- 

144 lsst.meas.base.SingleFramePlugin.measure 

145 """ 

146 

147 # Get the SdssShape centroid or fall back to slot 

148 # There are currently no centroid errors for SdssShape 

149 xc = measRecord.get("base_SdssShape_x") 

150 yc = measRecord.get("base_SdssShape_y") 

151 if not np.isfinite(xc) or not np.isfinite(yc): 

152 xc, yc = self.centroidExtractor(measRecord, self.flagHandler) 

153 self.flagHandler.setValue(measRecord, self.SAFE_CENTROID.number) 

154 self.flagHandler.setValue(measRecord, self.FAILURE.number) 

155 return 

156 

157 ra, dec = self.computeRaDec(exposure, xc, yc) 

158 

159 # Transform the second-moments to semi-major and minor axes 

160 Ixx, Iyy, Ixy = measRecord.getShape().getParameterVector() 

161 xmy = Ixx - Iyy 

162 xpy = Ixx + Iyy 

163 xmy2 = xmy*xmy 

164 xy2 = Ixy*Ixy 

165 a2 = 0.5 * (xpy + sqrt(xmy2 + 4.0*xy2)) 

166 b2 = 0.5 * (xpy - sqrt(xmy2 + 4.0*xy2)) 

167 

168 # Measure the trail length 

169 # Check if the second-moments are weighted 

170 if measRecord.get("base_SdssShape_flag_unweighted"): 

171 self.log.debug("Unweighted") 

172 length, gradLength = self.computeLength(a2, b2) 

173 else: 

174 self.log.debug("Weighted") 

175 length, gradLength, results = self.findLength(a2, b2) 

176 if not results.converged: 

177 self.log.info("Results not converged: %s", results.flag) 

178 self.flagHandler.setValue(measRecord, self.NO_CONVERGE.number) 

179 self.flagHandler.setValue(measRecord, self.FAILURE.number) 

180 return 

181 

182 # Compute the angle of the trail from the x-axis 

183 theta = 0.5 * np.arctan2(2.0 * Ixy, xmy) 

184 

185 # Get end-points of the trail (there is a degeneracy here) 

186 radius = length/2.0 # Trail 'radius' 

187 dydtheta = radius*np.cos(theta) 

188 dxdtheta = radius*np.sin(theta) 

189 x0 = xc - dydtheta 

190 y0 = yc - dxdtheta 

191 x1 = xc + dydtheta 

192 y1 = yc + dxdtheta 

193 

194 # Get a cutout of the object from the exposure 

195 cutout = getMeasurementCutout(measRecord, exposure) 

196 

197 # Compute flux assuming fixed parameters for VeresModel 

198 params = np.array([xc, yc, 1.0, length, theta]) # Flux = 1.0 

199 model = VeresModel(cutout) 

200 flux, gradFlux = model.computeFluxWithGradient(params) 

201 

202 # Fall back to aperture flux 

203 if not np.isfinite(flux): 

204 if np.isfinite(measRecord.getApInstFlux()): 

205 flux = measRecord.getApInstFlux() 

206 else: 

207 self.flagHandler.setValue(measRecord, self.NO_FLUX.number) 

208 self.flagHandler.setValue(measRecord, self.FAILURE.number) 

209 return 

210 

211 # Propogate errors from second moments and centroid 

212 IxxErr2, IyyErr2, IxyErr2 = np.diag(measRecord.getShapeErr()) 

213 

214 # SdssShape does not produce centroid errors. The 

215 # Slot centroid errors will suffice for now. 

216 xcErr2, ycErr2 = np.diag(measRecord.getCentroidErr()) 

217 

218 # Error in length 

219 desc = sqrt(xmy2 + 4.0*xy2) # Descriminant^1/2 of EV equation 

220 da2dIxx = 0.5*(1.0 + (xmy/desc)) 

221 da2dIyy = 0.5*(1.0 - (xmy/desc)) 

222 da2dIxy = 2.0*Ixy / desc 

223 a2Err2 = IxxErr2*da2dIxx*da2dIxx + IyyErr2*da2dIyy*da2dIyy + IxyErr2*da2dIxy*da2dIxy 

224 b2Err2 = IxxErr2*da2dIyy*da2dIyy + IyyErr2*da2dIxx*da2dIxx + IxyErr2*da2dIxy*da2dIxy 

225 dLda2, dLdb2 = gradLength 

226 lengthErr = np.sqrt(dLda2*dLda2*a2Err2 + dLdb2*dLdb2*b2Err2) 

227 

228 # Error in theta 

229 dThetadIxx = -Ixy / (xmy2 + 4.0*xy2) # dThetadIxx = -dThetadIyy 

230 dThetadIxy = xmy / (xmy2 + 4.0*xy2) 

231 thetaErr = sqrt(dThetadIxx*dThetadIxx*(IxxErr2 + IyyErr2) + dThetadIxy*dThetadIxy*IxyErr2) 

232 

233 # Error in flux 

234 dFdxc, dFdyc, _, dFdL, dFdTheta = gradFlux 

235 fluxErr = sqrt(dFdL*dFdL*lengthErr*lengthErr + dFdTheta*dFdTheta*thetaErr*thetaErr 

236 + dFdxc*dFdxc*xcErr2 + dFdyc*dFdyc*ycErr2) 

237 

238 # Errors in end-points 

239 dxdradius = np.cos(theta) 

240 dydradius = np.sin(theta) 

241 radiusErr2 = lengthErr*lengthErr/4.0 

242 xErr2 = sqrt(xcErr2 + radiusErr2*dxdradius*dxdradius + thetaErr*thetaErr*dxdtheta*dxdtheta) 

243 yErr2 = sqrt(ycErr2 + radiusErr2*dydradius*dydradius + thetaErr*thetaErr*dydtheta*dydtheta) 

244 x0Err = sqrt(xErr2) # Same for x1 

245 y0Err = sqrt(yErr2) # Same for y1 

246 

247 # Set flags 

248 measRecord.set(self.keyRa, ra) 

249 measRecord.set(self.keyDec, dec) 

250 measRecord.set(self.keyX0, x0) 

251 measRecord.set(self.keyY0, y0) 

252 measRecord.set(self.keyX1, x1) 

253 measRecord.set(self.keyY1, y1) 

254 measRecord.set(self.keyFlux, flux) 

255 measRecord.set(self.keyLength, length) 

256 measRecord.set(self.keyAngle, theta) 

257 measRecord.set(self.keyX0Err, x0Err) 

258 measRecord.set(self.keyY0Err, y0Err) 

259 measRecord.set(self.keyX1Err, x0Err) 

260 measRecord.set(self.keyY1Err, y0Err) 

261 measRecord.set(self.keyFluxErr, fluxErr) 

262 measRecord.set(self.keyLengthErr, lengthErr) 

263 measRecord.set(self.keyAngleErr, thetaErr) 

264 

265 def fail(self, measRecord, error=None): 

266 """Record failure 

267 

268 See also 

269 -------- 

270 lsst.meas.base.SingleFramePlugin.fail 

271 """ 

272 if error is None: 

273 self.flagHandler.handleFailure(measRecord) 

274 else: 

275 self.flagHandler.handleFailure(measRecord, error.cpp) 

276 

277 @staticmethod 

278 def _computeSecondMomentDiff(z, c): 

279 """Compute difference of the numerical and analytic second moments. 

280 

281 Parameters 

282 ---------- 

283 z : `float` 

284 Proportional to the length of the trail. (see notes) 

285 c : `float` 

286 Constant (see notes) 

287 

288 Returns 

289 ------- 

290 diff : `float` 

291 Difference in numerical and analytic second moments. 

292 

293 Notes 

294 ----- 

295 This is a simplified expression for the difference between the stack 

296 computed adaptive second-moment and the analytic solution. The variable 

297 z is proportional to the length such that length=2*z*sqrt(2*(Ixx+Iyy)), 

298 and c is a constant (c = 4*Ixx/((Ixx+Iyy)*sqrt(pi))). Both have been 

299 defined to avoid unnecessary floating-point operations in the root 

300 finder. 

301 """ 

302 

303 diff = erf(z) - c*z*np.exp(-z*z) 

304 return diff 

305 

306 @classmethod 

307 def findLength(cls, Ixx, Iyy): 

308 """Find the length of a trail, given adaptive second-moments. 

309 

310 Uses a root finder to compute the length of a trail corresponding to 

311 the adaptive second-moments computed by previous measurements 

312 (ie. SdssShape). 

313 

314 Parameters 

315 ---------- 

316 Ixx : `float` 

317 Adaptive second-moment along x-axis. 

318 Iyy : `float` 

319 Adaptive second-moment along y-axis. 

320 

321 Returns 

322 ------- 

323 length : `float` 

324 Length of the trail. 

325 results : `scipy.optimize.RootResults` 

326 Contains messages about convergence from the root finder. 

327 """ 

328 

329 xpy = Ixx + Iyy 

330 c = 4.0*Ixx/(xpy*np.sqrt(np.pi)) 

331 

332 # Given a 'c' in (c_min, c_max], the root is contained in (0,1]. 

333 # c_min is given by the case: Ixx == Iyy, ie. a point source. 

334 # c_max is given by the limit Ixx >> Iyy. 

335 # Emperically, 0.001 is a suitable lower bound, assuming Ixx > Iyy. 

336 z, results = sciOpt.brentq(lambda z: cls._computeSecondMomentDiff(z, c), 

337 0.001, 1.0, full_output=True) 

338 

339 length = 2.0*z*np.sqrt(2.0*xpy) 

340 gradLength = cls._gradFindLength(Ixx, Iyy, z, c) 

341 return length, gradLength, results 

342 

343 @staticmethod 

344 def _gradFindLength(Ixx, Iyy, z, c): 

345 """Compute the gradient of the findLength function. 

346 """ 

347 spi = np.sqrt(np.pi) 

348 xpy = Ixx+Iyy 

349 xpy2 = xpy*xpy 

350 enz2 = np.exp(-z*z) 

351 sxpy = np.sqrt(xpy) 

352 

353 fac = 4.0 / (spi*xpy2) 

354 dcdIxx = Iyy*fac 

355 dcdIyy = -Ixx*fac 

356 

357 # Derivatives of the _computeMomentsDiff function 

358 dfdc = z*enz2 

359 dzdf = spi / (enz2*(spi*c*(2.0*z*z - 1.0) + 2.0)) # inverse of dfdz 

360 

361 dLdz = 2.0*np.sqrt(2.0)*sxpy 

362 pLpIxx = np.sqrt(2.0)*z / sxpy # Same as pLpIyy 

363 

364 dLdc = dLdz*dzdf*dfdc 

365 dLdIxx = dLdc*dcdIxx + pLpIxx 

366 dLdIyy = dLdc*dcdIyy + pLpIxx 

367 return dLdIxx, dLdIyy 

368 

369 @staticmethod 

370 def computeLength(Ixx, Iyy): 

371 """Compute the length of a trail, given unweighted second-moments. 

372 """ 

373 denom = np.sqrt(Ixx - 2.0*Iyy) 

374 

375 length = np.sqrt(6.0)*denom 

376 

377 dLdIxx = np.sqrt(1.5) / denom 

378 dLdIyy = -np.sqrt(6.0) / denom 

379 return length, (dLdIxx, dLdIyy) 

380 

381 @staticmethod 

382 def computeRaDec(exposure, x, y): 

383 """Convert pixel coordinates to RA and Dec. 

384 

385 Parameters 

386 ---------- 

387 exposure : `lsst.afw.image.ExposureF` 

388 Exposure object containing the WCS. 

389 x : `float` 

390 x coordinate of the trail centroid 

391 y : `float` 

392 y coodinate of the trail centroid 

393 

394 Returns 

395 ------- 

396 ra : `float` 

397 Right ascension. 

398 dec : `float` 

399 Declination. 

400 """ 

401 

402 wcs = exposure.getWcs() 

403 center = wcs.pixelToSky(Point2D(x, y)) 

404 ra = center.getRa().asDegrees() 

405 dec = center.getDec().asDegrees() 

406 return ra, dec