Coverage for python/lsst/pipe/base/graph/graph.py: 21%
389 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-03-27 02:40 -0700
« prev ^ index » next coverage.py v7.4.4, created at 2024-03-27 02:40 -0700
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
27from __future__ import annotations
29__all__ = ("QuantumGraph", "IncompatibleGraphError")
31import io
32import json
33import lzma
34import os
35import struct
36import time
37import uuid
38from collections import defaultdict, deque
39from collections.abc import Generator, Iterable, Iterator, Mapping, MutableMapping
40from itertools import chain
41from types import MappingProxyType
42from typing import Any, BinaryIO, TypeVar
44import networkx as nx
45from lsst.daf.butler import (
46 DatasetId,
47 DatasetRef,
48 DatasetType,
49 DimensionRecordsAccumulator,
50 DimensionUniverse,
51 Quantum,
52)
53from lsst.daf.butler.persistence_context import PersistenceContextVars
54from lsst.resources import ResourcePath, ResourcePathExpression
55from lsst.utils.introspection import get_full_type_name
56from lsst.utils.packages import Packages
57from networkx.drawing.nx_agraph import write_dot
59from ..connections import iterConnections
60from ..pipeline import TaskDef
61from ._implDetails import DatasetTypeName, _DatasetTracker
62from ._loadHelpers import LoadHelper
63from ._versionDeserializers import DESERIALIZER_MAP
64from .quantumNode import BuildId, QuantumNode
66_T = TypeVar("_T", bound="QuantumGraph")
68# modify this constant any time the on disk representation of the save file
69# changes, and update the load helpers to behave properly for each version.
70SAVE_VERSION = 3
72# Strings used to describe the format for the preamble bytes in a file save
73# The base is a big endian encoded unsigned short that is used to hold the
74# file format version. This allows reading version bytes and determine which
75# loading code should be used for the rest of the file
76STRUCT_FMT_BASE = ">H"
77#
78# Version 1
79# This marks a big endian encoded format with an unsigned short, an unsigned
80# long long, and an unsigned long long in the byte stream
81# Version 2
82# A big endian encoded format with an unsigned long long byte stream used to
83# indicate the total length of the entire header.
84STRUCT_FMT_STRING = {1: ">QQ", 2: ">Q"}
86# magic bytes that help determine this is a graph save
87MAGIC_BYTES = b"qgraph4\xf6\xe8\xa9"
90class IncompatibleGraphError(Exception):
91 """Exception class to indicate that a lookup by NodeId is impossible due
92 to incompatibilities.
93 """
95 pass
98class QuantumGraph:
99 """QuantumGraph is a directed acyclic graph of `QuantumNode` objects.
101 This data structure represents a concrete workflow generated from a
102 `Pipeline`.
104 Parameters
105 ----------
106 quanta : `~collections.abc.Mapping` [ `TaskDef`, \
107 `set` [ `~lsst.daf.butler.Quantum` ] ]
108 This maps tasks (and their configs) to the sets of data they are to
109 process.
110 metadata : Optional `~collections.abc.Mapping` of `str` to primitives
111 This is an optional parameter of extra data to carry with the graph.
112 Entries in this mapping should be able to be serialized in JSON.
113 universe : `~lsst.daf.butler.DimensionUniverse`, optional
114 The dimensions in which quanta can be defined. Need only be provided if
115 no quanta have data IDs.
116 initInputs : `~collections.abc.Mapping`, optional
117 Maps tasks to their InitInput dataset refs. Dataset refs can be either
118 resolved or non-resolved. Presently the same dataset refs are included
119 in each `~lsst.daf.butler.Quantum` for the same task.
120 initOutputs : `~collections.abc.Mapping`, optional
121 Maps tasks to their InitOutput dataset refs. Dataset refs can be either
122 resolved or non-resolved. For intermediate resolved refs their dataset
123 ID must match ``initInputs`` and Quantum ``initInputs``.
124 globalInitOutputs : iterable [ `~lsst.daf.butler.DatasetRef` ], optional
125 Dataset refs for some global objects produced by pipeline. These
126 objects include task configurations and package versions. Typically
127 they have an empty DataId, but there is no real restriction on what
128 can appear here.
129 registryDatasetTypes : iterable [ `~lsst.daf.butler.DatasetType` ], \
130 optional
131 Dataset types which are used by this graph, their definitions must
132 match registry. If registry does not define dataset type yet, then
133 it should match one that will be created later.
135 Raises
136 ------
137 ValueError
138 Raised if the graph is pruned such that some tasks no longer have nodes
139 associated with them.
140 """
142 def __init__(
143 self,
144 quanta: Mapping[TaskDef, set[Quantum]],
145 metadata: Mapping[str, Any] | None = None,
146 universe: DimensionUniverse | None = None,
147 initInputs: Mapping[TaskDef, Iterable[DatasetRef]] | None = None,
148 initOutputs: Mapping[TaskDef, Iterable[DatasetRef]] | None = None,
149 globalInitOutputs: Iterable[DatasetRef] | None = None,
150 registryDatasetTypes: Iterable[DatasetType] | None = None,
151 ):
152 self._buildGraphs(
153 quanta,
154 metadata=metadata,
155 universe=universe,
156 initInputs=initInputs,
157 initOutputs=initOutputs,
158 globalInitOutputs=globalInitOutputs,
159 registryDatasetTypes=registryDatasetTypes,
160 )
162 def _buildGraphs(
163 self,
164 quanta: Mapping[TaskDef, set[Quantum]],
165 *,
166 _quantumToNodeId: Mapping[Quantum, uuid.UUID] | None = None,
167 _buildId: BuildId | None = None,
168 metadata: Mapping[str, Any] | None = None,
169 universe: DimensionUniverse | None = None,
170 initInputs: Mapping[TaskDef, Iterable[DatasetRef]] | None = None,
171 initOutputs: Mapping[TaskDef, Iterable[DatasetRef]] | None = None,
172 globalInitOutputs: Iterable[DatasetRef] | None = None,
173 registryDatasetTypes: Iterable[DatasetType] | None = None,
174 ) -> None:
175 """Build the graph that is used to store the relation between tasks,
176 and the graph that holds the relations between quanta
177 """
178 # Save packages to metadata
179 self._metadata = dict(metadata) if metadata is not None else {}
180 self._metadata["packages"] = Packages.fromSystem()
182 self._buildId = _buildId if _buildId is not None else BuildId(f"{time.time()}-{os.getpid()}")
183 # Data structure used to identify relations between
184 # DatasetTypeName -> TaskDef.
185 self._datasetDict = _DatasetTracker(createInverse=True)
187 # Temporary graph that will have dataset UUIDs (as raw bytes) and
188 # QuantumNode objects as nodes; will be collapsed down to just quanta
189 # later.
190 bipartite_graph = nx.DiGraph()
192 self._nodeIdMap: dict[uuid.UUID, QuantumNode] = {}
193 self._taskToQuantumNode: MutableMapping[TaskDef, set[QuantumNode]] = defaultdict(set)
194 for taskDef, quantumSet in quanta.items():
195 connections = taskDef.connections
197 # For each type of connection in the task, add a key to the
198 # `_DatasetTracker` for the connections name, with a value of
199 # the TaskDef in the appropriate field
200 for inpt in iterConnections(connections, ("inputs", "prerequisiteInputs", "initInputs")):
201 # Have to handle components in inputs.
202 dataset_name, _, _ = inpt.name.partition(".")
203 self._datasetDict.addConsumer(DatasetTypeName(dataset_name), taskDef)
205 for output in iterConnections(connections, ("outputs",)):
206 # Have to handle possible components in outputs.
207 dataset_name, _, _ = output.name.partition(".")
208 self._datasetDict.addProducer(DatasetTypeName(dataset_name), taskDef)
210 # For each `Quantum` in the set of all `Quantum` for this task,
211 # add a key to the `_DatasetTracker` that is a `DatasetRef` for one
212 # of the individual datasets inside the `Quantum`, with a value of
213 # a newly created QuantumNode to the appropriate input/output
214 # field.
215 for quantum in quantumSet:
216 if quantum.dataId is not None:
217 if universe is None:
218 universe = quantum.dataId.universe
219 elif universe != quantum.dataId.universe:
220 raise RuntimeError(
221 "Mismatched dimension universes in QuantumGraph construction: "
222 f"{universe} != {quantum.dataId.universe}. "
223 )
225 if _quantumToNodeId:
226 if (nodeId := _quantumToNodeId.get(quantum)) is None:
227 raise ValueError(
228 "If _quantuMToNodeNumber is not None, all quanta must have an "
229 "associated value in the mapping"
230 )
231 else:
232 nodeId = uuid.uuid4()
234 inits = quantum.initInputs.values()
235 inputs = quantum.inputs.values()
236 value = QuantumNode(quantum, taskDef, nodeId)
237 self._taskToQuantumNode[taskDef].add(value)
238 self._nodeIdMap[nodeId] = value
240 bipartite_graph.add_node(value, bipartite=0)
241 for dsRef in chain(inits, inputs):
242 # unfortunately, `Quantum` allows inits to be individual
243 # `DatasetRef`s or an Iterable of such, so there must
244 # be an instance check here
245 if isinstance(dsRef, Iterable):
246 for sub in dsRef:
247 bipartite_graph.add_node(sub.id.bytes, bipartite=1)
248 bipartite_graph.add_edge(sub.id.bytes, value)
249 else:
250 assert isinstance(dsRef, DatasetRef)
251 if dsRef.isComponent():
252 dsRef = dsRef.makeCompositeRef()
253 bipartite_graph.add_node(dsRef.id.bytes, bipartite=1)
254 bipartite_graph.add_edge(dsRef.id.bytes, value)
255 for dsRef in chain.from_iterable(quantum.outputs.values()):
256 bipartite_graph.add_node(dsRef.id.bytes, bipartite=1)
257 bipartite_graph.add_edge(value, dsRef.id.bytes)
259 # Dimension universe
260 if universe is None:
261 raise RuntimeError(
262 "Dimension universe or at least one quantum with a data ID "
263 "must be provided when constructing a QuantumGraph."
264 )
265 self._universe = universe
267 # Make graph of quanta relations, by projecting out the dataset nodes
268 # in the bipartite_graph, leaving just the quanta.
269 self._connectedQuanta = nx.algorithms.bipartite.projected_graph(
270 bipartite_graph, self._nodeIdMap.values()
271 )
272 self._count = len(self._connectedQuanta)
274 # Graph of task relations, used in various methods
275 self._taskGraph = self._datasetDict.makeNetworkXGraph()
277 # convert default dict into a regular to prevent accidental key
278 # insertion
279 self._taskToQuantumNode = dict(self._taskToQuantumNode.items())
281 self._initInputRefs: dict[TaskDef, list[DatasetRef]] = {}
282 self._initOutputRefs: dict[TaskDef, list[DatasetRef]] = {}
283 self._globalInitOutputRefs: list[DatasetRef] = []
284 self._registryDatasetTypes: list[DatasetType] = []
285 if initInputs is not None:
286 self._initInputRefs = {taskDef: list(refs) for taskDef, refs in initInputs.items()}
287 if initOutputs is not None:
288 self._initOutputRefs = {taskDef: list(refs) for taskDef, refs in initOutputs.items()}
289 if globalInitOutputs is not None:
290 self._globalInitOutputRefs = list(globalInitOutputs)
291 if registryDatasetTypes is not None:
292 self._registryDatasetTypes = list(registryDatasetTypes)
294 @property
295 def taskGraph(self) -> nx.DiGraph:
296 """A graph representing the relations between the tasks inside
297 the quantum graph (`networkx.DiGraph`).
298 """
299 return self._taskGraph
301 @property
302 def graph(self) -> nx.DiGraph:
303 """A graph representing the relations between all the `QuantumNode`
304 objects (`networkx.DiGraph`).
306 The graph should usually be iterated over, or passed to methods of this
307 class, but sometimes direct access to the ``networkx`` object may be
308 helpful.
309 """
310 return self._connectedQuanta
312 @property
313 def inputQuanta(self) -> Iterable[QuantumNode]:
314 """The nodes that are inputs to the graph (iterable [`QuantumNode`]).
316 These are the nodes that do not depend on any other nodes in the
317 graph.
318 """
319 return (q for q, n in self._connectedQuanta.in_degree if n == 0)
321 @property
322 def outputQuanta(self) -> Iterable[QuantumNode]:
323 """The nodes that are outputs of the graph (iterable [`QuantumNode`]).
325 These are the nodes that have no nodes that depend on them in the
326 graph.
327 """
328 return [q for q, n in self._connectedQuanta.out_degree if n == 0]
330 @property
331 def allDatasetTypes(self) -> tuple[DatasetTypeName, ...]:
332 """All the data set type names that are present in the graph
333 (`tuple` [`str`]).
335 These types do not include global init-outputs.
336 """
337 return tuple(self._datasetDict.keys())
339 @property
340 def isConnected(self) -> bool:
341 """Whether all of the nodes in the graph are connected, ignoring
342 directionality of connections (`bool`).
343 """
344 return nx.is_weakly_connected(self._connectedQuanta)
346 def getQuantumNodeByNodeId(self, nodeId: uuid.UUID) -> QuantumNode:
347 """Lookup a `QuantumNode` from an id associated with the node.
349 Parameters
350 ----------
351 nodeId : `NodeId`
352 The number associated with a node.
354 Returns
355 -------
356 node : `QuantumNode`
357 The node corresponding with input number.
359 Raises
360 ------
361 KeyError
362 Raised if the requested nodeId is not in the graph.
363 """
364 return self._nodeIdMap[nodeId]
366 def getQuantaForTask(self, taskDef: TaskDef) -> frozenset[Quantum]:
367 """Return all the `~lsst.daf.butler.Quantum` associated with a
368 `TaskDef`.
370 Parameters
371 ----------
372 taskDef : `TaskDef`
373 The `TaskDef` for which `~lsst.daf.butler.Quantum` are to be
374 queried.
376 Returns
377 -------
378 quanta : `frozenset` of `~lsst.daf.butler.Quantum`
379 The `set` of `~lsst.daf.butler.Quantum` that is associated with the
380 specified `TaskDef`.
381 """
382 return frozenset(node.quantum for node in self._taskToQuantumNode.get(taskDef, ()))
384 def getNumberOfQuantaForTask(self, taskDef: TaskDef) -> int:
385 """Return the number of `~lsst.daf.butler.Quantum` associated with
386 a `TaskDef`.
388 Parameters
389 ----------
390 taskDef : `TaskDef`
391 The `TaskDef` for which `~lsst.daf.butler.Quantum` are to be
392 queried.
394 Returns
395 -------
396 count : `int`
397 The number of `~lsst.daf.butler.Quantum` that are associated with
398 the specified `TaskDef`.
399 """
400 return len(self._taskToQuantumNode.get(taskDef, ()))
402 def getNodesForTask(self, taskDef: TaskDef) -> frozenset[QuantumNode]:
403 r"""Return all the `QuantumNode`\s associated with a `TaskDef`.
405 Parameters
406 ----------
407 taskDef : `TaskDef`
408 The `TaskDef` for which `~lsst.daf.butler.Quantum` are to be
409 queried.
411 Returns
412 -------
413 nodes : `frozenset` [ `QuantumNode` ]
414 A `frozenset` of `QuantumNode` that is associated with the
415 specified `TaskDef`.
416 """
417 return frozenset(self._taskToQuantumNode[taskDef])
419 def findTasksWithInput(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]:
420 """Find all tasks that have the specified dataset type name as an
421 input.
423 Parameters
424 ----------
425 datasetTypeName : `str`
426 A string representing the name of a dataset type to be queried,
427 can also accept a `DatasetTypeName` which is a `~typing.NewType` of
428 `str` for type safety in static type checking.
430 Returns
431 -------
432 tasks : iterable of `TaskDef`
433 `TaskDef` objects that have the specified `DatasetTypeName` as an
434 input, list will be empty if no tasks use specified
435 `DatasetTypeName` as an input.
437 Raises
438 ------
439 KeyError
440 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`.
441 """
442 return (c for c in self._datasetDict.getConsumers(datasetTypeName))
444 def findTaskWithOutput(self, datasetTypeName: DatasetTypeName) -> TaskDef | None:
445 """Find all tasks that have the specified dataset type name as an
446 output.
448 Parameters
449 ----------
450 datasetTypeName : `str`
451 A string representing the name of a dataset type to be queried,
452 can also accept a `DatasetTypeName` which is a `~typing.NewType` of
453 `str` for type safety in static type checking.
455 Returns
456 -------
457 result : `TaskDef` or `None`
458 `TaskDef` that outputs `DatasetTypeName` as an output or `None` if
459 none of the tasks produce this `DatasetTypeName`.
461 Raises
462 ------
463 KeyError
464 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`.
465 """
466 return self._datasetDict.getProducer(datasetTypeName)
468 def tasksWithDSType(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]:
469 """Find all tasks that are associated with the specified dataset type
470 name.
472 Parameters
473 ----------
474 datasetTypeName : `str`
475 A string representing the name of a dataset type to be queried,
476 can also accept a `DatasetTypeName` which is a `~typing.NewType` of
477 `str` for type safety in static type checking.
479 Returns
480 -------
481 result : iterable of `TaskDef`
482 `TaskDef` objects that are associated with the specified
483 `DatasetTypeName`.
485 Raises
486 ------
487 KeyError
488 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`.
489 """
490 return self._datasetDict.getAll(datasetTypeName)
492 def findTaskDefByName(self, taskName: str) -> list[TaskDef]:
493 """Determine which `TaskDef` objects in this graph are associated
494 with a `str` representing a task name (looks at the ``taskName``
495 property of `TaskDef` objects).
497 Returns a list of `TaskDef` objects as a `PipelineTask` may appear
498 multiple times in a graph with different labels.
500 Parameters
501 ----------
502 taskName : `str`
503 Name of a task to search for.
505 Returns
506 -------
507 result : `list` of `TaskDef`
508 List of the `TaskDef` objects that have the name specified.
509 Multiple values are returned in the case that a task is used
510 multiple times with different labels.
511 """
512 results = []
513 for task in self._taskToQuantumNode:
514 split = task.taskName.split(".")
515 if split[-1] == taskName:
516 results.append(task)
517 return results
519 def findTaskDefByLabel(self, label: str) -> TaskDef | None:
520 """Determine which `TaskDef` objects in this graph are associated
521 with a `str` representing a tasks label.
523 Parameters
524 ----------
525 label : `str`
526 Name of a task to search for.
528 Returns
529 -------
530 result : `TaskDef`
531 `TaskDef` objects that has the specified label.
532 """
533 for task in self._taskToQuantumNode:
534 if label == task.label:
535 return task
536 return None
538 def findQuantaWithDSType(self, datasetTypeName: DatasetTypeName) -> set[Quantum]:
539 r"""Return all the `~lsst.daf.butler.Quantum` that contain a specified
540 `DatasetTypeName`.
542 Parameters
543 ----------
544 datasetTypeName : `str`
545 The name of the dataset type to search for as a string,
546 can also accept a `DatasetTypeName` which is a `~typing.NewType` of
547 `str` for type safety in static type checking.
549 Returns
550 -------
551 result : `set` of `QuantumNode` objects
552 A `set` of `QuantumNode`\s that contain specified
553 `DatasetTypeName`.
555 Raises
556 ------
557 KeyError
558 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`.
559 """
560 tasks = self._datasetDict.getAll(datasetTypeName)
561 result: set[Quantum] = set()
562 result = result.union(quantum for task in tasks for quantum in self.getQuantaForTask(task))
563 return result
565 def checkQuantumInGraph(self, quantum: Quantum) -> bool:
566 """Check if specified quantum appears in the graph as part of a node.
568 Parameters
569 ----------
570 quantum : `lsst.daf.butler.Quantum`
571 The quantum to search for.
573 Returns
574 -------
575 in_graph : `bool`
576 The result of searching for the quantum.
577 """
578 return any(quantum == node.quantum for node in self)
580 def writeDotGraph(self, output: str | io.BufferedIOBase) -> None:
581 """Write out the graph as a dot graph.
583 Parameters
584 ----------
585 output : `str` or `io.BufferedIOBase`
586 Either a filesystem path to write to, or a file handle object.
587 """
588 write_dot(self._connectedQuanta, output)
590 def subset(self: _T, nodes: QuantumNode | Iterable[QuantumNode]) -> _T:
591 """Create a new graph object that contains the subset of the nodes
592 specified as input. Node number is preserved.
594 Parameters
595 ----------
596 nodes : `QuantumNode` or iterable of `QuantumNode`
597 Nodes from which to create subset.
599 Returns
600 -------
601 graph : instance of graph type
602 An instance of the type from which the subset was created.
603 """
604 if not isinstance(nodes, Iterable):
605 nodes = (nodes,)
606 quantumSubgraph = self._connectedQuanta.subgraph(nodes).nodes
607 quantumMap = defaultdict(set)
609 dataset_type_names: set[str] = set()
610 node: QuantumNode
611 for node in quantumSubgraph:
612 quantumMap[node.taskDef].add(node.quantum)
613 dataset_type_names.update(
614 dstype.name
615 for dstype in chain(
616 node.quantum.inputs.keys(), node.quantum.outputs.keys(), node.quantum.initInputs.keys()
617 )
618 )
620 # May need to trim dataset types from registryDatasetTypes.
621 for taskDef in quantumMap:
622 if refs := self.initOutputRefs(taskDef):
623 dataset_type_names.update(ref.datasetType.name for ref in refs)
624 dataset_type_names.update(ref.datasetType.name for ref in self._globalInitOutputRefs)
625 registryDatasetTypes = [
626 dstype for dstype in self._registryDatasetTypes if dstype.name in dataset_type_names
627 ]
629 # convert to standard dict to prevent accidental key insertion
630 quantumDict: dict[TaskDef, set[Quantum]] = dict(quantumMap.items())
631 # Create an empty graph, and then populate it with custom mapping
632 newInst = type(self)({}, universe=self._universe)
633 # TODO: Do we need to copy initInputs/initOutputs?
634 newInst._buildGraphs(
635 quantumDict,
636 _quantumToNodeId={n.quantum: n.nodeId for n in nodes},
637 _buildId=self._buildId,
638 metadata=self._metadata,
639 universe=self._universe,
640 globalInitOutputs=self._globalInitOutputRefs,
641 registryDatasetTypes=registryDatasetTypes,
642 )
643 return newInst
645 def subsetToConnected(self: _T) -> tuple[_T, ...]:
646 """Generate a list of subgraphs where each is connected.
648 Returns
649 -------
650 result : `list` of `QuantumGraph`
651 A list of graphs that are each connected.
652 """
653 return tuple(
654 self.subset(connectedSet)
655 for connectedSet in nx.weakly_connected_components(self._connectedQuanta)
656 )
658 def determineInputsToQuantumNode(self, node: QuantumNode) -> set[QuantumNode]:
659 """Return a set of `QuantumNode` that are direct inputs to a specified
660 node.
662 Parameters
663 ----------
664 node : `QuantumNode`
665 The node of the graph for which inputs are to be determined.
667 Returns
668 -------
669 inputs : `set` of `QuantumNode`
670 All the nodes that are direct inputs to specified node.
671 """
672 return set(self._connectedQuanta.predecessors(node))
674 def determineOutputsOfQuantumNode(self, node: QuantumNode) -> set[QuantumNode]:
675 """Return a set of `QuantumNode` that are direct outputs of a specified
676 node.
678 Parameters
679 ----------
680 node : `QuantumNode`
681 The node of the graph for which outputs are to be determined.
683 Returns
684 -------
685 outputs : `set` of `QuantumNode`
686 All the nodes that are direct outputs to specified node.
687 """
688 return set(self._connectedQuanta.successors(node))
690 def determineConnectionsOfQuantumNode(self: _T, node: QuantumNode) -> _T:
691 """Return a graph of `QuantumNode` that are direct inputs and outputs
692 of a specified node.
694 Parameters
695 ----------
696 node : `QuantumNode`
697 The node of the graph for which connected nodes are to be
698 determined.
700 Returns
701 -------
702 graph : graph of `QuantumNode`
703 All the nodes that are directly connected to specified node.
704 """
705 nodes = self.determineInputsToQuantumNode(node).union(self.determineOutputsOfQuantumNode(node))
706 nodes.add(node)
707 return self.subset(nodes)
709 def determineAncestorsOfQuantumNode(self: _T, node: QuantumNode) -> _T:
710 """Return a graph of the specified node and all the ancestor nodes
711 directly reachable by walking edges.
713 Parameters
714 ----------
715 node : `QuantumNode`
716 The node for which all ancestors are to be determined.
718 Returns
719 -------
720 ancestors : graph of `QuantumNode`
721 Graph of node and all of its ancestors.
722 """
723 predecessorNodes = nx.ancestors(self._connectedQuanta, node)
724 predecessorNodes.add(node)
725 return self.subset(predecessorNodes)
727 def findCycle(self) -> list[tuple[QuantumNode, QuantumNode]]:
728 """Check a graph for the presense of cycles and returns the edges of
729 any cycles found, or an empty list if there is no cycle.
731 Returns
732 -------
733 result : `list` of `tuple` of [ `QuantumNode`, `QuantumNode` ]
734 A list of any graph edges that form a cycle, or an empty list if
735 there is no cycle. Empty list to so support if graph.find_cycle()
736 syntax as an empty list is falsy.
737 """
738 try:
739 return nx.find_cycle(self._connectedQuanta)
740 except nx.NetworkXNoCycle:
741 return []
743 def saveUri(self, uri: ResourcePathExpression) -> None:
744 """Save `QuantumGraph` to the specified URI.
746 Parameters
747 ----------
748 uri : convertible to `~lsst.resources.ResourcePath`
749 URI to where the graph should be saved.
750 """
751 buffer = self._buildSaveObject()
752 path = ResourcePath(uri)
753 if path.getExtension() not in (".qgraph"):
754 raise TypeError(f"Can currently only save a graph in qgraph format not {uri}")
755 path.write(buffer) # type: ignore # Ignore because bytearray is safe to use in place of bytes
757 @property
758 def metadata(self) -> MappingProxyType[str, Any] | None:
759 """Extra data carried with the graph (mapping [`str`] or `None`).
761 The mapping is a dynamic view of this object's metadata. Values should
762 be able to be serialized in JSON.
763 """
764 return MappingProxyType(self._metadata)
766 def initInputRefs(self, taskDef: TaskDef) -> list[DatasetRef] | None:
767 """Return DatasetRefs for a given task InitInputs.
769 Parameters
770 ----------
771 taskDef : `TaskDef`
772 Task definition structure.
774 Returns
775 -------
776 refs : `list` [ `~lsst.daf.butler.DatasetRef` ] or `None`
777 DatasetRef for the task InitInput, can be `None`. This can return
778 either resolved or non-resolved reference.
779 """
780 return self._initInputRefs.get(taskDef)
782 def initOutputRefs(self, taskDef: TaskDef) -> list[DatasetRef] | None:
783 """Return DatasetRefs for a given task InitOutputs.
785 Parameters
786 ----------
787 taskDef : `TaskDef`
788 Task definition structure.
790 Returns
791 -------
792 refs : `list` [ `~lsst.daf.butler.DatasetRef` ] or `None`
793 DatasetRefs for the task InitOutput, can be `None`. This can return
794 either resolved or non-resolved reference. Resolved reference will
795 match Quantum's initInputs if this is an intermediate dataset type.
796 """
797 return self._initOutputRefs.get(taskDef)
799 def globalInitOutputRefs(self) -> list[DatasetRef]:
800 """Return DatasetRefs for global InitOutputs.
802 Returns
803 -------
804 refs : `list` [ `~lsst.daf.butler.DatasetRef` ]
805 DatasetRefs for global InitOutputs.
806 """
807 return self._globalInitOutputRefs
809 def registryDatasetTypes(self) -> list[DatasetType]:
810 """Return dataset types used by this graph, their definitions match
811 dataset types from registry.
813 Returns
814 -------
815 refs : `list` [ `~lsst.daf.butler.DatasetType` ]
816 Dataset types for this graph.
817 """
818 return self._registryDatasetTypes
820 @classmethod
821 def loadUri(
822 cls,
823 uri: ResourcePathExpression,
824 universe: DimensionUniverse | None = None,
825 nodes: Iterable[uuid.UUID] | None = None,
826 graphID: BuildId | None = None,
827 minimumVersion: int = 3,
828 ) -> QuantumGraph:
829 """Read `QuantumGraph` from a URI.
831 Parameters
832 ----------
833 uri : convertible to `~lsst.resources.ResourcePath`
834 URI from where to load the graph.
835 universe : `~lsst.daf.butler.DimensionUniverse`, optional
836 If `None` it is loaded from the `QuantumGraph`
837 saved structure. If supplied, the
838 `~lsst.daf.butler.DimensionUniverse` from the loaded `QuantumGraph`
839 will be validated against the supplied argument for compatibility.
840 nodes : iterable of `uuid.UUID` or `None`
841 UUIDs that correspond to nodes in the graph. If specified, only
842 these nodes will be loaded. Defaults to None, in which case all
843 nodes will be loaded.
844 graphID : `str` or `None`
845 If specified this ID is verified against the loaded graph prior to
846 loading any Nodes. This defaults to None in which case no
847 validation is done.
848 minimumVersion : `int`
849 Minimum version of a save file to load. Set to -1 to load all
850 versions. Older versions may need to be loaded, and re-saved
851 to upgrade them to the latest format before they can be used in
852 production.
854 Returns
855 -------
856 graph : `QuantumGraph`
857 Resulting QuantumGraph instance.
859 Raises
860 ------
861 TypeError
862 Raised if file contains instance of a type other than
863 `QuantumGraph`.
864 ValueError
865 Raised if one or more of the nodes requested is not in the
866 `QuantumGraph` or if graphID parameter does not match the graph
867 being loaded or if the supplied uri does not point at a valid
868 `QuantumGraph` save file.
869 RuntimeError
870 Raise if Supplied `~lsst.daf.butler.DimensionUniverse` is not
871 compatible with the `~lsst.daf.butler.DimensionUniverse` saved in
872 the graph.
873 """
874 uri = ResourcePath(uri)
875 if uri.getExtension() in {".qgraph"}:
876 with LoadHelper(uri, minimumVersion) as loader:
877 qgraph = loader.load(universe, nodes, graphID)
878 else:
879 raise ValueError(f"Only know how to handle files saved as `.qgraph`, not {uri}")
880 if not isinstance(qgraph, QuantumGraph):
881 raise TypeError(f"QuantumGraph file {uri} contains unexpected object type: {type(qgraph)}")
882 return qgraph
884 @classmethod
885 def readHeader(cls, uri: ResourcePathExpression, minimumVersion: int = 3) -> str | None:
886 """Read the header of a `QuantumGraph` pointed to by the uri parameter
887 and return it as a string.
889 Parameters
890 ----------
891 uri : convertible to `~lsst.resources.ResourcePath`
892 The location of the `QuantumGraph` to load. If the argument is a
893 string, it must correspond to a valid
894 `~lsst.resources.ResourcePath` path.
895 minimumVersion : `int`
896 Minimum version of a save file to load. Set to -1 to load all
897 versions. Older versions may need to be loaded, and re-saved
898 to upgrade them to the latest format before they can be used in
899 production.
901 Returns
902 -------
903 header : `str` or `None`
904 The header associated with the specified `QuantumGraph` it there is
905 one, else `None`.
907 Raises
908 ------
909 ValueError
910 Raised if the extension of the file specified by uri is not a
911 `QuantumGraph` extension.
912 """
913 uri = ResourcePath(uri)
914 if uri.getExtension() in {".qgraph"}:
915 return LoadHelper(uri, minimumVersion).readHeader()
916 else:
917 raise ValueError("Only know how to handle files saved as `.qgraph`")
919 def buildAndPrintHeader(self) -> None:
920 """Create a header that would be used in a save of this object and
921 prints it out to standard out.
922 """
923 _, header = self._buildSaveObject(returnHeader=True)
924 print(json.dumps(header))
926 def save(self, file: BinaryIO) -> None:
927 """Save QuantumGraph to a file.
929 Parameters
930 ----------
931 file : `io.BufferedIOBase`
932 File to write data open in binary mode.
933 """
934 buffer = self._buildSaveObject()
935 file.write(buffer) # type: ignore # Ignore because bytearray is safe to use in place of bytes
937 def _buildSaveObject(self, returnHeader: bool = False) -> bytearray | tuple[bytearray, dict]:
938 thing = PersistenceContextVars()
939 result = thing.run(self._buildSaveObjectImpl, returnHeader)
940 return result
942 def _buildSaveObjectImpl(self, returnHeader: bool = False) -> bytearray | tuple[bytearray, dict]:
943 # make some containers
944 jsonData: deque[bytes] = deque()
945 # node map is a list because json does not accept mapping keys that
946 # are not strings, so we store a list of key, value pairs that will
947 # be converted to a mapping on load
948 nodeMap = []
949 taskDefMap = {}
950 headerData: dict[str, Any] = {}
952 # Store the QauntumGraph BuildId, this will allow validating BuildIds
953 # at load time, prior to loading any QuantumNodes. Name chosen for
954 # unlikely conflicts.
955 headerData["GraphBuildID"] = self.graphID
956 headerData["Metadata"] = self._metadata
958 # Store the universe this graph was created with
959 universeConfig = self._universe.dimensionConfig
960 headerData["universe"] = universeConfig.toDict()
962 # counter for the number of bytes processed thus far
963 count = 0
964 # serialize out the task Defs recording the start and end bytes of each
965 # taskDef
966 inverseLookup = self._datasetDict.inverse
967 taskDef: TaskDef
968 # sort by task label to ensure serialization happens in the same order
969 for taskDef in self.taskGraph:
970 # compressing has very little impact on saving or load time, but
971 # a large impact on on disk size, so it is worth doing
972 taskDescription: dict[str, Any] = {}
973 # save the fully qualified name.
974 taskDescription["taskName"] = get_full_type_name(taskDef.taskClass)
975 # save the config as a text stream that will be un-persisted on the
976 # other end
977 stream = io.StringIO()
978 taskDef.config.saveToStream(stream)
979 taskDescription["config"] = stream.getvalue()
980 taskDescription["label"] = taskDef.label
981 if (refs := self._initInputRefs.get(taskDef)) is not None:
982 taskDescription["initInputRefs"] = [ref.to_json() for ref in refs]
983 if (refs := self._initOutputRefs.get(taskDef)) is not None:
984 taskDescription["initOutputRefs"] = [ref.to_json() for ref in refs]
986 inputs = []
987 outputs = []
989 # Determine the connection between all of tasks and save that in
990 # the header as a list of connections and edges in each task
991 # this will help in un-persisting, and possibly in a "quick view"
992 # method that does not require everything to be un-persisted
993 #
994 # Typing returns can't be parameter dependent
995 for connection in inverseLookup[taskDef]: # type: ignore
996 consumers = self._datasetDict.getConsumers(connection)
997 producer = self._datasetDict.getProducer(connection)
998 if taskDef in consumers:
999 # This checks if the task consumes the connection directly
1000 # from the datastore or it is produced by another task
1001 producerLabel = producer.label if producer is not None else "datastore"
1002 inputs.append((producerLabel, connection))
1003 elif taskDef not in consumers and producer is taskDef:
1004 # If there are no consumers for this tasks produced
1005 # connection, the output will be said to be the datastore
1006 # in which case the for loop will be a zero length loop
1007 if not consumers:
1008 outputs.append(("datastore", connection))
1009 for td in consumers:
1010 outputs.append((td.label, connection))
1012 # dump to json string, and encode that string to bytes and then
1013 # conpress those bytes
1014 dump = lzma.compress(json.dumps(taskDescription).encode(), preset=2)
1015 # record the sizing and relation information
1016 taskDefMap[taskDef.label] = {
1017 "bytes": (count, count + len(dump)),
1018 "inputs": inputs,
1019 "outputs": outputs,
1020 }
1021 count += len(dump)
1022 jsonData.append(dump)
1024 headerData["TaskDefs"] = taskDefMap
1026 # serialize the nodes, recording the start and end bytes of each node
1027 dimAccumulator = DimensionRecordsAccumulator()
1028 for node in self:
1029 # compressing has very little impact on saving or load time, but
1030 # a large impact on on disk size, so it is worth doing
1031 simpleNode = node.to_simple(accumulator=dimAccumulator)
1033 dump = lzma.compress(simpleNode.model_dump_json().encode(), preset=2)
1034 jsonData.append(dump)
1035 nodeMap.append(
1036 (
1037 str(node.nodeId),
1038 {
1039 "bytes": (count, count + len(dump)),
1040 "inputs": [str(n.nodeId) for n in self.determineInputsToQuantumNode(node)],
1041 "outputs": [str(n.nodeId) for n in self.determineOutputsOfQuantumNode(node)],
1042 },
1043 )
1044 )
1045 count += len(dump)
1047 headerData["DimensionRecords"] = {
1048 key: value.model_dump()
1049 for key, value in dimAccumulator.makeSerializedDimensionRecordMapping().items()
1050 }
1052 # need to serialize this as a series of key,value tuples because of
1053 # a limitation on how json cant do anything but strings as keys
1054 headerData["Nodes"] = nodeMap
1056 if self._globalInitOutputRefs:
1057 headerData["GlobalInitOutputRefs"] = [ref.to_json() for ref in self._globalInitOutputRefs]
1059 if self._registryDatasetTypes:
1060 headerData["RegistryDatasetTypes"] = [dstype.to_json() for dstype in self._registryDatasetTypes]
1062 # dump the headerData to json
1063 header_encode = lzma.compress(json.dumps(headerData).encode())
1065 # record the sizes as 2 unsigned long long numbers for a total of 16
1066 # bytes
1067 save_bytes = struct.pack(STRUCT_FMT_BASE, SAVE_VERSION)
1069 fmt_string = DESERIALIZER_MAP[SAVE_VERSION].FMT_STRING()
1070 map_lengths = struct.pack(fmt_string, len(header_encode))
1072 # write each component of the save out in a deterministic order
1073 buffer = bytearray()
1074 buffer.extend(MAGIC_BYTES)
1075 buffer.extend(save_bytes)
1076 buffer.extend(map_lengths)
1077 buffer.extend(header_encode)
1078 # Iterate over the length of jsonData, and for each element pop the
1079 # leftmost element off the deque and write it out. This is to save
1080 # memory, as the memory is added to the buffer object, it is removed
1081 # from from the container.
1082 #
1083 # Only this section needs to worry about memory pressure because
1084 # everything else written to the buffer prior to this data is
1085 # only on the order of kilobytes to low numbers of megabytes.
1086 while jsonData:
1087 buffer.extend(jsonData.popleft())
1088 if returnHeader:
1089 return buffer, headerData
1090 else:
1091 return buffer
1093 @classmethod
1094 def load(
1095 cls,
1096 file: BinaryIO,
1097 universe: DimensionUniverse | None = None,
1098 nodes: Iterable[uuid.UUID] | None = None,
1099 graphID: BuildId | None = None,
1100 minimumVersion: int = 3,
1101 ) -> QuantumGraph:
1102 """Read `QuantumGraph` from a file that was made by `save`.
1104 Parameters
1105 ----------
1106 file : `io.IO` of bytes
1107 File with data open in binary mode.
1108 universe : `~lsst.daf.butler.DimensionUniverse`, optional
1109 If `None` it is loaded from the `QuantumGraph`
1110 saved structure. If supplied, the
1111 `~lsst.daf.butler.DimensionUniverse` from the loaded `QuantumGraph`
1112 will be validated against the supplied argument for compatibility.
1113 nodes : iterable of `uuid.UUID` or `None`
1114 UUIDs that correspond to nodes in the graph. If specified, only
1115 these nodes will be loaded. Defaults to None, in which case all
1116 nodes will be loaded.
1117 graphID : `str` or `None`
1118 If specified this ID is verified against the loaded graph prior to
1119 loading any Nodes. This defaults to None in which case no
1120 validation is done.
1121 minimumVersion : `int`
1122 Minimum version of a save file to load. Set to -1 to load all
1123 versions. Older versions may need to be loaded, and re-saved
1124 to upgrade them to the latest format before they can be used in
1125 production.
1127 Returns
1128 -------
1129 graph : `QuantumGraph`
1130 Resulting QuantumGraph instance.
1132 Raises
1133 ------
1134 TypeError
1135 Raised if data contains instance of a type other than
1136 `QuantumGraph`.
1137 ValueError
1138 Raised if one or more of the nodes requested is not in the
1139 `QuantumGraph` or if graphID parameter does not match the graph
1140 being loaded or if the supplied uri does not point at a valid
1141 `QuantumGraph` save file.
1142 """
1143 with LoadHelper(file, minimumVersion) as loader:
1144 qgraph = loader.load(universe, nodes, graphID)
1145 if not isinstance(qgraph, QuantumGraph):
1146 raise TypeError(f"QuantumGraph file contains unexpected object type: {type(qgraph)}")
1147 return qgraph
1149 def iterTaskGraph(self) -> Generator[TaskDef, None, None]:
1150 """Iterate over the `taskGraph` attribute in topological order.
1152 Yields
1153 ------
1154 taskDef : `TaskDef`
1155 `TaskDef` objects in topological order.
1156 """
1157 yield from nx.topological_sort(self.taskGraph)
1159 def updateRun(self, run: str, *, metadata_key: str | None = None, update_graph_id: bool = False) -> None:
1160 """Change output run and dataset ID for each output dataset.
1162 Parameters
1163 ----------
1164 run : `str`
1165 New output run name.
1166 metadata_key : `str` or `None`
1167 Specifies matadata key corresponding to output run name to update
1168 with new run name. If `None` or if metadata is missing it is not
1169 updated. If metadata is present but key is missing, it will be
1170 added.
1171 update_graph_id : `bool`, optional
1172 If `True` then also update graph ID with a new unique value.
1173 """
1174 dataset_id_map: dict[DatasetId, DatasetId] = {}
1176 def _update_output_refs(
1177 refs: Iterable[DatasetRef], run: str, dataset_id_map: MutableMapping[DatasetId, DatasetId]
1178 ) -> Iterator[DatasetRef]:
1179 """Update a collection of `~lsst.daf.butler.DatasetRef` with new
1180 run and dataset IDs.
1181 """
1182 for ref in refs:
1183 new_ref = ref.replace(run=run)
1184 dataset_id_map[ref.id] = new_ref.id
1185 yield new_ref
1187 def _update_intermediate_refs(
1188 refs: Iterable[DatasetRef], run: str, dataset_id_map: Mapping[DatasetId, DatasetId]
1189 ) -> Iterator[DatasetRef]:
1190 """Update intermediate references with new run and IDs. Only the
1191 references that appear in ``dataset_id_map`` are updated, others
1192 are returned unchanged.
1193 """
1194 for ref in refs:
1195 if dataset_id := dataset_id_map.get(ref.id):
1196 ref = ref.replace(run=run, id=dataset_id)
1197 yield ref
1199 # Replace quantum output refs first.
1200 for node in self._connectedQuanta:
1201 quantum = node.quantum
1202 outputs = {
1203 dataset_type: tuple(_update_output_refs(refs, run, dataset_id_map))
1204 for dataset_type, refs in quantum.outputs.items()
1205 }
1206 updated_quantum = Quantum(
1207 taskName=quantum.taskName,
1208 dataId=quantum.dataId,
1209 initInputs=quantum.initInputs,
1210 inputs=quantum.inputs,
1211 outputs=outputs,
1212 datastore_records=quantum.datastore_records,
1213 )
1214 node._replace_quantum(updated_quantum)
1216 self._initOutputRefs = {
1217 task_def: list(_update_output_refs(refs, run, dataset_id_map))
1218 for task_def, refs in self._initOutputRefs.items()
1219 }
1220 self._globalInitOutputRefs = list(
1221 _update_output_refs(self._globalInitOutputRefs, run, dataset_id_map)
1222 )
1224 # Update all intermediates from their matching outputs.
1225 for node in self._connectedQuanta:
1226 quantum = node.quantum
1227 inputs = {
1228 dataset_type: tuple(_update_intermediate_refs(refs, run, dataset_id_map))
1229 for dataset_type, refs in quantum.inputs.items()
1230 }
1231 initInputs = list(_update_intermediate_refs(quantum.initInputs.values(), run, dataset_id_map))
1233 updated_quantum = Quantum(
1234 taskName=quantum.taskName,
1235 dataId=quantum.dataId,
1236 initInputs=initInputs,
1237 inputs=inputs,
1238 outputs=quantum.outputs,
1239 datastore_records=quantum.datastore_records,
1240 )
1241 node._replace_quantum(updated_quantum)
1243 self._initInputRefs = {
1244 task_def: list(_update_intermediate_refs(refs, run, dataset_id_map))
1245 for task_def, refs in self._initInputRefs.items()
1246 }
1248 if update_graph_id:
1249 self._buildId = BuildId(f"{time.time()}-{os.getpid()}")
1251 # Update run if given.
1252 if metadata_key is not None:
1253 self._metadata[metadata_key] = run
1255 @property
1256 def graphID(self) -> BuildId:
1257 """The ID generated by the graph at construction time (`str`)."""
1258 return self._buildId
1260 @property
1261 def universe(self) -> DimensionUniverse:
1262 """Dimension universe associated with this graph
1263 (`~lsst.daf.butler.DimensionUniverse`).
1264 """
1265 return self._universe
1267 def __iter__(self) -> Generator[QuantumNode, None, None]:
1268 yield from nx.topological_sort(self._connectedQuanta)
1270 def __len__(self) -> int:
1271 return self._count
1273 def __contains__(self, node: QuantumNode) -> bool:
1274 return self._connectedQuanta.has_node(node)
1276 def __getstate__(self) -> dict:
1277 """Store a compact form of the graph as a list of graph nodes, and a
1278 tuple of task labels and task configs. The full graph can be
1279 reconstructed with this information, and it preserves the ordering of
1280 the graph nodes.
1281 """
1282 universe: DimensionUniverse | None = None
1283 for node in self:
1284 dId = node.quantum.dataId
1285 if dId is None:
1286 continue
1287 universe = dId.universe
1288 return {"reduced": self._buildSaveObject(), "graphId": self._buildId, "universe": universe}
1290 def __setstate__(self, state: dict) -> None:
1291 """Reconstructs the state of the graph from the information persisted
1292 in getstate.
1293 """
1294 buffer = io.BytesIO(state["reduced"])
1295 with LoadHelper(buffer, minimumVersion=3) as loader:
1296 qgraph = loader.load(state["universe"], graphID=state["graphId"])
1298 self._metadata = qgraph._metadata
1299 self._buildId = qgraph._buildId
1300 self._datasetDict = qgraph._datasetDict
1301 self._nodeIdMap = qgraph._nodeIdMap
1302 self._count = len(qgraph)
1303 self._taskToQuantumNode = qgraph._taskToQuantumNode
1304 self._taskGraph = qgraph._taskGraph
1305 self._connectedQuanta = qgraph._connectedQuanta
1306 self._initInputRefs = qgraph._initInputRefs
1307 self._initOutputRefs = qgraph._initOutputRefs
1309 def __eq__(self, other: object) -> bool:
1310 if not isinstance(other, QuantumGraph):
1311 return False
1312 if len(self) != len(other):
1313 return False
1314 for node in self:
1315 if node not in other:
1316 return False
1317 if self.determineInputsToQuantumNode(node) != other.determineInputsToQuantumNode(node):
1318 return False
1319 if self.determineOutputsOfQuantumNode(node) != other.determineOutputsOfQuantumNode(node):
1320 return False
1321 if set(self.allDatasetTypes) != set(other.allDatasetTypes):
1322 return False
1323 return set(self.taskGraph) == set(other.taskGraph)