Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22import warnings 

23 

24__all__ = ("QuantumGraph", "IncompatibleGraphError") 

25 

26from collections import defaultdict, deque 

27 

28from itertools import chain, count 

29import io 

30import json 

31import networkx as nx 

32from networkx.drawing.nx_agraph import write_dot 

33import os 

34import pickle 

35import lzma 

36import copy 

37import struct 

38import time 

39from types import MappingProxyType 

40from typing import (Any, DefaultDict, Dict, FrozenSet, Iterable, List, Mapping, Set, Generator, Optional, 

41 Tuple, Union, TypeVar) 

42 

43from ..connections import iterConnections 

44from ..pipeline import TaskDef 

45from lsst.daf.butler import Quantum, DatasetRef, ButlerURI, DimensionUniverse 

46 

47from ._implDetails import _DatasetTracker, DatasetTypeName 

48from .quantumNode import QuantumNode, NodeId, BuildId 

49from ._loadHelpers import LoadHelper 

50 

51 

52_T = TypeVar("_T", bound="QuantumGraph") 

53 

54# modify this constant any time the on disk representation of the save file 

55# changes, and update the load helpers to behave properly for each version. 

56SAVE_VERSION = 2 

57 

58# Strings used to describe the format for the preamble bytes in a file save 

59# The base is a big endian encoded unsigned short that is used to hold the 

60# file format version. This allows reading version bytes and determine which 

61# loading code should be used for the rest of the file 

62STRUCT_FMT_BASE = '>H' 

63# 

64# Version 1 

65# This marks a big endian encoded format with an unsigned short, an unsigned 

66# long long, and an unsigned long long in the byte stream 

67# Version 2 

68# A big endian encoded format with an unsigned long long byte stream used to 

69# indicate the total length of the entire header 

70STRUCT_FMT_STRING = { 

71 1: '>QQ', 

72 2: '>Q' 

73} 

74 

75 

76# magic bytes that help determine this is a graph save 

77MAGIC_BYTES = b"qgraph4\xf6\xe8\xa9" 

78 

79 

80class IncompatibleGraphError(Exception): 

81 """Exception class to indicate that a lookup by NodeId is impossible due 

82 to incompatibilities 

83 """ 

84 pass 

85 

86 

87class QuantumGraph: 

88 """QuantumGraph is a directed acyclic graph of `QuantumNode` objects 

89 

90 This data structure represents a concrete workflow generated from a 

91 `Pipeline`. 

92 

93 Parameters 

94 ---------- 

95 quanta : Mapping of `TaskDef` to sets of `Quantum` 

96 This maps tasks (and their configs) to the sets of data they are to 

97 process. 

98 metadata : Optional Mapping of `str` to primitives 

99 This is an optional parameter of extra data to carry with the graph. 

100 Entries in this mapping should be able to be serialized in JSON. 

101 """ 

102 def __init__(self, quanta: Mapping[TaskDef, Set[Quantum]], 

103 metadata: Optional[Mapping[str, Any]] = None): 

104 self._buildGraphs(quanta, metadata=metadata) 

105 

106 def _buildGraphs(self, 

107 quanta: Mapping[TaskDef, Set[Quantum]], 

108 *, 

109 _quantumToNodeId: Optional[Mapping[Quantum, NodeId]] = None, 

110 _buildId: Optional[BuildId] = None, 

111 metadata: Optional[Mapping[str, Any]] = None): 

112 """Builds the graph that is used to store the relation between tasks, 

113 and the graph that holds the relations between quanta 

114 """ 

115 self._metadata = metadata 

116 self._quanta = quanta 

117 self._buildId = _buildId if _buildId is not None else BuildId(f"{time.time()}-{os.getpid()}") 

118 # Data structures used to identify relations between components; 

119 # DatasetTypeName -> TaskDef for task, 

120 # and DatasetRef -> QuantumNode for the quanta 

121 self._datasetDict = _DatasetTracker[DatasetTypeName, TaskDef]() 

122 self._datasetRefDict = _DatasetTracker[DatasetRef, QuantumNode]() 

123 

124 nodeNumberGenerator = count() 

125 self._nodeIdMap: Dict[NodeId, QuantumNode] = {} 

126 self._taskToQuantumNode: DefaultDict[TaskDef, Set[QuantumNode]] = defaultdict(set) 

127 self._count = 0 

128 for taskDef, quantumSet in self._quanta.items(): 

