Coverage for python/lsst/pipe/base/quantum_graph_skeleton.py: 43%

136 statements  

« prev     ^ index     » next       coverage.py v7.3.0, created at 2023-08-31 09:39 +0000

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""An under-construction version of QuantumGraph and various helper 

23classes. 

24""" 

25 

26from __future__ import annotations 

27 

28__all__ = ( 

29 "QuantumGraphSkeleton", 

30 "QuantumKey", 

31 "TaskInitKey", 

32 "DatasetKey", 

33 "PrerequisiteDatasetKey", 

34) 

35 

36from collections.abc import Iterable, Iterator, MutableMapping, Set 

37from typing import TYPE_CHECKING, Any, ClassVar, Literal, NamedTuple 

38 

39import networkx 

40from lsst.daf.butler import DataCoordinate, DataIdValue, DatasetRef 

41from lsst.utils.logging import getLogger 

42 

43if TYPE_CHECKING: 

44 pass 

45 

46_LOG = getLogger(__name__) 

47 

48 

49class QuantumKey(NamedTuple): 

50 """Identifier type for quantum keys in a `QuantumGraphSkeleton`.""" 

51 

52 task_label: str 

53 """Label of the task in the pipeline.""" 

54 

55 data_id_values: tuple[DataIdValue, ...] 

56 """Data ID values of the quantum. 

57 

58 Note that keys are fixed given `task_label`, so using only the values here 

59 speeds up comparisons. 

60 """ 

61 

62 is_task: ClassVar[Literal[True]] = True 

63 """Whether this node represents a quantum or task initialization rather 

64 than a dataset (always `True`). 

65 """ 

66 

67 

68class TaskInitKey(NamedTuple): 

69 """Identifier type for task init keys in a `QuantumGraphSkeleton`.""" 

70 

71 task_label: str 

72 """Label of the task in the pipeline.""" 

73 

74 is_task: ClassVar[Literal[True]] = True 

75 """Whether this node represents a quantum or task initialization rather 

76 than a dataset (always `True`). 

77 """ 

78 

79 

80class DatasetKey(NamedTuple): 

81 """Identifier type for dataset keys in a `QuantumGraphSkeleton`.""" 

82 

83 parent_dataset_type_name: str 

84 """Name of the dataset type (never a component).""" 

85 

86 data_id_values: tuple[DataIdValue, ...] 

87 """Data ID values of the dataset. 

88 

89 Note that keys are fixed given `parent_dataset_type_name`, so using only 

90 the values here speeds up comparisons. 

91 """ 

92 

93 is_task: ClassVar[Literal[False]] = False 

94 """Whether this node represents a quantum or task initialization rather 

95 than a dataset (always `False`). 

96 """ 

97 

98 is_prerequisite: ClassVar[Literal[False]] = False 

99 

100 

101class PrerequisiteDatasetKey(NamedTuple): 

102 """Identifier type for prerequisite dataset keys in a 

103 `QuantumGraphSkeleton`. 

104 

105 Unlike regular datasets, prerequisites are not actually required to come 

106 from a find-first search of `input_collections`, so we don't want to 

107 assume that the same data ID implies the same dataset. Happily we also 

108 don't need to search for them by data ID in the graph, so we can use the 

109 dataset ID (UUID) instead. 

110 """ 

111 

112 parent_dataset_type_name: str 

113 """Name of the dataset type (never a component).""" 

114 

115 dataset_id_bytes: bytes 

116 """Dataset ID (UUID) as raw bytes.""" 

117 

118 is_task: ClassVar[Literal[False]] = False 

119 """Whether this node represents a quantum or task initialization rather 

120 than a dataset (always `False`). 

121 """ 

122 

123 is_prerequisite: ClassVar[Literal[True]] = True 

124 

125 

126class QuantumGraphSkeleton: 

127 """An under-construction quantum graph. 

128 

129 QuantumGraphSkeleton is intended for use inside `QuantumGraphBuilder` and 

130 its subclasses. 

131 

132 Parameters 

