Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22import warnings 

23 

24__all__ = ("QuantumGraph", "IncompatibleGraphError") 

25 

26from collections import defaultdict, deque 

27 

28from itertools import chain, count 

29import io 

30import networkx as nx 

31from networkx.drawing.nx_agraph import write_dot 

32import os 

33import pickle 

34import lzma 

35import copy 

36import struct 

37import time 

38from typing import (DefaultDict, Dict, FrozenSet, Iterable, List, Mapping, Set, Generator, Optional, Tuple, 

39 Union, TypeVar) 

40 

41from ..connections import iterConnections 

42from ..pipeline import TaskDef 

43from lsst.daf.butler import Quantum, DatasetRef, ButlerURI, DimensionUniverse 

44 

45from ._implDetails import _DatasetTracker, DatasetTypeName 

46from .quantumNode import QuantumNode, NodeId, BuildId 

47from ._loadHelpers import LoadHelper 

48 

49 

50_T = TypeVar("_T", bound="QuantumGraph") 

51 

52# modify this constant any time the on disk representation of the save file 

53# changes, and update the load helpers to behave properly for each version. 

54SAVE_VERSION = 1 

55 

56# String used to describe the format for the preamble bytes in a file save 

57# This marks a Big endian encoded format with an unsigned short, an unsigned 

58# long long, and an unsigned long long in the byte stream 

59STRUCT_FMT_STRING = '>HQQ' 

60 

61 

62# magic bytes that help determine this is a graph save 

63MAGIC_BYTES = b"qgraph4\xf6\xe8\xa9" 

64 

65 

66class IncompatibleGraphError(Exception): 

67 """Exception class to indicate that a lookup by NodeId is impossible due 

68 to incompatibilities 

69 """ 

70 pass 

71 

72 

73class QuantumGraph: 

74 """QuantumGraph is a directed acyclic graph of `QuantumNode` objects 

75 

76 This data structure represents a concrete workflow generated from a 

77 `Pipeline`. 

78 

79 Parameters 

80 ---------- 

81 quanta : Mapping of `TaskDef` to sets of `Quantum` 

82 This maps tasks (and their configs) to the sets of data they are to 

83 process. 

84 """ 

85 def __init__(self, quanta: Mapping[TaskDef, Set[Quantum]]): 

86 self._buildGraphs(quanta) 

87 

88 def _buildGraphs(self, 

89 quanta: Mapping[TaskDef, Set[Quantum]], 

90 *, 

91 _quantumToNodeId: Optional[Mapping[Quantum, NodeId]] = None, 

92 _buildId: Optional[BuildId] = None): 

93 """Builds the graph that is used to store the relation between tasks, 

94 and the graph that holds the relations between quanta 

95 """ 

96 self._quanta = quanta 

97 self._buildId = _buildId if _buildId is not None else BuildId(f"{time.time()}-{os.getpid()}") 

98 # Data structures used to identify relations between components; 

99 # DatasetTypeName -> TaskDef for task, 

100 # and DatasetRef -> QuantumNode for the quanta 

101 self._datasetDict = _DatasetTracker[DatasetTypeName, TaskDef]() 

102 self._datasetRefDict = _DatasetTracker[DatasetRef, QuantumNode]() 

103 

104 nodeNumberGenerator = count() 

105 self._nodeIdMap: Dict[NodeId, QuantumNode] = {} 

106 self._count = 0 

107 for taskDef, quantumSet in self._quanta.items(): 

108 connections = taskDef.connections 

109 

110 # For each type of connection in the task, add a key to the 

111 # `_DatasetTracker` for the connections name, with a value of 

112 # the TaskDef in the appropriate field 

113 for inpt in iterConnections(connections, ("inputs", "prerequisiteInputs", "initInputs")): 

114 self._datasetDict.addInput(DatasetTypeName(inpt.name), taskDef) 

115 

116 for output in iterConnections(connections, ("outputs", "initOutputs")): 

117 self._datasetDict.addOutput(DatasetTypeName(output.name), taskDef) 

118 

119 # For each `Quantum` in the set of all `Quantum` for this task, 

120 # add a key to the `_DatasetTracker` that is a `DatasetRef` for one 

