Coverage for python/lsst/ctrl/bps/clustered_quantum_graph.py: 23%

Shortcuts on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

166 statements  

1# This file is part of ctrl_bps. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22"""Class definitions for a Clustered QuantumGraph where a node in the graph is 

23a QuantumGraph. 

24""" 

25 

26__all__ = ["QuantaCluster", "ClusteredQuantumGraph"] 

27 

28 

29import logging 

30import re 

31import pickle 

32from collections import Counter, defaultdict 

33from pathlib import Path 

34from networkx import DiGraph 

35 

36from lsst.daf.butler import DimensionUniverse 

37from lsst.daf.butler.core.utils import iterable 

38from lsst.pipe.base import QuantumGraph, NodeId 

39from .bps_draw import draw_networkx_dot 

40 

41 

42_LOG = logging.getLogger(__name__) 

43 

44 

45class QuantaCluster: 

46 """Information about the cluster and Quanta belonging to it. 

47 

48 Parameters 

49 ---------- 

50 name: `str` 

51 Lookup key (logical file name) of file/directory. Must 

52 be unique within ClusteredQuantumGraph. 

53 label: `str` 

54 Value used to group clusters. 

55 tags : `dict` [`str`, `Any`], optional 

56 Arbitrary key/value pairs for the cluster. 

57 

58 Raises 

59 ------ 

60 ValueError 

61 Raised if invalid name (e.g., name contains /) 

62 """ 

63 def __init__(self, name, label, tags=None): 

64 if '/' in name: 

65 raise ValueError(f"Cluster's name cannot have a / ({name})") 

66 self.name = name 

67 self.label = label 

68 self._qgraph_node_ids = [] 

69 self._task_label_counts = Counter() 

70 self.tags = tags 

71 if self.tags is None: 

72 self.tags = {} 

73 

74 @classmethod 

75 def from_quantum_node(cls, quantum_node, template): 

76 """Create single quantum cluster from given quantum node. 

77 

78 Parameters 

79 ---------- 

80 quantum_node : `lsst.pipe.base.QuantumNode` 

81 QuantumNode for which to make into a single quantum cluster. 

82 

83 template : `str` 

84 Template for creating cluster name. 

85 

86 Returns 

87 ------- 

88 cluster : `QuantaCluster` 

89 Newly created cluster containing the given quantum. 

90 """ 

91 label = quantum_node.taskDef.label 

92 node_id = quantum_node.nodeId 

93 data_id = quantum_node.quantum.dataId 

94 

95 # Gather info for name template into a dictionary. 

96 info = data_id.byName() 

97 info["label"] = label 

98 info["node_number"] = node_id.number 

99 _LOG.debug("template = %s", template) 

100 _LOG.debug("info for template = %s", info) 

101 

102 # Use dictionary plus template format string to create name. To avoid 

103 # key errors from generic patterns, use defaultdict. 

104 name = template.format_map(defaultdict(lambda: "", info)) 

105 name = re.sub("_+", "_", name) 

106 _LOG.debug("template name = %s", name) 

107 

108 cluster = QuantaCluster(name, label, info) 

109 cluster.add_quantum(quantum_node.nodeId, label) 

110 return cluster 

111 

112 @property 

113 def qgraph_node_ids(self): 

114 """QuantumGraph NodeIds corresponding to this cluster. 

115 """ 

116 _LOG.debug("_qgraph_node_ids = %s", self._qgraph_node_ids) 

117 return frozenset(self._qgraph_node_ids) 

118 

119 @property 

120 def quanta_counts(self): 

121 """Counts of Quanta per taskDef.label in this cluster. 

122 """ 

123 return Counter(self._task_label_counts) 

124 

125 def add_quantum_node(self, quantum_node): 

126 """Add a quantumNode to this cluster. 

127 

128 Parameters 

129 ---------- 

130 quantum_node : `lsst.pipe.base.QuantumNode` 

131 """ 

132 _LOG.debug("quantum_node = %s", quantum_node) 

133 _LOG.debug("quantum_node.nodeId = %s", quantum_node.nodeId) 

134 self.add_quantum(quantum_node.nodeId, quantum_node.taskDef.label) 

135 

136 def add_quantum(self, node_id, task_label): 

137 """Add a quantumNode to this cluster. 

