Coverage for python/lsst/ctrl/bps/clustered_quantum_graph.py: 23%

Shortcuts on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

166 statements  

1# This file is part of ctrl_bps. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22"""Class definitions for a Clustered QuantumGraph where a node in the graph is 

23a QuantumGraph. 

24""" 

25 

26__all__ = ["QuantaCluster", "ClusteredQuantumGraph"] 

27 

28 

29import logging 

30import pickle 

31import re 

32from collections import Counter, defaultdict 

33from pathlib import Path 

34 

35from lsst.daf.butler import DimensionUniverse 

36from lsst.pipe.base import NodeId, QuantumGraph 

37from lsst.utils.iteration import ensure_iterable 

38from networkx import DiGraph 

39 

40from .bps_draw import draw_networkx_dot 

41 

42_LOG = logging.getLogger(__name__) 

43 

44 

45class QuantaCluster: 

46 """Information about the cluster and Quanta belonging to it. 

47 

48 Parameters 

49 ---------- 

50 name: `str` 

51 Lookup key (logical file name) of file/directory. Must 

52 be unique within ClusteredQuantumGraph. 

53 label: `str` 

54 Value used to group clusters. 

55 tags : `dict` [`str`, `Any`], optional 

56 Arbitrary key/value pairs for the cluster. 

57 

58 Raises 

59 ------ 

60 ValueError 

61 Raised if invalid name (e.g., name contains /) 

62 """ 

63 

64 def __init__(self, name, label, tags=None): 

65 if "/" in name: 

66 raise ValueError(f"Cluster's name cannot have a / ({name})") 

67 self.name = name 

68 self.label = label 

69 self._qgraph_node_ids = [] 

70 self._task_label_counts = Counter() 

71 self.tags = tags 

72 if self.tags is None: 

73 self.tags = {} 

74 

75 @classmethod 

76 def from_quantum_node(cls, quantum_node, template): 

77 """Create single quantum cluster from given quantum node. 

78 

79 Parameters 

80 ---------- 

81 quantum_node : `lsst.pipe.base.QuantumNode` 

82 QuantumNode for which to make into a single quantum cluster. 

83 

84 template : `str` 

85 Template for creating cluster name. 

86 

87 Returns 

88 ------- 

89 cluster : `QuantaCluster` 

90 Newly created cluster containing the given quantum. 

91 """ 

92 label = quantum_node.taskDef.label 

93 node_id = quantum_node.nodeId 

94 data_id = quantum_node.quantum.dataId 

95 

96 # Gather info for name template into a dictionary. 

97 info = data_id.byName() 

98 info["label"] = label 

99 info["node_number"] = node_id 

100 _LOG.debug("template = %s", template) 

101 _LOG.debug("info for template = %s", info) 

102 

103 # Use dictionary plus template format string to create name. To avoid 

104 # key errors from generic patterns, use defaultdict. 

105 name = template.format_map(defaultdict(lambda: "", info)) 

106 name = re.sub("_+", "_", name) 

107 _LOG.debug("template name = %s", name) 

108 

109 cluster = QuantaCluster(name, label, info) 

110 cluster.add_quantum(quantum_node.nodeId, label) 

111 return cluster 

112 

113 @property 

114 def qgraph_node_ids(self): 

115 """QuantumGraph NodeIds corresponding to this cluster.""" 

116 _LOG.debug("_qgraph_node_ids = %s", self._qgraph_node_ids) 

117 return frozenset(self._qgraph_node_ids) 

118 

119 @property 

120 def quanta_counts(self): 

121 """Counts of Quanta per taskDef.label in this cluster.""" 

122 return Counter(self._task_label_counts) 

123 

124 def add_quantum_node(self, quantum_node): 

125 """Add a quantumNode to this cluster. 

126 

127 Parameters 

128 ---------- 

129 quantum_node : `lsst.pipe.base.QuantumNode` 

130 """ 

131 _LOG.debug("quantum_node = %s", quantum_node) 

132 _LOG.debug("quantum_node.nodeId = %s", quantum_node.nodeId) 

133 self.add_quantum(quantum_node.nodeId, quantum_node.taskDef.label) 

134 

135 def add_quantum(self, node_id, task_label): 

136 """Add a quantumNode to this cluster. 

137 

