Coverage for python/lsst/pipe/base/graph/graph.py: 18%
367 statements
« prev ^ index » next coverage.py v6.4.4, created at 2022-08-18 11:52 -0700
« prev ^ index » next coverage.py v6.4.4, created at 2022-08-18 11:52 -0700
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ("QuantumGraph", "IncompatibleGraphError")
25import io
26import json
27import lzma
28import os
29import pickle
30import struct
31import time
32import uuid
33import warnings
34from collections import defaultdict, deque
35from itertools import chain
36from types import MappingProxyType
37from typing import (
38 Any,
39 BinaryIO,
40 DefaultDict,
41 Deque,
42 Dict,
43 FrozenSet,
44 Generator,
45 Iterable,
46 List,
47 Mapping,
48 MutableMapping,
49 Optional,
50 Set,
51 Tuple,
52 TypeVar,
53 Union,
54)
56import networkx as nx
57from lsst.daf.butler import DatasetRef, DatasetType, DimensionRecordsAccumulator, DimensionUniverse, Quantum
58from lsst.resources import ResourcePath, ResourcePathExpression
59from lsst.utils.introspection import get_full_type_name
60from networkx.drawing.nx_agraph import write_dot
62from ..connections import iterConnections
63from ..pipeline import TaskDef
64from ._implDetails import DatasetTypeName, _DatasetTracker, _pruner
65from ._loadHelpers import LoadHelper
66from ._versionDeserializers import DESERIALIZER_MAP
67from .quantumNode import BuildId, QuantumNode
69_T = TypeVar("_T", bound="QuantumGraph")
71# modify this constant any time the on disk representation of the save file
72# changes, and update the load helpers to behave properly for each version.
73SAVE_VERSION = 3
75# Strings used to describe the format for the preamble bytes in a file save
76# The base is a big endian encoded unsigned short that is used to hold the
77# file format version. This allows reading version bytes and determine which
78# loading code should be used for the rest of the file
79STRUCT_FMT_BASE = ">H"
80#
81# Version 1
82# This marks a big endian encoded format with an unsigned short, an unsigned
83# long long, and an unsigned long long in the byte stream
84# Version 2
85# A big endian encoded format with an unsigned long long byte stream used to
86# indicate the total length of the entire header.
87STRUCT_FMT_STRING = {1: ">QQ", 2: ">Q"}
89# magic bytes that help determine this is a graph save
90MAGIC_BYTES = b"qgraph4\xf6\xe8\xa9"
93class IncompatibleGraphError(Exception):
94 """Exception class to indicate that a lookup by NodeId is impossible due
95 to incompatibilities
96 """
98 pass
101class QuantumGraph:
102 """QuantumGraph is a directed acyclic graph of `QuantumNode` objects
104 This data structure represents a concrete workflow generated from a
105 `Pipeline`.
107 Parameters
108 ----------
109 quanta : Mapping of `TaskDef` to sets of `Quantum`
110 This maps tasks (and their configs) to the sets of data they are to
111 process.
112 metadata : Optional Mapping of `str` to primitives
113 This is an optional parameter of extra data to carry with the graph.
114 Entries in this mapping should be able to be serialized in JSON.
115 pruneRefs : iterable [ `DatasetRef` ], optional
116 Set of dataset refs to exclude from a graph.
117 initInputs : `Mapping`, optional
118 Maps tasks to their InitInput dataset refs. Dataset refs can be either
119 resolved or non-resolved. Presently the same dataset refs are included
120 in each `Quantum` for the same task.
121 initOutputs : `Mapping`, optional
122 Maps tasks to their InitOutput dataset refs. Dataset refs can be either
123 resolved or non-resolved. For intermediate resolved refs their dataset
124 ID must match ``initInputs`` and Quantum ``initInputs``.
126 Raises
127 ------
128 ValueError
129 Raised if the graph is pruned such that some tasks no longer have nodes
130 associated with them.
131 """
133 def __init__(
134 self,
135 quanta: Mapping[TaskDef, Set[Quantum]],
136 metadata: Optional[Mapping[str, Any]] = None,
137 pruneRefs: Optional[Iterable[DatasetRef]] = None,
138 universe: Optional[DimensionUniverse] = None,
139 initInputs: Optional[Mapping[TaskDef, Iterable[DatasetRef]]] = None,
140 initOutputs: Optional[Mapping[TaskDef, Iterable[DatasetRef]]] = None,
141 ):
142 self._buildGraphs(
143 quanta,
144 metadata=metadata,
145 pruneRefs=pruneRefs,
146 universe=universe,
147 initInputs=initInputs,
148 initOutputs=initOutputs,
149 )
151 def _buildGraphs(
152 self,
153 quanta: Mapping[TaskDef, Set[Quantum]],
154 *,
155 _quantumToNodeId: Optional[Mapping[Quantum, uuid.UUID]] = None,
156 _buildId: Optional[BuildId] = None,
157 metadata: Optional[Mapping[str, Any]] = None,
158 pruneRefs: Optional[Iterable[DatasetRef]] = None,
159 universe: Optional[DimensionUniverse] = None,
160 initInputs: Optional[Mapping[TaskDef, Iterable[DatasetRef]]] = None,
161 initOutputs: Optional[Mapping[TaskDef, Iterable[DatasetRef]]] = None,
162 ) -> None:
163 """Builds the graph that is used to store the relation between tasks,
164 and the graph that holds the relations between quanta
165 """
166 self._metadata = metadata
167 self._buildId = _buildId if _buildId is not None else BuildId(f"{time.time()}-{os.getpid()}")
168 # Data structures used to identify relations between components;
169 # DatasetTypeName -> TaskDef for task,
170 # and DatasetRef -> QuantumNode for the quanta
171 self._datasetDict = _DatasetTracker[DatasetTypeName, TaskDef](createInverse=True)
172 self._datasetRefDict = _DatasetTracker[DatasetRef, QuantumNode]()
174 self._nodeIdMap: Dict[uuid.UUID, QuantumNode] = {}
175 self._taskToQuantumNode: MutableMapping[TaskDef, Set[QuantumNode]] = defaultdict(set)
176 for taskDef, quantumSet in quanta.items():
177 connections = taskDef.connections
179 # For each type of connection in the task, add a key to the
180 # `_DatasetTracker` for the connections name, with a value of
181 # the TaskDef in the appropriate field
182 for inpt in iterConnections(connections, ("inputs", "prerequisiteInputs", "initInputs")):
183 self._datasetDict.addConsumer(DatasetTypeName(inpt.name), taskDef)
185 for output in iterConnections(connections, ("outputs",)):
186 self._datasetDict.addProducer(DatasetTypeName(output.name), taskDef)
188 # For each `Quantum` in the set of all `Quantum` for this task,
189 # add a key to the `_DatasetTracker` that is a `DatasetRef` for one
190 # of the individual datasets inside the `Quantum`, with a value of
191 # a newly created QuantumNode to the appropriate input/output
192 # field.
193 for quantum in quantumSet:
194 if quantum.dataId is not None:
195 if universe is None:
196 universe = quantum.dataId.universe
197 elif universe != quantum.dataId.universe:
198 raise RuntimeError(
199 "Mismatched dimension universes in QuantumGraph construction: "
200 f"{universe} != {quantum.dataId.universe}. "
201 )
203 if _quantumToNodeId:
204 if (nodeId := _quantumToNodeId.get(quantum)) is None:
205 raise ValueError(
206 "If _quantuMToNodeNumber is not None, all quanta must have an "
207 "associated value in the mapping"
208 )
209 else:
210 nodeId = uuid.uuid4()
212 inits = quantum.initInputs.values()
213 inputs = quantum.inputs.values()
214 value = QuantumNode(quantum, taskDef, nodeId)
215 self._taskToQuantumNode[taskDef].add(value)
216 self._nodeIdMap[nodeId] = value
218 for dsRef in chain(inits, inputs):
219 # unfortunately, `Quantum` allows inits to be individual
220 # `DatasetRef`s or an Iterable of such, so there must
221 # be an instance check here
222 if isinstance(dsRef, Iterable):
223 for sub in dsRef:
224 if sub.isComponent():
225 sub = sub.makeCompositeRef()
226 self._datasetRefDict.addConsumer(sub, value)
227 else:
228 assert isinstance(dsRef, DatasetRef)
229 if dsRef.isComponent():
230 dsRef = dsRef.makeCompositeRef()
231 self._datasetRefDict.addConsumer(dsRef, value)
232 for dsRef in chain.from_iterable(quantum.outputs.values()):
233 self._datasetRefDict.addProducer(dsRef, value)
235 if pruneRefs is not None:
236 # track what refs were pruned and prune the graph
237 prunes: Set[QuantumNode] = set()
238 _pruner(self._datasetRefDict, pruneRefs, alreadyPruned=prunes)
240 # recreate the taskToQuantumNode dict removing nodes that have been
241 # pruned. Keep track of task defs that now have no QuantumNodes
242 emptyTasks: Set[str] = set()
243 newTaskToQuantumNode: DefaultDict[TaskDef, Set[QuantumNode]] = defaultdict(set)
244 # accumulate all types
245 types_ = set()
246 # tracker for any pruneRefs that have caused tasks to have no nodes
247 # This helps the user find out what caused the issues seen.
248 culprits = set()
249 # Find all the types from the refs to prune
250 for r in pruneRefs:
251 types_.add(r.datasetType)
253 # For each of the tasks, and their associated nodes, remove any
254 # any nodes that were pruned. If there are no nodes associated
255 # with a task, record that task, and find out if that was due to
256 # a type from an input ref to prune.
257 for td, taskNodes in self._taskToQuantumNode.items():
258 diff = taskNodes.difference(prunes)
259 if len(diff) == 0:
260 if len(taskNodes) != 0:
261 tp: DatasetType
262 for tp in types_:
263 if (tmpRefs := next(iter(taskNodes)).quantum.inputs.get(tp)) and not set(
264 tmpRefs
265 ).difference(pruneRefs):
266 culprits.add(tp.name)
267 emptyTasks.add(td.label)
268 newTaskToQuantumNode[td] = diff
270 # update the internal dict
271 self._taskToQuantumNode = newTaskToQuantumNode
273 if emptyTasks:
274 raise ValueError(
275 f"{', '.join(emptyTasks)} task(s) have no nodes associated with them "
276 f"after graph pruning; {', '.join(culprits)} caused over-pruning"
277 )
279 # Dimension universe
280 if universe is None:
281 raise RuntimeError(
282 "Dimension universe or at least one quantum with a data ID "
283 "must be provided when constructing a QuantumGraph."
284 )
285 self._universe = universe
287 # Graph of quanta relations
288 self._connectedQuanta = self._datasetRefDict.makeNetworkXGraph()
289 self._count = len(self._connectedQuanta)
291 # Graph of task relations, used in various methods
292 self._taskGraph = self._datasetDict.makeNetworkXGraph()
294 # convert default dict into a regular to prevent accidental key
295 # insertion
296 self._taskToQuantumNode = dict(self._taskToQuantumNode.items())
298 self._initInputRefs: Dict[TaskDef, List[DatasetRef]] = {}
299 self._initOutputRefs: Dict[TaskDef, List[DatasetRef]] = {}
300 if initInputs is not None:
301 self._initInputRefs = {taskDef: list(refs) for taskDef, refs in initInputs.items()}
302 if initOutputs is not None:
303 self._initOutputRefs = {taskDef: list(refs) for taskDef, refs in initOutputs.items()}
305 @property
306 def taskGraph(self) -> nx.DiGraph:
307 """Return a graph representing the relations between the tasks inside
308 the quantum graph.
310 Returns
311 -------
312 taskGraph : `networkx.Digraph`
313 Internal datastructure that holds relations of `TaskDef` objects
314 """
315 return self._taskGraph
317 @property
318 def graph(self) -> nx.DiGraph:
319 """Return a graph representing the relations between all the
320 `QuantumNode` objects. Largely it should be preferred to iterate
321 over, and use methods of this class, but sometimes direct access to
322 the networkx object may be helpful
324 Returns
325 -------
326 graph : `networkx.Digraph`
327 Internal datastructure that holds relations of `QuantumNode`
328 objects
329 """
330 return self._connectedQuanta
332 @property
333 def inputQuanta(self) -> Iterable[QuantumNode]:
334 """Make a `list` of all `QuantumNode` objects that are 'input' nodes
335 to the graph, meaning those nodes to not depend on any other nodes in
336 the graph.
338 Returns
339 -------
340 inputNodes : iterable of `QuantumNode`
341 A list of nodes that are inputs to the graph
342 """
343 return (q for q, n in self._connectedQuanta.in_degree if n == 0)
345 @property
346 def outputQuanta(self) -> Iterable[QuantumNode]:
347 """Make a `list` of all `QuantumNode` objects that are 'output' nodes
348 to the graph, meaning those nodes have no nodes that depend them in
349 the graph.
351 Returns
352 -------
353 outputNodes : iterable of `QuantumNode`
354 A list of nodes that are outputs of the graph
355 """
356 return [q for q, n in self._connectedQuanta.out_degree if n == 0]
358 @property
359 def allDatasetTypes(self) -> Tuple[DatasetTypeName, ...]:
360 """Return all the `DatasetTypeName` objects that are contained inside
361 the graph.
363 Returns
364 -------
365 tuple of `DatasetTypeName`
366 All the data set type names that are present in the graph
367 """
368 return tuple(self._datasetDict.keys())
370 @property
371 def isConnected(self) -> bool:
372 """Return True if all of the nodes in the graph are connected, ignores
373 directionality of connections.
374 """
375 return nx.is_weakly_connected(self._connectedQuanta)
377 def pruneGraphFromRefs(self: _T, refs: Iterable[DatasetRef]) -> _T:
378 r"""Return a graph pruned of input `~lsst.daf.butler.DatasetRef`\ s
379 and nodes which depend on them.
381 Parameters
382 ----------
383 refs : `Iterable` of `DatasetRef`
384 Refs which should be removed from resulting graph
386 Returns
387 -------
388 graph : `QuantumGraph`
389 A graph that has been pruned of specified refs and the nodes that
390 depend on them.
391 """
392 newInst = object.__new__(type(self))
393 quantumMap = defaultdict(set)
394 for node in self:
395 quantumMap[node.taskDef].add(node.quantum)
397 # convert to standard dict to prevent accidental key insertion
398 quantumDict: Dict[TaskDef, Set[Quantum]] = dict(quantumMap.items())
400 newInst._buildGraphs(
401 quantumDict,
402 _quantumToNodeId={n.quantum: n.nodeId for n in self},
403 metadata=self._metadata,
404 pruneRefs=refs,
405 universe=self._universe,
406 )
407 return newInst
409 def getQuantumNodeByNodeId(self, nodeId: uuid.UUID) -> QuantumNode:
410 """Lookup a `QuantumNode` from an id associated with the node.
412 Parameters
413 ----------
414 nodeId : `NodeId`
415 The number associated with a node
417 Returns
418 -------
419 node : `QuantumNode`
420 The node corresponding with input number
422 Raises
423 ------
424 KeyError
425 Raised if the requested nodeId is not in the graph.
426 """
427 return self._nodeIdMap[nodeId]
429 def getQuantaForTask(self, taskDef: TaskDef) -> FrozenSet[Quantum]:
430 """Return all the `Quantum` associated with a `TaskDef`.
432 Parameters
433 ----------
434 taskDef : `TaskDef`
435 The `TaskDef` for which `Quantum` are to be queried
437 Returns
438 -------
439 frozenset of `Quantum`
440 The `set` of `Quantum` that is associated with the specified
441 `TaskDef`.
442 """
443 return frozenset(node.quantum for node in self._taskToQuantumNode[taskDef])
445 def getNodesForTask(self, taskDef: TaskDef) -> FrozenSet[QuantumNode]:
446 """Return all the `QuantumNodes` associated with a `TaskDef`.
448 Parameters
449 ----------
450 taskDef : `TaskDef`
451 The `TaskDef` for which `Quantum` are to be queried
453 Returns
454 -------
455 frozenset of `QuantumNodes`
456 The `frozenset` of `QuantumNodes` that is associated with the
457 specified `TaskDef`.
458 """
459 return frozenset(self._taskToQuantumNode[taskDef])
461 def findTasksWithInput(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]:
462 """Find all tasks that have the specified dataset type name as an
463 input.
465 Parameters
466 ----------
467 datasetTypeName : `str`
468 A string representing the name of a dataset type to be queried,
469 can also accept a `DatasetTypeName` which is a `NewType` of str for
470 type safety in static type checking.
472 Returns
473 -------
474 tasks : iterable of `TaskDef`
475 `TaskDef` objects that have the specified `DatasetTypeName` as an
476 input, list will be empty if no tasks use specified
477 `DatasetTypeName` as an input.
479 Raises
480 ------
481 KeyError
482 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
483 """
484 return (c for c in self._datasetDict.getConsumers(datasetTypeName))
486 def findTaskWithOutput(self, datasetTypeName: DatasetTypeName) -> Optional[TaskDef]:
487 """Find all tasks that have the specified dataset type name as an
488 output.
490 Parameters
491 ----------
492 datasetTypeName : `str`
493 A string representing the name of a dataset type to be queried,
494 can also accept a `DatasetTypeName` which is a `NewType` of str for
495 type safety in static type checking.
497 Returns
498 -------
499 `TaskDef` or `None`
500 `TaskDef` that outputs `DatasetTypeName` as an output or None if
501 none of the tasks produce this `DatasetTypeName`.
503 Raises
504 ------
505 KeyError
506 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
507 """
508 return self._datasetDict.getProducer(datasetTypeName)
510 def tasksWithDSType(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]:
511 """Find all tasks that are associated with the specified dataset type
512 name.
514 Parameters
515 ----------
516 datasetTypeName : `str`
517 A string representing the name of a dataset type to be queried,
518 can also accept a `DatasetTypeName` which is a `NewType` of str for
519 type safety in static type checking.
521 Returns
522 -------
523 result : iterable of `TaskDef`
524 `TaskDef` objects that are associated with the specified
525 `DatasetTypeName`
527 Raises
528 ------
529 KeyError
530 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
531 """
532 return self._datasetDict.getAll(datasetTypeName)
534 def findTaskDefByName(self, taskName: str) -> List[TaskDef]:
535 """Determine which `TaskDef` objects in this graph are associated
536 with a `str` representing a task name (looks at the taskName property
537 of `TaskDef` objects).
539 Returns a list of `TaskDef` objects as a `PipelineTask` may appear
540 multiple times in a graph with different labels.
542 Parameters
543 ----------
544 taskName : str
545 Name of a task to search for
547 Returns
548 -------
549 result : list of `TaskDef`
550 List of the `TaskDef` objects that have the name specified.
551 Multiple values are returned in the case that a task is used
552 multiple times with different labels.
553 """
554 results = []
555 for task in self._taskToQuantumNode.keys():
556 split = task.taskName.split(".")
557 if split[-1] == taskName:
558 results.append(task)
559 return results
561 def findTaskDefByLabel(self, label: str) -> Optional[TaskDef]:
562 """Determine which `TaskDef` objects in this graph are associated
563 with a `str` representing a tasks label.
565 Parameters
566 ----------
567 taskName : str
568 Name of a task to search for
570 Returns
571 -------
572 result : `TaskDef`
573 `TaskDef` objects that has the specified label.
574 """
575 for task in self._taskToQuantumNode.keys():
576 if label == task.label:
577 return task
578 return None
580 def findQuantaWithDSType(self, datasetTypeName: DatasetTypeName) -> Set[Quantum]:
581 """Return all the `Quantum` that contain a specified `DatasetTypeName`.
583 Parameters
584 ----------
585 datasetTypeName : `str`
586 The name of the dataset type to search for as a string,
587 can also accept a `DatasetTypeName` which is a `NewType` of str for
588 type safety in static type checking.
590 Returns
591 -------
592 result : `set` of `QuantumNode` objects
593 A `set` of `QuantumNode`s that contain specified `DatasetTypeName`
595 Raises
596 ------
597 KeyError
598 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
600 """
601 tasks = self._datasetDict.getAll(datasetTypeName)
602 result: Set[Quantum] = set()
603 result = result.union(quantum for task in tasks for quantum in self.getQuantaForTask(task))
604 return result
606 def checkQuantumInGraph(self, quantum: Quantum) -> bool:
607 """Check if specified quantum appears in the graph as part of a node.
609 Parameters
610 ----------
611 quantum : `Quantum`
612 The quantum to search for
614 Returns
615 -------
616 `bool`
617 The result of searching for the quantum
618 """
619 for node in self:
620 if quantum == node.quantum:
621 return True
622 return False
624 def writeDotGraph(self, output: Union[str, io.BufferedIOBase]) -> None:
625 """Write out the graph as a dot graph.
627 Parameters
628 ----------
629 output : str or `io.BufferedIOBase`
630 Either a filesystem path to write to, or a file handle object
631 """
632 write_dot(self._connectedQuanta, output)
634 def subset(self: _T, nodes: Union[QuantumNode, Iterable[QuantumNode]]) -> _T:
635 """Create a new graph object that contains the subset of the nodes
636 specified as input. Node number is preserved.
638 Parameters
639 ----------
640 nodes : `QuantumNode` or iterable of `QuantumNode`
642 Returns
643 -------
644 graph : instance of graph type
645 An instance of the type from which the subset was created
646 """
647 if not isinstance(nodes, Iterable):
648 nodes = (nodes,)
649 quantumSubgraph = self._connectedQuanta.subgraph(nodes).nodes
650 quantumMap = defaultdict(set)
652 node: QuantumNode
653 for node in quantumSubgraph:
654 quantumMap[node.taskDef].add(node.quantum)
656 # convert to standard dict to prevent accidental key insertion
657 quantumDict: Dict[TaskDef, Set[Quantum]] = dict(quantumMap.items())
658 # Create an empty graph, and then populate it with custom mapping
659 newInst = type(self)({}, universe=self._universe)
660 newInst._buildGraphs(
661 quantumDict,
662 _quantumToNodeId={n.quantum: n.nodeId for n in nodes},
663 _buildId=self._buildId,
664 metadata=self._metadata,
665 universe=self._universe,
666 )
667 return newInst
669 def subsetToConnected(self: _T) -> Tuple[_T, ...]:
670 """Generate a list of subgraphs where each is connected.
672 Returns
673 -------
674 result : list of `QuantumGraph`
675 A list of graphs that are each connected
676 """
677 return tuple(
678 self.subset(connectedSet)
679 for connectedSet in nx.weakly_connected_components(self._connectedQuanta)
680 )
682 def determineInputsToQuantumNode(self, node: QuantumNode) -> Set[QuantumNode]:
683 """Return a set of `QuantumNode` that are direct inputs to a specified
684 node.
686 Parameters
687 ----------
688 node : `QuantumNode`
689 The node of the graph for which inputs are to be determined
691 Returns
692 -------
693 set of `QuantumNode`
694 All the nodes that are direct inputs to specified node
695 """
696 return set(pred for pred in self._connectedQuanta.predecessors(node))
698 def determineOutputsOfQuantumNode(self, node: QuantumNode) -> Set[QuantumNode]:
699 """Return a set of `QuantumNode` that are direct outputs of a specified
700 node.
702 Parameters
703 ----------
704 node : `QuantumNode`
705 The node of the graph for which outputs are to be determined
707 Returns
708 -------
709 set of `QuantumNode`
710 All the nodes that are direct outputs to specified node
711 """
712 return set(succ for succ in self._connectedQuanta.successors(node))
714 def determineConnectionsOfQuantumNode(self: _T, node: QuantumNode) -> _T:
715 """Return a graph of `QuantumNode` that are direct inputs and outputs
716 of a specified node.
718 Parameters
719 ----------
720 node : `QuantumNode`
721 The node of the graph for which connected nodes are to be
722 determined.
724 Returns
725 -------
726 graph : graph of `QuantumNode`
727 All the nodes that are directly connected to specified node
728 """
729 nodes = self.determineInputsToQuantumNode(node).union(self.determineOutputsOfQuantumNode(node))
730 nodes.add(node)
731 return self.subset(nodes)
733 def determineAncestorsOfQuantumNode(self: _T, node: QuantumNode) -> _T:
734 """Return a graph of the specified node and all the ancestor nodes
735 directly reachable by walking edges.
737 Parameters
738 ----------
739 node : `QuantumNode`
740 The node for which all ansestors are to be determined
742 Returns
743 -------
744 graph of `QuantumNode`
745 Graph of node and all of its ansestors
746 """
747 predecessorNodes = nx.ancestors(self._connectedQuanta, node)
748 predecessorNodes.add(node)
749 return self.subset(predecessorNodes)
751 def findCycle(self) -> List[Tuple[QuantumNode, QuantumNode]]:
752 """Check a graph for the presense of cycles and returns the edges of
753 any cycles found, or an empty list if there is no cycle.
755 Returns
756 -------
757 result : list of tuple of `QuantumNode`, `QuantumNode`
758 A list of any graph edges that form a cycle, or an empty list if
759 there is no cycle. Empty list to so support if graph.find_cycle()
760 syntax as an empty list is falsy.
761 """
762 try:
763 return nx.find_cycle(self._connectedQuanta)
764 except nx.NetworkXNoCycle:
765 return []
767 def saveUri(self, uri: ResourcePathExpression) -> None:
768 """Save `QuantumGraph` to the specified URI.
770 Parameters
771 ----------
772 uri : convertible to `ResourcePath`
773 URI to where the graph should be saved.
774 """
775 buffer = self._buildSaveObject()
776 path = ResourcePath(uri)
777 if path.getExtension() not in (".qgraph"):
778 raise TypeError(f"Can currently only save a graph in qgraph format not {uri}")
779 path.write(buffer) # type: ignore # Ignore because bytearray is safe to use in place of bytes
781 @property
782 def metadata(self) -> Optional[MappingProxyType[str, Any]]:
783 """ """
784 if self._metadata is None:
785 return None
786 return MappingProxyType(self._metadata)
788 def initInputRefs(self, taskDef: TaskDef) -> Optional[List[DatasetRef]]:
789 """Return DatasetRefs for a given task InitInputs.
791 Parameters
792 ----------
793 taskDef : `TaskDef`
794 Task definition structure.
796 Returns
797 -------
798 refs : `list` [ `DatasetRef` ] or None
799 DatasetRef for the task InitInput, can be `None`. This can return
800 either resolved or non-resolved reference.
801 """
802 return self._initInputRefs.get(taskDef)
804 def initOutputRefs(self, taskDef: TaskDef) -> Optional[List[DatasetRef]]:
805 """Return DatasetRefs for a given task InitOutputs.
807 Parameters
808 ----------
809 taskDef : `TaskDef`
810 Task definition structure.
812 Returns
813 -------
814 refs : `list` [ `DatasetRef` ] or None
815 DatasetRefs for the task InitOutput, can be `None`. This can return
816 either resolved or non-resolved reference. Resolved reference will
817 match Quantum's initInputs if this is an intermediate dataset type.
818 """
819 return self._initOutputRefs.get(taskDef)
821 @classmethod
822 def loadUri(
823 cls,
824 uri: ResourcePathExpression,
825 universe: Optional[DimensionUniverse] = None,
826 nodes: Optional[Iterable[uuid.UUID]] = None,
827 graphID: Optional[BuildId] = None,
828 minimumVersion: int = 3,
829 ) -> QuantumGraph:
830 """Read `QuantumGraph` from a URI.
832 Parameters
833 ----------
834 uri : convertible to `ResourcePath`
835 URI from where to load the graph.
836 universe: `~lsst.daf.butler.DimensionUniverse` optional
837 DimensionUniverse instance, not used by the method itself but
838 needed to ensure that registry data structures are initialized.
839 If None it is loaded from the QuantumGraph saved structure. If
840 supplied, the DimensionUniverse from the loaded `QuantumGraph`
841 will be validated against the supplied argument for compatibility.
842 nodes: iterable of `int` or None
843 Numbers that correspond to nodes in the graph. If specified, only
844 these nodes will be loaded. Defaults to None, in which case all
845 nodes will be loaded.
846 graphID : `str` or `None`
847 If specified this ID is verified against the loaded graph prior to
848 loading any Nodes. This defaults to None in which case no
849 validation is done.
850 minimumVersion : int
851 Minimum version of a save file to load. Set to -1 to load all
852 versions. Older versions may need to be loaded, and re-saved
853 to upgrade them to the latest format before they can be used in
854 production.
856 Returns
857 -------
858 graph : `QuantumGraph`
859 Resulting QuantumGraph instance.
861 Raises
862 ------
863 TypeError
864 Raised if pickle contains instance of a type other than
865 QuantumGraph.
866 ValueError
867 Raised if one or more of the nodes requested is not in the
868 `QuantumGraph` or if graphID parameter does not match the graph
869 being loaded or if the supplied uri does not point at a valid
870 `QuantumGraph` save file.
871 RuntimeError
872 Raise if Supplied DimensionUniverse is not compatible with the
873 DimensionUniverse saved in the graph
876 Notes
877 -----
878 Reading Quanta from pickle requires existence of singleton
879 DimensionUniverse which is usually instantiated during Registry
880 initialization. To make sure that DimensionUniverse exists this method
881 accepts dummy DimensionUniverse argument.
882 """
883 uri = ResourcePath(uri)
884 # With ResourcePath we have the choice of always using a local file
885 # or reading in the bytes directly. Reading in bytes can be more
886 # efficient for reasonably-sized pickle files when the resource
887 # is remote. For now use the local file variant. For a local file
888 # as_local() does nothing.
890 if uri.getExtension() in (".pickle", ".pkl"):
891 with uri.as_local() as local, open(local.ospath, "rb") as fd:
892 warnings.warn("Pickle graphs are deprecated, please re-save your graph with the save method")
893 qgraph = pickle.load(fd)
894 elif uri.getExtension() in (".qgraph"):
895 with LoadHelper(uri, minimumVersion) as loader:
896 qgraph = loader.load(universe, nodes, graphID)
897 else:
898 raise ValueError("Only know how to handle files saved as `pickle`, `pkl`, or `qgraph`")
899 if not isinstance(qgraph, QuantumGraph):
900 raise TypeError(f"QuantumGraph save file contains unexpected object type: {type(qgraph)}")
901 return qgraph
903 @classmethod
904 def readHeader(cls, uri: ResourcePathExpression, minimumVersion: int = 3) -> Optional[str]:
905 """Read the header of a `QuantumGraph` pointed to by the uri parameter
906 and return it as a string.
908 Parameters
909 ----------
910 uri : convertible to `ResourcePath`
911 The location of the `QuantumGraph` to load. If the argument is a
912 string, it must correspond to a valid `ResourcePath` path.
913 minimumVersion : int
914 Minimum version of a save file to load. Set to -1 to load all
915 versions. Older versions may need to be loaded, and re-saved
916 to upgrade them to the latest format before they can be used in
917 production.
919 Returns
920 -------
921 header : `str` or `None`
922 The header associated with the specified `QuantumGraph` it there is
923 one, else `None`.
925 Raises
926 ------
927 ValueError
928 Raised if `QuantuGraph` was saved as a pickle.
929 Raised if the extention of the file specified by uri is not a
930 `QuantumGraph` extention.
931 """
932 uri = ResourcePath(uri)
933 if uri.getExtension() in (".pickle", ".pkl"):
934 raise ValueError("Reading a header from a pickle save is not supported")
935 elif uri.getExtension() in (".qgraph"):
936 return LoadHelper(uri, minimumVersion).readHeader()
937 else:
938 raise ValueError("Only know how to handle files saved as `qgraph`")
940 def buildAndPrintHeader(self) -> None:
941 """Creates a header that would be used in a save of this object and
942 prints it out to standard out.
943 """
944 _, header = self._buildSaveObject(returnHeader=True)
945 print(json.dumps(header))
947 def save(self, file: BinaryIO) -> None:
948 """Save QuantumGraph to a file.
950 Parameters
951 ----------
952 file : `io.BufferedIOBase`
953 File to write pickle data open in binary mode.
954 """
955 buffer = self._buildSaveObject()
956 file.write(buffer) # type: ignore # Ignore because bytearray is safe to use in place of bytes
958 def _buildSaveObject(self, returnHeader: bool = False) -> Union[bytearray, Tuple[bytearray, Dict]]:
959 # make some containers
960 jsonData: Deque[bytes] = deque()
961 # node map is a list because json does not accept mapping keys that
962 # are not strings, so we store a list of key, value pairs that will
963 # be converted to a mapping on load
964 nodeMap = []
965 taskDefMap = {}
966 headerData: Dict[str, Any] = {}
968 # Store the QauntumGraph BuildId, this will allow validating BuildIds
969 # at load time, prior to loading any QuantumNodes. Name chosen for
970 # unlikely conflicts.
971 headerData["GraphBuildID"] = self.graphID
972 headerData["Metadata"] = self._metadata
974 # Store the universe this graph was created with
975 universeConfig = self._universe.dimensionConfig
976 headerData["universe"] = universeConfig.toDict()
978 # counter for the number of bytes processed thus far
979 count = 0
980 # serialize out the task Defs recording the start and end bytes of each
981 # taskDef
982 inverseLookup = self._datasetDict.inverse
983 taskDef: TaskDef
984 # sort by task label to ensure serialization happens in the same order
985 for taskDef in self.taskGraph:
986 # compressing has very little impact on saving or load time, but
987 # a large impact on on disk size, so it is worth doing
988 taskDescription: Dict[str, Any] = {}
989 # save the fully qualified name.
990 taskDescription["taskName"] = get_full_type_name(taskDef.taskClass)
991 # save the config as a text stream that will be un-persisted on the
992 # other end
993 stream = io.StringIO()
994 taskDef.config.saveToStream(stream)
995 taskDescription["config"] = stream.getvalue()
996 taskDescription["label"] = taskDef.label
997 if (refs := self._initInputRefs.get(taskDef)) is not None:
998 taskDescription["initInputRefs"] = [ref.to_json() for ref in refs]
999 if (refs := self._initOutputRefs.get(taskDef)) is not None:
1000 taskDescription["initOutputRefs"] = [ref.to_json() for ref in refs]
1002 inputs = []
1003 outputs = []
1005 # Determine the connection between all of tasks and save that in
1006 # the header as a list of connections and edges in each task
1007 # this will help in un-persisting, and possibly in a "quick view"
1008 # method that does not require everything to be un-persisted
1009 #
1010 # Typing returns can't be parameter dependent
1011 for connection in inverseLookup[taskDef]: # type: ignore
1012 consumers = self._datasetDict.getConsumers(connection)
1013 producer = self._datasetDict.getProducer(connection)
1014 if taskDef in consumers:
1015 # This checks if the task consumes the connection directly
1016 # from the datastore or it is produced by another task
1017 producerLabel = producer.label if producer is not None else "datastore"
1018 inputs.append((producerLabel, connection))
1019 elif taskDef not in consumers and producer is taskDef:
1020 # If there are no consumers for this tasks produced
1021 # connection, the output will be said to be the datastore
1022 # in which case the for loop will be a zero length loop
1023 if not consumers:
1024 outputs.append(("datastore", connection))
1025 for td in consumers:
1026 outputs.append((td.label, connection))
1028 # dump to json string, and encode that string to bytes and then
1029 # conpress those bytes
1030 dump = lzma.compress(json.dumps(taskDescription).encode())
1031 # record the sizing and relation information
1032 taskDefMap[taskDef.label] = {
1033 "bytes": (count, count + len(dump)),
1034 "inputs": inputs,
1035 "outputs": outputs,
1036 }
1037 count += len(dump)
1038 jsonData.append(dump)
1040 headerData["TaskDefs"] = taskDefMap
1042 # serialize the nodes, recording the start and end bytes of each node
1043 dimAccumulator = DimensionRecordsAccumulator()
1044 for node in self:
1045 # compressing has very little impact on saving or load time, but
1046 # a large impact on on disk size, so it is worth doing
1047 simpleNode = node.to_simple(accumulator=dimAccumulator)
1049 dump = lzma.compress(simpleNode.json().encode())
1050 jsonData.append(dump)
1051 nodeMap.append(
1052 (
1053 str(node.nodeId),
1054 {
1055 "bytes": (count, count + len(dump)),
1056 "inputs": [str(n.nodeId) for n in self.determineInputsToQuantumNode(node)],
1057 "outputs": [str(n.nodeId) for n in self.determineOutputsOfQuantumNode(node)],
1058 },
1059 )
1060 )
1061 count += len(dump)
1063 headerData["DimensionRecords"] = {
1064 key: value.dict() for key, value in dimAccumulator.makeSerializedDimensionRecordMapping().items()
1065 }
1067 # need to serialize this as a series of key,value tuples because of
1068 # a limitation on how json cant do anyting but strings as keys
1069 headerData["Nodes"] = nodeMap
1071 # dump the headerData to json
1072 header_encode = lzma.compress(json.dumps(headerData).encode())
1074 # record the sizes as 2 unsigned long long numbers for a total of 16
1075 # bytes
1076 save_bytes = struct.pack(STRUCT_FMT_BASE, SAVE_VERSION)
1078 fmt_string = DESERIALIZER_MAP[SAVE_VERSION].FMT_STRING()
1079 map_lengths = struct.pack(fmt_string, len(header_encode))
1081 # write each component of the save out in a deterministic order
1082 # buffer = io.BytesIO()
1083 # buffer.write(map_lengths)
1084 # buffer.write(taskDef_pickle)
1085 # buffer.write(map_pickle)
1086 buffer = bytearray()
1087 buffer.extend(MAGIC_BYTES)
1088 buffer.extend(save_bytes)
1089 buffer.extend(map_lengths)
1090 buffer.extend(header_encode)
1091 # Iterate over the length of pickleData, and for each element pop the
1092 # leftmost element off the deque and write it out. This is to save
1093 # memory, as the memory is added to the buffer object, it is removed
1094 # from from the container.
1095 #
1096 # Only this section needs to worry about memory pressue because
1097 # everything else written to the buffer prior to this pickle data is
1098 # only on the order of kilobytes to low numbers of megabytes.
1099 while jsonData:
1100 buffer.extend(jsonData.popleft())
1101 if returnHeader:
1102 return buffer, headerData
1103 else:
1104 return buffer
1106 @classmethod
1107 def load(
1108 cls,
1109 file: BinaryIO,
1110 universe: Optional[DimensionUniverse] = None,
1111 nodes: Optional[Iterable[uuid.UUID]] = None,
1112 graphID: Optional[BuildId] = None,
1113 minimumVersion: int = 3,
1114 ) -> QuantumGraph:
1115 """Read QuantumGraph from a file that was made by `save`.
1117 Parameters
1118 ----------
1119 file : `io.IO` of bytes
1120 File with pickle data open in binary mode.
1121 universe: `~lsst.daf.butler.DimensionUniverse`, optional
1122 DimensionUniverse instance, not used by the method itself but
1123 needed to ensure that registry data structures are initialized.
1124 If None it is loaded from the QuantumGraph saved structure. If
1125 supplied, the DimensionUniverse from the loaded `QuantumGraph`
1126 will be validated against the supplied argument for compatibility.
1127 nodes: iterable of `int` or None
1128 Numbers that correspond to nodes in the graph. If specified, only
1129 these nodes will be loaded. Defaults to None, in which case all
1130 nodes will be loaded.
1131 graphID : `str` or `None`
1132 If specified this ID is verified against the loaded graph prior to
1133 loading any Nodes. This defaults to None in which case no
1134 validation is done.
1135 minimumVersion : int
1136 Minimum version of a save file to load. Set to -1 to load all
1137 versions. Older versions may need to be loaded, and re-saved
1138 to upgrade them to the latest format before they can be used in
1139 production.
1141 Returns
1142 -------
1143 graph : `QuantumGraph`
1144 Resulting QuantumGraph instance.
1146 Raises
1147 ------
1148 TypeError
1149 Raised if pickle contains instance of a type other than
1150 QuantumGraph.
1151 ValueError
1152 Raised if one or more of the nodes requested is not in the
1153 `QuantumGraph` or if graphID parameter does not match the graph
1154 being loaded or if the supplied uri does not point at a valid
1155 `QuantumGraph` save file.
1157 Notes
1158 -----
1159 Reading Quanta from pickle requires existence of singleton
1160 DimensionUniverse which is usually instantiated during Registry
1161 initialization. To make sure that DimensionUniverse exists this method
1162 accepts dummy DimensionUniverse argument.
1163 """
1164 # Try to see if the file handle contains pickle data, this will be
1165 # removed in the future
1166 try:
1167 qgraph = pickle.load(file)
1168 warnings.warn("Pickle graphs are deprecated, please re-save your graph with the save method")
1169 except pickle.UnpicklingError:
1170 with LoadHelper(file, minimumVersion) as loader:
1171 qgraph = loader.load(universe, nodes, graphID)
1172 if not isinstance(qgraph, QuantumGraph):
1173 raise TypeError(f"QuantumGraph pickle file has contains unexpected object type: {type(qgraph)}")
1174 return qgraph
1176 def iterTaskGraph(self) -> Generator[TaskDef, None, None]:
1177 """Iterate over the `taskGraph` attribute in topological order
1179 Yields
1180 ------
1181 taskDef : `TaskDef`
1182 `TaskDef` objects in topological order
1183 """
1184 yield from nx.topological_sort(self.taskGraph)
1186 @property
1187 def graphID(self) -> BuildId:
1188 """Returns the ID generated by the graph at construction time"""
1189 return self._buildId
1191 def __iter__(self) -> Generator[QuantumNode, None, None]:
1192 yield from nx.topological_sort(self._connectedQuanta)
1194 def __len__(self) -> int:
1195 return self._count
1197 def __contains__(self, node: QuantumNode) -> bool:
1198 return self._connectedQuanta.has_node(node)
1200 def __getstate__(self) -> dict:
1201 """Stores a compact form of the graph as a list of graph nodes, and a
1202 tuple of task labels and task configs. The full graph can be
1203 reconstructed with this information, and it preseves the ordering of
1204 the graph ndoes.
1205 """
1206 universe: Optional[DimensionUniverse] = None
1207 for node in self:
1208 dId = node.quantum.dataId
1209 if dId is None:
1210 continue
1211 universe = dId.graph.universe
1212 return {"reduced": self._buildSaveObject(), "graphId": self._buildId, "universe": universe}
1214 def __setstate__(self, state: dict) -> None:
1215 """Reconstructs the state of the graph from the information persisted
1216 in getstate.
1217 """
1218 buffer = io.BytesIO(state["reduced"])
1219 with LoadHelper(buffer, minimumVersion=3) as loader:
1220 qgraph = loader.load(state["universe"], graphID=state["graphId"])
1222 self._metadata = qgraph._metadata
1223 self._buildId = qgraph._buildId
1224 self._datasetDict = qgraph._datasetDict
1225 self._nodeIdMap = qgraph._nodeIdMap
1226 self._count = len(qgraph)
1227 self._taskToQuantumNode = qgraph._taskToQuantumNode
1228 self._taskGraph = qgraph._taskGraph
1229 self._connectedQuanta = qgraph._connectedQuanta
1230 self._initInputRefs = qgraph._initInputRefs
1231 self._initOutputRefs = qgraph._initOutputRefs
1233 def __eq__(self, other: object) -> bool:
1234 if not isinstance(other, QuantumGraph):
1235 return False
1236 if len(self) != len(other):
1237 return False
1238 for node in self:
1239 if node not in other:
1240 return False
1241 if self.determineInputsToQuantumNode(node) != other.determineInputsToQuantumNode(node):
1242 return False
1243 if self.determineOutputsOfQuantumNode(node) != other.determineOutputsOfQuantumNode(node):
1244 return False
1245 if set(self.allDatasetTypes) != set(other.allDatasetTypes):
1246 return False
1247 return set(self.taskGraph) == set(other.taskGraph)