Coverage for python/lsst/ctrl/bps/clustered_quantum_graph.py: 21%

168 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-04-04 09:17 +0000

1# This file is part of ctrl_bps. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22"""Class definitions for a Clustered QuantumGraph where a node in the graph is 

23a QuantumGraph. 

24""" 

25 

26__all__ = ["QuantaCluster", "ClusteredQuantumGraph"] 

27 

28 

29import logging 

30import pickle 

31import re 

32from collections import Counter, defaultdict 

33from pathlib import Path 

34 

35from lsst.pipe.base import NodeId, QuantumGraph 

36from lsst.utils.iteration import ensure_iterable 

37from networkx import DiGraph 

38 

39from .bps_draw import draw_networkx_dot 

40 

41_LOG = logging.getLogger(__name__) 

42 

43 

44class QuantaCluster: 

45 """Information about the cluster and Quanta belonging to it. 

46 

47 Parameters 

48 ---------- 

49 name: `str` 

50 Lookup key (logical file name) of file/directory. Must 

51 be unique within ClusteredQuantumGraph. 

52 label: `str` 

53 Value used to group clusters. 

54 tags : `dict` [`str`, `Any`], optional 

55 Arbitrary key/value pairs for the cluster. 

56 

57 Raises 

58 ------ 

59 ValueError 

60 Raised if invalid name (e.g., name contains /) 

61 """ 

62 

63 def __init__(self, name, label, tags=None): 

64 if "/" in name: 

65 raise ValueError(f"Cluster's name cannot have a / ({name})") 

66 self.name = name 

67 self.label = label 

68 self._qgraph_node_ids = [] 

69 self._task_label_counts = Counter() 

70 self.tags = tags 

71 if self.tags is None: 

72 self.tags = {} 

73 

74 @classmethod 

75 def from_quantum_node(cls, quantum_node, template): 

76 """Create single quantum cluster from given quantum node. 

77 

78 Parameters 

79 ---------- 

80 quantum_node : `lsst.pipe.base.QuantumNode` 

81 QuantumNode for which to make into a single quantum cluster. 

82 

83 template : `str` 

84 Template for creating cluster name. 

85 

86 Returns 

87 ------- 

88 cluster : `QuantaCluster` 

89 Newly created cluster containing the given quantum. 

90 """ 

91 label = quantum_node.taskDef.label 

92 node_id = quantum_node.nodeId 

93 data_id = quantum_node.quantum.dataId 

94 

95 # Gather info for name template into a dictionary. 

96 info = data_id.byName() 

97 info["label"] = label 

98 info["node_number"] = node_id 

99 _LOG.debug("template = %s", template) 

100 _LOG.debug("info for template = %s", info) 

101 

102 # Use dictionary plus template format string to create name. To avoid 

103 # key errors from generic patterns, use defaultdict. 

104 try: 

105 name = template.format_map(defaultdict(lambda: "", info)) 

106 except TypeError: 

107 _LOG.error("Problems creating cluster name. template='%s', info=%s", template, info) 

108 raise 

109 name = re.sub("_+", "_", name) 

110 _LOG.debug("template name = %s", name) 

111 

112 cluster = QuantaCluster(name, label, info) 

113 cluster.add_quantum(quantum_node.nodeId, label) 

114 return cluster 

115 

116 @property 

117 def qgraph_node_ids(self): 

118 """QuantumGraph NodeIds corresponding to this cluster.""" 

119 _LOG.debug("_qgraph_node_ids = %s", self._qgraph_node_ids) 

120 return frozenset(self._qgraph_node_ids) 

121 

122 @property 

123 def quanta_counts(self): 

124 """Counts of Quanta per taskDef.label in this cluster.""" 

125 return Counter(self._task_label_counts) 

126 

127 def add_quantum_node(self, quantum_node): 

128 """Add a quantumNode to this cluster. 

129 

130 Parameters 

131 ---------- 

132 quantum_node : `lsst.pipe.base.QuantumNode` 

133 """ 

134 _LOG.debug("quantum_node = %s", quantum_node) 

135 _LOG.debug("quantum_node.nodeId = %s", quantum_node.nodeId) 

136 self.add_quantum(quantum_node.nodeId, quantum_node.taskDef.label) 

137 

138 def add_quantum(self, node_id, task_label): 

139 """Add a quantumNode to this cluster. 