129 connections = taskDef.connections 

130 

131 # For each type of connection in the task, add a key to the 

132 # `_DatasetTracker` for the connections name, with a value of 

133 # the TaskDef in the appropriate field 

134 for inpt in iterConnections(connections, ("inputs", "prerequisiteInputs", "initInputs")): 

135 self._datasetDict.addInput(DatasetTypeName(inpt.name), taskDef) 

136 

137 for output in iterConnections(connections, ("outputs", "initOutputs")): 

138 self._datasetDict.addOutput(DatasetTypeName(output.name), taskDef) 

139 

140 # For each `Quantum` in the set of all `Quantum` for this task, 

141 # add a key to the `_DatasetTracker` that is a `DatasetRef` for one 

142 # of the individual datasets inside the `Quantum`, with a value of 

143 # a newly created QuantumNode to the appropriate input/output 

144 # field. 

145 self._count += len(quantumSet) 

146 for quantum in quantumSet: 

147 if _quantumToNodeId: 

148 nodeId = _quantumToNodeId.get(quantum) 

149 if nodeId is None: 

150 raise ValueError("If _quantuMToNodeNumber is not None, all quanta must have an " 

151 "associated value in the mapping") 

152 else: 

153 nodeId = NodeId(next(nodeNumberGenerator), self._buildId) 

154 

155 inits = quantum.initInputs.values() 

156 inputs = quantum.inputs.values() 

157 value = QuantumNode(quantum, taskDef, nodeId) 

158 self._taskToQuantumNode[taskDef].add(value) 

159 self._nodeIdMap[nodeId] = value 

160 

161 for dsRef in chain(inits, inputs): 

162 # unfortunately, `Quantum` allows inits to be individual 

163 # `DatasetRef`s or an Iterable of such, so there must 

164 # be an instance check here 

165 if isinstance(dsRef, Iterable): 

166 for sub in dsRef: 

167 self._datasetRefDict.addInput(sub, value) 

168 else: 

169 self._datasetRefDict.addInput(dsRef, value) 

170 for dsRef in chain.from_iterable(quantum.outputs.values()): 

171 self._datasetRefDict.addOutput(dsRef, value) 

172 

173 # Graph of task relations, used in various methods 

174 self._taskGraph = self._datasetDict.makeNetworkXGraph() 

175 

176 # Graph of quanta relations 

177 self._connectedQuanta = self._datasetRefDict.makeNetworkXGraph() 

178 

179 @property 

180 def taskGraph(self) -> nx.DiGraph: 

181 """Return a graph representing the relations between the tasks inside 

182 the quantum graph. 

183 

184 Returns 

185 ------- 

186 taskGraph : `networkx.Digraph` 

187 Internal datastructure that holds relations of `TaskDef` objects 

188 """ 

189 return self._taskGraph 

190 

191 @property 

192 def graph(self) -> nx.DiGraph: 

193 """Return a graph representing the relations between all the 

194 `QuantumNode` objects. Largely it should be preferred to iterate 

195 over, and use methods of this class, but sometimes direct access to 

196 the networkx object may be helpful 

197 

198 Returns 

199 ------- 

200 graph : `networkx.Digraph` 

201 Internal datastructure that holds relations of `QuantumNode` 

202 objects 

203 """ 

204 return self._connectedQuanta 

205 

206 @property 

207 def inputQuanta(self) -> Iterable[QuantumNode]: 

208 """Make a `list` of all `QuantumNode` objects that are 'input' nodes 

209 to the graph, meaning those nodes to not depend on any other nodes in 

210 the graph. 

211 

212 Returns 

213 ------- 

214 inputNodes : iterable of `QuantumNode` 

215 A list of nodes that are inputs to the graph 

216 """ 

217 return (q for q, n in self._connectedQuanta.in_degree if n == 0) 

218 

219 @property 

220 def outputQuanta(self) -> Iterable[QuantumNode]: 

221 """Make a `list` of all `QuantumNode` objects that are 'output' nodes 

222 to the graph, meaning those nodes have no nodes that depend them in 

223 the graph. 

224 

225 Returns 

226 ------- 

227 outputNodes : iterable of `QuantumNode` 

228 A list of nodes that are outputs of the graph 

229 """ 

230 return [q for q, n in self._connectedQuanta.out_degree if n == 0] 

231 

232 @property 

233 def allDatasetTypes(self) -> Tuple[DatasetTypeName, ...]: 

