Coverage for python/lsst/ctrl/bps/clustered_quantum_graph.py: 21%
170 statements
« prev ^ index » next coverage.py v6.5.0, created at 2022-10-29 02:43 -0700
« prev ^ index » next coverage.py v6.5.0, created at 2022-10-29 02:43 -0700
1# This file is part of ctrl_bps.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22"""Class definitions for a Clustered QuantumGraph where a node in the graph is
23a QuantumGraph.
24"""
26__all__ = ["QuantaCluster", "ClusteredQuantumGraph"]
29import logging
30import pickle
31import re
32from collections import Counter, defaultdict
33from pathlib import Path
35from lsst.pipe.base import NodeId
36from lsst.utils.iteration import ensure_iterable
37from networkx import DiGraph
39from .bps_draw import draw_networkx_dot
40from .pre_transform import read_quantum_graph
42_LOG = logging.getLogger(__name__)
45class QuantaCluster:
46 """Information about the cluster and Quanta belonging to it.
48 Parameters
49 ----------
50 name: `str`
51 Lookup key (logical file name) of file/directory. Must
52 be unique within ClusteredQuantumGraph.
53 label: `str`
54 Value used to group clusters.
55 tags : `dict` [`str`, `Any`], optional
56 Arbitrary key/value pairs for the cluster.
58 Raises
59 ------
60 ValueError
61 Raised if invalid name (e.g., name contains /)
62 """
64 def __init__(self, name, label, tags=None):
65 if "/" in name:
66 raise ValueError(f"Cluster's name cannot have a / ({name})")
67 self.name = name
68 self.label = label
69 self._qgraph_node_ids = []
70 self._task_label_counts = Counter()
71 self.tags = tags
72 if self.tags is None:
73 self.tags = {}
75 @classmethod
76 def from_quantum_node(cls, quantum_node, template):
77 """Create single quantum cluster from given quantum node.
79 Parameters
80 ----------
81 quantum_node : `lsst.pipe.base.QuantumNode`
82 QuantumNode for which to make into a single quantum cluster.
84 template : `str`
85 Template for creating cluster name.
87 Returns
88 -------
89 cluster : `QuantaCluster`
90 Newly created cluster containing the given quantum.
91 """
92 label = quantum_node.taskDef.label
93 node_id = quantum_node.nodeId
94 data_id = quantum_node.quantum.dataId
96 # Gather info for name template into a dictionary.
97 info = data_id.byName()
98 info["label"] = label
99 info["node_number"] = node_id
100 _LOG.debug("template = %s", template)
101 _LOG.debug("info for template = %s", info)
103 # Use dictionary plus template format string to create name. To avoid
104 # key errors from generic patterns, use defaultdict.
105 try:
106 name = template.format_map(defaultdict(lambda: "", info))
107 except TypeError:
108 _LOG.error("Problems creating cluster name. template='%s', info=%s", template, info)
109 raise
110 name = re.sub("_+", "_", name)
111 _LOG.debug("template name = %s", name)
113 cluster = QuantaCluster(name, label, info)
114 cluster.add_quantum(quantum_node.nodeId, label)
115 return cluster
117 @property
118 def qgraph_node_ids(self):
119 """QuantumGraph NodeIds corresponding to this cluster."""
120 _LOG.debug("_qgraph_node_ids = %s", self._qgraph_node_ids)
121 return frozenset(self._qgraph_node_ids)
123 @property
124 def quanta_counts(self):
125 """Counts of Quanta per taskDef.label in this cluster."""
126 return Counter(self._task_label_counts)
128 def add_quantum_node(self, quantum_node):
129 """Add a quantumNode to this cluster.
131 Parameters
132 ----------
133 quantum_node : `lsst.pipe.base.QuantumNode`
134 """
135 _LOG.debug("quantum_node = %s", quantum_node)
136 _LOG.debug("quantum_node.nodeId = %s", quantum_node.nodeId)
137 self.add_quantum(quantum_node.nodeId, quantum_node.taskDef.label)
139 def add_quantum(self, node_id, task_label):
140 """Add a quantumNode to this cluster.
142 Parameters
143 ----------
144 node_id : `lsst.pipe.base.NodeId`
145 ID for quantumNode to be added to cluster.
146 task_label : `str`
147 Task label for quantumNode to be added to cluster.
148 """
149 self._qgraph_node_ids.append(node_id)
150 self._task_label_counts[task_label] += 1
152 def __str__(self):
153 return (
154 f"QuantaCluster(name={self.name},label={self.label},tags={self.tags},"
155 f"counts={self.quanta_counts},ids={self.qgraph_node_ids})"
156 )
158 def __eq__(self, other: object) -> bool:
159 # Doesn't check data equality, but only
160 # name equality since those are supposed
161 # to be unique.
162 if isinstance(other, str):
163 return self.name == other
165 if isinstance(other, QuantaCluster):
166 return self.name == other.name
168 return False
170 def __hash__(self) -> int:
171 return hash(self.name)
174class ClusteredQuantumGraph:
175 """Graph where the data for a node is a subgraph of the full
176 QuantumGraph represented by a list of node ids.
178 Parameters
179 ----------
180 name : `str`
181 Name to be given to the ClusteredQuantumGraph.
182 qgraph : `lsst.pipe.base.QuantumGraph`
183 The QuantumGraph to be clustered.
184 qgraph_filename : `str`
185 Filename for given QuantumGraph if it has already been
186 serialized.
187 butler_uri : `str`
188 Location of butler repo used to create the QuantumGraph.
190 Raises
191 ------
192 ValueError
193 Raised if invalid name (e.g., name contains /)
195 Notes
196 -----
197 Using lsst.pipe.base.NodeId instead of integer because the QuantumGraph
198 API requires them. Chose skipping the repeated creation of objects to
199 use API over totally minimized memory usage.
200 """
202 def __init__(self, name, qgraph, qgraph_filename=None, butler_uri=None):
203 if "/" in name:
204 raise ValueError(f"name cannot have a / ({name})")
205 self._name = name
206 self._quantum_graph = qgraph
207 self._quantum_graph_filename = Path(qgraph_filename).resolve()
208 self._butler_uri = Path(butler_uri).resolve()
209 self._cluster_graph = DiGraph()
211 def __str__(self):
212 return (
213 f"ClusteredQuantumGraph(name={self.name},"
214 f"quantum_graph_filename={self._quantum_graph_filename},"
215 f"butler_uri ={self._butler_uri},"
216 f"len(qgraph)={len(self._quantum_graph) if self._quantum_graph else None},"
217 f"len(cqgraph)={len(self._cluster_graph) if self._cluster_graph else None})"
218 )
220 def __len__(self):
221 """Return the number of clusters."""
222 return len(self._cluster_graph)
224 @property
225 def name(self):
226 """The name of the ClusteredQuantumGraph."""
227 return self._name
229 @property
230 def qgraph(self):
231 """The QuantumGraph associated with this Clustered
232 QuantumGraph.
233 """
234 return self._quantum_graph
236 def add_cluster(self, clusters_for_adding):
237 """Add a cluster of quanta as a node in the graph.
239 Parameters
240 ----------
241 clusters_for_adding: `QuantaCluster` or `Iterable` [`QuantaCluster`]
242 The cluster to be added to the ClusteredQuantumGraph.
243 """
244 for cluster in ensure_iterable(clusters_for_adding):
245 if not isinstance(cluster, QuantaCluster):
246 raise TypeError(f"Must be type QuantaCluster (given: {type(cluster)})")
248 if self._cluster_graph.has_node(cluster.name):
249 raise KeyError(f"Cluster {cluster.name} already exists in ClusteredQuantumGraph")
251 self._cluster_graph.add_node(cluster.name, cluster=cluster)
253 def get_cluster(self, name):
254 """Retrieve a cluster from the ClusteredQuantumGraph by name.
256 Parameters
257 ----------
258 name : `str`
259 Name of cluster to retrieve.
261 Returns
262 -------
263 cluster : `QuantaCluster`
264 QuantaCluster matching given name.
266 Raises
267 ------
268 KeyError
269 Raised if the ClusteredQuantumGraph does not contain
270 a cluster with given name.
271 """
272 try:
273 attr = self._cluster_graph.nodes[name]
274 except KeyError as ex:
275 raise KeyError(f"{self.name} does not have a cluster named {name}") from ex
276 return attr["cluster"]
278 def get_quantum_node(self, id_):
279 """Retrieve a QuantumNode from the ClusteredQuantumGraph by ID.
281 Parameters
282 ----------
283 id_ : `lsst.pipe.base.NodeId` or int
284 ID of the QuantumNode to retrieve.
286 Returns
287 -------
288 quantum_node : `lsst.pipe.base.QuantumNode`
289 QuantumNode matching given ID.
291 Raises
292 ------
293 KeyError
294 Raised if the ClusteredQuantumGraph does not contain
295 a QuantumNode with given ID.
296 """
297 node_id = id_
298 if isinstance(id_, int):
299 node_id = NodeId(id, self._quantum_graph.graphID)
300 _LOG.debug("get_quantum_node: node_id = %s", node_id)
301 return self._quantum_graph.getQuantumNodeByNodeId(node_id)
303 def __iter__(self):
304 """Iterate over names of clusters.
306 Returns
307 -------
308 names : `Iterator` [`str`]
309 Iterator over names of clusters.
310 """
311 return self._cluster_graph.nodes()
313 def clusters(self):
314 """Iterate over clusters.
316 Returns
317 -------
318 clusters : `Iterator` [`lsst.ctrl.bps.QuantaCluster`]
319 Iterator over clusters.
320 """
321 return map(self.get_cluster, self._cluster_graph.nodes())
323 def successors(self, name):
324 """Return clusters that are successors of the cluster
325 with the given name.
327 Parameters
328 ----------
329 name : `str`
330 Name of cluster for which need the successors.
332 Returns
333 -------
334 clusters : `Iterator` [`lsst.ctrl.bps.QuantaCluster`]
335 Iterator over successors of given cluster.
336 """
337 return map(self.get_cluster, self._cluster_graph.successors(name))
339 def predecessors(self, name):
340 """Return clusters that are predecessors of the cluster
341 with the given name.
343 Parameters
344 ----------
345 name : `str`
346 Name of cluster for which need the predecessors.
348 Returns
349 -------
350 clusters : `Iterator` [`lsst.ctrl.bps.QuantaCluster`]
351 Iterator over predecessors of given cluster.
352 """
353 return map(self.get_cluster, self._cluster_graph.predecessors(name))
355 def add_dependency(self, parent, child):
356 """Add a directed dependency between a parent cluster and a child
357 cluster.
359 Parameters
360 ----------
361 parent : `str` or `QuantaCluster`
362 Parent cluster.
363 child : `str` or `QuantaCluster`
364 Child cluster.
366 Raises
367 ------
368 KeyError
369 Raised if either the parent or child doesn't exist in the
370 ClusteredQuantumGraph.
371 """
372 if not self._cluster_graph.has_node(parent):
373 raise KeyError(f"{self.name} does not have a cluster named {parent}")
374 if not self._cluster_graph.has_node(child):
375 raise KeyError(f"{self.name} does not have a cluster named {child}")
376 _LOG.debug("add_dependency: adding edge %s %s", parent, child)
378 if isinstance(parent, QuantaCluster):
379 pname = parent.name
380 else:
381 pname = parent
383 if isinstance(child, QuantaCluster):
384 cname = child.name
385 else:
386 cname = child
387 self._cluster_graph.add_edge(pname, cname)
389 def __contains__(self, name):
390 """Check if a cluster with given name is in this ClusteredQuantumGraph.
392 Parameters
393 ----------
394 name : `str`
395 Name of cluster to check.
397 Returns
398 -------
399 found : `bool`
400 Whether a cluster with given name is in this ClusteredQuantumGraph.
401 """
402 return self._cluster_graph.has_node(name)
404 def save(self, filename, format_=None):
405 """Save the ClusteredQuantumGraph in a format that is loadable.
406 The QuantumGraph is saved separately if hasn't already been
407 serialized.
409 Parameters
410 ----------
411 filename : `str`
412 File to which the ClusteredQuantumGraph should be serialized.
414 format_ : `str`, optional
415 Format in which to write the data. It defaults to pickle format.
416 """
417 path = Path(filename)
419 # if format is None, try extension
420 if format_ is None:
421 format_ = path.suffix[1:] # suffix includes the leading period
423 if format_ not in {"pickle"}:
424 raise RuntimeError(f"Unknown format ({format_})")
426 if not self._quantum_graph_filename:
427 # Create filename based on given ClusteredQuantumGraph filename
428 self._quantum_graph_filename = path.with_suffix(".qgraph")
430 # If QuantumGraph file doesn't already exist, save it:
431 if not Path(self._quantum_graph_filename).exists():
432 self._quantum_graph.saveUri(self._quantum_graph_filename)
434 if format_ == "pickle":
435 # Don't save QuantumGraph in same file.
436 tmp_qgraph = self._quantum_graph
437 self._quantum_graph = None
438 with open(filename, "wb") as fh:
439 pickle.dump(self, fh)
440 # Return to original state.
441 self._quantum_graph = tmp_qgraph
443 def draw(self, filename, format_=None):
444 """Draw the ClusteredQuantumGraph in a given format.
446 Parameters
447 ----------
448 filename : `str`
449 File to which the ClusteredQuantumGraph should be serialized.
451 format_ : `str`, optional
452 Format in which to draw the data. It defaults to dot format.
453 """
454 path = Path(filename)
456 # if format is None, try extension
457 if format_ is None:
458 format_ = path.suffix[1:] # suffix includes the leading period
460 draw_funcs = {"dot": draw_networkx_dot}
461 if format_ in draw_funcs:
462 draw_funcs[format_](self._cluster_graph, filename)
463 else:
464 raise RuntimeError(f"Unknown draw format ({format_}")
466 @classmethod
467 def load(cls, filename, format_=None):
468 """Load a ClusteredQuantumGraph from the given file.
470 Parameters
471 ----------
472 filename : `str`
473 File from which to read the ClusteredQuantumGraph.
474 format_ : `str`, optional
475 Format of data to expect when loading from stream. It defaults
476 to pickle format.
478 Returns
479 -------
480 ClusteredQuantumGraph : `lsst.ctrl.bps.ClusteredQuantumGraph`
481 ClusteredQuantumGraph workflow loaded from the given file.
482 The QuantumGraph is loaded from its own file specified in
483 the saved ClusteredQuantumGraph.
484 """
485 path = Path(filename)
487 # if format is None, try extension
488 if format_ is None:
489 format_ = path.suffix[1:] # suffix includes the leading period
491 if format_ not in {"pickle"}:
492 raise RuntimeError(f"Unknown format ({format_})")
494 cgraph = None
495 if format_ == "pickle":
496 with open(filename, "rb") as fh:
497 cgraph = pickle.load(fh)
499 # The QuantumGraph was saved separately
500 try:
501 cgraph._quantum_graph = read_quantum_graph(cgraph._quantum_graph_filename, cgraph._butler_uri)
502 except FileNotFoundError: # Try same path as ClusteredQuantumGraph
503 new_filename = path.parent / Path(cgraph._quantum_graph_filename).name
504 cgraph._quantum_graph = read_quantum_graph(new_filename, cgraph._butler_uri)
506 return cgraph