Coverage for python/lsst/pipe/base/graph/graph.py: 19%

346 statements  

« prev     ^ index     » next       coverage.py v6.4.2, created at 2022-07-20 02:51 -0700

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("QuantumGraph", "IncompatibleGraphError") 

24 

25import io 

26import json 

27import lzma 

28import os 

29import pickle 

30import struct 

31import time 

32import uuid 

33import warnings 

34from collections import defaultdict, deque 

35from itertools import chain 

36from types import MappingProxyType 

37from typing import ( 

38 Any, 

39 BinaryIO, 

40 DefaultDict, 

41 Deque, 

42 Dict, 

43 FrozenSet, 

44 Generator, 

45 Iterable, 

46 List, 

47 Mapping, 

48 MutableMapping, 

49 Optional, 

50 Set, 

51 Tuple, 

52 TypeVar, 

53 Union, 

54) 

55 

56import networkx as nx 

57from lsst.daf.butler import DatasetRef, DatasetType, DimensionRecordsAccumulator, DimensionUniverse, Quantum 

58from lsst.resources import ResourcePath, ResourcePathExpression 

59from lsst.utils.introspection import get_full_type_name 

60from networkx.drawing.nx_agraph import write_dot 

61 

62from ..connections import iterConnections 

63from ..pipeline import TaskDef 

64from ._implDetails import DatasetTypeName, _DatasetTracker, _pruner 

65from ._loadHelpers import LoadHelper 

66from ._versionDeserializers import DESERIALIZER_MAP 

67from .quantumNode import BuildId, QuantumNode 

68 

69_T = TypeVar("_T", bound="QuantumGraph") 

70 

71# modify this constant any time the on disk representation of the save file 

72# changes, and update the load helpers to behave properly for each version. 

73SAVE_VERSION = 3 

74 

75# Strings used to describe the format for the preamble bytes in a file save 

76# The base is a big endian encoded unsigned short that is used to hold the 

77# file format version. This allows reading version bytes and determine which 

78# loading code should be used for the rest of the file 

79STRUCT_FMT_BASE = ">H" 

80# 

81# Version 1 

82# This marks a big endian encoded format with an unsigned short, an unsigned 

83# long long, and an unsigned long long in the byte stream 

84# Version 2 

85# A big endian encoded format with an unsigned long long byte stream used to 

86# indicate the total length of the entire header. 

87STRUCT_FMT_STRING = {1: ">QQ", 2: ">Q"} 

88 

89# magic bytes that help determine this is a graph save 

90MAGIC_BYTES = b"qgraph4\xf6\xe8\xa9" 

91 

92 

93class IncompatibleGraphError(Exception): 

94 """Exception class to indicate that a lookup by NodeId is impossible due 

95 to incompatibilities 

96 """ 

97 

98 pass 

99 

100 

101class QuantumGraph: 

102 """QuantumGraph is a directed acyclic graph of `QuantumNode` objects 

103 

104 This data structure represents a concrete workflow generated from a 

105 `Pipeline`. 

106 

107 Parameters 

108 ---------- 

109 quanta : Mapping of `TaskDef` to sets of `Quantum` 

110 This maps tasks (and their configs) to the sets of data they are to 

111 process. 

112 metadata : Optional Mapping of `str` to primitives 

113 This is an optional parameter of extra data to carry with the graph. 

114 Entries in this mapping should be able to be serialized in JSON. 

115 

116 Raises 

117 ------ 

118 ValueError 

119 Raised if the graph is pruned such that some tasks no longer have nodes 

120 associated with them. 

121 """ 

122 

123 def __init__( 

124 self, 

125 quanta: Mapping[TaskDef, Set[Quantum]], 

126 metadata: Optional[Mapping[str, Any]] = None, 

127 pruneRefs: Optional[Iterable[DatasetRef]] = None, 

128 universe: Optional[DimensionUniverse] = None, 

129 ): 

130 self._buildGraphs(quanta, metadata=metadata, pruneRefs=pruneRefs, universe=universe) 

131 

132 def _buildGraphs( 

133 self, 

134 quanta: Mapping[TaskDef, Set[Quantum]], 

135 *, 

136 _quantumToNodeId: Optional[Mapping[Quantum, uuid.UUID]] = None, 

137 _buildId: Optional[BuildId] = None, 

138 metadata: Optional[Mapping[str, Any]] = None, 

139 pruneRefs: Optional[Iterable[DatasetRef]] = None, 

140 universe: Optional[DimensionUniverse] = None, 

141 ) -> None: 

142 """Builds the graph that is used to store the relation between tasks, 

143 and the graph that holds the relations between quanta 

144 """ 

145 if universe is None: 

146 universe = DimensionUniverse() 

147 self._universe = universe 

148 self._metadata = metadata 

149 self._buildId = _buildId if _buildId is not None else BuildId(f"{time.time()}-{os.getpid()}") 

150 # Data structures used to identify relations between components; 

151 # DatasetTypeName -> TaskDef for task, 

152 # and DatasetRef -> QuantumNode for the quanta 

153 self._datasetDict = _DatasetTracker[DatasetTypeName, TaskDef](createInverse=True) 

154 self._datasetRefDict = _DatasetTracker[DatasetRef, QuantumNode]() 

155 

156 self._nodeIdMap: Dict[uuid.UUID, QuantumNode] = {} 

157 self._taskToQuantumNode: MutableMapping[TaskDef, Set[QuantumNode]] = defaultdict(set) 

158 for taskDef, quantumSet in quanta.items(): 

159 connections = taskDef.connections 

160 

161 # For each type of connection in the task, add a key to the 

162 # `_DatasetTracker` for the connections name, with a value of 

163 # the TaskDef in the appropriate field 

