Coverage for python/lsst/ctrl/bps/clustered_quantum_graph.py: 21%

170 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-03-22 02:14 -0700

1# This file is part of ctrl_bps. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22"""Class definitions for a Clustered QuantumGraph where a node in the graph is 

23a QuantumGraph. 

24""" 

25 

26__all__ = ["QuantaCluster", "ClusteredQuantumGraph"] 

27 

28 

29import logging 

30import pickle 

31import re 

32from collections import Counter, defaultdict 

33from pathlib import Path 

34 

35from lsst.pipe.base import NodeId 

36from lsst.utils.iteration import ensure_iterable 

37from networkx import DiGraph 

38 

39from .bps_draw import draw_networkx_dot 

40from .pre_transform import read_quantum_graph 

41 

42_LOG = logging.getLogger(__name__) 

43 

44 

45class QuantaCluster: 

46 """Information about the cluster and Quanta belonging to it. 

47 

48 Parameters 

49 ---------- 

50 name: `str` 

51 Lookup key (logical file name) of file/directory. Must 

52 be unique within ClusteredQuantumGraph. 

53 label: `str` 

54 Value used to group clusters. 

55 tags : `dict` [`str`, `Any`], optional 

56 Arbitrary key/value pairs for the cluster. 

57 

58 Raises 

59 ------ 

60 ValueError 

61 Raised if invalid name (e.g., name contains /) 

62 """ 

63 

64 def __init__(self, name, label, tags=None): 

65 if "/" in name: 

66 raise ValueError(f"Cluster's name cannot have a / ({name})") 

67 self.name = name 

68 self.label = label 

69 self._qgraph_node_ids = [] 

70 self._task_label_counts = Counter() 

71 self.tags = tags 

72 if self.tags is None: 

73 self.tags = {} 

74 

75 @classmethod 

76 def from_quantum_node(cls, quantum_node, template): 

77 """Create single quantum cluster from given quantum node. 

78 

79 Parameters 

80 ---------- 

81 quantum_node : `lsst.pipe.base.QuantumNode` 

82 QuantumNode for which to make into a single quantum cluster. 

83 

84 template : `str` 

85 Template for creating cluster name. 

86 

87 Returns 

88 ------- 

89 cluster : `QuantaCluster` 

90 Newly created cluster containing the given quantum. 

91 """ 

92 label = quantum_node.taskDef.label 

93 node_id = quantum_node.nodeId 

94 data_id = quantum_node.quantum.dataId 

95 

96 # Gather info for name template into a dictionary. 

97 info = data_id.byName() 

98 info["label"] = label 

99 info["node_number"] = node_id 

100 _LOG.debug("template = %s", template) 

101 _LOG.debug("info for template = %s", info) 

102 

103 # Use dictionary plus template format string to create name. To avoid 

104 # key errors from generic patterns, use defaultdict. 

105 try: 

106 name = template.format_map(defaultdict(lambda: "", info)) 

107 except TypeError: 

108 _LOG.error("Problems creating cluster name. template='%s', info=%s", template, info) 

109 raise 

110 name = re.sub("_+", "_", name) 

111 _LOG.debug("template name = %s", name) 

112 

113 cluster = QuantaCluster(name, label, info) 

114 cluster.add_quantum(quantum_node.nodeId, label) 

115 return cluster 

116 

117 @property 

118 def qgraph_node_ids(self): 

119 """QuantumGraph NodeIds corresponding to this cluster.""" 

120 _LOG.debug("_qgraph_node_ids = %s", self._qgraph_node_ids) 

121 return frozenset(self._qgraph_node_ids) 

122 

123 @property 

124 def quanta_counts(self): 

125 """Counts of Quanta per taskDef.label in this cluster.""" 

126 return Counter(self._task_label_counts) 

127 

128 def add_quantum_node(self, quantum_node): 

129 """Add a quantumNode to this cluster. 

130 

131 Parameters 

132 ---------- 

133 quantum_node : `lsst.pipe.base.QuantumNode` 

134 """ 

135 _LOG.debug("quantum_node = %s", quantum_node) 

136 _LOG.debug("quantum_node.nodeId = %s", quantum_node.nodeId) 

137 self.add_quantum(quantum_node.nodeId, quantum_node.taskDef.label) 

138 

139 def add_quantum(self, node_id, task_label): 

140 """Add a quantumNode to this cluster. 