121 # of the individual datasets inside the `Quantum`, with a value of 

122 # a newly created QuantumNode to the appropriate input/output 

123 # field. 

124 self._count += len(quantumSet) 

125 for quantum in quantumSet: 

126 if _quantumToNodeId: 

127 nodeId = _quantumToNodeId.get(quantum) 

128 if nodeId is None: 

129 raise ValueError("If _quantuMToNodeNumber is not None, all quanta must have an " 

130 "associated value in the mapping") 

131 else: 

132 nodeId = NodeId(next(nodeNumberGenerator), self._buildId) 

133 

134 inits = quantum.initInputs.values() 

135 inputs = quantum.inputs.values() 

136 value = QuantumNode(quantum, taskDef, nodeId) 

137 self._nodeIdMap[nodeId] = value 

138 

139 for dsRef in chain(inits, inputs): 

140 # unfortunately, `Quantum` allows inits to be individual 

141 # `DatasetRef`s or an Iterable of such, so there must 

142 # be an instance check here 

143 if isinstance(dsRef, Iterable): 

144 for sub in dsRef: 

145 self._datasetRefDict.addInput(sub, value) 

146 else: 

147 self._datasetRefDict.addInput(dsRef, value) 

148 for dsRef in chain.from_iterable(quantum.outputs.values()): 

149 self._datasetRefDict.addOutput(dsRef, value) 

150 

151 # Graph of task relations, used in various methods 

152 self._taskGraph = self._datasetDict.makeNetworkXGraph() 

153 

154 # Graph of quanta relations 

155 self._connectedQuanta = self._datasetRefDict.makeNetworkXGraph() 

156 

157 @property 

158 def taskGraph(self) -> nx.DiGraph: 

159 """Return a graph representing the relations between the tasks inside 

160 the quantum graph. 

161 

162 Returns 

163 ------- 

164 taskGraph : `networkx.Digraph` 

165 Internal datastructure that holds relations of `TaskDef` objects 

166 """ 

167 return self._taskGraph 

168 

169 @property 

170 def graph(self) -> nx.DiGraph: 

171 """Return a graph representing the relations between all the 

172 `QuantumNode` objects. Largely it should be preferred to iterate 

173 over, and use methods of this class, but sometimes direct access to 

174 the networkx object may be helpful 

175 

176 Returns 

177 ------- 

178 graph : `networkx.Digraph` 

179 Internal datastructure that holds relations of `QuantumNode` 

180 objects 

181 """ 

182 return self._connectedQuanta 

183 

184 @property 

185 def inputQuanta(self) -> Iterable[QuantumNode]: 

186 """Make a `list` of all `QuantumNode` objects that are 'input' nodes 

187 to the graph, meaning those nodes to not depend on any other nodes in 

188 the graph. 

189 

190 Returns 

191 ------- 

192 inputNodes : iterable of `QuantumNode` 

193 A list of nodes that are inputs to the graph 

194 """ 

195 return (q for q, n in self._connectedQuanta.in_degree if n == 0) 

196 

197 @property 

198 def outputQuanta(self) -> Iterable[QuantumNode]: 

199 """Make a `list` of all `QuantumNode` objects that are 'output' nodes 

200 to the graph, meaning those nodes have no nodes that depend them in 

201 the graph. 

202 

203 Returns 

204 ------- 

205 outputNodes : iterable of `QuantumNode` 

206 A list of nodes that are outputs of the graph 

207 """ 

208 return [q for q, n in self._connectedQuanta.out_degree if n == 0] 

209 

210 @property 

211 def allDatasetTypes(self) -> Tuple[DatasetTypeName, ...]: 

212 """Return all the `DatasetTypeName` objects that are contained inside 

213 the graph. 

214 

215 Returns 

216 ------- 

217 tuple of `DatasetTypeName` 

218 All the data set type names that are present in the graph 

219 """ 

220 return tuple(self._datasetDict.keys()) 

221 

222 @property 

223 def isConnected(self) -> bool: 

224 """Return True if all of the nodes in the graph are connected, ignores 

225 directionality of connections. 

226 """ 

