Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("QuantumGraph", "IncompatibleGraphError") 

24 

25from collections import defaultdict 

26 

27from itertools import chain, count 

28import io 

29import networkx as nx 

30from networkx.drawing.nx_agraph import write_dot 

31import os 

32import pickle 

33import time 

34from typing import (DefaultDict, Dict, FrozenSet, Iterable, List, Mapping, Set, Generator, Optional, Tuple, 

35 Union, TypeVar) 

36 

37from ..connections import iterConnections 

38from ..pipeline import TaskDef 

39from lsst.daf.butler import Quantum, DatasetRef 

40 

41from ._implDetails import _DatasetTracker, DatasetTypeName 

42from .quantumNode import QuantumNode, NodeId, BuildId 

43 

44_T = TypeVar("_T", bound="QuantumGraph") 

45 

46 

47class IncompatibleGraphError(Exception): 

48 """Exception class to indicate that a lookup by NodeId is impossible due 

49 to incompatibilities 

50 """ 

51 pass 

52 

53 

54class QuantumGraph: 

55 """QuantumGraph is a directed acyclic graph of `QuantumNode`s 

56 

57 This data structure represents a concrete workflow generated from a 

58 `Pipeline`. 

59 

60 Parameters 

61 ---------- 

62 quanta : Mapping of `TaskDef` to sets of `Quantum` 

63 This maps tasks (and their configs) to the sets of data they are to 

64 process. 

65 """ 

66 def __init__(self, quanta: Mapping[TaskDef, Set[Quantum]]): 

67 self._buildGraphs(quanta) 

68 

69 def _buildGraphs(self, 

70 quanta: Mapping[TaskDef, Set[Quantum]], 

71 *, 

72 _quantumToNodeId: Optional[Mapping[Quantum, NodeId]] = None, 

73 _buildId: Optional[BuildId] = None): 

74 """Builds the graph that is used to store the relation between tasks, 

75 and the graph that holds the relations between quanta 

76 """ 

77 self._quanta = quanta 

78 self._buildId = _buildId if _buildId is not None else BuildId(f"{time.time()}-{os.getpid()}") 

79 # Data structures used to identify relations between components; 

80 # DatasetTypeName -> TaskDef for task, 

81 # and DatasetRef -> QuantumNode for the quanta 

82 self._datasetDict = _DatasetTracker[DatasetTypeName, TaskDef]() 

83 self._datasetRefDict = _DatasetTracker[DatasetRef, QuantumNode]() 

84 

85 nodeNumberGenerator = count() 

86 self._nodeIdMap: Dict[NodeId, QuantumNode] = {} 

87 self._count = 0 

88 for taskDef, quantumSet in self._quanta.items(): 

89 connections = taskDef.connections 

90 

91 # For each type of connection in the task, add a key to the 

92 # `_DatasetTracker` for the connections name, with a value of 

93 # the TaskDef in the appropriate field 

94 for inpt in iterConnections(connections, ("inputs", "prerequisiteInputs", "initInputs")): 

95 self._datasetDict.addInput(DatasetTypeName(inpt.name), taskDef) 

96 

97 for output in iterConnections(connections, ("outputs", "initOutputs")): 

98 self._datasetDict.addOutput(DatasetTypeName(output.name), taskDef) 

99 

100 # For each `Quantum` in the set of all `Quantum` for this task, 

101 # add a key to the `_DatasetTracker` that is a `DatasetRef` for one 

102 # of the individual datasets inside the `Quantum`, with a value of a 

103 # newly created QuantumNode to the appropriate input/output field. 

104 self._count += len(quantumSet) 

105 for quantum in quantumSet: 

106 if _quantumToNodeId: 

107 nodeId = _quantumToNodeId.get(quantum) 

108 if nodeId is None: 

109 raise ValueError("If _quantuMToNodeNumber is not None, all quanta must have an " 

110 "associated value in the mapping") 

111 else: 