133 ---------- 

134 task_labels : `~collections.abc.Iterable` [ `str` ] 

135 The labels of all tasks whose quanta may be included in the graph, in 

136 topological order. 

137 

138 Notes 

139 ----- 

140 QuantumGraphSkeleton models a bipartite version of the quantum graph, in 

141 which both quanta and datasets are represented as nodes and each type of 

142 node only has edges to the other type. 

143 

144 Square-bracket (`getitem`) indexing returns a mutable mapping of a node's 

145 flexible attributes. 

146 

147 The details of the `QuantumGraphSkeleton` API (e.g. which operations 

148 operate on multiple nodes vs. a single node) are set by what's actually 

149 needed by current quantum graph generation algorithms. New variants can be 

150 added as needed, but adding all operations that *might* be useful for some 

151 future algorithm seems premature. 

152 """ 

153 

154 def __init__(self, task_labels: Iterable[str]): 

155 self._tasks: dict[str, tuple[TaskInitKey, set[QuantumKey]]] = {} 

156 self._xgraph: networkx.DiGraph = networkx.DiGraph() 

157 self._global_init_outputs: set[DatasetKey] = set() 

158 for task_label in task_labels: 

159 task_init_key = TaskInitKey(task_label) 

160 self._tasks[task_label] = (task_init_key, set()) 

161 self._xgraph.add_node(task_init_key) 

162 

163 def __contains__(self, key: QuantumKey | TaskInitKey | DatasetKey | PrerequisiteDatasetKey) -> bool: 

164 return key in self._xgraph.nodes 

165 

166 def __getitem__( 

167 self, key: QuantumKey | TaskInitKey | DatasetKey | PrerequisiteDatasetKey 

168 ) -> MutableMapping[str, Any]: 

169 return self._xgraph.nodes[key] 

170 

171 @property 

172 def n_nodes(self) -> int: 

173 """The total number of nodes of all types.""" 

174 return len(self._xgraph.nodes) 

175 

176 @property 

177 def n_edges(self) -> int: 

178 """The total number of edges.""" 

179 return len(self._xgraph.edges) 

180 

181 def has_task(self, task_label: str) -> bool: 

182 """Test whether the given task is in this skeleton. 

183 

184 Tasks are only added to the skeleton at initialization, but may be 

185 removed by `remove_task` if they end up having no quanta. 

186 """ 

187 return task_label in self._tasks 

188 

189 def get_task_init_node(self, task_label: str) -> TaskInitKey: 

190 """Return the graph node that represents a task's initialization.""" 

191 return self._tasks[task_label][0] 

192 

193 def get_quanta(self, task_label: str) -> Set[QuantumKey]: 

194 """Return the quanta for the given task label. 

195 

196 Parameters 

197 ---------- 

198 task_label : `str` 

199 Label for the task. 

200 

201 Returns 

202 ------- 

203 quanta : `~collections.abc.Set` [ `QuantumKey` ] 

204 A set-like object with the identifiers of all quanta for the given 

205 task. *The skeleton object's set of quanta must not be modified 

206 while iterating over this container; make a copy if mutation during 

207 iteration is necessary.* 

208 """ 

209 return self._tasks[task_label][1] 

210 

211 @property 

212 def global_init_outputs(self) -> Set[DatasetKey]: 

213 """The set of dataset nodes that are not associated with any task.""" 

214 return self._global_init_outputs 

215 

216 def iter_all_quanta(self) -> Iterator[QuantumKey]: 

217 """Iterate over all quanta from any task, in topological (but otherwise 

218 unspecified) order. 

219 """ 

220 for _, quanta in self._tasks.values(): 

221 yield from quanta 

222 

223 def iter_outputs_of(self, quantum_key: QuantumKey | TaskInitKey) -> Iterator[DatasetKey]: 

224 """Iterate over the datasets produced by the given quantum.""" 

225 return self._xgraph.successors(quantum_key) 

226 