234 """Return all the `DatasetTypeName` objects that are contained inside 

235 the graph. 

236 

237 Returns 

238 ------- 

239 tuple of `DatasetTypeName` 

240 All the data set type names that are present in the graph 

241 """ 

242 return tuple(self._datasetDict.keys()) 

243 

244 @property 

245 def isConnected(self) -> bool: 

246 """Return True if all of the nodes in the graph are connected, ignores 

247 directionality of connections. 

248 """ 

249 return nx.is_weakly_connected(self._connectedQuanta) 

250 

251 def getQuantumNodeByNodeId(self, nodeId: NodeId) -> QuantumNode: 

252 """Lookup a `QuantumNode` from an id associated with the node. 

253 

254 Parameters 

255 ---------- 

256 nodeId : `NodeId` 

257 The number associated with a node 

258 

259 Returns 

260 ------- 

261 node : `QuantumNode` 

262 The node corresponding with input number 

263 

264 Raises 

265 ------ 

266 IndexError 

267 Raised if the requested nodeId is not in the graph. 

268 IncompatibleGraphError 

269 Raised if the nodeId was built with a different graph than is not 

270 this instance (or a graph instance that produced this instance 

271 through and operation such as subset) 

272 """ 

273 if nodeId.buildId != self._buildId: 

274 raise IncompatibleGraphError("This node was built from a different, incompatible, graph instance") 

275 return self._nodeIdMap[nodeId] 

276 

277 def getQuantaForTask(self, taskDef: TaskDef) -> FrozenSet[Quantum]: 

278 """Return all the `Quantum` associated with a `TaskDef`. 

279 

280 Parameters 

281 ---------- 

282 taskDef : `TaskDef` 

283 The `TaskDef` for which `Quantum` are to be queried 

284 

285 Returns 

286 ------- 

287 frozenset of `Quantum` 

288 The `set` of `Quantum` that is associated with the specified 

289 `TaskDef`. 

290 """ 

291 return frozenset(self._quanta[taskDef]) 

292 

293 def getNodesForTask(self, taskDef: TaskDef) -> FrozenSet[QuantumNode]: 

294 """Return all the `QuantumNodes` associated with a `TaskDef`. 

295 

296 Parameters 

297 ---------- 

298 taskDef : `TaskDef` 

299 The `TaskDef` for which `Quantum` are to be queried 

300 

301 Returns 

302 ------- 

303 frozenset of `QuantumNodes` 

304 The `frozenset` of `QuantumNodes` that is associated with the 

305 specified `TaskDef`. 

306 """ 

307 return frozenset(self._taskToQuantumNode[taskDef]) 

308 

309 def findTasksWithInput(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]: 

310 """Find all tasks that have the specified dataset type name as an 

311 input. 

312 

313 Parameters 

314 ---------- 

315 datasetTypeName : `str` 

316 A string representing the name of a dataset type to be queried, 

317 can also accept a `DatasetTypeName` which is a `NewType` of str for 

318 type safety in static type checking. 

319 

320 Returns 

321 ------- 

322 tasks : iterable of `TaskDef` 

323 `TaskDef` objects that have the specified `DatasetTypeName` as an 

324 input, list will be empty if no tasks use specified 

325 `DatasetTypeName` as an input. 

326 

327 Raises 

328 ------ 

329 KeyError 

330 Raised if the `DatasetTypeName` is not part of the `QuantumGraph` 

331 """ 

332 return (c for c in self._datasetDict.getInputs(datasetTypeName)) 

333 

334 def findTaskWithOutput(self, datasetTypeName: DatasetTypeName) -> Optional[TaskDef]: 

335 """Find all tasks that have the specified dataset type name as an 

336 output. 

337 

338 Parameters 

339 ---------- 

340 datasetTypeName : `str` 

341 A string representing the name of a dataset type to be queried, 

342 can also accept a `DatasetTypeName` which is a `NewType` of str for 

343 type safety in static type checking. 

344 

345 Returns 

346 ------- 

347 `TaskDef` or `None` 

348 `TaskDef` that outputs `DatasetTypeName` as an output or None if 

349 none of the tasks produce this `DatasetTypeName`. 

350 

351 Raises 

352 ------ 

353 KeyError 

354 Raised if the `DatasetTypeName` is not part of the `QuantumGraph` 

355 """ 

356 return self._datasetDict.getOutput(datasetTypeName) 

357 

358 def tasksWithDSType(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]: 

