Coverage for python/lsst/meas/extensions/trailedSources/NaivePlugin.py: 17%

213 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-11 04:01 -0700

1# 

2# This file is part of meas_extensions_trailedSources. 

3# 

4# Developed for the LSST Data Management System. 

5# This product includes software developed by the LSST Project 

6# (http://www.lsst.org). 

7# See the COPYRIGHT file at the top-level directory of this distribution 

8# for details of code ownership. 

9# 

10# This program is free software: you can redistribute it and/or modify 

11# it under the terms of the GNU General Public License as published by 

12# the Free Software Foundation, either version 3 of the License, or 

13# (at your option) any later version. 

14# 

15# This program is distributed in the hope that it will be useful, 

16# but WITHOUT ANY WARRANTY; without even the implied warranty of 

17# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

18# GNU General Public License for more details. 

19# 

20# You should have received a copy of the GNU General Public License 

21# along with this program. If not, see <http://www.gnu.org/licenses/>. 

22# 

23 

24import logging 

25import numpy as np 

26import scipy.optimize as sciOpt 

27from scipy.special import erf 

28from math import sqrt 

29 

30from lsst.geom import Point2D, Point2I 

31from lsst.meas.base.pluginRegistry import register 

32from lsst.meas.base import SingleFramePlugin, SingleFramePluginConfig 

33from lsst.meas.base import FlagHandler, FlagDefinitionList, SafeCentroidExtractor 

34 

35from ._trailedSources import VeresModel 

36from .utils import getMeasurementCutout 

37 

38__all__ = ("SingleFrameNaiveTrailConfig", "SingleFrameNaiveTrailPlugin") 

39 

40 

41class SingleFrameNaiveTrailConfig(SingleFramePluginConfig): 

42 """Config class for SingleFrameNaiveTrailPlugin. 

43 """ 

44 pass 

45 

46 

47@register("ext_trailedSources_Naive") 

48class SingleFrameNaiveTrailPlugin(SingleFramePlugin): 

49 """Naive trailed source measurement plugin 

50 

51 Measures the length, angle from +x-axis, and end points of an extended 

52 source using the second moments. 

53 

54 Parameters 

55 ---------- 

56 config: `SingleFrameNaiveTrailConfig` 

57 Plugin configuration. 

58 name: `str` 

59 Plugin name. 

60 schema: `lsst.afw.table.Schema` 

61 Schema for the output catalog. 

62 metadata: `lsst.daf.base.PropertySet` 

63 Metadata to be attached to output catalog. 

64 

65 Notes 

66 ----- 

67 This measurement plugin aims to utilize the already measured adaptive 

68 second moments to naively estimate the length and angle, and thus 

69 end-points, of a fast-moving, trailed source. The length is solved for via 

70 finding the root of the difference between the numerical (stack computed) 

71 and the analytic adaptive second moments. The angle, theta, from the x-axis 

72 is also computed via adaptive moments: theta = arctan(2*Ixy/(Ixx - Iyy))/2. 

73 The end points of the trail are then given by (xc +/- (length/2)*cos(theta) 

74 and yc +/- (length/2)*sin(theta)), with xc and yc being the centroid 

75 coordinates. 

76 

77 See also 

78 -------- 

79 lsst.meas.base.SingleFramePlugin 

80 """ 

81 

82 ConfigClass = SingleFrameNaiveTrailConfig 

83 

84 @classmethod 

85 def getExecutionOrder(cls): 

86 # Needs centroids, shape, and flux measurements. 

87 # VeresPlugin is run after, which requires image data. 

88 return cls.APCORR_ORDER + 0.1 

89 

90 def __init__(self, config, name, schema, metadata, logName=None): 

91 if logName is None: 

92 logName = __name__ 

93 super().__init__(config, name, schema, metadata, logName=logName) 

94 

95 # Measurement Keys 

96 self.keyRa = schema.addField(name + "_ra", type="D", doc="Trail centroid right ascension.") 

97 self.keyDec = schema.addField(name + "_dec", type="D", doc="Trail centroid declination.") 

