Coverage for python/lsst/ctrl/bps/clustered_quantum_graph.py: 24%

174 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-20 11:11 +0000

1# This file is part of ctrl_bps. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <https://www.gnu.org/licenses/>. 

27 

28"""Class definitions for a Clustered QuantumGraph where a node in the graph is 

29a QuantumGraph. 

30""" 

31 

32__all__ = ["QuantaCluster", "ClusteredQuantumGraph"] 

33 

34 

35import logging 

36import pickle 

37import re 

38from collections import Counter, defaultdict 

39from pathlib import Path 

40 

41from lsst.pipe.base import NodeId, QuantumGraph 

42from lsst.utils.iteration import ensure_iterable 

43from networkx import DiGraph, is_isomorphic, topological_sort 

44 

45from .bps_draw import draw_networkx_dot 

46 

47_LOG = logging.getLogger(__name__) 

48 

49 

50class QuantaCluster: 

51 """Information about the cluster and Quanta belonging to it. 

52 

53 Parameters 

54 ---------- 

55 name : `str` 

56 Lookup key (logical file name) of file/directory. Must 

57 be unique within ClusteredQuantumGraph. 

58 label : `str` 

59 Value used to group clusters. 

60 tags : `dict` [`str`, `Any`], optional 

61 Arbitrary key/value pairs for the cluster. 

62 

63 Raises 

64 ------ 

65 ValueError 

66 Raised if invalid name (e.g., name contains /). 

67 """ 

68 

69 def __init__(self, name, label, tags=None): 

70 if "/" in name: 

71 raise ValueError(f"Cluster's name cannot have a / ({name})") 

72 self.name = name 

73 self.label = label 

74 self._qgraph_node_ids = [] 

75 self._task_label_counts = Counter() 

76 self.tags = tags 

77 if self.tags is None: 

78 self.tags = {} 

79 

80 @classmethod 

81 def from_quantum_node(cls, quantum_node, template): 

82 """Create single quantum cluster from given quantum node. 

83 

84 Parameters 

85 ---------- 

86 quantum_node : `lsst.pipe.base.QuantumNode` 

87 QuantumNode for which to make into a single quantum cluster. 

88 template : `str` 

89 Template for creating cluster name. 

90 

91 Returns 

92 ------- 

93 cluster : `QuantaCluster` 

94 Newly created cluster containing the given quantum. 

95 """ 

96 label = quantum_node.taskDef.label 

97 node_id = quantum_node.nodeId 

98 data_id = quantum_node.quantum.dataId 

99 

100 # Gather info for name template into a dictionary. 

101 info = dict(data_id.required) 

102 info["label"] = label 

103 info["node_number"] = node_id 

104 _LOG.debug("template = %s", template) 

105 _LOG.debug("info for template = %s", info) 

106 

107 # Use dictionary plus template format string to create name. To avoid 

108 # key errors from generic patterns, use defaultdict. 

109 try: 

110 name = template.format_map(defaultdict(lambda: "", info)) 

111 except TypeError: 

112 _LOG.error("Problems creating cluster name. template='%s', info=%s", template, info) 

113 raise 

114 name = re.sub("_+", "_", name) 

115 _LOG.debug("template name = %s", name) 

116 

117 cluster = QuantaCluster(name, label, info) 

118 cluster.add_quantum(quantum_node.nodeId, label) 

119 return cluster 

120 

121 @property 

122 def qgraph_node_ids(self): 

123 """Quantum graph NodeIds corresponding to this cluster.""" 

124 _LOG.debug("_qgraph_node_ids = %s", self._qgraph_node_ids) 

125 return frozenset(self._qgraph_node_ids) 

126 

127 @property 

128 def quanta_counts(self): 

129 """Counts of Quanta per taskDef.label in this cluster.""" 

130 return Counter(self._task_label_counts) 

131 

132 def add_quantum_node(self, quantum_node): 

133 """Add a quantumNode to this cluster. 

134 

135 Parameters 

136 ---------- 

137 quantum_node : `lsst.pipe.base.QuantumNode` 

138 Quantum node to add. 

139 """ 

140 _LOG.debug("quantum_node = %s", quantum_node) 

141 _LOG.debug("quantum_node.nodeId = %s", quantum_node.nodeId) 

142 self.add_quantum(quantum_node.nodeId, quantum_node.taskDef.label) 

143 