138 Parameters 

139 ---------- 

140 node_id : `lsst.pipe.base.NodeId` 

141 ID for quantumNode to be added to cluster. 

142 task_label : `str` 

143 Task label for quantumNode to be added to cluster. 

144 """ 

145 self._qgraph_node_ids.append(node_id) 

146 self._task_label_counts[task_label] += 1 

147 

148 def __str__(self): 

149 return ( 

150 f"QuantaCluster(name={self.name},label={self.label},tags={self.tags}," 

151 f"counts={self.quanta_counts},ids={self.qgraph_node_ids})" 

152 ) 

153 

154 def __eq__(self, other: object) -> bool: 

155 # Doesn't check data equality, but only 

156 # name equality since those are supposed 

157 # to be unique. 

158 if isinstance(other, str): 

159 return self.name == other 

160 

161 if isinstance(other, QuantaCluster): 

162 return self.name == other.name 

163 

164 return False 

165 

166 def __hash__(self) -> int: 

167 return hash(self.name) 

168 

169 

170class ClusteredQuantumGraph: 

171 """Graph where the data for a node is a subgraph of the full 

172 QuantumGraph represented by a list of node ids. 

173 

174 Parameters 

175 ---------- 

176 name : `str` 

177 Name to be given to the ClusteredQuantumGraph. 

178 qgraph : `lsst.pipe.base.QuantumGraph` 

179 The QuantumGraph to be clustered. 

180 qgraph_filename : `str` 

181 Filename for given QuantumGraph if it has already been 

182 serialized. 

183 

184 Raises 

185 ------ 

186 ValueError 

187 Raised if invalid name (e.g., name contains /) 

188 

189 Notes 

190 ----- 

191 Using lsst.pipe.base.NodeId instead of integer because the QuantumGraph 

192 API requires them. Chose skipping the repeated creation of objects to 

193 use API over totally minimized memory usage. 

194 """ 

195 

196 def __init__(self, name, qgraph, qgraph_filename=None): 

197 if "/" in name: 

198 raise ValueError(f"name cannot have a / ({name})") 

199 self._name = name 

200 self._quantum_graph = qgraph 

201 self._quantum_graph_filename = Path(qgraph_filename).resolve() 

202 self._cluster_graph = DiGraph() 

203 

204 def __str__(self): 

205 return ( 

206 f"ClusteredQuantumGraph(name={self.name}," 

207 f"quantum_graph_filename={self._quantum_graph_filename}," 

208 f"len(qgraph)={len(self._quantum_graph) if self._quantum_graph else None}," 

209 f"len(cqgraph)={len(self._cluster_graph) if self._cluster_graph else None})" 

210 ) 

211 

212 def __len__(self): 

213 """Return the number of clusters.""" 

214 return len(self._cluster_graph) 

215 

216 @property 

217 def name(self): 

218 """The name of the ClusteredQuantumGraph.""" 

219 return self._name 

220 

221 @property 

222 def qgraph(self): 

223 """The QuantumGraph associated with this Clustered 

224 QuantumGraph. 

225 """ 

226 return self._quantum_graph 

227 

228 def add_cluster(self, clusters_for_adding): 

229 """Add a cluster of quanta as a node in the graph. 

230 

231 Parameters 

232 ---------- 

233 clusters_for_adding: `QuantaCluster` or `Iterable` [`QuantaCluster`] 

234 The cluster to be added to the ClusteredQuantumGraph. 

235 """ 

236 for cluster in ensure_iterable(clusters_for_adding): 

237 if not isinstance(cluster, QuantaCluster): 

238 raise TypeError(f"Must be type QuantaCluster (given: {type(cluster)})") 

239 

240 if self._cluster_graph.has_node(cluster.name): 

241 raise KeyError(f"Cluster {cluster.name} already exists in ClusteredQuantumGraph") 

242 

243 self._cluster_graph.add_node(cluster.name, cluster=cluster) 

244 

245 def get_cluster(self, name): 

246 """Retrieve a cluster from the ClusteredQuantumGraph by name. 

247 

248 Parameters 

249 ---------- 

250 name : `str` 

251 Name of cluster to retrieve. 

252 

253 Returns 

254 ------- 

255 cluster : `QuantaCluster` 

256 QuantaCluster matching given name. 