164 for inpt in iterConnections(connections, ("inputs", "prerequisiteInputs", "initInputs")): 

165 self._datasetDict.addConsumer(DatasetTypeName(inpt.name), taskDef) 

166 

167 for output in iterConnections(connections, ("outputs",)): 

168 self._datasetDict.addProducer(DatasetTypeName(output.name), taskDef) 

169 

170 # For each `Quantum` in the set of all `Quantum` for this task, 

171 # add a key to the `_DatasetTracker` that is a `DatasetRef` for one 

172 # of the individual datasets inside the `Quantum`, with a value of 

173 # a newly created QuantumNode to the appropriate input/output 

174 # field. 

175 for quantum in quantumSet: 

176 if _quantumToNodeId: 

177 if (nodeId := _quantumToNodeId.get(quantum)) is None: 

178 raise ValueError( 

179 "If _quantuMToNodeNumber is not None, all quanta must have an " 

180 "associated value in the mapping" 

181 ) 

182 else: 

183 nodeId = uuid.uuid4() 

184 

185 inits = quantum.initInputs.values() 

186 inputs = quantum.inputs.values() 

187 value = QuantumNode(quantum, taskDef, nodeId) 

188 self._taskToQuantumNode[taskDef].add(value) 

189 self._nodeIdMap[nodeId] = value 

190 

191 for dsRef in chain(inits, inputs): 

192 # unfortunately, `Quantum` allows inits to be individual 

193 # `DatasetRef`s or an Iterable of such, so there must 

194 # be an instance check here 

195 if isinstance(dsRef, Iterable): 

196 for sub in dsRef: 

197 if sub.isComponent(): 

198 sub = sub.makeCompositeRef() 

199 self._datasetRefDict.addConsumer(sub, value) 

200 else: 

201 assert isinstance(dsRef, DatasetRef) 

202 if dsRef.isComponent(): 

203 dsRef = dsRef.makeCompositeRef() 

204 self._datasetRefDict.addConsumer(dsRef, value) 

205 for dsRef in chain.from_iterable(quantum.outputs.values()): 

206 self._datasetRefDict.addProducer(dsRef, value) 

207 

208 if pruneRefs is not None: 

209 # track what refs were pruned and prune the graph 

210 prunes: Set[QuantumNode] = set() 

211 _pruner(self._datasetRefDict, pruneRefs, alreadyPruned=prunes) 

212 

213 # recreate the taskToQuantumNode dict removing nodes that have been 

214 # pruned. Keep track of task defs that now have no QuantumNodes 

215 emptyTasks: Set[str] = set() 

216 newTaskToQuantumNode: DefaultDict[TaskDef, Set[QuantumNode]] = defaultdict(set) 

217 # accumulate all types 

218 types_ = set() 

219 # tracker for any pruneRefs that have caused tasks to have no nodes 

220 # This helps the user find out what caused the issues seen. 

221 culprits = set() 

222 # Find all the types from the refs to prune 

223 for r in pruneRefs: 

224 types_.add(r.datasetType) 

225 

226 # For each of the tasks, and their associated nodes, remove any 

227 # any nodes that were pruned. If there are no nodes associated 

228 # with a task, record that task, and find out if that was due to 

229 # a type from an input ref to prune. 

230 for td, taskNodes in self._taskToQuantumNode.items(): 

231 diff = taskNodes.difference(prunes) 

232 if len(diff) == 0: 

233 if len(taskNodes) != 0: 

234 tp: DatasetType 

235 for tp in types_: 

236 if (tmpRefs := next(iter(taskNodes)).quantum.inputs.get(tp)) and not set( 

237 tmpRefs 

238 ).difference(pruneRefs): 

239 culprits.add(tp.name) 

240 emptyTasks.add(td.label) 

241 newTaskToQuantumNode[td] = diff 

242 

243 # update the internal dict 

244 self._taskToQuantumNode = newTaskToQuantumNode 

245 

246 if emptyTasks: 

247 raise ValueError( 

248 f"{', '.join(emptyTasks)} task(s) have no nodes associated with them " 

249 f"after graph pruning; {', '.join(culprits)} caused over-pruning" 

250 ) 

251 

252 # Graph of quanta relations 

253 self._connectedQuanta = self._datasetRefDict.makeNetworkXGraph() 

254 self._count = len(self._connectedQuanta) 

255 

256 # Graph of task relations, used in various methods 

257 self._taskGraph = self._datasetDict.makeNetworkXGraph() 

258 

259 # convert default dict into a regular to prevent accidental key 

260 # insertion 

261 self._taskToQuantumNode = dict(self._taskToQuantumNode.items()) 

262 

263 @property 

264 def taskGraph(self) -> nx.DiGraph: 

265 """Return a graph representing the relations between the tasks inside 

266 the quantum graph. 

267 

268 Returns 

269 ------- 

270 taskGraph : `networkx.Digraph` 

271 Internal datastructure that holds relations of `TaskDef` objects 

272 """ 

273 return self._taskGraph 

274 

275 @property 

276 def graph(self) -> nx.DiGraph: 

277 """Return a graph representing the relations between all the 

278 `QuantumNode` objects. Largely it should be preferred to iterate 

279 over, and use methods of this class, but sometimes direct access to 

280 the networkx object may be helpful 

281 

282 Returns 

283 ------- 

284 graph : `networkx.Digraph` 

285 Internal datastructure that holds relations of `QuantumNode` 

286 objects 

287 """ 

288 return self._connectedQuanta 

289 

290 @property 

291 def inputQuanta(self) -> Iterable[QuantumNode]: 

