Coverage for python/lsst/pipe/base/graph/graph.py : 22%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ("QuantumGraph", "IncompatibleGraphError")
25from collections import defaultdict
27from itertools import chain, count
28import io
29import networkx as nx
30from networkx.drawing.nx_agraph import write_dot
31import os
32import pickle
33import time
34from typing import (DefaultDict, Dict, FrozenSet, Iterable, List, Mapping, Set, Generator, Optional, Tuple,
35 Union, TypeVar)
37from ..connections import iterConnections
38from ..pipeline import TaskDef
39from lsst.daf.butler import Quantum, DatasetRef
41from ._implDetails import _DatasetTracker, DatasetTypeName
42from .quantumNode import QuantumNode, NodeId, BuildId
44_T = TypeVar("_T", bound="QuantumGraph")
47class IncompatibleGraphError(Exception):
48 """Exception class to indicate that a lookup by NodeId is impossible due
49 to incompatibilities
50 """
51 pass
54class QuantumGraph:
55 """QuantumGraph is a directed acyclic graph of `QuantumNode`s
57 This data structure represents a concrete workflow generated from a
58 `Pipeline`.
60 Parameters
61 ----------
62 quanta : Mapping of `TaskDef` to sets of `Quantum`
63 This maps tasks (and their configs) to the sets of data they are to
64 process.
65 """
66 def __init__(self, quanta: Mapping[TaskDef, Set[Quantum]]):
67 self._buildGraphs(quanta)
69 def _buildGraphs(self,
70 quanta: Mapping[TaskDef, Set[Quantum]],
71 *,
72 _quantumToNodeId: Optional[Mapping[Quantum, NodeId]] = None,
73 _buildId: Optional[BuildId] = None):
74 """Builds the graph that is used to store the relation between tasks,
75 and the graph that holds the relations between quanta
76 """
77 self._quanta = quanta
78 self._buildId = _buildId if _buildId is not None else BuildId(f"{time.time()}-{os.getpid()}")
79 # Data structures used to identify relations between components;
80 # DatasetTypeName -> TaskDef for task,
81 # and DatasetRef -> QuantumNode for the quanta
82 self._datasetDict = _DatasetTracker[DatasetTypeName, TaskDef]()
83 self._datasetRefDict = _DatasetTracker[DatasetRef, QuantumNode]()
85 nodeNumberGenerator = count()
86 self._nodeIdMap: Dict[NodeId, QuantumNode] = {}
87 self._count = 0
88 for taskDef, quantumSet in self._quanta.items():
89 connections = taskDef.connections
91 # For each type of connection in the task, add a key to the
92 # `_DatasetTracker` for the connections name, with a value of
93 # the TaskDef in the appropriate field
94 for inpt in iterConnections(connections, ("inputs", "prerequisiteInputs", "initInputs")):
95 self._datasetDict.addInput(DatasetTypeName(inpt.name), taskDef)
97 for output in iterConnections(connections, ("outputs", "initOutputs")):
98 self._datasetDict.addOutput(DatasetTypeName(output.name), taskDef)
100 # For each `Quantum` in the set of all `Quantum` for this task,
101 # add a key to the `_DatasetTracker` that is a `DatasetRef` for one
102 # of the individual datasets inside the `Quantum`, with a value of
103 # a newly created QuantumNode to the appropriate input/output
104 # field.
105 self._count += len(quantumSet)
106 for quantum in quantumSet:
107 if _quantumToNodeId:
108 nodeId = _quantumToNodeId.get(quantum)
109 if nodeId is None:
110 raise ValueError("If _quantuMToNodeNumber is not None, all quanta must have an "
111 "associated value in the mapping")
112 else:
113 nodeId = NodeId(next(nodeNumberGenerator), self._buildId)
115 inits = quantum.initInputs.values()
116 inputs = quantum.inputs.values()
117 value = QuantumNode(quantum, taskDef, nodeId)
118 self._nodeIdMap[nodeId] = value
120 for dsRef in chain(inits, inputs):
121 # unfortunately, `Quantum` allows inits to be individual
122 # `DatasetRef`s or an Iterable of such, so there must
123 # be an instance check here
124 if isinstance(dsRef, Iterable):
125 for sub in dsRef:
126 self._datasetRefDict.addInput(sub, value)
127 else:
128 self._datasetRefDict.addInput(dsRef, value)
129 for dsRef in chain.from_iterable(quantum.outputs.values()):
130 self._datasetRefDict.addOutput(dsRef, value)
132 # Graph of task relations, used in various methods
133 self._taskGraph = self._datasetDict.makeNetworkXGraph()
135 # Graph of quanta relations
136 self._connectedQuanta = self._datasetRefDict.makeNetworkXGraph()
138 @property
139 def taskGraph(self) -> nx.DiGraph:
140 """Return a graph representing the relations between the tasks inside
141 the quantum graph.
143 Returns
144 -------
145 taskGraph : `networkx.Digraph`
146 Internal datastructure that holds relations of `TaskDef`s
147 """
148 return self._taskGraph
150 @property
151 def graph(self) -> nx.DiGraph:
152 """Return a graph representing the relations between all the
153 `QuantumNode`s. Largely it should be preferred to iterate over, and use
154 methods of this class, but sometimes direct access to the networkx
155 object may be helpful
157 Returns
158 -------
159 graph : `networkx.Digraph`
160 Internal datastructure that holds relations of `QuantumNode`s
161 """
162 return self._connectedQuanta
164 @property
165 def inputQuanta(self) -> Iterable[QuantumNode]:
166 """Make a `list` of all `QuantumNode`s that are 'input' nodes to the
167 graph, meaning those nodes to not depend on any other nodes in the
168 graph.
170 Returns
171 -------
172 inputNodes : iterable of `QuantumNode`
173 A list of nodes that are inputs to the graph
174 """
175 return (q for q, n in self._connectedQuanta.in_degree if n == 0)
177 @property
178 def outputQuanta(self) -> Iterable[QuantumNode]:
179 """Make a `list` of all `QuantumNode`s that are 'output' nodes to the
180 graph, meaning those nodes have no nodes that depend them in the graph.
182 Returns
183 -------
184 outputNodes : iterable of `QuantumNode`
185 A list of nodes that are outputs of the graph
186 """
187 return [q for q, n in self._connectedQuanta.out_degree if n == 0]
189 @property
190 def allDatasetTypes(self) -> Tuple[DatasetTypeName, ...]:
191 """Return all the `DatasetTypeNames` that are contained inside the
192 graph.
194 Returns
195 -------
196 tuple of `DatasetTypeName`
197 All the data set type names that are present in the graph
198 """
199 return tuple(self._datasetDict.keys())
201 @property
202 def isConnected(self) -> bool:
203 """Return True if all of the nodes in the graph are connected, ignores
204 directionality of connections.
205 """
206 return nx.is_weakly_connected(self._connectedQuanta)
208 def getQuantumNodeByNodeId(self, nodeId: NodeId) -> QuantumNode:
209 """Lookup a `QuantumNode` from an id associated with the node.
211 Parameters
212 ----------
213 nodeId : `NodeId`
214 The number associated with a node
216 Returns
217 -------
218 node : `QuantumNode`
219 The node corresponding with input number
221 Raises
222 ------
223 IndexError
224 Raised if the requested nodeId is not in the graph.
225 IncompatibleGraphError
226 Raised if the nodeId was built with a different graph than is not
227 this instance (or a graph instance that produced this instance
228 through and operation such as subset)
229 """
230 if nodeId.buildId != self._buildId:
231 raise IncompatibleGraphError("This node was built from a different, incompatible, graph instance")
232 return self._nodeIdMap[nodeId]
234 def getQuantaForTask(self, taskDef: TaskDef) -> FrozenSet[Quantum]:
235 """Return all the `Quantum` associated with a `TaskDef`.
237 Parameters
238 ----------
239 taskDef : `TaskDef`
240 The `TaskDef` for which `Quantum` are to be queried
242 Returns
243 -------
244 frozenset of `Quantum`
245 The `set` of `Quantum` that is associated with the specified
246 `TaskDef`.
247 """
248 return frozenset(self._quanta[taskDef])
250 def findTasksWithInput(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]:
251 """Find all tasks that have the specified dataset type name as an
252 input.
254 Parameters
255 ----------
256 datasetTypeName : `str`
257 A string representing the name of a dataset type to be queried,
258 can also accept a `DatasetTypeName` which is a `NewType` of str for
259 type safety in static type checking.
261 Returns
262 -------
263 tasks : iterable of `TaskDef`
264 `TaskDef`s that have the specified `DatasetTypeName` as an input,
265 list will be empty if no tasks use specified `DatasetTypeName` as
266 an input.
268 Raises
269 ------
270 KeyError
271 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
272 """
273 return (c for c in self._datasetDict.getInputs(datasetTypeName))
275 def findTaskWithOutput(self, datasetTypeName: DatasetTypeName) -> Optional[TaskDef]:
276 """Find all tasks that have the specified dataset type name as an
277 output.
279 Parameters
280 ----------
281 datasetTypeName : `str`
282 A string representing the name of a dataset type to be queried,
283 can also accept a `DatasetTypeName` which is a `NewType` of str for
284 type safety in static type checking.
286 Returns
287 -------
288 `TaskDef` or `None`
289 `TaskDef` that outputs `DatasetTypeName` as an output or None if
290 none of the tasks produce this `DatasetTypeName`.
292 Raises
293 ------
294 KeyError
295 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
296 """
297 return self._datasetDict.getOutput(datasetTypeName)
299 def tasksWithDSType(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]:
300 """Find all tasks that are associated with the specified dataset type
301 name.
303 Parameters
304 ----------
305 datasetTypeName : `str`
306 A string representing the name of a dataset type to be queried,
307 can also accept a `DatasetTypeName` which is a `NewType` of str for
308 type safety in static type checking.
310 Returns
311 -------
312 result : iterable of `TaskDef`
313 `TaskDef`s that are associated with the specified `DatasetTypeName`
315 Raises
316 ------
317 KeyError
318 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
319 """
320 results = self.findTasksWithInput(datasetTypeName)
321 output = self.findTaskWithOutput(datasetTypeName)
322 if output is not None:
323 results = chain(results, (output,))
324 return results
326 def findTaskDefByName(self, taskName: str) -> List[TaskDef]:
327 """Determine which `TaskDef`s in this graph are associated with a `str`
328 representing a task name (looks at the taskName property of
329 `TaskDef`s).
331 Returns a list of `TaskDef`s as a `PipelineTask` may appear multiple
332 times in a graph with different labels.
334 Parameters
335 ----------
336 taskName : str
337 Name of a task to search for
339 Returns
340 -------
341 result : list of `TaskDef`
342 List of the `TaskDef`s that have the name specified. Multiple
343 values are returned in the case that a task is used multiple times
344 with different labels.
345 """
346 results = []
347 for task in self._quanta.keys():
348 split = task.taskName.split('.')
349 if split[-1] == taskName:
350 results.append(task)
351 return results
353 def findTaskDefByLabel(self, label: str) -> Optional[TaskDef]:
354 """Determine which `TaskDef`s in this graph are associated with a `str`
355 representing a tasks label.
357 Parameters
358 ----------
359 taskName : str
360 Name of a task to search for
362 Returns
363 -------
364 result : `TaskDef`
365 `TaskDef`s that has the specified label.
366 """
367 for task in self._quanta.keys():
368 if label == task.label:
369 return task
370 return None
372 def findQuantaWithDSType(self, datasetTypeName: DatasetTypeName) -> Set[Quantum]:
373 """Return all the `Quantum` that contain a specified `DatasetTypeName`.
375 Parameters
376 ----------
377 datasetTypeName : `str`
378 The name of the dataset type to search for as a string,
379 can also accept a `DatasetTypeName` which is a `NewType` of str for
380 type safety in static type checking.
382 Returns
383 -------
384 result : `set` of `QuantumNode`s
385 A `set` of `QuantumNode`s that contain specified `DatasetTypeName`
387 Raises
388 ------
389 KeyError
390 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
392 """
393 tasks = self._datasetDict.getAll(datasetTypeName)
394 result: Set[Quantum] = set()
395 result = result.union(*(self._quanta[task] for task in tasks))
396 return result
398 def checkQuantumInGraph(self, quantum: Quantum) -> bool:
399 """Check if specified quantum appears in the graph as part of a node.
401 Parameters
402 ----------
403 quantum : `Quantum`
404 The quantum to search for
406 Returns
407 -------
408 `bool`
409 The result of searching for the quantum
410 """
411 for qset in self._quanta.values():
412 if quantum in qset:
413 return True
414 return False
416 def writeDotGraph(self, output: Union[str, io.BufferedIOBase]):
417 """Write out the graph as a dot graph.
419 Parameters
420 ----------
421 output : str or `io.BufferedIOBase`
422 Either a filesystem path to write to, or a file handle object
423 """
424 write_dot(self._connectedQuanta, output)
426 def subset(self: _T, nodes: Union[QuantumNode, Iterable[QuantumNode]]) -> _T:
427 """Create a new graph object that contains the subset of the nodes
428 specified as input. Node number is preserved.
430 Parameters
431 ----------
432 nodes : `QuantumNode` or iterable of `QuantumNode`
434 Returns
435 -------
436 graph : instance of graph type
437 An instance of the type from which the subset was created
438 """
439 if not isinstance(nodes, Iterable):
440 nodes = (nodes, )
441 quantumSubgraph = self._connectedQuanta.subgraph(nodes).nodes
442 quantumMap = defaultdict(set)
444 node: QuantumNode
445 for node in quantumSubgraph:
446 quantumMap[node.taskDef].add(node.quantum)
447 # Create an empty graph, and then populate it with custom mapping
448 newInst = type(self)({})
449 newInst._buildGraphs(quantumMap, _quantumToNodeId={n.quantum: n.nodeId for n in nodes},
450 _buildId=self._buildId)
451 return newInst
453 def subsetToConnected(self: _T) -> Tuple[_T, ...]:
454 """Generate a list of subgraphs where each is connected.
456 Returns
457 -------
458 result : list of `QuantumGraph`
459 A list of graphs that are each connected
460 """
461 return tuple(self.subset(connectedSet)
462 for connectedSet in nx.weakly_connected_components(self._connectedQuanta))
464 def determineInputsToQuantumNode(self, node: QuantumNode) -> Set[QuantumNode]:
465 """Return a set of `QuantumNode` that are direct inputs to a specified
466 node.
468 Parameters
469 ----------
470 node : `QuantumNode`
471 The node of the graph for which inputs are to be determined
473 Returns
474 -------
475 set of `QuantumNode`
476 All the nodes that are direct inputs to specified node
477 """
478 return set(pred for pred in self._connectedQuanta.predecessors(node))
480 def determineOutputsOfQuantumNode(self, node: QuantumNode) -> Set[QuantumNode]:
481 """Return a set of `QuantumNode` that are direct outputs of a specified
482 node.
484 Parameters
485 ----------
486 node : `QuantumNode`
487 The node of the graph for which outputs are to be determined
489 Returns
490 -------
491 set of `QuantumNode`
492 All the nodes that are direct outputs to specified node
493 """
494 return set(succ for succ in self._connectedQuanta.successors(node))
496 def determineConnectionsOfQuantumNode(self: _T, node: QuantumNode) -> _T:
497 """Return a graph of `QuantumNode` that are direct inputs and outputs
498 of a specified node.
500 Parameters
501 ----------
502 node : `QuantumNode`
503 The node of the graph for which connected nodes are to be
504 determined.
506 Returns
507 -------
508 graph : graph of `QuantumNode`
509 All the nodes that are directly connected to specified node
510 """
511 nodes = self.determineInputsToQuantumNode(node).union(self.determineOutputsOfQuantumNode(node))
512 nodes.add(node)
513 return self.subset(nodes)
515 def determineAncestorsOfQuantumNode(self: _T, node: QuantumNode) -> _T:
516 """Return a graph of the specified node and all the ancestor nodes
517 directly reachable by walking edges.
519 Parameters
520 ----------
521 node : `QuantumNode`
522 The node for which all ansestors are to be determined
524 Returns
525 -------
526 graph of `QuantumNode`
527 Graph of node and all of its ansestors
528 """
529 predecessorNodes = nx.ancestors(self._connectedQuanta, node)
530 predecessorNodes.add(node)
531 return self.subset(predecessorNodes)
533 def findCycle(self) -> List[Tuple[QuantumNode, QuantumNode]]:
534 """Check a graph for the presense of cycles and returns the edges of
535 any cycles found, or an empty list if there is no cycle.
537 Returns
538 -------
539 result : list of tuple of `QuantumNode`, `QuantumNode`
540 A list of any graph edges that form a cycle, or an empty list if
541 there is no cycle. Empty list to so support if graph.find_cycle()
542 syntax as an empty list is falsy.
543 """
544 try:
545 return nx.find_cycle(self._connectedQuanta)
546 except nx.NetworkXNoCycle:
547 return []
549 def save(self, file):
550 """Save QuantumGraph to a file.
551 Presently we store QuantumGraph in pickle format, this could
552 potentially change in the future if better format is found.
554 Parameters
555 ----------
556 file : `io.BufferedIOBase`
557 File to write pickle data open in binary mode.
558 """
559 pickle.dump(self, file)
561 @classmethod
562 def load(cls, file, universe):
563 """Read QuantumGraph from a file that was made by `save`.
565 Parameters
566 ----------
567 file : `io.BufferedIOBase`
568 File with pickle data open in binary mode.
569 universe: `~lsst.daf.butler.DimensionUniverse`
570 DimensionUniverse instance, not used by the method itself but
571 needed to ensure that registry data structures are initialized.
573 Returns
574 -------
575 graph : `QuantumGraph`
576 Resulting QuantumGraph instance.
578 Raises
579 ------
580 TypeError
581 Raised if pickle contains instance of a type other than
582 QuantumGraph.
583 Notes
584 -----
585 Reading Quanta from pickle requires existence of singleton
586 DimensionUniverse which is usually instantiated during Registry
587 initialization. To make sure that DimensionUniverse exists this method
588 accepts dummy DimensionUniverse argument.
589 """
590 qgraph = pickle.load(file)
591 if not isinstance(qgraph, QuantumGraph):
592 raise TypeError(f"QuantumGraph pickle file has contains unexpected object type: {type(qgraph)}")
593 return qgraph
595 def iterTaskGraph(self) -> Generator[TaskDef, None, None]:
596 """Iterate over the `taskGraph` attribute in topological order
598 Yields
599 ------
600 `TaskDef`
601 `TaskDef` objects in topological order
602 """
603 yield from nx.topological_sort(self.taskGraph)
605 def __iter__(self) -> Generator[QuantumNode, None, None]:
606 yield from nx.topological_sort(self._connectedQuanta)
608 def __len__(self) -> int:
609 return self._count
611 def __contains__(self, node: QuantumNode) -> bool:
612 return self._connectedQuanta.has_node(node)
614 def __getstate__(self) -> dict:
615 """Stores a compact form of the graph as a list of graph nodes, and a
616 tuple of task labels and task configs. The full graph can be
617 reconstructed with this information, and it preseves the ordering of
618 the graph ndoes.
619 """
620 return {"nodesList": list(self)}
622 def __setstate__(self, state: dict):
623 """Reconstructs the state of the graph from the information persisted
624 in getstate.
625 """
626 quanta: DefaultDict[TaskDef, Set[Quantum]] = defaultdict(set)
627 quantumToNodeId: Dict[Quantum, NodeId] = {}
628 quantumNode: QuantumNode
629 for quantumNode in state['nodesList']:
630 quanta[quantumNode.taskDef].add(quantumNode.quantum)
631 quantumToNodeId[quantumNode.quantum] = quantumNode.nodeId
632 _buildId = quantumNode.nodeId.buildId if state['nodesList'] else None # type: ignore
633 self._buildGraphs(quanta, _quantumToNodeId=quantumToNodeId, _buildId=_buildId)
635 def __eq__(self, other: object) -> bool:
636 if not isinstance(other, QuantumGraph):
637 return False
638 if len(self) != len(other):
639 return False
640 for node in self:
641 if node not in other:
642 return False
643 if self.determineInputsToQuantumNode(node) != other.determineInputsToQuantumNode(node):
644 return False
645 if self.determineOutputsOfQuantumNode(node) != other.determineOutputsOfQuantumNode(node):
646 return False
647 return list(self.taskGraph) == list(other.taskGraph)