112 nodeId = NodeId(next(nodeNumberGenerator), self._buildId) 

113 

114 inits = quantum.initInputs.values() 

115 inputs = quantum.inputs.values() 

116 value = QuantumNode(quantum, taskDef, nodeId) 

117 self._nodeIdMap[nodeId] = value 

118 

119 for dsRef in chain(inits, inputs): 

120 # unfortunately, `Quantum` allows inits to be individual 

121 # `DatasetRef`s or an Iterable of such, so there must 

122 # be an instance check here 

123 if isinstance(dsRef, Iterable): 

124 for sub in dsRef: 

125 self._datasetRefDict.addInput(sub, value) 

126 else: 

127 self._datasetRefDict.addInput(dsRef, value) 

128 for dsRef in chain.from_iterable(quantum.outputs.values()): 

129 self._datasetRefDict.addOutput(dsRef, value) 

130 

131 # Graph of task relations, used in various methods 

132 self._taskGraph = self._datasetDict.makeNetworkXGraph() 

133 

134 # Graph of quanta relations 

135 self._connectedQuanta = self._datasetRefDict.makeNetworkXGraph() 

136 

137 @property 

138 def taskGraph(self) -> nx.DiGraph: 

139 """Return a graph representing the relations between the tasks inside 

140 the quantum graph. 

141 

142 Returns 

143 ------- 

144 taskGraph : `networkx.Digraph` 

145 Internal datastructure that holds relations of `TaskDef`s 

146 """ 

147 return self._taskGraph 

148 

149 @property 

150 def graph(self) -> nx.DiGraph: 

151 """Return a graph representing the relations between all the 

152 `QuantumNode`s. Largely it should be preferred to iterate over, and use 

153 methods of this class, but sometimes direct access to the networkx 

154 object may be helpful 

155 

156 Returns 

157 ------- 

158 graph : `networkx.Digraph` 

159 Internal datastructure that holds relations of `QuantumNode`s 

160 """ 

161 return self._connectedQuanta 

162 

163 @property 

164 def inputQuanta(self) -> Iterable[QuantumNode]: 

165 """Make a `list` of all `QuantumNode`s that are 'input' nodes to the 

166 graph, meaning those nodes to not depend on any other nodes in the 

167 graph. 

168 

169 Returns 

170 ------- 

171 inputNodes : iterable of `QuantumNode` 

172 A list of nodes that are inputs to the graph 

173 """ 

174 return (q for q, n in self._connectedQuanta.in_degree if n == 0) 

175 

176 @property 

177 def outputQuanta(self) -> Iterable[QuantumNode]: 

178 """Make a `list` of all `QuantumNode`s that are 'output' nodes to the 

179 graph, meaning those nodes have no nodes that depend them in the graph. 

180 

181 Returns 

182 ------- 

183 outputNodes : iterable of `QuantumNode` 

184 A list of nodes that are outputs of the graph 

185 """ 

186 return [q for q, n in self._connectedQuanta.out_degree if n == 0] 

187 

188 @property 

189 def allDatasetTypes(self) -> Tuple[DatasetTypeName, ...]: 

190 """Return all the `DatasetTypeNames` that are contained inside the graph. 

191 

192 Returns 

193 ------- 

194 tuple of `DatasetTypeName` 

195 All the data set type names that are present in the graph 

196 """ 

197 return tuple(self._datasetDict.keys()) 

198 

199 @property 

200 def isConnected(self) -> bool: 

201 """Return True if all of the nodes in the graph are connected, ignores 

202 directionality of connections. 

203 """ 

204 return nx.is_weakly_connected(self._connectedQuanta) 

205 

206 def getQuantumNodeByNodeId(self, nodeId: NodeId) -> QuantumNode: 