292 """Make a `list` of all `QuantumNode` objects that are 'input' nodes 

293 to the graph, meaning those nodes to not depend on any other nodes in 

294 the graph. 

295 

296 Returns 

297 ------- 

298 inputNodes : iterable of `QuantumNode` 

299 A list of nodes that are inputs to the graph 

300 """ 

301 return (q for q, n in self._connectedQuanta.in_degree if n == 0) 

302 

303 @property 

304 def outputQuanta(self) -> Iterable[QuantumNode]: 

305 """Make a `list` of all `QuantumNode` objects that are 'output' nodes 

306 to the graph, meaning those nodes have no nodes that depend them in 

307 the graph. 

308 

309 Returns 

310 ------- 

311 outputNodes : iterable of `QuantumNode` 

312 A list of nodes that are outputs of the graph 

313 """ 

314 return [q for q, n in self._connectedQuanta.out_degree if n == 0] 

315 

316 @property 

317 def allDatasetTypes(self) -> Tuple[DatasetTypeName, ...]: 

318 """Return all the `DatasetTypeName` objects that are contained inside 

319 the graph. 

320 

321 Returns 

322 ------- 

323 tuple of `DatasetTypeName` 

324 All the data set type names that are present in the graph 

325 """ 

326 return tuple(self._datasetDict.keys()) 

327 

328 @property 

329 def isConnected(self) -> bool: 

330 """Return True if all of the nodes in the graph are connected, ignores 

331 directionality of connections. 

332 """ 

333 return nx.is_weakly_connected(self._connectedQuanta) 

334 

335 def pruneGraphFromRefs(self: _T, refs: Iterable[DatasetRef]) -> _T: 

336 r"""Return a graph pruned of input `~lsst.daf.butler.DatasetRef`\ s 

337 and nodes which depend on them. 

338 

339 Parameters 

340 ---------- 

341 refs : `Iterable` of `DatasetRef` 

342 Refs which should be removed from resulting graph 

343 

344 Returns 

345 ------- 

346 graph : `QuantumGraph` 

347 A graph that has been pruned of specified refs and the nodes that 

348 depend on them. 

349 """ 

350 newInst = object.__new__(type(self)) 

351 quantumMap = defaultdict(set) 

352 for node in self: 

353 quantumMap[node.taskDef].add(node.quantum) 

354 

355 # convert to standard dict to prevent accidental key insertion 

356 quantumDict: Dict[TaskDef, Set[Quantum]] = dict(quantumMap.items()) 

357 

358 newInst._buildGraphs( 

359 quantumDict, 

360 _quantumToNodeId={n.quantum: n.nodeId for n in self}, 

361 metadata=self._metadata, 

362 pruneRefs=refs, 

363 ) 

364 return newInst 

365 

366 def getQuantumNodeByNodeId(self, nodeId: uuid.UUID) -> QuantumNode: 

367 """Lookup a `QuantumNode` from an id associated with the node. 

368 

369 Parameters 

370 ---------- 

371 nodeId : `NodeId` 

372 The number associated with a node 

373 

374 Returns 

375 ------- 

376 node : `QuantumNode` 

377 The node corresponding with input number 

378 

379 Raises 

380 ------ 

381 KeyError 

382 Raised if the requested nodeId is not in the graph. 

383 """ 

384 return self._nodeIdMap[nodeId] 

385 

386 def getQuantaForTask(self, taskDef: TaskDef) -> FrozenSet[Quantum]: 

387 """Return all the `Quantum` associated with a `TaskDef`. 

388 

389 Parameters 

390 ---------- 

391 taskDef : `TaskDef` 

392 The `TaskDef` for which `Quantum` are to be queried 

393 

394 Returns 

395 ------- 

396 frozenset of `Quantum` 

397 The `set` of `Quantum` that is associated with the specified 

398 `TaskDef`. 

399 """ 

400 return frozenset(node.quantum for node in self._taskToQuantumNode[taskDef]) 

401 

402 def getNodesForTask(self, taskDef: TaskDef) -> FrozenSet[QuantumNode]: 

403 """Return all the `QuantumNodes` associated with a `TaskDef`. 

404 

405 Parameters 

406 ---------- 

407 taskDef : `TaskDef` 

408 The `TaskDef` for which `Quantum` are to be queried 

409 

410 Returns 

411 ------- 

412 frozenset of `QuantumNodes` 

413 The `frozenset` of `QuantumNodes` that is associated with the 

414 specified `TaskDef`. 

415 """ 

416 return frozenset(self._taskToQuantumNode[taskDef]) 

417 

418 def findTasksWithInput(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]: 

419 """Find all tasks that have the specified dataset type name as an 

420 input. 

421 

422 Parameters 

423 ---------- 

424 datasetTypeName : `str` 

425 A string representing the name of a dataset type to be queried, 

426 can also accept a `DatasetTypeName` which is a `NewType` of str for 

427 type safety in static type checking. 

428 

429 Returns 

430 ------- 

431 tasks : iterable of `TaskDef` 

432 `TaskDef` objects that have the specified `DatasetTypeName` as an 

433 input, list will be empty if no tasks use specified 

434 `DatasetTypeName` as an input. 

435 

436 Raises 

437 ------ 

438 KeyError 

439 Raised if the `DatasetTypeName` is not part of the `QuantumGraph` 

440 """ 

441 return (c for c in self._datasetDict.getConsumers(datasetTypeName)) 

442 

443 def findTaskWithOutput(self, datasetTypeName: DatasetTypeName) -> Optional[TaskDef]: 

