lsst.pipe.drivers  13.0-21-g61c0bd4+4
multiBandDriver.py
Go to the documentation of this file.
1 from __future__ import absolute_import, division, print_function
2 import os
3 from argparse import ArgumentError
4 
5 from builtins import zip
6 
7 from lsst.pex.config import Config, Field, ConfigurableField
8 from lsst.pipe.base import ArgumentParser, TaskRunner
9 from lsst.pipe.tasks.multiBand import (DetectCoaddSourcesTask,
10  MergeDetectionsTask,
11  MeasureMergedCoaddSourcesTask,
12  MergeMeasurementsTask,)
13 from lsst.ctrl.pool.parallel import BatchPoolTask
14 from lsst.ctrl.pool.pool import Pool, abortOnError, NODE
15 from lsst.meas.base.references import MultiBandReferencesTask
16 from lsst.meas.base.forcedPhotCoadd import ForcedPhotCoaddTask
17 from lsst.pipe.drivers.utils import getDataRef, TractDataIdContainer
18 from lsst.pipe.tasks.coaddBase import CoaddDataIdContainer
19 
20 import lsst.afw.table as afwTable
21 
22 
23 class MultiBandDataIdContainer(CoaddDataIdContainer):
24 
25  def makeDataRefList(self, namespace):
26  """!Make self.refList from self.idList
27 
28  It's difficult to make a data reference that merely points to an entire
29  tract: there is no data product solely at the tract level. Instead, we
30  generate a list of data references for patches within the tract.
31 
32  @param namespace namespace object that is the result of an argument parser
33  """
34  datasetType = namespace.config.coaddName + "Coadd_calexp"
35 
36  def getPatchRefList(tract):
37  return [namespace.butler.dataRef(datasetType=datasetType,
38  tract=tract.getId(),
39  filter=dataId["filter"],
40  patch="%d,%d" % patch.getIndex())
41  for patch in tract]
42 
43  tractRefs = {} # Data references for each tract
44  for dataId in self.idList:
45  # There's no registry of coadds by filter, so we need to be given
46  # the filter
47  if "filter" not in dataId:
48  raise ArgumentError(None, "--id must include 'filter'")
49 
50  skymap = self.getSkymap(namespace, datasetType)
51 
52  if "tract" in dataId:
53  tractId = dataId["tract"]
54  if tractId not in tractRefs:
55  tractRefs[tractId] = []
56  if "patch" in dataId:
57  tractRefs[tractId].append(namespace.butler.dataRef(datasetType=datasetType,
58  tract=tractId,
59  filter=dataId[
60  'filter'],
61  patch=dataId['patch']))
62  else:
63  tractRefs[tractId] += getPatchRefList(skymap[tractId])
64  else:
65  tractRefs = dict((tract.getId(), tractRefs.get(tract.getId(), []) + getPatchRefList(tract))
66  for tract in skymap)
67 
68  self.refList = list(tractRefs.values())
69 
70 
71 class MultiBandDriverConfig(Config):
72  coaddName = Field(dtype=str, default="deep", doc="Name of coadd")
73  detectCoaddSources = ConfigurableField(target=DetectCoaddSourcesTask,
74  doc="Detect sources on coadd")
75  mergeCoaddDetections = ConfigurableField(
76  target=MergeDetectionsTask, doc="Merge detections")
77  measureCoaddSources = ConfigurableField(target=MeasureMergedCoaddSourcesTask,
78  doc="Measure merged detections")
79  mergeCoaddMeasurements = ConfigurableField(
80  target=MergeMeasurementsTask, doc="Merge measurements")
81  forcedPhotCoadd = ConfigurableField(target=ForcedPhotCoaddTask,
82  doc="Forced measurement on coadded images")
83  clobberDetections = Field(
84  dtype=bool, default=False, doc="Clobber existing detections?")
85  clobberMergedDetections = Field(
86  dtype=bool, default=False, doc="Clobber existing merged detections?")
87  clobberMeasurements = Field(
88  dtype=bool, default=False, doc="Clobber existing measurements?")
89  clobberMergedMeasurements = Field(
90  dtype=bool, default=False, doc="Clobber existing merged measurements?")
91  clobberForcedPhotometry = Field(
92  dtype=bool, default=False, doc="Clobber existing forced photometry?")
93  reprocessing = Field(
94  dtype=bool, default=False,
95  doc=("Are we reprocessing?\n\n"
96  "This exists as a workaround for large deblender footprints causing large memory use "
97  "and/or very slow processing. We refuse to deblend those footprints when running on a cluster "
98  "and return to reprocess on a machine with larger memory or more time "
99  "if we consider those footprints important to recover."),
100  )
101 
102  def setDefaults(self):
103  Config.setDefaults(self)
104  self.forcedPhotCoadd.references.retarget(MultiBandReferencesTask)
105 
106  def validate(self):
107  for subtask in ("mergeCoaddDetections", "measureCoaddSources",
108  "mergeCoaddMeasurements", "forcedPhotCoadd"):
109  coaddName = getattr(self, subtask).coaddName
110  if coaddName != self.coaddName:
111  raise RuntimeError("%s.coaddName (%s) doesn't match root coaddName (%s)" %
112  (subtask, coaddName, self.coaddName))
113 
114 
115 class MultiBandDriverTaskRunner(TaskRunner):
116  """TaskRunner for running MultiBandTask
117 
118  This is similar to the lsst.pipe.base.ButlerInitializedTaskRunner,
119  except that we have a list of data references instead of a single
120  data reference being passed to the Task.run.
121  """
122 
123  def makeTask(self, parsedCmd=None, args=None):
124  """A variant of the base version that passes a butler argument to the task's constructor
125  parsedCmd or args must be specified.
126  """
127  if parsedCmd is not None:
128  butler = parsedCmd.butler
129  elif args is not None:
130  dataRefList, kwargs = args
131  butler = dataRefList[0].butlerSubset.butler
132  else:
133  raise RuntimeError("parsedCmd or args must be specified")
134  return self.TaskClass(config=self.config, log=self.log, butler=butler)
135 
136 
137 def unpickle(factory, args, kwargs):
138  """Unpickle something by calling a factory"""
139  return factory(*args, **kwargs)
140 
141 
142 class MultiBandDriverTask(BatchPoolTask):
143  """Multi-node driver for multiband processing"""
144  ConfigClass = MultiBandDriverConfig
145  _DefaultName = "multiBandDriver"
146  RunnerClass = MultiBandDriverTaskRunner
147 
148  def __init__(self, butler=None, schema=None, refObjLoader=None, **kwargs):
149  """!
150  @param[in] butler: the butler can be used to retrieve schema or passed to the refObjLoader constructor
151  in case it is needed.
152  @param[in] schema: the schema of the source detection catalog used as input.
153  @param[in] refObjLoader: an instance of LoadReferenceObjectsTasks that supplies an external reference
154  catalog. May be None if the butler argument is provided or all steps requiring a reference
155  catalog are disabled.
156  """
157  BatchPoolTask.__init__(self, **kwargs)
158  if schema is None:
159  assert butler is not None, "Butler not provided"
160  schema = butler.get(self.config.coaddName +
161  "Coadd_det_schema", immediate=True).schema
162  self.butler = butler
163  self.makeSubtask("detectCoaddSources")
164  self.makeSubtask("mergeCoaddDetections", schema=schema)
165  self.makeSubtask("measureCoaddSources", schema=afwTable.Schema(self.mergeCoaddDetections.schema),
166  peakSchema=afwTable.Schema(
167  self.mergeCoaddDetections.merged.getPeakSchema()),
168  refObjLoader=refObjLoader, butler=butler)
169  self.makeSubtask("mergeCoaddMeasurements", schema=afwTable.Schema(
170  self.measureCoaddSources.schema))
171  self.makeSubtask("forcedPhotCoadd", refSchema=afwTable.Schema(
172  self.mergeCoaddMeasurements.schema))
173 
174  def __reduce__(self):
175  """Pickler"""
176  return unpickle, (self.__class__, [], dict(config=self.config, name=self._name,
177  parentTask=self._parentTask, log=self.log,
178  butler=self.butler))
179 
180  @classmethod
181  def _makeArgumentParser(cls, *args, **kwargs):
182  kwargs.pop("doBatch", False)
183  parser = ArgumentParser(name=cls._DefaultName, *args, **kwargs)
184  parser.add_id_argument("--id", "deepCoadd", help="data ID, e.g. --id tract=12345 patch=1,2",
185  ContainerClass=TractDataIdContainer)
186  return parser
187 
188  @classmethod
189  def batchWallTime(cls, time, parsedCmd, numCpus):
190  """!Return walltime request for batch job
191 
192  @param time: Requested time per iteration
193  @param parsedCmd: Results of argument parsing
194  @param numCores: Number of cores
195  """
196  numTargets = 0
197  for refList in parsedCmd.id.refList:
198  numTargets += len(refList)
199  return time*numTargets/float(numCpus)
200 
201  @abortOnError
202  def run(self, patchRefList):
203  """!Run multiband processing on coadds
204 
205  Only the master node runs this method.
206 
207  No real MPI communication (scatter/gather) takes place: all I/O goes
208  through the disk. We want the intermediate stages on disk, and the
209  component Tasks are implemented around this, so we just follow suit.
210 
211  @param patchRefList: Data references to run measurement
212  """
213  for patchRef in patchRefList:
214  if patchRef:
215  butler = patchRef.getButler()
216  break
217  else:
218  raise RuntimeError("No valid patches")
219  pool = Pool("all")
220  pool.cacheClear()
221  pool.storeSet(butler=butler)
222 
223  # MultiBand measurements require that the detection stage be completed before
224  # measurements can be made. Determine if data products are present, but detections
225  # are not, and attempt to run the detection stage where necessary. The configuration
226  # for coaddDriver.py allows detection to be turned of in the event that fake objects
227  # are to be added during the detection process. This allows the long co-addition
228  # process to be run once, and multiple different MultiBand reruns (with different
229  # fake objects) to exist from the same base co-addition.
230  # If the detections are to be clobbered, add all patches to the detection list
231  # unless the datasets necessary to generate detections do not exist
232  if self.config.clobberDetections:
233  detectionList = [patchRef for patchRef in patchRefList if
234  patchRef.datasetExists(self.config.coaddName + "Coadd")]
235  else:
236  detectionList = [patchRef for patchRef in patchRefList if not
237  patchRef.datasetExists(self.config.coaddName +
238  "Coadd_calexp") and
239  patchRef.datasetExists(self.config.coaddName +
240  "Coadd")]
241 
242  pool.map(self.runDetection, detectionList)
243 
244  patchRefList = [patchRef for patchRef in patchRefList if
245  patchRef.datasetExists(self.config.coaddName + "Coadd_calexp") and
246  patchRef.datasetExists(self.config.coaddName + "Coadd_det")]
247  dataIdList = [patchRef.dataId for patchRef in patchRefList]
248 
249  # Group by patch
250  patches = {}
251  tract = None
252  for patchRef in patchRefList:
253  dataId = patchRef.dataId
254  if tract is None:
255  tract = dataId["tract"]
256  else:
257  assert tract == dataId["tract"]
258 
259  patch = dataId["patch"]
260  if patch not in patches:
261  patches[patch] = []
262  patches[patch].append(dataId)
263 
264  pool.map(self.runMergeDetections, patches.values())
265 
266  # Measure merged detections, and test for reprocessing
267  #
268  # The reprocessing allows us to have multiple attempts at deblending large footprints. Large
269  # footprints can suck up a lot of memory in the deblender, which means that when we process on a
270  # cluster, we want to refuse to deblend them (they're flagged "deblend.parent-too-big"). But since
271  # they may have astronomically interesting data, we want the ability to go back and reprocess them
272  # with a more permissive configuration when we have more memory or processing time.
273  #
274  # self.runMeasureMerged will return whether there are any footprints in that image that required
275  # reprocessing. We need to convert that list of booleans into a dict mapping the patchId (x,y) to
276  # a boolean. That tells us whether the merge measurement and forced photometry need to be re-run on
277  # a particular patch.
278  #
279  # This determination of which patches need to be reprocessed exists only in memory (the measurements
280  # have been written, clobbering the old ones), so if there was an exception we would lose this
281  # information, leaving things in an inconsistent state (measurements new, but merged measurements and
282  # forced photometry old). To attempt to preserve this status, we touch a file (dataset named
283  # "deepCoadd_multibandReprocessing") --- if this file exists, we need to re-run the merge and
284  # forced photometry.
285  #
286  # This is, hopefully, a temporary workaround until we can improve the
287  # deblender.
288  try:
289  reprocessed = pool.map(self.runMeasureMerged, dataIdList)
290  finally:
291  if self.config.reprocessing:
292  patchReprocessing = {}
293  for dataId, reprocess in zip(dataIdList, reprocessed):
294  patchId = dataId["patch"]
295  patchReprocessing[patchId] = patchReprocessing.get(
296  patchId, False) or reprocess
297  # Persist the determination, to make error recover easier
298  reprocessDataset = self.config.coaddName + "Coadd_multibandReprocessing"
299  for patchId in patchReprocessing:
300  if not patchReprocessing[patchId]:
301  continue
302  dataId = dict(tract=tract, patch=patchId)
303  if patchReprocessing[patchId]:
304  filename = butler.get(
305  reprocessDataset + "_filename", dataId)[0]
306  open(filename, 'a').close() # Touch file
307  elif butler.datasetExists(reprocessDataset, dataId):
308  # We must have failed at some point while reprocessing
309  # and we're starting over
310  patchReprocessing[patchId] = True
311 
312  # Only process patches that have been identified as needing it
313  pool.map(self.runMergeMeasurements, [idList for patchId, idList in patches.items() if
314  not self.config.reprocessing or patchReprocessing[patchId]])
315  pool.map(self.runForcedPhot, [dataId1 for dataId1 in dataIdList if not self.config.reprocessing or
316  patchReprocessing[dataId["patch"]]])
317 
318  # Remove persisted reprocessing determination
319  if self.config.reprocessing:
320  for patchId in patchReprocessing:
321  if not patchReprocessing[patchId]:
322  continue
323  dataId = dict(tract=tract, patch=patchId)
324  filename = butler.get(
325  reprocessDataset + "_filename", dataId)[0]
326  os.unlink(filename)
327 
328  def runDetection(self, cache, patchRef):
329  """! Run detection on a patch
330 
331  Only slave nodes execute this method.
332 
333  @param cache: Pool cache, containing butler
334  @param patchRef: Patch on which to do detection
335  """
336  with self.logOperation("do detections on {}".format(patchRef.dataId)):
337  idFactory = self.detectCoaddSources.makeIdFactory(patchRef)
338  coadd = patchRef.get(self.config.coaddName + "Coadd",
339  immediate=True)
340  detResults = self.detectCoaddSources.runDetection(coadd, idFactory)
341  self.detectCoaddSources.write(coadd, detResults, patchRef)
342 
343  def runMergeDetections(self, cache, dataIdList):
344  """!Run detection merging on a patch
345 
346  Only slave nodes execute this method.
347 
348  @param cache: Pool cache, containing butler
349  @param dataIdList: List of data identifiers for the patch in different filters
350  """
351  with self.logOperation("merge detections from %s" % (dataIdList,)):
352  dataRefList = [getDataRef(cache.butler, dataId, self.config.coaddName + "Coadd_calexp") for
353  dataId in dataIdList]
354  if (not self.config.clobberMergedDetections and
355  dataRefList[0].datasetExists(self.config.coaddName + "Coadd_mergeDet")):
356  return
357  self.mergeCoaddDetections.run(dataRefList)
358 
359  def runMeasureMerged(self, cache, dataId):
360  """!Run measurement on a patch for a single filter
361 
362  Only slave nodes execute this method.
363 
364  @param cache: Pool cache, with butler
365  @param dataId: Data identifier for patch
366  @return whether the patch requires reprocessing.
367  """
368  with self.logOperation("measurement on %s" % (dataId,)):
369  dataRef = getDataRef(cache.butler, dataId,
370  self.config.coaddName + "Coadd_calexp")
371  reprocessing = False # Does this patch require reprocessing?
372  if (not self.config.clobberMeasurements and
373  dataRef.datasetExists(self.config.coaddName + "Coadd_meas")):
374  if not self.config.reprocessing:
375  return False
376 
377  catalog = dataRef.get(self.config.coaddName + "Coadd_meas")
378  bigFlag = catalog["deblend.parent-too-big"]
379  numOldBig = bigFlag.sum()
380  if numOldBig == 0:
381  self.log.info("No large footprints in %s" %
382  (dataRef.dataId,))
383  return False
384  numNewBig = sum((self.measureCoaddSources.deblend.isLargeFootprint(src.getFootprint()) for
385  src in catalog[bigFlag]))
386  if numNewBig == numOldBig:
387  self.log.info("All %d formerly large footprints continue to be large in %s" %
388  (numOldBig, dataRef.dataId,))
389  return False
390  self.log.info("Found %d large footprints to be reprocessed in %s" %
391  (numOldBig - numNewBig, dataRef.dataId))
392  reprocessing = True
393 
394  self.measureCoaddSources.run(dataRef)
395  return reprocessing
396 
397  def runMergeMeasurements(self, cache, dataIdList):
398  """!Run measurement merging on a patch
399 
400  Only slave nodes execute this method.
401 
402  @param cache: Pool cache, containing butler
403  @param dataIdList: List of data identifiers for the patch in different filters
404  """
405  with self.logOperation("merge measurements from %s" % (dataIdList,)):
406  dataRefList = [getDataRef(cache.butler, dataId, self.config.coaddName + "Coadd_calexp") for
407  dataId in dataIdList]
408  if (not self.config.clobberMergedMeasurements and
409  not self.config.reprocessing and
410  dataRefList[0].datasetExists(self.config.coaddName + "Coadd_ref")):
411  return
412  self.mergeCoaddMeasurements.run(dataRefList)
413 
414  def runForcedPhot(self, cache, dataId):
415  """!Run forced photometry on a patch for a single filter
416 
417  Only slave nodes execute this method.
418 
419  @param cache: Pool cache, with butler
420  @param dataId: Data identifier for patch
421  """
422  with self.logOperation("forced photometry on %s" % (dataId,)):
423  dataRef = getDataRef(cache.butler, dataId,
424  self.config.coaddName + "Coadd_calexp")
425  if (not self.config.clobberForcedPhotometry and
426  not self.config.reprocessing and
427  dataRef.datasetExists(self.config.coaddName + "Coadd_forced_src")):
428  return
429  self.forcedPhotCoadd.run(dataRef)
430 
431  def writeMetadata(self, dataRef):
432  """We don't collect any metadata, so skip"""
433  pass
def unpickle(factory, args, kwargs)
def __init__(self, butler=None, schema=None, refObjLoader=None, kwargs)
def getDataRef(butler, dataId, datasetType="raw")
Definition: utils.py:17
def runForcedPhot(self, cache, dataId)
Run forced photometry on a patch for a single filter.
def batchWallTime(cls, time, parsedCmd, numCpus)
Return walltime request for batch job.
def runMeasureMerged(self, cache, dataId)
Run measurement on a patch for a single filter.
def makeDataRefList(self, namespace)
Make self.refList from self.idList.
def runDetection(self, cache, patchRef)
Run detection on a patch.
def runMergeDetections(self, cache, dataIdList)
Run detection merging on a patch.
def run(self, patchRefList)
Run multiband processing on coadds.
def runMergeMeasurements(self, cache, dataIdList)
Run measurement merging on a patch.