98 self.keyX0 = schema.addField(name + "_x0", type="D", doc="Trail head X coordinate.", units="pixel") 

99 self.keyY0 = schema.addField(name + "_y0", type="D", doc="Trail head Y coordinate.", units="pixel") 

100 self.keyX1 = schema.addField(name + "_x1", type="D", doc="Trail tail X coordinate.", units="pixel") 

101 self.keyY1 = schema.addField(name + "_y1", type="D", doc="Trail tail Y coordinate.", units="pixel") 

102 self.keyFlux = schema.addField(name + "_flux", type="D", doc="Trailed source flux.", units="count") 

103 self.keyLength = schema.addField(name + "_length", type="D", doc="Trail length.", units="pixel") 

104 self.keyAngle = schema.addField(name + "_angle", type="D", doc="Angle measured from +x-axis.") 

105 

106 # Measurement Error Keys 

107 self.keyX0Err = schema.addField(name + "_x0Err", type="D", 

108 doc="Trail head X coordinate error.", units="pixel") 

109 self.keyY0Err = schema.addField(name + "_y0Err", type="D", 

110 doc="Trail head Y coordinate error.", units="pixel") 

111 self.keyX1Err = schema.addField(name + "_x1Err", type="D", 

112 doc="Trail tail X coordinate error.", units="pixel") 

113 self.keyY1Err = schema.addField(name + "_y1Err", type="D", 

114 doc="Trail tail Y coordinate error.", units="pixel") 

115 self.keyFluxErr = schema.addField(name + "_fluxErr", type="D", 

116 doc="Trail flux error.", units="count") 

117 self.keyLengthErr = schema.addField(name + "_lengthErr", type="D", 

118 doc="Trail length error.", units="pixel") 

119 self.keyAngleErr = schema.addField(name + "_angleErr", type="D", doc="Trail angle error.") 

120 

121 flagDefs = FlagDefinitionList() 

122 self.FAILURE = flagDefs.addFailureFlag("No trailed-source measured") 

123 self.NO_FLUX = flagDefs.add("flag_noFlux", "No suitable prior flux measurement") 

124 self.NO_CONVERGE = flagDefs.add("flag_noConverge", "The root finder did not converge") 

125 self.NO_SIGMA = flagDefs.add("flag_noSigma", "No PSF width (sigma)") 

126 self.SAFE_CENTROID = flagDefs.add("flag_safeCentroid", "Fell back to safe centroid extractor") 

127 self.EDGE = flagDefs.add("flag_edge", "Trail contains edge pixels") 

128 self.OFFIMAGE = flagDefs.add("flag_off_image", "Trail extends off image") 

129 self.NAN = flagDefs.add("flag_nan", "One or more trail coordinates are missing") 

130 self.SUSPECT_LONG_TRAIL = flagDefs.add("flag_suspect_long_trail", 

131 "Trail length is greater than three times the psf radius") 

132 self.flagHandler = FlagHandler.addFields(schema, name, flagDefs) 

133 

134 self.centroidExtractor = SafeCentroidExtractor(schema, name) 

135 self.log = logging.getLogger(self.logName) 

136 

137 def measure(self, measRecord, exposure): 

138 """Run the Naive trailed source measurement algorithm. 

139 

140 Parameters 

141 ---------- 

142 measRecord : `lsst.afw.table.SourceRecord` 

143 Record describing the object being measured. 

144 exposure : `lsst.afw.image.Exposure` 

145 Pixel data to be measured. 

146 

147 See also 

148 -------- 

149 lsst.meas.base.SingleFramePlugin.measure 

150 """ 

151 # Get the SdssShape centroid or fall back to slot 

152 # There are currently no centroid errors for SdssShape 

153 xc = measRecord.get("base_SdssShape_x") 

154 yc = measRecord.get("base_SdssShape_y") 

155 if not np.isfinite(xc) or not np.isfinite(yc): 

156 xc, yc = self.centroidExtractor(measRecord, self.flagHandler) 

157 self.flagHandler.setValue(measRecord, self.SAFE_CENTROID.number, True) 

158 self.flagHandler.setValue(measRecord, self.FAILURE.number, True) 