444 """Find all tasks that have the specified dataset type name as an 

445 output. 

446 

447 Parameters 

448 ---------- 

449 datasetTypeName : `str` 

450 A string representing the name of a dataset type to be queried, 

451 can also accept a `DatasetTypeName` which is a `NewType` of str for 

452 type safety in static type checking. 

453 

454 Returns 

455 ------- 

456 `TaskDef` or `None` 

457 `TaskDef` that outputs `DatasetTypeName` as an output or None if 

458 none of the tasks produce this `DatasetTypeName`. 

459 

460 Raises 

461 ------ 

462 KeyError 

463 Raised if the `DatasetTypeName` is not part of the `QuantumGraph` 

464 """ 

465 return self._datasetDict.getProducer(datasetTypeName) 

466 

467 def tasksWithDSType(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]: 

468 """Find all tasks that are associated with the specified dataset type 

469 name. 

470 

471 Parameters 

472 ---------- 

473 datasetTypeName : `str` 

474 A string representing the name of a dataset type to be queried, 

475 can also accept a `DatasetTypeName` which is a `NewType` of str for 

476 type safety in static type checking. 

477 

478 Returns 

479 ------- 

480 result : iterable of `TaskDef` 

481 `TaskDef` objects that are associated with the specified 

482 `DatasetTypeName` 

483 

484 Raises 

485 ------ 

486 KeyError 

487 Raised if the `DatasetTypeName` is not part of the `QuantumGraph` 

488 """ 

489 return self._datasetDict.getAll(datasetTypeName) 

490 

491 def findTaskDefByName(self, taskName: str) -> List[TaskDef]: 

492 """Determine which `TaskDef` objects in this graph are associated 

493 with a `str` representing a task name (looks at the taskName property 

494 of `TaskDef` objects). 

495 

496 Returns a list of `TaskDef` objects as a `PipelineTask` may appear 

497 multiple times in a graph with different labels. 

498 

499 Parameters 

500 ---------- 

501 taskName : str 

502 Name of a task to search for 

503 

504 Returns 

505 ------- 

506 result : list of `TaskDef` 

507 List of the `TaskDef` objects that have the name specified. 

508 Multiple values are returned in the case that a task is used 

509 multiple times with different labels. 

510 """ 

511 results = [] 

512 for task in self._taskToQuantumNode.keys(): 

513 split = task.taskName.split(".") 

514 if split[-1] == taskName: 

515 results.append(task) 

516 return results 

517 

518 def findTaskDefByLabel(self, label: str) -> Optional[TaskDef]: 

519 """Determine which `TaskDef` objects in this graph are associated 

520 with a `str` representing a tasks label. 

521 

522 Parameters 

523 ---------- 

524 taskName : str 

525 Name of a task to search for 

526 

527 Returns 

528 ------- 

529 result : `TaskDef` 

530 `TaskDef` objects that has the specified label. 

531 """ 

532 for task in self._taskToQuantumNode.keys(): 

533 if label == task.label: 

534 return task 

535 return None 

536 

537 def findQuantaWithDSType(self, datasetTypeName: DatasetTypeName) -> Set[Quantum]: 

538 """Return all the `Quantum` that contain a specified `DatasetTypeName`. 

539 

540 Parameters 

541 ---------- 

542 datasetTypeName : `str` 

543 The name of the dataset type to search for as a string, 

544 can also accept a `DatasetTypeName` which is a `NewType` of str for 

545 type safety in static type checking. 

546 

547 Returns 

548 ------- 

549 result : `set` of `QuantumNode` objects 

550 A `set` of `QuantumNode`s that contain specified `DatasetTypeName` 

551 

552 Raises 

553 ------ 

554 KeyError 

555 Raised if the `DatasetTypeName` is not part of the `QuantumGraph` 

556 

557 """ 

558 tasks = self._datasetDict.getAll(datasetTypeName) 

559 result: Set[Quantum] = set() 

560 result = result.union(quantum for task in tasks for quantum in self.getQuantaForTask(task)) 

561 return result 

562 

563 def checkQuantumInGraph(self, quantum: Quantum) -> bool: 

564 """Check if specified quantum appears in the graph as part of a node. 

565 

566 Parameters 

567 ---------- 

568 quantum : `Quantum` 

569 The quantum to search for 

570 

571 Returns 

572 ------- 

573 `bool` 

574 The result of searching for the quantum 

575 """ 

576 for node in self: 

577 if quantum == node.quantum: 

578 return True 

579 return False 

580 

581 def writeDotGraph(self, output: Union[str, io.BufferedIOBase]) -> None: 

582 """Write out the graph as a dot graph. 

583 

584 Parameters 

585 ---------- 

586 output : str or `io.BufferedIOBase` 

587 Either a filesystem path to write to, or a file handle object 

588 """ 

589 write_dot(self._connectedQuanta, output) 

590 

591 def subset(self: _T, nodes: Union[QuantumNode, Iterable[QuantumNode]]) -> _T: 

592 """Create a new graph object that contains the subset of the nodes 

593 specified as input. Node number is preserved. 

594 

595 Parameters 

596 ---------- 

597 nodes : `QuantumNode` or iterable of `QuantumNode` 

598 

599 Returns 

600 ------- 

601 graph : instance of graph type 

602 An instance of the type from which the subset was created 

603 """ 

604 if not isinstance(nodes, Iterable): 

605 nodes = (nodes,) 

606 quantumSubgraph = self._connectedQuanta.subgraph(nodes).nodes 

607 quantumMap = defaultdict(set) 

608 

609 node: QuantumNode 

610 for node in quantumSubgraph: 

611 quantumMap[node.taskDef].add(node.quantum) 

612 

613 # convert to standard dict to prevent accidental key insertion 

614 quantumDict: Dict[TaskDef, Set[Quantum]] = dict(quantumMap.items()) 

615 # Create an empty graph, and then populate it with custom mapping 

616 newInst = type(self)({}) 

