Coverage for python/lsst/ctrl/bps/clustered_quantum_graph.py: 24%

174 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-09-19 10:54 +0000

1# This file is part of ctrl_bps. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <https://www.gnu.org/licenses/>. 

27 

28"""Class definitions for a Clustered QuantumGraph where a node in the graph is 

29a QuantumGraph. 

30""" 

31 

32__all__ = ["QuantaCluster", "ClusteredQuantumGraph"] 

33 

34 

35import logging 

36import pickle 

37import re 

38from collections import Counter, defaultdict 

39from pathlib import Path 

40 

41from lsst.pipe.base import NodeId, QuantumGraph 

42from lsst.utils.iteration import ensure_iterable 

43from networkx import DiGraph, is_isomorphic, topological_sort 

44 

45from .bps_draw import draw_networkx_dot 

46 

47_LOG = logging.getLogger(__name__) 

48 

49 

50class QuantaCluster: 

51 """Information about the cluster and Quanta belonging to it. 

52 

53 Parameters 

54 ---------- 

55 name: `str` 

56 Lookup key (logical file name) of file/directory. Must 

57 be unique within ClusteredQuantumGraph. 

58 label: `str` 

59 Value used to group clusters. 

60 tags : `dict` [`str`, `Any`], optional 

61 Arbitrary key/value pairs for the cluster. 

62 

63 Raises 

64 ------ 

65 ValueError 

66 Raised if invalid name (e.g., name contains /) 

67 """ 

68 

69 def __init__(self, name, label, tags=None): 

70 if "/" in name: 

71 raise ValueError(f"Cluster's name cannot have a / ({name})") 

72 self.name = name 

73 self.label = label 

74 self._qgraph_node_ids = [] 

75 self._task_label_counts = Counter() 

76 self.tags = tags 

77 if self.tags is None: 

78 self.tags = {} 

79 

80 @classmethod 

81 def from_quantum_node(cls, quantum_node, template): 

82 """Create single quantum cluster from given quantum node. 

83 

84 Parameters 

85 ---------- 

86 quantum_node : `lsst.pipe.base.QuantumNode` 

87 QuantumNode for which to make into a single quantum cluster. 

88 

89 template : `str` 

90 Template for creating cluster name. 

91 

92 Returns 

93 ------- 

94 cluster : `QuantaCluster` 

95 Newly created cluster containing the given quantum. 

96 """ 

97 label = quantum_node.taskDef.label 

98 node_id = quantum_node.nodeId 

99 data_id = quantum_node.quantum.dataId 

100 

101 # Gather info for name template into a dictionary. 

102 info = data_id.byName() 

103 info["label"] = label 

104 info["node_number"] = node_id 

105 _LOG.debug("template = %s", template) 

106 _LOG.debug("info for template = %s", info) 

107 

108 # Use dictionary plus template format string to create name. To avoid 

109 # key errors from generic patterns, use defaultdict. 

110 try: 

111 name = template.format_map(defaultdict(lambda: "", info)) 

112 except TypeError: 

113 _LOG.error("Problems creating cluster name. template='%s', info=%s", template, info) 

114 raise 

115 name = re.sub("_+", "_", name) 

116 _LOG.debug("template name = %s", name) 

117 

118 cluster = QuantaCluster(name, label, info) 

119 cluster.add_quantum(quantum_node.nodeId, label) 

120 return cluster 

121 

122 @property 

123 def qgraph_node_ids(self): 

124 """Quantum graph NodeIds corresponding to this cluster.""" 

125 _LOG.debug("_qgraph_node_ids = %s", self._qgraph_node_ids) 

126 return frozenset(self._qgraph_node_ids) 

127 

128 @property 

129 def quanta_counts(self): 

130 """Counts of Quanta per taskDef.label in this cluster.""" 

131 return Counter(self._task_label_counts) 

132 

133 def add_quantum_node(self, quantum_node): 

134 """Add a quantumNode to this cluster. 

135 

136 Parameters 

137 ---------- 

138 quantum_node : `lsst.pipe.base.QuantumNode` 

139 """ 

140 _LOG.debug("quantum_node = %s", quantum_node) 

141 _LOG.debug("quantum_node.nodeId = %s", quantum_node.nodeId) 

142 self.add_quantum(quantum_node.nodeId, quantum_node.taskDef.label) 

143 

144 def add_quantum(self, node_id, task_label): 

