Coverage for python/lsst/ctrl/bps/clustered_quantum_graph.py: 24%
174 statements
« prev ^ index » next coverage.py v7.3.0, created at 2023-09-02 09:44 +0000
« prev ^ index » next coverage.py v7.3.0, created at 2023-09-02 09:44 +0000
1# This file is part of ctrl_bps.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22"""Class definitions for a Clustered QuantumGraph where a node in the graph is
23a QuantumGraph.
24"""
26__all__ = ["QuantaCluster", "ClusteredQuantumGraph"]
29import logging
30import pickle
31import re
32from collections import Counter, defaultdict
33from pathlib import Path
35from lsst.pipe.base import NodeId, QuantumGraph
36from lsst.utils.iteration import ensure_iterable
37from networkx import DiGraph, is_isomorphic, topological_sort
39from .bps_draw import draw_networkx_dot
41_LOG = logging.getLogger(__name__)
44class QuantaCluster:
45 """Information about the cluster and Quanta belonging to it.
47 Parameters
48 ----------
49 name: `str`
50 Lookup key (logical file name) of file/directory. Must
51 be unique within ClusteredQuantumGraph.
52 label: `str`
53 Value used to group clusters.
54 tags : `dict` [`str`, `Any`], optional
55 Arbitrary key/value pairs for the cluster.
57 Raises
58 ------
59 ValueError
60 Raised if invalid name (e.g., name contains /)
61 """
63 def __init__(self, name, label, tags=None):
64 if "/" in name:
65 raise ValueError(f"Cluster's name cannot have a / ({name})")
66 self.name = name
67 self.label = label
68 self._qgraph_node_ids = []
69 self._task_label_counts = Counter()
70 self.tags = tags
71 if self.tags is None:
72 self.tags = {}
74 @classmethod
75 def from_quantum_node(cls, quantum_node, template):
76 """Create single quantum cluster from given quantum node.
78 Parameters
79 ----------
80 quantum_node : `lsst.pipe.base.QuantumNode`
81 QuantumNode for which to make into a single quantum cluster.
83 template : `str`
84 Template for creating cluster name.
86 Returns
87 -------
88 cluster : `QuantaCluster`
89 Newly created cluster containing the given quantum.
90 """
91 label = quantum_node.taskDef.label
92 node_id = quantum_node.nodeId
93 data_id = quantum_node.quantum.dataId
95 # Gather info for name template into a dictionary.
96 info = data_id.byName()
97 info["label"] = label
98 info["node_number"] = node_id
99 _LOG.debug("template = %s", template)
100 _LOG.debug("info for template = %s", info)
102 # Use dictionary plus template format string to create name. To avoid
103 # key errors from generic patterns, use defaultdict.
104 try:
105 name = template.format_map(defaultdict(lambda: "", info))
106 except TypeError:
107 _LOG.error("Problems creating cluster name. template='%s', info=%s", template, info)
108 raise
109 name = re.sub("_+", "_", name)
110 _LOG.debug("template name = %s", name)
112 cluster = QuantaCluster(name, label, info)
113 cluster.add_quantum(quantum_node.nodeId, label)
114 return cluster
116 @property
117 def qgraph_node_ids(self):
118 """Quantum graph NodeIds corresponding to this cluster."""
119 _LOG.debug("_qgraph_node_ids = %s", self._qgraph_node_ids)
120 return frozenset(self._qgraph_node_ids)
122 @property
123 def quanta_counts(self):
124 """Counts of Quanta per taskDef.label in this cluster."""
125 return Counter(self._task_label_counts)
127 def add_quantum_node(self, quantum_node):
128 """Add a quantumNode to this cluster.
130 Parameters
131 ----------
132 quantum_node : `lsst.pipe.base.QuantumNode`
133 """
134 _LOG.debug("quantum_node = %s", quantum_node)
135 _LOG.debug("quantum_node.nodeId = %s", quantum_node.nodeId)
136 self.add_quantum(quantum_node.nodeId, quantum_node.taskDef.label)
138 def add_quantum(self, node_id, task_label):
139 """Add a quantumNode to this cluster.
141 Parameters
142 ----------
143 node_id : `lsst.pipe.base.NodeId`
144 ID for quantumNode to be added to cluster.
145 task_label : `str`
146 Task label for quantumNode to be added to cluster.
147 """
148 self._qgraph_node_ids.append(node_id)
149 self._task_label_counts[task_label] += 1
151 def __str__(self):
152 return (
153 f"QuantaCluster(name={self.name},label={self.label},tags={self.tags},"
154 f"counts={self.quanta_counts},ids={self.qgraph_node_ids})"
155 )
157 def __eq__(self, other: object) -> bool:
158 # Doesn't check data equality, but only
159 # name equality since those are supposed
160 # to be unique.
161 if isinstance(other, str):
162 return self.name == other
164 if isinstance(other, QuantaCluster):
165 return self.name == other.name
167 return False
169 def __hash__(self) -> int:
170 return hash(self.name)
173class ClusteredQuantumGraph:
174 """Graph where the data for a node is a subgraph of the full
175 QuantumGraph represented by a list of node ids.
177 Parameters
178 ----------
179 name : `str`
180 Name to be given to the ClusteredQuantumGraph.
181 qgraph : `lsst.pipe.base.QuantumGraph`
182 The QuantumGraph to be clustered.
183 qgraph_filename : `str`
184 Filename for given QuantumGraph if it has already been
185 serialized.
187 Raises
188 ------
189 ValueError
190 Raised if invalid name (e.g., name contains /)
192 Notes
193 -----
194 Using lsst.pipe.base.NodeId instead of integer because the QuantumGraph
195 API requires them. Chose skipping the repeated creation of objects to
196 use API over totally minimized memory usage.
197 """
199 def __init__(self, name, qgraph, qgraph_filename=None):
200 if "/" in name:
201 raise ValueError(f"name cannot have a / ({name})")
202 self._name = name
203 self._quantum_graph = qgraph
204 self._quantum_graph_filename = Path(qgraph_filename).resolve()
205 self._cluster_graph = DiGraph()
207 def __str__(self):
208 return (
209 f"ClusteredQuantumGraph(name={self.name},"
210 f"quantum_graph_filename={self._quantum_graph_filename},"
211 f"len(qgraph)={len(self._quantum_graph) if self._quantum_graph else None},"
212 f"len(cqgraph)={len(self._cluster_graph) if self._cluster_graph else None})"
213 )
215 def __len__(self):
216 """Return the number of clusters."""
217 return len(self._cluster_graph)
219 def __eq__(self, other):
220 if not isinstance(other, ClusteredQuantumGraph):
221 return False
222 if len(self) != len(other):
223 return False
224 return self._quantum_graph == other._quantum_graph and is_isomorphic(
225 self._cluster_graph, other._cluster_graph
226 )
228 @property
229 def name(self):
230 """The name of the ClusteredQuantumGraph."""
231 return self._name
233 @property
234 def qgraph(self):
235 """The QuantumGraph associated with this Clustered
236 QuantumGraph.
237 """
238 return self._quantum_graph
240 def add_cluster(self, clusters_for_adding):
241 """Add a cluster of quanta as a node in the graph.
243 Parameters
244 ----------
245 clusters_for_adding: `QuantaCluster` or `Iterable` [`QuantaCluster`]
246 The cluster to be added to the ClusteredQuantumGraph.
247 """
248 for cluster in ensure_iterable(clusters_for_adding):
249 if not isinstance(cluster, QuantaCluster):
250 raise TypeError(f"Must be type QuantaCluster (given: {type(cluster)})")
252 if self._cluster_graph.has_node(cluster.name):
253 raise KeyError(f"Cluster {cluster.name} already exists in ClusteredQuantumGraph")
255 self._cluster_graph.add_node(cluster.name, cluster=cluster)
257 def get_cluster(self, name):
258 """Retrieve a cluster from the ClusteredQuantumGraph by name.
260 Parameters
261 ----------
262 name : `str`
263 Name of cluster to retrieve.
265 Returns
266 -------
267 cluster : `QuantaCluster`
268 QuantaCluster matching given name.
270 Raises
271 ------
272 KeyError
273 Raised if the ClusteredQuantumGraph does not contain
274 a cluster with given name.
275 """
276 try:
277 attr = self._cluster_graph.nodes[name]
278 except KeyError as ex:
279 raise KeyError(f"{self.name} does not have a cluster named {name}") from ex
280 return attr["cluster"]
282 def get_quantum_node(self, id_):
283 """Retrieve a QuantumNode from the ClusteredQuantumGraph by ID.
285 Parameters
286 ----------
287 id_ : `lsst.pipe.base.NodeId` or int
288 ID of the QuantumNode to retrieve.
290 Returns
291 -------
292 quantum_node : `lsst.pipe.base.QuantumNode`
293 QuantumNode matching given ID.
295 Raises
296 ------
297 KeyError
298 Raised if the ClusteredQuantumGraph does not contain
299 a QuantumNode with given ID.
300 """
301 node_id = id_
302 if isinstance(id_, int):
303 node_id = NodeId(id, self._quantum_graph.graphID)
304 _LOG.debug("get_quantum_node: node_id = %s", node_id)
305 return self._quantum_graph.getQuantumNodeByNodeId(node_id)
307 def __iter__(self):
308 """Iterate over names of clusters.
310 Returns
311 -------
312 names : `Iterator` [`str`]
313 Iterator over names of clusters.
314 """
315 return self._cluster_graph.nodes()
317 def clusters(self):
318 """Iterate over clusters.
320 Returns
321 -------
322 clusters : `Iterator` [`lsst.ctrl.bps.QuantaCluster`]
323 Iterator over clusters in topological order.
324 """
325 return map(self.get_cluster, topological_sort(self._cluster_graph))
327 def successors(self, name):
328 """Return clusters that are successors of the cluster
329 with the given name.
331 Parameters
332 ----------
333 name : `str`
334 Name of cluster for which need the successors.
336 Returns
337 -------
338 clusters : `Iterator` [`lsst.ctrl.bps.QuantaCluster`]
339 Iterator over successors of given cluster.
340 """
341 return map(self.get_cluster, self._cluster_graph.successors(name))
343 def predecessors(self, name):
344 """Return clusters that are predecessors of the cluster
345 with the given name.
347 Parameters
348 ----------
349 name : `str`
350 Name of cluster for which need the predecessors.
352 Returns
353 -------
354 clusters : `Iterator` [`lsst.ctrl.bps.QuantaCluster`]
355 Iterator over predecessors of given cluster.
356 """
357 return map(self.get_cluster, self._cluster_graph.predecessors(name))
359 def add_dependency(self, parent, child):
360 """Add a directed dependency between a parent cluster and a child
361 cluster.
363 Parameters
364 ----------
365 parent : `str` or `QuantaCluster`
366 Parent cluster.
367 child : `str` or `QuantaCluster`
368 Child cluster.
370 Raises
371 ------
372 KeyError
373 Raised if either the parent or child doesn't exist in the
374 ClusteredQuantumGraph.
375 """
376 if not self._cluster_graph.has_node(parent):
377 raise KeyError(f"{self.name} does not have a cluster named {parent}")
378 if not self._cluster_graph.has_node(child):
379 raise KeyError(f"{self.name} does not have a cluster named {child}")
380 _LOG.debug("add_dependency: adding edge %s %s", parent, child)
382 if isinstance(parent, QuantaCluster):
383 pname = parent.name
384 else:
385 pname = parent
387 if isinstance(child, QuantaCluster):
388 cname = child.name
389 else:
390 cname = child
391 self._cluster_graph.add_edge(pname, cname)
393 def __contains__(self, name):
394 """Check if a cluster with given name is in this ClusteredQuantumGraph.
396 Parameters
397 ----------
398 name : `str`
399 Name of cluster to check.
401 Returns
402 -------
403 found : `bool`
404 Whether a cluster with given name is in this ClusteredQuantumGraph.
405 """
406 return self._cluster_graph.has_node(name)
408 def save(self, filename, format_=None):
409 """Save the ClusteredQuantumGraph in a format that is loadable.
410 The QuantumGraph is saved separately if hasn't already been
411 serialized.
413 Parameters
414 ----------
415 filename : `str`
416 File to which the ClusteredQuantumGraph should be serialized.
418 format_ : `str`, optional
419 Format in which to write the data. It defaults to pickle format.
420 """
421 path = Path(filename)
423 # if format is None, try extension
424 if format_ is None:
425 format_ = path.suffix[1:] # suffix includes the leading period
427 if format_ not in {"pickle"}:
428 raise RuntimeError(f"Unknown format ({format_})")
430 if not self._quantum_graph_filename:
431 # Create filename based on given ClusteredQuantumGraph filename
432 self._quantum_graph_filename = path.with_suffix(".qgraph")
434 # If QuantumGraph file doesn't already exist, save it:
435 if not Path(self._quantum_graph_filename).exists():
436 self._quantum_graph.saveUri(self._quantum_graph_filename)
438 if format_ == "pickle":
439 # Don't save QuantumGraph in same file.
440 tmp_qgraph = self._quantum_graph
441 self._quantum_graph = None
442 with open(filename, "wb") as fh:
443 pickle.dump(self, fh)
444 # Return to original state.
445 self._quantum_graph = tmp_qgraph
447 def draw(self, filename, format_=None):
448 """Draw the ClusteredQuantumGraph in a given format.
450 Parameters
451 ----------
452 filename : `str`
453 File to which the ClusteredQuantumGraph should be serialized.
455 format_ : `str`, optional
456 Format in which to draw the data. It defaults to dot format.
457 """
458 path = Path(filename)
460 # if format is None, try extension
461 if format_ is None:
462 format_ = path.suffix[1:] # suffix includes the leading period
464 draw_funcs = {"dot": draw_networkx_dot}
465 if format_ in draw_funcs:
466 draw_funcs[format_](self._cluster_graph, filename)
467 else:
468 raise RuntimeError(f"Unknown draw format ({format_}")
470 @classmethod
471 def load(cls, filename, format_=None):
472 """Load a ClusteredQuantumGraph from the given file.
474 Parameters
475 ----------
476 filename : `str`
477 File from which to read the ClusteredQuantumGraph.
478 format_ : `str`, optional
479 Format of data to expect when loading from stream. It defaults
480 to pickle format.
482 Returns
483 -------
484 ClusteredQuantumGraph : `lsst.ctrl.bps.ClusteredQuantumGraph`
485 ClusteredQuantumGraph workflow loaded from the given file.
486 The QuantumGraph is loaded from its own file specified in
487 the saved ClusteredQuantumGraph.
488 """
489 path = Path(filename)
491 # if format is None, try extension
492 if format_ is None:
493 format_ = path.suffix[1:] # suffix includes the leading period
495 if format_ not in {"pickle"}:
496 raise RuntimeError(f"Unknown format ({format_})")
498 cgraph = None
499 if format_ == "pickle":
500 with open(filename, "rb") as fh:
501 cgraph = pickle.load(fh)
503 # The QuantumGraph was saved separately
504 try:
505 cgraph._quantum_graph = QuantumGraph.loadUri(cgraph._quantum_graph_filename)
506 except FileNotFoundError: # Try same path as ClusteredQuantumGraph
507 new_filename = path.parent / Path(cgraph._quantum_graph_filename).name
508 cgraph._quantum_graph = QuantumGraph.loadUri(new_filename)
510 return cgraph