138 

139 Parameters 

140 ---------- 

141 node_id : `lsst.pipe.base.NodeId` 

142 ID for quantumNode to be added to cluster. 

143 task_label : `str` 

144 Task label for quantumNode to be added to cluster. 

145 """ 

146 self._qgraph_node_ids.append(node_id) 

147 self._task_label_counts[task_label] += 1 

148 

149 def __str__(self): 

150 return f"QuantaCluster(name={self.name},label={self.label},tags={self.tags}," \ 

151 f"counts={self.quanta_counts},ids={self.qgraph_node_ids})" 

152 

153 def __eq__(self, other: object) -> bool: 

154 # Doesn't check data equality, but only 

155 # name equality since those are supposed 

156 # to be unique. 

157 if isinstance(other, str): 

158 return self.name == other 

159 

160 if isinstance(other, QuantaCluster): 

161 return self.name == other.name 

162 

163 return False 

164 

165 def __hash__(self) -> int: 

166 return hash(self.name) 

167 

168 

169class ClusteredQuantumGraph: 

170 """Graph where the data for a node is a subgraph of the full 

171 QuantumGraph represented by a list of node ids. 

172 

173 Parameters 

174 ---------- 

175 name : `str` 

176 Name to be given to the ClusteredQuantumGraph. 

177 qgraph : `lsst.pipe.base.QuantumGraph` 

178 The QuantumGraph to be clustered. 

179 qgraph_filename : `str` 

180 Filename for given QuantumGraph if it has already been 

181 serialized. 

182 

183 Raises 

184 ------ 

185 ValueError 

186 Raised if invalid name (e.g., name contains /) 

187 

188 Notes 

189 ----- 

190 Using lsst.pipe.base.NodeId instead of integer because the QuantumGraph 

191 API requires them. Chose skipping the repeated creation of objects to 

192 use API over totally minimized memory usage. 

193 """ 

194 

195 def __init__(self, name, qgraph, qgraph_filename=None): 

196 if '/' in name: 

197 raise ValueError(f"name cannot have a / ({name})") 

198 self._name = name 

199 self._quantum_graph = qgraph 

200 self._quantum_graph_filename = Path(qgraph_filename).resolve() 

201 self._cluster_graph = DiGraph() 

202 

203 def __str__(self): 

204 return f"ClusteredQuantumGraph(name={self.name}," \ 

205 f"quantum_graph_filename={self._quantum_graph_filename}," \ 

206 f"len(qgraph)={len(self._quantum_graph) if self._quantum_graph else None}," \ 

207 f"len(cqgraph)={len(self._cluster_graph) if self._cluster_graph else None})" 

208 

209 def __len__(self): 

210 """Return the number of clusters. 

211 """ 

212 return len(self._cluster_graph) 

213 

214 @property 

215 def name(self): 

216 """The name of the ClusteredQuantumGraph. 

217 """ 

218 return self._name 

219 

220 @property 

221 def qgraph(self): 

222 """The QuantumGraph associated with this Clustered 

223 QuantumGraph. 

224 """ 

225 return self._quantum_graph 

226 

227 def add_cluster(self, clusters_for_adding): 

228 """Add a cluster of quanta as a node in the graph. 

229 

230 Parameters 

231 ---------- 

232 clusters_for_adding: `QuantaCluster` or `Iterable` [`QuantaCluster`] 

233 The cluster to be added to the ClusteredQuantumGraph. 

234 """ 

235 for cluster in iterable(clusters_for_adding): 

236 if not isinstance(cluster, QuantaCluster): 

237 raise TypeError(f"Must be type QuantaCluster (given: {type(cluster)})") 

238 

239 if self._cluster_graph.has_node(cluster.name): 

240 raise KeyError(f"Cluster {cluster.name} already exists in ClusteredQuantumGraph") 

241 

242 self._cluster_graph.add_node(cluster.name, cluster=cluster) 

243 

244 def get_cluster(self, name): 

245 """Retrieve a cluster from the ClusteredQuantumGraph by name. 

246 

247 Parameters 

248 ---------- 

249 name : `str` 

250 Name of cluster to retrieve. 

251 

252 Returns 

253 ------- 

254 cluster : `QuantaCluster` 

255 QuantaCluster matching given name. 

