lsst.meas.extensions.trailedSources ga5ce0eedc2+87645b091e
NaivePlugin.py
Go to the documentation of this file.
2# This file is part of meas_extensions_trailedSources.
3#
4# Developed for the LSST Data Management System.
5# This product includes software developed by the LSST Project
6# (http://www.lsst.org).
7# See the COPYRIGHT file at the top-level directory of this distribution
8# for details of code ownership.
9#
10# This program is free software: you can redistribute it and/or modify
11# it under the terms of the GNU General Public License as published by
12# the Free Software Foundation, either version 3 of the License, or
13# (at your option) any later version.
14#
15# This program is distributed in the hope that it will be useful,
16# but WITHOUT ANY WARRANTY; without even the implied warranty of
17# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18# GNU General Public License for more details.
19#
20# You should have received a copy of the GNU General Public License
21# along with this program. If not, see <http://www.gnu.org/licenses/>.
22#
23
24import numpy as np
25import scipy.optimize as sciOpt
26from scipy.special import erf
27
28import lsst.log
29from lsst.geom import Point2D
30from lsst.meas.base.pluginRegistry import register
31from lsst.meas.base import SingleFramePlugin, SingleFramePluginConfig
32from lsst.meas.base import FlagHandler, FlagDefinitionList, SafeCentroidExtractor
33from lsst.meas.base import MeasurementError
34
35from ._trailedSources import VeresModel
36from .utils import getMeasurementCutout
37
38__all__ = ("SingleFrameNaiveTrailConfig", "SingleFrameNaiveTrailPlugin")
39
40
41class SingleFrameNaiveTrailConfig(SingleFramePluginConfig):
42 """Config class for SingleFrameNaiveTrailPlugin.
43 """
44 pass
45
46
47@register("ext_trailedSources_Naive")
48class SingleFrameNaiveTrailPlugin(SingleFramePlugin):
49 """Naive trailed source measurement plugin
50
51 Measures the length, angle from +x-axis, and end points of an extended
52 source using the second moments.
53
54 Parameters
55 ----------
56 config: `SingleFrameNaiveTrailConfig`
57 Plugin configuration.
58 name: `str`
59 Plugin name.
60 schema: `lsst.afw.table.Schema`
61 Schema for the output catalog.
63 Metadata to be attached to output catalog.
64
65 Notes
66 -----
67 This measurement plugin aims to utilize the already measured adaptive
68 second moments to naively estimate the length and angle, and thus
69 end-points, of a fast-moving, trailed source. The length is solved for via
70 finding the root of the difference between the numerical (stack computed)
71 and the analytic adaptive second moments. The angle, theta, from the x-axis
72 is also computed via adaptive moments: theta = arctan(2*Ixy/(Ixx - Iyy))/2.
73 The end points of the trail are then given by (xc +/- (length/2)*cos(theta)
74 and yc +/- (length/2)*sin(theta)), with xc and yc being the centroid
75 coordinates.
76
77 See also
78 --------
79 lsst.meas.base.SingleFramePlugin
80 """
81
82 ConfigClass = SingleFrameNaiveTrailConfig
83
84 @classmethod
86 # Needs centroids, shape, and flux measurements.
87 # VeresPlugin is run after, which requires image data.
88 return cls.APCORR_ORDER + 0.1
89
90 def __init__(self, config, name, schema, metadata):
91 super().__init__(config, name, schema, metadata)
92
93 # Measurement Keys
94 self.keyRa = schema.addField(name + "_ra", type="D", doc="Trail centroid right ascension.")
95 self.keyDec = schema.addField(name + "_dec", type="D", doc="Trail centroid declination.")
96 self.keyX0 = schema.addField(name + "_x0", type="D", doc="Trail head X coordinate.", units="pixel")
97 self.keyY0 = schema.addField(name + "_y0", type="D", doc="Trail head Y coordinate.", units="pixel")
98 self.keyX1 = schema.addField(name + "_x1", type="D", doc="Trail tail X coordinate.", units="pixel")
99 self.keyY1 = schema.addField(name + "_y1", type="D", doc="Trail tail Y coordinate.", units="pixel")
100 self.keyFlux = schema.addField(name + "_flux", type="D", doc="Trailed source flux.", units="count")
101 self.keyLength = schema.addField(name + "_length", type="D", doc="Trail length.", units="pixel")
102 self.keyAngle = schema.addField(name + "_angle", type="D", doc="Angle measured from +x-axis.")
103
104 # Measurement Error Keys
105 self.keyX0Err = schema.addField(name + "_x0Err", type="D",
106 doc="Trail head X coordinate error.", units="pixel")
107 self.keyY0Err = schema.addField(name + "_y0Err", type="D",
108 doc="Trail head Y coordinate error.", units="pixel")
109 self.keyX1Err = schema.addField(name + "_x1Err", type="D",
110 doc="Trail tail X coordinate error.", units="pixel")
111 self.keyY1Err = schema.addField(name + "_y1Err", type="D",
112 doc="Trail tail Y coordinate error.", units="pixel")
113 self.keyFluxErr = schema.addField(name + "_fluxErr", type="D",
114 doc="Trail flux error.", units="count")
115 self.keyLengthErr = schema.addField(name + "_lengthErr", type="D",
116 doc="Trail length error.", units="pixel")
117 self.keyAngleErr = schema.addField(name + "_angleErr", type="D", doc="Trail angle error.")
118
119 flagDefs = FlagDefinitionList()
120 flagDefs.addFailureFlag("No trailed-source measured")
121 self.NO_FLUX = flagDefs.add("flag_noFlux", "No suitable prior flux measurement")
122 self.NO_CONVERGE = flagDefs.add("flag_noConverge", "The root finder did not converge")
123 self.NO_SIGMA = flagDefs.add("flag_noSigma", "No PSF width (sigma)")
124 self.SAFE_CENTROID = flagDefs.add("flag_safeCentroid", "Fell back to safe centroid extractor")
125 self.flagHandler = FlagHandler.addFields(schema, name, flagDefs)
126
128
129 def measure(self, measRecord, exposure):
130 """Run the Naive trailed source measurement algorithm.
131
132 Parameters
133 ----------
134 measRecord : `lsst.afw.table.SourceRecord`
135 Record describing the object being measured.
136 exposure : `lsst.afw.image.Exposure`
137 Pixel data to be measured.
138
139 See also
140 --------
141 lsst.meas.base.SingleFramePlugin.measure
142 """
143
144 # Get the SdssShape centroid or fall back to slot
145 # There are currently no centroid errors for SdssShape
146 xc = measRecord.get("base_SdssShape_x")
147 yc = measRecord.get("base_SdssShape_y")
148 if not np.isfinite(xc) or not np.isfinite(yc):
149 xc, yc = self.centroidExtractor(measRecord, self.flagHandler)
150 raise MeasurementError(self.SAFE_CENTROID.doc, self.SAFE_CENTROID.number)
151
152 ra, dec = self.computeRaDec(exposure, xc, yc)
153
154 # Transform the second-moments to semi-major and minor axes
155 Ixx, Iyy, Ixy = measRecord.getShape().getParameterVector()
156 xmy = Ixx - Iyy
157 xpy = Ixx + Iyy
158 xmy2 = xmy*xmy
159 xy2 = Ixy*Ixy
160 a2 = 0.5 * (xpy + np.sqrt(xmy2 + 4.0*xy2))
161 b2 = 0.5 * (xpy - np.sqrt(xmy2 + 4.0*xy2))
162
163 # Measure the trail length
164 # Check if the second-moments are weighted
165 if measRecord.get("base_SdssShape_flag_unweighted"):
166 lsst.log.debug("Unweighed")
167 length, gradLength = self.computeLength(a2, b2)
168 else:
169 lsst.log.debug("Weighted")
170 length, gradLength, results = self.findLength(a2, b2)
171 if not results.converged:
172 lsst.log.info(results.flag)
173 raise MeasurementError(self.NO_CONVERGE.doc, self.NO_CONVERGE.number)
174
175 # Compute the angle of the trail from the x-axis
176 theta = 0.5 * np.arctan2(2.0 * Ixy, xmy)
177
178 # Get end-points of the trail (there is a degeneracy here)
179 radius = length/2.0 # Trail 'radius'
180 dydtheta = radius*np.cos(theta)
181 dxdtheta = radius*np.sin(theta)
182 x0 = xc - dydtheta
183 y0 = yc - dxdtheta
184 x1 = xc + dydtheta
185 y1 = yc + dxdtheta
186
187 # Get a cutout of the object from the exposure
188 cutout = getMeasurementCutout(measRecord, exposure)
189
190 # Compute flux assuming fixed parameters for VeresModel
191 params = np.array([xc, yc, 1.0, length, theta]) # Flux = 1.0
192 model = VeresModel(cutout)
193 flux, gradFlux = model.computeFluxWithGradient(params)
194
195 # Fall back to aperture flux
196 if not np.isfinite(flux):
197 if np.isfinite(measRecord.getApInstFlux()):
198 flux = measRecord.getApInstFlux()
199 else:
200 raise MeasurementError(self.NO_FLUX.doc, self.NO_FLUX.number)
201
202 # Propogate errors from second moments and centroid
203 IxxErr2, IyyErr2, IxyErr2 = np.diag(measRecord.getShapeErr())
204
205 # SdssShape does not produce centroid errors. The
206 # Slot centroid errors will suffice for now.
207 xcErr2, ycErr2 = np.diag(measRecord.getCentroidErr())
208
209 # Error in length
210 desc = np.sqrt(xmy2 + 4.0*xy2) # Descriminant^1/2 of EV equation
211 da2dIxx = 0.5*(1.0 + (xmy/desc))
212 da2dIyy = 0.5*(1.0 - (xmy/desc))
213 da2dIxy = 2.0*Ixy / desc
214 a2Err2 = IxxErr2*da2dIxx*da2dIxx + IyyErr2*da2dIyy*da2dIyy + IxyErr2*da2dIxy*da2dIxy
215 b2Err2 = IxxErr2*da2dIyy*da2dIyy + IyyErr2*da2dIxx*da2dIxx + IxyErr2*da2dIxy*da2dIxy
216 dLda2, dLdb2 = gradLength
217 lengthErr = np.sqrt(dLda2*dLda2*a2Err2 + dLdb2*dLdb2*b2Err2)
218
219 # Error in theta
220 dThetadIxx = -Ixy / (xmy2 + 4.0*xy2) # dThetadIxx = -dThetadIyy
221 dThetadIxy = xmy / (xmy2 + 4.0*xy2)
222 thetaErr = np.sqrt(dThetadIxx*dThetadIxx*(IxxErr2 + IyyErr2) + dThetadIxy*dThetadIxy*IxyErr2)
223
224 # Error in flux
225 dFdxc, dFdyc, _, dFdL, dFdTheta = gradFlux
226 fluxErr = np.sqrt(dFdL*dFdL*lengthErr*lengthErr + dFdTheta*dFdTheta*thetaErr*thetaErr
227 + dFdxc*dFdxc*xcErr2 + dFdyc*dFdyc*ycErr2)
228
229 # Errors in end-points
230 dxdradius = np.cos(theta)
231 dydradius = np.sin(theta)
232 radiusErr2 = lengthErr*lengthErr/4.0
233 xErr2 = np.sqrt(xcErr2 + radiusErr2*dxdradius*dxdradius + thetaErr*thetaErr*dxdtheta*dxdtheta)
234 yErr2 = np.sqrt(ycErr2 + radiusErr2*dydradius*dydradius + thetaErr*thetaErr*dydtheta*dydtheta)
235 x0Err = np.sqrt(xErr2) # Same for x1
236 y0Err = np.sqrt(yErr2) # Same for y1
237
238 # Set flags
239 measRecord.set(self.keyRa, ra)
240 measRecord.set(self.keyDec, dec)
241 measRecord.set(self.keyX0, x0)
242 measRecord.set(self.keyY0, y0)
243 measRecord.set(self.keyX1, x1)
244 measRecord.set(self.keyY1, y1)
245 measRecord.set(self.keyFlux, flux)
246 measRecord.set(self.keyLength, length)
247 measRecord.set(self.keyAngle, theta)
248 measRecord.set(self.keyX0Err, x0Err)
249 measRecord.set(self.keyY0Err, y0Err)
250 measRecord.set(self.keyX1Err, x0Err)
251 measRecord.set(self.keyY1Err, y0Err)
252 measRecord.set(self.keyFluxErr, fluxErr)
253 measRecord.set(self.keyLengthErr, lengthErr)
254 measRecord.set(self.keyAngleErr, thetaErr)
255
256 def fail(self, measRecord, error=None):
257 """Record failure
258
259 See also
260 --------
261 lsst.meas.base.SingleFramePlugin.fail
262 """
263 if error is None:
264 self.flagHandler.handleFailure(measRecord)
265 else:
266 self.flagHandler.handleFailure(measRecord, error.cpp)
267
268 @staticmethod
269 def _computeSecondMomentDiff(z, c):
270 """Compute difference of the numerical and analytic second moments.
271
272 Parameters
273 ----------
274 z : `float`
275 Proportional to the length of the trail. (see notes)
276 c : `float`
277 Constant (see notes)
278
279 Returns
280 -------
281 diff : `float`
282 Difference in numerical and analytic second moments.
283
284 Notes
285 -----
286 This is a simplified expression for the difference between the stack
287 computed adaptive second-moment and the analytic solution. The variable
288 z is proportional to the length such that length=2*z*sqrt(2*(Ixx+Iyy)),
289 and c is a constant (c = 4*Ixx/((Ixx+Iyy)*sqrt(pi))). Both have been
290 defined to avoid unnecessary floating-point operations in the root
291 finder.
292 """
293
294 diff = erf(z) - c*z*np.exp(-z*z)
295 return diff
296
297 @classmethod
298 def findLength(cls, Ixx, Iyy):
299 """Find the length of a trail, given adaptive second-moments.
300
301 Uses a root finder to compute the length of a trail corresponding to
302 the adaptive second-moments computed by previous measurements
303 (ie. SdssShape).
304
305 Parameters
306 ----------
307 Ixx : `float`
308 Adaptive second-moment along x-axis.
309 Iyy : `float`
310 Adaptive second-moment along y-axis.
311
312 Returns
313 -------
314 length : `float`
315 Length of the trail.
316 results : `scipy.optimize.RootResults`
317 Contains messages about convergence from the root finder.
318 """
319
320 xpy = Ixx + Iyy
321 c = 4.0*Ixx/(xpy*np.sqrt(np.pi))
322
323 # Given a 'c' in (c_min, c_max], the root is contained in (0,1].
324 # c_min is given by the case: Ixx == Iyy, ie. a point source.
325 # c_max is given by the limit Ixx >> Iyy.
326 # Emperically, 0.001 is a suitable lower bound, assuming Ixx > Iyy.
327 z, results = sciOpt.brentq(lambda z: cls._computeSecondMomentDiff(z, c),
328 0.001, 1.0, full_output=True)
329
330 length = 2.0*z*np.sqrt(2.0*xpy)
331 gradLength = cls._gradFindLength(Ixx, Iyy, z, c)
332 return length, gradLength, results
333
334 @staticmethod
335 def _gradFindLength(Ixx, Iyy, z, c):
336 """Compute the gradient of the findLength function.
337 """
338 spi = np.sqrt(np.pi)
339 xpy = Ixx+Iyy
340 xpy2 = xpy*xpy
341 enz2 = np.exp(-z*z)
342 sxpy = np.sqrt(xpy)
343
344 fac = 4.0 / (spi*xpy2)
345 dcdIxx = Iyy*fac
346 dcdIyy = -Ixx*fac
347
348 # Derivatives of the _computeMomentsDiff function
349 dfdc = z*enz2
350 dzdf = spi / (enz2*(spi*c*(2.0*z*z - 1.0) + 2.0)) # inverse of dfdz
351
352 dLdz = 2.0*np.sqrt(2.0)*sxpy
353 pLpIxx = np.sqrt(2.0)*z / sxpy # Same as pLpIyy
354
355 dLdc = dLdz*dzdf*dfdc
356 dLdIxx = dLdc*dcdIxx + pLpIxx
357 dLdIyy = dLdc*dcdIyy + pLpIxx
358 return dLdIxx, dLdIyy
359
360 @staticmethod
361 def computeLength(Ixx, Iyy):
362 """Compute the length of a trail, given unweighted second-moments.
363 """
364 denom = np.sqrt(Ixx - 2.0*Iyy)
365
366 length = np.sqrt(6.0)*denom
367
368 dLdIxx = np.sqrt(1.5) / denom
369 dLdIyy = -np.sqrt(6.0) / denom
370 return length, (dLdIxx, dLdIyy)
371
372 @staticmethod
373 def computeRaDec(exposure, x, y):
374 """Convert pixel coordinates to RA and Dec.
375
376 Parameters
377 ----------
378 exposure : `lsst.afw.image.ExposureF`
379 Exposure object containing the WCS.
380 x : `float`
381 x coordinate of the trail centroid
382 y : `float`
383 y coodinate of the trail centroid
384
385 Returns
386 -------
387 ra : `float`
388 Right ascension.
389 dec : `float`
390 Declination.
391 """
392
393 wcs = exposure.getWcs()
394 center = wcs.pixelToSky(Point2D(x, y))
395 ra = center.getRa().asDegrees()
396 dec = center.getDec().asDegrees()
397 return ra, dec
def getMeasurementCutout(measRecord, exposure)
Definition: utils.py:28