227 return nx.is_weakly_connected(self._connectedQuanta) 

228 

229 def getQuantumNodeByNodeId(self, nodeId: NodeId) -> QuantumNode: 

230 """Lookup a `QuantumNode` from an id associated with the node. 

231 

232 Parameters 

233 ---------- 

234 nodeId : `NodeId` 

235 The number associated with a node 

236 

237 Returns 

238 ------- 

239 node : `QuantumNode` 

240 The node corresponding with input number 

241 

242 Raises 

243 ------ 

244 IndexError 

245 Raised if the requested nodeId is not in the graph. 

246 IncompatibleGraphError 

247 Raised if the nodeId was built with a different graph than is not 

248 this instance (or a graph instance that produced this instance 

249 through and operation such as subset) 

250 """ 

251 if nodeId.buildId != self._buildId: 

252 raise IncompatibleGraphError("This node was built from a different, incompatible, graph instance") 

253 return self._nodeIdMap[nodeId] 

254 

255 def getQuantaForTask(self, taskDef: TaskDef) -> FrozenSet[Quantum]: 

256 """Return all the `Quantum` associated with a `TaskDef`. 

257 

258 Parameters 

259 ---------- 

260 taskDef : `TaskDef` 

261 The `TaskDef` for which `Quantum` are to be queried 

262 

263 Returns 

264 ------- 

265 frozenset of `Quantum` 

266 The `set` of `Quantum` that is associated with the specified 

267 `TaskDef`. 

268 """ 

269 return frozenset(self._quanta[taskDef]) 

270 

271 def findTasksWithInput(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]: 

272 """Find all tasks that have the specified dataset type name as an 

273 input. 

274 

275 Parameters 

276 ---------- 

277 datasetTypeName : `str` 

278 A string representing the name of a dataset type to be queried, 

279 can also accept a `DatasetTypeName` which is a `NewType` of str for 

280 type safety in static type checking. 

281 

282 Returns 

283 ------- 

284 tasks : iterable of `TaskDef` 

285 `TaskDef` objects that have the specified `DatasetTypeName` as an 

286 input, list will be empty if no tasks use specified 

287 `DatasetTypeName` as an input. 

288 

289 Raises 

290 ------ 

291 KeyError 

292 Raised if the `DatasetTypeName` is not part of the `QuantumGraph` 

293 """ 

294 return (c for c in self._datasetDict.getInputs(datasetTypeName)) 

295 

296 def findTaskWithOutput(self, datasetTypeName: DatasetTypeName) -> Optional[TaskDef]: 

297 """Find all tasks that have the specified dataset type name as an 

298 output. 

299 

300 Parameters 

301 ---------- 

302 datasetTypeName : `str` 

303 A string representing the name of a dataset type to be queried, 

304 can also accept a `DatasetTypeName` which is a `NewType` of str for 

305 type safety in static type checking. 

306 

307 Returns 

308 ------- 

309 `TaskDef` or `None` 

310 `TaskDef` that outputs `DatasetTypeName` as an output or None if 

311 none of the tasks produce this `DatasetTypeName`. 

312 

313 Raises 

314 ------ 

315 KeyError 

316 Raised if the `DatasetTypeName` is not part of the `QuantumGraph` 

317 """ 

318 return self._datasetDict.getOutput(datasetTypeName) 

319 

320 def tasksWithDSType(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]: 

321 """Find all tasks that are associated with the specified dataset type 

322 name. 

323 

324 Parameters 

325 ---------- 

326 datasetTypeName : `str` 

327 A string representing the name of a dataset type to be queried, 

328 can also accept a `DatasetTypeName` which is a `NewType` of str for 

329 type safety in static type checking. 

330 

331 Returns 

332 ------- 

333 result : iterable of `TaskDef` 

334 `TaskDef` objects that are associated with the specified 

335 `DatasetTypeName` 

336 

337 Raises 

338 ------ 

339 KeyError 

340 Raised if the `DatasetTypeName` is not part of the `QuantumGraph` 

341 """ 

342 results = self.findTasksWithInput(datasetTypeName) 

343 output = self.findTaskWithOutput(datasetTypeName) 

344 if output is not None: 

