lsst.pipe.drivers g8b6839f0a4+7a34e91110
multiBandDriver.py
Go to the documentation of this file.
1import os
2
3from lsst.pex.config import Config, Field, ConfigurableField
4from lsst.pipe.base import ArgumentParser, TaskRunner
5from lsst.pipe.tasks.multiBand import (DetectCoaddSourcesTask,
6 MergeDetectionsTask,
7 DeblendCoaddSourcesTask,
8 MeasureMergedCoaddSourcesTask,
9 MergeMeasurementsTask,)
10from lsst.ctrl.pool.parallel import BatchPoolTask
11from lsst.ctrl.pool.pool import Pool, abortOnError
12from lsst.meas.base.references import MultiBandReferencesTask
13from lsst.meas.base.forcedPhotCoadd import ForcedPhotCoaddTask
14from lsst.pipe.drivers.utils import getDataRef, TractDataIdContainer
15
16import lsst.afw.table as afwTable
17
18
20 coaddName = Field(dtype=str, default="deep", doc="Name of coadd")
21 doDetection = Field(dtype=bool, default=False,
22 doc="Re-run detection? (requires *Coadd dataset to have been written)")
23 detectCoaddSources = ConfigurableField(target=DetectCoaddSourcesTask,
24 doc="Detect sources on coadd")
25 mergeCoaddDetections = ConfigurableField(
26 target=MergeDetectionsTask, doc="Merge detections")
27 deblendCoaddSources = ConfigurableField(target=DeblendCoaddSourcesTask, doc="Deblend merged detections")
28 measureCoaddSources = ConfigurableField(target=MeasureMergedCoaddSourcesTask,
29 doc="Measure merged and (optionally) deblended detections")
30 mergeCoaddMeasurements = ConfigurableField(
31 target=MergeMeasurementsTask, doc="Merge measurements")
32 forcedPhotCoadd = ConfigurableField(target=ForcedPhotCoaddTask,
33 doc="Forced measurement on coadded images")
34 reprocessing = Field(
35 dtype=bool, default=False,
36 doc=("Are we reprocessing?\n\n"
37 "This exists as a workaround for large deblender footprints causing large memory use "
38 "and/or very slow processing. We refuse to deblend those footprints when running on a cluster "
39 "and return to reprocess on a machine with larger memory or more time "
40 "if we consider those footprints important to recover."),
41 )
42
43 hasFakes = Field(
44 dtype=bool,
45 default=False,
46 doc="Should be set to True if fakes were inserted into the data being processed."
47 )
48
49 def setDefaults(self):
50 Config.setDefaults(self)
51 self.forcedPhotCoaddforcedPhotCoadd.references.retarget(MultiBandReferencesTask)
52
53 def validate(self):
54
55 for subtask in ("mergeCoaddDetections", "deblendCoaddSources", "measureCoaddSources",
56 "mergeCoaddMeasurements", "forcedPhotCoadd"):
57 coaddName = getattr(self, subtask).coaddName
58 if coaddName != self.coaddNamecoaddName:
59 raise RuntimeError("%s.coaddName (%s) doesn't match root coaddName (%s)" %
60 (subtask, coaddName, self.coaddNamecoaddName))
61
62
63class MultiBandDriverTaskRunner(TaskRunner):
64 """TaskRunner for running MultiBandTask
65
66 This is similar to the lsst.pipe.base.ButlerInitializedTaskRunner,
67 except that we have a list of data references instead of a single
68 data reference being passed to the Task.run, and we pass the results
69 of the '--reuse-outputs-from' command option to the Task constructor.
70 """
71
72 def __init__(self, TaskClass, parsedCmd, doReturnResults=False):
73 TaskRunner.__init__(self, TaskClass, parsedCmd, doReturnResults)
74 self.reusereuse = parsedCmd.reuse
75
76 def makeTask(self, parsedCmd=None, args=None):
77 """A variant of the base version that passes a butler argument to the task's constructor
78 parsedCmd or args must be specified.
79 """
80 if parsedCmd is not None:
81 butler = parsedCmd.butler
82 elif args is not None:
83 dataRefList, kwargs = args
84 butler = dataRefList[0].butlerSubset.butler
85 else:
86 raise RuntimeError("parsedCmd or args must be specified")
87 return self.TaskClass(config=self.config, log=self.log, butler=butler, reuse=self.reusereuse)
88
89
90def unpickle(factory, args, kwargs):
91 """Unpickle something by calling a factory"""
92 return factory(*args, **kwargs)
93
94
96 """Multi-node driver for multiband processing"""
97 ConfigClass = MultiBandDriverConfig
98 _DefaultName = "multiBandDriver"
99 RunnerClass = MultiBandDriverTaskRunner
100
101 def __init__(self, butler=None, schema=None, refObjLoader=None, reuse=tuple(), **kwargs):
102 """!
103 @param[in] butler: the butler can be used to retrieve schema or passed to the refObjLoader constructor
104 in case it is needed.
105 @param[in] schema: the schema of the source detection catalog used as input.
106 @param[in] refObjLoader: an instance of LoadReferenceObjectsTasks that supplies an external reference
107 catalog. May be None if the butler argument is provided or all steps requiring a reference
108 catalog are disabled.
109 """
110 BatchPoolTask.__init__(self, **kwargs)
111 if schema is None:
112 assert butler is not None, "Butler not provided"
113 schema = butler.get(self.config.coaddName +
114 "Coadd_det_schema", immediate=True).schema
115 self.butlerbutler = butler
116 self.reusereuse = tuple(reuse)
117 self.makeSubtask("detectCoaddSources")
118 self.makeSubtask("mergeCoaddDetections", schema=schema)
119 if self.config.measureCoaddSources.inputCatalog.startswith("deblended"):
120 # Ensure that the output from deblendCoaddSources matches the input to measureCoaddSources
121 self.measurementInputmeasurementInput = self.config.measureCoaddSources.inputCatalog
122 self.deblenderOutputdeblenderOutput = []
123 self.deblenderOutputdeblenderOutput.append("deblendedFlux")
124 if self.measurementInputmeasurementInput not in self.deblenderOutputdeblenderOutput:
125 err = "Measurement input '{0}' is not in the list of deblender output catalogs '{1}'"
126 raise ValueError(err.format(self.measurementInputmeasurementInput, self.deblenderOutputdeblenderOutput))
127
128 self.makeSubtask("deblendCoaddSources",
129 schema=afwTable.Schema(self.mergeCoaddDetections.schema),
130 peakSchema=afwTable.Schema(self.mergeCoaddDetections.merged.getPeakSchema()),
131 butler=butler)
132 measureInputSchema = afwTable.Schema(self.deblendCoaddSources.schema)
133 else:
134 measureInputSchema = afwTable.Schema(self.mergeCoaddDetections.schema)
135 self.makeSubtask("measureCoaddSources", schema=measureInputSchema,
136 peakSchema=afwTable.Schema(
137 self.mergeCoaddDetections.merged.getPeakSchema()),
138 refObjLoader=refObjLoader, butler=butler)
139 self.makeSubtask("mergeCoaddMeasurements", schema=afwTable.Schema(
140 self.measureCoaddSources.schema))
141 self.makeSubtask("forcedPhotCoadd", refSchema=afwTable.Schema(
142 self.mergeCoaddMeasurements.schema))
143 if self.config.hasFakes:
144 self.coaddTypecoaddType = "fakes_" + self.config.coaddName
145 else:
146 self.coaddTypecoaddType = self.config.coaddName
147
148 def __reduce__(self):
149 """Pickler"""
150 return unpickle, (self.__class__, [], dict(config=self.config, name=self._name,
151 parentTask=self._parentTask, log=self.log,
152 butler=self.butlerbutler, reuse=self.reusereuse))
153
154 @classmethod
155 def _makeArgumentParser(cls, *args, **kwargs):
156 kwargs.pop("doBatch", False)
157 parser = ArgumentParser(name=cls._DefaultName_DefaultName, *args, **kwargs)
158 parser.add_id_argument("--id", "deepCoadd", help="data ID, e.g. --id tract=12345 patch=1,2",
159 ContainerClass=TractDataIdContainer)
160 parser.addReuseOption(["detectCoaddSources", "mergeCoaddDetections", "measureCoaddSources",
161 "mergeCoaddMeasurements", "forcedPhotCoadd", "deblendCoaddSources"])
162 return parser
163
164 @classmethod
165 def batchWallTime(cls, time, parsedCmd, numCpus):
166 """!Return walltime request for batch job
167
168 @param time: Requested time per iteration
169 @param parsedCmd: Results of argument parsing
170 @param numCores: Number of cores
171 """
172 numTargets = 0
173 for refList in parsedCmd.id.refList:
174 numTargets += len(refList)
175 return time*numTargets/float(numCpus)
176
177 @abortOnError
178 def runDataRef(self, patchRefList):
179 """!Run multiband processing on coadds
180
181 Only the master node runs this method.
182
183 No real MPI communication (scatter/gather) takes place: all I/O goes
184 through the disk. We want the intermediate stages on disk, and the
185 component Tasks are implemented around this, so we just follow suit.
186
187 @param patchRefList: Data references to run measurement
188 """
189 for patchRef in patchRefList:
190 if patchRef:
191 butler = patchRef.getButler()
192 break
193 else:
194 raise RuntimeError("No valid patches")
195 pool = Pool("all")
196 pool.cacheClear()
197 pool.storeSet(butler=butler)
198 # MultiBand measurements require that the detection stage be completed
199 # before measurements can be made.
200 #
201 # The configuration for coaddDriver.py allows detection to be turned
202 # of in the event that fake objects are to be added during the
203 # detection process. This allows the long co-addition process to be
204 # run once, and multiple different MultiBand reruns (with different
205 # fake objects) to exist from the same base co-addition.
206 #
207 # However, we only re-run detection if doDetection is explicitly True
208 # here (this should always be the opposite of coaddDriver.doDetection);
209 # otherwise we have no way to tell reliably whether any detections
210 # present in an input repo are safe to use.
211 if self.config.doDetection:
212 detectionList = []
213 for patchRef in patchRefList:
214 if ("detectCoaddSources" in self.reusereuse and
215 patchRef.datasetExists(self.coaddTypecoaddType + "Coadd_calexp", write=True)):
216 self.log.info("Skipping detectCoaddSources for %s; output already exists." %
217 patchRef.dataId)
218 continue
219 if not patchRef.datasetExists(self.coaddTypecoaddType + "Coadd"):
220 self.log.debug("Not processing %s; required input %sCoadd missing." %
221 (patchRef.dataId, self.config.coaddName))
222 continue
223 detectionList.append(patchRef)
224
225 pool.map(self.runDetectionrunDetection, detectionList)
226
227 patchRefList = [patchRef for patchRef in patchRefList if
228 patchRef.datasetExists(self.coaddTypecoaddType + "Coadd_calexp") and
229 patchRef.datasetExists(self.config.coaddName + "Coadd_det",
230 write=self.config.doDetection)]
231 dataIdList = [patchRef.dataId for patchRef in patchRefList]
232
233 # Group by patch
234 patches = {}
235 tract = None
236 for patchRef in patchRefList:
237 dataId = patchRef.dataId
238 if tract is None:
239 tract = dataId["tract"]
240 else:
241 assert tract == dataId["tract"]
242
243 patch = dataId["patch"]
244 if patch not in patches:
245 patches[patch] = []
246 patches[patch].append(dataId)
247
248 pool.map(self.runMergeDetectionsrunMergeDetections, patches.values())
249
250 # Deblend merged detections, and test for reprocessing
251 #
252 # The reprocessing allows us to have multiple attempts at deblending large footprints. Large
253 # footprints can suck up a lot of memory in the deblender, which means that when we process on a
254 # cluster, we want to refuse to deblend them (they're flagged "deblend.parent-too-big"). But since
255 # they may have astronomically interesting data, we want the ability to go back and reprocess them
256 # with a more permissive configuration when we have more memory or processing time.
257 #
258 # self.runDeblendMerged will return whether there are any footprints in that image that required
259 # reprocessing. We need to convert that list of booleans into a dict mapping the patchId (x,y) to
260 # a boolean. That tells us whether the merge measurement and forced photometry need to be re-run on
261 # a particular patch.
262 #
263 # This determination of which patches need to be reprocessed exists only in memory (the measurements
264 # have been written, clobbering the old ones), so if there was an exception we would lose this
265 # information, leaving things in an inconsistent state (measurements, merged measurements and
266 # forced photometry old). To attempt to preserve this status, we touch a file (dataset named
267 # "deepCoadd_multibandReprocessing") --- if this file exists, we need to re-run the measurements,
268 # merge and forced photometry.
269 #
270 # This is, hopefully, a temporary workaround until we can improve the
271 # deblender.
272 try:
273 reprocessed = pool.map(self.runDeblendMergedrunDeblendMerged, patches.values())
274 finally:
275 if self.config.reprocessing:
276 patchReprocessing = {}
277 for dataId, reprocess in zip(dataIdList, reprocessed):
278 patchId = dataId["patch"]
279 patchReprocessing[patchId] = patchReprocessing.get(
280 patchId, False) or reprocess
281 # Persist the determination, to make error recover easier
282 reprocessDataset = self.config.coaddName + "Coadd_multibandReprocessing"
283 for patchId in patchReprocessing:
284 if not patchReprocessing[patchId]:
285 continue
286 dataId = dict(tract=tract, patch=patchId)
287 if patchReprocessing[patchId]:
288 filename = butler.get(
289 reprocessDataset + "_filename", dataId)[0]
290 open(filename, 'a').close() # Touch file
291 elif butler.datasetExists(reprocessDataset, dataId):
292 # We must have failed at some point while reprocessing
293 # and we're starting over
294 patchReprocessing[patchId] = True
295
296 # Only process patches that have been identifiedz as needing it
297 pool.map(self.runMeasurementsrunMeasurements, [dataId1 for dataId1 in dataIdList if not self.config.reprocessing or
298 patchReprocessing[dataId1["patch"]]])
299 pool.map(self.runMergeMeasurementsrunMergeMeasurements, [idList for patchId, idList in patches.items() if
300 not self.config.reprocessing or patchReprocessing[patchId]])
301 pool.map(self.runForcedPhotrunForcedPhot, [dataId1 for dataId1 in dataIdList if not self.config.reprocessing or
302 patchReprocessing[dataId1["patch"]]])
303
304 # Remove persisted reprocessing determination
305 if self.config.reprocessing:
306 for patchId in patchReprocessing:
307 if not patchReprocessing[patchId]:
308 continue
309 dataId = dict(tract=tract, patch=patchId)
310 filename = butler.get(
311 reprocessDataset + "_filename", dataId)[0]
312 os.unlink(filename)
313
314 def runDetection(self, cache, patchRef):
315 """! Run detection on a patch
316
317 Only slave nodes execute this method.
318
319 @param cache: Pool cache, containing butler
320 @param patchRef: Patch on which to do detection
321 """
322 with self.logOperationlogOperation("do detections on {}".format(patchRef.dataId)):
323 idFactory = self.detectCoaddSources.makeIdFactory(patchRef)
324 coadd = patchRef.get(self.coaddTypecoaddType + "Coadd", immediate=True)
325 expId = int(patchRef.get(self.config.coaddName + "CoaddId"))
326 self.detectCoaddSources.emptyMetadata()
327 detResults = self.detectCoaddSources.run(coadd, idFactory, expId=expId)
328 self.detectCoaddSources.write(detResults, patchRef)
329 self.detectCoaddSources.writeMetadata(patchRef)
330
331 def runMergeDetections(self, cache, dataIdList):
332 """!Run detection merging on a patch
333
334 Only slave nodes execute this method.
335
336 @param cache: Pool cache, containing butler
337 @param dataIdList: List of data identifiers for the patch in different filters
338 """
339 with self.logOperationlogOperation("merge detections from %s" % (dataIdList,)):
340 dataRefList = [getDataRef(cache.butler, dataId, self.coaddTypecoaddType + "Coadd_calexp") for
341 dataId in dataIdList]
342 if ("mergeCoaddDetections" in self.reusereuse and
343 dataRefList[0].datasetExists(self.config.coaddName + "Coadd_mergeDet", write=True)):
344 self.log.info("Skipping mergeCoaddDetections for %s; output already exists." %
345 dataRefList[0].dataId)
346 return
347 self.mergeCoaddDetections.runDataRef(dataRefList)
348
349 def runDeblendMerged(self, cache, dataIdList):
350 """Run the deblender on a list of dataId's
351
352 Only slave nodes execute this method.
353
354 Parameters
355 ----------
356 cache: Pool cache
357 Pool cache with butler.
358 dataIdList: list
359 Data identifier for patch in each band.
360
361 Returns
362 -------
363 result: bool
364 whether the patch requires reprocessing.
365 """
366 with self.logOperationlogOperation("deblending %s" % (dataIdList,)):
367 dataRefList = [getDataRef(cache.butler, dataId, self.coaddTypecoaddType + "Coadd_calexp") for
368 dataId in dataIdList]
369 reprocessing = False # Does this patch require reprocessing?
370 if ("deblendCoaddSources" in self.reusereuse and
371 all([dataRef.datasetExists(self.config.coaddName + "Coadd_" + self.measurementInputmeasurementInput,
372 write=True) for dataRef in dataRefList])):
373 if not self.config.reprocessing:
374 self.log.info("Skipping deblendCoaddSources for %s; output already exists" % dataIdList)
375 return False
376
377 # Footprints are the same every band, therefore we can check just one
378 catalog = dataRefList[0].get(self.config.coaddName + "Coadd_" + self.measurementInputmeasurementInput)
379 bigFlag = catalog["deblend_parentTooBig"]
380 # Footprints marked too large by the previous deblender run
381 numOldBig = bigFlag.sum()
382 if numOldBig == 0:
383 self.log.info("No large footprints in %s" % (dataRefList[0].dataId))
384 return False
385
386 # This if-statement can be removed after DM-15662
387 if self.config.deblendCoaddSources.simultaneous:
388 deblender = self.deblendCoaddSources.multiBandDeblend
389 else:
390 deblender = self.deblendCoaddSources.singleBandDeblend
391
392 # isLargeFootprint() can potentially return False for a source that is marked
393 # too big in the catalog, because of "new"/different deblender configs.
394 # numNewBig is the number of footprints that *will* be too big if reprocessed
395 numNewBig = sum((deblender.isLargeFootprint(src.getFootprint()) for
396 src in catalog[bigFlag]))
397 if numNewBig == numOldBig:
398 self.log.info("All %d formerly large footprints continue to be large in %s" %
399 (numOldBig, dataRefList[0].dataId,))
400 return False
401 self.log.info("Found %d large footprints to be reprocessed in %s" %
402 (numOldBig - numNewBig, [dataRef.dataId for dataRef in dataRefList]))
403 reprocessing = True
404
405 self.deblendCoaddSources.runDataRef(dataRefList)
406 return reprocessing
407
408 def runMeasurements(self, cache, dataId):
409 """Run measurement on a patch for a single filter
410
411 Only slave nodes execute this method.
412
413 Parameters
414 ----------
415 cache: Pool cache
416 Pool cache, with butler
417 dataId: dataRef
418 Data identifier for patch
419 """
420 with self.logOperationlogOperation("measurements on %s" % (dataId,)):
421 dataRef = getDataRef(cache.butler, dataId, self.coaddTypecoaddType + "Coadd_calexp")
422 if ("measureCoaddSources" in self.reusereuse and
423 not self.config.reprocessing and
424 dataRef.datasetExists(self.config.coaddName + "Coadd_meas", write=True)):
425 self.log.info("Skipping measuretCoaddSources for %s; output already exists" % dataId)
426 return
427 self.measureCoaddSources.runDataRef(dataRef)
428
429 def runMergeMeasurements(self, cache, dataIdList):
430 """!Run measurement merging on a patch
431
432 Only slave nodes execute this method.
433
434 @param cache: Pool cache, containing butler
435 @param dataIdList: List of data identifiers for the patch in different filters
436 """
437 with self.logOperationlogOperation("merge measurements from %s" % (dataIdList,)):
438 dataRefList = [getDataRef(cache.butler, dataId, self.coaddTypecoaddType + "Coadd_calexp") for
439 dataId in dataIdList]
440 if ("mergeCoaddMeasurements" in self.reusereuse and
441 not self.config.reprocessing and
442 dataRefList[0].datasetExists(self.config.coaddName + "Coadd_ref", write=True)):
443 self.log.info("Skipping mergeCoaddMeasurements for %s; output already exists" %
444 dataRefList[0].dataId)
445 return
446 self.mergeCoaddMeasurements.runDataRef(dataRefList)
447
448 def runForcedPhot(self, cache, dataId):
449 """!Run forced photometry on a patch for a single filter
450
451 Only slave nodes execute this method.
452
453 @param cache: Pool cache, with butler
454 @param dataId: Data identifier for patch
455 """
456 with self.logOperationlogOperation("forced photometry on %s" % (dataId,)):
457 dataRef = getDataRef(cache.butler, dataId,
458 self.coaddTypecoaddType + "Coadd_calexp")
459 if ("forcedPhotCoadd" in self.reusereuse and
460 not self.config.reprocessing and
461 dataRef.datasetExists(self.config.coaddName + "Coadd_forced_src", write=True)):
462 self.log.info("Skipping forcedPhotCoadd for %s; output already exists" % dataId)
463 return
464 self.forcedPhotCoadd.runDataRef(dataRef)
465
466 def writeMetadata(self, dataRef):
467 """We don't collect any metadata, so skip"""
468 pass
def logOperation(self, operation, catch=False, trace=True)
def batchWallTime(cls, time, parsedCmd, numCpus)
Return walltime request for batch job.
def runDetection(self, cache, patchRef)
Run detection on a patch.
def runMergeMeasurements(self, cache, dataIdList)
Run measurement merging on a patch.
def runMergeDetections(self, cache, dataIdList)
Run detection merging on a patch.
def runDataRef(self, patchRefList)
Run multiband processing on coadds.
def runForcedPhot(self, cache, dataId)
Run forced photometry on a patch for a single filter.
def __init__(self, butler=None, schema=None, refObjLoader=None, reuse=tuple(), **kwargs)
def __init__(self, TaskClass, parsedCmd, doReturnResults=False)
def unpickle(factory, args, kwargs)
def getDataRef(butler, dataId, datasetType="raw")
Definition: utils.py:16