Coverage for python / lsst / ctrl / bps / clustered_quantum_graph.py: 22%

180 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-06 08:35 +0000

1# This file is part of ctrl_bps. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <https://www.gnu.org/licenses/>. 

27 

28"""Class definitions for a Clustered QuantumGraph where a node in the graph is 

29a QuantumGraph. 

30""" 

31 

32from __future__ import annotations 

33 

34__all__ = ["ClusteredQuantumGraph", "QuantaCluster"] 

35 

36import logging 

37import pickle 

38import re 

39import uuid 

40from collections import Counter, defaultdict 

41from pathlib import Path 

42 

43from networkx import DiGraph, is_directed_acyclic_graph, is_isomorphic, topological_sort 

44 

45from lsst.pipe.base.pipeline_graph import TaskImportMode 

46from lsst.pipe.base.quantum_graph import PredictedQuantumGraph, QuantumInfo 

47from lsst.utils.iteration import ensure_iterable 

48 

49from .bps_draw import draw_networkx_dot 

50 

51_LOG = logging.getLogger(__name__) 

52 

53 

54class QuantaCluster: 

55 """Information about the cluster and Quanta belonging to it. 

56 

57 Parameters 

58 ---------- 

59 name : `str` 

60 Lookup key (logical file name) of file/directory. Must 

61 be unique within ClusteredQuantumGraph. 

62 label : `str` 

63 Value used to group clusters. 

64 tags : `dict` [`str`, `~typing.Any`], optional 

65 Arbitrary key/value pairs for the cluster. 

66 

67 Raises 

68 ------ 

69 ValueError 

70 Raised if invalid name (e.g., name contains /). 

71 """ 

72 

73 def __init__(self, name, label, tags=None): 

74 if "/" in name: 

75 raise ValueError(f"Cluster's name cannot have a / ({name})") 

76 self.name = name 

77 self.label = label 

78 self._qgraph_node_ids = [] 

79 self._task_label_counts = Counter() 

80 self.tags = tags 

81 if self.tags is None: 

82 self.tags = {} 

83 

84 @classmethod 

85 def from_quantum_info( 

86 cls, quantum_id: uuid.UUID, quantum_info: QuantumInfo, template: str 

87 ) -> QuantaCluster: 

88 """Create single quantum cluster from the given quantum information. 

89 

90 Parameters 

91 ---------- 

92 quantum_id : `uuid.UUID` 

93 ID of the quantum. 

94 quantum_info : `lsst.pipe.base.quantum_graph.QuantumInfo` 

95 Dictionary of additional information about the quantum. 

96 template : `str` 

97 Template for creating cluster name. 

98 

99 Returns 

100 ------- 

101 cluster : `QuantaCluster` 

102 Newly created cluster containing the given quantum. 

103 """ 

104 label = quantum_info["task_label"] 

105 data_id = quantum_info["data_id"] 

106 

107 # Gather info for name template into a dictionary. 

108 info = dict(data_id.required) 

109 info["label"] = label 

110 info["node_number"] = quantum_id 

111 _LOG.debug("template = %s", template) 

112 _LOG.debug("info for template = %s", info) 

113 

114 # Use dictionary plus template format string to create name. To avoid 

115 # key errors from generic patterns, use defaultdict. 

116 try: 

117 name = template.format_map(defaultdict(lambda: "", info)) 

118 except TypeError: 

119 _LOG.error("Problems creating cluster name. template='%s', info=%s", template, info) 

120 raise 

121 name = re.sub("_+", "_", name) 

122 _LOG.debug("template name = %s", name) 

123 

124 cluster = QuantaCluster(name, label, info) 

125 cluster.add_quantum(quantum_id, label) 

126 return cluster 

127 

128 @property 

129 def qgraph_node_ids(self): 

130 """Quantum graph NodeIds corresponding to this cluster.""" 

131 _LOG.debug("_qgraph_node_ids = %s", self._qgraph_node_ids) 

132 return frozenset(self._qgraph_node_ids) 

133 

134 @property 

135 def quanta_counts(self): 

136 """Counts of Quanta per taskDef.label in this cluster.""" 

137 return Counter(self._task_label_counts) 

138 

139 def add_quantum(self, node_id, task_label): 

140 """Add a quantumNode to this cluster. 

141 

142 Parameters 

143 ---------- 

144 node_id : `uuid.UUID` 

145 ID for quantumNode to be added to cluster. 

146 task_label : `str` 

147 Task label for quantumNode to be added to cluster. 

148 """ 

149 self._qgraph_node_ids.append(node_id) 

150 self._task_label_counts[task_label] += 1 

151 

152 def __str__(self): 