617 newInst._buildGraphs( 

618 quantumDict, 

619 _quantumToNodeId={n.quantum: n.nodeId for n in nodes}, 

620 _buildId=self._buildId, 

621 metadata=self._metadata, 

622 ) 

623 return newInst 

624 

625 def subsetToConnected(self: _T) -> Tuple[_T, ...]: 

626 """Generate a list of subgraphs where each is connected. 

627 

628 Returns 

629 ------- 

630 result : list of `QuantumGraph` 

631 A list of graphs that are each connected 

632 """ 

633 return tuple( 

634 self.subset(connectedSet) 

635 for connectedSet in nx.weakly_connected_components(self._connectedQuanta) 

636 ) 

637 

638 def determineInputsToQuantumNode(self, node: QuantumNode) -> Set[QuantumNode]: 

639 """Return a set of `QuantumNode` that are direct inputs to a specified 

640 node. 

641 

642 Parameters 

643 ---------- 

644 node : `QuantumNode` 

645 The node of the graph for which inputs are to be determined 

646 

647 Returns 

648 ------- 

649 set of `QuantumNode` 

650 All the nodes that are direct inputs to specified node 

651 """ 

652 return set(pred for pred in self._connectedQuanta.predecessors(node)) 

653 

654 def determineOutputsOfQuantumNode(self, node: QuantumNode) -> Set[QuantumNode]: 

655 """Return a set of `QuantumNode` that are direct outputs of a specified 

656 node. 

657 

658 Parameters 

659 ---------- 

660 node : `QuantumNode` 

661 The node of the graph for which outputs are to be determined 

662 

663 Returns 

664 ------- 

665 set of `QuantumNode` 

666 All the nodes that are direct outputs to specified node 

667 """ 

668 return set(succ for succ in self._connectedQuanta.successors(node)) 

669 

670 def determineConnectionsOfQuantumNode(self: _T, node: QuantumNode) -> _T: 

671 """Return a graph of `QuantumNode` that are direct inputs and outputs 

672 of a specified node. 

673 

674 Parameters 

675 ---------- 

676 node : `QuantumNode` 

677 The node of the graph for which connected nodes are to be 

678 determined. 

679 

680 Returns 

681 ------- 

682 graph : graph of `QuantumNode` 

683 All the nodes that are directly connected to specified node 

684 """ 

685 nodes = self.determineInputsToQuantumNode(node).union(self.determineOutputsOfQuantumNode(node)) 

686 nodes.add(node) 

687 return self.subset(nodes) 

688 

689 def determineAncestorsOfQuantumNode(self: _T, node: QuantumNode) -> _T: 

690 """Return a graph of the specified node and all the ancestor nodes 

691 directly reachable by walking edges. 

692 

693 Parameters 

694 ---------- 

695 node : `QuantumNode` 

696 The node for which all ansestors are to be determined 

697 

698 Returns 

699 ------- 

700 graph of `QuantumNode` 

701 Graph of node and all of its ansestors 

702 """ 

703 predecessorNodes = nx.ancestors(self._connectedQuanta, node) 

704 predecessorNodes.add(node) 

705 return self.subset(predecessorNodes) 

706 

707 def findCycle(self) -> List[Tuple[QuantumNode, QuantumNode]]: 

708 """Check a graph for the presense of cycles and returns the edges of 

709 any cycles found, or an empty list if there is no cycle. 

710 

711 Returns 

712 ------- 

713 result : list of tuple of `QuantumNode`, `QuantumNode` 

714 A list of any graph edges that form a cycle, or an empty list if 

715 there is no cycle. Empty list to so support if graph.find_cycle() 

716 syntax as an empty list is falsy. 

717 """ 

718 try: 

719 return nx.find_cycle(self._connectedQuanta) 

720 except nx.NetworkXNoCycle: 

721 return [] 

722 

723 def saveUri(self, uri: ResourcePathExpression) -> None: 

724 """Save `QuantumGraph` to the specified URI. 

725 

726 Parameters 

727 ---------- 

728 uri : convertible to `ResourcePath` 

729 URI to where the graph should be saved. 

730 """ 

731 buffer = self._buildSaveObject() 

732 path = ResourcePath(uri) 

733 if path.getExtension() not in (".qgraph"): 

734 raise TypeError(f"Can currently only save a graph in qgraph format not {uri}") 

735 path.write(buffer) # type: ignore # Ignore because bytearray is safe to use in place of bytes 

736 

737 @property 

738 def metadata(self) -> Optional[MappingProxyType[str, Any]]: 

739 """ """ 

740 if self._metadata is None: 

741 return None 

742 return MappingProxyType(self._metadata) 

743 

744 @classmethod 

745 def loadUri( 

746 cls, 

747 uri: ResourcePathExpression, 

748 universe: Optional[DimensionUniverse] = None, 

749 nodes: Optional[Iterable[uuid.UUID]] = None, 

750 graphID: Optional[BuildId] = None, 

751 minimumVersion: int = 3, 

752 ) -> QuantumGraph: 

