Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22import warnings 

23 

24__all__ = ("QuantumGraph", "IncompatibleGraphError") 

25 

26from collections import defaultdict, deque 

27 

28from itertools import chain, count 

29import io 

30import networkx as nx 

31from networkx.drawing.nx_agraph import write_dot 

32import os 

33import pickle 

34import lzma 

35import copy 

36import struct 

37import time 

38from typing import (DefaultDict, Dict, FrozenSet, Iterable, List, Mapping, Set, Generator, Optional, Tuple, 

39 Union, TypeVar) 

40 

41from ..connections import iterConnections 

42from ..pipeline import TaskDef 

43from lsst.daf.butler import Quantum, DatasetRef, ButlerURI, DimensionUniverse 

44 

45from ._implDetails import _DatasetTracker, DatasetTypeName 

46from .quantumNode import QuantumNode, NodeId, BuildId 

47from ._loadHelpers import LoadHelper 

48 

49 

50_T = TypeVar("_T", bound="QuantumGraph") 

51 

52# modify this constant any time the on disk representation of the save file 

53# changes, and update the load helpers to behave properly for each version. 

54SAVE_VERSION = 1 

55 

56# String used to describe the format for the preamble bytes in a file save 

57# This marks a Big endian encoded format with an unsigned short, an unsigned 

58# long long, and an unsigned long long in the byte stream 

59STRUCT_FMT_STRING = '>HQQ' 

60 

61 

62# magic bytes that help determine this is a graph save 

63MAGIC_BYTES = b"qgraph4\xf6\xe8\xa9" 

64 

65 

66class IncompatibleGraphError(Exception): 

67 """Exception class to indicate that a lookup by NodeId is impossible due 

68 to incompatibilities 

69 """ 

70 pass 

71 

72 

73class QuantumGraph: 

74 """QuantumGraph is a directed acyclic graph of `QuantumNode` objects 

75 

76 This data structure represents a concrete workflow generated from a 

77 `Pipeline`. 

78 

79 Parameters 

80 ---------- 

81 quanta : Mapping of `TaskDef` to sets of `Quantum` 

82 This maps tasks (and their configs) to the sets of data they are to 

83 process. 

84 """ 

85 def __init__(self, quanta: Mapping[TaskDef, Set[Quantum]]): 

86 self._buildGraphs(quanta) 

87 

88 def _buildGraphs(self, 

89 quanta: Mapping[TaskDef, Set[Quantum]], 

90 *, 

91 _quantumToNodeId: Optional[Mapping[Quantum, NodeId]] = None, 

92 _buildId: Optional[BuildId] = None): 

93 """Builds the graph that is used to store the relation between tasks, 

94 and the graph that holds the relations between quanta 

95 """ 

96 self._quanta = quanta 

97 self._buildId = _buildId if _buildId is not None else BuildId(f"{time.time()}-{os.getpid()}") 

98 # Data structures used to identify relations between components; 

99 # DatasetTypeName -> TaskDef for task, 

100 # and DatasetRef -> QuantumNode for the quanta 

101 self._datasetDict = _DatasetTracker[DatasetTypeName, TaskDef]() 

102 self._datasetRefDict = _DatasetTracker[DatasetRef, QuantumNode]() 

103 

104 nodeNumberGenerator = count() 

105 self._nodeIdMap: Dict[NodeId, QuantumNode] = {} 

106 self._taskToQuantumNode: DefaultDict[TaskDef, Set[QuantumNode]] = defaultdict(set) 

107 self._count = 0 

108 for taskDef, quantumSet in self._quanta.items(): 

109 connections = taskDef.connections 

110 

111 # For each type of connection in the task, add a key to the 

112 # `_DatasetTracker` for the connections name, with a value of 

113 # the TaskDef in the appropriate field 

114 for inpt in iterConnections(connections, ("inputs", "prerequisiteInputs", "initInputs")): 

115 self._datasetDict.addInput(DatasetTypeName(inpt.name), taskDef) 

116 

117 for output in iterConnections(connections, ("outputs", "initOutputs")): 

118 self._datasetDict.addOutput(DatasetTypeName(output.name), taskDef) 

119 

120 # For each `Quantum` in the set of all `Quantum` for this task, 

