Coverage for python/lsst/ctrl/bps/clustered_quantum_graph.py: 23%

Shortcuts on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

170 statements  

1# This file is part of ctrl_bps. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22"""Class definitions for a Clustered QuantumGraph where a node in the graph is 

23a QuantumGraph. 

24""" 

25 

26__all__ = ["QuantaCluster", "ClusteredQuantumGraph"] 

27 

28 

29import logging 

30import pickle 

31import re 

32from collections import Counter, defaultdict 

33from pathlib import Path 

34 

35from lsst.daf.butler import DimensionUniverse 

36from lsst.pipe.base import NodeId, QuantumGraph 

37from lsst.utils.iteration import ensure_iterable 

38from networkx import DiGraph 

39 

40from .bps_draw import draw_networkx_dot 

41 

42_LOG = logging.getLogger(__name__) 

43 

44 

45class QuantaCluster: 

46 """Information about the cluster and Quanta belonging to it. 

47 

48 Parameters 

49 ---------- 

50 name: `str` 

51 Lookup key (logical file name) of file/directory. Must 

52 be unique within ClusteredQuantumGraph. 

53 label: `str` 

54 Value used to group clusters. 

55 tags : `dict` [`str`, `Any`], optional 

56 Arbitrary key/value pairs for the cluster. 

57 

58 Raises 

59 ------ 

60 ValueError 

61 Raised if invalid name (e.g., name contains /) 

62 """ 

63 

64 def __init__(self, name, label, tags=None): 

65 if "/" in name: 

66 raise ValueError(f"Cluster's name cannot have a / ({name})") 

67 self.name = name 

68 self.label = label 

69 self._qgraph_node_ids = [] 

70 self._task_label_counts = Counter() 

71 self.tags = tags 

72 if self.tags is None: 

73 self.tags = {} 

74 

75 @classmethod 

76 def from_quantum_node(cls, quantum_node, template): 

77 """Create single quantum cluster from given quantum node. 

78 

79 Parameters 

80 ---------- 

81 quantum_node : `lsst.pipe.base.QuantumNode` 

82 QuantumNode for which to make into a single quantum cluster. 

83 

84 template : `str` 

85 Template for creating cluster name. 

86 

87 Returns 

88 ------- 

89 cluster : `QuantaCluster` 

90 Newly created cluster containing the given quantum. 

91 """ 

92 label = quantum_node.taskDef.label 

93 node_id = quantum_node.nodeId 

94 data_id = quantum_node.quantum.dataId 

95 

96 # Gather info for name template into a dictionary. 

97 info = data_id.byName() 

98 info["label"] = label 

99 info["node_number"] = node_id 

100 _LOG.debug("template = %s", template) 

101 _LOG.debug("info for template = %s", info) 

102 

103 # Use dictionary plus template format string to create name. To avoid 

104 # key errors from generic patterns, use defaultdict. 

105 try: 

106 name = template.format_map(defaultdict(lambda: "", info)) 

107 except TypeError: 

108 _LOG.error("Problems creating cluster name. template='%s', info=%s", template, info) 

109 raise 

110 name = re.sub("_+", "_", name) 

111 _LOG.debug("template name = %s", name) 

112 

113 cluster = QuantaCluster(name, label, info) 

114 cluster.add_quantum(quantum_node.nodeId, label) 

115 return cluster 

116 

117 @property 

118 def qgraph_node_ids(self): 

119 """QuantumGraph NodeIds corresponding to this cluster.""" 

120 _LOG.debug("_qgraph_node_ids = %s", self._qgraph_node_ids) 

121 return frozenset(self._qgraph_node_ids) 

122 

123 @property 

124 def quanta_counts(self): 

125 """Counts of Quanta per taskDef.label in this cluster.""" 

126 return Counter(self._task_label_counts) 

127 

128 def add_quantum_node(self, quantum_node): 

129 """Add a quantumNode to this cluster. 

130 

131 Parameters 

132 ---------- 

133 quantum_node : `lsst.pipe.base.QuantumNode` 

134 """ 

135 _LOG.debug("quantum_node = %s", quantum_node) 

136 _LOG.debug("quantum_node.nodeId = %s", quantum_node.nodeId) 

137 self.add_quantum(quantum_node.nodeId, quantum_node.taskDef.label) 

138 

139 def add_quantum(self, node_id, task_label): 

140 """Add a quantumNode to this cluster. 