753 """Read `QuantumGraph` from a URI. 

754 

755 Parameters 

756 ---------- 

757 uri : convertible to `ResourcePath` 

758 URI from where to load the graph. 

759 universe: `~lsst.daf.butler.DimensionUniverse` optional 

760 DimensionUniverse instance, not used by the method itself but 

761 needed to ensure that registry data structures are initialized. 

762 If None it is loaded from the QuantumGraph saved structure. If 

763 supplied, the DimensionUniverse from the loaded `QuantumGraph` 

764 will be validated against the supplied argument for compatibility. 

765 nodes: iterable of `int` or None 

766 Numbers that correspond to nodes in the graph. If specified, only 

767 these nodes will be loaded. Defaults to None, in which case all 

768 nodes will be loaded. 

769 graphID : `str` or `None` 

770 If specified this ID is verified against the loaded graph prior to 

771 loading any Nodes. This defaults to None in which case no 

772 validation is done. 

773 minimumVersion : int 

774 Minimum version of a save file to load. Set to -1 to load all 

775 versions. Older versions may need to be loaded, and re-saved 

776 to upgrade them to the latest format before they can be used in 

777 production. 

778 

779 Returns 

780 ------- 

781 graph : `QuantumGraph` 

782 Resulting QuantumGraph instance. 

783 

784 Raises 

785 ------ 

786 TypeError 

787 Raised if pickle contains instance of a type other than 

788 QuantumGraph. 

789 ValueError 

790 Raised if one or more of the nodes requested is not in the 

791 `QuantumGraph` or if graphID parameter does not match the graph 

792 being loaded or if the supplied uri does not point at a valid 

793 `QuantumGraph` save file. 

794 RuntimeError 

795 Raise if Supplied DimensionUniverse is not compatible with the 

796 DimensionUniverse saved in the graph 

797 

798 

799 Notes 

800 ----- 

801 Reading Quanta from pickle requires existence of singleton 

802 DimensionUniverse which is usually instantiated during Registry 

803 initialization. To make sure that DimensionUniverse exists this method 

804 accepts dummy DimensionUniverse argument. 

805 """ 

806 uri = ResourcePath(uri) 

807 # With ResourcePath we have the choice of always using a local file 

808 # or reading in the bytes directly. Reading in bytes can be more 

809 # efficient for reasonably-sized pickle files when the resource 

810 # is remote. For now use the local file variant. For a local file 

811 # as_local() does nothing. 

812 

813 if uri.getExtension() in (".pickle", ".pkl"): 

814 with uri.as_local() as local, open(local.ospath, "rb") as fd: 

815 warnings.warn("Pickle graphs are deprecated, please re-save your graph with the save method") 

816 qgraph = pickle.load(fd) 

817 elif uri.getExtension() in (".qgraph"): 

818 with LoadHelper(uri, minimumVersion) as loader: 

819 qgraph = loader.load(universe, nodes, graphID) 

820 else: 

821 raise ValueError("Only know how to handle files saved as `pickle`, `pkl`, or `qgraph`") 

822 if not isinstance(qgraph, QuantumGraph): 

823 raise TypeError(f"QuantumGraph save file contains unexpected object type: {type(qgraph)}") 

824 return qgraph 

825 

826 @classmethod 

827 def readHeader(cls, uri: ResourcePathExpression, minimumVersion: int = 3) -> Optional[str]: 

828 """Read the header of a `QuantumGraph` pointed to by the uri parameter 

829 and return it as a string. 

830 

831 Parameters 

832 ---------- 

833 uri : convertible to `ResourcePath` 

834 The location of the `QuantumGraph` to load. If the argument is a 

835 string, it must correspond to a valid `ResourcePath` path. 

836 minimumVersion : int 

837 Minimum version of a save file to load. Set to -1 to load all 

838 versions. Older versions may need to be loaded, and re-saved 

839 to upgrade them to the latest format before they can be used in 

840 production. 

841 

842 Returns 

843 ------- 

844 header : `str` or `None` 

845 The header associated with the specified `QuantumGraph` it there is 

846 one, else `None`. 

847 

848 Raises 

849 ------ 

850 ValueError 

851 Raised if `QuantuGraph` was saved as a pickle. 

852 Raised if the extention of the file specified by uri is not a 

853 `QuantumGraph` extention. 

854 """ 

855 uri = ResourcePath(uri) 

856 if uri.getExtension() in (".pickle", ".pkl"): 

857 raise ValueError("Reading a header from a pickle save is not supported") 

858 elif uri.getExtension() in (".qgraph"): 

859 return LoadHelper(uri, minimumVersion).readHeader() 

860 else: 

861 raise ValueError("Only know how to handle files saved as `qgraph`") 

862 

863 def buildAndPrintHeader(self) -> None: 

864 """Creates a header that would be used in a save of this object and 

865 prints it out to standard out. 

866 """ 

867 _, header = self._buildSaveObject(returnHeader=True) 

868 print(json.dumps(header)) 

869 

870 def save(self, file: BinaryIO) -> None: 

871 """Save QuantumGraph to a file. 

872 

873 Parameters 

874 ---------- 

875 file : `io.BufferedIOBase` 

876 File to write pickle data open in binary mode. 

877 """ 

878 buffer = self._buildSaveObject() 

879 file.write(buffer) # type: ignore # Ignore because bytearray is safe to use in place of bytes 

880 

881 def _buildSaveObject(self, returnHeader: bool = False) -> Union[bytearray, Tuple[bytearray, Dict]]: 

882 # make some containers 

883 jsonData: Deque[bytes] = deque() 

884 # node map is a list because json does not accept mapping keys that 

885 # are not strings, so we store a list of key, value pairs that will 

886 # be converted to a mapping on load 

887 nodeMap = [] 

888 taskDefMap = {} 

889 headerData: Dict[str, Any] = {} 

890 

891 # Store the QauntumGraph BuildId, this will allow validating BuildIds 

892 # at load time, prior to loading any QuantumNodes. Name chosen for 

893 # unlikely conflicts. 

894 headerData["GraphBuildID"] = self.graphID 

895 headerData["Metadata"] = self._metadata 

896 

897 # Store the universe this graph was created with 

898 universeConfig = self._universe.dimensionConfig 

899 headerData["universe"] = universeConfig.toDict() 

900 

901 # counter for the number of bytes processed thus far 

902 count = 0 

