Coverage for python/lsst/pipe/base/graph/graph.py: 19%

341 statements  

« prev     ^ index     » next       coverage.py v6.4, created at 2022-05-24 02:42 -0700

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("QuantumGraph", "IncompatibleGraphError") 

24 

25import io 

26import json 

27import lzma 

28import os 

29import pickle 

30import struct 

31import time 

32import uuid 

33import warnings 

34from collections import defaultdict, deque 

35from itertools import chain 

36from types import MappingProxyType 

37from typing import ( 

38 Any, 

39 BinaryIO, 

40 DefaultDict, 

41 Deque, 

42 Dict, 

43 FrozenSet, 

44 Generator, 

45 Iterable, 

46 List, 

47 Mapping, 

48 MutableMapping, 

49 Optional, 

50 Set, 

51 Tuple, 

52 TypeVar, 

53 Union, 

54) 

55 

56import networkx as nx 

57from lsst.daf.butler import DatasetRef, DatasetType, DimensionRecordsAccumulator, DimensionUniverse, Quantum 

58from lsst.resources import ResourcePath, ResourcePathExpression 

59from lsst.utils.introspection import get_full_type_name 

60from networkx.drawing.nx_agraph import write_dot 

61 

62from ..connections import iterConnections 

63from ..pipeline import TaskDef 

64from ._implDetails import DatasetTypeName, _DatasetTracker, _pruner 

65from ._loadHelpers import LoadHelper 

66from ._versionDeserializers import DESERIALIZER_MAP 

67from .quantumNode import BuildId, QuantumNode 

68 

69_T = TypeVar("_T", bound="QuantumGraph") 

70 

71# modify this constant any time the on disk representation of the save file 

72# changes, and update the load helpers to behave properly for each version. 

73SAVE_VERSION = 3 

74 

75# Strings used to describe the format for the preamble bytes in a file save 

76# The base is a big endian encoded unsigned short that is used to hold the 

77# file format version. This allows reading version bytes and determine which 

78# loading code should be used for the rest of the file 

79STRUCT_FMT_BASE = ">H" 

80# 

81# Version 1 

82# This marks a big endian encoded format with an unsigned short, an unsigned 

83# long long, and an unsigned long long in the byte stream 

84# Version 2 

85# A big endian encoded format with an unsigned long long byte stream used to 

86# indicate the total length of the entire header. 

87STRUCT_FMT_STRING = {1: ">QQ", 2: ">Q"} 

88 

89# magic bytes that help determine this is a graph save 

90MAGIC_BYTES = b"qgraph4\xf6\xe8\xa9" 

91 

92 

93class IncompatibleGraphError(Exception): 

94 """Exception class to indicate that a lookup by NodeId is impossible due 

95 to incompatibilities 

96 """ 

97 

98 pass 

99 

100 

101class QuantumGraph: 

102 """QuantumGraph is a directed acyclic graph of `QuantumNode` objects 

103 

104 This data structure represents a concrete workflow generated from a 

105 `Pipeline`. 

106 

107 Parameters 

108 ---------- 

109 quanta : Mapping of `TaskDef` to sets of `Quantum` 

110 This maps tasks (and their configs) to the sets of data they are to 

111 process. 

112 metadata : Optional Mapping of `str` to primitives 

113 This is an optional parameter of extra data to carry with the graph. 

114 Entries in this mapping should be able to be serialized in JSON. 

115 

116 Raises 

117 ------ 

118 ValueError 

119 Raised if the graph is pruned such that some tasks no longer have nodes 

120 associated with them. 

121 """ 

122 

123 def __init__( 

124 self, 

125 quanta: Mapping[TaskDef, Set[Quantum]], 

126 metadata: Optional[Mapping[str, Any]] = None, 

127 pruneRefs: Optional[Iterable[DatasetRef]] = None, 

128 ): 

129 self._buildGraphs(quanta, metadata=metadata, pruneRefs=pruneRefs) 

130 

131 def _buildGraphs( 

132 self, 

133 quanta: Mapping[TaskDef, Set[Quantum]], 

134 *, 

135 _quantumToNodeId: Optional[Mapping[Quantum, uuid.UUID]] = None, 

136 _buildId: Optional[BuildId] = None, 

137 metadata: Optional[Mapping[str, Any]] = None, 

138 pruneRefs: Optional[Iterable[DatasetRef]] = None, 

139 ) -> None: 

140 """Builds the graph that is used to store the relation between tasks, 

141 and the graph that holds the relations between quanta 

142 """ 

143 self._metadata = metadata 

144 self._buildId = _buildId if _buildId is not None else BuildId(f"{time.time()}-{os.getpid()}") 

145 # Data structures used to identify relations between components; 

146 # DatasetTypeName -> TaskDef for task, 

147 # and DatasetRef -> QuantumNode for the quanta 

148 self._datasetDict = _DatasetTracker[DatasetTypeName, TaskDef](createInverse=True) 

149 self._datasetRefDict = _DatasetTracker[DatasetRef, QuantumNode]() 

150 

151 self._nodeIdMap: Dict[uuid.UUID, QuantumNode] = {} 

152 self._taskToQuantumNode: MutableMapping[TaskDef, Set[QuantumNode]] = defaultdict(set) 

153 for taskDef, quantumSet in quanta.items(): 

154 connections = taskDef.connections 

155 

156 # For each type of connection in the task, add a key to the 

157 # `_DatasetTracker` for the connections name, with a value of 

158 # the TaskDef in the appropriate field 

159 for inpt in iterConnections(connections, ("inputs", "prerequisiteInputs", "initInputs")): 