141 

142 Parameters 

143 ---------- 

144 node_id : `lsst.pipe.base.NodeId` 

145 ID for quantumNode to be added to cluster. 

146 task_label : `str` 

147 Task label for quantumNode to be added to cluster. 

148 """ 

149 self._qgraph_node_ids.append(node_id) 

150 self._task_label_counts[task_label] += 1 

151 

152 def __str__(self): 

153 return ( 

154 f"QuantaCluster(name={self.name},label={self.label},tags={self.tags}," 

155 f"counts={self.quanta_counts},ids={self.qgraph_node_ids})" 

156 ) 

157 

158 def __eq__(self, other: object) -> bool: 

159 # Doesn't check data equality, but only 

160 # name equality since those are supposed 

161 # to be unique. 

162 if isinstance(other, str): 

163 return self.name == other 

164 

165 if isinstance(other, QuantaCluster): 

166 return self.name == other.name 

167 

168 return False 

169 

170 def __hash__(self) -> int: 

171 return hash(self.name) 

172 

173 

174class ClusteredQuantumGraph: 

175 """Graph where the data for a node is a subgraph of the full 

176 QuantumGraph represented by a list of node ids. 

177 

178 Parameters 

179 ---------- 

180 name : `str` 

181 Name to be given to the ClusteredQuantumGraph. 

182 qgraph : `lsst.pipe.base.QuantumGraph` 

183 The QuantumGraph to be clustered. 

184 qgraph_filename : `str` 

185 Filename for given QuantumGraph if it has already been 

186 serialized. 

187 

188 Raises 

189 ------ 

190 ValueError 

191 Raised if invalid name (e.g., name contains /) 

192 

193 Notes 

194 ----- 

195 Using lsst.pipe.base.NodeId instead of integer because the QuantumGraph 

196 API requires them. Chose skipping the repeated creation of objects to 

197 use API over totally minimized memory usage. 

198 """ 

199 

200 def __init__(self, name, qgraph, qgraph_filename=None): 

201 if "/" in name: 

202 raise ValueError(f"name cannot have a / ({name})") 

203 self._name = name 

204 self._quantum_graph = qgraph 

205 self._quantum_graph_filename = Path(qgraph_filename).resolve() 

206 self._cluster_graph = DiGraph() 

207 

208 def __str__(self): 

209 return ( 

210 f"ClusteredQuantumGraph(name={self.name}," 

211 f"quantum_graph_filename={self._quantum_graph_filename}," 

212 f"len(qgraph)={len(self._quantum_graph) if self._quantum_graph else None}," 

213 f"len(cqgraph)={len(self._cluster_graph) if self._cluster_graph else None})" 

214 ) 

215 

216 def __len__(self): 

217 """Return the number of clusters.""" 

218 return len(self._cluster_graph) 

219 

220 @property 

221 def name(self): 

222 """The name of the ClusteredQuantumGraph.""" 

223 return self._name 

224 

225 @property 

226 def qgraph(self): 

227 """The QuantumGraph associated with this Clustered 

228 QuantumGraph. 

229 """ 

230 return self._quantum_graph 

231 

232 def add_cluster(self, clusters_for_adding): 

233 """Add a cluster of quanta as a node in the graph. 

234 

235 Parameters 

236 ---------- 

237 clusters_for_adding: `QuantaCluster` or `Iterable` [`QuantaCluster`] 

238 The cluster to be added to the ClusteredQuantumGraph. 

239 """ 

240 for cluster in ensure_iterable(clusters_for_adding): 

241 if not isinstance(cluster, QuantaCluster): 

242 raise TypeError(f"Must be type QuantaCluster (given: {type(cluster)})") 

243 

244 if self._cluster_graph.has_node(cluster.name): 

245 raise KeyError(f"Cluster {cluster.name} already exists in ClusteredQuantumGraph") 

246 

247 self._cluster_graph.add_node(cluster.name, cluster=cluster) 

248 

249 def get_cluster(self, name): 

250 """Retrieve a cluster from the ClusteredQuantumGraph by name. 

251 

252 Parameters 

253 ---------- 

254 name : `str` 

255 Name of cluster to retrieve. 

256 

257 Returns 

258 ------- 