359 """Find all tasks that are associated with the specified dataset type 

360 name. 

361 

362 Parameters 

363 ---------- 

364 datasetTypeName : `str` 

365 A string representing the name of a dataset type to be queried, 

366 can also accept a `DatasetTypeName` which is a `NewType` of str for 

367 type safety in static type checking. 

368 

369 Returns 

370 ------- 

371 result : iterable of `TaskDef` 

372 `TaskDef` objects that are associated with the specified 

373 `DatasetTypeName` 

374 

375 Raises 

376 ------ 

377 KeyError 

378 Raised if the `DatasetTypeName` is not part of the `QuantumGraph` 

379 """ 

380 results = self.findTasksWithInput(datasetTypeName) 

381 output = self.findTaskWithOutput(datasetTypeName) 

382 if output is not None: 

383 results = chain(results, (output,)) 

384 return results 

385 

386 def findTaskDefByName(self, taskName: str) -> List[TaskDef]: 

387 """Determine which `TaskDef` objects in this graph are associated 

388 with a `str` representing a task name (looks at the taskName property 

389 of `TaskDef` objects). 

390 

391 Returns a list of `TaskDef` objects as a `PipelineTask` may appear 

392 multiple times in a graph with different labels. 

393 

394 Parameters 

395 ---------- 

396 taskName : str 

397 Name of a task to search for 

398 

399 Returns 

400 ------- 

401 result : list of `TaskDef` 

402 List of the `TaskDef` objects that have the name specified. 

403 Multiple values are returned in the case that a task is used 

404 multiple times with different labels. 

405 """ 

406 results = [] 

407 for task in self._quanta.keys(): 

408 split = task.taskName.split('.') 

409 if split[-1] == taskName: 

410 results.append(task) 

411 return results 

412 

413 def findTaskDefByLabel(self, label: str) -> Optional[TaskDef]: 

414 """Determine which `TaskDef` objects in this graph are associated 

415 with a `str` representing a tasks label. 

416 

417 Parameters 

418 ---------- 

419 taskName : str 

420 Name of a task to search for 

421 

422 Returns 

423 ------- 

424 result : `TaskDef` 

425 `TaskDef` objects that has the specified label. 

426 """ 

427 for task in self._quanta.keys(): 

428 if label == task.label: 

429 return task 

430 return None 

431 

432 def findQuantaWithDSType(self, datasetTypeName: DatasetTypeName) -> Set[Quantum]: 

433 """Return all the `Quantum` that contain a specified `DatasetTypeName`. 

434 

435 Parameters 

436 ---------- 

437 datasetTypeName : `str` 

438 The name of the dataset type to search for as a string, 

439 can also accept a `DatasetTypeName` which is a `NewType` of str for 

440 type safety in static type checking. 

441 

442 Returns 

443 ------- 

444 result : `set` of `QuantumNode` objects 

445 A `set` of `QuantumNode`s that contain specified `DatasetTypeName` 

446 

447 Raises 

448 ------ 

449 KeyError 

450 Raised if the `DatasetTypeName` is not part of the `QuantumGraph` 

451 

452 """ 

453 tasks = self._datasetDict.getAll(datasetTypeName) 

454 result: Set[Quantum] = set() 

455 result = result.union(*(self._quanta[task] for task in tasks)) 

456 return result 

457 

458 def checkQuantumInGraph(self, quantum: Quantum) -> bool: 

459 """Check if specified quantum appears in the graph as part of a node. 

460 

461 Parameters 

462 ---------- 

463 quantum : `Quantum` 

464 The quantum to search for 

465 

466 Returns 

467 ------- 

468 `bool` 

469 The result of searching for the quantum 

470 """ 

471 for qset in self._quanta.values(): 

472 if quantum in qset: 

473 return True 

474 return False 

475 

476 def writeDotGraph(self, output: Union[str, io.BufferedIOBase]): 

477 """Write out the graph as a dot graph. 

478 

479 Parameters 

480 ---------- 

481 output : str or `io.BufferedIOBase` 

482 Either a filesystem path to write to, or a file handle object 

483 """ 

484 write_dot(self._connectedQuanta, output) 

485 

486 def subset(self: _T, nodes: Union[QuantumNode, Iterable[QuantumNode]]) -> _T: 

