Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("QuantumGraph", "IncompatibleGraphError") 

24 

25from collections import defaultdict 

26 

27from itertools import chain, count 

28import io 

29import networkx as nx 

30from networkx.drawing.nx_agraph import write_dot 

31import os 

32import pickle 

33import time 

34from typing import (DefaultDict, Dict, FrozenSet, Iterable, List, Mapping, Set, Generator, Optional, Tuple, 

35 Union, TypeVar) 

36 

37from ..connections import iterConnections 

38from ..pipeline import TaskDef 

39from lsst.daf.butler import Quantum, DatasetRef 

40 

41from ._implDetails import _DatasetTracker, DatasetTypeName 

42from .quantumNode import QuantumNode, NodeId, BuildId 

43 

44_T = TypeVar("_T", bound="QuantumGraph") 

45 

46 

47class IncompatibleGraphError(Exception): 

48 """Exception class to indicate that a lookup by NodeId is impossible due 

49 to incompatibilities 

50 """ 

51 pass 

52 

53 

54class QuantumGraph: 

55 """QuantumGraph is a directed acyclic graph of `QuantumNode`s 

56 

57 This data structure represents a concrete workflow generated from a 

58 `Pipeline`. 

59 

60 Parameters 

61 ---------- 

62 quanta : Mapping of `TaskDef` to sets of `Quantum` 

63 This maps tasks (and their configs) to the sets of data they are to 

64 process. 

65 """ 

66 def __init__(self, quanta: Mapping[TaskDef, Set[Quantum]]): 

67 self._buildGraphs(quanta) 

68 

69 def _buildGraphs(self, 

70 quanta: Mapping[TaskDef, Set[Quantum]], 

71 *, 

72 _quantumToNodeId: Optional[Mapping[Quantum, NodeId]] = None, 

73 _buildId: Optional[BuildId] = None): 

74 """Builds the graph that is used to store the relation between tasks, 

75 and the graph that holds the relations between quanta 

76 """ 

77 self._quanta = quanta 

78 self._buildId = _buildId if _buildId is not None else BuildId(f"{time.time()}-{os.getpid()}") 

79 # Data structures used to identify relations between components; 

80 # DatasetTypeName -> TaskDef for task, 

81 # and DatasetRef -> QuantumNode for the quanta 

82 self._datasetDict = _DatasetTracker[DatasetTypeName, TaskDef]() 

83 self._datasetRefDict = _DatasetTracker[DatasetRef, QuantumNode]() 

84 

85 nodeNumberGenerator = count() 

86 self._nodeIdMap: Dict[NodeId, QuantumNode] = {} 

87 self._count = 0 

88 for taskDef, quantumSet in self._quanta.items(): 

89 connections = taskDef.connections 

90 

91 # For each type of connection in the task, add a key to the 

92 # `_DatasetTracker` for the connections name, with a value of 

93 # the TaskDef in the appropriate field 

94 for inpt in iterConnections(connections, ("inputs", "prerequisiteInputs", "initInputs")): 

95 self._datasetDict.addInput(DatasetTypeName(inpt.name), taskDef) 

96 

97 for output in iterConnections(connections, ("outputs", "initOutputs")): 

98 self._datasetDict.addOutput(DatasetTypeName(output.name), taskDef) 

99 

100 # For each `Quantum` in the set of all `Quantum` for this task, 

101 # add a key to the `_DatasetTracker` that is a `DatasetRef` for one 

102 # of the individual datasets inside the `Quantum`, with a value of 

103 # a newly created QuantumNode to the appropriate input/output 

104 # field. 

105 self._count += len(quantumSet) 

106 for quantum in quantumSet: 

107 if _quantumToNodeId: 

108 nodeId = _quantumToNodeId.get(quantum) 

109 if nodeId is None: 

110 raise ValueError("If _quantuMToNodeNumber is not None, all quanta must have an " 

111 "associated value in the mapping") 

112 else: 

