Coverage for python/lsst/pipe/base/graph/graph.py: 18%

351 statements  

« prev     ^ index     » next       coverage.py v7.2.2, created at 2023-03-23 23:10 +0000

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("QuantumGraph", "IncompatibleGraphError") 

24 

25import io 

26import json 

27import lzma 

28import os 

29import pickle 

30import struct 

31import time 

32import uuid 

33import warnings 

34from collections import defaultdict, deque 

35from itertools import chain 

36from types import MappingProxyType 

37from typing import ( 

38 Any, 

39 BinaryIO, 

40 DefaultDict, 

41 Deque, 

42 Dict, 

43 FrozenSet, 

44 Generator, 

45 Iterable, 

46 List, 

47 Mapping, 

48 MutableMapping, 

49 Optional, 

50 Set, 

51 Tuple, 

52 TypeVar, 

53 Union, 

54) 

55 

56import networkx as nx 

57from lsst.daf.butler import DatasetRef, DatasetType, DimensionRecordsAccumulator, DimensionUniverse, Quantum 

58from lsst.resources import ResourcePath, ResourcePathExpression 

59from lsst.utils.introspection import get_full_type_name 

60from networkx.drawing.nx_agraph import write_dot 

61 

62from ..connections import iterConnections 

63from ..pipeline import TaskDef 

64from ._implDetails import DatasetTypeName, _DatasetTracker, _pruner 

65from ._loadHelpers import LoadHelper 

66from ._versionDeserializers import DESERIALIZER_MAP 

67from .quantumNode import BuildId, QuantumNode 

68 

69_T = TypeVar("_T", bound="QuantumGraph") 

70 

71# modify this constant any time the on disk representation of the save file 

72# changes, and update the load helpers to behave properly for each version. 

73SAVE_VERSION = 3 

74 

75# Strings used to describe the format for the preamble bytes in a file save 

76# The base is a big endian encoded unsigned short that is used to hold the 

77# file format version. This allows reading version bytes and determine which 

78# loading code should be used for the rest of the file 

79STRUCT_FMT_BASE = ">H" 

80# 

81# Version 1 

82# This marks a big endian encoded format with an unsigned short, an unsigned 

83# long long, and an unsigned long long in the byte stream 

84# Version 2 

85# A big endian encoded format with an unsigned long long byte stream used to 

86# indicate the total length of the entire header. 

87STRUCT_FMT_STRING = {1: ">QQ", 2: ">Q"} 

88 

89# magic bytes that help determine this is a graph save 

90MAGIC_BYTES = b"qgraph4\xf6\xe8\xa9" 

91 

92 

93class IncompatibleGraphError(Exception): 

94 """Exception class to indicate that a lookup by NodeId is impossible due 

95 to incompatibilities 

96 """ 

97 

98 pass 

99 

100 

101class QuantumGraph: 

102 """QuantumGraph is a directed acyclic graph of `QuantumNode` objects 

103 

104 This data structure represents a concrete workflow generated from a 

105 `Pipeline`. 

106 

107 Parameters 

108 ---------- 

109 quanta : Mapping of `TaskDef` to sets of `Quantum` 

110 This maps tasks (and their configs) to the sets of data they are to 

111 process. 

112 metadata : Optional Mapping of `str` to primitives 

113 This is an optional parameter of extra data to carry with the graph. 

114 Entries in this mapping should be able to be serialized in JSON. 

115 

116 Raises 

117 ------ 

118 ValueError 

119 Raised if the graph is pruned such that some tasks no longer have nodes 

120 associated with them. 

121 """ 

122 

123 def __init__( 

124 self, 

125 quanta: Mapping[TaskDef, Set[Quantum]], 

126 metadata: Optional[Mapping[str, Any]] = None, 

127 pruneRefs: Optional[Iterable[DatasetRef]] = None, 

128 universe: Optional[DimensionUniverse] = None, 

129 ): 

130 self._buildGraphs(quanta, metadata=metadata, pruneRefs=pruneRefs, universe=universe) 

131 

132 def _buildGraphs( 

133 self, 

134 quanta: Mapping[TaskDef, Set[Quantum]], 

135 *, 

136 _quantumToNodeId: Optional[Mapping[Quantum, uuid.UUID]] = None, 

137 _buildId: Optional[BuildId] = None, 

138 metadata: Optional[Mapping[str, Any]] = None, 

139 pruneRefs: Optional[Iterable[DatasetRef]] = None, 

140 universe: Optional[DimensionUniverse] = None, 

141 ) -> None: 

142 """Builds the graph that is used to store the relation between tasks, 

143 and the graph that holds the relations between quanta 

144 """ 

145 self._metadata = metadata 

146 self._buildId = _buildId if _buildId is not None else BuildId(f"{time.time()}-{os.getpid()}") 

147 # Data structures used to identify relations between components; 

148 # DatasetTypeName -> TaskDef for task, 

149 # and DatasetRef -> QuantumNode for the quanta 

150 self._datasetDict = _DatasetTracker[DatasetTypeName, TaskDef](createInverse=True) 

151 self._datasetRefDict = _DatasetTracker[DatasetRef, QuantumNode]() 

152 

153 self._nodeIdMap: Dict[uuid.UUID, QuantumNode] = {} 

154 self._taskToQuantumNode: MutableMapping[TaskDef, Set[QuantumNode]] = defaultdict(set) 

155 for taskDef, quantumSet in quanta.items(): 

156 connections = taskDef.connections 

157 

158 # For each type of connection in the task, add a key to the 

159 # `_DatasetTracker` for the connections name, with a value of 

160 # the TaskDef in the appropriate field 

161 for inpt in iterConnections(connections, ("inputs", "prerequisiteInputs", "initInputs")): 