121 # add a key to the `_DatasetTracker` that is a `DatasetRef` for one 

122 # of the individual datasets inside the `Quantum`, with a value of 

123 # a newly created QuantumNode to the appropriate input/output 

124 # field. 

125 self._count += len(quantumSet) 

126 for quantum in quantumSet: 

127 if _quantumToNodeId: 

128 nodeId = _quantumToNodeId.get(quantum) 

129 if nodeId is None: 

130 raise ValueError("If _quantuMToNodeNumber is not None, all quanta must have an " 

131 "associated value in the mapping") 

132 else: 

133 nodeId = NodeId(next(nodeNumberGenerator), self._buildId) 

134 

135 inits = quantum.initInputs.values() 

136 inputs = quantum.inputs.values() 

137 value = QuantumNode(quantum, taskDef, nodeId) 

138 self._taskToQuantumNode[taskDef].add(value) 

139 self._nodeIdMap[nodeId] = value 

140 

141 for dsRef in chain(inits, inputs): 

142 # unfortunately, `Quantum` allows inits to be individual 

143 # `DatasetRef`s or an Iterable of such, so there must 

144 # be an instance check here 

145 if isinstance(dsRef, Iterable): 

146 for sub in dsRef: 

147 self._datasetRefDict.addInput(sub, value) 

148 else: 

149 self._datasetRefDict.addInput(dsRef, value) 

150 for dsRef in chain.from_iterable(quantum.outputs.values()): 

151 self._datasetRefDict.addOutput(dsRef, value) 

152 

153 # Graph of task relations, used in various methods 

154 self._taskGraph = self._datasetDict.makeNetworkXGraph() 

155 

156 # Graph of quanta relations 

157 self._connectedQuanta = self._datasetRefDict.makeNetworkXGraph() 

158 

159 @property 

160 def taskGraph(self) -> nx.DiGraph: 

161 """Return a graph representing the relations between the tasks inside 

162 the quantum graph. 

163 

164 Returns 

165 ------- 

166 taskGraph : `networkx.Digraph` 

167 Internal datastructure that holds relations of `TaskDef` objects 

168 """ 

169 return self._taskGraph 

170 

171 @property 

172 def graph(self) -> nx.DiGraph: 

173 """Return a graph representing the relations between all the 

174 `QuantumNode` objects. Largely it should be preferred to iterate 

175 over, and use methods of this class, but sometimes direct access to 

176 the networkx object may be helpful 

177 

178 Returns 

179 ------- 

180 graph : `networkx.Digraph` 

181 Internal datastructure that holds relations of `QuantumNode` 

182 objects 

183 """ 

184 return self._connectedQuanta 

185 

186 @property 

187 def inputQuanta(self) -> Iterable[QuantumNode]: 

188 """Make a `list` of all `QuantumNode` objects that are 'input' nodes 

189 to the graph, meaning those nodes to not depend on any other nodes in 

190 the graph. 

191 

192 Returns 

193 ------- 

194 inputNodes : iterable of `QuantumNode` 

195 A list of nodes that are inputs to the graph 

196 """ 

197 return (q for q, n in self._connectedQuanta.in_degree if n == 0) 

198 

199 @property 

200 def outputQuanta(self) -> Iterable[QuantumNode]: 

201 """Make a `list` of all `QuantumNode` objects that are 'output' nodes 

202 to the graph, meaning those nodes have no nodes that depend them in 

203 the graph. 

204 

205 Returns 

206 ------- 

207 outputNodes : iterable of `QuantumNode` 

208 A list of nodes that are outputs of the graph 

209 """ 

210 return [q for q, n in self._connectedQuanta.out_degree if n == 0] 

211 

212 @property 

213 def allDatasetTypes(self) -> Tuple[DatasetTypeName, ...]: 

214 """Return all the `DatasetTypeName` objects that are contained inside 

215 the graph. 

216 

217 Returns 

218 ------- 

219 tuple of `DatasetTypeName` 

220 All the data set type names that are present in the graph 

221 """ 

222 return tuple(self._datasetDict.keys()) 

223 

224 @property 

225 def isConnected(self) -> bool: 

226 """Return True if all of the nodes in the graph are connected, ignores 

227 directionality of connections. 

228 """ 

