Coverage for python/lsst/ctrl/bps/clustered_quantum_graph.py: 24%

174 statements  

« prev     ^ index     » next       coverage.py v7.3.0, created at 2023-09-02 09:44 +0000

1# This file is part of ctrl_bps. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22"""Class definitions for a Clustered QuantumGraph where a node in the graph is 

23a QuantumGraph. 

24""" 

25 

26__all__ = ["QuantaCluster", "ClusteredQuantumGraph"] 

27 

28 

29import logging 

30import pickle 

31import re 

32from collections import Counter, defaultdict 

33from pathlib import Path 

34 

35from lsst.pipe.base import NodeId, QuantumGraph 

36from lsst.utils.iteration import ensure_iterable 

37from networkx import DiGraph, is_isomorphic, topological_sort 

38 

39from .bps_draw import draw_networkx_dot 

40 

41_LOG = logging.getLogger(__name__) 

42 

43 

44class QuantaCluster: 

45 """Information about the cluster and Quanta belonging to it. 

46 

47 Parameters 

48 ---------- 

49 name: `str` 

50 Lookup key (logical file name) of file/directory. Must 

51 be unique within ClusteredQuantumGraph. 

52 label: `str` 

53 Value used to group clusters. 

54 tags : `dict` [`str`, `Any`], optional 

55 Arbitrary key/value pairs for the cluster. 

56 

57 Raises 

58 ------ 

59 ValueError 

60 Raised if invalid name (e.g., name contains /) 

61 """ 

62 

63 def __init__(self, name, label, tags=None): 

64 if "/" in name: 

65 raise ValueError(f"Cluster's name cannot have a / ({name})") 

66 self.name = name 

67 self.label = label 

68 self._qgraph_node_ids = [] 

69 self._task_label_counts = Counter() 

70 self.tags = tags 

71 if self.tags is None: 

72 self.tags = {} 

73 

74 @classmethod 

75 def from_quantum_node(cls, quantum_node, template): 

76 """Create single quantum cluster from given quantum node. 

77 

78 Parameters 

79 ---------- 

80 quantum_node : `lsst.pipe.base.QuantumNode` 

81 QuantumNode for which to make into a single quantum cluster. 

82 

83 template : `str` 

84 Template for creating cluster name. 

85 

86 Returns 

87 ------- 

88 cluster : `QuantaCluster` 

89 Newly created cluster containing the given quantum. 

90 """ 

91 label = quantum_node.taskDef.label 

92 node_id = quantum_node.nodeId 

93 data_id = quantum_node.quantum.dataId 

94 

95 # Gather info for name template into a dictionary. 

96 info = data_id.byName() 

97 info["label"] = label 

98 info["node_number"] = node_id 

99 _LOG.debug("template = %s", template) 

100 _LOG.debug("info for template = %s", info) 

101 

102 # Use dictionary plus template format string to create name. To avoid 

103 # key errors from generic patterns, use defaultdict. 

104 try: 

105 name = template.format_map(defaultdict(lambda: "", info)) 

106 except TypeError: 

107 _LOG.error("Problems creating cluster name. template='%s', info=%s", template, info) 

108 raise 

109 name = re.sub("_+", "_", name) 

110 _LOG.debug("template name = %s", name) 

111 

112 cluster = QuantaCluster(name, label, info) 

113 cluster.add_quantum(quantum_node.nodeId, label) 

114 return cluster 

115 

116 @property 

117 def qgraph_node_ids(self): 

118 """Quantum graph NodeIds corresponding to this cluster.""" 

119 _LOG.debug("_qgraph_node_ids = %s", self._qgraph_node_ids) 

120 return frozenset(self._qgraph_node_ids) 

121 

122 @property 

123 def quanta_counts(self): 

124 """Counts of Quanta per taskDef.label in this cluster.""" 

125 return Counter(self._task_label_counts) 

126 

127 def add_quantum_node(self, quantum_node): 

128 """Add a quantumNode to this cluster. 

129 

130 Parameters 

131 ---------- 

132 quantum_node : `lsst.pipe.base.QuantumNode` 

133 """ 

134 _LOG.debug("quantum_node = %s", quantum_node) 

135 _LOG.debug("quantum_node.nodeId = %s", quantum_node.nodeId) 

136 self.add_quantum(quantum_node.nodeId, quantum_node.taskDef.label) 

137 

138 def add_quantum(self, node_id, task_label): 

139 """Add a quantumNode to this cluster. 

