lsst.pipe.drivers  17.0.1-2-gd73ec07+6
skyCorrection.py
Go to the documentation of this file.
1 from __future__ import absolute_import, division, print_function
2 
3 import lsst.afw.math as afwMath
4 import lsst.afw.image as afwImage
5 import lsst.afw.table as afwTable
6 import lsst.meas.algorithms as measAlg
7 
8 from lsst.pipe.base import ArgumentParser, Struct
9 from lsst.pex.config import Config, Field, ConfigurableField, ConfigField
10 from lsst.ctrl.pool.pool import Pool
11 from lsst.ctrl.pool.parallel import BatchPoolTask
12 from lsst.pipe.drivers.background import SkyMeasurementTask, FocalPlaneBackground, FocalPlaneBackgroundConfig
13 import lsst.pipe.drivers.visualizeVisit as visualizeVisit
14 
15 DEBUG = False # Debugging outputs?
16 
17 
18 def makeCameraImage(camera, exposures, filename=None, binning=8):
19  """Make and write an image of an entire focal plane
20 
21  Parameters
22  ----------
23  camera : `lsst.afw.cameraGeom.Camera`
24  Camera description.
25  exposures : `list` of `tuple` of `int` and `lsst.afw.image.Exposure`
26  List of detector ID and CCD exposures (binned by `binning`).
27  filename : `str`, optional
28  Output filename.
29  binning : `int`
30  Binning size that has been applied to images.
31  """
32  image = visualizeVisit.makeCameraImage(camera, dict(exp for exp in exposures if exp is not None), binning)
33  if filename is not None:
34  image.writeFits(filename)
35  return image
36 
37 
38 class SkyCorrectionConfig(Config):
39  """Configuration for SkyCorrectionTask"""
40  bgModel = ConfigField(dtype=FocalPlaneBackgroundConfig, doc="Background model")
41  sky = ConfigurableField(target=SkyMeasurementTask, doc="Sky measurement")
42  detection = ConfigurableField(target=measAlg.SourceDetectionTask, doc="Detection configuration")
43  doDetection = Field(dtype=bool, default=True, doc="Detect sources (to find good sky)?")
44  detectSigma = Field(dtype=float, default=5.0, doc="Detection PSF gaussian sigma")
45  doBgModel = Field(dtype=bool, default=True, doc="Do background model subtraction?")
46  doSky = Field(dtype=bool, default=True, doc="Do sky frame subtraction?")
47  binning = Field(dtype=int, default=8, doc="Binning factor for constructing focal-plane images")
48 
49  def setDefaults(self):
50  Config.setDefaults(self)
51  self.detection.reEstimateBackground = False
52  self.detection.thresholdPolarity = "both"
53  self.detection.doTempLocalBackground = False
54  self.detection.thresholdType = "pixel_stdev"
55  self.detection.thresholdValue = 3.0
56 
57 
59  """Correct sky over entire focal plane"""
60  ConfigClass = SkyCorrectionConfig
61  _DefaultName = "skyCorr"
62 
63  def __init__(self, *args, **kwargs):
64  BatchPoolTask.__init__(self, *args, **kwargs)
65  self.makeSubtask("sky")
66  # Disposable schema suppresses warning from SourceDetectionTask.__init__
67  self.makeSubtask("detection", schema=afwTable.Schema())
68 
69  @classmethod
70  def _makeArgumentParser(cls, *args, **kwargs):
71  kwargs.pop("doBatch", False)
72  parser = ArgumentParser(name="skyCorr", *args, **kwargs)
73  parser.add_id_argument("--id", datasetType="calexp", level="visit",
74  help="data ID, e.g. --id visit=12345")
75  return parser
76 
77  @classmethod
78  def batchWallTime(cls, time, parsedCmd, numCores):
79  """Return walltime request for batch job
80 
81  Subclasses should override if the walltime should be calculated
82  differently (e.g., addition of some serial time).
83 
84  Parameters
85  ----------
86  time : `float`
87  Requested time per iteration.
88  parsedCmd : `argparse.Namespace`
89  Results of argument parsing.
90  numCores : `int`
91  Number of cores.
92  """
93  numTargets = len(cls.RunnerClass.getTargetList(parsedCmd))
94  return time*numTargets
95 
96  def runDataRef(self, expRef):
97  """Perform sky correction on an exposure
98 
99  We restore the original sky, and remove it again using multiple
100  algorithms. We optionally apply:
101 
102  1. A large-scale background model.
103  2. A sky frame.
104 
105  Only the master node executes this method. The data is held on
106  the slave nodes, which do all the hard work.
107 
108  Parameters
109  ----------
110  expRef : `lsst.daf.persistence.ButlerDataRef`
111  Data reference for exposure.
112  """
113  if DEBUG:
114  extension = "-%(visit)d.fits" % expRef.dataId
115 
116  with self.logOperation("processing %s" % (expRef.dataId,)):
117  pool = Pool()
118  pool.cacheClear()
119  pool.storeSet(butler=expRef.getButler())
120  camera = expRef.get("camera")
121 
122  dataIdList = [ccdRef.dataId for ccdRef in expRef.subItems("ccd") if
123  ccdRef.datasetExists("calexp")]
124 
125  exposures = pool.map(self.loadImage, dataIdList)
126  if DEBUG:
127  makeCameraImage(camera, exposures, "restored" + extension)
128  exposures = pool.mapToPrevious(self.collectOriginal, dataIdList)
129  makeCameraImage(camera, exposures, "original" + extension)
130  exposures = pool.mapToPrevious(self.collectMask, dataIdList)
131  makeCameraImage(camera, exposures, "mask" + extension)
132 
133  if self.config.doBgModel:
134  bgModel = FocalPlaneBackground.fromCamera(self.config.bgModel, camera)
135  data = [Struct(dataId=dataId, bgModel=bgModel.clone()) for dataId in dataIdList]
136  bgModelList = pool.mapToPrevious(self.accumulateModel, data)
137  for ii, bg in enumerate(bgModelList):
138  self.log.info("Background %d: %d pixels", ii, bg._numbers.getArray().sum())
139  bgModel.merge(bg)
140 
141  if DEBUG:
142  bgModel.getStatsImage().writeFits("bgModel" + extension)
143  bgImages = pool.mapToPrevious(self.realiseModel, dataIdList, bgModel)
144  makeCameraImage(camera, bgImages, "bgModelCamera" + extension)
145 
146  exposures = pool.mapToPrevious(self.subtractModel, dataIdList, bgModel)
147  if DEBUG:
148  makeCameraImage(camera, exposures, "modelsub" + extension)
149 
150  if self.config.doSky:
151  measScales = pool.mapToPrevious(self.measureSkyFrame, dataIdList)
152  scale = self.sky.solveScales(measScales)
153  self.log.info("Sky frame scale: %s" % (scale,))
154  exposures = pool.mapToPrevious(self.subtractSkyFrame, dataIdList, scale)
155  if DEBUG:
156  makeCameraImage(camera, exposures, "skysub" + extension)
157  calibs = pool.mapToPrevious(self.collectSky, dataIdList)
158  makeCameraImage(camera, calibs, "sky" + extension)
159 
160  # Persist camera-level image of calexp
161  image = makeCameraImage(camera, exposures)
162  expRef.put(image, "calexp_camera")
163 
164  pool.mapToPrevious(self.write, dataIdList)
165 
166  def loadImage(self, cache, dataId):
167  """Load original image and restore the sky
168 
169  This method runs on the slave nodes.
170 
171  Parameters
172  ----------
173  cache : `lsst.pipe.base.Struct`
174  Process pool cache.
175  dataId : `dict`
176  Data identifier.
177 
178  Returns
179  -------
180  exposure : `lsst.afw.image.Exposure`
181  Resultant exposure.
182  """
183  cache.dataId = dataId
184  cache.exposure = cache.butler.get("calexp", dataId, immediate=True).clone()
185  bgOld = cache.butler.get("calexpBackground", dataId, immediate=True)
186  image = cache.exposure.getMaskedImage()
187 
188  if self.config.doDetection:
189  # We deliberately use the specified 'detectSigma' instead of the PSF, in order to better pick up
190  # the faint wings of objects.
191  results = self.detection.detectFootprints(cache.exposure, doSmooth=True,
192  sigma=self.config.detectSigma, clearMask=True)
193  if hasattr(results, "background") and results.background:
194  # Restore any background that was removed during detection
195  image += results.background.getImage()
196 
197  # We're removing the old background, so change the sense of all its components
198  for bgData in bgOld:
199  statsImage = bgData[0].getStatsImage()
200  statsImage *= -1
201 
202  image -= bgOld.getImage()
203  cache.bgList = afwMath.BackgroundList()
204  for bgData in bgOld:
205  cache.bgList.append(bgData)
206 
207  return self.collect(cache)
208 
209  def measureSkyFrame(self, cache, dataId):
210  """Measure scale for sky frame
211 
212  This method runs on the slave nodes.
213 
214  Parameters
215  ----------
216  cache : `lsst.pipe.base.Struct`
217  Process pool cache.
218  dataId : `dict`
219  Data identifier.
220 
221  Returns
222  -------
223  scale : `float`
224  Scale for sky frame.
225  """
226  assert cache.dataId == dataId
227  cache.sky = self.sky.getSkyData(cache.butler, dataId)
228  scale = self.sky.measureScale(cache.exposure.getMaskedImage(), cache.sky)
229  return scale
230 
231  def subtractSkyFrame(self, cache, dataId, scale):
232  """Subtract sky frame
233 
234  This method runs on the slave nodes.
235 
236  Parameters
237  ----------
238  cache : `lsst.pipe.base.Struct`
239  Process pool cache.
240  dataId : `dict`
241  Data identifier.
242  scale : `float`
243  Scale for sky frame.
244 
245  Returns
246  -------
247  exposure : `lsst.afw.image.Exposure`
248  Resultant exposure.
249  """
250  assert cache.dataId == dataId
251  self.sky.subtractSkyFrame(cache.exposure.getMaskedImage(), cache.sky, scale, cache.bgList)
252  return self.collect(cache)
253 
254  def accumulateModel(self, cache, data):
255  """Fit background model for CCD
256 
257  This method runs on the slave nodes.
258 
259  Parameters
260  ----------
261  cache : `lsst.pipe.base.Struct`
262  Process pool cache.
263  data : `lsst.pipe.base.Struct`
264  Data identifier, with `dataId` (data identifier) and `bgModel`
265  (background model) elements.
266 
267  Returns
268  -------
269  bgModel : `lsst.pipe.drivers.background.FocalPlaneBackground`
270  Background model.
271  """
272  assert cache.dataId == data.dataId
273  data.bgModel.addCcd(cache.exposure)
274  return data.bgModel
275 
276  def subtractModel(self, cache, dataId, bgModel):
277  """Subtract background model
278 
279  This method runs on the slave nodes.
280 
281  Parameters
282  ----------
283  cache : `lsst.pipe.base.Struct`
284  Process pool cache.
285  dataId : `dict`
286  Data identifier.
287  bgModel : `lsst.pipe.drivers.background.FocalPlaneBackround`
288  Background model.
289 
290  Returns
291  -------
292  exposure : `lsst.afw.image.Exposure`
293  Resultant exposure.
294  """
295  assert cache.dataId == dataId
296  exposure = cache.exposure
297  image = exposure.getMaskedImage()
298  detector = exposure.getDetector()
299  bbox = image.getBBox()
300  cache.bgModel = bgModel.toCcdBackground(detector, bbox)
301  image -= cache.bgModel.getImage()
302  cache.bgList.append(cache.bgModel[0])
303  return self.collect(cache)
304 
305  def realiseModel(self, cache, dataId, bgModel):
306  """Generate an image of the background model for visualisation
307 
308  Useful for debugging.
309 
310  Parameters
311  ----------
312  cache : `lsst.pipe.base.Struct`
313  Process pool cache.
314  dataId : `dict`
315  Data identifier.
316  bgModel : `lsst.pipe.drivers.background.FocalPlaneBackround`
317  Background model.
318 
319  Returns
320  -------
321  detId : `int`
322  Detector identifier.
323  image : `lsst.afw.image.MaskedImage`
324  Binned background model image.
325  """
326  assert cache.dataId == dataId
327  exposure = cache.exposure
328  detector = exposure.getDetector()
329  bbox = exposure.getMaskedImage().getBBox()
330  image = bgModel.toCcdBackground(detector, bbox).getImage()
331  return self.collectBinnedImage(exposure, image)
332 
333  def collectBinnedImage(self, exposure, image):
334  """Return the binned image required for visualization
335 
336  This method just helps to cut down on boilerplate.
337 
338  Parameters
339  ----------
340  image : `lsst.afw.image.MaskedImage`
341  Image to go into visualisation.
342 
343  Returns
344  -------
345  detId : `int`
346  Detector identifier.
347  image : `lsst.afw.image.MaskedImage`
348  Binned image.
349  """
350  return (exposure.getDetector().getId(), afwMath.binImage(image, self.config.binning))
351 
352  def collect(self, cache):
353  """Collect exposure for potential visualisation
354 
355  This method runs on the slave nodes.
356 
357  Parameters
358  ----------
359  cache : `lsst.pipe.base.Struct`
360  Process pool cache.
361 
362  Returns
363  -------
364  detId : `int`
365  Detector identifier.
366  image : `lsst.afw.image.MaskedImage`
367  Binned image.
368  """
369  return self.collectBinnedImage(cache.exposure, cache.exposure.maskedImage)
370 
371  def collectOriginal(self, cache, dataId):
372  """Collect original image for visualisation
373 
374  This method runs on the slave nodes.
375 
376  Parameters
377  ----------
378  cache : `lsst.pipe.base.Struct`
379  Process pool cache.
380  dataId : `dict`
381  Data identifier.
382 
383  Returns
384  -------
385  detId : `int`
386  Detector identifier.
387  image : `lsst.afw.image.MaskedImage`
388  Binned image.
389  """
390  exposure = cache.butler.get("calexp", dataId, immediate=True)
391  return self.collectBinnedImage(exposure, exposure.maskedImage)
392 
393  def collectSky(self, cache, dataId):
394  """Collect original image for visualisation
395 
396  This method runs on the slave nodes.
397 
398  Parameters
399  ----------
400  cache : `lsst.pipe.base.Struct`
401  Process pool cache.
402  dataId : `dict`
403  Data identifier.
404 
405  Returns
406  -------
407  detId : `int`
408  Detector identifier.
409  image : `lsst.afw.image.MaskedImage`
410  Binned image.
411  """
412  return self.collectBinnedImage(cache.exposure, cache.sky.getImage())
413 
414  def collectMask(self, cache, dataId):
415  """Collect mask for visualisation
416 
417  This method runs on the slave nodes.
418 
419  Parameters
420  ----------
421  cache : `lsst.pipe.base.Struct`
422  Process pool cache.
423  dataId : `dict`
424  Data identifier.
425 
426  Returns
427  -------
428  detId : `int`
429  Detector identifier.
430  image : `lsst.afw.image.Image`
431  Binned image.
432  """
433  # Convert Mask to floating-point image, because that's what's required for focal plane construction
434  image = afwImage.ImageF(cache.exposure.maskedImage.getBBox())
435  image.array[:] = cache.exposure.maskedImage.mask.array
436  return self.collectBinnedImage(cache.exposure, image)
437 
438  def write(self, cache, dataId):
439  """Write resultant background list
440 
441  This method runs on the slave nodes.
442 
443  Parameters
444  ----------
445  cache : `lsst.pipe.base.Struct`
446  Process pool cache.
447  dataId : `dict`
448  Data identifier.
449  """
450  cache.butler.put(cache.bgList, "skyCorr", dataId)
451 
452  def _getMetadataName(self):
453  """There's no metadata to write out"""
454  return None
def subtractModel(self, cache, dataId, bgModel)
def batchWallTime(cls, time, parsedCmd, numCores)
def makeCameraImage(camera, exposures, filename=None, binning=8)
def realiseModel(self, cache, dataId, bgModel)
def subtractSkyFrame(self, cache, dataId, scale)
def logOperation(self, operation, catch=False, trace=True)