229 return nx.is_weakly_connected(self._connectedQuanta) 

230 

231 def getQuantumNodeByNodeId(self, nodeId: NodeId) -> QuantumNode: 

232 """Lookup a `QuantumNode` from an id associated with the node. 

233 

234 Parameters 

235 ---------- 

236 nodeId : `NodeId` 

237 The number associated with a node 

238 

239 Returns 

240 ------- 

241 node : `QuantumNode` 

242 The node corresponding with input number 

243 

244 Raises 

245 ------ 

246 IndexError 

247 Raised if the requested nodeId is not in the graph. 

248 IncompatibleGraphError 

249 Raised if the nodeId was built with a different graph than is not 

250 this instance (or a graph instance that produced this instance 

251 through and operation such as subset) 

252 """ 

253 if nodeId.buildId != self._buildId: 

254 raise IncompatibleGraphError("This node was built from a different, incompatible, graph instance") 

255 return self._nodeIdMap[nodeId] 

256 

257 def getQuantaForTask(self, taskDef: TaskDef) -> FrozenSet[Quantum]: 

258 """Return all the `Quantum` associated with a `TaskDef`. 

259 

260 Parameters 

261 ---------- 

262 taskDef : `TaskDef` 

263 The `TaskDef` for which `Quantum` are to be queried 

264 

265 Returns 

266 ------- 

267 frozenset of `Quantum` 

268 The `set` of `Quantum` that is associated with the specified 

269 `TaskDef`. 

270 """ 

271 return frozenset(self._quanta[taskDef]) 

272 

273 def getNodesForTask(self, taskDef: TaskDef) -> FrozenSet[QuantumNode]: 

274 """Return all the `QuantumNodes` associated with a `TaskDef`. 

275 

276 Parameters 

277 ---------- 

278 taskDef : `TaskDef` 

279 The `TaskDef` for which `Quantum` are to be queried 

280 

281 Returns 

282 ------- 

283 frozenset of `QuantumNodes` 

284 The `frozenset` of `QuantumNodes` that is associated with the 

285 specified `TaskDef`. 

286 """ 

287 return frozenset(self._taskToQuantumNode[taskDef]) 

288 

289 def findTasksWithInput(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]: 

290 """Find all tasks that have the specified dataset type name as an 

291 input. 

292 

293 Parameters 

294 ---------- 

295 datasetTypeName : `str` 

296 A string representing the name of a dataset type to be queried, 

297 can also accept a `DatasetTypeName` which is a `NewType` of str for 

298 type safety in static type checking. 

299 

300 Returns 

301 ------- 

302 tasks : iterable of `TaskDef` 

303 `TaskDef` objects that have the specified `DatasetTypeName` as an 

304 input, list will be empty if no tasks use specified 

305 `DatasetTypeName` as an input. 

306 

307 Raises 

308 ------ 

309 KeyError 

310 Raised if the `DatasetTypeName` is not part of the `QuantumGraph` 

311 """ 

312 return (c for c in self._datasetDict.getInputs(datasetTypeName)) 

313 

314 def findTaskWithOutput(self, datasetTypeName: DatasetTypeName) -> Optional[TaskDef]: 

315 """Find all tasks that have the specified dataset type name as an 

316 output. 

317 

318 Parameters 

319 ---------- 

320 datasetTypeName : `str` 

321 A string representing the name of a dataset type to be queried, 

322 can also accept a `DatasetTypeName` which is a `NewType` of str for 

323 type safety in static type checking. 

324 

325 Returns 

326 ------- 

327 `TaskDef` or `None` 

328 `TaskDef` that outputs `DatasetTypeName` as an output or None if 

329 none of the tasks produce this `DatasetTypeName`. 

330 

331 Raises 

332 ------ 

333 KeyError 

334 Raised if the `DatasetTypeName` is not part of the `QuantumGraph` 

335 """ 

336 return self._datasetDict.getOutput(datasetTypeName) 

337 

338 def tasksWithDSType(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]: 