162 self._datasetDict.addConsumer(DatasetTypeName(inpt.name), taskDef) 

163 

164 for output in iterConnections(connections, ("outputs",)): 

165 self._datasetDict.addProducer(DatasetTypeName(output.name), taskDef) 

166 

167 # For each `Quantum` in the set of all `Quantum` for this task, 

168 # add a key to the `_DatasetTracker` that is a `DatasetRef` for one 

169 # of the individual datasets inside the `Quantum`, with a value of 

170 # a newly created QuantumNode to the appropriate input/output 

171 # field. 

172 for quantum in quantumSet: 

173 if quantum.dataId is not None: 

174 if universe is None: 

175 universe = quantum.dataId.universe 

176 elif universe != quantum.dataId.universe: 

177 raise RuntimeError( 

178 "Mismatched dimension universes in QuantumGraph construction: " 

179 f"{universe} != {quantum.dataId.universe}. " 

180 ) 

181 

182 if _quantumToNodeId: 

183 if (nodeId := _quantumToNodeId.get(quantum)) is None: 

184 raise ValueError( 

185 "If _quantuMToNodeNumber is not None, all quanta must have an " 

186 "associated value in the mapping" 

187 ) 

188 else: 

189 nodeId = uuid.uuid4() 

190 

191 inits = quantum.initInputs.values() 

192 inputs = quantum.inputs.values() 

193 value = QuantumNode(quantum, taskDef, nodeId) 

194 self._taskToQuantumNode[taskDef].add(value) 

195 self._nodeIdMap[nodeId] = value 

196 

197 for dsRef in chain(inits, inputs): 

198 # unfortunately, `Quantum` allows inits to be individual 

199 # `DatasetRef`s or an Iterable of such, so there must 

200 # be an instance check here 

201 if isinstance(dsRef, Iterable): 

202 for sub in dsRef: 

203 if sub.isComponent(): 

204 sub = sub.makeCompositeRef() 

205 self._datasetRefDict.addConsumer(sub, value) 

206 else: 

207 assert isinstance(dsRef, DatasetRef) 

208 if dsRef.isComponent(): 

209 dsRef = dsRef.makeCompositeRef() 

210 self._datasetRefDict.addConsumer(dsRef, value) 

211 for dsRef in chain.from_iterable(quantum.outputs.values()): 

212 self._datasetRefDict.addProducer(dsRef, value) 

213 

214 if pruneRefs is not None: 

215 # track what refs were pruned and prune the graph 

216 prunes: Set[QuantumNode] = set() 

217 _pruner(self._datasetRefDict, pruneRefs, alreadyPruned=prunes) 

218 

219 # recreate the taskToQuantumNode dict removing nodes that have been 

220 # pruned. Keep track of task defs that now have no QuantumNodes 

221 emptyTasks: Set[str] = set() 

222 newTaskToQuantumNode: DefaultDict[TaskDef, Set[QuantumNode]] = defaultdict(set) 

223 # accumulate all types 

224 types_ = set() 

225 # tracker for any pruneRefs that have caused tasks to have no nodes 

226 # This helps the user find out what caused the issues seen. 

227 culprits = set() 

228 # Find all the types from the refs to prune 

229 for r in pruneRefs: 

230 types_.add(r.datasetType) 

231 

232 # For each of the tasks, and their associated nodes, remove any 

233 # any nodes that were pruned. If there are no nodes associated 

234 # with a task, record that task, and find out if that was due to 

235 # a type from an input ref to prune. 

236 for td, taskNodes in self._taskToQuantumNode.items(): 

237 diff = taskNodes.difference(prunes) 

238 if len(diff) == 0: 

239 if len(taskNodes) != 0: 

240 tp: DatasetType 

241 for tp in types_: 

242 if (tmpRefs := next(iter(taskNodes)).quantum.inputs.get(tp)) and not set( 

243 tmpRefs 

244 ).difference(pruneRefs): 

245 culprits.add(tp.name) 

246 emptyTasks.add(td.label) 

247 newTaskToQuantumNode[td] = diff 

248 

249 # update the internal dict 

250 self._taskToQuantumNode = newTaskToQuantumNode 

251 

252 if emptyTasks: 

253 raise ValueError( 

254 f"{', '.join(emptyTasks)} task(s) have no nodes associated with them " 

255 f"after graph pruning; {', '.join(culprits)} caused over-pruning" 

256 ) 

257 

258 # Dimension universe 

259 if universe is None: 

260 raise RuntimeError( 

261 "Dimension universe or at least one quantum with a data ID " 

262 "must be provided when constructing a QuantumGraph." 

263 ) 

264 self._universe = universe 

265 

266 # Graph of quanta relations 

267 self._connectedQuanta = self._datasetRefDict.makeNetworkXGraph() 

268 self._count = len(self._connectedQuanta) 

269 

270 # Graph of task relations, used in various methods 

271 self._taskGraph = self._datasetDict.makeNetworkXGraph() 

272 

273 # convert default dict into a regular to prevent accidental key 

274 # insertion 

275 self._taskToQuantumNode = dict(self._taskToQuantumNode.items()) 

276 

277 @property 

278 def taskGraph(self) -> nx.DiGraph: 

279 """Return a graph representing the relations between the tasks inside 

280 the quantum graph. 

281 

282 Returns 

283 ------- 

284 taskGraph : `networkx.Digraph` 

285 Internal datastructure that holds relations of `TaskDef` objects 

286 """ 

287 return self._taskGraph 

288 

289 @property 

290 def graph(self) -> nx.DiGraph: 

291 """Return a graph representing the relations between all the 

292 `QuantumNode` objects. Largely it should be preferred to iterate 

293 over, and use methods of this class, but sometimes direct access to 

294 the networkx object may be helpful 

295 

296 Returns 

297 ------- 

298 graph : `networkx.Digraph` 

299 Internal datastructure that holds relations of `QuantumNode` 

300 objects 

301 """ 