487 """Create a new graph object that contains the subset of the nodes 

488 specified as input. Node number is preserved. 

489 

490 Parameters 

491 ---------- 

492 nodes : `QuantumNode` or iterable of `QuantumNode` 

493 

494 Returns 

495 ------- 

496 graph : instance of graph type 

497 An instance of the type from which the subset was created 

498 """ 

499 if not isinstance(nodes, Iterable): 

500 nodes = (nodes, ) 

501 quantumSubgraph = self._connectedQuanta.subgraph(nodes).nodes 

502 quantumMap = defaultdict(set) 

503 

504 node: QuantumNode 

505 for node in quantumSubgraph: 

506 quantumMap[node.taskDef].add(node.quantum) 

507 # Create an empty graph, and then populate it with custom mapping 

508 newInst = type(self)({}) 

509 newInst._buildGraphs(quantumMap, _quantumToNodeId={n.quantum: n.nodeId for n in nodes}, 

510 _buildId=self._buildId) 

511 return newInst 

512 

513 def subsetToConnected(self: _T) -> Tuple[_T, ...]: 

514 """Generate a list of subgraphs where each is connected. 

515 

516 Returns 

517 ------- 

518 result : list of `QuantumGraph` 

519 A list of graphs that are each connected 

520 """ 

521 return tuple(self.subset(connectedSet) 

522 for connectedSet in nx.weakly_connected_components(self._connectedQuanta)) 

523 

524 def determineInputsToQuantumNode(self, node: QuantumNode) -> Set[QuantumNode]: 

525 """Return a set of `QuantumNode` that are direct inputs to a specified 

526 node. 

527 

528 Parameters 

529 ---------- 

530 node : `QuantumNode` 

531 The node of the graph for which inputs are to be determined 

532 

533 Returns 

534 ------- 

535 set of `QuantumNode` 

536 All the nodes that are direct inputs to specified node 

537 """ 

538 return set(pred for pred in self._connectedQuanta.predecessors(node)) 

539 

540 def determineOutputsOfQuantumNode(self, node: QuantumNode) -> Set[QuantumNode]: 

541 """Return a set of `QuantumNode` that are direct outputs of a specified 

542 node. 

543 

544 Parameters 

545 ---------- 

546 node : `QuantumNode` 

547 The node of the graph for which outputs are to be determined 

548 

549 Returns 

550 ------- 

551 set of `QuantumNode` 

552 All the nodes that are direct outputs to specified node 

553 """ 

554 return set(succ for succ in self._connectedQuanta.successors(node)) 

555 

556 def determineConnectionsOfQuantumNode(self: _T, node: QuantumNode) -> _T: 

557 """Return a graph of `QuantumNode` that are direct inputs and outputs 

558 of a specified node. 

559 

560 Parameters 

561 ---------- 

562 node : `QuantumNode` 

563 The node of the graph for which connected nodes are to be 

564 determined. 

565 

566 Returns 

567 ------- 

568 graph : graph of `QuantumNode` 

569 All the nodes that are directly connected to specified node 

570 """ 

571 nodes = self.determineInputsToQuantumNode(node).union(self.determineOutputsOfQuantumNode(node)) 

572 nodes.add(node) 

573 return self.subset(nodes) 

574 

575 def determineAncestorsOfQuantumNode(self: _T, node: QuantumNode) -> _T: 

576 """Return a graph of the specified node and all the ancestor nodes 

577 directly reachable by walking edges. 

578 

579 Parameters 

580 ---------- 

581 node : `QuantumNode` 

582 The node for which all ansestors are to be determined 

583 

584 Returns 

585 ------- 

586 graph of `QuantumNode` 

587 Graph of node and all of its ansestors 

588 """ 

589 predecessorNodes = nx.ancestors(self._connectedQuanta, node) 

590 predecessorNodes.add(node) 

591 return self.subset(predecessorNodes) 

592 

593 def findCycle(self) -> List[Tuple[QuantumNode, QuantumNode]]: 

594 """Check a graph for the presense of cycles and returns the edges of 

595 any cycles found, or an empty list if there is no cycle. 

596 

597 Returns 

598 ------- 

599 result : list of tuple of `QuantumNode`, `QuantumNode` 

600 A list of any graph edges that form a cycle, or an empty list if 

601 there is no cycle. Empty list to so support if graph.find_cycle() 

602 syntax as an empty list is falsy. 

603 """ 

604 try: 

605 return nx.find_cycle(self._connectedQuanta) 

606 except nx.NetworkXNoCycle: 

607 return [] 

608 

609 def saveUri(self, uri): 

