Coverage for python/lsst/ctrl/bps/clustered_quantum_graph.py: 24%
174 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-10 03:38 -0700
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-10 03:38 -0700
1# This file is part of ctrl_bps.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <https://www.gnu.org/licenses/>.
28"""Class definitions for a Clustered QuantumGraph where a node in the graph is
29a QuantumGraph.
30"""
32__all__ = ["QuantaCluster", "ClusteredQuantumGraph"]
35import logging
36import pickle
37import re
38from collections import Counter, defaultdict
39from pathlib import Path
41from lsst.pipe.base import NodeId, QuantumGraph
42from lsst.utils.iteration import ensure_iterable
43from networkx import DiGraph, is_isomorphic, topological_sort
45from .bps_draw import draw_networkx_dot
47_LOG = logging.getLogger(__name__)
50class QuantaCluster:
51 """Information about the cluster and Quanta belonging to it.
53 Parameters
54 ----------
55 name : `str`
56 Lookup key (logical file name) of file/directory. Must
57 be unique within ClusteredQuantumGraph.
58 label : `str`
59 Value used to group clusters.
60 tags : `dict` [`str`, `Any`], optional
61 Arbitrary key/value pairs for the cluster.
63 Raises
64 ------
65 ValueError
66 Raised if invalid name (e.g., name contains /).
67 """
69 def __init__(self, name, label, tags=None):
70 if "/" in name:
71 raise ValueError(f"Cluster's name cannot have a / ({name})")
72 self.name = name
73 self.label = label
74 self._qgraph_node_ids = []
75 self._task_label_counts = Counter()
76 self.tags = tags
77 if self.tags is None:
78 self.tags = {}
80 @classmethod
81 def from_quantum_node(cls, quantum_node, template):
82 """Create single quantum cluster from given quantum node.
84 Parameters
85 ----------
86 quantum_node : `lsst.pipe.base.QuantumNode`
87 QuantumNode for which to make into a single quantum cluster.
88 template : `str`
89 Template for creating cluster name.
91 Returns
92 -------
93 cluster : `QuantaCluster`
94 Newly created cluster containing the given quantum.
95 """
96 label = quantum_node.taskDef.label
97 node_id = quantum_node.nodeId
98 data_id = quantum_node.quantum.dataId
100 # Gather info for name template into a dictionary.
101 info = dict(data_id.required)
102 info["label"] = label
103 info["node_number"] = node_id
104 _LOG.debug("template = %s", template)
105 _LOG.debug("info for template = %s", info)
107 # Use dictionary plus template format string to create name. To avoid
108 # key errors from generic patterns, use defaultdict.
109 try:
110 name = template.format_map(defaultdict(lambda: "", info))
111 except TypeError:
112 _LOG.error("Problems creating cluster name. template='%s', info=%s", template, info)
113 raise
114 name = re.sub("_+", "_", name)
115 _LOG.debug("template name = %s", name)
117 cluster = QuantaCluster(name, label, info)
118 cluster.add_quantum(quantum_node.nodeId, label)
119 return cluster
121 @property
122 def qgraph_node_ids(self):
123 """Quantum graph NodeIds corresponding to this cluster."""
124 _LOG.debug("_qgraph_node_ids = %s", self._qgraph_node_ids)
125 return frozenset(self._qgraph_node_ids)
127 @property
128 def quanta_counts(self):
129 """Counts of Quanta per taskDef.label in this cluster."""
130 return Counter(self._task_label_counts)
132 def add_quantum_node(self, quantum_node):
133 """Add a quantumNode to this cluster.
135 Parameters
136 ----------
137 quantum_node : `lsst.pipe.base.QuantumNode`
138 Quantum node to add.
139 """
140 _LOG.debug("quantum_node = %s", quantum_node)
141 _LOG.debug("quantum_node.nodeId = %s", quantum_node.nodeId)
142 self.add_quantum(quantum_node.nodeId, quantum_node.taskDef.label)
144 def add_quantum(self, node_id, task_label):
145 """Add a quantumNode to this cluster.
147 Parameters
148 ----------
149 node_id : `lsst.pipe.base.NodeId`
150 ID for quantumNode to be added to cluster.
151 task_label : `str`
152 Task label for quantumNode to be added to cluster.
153 """
154 self._qgraph_node_ids.append(node_id)
155 self._task_label_counts[task_label] += 1
157 def __str__(self):
158 return (
159 f"QuantaCluster(name={self.name},label={self.label},tags={self.tags},"
160 f"counts={self.quanta_counts},ids={self.qgraph_node_ids})"
161 )
163 def __eq__(self, other: object) -> bool:
164 # Doesn't check data equality, but only
165 # name equality since those are supposed
166 # to be unique.
167 if isinstance(other, str):
168 return self.name == other
170 if isinstance(other, QuantaCluster):
171 return self.name == other.name
173 return False
175 def __hash__(self) -> int:
176 return hash(self.name)
179class ClusteredQuantumGraph:
180 """Graph where the data for a node is a subgraph of the full
181 QuantumGraph represented by a list of node ids.
183 Parameters
184 ----------
185 name : `str`
186 Name to be given to the ClusteredQuantumGraph.
187 qgraph : `lsst.pipe.base.QuantumGraph`
188 The QuantumGraph to be clustered.
189 qgraph_filename : `str`
190 Filename for given QuantumGraph if it has already been
191 serialized.
193 Raises
194 ------
195 ValueError
196 Raised if invalid name (e.g., name contains /)
198 Notes
199 -----
200 Using lsst.pipe.base.NodeId instead of integer because the QuantumGraph
201 API requires them. Chose skipping the repeated creation of objects to
202 use API over totally minimized memory usage.
203 """
205 def __init__(self, name, qgraph, qgraph_filename=None):
206 if "/" in name:
207 raise ValueError(f"name cannot have a / ({name})")
208 self._name = name
209 self._quantum_graph = qgraph
210 self._quantum_graph_filename = Path(qgraph_filename).resolve()
211 self._cluster_graph = DiGraph()
213 def __str__(self):
214 return (
215 f"ClusteredQuantumGraph(name={self.name},"
216 f"quantum_graph_filename={self._quantum_graph_filename},"
217 f"len(qgraph)={len(self._quantum_graph) if self._quantum_graph else None},"
218 f"len(cqgraph)={len(self._cluster_graph) if self._cluster_graph else None})"
219 )
221 def __len__(self):
222 """Return the number of clusters."""
223 return len(self._cluster_graph)
225 def __eq__(self, other):
226 if not isinstance(other, ClusteredQuantumGraph):
227 return False
228 if len(self) != len(other):
229 return False
230 return self._quantum_graph == other._quantum_graph and is_isomorphic(
231 self._cluster_graph, other._cluster_graph
232 )
234 @property
235 def name(self):
236 """The name of the ClusteredQuantumGraph."""
237 return self._name
239 @property
240 def qgraph(self):
241 """The QuantumGraph associated with this Clustered
242 QuantumGraph.
243 """
244 return self._quantum_graph
246 def add_cluster(self, clusters_for_adding):
247 """Add a cluster of quanta as a node in the graph.
249 Parameters
250 ----------
251 clusters_for_adding : `QuantaCluster` or `Iterable` [`QuantaCluster`]
252 The cluster to be added to the ClusteredQuantumGraph.
253 """
254 for cluster in ensure_iterable(clusters_for_adding):
255 if not isinstance(cluster, QuantaCluster):
256 raise TypeError(f"Must be type QuantaCluster (given: {type(cluster)})")
258 if self._cluster_graph.has_node(cluster.name):
259 raise KeyError(f"Cluster {cluster.name} already exists in ClusteredQuantumGraph")
261 self._cluster_graph.add_node(cluster.name, cluster=cluster)
263 def get_cluster(self, name):
264 """Retrieve a cluster from the ClusteredQuantumGraph by name.
266 Parameters
267 ----------
268 name : `str`
269 Name of cluster to retrieve.
271 Returns
272 -------
273 cluster : `QuantaCluster`
274 QuantaCluster matching given name.
276 Raises
277 ------
278 KeyError
279 Raised if the ClusteredQuantumGraph does not contain
280 a cluster with given name.
281 """
282 try:
283 attr = self._cluster_graph.nodes[name]
284 except KeyError as ex:
285 raise KeyError(f"{self.name} does not have a cluster named {name}") from ex
286 return attr["cluster"]
288 def get_quantum_node(self, id_):
289 """Retrieve a QuantumNode from the ClusteredQuantumGraph by ID.
291 Parameters
292 ----------
293 id_ : `lsst.pipe.base.NodeId` or int
294 ID of the QuantumNode to retrieve.
296 Returns
297 -------
298 quantum_node : `lsst.pipe.base.QuantumNode`
299 QuantumNode matching given ID.
301 Raises
302 ------
303 KeyError
304 Raised if the ClusteredQuantumGraph does not contain
305 a QuantumNode with given ID.
306 """
307 node_id = id_
308 if isinstance(id_, int):
309 node_id = NodeId(id, self._quantum_graph.graphID)
310 _LOG.debug("get_quantum_node: node_id = %s", node_id)
311 return self._quantum_graph.getQuantumNodeByNodeId(node_id)
313 def __iter__(self):
314 """Iterate over names of clusters.
316 Returns
317 -------
318 names : `Iterator` [`str`]
319 Iterator over names of clusters.
320 """
321 return self._cluster_graph.nodes()
323 def clusters(self):
324 """Iterate over clusters.
326 Returns
327 -------
328 clusters : `Iterator` [`lsst.ctrl.bps.QuantaCluster`]
329 Iterator over clusters in topological order.
330 """
331 return map(self.get_cluster, topological_sort(self._cluster_graph))
333 def successors(self, name):
334 """Return clusters that are successors of the cluster
335 with the given name.
337 Parameters
338 ----------
339 name : `str`
340 Name of cluster for which need the successors.
342 Returns
343 -------
344 clusters : `Iterator` [`lsst.ctrl.bps.QuantaCluster`]
345 Iterator over successors of given cluster.
346 """
347 return map(self.get_cluster, self._cluster_graph.successors(name))
349 def predecessors(self, name):
350 """Return clusters that are predecessors of the cluster
351 with the given name.
353 Parameters
354 ----------
355 name : `str`
356 Name of cluster for which need the predecessors.
358 Returns
359 -------
360 clusters : `Iterator` [`lsst.ctrl.bps.QuantaCluster`]
361 Iterator over predecessors of given cluster.
362 """
363 return map(self.get_cluster, self._cluster_graph.predecessors(name))
365 def add_dependency(self, parent, child):
366 """Add a directed dependency between a parent cluster and a child
367 cluster.
369 Parameters
370 ----------
371 parent : `str` or `QuantaCluster`
372 Parent cluster.
373 child : `str` or `QuantaCluster`
374 Child cluster.
376 Raises
377 ------
378 KeyError
379 Raised if either the parent or child doesn't exist in the
380 ClusteredQuantumGraph.
381 """
382 if not self._cluster_graph.has_node(parent):
383 raise KeyError(f"{self.name} does not have a cluster named {parent}")
384 if not self._cluster_graph.has_node(child):
385 raise KeyError(f"{self.name} does not have a cluster named {child}")
386 _LOG.debug("add_dependency: adding edge %s %s", parent, child)
388 if isinstance(parent, QuantaCluster):
389 pname = parent.name
390 else:
391 pname = parent
393 if isinstance(child, QuantaCluster):
394 cname = child.name
395 else:
396 cname = child
397 self._cluster_graph.add_edge(pname, cname)
399 def __contains__(self, name):
400 """Check if a cluster with given name is in this ClusteredQuantumGraph.
402 Parameters
403 ----------
404 name : `str`
405 Name of cluster to check.
407 Returns
408 -------
409 found : `bool`
410 Whether a cluster with given name is in this ClusteredQuantumGraph.
411 """
412 return self._cluster_graph.has_node(name)
414 def save(self, filename, format_=None):
415 """Save the ClusteredQuantumGraph in a format that is loadable.
416 The QuantumGraph is saved separately if hasn't already been
417 serialized.
419 Parameters
420 ----------
421 filename : `str`
422 File to which the ClusteredQuantumGraph should be serialized.
423 format_ : `str`, optional
424 Format in which to write the data. It defaults to pickle format.
425 """
426 path = Path(filename)
428 # if format is None, try extension
429 if format_ is None:
430 format_ = path.suffix[1:] # suffix includes the leading period
432 if format_ not in {"pickle"}:
433 raise RuntimeError(f"Unknown format ({format_})")
435 if not self._quantum_graph_filename:
436 # Create filename based on given ClusteredQuantumGraph filename
437 self._quantum_graph_filename = path.with_suffix(".qgraph")
439 # If QuantumGraph file doesn't already exist, save it:
440 if not Path(self._quantum_graph_filename).exists():
441 self._quantum_graph.saveUri(self._quantum_graph_filename)
443 if format_ == "pickle":
444 # Don't save QuantumGraph in same file.
445 tmp_qgraph = self._quantum_graph
446 self._quantum_graph = None
447 with open(filename, "wb") as fh:
448 pickle.dump(self, fh)
449 # Return to original state.
450 self._quantum_graph = tmp_qgraph
452 def draw(self, filename, format_=None):
453 """Draw the ClusteredQuantumGraph in a given format.
455 Parameters
456 ----------
457 filename : `str`
458 File to which the ClusteredQuantumGraph should be serialized.
459 format_ : `str`, optional
460 Format in which to draw the data. It defaults to dot format.
461 """
462 path = Path(filename)
464 # if format is None, try extension
465 if format_ is None:
466 format_ = path.suffix[1:] # suffix includes the leading period
468 draw_funcs = {"dot": draw_networkx_dot}
469 if format_ in draw_funcs:
470 draw_funcs[format_](self._cluster_graph, filename)
471 else:
472 raise RuntimeError(f"Unknown draw format ({format_}")
474 @classmethod
475 def load(cls, filename, format_=None):
476 """Load a ClusteredQuantumGraph from the given file.
478 Parameters
479 ----------
480 filename : `str`
481 File from which to read the ClusteredQuantumGraph.
482 format_ : `str`, optional
483 Format of data to expect when loading from stream. It defaults
484 to pickle format.
486 Returns
487 -------
488 ClusteredQuantumGraph : `lsst.ctrl.bps.ClusteredQuantumGraph`
489 ClusteredQuantumGraph workflow loaded from the given file.
490 The QuantumGraph is loaded from its own file specified in
491 the saved ClusteredQuantumGraph.
492 """
493 path = Path(filename)
495 # if format is None, try extension
496 if format_ is None:
497 format_ = path.suffix[1:] # suffix includes the leading period
499 if format_ not in {"pickle"}:
500 raise RuntimeError(f"Unknown format ({format_})")
502 cgraph = None
503 if format_ == "pickle":
504 with open(filename, "rb") as fh:
505 cgraph = pickle.load(fh)
507 # The QuantumGraph was saved separately
508 try:
509 cgraph._quantum_graph = QuantumGraph.loadUri(cgraph._quantum_graph_filename)
510 except FileNotFoundError: # Try same path as ClusteredQuantumGraph
511 new_filename = path.parent / Path(cgraph._quantum_graph_filename).name
512 cgraph._quantum_graph = QuantumGraph.loadUri(new_filename)
514 return cgraph