160 self._datasetDict.addConsumer(DatasetTypeName(inpt.name), taskDef) 

161 

162 for output in iterConnections(connections, ("outputs",)): 

163 self._datasetDict.addProducer(DatasetTypeName(output.name), taskDef) 

164 

165 # For each `Quantum` in the set of all `Quantum` for this task, 

166 # add a key to the `_DatasetTracker` that is a `DatasetRef` for one 

167 # of the individual datasets inside the `Quantum`, with a value of 

168 # a newly created QuantumNode to the appropriate input/output 

169 # field. 

170 for quantum in quantumSet: 

171 if _quantumToNodeId: 

172 if (nodeId := _quantumToNodeId.get(quantum)) is None: 

173 raise ValueError( 

174 "If _quantuMToNodeNumber is not None, all quanta must have an " 

175 "associated value in the mapping" 

176 ) 

177 else: 

178 nodeId = uuid.uuid4() 

179 

180 inits = quantum.initInputs.values() 

181 inputs = quantum.inputs.values() 

182 value = QuantumNode(quantum, taskDef, nodeId) 

183 self._taskToQuantumNode[taskDef].add(value) 

184 self._nodeIdMap[nodeId] = value 

185 

186 for dsRef in chain(inits, inputs): 

187 # unfortunately, `Quantum` allows inits to be individual 

188 # `DatasetRef`s or an Iterable of such, so there must 

189 # be an instance check here 

190 if isinstance(dsRef, Iterable): 

191 for sub in dsRef: 

192 if sub.isComponent(): 

193 sub = sub.makeCompositeRef() 

194 self._datasetRefDict.addConsumer(sub, value) 

195 else: 

196 assert isinstance(dsRef, DatasetRef) 

197 if dsRef.isComponent(): 

198 dsRef = dsRef.makeCompositeRef() 

199 self._datasetRefDict.addConsumer(dsRef, value) 

200 for dsRef in chain.from_iterable(quantum.outputs.values()): 

201 self._datasetRefDict.addProducer(dsRef, value) 

202 

203 if pruneRefs is not None: 

204 # track what refs were pruned and prune the graph 

205 prunes: Set[QuantumNode] = set() 

206 _pruner(self._datasetRefDict, pruneRefs, alreadyPruned=prunes) 

207 

208 # recreate the taskToQuantumNode dict removing nodes that have been 

209 # pruned. Keep track of task defs that now have no QuantumNodes 

210 emptyTasks: Set[str] = set() 

211 newTaskToQuantumNode: DefaultDict[TaskDef, Set[QuantumNode]] = defaultdict(set) 

212 # accumulate all types 

213 types_ = set() 

214 # tracker for any pruneRefs that have caused tasks to have no nodes 

215 # This helps the user find out what caused the issues seen. 

216 culprits = set() 

217 # Find all the types from the refs to prune 

218 for r in pruneRefs: 

219 types_.add(r.datasetType) 

220 

221 # For each of the tasks, and their associated nodes, remove any 

222 # any nodes that were pruned. If there are no nodes associated 

223 # with a task, record that task, and find out if that was due to 

224 # a type from an input ref to prune. 

225 for td, taskNodes in self._taskToQuantumNode.items(): 

226 diff = taskNodes.difference(prunes) 

227 if len(diff) == 0: 

228 if len(taskNodes) != 0: 

229 tp: DatasetType 

230 for tp in types_: 

231 if (tmpRefs := next(iter(taskNodes)).quantum.inputs.get(tp)) and not set( 

232 tmpRefs 

233 ).difference(pruneRefs): 

234 culprits.add(tp.name) 

235 emptyTasks.add(td.label) 

236 newTaskToQuantumNode[td] = diff 

237 

238 # update the internal dict 

239 self._taskToQuantumNode = newTaskToQuantumNode 

240 

241 if emptyTasks: 

242 raise ValueError( 

243 f"{', '.join(emptyTasks)} task(s) have no nodes associated with them " 

244 f"after graph pruning; {', '.join(culprits)} caused over-pruning" 

245 ) 

246 

247 # Graph of quanta relations 

248 self._connectedQuanta = self._datasetRefDict.makeNetworkXGraph() 

249 self._count = len(self._connectedQuanta) 

250 

251 # Graph of task relations, used in various methods 

252 self._taskGraph = self._datasetDict.makeNetworkXGraph() 

253 

254 # convert default dict into a regular to prevent accidental key 

255 # insertion 

256 self._taskToQuantumNode = dict(self._taskToQuantumNode.items()) 

257 

258 @property 

259 def taskGraph(self) -> nx.DiGraph: 

260 """Return a graph representing the relations between the tasks inside 

261 the quantum graph. 

262 

263 Returns 

264 ------- 

265 taskGraph : `networkx.Digraph` 

266 Internal datastructure that holds relations of `TaskDef` objects 

267 """ 

268 return self._taskGraph 

269 

270 @property 

271 def graph(self) -> nx.DiGraph: 

272 """Return a graph representing the relations between all the 

273 `QuantumNode` objects. Largely it should be preferred to iterate 

274 over, and use methods of this class, but sometimes direct access to 

275 the networkx object may be helpful 

276 

277 Returns 

278 ------- 

279 graph : `networkx.Digraph` 

280 Internal datastructure that holds relations of `QuantumNode` 

281 objects 

282 """ 

283 return self._connectedQuanta 

284 

285 @property 

286 def inputQuanta(self) -> Iterable[QuantumNode]: 