207 """Lookup a `QuantumNode` from an id associated with the node. 

208 

209 Parameters 

210 ---------- 

211 nodeId : `NodeId` 

212 The number associated with a node 

213 

214 Returns 

215 ------- 

216 node : `QuantumNode` 

217 The node corresponding with input number 

218 

219 Raises 

220 ------ 

221 IndexError 

222 Raised if the requested nodeId is not in the graph. 

223 IncompatibleGraphError 

224 Raised if the nodeId was built with a different graph than is not 

225 this instance (or a graph instance that produced this instance 

226 through and operation such as subset) 

227 """ 

228 if nodeId.buildId != self._buildId: 

229 raise IncompatibleGraphError("This node was built from a different, incompatible, graph instance") 

230 return self._nodeIdMap[nodeId] 

231 

232 def getQuantaForTask(self, taskDef: TaskDef) -> FrozenSet[Quantum]: 

233 """Return all the `Quantum` associated with a `TaskDef`. 

234 

235 Parameters 

236 ---------- 

237 taskDef : `TaskDef` 

238 The `TaskDef` for which `Quantum` are to be queried 

239 

240 Returns 

241 ------- 

242 frozenset of `Quantum` 

243 The `set` of `Quantum` that is associated with the specified 

244 `TaskDef`. 

245 """ 

246 return frozenset(self._quanta[taskDef]) 

247 

248 def findTasksWithInput(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]: 

249 """Find all tasks that have the specified dataset type name as an 

250 input. 

251 

252 Parameters 

253 ---------- 

254 datasetTypeName : `str` 

255 A string representing the name of a dataset type to be queried, 

256 can also accept a `DatasetTypeName` which is a `NewType` of str for 

257 type safety in static type checking. 

258 

259 Returns 

260 ------- 

261 tasks : iterable of `TaskDef` 

262 `TaskDef`s that have the specified `DatasetTypeName` as an input, list 

263 will be empty if no tasks use specified `DatasetTypeName` as an input. 

264 

265 Raises 

266 ------ 

267 KeyError 

268 Raised if the `DatasetTypeName` is not part of the `QuantumGraph` 

269 """ 

270 return (c for c in self._datasetDict.getInputs(datasetTypeName)) 

271 

272 def findTaskWithOutput(self, datasetTypeName: DatasetTypeName) -> Optional[TaskDef]: 

273 """Find all tasks that have the specified dataset type name as an 

274 output. 

275 

276 Parameters 

277 ---------- 

278 datasetTypeName : `str` 

279 A string representing the name of a dataset type to be queried, 

280 can also accept a `DatasetTypeName` which is a `NewType` of str for 

281 type safety in static type checking. 

282 

283 Returns 

284 ------- 

285 `TaskDef` or `None` 

286 `TaskDef` that outputs `DatasetTypeName` as an output or None if none of 

287 the tasks produce this `DatasetTypeName`. 

288 

289 Raises 

290 ------ 

291 KeyError 

292 Raised if the `DatasetTypeName` is not part of the `QuantumGraph` 

293 """ 

294 return self._datasetDict.getOutput(datasetTypeName) 

295 

296 def tasksWithDSType(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]: 

297 """Find all tasks that are associated with the specified dataset type 

298 name. 

299 

300 Parameters 

301 ---------- 

302 datasetTypeName : `str` 

303 A string representing the name of a dataset type to be queried, 

304 can also accept a `DatasetTypeName` which is a `NewType` of str for 

305 type safety in static type checking. 

306 

307 Returns 

308 ------- 

309 result : iterable of `TaskDef` 

310 `TaskDef`s that are associated with the specified `DatasetTypeName` 

311 

312 Raises 

313 ------ 

314 KeyError 

315 Raised if the `DatasetTypeName` is not part of the `QuantumGraph` 

316 """ 

317 results = self.findTasksWithInput(datasetTypeName) 

318 output = self.findTaskWithOutput(datasetTypeName) 

319 if output is not None: 

320 results = chain(results, (output,)) 

321 return results 

322 

323 def findTaskDefByName(self, taskName: str) -> List[TaskDef]: 