345 results = chain(results, (output,)) 

346 return results 

347 

348 def findTaskDefByName(self, taskName: str) -> List[TaskDef]: 

349 """Determine which `TaskDef` objects in this graph are associated 

350 with a `str` representing a task name (looks at the taskName property 

351 of `TaskDef` objects). 

352 

353 Returns a list of `TaskDef` objects as a `PipelineTask` may appear 

354 multiple times in a graph with different labels. 

355 

356 Parameters 

357 ---------- 

358 taskName : str 

359 Name of a task to search for 

360 

361 Returns 

362 ------- 

363 result : list of `TaskDef` 

364 List of the `TaskDef` objects that have the name specified. 

365 Multiple values are returned in the case that a task is used 

366 multiple times with different labels. 

367 """ 

368 results = [] 

369 for task in self._quanta.keys(): 

370 split = task.taskName.split('.') 

371 if split[-1] == taskName: 

372 results.append(task) 

373 return results 

374 

375 def findTaskDefByLabel(self, label: str) -> Optional[TaskDef]: 

376 """Determine which `TaskDef` objects in this graph are associated 

377 with a `str` representing a tasks label. 

378 

379 Parameters 

380 ---------- 

381 taskName : str 

382 Name of a task to search for 

383 

384 Returns 

385 ------- 

386 result : `TaskDef` 

387 `TaskDef` objects that has the specified label. 

388 """ 

389 for task in self._quanta.keys(): 

390 if label == task.label: 

391 return task 

392 return None 

393 

394 def findQuantaWithDSType(self, datasetTypeName: DatasetTypeName) -> Set[Quantum]: 

395 """Return all the `Quantum` that contain a specified `DatasetTypeName`. 

396 

397 Parameters 

398 ---------- 

399 datasetTypeName : `str` 

400 The name of the dataset type to search for as a string, 

401 can also accept a `DatasetTypeName` which is a `NewType` of str for 

402 type safety in static type checking. 

403 

404 Returns 

405 ------- 

406 result : `set` of `QuantumNode` objects 

407 A `set` of `QuantumNode`s that contain specified `DatasetTypeName` 

408 

409 Raises 

410 ------ 

411 KeyError 

412 Raised if the `DatasetTypeName` is not part of the `QuantumGraph` 

413 

414 """ 

415 tasks = self._datasetDict.getAll(datasetTypeName) 

416 result: Set[Quantum] = set() 

417 result = result.union(*(self._quanta[task] for task in tasks)) 

418 return result 

419 

420 def checkQuantumInGraph(self, quantum: Quantum) -> bool: 

421 """Check if specified quantum appears in the graph as part of a node. 

422 

423 Parameters 

424 ---------- 

425 quantum : `Quantum` 

426 The quantum to search for 

427 

428 Returns 

429 ------- 

430 `bool` 

431 The result of searching for the quantum 

432 """ 

433 for qset in self._quanta.values(): 

434 if quantum in qset: 

435 return True 

436 return False 

437 

438 def writeDotGraph(self, output: Union[str, io.BufferedIOBase]): 

439 """Write out the graph as a dot graph. 

440 

441 Parameters 

442 ---------- 

443 output : str or `io.BufferedIOBase` 

444 Either a filesystem path to write to, or a file handle object 

445 """ 

446 write_dot(self._connectedQuanta, output) 

447 

448 def subset(self: _T, nodes: Union[QuantumNode, Iterable[QuantumNode]]) -> _T: 

449 """Create a new graph object that contains the subset of the nodes 

450 specified as input. Node number is preserved. 

451 

452 Parameters 

453 ---------- 

454 nodes : `QuantumNode` or iterable of `QuantumNode` 

455 

456 Returns 

457 ------- 

458 graph : instance of graph type 

459 An instance of the type from which the subset was created 

460 """ 

461 if not isinstance(nodes, Iterable): 

462 nodes = (nodes, ) 

463 quantumSubgraph = self._connectedQuanta.subgraph(nodes).nodes 

464 quantumMap = defaultdict(set) 

465 

466 node: QuantumNode 

467 for node in quantumSubgraph: 

