Coverage for python/lsst/pipe/base/graph/graph.py : 22%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ("QuantumGraph", "IncompatibleGraphError")
25from collections import defaultdict
27from itertools import chain, count
28import io
29import networkx as nx
30from networkx.drawing.nx_agraph import write_dot
31import os
32import pickle
33import time
34from typing import (DefaultDict, Dict, FrozenSet, Iterable, List, Mapping, Set, Generator, Optional, Tuple,
35 Union, TypeVar)
37from ..connections import iterConnections
38from ..pipeline import TaskDef
39from lsst.daf.butler import Quantum, DatasetRef
41from ._implDetails import _DatasetTracker, DatasetTypeName
42from .quantumNode import QuantumNode, NodeId, BuildId
44_T = TypeVar("_T", bound="QuantumGraph")
47class IncompatibleGraphError(Exception):
48 """Exception class to indicate that a lookup by NodeId is impossible due
49 to incompatibilities
50 """
51 pass
54class QuantumGraph:
55 """QuantumGraph is a directed acyclic graph of `QuantumNode` objects
57 This data structure represents a concrete workflow generated from a
58 `Pipeline`.
60 Parameters
61 ----------
62 quanta : Mapping of `TaskDef` to sets of `Quantum`
63 This maps tasks (and their configs) to the sets of data they are to
64 process.
65 """
66 def __init__(self, quanta: Mapping[TaskDef, Set[Quantum]]):
67 self._buildGraphs(quanta)
69 def _buildGraphs(self,
70 quanta: Mapping[TaskDef, Set[Quantum]],
71 *,
72 _quantumToNodeId: Optional[Mapping[Quantum, NodeId]] = None,
73 _buildId: Optional[BuildId] = None):
74 """Builds the graph that is used to store the relation between tasks,
75 and the graph that holds the relations between quanta
76 """
77 self._quanta = quanta
78 self._buildId = _buildId if _buildId is not None else BuildId(f"{time.time()}-{os.getpid()}")
79 # Data structures used to identify relations between components;
80 # DatasetTypeName -> TaskDef for task,
81 # and DatasetRef -> QuantumNode for the quanta
82 self._datasetDict = _DatasetTracker[DatasetTypeName, TaskDef]()
83 self._datasetRefDict = _DatasetTracker[DatasetRef, QuantumNode]()
85 nodeNumberGenerator = count()
86 self._nodeIdMap: Dict[NodeId, QuantumNode] = {}
87 self._count = 0
88 for taskDef, quantumSet in self._quanta.items():
89 connections = taskDef.connections
91 # For each type of connection in the task, add a key to the
92 # `_DatasetTracker` for the connections name, with a value of
93 # the TaskDef in the appropriate field
94 for inpt in iterConnections(connections, ("inputs", "prerequisiteInputs", "initInputs")):
95 self._datasetDict.addInput(DatasetTypeName(inpt.name), taskDef)
97 for output in iterConnections(connections, ("outputs", "initOutputs")):
98 self._datasetDict.addOutput(DatasetTypeName(output.name), taskDef)
100 # For each `Quantum` in the set of all `Quantum` for this task,
101 # add a key to the `_DatasetTracker` that is a `DatasetRef` for one
102 # of the individual datasets inside the `Quantum`, with a value of
103 # a newly created QuantumNode to the appropriate input/output
104 # field.
105 self._count += len(quantumSet)
106 for quantum in quantumSet:
107 if _quantumToNodeId:
108 nodeId = _quantumToNodeId.get(quantum)
109 if nodeId is None:
110 raise ValueError("If _quantuMToNodeNumber is not None, all quanta must have an "
111 "associated value in the mapping")
112 else:
113 nodeId = NodeId(next(nodeNumberGenerator), self._buildId)
115 inits = quantum.initInputs.values()
116 inputs = quantum.inputs.values()
117 value = QuantumNode(quantum, taskDef, nodeId)
118 self._nodeIdMap[nodeId] = value
120 for dsRef in chain(inits, inputs):
121 # unfortunately, `Quantum` allows inits to be individual
122 # `DatasetRef`s or an Iterable of such, so there must
123 # be an instance check here
124 if isinstance(dsRef, Iterable):
125 for sub in dsRef:
126 self._datasetRefDict.addInput(sub, value)
127 else:
128 self._datasetRefDict.addInput(dsRef, value)
129 for dsRef in chain.from_iterable(quantum.outputs.values()):
130 self._datasetRefDict.addOutput(dsRef, value)
132 # Graph of task relations, used in various methods
133 self._taskGraph = self._datasetDict.makeNetworkXGraph()
135 # Graph of quanta relations
136 self._connectedQuanta = self._datasetRefDict.makeNetworkXGraph()
138 @property
139 def taskGraph(self) -> nx.DiGraph:
140 """Return a graph representing the relations between the tasks inside
141 the quantum graph.
143 Returns
144 -------
145 taskGraph : `networkx.Digraph`
146 Internal datastructure that holds relations of `TaskDef` objects
147 """
148 return self._taskGraph
150 @property
151 def graph(self) -> nx.DiGraph:
152 """Return a graph representing the relations between all the
153 `QuantumNode` objects. Largely it should be preferred to iterate
154 over, and use methods of this class, but sometimes direct access to
155 the networkx object may be helpful
157 Returns
158 -------
159 graph : `networkx.Digraph`
160 Internal datastructure that holds relations of `QuantumNode`
161 objects
162 """
163 return self._connectedQuanta
165 @property
166 def inputQuanta(self) -> Iterable[QuantumNode]:
167 """Make a `list` of all `QuantumNode` objects that are 'input' nodes
168 to the graph, meaning those nodes to not depend on any other nodes in
169 the graph.
171 Returns
172 -------
173 inputNodes : iterable of `QuantumNode`
174 A list of nodes that are inputs to the graph
175 """
176 return (q for q, n in self._connectedQuanta.in_degree if n == 0)
178 @property
179 def outputQuanta(self) -> Iterable[QuantumNode]:
180 """Make a `list` of all `QuantumNode` objects that are 'output' nodes
181 to the graph, meaning those nodes have no nodes that depend them in
182 the graph.
184 Returns
185 -------
186 outputNodes : iterable of `QuantumNode`
187 A list of nodes that are outputs of the graph
188 """
189 return [q for q, n in self._connectedQuanta.out_degree if n == 0]
191 @property
192 def allDatasetTypes(self) -> Tuple[DatasetTypeName, ...]:
193 """Return all the `DatasetTypeName` objects that are contained inside
194 the graph.
196 Returns
197 -------
198 tuple of `DatasetTypeName`
199 All the data set type names that are present in the graph
200 """
201 return tuple(self._datasetDict.keys())
203 @property
204 def isConnected(self) -> bool:
205 """Return True if all of the nodes in the graph are connected, ignores
206 directionality of connections.
207 """
208 return nx.is_weakly_connected(self._connectedQuanta)
210 def getQuantumNodeByNodeId(self, nodeId: NodeId) -> QuantumNode:
211 """Lookup a `QuantumNode` from an id associated with the node.
213 Parameters
214 ----------
215 nodeId : `NodeId`
216 The number associated with a node
218 Returns
219 -------
220 node : `QuantumNode`
221 The node corresponding with input number
223 Raises
224 ------
225 IndexError
226 Raised if the requested nodeId is not in the graph.
227 IncompatibleGraphError
228 Raised if the nodeId was built with a different graph than is not
229 this instance (or a graph instance that produced this instance
230 through and operation such as subset)
231 """
232 if nodeId.buildId != self._buildId:
233 raise IncompatibleGraphError("This node was built from a different, incompatible, graph instance")
234 return self._nodeIdMap[nodeId]
236 def getQuantaForTask(self, taskDef: TaskDef) -> FrozenSet[Quantum]:
237 """Return all the `Quantum` associated with a `TaskDef`.
239 Parameters
240 ----------
241 taskDef : `TaskDef`
242 The `TaskDef` for which `Quantum` are to be queried
244 Returns
245 -------
246 frozenset of `Quantum`
247 The `set` of `Quantum` that is associated with the specified
248 `TaskDef`.
249 """
250 return frozenset(self._quanta[taskDef])
252 def findTasksWithInput(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]:
253 """Find all tasks that have the specified dataset type name as an
254 input.
256 Parameters
257 ----------
258 datasetTypeName : `str`
259 A string representing the name of a dataset type to be queried,
260 can also accept a `DatasetTypeName` which is a `NewType` of str for
261 type safety in static type checking.
263 Returns
264 -------
265 tasks : iterable of `TaskDef`
266 `TaskDef` objects that have the specified `DatasetTypeName` as an
267 input, list will be empty if no tasks use specified
268 `DatasetTypeName` as an input.
270 Raises
271 ------
272 KeyError
273 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
274 """
275 return (c for c in self._datasetDict.getInputs(datasetTypeName))
277 def findTaskWithOutput(self, datasetTypeName: DatasetTypeName) -> Optional[TaskDef]:
278 """Find all tasks that have the specified dataset type name as an
279 output.
281 Parameters
282 ----------
283 datasetTypeName : `str`
284 A string representing the name of a dataset type to be queried,
285 can also accept a `DatasetTypeName` which is a `NewType` of str for
286 type safety in static type checking.
288 Returns
289 -------
290 `TaskDef` or `None`
291 `TaskDef` that outputs `DatasetTypeName` as an output or None if
292 none of the tasks produce this `DatasetTypeName`.
294 Raises
295 ------
296 KeyError
297 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
298 """
299 return self._datasetDict.getOutput(datasetTypeName)
301 def tasksWithDSType(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]:
302 """Find all tasks that are associated with the specified dataset type
303 name.
305 Parameters
306 ----------
307 datasetTypeName : `str`
308 A string representing the name of a dataset type to be queried,
309 can also accept a `DatasetTypeName` which is a `NewType` of str for
310 type safety in static type checking.
312 Returns
313 -------
314 result : iterable of `TaskDef`
315 `TaskDef` objects that are associated with the specified
316 `DatasetTypeName`
318 Raises
319 ------
320 KeyError
321 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
322 """
323 results = self.findTasksWithInput(datasetTypeName)
324 output = self.findTaskWithOutput(datasetTypeName)
325 if output is not None:
326 results = chain(results, (output,))
327 return results
329 def findTaskDefByName(self, taskName: str) -> List[TaskDef]:
330 """Determine which `TaskDef` objects in this graph are associated
331 with a `str` representing a task name (looks at the taskName property
332 of `TaskDef` objects).
334 Returns a list of `TaskDef` objects as a `PipelineTask` may appear
335 multiple times in a graph with different labels.
337 Parameters
338 ----------
339 taskName : str
340 Name of a task to search for
342 Returns
343 -------
344 result : list of `TaskDef`
345 List of the `TaskDef` objects that have the name specified.
346 Multiple values are returned in the case that a task is used
347 multiple times with different labels.
348 """
349 results = []
350 for task in self._quanta.keys():
351 split = task.taskName.split('.')
352 if split[-1] == taskName:
353 results.append(task)
354 return results
356 def findTaskDefByLabel(self, label: str) -> Optional[TaskDef]:
357 """Determine which `TaskDef` objects in this graph are associated
358 with a `str` representing a tasks label.
360 Parameters
361 ----------
362 taskName : str
363 Name of a task to search for
365 Returns
366 -------
367 result : `TaskDef`
368 `TaskDef` objects that has the specified label.
369 """
370 for task in self._quanta.keys():
371 if label == task.label:
372 return task
373 return None
375 def findQuantaWithDSType(self, datasetTypeName: DatasetTypeName) -> Set[Quantum]:
376 """Return all the `Quantum` that contain a specified `DatasetTypeName`.
378 Parameters
379 ----------
380 datasetTypeName : `str`
381 The name of the dataset type to search for as a string,
382 can also accept a `DatasetTypeName` which is a `NewType` of str for
383 type safety in static type checking.
385 Returns
386 -------
387 result : `set` of `QuantumNode` objects
388 A `set` of `QuantumNode`s that contain specified `DatasetTypeName`
390 Raises
391 ------
392 KeyError
393 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
395 """
396 tasks = self._datasetDict.getAll(datasetTypeName)
397 result: Set[Quantum] = set()
398 result = result.union(*(self._quanta[task] for task in tasks))
399 return result
401 def checkQuantumInGraph(self, quantum: Quantum) -> bool:
402 """Check if specified quantum appears in the graph as part of a node.
404 Parameters
405 ----------
406 quantum : `Quantum`
407 The quantum to search for
409 Returns
410 -------
411 `bool`
412 The result of searching for the quantum
413 """
414 for qset in self._quanta.values():
415 if quantum in qset:
416 return True
417 return False
419 def writeDotGraph(self, output: Union[str, io.BufferedIOBase]):
420 """Write out the graph as a dot graph.
422 Parameters
423 ----------
424 output : str or `io.BufferedIOBase`
425 Either a filesystem path to write to, or a file handle object
426 """
427 write_dot(self._connectedQuanta, output)
429 def subset(self: _T, nodes: Union[QuantumNode, Iterable[QuantumNode]]) -> _T:
430 """Create a new graph object that contains the subset of the nodes
431 specified as input. Node number is preserved.
433 Parameters
434 ----------
435 nodes : `QuantumNode` or iterable of `QuantumNode`
437 Returns
438 -------
439 graph : instance of graph type
440 An instance of the type from which the subset was created
441 """
442 if not isinstance(nodes, Iterable):
443 nodes = (nodes, )
444 quantumSubgraph = self._connectedQuanta.subgraph(nodes).nodes
445 quantumMap = defaultdict(set)
447 node: QuantumNode
448 for node in quantumSubgraph:
449 quantumMap[node.taskDef].add(node.quantum)
450 # Create an empty graph, and then populate it with custom mapping
451 newInst = type(self)({})
452 newInst._buildGraphs(quantumMap, _quantumToNodeId={n.quantum: n.nodeId for n in nodes},
453 _buildId=self._buildId)
454 return newInst
456 def subsetToConnected(self: _T) -> Tuple[_T, ...]:
457 """Generate a list of subgraphs where each is connected.
459 Returns
460 -------
461 result : list of `QuantumGraph`
462 A list of graphs that are each connected
463 """
464 return tuple(self.subset(connectedSet)
465 for connectedSet in nx.weakly_connected_components(self._connectedQuanta))
467 def determineInputsToQuantumNode(self, node: QuantumNode) -> Set[QuantumNode]:
468 """Return a set of `QuantumNode` that are direct inputs to a specified
469 node.
471 Parameters
472 ----------
473 node : `QuantumNode`
474 The node of the graph for which inputs are to be determined
476 Returns
477 -------
478 set of `QuantumNode`
479 All the nodes that are direct inputs to specified node
480 """
481 return set(pred for pred in self._connectedQuanta.predecessors(node))
483 def determineOutputsOfQuantumNode(self, node: QuantumNode) -> Set[QuantumNode]:
484 """Return a set of `QuantumNode` that are direct outputs of a specified
485 node.
487 Parameters
488 ----------
489 node : `QuantumNode`
490 The node of the graph for which outputs are to be determined
492 Returns
493 -------
494 set of `QuantumNode`
495 All the nodes that are direct outputs to specified node
496 """
497 return set(succ for succ in self._connectedQuanta.successors(node))
499 def determineConnectionsOfQuantumNode(self: _T, node: QuantumNode) -> _T:
500 """Return a graph of `QuantumNode` that are direct inputs and outputs
501 of a specified node.
503 Parameters
504 ----------
505 node : `QuantumNode`
506 The node of the graph for which connected nodes are to be
507 determined.
509 Returns
510 -------
511 graph : graph of `QuantumNode`
512 All the nodes that are directly connected to specified node
513 """
514 nodes = self.determineInputsToQuantumNode(node).union(self.determineOutputsOfQuantumNode(node))
515 nodes.add(node)
516 return self.subset(nodes)
518 def determineAncestorsOfQuantumNode(self: _T, node: QuantumNode) -> _T:
519 """Return a graph of the specified node and all the ancestor nodes
520 directly reachable by walking edges.
522 Parameters
523 ----------
524 node : `QuantumNode`
525 The node for which all ansestors are to be determined
527 Returns
528 -------
529 graph of `QuantumNode`
530 Graph of node and all of its ansestors
531 """
532 predecessorNodes = nx.ancestors(self._connectedQuanta, node)
533 predecessorNodes.add(node)
534 return self.subset(predecessorNodes)
536 def findCycle(self) -> List[Tuple[QuantumNode, QuantumNode]]:
537 """Check a graph for the presense of cycles and returns the edges of
538 any cycles found, or an empty list if there is no cycle.
540 Returns
541 -------
542 result : list of tuple of `QuantumNode`, `QuantumNode`
543 A list of any graph edges that form a cycle, or an empty list if
544 there is no cycle. Empty list to so support if graph.find_cycle()
545 syntax as an empty list is falsy.
546 """
547 try:
548 return nx.find_cycle(self._connectedQuanta)
549 except nx.NetworkXNoCycle:
550 return []
552 def save(self, file):
553 """Save QuantumGraph to a file.
554 Presently we store QuantumGraph in pickle format, this could
555 potentially change in the future if better format is found.
557 Parameters
558 ----------
559 file : `io.BufferedIOBase`
560 File to write pickle data open in binary mode.
561 """
562 pickle.dump(self, file)
564 @classmethod
565 def load(cls, file, universe):
566 """Read QuantumGraph from a file that was made by `save`.
568 Parameters
569 ----------
570 file : `io.BufferedIOBase`
571 File with pickle data open in binary mode.
572 universe: `~lsst.daf.butler.DimensionUniverse`
573 DimensionUniverse instance, not used by the method itself but
574 needed to ensure that registry data structures are initialized.
576 Returns
577 -------
578 graph : `QuantumGraph`
579 Resulting QuantumGraph instance.
581 Raises
582 ------
583 TypeError
584 Raised if pickle contains instance of a type other than
585 QuantumGraph.
586 Notes
587 -----
588 Reading Quanta from pickle requires existence of singleton
589 DimensionUniverse which is usually instantiated during Registry
590 initialization. To make sure that DimensionUniverse exists this method
591 accepts dummy DimensionUniverse argument.
592 """
593 qgraph = pickle.load(file)
594 if not isinstance(qgraph, QuantumGraph):
595 raise TypeError(f"QuantumGraph pickle file has contains unexpected object type: {type(qgraph)}")
596 return qgraph
598 def iterTaskGraph(self) -> Generator[TaskDef, None, None]:
599 """Iterate over the `taskGraph` attribute in topological order
601 Yields
602 ------
603 `TaskDef`
604 `TaskDef` objects in topological order
605 """
606 yield from nx.topological_sort(self.taskGraph)
608 def __iter__(self) -> Generator[QuantumNode, None, None]:
609 yield from nx.topological_sort(self._connectedQuanta)
611 def __len__(self) -> int:
612 return self._count
614 def __contains__(self, node: QuantumNode) -> bool:
615 return self._connectedQuanta.has_node(node)
617 def __getstate__(self) -> dict:
618 """Stores a compact form of the graph as a list of graph nodes, and a
619 tuple of task labels and task configs. The full graph can be
620 reconstructed with this information, and it preseves the ordering of
621 the graph ndoes.
622 """
623 return {"nodesList": list(self)}
625 def __setstate__(self, state: dict):
626 """Reconstructs the state of the graph from the information persisted
627 in getstate.
628 """
629 quanta: DefaultDict[TaskDef, Set[Quantum]] = defaultdict(set)
630 quantumToNodeId: Dict[Quantum, NodeId] = {}
631 quantumNode: QuantumNode
632 for quantumNode in state['nodesList']:
633 quanta[quantumNode.taskDef].add(quantumNode.quantum)
634 quantumToNodeId[quantumNode.quantum] = quantumNode.nodeId
635 _buildId = quantumNode.nodeId.buildId if state['nodesList'] else None # type: ignore
636 self._buildGraphs(quanta, _quantumToNodeId=quantumToNodeId, _buildId=_buildId)
638 def __eq__(self, other: object) -> bool:
639 if not isinstance(other, QuantumGraph):
640 return False
641 if len(self) != len(other):
642 return False
643 for node in self:
644 if node not in other:
645 return False
646 if self.determineInputsToQuantumNode(node) != other.determineInputsToQuantumNode(node):
647 return False
648 if self.determineOutputsOfQuantumNode(node) != other.determineOutputsOfQuantumNode(node):
649 return False
650 return list(self.taskGraph) == list(other.taskGraph)