287 """Make a `list` of all `QuantumNode` objects that are 'input' nodes 

288 to the graph, meaning those nodes to not depend on any other nodes in 

289 the graph. 

290 

291 Returns 

292 ------- 

293 inputNodes : iterable of `QuantumNode` 

294 A list of nodes that are inputs to the graph 

295 """ 

296 return (q for q, n in self._connectedQuanta.in_degree if n == 0) 

297 

298 @property 

299 def outputQuanta(self) -> Iterable[QuantumNode]: 

300 """Make a `list` of all `QuantumNode` objects that are 'output' nodes 

301 to the graph, meaning those nodes have no nodes that depend them in 

302 the graph. 

303 

304 Returns 

305 ------- 

306 outputNodes : iterable of `QuantumNode` 

307 A list of nodes that are outputs of the graph 

308 """ 

309 return [q for q, n in self._connectedQuanta.out_degree if n == 0] 

310 

311 @property 

312 def allDatasetTypes(self) -> Tuple[DatasetTypeName, ...]: 

313 """Return all the `DatasetTypeName` objects that are contained inside 

314 the graph. 

315 

316 Returns 

317 ------- 

318 tuple of `DatasetTypeName` 

319 All the data set type names that are present in the graph 

320 """ 

321 return tuple(self._datasetDict.keys()) 

322 

323 @property 

324 def isConnected(self) -> bool: 

325 """Return True if all of the nodes in the graph are connected, ignores 

326 directionality of connections. 

327 """ 

328 return nx.is_weakly_connected(self._connectedQuanta) 

329 

330 def pruneGraphFromRefs(self: _T, refs: Iterable[DatasetRef]) -> _T: 

331 r"""Return a graph pruned of input `~lsst.daf.butler.DatasetRef`\ s 

332 and nodes which depend on them. 

333 

334 Parameters 

335 ---------- 

336 refs : `Iterable` of `DatasetRef` 

337 Refs which should be removed from resulting graph 

338 

339 Returns 

340 ------- 

341 graph : `QuantumGraph` 

342 A graph that has been pruned of specified refs and the nodes that 

343 depend on them. 

344 """ 

345 newInst = object.__new__(type(self)) 

346 quantumMap = defaultdict(set) 

347 for node in self: 

348 quantumMap[node.taskDef].add(node.quantum) 

349 

350 # convert to standard dict to prevent accidental key insertion 

351 quantumDict: Dict[TaskDef, Set[Quantum]] = dict(quantumMap.items()) 

352 

353 newInst._buildGraphs( 

354 quantumDict, 

355 _quantumToNodeId={n.quantum: n.nodeId for n in self}, 

356 metadata=self._metadata, 

357 pruneRefs=refs, 

358 ) 

359 return newInst 

360 

361 def getQuantumNodeByNodeId(self, nodeId: uuid.UUID) -> QuantumNode: 

362 """Lookup a `QuantumNode` from an id associated with the node. 

363 

364 Parameters 

365 ---------- 

366 nodeId : `NodeId` 

367 The number associated with a node 

368 

369 Returns 

370 ------- 

371 node : `QuantumNode` 

372 The node corresponding with input number 

373 

374 Raises 

375 ------ 

376 KeyError 

377 Raised if the requested nodeId is not in the graph. 

378 """ 

379 return self._nodeIdMap[nodeId] 

380 

381 def getQuantaForTask(self, taskDef: TaskDef) -> FrozenSet[Quantum]: 

382 """Return all the `Quantum` associated with a `TaskDef`. 

383 

384 Parameters 

385 ---------- 

386 taskDef : `TaskDef` 

387 The `TaskDef` for which `Quantum` are to be queried 

388 

389 Returns 

390 ------- 

391 frozenset of `Quantum` 

392 The `set` of `Quantum` that is associated with the specified 

393 `TaskDef`. 

394 """ 

395 return frozenset(node.quantum for node in self._taskToQuantumNode[taskDef]) 

396 

397 def getNodesForTask(self, taskDef: TaskDef) -> FrozenSet[QuantumNode]: 

398 """Return all the `QuantumNodes` associated with a `TaskDef`. 

399 

400 Parameters 

401 ---------- 

402 taskDef : `TaskDef` 

403 The `TaskDef` for which `Quantum` are to be queried 

404 

405 Returns 

406 ------- 

407 frozenset of `QuantumNodes` 

408 The `frozenset` of `QuantumNodes` that is associated with the 

409 specified `TaskDef`. 

410 """ 

411 return frozenset(self._taskToQuantumNode[taskDef]) 

412 

413 def findTasksWithInput(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]: 

414 """Find all tasks that have the specified dataset type name as an 

415 input. 

416 

417 Parameters 

418 ---------- 

419 datasetTypeName : `str` 

420 A string representing the name of a dataset type to be queried, 

421 can also accept a `DatasetTypeName` which is a `NewType` of str for 

422 type safety in static type checking. 

423 

424 Returns 

425 ------- 

426 tasks : iterable of `TaskDef` 

427 `TaskDef` objects that have the specified `DatasetTypeName` as an 

428 input, list will be empty if no tasks use specified 

429 `DatasetTypeName` as an input. 

430 

431 Raises 

432 ------ 

433 KeyError 

434 Raised if the `DatasetTypeName` is not part of the `QuantumGraph` 

435 """ 

436 return (c for c in self._datasetDict.getConsumers(datasetTypeName)) 

437 

438 def findTaskWithOutput(self, datasetTypeName: DatasetTypeName) -> Optional[TaskDef]: 