141 

142 Parameters 

143 ---------- 

144 node_id : `lsst.pipe.base.NodeId` 

145 ID for quantumNode to be added to cluster. 

146 task_label : `str` 

147 Task label for quantumNode to be added to cluster. 

148 """ 

149 self._qgraph_node_ids.append(node_id) 

150 self._task_label_counts[task_label] += 1 

151 

152 def __str__(self): 

153 return ( 

154 f"QuantaCluster(name={self.name},label={self.label},tags={self.tags}," 

155 f"counts={self.quanta_counts},ids={self.qgraph_node_ids})" 

156 ) 

157 

158 def __eq__(self, other: object) -> bool: 

159 # Doesn't check data equality, but only 

160 # name equality since those are supposed 

161 # to be unique. 

162 if isinstance(other, str): 

163 return self.name == other 

164 

165 if isinstance(other, QuantaCluster): 

166 return self.name == other.name 

167 

168 return False 

169 

170 def __hash__(self) -> int: 

171 return hash(self.name) 

172 

173 

174class ClusteredQuantumGraph: 

175 """Graph where the data for a node is a subgraph of the full 

176 QuantumGraph represented by a list of node ids. 

177 

178 Parameters 

179 ---------- 

180 name : `str` 

181 Name to be given to the ClusteredQuantumGraph. 

182 qgraph : `lsst.pipe.base.QuantumGraph` 

183 The QuantumGraph to be clustered. 

184 qgraph_filename : `str` 

185 Filename for given QuantumGraph if it has already been 

186 serialized. 

187 butler_uri : `str` 

188 Location of butler repo used to create the QuantumGraph. 

189 

190 Raises 

191 ------ 

192 ValueError 

193 Raised if invalid name (e.g., name contains /) 

194 

195 Notes 

196 ----- 

197 Using lsst.pipe.base.NodeId instead of integer because the QuantumGraph 

198 API requires them. Chose skipping the repeated creation of objects to 

199 use API over totally minimized memory usage. 

200 """ 

201 

202 def __init__(self, name, qgraph, qgraph_filename=None, butler_uri=None): 

203 if "/" in name: 

204 raise ValueError(f"name cannot have a / ({name})") 

205 self._name = name 

206 self._quantum_graph = qgraph 

207 self._quantum_graph_filename = Path(qgraph_filename).resolve() 

208 self._butler_uri = Path(butler_uri).resolve() 

209 self._cluster_graph = DiGraph() 

210 

211 def __str__(self): 

212 return ( 

213 f"ClusteredQuantumGraph(name={self.name}," 

214 f"quantum_graph_filename={self._quantum_graph_filename}," 

215 f"butler_uri ={self._butler_uri}," 

216 f"len(qgraph)={len(self._quantum_graph) if self._quantum_graph else None}," 

217 f"len(cqgraph)={len(self._cluster_graph) if self._cluster_graph else None})" 

218 ) 

219 

220 def __len__(self): 

221 """Return the number of clusters.""" 

222 return len(self._cluster_graph) 

223 

224 @property 

225 def name(self): 

226 """The name of the ClusteredQuantumGraph.""" 

227 return self._name 

228 

229 @property 

230 def qgraph(self): 

231 """The QuantumGraph associated with this Clustered 

232 QuantumGraph. 

233 """ 

234 return self._quantum_graph 

235 

236 def add_cluster(self, clusters_for_adding): 

237 """Add a cluster of quanta as a node in the graph. 

238 

239 Parameters 

240 ---------- 

241 clusters_for_adding: `QuantaCluster` or `Iterable` [`QuantaCluster`] 

242 The cluster to be added to the ClusteredQuantumGraph. 

243 """ 

244 for cluster in ensure_iterable(clusters_for_adding): 

245 if not isinstance(cluster, QuantaCluster): 

246 raise TypeError(f"Must be type QuantaCluster (given: {type(cluster)})") 

247 

248 if self._cluster_graph.has_node(cluster.name): 

249 raise KeyError(f"Cluster {cluster.name} already exists in ClusteredQuantumGraph") 

250 

251 self._cluster_graph.add_node(cluster.name, cluster=cluster) 

252 

253 def get_cluster(self, name): 

254 """Retrieve a cluster from the ClusteredQuantumGraph by name. 

255 

256 Parameters 