610 """Save `QuantumGraph` to the specified URI. 

611 

612 Parameters 

613 ---------- 

614 uri : `ButlerURI` or `str` 

615 URI to where the graph should be saved. 

616 """ 

617 buffer = self._buildSaveObject() 

618 butlerUri = ButlerURI(uri) 

619 if butlerUri.getExtension() not in (".qgraph"): 

620 raise TypeError(f"Can currently only save a graph in qgraph format not {uri}") 

621 butlerUri.write(buffer) # type: ignore # Ignore because bytearray is safe to use in place of bytes 

622 

623 @property 

624 def metadata(self) -> Optional[MappingProxyType[str, Any]]: 

625 """ 

626 """ 

627 if self._metadata is None: 

628 return None 

629 return MappingProxyType(self._metadata) 

630 

631 @classmethod 

632 def loadUri(cls, uri: Union[ButlerURI, str], universe: DimensionUniverse, 

633 nodes: Optional[Iterable[int]] = None, 

634 graphID: Optional[BuildId] = None 

635 ) -> QuantumGraph: 

636 """Read `QuantumGraph` from a URI. 

637 

638 Parameters 

639 ---------- 

640 uri : `ButlerURI` or `str` 

641 URI from where to load the graph. 

642 universe: `~lsst.daf.butler.DimensionUniverse` 

643 DimensionUniverse instance, not used by the method itself but 

644 needed to ensure that registry data structures are initialized. 

645 nodes: iterable of `int` or None 

646 Numbers that correspond to nodes in the graph. If specified, only 

647 these nodes will be loaded. Defaults to None, in which case all 

648 nodes will be loaded. 

649 graphID : `str` or `None` 

650 If specified this ID is verified against the loaded graph prior to 

651 loading any Nodes. This defaults to None in which case no 

652 validation is done. 

653 

654 Returns 

655 ------- 

656 graph : `QuantumGraph` 

657 Resulting QuantumGraph instance. 

658 

659 Raises 

660 ------ 

661 TypeError 

662 Raised if pickle contains instance of a type other than 

663 QuantumGraph. 

664 ValueError 

665 Raised if one or more of the nodes requested is not in the 

666 `QuantumGraph` or if graphID parameter does not match the graph 

667 being loaded or if the supplied uri does not point at a valid 

668 `QuantumGraph` save file. 

669 

670 

671 Notes 

672 ----- 

673 Reading Quanta from pickle requires existence of singleton 

674 DimensionUniverse which is usually instantiated during Registry 

675 initialization. To make sure that DimensionUniverse exists this method 

676 accepts dummy DimensionUniverse argument. 

677 """ 

678 uri = ButlerURI(uri) 

679 # With ButlerURI we have the choice of always using a local file 

680 # or reading in the bytes directly. Reading in bytes can be more 

681 # efficient for reasonably-sized pickle files when the resource 

682 # is remote. For now use the local file variant. For a local file 

683 # as_local() does nothing. 

684 

685 if uri.getExtension() in (".pickle", ".pkl"): 

686 with uri.as_local() as local, open(local.ospath, "rb") as fd: 

687 warnings.warn("Pickle graphs are deprecated, please re-save your graph with the save method") 

688 qgraph = pickle.load(fd) 

689 elif uri.getExtension() in ('.qgraph'): 

690 with LoadHelper(uri) as loader: 

691 qgraph = loader.load(nodes, graphID) 

692 else: 

693 raise ValueError("Only know how to handle files saved as `pickle`, `pkl`, or `qgraph`") 

694 if not isinstance(qgraph, QuantumGraph): 

695 raise TypeError(f"QuantumGraph save file contains unexpected object type: {type(qgraph)}") 

696 return qgraph 

697 

698 def save(self, file: io.IO[bytes]): 

699 """Save QuantumGraph to a file. 

700 

701 Presently we store QuantumGraph in pickle format, this could 

702 potentially change in the future if better format is found. 

703 

704 Parameters 

705 ---------- 

706 file : `io.BufferedIOBase` 

707 File to write pickle data open in binary mode. 

708 """ 

709 buffer = self._buildSaveObject() 

710 file.write(buffer) # type: ignore # Ignore because bytearray is safe to use in place of bytes 

711 

712 def _buildSaveObject(self) -> bytearray: 

713 # make some containers 

714 pickleData = deque() 

715 # node map is a list because json does not accept mapping keys that 

