Coverage for python/lsst/pipe/base/graph/graph.py: 18%
351 statements
« prev ^ index » next coverage.py v7.2.2, created at 2023-03-23 23:10 +0000
« prev ^ index » next coverage.py v7.2.2, created at 2023-03-23 23:10 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ("QuantumGraph", "IncompatibleGraphError")
25import io
26import json
27import lzma
28import os
29import pickle
30import struct
31import time
32import uuid
33import warnings
34from collections import defaultdict, deque
35from itertools import chain
36from types import MappingProxyType
37from typing import (
38 Any,
39 BinaryIO,
40 DefaultDict,
41 Deque,
42 Dict,
43 FrozenSet,
44 Generator,
45 Iterable,
46 List,
47 Mapping,
48 MutableMapping,
49 Optional,
50 Set,
51 Tuple,
52 TypeVar,
53 Union,
54)
56import networkx as nx
57from lsst.daf.butler import DatasetRef, DatasetType, DimensionRecordsAccumulator, DimensionUniverse, Quantum
58from lsst.resources import ResourcePath, ResourcePathExpression
59from lsst.utils.introspection import get_full_type_name
60from networkx.drawing.nx_agraph import write_dot
62from ..connections import iterConnections
63from ..pipeline import TaskDef
64from ._implDetails import DatasetTypeName, _DatasetTracker, _pruner
65from ._loadHelpers import LoadHelper
66from ._versionDeserializers import DESERIALIZER_MAP
67from .quantumNode import BuildId, QuantumNode
69_T = TypeVar("_T", bound="QuantumGraph")
71# modify this constant any time the on disk representation of the save file
72# changes, and update the load helpers to behave properly for each version.
73SAVE_VERSION = 3
75# Strings used to describe the format for the preamble bytes in a file save
76# The base is a big endian encoded unsigned short that is used to hold the
77# file format version. This allows reading version bytes and determine which
78# loading code should be used for the rest of the file
79STRUCT_FMT_BASE = ">H"
80#
81# Version 1
82# This marks a big endian encoded format with an unsigned short, an unsigned
83# long long, and an unsigned long long in the byte stream
84# Version 2
85# A big endian encoded format with an unsigned long long byte stream used to
86# indicate the total length of the entire header.
87STRUCT_FMT_STRING = {1: ">QQ", 2: ">Q"}
89# magic bytes that help determine this is a graph save
90MAGIC_BYTES = b"qgraph4\xf6\xe8\xa9"
93class IncompatibleGraphError(Exception):
94 """Exception class to indicate that a lookup by NodeId is impossible due
95 to incompatibilities
96 """
98 pass
101class QuantumGraph:
102 """QuantumGraph is a directed acyclic graph of `QuantumNode` objects
104 This data structure represents a concrete workflow generated from a
105 `Pipeline`.
107 Parameters
108 ----------
109 quanta : Mapping of `TaskDef` to sets of `Quantum`
110 This maps tasks (and their configs) to the sets of data they are to
111 process.
112 metadata : Optional Mapping of `str` to primitives
113 This is an optional parameter of extra data to carry with the graph.
114 Entries in this mapping should be able to be serialized in JSON.
116 Raises
117 ------
118 ValueError
119 Raised if the graph is pruned such that some tasks no longer have nodes
120 associated with them.
121 """
123 def __init__(
124 self,
125 quanta: Mapping[TaskDef, Set[Quantum]],
126 metadata: Optional[Mapping[str, Any]] = None,
127 pruneRefs: Optional[Iterable[DatasetRef]] = None,
128 universe: Optional[DimensionUniverse] = None,
129 ):
130 self._buildGraphs(quanta, metadata=metadata, pruneRefs=pruneRefs, universe=universe)
132 def _buildGraphs(
133 self,
134 quanta: Mapping[TaskDef, Set[Quantum]],
135 *,
136 _quantumToNodeId: Optional[Mapping[Quantum, uuid.UUID]] = None,
137 _buildId: Optional[BuildId] = None,
138 metadata: Optional[Mapping[str, Any]] = None,
139 pruneRefs: Optional[Iterable[DatasetRef]] = None,
140 universe: Optional[DimensionUniverse] = None,
141 ) -> None:
142 """Builds the graph that is used to store the relation between tasks,
143 and the graph that holds the relations between quanta
144 """
145 self._metadata = metadata
146 self._buildId = _buildId if _buildId is not None else BuildId(f"{time.time()}-{os.getpid()}")
147 # Data structures used to identify relations between components;
148 # DatasetTypeName -> TaskDef for task,
149 # and DatasetRef -> QuantumNode for the quanta
150 self._datasetDict = _DatasetTracker[DatasetTypeName, TaskDef](createInverse=True)
151 self._datasetRefDict = _DatasetTracker[DatasetRef, QuantumNode]()
153 self._nodeIdMap: Dict[uuid.UUID, QuantumNode] = {}
154 self._taskToQuantumNode: MutableMapping[TaskDef, Set[QuantumNode]] = defaultdict(set)
155 for taskDef, quantumSet in quanta.items():
156 connections = taskDef.connections
158 # For each type of connection in the task, add a key to the
159 # `_DatasetTracker` for the connections name, with a value of
160 # the TaskDef in the appropriate field
161 for inpt in iterConnections(connections, ("inputs", "prerequisiteInputs", "initInputs")):
162 self._datasetDict.addConsumer(DatasetTypeName(inpt.name), taskDef)
164 for output in iterConnections(connections, ("outputs",)):
165 self._datasetDict.addProducer(DatasetTypeName(output.name), taskDef)
167 # For each `Quantum` in the set of all `Quantum` for this task,
168 # add a key to the `_DatasetTracker` that is a `DatasetRef` for one
169 # of the individual datasets inside the `Quantum`, with a value of
170 # a newly created QuantumNode to the appropriate input/output
171 # field.
172 for quantum in quantumSet:
173 if quantum.dataId is not None:
174 if universe is None:
175 universe = quantum.dataId.universe
176 elif universe != quantum.dataId.universe:
177 raise RuntimeError(
178 "Mismatched dimension universes in QuantumGraph construction: "
179 f"{universe} != {quantum.dataId.universe}. "
180 )
182 if _quantumToNodeId:
183 if (nodeId := _quantumToNodeId.get(quantum)) is None:
184 raise ValueError(
185 "If _quantuMToNodeNumber is not None, all quanta must have an "
186 "associated value in the mapping"
187 )
188 else:
189 nodeId = uuid.uuid4()
191 inits = quantum.initInputs.values()
192 inputs = quantum.inputs.values()
193 value = QuantumNode(quantum, taskDef, nodeId)
194 self._taskToQuantumNode[taskDef].add(value)
195 self._nodeIdMap[nodeId] = value
197 for dsRef in chain(inits, inputs):
198 # unfortunately, `Quantum` allows inits to be individual
199 # `DatasetRef`s or an Iterable of such, so there must
200 # be an instance check here
201 if isinstance(dsRef, Iterable):
202 for sub in dsRef:
203 if sub.isComponent():
204 sub = sub.makeCompositeRef()
205 self._datasetRefDict.addConsumer(sub, value)
206 else:
207 assert isinstance(dsRef, DatasetRef)
208 if dsRef.isComponent():
209 dsRef = dsRef.makeCompositeRef()
210 self._datasetRefDict.addConsumer(dsRef, value)
211 for dsRef in chain.from_iterable(quantum.outputs.values()):
212 self._datasetRefDict.addProducer(dsRef, value)
214 if pruneRefs is not None:
215 # track what refs were pruned and prune the graph
216 prunes: Set[QuantumNode] = set()
217 _pruner(self._datasetRefDict, pruneRefs, alreadyPruned=prunes)
219 # recreate the taskToQuantumNode dict removing nodes that have been
220 # pruned. Keep track of task defs that now have no QuantumNodes
221 emptyTasks: Set[str] = set()
222 newTaskToQuantumNode: DefaultDict[TaskDef, Set[QuantumNode]] = defaultdict(set)
223 # accumulate all types
224 types_ = set()
225 # tracker for any pruneRefs that have caused tasks to have no nodes
226 # This helps the user find out what caused the issues seen.
227 culprits = set()
228 # Find all the types from the refs to prune
229 for r in pruneRefs:
230 types_.add(r.datasetType)
232 # For each of the tasks, and their associated nodes, remove any
233 # any nodes that were pruned. If there are no nodes associated
234 # with a task, record that task, and find out if that was due to
235 # a type from an input ref to prune.
236 for td, taskNodes in self._taskToQuantumNode.items():
237 diff = taskNodes.difference(prunes)
238 if len(diff) == 0:
239 if len(taskNodes) != 0:
240 tp: DatasetType
241 for tp in types_:
242 if (tmpRefs := next(iter(taskNodes)).quantum.inputs.get(tp)) and not set(
243 tmpRefs
244 ).difference(pruneRefs):
245 culprits.add(tp.name)
246 emptyTasks.add(td.label)
247 newTaskToQuantumNode[td] = diff
249 # update the internal dict
250 self._taskToQuantumNode = newTaskToQuantumNode
252 if emptyTasks:
253 raise ValueError(
254 f"{', '.join(emptyTasks)} task(s) have no nodes associated with them "
255 f"after graph pruning; {', '.join(culprits)} caused over-pruning"
256 )
258 # Dimension universe
259 if universe is None:
260 raise RuntimeError(
261 "Dimension universe or at least one quantum with a data ID "
262 "must be provided when constructing a QuantumGraph."
263 )
264 self._universe = universe
266 # Graph of quanta relations
267 self._connectedQuanta = self._datasetRefDict.makeNetworkXGraph()
268 self._count = len(self._connectedQuanta)
270 # Graph of task relations, used in various methods
271 self._taskGraph = self._datasetDict.makeNetworkXGraph()
273 # convert default dict into a regular to prevent accidental key
274 # insertion
275 self._taskToQuantumNode = dict(self._taskToQuantumNode.items())
277 @property
278 def taskGraph(self) -> nx.DiGraph:
279 """Return a graph representing the relations between the tasks inside
280 the quantum graph.
282 Returns
283 -------
284 taskGraph : `networkx.Digraph`
285 Internal datastructure that holds relations of `TaskDef` objects
286 """
287 return self._taskGraph
289 @property
290 def graph(self) -> nx.DiGraph:
291 """Return a graph representing the relations between all the
292 `QuantumNode` objects. Largely it should be preferred to iterate
293 over, and use methods of this class, but sometimes direct access to
294 the networkx object may be helpful
296 Returns
297 -------
298 graph : `networkx.Digraph`
299 Internal datastructure that holds relations of `QuantumNode`
300 objects
301 """
302 return self._connectedQuanta
304 @property
305 def inputQuanta(self) -> Iterable[QuantumNode]:
306 """Make a `list` of all `QuantumNode` objects that are 'input' nodes
307 to the graph, meaning those nodes to not depend on any other nodes in
308 the graph.
310 Returns
311 -------
312 inputNodes : iterable of `QuantumNode`
313 A list of nodes that are inputs to the graph
314 """
315 return (q for q, n in self._connectedQuanta.in_degree if n == 0)
317 @property
318 def outputQuanta(self) -> Iterable[QuantumNode]:
319 """Make a `list` of all `QuantumNode` objects that are 'output' nodes
320 to the graph, meaning those nodes have no nodes that depend them in
321 the graph.
323 Returns
324 -------
325 outputNodes : iterable of `QuantumNode`
326 A list of nodes that are outputs of the graph
327 """
328 return [q for q, n in self._connectedQuanta.out_degree if n == 0]
330 @property
331 def allDatasetTypes(self) -> Tuple[DatasetTypeName, ...]:
332 """Return all the `DatasetTypeName` objects that are contained inside
333 the graph.
335 Returns
336 -------
337 tuple of `DatasetTypeName`
338 All the data set type names that are present in the graph
339 """
340 return tuple(self._datasetDict.keys())
342 @property
343 def isConnected(self) -> bool:
344 """Return True if all of the nodes in the graph are connected, ignores
345 directionality of connections.
346 """
347 return nx.is_weakly_connected(self._connectedQuanta)
349 def pruneGraphFromRefs(self: _T, refs: Iterable[DatasetRef]) -> _T:
350 r"""Return a graph pruned of input `~lsst.daf.butler.DatasetRef`\ s
351 and nodes which depend on them.
353 Parameters
354 ----------
355 refs : `Iterable` of `DatasetRef`
356 Refs which should be removed from resulting graph
358 Returns
359 -------
360 graph : `QuantumGraph`
361 A graph that has been pruned of specified refs and the nodes that
362 depend on them.
363 """
364 newInst = object.__new__(type(self))
365 quantumMap = defaultdict(set)
366 for node in self:
367 quantumMap[node.taskDef].add(node.quantum)
369 # convert to standard dict to prevent accidental key insertion
370 quantumDict: Dict[TaskDef, Set[Quantum]] = dict(quantumMap.items())
372 newInst._buildGraphs(
373 quantumDict,
374 _quantumToNodeId={n.quantum: n.nodeId for n in self},
375 metadata=self._metadata,
376 pruneRefs=refs,
377 universe=self._universe,
378 )
379 return newInst
381 def getQuantumNodeByNodeId(self, nodeId: uuid.UUID) -> QuantumNode:
382 """Lookup a `QuantumNode` from an id associated with the node.
384 Parameters
385 ----------
386 nodeId : `NodeId`
387 The number associated with a node
389 Returns
390 -------
391 node : `QuantumNode`
392 The node corresponding with input number
394 Raises
395 ------
396 KeyError
397 Raised if the requested nodeId is not in the graph.
398 """
399 return self._nodeIdMap[nodeId]
401 def getQuantaForTask(self, taskDef: TaskDef) -> FrozenSet[Quantum]:
402 """Return all the `Quantum` associated with a `TaskDef`.
404 Parameters
405 ----------
406 taskDef : `TaskDef`
407 The `TaskDef` for which `Quantum` are to be queried
409 Returns
410 -------
411 frozenset of `Quantum`
412 The `set` of `Quantum` that is associated with the specified
413 `TaskDef`.
414 """
415 return frozenset(node.quantum for node in self._taskToQuantumNode[taskDef])
417 def getNodesForTask(self, taskDef: TaskDef) -> FrozenSet[QuantumNode]:
418 """Return all the `QuantumNodes` associated with a `TaskDef`.
420 Parameters
421 ----------
422 taskDef : `TaskDef`
423 The `TaskDef` for which `Quantum` are to be queried
425 Returns
426 -------
427 frozenset of `QuantumNodes`
428 The `frozenset` of `QuantumNodes` that is associated with the
429 specified `TaskDef`.
430 """
431 return frozenset(self._taskToQuantumNode[taskDef])
433 def findTasksWithInput(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]:
434 """Find all tasks that have the specified dataset type name as an
435 input.
437 Parameters
438 ----------
439 datasetTypeName : `str`
440 A string representing the name of a dataset type to be queried,
441 can also accept a `DatasetTypeName` which is a `NewType` of str for
442 type safety in static type checking.
444 Returns
445 -------
446 tasks : iterable of `TaskDef`
447 `TaskDef` objects that have the specified `DatasetTypeName` as an
448 input, list will be empty if no tasks use specified
449 `DatasetTypeName` as an input.
451 Raises
452 ------
453 KeyError
454 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
455 """
456 return (c for c in self._datasetDict.getConsumers(datasetTypeName))
458 def findTaskWithOutput(self, datasetTypeName: DatasetTypeName) -> Optional[TaskDef]:
459 """Find all tasks that have the specified dataset type name as an
460 output.
462 Parameters
463 ----------
464 datasetTypeName : `str`
465 A string representing the name of a dataset type to be queried,
466 can also accept a `DatasetTypeName` which is a `NewType` of str for
467 type safety in static type checking.
469 Returns
470 -------
471 `TaskDef` or `None`
472 `TaskDef` that outputs `DatasetTypeName` as an output or None if
473 none of the tasks produce this `DatasetTypeName`.
475 Raises
476 ------
477 KeyError
478 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
479 """
480 return self._datasetDict.getProducer(datasetTypeName)
482 def tasksWithDSType(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]:
483 """Find all tasks that are associated with the specified dataset type
484 name.
486 Parameters
487 ----------
488 datasetTypeName : `str`
489 A string representing the name of a dataset type to be queried,
490 can also accept a `DatasetTypeName` which is a `NewType` of str for
491 type safety in static type checking.
493 Returns
494 -------
495 result : iterable of `TaskDef`
496 `TaskDef` objects that are associated with the specified
497 `DatasetTypeName`
499 Raises
500 ------
501 KeyError
502 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
503 """
504 return self._datasetDict.getAll(datasetTypeName)
506 def findTaskDefByName(self, taskName: str) -> List[TaskDef]:
507 """Determine which `TaskDef` objects in this graph are associated
508 with a `str` representing a task name (looks at the taskName property
509 of `TaskDef` objects).
511 Returns a list of `TaskDef` objects as a `PipelineTask` may appear
512 multiple times in a graph with different labels.
514 Parameters
515 ----------
516 taskName : str
517 Name of a task to search for
519 Returns
520 -------
521 result : list of `TaskDef`
522 List of the `TaskDef` objects that have the name specified.
523 Multiple values are returned in the case that a task is used
524 multiple times with different labels.
525 """
526 results = []
527 for task in self._taskToQuantumNode.keys():
528 split = task.taskName.split(".")
529 if split[-1] == taskName:
530 results.append(task)
531 return results
533 def findTaskDefByLabel(self, label: str) -> Optional[TaskDef]:
534 """Determine which `TaskDef` objects in this graph are associated
535 with a `str` representing a tasks label.
537 Parameters
538 ----------
539 taskName : str
540 Name of a task to search for
542 Returns
543 -------
544 result : `TaskDef`
545 `TaskDef` objects that has the specified label.
546 """
547 for task in self._taskToQuantumNode.keys():
548 if label == task.label:
549 return task
550 return None
552 def findQuantaWithDSType(self, datasetTypeName: DatasetTypeName) -> Set[Quantum]:
553 """Return all the `Quantum` that contain a specified `DatasetTypeName`.
555 Parameters
556 ----------
557 datasetTypeName : `str`
558 The name of the dataset type to search for as a string,
559 can also accept a `DatasetTypeName` which is a `NewType` of str for
560 type safety in static type checking.
562 Returns
563 -------
564 result : `set` of `QuantumNode` objects
565 A `set` of `QuantumNode`s that contain specified `DatasetTypeName`
567 Raises
568 ------
569 KeyError
570 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
572 """
573 tasks = self._datasetDict.getAll(datasetTypeName)
574 result: Set[Quantum] = set()
575 result = result.union(quantum for task in tasks for quantum in self.getQuantaForTask(task))
576 return result
578 def checkQuantumInGraph(self, quantum: Quantum) -> bool:
579 """Check if specified quantum appears in the graph as part of a node.
581 Parameters
582 ----------
583 quantum : `Quantum`
584 The quantum to search for
586 Returns
587 -------
588 `bool`
589 The result of searching for the quantum
590 """
591 for node in self:
592 if quantum == node.quantum:
593 return True
594 return False
596 def writeDotGraph(self, output: Union[str, io.BufferedIOBase]) -> None:
597 """Write out the graph as a dot graph.
599 Parameters
600 ----------
601 output : str or `io.BufferedIOBase`
602 Either a filesystem path to write to, or a file handle object
603 """
604 write_dot(self._connectedQuanta, output)
606 def subset(self: _T, nodes: Union[QuantumNode, Iterable[QuantumNode]]) -> _T:
607 """Create a new graph object that contains the subset of the nodes
608 specified as input. Node number is preserved.
610 Parameters
611 ----------
612 nodes : `QuantumNode` or iterable of `QuantumNode`
614 Returns
615 -------
616 graph : instance of graph type
617 An instance of the type from which the subset was created
618 """
619 if not isinstance(nodes, Iterable):
620 nodes = (nodes,)
621 quantumSubgraph = self._connectedQuanta.subgraph(nodes).nodes
622 quantumMap = defaultdict(set)
624 node: QuantumNode
625 for node in quantumSubgraph:
626 quantumMap[node.taskDef].add(node.quantum)
628 # convert to standard dict to prevent accidental key insertion
629 quantumDict: Dict[TaskDef, Set[Quantum]] = dict(quantumMap.items())
630 # Create an empty graph, and then populate it with custom mapping
631 newInst = type(self)({}, universe=self._universe)
632 newInst._buildGraphs(
633 quantumDict,
634 _quantumToNodeId={n.quantum: n.nodeId for n in nodes},
635 _buildId=self._buildId,
636 metadata=self._metadata,
637 universe=self._universe,
638 )
639 return newInst
641 def subsetToConnected(self: _T) -> Tuple[_T, ...]:
642 """Generate a list of subgraphs where each is connected.
644 Returns
645 -------
646 result : list of `QuantumGraph`
647 A list of graphs that are each connected
648 """
649 return tuple(
650 self.subset(connectedSet)
651 for connectedSet in nx.weakly_connected_components(self._connectedQuanta)
652 )
654 def determineInputsToQuantumNode(self, node: QuantumNode) -> Set[QuantumNode]:
655 """Return a set of `QuantumNode` that are direct inputs to a specified
656 node.
658 Parameters
659 ----------
660 node : `QuantumNode`
661 The node of the graph for which inputs are to be determined
663 Returns
664 -------
665 set of `QuantumNode`
666 All the nodes that are direct inputs to specified node
667 """
668 return set(pred for pred in self._connectedQuanta.predecessors(node))
670 def determineOutputsOfQuantumNode(self, node: QuantumNode) -> Set[QuantumNode]:
671 """Return a set of `QuantumNode` that are direct outputs of a specified
672 node.
674 Parameters
675 ----------
676 node : `QuantumNode`
677 The node of the graph for which outputs are to be determined
679 Returns
680 -------
681 set of `QuantumNode`
682 All the nodes that are direct outputs to specified node
683 """
684 return set(succ for succ in self._connectedQuanta.successors(node))
686 def determineConnectionsOfQuantumNode(self: _T, node: QuantumNode) -> _T:
687 """Return a graph of `QuantumNode` that are direct inputs and outputs
688 of a specified node.
690 Parameters
691 ----------
692 node : `QuantumNode`
693 The node of the graph for which connected nodes are to be
694 determined.
696 Returns
697 -------
698 graph : graph of `QuantumNode`
699 All the nodes that are directly connected to specified node
700 """
701 nodes = self.determineInputsToQuantumNode(node).union(self.determineOutputsOfQuantumNode(node))
702 nodes.add(node)
703 return self.subset(nodes)
705 def determineAncestorsOfQuantumNode(self: _T, node: QuantumNode) -> _T:
706 """Return a graph of the specified node and all the ancestor nodes
707 directly reachable by walking edges.
709 Parameters
710 ----------
711 node : `QuantumNode`
712 The node for which all ansestors are to be determined
714 Returns
715 -------
716 graph of `QuantumNode`
717 Graph of node and all of its ansestors
718 """
719 predecessorNodes = nx.ancestors(self._connectedQuanta, node)
720 predecessorNodes.add(node)
721 return self.subset(predecessorNodes)
723 def findCycle(self) -> List[Tuple[QuantumNode, QuantumNode]]:
724 """Check a graph for the presense of cycles and returns the edges of
725 any cycles found, or an empty list if there is no cycle.
727 Returns
728 -------
729 result : list of tuple of `QuantumNode`, `QuantumNode`
730 A list of any graph edges that form a cycle, or an empty list if
731 there is no cycle. Empty list to so support if graph.find_cycle()
732 syntax as an empty list is falsy.
733 """
734 try:
735 return nx.find_cycle(self._connectedQuanta)
736 except nx.NetworkXNoCycle:
737 return []
739 def saveUri(self, uri: ResourcePathExpression) -> None:
740 """Save `QuantumGraph` to the specified URI.
742 Parameters
743 ----------
744 uri : convertible to `ResourcePath`
745 URI to where the graph should be saved.
746 """
747 buffer = self._buildSaveObject()
748 path = ResourcePath(uri)
749 if path.getExtension() not in (".qgraph"):
750 raise TypeError(f"Can currently only save a graph in qgraph format not {uri}")
751 path.write(buffer) # type: ignore # Ignore because bytearray is safe to use in place of bytes
753 @property
754 def metadata(self) -> Optional[MappingProxyType[str, Any]]:
755 """ """
756 if self._metadata is None:
757 return None
758 return MappingProxyType(self._metadata)
760 @classmethod
761 def loadUri(
762 cls,
763 uri: ResourcePathExpression,
764 universe: Optional[DimensionUniverse] = None,
765 nodes: Optional[Iterable[uuid.UUID]] = None,
766 graphID: Optional[BuildId] = None,
767 minimumVersion: int = 3,
768 ) -> QuantumGraph:
769 """Read `QuantumGraph` from a URI.
771 Parameters
772 ----------
773 uri : convertible to `ResourcePath`
774 URI from where to load the graph.
775 universe: `~lsst.daf.butler.DimensionUniverse` optional
776 DimensionUniverse instance, not used by the method itself but
777 needed to ensure that registry data structures are initialized.
778 If None it is loaded from the QuantumGraph saved structure. If
779 supplied, the DimensionUniverse from the loaded `QuantumGraph`
780 will be validated against the supplied argument for compatibility.
781 nodes: iterable of `int` or None
782 Numbers that correspond to nodes in the graph. If specified, only
783 these nodes will be loaded. Defaults to None, in which case all
784 nodes will be loaded.
785 graphID : `str` or `None`
786 If specified this ID is verified against the loaded graph prior to
787 loading any Nodes. This defaults to None in which case no
788 validation is done.
789 minimumVersion : int
790 Minimum version of a save file to load. Set to -1 to load all
791 versions. Older versions may need to be loaded, and re-saved
792 to upgrade them to the latest format before they can be used in
793 production.
795 Returns
796 -------
797 graph : `QuantumGraph`
798 Resulting QuantumGraph instance.
800 Raises
801 ------
802 TypeError
803 Raised if pickle contains instance of a type other than
804 QuantumGraph.
805 ValueError
806 Raised if one or more of the nodes requested is not in the
807 `QuantumGraph` or if graphID parameter does not match the graph
808 being loaded or if the supplied uri does not point at a valid
809 `QuantumGraph` save file.
810 RuntimeError
811 Raise if Supplied DimensionUniverse is not compatible with the
812 DimensionUniverse saved in the graph
815 Notes
816 -----
817 Reading Quanta from pickle requires existence of singleton
818 DimensionUniverse which is usually instantiated during Registry
819 initialization. To make sure that DimensionUniverse exists this method
820 accepts dummy DimensionUniverse argument.
821 """
822 uri = ResourcePath(uri)
823 # With ResourcePath we have the choice of always using a local file
824 # or reading in the bytes directly. Reading in bytes can be more
825 # efficient for reasonably-sized pickle files when the resource
826 # is remote. For now use the local file variant. For a local file
827 # as_local() does nothing.
829 if uri.getExtension() in (".pickle", ".pkl"):
830 with uri.as_local() as local, open(local.ospath, "rb") as fd:
831 warnings.warn("Pickle graphs are deprecated, please re-save your graph with the save method")
832 qgraph = pickle.load(fd)
833 elif uri.getExtension() in (".qgraph"):
834 with LoadHelper(uri, minimumVersion) as loader:
835 qgraph = loader.load(universe, nodes, graphID)
836 else:
837 raise ValueError("Only know how to handle files saved as `pickle`, `pkl`, or `qgraph`")
838 if not isinstance(qgraph, QuantumGraph):
839 raise TypeError(f"QuantumGraph save file contains unexpected object type: {type(qgraph)}")
840 return qgraph
842 @classmethod
843 def readHeader(cls, uri: ResourcePathExpression, minimumVersion: int = 3) -> Optional[str]:
844 """Read the header of a `QuantumGraph` pointed to by the uri parameter
845 and return it as a string.
847 Parameters
848 ----------
849 uri : convertible to `ResourcePath`
850 The location of the `QuantumGraph` to load. If the argument is a
851 string, it must correspond to a valid `ResourcePath` path.
852 minimumVersion : int
853 Minimum version of a save file to load. Set to -1 to load all
854 versions. Older versions may need to be loaded, and re-saved
855 to upgrade them to the latest format before they can be used in
856 production.
858 Returns
859 -------
860 header : `str` or `None`
861 The header associated with the specified `QuantumGraph` it there is
862 one, else `None`.
864 Raises
865 ------
866 ValueError
867 Raised if `QuantuGraph` was saved as a pickle.
868 Raised if the extention of the file specified by uri is not a
869 `QuantumGraph` extention.
870 """
871 uri = ResourcePath(uri)
872 if uri.getExtension() in (".pickle", ".pkl"):
873 raise ValueError("Reading a header from a pickle save is not supported")
874 elif uri.getExtension() in (".qgraph"):
875 return LoadHelper(uri, minimumVersion).readHeader()
876 else:
877 raise ValueError("Only know how to handle files saved as `qgraph`")
879 def buildAndPrintHeader(self) -> None:
880 """Creates a header that would be used in a save of this object and
881 prints it out to standard out.
882 """
883 _, header = self._buildSaveObject(returnHeader=True)
884 print(json.dumps(header))
886 def save(self, file: BinaryIO) -> None:
887 """Save QuantumGraph to a file.
889 Parameters
890 ----------
891 file : `io.BufferedIOBase`
892 File to write pickle data open in binary mode.
893 """
894 buffer = self._buildSaveObject()
895 file.write(buffer) # type: ignore # Ignore because bytearray is safe to use in place of bytes
897 def _buildSaveObject(self, returnHeader: bool = False) -> Union[bytearray, Tuple[bytearray, Dict]]:
898 # make some containers
899 jsonData: Deque[bytes] = deque()
900 # node map is a list because json does not accept mapping keys that
901 # are not strings, so we store a list of key, value pairs that will
902 # be converted to a mapping on load
903 nodeMap = []
904 taskDefMap = {}
905 headerData: Dict[str, Any] = {}
907 # Store the QauntumGraph BuildId, this will allow validating BuildIds
908 # at load time, prior to loading any QuantumNodes. Name chosen for
909 # unlikely conflicts.
910 headerData["GraphBuildID"] = self.graphID
911 headerData["Metadata"] = self._metadata
913 # Store the universe this graph was created with
914 universeConfig = self._universe.dimensionConfig
915 headerData["universe"] = universeConfig.toDict()
917 # counter for the number of bytes processed thus far
918 count = 0
919 # serialize out the task Defs recording the start and end bytes of each
920 # taskDef
921 inverseLookup = self._datasetDict.inverse
922 taskDef: TaskDef
923 # sort by task label to ensure serialization happens in the same order
924 for taskDef in self.taskGraph:
925 # compressing has very little impact on saving or load time, but
926 # a large impact on on disk size, so it is worth doing
927 taskDescription = {}
928 # save the fully qualified name.
929 taskDescription["taskName"] = get_full_type_name(taskDef.taskClass)
930 # save the config as a text stream that will be un-persisted on the
931 # other end
932 stream = io.StringIO()
933 taskDef.config.saveToStream(stream)
934 taskDescription["config"] = stream.getvalue()
935 taskDescription["label"] = taskDef.label
937 inputs = []
938 outputs = []
940 # Determine the connection between all of tasks and save that in
941 # the header as a list of connections and edges in each task
942 # this will help in un-persisting, and possibly in a "quick view"
943 # method that does not require everything to be un-persisted
944 #
945 # Typing returns can't be parameter dependent
946 for connection in inverseLookup[taskDef]: # type: ignore
947 consumers = self._datasetDict.getConsumers(connection)
948 producer = self._datasetDict.getProducer(connection)
949 if taskDef in consumers:
950 # This checks if the task consumes the connection directly
951 # from the datastore or it is produced by another task
952 producerLabel = producer.label if producer is not None else "datastore"
953 inputs.append((producerLabel, connection))
954 elif taskDef not in consumers and producer is taskDef:
955 # If there are no consumers for this tasks produced
956 # connection, the output will be said to be the datastore
957 # in which case the for loop will be a zero length loop
958 if not consumers:
959 outputs.append(("datastore", connection))
960 for td in consumers:
961 outputs.append((td.label, connection))
963 # dump to json string, and encode that string to bytes and then
964 # conpress those bytes
965 dump = lzma.compress(json.dumps(taskDescription).encode())
966 # record the sizing and relation information
967 taskDefMap[taskDef.label] = {
968 "bytes": (count, count + len(dump)),
969 "inputs": inputs,
970 "outputs": outputs,
971 }
972 count += len(dump)
973 jsonData.append(dump)
975 headerData["TaskDefs"] = taskDefMap
977 # serialize the nodes, recording the start and end bytes of each node
978 dimAccumulator = DimensionRecordsAccumulator()
979 for node in self:
980 # compressing has very little impact on saving or load time, but
981 # a large impact on on disk size, so it is worth doing
982 simpleNode = node.to_simple(accumulator=dimAccumulator)
984 dump = lzma.compress(simpleNode.json().encode())
985 jsonData.append(dump)
986 nodeMap.append(
987 (
988 str(node.nodeId),
989 {
990 "bytes": (count, count + len(dump)),
991 "inputs": [str(n.nodeId) for n in self.determineInputsToQuantumNode(node)],
992 "outputs": [str(n.nodeId) for n in self.determineOutputsOfQuantumNode(node)],
993 },
994 )
995 )
996 count += len(dump)
998 headerData["DimensionRecords"] = {
999 key: value.dict() for key, value in dimAccumulator.makeSerializedDimensionRecordMapping().items()
1000 }
1002 # need to serialize this as a series of key,value tuples because of
1003 # a limitation on how json cant do anyting but strings as keys
1004 headerData["Nodes"] = nodeMap
1006 # dump the headerData to json
1007 header_encode = lzma.compress(json.dumps(headerData).encode())
1009 # record the sizes as 2 unsigned long long numbers for a total of 16
1010 # bytes
1011 save_bytes = struct.pack(STRUCT_FMT_BASE, SAVE_VERSION)
1013 fmt_string = DESERIALIZER_MAP[SAVE_VERSION].FMT_STRING()
1014 map_lengths = struct.pack(fmt_string, len(header_encode))
1016 # write each component of the save out in a deterministic order
1017 # buffer = io.BytesIO()
1018 # buffer.write(map_lengths)
1019 # buffer.write(taskDef_pickle)
1020 # buffer.write(map_pickle)
1021 buffer = bytearray()
1022 buffer.extend(MAGIC_BYTES)
1023 buffer.extend(save_bytes)
1024 buffer.extend(map_lengths)
1025 buffer.extend(header_encode)
1026 # Iterate over the length of pickleData, and for each element pop the
1027 # leftmost element off the deque and write it out. This is to save
1028 # memory, as the memory is added to the buffer object, it is removed
1029 # from from the container.
1030 #
1031 # Only this section needs to worry about memory pressue because
1032 # everything else written to the buffer prior to this pickle data is
1033 # only on the order of kilobytes to low numbers of megabytes.
1034 while jsonData:
1035 buffer.extend(jsonData.popleft())
1036 if returnHeader:
1037 return buffer, headerData
1038 else:
1039 return buffer
1041 @classmethod
1042 def load(
1043 cls,
1044 file: BinaryIO,
1045 universe: Optional[DimensionUniverse] = None,
1046 nodes: Optional[Iterable[uuid.UUID]] = None,
1047 graphID: Optional[BuildId] = None,
1048 minimumVersion: int = 3,
1049 ) -> QuantumGraph:
1050 """Read QuantumGraph from a file that was made by `save`.
1052 Parameters
1053 ----------
1054 file : `io.IO` of bytes
1055 File with pickle data open in binary mode.
1056 universe: `~lsst.daf.butler.DimensionUniverse`, optional
1057 DimensionUniverse instance, not used by the method itself but
1058 needed to ensure that registry data structures are initialized.
1059 If None it is loaded from the QuantumGraph saved structure. If
1060 supplied, the DimensionUniverse from the loaded `QuantumGraph`
1061 will be validated against the supplied argument for compatibility.
1062 nodes: iterable of `int` or None
1063 Numbers that correspond to nodes in the graph. If specified, only
1064 these nodes will be loaded. Defaults to None, in which case all
1065 nodes will be loaded.
1066 graphID : `str` or `None`
1067 If specified this ID is verified against the loaded graph prior to
1068 loading any Nodes. This defaults to None in which case no
1069 validation is done.
1070 minimumVersion : int
1071 Minimum version of a save file to load. Set to -1 to load all
1072 versions. Older versions may need to be loaded, and re-saved
1073 to upgrade them to the latest format before they can be used in
1074 production.
1076 Returns
1077 -------
1078 graph : `QuantumGraph`
1079 Resulting QuantumGraph instance.
1081 Raises
1082 ------
1083 TypeError
1084 Raised if pickle contains instance of a type other than
1085 QuantumGraph.
1086 ValueError
1087 Raised if one or more of the nodes requested is not in the
1088 `QuantumGraph` or if graphID parameter does not match the graph
1089 being loaded or if the supplied uri does not point at a valid
1090 `QuantumGraph` save file.
1092 Notes
1093 -----
1094 Reading Quanta from pickle requires existence of singleton
1095 DimensionUniverse which is usually instantiated during Registry
1096 initialization. To make sure that DimensionUniverse exists this method
1097 accepts dummy DimensionUniverse argument.
1098 """
1099 # Try to see if the file handle contains pickle data, this will be
1100 # removed in the future
1101 try:
1102 qgraph = pickle.load(file)
1103 warnings.warn("Pickle graphs are deprecated, please re-save your graph with the save method")
1104 except pickle.UnpicklingError:
1105 with LoadHelper(file, minimumVersion) as loader:
1106 qgraph = loader.load(universe, nodes, graphID)
1107 if not isinstance(qgraph, QuantumGraph):
1108 raise TypeError(f"QuantumGraph pickle file has contains unexpected object type: {type(qgraph)}")
1109 return qgraph
1111 def iterTaskGraph(self) -> Generator[TaskDef, None, None]:
1112 """Iterate over the `taskGraph` attribute in topological order
1114 Yields
1115 ------
1116 taskDef : `TaskDef`
1117 `TaskDef` objects in topological order
1118 """
1119 yield from nx.topological_sort(self.taskGraph)
1121 @property
1122 def graphID(self) -> BuildId:
1123 """Returns the ID generated by the graph at construction time"""
1124 return self._buildId
1126 def __iter__(self) -> Generator[QuantumNode, None, None]:
1127 yield from nx.topological_sort(self._connectedQuanta)
1129 def __len__(self) -> int:
1130 return self._count
1132 def __contains__(self, node: QuantumNode) -> bool:
1133 return self._connectedQuanta.has_node(node)
1135 def __getstate__(self) -> dict:
1136 """Stores a compact form of the graph as a list of graph nodes, and a
1137 tuple of task labels and task configs. The full graph can be
1138 reconstructed with this information, and it preseves the ordering of
1139 the graph ndoes.
1140 """
1141 universe: Optional[DimensionUniverse] = None
1142 for node in self:
1143 dId = node.quantum.dataId
1144 if dId is None:
1145 continue
1146 universe = dId.graph.universe
1147 return {"reduced": self._buildSaveObject(), "graphId": self._buildId, "universe": universe}
1149 def __setstate__(self, state: dict) -> None:
1150 """Reconstructs the state of the graph from the information persisted
1151 in getstate.
1152 """
1153 buffer = io.BytesIO(state["reduced"])
1154 with LoadHelper(buffer, minimumVersion=3) as loader:
1155 qgraph = loader.load(state["universe"], graphID=state["graphId"])
1157 self._metadata = qgraph._metadata
1158 self._buildId = qgraph._buildId
1159 self._datasetDict = qgraph._datasetDict
1160 self._nodeIdMap = qgraph._nodeIdMap
1161 self._count = len(qgraph)
1162 self._taskToQuantumNode = qgraph._taskToQuantumNode
1163 self._taskGraph = qgraph._taskGraph
1164 self._connectedQuanta = qgraph._connectedQuanta
1166 def __eq__(self, other: object) -> bool:
1167 if not isinstance(other, QuantumGraph):
1168 return False
1169 if len(self) != len(other):
1170 return False
1171 for node in self:
1172 if node not in other:
1173 return False
1174 if self.determineInputsToQuantumNode(node) != other.determineInputsToQuantumNode(node):
1175 return False
1176 if self.determineOutputsOfQuantumNode(node) != other.determineOutputsOfQuantumNode(node):
1177 return False
1178 if set(self.allDatasetTypes) != set(other.allDatasetTypes):
1179 return False
1180 return set(self.taskGraph) == set(other.taskGraph)