903 # serialize out the task Defs recording the start and end bytes of each 

904 # taskDef 

905 inverseLookup = self._datasetDict.inverse 

906 taskDef: TaskDef 

907 # sort by task label to ensure serialization happens in the same order 

908 for taskDef in self.taskGraph: 

909 # compressing has very little impact on saving or load time, but 

910 # a large impact on on disk size, so it is worth doing 

911 taskDescription = {} 

912 # save the fully qualified name. 

913 taskDescription["taskName"] = get_full_type_name(taskDef.taskClass) 

914 # save the config as a text stream that will be un-persisted on the 

915 # other end 

916 stream = io.StringIO() 

917 taskDef.config.saveToStream(stream) 

918 taskDescription["config"] = stream.getvalue() 

919 taskDescription["label"] = taskDef.label 

920 

921 inputs = [] 

922 outputs = [] 

923 

924 # Determine the connection between all of tasks and save that in 

925 # the header as a list of connections and edges in each task 

926 # this will help in un-persisting, and possibly in a "quick view" 

927 # method that does not require everything to be un-persisted 

928 # 

929 # Typing returns can't be parameter dependent 

930 for connection in inverseLookup[taskDef]: # type: ignore 

931 consumers = self._datasetDict.getConsumers(connection) 

932 producer = self._datasetDict.getProducer(connection) 

933 if taskDef in consumers: 

934 # This checks if the task consumes the connection directly 

935 # from the datastore or it is produced by another task 

936 producerLabel = producer.label if producer is not None else "datastore" 

937 inputs.append((producerLabel, connection)) 

938 elif taskDef not in consumers and producer is taskDef: 

939 # If there are no consumers for this tasks produced 

940 # connection, the output will be said to be the datastore 

941 # in which case the for loop will be a zero length loop 

942 if not consumers: 

943 outputs.append(("datastore", connection)) 

944 for td in consumers: 

945 outputs.append((td.label, connection)) 

946 

947 # dump to json string, and encode that string to bytes and then 

948 # conpress those bytes 

949 dump = lzma.compress(json.dumps(taskDescription).encode()) 

950 # record the sizing and relation information 

951 taskDefMap[taskDef.label] = { 

952 "bytes": (count, count + len(dump)), 

953 "inputs": inputs, 

954 "outputs": outputs, 

955 } 

956 count += len(dump) 

957 jsonData.append(dump) 

958 

959 headerData["TaskDefs"] = taskDefMap 

960 

961 # serialize the nodes, recording the start and end bytes of each node 

962 dimAccumulator = DimensionRecordsAccumulator() 

963 for node in self: 

964 # compressing has very little impact on saving or load time, but 

965 # a large impact on on disk size, so it is worth doing 

966 simpleNode = node.to_simple(accumulator=dimAccumulator) 

967 

968 dump = lzma.compress(simpleNode.json().encode()) 

969 jsonData.append(dump) 

970 nodeMap.append( 

971 ( 

972 str(node.nodeId), 

973 { 

974 "bytes": (count, count + len(dump)), 

975 "inputs": [str(n.nodeId) for n in self.determineInputsToQuantumNode(node)], 

976 "outputs": [str(n.nodeId) for n in self.determineOutputsOfQuantumNode(node)], 

977 }, 

978 ) 

979 ) 

980 count += len(dump) 

981 

982 headerData["DimensionRecords"] = { 

983 key: value.dict() for key, value in dimAccumulator.makeSerializedDimensionRecordMapping().items() 

984 } 

985 

986 # need to serialize this as a series of key,value tuples because of 

987 # a limitation on how json cant do anyting but strings as keys 

988 headerData["Nodes"] = nodeMap 

989 

990 # dump the headerData to json 

991 header_encode = lzma.compress(json.dumps(headerData).encode()) 

992 

993 # record the sizes as 2 unsigned long long numbers for a total of 16 

994 # bytes 

995 save_bytes = struct.pack(STRUCT_FMT_BASE, SAVE_VERSION) 

996 

997 fmt_string = DESERIALIZER_MAP[SAVE_VERSION].FMT_STRING() 

998 map_lengths = struct.pack(fmt_string, len(header_encode)) 

999 

1000 # write each component of the save out in a deterministic order 

1001 # buffer = io.BytesIO() 

1002 # buffer.write(map_lengths) 

1003 # buffer.write(taskDef_pickle) 

1004 # buffer.write(map_pickle) 

1005 buffer = bytearray() 

1006 buffer.extend(MAGIC_BYTES) 

1007 buffer.extend(save_bytes) 

1008 buffer.extend(map_lengths) 

1009 buffer.extend(header_encode) 

1010 # Iterate over the length of pickleData, and for each element pop the 

1011 # leftmost element off the deque and write it out. This is to save 

1012 # memory, as the memory is added to the buffer object, it is removed 

1013 # from from the container. 

1014 # 

1015 # Only this section needs to worry about memory pressue because 

1016 # everything else written to the buffer prior to this pickle data is 

1017 # only on the order of kilobytes to low numbers of megabytes. 

1018 while jsonData: 

1019 buffer.extend(jsonData.popleft()) 

1020 if returnHeader: 

1021 return buffer, headerData 

1022 else: 

1023 return buffer 

1024 

1025 @classmethod 

1026 def load( 

1027 cls, 

1028 file: BinaryIO, 

1029 universe: Optional[DimensionUniverse] = None, 

1030 nodes: Optional[Iterable[uuid.UUID]] = None, 

1031 graphID: Optional[BuildId] = None, 

1032 minimumVersion: int = 3, 

1033 ) -> QuantumGraph: 

