Coverage for python/lsst/meas/extensions/trailedSources/NaivePlugin.py: 21%

185 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-10-13 10:13 +0000

1# 

2# This file is part of meas_extensions_trailedSources. 

3# 

4# Developed for the LSST Data Management System. 

5# This product includes software developed by the LSST Project 

6# (http://www.lsst.org). 

7# See the COPYRIGHT file at the top-level directory of this distribution 

8# for details of code ownership. 

9# 

10# This program is free software: you can redistribute it and/or modify 

11# it under the terms of the GNU General Public License as published by 

12# the Free Software Foundation, either version 3 of the License, or 

13# (at your option) any later version. 

14# 

15# This program is distributed in the hope that it will be useful, 

16# but WITHOUT ANY WARRANTY; without even the implied warranty of 

17# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

18# GNU General Public License for more details. 

19# 

20# You should have received a copy of the GNU General Public License 

21# along with this program. If not, see <http://www.gnu.org/licenses/>. 

22# 

23 

24import logging 

25import numpy as np 

26import scipy.optimize as sciOpt 

27from scipy.special import erf 

28from math import sqrt 

29 

30from lsst.geom import Point2D, Point2I 

31from lsst.meas.base.pluginRegistry import register 

32from lsst.meas.base import SingleFramePlugin, SingleFramePluginConfig 

33from lsst.meas.base import FlagHandler, FlagDefinitionList, SafeCentroidExtractor 

34 

35from ._trailedSources import VeresModel 

36from .utils import getMeasurementCutout 

37 

38__all__ = ("SingleFrameNaiveTrailConfig", "SingleFrameNaiveTrailPlugin") 

39 

40 

41class SingleFrameNaiveTrailConfig(SingleFramePluginConfig): 

42 """Config class for SingleFrameNaiveTrailPlugin. 

43 """ 

44 pass 

45 

46 

47@register("ext_trailedSources_Naive") 

48class SingleFrameNaiveTrailPlugin(SingleFramePlugin): 

49 """Naive trailed source measurement plugin 

50 

51 Measures the length, angle from +x-axis, and end points of an extended 

52 source using the second moments. 

53 

54 Parameters 

55 ---------- 

56 config: `SingleFrameNaiveTrailConfig` 

57 Plugin configuration. 

58 name: `str` 

59 Plugin name. 

60 schema: `lsst.afw.table.Schema` 

61 Schema for the output catalog. 

62 metadata: `lsst.daf.base.PropertySet` 

63 Metadata to be attached to output catalog. 

64 

65 Notes 

66 ----- 

67 This measurement plugin aims to utilize the already measured adaptive 

68 second moments to naively estimate the length and angle, and thus 

69 end-points, of a fast-moving, trailed source. The length is solved for via 

70 finding the root of the difference between the numerical (stack computed) 

71 and the analytic adaptive second moments. The angle, theta, from the x-axis 

72 is also computed via adaptive moments: theta = arctan(2*Ixy/(Ixx - Iyy))/2. 

73 The end points of the trail are then given by (xc +/- (length/2)*cos(theta) 

74 and yc +/- (length/2)*sin(theta)), with xc and yc being the centroid 

75 coordinates. 

76 

77 See also 

78 -------- 

79 lsst.meas.base.SingleFramePlugin 

80 """ 

81 

82 ConfigClass = SingleFrameNaiveTrailConfig 

83 

84 @classmethod 

85 def getExecutionOrder(cls): 

86 # Needs centroids, shape, and flux measurements. 

87 # VeresPlugin is run after, which requires image data. 

88 return cls.APCORR_ORDER + 0.1 

89 

90 def __init__(self, config, name, schema, metadata, logName=None): 

91 if logName is None: 

92 logName = __name__ 

93 super().__init__(config, name, schema, metadata, logName=logName) 

94 

95 # Measurement Keys 

96 self.keyRa = schema.addField(name + "_ra", type="D", doc="Trail centroid right ascension.") 

97 self.keyDec = schema.addField(name + "_dec", type="D", doc="Trail centroid declination.") 

98 self.keyX0 = schema.addField(name + "_x0", type="D", doc="Trail head X coordinate.", units="pixel") 