259 cluster : `QuantaCluster` 

260 QuantaCluster matching given name. 

261 

262 Raises 

263 ------ 

264 KeyError 

265 Raised if the ClusteredQuantumGraph does not contain 

266 a cluster with given name. 

267 """ 

268 try: 

269 attr = self._cluster_graph.nodes[name] 

270 except KeyError as ex: 

271 raise KeyError(f"{self.name} does not have a cluster named {name}") from ex 

272 return attr["cluster"] 

273 

274 def get_quantum_node(self, id_): 

275 """Retrieve a QuantumNode from the ClusteredQuantumGraph by ID. 

276 

277 Parameters 

278 ---------- 

279 id_ : `lsst.pipe.base.NodeId` or int 

280 ID of the QuantumNode to retrieve. 

281 

282 Returns 

283 ------- 

284 quantum_node : `lsst.pipe.base.QuantumNode` 

285 QuantumNode matching given ID. 

286 

287 Raises 

288 ------ 

289 KeyError 

290 Raised if the ClusteredQuantumGraph does not contain 

291 a QuantumNode with given ID. 

292 """ 

293 node_id = id_ 

294 if isinstance(id_, int): 

295 node_id = NodeId(id, self._quantum_graph.graphID) 

296 _LOG.debug("get_quantum_node: node_id = %s", node_id) 

297 return self._quantum_graph.getQuantumNodeByNodeId(node_id) 

298 

299 def __iter__(self): 

300 """Iterate over names of clusters. 

301 

302 Returns 

303 ------- 

304 names : `Iterator` [`str`] 

305 Iterator over names of clusters. 

306 """ 

307 return self._cluster_graph.nodes() 

308 

309 def clusters(self): 

310 """Iterate over clusters. 

311 

312 Returns 

313 ------- 

314 clusters : `Iterator` [`lsst.ctrl.bps.QuantaCluster`] 

315 Iterator over clusters. 

316 """ 

317 return map(self.get_cluster, self._cluster_graph.nodes()) 

318 

319 def successors(self, name): 

320 """Return clusters that are successors of the cluster 

321 with the given name. 

322 

323 Parameters 

324 ---------- 

325 name : `str` 

326 Name of cluster for which need the successors. 

327 

328 Returns 

329 ------- 

330 clusters : `Iterator` [`lsst.ctrl.bps.QuantaCluster`] 

331 Iterator over successors of given cluster. 

332 """ 

333 return map(self.get_cluster, self._cluster_graph.successors(name)) 

334 

335 def predecessors(self, name): 

336 """Return clusters that are predecessors of the cluster 

337 with the given name. 

338 

339 Parameters 

340 ---------- 

341 name : `str` 

342 Name of cluster for which need the predecessors. 

343 

344 Returns 

345 ------- 

346 clusters : `Iterator` [`lsst.ctrl.bps.QuantaCluster`] 

347 Iterator over predecessors of given cluster. 

348 """ 

349 return map(self.get_cluster, self._cluster_graph.predecessors(name)) 

350 

351 def add_dependency(self, parent, child): 

352 """Add a directed dependency between a parent cluster and a child 

353 cluster. 

354 

355 Parameters 

356 ---------- 

357 parent : `str` or `QuantaCluster` 

358 Parent cluster. 

359 child : `str` or `QuantaCluster` 

360 Child cluster. 

361 

362 Raises 

363 ------ 

364 KeyError 

365 Raised if either the parent or child doesn't exist in the 

366 ClusteredQuantumGraph. 

367 """ 

368 if not self._cluster_graph.has_node(parent): 

369 raise KeyError(f"{self.name} does not have a cluster named {parent}") 

370 if not self._cluster_graph.has_node(child): 

371 raise KeyError(f"{self.name} does not have a cluster named {child}") 

372 _LOG.debug("add_dependency: adding edge %s %s", parent, child) 

373 

374 if isinstance(parent, QuantaCluster): 

375 pname = parent.name 

376 else: 

377 pname = parent 

378 

379 if isinstance(child, QuantaCluster): 

380 cname = child.name 

381 else: 

382 cname = child 

383 self._cluster_graph.add_edge(pname, cname) 

384 

385 def __contains__(self, name): 

386 """Check if a cluster with given name is in this ClusteredQuantumGraph. 

387 