339 """Find all tasks that are associated with the specified dataset type 

340 name. 

341 

342 Parameters 

343 ---------- 

344 datasetTypeName : `str` 

345 A string representing the name of a dataset type to be queried, 

346 can also accept a `DatasetTypeName` which is a `NewType` of str for 

347 type safety in static type checking. 

348 

349 Returns 

350 ------- 

351 result : iterable of `TaskDef` 

352 `TaskDef` objects that are associated with the specified 

353 `DatasetTypeName` 

354 

355 Raises 

356 ------ 

357 KeyError 

358 Raised if the `DatasetTypeName` is not part of the `QuantumGraph` 

359 """ 

360 results = self.findTasksWithInput(datasetTypeName) 

361 output = self.findTaskWithOutput(datasetTypeName) 

362 if output is not None: 

363 results = chain(results, (output,)) 

364 return results 

365 

366 def findTaskDefByName(self, taskName: str) -> List[TaskDef]: 

367 """Determine which `TaskDef` objects in this graph are associated 

368 with a `str` representing a task name (looks at the taskName property 

369 of `TaskDef` objects). 

370 

371 Returns a list of `TaskDef` objects as a `PipelineTask` may appear 

372 multiple times in a graph with different labels. 

373 

374 Parameters 

375 ---------- 

376 taskName : str 

377 Name of a task to search for 

378 

379 Returns 

380 ------- 

381 result : list of `TaskDef` 

382 List of the `TaskDef` objects that have the name specified. 

383 Multiple values are returned in the case that a task is used 

384 multiple times with different labels. 

385 """ 

386 results = [] 

387 for task in self._quanta.keys(): 

388 split = task.taskName.split('.') 

389 if split[-1] == taskName: 

390 results.append(task) 

391 return results 

392 

393 def findTaskDefByLabel(self, label: str) -> Optional[TaskDef]: 

394 """Determine which `TaskDef` objects in this graph are associated 

395 with a `str` representing a tasks label. 

396 

397 Parameters 

398 ---------- 

399 taskName : str 

400 Name of a task to search for 

401 

402 Returns 

403 ------- 

404 result : `TaskDef` 

405 `TaskDef` objects that has the specified label. 

406 """ 

407 for task in self._quanta.keys(): 

408 if label == task.label: 

409 return task 

410 return None 

411 

412 def findQuantaWithDSType(self, datasetTypeName: DatasetTypeName) -> Set[Quantum]: 

413 """Return all the `Quantum` that contain a specified `DatasetTypeName`. 

414 

415 Parameters 

416 ---------- 

417 datasetTypeName : `str` 

418 The name of the dataset type to search for as a string, 

419 can also accept a `DatasetTypeName` which is a `NewType` of str for 

420 type safety in static type checking. 

421 

422 Returns 

423 ------- 

424 result : `set` of `QuantumNode` objects 

425 A `set` of `QuantumNode`s that contain specified `DatasetTypeName` 

426 

427 Raises 

428 ------ 

429 KeyError 

430 Raised if the `DatasetTypeName` is not part of the `QuantumGraph` 

431 

432 """ 

433 tasks = self._datasetDict.getAll(datasetTypeName) 

434 result: Set[Quantum] = set() 

435 result = result.union(*(self._quanta[task] for task in tasks)) 

436 return result 

437 

438 def checkQuantumInGraph(self, quantum: Quantum) -> bool: 

439 """Check if specified quantum appears in the graph as part of a node. 

440 

441 Parameters 

442 ---------- 

443 quantum : `Quantum` 

444 The quantum to search for 

445 

446 Returns 

447 ------- 

448 `bool` 

449 The result of searching for the quantum 

450 """ 

451 for qset in self._quanta.values(): 

452 if quantum in qset: 

453 return True 

454 return False 

455 

456 def writeDotGraph(self, output: Union[str, io.BufferedIOBase]): 

457 """Write out the graph as a dot graph. 

458 

459 Parameters 

460 ---------- 

461 output : str or `io.BufferedIOBase` 

462 Either a filesystem path to write to, or a file handle object 

463 """ 

464 write_dot(self._connectedQuanta, output) 

465 

466 def subset(self: _T, nodes: Union[QuantumNode, Iterable[QuantumNode]]) -> _T: 

467 """Create a new graph object that contains the subset of the nodes 

468 specified as input. Node number is preserved. 

469 

470 Parameters 

471 ---------- 

472 nodes : `QuantumNode` or iterable of `QuantumNode` 

473 

474 Returns 

475 ------- 

476 graph : instance of graph type 

477 An instance of the type from which the subset was created 

478 """ 