99 self.keyY0 = schema.addField(name + "_y0", type="D", doc="Trail head Y coordinate.", units="pixel") 

100 self.keyX1 = schema.addField(name + "_x1", type="D", doc="Trail tail X coordinate.", units="pixel") 

101 self.keyY1 = schema.addField(name + "_y1", type="D", doc="Trail tail Y coordinate.", units="pixel") 

102 self.keyFlux = schema.addField(name + "_flux", type="D", doc="Trailed source flux.", units="count") 

103 self.keyLength = schema.addField(name + "_length", type="D", doc="Trail length.", units="pixel") 

104 self.keyAngle = schema.addField(name + "_angle", type="D", doc="Angle measured from +x-axis.") 

105 

106 # Measurement Error Keys 

107 self.keyX0Err = schema.addField(name + "_x0Err", type="D", 

108 doc="Trail head X coordinate error.", units="pixel") 

109 self.keyY0Err = schema.addField(name + "_y0Err", type="D", 

110 doc="Trail head Y coordinate error.", units="pixel") 

111 self.keyX1Err = schema.addField(name + "_x1Err", type="D", 

112 doc="Trail tail X coordinate error.", units="pixel") 

113 self.keyY1Err = schema.addField(name + "_y1Err", type="D", 

114 doc="Trail tail Y coordinate error.", units="pixel") 

115 self.keyFluxErr = schema.addField(name + "_fluxErr", type="D", 

116 doc="Trail flux error.", units="count") 

117 self.keyLengthErr = schema.addField(name + "_lengthErr", type="D", 

118 doc="Trail length error.", units="pixel") 

119 self.keyAngleErr = schema.addField(name + "_angleErr", type="D", doc="Trail angle error.") 

120 

121 flagDefs = FlagDefinitionList() 

122 self.FAILURE = flagDefs.addFailureFlag("No trailed-source measured") 

123 self.NO_FLUX = flagDefs.add("flag_noFlux", "No suitable prior flux measurement") 

124 self.NO_CONVERGE = flagDefs.add("flag_noConverge", "The root finder did not converge") 

125 self.NO_SIGMA = flagDefs.add("flag_noSigma", "No PSF width (sigma)") 

126 self.SAFE_CENTROID = flagDefs.add("flag_safeCentroid", "Fell back to safe centroid extractor") 

127 self.EDGE = flagDefs.add("flag_edge", "Trail contains edge pixels or extends off chip") 

128 self.flagHandler = FlagHandler.addFields(schema, name, flagDefs) 

129 

130 self.centriodExtractor = SafeCentroidExtractor(schema, name) 

131 self.log = logging.getLogger(self.logName) 

132 

133 def measure(self, measRecord, exposure): 

134 """Run the Naive trailed source measurement algorithm. 

135 

136 Parameters 

137 ---------- 

138 measRecord : `lsst.afw.table.SourceRecord` 

139 Record describing the object being measured. 

140 exposure : `lsst.afw.image.Exposure` 

141 Pixel data to be measured. 

142 

143 See also 

144 -------- 

145 lsst.meas.base.SingleFramePlugin.measure 

146 """ 

147 

148 # Get the SdssShape centroid or fall back to slot 

149 # There are currently no centroid errors for SdssShape 

150 xc = measRecord.get("base_SdssShape_x") 

151 yc = measRecord.get("base_SdssShape_y") 

152 if not np.isfinite(xc) or not np.isfinite(yc): 

153 xc, yc = self.centroidExtractor(measRecord, self.flagHandler) 

154 self.flagHandler.setValue(measRecord, self.SAFE_CENTROID.number, True) 

155 self.flagHandler.setValue(measRecord, self.FAILURE.number, True) 

156 return 

157 

158 ra, dec = self.computeRaDec(exposure, xc, yc) 

159 

160 # Transform the second-moments to semi-major and minor axes 

161 Ixx, Iyy, Ixy = measRecord.getShape().getParameterVector() 

162 xmy = Ixx - Iyy 

163 xpy = Ixx + Iyy 

164 xmy2 = xmy*xmy 

165 xy2 = Ixy*Ixy 