140 

141 Parameters 

142 ---------- 

143 node_id : `lsst.pipe.base.NodeId` 

144 ID for quantumNode to be added to cluster. 

145 task_label : `str` 

146 Task label for quantumNode to be added to cluster. 

147 """ 

148 self._qgraph_node_ids.append(node_id) 

149 self._task_label_counts[task_label] += 1 

150 

151 def __str__(self): 

152 return ( 

153 f"QuantaCluster(name={self.name},label={self.label},tags={self.tags}," 

154 f"counts={self.quanta_counts},ids={self.qgraph_node_ids})" 

155 ) 

156 

157 def __eq__(self, other: object) -> bool: 

158 # Doesn't check data equality, but only 

159 # name equality since those are supposed 

160 # to be unique. 

161 if isinstance(other, str): 

162 return self.name == other 

163 

164 if isinstance(other, QuantaCluster): 

165 return self.name == other.name 

166 

167 return False 

168 

169 def __hash__(self) -> int: 

170 return hash(self.name) 

171 

172 

173class ClusteredQuantumGraph: 

174 """Graph where the data for a node is a subgraph of the full 

175 QuantumGraph represented by a list of node ids. 

176 

177 Parameters 

178 ---------- 

179 name : `str` 

180 Name to be given to the ClusteredQuantumGraph. 

181 qgraph : `lsst.pipe.base.QuantumGraph` 

182 The QuantumGraph to be clustered. 

183 qgraph_filename : `str` 

184 Filename for given QuantumGraph if it has already been 

185 serialized. 

186 

187 Raises 

188 ------ 

189 ValueError 

190 Raised if invalid name (e.g., name contains /) 

191 

192 Notes 

193 ----- 

194 Using lsst.pipe.base.NodeId instead of integer because the QuantumGraph 

195 API requires them. Chose skipping the repeated creation of objects to 

196 use API over totally minimized memory usage. 

197 """ 

198 

199 def __init__(self, name, qgraph, qgraph_filename=None): 

200 if "/" in name: 

201 raise ValueError(f"name cannot have a / ({name})") 

202 self._name = name 

203 self._quantum_graph = qgraph 

204 self._quantum_graph_filename = Path(qgraph_filename).resolve() 

205 self._cluster_graph = DiGraph() 

206 

207 def __str__(self): 

208 return ( 

209 f"ClusteredQuantumGraph(name={self.name}," 

210 f"quantum_graph_filename={self._quantum_graph_filename}," 

211 f"len(qgraph)={len(self._quantum_graph) if self._quantum_graph else None}," 

212 f"len(cqgraph)={len(self._cluster_graph) if self._cluster_graph else None})" 

213 ) 

214 

215 def __len__(self): 

216 """Return the number of clusters.""" 

217 return len(self._cluster_graph) 

218 

219 @property 

220 def name(self): 

221 """The name of the ClusteredQuantumGraph.""" 

222 return self._name 

223 

224 @property 

225 def qgraph(self): 

226 """The QuantumGraph associated with this Clustered 

227 QuantumGraph. 

228 """ 

229 return self._quantum_graph 

230 

231 def add_cluster(self, clusters_for_adding): 

232 """Add a cluster of quanta as a node in the graph. 

233 

234 Parameters 

235 ---------- 

236 clusters_for_adding: `QuantaCluster` or `Iterable` [`QuantaCluster`] 

237 The cluster to be added to the ClusteredQuantumGraph. 

238 """ 

239 for cluster in ensure_iterable(clusters_for_adding): 

240 if not isinstance(cluster, QuantaCluster): 

241 raise TypeError(f"Must be type QuantaCluster (given: {type(cluster)})") 

242 

243 if self._cluster_graph.has_node(cluster.name): 

244 raise KeyError(f"Cluster {cluster.name} already exists in ClusteredQuantumGraph") 

245 

246 self._cluster_graph.add_node(cluster.name, cluster=cluster) 

247 

248 def get_cluster(self, name): 

249 """Retrieve a cluster from the ClusteredQuantumGraph by name. 

250 

251 Parameters 

252 ---------- 

253 name : `str` 

254 Name of cluster to retrieve. 

255 

256 Returns 

