Coverage for python/lsst/ctrl/bps/clustered_quantum_graph.py: 23%
170 statements
« prev ^ index » next coverage.py v6.4, created at 2022-05-28 09:32 +0000
« prev ^ index » next coverage.py v6.4, created at 2022-05-28 09:32 +0000
1# This file is part of ctrl_bps.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22"""Class definitions for a Clustered QuantumGraph where a node in the graph is
23a QuantumGraph.
24"""
26__all__ = ["QuantaCluster", "ClusteredQuantumGraph"]
29import logging
30import pickle
31import re
32from collections import Counter, defaultdict
33from pathlib import Path
35from lsst.daf.butler import DimensionUniverse
36from lsst.pipe.base import NodeId, QuantumGraph
37from lsst.utils.iteration import ensure_iterable
38from networkx import DiGraph
40from .bps_draw import draw_networkx_dot
42_LOG = logging.getLogger(__name__)
45class QuantaCluster:
46 """Information about the cluster and Quanta belonging to it.
48 Parameters
49 ----------
50 name: `str`
51 Lookup key (logical file name) of file/directory. Must
52 be unique within ClusteredQuantumGraph.
53 label: `str`
54 Value used to group clusters.
55 tags : `dict` [`str`, `Any`], optional
56 Arbitrary key/value pairs for the cluster.
58 Raises
59 ------
60 ValueError
61 Raised if invalid name (e.g., name contains /)
62 """
64 def __init__(self, name, label, tags=None):
65 if "/" in name:
66 raise ValueError(f"Cluster's name cannot have a / ({name})")
67 self.name = name
68 self.label = label
69 self._qgraph_node_ids = []
70 self._task_label_counts = Counter()
71 self.tags = tags
72 if self.tags is None:
73 self.tags = {}
75 @classmethod
76 def from_quantum_node(cls, quantum_node, template):
77 """Create single quantum cluster from given quantum node.
79 Parameters
80 ----------
81 quantum_node : `lsst.pipe.base.QuantumNode`
82 QuantumNode for which to make into a single quantum cluster.
84 template : `str`
85 Template for creating cluster name.
87 Returns
88 -------
89 cluster : `QuantaCluster`
90 Newly created cluster containing the given quantum.
91 """
92 label = quantum_node.taskDef.label
93 node_id = quantum_node.nodeId
94 data_id = quantum_node.quantum.dataId
96 # Gather info for name template into a dictionary.
97 info = data_id.byName()
98 info["label"] = label
99 info["node_number"] = node_id
100 _LOG.debug("template = %s", template)
101 _LOG.debug("info for template = %s", info)
103 # Use dictionary plus template format string to create name. To avoid
104 # key errors from generic patterns, use defaultdict.
105 try:
106 name = template.format_map(defaultdict(lambda: "", info))
107 except TypeError:
108 _LOG.error("Problems creating cluster name. template='%s', info=%s", template, info)
109 raise
110 name = re.sub("_+", "_", name)
111 _LOG.debug("template name = %s", name)
113 cluster = QuantaCluster(name, label, info)
114 cluster.add_quantum(quantum_node.nodeId, label)
115 return cluster
117 @property
118 def qgraph_node_ids(self):
119 """QuantumGraph NodeIds corresponding to this cluster."""
120 _LOG.debug("_qgraph_node_ids = %s", self._qgraph_node_ids)
121 return frozenset(self._qgraph_node_ids)
123 @property
124 def quanta_counts(self):
125 """Counts of Quanta per taskDef.label in this cluster."""
126 return Counter(self._task_label_counts)
128 def add_quantum_node(self, quantum_node):
129 """Add a quantumNode to this cluster.
131 Parameters
132 ----------
133 quantum_node : `lsst.pipe.base.QuantumNode`
134 """
135 _LOG.debug("quantum_node = %s", quantum_node)
136 _LOG.debug("quantum_node.nodeId = %s", quantum_node.nodeId)
137 self.add_quantum(quantum_node.nodeId, quantum_node.taskDef.label)
139 def add_quantum(self, node_id, task_label):
140 """Add a quantumNode to this cluster.
142 Parameters
143 ----------
144 node_id : `lsst.pipe.base.NodeId`
145 ID for quantumNode to be added to cluster.
146 task_label : `str`
147 Task label for quantumNode to be added to cluster.
148 """
149 self._qgraph_node_ids.append(node_id)
150 self._task_label_counts[task_label] += 1
152 def __str__(self):
153 return (
154 f"QuantaCluster(name={self.name},label={self.label},tags={self.tags},"
155 f"counts={self.quanta_counts},ids={self.qgraph_node_ids})"
156 )
158 def __eq__(self, other: object) -> bool:
159 # Doesn't check data equality, but only
160 # name equality since those are supposed
161 # to be unique.
162 if isinstance(other, str):
163 return self.name == other
165 if isinstance(other, QuantaCluster):
166 return self.name == other.name
168 return False
170 def __hash__(self) -> int:
171 return hash(self.name)
174class ClusteredQuantumGraph:
175 """Graph where the data for a node is a subgraph of the full
176 QuantumGraph represented by a list of node ids.
178 Parameters
179 ----------
180 name : `str`
181 Name to be given to the ClusteredQuantumGraph.
182 qgraph : `lsst.pipe.base.QuantumGraph`
183 The QuantumGraph to be clustered.
184 qgraph_filename : `str`
185 Filename for given QuantumGraph if it has already been
186 serialized.
188 Raises
189 ------
190 ValueError
191 Raised if invalid name (e.g., name contains /)
193 Notes
194 -----
195 Using lsst.pipe.base.NodeId instead of integer because the QuantumGraph
196 API requires them. Chose skipping the repeated creation of objects to
197 use API over totally minimized memory usage.
198 """
200 def __init__(self, name, qgraph, qgraph_filename=None):
201 if "/" in name:
202 raise ValueError(f"name cannot have a / ({name})")
203 self._name = name
204 self._quantum_graph = qgraph
205 self._quantum_graph_filename = Path(qgraph_filename).resolve()
206 self._cluster_graph = DiGraph()
208 def __str__(self):
209 return (
210 f"ClusteredQuantumGraph(name={self.name},"
211 f"quantum_graph_filename={self._quantum_graph_filename},"
212 f"len(qgraph)={len(self._quantum_graph) if self._quantum_graph else None},"
213 f"len(cqgraph)={len(self._cluster_graph) if self._cluster_graph else None})"
214 )
216 def __len__(self):
217 """Return the number of clusters."""
218 return len(self._cluster_graph)
220 @property
221 def name(self):
222 """The name of the ClusteredQuantumGraph."""
223 return self._name
225 @property
226 def qgraph(self):
227 """The QuantumGraph associated with this Clustered
228 QuantumGraph.
229 """
230 return self._quantum_graph
232 def add_cluster(self, clusters_for_adding):
233 """Add a cluster of quanta as a node in the graph.
235 Parameters
236 ----------
237 clusters_for_adding: `QuantaCluster` or `Iterable` [`QuantaCluster`]
238 The cluster to be added to the ClusteredQuantumGraph.
239 """
240 for cluster in ensure_iterable(clusters_for_adding):
241 if not isinstance(cluster, QuantaCluster):
242 raise TypeError(f"Must be type QuantaCluster (given: {type(cluster)})")
244 if self._cluster_graph.has_node(cluster.name):
245 raise KeyError(f"Cluster {cluster.name} already exists in ClusteredQuantumGraph")
247 self._cluster_graph.add_node(cluster.name, cluster=cluster)
249 def get_cluster(self, name):
250 """Retrieve a cluster from the ClusteredQuantumGraph by name.
252 Parameters
253 ----------
254 name : `str`
255 Name of cluster to retrieve.
257 Returns
258 -------
259 cluster : `QuantaCluster`
260 QuantaCluster matching given name.
262 Raises
263 ------
264 KeyError
265 Raised if the ClusteredQuantumGraph does not contain
266 a cluster with given name.
267 """
268 try:
269 attr = self._cluster_graph.nodes[name]
270 except KeyError as ex:
271 raise KeyError(f"{self.name} does not have a cluster named {name}") from ex
272 return attr["cluster"]
274 def get_quantum_node(self, id_):
275 """Retrieve a QuantumNode from the ClusteredQuantumGraph by ID.
277 Parameters
278 ----------
279 id_ : `lsst.pipe.base.NodeId` or int
280 ID of the QuantumNode to retrieve.
282 Returns
283 -------
284 quantum_node : `lsst.pipe.base.QuantumNode`
285 QuantumNode matching given ID.
287 Raises
288 ------
289 KeyError
290 Raised if the ClusteredQuantumGraph does not contain
291 a QuantumNode with given ID.
292 """
293 node_id = id_
294 if isinstance(id_, int):
295 node_id = NodeId(id, self._quantum_graph.graphID)
296 _LOG.debug("get_quantum_node: node_id = %s", node_id)
297 return self._quantum_graph.getQuantumNodeByNodeId(node_id)
299 def __iter__(self):
300 """Iterate over names of clusters.
302 Returns
303 -------
304 names : `Iterator` [`str`]
305 Iterator over names of clusters.
306 """
307 return self._cluster_graph.nodes()
309 def clusters(self):
310 """Iterate over clusters.
312 Returns
313 -------
314 clusters : `Iterator` [`lsst.ctrl.bps.QuantaCluster`]
315 Iterator over clusters.
316 """
317 return map(self.get_cluster, self._cluster_graph.nodes())
319 def successors(self, name):
320 """Return clusters that are successors of the cluster
321 with the given name.
323 Parameters
324 ----------
325 name : `str`
326 Name of cluster for which need the successors.
328 Returns
329 -------
330 clusters : `Iterator` [`lsst.ctrl.bps.QuantaCluster`]
331 Iterator over successors of given cluster.
332 """
333 return map(self.get_cluster, self._cluster_graph.successors(name))
335 def predecessors(self, name):
336 """Return clusters that are predecessors of the cluster
337 with the given name.
339 Parameters
340 ----------
341 name : `str`
342 Name of cluster for which need the predecessors.
344 Returns
345 -------
346 clusters : `Iterator` [`lsst.ctrl.bps.QuantaCluster`]
347 Iterator over predecessors of given cluster.
348 """
349 return map(self.get_cluster, self._cluster_graph.predecessors(name))
351 def add_dependency(self, parent, child):
352 """Add a directed dependency between a parent cluster and a child
353 cluster.
355 Parameters
356 ----------
357 parent : `str` or `QuantaCluster`
358 Parent cluster.
359 child : `str` or `QuantaCluster`
360 Child cluster.
362 Raises
363 ------
364 KeyError
365 Raised if either the parent or child doesn't exist in the
366 ClusteredQuantumGraph.
367 """
368 if not self._cluster_graph.has_node(parent):
369 raise KeyError(f"{self.name} does not have a cluster named {parent}")
370 if not self._cluster_graph.has_node(child):
371 raise KeyError(f"{self.name} does not have a cluster named {child}")
372 _LOG.debug("add_dependency: adding edge %s %s", parent, child)
374 if isinstance(parent, QuantaCluster):
375 pname = parent.name
376 else:
377 pname = parent
379 if isinstance(child, QuantaCluster):
380 cname = child.name
381 else:
382 cname = child
383 self._cluster_graph.add_edge(pname, cname)
385 def __contains__(self, name):
386 """Check if a cluster with given name is in this ClusteredQuantumGraph.
388 Parameters
389 ----------
390 name : `str`
391 Name of cluster to check.
393 Returns
394 -------
395 found : `bool`
396 Whether a cluster with given name is in this ClusteredQuantumGraph.
397 """
398 return self._cluster_graph.has_node(name)
400 def save(self, filename, format_=None):
401 """Save the ClusteredQuantumGraph in a format that is loadable.
402 The QuantumGraph is saved separately if hasn't already been
403 serialized.
405 Parameters
406 ----------
407 filename : `str`
408 File to which the ClusteredQuantumGraph should be serialized.
410 format_ : `str`, optional
411 Format in which to write the data. It defaults to pickle format.
412 """
413 path = Path(filename)
415 # if format is None, try extension
416 if format_ is None:
417 format_ = path.suffix[1:] # suffix includes the leading period
419 if format_ not in {"pickle"}:
420 raise RuntimeError(f"Unknown format ({format_})")
422 if not self._quantum_graph_filename:
423 # Create filename based on given ClusteredQuantumGraph filename
424 self._quantum_graph_filename = path.with_suffix(".qgraph")
426 # If QuantumGraph file doesn't already exist, save it:
427 if not Path(self._quantum_graph_filename).exists():
428 self._quantum_graph.saveUri(self._quantum_graph_filename)
430 if format_ == "pickle":
431 # Don't save QuantumGraph in same file.
432 tmp_qgraph = self._quantum_graph
433 self._quantum_graph = None
434 with open(filename, "wb") as fh:
435 pickle.dump(self, fh)
436 # Return to original state.
437 self._quantum_graph = tmp_qgraph
439 def draw(self, filename, format_=None):
440 """Draw the ClusteredQuantumGraph in a given format.
442 Parameters
443 ----------
444 filename : `str`
445 File to which the ClusteredQuantumGraph should be serialized.
447 format_ : `str`, optional
448 Format in which to draw the data. It defaults to dot format.
449 """
450 path = Path(filename)
452 # if format is None, try extension
453 if format_ is None:
454 format_ = path.suffix[1:] # suffix includes the leading period
456 draw_funcs = {"dot": draw_networkx_dot}
457 if format_ in draw_funcs:
458 draw_funcs[format_](self._cluster_graph, filename)
459 else:
460 raise RuntimeError(f"Unknown draw format ({format_}")
462 @classmethod
463 def load(cls, filename, format_=None):
464 """Load a ClusteredQuantumGraph from the given file.
466 Parameters
467 ----------
468 filename : `str`
469 File from which to read the ClusteredQuantumGraph.
470 format_ : `str`, optional
471 Format of data to expect when loading from stream. It defaults
472 to pickle format.
474 Returns
475 -------
476 ClusteredQuantumGraph : `lsst.ctrl.bps.ClusteredQuantumGraph`
477 ClusteredQuantumGraph workflow loaded from the given file.
478 The QuantumGraph is loaded from its own file specified in
479 the saved ClusteredQuantumGraph.
480 """
481 path = Path(filename)
483 # if format is None, try extension
484 if format_ is None:
485 format_ = path.suffix[1:] # suffix includes the leading period
487 if format_ not in {"pickle"}:
488 raise RuntimeError(f"Unknown format ({format_})")
490 cgraph = None
491 if format_ == "pickle":
492 dim_univ = DimensionUniverse()
493 with open(filename, "rb") as fh:
494 cgraph = pickle.load(fh)
496 # The QuantumGraph was saved separately
497 try:
498 cgraph._quantum_graph = QuantumGraph.loadUri(cgraph._quantum_graph_filename, dim_univ)
499 except FileNotFoundError: # Try same path as ClusteredQuantumGraph
500 new_filename = path.parent / Path(cgraph._quantum_graph_filename).name
501 cgraph._quantum_graph = QuantumGraph.loadUri(new_filename, dim_univ)
503 return cgraph