113 nodeId = NodeId(next(nodeNumberGenerator), self._buildId) 

114 

115 inits = quantum.initInputs.values() 

116 inputs = quantum.inputs.values() 

117 value = QuantumNode(quantum, taskDef, nodeId) 

118 self._nodeIdMap[nodeId] = value 

119 

120 for dsRef in chain(inits, inputs): 

121 # unfortunately, `Quantum` allows inits to be individual 

122 # `DatasetRef`s or an Iterable of such, so there must 

123 # be an instance check here 

124 if isinstance(dsRef, Iterable): 

125 for sub in dsRef: 

126 self._datasetRefDict.addInput(sub, value) 

127 else: 

128 self._datasetRefDict.addInput(dsRef, value) 

129 for dsRef in chain.from_iterable(quantum.outputs.values()): 

130 self._datasetRefDict.addOutput(dsRef, value) 

131 

132 # Graph of task relations, used in various methods 

133 self._taskGraph = self._datasetDict.makeNetworkXGraph() 

134 

135 # Graph of quanta relations 

136 self._connectedQuanta = self._datasetRefDict.makeNetworkXGraph() 

137 

138 @property 

139 def taskGraph(self) -> nx.DiGraph: 

140 """Return a graph representing the relations between the tasks inside 

141 the quantum graph. 

142 

143 Returns 

144 ------- 

145 taskGraph : `networkx.Digraph` 

146 Internal datastructure that holds relations of `TaskDef`s 

147 """ 

148 return self._taskGraph 

149 

150 @property 

151 def graph(self) -> nx.DiGraph: 

152 """Return a graph representing the relations between all the 

153 `QuantumNode`s. Largely it should be preferred to iterate over, and use 

154 methods of this class, but sometimes direct access to the networkx 

155 object may be helpful 

156 

157 Returns 

158 ------- 

159 graph : `networkx.Digraph` 

160 Internal datastructure that holds relations of `QuantumNode`s 

161 """ 

162 return self._connectedQuanta 

163 

164 @property 

165 def inputQuanta(self) -> Iterable[QuantumNode]: 

166 """Make a `list` of all `QuantumNode`s that are 'input' nodes to the 

167 graph, meaning those nodes to not depend on any other nodes in the 

168 graph. 

169 

170 Returns 

171 ------- 

172 inputNodes : iterable of `QuantumNode` 

173 A list of nodes that are inputs to the graph 

174 """ 

175 return (q for q, n in self._connectedQuanta.in_degree if n == 0) 

176 

177 @property 

178 def outputQuanta(self) -> Iterable[QuantumNode]: 

179 """Make a `list` of all `QuantumNode`s that are 'output' nodes to the 

180 graph, meaning those nodes have no nodes that depend them in the graph. 

181 

182 Returns 

183 ------- 

184 outputNodes : iterable of `QuantumNode` 

185 A list of nodes that are outputs of the graph 

186 """ 

187 return [q for q, n in self._connectedQuanta.out_degree if n == 0] 

188 

189 @property 

190 def allDatasetTypes(self) -> Tuple[DatasetTypeName, ...]: 

191 """Return all the `DatasetTypeNames` that are contained inside the 

192 graph. 

193 

194 Returns 

195 ------- 

196 tuple of `DatasetTypeName` 

197 All the data set type names that are present in the graph 

198 """ 

199 return tuple(self._datasetDict.keys()) 

200 

201 @property 

202 def isConnected(self) -> bool: 

203 """Return True if all of the nodes in the graph are connected, ignores 

204 directionality of connections. 

205 """ 

206 return nx.is_weakly_connected(self._connectedQuanta) 

207 

208 def getQuantumNodeByNodeId(self, nodeId: NodeId) -> QuantumNode: 

