Coverage for python/lsst/ctrl/bps/quantum_clustering_funcs.py: 5%
126 statements
« prev ^ index » next coverage.py v6.4, created at 2022-05-28 09:32 +0000
« prev ^ index » next coverage.py v6.4, created at 2022-05-28 09:32 +0000
1# This file is part of ctrl_bps.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22"""Functions that convert QuantumGraph into ClusteredQuantumGraph.
23"""
24import logging
25import re
26from collections import defaultdict
28from lsst.pipe.base import NodeId
29from networkx import DiGraph, is_directed_acyclic_graph
31from . import ClusteredQuantumGraph, QuantaCluster
33_LOG = logging.getLogger(__name__)
36def single_quantum_clustering(config, qgraph, name):
37 """Create clusters with only single quantum.
39 Parameters
40 ----------
41 config : `lsst.ctrl.bps.BpsConfig`
42 BPS configuration.
43 qgraph : `lsst.pipe.base.QuantumGraph`
44 QuantumGraph to break into clusters for ClusteredQuantumGraph.
45 name : `str`
46 Name to give to ClusteredQuantumGraph.
48 Returns
49 -------
50 clustered_quantum : `lsst.ctrl.bps.ClusteredQuantumGraph`
51 ClusteredQuantumGraph with single quantum per cluster created from
52 given QuantumGraph.
53 """
54 cqgraph = ClusteredQuantumGraph(
55 name=name, qgraph=qgraph, qgraph_filename=config[".bps_defined.runQgraphFile"]
56 )
58 # Save mapping of quantum nodeNumber to name so don't have to create it
59 # multiple times.
60 number_to_name = {}
62 # Cache template per label for speed.
63 cached_template = {}
65 # Create cluster of single quantum.
66 for qnode in qgraph:
67 if qnode.taskDef.label not in cached_template:
68 found, template_data_id = config.search(
69 "templateDataId",
70 opt={"curvals": {"curr_pipetask": qnode.taskDef.label}, "replaceVars": False},
71 )
72 if found:
73 template = "{node_number}_{label}_" + template_data_id
74 else:
75 template = "{node_number}"
76 cached_template[qnode.taskDef.label] = template
78 cluster = QuantaCluster.from_quantum_node(qnode, cached_template[qnode.taskDef.label])
80 # Save mapping for use when creating dependencies.
81 number_to_name[qnode.nodeId] = cluster.name
83 cqgraph.add_cluster(cluster)
85 # Add cluster dependencies.
86 for qnode in qgraph:
87 # Get child nodes.
88 children = qgraph.determineOutputsOfQuantumNode(qnode)
89 for child in children:
90 cqgraph.add_dependency(number_to_name[qnode.nodeId], number_to_name[child.nodeId])
92 return cqgraph
95def _check_clusters_tasks(cluster_config, taskGraph):
96 """Check cluster definitions in terms of pipetask lists.
98 Parameters
99 ----------
100 cluster_config : `lsst.ctrl.bps.BpsConfig`
101 The cluster section from the BPS configuration.
102 taskGraph : `lsst.pipe.base.taskGraph`
103 Directed graph of tasks.
105 Returns
106 -------
107 task_labels : `set` [`str`]
108 Set of task labels from the cluster definitions.
110 Raises
111 -------
112 RuntimeError
113 Raised if task label appears in more than one cluster def or
114 if there's a cycle in the cluster defs.
115 """
117 # Build a "clustered" task graph to check for cycle.
118 task_to_cluster = {}
119 task_labels = set()
120 clustered_task_graph = DiGraph()
122 # Create clusters based on given configuration.
123 for cluster_label in cluster_config:
124 _LOG.debug("cluster = %s", cluster_label)
125 cluster_tasks = [pt.strip() for pt in cluster_config[cluster_label]["pipetasks"].split(",")]
126 for task_label in cluster_tasks:
127 if task_label in task_labels:
128 raise RuntimeError(
129 f"Task label {task_label} appears in more than one cluster definition. "
130 "Aborting submission."
131 )
132 task_labels.add(task_label)
133 task_to_cluster[task_label] = cluster_label
134 clustered_task_graph.add_node(cluster_label)
136 # Create clusters for tasks not covered by clusters.
137 for task in taskGraph:
138 if task.label not in task_labels:
139 task_to_cluster[task.label] = task.label
140 clustered_task_graph.add_node(task.label)
142 # Create dependencies between clusters.
143 for edge in taskGraph.edges:
144 if task_to_cluster[edge[0].label] != task_to_cluster[edge[1].label]:
145 clustered_task_graph.add_edge(task_to_cluster[edge[0].label], task_to_cluster[edge[1].label])
147 _LOG.debug("clustered_task_graph.edges = %s", [e for e in clustered_task_graph.edges])
149 if not is_directed_acyclic_graph(clustered_task_graph):
150 raise RuntimeError("Cluster pipetasks do not create a DAG")
152 return task_labels
155def dimension_clustering(config, qgraph, name):
156 """Follow config instructions to make clusters based upon dimensions.
158 Parameters
159 ----------
160 config : `lsst.ctrl.bps.BpsConfig`
161 BPS configuration.
162 qgraph : `lsst.pipe.base.QuantumGraph`
163 QuantumGraph to break into clusters for ClusteredQuantumGraph.
164 name : `str`
165 Name to give to ClusteredQuantumGraph.
167 Returns
168 -------
169 cqgraph : `lsst.ctrl.bps.ClusteredQuantumGraph`
170 ClusteredQuantumGraph with clustering as defined in config.
171 """
172 cqgraph = ClusteredQuantumGraph(
173 name=name, qgraph=qgraph, qgraph_filename=config[".bps_defined.runQgraphFile"]
174 )
176 # save mapping in order to create dependencies later
177 quantum_to_cluster = {}
179 cluster_config = config["cluster"]
180 task_labels = _check_clusters_tasks(cluster_config, qgraph.taskGraph)
181 for cluster_label in cluster_config:
182 _LOG.debug("cluster = %s", cluster_label)
183 cluster_dims = []
184 if "dimensions" in cluster_config[cluster_label]:
185 cluster_dims = [d.strip() for d in cluster_config[cluster_label]["dimensions"].split(",")]
186 _LOG.debug("cluster_dims = %s", cluster_dims)
188 found, template = cluster_config[cluster_label].search("clusterTemplate", opt={"replaceVars": False})
189 if not found:
190 if cluster_dims:
191 template = f"{cluster_label}_" + "_".join(f"{{{dim}}}" for dim in cluster_dims)
192 else:
193 template = cluster_label
194 _LOG.debug("template = %s", template)
196 cluster_tasks = [pt.strip() for pt in cluster_config[cluster_label]["pipetasks"].split(",")]
197 for task_label in cluster_tasks:
198 task_labels.add(task_label)
200 # Currently getQuantaForTask is currently a mapping taskDef to
201 # Quanta, so quick enough to call repeatedly.
202 task_def = qgraph.findTaskDefByLabel(task_label)
203 if task_def is None:
204 continue
205 quantum_nodes = qgraph.getNodesForTask(task_def)
207 equal_dims = cluster_config[cluster_label].get("equalDimensions", None)
209 # Determine cluster for each node
210 for qnode in quantum_nodes:
211 # Gather info for cluster name template into a dictionary.
212 info = {}
214 missing_info = set()
215 data_id_info = qnode.quantum.dataId.byName()
216 for dim_name in cluster_dims:
217 _LOG.debug("dim_name = %s", dim_name)
218 if dim_name in data_id_info:
219 info[dim_name] = data_id_info[dim_name]
220 else:
221 missing_info.add(dim_name)
222 if equal_dims:
223 for pair in [pt.strip() for pt in equal_dims.split(",")]:
224 dim1, dim2 = pair.strip().split(":")
225 if dim1 in cluster_dims and dim2 in data_id_info:
226 info[dim1] = data_id_info[dim2]
227 missing_info.remove(dim1)
228 elif dim2 in cluster_dims and dim1 in data_id_info:
229 info[dim2] = data_id_info[dim1]
230 missing_info.remove(dim2)
232 info["label"] = cluster_label
233 _LOG.debug("info for template = %s", info)
235 if missing_info:
236 raise RuntimeError(
237 "Quantum %s (%s) missing dimensions %s required for cluster %s"
238 % (qnode.nodeId, data_id_info, ",".join(missing_info), cluster_label)
239 )
241 # Use dictionary plus template format string to create name.
242 # To avoid # key errors from generic patterns, use defaultdict.
243 cluster_name = template.format_map(defaultdict(lambda: "", info))
244 cluster_name = re.sub("_+", "_", cluster_name)
246 # Some dimensions contain slash which must be replaced.
247 cluster_name = re.sub("/", "_", cluster_name)
248 _LOG.debug("cluster_name = %s", cluster_name)
250 # Save mapping for use when creating dependencies.
251 quantum_to_cluster[qnode.nodeId] = cluster_name
253 # Add cluster to the ClusteredQuantumGraph.
254 # Saving NodeId instead of number because QuantumGraph API
255 # requires it for creating per-job QuantumGraphs.
256 if cluster_name in cqgraph:
257 cluster = cqgraph.get_cluster(cluster_name)
258 else:
259 cluster = QuantaCluster(cluster_name, cluster_label, info)
260 cqgraph.add_cluster(cluster)
261 cluster.add_quantum(qnode.nodeId, task_label)
263 # Assume any task not handled above is supposed to be 1 cluster = 1 quantum
264 for task_def in qgraph.iterTaskGraph():
265 if task_def.label not in task_labels:
266 _LOG.info("Creating 1-quantum clusters for task %s", task_def.label)
267 found, template_data_id = config.search(
268 "templateDataId", opt={"curvals": {"curr_pipetask": task_def.label}, "replaceVars": False}
269 )
270 if found:
271 template = "{node_number}_{label}_" + template_data_id
272 else:
273 template = "{node_number}"
275 for qnode in qgraph.getNodesForTask(task_def):
276 cluster = QuantaCluster.from_quantum_node(qnode, template)
277 cqgraph.add_cluster(cluster)
278 quantum_to_cluster[qnode.nodeId] = cluster.name
280 # Add cluster dependencies.
281 for parent in qgraph:
282 # Get child nodes.
283 children = qgraph.determineOutputsOfQuantumNode(parent)
284 for child in children:
285 try:
286 if quantum_to_cluster[parent.nodeId] != quantum_to_cluster[child.nodeId]:
287 cqgraph.add_dependency(
288 quantum_to_cluster[parent.nodeId], quantum_to_cluster[child.nodeId]
289 )
290 except KeyError as e: # pragma: no cover
291 # For debugging a problem internal to method
292 nid = NodeId(e.args[0], qgraph.graphID)
293 qnode = qgraph.getQuantumNodeByNodeId(nid)
295 print(
296 f"Quanta missing when clustering: {qnode.taskDef.label}, "
297 f"{qnode.quantum.dataId.byName()}"
298 )
299 raise
301 return cqgraph