153 return ( 

154 f"QuantaCluster(name={self.name},label={self.label},tags={self.tags}," 

155 f"counts={self.quanta_counts},ids={self.qgraph_node_ids})" 

156 ) 

157 

158 def __eq__(self, other: object) -> bool: 

159 # Doesn't check data equality, but only 

160 # name equality since those are supposed 

161 # to be unique. 

162 if isinstance(other, str): 

163 return self.name == other 

164 

165 if isinstance(other, QuantaCluster): 

166 return self.name == other.name 

167 

168 return False 

169 

170 def __hash__(self) -> int: 

171 return hash(self.name) 

172 

173 

174class ClusteredQuantumGraph: 

175 """Graph where the data for a node is a subgraph of the full 

176 QuantumGraph represented by a list of node ids. 

177 

178 Parameters 

179 ---------- 

180 name : `str` 

181 Name to be given to the ClusteredQuantumGraph. 

182 qgraph : `lsst.pipe.base.quantum_graph.PredictedQuantumGraph` 

183 The quantum graph to be clustered. 

184 qgraph_filename : `str` 

185 Filename for given quantum graph. 

186 

187 Raises 

188 ------ 

189 ValueError 

190 Raised if invalid name (e.g., name contains /) 

191 

192 Notes 

193 ----- 

194 Using lsst.pipe.base.NodeId instead of integer because the QuantumGraph 

195 API requires them. Chose skipping the repeated creation of objects to 

196 use API over totally minimized memory usage. 

197 """ 

198 

199 def __init__(self, name: str, qgraph: PredictedQuantumGraph, qgraph_filename: str): 

200 if "/" in name: 

201 raise ValueError(f"name cannot have a / ({name})") 

202 self._name = name 

203 self._quantum_graph = qgraph 

204 self._quantum_only_xgraph = qgraph.quantum_only_xgraph 

205 self._quantum_graph_filename = Path(qgraph_filename).resolve() 

206 self._cluster_graph = DiGraph() 

207 

208 def __str__(self): 

209 return ( 

210 f"ClusteredQuantumGraph(name={self.name}," 

211 f"quantum_graph_filename={self._quantum_graph_filename}," 

212 f"len(qgraph)={len(self._quantum_graph) if self._quantum_graph else None}," 

213 f"len(cqgraph)={len(self._cluster_graph) if self._cluster_graph else None})" 

214 ) 

215 

216 def __len__(self): 

217 """Return the number of clusters.""" 

218 return len(self._cluster_graph) 

219 

220 def __eq__(self, other): 

221 if not isinstance(other, ClusteredQuantumGraph): 

222 return False 

223 if len(self) != len(other): 

224 return False 

225 return is_isomorphic(self.qxgraph, other.qxgraph) and is_isomorphic( 

226 self._cluster_graph, other._cluster_graph 

227 ) 

228 

229 @property 

230 def name(self) -> str: 

231 """The name of the ClusteredQuantumGraph.""" 

232 return self._name 

233 

234 @property 

235 def qgraph(self) -> PredictedQuantumGraph: 

236 """The quantum graph associated with this Clustered 

237 QuantumGraph. 

238 """ 

239 return self._quantum_graph 

240 

241 @property 

242 def qxgraph(self) -> DiGraph: 

243 """A networkx graph of all quanta.""" 

244 return self._quantum_only_xgraph 

245 

246 def add_cluster(self, clusters_for_adding): 

247 """Add a cluster of quanta as a node in the graph. 

248 

249 Parameters 

250 ---------- 

251 clusters_for_adding : `QuantaCluster` or \ 

252 `~collections.abc.Iterable` [`QuantaCluster`] 

253 The cluster to be added to the ClusteredQuantumGraph. 

254 """ 

255 for cluster in ensure_iterable(clusters_for_adding): 

256 if not isinstance(cluster, QuantaCluster): 

257 raise TypeError(f"Must be type QuantaCluster (given: {type(cluster)})") 

258 

259 if self._cluster_graph.has_node(cluster.name): 

260 raise KeyError(f"Cluster {cluster.name} already exists in ClusteredQuantumGraph") 

261 

262 self._cluster_graph.add_node(cluster.name, cluster=cluster) 

263 

264 def get_cluster(self, name): 

265 """Retrieve a cluster from the ClusteredQuantumGraph by name. 

266 

267 Parameters 

268 ---------- 

269 name : `str` 

270 Name of cluster to retrieve. 

271 

272 Returns 

273 ------- 

274 cluster : `QuantaCluster` 

275 QuantaCluster matching given name. 

276 

277 Raises 

278 ------ 

279 KeyError 

280 Raised if the ClusteredQuantumGraph does not contain 

281 a cluster with given name. 

282 """ 