209 """Lookup a `QuantumNode` from an id associated with the node. 

210 

211 Parameters 

212 ---------- 

213 nodeId : `NodeId` 

214 The number associated with a node 

215 

216 Returns 

217 ------- 

218 node : `QuantumNode` 

219 The node corresponding with input number 

220 

221 Raises 

222 ------ 

223 IndexError 

224 Raised if the requested nodeId is not in the graph. 

225 IncompatibleGraphError 

226 Raised if the nodeId was built with a different graph than is not 

227 this instance (or a graph instance that produced this instance 

228 through and operation such as subset) 

229 """ 

230 if nodeId.buildId != self._buildId: 

231 raise IncompatibleGraphError("This node was built from a different, incompatible, graph instance") 

232 return self._nodeIdMap[nodeId] 

233 

234 def getQuantaForTask(self, taskDef: TaskDef) -> FrozenSet[Quantum]: 

235 """Return all the `Quantum` associated with a `TaskDef`. 

236 

237 Parameters 

238 ---------- 

239 taskDef : `TaskDef` 

240 The `TaskDef` for which `Quantum` are to be queried 

241 

242 Returns 

243 ------- 

244 frozenset of `Quantum` 

245 The `set` of `Quantum` that is associated with the specified 

246 `TaskDef`. 

247 """ 

248 return frozenset(self._quanta[taskDef]) 

249 

250 def findTasksWithInput(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]: 

251 """Find all tasks that have the specified dataset type name as an 

252 input. 

253 

254 Parameters 

255 ---------- 

256 datasetTypeName : `str` 

257 A string representing the name of a dataset type to be queried, 

258 can also accept a `DatasetTypeName` which is a `NewType` of str for 

259 type safety in static type checking. 

260 

261 Returns 

262 ------- 

263 tasks : iterable of `TaskDef` 

264 `TaskDef`s that have the specified `DatasetTypeName` as an input, 

265 list will be empty if no tasks use specified `DatasetTypeName` as 

266 an input. 

267 

268 Raises 

269 ------ 

270 KeyError 

271 Raised if the `DatasetTypeName` is not part of the `QuantumGraph` 

272 """ 

273 return (c for c in self._datasetDict.getInputs(datasetTypeName)) 

274 

275 def findTaskWithOutput(self, datasetTypeName: DatasetTypeName) -> Optional[TaskDef]: 

276 """Find all tasks that have the specified dataset type name as an 

277 output. 

278 

279 Parameters 

280 ---------- 

281 datasetTypeName : `str` 

282 A string representing the name of a dataset type to be queried, 

283 can also accept a `DatasetTypeName` which is a `NewType` of str for 

284 type safety in static type checking. 

285 

286 Returns 

287 ------- 

288 `TaskDef` or `None` 

289 `TaskDef` that outputs `DatasetTypeName` as an output or None if 

290 none of the tasks produce this `DatasetTypeName`. 

291 

292 Raises 

293 ------ 

294 KeyError 

295 Raised if the `DatasetTypeName` is not part of the `QuantumGraph` 

296 """ 

297 return self._datasetDict.getOutput(datasetTypeName) 

298 

299 def tasksWithDSType(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]: 

300 """Find all tasks that are associated with the specified dataset type 

301 name. 

302 

303 Parameters 

304 ---------- 

305 datasetTypeName : `str` 

306 A string representing the name of a dataset type to be queried, 

307 can also accept a `DatasetTypeName` which is a `NewType` of str for 

308 type safety in static type checking. 

309 

310 Returns 

311 ------- 

312 result : iterable of `TaskDef` 

313 `TaskDef`s that are associated with the specified `DatasetTypeName` 

314 

315 Raises 

316 ------ 

317 KeyError 

318 Raised if the `DatasetTypeName` is not part of the `QuantumGraph` 

319 """ 

320 results = self.findTasksWithInput(datasetTypeName) 

321 output = self.findTaskWithOutput(datasetTypeName) 

322 if output is not None: 

323 results = chain(results, (output,)) 

324 return results 

325 

326 def findTaskDefByName(self, taskName: str) -> List[TaskDef]: 

