lsst.pipe.base  18.0.0-1-g3877d06
graphBuilder.py
Go to the documentation of this file.
1 # This file is part of pipe_base.
2 #
3 # Developed for the LSST Data Management System.
4 # This product includes software developed by the LSST Project
5 # (http://www.lsst.org).
6 # See the COPYRIGHT file at the top-level directory of this distribution
7 # for details of code ownership.
8 #
9 # This program is free software: you can redistribute it and/or modify
10 # it under the terms of the GNU General Public License as published by
11 # the Free Software Foundation, either version 3 of the License, or
12 # (at your option) any later version.
13 #
14 # This program is distributed in the hope that it will be useful,
15 # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 # GNU General Public License for more details.
18 #
19 # You should have received a copy of the GNU General Public License
20 # along with this program. If not, see <http://www.gnu.org/licenses/>.
21 
22 """Module defining GraphBuilder class and related methods.
23 """
24 
25 __all__ = ['GraphBuilder']
26 
27 # -------------------------------
28 # Imports of standard modules --
29 # -------------------------------
30 import copy
31 from collections import namedtuple
32 from itertools import chain
33 import logging
34 
35 # -----------------------------
36 # Imports for other modules --
37 # -----------------------------
38 from .graph import QuantumGraphTaskNodes, QuantumGraph
39 from lsst.daf.butler import Quantum, DatasetRef, DimensionSet
40 
41 # ----------------------------------
42 # Local non-exported definitions --
43 # ----------------------------------
44 
45 _LOG = logging.getLogger(__name__.partition(".")[2])
46 
47 # Tuple containing TaskDef, its input dataset types and output dataset types
48 #
49 # Attributes
50 # ----------
51 # taskDef : `TaskDef`
52 # inputs : `set` of `DatasetType`
53 # outputs : `set` of `DatasetType`
54 # initTnputs : `set` of `DatasetType`
55 # initOutputs : `set` of `DatasetType`
56 # perDatasetTypeDimensions : `~lsst.daf.butler.DimensionSet`
57 # prerequisite : `set` of `DatasetType`
58 _TaskDatasetTypes = namedtuple("_TaskDatasetTypes", ("taskDef", "inputs", "outputs",
59  "initInputs", "initOutputs",
60  "perDatasetTypeDimensions", "prerequisite"))
61 
62 
63 class GraphBuilderError(Exception):
64  """Base class for exceptions generated by graph builder.
65  """
66  pass
67 
68 
70  """Exception generated when output datasets already exist.
71  """
72 
73  def __init__(self, taskName, refs):
74  refs = ', '.join(str(ref) for ref in refs)
75  msg = "Output datasets already exist for task {}: {}".format(taskName, refs)
76  GraphBuilderError.__init__(self, msg)
77 
78 
80  """Exception generated when a prerequisite dataset does not exist.
81  """
82  pass
83 
84 
85 class GraphBuilder(object):
86  """
87  GraphBuilder class is responsible for building task execution graph from
88  a Pipeline.
89 
90  Parameters
91  ----------
92  taskFactory : `TaskFactory`
93  Factory object used to load/instantiate PipelineTasks
94  registry : `~lsst.daf.butler.Registry`
95  Data butler instance.
96  skipExisting : `bool`, optional
97  If ``True`` (default) then Quantum is not created if all its outputs
98  already exist, otherwise exception is raised.
99  """
100 
101  def __init__(self, taskFactory, registry, skipExisting=True):
102  self.taskFactory = taskFactory
103  self.registry = registry
104  self.dimensions = registry.dimensions
105  self.skipExisting = skipExisting
106 
107  def _loadTaskClass(self, taskDef):
108  """Make sure task class is loaded.
109 
110  Load task class, update task name to make sure it is fully-qualified,
111  do not update original taskDef in a Pipeline though.
112 
113  Parameters
114  ----------
115  taskDef : `TaskDef`
116 
117  Returns
118  -------
119  `TaskDef` instance, may be the same as parameter if task class is
120  already loaded.
121  """
122  if taskDef.taskClass is None:
123  tClass, tName = self.taskFactory.loadTaskClass(taskDef.taskName)
124  taskDef = copy.copy(taskDef)
125  taskDef.taskClass = tClass
126  taskDef.taskName = tName
127  return taskDef
128 
129  def makeGraph(self, pipeline, originInfo, userQuery):
130  """Create execution graph for a pipeline.
131 
132  Parameters
133  ----------
134  pipeline : `Pipeline`
135  Pipeline definition, task names/classes and their configs.
136  originInfo : `~lsst.daf.butler.DatasetOriginInfo`
137  Object which provides names of the input/output collections.
138  userQuery : `str`
139  String which defunes user-defined selection for registry, should be
140  empty or `None` if there is no restrictions on data selection.
141 
142  Returns
143  -------
144  graph : `QuantumGraph`
145 
146  Raises
147  ------
148  UserExpressionError
149  Raised when user expression cannot be parsed.
150  OutputExistsError
151  Raised when output datasets already exist.
152  Exception
153  Other exceptions types may be raised by underlying registry
154  classes.
155  """
156 
157  # make sure all task classes are loaded
158  taskList = [self._loadTaskClass(taskDef) for taskDef in pipeline]
159 
160  # collect inputs/outputs from each task
161  taskDatasets = []
162  for taskDef in taskList:
163  taskClass = taskDef.taskClass
164  inputs = {k: v.makeDatasetType(self.registry.dimensions)
165  for k, v in taskClass.getInputDatasetTypes(taskDef.config).items()}
166  prerequisite = set(inputs[k] for k in taskClass.getPrerequisiteDatasetTypes(taskDef.config))
167  taskIo = [inputs.values()]
168  for attr in ("Output", "InitInput", "InitOutput"):
169  getter = getattr(taskClass, f"get{attr}DatasetTypes")
170  ioObject = getter(taskDef.config) or {}
171  taskIo.append(set(dsTypeDescr.makeDatasetType(self.registry.dimensions)
172  for dsTypeDescr in ioObject.values()))
173  perDatasetTypeDimensions = DimensionSet(self.registry.dimensions,
174  taskClass.getPerDatasetTypeDimensions(taskDef.config))
175  taskDatasets.append(_TaskDatasetTypes(taskDef, *taskIo, prerequisite=prerequisite,
176  perDatasetTypeDimensions=perDatasetTypeDimensions))
177 
178  perDatasetTypeDimensions = self._extractPerDatasetTypeDimensions(taskDatasets)
179 
180  # categorize dataset types for the full Pipeline
181  required, optional, prerequisite, initInputs, initOutputs = self._makeFullIODatasetTypes(taskDatasets)
182 
183  # make a graph
184  return self._makeGraph(taskDatasets, required, optional, prerequisite, initInputs, initOutputs,
185  originInfo, userQuery, perDatasetTypeDimensions=perDatasetTypeDimensions)
186 
187  def _extractPerDatasetTypeDimensions(self, taskDatasets):
188  """Return the complete set of all per-DatasetType dimensions declared
189  by any task.
190 
191  Per-DatasetType dimensions are those that need not have the same values
192  for different Datasets within a Quantum.
193 
194  Parameters
195  ----------
196  taskDatasets : sequence of `_TaskDatasetTypes`
197  Information for each task in the pipeline.
198 
199  Returns
200  -------
201  perDatasetTypeDimensions : `~lsst.daf.butler.DimensionSet`
202  All per-DatasetType dimensions.
203 
204  Raises
205  ------
206  ValueError
207  Raised if tasks disagree on whether a dimension is declared
208  per-DatasetType.
209  """
210  # Empty dimension set, just used to construct more DimensionSets via
211  # union method.
212  noDimensions = DimensionSet(self.registry.dimensions, ())
213  # Construct pipeline-wide perDatasetTypeDimensions set from union of
214  # all Task-level perDatasetTypeDimensions.
215  perDatasetTypeDimensions = noDimensions.union(
216  *[taskDs.perDatasetTypeDimensions for taskDs in taskDatasets]
217  )
218  # Check that no tasks want any of these as common (i.e. not
219  # per-DatasetType) dimensions.
220  for taskDs in taskDatasets:
221  allTaskDimensions = noDimensions.union(
222  *[datasetType.dimensions for datasetType in chain(taskDs.inputs, taskDs.outputs)]
223  )
224  commonTaskDimensions = allTaskDimensions - taskDs.perDatasetTypeDimensions
225  if not commonTaskDimensions.isdisjoint(perDatasetTypeDimensions):
226  overlap = commonTaskDimensions.intersections(perDatasetTypeDimensions)
227  raise ValueError(
228  f"Task {taskDs.taskDef.taskName} uses dimensions {overlap} without declaring them "
229  f"per-DatasetType, but they are declared per-DatasetType by another task."
230  )
231  return perDatasetTypeDimensions
232 
233  def _makeFullIODatasetTypes(self, taskDatasets):
234  """Returns full set of input and output dataset types for all tasks.
235 
236  Parameters
237  ----------
238  taskDatasets : sequence of `_TaskDatasetTypes`
239  Tasks with their inputs, outputs, initInputs and initOutputs.
240 
241  Returns
242  -------
243  required : `set` of `~lsst.daf.butler.DatasetType`
244  Datasets that must exist in the repository in order to generate
245  a QuantumGraph node that consumes them.
246  optional : `set` of `~lsst.daf.butler.DatasetType`
247  Datasets that will be produced by the graph, but may exist in the
248  repository. If ``self.skipExisting`` is `True` and all outputs of
249  a particular node already exist, it will be skipped. Otherwise
250  pre-existing datasets of these types will cause
251  `OutputExistsError` to be raised.
252  prerequisite : `set` of `~lsst.daf.butler.DatasetType`
253  Datasets that must exist in the repository, but whose absence
254  should cause `PrerequisiteMissingError` to be raised if they
255  are needed by any graph node that would otherwise be created.
256  initInputs : `set` of `~lsst.daf.butler.DatasetType`
257  Datasets used as init method inputs by the pipeline.
258  initOutputs : `set` of `~lsst.daf.butler.DatasetType`
259  Datasets used as init method outputs by the pipeline.
260  """
261  # to build initial dataset graph we have to collect info about all
262  # datasets to be used by this pipeline
263  allDatasetTypes = {}
264  required = set()
265  optional = set()
266  prerequisite = set()
267  initInputs = set()
268  initOutputs = set()
269  for taskDs in taskDatasets:
270  for ioType, ioSet in zip(("inputs", "outputs", "prerequisite", "initInputs", "initOutputs"),
271  (required, optional, prerequisite, initInputs, initOutputs)):
272  for dsType in getattr(taskDs, ioType):
273  ioSet.add(dsType.name)
274  allDatasetTypes[dsType.name] = dsType
275 
276  # Any dataset the pipeline produces can't be required or prerequisite
277  required -= optional
278  prerequisite -= optional
279 
280  # remove initOutputs from initInputs
281  initInputs -= initOutputs
282 
283  required = set(allDatasetTypes[name] for name in required)
284  optional = set(allDatasetTypes[name] for name in optional)
285  prerequisite = set(allDatasetTypes[name] for name in prerequisite)
286  initInputs = set(allDatasetTypes[name] for name in initInputs)
287  initOutputs = set(allDatasetTypes[name] for name in initOutputs)
288  return required, optional, prerequisite, initInputs, initOutputs
289 
290  def _makeGraph(self, taskDatasets, required, optional, prerequisite,
291  initInputs, initOutputs, originInfo, userQuery,
292  perDatasetTypeDimensions=()):
293  """Make QuantumGraph instance.
294 
295  Parameters
296  ----------
297  taskDatasets : sequence of `_TaskDatasetTypes`
298  Tasks with their inputs and outputs.
299  required : `set` of `~lsst.daf.butler.DatasetType`
300  Datasets that must exist in the repository in order to generate
301  a QuantumGraph node that consumes them.
302  optional : `set` of `~lsst.daf.butler.DatasetType`
303  Datasets that will be produced by the graph, but may exist in
304  the repository. If ``self.skipExisting`` and all outputs of a
305  particular node already exist, it will be skipped. Otherwise
306  pre-existing datasets of these types will cause
307  `OutputExistsError` to be raised.
308  prerequisite : `set` of `~lsst.daf.butler.DatasetType`
309  Datasets that must exist in the repository, but whose absence
310  should cause `PrerequisiteMissingError` to be raised if they
311  are needed by any graph node that would otherwise be created.
312  initInputs : `set` of `DatasetType`
313  Datasets which should exist in input repository, and will be used
314  in task initialization
315  initOutputs : `set` of `DatasetType`
316  Datasets which which will be created in task initialization
317  originInfo : `DatasetOriginInfo`
318  Object which provides names of the input/output collections.
319  userQuery : `str`
320  String which defines user-defined selection for registry, should be
321  empty or `None` if there is no restrictions on data selection.
322  perDatasetTypeDimensions : iterable of `Dimension` or `str`
323  Dimensions (or names thereof) that may have different values for
324  different dataset types within the same quantum.
325 
326  Returns
327  -------
328  `QuantumGraph` instance.
329  """
330  rows = self.registry.selectMultipleDatasetTypes(
331  originInfo, userQuery,
332  required=required, optional=optional, prerequisite=prerequisite,
333  perDatasetTypeDimensions=perDatasetTypeDimensions
334  )
335 
336  # store result locally for multi-pass algorithm below
337  # TODO: change it to single pass
338  dimensionVerse = []
339  try:
340  for row in rows:
341  _LOG.debug("row: %s", row)
342  dimensionVerse.append(row)
343  except LookupError as err:
344  raise PrerequisiteMissingError(str(err)) from err
345 
346  # Next step is to group by task quantum dimensions
347  qgraph = QuantumGraph()
348  qgraph._inputDatasetTypes = (required | prerequisite)
349  qgraph._outputDatasetTypes = optional
350  for dsType in initInputs:
351  for collection in originInfo.getInputCollections(dsType.name):
352  result = self.registry.find(collection, dsType)
353  if result is not None:
354  qgraph.initInputs.append(result)
355  break
356  else:
357  raise GraphBuilderError(f"Could not find initInput {dsType.name} in any input"
358  " collection")
359  for dsType in initOutputs:
360  qgraph.initOutputs.append(DatasetRef(dsType, {}))
361 
362  for taskDss in taskDatasets:
363  taskQuantaInputs = {} # key is the quantum dataId (as tuple)
364  taskQuantaOutputs = {} # key is the quantum dataId (as tuple)
365  qlinks = []
366  for dimensionName in taskDss.taskDef.config.quantum.dimensions:
367  dimension = self.dimensions[dimensionName]
368  qlinks += dimension.links()
369  _LOG.debug("task %s qdimensions: %s", taskDss.taskDef.label, qlinks)
370 
371  # some rows will be non-unique for subset of dimensions, create
372  # temporary structure to remove duplicates
373  for row in dimensionVerse:
374  qkey = tuple((col, row.dataId[col]) for col in qlinks)
375  _LOG.debug("qkey: %s", qkey)
376 
377  def _datasetRefKey(datasetRef):
378  return tuple(sorted(datasetRef.dataId.items()))
379 
380  qinputs = taskQuantaInputs.setdefault(qkey, {})
381  for dsType in taskDss.inputs:
382  datasetRefs = qinputs.setdefault(dsType, {})
383  datasetRef = row.datasetRefs[dsType]
384  datasetRefs[_datasetRefKey(datasetRef)] = datasetRef
385  _LOG.debug("add input datasetRef: %s %s", dsType.name, datasetRef)
386 
387  qoutputs = taskQuantaOutputs.setdefault(qkey, {})
388  for dsType in taskDss.outputs:
389  datasetRefs = qoutputs.setdefault(dsType, {})
390  datasetRef = row.datasetRefs[dsType]
391  datasetRefs[_datasetRefKey(datasetRef)] = datasetRef
392  _LOG.debug("add output datasetRef: %s %s", dsType.name, datasetRef)
393 
394  # all nodes for this task
395  quanta = []
396  for qkey in taskQuantaInputs:
397  # taskQuantaInputs and taskQuantaOutputs have the same keys
398  _LOG.debug("make quantum for qkey: %s", qkey)
399  quantum = Quantum(run=None, task=None)
400 
401  # add all outputs, but check first that outputs don't exist
402  outputs = list(chain.from_iterable(datasetRefs.values()
403  for datasetRefs in taskQuantaOutputs[qkey].values()))
404  for ref in outputs:
405  _LOG.debug("add output: %s", ref)
406  if self.skipExisting and all(ref.id is not None for ref in outputs):
407  _LOG.debug("all output datasetRefs already exist, skip quantum")
408  continue
409  if any(ref.id is not None for ref in outputs):
410  # some outputs exist, can't override them
411  raise OutputExistsError(taskDss.taskDef.taskName, outputs)
412 
413  for ref in outputs:
414  quantum.addOutput(ref)
415 
416  # add all inputs
417  for datasetRefs in taskQuantaInputs[qkey].values():
418  for ref in datasetRefs.values():
419  quantum.addPredictedInput(ref)
420  _LOG.debug("add input: %s", ref)
421 
422  quanta.append(quantum)
423 
424  qgraph.append(QuantumGraphTaskNodes(taskDss.taskDef, quanta))
425 
426  return qgraph
def makeGraph(self, pipeline, originInfo, userQuery)
def _makeFullIODatasetTypes(self, taskDatasets)
def _makeGraph(self, taskDatasets, required, optional, prerequisite, initInputs, initOutputs, originInfo, userQuery, perDatasetTypeDimensions=())
def __init__(self, taskFactory, registry, skipExisting=True)
def _extractPerDatasetTypeDimensions(self, taskDatasets)