716 # are not strings, so we store a list of key, value pairs that will 

717 # be converted to a mapping on load 

718 nodeMap = [] 

719 taskDefMap = {} 

720 headerData = {} 

721 protocol = 3 

722 

723 # Store the QauntumGraph BuildId, this will allow validating BuildIds 

724 # at load time, prior to loading any QuantumNodes. Name chosen for 

725 # unlikely conflicts. 

726 headerData['GraphBuildID'] = self.graphID 

727 headerData['Metadata'] = self._metadata 

728 

729 # counter for the number of bytes processed thus far 

730 count = 0 

731 # serialize out the task Defs recording the start and end bytes of each 

732 # taskDef 

733 for taskDef in self.taskGraph: 

734 # compressing has very little impact on saving or load time, but 

735 # a large impact on on disk size, so it is worth doing 

736 dump = lzma.compress(pickle.dumps(taskDef, protocol=protocol)) 

737 taskDefMap[taskDef.label] = {"bytes": (count, count+len(dump))} 

738 count += len(dump) 

739 pickleData.append(dump) 

740 

741 headerData['TaskDefs'] = taskDefMap 

742 

743 # serialize the nodes, recording the start and end bytes of each node 

744 for node in self: 

745 node = copy.copy(node) 

746 taskDef = node.taskDef 

747 # Explicitly overload the "frozen-ness" of nodes to normalized out 

748 # the taskDef, this saves a lot of space and load time. The label 

749 # will be used to retrive the taskDef from the taskDefMap upon load 

750 # 

751 # This strategy was chosen instead of creating a new class that 

752 # looked just like a QuantumNode but containing a label in place of 

753 # a TaskDef because it would be needlessly slow to construct a 

754 # bunch of new object to immediately serialize them and destroy the 

755 # object. This seems like an acceptable use of Python's dynamic 

756 # nature in a controlled way for optimization and simplicity. 

757 object.__setattr__(node, 'taskDef', taskDef.label) 

758 # compressing has very little impact on saving or load time, but 

759 # a large impact on on disk size, so it is worth doing 

760 dump = lzma.compress(pickle.dumps(node, protocol=protocol)) 

761 pickleData.append(dump) 

762 nodeMap.append((int(node.nodeId.number), {"bytes": (count, count+len(dump))})) 

763 count += len(dump) 

764 

765 # need to serialize this as a series of key,value tuples because of 

766 # a limitation on how json cant do anyting but strings as keys 

767 headerData['Nodes'] = nodeMap 

768 

769 # dump the headerData to json 

770 header_encode = lzma.compress(json.dumps(headerData).encode()) 

771 

772 # record the sizes as 2 unsigned long long numbers for a total of 16 

773 # bytes 

774 save_bytes = struct.pack(STRUCT_FMT_BASE, SAVE_VERSION) 

775 

776 fmt_string = STRUCT_FMT_STRING[SAVE_VERSION] 

777 map_lengths = struct.pack(fmt_string, len(header_encode)) 

778 

779 # write each component of the save out in a deterministic order 

780 # buffer = io.BytesIO() 

781 # buffer.write(map_lengths) 

782 # buffer.write(taskDef_pickle) 

783 # buffer.write(map_pickle) 

784 buffer = bytearray() 

785 buffer.extend(MAGIC_BYTES) 

786 buffer.extend(save_bytes) 

787 buffer.extend(map_lengths) 

788 buffer.extend(header_encode) 

789 # Iterate over the length of pickleData, and for each element pop the 

790 # leftmost element off the deque and write it out. This is to save 

791 # memory, as the memory is added to the buffer object, it is removed 

792 # from from the container. 

793 # 

794 # Only this section needs to worry about memory pressue because 

795 # everything else written to the buffer prior to this pickle data is 

796 # only on the order of kilobytes to low numbers of megabytes. 

797 while pickleData: 

798 buffer.extend(pickleData.popleft()) 

799 return buffer 

800 

801 @classmethod 

802 def load(cls, file: io.IO[bytes], universe: DimensionUniverse, 

803 nodes: Optional[Iterable[int]] = None, 

804 graphID: Optional[BuildId] = None 

805 ) -> QuantumGraph: 