256 

257 Raises 

258 ------ 

259 KeyError 

260 Raised if the ClusteredQuantumGraph does not contain 

261 a cluster with given name. 

262 """ 

263 try: 

264 attr = self._cluster_graph.nodes[name] 

265 except KeyError as ex: 

266 raise KeyError(f"{self.name} does not have a cluster named {name}") from ex 

267 return attr['cluster'] 

268 

269 def get_quantum_node(self, id_): 

270 """Retrieve a QuantumNode from the ClusteredQuantumGraph by ID. 

271 

272 Parameters 

273 ---------- 

274 id_ : `lsst.pipe.base.NodeId` or int 

275 ID of the QuantumNode to retrieve. 

276 

277 Returns 

278 ------- 

279 quantum_node : `lsst.pipe.base.QuantumNode` 

280 QuantumNode matching given ID. 

281 

282 Raises 

283 ------ 

284 KeyError 

285 Raised if the ClusteredQuantumGraph does not contain 

286 a QuantumNode with given ID. 

287 """ 

288 node_id = id_ 

289 if isinstance(id_, int): 

290 node_id = NodeId(id, self._quantum_graph.graphID) 

291 _LOG.debug("get_quantum_node: node_id = %s", node_id) 

292 return self._quantum_graph.getQuantumNodeByNodeId(node_id) 

293 

294 def __iter__(self): 

295 """Iterate over names of clusters. 

296 

297 Returns 

298 ------- 

299 names : `Iterator` [`str`] 

300 Iterator over names of clusters. 

301 """ 

302 return self._cluster_graph.nodes() 

303 

304 def clusters(self): 

305 """Iterate over clusters. 

306 

307 Returns 

308 ------- 

309 clusters : `Iterator` [`lsst.ctrl.bps.QuantaCluster`] 

310 Iterator over clusters. 

311 """ 

312 return map(self.get_cluster, self._cluster_graph.nodes()) 

313 

314 def successors(self, name): 

315 """Return clusters that are successors of the cluster 

316 with the given name. 

317 

318 Parameters 

319 ---------- 

320 name : `str` 

321 Name of cluster for which need the successors. 

322 

323 Returns 

324 ------- 

325 clusters : `Iterator` [`lsst.ctrl.bps.QuantaCluster`] 

326 Iterator over successors of given cluster. 

327 """ 

328 return map(self.get_cluster, self._cluster_graph.successors(name)) 

329 

330 def predecessors(self, name): 

331 """Return clusters that are predecessors of the cluster 

332 with the given name. 

333 

334 Parameters 

335 ---------- 

336 name : `str` 

337 Name of cluster for which need the predecessors. 

338 

339 Returns 

340 ------- 

341 clusters : `Iterator` [`lsst.ctrl.bps.QuantaCluster`] 

342 Iterator over predecessors of given cluster. 

343 """ 

344 return map(self.get_cluster, self._cluster_graph.predecessors(name)) 

345 

346 def add_dependency(self, parent, child): 

347 """Add a directed dependency between a parent cluster and a child 

348 cluster. 

349 

350 Parameters 

351 ---------- 

352 parent : `str` or `QuantaCluster` 

353 Parent cluster. 

354 child : `str` or `QuantaCluster` 

355 Child cluster. 

356 

357 Raises 

358 ------ 

359 KeyError 

360 Raised if either the parent or child doesn't exist in the 

361 ClusteredQuantumGraph. 

362 """ 

363 if not self._cluster_graph.has_node(parent): 

364 raise KeyError(f"{self.name} does not have a cluster named {parent}") 

365 if not self._cluster_graph.has_node(child): 

366 raise KeyError(f"{self.name} does not have a cluster named {child}") 

367 _LOG.debug("add_dependency: adding edge %s %s", parent, child) 

368 

369 if isinstance(parent, QuantaCluster): 

370 pname = parent.name 

371 else: 

372 pname = parent 

373 

374 if isinstance(child, QuantaCluster): 

375 cname = child.name 

376 else: 

377 cname = child 

378 self._cluster_graph.add_edge(pname, cname) 

379 

380 def __contains__(self, name): 

381 """Check if a cluster with given name is in this ClusteredQuantumGraph. 

382 

383 Parameters 

384 ---------- 

