lsst.pipe.drivers  16.0-4-g17aeff3+7
multiBandDriver.py
Go to the documentation of this file.
1 from __future__ import absolute_import, division, print_function
2 import os
3 from argparse import ArgumentError
4 
5 from builtins import zip
6 
7 from lsst.pex.config import Config, Field, ConfigurableField
8 from lsst.pipe.base import ArgumentParser, TaskRunner
9 from lsst.pipe.tasks.multiBand import (DetectCoaddSourcesTask,
10  MergeDetectionsTask,
11  DeblendCoaddSourcesTask,
12  MeasureMergedCoaddSourcesTask,
13  MergeMeasurementsTask,)
14 from lsst.ctrl.pool.parallel import BatchPoolTask
15 from lsst.ctrl.pool.pool import Pool, abortOnError, NODE
16 from lsst.meas.base.references import MultiBandReferencesTask
17 from lsst.meas.base.forcedPhotCoadd import ForcedPhotCoaddTask
18 from lsst.pipe.drivers.utils import getDataRef, TractDataIdContainer
19 from lsst.pipe.tasks.coaddBase import CoaddDataIdContainer
20 
21 import lsst.afw.table as afwTable
22 
23 
24 class MultiBandDataIdContainer(CoaddDataIdContainer):
25 
26  def makeDataRefList(self, namespace):
27  """!Make self.refList from self.idList
28 
29  It's difficult to make a data reference that merely points to an entire
30  tract: there is no data product solely at the tract level. Instead, we
31  generate a list of data references for patches within the tract.
32 
33  @param namespace namespace object that is the result of an argument parser
34  """
35  datasetType = namespace.config.coaddName + "Coadd_calexp"
36 
37  def getPatchRefList(tract):
38  return [namespace.butler.dataRef(datasetType=datasetType,
39  tract=tract.getId(),
40  filter=dataId["filter"],
41  patch="%d,%d" % patch.getIndex())
42  for patch in tract]
43 
44  tractRefs = {} # Data references for each tract
45  for dataId in self.idList:
46  # There's no registry of coadds by filter, so we need to be given
47  # the filter
48  if "filter" not in dataId:
49  raise ArgumentError(None, "--id must include 'filter'")
50 
51  skymap = self.getSkymap(namespace, datasetType)
52 
53  if "tract" in dataId:
54  tractId = dataId["tract"]
55  if tractId not in tractRefs:
56  tractRefs[tractId] = []
57  if "patch" in dataId:
58  tractRefs[tractId].append(namespace.butler.dataRef(datasetType=datasetType,
59  tract=tractId,
60  filter=dataId[
61  'filter'],
62  patch=dataId['patch']))
63  else:
64  tractRefs[tractId] += getPatchRefList(skymap[tractId])
65  else:
66  tractRefs = dict((tract.getId(), tractRefs.get(tract.getId(), []) + getPatchRefList(tract))
67  for tract in skymap)
68 
69  self.refList = list(tractRefs.values())
70 
71 
72 class MultiBandDriverConfig(Config):
73  coaddName = Field(dtype=str, default="deep", doc="Name of coadd")
74  doDetection = Field(dtype=bool, default=False,
75  doc="Re-run detection? (requires *Coadd dataset to have been written)")
76  detectCoaddSources = ConfigurableField(target=DetectCoaddSourcesTask,
77  doc="Detect sources on coadd")
78  mergeCoaddDetections = ConfigurableField(
79  target=MergeDetectionsTask, doc="Merge detections")
80  deblendCoaddSources = ConfigurableField(target=DeblendCoaddSourcesTask, doc="Deblend merged detections")
81  measureCoaddSources = ConfigurableField(target=MeasureMergedCoaddSourcesTask,
82  doc="Measure merged and (optionally) deblended detections")
83  mergeCoaddMeasurements = ConfigurableField(
84  target=MergeMeasurementsTask, doc="Merge measurements")
85  forcedPhotCoadd = ConfigurableField(target=ForcedPhotCoaddTask,
86  doc="Forced measurement on coadded images")
87  reprocessing = Field(
88  dtype=bool, default=False,
89  doc=("Are we reprocessing?\n\n"
90  "This exists as a workaround for large deblender footprints causing large memory use "
91  "and/or very slow processing. We refuse to deblend those footprints when running on a cluster "
92  "and return to reprocess on a machine with larger memory or more time "
93  "if we consider those footprints important to recover."),
94  )
95 
96  def setDefaults(self):
97  Config.setDefaults(self)
98  self.forcedPhotCoadd.references.retarget(MultiBandReferencesTask)
99 
100  def validate(self):
101  for subtask in ("mergeCoaddDetections", "deblendCoaddSources", "measureCoaddSources",
102  "mergeCoaddMeasurements", "forcedPhotCoadd"):
103  coaddName = getattr(self, subtask).coaddName
104  if coaddName != self.coaddName:
105  raise RuntimeError("%s.coaddName (%s) doesn't match root coaddName (%s)" %
106  (subtask, coaddName, self.coaddName))
107 
108 
109 class MultiBandDriverTaskRunner(TaskRunner):
110  """TaskRunner for running MultiBandTask
111 
112  This is similar to the lsst.pipe.base.ButlerInitializedTaskRunner,
113  except that we have a list of data references instead of a single
114  data reference being passed to the Task.run, and we pass the results
115  of the '--reuse-outputs-from' command option to the Task constructor.
116  """
117 
118  def __init__(self, TaskClass, parsedCmd, doReturnResults=False):
119  TaskRunner.__init__(self, TaskClass, parsedCmd, doReturnResults)
120  self.reuse = parsedCmd.reuse
121 
122  def makeTask(self, parsedCmd=None, args=None):
123  """A variant of the base version that passes a butler argument to the task's constructor
124  parsedCmd or args must be specified.
125  """
126  if parsedCmd is not None:
127  butler = parsedCmd.butler
128  elif args is not None:
129  dataRefList, kwargs = args
130  butler = dataRefList[0].butlerSubset.butler
131  else:
132  raise RuntimeError("parsedCmd or args must be specified")
133  return self.TaskClass(config=self.config, log=self.log, butler=butler, reuse=self.reuse)
134 
135 
136 def unpickle(factory, args, kwargs):
137  """Unpickle something by calling a factory"""
138  return factory(*args, **kwargs)
139 
140 
142  """Multi-node driver for multiband processing"""
143  ConfigClass = MultiBandDriverConfig
144  _DefaultName = "multiBandDriver"
145  RunnerClass = MultiBandDriverTaskRunner
146 
147  def __init__(self, butler=None, schema=None, refObjLoader=None, reuse=tuple(), **kwargs):
148  """!
149  @param[in] butler: the butler can be used to retrieve schema or passed to the refObjLoader constructor
150  in case it is needed.
151  @param[in] schema: the schema of the source detection catalog used as input.
152  @param[in] refObjLoader: an instance of LoadReferenceObjectsTasks that supplies an external reference
153  catalog. May be None if the butler argument is provided or all steps requiring a reference
154  catalog are disabled.
155  """
156  BatchPoolTask.__init__(self, **kwargs)
157  if schema is None:
158  assert butler is not None, "Butler not provided"
159  schema = butler.get(self.config.coaddName +
160  "Coadd_det_schema", immediate=True).schema
161  self.butler = butler
162  self.reuse = tuple(reuse)
163  self.makeSubtask("detectCoaddSources")
164  self.makeSubtask("mergeCoaddDetections", schema=schema)
165  if self.config.measureCoaddSources.inputCatalog.startswith("deblended"):
166  # Ensure that the output from deblendCoaddSources matches the input to measureCoaddSources
167  self.measurementInput = self.config.measureCoaddSources.inputCatalog
168  self.deblenderOutput = []
169  if self.config.deblendCoaddSources.simultaneous:
170  if self.config.deblendCoaddSources.multiBandDeblend.conserveFlux:
171  self.deblenderOutput.append("deblendedFlux")
172  if self.config.deblendCoaddSources.multiBandDeblend.saveTemplates:
173  self.deblenderOutput.append("deblendedModel")
174  else:
175  self.deblenderOutput.append("deblendedFlux")
176  if self.measurementInput not in self.deblenderOutput:
177  err = "Measurement input '{0}' is not in the list of deblender output catalogs '{1}'"
178  raise ValueError(err.format(self.measurementInput, self.deblenderOutput))
179 
180  self.makeSubtask("deblendCoaddSources",
181  schema=afwTable.Schema(self.mergeCoaddDetections.schema),
182  peakSchema=afwTable.Schema(self.mergeCoaddDetections.merged.getPeakSchema()),
183  butler=butler)
184  measureInputSchema = afwTable.Schema(self.deblendCoaddSources.schema)
185  else:
186  measureInputSchema = afwTable.Schema(self.mergeCoaddDetections.schema)
187  self.makeSubtask("measureCoaddSources", schema=measureInputSchema,
188  peakSchema=afwTable.Schema(
189  self.mergeCoaddDetections.merged.getPeakSchema()),
190  refObjLoader=refObjLoader, butler=butler)
191  self.makeSubtask("mergeCoaddMeasurements", schema=afwTable.Schema(
192  self.measureCoaddSources.schema))
193  self.makeSubtask("forcedPhotCoadd", refSchema=afwTable.Schema(
194  self.mergeCoaddMeasurements.schema))
195 
196  def __reduce__(self):
197  """Pickler"""
198  return unpickle, (self.__class__, [], dict(config=self.config, name=self._name,
199  parentTask=self._parentTask, log=self.log,
200  butler=self.butler, reuse=self.reuse))
201 
202  @classmethod
203  def _makeArgumentParser(cls, *args, **kwargs):
204  kwargs.pop("doBatch", False)
205  parser = ArgumentParser(name=cls._DefaultName, *args, **kwargs)
206  parser.add_id_argument("--id", "deepCoadd", help="data ID, e.g. --id tract=12345 patch=1,2",
207  ContainerClass=TractDataIdContainer)
208  parser.addReuseOption(["detectCoaddSources", "mergeCoaddDetections", "measureCoaddSources",
209  "mergeCoaddMeasurements", "forcedPhotCoadd"])
210  return parser
211 
212  @classmethod
213  def batchWallTime(cls, time, parsedCmd, numCpus):
214  """!Return walltime request for batch job
215 
216  @param time: Requested time per iteration
217  @param parsedCmd: Results of argument parsing
218  @param numCores: Number of cores
219  """
220  numTargets = 0
221  for refList in parsedCmd.id.refList:
222  numTargets += len(refList)
223  return time*numTargets/float(numCpus)
224 
225  @abortOnError
226  def runDataRef(self, patchRefList):
227  """!Run multiband processing on coadds
228 
229  Only the master node runs this method.
230 
231  No real MPI communication (scatter/gather) takes place: all I/O goes
232  through the disk. We want the intermediate stages on disk, and the
233  component Tasks are implemented around this, so we just follow suit.
234 
235  @param patchRefList: Data references to run measurement
236  """
237  for patchRef in patchRefList:
238  if patchRef:
239  butler = patchRef.getButler()
240  break
241  else:
242  raise RuntimeError("No valid patches")
243  pool = Pool("all")
244  pool.cacheClear()
245  pool.storeSet(butler=butler)
246 
247  # MultiBand measurements require that the detection stage be completed
248  # before measurements can be made.
249  #
250  # The configuration for coaddDriver.py allows detection to be turned
251  # of in the event that fake objects are to be added during the
252  # detection process. This allows the long co-addition process to be
253  # run once, and multiple different MultiBand reruns (with different
254  # fake objects) to exist from the same base co-addition.
255  #
256  # However, we only re-run detection if doDetection is explicitly True
257  # here (this should always be the opposite of coaddDriver.doDetection);
258  # otherwise we have no way to tell reliably whether any detections
259  # present in an input repo are safe to use.
260  if self.config.doDetection:
261  detectionList = []
262  for patchRef in patchRefList:
263  if ("detectCoaddSources" in self.reuse and
264  patchRef.datasetExists(self.config.coaddName + "Coadd_calexp", write=True)):
265  self.log.info("Skipping detectCoaddSources for %s; output already exists." % patchRef.dataId)
266  continue
267  if not patchRef.datasetExists(self.config.coaddName + "Coadd"):
268  self.log.debug("Not processing %s; required input %sCoadd missing." %
269  (patchRef.dataId, self.config.coaddName))
270  continue
271  detectionList.append(patchRef)
272 
273  pool.map(self.runDetection, detectionList)
274 
275  patchRefList = [patchRef for patchRef in patchRefList if
276  patchRef.datasetExists(self.config.coaddName + "Coadd_calexp") and
277  patchRef.datasetExists(self.config.coaddName + "Coadd_det", write=self.config.doDetection)]
278  dataIdList = [patchRef.dataId for patchRef in patchRefList]
279 
280  # Group by patch
281  patches = {}
282  tract = None
283  for patchRef in patchRefList:
284  dataId = patchRef.dataId
285  if tract is None:
286  tract = dataId["tract"]
287  else:
288  assert tract == dataId["tract"]
289 
290  patch = dataId["patch"]
291  if patch not in patches:
292  patches[patch] = []
293  patches[patch].append(dataId)
294 
295  pool.map(self.runMergeDetections, patches.values())
296 
297  # Deblend merged detections, and test for reprocessing
298  #
299  # The reprocessing allows us to have multiple attempts at deblending large footprints. Large
300  # footprints can suck up a lot of memory in the deblender, which means that when we process on a
301  # cluster, we want to refuse to deblend them (they're flagged "deblend.parent-too-big"). But since
302  # they may have astronomically interesting data, we want the ability to go back and reprocess them
303  # with a more permissive configuration when we have more memory or processing time.
304  #
305  # self.runDeblendMerged will return whether there are any footprints in that image that required
306  # reprocessing. We need to convert that list of booleans into a dict mapping the patchId (x,y) to
307  # a boolean. That tells us whether the merge measurement and forced photometry need to be re-run on
308  # a particular patch.
309  #
310  # This determination of which patches need to be reprocessed exists only in memory (the measurements
311  # have been written, clobbering the old ones), so if there was an exception we would lose this
312  # information, leaving things in an inconsistent state (measurements, merged measurements and
313  # forced photometry old). To attempt to preserve this status, we touch a file (dataset named
314  # "deepCoadd_multibandReprocessing") --- if this file exists, we need to re-run the measurements,
315  # merge and forced photometry.
316  #
317  # This is, hopefully, a temporary workaround until we can improve the
318  # deblender.
319  try:
320  reprocessed = pool.map(self.runDeblendMerged, patches.values())
321  finally:
322  if self.config.reprocessing:
323  patchReprocessing = {}
324  for dataId, reprocess in zip(dataIdList, reprocessed):
325  patchId = dataId["patch"]
326  patchReprocessing[patchId] = patchReprocessing.get(
327  patchId, False) or reprocess
328  # Persist the determination, to make error recover easier
329  reprocessDataset = self.config.coaddName + "Coadd_multibandReprocessing"
330  for patchId in patchReprocessing:
331  if not patchReprocessing[patchId]:
332  continue
333  dataId = dict(tract=tract, patch=patchId)
334  if patchReprocessing[patchId]:
335  filename = butler.get(
336  reprocessDataset + "_filename", dataId)[0]
337  open(filename, 'a').close() # Touch file
338  elif butler.datasetExists(reprocessDataset, dataId):
339  # We must have failed at some point while reprocessing
340  # and we're starting over
341  patchReprocessing[patchId] = True
342 
343  # Only process patches that have been identified as needing it
344  pool.map(self.runMeasurements, [dataId1 for dataId1 in dataIdList if not self.config.reprocessing or
345  patchReprocessing[dataId1["patch"]]])
346  pool.map(self.runMergeMeasurements, [idList for patchId, idList in patches.items() if
347  not self.config.reprocessing or patchReprocessing[patchId]])
348  pool.map(self.runForcedPhot, [dataId1 for dataId1 in dataIdList if not self.config.reprocessing or
349  patchReprocessing[dataId1["patch"]]])
350 
351  # Remove persisted reprocessing determination
352  if self.config.reprocessing:
353  for patchId in patchReprocessing:
354  if not patchReprocessing[patchId]:
355  continue
356  dataId = dict(tract=tract, patch=patchId)
357  filename = butler.get(
358  reprocessDataset + "_filename", dataId)[0]
359  os.unlink(filename)
360 
361  def runDetection(self, cache, patchRef):
362  """! Run detection on a patch
363 
364  Only slave nodes execute this method.
365 
366  @param cache: Pool cache, containing butler
367  @param patchRef: Patch on which to do detection
368  """
369  with self.logOperation("do detections on {}".format(patchRef.dataId)):
370  idFactory = self.detectCoaddSources.makeIdFactory(patchRef)
371  coadd = patchRef.get(self.config.coaddName + "Coadd",
372  immediate=True)
373  expId = int(patchRef.get(self.config.coaddName + "CoaddId"))
374  self.detectCoaddSources.emptyMetadata()
375  detResults = self.detectCoaddSources.run(coadd, idFactory, expId=expId)
376  self.detectCoaddSources.write(coadd, detResults, patchRef)
377  self.detectCoaddSources.writeMetadata(patchRef)
378 
379  def runMergeDetections(self, cache, dataIdList):
380  """!Run detection merging on a patch
381 
382  Only slave nodes execute this method.
383 
384  @param cache: Pool cache, containing butler
385  @param dataIdList: List of data identifiers for the patch in different filters
386  """
387  with self.logOperation("merge detections from %s" % (dataIdList,)):
388  dataRefList = [getDataRef(cache.butler, dataId, self.config.coaddName + "Coadd_calexp") for
389  dataId in dataIdList]
390  if ("mergeCoaddDetections" in self.reuse and
391  dataRefList[0].datasetExists(self.config.coaddName + "Coadd_mergeDet", write=True)):
392  self.log.info("Skipping mergeCoaddDetections for %s; output already exists." %
393  dataRefList[0].dataId)
394  return
395  self.mergeCoaddDetections.runDataRef(dataRefList)
396 
397  def runDeblendMerged(self, cache, dataIdList):
398  """Run the deblender on a list of dataId's
399 
400  Only slave nodes execute this method.
401 
402  Parameters
403  ----------
404  cache: Pool cache
405  Pool cache with butler.
406  dataIdList: list
407  Data identifier for patch in each band.
408 
409  Returns
410  -------
411  result: bool
412  whether the patch requires reprocessing.
413  """
414  with self.logOperation("deblending %s" % (dataIdList,)):
415  dataRefList = [getDataRef(cache.butler, dataId, self.config.coaddName + "Coadd_calexp") for
416  dataId in dataIdList]
417  reprocessing = False # Does this patch require reprocessing?
418  if ("deblendCoaddSources" in self.reuse and
419  dataRef.datasetExists(self.config.coaddName + self.measurementInput, write=True)):
420  if not self.config.reprocessing:
421  self.log.info("Skipping deblendCoaddSources for %s; output already exists" % dataIdList)
422  return False
423 
424  catalog = dataRefList[0].get(self.config.coaddName + self.measurementInput)
425  bigFlag = catalog["deblend.parent-too-big"]
426  numOldBig = bigFlag.sum()
427  if numOldBig == 0:
428  self.log.info("No large footprints in %s" %
429  (dataRef.dataId,))
430  return False
431  numNewBig = sum((self.deblendCoaddSources.isLargeFootprint(src.getFootprint()) for
432  src in catalog[bigFlag]))
433  if numNewBig == numOldBig:
434  self.log.info("All %d formerly large footprints continue to be large in %s" %
435  (numOldBig, dataRefList[0].dataId,))
436  return False
437  self.log.info("Found %d large footprints to be reprocessed in %s" %
438  (numOldBig - numNewBig, [dataRef.dataId for dataRef in dataRefList]))
439  reprocessing = True
440 
441  self.deblendCoaddSources.runDataRef(dataRefList)
442  return reprocessing
443 
444  def runMeasurements(self, cache, dataId):
445  """Run measurement on a patch for a single filter
446 
447  Only slave nodes execute this method.
448 
449  Parameters
450  ----------
451  cache: Pool cache
452  Pool cache, with butler
453  dataId: dataRef
454  Data identifier for patch
455  """
456  with self.logOperation("measurements on %s" % (dataId,)):
457  dataRef = getDataRef(cache.butler, dataId,
458  self.config.coaddName + "Coadd_calexp")
459  if ("measureCoaddSources" in self.reuse and
460  not self.config.reprocessing and
461  dataRef.datasetExists(self.config.coaddName + "Coadd_meas", write=True)):
462  self.log.info("Skipping measuretCoaddSources for %s; output already exists" % dataId)
463  return
464  self.measureCoaddSources.runDataRef(dataRef)
465 
466  def runMergeMeasurements(self, cache, dataIdList):
467  """!Run measurement merging on a patch
468 
469  Only slave nodes execute this method.
470 
471  @param cache: Pool cache, containing butler
472  @param dataIdList: List of data identifiers for the patch in different filters
473  """
474  with self.logOperation("merge measurements from %s" % (dataIdList,)):
475  dataRefList = [getDataRef(cache.butler, dataId, self.config.coaddName + "Coadd_calexp") for
476  dataId in dataIdList]
477  if ("mergeCoaddMeasurements" in self.reuse and
478  not self.config.reprocessing and
479  dataRefList[0].datasetExists(self.config.coaddName + "Coadd_ref", write=True)):
480  self.log.info("Skipping mergeCoaddMeasurements for %s; output already exists" %
481  dataRefList[0].dataId)
482  return
483  self.mergeCoaddMeasurements.runDataRef(dataRefList)
484 
485  def runForcedPhot(self, cache, dataId):
486  """!Run forced photometry on a patch for a single filter
487 
488  Only slave nodes execute this method.
489 
490  @param cache: Pool cache, with butler
491  @param dataId: Data identifier for patch
492  """
493  with self.logOperation("forced photometry on %s" % (dataId,)):
494  dataRef = getDataRef(cache.butler, dataId,
495  self.config.coaddName + "Coadd_calexp")
496  if ("forcedPhotCoadd" in self.reuse and
497  not self.config.reprocessing and
498  dataRef.datasetExists(self.config.coaddName + "Coadd_forced_src", write=True)):
499  self.log.info("Skipping forcedPhotCoadd for %s; output already exists" % dataId)
500  return
501  self.forcedPhotCoadd.runDataRef(dataRef)
502 
503  def writeMetadata(self, dataRef):
504  """We don't collect any metadata, so skip"""
505  pass
def unpickle(factory, args, kwargs)
def runDataRef(self, patchRefList)
Run multiband processing on coadds.
def __init__(self, butler=None, schema=None, refObjLoader=None, reuse=tuple(), kwargs)
def getDataRef(butler, dataId, datasetType="raw")
Definition: utils.py:17
def runForcedPhot(self, cache, dataId)
Run forced photometry on a patch for a single filter.
def batchWallTime(cls, time, parsedCmd, numCpus)
Return walltime request for batch job.
def __init__(self, TaskClass, parsedCmd, doReturnResults=False)
def logOperation(self, operation, catch=False, trace=True)
def makeDataRefList(self, namespace)
Make self.refList from self.idList.
def runDetection(self, cache, patchRef)
Run detection on a patch.
def runMergeDetections(self, cache, dataIdList)
Run detection merging on a patch.
def runMergeMeasurements(self, cache, dataIdList)
Run measurement merging on a patch.