324 """Determine which `TaskDef`s in this graph are associated with a `str` 

325 representing a task name (looks at the taskName property of 

326 `TaskDef`s). 

327 

328 Returns a list of `TaskDef`s as a `PipelineTask` may appear multiple 

329 times in a graph with different labels. 

330 

331 Parameters 

332 ---------- 

333 taskName : str 

334 Name of a task to search for 

335 

336 Returns 

337 ------- 

338 result : list of `TaskDef` 

339 List of the `TaskDef`s that have the name specified. Multiple values 

340 are returned in the case that a task is used multiple times with 

341 different labels. 

342 """ 

343 results = [] 

344 for task in self._quanta.keys(): 

345 split = task.taskName.split('.') 

346 if split[-1] == taskName: 

347 results.append(task) 

348 return results 

349 

350 def findTaskDefByLabel(self, label: str) -> Optional[TaskDef]: 

351 """Determine which `TaskDef`s in this graph are associated with a `str` 

352 representing a tasks label. 

353 

354 Parameters 

355 ---------- 

356 taskName : str 

357 Name of a task to search for 

358 

359 Returns 

360 ------- 

361 result : `TaskDef` 

362 `TaskDef`s that has the specified label. 

363 """ 

364 for task in self._quanta.keys(): 

365 if label == task.label: 

366 return task 

367 return None 

368 

369 def findQuantaWithDSType(self, datasetTypeName: DatasetTypeName) -> Set[Quantum]: 

370 """Return all the `Quantum` that contain a specified `DatasetTypeName`. 

371 

372 Parameters 

373 ---------- 

374 datasetTypeName : `str` 

375 The name of the dataset type to search for as a string, 

376 can also accept a `DatasetTypeName` which is a `NewType` of str for 

377 type safety in static type checking. 

378 

379 Returns 

380 ------- 

381 result : `set` of `QuantumNode`s 

382 A `set` of `QuantumNode`s that contain specified `DatasetTypeName` 

383 

384 Raises 

385 ------ 

386 KeyError 

387 Raised if the `DatasetTypeName` is not part of the `QuantumGraph` 

388 

389 """ 

390 tasks = self._datasetDict.getAll(datasetTypeName) 

391 result: Set[Quantum] = set() 

392 result = result.union(*(self._quanta[task] for task in tasks)) 

393 return result 

394 

395 def checkQuantumInGraph(self, quantum: Quantum) -> bool: 

396 """Check if specified quantum appears in the graph as part of a node. 

397 

398 Parameters 

399 ---------- 

400 quantum : `Quantum` 

401 The quantum to search for 

402 

403 Returns 

404 ------- 

405 `bool` 

406 The result of searching for the quantum 

407 """ 

408 for qset in self._quanta.values(): 

409 if quantum in qset: 

410 return True 

411 return False 

412 

413 def writeDotGraph(self, output: Union[str, io.BufferedIOBase]): 

414 """Write out the graph as a dot graph. 

415 

416 Parameters 

417 ---------- 

418 output : str or `io.BufferedIOBase` 

419 Either a filesystem path to write to, or a file handle object 

420 """ 

421 write_dot(self._connectedQuanta, output) 

422 

423 def subset(self: _T, nodes: Union[QuantumNode, Iterable[QuantumNode]]) -> _T: 

424 """Create a new graph object that contains the subset of the nodes 

425 specified as input. Node number is preserved. 

426 

427 Parameters 

428 ---------- 

429 nodes : `QuantumNode` or iterable of `QuantumNode` 

430 

431 Returns 

432 ------- 

433 graph : instance of graph type 

434 An instance of the type from which the subset was created 

435 """ 

436 if not isinstance(nodes, Iterable): 

437 nodes = (nodes, ) 

438 quantumSubgraph = self._connectedQuanta.subgraph(nodes).nodes 

439 quantumMap = defaultdict(set) 

440 

441 node: QuantumNode 

442 for node in quantumSubgraph: 

