Coverage for python/lsst/pipe/base/graph/graph.py: 19%
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23from lsst.daf.butler.core.datasets.type import DatasetType
25__all__ = ("QuantumGraph", "IncompatibleGraphError")
27import io
28import json
29import lzma
30import os
31import pickle
32import struct
33import time
34import uuid
35import warnings
36from collections import defaultdict, deque
37from itertools import chain
38from types import MappingProxyType
39from typing import (
40 Any,
41 DefaultDict,
42 Dict,
43 FrozenSet,
44 Generator,
45 Iterable,
46 List,
47 Mapping,
48 MutableMapping,
49 Optional,
50 Set,
51 Tuple,
52 TypeVar,
53 Union,
54)
56import networkx as nx
57from lsst.daf.butler import ButlerURI, DatasetRef, DimensionRecordsAccumulator, DimensionUniverse, Quantum
58from networkx.drawing.nx_agraph import write_dot
60from ..connections import iterConnections
61from ..pipeline import TaskDef
62from ._implDetails import DatasetTypeName, _DatasetTracker, _pruner
63from ._loadHelpers import LoadHelper
64from ._versionDeserializers import DESERIALIZER_MAP
65from .quantumNode import BuildId, QuantumNode
67_T = TypeVar("_T", bound="QuantumGraph")
69# modify this constant any time the on disk representation of the save file
70# changes, and update the load helpers to behave properly for each version.
71SAVE_VERSION = 3
73# Strings used to describe the format for the preamble bytes in a file save
74# The base is a big endian encoded unsigned short that is used to hold the
75# file format version. This allows reading version bytes and determine which
76# loading code should be used for the rest of the file
77STRUCT_FMT_BASE = ">H"
78#
79# Version 1
80# This marks a big endian encoded format with an unsigned short, an unsigned
81# long long, and an unsigned long long in the byte stream
82# Version 2
83# A big endian encoded format with an unsigned long long byte stream used to
84# indicate the total length of the entire header.
85STRUCT_FMT_STRING = {1: ">QQ", 2: ">Q"}
87# magic bytes that help determine this is a graph save
88MAGIC_BYTES = b"qgraph4\xf6\xe8\xa9"
91class IncompatibleGraphError(Exception):
92 """Exception class to indicate that a lookup by NodeId is impossible due
93 to incompatibilities
94 """
96 pass
99class QuantumGraph:
100 """QuantumGraph is a directed acyclic graph of `QuantumNode` objects
102 This data structure represents a concrete workflow generated from a
103 `Pipeline`.
105 Parameters
106 ----------
107 quanta : Mapping of `TaskDef` to sets of `Quantum`
108 This maps tasks (and their configs) to the sets of data they are to
109 process.
110 metadata : Optional Mapping of `str` to primitives
111 This is an optional parameter of extra data to carry with the graph.
112 Entries in this mapping should be able to be serialized in JSON.
114 Raises
115 ------
116 ValueError
117 Raised if the graph is pruned such that some tasks no longer have nodes
118 associated with them.
119 """
121 def __init__(
122 self,
123 quanta: Mapping[TaskDef, Set[Quantum]],
124 metadata: Optional[Mapping[str, Any]] = None,
125 pruneRefs: Optional[Iterable[DatasetRef]] = None,
126 ):
127 self._buildGraphs(quanta, metadata=metadata, pruneRefs=pruneRefs)
129 def _buildGraphs(
130 self,
131 quanta: Mapping[TaskDef, Set[Quantum]],
132 *,
133 _quantumToNodeId: Optional[Mapping[Quantum, uuid.UUID]] = None,
134 _buildId: Optional[BuildId] = None,
135 metadata: Optional[Mapping[str, Any]] = None,
136 pruneRefs: Optional[Iterable[DatasetRef]] = None,
137 ):
138 """Builds the graph that is used to store the relation between tasks,
139 and the graph that holds the relations between quanta
140 """
141 self._metadata = metadata
142 self._buildId = _buildId if _buildId is not None else BuildId(f"{time.time()}-{os.getpid()}")
143 # Data structures used to identify relations between components;
144 # DatasetTypeName -> TaskDef for task,
145 # and DatasetRef -> QuantumNode for the quanta
146 self._datasetDict = _DatasetTracker[DatasetTypeName, TaskDef](createInverse=True)
147 self._datasetRefDict = _DatasetTracker[DatasetRef, QuantumNode]()
149 self._nodeIdMap: Dict[uuid.UUID, QuantumNode] = {}
150 self._taskToQuantumNode: MutableMapping[TaskDef, Set[QuantumNode]] = defaultdict(set)
151 for taskDef, quantumSet in quanta.items():
152 connections = taskDef.connections
154 # For each type of connection in the task, add a key to the
155 # `_DatasetTracker` for the connections name, with a value of
156 # the TaskDef in the appropriate field
157 for inpt in iterConnections(connections, ("inputs", "prerequisiteInputs", "initInputs")):
158 self._datasetDict.addConsumer(DatasetTypeName(inpt.name), taskDef)
160 for output in iterConnections(connections, ("outputs",)):
161 self._datasetDict.addProducer(DatasetTypeName(output.name), taskDef)
163 # For each `Quantum` in the set of all `Quantum` for this task,
164 # add a key to the `_DatasetTracker` that is a `DatasetRef` for one
165 # of the individual datasets inside the `Quantum`, with a value of
166 # a newly created QuantumNode to the appropriate input/output
167 # field.
168 for quantum in quantumSet:
169 if _quantumToNodeId:
170 if (nodeId := _quantumToNodeId.get(quantum)) is None:
171 raise ValueError(
172 "If _quantuMToNodeNumber is not None, all quanta must have an "
173 "associated value in the mapping"
174 )
175 else:
176 nodeId = uuid.uuid4()
178 inits = quantum.initInputs.values()
179 inputs = quantum.inputs.values()
180 value = QuantumNode(quantum, taskDef, nodeId)
181 self._taskToQuantumNode[taskDef].add(value)
182 self._nodeIdMap[nodeId] = value
184 for dsRef in chain(inits, inputs):
185 # unfortunately, `Quantum` allows inits to be individual
186 # `DatasetRef`s or an Iterable of such, so there must
187 # be an instance check here
188 if isinstance(dsRef, Iterable):
189 for sub in dsRef:
190 if sub.isComponent():
191 sub = sub.makeCompositeRef()
192 self._datasetRefDict.addConsumer(sub, value)
193 else:
194 if dsRef.isComponent():
195 dsRef = dsRef.makeCompositeRef()
196 self._datasetRefDict.addConsumer(dsRef, value)
197 for dsRef in chain.from_iterable(quantum.outputs.values()):
198 self._datasetRefDict.addProducer(dsRef, value)
200 if pruneRefs is not None:
201 # track what refs were pruned and prune the graph
202 prunes = set()
203 _pruner(self._datasetRefDict, pruneRefs, alreadyPruned=prunes)
205 # recreate the taskToQuantumNode dict removing nodes that have been
206 # pruned. Keep track of task defs that now have no QuantumNodes
207 emptyTasks: Set[str] = set()
208 newTaskToQuantumNode: DefaultDict[TaskDef, Set[QuantumNode]] = defaultdict(set)
209 # accumulate all types
210 types_ = set()
211 # tracker for any pruneRefs that have caused tasks to have no nodes
212 # This helps the user find out what caused the issues seen.
213 culprits = set()
214 # Find all the types from the refs to prune
215 for r in pruneRefs:
216 types_.add(r.datasetType)
218 # For each of the tasks, and their associated nodes, remove any
219 # any nodes that were pruned. If there are no nodes associated
220 # with a task, record that task, and find out if that was due to
221 # a type from an input ref to prune.
222 for td, taskNodes in self._taskToQuantumNode.items():
223 diff = taskNodes.difference(prunes)
224 if len(diff) == 0:
225 if len(taskNodes) != 0:
226 tp: DatasetType
227 for tp in types_:
228 if (tmpRefs := next(iter(taskNodes)).quantum.inputs.get(tp)) and not set(
229 tmpRefs
230 ).difference(pruneRefs):
231 culprits.add(tp.name)
232 emptyTasks.add(td.label)
233 newTaskToQuantumNode[td] = diff
235 # update the internal dict
236 self._taskToQuantumNode = newTaskToQuantumNode
238 if emptyTasks:
239 raise ValueError(
240 f"{', '.join(emptyTasks)} task(s) have no nodes associated with them "
241 f"after graph pruning; {', '.join(culprits)} caused over-pruning"
242 )
244 # Graph of quanta relations
245 self._connectedQuanta = self._datasetRefDict.makeNetworkXGraph()
246 self._count = len(self._connectedQuanta)
248 # Graph of task relations, used in various methods
249 self._taskGraph = self._datasetDict.makeNetworkXGraph()
251 # convert default dict into a regular to prevent accidental key
252 # insertion
253 self._taskToQuantumNode = dict(self._taskToQuantumNode.items())
255 @property
256 def taskGraph(self) -> nx.DiGraph:
257 """Return a graph representing the relations between the tasks inside
258 the quantum graph.
260 Returns
261 -------
262 taskGraph : `networkx.Digraph`
263 Internal datastructure that holds relations of `TaskDef` objects
264 """
265 return self._taskGraph
267 @property
268 def graph(self) -> nx.DiGraph:
269 """Return a graph representing the relations between all the
270 `QuantumNode` objects. Largely it should be preferred to iterate
271 over, and use methods of this class, but sometimes direct access to
272 the networkx object may be helpful
274 Returns
275 -------
276 graph : `networkx.Digraph`
277 Internal datastructure that holds relations of `QuantumNode`
278 objects
279 """
280 return self._connectedQuanta
282 @property
283 def inputQuanta(self) -> Iterable[QuantumNode]:
284 """Make a `list` of all `QuantumNode` objects that are 'input' nodes
285 to the graph, meaning those nodes to not depend on any other nodes in
286 the graph.
288 Returns
289 -------
290 inputNodes : iterable of `QuantumNode`
291 A list of nodes that are inputs to the graph
292 """
293 return (q for q, n in self._connectedQuanta.in_degree if n == 0)
295 @property
296 def outputQuanta(self) -> Iterable[QuantumNode]:
297 """Make a `list` of all `QuantumNode` objects that are 'output' nodes
298 to the graph, meaning those nodes have no nodes that depend them in
299 the graph.
301 Returns
302 -------
303 outputNodes : iterable of `QuantumNode`
304 A list of nodes that are outputs of the graph
305 """
306 return [q for q, n in self._connectedQuanta.out_degree if n == 0]
308 @property
309 def allDatasetTypes(self) -> Tuple[DatasetTypeName, ...]:
310 """Return all the `DatasetTypeName` objects that are contained inside
311 the graph.
313 Returns
314 -------
315 tuple of `DatasetTypeName`
316 All the data set type names that are present in the graph
317 """
318 return tuple(self._datasetDict.keys())
320 @property
321 def isConnected(self) -> bool:
322 """Return True if all of the nodes in the graph are connected, ignores
323 directionality of connections.
324 """
325 return nx.is_weakly_connected(self._connectedQuanta)
327 def pruneGraphFromRefs(self: _T, refs: Iterable[DatasetRef]) -> _T:
328 r"""Return a graph pruned of input `~lsst.daf.butler.DatasetRef`\ s
329 and nodes which depend on them.
331 Parameters
332 ----------
333 refs : `Iterable` of `DatasetRef`
334 Refs which should be removed from resulting graph
336 Returns
337 -------
338 graph : `QuantumGraph`
339 A graph that has been pruned of specified refs and the nodes that
340 depend on them.
341 """
342 newInst = object.__new__(type(self))
343 quantumMap = defaultdict(set)
344 for node in self:
345 quantumMap[node.taskDef].add(node.quantum)
347 # convert to standard dict to prevent accidental key insertion
348 quantumMap = dict(quantumMap.items())
350 newInst._buildGraphs(
351 quantumMap,
352 _quantumToNodeId={n.quantum: n.nodeId for n in self},
353 metadata=self._metadata,
354 pruneRefs=refs,
355 )
356 return newInst
358 def getQuantumNodeByNodeId(self, nodeId: uuid.UUID) -> QuantumNode:
359 """Lookup a `QuantumNode` from an id associated with the node.
361 Parameters
362 ----------
363 nodeId : `NodeId`
364 The number associated with a node
366 Returns
367 -------
368 node : `QuantumNode`
369 The node corresponding with input number
371 Raises
372 ------
373 KeyError
374 Raised if the requested nodeId is not in the graph.
375 """
376 return self._nodeIdMap[nodeId]
378 def getQuantaForTask(self, taskDef: TaskDef) -> FrozenSet[Quantum]:
379 """Return all the `Quantum` associated with a `TaskDef`.
381 Parameters
382 ----------
383 taskDef : `TaskDef`
384 The `TaskDef` for which `Quantum` are to be queried
386 Returns
387 -------
388 frozenset of `Quantum`
389 The `set` of `Quantum` that is associated with the specified
390 `TaskDef`.
391 """
392 return frozenset(node.quantum for node in self._taskToQuantumNode[taskDef])
394 def getNodesForTask(self, taskDef: TaskDef) -> FrozenSet[QuantumNode]:
395 """Return all the `QuantumNodes` associated with a `TaskDef`.
397 Parameters
398 ----------
399 taskDef : `TaskDef`
400 The `TaskDef` for which `Quantum` are to be queried
402 Returns
403 -------
404 frozenset of `QuantumNodes`
405 The `frozenset` of `QuantumNodes` that is associated with the
406 specified `TaskDef`.
407 """
408 return frozenset(self._taskToQuantumNode[taskDef])
410 def findTasksWithInput(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]:
411 """Find all tasks that have the specified dataset type name as an
412 input.
414 Parameters
415 ----------
416 datasetTypeName : `str`
417 A string representing the name of a dataset type to be queried,
418 can also accept a `DatasetTypeName` which is a `NewType` of str for
419 type safety in static type checking.
421 Returns
422 -------
423 tasks : iterable of `TaskDef`
424 `TaskDef` objects that have the specified `DatasetTypeName` as an
425 input, list will be empty if no tasks use specified
426 `DatasetTypeName` as an input.
428 Raises
429 ------
430 KeyError
431 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
432 """
433 return (c for c in self._datasetDict.getConsumers(datasetTypeName))
435 def findTaskWithOutput(self, datasetTypeName: DatasetTypeName) -> Optional[TaskDef]:
436 """Find all tasks that have the specified dataset type name as an
437 output.
439 Parameters
440 ----------
441 datasetTypeName : `str`
442 A string representing the name of a dataset type to be queried,
443 can also accept a `DatasetTypeName` which is a `NewType` of str for
444 type safety in static type checking.
446 Returns
447 -------
448 `TaskDef` or `None`
449 `TaskDef` that outputs `DatasetTypeName` as an output or None if
450 none of the tasks produce this `DatasetTypeName`.
452 Raises
453 ------
454 KeyError
455 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
456 """
457 return self._datasetDict.getProducer(datasetTypeName)
459 def tasksWithDSType(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]:
460 """Find all tasks that are associated with the specified dataset type
461 name.
463 Parameters
464 ----------
465 datasetTypeName : `str`
466 A string representing the name of a dataset type to be queried,
467 can also accept a `DatasetTypeName` which is a `NewType` of str for
468 type safety in static type checking.
470 Returns
471 -------
472 result : iterable of `TaskDef`
473 `TaskDef` objects that are associated with the specified
474 `DatasetTypeName`
476 Raises
477 ------
478 KeyError
479 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
480 """
481 return self._datasetDict.getAll(datasetTypeName)
483 def findTaskDefByName(self, taskName: str) -> List[TaskDef]:
484 """Determine which `TaskDef` objects in this graph are associated
485 with a `str` representing a task name (looks at the taskName property
486 of `TaskDef` objects).
488 Returns a list of `TaskDef` objects as a `PipelineTask` may appear
489 multiple times in a graph with different labels.
491 Parameters
492 ----------
493 taskName : str
494 Name of a task to search for
496 Returns
497 -------
498 result : list of `TaskDef`
499 List of the `TaskDef` objects that have the name specified.
500 Multiple values are returned in the case that a task is used
501 multiple times with different labels.
502 """
503 results = []
504 for task in self._taskToQuantumNode.keys():
505 split = task.taskName.split(".")
506 if split[-1] == taskName:
507 results.append(task)
508 return results
510 def findTaskDefByLabel(self, label: str) -> Optional[TaskDef]:
511 """Determine which `TaskDef` objects in this graph are associated
512 with a `str` representing a tasks label.
514 Parameters
515 ----------
516 taskName : str
517 Name of a task to search for
519 Returns
520 -------
521 result : `TaskDef`
522 `TaskDef` objects that has the specified label.
523 """
524 for task in self._taskToQuantumNode.keys():
525 if label == task.label:
526 return task
527 return None
529 def findQuantaWithDSType(self, datasetTypeName: DatasetTypeName) -> Set[Quantum]:
530 """Return all the `Quantum` that contain a specified `DatasetTypeName`.
532 Parameters
533 ----------
534 datasetTypeName : `str`
535 The name of the dataset type to search for as a string,
536 can also accept a `DatasetTypeName` which is a `NewType` of str for
537 type safety in static type checking.
539 Returns
540 -------
541 result : `set` of `QuantumNode` objects
542 A `set` of `QuantumNode`s that contain specified `DatasetTypeName`
544 Raises
545 ------
546 KeyError
547 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
549 """
550 tasks = self._datasetDict.getAll(datasetTypeName)
551 result: Set[Quantum] = set()
552 result = result.union(quantum for task in tasks for quantum in self.getQuantaForTask(task))
553 return result
555 def checkQuantumInGraph(self, quantum: Quantum) -> bool:
556 """Check if specified quantum appears in the graph as part of a node.
558 Parameters
559 ----------
560 quantum : `Quantum`
561 The quantum to search for
563 Returns
564 -------
565 `bool`
566 The result of searching for the quantum
567 """
568 for node in self:
569 if quantum == node.quantum:
570 return True
571 return False
573 def writeDotGraph(self, output: Union[str, io.BufferedIOBase]):
574 """Write out the graph as a dot graph.
576 Parameters
577 ----------
578 output : str or `io.BufferedIOBase`
579 Either a filesystem path to write to, or a file handle object
580 """
581 write_dot(self._connectedQuanta, output)
583 def subset(self: _T, nodes: Union[QuantumNode, Iterable[QuantumNode]]) -> _T:
584 """Create a new graph object that contains the subset of the nodes
585 specified as input. Node number is preserved.
587 Parameters
588 ----------
589 nodes : `QuantumNode` or iterable of `QuantumNode`
591 Returns
592 -------
593 graph : instance of graph type
594 An instance of the type from which the subset was created
595 """
596 if not isinstance(nodes, Iterable):
597 nodes = (nodes,)
598 quantumSubgraph = self._connectedQuanta.subgraph(nodes).nodes
599 quantumMap = defaultdict(set)
601 node: QuantumNode
602 for node in quantumSubgraph:
603 quantumMap[node.taskDef].add(node.quantum)
605 # convert to standard dict to prevent accidental key insertion
606 quantumMap = dict(quantumMap.items())
607 # Create an empty graph, and then populate it with custom mapping
608 newInst = type(self)({})
609 newInst._buildGraphs(
610 quantumMap,
611 _quantumToNodeId={n.quantum: n.nodeId for n in nodes},
612 _buildId=self._buildId,
613 metadata=self._metadata,
614 )
615 return newInst
617 def subsetToConnected(self: _T) -> Tuple[_T, ...]:
618 """Generate a list of subgraphs where each is connected.
620 Returns
621 -------
622 result : list of `QuantumGraph`
623 A list of graphs that are each connected
624 """
625 return tuple(
626 self.subset(connectedSet)
627 for connectedSet in nx.weakly_connected_components(self._connectedQuanta)
628 )
630 def determineInputsToQuantumNode(self, node: QuantumNode) -> Set[QuantumNode]:
631 """Return a set of `QuantumNode` that are direct inputs to a specified
632 node.
634 Parameters
635 ----------
636 node : `QuantumNode`
637 The node of the graph for which inputs are to be determined
639 Returns
640 -------
641 set of `QuantumNode`
642 All the nodes that are direct inputs to specified node
643 """
644 return set(pred for pred in self._connectedQuanta.predecessors(node))
646 def determineOutputsOfQuantumNode(self, node: QuantumNode) -> Set[QuantumNode]:
647 """Return a set of `QuantumNode` that are direct outputs of a specified
648 node.
650 Parameters
651 ----------
652 node : `QuantumNode`
653 The node of the graph for which outputs are to be determined
655 Returns
656 -------
657 set of `QuantumNode`
658 All the nodes that are direct outputs to specified node
659 """
660 return set(succ for succ in self._connectedQuanta.successors(node))
662 def determineConnectionsOfQuantumNode(self: _T, node: QuantumNode) -> _T:
663 """Return a graph of `QuantumNode` that are direct inputs and outputs
664 of a specified node.
666 Parameters
667 ----------
668 node : `QuantumNode`
669 The node of the graph for which connected nodes are to be
670 determined.
672 Returns
673 -------
674 graph : graph of `QuantumNode`
675 All the nodes that are directly connected to specified node
676 """
677 nodes = self.determineInputsToQuantumNode(node).union(self.determineOutputsOfQuantumNode(node))
678 nodes.add(node)
679 return self.subset(nodes)
681 def determineAncestorsOfQuantumNode(self: _T, node: QuantumNode) -> _T:
682 """Return a graph of the specified node and all the ancestor nodes
683 directly reachable by walking edges.
685 Parameters
686 ----------
687 node : `QuantumNode`
688 The node for which all ansestors are to be determined
690 Returns
691 -------
692 graph of `QuantumNode`
693 Graph of node and all of its ansestors
694 """
695 predecessorNodes = nx.ancestors(self._connectedQuanta, node)
696 predecessorNodes.add(node)
697 return self.subset(predecessorNodes)
699 def findCycle(self) -> List[Tuple[QuantumNode, QuantumNode]]:
700 """Check a graph for the presense of cycles and returns the edges of
701 any cycles found, or an empty list if there is no cycle.
703 Returns
704 -------
705 result : list of tuple of `QuantumNode`, `QuantumNode`
706 A list of any graph edges that form a cycle, or an empty list if
707 there is no cycle. Empty list to so support if graph.find_cycle()
708 syntax as an empty list is falsy.
709 """
710 try:
711 return nx.find_cycle(self._connectedQuanta)
712 except nx.NetworkXNoCycle:
713 return []
715 def saveUri(self, uri):
716 """Save `QuantumGraph` to the specified URI.
718 Parameters
719 ----------
720 uri : `ButlerURI` or `str`
721 URI to where the graph should be saved.
722 """
723 buffer = self._buildSaveObject()
724 butlerUri = ButlerURI(uri)
725 if butlerUri.getExtension() not in (".qgraph"):
726 raise TypeError(f"Can currently only save a graph in qgraph format not {uri}")
727 butlerUri.write(buffer) # type: ignore # Ignore because bytearray is safe to use in place of bytes
729 @property
730 def metadata(self) -> Optional[MappingProxyType[str, Any]]:
731 """ """
732 if self._metadata is None:
733 return None
734 return MappingProxyType(self._metadata)
736 @classmethod
737 def loadUri(
738 cls,
739 uri: Union[ButlerURI, str],
740 universe: DimensionUniverse,
741 nodes: Optional[Iterable[Union[str, uuid.UUID]]] = None,
742 graphID: Optional[BuildId] = None,
743 minimumVersion: int = 3,
744 ) -> QuantumGraph:
745 """Read `QuantumGraph` from a URI.
747 Parameters
748 ----------
749 uri : `ButlerURI` or `str`
750 URI from where to load the graph.
751 universe: `~lsst.daf.butler.DimensionUniverse`
752 DimensionUniverse instance, not used by the method itself but
753 needed to ensure that registry data structures are initialized.
754 nodes: iterable of `int` or None
755 Numbers that correspond to nodes in the graph. If specified, only
756 these nodes will be loaded. Defaults to None, in which case all
757 nodes will be loaded.
758 graphID : `str` or `None`
759 If specified this ID is verified against the loaded graph prior to
760 loading any Nodes. This defaults to None in which case no
761 validation is done.
762 minimumVersion : int
763 Minimum version of a save file to load. Set to -1 to load all
764 versions. Older versions may need to be loaded, and re-saved
765 to upgrade them to the latest format before they can be used in
766 production.
768 Returns
769 -------
770 graph : `QuantumGraph`
771 Resulting QuantumGraph instance.
773 Raises
774 ------
775 TypeError
776 Raised if pickle contains instance of a type other than
777 QuantumGraph.
778 ValueError
779 Raised if one or more of the nodes requested is not in the
780 `QuantumGraph` or if graphID parameter does not match the graph
781 being loaded or if the supplied uri does not point at a valid
782 `QuantumGraph` save file.
785 Notes
786 -----
787 Reading Quanta from pickle requires existence of singleton
788 DimensionUniverse which is usually instantiated during Registry
789 initialization. To make sure that DimensionUniverse exists this method
790 accepts dummy DimensionUniverse argument.
791 """
792 uri = ButlerURI(uri)
793 # With ButlerURI we have the choice of always using a local file
794 # or reading in the bytes directly. Reading in bytes can be more
795 # efficient for reasonably-sized pickle files when the resource
796 # is remote. For now use the local file variant. For a local file
797 # as_local() does nothing.
799 if uri.getExtension() in (".pickle", ".pkl"):
800 with uri.as_local() as local, open(local.ospath, "rb") as fd:
801 warnings.warn("Pickle graphs are deprecated, please re-save your graph with the save method")
802 qgraph = pickle.load(fd)
803 elif uri.getExtension() in (".qgraph"):
804 with LoadHelper(uri, minimumVersion) as loader:
805 qgraph = loader.load(universe, nodes, graphID)
806 else:
807 raise ValueError("Only know how to handle files saved as `pickle`, `pkl`, or `qgraph`")
808 if not isinstance(qgraph, QuantumGraph):
809 raise TypeError(f"QuantumGraph save file contains unexpected object type: {type(qgraph)}")
810 return qgraph
812 @classmethod
813 def readHeader(cls, uri: Union[ButlerURI, str], minimumVersion: int = 3) -> Optional[str]:
814 """Read the header of a `QuantumGraph` pointed to by the uri parameter
815 and return it as a string.
817 Parameters
818 ----------
819 uri : `~lsst.daf.butler.ButlerURI` or `str`
820 The location of the `QuantumGraph` to load. If the argument is a
821 string, it must correspond to a valid `~lsst.daf.butler.ButlerURI`
822 path.
823 minimumVersion : int
824 Minimum version of a save file to load. Set to -1 to load all
825 versions. Older versions may need to be loaded, and re-saved
826 to upgrade them to the latest format before they can be used in
827 production.
829 Returns
830 -------
831 header : `str` or `None`
832 The header associated with the specified `QuantumGraph` it there is
833 one, else `None`.
835 Raises
836 ------
837 ValueError
838 Raised if `QuantuGraph` was saved as a pickle.
839 Raised if the extention of the file specified by uri is not a
840 `QuantumGraph` extention.
841 """
842 uri = ButlerURI(uri)
843 if uri.getExtension() in (".pickle", ".pkl"):
844 raise ValueError("Reading a header from a pickle save is not supported")
845 elif uri.getExtension() in (".qgraph"):
846 return LoadHelper(uri, minimumVersion).readHeader()
847 else:
848 raise ValueError("Only know how to handle files saved as `qgraph`")
850 def buildAndPrintHeader(self):
851 """Creates a header that would be used in a save of this object and
852 prints it out to standard out.
853 """
854 _, header = self._buildSaveObject(returnHeader=True)
855 print(json.dumps(header))
857 def save(self, file: io.IO[bytes]):
858 """Save QuantumGraph to a file.
860 Presently we store QuantumGraph in pickle format, this could
861 potentially change in the future if better format is found.
863 Parameters
864 ----------
865 file : `io.BufferedIOBase`
866 File to write pickle data open in binary mode.
867 """
868 buffer = self._buildSaveObject()
869 file.write(buffer) # type: ignore # Ignore because bytearray is safe to use in place of bytes
871 def _buildSaveObject(self, returnHeader: bool = False) -> Union[bytearray, Tuple[bytearray, Dict]]:
872 # make some containers
873 jsonData = deque()
874 # node map is a list because json does not accept mapping keys that
875 # are not strings, so we store a list of key, value pairs that will
876 # be converted to a mapping on load
877 nodeMap = []
878 taskDefMap = {}
879 headerData = {}
881 # Store the QauntumGraph BuildId, this will allow validating BuildIds
882 # at load time, prior to loading any QuantumNodes. Name chosen for
883 # unlikely conflicts.
884 headerData["GraphBuildID"] = self.graphID
885 headerData["Metadata"] = self._metadata
887 # counter for the number of bytes processed thus far
888 count = 0
889 # serialize out the task Defs recording the start and end bytes of each
890 # taskDef
891 inverseLookup = self._datasetDict.inverse
892 taskDef: TaskDef
893 # sort by task label to ensure serialization happens in the same order
894 for taskDef in self.taskGraph:
895 # compressing has very little impact on saving or load time, but
896 # a large impact on on disk size, so it is worth doing
897 taskDescription = {}
898 # save the fully qualified name, as TaskDef not not require this,
899 # but by doing so can save space and is easier to transport
900 taskDescription["taskName"] = f"{taskDef.taskClass.__module__}.{taskDef.taskClass.__qualname__}"
901 # save the config as a text stream that will be un-persisted on the
902 # other end
903 stream = io.StringIO()
904 taskDef.config.saveToStream(stream)
905 taskDescription["config"] = stream.getvalue()
906 taskDescription["label"] = taskDef.label
908 inputs = []
909 outputs = []
911 # Determine the connection between all of tasks and save that in
912 # the header as a list of connections and edges in each task
913 # this will help in un-persisting, and possibly in a "quick view"
914 # method that does not require everything to be un-persisted
915 #
916 # Typing returns can't be parameter dependent
917 for connection in inverseLookup[taskDef]: # type: ignore
918 consumers = self._datasetDict.getConsumers(connection)
919 producer = self._datasetDict.getProducer(connection)
920 if taskDef in consumers:
921 # This checks if the task consumes the connection directly
922 # from the datastore or it is produced by another task
923 producerLabel = producer.label if producer is not None else "datastore"
924 inputs.append((producerLabel, connection))
925 elif taskDef not in consumers and producer is taskDef:
926 # If there are no consumers for this tasks produced
927 # connection, the output will be said to be the datastore
928 # in which case the for loop will be a zero length loop
929 if not consumers:
930 outputs.append(("datastore", connection))
931 for td in consumers:
932 outputs.append((td.label, connection))
934 # dump to json string, and encode that string to bytes and then
935 # conpress those bytes
936 dump = lzma.compress(json.dumps(taskDescription).encode())
937 # record the sizing and relation information
938 taskDefMap[taskDef.label] = {
939 "bytes": (count, count + len(dump)),
940 "inputs": inputs,
941 "outputs": outputs,
942 }
943 count += len(dump)
944 jsonData.append(dump)
946 headerData["TaskDefs"] = taskDefMap
948 # serialize the nodes, recording the start and end bytes of each node
949 dimAccumulator = DimensionRecordsAccumulator()
950 for node in self:
951 # compressing has very little impact on saving or load time, but
952 # a large impact on on disk size, so it is worth doing
953 simpleNode = node.to_simple(accumulator=dimAccumulator)
955 dump = lzma.compress(simpleNode.json().encode())
956 jsonData.append(dump)
957 nodeMap.append(
958 (
959 str(node.nodeId),
960 {
961 "bytes": (count, count + len(dump)),
962 "inputs": [str(n.nodeId) for n in self.determineInputsToQuantumNode(node)],
963 "outputs": [str(n.nodeId) for n in self.determineOutputsOfQuantumNode(node)],
964 },
965 )
966 )
967 count += len(dump)
969 headerData["DimensionRecords"] = {
970 key: value.dict() for key, value in dimAccumulator.makeSerializedDimensionRecordMapping().items()
971 }
973 # need to serialize this as a series of key,value tuples because of
974 # a limitation on how json cant do anyting but strings as keys
975 headerData["Nodes"] = nodeMap
977 # dump the headerData to json
978 header_encode = lzma.compress(json.dumps(headerData).encode())
980 # record the sizes as 2 unsigned long long numbers for a total of 16
981 # bytes
982 save_bytes = struct.pack(STRUCT_FMT_BASE, SAVE_VERSION)
984 fmt_string = DESERIALIZER_MAP[SAVE_VERSION].FMT_STRING()
985 map_lengths = struct.pack(fmt_string, len(header_encode))
987 # write each component of the save out in a deterministic order
988 # buffer = io.BytesIO()
989 # buffer.write(map_lengths)
990 # buffer.write(taskDef_pickle)
991 # buffer.write(map_pickle)
992 buffer = bytearray()
993 buffer.extend(MAGIC_BYTES)
994 buffer.extend(save_bytes)
995 buffer.extend(map_lengths)
996 buffer.extend(header_encode)
997 # Iterate over the length of pickleData, and for each element pop the
998 # leftmost element off the deque and write it out. This is to save
999 # memory, as the memory is added to the buffer object, it is removed
1000 # from from the container.
1001 #
1002 # Only this section needs to worry about memory pressue because
1003 # everything else written to the buffer prior to this pickle data is
1004 # only on the order of kilobytes to low numbers of megabytes.
1005 while jsonData:
1006 buffer.extend(jsonData.popleft())
1007 if returnHeader:
1008 return buffer, headerData
1009 else:
1010 return buffer
1012 @classmethod
1013 def load(
1014 cls,
1015 file: io.IO[bytes],
1016 universe: DimensionUniverse,
1017 nodes: Optional[Iterable[uuid.UUID]] = None,
1018 graphID: Optional[BuildId] = None,
1019 minimumVersion: int = 3,
1020 ) -> QuantumGraph:
1021 """Read QuantumGraph from a file that was made by `save`.
1023 Parameters
1024 ----------
1025 file : `io.IO` of bytes
1026 File with pickle data open in binary mode.
1027 universe: `~lsst.daf.butler.DimensionUniverse`
1028 DimensionUniverse instance, not used by the method itself but
1029 needed to ensure that registry data structures are initialized.
1030 nodes: iterable of `int` or None
1031 Numbers that correspond to nodes in the graph. If specified, only
1032 these nodes will be loaded. Defaults to None, in which case all
1033 nodes will be loaded.
1034 graphID : `str` or `None`
1035 If specified this ID is verified against the loaded graph prior to
1036 loading any Nodes. This defaults to None in which case no
1037 validation is done.
1038 minimumVersion : int
1039 Minimum version of a save file to load. Set to -1 to load all
1040 versions. Older versions may need to be loaded, and re-saved
1041 to upgrade them to the latest format before they can be used in
1042 production.
1044 Returns
1045 -------
1046 graph : `QuantumGraph`
1047 Resulting QuantumGraph instance.
1049 Raises
1050 ------
1051 TypeError
1052 Raised if pickle contains instance of a type other than
1053 QuantumGraph.
1054 ValueError
1055 Raised if one or more of the nodes requested is not in the
1056 `QuantumGraph` or if graphID parameter does not match the graph
1057 being loaded or if the supplied uri does not point at a valid
1058 `QuantumGraph` save file.
1060 Notes
1061 -----
1062 Reading Quanta from pickle requires existence of singleton
1063 DimensionUniverse which is usually instantiated during Registry
1064 initialization. To make sure that DimensionUniverse exists this method
1065 accepts dummy DimensionUniverse argument.
1066 """
1067 # Try to see if the file handle contains pickle data, this will be
1068 # removed in the future
1069 try:
1070 qgraph = pickle.load(file)
1071 warnings.warn("Pickle graphs are deprecated, please re-save your graph with the save method")
1072 except pickle.UnpicklingError:
1073 # needed because we don't have Protocols yet
1074 with LoadHelper(file, minimumVersion) as loader: # type: ignore
1075 qgraph = loader.load(universe, nodes, graphID)
1076 if not isinstance(qgraph, QuantumGraph):
1077 raise TypeError(f"QuantumGraph pickle file has contains unexpected object type: {type(qgraph)}")
1078 return qgraph
1080 def iterTaskGraph(self) -> Generator[TaskDef, None, None]:
1081 """Iterate over the `taskGraph` attribute in topological order
1083 Yields
1084 ------
1085 taskDef : `TaskDef`
1086 `TaskDef` objects in topological order
1087 """
1088 yield from nx.topological_sort(self.taskGraph)
1090 @property
1091 def graphID(self):
1092 """Returns the ID generated by the graph at construction time"""
1093 return self._buildId
1095 def __iter__(self) -> Generator[QuantumNode, None, None]:
1096 yield from nx.topological_sort(self._connectedQuanta)
1098 def __len__(self) -> int:
1099 return self._count
1101 def __contains__(self, node: QuantumNode) -> bool:
1102 return self._connectedQuanta.has_node(node)
1104 def __getstate__(self) -> dict:
1105 """Stores a compact form of the graph as a list of graph nodes, and a
1106 tuple of task labels and task configs. The full graph can be
1107 reconstructed with this information, and it preseves the ordering of
1108 the graph ndoes.
1109 """
1110 universe: Optional[DimensionUniverse] = None
1111 for node in self:
1112 dId = node.quantum.dataId
1113 if dId is None:
1114 continue
1115 universe = dId.graph.universe
1116 return {"reduced": self._buildSaveObject(), "graphId": self._buildId, "universe": universe}
1118 def __setstate__(self, state: dict):
1119 """Reconstructs the state of the graph from the information persisted
1120 in getstate.
1121 """
1122 buffer = io.BytesIO(state["reduced"])
1123 with LoadHelper(buffer, minimumVersion=3) as loader:
1124 qgraph = loader.load(state["universe"], graphID=state["graphId"])
1126 self._metadata = qgraph._metadata
1127 self._buildId = qgraph._buildId
1128 self._datasetDict = qgraph._datasetDict
1129 self._nodeIdMap = qgraph._nodeIdMap
1130 self._count = len(qgraph)
1131 self._taskToQuantumNode = qgraph._taskToQuantumNode
1132 self._taskGraph = qgraph._taskGraph
1133 self._connectedQuanta = qgraph._connectedQuanta
1135 def __eq__(self, other: object) -> bool:
1136 if not isinstance(other, QuantumGraph):
1137 return False
1138 if len(self) != len(other):
1139 return False
1140 for node in self:
1141 if node not in other:
1142 return False
1143 if self.determineInputsToQuantumNode(node) != other.determineInputsToQuantumNode(node):
1144 return False
1145 if self.determineOutputsOfQuantumNode(node) != other.determineOutputsOfQuantumNode(node):
1146 return False
1147 if set(self.allDatasetTypes) != set(other.allDatasetTypes):
1148 return False
1149 return set(self.taskGraph) == set(other.taskGraph)