257 ---------- 

258 name : `str` 

259 Name of cluster to retrieve. 

260 

261 Returns 

262 ------- 

263 cluster : `QuantaCluster` 

264 QuantaCluster matching given name. 

265 

266 Raises 

267 ------ 

268 KeyError 

269 Raised if the ClusteredQuantumGraph does not contain 

270 a cluster with given name. 

271 """ 

272 try: 

273 attr = self._cluster_graph.nodes[name] 

274 except KeyError as ex: 

275 raise KeyError(f"{self.name} does not have a cluster named {name}") from ex 

276 return attr["cluster"] 

277 

278 def get_quantum_node(self, id_): 

279 """Retrieve a QuantumNode from the ClusteredQuantumGraph by ID. 

280 

281 Parameters 

282 ---------- 

283 id_ : `lsst.pipe.base.NodeId` or int 

284 ID of the QuantumNode to retrieve. 

285 

286 Returns 

287 ------- 

288 quantum_node : `lsst.pipe.base.QuantumNode` 

289 QuantumNode matching given ID. 

290 

291 Raises 

292 ------ 

293 KeyError 

294 Raised if the ClusteredQuantumGraph does not contain 

295 a QuantumNode with given ID. 

296 """ 

297 node_id = id_ 

298 if isinstance(id_, int): 

299 node_id = NodeId(id, self._quantum_graph.graphID) 

300 _LOG.debug("get_quantum_node: node_id = %s", node_id) 

301 return self._quantum_graph.getQuantumNodeByNodeId(node_id) 

302 

303 def __iter__(self): 

304 """Iterate over names of clusters. 

305 

306 Returns 

307 ------- 

308 names : `Iterator` [`str`] 

309 Iterator over names of clusters. 

310 """ 

311 return self._cluster_graph.nodes() 

312 

313 def clusters(self): 

314 """Iterate over clusters. 

315 

316 Returns 

317 ------- 

318 clusters : `Iterator` [`lsst.ctrl.bps.QuantaCluster`] 

319 Iterator over clusters. 

320 """ 

321 return map(self.get_cluster, self._cluster_graph.nodes()) 

322 

323 def successors(self, name): 

324 """Return clusters that are successors of the cluster 

325 with the given name. 

326 

327 Parameters 

328 ---------- 

329 name : `str` 

330 Name of cluster for which need the successors. 

331 

332 Returns 

333 ------- 

334 clusters : `Iterator` [`lsst.ctrl.bps.QuantaCluster`] 

335 Iterator over successors of given cluster. 

336 """ 

337 return map(self.get_cluster, self._cluster_graph.successors(name)) 

338 

339 def predecessors(self, name): 

340 """Return clusters that are predecessors of the cluster 

341 with the given name. 

342 

343 Parameters 

344 ---------- 

345 name : `str` 

346 Name of cluster for which need the predecessors. 

347 

348 Returns 

349 ------- 

350 clusters : `Iterator` [`lsst.ctrl.bps.QuantaCluster`] 

351 Iterator over predecessors of given cluster. 

352 """ 

353 return map(self.get_cluster, self._cluster_graph.predecessors(name)) 

354 

355 def add_dependency(self, parent, child): 

356 """Add a directed dependency between a parent cluster and a child 

357 cluster. 

358 

359 Parameters 

360 ---------- 

361 parent : `str` or `QuantaCluster` 

362 Parent cluster. 

363 child : `str` or `QuantaCluster` 

364 Child cluster. 

365 

366 Raises 

367 ------ 

368 KeyError 

369 Raised if either the parent or child doesn't exist in the 

370 ClusteredQuantumGraph. 

371 """ 

372 if not self._cluster_graph.has_node(parent): 

373 raise KeyError(f"{self.name} does not have a cluster named {parent}") 

374 if not self._cluster_graph.has_node(child): 

375 raise KeyError(f"{self.name} does not have a cluster named {child}") 

376 _LOG.debug("add_dependency: adding edge %s %s", parent, child) 

377 

378 if isinstance(parent, QuantaCluster): 

379 pname = parent.name 

380 else: 

381 pname = parent 

382 

383 if isinstance(child, QuantaCluster): 

384 cname = child.name 

385 else: 

386 cname = child 

387 self._cluster_graph.add_edge(pname, cname) 

388 

389 def __contains__(self, name): 

390 """Check if a cluster with given name is in this ClusteredQuantumGraph. 

