Coverage for python/lsst/ip/isr/fringe.py : 17%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of ip_isr.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import numpy
24import lsst.geom
25import lsst.afw.image as afwImage
26import lsst.afw.math as afwMath
27import lsst.afw.display as afwDisplay
29from lsst.pipe.base import Task, Struct, timeMethod
30from lsst.pex.config import Config, Field, ListField, ConfigField
32afwDisplay.setDefaultMaskTransparency(75)
35def getFrame():
36 """Produce a new frame number each time"""
37 getFrame.frame += 1
38 return getFrame.frame
41getFrame.frame = 0
44class FringeStatisticsConfig(Config):
45 """Options for measuring fringes on an exposure"""
46 badMaskPlanes = ListField(dtype=str, default=["SAT"], doc="Ignore pixels with these masks")
47 stat = Field(dtype=int, default=int(afwMath.MEDIAN), doc="Statistic to use")
48 clip = Field(dtype=float, default=3.0, doc="Sigma clip threshold")
49 iterations = Field(dtype=int, default=3, doc="Number of fitting iterations")
50 rngSeedOffset = Field(dtype=int, default=0,
51 doc="Offset to the random number generator seed (full seed includes exposure ID)")
54class FringeConfig(Config):
55 """Fringe subtraction options"""
56 # TODO DM-28093: change the doc to specify that these are physical labels
57 filters = ListField(dtype=str, default=[], doc="Only fringe-subtract these filters")
58 # TODO: remove in DM-27177
59 useFilterAliases = Field(dtype=bool, default=False, doc="Search filter aliases during check.",
60 deprecated=("Removed with no replacement (FilterLabel has no aliases)."
61 "Will be removed after v22."))
62 num = Field(dtype=int, default=30000, doc="Number of fringe measurements")
63 small = Field(dtype=int, default=3, doc="Half-size of small (fringe) measurements (pixels)")
64 large = Field(dtype=int, default=30, doc="Half-size of large (background) measurements (pixels)")
65 iterations = Field(dtype=int, default=20, doc="Number of fitting iterations")
66 clip = Field(dtype=float, default=3.0, doc="Sigma clip threshold")
67 stats = ConfigField(dtype=FringeStatisticsConfig, doc="Statistics for measuring fringes")
68 pedestal = Field(dtype=bool, default=False, doc="Remove fringe pedestal?")
71class FringeTask(Task):
72 """Task to remove fringes from a science exposure
74 We measure fringe amplitudes at random positions on the science exposure
75 and at the same positions on the (potentially multiple) fringe frames
76 and solve for the scales simultaneously.
77 """
78 ConfigClass = FringeConfig
79 _DefaultName = 'isrFringe'
81 def readFringes(self, dataRef, assembler=None):
82 """Read the fringe frame(s), and pack data into a Struct
84 The current implementation assumes only a single fringe frame and
85 will have to be updated to support multi-mode fringe subtraction.
87 This implementation could be optimised by persisting the fringe
88 positions and fluxes.
90 Parameters
91 ----------
92 dataRef : `daf.butler.butlerSubset.ButlerDataRef`
93 Butler reference for the exposure that will have fringing
94 removed.
95 assembler : `lsst.ip.isr.AssembleCcdTask`, optional
96 An instance of AssembleCcdTask (for assembling fringe
97 frames).
99 Returns
100 -------
101 fringeData : `pipeBase.Struct`
102 Struct containing fringe data:
103 - ``fringes`` : `lsst.afw.image.Exposure` or `list` thereof
104 Calibration fringe files containing master fringe frames.
105 - ``seed`` : `int`, optional
106 Seed for random number generation.
107 """
108 try:
109 fringe = dataRef.get("fringe", immediate=True)
110 except Exception as e:
111 raise RuntimeError("Unable to retrieve fringe for %s: %s" % (dataRef.dataId, e))
113 return self.loadFringes(fringe, assembler)
115 def loadFringes(self, fringeExp, expId=0, assembler=None):
116 """Pack the fringe data into a Struct.
118 This method moves the struct parsing code into a butler
119 generation agnostic handler.
121 Parameters
122 ----------
123 fringeExp : `lsst.afw.exposure.Exposure`
124 The exposure containing the fringe data.
125 expId : `int`, optional
126 Exposure id to be fringe corrected, used to set RNG seed.
127 assembler : `lsst.ip.isr.AssembleCcdTask`, optional
128 An instance of AssembleCcdTask (for assembling fringe
129 frames).
131 Returns
132 -------
133 fringeData : `pipeBase.Struct`
134 Struct containing fringe data:
135 - ``fringes`` : `lsst.afw.image.Exposure` or `list` thereof
136 Calibration fringe files containing master fringe frames.
137 - ``seed`` : `int`, optional
138 Seed for random number generation.
139 """
140 if assembler is not None:
141 fringeExp = assembler.assembleCcd(fringeExp)
143 if expId is None:
144 seed = self.config.stats.rngSeedOffset
145 else:
146 print(f"{self.config.stats.rngSeedOffset} {expId}")
147 seed = self.config.stats.rngSeedOffset + expId
149 # Seed for numpy.random.RandomState must be convertable to a 32 bit unsigned integer
150 seed %= 2**32
152 return Struct(fringes=fringeExp,
153 seed=seed)
155 @timeMethod
156 def run(self, exposure, fringes, seed=None):
157 """Remove fringes from the provided science exposure.
159 Primary method of FringeTask. Fringes are only subtracted if the
160 science exposure has a filter listed in the configuration.
162 Parameters
163 ----------
164 exposure : `lsst.afw.image.Exposure`
165 Science exposure from which to remove fringes.
166 fringes : `lsst.afw.image.Exposure` or `list` thereof
167 Calibration fringe files containing master fringe frames.
168 seed : `int`, optional
169 Seed for random number generation.
171 Returns
172 -------
173 solution : `np.array`
174 Fringe solution amplitudes for each input fringe frame.
175 rms : `float`
176 RMS error for the fit solution for this exposure.
177 """
178 import lsstDebug
179 display = lsstDebug.Info(__name__).display
181 if not self.checkFilter(exposure):
182 self.log.info("Filter not found in FringeTaskConfig.filters. Skipping fringe correction.")
183 return
185 if seed is None:
186 seed = self.config.stats.rngSeedOffset
187 rng = numpy.random.RandomState(seed=seed)
189 if not hasattr(fringes, '__iter__'):
190 fringes = [fringes]
192 mask = exposure.getMaskedImage().getMask()
193 for fringe in fringes:
194 fringe.getMaskedImage().getMask().__ior__(mask)
195 if self.config.pedestal:
196 self.removePedestal(fringe)
198 positions = self.generatePositions(fringes[0], rng)
199 fluxes = numpy.ndarray([self.config.num, len(fringes)])
200 for i, f in enumerate(fringes):
201 fluxes[:, i] = self.measureExposure(f, positions, title="Fringe frame")
203 expFringes = self.measureExposure(exposure, positions, title="Science")
204 solution, rms = self.solve(expFringes, fluxes)
205 self.subtract(exposure, fringes, solution)
206 if display:
207 afwDisplay.Display(frame=getFrame()).mtv(exposure, title="Fringe subtracted")
208 return solution, rms
210 @timeMethod
211 def runDataRef(self, exposure, dataRef, assembler=None):
212 """Remove fringes from the provided science exposure.
214 Retrieve fringes from butler dataRef provided and remove from
215 provided science exposure. Fringes are only subtracted if the
216 science exposure has a filter listed in the configuration.
218 Parameters
219 ----------
220 exposure : `lsst.afw.image.Exposure`
221 Science exposure from which to remove fringes.
222 dataRef : `daf.persistence.butlerSubset.ButlerDataRef`
223 Butler reference to the exposure. Used to find
224 appropriate fringe data.
225 assembler : `lsst.ip.isr.AssembleCcdTask`, optional
226 An instance of AssembleCcdTask (for assembling fringe
227 frames).
229 Returns
230 -------
231 solution : `np.array`
232 Fringe solution amplitudes for each input fringe frame.
233 rms : `float`
234 RMS error for the fit solution for this exposure.
235 """
236 if not self.checkFilter(exposure):
237 self.log.info("Filter not found in FringeTaskConfig.filters. Skipping fringe correction.")
238 return
239 fringeStruct = self.readFringes(dataRef, assembler=assembler)
240 return self.run(exposure, **fringeStruct.getDict())
242 def checkFilter(self, exposure):
243 """Check whether we should fringe-subtract the science exposure.
245 Parameters
246 ----------
247 exposure : `lsst.afw.image.Exposure`
248 Exposure to check the filter of.
250 Returns
251 -------
252 needsFringe : `bool`
253 If True, then the exposure has a filter listed in the
254 configuration, and should have the fringe applied.
255 """
256 filterObj = afwImage.Filter(exposure.getFilter().getId())
257 # TODO: remove this check along with the config option in DM-27177
258 if self.config.useFilterAliases:
259 filterNameSet = set(filterObj.getAliases() + [filterObj.getName()])
260 else:
261 filterNameSet = set([filterObj.getName(), ])
262 return bool(len(filterNameSet.intersection(self.config.filters)))
264 def removePedestal(self, fringe):
265 """Remove pedestal from fringe exposure.
267 Parameters
268 ----------
269 fringe : `lsst.afw.image.Exposure`
270 Fringe data to subtract the pedestal value from.
271 """
272 stats = afwMath.StatisticsControl()
273 stats.setNumSigmaClip(self.config.stats.clip)
274 stats.setNumIter(self.config.stats.iterations)
275 mi = fringe.getMaskedImage()
276 pedestal = afwMath.makeStatistics(mi, afwMath.MEDIAN, stats).getValue()
277 self.log.info("Removing fringe pedestal: %f", pedestal)
278 mi -= pedestal
280 def generatePositions(self, exposure, rng):
281 """Generate a random distribution of positions for measuring fringe amplitudes.
283 Parameters
284 ----------
285 exposure : `lsst.afw.image.Exposure`
286 Exposure to measure the positions on.
287 rng : `numpy.random.RandomState`
288 Random number generator to use.
290 Returns
291 -------
292 positions : `numpy.array`
293 Two-dimensional array containing the positions to sample
294 for fringe amplitudes.
295 """
296 start = self.config.large
297 num = self.config.num
298 width = exposure.getWidth() - self.config.large
299 height = exposure.getHeight() - self.config.large
300 return numpy.array([rng.randint(start, width, size=num),
301 rng.randint(start, height, size=num)]).swapaxes(0, 1)
303 @timeMethod
304 def measureExposure(self, exposure, positions, title="Fringe"):
305 """Measure fringe amplitudes for an exposure
307 The fringe amplitudes are measured as the statistic within a square
308 aperture. The statistic within a larger aperture are subtracted so
309 as to remove the background.
311 Parameters
312 ----------
313 exposure : `lsst.afw.image.Exposure`
314 Exposure to measure the positions on.
315 positions : `numpy.array`
316 Two-dimensional array containing the positions to sample
317 for fringe amplitudes.
318 title : `str`, optional
319 Title used for debug out plots.
321 Returns
322 -------
323 fringes : `numpy.array`
324 Array of measured exposure values at each of the positions
325 supplied.
326 """
327 stats = afwMath.StatisticsControl()
328 stats.setNumSigmaClip(self.config.stats.clip)
329 stats.setNumIter(self.config.stats.iterations)
330 stats.setAndMask(exposure.getMaskedImage().getMask().getPlaneBitMask(self.config.stats.badMaskPlanes))
332 num = self.config.num
333 fringes = numpy.ndarray(num)
335 for i in range(num):
336 x, y = positions[i]
337 small = measure(exposure.getMaskedImage(), x, y, self.config.small, self.config.stats.stat, stats)
338 large = measure(exposure.getMaskedImage(), x, y, self.config.large, self.config.stats.stat, stats)
339 fringes[i] = small - large
341 import lsstDebug
342 display = lsstDebug.Info(__name__).display
343 if display:
344 disp = afwDisplay.Display(frame=getFrame())
345 disp.mtv(exposure, title=title)
346 if False:
347 with disp.Buffering():
348 for x, y in positions:
349 corners = numpy.array([[-1, -1], [1, -1], [1, 1], [-1, 1], [-1, -1]]) + [[x, y]]
350 disp.line(corners*self.config.small, ctype=afwDisplay.GREEN)
351 disp.line(corners*self.config.large, ctype=afwDisplay.BLUE)
353 return fringes
355 @timeMethod
356 def solve(self, science, fringes):
357 """Solve for the scale factors with iterative clipping.
359 Parameters
360 ----------
361 science : `numpy.array`
362 Array of measured science image values at each of the
363 positions supplied.
364 fringes : `numpy.array`
365 Array of measured fringe values at each of the positions
366 supplied.
368 Returns
369 -------
370 solution : `np.array`
371 Fringe solution amplitudes for each input fringe frame.
372 rms : `float`
373 RMS error for the fit solution for this exposure.
374 """
375 import lsstDebug
376 doPlot = lsstDebug.Info(__name__).plot
378 origNum = len(science)
380 def emptyResult(msg=""):
381 """Generate an empty result for return to the user
383 There are no good pixels; doesn't matter what we return.
384 """
385 self.log.warn("Unable to solve for fringes: no good pixels%s", msg)
386 out = [0]
387 if len(fringes) > 1:
388 out = out*len(fringes)
389 return numpy.array(out), numpy.nan
391 good = numpy.where(numpy.logical_and(numpy.isfinite(science), numpy.any(numpy.isfinite(fringes), 1)))
392 science = science[good]
393 fringes = fringes[good]
394 oldNum = len(science)
395 if oldNum == 0:
396 return emptyResult()
398 # Up-front rejection to get rid of extreme, potentially troublesome values
399 # (e.g., fringe apertures that fall on objects).
400 good = select(science, self.config.clip)
401 for ff in range(fringes.shape[1]):
402 good &= select(fringes[:, ff], self.config.clip)
403 science = science[good]
404 fringes = fringes[good]
405 oldNum = len(science)
406 if oldNum == 0:
407 return emptyResult(" after initial rejection")
409 for i in range(self.config.iterations):
410 solution = self._solve(science, fringes)
411 resid = science - numpy.sum(solution*fringes, 1)
412 rms = stdev(resid)
413 good = numpy.logical_not(abs(resid) > self.config.clip*rms)
414 self.log.debug("Iteration %d: RMS=%f numGood=%d", i, rms, good.sum())
415 self.log.debug("Solution %d: %s", i, solution)
416 newNum = good.sum()
417 if newNum == 0:
418 return emptyResult(" after %d rejection iterations" % i)
420 if doPlot:
421 import matplotlib.pyplot as plot
422 for j in range(fringes.shape[1]):
423 fig = plot.figure(j)
424 fig.clf()
425 try:
426 fig.canvas._tkcanvas._root().lift() # == Tk's raise
427 except Exception:
428 pass
429 ax = fig.add_subplot(1, 1, 1)
430 adjust = science.copy()
431 others = set(range(fringes.shape[1]))
432 others.discard(j)
433 for k in others:
434 adjust -= solution[k]*fringes[:, k]
435 ax.plot(fringes[:, j], adjust, 'r.')
436 xmin = fringes[:, j].min()
437 xmax = fringes[:, j].max()
438 ymin = solution[j]*xmin
439 ymax = solution[j]*xmax
440 ax.plot([xmin, xmax], [ymin, ymax], 'b-')
441 ax.set_title("Fringe %d: %f" % (j, solution[j]))
442 ax.set_xlabel("Fringe amplitude")
443 ax.set_ylabel("Science amplitude")
444 ax.set_autoscale_on(False)
445 ax.set_xbound(lower=xmin, upper=xmax)
446 ax.set_ybound(lower=ymin, upper=ymax)
447 fig.show()
448 while True:
449 ans = input("Enter or c to continue [chp]").lower()
450 if ans in ("", "c",):
451 break
452 if ans in ("p",):
453 import pdb
454 pdb.set_trace()
455 elif ans in ("h", ):
456 print("h[elp] c[ontinue] p[db]")
458 if newNum == oldNum:
459 # Not gaining
460 break
461 oldNum = newNum
462 good = numpy.where(good)
463 science = science[good]
464 fringes = fringes[good]
466 # Final solution without rejection
467 solution = self._solve(science, fringes)
468 self.log.info("Fringe solution: %s RMS: %f Good: %d/%d", solution, rms, len(science), origNum)
469 return solution, rms
471 def _solve(self, science, fringes):
472 """Solve for the scale factors.
474 Parameters
475 ----------
476 science : `numpy.array`
477 Array of measured science image values at each of the
478 positions supplied.
479 fringes : `numpy.array`
480 Array of measured fringe values at each of the positions
481 supplied.
483 Returns
484 -------
485 solution : `np.array`
486 Fringe solution amplitudes for each input fringe frame.
487 """
488 return afwMath.LeastSquares.fromDesignMatrix(fringes, science,
489 afwMath.LeastSquares.DIRECT_SVD).getSolution()
491 def subtract(self, science, fringes, solution):
492 """Subtract the fringes.
494 Parameters
495 ----------
496 science : `lsst.afw.image.Exposure`
497 Science exposure from which to remove fringes.
498 fringes : `lsst.afw.image.Exposure` or `list` thereof
499 Calibration fringe files containing master fringe frames.
500 solution : `np.array`
501 Fringe solution amplitudes for each input fringe frame.
503 Raises
504 ------
505 RuntimeError :
506 Raised if the number of fringe frames does not match the
507 number of measured amplitudes.
508 """
509 if len(solution) != len(fringes):
510 raise RuntimeError("Number of fringe frames (%s) != number of scale factors (%s)" %
511 (len(fringes), len(solution)))
513 for s, f in zip(solution, fringes):
514 science.getMaskedImage().scaledMinus(s, f.getMaskedImage())
517def measure(mi, x, y, size, statistic, stats):
518 """Measure a statistic within an aperture
520 @param mi MaskedImage to measure
521 @param x, y Center for aperture
522 @param size Size of aperture
523 @param statistic Statistic to measure
524 @param stats StatisticsControl object
525 @return Value of statistic within aperture
526 """
527 bbox = lsst.geom.Box2I(lsst.geom.Point2I(int(x) - size, int(y - size)),
528 lsst.geom.Extent2I(2*size, 2*size))
529 subImage = mi.Factory(mi, bbox, afwImage.LOCAL)
530 return afwMath.makeStatistics(subImage, statistic, stats).getValue()
533def stdev(vector):
534 """Calculate a robust standard deviation of an array of values
536 @param vector Array of values
537 @return Standard deviation
538 """
539 q1, q3 = numpy.percentile(vector, (25, 75))
540 return 0.74*(q3 - q1)
543def select(vector, clip):
544 """Select values within 'clip' standard deviations of the median
546 Returns a boolean array.
547 """
548 q1, q2, q3 = numpy.percentile(vector, (25, 50, 75))
549 return numpy.abs(vector - q2) < clip*0.74*(q3 - q1)