lsst.pipe.drivers  13.0-4-g634fdee+3
 All Classes Namespaces Files Functions Variables Pages
multiBandDriver.py
Go to the documentation of this file.
1 from __future__ import absolute_import, division, print_function
2 import os
3 from argparse import ArgumentError
4 
5 from builtins import zip
6 
7 from lsst.pex.config import Config, Field, ConfigurableField
8 from lsst.pipe.base import ArgumentParser, TaskRunner
9 from lsst.pipe.tasks.multiBand import (MergeDetectionsTask,
10  MeasureMergedCoaddSourcesTask, MergeMeasurementsTask,)
11 from lsst.ctrl.pool.parallel import BatchPoolTask
12 from lsst.ctrl.pool.pool import Pool, abortOnError
13 from lsst.meas.base.references import MultiBandReferencesTask
14 from lsst.meas.base.forcedPhotCoadd import ForcedPhotCoaddTask
15 from lsst.pipe.drivers.utils import getDataRef, TractDataIdContainer
16 from lsst.pipe.tasks.coaddBase import CoaddDataIdContainer
17 
18 import lsst.afw.table as afwTable
19 
20 
21 class MultiBandDataIdContainer(CoaddDataIdContainer):
22 
23  def makeDataRefList(self, namespace):
24  """!Make self.refList from self.idList
25 
26  It's difficult to make a data reference that merely points to an entire
27  tract: there is no data product solely at the tract level. Instead, we
28  generate a list of data references for patches within the tract.
29 
30  @param namespace namespace object that is the result of an argument parser
31  """
32  datasetType = namespace.config.coaddName + "Coadd_calexp"
33 
34  def getPatchRefList(tract):
35  return [namespace.butler.dataRef(datasetType=datasetType,
36  tract=tract.getId(),
37  filter=dataId["filter"],
38  patch="%d,%d" % patch.getIndex())
39  for patch in tract]
40 
41  tractRefs = {} # Data references for each tract
42  for dataId in self.idList:
43  # There's no registry of coadds by filter, so we need to be given
44  # the filter
45  if "filter" not in dataId:
46  raise ArgumentError(None, "--id must include 'filter'")
47 
48  skymap = self.getSkymap(namespace, datasetType)
49 
50  if "tract" in dataId:
51  tractId = dataId["tract"]
52  if tractId not in tractRefs:
53  tractRefs[tractId] = []
54  if "patch" in dataId:
55  tractRefs[tractId].append(namespace.butler.dataRef(datasetType=datasetType,
56  tract=tractId,
57  filter=dataId[
58  'filter'],
59  patch=dataId['patch']))
60  else:
61  tractRefs[tractId] += getPatchRefList(skymap[tractId])
62  else:
63  tractRefs = dict((tract.getId(), tractRefs.get(tract.getId(), []) + getPatchRefList(tract))
64  for tract in skymap)
65 
66  self.refList = list(tractRefs.values())
67 
68 
69 class MultiBandDriverConfig(Config):
70  coaddName = Field(dtype=str, default="deep", doc="Name of coadd")
71  mergeCoaddDetections = ConfigurableField(
72  target=MergeDetectionsTask, doc="Merge detections")
73  measureCoaddSources = ConfigurableField(target=MeasureMergedCoaddSourcesTask,
74  doc="Measure merged detections")
75  mergeCoaddMeasurements = ConfigurableField(
76  target=MergeMeasurementsTask, doc="Merge measurements")
77  forcedPhotCoadd = ConfigurableField(target=ForcedPhotCoaddTask,
78  doc="Forced measurement on coadded images")
79  clobberDetections = Field(
80  dtype=bool, default=False, doc="Clobber existing detections?")
81  clobberMergedDetections = Field(
82  dtype=bool, default=False, doc="Clobber existing merged detections?")
83  clobberMeasurements = Field(
84  dtype=bool, default=False, doc="Clobber existing measurements?")
85  clobberMergedMeasurements = Field(
86  dtype=bool, default=False, doc="Clobber existing merged measurements?")
87  clobberForcedPhotometry = Field(
88  dtype=bool, default=False, doc="Clobber existing forced photometry?")
89  reprocessing = Field(
90  dtype=bool, default=False,
91  doc=("Are we reprocessing?\n\n"
92  "This exists as a workaround for large deblender footprints causing large memory use "
93  "and/or very slow processing. We refuse to deblend those footprints when running on a cluster "
94  "and return to reprocess on a machine with larger memory or more time "
95  "if we consider those footprints important to recover."),
96  )
97 
98  def setDefaults(self):
99  Config.setDefaults(self)
100  self.forcedPhotCoadd.references.retarget(MultiBandReferencesTask)
101 
102  def validate(self):
103  for subtask in ("mergeCoaddDetections", "measureCoaddSources",
104  "mergeCoaddMeasurements", "forcedPhotCoadd"):
105  coaddName = getattr(self, subtask).coaddName
106  if coaddName != self.coaddName:
107  raise RuntimeError("%s.coaddName (%s) doesn't match root coaddName (%s)" %
108  (subtask, coaddName, self.coaddName))
109 
110 
111 class MultiBandDriverTaskRunner(TaskRunner):
112  """TaskRunner for running MultiBandTask
113 
114  This is similar to the lsst.pipe.base.ButlerInitializedTaskRunner,
115  except that we have a list of data references instead of a single
116  data reference being passed to the Task.run.
117  """
118 
119  def makeTask(self, parsedCmd=None, args=None):
120  """A variant of the base version that passes a butler argument to the task's constructor
121  parsedCmd or args must be specified.
122  """
123  if parsedCmd is not None:
124  butler = parsedCmd.butler
125  elif args is not None:
126  dataRefList, kwargs = args
127  butler = dataRefList[0].butlerSubset.butler
128  else:
129  raise RuntimeError("parsedCmd or args must be specified")
130  return self.TaskClass(config=self.config, log=self.log, butler=butler)
131 
132 
133 def unpickle(factory, args, kwargs):
134  """Unpickle something by calling a factory"""
135  return factory(*args, **kwargs)
136 
137 
138 class MultiBandDriverTask(BatchPoolTask):
139  """Multi-node driver for multiband processing"""
140  ConfigClass = MultiBandDriverConfig
141  _DefaultName = "multiBandDriver"
142  RunnerClass = MultiBandDriverTaskRunner
143 
144  def __init__(self, butler=None, schema=None, refObjLoader=None, **kwargs):
145  """!
146  @param[in] butler: the butler can be used to retrieve schema or passed to the refObjLoader constructor
147  in case it is needed.
148  @param[in] schema: the schema of the source detection catalog used as input.
149  @param[in] refObjLoader: an instance of LoadReferenceObjectsTasks that supplies an external reference
150  catalog. May be None if the butler argument is provided or all steps requiring a reference
151  catalog are disabled.
152  """
153  BatchPoolTask.__init__(self, **kwargs)
154  if schema is None:
155  assert butler is not None, "Butler not provided"
156  schema = butler.get(self.config.coaddName +
157  "Coadd_det_schema", immediate=True).schema
158  self.butler = butler
159  self.makeSubtask("mergeCoaddDetections", schema=schema)
160  self.makeSubtask("measureCoaddSources", schema=afwTable.Schema(self.mergeCoaddDetections.schema),
161  peakSchema=afwTable.Schema(
162  self.mergeCoaddDetections.merged.getPeakSchema()),
163  refObjLoader=refObjLoader, butler=butler)
164  self.makeSubtask("mergeCoaddMeasurements", schema=afwTable.Schema(
165  self.measureCoaddSources.schema))
166  self.makeSubtask("forcedPhotCoadd", refSchema=afwTable.Schema(
167  self.mergeCoaddMeasurements.schema))
168 
169  def __reduce__(self):
170  """Pickler"""
171  return unpickle, (self.__class__, [], dict(config=self.config, name=self._name,
172  parentTask=self._parentTask, log=self.log,
173  butler=self.butler))
174 
175  @classmethod
176  def _makeArgumentParser(cls, *args, **kwargs):
177  kwargs.pop("doBatch", False)
178  parser = ArgumentParser(name=cls._DefaultName, *args, **kwargs)
179  parser.add_id_argument("--id", "deepCoadd", help="data ID, e.g. --id tract=12345 patch=1,2",
180  ContainerClass=TractDataIdContainer)
181  return parser
182 
183  @classmethod
184  def batchWallTime(cls, time, parsedCmd, numCpus):
185  """!Return walltime request for batch job
186 
187  @param time: Requested time per iteration
188  @param parsedCmd: Results of argument parsing
189  @param numCores: Number of cores
190  """
191  numTargets = 0
192  for refList in parsedCmd.id.refList:
193  numTargets += len(refList)
194  return time*numTargets/float(numCpus)
195 
196  @abortOnError
197  def run(self, patchRefList):
198  """!Run multiband processing on coadds
199 
200  Only the master node runs this method.
201 
202  No real MPI communication (scatter/gather) takes place: all I/O goes
203  through the disk. We want the intermediate stages on disk, and the
204  component Tasks are implemented around this, so we just follow suit.
205 
206  @param patchRefList: Data references to run measurement
207  """
208  for patchRef in patchRefList:
209  if patchRef:
210  butler = patchRef.getButler()
211  break
212  else:
213  raise RuntimeError("No valid patches")
214  pool = Pool("all")
215  pool.cacheClear()
216  pool.storeSet(butler=butler)
217 
218  patchRefList = [patchRef for patchRef in patchRefList if
219  patchRef.datasetExists(self.config.coaddName + "Coadd_calexp") and
220  patchRef.datasetExists(self.config.coaddName + "Coadd_det")]
221  dataIdList = [patchRef.dataId for patchRef in patchRefList]
222 
223  # Group by patch
224  patches = {}
225  tract = None
226  for patchRef in patchRefList:
227  dataId = patchRef.dataId
228  if tract is None:
229  tract = dataId["tract"]
230  else:
231  assert tract == dataId["tract"]
232 
233  patch = dataId["patch"]
234  if patch not in patches:
235  patches[patch] = []
236  patches[patch].append(dataId)
237 
238  pool.map(self.runMergeDetections, patches.values())
239 
240  # Measure merged detections, and test for reprocessing
241  #
242  # The reprocessing allows us to have multiple attempts at deblending large footprints. Large
243  # footprints can suck up a lot of memory in the deblender, which means that when we process on a
244  # cluster, we want to refuse to deblend them (they're flagged "deblend.parent-too-big"). But since
245  # they may have astronomically interesting data, we want the ability to go back and reprocess them
246  # with a more permissive configuration when we have more memory or processing time.
247  #
248  # self.runMeasureMerged will return whether there are any footprints in that image that required
249  # reprocessing. We need to convert that list of booleans into a dict mapping the patchId (x,y) to
250  # a boolean. That tells us whether the merge measurement and forced photometry need to be re-run on
251  # a particular patch.
252  #
253  # This determination of which patches need to be reprocessed exists only in memory (the measurements
254  # have been written, clobbering the old ones), so if there was an exception we would lose this
255  # information, leaving things in an inconsistent state (measurements new, but merged measurements and
256  # forced photometry old). To attempt to preserve this status, we touch a file (dataset named
257  # "deepCoadd_multibandReprocessing") --- if this file exists, we need to re-run the merge and
258  # forced photometry.
259  #
260  # This is, hopefully, a temporary workaround until we can improve the
261  # deblender.
262  try:
263  reprocessed = pool.map(self.runMeasureMerged, dataIdList)
264  finally:
265  if self.config.reprocessing:
266  patchReprocessing = {}
267  for dataId, reprocess in zip(dataIdList, reprocessed):
268  patchId = dataId["patch"]
269  patchReprocessing[patchId] = patchReprocessing.get(
270  patchId, False) or reprocess
271  # Persist the determination, to make error recover easier
272  reprocessDataset = self.config.coaddName + "Coadd_multibandReprocessing"
273  for patchId in patchReprocessing:
274  if not patchReprocessing[patchId]:
275  continue
276  dataId = dict(tract=tract, patch=patchId)
277  if patchReprocessing[patchId]:
278  filename = butler.get(
279  reprocessDataset + "_filename", dataId)[0]
280  open(filename, 'a').close() # Touch file
281  elif butler.datasetExists(reprocessDataset, dataId):
282  # We must have failed at some point while reprocessing
283  # and we're starting over
284  patchReprocessing[patchId] = True
285 
286  # Only process patches that have been identified as needing it
287  pool.map(self.runMergeMeasurements, [idList for patchId, idList in patches.items() if
288  not self.config.reprocessing or patchReprocessing[patchId]])
289  pool.map(self.runForcedPhot, [dataId1 for dataId1 in dataIdList if not self.config.reprocessing or
290  patchReprocessing[dataId["patch"]]])
291 
292  # Remove persisted reprocessing determination
293  if self.config.reprocessing:
294  for patchId in patchReprocessing:
295  if not patchReprocessing[patchId]:
296  continue
297  dataId = dict(tract=tract, patch=patchId)
298  filename = butler.get(
299  reprocessDataset + "_filename", dataId)[0]
300  os.unlink(filename)
301 
302  def runMergeDetections(self, cache, dataIdList):
303  """!Run detection merging on a patch
304 
305  Only slave nodes execute this method.
306 
307  @param cache: Pool cache, containing butler
308  @param dataIdList: List of data identifiers for the patch in different filters
309  """
310  with self.logOperation("merge detections from %s" % (dataIdList,)):
311  dataRefList = [getDataRef(cache.butler, dataId, self.config.coaddName + "Coadd_calexp") for
312  dataId in dataIdList]
313  if (not self.config.clobberMergedDetections and
314  dataRefList[0].datasetExists(self.config.coaddName + "Coadd_mergeDet")):
315  return
316  self.mergeCoaddDetections.run(dataRefList)
317 
318  def runMeasureMerged(self, cache, dataId):
319  """!Run measurement on a patch for a single filter
320 
321  Only slave nodes execute this method.
322 
323  @param cache: Pool cache, with butler
324  @param dataId: Data identifier for patch
325  @return whether the patch requires reprocessing.
326  """
327  with self.logOperation("measurement on %s" % (dataId,)):
328  dataRef = getDataRef(cache.butler, dataId,
329  self.config.coaddName + "Coadd_calexp")
330  reprocessing = False # Does this patch require reprocessing?
331  if (not self.config.clobberMeasurements and
332  dataRef.datasetExists(self.config.coaddName + "Coadd_meas")):
333  if not self.config.reprocessing:
334  return False
335 
336  catalog = dataRef.get(self.config.coaddName + "Coadd_meas")
337  bigFlag = catalog["deblend.parent-too-big"]
338  numOldBig = bigFlag.sum()
339  if numOldBig == 0:
340  self.log.info("No large footprints in %s" %
341  (dataRef.dataId,))
342  return False
343  numNewBig = sum((self.measureCoaddSources.deblend.isLargeFootprint(src.getFootprint()) for
344  src in catalog[bigFlag]))
345  if numNewBig == numOldBig:
346  self.log.info("All %d formerly large footprints continue to be large in %s" %
347  (numOldBig, dataRef.dataId,))
348  return False
349  self.log.info("Found %d large footprints to be reprocessed in %s" %
350  (numOldBig - numNewBig, dataRef.dataId))
351  reprocessing = True
352 
353  self.measureCoaddSources.run(dataRef)
354  return reprocessing
355 
356  def runMergeMeasurements(self, cache, dataIdList):
357  """!Run measurement merging on a patch
358 
359  Only slave nodes execute this method.
360 
361  @param cache: Pool cache, containing butler
362  @param dataIdList: List of data identifiers for the patch in different filters
363  """
364  with self.logOperation("merge measurements from %s" % (dataIdList,)):
365  dataRefList = [getDataRef(cache.butler, dataId, self.config.coaddName + "Coadd_calexp") for
366  dataId in dataIdList]
367  if (not self.config.clobberMergedMeasurements and
368  not self.config.reprocessing and
369  dataRefList[0].datasetExists(self.config.coaddName + "Coadd_ref")):
370  return
371  self.mergeCoaddMeasurements.run(dataRefList)
372 
373  def runForcedPhot(self, cache, dataId):
374  """!Run forced photometry on a patch for a single filter
375 
376  Only slave nodes execute this method.
377 
378  @param cache: Pool cache, with butler
379  @param dataId: Data identifier for patch
380  """
381  with self.logOperation("forced photometry on %s" % (dataId,)):
382  dataRef = getDataRef(cache.butler, dataId,
383  self.config.coaddName + "Coadd_calexp")
384  if (not self.config.clobberForcedPhotometry and
385  not self.config.reprocessing and
386  dataRef.datasetExists(self.config.coaddName + "Coadd_forced_src")):
387  return
388  self.forcedPhotCoadd.run(dataRef)
389 
390  def writeMetadata(self, dataRef):
391  """We don't collect any metadata, so skip"""
392  pass
def run
Run multiband processing on coadds.
def runMergeMeasurements
Run measurement merging on a patch.
def makeDataRefList
Make self.refList from self.idList.
def runMergeDetections
Run detection merging on a patch.
def runMeasureMerged
Run measurement on a patch for a single filter.
def runForcedPhot
Run forced photometry on a patch for a single filter.
def batchWallTime
Return walltime request for batch job.