327 """Determine which `TaskDef`s in this graph are associated with a `str` 

328 representing a task name (looks at the taskName property of 

329 `TaskDef`s). 

330 

331 Returns a list of `TaskDef`s as a `PipelineTask` may appear multiple 

332 times in a graph with different labels. 

333 

334 Parameters 

335 ---------- 

336 taskName : str 

337 Name of a task to search for 

338 

339 Returns 

340 ------- 

341 result : list of `TaskDef` 

342 List of the `TaskDef`s that have the name specified. Multiple 

343 values are returned in the case that a task is used multiple times 

344 with different labels. 

345 """ 

346 results = [] 

347 for task in self._quanta.keys(): 

348 split = task.taskName.split('.') 

349 if split[-1] == taskName: 

350 results.append(task) 

351 return results 

352 

353 def findTaskDefByLabel(self, label: str) -> Optional[TaskDef]: 

354 """Determine which `TaskDef`s in this graph are associated with a `str` 

355 representing a tasks label. 

356 

357 Parameters 

358 ---------- 

359 taskName : str 

360 Name of a task to search for 

361 

362 Returns 

363 ------- 

364 result : `TaskDef` 

365 `TaskDef`s that has the specified label. 

366 """ 

367 for task in self._quanta.keys(): 

368 if label == task.label: 

369 return task 

370 return None 

371 

372 def findQuantaWithDSType(self, datasetTypeName: DatasetTypeName) -> Set[Quantum]: 

373 """Return all the `Quantum` that contain a specified `DatasetTypeName`. 

374 

375 Parameters 

376 ---------- 

377 datasetTypeName : `str` 

378 The name of the dataset type to search for as a string, 

379 can also accept a `DatasetTypeName` which is a `NewType` of str for 

380 type safety in static type checking. 

381 

382 Returns 

383 ------- 

384 result : `set` of `QuantumNode`s 

385 A `set` of `QuantumNode`s that contain specified `DatasetTypeName` 

386 

387 Raises 

388 ------ 

389 KeyError 

390 Raised if the `DatasetTypeName` is not part of the `QuantumGraph` 

391 

392 """ 

393 tasks = self._datasetDict.getAll(datasetTypeName) 

394 result: Set[Quantum] = set() 

395 result = result.union(*(self._quanta[task] for task in tasks)) 

396 return result 

397 

398 def checkQuantumInGraph(self, quantum: Quantum) -> bool: 

399 """Check if specified quantum appears in the graph as part of a node. 

400 

401 Parameters 

402 ---------- 

403 quantum : `Quantum` 

404 The quantum to search for 

405 

406 Returns 

407 ------- 

408 `bool` 

409 The result of searching for the quantum 

410 """ 

411 for qset in self._quanta.values(): 

412 if quantum in qset: 

413 return True 

414 return False 

415 

416 def writeDotGraph(self, output: Union[str, io.BufferedIOBase]): 

417 """Write out the graph as a dot graph. 

418 

419 Parameters 

420 ---------- 

421 output : str or `io.BufferedIOBase` 

422 Either a filesystem path to write to, or a file handle object 

423 """ 

424 write_dot(self._connectedQuanta, output) 

425 

426 def subset(self: _T, nodes: Union[QuantumNode, Iterable[QuantumNode]]) -> _T: 

427 """Create a new graph object that contains the subset of the nodes 

428 specified as input. Node number is preserved. 

429 

430 Parameters 

431 ---------- 

432 nodes : `QuantumNode` or iterable of `QuantumNode` 

433 

434 Returns 

435 ------- 

436 graph : instance of graph type 

437 An instance of the type from which the subset was created 

438 """ 

439 if not isinstance(nodes, Iterable): 

440 nodes = (nodes, ) 

441 quantumSubgraph = self._connectedQuanta.subgraph(nodes).nodes 

442 quantumMap = defaultdict(set) 

443 

444 node: QuantumNode 

445 for node in quantumSubgraph: 

