Coverage for python/lsst/pipe/base/graph/graph.py: 21%
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23from lsst.daf.butler.core.datasets.type import DatasetType
24__all__ = ("QuantumGraph", "IncompatibleGraphError")
26import warnings
28from lsst.daf.butler import Quantum, DatasetRef, ButlerURI, DimensionUniverse
30from collections import defaultdict, deque
32from itertools import chain, count
33import io
34import json
35import networkx as nx
36from networkx.drawing.nx_agraph import write_dot
37import os
38import pickle
39import lzma
40import copy
41import struct
42import time
43from types import MappingProxyType
44from typing import (Any, DefaultDict, Dict, FrozenSet, Iterable, List, Mapping, Set, Generator,
45 Optional, Tuple, Union, TypeVar)
47from ..connections import iterConnections
48from ..pipeline import TaskDef
50from ._implDetails import _DatasetTracker, DatasetTypeName, _pruner
51from .quantumNode import QuantumNode, NodeId, BuildId
52from ._loadHelpers import LoadHelper
55_T = TypeVar("_T", bound="QuantumGraph")
57# modify this constant any time the on disk representation of the save file
58# changes, and update the load helpers to behave properly for each version.
59SAVE_VERSION = 2
61# Strings used to describe the format for the preamble bytes in a file save
62# The base is a big endian encoded unsigned short that is used to hold the
63# file format version. This allows reading version bytes and determine which
64# loading code should be used for the rest of the file
65STRUCT_FMT_BASE = '>H'
66#
67# Version 1
68# This marks a big endian encoded format with an unsigned short, an unsigned
69# long long, and an unsigned long long in the byte stream
70# Version 2
71# A big endian encoded format with an unsigned long long byte stream used to
72# indicate the total length of the entire header
73STRUCT_FMT_STRING = {
74 1: '>QQ',
75 2: '>Q'
76}
79# magic bytes that help determine this is a graph save
80MAGIC_BYTES = b"qgraph4\xf6\xe8\xa9"
83class IncompatibleGraphError(Exception):
84 """Exception class to indicate that a lookup by NodeId is impossible due
85 to incompatibilities
86 """
87 pass
90class QuantumGraph:
91 """QuantumGraph is a directed acyclic graph of `QuantumNode` objects
93 This data structure represents a concrete workflow generated from a
94 `Pipeline`.
96 Parameters
97 ----------
98 quanta : Mapping of `TaskDef` to sets of `Quantum`
99 This maps tasks (and their configs) to the sets of data they are to
100 process.
101 metadata : Optional Mapping of `str` to primitives
102 This is an optional parameter of extra data to carry with the graph.
103 Entries in this mapping should be able to be serialized in JSON.
105 Raises
106 ------
107 ValueError
108 Raised if the graph is pruned such that some tasks no longer have nodes
109 associated with them.
110 """
111 def __init__(self, quanta: Mapping[TaskDef, Set[Quantum]],
112 metadata: Optional[Mapping[str, Any]] = None,
113 pruneRefs: Optional[Iterable[DatasetRef]] = None):
114 self._buildGraphs(quanta, metadata=metadata, pruneRefs=pruneRefs)
116 def _buildGraphs(self,
117 quanta: Mapping[TaskDef, Set[Quantum]],
118 *,
119 _quantumToNodeId: Optional[Mapping[Quantum, NodeId]] = None,
120 _buildId: Optional[BuildId] = None,
121 metadata: Optional[Mapping[str, Any]] = None,
122 pruneRefs: Optional[Iterable[DatasetRef]] = None):
123 """Builds the graph that is used to store the relation between tasks,
124 and the graph that holds the relations between quanta
125 """
126 self._metadata = metadata
127 self._quanta = quanta
128 self._buildId = _buildId if _buildId is not None else BuildId(f"{time.time()}-{os.getpid()}")
129 # Data structures used to identify relations between components;
130 # DatasetTypeName -> TaskDef for task,
131 # and DatasetRef -> QuantumNode for the quanta
132 self._datasetDict = _DatasetTracker[DatasetTypeName, TaskDef]()
133 self._datasetRefDict = _DatasetTracker[DatasetRef, QuantumNode]()
135 nodeNumberGenerator = count()
136 self._nodeIdMap: Dict[NodeId, QuantumNode] = {}
137 self._taskToQuantumNode: DefaultDict[TaskDef, Set[QuantumNode]] = defaultdict(set)
138 for taskDef, quantumSet in self._quanta.items():
139 connections = taskDef.connections
141 # For each type of connection in the task, add a key to the
142 # `_DatasetTracker` for the connections name, with a value of
143 # the TaskDef in the appropriate field
144 for inpt in iterConnections(connections, ("inputs", "prerequisiteInputs", "initInputs")):
145 self._datasetDict.addConsumer(DatasetTypeName(inpt.name), taskDef)
147 for output in iterConnections(connections, ("outputs",)):
148 self._datasetDict.addProducer(DatasetTypeName(output.name), taskDef)
150 # For each `Quantum` in the set of all `Quantum` for this task,
151 # add a key to the `_DatasetTracker` that is a `DatasetRef` for one
152 # of the individual datasets inside the `Quantum`, with a value of
153 # a newly created QuantumNode to the appropriate input/output
154 # field.
155 for quantum in quantumSet:
156 if _quantumToNodeId:
157 nodeId = _quantumToNodeId.get(quantum)
158 if nodeId is None:
159 raise ValueError("If _quantuMToNodeNumber is not None, all quanta must have an "
160 "associated value in the mapping")
161 else:
162 nodeId = NodeId(next(nodeNumberGenerator), self._buildId)
164 inits = quantum.initInputs.values()
165 inputs = quantum.inputs.values()
166 value = QuantumNode(quantum, taskDef, nodeId)
167 self._taskToQuantumNode[taskDef].add(value)
168 self._nodeIdMap[nodeId] = value
170 for dsRef in chain(inits, inputs):
171 # unfortunately, `Quantum` allows inits to be individual
172 # `DatasetRef`s or an Iterable of such, so there must
173 # be an instance check here
174 if isinstance(dsRef, Iterable):
175 for sub in dsRef:
176 if sub.isComponent():
177 sub = sub.makeCompositeRef()
178 self._datasetRefDict.addConsumer(sub, value)
179 else:
180 if dsRef.isComponent():
181 dsRef = dsRef.makeCompositeRef()
182 self._datasetRefDict.addConsumer(dsRef, value)
183 for dsRef in chain.from_iterable(quantum.outputs.values()):
184 self._datasetRefDict.addProducer(dsRef, value)
186 if pruneRefs is not None:
187 # track what refs were pruned and prune the graph
188 prunes = set()
189 _pruner(self._datasetRefDict, pruneRefs, alreadyPruned=prunes)
191 # recreate the taskToQuantumNode dict removing nodes that have been
192 # pruned. Keep track of task defs that now have no QuantumNodes
193 emptyTasks: Set[str] = set()
194 newTaskToQuantumNode: DefaultDict[TaskDef, Set[QuantumNode]] = defaultdict(set)
195 # accumulate all types
196 types_ = set()
197 # tracker for any pruneRefs that have caused tasks to have no nodes
198 # This helps the user find out what caused the issues seen.
199 culprits = set()
200 # Find all the types from the refs to prune
201 for r in pruneRefs:
202 types_.add(r.datasetType)
204 # For each of the tasks, and their associated nodes, remove any
205 # any nodes that were pruned. If there are no nodes associated
206 # with a task, record that task, and find out if that was due to
207 # a type from an input ref to prune.
208 for td, taskNodes in self._taskToQuantumNode.items():
209 diff = taskNodes.difference(prunes)
210 if len(diff) == 0:
211 if len(taskNodes) != 0:
212 tp: DatasetType
213 for tp in types_:
214 if ((tmpRefs := next(iter(taskNodes)).quantum.inputs.get(tp)) and not
215 set(tmpRefs).difference(pruneRefs)):
216 culprits.add(tp.name)
217 emptyTasks.add(td.label)
218 newTaskToQuantumNode[td] = diff
220 # update the internal dict
221 self._taskToQuantumNode = newTaskToQuantumNode
223 if emptyTasks:
224 raise ValueError(f"{', '.join(emptyTasks)} task(s) have no nodes associated with them "
225 f"after graph pruning; {', '.join(culprits)} caused over-pruning")
227 # Graph of quanta relations
228 self._connectedQuanta = self._datasetRefDict.makeNetworkXGraph()
229 self._count = len(self._connectedQuanta)
231 # Graph of task relations, used in various methods
232 self._taskGraph = self._datasetDict.makeNetworkXGraph()
234 @property
235 def taskGraph(self) -> nx.DiGraph:
236 """Return a graph representing the relations between the tasks inside
237 the quantum graph.
239 Returns
240 -------
241 taskGraph : `networkx.Digraph`
242 Internal datastructure that holds relations of `TaskDef` objects
243 """
244 return self._taskGraph
246 @property
247 def graph(self) -> nx.DiGraph:
248 """Return a graph representing the relations between all the
249 `QuantumNode` objects. Largely it should be preferred to iterate
250 over, and use methods of this class, but sometimes direct access to
251 the networkx object may be helpful
253 Returns
254 -------
255 graph : `networkx.Digraph`
256 Internal datastructure that holds relations of `QuantumNode`
257 objects
258 """
259 return self._connectedQuanta
261 @property
262 def inputQuanta(self) -> Iterable[QuantumNode]:
263 """Make a `list` of all `QuantumNode` objects that are 'input' nodes
264 to the graph, meaning those nodes to not depend on any other nodes in
265 the graph.
267 Returns
268 -------
269 inputNodes : iterable of `QuantumNode`
270 A list of nodes that are inputs to the graph
271 """
272 return (q for q, n in self._connectedQuanta.in_degree if n == 0)
274 @property
275 def outputQuanta(self) -> Iterable[QuantumNode]:
276 """Make a `list` of all `QuantumNode` objects that are 'output' nodes
277 to the graph, meaning those nodes have no nodes that depend them in
278 the graph.
280 Returns
281 -------
282 outputNodes : iterable of `QuantumNode`
283 A list of nodes that are outputs of the graph
284 """
285 return [q for q, n in self._connectedQuanta.out_degree if n == 0]
287 @property
288 def allDatasetTypes(self) -> Tuple[DatasetTypeName, ...]:
289 """Return all the `DatasetTypeName` objects that are contained inside
290 the graph.
292 Returns
293 -------
294 tuple of `DatasetTypeName`
295 All the data set type names that are present in the graph
296 """
297 return tuple(self._datasetDict.keys())
299 @property
300 def isConnected(self) -> bool:
301 """Return True if all of the nodes in the graph are connected, ignores
302 directionality of connections.
303 """
304 return nx.is_weakly_connected(self._connectedQuanta)
306 def pruneGraphFromRefs(self: _T, refs: Iterable[DatasetRef]) -> _T:
307 r"""Return a graph pruned of input `~lsst.daf.butler.DatasetRef`\ s
308 and nodes which depend on them.
310 Parameters
311 ----------
312 refs : `Iterable` of `DatasetRef`
313 Refs which should be removed from resulting graph
315 Returns
316 -------
317 graph : `QuantumGraph`
318 A graph that has been pruned of specified refs and the nodes that
319 depend on them.
320 """
321 newInst = object.__new__(type(self))
322 quantumMap = defaultdict(set)
323 for node in self:
324 quantumMap[node.taskDef].add(node.quantum)
326 newInst._buildGraphs(quantumMap, _quantumToNodeId={n.quantum: n.nodeId for n in self},
327 metadata=self._metadata, pruneRefs=refs)
328 return newInst
330 def getQuantumNodeByNodeId(self, nodeId: NodeId) -> QuantumNode:
331 """Lookup a `QuantumNode` from an id associated with the node.
333 Parameters
334 ----------
335 nodeId : `NodeId`
336 The number associated with a node
338 Returns
339 -------
340 node : `QuantumNode`
341 The node corresponding with input number
343 Raises
344 ------
345 IndexError
346 Raised if the requested nodeId is not in the graph.
347 IncompatibleGraphError
348 Raised if the nodeId was built with a different graph than is not
349 this instance (or a graph instance that produced this instance
350 through and operation such as subset)
351 """
352 if nodeId.buildId != self._buildId:
353 raise IncompatibleGraphError("This node was built from a different, incompatible, graph instance")
354 return self._nodeIdMap[nodeId]
356 def getQuantaForTask(self, taskDef: TaskDef) -> FrozenSet[Quantum]:
357 """Return all the `Quantum` associated with a `TaskDef`.
359 Parameters
360 ----------
361 taskDef : `TaskDef`
362 The `TaskDef` for which `Quantum` are to be queried
364 Returns
365 -------
366 frozenset of `Quantum`
367 The `set` of `Quantum` that is associated with the specified
368 `TaskDef`.
369 """
370 return frozenset(self._quanta[taskDef])
372 def getNodesForTask(self, taskDef: TaskDef) -> FrozenSet[QuantumNode]:
373 """Return all the `QuantumNodes` associated with a `TaskDef`.
375 Parameters
376 ----------
377 taskDef : `TaskDef`
378 The `TaskDef` for which `Quantum` are to be queried
380 Returns
381 -------
382 frozenset of `QuantumNodes`
383 The `frozenset` of `QuantumNodes` that is associated with the
384 specified `TaskDef`.
385 """
386 return frozenset(self._taskToQuantumNode[taskDef])
388 def findTasksWithInput(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]:
389 """Find all tasks that have the specified dataset type name as an
390 input.
392 Parameters
393 ----------
394 datasetTypeName : `str`
395 A string representing the name of a dataset type to be queried,
396 can also accept a `DatasetTypeName` which is a `NewType` of str for
397 type safety in static type checking.
399 Returns
400 -------
401 tasks : iterable of `TaskDef`
402 `TaskDef` objects that have the specified `DatasetTypeName` as an
403 input, list will be empty if no tasks use specified
404 `DatasetTypeName` as an input.
406 Raises
407 ------
408 KeyError
409 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
410 """
411 return (c for c in self._datasetDict.getConsumers(datasetTypeName))
413 def findTaskWithOutput(self, datasetTypeName: DatasetTypeName) -> Optional[TaskDef]:
414 """Find all tasks that have the specified dataset type name as an
415 output.
417 Parameters
418 ----------
419 datasetTypeName : `str`
420 A string representing the name of a dataset type to be queried,
421 can also accept a `DatasetTypeName` which is a `NewType` of str for
422 type safety in static type checking.
424 Returns
425 -------
426 `TaskDef` or `None`
427 `TaskDef` that outputs `DatasetTypeName` as an output or None if
428 none of the tasks produce this `DatasetTypeName`.
430 Raises
431 ------
432 KeyError
433 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
434 """
435 return self._datasetDict.getProducer(datasetTypeName)
437 def tasksWithDSType(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]:
438 """Find all tasks that are associated with the specified dataset type
439 name.
441 Parameters
442 ----------
443 datasetTypeName : `str`
444 A string representing the name of a dataset type to be queried,
445 can also accept a `DatasetTypeName` which is a `NewType` of str for
446 type safety in static type checking.
448 Returns
449 -------
450 result : iterable of `TaskDef`
451 `TaskDef` objects that are associated with the specified
452 `DatasetTypeName`
454 Raises
455 ------
456 KeyError
457 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
458 """
459 return self._datasetDict.getAll(datasetTypeName)
461 def findTaskDefByName(self, taskName: str) -> List[TaskDef]:
462 """Determine which `TaskDef` objects in this graph are associated
463 with a `str` representing a task name (looks at the taskName property
464 of `TaskDef` objects).
466 Returns a list of `TaskDef` objects as a `PipelineTask` may appear
467 multiple times in a graph with different labels.
469 Parameters
470 ----------
471 taskName : str
472 Name of a task to search for
474 Returns
475 -------
476 result : list of `TaskDef`
477 List of the `TaskDef` objects that have the name specified.
478 Multiple values are returned in the case that a task is used
479 multiple times with different labels.
480 """
481 results = []
482 for task in self._quanta.keys():
483 split = task.taskName.split('.')
484 if split[-1] == taskName:
485 results.append(task)
486 return results
488 def findTaskDefByLabel(self, label: str) -> Optional[TaskDef]:
489 """Determine which `TaskDef` objects in this graph are associated
490 with a `str` representing a tasks label.
492 Parameters
493 ----------
494 taskName : str
495 Name of a task to search for
497 Returns
498 -------
499 result : `TaskDef`
500 `TaskDef` objects that has the specified label.
501 """
502 for task in self._quanta.keys():
503 if label == task.label:
504 return task
505 return None
507 def findQuantaWithDSType(self, datasetTypeName: DatasetTypeName) -> Set[Quantum]:
508 """Return all the `Quantum` that contain a specified `DatasetTypeName`.
510 Parameters
511 ----------
512 datasetTypeName : `str`
513 The name of the dataset type to search for as a string,
514 can also accept a `DatasetTypeName` which is a `NewType` of str for
515 type safety in static type checking.
517 Returns
518 -------
519 result : `set` of `QuantumNode` objects
520 A `set` of `QuantumNode`s that contain specified `DatasetTypeName`
522 Raises
523 ------
524 KeyError
525 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
527 """
528 tasks = self._datasetDict.getAll(datasetTypeName)
529 result: Set[Quantum] = set()
530 result = result.union(*(self._quanta[task] for task in tasks))
531 return result
533 def checkQuantumInGraph(self, quantum: Quantum) -> bool:
534 """Check if specified quantum appears in the graph as part of a node.
536 Parameters
537 ----------
538 quantum : `Quantum`
539 The quantum to search for
541 Returns
542 -------
543 `bool`
544 The result of searching for the quantum
545 """
546 for qset in self._quanta.values():
547 if quantum in qset:
548 return True
549 return False
551 def writeDotGraph(self, output: Union[str, io.BufferedIOBase]):
552 """Write out the graph as a dot graph.
554 Parameters
555 ----------
556 output : str or `io.BufferedIOBase`
557 Either a filesystem path to write to, or a file handle object
558 """
559 write_dot(self._connectedQuanta, output)
561 def subset(self: _T, nodes: Union[QuantumNode, Iterable[QuantumNode]]) -> _T:
562 """Create a new graph object that contains the subset of the nodes
563 specified as input. Node number is preserved.
565 Parameters
566 ----------
567 nodes : `QuantumNode` or iterable of `QuantumNode`
569 Returns
570 -------
571 graph : instance of graph type
572 An instance of the type from which the subset was created
573 """
574 if not isinstance(nodes, Iterable):
575 nodes = (nodes, )
576 quantumSubgraph = self._connectedQuanta.subgraph(nodes).nodes
577 quantumMap = defaultdict(set)
579 node: QuantumNode
580 for node in quantumSubgraph:
581 quantumMap[node.taskDef].add(node.quantum)
582 # Create an empty graph, and then populate it with custom mapping
583 newInst = type(self)({})
584 newInst._buildGraphs(quantumMap, _quantumToNodeId={n.quantum: n.nodeId for n in nodes},
585 _buildId=self._buildId, metadata=self._metadata)
586 return newInst
588 def subsetToConnected(self: _T) -> Tuple[_T, ...]:
589 """Generate a list of subgraphs where each is connected.
591 Returns
592 -------
593 result : list of `QuantumGraph`
594 A list of graphs that are each connected
595 """
596 return tuple(self.subset(connectedSet)
597 for connectedSet in nx.weakly_connected_components(self._connectedQuanta))
599 def determineInputsToQuantumNode(self, node: QuantumNode) -> Set[QuantumNode]:
600 """Return a set of `QuantumNode` that are direct inputs to a specified
601 node.
603 Parameters
604 ----------
605 node : `QuantumNode`
606 The node of the graph for which inputs are to be determined
608 Returns
609 -------
610 set of `QuantumNode`
611 All the nodes that are direct inputs to specified node
612 """
613 return set(pred for pred in self._connectedQuanta.predecessors(node))
615 def determineOutputsOfQuantumNode(self, node: QuantumNode) -> Set[QuantumNode]:
616 """Return a set of `QuantumNode` that are direct outputs of a specified
617 node.
619 Parameters
620 ----------
621 node : `QuantumNode`
622 The node of the graph for which outputs are to be determined
624 Returns
625 -------
626 set of `QuantumNode`
627 All the nodes that are direct outputs to specified node
628 """
629 return set(succ for succ in self._connectedQuanta.successors(node))
631 def determineConnectionsOfQuantumNode(self: _T, node: QuantumNode) -> _T:
632 """Return a graph of `QuantumNode` that are direct inputs and outputs
633 of a specified node.
635 Parameters
636 ----------
637 node : `QuantumNode`
638 The node of the graph for which connected nodes are to be
639 determined.
641 Returns
642 -------
643 graph : graph of `QuantumNode`
644 All the nodes that are directly connected to specified node
645 """
646 nodes = self.determineInputsToQuantumNode(node).union(self.determineOutputsOfQuantumNode(node))
647 nodes.add(node)
648 return self.subset(nodes)
650 def determineAncestorsOfQuantumNode(self: _T, node: QuantumNode) -> _T:
651 """Return a graph of the specified node and all the ancestor nodes
652 directly reachable by walking edges.
654 Parameters
655 ----------
656 node : `QuantumNode`
657 The node for which all ansestors are to be determined
659 Returns
660 -------
661 graph of `QuantumNode`
662 Graph of node and all of its ansestors
663 """
664 predecessorNodes = nx.ancestors(self._connectedQuanta, node)
665 predecessorNodes.add(node)
666 return self.subset(predecessorNodes)
668 def findCycle(self) -> List[Tuple[QuantumNode, QuantumNode]]:
669 """Check a graph for the presense of cycles and returns the edges of
670 any cycles found, or an empty list if there is no cycle.
672 Returns
673 -------
674 result : list of tuple of `QuantumNode`, `QuantumNode`
675 A list of any graph edges that form a cycle, or an empty list if
676 there is no cycle. Empty list to so support if graph.find_cycle()
677 syntax as an empty list is falsy.
678 """
679 try:
680 return nx.find_cycle(self._connectedQuanta)
681 except nx.NetworkXNoCycle:
682 return []
684 def saveUri(self, uri):
685 """Save `QuantumGraph` to the specified URI.
687 Parameters
688 ----------
689 uri : `ButlerURI` or `str`
690 URI to where the graph should be saved.
691 """
692 buffer = self._buildSaveObject()
693 butlerUri = ButlerURI(uri)
694 if butlerUri.getExtension() not in (".qgraph"):
695 raise TypeError(f"Can currently only save a graph in qgraph format not {uri}")
696 butlerUri.write(buffer) # type: ignore # Ignore because bytearray is safe to use in place of bytes
698 @property
699 def metadata(self) -> Optional[MappingProxyType[str, Any]]:
700 """
701 """
702 if self._metadata is None:
703 return None
704 return MappingProxyType(self._metadata)
706 @classmethod
707 def loadUri(cls, uri: Union[ButlerURI, str], universe: DimensionUniverse,
708 nodes: Optional[Iterable[int]] = None,
709 graphID: Optional[BuildId] = None
710 ) -> QuantumGraph:
711 """Read `QuantumGraph` from a URI.
713 Parameters
714 ----------
715 uri : `ButlerURI` or `str`
716 URI from where to load the graph.
717 universe: `~lsst.daf.butler.DimensionUniverse`
718 DimensionUniverse instance, not used by the method itself but
719 needed to ensure that registry data structures are initialized.
720 nodes: iterable of `int` or None
721 Numbers that correspond to nodes in the graph. If specified, only
722 these nodes will be loaded. Defaults to None, in which case all
723 nodes will be loaded.
724 graphID : `str` or `None`
725 If specified this ID is verified against the loaded graph prior to
726 loading any Nodes. This defaults to None in which case no
727 validation is done.
729 Returns
730 -------
731 graph : `QuantumGraph`
732 Resulting QuantumGraph instance.
734 Raises
735 ------
736 TypeError
737 Raised if pickle contains instance of a type other than
738 QuantumGraph.
739 ValueError
740 Raised if one or more of the nodes requested is not in the
741 `QuantumGraph` or if graphID parameter does not match the graph
742 being loaded or if the supplied uri does not point at a valid
743 `QuantumGraph` save file.
746 Notes
747 -----
748 Reading Quanta from pickle requires existence of singleton
749 DimensionUniverse which is usually instantiated during Registry
750 initialization. To make sure that DimensionUniverse exists this method
751 accepts dummy DimensionUniverse argument.
752 """
753 uri = ButlerURI(uri)
754 # With ButlerURI we have the choice of always using a local file
755 # or reading in the bytes directly. Reading in bytes can be more
756 # efficient for reasonably-sized pickle files when the resource
757 # is remote. For now use the local file variant. For a local file
758 # as_local() does nothing.
760 if uri.getExtension() in (".pickle", ".pkl"):
761 with uri.as_local() as local, open(local.ospath, "rb") as fd:
762 warnings.warn("Pickle graphs are deprecated, please re-save your graph with the save method")
763 qgraph = pickle.load(fd)
764 elif uri.getExtension() in ('.qgraph'):
765 with LoadHelper(uri) as loader:
766 qgraph = loader.load(nodes, graphID)
767 else:
768 raise ValueError("Only know how to handle files saved as `pickle`, `pkl`, or `qgraph`")
769 if not isinstance(qgraph, QuantumGraph):
770 raise TypeError(f"QuantumGraph save file contains unexpected object type: {type(qgraph)}")
771 return qgraph
773 def save(self, file: io.IO[bytes]):
774 """Save QuantumGraph to a file.
776 Presently we store QuantumGraph in pickle format, this could
777 potentially change in the future if better format is found.
779 Parameters
780 ----------
781 file : `io.BufferedIOBase`
782 File to write pickle data open in binary mode.
783 """
784 buffer = self._buildSaveObject()
785 file.write(buffer) # type: ignore # Ignore because bytearray is safe to use in place of bytes
787 def _buildSaveObject(self) -> bytearray:
788 # make some containers
789 pickleData = deque()
790 # node map is a list because json does not accept mapping keys that
791 # are not strings, so we store a list of key, value pairs that will
792 # be converted to a mapping on load
793 nodeMap = []
794 taskDefMap = {}
795 headerData = {}
796 protocol = 3
798 # Store the QauntumGraph BuildId, this will allow validating BuildIds
799 # at load time, prior to loading any QuantumNodes. Name chosen for
800 # unlikely conflicts.
801 headerData['GraphBuildID'] = self.graphID
802 headerData['Metadata'] = self._metadata
804 # counter for the number of bytes processed thus far
805 count = 0
806 # serialize out the task Defs recording the start and end bytes of each
807 # taskDef
808 for taskDef in self.taskGraph:
809 # compressing has very little impact on saving or load time, but
810 # a large impact on on disk size, so it is worth doing
811 dump = lzma.compress(pickle.dumps(taskDef, protocol=protocol))
812 taskDefMap[taskDef.label] = {"bytes": (count, count+len(dump))}
813 count += len(dump)
814 pickleData.append(dump)
816 headerData['TaskDefs'] = taskDefMap
818 # serialize the nodes, recording the start and end bytes of each node
819 for node in self:
820 node = copy.copy(node)
821 taskDef = node.taskDef
822 # Explicitly overload the "frozen-ness" of nodes to normalized out
823 # the taskDef, this saves a lot of space and load time. The label
824 # will be used to retrive the taskDef from the taskDefMap upon load
825 #
826 # This strategy was chosen instead of creating a new class that
827 # looked just like a QuantumNode but containing a label in place of
828 # a TaskDef because it would be needlessly slow to construct a
829 # bunch of new object to immediately serialize them and destroy the
830 # object. This seems like an acceptable use of Python's dynamic
831 # nature in a controlled way for optimization and simplicity.
832 object.__setattr__(node, 'taskDef', taskDef.label)
833 # compressing has very little impact on saving or load time, but
834 # a large impact on on disk size, so it is worth doing
835 dump = lzma.compress(pickle.dumps(node, protocol=protocol))
836 pickleData.append(dump)
837 nodeMap.append((int(node.nodeId.number), {"bytes": (count, count+len(dump))}))
838 count += len(dump)
840 # need to serialize this as a series of key,value tuples because of
841 # a limitation on how json cant do anyting but strings as keys
842 headerData['Nodes'] = nodeMap
844 # dump the headerData to json
845 header_encode = lzma.compress(json.dumps(headerData).encode())
847 # record the sizes as 2 unsigned long long numbers for a total of 16
848 # bytes
849 save_bytes = struct.pack(STRUCT_FMT_BASE, SAVE_VERSION)
851 fmt_string = STRUCT_FMT_STRING[SAVE_VERSION]
852 map_lengths = struct.pack(fmt_string, len(header_encode))
854 # write each component of the save out in a deterministic order
855 # buffer = io.BytesIO()
856 # buffer.write(map_lengths)
857 # buffer.write(taskDef_pickle)
858 # buffer.write(map_pickle)
859 buffer = bytearray()
860 buffer.extend(MAGIC_BYTES)
861 buffer.extend(save_bytes)
862 buffer.extend(map_lengths)
863 buffer.extend(header_encode)
864 # Iterate over the length of pickleData, and for each element pop the
865 # leftmost element off the deque and write it out. This is to save
866 # memory, as the memory is added to the buffer object, it is removed
867 # from from the container.
868 #
869 # Only this section needs to worry about memory pressue because
870 # everything else written to the buffer prior to this pickle data is
871 # only on the order of kilobytes to low numbers of megabytes.
872 while pickleData:
873 buffer.extend(pickleData.popleft())
874 return buffer
876 @classmethod
877 def load(cls, file: io.IO[bytes], universe: DimensionUniverse,
878 nodes: Optional[Iterable[int]] = None,
879 graphID: Optional[BuildId] = None
880 ) -> QuantumGraph:
881 """Read QuantumGraph from a file that was made by `save`.
883 Parameters
884 ----------
885 file : `io.IO` of bytes
886 File with pickle data open in binary mode.
887 universe: `~lsst.daf.butler.DimensionUniverse`
888 DimensionUniverse instance, not used by the method itself but
889 needed to ensure that registry data structures are initialized.
890 nodes: iterable of `int` or None
891 Numbers that correspond to nodes in the graph. If specified, only
892 these nodes will be loaded. Defaults to None, in which case all
893 nodes will be loaded.
894 graphID : `str` or `None`
895 If specified this ID is verified against the loaded graph prior to
896 loading any Nodes. This defaults to None in which case no
897 validation is done.
899 Returns
900 -------
901 graph : `QuantumGraph`
902 Resulting QuantumGraph instance.
904 Raises
905 ------
906 TypeError
907 Raised if pickle contains instance of a type other than
908 QuantumGraph.
909 ValueError
910 Raised if one or more of the nodes requested is not in the
911 `QuantumGraph` or if graphID parameter does not match the graph
912 being loaded or if the supplied uri does not point at a valid
913 `QuantumGraph` save file.
915 Notes
916 -----
917 Reading Quanta from pickle requires existence of singleton
918 DimensionUniverse which is usually instantiated during Registry
919 initialization. To make sure that DimensionUniverse exists this method
920 accepts dummy DimensionUniverse argument.
921 """
922 # Try to see if the file handle contains pickle data, this will be
923 # removed in the future
924 try:
925 qgraph = pickle.load(file)
926 warnings.warn("Pickle graphs are deprecated, please re-save your graph with the save method")
927 except pickle.UnpicklingError:
928 with LoadHelper(file) as loader: # type: ignore # needed because we don't have Protocols yet
929 qgraph = loader.load(nodes, graphID)
930 if not isinstance(qgraph, QuantumGraph):
931 raise TypeError(f"QuantumGraph pickle file has contains unexpected object type: {type(qgraph)}")
932 return qgraph
934 def iterTaskGraph(self) -> Generator[TaskDef, None, None]:
935 """Iterate over the `taskGraph` attribute in topological order
937 Yields
938 ------
939 taskDef : `TaskDef`
940 `TaskDef` objects in topological order
941 """
942 yield from nx.topological_sort(self.taskGraph)
944 @property
945 def graphID(self):
946 """Returns the ID generated by the graph at construction time
947 """
948 return self._buildId
950 def __iter__(self) -> Generator[QuantumNode, None, None]:
951 yield from nx.topological_sort(self._connectedQuanta)
953 def __len__(self) -> int:
954 return self._count
956 def __contains__(self, node: QuantumNode) -> bool:
957 return self._connectedQuanta.has_node(node)
959 def __getstate__(self) -> dict:
960 """Stores a compact form of the graph as a list of graph nodes, and a
961 tuple of task labels and task configs. The full graph can be
962 reconstructed with this information, and it preseves the ordering of
963 the graph ndoes.
964 """
965 return {"nodesList": list(self)}
967 def __setstate__(self, state: dict):
968 """Reconstructs the state of the graph from the information persisted
969 in getstate.
970 """
971 quanta: DefaultDict[TaskDef, Set[Quantum]] = defaultdict(set)
972 quantumToNodeId: Dict[Quantum, NodeId] = {}
973 quantumNode: QuantumNode
974 for quantumNode in state['nodesList']:
975 quanta[quantumNode.taskDef].add(quantumNode.quantum)
976 quantumToNodeId[quantumNode.quantum] = quantumNode.nodeId
977 _buildId = quantumNode.nodeId.buildId if state['nodesList'] else None # type: ignore
978 self._buildGraphs(quanta, _quantumToNodeId=quantumToNodeId, _buildId=_buildId)
980 def __eq__(self, other: object) -> bool:
981 if not isinstance(other, QuantumGraph):
982 return False
983 if len(self) != len(other):
984 return False
985 for node in self:
986 if node not in other:
987 return False
988 if self.determineInputsToQuantumNode(node) != other.determineInputsToQuantumNode(node):
989 return False
990 if self.determineOutputsOfQuantumNode(node) != other.determineOutputsOfQuantumNode(node):
991 return False
992 return list(self.taskGraph) == list(other.taskGraph)