302 return self._connectedQuanta 

303 

304 @property 

305 def inputQuanta(self) -> Iterable[QuantumNode]: 

306 """Make a `list` of all `QuantumNode` objects that are 'input' nodes 

307 to the graph, meaning those nodes to not depend on any other nodes in 

308 the graph. 

309 

310 Returns 

311 ------- 

312 inputNodes : iterable of `QuantumNode` 

313 A list of nodes that are inputs to the graph 

314 """ 

315 return (q for q, n in self._connectedQuanta.in_degree if n == 0) 

316 

317 @property 

318 def outputQuanta(self) -> Iterable[QuantumNode]: 

319 """Make a `list` of all `QuantumNode` objects that are 'output' nodes 

320 to the graph, meaning those nodes have no nodes that depend them in 

321 the graph. 

322 

323 Returns 

324 ------- 

325 outputNodes : iterable of `QuantumNode` 

326 A list of nodes that are outputs of the graph 

327 """ 

328 return [q for q, n in self._connectedQuanta.out_degree if n == 0] 

329 

330 @property 

331 def allDatasetTypes(self) -> Tuple[DatasetTypeName, ...]: 

332 """Return all the `DatasetTypeName` objects that are contained inside 

333 the graph. 

334 

335 Returns 

336 ------- 

337 tuple of `DatasetTypeName` 

338 All the data set type names that are present in the graph 

339 """ 

340 return tuple(self._datasetDict.keys()) 

341 

342 @property 

343 def isConnected(self) -> bool: 

344 """Return True if all of the nodes in the graph are connected, ignores 

345 directionality of connections. 

346 """ 

347 return nx.is_weakly_connected(self._connectedQuanta) 

348 

349 def pruneGraphFromRefs(self: _T, refs: Iterable[DatasetRef]) -> _T: 

350 r"""Return a graph pruned of input `~lsst.daf.butler.DatasetRef`\ s 

351 and nodes which depend on them. 

352 

353 Parameters 

354 ---------- 

355 refs : `Iterable` of `DatasetRef` 

356 Refs which should be removed from resulting graph 

357 

358 Returns 

359 ------- 

360 graph : `QuantumGraph` 

361 A graph that has been pruned of specified refs and the nodes that 

362 depend on them. 

363 """ 

364 newInst = object.__new__(type(self)) 

365 quantumMap = defaultdict(set) 

366 for node in self: 

367 quantumMap[node.taskDef].add(node.quantum) 

368 

369 # convert to standard dict to prevent accidental key insertion 

370 quantumDict: Dict[TaskDef, Set[Quantum]] = dict(quantumMap.items()) 

371 

372 newInst._buildGraphs( 

373 quantumDict, 

374 _quantumToNodeId={n.quantum: n.nodeId for n in self}, 

375 metadata=self._metadata, 

376 pruneRefs=refs, 

377 universe=self._universe, 

378 ) 

379 return newInst 

380 

381 def getQuantumNodeByNodeId(self, nodeId: uuid.UUID) -> QuantumNode: 

382 """Lookup a `QuantumNode` from an id associated with the node. 

383 

384 Parameters 

385 ---------- 

386 nodeId : `NodeId` 

387 The number associated with a node 

388 

389 Returns 

390 ------- 

391 node : `QuantumNode` 

392 The node corresponding with input number 

393 

394 Raises 

395 ------ 

396 KeyError 

397 Raised if the requested nodeId is not in the graph. 

398 """ 

399 return self._nodeIdMap[nodeId] 

400 

401 def getQuantaForTask(self, taskDef: TaskDef) -> FrozenSet[Quantum]: 

402 """Return all the `Quantum` associated with a `TaskDef`. 

403 

404 Parameters 

405 ---------- 

406 taskDef : `TaskDef` 

407 The `TaskDef` for which `Quantum` are to be queried 

408 

409 Returns 

410 ------- 

411 frozenset of `Quantum` 

412 The `set` of `Quantum` that is associated with the specified 

413 `TaskDef`. 

414 """ 

415 return frozenset(node.quantum for node in self._taskToQuantumNode[taskDef]) 

416 

417 def getNodesForTask(self, taskDef: TaskDef) -> FrozenSet[QuantumNode]: 

418 """Return all the `QuantumNodes` associated with a `TaskDef`. 

419 

420 Parameters 

421 ---------- 

422 taskDef : `TaskDef` 

423 The `TaskDef` for which `Quantum` are to be queried 

424 

425 Returns 

426 ------- 

427 frozenset of `QuantumNodes` 

428 The `frozenset` of `QuantumNodes` that is associated with the 

429 specified `TaskDef`. 

430 """ 

431 return frozenset(self._taskToQuantumNode[taskDef]) 

432 

433 def findTasksWithInput(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]: 

434 """Find all tasks that have the specified dataset type name as an 

435 input. 

436 

437 Parameters 

438 ---------- 

439 datasetTypeName : `str` 

440 A string representing the name of a dataset type to be queried, 

441 can also accept a `DatasetTypeName` which is a `NewType` of str for 

442 type safety in static type checking. 

443 

444 Returns 

445 ------- 

446 tasks : iterable of `TaskDef` 

447 `TaskDef` objects that have the specified `DatasetTypeName` as an 

448 input, list will be empty if no tasks use specified 

449 `DatasetTypeName` as an input. 

450 

451 Raises 

452 ------ 

453 KeyError 

454 Raised if the `DatasetTypeName` is not part of the `QuantumGraph` 

455 """ 

456 return (c for c in self._datasetDict.getConsumers(datasetTypeName)) 

457 

458 def findTaskWithOutput(self, datasetTypeName: DatasetTypeName) -> Optional[TaskDef]: 