446 quantumMap[node.taskDef].add(node.quantum) 

447 # Create an empty graph, and then populate it with custom mapping 

448 newInst = type(self)({}) 

449 newInst._buildGraphs(quantumMap, _quantumToNodeId={n.quantum: n.nodeId for n in nodes}, 

450 _buildId=self._buildId) 

451 return newInst 

452 

453 def subsetToConnected(self: _T) -> Tuple[_T, ...]: 

454 """Generate a list of subgraphs where each is connected. 

455 

456 Returns 

457 ------- 

458 result : list of `QuantumGraph` 

459 A list of graphs that are each connected 

460 """ 

461 return tuple(self.subset(connectedSet) 

462 for connectedSet in nx.weakly_connected_components(self._connectedQuanta)) 

463 

464 def determineInputsToQuantumNode(self, node: QuantumNode) -> Set[QuantumNode]: 

465 """Return a set of `QuantumNode` that are direct inputs to a specified 

466 node. 

467 

468 Parameters 

469 ---------- 

470 node : `QuantumNode` 

471 The node of the graph for which inputs are to be determined 

472 

473 Returns 

474 ------- 

475 set of `QuantumNode` 

476 All the nodes that are direct inputs to specified node 

477 """ 

478 return set(pred for pred in self._connectedQuanta.predecessors(node)) 

479 

480 def determineOutputsOfQuantumNode(self, node: QuantumNode) -> Set[QuantumNode]: 

481 """Return a set of `QuantumNode` that are direct outputs of a specified 

482 node. 

483 

484 Parameters 

485 ---------- 

486 node : `QuantumNode` 

487 The node of the graph for which outputs are to be determined 

488 

489 Returns 

490 ------- 

491 set of `QuantumNode` 

492 All the nodes that are direct outputs to specified node 

493 """ 

494 return set(succ for succ in self._connectedQuanta.successors(node)) 

495 

496 def determineConnectionsOfQuantumNode(self: _T, node: QuantumNode) -> _T: 

497 """Return a graph of `QuantumNode` that are direct inputs and outputs 

498 of a specified node. 

499 

500 Parameters 

501 ---------- 

502 node : `QuantumNode` 

503 The node of the graph for which connected nodes are to be 

504 determined. 

505 

506 Returns 

507 ------- 

508 graph : graph of `QuantumNode` 

509 All the nodes that are directly connected to specified node 

510 """ 

511 nodes = self.determineInputsToQuantumNode(node).union(self.determineOutputsOfQuantumNode(node)) 

512 nodes.add(node) 

513 return self.subset(nodes) 

514 

515 def determineAncestorsOfQuantumNode(self: _T, node: QuantumNode) -> _T: 

516 """Return a graph of the specified node and all the ancestor nodes 

517 directly reachable by walking edges. 

518 

519 Parameters 

520 ---------- 

521 node : `QuantumNode` 

522 The node for which all ansestors are to be determined 

523 

524 Returns 

525 ------- 

526 graph of `QuantumNode` 

527 Graph of node and all of its ansestors 

528 """ 

529 predecessorNodes = nx.ancestors(self._connectedQuanta, node) 

530 predecessorNodes.add(node) 

531 return self.subset(predecessorNodes) 

532 

533 def findCycle(self) -> List[Tuple[QuantumNode, QuantumNode]]: 

534 """Check a graph for the presense of cycles and returns the edges of 

535 any cycles found, or an empty list if there is no cycle. 

536 

537 Returns 

538 ------- 

539 result : list of tuple of `QuantumNode`, `QuantumNode` 

540 A list of any graph edges that form a cycle, or an empty list if 

541 there is no cycle. Empty list to so support if graph.find_cycle() 

542 syntax as an empty list is falsy. 

543 """ 

544 try: 

545 return nx.find_cycle(self._connectedQuanta) 

546 except nx.NetworkXNoCycle: 

547 return [] 

548 

549 def save(self, file): 

