lsst.pipe.drivers  13.0-12-gda22aa7
 All Classes Namespaces Files Functions Variables Pages
multiBandDriver.py
Go to the documentation of this file.
1 from __future__ import absolute_import, division, print_function
2 import os
3 from argparse import ArgumentError
4 
5 from builtins import zip
6 
7 from lsst.pex.config import Config, Field, ConfigurableField
8 from lsst.pipe.base import ArgumentParser, TaskRunner
9 from lsst.pipe.tasks.multiBand import (DetectCoaddSourcesTask,
10  MergeDetectionsTask,
11  MeasureMergedCoaddSourcesTask,
12  MergeMeasurementsTask,)
13 from lsst.ctrl.pool.parallel import BatchPoolTask
14 from lsst.ctrl.pool.pool import Pool, abortOnError, NODE
15 from lsst.meas.base.references import MultiBandReferencesTask
16 from lsst.meas.base.forcedPhotCoadd import ForcedPhotCoaddTask
17 from lsst.pipe.drivers.utils import getDataRef, TractDataIdContainer
18 from lsst.pipe.tasks.coaddBase import CoaddDataIdContainer
19 
20 import lsst.afw.table as afwTable
21 
22 
23 class MultiBandDataIdContainer(CoaddDataIdContainer):
24 
25  def makeDataRefList(self, namespace):
26  """!Make self.refList from self.idList
27 
28  It's difficult to make a data reference that merely points to an entire
29  tract: there is no data product solely at the tract level. Instead, we
30  generate a list of data references for patches within the tract.
31 
32  @param namespace namespace object that is the result of an argument parser
33  """
34  datasetType = namespace.config.coaddName + "Coadd_calexp"
35 
36  def getPatchRefList(tract):
37  return [namespace.butler.dataRef(datasetType=datasetType,
38  tract=tract.getId(),
39  filter=dataId["filter"],
40  patch="%d,%d" % patch.getIndex())
41  for patch in tract]
42 
43  tractRefs = {} # Data references for each tract
44  for dataId in self.idList:
45  # There's no registry of coadds by filter, so we need to be given
46  # the filter
47  if "filter" not in dataId:
48  raise ArgumentError(None, "--id must include 'filter'")
49 
50  skymap = self.getSkymap(namespace, datasetType)
51 
52  if "tract" in dataId:
53  tractId = dataId["tract"]
54  if tractId not in tractRefs:
55  tractRefs[tractId] = []
56  if "patch" in dataId:
57  tractRefs[tractId].append(namespace.butler.dataRef(datasetType=datasetType,
58  tract=tractId,
59  filter=dataId[
60  'filter'],
61  patch=dataId['patch']))
62  else:
63  tractRefs[tractId] += getPatchRefList(skymap[tractId])
64  else:
65  tractRefs = dict((tract.getId(), tractRefs.get(tract.getId(), []) + getPatchRefList(tract))
66  for tract in skymap)
67 
68  self.refList = list(tractRefs.values())
69 
70 
71 class MultiBandDriverConfig(Config):
72  coaddName = Field(dtype=str, default="deep", doc="Name of coadd")
73  detectCoaddSources = ConfigurableField(target=DetectCoaddSourcesTask,
74  doc="Detect sources on coadd")
75  mergeCoaddDetections = ConfigurableField(
76  target=MergeDetectionsTask, doc="Merge detections")
77  measureCoaddSources = ConfigurableField(target=MeasureMergedCoaddSourcesTask,
78  doc="Measure merged detections")
79  mergeCoaddMeasurements = ConfigurableField(
80  target=MergeMeasurementsTask, doc="Merge measurements")
81  forcedPhotCoadd = ConfigurableField(target=ForcedPhotCoaddTask,
82  doc="Forced measurement on coadded images")
83  clobberDetections = Field(
84  dtype=bool, default=False, doc="Clobber existing detections?")
85  clobberMergedDetections = Field(
86  dtype=bool, default=False, doc="Clobber existing merged detections?")
87  clobberMeasurements = Field(
88  dtype=bool, default=False, doc="Clobber existing measurements?")
89  clobberMergedMeasurements = Field(
90  dtype=bool, default=False, doc="Clobber existing merged measurements?")
91  clobberForcedPhotometry = Field(
92  dtype=bool, default=False, doc="Clobber existing forced photometry?")
93  reprocessing = Field(
94  dtype=bool, default=False,
95  doc=("Are we reprocessing?\n\n"
96  "This exists as a workaround for large deblender footprints causing large memory use "
97  "and/or very slow processing. We refuse to deblend those footprints when running on a cluster "
98  "and return to reprocess on a machine with larger memory or more time "
99  "if we consider those footprints important to recover."),
100  )
101 
102  def setDefaults(self):
103  Config.setDefaults(self)
104  self.forcedPhotCoadd.references.retarget(MultiBandReferencesTask)
105 
106  def validate(self):
107  for subtask in ("mergeCoaddDetections", "measureCoaddSources",
108  "mergeCoaddMeasurements", "forcedPhotCoadd"):
109  coaddName = getattr(self, subtask).coaddName
110  if coaddName != self.coaddName:
111  raise RuntimeError("%s.coaddName (%s) doesn't match root coaddName (%s)" %
112  (subtask, coaddName, self.coaddName))
113 
114 
115 class MultiBandDriverTaskRunner(TaskRunner):
116  """TaskRunner for running MultiBandTask
117 
118  This is similar to the lsst.pipe.base.ButlerInitializedTaskRunner,
119  except that we have a list of data references instead of a single
120  data reference being passed to the Task.run.
121  """
122 
123  def makeTask(self, parsedCmd=None, args=None):
124  """A variant of the base version that passes a butler argument to the task's constructor
125  parsedCmd or args must be specified.
126  """
127  if parsedCmd is not None:
128  butler = parsedCmd.butler
129  elif args is not None:
130  dataRefList, kwargs = args
131  butler = dataRefList[0].butlerSubset.butler
132  else:
133  raise RuntimeError("parsedCmd or args must be specified")
134  return self.TaskClass(config=self.config, log=self.log, butler=butler)
135 
136 
137 def unpickle(factory, args, kwargs):
138  """Unpickle something by calling a factory"""
139  return factory(*args, **kwargs)
140 
141 
142 class MultiBandDriverTask(BatchPoolTask):
143  """Multi-node driver for multiband processing"""
144  ConfigClass = MultiBandDriverConfig
145  _DefaultName = "multiBandDriver"
146  RunnerClass = MultiBandDriverTaskRunner
147 
148  def __init__(self, butler=None, schema=None, refObjLoader=None, **kwargs):
149  """!
150  @param[in] butler: the butler can be used to retrieve schema or passed to the refObjLoader constructor
151  in case it is needed.
152  @param[in] schema: the schema of the source detection catalog used as input.
153  @param[in] refObjLoader: an instance of LoadReferenceObjectsTasks that supplies an external reference
154  catalog. May be None if the butler argument is provided or all steps requiring a reference
155  catalog are disabled.
156  """
157  BatchPoolTask.__init__(self, **kwargs)
158  if schema is None:
159  assert butler is not None, "Butler not provided"
160  schema = butler.get(self.config.coaddName +
161  "Coadd_det_schema", immediate=True).schema
162  self.butler = butler
163  self.makeSubtask("detectCoaddSources")
164  self.makeSubtask("mergeCoaddDetections", schema=schema)
165  self.makeSubtask("measureCoaddSources", schema=afwTable.Schema(self.mergeCoaddDetections.schema),
166  peakSchema=afwTable.Schema(
167  self.mergeCoaddDetections.merged.getPeakSchema()),
168  refObjLoader=refObjLoader, butler=butler)
169  self.makeSubtask("mergeCoaddMeasurements", schema=afwTable.Schema(
170  self.measureCoaddSources.schema))
171  self.makeSubtask("forcedPhotCoadd", refSchema=afwTable.Schema(
172  self.mergeCoaddMeasurements.schema))
173 
174  def __reduce__(self):
175  """Pickler"""
176  return unpickle, (self.__class__, [], dict(config=self.config, name=self._name,
177  parentTask=self._parentTask, log=self.log,
178  butler=self.butler))
179 
180  @classmethod
181  def _makeArgumentParser(cls, *args, **kwargs):
182  kwargs.pop("doBatch", False)
183  parser = ArgumentParser(name=cls._DefaultName, *args, **kwargs)
184  parser.add_id_argument("--id", "deepCoadd", help="data ID, e.g. --id tract=12345 patch=1,2",
185  ContainerClass=TractDataIdContainer)
186  return parser
187 
188  @classmethod
189  def batchWallTime(cls, time, parsedCmd, numCpus):
190  """!Return walltime request for batch job
191 
192  @param time: Requested time per iteration
193  @param parsedCmd: Results of argument parsing
194  @param numCores: Number of cores
195  """
196  numTargets = 0
197  for refList in parsedCmd.id.refList:
198  numTargets += len(refList)
199  return time*numTargets/float(numCpus)
200 
201  @abortOnError
202  def run(self, patchRefList):
203  """!Run multiband processing on coadds
204 
205  Only the master node runs this method.
206 
207  No real MPI communication (scatter/gather) takes place: all I/O goes
208  through the disk. We want the intermediate stages on disk, and the
209  component Tasks are implemented around this, so we just follow suit.
210 
211  @param patchRefList: Data references to run measurement
212  """
213  for patchRef in patchRefList:
214  if patchRef:
215  butler = patchRef.getButler()
216  break
217  else:
218  raise RuntimeError("No valid patches")
219  pool = Pool("all")
220  pool.cacheClear()
221  pool.storeSet(butler=butler)
222 
223  # MultiBand measurements require that the detection stage be completed before
224  # measurements can be made. Determine if data products are present, but detections
225  # are not, and attempt to run the detection stage where necessary. The configuration
226  # for coaddDriver.py allows detection to be turned of in the event that fake objects
227  # are to be added during the detection process. This allows the long co-addition
228  # process to be run once, and multiple different MultiBand reruns (with different
229  # fake objects) to exist from the same base co-addition.
230  detectionList = [patchRef for patchRef in patchRefList if not
231  patchRef.datasetExists(self.config.coaddName +
232  "Coadd_calexp") and
233  patchRef.datasetExists(self.config.coaddName +
234  "Coadd")]
235 
236  pool.map(self.runDetection, detectionList)
237 
238  patchRefList = [patchRef for patchRef in patchRefList if
239  patchRef.datasetExists(self.config.coaddName + "Coadd_calexp") and
240  patchRef.datasetExists(self.config.coaddName + "Coadd_det")]
241  dataIdList = [patchRef.dataId for patchRef in patchRefList]
242 
243  # Group by patch
244  patches = {}
245  tract = None
246  for patchRef in patchRefList:
247  dataId = patchRef.dataId
248  if tract is None:
249  tract = dataId["tract"]
250  else:
251  assert tract == dataId["tract"]
252 
253  patch = dataId["patch"]
254  if patch not in patches:
255  patches[patch] = []
256  patches[patch].append(dataId)
257 
258  pool.map(self.runMergeDetections, patches.values())
259 
260  # Measure merged detections, and test for reprocessing
261  #
262  # The reprocessing allows us to have multiple attempts at deblending large footprints. Large
263  # footprints can suck up a lot of memory in the deblender, which means that when we process on a
264  # cluster, we want to refuse to deblend them (they're flagged "deblend.parent-too-big"). But since
265  # they may have astronomically interesting data, we want the ability to go back and reprocess them
266  # with a more permissive configuration when we have more memory or processing time.
267  #
268  # self.runMeasureMerged will return whether there are any footprints in that image that required
269  # reprocessing. We need to convert that list of booleans into a dict mapping the patchId (x,y) to
270  # a boolean. That tells us whether the merge measurement and forced photometry need to be re-run on
271  # a particular patch.
272  #
273  # This determination of which patches need to be reprocessed exists only in memory (the measurements
274  # have been written, clobbering the old ones), so if there was an exception we would lose this
275  # information, leaving things in an inconsistent state (measurements new, but merged measurements and
276  # forced photometry old). To attempt to preserve this status, we touch a file (dataset named
277  # "deepCoadd_multibandReprocessing") --- if this file exists, we need to re-run the merge and
278  # forced photometry.
279  #
280  # This is, hopefully, a temporary workaround until we can improve the
281  # deblender.
282  try:
283  reprocessed = pool.map(self.runMeasureMerged, dataIdList)
284  finally:
285  if self.config.reprocessing:
286  patchReprocessing = {}
287  for dataId, reprocess in zip(dataIdList, reprocessed):
288  patchId = dataId["patch"]
289  patchReprocessing[patchId] = patchReprocessing.get(
290  patchId, False) or reprocess
291  # Persist the determination, to make error recover easier
292  reprocessDataset = self.config.coaddName + "Coadd_multibandReprocessing"
293  for patchId in patchReprocessing:
294  if not patchReprocessing[patchId]:
295  continue
296  dataId = dict(tract=tract, patch=patchId)
297  if patchReprocessing[patchId]:
298  filename = butler.get(
299  reprocessDataset + "_filename", dataId)[0]
300  open(filename, 'a').close() # Touch file
301  elif butler.datasetExists(reprocessDataset, dataId):
302  # We must have failed at some point while reprocessing
303  # and we're starting over
304  patchReprocessing[patchId] = True
305 
306  # Only process patches that have been identified as needing it
307  pool.map(self.runMergeMeasurements, [idList for patchId, idList in patches.items() if
308  not self.config.reprocessing or patchReprocessing[patchId]])
309  pool.map(self.runForcedPhot, [dataId1 for dataId1 in dataIdList if not self.config.reprocessing or
310  patchReprocessing[dataId["patch"]]])
311 
312  # Remove persisted reprocessing determination
313  if self.config.reprocessing:
314  for patchId in patchReprocessing:
315  if not patchReprocessing[patchId]:
316  continue
317  dataId = dict(tract=tract, patch=patchId)
318  filename = butler.get(
319  reprocessDataset + "_filename", dataId)[0]
320  os.unlink(filename)
321 
322  def runDetection(self, cache, patchRef):
323  """! Run detection on a patch
324 
325  Only slave nodes execute this method.
326 
327  @param cache: Pool cache, containing butler
328  @param patchRef: Patch on which to do detection
329  """
330  with self.logOperation("do detections on {}".format(patchRef.dataId)):
331  idFactory = self.detectCoaddSources.makeIdFactory(patchRef)
332  coadd = patchRef.get(self.config.coaddName + "Coadd",
333  immediate=True)
334  detResults = self.detectCoaddSources.runDetection(coadd, idFactory)
335  self.detectCoaddSources.write(coadd, detResults, patchRef)
336 
337  def runMergeDetections(self, cache, dataIdList):
338  """!Run detection merging on a patch
339 
340  Only slave nodes execute this method.
341 
342  @param cache: Pool cache, containing butler
343  @param dataIdList: List of data identifiers for the patch in different filters
344  """
345  with self.logOperation("merge detections from %s" % (dataIdList,)):
346  dataRefList = [getDataRef(cache.butler, dataId, self.config.coaddName + "Coadd_calexp") for
347  dataId in dataIdList]
348  if (not self.config.clobberMergedDetections and
349  dataRefList[0].datasetExists(self.config.coaddName + "Coadd_mergeDet")):
350  return
351  self.mergeCoaddDetections.run(dataRefList)
352 
353  def runMeasureMerged(self, cache, dataId):
354  """!Run measurement on a patch for a single filter
355 
356  Only slave nodes execute this method.
357 
358  @param cache: Pool cache, with butler
359  @param dataId: Data identifier for patch
360  @return whether the patch requires reprocessing.
361  """
362  with self.logOperation("measurement on %s" % (dataId,)):
363  dataRef = getDataRef(cache.butler, dataId,
364  self.config.coaddName + "Coadd_calexp")
365  reprocessing = False # Does this patch require reprocessing?
366  if (not self.config.clobberMeasurements and
367  dataRef.datasetExists(self.config.coaddName + "Coadd_meas")):
368  if not self.config.reprocessing:
369  return False
370 
371  catalog = dataRef.get(self.config.coaddName + "Coadd_meas")
372  bigFlag = catalog["deblend.parent-too-big"]
373  numOldBig = bigFlag.sum()
374  if numOldBig == 0:
375  self.log.info("No large footprints in %s" %
376  (dataRef.dataId,))
377  return False
378  numNewBig = sum((self.measureCoaddSources.deblend.isLargeFootprint(src.getFootprint()) for
379  src in catalog[bigFlag]))
380  if numNewBig == numOldBig:
381  self.log.info("All %d formerly large footprints continue to be large in %s" %
382  (numOldBig, dataRef.dataId,))
383  return False
384  self.log.info("Found %d large footprints to be reprocessed in %s" %
385  (numOldBig - numNewBig, dataRef.dataId))
386  reprocessing = True
387 
388  self.measureCoaddSources.run(dataRef)
389  return reprocessing
390 
391  def runMergeMeasurements(self, cache, dataIdList):
392  """!Run measurement merging on a patch
393 
394  Only slave nodes execute this method.
395 
396  @param cache: Pool cache, containing butler
397  @param dataIdList: List of data identifiers for the patch in different filters
398  """
399  with self.logOperation("merge measurements from %s" % (dataIdList,)):
400  dataRefList = [getDataRef(cache.butler, dataId, self.config.coaddName + "Coadd_calexp") for
401  dataId in dataIdList]
402  if (not self.config.clobberMergedMeasurements and
403  not self.config.reprocessing and
404  dataRefList[0].datasetExists(self.config.coaddName + "Coadd_ref")):
405  return
406  self.mergeCoaddMeasurements.run(dataRefList)
407 
408  def runForcedPhot(self, cache, dataId):
409  """!Run forced photometry on a patch for a single filter
410 
411  Only slave nodes execute this method.
412 
413  @param cache: Pool cache, with butler
414  @param dataId: Data identifier for patch
415  """
416  with self.logOperation("forced photometry on %s" % (dataId,)):
417  dataRef = getDataRef(cache.butler, dataId,
418  self.config.coaddName + "Coadd_calexp")
419  if (not self.config.clobberForcedPhotometry and
420  not self.config.reprocessing and
421  dataRef.datasetExists(self.config.coaddName + "Coadd_forced_src")):
422  return
423  self.forcedPhotCoadd.run(dataRef)
424 
425  def writeMetadata(self, dataRef):
426  """We don't collect any metadata, so skip"""
427  pass
def run
Run multiband processing on coadds.
def runMergeMeasurements
Run measurement merging on a patch.
def makeDataRefList
Make self.refList from self.idList.
def runMergeDetections
Run detection merging on a patch.
def runMeasureMerged
Run measurement on a patch for a single filter.
def runForcedPhot
Run forced photometry on a patch for a single filter.
def batchWallTime
Return walltime request for batch job.