lsst.pipe.drivers  17.0.1-8-g4eedfc1+1
multiBandDriver.py
Go to the documentation of this file.
1 from __future__ import absolute_import, division, print_function
2 import os
3 
4 from builtins import zip
5 
6 from lsst.pex.config import Config, Field, ConfigurableField
7 from lsst.pipe.base import ArgumentParser, TaskRunner
8 from lsst.pipe.tasks.multiBand import (DetectCoaddSourcesTask,
9  MergeDetectionsTask,
10  DeblendCoaddSourcesTask,
11  MeasureMergedCoaddSourcesTask,
12  MergeMeasurementsTask,)
13 from lsst.ctrl.pool.parallel import BatchPoolTask
14 from lsst.ctrl.pool.pool import Pool, abortOnError
15 from lsst.meas.base.references import MultiBandReferencesTask
16 from lsst.meas.base.forcedPhotCoadd import ForcedPhotCoaddTask
17 from lsst.pipe.drivers.utils import getDataRef, TractDataIdContainer
18 
19 import lsst.afw.table as afwTable
20 
21 
22 class MultiBandDriverConfig(Config):
23  coaddName = Field(dtype=str, default="deep", doc="Name of coadd")
24  doDetection = Field(dtype=bool, default=False,
25  doc="Re-run detection? (requires *Coadd dataset to have been written)")
26  detectCoaddSources = ConfigurableField(target=DetectCoaddSourcesTask,
27  doc="Detect sources on coadd")
28  mergeCoaddDetections = ConfigurableField(
29  target=MergeDetectionsTask, doc="Merge detections")
30  deblendCoaddSources = ConfigurableField(target=DeblendCoaddSourcesTask, doc="Deblend merged detections")
31  measureCoaddSources = ConfigurableField(target=MeasureMergedCoaddSourcesTask,
32  doc="Measure merged and (optionally) deblended detections")
33  mergeCoaddMeasurements = ConfigurableField(
34  target=MergeMeasurementsTask, doc="Merge measurements")
35  forcedPhotCoadd = ConfigurableField(target=ForcedPhotCoaddTask,
36  doc="Forced measurement on coadded images")
37  reprocessing = Field(
38  dtype=bool, default=False,
39  doc=("Are we reprocessing?\n\n"
40  "This exists as a workaround for large deblender footprints causing large memory use "
41  "and/or very slow processing. We refuse to deblend those footprints when running on a cluster "
42  "and return to reprocess on a machine with larger memory or more time "
43  "if we consider those footprints important to recover."),
44  )
45 
46  def setDefaults(self):
47  Config.setDefaults(self)
48  self.forcedPhotCoadd.references.retarget(MultiBandReferencesTask)
49 
50  def validate(self):
51  for subtask in ("mergeCoaddDetections", "deblendCoaddSources", "measureCoaddSources",
52  "mergeCoaddMeasurements", "forcedPhotCoadd"):
53  coaddName = getattr(self, subtask).coaddName
54  if coaddName != self.coaddName:
55  raise RuntimeError("%s.coaddName (%s) doesn't match root coaddName (%s)" %
56  (subtask, coaddName, self.coaddName))
57 
58 
59 class MultiBandDriverTaskRunner(TaskRunner):
60  """TaskRunner for running MultiBandTask
61 
62  This is similar to the lsst.pipe.base.ButlerInitializedTaskRunner,
63  except that we have a list of data references instead of a single
64  data reference being passed to the Task.run, and we pass the results
65  of the '--reuse-outputs-from' command option to the Task constructor.
66  """
67 
68  def __init__(self, TaskClass, parsedCmd, doReturnResults=False):
69  TaskRunner.__init__(self, TaskClass, parsedCmd, doReturnResults)
70  self.reuse = parsedCmd.reuse
71 
72  def makeTask(self, parsedCmd=None, args=None):
73  """A variant of the base version that passes a butler argument to the task's constructor
74  parsedCmd or args must be specified.
75  """
76  if parsedCmd is not None:
77  butler = parsedCmd.butler
78  elif args is not None:
79  dataRefList, kwargs = args
80  butler = dataRefList[0].butlerSubset.butler
81  else:
82  raise RuntimeError("parsedCmd or args must be specified")
83  return self.TaskClass(config=self.config, log=self.log, butler=butler, reuse=self.reuse)
84 
85 
86 def unpickle(factory, args, kwargs):
87  """Unpickle something by calling a factory"""
88  return factory(*args, **kwargs)
89 
90 
92  """Multi-node driver for multiband processing"""
93  ConfigClass = MultiBandDriverConfig
94  _DefaultName = "multiBandDriver"
95  RunnerClass = MultiBandDriverTaskRunner
96 
97  def __init__(self, butler=None, schema=None, refObjLoader=None, reuse=tuple(), **kwargs):
98  """!
99  @param[in] butler: the butler can be used to retrieve schema or passed to the refObjLoader constructor
100  in case it is needed.
101  @param[in] schema: the schema of the source detection catalog used as input.
102  @param[in] refObjLoader: an instance of LoadReferenceObjectsTasks that supplies an external reference
103  catalog. May be None if the butler argument is provided or all steps requiring a reference
104  catalog are disabled.
105  """
106  BatchPoolTask.__init__(self, **kwargs)
107  if schema is None:
108  assert butler is not None, "Butler not provided"
109  schema = butler.get(self.config.coaddName +
110  "Coadd_det_schema", immediate=True).schema
111  self.butler = butler
112  self.reuse = tuple(reuse)
113  self.makeSubtask("detectCoaddSources")
114  self.makeSubtask("mergeCoaddDetections", schema=schema)
115  if self.config.measureCoaddSources.inputCatalog.startswith("deblended"):
116  # Ensure that the output from deblendCoaddSources matches the input to measureCoaddSources
117  self.measurementInput = self.config.measureCoaddSources.inputCatalog
118  self.deblenderOutput = []
119  if self.config.deblendCoaddSources.simultaneous:
120  self.deblenderOutput.append("deblendedModel")
121  else:
122  self.deblenderOutput.append("deblendedFlux")
123  if self.measurementInput not in self.deblenderOutput:
124  err = "Measurement input '{0}' is not in the list of deblender output catalogs '{1}'"
125  raise ValueError(err.format(self.measurementInput, self.deblenderOutput))
126 
127  self.makeSubtask("deblendCoaddSources",
128  schema=afwTable.Schema(self.mergeCoaddDetections.schema),
129  peakSchema=afwTable.Schema(self.mergeCoaddDetections.merged.getPeakSchema()),
130  butler=butler)
131  measureInputSchema = afwTable.Schema(self.deblendCoaddSources.schema)
132  else:
133  measureInputSchema = afwTable.Schema(self.mergeCoaddDetections.schema)
134  self.makeSubtask("measureCoaddSources", schema=measureInputSchema,
135  peakSchema=afwTable.Schema(
136  self.mergeCoaddDetections.merged.getPeakSchema()),
137  refObjLoader=refObjLoader, butler=butler)
138  self.makeSubtask("mergeCoaddMeasurements", schema=afwTable.Schema(
139  self.measureCoaddSources.schema))
140  self.makeSubtask("forcedPhotCoadd", refSchema=afwTable.Schema(
141  self.mergeCoaddMeasurements.schema))
142 
143  def __reduce__(self):
144  """Pickler"""
145  return unpickle, (self.__class__, [], dict(config=self.config, name=self._name,
146  parentTask=self._parentTask, log=self.log,
147  butler=self.butler, reuse=self.reuse))
148 
149  @classmethod
150  def _makeArgumentParser(cls, *args, **kwargs):
151  kwargs.pop("doBatch", False)
152  parser = ArgumentParser(name=cls._DefaultName, *args, **kwargs)
153  parser.add_id_argument("--id", "deepCoadd", help="data ID, e.g. --id tract=12345 patch=1,2",
154  ContainerClass=TractDataIdContainer)
155  parser.addReuseOption(["detectCoaddSources", "mergeCoaddDetections", "measureCoaddSources",
156  "mergeCoaddMeasurements", "forcedPhotCoadd", "deblendCoaddSources"])
157  return parser
158 
159  @classmethod
160  def batchWallTime(cls, time, parsedCmd, numCpus):
161  """!Return walltime request for batch job
162 
163  @param time: Requested time per iteration
164  @param parsedCmd: Results of argument parsing
165  @param numCores: Number of cores
166  """
167  numTargets = 0
168  for refList in parsedCmd.id.refList:
169  numTargets += len(refList)
170  return time*numTargets/float(numCpus)
171 
172  @abortOnError
173  def runDataRef(self, patchRefList):
174  """!Run multiband processing on coadds
175 
176  Only the master node runs this method.
177 
178  No real MPI communication (scatter/gather) takes place: all I/O goes
179  through the disk. We want the intermediate stages on disk, and the
180  component Tasks are implemented around this, so we just follow suit.
181 
182  @param patchRefList: Data references to run measurement
183  """
184  for patchRef in patchRefList:
185  if patchRef:
186  butler = patchRef.getButler()
187  break
188  else:
189  raise RuntimeError("No valid patches")
190  pool = Pool("all")
191  pool.cacheClear()
192  pool.storeSet(butler=butler)
193 
194  # MultiBand measurements require that the detection stage be completed
195  # before measurements can be made.
196  #
197  # The configuration for coaddDriver.py allows detection to be turned
198  # of in the event that fake objects are to be added during the
199  # detection process. This allows the long co-addition process to be
200  # run once, and multiple different MultiBand reruns (with different
201  # fake objects) to exist from the same base co-addition.
202  #
203  # However, we only re-run detection if doDetection is explicitly True
204  # here (this should always be the opposite of coaddDriver.doDetection);
205  # otherwise we have no way to tell reliably whether any detections
206  # present in an input repo are safe to use.
207  if self.config.doDetection:
208  detectionList = []
209  for patchRef in patchRefList:
210  if ("detectCoaddSources" in self.reuse and
211  patchRef.datasetExists(self.config.coaddName + "Coadd_calexp", write=True)):
212  self.log.info("Skipping detectCoaddSources for %s; output already exists." %
213  patchRef.dataId)
214  continue
215  if not patchRef.datasetExists(self.config.coaddName + "Coadd"):
216  self.log.debug("Not processing %s; required input %sCoadd missing." %
217  (patchRef.dataId, self.config.coaddName))
218  continue
219  detectionList.append(patchRef)
220 
221  pool.map(self.runDetection, detectionList)
222 
223  patchRefList = [patchRef for patchRef in patchRefList if
224  patchRef.datasetExists(self.config.coaddName + "Coadd_calexp") and
225  patchRef.datasetExists(self.config.coaddName + "Coadd_det",
226  write=self.config.doDetection)]
227  dataIdList = [patchRef.dataId for patchRef in patchRefList]
228 
229  # Group by patch
230  patches = {}
231  tract = None
232  for patchRef in patchRefList:
233  dataId = patchRef.dataId
234  if tract is None:
235  tract = dataId["tract"]
236  else:
237  assert tract == dataId["tract"]
238 
239  patch = dataId["patch"]
240  if patch not in patches:
241  patches[patch] = []
242  patches[patch].append(dataId)
243 
244  pool.map(self.runMergeDetections, patches.values())
245 
246  # Deblend merged detections, and test for reprocessing
247  #
248  # The reprocessing allows us to have multiple attempts at deblending large footprints. Large
249  # footprints can suck up a lot of memory in the deblender, which means that when we process on a
250  # cluster, we want to refuse to deblend them (they're flagged "deblend.parent-too-big"). But since
251  # they may have astronomically interesting data, we want the ability to go back and reprocess them
252  # with a more permissive configuration when we have more memory or processing time.
253  #
254  # self.runDeblendMerged will return whether there are any footprints in that image that required
255  # reprocessing. We need to convert that list of booleans into a dict mapping the patchId (x,y) to
256  # a boolean. That tells us whether the merge measurement and forced photometry need to be re-run on
257  # a particular patch.
258  #
259  # This determination of which patches need to be reprocessed exists only in memory (the measurements
260  # have been written, clobbering the old ones), so if there was an exception we would lose this
261  # information, leaving things in an inconsistent state (measurements, merged measurements and
262  # forced photometry old). To attempt to preserve this status, we touch a file (dataset named
263  # "deepCoadd_multibandReprocessing") --- if this file exists, we need to re-run the measurements,
264  # merge and forced photometry.
265  #
266  # This is, hopefully, a temporary workaround until we can improve the
267  # deblender.
268  try:
269  reprocessed = pool.map(self.runDeblendMerged, patches.values())
270  finally:
271  if self.config.reprocessing:
272  patchReprocessing = {}
273  for dataId, reprocess in zip(dataIdList, reprocessed):
274  patchId = dataId["patch"]
275  patchReprocessing[patchId] = patchReprocessing.get(
276  patchId, False) or reprocess
277  # Persist the determination, to make error recover easier
278  reprocessDataset = self.config.coaddName + "Coadd_multibandReprocessing"
279  for patchId in patchReprocessing:
280  if not patchReprocessing[patchId]:
281  continue
282  dataId = dict(tract=tract, patch=patchId)
283  if patchReprocessing[patchId]:
284  filename = butler.get(
285  reprocessDataset + "_filename", dataId)[0]
286  open(filename, 'a').close() # Touch file
287  elif butler.datasetExists(reprocessDataset, dataId):
288  # We must have failed at some point while reprocessing
289  # and we're starting over
290  patchReprocessing[patchId] = True
291 
292  # Only process patches that have been identified as needing it
293  pool.map(self.runMeasurements, [dataId1 for dataId1 in dataIdList if not self.config.reprocessing or
294  patchReprocessing[dataId1["patch"]]])
295  pool.map(self.runMergeMeasurements, [idList for patchId, idList in patches.items() if
296  not self.config.reprocessing or patchReprocessing[patchId]])
297  pool.map(self.runForcedPhot, [dataId1 for dataId1 in dataIdList if not self.config.reprocessing or
298  patchReprocessing[dataId1["patch"]]])
299 
300  # Remove persisted reprocessing determination
301  if self.config.reprocessing:
302  for patchId in patchReprocessing:
303  if not patchReprocessing[patchId]:
304  continue
305  dataId = dict(tract=tract, patch=patchId)
306  filename = butler.get(
307  reprocessDataset + "_filename", dataId)[0]
308  os.unlink(filename)
309 
310  def runDetection(self, cache, patchRef):
311  """! Run detection on a patch
312 
313  Only slave nodes execute this method.
314 
315  @param cache: Pool cache, containing butler
316  @param patchRef: Patch on which to do detection
317  """
318  with self.logOperation("do detections on {}".format(patchRef.dataId)):
319  idFactory = self.detectCoaddSources.makeIdFactory(patchRef)
320  coadd = patchRef.get(self.config.coaddName + "Coadd",
321  immediate=True)
322  expId = int(patchRef.get(self.config.coaddName + "CoaddId"))
323  self.detectCoaddSources.emptyMetadata()
324  detResults = self.detectCoaddSources.run(coadd, idFactory, expId=expId)
325  self.detectCoaddSources.write(detResults, patchRef)
326  self.detectCoaddSources.writeMetadata(patchRef)
327 
328  def runMergeDetections(self, cache, dataIdList):
329  """!Run detection merging on a patch
330 
331  Only slave nodes execute this method.
332 
333  @param cache: Pool cache, containing butler
334  @param dataIdList: List of data identifiers for the patch in different filters
335  """
336  with self.logOperation("merge detections from %s" % (dataIdList,)):
337  dataRefList = [getDataRef(cache.butler, dataId, self.config.coaddName + "Coadd_calexp") for
338  dataId in dataIdList]
339  if ("mergeCoaddDetections" in self.reuse and
340  dataRefList[0].datasetExists(self.config.coaddName + "Coadd_mergeDet", write=True)):
341  self.log.info("Skipping mergeCoaddDetections for %s; output already exists." %
342  dataRefList[0].dataId)
343  return
344  self.mergeCoaddDetections.runDataRef(dataRefList)
345 
346  def runDeblendMerged(self, cache, dataIdList):
347  """Run the deblender on a list of dataId's
348 
349  Only slave nodes execute this method.
350 
351  Parameters
352  ----------
353  cache: Pool cache
354  Pool cache with butler.
355  dataIdList: list
356  Data identifier for patch in each band.
357 
358  Returns
359  -------
360  result: bool
361  whether the patch requires reprocessing.
362  """
363  with self.logOperation("deblending %s" % (dataIdList,)):
364  dataRefList = [getDataRef(cache.butler, dataId, self.config.coaddName + "Coadd_calexp") for
365  dataId in dataIdList]
366  reprocessing = False # Does this patch require reprocessing?
367  if ("deblendCoaddSources" in self.reuse and
368  all([dataRef.datasetExists(self.config.coaddName + "Coadd_" + self.measurementInput,
369  write=True) for dataRef in dataRefList])):
370  if not self.config.reprocessing:
371  self.log.info("Skipping deblendCoaddSources for %s; output already exists" % dataIdList)
372  return False
373 
374  # Footprints are the same every band, therefore we can check just one
375  catalog = dataRefList[0].get(self.config.coaddName + "Coadd_" + self.measurementInput)
376  bigFlag = catalog["deblend_parentTooBig"]
377  # Footprints marked too large by the previous deblender run
378  numOldBig = bigFlag.sum()
379  if numOldBig == 0:
380  self.log.info("No large footprints in %s" % (dataRefList[0].dataId))
381  return False
382 
383  # This if-statement can be removed after DM-15662
384  if self.config.deblendCoaddSources.simultaneous:
385  deblender = self.deblendCoaddSources.multiBandDeblend
386  else:
387  deblender = self.deblendCoaddSources.singleBandDeblend
388 
389  # isLargeFootprint() can potentially return False for a source that is marked
390  # too big in the catalog, because of "new"/different deblender configs.
391  # numNewBig is the number of footprints that *will* be too big if reprocessed
392  numNewBig = sum((deblender.isLargeFootprint(src.getFootprint()) for
393  src in catalog[bigFlag]))
394  if numNewBig == numOldBig:
395  self.log.info("All %d formerly large footprints continue to be large in %s" %
396  (numOldBig, dataRefList[0].dataId,))
397  return False
398  self.log.info("Found %d large footprints to be reprocessed in %s" %
399  (numOldBig - numNewBig, [dataRef.dataId for dataRef in dataRefList]))
400  reprocessing = True
401 
402  self.deblendCoaddSources.runDataRef(dataRefList)
403  return reprocessing
404 
405  def runMeasurements(self, cache, dataId):
406  """Run measurement on a patch for a single filter
407 
408  Only slave nodes execute this method.
409 
410  Parameters
411  ----------
412  cache: Pool cache
413  Pool cache, with butler
414  dataId: dataRef
415  Data identifier for patch
416  """
417  with self.logOperation("measurements on %s" % (dataId,)):
418  dataRef = getDataRef(cache.butler, dataId,
419  self.config.coaddName + "Coadd_calexp")
420  if ("measureCoaddSources" in self.reuse and
421  not self.config.reprocessing and
422  dataRef.datasetExists(self.config.coaddName + "Coadd_meas", write=True)):
423  self.log.info("Skipping measuretCoaddSources for %s; output already exists" % dataId)
424  return
425  self.measureCoaddSources.runDataRef(dataRef)
426 
427  def runMergeMeasurements(self, cache, dataIdList):
428  """!Run measurement merging on a patch
429 
430  Only slave nodes execute this method.
431 
432  @param cache: Pool cache, containing butler
433  @param dataIdList: List of data identifiers for the patch in different filters
434  """
435  with self.logOperation("merge measurements from %s" % (dataIdList,)):
436  dataRefList = [getDataRef(cache.butler, dataId, self.config.coaddName + "Coadd_calexp") for
437  dataId in dataIdList]
438  if ("mergeCoaddMeasurements" in self.reuse and
439  not self.config.reprocessing and
440  dataRefList[0].datasetExists(self.config.coaddName + "Coadd_ref", write=True)):
441  self.log.info("Skipping mergeCoaddMeasurements for %s; output already exists" %
442  dataRefList[0].dataId)
443  return
444  self.mergeCoaddMeasurements.runDataRef(dataRefList)
445 
446  def runForcedPhot(self, cache, dataId):
447  """!Run forced photometry on a patch for a single filter
448 
449  Only slave nodes execute this method.
450 
451  @param cache: Pool cache, with butler
452  @param dataId: Data identifier for patch
453  """
454  with self.logOperation("forced photometry on %s" % (dataId,)):
455  dataRef = getDataRef(cache.butler, dataId,
456  self.config.coaddName + "Coadd_calexp")
457  if ("forcedPhotCoadd" in self.reuse and
458  not self.config.reprocessing and
459  dataRef.datasetExists(self.config.coaddName + "Coadd_forced_src", write=True)):
460  self.log.info("Skipping forcedPhotCoadd for %s; output already exists" % dataId)
461  return
462  self.forcedPhotCoadd.runDataRef(dataRef)
463 
464  def writeMetadata(self, dataRef):
465  """We don't collect any metadata, so skip"""
466  pass
def unpickle(factory, args, kwargs)
def runDataRef(self, patchRefList)
Run multiband processing on coadds.
def __init__(self, butler=None, schema=None, refObjLoader=None, reuse=tuple(), kwargs)
def getDataRef(butler, dataId, datasetType="raw")
Definition: utils.py:17
def runForcedPhot(self, cache, dataId)
Run forced photometry on a patch for a single filter.
def batchWallTime(cls, time, parsedCmd, numCpus)
Return walltime request for batch job.
def __init__(self, TaskClass, parsedCmd, doReturnResults=False)
def logOperation(self, operation, catch=False, trace=True)
def runDetection(self, cache, patchRef)
Run detection on a patch.
def runMergeDetections(self, cache, dataIdList)
Run detection merging on a patch.
def runMergeMeasurements(self, cache, dataIdList)
Run measurement merging on a patch.