21 from __future__
import annotations
24 __all__ = (
"QuantumGraph",
"IncompatibleGraphError")
26 from collections
import defaultdict, deque
28 from itertools
import chain, count
31 from networkx.drawing.nx_agraph
import write_dot
38 from typing
import (DefaultDict, Dict, FrozenSet, Iterable, List, Mapping, Set, Generator, Optional, Tuple,
41 from ..connections
import iterConnections
42 from ..pipeline
import TaskDef
43 from lsst.daf.butler
import Quantum, DatasetRef, ButlerURI, DimensionUniverse
45 from ._implDetails
import _DatasetTracker, DatasetTypeName
46 from .quantumNode
import QuantumNode, NodeId, BuildId
47 from ._loadHelpers
import LoadHelper
50 _T = TypeVar(
"_T", bound=
"QuantumGraph")
59 STRUCT_FMT_STRING =
'>HQQ'
63 MAGIC_BYTES = b
"qgraph4\xf6\xe8\xa9"
67 """Exception class to indicate that a lookup by NodeId is impossible due
74 """QuantumGraph is a directed acyclic graph of `QuantumNode` objects
76 This data structure represents a concrete workflow generated from a
81 quanta : Mapping of `TaskDef` to sets of `Quantum`
82 This maps tasks (and their configs) to the sets of data they are to
85 def __init__(self, quanta: Mapping[TaskDef, Set[Quantum]]):
88 def _buildGraphs(self,
89 quanta: Mapping[TaskDef, Set[Quantum]],
91 _quantumToNodeId: Optional[Mapping[Quantum, NodeId]] =
None,
92 _buildId: Optional[BuildId] =
None):
93 """Builds the graph that is used to store the relation between tasks,
94 and the graph that holds the relations between quanta
97 self.
_buildId = _buildId
if _buildId
is not None else BuildId(f
"{time.time()}-{os.getpid()}")
101 self.
_datasetDict = _DatasetTracker[DatasetTypeName, TaskDef]()
104 nodeNumberGenerator = count()
105 self._nodeIdMap: Dict[NodeId, QuantumNode] = {}
107 for taskDef, quantumSet
in self.
_quanta.items():
108 connections = taskDef.connections
113 for inpt
in iterConnections(connections, (
"inputs",
"prerequisiteInputs",
"initInputs")):
116 for output
in iterConnections(connections, (
"outputs",
"initOutputs")):
124 self.
_count += len(quantumSet)
125 for quantum
in quantumSet:
127 nodeId = _quantumToNodeId.get(quantum)
129 raise ValueError(
"If _quantuMToNodeNumber is not None, all quanta must have an "
130 "associated value in the mapping")
134 inits = quantum.initInputs.values()
135 inputs = quantum.inputs.values()
137 self._nodeIdMap[nodeId] = value
139 for dsRef
in chain(inits, inputs):
143 if isinstance(dsRef, Iterable):
148 for dsRef
in chain.from_iterable(quantum.outputs.values()):
159 """Return a graph representing the relations between the tasks inside
164 taskGraph : `networkx.Digraph`
165 Internal datastructure that holds relations of `TaskDef` objects
171 """Return a graph representing the relations between all the
172 `QuantumNode` objects. Largely it should be preferred to iterate
173 over, and use methods of this class, but sometimes direct access to
174 the networkx object may be helpful
178 graph : `networkx.Digraph`
179 Internal datastructure that holds relations of `QuantumNode`
186 """Make a `list` of all `QuantumNode` objects that are 'input' nodes
187 to the graph, meaning those nodes to not depend on any other nodes in
192 inputNodes : iterable of `QuantumNode`
193 A list of nodes that are inputs to the graph
199 """Make a `list` of all `QuantumNode` objects that are 'output' nodes
200 to the graph, meaning those nodes have no nodes that depend them in
205 outputNodes : iterable of `QuantumNode`
206 A list of nodes that are outputs of the graph
212 """Return all the `DatasetTypeName` objects that are contained inside
217 tuple of `DatasetTypeName`
218 All the data set type names that are present in the graph
224 """Return True if all of the nodes in the graph are connected, ignores
225 directionality of connections.
230 """Lookup a `QuantumNode` from an id associated with the node.
235 The number associated with a node
240 The node corresponding with input number
245 Raised if the requested nodeId is not in the graph.
246 IncompatibleGraphError
247 Raised if the nodeId was built with a different graph than is not
248 this instance (or a graph instance that produced this instance
249 through and operation such as subset)
253 return self._nodeIdMap[nodeId]
256 """Return all the `Quantum` associated with a `TaskDef`.
261 The `TaskDef` for which `Quantum` are to be queried
265 frozenset of `Quantum`
266 The `set` of `Quantum` that is associated with the specified
269 return frozenset(self.
_quanta[taskDef])
272 """Find all tasks that have the specified dataset type name as an
277 datasetTypeName : `str`
278 A string representing the name of a dataset type to be queried,
279 can also accept a `DatasetTypeName` which is a `NewType` of str for
280 type safety in static type checking.
284 tasks : iterable of `TaskDef`
285 `TaskDef` objects that have the specified `DatasetTypeName` as an
286 input, list will be empty if no tasks use specified
287 `DatasetTypeName` as an input.
292 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
294 return (c
for c
in self.
_datasetDict.getInputs(datasetTypeName))
297 """Find all tasks that have the specified dataset type name as an
302 datasetTypeName : `str`
303 A string representing the name of a dataset type to be queried,
304 can also accept a `DatasetTypeName` which is a `NewType` of str for
305 type safety in static type checking.
310 `TaskDef` that outputs `DatasetTypeName` as an output or None if
311 none of the tasks produce this `DatasetTypeName`.
316 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
321 """Find all tasks that are associated with the specified dataset type
326 datasetTypeName : `str`
327 A string representing the name of a dataset type to be queried,
328 can also accept a `DatasetTypeName` which is a `NewType` of str for
329 type safety in static type checking.
333 result : iterable of `TaskDef`
334 `TaskDef` objects that are associated with the specified
340 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
344 if output
is not None:
345 results = chain(results, (output,))
349 """Determine which `TaskDef` objects in this graph are associated
350 with a `str` representing a task name (looks at the taskName property
351 of `TaskDef` objects).
353 Returns a list of `TaskDef` objects as a `PipelineTask` may appear
354 multiple times in a graph with different labels.
359 Name of a task to search for
363 result : list of `TaskDef`
364 List of the `TaskDef` objects that have the name specified.
365 Multiple values are returned in the case that a task is used
366 multiple times with different labels.
369 for task
in self.
_quanta.keys():
370 split = task.taskName.split(
'.')
371 if split[-1] == taskName:
376 """Determine which `TaskDef` objects in this graph are associated
377 with a `str` representing a tasks label.
382 Name of a task to search for
387 `TaskDef` objects that has the specified label.
389 for task
in self.
_quanta.keys():
390 if label == task.label:
395 """Return all the `Quantum` that contain a specified `DatasetTypeName`.
399 datasetTypeName : `str`
400 The name of the dataset type to search for as a string,
401 can also accept a `DatasetTypeName` which is a `NewType` of str for
402 type safety in static type checking.
406 result : `set` of `QuantumNode` objects
407 A `set` of `QuantumNode`s that contain specified `DatasetTypeName`
412 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
416 result: Set[Quantum] = set()
417 result = result.union(*(self.
_quanta[task]
for task
in tasks))
421 """Check if specified quantum appears in the graph as part of a node.
426 The quantum to search for
431 The result of searching for the quantum
433 for qset
in self.
_quanta.values():
439 """Write out the graph as a dot graph.
443 output : str or `io.BufferedIOBase`
444 Either a filesystem path to write to, or a file handle object
448 def subset(self: _T, nodes: Union[QuantumNode, Iterable[QuantumNode]]) -> _T:
449 """Create a new graph object that contains the subset of the nodes
450 specified as input. Node number is preserved.
454 nodes : `QuantumNode` or iterable of `QuantumNode`
458 graph : instance of graph type
459 An instance of the type from which the subset was created
461 if not isinstance(nodes, Iterable):
464 quantumMap = defaultdict(set)
467 for node
in quantumSubgraph:
468 quantumMap[node.taskDef].add(node.quantum)
470 newInst = type(self)({})
471 newInst._buildGraphs(quantumMap, _quantumToNodeId={n.quantum: n.nodeId
for n
in nodes},
476 """Generate a list of subgraphs where each is connected.
480 result : list of `QuantumGraph`
481 A list of graphs that are each connected
483 return tuple(self.
subset(connectedSet)
487 """Return a set of `QuantumNode` that are direct inputs to a specified
493 The node of the graph for which inputs are to be determined
498 All the nodes that are direct inputs to specified node
503 """Return a set of `QuantumNode` that are direct outputs of a specified
509 The node of the graph for which outputs are to be determined
514 All the nodes that are direct outputs to specified node
519 """Return a graph of `QuantumNode` that are direct inputs and outputs
525 The node of the graph for which connected nodes are to be
530 graph : graph of `QuantumNode`
531 All the nodes that are directly connected to specified node
538 """Return a graph of the specified node and all the ancestor nodes
539 directly reachable by walking edges.
544 The node for which all ansestors are to be determined
548 graph of `QuantumNode`
549 Graph of node and all of its ansestors
552 predecessorNodes.add(node)
553 return self.
subset(predecessorNodes)
555 def findCycle(self) -> List[Tuple[QuantumNode, QuantumNode]]:
556 """Check a graph for the presense of cycles and returns the edges of
557 any cycles found, or an empty list if there is no cycle.
561 result : list of tuple of `QuantumNode`, `QuantumNode`
562 A list of any graph edges that form a cycle, or an empty list if
563 there is no cycle. Empty list to so support if graph.find_cycle()
564 syntax as an empty list is falsy.
568 except nx.NetworkXNoCycle:
572 """Save `QuantumGraph` to the specified URI.
576 uri : `ButlerURI` or `str`
577 URI to where the graph should be saved.
580 butlerUri = ButlerURI(uri)
581 if butlerUri.getExtension()
not in (
".qgraph"):
582 raise TypeError(f
"Can currently only save a graph in qgraph format not {uri}")
583 butlerUri.write(buffer)
586 def loadUri(cls, uri: Union[ButlerURI, str], universe: DimensionUniverse,
587 nodes: Optional[Iterable[int]] =
None,
588 graphID: Optional[BuildId] =
None
590 """Read `QuantumGraph` from a URI.
594 uri : `ButlerURI` or `str`
595 URI from where to load the graph.
596 universe: `~lsst.daf.butler.DimensionUniverse`
597 DimensionUniverse instance, not used by the method itself but
598 needed to ensure that registry data structures are initialized.
599 nodes: iterable of `int` or None
600 Numbers that correspond to nodes in the graph. If specified, only
601 these nodes will be loaded. Defaults to None, in which case all
602 nodes will be loaded.
603 graphID : `str` or `None`
604 If specified this ID is verified against the loaded graph prior to
605 loading any Nodes. This defaults to None in which case no
610 graph : `QuantumGraph`
611 Resulting QuantumGraph instance.
616 Raised if pickle contains instance of a type other than
619 Raised if one or more of the nodes requested is not in the
620 `QuantumGraph` or if graphID parameter does not match the graph
621 being loaded or if the supplied uri does not point at a valid
622 `QuantumGraph` save file.
627 Reading Quanta from pickle requires existence of singleton
628 DimensionUniverse which is usually instantiated during Registry
629 initialization. To make sure that DimensionUniverse exists this method
630 accepts dummy DimensionUniverse argument.
639 if uri.getExtension()
in (
".pickle",
".pkl"):
640 with uri.as_local()
as local, open(local.ospath,
"rb")
as fd:
641 warnings.warn(
"Pickle graphs are deprecated, please re-save your graph with the save method")
642 qgraph = pickle.load(fd)
643 elif uri.getExtension()
in (
'.qgraph'):
645 qgraph = loader.load(nodes, graphID)
647 raise ValueError(
"Only know how to handle files saved as `pickle`, `pkl`, or `qgraph`")
648 if not isinstance(qgraph, QuantumGraph):
649 raise TypeError(f
"QuantumGraph save file contains unexpected object type: {type(qgraph)}")
652 def save(self, file: io.IO[bytes]):
653 """Save QuantumGraph to a file.
655 Presently we store QuantumGraph in pickle format, this could
656 potentially change in the future if better format is found.
660 file : `io.BufferedIOBase`
661 File to write pickle data open in binary mode.
666 def _buildSaveObject(self) -> bytearray:
680 dump = lzma.compress(pickle.dumps(taskDef, protocol=protocol))
681 taskDefMap[taskDef.label] = (count, count+len(dump))
683 pickleData.append(dump)
689 taskDefMap[
'__GraphBuildID'] = self.
graphID
693 node = copy.copy(node)
694 taskDef = node.taskDef
705 object.__setattr__(node,
'taskDef', taskDef.label)
708 dump = lzma.compress(pickle.dumps(node, protocol=protocol))
709 pickleData.append(dump)
710 nodeMap[node.nodeId.number] = (count, count+len(dump))
714 taskDef_pickle = pickle.dumps(taskDefMap, protocol=protocol)
717 map_pickle = pickle.dumps(nodeMap, protocol=protocol)
721 map_lengths = struct.pack(STRUCT_FMT_STRING, SAVE_VERSION, len(taskDef_pickle), len(map_pickle))
729 buffer.extend(MAGIC_BYTES)
730 buffer.extend(map_lengths)
731 buffer.extend(taskDef_pickle)
732 buffer.extend(map_pickle)
742 buffer.extend(pickleData.popleft())
746 def load(cls, file: io.IO[bytes], universe: DimensionUniverse,
747 nodes: Optional[Iterable[int]] =
None,
748 graphID: Optional[BuildId] =
None
750 """Read QuantumGraph from a file that was made by `save`.
754 file : `io.IO` of bytes
755 File with pickle data open in binary mode.
756 universe: `~lsst.daf.butler.DimensionUniverse`
757 DimensionUniverse instance, not used by the method itself but
758 needed to ensure that registry data structures are initialized.
759 nodes: iterable of `int` or None
760 Numbers that correspond to nodes in the graph. If specified, only
761 these nodes will be loaded. Defaults to None, in which case all
762 nodes will be loaded.
763 graphID : `str` or `None`
764 If specified this ID is verified against the loaded graph prior to
765 loading any Nodes. This defaults to None in which case no
770 graph : `QuantumGraph`
771 Resulting QuantumGraph instance.
776 Raised if pickle contains instance of a type other than
779 Raised if one or more of the nodes requested is not in the
780 `QuantumGraph` or if graphID parameter does not match the graph
781 being loaded or if the supplied uri does not point at a valid
782 `QuantumGraph` save file.
786 Reading Quanta from pickle requires existence of singleton
787 DimensionUniverse which is usually instantiated during Registry
788 initialization. To make sure that DimensionUniverse exists this method
789 accepts dummy DimensionUniverse argument.
794 qgraph = pickle.load(file)
795 warnings.warn(
"Pickle graphs are deprecated, please re-save your graph with the save method")
796 except pickle.UnpicklingError:
798 qgraph = loader.load(nodes, graphID)
799 if not isinstance(qgraph, QuantumGraph):
800 raise TypeError(f
"QuantumGraph pickle file has contains unexpected object type: {type(qgraph)}")
804 """Iterate over the `taskGraph` attribute in topological order
809 `TaskDef` objects in topological order
811 yield from nx.topological_sort(self.
taskGraph)
815 """Returns the ID generated by the graph at construction time
819 def __iter__(self) -> Generator[QuantumNode, None, None]:
829 """Stores a compact form of the graph as a list of graph nodes, and a
830 tuple of task labels and task configs. The full graph can be
831 reconstructed with this information, and it preseves the ordering of
834 return {
"nodesList": list(self)}
837 """Reconstructs the state of the graph from the information persisted
840 quanta: DefaultDict[TaskDef, Set[Quantum]] = defaultdict(set)
841 quantumToNodeId: Dict[Quantum, NodeId] = {}
842 quantumNode: QuantumNode
843 for quantumNode
in state[
'nodesList']:
844 quanta[quantumNode.taskDef].add(quantumNode.quantum)
845 quantumToNodeId[quantumNode.quantum] = quantumNode.nodeId
846 _buildId = quantumNode.nodeId.buildId
if state[
'nodesList']
else None
847 self.
_buildGraphs(quanta, _quantumToNodeId=quantumToNodeId, _buildId=_buildId)
850 if not isinstance(other, QuantumGraph):
852 if len(self) != len(other):
855 if node
not in other:
861 return list(self.
taskGraph) == list(other.taskGraph)