459 """Find all tasks that have the specified dataset type name as an 

460 output. 

461 

462 Parameters 

463 ---------- 

464 datasetTypeName : `str` 

465 A string representing the name of a dataset type to be queried, 

466 can also accept a `DatasetTypeName` which is a `NewType` of str for 

467 type safety in static type checking. 

468 

469 Returns 

470 ------- 

471 `TaskDef` or `None` 

472 `TaskDef` that outputs `DatasetTypeName` as an output or None if 

473 none of the tasks produce this `DatasetTypeName`. 

474 

475 Raises 

476 ------ 

477 KeyError 

478 Raised if the `DatasetTypeName` is not part of the `QuantumGraph` 

479 """ 

480 return self._datasetDict.getProducer(datasetTypeName) 

481 

482 def tasksWithDSType(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]: 

483 """Find all tasks that are associated with the specified dataset type 

484 name. 

485 

486 Parameters 

487 ---------- 

488 datasetTypeName : `str` 

489 A string representing the name of a dataset type to be queried, 

490 can also accept a `DatasetTypeName` which is a `NewType` of str for 

491 type safety in static type checking. 

492 

493 Returns 

494 ------- 

495 result : iterable of `TaskDef` 

496 `TaskDef` objects that are associated with the specified 

497 `DatasetTypeName` 

498 

499 Raises 

500 ------ 

501 KeyError 

502 Raised if the `DatasetTypeName` is not part of the `QuantumGraph` 

503 """ 

504 return self._datasetDict.getAll(datasetTypeName) 

505 

506 def findTaskDefByName(self, taskName: str) -> List[TaskDef]: 

507 """Determine which `TaskDef` objects in this graph are associated 

508 with a `str` representing a task name (looks at the taskName property 

509 of `TaskDef` objects). 

510 

511 Returns a list of `TaskDef` objects as a `PipelineTask` may appear 

512 multiple times in a graph with different labels. 

513 

514 Parameters 

515 ---------- 

516 taskName : str 

517 Name of a task to search for 

518 

519 Returns 

520 ------- 

521 result : list of `TaskDef` 

522 List of the `TaskDef` objects that have the name specified. 

523 Multiple values are returned in the case that a task is used 

524 multiple times with different labels. 

525 """ 

526 results = [] 

527 for task in self._taskToQuantumNode.keys(): 

528 split = task.taskName.split(".") 

529 if split[-1] == taskName: 

530 results.append(task) 

531 return results 

532 

533 def findTaskDefByLabel(self, label: str) -> Optional[TaskDef]: 

534 """Determine which `TaskDef` objects in this graph are associated 

535 with a `str` representing a tasks label. 

536 

537 Parameters 

538 ---------- 

539 taskName : str 

540 Name of a task to search for 

541 

542 Returns 

543 ------- 

544 result : `TaskDef` 

545 `TaskDef` objects that has the specified label. 

546 """ 

547 for task in self._taskToQuantumNode.keys(): 

548 if label == task.label: 

549 return task 

550 return None 

551 

552 def findQuantaWithDSType(self, datasetTypeName: DatasetTypeName) -> Set[Quantum]: 

553 """Return all the `Quantum` that contain a specified `DatasetTypeName`. 

554 

555 Parameters 

556 ---------- 

557 datasetTypeName : `str` 

558 The name of the dataset type to search for as a string, 

559 can also accept a `DatasetTypeName` which is a `NewType` of str for 

560 type safety in static type checking. 

561 

562 Returns 

563 ------- 

564 result : `set` of `QuantumNode` objects 

565 A `set` of `QuantumNode`s that contain specified `DatasetTypeName` 

566 

567 Raises 

568 ------ 

569 KeyError 

570 Raised if the `DatasetTypeName` is not part of the `QuantumGraph` 

571 

572 """ 

573 tasks = self._datasetDict.getAll(datasetTypeName) 

574 result: Set[Quantum] = set() 

575 result = result.union(quantum for task in tasks for quantum in self.getQuantaForTask(task)) 

576 return result 

577 

578 def checkQuantumInGraph(self, quantum: Quantum) -> bool: 

579 """Check if specified quantum appears in the graph as part of a node. 

580 

581 Parameters 

582 ---------- 

583 quantum : `Quantum` 

584 The quantum to search for 

585 

586 Returns 

587 ------- 

588 `bool` 

589 The result of searching for the quantum 

590 """ 

591 for node in self: 

592 if quantum == node.quantum: 

593 return True 

594 return False 

595 

596 def writeDotGraph(self, output: Union[str, io.BufferedIOBase]) -> None: 

597 """Write out the graph as a dot graph. 

598 

599 Parameters 

600 ---------- 

601 output : str or `io.BufferedIOBase` 

602 Either a filesystem path to write to, or a file handle object 

603 """ 

604 write_dot(self._connectedQuanta, output) 

605 

606 def subset(self: _T, nodes: Union[QuantumNode, Iterable[QuantumNode]]) -> _T: 

607 """Create a new graph object that contains the subset of the nodes 

608 specified as input. Node number is preserved. 

609 

610 Parameters 

611 ---------- 

612 nodes : `QuantumNode` or iterable of `QuantumNode` 

613 

614 Returns 

615 ------- 

616 graph : instance of graph type 

617 An instance of the type from which the subset was created 

618 """ 

619 if not isinstance(nodes, Iterable): 

620 nodes = (nodes,) 

621 quantumSubgraph = self._connectedQuanta.subgraph(nodes).nodes 

622 quantumMap = defaultdict(set) 

623 

624 node: QuantumNode 

625 for node in quantumSubgraph: 

626 quantumMap[node.taskDef].add(node.quantum) 

627 

628 # convert to standard dict to prevent accidental key insertion 