145 """Add a quantumNode to this cluster. 

146 

147 Parameters 

148 ---------- 

149 node_id : `lsst.pipe.base.NodeId` 

150 ID for quantumNode to be added to cluster. 

151 task_label : `str` 

152 Task label for quantumNode to be added to cluster. 

153 """ 

154 self._qgraph_node_ids.append(node_id) 

155 self._task_label_counts[task_label] += 1 

156 

157 def __str__(self): 

158 return ( 

159 f"QuantaCluster(name={self.name},label={self.label},tags={self.tags}," 

160 f"counts={self.quanta_counts},ids={self.qgraph_node_ids})" 

161 ) 

162 

163 def __eq__(self, other: object) -> bool: 

164 # Doesn't check data equality, but only 

165 # name equality since those are supposed 

166 # to be unique. 

167 if isinstance(other, str): 

168 return self.name == other 

169 

170 if isinstance(other, QuantaCluster): 

171 return self.name == other.name 

172 

173 return False 

174 

175 def __hash__(self) -> int: 

176 return hash(self.name) 

177 

178 

179class ClusteredQuantumGraph: 

180 """Graph where the data for a node is a subgraph of the full 

181 QuantumGraph represented by a list of node ids. 

182 

183 Parameters 

184 ---------- 

185 name : `str` 

186 Name to be given to the ClusteredQuantumGraph. 

187 qgraph : `lsst.pipe.base.QuantumGraph` 

188 The QuantumGraph to be clustered. 

189 qgraph_filename : `str` 

190 Filename for given QuantumGraph if it has already been 

191 serialized. 

192 

193 Raises 

194 ------ 

195 ValueError 

196 Raised if invalid name (e.g., name contains /) 

197 

198 Notes 

199 ----- 

200 Using lsst.pipe.base.NodeId instead of integer because the QuantumGraph 

201 API requires them. Chose skipping the repeated creation of objects to 

202 use API over totally minimized memory usage. 

203 """ 

204 

205 def __init__(self, name, qgraph, qgraph_filename=None): 

206 if "/" in name: 

207 raise ValueError(f"name cannot have a / ({name})") 

208 self._name = name 

209 self._quantum_graph = qgraph 

210 self._quantum_graph_filename = Path(qgraph_filename).resolve() 

211 self._cluster_graph = DiGraph() 

212 

213 def __str__(self): 

214 return ( 

215 f"ClusteredQuantumGraph(name={self.name}," 

216 f"quantum_graph_filename={self._quantum_graph_filename}," 

217 f"len(qgraph)={len(self._quantum_graph) if self._quantum_graph else None}," 

218 f"len(cqgraph)={len(self._cluster_graph) if self._cluster_graph else None})" 

219 ) 

220 

221 def __len__(self): 

222 """Return the number of clusters.""" 

223 return len(self._cluster_graph) 

224 

225 def __eq__(self, other): 

226 if not isinstance(other, ClusteredQuantumGraph): 

227 return False 

228 if len(self) != len(other): 

229 return False 

230 return self._quantum_graph == other._quantum_graph and is_isomorphic( 

231 self._cluster_graph, other._cluster_graph 

232 ) 

233 

234 @property 

235 def name(self): 

236 """The name of the ClusteredQuantumGraph.""" 

237 return self._name 

238 

239 @property 

240 def qgraph(self): 

241 """The QuantumGraph associated with this Clustered 

242 QuantumGraph. 

243 """ 

244 return self._quantum_graph 

245 

246 def add_cluster(self, clusters_for_adding): 

247 """Add a cluster of quanta as a node in the graph. 

248 

249 Parameters 

250 ---------- 

251 clusters_for_adding: `QuantaCluster` or `Iterable` [`QuantaCluster`] 

252 The cluster to be added to the ClusteredQuantumGraph. 

253 """ 

254 for cluster in ensure_iterable(clusters_for_adding): 

255 if not isinstance(cluster, QuantaCluster): 

256 raise TypeError(f"Must be type QuantaCluster (given: {type(cluster)})") 

257 

258 if self._cluster_graph.has_node(cluster.name): 

259 raise KeyError(f"Cluster {cluster.name} already exists in ClusteredQuantumGraph") 

260 

261 self._cluster_graph.add_node(cluster.name, cluster=cluster) 

262 

263 def get_cluster(self, name): 

