Coverage for python/lsst/ctrl/bps/clustered_quantum_graph.py : 22%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of ctrl_bps.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22"""Class definitions for a Clustered QuantumGraph where a node in the graph is
23a QuantumGraph.
24"""
26__all__ = ["QuantaCluster", "ClusteredQuantumGraph"]
29import logging
30import re
31import pickle
32from collections import Counter, defaultdict
33from pathlib import Path
34from networkx import DiGraph
36from lsst.daf.butler import DimensionUniverse
37from lsst.utils.iteration import ensure_iterable
38from lsst.pipe.base import QuantumGraph, NodeId
39from .bps_draw import draw_networkx_dot
42_LOG = logging.getLogger(__name__)
45class QuantaCluster:
46 """Information about the cluster and Quanta belonging to it.
48 Parameters
49 ----------
50 name: `str`
51 Lookup key (logical file name) of file/directory. Must
52 be unique within ClusteredQuantumGraph.
53 label: `str`
54 Value used to group clusters.
55 tags : `dict` [`str`, `Any`], optional
56 Arbitrary key/value pairs for the cluster.
58 Raises
59 ------
60 ValueError
61 Raised if invalid name (e.g., name contains /)
62 """
63 def __init__(self, name, label, tags=None):
64 if '/' in name:
65 raise ValueError(f"Cluster's name cannot have a / ({name})")
66 self.name = name
67 self.label = label
68 self._qgraph_node_ids = []
69 self._task_label_counts = Counter()
70 self.tags = tags
71 if self.tags is None:
72 self.tags = {}
74 @classmethod
75 def from_quantum_node(cls, quantum_node, template):
76 """Create single quantum cluster from given quantum node.
78 Parameters
79 ----------
80 quantum_node : `lsst.pipe.base.QuantumNode`
81 QuantumNode for which to make into a single quantum cluster.
83 template : `str`
84 Template for creating cluster name.
86 Returns
87 -------
88 cluster : `QuantaCluster`
89 Newly created cluster containing the given quantum.
90 """
91 label = quantum_node.taskDef.label
92 node_id = quantum_node.nodeId
93 data_id = quantum_node.quantum.dataId
95 # Gather info for name template into a dictionary.
96 info = data_id.byName()
97 info["label"] = label
98 info["node_number"] = node_id.number
99 _LOG.debug("template = %s", template)
100 _LOG.debug("info for template = %s", info)
102 # Use dictionary plus template format string to create name. To avoid
103 # key errors from generic patterns, use defaultdict.
104 name = template.format_map(defaultdict(lambda: "", info))
105 name = re.sub("_+", "_", name)
106 _LOG.debug("template name = %s", name)
108 cluster = QuantaCluster(name, label, info)
109 cluster.add_quantum(quantum_node.nodeId, label)
110 return cluster
112 @property
113 def qgraph_node_ids(self):
114 """QuantumGraph NodeIds corresponding to this cluster.
115 """
116 _LOG.debug("_qgraph_node_ids = %s", self._qgraph_node_ids)
117 return frozenset(self._qgraph_node_ids)
119 @property
120 def quanta_counts(self):
121 """Counts of Quanta per taskDef.label in this cluster.
122 """
123 return Counter(self._task_label_counts)
125 def add_quantum_node(self, quantum_node):
126 """Add a quantumNode to this cluster.
128 Parameters
129 ----------
130 quantum_node : `lsst.pipe.base.QuantumNode`
131 """
132 _LOG.debug("quantum_node = %s", quantum_node)
133 _LOG.debug("quantum_node.nodeId = %s", quantum_node.nodeId)
134 self.add_quantum(quantum_node.nodeId, quantum_node.taskDef.label)
136 def add_quantum(self, node_id, task_label):
137 """Add a quantumNode to this cluster.
139 Parameters
140 ----------
141 node_id : `lsst.pipe.base.NodeId`
142 ID for quantumNode to be added to cluster.
143 task_label : `str`
144 Task label for quantumNode to be added to cluster.
145 """
146 self._qgraph_node_ids.append(node_id)
147 self._task_label_counts[task_label] += 1
149 def __str__(self):
150 return f"QuantaCluster(name={self.name},label={self.label},tags={self.tags}," \
151 f"counts={self.quanta_counts},ids={self.qgraph_node_ids})"
153 def __eq__(self, other: object) -> bool:
154 # Doesn't check data equality, but only
155 # name equality since those are supposed
156 # to be unique.
157 if isinstance(other, str):
158 return self.name == other
160 if isinstance(other, QuantaCluster):
161 return self.name == other.name
163 return False
165 def __hash__(self) -> int:
166 return hash(self.name)
169class ClusteredQuantumGraph:
170 """Graph where the data for a node is a subgraph of the full
171 QuantumGraph represented by a list of node ids.
173 Parameters
174 ----------
175 name : `str`
176 Name to be given to the ClusteredQuantumGraph.
177 qgraph : `lsst.pipe.base.QuantumGraph`
178 The QuantumGraph to be clustered.
179 qgraph_filename : `str`
180 Filename for given QuantumGraph if it has already been
181 serialized.
183 Raises
184 ------
185 ValueError
186 Raised if invalid name (e.g., name contains /)
188 Notes
189 -----
190 Using lsst.pipe.base.NodeId instead of integer because the QuantumGraph
191 API requires them. Chose skipping the repeated creation of objects to
192 use API over totally minimized memory usage.
193 """
195 def __init__(self, name, qgraph, qgraph_filename=None):
196 if '/' in name:
197 raise ValueError(f"name cannot have a / ({name})")
198 self._name = name
199 self._quantum_graph = qgraph
200 self._quantum_graph_filename = qgraph_filename
201 self._cluster_graph = DiGraph()
203 def __str__(self):
204 return f"ClusteredQuantumGraph(name={self.name}," \
205 f"quantum_graph_filename={self._quantum_graph_filename}," \
206 f"len(qgraph)={len(self._quantum_graph) if self._quantum_graph else None}," \
207 f"len(cqgraph)={len(self._cluster_graph) if self._cluster_graph else None})"
209 def __len__(self):
210 """Return the number of clusters.
211 """
212 return len(self._cluster_graph)
214 @property
215 def name(self):
216 """The name of the ClusteredQuantumGraph.
217 """
218 return self._name
220 @property
221 def qgraph(self):
222 """The QuantumGraph associated with this Clustered
223 QuantumGraph.
224 """
225 return self._quantum_graph
227 def add_cluster(self, clusters_for_adding):
228 """Add a cluster of quanta as a node in the graph.
230 Parameters
231 ----------
232 clusters_for_adding: `QuantaCluster` or `Iterable` [`QuantaCluster`]
233 The cluster to be added to the ClusteredQuantumGraph.
234 """
235 for cluster in ensure_iterable(clusters_for_adding):
236 if not isinstance(cluster, QuantaCluster):
237 raise TypeError(f"Must be type QuantaCluster (given: {type(cluster)})")
239 if self._cluster_graph.has_node(cluster.name):
240 raise KeyError(f"Cluster {cluster.name} already exists in ClusteredQuantumGraph")
242 self._cluster_graph.add_node(cluster.name, cluster=cluster)
244 def get_cluster(self, name):
245 """Retrieve a cluster from the ClusteredQuantumGraph by name.
247 Parameters
248 ----------
249 name : `str`
250 Name of cluster to retrieve.
252 Returns
253 -------
254 cluster : `QuantaCluster`
255 QuantaCluster matching given name.
257 Raises
258 ------
259 KeyError
260 Raised if the ClusteredQuantumGraph does not contain
261 a cluster with given name.
262 """
263 if name not in self._cluster_graph:
264 raise KeyError(f"{self.name} does not have a cluster named {name}")
266 _LOG.debug("get_cluster nodes = %s", list(self._cluster_graph.nodes))
267 attr = self._cluster_graph.nodes[name]
268 return attr['cluster']
270 def get_quantum_node(self, id_):
271 """Retrieve a QuantumNode from the ClusteredQuantumGraph by ID.
273 Parameters
274 ----------
275 id_ : `lsst.pipe.base.NodeId` or int
276 ID of the QuantumNode to retrieve.
278 Returns
279 -------
280 quantum_node : `lsst.pipe.base.QuantumNode`
281 QuantumNode matching given ID.
283 Raises
284 ------
285 KeyError
286 Raised if the ClusteredQuantumGraph does not contain
287 a QuantumNode with given ID.
288 """
289 node_id = id_
290 if isinstance(id_, int):
291 node_id = NodeId(id, self._quantum_graph.graphID)
292 _LOG.debug("get_quantum_node: node_id = %s", node_id)
293 return self._quantum_graph.getQuantumNodeByNodeId(node_id)
295 def __iter__(self):
296 """Iterate over names of clusters.
298 Returns
299 -------
300 names : `Iterator` [`str`]
301 Iterator over names of clusters.
302 """
303 return self._cluster_graph.nodes()
305 def clusters(self):
306 """Iterate over clusters.
308 Returns
309 -------
310 clusters : `Iterator` [`lsst.ctrl.bps.QuantaCluster`]
311 Iterator over clusters.
312 """
313 return map(self.get_cluster, self._cluster_graph.nodes())
315 def successors(self, name):
316 """Return clusters that are successors of the cluster
317 with the given name.
319 Parameters
320 ----------
321 name : `str`
322 Name of cluster for which need the successors.
324 Returns
325 -------
326 clusters : `Iterator` [`lsst.ctrl.bps.QuantaCluster`]
327 Iterator over successors of given cluster.
328 """
329 return map(self.get_cluster, self._cluster_graph.successors(name))
331 def predecessors(self, name):
332 """Return clusters that are predecessors of the cluster
333 with the given name.
335 Parameters
336 ----------
337 name : `str`
338 Name of cluster for which need the predecessors.
340 Returns
341 -------
342 clusters : `Iterator` [`lsst.ctrl.bps.QuantaCluster`]
343 Iterator over predecessors of given cluster.
344 """
345 return map(self.get_cluster, self._cluster_graph.predecessors(name))
347 def add_dependency(self, parent, child):
348 """Add a directed dependency between a parent cluster and a child
349 cluster.
351 Parameters
352 ----------
353 parent : `str` or `QuantaCluster`
354 Parent cluster.
355 child : `str` or `QuantaCluster`
356 Child cluster.
358 Raises
359 ------
360 KeyError
361 Raised if either the parent or child doesn't exist in the
362 ClusteredQuantumGraph.
363 """
364 if not self._cluster_graph.has_node(parent):
365 raise KeyError(f"{self.name} does not have a cluster named {parent}")
366 if not self._cluster_graph.has_node(child):
367 raise KeyError(f"{self.name} does not have a cluster named {child}")
368 _LOG.debug("add_dependency: adding edge %s %s", parent, child)
370 if isinstance(parent, QuantaCluster):
371 pname = parent.name
372 else:
373 pname = parent
375 if isinstance(child, QuantaCluster):
376 cname = child.name
377 else:
378 cname = child
379 self._cluster_graph.add_edge(pname, cname)
381 def __contains__(self, name):
382 """Check if a cluster with given name is in this ClusteredQuantumGraph.
384 Parameters
385 ----------
386 name : `str`
387 Name of cluster to check.
389 Returns
390 -------
391 found : `bool`
392 Whether a cluster with given name is in this ClusteredQuantumGraph.
393 """
394 return self._cluster_graph.has_node(name)
396 def save(self, filename, format_=None):
397 """Save the ClusteredQuantumGraph in a format that is loadable.
398 The QuantumGraph is saved separately if hasn't already been
399 serialized.
401 Parameters
402 ----------
403 filename : `str`
404 File to which the ClusteredQuantumGraph should be serialized.
406 format_ : `str`, optional
407 Format in which to write the data. It defaults to pickle format.
408 """
409 path = Path(filename)
411 # if format is None, try extension
412 if format_ is None:
413 format_ = path.suffix[1:] # suffix includes the leading period
415 if format_ not in {"pickle"}:
416 raise RuntimeError(f"Unknown format ({format_})")
418 if not self._quantum_graph_filename:
419 # Create filename based on given ClusteredQuantumGraph filename
420 self._quantum_graph_filename = path.with_suffix('.qgraph')
422 # If QuantumGraph file doesn't already exist, save it:
423 if not Path(self._quantum_graph_filename).exists():
424 self._quantum_graph.saveUri(self._quantum_graph_filename)
426 if format_ == "pickle":
427 # Don't save QuantumGraph in same file.
428 tmp_qgraph = self._quantum_graph
429 self._quantum_graph = None
430 with open(filename, "wb") as fh:
431 pickle.dump(self, fh)
432 # Return to original state.
433 self._quantum_graph = tmp_qgraph
435 def draw(self, filename, format_=None):
436 """Draw the ClusteredQuantumGraph in a given format.
438 Parameters
439 ----------
440 filename : `str`
441 File to which the ClusteredQuantumGraph should be serialized.
443 format_ : `str`, optional
444 Format in which to draw the data. It defaults to dot format.
445 """
446 path = Path(filename)
448 # if format is None, try extension
449 if format_ is None:
450 format_ = path.suffix[1:] # suffix includes the leading period
452 draw_funcs = {"dot": draw_networkx_dot}
453 if format_ in draw_funcs:
454 draw_funcs[format_](self._cluster_graph, filename)
455 else:
456 raise RuntimeError(f"Unknown draw format ({format_}")
458 @classmethod
459 def load(cls, filename, format_=None):
460 """Load a ClusteredQuantumGraph from the given file.
462 Parameters
463 ----------
464 filename : `str`
465 File from which to read the ClusteredQuantumGraph.
466 format_ : `str`, optional
467 Format of data to expect when loading from stream. It defaults
468 to pickle format.
470 Returns
471 -------
472 ClusteredQuantumGraph : `lsst.ctrl.bps.ClusteredQuantumGraph`
473 ClusteredQuantumGraph workflow loaded from the given file.
474 The QuantumGraph is loaded from its own file specified in
475 the saved ClusteredQuantumGraph.
476 """
477 path = Path(filename)
479 # if format is None, try extension
480 if format_ is None:
481 format_ = path.suffix[1:] # suffix includes the leading period
483 if format_ not in {"pickle"}:
484 raise RuntimeError(f"Unknown format ({format_})")
486 cgraph = None
487 if format_ == "pickle":
488 dim_univ = DimensionUniverse()
489 with open(filename, "rb") as fh:
490 cgraph = pickle.load(fh)
492 # The QuantumGraph was saved separately
493 cgraph._quantum_graph = QuantumGraph.loadUri(cgraph._quantum_graph_filename, dim_univ)
494 return cgraph