Coverage for python/lsst/pipe/base/graph/graph.py: 24%
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
22import warnings
24__all__ = ("QuantumGraph", "IncompatibleGraphError")
26from collections import defaultdict, deque
28from itertools import chain, count
29import io
30import json
31import networkx as nx
32from networkx.drawing.nx_agraph import write_dot
33import os
34import pickle
35import lzma
36import copy
37import struct
38import time
39from types import MappingProxyType
40from typing import (Any, DefaultDict, Dict, FrozenSet, Iterable, List, Mapping, Set, Generator, Optional,
41 Tuple, Union, TypeVar)
43from ..connections import iterConnections
44from ..pipeline import TaskDef
45from lsst.daf.butler import Quantum, DatasetRef, ButlerURI, DimensionUniverse
47from ._implDetails import _DatasetTracker, DatasetTypeName
48from .quantumNode import QuantumNode, NodeId, BuildId
49from ._loadHelpers import LoadHelper
52_T = TypeVar("_T", bound="QuantumGraph")
54# modify this constant any time the on disk representation of the save file
55# changes, and update the load helpers to behave properly for each version.
56SAVE_VERSION = 2
58# Strings used to describe the format for the preamble bytes in a file save
59# The base is a big endian encoded unsigned short that is used to hold the
60# file format version. This allows reading version bytes and determine which
61# loading code should be used for the rest of the file
62STRUCT_FMT_BASE = '>H'
63#
64# Version 1
65# This marks a big endian encoded format with an unsigned short, an unsigned
66# long long, and an unsigned long long in the byte stream
67# Version 2
68# A big endian encoded format with an unsigned long long byte stream used to
69# indicate the total length of the entire header
70STRUCT_FMT_STRING = {
71 1: '>QQ',
72 2: '>Q'
73}
76# magic bytes that help determine this is a graph save
77MAGIC_BYTES = b"qgraph4\xf6\xe8\xa9"
80class IncompatibleGraphError(Exception):
81 """Exception class to indicate that a lookup by NodeId is impossible due
82 to incompatibilities
83 """
84 pass
87class QuantumGraph:
88 """QuantumGraph is a directed acyclic graph of `QuantumNode` objects
90 This data structure represents a concrete workflow generated from a
91 `Pipeline`.
93 Parameters
94 ----------
95 quanta : Mapping of `TaskDef` to sets of `Quantum`
96 This maps tasks (and their configs) to the sets of data they are to
97 process.
98 metadata : Optional Mapping of `str` to primitives
99 This is an optional parameter of extra data to carry with the graph.
100 Entries in this mapping should be able to be serialized in JSON.
101 """
102 def __init__(self, quanta: Mapping[TaskDef, Set[Quantum]],
103 metadata: Optional[Mapping[str, Any]] = None):
104 self._buildGraphs(quanta, metadata=metadata)
106 def _buildGraphs(self,
107 quanta: Mapping[TaskDef, Set[Quantum]],
108 *,
109 _quantumToNodeId: Optional[Mapping[Quantum, NodeId]] = None,
110 _buildId: Optional[BuildId] = None,
111 metadata: Optional[Mapping[str, Any]] = None):
112 """Builds the graph that is used to store the relation between tasks,
113 and the graph that holds the relations between quanta
114 """
115 self._metadata = metadata
116 self._quanta = quanta
117 self._buildId = _buildId if _buildId is not None else BuildId(f"{time.time()}-{os.getpid()}")
118 # Data structures used to identify relations between components;
119 # DatasetTypeName -> TaskDef for task,
120 # and DatasetRef -> QuantumNode for the quanta
121 self._datasetDict = _DatasetTracker[DatasetTypeName, TaskDef]()
122 self._datasetRefDict = _DatasetTracker[DatasetRef, QuantumNode]()
124 nodeNumberGenerator = count()
125 self._nodeIdMap: Dict[NodeId, QuantumNode] = {}
126 self._taskToQuantumNode: DefaultDict[TaskDef, Set[QuantumNode]] = defaultdict(set)
127 self._count = 0
128 for taskDef, quantumSet in self._quanta.items():
129 connections = taskDef.connections
131 # For each type of connection in the task, add a key to the
132 # `_DatasetTracker` for the connections name, with a value of
133 # the TaskDef in the appropriate field
134 for inpt in iterConnections(connections, ("inputs", "prerequisiteInputs", "initInputs")):
135 self._datasetDict.addConsumer(DatasetTypeName(inpt.name), taskDef)
137 for output in iterConnections(connections, ("outputs", "initOutputs")):
138 self._datasetDict.addProducer(DatasetTypeName(output.name), taskDef)
140 # For each `Quantum` in the set of all `Quantum` for this task,
141 # add a key to the `_DatasetTracker` that is a `DatasetRef` for one
142 # of the individual datasets inside the `Quantum`, with a value of
143 # a newly created QuantumNode to the appropriate input/output
144 # field.
145 self._count += len(quantumSet)
146 for quantum in quantumSet:
147 if _quantumToNodeId:
148 nodeId = _quantumToNodeId.get(quantum)
149 if nodeId is None:
150 raise ValueError("If _quantuMToNodeNumber is not None, all quanta must have an "
151 "associated value in the mapping")
152 else:
153 nodeId = NodeId(next(nodeNumberGenerator), self._buildId)
155 inits = quantum.initInputs.values()
156 inputs = quantum.inputs.values()
157 value = QuantumNode(quantum, taskDef, nodeId)
158 self._taskToQuantumNode[taskDef].add(value)
159 self._nodeIdMap[nodeId] = value
161 for dsRef in chain(inits, inputs):
162 # unfortunately, `Quantum` allows inits to be individual
163 # `DatasetRef`s or an Iterable of such, so there must
164 # be an instance check here
165 if isinstance(dsRef, Iterable):
166 for sub in dsRef:
167 self._datasetRefDict.addConsumer(sub, value)
168 else:
169 self._datasetRefDict.addConsumer(dsRef, value)
170 for dsRef in chain.from_iterable(quantum.outputs.values()):
171 self._datasetRefDict.addProducer(dsRef, value)
173 # Graph of task relations, used in various methods
174 self._taskGraph = self._datasetDict.makeNetworkXGraph()
176 # Graph of quanta relations
177 self._connectedQuanta = self._datasetRefDict.makeNetworkXGraph()
179 @property
180 def taskGraph(self) -> nx.DiGraph:
181 """Return a graph representing the relations between the tasks inside
182 the quantum graph.
184 Returns
185 -------
186 taskGraph : `networkx.Digraph`
187 Internal datastructure that holds relations of `TaskDef` objects
188 """
189 return self._taskGraph
191 @property
192 def graph(self) -> nx.DiGraph:
193 """Return a graph representing the relations between all the
194 `QuantumNode` objects. Largely it should be preferred to iterate
195 over, and use methods of this class, but sometimes direct access to
196 the networkx object may be helpful
198 Returns
199 -------
200 graph : `networkx.Digraph`
201 Internal datastructure that holds relations of `QuantumNode`
202 objects
203 """
204 return self._connectedQuanta
206 @property
207 def inputQuanta(self) -> Iterable[QuantumNode]:
208 """Make a `list` of all `QuantumNode` objects that are 'input' nodes
209 to the graph, meaning those nodes to not depend on any other nodes in
210 the graph.
212 Returns
213 -------
214 inputNodes : iterable of `QuantumNode`
215 A list of nodes that are inputs to the graph
216 """
217 return (q for q, n in self._connectedQuanta.in_degree if n == 0)
219 @property
220 def outputQuanta(self) -> Iterable[QuantumNode]:
221 """Make a `list` of all `QuantumNode` objects that are 'output' nodes
222 to the graph, meaning those nodes have no nodes that depend them in
223 the graph.
225 Returns
226 -------
227 outputNodes : iterable of `QuantumNode`
228 A list of nodes that are outputs of the graph
229 """
230 return [q for q, n in self._connectedQuanta.out_degree if n == 0]
232 @property
233 def allDatasetTypes(self) -> Tuple[DatasetTypeName, ...]:
234 """Return all the `DatasetTypeName` objects that are contained inside
235 the graph.
237 Returns
238 -------
239 tuple of `DatasetTypeName`
240 All the data set type names that are present in the graph
241 """
242 return tuple(self._datasetDict.keys())
244 @property
245 def isConnected(self) -> bool:
246 """Return True if all of the nodes in the graph are connected, ignores
247 directionality of connections.
248 """
249 return nx.is_weakly_connected(self._connectedQuanta)
251 def getQuantumNodeByNodeId(self, nodeId: NodeId) -> QuantumNode:
252 """Lookup a `QuantumNode` from an id associated with the node.
254 Parameters
255 ----------
256 nodeId : `NodeId`
257 The number associated with a node
259 Returns
260 -------
261 node : `QuantumNode`
262 The node corresponding with input number
264 Raises
265 ------
266 IndexError
267 Raised if the requested nodeId is not in the graph.
268 IncompatibleGraphError
269 Raised if the nodeId was built with a different graph than is not
270 this instance (or a graph instance that produced this instance
271 through and operation such as subset)
272 """
273 if nodeId.buildId != self._buildId:
274 raise IncompatibleGraphError("This node was built from a different, incompatible, graph instance")
275 return self._nodeIdMap[nodeId]
277 def getQuantaForTask(self, taskDef: TaskDef) -> FrozenSet[Quantum]:
278 """Return all the `Quantum` associated with a `TaskDef`.
280 Parameters
281 ----------
282 taskDef : `TaskDef`
283 The `TaskDef` for which `Quantum` are to be queried
285 Returns
286 -------
287 frozenset of `Quantum`
288 The `set` of `Quantum` that is associated with the specified
289 `TaskDef`.
290 """
291 return frozenset(self._quanta[taskDef])
293 def getNodesForTask(self, taskDef: TaskDef) -> FrozenSet[QuantumNode]:
294 """Return all the `QuantumNodes` associated with a `TaskDef`.
296 Parameters
297 ----------
298 taskDef : `TaskDef`
299 The `TaskDef` for which `Quantum` are to be queried
301 Returns
302 -------
303 frozenset of `QuantumNodes`
304 The `frozenset` of `QuantumNodes` that is associated with the
305 specified `TaskDef`.
306 """
307 return frozenset(self._taskToQuantumNode[taskDef])
309 def findTasksWithInput(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]:
310 """Find all tasks that have the specified dataset type name as an
311 input.
313 Parameters
314 ----------
315 datasetTypeName : `str`
316 A string representing the name of a dataset type to be queried,
317 can also accept a `DatasetTypeName` which is a `NewType` of str for
318 type safety in static type checking.
320 Returns
321 -------
322 tasks : iterable of `TaskDef`
323 `TaskDef` objects that have the specified `DatasetTypeName` as an
324 input, list will be empty if no tasks use specified
325 `DatasetTypeName` as an input.
327 Raises
328 ------
329 KeyError
330 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
331 """
332 return (c for c in self._datasetDict.getConsumers(datasetTypeName))
334 def findTaskWithOutput(self, datasetTypeName: DatasetTypeName) -> Optional[TaskDef]:
335 """Find all tasks that have the specified dataset type name as an
336 output.
338 Parameters
339 ----------
340 datasetTypeName : `str`
341 A string representing the name of a dataset type to be queried,
342 can also accept a `DatasetTypeName` which is a `NewType` of str for
343 type safety in static type checking.
345 Returns
346 -------
347 `TaskDef` or `None`
348 `TaskDef` that outputs `DatasetTypeName` as an output or None if
349 none of the tasks produce this `DatasetTypeName`.
351 Raises
352 ------
353 KeyError
354 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
355 """
356 return self._datasetDict.getProducer(datasetTypeName)
358 def tasksWithDSType(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]:
359 """Find all tasks that are associated with the specified dataset type
360 name.
362 Parameters
363 ----------
364 datasetTypeName : `str`
365 A string representing the name of a dataset type to be queried,
366 can also accept a `DatasetTypeName` which is a `NewType` of str for
367 type safety in static type checking.
369 Returns
370 -------
371 result : iterable of `TaskDef`
372 `TaskDef` objects that are associated with the specified
373 `DatasetTypeName`
375 Raises
376 ------
377 KeyError
378 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
379 """
380 return self._datasetDict.getAll(datasetTypeName)
382 def findTaskDefByName(self, taskName: str) -> List[TaskDef]:
383 """Determine which `TaskDef` objects in this graph are associated
384 with a `str` representing a task name (looks at the taskName property
385 of `TaskDef` objects).
387 Returns a list of `TaskDef` objects as a `PipelineTask` may appear
388 multiple times in a graph with different labels.
390 Parameters
391 ----------
392 taskName : str
393 Name of a task to search for
395 Returns
396 -------
397 result : list of `TaskDef`
398 List of the `TaskDef` objects that have the name specified.
399 Multiple values are returned in the case that a task is used
400 multiple times with different labels.
401 """
402 results = []
403 for task in self._quanta.keys():
404 split = task.taskName.split('.')
405 if split[-1] == taskName:
406 results.append(task)
407 return results
409 def findTaskDefByLabel(self, label: str) -> Optional[TaskDef]:
410 """Determine which `TaskDef` objects in this graph are associated
411 with a `str` representing a tasks label.
413 Parameters
414 ----------
415 taskName : str
416 Name of a task to search for
418 Returns
419 -------
420 result : `TaskDef`
421 `TaskDef` objects that has the specified label.
422 """
423 for task in self._quanta.keys():
424 if label == task.label:
425 return task
426 return None
428 def findQuantaWithDSType(self, datasetTypeName: DatasetTypeName) -> Set[Quantum]:
429 """Return all the `Quantum` that contain a specified `DatasetTypeName`.
431 Parameters
432 ----------
433 datasetTypeName : `str`
434 The name of the dataset type to search for as a string,
435 can also accept a `DatasetTypeName` which is a `NewType` of str for
436 type safety in static type checking.
438 Returns
439 -------
440 result : `set` of `QuantumNode` objects
441 A `set` of `QuantumNode`s that contain specified `DatasetTypeName`
443 Raises
444 ------
445 KeyError
446 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
448 """
449 tasks = self._datasetDict.getAll(datasetTypeName)
450 result: Set[Quantum] = set()
451 result = result.union(*(self._quanta[task] for task in tasks))
452 return result
454 def checkQuantumInGraph(self, quantum: Quantum) -> bool:
455 """Check if specified quantum appears in the graph as part of a node.
457 Parameters
458 ----------
459 quantum : `Quantum`
460 The quantum to search for
462 Returns
463 -------
464 `bool`
465 The result of searching for the quantum
466 """
467 for qset in self._quanta.values():
468 if quantum in qset:
469 return True
470 return False
472 def writeDotGraph(self, output: Union[str, io.BufferedIOBase]):
473 """Write out the graph as a dot graph.
475 Parameters
476 ----------
477 output : str or `io.BufferedIOBase`
478 Either a filesystem path to write to, or a file handle object
479 """
480 write_dot(self._connectedQuanta, output)
482 def subset(self: _T, nodes: Union[QuantumNode, Iterable[QuantumNode]]) -> _T:
483 """Create a new graph object that contains the subset of the nodes
484 specified as input. Node number is preserved.
486 Parameters
487 ----------
488 nodes : `QuantumNode` or iterable of `QuantumNode`
490 Returns
491 -------
492 graph : instance of graph type
493 An instance of the type from which the subset was created
494 """
495 if not isinstance(nodes, Iterable):
496 nodes = (nodes, )
497 quantumSubgraph = self._connectedQuanta.subgraph(nodes).nodes
498 quantumMap = defaultdict(set)
500 node: QuantumNode
501 for node in quantumSubgraph:
502 quantumMap[node.taskDef].add(node.quantum)
503 # Create an empty graph, and then populate it with custom mapping
504 newInst = type(self)({})
505 newInst._buildGraphs(quantumMap, _quantumToNodeId={n.quantum: n.nodeId for n in nodes},
506 _buildId=self._buildId)
507 return newInst
509 def subsetToConnected(self: _T) -> Tuple[_T, ...]:
510 """Generate a list of subgraphs where each is connected.
512 Returns
513 -------
514 result : list of `QuantumGraph`
515 A list of graphs that are each connected
516 """
517 return tuple(self.subset(connectedSet)
518 for connectedSet in nx.weakly_connected_components(self._connectedQuanta))
520 def determineInputsToQuantumNode(self, node: QuantumNode) -> Set[QuantumNode]:
521 """Return a set of `QuantumNode` that are direct inputs to a specified
522 node.
524 Parameters
525 ----------
526 node : `QuantumNode`
527 The node of the graph for which inputs are to be determined
529 Returns
530 -------
531 set of `QuantumNode`
532 All the nodes that are direct inputs to specified node
533 """
534 return set(pred for pred in self._connectedQuanta.predecessors(node))
536 def determineOutputsOfQuantumNode(self, node: QuantumNode) -> Set[QuantumNode]:
537 """Return a set of `QuantumNode` that are direct outputs of a specified
538 node.
540 Parameters
541 ----------
542 node : `QuantumNode`
543 The node of the graph for which outputs are to be determined
545 Returns
546 -------
547 set of `QuantumNode`
548 All the nodes that are direct outputs to specified node
549 """
550 return set(succ for succ in self._connectedQuanta.successors(node))
552 def determineConnectionsOfQuantumNode(self: _T, node: QuantumNode) -> _T:
553 """Return a graph of `QuantumNode` that are direct inputs and outputs
554 of a specified node.
556 Parameters
557 ----------
558 node : `QuantumNode`
559 The node of the graph for which connected nodes are to be
560 determined.
562 Returns
563 -------
564 graph : graph of `QuantumNode`
565 All the nodes that are directly connected to specified node
566 """
567 nodes = self.determineInputsToQuantumNode(node).union(self.determineOutputsOfQuantumNode(node))
568 nodes.add(node)
569 return self.subset(nodes)
571 def determineAncestorsOfQuantumNode(self: _T, node: QuantumNode) -> _T:
572 """Return a graph of the specified node and all the ancestor nodes
573 directly reachable by walking edges.
575 Parameters
576 ----------
577 node : `QuantumNode`
578 The node for which all ansestors are to be determined
580 Returns
581 -------
582 graph of `QuantumNode`
583 Graph of node and all of its ansestors
584 """
585 predecessorNodes = nx.ancestors(self._connectedQuanta, node)
586 predecessorNodes.add(node)
587 return self.subset(predecessorNodes)
589 def findCycle(self) -> List[Tuple[QuantumNode, QuantumNode]]:
590 """Check a graph for the presense of cycles and returns the edges of
591 any cycles found, or an empty list if there is no cycle.
593 Returns
594 -------
595 result : list of tuple of `QuantumNode`, `QuantumNode`
596 A list of any graph edges that form a cycle, or an empty list if
597 there is no cycle. Empty list to so support if graph.find_cycle()
598 syntax as an empty list is falsy.
599 """
600 try:
601 return nx.find_cycle(self._connectedQuanta)
602 except nx.NetworkXNoCycle:
603 return []
605 def saveUri(self, uri):
606 """Save `QuantumGraph` to the specified URI.
608 Parameters
609 ----------
610 uri : `ButlerURI` or `str`
611 URI to where the graph should be saved.
612 """
613 buffer = self._buildSaveObject()
614 butlerUri = ButlerURI(uri)
615 if butlerUri.getExtension() not in (".qgraph"):
616 raise TypeError(f"Can currently only save a graph in qgraph format not {uri}")
617 butlerUri.write(buffer) # type: ignore # Ignore because bytearray is safe to use in place of bytes
619 @property
620 def metadata(self) -> Optional[MappingProxyType[str, Any]]:
621 """
622 """
623 if self._metadata is None:
624 return None
625 return MappingProxyType(self._metadata)
627 @classmethod
628 def loadUri(cls, uri: Union[ButlerURI, str], universe: DimensionUniverse,
629 nodes: Optional[Iterable[int]] = None,
630 graphID: Optional[BuildId] = None
631 ) -> QuantumGraph:
632 """Read `QuantumGraph` from a URI.
634 Parameters
635 ----------
636 uri : `ButlerURI` or `str`
637 URI from where to load the graph.
638 universe: `~lsst.daf.butler.DimensionUniverse`
639 DimensionUniverse instance, not used by the method itself but
640 needed to ensure that registry data structures are initialized.
641 nodes: iterable of `int` or None
642 Numbers that correspond to nodes in the graph. If specified, only
643 these nodes will be loaded. Defaults to None, in which case all
644 nodes will be loaded.
645 graphID : `str` or `None`
646 If specified this ID is verified against the loaded graph prior to
647 loading any Nodes. This defaults to None in which case no
648 validation is done.
650 Returns
651 -------
652 graph : `QuantumGraph`
653 Resulting QuantumGraph instance.
655 Raises
656 ------
657 TypeError
658 Raised if pickle contains instance of a type other than
659 QuantumGraph.
660 ValueError
661 Raised if one or more of the nodes requested is not in the
662 `QuantumGraph` or if graphID parameter does not match the graph
663 being loaded or if the supplied uri does not point at a valid
664 `QuantumGraph` save file.
667 Notes
668 -----
669 Reading Quanta from pickle requires existence of singleton
670 DimensionUniverse which is usually instantiated during Registry
671 initialization. To make sure that DimensionUniverse exists this method
672 accepts dummy DimensionUniverse argument.
673 """
674 uri = ButlerURI(uri)
675 # With ButlerURI we have the choice of always using a local file
676 # or reading in the bytes directly. Reading in bytes can be more
677 # efficient for reasonably-sized pickle files when the resource
678 # is remote. For now use the local file variant. For a local file
679 # as_local() does nothing.
681 if uri.getExtension() in (".pickle", ".pkl"):
682 with uri.as_local() as local, open(local.ospath, "rb") as fd:
683 warnings.warn("Pickle graphs are deprecated, please re-save your graph with the save method")
684 qgraph = pickle.load(fd)
685 elif uri.getExtension() in ('.qgraph'):
686 with LoadHelper(uri) as loader:
687 qgraph = loader.load(nodes, graphID)
688 else:
689 raise ValueError("Only know how to handle files saved as `pickle`, `pkl`, or `qgraph`")
690 if not isinstance(qgraph, QuantumGraph):
691 raise TypeError(f"QuantumGraph save file contains unexpected object type: {type(qgraph)}")
692 return qgraph
694 def save(self, file: io.IO[bytes]):
695 """Save QuantumGraph to a file.
697 Presently we store QuantumGraph in pickle format, this could
698 potentially change in the future if better format is found.
700 Parameters
701 ----------
702 file : `io.BufferedIOBase`
703 File to write pickle data open in binary mode.
704 """
705 buffer = self._buildSaveObject()
706 file.write(buffer) # type: ignore # Ignore because bytearray is safe to use in place of bytes
708 def _buildSaveObject(self) -> bytearray:
709 # make some containers
710 pickleData = deque()
711 # node map is a list because json does not accept mapping keys that
712 # are not strings, so we store a list of key, value pairs that will
713 # be converted to a mapping on load
714 nodeMap = []
715 taskDefMap = {}
716 headerData = {}
717 protocol = 3
719 # Store the QauntumGraph BuildId, this will allow validating BuildIds
720 # at load time, prior to loading any QuantumNodes. Name chosen for
721 # unlikely conflicts.
722 headerData['GraphBuildID'] = self.graphID
723 headerData['Metadata'] = self._metadata
725 # counter for the number of bytes processed thus far
726 count = 0
727 # serialize out the task Defs recording the start and end bytes of each
728 # taskDef
729 for taskDef in self.taskGraph:
730 # compressing has very little impact on saving or load time, but
731 # a large impact on on disk size, so it is worth doing
732 dump = lzma.compress(pickle.dumps(taskDef, protocol=protocol))
733 taskDefMap[taskDef.label] = {"bytes": (count, count+len(dump))}
734 count += len(dump)
735 pickleData.append(dump)
737 headerData['TaskDefs'] = taskDefMap
739 # serialize the nodes, recording the start and end bytes of each node
740 for node in self:
741 node = copy.copy(node)
742 taskDef = node.taskDef
743 # Explicitly overload the "frozen-ness" of nodes to normalized out
744 # the taskDef, this saves a lot of space and load time. The label
745 # will be used to retrive the taskDef from the taskDefMap upon load
746 #
747 # This strategy was chosen instead of creating a new class that
748 # looked just like a QuantumNode but containing a label in place of
749 # a TaskDef because it would be needlessly slow to construct a
750 # bunch of new object to immediately serialize them and destroy the
751 # object. This seems like an acceptable use of Python's dynamic
752 # nature in a controlled way for optimization and simplicity.
753 object.__setattr__(node, 'taskDef', taskDef.label)
754 # compressing has very little impact on saving or load time, but
755 # a large impact on on disk size, so it is worth doing
756 dump = lzma.compress(pickle.dumps(node, protocol=protocol))
757 pickleData.append(dump)
758 nodeMap.append((int(node.nodeId.number), {"bytes": (count, count+len(dump))}))
759 count += len(dump)
761 # need to serialize this as a series of key,value tuples because of
762 # a limitation on how json cant do anyting but strings as keys
763 headerData['Nodes'] = nodeMap
765 # dump the headerData to json
766 header_encode = lzma.compress(json.dumps(headerData).encode())
768 # record the sizes as 2 unsigned long long numbers for a total of 16
769 # bytes
770 save_bytes = struct.pack(STRUCT_FMT_BASE, SAVE_VERSION)
772 fmt_string = STRUCT_FMT_STRING[SAVE_VERSION]
773 map_lengths = struct.pack(fmt_string, len(header_encode))
775 # write each component of the save out in a deterministic order
776 # buffer = io.BytesIO()
777 # buffer.write(map_lengths)
778 # buffer.write(taskDef_pickle)
779 # buffer.write(map_pickle)
780 buffer = bytearray()
781 buffer.extend(MAGIC_BYTES)
782 buffer.extend(save_bytes)
783 buffer.extend(map_lengths)
784 buffer.extend(header_encode)
785 # Iterate over the length of pickleData, and for each element pop the
786 # leftmost element off the deque and write it out. This is to save
787 # memory, as the memory is added to the buffer object, it is removed
788 # from from the container.
789 #
790 # Only this section needs to worry about memory pressue because
791 # everything else written to the buffer prior to this pickle data is
792 # only on the order of kilobytes to low numbers of megabytes.
793 while pickleData:
794 buffer.extend(pickleData.popleft())
795 return buffer
797 @classmethod
798 def load(cls, file: io.IO[bytes], universe: DimensionUniverse,
799 nodes: Optional[Iterable[int]] = None,
800 graphID: Optional[BuildId] = None
801 ) -> QuantumGraph:
802 """Read QuantumGraph from a file that was made by `save`.
804 Parameters
805 ----------
806 file : `io.IO` of bytes
807 File with pickle data open in binary mode.
808 universe: `~lsst.daf.butler.DimensionUniverse`
809 DimensionUniverse instance, not used by the method itself but
810 needed to ensure that registry data structures are initialized.
811 nodes: iterable of `int` or None
812 Numbers that correspond to nodes in the graph. If specified, only
813 these nodes will be loaded. Defaults to None, in which case all
814 nodes will be loaded.
815 graphID : `str` or `None`
816 If specified this ID is verified against the loaded graph prior to
817 loading any Nodes. This defaults to None in which case no
818 validation is done.
820 Returns
821 -------
822 graph : `QuantumGraph`
823 Resulting QuantumGraph instance.
825 Raises
826 ------
827 TypeError
828 Raised if pickle contains instance of a type other than
829 QuantumGraph.
830 ValueError
831 Raised if one or more of the nodes requested is not in the
832 `QuantumGraph` or if graphID parameter does not match the graph
833 being loaded or if the supplied uri does not point at a valid
834 `QuantumGraph` save file.
836 Notes
837 -----
838 Reading Quanta from pickle requires existence of singleton
839 DimensionUniverse which is usually instantiated during Registry
840 initialization. To make sure that DimensionUniverse exists this method
841 accepts dummy DimensionUniverse argument.
842 """
843 # Try to see if the file handle contains pickle data, this will be
844 # removed in the future
845 try:
846 qgraph = pickle.load(file)
847 warnings.warn("Pickle graphs are deprecated, please re-save your graph with the save method")
848 except pickle.UnpicklingError:
849 with LoadHelper(file) as loader: # type: ignore # needed because we don't have Protocols yet
850 qgraph = loader.load(nodes, graphID)
851 if not isinstance(qgraph, QuantumGraph):
852 raise TypeError(f"QuantumGraph pickle file has contains unexpected object type: {type(qgraph)}")
853 return qgraph
855 def iterTaskGraph(self) -> Generator[TaskDef, None, None]:
856 """Iterate over the `taskGraph` attribute in topological order
858 Yields
859 ------
860 taskDef : `TaskDef`
861 `TaskDef` objects in topological order
862 """
863 yield from nx.topological_sort(self.taskGraph)
865 @property
866 def graphID(self):
867 """Returns the ID generated by the graph at construction time
868 """
869 return self._buildId
871 def __iter__(self) -> Generator[QuantumNode, None, None]:
872 yield from nx.topological_sort(self._connectedQuanta)
874 def __len__(self) -> int:
875 return self._count
877 def __contains__(self, node: QuantumNode) -> bool:
878 return self._connectedQuanta.has_node(node)
880 def __getstate__(self) -> dict:
881 """Stores a compact form of the graph as a list of graph nodes, and a
882 tuple of task labels and task configs. The full graph can be
883 reconstructed with this information, and it preseves the ordering of
884 the graph ndoes.
885 """
886 return {"nodesList": list(self)}
888 def __setstate__(self, state: dict):
889 """Reconstructs the state of the graph from the information persisted
890 in getstate.
891 """
892 quanta: DefaultDict[TaskDef, Set[Quantum]] = defaultdict(set)
893 quantumToNodeId: Dict[Quantum, NodeId] = {}
894 quantumNode: QuantumNode
895 for quantumNode in state['nodesList']:
896 quanta[quantumNode.taskDef].add(quantumNode.quantum)
897 quantumToNodeId[quantumNode.quantum] = quantumNode.nodeId
898 _buildId = quantumNode.nodeId.buildId if state['nodesList'] else None # type: ignore
899 self._buildGraphs(quanta, _quantumToNodeId=quantumToNodeId, _buildId=_buildId)
901 def __eq__(self, other: object) -> bool:
902 if not isinstance(other, QuantumGraph):
903 return False
904 if len(self) != len(other):
905 return False
906 for node in self:
907 if node not in other:
908 return False
909 if self.determineInputsToQuantumNode(node) != other.determineInputsToQuantumNode(node):
910 return False
911 if self.determineOutputsOfQuantumNode(node) != other.determineOutputsOfQuantumNode(node):
912 return False
913 return list(self.taskGraph) == list(other.taskGraph)