388 Parameters 

389 ---------- 

390 name : `str` 

391 Name of cluster to check. 

392 

393 Returns 

394 ------- 

395 found : `bool` 

396 Whether a cluster with given name is in this ClusteredQuantumGraph. 

397 """ 

398 return self._cluster_graph.has_node(name) 

399 

400 def save(self, filename, format_=None): 

401 """Save the ClusteredQuantumGraph in a format that is loadable. 

402 The QuantumGraph is saved separately if hasn't already been 

403 serialized. 

404 

405 Parameters 

406 ---------- 

407 filename : `str` 

408 File to which the ClusteredQuantumGraph should be serialized. 

409 

410 format_ : `str`, optional 

411 Format in which to write the data. It defaults to pickle format. 

412 """ 

413 path = Path(filename) 

414 

415 # if format is None, try extension 

416 if format_ is None: 

417 format_ = path.suffix[1:] # suffix includes the leading period 

418 

419 if format_ not in {"pickle"}: 

420 raise RuntimeError(f"Unknown format ({format_})") 

421 

422 if not self._quantum_graph_filename: 

423 # Create filename based on given ClusteredQuantumGraph filename 

424 self._quantum_graph_filename = path.with_suffix(".qgraph") 

425 

426 # If QuantumGraph file doesn't already exist, save it: 

427 if not Path(self._quantum_graph_filename).exists(): 

428 self._quantum_graph.saveUri(self._quantum_graph_filename) 

429 

430 if format_ == "pickle": 

431 # Don't save QuantumGraph in same file. 

432 tmp_qgraph = self._quantum_graph 

433 self._quantum_graph = None 

434 with open(filename, "wb") as fh: 

435 pickle.dump(self, fh) 

436 # Return to original state. 

437 self._quantum_graph = tmp_qgraph 

438 

439 def draw(self, filename, format_=None): 

440 """Draw the ClusteredQuantumGraph in a given format. 

441 

442 Parameters 

443 ---------- 

444 filename : `str` 

445 File to which the ClusteredQuantumGraph should be serialized. 

446 

447 format_ : `str`, optional 

448 Format in which to draw the data. It defaults to dot format. 

449 """ 

450 path = Path(filename) 

451 

452 # if format is None, try extension 

453 if format_ is None: 

454 format_ = path.suffix[1:] # suffix includes the leading period 

455 

456 draw_funcs = {"dot": draw_networkx_dot} 

457 if format_ in draw_funcs: 

458 draw_funcs[format_](self._cluster_graph, filename) 

459 else: 

460 raise RuntimeError(f"Unknown draw format ({format_}") 

461 

462 @classmethod 

463 def load(cls, filename, format_=None): 

464 """Load a ClusteredQuantumGraph from the given file. 

465 

466 Parameters 

467 ---------- 

468 filename : `str` 

469 File from which to read the ClusteredQuantumGraph. 

470 format_ : `str`, optional 

471 Format of data to expect when loading from stream. It defaults 

472 to pickle format. 

473 

474 Returns 

475 ------- 

476 ClusteredQuantumGraph : `lsst.ctrl.bps.ClusteredQuantumGraph` 

477 ClusteredQuantumGraph workflow loaded from the given file. 

478 The QuantumGraph is loaded from its own file specified in 

479 the saved ClusteredQuantumGraph. 

480 """ 

481 path = Path(filename) 

482 

483 # if format is None, try extension 

484 if format_ is None: 

485 format_ = path.suffix[1:] # suffix includes the leading period 

486 

487 if format_ not in {"pickle"}: 

488 raise RuntimeError(f"Unknown format ({format_})") 

489 

490 cgraph = None 

491 if format_ == "pickle": 

492 dim_univ = DimensionUniverse() 

493 with open(filename, "rb") as fh: 

494 cgraph = pickle.load(fh) 

495 

496 # The QuantumGraph was saved separately 

497 try: 

498 cgraph._quantum_graph = QuantumGraph.loadUri(cgraph._quantum_graph_filename, dim_univ) 

499 except FileNotFoundError: # Try same path as ClusteredQuantumGraph 

500 new_filename = path.parent / Path(cgraph._quantum_graph_filename).name 

501 cgraph._quantum_graph = QuantumGraph.loadUri(new_filename, dim_univ) 

502 

503 return cgraph