Coverage for python/lsst/pipe/base/quantum_graph_skeleton.py: 43%

136 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-19 11:28 +0000

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27 

28"""An under-construction version of QuantumGraph and various helper 

29classes. 

30""" 

31 

32from __future__ import annotations 

33 

34__all__ = ( 

35 "QuantumGraphSkeleton", 

36 "QuantumKey", 

37 "TaskInitKey", 

38 "DatasetKey", 

39 "PrerequisiteDatasetKey", 

40) 

41 

42from collections.abc import Iterable, Iterator, MutableMapping, Set 

43from typing import TYPE_CHECKING, Any, ClassVar, Literal, NamedTuple 

44 

45import networkx 

46from lsst.daf.butler import DataCoordinate, DataIdValue, DatasetRef 

47from lsst.utils.logging import getLogger 

48 

49if TYPE_CHECKING: 

50 pass 

51 

52_LOG = getLogger(__name__) 

53 

54 

55class QuantumKey(NamedTuple): 

56 """Identifier type for quantum keys in a `QuantumGraphSkeleton`.""" 

57 

58 task_label: str 

59 """Label of the task in the pipeline.""" 

60 

61 data_id_values: tuple[DataIdValue, ...] 

62 """Data ID values of the quantum. 

63 

64 Note that keys are fixed given `task_label`, so using only the values here 

65 speeds up comparisons. 

66 """ 

67 

68 is_task: ClassVar[Literal[True]] = True 

69 """Whether this node represents a quantum or task initialization rather 

70 than a dataset (always `True`). 

71 """ 

72 

73 

74class TaskInitKey(NamedTuple): 

75 """Identifier type for task init keys in a `QuantumGraphSkeleton`.""" 

76 

77 task_label: str 

78 """Label of the task in the pipeline.""" 

79 

80 is_task: ClassVar[Literal[True]] = True 

81 """Whether this node represents a quantum or task initialization rather 

82 than a dataset (always `True`). 

83 """ 

84 

85 

86class DatasetKey(NamedTuple): 

87 """Identifier type for dataset keys in a `QuantumGraphSkeleton`.""" 

88 

89 parent_dataset_type_name: str 

90 """Name of the dataset type (never a component).""" 

91 

92 data_id_values: tuple[DataIdValue, ...] 

93 """Data ID values of the dataset. 

94 

95 Note that keys are fixed given `parent_dataset_type_name`, so using only 

96 the values here speeds up comparisons. 

97 """ 

98 

99 is_task: ClassVar[Literal[False]] = False 

100 """Whether this node represents a quantum or task initialization rather 

101 than a dataset (always `False`). 

102 """ 

103 

104 is_prerequisite: ClassVar[Literal[False]] = False 

105 

106 

107class PrerequisiteDatasetKey(NamedTuple): 

108 """Identifier type for prerequisite dataset keys in a 

109 `QuantumGraphSkeleton`. 

110 

111 Unlike regular datasets, prerequisites are not actually required to come 

112 from a find-first search of `input_collections`, so we don't want to 

113 assume that the same data ID implies the same dataset. Happily we also 

114 don't need to search for them by data ID in the graph, so we can use the 

115 dataset ID (UUID) instead. 

116 """ 

117 

118 parent_dataset_type_name: str 

119 """Name of the dataset type (never a component).""" 

120 

121 dataset_id_bytes: bytes 

122 """Dataset ID (UUID) as raw bytes.""" 

123 

124 is_task: ClassVar[Literal[False]] = False 

125 """Whether this node represents a quantum or task initialization rather 

126 than a dataset (always `False`). 

127 """ 

128 

129 is_prerequisite: ClassVar[Literal[True]] = True 

130 

131 

132class QuantumGraphSkeleton: 

133 """An under-construction quantum graph. 

134 

135 QuantumGraphSkeleton is intended for use inside `QuantumGraphBuilder` and 

136 its subclasses. 

137 

138 Parameters 

139 ---------- 

140 task_labels : `~collections.abc.Iterable` [ `str` ] 

141 The labels of all tasks whose quanta may be included in the graph, in 

142 topological order. 

143 

144 Notes 

145 ----- 

146 QuantumGraphSkeleton models a bipartite version of the quantum graph, in 

147 which both quanta and datasets are represented as nodes and each type of 

148 node only has edges to the other type. 

149 

150 Square-bracket (`getitem`) indexing returns a mutable mapping of a node's 

151 flexible attributes. 

152 

153 The details of the `QuantumGraphSkeleton` API (e.g. which operations 

154 operate on multiple nodes vs. a single node) are set by what's actually 

155 needed by current quantum graph generation algorithms. New variants can be 

156 added as needed, but adding all operations that *might* be useful for some 

157 future algorithm seems premature. 

158 """ 

159 