166 a2 = 0.5 * (xpy + sqrt(xmy2 + 4.0*xy2)) 

167 b2 = 0.5 * (xpy - sqrt(xmy2 + 4.0*xy2)) 

168 

169 # Measure the trail length 

170 # Check if the second-moments are weighted 

171 if measRecord.get("base_SdssShape_flag_unweighted"): 

172 self.log.debug("Unweighted") 

173 length, gradLength = self.computeLength(a2, b2) 

174 else: 

175 self.log.debug("Weighted") 

176 length, gradLength, results = self.findLength(a2, b2) 

177 if not results.converged: 

178 self.log.info("Results not converged: %s", results.flag) 

179 self.flagHandler.setValue(measRecord, self.NO_CONVERGE.number, True) 

180 self.flagHandler.setValue(measRecord, self.FAILURE.number, True) 

181 return 

182 

183 # Compute the angle of the trail from the x-axis 

184 theta = 0.5 * np.arctan2(2.0 * Ixy, xmy) 

185 

186 # Get end-points of the trail (there is a degeneracy here) 

187 radius = length/2.0 # Trail 'radius' 

188 dydtheta = radius*np.cos(theta) 

189 dxdtheta = radius*np.sin(theta) 

190 x0 = xc - dydtheta 

191 y0 = yc - dxdtheta 

192 x1 = xc + dydtheta 

193 y1 = yc + dxdtheta 

194 

195 # Check whether trail extends off the edge of the exposure 

196 if not (exposure.getBBox().beginX <= x0 <= exposure.getBBox().endX 

197 and exposure.getBBox().beginX <= x1 <= exposure.getBBox().endX 

198 and exposure.getBBox().beginY <= y0 <= exposure.getBBox().endY 

199 and exposure.getBBox().beginY <= y1 <= exposure.getBBox().endY): 

200 

201 self.flagHandler.setValue(measRecord, self.EDGE.number, True) 

202 

203 else: 

204 # Check whether the beginning or end point of the trail has the 

205 # edge flag set. The end points are not whole pixel values, so 

206 # the pixel value must be rounded. 

207 if exposure.mask[Point2I(int(x0), int(y0))] and exposure.mask[Point2I(int(x1), int(y1))]: 

208 if ((exposure.mask[Point2I(int(x0), int(y0))] & exposure.mask.getPlaneBitMask('EDGE') != 0) 

209 or (exposure.mask[Point2I(int(x1), int(y1))] 

210 & exposure.mask.getPlaneBitMask('EDGE') != 0)): 

211 

212 self.flagHandler.setValue(measRecord, self.EDGE.number, True) 

213 

214 # Get a cutout of the object from the exposure 

215 cutout = getMeasurementCutout(measRecord, exposure) 

216 

217 # Compute flux assuming fixed parameters for VeresModel 

218 params = np.array([xc, yc, 1.0, length, theta]) # Flux = 1.0 

219 model = VeresModel(cutout) 

220 flux, gradFlux = model.computeFluxWithGradient(params) 

221 

222 # Fall back to aperture flux 

223 if not np.isfinite(flux): 

224 if np.isfinite(measRecord.getApInstFlux()): 

225 flux = measRecord.getApInstFlux() 

226 else: 

227 self.flagHandler.setValue(measRecord, self.NO_FLUX.number, True) 

228 self.flagHandler.setValue(measRecord, self.FAILURE.number, True) 

229 return 

230 

231 # Propogate errors from second moments and centroid 

232 IxxErr2, IyyErr2, IxyErr2 = np.diag(measRecord.getShapeErr()) 

233 

234 # SdssShape does not produce centroid errors. The 

235 # Slot centroid errors will suffice for now. 

236 xcErr2, ycErr2 = np.diag(measRecord.getCentroidErr()) 

237 

238 # Error in length 

239 desc = sqrt(xmy2 + 4.0*xy2) # Descriminant^1/2 of EV equation 

240 da2dIxx = 0.5*(1.0 + (xmy/desc)) 

241 da2dIyy = 0.5*(1.0 - (xmy/desc)) 

242 da2dIxy = 2.0*Ixy / desc 