629 quantumDict: Dict[TaskDef, Set[Quantum]] = dict(quantumMap.items()) 

630 # Create an empty graph, and then populate it with custom mapping 

631 newInst = type(self)({}, universe=self._universe) 

632 newInst._buildGraphs( 

633 quantumDict, 

634 _quantumToNodeId={n.quantum: n.nodeId for n in nodes}, 

635 _buildId=self._buildId, 

636 metadata=self._metadata, 

637 universe=self._universe, 

638 ) 

639 return newInst 

640 

641 def subsetToConnected(self: _T) -> Tuple[_T, ...]: 

642 """Generate a list of subgraphs where each is connected. 

643 

644 Returns 

645 ------- 

646 result : list of `QuantumGraph` 

647 A list of graphs that are each connected 

648 """ 

649 return tuple( 

650 self.subset(connectedSet) 

651 for connectedSet in nx.weakly_connected_components(self._connectedQuanta) 

652 ) 

653 

654 def determineInputsToQuantumNode(self, node: QuantumNode) -> Set[QuantumNode]: 

655 """Return a set of `QuantumNode` that are direct inputs to a specified 

656 node. 

657 

658 Parameters 

659 ---------- 

660 node : `QuantumNode` 

661 The node of the graph for which inputs are to be determined 

662 

663 Returns 

664 ------- 

665 set of `QuantumNode` 

666 All the nodes that are direct inputs to specified node 

667 """ 

668 return set(pred for pred in self._connectedQuanta.predecessors(node)) 

669 

670 def determineOutputsOfQuantumNode(self, node: QuantumNode) -> Set[QuantumNode]: 

671 """Return a set of `QuantumNode` that are direct outputs of a specified 

672 node. 

673 

674 Parameters 

675 ---------- 

676 node : `QuantumNode` 

677 The node of the graph for which outputs are to be determined 

678 

679 Returns 

680 ------- 

681 set of `QuantumNode` 

682 All the nodes that are direct outputs to specified node 

683 """ 

684 return set(succ for succ in self._connectedQuanta.successors(node)) 

685 

686 def determineConnectionsOfQuantumNode(self: _T, node: QuantumNode) -> _T: 

687 """Return a graph of `QuantumNode` that are direct inputs and outputs 

688 of a specified node. 

689 

690 Parameters 

691 ---------- 

692 node : `QuantumNode` 

693 The node of the graph for which connected nodes are to be 

694 determined. 

695 

696 Returns 

697 ------- 

698 graph : graph of `QuantumNode` 

699 All the nodes that are directly connected to specified node 

700 """ 

701 nodes = self.determineInputsToQuantumNode(node).union(self.determineOutputsOfQuantumNode(node)) 

702 nodes.add(node) 

703 return self.subset(nodes) 

704 

705 def determineAncestorsOfQuantumNode(self: _T, node: QuantumNode) -> _T: 

706 """Return a graph of the specified node and all the ancestor nodes 

707 directly reachable by walking edges. 

708 

709 Parameters 

710 ---------- 

711 node : `QuantumNode` 

712 The node for which all ansestors are to be determined 

713 

714 Returns 

715 ------- 

716 graph of `QuantumNode` 

717 Graph of node and all of its ansestors 

718 """ 

719 predecessorNodes = nx.ancestors(self._connectedQuanta, node) 

720 predecessorNodes.add(node) 

721 return self.subset(predecessorNodes) 

722 

723 def findCycle(self) -> List[Tuple[QuantumNode, QuantumNode]]: 

724 """Check a graph for the presense of cycles and returns the edges of 

725 any cycles found, or an empty list if there is no cycle. 

726 

727 Returns 

728 ------- 

729 result : list of tuple of `QuantumNode`, `QuantumNode` 

730 A list of any graph edges that form a cycle, or an empty list if 

731 there is no cycle. Empty list to so support if graph.find_cycle() 

732 syntax as an empty list is falsy. 

733 """ 

734 try: 

735 return nx.find_cycle(self._connectedQuanta) 

736 except nx.NetworkXNoCycle: 

737 return [] 

738 

739 def saveUri(self, uri: ResourcePathExpression) -> None: 

740 """Save `QuantumGraph` to the specified URI. 

741 

742 Parameters 

743 ---------- 

744 uri : convertible to `ResourcePath` 

745 URI to where the graph should be saved. 

746 """ 

747 buffer = self._buildSaveObject() 

748 path = ResourcePath(uri) 

749 if path.getExtension() not in (".qgraph"): 

750 raise TypeError(f"Can currently only save a graph in qgraph format not {uri}") 

751 path.write(buffer) # type: ignore # Ignore because bytearray is safe to use in place of bytes 

752 

753 @property 

754 def metadata(self) -> Optional[MappingProxyType[str, Any]]: 

755 """ """ 

756 if self._metadata is None: 

757 return None 

758 return MappingProxyType(self._metadata) 

759 

760 @classmethod 

761 def loadUri( 

762 cls, 

763 uri: ResourcePathExpression, 

764 universe: Optional[DimensionUniverse] = None, 

765 nodes: Optional[Iterable[uuid.UUID]] = None, 

766 graphID: Optional[BuildId] = None, 

767 minimumVersion: int = 3, 

768 ) -> QuantumGraph: 