479 if not isinstance(nodes, Iterable): 

480 nodes = (nodes, ) 

481 quantumSubgraph = self._connectedQuanta.subgraph(nodes).nodes 

482 quantumMap = defaultdict(set) 

483 

484 node: QuantumNode 

485 for node in quantumSubgraph: 

486 quantumMap[node.taskDef].add(node.quantum) 

487 # Create an empty graph, and then populate it with custom mapping 

488 newInst = type(self)({}) 

489 newInst._buildGraphs(quantumMap, _quantumToNodeId={n.quantum: n.nodeId for n in nodes}, 

490 _buildId=self._buildId) 

491 return newInst 

492 

493 def subsetToConnected(self: _T) -> Tuple[_T, ...]: 

494 """Generate a list of subgraphs where each is connected. 

495 

496 Returns 

497 ------- 

498 result : list of `QuantumGraph` 

499 A list of graphs that are each connected 

500 """ 

501 return tuple(self.subset(connectedSet) 

502 for connectedSet in nx.weakly_connected_components(self._connectedQuanta)) 

503 

504 def determineInputsToQuantumNode(self, node: QuantumNode) -> Set[QuantumNode]: 

505 """Return a set of `QuantumNode` that are direct inputs to a specified 

506 node. 

507 

508 Parameters 

509 ---------- 

510 node : `QuantumNode` 

511 The node of the graph for which inputs are to be determined 

512 

513 Returns 

514 ------- 

515 set of `QuantumNode` 

516 All the nodes that are direct inputs to specified node 

517 """ 

518 return set(pred for pred in self._connectedQuanta.predecessors(node)) 

519 

520 def determineOutputsOfQuantumNode(self, node: QuantumNode) -> Set[QuantumNode]: 

521 """Return a set of `QuantumNode` that are direct outputs of a specified 

522 node. 

523 

524 Parameters 

525 ---------- 

526 node : `QuantumNode` 

527 The node of the graph for which outputs are to be determined 

528 

529 Returns 

530 ------- 

531 set of `QuantumNode` 

532 All the nodes that are direct outputs to specified node 

533 """ 

534 return set(succ for succ in self._connectedQuanta.successors(node)) 

535 

536 def determineConnectionsOfQuantumNode(self: _T, node: QuantumNode) -> _T: 

537 """Return a graph of `QuantumNode` that are direct inputs and outputs 

538 of a specified node. 

539 

540 Parameters 

541 ---------- 

542 node : `QuantumNode` 

543 The node of the graph for which connected nodes are to be 

544 determined. 

545 

546 Returns 

547 ------- 

548 graph : graph of `QuantumNode` 

549 All the nodes that are directly connected to specified node 

550 """ 

551 nodes = self.determineInputsToQuantumNode(node).union(self.determineOutputsOfQuantumNode(node)) 

552 nodes.add(node) 

553 return self.subset(nodes) 

554 

555 def determineAncestorsOfQuantumNode(self: _T, node: QuantumNode) -> _T: 

556 """Return a graph of the specified node and all the ancestor nodes 

557 directly reachable by walking edges. 

558 

559 Parameters 

560 ---------- 

561 node : `QuantumNode` 

562 The node for which all ansestors are to be determined 

563 

564 Returns 

565 ------- 

566 graph of `QuantumNode` 

567 Graph of node and all of its ansestors 

568 """ 

569 predecessorNodes = nx.ancestors(self._connectedQuanta, node) 

570 predecessorNodes.add(node) 

571 return self.subset(predecessorNodes) 

572 

573 def findCycle(self) -> List[Tuple[QuantumNode, QuantumNode]]: 

574 """Check a graph for the presense of cycles and returns the edges of 

575 any cycles found, or an empty list if there is no cycle. 

576 

577 Returns 

578 ------- 

579 result : list of tuple of `QuantumNode`, `QuantumNode` 

580 A list of any graph edges that form a cycle, or an empty list if 

581 there is no cycle. Empty list to so support if graph.find_cycle() 

582 syntax as an empty list is falsy. 

583 """ 

584 try: 

585 return nx.find_cycle(self._connectedQuanta) 

586 except nx.NetworkXNoCycle: 