439 """Find all tasks that have the specified dataset type name as an 

440 output. 

441 

442 Parameters 

443 ---------- 

444 datasetTypeName : `str` 

445 A string representing the name of a dataset type to be queried, 

446 can also accept a `DatasetTypeName` which is a `NewType` of str for 

447 type safety in static type checking. 

448 

449 Returns 

450 ------- 

451 `TaskDef` or `None` 

452 `TaskDef` that outputs `DatasetTypeName` as an output or None if 

453 none of the tasks produce this `DatasetTypeName`. 

454 

455 Raises 

456 ------ 

457 KeyError 

458 Raised if the `DatasetTypeName` is not part of the `QuantumGraph` 

459 """ 

460 return self._datasetDict.getProducer(datasetTypeName) 

461 

462 def tasksWithDSType(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]: 

463 """Find all tasks that are associated with the specified dataset type 

464 name. 

465 

466 Parameters 

467 ---------- 

468 datasetTypeName : `str` 

469 A string representing the name of a dataset type to be queried, 

470 can also accept a `DatasetTypeName` which is a `NewType` of str for 

471 type safety in static type checking. 

472 

473 Returns 

474 ------- 

475 result : iterable of `TaskDef` 

476 `TaskDef` objects that are associated with the specified 

477 `DatasetTypeName` 

478 

479 Raises 

480 ------ 

481 KeyError 

482 Raised if the `DatasetTypeName` is not part of the `QuantumGraph` 

483 """ 

484 return self._datasetDict.getAll(datasetTypeName) 

485 

486 def findTaskDefByName(self, taskName: str) -> List[TaskDef]: 

487 """Determine which `TaskDef` objects in this graph are associated 

488 with a `str` representing a task name (looks at the taskName property 

489 of `TaskDef` objects). 

490 

491 Returns a list of `TaskDef` objects as a `PipelineTask` may appear 

492 multiple times in a graph with different labels. 

493 

494 Parameters 

495 ---------- 

496 taskName : str 

497 Name of a task to search for 

498 

499 Returns 

500 ------- 

501 result : list of `TaskDef` 

502 List of the `TaskDef` objects that have the name specified. 

503 Multiple values are returned in the case that a task is used 

504 multiple times with different labels. 

505 """ 

506 results = [] 

507 for task in self._taskToQuantumNode.keys(): 

508 split = task.taskName.split(".") 

509 if split[-1] == taskName: 

510 results.append(task) 

511 return results 

512 

513 def findTaskDefByLabel(self, label: str) -> Optional[TaskDef]: 

514 """Determine which `TaskDef` objects in this graph are associated 

515 with a `str` representing a tasks label. 

516 

517 Parameters 

518 ---------- 

519 taskName : str 

520 Name of a task to search for 

521 

522 Returns 

523 ------- 

524 result : `TaskDef` 

525 `TaskDef` objects that has the specified label. 

526 """ 

527 for task in self._taskToQuantumNode.keys(): 

528 if label == task.label: 

529 return task 

530 return None 

531 

532 def findQuantaWithDSType(self, datasetTypeName: DatasetTypeName) -> Set[Quantum]: 

533 """Return all the `Quantum` that contain a specified `DatasetTypeName`. 

534 

535 Parameters 

536 ---------- 

537 datasetTypeName : `str` 

538 The name of the dataset type to search for as a string, 

539 can also accept a `DatasetTypeName` which is a `NewType` of str for 

540 type safety in static type checking. 

541 

542 Returns 

543 ------- 

544 result : `set` of `QuantumNode` objects 

545 A `set` of `QuantumNode`s that contain specified `DatasetTypeName` 

546 

547 Raises 

548 ------ 

549 KeyError 

550 Raised if the `DatasetTypeName` is not part of the `QuantumGraph` 

551 

552 """ 

553 tasks = self._datasetDict.getAll(datasetTypeName) 

554 result: Set[Quantum] = set() 

555 result = result.union(quantum for task in tasks for quantum in self.getQuantaForTask(task)) 

556 return result 

557 

558 def checkQuantumInGraph(self, quantum: Quantum) -> bool: 

559 """Check if specified quantum appears in the graph as part of a node. 

560 

561 Parameters 

562 ---------- 

563 quantum : `Quantum` 

564 The quantum to search for 

565 

566 Returns 

567 ------- 

568 `bool` 

569 The result of searching for the quantum 

570 """ 

571 for node in self: 

572 if quantum == node.quantum: 

573 return True 

574 return False 

575 

576 def writeDotGraph(self, output: Union[str, io.BufferedIOBase]) -> None: 

577 """Write out the graph as a dot graph. 

578 

579 Parameters 

580 ---------- 

581 output : str or `io.BufferedIOBase` 

582 Either a filesystem path to write to, or a file handle object 

583 """ 

584 write_dot(self._connectedQuanta, output) 

585 

586 def subset(self: _T, nodes: Union[QuantumNode, Iterable[QuantumNode]]) -> _T: 

587 """Create a new graph object that contains the subset of the nodes 

588 specified as input. Node number is preserved. 

589 

590 Parameters 

591 ---------- 

592 nodes : `QuantumNode` or iterable of `QuantumNode` 

593 

594 Returns 

595 ------- 

596 graph : instance of graph type 

597 An instance of the type from which the subset was created 

598 """ 

599 if not isinstance(nodes, Iterable): 

600 nodes = (nodes,) 

601 quantumSubgraph = self._connectedQuanta.subgraph(nodes).nodes 

602 quantumMap = defaultdict(set) 

603 

604 node: QuantumNode 

605 for node in quantumSubgraph: 

606 quantumMap[node.taskDef].add(node.quantum) 

