Coverage for python/lsst/ctrl/bps/clustered_quantum_graph.py: 21%
168 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-04-01 02:12 -0700
« prev ^ index » next coverage.py v6.5.0, created at 2023-04-01 02:12 -0700
1# This file is part of ctrl_bps.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22"""Class definitions for a Clustered QuantumGraph where a node in the graph is
23a QuantumGraph.
24"""
26__all__ = ["QuantaCluster", "ClusteredQuantumGraph"]
29import logging
30import pickle
31import re
32from collections import Counter, defaultdict
33from pathlib import Path
35from lsst.pipe.base import NodeId, QuantumGraph
36from lsst.utils.iteration import ensure_iterable
37from networkx import DiGraph
39from .bps_draw import draw_networkx_dot
41_LOG = logging.getLogger(__name__)
44class QuantaCluster:
45 """Information about the cluster and Quanta belonging to it.
47 Parameters
48 ----------
49 name: `str`
50 Lookup key (logical file name) of file/directory. Must
51 be unique within ClusteredQuantumGraph.
52 label: `str`
53 Value used to group clusters.
54 tags : `dict` [`str`, `Any`], optional
55 Arbitrary key/value pairs for the cluster.
57 Raises
58 ------
59 ValueError
60 Raised if invalid name (e.g., name contains /)
61 """
63 def __init__(self, name, label, tags=None):
64 if "/" in name:
65 raise ValueError(f"Cluster's name cannot have a / ({name})")
66 self.name = name
67 self.label = label
68 self._qgraph_node_ids = []
69 self._task_label_counts = Counter()
70 self.tags = tags
71 if self.tags is None:
72 self.tags = {}
74 @classmethod
75 def from_quantum_node(cls, quantum_node, template):
76 """Create single quantum cluster from given quantum node.
78 Parameters
79 ----------
80 quantum_node : `lsst.pipe.base.QuantumNode`
81 QuantumNode for which to make into a single quantum cluster.
83 template : `str`
84 Template for creating cluster name.
86 Returns
87 -------
88 cluster : `QuantaCluster`
89 Newly created cluster containing the given quantum.
90 """
91 label = quantum_node.taskDef.label
92 node_id = quantum_node.nodeId
93 data_id = quantum_node.quantum.dataId
95 # Gather info for name template into a dictionary.
96 info = data_id.byName()
97 info["label"] = label
98 info["node_number"] = node_id
99 _LOG.debug("template = %s", template)
100 _LOG.debug("info for template = %s", info)
102 # Use dictionary plus template format string to create name. To avoid
103 # key errors from generic patterns, use defaultdict.
104 try:
105 name = template.format_map(defaultdict(lambda: "", info))
106 except TypeError:
107 _LOG.error("Problems creating cluster name. template='%s', info=%s", template, info)
108 raise
109 name = re.sub("_+", "_", name)
110 _LOG.debug("template name = %s", name)
112 cluster = QuantaCluster(name, label, info)
113 cluster.add_quantum(quantum_node.nodeId, label)
114 return cluster
116 @property
117 def qgraph_node_ids(self):
118 """QuantumGraph NodeIds corresponding to this cluster."""
119 _LOG.debug("_qgraph_node_ids = %s", self._qgraph_node_ids)
120 return frozenset(self._qgraph_node_ids)
122 @property
123 def quanta_counts(self):
124 """Counts of Quanta per taskDef.label in this cluster."""
125 return Counter(self._task_label_counts)
127 def add_quantum_node(self, quantum_node):
128 """Add a quantumNode to this cluster.
130 Parameters
131 ----------
132 quantum_node : `lsst.pipe.base.QuantumNode`
133 """
134 _LOG.debug("quantum_node = %s", quantum_node)
135 _LOG.debug("quantum_node.nodeId = %s", quantum_node.nodeId)
136 self.add_quantum(quantum_node.nodeId, quantum_node.taskDef.label)
138 def add_quantum(self, node_id, task_label):
139 """Add a quantumNode to this cluster.
141 Parameters
142 ----------
143 node_id : `lsst.pipe.base.NodeId`
144 ID for quantumNode to be added to cluster.
145 task_label : `str`
146 Task label for quantumNode to be added to cluster.
147 """
148 self._qgraph_node_ids.append(node_id)
149 self._task_label_counts[task_label] += 1
151 def __str__(self):
152 return (
153 f"QuantaCluster(name={self.name},label={self.label},tags={self.tags},"
154 f"counts={self.quanta_counts},ids={self.qgraph_node_ids})"
155 )
157 def __eq__(self, other: object) -> bool:
158 # Doesn't check data equality, but only
159 # name equality since those are supposed
160 # to be unique.
161 if isinstance(other, str):
162 return self.name == other
164 if isinstance(other, QuantaCluster):
165 return self.name == other.name
167 return False
169 def __hash__(self) -> int:
170 return hash(self.name)
173class ClusteredQuantumGraph:
174 """Graph where the data for a node is a subgraph of the full
175 QuantumGraph represented by a list of node ids.
177 Parameters
178 ----------
179 name : `str`
180 Name to be given to the ClusteredQuantumGraph.
181 qgraph : `lsst.pipe.base.QuantumGraph`
182 The QuantumGraph to be clustered.
183 qgraph_filename : `str`
184 Filename for given QuantumGraph if it has already been
185 serialized.
187 Raises
188 ------
189 ValueError
190 Raised if invalid name (e.g., name contains /)
192 Notes
193 -----
194 Using lsst.pipe.base.NodeId instead of integer because the QuantumGraph
195 API requires them. Chose skipping the repeated creation of objects to
196 use API over totally minimized memory usage.
197 """
199 def __init__(self, name, qgraph, qgraph_filename=None):
200 if "/" in name:
201 raise ValueError(f"name cannot have a / ({name})")
202 self._name = name
203 self._quantum_graph = qgraph
204 self._quantum_graph_filename = Path(qgraph_filename).resolve()
205 self._cluster_graph = DiGraph()
207 def __str__(self):
208 return (
209 f"ClusteredQuantumGraph(name={self.name},"
210 f"quantum_graph_filename={self._quantum_graph_filename},"
211 f"len(qgraph)={len(self._quantum_graph) if self._quantum_graph else None},"
212 f"len(cqgraph)={len(self._cluster_graph) if self._cluster_graph else None})"
213 )
215 def __len__(self):
216 """Return the number of clusters."""
217 return len(self._cluster_graph)
219 @property
220 def name(self):
221 """The name of the ClusteredQuantumGraph."""
222 return self._name
224 @property
225 def qgraph(self):
226 """The QuantumGraph associated with this Clustered
227 QuantumGraph.
228 """
229 return self._quantum_graph
231 def add_cluster(self, clusters_for_adding):
232 """Add a cluster of quanta as a node in the graph.
234 Parameters
235 ----------
236 clusters_for_adding: `QuantaCluster` or `Iterable` [`QuantaCluster`]
237 The cluster to be added to the ClusteredQuantumGraph.
238 """
239 for cluster in ensure_iterable(clusters_for_adding):
240 if not isinstance(cluster, QuantaCluster):
241 raise TypeError(f"Must be type QuantaCluster (given: {type(cluster)})")
243 if self._cluster_graph.has_node(cluster.name):
244 raise KeyError(f"Cluster {cluster.name} already exists in ClusteredQuantumGraph")
246 self._cluster_graph.add_node(cluster.name, cluster=cluster)
248 def get_cluster(self, name):
249 """Retrieve a cluster from the ClusteredQuantumGraph by name.
251 Parameters
252 ----------
253 name : `str`
254 Name of cluster to retrieve.
256 Returns
257 -------
258 cluster : `QuantaCluster`
259 QuantaCluster matching given name.
261 Raises
262 ------
263 KeyError
264 Raised if the ClusteredQuantumGraph does not contain
265 a cluster with given name.
266 """
267 try:
268 attr = self._cluster_graph.nodes[name]
269 except KeyError as ex:
270 raise KeyError(f"{self.name} does not have a cluster named {name}") from ex
271 return attr["cluster"]
273 def get_quantum_node(self, id_):
274 """Retrieve a QuantumNode from the ClusteredQuantumGraph by ID.
276 Parameters
277 ----------
278 id_ : `lsst.pipe.base.NodeId` or int
279 ID of the QuantumNode to retrieve.
281 Returns
282 -------
283 quantum_node : `lsst.pipe.base.QuantumNode`
284 QuantumNode matching given ID.
286 Raises
287 ------
288 KeyError
289 Raised if the ClusteredQuantumGraph does not contain
290 a QuantumNode with given ID.
291 """
292 node_id = id_
293 if isinstance(id_, int):
294 node_id = NodeId(id, self._quantum_graph.graphID)
295 _LOG.debug("get_quantum_node: node_id = %s", node_id)
296 return self._quantum_graph.getQuantumNodeByNodeId(node_id)
298 def __iter__(self):
299 """Iterate over names of clusters.
301 Returns
302 -------
303 names : `Iterator` [`str`]
304 Iterator over names of clusters.
305 """
306 return self._cluster_graph.nodes()
308 def clusters(self):
309 """Iterate over clusters.
311 Returns
312 -------
313 clusters : `Iterator` [`lsst.ctrl.bps.QuantaCluster`]
314 Iterator over clusters.
315 """
316 return map(self.get_cluster, self._cluster_graph.nodes())
318 def successors(self, name):
319 """Return clusters that are successors of the cluster
320 with the given name.
322 Parameters
323 ----------
324 name : `str`
325 Name of cluster for which need the successors.
327 Returns
328 -------
329 clusters : `Iterator` [`lsst.ctrl.bps.QuantaCluster`]
330 Iterator over successors of given cluster.
331 """
332 return map(self.get_cluster, self._cluster_graph.successors(name))
334 def predecessors(self, name):
335 """Return clusters that are predecessors of the cluster
336 with the given name.
338 Parameters
339 ----------
340 name : `str`
341 Name of cluster for which need the predecessors.
343 Returns
344 -------
345 clusters : `Iterator` [`lsst.ctrl.bps.QuantaCluster`]
346 Iterator over predecessors of given cluster.
347 """
348 return map(self.get_cluster, self._cluster_graph.predecessors(name))
350 def add_dependency(self, parent, child):
351 """Add a directed dependency between a parent cluster and a child
352 cluster.
354 Parameters
355 ----------
356 parent : `str` or `QuantaCluster`
357 Parent cluster.
358 child : `str` or `QuantaCluster`
359 Child cluster.
361 Raises
362 ------
363 KeyError
364 Raised if either the parent or child doesn't exist in the
365 ClusteredQuantumGraph.
366 """
367 if not self._cluster_graph.has_node(parent):
368 raise KeyError(f"{self.name} does not have a cluster named {parent}")
369 if not self._cluster_graph.has_node(child):
370 raise KeyError(f"{self.name} does not have a cluster named {child}")
371 _LOG.debug("add_dependency: adding edge %s %s", parent, child)
373 if isinstance(parent, QuantaCluster):
374 pname = parent.name
375 else:
376 pname = parent
378 if isinstance(child, QuantaCluster):
379 cname = child.name
380 else:
381 cname = child
382 self._cluster_graph.add_edge(pname, cname)
384 def __contains__(self, name):
385 """Check if a cluster with given name is in this ClusteredQuantumGraph.
387 Parameters
388 ----------
389 name : `str`
390 Name of cluster to check.
392 Returns
393 -------
394 found : `bool`
395 Whether a cluster with given name is in this ClusteredQuantumGraph.
396 """
397 return self._cluster_graph.has_node(name)
399 def save(self, filename, format_=None):
400 """Save the ClusteredQuantumGraph in a format that is loadable.
401 The QuantumGraph is saved separately if hasn't already been
402 serialized.
404 Parameters
405 ----------
406 filename : `str`
407 File to which the ClusteredQuantumGraph should be serialized.
409 format_ : `str`, optional
410 Format in which to write the data. It defaults to pickle format.
411 """
412 path = Path(filename)
414 # if format is None, try extension
415 if format_ is None:
416 format_ = path.suffix[1:] # suffix includes the leading period
418 if format_ not in {"pickle"}:
419 raise RuntimeError(f"Unknown format ({format_})")
421 if not self._quantum_graph_filename:
422 # Create filename based on given ClusteredQuantumGraph filename
423 self._quantum_graph_filename = path.with_suffix(".qgraph")
425 # If QuantumGraph file doesn't already exist, save it:
426 if not Path(self._quantum_graph_filename).exists():
427 self._quantum_graph.saveUri(self._quantum_graph_filename)
429 if format_ == "pickle":
430 # Don't save QuantumGraph in same file.
431 tmp_qgraph = self._quantum_graph
432 self._quantum_graph = None
433 with open(filename, "wb") as fh:
434 pickle.dump(self, fh)
435 # Return to original state.
436 self._quantum_graph = tmp_qgraph
438 def draw(self, filename, format_=None):
439 """Draw the ClusteredQuantumGraph in a given format.
441 Parameters
442 ----------
443 filename : `str`
444 File to which the ClusteredQuantumGraph should be serialized.
446 format_ : `str`, optional
447 Format in which to draw the data. It defaults to dot format.
448 """
449 path = Path(filename)
451 # if format is None, try extension
452 if format_ is None:
453 format_ = path.suffix[1:] # suffix includes the leading period
455 draw_funcs = {"dot": draw_networkx_dot}
456 if format_ in draw_funcs:
457 draw_funcs[format_](self._cluster_graph, filename)
458 else:
459 raise RuntimeError(f"Unknown draw format ({format_}")
461 @classmethod
462 def load(cls, filename, format_=None):
463 """Load a ClusteredQuantumGraph from the given file.
465 Parameters
466 ----------
467 filename : `str`
468 File from which to read the ClusteredQuantumGraph.
469 format_ : `str`, optional
470 Format of data to expect when loading from stream. It defaults
471 to pickle format.
473 Returns
474 -------
475 ClusteredQuantumGraph : `lsst.ctrl.bps.ClusteredQuantumGraph`
476 ClusteredQuantumGraph workflow loaded from the given file.
477 The QuantumGraph is loaded from its own file specified in
478 the saved ClusteredQuantumGraph.
479 """
480 path = Path(filename)
482 # if format is None, try extension
483 if format_ is None:
484 format_ = path.suffix[1:] # suffix includes the leading period
486 if format_ not in {"pickle"}:
487 raise RuntimeError(f"Unknown format ({format_})")
489 cgraph = None
490 if format_ == "pickle":
491 with open(filename, "rb") as fh:
492 cgraph = pickle.load(fh)
494 # The QuantumGraph was saved separately
495 try:
496 cgraph._quantum_graph = QuantumGraph.loadUri(cgraph._quantum_graph_filename)
497 except FileNotFoundError: # Try same path as ClusteredQuantumGraph
498 new_filename = path.parent / Path(cgraph._quantum_graph_filename).name
499 cgraph._quantum_graph = QuantumGraph.loadUri(new_filename)
501 return cgraph