264 """Retrieve a cluster from the ClusteredQuantumGraph by name. 

265 

266 Parameters 

267 ---------- 

268 name : `str` 

269 Name of cluster to retrieve. 

270 

271 Returns 

272 ------- 

273 cluster : `QuantaCluster` 

274 QuantaCluster matching given name. 

275 

276 Raises 

277 ------ 

278 KeyError 

279 Raised if the ClusteredQuantumGraph does not contain 

280 a cluster with given name. 

281 """ 

282 try: 

283 attr = self._cluster_graph.nodes[name] 

284 except KeyError as ex: 

285 raise KeyError(f"{self.name} does not have a cluster named {name}") from ex 

286 return attr["cluster"] 

287 

288 def get_quantum_node(self, id_): 

289 """Retrieve a QuantumNode from the ClusteredQuantumGraph by ID. 

290 

291 Parameters 

292 ---------- 

293 id_ : `lsst.pipe.base.NodeId` or int 

294 ID of the QuantumNode to retrieve. 

295 

296 Returns 

297 ------- 

298 quantum_node : `lsst.pipe.base.QuantumNode` 

299 QuantumNode matching given ID. 

300 

301 Raises 

302 ------ 

303 KeyError 

304 Raised if the ClusteredQuantumGraph does not contain 

305 a QuantumNode with given ID. 

306 """ 

307 node_id = id_ 

308 if isinstance(id_, int): 

309 node_id = NodeId(id, self._quantum_graph.graphID) 

310 _LOG.debug("get_quantum_node: node_id = %s", node_id) 

311 return self._quantum_graph.getQuantumNodeByNodeId(node_id) 

312 

313 def __iter__(self): 

314 """Iterate over names of clusters. 

315 

316 Returns 

317 ------- 

318 names : `Iterator` [`str`] 

319 Iterator over names of clusters. 

320 """ 

321 return self._cluster_graph.nodes() 

322 

323 def clusters(self): 

324 """Iterate over clusters. 

325 

326 Returns 

327 ------- 

328 clusters : `Iterator` [`lsst.ctrl.bps.QuantaCluster`] 

329 Iterator over clusters in topological order. 

330 """ 

331 return map(self.get_cluster, topological_sort(self._cluster_graph)) 

332 

333 def successors(self, name): 

334 """Return clusters that are successors of the cluster 

335 with the given name. 

336 

337 Parameters 

338 ---------- 

339 name : `str` 

340 Name of cluster for which need the successors. 

341 

342 Returns 

343 ------- 

344 clusters : `Iterator` [`lsst.ctrl.bps.QuantaCluster`] 

345 Iterator over successors of given cluster. 

346 """ 

347 return map(self.get_cluster, self._cluster_graph.successors(name)) 

348 

349 def predecessors(self, name): 

350 """Return clusters that are predecessors of the cluster 

351 with the given name. 

352 

353 Parameters 

354 ---------- 

355 name : `str` 

356 Name of cluster for which need the predecessors. 

357 

358 Returns 

359 ------- 

360 clusters : `Iterator` [`lsst.ctrl.bps.QuantaCluster`] 

361 Iterator over predecessors of given cluster. 

362 """ 

363 return map(self.get_cluster, self._cluster_graph.predecessors(name)) 

364 

365 def add_dependency(self, parent, child): 

366 """Add a directed dependency between a parent cluster and a child 

367 cluster. 

368 

369 Parameters 

370 ---------- 

371 parent : `str` or `QuantaCluster` 

372 Parent cluster. 

373 child : `str` or `QuantaCluster` 

374 Child cluster. 

375 

376 Raises 

377 ------ 

378 KeyError 

379 Raised if either the parent or child doesn't exist in the 

380 ClusteredQuantumGraph. 

381 """ 

382 if not self._cluster_graph.has_node(parent): 

383 raise KeyError(f"{self.name} does not have a cluster named {parent}") 

384 if not self._cluster_graph.has_node(child): 

385 raise KeyError(f"{self.name} does not have a cluster named {child}") 

386 _LOG.debug("add_dependency: adding edge %s %s", parent, child) 

387 

388 if isinstance(parent, QuantaCluster): 

389 pname = parent.name 

390 else: 

391 pname = parent 

392 

393 if isinstance(child, QuantaCluster): 

394 cname = child.name 

395 else: 

396 cname = child 

397 self._cluster_graph.add_edge(pname, cname) 