243 a2Err2 = IxxErr2*da2dIxx*da2dIxx + IyyErr2*da2dIyy*da2dIyy + IxyErr2*da2dIxy*da2dIxy 

244 b2Err2 = IxxErr2*da2dIyy*da2dIyy + IyyErr2*da2dIxx*da2dIxx + IxyErr2*da2dIxy*da2dIxy 

245 dLda2, dLdb2 = gradLength 

246 lengthErr = np.sqrt(dLda2*dLda2*a2Err2 + dLdb2*dLdb2*b2Err2) 

247 

248 # Error in theta 

249 dThetadIxx = -Ixy / (xmy2 + 4.0*xy2) # dThetadIxx = -dThetadIyy 

250 dThetadIxy = xmy / (xmy2 + 4.0*xy2) 

251 thetaErr = sqrt(dThetadIxx*dThetadIxx*(IxxErr2 + IyyErr2) + dThetadIxy*dThetadIxy*IxyErr2) 

252 

253 # Error in flux 

254 dFdxc, dFdyc, _, dFdL, dFdTheta = gradFlux 

255 fluxErr = sqrt(dFdL*dFdL*lengthErr*lengthErr + dFdTheta*dFdTheta*thetaErr*thetaErr 

256 + dFdxc*dFdxc*xcErr2 + dFdyc*dFdyc*ycErr2) 

257 

258 # Errors in end-points 

259 dxdradius = np.cos(theta) 

260 dydradius = np.sin(theta) 

261 radiusErr2 = lengthErr*lengthErr/4.0 

262 xErr2 = sqrt(xcErr2 + radiusErr2*dxdradius*dxdradius + thetaErr*thetaErr*dxdtheta*dxdtheta) 

263 yErr2 = sqrt(ycErr2 + radiusErr2*dydradius*dydradius + thetaErr*thetaErr*dydtheta*dydtheta) 

264 x0Err = sqrt(xErr2) # Same for x1 

265 y0Err = sqrt(yErr2) # Same for y1 

266 

267 # Set flags 

268 measRecord.set(self.keyRa, ra) 

269 measRecord.set(self.keyDec, dec) 

270 measRecord.set(self.keyX0, x0) 

271 measRecord.set(self.keyY0, y0) 

272 measRecord.set(self.keyX1, x1) 

273 measRecord.set(self.keyY1, y1) 

274 measRecord.set(self.keyFlux, flux) 

275 measRecord.set(self.keyLength, length) 

276 measRecord.set(self.keyAngle, theta) 

277 measRecord.set(self.keyX0Err, x0Err) 

278 measRecord.set(self.keyY0Err, y0Err) 

279 measRecord.set(self.keyX1Err, x0Err) 

280 measRecord.set(self.keyY1Err, y0Err) 

281 measRecord.set(self.keyFluxErr, fluxErr) 

282 measRecord.set(self.keyLengthErr, lengthErr) 

283 measRecord.set(self.keyAngleErr, thetaErr) 

284 

285 def fail(self, measRecord, error=None): 

286 """Record failure 

287 

288 See also 

289 -------- 

290 lsst.meas.base.SingleFramePlugin.fail 

291 """ 

292 if error is None: 

293 self.flagHandler.handleFailure(measRecord) 

294 else: 

295 self.flagHandler.handleFailure(measRecord, error.cpp) 

296 

297 @staticmethod 

298 def _computeSecondMomentDiff(z, c): 

299 """Compute difference of the numerical and analytic second moments. 

300 

301 Parameters 

302 ---------- 

303 z : `float` 

304 Proportional to the length of the trail. (see notes) 

305 c : `float` 

306 Constant (see notes) 

307 

308 Returns 

309 ------- 

310 diff : `float` 

311 Difference in numerical and analytic second moments. 

312 

313 Notes 

314 ----- 

315 This is a simplified expression for the difference between the stack 

316 computed adaptive second-moment and the analytic solution. The variable 

317 z is proportional to the length such that length=2*z*sqrt(2*(Ixx+Iyy)), 

318 and c is a constant (c = 4*Ixx/((Ixx+Iyy)*sqrt(pi))). Both have been 

319 defined to avoid unnecessary floating-point operations in the root 

320 finder. 

321 """ 