550 """Save QuantumGraph to a file. 

551 Presently we store QuantumGraph in pickle format, this could 

552 potentially change in the future if better format is found. 

553 

554 Parameters 

555 ---------- 

556 file : `io.BufferedIOBase` 

557 File to write pickle data open in binary mode. 

558 """ 

559 pickle.dump(self, file) 

560 

561 @classmethod 

562 def load(cls, file, universe): 

563 """Read QuantumGraph from a file that was made by `save`. 

564 

565 Parameters 

566 ---------- 

567 file : `io.BufferedIOBase` 

568 File with pickle data open in binary mode. 

569 universe: `~lsst.daf.butler.DimensionUniverse` 

570 DimensionUniverse instance, not used by the method itself but 

571 needed to ensure that registry data structures are initialized. 

572 

573 Returns 

574 ------- 

575 graph : `QuantumGraph` 

576 Resulting QuantumGraph instance. 

577 

578 Raises 

579 ------ 

580 TypeError 

581 Raised if pickle contains instance of a type other than 

582 QuantumGraph. 

583 Notes 

584 ----- 

585 Reading Quanta from pickle requires existence of singleton 

586 DimensionUniverse which is usually instantiated during Registry 

587 initialization. To make sure that DimensionUniverse exists this method 

588 accepts dummy DimensionUniverse argument. 

589 """ 

590 qgraph = pickle.load(file) 

591 if not isinstance(qgraph, QuantumGraph): 

592 raise TypeError(f"QuantumGraph pickle file has contains unexpected object type: {type(qgraph)}") 

593 return qgraph 

594 

595 def iterTaskGraph(self) -> Generator[TaskDef, None, None]: 

596 """Iterate over the `taskGraph` attribute in topological order 

597 

598 Yields 

599 ------ 

600 `TaskDef` 

601 `TaskDef` objects in topological order 

602 """ 

603 yield from nx.topological_sort(self.taskGraph) 

604 

605 def __iter__(self) -> Generator[QuantumNode, None, None]: 

606 yield from nx.topological_sort(self._connectedQuanta) 

607 

608 def __len__(self) -> int: 

609 return self._count 

610 

611 def __contains__(self, node: QuantumNode) -> bool: 

612 return self._connectedQuanta.has_node(node) 

613 

614 def __getstate__(self) -> dict: 

615 """Stores a compact form of the graph as a list of graph nodes, and a 

616 tuple of task labels and task configs. The full graph can be 

617 reconstructed with this information, and it preseves the ordering of 

618 the graph ndoes. 

619 """ 

620 return {"nodesList": list(self)} 

621 

622 def __setstate__(self, state: dict): 

623 """Reconstructs the state of the graph from the information persisted 

624 in getstate. 

625 """ 

626 quanta: DefaultDict[TaskDef, Set[Quantum]] = defaultdict(set) 

627 quantumToNodeId: Dict[Quantum, NodeId] = {} 

628 quantumNode: QuantumNode 

629 for quantumNode in state['nodesList']: 

630 quanta[quantumNode.taskDef].add(quantumNode.quantum) 

631 quantumToNodeId[quantumNode.quantum] = quantumNode.nodeId 

632 _buildId = quantumNode.nodeId.buildId if state['nodesList'] else None # type: ignore 

633 self._buildGraphs(quanta, _quantumToNodeId=quantumToNodeId, _buildId=_buildId) 

634 

635 def __eq__(self, other: object) -> bool: 

636 if not isinstance(other, QuantumGraph): 

637 return False 

638 if len(self) != len(other): 

639 return False 

640 for node in self: 

641 if node not in other: 

642 return False 

643 if self.determineInputsToQuantumNode(node) != other.determineInputsToQuantumNode(node): 

644 return False 

645 if self.determineOutputsOfQuantumNode(node) != other.determineOutputsOfQuantumNode(node): 

646 return False 

647 return list(self.taskGraph) == list(other.taskGraph)