Coverage for python/lsst/pipe/base/graph/graph.py : 22%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
22import warnings
24__all__ = ("QuantumGraph", "IncompatibleGraphError")
26from collections import defaultdict, deque
28from itertools import chain, count
29import io
30import networkx as nx
31from networkx.drawing.nx_agraph import write_dot
32import os
33import pickle
34import lzma
35import copy
36import struct
37import time
38from typing import (DefaultDict, Dict, FrozenSet, Iterable, List, Mapping, Set, Generator, Optional, Tuple,
39 Union, TypeVar)
41from ..connections import iterConnections
42from ..pipeline import TaskDef
43from lsst.daf.butler import Quantum, DatasetRef, ButlerURI, DimensionUniverse
45from ._implDetails import _DatasetTracker, DatasetTypeName
46from .quantumNode import QuantumNode, NodeId, BuildId
47from ._loadHelpers import LoadHelper
50_T = TypeVar("_T", bound="QuantumGraph")
52# modify this constant any time the on disk representation of the save file
53# changes, and update the load helpers to behave properly for each version.
54SAVE_VERSION = 1
56# String used to describe the format for the preamble bytes in a file save
57# This marks a Big endian encoded format with an unsigned short, an unsigned
58# long long, and an unsigned long long in the byte stream
59STRUCT_FMT_STRING = '>HQQ'
62# magic bytes that help determine this is a graph save
63MAGIC_BYTES = b"qgraph4\xf6\xe8\xa9"
66class IncompatibleGraphError(Exception):
67 """Exception class to indicate that a lookup by NodeId is impossible due
68 to incompatibilities
69 """
70 pass
73class QuantumGraph:
74 """QuantumGraph is a directed acyclic graph of `QuantumNode` objects
76 This data structure represents a concrete workflow generated from a
77 `Pipeline`.
79 Parameters
80 ----------
81 quanta : Mapping of `TaskDef` to sets of `Quantum`
82 This maps tasks (and their configs) to the sets of data they are to
83 process.
84 """
85 def __init__(self, quanta: Mapping[TaskDef, Set[Quantum]]):
86 self._buildGraphs(quanta)
88 def _buildGraphs(self,
89 quanta: Mapping[TaskDef, Set[Quantum]],
90 *,
91 _quantumToNodeId: Optional[Mapping[Quantum, NodeId]] = None,
92 _buildId: Optional[BuildId] = None):
93 """Builds the graph that is used to store the relation between tasks,
94 and the graph that holds the relations between quanta
95 """
96 self._quanta = quanta
97 self._buildId = _buildId if _buildId is not None else BuildId(f"{time.time()}-{os.getpid()}")
98 # Data structures used to identify relations between components;
99 # DatasetTypeName -> TaskDef for task,
100 # and DatasetRef -> QuantumNode for the quanta
101 self._datasetDict = _DatasetTracker[DatasetTypeName, TaskDef]()
102 self._datasetRefDict = _DatasetTracker[DatasetRef, QuantumNode]()
104 nodeNumberGenerator = count()
105 self._nodeIdMap: Dict[NodeId, QuantumNode] = {}
106 self._taskToQuantumNode: DefaultDict[TaskDef, Set[QuantumNode]] = defaultdict(set)
107 self._count = 0
108 for taskDef, quantumSet in self._quanta.items():
109 connections = taskDef.connections
111 # For each type of connection in the task, add a key to the
112 # `_DatasetTracker` for the connections name, with a value of
113 # the TaskDef in the appropriate field
114 for inpt in iterConnections(connections, ("inputs", "prerequisiteInputs", "initInputs")):
115 self._datasetDict.addInput(DatasetTypeName(inpt.name), taskDef)
117 for output in iterConnections(connections, ("outputs", "initOutputs")):
118 self._datasetDict.addOutput(DatasetTypeName(output.name), taskDef)
120 # For each `Quantum` in the set of all `Quantum` for this task,
121 # add a key to the `_DatasetTracker` that is a `DatasetRef` for one
122 # of the individual datasets inside the `Quantum`, with a value of
123 # a newly created QuantumNode to the appropriate input/output
124 # field.
125 self._count += len(quantumSet)
126 for quantum in quantumSet:
127 if _quantumToNodeId:
128 nodeId = _quantumToNodeId.get(quantum)
129 if nodeId is None:
130 raise ValueError("If _quantuMToNodeNumber is not None, all quanta must have an "
131 "associated value in the mapping")
132 else:
133 nodeId = NodeId(next(nodeNumberGenerator), self._buildId)
135 inits = quantum.initInputs.values()
136 inputs = quantum.inputs.values()
137 value = QuantumNode(quantum, taskDef, nodeId)
138 self._taskToQuantumNode[taskDef].add(value)
139 self._nodeIdMap[nodeId] = value
141 for dsRef in chain(inits, inputs):
142 # unfortunately, `Quantum` allows inits to be individual
143 # `DatasetRef`s or an Iterable of such, so there must
144 # be an instance check here
145 if isinstance(dsRef, Iterable):
146 for sub in dsRef:
147 self._datasetRefDict.addInput(sub, value)
148 else:
149 self._datasetRefDict.addInput(dsRef, value)
150 for dsRef in chain.from_iterable(quantum.outputs.values()):
151 self._datasetRefDict.addOutput(dsRef, value)
153 # Graph of task relations, used in various methods
154 self._taskGraph = self._datasetDict.makeNetworkXGraph()
156 # Graph of quanta relations
157 self._connectedQuanta = self._datasetRefDict.makeNetworkXGraph()
159 @property
160 def taskGraph(self) -> nx.DiGraph:
161 """Return a graph representing the relations between the tasks inside
162 the quantum graph.
164 Returns
165 -------
166 taskGraph : `networkx.Digraph`
167 Internal datastructure that holds relations of `TaskDef` objects
168 """
169 return self._taskGraph
171 @property
172 def graph(self) -> nx.DiGraph:
173 """Return a graph representing the relations between all the
174 `QuantumNode` objects. Largely it should be preferred to iterate
175 over, and use methods of this class, but sometimes direct access to
176 the networkx object may be helpful
178 Returns
179 -------
180 graph : `networkx.Digraph`
181 Internal datastructure that holds relations of `QuantumNode`
182 objects
183 """
184 return self._connectedQuanta
186 @property
187 def inputQuanta(self) -> Iterable[QuantumNode]:
188 """Make a `list` of all `QuantumNode` objects that are 'input' nodes
189 to the graph, meaning those nodes to not depend on any other nodes in
190 the graph.
192 Returns
193 -------
194 inputNodes : iterable of `QuantumNode`
195 A list of nodes that are inputs to the graph
196 """
197 return (q for q, n in self._connectedQuanta.in_degree if n == 0)
199 @property
200 def outputQuanta(self) -> Iterable[QuantumNode]:
201 """Make a `list` of all `QuantumNode` objects that are 'output' nodes
202 to the graph, meaning those nodes have no nodes that depend them in
203 the graph.
205 Returns
206 -------
207 outputNodes : iterable of `QuantumNode`
208 A list of nodes that are outputs of the graph
209 """
210 return [q for q, n in self._connectedQuanta.out_degree if n == 0]
212 @property
213 def allDatasetTypes(self) -> Tuple[DatasetTypeName, ...]:
214 """Return all the `DatasetTypeName` objects that are contained inside
215 the graph.
217 Returns
218 -------
219 tuple of `DatasetTypeName`
220 All the data set type names that are present in the graph
221 """
222 return tuple(self._datasetDict.keys())
224 @property
225 def isConnected(self) -> bool:
226 """Return True if all of the nodes in the graph are connected, ignores
227 directionality of connections.
228 """
229 return nx.is_weakly_connected(self._connectedQuanta)
231 def getQuantumNodeByNodeId(self, nodeId: NodeId) -> QuantumNode:
232 """Lookup a `QuantumNode` from an id associated with the node.
234 Parameters
235 ----------
236 nodeId : `NodeId`
237 The number associated with a node
239 Returns
240 -------
241 node : `QuantumNode`
242 The node corresponding with input number
244 Raises
245 ------
246 IndexError
247 Raised if the requested nodeId is not in the graph.
248 IncompatibleGraphError
249 Raised if the nodeId was built with a different graph than is not
250 this instance (or a graph instance that produced this instance
251 through and operation such as subset)
252 """
253 if nodeId.buildId != self._buildId:
254 raise IncompatibleGraphError("This node was built from a different, incompatible, graph instance")
255 return self._nodeIdMap[nodeId]
257 def getQuantaForTask(self, taskDef: TaskDef) -> FrozenSet[Quantum]:
258 """Return all the `Quantum` associated with a `TaskDef`.
260 Parameters
261 ----------
262 taskDef : `TaskDef`
263 The `TaskDef` for which `Quantum` are to be queried
265 Returns
266 -------
267 frozenset of `Quantum`
268 The `set` of `Quantum` that is associated with the specified
269 `TaskDef`.
270 """
271 return frozenset(self._quanta[taskDef])
273 def getNodesForTask(self, taskDef: TaskDef) -> FrozenSet[QuantumNode]:
274 """Return all the `QuantumNodes` associated with a `TaskDef`.
276 Parameters
277 ----------
278 taskDef : `TaskDef`
279 The `TaskDef` for which `Quantum` are to be queried
281 Returns
282 -------
283 frozenset of `QuantumNodes`
284 The `frozenset` of `QuantumNodes` that is associated with the
285 specified `TaskDef`.
286 """
287 return frozenset(self._taskToQuantumNode[taskDef])
289 def findTasksWithInput(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]:
290 """Find all tasks that have the specified dataset type name as an
291 input.
293 Parameters
294 ----------
295 datasetTypeName : `str`
296 A string representing the name of a dataset type to be queried,
297 can also accept a `DatasetTypeName` which is a `NewType` of str for
298 type safety in static type checking.
300 Returns
301 -------
302 tasks : iterable of `TaskDef`
303 `TaskDef` objects that have the specified `DatasetTypeName` as an
304 input, list will be empty if no tasks use specified
305 `DatasetTypeName` as an input.
307 Raises
308 ------
309 KeyError
310 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
311 """
312 return (c for c in self._datasetDict.getInputs(datasetTypeName))
314 def findTaskWithOutput(self, datasetTypeName: DatasetTypeName) -> Optional[TaskDef]:
315 """Find all tasks that have the specified dataset type name as an
316 output.
318 Parameters
319 ----------
320 datasetTypeName : `str`
321 A string representing the name of a dataset type to be queried,
322 can also accept a `DatasetTypeName` which is a `NewType` of str for
323 type safety in static type checking.
325 Returns
326 -------
327 `TaskDef` or `None`
328 `TaskDef` that outputs `DatasetTypeName` as an output or None if
329 none of the tasks produce this `DatasetTypeName`.
331 Raises
332 ------
333 KeyError
334 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
335 """
336 return self._datasetDict.getOutput(datasetTypeName)
338 def tasksWithDSType(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]:
339 """Find all tasks that are associated with the specified dataset type
340 name.
342 Parameters
343 ----------
344 datasetTypeName : `str`
345 A string representing the name of a dataset type to be queried,
346 can also accept a `DatasetTypeName` which is a `NewType` of str for
347 type safety in static type checking.
349 Returns
350 -------
351 result : iterable of `TaskDef`
352 `TaskDef` objects that are associated with the specified
353 `DatasetTypeName`
355 Raises
356 ------
357 KeyError
358 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
359 """
360 results = self.findTasksWithInput(datasetTypeName)
361 output = self.findTaskWithOutput(datasetTypeName)
362 if output is not None:
363 results = chain(results, (output,))
364 return results
366 def findTaskDefByName(self, taskName: str) -> List[TaskDef]:
367 """Determine which `TaskDef` objects in this graph are associated
368 with a `str` representing a task name (looks at the taskName property
369 of `TaskDef` objects).
371 Returns a list of `TaskDef` objects as a `PipelineTask` may appear
372 multiple times in a graph with different labels.
374 Parameters
375 ----------
376 taskName : str
377 Name of a task to search for
379 Returns
380 -------
381 result : list of `TaskDef`
382 List of the `TaskDef` objects that have the name specified.
383 Multiple values are returned in the case that a task is used
384 multiple times with different labels.
385 """
386 results = []
387 for task in self._quanta.keys():
388 split = task.taskName.split('.')
389 if split[-1] == taskName:
390 results.append(task)
391 return results
393 def findTaskDefByLabel(self, label: str) -> Optional[TaskDef]:
394 """Determine which `TaskDef` objects in this graph are associated
395 with a `str` representing a tasks label.
397 Parameters
398 ----------
399 taskName : str
400 Name of a task to search for
402 Returns
403 -------
404 result : `TaskDef`
405 `TaskDef` objects that has the specified label.
406 """
407 for task in self._quanta.keys():
408 if label == task.label:
409 return task
410 return None
412 def findQuantaWithDSType(self, datasetTypeName: DatasetTypeName) -> Set[Quantum]:
413 """Return all the `Quantum` that contain a specified `DatasetTypeName`.
415 Parameters
416 ----------
417 datasetTypeName : `str`
418 The name of the dataset type to search for as a string,
419 can also accept a `DatasetTypeName` which is a `NewType` of str for
420 type safety in static type checking.
422 Returns
423 -------
424 result : `set` of `QuantumNode` objects
425 A `set` of `QuantumNode`s that contain specified `DatasetTypeName`
427 Raises
428 ------
429 KeyError
430 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
432 """
433 tasks = self._datasetDict.getAll(datasetTypeName)
434 result: Set[Quantum] = set()
435 result = result.union(*(self._quanta[task] for task in tasks))
436 return result
438 def checkQuantumInGraph(self, quantum: Quantum) -> bool:
439 """Check if specified quantum appears in the graph as part of a node.
441 Parameters
442 ----------
443 quantum : `Quantum`
444 The quantum to search for
446 Returns
447 -------
448 `bool`
449 The result of searching for the quantum
450 """
451 for qset in self._quanta.values():
452 if quantum in qset:
453 return True
454 return False
456 def writeDotGraph(self, output: Union[str, io.BufferedIOBase]):
457 """Write out the graph as a dot graph.
459 Parameters
460 ----------
461 output : str or `io.BufferedIOBase`
462 Either a filesystem path to write to, or a file handle object
463 """
464 write_dot(self._connectedQuanta, output)
466 def subset(self: _T, nodes: Union[QuantumNode, Iterable[QuantumNode]]) -> _T:
467 """Create a new graph object that contains the subset of the nodes
468 specified as input. Node number is preserved.
470 Parameters
471 ----------
472 nodes : `QuantumNode` or iterable of `QuantumNode`
474 Returns
475 -------
476 graph : instance of graph type
477 An instance of the type from which the subset was created
478 """
479 if not isinstance(nodes, Iterable):
480 nodes = (nodes, )
481 quantumSubgraph = self._connectedQuanta.subgraph(nodes).nodes
482 quantumMap = defaultdict(set)
484 node: QuantumNode
485 for node in quantumSubgraph:
486 quantumMap[node.taskDef].add(node.quantum)
487 # Create an empty graph, and then populate it with custom mapping
488 newInst = type(self)({})
489 newInst._buildGraphs(quantumMap, _quantumToNodeId={n.quantum: n.nodeId for n in nodes},
490 _buildId=self._buildId)
491 return newInst
493 def subsetToConnected(self: _T) -> Tuple[_T, ...]:
494 """Generate a list of subgraphs where each is connected.
496 Returns
497 -------
498 result : list of `QuantumGraph`
499 A list of graphs that are each connected
500 """
501 return tuple(self.subset(connectedSet)
502 for connectedSet in nx.weakly_connected_components(self._connectedQuanta))
504 def determineInputsToQuantumNode(self, node: QuantumNode) -> Set[QuantumNode]:
505 """Return a set of `QuantumNode` that are direct inputs to a specified
506 node.
508 Parameters
509 ----------
510 node : `QuantumNode`
511 The node of the graph for which inputs are to be determined
513 Returns
514 -------
515 set of `QuantumNode`
516 All the nodes that are direct inputs to specified node
517 """
518 return set(pred for pred in self._connectedQuanta.predecessors(node))
520 def determineOutputsOfQuantumNode(self, node: QuantumNode) -> Set[QuantumNode]:
521 """Return a set of `QuantumNode` that are direct outputs of a specified
522 node.
524 Parameters
525 ----------
526 node : `QuantumNode`
527 The node of the graph for which outputs are to be determined
529 Returns
530 -------
531 set of `QuantumNode`
532 All the nodes that are direct outputs to specified node
533 """
534 return set(succ for succ in self._connectedQuanta.successors(node))
536 def determineConnectionsOfQuantumNode(self: _T, node: QuantumNode) -> _T:
537 """Return a graph of `QuantumNode` that are direct inputs and outputs
538 of a specified node.
540 Parameters
541 ----------
542 node : `QuantumNode`
543 The node of the graph for which connected nodes are to be
544 determined.
546 Returns
547 -------
548 graph : graph of `QuantumNode`
549 All the nodes that are directly connected to specified node
550 """
551 nodes = self.determineInputsToQuantumNode(node).union(self.determineOutputsOfQuantumNode(node))
552 nodes.add(node)
553 return self.subset(nodes)
555 def determineAncestorsOfQuantumNode(self: _T, node: QuantumNode) -> _T:
556 """Return a graph of the specified node and all the ancestor nodes
557 directly reachable by walking edges.
559 Parameters
560 ----------
561 node : `QuantumNode`
562 The node for which all ansestors are to be determined
564 Returns
565 -------
566 graph of `QuantumNode`
567 Graph of node and all of its ansestors
568 """
569 predecessorNodes = nx.ancestors(self._connectedQuanta, node)
570 predecessorNodes.add(node)
571 return self.subset(predecessorNodes)
573 def findCycle(self) -> List[Tuple[QuantumNode, QuantumNode]]:
574 """Check a graph for the presense of cycles and returns the edges of
575 any cycles found, or an empty list if there is no cycle.
577 Returns
578 -------
579 result : list of tuple of `QuantumNode`, `QuantumNode`
580 A list of any graph edges that form a cycle, or an empty list if
581 there is no cycle. Empty list to so support if graph.find_cycle()
582 syntax as an empty list is falsy.
583 """
584 try:
585 return nx.find_cycle(self._connectedQuanta)
586 except nx.NetworkXNoCycle:
587 return []
589 def saveUri(self, uri):
590 """Save `QuantumGraph` to the specified URI.
592 Parameters
593 ----------
594 uri : `ButlerURI` or `str`
595 URI to where the graph should be saved.
596 """
597 buffer = self._buildSaveObject()
598 butlerUri = ButlerURI(uri)
599 if butlerUri.getExtension() not in (".qgraph"):
600 raise TypeError(f"Can currently only save a graph in qgraph format not {uri}")
601 butlerUri.write(buffer) # type: ignore # Ignore because bytearray is safe to use in place of bytes
603 @classmethod
604 def loadUri(cls, uri: Union[ButlerURI, str], universe: DimensionUniverse,
605 nodes: Optional[Iterable[int]] = None,
606 graphID: Optional[BuildId] = None
607 ) -> QuantumGraph:
608 """Read `QuantumGraph` from a URI.
610 Parameters
611 ----------
612 uri : `ButlerURI` or `str`
613 URI from where to load the graph.
614 universe: `~lsst.daf.butler.DimensionUniverse`
615 DimensionUniverse instance, not used by the method itself but
616 needed to ensure that registry data structures are initialized.
617 nodes: iterable of `int` or None
618 Numbers that correspond to nodes in the graph. If specified, only
619 these nodes will be loaded. Defaults to None, in which case all
620 nodes will be loaded.
621 graphID : `str` or `None`
622 If specified this ID is verified against the loaded graph prior to
623 loading any Nodes. This defaults to None in which case no
624 validation is done.
626 Returns
627 -------
628 graph : `QuantumGraph`
629 Resulting QuantumGraph instance.
631 Raises
632 ------
633 TypeError
634 Raised if pickle contains instance of a type other than
635 QuantumGraph.
636 ValueError
637 Raised if one or more of the nodes requested is not in the
638 `QuantumGraph` or if graphID parameter does not match the graph
639 being loaded or if the supplied uri does not point at a valid
640 `QuantumGraph` save file.
643 Notes
644 -----
645 Reading Quanta from pickle requires existence of singleton
646 DimensionUniverse which is usually instantiated during Registry
647 initialization. To make sure that DimensionUniverse exists this method
648 accepts dummy DimensionUniverse argument.
649 """
650 uri = ButlerURI(uri)
651 # With ButlerURI we have the choice of always using a local file
652 # or reading in the bytes directly. Reading in bytes can be more
653 # efficient for reasonably-sized pickle files when the resource
654 # is remote. For now use the local file variant. For a local file
655 # as_local() does nothing.
657 if uri.getExtension() in (".pickle", ".pkl"):
658 with uri.as_local() as local, open(local.ospath, "rb") as fd:
659 warnings.warn("Pickle graphs are deprecated, please re-save your graph with the save method")
660 qgraph = pickle.load(fd)
661 elif uri.getExtension() in ('.qgraph'):
662 with LoadHelper(uri) as loader:
663 qgraph = loader.load(nodes, graphID)
664 else:
665 raise ValueError("Only know how to handle files saved as `pickle`, `pkl`, or `qgraph`")
666 if not isinstance(qgraph, QuantumGraph):
667 raise TypeError(f"QuantumGraph save file contains unexpected object type: {type(qgraph)}")
668 return qgraph
670 def save(self, file: io.IO[bytes]):
671 """Save QuantumGraph to a file.
673 Presently we store QuantumGraph in pickle format, this could
674 potentially change in the future if better format is found.
676 Parameters
677 ----------
678 file : `io.BufferedIOBase`
679 File to write pickle data open in binary mode.
680 """
681 buffer = self._buildSaveObject()
682 file.write(buffer) # type: ignore # Ignore because bytearray is safe to use in place of bytes
684 def _buildSaveObject(self) -> bytearray:
685 # make some containers
686 pickleData = deque()
687 nodeMap = {}
688 taskDefMap = {}
689 protocol = 3
691 # counter for the number of bytes processed thus far
692 count = 0
693 # serialize out the task Defs recording the start and end bytes of each
694 # taskDef
695 for taskDef in self.taskGraph:
696 # compressing has very little impact on saving or load time, but
697 # a large impact on on disk size, so it is worth doing
698 dump = lzma.compress(pickle.dumps(taskDef, protocol=protocol))
699 taskDefMap[taskDef.label] = (count, count+len(dump))
700 count += len(dump)
701 pickleData.append(dump)
703 # Store the QauntumGraph BuildId along side the TaskDefs for
704 # convenance. This will allow validating BuildIds at load time, prior
705 # to loading any QuantumNodes. Name chosen for unlikely conflicts with
706 # labels as this is python standard for private.
707 taskDefMap['__GraphBuildID'] = self.graphID
709 # serialize the nodes, recording the start and end bytes of each node
710 for node in self:
711 node = copy.copy(node)
712 taskDef = node.taskDef
713 # Explicitly overload the "frozen-ness" of nodes to normalized out
714 # the taskDef, this saves a lot of space and load time. The label
715 # will be used to retrive the taskDef from the taskDefMap upon load
716 #
717 # This strategy was chosen instead of creating a new class that
718 # looked just like a QuantumNode but containing a label in place of
719 # a TaskDef because it would be needlessly slow to construct a
720 # bunch of new object to immediately serialize them and destroy the
721 # object. This seems like an acceptable use of Python's dynamic
722 # nature in a controlled way for optimization and simplicity.
723 object.__setattr__(node, 'taskDef', taskDef.label)
724 # compressing has very little impact on saving or load time, but
725 # a large impact on on disk size, so it is worth doing
726 dump = lzma.compress(pickle.dumps(node, protocol=protocol))
727 pickleData.append(dump)
728 nodeMap[node.nodeId.number] = (count, count+len(dump))
729 count += len(dump)
731 # pickle the taskDef byte map
732 taskDef_pickle = pickle.dumps(taskDefMap, protocol=protocol)
734 # pickle the node byte map
735 map_pickle = pickle.dumps(nodeMap, protocol=protocol)
737 # record the sizes as 2 unsigned long long numbers for a total of 16
738 # bytes
739 map_lengths = struct.pack(STRUCT_FMT_STRING, SAVE_VERSION, len(taskDef_pickle), len(map_pickle))
741 # write each component of the save out in a deterministic order
742 # buffer = io.BytesIO()
743 # buffer.write(map_lengths)
744 # buffer.write(taskDef_pickle)
745 # buffer.write(map_pickle)
746 buffer = bytearray()
747 buffer.extend(MAGIC_BYTES)
748 buffer.extend(map_lengths)
749 buffer.extend(taskDef_pickle)
750 buffer.extend(map_pickle)
751 # Iterate over the length of pickleData, and for each element pop the
752 # leftmost element off the deque and write it out. This is to save
753 # memory, as the memory is added to the buffer object, it is removed
754 # from from the container.
755 #
756 # Only this section needs to worry about memory pressue because
757 # everything else written to the buffer prior to this pickle data is
758 # only on the order of kilobytes to low numbers of megabytes.
759 while pickleData:
760 buffer.extend(pickleData.popleft())
761 return buffer
763 @classmethod
764 def load(cls, file: io.IO[bytes], universe: DimensionUniverse,
765 nodes: Optional[Iterable[int]] = None,
766 graphID: Optional[BuildId] = None
767 ) -> QuantumGraph:
768 """Read QuantumGraph from a file that was made by `save`.
770 Parameters
771 ----------
772 file : `io.IO` of bytes
773 File with pickle data open in binary mode.
774 universe: `~lsst.daf.butler.DimensionUniverse`
775 DimensionUniverse instance, not used by the method itself but
776 needed to ensure that registry data structures are initialized.
777 nodes: iterable of `int` or None
778 Numbers that correspond to nodes in the graph. If specified, only
779 these nodes will be loaded. Defaults to None, in which case all
780 nodes will be loaded.
781 graphID : `str` or `None`
782 If specified this ID is verified against the loaded graph prior to
783 loading any Nodes. This defaults to None in which case no
784 validation is done.
786 Returns
787 -------
788 graph : `QuantumGraph`
789 Resulting QuantumGraph instance.
791 Raises
792 ------
793 TypeError
794 Raised if pickle contains instance of a type other than
795 QuantumGraph.
796 ValueError
797 Raised if one or more of the nodes requested is not in the
798 `QuantumGraph` or if graphID parameter does not match the graph
799 being loaded or if the supplied uri does not point at a valid
800 `QuantumGraph` save file.
802 Notes
803 -----
804 Reading Quanta from pickle requires existence of singleton
805 DimensionUniverse which is usually instantiated during Registry
806 initialization. To make sure that DimensionUniverse exists this method
807 accepts dummy DimensionUniverse argument.
808 """
809 # Try to see if the file handle contains pickle data, this will be
810 # removed in the future
811 try:
812 qgraph = pickle.load(file)
813 warnings.warn("Pickle graphs are deprecated, please re-save your graph with the save method")
814 except pickle.UnpicklingError:
815 with LoadHelper(file) as loader: # type: ignore # needed because we don't have Protocols yet
816 qgraph = loader.load(nodes, graphID)
817 if not isinstance(qgraph, QuantumGraph):
818 raise TypeError(f"QuantumGraph pickle file has contains unexpected object type: {type(qgraph)}")
819 return qgraph
821 def iterTaskGraph(self) -> Generator[TaskDef, None, None]:
822 """Iterate over the `taskGraph` attribute in topological order
824 Yields
825 ------
826 taskDef : `TaskDef`
827 `TaskDef` objects in topological order
828 """
829 yield from nx.topological_sort(self.taskGraph)
831 @property
832 def graphID(self):
833 """Returns the ID generated by the graph at construction time
834 """
835 return self._buildId
837 def __iter__(self) -> Generator[QuantumNode, None, None]:
838 yield from nx.topological_sort(self._connectedQuanta)
840 def __len__(self) -> int:
841 return self._count
843 def __contains__(self, node: QuantumNode) -> bool:
844 return self._connectedQuanta.has_node(node)
846 def __getstate__(self) -> dict:
847 """Stores a compact form of the graph as a list of graph nodes, and a
848 tuple of task labels and task configs. The full graph can be
849 reconstructed with this information, and it preseves the ordering of
850 the graph ndoes.
851 """
852 return {"nodesList": list(self)}
854 def __setstate__(self, state: dict):
855 """Reconstructs the state of the graph from the information persisted
856 in getstate.
857 """
858 quanta: DefaultDict[TaskDef, Set[Quantum]] = defaultdict(set)
859 quantumToNodeId: Dict[Quantum, NodeId] = {}
860 quantumNode: QuantumNode
861 for quantumNode in state['nodesList']:
862 quanta[quantumNode.taskDef].add(quantumNode.quantum)
863 quantumToNodeId[quantumNode.quantum] = quantumNode.nodeId
864 _buildId = quantumNode.nodeId.buildId if state['nodesList'] else None # type: ignore
865 self._buildGraphs(quanta, _quantumToNodeId=quantumToNodeId, _buildId=_buildId)
867 def __eq__(self, other: object) -> bool:
868 if not isinstance(other, QuantumGraph):
869 return False
870 if len(self) != len(other):
871 return False
872 for node in self:
873 if node not in other:
874 return False
875 if self.determineInputsToQuantumNode(node) != other.determineInputsToQuantumNode(node):
876 return False
877 if self.determineOutputsOfQuantumNode(node) != other.determineOutputsOfQuantumNode(node):
878 return False
879 return list(self.taskGraph) == list(other.taskGraph)