140 

141 Parameters 

142 ---------- 

143 node_id : `lsst.pipe.base.NodeId` 

144 ID for quantumNode to be added to cluster. 

145 task_label : `str` 

146 Task label for quantumNode to be added to cluster. 

147 """ 

148 self._qgraph_node_ids.append(node_id) 

149 self._task_label_counts[task_label] += 1 

150 

151 def __str__(self): 

152 return ( 

153 f"QuantaCluster(name={self.name},label={self.label},tags={self.tags}," 

154 f"counts={self.quanta_counts},ids={self.qgraph_node_ids})" 

155 ) 

156 

157 def __eq__(self, other: object) -> bool: 

158 # Doesn't check data equality, but only 

159 # name equality since those are supposed 

160 # to be unique. 

161 if isinstance(other, str): 

162 return self.name == other 

163 

164 if isinstance(other, QuantaCluster): 

165 return self.name == other.name 

166 

167 return False 

168 

169 def __hash__(self) -> int: 

170 return hash(self.name) 

171 

172 

173class ClusteredQuantumGraph: 

174 """Graph where the data for a node is a subgraph of the full 

175 QuantumGraph represented by a list of node ids. 

176 

177 Parameters 

178 ---------- 

179 name : `str` 

180 Name to be given to the ClusteredQuantumGraph. 

181 qgraph : `lsst.pipe.base.QuantumGraph` 

182 The QuantumGraph to be clustered. 

183 qgraph_filename : `str` 

184 Filename for given QuantumGraph if it has already been 

185 serialized. 

186 

187 Raises 

188 ------ 

189 ValueError 

190 Raised if invalid name (e.g., name contains /) 

191 

192 Notes 

193 ----- 

194 Using lsst.pipe.base.NodeId instead of integer because the QuantumGraph 

195 API requires them. Chose skipping the repeated creation of objects to 

196 use API over totally minimized memory usage. 

197 """ 

198 

199 def __init__(self, name, qgraph, qgraph_filename=None): 

200 if "/" in name: 

201 raise ValueError(f"name cannot have a / ({name})") 

202 self._name = name 

203 self._quantum_graph = qgraph 

204 self._quantum_graph_filename = Path(qgraph_filename).resolve() 

205 self._cluster_graph = DiGraph() 

206 

207 def __str__(self): 

208 return ( 

209 f"ClusteredQuantumGraph(name={self.name}," 

210 f"quantum_graph_filename={self._quantum_graph_filename}," 

211 f"len(qgraph)={len(self._quantum_graph) if self._quantum_graph else None}," 

212 f"len(cqgraph)={len(self._cluster_graph) if self._cluster_graph else None})" 

213 ) 

214 

215 def __len__(self): 

216 """Return the number of clusters.""" 

217 return len(self._cluster_graph) 

218 

219 def __eq__(self, other): 

220 if not isinstance(other, ClusteredQuantumGraph): 

221 return False 

222 if len(self) != len(other): 

223 return False 

224 return self._quantum_graph == other._quantum_graph and is_isomorphic( 

225 self._cluster_graph, other._cluster_graph 

226 ) 

227 

228 @property 

229 def name(self): 

230 """The name of the ClusteredQuantumGraph.""" 

231 return self._name 

232 

233 @property 

234 def qgraph(self): 

235 """The QuantumGraph associated with this Clustered 

236 QuantumGraph. 

237 """ 

238 return self._quantum_graph 

239 

240 def add_cluster(self, clusters_for_adding): 

241 """Add a cluster of quanta as a node in the graph. 

242 

243 Parameters 

244 ---------- 

245 clusters_for_adding: `QuantaCluster` or `Iterable` [`QuantaCluster`] 

246 The cluster to be added to the ClusteredQuantumGraph. 

247 """ 

248 for cluster in ensure_iterable(clusters_for_adding): 

249 if not isinstance(cluster, QuantaCluster): 

250 raise TypeError(f"Must be type QuantaCluster (given: {type(cluster)})") 

251 

252 if self._cluster_graph.has_node(cluster.name): 

253 raise KeyError(f"Cluster {cluster.name} already exists in ClusteredQuantumGraph") 

254 

255 self._cluster_graph.add_node(cluster.name, cluster=cluster) 

256 

257 def get_cluster(self, name): 

258 """Retrieve a cluster from the ClusteredQuantumGraph by name. 

