Coverage for python / lsst / ctrl / bps / clustered_quantum_graph.py: 22%
180 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-18 08:47 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-18 08:47 +0000
1# This file is part of ctrl_bps.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <https://www.gnu.org/licenses/>.
28"""Class definitions for a Clustered QuantumGraph where a node in the graph is
29a QuantumGraph.
30"""
32from __future__ import annotations
34__all__ = ["ClusteredQuantumGraph", "QuantaCluster"]
36import logging
37import pickle
38import re
39import uuid
40from collections import Counter, defaultdict
41from pathlib import Path
43from networkx import DiGraph, is_directed_acyclic_graph, is_isomorphic, topological_sort
45from lsst.pipe.base.pipeline_graph import TaskImportMode
46from lsst.pipe.base.quantum_graph import PredictedQuantumGraph, QuantumInfo
47from lsst.utils.iteration import ensure_iterable
49from .bps_draw import draw_networkx_dot
51_LOG = logging.getLogger(__name__)
54class QuantaCluster:
55 """Information about the cluster and Quanta belonging to it.
57 Parameters
58 ----------
59 name : `str`
60 Lookup key (logical file name) of file/directory. Must
61 be unique within ClusteredQuantumGraph.
62 label : `str`
63 Value used to group clusters.
64 tags : `dict` [`str`, `~typing.Any`], optional
65 Arbitrary key/value pairs for the cluster.
67 Raises
68 ------
69 ValueError
70 Raised if invalid name (e.g., name contains /).
71 """
73 def __init__(self, name, label, tags=None):
74 if "/" in name:
75 raise ValueError(f"Cluster's name cannot have a / ({name})")
76 self.name = name
77 self.label = label
78 self._qgraph_node_ids = []
79 self._task_label_counts = Counter()
80 self.tags = tags
81 if self.tags is None:
82 self.tags = {}
84 @classmethod
85 def from_quantum_info(
86 cls, quantum_id: uuid.UUID, quantum_info: QuantumInfo, template: str
87 ) -> QuantaCluster:
88 """Create single quantum cluster from the given quantum information.
90 Parameters
91 ----------
92 quantum_id : `uuid.UUID`
93 ID of the quantum.
94 quantum_info : `lsst.pipe.base.quantum_graph.QuantumInfo`
95 Dictionary of additional information about the quantum.
96 template : `str`
97 Template for creating cluster name.
99 Returns
100 -------
101 cluster : `QuantaCluster`
102 Newly created cluster containing the given quantum.
103 """
104 label = quantum_info["task_label"]
105 data_id = quantum_info["data_id"]
107 # Gather info for name template into a dictionary.
108 info = dict(data_id.required)
109 info["label"] = label
110 info["node_number"] = quantum_id
111 _LOG.debug("template = %s", template)
112 _LOG.debug("info for template = %s", info)
114 # Use dictionary plus template format string to create name. To avoid
115 # key errors from generic patterns, use defaultdict.
116 try:
117 name = template.format_map(defaultdict(lambda: "", info))
118 except TypeError:
119 _LOG.error("Problems creating cluster name. template='%s', info=%s", template, info)
120 raise
121 name = re.sub("_+", "_", name)
122 _LOG.debug("template name = %s", name)
124 cluster = QuantaCluster(name, label, info)
125 cluster.add_quantum(quantum_id, label)
126 return cluster
128 @property
129 def qgraph_node_ids(self):
130 """Quantum graph NodeIds corresponding to this cluster."""
131 _LOG.debug("_qgraph_node_ids = %s", self._qgraph_node_ids)
132 return frozenset(self._qgraph_node_ids)
134 @property
135 def quanta_counts(self):
136 """Counts of Quanta per taskDef.label in this cluster."""
137 return Counter(self._task_label_counts)
139 def add_quantum(self, node_id, task_label):
140 """Add a quantumNode to this cluster.
142 Parameters
143 ----------
144 node_id : `uuid.UUID`
145 ID for quantumNode to be added to cluster.
146 task_label : `str`
147 Task label for quantumNode to be added to cluster.
148 """
149 self._qgraph_node_ids.append(node_id)
150 self._task_label_counts[task_label] += 1
152 def __str__(self):
153 return (
154 f"QuantaCluster(name={self.name},label={self.label},tags={self.tags},"
155 f"counts={self.quanta_counts},ids={self.qgraph_node_ids})"
156 )
158 def __eq__(self, other: object) -> bool:
159 # Doesn't check data equality, but only
160 # name equality since those are supposed
161 # to be unique.
162 if isinstance(other, str):
163 return self.name == other
165 if isinstance(other, QuantaCluster):
166 return self.name == other.name
168 return False
170 def __hash__(self) -> int:
171 return hash(self.name)
174class ClusteredQuantumGraph:
175 """Graph where the data for a node is a subgraph of the full
176 QuantumGraph represented by a list of node ids.
178 Parameters
179 ----------
180 name : `str`
181 Name to be given to the ClusteredQuantumGraph.
182 qgraph : `lsst.pipe.base.quantum_graph.PredictedQuantumGraph`
183 The quantum graph to be clustered.
184 qgraph_filename : `str`
185 Filename for given quantum graph.
187 Raises
188 ------
189 ValueError
190 Raised if invalid name (e.g., name contains /)
192 Notes
193 -----
194 Using lsst.pipe.base.NodeId instead of integer because the QuantumGraph
195 API requires them. Chose skipping the repeated creation of objects to
196 use API over totally minimized memory usage.
197 """
199 def __init__(self, name: str, qgraph: PredictedQuantumGraph, qgraph_filename: str):
200 if "/" in name:
201 raise ValueError(f"name cannot have a / ({name})")
202 self._name = name
203 self._quantum_graph = qgraph
204 self._quantum_only_xgraph = qgraph.quantum_only_xgraph
205 self._quantum_graph_filename = Path(qgraph_filename).resolve()
206 self._cluster_graph = DiGraph()
208 def __str__(self):
209 return (
210 f"ClusteredQuantumGraph(name={self.name},"
211 f"quantum_graph_filename={self._quantum_graph_filename},"
212 f"len(qgraph)={len(self._quantum_graph) if self._quantum_graph else None},"
213 f"len(cqgraph)={len(self._cluster_graph) if self._cluster_graph else None})"
214 )
216 def __len__(self):
217 """Return the number of clusters."""
218 return len(self._cluster_graph)
220 def __eq__(self, other):
221 if not isinstance(other, ClusteredQuantumGraph):
222 return False
223 if len(self) != len(other):
224 return False
225 return is_isomorphic(self.qxgraph, other.qxgraph) and is_isomorphic(
226 self._cluster_graph, other._cluster_graph
227 )
229 @property
230 def name(self) -> str:
231 """The name of the ClusteredQuantumGraph."""
232 return self._name
234 @property
235 def qgraph(self) -> PredictedQuantumGraph:
236 """The quantum graph associated with this Clustered
237 QuantumGraph.
238 """
239 return self._quantum_graph
241 @property
242 def qxgraph(self) -> DiGraph:
243 """A networkx graph of all quanta."""
244 return self._quantum_only_xgraph
246 def add_cluster(self, clusters_for_adding):
247 """Add a cluster of quanta as a node in the graph.
249 Parameters
250 ----------
251 clusters_for_adding : `QuantaCluster` or \
252 `~collections.abc.Iterable` [`QuantaCluster`]
253 The cluster to be added to the ClusteredQuantumGraph.
254 """
255 for cluster in ensure_iterable(clusters_for_adding):
256 if not isinstance(cluster, QuantaCluster):
257 raise TypeError(f"Must be type QuantaCluster (given: {type(cluster)})")
259 if self._cluster_graph.has_node(cluster.name):
260 raise KeyError(f"Cluster {cluster.name} already exists in ClusteredQuantumGraph")
262 self._cluster_graph.add_node(cluster.name, cluster=cluster)
264 def get_cluster(self, name):
265 """Retrieve a cluster from the ClusteredQuantumGraph by name.
267 Parameters
268 ----------
269 name : `str`
270 Name of cluster to retrieve.
272 Returns
273 -------
274 cluster : `QuantaCluster`
275 QuantaCluster matching given name.
277 Raises
278 ------
279 KeyError
280 Raised if the ClusteredQuantumGraph does not contain
281 a cluster with given name.
282 """
283 try:
284 attr = self._cluster_graph.nodes[name]
285 except KeyError as ex:
286 raise KeyError(f"{self.name} does not have a cluster named {name}") from ex
287 return attr["cluster"]
289 def get_quantum_info(self, id_: uuid.UUID) -> QuantumInfo:
290 """Retrieve a quantum info dict from the ClusteredQuantumGraph by ID.
292 Parameters
293 ----------
294 id_ : `uuid.UUID`
295 ID of the quantum to retrieve.
297 Returns
298 -------
299 quantum_info : `lsst.pipe.base.quantum_graph.QuantumInfo`
300 Quantum info dictionary for the given ID.
302 Raises
303 ------
304 KeyError
305 Raised if the ClusteredQuantumGraph does not contain
306 a quantum with given ID.
307 """
308 return self._quantum_only_xgraph.nodes[id_]
310 def __iter__(self):
311 """Iterate over names of clusters.
313 Returns
314 -------
315 names : `~collections.abc.Iterator` [`str`]
316 Iterator over names of clusters.
317 """
318 return self._cluster_graph.nodes()
320 def clusters(self):
321 """Iterate over clusters.
323 Returns
324 -------
325 clusters : `~collections.abc.Iterator` [`lsst.ctrl.bps.QuantaCluster`]
326 Iterator over clusters in topological order.
327 """
328 return map(self.get_cluster, topological_sort(self._cluster_graph))
330 def successors(self, name):
331 """Return clusters that are successors of the cluster
332 with the given name.
334 Parameters
335 ----------
336 name : `str`
337 Name of cluster for which need the successors.
339 Returns
340 -------
341 clusters : `~collections.abc.Iterator` [`lsst.ctrl.bps.QuantaCluster`]
342 Iterator over successors of given cluster.
343 """
344 return map(self.get_cluster, self._cluster_graph.successors(name))
346 def predecessors(self, name):
347 """Return clusters that are predecessors of the cluster
348 with the given name.
350 Parameters
351 ----------
352 name : `str`
353 Name of cluster for which need the predecessors.
355 Returns
356 -------
357 clusters : `~collections.abc.Iterator` [`lsst.ctrl.bps.QuantaCluster`]
358 Iterator over predecessors of given cluster.
359 """
360 return map(self.get_cluster, self._cluster_graph.predecessors(name))
362 def add_dependency(self, parent, child):
363 """Add a directed dependency between a parent cluster and a child
364 cluster.
366 Parameters
367 ----------
368 parent : `str` or `QuantaCluster`
369 Parent cluster.
370 child : `str` or `QuantaCluster`
371 Child cluster.
373 Raises
374 ------
375 KeyError
376 Raised if either the parent or child doesn't exist in the
377 ClusteredQuantumGraph.
378 """
379 if not self._cluster_graph.has_node(parent):
380 raise KeyError(f"{self.name} does not have a cluster named {parent}")
381 if not self._cluster_graph.has_node(child):
382 raise KeyError(f"{self.name} does not have a cluster named {child}")
383 _LOG.debug("add_dependency: adding edge %s %s", parent, child)
385 if isinstance(parent, QuantaCluster):
386 pname = parent.name
387 else:
388 pname = parent
390 if isinstance(child, QuantaCluster):
391 cname = child.name
392 else:
393 cname = child
394 self._cluster_graph.add_edge(pname, cname)
396 def __contains__(self, name):
397 """Check if a cluster with given name is in this ClusteredQuantumGraph.
399 Parameters
400 ----------
401 name : `str`
402 Name of cluster to check.
404 Returns
405 -------
406 found : `bool`
407 Whether a cluster with given name is in this ClusteredQuantumGraph.
408 """
409 return self._cluster_graph.has_node(name)
411 def save(self, filename, format_=None):
412 """Save the ClusteredQuantumGraph in a format that is loadable.
414 The quantum graph is assumed to have been saved separately.
416 Parameters
417 ----------
418 filename : `str`
419 File to which the ClusteredQuantumGraph should be serialized.
420 format_ : `str`, optional
421 Format in which to write the data. It defaults to pickle format.
422 """
423 path = Path(filename)
425 # if format is None, try extension
426 if format_ is None:
427 format_ = path.suffix[1:] # suffix includes the leading period
429 if format_ not in {"pickle"}:
430 raise RuntimeError(f"Unknown format ({format_})")
432 if format_ == "pickle":
433 # Don't save QuantumGraph in same file.
434 tmp_qgraph = self._quantum_graph
435 self._quantum_graph = None
436 with open(filename, "wb") as fh:
437 pickle.dump(self, fh)
438 # Return to original state.
439 self._quantum_graph = tmp_qgraph
441 def draw(self, filename, format_=None):
442 """Draw the ClusteredQuantumGraph in a given format.
444 Parameters
445 ----------
446 filename : `str`
447 File to which the ClusteredQuantumGraph should be serialized.
448 format_ : `str`, optional
449 Format in which to draw the data. It defaults to dot format.
450 """
451 path = Path(filename)
453 # if format is None, try extension
454 if format_ is None:
455 format_ = path.suffix[1:] # suffix includes the leading period
457 draw_funcs = {"dot": draw_networkx_dot}
458 if format_ in draw_funcs:
459 draw_funcs[format_](self._cluster_graph, filename)
460 else:
461 raise RuntimeError(f"Unknown draw format ({format_})")
463 @classmethod
464 def load(cls, filename, format_=None):
465 """Load a ClusteredQuantumGraph from the given file.
467 Parameters
468 ----------
469 filename : `str`
470 File from which to read the ClusteredQuantumGraph.
471 format_ : `str`, optional
472 Format of data to expect when loading from stream. It defaults
473 to pickle format.
475 Returns
476 -------
477 ClusteredQuantumGraph : `lsst.ctrl.bps.ClusteredQuantumGraph`
478 ClusteredQuantumGraph workflow loaded from the given file.
479 The QuantumGraph is loaded from its own file specified in
480 the saved ClusteredQuantumGraph.
481 """
482 path = Path(filename)
484 # if format is None, try extension
485 if format_ is None:
486 format_ = path.suffix[1:] # suffix includes the leading period
488 if format_ not in {"pickle"}:
489 raise RuntimeError(f"Unknown format ({format_})")
491 cgraph = None
492 if format_ == "pickle":
493 with open(filename, "rb") as fh:
494 cgraph: ClusteredQuantumGraph = pickle.load(fh)
496 # The QuantumGraph was saved separately
497 with PredictedQuantumGraph.open(
498 cgraph._quantum_graph_filename, import_mode=TaskImportMode.DO_NOT_IMPORT
499 ) as reader:
500 reader.read_thin_graph()
501 cgraph._quantum_graph = reader.finish()
503 return cgraph
505 def validate(self):
506 """Check correctness of completed ClusteredQuantumGraph.
508 Raises
509 ------
510 RuntimeError
511 If the ClusteredQuantumGraph is not valid.
512 """
513 # Check no cycles
514 if not is_directed_acyclic_graph(self._cluster_graph):
515 raise RuntimeError("ClusteredQuantumGraph is not a directed acyclic graph.")
517 # Check that Quantum only in 1 cluster
518 # Check cluster tags label matches cluster label
519 node_ids = set()
520 for cluster in self.clusters():
521 if "label" in cluster.tags and cluster.tags["label"] != cluster.label:
522 raise RuntimeError(
523 f"Label mismatch in cluster {cluster.name}: "
524 f"cluster={cluster.label} tags={cluster.tags['label']}"
525 )
527 for node_id in cluster.qgraph_node_ids:
528 if node_id in node_ids:
529 raise RuntimeError(
530 f"Quantum {node_id} occurs in at least 2 clusters (one of which is {cluster.name})"
531 )
532 else:
533 node_ids.add(node_id)
535 # Check that have all Quanta
536 quanta_count_qgraph = len(self._quantum_graph)
537 quanta_count_cqgraph = len(node_ids)
538 if quanta_count_qgraph != quanta_count_cqgraph:
539 raise RuntimeError(
540 f"Number of Quanta in clustered qgraph ({quanta_count_cqgraph}) does not equal number in"
541 f" quantum graph ({quanta_count_qgraph})"
542 )