Coverage for python/lsst/pipe/base/graph/graph.py: 20%
391 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-16 10:46 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-16 10:46 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
27from __future__ import annotations
29__all__ = ("QuantumGraph", "IncompatibleGraphError")
31import io
32import json
33import lzma
34import os
35import struct
36import time
37import uuid
38from collections import defaultdict, deque
39from collections.abc import Generator, Iterable, Iterator, Mapping, MutableMapping
40from itertools import chain
41from types import MappingProxyType
42from typing import Any, BinaryIO, TypeVar
44import networkx as nx
45from lsst.daf.butler import (
46 DatasetId,
47 DatasetRef,
48 DatasetType,
49 DimensionRecordsAccumulator,
50 DimensionUniverse,
51 Quantum,
52)
53from lsst.daf.butler.persistence_context import PersistenceContextVars
54from lsst.resources import ResourcePath, ResourcePathExpression
55from lsst.utils.introspection import get_full_type_name
56from networkx.drawing.nx_agraph import write_dot
58from ..connections import iterConnections
59from ..pipeline import TaskDef
60from ._implDetails import DatasetTypeName, _DatasetTracker
61from ._loadHelpers import LoadHelper
62from ._versionDeserializers import DESERIALIZER_MAP
63from .quantumNode import BuildId, QuantumNode
65_T = TypeVar("_T", bound="QuantumGraph")
67# modify this constant any time the on disk representation of the save file
68# changes, and update the load helpers to behave properly for each version.
69SAVE_VERSION = 3
71# Strings used to describe the format for the preamble bytes in a file save
72# The base is a big endian encoded unsigned short that is used to hold the
73# file format version. This allows reading version bytes and determine which
74# loading code should be used for the rest of the file
75STRUCT_FMT_BASE = ">H"
76#
77# Version 1
78# This marks a big endian encoded format with an unsigned short, an unsigned
79# long long, and an unsigned long long in the byte stream
80# Version 2
81# A big endian encoded format with an unsigned long long byte stream used to
82# indicate the total length of the entire header.
83STRUCT_FMT_STRING = {1: ">QQ", 2: ">Q"}
85# magic bytes that help determine this is a graph save
86MAGIC_BYTES = b"qgraph4\xf6\xe8\xa9"
89class IncompatibleGraphError(Exception):
90 """Exception class to indicate that a lookup by NodeId is impossible due
91 to incompatibilities.
92 """
94 pass
97class QuantumGraph:
98 """QuantumGraph is a directed acyclic graph of `QuantumNode` objects.
100 This data structure represents a concrete workflow generated from a
101 `Pipeline`.
103 Parameters
104 ----------
105 quanta : `~collections.abc.Mapping` [ `TaskDef`, \
106 `set` [ `~lsst.daf.butler.Quantum` ] ]
107 This maps tasks (and their configs) to the sets of data they are to
108 process.
109 metadata : Optional `~collections.abc.Mapping` of `str` to primitives
110 This is an optional parameter of extra data to carry with the graph.
111 Entries in this mapping should be able to be serialized in JSON.
112 universe : `~lsst.daf.butler.DimensionUniverse`, optional
113 The dimensions in which quanta can be defined. Need only be provided if
114 no quanta have data IDs.
115 initInputs : `~collections.abc.Mapping`, optional
116 Maps tasks to their InitInput dataset refs. Dataset refs can be either
117 resolved or non-resolved. Presently the same dataset refs are included
118 in each `~lsst.daf.butler.Quantum` for the same task.
119 initOutputs : `~collections.abc.Mapping`, optional
120 Maps tasks to their InitOutput dataset refs. Dataset refs can be either
121 resolved or non-resolved. For intermediate resolved refs their dataset
122 ID must match ``initInputs`` and Quantum ``initInputs``.
123 globalInitOutputs : iterable [ `~lsst.daf.butler.DatasetRef` ], optional
124 Dataset refs for some global objects produced by pipeline. These
125 objects include task configurations and package versions. Typically
126 they have an empty DataId, but there is no real restriction on what
127 can appear here.
128 registryDatasetTypes : iterable [ `~lsst.daf.butler.DatasetType` ], \
129 optional
130 Dataset types which are used by this graph, their definitions must
131 match registry. If registry does not define dataset type yet, then
132 it should match one that will be created later.
134 Raises
135 ------
136 ValueError
137 Raised if the graph is pruned such that some tasks no longer have nodes
138 associated with them.
139 """
141 def __init__(
142 self,
143 quanta: Mapping[TaskDef, set[Quantum]],
144 metadata: Mapping[str, Any] | None = None,
145 universe: DimensionUniverse | None = None,
146 initInputs: Mapping[TaskDef, Iterable[DatasetRef]] | None = None,
147 initOutputs: Mapping[TaskDef, Iterable[DatasetRef]] | None = None,
148 globalInitOutputs: Iterable[DatasetRef] | None = None,
149 registryDatasetTypes: Iterable[DatasetType] | None = None,
150 ):
151 self._buildGraphs(
152 quanta,
153 metadata=metadata,
154 universe=universe,
155 initInputs=initInputs,
156 initOutputs=initOutputs,
157 globalInitOutputs=globalInitOutputs,
158 registryDatasetTypes=registryDatasetTypes,
159 )
161 def _buildGraphs(
162 self,
163 quanta: Mapping[TaskDef, set[Quantum]],
164 *,
165 _quantumToNodeId: Mapping[Quantum, uuid.UUID] | None = None,
166 _buildId: BuildId | None = None,
167 metadata: Mapping[str, Any] | None = None,
168 universe: DimensionUniverse | None = None,
169 initInputs: Mapping[TaskDef, Iterable[DatasetRef]] | None = None,
170 initOutputs: Mapping[TaskDef, Iterable[DatasetRef]] | None = None,
171 globalInitOutputs: Iterable[DatasetRef] | None = None,
172 registryDatasetTypes: Iterable[DatasetType] | None = None,
173 ) -> None:
174 """Build the graph that is used to store the relation between tasks,
175 and the graph that holds the relations between quanta
176 """
177 self._metadata = metadata
178 self._buildId = _buildId if _buildId is not None else BuildId(f"{time.time()}-{os.getpid()}")
179 # Data structure used to identify relations between
180 # DatasetTypeName -> TaskDef.
181 self._datasetDict = _DatasetTracker(createInverse=True)
183 # Temporary graph that will have dataset UUIDs (as raw bytes) and
184 # QuantumNode objects as nodes; will be collapsed down to just quanta
185 # later.
186 bipartite_graph = nx.DiGraph()
188 self._nodeIdMap: dict[uuid.UUID, QuantumNode] = {}
189 self._taskToQuantumNode: MutableMapping[TaskDef, set[QuantumNode]] = defaultdict(set)
190 for taskDef, quantumSet in quanta.items():
191 connections = taskDef.connections
193 # For each type of connection in the task, add a key to the
194 # `_DatasetTracker` for the connections name, with a value of
195 # the TaskDef in the appropriate field
196 for inpt in iterConnections(connections, ("inputs", "prerequisiteInputs", "initInputs")):
197 # Have to handle components in inputs.
198 dataset_name, _, _ = inpt.name.partition(".")
199 self._datasetDict.addConsumer(DatasetTypeName(dataset_name), taskDef)
201 for output in iterConnections(connections, ("outputs",)):
202 # Have to handle possible components in outputs.
203 dataset_name, _, _ = output.name.partition(".")
204 self._datasetDict.addProducer(DatasetTypeName(dataset_name), taskDef)
206 # For each `Quantum` in the set of all `Quantum` for this task,
207 # add a key to the `_DatasetTracker` that is a `DatasetRef` for one
208 # of the individual datasets inside the `Quantum`, with a value of
209 # a newly created QuantumNode to the appropriate input/output
210 # field.
211 for quantum in quantumSet:
212 if quantum.dataId is not None:
213 if universe is None:
214 universe = quantum.dataId.universe
215 elif universe != quantum.dataId.universe:
216 raise RuntimeError(
217 "Mismatched dimension universes in QuantumGraph construction: "
218 f"{universe} != {quantum.dataId.universe}. "
219 )
221 if _quantumToNodeId:
222 if (nodeId := _quantumToNodeId.get(quantum)) is None:
223 raise ValueError(
224 "If _quantuMToNodeNumber is not None, all quanta must have an "
225 "associated value in the mapping"
226 )
227 else:
228 nodeId = uuid.uuid4()
230 inits = quantum.initInputs.values()
231 inputs = quantum.inputs.values()
232 value = QuantumNode(quantum, taskDef, nodeId)
233 self._taskToQuantumNode[taskDef].add(value)
234 self._nodeIdMap[nodeId] = value
236 bipartite_graph.add_node(value, bipartite=0)
237 for dsRef in chain(inits, inputs):
238 # unfortunately, `Quantum` allows inits to be individual
239 # `DatasetRef`s or an Iterable of such, so there must
240 # be an instance check here
241 if isinstance(dsRef, Iterable):
242 for sub in dsRef:
243 bipartite_graph.add_node(sub.id.bytes, bipartite=1)
244 bipartite_graph.add_edge(sub.id.bytes, value)
245 else:
246 assert isinstance(dsRef, DatasetRef)
247 if dsRef.isComponent():
248 dsRef = dsRef.makeCompositeRef()
249 bipartite_graph.add_node(dsRef.id.bytes, bipartite=1)
250 bipartite_graph.add_edge(dsRef.id.bytes, value)
251 for dsRef in chain.from_iterable(quantum.outputs.values()):
252 bipartite_graph.add_node(dsRef.id.bytes, bipartite=1)
253 bipartite_graph.add_edge(value, dsRef.id.bytes)
255 # Dimension universe
256 if universe is None:
257 raise RuntimeError(
258 "Dimension universe or at least one quantum with a data ID "
259 "must be provided when constructing a QuantumGraph."
260 )
261 self._universe = universe
263 # Make graph of quanta relations, by projecting out the dataset nodes
264 # in the bipartite_graph, leaving just the quanta.
265 self._connectedQuanta = nx.algorithms.bipartite.projected_graph(
266 bipartite_graph, self._nodeIdMap.values()
267 )
268 self._count = len(self._connectedQuanta)
270 # Graph of task relations, used in various methods
271 self._taskGraph = self._datasetDict.makeNetworkXGraph()
273 # convert default dict into a regular to prevent accidental key
274 # insertion
275 self._taskToQuantumNode = dict(self._taskToQuantumNode.items())
277 self._initInputRefs: dict[TaskDef, list[DatasetRef]] = {}
278 self._initOutputRefs: dict[TaskDef, list[DatasetRef]] = {}
279 self._globalInitOutputRefs: list[DatasetRef] = []
280 self._registryDatasetTypes: list[DatasetType] = []
281 if initInputs is not None:
282 self._initInputRefs = {taskDef: list(refs) for taskDef, refs in initInputs.items()}
283 if initOutputs is not None:
284 self._initOutputRefs = {taskDef: list(refs) for taskDef, refs in initOutputs.items()}
285 if globalInitOutputs is not None:
286 self._globalInitOutputRefs = list(globalInitOutputs)
287 if registryDatasetTypes is not None:
288 self._registryDatasetTypes = list(registryDatasetTypes)
290 @property
291 def taskGraph(self) -> nx.DiGraph:
292 """A graph representing the relations between the tasks inside
293 the quantum graph (`networkx.DiGraph`).
294 """
295 return self._taskGraph
297 @property
298 def graph(self) -> nx.DiGraph:
299 """A graph representing the relations between all the `QuantumNode`
300 objects (`networkx.DiGraph`).
302 The graph should usually be iterated over, or passed to methods of this
303 class, but sometimes direct access to the ``networkx`` object may be
304 helpful.
305 """
306 return self._connectedQuanta
308 @property
309 def inputQuanta(self) -> Iterable[QuantumNode]:
310 """The nodes that are inputs to the graph (iterable [`QuantumNode`]).
312 These are the nodes that do not depend on any other nodes in the
313 graph.
314 """
315 return (q for q, n in self._connectedQuanta.in_degree if n == 0)
317 @property
318 def outputQuanta(self) -> Iterable[QuantumNode]:
319 """The nodes that are outputs of the graph (iterable [`QuantumNode`]).
321 These are the nodes that have no nodes that depend on them in the
322 graph.
323 """
324 return [q for q, n in self._connectedQuanta.out_degree if n == 0]
326 @property
327 def allDatasetTypes(self) -> tuple[DatasetTypeName, ...]:
328 """All the data set type names that are present in the graph
329 (`tuple` [`str`]).
331 These types do not include global init-outputs.
332 """
333 return tuple(self._datasetDict.keys())
335 @property
336 def isConnected(self) -> bool:
337 """Whether all of the nodes in the graph are connected, ignoring
338 directionality of connections (`bool`).
339 """
340 return nx.is_weakly_connected(self._connectedQuanta)
342 def getQuantumNodeByNodeId(self, nodeId: uuid.UUID) -> QuantumNode:
343 """Lookup a `QuantumNode` from an id associated with the node.
345 Parameters
346 ----------
347 nodeId : `NodeId`
348 The number associated with a node.
350 Returns
351 -------
352 node : `QuantumNode`
353 The node corresponding with input number.
355 Raises
356 ------
357 KeyError
358 Raised if the requested nodeId is not in the graph.
359 """
360 return self._nodeIdMap[nodeId]
362 def getQuantaForTask(self, taskDef: TaskDef) -> frozenset[Quantum]:
363 """Return all the `~lsst.daf.butler.Quantum` associated with a
364 `TaskDef`.
366 Parameters
367 ----------
368 taskDef : `TaskDef`
369 The `TaskDef` for which `~lsst.daf.butler.Quantum` are to be
370 queried.
372 Returns
373 -------
374 quanta : `frozenset` of `~lsst.daf.butler.Quantum`
375 The `set` of `~lsst.daf.butler.Quantum` that is associated with the
376 specified `TaskDef`.
377 """
378 return frozenset(node.quantum for node in self._taskToQuantumNode.get(taskDef, ()))
380 def getNumberOfQuantaForTask(self, taskDef: TaskDef) -> int:
381 """Return the number of `~lsst.daf.butler.Quantum` associated with
382 a `TaskDef`.
384 Parameters
385 ----------
386 taskDef : `TaskDef`
387 The `TaskDef` for which `~lsst.daf.butler.Quantum` are to be
388 queried.
390 Returns
391 -------
392 count : `int`
393 The number of `~lsst.daf.butler.Quantum` that are associated with
394 the specified `TaskDef`.
395 """
396 return len(self._taskToQuantumNode.get(taskDef, ()))
398 def getNodesForTask(self, taskDef: TaskDef) -> frozenset[QuantumNode]:
399 r"""Return all the `QuantumNode`\s associated with a `TaskDef`.
401 Parameters
402 ----------
403 taskDef : `TaskDef`
404 The `TaskDef` for which `~lsst.daf.butler.Quantum` are to be
405 queried.
407 Returns
408 -------
409 nodes : `frozenset` [ `QuantumNode` ]
410 A `frozenset` of `QuantumNode` that is associated with the
411 specified `TaskDef`.
412 """
413 return frozenset(self._taskToQuantumNode[taskDef])
415 def findTasksWithInput(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]:
416 """Find all tasks that have the specified dataset type name as an
417 input.
419 Parameters
420 ----------
421 datasetTypeName : `str`
422 A string representing the name of a dataset type to be queried,
423 can also accept a `DatasetTypeName` which is a `~typing.NewType` of
424 `str` for type safety in static type checking.
426 Returns
427 -------
428 tasks : iterable of `TaskDef`
429 `TaskDef` objects that have the specified `DatasetTypeName` as an
430 input, list will be empty if no tasks use specified
431 `DatasetTypeName` as an input.
433 Raises
434 ------
435 KeyError
436 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`.
437 """
438 return (c for c in self._datasetDict.getConsumers(datasetTypeName))
440 def findTaskWithOutput(self, datasetTypeName: DatasetTypeName) -> TaskDef | None:
441 """Find all tasks that have the specified dataset type name as an
442 output.
444 Parameters
445 ----------
446 datasetTypeName : `str`
447 A string representing the name of a dataset type to be queried,
448 can also accept a `DatasetTypeName` which is a `~typing.NewType` of
449 `str` for type safety in static type checking.
451 Returns
452 -------
453 result : `TaskDef` or `None`
454 `TaskDef` that outputs `DatasetTypeName` as an output or `None` if
455 none of the tasks produce this `DatasetTypeName`.
457 Raises
458 ------
459 KeyError
460 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`.
461 """
462 return self._datasetDict.getProducer(datasetTypeName)
464 def tasksWithDSType(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]:
465 """Find all tasks that are associated with the specified dataset type
466 name.
468 Parameters
469 ----------
470 datasetTypeName : `str`
471 A string representing the name of a dataset type to be queried,
472 can also accept a `DatasetTypeName` which is a `~typing.NewType` of
473 `str` for type safety in static type checking.
475 Returns
476 -------
477 result : iterable of `TaskDef`
478 `TaskDef` objects that are associated with the specified
479 `DatasetTypeName`.
481 Raises
482 ------
483 KeyError
484 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`.
485 """
486 return self._datasetDict.getAll(datasetTypeName)
488 def findTaskDefByName(self, taskName: str) -> list[TaskDef]:
489 """Determine which `TaskDef` objects in this graph are associated
490 with a `str` representing a task name (looks at the ``taskName``
491 property of `TaskDef` objects).
493 Returns a list of `TaskDef` objects as a `PipelineTask` may appear
494 multiple times in a graph with different labels.
496 Parameters
497 ----------
498 taskName : `str`
499 Name of a task to search for.
501 Returns
502 -------
503 result : `list` of `TaskDef`
504 List of the `TaskDef` objects that have the name specified.
505 Multiple values are returned in the case that a task is used
506 multiple times with different labels.
507 """
508 results = []
509 for task in self._taskToQuantumNode:
510 split = task.taskName.split(".")
511 if split[-1] == taskName:
512 results.append(task)
513 return results
515 def findTaskDefByLabel(self, label: str) -> TaskDef | None:
516 """Determine which `TaskDef` objects in this graph are associated
517 with a `str` representing a tasks label.
519 Parameters
520 ----------
521 label : `str`
522 Name of a task to search for.
524 Returns
525 -------
526 result : `TaskDef`
527 `TaskDef` objects that has the specified label.
528 """
529 for task in self._taskToQuantumNode:
530 if label == task.label:
531 return task
532 return None
534 def findQuantaWithDSType(self, datasetTypeName: DatasetTypeName) -> set[Quantum]:
535 r"""Return all the `~lsst.daf.butler.Quantum` that contain a specified
536 `DatasetTypeName`.
538 Parameters
539 ----------
540 datasetTypeName : `str`
541 The name of the dataset type to search for as a string,
542 can also accept a `DatasetTypeName` which is a `~typing.NewType` of
543 `str` for type safety in static type checking.
545 Returns
546 -------
547 result : `set` of `QuantumNode` objects
548 A `set` of `QuantumNode`\s that contain specified
549 `DatasetTypeName`.
551 Raises
552 ------
553 KeyError
554 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`.
555 """
556 tasks = self._datasetDict.getAll(datasetTypeName)
557 result: set[Quantum] = set()
558 result = result.union(quantum for task in tasks for quantum in self.getQuantaForTask(task))
559 return result
561 def checkQuantumInGraph(self, quantum: Quantum) -> bool:
562 """Check if specified quantum appears in the graph as part of a node.
564 Parameters
565 ----------
566 quantum : `lsst.daf.butler.Quantum`
567 The quantum to search for.
569 Returns
570 -------
571 in_graph : `bool`
572 The result of searching for the quantum.
573 """
574 return any(quantum == node.quantum for node in self)
576 def writeDotGraph(self, output: str | io.BufferedIOBase) -> None:
577 """Write out the graph as a dot graph.
579 Parameters
580 ----------
581 output : `str` or `io.BufferedIOBase`
582 Either a filesystem path to write to, or a file handle object.
583 """
584 write_dot(self._connectedQuanta, output)
586 def subset(self: _T, nodes: QuantumNode | Iterable[QuantumNode]) -> _T:
587 """Create a new graph object that contains the subset of the nodes
588 specified as input. Node number is preserved.
590 Parameters
591 ----------
592 nodes : `QuantumNode` or iterable of `QuantumNode`
593 Nodes from which to create subset.
595 Returns
596 -------
597 graph : instance of graph type
598 An instance of the type from which the subset was created.
599 """
600 if not isinstance(nodes, Iterable):
601 nodes = (nodes,)
602 quantumSubgraph = self._connectedQuanta.subgraph(nodes).nodes
603 quantumMap = defaultdict(set)
605 dataset_type_names: set[str] = set()
606 node: QuantumNode
607 for node in quantumSubgraph:
608 quantumMap[node.taskDef].add(node.quantum)
609 dataset_type_names.update(
610 dstype.name
611 for dstype in chain(
612 node.quantum.inputs.keys(), node.quantum.outputs.keys(), node.quantum.initInputs.keys()
613 )
614 )
616 # May need to trim dataset types from registryDatasetTypes.
617 for taskDef in quantumMap:
618 if refs := self.initOutputRefs(taskDef):
619 dataset_type_names.update(ref.datasetType.name for ref in refs)
620 dataset_type_names.update(ref.datasetType.name for ref in self._globalInitOutputRefs)
621 registryDatasetTypes = [
622 dstype for dstype in self._registryDatasetTypes if dstype.name in dataset_type_names
623 ]
625 # convert to standard dict to prevent accidental key insertion
626 quantumDict: dict[TaskDef, set[Quantum]] = dict(quantumMap.items())
627 # Create an empty graph, and then populate it with custom mapping
628 newInst = type(self)({}, universe=self._universe)
629 # TODO: Do we need to copy initInputs/initOutputs?
630 newInst._buildGraphs(
631 quantumDict,
632 _quantumToNodeId={n.quantum: n.nodeId for n in nodes},
633 _buildId=self._buildId,
634 metadata=self._metadata,
635 universe=self._universe,
636 globalInitOutputs=self._globalInitOutputRefs,
637 registryDatasetTypes=registryDatasetTypes,
638 )
639 return newInst
641 def subsetToConnected(self: _T) -> tuple[_T, ...]:
642 """Generate a list of subgraphs where each is connected.
644 Returns
645 -------
646 result : `list` of `QuantumGraph`
647 A list of graphs that are each connected.
648 """
649 return tuple(
650 self.subset(connectedSet)
651 for connectedSet in nx.weakly_connected_components(self._connectedQuanta)
652 )
654 def determineInputsToQuantumNode(self, node: QuantumNode) -> set[QuantumNode]:
655 """Return a set of `QuantumNode` that are direct inputs to a specified
656 node.
658 Parameters
659 ----------
660 node : `QuantumNode`
661 The node of the graph for which inputs are to be determined.
663 Returns
664 -------
665 inputs : `set` of `QuantumNode`
666 All the nodes that are direct inputs to specified node.
667 """
668 return set(self._connectedQuanta.predecessors(node))
670 def determineOutputsOfQuantumNode(self, node: QuantumNode) -> set[QuantumNode]:
671 """Return a set of `QuantumNode` that are direct outputs of a specified
672 node.
674 Parameters
675 ----------
676 node : `QuantumNode`
677 The node of the graph for which outputs are to be determined.
679 Returns
680 -------
681 outputs : `set` of `QuantumNode`
682 All the nodes that are direct outputs to specified node.
683 """
684 return set(self._connectedQuanta.successors(node))
686 def determineConnectionsOfQuantumNode(self: _T, node: QuantumNode) -> _T:
687 """Return a graph of `QuantumNode` that are direct inputs and outputs
688 of a specified node.
690 Parameters
691 ----------
692 node : `QuantumNode`
693 The node of the graph for which connected nodes are to be
694 determined.
696 Returns
697 -------
698 graph : graph of `QuantumNode`
699 All the nodes that are directly connected to specified node.
700 """
701 nodes = self.determineInputsToQuantumNode(node).union(self.determineOutputsOfQuantumNode(node))
702 nodes.add(node)
703 return self.subset(nodes)
705 def determineAncestorsOfQuantumNode(self: _T, node: QuantumNode) -> _T:
706 """Return a graph of the specified node and all the ancestor nodes
707 directly reachable by walking edges.
709 Parameters
710 ----------
711 node : `QuantumNode`
712 The node for which all ancestors are to be determined.
714 Returns
715 -------
716 ancestors : graph of `QuantumNode`
717 Graph of node and all of its ancestors.
718 """
719 predecessorNodes = nx.ancestors(self._connectedQuanta, node)
720 predecessorNodes.add(node)
721 return self.subset(predecessorNodes)
723 def findCycle(self) -> list[tuple[QuantumNode, QuantumNode]]:
724 """Check a graph for the presense of cycles and returns the edges of
725 any cycles found, or an empty list if there is no cycle.
727 Returns
728 -------
729 result : `list` of `tuple` of [ `QuantumNode`, `QuantumNode` ]
730 A list of any graph edges that form a cycle, or an empty list if
731 there is no cycle. Empty list to so support if graph.find_cycle()
732 syntax as an empty list is falsy.
733 """
734 try:
735 return nx.find_cycle(self._connectedQuanta)
736 except nx.NetworkXNoCycle:
737 return []
739 def saveUri(self, uri: ResourcePathExpression) -> None:
740 """Save `QuantumGraph` to the specified URI.
742 Parameters
743 ----------
744 uri : convertible to `~lsst.resources.ResourcePath`
745 URI to where the graph should be saved.
746 """
747 buffer = self._buildSaveObject()
748 path = ResourcePath(uri)
749 if path.getExtension() not in (".qgraph"):
750 raise TypeError(f"Can currently only save a graph in qgraph format not {uri}")
751 path.write(buffer) # type: ignore # Ignore because bytearray is safe to use in place of bytes
753 @property
754 def metadata(self) -> MappingProxyType[str, Any] | None:
755 """Extra data carried with the graph (mapping [`str`] or `None`).
757 The mapping is a dynamic view of this object's metadata. Values should
758 be able to be serialized in JSON.
759 """
760 if self._metadata is None:
761 return None
762 return MappingProxyType(self._metadata)
764 def initInputRefs(self, taskDef: TaskDef) -> list[DatasetRef] | None:
765 """Return DatasetRefs for a given task InitInputs.
767 Parameters
768 ----------
769 taskDef : `TaskDef`
770 Task definition structure.
772 Returns
773 -------
774 refs : `list` [ `~lsst.daf.butler.DatasetRef` ] or `None`
775 DatasetRef for the task InitInput, can be `None`. This can return
776 either resolved or non-resolved reference.
777 """
778 return self._initInputRefs.get(taskDef)
780 def initOutputRefs(self, taskDef: TaskDef) -> list[DatasetRef] | None:
781 """Return DatasetRefs for a given task InitOutputs.
783 Parameters
784 ----------
785 taskDef : `TaskDef`
786 Task definition structure.
788 Returns
789 -------
790 refs : `list` [ `~lsst.daf.butler.DatasetRef` ] or `None`
791 DatasetRefs for the task InitOutput, can be `None`. This can return
792 either resolved or non-resolved reference. Resolved reference will
793 match Quantum's initInputs if this is an intermediate dataset type.
794 """
795 return self._initOutputRefs.get(taskDef)
797 def globalInitOutputRefs(self) -> list[DatasetRef]:
798 """Return DatasetRefs for global InitOutputs.
800 Returns
801 -------
802 refs : `list` [ `~lsst.daf.butler.DatasetRef` ]
803 DatasetRefs for global InitOutputs.
804 """
805 return self._globalInitOutputRefs
807 def registryDatasetTypes(self) -> list[DatasetType]:
808 """Return dataset types used by this graph, their definitions match
809 dataset types from registry.
811 Returns
812 -------
813 refs : `list` [ `~lsst.daf.butler.DatasetType` ]
814 Dataset types for this graph.
815 """
816 return self._registryDatasetTypes
818 @classmethod
819 def loadUri(
820 cls,
821 uri: ResourcePathExpression,
822 universe: DimensionUniverse | None = None,
823 nodes: Iterable[uuid.UUID] | None = None,
824 graphID: BuildId | None = None,
825 minimumVersion: int = 3,
826 ) -> QuantumGraph:
827 """Read `QuantumGraph` from a URI.
829 Parameters
830 ----------
831 uri : convertible to `~lsst.resources.ResourcePath`
832 URI from where to load the graph.
833 universe : `~lsst.daf.butler.DimensionUniverse`, optional
834 If `None` it is loaded from the `QuantumGraph`
835 saved structure. If supplied, the
836 `~lsst.daf.butler.DimensionUniverse` from the loaded `QuantumGraph`
837 will be validated against the supplied argument for compatibility.
838 nodes : iterable of `uuid.UUID` or `None`
839 UUIDs that correspond to nodes in the graph. If specified, only
840 these nodes will be loaded. Defaults to None, in which case all
841 nodes will be loaded.
842 graphID : `str` or `None`
843 If specified this ID is verified against the loaded graph prior to
844 loading any Nodes. This defaults to None in which case no
845 validation is done.
846 minimumVersion : `int`
847 Minimum version of a save file to load. Set to -1 to load all
848 versions. Older versions may need to be loaded, and re-saved
849 to upgrade them to the latest format before they can be used in
850 production.
852 Returns
853 -------
854 graph : `QuantumGraph`
855 Resulting QuantumGraph instance.
857 Raises
858 ------
859 TypeError
860 Raised if file contains instance of a type other than
861 `QuantumGraph`.
862 ValueError
863 Raised if one or more of the nodes requested is not in the
864 `QuantumGraph` or if graphID parameter does not match the graph
865 being loaded or if the supplied uri does not point at a valid
866 `QuantumGraph` save file.
867 RuntimeError
868 Raise if Supplied `~lsst.daf.butler.DimensionUniverse` is not
869 compatible with the `~lsst.daf.butler.DimensionUniverse` saved in
870 the graph.
871 """
872 uri = ResourcePath(uri)
873 if uri.getExtension() in {".qgraph"}:
874 with LoadHelper(uri, minimumVersion) as loader:
875 qgraph = loader.load(universe, nodes, graphID)
876 else:
877 raise ValueError(f"Only know how to handle files saved as `.qgraph`, not {uri}")
878 if not isinstance(qgraph, QuantumGraph):
879 raise TypeError(f"QuantumGraph file {uri} contains unexpected object type: {type(qgraph)}")
880 return qgraph
882 @classmethod
883 def readHeader(cls, uri: ResourcePathExpression, minimumVersion: int = 3) -> str | None:
884 """Read the header of a `QuantumGraph` pointed to by the uri parameter
885 and return it as a string.
887 Parameters
888 ----------
889 uri : convertible to `~lsst.resources.ResourcePath`
890 The location of the `QuantumGraph` to load. If the argument is a
891 string, it must correspond to a valid
892 `~lsst.resources.ResourcePath` path.
893 minimumVersion : `int`
894 Minimum version of a save file to load. Set to -1 to load all
895 versions. Older versions may need to be loaded, and re-saved
896 to upgrade them to the latest format before they can be used in
897 production.
899 Returns
900 -------
901 header : `str` or `None`
902 The header associated with the specified `QuantumGraph` it there is
903 one, else `None`.
905 Raises
906 ------
907 ValueError
908 Raised if the extension of the file specified by uri is not a
909 `QuantumGraph` extension.
910 """
911 uri = ResourcePath(uri)
912 if uri.getExtension() in {".qgraph"}:
913 return LoadHelper(uri, minimumVersion).readHeader()
914 else:
915 raise ValueError("Only know how to handle files saved as `.qgraph`")
917 def buildAndPrintHeader(self) -> None:
918 """Create a header that would be used in a save of this object and
919 prints it out to standard out.
920 """
921 _, header = self._buildSaveObject(returnHeader=True)
922 print(json.dumps(header))
924 def save(self, file: BinaryIO) -> None:
925 """Save QuantumGraph to a file.
927 Parameters
928 ----------
929 file : `io.BufferedIOBase`
930 File to write data open in binary mode.
931 """
932 buffer = self._buildSaveObject()
933 file.write(buffer) # type: ignore # Ignore because bytearray is safe to use in place of bytes
935 def _buildSaveObject(self, returnHeader: bool = False) -> bytearray | tuple[bytearray, dict]:
936 thing = PersistenceContextVars()
937 result = thing.run(self._buildSaveObjectImpl, returnHeader)
938 return result
940 def _buildSaveObjectImpl(self, returnHeader: bool = False) -> bytearray | tuple[bytearray, dict]:
941 # make some containers
942 jsonData: deque[bytes] = deque()
943 # node map is a list because json does not accept mapping keys that
944 # are not strings, so we store a list of key, value pairs that will
945 # be converted to a mapping on load
946 nodeMap = []
947 taskDefMap = {}
948 headerData: dict[str, Any] = {}
950 # Store the QauntumGraph BuildId, this will allow validating BuildIds
951 # at load time, prior to loading any QuantumNodes. Name chosen for
952 # unlikely conflicts.
953 headerData["GraphBuildID"] = self.graphID
954 headerData["Metadata"] = self._metadata
956 # Store the universe this graph was created with
957 universeConfig = self._universe.dimensionConfig
958 headerData["universe"] = universeConfig.toDict()
960 # counter for the number of bytes processed thus far
961 count = 0
962 # serialize out the task Defs recording the start and end bytes of each
963 # taskDef
964 inverseLookup = self._datasetDict.inverse
965 taskDef: TaskDef
966 # sort by task label to ensure serialization happens in the same order
967 for taskDef in self.taskGraph:
968 # compressing has very little impact on saving or load time, but
969 # a large impact on on disk size, so it is worth doing
970 taskDescription: dict[str, Any] = {}
971 # save the fully qualified name.
972 taskDescription["taskName"] = get_full_type_name(taskDef.taskClass)
973 # save the config as a text stream that will be un-persisted on the
974 # other end
975 stream = io.StringIO()
976 taskDef.config.saveToStream(stream)
977 taskDescription["config"] = stream.getvalue()
978 taskDescription["label"] = taskDef.label
979 if (refs := self._initInputRefs.get(taskDef)) is not None:
980 taskDescription["initInputRefs"] = [ref.to_json() for ref in refs]
981 if (refs := self._initOutputRefs.get(taskDef)) is not None:
982 taskDescription["initOutputRefs"] = [ref.to_json() for ref in refs]
984 inputs = []
985 outputs = []
987 # Determine the connection between all of tasks and save that in
988 # the header as a list of connections and edges in each task
989 # this will help in un-persisting, and possibly in a "quick view"
990 # method that does not require everything to be un-persisted
991 #
992 # Typing returns can't be parameter dependent
993 for connection in inverseLookup[taskDef]: # type: ignore
994 consumers = self._datasetDict.getConsumers(connection)
995 producer = self._datasetDict.getProducer(connection)
996 if taskDef in consumers:
997 # This checks if the task consumes the connection directly
998 # from the datastore or it is produced by another task
999 producerLabel = producer.label if producer is not None else "datastore"
1000 inputs.append((producerLabel, connection))
1001 elif taskDef not in consumers and producer is taskDef:
1002 # If there are no consumers for this tasks produced
1003 # connection, the output will be said to be the datastore
1004 # in which case the for loop will be a zero length loop
1005 if not consumers:
1006 outputs.append(("datastore", connection))
1007 for td in consumers:
1008 outputs.append((td.label, connection))
1010 # dump to json string, and encode that string to bytes and then
1011 # conpress those bytes
1012 dump = lzma.compress(json.dumps(taskDescription).encode(), preset=2)
1013 # record the sizing and relation information
1014 taskDefMap[taskDef.label] = {
1015 "bytes": (count, count + len(dump)),
1016 "inputs": inputs,
1017 "outputs": outputs,
1018 }
1019 count += len(dump)
1020 jsonData.append(dump)
1022 headerData["TaskDefs"] = taskDefMap
1024 # serialize the nodes, recording the start and end bytes of each node
1025 dimAccumulator = DimensionRecordsAccumulator()
1026 for node in self:
1027 # compressing has very little impact on saving or load time, but
1028 # a large impact on on disk size, so it is worth doing
1029 simpleNode = node.to_simple(accumulator=dimAccumulator)
1031 dump = lzma.compress(simpleNode.model_dump_json().encode(), preset=2)
1032 jsonData.append(dump)
1033 nodeMap.append(
1034 (
1035 str(node.nodeId),
1036 {
1037 "bytes": (count, count + len(dump)),
1038 "inputs": [str(n.nodeId) for n in self.determineInputsToQuantumNode(node)],
1039 "outputs": [str(n.nodeId) for n in self.determineOutputsOfQuantumNode(node)],
1040 },
1041 )
1042 )
1043 count += len(dump)
1045 headerData["DimensionRecords"] = {
1046 key: value.model_dump()
1047 for key, value in dimAccumulator.makeSerializedDimensionRecordMapping().items()
1048 }
1050 # need to serialize this as a series of key,value tuples because of
1051 # a limitation on how json cant do anything but strings as keys
1052 headerData["Nodes"] = nodeMap
1054 if self._globalInitOutputRefs:
1055 headerData["GlobalInitOutputRefs"] = [ref.to_json() for ref in self._globalInitOutputRefs]
1057 if self._registryDatasetTypes:
1058 headerData["RegistryDatasetTypes"] = [dstype.to_json() for dstype in self._registryDatasetTypes]
1060 # dump the headerData to json
1061 header_encode = lzma.compress(json.dumps(headerData).encode())
1063 # record the sizes as 2 unsigned long long numbers for a total of 16
1064 # bytes
1065 save_bytes = struct.pack(STRUCT_FMT_BASE, SAVE_VERSION)
1067 fmt_string = DESERIALIZER_MAP[SAVE_VERSION].FMT_STRING()
1068 map_lengths = struct.pack(fmt_string, len(header_encode))
1070 # write each component of the save out in a deterministic order
1071 buffer = bytearray()
1072 buffer.extend(MAGIC_BYTES)
1073 buffer.extend(save_bytes)
1074 buffer.extend(map_lengths)
1075 buffer.extend(header_encode)
1076 # Iterate over the length of jsonData, and for each element pop the
1077 # leftmost element off the deque and write it out. This is to save
1078 # memory, as the memory is added to the buffer object, it is removed
1079 # from from the container.
1080 #
1081 # Only this section needs to worry about memory pressure because
1082 # everything else written to the buffer prior to this data is
1083 # only on the order of kilobytes to low numbers of megabytes.
1084 while jsonData:
1085 buffer.extend(jsonData.popleft())
1086 if returnHeader:
1087 return buffer, headerData
1088 else:
1089 return buffer
1091 @classmethod
1092 def load(
1093 cls,
1094 file: BinaryIO,
1095 universe: DimensionUniverse | None = None,
1096 nodes: Iterable[uuid.UUID] | None = None,
1097 graphID: BuildId | None = None,
1098 minimumVersion: int = 3,
1099 ) -> QuantumGraph:
1100 """Read `QuantumGraph` from a file that was made by `save`.
1102 Parameters
1103 ----------
1104 file : `io.IO` of bytes
1105 File with data open in binary mode.
1106 universe : `~lsst.daf.butler.DimensionUniverse`, optional
1107 If `None` it is loaded from the `QuantumGraph`
1108 saved structure. If supplied, the
1109 `~lsst.daf.butler.DimensionUniverse` from the loaded `QuantumGraph`
1110 will be validated against the supplied argument for compatibility.
1111 nodes : iterable of `uuid.UUID` or `None`
1112 UUIDs that correspond to nodes in the graph. If specified, only
1113 these nodes will be loaded. Defaults to None, in which case all
1114 nodes will be loaded.
1115 graphID : `str` or `None`
1116 If specified this ID is verified against the loaded graph prior to
1117 loading any Nodes. This defaults to None in which case no
1118 validation is done.
1119 minimumVersion : `int`
1120 Minimum version of a save file to load. Set to -1 to load all
1121 versions. Older versions may need to be loaded, and re-saved
1122 to upgrade them to the latest format before they can be used in
1123 production.
1125 Returns
1126 -------
1127 graph : `QuantumGraph`
1128 Resulting QuantumGraph instance.
1130 Raises
1131 ------
1132 TypeError
1133 Raised if data contains instance of a type other than
1134 `QuantumGraph`.
1135 ValueError
1136 Raised if one or more of the nodes requested is not in the
1137 `QuantumGraph` or if graphID parameter does not match the graph
1138 being loaded or if the supplied uri does not point at a valid
1139 `QuantumGraph` save file.
1140 """
1141 with LoadHelper(file, minimumVersion) as loader:
1142 qgraph = loader.load(universe, nodes, graphID)
1143 if not isinstance(qgraph, QuantumGraph):
1144 raise TypeError(f"QuantumGraph file contains unexpected object type: {type(qgraph)}")
1145 return qgraph
1147 def iterTaskGraph(self) -> Generator[TaskDef, None, None]:
1148 """Iterate over the `taskGraph` attribute in topological order.
1150 Yields
1151 ------
1152 taskDef : `TaskDef`
1153 `TaskDef` objects in topological order.
1154 """
1155 yield from nx.topological_sort(self.taskGraph)
1157 def updateRun(self, run: str, *, metadata_key: str | None = None, update_graph_id: bool = False) -> None:
1158 """Change output run and dataset ID for each output dataset.
1160 Parameters
1161 ----------
1162 run : `str`
1163 New output run name.
1164 metadata_key : `str` or `None`
1165 Specifies matadata key corresponding to output run name to update
1166 with new run name. If `None` or if metadata is missing it is not
1167 updated. If metadata is present but key is missing, it will be
1168 added.
1169 update_graph_id : `bool`, optional
1170 If `True` then also update graph ID with a new unique value.
1171 """
1172 dataset_id_map: dict[DatasetId, DatasetId] = {}
1174 def _update_output_refs(
1175 refs: Iterable[DatasetRef], run: str, dataset_id_map: MutableMapping[DatasetId, DatasetId]
1176 ) -> Iterator[DatasetRef]:
1177 """Update a collection of `~lsst.daf.butler.DatasetRef` with new
1178 run and dataset IDs.
1179 """
1180 for ref in refs:
1181 new_ref = ref.replace(run=run)
1182 dataset_id_map[ref.id] = new_ref.id
1183 yield new_ref
1185 def _update_intermediate_refs(
1186 refs: Iterable[DatasetRef], run: str, dataset_id_map: Mapping[DatasetId, DatasetId]
1187 ) -> Iterator[DatasetRef]:
1188 """Update intermediate references with new run and IDs. Only the
1189 references that appear in ``dataset_id_map`` are updated, others
1190 are returned unchanged.
1191 """
1192 for ref in refs:
1193 if dataset_id := dataset_id_map.get(ref.id):
1194 ref = ref.replace(run=run, id=dataset_id)
1195 yield ref
1197 # Replace quantum output refs first.
1198 for node in self._connectedQuanta:
1199 quantum = node.quantum
1200 outputs = {
1201 dataset_type: tuple(_update_output_refs(refs, run, dataset_id_map))
1202 for dataset_type, refs in quantum.outputs.items()
1203 }
1204 updated_quantum = Quantum(
1205 taskName=quantum.taskName,
1206 dataId=quantum.dataId,
1207 initInputs=quantum.initInputs,
1208 inputs=quantum.inputs,
1209 outputs=outputs,
1210 datastore_records=quantum.datastore_records,
1211 )
1212 node._replace_quantum(updated_quantum)
1214 self._initOutputRefs = {
1215 task_def: list(_update_output_refs(refs, run, dataset_id_map))
1216 for task_def, refs in self._initOutputRefs.items()
1217 }
1218 self._globalInitOutputRefs = list(
1219 _update_output_refs(self._globalInitOutputRefs, run, dataset_id_map)
1220 )
1222 # Update all intermediates from their matching outputs.
1223 for node in self._connectedQuanta:
1224 quantum = node.quantum
1225 inputs = {
1226 dataset_type: tuple(_update_intermediate_refs(refs, run, dataset_id_map))
1227 for dataset_type, refs in quantum.inputs.items()
1228 }
1229 initInputs = list(_update_intermediate_refs(quantum.initInputs.values(), run, dataset_id_map))
1231 updated_quantum = Quantum(
1232 taskName=quantum.taskName,
1233 dataId=quantum.dataId,
1234 initInputs=initInputs,
1235 inputs=inputs,
1236 outputs=quantum.outputs,
1237 datastore_records=quantum.datastore_records,
1238 )
1239 node._replace_quantum(updated_quantum)
1241 self._initInputRefs = {
1242 task_def: list(_update_intermediate_refs(refs, run, dataset_id_map))
1243 for task_def, refs in self._initInputRefs.items()
1244 }
1246 if update_graph_id:
1247 self._buildId = BuildId(f"{time.time()}-{os.getpid()}")
1249 # Update metadata if present.
1250 if self._metadata is not None and metadata_key is not None:
1251 metadata = dict(self._metadata)
1252 metadata[metadata_key] = run
1253 self._metadata = metadata
1255 @property
1256 def graphID(self) -> BuildId:
1257 """The ID generated by the graph at construction time (`str`)."""
1258 return self._buildId
1260 @property
1261 def universe(self) -> DimensionUniverse:
1262 """Dimension universe associated with this graph
1263 (`~lsst.daf.butler.DimensionUniverse`).
1264 """
1265 return self._universe
1267 def __iter__(self) -> Generator[QuantumNode, None, None]:
1268 yield from nx.topological_sort(self._connectedQuanta)
1270 def __len__(self) -> int:
1271 return self._count
1273 def __contains__(self, node: QuantumNode) -> bool:
1274 return self._connectedQuanta.has_node(node)
1276 def __getstate__(self) -> dict:
1277 """Store a compact form of the graph as a list of graph nodes, and a
1278 tuple of task labels and task configs. The full graph can be
1279 reconstructed with this information, and it preserves the ordering of
1280 the graph nodes.
1281 """
1282 universe: DimensionUniverse | None = None
1283 for node in self:
1284 dId = node.quantum.dataId
1285 if dId is None:
1286 continue
1287 universe = dId.universe
1288 return {"reduced": self._buildSaveObject(), "graphId": self._buildId, "universe": universe}
1290 def __setstate__(self, state: dict) -> None:
1291 """Reconstructs the state of the graph from the information persisted
1292 in getstate.
1293 """
1294 buffer = io.BytesIO(state["reduced"])
1295 with LoadHelper(buffer, minimumVersion=3) as loader:
1296 qgraph = loader.load(state["universe"], graphID=state["graphId"])
1298 self._metadata = qgraph._metadata
1299 self._buildId = qgraph._buildId
1300 self._datasetDict = qgraph._datasetDict
1301 self._nodeIdMap = qgraph._nodeIdMap
1302 self._count = len(qgraph)
1303 self._taskToQuantumNode = qgraph._taskToQuantumNode
1304 self._taskGraph = qgraph._taskGraph
1305 self._connectedQuanta = qgraph._connectedQuanta
1306 self._initInputRefs = qgraph._initInputRefs
1307 self._initOutputRefs = qgraph._initOutputRefs
1309 def __eq__(self, other: object) -> bool:
1310 if not isinstance(other, QuantumGraph):
1311 return False
1312 if len(self) != len(other):
1313 return False
1314 for node in self:
1315 if node not in other:
1316 return False
1317 if self.determineInputsToQuantumNode(node) != other.determineInputsToQuantumNode(node):
1318 return False
1319 if self.determineOutputsOfQuantumNode(node) != other.determineOutputsOfQuantumNode(node):
1320 return False
1321 if set(self.allDatasetTypes) != set(other.allDatasetTypes):
1322 return False
1323 return set(self.taskGraph) == set(other.taskGraph)