468 quantumMap[node.taskDef].add(node.quantum) 

469 # Create an empty graph, and then populate it with custom mapping 

470 newInst = type(self)({}) 

471 newInst._buildGraphs(quantumMap, _quantumToNodeId={n.quantum: n.nodeId for n in nodes}, 

472 _buildId=self._buildId) 

473 return newInst 

474 

475 def subsetToConnected(self: _T) -> Tuple[_T, ...]: 

476 """Generate a list of subgraphs where each is connected. 

477 

478 Returns 

479 ------- 

480 result : list of `QuantumGraph` 

481 A list of graphs that are each connected 

482 """ 

483 return tuple(self.subset(connectedSet) 

484 for connectedSet in nx.weakly_connected_components(self._connectedQuanta)) 

485 

486 def determineInputsToQuantumNode(self, node: QuantumNode) -> Set[QuantumNode]: 

487 """Return a set of `QuantumNode` that are direct inputs to a specified 

488 node. 

489 

490 Parameters 

491 ---------- 

492 node : `QuantumNode` 

493 The node of the graph for which inputs are to be determined 

494 

495 Returns 

496 ------- 

497 set of `QuantumNode` 

498 All the nodes that are direct inputs to specified node 

499 """ 

500 return set(pred for pred in self._connectedQuanta.predecessors(node)) 

501 

502 def determineOutputsOfQuantumNode(self, node: QuantumNode) -> Set[QuantumNode]: 

503 """Return a set of `QuantumNode` that are direct outputs of a specified 

504 node. 

505 

506 Parameters 

507 ---------- 

508 node : `QuantumNode` 

509 The node of the graph for which outputs are to be determined 

510 

511 Returns 

512 ------- 

513 set of `QuantumNode` 

514 All the nodes that are direct outputs to specified node 

515 """ 

516 return set(succ for succ in self._connectedQuanta.successors(node)) 

517 

518 def determineConnectionsOfQuantumNode(self: _T, node: QuantumNode) -> _T: 

519 """Return a graph of `QuantumNode` that are direct inputs and outputs 

520 of a specified node. 

521 

522 Parameters 

523 ---------- 

524 node : `QuantumNode` 

525 The node of the graph for which connected nodes are to be 

526 determined. 

527 

528 Returns 

529 ------- 

530 graph : graph of `QuantumNode` 

531 All the nodes that are directly connected to specified node 

532 """ 

533 nodes = self.determineInputsToQuantumNode(node).union(self.determineOutputsOfQuantumNode(node)) 

534 nodes.add(node) 

535 return self.subset(nodes) 

536 

537 def determineAncestorsOfQuantumNode(self: _T, node: QuantumNode) -> _T: 

538 """Return a graph of the specified node and all the ancestor nodes 

539 directly reachable by walking edges. 

540 

541 Parameters 

542 ---------- 

543 node : `QuantumNode` 

544 The node for which all ansestors are to be determined 

545 

546 Returns 

547 ------- 

548 graph of `QuantumNode` 

549 Graph of node and all of its ansestors 

550 """ 

551 predecessorNodes = nx.ancestors(self._connectedQuanta, node) 

552 predecessorNodes.add(node) 

553 return self.subset(predecessorNodes) 

554 

555 def findCycle(self) -> List[Tuple[QuantumNode, QuantumNode]]: 

556 """Check a graph for the presense of cycles and returns the edges of 

557 any cycles found, or an empty list if there is no cycle. 

558 

559 Returns 

560 ------- 

561 result : list of tuple of `QuantumNode`, `QuantumNode` 

562 A list of any graph edges that form a cycle, or an empty list if 

563 there is no cycle. Empty list to so support if graph.find_cycle() 

564 syntax as an empty list is falsy. 

565 """ 

566 try: 

567 return nx.find_cycle(self._connectedQuanta) 

568 except nx.NetworkXNoCycle: 

569 return [] 

570 

571 def saveUri(self, uri): 

572 """Save `QuantumGraph` to the specified URI. 

573 

574 Parameters 

575 ---------- 

576 uri : `ButlerURI` or `str` 

577 URI to where the graph should be saved. 

578 """ 