257 ------- 

258 cluster : `QuantaCluster` 

259 QuantaCluster matching given name. 

260 

261 Raises 

262 ------ 

263 KeyError 

264 Raised if the ClusteredQuantumGraph does not contain 

265 a cluster with given name. 

266 """ 

267 try: 

268 attr = self._cluster_graph.nodes[name] 

269 except KeyError as ex: 

270 raise KeyError(f"{self.name} does not have a cluster named {name}") from ex 

271 return attr["cluster"] 

272 

273 def get_quantum_node(self, id_): 

274 """Retrieve a QuantumNode from the ClusteredQuantumGraph by ID. 

275 

276 Parameters 

277 ---------- 

278 id_ : `lsst.pipe.base.NodeId` or int 

279 ID of the QuantumNode to retrieve. 

280 

281 Returns 

282 ------- 

283 quantum_node : `lsst.pipe.base.QuantumNode` 

284 QuantumNode matching given ID. 

285 

286 Raises 

287 ------ 

288 KeyError 

289 Raised if the ClusteredQuantumGraph does not contain 

290 a QuantumNode with given ID. 

291 """ 

292 node_id = id_ 

293 if isinstance(id_, int): 

294 node_id = NodeId(id, self._quantum_graph.graphID) 

295 _LOG.debug("get_quantum_node: node_id = %s", node_id) 

296 return self._quantum_graph.getQuantumNodeByNodeId(node_id) 

297 

298 def __iter__(self): 

299 """Iterate over names of clusters. 

300 

301 Returns 

302 ------- 

303 names : `Iterator` [`str`] 

304 Iterator over names of clusters. 

305 """ 

306 return self._cluster_graph.nodes() 

307 

308 def clusters(self): 

309 """Iterate over clusters. 

310 

311 Returns 

312 ------- 

313 clusters : `Iterator` [`lsst.ctrl.bps.QuantaCluster`] 

314 Iterator over clusters. 

315 """ 

316 return map(self.get_cluster, self._cluster_graph.nodes()) 

317 

318 def successors(self, name): 

319 """Return clusters that are successors of the cluster 

320 with the given name. 

321 

322 Parameters 

323 ---------- 

324 name : `str` 

325 Name of cluster for which need the successors. 

326 

327 Returns 

328 ------- 

329 clusters : `Iterator` [`lsst.ctrl.bps.QuantaCluster`] 

330 Iterator over successors of given cluster. 

331 """ 

332 return map(self.get_cluster, self._cluster_graph.successors(name)) 

333 

334 def predecessors(self, name): 

335 """Return clusters that are predecessors of the cluster 

336 with the given name. 

337 

338 Parameters 

339 ---------- 

340 name : `str` 

341 Name of cluster for which need the predecessors. 

342 

343 Returns 

344 ------- 

345 clusters : `Iterator` [`lsst.ctrl.bps.QuantaCluster`] 

346 Iterator over predecessors of given cluster. 

347 """ 

348 return map(self.get_cluster, self._cluster_graph.predecessors(name)) 

349 

350 def add_dependency(self, parent, child): 

351 """Add a directed dependency between a parent cluster and a child 

352 cluster. 

353 

354 Parameters 

355 ---------- 

356 parent : `str` or `QuantaCluster` 

357 Parent cluster. 

358 child : `str` or `QuantaCluster` 

359 Child cluster. 

360 

361 Raises 

362 ------ 

363 KeyError 

364 Raised if either the parent or child doesn't exist in the 

365 ClusteredQuantumGraph. 

366 """ 

367 if not self._cluster_graph.has_node(parent): 

368 raise KeyError(f"{self.name} does not have a cluster named {parent}") 

369 if not self._cluster_graph.has_node(child): 

370 raise KeyError(f"{self.name} does not have a cluster named {child}") 

371 _LOG.debug("add_dependency: adding edge %s %s", parent, child) 

372 

373 if isinstance(parent, QuantaCluster): 

374 pname = parent.name 

375 else: 

376 pname = parent 

377 

378 if isinstance(child, QuantaCluster): 

379 cname = child.name 

380 else: 

381 cname = child 

382 self._cluster_graph.add_edge(pname, cname) 

383 

384 def __contains__(self, name): 

385 """Check if a cluster with given name is in this ClusteredQuantumGraph. 

