lsst.pipe.drivers  17.0.1-4-gce169aa+1
multiBandDriver.py
Go to the documentation of this file.
1 from __future__ import absolute_import, division, print_function
2 import os
3 
4 from builtins import zip
5 
6 from lsst.pex.config import Config, Field, ConfigurableField
7 from lsst.pipe.base import ArgumentParser, TaskRunner
8 from lsst.pipe.tasks.multiBand import (DetectCoaddSourcesTask,
9  MergeDetectionsTask,
10  DeblendCoaddSourcesTask,
11  MeasureMergedCoaddSourcesTask,
12  MergeMeasurementsTask,)
13 from lsst.ctrl.pool.parallel import BatchPoolTask
14 from lsst.ctrl.pool.pool import Pool, abortOnError
15 from lsst.meas.base.references import MultiBandReferencesTask
16 from lsst.meas.base.forcedPhotCoadd import ForcedPhotCoaddTask
17 from lsst.pipe.drivers.utils import getDataRef, TractDataIdContainer
18 
19 import lsst.afw.table as afwTable
20 
21 
22 class MultiBandDriverConfig(Config):
23  coaddName = Field(dtype=str, default="deep", doc="Name of coadd")
24  doDetection = Field(dtype=bool, default=False,
25  doc="Re-run detection? (requires *Coadd dataset to have been written)")
26  detectCoaddSources = ConfigurableField(target=DetectCoaddSourcesTask,
27  doc="Detect sources on coadd")
28  mergeCoaddDetections = ConfigurableField(
29  target=MergeDetectionsTask, doc="Merge detections")
30  deblendCoaddSources = ConfigurableField(target=DeblendCoaddSourcesTask, doc="Deblend merged detections")
31  measureCoaddSources = ConfigurableField(target=MeasureMergedCoaddSourcesTask,
32  doc="Measure merged and (optionally) deblended detections")
33  mergeCoaddMeasurements = ConfigurableField(
34  target=MergeMeasurementsTask, doc="Merge measurements")
35  forcedPhotCoadd = ConfigurableField(target=ForcedPhotCoaddTask,
36  doc="Forced measurement on coadded images")
37  reprocessing = Field(
38  dtype=bool, default=False,
39  doc=("Are we reprocessing?\n\n"
40  "This exists as a workaround for large deblender footprints causing large memory use "
41  "and/or very slow processing. We refuse to deblend those footprints when running on a cluster "
42  "and return to reprocess on a machine with larger memory or more time "
43  "if we consider those footprints important to recover."),
44  )
45 
46  def setDefaults(self):
47  Config.setDefaults(self)
48  self.forcedPhotCoadd.references.retarget(MultiBandReferencesTask)
49 
50  def validate(self):
51  for subtask in ("mergeCoaddDetections", "deblendCoaddSources", "measureCoaddSources",
52  "mergeCoaddMeasurements", "forcedPhotCoadd"):
53  coaddName = getattr(self, subtask).coaddName
54  if coaddName != self.coaddName:
55  raise RuntimeError("%s.coaddName (%s) doesn't match root coaddName (%s)" %
56  (subtask, coaddName, self.coaddName))
57 
58 
59 class MultiBandDriverTaskRunner(TaskRunner):
60  """TaskRunner for running MultiBandTask
61 
62  This is similar to the lsst.pipe.base.ButlerInitializedTaskRunner,
63  except that we have a list of data references instead of a single
64  data reference being passed to the Task.run, and we pass the results
65  of the '--reuse-outputs-from' command option to the Task constructor.
66  """
67 
68  def __init__(self, TaskClass, parsedCmd, doReturnResults=False):
69  TaskRunner.__init__(self, TaskClass, parsedCmd, doReturnResults)
70  self.reuse = parsedCmd.reuse
71 
72  def makeTask(self, parsedCmd=None, args=None):
73  """A variant of the base version that passes a butler argument to the task's constructor
74  parsedCmd or args must be specified.
75  """
76  if parsedCmd is not None:
77  butler = parsedCmd.butler
78  elif args is not None:
79  dataRefList, kwargs = args
80  butler = dataRefList[0].butlerSubset.butler
81  else:
82  raise RuntimeError("parsedCmd or args must be specified")
83  return self.TaskClass(config=self.config, log=self.log, butler=butler, reuse=self.reuse)
84 
85 
86 def unpickle(factory, args, kwargs):
87  """Unpickle something by calling a factory"""
88  return factory(*args, **kwargs)
89 
90 
92  """Multi-node driver for multiband processing"""
93  ConfigClass = MultiBandDriverConfig
94  _DefaultName = "multiBandDriver"
95  RunnerClass = MultiBandDriverTaskRunner
96 
97  def __init__(self, butler=None, schema=None, refObjLoader=None, reuse=tuple(), **kwargs):
98  """!
99  @param[in] butler: the butler can be used to retrieve schema or passed to the refObjLoader constructor
100  in case it is needed.
101  @param[in] schema: the schema of the source detection catalog used as input.
102  @param[in] refObjLoader: an instance of LoadReferenceObjectsTasks that supplies an external reference
103  catalog. May be None if the butler argument is provided or all steps requiring a reference
104  catalog are disabled.
105  """
106  BatchPoolTask.__init__(self, **kwargs)
107  if schema is None:
108  assert butler is not None, "Butler not provided"
109  schema = butler.get(self.config.coaddName +
110  "Coadd_det_schema", immediate=True).schema
111  self.butler = butler
112  self.reuse = tuple(reuse)
113  self.makeSubtask("detectCoaddSources")
114  self.makeSubtask("mergeCoaddDetections", schema=schema)
115  if self.config.measureCoaddSources.inputCatalog.startswith("deblended"):
116  # Ensure that the output from deblendCoaddSources matches the input to measureCoaddSources
117  self.measurementInput = self.config.measureCoaddSources.inputCatalog
118  self.deblenderOutput = []
119  if self.config.deblendCoaddSources.simultaneous:
120  if self.config.deblendCoaddSources.multiBandDeblend.conserveFlux:
121  self.deblenderOutput.append("deblendedFlux")
122  if self.config.deblendCoaddSources.multiBandDeblend.saveTemplates:
123  self.deblenderOutput.append("deblendedModel")
124  else:
125  self.deblenderOutput.append("deblendedFlux")
126  if self.measurementInput not in self.deblenderOutput:
127  err = "Measurement input '{0}' is not in the list of deblender output catalogs '{1}'"
128  raise ValueError(err.format(self.measurementInput, self.deblenderOutput))
129 
130  self.makeSubtask("deblendCoaddSources",
131  schema=afwTable.Schema(self.mergeCoaddDetections.schema),
132  peakSchema=afwTable.Schema(self.mergeCoaddDetections.merged.getPeakSchema()),
133  butler=butler)
134  measureInputSchema = afwTable.Schema(self.deblendCoaddSources.schema)
135  else:
136  measureInputSchema = afwTable.Schema(self.mergeCoaddDetections.schema)
137  self.makeSubtask("measureCoaddSources", schema=measureInputSchema,
138  peakSchema=afwTable.Schema(
139  self.mergeCoaddDetections.merged.getPeakSchema()),
140  refObjLoader=refObjLoader, butler=butler)
141  self.makeSubtask("mergeCoaddMeasurements", schema=afwTable.Schema(
142  self.measureCoaddSources.schema))
143  self.makeSubtask("forcedPhotCoadd", refSchema=afwTable.Schema(
144  self.mergeCoaddMeasurements.schema))
145 
146  def __reduce__(self):
147  """Pickler"""
148  return unpickle, (self.__class__, [], dict(config=self.config, name=self._name,
149  parentTask=self._parentTask, log=self.log,
150  butler=self.butler, reuse=self.reuse))
151 
152  @classmethod
153  def _makeArgumentParser(cls, *args, **kwargs):
154  kwargs.pop("doBatch", False)
155  parser = ArgumentParser(name=cls._DefaultName, *args, **kwargs)
156  parser.add_id_argument("--id", "deepCoadd", help="data ID, e.g. --id tract=12345 patch=1,2",
157  ContainerClass=TractDataIdContainer)
158  parser.addReuseOption(["detectCoaddSources", "mergeCoaddDetections", "measureCoaddSources",
159  "mergeCoaddMeasurements", "forcedPhotCoadd", "deblendCoaddSources"])
160  return parser
161 
162  @classmethod
163  def batchWallTime(cls, time, parsedCmd, numCpus):
164  """!Return walltime request for batch job
165 
166  @param time: Requested time per iteration
167  @param parsedCmd: Results of argument parsing
168  @param numCores: Number of cores
169  """
170  numTargets = 0
171  for refList in parsedCmd.id.refList:
172  numTargets += len(refList)
173  return time*numTargets/float(numCpus)
174 
175  @abortOnError
176  def runDataRef(self, patchRefList):
177  """!Run multiband processing on coadds
178 
179  Only the master node runs this method.
180 
181  No real MPI communication (scatter/gather) takes place: all I/O goes
182  through the disk. We want the intermediate stages on disk, and the
183  component Tasks are implemented around this, so we just follow suit.
184 
185  @param patchRefList: Data references to run measurement
186  """
187  for patchRef in patchRefList:
188  if patchRef:
189  butler = patchRef.getButler()
190  break
191  else:
192  raise RuntimeError("No valid patches")
193  pool = Pool("all")
194  pool.cacheClear()
195  pool.storeSet(butler=butler)
196 
197  # MultiBand measurements require that the detection stage be completed
198  # before measurements can be made.
199  #
200  # The configuration for coaddDriver.py allows detection to be turned
201  # of in the event that fake objects are to be added during the
202  # detection process. This allows the long co-addition process to be
203  # run once, and multiple different MultiBand reruns (with different
204  # fake objects) to exist from the same base co-addition.
205  #
206  # However, we only re-run detection if doDetection is explicitly True
207  # here (this should always be the opposite of coaddDriver.doDetection);
208  # otherwise we have no way to tell reliably whether any detections
209  # present in an input repo are safe to use.
210  if self.config.doDetection:
211  detectionList = []
212  for patchRef in patchRefList:
213  if ("detectCoaddSources" in self.reuse and
214  patchRef.datasetExists(self.config.coaddName + "Coadd_calexp", write=True)):
215  self.log.info("Skipping detectCoaddSources for %s; output already exists." %
216  patchRef.dataId)
217  continue
218  if not patchRef.datasetExists(self.config.coaddName + "Coadd"):
219  self.log.debug("Not processing %s; required input %sCoadd missing." %
220  (patchRef.dataId, self.config.coaddName))
221  continue
222  detectionList.append(patchRef)
223 
224  pool.map(self.runDetection, detectionList)
225 
226  patchRefList = [patchRef for patchRef in patchRefList if
227  patchRef.datasetExists(self.config.coaddName + "Coadd_calexp") and
228  patchRef.datasetExists(self.config.coaddName + "Coadd_det",
229  write=self.config.doDetection)]
230  dataIdList = [patchRef.dataId for patchRef in patchRefList]
231 
232  # Group by patch
233  patches = {}
234  tract = None
235  for patchRef in patchRefList:
236  dataId = patchRef.dataId
237  if tract is None:
238  tract = dataId["tract"]
239  else:
240  assert tract == dataId["tract"]
241 
242  patch = dataId["patch"]
243  if patch not in patches:
244  patches[patch] = []
245  patches[patch].append(dataId)
246 
247  pool.map(self.runMergeDetections, patches.values())
248 
249  # Deblend merged detections, and test for reprocessing
250  #
251  # The reprocessing allows us to have multiple attempts at deblending large footprints. Large
252  # footprints can suck up a lot of memory in the deblender, which means that when we process on a
253  # cluster, we want to refuse to deblend them (they're flagged "deblend.parent-too-big"). But since
254  # they may have astronomically interesting data, we want the ability to go back and reprocess them
255  # with a more permissive configuration when we have more memory or processing time.
256  #
257  # self.runDeblendMerged will return whether there are any footprints in that image that required
258  # reprocessing. We need to convert that list of booleans into a dict mapping the patchId (x,y) to
259  # a boolean. That tells us whether the merge measurement and forced photometry need to be re-run on
260  # a particular patch.
261  #
262  # This determination of which patches need to be reprocessed exists only in memory (the measurements
263  # have been written, clobbering the old ones), so if there was an exception we would lose this
264  # information, leaving things in an inconsistent state (measurements, merged measurements and
265  # forced photometry old). To attempt to preserve this status, we touch a file (dataset named
266  # "deepCoadd_multibandReprocessing") --- if this file exists, we need to re-run the measurements,
267  # merge and forced photometry.
268  #
269  # This is, hopefully, a temporary workaround until we can improve the
270  # deblender.
271  try:
272  reprocessed = pool.map(self.runDeblendMerged, patches.values())
273  finally:
274  if self.config.reprocessing:
275  patchReprocessing = {}
276  for dataId, reprocess in zip(dataIdList, reprocessed):
277  patchId = dataId["patch"]
278  patchReprocessing[patchId] = patchReprocessing.get(
279  patchId, False) or reprocess
280  # Persist the determination, to make error recover easier
281  reprocessDataset = self.config.coaddName + "Coadd_multibandReprocessing"
282  for patchId in patchReprocessing:
283  if not patchReprocessing[patchId]:
284  continue
285  dataId = dict(tract=tract, patch=patchId)
286  if patchReprocessing[patchId]:
287  filename = butler.get(
288  reprocessDataset + "_filename", dataId)[0]
289  open(filename, 'a').close() # Touch file
290  elif butler.datasetExists(reprocessDataset, dataId):
291  # We must have failed at some point while reprocessing
292  # and we're starting over
293  patchReprocessing[patchId] = True
294 
295  # Only process patches that have been identified as needing it
296  pool.map(self.runMeasurements, [dataId1 for dataId1 in dataIdList if not self.config.reprocessing or
297  patchReprocessing[dataId1["patch"]]])
298  pool.map(self.runMergeMeasurements, [idList for patchId, idList in patches.items() if
299  not self.config.reprocessing or patchReprocessing[patchId]])
300  pool.map(self.runForcedPhot, [dataId1 for dataId1 in dataIdList if not self.config.reprocessing or
301  patchReprocessing[dataId1["patch"]]])
302 
303  # Remove persisted reprocessing determination
304  if self.config.reprocessing:
305  for patchId in patchReprocessing:
306  if not patchReprocessing[patchId]:
307  continue
308  dataId = dict(tract=tract, patch=patchId)
309  filename = butler.get(
310  reprocessDataset + "_filename", dataId)[0]
311  os.unlink(filename)
312 
313  def runDetection(self, cache, patchRef):
314  """! Run detection on a patch
315 
316  Only slave nodes execute this method.
317 
318  @param cache: Pool cache, containing butler
319  @param patchRef: Patch on which to do detection
320  """
321  with self.logOperation("do detections on {}".format(patchRef.dataId)):
322  idFactory = self.detectCoaddSources.makeIdFactory(patchRef)
323  coadd = patchRef.get(self.config.coaddName + "Coadd",
324  immediate=True)
325  expId = int(patchRef.get(self.config.coaddName + "CoaddId"))
326  self.detectCoaddSources.emptyMetadata()
327  detResults = self.detectCoaddSources.run(coadd, idFactory, expId=expId)
328  self.detectCoaddSources.write(detResults, patchRef)
329  self.detectCoaddSources.writeMetadata(patchRef)
330 
331  def runMergeDetections(self, cache, dataIdList):
332  """!Run detection merging on a patch
333 
334  Only slave nodes execute this method.
335 
336  @param cache: Pool cache, containing butler
337  @param dataIdList: List of data identifiers for the patch in different filters
338  """
339  with self.logOperation("merge detections from %s" % (dataIdList,)):
340  dataRefList = [getDataRef(cache.butler, dataId, self.config.coaddName + "Coadd_calexp") for
341  dataId in dataIdList]
342  if ("mergeCoaddDetections" in self.reuse and
343  dataRefList[0].datasetExists(self.config.coaddName + "Coadd_mergeDet", write=True)):
344  self.log.info("Skipping mergeCoaddDetections for %s; output already exists." %
345  dataRefList[0].dataId)
346  return
347  self.mergeCoaddDetections.runDataRef(dataRefList)
348 
349  def runDeblendMerged(self, cache, dataIdList):
350  """Run the deblender on a list of dataId's
351 
352  Only slave nodes execute this method.
353 
354  Parameters
355  ----------
356  cache: Pool cache
357  Pool cache with butler.
358  dataIdList: list
359  Data identifier for patch in each band.
360 
361  Returns
362  -------
363  result: bool
364  whether the patch requires reprocessing.
365  """
366  with self.logOperation("deblending %s" % (dataIdList,)):
367  dataRefList = [getDataRef(cache.butler, dataId, self.config.coaddName + "Coadd_calexp") for
368  dataId in dataIdList]
369  reprocessing = False # Does this patch require reprocessing?
370  if ("deblendCoaddSources" in self.reuse and
371  all([dataRef.datasetExists(self.config.coaddName + "Coadd_" + self.measurementInput,
372  write=True) for dataRef in dataRefList])):
373  if not self.config.reprocessing:
374  self.log.info("Skipping deblendCoaddSources for %s; output already exists" % dataIdList)
375  return False
376 
377  # Footprints are the same every band, therefore we can check just one
378  catalog = dataRefList[0].get(self.config.coaddName + "Coadd_" + self.measurementInput)
379  bigFlag = catalog["deblend_parentTooBig"]
380  # Footprints marked too large by the previous deblender run
381  numOldBig = bigFlag.sum()
382  if numOldBig == 0:
383  self.log.info("No large footprints in %s" % (dataRefList[0].dataId))
384  return False
385 
386  # This if-statement can be removed after DM-15662
387  if self.config.deblendCoaddSources.simultaneous:
388  deblender = self.deblendCoaddSources.multiBandDeblend
389  else:
390  deblender = self.deblendCoaddSources.singleBandDeblend
391 
392  # isLargeFootprint() can potentially return False for a source that is marked
393  # too big in the catalog, because of "new"/different deblender configs.
394  # numNewBig is the number of footprints that *will* be too big if reprocessed
395  numNewBig = sum((deblender.isLargeFootprint(src.getFootprint()) for
396  src in catalog[bigFlag]))
397  if numNewBig == numOldBig:
398  self.log.info("All %d formerly large footprints continue to be large in %s" %
399  (numOldBig, dataRefList[0].dataId,))
400  return False
401  self.log.info("Found %d large footprints to be reprocessed in %s" %
402  (numOldBig - numNewBig, [dataRef.dataId for dataRef in dataRefList]))
403  reprocessing = True
404 
405  self.deblendCoaddSources.runDataRef(dataRefList)
406  return reprocessing
407 
408  def runMeasurements(self, cache, dataId):
409  """Run measurement on a patch for a single filter
410 
411  Only slave nodes execute this method.
412 
413  Parameters
414  ----------
415  cache: Pool cache
416  Pool cache, with butler
417  dataId: dataRef
418  Data identifier for patch
419  """
420  with self.logOperation("measurements on %s" % (dataId,)):
421  dataRef = getDataRef(cache.butler, dataId,
422  self.config.coaddName + "Coadd_calexp")
423  if ("measureCoaddSources" in self.reuse and
424  not self.config.reprocessing and
425  dataRef.datasetExists(self.config.coaddName + "Coadd_meas", write=True)):
426  self.log.info("Skipping measuretCoaddSources for %s; output already exists" % dataId)
427  return
428  self.measureCoaddSources.runDataRef(dataRef)
429 
430  def runMergeMeasurements(self, cache, dataIdList):
431  """!Run measurement merging on a patch
432 
433  Only slave nodes execute this method.
434 
435  @param cache: Pool cache, containing butler
436  @param dataIdList: List of data identifiers for the patch in different filters
437  """
438  with self.logOperation("merge measurements from %s" % (dataIdList,)):
439  dataRefList = [getDataRef(cache.butler, dataId, self.config.coaddName + "Coadd_calexp") for
440  dataId in dataIdList]
441  if ("mergeCoaddMeasurements" in self.reuse and
442  not self.config.reprocessing and
443  dataRefList[0].datasetExists(self.config.coaddName + "Coadd_ref", write=True)):
444  self.log.info("Skipping mergeCoaddMeasurements for %s; output already exists" %
445  dataRefList[0].dataId)
446  return
447  self.mergeCoaddMeasurements.runDataRef(dataRefList)
448 
449  def runForcedPhot(self, cache, dataId):
450  """!Run forced photometry on a patch for a single filter
451 
452  Only slave nodes execute this method.
453 
454  @param cache: Pool cache, with butler
455  @param dataId: Data identifier for patch
456  """
457  with self.logOperation("forced photometry on %s" % (dataId,)):
458  dataRef = getDataRef(cache.butler, dataId,
459  self.config.coaddName + "Coadd_calexp")
460  if ("forcedPhotCoadd" in self.reuse and
461  not self.config.reprocessing and
462  dataRef.datasetExists(self.config.coaddName + "Coadd_forced_src", write=True)):
463  self.log.info("Skipping forcedPhotCoadd for %s; output already exists" % dataId)
464  return
465  self.forcedPhotCoadd.runDataRef(dataRef)
466 
467  def writeMetadata(self, dataRef):
468  """We don't collect any metadata, so skip"""
469  pass
def unpickle(factory, args, kwargs)
def runDataRef(self, patchRefList)
Run multiband processing on coadds.
def __init__(self, butler=None, schema=None, refObjLoader=None, reuse=tuple(), kwargs)
def getDataRef(butler, dataId, datasetType="raw")
Definition: utils.py:17
def runForcedPhot(self, cache, dataId)
Run forced photometry on a patch for a single filter.
def batchWallTime(cls, time, parsedCmd, numCpus)
Return walltime request for batch job.
def __init__(self, TaskClass, parsedCmd, doReturnResults=False)
def logOperation(self, operation, catch=False, trace=True)
def runDetection(self, cache, patchRef)
Run detection on a patch.
def runMergeDetections(self, cache, dataIdList)
Run detection merging on a patch.
def runMergeMeasurements(self, cache, dataIdList)
Run measurement merging on a patch.