Coverage for python/lsst/pipe/base/graph/graph.py: 16%
395 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-04-14 02:16 -0700
« prev ^ index » next coverage.py v6.5.0, created at 2023-04-14 02:16 -0700
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ("QuantumGraph", "IncompatibleGraphError")
25import io
26import json
27import lzma
28import os
29import pickle
30import struct
31import time
32import uuid
33import warnings
34from collections import defaultdict, deque
35from itertools import chain
36from types import MappingProxyType
37from typing import (
38 Any,
39 BinaryIO,
40 DefaultDict,
41 Deque,
42 Dict,
43 FrozenSet,
44 Generator,
45 Iterable,
46 List,
47 Mapping,
48 MutableMapping,
49 Optional,
50 Set,
51 Tuple,
52 TypeVar,
53 Union,
54)
56import networkx as nx
57from lsst.daf.butler import DatasetRef, DatasetType, DimensionRecordsAccumulator, DimensionUniverse, Quantum
58from lsst.resources import ResourcePath, ResourcePathExpression
59from lsst.utils.introspection import get_full_type_name
60from networkx.drawing.nx_agraph import write_dot
62from ..connections import iterConnections
63from ..pipeline import TaskDef
64from ._implDetails import DatasetTypeName, _DatasetTracker, _pruner
65from ._loadHelpers import LoadHelper
66from ._versionDeserializers import DESERIALIZER_MAP
67from .quantumNode import BuildId, QuantumNode
69_T = TypeVar("_T", bound="QuantumGraph")
71# modify this constant any time the on disk representation of the save file
72# changes, and update the load helpers to behave properly for each version.
73SAVE_VERSION = 3
75# Strings used to describe the format for the preamble bytes in a file save
76# The base is a big endian encoded unsigned short that is used to hold the
77# file format version. This allows reading version bytes and determine which
78# loading code should be used for the rest of the file
79STRUCT_FMT_BASE = ">H"
80#
81# Version 1
82# This marks a big endian encoded format with an unsigned short, an unsigned
83# long long, and an unsigned long long in the byte stream
84# Version 2
85# A big endian encoded format with an unsigned long long byte stream used to
86# indicate the total length of the entire header.
87STRUCT_FMT_STRING = {1: ">QQ", 2: ">Q"}
89# magic bytes that help determine this is a graph save
90MAGIC_BYTES = b"qgraph4\xf6\xe8\xa9"
93class IncompatibleGraphError(Exception):
94 """Exception class to indicate that a lookup by NodeId is impossible due
95 to incompatibilities
96 """
98 pass
101class QuantumGraph:
102 """QuantumGraph is a directed acyclic graph of `QuantumNode` objects
104 This data structure represents a concrete workflow generated from a
105 `Pipeline`.
107 Parameters
108 ----------
109 quanta : Mapping of `TaskDef` to sets of `Quantum`
110 This maps tasks (and their configs) to the sets of data they are to
111 process.
112 metadata : Optional Mapping of `str` to primitives
113 This is an optional parameter of extra data to carry with the graph.
114 Entries in this mapping should be able to be serialized in JSON.
115 pruneRefs : iterable [ `DatasetRef` ], optional
116 Set of dataset refs to exclude from a graph.
117 initInputs : `Mapping`, optional
118 Maps tasks to their InitInput dataset refs. Dataset refs can be either
119 resolved or non-resolved. Presently the same dataset refs are included
120 in each `Quantum` for the same task.
121 initOutputs : `Mapping`, optional
122 Maps tasks to their InitOutput dataset refs. Dataset refs can be either
123 resolved or non-resolved. For intermediate resolved refs their dataset
124 ID must match ``initInputs`` and Quantum ``initInputs``.
125 globalInitOutputs : iterable [ `DatasetRef` ], optional
126 Dataset refs for some global objects produced by pipeline. These
127 objects include task configurations and package versions. Typically
128 they have an empty DataId, but there is no real restriction on what
129 can appear here.
130 registryDatasetTypes : iterable [ `DatasetType` ], optional
131 Dataset types which are used by this graph, their definitions must
132 match registry. If registry does not define dataset type yet, then
133 it should match one that will be created later.
135 Raises
136 ------
137 ValueError
138 Raised if the graph is pruned such that some tasks no longer have nodes
139 associated with them.
140 """
142 def __init__(
143 self,
144 quanta: Mapping[TaskDef, Set[Quantum]],
145 metadata: Optional[Mapping[str, Any]] = None,
146 pruneRefs: Optional[Iterable[DatasetRef]] = None,
147 universe: Optional[DimensionUniverse] = None,
148 initInputs: Optional[Mapping[TaskDef, Iterable[DatasetRef]]] = None,
149 initOutputs: Optional[Mapping[TaskDef, Iterable[DatasetRef]]] = None,
150 globalInitOutputs: Optional[Iterable[DatasetRef]] = None,
151 registryDatasetTypes: Optional[Iterable[DatasetType]] = None,
152 ):
153 self._buildGraphs(
154 quanta,
155 metadata=metadata,
156 pruneRefs=pruneRefs,
157 universe=universe,
158 initInputs=initInputs,
159 initOutputs=initOutputs,
160 globalInitOutputs=globalInitOutputs,
161 registryDatasetTypes=registryDatasetTypes,
162 )
164 def _buildGraphs(
165 self,
166 quanta: Mapping[TaskDef, Set[Quantum]],
167 *,
168 _quantumToNodeId: Optional[Mapping[Quantum, uuid.UUID]] = None,
169 _buildId: Optional[BuildId] = None,
170 metadata: Optional[Mapping[str, Any]] = None,
171 pruneRefs: Optional[Iterable[DatasetRef]] = None,
172 universe: Optional[DimensionUniverse] = None,
173 initInputs: Optional[Mapping[TaskDef, Iterable[DatasetRef]]] = None,
174 initOutputs: Optional[Mapping[TaskDef, Iterable[DatasetRef]]] = None,
175 globalInitOutputs: Optional[Iterable[DatasetRef]] = None,
176 registryDatasetTypes: Optional[Iterable[DatasetType]] = None,
177 ) -> None:
178 """Builds the graph that is used to store the relation between tasks,
179 and the graph that holds the relations between quanta
180 """
181 self._metadata = metadata
182 self._buildId = _buildId if _buildId is not None else BuildId(f"{time.time()}-{os.getpid()}")
183 # Data structures used to identify relations between components;
184 # DatasetTypeName -> TaskDef for task,
185 # and DatasetRef -> QuantumNode for the quanta
186 self._datasetDict = _DatasetTracker[DatasetTypeName, TaskDef](createInverse=True)
187 self._datasetRefDict = _DatasetTracker[DatasetRef, QuantumNode]()
189 self._nodeIdMap: Dict[uuid.UUID, QuantumNode] = {}
190 self._taskToQuantumNode: MutableMapping[TaskDef, Set[QuantumNode]] = defaultdict(set)
191 for taskDef, quantumSet in quanta.items():
192 connections = taskDef.connections
194 # For each type of connection in the task, add a key to the
195 # `_DatasetTracker` for the connections name, with a value of
196 # the TaskDef in the appropriate field
197 for inpt in iterConnections(connections, ("inputs", "prerequisiteInputs", "initInputs")):
198 # Have to handle components in inputs.
199 dataset_name, _, _ = inpt.name.partition(".")
200 self._datasetDict.addConsumer(DatasetTypeName(dataset_name), taskDef)
202 for output in iterConnections(connections, ("outputs",)):
203 # Have to handle possible components in outputs.
204 dataset_name, _, _ = output.name.partition(".")
205 self._datasetDict.addProducer(DatasetTypeName(dataset_name), taskDef)
207 # For each `Quantum` in the set of all `Quantum` for this task,
208 # add a key to the `_DatasetTracker` that is a `DatasetRef` for one
209 # of the individual datasets inside the `Quantum`, with a value of
210 # a newly created QuantumNode to the appropriate input/output
211 # field.
212 for quantum in quantumSet:
213 if quantum.dataId is not None:
214 if universe is None:
215 universe = quantum.dataId.universe
216 elif universe != quantum.dataId.universe:
217 raise RuntimeError(
218 "Mismatched dimension universes in QuantumGraph construction: "
219 f"{universe} != {quantum.dataId.universe}. "
220 )
222 if _quantumToNodeId:
223 if (nodeId := _quantumToNodeId.get(quantum)) is None:
224 raise ValueError(
225 "If _quantuMToNodeNumber is not None, all quanta must have an "
226 "associated value in the mapping"
227 )
228 else:
229 nodeId = uuid.uuid4()
231 inits = quantum.initInputs.values()
232 inputs = quantum.inputs.values()
233 value = QuantumNode(quantum, taskDef, nodeId)
234 self._taskToQuantumNode[taskDef].add(value)
235 self._nodeIdMap[nodeId] = value
237 for dsRef in chain(inits, inputs):
238 # unfortunately, `Quantum` allows inits to be individual
239 # `DatasetRef`s or an Iterable of such, so there must
240 # be an instance check here
241 if isinstance(dsRef, Iterable):
242 for sub in dsRef:
243 if sub.isComponent():
244 sub = sub.makeCompositeRef()
245 self._datasetRefDict.addConsumer(sub, value)
246 else:
247 assert isinstance(dsRef, DatasetRef)
248 if dsRef.isComponent():
249 dsRef = dsRef.makeCompositeRef()
250 self._datasetRefDict.addConsumer(dsRef, value)
251 for dsRef in chain.from_iterable(quantum.outputs.values()):
252 self._datasetRefDict.addProducer(dsRef, value)
254 if pruneRefs is not None:
255 # track what refs were pruned and prune the graph
256 prunes: Set[QuantumNode] = set()
257 _pruner(self._datasetRefDict, pruneRefs, alreadyPruned=prunes)
259 # recreate the taskToQuantumNode dict removing nodes that have been
260 # pruned. Keep track of task defs that now have no QuantumNodes
261 emptyTasks: Set[str] = set()
262 newTaskToQuantumNode: DefaultDict[TaskDef, Set[QuantumNode]] = defaultdict(set)
263 # accumulate all types
264 types_ = set()
265 # tracker for any pruneRefs that have caused tasks to have no nodes
266 # This helps the user find out what caused the issues seen.
267 culprits = set()
268 # Find all the types from the refs to prune
269 for r in pruneRefs:
270 types_.add(r.datasetType)
272 # For each of the tasks, and their associated nodes, remove any
273 # any nodes that were pruned. If there are no nodes associated
274 # with a task, record that task, and find out if that was due to
275 # a type from an input ref to prune.
276 for td, taskNodes in self._taskToQuantumNode.items():
277 diff = taskNodes.difference(prunes)
278 if len(diff) == 0:
279 if len(taskNodes) != 0:
280 tp: DatasetType
281 for tp in types_:
282 if (tmpRefs := next(iter(taskNodes)).quantum.inputs.get(tp)) and not set(
283 tmpRefs
284 ).difference(pruneRefs):
285 culprits.add(tp.name)
286 emptyTasks.add(td.label)
287 newTaskToQuantumNode[td] = diff
289 # update the internal dict
290 self._taskToQuantumNode = newTaskToQuantumNode
292 if emptyTasks:
293 raise ValueError(
294 f"{', '.join(emptyTasks)} task(s) have no nodes associated with them "
295 f"after graph pruning; {', '.join(culprits)} caused over-pruning"
296 )
298 # Dimension universe
299 if universe is None:
300 raise RuntimeError(
301 "Dimension universe or at least one quantum with a data ID "
302 "must be provided when constructing a QuantumGraph."
303 )
304 self._universe = universe
306 # Graph of quanta relations
307 self._connectedQuanta = self._datasetRefDict.makeNetworkXGraph()
308 self._count = len(self._connectedQuanta)
310 # Graph of task relations, used in various methods
311 self._taskGraph = self._datasetDict.makeNetworkXGraph()
313 # convert default dict into a regular to prevent accidental key
314 # insertion
315 self._taskToQuantumNode = dict(self._taskToQuantumNode.items())
317 self._initInputRefs: Dict[TaskDef, List[DatasetRef]] = {}
318 self._initOutputRefs: Dict[TaskDef, List[DatasetRef]] = {}
319 self._globalInitOutputRefs: List[DatasetRef] = []
320 self._registryDatasetTypes: List[DatasetType] = []
321 if initInputs is not None:
322 self._initInputRefs = {taskDef: list(refs) for taskDef, refs in initInputs.items()}
323 if initOutputs is not None:
324 self._initOutputRefs = {taskDef: list(refs) for taskDef, refs in initOutputs.items()}
325 if globalInitOutputs is not None:
326 self._globalInitOutputRefs = list(globalInitOutputs)
327 if registryDatasetTypes is not None:
328 self._registryDatasetTypes = list(registryDatasetTypes)
330 @property
331 def taskGraph(self) -> nx.DiGraph:
332 """Return a graph representing the relations between the tasks inside
333 the quantum graph.
335 Returns
336 -------
337 taskGraph : `networkx.Digraph`
338 Internal datastructure that holds relations of `TaskDef` objects
339 """
340 return self._taskGraph
342 @property
343 def graph(self) -> nx.DiGraph:
344 """Return a graph representing the relations between all the
345 `QuantumNode` objects. Largely it should be preferred to iterate
346 over, and use methods of this class, but sometimes direct access to
347 the networkx object may be helpful
349 Returns
350 -------
351 graph : `networkx.Digraph`
352 Internal datastructure that holds relations of `QuantumNode`
353 objects
354 """
355 return self._connectedQuanta
357 @property
358 def inputQuanta(self) -> Iterable[QuantumNode]:
359 """Make a `list` of all `QuantumNode` objects that are 'input' nodes
360 to the graph, meaning those nodes to not depend on any other nodes in
361 the graph.
363 Returns
364 -------
365 inputNodes : iterable of `QuantumNode`
366 A list of nodes that are inputs to the graph
367 """
368 return (q for q, n in self._connectedQuanta.in_degree if n == 0)
370 @property
371 def outputQuanta(self) -> Iterable[QuantumNode]:
372 """Make a `list` of all `QuantumNode` objects that are 'output' nodes
373 to the graph, meaning those nodes have no nodes that depend them in
374 the graph.
376 Returns
377 -------
378 outputNodes : iterable of `QuantumNode`
379 A list of nodes that are outputs of the graph
380 """
381 return [q for q, n in self._connectedQuanta.out_degree if n == 0]
383 @property
384 def allDatasetTypes(self) -> Tuple[DatasetTypeName, ...]:
385 """Return all the `DatasetTypeName` objects that are contained inside
386 the graph.
388 Returns
389 -------
390 tuple of `DatasetTypeName`
391 All the data set type names that are present in the graph, not
392 including global init-outputs.
393 """
394 return tuple(self._datasetDict.keys())
396 @property
397 def isConnected(self) -> bool:
398 """Return True if all of the nodes in the graph are connected, ignores
399 directionality of connections.
400 """
401 return nx.is_weakly_connected(self._connectedQuanta)
403 def pruneGraphFromRefs(self: _T, refs: Iterable[DatasetRef]) -> _T:
404 r"""Return a graph pruned of input `~lsst.daf.butler.DatasetRef`\ s
405 and nodes which depend on them.
407 Parameters
408 ----------
409 refs : `Iterable` of `DatasetRef`
410 Refs which should be removed from resulting graph
412 Returns
413 -------
414 graph : `QuantumGraph`
415 A graph that has been pruned of specified refs and the nodes that
416 depend on them.
417 """
418 newInst = object.__new__(type(self))
419 quantumMap = defaultdict(set)
420 for node in self:
421 quantumMap[node.taskDef].add(node.quantum)
423 # convert to standard dict to prevent accidental key insertion
424 quantumDict: Dict[TaskDef, Set[Quantum]] = dict(quantumMap.items())
426 # This should not change set of tasks in a graph, so we can keep the
427 # same registryDatasetTypes as in the original graph.
428 # TODO: Do we need to copy initInputs/initOutputs?
429 newInst._buildGraphs(
430 quantumDict,
431 _quantumToNodeId={n.quantum: n.nodeId for n in self},
432 metadata=self._metadata,
433 pruneRefs=refs,
434 universe=self._universe,
435 globalInitOutputs=self._globalInitOutputRefs,
436 registryDatasetTypes=self._registryDatasetTypes,
437 )
438 return newInst
440 def getQuantumNodeByNodeId(self, nodeId: uuid.UUID) -> QuantumNode:
441 """Lookup a `QuantumNode` from an id associated with the node.
443 Parameters
444 ----------
445 nodeId : `NodeId`
446 The number associated with a node
448 Returns
449 -------
450 node : `QuantumNode`
451 The node corresponding with input number
453 Raises
454 ------
455 KeyError
456 Raised if the requested nodeId is not in the graph.
457 """
458 return self._nodeIdMap[nodeId]
460 def getQuantaForTask(self, taskDef: TaskDef) -> FrozenSet[Quantum]:
461 """Return all the `Quantum` associated with a `TaskDef`.
463 Parameters
464 ----------
465 taskDef : `TaskDef`
466 The `TaskDef` for which `Quantum` are to be queried
468 Returns
469 -------
470 frozenset of `Quantum`
471 The `set` of `Quantum` that is associated with the specified
472 `TaskDef`.
473 """
474 return frozenset(node.quantum for node in self._taskToQuantumNode.get(taskDef, ()))
476 def getNumberOfQuantaForTask(self, taskDef: TaskDef) -> int:
477 """Return all the number of `Quantum` associated with a `TaskDef`.
479 Parameters
480 ----------
481 taskDef : `TaskDef`
482 The `TaskDef` for which `Quantum` are to be queried
484 Returns
485 -------
486 count : int
487 The number of `Quantum` that are associated with the specified
488 `TaskDef`.
489 """
490 return len(self._taskToQuantumNode.get(taskDef, ()))
492 def getNodesForTask(self, taskDef: TaskDef) -> FrozenSet[QuantumNode]:
493 """Return all the `QuantumNodes` associated with a `TaskDef`.
495 Parameters
496 ----------
497 taskDef : `TaskDef`
498 The `TaskDef` for which `Quantum` are to be queried
500 Returns
501 -------
502 frozenset of `QuantumNodes`
503 The `frozenset` of `QuantumNodes` that is associated with the
504 specified `TaskDef`.
505 """
506 return frozenset(self._taskToQuantumNode[taskDef])
508 def findTasksWithInput(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]:
509 """Find all tasks that have the specified dataset type name as an
510 input.
512 Parameters
513 ----------
514 datasetTypeName : `str`
515 A string representing the name of a dataset type to be queried,
516 can also accept a `DatasetTypeName` which is a `NewType` of str for
517 type safety in static type checking.
519 Returns
520 -------
521 tasks : iterable of `TaskDef`
522 `TaskDef` objects that have the specified `DatasetTypeName` as an
523 input, list will be empty if no tasks use specified
524 `DatasetTypeName` as an input.
526 Raises
527 ------
528 KeyError
529 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
530 """
531 return (c for c in self._datasetDict.getConsumers(datasetTypeName))
533 def findTaskWithOutput(self, datasetTypeName: DatasetTypeName) -> Optional[TaskDef]:
534 """Find all tasks that have the specified dataset type name as an
535 output.
537 Parameters
538 ----------
539 datasetTypeName : `str`
540 A string representing the name of a dataset type to be queried,
541 can also accept a `DatasetTypeName` which is a `NewType` of str for
542 type safety in static type checking.
544 Returns
545 -------
546 `TaskDef` or `None`
547 `TaskDef` that outputs `DatasetTypeName` as an output or None if
548 none of the tasks produce this `DatasetTypeName`.
550 Raises
551 ------
552 KeyError
553 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
554 """
555 return self._datasetDict.getProducer(datasetTypeName)
557 def tasksWithDSType(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]:
558 """Find all tasks that are associated with the specified dataset type
559 name.
561 Parameters
562 ----------
563 datasetTypeName : `str`
564 A string representing the name of a dataset type to be queried,
565 can also accept a `DatasetTypeName` which is a `NewType` of str for
566 type safety in static type checking.
568 Returns
569 -------
570 result : iterable of `TaskDef`
571 `TaskDef` objects that are associated with the specified
572 `DatasetTypeName`
574 Raises
575 ------
576 KeyError
577 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
578 """
579 return self._datasetDict.getAll(datasetTypeName)
581 def findTaskDefByName(self, taskName: str) -> List[TaskDef]:
582 """Determine which `TaskDef` objects in this graph are associated
583 with a `str` representing a task name (looks at the taskName property
584 of `TaskDef` objects).
586 Returns a list of `TaskDef` objects as a `PipelineTask` may appear
587 multiple times in a graph with different labels.
589 Parameters
590 ----------
591 taskName : str
592 Name of a task to search for
594 Returns
595 -------
596 result : list of `TaskDef`
597 List of the `TaskDef` objects that have the name specified.
598 Multiple values are returned in the case that a task is used
599 multiple times with different labels.
600 """
601 results = []
602 for task in self._taskToQuantumNode.keys():
603 split = task.taskName.split(".")
604 if split[-1] == taskName:
605 results.append(task)
606 return results
608 def findTaskDefByLabel(self, label: str) -> Optional[TaskDef]:
609 """Determine which `TaskDef` objects in this graph are associated
610 with a `str` representing a tasks label.
612 Parameters
613 ----------
614 taskName : str
615 Name of a task to search for
617 Returns
618 -------
619 result : `TaskDef`
620 `TaskDef` objects that has the specified label.
621 """
622 for task in self._taskToQuantumNode.keys():
623 if label == task.label:
624 return task
625 return None
627 def findQuantaWithDSType(self, datasetTypeName: DatasetTypeName) -> Set[Quantum]:
628 """Return all the `Quantum` that contain a specified `DatasetTypeName`.
630 Parameters
631 ----------
632 datasetTypeName : `str`
633 The name of the dataset type to search for as a string,
634 can also accept a `DatasetTypeName` which is a `NewType` of str for
635 type safety in static type checking.
637 Returns
638 -------
639 result : `set` of `QuantumNode` objects
640 A `set` of `QuantumNode`s that contain specified `DatasetTypeName`
642 Raises
643 ------
644 KeyError
645 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
647 """
648 tasks = self._datasetDict.getAll(datasetTypeName)
649 result: Set[Quantum] = set()
650 result = result.union(quantum for task in tasks for quantum in self.getQuantaForTask(task))
651 return result
653 def checkQuantumInGraph(self, quantum: Quantum) -> bool:
654 """Check if specified quantum appears in the graph as part of a node.
656 Parameters
657 ----------
658 quantum : `Quantum`
659 The quantum to search for
661 Returns
662 -------
663 `bool`
664 The result of searching for the quantum
665 """
666 for node in self:
667 if quantum == node.quantum:
668 return True
669 return False
671 def writeDotGraph(self, output: Union[str, io.BufferedIOBase]) -> None:
672 """Write out the graph as a dot graph.
674 Parameters
675 ----------
676 output : str or `io.BufferedIOBase`
677 Either a filesystem path to write to, or a file handle object
678 """
679 write_dot(self._connectedQuanta, output)
681 def subset(self: _T, nodes: Union[QuantumNode, Iterable[QuantumNode]]) -> _T:
682 """Create a new graph object that contains the subset of the nodes
683 specified as input. Node number is preserved.
685 Parameters
686 ----------
687 nodes : `QuantumNode` or iterable of `QuantumNode`
689 Returns
690 -------
691 graph : instance of graph type
692 An instance of the type from which the subset was created
693 """
694 if not isinstance(nodes, Iterable):
695 nodes = (nodes,)
696 quantumSubgraph = self._connectedQuanta.subgraph(nodes).nodes
697 quantumMap = defaultdict(set)
699 dataset_type_names: set[str] = set()
700 node: QuantumNode
701 for node in quantumSubgraph:
702 quantumMap[node.taskDef].add(node.quantum)
703 dataset_type_names.update(
704 dstype.name
705 for dstype in chain(
706 node.quantum.inputs.keys(), node.quantum.outputs.keys(), node.quantum.initInputs.keys()
707 )
708 )
710 # May need to trim dataset types from registryDatasetTypes.
711 for taskDef in quantumMap:
712 if refs := self.initOutputRefs(taskDef):
713 dataset_type_names.update(ref.datasetType.name for ref in refs)
714 dataset_type_names.update(ref.datasetType.name for ref in self._globalInitOutputRefs)
715 registryDatasetTypes = [
716 dstype for dstype in self._registryDatasetTypes if dstype.name in dataset_type_names
717 ]
719 # convert to standard dict to prevent accidental key insertion
720 quantumDict: Dict[TaskDef, Set[Quantum]] = dict(quantumMap.items())
721 # Create an empty graph, and then populate it with custom mapping
722 newInst = type(self)({}, universe=self._universe)
723 # TODO: Do we need to copy initInputs/initOutputs?
724 newInst._buildGraphs(
725 quantumDict,
726 _quantumToNodeId={n.quantum: n.nodeId for n in nodes},
727 _buildId=self._buildId,
728 metadata=self._metadata,
729 universe=self._universe,
730 globalInitOutputs=self._globalInitOutputRefs,
731 registryDatasetTypes=registryDatasetTypes,
732 )
733 return newInst
735 def subsetToConnected(self: _T) -> Tuple[_T, ...]:
736 """Generate a list of subgraphs where each is connected.
738 Returns
739 -------
740 result : list of `QuantumGraph`
741 A list of graphs that are each connected
742 """
743 return tuple(
744 self.subset(connectedSet)
745 for connectedSet in nx.weakly_connected_components(self._connectedQuanta)
746 )
748 def determineInputsToQuantumNode(self, node: QuantumNode) -> Set[QuantumNode]:
749 """Return a set of `QuantumNode` that are direct inputs to a specified
750 node.
752 Parameters
753 ----------
754 node : `QuantumNode`
755 The node of the graph for which inputs are to be determined
757 Returns
758 -------
759 set of `QuantumNode`
760 All the nodes that are direct inputs to specified node
761 """
762 return set(pred for pred in self._connectedQuanta.predecessors(node))
764 def determineOutputsOfQuantumNode(self, node: QuantumNode) -> Set[QuantumNode]:
765 """Return a set of `QuantumNode` that are direct outputs of a specified
766 node.
768 Parameters
769 ----------
770 node : `QuantumNode`
771 The node of the graph for which outputs are to be determined
773 Returns
774 -------
775 set of `QuantumNode`
776 All the nodes that are direct outputs to specified node
777 """
778 return set(succ for succ in self._connectedQuanta.successors(node))
780 def determineConnectionsOfQuantumNode(self: _T, node: QuantumNode) -> _T:
781 """Return a graph of `QuantumNode` that are direct inputs and outputs
782 of a specified node.
784 Parameters
785 ----------
786 node : `QuantumNode`
787 The node of the graph for which connected nodes are to be
788 determined.
790 Returns
791 -------
792 graph : graph of `QuantumNode`
793 All the nodes that are directly connected to specified node
794 """
795 nodes = self.determineInputsToQuantumNode(node).union(self.determineOutputsOfQuantumNode(node))
796 nodes.add(node)
797 return self.subset(nodes)
799 def determineAncestorsOfQuantumNode(self: _T, node: QuantumNode) -> _T:
800 """Return a graph of the specified node and all the ancestor nodes
801 directly reachable by walking edges.
803 Parameters
804 ----------
805 node : `QuantumNode`
806 The node for which all ansestors are to be determined
808 Returns
809 -------
810 graph of `QuantumNode`
811 Graph of node and all of its ansestors
812 """
813 predecessorNodes = nx.ancestors(self._connectedQuanta, node)
814 predecessorNodes.add(node)
815 return self.subset(predecessorNodes)
817 def findCycle(self) -> List[Tuple[QuantumNode, QuantumNode]]:
818 """Check a graph for the presense of cycles and returns the edges of
819 any cycles found, or an empty list if there is no cycle.
821 Returns
822 -------
823 result : list of tuple of `QuantumNode`, `QuantumNode`
824 A list of any graph edges that form a cycle, or an empty list if
825 there is no cycle. Empty list to so support if graph.find_cycle()
826 syntax as an empty list is falsy.
827 """
828 try:
829 return nx.find_cycle(self._connectedQuanta)
830 except nx.NetworkXNoCycle:
831 return []
833 def saveUri(self, uri: ResourcePathExpression) -> None:
834 """Save `QuantumGraph` to the specified URI.
836 Parameters
837 ----------
838 uri : convertible to `ResourcePath`
839 URI to where the graph should be saved.
840 """
841 buffer = self._buildSaveObject()
842 path = ResourcePath(uri)
843 if path.getExtension() not in (".qgraph"):
844 raise TypeError(f"Can currently only save a graph in qgraph format not {uri}")
845 path.write(buffer) # type: ignore # Ignore because bytearray is safe to use in place of bytes
847 @property
848 def metadata(self) -> Optional[MappingProxyType[str, Any]]:
849 """ """
850 if self._metadata is None:
851 return None
852 return MappingProxyType(self._metadata)
854 def initInputRefs(self, taskDef: TaskDef) -> Optional[List[DatasetRef]]:
855 """Return DatasetRefs for a given task InitInputs.
857 Parameters
858 ----------
859 taskDef : `TaskDef`
860 Task definition structure.
862 Returns
863 -------
864 refs : `list` [ `DatasetRef` ] or None
865 DatasetRef for the task InitInput, can be `None`. This can return
866 either resolved or non-resolved reference.
867 """
868 return self._initInputRefs.get(taskDef)
870 def initOutputRefs(self, taskDef: TaskDef) -> Optional[List[DatasetRef]]:
871 """Return DatasetRefs for a given task InitOutputs.
873 Parameters
874 ----------
875 taskDef : `TaskDef`
876 Task definition structure.
878 Returns
879 -------
880 refs : `list` [ `DatasetRef` ] or None
881 DatasetRefs for the task InitOutput, can be `None`. This can return
882 either resolved or non-resolved reference. Resolved reference will
883 match Quantum's initInputs if this is an intermediate dataset type.
884 """
885 return self._initOutputRefs.get(taskDef)
887 def globalInitOutputRefs(self) -> List[DatasetRef]:
888 """Return DatasetRefs for global InitOutputs.
890 Returns
891 -------
892 refs : `list` [ `DatasetRef` ]
893 DatasetRefs for global InitOutputs.
894 """
895 return self._globalInitOutputRefs
897 def registryDatasetTypes(self) -> List[DatasetType]:
898 """Return dataset types used by this graph, their definitions match
899 dataset types from registry.
901 Returns
902 -------
903 refs : `list` [ `DatasetType` ]
904 Dataset types for this graph.
905 """
906 return self._registryDatasetTypes
908 @classmethod
909 def loadUri(
910 cls,
911 uri: ResourcePathExpression,
912 universe: Optional[DimensionUniverse] = None,
913 nodes: Optional[Iterable[uuid.UUID]] = None,
914 graphID: Optional[BuildId] = None,
915 minimumVersion: int = 3,
916 ) -> QuantumGraph:
917 """Read `QuantumGraph` from a URI.
919 Parameters
920 ----------
921 uri : convertible to `ResourcePath`
922 URI from where to load the graph.
923 universe: `~lsst.daf.butler.DimensionUniverse` optional
924 DimensionUniverse instance, not used by the method itself but
925 needed to ensure that registry data structures are initialized.
926 If None it is loaded from the QuantumGraph saved structure. If
927 supplied, the DimensionUniverse from the loaded `QuantumGraph`
928 will be validated against the supplied argument for compatibility.
929 nodes: iterable of `int` or None
930 Numbers that correspond to nodes in the graph. If specified, only
931 these nodes will be loaded. Defaults to None, in which case all
932 nodes will be loaded.
933 graphID : `str` or `None`
934 If specified this ID is verified against the loaded graph prior to
935 loading any Nodes. This defaults to None in which case no
936 validation is done.
937 minimumVersion : int
938 Minimum version of a save file to load. Set to -1 to load all
939 versions. Older versions may need to be loaded, and re-saved
940 to upgrade them to the latest format before they can be used in
941 production.
943 Returns
944 -------
945 graph : `QuantumGraph`
946 Resulting QuantumGraph instance.
948 Raises
949 ------
950 TypeError
951 Raised if pickle contains instance of a type other than
952 QuantumGraph.
953 ValueError
954 Raised if one or more of the nodes requested is not in the
955 `QuantumGraph` or if graphID parameter does not match the graph
956 being loaded or if the supplied uri does not point at a valid
957 `QuantumGraph` save file.
958 RuntimeError
959 Raise if Supplied DimensionUniverse is not compatible with the
960 DimensionUniverse saved in the graph
963 Notes
964 -----
965 Reading Quanta from pickle requires existence of singleton
966 DimensionUniverse which is usually instantiated during Registry
967 initialization. To make sure that DimensionUniverse exists this method
968 accepts dummy DimensionUniverse argument.
969 """
970 uri = ResourcePath(uri)
971 # With ResourcePath we have the choice of always using a local file
972 # or reading in the bytes directly. Reading in bytes can be more
973 # efficient for reasonably-sized pickle files when the resource
974 # is remote. For now use the local file variant. For a local file
975 # as_local() does nothing.
977 if uri.getExtension() in (".pickle", ".pkl"):
978 with uri.as_local() as local, open(local.ospath, "rb") as fd:
979 warnings.warn("Pickle graphs are deprecated, please re-save your graph with the save method")
980 qgraph = pickle.load(fd)
981 elif uri.getExtension() in (".qgraph"):
982 with LoadHelper(uri, minimumVersion) as loader:
983 qgraph = loader.load(universe, nodes, graphID)
984 else:
985 raise ValueError("Only know how to handle files saved as `pickle`, `pkl`, or `qgraph`")
986 if not isinstance(qgraph, QuantumGraph):
987 raise TypeError(f"QuantumGraph save file contains unexpected object type: {type(qgraph)}")
988 return qgraph
990 @classmethod
991 def readHeader(cls, uri: ResourcePathExpression, minimumVersion: int = 3) -> Optional[str]:
992 """Read the header of a `QuantumGraph` pointed to by the uri parameter
993 and return it as a string.
995 Parameters
996 ----------
997 uri : convertible to `ResourcePath`
998 The location of the `QuantumGraph` to load. If the argument is a
999 string, it must correspond to a valid `ResourcePath` path.
1000 minimumVersion : int
1001 Minimum version of a save file to load. Set to -1 to load all
1002 versions. Older versions may need to be loaded, and re-saved
1003 to upgrade them to the latest format before they can be used in
1004 production.
1006 Returns
1007 -------
1008 header : `str` or `None`
1009 The header associated with the specified `QuantumGraph` it there is
1010 one, else `None`.
1012 Raises
1013 ------
1014 ValueError
1015 Raised if `QuantuGraph` was saved as a pickle.
1016 Raised if the extention of the file specified by uri is not a
1017 `QuantumGraph` extention.
1018 """
1019 uri = ResourcePath(uri)
1020 if uri.getExtension() in (".pickle", ".pkl"):
1021 raise ValueError("Reading a header from a pickle save is not supported")
1022 elif uri.getExtension() in (".qgraph"):
1023 return LoadHelper(uri, minimumVersion).readHeader()
1024 else:
1025 raise ValueError("Only know how to handle files saved as `qgraph`")
1027 def buildAndPrintHeader(self) -> None:
1028 """Creates a header that would be used in a save of this object and
1029 prints it out to standard out.
1030 """
1031 _, header = self._buildSaveObject(returnHeader=True)
1032 print(json.dumps(header))
1034 def save(self, file: BinaryIO) -> None:
1035 """Save QuantumGraph to a file.
1037 Parameters
1038 ----------
1039 file : `io.BufferedIOBase`
1040 File to write pickle data open in binary mode.
1041 """
1042 buffer = self._buildSaveObject()
1043 file.write(buffer) # type: ignore # Ignore because bytearray is safe to use in place of bytes
1045 def _buildSaveObject(self, returnHeader: bool = False) -> Union[bytearray, Tuple[bytearray, Dict]]:
1046 # make some containers
1047 jsonData: Deque[bytes] = deque()
1048 # node map is a list because json does not accept mapping keys that
1049 # are not strings, so we store a list of key, value pairs that will
1050 # be converted to a mapping on load
1051 nodeMap = []
1052 taskDefMap = {}
1053 headerData: Dict[str, Any] = {}
1055 # Store the QauntumGraph BuildId, this will allow validating BuildIds
1056 # at load time, prior to loading any QuantumNodes. Name chosen for
1057 # unlikely conflicts.
1058 headerData["GraphBuildID"] = self.graphID
1059 headerData["Metadata"] = self._metadata
1061 # Store the universe this graph was created with
1062 universeConfig = self._universe.dimensionConfig
1063 headerData["universe"] = universeConfig.toDict()
1065 # counter for the number of bytes processed thus far
1066 count = 0
1067 # serialize out the task Defs recording the start and end bytes of each
1068 # taskDef
1069 inverseLookup = self._datasetDict.inverse
1070 taskDef: TaskDef
1071 # sort by task label to ensure serialization happens in the same order
1072 for taskDef in self.taskGraph:
1073 # compressing has very little impact on saving or load time, but
1074 # a large impact on on disk size, so it is worth doing
1075 taskDescription: Dict[str, Any] = {}
1076 # save the fully qualified name.
1077 taskDescription["taskName"] = get_full_type_name(taskDef.taskClass)
1078 # save the config as a text stream that will be un-persisted on the
1079 # other end
1080 stream = io.StringIO()
1081 taskDef.config.saveToStream(stream)
1082 taskDescription["config"] = stream.getvalue()
1083 taskDescription["label"] = taskDef.label
1084 if (refs := self._initInputRefs.get(taskDef)) is not None:
1085 taskDescription["initInputRefs"] = [ref.to_json() for ref in refs]
1086 if (refs := self._initOutputRefs.get(taskDef)) is not None:
1087 taskDescription["initOutputRefs"] = [ref.to_json() for ref in refs]
1089 inputs = []
1090 outputs = []
1092 # Determine the connection between all of tasks and save that in
1093 # the header as a list of connections and edges in each task
1094 # this will help in un-persisting, and possibly in a "quick view"
1095 # method that does not require everything to be un-persisted
1096 #
1097 # Typing returns can't be parameter dependent
1098 for connection in inverseLookup[taskDef]: # type: ignore
1099 consumers = self._datasetDict.getConsumers(connection)
1100 producer = self._datasetDict.getProducer(connection)
1101 if taskDef in consumers:
1102 # This checks if the task consumes the connection directly
1103 # from the datastore or it is produced by another task
1104 producerLabel = producer.label if producer is not None else "datastore"
1105 inputs.append((producerLabel, connection))
1106 elif taskDef not in consumers and producer is taskDef:
1107 # If there are no consumers for this tasks produced
1108 # connection, the output will be said to be the datastore
1109 # in which case the for loop will be a zero length loop
1110 if not consumers:
1111 outputs.append(("datastore", connection))
1112 for td in consumers:
1113 outputs.append((td.label, connection))
1115 # dump to json string, and encode that string to bytes and then
1116 # conpress those bytes
1117 dump = lzma.compress(json.dumps(taskDescription).encode())
1118 # record the sizing and relation information
1119 taskDefMap[taskDef.label] = {
1120 "bytes": (count, count + len(dump)),
1121 "inputs": inputs,
1122 "outputs": outputs,
1123 }
1124 count += len(dump)
1125 jsonData.append(dump)
1127 headerData["TaskDefs"] = taskDefMap
1129 # serialize the nodes, recording the start and end bytes of each node
1130 dimAccumulator = DimensionRecordsAccumulator()
1131 for node in self:
1132 # compressing has very little impact on saving or load time, but
1133 # a large impact on on disk size, so it is worth doing
1134 simpleNode = node.to_simple(accumulator=dimAccumulator)
1136 dump = lzma.compress(simpleNode.json().encode())
1137 jsonData.append(dump)
1138 nodeMap.append(
1139 (
1140 str(node.nodeId),
1141 {
1142 "bytes": (count, count + len(dump)),
1143 "inputs": [str(n.nodeId) for n in self.determineInputsToQuantumNode(node)],
1144 "outputs": [str(n.nodeId) for n in self.determineOutputsOfQuantumNode(node)],
1145 },
1146 )
1147 )
1148 count += len(dump)
1150 headerData["DimensionRecords"] = {
1151 key: value.dict() for key, value in dimAccumulator.makeSerializedDimensionRecordMapping().items()
1152 }
1154 # need to serialize this as a series of key,value tuples because of
1155 # a limitation on how json cant do anything but strings as keys
1156 headerData["Nodes"] = nodeMap
1158 if self._globalInitOutputRefs:
1159 headerData["GlobalInitOutputRefs"] = [ref.to_json() for ref in self._globalInitOutputRefs]
1161 if self._registryDatasetTypes:
1162 headerData["RegistryDatasetTypes"] = [dstype.to_json() for dstype in self._registryDatasetTypes]
1164 # dump the headerData to json
1165 header_encode = lzma.compress(json.dumps(headerData).encode())
1167 # record the sizes as 2 unsigned long long numbers for a total of 16
1168 # bytes
1169 save_bytes = struct.pack(STRUCT_FMT_BASE, SAVE_VERSION)
1171 fmt_string = DESERIALIZER_MAP[SAVE_VERSION].FMT_STRING()
1172 map_lengths = struct.pack(fmt_string, len(header_encode))
1174 # write each component of the save out in a deterministic order
1175 # buffer = io.BytesIO()
1176 # buffer.write(map_lengths)
1177 # buffer.write(taskDef_pickle)
1178 # buffer.write(map_pickle)
1179 buffer = bytearray()
1180 buffer.extend(MAGIC_BYTES)
1181 buffer.extend(save_bytes)
1182 buffer.extend(map_lengths)
1183 buffer.extend(header_encode)
1184 # Iterate over the length of pickleData, and for each element pop the
1185 # leftmost element off the deque and write it out. This is to save
1186 # memory, as the memory is added to the buffer object, it is removed
1187 # from from the container.
1188 #
1189 # Only this section needs to worry about memory pressue because
1190 # everything else written to the buffer prior to this pickle data is
1191 # only on the order of kilobytes to low numbers of megabytes.
1192 while jsonData:
1193 buffer.extend(jsonData.popleft())
1194 if returnHeader:
1195 return buffer, headerData
1196 else:
1197 return buffer
1199 @classmethod
1200 def load(
1201 cls,
1202 file: BinaryIO,
1203 universe: Optional[DimensionUniverse] = None,
1204 nodes: Optional[Iterable[uuid.UUID]] = None,
1205 graphID: Optional[BuildId] = None,
1206 minimumVersion: int = 3,
1207 ) -> QuantumGraph:
1208 """Read QuantumGraph from a file that was made by `save`.
1210 Parameters
1211 ----------
1212 file : `io.IO` of bytes
1213 File with pickle data open in binary mode.
1214 universe: `~lsst.daf.butler.DimensionUniverse`, optional
1215 DimensionUniverse instance, not used by the method itself but
1216 needed to ensure that registry data structures are initialized.
1217 If None it is loaded from the QuantumGraph saved structure. If
1218 supplied, the DimensionUniverse from the loaded `QuantumGraph`
1219 will be validated against the supplied argument for compatibility.
1220 nodes: iterable of `int` or None
1221 Numbers that correspond to nodes in the graph. If specified, only
1222 these nodes will be loaded. Defaults to None, in which case all
1223 nodes will be loaded.
1224 graphID : `str` or `None`
1225 If specified this ID is verified against the loaded graph prior to
1226 loading any Nodes. This defaults to None in which case no
1227 validation is done.
1228 minimumVersion : int
1229 Minimum version of a save file to load. Set to -1 to load all
1230 versions. Older versions may need to be loaded, and re-saved
1231 to upgrade them to the latest format before they can be used in
1232 production.
1234 Returns
1235 -------
1236 graph : `QuantumGraph`
1237 Resulting QuantumGraph instance.
1239 Raises
1240 ------
1241 TypeError
1242 Raised if pickle contains instance of a type other than
1243 QuantumGraph.
1244 ValueError
1245 Raised if one or more of the nodes requested is not in the
1246 `QuantumGraph` or if graphID parameter does not match the graph
1247 being loaded or if the supplied uri does not point at a valid
1248 `QuantumGraph` save file.
1250 Notes
1251 -----
1252 Reading Quanta from pickle requires existence of singleton
1253 DimensionUniverse which is usually instantiated during Registry
1254 initialization. To make sure that DimensionUniverse exists this method
1255 accepts dummy DimensionUniverse argument.
1256 """
1257 # Try to see if the file handle contains pickle data, this will be
1258 # removed in the future
1259 try:
1260 qgraph = pickle.load(file)
1261 warnings.warn("Pickle graphs are deprecated, please re-save your graph with the save method")
1262 except pickle.UnpicklingError:
1263 with LoadHelper(file, minimumVersion) as loader:
1264 qgraph = loader.load(universe, nodes, graphID)
1265 if not isinstance(qgraph, QuantumGraph):
1266 raise TypeError(f"QuantumGraph pickle file has contains unexpected object type: {type(qgraph)}")
1267 return qgraph
1269 def iterTaskGraph(self) -> Generator[TaskDef, None, None]:
1270 """Iterate over the `taskGraph` attribute in topological order
1272 Yields
1273 ------
1274 taskDef : `TaskDef`
1275 `TaskDef` objects in topological order
1276 """
1277 yield from nx.topological_sort(self.taskGraph)
1279 @property
1280 def graphID(self) -> BuildId:
1281 """Returns the ID generated by the graph at construction time"""
1282 return self._buildId
1284 @property
1285 def universe(self) -> DimensionUniverse:
1286 """Dimension universe associated with this graph."""
1287 return self._universe
1289 def __iter__(self) -> Generator[QuantumNode, None, None]:
1290 yield from nx.topological_sort(self._connectedQuanta)
1292 def __len__(self) -> int:
1293 return self._count
1295 def __contains__(self, node: QuantumNode) -> bool:
1296 return self._connectedQuanta.has_node(node)
1298 def __getstate__(self) -> dict:
1299 """Stores a compact form of the graph as a list of graph nodes, and a
1300 tuple of task labels and task configs. The full graph can be
1301 reconstructed with this information, and it preserves the ordering of
1302 the graph nodes.
1303 """
1304 universe: Optional[DimensionUniverse] = None
1305 for node in self:
1306 dId = node.quantum.dataId
1307 if dId is None:
1308 continue
1309 universe = dId.graph.universe
1310 return {"reduced": self._buildSaveObject(), "graphId": self._buildId, "universe": universe}
1312 def __setstate__(self, state: dict) -> None:
1313 """Reconstructs the state of the graph from the information persisted
1314 in getstate.
1315 """
1316 buffer = io.BytesIO(state["reduced"])
1317 with LoadHelper(buffer, minimumVersion=3) as loader:
1318 qgraph = loader.load(state["universe"], graphID=state["graphId"])
1320 self._metadata = qgraph._metadata
1321 self._buildId = qgraph._buildId
1322 self._datasetDict = qgraph._datasetDict
1323 self._nodeIdMap = qgraph._nodeIdMap
1324 self._count = len(qgraph)
1325 self._taskToQuantumNode = qgraph._taskToQuantumNode
1326 self._taskGraph = qgraph._taskGraph
1327 self._connectedQuanta = qgraph._connectedQuanta
1328 self._initInputRefs = qgraph._initInputRefs
1329 self._initOutputRefs = qgraph._initOutputRefs
1331 def __eq__(self, other: object) -> bool:
1332 if not isinstance(other, QuantumGraph):
1333 return False
1334 if len(self) != len(other):
1335 return False
1336 for node in self:
1337 if node not in other:
1338 return False
1339 if self.determineInputsToQuantumNode(node) != other.determineInputsToQuantumNode(node):
1340 return False
1341 if self.determineOutputsOfQuantumNode(node) != other.determineOutputsOfQuantumNode(node):
1342 return False
1343 if set(self.allDatasetTypes) != set(other.allDatasetTypes):
1344 return False
1345 return set(self.taskGraph) == set(other.taskGraph)