259 

260 Parameters 

261 ---------- 

262 name : `str` 

263 Name of cluster to retrieve. 

264 

265 Returns 

266 ------- 

267 cluster : `QuantaCluster` 

268 QuantaCluster matching given name. 

269 

270 Raises 

271 ------ 

272 KeyError 

273 Raised if the ClusteredQuantumGraph does not contain 

274 a cluster with given name. 

275 """ 

276 try: 

277 attr = self._cluster_graph.nodes[name] 

278 except KeyError as ex: 

279 raise KeyError(f"{self.name} does not have a cluster named {name}") from ex 

280 return attr["cluster"] 

281 

282 def get_quantum_node(self, id_): 

283 """Retrieve a QuantumNode from the ClusteredQuantumGraph by ID. 

284 

285 Parameters 

286 ---------- 

287 id_ : `lsst.pipe.base.NodeId` or int 

288 ID of the QuantumNode to retrieve. 

289 

290 Returns 

291 ------- 

292 quantum_node : `lsst.pipe.base.QuantumNode` 

293 QuantumNode matching given ID. 

294 

295 Raises 

296 ------ 

297 KeyError 

298 Raised if the ClusteredQuantumGraph does not contain 

299 a QuantumNode with given ID. 

300 """ 

301 node_id = id_ 

302 if isinstance(id_, int): 

303 node_id = NodeId(id, self._quantum_graph.graphID) 

304 _LOG.debug("get_quantum_node: node_id = %s", node_id) 

305 return self._quantum_graph.getQuantumNodeByNodeId(node_id) 

306 

307 def __iter__(self): 

308 """Iterate over names of clusters. 

309 

310 Returns 

311 ------- 

312 names : `Iterator` [`str`] 

313 Iterator over names of clusters. 

314 """ 

315 return self._cluster_graph.nodes() 

316 

317 def clusters(self): 

318 """Iterate over clusters. 

319 

320 Returns 

321 ------- 

322 clusters : `Iterator` [`lsst.ctrl.bps.QuantaCluster`] 

323 Iterator over clusters in topological order. 

324 """ 

325 return map(self.get_cluster, topological_sort(self._cluster_graph)) 

326 

327 def successors(self, name): 

328 """Return clusters that are successors of the cluster 

329 with the given name. 

330 

331 Parameters 

332 ---------- 

333 name : `str` 

334 Name of cluster for which need the successors. 

335 

336 Returns 

337 ------- 

338 clusters : `Iterator` [`lsst.ctrl.bps.QuantaCluster`] 

339 Iterator over successors of given cluster. 

340 """ 

341 return map(self.get_cluster, self._cluster_graph.successors(name)) 

342 

343 def predecessors(self, name): 

344 """Return clusters that are predecessors of the cluster 

345 with the given name. 

346 

347 Parameters 

348 ---------- 

349 name : `str` 

350 Name of cluster for which need the predecessors. 

351 

352 Returns 

353 ------- 

354 clusters : `Iterator` [`lsst.ctrl.bps.QuantaCluster`] 

355 Iterator over predecessors of given cluster. 

356 """ 

357 return map(self.get_cluster, self._cluster_graph.predecessors(name)) 

358 

359 def add_dependency(self, parent, child): 

360 """Add a directed dependency between a parent cluster and a child 

361 cluster. 

362 

363 Parameters 

364 ---------- 

365 parent : `str` or `QuantaCluster` 

366 Parent cluster. 

367 child : `str` or `QuantaCluster` 

368 Child cluster. 

369 

370 Raises 

371 ------ 

372 KeyError 

373 Raised if either the parent or child doesn't exist in the 

374 ClusteredQuantumGraph. 

375 """ 

376 if not self._cluster_graph.has_node(parent): 

377 raise KeyError(f"{self.name} does not have a cluster named {parent}") 

378 if not self._cluster_graph.has_node(child): 

379 raise KeyError(f"{self.name} does not have a cluster named {child}") 

380 _LOG.debug("add_dependency: adding edge %s %s", parent, child) 

381 

382 if isinstance(parent, QuantaCluster): 

383 pname = parent.name 

384 else: 

385 pname = parent 

386 

387 if isinstance(child, QuantaCluster): 

388 cname = child.name 

389 else: 

390 cname = child 

391 self._cluster_graph.add_edge(pname, cname) 

392 

393 def __contains__(self, name): 

394 """Check if a cluster with given name is in this ClusteredQuantumGraph. 