398 

399 def __contains__(self, name): 

400 """Check if a cluster with given name is in this ClusteredQuantumGraph. 

401 

402 Parameters 

403 ---------- 

404 name : `str` 

405 Name of cluster to check. 

406 

407 Returns 

408 ------- 

409 found : `bool` 

410 Whether a cluster with given name is in this ClusteredQuantumGraph. 

411 """ 

412 return self._cluster_graph.has_node(name) 

413 

414 def save(self, filename, format_=None): 

415 """Save the ClusteredQuantumGraph in a format that is loadable. 

416 The QuantumGraph is saved separately if hasn't already been 

417 serialized. 

418 

419 Parameters 

420 ---------- 

421 filename : `str` 

422 File to which the ClusteredQuantumGraph should be serialized. 

423 

424 format_ : `str`, optional 

425 Format in which to write the data. It defaults to pickle format. 

426 """ 

427 path = Path(filename) 

428 

429 # if format is None, try extension 

430 if format_ is None: 

431 format_ = path.suffix[1:] # suffix includes the leading period 

432 

433 if format_ not in {"pickle"}: 

434 raise RuntimeError(f"Unknown format ({format_})") 

435 

436 if not self._quantum_graph_filename: 

437 # Create filename based on given ClusteredQuantumGraph filename 

438 self._quantum_graph_filename = path.with_suffix(".qgraph") 

439 

440 # If QuantumGraph file doesn't already exist, save it: 

441 if not Path(self._quantum_graph_filename).exists(): 

442 self._quantum_graph.saveUri(self._quantum_graph_filename) 

443 

444 if format_ == "pickle": 

445 # Don't save QuantumGraph in same file. 

446 tmp_qgraph = self._quantum_graph 

447 self._quantum_graph = None 

448 with open(filename, "wb") as fh: 

449 pickle.dump(self, fh) 

450 # Return to original state. 

451 self._quantum_graph = tmp_qgraph 

452 

453 def draw(self, filename, format_=None): 

454 """Draw the ClusteredQuantumGraph in a given format. 

455 

456 Parameters 

457 ---------- 

458 filename : `str` 

459 File to which the ClusteredQuantumGraph should be serialized. 

460 

461 format_ : `str`, optional 

462 Format in which to draw the data. It defaults to dot format. 

463 """ 

464 path = Path(filename) 

465 

466 # if format is None, try extension 

467 if format_ is None: 

468 format_ = path.suffix[1:] # suffix includes the leading period 

469 

470 draw_funcs = {"dot": draw_networkx_dot} 

471 if format_ in draw_funcs: 

472 draw_funcs[format_](self._cluster_graph, filename) 

473 else: 

474 raise RuntimeError(f"Unknown draw format ({format_}") 

475 

476 @classmethod 

477 def load(cls, filename, format_=None): 

478 """Load a ClusteredQuantumGraph from the given file. 

479 

480 Parameters 

481 ---------- 

482 filename : `str` 

483 File from which to read the ClusteredQuantumGraph. 

484 format_ : `str`, optional 

485 Format of data to expect when loading from stream. It defaults 

486 to pickle format. 

487 

488 Returns 

489 ------- 

490 ClusteredQuantumGraph : `lsst.ctrl.bps.ClusteredQuantumGraph` 

491 ClusteredQuantumGraph workflow loaded from the given file. 

492 The QuantumGraph is loaded from its own file specified in 

493 the saved ClusteredQuantumGraph. 

494 """ 

495 path = Path(filename) 

496 

497 # if format is None, try extension 

498 if format_ is None: 

499 format_ = path.suffix[1:] # suffix includes the leading period 

500 

501 if format_ not in {"pickle"}: 

502 raise RuntimeError(f"Unknown format ({format_})") 

503 

504 cgraph = None 

505 if format_ == "pickle": 

506 with open(filename, "rb") as fh: 

507 cgraph = pickle.load(fh) 

508 

509 # The QuantumGraph was saved separately 

510 try: 

511 cgraph._quantum_graph = QuantumGraph.loadUri(cgraph._quantum_graph_filename) 

512 except FileNotFoundError: # Try same path as ClusteredQuantumGraph 

513 new_filename = path.parent / Path(cgraph._quantum_graph_filename).name 

514 cgraph._quantum_graph = QuantumGraph.loadUri(new_filename) 

515 

516 return cgraph