769 """Read `QuantumGraph` from a URI. 

770 

771 Parameters 

772 ---------- 

773 uri : convertible to `ResourcePath` 

774 URI from where to load the graph. 

775 universe: `~lsst.daf.butler.DimensionUniverse` optional 

776 DimensionUniverse instance, not used by the method itself but 

777 needed to ensure that registry data structures are initialized. 

778 If None it is loaded from the QuantumGraph saved structure. If 

779 supplied, the DimensionUniverse from the loaded `QuantumGraph` 

780 will be validated against the supplied argument for compatibility. 

781 nodes: iterable of `int` or None 

782 Numbers that correspond to nodes in the graph. If specified, only 

783 these nodes will be loaded. Defaults to None, in which case all 

784 nodes will be loaded. 

785 graphID : `str` or `None` 

786 If specified this ID is verified against the loaded graph prior to 

787 loading any Nodes. This defaults to None in which case no 

788 validation is done. 

789 minimumVersion : int 

790 Minimum version of a save file to load. Set to -1 to load all 

791 versions. Older versions may need to be loaded, and re-saved 

792 to upgrade them to the latest format before they can be used in 

793 production. 

794 

795 Returns 

796 ------- 

797 graph : `QuantumGraph` 

798 Resulting QuantumGraph instance. 

799 

800 Raises 

801 ------ 

802 TypeError 

803 Raised if pickle contains instance of a type other than 

804 QuantumGraph. 

805 ValueError 

806 Raised if one or more of the nodes requested is not in the 

807 `QuantumGraph` or if graphID parameter does not match the graph 

808 being loaded or if the supplied uri does not point at a valid 

809 `QuantumGraph` save file. 

810 RuntimeError 

811 Raise if Supplied DimensionUniverse is not compatible with the 

812 DimensionUniverse saved in the graph 

813 

814 

815 Notes 

816 ----- 

817 Reading Quanta from pickle requires existence of singleton 

818 DimensionUniverse which is usually instantiated during Registry 

819 initialization. To make sure that DimensionUniverse exists this method 

820 accepts dummy DimensionUniverse argument. 

821 """ 

822 uri = ResourcePath(uri) 

823 # With ResourcePath we have the choice of always using a local file 

824 # or reading in the bytes directly. Reading in bytes can be more 

825 # efficient for reasonably-sized pickle files when the resource 

826 # is remote. For now use the local file variant. For a local file 

827 # as_local() does nothing. 

828 

829 if uri.getExtension() in (".pickle", ".pkl"): 

830 with uri.as_local() as local, open(local.ospath, "rb") as fd: 

831 warnings.warn("Pickle graphs are deprecated, please re-save your graph with the save method") 

832 qgraph = pickle.load(fd) 

833 elif uri.getExtension() in (".qgraph"): 

834 with LoadHelper(uri, minimumVersion) as loader: 

835 qgraph = loader.load(universe, nodes, graphID) 

836 else: 

837 raise ValueError("Only know how to handle files saved as `pickle`, `pkl`, or `qgraph`") 

838 if not isinstance(qgraph, QuantumGraph): 

839 raise TypeError(f"QuantumGraph save file contains unexpected object type: {type(qgraph)}") 

840 return qgraph 

841 

842 @classmethod 

843 def readHeader(cls, uri: ResourcePathExpression, minimumVersion: int = 3) -> Optional[str]: 

844 """Read the header of a `QuantumGraph` pointed to by the uri parameter 

845 and return it as a string. 

846 

847 Parameters 

848 ---------- 

849 uri : convertible to `ResourcePath` 

850 The location of the `QuantumGraph` to load. If the argument is a 

851 string, it must correspond to a valid `ResourcePath` path. 

852 minimumVersion : int 

853 Minimum version of a save file to load. Set to -1 to load all 

854 versions. Older versions may need to be loaded, and re-saved 

855 to upgrade them to the latest format before they can be used in 

856 production. 

857 

858 Returns 

859 ------- 

860 header : `str` or `None` 

861 The header associated with the specified `QuantumGraph` it there is 

862 one, else `None`. 

863 

864 Raises 

865 ------ 

866 ValueError 

867 Raised if `QuantuGraph` was saved as a pickle. 

868 Raised if the extention of the file specified by uri is not a 

869 `QuantumGraph` extention. 

870 """ 

871 uri = ResourcePath(uri) 

872 if uri.getExtension() in (".pickle", ".pkl"): 

873 raise ValueError("Reading a header from a pickle save is not supported") 

874 elif uri.getExtension() in (".qgraph"): 

875 return LoadHelper(uri, minimumVersion).readHeader() 

876 else: 

877 raise ValueError("Only know how to handle files saved as `qgraph`") 

878 

879 def buildAndPrintHeader(self) -> None: 

880 """Creates a header that would be used in a save of this object and 

881 prints it out to standard out. 

882 """ 

883 _, header = self._buildSaveObject(returnHeader=True) 

884 print(json.dumps(header)) 

885 

886 def save(self, file: BinaryIO) -> None: 

887 """Save QuantumGraph to a file. 

888 

889 Parameters 

890 ---------- 

891 file : `io.BufferedIOBase` 

892 File to write pickle data open in binary mode. 

893 """ 

894 buffer = self._buildSaveObject() 

895 file.write(buffer) # type: ignore # Ignore because bytearray is safe to use in place of bytes 

896 

897 def _buildSaveObject(self, returnHeader: bool = False) -> Union[bytearray, Tuple[bytearray, Dict]]: 

898 # make some containers 

899 jsonData: Deque[bytes] = deque() 

900 # node map is a list because json does not accept mapping keys that 

901 # are not strings, so we store a list of key, value pairs that will 

902 # be converted to a mapping on load 

903 nodeMap = [] 

904 taskDefMap = {} 

905 headerData: Dict[str, Any] = {} 

906 

907 # Store the QauntumGraph BuildId, this will allow validating BuildIds 

908 # at load time, prior to loading any QuantumNodes. Name chosen for 

909 # unlikely conflicts. 

910 headerData["GraphBuildID"] = self.graphID 

911 headerData["Metadata"] = self._metadata 

912 

913 # Store the universe this graph was created with 

914 universeConfig = self._universe.dimensionConfig 