395 

396 Parameters 

397 ---------- 

398 name : `str` 

399 Name of cluster to check. 

400 

401 Returns 

402 ------- 

403 found : `bool` 

404 Whether a cluster with given name is in this ClusteredQuantumGraph. 

405 """ 

406 return self._cluster_graph.has_node(name) 

407 

408 def save(self, filename, format_=None): 

409 """Save the ClusteredQuantumGraph in a format that is loadable. 

410 The QuantumGraph is saved separately if hasn't already been 

411 serialized. 

412 

413 Parameters 

414 ---------- 

415 filename : `str` 

416 File to which the ClusteredQuantumGraph should be serialized. 

417 

418 format_ : `str`, optional 

419 Format in which to write the data. It defaults to pickle format. 

420 """ 

421 path = Path(filename) 

422 

423 # if format is None, try extension 

424 if format_ is None: 

425 format_ = path.suffix[1:] # suffix includes the leading period 

426 

427 if format_ not in {"pickle"}: 

428 raise RuntimeError(f"Unknown format ({format_})") 

429 

430 if not self._quantum_graph_filename: 

431 # Create filename based on given ClusteredQuantumGraph filename 

432 self._quantum_graph_filename = path.with_suffix(".qgraph") 

433 

434 # If QuantumGraph file doesn't already exist, save it: 

435 if not Path(self._quantum_graph_filename).exists(): 

436 self._quantum_graph.saveUri(self._quantum_graph_filename) 

437 

438 if format_ == "pickle": 

439 # Don't save QuantumGraph in same file. 

440 tmp_qgraph = self._quantum_graph 

441 self._quantum_graph = None 

442 with open(filename, "wb") as fh: 

443 pickle.dump(self, fh) 

444 # Return to original state. 

445 self._quantum_graph = tmp_qgraph 

446 

447 def draw(self, filename, format_=None): 

448 """Draw the ClusteredQuantumGraph in a given format. 

449 

450 Parameters 

451 ---------- 

452 filename : `str` 

453 File to which the ClusteredQuantumGraph should be serialized. 

454 

455 format_ : `str`, optional 

456 Format in which to draw the data. It defaults to dot format. 

457 """ 

458 path = Path(filename) 

459 

460 # if format is None, try extension 

461 if format_ is None: 

462 format_ = path.suffix[1:] # suffix includes the leading period 

463 

464 draw_funcs = {"dot": draw_networkx_dot} 

465 if format_ in draw_funcs: 

466 draw_funcs[format_](self._cluster_graph, filename) 

467 else: 

468 raise RuntimeError(f"Unknown draw format ({format_}") 

469 

470 @classmethod 

471 def load(cls, filename, format_=None): 

472 """Load a ClusteredQuantumGraph from the given file. 

473 

474 Parameters 

475 ---------- 

476 filename : `str` 

477 File from which to read the ClusteredQuantumGraph. 

478 format_ : `str`, optional 

479 Format of data to expect when loading from stream. It defaults 

480 to pickle format. 

481 

482 Returns 

483 ------- 

484 ClusteredQuantumGraph : `lsst.ctrl.bps.ClusteredQuantumGraph` 

485 ClusteredQuantumGraph workflow loaded from the given file. 

486 The QuantumGraph is loaded from its own file specified in 

487 the saved ClusteredQuantumGraph. 

488 """ 

489 path = Path(filename) 

490 

491 # if format is None, try extension 

492 if format_ is None: 

493 format_ = path.suffix[1:] # suffix includes the leading period 

494 

495 if format_ not in {"pickle"}: 

496 raise RuntimeError(f"Unknown format ({format_})") 

497 

498 cgraph = None 

499 if format_ == "pickle": 

500 with open(filename, "rb") as fh: 

501 cgraph = pickle.load(fh) 

502 

503 # The QuantumGraph was saved separately 

504 try: 

505 cgraph._quantum_graph = QuantumGraph.loadUri(cgraph._quantum_graph_filename) 

506 except FileNotFoundError: # Try same path as ClusteredQuantumGraph 

507 new_filename = path.parent / Path(cgraph._quantum_graph_filename).name 

508 cgraph._quantum_graph = QuantumGraph.loadUri(new_filename) 

509 

510 return cgraph