144 def add_quantum(self, node_id, task_label): 

145 """Add a quantumNode to this cluster. 

146 

147 Parameters 

148 ---------- 

149 node_id : `lsst.pipe.base.NodeId` 

150 ID for quantumNode to be added to cluster. 

151 task_label : `str` 

152 Task label for quantumNode to be added to cluster. 

153 """ 

154 self._qgraph_node_ids.append(node_id) 

155 self._task_label_counts[task_label] += 1 

156 

157 def __str__(self): 

158 return ( 

159 f"QuantaCluster(name={self.name},label={self.label},tags={self.tags}," 

160 f"counts={self.quanta_counts},ids={self.qgraph_node_ids})" 

161 ) 

162 

163 def __eq__(self, other: object) -> bool: 

164 # Doesn't check data equality, but only 

165 # name equality since those are supposed 

166 # to be unique. 

167 if isinstance(other, str): 

168 return self.name == other 

169 

170 if isinstance(other, QuantaCluster): 

171 return self.name == other.name 

172 

173 return False 

174 

175 def __hash__(self) -> int: 

176 return hash(self.name) 

177 

178 

179class ClusteredQuantumGraph: 

180 """Graph where the data for a node is a subgraph of the full 

181 QuantumGraph represented by a list of node ids. 

182 

183 Parameters 

184 ---------- 

185 name : `str` 

186 Name to be given to the ClusteredQuantumGraph. 

187 qgraph : `lsst.pipe.base.QuantumGraph` 

188 The QuantumGraph to be clustered. 

189 qgraph_filename : `str` 

190 Filename for given QuantumGraph if it has already been 

191 serialized. 

192 

193 Raises 

194 ------ 

195 ValueError 

196 Raised if invalid name (e.g., name contains /) 

197 

198 Notes 

199 ----- 

200 Using lsst.pipe.base.NodeId instead of integer because the QuantumGraph 

201 API requires them. Chose skipping the repeated creation of objects to 

202 use API over totally minimized memory usage. 

203 """ 

204 

205 def __init__(self, name, qgraph, qgraph_filename=None): 

206 if "/" in name: 

207 raise ValueError(f"name cannot have a / ({name})") 

208 self._name = name 

209 self._quantum_graph = qgraph 

210 self._quantum_graph_filename = Path(qgraph_filename).resolve() 

211 self._cluster_graph = DiGraph() 

212 

213 def __str__(self): 

214 return ( 

215 f"ClusteredQuantumGraph(name={self.name}," 

216 f"quantum_graph_filename={self._quantum_graph_filename}," 

217 f"len(qgraph)={len(self._quantum_graph) if self._quantum_graph else None}," 

218 f"len(cqgraph)={len(self._cluster_graph) if self._cluster_graph else None})" 

219 ) 

220 

221 def __len__(self): 

222 """Return the number of clusters.""" 

223 return len(self._cluster_graph) 

224 

225 def __eq__(self, other): 

226 if not isinstance(other, ClusteredQuantumGraph): 

227 return False 

228 if len(self) != len(other): 

229 return False 

230 return self._quantum_graph == other._quantum_graph and is_isomorphic( 

231 self._cluster_graph, other._cluster_graph 

232 ) 

233 

234 @property 

235 def name(self): 

236 """The name of the ClusteredQuantumGraph.""" 

237 return self._name 

238 

239 @property 

240 def qgraph(self): 

241 """The QuantumGraph associated with this Clustered 

242 QuantumGraph. 

243 """ 

244 return self._quantum_graph 

245 

246 def add_cluster(self, clusters_for_adding): 

247 """Add a cluster of quanta as a node in the graph. 

248 

249 Parameters 

250 ---------- 

251 clusters_for_adding : `QuantaCluster` or `Iterable` [`QuantaCluster`] 

252 The cluster to be added to the ClusteredQuantumGraph. 

253 """ 

254 for cluster in ensure_iterable(clusters_for_adding): 

255 if not isinstance(cluster, QuantaCluster): 

256 raise TypeError(f"Must be type QuantaCluster (given: {type(cluster)})") 

257 

258 if self._cluster_graph.has_node(cluster.name): 

259 raise KeyError(f"Cluster {cluster.name} already exists in ClusteredQuantumGraph") 

260 

261 self._cluster_graph.add_node(cluster.name, cluster=cluster) 

262 

263 def get_cluster(self, name): 