227 def iter_inputs_of( 

228 self, quantum_key: QuantumKey | TaskInitKey 

229 ) -> Iterator[DatasetKey | PrerequisiteDatasetKey]: 

230 """Iterate over the datasets consumed by the given quantum.""" 

231 return self._xgraph.predecessors(quantum_key) 

232 

233 def update(self, other: QuantumGraphSkeleton) -> None: 

234 """Copy all nodes from ``other`` to ``self``. 

235 

236 The tasks in ``other`` must be a subset of the tasks in ``self`` (this 

237 method is expected to be used to populate a skeleton for a full 

238 from independent-subgraph skeletons). 

239 """ 

240 for task_label, (_, quanta) in other._tasks.items(): 

241 self._tasks[task_label][1].update(quanta) 

242 self._xgraph.update(other._xgraph) 

243 

244 def add_quantum_node(self, task_label: str, data_id: DataCoordinate, **attrs: Any) -> QuantumKey: 

245 """Add a new node representing a quantum.""" 

246 key = QuantumKey(task_label, data_id.values_tuple()) 

247 self._xgraph.add_node(key, data_id=data_id, **attrs) 

248 self._tasks[key.task_label][1].add(key) 

249 return key 

250 

251 def add_dataset_node( 

252 self, 

253 parent_dataset_type_name: str, 

254 data_id: DataCoordinate, 

255 is_global_init_output: bool = False, 

256 **attrs: Any, 

257 ) -> DatasetKey: 

258 """Add a new node representing a dataset.""" 

259 key = DatasetKey(parent_dataset_type_name, data_id.values_tuple()) 

260 self._xgraph.add_node(key, data_id=data_id, **attrs) 

261 if is_global_init_output: 

262 assert isinstance(key, DatasetKey) 

263 self._global_init_outputs.add(key) 

264 return key 

265 

266 def add_prerequisite_node( 

267 self, 

268 parent_dataset_type_name: str, 

269 ref: DatasetRef, 

270 **attrs: Any, 

271 ) -> PrerequisiteDatasetKey: 

272 """Add a new node representing a prerequisite input dataset.""" 

273 key = PrerequisiteDatasetKey(parent_dataset_type_name, ref.id.bytes) 

274 self._xgraph.add_node(key, data_id=ref.dataId, ref=ref, **attrs) 

275 return key 

276 

277 def remove_quantum_node(self, key: QuantumKey, remove_outputs: bool) -> None: 

278 """Remove a node representing a quantum. 

279 

280 Parameters 

281 ---------- 

282 key : `QuantumKey` 

283 Identifier for the node. 

284 remove_outputs : `bool` 

285 If `True`, also remove all dataset nodes produced by this quantum. 

286 If `False`, any such dataset nodes will become overall inputs. 

287 """ 

288 _, quanta = self._tasks[key.task_label] 

289 quanta.remove(key) 

290 if remove_outputs: 

291 to_remove = list(self._xgraph.successors(key)) 

292 to_remove.append(key) 

293 self._xgraph.remove_nodes_from(to_remove) 

294 else: 

295 self._xgraph.remove_node(key) 

296 

297 def remove_dataset_nodes(self, keys: Iterable[DatasetKey | PrerequisiteDatasetKey]) -> None: 

298 """Remove nodes representing datasets.""" 

299 self._xgraph.remove_nodes_from(keys) 

300 

301 def remove_task(self, task_label: str) -> None: 

302 """Fully remove a task from the skeleton. 

303 

304 All init-output datasets and quanta for the task must already have been 

305 removed. 

306 """ 

307 task_init_key, quanta = self._tasks.pop(task_label) 

308 assert not quanta, "Cannot remove task unless all quanta have already been removed." 

309 assert not list(self._xgraph.successors(task_init_key)) 

310 self._xgraph.remove_node(task_init_key) 

311 

312 def add_input_edges( 

313 self, 

314 task_key: QuantumKey | TaskInitKey, 

315 dataset_keys: Iterable[DatasetKey | PrerequisiteDatasetKey], 

316 ) -> None: 