607 

608 # convert to standard dict to prevent accidental key insertion 

609 quantumDict: Dict[TaskDef, Set[Quantum]] = dict(quantumMap.items()) 

610 # Create an empty graph, and then populate it with custom mapping 

611 newInst = type(self)({}) 

612 newInst._buildGraphs( 

613 quantumDict, 

614 _quantumToNodeId={n.quantum: n.nodeId for n in nodes}, 

615 _buildId=self._buildId, 

616 metadata=self._metadata, 

617 ) 

618 return newInst 

619 

620 def subsetToConnected(self: _T) -> Tuple[_T, ...]: 

621 """Generate a list of subgraphs where each is connected. 

622 

623 Returns 

624 ------- 

625 result : list of `QuantumGraph` 

626 A list of graphs that are each connected 

627 """ 

628 return tuple( 

629 self.subset(connectedSet) 

630 for connectedSet in nx.weakly_connected_components(self._connectedQuanta) 

631 ) 

632 

633 def determineInputsToQuantumNode(self, node: QuantumNode) -> Set[QuantumNode]: 

634 """Return a set of `QuantumNode` that are direct inputs to a specified 

635 node. 

636 

637 Parameters 

638 ---------- 

639 node : `QuantumNode` 

640 The node of the graph for which inputs are to be determined 

641 

642 Returns 

643 ------- 

644 set of `QuantumNode` 

645 All the nodes that are direct inputs to specified node 

646 """ 

647 return set(pred for pred in self._connectedQuanta.predecessors(node)) 

648 

649 def determineOutputsOfQuantumNode(self, node: QuantumNode) -> Set[QuantumNode]: 

650 """Return a set of `QuantumNode` that are direct outputs of a specified 

651 node. 

652 

653 Parameters 

654 ---------- 

655 node : `QuantumNode` 

656 The node of the graph for which outputs are to be determined 

657 

658 Returns 

659 ------- 

660 set of `QuantumNode` 

661 All the nodes that are direct outputs to specified node 

662 """ 

663 return set(succ for succ in self._connectedQuanta.successors(node)) 

664 

665 def determineConnectionsOfQuantumNode(self: _T, node: QuantumNode) -> _T: 

666 """Return a graph of `QuantumNode` that are direct inputs and outputs 

667 of a specified node. 

668 

669 Parameters 

670 ---------- 

671 node : `QuantumNode` 

672 The node of the graph for which connected nodes are to be 

673 determined. 

674 

675 Returns 

676 ------- 

677 graph : graph of `QuantumNode` 

678 All the nodes that are directly connected to specified node 

679 """ 

680 nodes = self.determineInputsToQuantumNode(node).union(self.determineOutputsOfQuantumNode(node)) 

681 nodes.add(node) 

682 return self.subset(nodes) 

683 

684 def determineAncestorsOfQuantumNode(self: _T, node: QuantumNode) -> _T: 

685 """Return a graph of the specified node and all the ancestor nodes 

686 directly reachable by walking edges. 

687 

688 Parameters 

689 ---------- 

690 node : `QuantumNode` 

691 The node for which all ansestors are to be determined 

692 

693 Returns 

694 ------- 

695 graph of `QuantumNode` 

696 Graph of node and all of its ansestors 

697 """ 

698 predecessorNodes = nx.ancestors(self._connectedQuanta, node) 

699 predecessorNodes.add(node) 

700 return self.subset(predecessorNodes) 

701 

702 def findCycle(self) -> List[Tuple[QuantumNode, QuantumNode]]: 

703 """Check a graph for the presense of cycles and returns the edges of 

704 any cycles found, or an empty list if there is no cycle. 

705 

706 Returns 

707 ------- 

708 result : list of tuple of `QuantumNode`, `QuantumNode` 

709 A list of any graph edges that form a cycle, or an empty list if 

710 there is no cycle. Empty list to so support if graph.find_cycle() 

711 syntax as an empty list is falsy. 

712 """ 

713 try: 

714 return nx.find_cycle(self._connectedQuanta) 

715 except nx.NetworkXNoCycle: 

716 return [] 

717 

718 def saveUri(self, uri: ResourcePathExpression) -> None: 

719 """Save `QuantumGraph` to the specified URI. 

720 

721 Parameters 

722 ---------- 

723 uri : convertible to `ResourcePath` 

724 URI to where the graph should be saved. 

725 """ 

726 buffer = self._buildSaveObject() 

727 path = ResourcePath(uri) 

728 if path.getExtension() not in (".qgraph"): 

729 raise TypeError(f"Can currently only save a graph in qgraph format not {uri}") 

730 path.write(buffer) # type: ignore # Ignore because bytearray is safe to use in place of bytes 

731 

732 @property 

733 def metadata(self) -> Optional[MappingProxyType[str, Any]]: 

734 """ """ 

735 if self._metadata is None: 

736 return None 

737 return MappingProxyType(self._metadata) 

738 

739 @classmethod 

740 def loadUri( 

741 cls, 

742 uri: ResourcePathExpression, 

743 universe: DimensionUniverse, 

744 nodes: Optional[Iterable[uuid.UUID]] = None, 

745 graphID: Optional[BuildId] = None, 

746 minimumVersion: int = 3, 

747 ) -> QuantumGraph: 