806 """Read QuantumGraph from a file that was made by `save`. 

807 

808 Parameters 

809 ---------- 

810 file : `io.IO` of bytes 

811 File with pickle data open in binary mode. 

812 universe: `~lsst.daf.butler.DimensionUniverse` 

813 DimensionUniverse instance, not used by the method itself but 

814 needed to ensure that registry data structures are initialized. 

815 nodes: iterable of `int` or None 

816 Numbers that correspond to nodes in the graph. If specified, only 

817 these nodes will be loaded. Defaults to None, in which case all 

818 nodes will be loaded. 

819 graphID : `str` or `None` 

820 If specified this ID is verified against the loaded graph prior to 

821 loading any Nodes. This defaults to None in which case no 

822 validation is done. 

823 

824 Returns 

825 ------- 

826 graph : `QuantumGraph` 

827 Resulting QuantumGraph instance. 

828 

829 Raises 

830 ------ 

831 TypeError 

832 Raised if pickle contains instance of a type other than 

833 QuantumGraph. 

834 ValueError 

835 Raised if one or more of the nodes requested is not in the 

836 `QuantumGraph` or if graphID parameter does not match the graph 

837 being loaded or if the supplied uri does not point at a valid 

838 `QuantumGraph` save file. 

839 

840 Notes 

841 ----- 

842 Reading Quanta from pickle requires existence of singleton 

843 DimensionUniverse which is usually instantiated during Registry 

844 initialization. To make sure that DimensionUniverse exists this method 

845 accepts dummy DimensionUniverse argument. 

846 """ 

847 # Try to see if the file handle contains pickle data, this will be 

848 # removed in the future 

849 try: 

850 qgraph = pickle.load(file) 

851 warnings.warn("Pickle graphs are deprecated, please re-save your graph with the save method") 

852 except pickle.UnpicklingError: 

853 with LoadHelper(file) as loader: # type: ignore # needed because we don't have Protocols yet 

854 qgraph = loader.load(nodes, graphID) 

855 if not isinstance(qgraph, QuantumGraph): 

856 raise TypeError(f"QuantumGraph pickle file has contains unexpected object type: {type(qgraph)}") 

857 return qgraph 

858 

859 def iterTaskGraph(self) -> Generator[TaskDef, None, None]: 

860 """Iterate over the `taskGraph` attribute in topological order 

861 

862 Yields 

863 ------ 

864 taskDef : `TaskDef` 

865 `TaskDef` objects in topological order 

866 """ 

867 yield from nx.topological_sort(self.taskGraph) 

868 

869 @property 

870 def graphID(self): 

871 """Returns the ID generated by the graph at construction time 

872 """ 

873 return self._buildId 

874 

875 def __iter__(self) -> Generator[QuantumNode, None, None]: 

876 yield from nx.topological_sort(self._connectedQuanta) 

877 

878 def __len__(self) -> int: 

879 return self._count 

880 

881 def __contains__(self, node: QuantumNode) -> bool: 

882 return self._connectedQuanta.has_node(node) 

883 

884 def __getstate__(self) -> dict: 

885 """Stores a compact form of the graph as a list of graph nodes, and a 

886 tuple of task labels and task configs. The full graph can be 

887 reconstructed with this information, and it preseves the ordering of 

888 the graph ndoes. 

889 """ 

890 return {"nodesList": list(self)} 

891 

892 def __setstate__(self, state: dict): 

893 """Reconstructs the state of the graph from the information persisted 

894 in getstate. 

895 """ 

896 quanta: DefaultDict[TaskDef, Set[Quantum]] = defaultdict(set) 

897 quantumToNodeId: Dict[Quantum, NodeId] = {} 

898 quantumNode: QuantumNode 

899 for quantumNode in state['nodesList']: 

900 quanta[quantumNode.taskDef].add(quantumNode.quantum) 

901 quantumToNodeId[quantumNode.quantum] = quantumNode.nodeId 

902 _buildId = quantumNode.nodeId.buildId if state['nodesList'] else None # type: ignore 

903 self._buildGraphs(quanta, _quantumToNodeId=quantumToNodeId, _buildId=_buildId) 

904 

905 def __eq__(self, other: object) -> bool: 

906 if not isinstance(other, QuantumGraph): 

907 return False 

908 if len(self) != len(other): 

909 return False 

910 for node in self: 

911 if node not in other: 

912 return False 

913 if self.determineInputsToQuantumNode(node) != other.determineInputsToQuantumNode(node): 

914 return False 

915 if self.determineOutputsOfQuantumNode(node) != other.determineOutputsOfQuantumNode(node): 

916 return False 

917 return list(self.taskGraph) == list(other.taskGraph)