Coverage for python/lsst/pipe/base/graph/graph.py: 20%
391 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-11-30 12:09 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-11-30 12:09 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
27from __future__ import annotations
29__all__ = ("QuantumGraph", "IncompatibleGraphError")
31import io
32import json
33import lzma
34import os
35import struct
36import time
37import uuid
38from collections import defaultdict, deque
39from collections.abc import Generator, Iterable, Iterator, Mapping, MutableMapping
40from itertools import chain
41from types import MappingProxyType
42from typing import Any, BinaryIO, TypeVar
44import networkx as nx
45from lsst.daf.butler import (
46 DatasetId,
47 DatasetRef,
48 DatasetType,
49 DimensionRecordsAccumulator,
50 DimensionUniverse,
51 Quantum,
52)
53from lsst.daf.butler.persistence_context import PersistenceContextVars
54from lsst.resources import ResourcePath, ResourcePathExpression
55from lsst.utils.introspection import get_full_type_name
56from networkx.drawing.nx_agraph import write_dot
58from ..connections import iterConnections
59from ..pipeline import TaskDef
60from ._implDetails import DatasetTypeName, _DatasetTracker
61from ._loadHelpers import LoadHelper
62from ._versionDeserializers import DESERIALIZER_MAP
63from .quantumNode import BuildId, QuantumNode
65_T = TypeVar("_T", bound="QuantumGraph")
67# modify this constant any time the on disk representation of the save file
68# changes, and update the load helpers to behave properly for each version.
69SAVE_VERSION = 3
71# Strings used to describe the format for the preamble bytes in a file save
72# The base is a big endian encoded unsigned short that is used to hold the
73# file format version. This allows reading version bytes and determine which
74# loading code should be used for the rest of the file
75STRUCT_FMT_BASE = ">H"
76#
77# Version 1
78# This marks a big endian encoded format with an unsigned short, an unsigned
79# long long, and an unsigned long long in the byte stream
80# Version 2
81# A big endian encoded format with an unsigned long long byte stream used to
82# indicate the total length of the entire header.
83STRUCT_FMT_STRING = {1: ">QQ", 2: ">Q"}
85# magic bytes that help determine this is a graph save
86MAGIC_BYTES = b"qgraph4\xf6\xe8\xa9"
89class IncompatibleGraphError(Exception):
90 """Exception class to indicate that a lookup by NodeId is impossible due
91 to incompatibilities
92 """
94 pass
97class QuantumGraph:
98 """QuantumGraph is a directed acyclic graph of `QuantumNode` objects
100 This data structure represents a concrete workflow generated from a
101 `Pipeline`.
103 Parameters
104 ----------
105 quanta : `~collections.abc.Mapping` [ `TaskDef`, \
106 `set` [ `~lsst.daf.butler.Quantum` ] ]
107 This maps tasks (and their configs) to the sets of data they are to
108 process.
109 metadata : Optional `~collections.abc.Mapping` of `str` to primitives
110 This is an optional parameter of extra data to carry with the graph.
111 Entries in this mapping should be able to be serialized in JSON.
112 universe : `~lsst.daf.butler.DimensionUniverse`, optional
113 The dimensions in which quanta can be defined. Need only be provided if
114 no quanta have data IDs.
115 initInputs : `~collections.abc.Mapping`, optional
116 Maps tasks to their InitInput dataset refs. Dataset refs can be either
117 resolved or non-resolved. Presently the same dataset refs are included
118 in each `~lsst.daf.butler.Quantum` for the same task.
119 initOutputs : `~collections.abc.Mapping`, optional
120 Maps tasks to their InitOutput dataset refs. Dataset refs can be either
121 resolved or non-resolved. For intermediate resolved refs their dataset
122 ID must match ``initInputs`` and Quantum ``initInputs``.
123 globalInitOutputs : iterable [ `~lsst.daf.butler.DatasetRef` ], optional
124 Dataset refs for some global objects produced by pipeline. These
125 objects include task configurations and package versions. Typically
126 they have an empty DataId, but there is no real restriction on what
127 can appear here.
128 registryDatasetTypes : iterable [ `~lsst.daf.butler.DatasetType` ], \
129 optional
130 Dataset types which are used by this graph, their definitions must
131 match registry. If registry does not define dataset type yet, then
132 it should match one that will be created later.
134 Raises
135 ------
136 ValueError
137 Raised if the graph is pruned such that some tasks no longer have nodes
138 associated with them.
139 """
141 def __init__(
142 self,
143 quanta: Mapping[TaskDef, set[Quantum]],
144 metadata: Mapping[str, Any] | None = None,
145 universe: DimensionUniverse | None = None,
146 initInputs: Mapping[TaskDef, Iterable[DatasetRef]] | None = None,
147 initOutputs: Mapping[TaskDef, Iterable[DatasetRef]] | None = None,
148 globalInitOutputs: Iterable[DatasetRef] | None = None,
149 registryDatasetTypes: Iterable[DatasetType] | None = None,
150 ):
151 self._buildGraphs(
152 quanta,
153 metadata=metadata,
154 universe=universe,
155 initInputs=initInputs,
156 initOutputs=initOutputs,
157 globalInitOutputs=globalInitOutputs,
158 registryDatasetTypes=registryDatasetTypes,
159 )
161 def _buildGraphs(
162 self,
163 quanta: Mapping[TaskDef, set[Quantum]],
164 *,
165 _quantumToNodeId: Mapping[Quantum, uuid.UUID] | None = None,
166 _buildId: BuildId | None = None,
167 metadata: Mapping[str, Any] | None = None,
168 universe: DimensionUniverse | None = None,
169 initInputs: Mapping[TaskDef, Iterable[DatasetRef]] | None = None,
170 initOutputs: Mapping[TaskDef, Iterable[DatasetRef]] | None = None,
171 globalInitOutputs: Iterable[DatasetRef] | None = None,
172 registryDatasetTypes: Iterable[DatasetType] | None = None,
173 ) -> None:
174 """Build the graph that is used to store the relation between tasks,
175 and the graph that holds the relations between quanta
176 """
177 self._metadata = metadata
178 self._buildId = _buildId if _buildId is not None else BuildId(f"{time.time()}-{os.getpid()}")
179 # Data structure used to identify relations between
180 # DatasetTypeName -> TaskDef.
181 self._datasetDict = _DatasetTracker(createInverse=True)
183 # Temporary graph that will have dataset UUIDs (as raw bytes) and
184 # QuantumNode objects as nodes; will be collapsed down to just quanta
185 # later.
186 bipartite_graph = nx.DiGraph()
188 self._nodeIdMap: dict[uuid.UUID, QuantumNode] = {}
189 self._taskToQuantumNode: MutableMapping[TaskDef, set[QuantumNode]] = defaultdict(set)
190 for taskDef, quantumSet in quanta.items():
191 connections = taskDef.connections
193 # For each type of connection in the task, add a key to the
194 # `_DatasetTracker` for the connections name, with a value of
195 # the TaskDef in the appropriate field
196 for inpt in iterConnections(connections, ("inputs", "prerequisiteInputs", "initInputs")):
197 # Have to handle components in inputs.
198 dataset_name, _, _ = inpt.name.partition(".")
199 self._datasetDict.addConsumer(DatasetTypeName(dataset_name), taskDef)
201 for output in iterConnections(connections, ("outputs",)):
202 # Have to handle possible components in outputs.
203 dataset_name, _, _ = output.name.partition(".")
204 self._datasetDict.addProducer(DatasetTypeName(dataset_name), taskDef)
206 # For each `Quantum` in the set of all `Quantum` for this task,
207 # add a key to the `_DatasetTracker` that is a `DatasetRef` for one
208 # of the individual datasets inside the `Quantum`, with a value of
209 # a newly created QuantumNode to the appropriate input/output
210 # field.
211 for quantum in quantumSet:
212 if quantum.dataId is not None:
213 if universe is None:
214 universe = quantum.dataId.universe
215 elif universe != quantum.dataId.universe:
216 raise RuntimeError(
217 "Mismatched dimension universes in QuantumGraph construction: "
218 f"{universe} != {quantum.dataId.universe}. "
219 )
221 if _quantumToNodeId:
222 if (nodeId := _quantumToNodeId.get(quantum)) is None:
223 raise ValueError(
224 "If _quantuMToNodeNumber is not None, all quanta must have an "
225 "associated value in the mapping"
226 )
227 else:
228 nodeId = uuid.uuid4()
230 inits = quantum.initInputs.values()
231 inputs = quantum.inputs.values()
232 value = QuantumNode(quantum, taskDef, nodeId)
233 self._taskToQuantumNode[taskDef].add(value)
234 self._nodeIdMap[nodeId] = value
236 bipartite_graph.add_node(value, bipartite=0)
237 for dsRef in chain(inits, inputs):
238 # unfortunately, `Quantum` allows inits to be individual
239 # `DatasetRef`s or an Iterable of such, so there must
240 # be an instance check here
241 if isinstance(dsRef, Iterable):
242 for sub in dsRef:
243 bipartite_graph.add_node(sub.id.bytes, bipartite=1)
244 bipartite_graph.add_edge(sub.id.bytes, value)
245 else:
246 assert isinstance(dsRef, DatasetRef)
247 if dsRef.isComponent():
248 dsRef = dsRef.makeCompositeRef()
249 bipartite_graph.add_node(dsRef.id.bytes, bipartite=1)
250 bipartite_graph.add_edge(dsRef.id.bytes, value)
251 for dsRef in chain.from_iterable(quantum.outputs.values()):
252 bipartite_graph.add_node(dsRef.id.bytes, bipartite=1)
253 bipartite_graph.add_edge(value, dsRef.id.bytes)
255 # Dimension universe
256 if universe is None:
257 raise RuntimeError(
258 "Dimension universe or at least one quantum with a data ID "
259 "must be provided when constructing a QuantumGraph."
260 )
261 self._universe = universe
263 # Make graph of quanta relations, by projecting out the dataset nodes
264 # in the bipartite_graph, leaving just the quanta.
265 self._connectedQuanta = nx.algorithms.bipartite.projected_graph(
266 bipartite_graph, self._nodeIdMap.values()
267 )
268 self._count = len(self._connectedQuanta)
270 # Graph of task relations, used in various methods
271 self._taskGraph = self._datasetDict.makeNetworkXGraph()
273 # convert default dict into a regular to prevent accidental key
274 # insertion
275 self._taskToQuantumNode = dict(self._taskToQuantumNode.items())
277 self._initInputRefs: dict[TaskDef, list[DatasetRef]] = {}
278 self._initOutputRefs: dict[TaskDef, list[DatasetRef]] = {}
279 self._globalInitOutputRefs: list[DatasetRef] = []
280 self._registryDatasetTypes: list[DatasetType] = []
281 if initInputs is not None:
282 self._initInputRefs = {taskDef: list(refs) for taskDef, refs in initInputs.items()}
283 if initOutputs is not None:
284 self._initOutputRefs = {taskDef: list(refs) for taskDef, refs in initOutputs.items()}
285 if globalInitOutputs is not None:
286 self._globalInitOutputRefs = list(globalInitOutputs)
287 if registryDatasetTypes is not None:
288 self._registryDatasetTypes = list(registryDatasetTypes)
290 @property
291 def taskGraph(self) -> nx.DiGraph:
292 """A graph representing the relations between the tasks inside
293 the quantum graph (`networkx.DiGraph`).
294 """
295 return self._taskGraph
297 @property
298 def graph(self) -> nx.DiGraph:
299 """A graph representing the relations between all the `QuantumNode`
300 objects (`networkx.DiGraph`).
302 The graph should usually be iterated over, or passed to methods of this
303 class, but sometimes direct access to the ``networkx`` object may be
304 helpful.
305 """
306 return self._connectedQuanta
308 @property
309 def inputQuanta(self) -> Iterable[QuantumNode]:
310 """The nodes that are inputs to the graph (iterable [`QuantumNode`]).
312 These are the nodes that do not depend on any other nodes in the
313 graph.
314 """
315 return (q for q, n in self._connectedQuanta.in_degree if n == 0)
317 @property
318 def outputQuanta(self) -> Iterable[QuantumNode]:
319 """The nodes that are outputs of the graph (iterable [`QuantumNode`]).
321 These are the nodes that have no nodes that depend on them in the
322 graph.
323 """
324 return [q for q, n in self._connectedQuanta.out_degree if n == 0]
326 @property
327 def allDatasetTypes(self) -> tuple[DatasetTypeName, ...]:
328 """All the data set type names that are present in the graph
329 (`tuple` [`str`]).
331 These types do not include global init-outputs.
332 """
333 return tuple(self._datasetDict.keys())
335 @property
336 def isConnected(self) -> bool:
337 """Whether all of the nodes in the graph are connected, ignoring
338 directionality of connections (`bool`).
339 """
340 return nx.is_weakly_connected(self._connectedQuanta)
342 def getQuantumNodeByNodeId(self, nodeId: uuid.UUID) -> QuantumNode:
343 """Lookup a `QuantumNode` from an id associated with the node.
345 Parameters
346 ----------
347 nodeId : `NodeId`
348 The number associated with a node
350 Returns
351 -------
352 node : `QuantumNode`
353 The node corresponding with input number
355 Raises
356 ------
357 KeyError
358 Raised if the requested nodeId is not in the graph.
359 """
360 return self._nodeIdMap[nodeId]
362 def getQuantaForTask(self, taskDef: TaskDef) -> frozenset[Quantum]:
363 """Return all the `~lsst.daf.butler.Quantum` associated with a
364 `TaskDef`.
366 Parameters
367 ----------
368 taskDef : `TaskDef`
369 The `TaskDef` for which `~lsst.daf.butler.Quantum` are to be
370 queried.
372 Returns
373 -------
374 quanta : `frozenset` of `~lsst.daf.butler.Quantum`
375 The `set` of `~lsst.daf.butler.Quantum` that is associated with the
376 specified `TaskDef`.
377 """
378 return frozenset(node.quantum for node in self._taskToQuantumNode.get(taskDef, ()))
380 def getNumberOfQuantaForTask(self, taskDef: TaskDef) -> int:
381 """Return the number of `~lsst.daf.butler.Quantum` associated with
382 a `TaskDef`.
384 Parameters
385 ----------
386 taskDef : `TaskDef`
387 The `TaskDef` for which `~lsst.daf.butler.Quantum` are to be
388 queried.
390 Returns
391 -------
392 count : `int`
393 The number of `~lsst.daf.butler.Quantum` that are associated with
394 the specified `TaskDef`.
395 """
396 return len(self._taskToQuantumNode.get(taskDef, ()))
398 def getNodesForTask(self, taskDef: TaskDef) -> frozenset[QuantumNode]:
399 r"""Return all the `QuantumNode`\s associated with a `TaskDef`.
401 Parameters
402 ----------
403 taskDef : `TaskDef`
404 The `TaskDef` for which `~lsst.daf.butler.Quantum` are to be
405 queried.
407 Returns
408 -------
409 nodes : `frozenset` [ `QuantumNode` ]
410 A `frozenset` of `QuantumNode` that is associated with the
411 specified `TaskDef`.
412 """
413 return frozenset(self._taskToQuantumNode[taskDef])
415 def findTasksWithInput(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]:
416 """Find all tasks that have the specified dataset type name as an
417 input.
419 Parameters
420 ----------
421 datasetTypeName : `str`
422 A string representing the name of a dataset type to be queried,
423 can also accept a `DatasetTypeName` which is a `~typing.NewType` of
424 `str` for type safety in static type checking.
426 Returns
427 -------
428 tasks : iterable of `TaskDef`
429 `TaskDef` objects that have the specified `DatasetTypeName` as an
430 input, list will be empty if no tasks use specified
431 `DatasetTypeName` as an input.
433 Raises
434 ------
435 KeyError
436 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`.
437 """
438 return (c for c in self._datasetDict.getConsumers(datasetTypeName))
440 def findTaskWithOutput(self, datasetTypeName: DatasetTypeName) -> TaskDef | None:
441 """Find all tasks that have the specified dataset type name as an
442 output.
444 Parameters
445 ----------
446 datasetTypeName : `str`
447 A string representing the name of a dataset type to be queried,
448 can also accept a `DatasetTypeName` which is a `~typing.NewType` of
449 `str` for type safety in static type checking.
451 Returns
452 -------
453 result : `TaskDef` or `None`
454 `TaskDef` that outputs `DatasetTypeName` as an output or `None` if
455 none of the tasks produce this `DatasetTypeName`.
457 Raises
458 ------
459 KeyError
460 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`.
461 """
462 return self._datasetDict.getProducer(datasetTypeName)
464 def tasksWithDSType(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]:
465 """Find all tasks that are associated with the specified dataset type
466 name.
468 Parameters
469 ----------
470 datasetTypeName : `str`
471 A string representing the name of a dataset type to be queried,
472 can also accept a `DatasetTypeName` which is a `~typing.NewType` of
473 `str` for type safety in static type checking.
475 Returns
476 -------
477 result : iterable of `TaskDef`
478 `TaskDef` objects that are associated with the specified
479 `DatasetTypeName`.
481 Raises
482 ------
483 KeyError
484 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`.
485 """
486 return self._datasetDict.getAll(datasetTypeName)
488 def findTaskDefByName(self, taskName: str) -> list[TaskDef]:
489 """Determine which `TaskDef` objects in this graph are associated
490 with a `str` representing a task name (looks at the ``taskName``
491 property of `TaskDef` objects).
493 Returns a list of `TaskDef` objects as a `PipelineTask` may appear
494 multiple times in a graph with different labels.
496 Parameters
497 ----------
498 taskName : `str`
499 Name of a task to search for.
501 Returns
502 -------
503 result : `list` of `TaskDef`
504 List of the `TaskDef` objects that have the name specified.
505 Multiple values are returned in the case that a task is used
506 multiple times with different labels.
507 """
508 results = []
509 for task in self._taskToQuantumNode:
510 split = task.taskName.split(".")
511 if split[-1] == taskName:
512 results.append(task)
513 return results
515 def findTaskDefByLabel(self, label: str) -> TaskDef | None:
516 """Determine which `TaskDef` objects in this graph are associated
517 with a `str` representing a tasks label.
519 Parameters
520 ----------
521 taskName : `str`
522 Name of a task to search for
524 Returns
525 -------
526 result : `TaskDef`
527 `TaskDef` objects that has the specified label.
528 """
529 for task in self._taskToQuantumNode:
530 if label == task.label:
531 return task
532 return None
534 def findQuantaWithDSType(self, datasetTypeName: DatasetTypeName) -> set[Quantum]:
535 r"""Return all the `~lsst.daf.butler.Quantum` that contain a specified
536 `DatasetTypeName`.
538 Parameters
539 ----------
540 datasetTypeName : `str`
541 The name of the dataset type to search for as a string,
542 can also accept a `DatasetTypeName` which is a `~typing.NewType` of
543 `str` for type safety in static type checking.
545 Returns
546 -------
547 result : `set` of `QuantumNode` objects
548 A `set` of `QuantumNode`\s that contain specified
549 `DatasetTypeName`.
551 Raises
552 ------
553 KeyError
554 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`
556 """
557 tasks = self._datasetDict.getAll(datasetTypeName)
558 result: set[Quantum] = set()
559 result = result.union(quantum for task in tasks for quantum in self.getQuantaForTask(task))
560 return result
562 def checkQuantumInGraph(self, quantum: Quantum) -> bool:
563 """Check if specified quantum appears in the graph as part of a node.
565 Parameters
566 ----------
567 quantum : `lsst.daf.butler.Quantum`
568 The quantum to search for.
570 Returns
571 -------
572 in_graph : `bool`
573 The result of searching for the quantum.
574 """
575 return any(quantum == node.quantum for node in self)
577 def writeDotGraph(self, output: str | io.BufferedIOBase) -> None:
578 """Write out the graph as a dot graph.
580 Parameters
581 ----------
582 output : `str` or `io.BufferedIOBase`
583 Either a filesystem path to write to, or a file handle object.
584 """
585 write_dot(self._connectedQuanta, output)
587 def subset(self: _T, nodes: QuantumNode | Iterable[QuantumNode]) -> _T:
588 """Create a new graph object that contains the subset of the nodes
589 specified as input. Node number is preserved.
591 Parameters
592 ----------
593 nodes : `QuantumNode` or iterable of `QuantumNode`
594 Nodes from which to create subset.
596 Returns
597 -------
598 graph : instance of graph type
599 An instance of the type from which the subset was created.
600 """
601 if not isinstance(nodes, Iterable):
602 nodes = (nodes,)
603 quantumSubgraph = self._connectedQuanta.subgraph(nodes).nodes
604 quantumMap = defaultdict(set)
606 dataset_type_names: set[str] = set()
607 node: QuantumNode
608 for node in quantumSubgraph:
609 quantumMap[node.taskDef].add(node.quantum)
610 dataset_type_names.update(
611 dstype.name
612 for dstype in chain(
613 node.quantum.inputs.keys(), node.quantum.outputs.keys(), node.quantum.initInputs.keys()
614 )
615 )
617 # May need to trim dataset types from registryDatasetTypes.
618 for taskDef in quantumMap:
619 if refs := self.initOutputRefs(taskDef):
620 dataset_type_names.update(ref.datasetType.name for ref in refs)
621 dataset_type_names.update(ref.datasetType.name for ref in self._globalInitOutputRefs)
622 registryDatasetTypes = [
623 dstype for dstype in self._registryDatasetTypes if dstype.name in dataset_type_names
624 ]
626 # convert to standard dict to prevent accidental key insertion
627 quantumDict: dict[TaskDef, set[Quantum]] = dict(quantumMap.items())
628 # Create an empty graph, and then populate it with custom mapping
629 newInst = type(self)({}, universe=self._universe)
630 # TODO: Do we need to copy initInputs/initOutputs?
631 newInst._buildGraphs(
632 quantumDict,
633 _quantumToNodeId={n.quantum: n.nodeId for n in nodes},
634 _buildId=self._buildId,
635 metadata=self._metadata,
636 universe=self._universe,
637 globalInitOutputs=self._globalInitOutputRefs,
638 registryDatasetTypes=registryDatasetTypes,
639 )
640 return newInst
642 def subsetToConnected(self: _T) -> tuple[_T, ...]:
643 """Generate a list of subgraphs where each is connected.
645 Returns
646 -------
647 result : `list` of `QuantumGraph`
648 A list of graphs that are each connected.
649 """
650 return tuple(
651 self.subset(connectedSet)
652 for connectedSet in nx.weakly_connected_components(self._connectedQuanta)
653 )
655 def determineInputsToQuantumNode(self, node: QuantumNode) -> set[QuantumNode]:
656 """Return a set of `QuantumNode` that are direct inputs to a specified
657 node.
659 Parameters
660 ----------
661 node : `QuantumNode`
662 The node of the graph for which inputs are to be determined.
664 Returns
665 -------
666 inputs : `set` of `QuantumNode`
667 All the nodes that are direct inputs to specified node.
668 """
669 return set(self._connectedQuanta.predecessors(node))
671 def determineOutputsOfQuantumNode(self, node: QuantumNode) -> set[QuantumNode]:
672 """Return a set of `QuantumNode` that are direct outputs of a specified
673 node.
675 Parameters
676 ----------
677 node : `QuantumNode`
678 The node of the graph for which outputs are to be determined.
680 Returns
681 -------
682 outputs : `set` of `QuantumNode`
683 All the nodes that are direct outputs to specified node.
684 """
685 return set(self._connectedQuanta.successors(node))
687 def determineConnectionsOfQuantumNode(self: _T, node: QuantumNode) -> _T:
688 """Return a graph of `QuantumNode` that are direct inputs and outputs
689 of a specified node.
691 Parameters
692 ----------
693 node : `QuantumNode`
694 The node of the graph for which connected nodes are to be
695 determined.
697 Returns
698 -------
699 graph : graph of `QuantumNode`
700 All the nodes that are directly connected to specified node.
701 """
702 nodes = self.determineInputsToQuantumNode(node).union(self.determineOutputsOfQuantumNode(node))
703 nodes.add(node)
704 return self.subset(nodes)
706 def determineAncestorsOfQuantumNode(self: _T, node: QuantumNode) -> _T:
707 """Return a graph of the specified node and all the ancestor nodes
708 directly reachable by walking edges.
710 Parameters
711 ----------
712 node : `QuantumNode`
713 The node for which all ancestors are to be determined
715 Returns
716 -------
717 ancestors : graph of `QuantumNode`
718 Graph of node and all of its ancestors.
719 """
720 predecessorNodes = nx.ancestors(self._connectedQuanta, node)
721 predecessorNodes.add(node)
722 return self.subset(predecessorNodes)
724 def findCycle(self) -> list[tuple[QuantumNode, QuantumNode]]:
725 """Check a graph for the presense of cycles and returns the edges of
726 any cycles found, or an empty list if there is no cycle.
728 Returns
729 -------
730 result : `list` of `tuple` of [ `QuantumNode`, `QuantumNode` ]
731 A list of any graph edges that form a cycle, or an empty list if
732 there is no cycle. Empty list to so support if graph.find_cycle()
733 syntax as an empty list is falsy.
734 """
735 try:
736 return nx.find_cycle(self._connectedQuanta)
737 except nx.NetworkXNoCycle:
738 return []
740 def saveUri(self, uri: ResourcePathExpression) -> None:
741 """Save `QuantumGraph` to the specified URI.
743 Parameters
744 ----------
745 uri : convertible to `~lsst.resources.ResourcePath`
746 URI to where the graph should be saved.
747 """
748 buffer = self._buildSaveObject()
749 path = ResourcePath(uri)
750 if path.getExtension() not in (".qgraph"):
751 raise TypeError(f"Can currently only save a graph in qgraph format not {uri}")
752 path.write(buffer) # type: ignore # Ignore because bytearray is safe to use in place of bytes
754 @property
755 def metadata(self) -> MappingProxyType[str, Any] | None:
756 """Extra data carried with the graph (mapping [`str`] or `None`).
758 The mapping is a dynamic view of this object's metadata. Values should
759 be able to be serialized in JSON.
760 """
761 if self._metadata is None:
762 return None
763 return MappingProxyType(self._metadata)
765 def initInputRefs(self, taskDef: TaskDef) -> list[DatasetRef] | None:
766 """Return DatasetRefs for a given task InitInputs.
768 Parameters
769 ----------
770 taskDef : `TaskDef`
771 Task definition structure.
773 Returns
774 -------
775 refs : `list` [ `~lsst.daf.butler.DatasetRef` ] or `None`
776 DatasetRef for the task InitInput, can be `None`. This can return
777 either resolved or non-resolved reference.
778 """
779 return self._initInputRefs.get(taskDef)
781 def initOutputRefs(self, taskDef: TaskDef) -> list[DatasetRef] | None:
782 """Return DatasetRefs for a given task InitOutputs.
784 Parameters
785 ----------
786 taskDef : `TaskDef`
787 Task definition structure.
789 Returns
790 -------
791 refs : `list` [ `~lsst.daf.butler.DatasetRef` ] or `None`
792 DatasetRefs for the task InitOutput, can be `None`. This can return
793 either resolved or non-resolved reference. Resolved reference will
794 match Quantum's initInputs if this is an intermediate dataset type.
795 """
796 return self._initOutputRefs.get(taskDef)
798 def globalInitOutputRefs(self) -> list[DatasetRef]:
799 """Return DatasetRefs for global InitOutputs.
801 Returns
802 -------
803 refs : `list` [ `~lsst.daf.butler.DatasetRef` ]
804 DatasetRefs for global InitOutputs.
805 """
806 return self._globalInitOutputRefs
808 def registryDatasetTypes(self) -> list[DatasetType]:
809 """Return dataset types used by this graph, their definitions match
810 dataset types from registry.
812 Returns
813 -------
814 refs : `list` [ `~lsst.daf.butler.DatasetType` ]
815 Dataset types for this graph.
816 """
817 return self._registryDatasetTypes
819 @classmethod
820 def loadUri(
821 cls,
822 uri: ResourcePathExpression,
823 universe: DimensionUniverse | None = None,
824 nodes: Iterable[uuid.UUID] | None = None,
825 graphID: BuildId | None = None,
826 minimumVersion: int = 3,
827 ) -> QuantumGraph:
828 """Read `QuantumGraph` from a URI.
830 Parameters
831 ----------
832 uri : convertible to `~lsst.resources.ResourcePath`
833 URI from where to load the graph.
834 universe : `~lsst.daf.butler.DimensionUniverse`, optional
835 If `None` it is loaded from the `QuantumGraph`
836 saved structure. If supplied, the
837 `~lsst.daf.butler.DimensionUniverse` from the loaded `QuantumGraph`
838 will be validated against the supplied argument for compatibility.
839 nodes : iterable of `uuid.UUID` or `None`
840 UUIDs that correspond to nodes in the graph. If specified, only
841 these nodes will be loaded. Defaults to None, in which case all
842 nodes will be loaded.
843 graphID : `str` or `None`
844 If specified this ID is verified against the loaded graph prior to
845 loading any Nodes. This defaults to None in which case no
846 validation is done.
847 minimumVersion : `int`
848 Minimum version of a save file to load. Set to -1 to load all
849 versions. Older versions may need to be loaded, and re-saved
850 to upgrade them to the latest format before they can be used in
851 production.
853 Returns
854 -------
855 graph : `QuantumGraph`
856 Resulting QuantumGraph instance.
858 Raises
859 ------
860 TypeError
861 Raised if file contains instance of a type other than
862 `QuantumGraph`.
863 ValueError
864 Raised if one or more of the nodes requested is not in the
865 `QuantumGraph` or if graphID parameter does not match the graph
866 being loaded or if the supplied uri does not point at a valid
867 `QuantumGraph` save file.
868 RuntimeError
869 Raise if Supplied `~lsst.daf.butler.DimensionUniverse` is not
870 compatible with the `~lsst.daf.butler.DimensionUniverse` saved in
871 the graph.
872 """
873 uri = ResourcePath(uri)
874 if uri.getExtension() in {".qgraph"}:
875 with LoadHelper(uri, minimumVersion) as loader:
876 qgraph = loader.load(universe, nodes, graphID)
877 else:
878 raise ValueError(f"Only know how to handle files saved as `.qgraph`, not {uri}")
879 if not isinstance(qgraph, QuantumGraph):
880 raise TypeError(f"QuantumGraph file {uri} contains unexpected object type: {type(qgraph)}")
881 return qgraph
883 @classmethod
884 def readHeader(cls, uri: ResourcePathExpression, minimumVersion: int = 3) -> str | None:
885 """Read the header of a `QuantumGraph` pointed to by the uri parameter
886 and return it as a string.
888 Parameters
889 ----------
890 uri : convertible to `~lsst.resources.ResourcePath`
891 The location of the `QuantumGraph` to load. If the argument is a
892 string, it must correspond to a valid
893 `~lsst.resources.ResourcePath` path.
894 minimumVersion : `int`
895 Minimum version of a save file to load. Set to -1 to load all
896 versions. Older versions may need to be loaded, and re-saved
897 to upgrade them to the latest format before they can be used in
898 production.
900 Returns
901 -------
902 header : `str` or `None`
903 The header associated with the specified `QuantumGraph` it there is
904 one, else `None`.
906 Raises
907 ------
908 ValueError
909 Raised if the extension of the file specified by uri is not a
910 `QuantumGraph` extension.
911 """
912 uri = ResourcePath(uri)
913 if uri.getExtension() in {".qgraph"}:
914 return LoadHelper(uri, minimumVersion).readHeader()
915 else:
916 raise ValueError("Only know how to handle files saved as `.qgraph`")
918 def buildAndPrintHeader(self) -> None:
919 """Create a header that would be used in a save of this object and
920 prints it out to standard out.
921 """
922 _, header = self._buildSaveObject(returnHeader=True)
923 print(json.dumps(header))
925 def save(self, file: BinaryIO) -> None:
926 """Save QuantumGraph to a file.
928 Parameters
929 ----------
930 file : `io.BufferedIOBase`
931 File to write data open in binary mode.
932 """
933 buffer = self._buildSaveObject()
934 file.write(buffer) # type: ignore # Ignore because bytearray is safe to use in place of bytes
936 def _buildSaveObject(self, returnHeader: bool = False) -> bytearray | tuple[bytearray, dict]:
937 thing = PersistenceContextVars()
938 result = thing.run(self._buildSaveObjectImpl, returnHeader)
939 return result
941 def _buildSaveObjectImpl(self, returnHeader: bool = False) -> bytearray | tuple[bytearray, dict]:
942 # make some containers
943 jsonData: deque[bytes] = deque()
944 # node map is a list because json does not accept mapping keys that
945 # are not strings, so we store a list of key, value pairs that will
946 # be converted to a mapping on load
947 nodeMap = []
948 taskDefMap = {}
949 headerData: dict[str, Any] = {}
951 # Store the QauntumGraph BuildId, this will allow validating BuildIds
952 # at load time, prior to loading any QuantumNodes. Name chosen for
953 # unlikely conflicts.
954 headerData["GraphBuildID"] = self.graphID
955 headerData["Metadata"] = self._metadata
957 # Store the universe this graph was created with
958 universeConfig = self._universe.dimensionConfig
959 headerData["universe"] = universeConfig.toDict()
961 # counter for the number of bytes processed thus far
962 count = 0
963 # serialize out the task Defs recording the start and end bytes of each
964 # taskDef
965 inverseLookup = self._datasetDict.inverse
966 taskDef: TaskDef
967 # sort by task label to ensure serialization happens in the same order
968 for taskDef in self.taskGraph:
969 # compressing has very little impact on saving or load time, but
970 # a large impact on on disk size, so it is worth doing
971 taskDescription: dict[str, Any] = {}
972 # save the fully qualified name.
973 taskDescription["taskName"] = get_full_type_name(taskDef.taskClass)
974 # save the config as a text stream that will be un-persisted on the
975 # other end
976 stream = io.StringIO()
977 taskDef.config.saveToStream(stream)
978 taskDescription["config"] = stream.getvalue()
979 taskDescription["label"] = taskDef.label
980 if (refs := self._initInputRefs.get(taskDef)) is not None:
981 taskDescription["initInputRefs"] = [ref.to_json() for ref in refs]
982 if (refs := self._initOutputRefs.get(taskDef)) is not None:
983 taskDescription["initOutputRefs"] = [ref.to_json() for ref in refs]
985 inputs = []
986 outputs = []
988 # Determine the connection between all of tasks and save that in
989 # the header as a list of connections and edges in each task
990 # this will help in un-persisting, and possibly in a "quick view"
991 # method that does not require everything to be un-persisted
992 #
993 # Typing returns can't be parameter dependent
994 for connection in inverseLookup[taskDef]: # type: ignore
995 consumers = self._datasetDict.getConsumers(connection)
996 producer = self._datasetDict.getProducer(connection)
997 if taskDef in consumers:
998 # This checks if the task consumes the connection directly
999 # from the datastore or it is produced by another task
1000 producerLabel = producer.label if producer is not None else "datastore"
1001 inputs.append((producerLabel, connection))
1002 elif taskDef not in consumers and producer is taskDef:
1003 # If there are no consumers for this tasks produced
1004 # connection, the output will be said to be the datastore
1005 # in which case the for loop will be a zero length loop
1006 if not consumers:
1007 outputs.append(("datastore", connection))
1008 for td in consumers:
1009 outputs.append((td.label, connection))
1011 # dump to json string, and encode that string to bytes and then
1012 # conpress those bytes
1013 dump = lzma.compress(json.dumps(taskDescription).encode(), preset=2)
1014 # record the sizing and relation information
1015 taskDefMap[taskDef.label] = {
1016 "bytes": (count, count + len(dump)),
1017 "inputs": inputs,
1018 "outputs": outputs,
1019 }
1020 count += len(dump)
1021 jsonData.append(dump)
1023 headerData["TaskDefs"] = taskDefMap
1025 # serialize the nodes, recording the start and end bytes of each node
1026 dimAccumulator = DimensionRecordsAccumulator()
1027 for node in self:
1028 # compressing has very little impact on saving or load time, but
1029 # a large impact on on disk size, so it is worth doing
1030 simpleNode = node.to_simple(accumulator=dimAccumulator)
1032 dump = lzma.compress(simpleNode.json().encode(), preset=2)
1033 jsonData.append(dump)
1034 nodeMap.append(
1035 (
1036 str(node.nodeId),
1037 {
1038 "bytes": (count, count + len(dump)),
1039 "inputs": [str(n.nodeId) for n in self.determineInputsToQuantumNode(node)],
1040 "outputs": [str(n.nodeId) for n in self.determineOutputsOfQuantumNode(node)],
1041 },
1042 )
1043 )
1044 count += len(dump)
1046 headerData["DimensionRecords"] = {
1047 key: value.model_dump()
1048 for key, value in dimAccumulator.makeSerializedDimensionRecordMapping().items()
1049 }
1051 # need to serialize this as a series of key,value tuples because of
1052 # a limitation on how json cant do anything but strings as keys
1053 headerData["Nodes"] = nodeMap
1055 if self._globalInitOutputRefs:
1056 headerData["GlobalInitOutputRefs"] = [ref.to_json() for ref in self._globalInitOutputRefs]
1058 if self._registryDatasetTypes:
1059 headerData["RegistryDatasetTypes"] = [dstype.to_json() for dstype in self._registryDatasetTypes]
1061 # dump the headerData to json
1062 header_encode = lzma.compress(json.dumps(headerData).encode())
1064 # record the sizes as 2 unsigned long long numbers for a total of 16
1065 # bytes
1066 save_bytes = struct.pack(STRUCT_FMT_BASE, SAVE_VERSION)
1068 fmt_string = DESERIALIZER_MAP[SAVE_VERSION].FMT_STRING()
1069 map_lengths = struct.pack(fmt_string, len(header_encode))
1071 # write each component of the save out in a deterministic order
1072 buffer = bytearray()
1073 buffer.extend(MAGIC_BYTES)
1074 buffer.extend(save_bytes)
1075 buffer.extend(map_lengths)
1076 buffer.extend(header_encode)
1077 # Iterate over the length of jsonData, and for each element pop the
1078 # leftmost element off the deque and write it out. This is to save
1079 # memory, as the memory is added to the buffer object, it is removed
1080 # from from the container.
1081 #
1082 # Only this section needs to worry about memory pressure because
1083 # everything else written to the buffer prior to this data is
1084 # only on the order of kilobytes to low numbers of megabytes.
1085 while jsonData:
1086 buffer.extend(jsonData.popleft())
1087 if returnHeader:
1088 return buffer, headerData
1089 else:
1090 return buffer
1092 @classmethod
1093 def load(
1094 cls,
1095 file: BinaryIO,
1096 universe: DimensionUniverse | None = None,
1097 nodes: Iterable[uuid.UUID] | None = None,
1098 graphID: BuildId | None = None,
1099 minimumVersion: int = 3,
1100 ) -> QuantumGraph:
1101 """Read `QuantumGraph` from a file that was made by `save`.
1103 Parameters
1104 ----------
1105 file : `io.IO` of bytes
1106 File with data open in binary mode.
1107 universe : `~lsst.daf.butler.DimensionUniverse`, optional
1108 If `None` it is loaded from the `QuantumGraph`
1109 saved structure. If supplied, the
1110 `~lsst.daf.butler.DimensionUniverse` from the loaded `QuantumGraph`
1111 will be validated against the supplied argument for compatibility.
1112 nodes : iterable of `uuid.UUID` or `None`
1113 UUIDs that correspond to nodes in the graph. If specified, only
1114 these nodes will be loaded. Defaults to None, in which case all
1115 nodes will be loaded.
1116 graphID : `str` or `None`
1117 If specified this ID is verified against the loaded graph prior to
1118 loading any Nodes. This defaults to None in which case no
1119 validation is done.
1120 minimumVersion : `int`
1121 Minimum version of a save file to load. Set to -1 to load all
1122 versions. Older versions may need to be loaded, and re-saved
1123 to upgrade them to the latest format before they can be used in
1124 production.
1126 Returns
1127 -------
1128 graph : `QuantumGraph`
1129 Resulting QuantumGraph instance.
1131 Raises
1132 ------
1133 TypeError
1134 Raised if data contains instance of a type other than
1135 `QuantumGraph`.
1136 ValueError
1137 Raised if one or more of the nodes requested is not in the
1138 `QuantumGraph` or if graphID parameter does not match the graph
1139 being loaded or if the supplied uri does not point at a valid
1140 `QuantumGraph` save file.
1141 """
1142 with LoadHelper(file, minimumVersion) as loader:
1143 qgraph = loader.load(universe, nodes, graphID)
1144 if not isinstance(qgraph, QuantumGraph):
1145 raise TypeError(f"QuantumGraph file contains unexpected object type: {type(qgraph)}")
1146 return qgraph
1148 def iterTaskGraph(self) -> Generator[TaskDef, None, None]:
1149 """Iterate over the `taskGraph` attribute in topological order
1151 Yields
1152 ------
1153 taskDef : `TaskDef`
1154 `TaskDef` objects in topological order
1155 """
1156 yield from nx.topological_sort(self.taskGraph)
1158 def updateRun(self, run: str, *, metadata_key: str | None = None, update_graph_id: bool = False) -> None:
1159 """Change output run and dataset ID for each output dataset.
1161 Parameters
1162 ----------
1163 run : `str`
1164 New output run name.
1165 metadata_key : `str` or `None`
1166 Specifies matadata key corresponding to output run name to update
1167 with new run name. If `None` or if metadata is missing it is not
1168 updated. If metadata is present but key is missing, it will be
1169 added.
1170 update_graph_id : `bool`, optional
1171 If `True` then also update graph ID with a new unique value.
1172 """
1173 dataset_id_map: dict[DatasetId, DatasetId] = {}
1175 def _update_output_refs(
1176 refs: Iterable[DatasetRef], run: str, dataset_id_map: MutableMapping[DatasetId, DatasetId]
1177 ) -> Iterator[DatasetRef]:
1178 """Update a collection of `~lsst.daf.butler.DatasetRef` with new
1179 run and dataset IDs.
1180 """
1181 for ref in refs:
1182 new_ref = ref.replace(run=run)
1183 dataset_id_map[ref.id] = new_ref.id
1184 yield new_ref
1186 def _update_intermediate_refs(
1187 refs: Iterable[DatasetRef], run: str, dataset_id_map: Mapping[DatasetId, DatasetId]
1188 ) -> Iterator[DatasetRef]:
1189 """Update intermediate references with new run and IDs. Only the
1190 references that appear in ``dataset_id_map`` are updated, others
1191 are returned unchanged.
1192 """
1193 for ref in refs:
1194 if dataset_id := dataset_id_map.get(ref.id):
1195 ref = ref.replace(run=run, id=dataset_id)
1196 yield ref
1198 # Replace quantum output refs first.
1199 for node in self._connectedQuanta:
1200 quantum = node.quantum
1201 outputs = {
1202 dataset_type: tuple(_update_output_refs(refs, run, dataset_id_map))
1203 for dataset_type, refs in quantum.outputs.items()
1204 }
1205 updated_quantum = Quantum(
1206 taskName=quantum.taskName,
1207 dataId=quantum.dataId,
1208 initInputs=quantum.initInputs,
1209 inputs=quantum.inputs,
1210 outputs=outputs,
1211 datastore_records=quantum.datastore_records,
1212 )
1213 node._replace_quantum(updated_quantum)
1215 self._initOutputRefs = {
1216 task_def: list(_update_output_refs(refs, run, dataset_id_map))
1217 for task_def, refs in self._initOutputRefs.items()
1218 }
1219 self._globalInitOutputRefs = list(
1220 _update_output_refs(self._globalInitOutputRefs, run, dataset_id_map)
1221 )
1223 # Update all intermediates from their matching outputs.
1224 for node in self._connectedQuanta:
1225 quantum = node.quantum
1226 inputs = {
1227 dataset_type: tuple(_update_intermediate_refs(refs, run, dataset_id_map))
1228 for dataset_type, refs in quantum.inputs.items()
1229 }
1230 initInputs = list(_update_intermediate_refs(quantum.initInputs.values(), run, dataset_id_map))
1232 updated_quantum = Quantum(
1233 taskName=quantum.taskName,
1234 dataId=quantum.dataId,
1235 initInputs=initInputs,
1236 inputs=inputs,
1237 outputs=quantum.outputs,
1238 datastore_records=quantum.datastore_records,
1239 )
1240 node._replace_quantum(updated_quantum)
1242 self._initInputRefs = {
1243 task_def: list(_update_intermediate_refs(refs, run, dataset_id_map))
1244 for task_def, refs in self._initInputRefs.items()
1245 }
1247 if update_graph_id:
1248 self._buildId = BuildId(f"{time.time()}-{os.getpid()}")
1250 # Update metadata if present.
1251 if self._metadata is not None and metadata_key is not None:
1252 metadata = dict(self._metadata)
1253 metadata[metadata_key] = run
1254 self._metadata = metadata
1256 @property
1257 def graphID(self) -> BuildId:
1258 """The ID generated by the graph at construction time (`str`)."""
1259 return self._buildId
1261 @property
1262 def universe(self) -> DimensionUniverse:
1263 """Dimension universe associated with this graph
1264 (`~lsst.daf.butler.DimensionUniverse`).
1265 """
1266 return self._universe
1268 def __iter__(self) -> Generator[QuantumNode, None, None]:
1269 yield from nx.topological_sort(self._connectedQuanta)
1271 def __len__(self) -> int:
1272 return self._count
1274 def __contains__(self, node: QuantumNode) -> bool:
1275 return self._connectedQuanta.has_node(node)
1277 def __getstate__(self) -> dict:
1278 """Store a compact form of the graph as a list of graph nodes, and a
1279 tuple of task labels and task configs. The full graph can be
1280 reconstructed with this information, and it preserves the ordering of
1281 the graph nodes.
1282 """
1283 universe: DimensionUniverse | None = None
1284 for node in self:
1285 dId = node.quantum.dataId
1286 if dId is None:
1287 continue
1288 universe = dId.universe
1289 return {"reduced": self._buildSaveObject(), "graphId": self._buildId, "universe": universe}
1291 def __setstate__(self, state: dict) -> None:
1292 """Reconstructs the state of the graph from the information persisted
1293 in getstate.
1294 """
1295 buffer = io.BytesIO(state["reduced"])
1296 with LoadHelper(buffer, minimumVersion=3) as loader:
1297 qgraph = loader.load(state["universe"], graphID=state["graphId"])
1299 self._metadata = qgraph._metadata
1300 self._buildId = qgraph._buildId
1301 self._datasetDict = qgraph._datasetDict
1302 self._nodeIdMap = qgraph._nodeIdMap
1303 self._count = len(qgraph)
1304 self._taskToQuantumNode = qgraph._taskToQuantumNode
1305 self._taskGraph = qgraph._taskGraph
1306 self._connectedQuanta = qgraph._connectedQuanta
1307 self._initInputRefs = qgraph._initInputRefs
1308 self._initOutputRefs = qgraph._initOutputRefs
1310 def __eq__(self, other: object) -> bool:
1311 if not isinstance(other, QuantumGraph):
1312 return False
1313 if len(self) != len(other):
1314 return False
1315 for node in self:
1316 if node not in other:
1317 return False
1318 if self.determineInputsToQuantumNode(node) != other.determineInputsToQuantumNode(node):
1319 return False
1320 if self.determineOutputsOfQuantumNode(node) != other.determineOutputsOfQuantumNode(node):
1321 return False
1322 if set(self.allDatasetTypes) != set(other.allDatasetTypes):
1323 return False
1324 return set(self.taskGraph) == set(other.taskGraph)