264 """Retrieve a cluster from the ClusteredQuantumGraph by name. 

265 

266 Parameters 

267 ---------- 

268 name : `str` 

269 Name of cluster to retrieve. 

270 

271 Returns 

272 ------- 

273 cluster : `QuantaCluster` 

274 QuantaCluster matching given name. 

275 

276 Raises 

277 ------ 

278 KeyError 

279 Raised if the ClusteredQuantumGraph does not contain 

280 a cluster with given name. 

281 """ 

282 try: 

283 attr = self._cluster_graph.nodes[name] 

284 except KeyError as ex: 

285 raise KeyError(f"{self.name} does not have a cluster named {name}") from ex 

286 return attr["cluster"] 

287 

288 def get_quantum_node(self, id_): 

289 """Retrieve a QuantumNode from the ClusteredQuantumGraph by ID. 

290 

291 Parameters 

292 ---------- 

293 id_ : `lsst.pipe.base.NodeId` or int 

294 ID of the QuantumNode to retrieve. 

295 

296 Returns 

297 ------- 

298 quantum_node : `lsst.pipe.base.QuantumNode` 

299 QuantumNode matching given ID. 

300 

301 Raises 

302 ------ 

303 KeyError 

304 Raised if the ClusteredQuantumGraph does not contain 

305 a QuantumNode with given ID. 

306 """ 

307 node_id = id_ 

308 if isinstance(id_, int): 

309 node_id = NodeId(id, self._quantum_graph.graphID) 

310 _LOG.debug("get_quantum_node: node_id = %s", node_id) 

311 return self._quantum_graph.getQuantumNodeByNodeId(node_id) 

312 

313 def __iter__(self): 

314 """Iterate over names of clusters. 

315 

316 Returns 

317 ------- 

318 names : `Iterator` [`str`] 

319 Iterator over names of clusters. 

320 """ 

321 return self._cluster_graph.nodes() 

322 

323 def clusters(self): 

324 """Iterate over clusters. 

325 

326 Returns 

327 ------- 

328 clusters : `Iterator` [`lsst.ctrl.bps.QuantaCluster`] 

329 Iterator over clusters in topological order. 

330 """ 

331 return map(self.get_cluster, topological_sort(self._cluster_graph)) 

332 

333 def successors(self, name): 

334 """Return clusters that are successors of the cluster 

335 with the given name. 

336 

337 Parameters 

338 ---------- 

339 name : `str` 

340 Name of cluster for which need the successors. 

341 

342 Returns 

343 ------- 

344 clusters : `Iterator` [`lsst.ctrl.bps.QuantaCluster`] 

345 Iterator over successors of given cluster. 

346 """ 

347 return map(self.get_cluster, self._cluster_graph.successors(name)) 

348 

349 def predecessors(self, name): 

350 """Return clusters that are predecessors of the cluster 

351 with the given name. 

352 

353 Parameters 

354 ---------- 

355 name : `str` 

356 Name of cluster for which need the predecessors. 

357 

358 Returns 

359 ------- 

360 clusters : `Iterator` [`lsst.ctrl.bps.QuantaCluster`] 

361 Iterator over predecessors of given cluster. 

362 """ 

363 return map(self.get_cluster, self._cluster_graph.predecessors(name)) 

364 

365 def add_dependency(self, parent, child): 

366 """Add a directed dependency between a parent cluster and a child 

367 cluster. 

368 

369 Parameters 

370 ---------- 

371 parent : `str` or `QuantaCluster` 

372 Parent cluster. 

373 child : `str` or `QuantaCluster` 

374 Child cluster. 

375 

376 Raises 

377 ------ 

378 KeyError 

379 Raised if either the parent or child doesn't exist in the 

380 ClusteredQuantumGraph. 

381 """ 

382 if not self._cluster_graph.has_node(parent): 

383 raise KeyError(f"{self.name} does not have a cluster named {parent}") 

384 if not self._cluster_graph.has_node(child): 

385 raise KeyError(f"{self.name} does not have a cluster named {child}") 

386 _LOG.debug("add_dependency: adding edge %s %s", parent, child) 

387 

388 if isinstance(parent, QuantaCluster): 

389 pname = parent.name 

390 else: 

391 pname = parent 

392 

393 if isinstance(child, QuantaCluster): 

394 cname = child.name 

395 else: 

396 cname = child 

397 self._cluster_graph.add_edge(pname, cname) 

