Coverage for python/lsst/pipe/base/graph/graph.py: 19%
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23from lsst.daf.butler.core.datasets.type import DatasetType
24__all__ = ("QuantumGraph", "IncompatibleGraphError")
26import warnings
28from lsst.daf.butler import Quantum, DatasetRef, ButlerURI, DimensionUniverse, DimensionRecordsAccumulator
30from collections import defaultdict, deque
32from itertools import chain
33import io
34import json
35import networkx as nx
36from networkx.drawing.nx_agraph import write_dot
37import os
38import pickle
39import lzma
40import struct
41import time
42import uuid
43from types import MappingProxyType
44from typing import (Any, DefaultDict, Dict, FrozenSet, Iterable, List, Mapping, MutableMapping, Set,
45 Generator, Optional, Tuple, Union, TypeVar)
47from ..connections import iterConnections
48from ..pipeline import TaskDef
50from ._implDetails import _DatasetTracker, DatasetTypeName, _pruner
51from .quantumNode import QuantumNode, BuildId
52from ._loadHelpers import LoadHelper
53from ._versionDeserializers import DESERIALIZER_MAP
56_T = TypeVar("_T", bound="QuantumGraph")
58# modify this constant any time the on disk representation of the save file
59# changes, and update the load helpers to behave properly for each version.
60SAVE_VERSION = 3
62# Strings used to describe the format for the preamble bytes in a file save
63# The base is a big endian encoded unsigned short that is used to hold the
64# file format version. This allows reading version bytes and determine which
65# loading code should be used for the rest of the file
66STRUCT_FMT_BASE = '>H'
67#
68# Version 1
69# This marks a big endian encoded format with an unsigned short, an unsigned
70# long long, and an unsigned long long in the byte stream
71# Version 2
72# A big endian encoded format with an unsigned long long byte stream used to
73# indicate the total length of the entire header.
74STRUCT_FMT_STRING = {
75 1: '>QQ',
76 2: '>Q'
77}
79# magic bytes that help determine this is a graph save
80MAGIC_BYTES = b"qgraph4\xf6\xe8\xa9"
83class IncompatibleGraphError(Exception):
84 """Exception class to indicate that a lookup by NodeId is impossible due
85 to incompatibilities
86 """
87 pass
90class QuantumGraph:
91 """QuantumGraph is a directed acyclic graph of `QuantumNode` objects
93 This data structure represents a concrete workflow generated from a
94 `Pipeline`.
96 Parameters
97 ----------
98 quanta : Mapping of `TaskDef` to sets of `Quantum`
99 This maps tasks (and their configs) to the sets of data they are to
100 process.
101 metadata : Optional Mapping of `str` to primitives
102 This is an optional parameter of extra data to carry with the graph.
103 Entries in this mapping should be able to be serialized in JSON.
105 Raises
106 ------
107 ValueError
108 Raised if the graph is pruned such that some tasks no longer have nodes
109 associated with them.
110 """
111 def __init__(self, quanta: Mapping[TaskDef, Set[Quantum]],
112 metadata: Optional[Mapping[str, Any]] = None,
113 pruneRefs: Optional[Iterable[DatasetRef]] = None):
114 self._buildGraphs(quanta, metadata=metadata, pruneRefs=pruneRefs)
116 def _buildGraphs(self,
117 quanta: Mapping[TaskDef, Set[Quantum]],
118 *,
119 _quantumToNodeId: Optional[Mapping[Quantum, uuid.UUID]] = None,
120 _buildId: Optional[BuildId] = None,
121 metadata: Optional[Mapping[str, Any]] = None,
122 pruneRefs: Optional[Iterable[DatasetRef]] = None):
123 """Builds the graph that is used to store the relation between tasks,
124 and the graph that holds the relations between quanta
125 """
126 self._metadata = metadata
127 self._buildId = _buildId if _buildId is not None else BuildId(f"{time.time()}-{os.getpid()}")
128 # Data structures used to identify relations between components;
129 # DatasetTypeName -> TaskDef for task,
130 # and DatasetRef -> QuantumNode for the quanta
131 self._datasetDict = _DatasetTracker[DatasetTypeName, TaskDef](createInverse=True)
132 self._datasetRefDict = _DatasetTracker[DatasetRef, QuantumNode]()
134 self._nodeIdMap: Dict[uuid.UUID, QuantumNode] = {}
135 self._taskToQuantumNode: MutableMapping[TaskDef, Set[QuantumNode]] = defaultdict(set)
136 for taskDef, quantumSet in quanta.items():
137 connections = taskDef.connections
139 # For each type of connection in the task, add a key to the
140 # `_DatasetTracker` for the connections name, with a value of
141 # the TaskDef in the appropriate field
142 for inpt in iterConnections(connections, ("inputs", "prerequisiteInputs", "initInputs")):
143 self._datasetDict.addConsumer(DatasetTypeName(inpt.name), taskDef)
145 for output in iterConnections(connections, ("outputs",)):
146 self._datasetDict.addProducer(DatasetTypeName(output.name), taskDef)
148 # For each `Quantum` in the set of all `Quantum` for this task,
149 # add a key to the `_DatasetTracker` that is a `DatasetRef` for one
150 # of the individual datasets inside the `Quantum`, with a value of
151 # a newly created QuantumNode to the appropriate input/output
152 # field.
153 for quantum in quantumSet:
154 if _quantumToNodeId:
155 if (nodeId := _quantumToNodeId.get(quantum)) is None:
156 raise ValueError("If _quantuMToNodeNumber is not None, all quanta must have an "
157 "associated value in the mapping")
158 else:
159 nodeId = uuid.uuid4()
161 inits = quantum.initInputs.values()
162 inputs = quantum.inputs.values()
163 value = QuantumNode(quantum, taskDef, nodeId)
164 self._taskToQuantumNode[taskDef].add(value)
165 self._nodeIdMap[nodeId] = value
167 for dsRef in chain(inits, inputs):
168 # unfortunately, `Quantum` allows inits to be individual
169 # `DatasetRef`s or an Iterable of such, so there must
170 # be an instance check here
171 if isinstance(dsRef, Iterable):
172 for sub in dsRef:
173 if sub.isComponent():
174 sub = sub.makeCompositeRef()
175 self._datasetRefDict.addConsumer(sub, value)
176 else:
177 if dsRef.isComponent():
178 dsRef = dsRef.makeCompositeRef()
179 self._datasetRefDict.addConsumer(dsRef, value)
180 for dsRef in chain.from_iterable(quantum.outputs.values()):
181 self._datasetRefDict.addProducer(dsRef, value)
183 if pruneRefs is not None:
184 # track what refs were pruned and prune the graph
185 prunes = set()
186 _pruner(self._datasetRefDict, pruneRefs, alreadyPruned=prunes)
188 # recreate the taskToQuantumNode dict removing nodes that have been
189 # pruned. Keep track of task defs that now have no QuantumNodes
190 emptyTasks: Set[str] = set()
191 newTaskToQuantumNode: DefaultDict[TaskDef, Set[QuantumNode]] = defaultdict(set)
192 # accumulate all types
193 types_ = set()
194 # tracker for any pruneRefs that have caused tasks to have no nodes
195 # This helps the user find out what caused the issues seen.
196 culprits = set()
197 # Find all the types from the refs to prune
198 for r in pruneRefs:
199 types_.add(r.datasetType)
201 # For each of the tasks, and their associated nodes, remove any
202 # any nodes that were pruned. If there are no nodes associated
203 # with a task, record that task, and find out if that was due to
204 # a type from an input ref to prune.
205 for td, taskNodes in self._taskToQuantumNode.items():
206 diff = taskNodes.difference(prunes)
207 if len(diff) == 0:
208 if len(taskNodes) != 0:
209 tp: DatasetType
210 for tp in types_:
211 if ((tmpRefs := next(iter(taskNodes)).quantum.inputs.get(tp)) and not
212 set(tmpRefs).difference(pruneRefs)):
213 culprits.add(tp.name)
214 emptyTasks.add(td.label)
215 newTaskToQuantumNode[td] = diff
217 # update the internal dict
218 self._taskToQuantumNode = newTaskToQuantumNode
220 if emptyTasks:
221 raise ValueError(f"{', '.join(emptyTasks)} task(s) have no nodes associated with them "
222 f"after graph pruning; {', '.join(culprits)} caused over-pruning")
224 # Graph of quanta relations
225 self._connectedQuanta = self._datasetRefDict.makeNetworkXGraph()
226 self._count = len(self._connectedQuanta)
228 # Graph of task relations, used in various methods
229 self._taskGraph = self._datasetDict.makeNetworkXGraph()
231 # convert default dict into a regular to prevent accidental key
232 # insertion
233 self._taskToQuantumNode = dict(self._taskToQuantumNode.items())
235 @property
236 def taskGraph(self) -> nx.DiGraph:
237 """Return a graph representing the relations between the tasks inside
238 the quantum graph.
240 Returns
241 -------
242 taskGraph : `networkx.Digraph`
243 Internal datastructure that holds relations of `TaskDef` objects
244 """
245 return self._taskGraph
247 @property
248 def graph(self) -> nx.DiGraph:
249 """Return a graph representing the relations between all the
250 `QuantumNode` objects. Largely it should be preferred to iterate
251 over, and use methods of this class, but sometimes direct access to
252 the networkx object may be helpful
254 Returns
255 -------
256 graph : `networkx.Digraph`
257 Internal datastructure that holds relations of `QuantumNode`
258 objects
259 """
260 return self._connectedQuanta
262 @property
263 def inputQuanta(self) -> Iterable[QuantumNode]:
264 """Make a `list` of all `QuantumNode` objects that are 'input' nodes
265 to the graph, meaning those nodes to not depend on any other nodes in
266 the graph.
268 Returns
269 -------
270 inputNodes : iterable of `QuantumNode`
271 A list of nodes that are inputs to the graph
272 """
273 return (q for q, n in self._connectedQuanta.in_degree if n == 0)
275 @property
276 def outputQuanta(self) -> Iterable[QuantumNode]:
277 """Make a `list` of all `QuantumNode` objects that are 'output' nodes
278 to the graph, meaning those nodes have no nodes that depend them in
279 the graph.
281 Returns
282 -------
283 outputNodes : iterable of `QuantumNode`
284 A list of nodes that are outputs of the graph
285 """
286 return [q for q, n in self._connectedQuanta.out_degree if n == 0]
288 @property
289 def allDatasetTypes(self) -> Tuple[DatasetTypeName, ...]:
290 """Return all the `DatasetTypeName` objects that are contained inside
291 the graph.
293 Returns
294 -------
295 tuple of `DatasetTypeName`
296 All the data set type names that are present in the graph
297 """
298 return tuple(self._datasetDict.keys())
300 @property
301 def isConnected(self) -> bool:
302 """Return True if all of the nodes in the graph are connected, ignores
303 directionality of connections.
304 """
305 return nx.is_weakly_connected(self._connectedQuanta)
307 def pruneGraphFromRefs(self: _T, refs: Iterable[DatasetRef]) -> _T:
308 r"""Return a graph pruned of input `~lsst.daf.butler.DatasetRef`\ s
309 and nodes which depend on them.
311 Parameters
312 ----------
313 refs : `Iterable` of `DatasetRef`
314 Refs which should be removed from resulting graph
316 Returns
317 -------
318 graph : `QuantumGraph`
319 A graph that has been pruned of specified refs and the nodes that
320 depend on them.
321 """
322 newInst = object.__new__(type(self))
323 quantumMap = defaultdict(set)
324 for node in self:
325 quantumMap[node.taskDef].add(node.quantum)
327 # convert to standard dict to prevent accidental key insertion
328 quantumMap = dict(quantumMap.items())
330 newInst._buildGraphs(quantumMap, _quantumToNodeId={n.quantum: n.nodeId for n in self},
331 metadata=self._metadata, pruneRefs=refs)
332 return newInst
334 def getQuantumNodeByNodeId(self, nodeId: uuid.UUID) -> QuantumNode:
335 """Lookup a `QuantumNode` from an id associated with the node.
337 Parameters
338 ----------
339 nodeId : `NodeId`
340 The number associated with a node
342 Returns
343 -------
344 node : `QuantumNode`
345 The node corresponding with input number
347 Raises
348 ------
349 KeyError
350 Raised if the requested nodeId is not in the graph.
351 """
352 return self._nodeIdMap[nodeId]
354 def getQuantaForTask(self, taskDef: TaskDef) -> FrozenSet[Quantum]:
355 """Return all the `Quantum` associated with a `TaskDef`.
357 Parameters
358 ----------
359 taskDef : `TaskDef`
360 The `TaskDef` for which `Quantum` are to be queried
362 Returns
363 -------
364 frozenset of `Quantum`
365 The `set` of `Quantum` that is associated with the specified
366 `TaskDef`.
367 """
368 return frozenset(node.quantum for node in self._taskToQuantumNode[taskDef])
370 def getNodesForTask(self, taskDef: TaskDef) -> FrozenSet[QuantumNode]:
371 """Return all the `QuantumNodes` associated with a `TaskDef`.
373 Parameters
374 ----------
375 taskDef : `TaskDef`
376 The `TaskDef` for which `Quantum` are to be queried
378 Returns
379 -------
380 frozenset of `QuantumNodes`
381 The `frozenset` of `QuantumNodes` that is associated with the
382 specified `TaskDef`.
383 """
384 return frozenset(self._taskToQuantumNode[taskDef])
386 def findTasksWithInput(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]:
387 """Find all tasks that have the specified dataset type name as an
388 input.
390 Parameters
391 ----------
392 datasetTypeName : `str`
393 A string representing the name of a dataset type to be queried,
394 can also accept a `DatasetTypeName` which is a `NewType` of str for
395 type safety in static type checking.
397 Returns
398 -------
399 tasks : iterable of `TaskDef`
400 `TaskDef` objects that have the specified `DatasetTypeName` as an
401 input, list will be empty if no tasks use specified
402 `DatasetTypeName` as an input.
404 Raises
405 ------
406 KeyError
407 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
408 """
409 return (c for c in self._datasetDict.getConsumers(datasetTypeName))
411 def findTaskWithOutput(self, datasetTypeName: DatasetTypeName) -> Optional[TaskDef]:
412 """Find all tasks that have the specified dataset type name as an
413 output.
415 Parameters
416 ----------
417 datasetTypeName : `str`
418 A string representing the name of a dataset type to be queried,
419 can also accept a `DatasetTypeName` which is a `NewType` of str for
420 type safety in static type checking.
422 Returns
423 -------
424 `TaskDef` or `None`
425 `TaskDef` that outputs `DatasetTypeName` as an output or None if
426 none of the tasks produce this `DatasetTypeName`.
428 Raises
429 ------
430 KeyError
431 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
432 """
433 return self._datasetDict.getProducer(datasetTypeName)
435 def tasksWithDSType(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]:
436 """Find all tasks that are associated with the specified dataset type
437 name.
439 Parameters
440 ----------
441 datasetTypeName : `str`
442 A string representing the name of a dataset type to be queried,
443 can also accept a `DatasetTypeName` which is a `NewType` of str for
444 type safety in static type checking.
446 Returns
447 -------
448 result : iterable of `TaskDef`
449 `TaskDef` objects that are associated with the specified
450 `DatasetTypeName`
452 Raises
453 ------
454 KeyError
455 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
456 """
457 return self._datasetDict.getAll(datasetTypeName)
459 def findTaskDefByName(self, taskName: str) -> List[TaskDef]:
460 """Determine which `TaskDef` objects in this graph are associated
461 with a `str` representing a task name (looks at the taskName property
462 of `TaskDef` objects).
464 Returns a list of `TaskDef` objects as a `PipelineTask` may appear
465 multiple times in a graph with different labels.
467 Parameters
468 ----------
469 taskName : str
470 Name of a task to search for
472 Returns
473 -------
474 result : list of `TaskDef`
475 List of the `TaskDef` objects that have the name specified.
476 Multiple values are returned in the case that a task is used
477 multiple times with different labels.
478 """
479 results = []
480 for task in self._taskToQuantumNode.keys():
481 split = task.taskName.split('.')
482 if split[-1] == taskName:
483 results.append(task)
484 return results
486 def findTaskDefByLabel(self, label: str) -> Optional[TaskDef]:
487 """Determine which `TaskDef` objects in this graph are associated
488 with a `str` representing a tasks label.
490 Parameters
491 ----------
492 taskName : str
493 Name of a task to search for
495 Returns
496 -------
497 result : `TaskDef`
498 `TaskDef` objects that has the specified label.
499 """
500 for task in self._taskToQuantumNode.keys():
501 if label == task.label:
502 return task
503 return None
505 def findQuantaWithDSType(self, datasetTypeName: DatasetTypeName) -> Set[Quantum]:
506 """Return all the `Quantum` that contain a specified `DatasetTypeName`.
508 Parameters
509 ----------
510 datasetTypeName : `str`
511 The name of the dataset type to search for as a string,
512 can also accept a `DatasetTypeName` which is a `NewType` of str for
513 type safety in static type checking.
515 Returns
516 -------
517 result : `set` of `QuantumNode` objects
518 A `set` of `QuantumNode`s that contain specified `DatasetTypeName`
520 Raises
521 ------
522 KeyError
523 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
525 """
526 tasks = self._datasetDict.getAll(datasetTypeName)
527 result: Set[Quantum] = set()
528 result = result.union(quantum for task in tasks for quantum in self.getQuantaForTask(task))
529 return result
531 def checkQuantumInGraph(self, quantum: Quantum) -> bool:
532 """Check if specified quantum appears in the graph as part of a node.
534 Parameters
535 ----------
536 quantum : `Quantum`
537 The quantum to search for
539 Returns
540 -------
541 `bool`
542 The result of searching for the quantum
543 """
544 for node in self:
545 if quantum == node.quantum:
546 return True
547 return False
549 def writeDotGraph(self, output: Union[str, io.BufferedIOBase]):
550 """Write out the graph as a dot graph.
552 Parameters
553 ----------
554 output : str or `io.BufferedIOBase`
555 Either a filesystem path to write to, or a file handle object
556 """
557 write_dot(self._connectedQuanta, output)
559 def subset(self: _T, nodes: Union[QuantumNode, Iterable[QuantumNode]]) -> _T:
560 """Create a new graph object that contains the subset of the nodes
561 specified as input. Node number is preserved.
563 Parameters
564 ----------
565 nodes : `QuantumNode` or iterable of `QuantumNode`
567 Returns
568 -------
569 graph : instance of graph type
570 An instance of the type from which the subset was created
571 """
572 if not isinstance(nodes, Iterable):
573 nodes = (nodes, )
574 quantumSubgraph = self._connectedQuanta.subgraph(nodes).nodes
575 quantumMap = defaultdict(set)
577 node: QuantumNode
578 for node in quantumSubgraph:
579 quantumMap[node.taskDef].add(node.quantum)
581 # convert to standard dict to prevent accidental key insertion
582 quantumMap = dict(quantumMap.items())
583 # Create an empty graph, and then populate it with custom mapping
584 newInst = type(self)({})
585 newInst._buildGraphs(quantumMap, _quantumToNodeId={n.quantum: n.nodeId for n in nodes},
586 _buildId=self._buildId, metadata=self._metadata)
587 return newInst
589 def subsetToConnected(self: _T) -> Tuple[_T, ...]:
590 """Generate a list of subgraphs where each is connected.
592 Returns
593 -------
594 result : list of `QuantumGraph`
595 A list of graphs that are each connected
596 """
597 return tuple(self.subset(connectedSet)
598 for connectedSet in nx.weakly_connected_components(self._connectedQuanta))
600 def determineInputsToQuantumNode(self, node: QuantumNode) -> Set[QuantumNode]:
601 """Return a set of `QuantumNode` that are direct inputs to a specified
602 node.
604 Parameters
605 ----------
606 node : `QuantumNode`
607 The node of the graph for which inputs are to be determined
609 Returns
610 -------
611 set of `QuantumNode`
612 All the nodes that are direct inputs to specified node
613 """
614 return set(pred for pred in self._connectedQuanta.predecessors(node))
616 def determineOutputsOfQuantumNode(self, node: QuantumNode) -> Set[QuantumNode]:
617 """Return a set of `QuantumNode` that are direct outputs of a specified
618 node.
620 Parameters
621 ----------
622 node : `QuantumNode`
623 The node of the graph for which outputs are to be determined
625 Returns
626 -------
627 set of `QuantumNode`
628 All the nodes that are direct outputs to specified node
629 """
630 return set(succ for succ in self._connectedQuanta.successors(node))
632 def determineConnectionsOfQuantumNode(self: _T, node: QuantumNode) -> _T:
633 """Return a graph of `QuantumNode` that are direct inputs and outputs
634 of a specified node.
636 Parameters
637 ----------
638 node : `QuantumNode`
639 The node of the graph for which connected nodes are to be
640 determined.
642 Returns
643 -------
644 graph : graph of `QuantumNode`
645 All the nodes that are directly connected to specified node
646 """
647 nodes = self.determineInputsToQuantumNode(node).union(self.determineOutputsOfQuantumNode(node))
648 nodes.add(node)
649 return self.subset(nodes)
651 def determineAncestorsOfQuantumNode(self: _T, node: QuantumNode) -> _T:
652 """Return a graph of the specified node and all the ancestor nodes
653 directly reachable by walking edges.
655 Parameters
656 ----------
657 node : `QuantumNode`
658 The node for which all ansestors are to be determined
660 Returns
661 -------
662 graph of `QuantumNode`
663 Graph of node and all of its ansestors
664 """
665 predecessorNodes = nx.ancestors(self._connectedQuanta, node)
666 predecessorNodes.add(node)
667 return self.subset(predecessorNodes)
669 def findCycle(self) -> List[Tuple[QuantumNode, QuantumNode]]:
670 """Check a graph for the presense of cycles and returns the edges of
671 any cycles found, or an empty list if there is no cycle.
673 Returns
674 -------
675 result : list of tuple of `QuantumNode`, `QuantumNode`
676 A list of any graph edges that form a cycle, or an empty list if
677 there is no cycle. Empty list to so support if graph.find_cycle()
678 syntax as an empty list is falsy.
679 """
680 try:
681 return nx.find_cycle(self._connectedQuanta)
682 except nx.NetworkXNoCycle:
683 return []
685 def saveUri(self, uri):
686 """Save `QuantumGraph` to the specified URI.
688 Parameters
689 ----------
690 uri : `ButlerURI` or `str`
691 URI to where the graph should be saved.
692 """
693 buffer = self._buildSaveObject()
694 butlerUri = ButlerURI(uri)
695 if butlerUri.getExtension() not in (".qgraph"):
696 raise TypeError(f"Can currently only save a graph in qgraph format not {uri}")
697 butlerUri.write(buffer) # type: ignore # Ignore because bytearray is safe to use in place of bytes
699 @property
700 def metadata(self) -> Optional[MappingProxyType[str, Any]]:
701 """
702 """
703 if self._metadata is None:
704 return None
705 return MappingProxyType(self._metadata)
707 @classmethod
708 def loadUri(cls, uri: Union[ButlerURI, str], universe: DimensionUniverse,
709 nodes: Optional[Iterable[Union[str, uuid.UUID]]] = None,
710 graphID: Optional[BuildId] = None,
711 minimumVersion: int = 3
712 ) -> QuantumGraph:
713 """Read `QuantumGraph` from a URI.
715 Parameters
716 ----------
717 uri : `ButlerURI` or `str`
718 URI from where to load the graph.
719 universe: `~lsst.daf.butler.DimensionUniverse`
720 DimensionUniverse instance, not used by the method itself but
721 needed to ensure that registry data structures are initialized.
722 nodes: iterable of `int` or None
723 Numbers that correspond to nodes in the graph. If specified, only
724 these nodes will be loaded. Defaults to None, in which case all
725 nodes will be loaded.
726 graphID : `str` or `None`
727 If specified this ID is verified against the loaded graph prior to
728 loading any Nodes. This defaults to None in which case no
729 validation is done.
730 minimumVersion : int
731 Minimum version of a save file to load. Set to -1 to load all
732 versions. Older versions may need to be loaded, and re-saved
733 to upgrade them to the latest format before they can be used in
734 production.
736 Returns
737 -------
738 graph : `QuantumGraph`
739 Resulting QuantumGraph instance.
741 Raises
742 ------
743 TypeError
744 Raised if pickle contains instance of a type other than
745 QuantumGraph.
746 ValueError
747 Raised if one or more of the nodes requested is not in the
748 `QuantumGraph` or if graphID parameter does not match the graph
749 being loaded or if the supplied uri does not point at a valid
750 `QuantumGraph` save file.
753 Notes
754 -----
755 Reading Quanta from pickle requires existence of singleton
756 DimensionUniverse which is usually instantiated during Registry
757 initialization. To make sure that DimensionUniverse exists this method
758 accepts dummy DimensionUniverse argument.
759 """
760 uri = ButlerURI(uri)
761 # With ButlerURI we have the choice of always using a local file
762 # or reading in the bytes directly. Reading in bytes can be more
763 # efficient for reasonably-sized pickle files when the resource
764 # is remote. For now use the local file variant. For a local file
765 # as_local() does nothing.
767 if uri.getExtension() in (".pickle", ".pkl"):
768 with uri.as_local() as local, open(local.ospath, "rb") as fd:
769 warnings.warn("Pickle graphs are deprecated, please re-save your graph with the save method")
770 qgraph = pickle.load(fd)
771 elif uri.getExtension() in ('.qgraph'):
772 with LoadHelper(uri, minimumVersion) as loader:
773 qgraph = loader.load(universe, nodes, graphID)
774 else:
775 raise ValueError("Only know how to handle files saved as `pickle`, `pkl`, or `qgraph`")
776 if not isinstance(qgraph, QuantumGraph):
777 raise TypeError(f"QuantumGraph save file contains unexpected object type: {type(qgraph)}")
778 return qgraph
780 @classmethod
781 def readHeader(cls, uri: Union[ButlerURI, str], minimumVersion: int = 3) -> Optional[str]:
782 """Read the header of a `QuantumGraph` pointed to by the uri parameter
783 and return it as a string.
785 Parameters
786 ----------
787 uri : `~lsst.daf.butler.ButlerURI` or `str`
788 The location of the `QuantumGraph` to load. If the argument is a
789 string, it must correspond to a valid `~lsst.daf.butler.ButlerURI`
790 path.
791 minimumVersion : int
792 Minimum version of a save file to load. Set to -1 to load all
793 versions. Older versions may need to be loaded, and re-saved
794 to upgrade them to the latest format before they can be used in
795 production.
797 Returns
798 -------
799 header : `str` or `None`
800 The header associated with the specified `QuantumGraph` it there is
801 one, else `None`.
803 Raises
804 ------
805 ValueError
806 Raised if `QuantuGraph` was saved as a pickle.
807 Raised if the extention of the file specified by uri is not a
808 `QuantumGraph` extention.
809 """
810 uri = ButlerURI(uri)
811 if uri.getExtension() in (".pickle", ".pkl"):
812 raise ValueError("Reading a header from a pickle save is not supported")
813 elif uri.getExtension() in ('.qgraph'):
814 return LoadHelper(uri, minimumVersion).readHeader()
815 else:
816 raise ValueError("Only know how to handle files saved as `qgraph`")
818 def buildAndPrintHeader(self):
819 """Creates a header that would be used in a save of this object and
820 prints it out to standard out.
821 """
822 _, header = self._buildSaveObject(returnHeader=True)
823 print(json.dumps(header))
825 def save(self, file: io.IO[bytes]):
826 """Save QuantumGraph to a file.
828 Presently we store QuantumGraph in pickle format, this could
829 potentially change in the future if better format is found.
831 Parameters
832 ----------
833 file : `io.BufferedIOBase`
834 File to write pickle data open in binary mode.
835 """
836 buffer = self._buildSaveObject()
837 file.write(buffer) # type: ignore # Ignore because bytearray is safe to use in place of bytes
839 def _buildSaveObject(self, returnHeader: bool = False) -> Union[bytearray, Tuple[bytearray, Dict]]:
840 # make some containers
841 jsonData = deque()
842 # node map is a list because json does not accept mapping keys that
843 # are not strings, so we store a list of key, value pairs that will
844 # be converted to a mapping on load
845 nodeMap = []
846 taskDefMap = {}
847 headerData = {}
849 # Store the QauntumGraph BuildId, this will allow validating BuildIds
850 # at load time, prior to loading any QuantumNodes. Name chosen for
851 # unlikely conflicts.
852 headerData['GraphBuildID'] = self.graphID
853 headerData['Metadata'] = self._metadata
855 # counter for the number of bytes processed thus far
856 count = 0
857 # serialize out the task Defs recording the start and end bytes of each
858 # taskDef
859 inverseLookup = self._datasetDict.inverse
860 taskDef: TaskDef
861 # sort by task label to ensure serialization happens in the same order
862 for taskDef in self.taskGraph:
863 # compressing has very little impact on saving or load time, but
864 # a large impact on on disk size, so it is worth doing
865 taskDescription = {}
866 # save the fully qualified name, as TaskDef not not require this,
867 # but by doing so can save space and is easier to transport
868 taskDescription['taskName'] = f"{taskDef.taskClass.__module__}.{taskDef.taskClass.__qualname__}"
869 # save the config as a text stream that will be un-persisted on the
870 # other end
871 stream = io.StringIO()
872 taskDef.config.saveToStream(stream)
873 taskDescription['config'] = stream.getvalue()
874 taskDescription['label'] = taskDef.label
876 inputs = []
877 outputs = []
879 # Determine the connection between all of tasks and save that in
880 # the header as a list of connections and edges in each task
881 # this will help in un-persisting, and possibly in a "quick view"
882 # method that does not require everything to be un-persisted
883 #
884 # Typing returns can't be parameter dependent
885 for connection in inverseLookup[taskDef]: # type: ignore
886 consumers = self._datasetDict.getConsumers(connection)
887 producer = self._datasetDict.getProducer(connection)
888 if taskDef in consumers:
889 # This checks if the task consumes the connection directly
890 # from the datastore or it is produced by another task
891 producerLabel = producer.label if producer is not None else "datastore"
892 inputs.append((producerLabel, connection))
893 elif taskDef not in consumers and producer is taskDef:
894 # If there are no consumers for this tasks produced
895 # connection, the output will be said to be the datastore
896 # in which case the for loop will be a zero length loop
897 if not consumers:
898 outputs.append(("datastore", connection))
899 for td in consumers:
900 outputs.append((td.label, connection))
902 # dump to json string, and encode that string to bytes and then
903 # conpress those bytes
904 dump = lzma.compress(json.dumps(taskDescription).encode())
905 # record the sizing and relation information
906 taskDefMap[taskDef.label] = {"bytes": (count, count+len(dump)),
907 "inputs": inputs,
908 "outputs": outputs}
909 count += len(dump)
910 jsonData.append(dump)
912 headerData['TaskDefs'] = taskDefMap
914 # serialize the nodes, recording the start and end bytes of each node
915 dimAccumulator = DimensionRecordsAccumulator()
916 for node in self:
917 # compressing has very little impact on saving or load time, but
918 # a large impact on on disk size, so it is worth doing
919 simpleNode = node.to_simple(accumulator=dimAccumulator)
921 dump = lzma.compress(simpleNode.json().encode())
922 jsonData.append(dump)
923 nodeMap.append((str(node.nodeId),
924 {"bytes": (count, count+len(dump)),
925 "inputs": [str(n.nodeId) for n in self.determineInputsToQuantumNode(node)],
926 "outputs": [str(n.nodeId) for n in self.determineOutputsOfQuantumNode(node)]}
927 ))
928 count += len(dump)
930 headerData['DimensionRecords'] = {key: value.dict() for key, value in
931 dimAccumulator.makeSerializedDimensionRecordMapping().items()}
933 # need to serialize this as a series of key,value tuples because of
934 # a limitation on how json cant do anyting but strings as keys
935 headerData['Nodes'] = nodeMap
937 # dump the headerData to json
938 header_encode = lzma.compress(json.dumps(headerData).encode())
940 # record the sizes as 2 unsigned long long numbers for a total of 16
941 # bytes
942 save_bytes = struct.pack(STRUCT_FMT_BASE, SAVE_VERSION)
944 fmt_string = DESERIALIZER_MAP[SAVE_VERSION].FMT_STRING()
945 map_lengths = struct.pack(fmt_string, len(header_encode))
947 # write each component of the save out in a deterministic order
948 # buffer = io.BytesIO()
949 # buffer.write(map_lengths)
950 # buffer.write(taskDef_pickle)
951 # buffer.write(map_pickle)
952 buffer = bytearray()
953 buffer.extend(MAGIC_BYTES)
954 buffer.extend(save_bytes)
955 buffer.extend(map_lengths)
956 buffer.extend(header_encode)
957 # Iterate over the length of pickleData, and for each element pop the
958 # leftmost element off the deque and write it out. This is to save
959 # memory, as the memory is added to the buffer object, it is removed
960 # from from the container.
961 #
962 # Only this section needs to worry about memory pressue because
963 # everything else written to the buffer prior to this pickle data is
964 # only on the order of kilobytes to low numbers of megabytes.
965 while jsonData:
966 buffer.extend(jsonData.popleft())
967 if returnHeader:
968 return buffer, headerData
969 else:
970 return buffer
972 @classmethod
973 def load(cls, file: io.IO[bytes], universe: DimensionUniverse,
974 nodes: Optional[Iterable[uuid.UUID]] = None,
975 graphID: Optional[BuildId] = None,
976 minimumVersion: int = 3
977 ) -> QuantumGraph:
978 """Read QuantumGraph from a file that was made by `save`.
980 Parameters
981 ----------
982 file : `io.IO` of bytes
983 File with pickle data open in binary mode.
984 universe: `~lsst.daf.butler.DimensionUniverse`
985 DimensionUniverse instance, not used by the method itself but
986 needed to ensure that registry data structures are initialized.
987 nodes: iterable of `int` or None
988 Numbers that correspond to nodes in the graph. If specified, only
989 these nodes will be loaded. Defaults to None, in which case all
990 nodes will be loaded.
991 graphID : `str` or `None`
992 If specified this ID is verified against the loaded graph prior to
993 loading any Nodes. This defaults to None in which case no
994 validation is done.
995 minimumVersion : int
996 Minimum version of a save file to load. Set to -1 to load all
997 versions. Older versions may need to be loaded, and re-saved
998 to upgrade them to the latest format before they can be used in
999 production.
1001 Returns
1002 -------
1003 graph : `QuantumGraph`
1004 Resulting QuantumGraph instance.
1006 Raises
1007 ------
1008 TypeError
1009 Raised if pickle contains instance of a type other than
1010 QuantumGraph.
1011 ValueError
1012 Raised if one or more of the nodes requested is not in the
1013 `QuantumGraph` or if graphID parameter does not match the graph
1014 being loaded or if the supplied uri does not point at a valid
1015 `QuantumGraph` save file.
1017 Notes
1018 -----
1019 Reading Quanta from pickle requires existence of singleton
1020 DimensionUniverse which is usually instantiated during Registry
1021 initialization. To make sure that DimensionUniverse exists this method
1022 accepts dummy DimensionUniverse argument.
1023 """
1024 # Try to see if the file handle contains pickle data, this will be
1025 # removed in the future
1026 try:
1027 qgraph = pickle.load(file)
1028 warnings.warn("Pickle graphs are deprecated, please re-save your graph with the save method")
1029 except pickle.UnpicklingError:
1030 # needed because we don't have Protocols yet
1031 with LoadHelper(file, minimumVersion) as loader: # type: ignore
1032 qgraph = loader.load(universe, nodes, graphID)
1033 if not isinstance(qgraph, QuantumGraph):
1034 raise TypeError(f"QuantumGraph pickle file has contains unexpected object type: {type(qgraph)}")
1035 return qgraph
1037 def iterTaskGraph(self) -> Generator[TaskDef, None, None]:
1038 """Iterate over the `taskGraph` attribute in topological order
1040 Yields
1041 ------
1042 taskDef : `TaskDef`
1043 `TaskDef` objects in topological order
1044 """
1045 yield from nx.topological_sort(self.taskGraph)
1047 @property
1048 def graphID(self):
1049 """Returns the ID generated by the graph at construction time
1050 """
1051 return self._buildId
1053 def __iter__(self) -> Generator[QuantumNode, None, None]:
1054 yield from nx.topological_sort(self._connectedQuanta)
1056 def __len__(self) -> int:
1057 return self._count
1059 def __contains__(self, node: QuantumNode) -> bool:
1060 return self._connectedQuanta.has_node(node)
1062 def __getstate__(self) -> dict:
1063 """Stores a compact form of the graph as a list of graph nodes, and a
1064 tuple of task labels and task configs. The full graph can be
1065 reconstructed with this information, and it preseves the ordering of
1066 the graph ndoes.
1067 """
1068 universe: Optional[DimensionUniverse] = None
1069 for node in self:
1070 dId = node.quantum.dataId
1071 if dId is None:
1072 continue
1073 universe = dId.graph.universe
1074 return {"reduced": self._buildSaveObject(),
1075 'graphId': self._buildId, 'universe': universe}
1077 def __setstate__(self, state: dict):
1078 """Reconstructs the state of the graph from the information persisted
1079 in getstate.
1080 """
1081 buffer = io.BytesIO(state['reduced'])
1082 with LoadHelper(buffer, minimumVersion=3) as loader:
1083 qgraph = loader.load(state['universe'], graphID=state['graphId'])
1085 self._metadata = qgraph._metadata
1086 self._buildId = qgraph._buildId
1087 self._datasetDict = qgraph._datasetDict
1088 self._nodeIdMap = qgraph._nodeIdMap
1089 self._count = len(qgraph)
1090 self._taskToQuantumNode = qgraph._taskToQuantumNode
1091 self._taskGraph = qgraph._taskGraph
1092 self._connectedQuanta = qgraph._connectedQuanta
1094 def __eq__(self, other: object) -> bool:
1095 if not isinstance(other, QuantumGraph):
1096 return False
1097 if len(self) != len(other):
1098 return False
1099 for node in self:
1100 if node not in other:
1101 return False
1102 if self.determineInputsToQuantumNode(node) != other.determineInputsToQuantumNode(node):
1103 return False
1104 if self.determineOutputsOfQuantumNode(node) != other.determineOutputsOfQuantumNode(node):
1105 return False
1106 if set(self.allDatasetTypes) != set(other.allDatasetTypes):
1107 return False
1108 return set(self.taskGraph) == set(other.taskGraph)