579 buffer = self._buildSaveObject() 

580 butlerUri = ButlerURI(uri) 

581 if butlerUri.getExtension() not in (".qgraph"): 

582 raise TypeError(f"Can currently only save a graph in qgraph format not {uri}") 

583 butlerUri.write(buffer) # type: ignore # Ignore because bytearray is safe to use in place of bytes 

584 

585 @classmethod 

586 def loadUri(cls, uri: Union[ButlerURI, str], universe: DimensionUniverse, 

587 nodes: Optional[Iterable[int]] = None, 

588 graphID: Optional[BuildId] = None 

589 ) -> QuantumGraph: 

590 """Read `QuantumGraph` from a URI. 

591 

592 Parameters 

593 ---------- 

594 uri : `ButlerURI` or `str` 

595 URI from where to load the graph. 

596 universe: `~lsst.daf.butler.DimensionUniverse` 

597 DimensionUniverse instance, not used by the method itself but 

598 needed to ensure that registry data structures are initialized. 

599 nodes: iterable of `int` or None 

600 Numbers that correspond to nodes in the graph. If specified, only 

601 these nodes will be loaded. Defaults to None, in which case all 

602 nodes will be loaded. 

603 graphID : `str` or `None` 

604 If specified this ID is verified against the loaded graph prior to 

605 loading any Nodes. This defaults to None in which case no 

606 validation is done. 

607 

608 Returns 

609 ------- 

610 graph : `QuantumGraph` 

611 Resulting QuantumGraph instance. 

612 

613 Raises 

614 ------ 

615 TypeError 

616 Raised if pickle contains instance of a type other than 

617 QuantumGraph. 

618 ValueError 

619 Raised if one or more of the nodes requested is not in the 

620 `QuantumGraph` or if graphID parameter does not match the graph 

621 being loaded or if the supplied uri does not point at a valid 

622 `QuantumGraph` save file. 

623 

624 

625 Notes 

626 ----- 

627 Reading Quanta from pickle requires existence of singleton 

628 DimensionUniverse which is usually instantiated during Registry 

629 initialization. To make sure that DimensionUniverse exists this method 

630 accepts dummy DimensionUniverse argument. 

631 """ 

632 uri = ButlerURI(uri) 

633 # With ButlerURI we have the choice of always using a local file 

634 # or reading in the bytes directly. Reading in bytes can be more 

635 # efficient for reasonably-sized pickle files when the resource 

636 # is remote. For now use the local file variant. For a local file 

637 # as_local() does nothing. 

638 

639 if uri.getExtension() in (".pickle", ".pkl"): 

640 with uri.as_local() as local, open(local.ospath, "rb") as fd: 

641 warnings.warn("Pickle graphs are deprecated, please re-save your graph with the save method") 

642 qgraph = pickle.load(fd) 

643 elif uri.getExtension() in ('.qgraph'): 

644 with LoadHelper(uri) as loader: 

645 qgraph = loader.load(nodes, graphID) 

646 else: 

647 raise ValueError("Only know how to handle files saved as `pickle`, `pkl`, or `qgraph`") 

648 if not isinstance(qgraph, QuantumGraph): 

649 raise TypeError(f"QuantumGraph save file contains unexpected object type: {type(qgraph)}") 

650 return qgraph 

651 

652 def save(self, file: io.IO[bytes]): 

653 """Save QuantumGraph to a file. 

654 

655 Presently we store QuantumGraph in pickle format, this could 

656 potentially change in the future if better format is found. 

657 

658 Parameters 

659 ---------- 

660 file : `io.BufferedIOBase` 

661 File to write pickle data open in binary mode. 

662 """ 

663 buffer = self._buildSaveObject() 

664 file.write(buffer) # type: ignore # Ignore because bytearray is safe to use in place of bytes 

665 

666 def _buildSaveObject(self) -> bytearray: 

667 # make some containers 

668 pickleData = deque() 

669 nodeMap = {} 

670 taskDefMap = {} 

671 protocol = 3 

672 

673 # counter for the number of bytes processed thus far 

674 count = 0 

675 # serialize out the task Defs recording the start and end bytes of each 

676 # taskDef 