748 """Read `QuantumGraph` from a URI. 

749 

750 Parameters 

751 ---------- 

752 uri : convertible to `ResourcePath` 

753 URI from where to load the graph. 

754 universe: `~lsst.daf.butler.DimensionUniverse` 

755 DimensionUniverse instance, not used by the method itself but 

756 needed to ensure that registry data structures are initialized. 

757 nodes: iterable of `int` or None 

758 Numbers that correspond to nodes in the graph. If specified, only 

759 these nodes will be loaded. Defaults to None, in which case all 

760 nodes will be loaded. 

761 graphID : `str` or `None` 

762 If specified this ID is verified against the loaded graph prior to 

763 loading any Nodes. This defaults to None in which case no 

764 validation is done. 

765 minimumVersion : int 

766 Minimum version of a save file to load. Set to -1 to load all 

767 versions. Older versions may need to be loaded, and re-saved 

768 to upgrade them to the latest format before they can be used in 

769 production. 

770 

771 Returns 

772 ------- 

773 graph : `QuantumGraph` 

774 Resulting QuantumGraph instance. 

775 

776 Raises 

777 ------ 

778 TypeError 

779 Raised if pickle contains instance of a type other than 

780 QuantumGraph. 

781 ValueError 

782 Raised if one or more of the nodes requested is not in the 

783 `QuantumGraph` or if graphID parameter does not match the graph 

784 being loaded or if the supplied uri does not point at a valid 

785 `QuantumGraph` save file. 

786 

787 

788 Notes 

789 ----- 

790 Reading Quanta from pickle requires existence of singleton 

791 DimensionUniverse which is usually instantiated during Registry 

792 initialization. To make sure that DimensionUniverse exists this method 

793 accepts dummy DimensionUniverse argument. 

794 """ 

795 uri = ResourcePath(uri) 

796 # With ResourcePath we have the choice of always using a local file 

797 # or reading in the bytes directly. Reading in bytes can be more 

798 # efficient for reasonably-sized pickle files when the resource 

799 # is remote. For now use the local file variant. For a local file 

800 # as_local() does nothing. 

801 

802 if uri.getExtension() in (".pickle", ".pkl"): 

803 with uri.as_local() as local, open(local.ospath, "rb") as fd: 

804 warnings.warn("Pickle graphs are deprecated, please re-save your graph with the save method") 

805 qgraph = pickle.load(fd) 

806 elif uri.getExtension() in (".qgraph"): 

807 with LoadHelper(uri, minimumVersion) as loader: 

808 qgraph = loader.load(universe, nodes, graphID) 

809 else: 

810 raise ValueError("Only know how to handle files saved as `pickle`, `pkl`, or `qgraph`") 

811 if not isinstance(qgraph, QuantumGraph): 

812 raise TypeError(f"QuantumGraph save file contains unexpected object type: {type(qgraph)}") 

813 return qgraph 

814 

815 @classmethod 

816 def readHeader(cls, uri: ResourcePathExpression, minimumVersion: int = 3) -> Optional[str]: 

817 """Read the header of a `QuantumGraph` pointed to by the uri parameter 

818 and return it as a string. 

819 

820 Parameters 

821 ---------- 

822 uri : convertible to `ResourcePath` 

823 The location of the `QuantumGraph` to load. If the argument is a 

824 string, it must correspond to a valid `ResourcePath` path. 

825 minimumVersion : int 

826 Minimum version of a save file to load. Set to -1 to load all 

827 versions. Older versions may need to be loaded, and re-saved 

828 to upgrade them to the latest format before they can be used in 

829 production. 

830 

831 Returns 

832 ------- 

833 header : `str` or `None` 

834 The header associated with the specified `QuantumGraph` it there is 

835 one, else `None`. 

836 

837 Raises 

838 ------ 

839 ValueError 

840 Raised if `QuantuGraph` was saved as a pickle. 

841 Raised if the extention of the file specified by uri is not a 

842 `QuantumGraph` extention. 

843 """ 

844 uri = ResourcePath(uri) 

845 if uri.getExtension() in (".pickle", ".pkl"): 

846 raise ValueError("Reading a header from a pickle save is not supported") 

847 elif uri.getExtension() in (".qgraph"): 

848 return LoadHelper(uri, minimumVersion).readHeader() 

849 else: 

850 raise ValueError("Only know how to handle files saved as `qgraph`") 

851 

852 def buildAndPrintHeader(self) -> None: 

853 """Creates a header that would be used in a save of this object and 

854 prints it out to standard out. 

855 """ 

856 _, header = self._buildSaveObject(returnHeader=True) 

857 print(json.dumps(header)) 

858 

859 def save(self, file: BinaryIO) -> None: 

860 """Save QuantumGraph to a file. 

861 

862 Parameters 

863 ---------- 

864 file : `io.BufferedIOBase` 

865 File to write pickle data open in binary mode. 

866 """ 

867 buffer = self._buildSaveObject() 

868 file.write(buffer) # type: ignore # Ignore because bytearray is safe to use in place of bytes 

869 

870 def _buildSaveObject(self, returnHeader: bool = False) -> Union[bytearray, Tuple[bytearray, Dict]]: 

871 # make some containers 

872 jsonData: Deque[bytes] = deque() 

873 # node map is a list because json does not accept mapping keys that 

874 # are not strings, so we store a list of key, value pairs that will 

875 # be converted to a mapping on load 

876 nodeMap = [] 

877 taskDefMap = {} 

878 headerData: Dict[str, Any] = {} 

879 

880 # Store the QauntumGraph BuildId, this will allow validating BuildIds 

881 # at load time, prior to loading any QuantumNodes. Name chosen for 

882 # unlikely conflicts. 

883 headerData["GraphBuildID"] = self.graphID 

884 headerData["Metadata"] = self._metadata 

885 

886 # counter for the number of bytes processed thus far 