159 return 

160 

161 ra, dec = self.computeRaDec(exposure, xc, yc) 

162 

163 if measRecord.getShapeFlag(): 

164 self.log.warning("Shape flag is set for measRecord: %s. Trail measurement " 

165 "will not be made.", measRecord.getId()) 

166 self.flagHandler.setValue(measRecord, self.FAILURE.number, True) 

167 self.flagHandler.setValue(measRecord, self.SHAPE.number, True) 

168 return 

169 

170 # Transform the second-moments to semi-major and minor axes 

171 Ixx, Iyy, Ixy = measRecord.getShape().getParameterVector() 

172 xmy = Ixx - Iyy 

173 xpy = Ixx + Iyy 

174 xmy2 = xmy*xmy 

175 xy2 = Ixy*Ixy 

176 a2 = 0.5 * (xpy + sqrt(xmy2 + 4.0*xy2)) 

177 b2 = 0.5 * (xpy - sqrt(xmy2 + 4.0*xy2)) 

178 

179 # Measure the trail length 

180 # Check if the second-moments are weighted 

181 if measRecord.get("base_SdssShape_flag_unweighted"): 

182 self.log.debug("Unweighted") 

183 length, gradLength = self.computeLength(a2, b2) 

184 else: 

185 self.log.debug("Weighted") 

186 length, gradLength, results = self.findLength(a2, b2) 

187 if not results.converged: 

188 self.log.info("Results not converged: %s", results.flag) 

189 self.flagHandler.setValue(measRecord, self.NO_CONVERGE.number, True) 

190 self.flagHandler.setValue(measRecord, self.FAILURE.number, True) 

191 return 

192 

193 # Compute the angle of the trail from the x-axis 

194 theta = 0.5 * np.arctan2(2.0 * Ixy, xmy) 

195 

196 # Get end-points of the trail (there is a degeneracy here) 

197 radius = length/2.0 # Trail 'radius' 

198 dydtheta = radius*np.cos(theta) 

199 dxdtheta = radius*np.sin(theta) 

200 x0 = xc - dydtheta 

201 y0 = yc - dxdtheta 

202 x1 = xc + dydtheta 

203 y1 = yc + dxdtheta 

204 

205 self.check_trail(measRecord, exposure, x0, y0, x1, y1, length) 

206 

207 # Get a cutout of the object from the exposure 

208 cutout = getMeasurementCutout(measRecord, exposure) 

209 

210 # Compute flux assuming fixed parameters for VeresModel 

211 params = np.array([xc, yc, 1.0, length, theta]) # Flux = 1.0 

212 model = VeresModel(cutout) 

213 flux, gradFlux = model.computeFluxWithGradient(params) 

214 

215 # Fall back to aperture flux 

216 if not np.isfinite(flux): 

217 if np.isfinite(measRecord.getApInstFlux()): 

218 flux = measRecord.getApInstFlux() 

219 else: 

220 self.flagHandler.setValue(measRecord, self.NO_FLUX.number, True) 

221 self.flagHandler.setValue(measRecord, self.FAILURE.number, True) 

222 return 

223 

224 # Propogate errors from second moments and centroid 

225 IxxErr2, IyyErr2, IxyErr2 = np.diag(measRecord.getShapeErr()) 

226 

227 # SdssShape does not produce centroid errors. The 

228 # Slot centroid errors will suffice for now. 

229 xcErr2, ycErr2 = np.diag(measRecord.getCentroidErr()) 

230 

231 # Error in length 

232 desc = sqrt(xmy2 + 4.0*xy2) # Descriminant^1/2 of EV equation 

233 da2dIxx = 0.5*(1.0 + (xmy/desc)) 

234 da2dIyy = 0.5*(1.0 - (xmy/desc)) 

235 da2dIxy = 2.0*Ixy / desc 

236 a2Err2 = IxxErr2*da2dIxx*da2dIxx + IyyErr2*da2dIyy*da2dIyy + IxyErr2*da2dIxy*da2dIxy 

