lsst.pipe.drivers  14.0-5-g782f885+1
multiBandDriver.py
Go to the documentation of this file.
1 from __future__ import absolute_import, division, print_function
2 import os
3 from argparse import ArgumentError
4 
5 from builtins import zip
6 
7 from lsst.pex.config import Config, Field, ConfigurableField
8 from lsst.pipe.base import ArgumentParser, TaskRunner
9 from lsst.pipe.tasks.multiBand import (DetectCoaddSourcesTask,
10  MergeDetectionsTask,
11  MeasureMergedCoaddSourcesTask,
12  MergeMeasurementsTask,)
13 from lsst.ctrl.pool.parallel import BatchPoolTask
14 from lsst.ctrl.pool.pool import Pool, abortOnError, NODE
15 from lsst.meas.base.references import MultiBandReferencesTask
16 from lsst.meas.base.forcedPhotCoadd import ForcedPhotCoaddTask
17 from lsst.pipe.drivers.utils import getDataRef, TractDataIdContainer
18 from lsst.pipe.tasks.coaddBase import CoaddDataIdContainer
19 
20 import lsst.afw.table as afwTable
21 
22 
23 class MultiBandDataIdContainer(CoaddDataIdContainer):
24 
25  def makeDataRefList(self, namespace):
26  """!Make self.refList from self.idList
27 
28  It's difficult to make a data reference that merely points to an entire
29  tract: there is no data product solely at the tract level. Instead, we
30  generate a list of data references for patches within the tract.
31 
32  @param namespace namespace object that is the result of an argument parser
33  """
34  datasetType = namespace.config.coaddName + "Coadd_calexp"
35 
36  def getPatchRefList(tract):
37  return [namespace.butler.dataRef(datasetType=datasetType,
38  tract=tract.getId(),
39  filter=dataId["filter"],
40  patch="%d,%d" % patch.getIndex())
41  for patch in tract]
42 
43  tractRefs = {} # Data references for each tract
44  for dataId in self.idList:
45  # There's no registry of coadds by filter, so we need to be given
46  # the filter
47  if "filter" not in dataId:
48  raise ArgumentError(None, "--id must include 'filter'")
49 
50  skymap = self.getSkymap(namespace, datasetType)
51 
52  if "tract" in dataId:
53  tractId = dataId["tract"]
54  if tractId not in tractRefs:
55  tractRefs[tractId] = []
56  if "patch" in dataId:
57  tractRefs[tractId].append(namespace.butler.dataRef(datasetType=datasetType,
58  tract=tractId,
59  filter=dataId[
60  'filter'],
61  patch=dataId['patch']))
62  else:
63  tractRefs[tractId] += getPatchRefList(skymap[tractId])
64  else:
65  tractRefs = dict((tract.getId(), tractRefs.get(tract.getId(), []) + getPatchRefList(tract))
66  for tract in skymap)
67 
68  self.refList = list(tractRefs.values())
69 
70 
71 class MultiBandDriverConfig(Config):
72  coaddName = Field(dtype=str, default="deep", doc="Name of coadd")
73  detectCoaddSources = ConfigurableField(target=DetectCoaddSourcesTask,
74  doc="Detect sources on coadd")
75  mergeCoaddDetections = ConfigurableField(
76  target=MergeDetectionsTask, doc="Merge detections")
77  measureCoaddSources = ConfigurableField(target=MeasureMergedCoaddSourcesTask,
78  doc="Measure merged detections")
79  mergeCoaddMeasurements = ConfigurableField(
80  target=MergeMeasurementsTask, doc="Merge measurements")
81  forcedPhotCoadd = ConfigurableField(target=ForcedPhotCoaddTask,
82  doc="Forced measurement on coadded images")
83  clobberDetections = Field(
84  dtype=bool, default=False, doc="Clobber existing detections?")
85  clobberMergedDetections = Field(
86  dtype=bool, default=False, doc="Clobber existing merged detections?")
87  clobberMeasurements = Field(
88  dtype=bool, default=False, doc="Clobber existing measurements?")
89  clobberMergedMeasurements = Field(
90  dtype=bool, default=False, doc="Clobber existing merged measurements?")
91  clobberForcedPhotometry = Field(
92  dtype=bool, default=False, doc="Clobber existing forced photometry?")
93  reprocessing = Field(
94  dtype=bool, default=False,
95  doc=("Are we reprocessing?\n\n"
96  "This exists as a workaround for large deblender footprints causing large memory use "
97  "and/or very slow processing. We refuse to deblend those footprints when running on a cluster "
98  "and return to reprocess on a machine with larger memory or more time "
99  "if we consider those footprints important to recover."),
100  )
101 
102  def setDefaults(self):
103  Config.setDefaults(self)
104  self.forcedPhotCoadd.references.retarget(MultiBandReferencesTask)
105 
106  def validate(self):
107  for subtask in ("mergeCoaddDetections", "measureCoaddSources",
108  "mergeCoaddMeasurements", "forcedPhotCoadd"):
109  coaddName = getattr(self, subtask).coaddName
110  if coaddName != self.coaddName:
111  raise RuntimeError("%s.coaddName (%s) doesn't match root coaddName (%s)" %
112  (subtask, coaddName, self.coaddName))
113 
114 
115 class MultiBandDriverTaskRunner(TaskRunner):
116  """TaskRunner for running MultiBandTask
117 
118  This is similar to the lsst.pipe.base.ButlerInitializedTaskRunner,
119  except that we have a list of data references instead of a single
120  data reference being passed to the Task.run.
121  """
122 
123  def makeTask(self, parsedCmd=None, args=None):
124  """A variant of the base version that passes a butler argument to the task's constructor
125  parsedCmd or args must be specified.
126  """
127  if parsedCmd is not None:
128  butler = parsedCmd.butler
129  elif args is not None:
130  dataRefList, kwargs = args
131  butler = dataRefList[0].butlerSubset.butler
132  else:
133  raise RuntimeError("parsedCmd or args must be specified")
134  return self.TaskClass(config=self.config, log=self.log, butler=butler)
135 
136 
137 def unpickle(factory, args, kwargs):
138  """Unpickle something by calling a factory"""
139  return factory(*args, **kwargs)
140 
141 
143  """Multi-node driver for multiband processing"""
144  ConfigClass = MultiBandDriverConfig
145  _DefaultName = "multiBandDriver"
146  RunnerClass = MultiBandDriverTaskRunner
147 
148  def __init__(self, butler=None, schema=None, refObjLoader=None, **kwargs):
149  """!
150  @param[in] butler: the butler can be used to retrieve schema or passed to the refObjLoader constructor
151  in case it is needed.
152  @param[in] schema: the schema of the source detection catalog used as input.
153  @param[in] refObjLoader: an instance of LoadReferenceObjectsTasks that supplies an external reference
154  catalog. May be None if the butler argument is provided or all steps requiring a reference
155  catalog are disabled.
156  """
157  BatchPoolTask.__init__(self, **kwargs)
158  if schema is None:
159  assert butler is not None, "Butler not provided"
160  schema = butler.get(self.config.coaddName +
161  "Coadd_det_schema", immediate=True).schema
162  self.butler = butler
163  self.makeSubtask("detectCoaddSources")
164  self.makeSubtask("mergeCoaddDetections", schema=schema)
165  self.makeSubtask("measureCoaddSources", schema=afwTable.Schema(self.mergeCoaddDetections.schema),
166  peakSchema=afwTable.Schema(
167  self.mergeCoaddDetections.merged.getPeakSchema()),
168  refObjLoader=refObjLoader, butler=butler)
169  self.makeSubtask("mergeCoaddMeasurements", schema=afwTable.Schema(
170  self.measureCoaddSources.schema))
171  self.makeSubtask("forcedPhotCoadd", refSchema=afwTable.Schema(
172  self.mergeCoaddMeasurements.schema))
173 
174  def __reduce__(self):
175  """Pickler"""
176  return unpickle, (self.__class__, [], dict(config=self.config, name=self._name,
177  parentTask=self._parentTask, log=self.log,
178  butler=self.butler))
179 
180  @classmethod
181  def _makeArgumentParser(cls, *args, **kwargs):
182  kwargs.pop("doBatch", False)
183  parser = ArgumentParser(name=cls._DefaultName, *args, **kwargs)
184  parser.add_id_argument("--id", "deepCoadd", help="data ID, e.g. --id tract=12345 patch=1,2",
185  ContainerClass=TractDataIdContainer)
186  return parser
187 
188  @classmethod
189  def batchWallTime(cls, time, parsedCmd, numCpus):
190  """!Return walltime request for batch job
191 
192  @param time: Requested time per iteration
193  @param parsedCmd: Results of argument parsing
194  @param numCores: Number of cores
195  """
196  numTargets = 0
197  for refList in parsedCmd.id.refList:
198  numTargets += len(refList)
199  return time*numTargets/float(numCpus)
200 
201  @abortOnError
202  def run(self, patchRefList):
203  """!Run multiband processing on coadds
204 
205  Only the master node runs this method.
206 
207  No real MPI communication (scatter/gather) takes place: all I/O goes
208  through the disk. We want the intermediate stages on disk, and the
209  component Tasks are implemented around this, so we just follow suit.
210 
211  @param patchRefList: Data references to run measurement
212  """
213  for patchRef in patchRefList:
214  if patchRef:
215  butler = patchRef.getButler()
216  break
217  else:
218  raise RuntimeError("No valid patches")
219  pool = Pool("all")
220  pool.cacheClear()
221  pool.storeSet(butler=butler)
222 
223  # MultiBand measurements require that the detection stage be completed before
224  # measurements can be made. Determine if data products are present, but detections
225  # are not, and attempt to run the detection stage where necessary. The configuration
226  # for coaddDriver.py allows detection to be turned of in the event that fake objects
227  # are to be added during the detection process. This allows the long co-addition
228  # process to be run once, and multiple different MultiBand reruns (with different
229  # fake objects) to exist from the same base co-addition.
230  # If the detections are to be clobbered, add all patches to the detection list
231  # unless the datasets necessary to generate detections do not exist
232  if self.config.clobberDetections:
233  detectionList = [patchRef for patchRef in patchRefList if
234  patchRef.datasetExists(self.config.coaddName + "Coadd")]
235  else:
236  detectionList = [patchRef for patchRef in patchRefList if not
237  patchRef.datasetExists(self.config.coaddName +
238  "Coadd_calexp") and
239  patchRef.datasetExists(self.config.coaddName +
240  "Coadd")]
241 
242  pool.map(self.runDetection, detectionList)
243 
244  patchRefList = [patchRef for patchRef in patchRefList if
245  patchRef.datasetExists(self.config.coaddName + "Coadd_calexp") and
246  patchRef.datasetExists(self.config.coaddName + "Coadd_det")]
247  dataIdList = [patchRef.dataId for patchRef in patchRefList]
248 
249  # Group by patch
250  patches = {}
251  tract = None
252  for patchRef in patchRefList:
253  dataId = patchRef.dataId
254  if tract is None:
255  tract = dataId["tract"]
256  else:
257  assert tract == dataId["tract"]
258 
259  patch = dataId["patch"]
260  if patch not in patches:
261  patches[patch] = []
262  patches[patch].append(dataId)
263 
264  pool.map(self.runMergeDetections, patches.values())
265 
266  # Measure merged detections, and test for reprocessing
267  #
268  # The reprocessing allows us to have multiple attempts at deblending large footprints. Large
269  # footprints can suck up a lot of memory in the deblender, which means that when we process on a
270  # cluster, we want to refuse to deblend them (they're flagged "deblend.parent-too-big"). But since
271  # they may have astronomically interesting data, we want the ability to go back and reprocess them
272  # with a more permissive configuration when we have more memory or processing time.
273  #
274  # self.runMeasureMerged will return whether there are any footprints in that image that required
275  # reprocessing. We need to convert that list of booleans into a dict mapping the patchId (x,y) to
276  # a boolean. That tells us whether the merge measurement and forced photometry need to be re-run on
277  # a particular patch.
278  #
279  # This determination of which patches need to be reprocessed exists only in memory (the measurements
280  # have been written, clobbering the old ones), so if there was an exception we would lose this
281  # information, leaving things in an inconsistent state (measurements new, but merged measurements and
282  # forced photometry old). To attempt to preserve this status, we touch a file (dataset named
283  # "deepCoadd_multibandReprocessing") --- if this file exists, we need to re-run the merge and
284  # forced photometry.
285  #
286  # This is, hopefully, a temporary workaround until we can improve the
287  # deblender.
288  try:
289  reprocessed = pool.map(self.runMeasureMerged, dataIdList)
290  finally:
291  if self.config.reprocessing:
292  patchReprocessing = {}
293  for dataId, reprocess in zip(dataIdList, reprocessed):
294  patchId = dataId["patch"]
295  patchReprocessing[patchId] = patchReprocessing.get(
296  patchId, False) or reprocess
297  # Persist the determination, to make error recover easier
298  reprocessDataset = self.config.coaddName + "Coadd_multibandReprocessing"
299  for patchId in patchReprocessing:
300  if not patchReprocessing[patchId]:
301  continue
302  dataId = dict(tract=tract, patch=patchId)
303  if patchReprocessing[patchId]:
304  filename = butler.get(
305  reprocessDataset + "_filename", dataId)[0]
306  open(filename, 'a').close() # Touch file
307  elif butler.datasetExists(reprocessDataset, dataId):
308  # We must have failed at some point while reprocessing
309  # and we're starting over
310  patchReprocessing[patchId] = True
311 
312  # Only process patches that have been identified as needing it
313  pool.map(self.runMergeMeasurements, [idList for patchId, idList in patches.items() if
314  not self.config.reprocessing or patchReprocessing[patchId]])
315  pool.map(self.runForcedPhot, [dataId1 for dataId1 in dataIdList if not self.config.reprocessing or
316  patchReprocessing[dataId["patch"]]])
317 
318  # Remove persisted reprocessing determination
319  if self.config.reprocessing:
320  for patchId in patchReprocessing:
321  if not patchReprocessing[patchId]:
322  continue
323  dataId = dict(tract=tract, patch=patchId)
324  filename = butler.get(
325  reprocessDataset + "_filename", dataId)[0]
326  os.unlink(filename)
327 
328  def runDetection(self, cache, patchRef):
329  """! Run detection on a patch
330 
331  Only slave nodes execute this method.
332 
333  @param cache: Pool cache, containing butler
334  @param patchRef: Patch on which to do detection
335  """
336  with self.logOperation("do detections on {}".format(patchRef.dataId)):
337  idFactory = self.detectCoaddSources.makeIdFactory(patchRef)
338  coadd = patchRef.get(self.config.coaddName + "Coadd",
339  immediate=True)
340  self.detectCoaddSources.emptyMetadata()
341  detResults = self.detectCoaddSources.runDetection(coadd, idFactory)
342  self.detectCoaddSources.write(coadd, detResults, patchRef)
343  self.detectCoaddSources.writeMetadata(patchRef)
344 
345  def runMergeDetections(self, cache, dataIdList):
346  """!Run detection merging on a patch
347 
348  Only slave nodes execute this method.
349 
350  @param cache: Pool cache, containing butler
351  @param dataIdList: List of data identifiers for the patch in different filters
352  """
353  with self.logOperation("merge detections from %s" % (dataIdList,)):
354  dataRefList = [getDataRef(cache.butler, dataId, self.config.coaddName + "Coadd_calexp") for
355  dataId in dataIdList]
356  if (not self.config.clobberMergedDetections and
357  dataRefList[0].datasetExists(self.config.coaddName + "Coadd_mergeDet")):
358  return
359  self.mergeCoaddDetections.run(dataRefList)
360 
361  def runMeasureMerged(self, cache, dataId):
362  """!Run measurement on a patch for a single filter
363 
364  Only slave nodes execute this method.
365 
366  @param cache: Pool cache, with butler
367  @param dataId: Data identifier for patch
368  @return whether the patch requires reprocessing.
369  """
370  with self.logOperation("measurement on %s" % (dataId,)):
371  dataRef = getDataRef(cache.butler, dataId,
372  self.config.coaddName + "Coadd_calexp")
373  reprocessing = False # Does this patch require reprocessing?
374  if (not self.config.clobberMeasurements and
375  dataRef.datasetExists(self.config.coaddName + "Coadd_meas")):
376  if not self.config.reprocessing:
377  return False
378 
379  catalog = dataRef.get(self.config.coaddName + "Coadd_meas")
380  bigFlag = catalog["deblend.parent-too-big"]
381  numOldBig = bigFlag.sum()
382  if numOldBig == 0:
383  self.log.info("No large footprints in %s" %
384  (dataRef.dataId,))
385  return False
386  numNewBig = sum((self.measureCoaddSources.deblend.isLargeFootprint(src.getFootprint()) for
387  src in catalog[bigFlag]))
388  if numNewBig == numOldBig:
389  self.log.info("All %d formerly large footprints continue to be large in %s" %
390  (numOldBig, dataRef.dataId,))
391  return False
392  self.log.info("Found %d large footprints to be reprocessed in %s" %
393  (numOldBig - numNewBig, dataRef.dataId))
394  reprocessing = True
395 
396  self.measureCoaddSources.run(dataRef)
397  return reprocessing
398 
399  def runMergeMeasurements(self, cache, dataIdList):
400  """!Run measurement merging on a patch
401 
402  Only slave nodes execute this method.
403 
404  @param cache: Pool cache, containing butler
405  @param dataIdList: List of data identifiers for the patch in different filters
406  """
407  with self.logOperation("merge measurements from %s" % (dataIdList,)):
408  dataRefList = [getDataRef(cache.butler, dataId, self.config.coaddName + "Coadd_calexp") for
409  dataId in dataIdList]
410  if (not self.config.clobberMergedMeasurements and
411  not self.config.reprocessing and
412  dataRefList[0].datasetExists(self.config.coaddName + "Coadd_ref")):
413  return
414  self.mergeCoaddMeasurements.run(dataRefList)
415 
416  def runForcedPhot(self, cache, dataId):
417  """!Run forced photometry on a patch for a single filter
418 
419  Only slave nodes execute this method.
420 
421  @param cache: Pool cache, with butler
422  @param dataId: Data identifier for patch
423  """
424  with self.logOperation("forced photometry on %s" % (dataId,)):
425  dataRef = getDataRef(cache.butler, dataId,
426  self.config.coaddName + "Coadd_calexp")
427  if (not self.config.clobberForcedPhotometry and
428  not self.config.reprocessing and
429  dataRef.datasetExists(self.config.coaddName + "Coadd_forced_src")):
430  return
431  self.forcedPhotCoadd.run(dataRef)
432 
433  def writeMetadata(self, dataRef):
434  """We don't collect any metadata, so skip"""
435  pass
def unpickle(factory, args, kwargs)
def __init__(self, butler=None, schema=None, refObjLoader=None, kwargs)
def getDataRef(butler, dataId, datasetType="raw")
Definition: utils.py:17
def runForcedPhot(self, cache, dataId)
Run forced photometry on a patch for a single filter.
def batchWallTime(cls, time, parsedCmd, numCpus)
Return walltime request for batch job.
def runMeasureMerged(self, cache, dataId)
Run measurement on a patch for a single filter.
def logOperation(self, operation, catch=False, trace=True)
def makeDataRefList(self, namespace)
Make self.refList from self.idList.
def runDetection(self, cache, patchRef)
Run detection on a patch.
def runMergeDetections(self, cache, dataIdList)
Run detection merging on a patch.
def run(self, patchRefList)
Run multiband processing on coadds.
def runMergeMeasurements(self, cache, dataIdList)
Run measurement merging on a patch.