887 count = 0 

888 # serialize out the task Defs recording the start and end bytes of each 

889 # taskDef 

890 inverseLookup = self._datasetDict.inverse 

891 taskDef: TaskDef 

892 # sort by task label to ensure serialization happens in the same order 

893 for taskDef in self.taskGraph: 

894 # compressing has very little impact on saving or load time, but 

895 # a large impact on on disk size, so it is worth doing 

896 taskDescription = {} 

897 # save the fully qualified name. 

898 taskDescription["taskName"] = get_full_type_name(taskDef.taskClass) 

899 # save the config as a text stream that will be un-persisted on the 

900 # other end 

901 stream = io.StringIO() 

902 taskDef.config.saveToStream(stream) 

903 taskDescription["config"] = stream.getvalue() 

904 taskDescription["label"] = taskDef.label 

905 

906 inputs = [] 

907 outputs = [] 

908 

909 # Determine the connection between all of tasks and save that in 

910 # the header as a list of connections and edges in each task 

911 # this will help in un-persisting, and possibly in a "quick view" 

912 # method that does not require everything to be un-persisted 

913 # 

914 # Typing returns can't be parameter dependent 

915 for connection in inverseLookup[taskDef]: # type: ignore 

916 consumers = self._datasetDict.getConsumers(connection) 

917 producer = self._datasetDict.getProducer(connection) 

918 if taskDef in consumers: 

919 # This checks if the task consumes the connection directly 

920 # from the datastore or it is produced by another task 

921 producerLabel = producer.label if producer is not None else "datastore" 

922 inputs.append((producerLabel, connection)) 

923 elif taskDef not in consumers and producer is taskDef: 

924 # If there are no consumers for this tasks produced 

925 # connection, the output will be said to be the datastore 

926 # in which case the for loop will be a zero length loop 

927 if not consumers: 

928 outputs.append(("datastore", connection)) 

929 for td in consumers: 

930 outputs.append((td.label, connection)) 

931 

932 # dump to json string, and encode that string to bytes and then 

933 # conpress those bytes 

934 dump = lzma.compress(json.dumps(taskDescription).encode()) 

935 # record the sizing and relation information 

936 taskDefMap[taskDef.label] = { 

937 "bytes": (count, count + len(dump)), 

938 "inputs": inputs, 

939 "outputs": outputs, 

940 } 

941 count += len(dump) 

942 jsonData.append(dump) 

943 

944 headerData["TaskDefs"] = taskDefMap 

945 

946 # serialize the nodes, recording the start and end bytes of each node 

947 dimAccumulator = DimensionRecordsAccumulator() 

948 for node in self: 

949 # compressing has very little impact on saving or load time, but 

950 # a large impact on on disk size, so it is worth doing 

951 simpleNode = node.to_simple(accumulator=dimAccumulator) 

952 

953 dump = lzma.compress(simpleNode.json().encode()) 

954 jsonData.append(dump) 

955 nodeMap.append( 

956 ( 

957 str(node.nodeId), 

958 { 

959 "bytes": (count, count + len(dump)), 

960 "inputs": [str(n.nodeId) for n in self.determineInputsToQuantumNode(node)], 

961 "outputs": [str(n.nodeId) for n in self.determineOutputsOfQuantumNode(node)], 

962 }, 

963 ) 

964 ) 

965 count += len(dump) 

966 

967 headerData["DimensionRecords"] = { 

968 key: value.dict() for key, value in dimAccumulator.makeSerializedDimensionRecordMapping().items() 

969 } 

970 

971 # need to serialize this as a series of key,value tuples because of 

972 # a limitation on how json cant do anyting but strings as keys 

973 headerData["Nodes"] = nodeMap 

974 

975 # dump the headerData to json 

976 header_encode = lzma.compress(json.dumps(headerData).encode()) 

977 

978 # record the sizes as 2 unsigned long long numbers for a total of 16 

979 # bytes 

980 save_bytes = struct.pack(STRUCT_FMT_BASE, SAVE_VERSION) 

981 

982 fmt_string = DESERIALIZER_MAP[SAVE_VERSION].FMT_STRING() 

983 map_lengths = struct.pack(fmt_string, len(header_encode)) 

984 

985 # write each component of the save out in a deterministic order 

986 # buffer = io.BytesIO() 

987 # buffer.write(map_lengths) 

988 # buffer.write(taskDef_pickle) 

989 # buffer.write(map_pickle) 

990 buffer = bytearray() 

991 buffer.extend(MAGIC_BYTES) 

992 buffer.extend(save_bytes) 

993 buffer.extend(map_lengths) 

994 buffer.extend(header_encode) 

995 # Iterate over the length of pickleData, and for each element pop the 

996 # leftmost element off the deque and write it out. This is to save 

997 # memory, as the memory is added to the buffer object, it is removed 

998 # from from the container. 

999 # 

1000 # Only this section needs to worry about memory pressue because 

1001 # everything else written to the buffer prior to this pickle data is 

1002 # only on the order of kilobytes to low numbers of megabytes. 

1003 while jsonData: 

1004 buffer.extend(jsonData.popleft()) 

1005 if returnHeader: 

1006 return buffer, headerData 

1007 else: 

1008 return buffer 

1009 

1010 @classmethod 

1011 def load( 

1012 cls, 

1013 file: BinaryIO, 

1014 universe: DimensionUniverse, 

1015 nodes: Optional[Iterable[uuid.UUID]] = None, 

1016 graphID: Optional[BuildId] = None, 

1017 minimumVersion: int = 3, 

1018 ) -> QuantumGraph: 