317 """Add edges connecting datasets to a quantum that consumes them. 

318 

319 Notes 

320 ----- 

321 This must only be called if the task node has already been added. 

322 Use `add_input_edge` if this cannot be assumed. 

323 

324 Dataset nodes that are not already present will be created. 

325 """ 

326 assert task_key in self._xgraph 

327 self._xgraph.add_edges_from((dataset_key, task_key) for dataset_key in dataset_keys) 

328 

329 def remove_input_edges( 

330 self, 

331 task_key: QuantumKey | TaskInitKey, 

332 dataset_keys: Iterable[DatasetKey | PrerequisiteDatasetKey], 

333 ) -> None: 

334 """Remove edges connecting datasets to a quantum that consumes them.""" 

335 self._xgraph.remove_edges_from((dataset_key, task_key) for dataset_key in dataset_keys) 

336 

337 def add_input_edge( 

338 self, 

339 task_key: QuantumKey | TaskInitKey, 

340 dataset_key: DatasetKey | PrerequisiteDatasetKey, 

341 ignore_unrecognized_quanta: bool = False, 

342 ) -> bool: 

343 """Add an edge connecting a dataset to a quantum that consumes it. 

344 

345 Parameters 

346 ---------- 

347 task_key : `QuantumKey` or `TaskInitKey` 

348 Identifier for the quantum node. 

349 dataset_key : `DatasetKey` or `PrerequisiteKey` 

350 Identifier for the dataset node. 

351 ignore_unrecognized_quanta : `bool`, optional 

352 If `False`, do nothing if the quantum node is not already present. 

353 If `True`, the quantum node is assumed to be present. 

354 

355 Returns 

356 ------- 

357 added : `bool` 

358 `True` if an edge was actually added, `False` if the quantum was 

359 not recognized and the edge was not added as a result. 

360 

361 Notes 

362 ----- 

363 Dataset nodes that are not already present will be created. 

364 """ 

365 if ignore_unrecognized_quanta and task_key not in self._xgraph: 

366 return False 

367 self._xgraph.add_edge(dataset_key, task_key) 

368 return True 

369 

370 def add_output_edge(self, task_key: QuantumKey | TaskInitKey, dataset_key: DatasetKey) -> None: 

371 """Add an edge connecting a dataset to the quantum that produces it. 

372 

373 Parameters 

374 ---------- 

375 task_key : `QuantumKey` or `TaskInitKey` 

376 Identifier for the quantum node. Must identify a node already 

377 present in the graph. 

378 dataset_key : `DatasetKey` 

379 Identifier for the dataset node. Must identify a node already 

380 present in the graph. 

381 """ 

382 assert task_key in self._xgraph 

383 assert dataset_key in self._xgraph 

384 self._xgraph.add_edge(task_key, dataset_key) 

385 

386 def remove_orphan_datasets(self) -> None: 

387 """Remove any dataset nodes that do not have any edges.""" 

388 for orphan in list(networkx.isolates(self._xgraph)): 

389 if not orphan.is_task and orphan not in self._global_init_outputs: 

390 self._xgraph.remove_node(orphan) 

391 

392 def extract_overall_inputs(self) -> dict[DatasetKey | PrerequisiteDatasetKey, DatasetRef]: 

393 """Find overall input datasets. 

394 

395 Returns 

396 ------- 

397 datasets : `dict` [ `DatasetKey` or `PrerequisiteDatasetKey`, \ 

398 `~lsst.daf.butler.DatasetRef` ] 

399 Overall-input datasets, including prerequisites and init-inputs. 

400 """ 

401 result = {} 

402 for generation in networkx.algorithms.topological_generations(self._xgraph): 

403 for dataset_key in generation: 

404 if dataset_key.is_task: 

405 continue 

406 try: 

407 result[dataset_key] = self[dataset_key]["ref"] 

408 except KeyError: 

409 raise AssertionError( 

410 f"Logic bug in QG generation: dataset {dataset_key} was never resolved." 

411 ) 

412 break 

413 return result