257 

258 Raises 

259 ------ 

260 KeyError 

261 Raised if the ClusteredQuantumGraph does not contain 

262 a cluster with given name. 

263 """ 

264 try: 

265 attr = self._cluster_graph.nodes[name] 

266 except KeyError as ex: 

267 raise KeyError(f"{self.name} does not have a cluster named {name}") from ex 

268 return attr["cluster"] 

269 

270 def get_quantum_node(self, id_): 

271 """Retrieve a QuantumNode from the ClusteredQuantumGraph by ID. 

272 

273 Parameters 

274 ---------- 

275 id_ : `lsst.pipe.base.NodeId` or int 

276 ID of the QuantumNode to retrieve. 

277 

278 Returns 

279 ------- 

280 quantum_node : `lsst.pipe.base.QuantumNode` 

281 QuantumNode matching given ID. 

282 

283 Raises 

284 ------ 

285 KeyError 

286 Raised if the ClusteredQuantumGraph does not contain 

287 a QuantumNode with given ID. 

288 """ 

289 node_id = id_ 

290 if isinstance(id_, int): 

291 node_id = NodeId(id, self._quantum_graph.graphID) 

292 _LOG.debug("get_quantum_node: node_id = %s", node_id) 

293 return self._quantum_graph.getQuantumNodeByNodeId(node_id) 

294 

295 def __iter__(self): 

296 """Iterate over names of clusters. 

297 

298 Returns 

299 ------- 

300 names : `Iterator` [`str`] 

301 Iterator over names of clusters. 

302 """ 

303 return self._cluster_graph.nodes() 

304 

305 def clusters(self): 

306 """Iterate over clusters. 

307 

308 Returns 

309 ------- 

310 clusters : `Iterator` [`lsst.ctrl.bps.QuantaCluster`] 

311 Iterator over clusters. 

312 """ 

313 return map(self.get_cluster, self._cluster_graph.nodes()) 

314 

315 def successors(self, name): 

316 """Return clusters that are successors of the cluster 

317 with the given name. 

318 

319 Parameters 

320 ---------- 

321 name : `str` 

322 Name of cluster for which need the successors. 

323 

324 Returns 

325 ------- 

326 clusters : `Iterator` [`lsst.ctrl.bps.QuantaCluster`] 

327 Iterator over successors of given cluster. 

328 """ 

329 return map(self.get_cluster, self._cluster_graph.successors(name)) 

330 

331 def predecessors(self, name): 

332 """Return clusters that are predecessors of the cluster 

333 with the given name. 

334 

335 Parameters 

336 ---------- 

337 name : `str` 

338 Name of cluster for which need the predecessors. 

339 

340 Returns 

341 ------- 

342 clusters : `Iterator` [`lsst.ctrl.bps.QuantaCluster`] 

343 Iterator over predecessors of given cluster. 

344 """ 

345 return map(self.get_cluster, self._cluster_graph.predecessors(name)) 

346 

347 def add_dependency(self, parent, child): 

348 """Add a directed dependency between a parent cluster and a child 

349 cluster. 

350 

351 Parameters 

352 ---------- 

353 parent : `str` or `QuantaCluster` 

354 Parent cluster. 

355 child : `str` or `QuantaCluster` 

356 Child cluster. 

357 

358 Raises 

359 ------ 

360 KeyError 

361 Raised if either the parent or child doesn't exist in the 

362 ClusteredQuantumGraph. 

363 """ 

364 if not self._cluster_graph.has_node(parent): 

365 raise KeyError(f"{self.name} does not have a cluster named {parent}") 

366 if not self._cluster_graph.has_node(child): 

367 raise KeyError(f"{self.name} does not have a cluster named {child}") 

368 _LOG.debug("add_dependency: adding edge %s %s", parent, child) 

369 

370 if isinstance(parent, QuantaCluster): 

371 pname = parent.name 

372 else: 

373 pname = parent 

374 

375 if isinstance(child, QuantaCluster): 

376 cname = child.name 

377 else: 

378 cname = child 

379 self._cluster_graph.add_edge(pname, cname) 

380 

381 def __contains__(self, name): 

382 """Check if a cluster with given name is in this ClusteredQuantumGraph. 

383 

384 Parameters 

385 ---------- 