587 return [] 

588 

589 def saveUri(self, uri): 

590 """Save `QuantumGraph` to the specified URI. 

591 

592 Parameters 

593 ---------- 

594 uri : `ButlerURI` or `str` 

595 URI to where the graph should be saved. 

596 """ 

597 buffer = self._buildSaveObject() 

598 butlerUri = ButlerURI(uri) 

599 if butlerUri.getExtension() not in (".qgraph"): 

600 raise TypeError(f"Can currently only save a graph in qgraph format not {uri}") 

601 butlerUri.write(buffer) # type: ignore # Ignore because bytearray is safe to use in place of bytes 

602 

603 @classmethod 

604 def loadUri(cls, uri: Union[ButlerURI, str], universe: DimensionUniverse, 

605 nodes: Optional[Iterable[int]] = None, 

606 graphID: Optional[BuildId] = None 

607 ) -> QuantumGraph: 

608 """Read `QuantumGraph` from a URI. 

609 

610 Parameters 

611 ---------- 

612 uri : `ButlerURI` or `str` 

613 URI from where to load the graph. 

614 universe: `~lsst.daf.butler.DimensionUniverse` 

615 DimensionUniverse instance, not used by the method itself but 

616 needed to ensure that registry data structures are initialized. 

617 nodes: iterable of `int` or None 

618 Numbers that correspond to nodes in the graph. If specified, only 

619 these nodes will be loaded. Defaults to None, in which case all 

620 nodes will be loaded. 

621 graphID : `str` or `None` 

622 If specified this ID is verified against the loaded graph prior to 

623 loading any Nodes. This defaults to None in which case no 

624 validation is done. 

625 

626 Returns 

627 ------- 

628 graph : `QuantumGraph` 

629 Resulting QuantumGraph instance. 

630 

631 Raises 

632 ------ 

633 TypeError 

634 Raised if pickle contains instance of a type other than 

635 QuantumGraph. 

636 ValueError 

637 Raised if one or more of the nodes requested is not in the 

638 `QuantumGraph` or if graphID parameter does not match the graph 

639 being loaded or if the supplied uri does not point at a valid 

640 `QuantumGraph` save file. 

641 

642 

643 Notes 

644 ----- 

645 Reading Quanta from pickle requires existence of singleton 

646 DimensionUniverse which is usually instantiated during Registry 

647 initialization. To make sure that DimensionUniverse exists this method 

648 accepts dummy DimensionUniverse argument. 

649 """ 

650 uri = ButlerURI(uri) 

651 # With ButlerURI we have the choice of always using a local file 

652 # or reading in the bytes directly. Reading in bytes can be more 

653 # efficient for reasonably-sized pickle files when the resource 

654 # is remote. For now use the local file variant. For a local file 

655 # as_local() does nothing. 

656 

657 if uri.getExtension() in (".pickle", ".pkl"): 

658 with uri.as_local() as local, open(local.ospath, "rb") as fd: 

659 warnings.warn("Pickle graphs are deprecated, please re-save your graph with the save method") 

660 qgraph = pickle.load(fd) 

661 elif uri.getExtension() in ('.qgraph'): 

662 with LoadHelper(uri) as loader: 

663 qgraph = loader.load(nodes, graphID) 

664 else: 

665 raise ValueError("Only know how to handle files saved as `pickle`, `pkl`, or `qgraph`") 

666 if not isinstance(qgraph, QuantumGraph): 

667 raise TypeError(f"QuantumGraph save file contains unexpected object type: {type(qgraph)}") 

668 return qgraph 

669 

670 def save(self, file: io.IO[bytes]): 

671 """Save QuantumGraph to a file. 

672 

673 Presently we store QuantumGraph in pickle format, this could 

674 potentially change in the future if better format is found. 

675 

676 Parameters 

677 ---------- 

678 file : `io.BufferedIOBase` 

679 File to write pickle data open in binary mode. 

680 """ 

681 buffer = self._buildSaveObject() 

682 file.write(buffer) # type: ignore # Ignore because bytearray is safe to use in place of bytes 

683 

684 def _buildSaveObject(self) -> bytearray: 

685 # make some containers 

686 pickleData = deque() 

687 nodeMap = {} 

688 taskDefMap = {} 