237 b2Err2 = IxxErr2*da2dIyy*da2dIyy + IyyErr2*da2dIxx*da2dIxx + IxyErr2*da2dIxy*da2dIxy 

238 dLda2, dLdb2 = gradLength 

239 lengthErr = np.sqrt(dLda2*dLda2*a2Err2 + dLdb2*dLdb2*b2Err2) 

240 

241 # Error in theta 

242 dThetadIxx = -Ixy / (xmy2 + 4.0*xy2) # dThetadIxx = -dThetadIyy 

243 dThetadIxy = xmy / (xmy2 + 4.0*xy2) 

244 thetaErr = sqrt(dThetadIxx*dThetadIxx*(IxxErr2 + IyyErr2) + dThetadIxy*dThetadIxy*IxyErr2) 

245 

246 # Error in flux 

247 dFdxc, dFdyc, _, dFdL, dFdTheta = gradFlux 

248 fluxErr = sqrt(dFdL*dFdL*lengthErr*lengthErr + dFdTheta*dFdTheta*thetaErr*thetaErr 

249 + dFdxc*dFdxc*xcErr2 + dFdyc*dFdyc*ycErr2) 

250 

251 # Errors in end-points 

252 dxdradius = np.cos(theta) 

253 dydradius = np.sin(theta) 

254 radiusErr2 = lengthErr*lengthErr/4.0 

255 xErr2 = sqrt(xcErr2 + radiusErr2*dxdradius*dxdradius + thetaErr*thetaErr*dxdtheta*dxdtheta) 

256 yErr2 = sqrt(ycErr2 + radiusErr2*dydradius*dydradius + thetaErr*thetaErr*dydtheta*dydtheta) 

257 x0Err = sqrt(xErr2) # Same for x1 

258 y0Err = sqrt(yErr2) # Same for y1 

259 

260 # Set flags 

261 measRecord.set(self.keyRa, ra) 

262 measRecord.set(self.keyDec, dec) 

263 measRecord.set(self.keyX0, x0) 

264 measRecord.set(self.keyY0, y0) 

265 measRecord.set(self.keyX1, x1) 

266 measRecord.set(self.keyY1, y1) 

267 measRecord.set(self.keyFlux, flux) 

268 measRecord.set(self.keyLength, length) 

269 measRecord.set(self.keyAngle, theta) 

270 measRecord.set(self.keyX0Err, x0Err) 

271 measRecord.set(self.keyY0Err, y0Err) 

272 measRecord.set(self.keyX1Err, x0Err) 

273 measRecord.set(self.keyY1Err, y0Err) 

274 measRecord.set(self.keyFluxErr, fluxErr) 

275 measRecord.set(self.keyLengthErr, lengthErr) 

276 measRecord.set(self.keyAngleErr, thetaErr) 

277 

278 def check_trail(self, measRecord, exposure, x0, y0, x1, y1, length): 

279 """ Set flags for edge pixels, off chip, and nan trail coordinates and 

280 flag if trail length is three times larger than psf. 

281 

282 Check if the coordinates of the beginning and ending of the trail fall 

283 inside the exposures bounding box. If not, set the off_chip flag. 

284 If the beginning or ending falls within a pixel marked as edge, set the 

285 edge flag. If any of the coordinates happens to fall on a nan, then 

286 set the nan flag. 

287 Additionally, check if the trail is three times larger than the psf. If 

288 so, set the suspect trail flag. 

289 

290 Parameters 

291 ---------- 

292 measRecord: `lsst.afw.MeasurementRecord` 

293 Record describing the object being measured. 

294 exposure: `lsst.afw.Exposure` 

295 Pixel data to be measured. 

296 

297 x0: `float` 

298 x coordinate of the beginning of the trail. 

299 y0: `float` 

300 y coordinate of the beginning of the trail. 

301 x1: `float` 

302 x coordinate of the end of the trail. 

303 y1: `float` 

304 y coordinate of the end of the trail. 

305 """ 

306 x_coords = [x0, x1] 

307 y_coords = [y0, y1] 

308 

309 # Check if one of the end points of the trail sources is nan. If so, 

310 # set the trailed source nan flag. 