386 name : `str` 

387 Name of cluster to check. 

388 

389 Returns 

390 ------- 

391 found : `bool` 

392 Whether a cluster with given name is in this ClusteredQuantumGraph. 

393 """ 

394 return self._cluster_graph.has_node(name) 

395 

396 def save(self, filename, format_=None): 

397 """Save the ClusteredQuantumGraph in a format that is loadable. 

398 The QuantumGraph is saved separately if hasn't already been 

399 serialized. 

400 

401 Parameters 

402 ---------- 

403 filename : `str` 

404 File to which the ClusteredQuantumGraph should be serialized. 

405 

406 format_ : `str`, optional 

407 Format in which to write the data. It defaults to pickle format. 

408 """ 

409 path = Path(filename) 

410 

411 # if format is None, try extension 

412 if format_ is None: 

413 format_ = path.suffix[1:] # suffix includes the leading period 

414 

415 if format_ not in {"pickle"}: 

416 raise RuntimeError(f"Unknown format ({format_})") 

417 

418 if not self._quantum_graph_filename: 

419 # Create filename based on given ClusteredQuantumGraph filename 

420 self._quantum_graph_filename = path.with_suffix(".qgraph") 

421 

422 # If QuantumGraph file doesn't already exist, save it: 

423 if not Path(self._quantum_graph_filename).exists(): 

424 self._quantum_graph.saveUri(self._quantum_graph_filename) 

425 

426 if format_ == "pickle": 

427 # Don't save QuantumGraph in same file. 

428 tmp_qgraph = self._quantum_graph 

429 self._quantum_graph = None 

430 with open(filename, "wb") as fh: 

431 pickle.dump(self, fh) 

432 # Return to original state. 

433 self._quantum_graph = tmp_qgraph 

434 

435 def draw(self, filename, format_=None): 

436 """Draw the ClusteredQuantumGraph in a given format. 

437 

438 Parameters 

439 ---------- 

440 filename : `str` 

441 File to which the ClusteredQuantumGraph should be serialized. 

442 

443 format_ : `str`, optional 

444 Format in which to draw the data. It defaults to dot format. 

445 """ 

446 path = Path(filename) 

447 

448 # if format is None, try extension 

449 if format_ is None: 

450 format_ = path.suffix[1:] # suffix includes the leading period 

451 

452 draw_funcs = {"dot": draw_networkx_dot} 

453 if format_ in draw_funcs: 

454 draw_funcs[format_](self._cluster_graph, filename) 

455 else: 

456 raise RuntimeError(f"Unknown draw format ({format_}") 

457 

458 @classmethod 

459 def load(cls, filename, format_=None): 

460 """Load a ClusteredQuantumGraph from the given file. 

461 

462 Parameters 

463 ---------- 

464 filename : `str` 

465 File from which to read the ClusteredQuantumGraph. 

466 format_ : `str`, optional 

467 Format of data to expect when loading from stream. It defaults 

468 to pickle format. 

469 

470 Returns 

471 ------- 

472 ClusteredQuantumGraph : `lsst.ctrl.bps.ClusteredQuantumGraph` 

473 ClusteredQuantumGraph workflow loaded from the given file. 

474 The QuantumGraph is loaded from its own file specified in 

475 the saved ClusteredQuantumGraph. 

476 """ 

477 path = Path(filename) 

478 

479 # if format is None, try extension 

480 if format_ is None: 

481 format_ = path.suffix[1:] # suffix includes the leading period 

482 

483 if format_ not in {"pickle"}: 

484 raise RuntimeError(f"Unknown format ({format_})") 

485 

486 cgraph = None 

487 if format_ == "pickle": 

488 dim_univ = DimensionUniverse() 

489 with open(filename, "rb") as fh: 

490 cgraph = pickle.load(fh) 

491 

492 # The QuantumGraph was saved separately 

493 try: 

494 cgraph._quantum_graph = QuantumGraph.loadUri(cgraph._quantum_graph_filename, dim_univ) 

495 except FileNotFoundError: # Try same path as ClusteredQuantumGraph 

496 new_filename = path.parent / Path(cgraph._quantum_graph_filename).name 

497 cgraph._quantum_graph = QuantumGraph.loadUri(new_filename, dim_univ) 

498 

499 return cgraph