lsst.pipe.base  16.0-25-g2c6bf4a+1
graphBuilder.py
Go to the documentation of this file.
1 # This file is part of pipe_base.
2 #
3 # Developed for the LSST Data Management System.
4 # This product includes software developed by the LSST Project
5 # (http://www.lsst.org).
6 # See the COPYRIGHT file at the top-level directory of this distribution
7 # for details of code ownership.
8 #
9 # This program is free software: you can redistribute it and/or modify
10 # it under the terms of the GNU General Public License as published by
11 # the Free Software Foundation, either version 3 of the License, or
12 # (at your option) any later version.
13 #
14 # This program is distributed in the hope that it will be useful,
15 # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 # GNU General Public License for more details.
18 #
19 # You should have received a copy of the GNU General Public License
20 # along with this program. If not, see <http://www.gnu.org/licenses/>.
21 
22 """Module defining GraphBuilder class and related methods.
23 """
24 
25 __all__ = ['GraphBuilder']
26 
27 # -------------------------------
28 # Imports of standard modules --
29 # -------------------------------
30 import copy
31 from collections import namedtuple
32 from itertools import chain
33 import logging
34 
35 # -----------------------------
36 # Imports for other modules --
37 # -----------------------------
38 from .graph import QuantumGraphNodes, QuantumGraph
39 from lsst.daf.butler import Quantum, DatasetRef
40 from lsst.daf.butler.exprParser import ParserYacc, ParserYaccError
41 
42 # ----------------------------------
43 # Local non-exported definitions --
44 # ----------------------------------
45 
46 _LOG = logging.getLogger(__name__.partition(".")[2])
47 
48 # Tuple containing TaskDef, its input dataset types and output dataset types
49 #
50 # Attributes
51 # ----------
52 # taskDef : `TaskDef`
53 # inputs : `list` of `DatasetType`
54 # outputs : `list` of `DatasetType`
55 _TaskDatasetTypes = namedtuple("_TaskDatasetTypes", "taskDef inputs outputs initInputs initOutputs")
56 
57 
58 class GraphBuilderError(Exception):
59  """Base class for exceptions generated by graph builder.
60  """
61  pass
62 
63 
65  """Exception generated by graph builder for error in user expression.
66  """
67 
68  def __init__(self, expr, exc):
69  msg = "Failed to parse user expression `{}' ({})".format(expr, exc)
70  GraphBuilderError.__init__(self, msg)
71 
72 
74  """Exception generated when output datasets already exist.
75  """
76 
77  def __init__(self, taskName, refs):
78  refs = ', '.join(str(ref) for ref in refs)
79  msg = "Output datasets already exist for task {}: {}".format(taskName, refs)
80  GraphBuilderError.__init__(self, msg)
81 
82 
83 # ------------------------
84 # Exported definitions --
85 # ------------------------
86 
87 
88 class GraphBuilder(object):
89  """
90  GraphBuilder class is responsible for building task execution graph from
91  a Pipeline.
92 
93  Parameters
94  ----------
95  taskFactory : `TaskFactory`
96  Factory object used to load/instantiate PipelineTasks
97  registry : `~lsst.daf.butler.Registry`
98  Data butler instance.
99  skipExisting : `bool`, optional
100  If ``True`` (default) then Quantum is not created if all its outputs
101  already exist, otherwise exception is raised.
102  """
103 
104  def __init__(self, taskFactory, registry, skipExisting=True):
105  self.taskFactory = taskFactory
106  self.registry = registry
107  self.dimensions = registry.dimensions
108  self.skipExisting = skipExisting
109 
110  @staticmethod
111  def _parseUserQuery(userQuery):
112  """Parse user query.
113 
114  Parameters
115  ----------
116  userQuery : `str`
117  User expression string specifying data selecton.
118 
119  Returns
120  -------
121  `exprTree.Node` instance representing parsed expression tree.
122  """
123  parser = ParserYacc()
124  # do parsing, this will raise exception
125  try:
126  tree = parser.parse(userQuery)
127  _LOG.debug("parsed expression: %s", tree)
128  except ParserYaccError as exc:
129  raise UserExpressionError(userQuery, exc)
130  return tree
131 
132  def _loadTaskClass(self, taskDef):
133  """Make sure task class is loaded.
134 
135  Load task class, update task name to make sure it is fully-qualified,
136  do not update original taskDef in a Pipeline though.
137 
138  Parameters
139  ----------
140  taskDef : `TaskDef`
141 
142  Returns
143  -------
144  `TaskDef` instance, may be the same as parameter if task class is
145  already loaded.
146  """
147  if taskDef.taskClass is None:
148  tClass, tName = self.taskFactory.loadTaskClass(taskDef.taskName)
149  taskDef = copy.copy(taskDef)
150  taskDef.taskClass = tClass
151  taskDef.taskName = tName
152  return taskDef
153 
154  def makeGraph(self, pipeline, originInfo, userQuery):
155  """Create execution graph for a pipeline.
156 
157  Parameters
158  ----------
159  pipeline : `Pipeline`
160  Pipeline definition, task names/classes and their configs.
161  originInfo : `~lsst.daf.butler.DatasetOriginInfo`
162  Object which provides names of the input/output collections.
163  userQuery : `str`
164  String which defunes user-defined selection for registry, should be
165  empty or `None` if there is no restrictions on data selection.
166 
167  Returns
168  -------
169  graph : `QuantumGraph`
170 
171  Raises
172  ------
173  UserExpressionError
174  Raised when user expression cannot be parsed.
175  OutputExistsError
176  Raised when output datasets already exist.
177  Exception
178  Other exceptions types may be raised by underlying registry
179  classes.
180  """
181 
182  # make sure all task classes are loaded
183  taskList = [self._loadTaskClass(taskDef) for taskDef in pipeline]
184 
185  # collect inputs/outputs from each task
186  taskDatasets = []
187  for taskDef in taskList:
188  taskClass = taskDef.taskClass
189  taskIo = []
190  for attr in ("Input", "Output", "InitInput", "InitOutput"):
191  getter = getattr(taskClass, f"get{attr}DatasetTypes")
192  ioObject = getter(taskDef.config) or {}
193  taskIo.append([dsTypeDescr.datasetType for dsTypeDescr in ioObject.values()])
194  taskDatasets.append(_TaskDatasetTypes(taskDef, *taskIo))
195 
196  # build initial dataset graph
197  inputs, outputs, initInputs, initOutputs = self._makeFullIODatasetTypes(taskDatasets)
198 
199  # make a graph
200  return self._makeGraph(taskDatasets, inputs, outputs, initInputs, initOutputs,
201  originInfo, userQuery)
202 
203  def _makeFullIODatasetTypes(self, taskDatasets):
204  """Returns full set of input and output dataset types for all tasks.
205 
206  Parameters
207  ----------
208  taskDatasets : sequence of `_TaskDatasetTypes`
209  Tasks with their inputs, outputs, initInputs and initOutputs.
210 
211  Returns
212  -------
213  inputs : `set` of `butler.DatasetType`
214  Datasets used as inputs by the pipeline.
215  outputs : `set` of `butler.DatasetType`
216  Datasets produced by the pipeline.
217  initInputs : `set` of `butler.DatasetType`
218  Datasets used as init method inputs by the pipeline.
219  initOutputs : `set` of `butler.DatasetType`
220  Datasets used as init method outputs by the pipeline.
221  """
222  # to build initial dataset graph we have to collect info about all
223  # datasets to be used by this pipeline
224  allDatasetTypes = {}
225  inputs = set()
226  outputs = set()
227  initInputs = set()
228  initOutputs = set()
229  for taskDs in taskDatasets:
230  for ioType, ioSet in zip(("inputs", "outputs", "initInputs", "initOutputs"),
231  (inputs, outputs, initInputs, initOutputs)):
232  for dsType in getattr(taskDs, ioType):
233  ioSet.add(dsType.name)
234  allDatasetTypes[dsType.name] = dsType
235  # remove outputs from inputs
236  inputs -= outputs
237 
238  # remove initOutputs from initInputs
239  initInputs -= initOutputs
240 
241  inputs = set(allDatasetTypes[name] for name in inputs)
242  outputs = set(allDatasetTypes[name] for name in outputs)
243  initInputs = set(allDatasetTypes[name] for name in initInputs)
244  initOutputs = set(allDatasetTypes[name] for name in initOutputs)
245  return inputs, outputs, initInputs, initOutputs
246 
247  def _makeGraph(self, taskDatasets, inputs, outputs, initInputs, initOutputs, originInfo, userQuery):
248  """Make QuantumGraph instance.
249 
250  Parameters
251  ----------
252  taskDatasets : sequence of `_TaskDatasetTypes`
253  Tasks with their inputs and outputs.
254  inputs : `set` of `DatasetType`
255  Datasets which should already exist in input repository
256  outputs : `set` of `DatasetType`
257  Datasets which will be created by tasks
258  initInputs : `set` of `DatasetType`
259  Datasets which should exist in input repository, and will be used
260  in task initialization
261  initOutputs : `set` of `DatasetType`
262  Datasets which which will be created in task initialization
263  originInfo : `DatasetOriginInfo`
264  Object which provides names of the input/output collections.
265  userQuery : `str`
266  String which defunes user-defined selection for registry, should be
267  empty or `None` if there is no restrictions on data selection.
268 
269  Returns
270  -------
271  `QuantumGraph` instance.
272  """
273  parsedQuery = self._parseUserQuery(userQuery or "")
274  expr = None if parsedQuery is None else str(parsedQuery)
275  rows = self.registry.selectDimensions(originInfo, expr, inputs, outputs)
276 
277  # store result locally for multi-pass algorithm below
278  # TODO: change it to single pass
279  dimensionVerse = []
280  for row in rows:
281  _LOG.debug("row: %s", row)
282  dimensionVerse.append(row)
283 
284  # Next step is to group by task quantum dimensions
285  qgraph = QuantumGraph()
286  qgraph._inputDatasetTypes = inputs
287  qgraph._outputDatasetTypes = outputs
288  for dsType in initInputs:
289  for collection in originInfo.getInputCollections(dsType.name):
290  result = self.registry.find(collection, dsType)
291  if result is not None:
292  qgraph.initInputs.append(result)
293  break
294  else:
295  raise GraphBuilderError(f"Could not find initInput {dsType.name} in any input"
296  " collection")
297  for dsType in initOutputs:
298  qgraph.initOutputs.append(DatasetRef(dsType, {}))
299 
300  for taskDss in taskDatasets:
301  taskQuantaInputs = {} # key is the quantum dataId (as tuple)
302  taskQuantaOutputs = {} # key is the quantum dataId (as tuple)
303  qlinks = []
304  for dimensionName in taskDss.taskDef.config.quantum.dimensions:
305  dimension = self.dimensions[dimensionName]
306  qlinks += dimension.links()
307  _LOG.debug("task %s qdimensions: %s", taskDss.taskDef.label, qlinks)
308 
309  # some rows will be non-unique for subset of dimensions, create
310  # temporary structure to remove duplicates
311  for row in dimensionVerse:
312  qkey = tuple((col, row.dataId[col]) for col in qlinks)
313  _LOG.debug("qkey: %s", qkey)
314 
315  def _dataRefKey(dataRef):
316  return tuple(sorted(dataRef.dataId.items()))
317 
318  qinputs = taskQuantaInputs.setdefault(qkey, {})
319  for dsType in taskDss.inputs:
320  dataRefs = qinputs.setdefault(dsType, {})
321  dataRef = row.datasetRefs[dsType]
322  dataRefs[_dataRefKey(dataRef)] = dataRef
323  _LOG.debug("add input dataRef: %s %s", dsType.name, dataRef)
324 
325  qoutputs = taskQuantaOutputs.setdefault(qkey, {})
326  for dsType in taskDss.outputs:
327  dataRefs = qoutputs.setdefault(dsType, {})
328  dataRef = row.datasetRefs[dsType]
329  dataRefs[_dataRefKey(dataRef)] = dataRef
330  _LOG.debug("add output dataRef: %s %s", dsType.name, dataRef)
331 
332  # pre-flight does not fill dataset components, and graph users
333  # may need to know that, re-retrieve all input datasets to have
334  # their components properly filled.
335  for qinputs in taskQuantaInputs.values():
336  for dataRefs in qinputs.values():
337  for key in dataRefs.keys():
338  if dataRefs[key].id is not None:
339  dataRefs[key] = self.registry.getDataset(dataRefs[key].id)
340 
341  # all nodes for this task
342  quanta = []
343  for qkey in taskQuantaInputs:
344  # taskQuantaInputs and taskQuantaOutputs have the same keys
345  _LOG.debug("make quantum for qkey: %s", qkey)
346  quantum = Quantum(run=None, task=None)
347 
348  # add all outputs, but check first that outputs don't exist
349  outputs = list(chain.from_iterable(dataRefs.values()
350  for dataRefs in taskQuantaOutputs[qkey].values()))
351  for ref in outputs:
352  _LOG.debug("add output: %s", ref)
353  if self.skipExisting and all(ref.id is not None for ref in outputs):
354  _LOG.debug("all output dataRefs already exist, skip quantum")
355  continue
356  if any(ref.id is not None for ref in outputs):
357  # some outputs exist, can't override them
358  raise OutputExistsError(taskDss.taskDef.taskName, outputs)
359  for ref in outputs:
360  quantum.addOutput(ref)
361 
362  # add all inputs
363  for dataRefs in taskQuantaInputs[qkey].values():
364  for ref in dataRefs.values():
365  quantum.addPredictedInput(ref)
366  _LOG.debug("add input: %s", ref)
367 
368  quanta.append(quantum)
369 
370  qgraph.append(QuantumGraphNodes(taskDss.taskDef, quanta))
371 
372  return qgraph
def makeGraph(self, pipeline, originInfo, userQuery)
def _makeGraph(self, taskDatasets, inputs, outputs, initInputs, initOutputs, originInfo, userQuery)
def _makeFullIODatasetTypes(self, taskDatasets)
def __init__(self, taskFactory, registry, skipExisting=True)