Coverage for python/lsst/pipe/base/graph/graph.py : 21%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ("QuantumGraph", "IncompatibleGraphError")
25from collections import defaultdict
27from itertools import chain, count
28import io
29import networkx as nx
30from networkx.drawing.nx_agraph import write_dot
31import os
32import pickle
33import time
34from typing import (DefaultDict, Dict, FrozenSet, Iterable, List, Mapping, Set, Generator, Optional, Tuple,
35 Union, TypeVar)
37from ..connections import iterConnections
38from ..pipeline import TaskDef
39from lsst.daf.butler import Quantum, DatasetRef, ButlerURI
41from ._implDetails import _DatasetTracker, DatasetTypeName
42from .quantumNode import QuantumNode, NodeId, BuildId
44_T = TypeVar("_T", bound="QuantumGraph")
47class IncompatibleGraphError(Exception):
48 """Exception class to indicate that a lookup by NodeId is impossible due
49 to incompatibilities
50 """
51 pass
54class QuantumGraph:
55 """QuantumGraph is a directed acyclic graph of `QuantumNode` objects
57 This data structure represents a concrete workflow generated from a
58 `Pipeline`.
60 Parameters
61 ----------
62 quanta : Mapping of `TaskDef` to sets of `Quantum`
63 This maps tasks (and their configs) to the sets of data they are to
64 process.
65 """
66 def __init__(self, quanta: Mapping[TaskDef, Set[Quantum]]):
67 self._buildGraphs(quanta)
69 def _buildGraphs(self,
70 quanta: Mapping[TaskDef, Set[Quantum]],
71 *,
72 _quantumToNodeId: Optional[Mapping[Quantum, NodeId]] = None,
73 _buildId: Optional[BuildId] = None):
74 """Builds the graph that is used to store the relation between tasks,
75 and the graph that holds the relations between quanta
76 """
77 self._quanta = quanta
78 self._buildId = _buildId if _buildId is not None else BuildId(f"{time.time()}-{os.getpid()}")
79 # Data structures used to identify relations between components;
80 # DatasetTypeName -> TaskDef for task,
81 # and DatasetRef -> QuantumNode for the quanta
82 self._datasetDict = _DatasetTracker[DatasetTypeName, TaskDef]()
83 self._datasetRefDict = _DatasetTracker[DatasetRef, QuantumNode]()
85 nodeNumberGenerator = count()
86 self._nodeIdMap: Dict[NodeId, QuantumNode] = {}
87 self._count = 0
88 for taskDef, quantumSet in self._quanta.items():
89 connections = taskDef.connections
91 # For each type of connection in the task, add a key to the
92 # `_DatasetTracker` for the connections name, with a value of
93 # the TaskDef in the appropriate field
94 for inpt in iterConnections(connections, ("inputs", "prerequisiteInputs", "initInputs")):
95 self._datasetDict.addInput(DatasetTypeName(inpt.name), taskDef)
97 for output in iterConnections(connections, ("outputs", "initOutputs")):
98 self._datasetDict.addOutput(DatasetTypeName(output.name), taskDef)
100 # For each `Quantum` in the set of all `Quantum` for this task,
101 # add a key to the `_DatasetTracker` that is a `DatasetRef` for one
102 # of the individual datasets inside the `Quantum`, with a value of
103 # a newly created QuantumNode to the appropriate input/output
104 # field.
105 self._count += len(quantumSet)
106 for quantum in quantumSet:
107 if _quantumToNodeId:
108 nodeId = _quantumToNodeId.get(quantum)
109 if nodeId is None:
110 raise ValueError("If _quantuMToNodeNumber is not None, all quanta must have an "
111 "associated value in the mapping")
112 else:
113 nodeId = NodeId(next(nodeNumberGenerator), self._buildId)
115 inits = quantum.initInputs.values()
116 inputs = quantum.inputs.values()
117 value = QuantumNode(quantum, taskDef, nodeId)
118 self._nodeIdMap[nodeId] = value
120 for dsRef in chain(inits, inputs):
121 # unfortunately, `Quantum` allows inits to be individual
122 # `DatasetRef`s or an Iterable of such, so there must
123 # be an instance check here
124 if isinstance(dsRef, Iterable):
125 for sub in dsRef:
126 self._datasetRefDict.addInput(sub, value)
127 else:
128 self._datasetRefDict.addInput(dsRef, value)
129 for dsRef in chain.from_iterable(quantum.outputs.values()):
130 self._datasetRefDict.addOutput(dsRef, value)
132 # Graph of task relations, used in various methods
133 self._taskGraph = self._datasetDict.makeNetworkXGraph()
135 # Graph of quanta relations
136 self._connectedQuanta = self._datasetRefDict.makeNetworkXGraph()
138 @property
139 def taskGraph(self) -> nx.DiGraph:
140 """Return a graph representing the relations between the tasks inside
141 the quantum graph.
143 Returns
144 -------
145 taskGraph : `networkx.Digraph`
146 Internal datastructure that holds relations of `TaskDef` objects
147 """
148 return self._taskGraph
150 @property
151 def graph(self) -> nx.DiGraph:
152 """Return a graph representing the relations between all the
153 `QuantumNode` objects. Largely it should be preferred to iterate
154 over, and use methods of this class, but sometimes direct access to
155 the networkx object may be helpful
157 Returns
158 -------
159 graph : `networkx.Digraph`
160 Internal datastructure that holds relations of `QuantumNode`
161 objects
162 """
163 return self._connectedQuanta
165 @property
166 def inputQuanta(self) -> Iterable[QuantumNode]:
167 """Make a `list` of all `QuantumNode` objects that are 'input' nodes
168 to the graph, meaning those nodes to not depend on any other nodes in
169 the graph.
171 Returns
172 -------
173 inputNodes : iterable of `QuantumNode`
174 A list of nodes that are inputs to the graph
175 """
176 return (q for q, n in self._connectedQuanta.in_degree if n == 0)
178 @property
179 def outputQuanta(self) -> Iterable[QuantumNode]:
180 """Make a `list` of all `QuantumNode` objects that are 'output' nodes
181 to the graph, meaning those nodes have no nodes that depend them in
182 the graph.
184 Returns
185 -------
186 outputNodes : iterable of `QuantumNode`
187 A list of nodes that are outputs of the graph
188 """
189 return [q for q, n in self._connectedQuanta.out_degree if n == 0]
191 @property
192 def allDatasetTypes(self) -> Tuple[DatasetTypeName, ...]:
193 """Return all the `DatasetTypeName` objects that are contained inside
194 the graph.
196 Returns
197 -------
198 tuple of `DatasetTypeName`
199 All the data set type names that are present in the graph
200 """
201 return tuple(self._datasetDict.keys())
203 @property
204 def isConnected(self) -> bool:
205 """Return True if all of the nodes in the graph are connected, ignores
206 directionality of connections.
207 """
208 return nx.is_weakly_connected(self._connectedQuanta)
210 def getQuantumNodeByNodeId(self, nodeId: NodeId) -> QuantumNode:
211 """Lookup a `QuantumNode` from an id associated with the node.
213 Parameters
214 ----------
215 nodeId : `NodeId`
216 The number associated with a node
218 Returns
219 -------
220 node : `QuantumNode`
221 The node corresponding with input number
223 Raises
224 ------
225 IndexError
226 Raised if the requested nodeId is not in the graph.
227 IncompatibleGraphError
228 Raised if the nodeId was built with a different graph than is not
229 this instance (or a graph instance that produced this instance
230 through and operation such as subset)
231 """
232 if nodeId.buildId != self._buildId:
233 raise IncompatibleGraphError("This node was built from a different, incompatible, graph instance")
234 return self._nodeIdMap[nodeId]
236 def getQuantaForTask(self, taskDef: TaskDef) -> FrozenSet[Quantum]:
237 """Return all the `Quantum` associated with a `TaskDef`.
239 Parameters
240 ----------
241 taskDef : `TaskDef`
242 The `TaskDef` for which `Quantum` are to be queried
244 Returns
245 -------
246 frozenset of `Quantum`
247 The `set` of `Quantum` that is associated with the specified
248 `TaskDef`.
249 """
250 return frozenset(self._quanta[taskDef])
252 def findTasksWithInput(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]:
253 """Find all tasks that have the specified dataset type name as an
254 input.
256 Parameters
257 ----------
258 datasetTypeName : `str`
259 A string representing the name of a dataset type to be queried,
260 can also accept a `DatasetTypeName` which is a `NewType` of str for
261 type safety in static type checking.
263 Returns
264 -------
265 tasks : iterable of `TaskDef`
266 `TaskDef` objects that have the specified `DatasetTypeName` as an
267 input, list will be empty if no tasks use specified
268 `DatasetTypeName` as an input.
270 Raises
271 ------
272 KeyError
273 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
274 """
275 return (c for c in self._datasetDict.getInputs(datasetTypeName))
277 def findTaskWithOutput(self, datasetTypeName: DatasetTypeName) -> Optional[TaskDef]:
278 """Find all tasks that have the specified dataset type name as an
279 output.
281 Parameters
282 ----------
283 datasetTypeName : `str`
284 A string representing the name of a dataset type to be queried,
285 can also accept a `DatasetTypeName` which is a `NewType` of str for
286 type safety in static type checking.
288 Returns
289 -------
290 `TaskDef` or `None`
291 `TaskDef` that outputs `DatasetTypeName` as an output or None if
292 none of the tasks produce this `DatasetTypeName`.
294 Raises
295 ------
296 KeyError
297 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
298 """
299 return self._datasetDict.getOutput(datasetTypeName)
301 def tasksWithDSType(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]:
302 """Find all tasks that are associated with the specified dataset type
303 name.
305 Parameters
306 ----------
307 datasetTypeName : `str`
308 A string representing the name of a dataset type to be queried,
309 can also accept a `DatasetTypeName` which is a `NewType` of str for
310 type safety in static type checking.
312 Returns
313 -------
314 result : iterable of `TaskDef`
315 `TaskDef` objects that are associated with the specified
316 `DatasetTypeName`
318 Raises
319 ------
320 KeyError
321 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
322 """
323 results = self.findTasksWithInput(datasetTypeName)
324 output = self.findTaskWithOutput(datasetTypeName)
325 if output is not None:
326 results = chain(results, (output,))
327 return results
329 def findTaskDefByName(self, taskName: str) -> List[TaskDef]:
330 """Determine which `TaskDef` objects in this graph are associated
331 with a `str` representing a task name (looks at the taskName property
332 of `TaskDef` objects).
334 Returns a list of `TaskDef` objects as a `PipelineTask` may appear
335 multiple times in a graph with different labels.
337 Parameters
338 ----------
339 taskName : str
340 Name of a task to search for
342 Returns
343 -------
344 result : list of `TaskDef`
345 List of the `TaskDef` objects that have the name specified.
346 Multiple values are returned in the case that a task is used
347 multiple times with different labels.
348 """
349 results = []
350 for task in self._quanta.keys():
351 split = task.taskName.split('.')
352 if split[-1] == taskName:
353 results.append(task)
354 return results
356 def findTaskDefByLabel(self, label: str) -> Optional[TaskDef]:
357 """Determine which `TaskDef` objects in this graph are associated
358 with a `str` representing a tasks label.
360 Parameters
361 ----------
362 taskName : str
363 Name of a task to search for
365 Returns
366 -------
367 result : `TaskDef`
368 `TaskDef` objects that has the specified label.
369 """
370 for task in self._quanta.keys():
371 if label == task.label:
372 return task
373 return None
375 def findQuantaWithDSType(self, datasetTypeName: DatasetTypeName) -> Set[Quantum]:
376 """Return all the `Quantum` that contain a specified `DatasetTypeName`.
378 Parameters
379 ----------
380 datasetTypeName : `str`
381 The name of the dataset type to search for as a string,
382 can also accept a `DatasetTypeName` which is a `NewType` of str for
383 type safety in static type checking.
385 Returns
386 -------
387 result : `set` of `QuantumNode` objects
388 A `set` of `QuantumNode`s that contain specified `DatasetTypeName`
390 Raises
391 ------
392 KeyError
393 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
395 """
396 tasks = self._datasetDict.getAll(datasetTypeName)
397 result: Set[Quantum] = set()
398 result = result.union(*(self._quanta[task] for task in tasks))
399 return result
401 def checkQuantumInGraph(self, quantum: Quantum) -> bool:
402 """Check if specified quantum appears in the graph as part of a node.
404 Parameters
405 ----------
406 quantum : `Quantum`
407 The quantum to search for
409 Returns
410 -------
411 `bool`
412 The result of searching for the quantum
413 """
414 for qset in self._quanta.values():
415 if quantum in qset:
416 return True
417 return False
419 def writeDotGraph(self, output: Union[str, io.BufferedIOBase]):
420 """Write out the graph as a dot graph.
422 Parameters
423 ----------
424 output : str or `io.BufferedIOBase`
425 Either a filesystem path to write to, or a file handle object
426 """
427 write_dot(self._connectedQuanta, output)
429 def subset(self: _T, nodes: Union[QuantumNode, Iterable[QuantumNode]]) -> _T:
430 """Create a new graph object that contains the subset of the nodes
431 specified as input. Node number is preserved.
433 Parameters
434 ----------
435 nodes : `QuantumNode` or iterable of `QuantumNode`
437 Returns
438 -------
439 graph : instance of graph type
440 An instance of the type from which the subset was created
441 """
442 if not isinstance(nodes, Iterable):
443 nodes = (nodes, )
444 quantumSubgraph = self._connectedQuanta.subgraph(nodes).nodes
445 quantumMap = defaultdict(set)
447 node: QuantumNode
448 for node in quantumSubgraph:
449 quantumMap[node.taskDef].add(node.quantum)
450 # Create an empty graph, and then populate it with custom mapping
451 newInst = type(self)({})
452 newInst._buildGraphs(quantumMap, _quantumToNodeId={n.quantum: n.nodeId for n in nodes},
453 _buildId=self._buildId)
454 return newInst
456 def subsetToConnected(self: _T) -> Tuple[_T, ...]:
457 """Generate a list of subgraphs where each is connected.
459 Returns
460 -------
461 result : list of `QuantumGraph`
462 A list of graphs that are each connected
463 """
464 return tuple(self.subset(connectedSet)
465 for connectedSet in nx.weakly_connected_components(self._connectedQuanta))
467 def determineInputsToQuantumNode(self, node: QuantumNode) -> Set[QuantumNode]:
468 """Return a set of `QuantumNode` that are direct inputs to a specified
469 node.
471 Parameters
472 ----------
473 node : `QuantumNode`
474 The node of the graph for which inputs are to be determined
476 Returns
477 -------
478 set of `QuantumNode`
479 All the nodes that are direct inputs to specified node
480 """
481 return set(pred for pred in self._connectedQuanta.predecessors(node))
483 def determineOutputsOfQuantumNode(self, node: QuantumNode) -> Set[QuantumNode]:
484 """Return a set of `QuantumNode` that are direct outputs of a specified
485 node.
487 Parameters
488 ----------
489 node : `QuantumNode`
490 The node of the graph for which outputs are to be determined
492 Returns
493 -------
494 set of `QuantumNode`
495 All the nodes that are direct outputs to specified node
496 """
497 return set(succ for succ in self._connectedQuanta.successors(node))
499 def determineConnectionsOfQuantumNode(self: _T, node: QuantumNode) -> _T:
500 """Return a graph of `QuantumNode` that are direct inputs and outputs
501 of a specified node.
503 Parameters
504 ----------
505 node : `QuantumNode`
506 The node of the graph for which connected nodes are to be
507 determined.
509 Returns
510 -------
511 graph : graph of `QuantumNode`
512 All the nodes that are directly connected to specified node
513 """
514 nodes = self.determineInputsToQuantumNode(node).union(self.determineOutputsOfQuantumNode(node))
515 nodes.add(node)
516 return self.subset(nodes)
518 def determineAncestorsOfQuantumNode(self: _T, node: QuantumNode) -> _T:
519 """Return a graph of the specified node and all the ancestor nodes
520 directly reachable by walking edges.
522 Parameters
523 ----------
524 node : `QuantumNode`
525 The node for which all ansestors are to be determined
527 Returns
528 -------
529 graph of `QuantumNode`
530 Graph of node and all of its ansestors
531 """
532 predecessorNodes = nx.ancestors(self._connectedQuanta, node)
533 predecessorNodes.add(node)
534 return self.subset(predecessorNodes)
536 def findCycle(self) -> List[Tuple[QuantumNode, QuantumNode]]:
537 """Check a graph for the presense of cycles and returns the edges of
538 any cycles found, or an empty list if there is no cycle.
540 Returns
541 -------
542 result : list of tuple of `QuantumNode`, `QuantumNode`
543 A list of any graph edges that form a cycle, or an empty list if
544 there is no cycle. Empty list to so support if graph.find_cycle()
545 syntax as an empty list is falsy.
546 """
547 try:
548 return nx.find_cycle(self._connectedQuanta)
549 except nx.NetworkXNoCycle:
550 return []
552 def saveUri(self, uri):
553 """Save `QuantumGraph` to the specified URI.
555 Parameters
556 ----------
557 uri : `ButlerURI` or `str`
558 URI to where the graph should be saved.
559 """
560 uri = ButlerURI(uri)
561 if uri.getExtension() not in (".pickle", ".pkl"):
562 raise TypeError(f"Can currently only save a graph in pickle format not {uri}")
563 uri.write(pickle.dumps(self))
565 @classmethod
566 def loadUri(cls, uri, universe):
567 """Read `QuantumGraph` from a URI.
569 Parameters
570 ----------
571 uri : `ButlerURI` or `str`
572 URI from where to load the graph.
573 universe: `~lsst.daf.butler.DimensionUniverse`
574 DimensionUniverse instance, not used by the method itself but
575 needed to ensure that registry data structures are initialized.
577 Returns
578 -------
579 graph : `QuantumGraph`
580 Resulting QuantumGraph instance.
582 Raises
583 ------
584 TypeError
585 Raised if pickle contains instance of a type other than
586 QuantumGraph.
587 Notes
588 -----
589 Reading Quanta from pickle requires existence of singleton
590 DimensionUniverse which is usually instantiated during Registry
591 initialization. To make sure that DimensionUniverse exists this method
592 accepts dummy DimensionUniverse argument.
593 """
594 uri = ButlerURI(uri)
595 # With ButlerURI we have the choice of always using a local file
596 # or reading in the bytes directly. Reading in bytes can be more
597 # efficient for reasonably-sized pickle files when the resource
598 # is remote. For now use the local file variant. For a local file
599 # as_local() does nothing.
600 with uri.as_local() as local, open(local.ospath, "rb") as fd:
601 qgraph = pickle.load(fd)
602 if not isinstance(qgraph, QuantumGraph):
603 raise TypeError(f"QuantumGraph pickle file has contains unexpected object type: {type(qgraph)}")
604 return qgraph
606 def save(self, file):
607 """Save QuantumGraph to a file.
609 Presently we store QuantumGraph in pickle format, this could
610 potentially change in the future if better format is found.
612 Parameters
613 ----------
614 file : `io.BufferedIOBase`
615 File to write pickle data open in binary mode.
616 """
617 pickle.dump(self, file)
619 @classmethod
620 def load(cls, file, universe):
621 """Read QuantumGraph from a file that was made by `save`.
623 Parameters
624 ----------
625 file : `io.BufferedIOBase`
626 File with pickle data open in binary mode.
627 universe: `~lsst.daf.butler.DimensionUniverse`
628 DimensionUniverse instance, not used by the method itself but
629 needed to ensure that registry data structures are initialized.
631 Returns
632 -------
633 graph : `QuantumGraph`
634 Resulting QuantumGraph instance.
636 Raises
637 ------
638 TypeError
639 Raised if pickle contains instance of a type other than
640 QuantumGraph.
641 Notes
642 -----
643 Reading Quanta from pickle requires existence of singleton
644 DimensionUniverse which is usually instantiated during Registry
645 initialization. To make sure that DimensionUniverse exists this method
646 accepts dummy DimensionUniverse argument.
647 """
648 qgraph = pickle.load(file)
649 if not isinstance(qgraph, QuantumGraph):
650 raise TypeError(f"QuantumGraph pickle file has contains unexpected object type: {type(qgraph)}")
651 return qgraph
653 def iterTaskGraph(self) -> Generator[TaskDef, None, None]:
654 """Iterate over the `taskGraph` attribute in topological order
656 Yields
657 ------
658 `TaskDef`
659 `TaskDef` objects in topological order
660 """
661 yield from nx.topological_sort(self.taskGraph)
663 def __iter__(self) -> Generator[QuantumNode, None, None]:
664 yield from nx.topological_sort(self._connectedQuanta)
666 def __len__(self) -> int:
667 return self._count
669 def __contains__(self, node: QuantumNode) -> bool:
670 return self._connectedQuanta.has_node(node)
672 def __getstate__(self) -> dict:
673 """Stores a compact form of the graph as a list of graph nodes, and a
674 tuple of task labels and task configs. The full graph can be
675 reconstructed with this information, and it preseves the ordering of
676 the graph ndoes.
677 """
678 return {"nodesList": list(self)}
680 def __setstate__(self, state: dict):
681 """Reconstructs the state of the graph from the information persisted
682 in getstate.
683 """
684 quanta: DefaultDict[TaskDef, Set[Quantum]] = defaultdict(set)
685 quantumToNodeId: Dict[Quantum, NodeId] = {}
686 quantumNode: QuantumNode
687 for quantumNode in state['nodesList']:
688 quanta[quantumNode.taskDef].add(quantumNode.quantum)
689 quantumToNodeId[quantumNode.quantum] = quantumNode.nodeId
690 _buildId = quantumNode.nodeId.buildId if state['nodesList'] else None # type: ignore
691 self._buildGraphs(quanta, _quantumToNodeId=quantumToNodeId, _buildId=_buildId)
693 def __eq__(self, other: object) -> bool:
694 if not isinstance(other, QuantumGraph):
695 return False
696 if len(self) != len(other):
697 return False
698 for node in self:
699 if node not in other:
700 return False
701 if self.determineInputsToQuantumNode(node) != other.determineInputsToQuantumNode(node):
702 return False
703 if self.determineOutputsOfQuantumNode(node) != other.determineOutputsOfQuantumNode(node):
704 return False
705 return list(self.taskGraph) == list(other.taskGraph)