lsst.pipe.drivers  6.0b0-hsc-11-g620b133+1
multiBandDriver.py
Go to the documentation of this file.
1 from __future__ import absolute_import, division, print_function
2 import os
3 from argparse import ArgumentError
4 
5 from builtins import zip
6 
7 from lsst.pex.config import Config, Field, ConfigurableField
8 from lsst.pipe.base import ArgumentParser, TaskRunner
9 from lsst.pipe.tasks.multiBand import (DetectCoaddSourcesTask,
10  MergeDetectionsTask,
11  MeasureMergedCoaddSourcesTask,
12  MergeMeasurementsTask,)
13 from lsst.ctrl.pool.parallel import BatchPoolTask
14 from lsst.ctrl.pool.pool import Pool, abortOnError, NODE
15 from lsst.meas.base.references import MultiBandReferencesTask
16 from lsst.meas.base.forcedPhotCoadd import ForcedPhotCoaddTask
17 from lsst.pipe.drivers.utils import getDataRef, TractDataIdContainer
18 from lsst.pipe.tasks.coaddBase import CoaddDataIdContainer
19 
20 import lsst.afw.table as afwTable
21 
22 
23 class MultiBandDataIdContainer(CoaddDataIdContainer):
24 
25  def makeDataRefList(self, namespace):
26  """!Make self.refList from self.idList
27 
28  It's difficult to make a data reference that merely points to an entire
29  tract: there is no data product solely at the tract level. Instead, we
30  generate a list of data references for patches within the tract.
31 
32  @param namespace namespace object that is the result of an argument parser
33  """
34  datasetType = namespace.config.coaddName + "Coadd_calexp"
35 
36  def getPatchRefList(tract):
37  return [namespace.butler.dataRef(datasetType=datasetType,
38  tract=tract.getId(),
39  filter=dataId["filter"],
40  patch="%d,%d" % patch.getIndex())
41  for patch in tract]
42 
43  tractRefs = {} # Data references for each tract
44  for dataId in self.idList:
45  # There's no registry of coadds by filter, so we need to be given
46  # the filter
47  if "filter" not in dataId:
48  raise ArgumentError(None, "--id must include 'filter'")
49 
50  skymap = self.getSkymap(namespace, datasetType)
51 
52  if "tract" in dataId:
53  tractId = dataId["tract"]
54  if tractId not in tractRefs:
55  tractRefs[tractId] = []
56  if "patch" in dataId:
57  tractRefs[tractId].append(namespace.butler.dataRef(datasetType=datasetType,
58  tract=tractId,
59  filter=dataId[
60  'filter'],
61  patch=dataId['patch']))
62  else:
63  tractRefs[tractId] += getPatchRefList(skymap[tractId])
64  else:
65  tractRefs = dict((tract.getId(), tractRefs.get(tract.getId(), []) + getPatchRefList(tract))
66  for tract in skymap)
67 
68  self.refList = list(tractRefs.values())
69 
70 
71 class MultiBandDriverConfig(Config):
72  coaddName = Field(dtype=str, default="deep", doc="Name of coadd")
73  doDetection = Field(dtype=bool, default=False,
74  doc="Re-run detection? (requires *Coadd dataset to have been written)")
75  detectCoaddSources = ConfigurableField(target=DetectCoaddSourcesTask,
76  doc="Detect sources on coadd")
77  mergeCoaddDetections = ConfigurableField(
78  target=MergeDetectionsTask, doc="Merge detections")
79  measureCoaddSources = ConfigurableField(target=MeasureMergedCoaddSourcesTask,
80  doc="Measure merged detections")
81  mergeCoaddMeasurements = ConfigurableField(
82  target=MergeMeasurementsTask, doc="Merge measurements")
83  forcedPhotCoadd = ConfigurableField(target=ForcedPhotCoaddTask,
84  doc="Forced measurement on coadded images")
85  reprocessing = Field(
86  dtype=bool, default=False,
87  doc=("Are we reprocessing?\n\n"
88  "This exists as a workaround for large deblender footprints causing large memory use "
89  "and/or very slow processing. We refuse to deblend those footprints when running on a cluster "
90  "and return to reprocess on a machine with larger memory or more time "
91  "if we consider those footprints important to recover."),
92  )
93 
94  def setDefaults(self):
95  Config.setDefaults(self)
96  self.forcedPhotCoadd.references.retarget(MultiBandReferencesTask)
97 
98  def validate(self):
99  for subtask in ("mergeCoaddDetections", "measureCoaddSources",
100  "mergeCoaddMeasurements", "forcedPhotCoadd"):
101  coaddName = getattr(self, subtask).coaddName
102  if coaddName != self.coaddName:
103  raise RuntimeError("%s.coaddName (%s) doesn't match root coaddName (%s)" %
104  (subtask, coaddName, self.coaddName))
105 
106 
107 class MultiBandDriverTaskRunner(TaskRunner):
108  """TaskRunner for running MultiBandTask
109 
110  This is similar to the lsst.pipe.base.ButlerInitializedTaskRunner,
111  except that we have a list of data references instead of a single
112  data reference being passed to the Task.run, and we pass the results
113  of the '--reuse-outputs-from' command option to the Task constructor.
114  """
115 
116  def __init__(self, TaskClass, parsedCmd, doReturnResults=False):
117  TaskRunner.__init__(self, TaskClass, parsedCmd, doReturnResults)
118  self.reuse = parsedCmd.reuse
119 
120  def makeTask(self, parsedCmd=None, args=None):
121  """A variant of the base version that passes a butler argument to the task's constructor
122  parsedCmd or args must be specified.
123  """
124  if parsedCmd is not None:
125  butler = parsedCmd.butler
126  elif args is not None:
127  dataRefList, kwargs = args
128  butler = dataRefList[0].butlerSubset.butler
129  else:
130  raise RuntimeError("parsedCmd or args must be specified")
131  return self.TaskClass(config=self.config, log=self.log, butler=butler, reuse=self.reuse)
132 
133 
134 def unpickle(factory, args, kwargs):
135  """Unpickle something by calling a factory"""
136  return factory(*args, **kwargs)
137 
138 
140  """Multi-node driver for multiband processing"""
141  ConfigClass = MultiBandDriverConfig
142  _DefaultName = "multiBandDriver"
143  RunnerClass = MultiBandDriverTaskRunner
144 
145  def __init__(self, butler=None, schema=None, refObjLoader=None, reuse=tuple(), **kwargs):
146  """!
147  @param[in] butler: the butler can be used to retrieve schema or passed to the refObjLoader constructor
148  in case it is needed.
149  @param[in] schema: the schema of the source detection catalog used as input.
150  @param[in] refObjLoader: an instance of LoadReferenceObjectsTasks that supplies an external reference
151  catalog. May be None if the butler argument is provided or all steps requiring a reference
152  catalog are disabled.
153  """
154  BatchPoolTask.__init__(self, **kwargs)
155  if schema is None:
156  assert butler is not None, "Butler not provided"
157  schema = butler.get(self.config.coaddName +
158  "Coadd_det_schema", immediate=True).schema
159  self.butler = butler
160  self.reuse = tuple(reuse)
161  self.makeSubtask("detectCoaddSources")
162  self.makeSubtask("mergeCoaddDetections", schema=schema)
163  self.makeSubtask("measureCoaddSources", schema=afwTable.Schema(self.mergeCoaddDetections.schema),
164  peakSchema=afwTable.Schema(
165  self.mergeCoaddDetections.merged.getPeakSchema()),
166  refObjLoader=refObjLoader, butler=butler)
167  self.makeSubtask("mergeCoaddMeasurements", schema=afwTable.Schema(
168  self.measureCoaddSources.schema))
169  self.makeSubtask("forcedPhotCoadd", refSchema=afwTable.Schema(
170  self.mergeCoaddMeasurements.schema))
171 
172  def __reduce__(self):
173  """Pickler"""
174  return unpickle, (self.__class__, [], dict(config=self.config, name=self._name,
175  parentTask=self._parentTask, log=self.log,
176  butler=self.butler, reuse=self.reuse))
177 
178  @classmethod
179  def _makeArgumentParser(cls, *args, **kwargs):
180  kwargs.pop("doBatch", False)
181  parser = ArgumentParser(name=cls._DefaultName, *args, **kwargs)
182  parser.add_id_argument("--id", "deepCoadd", help="data ID, e.g. --id tract=12345 patch=1,2",
183  ContainerClass=TractDataIdContainer)
184  parser.addReuseOption(["detectCoaddSources", "mergeCoaddDetections", "measureCoaddSources",
185  "mergeCoaddMeasurements", "forcedPhotCoadd"])
186  return parser
187 
188  @classmethod
189  def batchWallTime(cls, time, parsedCmd, numCpus):
190  """!Return walltime request for batch job
191 
192  @param time: Requested time per iteration
193  @param parsedCmd: Results of argument parsing
194  @param numCores: Number of cores
195  """
196  numTargets = 0
197  for refList in parsedCmd.id.refList:
198  numTargets += len(refList)
199  return time*numTargets/float(numCpus)
200 
201  @abortOnError
202  def run(self, patchRefList):
203  """!Run multiband processing on coadds
204 
205  Only the master node runs this method.
206 
207  No real MPI communication (scatter/gather) takes place: all I/O goes
208  through the disk. We want the intermediate stages on disk, and the
209  component Tasks are implemented around this, so we just follow suit.
210 
211  @param patchRefList: Data references to run measurement
212  """
213  for patchRef in patchRefList:
214  if patchRef:
215  butler = patchRef.getButler()
216  break
217  else:
218  raise RuntimeError("No valid patches")
219  pool = Pool("all")
220  pool.cacheClear()
221  pool.storeSet(butler=butler)
222 
223  # MultiBand measurements require that the detection stage be completed
224  # before measurements can be made.
225  #
226  # The configuration for coaddDriver.py allows detection to be turned
227  # of in the event that fake objects are to be added during the
228  # detection process. This allows the long co-addition process to be
229  # run once, and multiple different MultiBand reruns (with different
230  # fake objects) to exist from the same base co-addition.
231  #
232  # However, we only re-run detection if doDetection is explicitly True
233  # here (this should always be the opposite of coaddDriver.doDetection);
234  # otherwise we have no way to tell reliably whether any detections
235  # present in an input repo are safe to use.
236  if self.config.doDetection:
237  detectionList = []
238  for patchRef in patchRefList:
239  if ("detectCoaddSources" in self.reuse and
240  patchRef.datasetExists(self.config.coaddName + "Coadd_calexp", write=True)):
241  self.log.info("Skipping detectCoaddSources for %s; output already exists." % patchRef.dataId)
242  continue
243  if not patchRef.datasetExists(self.config.coaddName + "Coadd"):
244  self.log.debug("Not processing %s; required input %sCoadd missing." %
245  (patchRef.dataId, self.config.coaddName))
246  continue
247  detectionList.append(patchRef)
248 
249  pool.map(self.runDetection, detectionList)
250 
251  patchRefList = [patchRef for patchRef in patchRefList if
252  patchRef.datasetExists(self.config.coaddName + "Coadd_calexp") and
253  patchRef.datasetExists(self.config.coaddName + "Coadd_det", write=self.config.doDetection)]
254  dataIdList = [patchRef.dataId for patchRef in patchRefList]
255 
256  # Group by patch
257  patches = {}
258  tract = None
259  for patchRef in patchRefList:
260  dataId = patchRef.dataId
261  if tract is None:
262  tract = dataId["tract"]
263  else:
264  assert tract == dataId["tract"]
265 
266  patch = dataId["patch"]
267  if patch not in patches:
268  patches[patch] = []
269  patches[patch].append(dataId)
270 
271  pool.map(self.runMergeDetections, patches.values())
272 
273  # Measure merged detections, and test for reprocessing
274  #
275  # The reprocessing allows us to have multiple attempts at deblending large footprints. Large
276  # footprints can suck up a lot of memory in the deblender, which means that when we process on a
277  # cluster, we want to refuse to deblend them (they're flagged "deblend.parent-too-big"). But since
278  # they may have astronomically interesting data, we want the ability to go back and reprocess them
279  # with a more permissive configuration when we have more memory or processing time.
280  #
281  # self.runMeasureMerged will return whether there are any footprints in that image that required
282  # reprocessing. We need to convert that list of booleans into a dict mapping the patchId (x,y) to
283  # a boolean. That tells us whether the merge measurement and forced photometry need to be re-run on
284  # a particular patch.
285  #
286  # This determination of which patches need to be reprocessed exists only in memory (the measurements
287  # have been written, clobbering the old ones), so if there was an exception we would lose this
288  # information, leaving things in an inconsistent state (measurements new, but merged measurements and
289  # forced photometry old). To attempt to preserve this status, we touch a file (dataset named
290  # "deepCoadd_multibandReprocessing") --- if this file exists, we need to re-run the merge and
291  # forced photometry.
292  #
293  # This is, hopefully, a temporary workaround until we can improve the
294  # deblender.
295  try:
296  reprocessed = pool.map(self.runMeasureMerged, dataIdList)
297  finally:
298  if self.config.reprocessing:
299  patchReprocessing = {}
300  for dataId, reprocess in zip(dataIdList, reprocessed):
301  patchId = dataId["patch"]
302  patchReprocessing[patchId] = patchReprocessing.get(
303  patchId, False) or reprocess
304  # Persist the determination, to make error recover easier
305  reprocessDataset = self.config.coaddName + "Coadd_multibandReprocessing"
306  for patchId in patchReprocessing:
307  if not patchReprocessing[patchId]:
308  continue
309  dataId = dict(tract=tract, patch=patchId)
310  if patchReprocessing[patchId]:
311  filename = butler.get(
312  reprocessDataset + "_filename", dataId)[0]
313  open(filename, 'a').close() # Touch file
314  elif butler.datasetExists(reprocessDataset, dataId):
315  # We must have failed at some point while reprocessing
316  # and we're starting over
317  patchReprocessing[patchId] = True
318 
319  # Only process patches that have been identified as needing it
320  pool.map(self.runMergeMeasurements, [idList for patchId, idList in patches.items() if
321  not self.config.reprocessing or patchReprocessing[patchId]])
322  pool.map(self.runForcedPhot, [dataId1 for dataId1 in dataIdList if not self.config.reprocessing or
323  patchReprocessing[dataId["patch"]]])
324 
325  # Remove persisted reprocessing determination
326  if self.config.reprocessing:
327  for patchId in patchReprocessing:
328  if not patchReprocessing[patchId]:
329  continue
330  dataId = dict(tract=tract, patch=patchId)
331  filename = butler.get(
332  reprocessDataset + "_filename", dataId)[0]
333  os.unlink(filename)
334 
335  def runDetection(self, cache, patchRef):
336  """! Run detection on a patch
337 
338  Only slave nodes execute this method.
339 
340  @param cache: Pool cache, containing butler
341  @param patchRef: Patch on which to do detection
342  """
343  with self.logOperation("do detections on {}".format(patchRef.dataId)):
344  idFactory = self.detectCoaddSources.makeIdFactory(patchRef)
345  coadd = patchRef.get(self.config.coaddName + "Coadd",
346  immediate=True)
347  self.detectCoaddSources.emptyMetadata()
348  detResults = self.detectCoaddSources.runDetection(coadd, idFactory)
349  self.detectCoaddSources.write(coadd, detResults, patchRef)
350  self.detectCoaddSources.writeMetadata(patchRef)
351 
352  def runMergeDetections(self, cache, dataIdList):
353  """!Run detection merging on a patch
354 
355  Only slave nodes execute this method.
356 
357  @param cache: Pool cache, containing butler
358  @param dataIdList: List of data identifiers for the patch in different filters
359  """
360  with self.logOperation("merge detections from %s" % (dataIdList,)):
361  dataRefList = [getDataRef(cache.butler, dataId, self.config.coaddName + "Coadd_calexp") for
362  dataId in dataIdList]
363  if ("mergeCoaddDetections" in self.reuse and
364  dataRefList[0].datasetExists(self.config.coaddName + "Coadd_mergeDet", write=True)):
365  self.log.info("Skipping mergeCoaddDetections for %s; output already exists." %
366  dataRefList[0].dataId)
367  return
368  self.mergeCoaddDetections.run(dataRefList)
369 
370  def runMeasureMerged(self, cache, dataId):
371  """!Run measurement on a patch for a single filter
372 
373  Only slave nodes execute this method.
374 
375  @param cache: Pool cache, with butler
376  @param dataId: Data identifier for patch
377  @return whether the patch requires reprocessing.
378  """
379  with self.logOperation("measurement on %s" % (dataId,)):
380  dataRef = getDataRef(cache.butler, dataId,
381  self.config.coaddName + "Coadd_calexp")
382  reprocessing = False # Does this patch require reprocessing?
383  if ("measureCoaddSources" in self.reuse and
384  dataRef.datasetExists(self.config.coaddName + "Coadd_meas", write=True)):
385  if not self.config.reprocessing:
386  self.log.info("Skipping measureCoaddSources for %s; output already exists" % dataId)
387  return False
388 
389  catalog = dataRef.get(self.config.coaddName + "Coadd_meas")
390  bigFlag = catalog["deblend.parent-too-big"]
391  numOldBig = bigFlag.sum()
392  if numOldBig == 0:
393  self.log.info("No large footprints in %s" %
394  (dataRef.dataId,))
395  return False
396  numNewBig = sum((self.measureCoaddSources.deblend.isLargeFootprint(src.getFootprint()) for
397  src in catalog[bigFlag]))
398  if numNewBig == numOldBig:
399  self.log.info("All %d formerly large footprints continue to be large in %s" %
400  (numOldBig, dataRef.dataId,))
401  return False
402  self.log.info("Found %d large footprints to be reprocessed in %s" %
403  (numOldBig - numNewBig, dataRef.dataId))
404  reprocessing = True
405 
406  self.measureCoaddSources.run(dataRef)
407  return reprocessing
408 
409  def runMergeMeasurements(self, cache, dataIdList):
410  """!Run measurement merging on a patch
411 
412  Only slave nodes execute this method.
413 
414  @param cache: Pool cache, containing butler
415  @param dataIdList: List of data identifiers for the patch in different filters
416  """
417  with self.logOperation("merge measurements from %s" % (dataIdList,)):
418  dataRefList = [getDataRef(cache.butler, dataId, self.config.coaddName + "Coadd_calexp") for
419  dataId in dataIdList]
420  if ("mergeCoaddMeasurements" in self.reuse and
421  not self.config.reprocessing and
422  dataRefList[0].datasetExists(self.config.coaddName + "Coadd_ref", write=True)):
423  self.log.info("Skipping mergeCoaddMeasurements for %s; output already exists" %
424  dataRefList[0].dataId)
425  return
426  self.mergeCoaddMeasurements.run(dataRefList)
427 
428  def runForcedPhot(self, cache, dataId):
429  """!Run forced photometry on a patch for a single filter
430 
431  Only slave nodes execute this method.
432 
433  @param cache: Pool cache, with butler
434  @param dataId: Data identifier for patch
435  """
436  with self.logOperation("forced photometry on %s" % (dataId,)):
437  dataRef = getDataRef(cache.butler, dataId,
438  self.config.coaddName + "Coadd_calexp")
439  if ("forcedPhotCoadd" in self.reuse and
440  not self.config.reprocessing and
441  dataRef.datasetExists(self.config.coaddName + "Coadd_forced_src", write=True)):
442  self.log.info("Skipping forcedPhotCoadd for %s; output already exists" % dataId)
443  return
444  self.forcedPhotCoadd.run(dataRef)
445 
446  def writeMetadata(self, dataRef):
447  """We don't collect any metadata, so skip"""
448  pass
def unpickle(factory, args, kwargs)
def __init__(self, butler=None, schema=None, refObjLoader=None, reuse=tuple(), kwargs)
def getDataRef(butler, dataId, datasetType="raw")
Definition: utils.py:17
def runForcedPhot(self, cache, dataId)
Run forced photometry on a patch for a single filter.
def batchWallTime(cls, time, parsedCmd, numCpus)
Return walltime request for batch job.
def runMeasureMerged(self, cache, dataId)
Run measurement on a patch for a single filter.
def __init__(self, TaskClass, parsedCmd, doReturnResults=False)
def logOperation(self, operation, catch=False, trace=True)
def makeDataRefList(self, namespace)
Make self.refList from self.idList.
def runDetection(self, cache, patchRef)
Run detection on a patch.
def runMergeDetections(self, cache, dataIdList)
Run detection merging on a patch.
def run(self, patchRefList)
Run multiband processing on coadds.
def runMergeMeasurements(self, cache, dataIdList)
Run measurement merging on a patch.