689 protocol = 3 

690 

691 # counter for the number of bytes processed thus far 

692 count = 0 

693 # serialize out the task Defs recording the start and end bytes of each 

694 # taskDef 

695 for taskDef in self.taskGraph: 

696 # compressing has very little impact on saving or load time, but 

697 # a large impact on on disk size, so it is worth doing 

698 dump = lzma.compress(pickle.dumps(taskDef, protocol=protocol)) 

699 taskDefMap[taskDef.label] = (count, count+len(dump)) 

700 count += len(dump) 

701 pickleData.append(dump) 

702 

703 # Store the QauntumGraph BuildId along side the TaskDefs for 

704 # convenance. This will allow validating BuildIds at load time, prior 

705 # to loading any QuantumNodes. Name chosen for unlikely conflicts with 

706 # labels as this is python standard for private. 

707 taskDefMap['__GraphBuildID'] = self.graphID 

708 

709 # serialize the nodes, recording the start and end bytes of each node 

710 for node in self: 

711 node = copy.copy(node) 

712 taskDef = node.taskDef 

713 # Explicitly overload the "frozen-ness" of nodes to normalized out 

714 # the taskDef, this saves a lot of space and load time. The label 

715 # will be used to retrive the taskDef from the taskDefMap upon load 

716 # 

717 # This strategy was chosen instead of creating a new class that 

718 # looked just like a QuantumNode but containing a label in place of 

719 # a TaskDef because it would be needlessly slow to construct a 

720 # bunch of new object to immediately serialize them and destroy the 

721 # object. This seems like an acceptable use of Python's dynamic 

722 # nature in a controlled way for optimization and simplicity. 

723 object.__setattr__(node, 'taskDef', taskDef.label) 

724 # compressing has very little impact on saving or load time, but 

725 # a large impact on on disk size, so it is worth doing 

726 dump = lzma.compress(pickle.dumps(node, protocol=protocol)) 

727 pickleData.append(dump) 

728 nodeMap[node.nodeId.number] = (count, count+len(dump)) 

729 count += len(dump) 

730 

731 # pickle the taskDef byte map 

732 taskDef_pickle = pickle.dumps(taskDefMap, protocol=protocol) 

733 

734 # pickle the node byte map 

735 map_pickle = pickle.dumps(nodeMap, protocol=protocol) 

736 

737 # record the sizes as 2 unsigned long long numbers for a total of 16 

738 # bytes 

739 map_lengths = struct.pack(STRUCT_FMT_STRING, SAVE_VERSION, len(taskDef_pickle), len(map_pickle)) 

740 

741 # write each component of the save out in a deterministic order 

742 # buffer = io.BytesIO() 

743 # buffer.write(map_lengths) 

744 # buffer.write(taskDef_pickle) 

745 # buffer.write(map_pickle) 

746 buffer = bytearray() 

747 buffer.extend(MAGIC_BYTES) 

748 buffer.extend(map_lengths) 

749 buffer.extend(taskDef_pickle) 

750 buffer.extend(map_pickle) 

751 # Iterate over the length of pickleData, and for each element pop the 

752 # leftmost element off the deque and write it out. This is to save 

753 # memory, as the memory is added to the buffer object, it is removed 

754 # from from the container. 

755 # 

756 # Only this section needs to worry about memory pressue because 

757 # everything else written to the buffer prior to this pickle data is 

758 # only on the order of kilobytes to low numbers of megabytes. 

759 while pickleData: 

760 buffer.extend(pickleData.popleft()) 

761 return buffer 

762 

763 @classmethod 

764 def load(cls, file: io.IO[bytes], universe: DimensionUniverse, 

765 nodes: Optional[Iterable[int]] = None, 

766 graphID: Optional[BuildId] = None 

767 ) -> QuantumGraph: 