443 quantumMap[node.taskDef].add(node.quantum) 

444 # Create an empty graph, and then populate it with custom mapping 

445 newInst = type(self)({}) 

446 newInst._buildGraphs(quantumMap, _quantumToNodeId={n.quantum: n.nodeId for n in nodes}, 

447 _buildId=self._buildId) 

448 return newInst 

449 

450 def subsetToConnected(self: _T) -> Tuple[_T, ...]: 

451 """Generate a list of subgraphs where each is connected. 

452 

453 Returns 

454 ------- 

455 result : list of `QuantumGraph` 

456 A list of graphs that are each connected 

457 """ 

458 return tuple(self.subset(connectedSet) 

459 for connectedSet in nx.weakly_connected_components(self._connectedQuanta)) 

460 

461 def determineInputsToQuantumNode(self, node: QuantumNode) -> Set[QuantumNode]: 

462 """Return a set of `QuantumNode` that are direct inputs to a specified 

463 node. 

464 

465 Parameters 

466 ---------- 

467 node : `QuantumNode` 

468 The node of the graph for which inputs are to be determined 

469 

470 Returns 

471 ------- 

472 set of `QuantumNode` 

473 All the nodes that are direct inputs to specified node 

474 """ 

475 return set(pred for pred in self._connectedQuanta.predecessors(node)) 

476 

477 def determineOutputsOfQuantumNode(self, node: QuantumNode) -> Set[QuantumNode]: 

478 """Return a set of `QuantumNode` that are direct outputs of a specified 

479 node. 

480 

481 Parameters 

482 ---------- 

483 node : `QuantumNode` 

484 The node of the graph for which outputs are to be determined 

485 

486 Returns 

487 ------- 

488 set of `QuantumNode` 

489 All the nodes that are direct outputs to specified node 

490 """ 

491 return set(succ for succ in self._connectedQuanta.successors(node)) 

492 

493 def determineConnectionsOfQuantumNode(self: _T, node: QuantumNode) -> _T: 

494 """Return a graph of `QuantumNode` that are direct inputs and outputs 

495 of a specified node. 

496 

497 Parameters 

498 ---------- 

499 node : `QuantumNode` 

500 The node of the graph for which connected nodes are to be determined 

501 

502 Returns 

503 ------- 

504 graph : graph of `QuantumNode` 

505 All the nodes that are directly connected to specified node 

506 """ 

507 nodes = self.determineInputsToQuantumNode(node).union(self.determineOutputsOfQuantumNode(node)) 

508 nodes.add(node) 

509 return self.subset(nodes) 

510 

511 def determineAncestorsOfQuantumNode(self: _T, node: QuantumNode) -> _T: 

512 """Return a graph of the specified node and all the ancestor nodes 

513 directly reachable by walking edges. 

514 

515 Parameters 

516 ---------- 

517 node : `QuantumNode` 

518 The node for which all ansestors are to be determined 

519 

520 Returns 

521 ------- 

522 graph of `QuantumNode` 

523 Graph of node and all of its ansestors 

524 """ 

525 predecessorNodes = nx.ancestors(self._connectedQuanta, node) 

526 predecessorNodes.add(node) 

527 return self.subset(predecessorNodes) 

528 

529 def findCycle(self) -> List[Tuple[QuantumNode, QuantumNode]]: 

530 """Check a graph for the presense of cycles and returns the edges of 

531 any cycles found, or an empty list if there is no cycle. 

532 

533 Returns 

534 ------- 

535 result : list of tuple of `QuantumNode`, `QuantumNode` 

536 A list of any graph edges that form a cycle, or an empty list if 

537 there is no cycle. Empty list to so support if graph.find_cycle() 

538 syntax as an empty list is falsy. 

539 """ 

540 try: 

541 return nx.find_cycle(self._connectedQuanta) 

542 except nx.NetworkXNoCycle: 

543 return [] 

544 

545 def save(self, file): 