398 

399 def __contains__(self, name): 

400 """Check if a cluster with given name is in this ClusteredQuantumGraph. 

401 

402 Parameters 

403 ---------- 

404 name : `str` 

405 Name of cluster to check. 

406 

407 Returns 

408 ------- 

409 found : `bool` 

410 Whether a cluster with given name is in this ClusteredQuantumGraph. 

411 """ 

412 return self._cluster_graph.has_node(name) 

413 

414 def save(self, filename, format_=None): 

415 """Save the ClusteredQuantumGraph in a format that is loadable. 

416 The QuantumGraph is saved separately if hasn't already been 

417 serialized. 

418 

419 Parameters 

420 ---------- 

421 filename : `str` 

422 File to which the ClusteredQuantumGraph should be serialized. 

423 format_ : `str`, optional 

424 Format in which to write the data. It defaults to pickle format. 

425 """ 

426 path = Path(filename) 

427 

428 # if format is None, try extension 

429 if format_ is None: 

430 format_ = path.suffix[1:] # suffix includes the leading period 

431 

432 if format_ not in {"pickle"}: 

433 raise RuntimeError(f"Unknown format ({format_})") 

434 

435 if not self._quantum_graph_filename: 

436 # Create filename based on given ClusteredQuantumGraph filename 

437 self._quantum_graph_filename = path.with_suffix(".qgraph") 

438 

439 # If QuantumGraph file doesn't already exist, save it: 

440 if not Path(self._quantum_graph_filename).exists(): 

441 self._quantum_graph.saveUri(self._quantum_graph_filename) 

442 

443 if format_ == "pickle": 

444 # Don't save QuantumGraph in same file. 

445 tmp_qgraph = self._quantum_graph 

446 self._quantum_graph = None 

447 with open(filename, "wb") as fh: 

448 pickle.dump(self, fh) 

449 # Return to original state. 

450 self._quantum_graph = tmp_qgraph 

451 

452 def draw(self, filename, format_=None): 

453 """Draw the ClusteredQuantumGraph in a given format. 

454 

455 Parameters 

456 ---------- 

457 filename : `str` 

458 File to which the ClusteredQuantumGraph should be serialized. 

459 format_ : `str`, optional 

460 Format in which to draw the data. It defaults to dot format. 

461 """ 

462 path = Path(filename) 

463 

464 # if format is None, try extension 

465 if format_ is None: 

466 format_ = path.suffix[1:] # suffix includes the leading period 

467 

468 draw_funcs = {"dot": draw_networkx_dot} 

469 if format_ in draw_funcs: 

470 draw_funcs[format_](self._cluster_graph, filename) 

471 else: 

472 raise RuntimeError(f"Unknown draw format ({format_}") 

473 

474 @classmethod 

475 def load(cls, filename, format_=None): 

476 """Load a ClusteredQuantumGraph from the given file. 

477 

478 Parameters 

479 ---------- 

480 filename : `str` 

481 File from which to read the ClusteredQuantumGraph. 

482 format_ : `str`, optional 

483 Format of data to expect when loading from stream. It defaults 

484 to pickle format. 

485 

486 Returns 

487 ------- 

488 ClusteredQuantumGraph : `lsst.ctrl.bps.ClusteredQuantumGraph` 

489 ClusteredQuantumGraph workflow loaded from the given file. 

490 The QuantumGraph is loaded from its own file specified in 

491 the saved ClusteredQuantumGraph. 

492 """ 

493 path = Path(filename) 

494 

495 # if format is None, try extension 

496 if format_ is None: 

497 format_ = path.suffix[1:] # suffix includes the leading period 

498 

499 if format_ not in {"pickle"}: 

500 raise RuntimeError(f"Unknown format ({format_})") 

501 

502 cgraph = None 

503 if format_ == "pickle": 

504 with open(filename, "rb") as fh: 

505 cgraph = pickle.load(fh) 

506 

507 # The QuantumGraph was saved separately 

508 try: 

509 cgraph._quantum_graph = QuantumGraph.loadUri(cgraph._quantum_graph_filename) 

510 except FileNotFoundError: # Try same path as ClusteredQuantumGraph 

511 new_filename = path.parent / Path(cgraph._quantum_graph_filename).name 

512 cgraph._quantum_graph = QuantumGraph.loadUri(new_filename) 

513 

514 return cgraph