677 for taskDef in self.taskGraph: 

678 # compressing has very little impact on saving or load time, but 

679 # a large impact on on disk size, so it is worth doing 

680 dump = lzma.compress(pickle.dumps(taskDef, protocol=protocol)) 

681 taskDefMap[taskDef.label] = (count, count+len(dump)) 

682 count += len(dump) 

683 pickleData.append(dump) 

684 

685 # Store the QauntumGraph BuildId along side the TaskDefs for 

686 # convenance. This will allow validating BuildIds at load time, prior 

687 # to loading any QuantumNodes. Name chosen for unlikely conflicts with 

688 # labels as this is python standard for private. 

689 taskDefMap['__GraphBuildID'] = self.graphID 

690 

691 # serialize the nodes, recording the start and end bytes of each node 

692 for node in self: 

693 node = copy.copy(node) 

694 taskDef = node.taskDef 

695 # Explicitly overload the "frozen-ness" of nodes to normalized out 

696 # the taskDef, this saves a lot of space and load time. The label 

697 # will be used to retrive the taskDef from the taskDefMap upon load 

698 # 

699 # This strategy was chosen instead of creating a new class that 

700 # looked just like a QuantumNode but containing a label in place of 

701 # a TaskDef because it would be needlessly slow to construct a 

702 # bunch of new object to immediately serialize them and destroy the 

703 # object. This seems like an acceptable use of Python's dynamic 

704 # nature in a controlled way for optimization and simplicity. 

705 object.__setattr__(node, 'taskDef', taskDef.label) 

706 # compressing has very little impact on saving or load time, but 

707 # a large impact on on disk size, so it is worth doing 

708 dump = lzma.compress(pickle.dumps(node, protocol=protocol)) 

709 pickleData.append(dump) 

710 nodeMap[node.nodeId.number] = (count, count+len(dump)) 

711 count += len(dump) 

712 

713 # pickle the taskDef byte map 

714 taskDef_pickle = pickle.dumps(taskDefMap, protocol=protocol) 

715 

716 # pickle the node byte map 

717 map_pickle = pickle.dumps(nodeMap, protocol=protocol) 

718 

719 # record the sizes as 2 unsigned long long numbers for a total of 16 

720 # bytes 

721 map_lengths = struct.pack(STRUCT_FMT_STRING, SAVE_VERSION, len(taskDef_pickle), len(map_pickle)) 

722 

723 # write each component of the save out in a deterministic order 

724 # buffer = io.BytesIO() 

725 # buffer.write(map_lengths) 

726 # buffer.write(taskDef_pickle) 

727 # buffer.write(map_pickle) 

728 buffer = bytearray() 

729 buffer.extend(MAGIC_BYTES) 

730 buffer.extend(map_lengths) 

731 buffer.extend(taskDef_pickle) 

732 buffer.extend(map_pickle) 

733 # Iterate over the length of pickleData, and for each element pop the 

734 # leftmost element off the deque and write it out. This is to save 

735 # memory, as the memory is added to the buffer object, it is removed 

736 # from from the container. 

737 # 

738 # Only this section needs to worry about memory pressue because 

739 # everything else written to the buffer prior to this pickle data is 

740 # only on the order of kilobytes to low numbers of megabytes. 

741 while pickleData: 

742 buffer.extend(pickleData.popleft()) 

743 return buffer 

744 

745 @classmethod 

746 def load(cls, file: io.IO[bytes], universe: DimensionUniverse, 

747 nodes: Optional[Iterable[int]] = None, 

748 graphID: Optional[BuildId] = None 

749 ) -> QuantumGraph: 

