Coverage for python/lsst/pipe/base/graph/graph.py : 21%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
22import warnings
24__all__ = ("QuantumGraph", "IncompatibleGraphError")
26from collections import defaultdict, deque
28from itertools import chain, count
29import io
30import networkx as nx
31from networkx.drawing.nx_agraph import write_dot
32import os
33import pickle
34import lzma
35import copy
36import struct
37import time
38from typing import (DefaultDict, Dict, FrozenSet, Iterable, List, Mapping, Set, Generator, Optional, Tuple,
39 Union, TypeVar)
41from ..connections import iterConnections
42from ..pipeline import TaskDef
43from lsst.daf.butler import Quantum, DatasetRef, ButlerURI, DimensionUniverse
45from ._implDetails import _DatasetTracker, DatasetTypeName
46from .quantumNode import QuantumNode, NodeId, BuildId
47from ._loadHelpers import LoadHelper
50_T = TypeVar("_T", bound="QuantumGraph")
52# modify this constant any time the on disk representation of the save file
53# changes, and update the load helpers to behave properly for each version.
54SAVE_VERSION = 1
56# String used to describe the format for the preamble bytes in a file save
57# This marks a Big endian encoded format with an unsigned short, an unsigned
58# long long, and an unsigned long long in the byte stream
59STRUCT_FMT_STRING = '>HQQ'
62# magic bytes that help determine this is a graph save
63MAGIC_BYTES = b"qgraph4\xf6\xe8\xa9"
66class IncompatibleGraphError(Exception):
67 """Exception class to indicate that a lookup by NodeId is impossible due
68 to incompatibilities
69 """
70 pass
73class QuantumGraph:
74 """QuantumGraph is a directed acyclic graph of `QuantumNode` objects
76 This data structure represents a concrete workflow generated from a
77 `Pipeline`.
79 Parameters
80 ----------
81 quanta : Mapping of `TaskDef` to sets of `Quantum`
82 This maps tasks (and their configs) to the sets of data they are to
83 process.
84 """
85 def __init__(self, quanta: Mapping[TaskDef, Set[Quantum]]):
86 self._buildGraphs(quanta)
88 def _buildGraphs(self,
89 quanta: Mapping[TaskDef, Set[Quantum]],
90 *,
91 _quantumToNodeId: Optional[Mapping[Quantum, NodeId]] = None,
92 _buildId: Optional[BuildId] = None):
93 """Builds the graph that is used to store the relation between tasks,
94 and the graph that holds the relations between quanta
95 """
96 self._quanta = quanta
97 self._buildId = _buildId if _buildId is not None else BuildId(f"{time.time()}-{os.getpid()}")
98 # Data structures used to identify relations between components;
99 # DatasetTypeName -> TaskDef for task,
100 # and DatasetRef -> QuantumNode for the quanta
101 self._datasetDict = _DatasetTracker[DatasetTypeName, TaskDef]()
102 self._datasetRefDict = _DatasetTracker[DatasetRef, QuantumNode]()
104 nodeNumberGenerator = count()
105 self._nodeIdMap: Dict[NodeId, QuantumNode] = {}
106 self._count = 0
107 for taskDef, quantumSet in self._quanta.items():
108 connections = taskDef.connections
110 # For each type of connection in the task, add a key to the
111 # `_DatasetTracker` for the connections name, with a value of
112 # the TaskDef in the appropriate field
113 for inpt in iterConnections(connections, ("inputs", "prerequisiteInputs", "initInputs")):
114 self._datasetDict.addInput(DatasetTypeName(inpt.name), taskDef)
116 for output in iterConnections(connections, ("outputs", "initOutputs")):
117 self._datasetDict.addOutput(DatasetTypeName(output.name), taskDef)
119 # For each `Quantum` in the set of all `Quantum` for this task,
120 # add a key to the `_DatasetTracker` that is a `DatasetRef` for one
121 # of the individual datasets inside the `Quantum`, with a value of
122 # a newly created QuantumNode to the appropriate input/output
123 # field.
124 self._count += len(quantumSet)
125 for quantum in quantumSet:
126 if _quantumToNodeId:
127 nodeId = _quantumToNodeId.get(quantum)
128 if nodeId is None:
129 raise ValueError("If _quantuMToNodeNumber is not None, all quanta must have an "
130 "associated value in the mapping")
131 else:
132 nodeId = NodeId(next(nodeNumberGenerator), self._buildId)
134 inits = quantum.initInputs.values()
135 inputs = quantum.inputs.values()
136 value = QuantumNode(quantum, taskDef, nodeId)
137 self._nodeIdMap[nodeId] = value
139 for dsRef in chain(inits, inputs):
140 # unfortunately, `Quantum` allows inits to be individual
141 # `DatasetRef`s or an Iterable of such, so there must
142 # be an instance check here
143 if isinstance(dsRef, Iterable):
144 for sub in dsRef:
145 self._datasetRefDict.addInput(sub, value)
146 else:
147 self._datasetRefDict.addInput(dsRef, value)
148 for dsRef in chain.from_iterable(quantum.outputs.values()):
149 self._datasetRefDict.addOutput(dsRef, value)
151 # Graph of task relations, used in various methods
152 self._taskGraph = self._datasetDict.makeNetworkXGraph()
154 # Graph of quanta relations
155 self._connectedQuanta = self._datasetRefDict.makeNetworkXGraph()
157 @property
158 def taskGraph(self) -> nx.DiGraph:
159 """Return a graph representing the relations between the tasks inside
160 the quantum graph.
162 Returns
163 -------
164 taskGraph : `networkx.Digraph`
165 Internal datastructure that holds relations of `TaskDef` objects
166 """
167 return self._taskGraph
169 @property
170 def graph(self) -> nx.DiGraph:
171 """Return a graph representing the relations between all the
172 `QuantumNode` objects. Largely it should be preferred to iterate
173 over, and use methods of this class, but sometimes direct access to
174 the networkx object may be helpful
176 Returns
177 -------
178 graph : `networkx.Digraph`
179 Internal datastructure that holds relations of `QuantumNode`
180 objects
181 """
182 return self._connectedQuanta
184 @property
185 def inputQuanta(self) -> Iterable[QuantumNode]:
186 """Make a `list` of all `QuantumNode` objects that are 'input' nodes
187 to the graph, meaning those nodes to not depend on any other nodes in
188 the graph.
190 Returns
191 -------
192 inputNodes : iterable of `QuantumNode`
193 A list of nodes that are inputs to the graph
194 """
195 return (q for q, n in self._connectedQuanta.in_degree if n == 0)
197 @property
198 def outputQuanta(self) -> Iterable[QuantumNode]:
199 """Make a `list` of all `QuantumNode` objects that are 'output' nodes
200 to the graph, meaning those nodes have no nodes that depend them in
201 the graph.
203 Returns
204 -------
205 outputNodes : iterable of `QuantumNode`
206 A list of nodes that are outputs of the graph
207 """
208 return [q for q, n in self._connectedQuanta.out_degree if n == 0]
210 @property
211 def allDatasetTypes(self) -> Tuple[DatasetTypeName, ...]:
212 """Return all the `DatasetTypeName` objects that are contained inside
213 the graph.
215 Returns
216 -------
217 tuple of `DatasetTypeName`
218 All the data set type names that are present in the graph
219 """
220 return tuple(self._datasetDict.keys())
222 @property
223 def isConnected(self) -> bool:
224 """Return True if all of the nodes in the graph are connected, ignores
225 directionality of connections.
226 """
227 return nx.is_weakly_connected(self._connectedQuanta)
229 def getQuantumNodeByNodeId(self, nodeId: NodeId) -> QuantumNode:
230 """Lookup a `QuantumNode` from an id associated with the node.
232 Parameters
233 ----------
234 nodeId : `NodeId`
235 The number associated with a node
237 Returns
238 -------
239 node : `QuantumNode`
240 The node corresponding with input number
242 Raises
243 ------
244 IndexError
245 Raised if the requested nodeId is not in the graph.
246 IncompatibleGraphError
247 Raised if the nodeId was built with a different graph than is not
248 this instance (or a graph instance that produced this instance
249 through and operation such as subset)
250 """
251 if nodeId.buildId != self._buildId:
252 raise IncompatibleGraphError("This node was built from a different, incompatible, graph instance")
253 return self._nodeIdMap[nodeId]
255 def getQuantaForTask(self, taskDef: TaskDef) -> FrozenSet[Quantum]:
256 """Return all the `Quantum` associated with a `TaskDef`.
258 Parameters
259 ----------
260 taskDef : `TaskDef`
261 The `TaskDef` for which `Quantum` are to be queried
263 Returns
264 -------
265 frozenset of `Quantum`
266 The `set` of `Quantum` that is associated with the specified
267 `TaskDef`.
268 """
269 return frozenset(self._quanta[taskDef])
271 def findTasksWithInput(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]:
272 """Find all tasks that have the specified dataset type name as an
273 input.
275 Parameters
276 ----------
277 datasetTypeName : `str`
278 A string representing the name of a dataset type to be queried,
279 can also accept a `DatasetTypeName` which is a `NewType` of str for
280 type safety in static type checking.
282 Returns
283 -------
284 tasks : iterable of `TaskDef`
285 `TaskDef` objects that have the specified `DatasetTypeName` as an
286 input, list will be empty if no tasks use specified
287 `DatasetTypeName` as an input.
289 Raises
290 ------
291 KeyError
292 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
293 """
294 return (c for c in self._datasetDict.getInputs(datasetTypeName))
296 def findTaskWithOutput(self, datasetTypeName: DatasetTypeName) -> Optional[TaskDef]:
297 """Find all tasks that have the specified dataset type name as an
298 output.
300 Parameters
301 ----------
302 datasetTypeName : `str`
303 A string representing the name of a dataset type to be queried,
304 can also accept a `DatasetTypeName` which is a `NewType` of str for
305 type safety in static type checking.
307 Returns
308 -------
309 `TaskDef` or `None`
310 `TaskDef` that outputs `DatasetTypeName` as an output or None if
311 none of the tasks produce this `DatasetTypeName`.
313 Raises
314 ------
315 KeyError
316 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
317 """
318 return self._datasetDict.getOutput(datasetTypeName)
320 def tasksWithDSType(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]:
321 """Find all tasks that are associated with the specified dataset type
322 name.
324 Parameters
325 ----------
326 datasetTypeName : `str`
327 A string representing the name of a dataset type to be queried,
328 can also accept a `DatasetTypeName` which is a `NewType` of str for
329 type safety in static type checking.
331 Returns
332 -------
333 result : iterable of `TaskDef`
334 `TaskDef` objects that are associated with the specified
335 `DatasetTypeName`
337 Raises
338 ------
339 KeyError
340 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
341 """
342 results = self.findTasksWithInput(datasetTypeName)
343 output = self.findTaskWithOutput(datasetTypeName)
344 if output is not None:
345 results = chain(results, (output,))
346 return results
348 def findTaskDefByName(self, taskName: str) -> List[TaskDef]:
349 """Determine which `TaskDef` objects in this graph are associated
350 with a `str` representing a task name (looks at the taskName property
351 of `TaskDef` objects).
353 Returns a list of `TaskDef` objects as a `PipelineTask` may appear
354 multiple times in a graph with different labels.
356 Parameters
357 ----------
358 taskName : str
359 Name of a task to search for
361 Returns
362 -------
363 result : list of `TaskDef`
364 List of the `TaskDef` objects that have the name specified.
365 Multiple values are returned in the case that a task is used
366 multiple times with different labels.
367 """
368 results = []
369 for task in self._quanta.keys():
370 split = task.taskName.split('.')
371 if split[-1] == taskName:
372 results.append(task)
373 return results
375 def findTaskDefByLabel(self, label: str) -> Optional[TaskDef]:
376 """Determine which `TaskDef` objects in this graph are associated
377 with a `str` representing a tasks label.
379 Parameters
380 ----------
381 taskName : str
382 Name of a task to search for
384 Returns
385 -------
386 result : `TaskDef`
387 `TaskDef` objects that has the specified label.
388 """
389 for task in self._quanta.keys():
390 if label == task.label:
391 return task
392 return None
394 def findQuantaWithDSType(self, datasetTypeName: DatasetTypeName) -> Set[Quantum]:
395 """Return all the `Quantum` that contain a specified `DatasetTypeName`.
397 Parameters
398 ----------
399 datasetTypeName : `str`
400 The name of the dataset type to search for as a string,
401 can also accept a `DatasetTypeName` which is a `NewType` of str for
402 type safety in static type checking.
404 Returns
405 -------
406 result : `set` of `QuantumNode` objects
407 A `set` of `QuantumNode`s that contain specified `DatasetTypeName`
409 Raises
410 ------
411 KeyError
412 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
414 """
415 tasks = self._datasetDict.getAll(datasetTypeName)
416 result: Set[Quantum] = set()
417 result = result.union(*(self._quanta[task] for task in tasks))
418 return result
420 def checkQuantumInGraph(self, quantum: Quantum) -> bool:
421 """Check if specified quantum appears in the graph as part of a node.
423 Parameters
424 ----------
425 quantum : `Quantum`
426 The quantum to search for
428 Returns
429 -------
430 `bool`
431 The result of searching for the quantum
432 """
433 for qset in self._quanta.values():
434 if quantum in qset:
435 return True
436 return False
438 def writeDotGraph(self, output: Union[str, io.BufferedIOBase]):
439 """Write out the graph as a dot graph.
441 Parameters
442 ----------
443 output : str or `io.BufferedIOBase`
444 Either a filesystem path to write to, or a file handle object
445 """
446 write_dot(self._connectedQuanta, output)
448 def subset(self: _T, nodes: Union[QuantumNode, Iterable[QuantumNode]]) -> _T:
449 """Create a new graph object that contains the subset of the nodes
450 specified as input. Node number is preserved.
452 Parameters
453 ----------
454 nodes : `QuantumNode` or iterable of `QuantumNode`
456 Returns
457 -------
458 graph : instance of graph type
459 An instance of the type from which the subset was created
460 """
461 if not isinstance(nodes, Iterable):
462 nodes = (nodes, )
463 quantumSubgraph = self._connectedQuanta.subgraph(nodes).nodes
464 quantumMap = defaultdict(set)
466 node: QuantumNode
467 for node in quantumSubgraph:
468 quantumMap[node.taskDef].add(node.quantum)
469 # Create an empty graph, and then populate it with custom mapping
470 newInst = type(self)({})
471 newInst._buildGraphs(quantumMap, _quantumToNodeId={n.quantum: n.nodeId for n in nodes},
472 _buildId=self._buildId)
473 return newInst
475 def subsetToConnected(self: _T) -> Tuple[_T, ...]:
476 """Generate a list of subgraphs where each is connected.
478 Returns
479 -------
480 result : list of `QuantumGraph`
481 A list of graphs that are each connected
482 """
483 return tuple(self.subset(connectedSet)
484 for connectedSet in nx.weakly_connected_components(self._connectedQuanta))
486 def determineInputsToQuantumNode(self, node: QuantumNode) -> Set[QuantumNode]:
487 """Return a set of `QuantumNode` that are direct inputs to a specified
488 node.
490 Parameters
491 ----------
492 node : `QuantumNode`
493 The node of the graph for which inputs are to be determined
495 Returns
496 -------
497 set of `QuantumNode`
498 All the nodes that are direct inputs to specified node
499 """
500 return set(pred for pred in self._connectedQuanta.predecessors(node))
502 def determineOutputsOfQuantumNode(self, node: QuantumNode) -> Set[QuantumNode]:
503 """Return a set of `QuantumNode` that are direct outputs of a specified
504 node.
506 Parameters
507 ----------
508 node : `QuantumNode`
509 The node of the graph for which outputs are to be determined
511 Returns
512 -------
513 set of `QuantumNode`
514 All the nodes that are direct outputs to specified node
515 """
516 return set(succ for succ in self._connectedQuanta.successors(node))
518 def determineConnectionsOfQuantumNode(self: _T, node: QuantumNode) -> _T:
519 """Return a graph of `QuantumNode` that are direct inputs and outputs
520 of a specified node.
522 Parameters
523 ----------
524 node : `QuantumNode`
525 The node of the graph for which connected nodes are to be
526 determined.
528 Returns
529 -------
530 graph : graph of `QuantumNode`
531 All the nodes that are directly connected to specified node
532 """
533 nodes = self.determineInputsToQuantumNode(node).union(self.determineOutputsOfQuantumNode(node))
534 nodes.add(node)
535 return self.subset(nodes)
537 def determineAncestorsOfQuantumNode(self: _T, node: QuantumNode) -> _T:
538 """Return a graph of the specified node and all the ancestor nodes
539 directly reachable by walking edges.
541 Parameters
542 ----------
543 node : `QuantumNode`
544 The node for which all ansestors are to be determined
546 Returns
547 -------
548 graph of `QuantumNode`
549 Graph of node and all of its ansestors
550 """
551 predecessorNodes = nx.ancestors(self._connectedQuanta, node)
552 predecessorNodes.add(node)
553 return self.subset(predecessorNodes)
555 def findCycle(self) -> List[Tuple[QuantumNode, QuantumNode]]:
556 """Check a graph for the presense of cycles and returns the edges of
557 any cycles found, or an empty list if there is no cycle.
559 Returns
560 -------
561 result : list of tuple of `QuantumNode`, `QuantumNode`
562 A list of any graph edges that form a cycle, or an empty list if
563 there is no cycle. Empty list to so support if graph.find_cycle()
564 syntax as an empty list is falsy.
565 """
566 try:
567 return nx.find_cycle(self._connectedQuanta)
568 except nx.NetworkXNoCycle:
569 return []
571 def saveUri(self, uri):
572 """Save `QuantumGraph` to the specified URI.
574 Parameters
575 ----------
576 uri : `ButlerURI` or `str`
577 URI to where the graph should be saved.
578 """
579 buffer = self._buildSaveObject()
580 butlerUri = ButlerURI(uri)
581 if butlerUri.getExtension() not in (".qgraph"):
582 raise TypeError(f"Can currently only save a graph in qgraph format not {uri}")
583 butlerUri.write(buffer) # type: ignore # Ignore because bytearray is safe to use in place of bytes
585 @classmethod
586 def loadUri(cls, uri: Union[ButlerURI, str], universe: DimensionUniverse,
587 nodes: Optional[Iterable[int]] = None,
588 graphID: Optional[BuildId] = None
589 ) -> QuantumGraph:
590 """Read `QuantumGraph` from a URI.
592 Parameters
593 ----------
594 uri : `ButlerURI` or `str`
595 URI from where to load the graph.
596 universe: `~lsst.daf.butler.DimensionUniverse`
597 DimensionUniverse instance, not used by the method itself but
598 needed to ensure that registry data structures are initialized.
599 nodes: iterable of `int` or None
600 Numbers that correspond to nodes in the graph. If specified, only
601 these nodes will be loaded. Defaults to None, in which case all
602 nodes will be loaded.
603 graphID : `str` or `None`
604 If specified this ID is verified against the loaded graph prior to
605 loading any Nodes. This defaults to None in which case no
606 validation is done.
608 Returns
609 -------
610 graph : `QuantumGraph`
611 Resulting QuantumGraph instance.
613 Raises
614 ------
615 TypeError
616 Raised if pickle contains instance of a type other than
617 QuantumGraph.
618 ValueError
619 Raised if one or more of the nodes requested is not in the
620 `QuantumGraph` or if graphID parameter does not match the graph
621 being loaded or if the supplied uri does not point at a valid
622 `QuantumGraph` save file.
625 Notes
626 -----
627 Reading Quanta from pickle requires existence of singleton
628 DimensionUniverse which is usually instantiated during Registry
629 initialization. To make sure that DimensionUniverse exists this method
630 accepts dummy DimensionUniverse argument.
631 """
632 uri = ButlerURI(uri)
633 # With ButlerURI we have the choice of always using a local file
634 # or reading in the bytes directly. Reading in bytes can be more
635 # efficient for reasonably-sized pickle files when the resource
636 # is remote. For now use the local file variant. For a local file
637 # as_local() does nothing.
639 if uri.getExtension() in (".pickle", ".pkl"):
640 with uri.as_local() as local, open(local.ospath, "rb") as fd:
641 warnings.warn("Pickle graphs are deprecated, please re-save your graph with the save method")
642 qgraph = pickle.load(fd)
643 elif uri.getExtension() in ('.qgraph'):
644 with LoadHelper(uri) as loader:
645 qgraph = loader.load(nodes, graphID)
646 else:
647 raise ValueError("Only know how to handle files saved as `pickle`, `pkl`, or `qgraph`")
648 if not isinstance(qgraph, QuantumGraph):
649 raise TypeError(f"QuantumGraph save file contains unexpected object type: {type(qgraph)}")
650 return qgraph
652 def save(self, file: io.IO[bytes]):
653 """Save QuantumGraph to a file.
655 Presently we store QuantumGraph in pickle format, this could
656 potentially change in the future if better format is found.
658 Parameters
659 ----------
660 file : `io.BufferedIOBase`
661 File to write pickle data open in binary mode.
662 """
663 buffer = self._buildSaveObject()
664 file.write(buffer) # type: ignore # Ignore because bytearray is safe to use in place of bytes
666 def _buildSaveObject(self) -> bytearray:
667 # make some containers
668 pickleData = deque()
669 nodeMap = {}
670 taskDefMap = {}
671 protocol = 3
673 # counter for the number of bytes processed thus far
674 count = 0
675 # serialize out the task Defs recording the start and end bytes of each
676 # taskDef
677 for taskDef in self.taskGraph:
678 # compressing has very little impact on saving or load time, but
679 # a large impact on on disk size, so it is worth doing
680 dump = lzma.compress(pickle.dumps(taskDef, protocol=protocol))
681 taskDefMap[taskDef.label] = (count, count+len(dump))
682 count += len(dump)
683 pickleData.append(dump)
685 # Store the QauntumGraph BuildId along side the TaskDefs for
686 # convenance. This will allow validating BuildIds at load time, prior
687 # to loading any QuantumNodes. Name chosen for unlikely conflicts with
688 # labels as this is python standard for private.
689 taskDefMap['__GraphBuildID'] = self.graphID
691 # serialize the nodes, recording the start and end bytes of each node
692 for node in self:
693 node = copy.copy(node)
694 taskDef = node.taskDef
695 # Explicitly overload the "frozen-ness" of nodes to normalized out
696 # the taskDef, this saves a lot of space and load time. The label
697 # will be used to retrive the taskDef from the taskDefMap upon load
698 #
699 # This strategy was chosen instead of creating a new class that
700 # looked just like a QuantumNode but containing a label in place of
701 # a TaskDef because it would be needlessly slow to construct a
702 # bunch of new object to immediately serialize them and destroy the
703 # object. This seems like an acceptable use of Python's dynamic
704 # nature in a controlled way for optimization and simplicity.
705 object.__setattr__(node, 'taskDef', taskDef.label)
706 # compressing has very little impact on saving or load time, but
707 # a large impact on on disk size, so it is worth doing
708 dump = lzma.compress(pickle.dumps(node, protocol=protocol))
709 pickleData.append(dump)
710 nodeMap[node.nodeId.number] = (count, count+len(dump))
711 count += len(dump)
713 # pickle the taskDef byte map
714 taskDef_pickle = pickle.dumps(taskDefMap, protocol=protocol)
716 # pickle the node byte map
717 map_pickle = pickle.dumps(nodeMap, protocol=protocol)
719 # record the sizes as 2 unsigned long long numbers for a total of 16
720 # bytes
721 map_lengths = struct.pack(STRUCT_FMT_STRING, SAVE_VERSION, len(taskDef_pickle), len(map_pickle))
723 # write each component of the save out in a deterministic order
724 # buffer = io.BytesIO()
725 # buffer.write(map_lengths)
726 # buffer.write(taskDef_pickle)
727 # buffer.write(map_pickle)
728 buffer = bytearray()
729 buffer.extend(MAGIC_BYTES)
730 buffer.extend(map_lengths)
731 buffer.extend(taskDef_pickle)
732 buffer.extend(map_pickle)
733 # Iterate over the length of pickleData, and for each element pop the
734 # leftmost element off the deque and write it out. This is to save
735 # memory, as the memory is added to the buffer object, it is removed
736 # from from the container.
737 #
738 # Only this section needs to worry about memory pressue because
739 # everything else written to the buffer prior to this pickle data is
740 # only on the order of kilobytes to low numbers of megabytes.
741 while pickleData:
742 buffer.extend(pickleData.popleft())
743 return buffer
745 @classmethod
746 def load(cls, file: io.IO[bytes], universe: DimensionUniverse,
747 nodes: Optional[Iterable[int]] = None,
748 graphID: Optional[BuildId] = None
749 ) -> QuantumGraph:
750 """Read QuantumGraph from a file that was made by `save`.
752 Parameters
753 ----------
754 file : `io.IO` of bytes
755 File with pickle data open in binary mode.
756 universe: `~lsst.daf.butler.DimensionUniverse`
757 DimensionUniverse instance, not used by the method itself but
758 needed to ensure that registry data structures are initialized.
759 nodes: iterable of `int` or None
760 Numbers that correspond to nodes in the graph. If specified, only
761 these nodes will be loaded. Defaults to None, in which case all
762 nodes will be loaded.
763 graphID : `str` or `None`
764 If specified this ID is verified against the loaded graph prior to
765 loading any Nodes. This defaults to None in which case no
766 validation is done.
768 Returns
769 -------
770 graph : `QuantumGraph`
771 Resulting QuantumGraph instance.
773 Raises
774 ------
775 TypeError
776 Raised if pickle contains instance of a type other than
777 QuantumGraph.
778 ValueError
779 Raised if one or more of the nodes requested is not in the
780 `QuantumGraph` or if graphID parameter does not match the graph
781 being loaded or if the supplied uri does not point at a valid
782 `QuantumGraph` save file.
784 Notes
785 -----
786 Reading Quanta from pickle requires existence of singleton
787 DimensionUniverse which is usually instantiated during Registry
788 initialization. To make sure that DimensionUniverse exists this method
789 accepts dummy DimensionUniverse argument.
790 """
791 # Try to see if the file handle contains pickle data, this will be
792 # removed in the future
793 try:
794 qgraph = pickle.load(file)
795 warnings.warn("Pickle graphs are deprecated, please re-save your graph with the save method")
796 except pickle.UnpicklingError:
797 with LoadHelper(file) as loader: # type: ignore # needed because we don't have Protocols yet
798 qgraph = loader.load(nodes, graphID)
799 if not isinstance(qgraph, QuantumGraph):
800 raise TypeError(f"QuantumGraph pickle file has contains unexpected object type: {type(qgraph)}")
801 return qgraph
803 def iterTaskGraph(self) -> Generator[TaskDef, None, None]:
804 """Iterate over the `taskGraph` attribute in topological order
806 Yields
807 ------
808 taskDef : `TaskDef`
809 `TaskDef` objects in topological order
810 """
811 yield from nx.topological_sort(self.taskGraph)
813 @property
814 def graphID(self):
815 """Returns the ID generated by the graph at construction time
816 """
817 return self._buildId
819 def __iter__(self) -> Generator[QuantumNode, None, None]:
820 yield from nx.topological_sort(self._connectedQuanta)
822 def __len__(self) -> int:
823 return self._count
825 def __contains__(self, node: QuantumNode) -> bool:
826 return self._connectedQuanta.has_node(node)
828 def __getstate__(self) -> dict:
829 """Stores a compact form of the graph as a list of graph nodes, and a
830 tuple of task labels and task configs. The full graph can be
831 reconstructed with this information, and it preseves the ordering of
832 the graph ndoes.
833 """
834 return {"nodesList": list(self)}
836 def __setstate__(self, state: dict):
837 """Reconstructs the state of the graph from the information persisted
838 in getstate.
839 """
840 quanta: DefaultDict[TaskDef, Set[Quantum]] = defaultdict(set)
841 quantumToNodeId: Dict[Quantum, NodeId] = {}
842 quantumNode: QuantumNode
843 for quantumNode in state['nodesList']:
844 quanta[quantumNode.taskDef].add(quantumNode.quantum)
845 quantumToNodeId[quantumNode.quantum] = quantumNode.nodeId
846 _buildId = quantumNode.nodeId.buildId if state['nodesList'] else None # type: ignore
847 self._buildGraphs(quanta, _quantumToNodeId=quantumToNodeId, _buildId=_buildId)
849 def __eq__(self, other: object) -> bool:
850 if not isinstance(other, QuantumGraph):
851 return False
852 if len(self) != len(other):
853 return False
854 for node in self:
855 if node not in other:
856 return False
857 if self.determineInputsToQuantumNode(node) != other.determineInputsToQuantumNode(node):
858 return False
859 if self.determineOutputsOfQuantumNode(node) != other.determineOutputsOfQuantumNode(node):
860 return False
861 return list(self.taskGraph) == list(other.taskGraph)