160 def __init__(self, task_labels: Iterable[str]): 

161 self._tasks: dict[str, tuple[TaskInitKey, set[QuantumKey]]] = {} 

162 self._xgraph: networkx.DiGraph = networkx.DiGraph() 

163 self._global_init_outputs: set[DatasetKey] = set() 

164 for task_label in task_labels: 

165 task_init_key = TaskInitKey(task_label) 

166 self._tasks[task_label] = (task_init_key, set()) 

167 self._xgraph.add_node(task_init_key) 

168 

169 def __contains__(self, key: QuantumKey | TaskInitKey | DatasetKey | PrerequisiteDatasetKey) -> bool: 

170 return key in self._xgraph.nodes 

171 

172 def __getitem__( 

173 self, key: QuantumKey | TaskInitKey | DatasetKey | PrerequisiteDatasetKey 

174 ) -> MutableMapping[str, Any]: 

175 return self._xgraph.nodes[key] 

176 

177 @property 

178 def n_nodes(self) -> int: 

179 """The total number of nodes of all types.""" 

180 return len(self._xgraph.nodes) 

181 

182 @property 

183 def n_edges(self) -> int: 

184 """The total number of edges.""" 

185 return len(self._xgraph.edges) 

186 

187 def has_task(self, task_label: str) -> bool: 

188 """Test whether the given task is in this skeleton. 

189 

190 Tasks are only added to the skeleton at initialization, but may be 

191 removed by `remove_task` if they end up having no quanta. 

192 

193 Parameters 

194 ---------- 

195 task_label : `str` 

196 Task to check for. 

197 

198 Returns 

199 ------- 

200 has : `bool` 

201 `True` if the task is in this skeleton. 

202 """ 

203 return task_label in self._tasks 

204 

205 def get_task_init_node(self, task_label: str) -> TaskInitKey: 

206 """Return the graph node that represents a task's initialization. 

207 

208 Parameters 

209 ---------- 

210 task_label : `str` 

211 The task label to use. 

212 

213 Returns 

214 ------- 

215 node : `TaskInitKey` 

216 The graph node representing this task's initialization. 

217 """ 

218 return self._tasks[task_label][0] 

219 

220 def get_quanta(self, task_label: str) -> Set[QuantumKey]: 

221 """Return the quanta for the given task label. 

222 

223 Parameters 

224 ---------- 

225 task_label : `str` 

226 Label for the task. 

227 

228 Returns 

229 ------- 

230 quanta : `~collections.abc.Set` [ `QuantumKey` ] 

231 A set-like object with the identifiers of all quanta for the given 

232 task. *The skeleton object's set of quanta must not be modified 

233 while iterating over this container; make a copy if mutation during 

234 iteration is necessary*. 

235 """ 

236 return self._tasks[task_label][1] 

237 

238 @property 

239 def global_init_outputs(self) -> Set[DatasetKey]: 

240 """The set of dataset nodes that are not associated with any task.""" 

241 return self._global_init_outputs 

242 

243 def iter_all_quanta(self) -> Iterator[QuantumKey]: 

244 """Iterate over all quanta from any task, in topological (but otherwise 

245 unspecified) order. 

246 """ 

247 for _, quanta in self._tasks.values(): 

248 yield from quanta 

249 

250 def iter_outputs_of(self, quantum_key: QuantumKey | TaskInitKey) -> Iterator[DatasetKey]: 

251 """Iterate over the datasets produced by the given quantum. 

252 

253 Parameters 

254 ---------- 

255 quantum_key : `QuantumKey` or `TaskInitKey` 

256 Quantum to iterate over. 

257 

258 Returns 

259 ------- 

260 datasets : `~collections.abc.Iterator` of `DatasetKey` 

261 Datasets produced by the given quanta. 

262 """ 

263 return self._xgraph.successors(quantum_key) 

264 

265 def iter_inputs_of( 

266 self, quantum_key: QuantumKey | TaskInitKey 

267 ) -> Iterator[DatasetKey | PrerequisiteDatasetKey]: 

268 """Iterate over the datasets consumed by the given quantum. 

269 

270 Parameters 

271 ---------- 

272 quantum_key : `QuantumKey` or `TaskInitKey` 

273 Quantum to iterate over. 

274 

275 Returns 

276 ------- 

277 datasets : `~collections.abc.Iterator` of `DatasetKey` \ 

278 or `PrequisiteDatasetKey` 

279 Datasets consumed by the given quanta. 

280 """ 

281 return self._xgraph.predecessors(quantum_key) 

282 

283 def update(self, other: QuantumGraphSkeleton) -> None: 

284 """Copy all nodes from ``other`` to ``self``. 

285 

286 Parameters 

287 ---------- 

288 other : `QuantumGraphSkeleton` 

289 Source of nodes. The tasks in ``other`` must be a subset of the 

290 tasks in ``self`` (this method is expected to be used to populate 

291 a skeleton for a full from independent-subgraph skeletons). 

292 """ 

