lsst.ip.isr  14.0-10-g2094653+5
fringe.py
Go to the documentation of this file.
1 #!/usr/bin/env python
2 #
3 # LSST Data Management System
4 # Copyright 2012 LSST Corporation.
5 #
6 # This product includes software developed by the
7 # LSST Project (http://www.lsst.org/).
8 #
9 # This program is free software: you can redistribute it and/or modify
10 # it under the terms of the GNU General Public License as published by
11 # the Free Software Foundation, either version 3 of the License, or
12 # (at your option) any later version.
13 #
14 # This program is distributed in the hope that it will be useful,
15 # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 # GNU General Public License for more details.
18 #
19 # You should have received a copy of the LSST License Statement and
20 # the GNU General Public License along with this program. If not,
21 # see <http://www.lsstcorp.org/LegalNotices/>.
22 #
23 
24 from __future__ import absolute_import, division, print_function
25 from builtins import zip
26 from builtins import input
27 from builtins import range
28 import numpy
29 
30 import lsst.afw.geom as afwGeom
31 import lsst.afw.image as afwImage
32 import lsst.afw.math as afwMath
33 import lsst.afw.display.ds9 as ds9
34 
35 from lsst.pipe.base import Task, Struct, timeMethod
36 from lsst.pex.config import Config, Field, ListField, ConfigField
37 
38 
39 def getFrame():
40  """Produce a new frame number each time"""
41  getFrame.frame += 1
42  return getFrame.frame
43 getFrame.frame = 0
44 
45 
46 class FringeStatisticsConfig(Config):
47  """Options for measuring fringes on an exposure"""
48  badMaskPlanes = ListField(dtype=str, default=["SAT"], doc="Ignore pixels with these masks")
49  stat = Field(dtype=int, default=int(afwMath.MEDIAN), doc="Statistic to use")
50  clip = Field(dtype=float, default=3.0, doc="Sigma clip threshold")
51  iterations = Field(dtype=int, default=3, doc="Number of fitting iterations")
52  rngSeedOffset = Field(dtype=int, default=0,
53  doc="Offset to the random number generator seed (full seed includes exposure ID)")
54 
55 
56 class FringeConfig(Config):
57  """Fringe subtraction options"""
58  filters = ListField(dtype=str, default=[], doc="Only fringe-subtract these filters")
59  num = Field(dtype=int, default=30000, doc="Number of fringe measurements")
60  small = Field(dtype=int, default=3, doc="Half-size of small (fringe) measurements (pixels)")
61  large = Field(dtype=int, default=30, doc="Half-size of large (background) measurements (pixels)")
62  iterations = Field(dtype=int, default=20, doc="Number of fitting iterations")
63  clip = Field(dtype=float, default=3.0, doc="Sigma clip threshold")
64  stats = ConfigField(dtype=FringeStatisticsConfig, doc="Statistics for measuring fringes")
65  pedestal = Field(dtype=bool, default=False, doc="Remove fringe pedestal?")
66 
67 
68 class FringeTask(Task):
69  """Task to remove fringes from a science exposure
70 
71  We measure fringe amplitudes at random positions on the science exposure
72  and at the same positions on the (potentially multiple) fringe frames
73  and solve for the scales simultaneously.
74  """
75  ConfigClass = FringeConfig
76 
77  def readFringes(self, dataRef, assembler=None):
78  """Read the fringe frame(s)
79 
80  The current implementation assumes only a single fringe frame and
81  will have to be updated to support multi-mode fringe subtraction.
82 
83  This implementation could be optimised by persisting the fringe
84  positions and fluxes.
85 
86  @param dataRef Data reference for the science exposure
87  @param assembler An instance of AssembleCcdTask (for assembling fringe frames)
88  @return Struct(fringes: fringe exposure or list of fringe exposures;
89  seed: 32-bit uint derived from ccdExposureId for random number generator
90  """
91  try:
92  fringe = dataRef.get("fringe", immediate=True)
93  except Exception as e:
94  raise RuntimeError("Unable to retrieve fringe for %s: %s" % (dataRef.dataId, e))
95  if assembler is not None:
96  fringe = assembler.assembleCcd(fringe)
97 
98  seed = self.config.stats.rngSeedOffset + dataRef.get("ccdExposureId", immediate=True)
99  #Seed for numpy.random.RandomState must be convertable to a 32 bit unsigned integer
100  seed %= 2**32
101 
102  return Struct(fringes=fringe,
103  seed=seed)
104 
105  @timeMethod
106  def run(self, exposure, fringes, seed=None):
107  """Remove fringes from the provided science exposure.
108 
109  Primary method of FringeTask. Fringes are only subtracted if the
110  science exposure has a filter listed in the configuration.
111 
112  @param exposure Science exposure from which to remove fringes
113  @param fringes Exposure or list of Exposures
114  @param seed 32-bit unsigned integer for random number generator
115  """
116  import lsstDebug
117  display = lsstDebug.Info(__name__).display
118 
119  if not self.checkFilter(exposure):
120  return
121 
122  if seed is None:
123  seed = self.config.stats.rngSeedOffset
124  rng = numpy.random.RandomState(seed=seed)
125 
126  if not hasattr(fringes, '__iter__'):
127  fringes = [fringes]
128 
129  mask = exposure.getMaskedImage().getMask()
130  for fringe in fringes:
131  fringe.getMaskedImage().getMask().__ior__(mask)
132  if self.config.pedestal:
133  self.removePedestal(fringe)
134 
135  # Placeholder implementation for multiple fringe frames
136  # This needs to be revisited in DM-4441
137  positions = self.generatePositions(fringes[0], rng)
138  fluxes = numpy.ndarray([self.config.num, len(fringes)])
139  for i, f in enumerate(fringes):
140  fluxes[:, i] = self.measureExposure(f, positions, title="Fringe frame")
141 
142  expFringes = self.measureExposure(exposure, positions, title="Science")
143  solution = self.solve(expFringes, fluxes)
144  self.subtract(exposure, fringes, solution)
145  if display:
146  ds9.mtv(exposure, title="Fringe subtracted", frame=getFrame())
147 
148  @timeMethod
149  def runDataRef(self, exposure, dataRef, assembler=None):
150  """Remove fringes from the provided science exposure.
151 
152  Retrieve fringes from butler dataRef provided and remove from
153  provided science exposure.
154  Fringes are only subtracted if the science exposure has a filter
155  listed in the configuration.
156 
157  @param exposure Science exposure from which to remove fringes
158  @param dataRef Data reference for the science exposure
159  @param assembler An instance of AssembleCcdTask (for assembling fringe frames)
160  """
161  if not self.checkFilter(exposure):
162  return
163  fringeStruct = self.readFringes(dataRef, assembler=assembler)
164  self.run(exposure, **fringeStruct.getDict())
165 
166  def checkFilter(self, exposure):
167  """Check whether we should fringe-subtract the science exposure"""
168  return exposure.getFilter().getName() in self.config.filters
169 
170  def removePedestal(self, fringe):
171  """Remove pedestal from fringe exposure"""
172  stats = afwMath.StatisticsControl()
173  stats.setNumSigmaClip(self.config.stats.clip)
174  stats.setNumIter(self.config.stats.iterations)
175  mi = fringe.getMaskedImage()
176  pedestal = afwMath.makeStatistics(mi, afwMath.MEDIAN, stats).getValue()
177  self.log.info("Removing fringe pedestal: %f", pedestal)
178  mi -= pedestal
179 
180  def generatePositions(self, exposure, rng):
181  """Generate a random distribution of positions for measuring fringe amplitudes"""
182  start = self.config.large
183  num = self.config.num
184  width = exposure.getWidth() - self.config.large
185  height = exposure.getHeight() - self.config.large
186  return numpy.array([rng.randint(start, width, size=num),
187  rng.randint(start, height, size=num)]).swapaxes(0, 1)
188 
189  @timeMethod
190  def measureExposure(self, exposure, positions, title="Fringe"):
191  """Measure fringe amplitudes for an exposure
192 
193  The fringe amplitudes are measured as the statistic within a square
194  aperture. The statistic within a larger aperture are subtracted so
195  as to remove the background.
196 
197  @param exposure Exposure to measure
198  @param positions Array of (x,y) for fringe measurement
199  @param title Title for display
200  @return Array of fringe measurements
201  """
202  stats = afwMath.StatisticsControl()
203  stats.setNumSigmaClip(self.config.stats.clip)
204  stats.setNumIter(self.config.stats.iterations)
205  stats.setAndMask(exposure.getMaskedImage().getMask().getPlaneBitMask(self.config.stats.badMaskPlanes))
206 
207  num = self.config.num
208  fringes = numpy.ndarray(num)
209 
210  for i in range(num):
211  x, y = positions[i]
212  small = measure(exposure.getMaskedImage(), x, y, self.config.small, self.config.stats.stat, stats)
213  large = measure(exposure.getMaskedImage(), x, y, self.config.large, self.config.stats.stat, stats)
214  fringes[i] = small - large
215 
216  import lsstDebug
217  display = lsstDebug.Info(__name__).display
218  if display:
219  frame = getFrame()
220  ds9.mtv(exposure, frame=frame, title=title)
221  if False:
222  with ds9.Buffering():
223  for x, y in positions:
224  corners = numpy.array([[-1, -1], [1, -1], [1, 1], [-1, 1], [-1, -1]]) + [[x, y]]
225  ds9.line(corners * self.config.small, frame=frame, ctype="green")
226  ds9.line(corners * self.config.large, frame=frame, ctype="blue")
227 
228  return fringes
229 
230  @timeMethod
231  def solve(self, science, fringes):
232  """Solve (with iterative clipping) for the scale factors
233 
234  @param science Array of science exposure fringe amplitudes
235  @param fringes Array of arrays of fringe frame fringe amplitudes
236  @return Array of scale factors for the fringe frames
237  """
238  import lsstDebug
239  doPlot = lsstDebug.Info(__name__).plot
240 
241  origNum = len(science)
242 
243  def emptyResult(msg=""):
244  """Generate an empty result for return to the user
245 
246  There are no good pixels; doesn't matter what we return.
247  """
248  self.log.warn("Unable to solve for fringes: no good pixels%s", msg)
249  out = [0]
250  if len(fringes) > 1:
251  out = out*len(fringes)
252  return numpy.array(out)
253 
254  good = numpy.where(numpy.logical_and(numpy.isfinite(science), numpy.any(numpy.isfinite(fringes), 1)))
255  science = science[good]
256  fringes = fringes[good]
257  oldNum = len(science)
258  if oldNum == 0:
259  return emptyResult()
260 
261  # Up-front rejection to get rid of extreme, potentially troublesome values
262  # (e.g., fringe apertures that fall on objects).
263  good = select(science, self.config.clip)
264  for ff in range(fringes.shape[1]):
265  good &= select(fringes[:, ff], self.config.clip)
266  science = science[good]
267  fringes = fringes[good]
268  oldNum = len(science)
269  if oldNum == 0:
270  return emptyResult(" after initial rejection")
271 
272  for i in range(self.config.iterations):
273  solution = self._solve(science, fringes)
274  resid = science - numpy.sum(solution * fringes, 1)
275  rms = stdev(resid)
276  good = numpy.logical_not(abs(resid) > self.config.clip * rms)
277  self.log.debug("Iteration %d: RMS=%f numGood=%d", i, rms, good.sum())
278  self.log.debug("Solution %d: %s", i, solution)
279  newNum = good.sum()
280  if newNum == 0:
281  return emptyResult(" after %d rejection iterations" % i)
282 
283  if doPlot:
284  import matplotlib.pyplot as plot
285  for j in range(fringes.shape[1]):
286  fig = plot.figure(j)
287  fig.clf()
288  try:
289  fig.canvas._tkcanvas._root().lift() # == Tk's raise
290  except:
291  pass
292  ax = fig.add_subplot(1, 1, 1)
293  adjust = science.copy()
294  others = set(range(fringes.shape[1]))
295  others.discard(j)
296  for k in others:
297  adjust -= solution[k] * fringes[:, k]
298  ax.plot(fringes[:, j], adjust, 'r.')
299  xmin = fringes[:, j].min()
300  xmax = fringes[:, j].max()
301  ymin = solution[j] * xmin
302  ymax = solution[j] * xmax
303  ax.plot([xmin, xmax], [ymin, ymax], 'b-')
304  ax.set_title("Fringe %d: %f" % (j, solution[j]))
305  ax.set_xlabel("Fringe amplitude")
306  ax.set_ylabel("Science amplitude")
307  ax.set_autoscale_on(False)
308  ax.set_xbound(lower=xmin, upper=xmax)
309  ax.set_ybound(lower=ymin, upper=ymax)
310  fig.show()
311  while True:
312  ans = input("Enter or c to continue [chp]").lower()
313  if ans in ("", "c",):
314  break
315  if ans in ("p",):
316  import pdb
317  pdb.set_trace()
318  elif ans in ("h", ):
319  print("h[elp] c[ontinue] p[db]")
320 
321  if newNum == oldNum:
322  # Not gaining
323  break
324  oldNum = newNum
325  good = numpy.where(good)
326  science = science[good]
327  fringes = fringes[good]
328 
329  # Final solution without rejection
330  solution = self._solve(science, fringes)
331  self.log.info("Fringe solution: %s RMS: %f Good: %d/%d", solution, rms, len(science), origNum)
332  return solution
333 
334  def _solve(self, science, fringes):
335  """Solve for the scale factors
336 
337  @param science Array of science exposure fringe amplitudes
338  @param fringes Array of arrays of fringe frame fringe amplitudes
339  @return Array of scale factors for the fringe frames
340  """
341  return afwMath.LeastSquares.fromDesignMatrix(fringes, science,
342  afwMath.LeastSquares.DIRECT_SVD).getSolution()
343 
344  def subtract(self, science, fringes, solution):
345  """Subtract the fringes
346 
347  @param science Science exposure
348  @param fringes List of fringe frames
349  @param solution Array of scale factors for the fringe frames
350  """
351  if len(solution) != len(fringes):
352  raise RuntimeError("Number of fringe frames (%s) != number of scale factors (%s)" %
353  (len(fringes), len(solution)))
354 
355  for s, f in zip(solution, fringes):
356  science.getMaskedImage().scaledMinus(s, f.getMaskedImage())
357 
358 
359 def measure(mi, x, y, size, statistic, stats):
360  """Measure a statistic within an aperture
361 
362  @param mi MaskedImage to measure
363  @param x, y Center for aperture
364  @param size Size of aperture
365  @param statistic Statistic to measure
366  @param stats StatisticsControl object
367  @return Value of statistic within aperture
368  """
369  bbox = afwGeom.Box2I(afwGeom.Point2I(int(x) - size, int(y - size)), afwGeom.Extent2I(2*size, 2*size))
370  subImage = mi.Factory(mi, bbox, afwImage.LOCAL)
371  return afwMath.makeStatistics(subImage, statistic, stats).getValue()
372 
373 
374 def stdev(vector):
375  """Calculate a robust standard deviation of an array of values
376 
377  @param vector Array of values
378  @return Standard deviation
379  """
380  q1, q3 = numpy.percentile(vector, (25, 75))
381  return 0.74*(q3-q1)
382 
383 
384 def select(vector, clip):
385  """Select values within 'clip' standard deviations of the median
386 
387  Returns a boolean array.
388  """
389  q1, q2, q3 = numpy.percentile(vector, (25, 50, 75))
390  return numpy.abs(vector - q2) < clip*0.74*(q3 - q1)
def runDataRef(self, exposure, dataRef, assembler=None)
Definition: fringe.py:149
def run(self, exposure, fringes, seed=None)
Definition: fringe.py:106
def checkFilter(self, exposure)
Definition: fringe.py:166
def subtract(self, science, fringes, solution)
Definition: fringe.py:344
def _solve(self, science, fringes)
Definition: fringe.py:334
def stdev(vector)
Definition: fringe.py:374
def generatePositions(self, exposure, rng)
Definition: fringe.py:180
def measureExposure(self, exposure, positions, title="Fringe")
Definition: fringe.py:190
def removePedestal(self, fringe)
Definition: fringe.py:170
def readFringes(self, dataRef, assembler=None)
Definition: fringe.py:77
def select(vector, clip)
Definition: fringe.py:384
def measure(mi, x, y, size, statistic, stats)
Definition: fringe.py:359
def solve(self, science, fringes)
Definition: fringe.py:231