Coverage for python/lsst/pipe/base/graph/graph.py : 22%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ("QuantumGraph", "IncompatibleGraphError")
25from collections import defaultdict
27from itertools import chain, count
28import io
29import networkx as nx
30from networkx.drawing.nx_agraph import write_dot
31import os
32import pickle
33import time
34from typing import (DefaultDict, Dict, FrozenSet, Iterable, List, Mapping, Set, Generator, Optional, Tuple,
35 Union, TypeVar)
37from ..connections import iterConnections
38from ..pipeline import TaskDef
39from lsst.daf.butler import Quantum, DatasetRef
41from ._implDetails import _DatasetTracker, DatasetTypeName
42from .quantumNode import QuantumNode, NodeId, BuildId
44_T = TypeVar("_T", bound="QuantumGraph")
47class IncompatibleGraphError(Exception):
48 """Exception class to indicate that a lookup by NodeId is impossible due
49 to incompatibilities
50 """
51 pass
54class QuantumGraph:
55 """QuantumGraph is a directed acyclic graph of `QuantumNode`s
57 This data structure represents a concrete workflow generated from a
58 `Pipeline`.
60 Parameters
61 ----------
62 quanta : Mapping of `TaskDef` to sets of `Quantum`
63 This maps tasks (and their configs) to the sets of data they are to
64 process.
65 """
66 def __init__(self, quanta: Mapping[TaskDef, Set[Quantum]]):
67 self._buildGraphs(quanta)
69 def _buildGraphs(self,
70 quanta: Mapping[TaskDef, Set[Quantum]],
71 *,
72 _quantumToNodeId: Optional[Mapping[Quantum, NodeId]] = None,
73 _buildId: Optional[BuildId] = None):
74 """Builds the graph that is used to store the relation between tasks,
75 and the graph that holds the relations between quanta
76 """
77 self._quanta = quanta
78 self._buildId = _buildId if _buildId is not None else BuildId(f"{time.time()}-{os.getpid()}")
79 # Data structures used to identify relations between components;
80 # DatasetTypeName -> TaskDef for task,
81 # and DatasetRef -> QuantumNode for the quanta
82 self._datasetDict = _DatasetTracker[DatasetTypeName, TaskDef]()
83 self._datasetRefDict = _DatasetTracker[DatasetRef, QuantumNode]()
85 nodeNumberGenerator = count()
86 self._nodeIdMap: Dict[NodeId, QuantumNode] = {}
87 self._count = 0
88 for taskDef, quantumSet in self._quanta.items():
89 connections = taskDef.connections
91 # For each type of connection in the task, add a key to the
92 # `_DatasetTracker` for the connections name, with a value of
93 # the TaskDef in the appropriate field
94 for inpt in iterConnections(connections, ("inputs", "prerequisiteInputs", "initInputs")):
95 self._datasetDict.addInput(DatasetTypeName(inpt.name), taskDef)
97 for output in iterConnections(connections, ("outputs", "initOutputs")):
98 self._datasetDict.addOutput(DatasetTypeName(output.name), taskDef)
100 # For each `Quantum` in the set of all `Quantum` for this task,
101 # add a key to the `_DatasetTracker` that is a `DatasetRef` for one
102 # of the individual datasets inside the `Quantum`, with a value of a
103 # newly created QuantumNode to the appropriate input/output field.
104 self._count += len(quantumSet)
105 for quantum in quantumSet:
106 if _quantumToNodeId:
107 nodeId = _quantumToNodeId.get(quantum)
108 if nodeId is None:
109 raise ValueError("If _quantuMToNodeNumber is not None, all quanta must have an "
110 "associated value in the mapping")
111 else:
112 nodeId = NodeId(next(nodeNumberGenerator), self._buildId)
114 inits = quantum.initInputs.values()
115 inputs = quantum.inputs.values()
116 value = QuantumNode(quantum, taskDef, nodeId)
117 self._nodeIdMap[nodeId] = value
119 for dsRef in chain(inits, inputs):
120 # unfortunately, `Quantum` allows inits to be individual
121 # `DatasetRef`s or an Iterable of such, so there must
122 # be an instance check here
123 if isinstance(dsRef, Iterable):
124 for sub in dsRef:
125 self._datasetRefDict.addInput(sub, value)
126 else:
127 self._datasetRefDict.addInput(dsRef, value)
128 for dsRef in chain.from_iterable(quantum.outputs.values()):
129 self._datasetRefDict.addOutput(dsRef, value)
131 # Graph of task relations, used in various methods
132 self._taskGraph = self._datasetDict.makeNetworkXGraph()
134 # Graph of quanta relations
135 self._connectedQuanta = self._datasetRefDict.makeNetworkXGraph()
137 @property
138 def taskGraph(self) -> nx.DiGraph:
139 """Return a graph representing the relations between the tasks inside
140 the quantum graph.
142 Returns
143 -------
144 taskGraph : `networkx.Digraph`
145 Internal datastructure that holds relations of `TaskDef`s
146 """
147 return self._taskGraph
149 @property
150 def graph(self) -> nx.DiGraph:
151 """Return a graph representing the relations between all the
152 `QuantumNode`s. Largely it should be preferred to iterate over, and use
153 methods of this class, but sometimes direct access to the networkx
154 object may be helpful
156 Returns
157 -------
158 graph : `networkx.Digraph`
159 Internal datastructure that holds relations of `QuantumNode`s
160 """
161 return self._connectedQuanta
163 @property
164 def inputQuanta(self) -> Iterable[QuantumNode]:
165 """Make a `list` of all `QuantumNode`s that are 'input' nodes to the
166 graph, meaning those nodes to not depend on any other nodes in the
167 graph.
169 Returns
170 -------
171 inputNodes : iterable of `QuantumNode`
172 A list of nodes that are inputs to the graph
173 """
174 return (q for q, n in self._connectedQuanta.in_degree if n == 0)
176 @property
177 def outputQuanta(self) -> Iterable[QuantumNode]:
178 """Make a `list` of all `QuantumNode`s that are 'output' nodes to the
179 graph, meaning those nodes have no nodes that depend them in the graph.
181 Returns
182 -------
183 outputNodes : iterable of `QuantumNode`
184 A list of nodes that are outputs of the graph
185 """
186 return [q for q, n in self._connectedQuanta.out_degree if n == 0]
188 @property
189 def allDatasetTypes(self) -> Tuple[DatasetTypeName, ...]:
190 """Return all the `DatasetTypeNames` that are contained inside the graph.
192 Returns
193 -------
194 tuple of `DatasetTypeName`
195 All the data set type names that are present in the graph
196 """
197 return tuple(self._datasetDict.keys())
199 @property
200 def isConnected(self) -> bool:
201 """Return True if all of the nodes in the graph are connected, ignores
202 directionality of connections.
203 """
204 return nx.is_weakly_connected(self._connectedQuanta)
206 def getQuantumNodeByNodeId(self, nodeId: NodeId) -> QuantumNode:
207 """Lookup a `QuantumNode` from an id associated with the node.
209 Parameters
210 ----------
211 nodeId : `NodeId`
212 The number associated with a node
214 Returns
215 -------
216 node : `QuantumNode`
217 The node corresponding with input number
219 Raises
220 ------
221 IndexError
222 Raised if the requested nodeId is not in the graph.
223 IncompatibleGraphError
224 Raised if the nodeId was built with a different graph than is not
225 this instance (or a graph instance that produced this instance
226 through and operation such as subset)
227 """
228 if nodeId.buildId != self._buildId:
229 raise IncompatibleGraphError("This node was built from a different, incompatible, graph instance")
230 return self._nodeIdMap[nodeId]
232 def getQuantaForTask(self, taskDef: TaskDef) -> FrozenSet[Quantum]:
233 """Return all the `Quantum` associated with a `TaskDef`.
235 Parameters
236 ----------
237 taskDef : `TaskDef`
238 The `TaskDef` for which `Quantum` are to be queried
240 Returns
241 -------
242 frozenset of `Quantum`
243 The `set` of `Quantum` that is associated with the specified
244 `TaskDef`.
245 """
246 return frozenset(self._quanta[taskDef])
248 def findTasksWithInput(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]:
249 """Find all tasks that have the specified dataset type name as an
250 input.
252 Parameters
253 ----------
254 datasetTypeName : `str`
255 A string representing the name of a dataset type to be queried,
256 can also accept a `DatasetTypeName` which is a `NewType` of str for
257 type safety in static type checking.
259 Returns
260 -------
261 tasks : iterable of `TaskDef`
262 `TaskDef`s that have the specified `DatasetTypeName` as an input, list
263 will be empty if no tasks use specified `DatasetTypeName` as an input.
265 Raises
266 ------
267 KeyError
268 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
269 """
270 return (c for c in self._datasetDict.getInputs(datasetTypeName))
272 def findTaskWithOutput(self, datasetTypeName: DatasetTypeName) -> Optional[TaskDef]:
273 """Find all tasks that have the specified dataset type name as an
274 output.
276 Parameters
277 ----------
278 datasetTypeName : `str`
279 A string representing the name of a dataset type to be queried,
280 can also accept a `DatasetTypeName` which is a `NewType` of str for
281 type safety in static type checking.
283 Returns
284 -------
285 `TaskDef` or `None`
286 `TaskDef` that outputs `DatasetTypeName` as an output or None if none of
287 the tasks produce this `DatasetTypeName`.
289 Raises
290 ------
291 KeyError
292 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
293 """
294 return self._datasetDict.getOutput(datasetTypeName)
296 def tasksWithDSType(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]:
297 """Find all tasks that are associated with the specified dataset type
298 name.
300 Parameters
301 ----------
302 datasetTypeName : `str`
303 A string representing the name of a dataset type to be queried,
304 can also accept a `DatasetTypeName` which is a `NewType` of str for
305 type safety in static type checking.
307 Returns
308 -------
309 result : iterable of `TaskDef`
310 `TaskDef`s that are associated with the specified `DatasetTypeName`
312 Raises
313 ------
314 KeyError
315 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
316 """
317 results = self.findTasksWithInput(datasetTypeName)
318 output = self.findTaskWithOutput(datasetTypeName)
319 if output is not None:
320 results = chain(results, (output,))
321 return results
323 def findTaskDefByName(self, taskName: str) -> List[TaskDef]:
324 """Determine which `TaskDef`s in this graph are associated with a `str`
325 representing a task name (looks at the taskName property of
326 `TaskDef`s).
328 Returns a list of `TaskDef`s as a `PipelineTask` may appear multiple
329 times in a graph with different labels.
331 Parameters
332 ----------
333 taskName : str
334 Name of a task to search for
336 Returns
337 -------
338 result : list of `TaskDef`
339 List of the `TaskDef`s that have the name specified. Multiple values
340 are returned in the case that a task is used multiple times with
341 different labels.
342 """
343 results = []
344 for task in self._quanta.keys():
345 split = task.taskName.split('.')
346 if split[-1] == taskName:
347 results.append(task)
348 return results
350 def findTaskDefByLabel(self, label: str) -> Optional[TaskDef]:
351 """Determine which `TaskDef`s in this graph are associated with a `str`
352 representing a tasks label.
354 Parameters
355 ----------
356 taskName : str
357 Name of a task to search for
359 Returns
360 -------
361 result : `TaskDef`
362 `TaskDef`s that has the specified label.
363 """
364 for task in self._quanta.keys():
365 if label == task.label:
366 return task
367 return None
369 def findQuantaWithDSType(self, datasetTypeName: DatasetTypeName) -> Set[Quantum]:
370 """Return all the `Quantum` that contain a specified `DatasetTypeName`.
372 Parameters
373 ----------
374 datasetTypeName : `str`
375 The name of the dataset type to search for as a string,
376 can also accept a `DatasetTypeName` which is a `NewType` of str for
377 type safety in static type checking.
379 Returns
380 -------
381 result : `set` of `QuantumNode`s
382 A `set` of `QuantumNode`s that contain specified `DatasetTypeName`
384 Raises
385 ------
386 KeyError
387 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
389 """
390 tasks = self._datasetDict.getAll(datasetTypeName)
391 result: Set[Quantum] = set()
392 result = result.union(*(self._quanta[task] for task in tasks))
393 return result
395 def checkQuantumInGraph(self, quantum: Quantum) -> bool:
396 """Check if specified quantum appears in the graph as part of a node.
398 Parameters
399 ----------
400 quantum : `Quantum`
401 The quantum to search for
403 Returns
404 -------
405 `bool`
406 The result of searching for the quantum
407 """
408 for qset in self._quanta.values():
409 if quantum in qset:
410 return True
411 return False
413 def writeDotGraph(self, output: Union[str, io.BufferedIOBase]):
414 """Write out the graph as a dot graph.
416 Parameters
417 ----------
418 output : str or `io.BufferedIOBase`
419 Either a filesystem path to write to, or a file handle object
420 """
421 write_dot(self._connectedQuanta, output)
423 def subset(self: _T, nodes: Union[QuantumNode, Iterable[QuantumNode]]) -> _T:
424 """Create a new graph object that contains the subset of the nodes
425 specified as input. Node number is preserved.
427 Parameters
428 ----------
429 nodes : `QuantumNode` or iterable of `QuantumNode`
431 Returns
432 -------
433 graph : instance of graph type
434 An instance of the type from which the subset was created
435 """
436 if not isinstance(nodes, Iterable):
437 nodes = (nodes, )
438 quantumSubgraph = self._connectedQuanta.subgraph(nodes).nodes
439 quantumMap = defaultdict(set)
441 node: QuantumNode
442 for node in quantumSubgraph:
443 quantumMap[node.taskDef].add(node.quantum)
444 # Create an empty graph, and then populate it with custom mapping
445 newInst = type(self)({})
446 newInst._buildGraphs(quantumMap, _quantumToNodeId={n.quantum: n.nodeId for n in nodes},
447 _buildId=self._buildId)
448 return newInst
450 def subsetToConnected(self: _T) -> Tuple[_T, ...]:
451 """Generate a list of subgraphs where each is connected.
453 Returns
454 -------
455 result : list of `QuantumGraph`
456 A list of graphs that are each connected
457 """
458 return tuple(self.subset(connectedSet)
459 for connectedSet in nx.weakly_connected_components(self._connectedQuanta))
461 def determineInputsToQuantumNode(self, node: QuantumNode) -> Set[QuantumNode]:
462 """Return a set of `QuantumNode` that are direct inputs to a specified
463 node.
465 Parameters
466 ----------
467 node : `QuantumNode`
468 The node of the graph for which inputs are to be determined
470 Returns
471 -------
472 set of `QuantumNode`
473 All the nodes that are direct inputs to specified node
474 """
475 return set(pred for pred in self._connectedQuanta.predecessors(node))
477 def determineOutputsOfQuantumNode(self, node: QuantumNode) -> Set[QuantumNode]:
478 """Return a set of `QuantumNode` that are direct outputs of a specified
479 node.
481 Parameters
482 ----------
483 node : `QuantumNode`
484 The node of the graph for which outputs are to be determined
486 Returns
487 -------
488 set of `QuantumNode`
489 All the nodes that are direct outputs to specified node
490 """
491 return set(succ for succ in self._connectedQuanta.successors(node))
493 def determineConnectionsOfQuantumNode(self: _T, node: QuantumNode) -> _T:
494 """Return a graph of `QuantumNode` that are direct inputs and outputs
495 of a specified node.
497 Parameters
498 ----------
499 node : `QuantumNode`
500 The node of the graph for which connected nodes are to be determined
502 Returns
503 -------
504 graph : graph of `QuantumNode`
505 All the nodes that are directly connected to specified node
506 """
507 nodes = self.determineInputsToQuantumNode(node).union(self.determineOutputsOfQuantumNode(node))
508 nodes.add(node)
509 return self.subset(nodes)
511 def determineAncestorsOfQuantumNode(self: _T, node: QuantumNode) -> _T:
512 """Return a graph of the specified node and all the ancestor nodes
513 directly reachable by walking edges.
515 Parameters
516 ----------
517 node : `QuantumNode`
518 The node for which all ansestors are to be determined
520 Returns
521 -------
522 graph of `QuantumNode`
523 Graph of node and all of its ansestors
524 """
525 predecessorNodes = nx.ancestors(self._connectedQuanta, node)
526 predecessorNodes.add(node)
527 return self.subset(predecessorNodes)
529 def findCycle(self) -> List[Tuple[QuantumNode, QuantumNode]]:
530 """Check a graph for the presense of cycles and returns the edges of
531 any cycles found, or an empty list if there is no cycle.
533 Returns
534 -------
535 result : list of tuple of `QuantumNode`, `QuantumNode`
536 A list of any graph edges that form a cycle, or an empty list if
537 there is no cycle. Empty list to so support if graph.find_cycle()
538 syntax as an empty list is falsy.
539 """
540 try:
541 return nx.find_cycle(self._connectedQuanta)
542 except nx.NetworkXNoCycle:
543 return []
545 def save(self, file):
546 """Save QuantumGraph to a file.
547 Presently we store QuantumGraph in pickle format, this could
548 potentially change in the future if better format is found.
550 Parameters
551 ----------
552 file : `io.BufferedIOBase`
553 File to write pickle data open in binary mode.
554 """
555 pickle.dump(self, file)
557 @classmethod
558 def load(cls, file, universe):
559 """Read QuantumGraph from a file that was made by `save`.
561 Parameters
562 ----------
563 file : `io.BufferedIOBase`
564 File with pickle data open in binary mode.
565 universe: `~lsst.daf.butler.DimensionUniverse`
566 DimensionUniverse instance, not used by the method itself but
567 needed to ensure that registry data structures are initialized.
569 Returns
570 -------
571 graph : `QuantumGraph`
572 Resulting QuantumGraph instance.
574 Raises
575 ------
576 TypeError
577 Raised if pickle contains instance of a type other than
578 QuantumGraph.
579 Notes
580 -----
581 Reading Quanta from pickle requires existence of singleton
582 DimensionUniverse which is usually instantiated during Registry
583 initialization. To make sure that DimensionUniverse exists this method
584 accepts dummy DimensionUniverse argument.
585 """
586 qgraph = pickle.load(file)
587 if not isinstance(qgraph, QuantumGraph):
588 raise TypeError(f"QuantumGraph pickle file has contains unexpected object type: {type(qgraph)}")
589 return qgraph
591 def iterTaskGraph(self) -> Generator[TaskDef, None, None]:
592 """Iterate over the `taskGraph` attribute in topological order
594 Yields
595 ------
596 `TaskDef`
597 `TaskDef` objects in topological order
598 """
599 yield from nx.topological_sort(self.taskGraph)
601 def __iter__(self) -> Generator[QuantumNode, None, None]:
602 yield from nx.topological_sort(self._connectedQuanta)
604 def __len__(self) -> int:
605 return self._count
607 def __contains__(self, node: QuantumNode) -> bool:
608 return self._connectedQuanta.has_node(node)
610 def __getstate__(self) -> dict:
611 """Stores a compact form of the graph as a list of graph nodes, and a
612 tuple of task labels and task configs. The full graph can be
613 reconstructed with this information, and it preseves the ordering of
614 the graph ndoes.
615 """
616 return {"nodesList": list(self)}
618 def __setstate__(self, state: dict):
619 """Reconstructs the state of the graph from the information persisted
620 in getstate.
621 """
622 quanta: DefaultDict[TaskDef, Set[Quantum]] = defaultdict(set)
623 quantumToNodeId: Dict[Quantum, NodeId] = {}
624 quantumNode: QuantumNode
625 for quantumNode in state['nodesList']:
626 quanta[quantumNode.taskDef].add(quantumNode.quantum)
627 quantumToNodeId[quantumNode.quantum] = quantumNode.nodeId
628 _buildId = quantumNode.nodeId.buildId if state['nodesList'] else None # type: ignore
629 self._buildGraphs(quanta, _quantumToNodeId=quantumToNodeId, _buildId=_buildId)
631 def __eq__(self, other: object) -> bool:
632 if not isinstance(other, QuantumGraph):
633 return False
634 if len(self) != len(other):
635 return False
636 for node in self:
637 if node not in other:
638 return False
639 if self.determineInputsToQuantumNode(node) != other.determineInputsToQuantumNode(node):
640 return False
641 if self.determineOutputsOfQuantumNode(node) != other.determineOutputsOfQuantumNode(node):
642 return False
643 return list(self.taskGraph) == list(other.taskGraph)