Coverage for python/lsst/pipe/base/quantum_graph_skeleton.py: 43%
136 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-11 09:32 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-11 09:32 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
28"""An under-construction version of QuantumGraph and various helper
29classes.
30"""
32from __future__ import annotations
34__all__ = (
35 "QuantumGraphSkeleton",
36 "QuantumKey",
37 "TaskInitKey",
38 "DatasetKey",
39 "PrerequisiteDatasetKey",
40)
42from collections.abc import Iterable, Iterator, MutableMapping, Set
43from typing import TYPE_CHECKING, Any, ClassVar, Literal, NamedTuple
45import networkx
46from lsst.daf.butler import DataCoordinate, DataIdValue, DatasetRef
47from lsst.utils.logging import getLogger
49if TYPE_CHECKING:
50 pass
52_LOG = getLogger(__name__)
55class QuantumKey(NamedTuple):
56 """Identifier type for quantum keys in a `QuantumGraphSkeleton`."""
58 task_label: str
59 """Label of the task in the pipeline."""
61 data_id_values: tuple[DataIdValue, ...]
62 """Data ID values of the quantum.
64 Note that keys are fixed given `task_label`, so using only the values here
65 speeds up comparisons.
66 """
68 is_task: ClassVar[Literal[True]] = True
69 """Whether this node represents a quantum or task initialization rather
70 than a dataset (always `True`).
71 """
74class TaskInitKey(NamedTuple):
75 """Identifier type for task init keys in a `QuantumGraphSkeleton`."""
77 task_label: str
78 """Label of the task in the pipeline."""
80 is_task: ClassVar[Literal[True]] = True
81 """Whether this node represents a quantum or task initialization rather
82 than a dataset (always `True`).
83 """
86class DatasetKey(NamedTuple):
87 """Identifier type for dataset keys in a `QuantumGraphSkeleton`."""
89 parent_dataset_type_name: str
90 """Name of the dataset type (never a component)."""
92 data_id_values: tuple[DataIdValue, ...]
93 """Data ID values of the dataset.
95 Note that keys are fixed given `parent_dataset_type_name`, so using only
96 the values here speeds up comparisons.
97 """
99 is_task: ClassVar[Literal[False]] = False
100 """Whether this node represents a quantum or task initialization rather
101 than a dataset (always `False`).
102 """
104 is_prerequisite: ClassVar[Literal[False]] = False
107class PrerequisiteDatasetKey(NamedTuple):
108 """Identifier type for prerequisite dataset keys in a
109 `QuantumGraphSkeleton`.
111 Unlike regular datasets, prerequisites are not actually required to come
112 from a find-first search of `input_collections`, so we don't want to
113 assume that the same data ID implies the same dataset. Happily we also
114 don't need to search for them by data ID in the graph, so we can use the
115 dataset ID (UUID) instead.
116 """
118 parent_dataset_type_name: str
119 """Name of the dataset type (never a component)."""
121 dataset_id_bytes: bytes
122 """Dataset ID (UUID) as raw bytes."""
124 is_task: ClassVar[Literal[False]] = False
125 """Whether this node represents a quantum or task initialization rather
126 than a dataset (always `False`).
127 """
129 is_prerequisite: ClassVar[Literal[True]] = True
132class QuantumGraphSkeleton:
133 """An under-construction quantum graph.
135 QuantumGraphSkeleton is intended for use inside `QuantumGraphBuilder` and
136 its subclasses.
138 Parameters
139 ----------
140 task_labels : `~collections.abc.Iterable` [ `str` ]
141 The labels of all tasks whose quanta may be included in the graph, in
142 topological order.
144 Notes
145 -----
146 QuantumGraphSkeleton models a bipartite version of the quantum graph, in
147 which both quanta and datasets are represented as nodes and each type of
148 node only has edges to the other type.
150 Square-bracket (`getitem`) indexing returns a mutable mapping of a node's
151 flexible attributes.
153 The details of the `QuantumGraphSkeleton` API (e.g. which operations
154 operate on multiple nodes vs. a single node) are set by what's actually
155 needed by current quantum graph generation algorithms. New variants can be
156 added as needed, but adding all operations that *might* be useful for some
157 future algorithm seems premature.
158 """
160 def __init__(self, task_labels: Iterable[str]):
161 self._tasks: dict[str, tuple[TaskInitKey, set[QuantumKey]]] = {}
162 self._xgraph: networkx.DiGraph = networkx.DiGraph()
163 self._global_init_outputs: set[DatasetKey] = set()
164 for task_label in task_labels:
165 task_init_key = TaskInitKey(task_label)
166 self._tasks[task_label] = (task_init_key, set())
167 self._xgraph.add_node(task_init_key)
169 def __contains__(self, key: QuantumKey | TaskInitKey | DatasetKey | PrerequisiteDatasetKey) -> bool:
170 return key in self._xgraph.nodes
172 def __getitem__(
173 self, key: QuantumKey | TaskInitKey | DatasetKey | PrerequisiteDatasetKey
174 ) -> MutableMapping[str, Any]:
175 return self._xgraph.nodes[key]
177 @property
178 def n_nodes(self) -> int:
179 """The total number of nodes of all types."""
180 return len(self._xgraph.nodes)
182 @property
183 def n_edges(self) -> int:
184 """The total number of edges."""
185 return len(self._xgraph.edges)
187 def has_task(self, task_label: str) -> bool:
188 """Test whether the given task is in this skeleton.
190 Tasks are only added to the skeleton at initialization, but may be
191 removed by `remove_task` if they end up having no quanta.
192 """
193 return task_label in self._tasks
195 def get_task_init_node(self, task_label: str) -> TaskInitKey:
196 """Return the graph node that represents a task's initialization."""
197 return self._tasks[task_label][0]
199 def get_quanta(self, task_label: str) -> Set[QuantumKey]:
200 """Return the quanta for the given task label.
202 Parameters
203 ----------
204 task_label : `str`
205 Label for the task.
207 Returns
208 -------
209 quanta : `~collections.abc.Set` [ `QuantumKey` ]
210 A set-like object with the identifiers of all quanta for the given
211 task. *The skeleton object's set of quanta must not be modified
212 while iterating over this container; make a copy if mutation during
213 iteration is necessary.*
214 """
215 return self._tasks[task_label][1]
217 @property
218 def global_init_outputs(self) -> Set[DatasetKey]:
219 """The set of dataset nodes that are not associated with any task."""
220 return self._global_init_outputs
222 def iter_all_quanta(self) -> Iterator[QuantumKey]:
223 """Iterate over all quanta from any task, in topological (but otherwise
224 unspecified) order.
225 """
226 for _, quanta in self._tasks.values():
227 yield from quanta
229 def iter_outputs_of(self, quantum_key: QuantumKey | TaskInitKey) -> Iterator[DatasetKey]:
230 """Iterate over the datasets produced by the given quantum."""
231 return self._xgraph.successors(quantum_key)
233 def iter_inputs_of(
234 self, quantum_key: QuantumKey | TaskInitKey
235 ) -> Iterator[DatasetKey | PrerequisiteDatasetKey]:
236 """Iterate over the datasets consumed by the given quantum."""
237 return self._xgraph.predecessors(quantum_key)
239 def update(self, other: QuantumGraphSkeleton) -> None:
240 """Copy all nodes from ``other`` to ``self``.
242 The tasks in ``other`` must be a subset of the tasks in ``self`` (this
243 method is expected to be used to populate a skeleton for a full
244 from independent-subgraph skeletons).
245 """
246 for task_label, (_, quanta) in other._tasks.items():
247 self._tasks[task_label][1].update(quanta)
248 self._xgraph.update(other._xgraph)
250 def add_quantum_node(self, task_label: str, data_id: DataCoordinate, **attrs: Any) -> QuantumKey:
251 """Add a new node representing a quantum."""
252 key = QuantumKey(task_label, data_id.values_tuple())
253 self._xgraph.add_node(key, data_id=data_id, **attrs)
254 self._tasks[key.task_label][1].add(key)
255 return key
257 def add_dataset_node(
258 self,
259 parent_dataset_type_name: str,
260 data_id: DataCoordinate,
261 is_global_init_output: bool = False,
262 **attrs: Any,
263 ) -> DatasetKey:
264 """Add a new node representing a dataset."""
265 key = DatasetKey(parent_dataset_type_name, data_id.values_tuple())
266 self._xgraph.add_node(key, data_id=data_id, **attrs)
267 if is_global_init_output:
268 assert isinstance(key, DatasetKey)
269 self._global_init_outputs.add(key)
270 return key
272 def add_prerequisite_node(
273 self,
274 parent_dataset_type_name: str,
275 ref: DatasetRef,
276 **attrs: Any,
277 ) -> PrerequisiteDatasetKey:
278 """Add a new node representing a prerequisite input dataset."""
279 key = PrerequisiteDatasetKey(parent_dataset_type_name, ref.id.bytes)
280 self._xgraph.add_node(key, data_id=ref.dataId, ref=ref, **attrs)
281 return key
283 def remove_quantum_node(self, key: QuantumKey, remove_outputs: bool) -> None:
284 """Remove a node representing a quantum.
286 Parameters
287 ----------
288 key : `QuantumKey`
289 Identifier for the node.
290 remove_outputs : `bool`
291 If `True`, also remove all dataset nodes produced by this quantum.
292 If `False`, any such dataset nodes will become overall inputs.
293 """
294 _, quanta = self._tasks[key.task_label]
295 quanta.remove(key)
296 if remove_outputs:
297 to_remove = list(self._xgraph.successors(key))
298 to_remove.append(key)
299 self._xgraph.remove_nodes_from(to_remove)
300 else:
301 self._xgraph.remove_node(key)
303 def remove_dataset_nodes(self, keys: Iterable[DatasetKey | PrerequisiteDatasetKey]) -> None:
304 """Remove nodes representing datasets."""
305 self._xgraph.remove_nodes_from(keys)
307 def remove_task(self, task_label: str) -> None:
308 """Fully remove a task from the skeleton.
310 All init-output datasets and quanta for the task must already have been
311 removed.
312 """
313 task_init_key, quanta = self._tasks.pop(task_label)
314 assert not quanta, "Cannot remove task unless all quanta have already been removed."
315 assert not list(self._xgraph.successors(task_init_key))
316 self._xgraph.remove_node(task_init_key)
318 def add_input_edges(
319 self,
320 task_key: QuantumKey | TaskInitKey,
321 dataset_keys: Iterable[DatasetKey | PrerequisiteDatasetKey],
322 ) -> None:
323 """Add edges connecting datasets to a quantum that consumes them.
325 Notes
326 -----
327 This must only be called if the task node has already been added.
328 Use `add_input_edge` if this cannot be assumed.
330 Dataset nodes that are not already present will be created.
331 """
332 assert task_key in self._xgraph
333 self._xgraph.add_edges_from((dataset_key, task_key) for dataset_key in dataset_keys)
335 def remove_input_edges(
336 self,
337 task_key: QuantumKey | TaskInitKey,
338 dataset_keys: Iterable[DatasetKey | PrerequisiteDatasetKey],
339 ) -> None:
340 """Remove edges connecting datasets to a quantum that consumes them."""
341 self._xgraph.remove_edges_from((dataset_key, task_key) for dataset_key in dataset_keys)
343 def add_input_edge(
344 self,
345 task_key: QuantumKey | TaskInitKey,
346 dataset_key: DatasetKey | PrerequisiteDatasetKey,
347 ignore_unrecognized_quanta: bool = False,
348 ) -> bool:
349 """Add an edge connecting a dataset to a quantum that consumes it.
351 Parameters
352 ----------
353 task_key : `QuantumKey` or `TaskInitKey`
354 Identifier for the quantum node.
355 dataset_key : `DatasetKey` or `PrerequisiteKey`
356 Identifier for the dataset node.
357 ignore_unrecognized_quanta : `bool`, optional
358 If `False`, do nothing if the quantum node is not already present.
359 If `True`, the quantum node is assumed to be present.
361 Returns
362 -------
363 added : `bool`
364 `True` if an edge was actually added, `False` if the quantum was
365 not recognized and the edge was not added as a result.
367 Notes
368 -----
369 Dataset nodes that are not already present will be created.
370 """
371 if ignore_unrecognized_quanta and task_key not in self._xgraph:
372 return False
373 self._xgraph.add_edge(dataset_key, task_key)
374 return True
376 def add_output_edge(self, task_key: QuantumKey | TaskInitKey, dataset_key: DatasetKey) -> None:
377 """Add an edge connecting a dataset to the quantum that produces it.
379 Parameters
380 ----------
381 task_key : `QuantumKey` or `TaskInitKey`
382 Identifier for the quantum node. Must identify a node already
383 present in the graph.
384 dataset_key : `DatasetKey`
385 Identifier for the dataset node. Must identify a node already
386 present in the graph.
387 """
388 assert task_key in self._xgraph
389 assert dataset_key in self._xgraph
390 self._xgraph.add_edge(task_key, dataset_key)
392 def remove_orphan_datasets(self) -> None:
393 """Remove any dataset nodes that do not have any edges."""
394 for orphan in list(networkx.isolates(self._xgraph)):
395 if not orphan.is_task and orphan not in self._global_init_outputs:
396 self._xgraph.remove_node(orphan)
398 def extract_overall_inputs(self) -> dict[DatasetKey | PrerequisiteDatasetKey, DatasetRef]:
399 """Find overall input datasets.
401 Returns
402 -------
403 datasets : `dict` [ `DatasetKey` or `PrerequisiteDatasetKey`, \
404 `~lsst.daf.butler.DatasetRef` ]
405 Overall-input datasets, including prerequisites and init-inputs.
406 """
407 result = {}
408 for generation in networkx.algorithms.topological_generations(self._xgraph):
409 for dataset_key in generation:
410 if dataset_key.is_task:
411 continue
412 try:
413 result[dataset_key] = self[dataset_key]["ref"]
414 except KeyError:
415 raise AssertionError(
416 f"Logic bug in QG generation: dataset {dataset_key} was never resolved."
417 )
418 break
419 return result