386 

387 Parameters 

388 ---------- 

389 name : `str` 

390 Name of cluster to check. 

391 

392 Returns 

393 ------- 

394 found : `bool` 

395 Whether a cluster with given name is in this ClusteredQuantumGraph. 

396 """ 

397 return self._cluster_graph.has_node(name) 

398 

399 def save(self, filename, format_=None): 

400 """Save the ClusteredQuantumGraph in a format that is loadable. 

401 The QuantumGraph is saved separately if hasn't already been 

402 serialized. 

403 

404 Parameters 

405 ---------- 

406 filename : `str` 

407 File to which the ClusteredQuantumGraph should be serialized. 

408 

409 format_ : `str`, optional 

410 Format in which to write the data. It defaults to pickle format. 

411 """ 

412 path = Path(filename) 

413 

414 # if format is None, try extension 

415 if format_ is None: 

416 format_ = path.suffix[1:] # suffix includes the leading period 

417 

418 if format_ not in {"pickle"}: 

419 raise RuntimeError(f"Unknown format ({format_})") 

420 

421 if not self._quantum_graph_filename: 

422 # Create filename based on given ClusteredQuantumGraph filename 

423 self._quantum_graph_filename = path.with_suffix(".qgraph") 

424 

425 # If QuantumGraph file doesn't already exist, save it: 

426 if not Path(self._quantum_graph_filename).exists(): 

427 self._quantum_graph.saveUri(self._quantum_graph_filename) 

428 

429 if format_ == "pickle": 

430 # Don't save QuantumGraph in same file. 

431 tmp_qgraph = self._quantum_graph 

432 self._quantum_graph = None 

433 with open(filename, "wb") as fh: 

434 pickle.dump(self, fh) 

435 # Return to original state. 

436 self._quantum_graph = tmp_qgraph 

437 

438 def draw(self, filename, format_=None): 

439 """Draw the ClusteredQuantumGraph in a given format. 

440 

441 Parameters 

442 ---------- 

443 filename : `str` 

444 File to which the ClusteredQuantumGraph should be serialized. 

445 

446 format_ : `str`, optional 

447 Format in which to draw the data. It defaults to dot format. 

448 """ 

449 path = Path(filename) 

450 

451 # if format is None, try extension 

452 if format_ is None: 

453 format_ = path.suffix[1:] # suffix includes the leading period 

454 

455 draw_funcs = {"dot": draw_networkx_dot} 

456 if format_ in draw_funcs: 

457 draw_funcs[format_](self._cluster_graph, filename) 

458 else: 

459 raise RuntimeError(f"Unknown draw format ({format_}") 

460 

461 @classmethod 

462 def load(cls, filename, format_=None): 

463 """Load a ClusteredQuantumGraph from the given file. 

464 

465 Parameters 

466 ---------- 

467 filename : `str` 

468 File from which to read the ClusteredQuantumGraph. 

469 format_ : `str`, optional 

470 Format of data to expect when loading from stream. It defaults 

471 to pickle format. 

472 

473 Returns 

474 ------- 

475 ClusteredQuantumGraph : `lsst.ctrl.bps.ClusteredQuantumGraph` 

476 ClusteredQuantumGraph workflow loaded from the given file. 

477 The QuantumGraph is loaded from its own file specified in 

478 the saved ClusteredQuantumGraph. 

479 """ 

480 path = Path(filename) 

481 

482 # if format is None, try extension 

483 if format_ is None: 

484 format_ = path.suffix[1:] # suffix includes the leading period 

485 

486 if format_ not in {"pickle"}: 

487 raise RuntimeError(f"Unknown format ({format_})") 

488 

489 cgraph = None 

490 if format_ == "pickle": 

491 with open(filename, "rb") as fh: 

492 cgraph = pickle.load(fh) 

493 

494 # The QuantumGraph was saved separately 

495 try: 

496 cgraph._quantum_graph = QuantumGraph.loadUri(cgraph._quantum_graph_filename) 

497 except FileNotFoundError: # Try same path as ClusteredQuantumGraph 

498 new_filename = path.parent / Path(cgraph._quantum_graph_filename).name 

499 cgraph._quantum_graph = QuantumGraph.loadUri(new_filename) 

500 

501 return cgraph