Coverage for python / lsst / pipe / base / graph / graph.py: 15%
586 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 23:47 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 23:47 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
27from __future__ import annotations
29__all__ = ("IncompatibleGraphError", "QuantumGraph")
31import datetime
32import getpass
33import io
34import json
35import lzma
36import os
37import struct
38import sys
39import time
40import uuid
41from collections import defaultdict, deque
42from collections.abc import Generator, Iterable, Iterator, Mapping, MutableMapping
43from itertools import chain
44from types import MappingProxyType
45from typing import Any, BinaryIO, TypeVar
47import networkx as nx
48from networkx.drawing.nx_agraph import write_dot
50import lsst.utils.logging
51from lsst.daf.butler import (
52 Config,
53 DatasetId,
54 DatasetRef,
55 DatasetType,
56 DimensionRecordsAccumulator,
57 DimensionUniverse,
58 LimitedButler,
59 Quantum,
60 QuantumBackedButler,
61)
62from lsst.daf.butler._rubin import generate_uuidv7
63from lsst.daf.butler.datastore.record_data import DatastoreRecordData
64from lsst.daf.butler.persistence_context import PersistenceContextVars
65from lsst.daf.butler.registry import ConflictingDefinitionError
66from lsst.resources import ResourcePath, ResourcePathExpression
67from lsst.utils.introspection import get_full_type_name
68from lsst.utils.packages import Packages
70from ..config import PipelineTaskConfig
71from ..connections import iterConnections
72from ..pipeline import TaskDef
73from ..pipeline_graph import PipelineGraph, compare_packages, log_config_mismatch
74from ._implDetails import DatasetTypeName, _DatasetTracker
75from ._loadHelpers import LoadHelper
76from ._versionDeserializers import DESERIALIZER_MAP
77from .graphSummary import QgraphSummary, QgraphTaskSummary
78from .quantumNode import BuildId, QuantumNode
80_T = TypeVar("_T", bound="QuantumGraph")
81_LOG = lsst.utils.logging.getLogger(__name__)
83# modify this constant any time the on disk representation of the save file
84# changes, and update the load helpers to behave properly for each version.
85SAVE_VERSION = 3
87# Strings used to describe the format for the preamble bytes in a file save
88# The base is a big endian encoded unsigned short that is used to hold the
89# file format version. This allows reading version bytes and determine which
90# loading code should be used for the rest of the file
91STRUCT_FMT_BASE = ">H"
92#
93# Version 1
94# This marks a big endian encoded format with an unsigned short, an unsigned
95# long long, and an unsigned long long in the byte stream
96# Version 2
97# A big endian encoded format with an unsigned long long byte stream used to
98# indicate the total length of the entire header.
99STRUCT_FMT_STRING = {1: ">QQ", 2: ">Q"}
101# magic bytes that help determine this is a graph save
102MAGIC_BYTES = b"qgraph4\xf6\xe8\xa9"
105class IncompatibleGraphError(Exception):
106 """Exception class to indicate that a lookup by NodeId is impossible due
107 to incompatibilities.
108 """
110 pass
113class QuantumGraph:
114 """QuantumGraph is a directed acyclic graph of `QuantumNode` objects.
116 This data structure represents a concrete workflow generated from a
117 `Pipeline`.
119 Parameters
120 ----------
121 quanta : `~collections.abc.Mapping` [ `TaskDef`, \
122 `set` [ `~lsst.daf.butler.Quantum` ] ]
123 This maps tasks (and their configs) to the sets of data they are to
124 process.
125 metadata : Optional `~collections.abc.Mapping` of `str` to primitives
126 This is an optional parameter of extra data to carry with the graph.
127 Entries in this mapping should be able to be serialized in JSON.
128 universe : `~lsst.daf.butler.DimensionUniverse`, optional
129 The dimensions in which quanta can be defined. Need only be provided if
130 no quanta have data IDs.
131 initInputs : `~collections.abc.Mapping`, optional
132 Maps tasks to their InitInput dataset refs. Dataset refs can be either
133 resolved or non-resolved. Presently the same dataset refs are included
134 in each `~lsst.daf.butler.Quantum` for the same task.
135 initOutputs : `~collections.abc.Mapping`, optional
136 Maps tasks to their InitOutput dataset refs. Dataset refs can be either
137 resolved or non-resolved. For intermediate resolved refs their dataset
138 ID must match ``initInputs`` and Quantum ``initInputs``.
139 globalInitOutputs : `~collections.abc.Iterable` \
140 [ `~lsst.daf.butler.DatasetRef` ], optional
141 Dataset refs for some global objects produced by pipeline. These
142 objects include task configurations and package versions. Typically
143 they have an empty DataId, but there is no real restriction on what
144 can appear here.
145 registryDatasetTypes : `~collections.abc.Iterable` \
146 [ `~lsst.daf.butler.DatasetType` ], optional
147 Dataset types which are used by this graph, their definitions must
148 match registry. If registry does not define dataset type yet, then
149 it should match one that will be created later.
151 Raises
152 ------
153 ValueError
154 Raised if the graph is pruned such that some tasks no longer have nodes
155 associated with them.
156 """
158 def __init__(
159 self,
160 quanta: Mapping[TaskDef, set[Quantum]],
161 metadata: Mapping[str, Any] | None = None,
162 universe: DimensionUniverse | None = None,
163 initInputs: Mapping[TaskDef, Iterable[DatasetRef]] | None = None,
164 initOutputs: Mapping[TaskDef, Iterable[DatasetRef]] | None = None,
165 globalInitOutputs: Iterable[DatasetRef] | None = None,
166 registryDatasetTypes: Iterable[DatasetType] | None = None,
167 ):
168 self._buildGraphs(
169 quanta,
170 metadata=metadata,
171 universe=universe,
172 initInputs=initInputs,
173 initOutputs=initOutputs,
174 globalInitOutputs=globalInitOutputs,
175 registryDatasetTypes=registryDatasetTypes,
176 )
178 def _buildGraphs(
179 self,
180 quanta: Mapping[TaskDef, set[Quantum]],
181 *,
182 _quantumToNodeId: Mapping[Quantum, uuid.UUID] | None = None,
183 _buildId: BuildId | None = None,
184 metadata: Mapping[str, Any] | None = None,
185 universe: DimensionUniverse | None = None,
186 initInputs: Mapping[TaskDef, Iterable[DatasetRef]] | None = None,
187 initOutputs: Mapping[TaskDef, Iterable[DatasetRef]] | None = None,
188 globalInitOutputs: Iterable[DatasetRef] | None = None,
189 registryDatasetTypes: Iterable[DatasetType] | None = None,
190 ) -> None:
191 """Build the graph that is used to store the relation between tasks,
192 and the graph that holds the relations between quanta
193 """
194 # Save packages to metadata
195 self._metadata = dict(metadata) if metadata is not None else {}
196 self._metadata.setdefault("packages", Packages.fromSystem())
197 self._metadata.setdefault("user", getpass.getuser())
198 self._metadata.setdefault("time", f"{datetime.datetime.now()}")
199 self._metadata.setdefault("full_command", " ".join(sys.argv))
201 self._buildId = _buildId if _buildId is not None else BuildId(f"{time.time()}-{os.getpid()}")
202 # Data structure used to identify relations between
203 # DatasetTypeName -> TaskDef.
204 self._datasetDict = _DatasetTracker(createInverse=True)
206 # Temporary graph that will have dataset UUIDs (as raw bytes) and
207 # QuantumNode objects as nodes; will be collapsed down to just quanta
208 # later.
209 bipartite_graph = nx.DiGraph()
211 self._nodeIdMap: dict[uuid.UUID, QuantumNode] = {}
212 self._taskToQuantumNode: MutableMapping[TaskDef, set[QuantumNode]] = defaultdict(set)
213 for taskDef, quantumSet in quanta.items():
214 connections = taskDef.connections
216 # For each type of connection in the task, add a key to the
217 # `_DatasetTracker` for the connections name, with a value of
218 # the TaskDef in the appropriate field
219 for inpt in iterConnections(connections, ("inputs", "prerequisiteInputs", "initInputs")):
220 # Have to handle components in inputs.
221 dataset_name, _, _ = inpt.name.partition(".")
222 self._datasetDict.addConsumer(DatasetTypeName(dataset_name), taskDef)
224 for output in iterConnections(connections, ("outputs",)):
225 # Have to handle possible components in outputs.
226 dataset_name, _, _ = output.name.partition(".")
227 self._datasetDict.addProducer(DatasetTypeName(dataset_name), taskDef)
229 # For each `Quantum` in the set of all `Quantum` for this task,
230 # add a key to the `_DatasetTracker` that is a `DatasetRef` for one
231 # of the individual datasets inside the `Quantum`, with a value of
232 # a newly created QuantumNode to the appropriate input/output
233 # field.
234 for quantum in quantumSet:
235 if quantum.dataId is not None:
236 if universe is None:
237 universe = quantum.dataId.universe
238 elif universe != quantum.dataId.universe:
239 raise RuntimeError(
240 "Mismatched dimension universes in QuantumGraph construction: "
241 f"{universe} != {quantum.dataId.universe}. "
242 )
244 if _quantumToNodeId:
245 if (nodeId := _quantumToNodeId.get(quantum)) is None:
246 raise ValueError(
247 "If _quantuMToNodeNumber is not None, all quanta must have an "
248 "associated value in the mapping"
249 )
250 else:
251 nodeId = generate_uuidv7()
253 inits = quantum.initInputs.values()
254 inputs = quantum.inputs.values()
255 value = QuantumNode(quantum, taskDef, nodeId)
256 self._taskToQuantumNode[taskDef].add(value)
257 self._nodeIdMap[nodeId] = value
259 bipartite_graph.add_node(value, bipartite=0)
260 for dsRef in chain(inits, inputs):
261 # unfortunately, `Quantum` allows inits to be individual
262 # `DatasetRef`s or an Iterable of such, so there must
263 # be an instance check here
264 if isinstance(dsRef, Iterable):
265 for sub in dsRef:
266 bipartite_graph.add_node(sub.id.bytes, bipartite=1)
267 bipartite_graph.add_edge(sub.id.bytes, value)
268 else:
269 assert isinstance(dsRef, DatasetRef)
270 if dsRef.isComponent():
271 dsRef = dsRef.makeCompositeRef()
272 bipartite_graph.add_node(dsRef.id.bytes, bipartite=1)
273 bipartite_graph.add_edge(dsRef.id.bytes, value)
274 for dsRef in chain.from_iterable(quantum.outputs.values()):
275 bipartite_graph.add_node(dsRef.id.bytes, bipartite=1)
276 bipartite_graph.add_edge(value, dsRef.id.bytes)
278 # Dimension universe
279 if universe is None:
280 raise RuntimeError(
281 "Dimension universe or at least one quantum with a data ID "
282 "must be provided when constructing a QuantumGraph."
283 )
284 self._universe = universe
286 # Make graph of quanta relations, by projecting out the dataset nodes
287 # in the bipartite_graph, leaving just the quanta.
288 self._connectedQuanta = nx.algorithms.bipartite.projected_graph(
289 bipartite_graph, self._nodeIdMap.values()
290 )
291 self._count = len(self._connectedQuanta)
293 # Graph of task relations, used in various methods
294 self._taskGraph = self._datasetDict.makeNetworkXGraph()
296 # convert default dict into a regular to prevent accidental key
297 # insertion
298 self._taskToQuantumNode = dict(self._taskToQuantumNode.items())
300 self._initInputRefs: dict[str, list[DatasetRef]] = {}
301 self._initOutputRefs: dict[str, list[DatasetRef]] = {}
302 self._globalInitOutputRefs: list[DatasetRef] = []
303 self._registryDatasetTypes: list[DatasetType] = []
304 if initInputs is not None:
305 self._initInputRefs = {taskDef.label: list(refs) for taskDef, refs in initInputs.items()}
306 if initOutputs is not None:
307 self._initOutputRefs = {taskDef.label: list(refs) for taskDef, refs in initOutputs.items()}
308 if globalInitOutputs is not None:
309 self._globalInitOutputRefs = list(globalInitOutputs)
310 if registryDatasetTypes is not None:
311 self._registryDatasetTypes = list(registryDatasetTypes)
313 # PipelineGraph is current constructed on first use.
314 # TODO DM-40442: use PipelineGraph instead of TaskDef
315 # collections.
316 self._pipeline_graph: PipelineGraph | None = None
318 @property
319 def pipeline_graph(self) -> PipelineGraph:
320 """A graph representation of the tasks and dataset types in the quantum
321 graph.
322 """
323 if self._pipeline_graph is None:
324 # Construct into a temporary for strong exception safety.
325 pipeline_graph = PipelineGraph()
326 for task_def in self._taskToQuantumNode.keys():
327 pipeline_graph.add_task(
328 task_def.label, task_def.taskClass, task_def.config, connections=task_def.connections
329 )
330 dataset_types = {dataset_type.name: dataset_type for dataset_type in self._registryDatasetTypes}
331 pipeline_graph.resolve(dimensions=self._universe, dataset_types=dataset_types)
332 self._pipeline_graph = pipeline_graph
333 return self._pipeline_graph
335 def get_task_quanta(self, label: str) -> Mapping[uuid.UUID, Quantum]:
336 """Return the quanta associated with the given task label.
338 Parameters
339 ----------
340 label : `str`
341 Task label.
343 Returns
344 -------
345 quanta : `~collections.abc.Mapping` [ uuid.UUID, `Quantum` ]
346 Mapping from quantum ID to quantum. Empty if ``label`` does not
347 correspond to a task in this graph.
348 """
349 task_def = self.findTaskDefByLabel(label)
350 if not task_def:
351 return {}
352 return {node.nodeId: node.quantum for node in self.getNodesForTask(task_def)}
354 @property
355 def taskGraph(self) -> nx.DiGraph:
356 """A graph representing the relations between the tasks inside
357 the quantum graph (`networkx.DiGraph`).
358 """
359 return self._taskGraph
361 @property
362 def graph(self) -> nx.DiGraph:
363 """A graph representing the relations between all the `QuantumNode`
364 objects (`networkx.DiGraph`).
366 The graph should usually be iterated over, or passed to methods of this
367 class, but sometimes direct access to the ``networkx`` object may be
368 helpful.
369 """
370 return self._connectedQuanta
372 @property
373 def inputQuanta(self) -> Iterable[QuantumNode]:
374 """The nodes that are inputs to the graph (iterable [`QuantumNode`]).
376 These are the nodes that do not depend on any other nodes in the
377 graph.
378 """
379 return (q for q, n in self._connectedQuanta.in_degree if n == 0)
381 @property
382 def outputQuanta(self) -> Iterable[QuantumNode]:
383 """The nodes that are outputs of the graph (iterable [`QuantumNode`]).
385 These are the nodes that have no nodes that depend on them in the
386 graph.
387 """
388 return [q for q, n in self._connectedQuanta.out_degree if n == 0]
390 @property
391 def allDatasetTypes(self) -> tuple[DatasetTypeName, ...]:
392 """All the data set type names that are present in the graph
393 (`tuple` [`str`]).
395 These types do not include global init-outputs.
396 """
397 return tuple(self._datasetDict.keys())
399 @property
400 def isConnected(self) -> bool:
401 """Whether all of the nodes in the graph are connected, ignoring
402 directionality of connections (`bool`).
403 """
404 return nx.is_weakly_connected(self._connectedQuanta)
406 def getQuantumNodeByNodeId(self, nodeId: uuid.UUID) -> QuantumNode:
407 """Lookup a `QuantumNode` from an id associated with the node.
409 Parameters
410 ----------
411 nodeId : `NodeId`
412 The number associated with a node.
414 Returns
415 -------
416 node : `QuantumNode`
417 The node corresponding with input number.
419 Raises
420 ------
421 KeyError
422 Raised if the requested nodeId is not in the graph.
423 """
424 return self._nodeIdMap[nodeId]
426 def getQuantaForTask(self, taskDef: TaskDef) -> frozenset[Quantum]:
427 """Return all the `~lsst.daf.butler.Quantum` associated with a
428 `TaskDef`.
430 Parameters
431 ----------
432 taskDef : `TaskDef`
433 The `TaskDef` for which `~lsst.daf.butler.Quantum` are to be
434 queried.
436 Returns
437 -------
438 quanta : `frozenset` of `~lsst.daf.butler.Quantum`
439 The `set` of `~lsst.daf.butler.Quantum` that is associated with the
440 specified `TaskDef`.
441 """
442 return frozenset(node.quantum for node in self._taskToQuantumNode.get(taskDef, ()))
444 def getNumberOfQuantaForTask(self, taskDef: TaskDef) -> int:
445 """Return the number of `~lsst.daf.butler.Quantum` associated with
446 a `TaskDef`.
448 Parameters
449 ----------
450 taskDef : `TaskDef`
451 The `TaskDef` for which `~lsst.daf.butler.Quantum` are to be
452 queried.
454 Returns
455 -------
456 count : `int`
457 The number of `~lsst.daf.butler.Quantum` that are associated with
458 the specified `TaskDef`.
459 """
460 return len(self._taskToQuantumNode.get(taskDef, ()))
462 def getNodesForTask(self, taskDef: TaskDef) -> frozenset[QuantumNode]:
463 r"""Return all the `QuantumNode`\s associated with a `TaskDef`.
465 Parameters
466 ----------
467 taskDef : `TaskDef`
468 The `TaskDef` for which `~lsst.daf.butler.Quantum` are to be
469 queried.
471 Returns
472 -------
473 nodes : `frozenset` [ `QuantumNode` ]
474 A `frozenset` of `QuantumNode` that is associated with the
475 specified `TaskDef`.
476 """
477 return frozenset(self._taskToQuantumNode[taskDef])
479 def findTasksWithInput(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]:
480 """Find all tasks that have the specified dataset type name as an
481 input.
483 Parameters
484 ----------
485 datasetTypeName : `str`
486 A string representing the name of a dataset type to be queried,
487 can also accept a `DatasetTypeName` which is a `~typing.NewType` of
488 `str` for type safety in static type checking.
490 Returns
491 -------
492 tasks : `~collections.abc.Iterable` [ `TaskDef` ]
493 `TaskDef` objects that have the specified `DatasetTypeName` as an
494 input, list will be empty if no tasks use specified
495 `DatasetTypeName` as an input.
497 Raises
498 ------
499 KeyError
500 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`.
501 """
502 return (c for c in self._datasetDict.getConsumers(datasetTypeName))
504 def findTaskWithOutput(self, datasetTypeName: DatasetTypeName) -> TaskDef | None:
505 """Find all tasks that have the specified dataset type name as an
506 output.
508 Parameters
509 ----------
510 datasetTypeName : `str`
511 A string representing the name of a dataset type to be queried,
512 can also accept a `DatasetTypeName` which is a `~typing.NewType` of
513 `str` for type safety in static type checking.
515 Returns
516 -------
517 result : `TaskDef` or `None`
518 `TaskDef` that outputs `DatasetTypeName` as an output or `None` if
519 none of the tasks produce this `DatasetTypeName`.
521 Raises
522 ------
523 KeyError
524 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`.
525 """
526 return self._datasetDict.getProducer(datasetTypeName)
528 def tasksWithDSType(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]:
529 """Find all tasks that are associated with the specified dataset type
530 name.
532 Parameters
533 ----------
534 datasetTypeName : `str`
535 A string representing the name of a dataset type to be queried,
536 can also accept a `DatasetTypeName` which is a `~typing.NewType` of
537 `str` for type safety in static type checking.
539 Returns
540 -------
541 result : `~collections.abc.Iterable` [`TaskDef`]
542 `TaskDef` objects that are associated with the specified
543 `DatasetTypeName`.
545 Raises
546 ------
547 KeyError
548 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`.
549 """
550 return self._datasetDict.getAll(datasetTypeName)
552 def findTaskDefByName(self, taskName: str) -> list[TaskDef]:
553 """Determine which `TaskDef` objects in this graph are associated
554 with a `str` representing a task name (looks at the ``taskName``
555 property of `TaskDef` objects).
557 Returns a list of `TaskDef` objects as a `PipelineTask` may appear
558 multiple times in a graph with different labels.
560 Parameters
561 ----------
562 taskName : `str`
563 Name of a task to search for.
565 Returns
566 -------
567 result : `list` of `TaskDef`
568 List of the `TaskDef` objects that have the name specified.
569 Multiple values are returned in the case that a task is used
570 multiple times with different labels.
571 """
572 results = []
573 for task in self._taskToQuantumNode:
574 split = task.taskName.split(".")
575 if split[-1] == taskName:
576 results.append(task)
577 return results
579 def findTaskDefByLabel(self, label: str) -> TaskDef | None:
580 """Determine which `TaskDef` objects in this graph are associated
581 with a `str` representing a tasks label.
583 Parameters
584 ----------
585 label : `str`
586 Name of a task to search for.
588 Returns
589 -------
590 result : `TaskDef`
591 `TaskDef` objects that has the specified label.
592 """
593 for task in self._taskToQuantumNode:
594 if label == task.label:
595 return task
596 return None
598 def findQuantaWithDSType(self, datasetTypeName: DatasetTypeName) -> set[Quantum]:
599 r"""Return all the `~lsst.daf.butler.Quantum` that contain a specified
600 `DatasetTypeName`.
602 Parameters
603 ----------
604 datasetTypeName : `str`
605 The name of the dataset type to search for as a string,
606 can also accept a `DatasetTypeName` which is a `~typing.NewType` of
607 `str` for type safety in static type checking.
609 Returns
610 -------
611 result : `set` of `QuantumNode` objects
612 A `set` of `QuantumNode`\s that contain specified
613 `DatasetTypeName`.
615 Raises
616 ------
617 KeyError
618 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`.
619 """
620 tasks = self._datasetDict.getAll(datasetTypeName)
621 result: set[Quantum] = set()
622 result = result.union(quantum for task in tasks for quantum in self.getQuantaForTask(task))
623 return result
625 def checkQuantumInGraph(self, quantum: Quantum) -> bool:
626 """Check if specified quantum appears in the graph as part of a node.
628 Parameters
629 ----------
630 quantum : `lsst.daf.butler.Quantum`
631 The quantum to search for.
633 Returns
634 -------
635 in_graph : `bool`
636 The result of searching for the quantum.
637 """
638 return any(quantum == node.quantum for node in self)
640 def writeDotGraph(self, output: str | io.BufferedIOBase) -> None:
641 """Write out the graph as a dot graph.
643 Parameters
644 ----------
645 output : `str` or `io.BufferedIOBase`
646 Either a filesystem path to write to, or a file handle object.
647 """
648 write_dot(self._connectedQuanta, output)
650 def subset(self: _T, nodes: QuantumNode | Iterable[QuantumNode]) -> _T:
651 """Create a new graph object that contains the subset of the nodes
652 specified as input. Node number is preserved.
654 Parameters
655 ----------
656 nodes : `QuantumNode` or iterable of `QuantumNode`
657 Nodes from which to create subset.
659 Returns
660 -------
661 graph : instance of graph type
662 An instance of the type from which the subset was created.
663 """
664 if not isinstance(nodes, Iterable):
665 nodes = (nodes,)
666 quantumSubgraph = self._connectedQuanta.subgraph(nodes).nodes
667 quantumMap = defaultdict(set)
669 dataset_type_names: set[str] = set()
670 node: QuantumNode
671 for node in quantumSubgraph:
672 quantumMap[node.taskDef].add(node.quantum)
673 dataset_type_names.update(
674 dstype.name
675 for dstype in chain(
676 node.quantum.inputs.keys(), node.quantum.outputs.keys(), node.quantum.initInputs.keys()
677 )
678 )
680 # May need to trim dataset types from registryDatasetTypes.
681 for taskDef in quantumMap:
682 if refs := self.initOutputRefs(taskDef):
683 dataset_type_names.update(ref.datasetType.name for ref in refs)
684 dataset_type_names.update(ref.datasetType.name for ref in self._globalInitOutputRefs)
685 registryDatasetTypes = [
686 dstype for dstype in self._registryDatasetTypes if dstype.name in dataset_type_names
687 ]
689 # convert to standard dict to prevent accidental key insertion
690 quantumDict: dict[TaskDef, set[Quantum]] = dict(quantumMap.items())
691 # Create an empty graph, and then populate it with custom mapping
692 newInst = type(self)({}, universe=self._universe)
693 # TODO: Do we need to copy initInputs/initOutputs?
694 newInst._buildGraphs(
695 quantumDict,
696 _quantumToNodeId={n.quantum: n.nodeId for n in nodes},
697 _buildId=self._buildId,
698 metadata=self._metadata,
699 universe=self._universe,
700 globalInitOutputs=self._globalInitOutputRefs,
701 registryDatasetTypes=registryDatasetTypes,
702 )
703 return newInst
705 def subsetToConnected(self: _T) -> tuple[_T, ...]:
706 """Generate a list of subgraphs where each is connected.
708 Returns
709 -------
710 result : `list` of `QuantumGraph`
711 A list of graphs that are each connected.
712 """
713 return tuple(
714 self.subset(connectedSet)
715 for connectedSet in nx.weakly_connected_components(self._connectedQuanta)
716 )
718 def determineInputsToQuantumNode(self, node: QuantumNode) -> set[QuantumNode]:
719 """Return a set of `QuantumNode` that are direct inputs to a specified
720 node.
722 Parameters
723 ----------
724 node : `QuantumNode`
725 The node of the graph for which inputs are to be determined.
727 Returns
728 -------
729 inputs : `set` of `QuantumNode`
730 All the nodes that are direct inputs to specified node.
731 """
732 return set(self._connectedQuanta.predecessors(node))
734 def determineOutputsOfQuantumNode(self, node: QuantumNode) -> set[QuantumNode]:
735 """Return a set of `QuantumNode` that are direct outputs of a specified
736 node.
738 Parameters
739 ----------
740 node : `QuantumNode`
741 The node of the graph for which outputs are to be determined.
743 Returns
744 -------
745 outputs : `set` of `QuantumNode`
746 All the nodes that are direct outputs to specified node.
747 """
748 return set(self._connectedQuanta.successors(node))
750 def determineConnectionsOfQuantumNode(self: _T, node: QuantumNode) -> _T:
751 """Return a graph of `QuantumNode` that are direct inputs and outputs
752 of a specified node.
754 Parameters
755 ----------
756 node : `QuantumNode`
757 The node of the graph for which connected nodes are to be
758 determined.
760 Returns
761 -------
762 graph : graph of `QuantumNode`
763 All the nodes that are directly connected to specified node.
764 """
765 nodes = self.determineInputsToQuantumNode(node).union(self.determineOutputsOfQuantumNode(node))
766 nodes.add(node)
767 return self.subset(nodes)
769 def determineAncestorsOfQuantumNode(self: _T, node: QuantumNode) -> _T:
770 """Return a graph of the specified node and all the ancestor nodes
771 directly reachable by walking edges.
773 Parameters
774 ----------
775 node : `QuantumNode`
776 The node for which all ancestors are to be determined.
778 Returns
779 -------
780 ancestors : graph of `QuantumNode`
781 Graph of node and all of its ancestors.
782 """
783 predecessorNodes = nx.ancestors(self._connectedQuanta, node)
784 predecessorNodes.add(node)
785 return self.subset(predecessorNodes)
787 def findCycle(self) -> list[tuple[QuantumNode, QuantumNode]]:
788 """Check a graph for the presense of cycles and returns the edges of
789 any cycles found, or an empty list if there is no cycle.
791 Returns
792 -------
793 result : `list` of `tuple` of [ `QuantumNode`, `QuantumNode` ]
794 A list of any graph edges that form a cycle, or an empty list if
795 there is no cycle. Empty list to so support if graph.find_cycle()
796 syntax as an empty list is falsy.
797 """
798 try:
799 return nx.find_cycle(self._connectedQuanta)
800 except nx.NetworkXNoCycle:
801 return []
803 def saveUri(self, uri: ResourcePathExpression) -> None:
804 """Save `QuantumGraph` to the specified URI.
806 Parameters
807 ----------
808 uri : convertible to `~lsst.resources.ResourcePath`
809 URI to where the graph should be saved.
810 """
811 path = ResourcePath(uri)
812 match path.getExtension():
813 case ".qgraph":
814 buffer = self._buildSaveObject()
815 path.write(buffer) # type: ignore # Ignore because bytearray is safe to use in place of bytes
816 case ".qg":
817 from ..quantum_graph import PredictedQuantumGraphComponents
819 pqg = PredictedQuantumGraphComponents.from_old_quantum_graph(self)
820 pqg.write(path)
821 case ext:
822 raise TypeError(f"Can currently only save a graph in .qgraph or .qg format, not {ext!r}.")
824 @property
825 def metadata(self) -> MappingProxyType[str, Any]:
826 """Extra data carried with the graph (mapping [`str`] or `None`).
828 The mapping is a dynamic view of this object's metadata. Values should
829 be able to be serialized in JSON.
830 """
831 return MappingProxyType(self._metadata)
833 def get_init_input_refs(self, task_label: str) -> list[DatasetRef]:
834 """Return the DatasetRefs for the given task's init inputs.
836 Parameters
837 ----------
838 task_label : `str`
839 Label of the task.
841 Returns
842 -------
843 refs : `list` [ `lsst.daf.butler.DatasetRef` ]
844 Dataset references. Guaranteed to be a new list, not internal
845 state.
846 """
847 return list(self._initInputRefs.get(task_label, ()))
849 def get_init_output_refs(self, task_label: str) -> list[DatasetRef]:
850 """Return the DatasetRefs for the given task's init outputs.
852 Parameters
853 ----------
854 task_label : `str`
855 Label of the task.
857 Returns
858 -------
859 refs : `list` [ `lsst.daf.butler.DatasetRef` ]
860 Dataset references. Guaranteed to be a new list, not internal
861 state.
862 """
863 return list(self._initOutputRefs.get(task_label, ()))
865 def initInputRefs(self, taskDef: TaskDef) -> list[DatasetRef] | None:
866 """Return DatasetRefs for a given task InitInputs.
868 Parameters
869 ----------
870 taskDef : `TaskDef`
871 Task definition structure.
873 Returns
874 -------
875 refs : `list` [ `~lsst.daf.butler.DatasetRef` ] or `None`
876 DatasetRef for the task InitInput, can be `None`. This can return
877 either resolved or non-resolved reference.
878 """
879 return self._initInputRefs.get(taskDef.label)
881 def initOutputRefs(self, taskDef: TaskDef) -> list[DatasetRef] | None:
882 """Return DatasetRefs for a given task InitOutputs.
884 Parameters
885 ----------
886 taskDef : `TaskDef`
887 Task definition structure.
889 Returns
890 -------
891 refs : `list` [ `~lsst.daf.butler.DatasetRef` ] or `None`
892 DatasetRefs for the task InitOutput, can be `None`. This can return
893 either resolved or non-resolved reference. Resolved reference will
894 match Quantum's initInputs if this is an intermediate dataset type.
895 """
896 return self._initOutputRefs.get(taskDef.label)
898 def globalInitOutputRefs(self) -> list[DatasetRef]:
899 """Return DatasetRefs for global InitOutputs.
901 Returns
902 -------
903 refs : `list` [ `~lsst.daf.butler.DatasetRef` ]
904 DatasetRefs for global InitOutputs.
905 """
906 return self._globalInitOutputRefs
908 def registryDatasetTypes(self) -> list[DatasetType]:
909 """Return dataset types used by this graph, their definitions match
910 dataset types from registry.
912 Returns
913 -------
914 refs : `list` [ `~lsst.daf.butler.DatasetType` ]
915 Dataset types for this graph.
916 """
917 return self._registryDatasetTypes
919 @classmethod
920 def loadUri(
921 cls,
922 uri: ResourcePathExpression,
923 universe: DimensionUniverse | None = None,
924 nodes: Iterable[uuid.UUID | str] | None = None,
925 graphID: BuildId | None = None,
926 minimumVersion: int = 3,
927 ) -> QuantumGraph:
928 """Read `QuantumGraph` from a URI.
930 Parameters
931 ----------
932 uri : convertible to `~lsst.resources.ResourcePath`
933 URI from where to load the graph.
934 universe : `~lsst.daf.butler.DimensionUniverse`, optional
935 If `None` it is loaded from the `QuantumGraph`
936 saved structure. If supplied, the
937 `~lsst.daf.butler.DimensionUniverse` from the loaded `QuantumGraph`
938 will be validated against the supplied argument for compatibility.
939 nodes : `~collections.abc.Iterable` [ `uuid.UUID` | `str` ] or `None`
940 UUIDs that correspond to nodes in the graph. If specified, only
941 these nodes will be loaded. Defaults to None, in which case all
942 nodes will be loaded.
943 graphID : `str` or `None`
944 If specified this ID is verified against the loaded graph prior to
945 loading any Nodes. This defaults to None in which case no
946 validation is done.
947 minimumVersion : `int`
948 Minimum version of a save file to load. Set to -1 to load all
949 versions. Older versions may need to be loaded, and re-saved
950 to upgrade them to the latest format before they can be used in
951 production.
953 Returns
954 -------
955 graph : `QuantumGraph`
956 Resulting QuantumGraph instance.
958 Raises
959 ------
960 TypeError
961 Raised if file contains instance of a type other than
962 `QuantumGraph`.
963 ValueError
964 Raised if one or more of the nodes requested is not in the
965 `QuantumGraph` or if graphID parameter does not match the graph
966 being loaded or if the supplied uri does not point at a valid
967 `QuantumGraph` save file.
968 RuntimeError
969 Raise if Supplied `~lsst.daf.butler.DimensionUniverse` is not
970 compatible with the `~lsst.daf.butler.DimensionUniverse` saved in
971 the graph.
972 """
973 uri = ResourcePath(uri)
974 match uri.getExtension():
975 case ".qgraph":
976 with LoadHelper(uri, minimumVersion, fullRead=(nodes is None)) as loader:
977 qgraph = loader.load(universe, nodes, graphID)
978 case ".qg":
979 from ..quantum_graph import PredictedQuantumGraphReader
981 with PredictedQuantumGraphReader.open(uri, page_size=100000) as qgr:
982 quantum_ids = (
983 [uuid.UUID(q) if not isinstance(q, uuid.UUID) else q for q in nodes]
984 if nodes is not None
985 else None
986 )
987 qgr.read_execution_quanta(quantum_ids)
988 qgraph = qgr.finish().to_old_quantum_graph()
989 case _:
990 raise ValueError(f"Only know how to handle files saved as `.qgraph`, not {uri}")
991 if not isinstance(qgraph, QuantumGraph):
992 raise TypeError(f"QuantumGraph file {uri} contains unexpected object type: {type(qgraph)}")
993 return qgraph
995 @classmethod
996 def readHeader(cls, uri: ResourcePathExpression, minimumVersion: int = 3) -> str | None:
997 """Read the header of a `QuantumGraph` pointed to by the uri parameter
998 and return it as a string.
1000 Parameters
1001 ----------
1002 uri : convertible to `~lsst.resources.ResourcePath`
1003 The location of the `QuantumGraph` to load. If the argument is a
1004 string, it must correspond to a valid
1005 `~lsst.resources.ResourcePath` path.
1006 minimumVersion : `int`
1007 Minimum version of a save file to load. Set to -1 to load all
1008 versions. Older versions may need to be loaded, and re-saved
1009 to upgrade them to the latest format before they can be used in
1010 production.
1012 Returns
1013 -------
1014 header : `str` or `None`
1015 The header associated with the specified `QuantumGraph` it there is
1016 one, else `None`.
1018 Raises
1019 ------
1020 ValueError
1021 Raised if the extension of the file specified by uri is not a
1022 `QuantumGraph` extension.
1023 """
1024 uri = ResourcePath(uri)
1025 if uri.getExtension() in {".qgraph"}:
1026 return LoadHelper(uri, minimumVersion).readHeader()
1027 else:
1028 raise ValueError("Only know how to handle files saved as `.qgraph`")
1030 def buildAndPrintHeader(self) -> None:
1031 """Create a header that would be used in a save of this object and
1032 prints it out to standard out.
1033 """
1034 _, header = self._buildSaveObject(returnHeader=True)
1035 print(json.dumps(header))
1037 def save(self, file: BinaryIO) -> None:
1038 """Save QuantumGraph to a file.
1040 Parameters
1041 ----------
1042 file : `io.BufferedIOBase`
1043 File to write data open in binary mode.
1044 """
1045 buffer = self._buildSaveObject()
1046 file.write(buffer) # type: ignore # Ignore because bytearray is safe to use in place of bytes
1048 def _buildSaveObject(self, returnHeader: bool = False) -> bytearray | tuple[bytearray, dict]:
1049 thing = PersistenceContextVars()
1050 result = thing.run(self._buildSaveObjectImpl, returnHeader)
1051 return result
1053 def _buildSaveObjectImpl(self, returnHeader: bool = False) -> bytearray | tuple[bytearray, dict]:
1054 # make some containers
1055 jsonData: deque[bytes] = deque()
1056 # node map is a list because json does not accept mapping keys that
1057 # are not strings, so we store a list of key, value pairs that will
1058 # be converted to a mapping on load
1059 nodeMap = []
1060 taskDefMap = {}
1061 headerData: dict[str, Any] = {}
1063 # Store the QuantumGraph BuildId, this will allow validating BuildIds
1064 # at load time, prior to loading any QuantumNodes. Name chosen for
1065 # unlikely conflicts.
1066 headerData["GraphBuildID"] = self.graphID
1067 headerData["Metadata"] = self._metadata
1069 # Store the universe this graph was created with
1070 universeConfig = self._universe.dimensionConfig
1071 headerData["universe"] = universeConfig.toDict()
1073 # counter for the number of bytes processed thus far
1074 count = 0
1075 # serialize out the task Defs recording the start and end bytes of each
1076 # taskDef
1077 inverseLookup = self._datasetDict.inverse
1078 taskDef: TaskDef
1079 # sort by task label to ensure serialization happens in the same order
1080 for taskDef in self.taskGraph:
1081 # compressing has very little impact on saving or load time, but
1082 # a large impact on on disk size, so it is worth doing
1083 taskDescription: dict[str, Any] = {}
1084 # save the fully qualified name.
1085 taskDescription["taskName"] = get_full_type_name(taskDef.taskClass)
1086 # save the config as a text stream that will be un-persisted on the
1087 # other end
1088 stream = io.StringIO()
1089 taskDef.config.saveToStream(stream)
1090 taskDescription["config"] = stream.getvalue()
1091 taskDescription["label"] = taskDef.label
1092 if (refs := self._initInputRefs.get(taskDef.label)) is not None:
1093 taskDescription["initInputRefs"] = [ref.to_json() for ref in refs]
1094 if (refs := self._initOutputRefs.get(taskDef.label)) is not None:
1095 taskDescription["initOutputRefs"] = [ref.to_json() for ref in refs]
1097 inputs = []
1098 outputs = []
1100 # Determine the connection between all of tasks and save that in
1101 # the header as a list of connections and edges in each task
1102 # this will help in un-persisting, and possibly in a "quick view"
1103 # method that does not require everything to be un-persisted
1104 #
1105 # Typing returns can't be parameter dependent
1106 for connection in inverseLookup[taskDef]: # type: ignore
1107 consumers = self._datasetDict.getConsumers(connection)
1108 producer = self._datasetDict.getProducer(connection)
1109 if taskDef in consumers:
1110 # This checks if the task consumes the connection directly
1111 # from the datastore or it is produced by another task
1112 producerLabel = producer.label if producer is not None else "datastore"
1113 inputs.append((producerLabel, connection))
1114 elif taskDef not in consumers and producer is taskDef:
1115 # If there are no consumers for this tasks produced
1116 # connection, the output will be said to be the datastore
1117 # in which case the for loop will be a zero length loop
1118 if not consumers:
1119 outputs.append(("datastore", connection))
1120 for td in consumers:
1121 outputs.append((td.label, connection))
1123 # dump to json string, and encode that string to bytes and then
1124 # conpress those bytes
1125 dump = lzma.compress(json.dumps(taskDescription).encode(), preset=2)
1126 # record the sizing and relation information
1127 taskDefMap[taskDef.label] = {
1128 "bytes": (count, count + len(dump)),
1129 "inputs": inputs,
1130 "outputs": outputs,
1131 }
1132 count += len(dump)
1133 jsonData.append(dump)
1135 headerData["TaskDefs"] = taskDefMap
1137 # serialize the nodes, recording the start and end bytes of each node
1138 dimAccumulator = DimensionRecordsAccumulator()
1139 for node in self:
1140 # compressing has very little impact on saving or load time, but
1141 # a large impact on on disk size, so it is worth doing
1142 simpleNode = node.to_simple(accumulator=dimAccumulator)
1144 dump = lzma.compress(simpleNode.model_dump_json().encode(), preset=2)
1145 jsonData.append(dump)
1146 nodeMap.append(
1147 (
1148 str(node.nodeId),
1149 {
1150 "bytes": (count, count + len(dump)),
1151 "inputs": [str(n.nodeId) for n in self.determineInputsToQuantumNode(node)],
1152 "outputs": [str(n.nodeId) for n in self.determineOutputsOfQuantumNode(node)],
1153 },
1154 )
1155 )
1156 count += len(dump)
1158 headerData["DimensionRecords"] = {
1159 key: value.model_dump()
1160 for key, value in dimAccumulator.makeSerializedDimensionRecordMapping().items()
1161 }
1163 # need to serialize this as a series of key,value tuples because of
1164 # a limitation on how json cant do anything but strings as keys
1165 headerData["Nodes"] = nodeMap
1167 if self._globalInitOutputRefs:
1168 headerData["GlobalInitOutputRefs"] = [ref.to_json() for ref in self._globalInitOutputRefs]
1170 if self._registryDatasetTypes:
1171 headerData["RegistryDatasetTypes"] = [dstype.to_json() for dstype in self._registryDatasetTypes]
1173 # dump the headerData to json
1174 header_encode = lzma.compress(json.dumps(headerData).encode())
1176 # record the sizes as 2 unsigned long long numbers for a total of 16
1177 # bytes
1178 save_bytes = struct.pack(STRUCT_FMT_BASE, SAVE_VERSION)
1180 fmt_string = DESERIALIZER_MAP[SAVE_VERSION].FMT_STRING()
1181 map_lengths = struct.pack(fmt_string, len(header_encode))
1183 # write each component of the save out in a deterministic order
1184 buffer = bytearray()
1185 buffer.extend(MAGIC_BYTES)
1186 buffer.extend(save_bytes)
1187 buffer.extend(map_lengths)
1188 buffer.extend(header_encode)
1189 # Iterate over the length of jsonData, and for each element pop the
1190 # leftmost element off the deque and write it out. This is to save
1191 # memory, as the memory is added to the buffer object, it is removed
1192 # from from the container.
1193 #
1194 # Only this section needs to worry about memory pressure because
1195 # everything else written to the buffer prior to this data is
1196 # only on the order of kilobytes to low numbers of megabytes.
1197 while jsonData:
1198 buffer.extend(jsonData.popleft())
1199 if returnHeader:
1200 return buffer, headerData
1201 else:
1202 return buffer
1204 @classmethod
1205 def load(
1206 cls,
1207 file: BinaryIO,
1208 universe: DimensionUniverse | None = None,
1209 nodes: Iterable[uuid.UUID] | None = None,
1210 graphID: BuildId | None = None,
1211 minimumVersion: int = 3,
1212 ) -> QuantumGraph:
1213 """Read `QuantumGraph` from a file that was made by `save`.
1215 Parameters
1216 ----------
1217 file : `io.IO` of bytes
1218 File with data open in binary mode.
1219 universe : `~lsst.daf.butler.DimensionUniverse`, optional
1220 If `None` it is loaded from the `QuantumGraph`
1221 saved structure. If supplied, the
1222 `~lsst.daf.butler.DimensionUniverse` from the loaded `QuantumGraph`
1223 will be validated against the supplied argument for compatibility.
1224 nodes : `~collections.abc.Iterable` [`uuid.UUID`] or `None`
1225 UUIDs that correspond to nodes in the graph. If specified, only
1226 these nodes will be loaded. Defaults to None, in which case all
1227 nodes will be loaded.
1228 graphID : `str` or `None`
1229 If specified this ID is verified against the loaded graph prior to
1230 loading any Nodes. This defaults to None in which case no
1231 validation is done.
1232 minimumVersion : `int`
1233 Minimum version of a save file to load. Set to -1 to load all
1234 versions. Older versions may need to be loaded, and re-saved
1235 to upgrade them to the latest format before they can be used in
1236 production.
1238 Returns
1239 -------
1240 graph : `QuantumGraph`
1241 Resulting QuantumGraph instance.
1243 Raises
1244 ------
1245 TypeError
1246 Raised if data contains instance of a type other than
1247 `QuantumGraph`.
1248 ValueError
1249 Raised if one or more of the nodes requested is not in the
1250 `QuantumGraph` or if graphID parameter does not match the graph
1251 being loaded or if the supplied uri does not point at a valid
1252 `QuantumGraph` save file.
1253 """
1254 with LoadHelper(file, minimumVersion, fullRead=(nodes is None)) as loader:
1255 qgraph = loader.load(universe, nodes, graphID)
1256 if not isinstance(qgraph, QuantumGraph):
1257 raise TypeError(f"QuantumGraph file contains unexpected object type: {type(qgraph)}")
1258 return qgraph
1260 def iterTaskGraph(self) -> Generator[TaskDef]:
1261 """Iterate over the `taskGraph` attribute in topological order.
1263 Yields
1264 ------
1265 taskDef : `TaskDef`
1266 `TaskDef` objects in topological order.
1267 """
1268 yield from nx.topological_sort(self.taskGraph)
1270 def updateRun(self, run: str, *, metadata_key: str | None = None, update_graph_id: bool = False) -> None:
1271 """Change output run and dataset ID for each output dataset.
1273 Parameters
1274 ----------
1275 run : `str`
1276 New output run name.
1277 metadata_key : `str` or `None`
1278 Specifies matadata key corresponding to output run name to update
1279 with new run name. If `None` or if metadata is missing it is not
1280 updated. If metadata is present but key is missing, it will be
1281 added.
1282 update_graph_id : `bool`, optional
1283 If `True` then also update graph ID with a new unique value.
1284 """
1285 dataset_id_map: dict[DatasetId, DatasetId] = {}
1287 def _update_output_refs(
1288 refs: Iterable[DatasetRef], run: str, dataset_id_map: MutableMapping[DatasetId, DatasetId]
1289 ) -> Iterator[DatasetRef]:
1290 """Update a collection of `~lsst.daf.butler.DatasetRef` with new
1291 run and dataset IDs.
1292 """
1293 for ref in refs:
1294 new_ref = ref.replace(run=run)
1295 dataset_id_map[ref.id] = new_ref.id
1296 yield new_ref
1298 def _update_intermediate_refs(
1299 refs: Iterable[DatasetRef], run: str, dataset_id_map: Mapping[DatasetId, DatasetId]
1300 ) -> Iterator[DatasetRef]:
1301 """Update intermediate references with new run and IDs. Only the
1302 references that appear in ``dataset_id_map`` are updated, others
1303 are returned unchanged.
1304 """
1305 for ref in refs:
1306 if dataset_id := dataset_id_map.get(ref.id):
1307 ref = ref.replace(run=run, id=dataset_id)
1308 yield ref
1310 # Replace quantum output refs first.
1311 for node in self._connectedQuanta:
1312 quantum = node.quantum
1313 outputs = {
1314 dataset_type: tuple(_update_output_refs(refs, run, dataset_id_map))
1315 for dataset_type, refs in quantum.outputs.items()
1316 }
1317 updated_quantum = Quantum(
1318 taskName=quantum.taskName,
1319 dataId=quantum.dataId,
1320 initInputs=quantum.initInputs,
1321 inputs=quantum.inputs,
1322 outputs=outputs,
1323 datastore_records=quantum.datastore_records,
1324 )
1325 node._replace_quantum(updated_quantum)
1327 self._initOutputRefs = {
1328 task_def: list(_update_output_refs(refs, run, dataset_id_map))
1329 for task_def, refs in self._initOutputRefs.items()
1330 }
1331 self._globalInitOutputRefs = list(
1332 _update_output_refs(self._globalInitOutputRefs, run, dataset_id_map)
1333 )
1335 # Update all intermediates from their matching outputs.
1336 for node in self._connectedQuanta:
1337 quantum = node.quantum
1338 inputs = {
1339 dataset_type: tuple(_update_intermediate_refs(refs, run, dataset_id_map))
1340 for dataset_type, refs in quantum.inputs.items()
1341 }
1342 initInputs = list(_update_intermediate_refs(quantum.initInputs.values(), run, dataset_id_map))
1344 updated_quantum = Quantum(
1345 taskName=quantum.taskName,
1346 dataId=quantum.dataId,
1347 initInputs=initInputs,
1348 inputs=inputs,
1349 outputs=quantum.outputs,
1350 datastore_records=quantum.datastore_records,
1351 )
1352 node._replace_quantum(updated_quantum)
1354 self._initInputRefs = {
1355 task_def: list(_update_intermediate_refs(refs, run, dataset_id_map))
1356 for task_def, refs in self._initInputRefs.items()
1357 }
1359 if update_graph_id:
1360 self._buildId = BuildId(f"{time.time()}-{os.getpid()}")
1362 # Update run if given.
1363 if metadata_key is not None:
1364 self._metadata[metadata_key] = run
1366 @property
1367 def graphID(self) -> BuildId:
1368 """The ID generated by the graph at construction time (`str`)."""
1369 return self._buildId
1371 @property
1372 def universe(self) -> DimensionUniverse:
1373 """Dimension universe associated with this graph
1374 (`~lsst.daf.butler.DimensionUniverse`).
1375 """
1376 return self._universe
1378 def __iter__(self) -> Generator[QuantumNode]:
1379 yield from nx.topological_sort(self._connectedQuanta)
1381 def __len__(self) -> int:
1382 return self._count
1384 def __contains__(self, node: QuantumNode) -> bool:
1385 return self._connectedQuanta.has_node(node)
1387 def __getstate__(self) -> dict:
1388 """Store a compact form of the graph as a list of graph nodes, and a
1389 tuple of task labels and task configs. The full graph can be
1390 reconstructed with this information, and it preserves the ordering of
1391 the graph nodes.
1392 """
1393 universe: DimensionUniverse | None = None
1394 for node in self:
1395 dId = node.quantum.dataId
1396 if dId is None:
1397 continue
1398 universe = dId.universe
1399 return {"reduced": self._buildSaveObject(), "graphId": self._buildId, "universe": universe}
1401 def __setstate__(self, state: dict) -> None:
1402 """Reconstructs the state of the graph from the information persisted
1403 in getstate.
1404 """
1405 buffer = io.BytesIO(state["reduced"])
1406 with LoadHelper(buffer, minimumVersion=3) as loader:
1407 qgraph = loader.load(state["universe"], graphID=state["graphId"])
1409 self._metadata = qgraph._metadata
1410 self._buildId = qgraph._buildId
1411 self._datasetDict = qgraph._datasetDict
1412 self._nodeIdMap = qgraph._nodeIdMap
1413 self._count = len(qgraph)
1414 self._taskToQuantumNode = qgraph._taskToQuantumNode
1415 self._taskGraph = qgraph._taskGraph
1416 self._connectedQuanta = qgraph._connectedQuanta
1417 self._initInputRefs = qgraph._initInputRefs
1418 self._initOutputRefs = qgraph._initOutputRefs
1420 def __eq__(self, other: object) -> bool:
1421 if not isinstance(other, QuantumGraph):
1422 return False
1423 if len(self) != len(other):
1424 return False
1425 for node in self:
1426 if node not in other:
1427 return False
1428 if self.determineInputsToQuantumNode(node) != other.determineInputsToQuantumNode(node):
1429 return False
1430 if self.determineOutputsOfQuantumNode(node) != other.determineOutputsOfQuantumNode(node):
1431 return False
1432 if set(self.allDatasetTypes) != set(other.allDatasetTypes):
1433 return False
1434 return set(self.taskGraph) == set(other.taskGraph)
1436 def getSummary(self) -> QgraphSummary:
1437 """Create summary of graph.
1439 Returns
1440 -------
1441 summary : `QgraphSummary`
1442 Summary of QuantumGraph.
1443 """
1444 inCollection = self.metadata.get("input", None)
1445 if isinstance(inCollection, str):
1446 inCollection = [inCollection]
1447 summary = QgraphSummary(
1448 graphID=self.graphID,
1449 cmdLine=self.metadata.get("full_command", None),
1450 creationUTC=self.metadata.get("time", None),
1451 inputCollection=inCollection,
1452 outputCollection=self.metadata.get("output", None),
1453 outputRun=self.metadata.get("output_run", None),
1454 )
1455 for q in self:
1456 qts = summary.qgraphTaskSummaries.setdefault(
1457 q.taskDef.label, QgraphTaskSummary(taskLabel=q.taskDef.label)
1458 )
1459 qts.numQuanta += 1
1461 for k in q.quantum.inputs.keys():
1462 qts.numInputs[k.name] += 1
1464 for k in q.quantum.outputs.keys():
1465 qts.numOutputs[k.name] += 1
1467 return summary
1469 def make_init_qbb(
1470 self,
1471 butler_config: Config | ResourcePathExpression,
1472 *,
1473 config_search_paths: Iterable[str] | None = None,
1474 ) -> QuantumBackedButler:
1475 """Construct an quantum-backed butler suitable for reading and writing
1476 init input and init output datasets, respectively.
1478 This requires the full graph to have been loaded.
1480 Parameters
1481 ----------
1482 butler_config : `~lsst.daf.butler.Config` or \
1483 `~lsst.resources.ResourcePathExpression`
1484 A butler repository root, configuration filename, or configuration
1485 instance.
1486 config_search_paths : `~collections.abc.Iterable` [ `str` ], optional
1487 Additional search paths for butler configuration.
1489 Returns
1490 -------
1491 qbb : `~lsst.daf.butler.QuantumBackedButler`
1492 A limited butler that can ``get`` init-input datasets and ``put``
1493 init-output datasets.
1494 """
1495 universe = self.universe
1496 # Collect all init input/output dataset IDs.
1497 predicted_inputs: set[DatasetId] = set()
1498 predicted_outputs: set[DatasetId] = set()
1499 pipeline_graph = self.pipeline_graph
1500 for task_label in pipeline_graph.tasks:
1501 predicted_inputs.update(ref.id for ref in self.get_init_input_refs(task_label))
1502 predicted_outputs.update(ref.id for ref in self.get_init_output_refs(task_label))
1503 predicted_outputs.update(ref.id for ref in self.globalInitOutputRefs())
1504 # remove intermediates from inputs
1505 predicted_inputs -= predicted_outputs
1506 # Very inefficient way to extract datastore records from quantum graph,
1507 # we have to scan all quanta and look at their datastore records.
1508 datastore_records: dict[str, DatastoreRecordData] = {}
1509 for quantum_node in self:
1510 for store_name, records in quantum_node.quantum.datastore_records.items():
1511 subset = records.subset(predicted_inputs)
1512 if subset is not None:
1513 datastore_records.setdefault(store_name, DatastoreRecordData()).update(subset)
1515 dataset_types = {dstype.name: dstype for dstype in self.registryDatasetTypes()}
1516 # Make butler from everything.
1517 return QuantumBackedButler.from_predicted(
1518 config=butler_config,
1519 predicted_inputs=predicted_inputs,
1520 predicted_outputs=predicted_outputs,
1521 dimensions=universe,
1522 datastore_records=datastore_records,
1523 search_paths=list(config_search_paths) if config_search_paths is not None else None,
1524 dataset_types=dataset_types,
1525 )
1527 def write_init_outputs(self, butler: LimitedButler, skip_existing: bool = True) -> None:
1528 """Write the init-output datasets for all tasks in the quantum graph.
1530 Parameters
1531 ----------
1532 butler : `lsst.daf.butler.LimitedButler`
1533 A limited butler data repository client.
1534 skip_existing : `bool`, optional
1535 If `True` (default) ignore init-outputs that already exist. If
1536 `False`, raise.
1538 Raises
1539 ------
1540 lsst.daf.butler.registry.ConflictingDefinitionError
1541 Raised if an init-output dataset already exists and
1542 ``skip_existing=False``.
1543 """
1544 # Extract init-input and init-output refs from the QG.
1545 input_refs: dict[str, DatasetRef] = {}
1546 output_refs: dict[str, DatasetRef] = {}
1547 for task_node in self.pipeline_graph.tasks.values():
1548 input_refs.update(
1549 {ref.datasetType.name: ref for ref in self.get_init_input_refs(task_node.label)}
1550 )
1551 output_refs.update(
1552 {
1553 ref.datasetType.name: ref
1554 for ref in self.get_init_output_refs(task_node.label)
1555 if ref.datasetType.name != task_node.init.config_output.dataset_type_name
1556 }
1557 )
1558 for ref, is_stored in butler.stored_many(output_refs.values()).items():
1559 if is_stored:
1560 if not skip_existing:
1561 raise ConflictingDefinitionError(f"Init-output dataset {ref} already exists.")
1562 # We'll `put` whatever's left in output_refs at the end.
1563 del output_refs[ref.datasetType.name]
1564 # Instantiate tasks, reading overall init-inputs and gathering
1565 # init-output in-memory objects.
1566 init_outputs: list[tuple[Any, DatasetType]] = []
1567 self.pipeline_graph.instantiate_tasks(
1568 get_init_input=lambda dataset_type: butler.get(
1569 input_refs[dataset_type.name].overrideStorageClass(dataset_type.storageClass)
1570 ),
1571 init_outputs=init_outputs,
1572 )
1573 # Write init-outputs that weren't already present.
1574 for obj, dataset_type in init_outputs:
1575 if new_ref := output_refs.get(dataset_type.name):
1576 assert new_ref.datasetType.storageClass_name == dataset_type.storageClass_name, (
1577 "QG init refs should use task connection storage classes."
1578 )
1579 butler.put(obj, new_ref)
1581 def write_configs(self, butler: LimitedButler, compare_existing: bool = True) -> None:
1582 """Write the config datasets for all tasks in the quantum graph.
1584 Parameters
1585 ----------
1586 butler : `lsst.daf.butler.LimitedButler`
1587 A limited butler data repository client.
1588 compare_existing : `bool`, optional
1589 If `True` check configs that already exist for consistency. If
1590 `False`, always raise if configs already exist.
1592 Raises
1593 ------
1594 lsst.daf.butler.registry.ConflictingDefinitionError
1595 Raised if an config dataset already exists and
1596 ``compare_existing=False``, or if the existing config is not
1597 consistent with the config in the quantum graph.
1598 """
1599 to_put: list[tuple[PipelineTaskConfig, DatasetRef]] = []
1600 for task_node in self.pipeline_graph.tasks.values():
1601 dataset_type_name = task_node.init.config_output.dataset_type_name
1602 (ref,) = [ # noqa: UP027
1603 ref
1604 for ref in self.get_init_output_refs(task_node.label)
1605 if ref.datasetType.name == dataset_type_name
1606 ]
1607 try:
1608 old_config = butler.get(ref)
1609 except (LookupError, FileNotFoundError):
1610 old_config = None
1611 if old_config is not None:
1612 if not compare_existing:
1613 raise ConflictingDefinitionError(f"Config dataset {ref} already exists.")
1614 if not task_node.config.compare(old_config, shortcut=False, output=log_config_mismatch):
1615 raise ConflictingDefinitionError(
1616 f"Config does not match existing task config {dataset_type_name!r} in "
1617 "butler; tasks configurations must be consistent within the same run collection."
1618 )
1619 else:
1620 to_put.append((task_node.config, ref))
1621 # We do writes at the end to minimize the mess we leave behind when we
1622 # raise an exception.
1623 for config, ref in to_put:
1624 butler.put(config, ref)
1626 def write_packages(self, butler: LimitedButler, compare_existing: bool = True) -> None:
1627 """Write the 'packages' dataset for the currently-active software
1628 versions.
1630 Parameters
1631 ----------
1632 butler : `lsst.daf.butler.LimitedButler`
1633 A limited butler data repository client.
1634 compare_existing : `bool`, optional
1635 If `True` check packages that already exist for consistency. If
1636 `False`, always raise if the packages dataset already exists.
1638 Raises
1639 ------
1640 lsst.daf.butler.registry.ConflictingDefinitionError
1641 Raised if the packages dataset already exists and is not consistent
1642 with the current packages.
1643 """
1644 new_packages = Packages.fromSystem()
1645 (ref,) = self.globalInitOutputRefs()
1646 try:
1647 packages = butler.get(ref)
1648 except (LookupError, FileNotFoundError):
1649 packages = None
1650 if packages is not None:
1651 if not compare_existing:
1652 raise ConflictingDefinitionError(f"Packages dataset {ref} already exists.")
1653 if compare_packages(packages, new_packages):
1654 # have to remove existing dataset first; butler has no
1655 # replace option.
1656 butler.pruneDatasets([ref], unstore=True, purge=True)
1657 butler.put(packages, ref)
1658 else:
1659 butler.put(new_packages, ref)
1661 def init_output_run(self, butler: LimitedButler, existing: bool = True) -> None:
1662 """Initialize a new output RUN collection by writing init-output
1663 datasets (including configs and packages).
1665 Parameters
1666 ----------
1667 butler : `lsst.daf.butler.LimitedButler`
1668 A limited butler data repository client.
1669 existing : `bool`, optional
1670 If `True` check or ignore outputs that already exist. If
1671 `False`, always raise if an output dataset already exists.
1673 Raises
1674 ------
1675 lsst.daf.butler.registry.ConflictingDefinitionError
1676 Raised if there are existing init output datasets, and either
1677 ``existing=False`` or their contents are not compatible with this
1678 graph.
1679 """
1680 self.write_configs(butler, compare_existing=existing)
1681 self.write_packages(butler, compare_existing=existing)
1682 self.write_init_outputs(butler, skip_existing=existing)
1684 def get_refs(
1685 self,
1686 *,
1687 include_init_inputs: bool = False,
1688 include_inputs: bool = False,
1689 include_intermediates: bool | None = None,
1690 include_init_outputs: bool = False,
1691 include_outputs: bool = False,
1692 conform_outputs: bool = True,
1693 ) -> tuple[set[DatasetRef], dict[str, DatastoreRecordData]]:
1694 """Get the requested dataset refs from the graph.
1696 Parameters
1697 ----------
1698 include_init_inputs : `bool`, optional
1699 Include init inputs.
1700 include_inputs : `bool`, optional
1701 Include inputs.
1702 include_intermediates : `bool` or `None`, optional
1703 If `None`, no special handling for intermediates is performed.
1704 If `True` intermediates are calculated even if other flags
1705 do not request datasets. If `False` intermediates will be removed
1706 from any results.
1707 include_init_outputs : `bool`, optional
1708 Include init outpus.
1709 include_outputs : `bool`, optional
1710 Include outputs.
1711 conform_outputs : `bool`, optional
1712 Whether any outputs found should have their dataset types conformed
1713 with the registry dataset types.
1715 Returns
1716 -------
1717 refs : `set` [ `lsst.daf.butler.DatasetRef` ]
1718 The requested dataset refs found in the graph.
1719 datastore_records : `dict` [ `str`, \
1720 `lsst.daf.butler.datastore.record_data.DatastoreRecordData` ]
1721 Any datastore records found.
1723 Notes
1724 -----
1725 Conforming and requesting inputs and outputs can result in the same
1726 dataset appearing in the results twice with differing storage classes.
1727 """
1728 datastore_records: dict[str, DatastoreRecordData] = {}
1729 init_input_refs: set[DatasetRef] = set()
1730 init_output_refs: set[DatasetRef] = set(self.globalInitOutputRefs())
1732 if include_intermediates is True:
1733 # Need to enable inputs and outputs even if not explicitly
1734 # requested.
1735 request_include_init_inputs = True
1736 request_include_inputs = True
1737 request_include_init_outputs = True
1738 request_include_outputs = True
1739 else:
1740 request_include_init_inputs = include_init_inputs
1741 request_include_inputs = include_inputs
1742 request_include_init_outputs = include_init_outputs
1743 request_include_outputs = include_outputs
1745 if request_include_init_inputs or request_include_init_outputs:
1746 for task_def in self.iterTaskGraph():
1747 if request_include_init_inputs:
1748 if in_refs := self.initInputRefs(task_def):
1749 init_input_refs.update(in_refs)
1750 if request_include_init_outputs:
1751 if out_refs := self.initOutputRefs(task_def):
1752 init_output_refs.update(out_refs)
1754 input_refs: set[DatasetRef] = set()
1755 output_refs: set[DatasetRef] = set()
1757 for qnode in self:
1758 if request_include_inputs:
1759 for other_refs in qnode.quantum.inputs.values():
1760 input_refs.update(other_refs)
1761 # Inputs can come with datastore records.
1762 for store_name, records in qnode.quantum.datastore_records.items():
1763 datastore_records.setdefault(store_name, DatastoreRecordData()).update(records)
1764 if request_include_outputs:
1765 for other_refs in qnode.quantum.outputs.values():
1766 output_refs.update(other_refs)
1768 # Intermediates are the intersection of inputs and outputs. Must do
1769 # this analysis before conforming since dataset type changes will
1770 # change set membership.
1771 inter_msg = ""
1772 intermediates = set()
1773 if include_intermediates is not None:
1774 intermediates = (input_refs | init_input_refs) & (output_refs | init_output_refs)
1776 if include_intermediates is False:
1777 # Remove intermediates from results.
1778 init_input_refs -= intermediates
1779 input_refs -= intermediates
1780 init_output_refs -= intermediates
1781 output_refs -= intermediates
1782 inter_msg = f"; Intermediates removed: {len(intermediates)}"
1783 intermediates = set()
1784 elif include_intermediates is True:
1785 # Do not mention intermediates if all the input/output flags
1786 # would have resulted in them anyhow.
1787 if (
1788 (request_include_init_inputs is not include_init_inputs)
1789 or (request_include_inputs is not include_inputs)
1790 or (request_include_init_outputs is not include_init_outputs)
1791 or (request_include_outputs is not include_outputs)
1792 ):
1793 inter_msg = f"; including intermediates: {len(intermediates)}"
1795 # Assign intermediates to the relevant category.
1796 if not include_init_inputs:
1797 init_input_refs &= intermediates
1798 if not include_inputs:
1799 input_refs &= intermediates
1800 if not include_init_outputs:
1801 init_output_refs &= intermediates
1802 if not include_outputs:
1803 output_refs &= intermediates
1805 # Conforming can result in an input ref and an output ref appearing
1806 # in the returned results that are identical apart from storage class.
1807 if conform_outputs:
1808 # Get data repository definitions from the QuantumGraph; these can
1809 # have different storage classes than those in the quanta.
1810 dataset_types = {dstype.name: dstype for dstype in self.registryDatasetTypes()}
1812 def _update_ref(ref: DatasetRef) -> DatasetRef:
1813 internal_dataset_type = dataset_types.get(ref.datasetType.name, ref.datasetType)
1814 if internal_dataset_type.storageClass_name != ref.datasetType.storageClass_name:
1815 ref = ref.replace(storage_class=internal_dataset_type.storageClass_name)
1816 return ref
1818 # Convert output_refs to the data repository storage classes, too.
1819 output_refs = {_update_ref(ref) for ref in output_refs}
1820 init_output_refs = {_update_ref(ref) for ref in init_output_refs}
1822 _LOG.verbose(
1823 "Found the following datasets. InitInputs: %d; Inputs: %d; InitOutputs: %s; Outputs: %d%s",
1824 len(init_input_refs),
1825 len(input_refs),
1826 len(init_output_refs),
1827 len(output_refs),
1828 inter_msg,
1829 )
1831 refs = input_refs | init_input_refs | init_output_refs | output_refs
1832 return refs, datastore_records