Coverage for python/lsst/pipe/base/graph/graph.py: 17%
369 statements
« prev ^ index » next coverage.py v6.5.0, created at 2022-11-09 03:06 -0800
« prev ^ index » next coverage.py v6.5.0, created at 2022-11-09 03:06 -0800
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ("QuantumGraph", "IncompatibleGraphError")
25import io
26import json
27import lzma
28import os
29import pickle
30import struct
31import time
32import uuid
33import warnings
34from collections import defaultdict, deque
35from itertools import chain
36from types import MappingProxyType
37from typing import (
38 Any,
39 BinaryIO,
40 DefaultDict,
41 Deque,
42 Dict,
43 FrozenSet,
44 Generator,
45 Iterable,
46 List,
47 Mapping,
48 MutableMapping,
49 Optional,
50 Set,
51 Tuple,
52 TypeVar,
53 Union,
54)
56import networkx as nx
57from lsst.daf.butler import DatasetRef, DatasetType, DimensionRecordsAccumulator, DimensionUniverse, Quantum
58from lsst.resources import ResourcePath, ResourcePathExpression
59from lsst.utils.introspection import get_full_type_name
60from networkx.drawing.nx_agraph import write_dot
62from ..connections import iterConnections
63from ..pipeline import TaskDef
64from ._implDetails import DatasetTypeName, _DatasetTracker, _pruner
65from ._loadHelpers import LoadHelper
66from ._versionDeserializers import DESERIALIZER_MAP
67from .quantumNode import BuildId, QuantumNode
69_T = TypeVar("_T", bound="QuantumGraph")
71# modify this constant any time the on disk representation of the save file
72# changes, and update the load helpers to behave properly for each version.
73SAVE_VERSION = 3
75# Strings used to describe the format for the preamble bytes in a file save
76# The base is a big endian encoded unsigned short that is used to hold the
77# file format version. This allows reading version bytes and determine which
78# loading code should be used for the rest of the file
79STRUCT_FMT_BASE = ">H"
80#
81# Version 1
82# This marks a big endian encoded format with an unsigned short, an unsigned
83# long long, and an unsigned long long in the byte stream
84# Version 2
85# A big endian encoded format with an unsigned long long byte stream used to
86# indicate the total length of the entire header.
87STRUCT_FMT_STRING = {1: ">QQ", 2: ">Q"}
89# magic bytes that help determine this is a graph save
90MAGIC_BYTES = b"qgraph4\xf6\xe8\xa9"
93class IncompatibleGraphError(Exception):
94 """Exception class to indicate that a lookup by NodeId is impossible due
95 to incompatibilities
96 """
98 pass
101class QuantumGraph:
102 """QuantumGraph is a directed acyclic graph of `QuantumNode` objects
104 This data structure represents a concrete workflow generated from a
105 `Pipeline`.
107 Parameters
108 ----------
109 quanta : Mapping of `TaskDef` to sets of `Quantum`
110 This maps tasks (and their configs) to the sets of data they are to
111 process.
112 metadata : Optional Mapping of `str` to primitives
113 This is an optional parameter of extra data to carry with the graph.
114 Entries in this mapping should be able to be serialized in JSON.
115 pruneRefs : iterable [ `DatasetRef` ], optional
116 Set of dataset refs to exclude from a graph.
117 initInputs : `Mapping`, optional
118 Maps tasks to their InitInput dataset refs. Dataset refs can be either
119 resolved or non-resolved. Presently the same dataset refs are included
120 in each `Quantum` for the same task.
121 initOutputs : `Mapping`, optional
122 Maps tasks to their InitOutput dataset refs. Dataset refs can be either
123 resolved or non-resolved. For intermediate resolved refs their dataset
124 ID must match ``initInputs`` and Quantum ``initInputs``.
126 Raises
127 ------
128 ValueError
129 Raised if the graph is pruned such that some tasks no longer have nodes
130 associated with them.
131 """
133 def __init__(
134 self,
135 quanta: Mapping[TaskDef, Set[Quantum]],
136 metadata: Optional[Mapping[str, Any]] = None,
137 pruneRefs: Optional[Iterable[DatasetRef]] = None,
138 universe: Optional[DimensionUniverse] = None,
139 initInputs: Optional[Mapping[TaskDef, Iterable[DatasetRef]]] = None,
140 initOutputs: Optional[Mapping[TaskDef, Iterable[DatasetRef]]] = None,
141 ):
142 self._buildGraphs(
143 quanta,
144 metadata=metadata,
145 pruneRefs=pruneRefs,
146 universe=universe,
147 initInputs=initInputs,
148 initOutputs=initOutputs,
149 )
151 def _buildGraphs(
152 self,
153 quanta: Mapping[TaskDef, Set[Quantum]],
154 *,
155 _quantumToNodeId: Optional[Mapping[Quantum, uuid.UUID]] = None,
156 _buildId: Optional[BuildId] = None,
157 metadata: Optional[Mapping[str, Any]] = None,
158 pruneRefs: Optional[Iterable[DatasetRef]] = None,
159 universe: Optional[DimensionUniverse] = None,
160 initInputs: Optional[Mapping[TaskDef, Iterable[DatasetRef]]] = None,
161 initOutputs: Optional[Mapping[TaskDef, Iterable[DatasetRef]]] = None,
162 ) -> None:
163 """Builds the graph that is used to store the relation between tasks,
164 and the graph that holds the relations between quanta
165 """
166 self._metadata = metadata
167 self._buildId = _buildId if _buildId is not None else BuildId(f"{time.time()}-{os.getpid()}")
168 # Data structures used to identify relations between components;
169 # DatasetTypeName -> TaskDef for task,
170 # and DatasetRef -> QuantumNode for the quanta
171 self._datasetDict = _DatasetTracker[DatasetTypeName, TaskDef](createInverse=True)
172 self._datasetRefDict = _DatasetTracker[DatasetRef, QuantumNode]()
174 self._nodeIdMap: Dict[uuid.UUID, QuantumNode] = {}
175 self._taskToQuantumNode: MutableMapping[TaskDef, Set[QuantumNode]] = defaultdict(set)
176 for taskDef, quantumSet in quanta.items():
177 connections = taskDef.connections
179 # For each type of connection in the task, add a key to the
180 # `_DatasetTracker` for the connections name, with a value of
181 # the TaskDef in the appropriate field
182 for inpt in iterConnections(connections, ("inputs", "prerequisiteInputs", "initInputs")):
183 self._datasetDict.addConsumer(DatasetTypeName(inpt.name), taskDef)
185 for output in iterConnections(connections, ("outputs",)):
186 self._datasetDict.addProducer(DatasetTypeName(output.name), taskDef)
188 # For each `Quantum` in the set of all `Quantum` for this task,
189 # add a key to the `_DatasetTracker` that is a `DatasetRef` for one
190 # of the individual datasets inside the `Quantum`, with a value of
191 # a newly created QuantumNode to the appropriate input/output
192 # field.
193 for quantum in quantumSet:
194 if quantum.dataId is not None:
195 if universe is None:
196 universe = quantum.dataId.universe
197 elif universe != quantum.dataId.universe:
198 raise RuntimeError(
199 "Mismatched dimension universes in QuantumGraph construction: "
200 f"{universe} != {quantum.dataId.universe}. "
201 )
203 if _quantumToNodeId:
204 if (nodeId := _quantumToNodeId.get(quantum)) is None:
205 raise ValueError(
206 "If _quantuMToNodeNumber is not None, all quanta must have an "
207 "associated value in the mapping"
208 )
209 else:
210 nodeId = uuid.uuid4()
212 inits = quantum.initInputs.values()
213 inputs = quantum.inputs.values()
214 value = QuantumNode(quantum, taskDef, nodeId)
215 self._taskToQuantumNode[taskDef].add(value)
216 self._nodeIdMap[nodeId] = value
218 for dsRef in chain(inits, inputs):
219 # unfortunately, `Quantum` allows inits to be individual
220 # `DatasetRef`s or an Iterable of such, so there must
221 # be an instance check here
222 if isinstance(dsRef, Iterable):
223 for sub in dsRef:
224 if sub.isComponent():
225 sub = sub.makeCompositeRef()
226 self._datasetRefDict.addConsumer(sub, value)
227 else:
228 assert isinstance(dsRef, DatasetRef)
229 if dsRef.isComponent():
230 dsRef = dsRef.makeCompositeRef()
231 self._datasetRefDict.addConsumer(dsRef, value)
232 for dsRef in chain.from_iterable(quantum.outputs.values()):
233 self._datasetRefDict.addProducer(dsRef, value)
235 if pruneRefs is not None:
236 # track what refs were pruned and prune the graph
237 prunes: Set[QuantumNode] = set()
238 _pruner(self._datasetRefDict, pruneRefs, alreadyPruned=prunes)
240 # recreate the taskToQuantumNode dict removing nodes that have been
241 # pruned. Keep track of task defs that now have no QuantumNodes
242 emptyTasks: Set[str] = set()
243 newTaskToQuantumNode: DefaultDict[TaskDef, Set[QuantumNode]] = defaultdict(set)
244 # accumulate all types
245 types_ = set()
246 # tracker for any pruneRefs that have caused tasks to have no nodes
247 # This helps the user find out what caused the issues seen.
248 culprits = set()
249 # Find all the types from the refs to prune
250 for r in pruneRefs:
251 types_.add(r.datasetType)
253 # For each of the tasks, and their associated nodes, remove any
254 # any nodes that were pruned. If there are no nodes associated
255 # with a task, record that task, and find out if that was due to
256 # a type from an input ref to prune.
257 for td, taskNodes in self._taskToQuantumNode.items():
258 diff = taskNodes.difference(prunes)
259 if len(diff) == 0:
260 if len(taskNodes) != 0:
261 tp: DatasetType
262 for tp in types_:
263 if (tmpRefs := next(iter(taskNodes)).quantum.inputs.get(tp)) and not set(
264 tmpRefs
265 ).difference(pruneRefs):
266 culprits.add(tp.name)
267 emptyTasks.add(td.label)
268 newTaskToQuantumNode[td] = diff
270 # update the internal dict
271 self._taskToQuantumNode = newTaskToQuantumNode
273 if emptyTasks:
274 raise ValueError(
275 f"{', '.join(emptyTasks)} task(s) have no nodes associated with them "
276 f"after graph pruning; {', '.join(culprits)} caused over-pruning"
277 )
279 # Dimension universe
280 if universe is None:
281 raise RuntimeError(
282 "Dimension universe or at least one quantum with a data ID "
283 "must be provided when constructing a QuantumGraph."
284 )
285 self._universe = universe
287 # Graph of quanta relations
288 self._connectedQuanta = self._datasetRefDict.makeNetworkXGraph()
289 self._count = len(self._connectedQuanta)
291 # Graph of task relations, used in various methods
292 self._taskGraph = self._datasetDict.makeNetworkXGraph()
294 # convert default dict into a regular to prevent accidental key
295 # insertion
296 self._taskToQuantumNode = dict(self._taskToQuantumNode.items())
298 self._initInputRefs: Dict[TaskDef, List[DatasetRef]] = {}
299 self._initOutputRefs: Dict[TaskDef, List[DatasetRef]] = {}
300 if initInputs is not None:
301 self._initInputRefs = {taskDef: list(refs) for taskDef, refs in initInputs.items()}
302 if initOutputs is not None:
303 self._initOutputRefs = {taskDef: list(refs) for taskDef, refs in initOutputs.items()}
305 @property
306 def taskGraph(self) -> nx.DiGraph:
307 """Return a graph representing the relations between the tasks inside
308 the quantum graph.
310 Returns
311 -------
312 taskGraph : `networkx.Digraph`
313 Internal datastructure that holds relations of `TaskDef` objects
314 """
315 return self._taskGraph
317 @property
318 def graph(self) -> nx.DiGraph:
319 """Return a graph representing the relations between all the
320 `QuantumNode` objects. Largely it should be preferred to iterate
321 over, and use methods of this class, but sometimes direct access to
322 the networkx object may be helpful
324 Returns
325 -------
326 graph : `networkx.Digraph`
327 Internal datastructure that holds relations of `QuantumNode`
328 objects
329 """
330 return self._connectedQuanta
332 @property
333 def inputQuanta(self) -> Iterable[QuantumNode]:
334 """Make a `list` of all `QuantumNode` objects that are 'input' nodes
335 to the graph, meaning those nodes to not depend on any other nodes in
336 the graph.
338 Returns
339 -------
340 inputNodes : iterable of `QuantumNode`
341 A list of nodes that are inputs to the graph
342 """
343 return (q for q, n in self._connectedQuanta.in_degree if n == 0)
345 @property
346 def outputQuanta(self) -> Iterable[QuantumNode]:
347 """Make a `list` of all `QuantumNode` objects that are 'output' nodes
348 to the graph, meaning those nodes have no nodes that depend them in
349 the graph.
351 Returns
352 -------
353 outputNodes : iterable of `QuantumNode`
354 A list of nodes that are outputs of the graph
355 """
356 return [q for q, n in self._connectedQuanta.out_degree if n == 0]
358 @property
359 def allDatasetTypes(self) -> Tuple[DatasetTypeName, ...]:
360 """Return all the `DatasetTypeName` objects that are contained inside
361 the graph.
363 Returns
364 -------
365 tuple of `DatasetTypeName`
366 All the data set type names that are present in the graph
367 """
368 return tuple(self._datasetDict.keys())
370 @property
371 def isConnected(self) -> bool:
372 """Return True if all of the nodes in the graph are connected, ignores
373 directionality of connections.
374 """
375 return nx.is_weakly_connected(self._connectedQuanta)
377 def pruneGraphFromRefs(self: _T, refs: Iterable[DatasetRef]) -> _T:
378 r"""Return a graph pruned of input `~lsst.daf.butler.DatasetRef`\ s
379 and nodes which depend on them.
381 Parameters
382 ----------
383 refs : `Iterable` of `DatasetRef`
384 Refs which should be removed from resulting graph
386 Returns
387 -------
388 graph : `QuantumGraph`
389 A graph that has been pruned of specified refs and the nodes that
390 depend on them.
391 """
392 newInst = object.__new__(type(self))
393 quantumMap = defaultdict(set)
394 for node in self:
395 quantumMap[node.taskDef].add(node.quantum)
397 # convert to standard dict to prevent accidental key insertion
398 quantumDict: Dict[TaskDef, Set[Quantum]] = dict(quantumMap.items())
400 newInst._buildGraphs(
401 quantumDict,
402 _quantumToNodeId={n.quantum: n.nodeId for n in self},
403 metadata=self._metadata,
404 pruneRefs=refs,
405 universe=self._universe,
406 )
407 return newInst
409 def getQuantumNodeByNodeId(self, nodeId: uuid.UUID) -> QuantumNode:
410 """Lookup a `QuantumNode` from an id associated with the node.
412 Parameters
413 ----------
414 nodeId : `NodeId`
415 The number associated with a node
417 Returns
418 -------
419 node : `QuantumNode`
420 The node corresponding with input number
422 Raises
423 ------
424 KeyError
425 Raised if the requested nodeId is not in the graph.
426 """
427 return self._nodeIdMap[nodeId]
429 def getQuantaForTask(self, taskDef: TaskDef) -> FrozenSet[Quantum]:
430 """Return all the `Quantum` associated with a `TaskDef`.
432 Parameters
433 ----------
434 taskDef : `TaskDef`
435 The `TaskDef` for which `Quantum` are to be queried
437 Returns
438 -------
439 frozenset of `Quantum`
440 The `set` of `Quantum` that is associated with the specified
441 `TaskDef`.
442 """
443 return frozenset(node.quantum for node in self._taskToQuantumNode.get(taskDef, ()))
445 def getNumberOfQuantaForTask(self, taskDef: TaskDef) -> int:
446 """Return all the number of `Quantum` associated with a `TaskDef`.
448 Parameters
449 ----------
450 taskDef : `TaskDef`
451 The `TaskDef` for which `Quantum` are to be queried
453 Returns
454 -------
455 count : int
456 The number of `Quantum` that are associated with the specified
457 `TaskDef`.
458 """
459 return len(self._taskToQuantumNode.get(taskDef, ()))
461 def getNodesForTask(self, taskDef: TaskDef) -> FrozenSet[QuantumNode]:
462 """Return all the `QuantumNodes` associated with a `TaskDef`.
464 Parameters
465 ----------
466 taskDef : `TaskDef`
467 The `TaskDef` for which `Quantum` are to be queried
469 Returns
470 -------
471 frozenset of `QuantumNodes`
472 The `frozenset` of `QuantumNodes` that is associated with the
473 specified `TaskDef`.
474 """
475 return frozenset(self._taskToQuantumNode[taskDef])
477 def findTasksWithInput(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]:
478 """Find all tasks that have the specified dataset type name as an
479 input.
481 Parameters
482 ----------
483 datasetTypeName : `str`
484 A string representing the name of a dataset type to be queried,
485 can also accept a `DatasetTypeName` which is a `NewType` of str for
486 type safety in static type checking.
488 Returns
489 -------
490 tasks : iterable of `TaskDef`
491 `TaskDef` objects that have the specified `DatasetTypeName` as an
492 input, list will be empty if no tasks use specified
493 `DatasetTypeName` as an input.
495 Raises
496 ------
497 KeyError
498 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
499 """
500 return (c for c in self._datasetDict.getConsumers(datasetTypeName))
502 def findTaskWithOutput(self, datasetTypeName: DatasetTypeName) -> Optional[TaskDef]:
503 """Find all tasks that have the specified dataset type name as an
504 output.
506 Parameters
507 ----------
508 datasetTypeName : `str`
509 A string representing the name of a dataset type to be queried,
510 can also accept a `DatasetTypeName` which is a `NewType` of str for
511 type safety in static type checking.
513 Returns
514 -------
515 `TaskDef` or `None`
516 `TaskDef` that outputs `DatasetTypeName` as an output or None if
517 none of the tasks produce this `DatasetTypeName`.
519 Raises
520 ------
521 KeyError
522 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
523 """
524 return self._datasetDict.getProducer(datasetTypeName)
526 def tasksWithDSType(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]:
527 """Find all tasks that are associated with the specified dataset type
528 name.
530 Parameters
531 ----------
532 datasetTypeName : `str`
533 A string representing the name of a dataset type to be queried,
534 can also accept a `DatasetTypeName` which is a `NewType` of str for
535 type safety in static type checking.
537 Returns
538 -------
539 result : iterable of `TaskDef`
540 `TaskDef` objects that are associated with the specified
541 `DatasetTypeName`
543 Raises
544 ------
545 KeyError
546 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
547 """
548 return self._datasetDict.getAll(datasetTypeName)
550 def findTaskDefByName(self, taskName: str) -> List[TaskDef]:
551 """Determine which `TaskDef` objects in this graph are associated
552 with a `str` representing a task name (looks at the taskName property
553 of `TaskDef` objects).
555 Returns a list of `TaskDef` objects as a `PipelineTask` may appear
556 multiple times in a graph with different labels.
558 Parameters
559 ----------
560 taskName : str
561 Name of a task to search for
563 Returns
564 -------
565 result : list of `TaskDef`
566 List of the `TaskDef` objects that have the name specified.
567 Multiple values are returned in the case that a task is used
568 multiple times with different labels.
569 """
570 results = []
571 for task in self._taskToQuantumNode.keys():
572 split = task.taskName.split(".")
573 if split[-1] == taskName:
574 results.append(task)
575 return results
577 def findTaskDefByLabel(self, label: str) -> Optional[TaskDef]:
578 """Determine which `TaskDef` objects in this graph are associated
579 with a `str` representing a tasks label.
581 Parameters
582 ----------
583 taskName : str
584 Name of a task to search for
586 Returns
587 -------
588 result : `TaskDef`
589 `TaskDef` objects that has the specified label.
590 """
591 for task in self._taskToQuantumNode.keys():
592 if label == task.label:
593 return task
594 return None
596 def findQuantaWithDSType(self, datasetTypeName: DatasetTypeName) -> Set[Quantum]:
597 """Return all the `Quantum` that contain a specified `DatasetTypeName`.
599 Parameters
600 ----------
601 datasetTypeName : `str`
602 The name of the dataset type to search for as a string,
603 can also accept a `DatasetTypeName` which is a `NewType` of str for
604 type safety in static type checking.
606 Returns
607 -------
608 result : `set` of `QuantumNode` objects
609 A `set` of `QuantumNode`s that contain specified `DatasetTypeName`
611 Raises
612 ------
613 KeyError
614 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
616 """
617 tasks = self._datasetDict.getAll(datasetTypeName)
618 result: Set[Quantum] = set()
619 result = result.union(quantum for task in tasks for quantum in self.getQuantaForTask(task))
620 return result
622 def checkQuantumInGraph(self, quantum: Quantum) -> bool:
623 """Check if specified quantum appears in the graph as part of a node.
625 Parameters
626 ----------
627 quantum : `Quantum`
628 The quantum to search for
630 Returns
631 -------
632 `bool`
633 The result of searching for the quantum
634 """
635 for node in self:
636 if quantum == node.quantum:
637 return True
638 return False
640 def writeDotGraph(self, output: Union[str, io.BufferedIOBase]) -> None:
641 """Write out the graph as a dot graph.
643 Parameters
644 ----------
645 output : str or `io.BufferedIOBase`
646 Either a filesystem path to write to, or a file handle object
647 """
648 write_dot(self._connectedQuanta, output)
650 def subset(self: _T, nodes: Union[QuantumNode, Iterable[QuantumNode]]) -> _T:
651 """Create a new graph object that contains the subset of the nodes
652 specified as input. Node number is preserved.
654 Parameters
655 ----------
656 nodes : `QuantumNode` or iterable of `QuantumNode`
658 Returns
659 -------
660 graph : instance of graph type
661 An instance of the type from which the subset was created
662 """
663 if not isinstance(nodes, Iterable):
664 nodes = (nodes,)
665 quantumSubgraph = self._connectedQuanta.subgraph(nodes).nodes
666 quantumMap = defaultdict(set)
668 node: QuantumNode
669 for node in quantumSubgraph:
670 quantumMap[node.taskDef].add(node.quantum)
672 # convert to standard dict to prevent accidental key insertion
673 quantumDict: Dict[TaskDef, Set[Quantum]] = dict(quantumMap.items())
674 # Create an empty graph, and then populate it with custom mapping
675 newInst = type(self)({}, universe=self._universe)
676 newInst._buildGraphs(
677 quantumDict,
678 _quantumToNodeId={n.quantum: n.nodeId for n in nodes},
679 _buildId=self._buildId,
680 metadata=self._metadata,
681 universe=self._universe,
682 )
683 return newInst
685 def subsetToConnected(self: _T) -> Tuple[_T, ...]:
686 """Generate a list of subgraphs where each is connected.
688 Returns
689 -------
690 result : list of `QuantumGraph`
691 A list of graphs that are each connected
692 """
693 return tuple(
694 self.subset(connectedSet)
695 for connectedSet in nx.weakly_connected_components(self._connectedQuanta)
696 )
698 def determineInputsToQuantumNode(self, node: QuantumNode) -> Set[QuantumNode]:
699 """Return a set of `QuantumNode` that are direct inputs to a specified
700 node.
702 Parameters
703 ----------
704 node : `QuantumNode`
705 The node of the graph for which inputs are to be determined
707 Returns
708 -------
709 set of `QuantumNode`
710 All the nodes that are direct inputs to specified node
711 """
712 return set(pred for pred in self._connectedQuanta.predecessors(node))
714 def determineOutputsOfQuantumNode(self, node: QuantumNode) -> Set[QuantumNode]:
715 """Return a set of `QuantumNode` that are direct outputs of a specified
716 node.
718 Parameters
719 ----------
720 node : `QuantumNode`
721 The node of the graph for which outputs are to be determined
723 Returns
724 -------
725 set of `QuantumNode`
726 All the nodes that are direct outputs to specified node
727 """
728 return set(succ for succ in self._connectedQuanta.successors(node))
730 def determineConnectionsOfQuantumNode(self: _T, node: QuantumNode) -> _T:
731 """Return a graph of `QuantumNode` that are direct inputs and outputs
732 of a specified node.
734 Parameters
735 ----------
736 node : `QuantumNode`
737 The node of the graph for which connected nodes are to be
738 determined.
740 Returns
741 -------
742 graph : graph of `QuantumNode`
743 All the nodes that are directly connected to specified node
744 """
745 nodes = self.determineInputsToQuantumNode(node).union(self.determineOutputsOfQuantumNode(node))
746 nodes.add(node)
747 return self.subset(nodes)
749 def determineAncestorsOfQuantumNode(self: _T, node: QuantumNode) -> _T:
750 """Return a graph of the specified node and all the ancestor nodes
751 directly reachable by walking edges.
753 Parameters
754 ----------
755 node : `QuantumNode`
756 The node for which all ansestors are to be determined
758 Returns
759 -------
760 graph of `QuantumNode`
761 Graph of node and all of its ansestors
762 """
763 predecessorNodes = nx.ancestors(self._connectedQuanta, node)
764 predecessorNodes.add(node)
765 return self.subset(predecessorNodes)
767 def findCycle(self) -> List[Tuple[QuantumNode, QuantumNode]]:
768 """Check a graph for the presense of cycles and returns the edges of
769 any cycles found, or an empty list if there is no cycle.
771 Returns
772 -------
773 result : list of tuple of `QuantumNode`, `QuantumNode`
774 A list of any graph edges that form a cycle, or an empty list if
775 there is no cycle. Empty list to so support if graph.find_cycle()
776 syntax as an empty list is falsy.
777 """
778 try:
779 return nx.find_cycle(self._connectedQuanta)
780 except nx.NetworkXNoCycle:
781 return []
783 def saveUri(self, uri: ResourcePathExpression) -> None:
784 """Save `QuantumGraph` to the specified URI.
786 Parameters
787 ----------
788 uri : convertible to `ResourcePath`
789 URI to where the graph should be saved.
790 """
791 buffer = self._buildSaveObject()
792 path = ResourcePath(uri)
793 if path.getExtension() not in (".qgraph"):
794 raise TypeError(f"Can currently only save a graph in qgraph format not {uri}")
795 path.write(buffer) # type: ignore # Ignore because bytearray is safe to use in place of bytes
797 @property
798 def metadata(self) -> Optional[MappingProxyType[str, Any]]:
799 """ """
800 if self._metadata is None:
801 return None
802 return MappingProxyType(self._metadata)
804 def initInputRefs(self, taskDef: TaskDef) -> Optional[List[DatasetRef]]:
805 """Return DatasetRefs for a given task InitInputs.
807 Parameters
808 ----------
809 taskDef : `TaskDef`
810 Task definition structure.
812 Returns
813 -------
814 refs : `list` [ `DatasetRef` ] or None
815 DatasetRef for the task InitInput, can be `None`. This can return
816 either resolved or non-resolved reference.
817 """
818 return self._initInputRefs.get(taskDef)
820 def initOutputRefs(self, taskDef: TaskDef) -> Optional[List[DatasetRef]]:
821 """Return DatasetRefs for a given task InitOutputs.
823 Parameters
824 ----------
825 taskDef : `TaskDef`
826 Task definition structure.
828 Returns
829 -------
830 refs : `list` [ `DatasetRef` ] or None
831 DatasetRefs for the task InitOutput, can be `None`. This can return
832 either resolved or non-resolved reference. Resolved reference will
833 match Quantum's initInputs if this is an intermediate dataset type.
834 """
835 return self._initOutputRefs.get(taskDef)
837 @classmethod
838 def loadUri(
839 cls,
840 uri: ResourcePathExpression,
841 universe: Optional[DimensionUniverse] = None,
842 nodes: Optional[Iterable[uuid.UUID]] = None,
843 graphID: Optional[BuildId] = None,
844 minimumVersion: int = 3,
845 ) -> QuantumGraph:
846 """Read `QuantumGraph` from a URI.
848 Parameters
849 ----------
850 uri : convertible to `ResourcePath`
851 URI from where to load the graph.
852 universe: `~lsst.daf.butler.DimensionUniverse` optional
853 DimensionUniverse instance, not used by the method itself but
854 needed to ensure that registry data structures are initialized.
855 If None it is loaded from the QuantumGraph saved structure. If
856 supplied, the DimensionUniverse from the loaded `QuantumGraph`
857 will be validated against the supplied argument for compatibility.
858 nodes: iterable of `int` or None
859 Numbers that correspond to nodes in the graph. If specified, only
860 these nodes will be loaded. Defaults to None, in which case all
861 nodes will be loaded.
862 graphID : `str` or `None`
863 If specified this ID is verified against the loaded graph prior to
864 loading any Nodes. This defaults to None in which case no
865 validation is done.
866 minimumVersion : int
867 Minimum version of a save file to load. Set to -1 to load all
868 versions. Older versions may need to be loaded, and re-saved
869 to upgrade them to the latest format before they can be used in
870 production.
872 Returns
873 -------
874 graph : `QuantumGraph`
875 Resulting QuantumGraph instance.
877 Raises
878 ------
879 TypeError
880 Raised if pickle contains instance of a type other than
881 QuantumGraph.
882 ValueError
883 Raised if one or more of the nodes requested is not in the
884 `QuantumGraph` or if graphID parameter does not match the graph
885 being loaded or if the supplied uri does not point at a valid
886 `QuantumGraph` save file.
887 RuntimeError
888 Raise if Supplied DimensionUniverse is not compatible with the
889 DimensionUniverse saved in the graph
892 Notes
893 -----
894 Reading Quanta from pickle requires existence of singleton
895 DimensionUniverse which is usually instantiated during Registry
896 initialization. To make sure that DimensionUniverse exists this method
897 accepts dummy DimensionUniverse argument.
898 """
899 uri = ResourcePath(uri)
900 # With ResourcePath we have the choice of always using a local file
901 # or reading in the bytes directly. Reading in bytes can be more
902 # efficient for reasonably-sized pickle files when the resource
903 # is remote. For now use the local file variant. For a local file
904 # as_local() does nothing.
906 if uri.getExtension() in (".pickle", ".pkl"):
907 with uri.as_local() as local, open(local.ospath, "rb") as fd:
908 warnings.warn("Pickle graphs are deprecated, please re-save your graph with the save method")
909 qgraph = pickle.load(fd)
910 elif uri.getExtension() in (".qgraph"):
911 with LoadHelper(uri, minimumVersion) as loader:
912 qgraph = loader.load(universe, nodes, graphID)
913 else:
914 raise ValueError("Only know how to handle files saved as `pickle`, `pkl`, or `qgraph`")
915 if not isinstance(qgraph, QuantumGraph):
916 raise TypeError(f"QuantumGraph save file contains unexpected object type: {type(qgraph)}")
917 return qgraph
919 @classmethod
920 def readHeader(cls, uri: ResourcePathExpression, minimumVersion: int = 3) -> Optional[str]:
921 """Read the header of a `QuantumGraph` pointed to by the uri parameter
922 and return it as a string.
924 Parameters
925 ----------
926 uri : convertible to `ResourcePath`
927 The location of the `QuantumGraph` to load. If the argument is a
928 string, it must correspond to a valid `ResourcePath` path.
929 minimumVersion : int
930 Minimum version of a save file to load. Set to -1 to load all
931 versions. Older versions may need to be loaded, and re-saved
932 to upgrade them to the latest format before they can be used in
933 production.
935 Returns
936 -------
937 header : `str` or `None`
938 The header associated with the specified `QuantumGraph` it there is
939 one, else `None`.
941 Raises
942 ------
943 ValueError
944 Raised if `QuantuGraph` was saved as a pickle.
945 Raised if the extention of the file specified by uri is not a
946 `QuantumGraph` extention.
947 """
948 uri = ResourcePath(uri)
949 if uri.getExtension() in (".pickle", ".pkl"):
950 raise ValueError("Reading a header from a pickle save is not supported")
951 elif uri.getExtension() in (".qgraph"):
952 return LoadHelper(uri, minimumVersion).readHeader()
953 else:
954 raise ValueError("Only know how to handle files saved as `qgraph`")
956 def buildAndPrintHeader(self) -> None:
957 """Creates a header that would be used in a save of this object and
958 prints it out to standard out.
959 """
960 _, header = self._buildSaveObject(returnHeader=True)
961 print(json.dumps(header))
963 def save(self, file: BinaryIO) -> None:
964 """Save QuantumGraph to a file.
966 Parameters
967 ----------
968 file : `io.BufferedIOBase`
969 File to write pickle data open in binary mode.
970 """
971 buffer = self._buildSaveObject()
972 file.write(buffer) # type: ignore # Ignore because bytearray is safe to use in place of bytes
974 def _buildSaveObject(self, returnHeader: bool = False) -> Union[bytearray, Tuple[bytearray, Dict]]:
975 # make some containers
976 jsonData: Deque[bytes] = deque()
977 # node map is a list because json does not accept mapping keys that
978 # are not strings, so we store a list of key, value pairs that will
979 # be converted to a mapping on load
980 nodeMap = []
981 taskDefMap = {}
982 headerData: Dict[str, Any] = {}
984 # Store the QauntumGraph BuildId, this will allow validating BuildIds
985 # at load time, prior to loading any QuantumNodes. Name chosen for
986 # unlikely conflicts.
987 headerData["GraphBuildID"] = self.graphID
988 headerData["Metadata"] = self._metadata
990 # Store the universe this graph was created with
991 universeConfig = self._universe.dimensionConfig
992 headerData["universe"] = universeConfig.toDict()
994 # counter for the number of bytes processed thus far
995 count = 0
996 # serialize out the task Defs recording the start and end bytes of each
997 # taskDef
998 inverseLookup = self._datasetDict.inverse
999 taskDef: TaskDef
1000 # sort by task label to ensure serialization happens in the same order
1001 for taskDef in self.taskGraph:
1002 # compressing has very little impact on saving or load time, but
1003 # a large impact on on disk size, so it is worth doing
1004 taskDescription: Dict[str, Any] = {}
1005 # save the fully qualified name.
1006 taskDescription["taskName"] = get_full_type_name(taskDef.taskClass)
1007 # save the config as a text stream that will be un-persisted on the
1008 # other end
1009 stream = io.StringIO()
1010 taskDef.config.saveToStream(stream)
1011 taskDescription["config"] = stream.getvalue()
1012 taskDescription["label"] = taskDef.label
1013 if (refs := self._initInputRefs.get(taskDef)) is not None:
1014 taskDescription["initInputRefs"] = [ref.to_json() for ref in refs]
1015 if (refs := self._initOutputRefs.get(taskDef)) is not None:
1016 taskDescription["initOutputRefs"] = [ref.to_json() for ref in refs]
1018 inputs = []
1019 outputs = []
1021 # Determine the connection between all of tasks and save that in
1022 # the header as a list of connections and edges in each task
1023 # this will help in un-persisting, and possibly in a "quick view"
1024 # method that does not require everything to be un-persisted
1025 #
1026 # Typing returns can't be parameter dependent
1027 for connection in inverseLookup[taskDef]: # type: ignore
1028 consumers = self._datasetDict.getConsumers(connection)
1029 producer = self._datasetDict.getProducer(connection)
1030 if taskDef in consumers:
1031 # This checks if the task consumes the connection directly
1032 # from the datastore or it is produced by another task
1033 producerLabel = producer.label if producer is not None else "datastore"
1034 inputs.append((producerLabel, connection))
1035 elif taskDef not in consumers and producer is taskDef:
1036 # If there are no consumers for this tasks produced
1037 # connection, the output will be said to be the datastore
1038 # in which case the for loop will be a zero length loop
1039 if not consumers:
1040 outputs.append(("datastore", connection))
1041 for td in consumers:
1042 outputs.append((td.label, connection))
1044 # dump to json string, and encode that string to bytes and then
1045 # conpress those bytes
1046 dump = lzma.compress(json.dumps(taskDescription).encode())
1047 # record the sizing and relation information
1048 taskDefMap[taskDef.label] = {
1049 "bytes": (count, count + len(dump)),
1050 "inputs": inputs,
1051 "outputs": outputs,
1052 }
1053 count += len(dump)
1054 jsonData.append(dump)
1056 headerData["TaskDefs"] = taskDefMap
1058 # serialize the nodes, recording the start and end bytes of each node
1059 dimAccumulator = DimensionRecordsAccumulator()
1060 for node in self:
1061 # compressing has very little impact on saving or load time, but
1062 # a large impact on on disk size, so it is worth doing
1063 simpleNode = node.to_simple(accumulator=dimAccumulator)
1065 dump = lzma.compress(simpleNode.json().encode())
1066 jsonData.append(dump)
1067 nodeMap.append(
1068 (
1069 str(node.nodeId),
1070 {
1071 "bytes": (count, count + len(dump)),
1072 "inputs": [str(n.nodeId) for n in self.determineInputsToQuantumNode(node)],
1073 "outputs": [str(n.nodeId) for n in self.determineOutputsOfQuantumNode(node)],
1074 },
1075 )
1076 )
1077 count += len(dump)
1079 headerData["DimensionRecords"] = {
1080 key: value.dict() for key, value in dimAccumulator.makeSerializedDimensionRecordMapping().items()
1081 }
1083 # need to serialize this as a series of key,value tuples because of
1084 # a limitation on how json cant do anyting but strings as keys
1085 headerData["Nodes"] = nodeMap
1087 # dump the headerData to json
1088 header_encode = lzma.compress(json.dumps(headerData).encode())
1090 # record the sizes as 2 unsigned long long numbers for a total of 16
1091 # bytes
1092 save_bytes = struct.pack(STRUCT_FMT_BASE, SAVE_VERSION)
1094 fmt_string = DESERIALIZER_MAP[SAVE_VERSION].FMT_STRING()
1095 map_lengths = struct.pack(fmt_string, len(header_encode))
1097 # write each component of the save out in a deterministic order
1098 # buffer = io.BytesIO()
1099 # buffer.write(map_lengths)
1100 # buffer.write(taskDef_pickle)
1101 # buffer.write(map_pickle)
1102 buffer = bytearray()
1103 buffer.extend(MAGIC_BYTES)
1104 buffer.extend(save_bytes)
1105 buffer.extend(map_lengths)
1106 buffer.extend(header_encode)
1107 # Iterate over the length of pickleData, and for each element pop the
1108 # leftmost element off the deque and write it out. This is to save
1109 # memory, as the memory is added to the buffer object, it is removed
1110 # from from the container.
1111 #
1112 # Only this section needs to worry about memory pressue because
1113 # everything else written to the buffer prior to this pickle data is
1114 # only on the order of kilobytes to low numbers of megabytes.
1115 while jsonData:
1116 buffer.extend(jsonData.popleft())
1117 if returnHeader:
1118 return buffer, headerData
1119 else:
1120 return buffer
1122 @classmethod
1123 def load(
1124 cls,
1125 file: BinaryIO,
1126 universe: Optional[DimensionUniverse] = None,
1127 nodes: Optional[Iterable[uuid.UUID]] = None,
1128 graphID: Optional[BuildId] = None,
1129 minimumVersion: int = 3,
1130 ) -> QuantumGraph:
1131 """Read QuantumGraph from a file that was made by `save`.
1133 Parameters
1134 ----------
1135 file : `io.IO` of bytes
1136 File with pickle data open in binary mode.
1137 universe: `~lsst.daf.butler.DimensionUniverse`, optional
1138 DimensionUniverse instance, not used by the method itself but
1139 needed to ensure that registry data structures are initialized.
1140 If None it is loaded from the QuantumGraph saved structure. If
1141 supplied, the DimensionUniverse from the loaded `QuantumGraph`
1142 will be validated against the supplied argument for compatibility.
1143 nodes: iterable of `int` or None
1144 Numbers that correspond to nodes in the graph. If specified, only
1145 these nodes will be loaded. Defaults to None, in which case all
1146 nodes will be loaded.
1147 graphID : `str` or `None`
1148 If specified this ID is verified against the loaded graph prior to
1149 loading any Nodes. This defaults to None in which case no
1150 validation is done.
1151 minimumVersion : int
1152 Minimum version of a save file to load. Set to -1 to load all
1153 versions. Older versions may need to be loaded, and re-saved
1154 to upgrade them to the latest format before they can be used in
1155 production.
1157 Returns
1158 -------
1159 graph : `QuantumGraph`
1160 Resulting QuantumGraph instance.
1162 Raises
1163 ------
1164 TypeError
1165 Raised if pickle contains instance of a type other than
1166 QuantumGraph.
1167 ValueError
1168 Raised if one or more of the nodes requested is not in the
1169 `QuantumGraph` or if graphID parameter does not match the graph
1170 being loaded or if the supplied uri does not point at a valid
1171 `QuantumGraph` save file.
1173 Notes
1174 -----
1175 Reading Quanta from pickle requires existence of singleton
1176 DimensionUniverse which is usually instantiated during Registry
1177 initialization. To make sure that DimensionUniverse exists this method
1178 accepts dummy DimensionUniverse argument.
1179 """
1180 # Try to see if the file handle contains pickle data, this will be
1181 # removed in the future
1182 try:
1183 qgraph = pickle.load(file)
1184 warnings.warn("Pickle graphs are deprecated, please re-save your graph with the save method")
1185 except pickle.UnpicklingError:
1186 with LoadHelper(file, minimumVersion) as loader:
1187 qgraph = loader.load(universe, nodes, graphID)
1188 if not isinstance(qgraph, QuantumGraph):
1189 raise TypeError(f"QuantumGraph pickle file has contains unexpected object type: {type(qgraph)}")
1190 return qgraph
1192 def iterTaskGraph(self) -> Generator[TaskDef, None, None]:
1193 """Iterate over the `taskGraph` attribute in topological order
1195 Yields
1196 ------
1197 taskDef : `TaskDef`
1198 `TaskDef` objects in topological order
1199 """
1200 yield from nx.topological_sort(self.taskGraph)
1202 @property
1203 def graphID(self) -> BuildId:
1204 """Returns the ID generated by the graph at construction time"""
1205 return self._buildId
1207 def __iter__(self) -> Generator[QuantumNode, None, None]:
1208 yield from nx.topological_sort(self._connectedQuanta)
1210 def __len__(self) -> int:
1211 return self._count
1213 def __contains__(self, node: QuantumNode) -> bool:
1214 return self._connectedQuanta.has_node(node)
1216 def __getstate__(self) -> dict:
1217 """Stores a compact form of the graph as a list of graph nodes, and a
1218 tuple of task labels and task configs. The full graph can be
1219 reconstructed with this information, and it preseves the ordering of
1220 the graph ndoes.
1221 """
1222 universe: Optional[DimensionUniverse] = None
1223 for node in self:
1224 dId = node.quantum.dataId
1225 if dId is None:
1226 continue
1227 universe = dId.graph.universe
1228 return {"reduced": self._buildSaveObject(), "graphId": self._buildId, "universe": universe}
1230 def __setstate__(self, state: dict) -> None:
1231 """Reconstructs the state of the graph from the information persisted
1232 in getstate.
1233 """
1234 buffer = io.BytesIO(state["reduced"])
1235 with LoadHelper(buffer, minimumVersion=3) as loader:
1236 qgraph = loader.load(state["universe"], graphID=state["graphId"])
1238 self._metadata = qgraph._metadata
1239 self._buildId = qgraph._buildId
1240 self._datasetDict = qgraph._datasetDict
1241 self._nodeIdMap = qgraph._nodeIdMap
1242 self._count = len(qgraph)
1243 self._taskToQuantumNode = qgraph._taskToQuantumNode
1244 self._taskGraph = qgraph._taskGraph
1245 self._connectedQuanta = qgraph._connectedQuanta
1246 self._initInputRefs = qgraph._initInputRefs
1247 self._initOutputRefs = qgraph._initOutputRefs
1249 def __eq__(self, other: object) -> bool:
1250 if not isinstance(other, QuantumGraph):
1251 return False
1252 if len(self) != len(other):
1253 return False
1254 for node in self:
1255 if node not in other:
1256 return False
1257 if self.determineInputsToQuantumNode(node) != other.determineInputsToQuantumNode(node):
1258 return False
1259 if self.determineOutputsOfQuantumNode(node) != other.determineOutputsOfQuantumNode(node):
1260 return False
1261 if set(self.allDatasetTypes) != set(other.allDatasetTypes):
1262 return False
1263 return set(self.taskGraph) == set(other.taskGraph)