Coverage for python/lsst/pipe/base/graph/graph.py: 19%
346 statements
« prev ^ index » next coverage.py v6.4.2, created at 2022-07-14 16:10 -0700
« prev ^ index » next coverage.py v6.4.2, created at 2022-07-14 16:10 -0700
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ("QuantumGraph", "IncompatibleGraphError")
25import io
26import json
27import lzma
28import os
29import pickle
30import struct
31import time
32import uuid
33import warnings
34from collections import defaultdict, deque
35from itertools import chain
36from types import MappingProxyType
37from typing import (
38 Any,
39 BinaryIO,
40 DefaultDict,
41 Deque,
42 Dict,
43 FrozenSet,
44 Generator,
45 Iterable,
46 List,
47 Mapping,
48 MutableMapping,
49 Optional,
50 Set,
51 Tuple,
52 TypeVar,
53 Union,
54)
56import networkx as nx
57from lsst.daf.butler import DatasetRef, DatasetType, DimensionRecordsAccumulator, DimensionUniverse, Quantum
58from lsst.resources import ResourcePath, ResourcePathExpression
59from lsst.utils.introspection import get_full_type_name
60from networkx.drawing.nx_agraph import write_dot
62from ..connections import iterConnections
63from ..pipeline import TaskDef
64from ._implDetails import DatasetTypeName, _DatasetTracker, _pruner
65from ._loadHelpers import LoadHelper
66from ._versionDeserializers import DESERIALIZER_MAP
67from .quantumNode import BuildId, QuantumNode
69_T = TypeVar("_T", bound="QuantumGraph")
71# modify this constant any time the on disk representation of the save file
72# changes, and update the load helpers to behave properly for each version.
73SAVE_VERSION = 3
75# Strings used to describe the format for the preamble bytes in a file save
76# The base is a big endian encoded unsigned short that is used to hold the
77# file format version. This allows reading version bytes and determine which
78# loading code should be used for the rest of the file
79STRUCT_FMT_BASE = ">H"
80#
81# Version 1
82# This marks a big endian encoded format with an unsigned short, an unsigned
83# long long, and an unsigned long long in the byte stream
84# Version 2
85# A big endian encoded format with an unsigned long long byte stream used to
86# indicate the total length of the entire header.
87STRUCT_FMT_STRING = {1: ">QQ", 2: ">Q"}
89# magic bytes that help determine this is a graph save
90MAGIC_BYTES = b"qgraph4\xf6\xe8\xa9"
93class IncompatibleGraphError(Exception):
94 """Exception class to indicate that a lookup by NodeId is impossible due
95 to incompatibilities
96 """
98 pass
101class QuantumGraph:
102 """QuantumGraph is a directed acyclic graph of `QuantumNode` objects
104 This data structure represents a concrete workflow generated from a
105 `Pipeline`.
107 Parameters
108 ----------
109 quanta : Mapping of `TaskDef` to sets of `Quantum`
110 This maps tasks (and their configs) to the sets of data they are to
111 process.
112 metadata : Optional Mapping of `str` to primitives
113 This is an optional parameter of extra data to carry with the graph.
114 Entries in this mapping should be able to be serialized in JSON.
116 Raises
117 ------
118 ValueError
119 Raised if the graph is pruned such that some tasks no longer have nodes
120 associated with them.
121 """
123 def __init__(
124 self,
125 quanta: Mapping[TaskDef, Set[Quantum]],
126 metadata: Optional[Mapping[str, Any]] = None,
127 pruneRefs: Optional[Iterable[DatasetRef]] = None,
128 universe: Optional[DimensionUniverse] = None,
129 ):
130 self._buildGraphs(quanta, metadata=metadata, pruneRefs=pruneRefs, universe=universe)
132 def _buildGraphs(
133 self,
134 quanta: Mapping[TaskDef, Set[Quantum]],
135 *,
136 _quantumToNodeId: Optional[Mapping[Quantum, uuid.UUID]] = None,
137 _buildId: Optional[BuildId] = None,
138 metadata: Optional[Mapping[str, Any]] = None,
139 pruneRefs: Optional[Iterable[DatasetRef]] = None,
140 universe: Optional[DimensionUniverse] = None,
141 ) -> None:
142 """Builds the graph that is used to store the relation between tasks,
143 and the graph that holds the relations between quanta
144 """
145 if universe is None:
146 universe = DimensionUniverse()
147 self._universe = universe
148 self._metadata = metadata
149 self._buildId = _buildId if _buildId is not None else BuildId(f"{time.time()}-{os.getpid()}")
150 # Data structures used to identify relations between components;
151 # DatasetTypeName -> TaskDef for task,
152 # and DatasetRef -> QuantumNode for the quanta
153 self._datasetDict = _DatasetTracker[DatasetTypeName, TaskDef](createInverse=True)
154 self._datasetRefDict = _DatasetTracker[DatasetRef, QuantumNode]()
156 self._nodeIdMap: Dict[uuid.UUID, QuantumNode] = {}
157 self._taskToQuantumNode: MutableMapping[TaskDef, Set[QuantumNode]] = defaultdict(set)
158 for taskDef, quantumSet in quanta.items():
159 connections = taskDef.connections
161 # For each type of connection in the task, add a key to the
162 # `_DatasetTracker` for the connections name, with a value of
163 # the TaskDef in the appropriate field
164 for inpt in iterConnections(connections, ("inputs", "prerequisiteInputs", "initInputs")):
165 self._datasetDict.addConsumer(DatasetTypeName(inpt.name), taskDef)
167 for output in iterConnections(connections, ("outputs",)):
168 self._datasetDict.addProducer(DatasetTypeName(output.name), taskDef)
170 # For each `Quantum` in the set of all `Quantum` for this task,
171 # add a key to the `_DatasetTracker` that is a `DatasetRef` for one
172 # of the individual datasets inside the `Quantum`, with a value of
173 # a newly created QuantumNode to the appropriate input/output
174 # field.
175 for quantum in quantumSet:
176 if _quantumToNodeId:
177 if (nodeId := _quantumToNodeId.get(quantum)) is None:
178 raise ValueError(
179 "If _quantuMToNodeNumber is not None, all quanta must have an "
180 "associated value in the mapping"
181 )
182 else:
183 nodeId = uuid.uuid4()
185 inits = quantum.initInputs.values()
186 inputs = quantum.inputs.values()
187 value = QuantumNode(quantum, taskDef, nodeId)
188 self._taskToQuantumNode[taskDef].add(value)
189 self._nodeIdMap[nodeId] = value
191 for dsRef in chain(inits, inputs):
192 # unfortunately, `Quantum` allows inits to be individual
193 # `DatasetRef`s or an Iterable of such, so there must
194 # be an instance check here
195 if isinstance(dsRef, Iterable):
196 for sub in dsRef:
197 if sub.isComponent():
198 sub = sub.makeCompositeRef()
199 self._datasetRefDict.addConsumer(sub, value)
200 else:
201 assert isinstance(dsRef, DatasetRef)
202 if dsRef.isComponent():
203 dsRef = dsRef.makeCompositeRef()
204 self._datasetRefDict.addConsumer(dsRef, value)
205 for dsRef in chain.from_iterable(quantum.outputs.values()):
206 self._datasetRefDict.addProducer(dsRef, value)
208 if pruneRefs is not None:
209 # track what refs were pruned and prune the graph
210 prunes: Set[QuantumNode] = set()
211 _pruner(self._datasetRefDict, pruneRefs, alreadyPruned=prunes)
213 # recreate the taskToQuantumNode dict removing nodes that have been
214 # pruned. Keep track of task defs that now have no QuantumNodes
215 emptyTasks: Set[str] = set()
216 newTaskToQuantumNode: DefaultDict[TaskDef, Set[QuantumNode]] = defaultdict(set)
217 # accumulate all types
218 types_ = set()
219 # tracker for any pruneRefs that have caused tasks to have no nodes
220 # This helps the user find out what caused the issues seen.
221 culprits = set()
222 # Find all the types from the refs to prune
223 for r in pruneRefs:
224 types_.add(r.datasetType)
226 # For each of the tasks, and their associated nodes, remove any
227 # any nodes that were pruned. If there are no nodes associated
228 # with a task, record that task, and find out if that was due to
229 # a type from an input ref to prune.
230 for td, taskNodes in self._taskToQuantumNode.items():
231 diff = taskNodes.difference(prunes)
232 if len(diff) == 0:
233 if len(taskNodes) != 0:
234 tp: DatasetType
235 for tp in types_:
236 if (tmpRefs := next(iter(taskNodes)).quantum.inputs.get(tp)) and not set(
237 tmpRefs
238 ).difference(pruneRefs):
239 culprits.add(tp.name)
240 emptyTasks.add(td.label)
241 newTaskToQuantumNode[td] = diff
243 # update the internal dict
244 self._taskToQuantumNode = newTaskToQuantumNode
246 if emptyTasks:
247 raise ValueError(
248 f"{', '.join(emptyTasks)} task(s) have no nodes associated with them "
249 f"after graph pruning; {', '.join(culprits)} caused over-pruning"
250 )
252 # Graph of quanta relations
253 self._connectedQuanta = self._datasetRefDict.makeNetworkXGraph()
254 self._count = len(self._connectedQuanta)
256 # Graph of task relations, used in various methods
257 self._taskGraph = self._datasetDict.makeNetworkXGraph()
259 # convert default dict into a regular to prevent accidental key
260 # insertion
261 self._taskToQuantumNode = dict(self._taskToQuantumNode.items())
263 @property
264 def taskGraph(self) -> nx.DiGraph:
265 """Return a graph representing the relations between the tasks inside
266 the quantum graph.
268 Returns
269 -------
270 taskGraph : `networkx.Digraph`
271 Internal datastructure that holds relations of `TaskDef` objects
272 """
273 return self._taskGraph
275 @property
276 def graph(self) -> nx.DiGraph:
277 """Return a graph representing the relations between all the
278 `QuantumNode` objects. Largely it should be preferred to iterate
279 over, and use methods of this class, but sometimes direct access to
280 the networkx object may be helpful
282 Returns
283 -------
284 graph : `networkx.Digraph`
285 Internal datastructure that holds relations of `QuantumNode`
286 objects
287 """
288 return self._connectedQuanta
290 @property
291 def inputQuanta(self) -> Iterable[QuantumNode]:
292 """Make a `list` of all `QuantumNode` objects that are 'input' nodes
293 to the graph, meaning those nodes to not depend on any other nodes in
294 the graph.
296 Returns
297 -------
298 inputNodes : iterable of `QuantumNode`
299 A list of nodes that are inputs to the graph
300 """
301 return (q for q, n in self._connectedQuanta.in_degree if n == 0)
303 @property
304 def outputQuanta(self) -> Iterable[QuantumNode]:
305 """Make a `list` of all `QuantumNode` objects that are 'output' nodes
306 to the graph, meaning those nodes have no nodes that depend them in
307 the graph.
309 Returns
310 -------
311 outputNodes : iterable of `QuantumNode`
312 A list of nodes that are outputs of the graph
313 """
314 return [q for q, n in self._connectedQuanta.out_degree if n == 0]
316 @property
317 def allDatasetTypes(self) -> Tuple[DatasetTypeName, ...]:
318 """Return all the `DatasetTypeName` objects that are contained inside
319 the graph.
321 Returns
322 -------
323 tuple of `DatasetTypeName`
324 All the data set type names that are present in the graph
325 """
326 return tuple(self._datasetDict.keys())
328 @property
329 def isConnected(self) -> bool:
330 """Return True if all of the nodes in the graph are connected, ignores
331 directionality of connections.
332 """
333 return nx.is_weakly_connected(self._connectedQuanta)
335 def pruneGraphFromRefs(self: _T, refs: Iterable[DatasetRef]) -> _T:
336 r"""Return a graph pruned of input `~lsst.daf.butler.DatasetRef`\ s
337 and nodes which depend on them.
339 Parameters
340 ----------
341 refs : `Iterable` of `DatasetRef`
342 Refs which should be removed from resulting graph
344 Returns
345 -------
346 graph : `QuantumGraph`
347 A graph that has been pruned of specified refs and the nodes that
348 depend on them.
349 """
350 newInst = object.__new__(type(self))
351 quantumMap = defaultdict(set)
352 for node in self:
353 quantumMap[node.taskDef].add(node.quantum)
355 # convert to standard dict to prevent accidental key insertion
356 quantumDict: Dict[TaskDef, Set[Quantum]] = dict(quantumMap.items())
358 newInst._buildGraphs(
359 quantumDict,
360 _quantumToNodeId={n.quantum: n.nodeId for n in self},
361 metadata=self._metadata,
362 pruneRefs=refs,
363 )
364 return newInst
366 def getQuantumNodeByNodeId(self, nodeId: uuid.UUID) -> QuantumNode:
367 """Lookup a `QuantumNode` from an id associated with the node.
369 Parameters
370 ----------
371 nodeId : `NodeId`
372 The number associated with a node
374 Returns
375 -------
376 node : `QuantumNode`
377 The node corresponding with input number
379 Raises
380 ------
381 KeyError
382 Raised if the requested nodeId is not in the graph.
383 """
384 return self._nodeIdMap[nodeId]
386 def getQuantaForTask(self, taskDef: TaskDef) -> FrozenSet[Quantum]:
387 """Return all the `Quantum` associated with a `TaskDef`.
389 Parameters
390 ----------
391 taskDef : `TaskDef`
392 The `TaskDef` for which `Quantum` are to be queried
394 Returns
395 -------
396 frozenset of `Quantum`
397 The `set` of `Quantum` that is associated with the specified
398 `TaskDef`.
399 """
400 return frozenset(node.quantum for node in self._taskToQuantumNode[taskDef])
402 def getNodesForTask(self, taskDef: TaskDef) -> FrozenSet[QuantumNode]:
403 """Return all the `QuantumNodes` associated with a `TaskDef`.
405 Parameters
406 ----------
407 taskDef : `TaskDef`
408 The `TaskDef` for which `Quantum` are to be queried
410 Returns
411 -------
412 frozenset of `QuantumNodes`
413 The `frozenset` of `QuantumNodes` that is associated with the
414 specified `TaskDef`.
415 """
416 return frozenset(self._taskToQuantumNode[taskDef])
418 def findTasksWithInput(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]:
419 """Find all tasks that have the specified dataset type name as an
420 input.
422 Parameters
423 ----------
424 datasetTypeName : `str`
425 A string representing the name of a dataset type to be queried,
426 can also accept a `DatasetTypeName` which is a `NewType` of str for
427 type safety in static type checking.
429 Returns
430 -------
431 tasks : iterable of `TaskDef`
432 `TaskDef` objects that have the specified `DatasetTypeName` as an
433 input, list will be empty if no tasks use specified
434 `DatasetTypeName` as an input.
436 Raises
437 ------
438 KeyError
439 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
440 """
441 return (c for c in self._datasetDict.getConsumers(datasetTypeName))
443 def findTaskWithOutput(self, datasetTypeName: DatasetTypeName) -> Optional[TaskDef]:
444 """Find all tasks that have the specified dataset type name as an
445 output.
447 Parameters
448 ----------
449 datasetTypeName : `str`
450 A string representing the name of a dataset type to be queried,
451 can also accept a `DatasetTypeName` which is a `NewType` of str for
452 type safety in static type checking.
454 Returns
455 -------
456 `TaskDef` or `None`
457 `TaskDef` that outputs `DatasetTypeName` as an output or None if
458 none of the tasks produce this `DatasetTypeName`.
460 Raises
461 ------
462 KeyError
463 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
464 """
465 return self._datasetDict.getProducer(datasetTypeName)
467 def tasksWithDSType(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]:
468 """Find all tasks that are associated with the specified dataset type
469 name.
471 Parameters
472 ----------
473 datasetTypeName : `str`
474 A string representing the name of a dataset type to be queried,
475 can also accept a `DatasetTypeName` which is a `NewType` of str for
476 type safety in static type checking.
478 Returns
479 -------
480 result : iterable of `TaskDef`
481 `TaskDef` objects that are associated with the specified
482 `DatasetTypeName`
484 Raises
485 ------
486 KeyError
487 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
488 """
489 return self._datasetDict.getAll(datasetTypeName)
491 def findTaskDefByName(self, taskName: str) -> List[TaskDef]:
492 """Determine which `TaskDef` objects in this graph are associated
493 with a `str` representing a task name (looks at the taskName property
494 of `TaskDef` objects).
496 Returns a list of `TaskDef` objects as a `PipelineTask` may appear
497 multiple times in a graph with different labels.
499 Parameters
500 ----------
501 taskName : str
502 Name of a task to search for
504 Returns
505 -------
506 result : list of `TaskDef`
507 List of the `TaskDef` objects that have the name specified.
508 Multiple values are returned in the case that a task is used
509 multiple times with different labels.
510 """
511 results = []
512 for task in self._taskToQuantumNode.keys():
513 split = task.taskName.split(".")
514 if split[-1] == taskName:
515 results.append(task)
516 return results
518 def findTaskDefByLabel(self, label: str) -> Optional[TaskDef]:
519 """Determine which `TaskDef` objects in this graph are associated
520 with a `str` representing a tasks label.
522 Parameters
523 ----------
524 taskName : str
525 Name of a task to search for
527 Returns
528 -------
529 result : `TaskDef`
530 `TaskDef` objects that has the specified label.
531 """
532 for task in self._taskToQuantumNode.keys():
533 if label == task.label:
534 return task
535 return None
537 def findQuantaWithDSType(self, datasetTypeName: DatasetTypeName) -> Set[Quantum]:
538 """Return all the `Quantum` that contain a specified `DatasetTypeName`.
540 Parameters
541 ----------
542 datasetTypeName : `str`
543 The name of the dataset type to search for as a string,
544 can also accept a `DatasetTypeName` which is a `NewType` of str for
545 type safety in static type checking.
547 Returns
548 -------
549 result : `set` of `QuantumNode` objects
550 A `set` of `QuantumNode`s that contain specified `DatasetTypeName`
552 Raises
553 ------
554 KeyError
555 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
557 """
558 tasks = self._datasetDict.getAll(datasetTypeName)
559 result: Set[Quantum] = set()
560 result = result.union(quantum for task in tasks for quantum in self.getQuantaForTask(task))
561 return result
563 def checkQuantumInGraph(self, quantum: Quantum) -> bool:
564 """Check if specified quantum appears in the graph as part of a node.
566 Parameters
567 ----------
568 quantum : `Quantum`
569 The quantum to search for
571 Returns
572 -------
573 `bool`
574 The result of searching for the quantum
575 """
576 for node in self:
577 if quantum == node.quantum:
578 return True
579 return False
581 def writeDotGraph(self, output: Union[str, io.BufferedIOBase]) -> None:
582 """Write out the graph as a dot graph.
584 Parameters
585 ----------
586 output : str or `io.BufferedIOBase`
587 Either a filesystem path to write to, or a file handle object
588 """
589 write_dot(self._connectedQuanta, output)
591 def subset(self: _T, nodes: Union[QuantumNode, Iterable[QuantumNode]]) -> _T:
592 """Create a new graph object that contains the subset of the nodes
593 specified as input. Node number is preserved.
595 Parameters
596 ----------
597 nodes : `QuantumNode` or iterable of `QuantumNode`
599 Returns
600 -------
601 graph : instance of graph type
602 An instance of the type from which the subset was created
603 """
604 if not isinstance(nodes, Iterable):
605 nodes = (nodes,)
606 quantumSubgraph = self._connectedQuanta.subgraph(nodes).nodes
607 quantumMap = defaultdict(set)
609 node: QuantumNode
610 for node in quantumSubgraph:
611 quantumMap[node.taskDef].add(node.quantum)
613 # convert to standard dict to prevent accidental key insertion
614 quantumDict: Dict[TaskDef, Set[Quantum]] = dict(quantumMap.items())
615 # Create an empty graph, and then populate it with custom mapping
616 newInst = type(self)({})
617 newInst._buildGraphs(
618 quantumDict,
619 _quantumToNodeId={n.quantum: n.nodeId for n in nodes},
620 _buildId=self._buildId,
621 metadata=self._metadata,
622 )
623 return newInst
625 def subsetToConnected(self: _T) -> Tuple[_T, ...]:
626 """Generate a list of subgraphs where each is connected.
628 Returns
629 -------
630 result : list of `QuantumGraph`
631 A list of graphs that are each connected
632 """
633 return tuple(
634 self.subset(connectedSet)
635 for connectedSet in nx.weakly_connected_components(self._connectedQuanta)
636 )
638 def determineInputsToQuantumNode(self, node: QuantumNode) -> Set[QuantumNode]:
639 """Return a set of `QuantumNode` that are direct inputs to a specified
640 node.
642 Parameters
643 ----------
644 node : `QuantumNode`
645 The node of the graph for which inputs are to be determined
647 Returns
648 -------
649 set of `QuantumNode`
650 All the nodes that are direct inputs to specified node
651 """
652 return set(pred for pred in self._connectedQuanta.predecessors(node))
654 def determineOutputsOfQuantumNode(self, node: QuantumNode) -> Set[QuantumNode]:
655 """Return a set of `QuantumNode` that are direct outputs of a specified
656 node.
658 Parameters
659 ----------
660 node : `QuantumNode`
661 The node of the graph for which outputs are to be determined
663 Returns
664 -------
665 set of `QuantumNode`
666 All the nodes that are direct outputs to specified node
667 """
668 return set(succ for succ in self._connectedQuanta.successors(node))
670 def determineConnectionsOfQuantumNode(self: _T, node: QuantumNode) -> _T:
671 """Return a graph of `QuantumNode` that are direct inputs and outputs
672 of a specified node.
674 Parameters
675 ----------
676 node : `QuantumNode`
677 The node of the graph for which connected nodes are to be
678 determined.
680 Returns
681 -------
682 graph : graph of `QuantumNode`
683 All the nodes that are directly connected to specified node
684 """
685 nodes = self.determineInputsToQuantumNode(node).union(self.determineOutputsOfQuantumNode(node))
686 nodes.add(node)
687 return self.subset(nodes)
689 def determineAncestorsOfQuantumNode(self: _T, node: QuantumNode) -> _T:
690 """Return a graph of the specified node and all the ancestor nodes
691 directly reachable by walking edges.
693 Parameters
694 ----------
695 node : `QuantumNode`
696 The node for which all ansestors are to be determined
698 Returns
699 -------
700 graph of `QuantumNode`
701 Graph of node and all of its ansestors
702 """
703 predecessorNodes = nx.ancestors(self._connectedQuanta, node)
704 predecessorNodes.add(node)
705 return self.subset(predecessorNodes)
707 def findCycle(self) -> List[Tuple[QuantumNode, QuantumNode]]:
708 """Check a graph for the presense of cycles and returns the edges of
709 any cycles found, or an empty list if there is no cycle.
711 Returns
712 -------
713 result : list of tuple of `QuantumNode`, `QuantumNode`
714 A list of any graph edges that form a cycle, or an empty list if
715 there is no cycle. Empty list to so support if graph.find_cycle()
716 syntax as an empty list is falsy.
717 """
718 try:
719 return nx.find_cycle(self._connectedQuanta)
720 except nx.NetworkXNoCycle:
721 return []
723 def saveUri(self, uri: ResourcePathExpression) -> None:
724 """Save `QuantumGraph` to the specified URI.
726 Parameters
727 ----------
728 uri : convertible to `ResourcePath`
729 URI to where the graph should be saved.
730 """
731 buffer = self._buildSaveObject()
732 path = ResourcePath(uri)
733 if path.getExtension() not in (".qgraph"):
734 raise TypeError(f"Can currently only save a graph in qgraph format not {uri}")
735 path.write(buffer) # type: ignore # Ignore because bytearray is safe to use in place of bytes
737 @property
738 def metadata(self) -> Optional[MappingProxyType[str, Any]]:
739 """ """
740 if self._metadata is None:
741 return None
742 return MappingProxyType(self._metadata)
744 @classmethod
745 def loadUri(
746 cls,
747 uri: ResourcePathExpression,
748 universe: Optional[DimensionUniverse] = None,
749 nodes: Optional[Iterable[uuid.UUID]] = None,
750 graphID: Optional[BuildId] = None,
751 minimumVersion: int = 3,
752 ) -> QuantumGraph:
753 """Read `QuantumGraph` from a URI.
755 Parameters
756 ----------
757 uri : convertible to `ResourcePath`
758 URI from where to load the graph.
759 universe: `~lsst.daf.butler.DimensionUniverse` optional
760 DimensionUniverse instance, not used by the method itself but
761 needed to ensure that registry data structures are initialized.
762 If None it is loaded from the QuantumGraph saved structure. If
763 supplied, the DimensionUniverse from the loaded `QuantumGraph`
764 will be validated against the supplied argument for compatibility.
765 nodes: iterable of `int` or None
766 Numbers that correspond to nodes in the graph. If specified, only
767 these nodes will be loaded. Defaults to None, in which case all
768 nodes will be loaded.
769 graphID : `str` or `None`
770 If specified this ID is verified against the loaded graph prior to
771 loading any Nodes. This defaults to None in which case no
772 validation is done.
773 minimumVersion : int
774 Minimum version of a save file to load. Set to -1 to load all
775 versions. Older versions may need to be loaded, and re-saved
776 to upgrade them to the latest format before they can be used in
777 production.
779 Returns
780 -------
781 graph : `QuantumGraph`
782 Resulting QuantumGraph instance.
784 Raises
785 ------
786 TypeError
787 Raised if pickle contains instance of a type other than
788 QuantumGraph.
789 ValueError
790 Raised if one or more of the nodes requested is not in the
791 `QuantumGraph` or if graphID parameter does not match the graph
792 being loaded or if the supplied uri does not point at a valid
793 `QuantumGraph` save file.
794 RuntimeError
795 Raise if Supplied DimensionUniverse is not compatible with the
796 DimensionUniverse saved in the graph
799 Notes
800 -----
801 Reading Quanta from pickle requires existence of singleton
802 DimensionUniverse which is usually instantiated during Registry
803 initialization. To make sure that DimensionUniverse exists this method
804 accepts dummy DimensionUniverse argument.
805 """
806 uri = ResourcePath(uri)
807 # With ResourcePath we have the choice of always using a local file
808 # or reading in the bytes directly. Reading in bytes can be more
809 # efficient for reasonably-sized pickle files when the resource
810 # is remote. For now use the local file variant. For a local file
811 # as_local() does nothing.
813 if uri.getExtension() in (".pickle", ".pkl"):
814 with uri.as_local() as local, open(local.ospath, "rb") as fd:
815 warnings.warn("Pickle graphs are deprecated, please re-save your graph with the save method")
816 qgraph = pickle.load(fd)
817 elif uri.getExtension() in (".qgraph"):
818 with LoadHelper(uri, minimumVersion) as loader:
819 qgraph = loader.load(universe, nodes, graphID)
820 else:
821 raise ValueError("Only know how to handle files saved as `pickle`, `pkl`, or `qgraph`")
822 if not isinstance(qgraph, QuantumGraph):
823 raise TypeError(f"QuantumGraph save file contains unexpected object type: {type(qgraph)}")
824 return qgraph
826 @classmethod
827 def readHeader(cls, uri: ResourcePathExpression, minimumVersion: int = 3) -> Optional[str]:
828 """Read the header of a `QuantumGraph` pointed to by the uri parameter
829 and return it as a string.
831 Parameters
832 ----------
833 uri : convertible to `ResourcePath`
834 The location of the `QuantumGraph` to load. If the argument is a
835 string, it must correspond to a valid `ResourcePath` path.
836 minimumVersion : int
837 Minimum version of a save file to load. Set to -1 to load all
838 versions. Older versions may need to be loaded, and re-saved
839 to upgrade them to the latest format before they can be used in
840 production.
842 Returns
843 -------
844 header : `str` or `None`
845 The header associated with the specified `QuantumGraph` it there is
846 one, else `None`.
848 Raises
849 ------
850 ValueError
851 Raised if `QuantuGraph` was saved as a pickle.
852 Raised if the extention of the file specified by uri is not a
853 `QuantumGraph` extention.
854 """
855 uri = ResourcePath(uri)
856 if uri.getExtension() in (".pickle", ".pkl"):
857 raise ValueError("Reading a header from a pickle save is not supported")
858 elif uri.getExtension() in (".qgraph"):
859 return LoadHelper(uri, minimumVersion).readHeader()
860 else:
861 raise ValueError("Only know how to handle files saved as `qgraph`")
863 def buildAndPrintHeader(self) -> None:
864 """Creates a header that would be used in a save of this object and
865 prints it out to standard out.
866 """
867 _, header = self._buildSaveObject(returnHeader=True)
868 print(json.dumps(header))
870 def save(self, file: BinaryIO) -> None:
871 """Save QuantumGraph to a file.
873 Parameters
874 ----------
875 file : `io.BufferedIOBase`
876 File to write pickle data open in binary mode.
877 """
878 buffer = self._buildSaveObject()
879 file.write(buffer) # type: ignore # Ignore because bytearray is safe to use in place of bytes
881 def _buildSaveObject(self, returnHeader: bool = False) -> Union[bytearray, Tuple[bytearray, Dict]]:
882 # make some containers
883 jsonData: Deque[bytes] = deque()
884 # node map is a list because json does not accept mapping keys that
885 # are not strings, so we store a list of key, value pairs that will
886 # be converted to a mapping on load
887 nodeMap = []
888 taskDefMap = {}
889 headerData: Dict[str, Any] = {}
891 # Store the QauntumGraph BuildId, this will allow validating BuildIds
892 # at load time, prior to loading any QuantumNodes. Name chosen for
893 # unlikely conflicts.
894 headerData["GraphBuildID"] = self.graphID
895 headerData["Metadata"] = self._metadata
897 # Store the universe this graph was created with
898 universeConfig = self._universe.dimensionConfig
899 headerData["universe"] = universeConfig.toDict()
901 # counter for the number of bytes processed thus far
902 count = 0
903 # serialize out the task Defs recording the start and end bytes of each
904 # taskDef
905 inverseLookup = self._datasetDict.inverse
906 taskDef: TaskDef
907 # sort by task label to ensure serialization happens in the same order
908 for taskDef in self.taskGraph:
909 # compressing has very little impact on saving or load time, but
910 # a large impact on on disk size, so it is worth doing
911 taskDescription = {}
912 # save the fully qualified name.
913 taskDescription["taskName"] = get_full_type_name(taskDef.taskClass)
914 # save the config as a text stream that will be un-persisted on the
915 # other end
916 stream = io.StringIO()
917 taskDef.config.saveToStream(stream)
918 taskDescription["config"] = stream.getvalue()
919 taskDescription["label"] = taskDef.label
921 inputs = []
922 outputs = []
924 # Determine the connection between all of tasks and save that in
925 # the header as a list of connections and edges in each task
926 # this will help in un-persisting, and possibly in a "quick view"
927 # method that does not require everything to be un-persisted
928 #
929 # Typing returns can't be parameter dependent
930 for connection in inverseLookup[taskDef]: # type: ignore
931 consumers = self._datasetDict.getConsumers(connection)
932 producer = self._datasetDict.getProducer(connection)
933 if taskDef in consumers:
934 # This checks if the task consumes the connection directly
935 # from the datastore or it is produced by another task
936 producerLabel = producer.label if producer is not None else "datastore"
937 inputs.append((producerLabel, connection))
938 elif taskDef not in consumers and producer is taskDef:
939 # If there are no consumers for this tasks produced
940 # connection, the output will be said to be the datastore
941 # in which case the for loop will be a zero length loop
942 if not consumers:
943 outputs.append(("datastore", connection))
944 for td in consumers:
945 outputs.append((td.label, connection))
947 # dump to json string, and encode that string to bytes and then
948 # conpress those bytes
949 dump = lzma.compress(json.dumps(taskDescription).encode())
950 # record the sizing and relation information
951 taskDefMap[taskDef.label] = {
952 "bytes": (count, count + len(dump)),
953 "inputs": inputs,
954 "outputs": outputs,
955 }
956 count += len(dump)
957 jsonData.append(dump)
959 headerData["TaskDefs"] = taskDefMap
961 # serialize the nodes, recording the start and end bytes of each node
962 dimAccumulator = DimensionRecordsAccumulator()
963 for node in self:
964 # compressing has very little impact on saving or load time, but
965 # a large impact on on disk size, so it is worth doing
966 simpleNode = node.to_simple(accumulator=dimAccumulator)
968 dump = lzma.compress(simpleNode.json().encode())
969 jsonData.append(dump)
970 nodeMap.append(
971 (
972 str(node.nodeId),
973 {
974 "bytes": (count, count + len(dump)),
975 "inputs": [str(n.nodeId) for n in self.determineInputsToQuantumNode(node)],
976 "outputs": [str(n.nodeId) for n in self.determineOutputsOfQuantumNode(node)],
977 },
978 )
979 )
980 count += len(dump)
982 headerData["DimensionRecords"] = {
983 key: value.dict() for key, value in dimAccumulator.makeSerializedDimensionRecordMapping().items()
984 }
986 # need to serialize this as a series of key,value tuples because of
987 # a limitation on how json cant do anyting but strings as keys
988 headerData["Nodes"] = nodeMap
990 # dump the headerData to json
991 header_encode = lzma.compress(json.dumps(headerData).encode())
993 # record the sizes as 2 unsigned long long numbers for a total of 16
994 # bytes
995 save_bytes = struct.pack(STRUCT_FMT_BASE, SAVE_VERSION)
997 fmt_string = DESERIALIZER_MAP[SAVE_VERSION].FMT_STRING()
998 map_lengths = struct.pack(fmt_string, len(header_encode))
1000 # write each component of the save out in a deterministic order
1001 # buffer = io.BytesIO()
1002 # buffer.write(map_lengths)
1003 # buffer.write(taskDef_pickle)
1004 # buffer.write(map_pickle)
1005 buffer = bytearray()
1006 buffer.extend(MAGIC_BYTES)
1007 buffer.extend(save_bytes)
1008 buffer.extend(map_lengths)
1009 buffer.extend(header_encode)
1010 # Iterate over the length of pickleData, and for each element pop the
1011 # leftmost element off the deque and write it out. This is to save
1012 # memory, as the memory is added to the buffer object, it is removed
1013 # from from the container.
1014 #
1015 # Only this section needs to worry about memory pressue because
1016 # everything else written to the buffer prior to this pickle data is
1017 # only on the order of kilobytes to low numbers of megabytes.
1018 while jsonData:
1019 buffer.extend(jsonData.popleft())
1020 if returnHeader:
1021 return buffer, headerData
1022 else:
1023 return buffer
1025 @classmethod
1026 def load(
1027 cls,
1028 file: BinaryIO,
1029 universe: Optional[DimensionUniverse] = None,
1030 nodes: Optional[Iterable[uuid.UUID]] = None,
1031 graphID: Optional[BuildId] = None,
1032 minimumVersion: int = 3,
1033 ) -> QuantumGraph:
1034 """Read QuantumGraph from a file that was made by `save`.
1036 Parameters
1037 ----------
1038 file : `io.IO` of bytes
1039 File with pickle data open in binary mode.
1040 universe: `~lsst.daf.butler.DimensionUniverse`, optional
1041 DimensionUniverse instance, not used by the method itself but
1042 needed to ensure that registry data structures are initialized.
1043 If None it is loaded from the QuantumGraph saved structure. If
1044 supplied, the DimensionUniverse from the loaded `QuantumGraph`
1045 will be validated against the supplied argument for compatibility.
1046 nodes: iterable of `int` or None
1047 Numbers that correspond to nodes in the graph. If specified, only
1048 these nodes will be loaded. Defaults to None, in which case all
1049 nodes will be loaded.
1050 graphID : `str` or `None`
1051 If specified this ID is verified against the loaded graph prior to
1052 loading any Nodes. This defaults to None in which case no
1053 validation is done.
1054 minimumVersion : int
1055 Minimum version of a save file to load. Set to -1 to load all
1056 versions. Older versions may need to be loaded, and re-saved
1057 to upgrade them to the latest format before they can be used in
1058 production.
1060 Returns
1061 -------
1062 graph : `QuantumGraph`
1063 Resulting QuantumGraph instance.
1065 Raises
1066 ------
1067 TypeError
1068 Raised if pickle contains instance of a type other than
1069 QuantumGraph.
1070 ValueError
1071 Raised if one or more of the nodes requested is not in the
1072 `QuantumGraph` or if graphID parameter does not match the graph
1073 being loaded or if the supplied uri does not point at a valid
1074 `QuantumGraph` save file.
1076 Notes
1077 -----
1078 Reading Quanta from pickle requires existence of singleton
1079 DimensionUniverse which is usually instantiated during Registry
1080 initialization. To make sure that DimensionUniverse exists this method
1081 accepts dummy DimensionUniverse argument.
1082 """
1083 # Try to see if the file handle contains pickle data, this will be
1084 # removed in the future
1085 try:
1086 qgraph = pickle.load(file)
1087 warnings.warn("Pickle graphs are deprecated, please re-save your graph with the save method")
1088 except pickle.UnpicklingError:
1089 with LoadHelper(file, minimumVersion) as loader:
1090 qgraph = loader.load(universe, nodes, graphID)
1091 if not isinstance(qgraph, QuantumGraph):
1092 raise TypeError(f"QuantumGraph pickle file has contains unexpected object type: {type(qgraph)}")
1093 return qgraph
1095 def iterTaskGraph(self) -> Generator[TaskDef, None, None]:
1096 """Iterate over the `taskGraph` attribute in topological order
1098 Yields
1099 ------
1100 taskDef : `TaskDef`
1101 `TaskDef` objects in topological order
1102 """
1103 yield from nx.topological_sort(self.taskGraph)
1105 @property
1106 def graphID(self) -> BuildId:
1107 """Returns the ID generated by the graph at construction time"""
1108 return self._buildId
1110 def __iter__(self) -> Generator[QuantumNode, None, None]:
1111 yield from nx.topological_sort(self._connectedQuanta)
1113 def __len__(self) -> int:
1114 return self._count
1116 def __contains__(self, node: QuantumNode) -> bool:
1117 return self._connectedQuanta.has_node(node)
1119 def __getstate__(self) -> dict:
1120 """Stores a compact form of the graph as a list of graph nodes, and a
1121 tuple of task labels and task configs. The full graph can be
1122 reconstructed with this information, and it preseves the ordering of
1123 the graph ndoes.
1124 """
1125 universe: Optional[DimensionUniverse] = None
1126 for node in self:
1127 dId = node.quantum.dataId
1128 if dId is None:
1129 continue
1130 universe = dId.graph.universe
1131 return {"reduced": self._buildSaveObject(), "graphId": self._buildId, "universe": universe}
1133 def __setstate__(self, state: dict) -> None:
1134 """Reconstructs the state of the graph from the information persisted
1135 in getstate.
1136 """
1137 buffer = io.BytesIO(state["reduced"])
1138 with LoadHelper(buffer, minimumVersion=3) as loader:
1139 qgraph = loader.load(state["universe"], graphID=state["graphId"])
1141 self._metadata = qgraph._metadata
1142 self._buildId = qgraph._buildId
1143 self._datasetDict = qgraph._datasetDict
1144 self._nodeIdMap = qgraph._nodeIdMap
1145 self._count = len(qgraph)
1146 self._taskToQuantumNode = qgraph._taskToQuantumNode
1147 self._taskGraph = qgraph._taskGraph
1148 self._connectedQuanta = qgraph._connectedQuanta
1150 def __eq__(self, other: object) -> bool:
1151 if not isinstance(other, QuantumGraph):
1152 return False
1153 if len(self) != len(other):
1154 return False
1155 for node in self:
1156 if node not in other:
1157 return False
1158 if self.determineInputsToQuantumNode(node) != other.determineInputsToQuantumNode(node):
1159 return False
1160 if self.determineOutputsOfQuantumNode(node) != other.determineOutputsOfQuantumNode(node):
1161 return False
1162 if set(self.allDatasetTypes) != set(other.allDatasetTypes):
1163 return False
1164 return set(self.taskGraph) == set(other.taskGraph)