546 """Save QuantumGraph to a file. 

547 Presently we store QuantumGraph in pickle format, this could 

548 potentially change in the future if better format is found. 

549 

550 Parameters 

551 ---------- 

552 file : `io.BufferedIOBase` 

553 File to write pickle data open in binary mode. 

554 """ 

555 pickle.dump(self, file) 

556 

557 @classmethod 

558 def load(cls, file, universe): 

559 """Read QuantumGraph from a file that was made by `save`. 

560 

561 Parameters 

562 ---------- 

563 file : `io.BufferedIOBase` 

564 File with pickle data open in binary mode. 

565 universe: `~lsst.daf.butler.DimensionUniverse` 

566 DimensionUniverse instance, not used by the method itself but 

567 needed to ensure that registry data structures are initialized. 

568 

569 Returns 

570 ------- 

571 graph : `QuantumGraph` 

572 Resulting QuantumGraph instance. 

573 

574 Raises 

575 ------ 

576 TypeError 

577 Raised if pickle contains instance of a type other than 

578 QuantumGraph. 

579 Notes 

580 ----- 

581 Reading Quanta from pickle requires existence of singleton 

582 DimensionUniverse which is usually instantiated during Registry 

583 initialization. To make sure that DimensionUniverse exists this method 

584 accepts dummy DimensionUniverse argument. 

585 """ 

586 qgraph = pickle.load(file) 

587 if not isinstance(qgraph, QuantumGraph): 

588 raise TypeError(f"QuantumGraph pickle file has contains unexpected object type: {type(qgraph)}") 

589 return qgraph 

590 

591 def iterTaskGraph(self) -> Generator[TaskDef, None, None]: 

592 """Iterate over the `taskGraph` attribute in topological order 

593 

594 Yields 

595 ------ 

596 `TaskDef` 

597 `TaskDef` objects in topological order 

598 """ 

599 yield from nx.topological_sort(self.taskGraph) 

600 

601 def __iter__(self) -> Generator[QuantumNode, None, None]: 

602 yield from nx.topological_sort(self._connectedQuanta) 

603 

604 def __len__(self) -> int: 

605 return self._count 

606 

607 def __contains__(self, node: QuantumNode) -> bool: 

608 return self._connectedQuanta.has_node(node) 

609 

610 def __getstate__(self) -> dict: 

611 """Stores a compact form of the graph as a list of graph nodes, and a 

612 tuple of task labels and task configs. The full graph can be 

613 reconstructed with this information, and it preseves the ordering of 

614 the graph ndoes. 

615 """ 

616 return {"nodesList": list(self)} 

617 

618 def __setstate__(self, state: dict): 

619 """Reconstructs the state of the graph from the information persisted 

620 in getstate. 

621 """ 

622 quanta: DefaultDict[TaskDef, Set[Quantum]] = defaultdict(set) 

623 quantumToNodeId: Dict[Quantum, NodeId] = {} 

624 quantumNode: QuantumNode 

625 for quantumNode in state['nodesList']: 

626 quanta[quantumNode.taskDef].add(quantumNode.quantum) 

627 quantumToNodeId[quantumNode.quantum] = quantumNode.nodeId 

628 _buildId = quantumNode.nodeId.buildId if state['nodesList'] else None # type: ignore 

629 self._buildGraphs(quanta, _quantumToNodeId=quantumToNodeId, _buildId=_buildId) 

630 

631 def __eq__(self, other: object) -> bool: 

632 if not isinstance(other, QuantumGraph): 

633 return False 

634 if len(self) != len(other): 

635 return False 

636 for node in self: 

637 if node not in other: 

638 return False 

639 if self.determineInputsToQuantumNode(node) != other.determineInputsToQuantumNode(node): 

640 return False 

641 if self.determineOutputsOfQuantumNode(node) != other.determineOutputsOfQuantumNode(node): 

642 return False 

643 return list(self.taskGraph) == list(other.taskGraph)