915 headerData["universe"] = universeConfig.toDict() 

916 

917 # counter for the number of bytes processed thus far 

918 count = 0 

919 # serialize out the task Defs recording the start and end bytes of each 

920 # taskDef 

921 inverseLookup = self._datasetDict.inverse 

922 taskDef: TaskDef 

923 # sort by task label to ensure serialization happens in the same order 

924 for taskDef in self.taskGraph: 

925 # compressing has very little impact on saving or load time, but 

926 # a large impact on on disk size, so it is worth doing 

927 taskDescription = {} 

928 # save the fully qualified name. 

929 taskDescription["taskName"] = get_full_type_name(taskDef.taskClass) 

930 # save the config as a text stream that will be un-persisted on the 

931 # other end 

932 stream = io.StringIO() 

933 taskDef.config.saveToStream(stream) 

934 taskDescription["config"] = stream.getvalue() 

935 taskDescription["label"] = taskDef.label 

936 

937 inputs = [] 

938 outputs = [] 

939 

940 # Determine the connection between all of tasks and save that in 

941 # the header as a list of connections and edges in each task 

942 # this will help in un-persisting, and possibly in a "quick view" 

943 # method that does not require everything to be un-persisted 

944 # 

945 # Typing returns can't be parameter dependent 

946 for connection in inverseLookup[taskDef]: # type: ignore 

947 consumers = self._datasetDict.getConsumers(connection) 

948 producer = self._datasetDict.getProducer(connection) 

949 if taskDef in consumers: 

950 # This checks if the task consumes the connection directly 

951 # from the datastore or it is produced by another task 

952 producerLabel = producer.label if producer is not None else "datastore" 

953 inputs.append((producerLabel, connection)) 

954 elif taskDef not in consumers and producer is taskDef: 

955 # If there are no consumers for this tasks produced 

956 # connection, the output will be said to be the datastore 

957 # in which case the for loop will be a zero length loop 

958 if not consumers: 

959 outputs.append(("datastore", connection)) 

960 for td in consumers: 

961 outputs.append((td.label, connection)) 

962 

963 # dump to json string, and encode that string to bytes and then 

964 # conpress those bytes 

965 dump = lzma.compress(json.dumps(taskDescription).encode()) 

966 # record the sizing and relation information 

967 taskDefMap[taskDef.label] = { 

968 "bytes": (count, count + len(dump)), 

969 "inputs": inputs, 

970 "outputs": outputs, 

971 } 

972 count += len(dump) 

973 jsonData.append(dump) 

974 

975 headerData["TaskDefs"] = taskDefMap 

976 

977 # serialize the nodes, recording the start and end bytes of each node 

978 dimAccumulator = DimensionRecordsAccumulator() 

979 for node in self: 

980 # compressing has very little impact on saving or load time, but 

981 # a large impact on on disk size, so it is worth doing 

982 simpleNode = node.to_simple(accumulator=dimAccumulator) 

983 

984 dump = lzma.compress(simpleNode.json().encode()) 

985 jsonData.append(dump) 

986 nodeMap.append( 

987 ( 

988 str(node.nodeId), 

989 { 

990 "bytes": (count, count + len(dump)), 

991 "inputs": [str(n.nodeId) for n in self.determineInputsToQuantumNode(node)], 

992 "outputs": [str(n.nodeId) for n in self.determineOutputsOfQuantumNode(node)], 

993 }, 

994 ) 

995 ) 

996 count += len(dump) 

997 

998 headerData["DimensionRecords"] = { 

999 key: value.dict() for key, value in dimAccumulator.makeSerializedDimensionRecordMapping().items() 

1000 } 

1001 

1002 # need to serialize this as a series of key,value tuples because of 

1003 # a limitation on how json cant do anyting but strings as keys 

1004 headerData["Nodes"] = nodeMap 

1005 

1006 # dump the headerData to json 

1007 header_encode = lzma.compress(json.dumps(headerData).encode()) 

1008 

1009 # record the sizes as 2 unsigned long long numbers for a total of 16 

1010 # bytes 

1011 save_bytes = struct.pack(STRUCT_FMT_BASE, SAVE_VERSION) 

1012 

1013 fmt_string = DESERIALIZER_MAP[SAVE_VERSION].FMT_STRING() 

1014 map_lengths = struct.pack(fmt_string, len(header_encode)) 

1015 

1016 # write each component of the save out in a deterministic order 

1017 # buffer = io.BytesIO() 

1018 # buffer.write(map_lengths) 

1019 # buffer.write(taskDef_pickle) 

1020 # buffer.write(map_pickle) 

1021 buffer = bytearray() 

1022 buffer.extend(MAGIC_BYTES) 

1023 buffer.extend(save_bytes) 

1024 buffer.extend(map_lengths) 

1025 buffer.extend(header_encode) 

1026 # Iterate over the length of pickleData, and for each element pop the 

1027 # leftmost element off the deque and write it out. This is to save 

1028 # memory, as the memory is added to the buffer object, it is removed 

1029 # from from the container. 

1030 # 

1031 # Only this section needs to worry about memory pressue because 

1032 # everything else written to the buffer prior to this pickle data is 

1033 # only on the order of kilobytes to low numbers of megabytes. 

1034 while jsonData: 

1035 buffer.extend(jsonData.popleft()) 

1036 if returnHeader: 

1037 return buffer, headerData 

1038 else: 

1039 return buffer 

1040 

1041 @classmethod 

1042 def load( 

1043 cls, 

1044 file: BinaryIO, 

1045 universe: Optional[DimensionUniverse] = None, 

1046 nodes: Optional[Iterable[uuid.UUID]] = None, 

1047 graphID: Optional[BuildId] = None, 

1048 minimumVersion: int = 3, 

1049 ) -> QuantumGraph: 