385 name : `str` 

386 Name of cluster to check. 

387 

388 Returns 

389 ------- 

390 found : `bool` 

391 Whether a cluster with given name is in this ClusteredQuantumGraph. 

392 """ 

393 return self._cluster_graph.has_node(name) 

394 

395 def save(self, filename, format_=None): 

396 """Save the ClusteredQuantumGraph in a format that is loadable. 

397 The QuantumGraph is saved separately if hasn't already been 

398 serialized. 

399 

400 Parameters 

401 ---------- 

402 filename : `str` 

403 File to which the ClusteredQuantumGraph should be serialized. 

404 

405 format_ : `str`, optional 

406 Format in which to write the data. It defaults to pickle format. 

407 """ 

408 path = Path(filename) 

409 

410 # if format is None, try extension 

411 if format_ is None: 

412 format_ = path.suffix[1:] # suffix includes the leading period 

413 

414 if format_ not in {"pickle"}: 

415 raise RuntimeError(f"Unknown format ({format_})") 

416 

417 if not self._quantum_graph_filename: 

418 # Create filename based on given ClusteredQuantumGraph filename 

419 self._quantum_graph_filename = path.with_suffix('.qgraph') 

420 

421 # If QuantumGraph file doesn't already exist, save it: 

422 if not Path(self._quantum_graph_filename).exists(): 

423 self._quantum_graph.saveUri(self._quantum_graph_filename) 

424 

425 if format_ == "pickle": 

426 # Don't save QuantumGraph in same file. 

427 tmp_qgraph = self._quantum_graph 

428 self._quantum_graph = None 

429 with open(filename, "wb") as fh: 

430 pickle.dump(self, fh) 

431 # Return to original state. 

432 self._quantum_graph = tmp_qgraph 

433 

434 def draw(self, filename, format_=None): 

435 """Draw the ClusteredQuantumGraph in a given format. 

436 

437 Parameters 

438 ---------- 

439 filename : `str` 

440 File to which the ClusteredQuantumGraph should be serialized. 

441 

442 format_ : `str`, optional 

443 Format in which to draw the data. It defaults to dot format. 

444 """ 

445 path = Path(filename) 

446 

447 # if format is None, try extension 

448 if format_ is None: 

449 format_ = path.suffix[1:] # suffix includes the leading period 

450 

451 draw_funcs = {"dot": draw_networkx_dot} 

452 if format_ in draw_funcs: 

453 draw_funcs[format_](self._cluster_graph, filename) 

454 else: 

455 raise RuntimeError(f"Unknown draw format ({format_}") 

456 

457 @classmethod 

458 def load(cls, filename, format_=None): 

459 """Load a ClusteredQuantumGraph from the given file. 

460 

461 Parameters 

462 ---------- 

463 filename : `str` 

464 File from which to read the ClusteredQuantumGraph. 

465 format_ : `str`, optional 

466 Format of data to expect when loading from stream. It defaults 

467 to pickle format. 

468 

469 Returns 

470 ------- 

471 ClusteredQuantumGraph : `lsst.ctrl.bps.ClusteredQuantumGraph` 

472 ClusteredQuantumGraph workflow loaded from the given file. 

473 The QuantumGraph is loaded from its own file specified in 

474 the saved ClusteredQuantumGraph. 

475 """ 

476 path = Path(filename) 

477 

478 # if format is None, try extension 

479 if format_ is None: 

480 format_ = path.suffix[1:] # suffix includes the leading period 

481 

482 if format_ not in {"pickle"}: 

483 raise RuntimeError(f"Unknown format ({format_})") 

484 

485 cgraph = None 

486 if format_ == "pickle": 

487 dim_univ = DimensionUniverse() 

488 with open(filename, "rb") as fh: 

489 cgraph = pickle.load(fh) 

490 

491 # The QuantumGraph was saved separately 

492 try: 

493 cgraph._quantum_graph = QuantumGraph.loadUri(cgraph._quantum_graph_filename, dim_univ) 

494 except FileNotFoundError: # Try same path as ClusteredQuantumGraph 

495 new_filename = path.parent / Path(cgraph._quantum_graph_filename).name 

496 cgraph._quantum_graph = QuantumGraph.loadUri(new_filename, dim_univ) 

497 

498 return cgraph