Coverage for python/lsst/ip/isr/fringe.py: 20%
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of ip_isr.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import numpy
24import lsst.geom
25import lsst.afw.image as afwImage
26import lsst.afw.math as afwMath
27import lsst.afw.display as afwDisplay
29from lsst.pipe.base import Task, Struct
30from lsst.pex.config import Config, Field, ListField, ConfigField
31from lsst.utils.timer import timeMethod
32from .isrFunctions import checkFilter
34afwDisplay.setDefaultMaskTransparency(75)
37def getFrame():
38 """Produce a new frame number each time"""
39 getFrame.frame += 1
40 return getFrame.frame
43getFrame.frame = 0
46class FringeStatisticsConfig(Config):
47 """Options for measuring fringes on an exposure"""
48 badMaskPlanes = ListField(dtype=str, default=["SAT"], doc="Ignore pixels with these masks")
49 stat = Field(dtype=int, default=int(afwMath.MEDIAN), doc="Statistic to use")
50 clip = Field(dtype=float, default=3.0, doc="Sigma clip threshold")
51 iterations = Field(dtype=int, default=3, doc="Number of fitting iterations")
52 rngSeedOffset = Field(dtype=int, default=0,
53 doc="Offset to the random number generator seed (full seed includes exposure ID)")
56class FringeConfig(Config):
57 """Fringe subtraction options"""
58 # TODO DM-28093: change the doc to specify that these are physical labels
59 filters = ListField(dtype=str, default=[], doc="Only fringe-subtract these filters")
60 # TODO: remove in DM-27177
61 useFilterAliases = Field(dtype=bool, default=False, doc="Search filter aliases during check.",
62 deprecated=("Removed with no replacement (FilterLabel has no aliases)."
63 "Will be removed after v22."))
64 num = Field(dtype=int, default=30000, doc="Number of fringe measurements")
65 small = Field(dtype=int, default=3, doc="Half-size of small (fringe) measurements (pixels)")
66 large = Field(dtype=int, default=30, doc="Half-size of large (background) measurements (pixels)")
67 iterations = Field(dtype=int, default=20, doc="Number of fitting iterations")
68 clip = Field(dtype=float, default=3.0, doc="Sigma clip threshold")
69 stats = ConfigField(dtype=FringeStatisticsConfig, doc="Statistics for measuring fringes")
70 pedestal = Field(dtype=bool, default=False, doc="Remove fringe pedestal?")
73class FringeTask(Task):
74 """Task to remove fringes from a science exposure
76 We measure fringe amplitudes at random positions on the science exposure
77 and at the same positions on the (potentially multiple) fringe frames
78 and solve for the scales simultaneously.
79 """
80 ConfigClass = FringeConfig
81 _DefaultName = 'isrFringe'
83 def readFringes(self, dataRef, expId=None, assembler=None):
84 """Read the fringe frame(s), and pack data into a Struct
86 The current implementation assumes only a single fringe frame and
87 will have to be updated to support multi-mode fringe subtraction.
89 This implementation could be optimised by persisting the fringe
90 positions and fluxes.
92 Parameters
93 ----------
94 dataRef : `daf.butler.butlerSubset.ButlerDataRef`
95 Butler reference for the exposure that will have fringing
96 removed.
97 expId : `int`, optional
98 Exposure id to be fringe corrected, used to set RNG seed.
99 assembler : `lsst.ip.isr.AssembleCcdTask`, optional
100 An instance of AssembleCcdTask (for assembling fringe
101 frames).
103 Returns
104 -------
105 fringeData : `pipeBase.Struct`
106 Struct containing fringe data:
107 - ``fringes`` : `lsst.afw.image.Exposure` or `list` thereof
108 Calibration fringe files containing master fringe frames.
109 - ``seed`` : `int`, optional
110 Seed for random number generation.
111 """
112 try:
113 fringe = dataRef.get("fringe", immediate=True)
114 except Exception as e:
115 raise RuntimeError("Unable to retrieve fringe for %s: %s" % (dataRef.dataId, e))
117 return self.loadFringes(fringe, expId=expId, assembler=assembler)
119 def loadFringes(self, fringeExp, expId=None, assembler=None):
120 """Pack the fringe data into a Struct.
122 This method moves the struct parsing code into a butler
123 generation agnostic handler.
125 Parameters
126 ----------
127 fringeExp : `lsst.afw.exposure.Exposure`
128 The exposure containing the fringe data.
129 expId : `int`, optional
130 Exposure id to be fringe corrected, used to set RNG seed.
131 assembler : `lsst.ip.isr.AssembleCcdTask`, optional
132 An instance of AssembleCcdTask (for assembling fringe
133 frames).
135 Returns
136 -------
137 fringeData : `pipeBase.Struct`
138 Struct containing fringe data:
139 - ``fringes`` : `lsst.afw.image.Exposure` or `list` thereof
140 Calibration fringe files containing master fringe frames.
141 - ``seed`` : `int`, optional
142 Seed for random number generation.
143 """
144 if assembler is not None:
145 fringeExp = assembler.assembleCcd(fringeExp)
147 if expId is None:
148 seed = self.config.stats.rngSeedOffset
149 else:
150 print(f"{self.config.stats.rngSeedOffset} {expId}")
151 seed = self.config.stats.rngSeedOffset + expId
153 # Seed for numpy.random.RandomState must be convertable to a 32 bit unsigned integer
154 seed %= 2**32
156 return Struct(fringes=fringeExp,
157 seed=seed)
159 @timeMethod
160 def run(self, exposure, fringes, seed=None):
161 """Remove fringes from the provided science exposure.
163 Primary method of FringeTask. Fringes are only subtracted if the
164 science exposure has a filter listed in the configuration.
166 Parameters
167 ----------
168 exposure : `lsst.afw.image.Exposure`
169 Science exposure from which to remove fringes.
170 fringes : `lsst.afw.image.Exposure` or `list` thereof
171 Calibration fringe files containing master fringe frames.
172 seed : `int`, optional
173 Seed for random number generation.
175 Returns
176 -------
177 solution : `np.array`
178 Fringe solution amplitudes for each input fringe frame.
179 rms : `float`
180 RMS error for the fit solution for this exposure.
181 """
182 import lsstDebug
183 display = lsstDebug.Info(__name__).display
185 if not self.checkFilter(exposure):
186 self.log.info("Filter not found in FringeTaskConfig.filters. Skipping fringe correction.")
187 return
189 if seed is None:
190 seed = self.config.stats.rngSeedOffset
191 rng = numpy.random.RandomState(seed=seed)
193 if not hasattr(fringes, '__iter__'):
194 fringes = [fringes]
196 mask = exposure.getMaskedImage().getMask()
197 for fringe in fringes:
198 fringe.getMaskedImage().getMask().__ior__(mask)
199 if self.config.pedestal:
200 self.removePedestal(fringe)
202 positions = self.generatePositions(fringes[0], rng)
203 fluxes = numpy.ndarray([self.config.num, len(fringes)])
204 for i, f in enumerate(fringes):
205 fluxes[:, i] = self.measureExposure(f, positions, title="Fringe frame")
207 expFringes = self.measureExposure(exposure, positions, title="Science")
208 solution, rms = self.solve(expFringes, fluxes)
209 self.subtract(exposure, fringes, solution)
210 if display:
211 afwDisplay.Display(frame=getFrame()).mtv(exposure, title="Fringe subtracted")
212 return solution, rms
214 @timeMethod
215 def runDataRef(self, exposure, dataRef, assembler=None):
216 """Remove fringes from the provided science exposure.
218 Retrieve fringes from butler dataRef provided and remove from
219 provided science exposure. Fringes are only subtracted if the
220 science exposure has a filter listed in the configuration.
222 Parameters
223 ----------
224 exposure : `lsst.afw.image.Exposure`
225 Science exposure from which to remove fringes.
226 dataRef : `daf.persistence.butlerSubset.ButlerDataRef`
227 Butler reference to the exposure. Used to find
228 appropriate fringe data.
229 assembler : `lsst.ip.isr.AssembleCcdTask`, optional
230 An instance of AssembleCcdTask (for assembling fringe
231 frames).
233 Returns
234 -------
235 solution : `np.array`
236 Fringe solution amplitudes for each input fringe frame.
237 rms : `float`
238 RMS error for the fit solution for this exposure.
239 """
240 if not self.checkFilter(exposure):
241 self.log.info("Filter not found in FringeTaskConfig.filters. Skipping fringe correction.")
242 return
243 fringeStruct = self.readFringes(dataRef, assembler=assembler)
244 return self.run(exposure, **fringeStruct.getDict())
246 def checkFilter(self, exposure):
247 """Check whether we should fringe-subtract the science exposure.
249 Parameters
250 ----------
251 exposure : `lsst.afw.image.Exposure`
252 Exposure to check the filter of.
254 Returns
255 -------
256 needsFringe : `bool`
257 If True, then the exposure has a filter listed in the
258 configuration, and should have the fringe applied.
259 """
260 return checkFilter(exposure, self.config.filters, log=self.log)
262 def removePedestal(self, fringe):
263 """Remove pedestal from fringe exposure.
265 Parameters
266 ----------
267 fringe : `lsst.afw.image.Exposure`
268 Fringe data to subtract the pedestal value from.
269 """
270 stats = afwMath.StatisticsControl()
271 stats.setNumSigmaClip(self.config.stats.clip)
272 stats.setNumIter(self.config.stats.iterations)
273 mi = fringe.getMaskedImage()
274 pedestal = afwMath.makeStatistics(mi, afwMath.MEDIAN, stats).getValue()
275 self.log.info("Removing fringe pedestal: %f", pedestal)
276 mi -= pedestal
278 def generatePositions(self, exposure, rng):
279 """Generate a random distribution of positions for measuring fringe amplitudes.
281 Parameters
282 ----------
283 exposure : `lsst.afw.image.Exposure`
284 Exposure to measure the positions on.
285 rng : `numpy.random.RandomState`
286 Random number generator to use.
288 Returns
289 -------
290 positions : `numpy.array`
291 Two-dimensional array containing the positions to sample
292 for fringe amplitudes.
293 """
294 start = self.config.large
295 num = self.config.num
296 width = exposure.getWidth() - self.config.large
297 height = exposure.getHeight() - self.config.large
298 return numpy.array([rng.randint(start, width, size=num),
299 rng.randint(start, height, size=num)]).swapaxes(0, 1)
301 @timeMethod
302 def measureExposure(self, exposure, positions, title="Fringe"):
303 """Measure fringe amplitudes for an exposure
305 The fringe amplitudes are measured as the statistic within a square
306 aperture. The statistic within a larger aperture are subtracted so
307 as to remove the background.
309 Parameters
310 ----------
311 exposure : `lsst.afw.image.Exposure`
312 Exposure to measure the positions on.
313 positions : `numpy.array`
314 Two-dimensional array containing the positions to sample
315 for fringe amplitudes.
316 title : `str`, optional
317 Title used for debug out plots.
319 Returns
320 -------
321 fringes : `numpy.array`
322 Array of measured exposure values at each of the positions
323 supplied.
324 """
325 stats = afwMath.StatisticsControl()
326 stats.setNumSigmaClip(self.config.stats.clip)
327 stats.setNumIter(self.config.stats.iterations)
328 stats.setAndMask(exposure.getMaskedImage().getMask().getPlaneBitMask(self.config.stats.badMaskPlanes))
330 num = self.config.num
331 fringes = numpy.ndarray(num)
333 for i in range(num):
334 x, y = positions[i]
335 small = measure(exposure.getMaskedImage(), x, y, self.config.small, self.config.stats.stat, stats)
336 large = measure(exposure.getMaskedImage(), x, y, self.config.large, self.config.stats.stat, stats)
337 fringes[i] = small - large
339 import lsstDebug
340 display = lsstDebug.Info(__name__).display
341 if display:
342 disp = afwDisplay.Display(frame=getFrame())
343 disp.mtv(exposure, title=title)
344 if False:
345 with disp.Buffering():
346 for x, y in positions:
347 corners = numpy.array([[-1, -1], [1, -1], [1, 1], [-1, 1], [-1, -1]]) + [[x, y]]
348 disp.line(corners*self.config.small, ctype=afwDisplay.GREEN)
349 disp.line(corners*self.config.large, ctype=afwDisplay.BLUE)
351 return fringes
353 @timeMethod
354 def solve(self, science, fringes):
355 """Solve for the scale factors with iterative clipping.
357 Parameters
358 ----------
359 science : `numpy.array`
360 Array of measured science image values at each of the
361 positions supplied.
362 fringes : `numpy.array`
363 Array of measured fringe values at each of the positions
364 supplied.
366 Returns
367 -------
368 solution : `np.array`
369 Fringe solution amplitudes for each input fringe frame.
370 rms : `float`
371 RMS error for the fit solution for this exposure.
372 """
373 import lsstDebug
374 doPlot = lsstDebug.Info(__name__).plot
376 origNum = len(science)
378 def emptyResult(msg=""):
379 """Generate an empty result for return to the user
381 There are no good pixels; doesn't matter what we return.
382 """
383 self.log.warning("Unable to solve for fringes: no good pixels%s", msg)
384 out = [0]
385 if len(fringes) > 1:
386 out = out*len(fringes)
387 return numpy.array(out), numpy.nan
389 good = numpy.where(numpy.logical_and(numpy.isfinite(science), numpy.any(numpy.isfinite(fringes), 1)))
390 science = science[good]
391 fringes = fringes[good]
392 oldNum = len(science)
393 if oldNum == 0:
394 return emptyResult()
396 # Up-front rejection to get rid of extreme, potentially troublesome values
397 # (e.g., fringe apertures that fall on objects).
398 good = select(science, self.config.clip)
399 for ff in range(fringes.shape[1]):
400 good &= select(fringes[:, ff], self.config.clip)
401 science = science[good]
402 fringes = fringes[good]
403 oldNum = len(science)
404 if oldNum == 0:
405 return emptyResult(" after initial rejection")
407 for i in range(self.config.iterations):
408 solution = self._solve(science, fringes)
409 resid = science - numpy.sum(solution*fringes, 1)
410 rms = stdev(resid)
411 good = numpy.logical_not(abs(resid) > self.config.clip*rms)
412 self.log.debug("Iteration %d: RMS=%f numGood=%d", i, rms, good.sum())
413 self.log.debug("Solution %d: %s", i, solution)
414 newNum = good.sum()
415 if newNum == 0:
416 return emptyResult(" after %d rejection iterations" % i)
418 if doPlot:
419 import matplotlib.pyplot as plot
420 for j in range(fringes.shape[1]):
421 fig = plot.figure(j)
422 fig.clf()
423 try:
424 fig.canvas._tkcanvas._root().lift() # == Tk's raise
425 except Exception:
426 pass
427 ax = fig.add_subplot(1, 1, 1)
428 adjust = science.copy()
429 others = set(range(fringes.shape[1]))
430 others.discard(j)
431 for k in others:
432 adjust -= solution[k]*fringes[:, k]
433 ax.plot(fringes[:, j], adjust, 'r.')
434 xmin = fringes[:, j].min()
435 xmax = fringes[:, j].max()
436 ymin = solution[j]*xmin
437 ymax = solution[j]*xmax
438 ax.plot([xmin, xmax], [ymin, ymax], 'b-')
439 ax.set_title("Fringe %d: %f" % (j, solution[j]))
440 ax.set_xlabel("Fringe amplitude")
441 ax.set_ylabel("Science amplitude")
442 ax.set_autoscale_on(False)
443 ax.set_xbound(lower=xmin, upper=xmax)
444 ax.set_ybound(lower=ymin, upper=ymax)
445 fig.show()
446 while True:
447 ans = input("Enter or c to continue [chp]").lower()
448 if ans in ("", "c",):
449 break
450 if ans in ("p",):
451 import pdb
452 pdb.set_trace()
453 elif ans in ("h", ):
454 print("h[elp] c[ontinue] p[db]")
456 if newNum == oldNum:
457 # Not gaining
458 break
459 oldNum = newNum
460 good = numpy.where(good)
461 science = science[good]
462 fringes = fringes[good]
464 # Final solution without rejection
465 solution = self._solve(science, fringes)
466 self.log.info("Fringe solution: %s RMS: %f Good: %d/%d", solution, rms, len(science), origNum)
467 return solution, rms
469 def _solve(self, science, fringes):
470 """Solve for the scale factors.
472 Parameters
473 ----------
474 science : `numpy.array`
475 Array of measured science image values at each of the
476 positions supplied.
477 fringes : `numpy.array`
478 Array of measured fringe values at each of the positions
479 supplied.
481 Returns
482 -------
483 solution : `np.array`
484 Fringe solution amplitudes for each input fringe frame.
485 """
486 return afwMath.LeastSquares.fromDesignMatrix(fringes, science,
487 afwMath.LeastSquares.DIRECT_SVD).getSolution()
489 def subtract(self, science, fringes, solution):
490 """Subtract the fringes.
492 Parameters
493 ----------
494 science : `lsst.afw.image.Exposure`
495 Science exposure from which to remove fringes.
496 fringes : `lsst.afw.image.Exposure` or `list` thereof
497 Calibration fringe files containing master fringe frames.
498 solution : `np.array`
499 Fringe solution amplitudes for each input fringe frame.
501 Raises
502 ------
503 RuntimeError :
504 Raised if the number of fringe frames does not match the
505 number of measured amplitudes.
506 """
507 if len(solution) != len(fringes):
508 raise RuntimeError("Number of fringe frames (%s) != number of scale factors (%s)" %
509 (len(fringes), len(solution)))
511 for s, f in zip(solution, fringes):
512 # We do not want to add the mask from the fringe to the image.
513 f.getMaskedImage().getMask().getArray()[:] = 0
514 science.getMaskedImage().scaledMinus(s, f.getMaskedImage())
517def measure(mi, x, y, size, statistic, stats):
518 """Measure a statistic within an aperture
520 @param mi MaskedImage to measure
521 @param x, y Center for aperture
522 @param size Size of aperture
523 @param statistic Statistic to measure
524 @param stats StatisticsControl object
525 @return Value of statistic within aperture
526 """
527 bbox = lsst.geom.Box2I(lsst.geom.Point2I(int(x) - size, int(y - size)),
528 lsst.geom.Extent2I(2*size, 2*size))
529 subImage = mi.Factory(mi, bbox, afwImage.LOCAL)
530 return afwMath.makeStatistics(subImage, statistic, stats).getValue()
533def stdev(vector):
534 """Calculate a robust standard deviation of an array of values
536 @param vector Array of values
537 @return Standard deviation
538 """
539 q1, q3 = numpy.percentile(vector, (25, 75))
540 return 0.74*(q3 - q1)
543def select(vector, clip):
544 """Select values within 'clip' standard deviations of the median
546 Returns a boolean array.
547 """
548 q1, q2, q3 = numpy.percentile(vector, (25, 50, 75))
549 return numpy.abs(vector - q2) < clip*0.74*(q3 - q1)