391 

392 Parameters 

393 ---------- 

394 name : `str` 

395 Name of cluster to check. 

396 

397 Returns 

398 ------- 

399 found : `bool` 

400 Whether a cluster with given name is in this ClusteredQuantumGraph. 

401 """ 

402 return self._cluster_graph.has_node(name) 

403 

404 def save(self, filename, format_=None): 

405 """Save the ClusteredQuantumGraph in a format that is loadable. 

406 The QuantumGraph is saved separately if hasn't already been 

407 serialized. 

408 

409 Parameters 

410 ---------- 

411 filename : `str` 

412 File to which the ClusteredQuantumGraph should be serialized. 

413 

414 format_ : `str`, optional 

415 Format in which to write the data. It defaults to pickle format. 

416 """ 

417 path = Path(filename) 

418 

419 # if format is None, try extension 

420 if format_ is None: 

421 format_ = path.suffix[1:] # suffix includes the leading period 

422 

423 if format_ not in {"pickle"}: 

424 raise RuntimeError(f"Unknown format ({format_})") 

425 

426 if not self._quantum_graph_filename: 

427 # Create filename based on given ClusteredQuantumGraph filename 

428 self._quantum_graph_filename = path.with_suffix(".qgraph") 

429 

430 # If QuantumGraph file doesn't already exist, save it: 

431 if not Path(self._quantum_graph_filename).exists(): 

432 self._quantum_graph.saveUri(self._quantum_graph_filename) 

433 

434 if format_ == "pickle": 

435 # Don't save QuantumGraph in same file. 

436 tmp_qgraph = self._quantum_graph 

437 self._quantum_graph = None 

438 with open(filename, "wb") as fh: 

439 pickle.dump(self, fh) 

440 # Return to original state. 

441 self._quantum_graph = tmp_qgraph 

442 

443 def draw(self, filename, format_=None): 

444 """Draw the ClusteredQuantumGraph in a given format. 

445 

446 Parameters 

447 ---------- 

448 filename : `str` 

449 File to which the ClusteredQuantumGraph should be serialized. 

450 

451 format_ : `str`, optional 

452 Format in which to draw the data. It defaults to dot format. 

453 """ 

454 path = Path(filename) 

455 

456 # if format is None, try extension 

457 if format_ is None: 

458 format_ = path.suffix[1:] # suffix includes the leading period 

459 

460 draw_funcs = {"dot": draw_networkx_dot} 

461 if format_ in draw_funcs: 

462 draw_funcs[format_](self._cluster_graph, filename) 

463 else: 

464 raise RuntimeError(f"Unknown draw format ({format_}") 

465 

466 @classmethod 

467 def load(cls, filename, format_=None): 

468 """Load a ClusteredQuantumGraph from the given file. 

469 

470 Parameters 

471 ---------- 

472 filename : `str` 

473 File from which to read the ClusteredQuantumGraph. 

474 format_ : `str`, optional 

475 Format of data to expect when loading from stream. It defaults 

476 to pickle format. 

477 

478 Returns 

479 ------- 

480 ClusteredQuantumGraph : `lsst.ctrl.bps.ClusteredQuantumGraph` 

481 ClusteredQuantumGraph workflow loaded from the given file. 

482 The QuantumGraph is loaded from its own file specified in 

483 the saved ClusteredQuantumGraph. 

484 """ 

485 path = Path(filename) 

486 

487 # if format is None, try extension 

488 if format_ is None: 

489 format_ = path.suffix[1:] # suffix includes the leading period 

490 

491 if format_ not in {"pickle"}: 

492 raise RuntimeError(f"Unknown format ({format_})") 

493 

494 cgraph = None 

495 if format_ == "pickle": 

496 with open(filename, "rb") as fh: 

497 cgraph = pickle.load(fh) 

498 

499 # The QuantumGraph was saved separately 

500 try: 

501 cgraph._quantum_graph = read_quantum_graph(cgraph._quantum_graph_filename, cgraph._butler_uri) 

502 except FileNotFoundError: # Try same path as ClusteredQuantumGraph 

503 new_filename = path.parent / Path(cgraph._quantum_graph_filename).name 

504 cgraph._quantum_graph = read_quantum_graph(new_filename, cgraph._butler_uri) 

505 

506 return cgraph