750 """Read QuantumGraph from a file that was made by `save`. 

751 

752 Parameters 

753 ---------- 

754 file : `io.IO` of bytes 

755 File with pickle data open in binary mode. 

756 universe: `~lsst.daf.butler.DimensionUniverse` 

757 DimensionUniverse instance, not used by the method itself but 

758 needed to ensure that registry data structures are initialized. 

759 nodes: iterable of `int` or None 

760 Numbers that correspond to nodes in the graph. If specified, only 

761 these nodes will be loaded. Defaults to None, in which case all 

762 nodes will be loaded. 

763 graphID : `str` or `None` 

764 If specified this ID is verified against the loaded graph prior to 

765 loading any Nodes. This defaults to None in which case no 

766 validation is done. 

767 

768 Returns 

769 ------- 

770 graph : `QuantumGraph` 

771 Resulting QuantumGraph instance. 

772 

773 Raises 

774 ------ 

775 TypeError 

776 Raised if pickle contains instance of a type other than 

777 QuantumGraph. 

778 ValueError 

779 Raised if one or more of the nodes requested is not in the 

780 `QuantumGraph` or if graphID parameter does not match the graph 

781 being loaded or if the supplied uri does not point at a valid 

782 `QuantumGraph` save file. 

783 

784 Notes 

785 ----- 

786 Reading Quanta from pickle requires existence of singleton 

787 DimensionUniverse which is usually instantiated during Registry 

788 initialization. To make sure that DimensionUniverse exists this method 

789 accepts dummy DimensionUniverse argument. 

790 """ 

791 # Try to see if the file handle contains pickle data, this will be 

792 # removed in the future 

793 try: 

794 qgraph = pickle.load(file) 

795 warnings.warn("Pickle graphs are deprecated, please re-save your graph with the save method") 

796 except pickle.UnpicklingError: 

797 with LoadHelper(file) as loader: # type: ignore # needed because we don't have Protocols yet 

798 qgraph = loader.load(nodes, graphID) 

799 if not isinstance(qgraph, QuantumGraph): 

800 raise TypeError(f"QuantumGraph pickle file has contains unexpected object type: {type(qgraph)}") 

801 return qgraph 

802 

803 def iterTaskGraph(self) -> Generator[TaskDef, None, None]: 

804 """Iterate over the `taskGraph` attribute in topological order 

805 

806 Yields 

807 ------ 

808 taskDef : `TaskDef` 

809 `TaskDef` objects in topological order 

810 """ 

811 yield from nx.topological_sort(self.taskGraph) 

812 

813 @property 

814 def graphID(self): 

815 """Returns the ID generated by the graph at construction time 

816 """ 

817 return self._buildId 

818 

819 def __iter__(self) -> Generator[QuantumNode, None, None]: 

820 yield from nx.topological_sort(self._connectedQuanta) 

821 

822 def __len__(self) -> int: 

823 return self._count 

824 

825 def __contains__(self, node: QuantumNode) -> bool: 

826 return self._connectedQuanta.has_node(node) 

827 

828 def __getstate__(self) -> dict: 

829 """Stores a compact form of the graph as a list of graph nodes, and a 

830 tuple of task labels and task configs. The full graph can be 

831 reconstructed with this information, and it preseves the ordering of 

832 the graph ndoes. 

833 """ 

834 return {"nodesList": list(self)} 

835 

836 def __setstate__(self, state: dict): 

837 """Reconstructs the state of the graph from the information persisted 

838 in getstate. 

839 """ 

840 quanta: DefaultDict[TaskDef, Set[Quantum]] = defaultdict(set) 

841 quantumToNodeId: Dict[Quantum, NodeId] = {} 

842 quantumNode: QuantumNode 

843 for quantumNode in state['nodesList']: 

844 quanta[quantumNode.taskDef].add(quantumNode.quantum) 

845 quantumToNodeId[quantumNode.quantum] = quantumNode.nodeId 

846 _buildId = quantumNode.nodeId.buildId if state['nodesList'] else None # type: ignore 

847 self._buildGraphs(quanta, _quantumToNodeId=quantumToNodeId, _buildId=_buildId) 

848 

849 def __eq__(self, other: object) -> bool: 

850 if not isinstance(other, QuantumGraph): 

851 return False 

852 if len(self) != len(other): 

853 return False 

854 for node in self: 

855 if node not in other: 

856 return False 

857 if self.determineInputsToQuantumNode(node) != other.determineInputsToQuantumNode(node): 

858 return False 

859 if self.determineOutputsOfQuantumNode(node) != other.determineOutputsOfQuantumNode(node): 

860 return False 

861 return list(self.taskGraph) == list(other.taskGraph)