lsst.pipe.drivers  13.0-7-g11ca5b2+4
 All Classes Namespaces Files Functions Variables Pages
multiBandDriver.py
Go to the documentation of this file.
1 from __future__ import absolute_import, division, print_function
2 import os
3 from argparse import ArgumentError
4 
5 from builtins import zip
6 
7 from lsst.pex.config import Config, Field, ConfigurableField
8 from lsst.pipe.base import ArgumentParser, TaskRunner
9 from lsst.pipe.tasks.multiBand import (DetectCoaddSourcesTask,
10  MergeDetectionsTask,
11  MeasureMergedCoaddSourcesTask,
12  MergeMeasurementsTask,)
13 from lsst.ctrl.pool.parallel import BatchPoolTask
14 from lsst.ctrl.pool.pool import Pool, abortOnError, NODE
15 from lsst.meas.base.references import MultiBandReferencesTask
16 from lsst.meas.base.forcedPhotCoadd import ForcedPhotCoaddTask
17 from lsst.pipe.drivers.utils import getDataRef, TractDataIdContainer
18 from lsst.pipe.tasks.coaddBase import CoaddDataIdContainer
19 
20 import lsst.afw.table as afwTable
21 
22 
23 class MultiBandDataIdContainer(CoaddDataIdContainer):
24 
25  def makeDataRefList(self, namespace):
26  """!Make self.refList from self.idList
27 
28  It's difficult to make a data reference that merely points to an entire
29  tract: there is no data product solely at the tract level. Instead, we
30  generate a list of data references for patches within the tract.
31 
32  @param namespace namespace object that is the result of an argument parser
33  """
34  datasetType = namespace.config.coaddName + "Coadd_calexp"
35 
36  def getPatchRefList(tract):
37  return [namespace.butler.dataRef(datasetType=datasetType,
38  tract=tract.getId(),
39  filter=dataId["filter"],
40  patch="%d,%d" % patch.getIndex())
41  for patch in tract]
42 
43  tractRefs = {} # Data references for each tract
44  for dataId in self.idList:
45  # There's no registry of coadds by filter, so we need to be given
46  # the filter
47  if "filter" not in dataId:
48  raise ArgumentError(None, "--id must include 'filter'")
49 
50  skymap = self.getSkymap(namespace, datasetType)
51 
52  if "tract" in dataId:
53  tractId = dataId["tract"]
54  if tractId not in tractRefs:
55  tractRefs[tractId] = []
56  if "patch" in dataId:
57  tractRefs[tractId].append(namespace.butler.dataRef(datasetType=datasetType,
58  tract=tractId,
59  filter=dataId[
60  'filter'],
61  patch=dataId['patch']))
62  else:
63  tractRefs[tractId] += getPatchRefList(skymap[tractId])
64  else:
65  tractRefs = dict((tract.getId(), tractRefs.get(tract.getId(), []) + getPatchRefList(tract))
66  for tract in skymap)
67 
68  self.refList = list(tractRefs.values())
69 
70 
71 class MultiBandDriverConfig(Config):
72  coaddName = Field(dtype=str, default="deep", doc="Name of coadd")
73  detectCoaddSources = ConfigurableField(target=DetectCoaddSourcesTask,
74  doc="Detect sources on coadd")
75  mergeCoaddDetections = ConfigurableField(
76  target=MergeDetectionsTask, doc="Merge detections")
77  measureCoaddSources = ConfigurableField(target=MeasureMergedCoaddSourcesTask,
78  doc="Measure merged detections")
79  mergeCoaddMeasurements = ConfigurableField(
80  target=MergeMeasurementsTask, doc="Merge measurements")
81  forcedPhotCoadd = ConfigurableField(target=ForcedPhotCoaddTask,
82  doc="Forced measurement on coadded images")
83  clobberDetections = Field(
84  dtype=bool, default=False, doc="Clobber existing detections?")
85  clobberMergedDetections = Field(
86  dtype=bool, default=False, doc="Clobber existing merged detections?")
87  clobberMeasurements = Field(
88  dtype=bool, default=False, doc="Clobber existing measurements?")
89  clobberMergedMeasurements = Field(
90  dtype=bool, default=False, doc="Clobber existing merged measurements?")
91  clobberForcedPhotometry = Field(
92  dtype=bool, default=False, doc="Clobber existing forced photometry?")
93  reprocessing = Field(
94  dtype=bool, default=False,
95  doc=("Are we reprocessing?\n\n"
96  "This exists as a workaround for large deblender footprints causing large memory use "
97  "and/or very slow processing. We refuse to deblend those footprints when running on a cluster "
98  "and return to reprocess on a machine with larger memory or more time "
99  "if we consider those footprints important to recover."),
100  )
101 
102  def setDefaults(self):
103  Config.setDefaults(self)
104  self.forcedPhotCoadd.references.retarget(MultiBandReferencesTask)
105 
106  def validate(self):
107  for subtask in ("mergeCoaddDetections", "measureCoaddSources",
108  "mergeCoaddMeasurements", "forcedPhotCoadd"):
109  coaddName = getattr(self, subtask).coaddName
110  if coaddName != self.coaddName:
111  raise RuntimeError("%s.coaddName (%s) doesn't match root coaddName (%s)" %
112  (subtask, coaddName, self.coaddName))
113 
114 
115 class MultiBandDriverTaskRunner(TaskRunner):
116  """TaskRunner for running MultiBandTask
117 
118  This is similar to the lsst.pipe.base.ButlerInitializedTaskRunner,
119  except that we have a list of data references instead of a single
120  data reference being passed to the Task.run.
121  """
122 
123  def makeTask(self, parsedCmd=None, args=None):
124  """A variant of the base version that passes a butler argument to the task's constructor
125  parsedCmd or args must be specified.
126  """
127  if parsedCmd is not None:
128  butler = parsedCmd.butler
129  elif args is not None:
130  dataRefList, kwargs = args
131  butler = dataRefList[0].butlerSubset.butler
132  else:
133  raise RuntimeError("parsedCmd or args must be specified")
134  return self.TaskClass(config=self.config, log=self.log, butler=butler)
135 
136 
137 def unpickle(factory, args, kwargs):
138  """Unpickle something by calling a factory"""
139  return factory(*args, **kwargs)
140 
141 
142 class MultiBandDriverTask(BatchPoolTask):
143  """Multi-node driver for multiband processing"""
144  ConfigClass = MultiBandDriverConfig
145  _DefaultName = "multiBandDriver"
146  RunnerClass = MultiBandDriverTaskRunner
147 
148  def __init__(self, butler=None, schema=None, refObjLoader=None, **kwargs):
149  """!
150  @param[in] butler: the butler can be used to retrieve schema or passed to the refObjLoader constructor
151  in case it is needed.
152  @param[in] schema: the schema of the source detection catalog used as input.
153  @param[in] refObjLoader: an instance of LoadReferenceObjectsTasks that supplies an external reference
154  catalog. May be None if the butler argument is provided or all steps requiring a reference
155  catalog are disabled.
156  """
157  BatchPoolTask.__init__(self, **kwargs)
158  if schema is None:
159  assert butler is not None, "Butler not provided"
160  schema = butler.get(self.config.coaddName +
161  "Coadd_det_schema", immediate=True).schema
162  self.butler = butler
163  self.makeSubtask("detectCoaddSources")
164  self.makeSubtask("mergeCoaddDetections", schema=schema)
165  self.makeSubtask("measureCoaddSources", schema=afwTable.Schema(self.mergeCoaddDetections.schema),
166  peakSchema=afwTable.Schema(
167  self.mergeCoaddDetections.merged.getPeakSchema()),
168  refObjLoader=refObjLoader, butler=butler)
169  self.makeSubtask("mergeCoaddMeasurements", schema=afwTable.Schema(
170  self.measureCoaddSources.schema))
171  self.makeSubtask("forcedPhotCoadd", refSchema=afwTable.Schema(
172  self.mergeCoaddMeasurements.schema))
173 
174  def __reduce__(self):
175  """Pickler"""
176  return unpickle, (self.__class__, [], dict(config=self.config, name=self._name,
177  parentTask=self._parentTask, log=self.log,
178  butler=self.butler))
179 
180  @classmethod
181  def _makeArgumentParser(cls, *args, **kwargs):
182  kwargs.pop("doBatch", False)
183  parser = ArgumentParser(name=cls._DefaultName, *args, **kwargs)
184  parser.add_id_argument("--id", "deepCoadd", help="data ID, e.g. --id tract=12345 patch=1,2",
185  ContainerClass=TractDataIdContainer)
186  return parser
187 
188  @classmethod
189  def batchWallTime(cls, time, parsedCmd, numCpus):
190  """!Return walltime request for batch job
191 
192  @param time: Requested time per iteration
193  @param parsedCmd: Results of argument parsing
194  @param numCores: Number of cores
195  """
196  numTargets = 0
197  for refList in parsedCmd.id.refList:
198  numTargets += len(refList)
199  return time*numTargets/float(numCpus)
200 
201  @abortOnError
202  def run(self, patchRefList):
203  """!Run multiband processing on coadds
204 
205  Only the master node runs this method.
206 
207  No real MPI communication (scatter/gather) takes place: all I/O goes
208  through the disk. We want the intermediate stages on disk, and the
209  component Tasks are implemented around this, so we just follow suit.
210 
211  @param patchRefList: Data references to run measurement
212  """
213  for patchRef in patchRefList:
214  if patchRef:
215  butler = patchRef.getButler()
216  break
217  else:
218  raise RuntimeError("No valid patches")
219  pool = Pool("all")
220  pool.cacheClear()
221  pool.storeSet(butler=butler)
222 
223  detectionList = [patchRef for patchRef in patchRefList if not
224  patchRef.datasetExists(self.config.coaddName +
225  "Coadd_calexp")]
226 
227  pool.map(self.runDetection, detectionList)
228 
229  patchRefList = [patchRef for patchRef in patchRefList if
230  patchRef.datasetExists(self.config.coaddName + "Coadd_calexp") and
231  patchRef.datasetExists(self.config.coaddName + "Coadd_det")]
232  dataIdList = [patchRef.dataId for patchRef in patchRefList]
233 
234  # Group by patch
235  patches = {}
236  tract = None
237  for patchRef in patchRefList:
238  dataId = patchRef.dataId
239  if tract is None:
240  tract = dataId["tract"]
241  else:
242  assert tract == dataId["tract"]
243 
244  patch = dataId["patch"]
245  if patch not in patches:
246  patches[patch] = []
247  patches[patch].append(dataId)
248 
249  pool.map(self.runMergeDetections, patches.values())
250 
251  # Measure merged detections, and test for reprocessing
252  #
253  # The reprocessing allows us to have multiple attempts at deblending large footprints. Large
254  # footprints can suck up a lot of memory in the deblender, which means that when we process on a
255  # cluster, we want to refuse to deblend them (they're flagged "deblend.parent-too-big"). But since
256  # they may have astronomically interesting data, we want the ability to go back and reprocess them
257  # with a more permissive configuration when we have more memory or processing time.
258  #
259  # self.runMeasureMerged will return whether there are any footprints in that image that required
260  # reprocessing. We need to convert that list of booleans into a dict mapping the patchId (x,y) to
261  # a boolean. That tells us whether the merge measurement and forced photometry need to be re-run on
262  # a particular patch.
263  #
264  # This determination of which patches need to be reprocessed exists only in memory (the measurements
265  # have been written, clobbering the old ones), so if there was an exception we would lose this
266  # information, leaving things in an inconsistent state (measurements new, but merged measurements and
267  # forced photometry old). To attempt to preserve this status, we touch a file (dataset named
268  # "deepCoadd_multibandReprocessing") --- if this file exists, we need to re-run the merge and
269  # forced photometry.
270  #
271  # This is, hopefully, a temporary workaround until we can improve the
272  # deblender.
273  try:
274  reprocessed = pool.map(self.runMeasureMerged, dataIdList)
275  finally:
276  if self.config.reprocessing:
277  patchReprocessing = {}
278  for dataId, reprocess in zip(dataIdList, reprocessed):
279  patchId = dataId["patch"]
280  patchReprocessing[patchId] = patchReprocessing.get(
281  patchId, False) or reprocess
282  # Persist the determination, to make error recover easier
283  reprocessDataset = self.config.coaddName + "Coadd_multibandReprocessing"
284  for patchId in patchReprocessing:
285  if not patchReprocessing[patchId]:
286  continue
287  dataId = dict(tract=tract, patch=patchId)
288  if patchReprocessing[patchId]:
289  filename = butler.get(
290  reprocessDataset + "_filename", dataId)[0]
291  open(filename, 'a').close() # Touch file
292  elif butler.datasetExists(reprocessDataset, dataId):
293  # We must have failed at some point while reprocessing
294  # and we're starting over
295  patchReprocessing[patchId] = True
296 
297  # Only process patches that have been identified as needing it
298  pool.map(self.runMergeMeasurements, [idList for patchId, idList in patches.items() if
299  not self.config.reprocessing or patchReprocessing[patchId]])
300  pool.map(self.runForcedPhot, [dataId1 for dataId1 in dataIdList if not self.config.reprocessing or
301  patchReprocessing[dataId["patch"]]])
302 
303  # Remove persisted reprocessing determination
304  if self.config.reprocessing:
305  for patchId in patchReprocessing:
306  if not patchReprocessing[patchId]:
307  continue
308  dataId = dict(tract=tract, patch=patchId)
309  filename = butler.get(
310  reprocessDataset + "_filename", dataId)[0]
311  os.unlink(filename)
312 
313  def runDetection(self, cache, patchRef):
314  """! Run detection on a patch
315 
316  Only slave nodes execute this method.
317 
318  @param cache: Pool cache, containing butler
319  @param patchRef: Patch on which to do detection
320  """
321  with self.logOperation("do detections on {}".format(patchRef.dataId)):
322  idFactory = self.detectCoaddSources.makeIdFactory(patchRef)
323  coadd = patchRef.get(self.config.coaddName + "Coadd",
324  immediate=True)
325  detResults = self.detectCoaddSources.runDetection(coadd, idFactory)
326  self.detectCoaddSources.write(coadd, detResults, patchRef)
327 
328  def runMergeDetections(self, cache, dataIdList):
329  """!Run detection merging on a patch
330 
331  Only slave nodes execute this method.
332 
333  @param cache: Pool cache, containing butler
334  @param dataIdList: List of data identifiers for the patch in different filters
335  """
336  with self.logOperation("merge detections from %s" % (dataIdList,)):
337  dataRefList = [getDataRef(cache.butler, dataId, self.config.coaddName + "Coadd_calexp") for
338  dataId in dataIdList]
339  if (not self.config.clobberMergedDetections and
340  dataRefList[0].datasetExists(self.config.coaddName + "Coadd_mergeDet")):
341  return
342  self.mergeCoaddDetections.run(dataRefList)
343 
344  def runMeasureMerged(self, cache, dataId):
345  """!Run measurement on a patch for a single filter
346 
347  Only slave nodes execute this method.
348 
349  @param cache: Pool cache, with butler
350  @param dataId: Data identifier for patch
351  @return whether the patch requires reprocessing.
352  """
353  with self.logOperation("measurement on %s" % (dataId,)):
354  dataRef = getDataRef(cache.butler, dataId,
355  self.config.coaddName + "Coadd_calexp")
356  reprocessing = False # Does this patch require reprocessing?
357  if (not self.config.clobberMeasurements and
358  dataRef.datasetExists(self.config.coaddName + "Coadd_meas")):
359  if not self.config.reprocessing:
360  return False
361 
362  catalog = dataRef.get(self.config.coaddName + "Coadd_meas")
363  bigFlag = catalog["deblend.parent-too-big"]
364  numOldBig = bigFlag.sum()
365  if numOldBig == 0:
366  self.log.info("No large footprints in %s" %
367  (dataRef.dataId,))
368  return False
369  numNewBig = sum((self.measureCoaddSources.deblend.isLargeFootprint(src.getFootprint()) for
370  src in catalog[bigFlag]))
371  if numNewBig == numOldBig:
372  self.log.info("All %d formerly large footprints continue to be large in %s" %
373  (numOldBig, dataRef.dataId,))
374  return False
375  self.log.info("Found %d large footprints to be reprocessed in %s" %
376  (numOldBig - numNewBig, dataRef.dataId))
377  reprocessing = True
378 
379  self.measureCoaddSources.run(dataRef)
380  return reprocessing
381 
382  def runMergeMeasurements(self, cache, dataIdList):
383  """!Run measurement merging on a patch
384 
385  Only slave nodes execute this method.
386 
387  @param cache: Pool cache, containing butler
388  @param dataIdList: List of data identifiers for the patch in different filters
389  """
390  with self.logOperation("merge measurements from %s" % (dataIdList,)):
391  dataRefList = [getDataRef(cache.butler, dataId, self.config.coaddName + "Coadd_calexp") for
392  dataId in dataIdList]
393  if (not self.config.clobberMergedMeasurements and
394  not self.config.reprocessing and
395  dataRefList[0].datasetExists(self.config.coaddName + "Coadd_ref")):
396  return
397  self.mergeCoaddMeasurements.run(dataRefList)
398 
399  def runForcedPhot(self, cache, dataId):
400  """!Run forced photometry on a patch for a single filter
401 
402  Only slave nodes execute this method.
403 
404  @param cache: Pool cache, with butler
405  @param dataId: Data identifier for patch
406  """
407  with self.logOperation("forced photometry on %s" % (dataId,)):
408  dataRef = getDataRef(cache.butler, dataId,
409  self.config.coaddName + "Coadd_calexp")
410  if (not self.config.clobberForcedPhotometry and
411  not self.config.reprocessing and
412  dataRef.datasetExists(self.config.coaddName + "Coadd_forced_src")):
413  return
414  self.forcedPhotCoadd.run(dataRef)
415 
416  def writeMetadata(self, dataRef):
417  """We don't collect any metadata, so skip"""
418  pass
def run
Run multiband processing on coadds.
def runMergeMeasurements
Run measurement merging on a patch.
def makeDataRefList
Make self.refList from self.idList.
def runMergeDetections
Run detection merging on a patch.
def runMeasureMerged
Run measurement on a patch for a single filter.
def runForcedPhot
Run forced photometry on a patch for a single filter.
def batchWallTime
Return walltime request for batch job.