293 for task_label, (_, quanta) in other._tasks.items(): 

294 self._tasks[task_label][1].update(quanta) 

295 self._xgraph.update(other._xgraph) 

296 

297 def add_quantum_node(self, task_label: str, data_id: DataCoordinate, **attrs: Any) -> QuantumKey: 

298 """Add a new node representing a quantum. 

299 

300 Parameters 

301 ---------- 

302 task_label : `str` 

303 Name of task. 

304 data_id : `~lsst.daf.butler.DataCoordinate` 

305 The data ID of the quantum. 

306 **attrs : `~typing.Any` 

307 Additional attributes. 

308 """ 

309 key = QuantumKey(task_label, data_id.required_values) 

310 self._xgraph.add_node(key, data_id=data_id, **attrs) 

311 self._tasks[key.task_label][1].add(key) 

312 return key 

313 

314 def add_dataset_node( 

315 self, 

316 parent_dataset_type_name: str, 

317 data_id: DataCoordinate, 

318 is_global_init_output: bool = False, 

319 **attrs: Any, 

320 ) -> DatasetKey: 

321 """Add a new node representing a dataset. 

322 

323 Parameters 

324 ---------- 

325 parent_dataset_type_name : `str` 

326 Name of the parent dataset type. 

327 data_id : `~lsst.daf.butler.DataCoordinate` 

328 The dataset data ID. 

329 is_global_init_output : `bool`, optional 

330 Whether this dataset is a global init output. 

331 **attrs : `~typing.Any` 

332 Additional attributes for the node. 

333 """ 

334 key = DatasetKey(parent_dataset_type_name, data_id.required_values) 

335 self._xgraph.add_node(key, data_id=data_id, **attrs) 

336 if is_global_init_output: 

337 assert isinstance(key, DatasetKey) 

338 self._global_init_outputs.add(key) 

339 return key 

340 

341 def add_prerequisite_node( 

342 self, 

343 parent_dataset_type_name: str, 

344 ref: DatasetRef, 

345 **attrs: Any, 

346 ) -> PrerequisiteDatasetKey: 

347 """Add a new node representing a prerequisite input dataset. 

348 

349 Parameters 

350 ---------- 

351 parent_dataset_type_name : `str` 

352 Name of the parent dataset type. 

353 ref : `~lsst.daf.butler.DatasetRef` 

354 The dataset ref of the pre-requisite. 

355 **attrs : `~typing.Any` 

356 Additional attributes for the node. 

357 """ 

358 key = PrerequisiteDatasetKey(parent_dataset_type_name, ref.id.bytes) 

359 self._xgraph.add_node(key, data_id=ref.dataId, ref=ref, **attrs) 

360 return key 

361 

362 def remove_quantum_node(self, key: QuantumKey, remove_outputs: bool) -> None: 

363 """Remove a node representing a quantum. 

364 

365 Parameters 

366 ---------- 

367 key : `QuantumKey` 

368 Identifier for the node. 

369 remove_outputs : `bool` 

370 If `True`, also remove all dataset nodes produced by this quantum. 

371 If `False`, any such dataset nodes will become overall inputs. 

372 """ 

373 _, quanta = self._tasks[key.task_label] 

374 quanta.remove(key) 

375 if remove_outputs: 

376 to_remove = list(self._xgraph.successors(key)) 

377 to_remove.append(key) 

378 self._xgraph.remove_nodes_from(to_remove) 

379 else: 

380 self._xgraph.remove_node(key) 

381 

382 def remove_dataset_nodes(self, keys: Iterable[DatasetKey | PrerequisiteDatasetKey]) -> None: 

383 """Remove nodes representing datasets. 

384 

385 Parameters 

386 ---------- 

387 keys : `~collections.abc.Iterable` of `DatasetKey`\ 

388 or `PrerequisiteDatasetKey` 

389 Nodes to remove. 

390 """ 

391 self._xgraph.remove_nodes_from(keys) 

392 

393 def remove_task(self, task_label: str) -> None: 

394 """Fully remove a task from the skeleton. 

395 

396 All init-output datasets and quanta for the task must already have been 

397 removed. 

398 

399 Parameters 

400 ---------- 

401 task_label : `str` 

402 Name of task to remove. 

403 """ 

404 task_init_key, quanta = self._tasks.pop(task_label) 

405 assert not quanta, "Cannot remove task unless all quanta have already been removed." 

406 assert not list(self._xgraph.successors(task_init_key)) 

407 self._xgraph.remove_node(task_init_key) 

408 