311 if np.isnan(x_coords).any() or np.isnan(y_coords).any(): 

312 self.flagHandler.setValue(measRecord, self.NAN.number, True) 

313 x_coords = [x for x in x_coords if not np.isnan(x)] 

314 y_coords = [y for y in y_coords if not np.isnan(y)] 

315 

316 # Check if the non-nan coordinates are within the bounding box 

317 if not (all(exposure.getBBox().beginX <= x <= exposure.getBBox().endX for x in x_coords) 

318 and all(exposure.getBBox().beginY <= y <= exposure.getBBox().endY for y in y_coords)): 

319 self.flagHandler.setValue(measRecord, self.EDGE.number, True) 

320 self.flagHandler.setValue(measRecord, self.OFFIMAGE.number, True) 

321 else: 

322 # Check if edge is set for any of the pixel pairs. Do not 

323 # check any that have a nan. 

324 for (x_val, y_val) in zip(x_coords, y_coords): 

325 if x_val is not np.nan and y_val is not np.nan: 

326 if exposure.mask[Point2I(int(x_val), 

327 int(y_val))] & exposure.mask.getPlaneBitMask('EDGE') != 0: 

328 self.flagHandler.setValue(measRecord, self.EDGE.number, True) 

329 # Check whether trail extends off the edge of the exposure. Allows nans 

330 # as their location 

331 elif not (all(exposure.getBBox().beginX <= x <= exposure.getBBox().endX for x in x_coords) 

332 and all(exposure.getBBox().beginY <= y <= exposure.getBBox().endY for y in y_coords)): 

333 self.flagHandler.setValue(measRecord, self.EDGE.number, True) 

334 self.flagHandler.setValue(measRecord, self.OFFIMAGE.number, True) 

335 else: 

336 # Check whether the beginning or end point of the trail has the 

337 # edge flag set. The end points are not whole pixel values, so 

338 # the pixel value must be rounded. 

339 if exposure.mask[Point2I(int(x0), int(y0))] and exposure.mask[Point2I(int(x1), int(y1))]: 

340 if ((exposure.mask[Point2I(int(x0), int(y0))] & exposure.mask.getPlaneBitMask('EDGE') != 0) 

341 or (exposure.mask[Point2I(int(x1), int(y1))] 

342 & exposure.mask.getPlaneBitMask('EDGE') != 0)): 

343 self.flagHandler.setValue(measRecord, self.EDGE.number, True) 

344 

345 psfShape = exposure.psf.computeShape(exposure.getBBox().getCenter()) 

346 psfRadius = psfShape.getDeterminantRadius() 

347 

348 if length > psfRadius*3.0: 

349 self.flagHandler.setValue(measRecord, self.SUSPECT_LONG_TRAIL.number, True) 

350 

351 def fail(self, measRecord, error=None): 

352 """Record failure 

353 

354 See also 

355 -------- 

356 lsst.meas.base.SingleFramePlugin.fail 

357 """ 

358 if error is None: 

359 self.flagHandler.handleFailure(measRecord) 

360 else: 

361 self.flagHandler.handleFailure(measRecord, error.cpp) 

362 

363 @staticmethod 

364 def _computeSecondMomentDiff(z, c): 

365 """Compute difference of the numerical and analytic second moments. 

366 

367 Parameters 

368 ---------- 

369 z : `float` 

370 Proportional to the length of the trail. (see notes) 

371 c : `float` 

372 Constant (see notes) 

373 

374 Returns 

375 ------- 

376 diff : `float` 

377 Difference in numerical and analytic second moments. 

378 

379 Notes 

380 ----- 

381 This is a simplified expression for the difference between the stack 

382 computed adaptive second-moment and the analytic solution. The variable 

383 z is proportional to the length such that length=2*z*sqrt(2*(Ixx+Iyy)), 

384 and c is a constant (c = 4*Ixx/((Ixx+Iyy)*sqrt(pi))). Both have been 

385 defined to avoid unnecessary floating-point operations in the root 

386 finder. 

387 """ 

388 

389 diff = erf(z) - c*z*np.exp(-z*z) 