1019 """Read QuantumGraph from a file that was made by `save`. 

1020 

1021 Parameters 

1022 ---------- 

1023 file : `io.IO` of bytes 

1024 File with pickle data open in binary mode. 

1025 universe: `~lsst.daf.butler.DimensionUniverse` 

1026 DimensionUniverse instance, not used by the method itself but 

1027 needed to ensure that registry data structures are initialized. 

1028 nodes: iterable of `int` or None 

1029 Numbers that correspond to nodes in the graph. If specified, only 

1030 these nodes will be loaded. Defaults to None, in which case all 

1031 nodes will be loaded. 

1032 graphID : `str` or `None` 

1033 If specified this ID is verified against the loaded graph prior to 

1034 loading any Nodes. This defaults to None in which case no 

1035 validation is done. 

1036 minimumVersion : int 

1037 Minimum version of a save file to load. Set to -1 to load all 

1038 versions. Older versions may need to be loaded, and re-saved 

1039 to upgrade them to the latest format before they can be used in 

1040 production. 

1041 

1042 Returns 

1043 ------- 

1044 graph : `QuantumGraph` 

1045 Resulting QuantumGraph instance. 

1046 

1047 Raises 

1048 ------ 

1049 TypeError 

1050 Raised if pickle contains instance of a type other than 

1051 QuantumGraph. 

1052 ValueError 

1053 Raised if one or more of the nodes requested is not in the 

1054 `QuantumGraph` or if graphID parameter does not match the graph 

1055 being loaded or if the supplied uri does not point at a valid 

1056 `QuantumGraph` save file. 

1057 

1058 Notes 

1059 ----- 

1060 Reading Quanta from pickle requires existence of singleton 

1061 DimensionUniverse which is usually instantiated during Registry 

1062 initialization. To make sure that DimensionUniverse exists this method 

1063 accepts dummy DimensionUniverse argument. 

1064 """ 

1065 # Try to see if the file handle contains pickle data, this will be 

1066 # removed in the future 

1067 try: 

1068 qgraph = pickle.load(file) 

1069 warnings.warn("Pickle graphs are deprecated, please re-save your graph with the save method") 

1070 except pickle.UnpicklingError: 

1071 with LoadHelper(file, minimumVersion) as loader: 

1072 qgraph = loader.load(universe, nodes, graphID) 

1073 if not isinstance(qgraph, QuantumGraph): 

1074 raise TypeError(f"QuantumGraph pickle file has contains unexpected object type: {type(qgraph)}") 

1075 return qgraph 

1076 

1077 def iterTaskGraph(self) -> Generator[TaskDef, None, None]: 

1078 """Iterate over the `taskGraph` attribute in topological order 

1079 

1080 Yields 

1081 ------ 

1082 taskDef : `TaskDef` 

1083 `TaskDef` objects in topological order 

1084 """ 

1085 yield from nx.topological_sort(self.taskGraph) 

1086 

1087 @property 

1088 def graphID(self) -> BuildId: 

1089 """Returns the ID generated by the graph at construction time""" 

1090 return self._buildId 

1091 

1092 def __iter__(self) -> Generator[QuantumNode, None, None]: 

1093 yield from nx.topological_sort(self._connectedQuanta) 

1094 

1095 def __len__(self) -> int: 

1096 return self._count 

1097 

1098 def __contains__(self, node: QuantumNode) -> bool: 

1099 return self._connectedQuanta.has_node(node) 

1100 

1101 def __getstate__(self) -> dict: 

1102 """Stores a compact form of the graph as a list of graph nodes, and a 

1103 tuple of task labels and task configs. The full graph can be 

1104 reconstructed with this information, and it preseves the ordering of 

1105 the graph ndoes. 

1106 """ 

1107 universe: Optional[DimensionUniverse] = None 

1108 for node in self: 

1109 dId = node.quantum.dataId 

1110 if dId is None: 

1111 continue 

1112 universe = dId.graph.universe 

1113 return {"reduced": self._buildSaveObject(), "graphId": self._buildId, "universe": universe} 

1114 

1115 def __setstate__(self, state: dict) -> None: 

1116 """Reconstructs the state of the graph from the information persisted 

1117 in getstate. 

1118 """ 

1119 buffer = io.BytesIO(state["reduced"]) 

1120 with LoadHelper(buffer, minimumVersion=3) as loader: 

1121 qgraph = loader.load(state["universe"], graphID=state["graphId"]) 

1122 

1123 self._metadata = qgraph._metadata 

1124 self._buildId = qgraph._buildId 

1125 self._datasetDict = qgraph._datasetDict 

1126 self._nodeIdMap = qgraph._nodeIdMap 

1127 self._count = len(qgraph) 

1128 self._taskToQuantumNode = qgraph._taskToQuantumNode 

1129 self._taskGraph = qgraph._taskGraph 

1130 self._connectedQuanta = qgraph._connectedQuanta 

1131 

1132 def __eq__(self, other: object) -> bool: 

1133 if not isinstance(other, QuantumGraph): 

1134 return False 

1135 if len(self) != len(other): 

1136 return False 

1137 for node in self: 

1138 if node not in other: 

1139 return False 

1140 if self.determineInputsToQuantumNode(node) != other.determineInputsToQuantumNode(node): 

1141 return False 

1142 if self.determineOutputsOfQuantumNode(node) != other.determineOutputsOfQuantumNode(node): 

1143 return False 

1144 if set(self.allDatasetTypes) != set(other.allDatasetTypes): 

1145 return False 

1146 return set(self.taskGraph) == set(other.taskGraph)