768 """Read QuantumGraph from a file that was made by `save`. 

769 

770 Parameters 

771 ---------- 

772 file : `io.IO` of bytes 

773 File with pickle data open in binary mode. 

774 universe: `~lsst.daf.butler.DimensionUniverse` 

775 DimensionUniverse instance, not used by the method itself but 

776 needed to ensure that registry data structures are initialized. 

777 nodes: iterable of `int` or None 

778 Numbers that correspond to nodes in the graph. If specified, only 

779 these nodes will be loaded. Defaults to None, in which case all 

780 nodes will be loaded. 

781 graphID : `str` or `None` 

782 If specified this ID is verified against the loaded graph prior to 

783 loading any Nodes. This defaults to None in which case no 

784 validation is done. 

785 

786 Returns 

787 ------- 

788 graph : `QuantumGraph` 

789 Resulting QuantumGraph instance. 

790 

791 Raises 

792 ------ 

793 TypeError 

794 Raised if pickle contains instance of a type other than 

795 QuantumGraph. 

796 ValueError 

797 Raised if one or more of the nodes requested is not in the 

798 `QuantumGraph` or if graphID parameter does not match the graph 

799 being loaded or if the supplied uri does not point at a valid 

800 `QuantumGraph` save file. 

801 

802 Notes 

803 ----- 

804 Reading Quanta from pickle requires existence of singleton 

805 DimensionUniverse which is usually instantiated during Registry 

806 initialization. To make sure that DimensionUniverse exists this method 

807 accepts dummy DimensionUniverse argument. 

808 """ 

809 # Try to see if the file handle contains pickle data, this will be 

810 # removed in the future 

811 try: 

812 qgraph = pickle.load(file) 

813 warnings.warn("Pickle graphs are deprecated, please re-save your graph with the save method") 

814 except pickle.UnpicklingError: 

815 with LoadHelper(file) as loader: # type: ignore # needed because we don't have Protocols yet 

816 qgraph = loader.load(nodes, graphID) 

817 if not isinstance(qgraph, QuantumGraph): 

818 raise TypeError(f"QuantumGraph pickle file has contains unexpected object type: {type(qgraph)}") 

819 return qgraph 

820 

821 def iterTaskGraph(self) -> Generator[TaskDef, None, None]: 

822 """Iterate over the `taskGraph` attribute in topological order 

823 

824 Yields 

825 ------ 

826 taskDef : `TaskDef` 

827 `TaskDef` objects in topological order 

828 """ 

829 yield from nx.topological_sort(self.taskGraph) 

830 

831 @property 

832 def graphID(self): 

833 """Returns the ID generated by the graph at construction time 

834 """ 

835 return self._buildId 

836 

837 def __iter__(self) -> Generator[QuantumNode, None, None]: 

838 yield from nx.topological_sort(self._connectedQuanta) 

839 

840 def __len__(self) -> int: 

841 return self._count 

842 

843 def __contains__(self, node: QuantumNode) -> bool: 

844 return self._connectedQuanta.has_node(node) 

845 

846 def __getstate__(self) -> dict: 

847 """Stores a compact form of the graph as a list of graph nodes, and a 

848 tuple of task labels and task configs. The full graph can be 

849 reconstructed with this information, and it preseves the ordering of 

850 the graph ndoes. 

851 """ 

852 return {"nodesList": list(self)} 

853 

854 def __setstate__(self, state: dict): 

855 """Reconstructs the state of the graph from the information persisted 

856 in getstate. 

857 """ 

858 quanta: DefaultDict[TaskDef, Set[Quantum]] = defaultdict(set) 

859 quantumToNodeId: Dict[Quantum, NodeId] = {} 

860 quantumNode: QuantumNode 

861 for quantumNode in state['nodesList']: 

862 quanta[quantumNode.taskDef].add(quantumNode.quantum) 

863 quantumToNodeId[quantumNode.quantum] = quantumNode.nodeId 

864 _buildId = quantumNode.nodeId.buildId if state['nodesList'] else None # type: ignore 

865 self._buildGraphs(quanta, _quantumToNodeId=quantumToNodeId, _buildId=_buildId) 

866 

867 def __eq__(self, other: object) -> bool: 

868 if not isinstance(other, QuantumGraph): 

869 return False 

870 if len(self) != len(other): 

871 return False 

872 for node in self: 

873 if node not in other: 

874 return False 

875 if self.determineInputsToQuantumNode(node) != other.determineInputsToQuantumNode(node): 

876 return False 

877 if self.determineOutputsOfQuantumNode(node) != other.determineOutputsOfQuantumNode(node): 

878 return False 

879 return list(self.taskGraph) == list(other.taskGraph)