283 try: 

284 attr = self._cluster_graph.nodes[name] 

285 except KeyError as ex: 

286 raise KeyError(f"{self.name} does not have a cluster named {name}") from ex 

287 return attr["cluster"] 

288 

289 def get_quantum_info(self, id_: uuid.UUID) -> QuantumInfo: 

290 """Retrieve a quantum info dict from the ClusteredQuantumGraph by ID. 

291 

292 Parameters 

293 ---------- 

294 id_ : `uuid.UUID` 

295 ID of the quantum to retrieve. 

296 

297 Returns 

298 ------- 

299 quantum_info : `lsst.pipe.base.quantum_graph.QuantumInfo` 

300 Quantum info dictionary for the given ID. 

301 

302 Raises 

303 ------ 

304 KeyError 

305 Raised if the ClusteredQuantumGraph does not contain 

306 a quantum with given ID. 

307 """ 

308 return self._quantum_only_xgraph.nodes[id_] 

309 

310 def __iter__(self): 

311 """Iterate over names of clusters. 

312 

313 Returns 

314 ------- 

315 names : `~collections.abc.Iterator` [`str`] 

316 Iterator over names of clusters. 

317 """ 

318 return self._cluster_graph.nodes() 

319 

320 def clusters(self): 

321 """Iterate over clusters. 

322 

323 Returns 

324 ------- 

325 clusters : `~collections.abc.Iterator` [`lsst.ctrl.bps.QuantaCluster`] 

326 Iterator over clusters in topological order. 

327 """ 

328 return map(self.get_cluster, topological_sort(self._cluster_graph)) 

329 

330 def successors(self, name): 

331 """Return clusters that are successors of the cluster 

332 with the given name. 

333 

334 Parameters 

335 ---------- 

336 name : `str` 

337 Name of cluster for which need the successors. 

338 

339 Returns 

340 ------- 

341 clusters : `~collections.abc.Iterator` [`lsst.ctrl.bps.QuantaCluster`] 

342 Iterator over successors of given cluster. 

343 """ 

344 return map(self.get_cluster, self._cluster_graph.successors(name)) 

345 

346 def predecessors(self, name): 

347 """Return clusters that are predecessors of the cluster 

348 with the given name. 

349 

350 Parameters 

351 ---------- 

352 name : `str` 

353 Name of cluster for which need the predecessors. 

354 

355 Returns 

356 ------- 

357 clusters : `~collections.abc.Iterator` [`lsst.ctrl.bps.QuantaCluster`] 

358 Iterator over predecessors of given cluster. 

359 """ 

360 return map(self.get_cluster, self._cluster_graph.predecessors(name)) 

361 

362 def add_dependency(self, parent, child): 

363 """Add a directed dependency between a parent cluster and a child 

364 cluster. 

365 

366 Parameters 

367 ---------- 

368 parent : `str` or `QuantaCluster` 

369 Parent cluster. 

370 child : `str` or `QuantaCluster` 

371 Child cluster. 

372 

373 Raises 

374 ------ 

375 KeyError 

376 Raised if either the parent or child doesn't exist in the 

377 ClusteredQuantumGraph. 

378 """ 

379 if not self._cluster_graph.has_node(parent): 

380 raise KeyError(f"{self.name} does not have a cluster named {parent}") 

381 if not self._cluster_graph.has_node(child): 

382 raise KeyError(f"{self.name} does not have a cluster named {child}") 

383 _LOG.debug("add_dependency: adding edge %s %s", parent, child) 

384 

385 if isinstance(parent, QuantaCluster): 

386 pname = parent.name 

387 else: 

388 pname = parent 

389 

390 if isinstance(child, QuantaCluster): 

391 cname = child.name 

392 else: 

393 cname = child 

394 self._cluster_graph.add_edge(pname, cname) 

395 

396 def __contains__(self, name): 

397 """Check if a cluster with given name is in this ClusteredQuantumGraph. 

398 

399 Parameters 

400 ---------- 

401 name : `str` 

402 Name of cluster to check. 

403 

404 Returns 

405 ------- 

406 found : `bool` 

407 Whether a cluster with given name is in this ClusteredQuantumGraph. 

408 """ 

409 return self._cluster_graph.has_node(name) 

410 

411 def save(self, filename, format_=None): 

412 """Save the ClusteredQuantumGraph in a format that is loadable. 

413 

414 The quantum graph is assumed to have been saved separately. 

415 

416 Parameters 

417 ---------- 

418 filename : `str` 

419 File to which the ClusteredQuantumGraph should be serialized. 

420 format_ : `str`, optional 

421 Format in which to write the data. It defaults to pickle format. 

422 """ 