390 return diff 

391 

392 @classmethod 

393 def findLength(cls, Ixx, Iyy): 

394 """Find the length of a trail, given adaptive second-moments. 

395 

396 Uses a root finder to compute the length of a trail corresponding to 

397 the adaptive second-moments computed by previous measurements 

398 (ie. SdssShape). 

399 

400 Parameters 

401 ---------- 

402 Ixx : `float` 

403 Adaptive second-moment along x-axis. 

404 Iyy : `float` 

405 Adaptive second-moment along y-axis. 

406 

407 Returns 

408 ------- 

409 length : `float` 

410 Length of the trail. 

411 results : `scipy.optimize.RootResults` 

412 Contains messages about convergence from the root finder. 

413 """ 

414 

415 xpy = Ixx + Iyy 

416 c = 4.0*Ixx/(xpy*np.sqrt(np.pi)) 

417 

418 # Given a 'c' in (c_min, c_max], the root is contained in (0,1]. 

419 # c_min is given by the case: Ixx == Iyy, ie. a point source. 

420 # c_max is given by the limit Ixx >> Iyy. 

421 # Empirically, 0.001 is a suitable lower bound, assuming Ixx > Iyy. 

422 z, results = sciOpt.brentq(lambda z: cls._computeSecondMomentDiff(z, c), 

423 0.001, 1.0, full_output=True) 

424 

425 length = 2.0*z*np.sqrt(2.0*xpy) 

426 gradLength = cls._gradFindLength(Ixx, Iyy, z, c) 

427 return length, gradLength, results 

428 

429 @staticmethod 

430 def _gradFindLength(Ixx, Iyy, z, c): 

431 """Compute the gradient of the findLength function. 

432 """ 

433 spi = np.sqrt(np.pi) 

434 xpy = Ixx+Iyy 

435 xpy2 = xpy*xpy 

436 enz2 = np.exp(-z*z) 

437 sxpy = np.sqrt(xpy) 

438 

439 fac = 4.0 / (spi*xpy2) 

440 dcdIxx = Iyy*fac 

441 dcdIyy = -Ixx*fac 

442 

443 # Derivatives of the _computeMomentsDiff function 

444 dfdc = z*enz2 

445 dzdf = spi / (enz2*(spi*c*(2.0*z*z - 1.0) + 2.0)) # inverse of dfdz 

446 

447 dLdz = 2.0*np.sqrt(2.0)*sxpy 

448 pLpIxx = np.sqrt(2.0)*z / sxpy # Same as pLpIyy 

449 

450 dLdc = dLdz*dzdf*dfdc 

451 dLdIxx = dLdc*dcdIxx + pLpIxx 

452 dLdIyy = dLdc*dcdIyy + pLpIxx 

453 return dLdIxx, dLdIyy 

454 

455 @staticmethod 

456 def computeLength(Ixx, Iyy): 

457 """Compute the length of a trail, given unweighted second-moments. 

458 """ 

459 denom = np.sqrt(Ixx - 2.0*Iyy) 

460 

461 length = np.sqrt(6.0)*denom 

462 

463 dLdIxx = np.sqrt(1.5) / denom 

464 dLdIyy = -np.sqrt(6.0) / denom 

465 return length, (dLdIxx, dLdIyy) 

466 

467 @staticmethod 

468 def computeRaDec(exposure, x, y): 

469 """Convert pixel coordinates to RA and Dec. 

470 

471 Parameters 

472 ---------- 

473 exposure : `lsst.afw.image.ExposureF` 

474 Exposure object containing the WCS. 

475 x : `float` 

476 x coordinate of the trail centroid 

477 y : `float` 

478 y coodinate of the trail centroid 

479 

480 Returns 

481 ------- 

482 ra : `float` 

483 Right ascension. 

484 dec : `float` 

485 Declination. 

486 """ 

487 

488 wcs = exposure.getWcs() 

489 center = wcs.pixelToSky(Point2D(x, y)) 

490 ra = center.getRa().asDegrees() 

491 dec = center.getDec().asDegrees() 

492 return ra, dec