1034 """Read QuantumGraph from a file that was made by `save`. 

1035 

1036 Parameters 

1037 ---------- 

1038 file : `io.IO` of bytes 

1039 File with pickle data open in binary mode. 

1040 universe: `~lsst.daf.butler.DimensionUniverse`, optional 

1041 DimensionUniverse instance, not used by the method itself but 

1042 needed to ensure that registry data structures are initialized. 

1043 If None it is loaded from the QuantumGraph saved structure. If 

1044 supplied, the DimensionUniverse from the loaded `QuantumGraph` 

1045 will be validated against the supplied argument for compatibility. 

1046 nodes: iterable of `int` or None 

1047 Numbers that correspond to nodes in the graph. If specified, only 

1048 these nodes will be loaded. Defaults to None, in which case all 

1049 nodes will be loaded. 

1050 graphID : `str` or `None` 

1051 If specified this ID is verified against the loaded graph prior to 

1052 loading any Nodes. This defaults to None in which case no 

1053 validation is done. 

1054 minimumVersion : int 

1055 Minimum version of a save file to load. Set to -1 to load all 

1056 versions. Older versions may need to be loaded, and re-saved 

1057 to upgrade them to the latest format before they can be used in 

1058 production. 

1059 

1060 Returns 

1061 ------- 

1062 graph : `QuantumGraph` 

1063 Resulting QuantumGraph instance. 

1064 

1065 Raises 

1066 ------ 

1067 TypeError 

1068 Raised if pickle contains instance of a type other than 

1069 QuantumGraph. 

1070 ValueError 

1071 Raised if one or more of the nodes requested is not in the 

1072 `QuantumGraph` or if graphID parameter does not match the graph 

1073 being loaded or if the supplied uri does not point at a valid 

1074 `QuantumGraph` save file. 

1075 

1076 Notes 

1077 ----- 

1078 Reading Quanta from pickle requires existence of singleton 

1079 DimensionUniverse which is usually instantiated during Registry 

1080 initialization. To make sure that DimensionUniverse exists this method 

1081 accepts dummy DimensionUniverse argument. 

1082 """ 

1083 # Try to see if the file handle contains pickle data, this will be 

1084 # removed in the future 

1085 try: 

1086 qgraph = pickle.load(file) 

1087 warnings.warn("Pickle graphs are deprecated, please re-save your graph with the save method") 

1088 except pickle.UnpicklingError: 

1089 with LoadHelper(file, minimumVersion) as loader: 

1090 qgraph = loader.load(universe, nodes, graphID) 

1091 if not isinstance(qgraph, QuantumGraph): 

1092 raise TypeError(f"QuantumGraph pickle file has contains unexpected object type: {type(qgraph)}") 

1093 return qgraph 

1094 

1095 def iterTaskGraph(self) -> Generator[TaskDef, None, None]: 

1096 """Iterate over the `taskGraph` attribute in topological order 

1097 

1098 Yields 

1099 ------ 

1100 taskDef : `TaskDef` 

1101 `TaskDef` objects in topological order 

1102 """ 

1103 yield from nx.topological_sort(self.taskGraph) 

1104 

1105 @property 

1106 def graphID(self) -> BuildId: 

1107 """Returns the ID generated by the graph at construction time""" 

1108 return self._buildId 

1109 

1110 def __iter__(self) -> Generator[QuantumNode, None, None]: 

1111 yield from nx.topological_sort(self._connectedQuanta) 

1112 

1113 def __len__(self) -> int: 

1114 return self._count 

1115 

1116 def __contains__(self, node: QuantumNode) -> bool: 

1117 return self._connectedQuanta.has_node(node) 

1118 

1119 def __getstate__(self) -> dict: 

1120 """Stores a compact form of the graph as a list of graph nodes, and a 

1121 tuple of task labels and task configs. The full graph can be 

1122 reconstructed with this information, and it preseves the ordering of 

1123 the graph ndoes. 

1124 """ 

1125 universe: Optional[DimensionUniverse] = None 

1126 for node in self: 

1127 dId = node.quantum.dataId 

1128 if dId is None: 

1129 continue 

1130 universe = dId.graph.universe 

1131 return {"reduced": self._buildSaveObject(), "graphId": self._buildId, "universe": universe} 

1132 

1133 def __setstate__(self, state: dict) -> None: 

1134 """Reconstructs the state of the graph from the information persisted 

1135 in getstate. 

1136 """ 

1137 buffer = io.BytesIO(state["reduced"]) 

1138 with LoadHelper(buffer, minimumVersion=3) as loader: 

1139 qgraph = loader.load(state["universe"], graphID=state["graphId"]) 

1140 

1141 self._metadata = qgraph._metadata 

1142 self._buildId = qgraph._buildId 

1143 self._datasetDict = qgraph._datasetDict 

1144 self._nodeIdMap = qgraph._nodeIdMap 

1145 self._count = len(qgraph) 

1146 self._taskToQuantumNode = qgraph._taskToQuantumNode 

1147 self._taskGraph = qgraph._taskGraph 

1148 self._connectedQuanta = qgraph._connectedQuanta 

1149 

1150 def __eq__(self, other: object) -> bool: 

1151 if not isinstance(other, QuantumGraph): 

1152 return False 

1153 if len(self) != len(other): 

1154 return False 

1155 for node in self: 

1156 if node not in other: 

1157 return False 

1158 if self.determineInputsToQuantumNode(node) != other.determineInputsToQuantumNode(node): 

1159 return False 

1160 if self.determineOutputsOfQuantumNode(node) != other.determineOutputsOfQuantumNode(node): 

1161 return False 

1162 if set(self.allDatasetTypes) != set(other.allDatasetTypes): 

1163 return False 

1164 return set(self.taskGraph) == set(other.taskGraph)