322 

323 diff = erf(z) - c*z*np.exp(-z*z) 

324 return diff 

325 

326 @classmethod 

327 def findLength(cls, Ixx, Iyy): 

328 """Find the length of a trail, given adaptive second-moments. 

329 

330 Uses a root finder to compute the length of a trail corresponding to 

331 the adaptive second-moments computed by previous measurements 

332 (ie. SdssShape). 

333 

334 Parameters 

335 ---------- 

336 Ixx : `float` 

337 Adaptive second-moment along x-axis. 

338 Iyy : `float` 

339 Adaptive second-moment along y-axis. 

340 

341 Returns 

342 ------- 

343 length : `float` 

344 Length of the trail. 

345 results : `scipy.optimize.RootResults` 

346 Contains messages about convergence from the root finder. 

347 """ 

348 

349 xpy = Ixx + Iyy 

350 c = 4.0*Ixx/(xpy*np.sqrt(np.pi)) 

351 

352 # Given a 'c' in (c_min, c_max], the root is contained in (0,1]. 

353 # c_min is given by the case: Ixx == Iyy, ie. a point source. 

354 # c_max is given by the limit Ixx >> Iyy. 

355 # Emperically, 0.001 is a suitable lower bound, assuming Ixx > Iyy. 

356 z, results = sciOpt.brentq(lambda z: cls._computeSecondMomentDiff(z, c), 

357 0.001, 1.0, full_output=True) 

358 

359 length = 2.0*z*np.sqrt(2.0*xpy) 

360 gradLength = cls._gradFindLength(Ixx, Iyy, z, c) 

361 return length, gradLength, results 

362 

363 @staticmethod 

364 def _gradFindLength(Ixx, Iyy, z, c): 

365 """Compute the gradient of the findLength function. 

366 """ 

367 spi = np.sqrt(np.pi) 

368 xpy = Ixx+Iyy 

369 xpy2 = xpy*xpy 

370 enz2 = np.exp(-z*z) 

371 sxpy = np.sqrt(xpy) 

372 

373 fac = 4.0 / (spi*xpy2) 

374 dcdIxx = Iyy*fac 

375 dcdIyy = -Ixx*fac 

376 

377 # Derivatives of the _computeMomentsDiff function 

378 dfdc = z*enz2 

379 dzdf = spi / (enz2*(spi*c*(2.0*z*z - 1.0) + 2.0)) # inverse of dfdz 

380 

381 dLdz = 2.0*np.sqrt(2.0)*sxpy 

382 pLpIxx = np.sqrt(2.0)*z / sxpy # Same as pLpIyy 

383 

384 dLdc = dLdz*dzdf*dfdc 

385 dLdIxx = dLdc*dcdIxx + pLpIxx 

386 dLdIyy = dLdc*dcdIyy + pLpIxx 

387 return dLdIxx, dLdIyy 

388 

389 @staticmethod 

390 def computeLength(Ixx, Iyy): 

391 """Compute the length of a trail, given unweighted second-moments. 

392 """ 

393 denom = np.sqrt(Ixx - 2.0*Iyy) 

394 

395 length = np.sqrt(6.0)*denom 

396 

397 dLdIxx = np.sqrt(1.5) / denom 

398 dLdIyy = -np.sqrt(6.0) / denom 

399 return length, (dLdIxx, dLdIyy) 

400 

401 @staticmethod 

402 def computeRaDec(exposure, x, y): 

403 """Convert pixel coordinates to RA and Dec. 

404 

405 Parameters 

406 ---------- 

407 exposure : `lsst.afw.image.ExposureF` 

408 Exposure object containing the WCS. 

409 x : `float` 

410 x coordinate of the trail centroid 

411 y : `float` 

412 y coodinate of the trail centroid 

413 

414 Returns 

415 ------- 

416 ra : `float` 

417 Right ascension. 

418 dec : `float` 

419 Declination. 

420 """ 

421 

422 wcs = exposure.getWcs() 

423 center = wcs.pixelToSky(Point2D(x, y)) 

424 ra = center.getRa().asDegrees() 

425 dec = center.getDec().asDegrees() 

426 return ra, dec