409 def add_input_edges( 

410 self, 

411 task_key: QuantumKey | TaskInitKey, 

412 dataset_keys: Iterable[DatasetKey | PrerequisiteDatasetKey], 

413 ) -> None: 

414 """Add edges connecting datasets to a quantum that consumes them. 

415 

416 Parameters 

417 ---------- 

418 task_key : `QuantumKey` or `TaskInitKey` 

419 Quantum to connect. 

420 dataset_keys : `~collections.abc.Iterable` of `DatasetKey`\ 

421 or `PrequisiteDatasetKey` 

422 Datasets to join to the quantum. 

423 

424 Notes 

425 ----- 

426 This must only be called if the task node has already been added. 

427 Use `add_input_edge` if this cannot be assumed. 

428 

429 Dataset nodes that are not already present will be created. 

430 """ 

431 assert task_key in self._xgraph 

432 self._xgraph.add_edges_from((dataset_key, task_key) for dataset_key in dataset_keys) 

433 

434 def remove_input_edges( 

435 self, 

436 task_key: QuantumKey | TaskInitKey, 

437 dataset_keys: Iterable[DatasetKey | PrerequisiteDatasetKey], 

438 ) -> None: 

439 """Remove edges connecting datasets to a quantum that consumes them. 

440 

441 Parameters 

442 ---------- 

443 task_key : `QuantumKey` or `TaskInitKey` 

444 Quantum to disconnect. 

445 dataset_keys : `~collections.abc.Iterable` of `DatasetKey`\ 

446 or `PrequisiteDatasetKey` 

447 Datasets to remove from the quantum. 

448 """ 

449 self._xgraph.remove_edges_from((dataset_key, task_key) for dataset_key in dataset_keys) 

450 

451 def add_input_edge( 

452 self, 

453 task_key: QuantumKey | TaskInitKey, 

454 dataset_key: DatasetKey | PrerequisiteDatasetKey, 

455 ignore_unrecognized_quanta: bool = False, 

456 ) -> bool: 

457 """Add an edge connecting a dataset to a quantum that consumes it. 

458 

459 Parameters 

460 ---------- 

461 task_key : `QuantumKey` or `TaskInitKey` 

462 Identifier for the quantum node. 

463 dataset_key : `DatasetKey` or `PrerequisiteKey` 

464 Identifier for the dataset node. 

465 ignore_unrecognized_quanta : `bool`, optional 

466 If `False`, do nothing if the quantum node is not already present. 

467 If `True`, the quantum node is assumed to be present. 

468 

469 Returns 

470 ------- 

471 added : `bool` 

472 `True` if an edge was actually added, `False` if the quantum was 

473 not recognized and the edge was not added as a result. 

474 

475 Notes 

476 ----- 

477 Dataset nodes that are not already present will be created. 

478 """ 

479 if ignore_unrecognized_quanta and task_key not in self._xgraph: 

480 return False 

481 self._xgraph.add_edge(dataset_key, task_key) 

482 return True 

483 

484 def add_output_edge(self, task_key: QuantumKey | TaskInitKey, dataset_key: DatasetKey) -> None: 

485 """Add an edge connecting a dataset to the quantum that produces it. 

486 

487 Parameters 

488 ---------- 

489 task_key : `QuantumKey` or `TaskInitKey` 

490 Identifier for the quantum node. Must identify a node already 

491 present in the graph. 

492 dataset_key : `DatasetKey` 

493 Identifier for the dataset node. Must identify a node already 

494 present in the graph. 

495 """ 

496 assert task_key in self._xgraph 

497 assert dataset_key in self._xgraph 

498 self._xgraph.add_edge(task_key, dataset_key) 

499 

500 def remove_orphan_datasets(self) -> None: 

501 """Remove any dataset nodes that do not have any edges.""" 

502 for orphan in list(networkx.isolates(self._xgraph)): 

503 if not orphan.is_task and orphan not in self._global_init_outputs: 

504 self._xgraph.remove_node(orphan) 

505 

506 def extract_overall_inputs(self) -> dict[DatasetKey | PrerequisiteDatasetKey, DatasetRef]: 

507 """Find overall input datasets. 

508 

509 Returns 

510 ------- 

511 datasets : `dict` [ `DatasetKey` or `PrerequisiteDatasetKey`, \ 

512 `~lsst.daf.butler.DatasetRef` ] 

513 Overall-input datasets, including prerequisites and init-inputs. 

514 """ 

515 result = {} 

516 for generation in networkx.algorithms.topological_generations(self._xgraph): 

517 for dataset_key in generation: 

518 if dataset_key.is_task: 

519 continue 

520 try: 

521 result[dataset_key] = self[dataset_key]["ref"] 

522 except KeyError: 

523 raise AssertionError( 

524 f"Logic bug in QG generation: dataset {dataset_key} was never resolved." 

525 ) 

526 break 

527 return result