lsst.ip.isr  17.0.1-17-gad3fdc4
fringe.py
Go to the documentation of this file.
1 # This file is part of ip_isr.
2 #
3 # Developed for the LSST Data Management System.
4 # This product includes software developed by the LSST Project
5 # (https://www.lsst.org).
6 # See the COPYRIGHT file at the top-level directory of this distribution
7 # for details of code ownership.
8 #
9 # This program is free software: you can redistribute it and/or modify
10 # it under the terms of the GNU General Public License as published by
11 # the Free Software Foundation, either version 3 of the License, or
12 # (at your option) any later version.
13 #
14 # This program is distributed in the hope that it will be useful,
15 # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 # GNU General Public License for more details.
18 #
19 # You should have received a copy of the GNU General Public License
20 # along with this program. If not, see <https://www.gnu.org/licenses/>.
21 
22 import numpy
23 
24 import lsst.geom
25 import lsst.afw.image as afwImage
26 import lsst.afw.math as afwMath
27 import lsst.afw.display as afwDisplay
28 
29 from lsst.pipe.base import Task, Struct, timeMethod
30 from lsst.pex.config import Config, Field, ListField, ConfigField
31 
32 afwDisplay.setDefaultMaskTransparency(75)
33 
34 
35 def getFrame():
36  """Produce a new frame number each time"""
37  getFrame.frame += 1
38  return getFrame.frame
39 
40 
41 getFrame.frame = 0
42 
43 
44 class FringeStatisticsConfig(Config):
45  """Options for measuring fringes on an exposure"""
46  badMaskPlanes = ListField(dtype=str, default=["SAT"], doc="Ignore pixels with these masks")
47  stat = Field(dtype=int, default=int(afwMath.MEDIAN), doc="Statistic to use")
48  clip = Field(dtype=float, default=3.0, doc="Sigma clip threshold")
49  iterations = Field(dtype=int, default=3, doc="Number of fitting iterations")
50  rngSeedOffset = Field(dtype=int, default=0,
51  doc="Offset to the random number generator seed (full seed includes exposure ID)")
52 
53 
54 class FringeConfig(Config):
55  """Fringe subtraction options"""
56  filters = ListField(dtype=str, default=[], doc="Only fringe-subtract these filters")
57  num = Field(dtype=int, default=30000, doc="Number of fringe measurements")
58  small = Field(dtype=int, default=3, doc="Half-size of small (fringe) measurements (pixels)")
59  large = Field(dtype=int, default=30, doc="Half-size of large (background) measurements (pixels)")
60  iterations = Field(dtype=int, default=20, doc="Number of fitting iterations")
61  clip = Field(dtype=float, default=3.0, doc="Sigma clip threshold")
62  stats = ConfigField(dtype=FringeStatisticsConfig, doc="Statistics for measuring fringes")
63  pedestal = Field(dtype=bool, default=False, doc="Remove fringe pedestal?")
64 
65 
66 class FringeTask(Task):
67  """Task to remove fringes from a science exposure
68 
69  We measure fringe amplitudes at random positions on the science exposure
70  and at the same positions on the (potentially multiple) fringe frames
71  and solve for the scales simultaneously.
72  """
73  ConfigClass = FringeConfig
74  _DefaultName = 'isrFringe'
75 
76  def readFringes(self, dataRef, assembler=None):
77  """Read the fringe frame(s)
78 
79  The current implementation assumes only a single fringe frame and
80  will have to be updated to support multi-mode fringe subtraction.
81 
82  This implementation could be optimised by persisting the fringe
83  positions and fluxes.
84 
85  @param dataRef Data reference for the science exposure
86  @param assembler An instance of AssembleCcdTask (for assembling fringe frames)
87  @return Struct(fringes: fringe exposure or list of fringe exposures;
88  seed: 32-bit uint derived from ccdExposureId for random number generator
89  """
90  try:
91  fringe = dataRef.get("fringe", immediate=True)
92  except Exception as e:
93  raise RuntimeError("Unable to retrieve fringe for %s: %s" % (dataRef.dataId, e))
94  if assembler is not None:
95  fringe = assembler.assembleCcd(fringe)
96 
97  seed = self.config.stats.rngSeedOffset + dataRef.get("ccdExposureId", immediate=True)
98  # Seed for numpy.random.RandomState must be convertable to a 32 bit unsigned integer
99  seed %= 2**32
100 
101  return Struct(fringes=fringe,
102  seed=seed)
103 
104  @timeMethod
105  def run(self, exposure, fringes, seed=None):
106  """Remove fringes from the provided science exposure.
107 
108  Primary method of FringeTask. Fringes are only subtracted if the
109  science exposure has a filter listed in the configuration.
110 
111  @param exposure Science exposure from which to remove fringes
112  @param fringes Exposure or list of Exposures
113  @param seed 32-bit unsigned integer for random number generator
114  """
115  import lsstDebug
116  display = lsstDebug.Info(__name__).display
117 
118  if not self.checkFilter(exposure):
119  return
120 
121  if seed is None:
122  seed = self.config.stats.rngSeedOffset
123  rng = numpy.random.RandomState(seed=seed)
124 
125  if not hasattr(fringes, '__iter__'):
126  fringes = [fringes]
127 
128  mask = exposure.getMaskedImage().getMask()
129  for fringe in fringes:
130  fringe.getMaskedImage().getMask().__ior__(mask)
131  if self.config.pedestal:
132  self.removePedestal(fringe)
133 
134  # Placeholder implementation for multiple fringe frames
135  # This needs to be revisited in DM-4441
136  positions = self.generatePositions(fringes[0], rng)
137  fluxes = numpy.ndarray([self.config.num, len(fringes)])
138  for i, f in enumerate(fringes):
139  fluxes[:, i] = self.measureExposure(f, positions, title="Fringe frame")
140 
141  expFringes = self.measureExposure(exposure, positions, title="Science")
142  solution = self.solve(expFringes, fluxes)
143  self.subtract(exposure, fringes, solution)
144  if display:
145  afwDisplay.Display(frame=getFrame()).mtv(exposure, title="Fringe subtracted")
146 
147  @timeMethod
148  def runDataRef(self, exposure, dataRef, assembler=None):
149  """Remove fringes from the provided science exposure.
150 
151  Retrieve fringes from butler dataRef provided and remove from
152  provided science exposure.
153  Fringes are only subtracted if the science exposure has a filter
154  listed in the configuration.
155 
156  @param exposure Science exposure from which to remove fringes
157  @param dataRef Data reference for the science exposure
158  @param assembler An instance of AssembleCcdTask (for assembling fringe frames)
159  """
160  if not self.checkFilter(exposure):
161  return
162  fringeStruct = self.readFringes(dataRef, assembler=assembler)
163  self.run(exposure, **fringeStruct.getDict())
164 
165  def checkFilter(self, exposure):
166  """Check whether we should fringe-subtract the science exposure"""
167  return exposure.getFilter().getName() in self.config.filters
168 
169  def removePedestal(self, fringe):
170  """Remove pedestal from fringe exposure"""
171  stats = afwMath.StatisticsControl()
172  stats.setNumSigmaClip(self.config.stats.clip)
173  stats.setNumIter(self.config.stats.iterations)
174  mi = fringe.getMaskedImage()
175  pedestal = afwMath.makeStatistics(mi, afwMath.MEDIAN, stats).getValue()
176  self.log.info("Removing fringe pedestal: %f", pedestal)
177  mi -= pedestal
178 
179  def generatePositions(self, exposure, rng):
180  """Generate a random distribution of positions for measuring fringe amplitudes"""
181  start = self.config.large
182  num = self.config.num
183  width = exposure.getWidth() - self.config.large
184  height = exposure.getHeight() - self.config.large
185  return numpy.array([rng.randint(start, width, size=num),
186  rng.randint(start, height, size=num)]).swapaxes(0, 1)
187 
188  @timeMethod
189  def measureExposure(self, exposure, positions, title="Fringe"):
190  """Measure fringe amplitudes for an exposure
191 
192  The fringe amplitudes are measured as the statistic within a square
193  aperture. The statistic within a larger aperture are subtracted so
194  as to remove the background.
195 
196  @param exposure Exposure to measure
197  @param positions Array of (x,y) for fringe measurement
198  @param title Title for display
199  @return Array of fringe measurements
200  """
201  stats = afwMath.StatisticsControl()
202  stats.setNumSigmaClip(self.config.stats.clip)
203  stats.setNumIter(self.config.stats.iterations)
204  stats.setAndMask(exposure.getMaskedImage().getMask().getPlaneBitMask(self.config.stats.badMaskPlanes))
205 
206  num = self.config.num
207  fringes = numpy.ndarray(num)
208 
209  for i in range(num):
210  x, y = positions[i]
211  small = measure(exposure.getMaskedImage(), x, y, self.config.small, self.config.stats.stat, stats)
212  large = measure(exposure.getMaskedImage(), x, y, self.config.large, self.config.stats.stat, stats)
213  fringes[i] = small - large
214 
215  import lsstDebug
216  display = lsstDebug.Info(__name__).display
217  if display:
218  disp = afwDisplay.Display(frame=getFrame())
219  disp.mtv(exposure, title=title)
220  if False:
221  with disp.Buffering():
222  for x, y in positions:
223  corners = numpy.array([[-1, -1], [1, -1], [1, 1], [-1, 1], [-1, -1]]) + [[x, y]]
224  disp.line(corners*self.config.small, ctype=afwDisplay.GREEN)
225  disp.line(corners*self.config.large, ctype=afwDisplay.BLUE)
226 
227  return fringes
228 
229  @timeMethod
230  def solve(self, science, fringes):
231  """Solve (with iterative clipping) for the scale factors
232 
233  @param science Array of science exposure fringe amplitudes
234  @param fringes Array of arrays of fringe frame fringe amplitudes
235  @return Array of scale factors for the fringe frames
236  """
237  import lsstDebug
238  doPlot = lsstDebug.Info(__name__).plot
239 
240  origNum = len(science)
241 
242  def emptyResult(msg=""):
243  """Generate an empty result for return to the user
244 
245  There are no good pixels; doesn't matter what we return.
246  """
247  self.log.warn("Unable to solve for fringes: no good pixels%s", msg)
248  out = [0]
249  if len(fringes) > 1:
250  out = out*len(fringes)
251  return numpy.array(out)
252 
253  good = numpy.where(numpy.logical_and(numpy.isfinite(science), numpy.any(numpy.isfinite(fringes), 1)))
254  science = science[good]
255  fringes = fringes[good]
256  oldNum = len(science)
257  if oldNum == 0:
258  return emptyResult()
259 
260  # Up-front rejection to get rid of extreme, potentially troublesome values
261  # (e.g., fringe apertures that fall on objects).
262  good = select(science, self.config.clip)
263  for ff in range(fringes.shape[1]):
264  good &= select(fringes[:, ff], self.config.clip)
265  science = science[good]
266  fringes = fringes[good]
267  oldNum = len(science)
268  if oldNum == 0:
269  return emptyResult(" after initial rejection")
270 
271  for i in range(self.config.iterations):
272  solution = self._solve(science, fringes)
273  resid = science - numpy.sum(solution*fringes, 1)
274  rms = stdev(resid)
275  good = numpy.logical_not(abs(resid) > self.config.clip*rms)
276  self.log.debug("Iteration %d: RMS=%f numGood=%d", i, rms, good.sum())
277  self.log.debug("Solution %d: %s", i, solution)
278  newNum = good.sum()
279  if newNum == 0:
280  return emptyResult(" after %d rejection iterations" % i)
281 
282  if doPlot:
283  import matplotlib.pyplot as plot
284  for j in range(fringes.shape[1]):
285  fig = plot.figure(j)
286  fig.clf()
287  try:
288  fig.canvas._tkcanvas._root().lift() # == Tk's raise
289  except Exception:
290  pass
291  ax = fig.add_subplot(1, 1, 1)
292  adjust = science.copy()
293  others = set(range(fringes.shape[1]))
294  others.discard(j)
295  for k in others:
296  adjust -= solution[k]*fringes[:, k]
297  ax.plot(fringes[:, j], adjust, 'r.')
298  xmin = fringes[:, j].min()
299  xmax = fringes[:, j].max()
300  ymin = solution[j]*xmin
301  ymax = solution[j]*xmax
302  ax.plot([xmin, xmax], [ymin, ymax], 'b-')
303  ax.set_title("Fringe %d: %f" % (j, solution[j]))
304  ax.set_xlabel("Fringe amplitude")
305  ax.set_ylabel("Science amplitude")
306  ax.set_autoscale_on(False)
307  ax.set_xbound(lower=xmin, upper=xmax)
308  ax.set_ybound(lower=ymin, upper=ymax)
309  fig.show()
310  while True:
311  ans = input("Enter or c to continue [chp]").lower()
312  if ans in ("", "c",):
313  break
314  if ans in ("p",):
315  import pdb
316  pdb.set_trace()
317  elif ans in ("h", ):
318  print("h[elp] c[ontinue] p[db]")
319 
320  if newNum == oldNum:
321  # Not gaining
322  break
323  oldNum = newNum
324  good = numpy.where(good)
325  science = science[good]
326  fringes = fringes[good]
327 
328  # Final solution without rejection
329  solution = self._solve(science, fringes)
330  self.log.info("Fringe solution: %s RMS: %f Good: %d/%d", solution, rms, len(science), origNum)
331  return solution
332 
333  def _solve(self, science, fringes):
334  """Solve for the scale factors
335 
336  @param science Array of science exposure fringe amplitudes
337  @param fringes Array of arrays of fringe frame fringe amplitudes
338  @return Array of scale factors for the fringe frames
339  """
340  return afwMath.LeastSquares.fromDesignMatrix(fringes, science,
341  afwMath.LeastSquares.DIRECT_SVD).getSolution()
342 
343  def subtract(self, science, fringes, solution):
344  """Subtract the fringes
345 
346  @param science Science exposure
347  @param fringes List of fringe frames
348  @param solution Array of scale factors for the fringe frames
349  """
350  if len(solution) != len(fringes):
351  raise RuntimeError("Number of fringe frames (%s) != number of scale factors (%s)" %
352  (len(fringes), len(solution)))
353 
354  for s, f in zip(solution, fringes):
355  science.getMaskedImage().scaledMinus(s, f.getMaskedImage())
356 
357 
358 def measure(mi, x, y, size, statistic, stats):
359  """Measure a statistic within an aperture
360 
361  @param mi MaskedImage to measure
362  @param x, y Center for aperture
363  @param size Size of aperture
364  @param statistic Statistic to measure
365  @param stats StatisticsControl object
366  @return Value of statistic within aperture
367  """
368  bbox = lsst.geom.Box2I(lsst.geom.Point2I(int(x) - size, int(y - size)),
369  lsst.geom.Extent2I(2*size, 2*size))
370  subImage = mi.Factory(mi, bbox, afwImage.LOCAL)
371  return afwMath.makeStatistics(subImage, statistic, stats).getValue()
372 
373 
374 def stdev(vector):
375  """Calculate a robust standard deviation of an array of values
376 
377  @param vector Array of values
378  @return Standard deviation
379  """
380  q1, q3 = numpy.percentile(vector, (25, 75))
381  return 0.74*(q3 - q1)
382 
383 
384 def select(vector, clip):
385  """Select values within 'clip' standard deviations of the median
386 
387  Returns a boolean array.
388  """
389  q1, q2, q3 = numpy.percentile(vector, (25, 50, 75))
390  return numpy.abs(vector - q2) < clip*0.74*(q3 - q1)
def runDataRef(self, exposure, dataRef, assembler=None)
Definition: fringe.py:148
def run(self, exposure, fringes, seed=None)
Definition: fringe.py:105
def checkFilter(self, exposure)
Definition: fringe.py:165
def subtract(self, science, fringes, solution)
Definition: fringe.py:343
def _solve(self, science, fringes)
Definition: fringe.py:333
def stdev(vector)
Definition: fringe.py:374
def generatePositions(self, exposure, rng)
Definition: fringe.py:179
def measureExposure(self, exposure, positions, title="Fringe")
Definition: fringe.py:189
def removePedestal(self, fringe)
Definition: fringe.py:169
def readFringes(self, dataRef, assembler=None)
Definition: fringe.py:76
def select(vector, clip)
Definition: fringe.py:384
def measure(mi, x, y, size, statistic, stats)
Definition: fringe.py:358
def solve(self, science, fringes)
Definition: fringe.py:230