1050 """Read QuantumGraph from a file that was made by `save`. 

1051 

1052 Parameters 

1053 ---------- 

1054 file : `io.IO` of bytes 

1055 File with pickle data open in binary mode. 

1056 universe: `~lsst.daf.butler.DimensionUniverse`, optional 

1057 DimensionUniverse instance, not used by the method itself but 

1058 needed to ensure that registry data structures are initialized. 

1059 If None it is loaded from the QuantumGraph saved structure. If 

1060 supplied, the DimensionUniverse from the loaded `QuantumGraph` 

1061 will be validated against the supplied argument for compatibility. 

1062 nodes: iterable of `int` or None 

1063 Numbers that correspond to nodes in the graph. If specified, only 

1064 these nodes will be loaded. Defaults to None, in which case all 

1065 nodes will be loaded. 

1066 graphID : `str` or `None` 

1067 If specified this ID is verified against the loaded graph prior to 

1068 loading any Nodes. This defaults to None in which case no 

1069 validation is done. 

1070 minimumVersion : int 

1071 Minimum version of a save file to load. Set to -1 to load all 

1072 versions. Older versions may need to be loaded, and re-saved 

1073 to upgrade them to the latest format before they can be used in 

1074 production. 

1075 

1076 Returns 

1077 ------- 

1078 graph : `QuantumGraph` 

1079 Resulting QuantumGraph instance. 

1080 

1081 Raises 

1082 ------ 

1083 TypeError 

1084 Raised if pickle contains instance of a type other than 

1085 QuantumGraph. 

1086 ValueError 

1087 Raised if one or more of the nodes requested is not in the 

1088 `QuantumGraph` or if graphID parameter does not match the graph 

1089 being loaded or if the supplied uri does not point at a valid 

1090 `QuantumGraph` save file. 

1091 

1092 Notes 

1093 ----- 

1094 Reading Quanta from pickle requires existence of singleton 

1095 DimensionUniverse which is usually instantiated during Registry 

1096 initialization. To make sure that DimensionUniverse exists this method 

1097 accepts dummy DimensionUniverse argument. 

1098 """ 

1099 # Try to see if the file handle contains pickle data, this will be 

1100 # removed in the future 

1101 try: 

1102 qgraph = pickle.load(file) 

1103 warnings.warn("Pickle graphs are deprecated, please re-save your graph with the save method") 

1104 except pickle.UnpicklingError: 

1105 with LoadHelper(file, minimumVersion) as loader: 

1106 qgraph = loader.load(universe, nodes, graphID) 

1107 if not isinstance(qgraph, QuantumGraph): 

1108 raise TypeError(f"QuantumGraph pickle file has contains unexpected object type: {type(qgraph)}") 

1109 return qgraph 

1110 

1111 def iterTaskGraph(self) -> Generator[TaskDef, None, None]: 

1112 """Iterate over the `taskGraph` attribute in topological order 

1113 

1114 Yields 

1115 ------ 

1116 taskDef : `TaskDef` 

1117 `TaskDef` objects in topological order 

1118 """ 

1119 yield from nx.topological_sort(self.taskGraph) 

1120 

1121 @property 

1122 def graphID(self) -> BuildId: 

1123 """Returns the ID generated by the graph at construction time""" 

1124 return self._buildId 

1125 

1126 def __iter__(self) -> Generator[QuantumNode, None, None]: 

1127 yield from nx.topological_sort(self._connectedQuanta) 

1128 

1129 def __len__(self) -> int: 

1130 return self._count 

1131 

1132 def __contains__(self, node: QuantumNode) -> bool: 

1133 return self._connectedQuanta.has_node(node) 

1134 

1135 def __getstate__(self) -> dict: 

1136 """Stores a compact form of the graph as a list of graph nodes, and a 

1137 tuple of task labels and task configs. The full graph can be 

1138 reconstructed with this information, and it preseves the ordering of 

1139 the graph ndoes. 

1140 """ 

1141 universe: Optional[DimensionUniverse] = None 

1142 for node in self: 

1143 dId = node.quantum.dataId 

1144 if dId is None: 

1145 continue 

1146 universe = dId.graph.universe 

1147 return {"reduced": self._buildSaveObject(), "graphId": self._buildId, "universe": universe} 

1148 

1149 def __setstate__(self, state: dict) -> None: 

1150 """Reconstructs the state of the graph from the information persisted 

1151 in getstate. 

1152 """ 

1153 buffer = io.BytesIO(state["reduced"]) 

1154 with LoadHelper(buffer, minimumVersion=3) as loader: 

1155 qgraph = loader.load(state["universe"], graphID=state["graphId"]) 

1156 

1157 self._metadata = qgraph._metadata 

1158 self._buildId = qgraph._buildId 

1159 self._datasetDict = qgraph._datasetDict 

1160 self._nodeIdMap = qgraph._nodeIdMap 

1161 self._count = len(qgraph) 

1162 self._taskToQuantumNode = qgraph._taskToQuantumNode 

1163 self._taskGraph = qgraph._taskGraph 

1164 self._connectedQuanta = qgraph._connectedQuanta 

1165 

1166 def __eq__(self, other: object) -> bool: 

1167 if not isinstance(other, QuantumGraph): 

1168 return False 

1169 if len(self) != len(other): 

1170 return False 

1171 for node in self: 

1172 if node not in other: 

1173 return False 

1174 if self.determineInputsToQuantumNode(node) != other.determineInputsToQuantumNode(node): 

1175 return False 

1176 if self.determineOutputsOfQuantumNode(node) != other.determineOutputsOfQuantumNode(node): 

1177 return False 

1178 if set(self.allDatasetTypes) != set(other.allDatasetTypes): 

1179 return False 

1180 return set(self.taskGraph) == set(other.taskGraph)