423 path = Path(filename) 

424 

425 # if format is None, try extension 

426 if format_ is None: 

427 format_ = path.suffix[1:] # suffix includes the leading period 

428 

429 if format_ not in {"pickle"}: 

430 raise RuntimeError(f"Unknown format ({format_})") 

431 

432 if format_ == "pickle": 

433 # Don't save QuantumGraph in same file. 

434 tmp_qgraph = self._quantum_graph 

435 self._quantum_graph = None 

436 with open(filename, "wb") as fh: 

437 pickle.dump(self, fh) 

438 # Return to original state. 

439 self._quantum_graph = tmp_qgraph 

440 

441 def draw(self, filename, format_=None): 

442 """Draw the ClusteredQuantumGraph in a given format. 

443 

444 Parameters 

445 ---------- 

446 filename : `str` 

447 File to which the ClusteredQuantumGraph should be serialized. 

448 format_ : `str`, optional 

449 Format in which to draw the data. It defaults to dot format. 

450 """ 

451 path = Path(filename) 

452 

453 # if format is None, try extension 

454 if format_ is None: 

455 format_ = path.suffix[1:] # suffix includes the leading period 

456 

457 draw_funcs = {"dot": draw_networkx_dot} 

458 if format_ in draw_funcs: 

459 draw_funcs[format_](self._cluster_graph, filename) 

460 else: 

461 raise RuntimeError(f"Unknown draw format ({format_})") 

462 

463 @classmethod 

464 def load(cls, filename, format_=None): 

465 """Load a ClusteredQuantumGraph from the given file. 

466 

467 Parameters 

468 ---------- 

469 filename : `str` 

470 File from which to read the ClusteredQuantumGraph. 

471 format_ : `str`, optional 

472 Format of data to expect when loading from stream. It defaults 

473 to pickle format. 

474 

475 Returns 

476 ------- 

477 ClusteredQuantumGraph : `lsst.ctrl.bps.ClusteredQuantumGraph` 

478 ClusteredQuantumGraph workflow loaded from the given file. 

479 The QuantumGraph is loaded from its own file specified in 

480 the saved ClusteredQuantumGraph. 

481 """ 

482 path = Path(filename) 

483 

484 # if format is None, try extension 

485 if format_ is None: 

486 format_ = path.suffix[1:] # suffix includes the leading period 

487 

488 if format_ not in {"pickle"}: 

489 raise RuntimeError(f"Unknown format ({format_})") 

490 

491 cgraph = None 

492 if format_ == "pickle": 

493 with open(filename, "rb") as fh: 

494 cgraph: ClusteredQuantumGraph = pickle.load(fh) 

495 

496 # The QuantumGraph was saved separately 

497 with PredictedQuantumGraph.open( 

498 cgraph._quantum_graph_filename, import_mode=TaskImportMode.DO_NOT_IMPORT 

499 ) as reader: 

500 reader.read_thin_graph() 

501 cgraph._quantum_graph = reader.finish() 

502 

503 return cgraph 

504 

505 def validate(self): 

506 """Check correctness of completed ClusteredQuantumGraph. 

507 

508 Raises 

509 ------ 

510 RuntimeError 

511 If the ClusteredQuantumGraph is not valid. 

512 """ 

513 # Check no cycles 

514 if not is_directed_acyclic_graph(self._cluster_graph): 

515 raise RuntimeError("ClusteredQuantumGraph is not a directed acyclic graph.") 

516 

517 # Check that Quantum only in 1 cluster 

518 # Check cluster tags label matches cluster label 

519 node_ids = set() 

520 for cluster in self.clusters(): 

521 if "label" in cluster.tags and cluster.tags["label"] != cluster.label: 

522 raise RuntimeError( 

523 f"Label mismatch in cluster {cluster.name}: " 

524 f"cluster={cluster.label} tags={cluster.tags['label']}" 

525 ) 

526 

527 for node_id in cluster.qgraph_node_ids: 

528 if node_id in node_ids: 

529 raise RuntimeError( 

530 f"Quantum {node_id} occurs in at least 2 clusters (one of which is {cluster.name})" 

531 ) 

532 else: 

533 node_ids.add(node_id) 

534 

535 # Check that have all Quanta 

536 quanta_count_qgraph = len(self._quantum_graph) 

537 quanta_count_cqgraph = len(node_ids) 

538 if quanta_count_qgraph != quanta_count_cqgraph: 

539 raise RuntimeError( 

540 f"Number of Quanta in clustered qgraph ({quanta_count_cqgraph}) does not equal number in" 

541 f" quantum graph ({quanta_count_qgraph})" 

542 )