Coverage for python/lsst/pipe/base/graph/graph.py: 19%

Shortcuts on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

340 statements  

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23from lsst.daf.butler.core.datasets.type import DatasetType 

24 

25__all__ = ("QuantumGraph", "IncompatibleGraphError") 

26 

27import io 

28import json 

29import lzma 

30import os 

31import pickle 

32import struct 

33import time 

34import uuid 

35import warnings 

36from collections import defaultdict, deque 

37from itertools import chain 

38from types import MappingProxyType 

39from typing import ( 

40 Any, 

41 DefaultDict, 

42 Dict, 

43 FrozenSet, 

44 Generator, 

45 Iterable, 

46 List, 

47 Mapping, 

48 MutableMapping, 

49 Optional, 

50 Set, 

51 Tuple, 

52 TypeVar, 

53 Union, 

54) 

55 

56import networkx as nx 

57from lsst.daf.butler import DatasetRef, DimensionRecordsAccumulator, DimensionUniverse, Quantum 

58from lsst.resources import ResourcePath, ResourcePathExpression 

59from networkx.drawing.nx_agraph import write_dot 

60 

61from ..connections import iterConnections 

62from ..pipeline import TaskDef 

63from ._implDetails import DatasetTypeName, _DatasetTracker, _pruner 

64from ._loadHelpers import LoadHelper 

65from ._versionDeserializers import DESERIALIZER_MAP 

66from .quantumNode import BuildId, QuantumNode 

67 

68_T = TypeVar("_T", bound="QuantumGraph") 

69 

70# modify this constant any time the on disk representation of the save file 

71# changes, and update the load helpers to behave properly for each version. 

72SAVE_VERSION = 3 

73 

74# Strings used to describe the format for the preamble bytes in a file save 

75# The base is a big endian encoded unsigned short that is used to hold the 

76# file format version. This allows reading version bytes and determine which 

77# loading code should be used for the rest of the file 

78STRUCT_FMT_BASE = ">H" 

79# 

80# Version 1 

81# This marks a big endian encoded format with an unsigned short, an unsigned 

82# long long, and an unsigned long long in the byte stream 

83# Version 2 

84# A big endian encoded format with an unsigned long long byte stream used to 

85# indicate the total length of the entire header. 

86STRUCT_FMT_STRING = {1: ">QQ", 2: ">Q"} 

87 

88# magic bytes that help determine this is a graph save 

89MAGIC_BYTES = b"qgraph4\xf6\xe8\xa9" 

90 

91 

92class IncompatibleGraphError(Exception): 

93 """Exception class to indicate that a lookup by NodeId is impossible due 

94 to incompatibilities 

95 """ 

96 

97 pass 

98 

99 

100class QuantumGraph: 

101 """QuantumGraph is a directed acyclic graph of `QuantumNode` objects 

102 

103 This data structure represents a concrete workflow generated from a 

104 `Pipeline`. 

105 

106 Parameters 

107 ---------- 

108 quanta : Mapping of `TaskDef` to sets of `Quantum` 

109 This maps tasks (and their configs) to the sets of data they are to 

110 process. 

111 metadata : Optional Mapping of `str` to primitives 

112 This is an optional parameter of extra data to carry with the graph. 

113 Entries in this mapping should be able to be serialized in JSON. 

114 

115 Raises 

116 ------ 

117 ValueError 

118 Raised if the graph is pruned such that some tasks no longer have nodes 

119 associated with them. 

120 """ 

121 

122 def __init__( 

123 self, 

124 quanta: Mapping[TaskDef, Set[Quantum]], 

125 metadata: Optional[Mapping[str, Any]] = None, 

126 pruneRefs: Optional[Iterable[DatasetRef]] = None, 

127 ): 

128 self._buildGraphs(quanta, metadata=metadata, pruneRefs=pruneRefs) 

129 

130 def _buildGraphs( 

131 self, 

132 quanta: Mapping[TaskDef, Set[Quantum]], 

133 *, 

134 _quantumToNodeId: Optional[Mapping[Quantum, uuid.UUID]] = None, 

135 _buildId: Optional[BuildId] = None, 

136 metadata: Optional[Mapping[str, Any]] = None, 

137 pruneRefs: Optional[Iterable[DatasetRef]] = None, 

138 ): 

139 """Builds the graph that is used to store the relation between tasks, 

140 and the graph that holds the relations between quanta 

141 """ 

142 self._metadata = metadata 

143 self._buildId = _buildId if _buildId is not None else BuildId(f"{time.time()}-{os.getpid()}") 

144 # Data structures used to identify relations between components; 

145 # DatasetTypeName -> TaskDef for task, 

146 # and DatasetRef -> QuantumNode for the quanta 

147 self._datasetDict = _DatasetTracker[DatasetTypeName, TaskDef](createInverse=True) 

148 self._datasetRefDict = _DatasetTracker[DatasetRef, QuantumNode]() 

149 

150 self._nodeIdMap: Dict[uuid.UUID, QuantumNode] = {} 

151 self._taskToQuantumNode: MutableMapping[TaskDef, Set[QuantumNode]] = defaultdict(set) 

152 for taskDef, quantumSet in quanta.items(): 

153 connections = taskDef.connections 

154 

155 # For each type of connection in the task, add a key to the 

156 # `_DatasetTracker` for the connections name, with a value of 

157 # the TaskDef in the appropriate field 

158 for inpt in iterConnections(connections, ("inputs", "prerequisiteInputs", "initInputs")): 

159 self._datasetDict.addConsumer(DatasetTypeName(inpt.name), taskDef) 

160 

161 for output in iterConnections(connections, ("outputs",)): 

162 self._datasetDict.addProducer(DatasetTypeName(output.name), taskDef) 

163 

164 # For each `Quantum` in the set of all `Quantum` for this task, 

165 # add a key to the `_DatasetTracker` that is a `DatasetRef` for one 

166 # of the individual datasets inside the `Quantum`, with a value of 

167 # a newly created QuantumNode to the appropriate input/output 

168 # field. 

169 for quantum in quantumSet: 

170 if _quantumToNodeId: 

171 if (nodeId := _quantumToNodeId.get(quantum)) is None: 

172 raise ValueError( 

173 "If _quantuMToNodeNumber is not None, all quanta must have an " 

174 "associated value in the mapping" 

175 ) 

176 else: 

177 nodeId = uuid.uuid4() 

178 

179 inits = quantum.initInputs.values() 

180 inputs = quantum.inputs.values() 

181 value = QuantumNode(quantum, taskDef, nodeId) 

182 self._taskToQuantumNode[taskDef].add(value) 

183 self._nodeIdMap[nodeId] = value 

184 

185 for dsRef in chain(inits, inputs): 

186 # unfortunately, `Quantum` allows inits to be individual 

187 # `DatasetRef`s or an Iterable of such, so there must 

188 # be an instance check here 

189 if isinstance(dsRef, Iterable): 

190 for sub in dsRef: 

191 if sub.isComponent(): 

192 sub = sub.makeCompositeRef() 

193 self._datasetRefDict.addConsumer(sub, value) 

194 else: 

195 if dsRef.isComponent(): 

196 dsRef = dsRef.makeCompositeRef() 

197 self._datasetRefDict.addConsumer(dsRef, value) 

198 for dsRef in chain.from_iterable(quantum.outputs.values()): 

199 self._datasetRefDict.addProducer(dsRef, value) 

200 

201 if pruneRefs is not None: 

202 # track what refs were pruned and prune the graph 

203 prunes = set() 

204 _pruner(self._datasetRefDict, pruneRefs, alreadyPruned=prunes) 

205 

206 # recreate the taskToQuantumNode dict removing nodes that have been 

207 # pruned. Keep track of task defs that now have no QuantumNodes 

208 emptyTasks: Set[str] = set() 

209 newTaskToQuantumNode: DefaultDict[TaskDef, Set[QuantumNode]] = defaultdict(set) 

210 # accumulate all types 

211 types_ = set() 

212 # tracker for any pruneRefs that have caused tasks to have no nodes 

213 # This helps the user find out what caused the issues seen. 

214 culprits = set() 

215 # Find all the types from the refs to prune 

216 for r in pruneRefs: 

217 types_.add(r.datasetType) 

218 

219 # For each of the tasks, and their associated nodes, remove any 

220 # any nodes that were pruned. If there are no nodes associated 

221 # with a task, record that task, and find out if that was due to 

222 # a type from an input ref to prune. 

223 for td, taskNodes in self._taskToQuantumNode.items(): 

224 diff = taskNodes.difference(prunes) 

225 if len(diff) == 0: 

226 if len(taskNodes) != 0: 

227 tp: DatasetType 

228 for tp in types_: 

229 if (tmpRefs := next(iter(taskNodes)).quantum.inputs.get(tp)) and not set( 

230 tmpRefs 

231 ).difference(pruneRefs): 

232 culprits.add(tp.name) 

233 emptyTasks.add(td.label) 

234 newTaskToQuantumNode[td] = diff 

235 

236 # update the internal dict 

237 self._taskToQuantumNode = newTaskToQuantumNode 

238 

239 if emptyTasks: 

240 raise ValueError( 

241 f"{', '.join(emptyTasks)} task(s) have no nodes associated with them " 

242 f"after graph pruning; {', '.join(culprits)} caused over-pruning" 

243 ) 

244 

245 # Graph of quanta relations 

246 self._connectedQuanta = self._datasetRefDict.makeNetworkXGraph() 

247 self._count = len(self._connectedQuanta) 

248 

249 # Graph of task relations, used in various methods 

250 self._taskGraph = self._datasetDict.makeNetworkXGraph() 

251 

252 # convert default dict into a regular to prevent accidental key 

253 # insertion 

254 self._taskToQuantumNode = dict(self._taskToQuantumNode.items()) 

255 

256 @property 

257 def taskGraph(self) -> nx.DiGraph: 

258 """Return a graph representing the relations between the tasks inside 

259 the quantum graph. 

260 

261 Returns 

262 ------- 

263 taskGraph : `networkx.Digraph` 

264 Internal datastructure that holds relations of `TaskDef` objects 

265 """ 

266 return self._taskGraph 

267 

268 @property 

269 def graph(self) -> nx.DiGraph: 

270 """Return a graph representing the relations between all the 

271 `QuantumNode` objects. Largely it should be preferred to iterate 

272 over, and use methods of this class, but sometimes direct access to 

273 the networkx object may be helpful 

274 

275 Returns 

276 ------- 

277 graph : `networkx.Digraph` 

278 Internal datastructure that holds relations of `QuantumNode` 

279 objects 

280 """ 

281 return self._connectedQuanta 

282 

283 @property 

284 def inputQuanta(self) -> Iterable[QuantumNode]: 

285 """Make a `list` of all `QuantumNode` objects that are 'input' nodes 

286 to the graph, meaning those nodes to not depend on any other nodes in 

287 the graph. 

288 

289 Returns 

290 ------- 

291 inputNodes : iterable of `QuantumNode` 

292 A list of nodes that are inputs to the graph 

293 """ 

294 return (q for q, n in self._connectedQuanta.in_degree if n == 0) 

295 

296 @property 

297 def outputQuanta(self) -> Iterable[QuantumNode]: 

298 """Make a `list` of all `QuantumNode` objects that are 'output' nodes 

299 to the graph, meaning those nodes have no nodes that depend them in 

300 the graph. 

301 

302 Returns 

303 ------- 

304 outputNodes : iterable of `QuantumNode` 

305 A list of nodes that are outputs of the graph 

306 """ 

307 return [q for q, n in self._connectedQuanta.out_degree if n == 0] 

308 

309 @property 

310 def allDatasetTypes(self) -> Tuple[DatasetTypeName, ...]: 

311 """Return all the `DatasetTypeName` objects that are contained inside 

312 the graph. 

313 

314 Returns 

315 ------- 

316 tuple of `DatasetTypeName` 

317 All the data set type names that are present in the graph 

318 """ 

319 return tuple(self._datasetDict.keys()) 

320 

321 @property 

322 def isConnected(self) -> bool: 

323 """Return True if all of the nodes in the graph are connected, ignores 

324 directionality of connections. 

325 """ 

326 return nx.is_weakly_connected(self._connectedQuanta) 

327 

328 def pruneGraphFromRefs(self: _T, refs: Iterable[DatasetRef]) -> _T: 

329 r"""Return a graph pruned of input `~lsst.daf.butler.DatasetRef`\ s 

330 and nodes which depend on them. 

331 

332 Parameters 

333 ---------- 

334 refs : `Iterable` of `DatasetRef` 

335 Refs which should be removed from resulting graph 

336 

337 Returns 

338 ------- 

339 graph : `QuantumGraph` 

340 A graph that has been pruned of specified refs and the nodes that 

341 depend on them. 

342 """ 

343 newInst = object.__new__(type(self)) 

344 quantumMap = defaultdict(set) 

345 for node in self: 

346 quantumMap[node.taskDef].add(node.quantum) 

347 

348 # convert to standard dict to prevent accidental key insertion 

349 quantumMap = dict(quantumMap.items()) 

350 

351 newInst._buildGraphs( 

352 quantumMap, 

353 _quantumToNodeId={n.quantum: n.nodeId for n in self}, 

354 metadata=self._metadata, 

355 pruneRefs=refs, 

356 ) 

357 return newInst 

358 

359 def getQuantumNodeByNodeId(self, nodeId: uuid.UUID) -> QuantumNode: 

360 """Lookup a `QuantumNode` from an id associated with the node. 

361 

362 Parameters 

363 ---------- 

364 nodeId : `NodeId` 

365 The number associated with a node 

366 

367 Returns 

368 ------- 

369 node : `QuantumNode` 

370 The node corresponding with input number 

371 

372 Raises 

373 ------ 

374 KeyError 

375 Raised if the requested nodeId is not in the graph. 

376 """ 

377 return self._nodeIdMap[nodeId] 

378 

379 def getQuantaForTask(self, taskDef: TaskDef) -> FrozenSet[Quantum]: 

380 """Return all the `Quantum` associated with a `TaskDef`. 

381 

382 Parameters 

383 ---------- 

384 taskDef : `TaskDef` 

385 The `TaskDef` for which `Quantum` are to be queried 

386 

387 Returns 

388 ------- 

389 frozenset of `Quantum` 

390 The `set` of `Quantum` that is associated with the specified 

391 `TaskDef`. 

392 """ 

393 return frozenset(node.quantum for node in self._taskToQuantumNode[taskDef]) 

394 

395 def getNodesForTask(self, taskDef: TaskDef) -> FrozenSet[QuantumNode]: 

396 """Return all the `QuantumNodes` associated with a `TaskDef`. 

397 

398 Parameters 

399 ---------- 

400 taskDef : `TaskDef` 

401 The `TaskDef` for which `Quantum` are to be queried 

402 

403 Returns 

404 ------- 

405 frozenset of `QuantumNodes` 

406 The `frozenset` of `QuantumNodes` that is associated with the 

407 specified `TaskDef`. 

408 """ 

409 return frozenset(self._taskToQuantumNode[taskDef]) 

410 

411 def findTasksWithInput(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]: 

412 """Find all tasks that have the specified dataset type name as an 

413 input. 

414 

415 Parameters 

416 ---------- 

417 datasetTypeName : `str` 

418 A string representing the name of a dataset type to be queried, 

419 can also accept a `DatasetTypeName` which is a `NewType` of str for 

420 type safety in static type checking. 

421 

422 Returns 

423 ------- 

424 tasks : iterable of `TaskDef` 

425 `TaskDef` objects that have the specified `DatasetTypeName` as an 

426 input, list will be empty if no tasks use specified 

427 `DatasetTypeName` as an input. 

428 

429 Raises 

430 ------ 

431 KeyError 

432 Raised if the `DatasetTypeName` is not part of the `QuantumGraph` 

433 """ 

434 return (c for c in self._datasetDict.getConsumers(datasetTypeName)) 

435 

436 def findTaskWithOutput(self, datasetTypeName: DatasetTypeName) -> Optional[TaskDef]: 

437 """Find all tasks that have the specified dataset type name as an 

438 output. 

439 

440 Parameters 

441 ---------- 

442 datasetTypeName : `str` 

443 A string representing the name of a dataset type to be queried, 

444 can also accept a `DatasetTypeName` which is a `NewType` of str for 

445 type safety in static type checking. 

446 

447 Returns 

448 ------- 

449 `TaskDef` or `None` 

450 `TaskDef` that outputs `DatasetTypeName` as an output or None if 

451 none of the tasks produce this `DatasetTypeName`. 

452 

453 Raises 

454 ------ 

455 KeyError 

456 Raised if the `DatasetTypeName` is not part of the `QuantumGraph` 

457 """ 

458 return self._datasetDict.getProducer(datasetTypeName) 

459 

460 def tasksWithDSType(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]: 

461 """Find all tasks that are associated with the specified dataset type 

462 name. 

463 

464 Parameters 

465 ---------- 

466 datasetTypeName : `str` 

467 A string representing the name of a dataset type to be queried, 

468 can also accept a `DatasetTypeName` which is a `NewType` of str for 

469 type safety in static type checking. 

470 

471 Returns 

472 ------- 

473 result : iterable of `TaskDef` 

474 `TaskDef` objects that are associated with the specified 

475 `DatasetTypeName` 

476 

477 Raises 

478 ------ 

479 KeyError 

480 Raised if the `DatasetTypeName` is not part of the `QuantumGraph` 

481 """ 

482 return self._datasetDict.getAll(datasetTypeName) 

483 

484 def findTaskDefByName(self, taskName: str) -> List[TaskDef]: 

485 """Determine which `TaskDef` objects in this graph are associated 

486 with a `str` representing a task name (looks at the taskName property 

487 of `TaskDef` objects). 

488 

489 Returns a list of `TaskDef` objects as a `PipelineTask` may appear 

490 multiple times in a graph with different labels. 

491 

492 Parameters 

493 ---------- 

494 taskName : str 

495 Name of a task to search for 

496 

497 Returns 

498 ------- 

499 result : list of `TaskDef` 

500 List of the `TaskDef` objects that have the name specified. 

501 Multiple values are returned in the case that a task is used 

502 multiple times with different labels. 

503 """ 

504 results = [] 

505 for task in self._taskToQuantumNode.keys(): 

506 split = task.taskName.split(".") 

507 if split[-1] == taskName: 

508 results.append(task) 

509 return results 

510 

511 def findTaskDefByLabel(self, label: str) -> Optional[TaskDef]: 

512 """Determine which `TaskDef` objects in this graph are associated 

513 with a `str` representing a tasks label. 

514 

515 Parameters 

516 ---------- 

517 taskName : str 

518 Name of a task to search for 

519 

520 Returns 

521 ------- 

522 result : `TaskDef` 

523 `TaskDef` objects that has the specified label. 

524 """ 

525 for task in self._taskToQuantumNode.keys(): 

526 if label == task.label: 

527 return task 

528 return None 

529 

530 def findQuantaWithDSType(self, datasetTypeName: DatasetTypeName) -> Set[Quantum]: 

531 """Return all the `Quantum` that contain a specified `DatasetTypeName`. 

532 

533 Parameters 

534 ---------- 

535 datasetTypeName : `str` 

536 The name of the dataset type to search for as a string, 

537 can also accept a `DatasetTypeName` which is a `NewType` of str for 

538 type safety in static type checking. 

539 

540 Returns 

541 ------- 

542 result : `set` of `QuantumNode` objects 

543 A `set` of `QuantumNode`s that contain specified `DatasetTypeName` 

544 

545 Raises 

546 ------ 

547 KeyError 

548 Raised if the `DatasetTypeName` is not part of the `QuantumGraph` 

549 

550 """ 

551 tasks = self._datasetDict.getAll(datasetTypeName) 

552 result: Set[Quantum] = set() 

553 result = result.union(quantum for task in tasks for quantum in self.getQuantaForTask(task)) 

554 return result 

555 

556 def checkQuantumInGraph(self, quantum: Quantum) -> bool: 

557 """Check if specified quantum appears in the graph as part of a node. 

558 

559 Parameters 

560 ---------- 

561 quantum : `Quantum` 

562 The quantum to search for 

563 

564 Returns 

565 ------- 

566 `bool` 

567 The result of searching for the quantum 

568 """ 

569 for node in self: 

570 if quantum == node.quantum: 

571 return True 

572 return False 

573 

574 def writeDotGraph(self, output: Union[str, io.BufferedIOBase]): 

575 """Write out the graph as a dot graph. 

576 

577 Parameters 

578 ---------- 

579 output : str or `io.BufferedIOBase` 

580 Either a filesystem path to write to, or a file handle object 

581 """ 

582 write_dot(self._connectedQuanta, output) 

583 

584 def subset(self: _T, nodes: Union[QuantumNode, Iterable[QuantumNode]]) -> _T: 

585 """Create a new graph object that contains the subset of the nodes 

586 specified as input. Node number is preserved. 

587 

588 Parameters 

589 ---------- 

590 nodes : `QuantumNode` or iterable of `QuantumNode` 

591 

592 Returns 

593 ------- 

594 graph : instance of graph type 

595 An instance of the type from which the subset was created 

596 """ 

597 if not isinstance(nodes, Iterable): 

598 nodes = (nodes,) 

599 quantumSubgraph = self._connectedQuanta.subgraph(nodes).nodes 

600 quantumMap = defaultdict(set) 

601 

602 node: QuantumNode 

603 for node in quantumSubgraph: 

604 quantumMap[node.taskDef].add(node.quantum) 

605 

606 # convert to standard dict to prevent accidental key insertion 

607 quantumMap = dict(quantumMap.items()) 

608 # Create an empty graph, and then populate it with custom mapping 

609 newInst = type(self)({}) 

610 newInst._buildGraphs( 

611 quantumMap, 

612 _quantumToNodeId={n.quantum: n.nodeId for n in nodes}, 

613 _buildId=self._buildId, 

614 metadata=self._metadata, 

615 ) 

616 return newInst 

617 

618 def subsetToConnected(self: _T) -> Tuple[_T, ...]: 

619 """Generate a list of subgraphs where each is connected. 

620 

621 Returns 

622 ------- 

623 result : list of `QuantumGraph` 

624 A list of graphs that are each connected 

625 """ 

626 return tuple( 

627 self.subset(connectedSet) 

628 for connectedSet in nx.weakly_connected_components(self._connectedQuanta) 

629 ) 

630 

631 def determineInputsToQuantumNode(self, node: QuantumNode) -> Set[QuantumNode]: 

632 """Return a set of `QuantumNode` that are direct inputs to a specified 

633 node. 

634 

635 Parameters 

636 ---------- 

637 node : `QuantumNode` 

638 The node of the graph for which inputs are to be determined 

639 

640 Returns 

641 ------- 

642 set of `QuantumNode` 

643 All the nodes that are direct inputs to specified node 

644 """ 

645 return set(pred for pred in self._connectedQuanta.predecessors(node)) 

646 

647 def determineOutputsOfQuantumNode(self, node: QuantumNode) -> Set[QuantumNode]: 

648 """Return a set of `QuantumNode` that are direct outputs of a specified 

649 node. 

650 

651 Parameters 

652 ---------- 

653 node : `QuantumNode` 

654 The node of the graph for which outputs are to be determined 

655 

656 Returns 

657 ------- 

658 set of `QuantumNode` 

659 All the nodes that are direct outputs to specified node 

660 """ 

661 return set(succ for succ in self._connectedQuanta.successors(node)) 

662 

663 def determineConnectionsOfQuantumNode(self: _T, node: QuantumNode) -> _T: 

664 """Return a graph of `QuantumNode` that are direct inputs and outputs 

665 of a specified node. 

666 

667 Parameters 

668 ---------- 

669 node : `QuantumNode` 

670 The node of the graph for which connected nodes are to be 

671 determined. 

672 

673 Returns 

674 ------- 

675 graph : graph of `QuantumNode` 

676 All the nodes that are directly connected to specified node 

677 """ 

678 nodes = self.determineInputsToQuantumNode(node).union(self.determineOutputsOfQuantumNode(node)) 

679 nodes.add(node) 

680 return self.subset(nodes) 

681 

682 def determineAncestorsOfQuantumNode(self: _T, node: QuantumNode) -> _T: 

683 """Return a graph of the specified node and all the ancestor nodes 

684 directly reachable by walking edges. 

685 

686 Parameters 

687 ---------- 

688 node : `QuantumNode` 

689 The node for which all ansestors are to be determined 

690 

691 Returns 

692 ------- 

693 graph of `QuantumNode` 

694 Graph of node and all of its ansestors 

695 """ 

696 predecessorNodes = nx.ancestors(self._connectedQuanta, node) 

697 predecessorNodes.add(node) 

698 return self.subset(predecessorNodes) 

699 

700 def findCycle(self) -> List[Tuple[QuantumNode, QuantumNode]]: 

701 """Check a graph for the presense of cycles and returns the edges of 

702 any cycles found, or an empty list if there is no cycle. 

703 

704 Returns 

705 ------- 

706 result : list of tuple of `QuantumNode`, `QuantumNode` 

707 A list of any graph edges that form a cycle, or an empty list if 

708 there is no cycle. Empty list to so support if graph.find_cycle() 

709 syntax as an empty list is falsy. 

710 """ 

711 try: 

712 return nx.find_cycle(self._connectedQuanta) 

713 except nx.NetworkXNoCycle: 

714 return [] 

715 

716 def saveUri(self, uri): 

717 """Save `QuantumGraph` to the specified URI. 

718 

719 Parameters 

720 ---------- 

721 uri : convertible to `ResourcePath` 

722 URI to where the graph should be saved. 

723 """ 

724 buffer = self._buildSaveObject() 

725 path = ResourcePath(uri) 

726 if path.getExtension() not in (".qgraph"): 

727 raise TypeError(f"Can currently only save a graph in qgraph format not {uri}") 

728 path.write(buffer) # type: ignore # Ignore because bytearray is safe to use in place of bytes 

729 

730 @property 

731 def metadata(self) -> Optional[MappingProxyType[str, Any]]: 

732 """ """ 

733 if self._metadata is None: 

734 return None 

735 return MappingProxyType(self._metadata) 

736 

737 @classmethod 

738 def loadUri( 

739 cls, 

740 uri: ResourcePathExpression, 

741 universe: DimensionUniverse, 

742 nodes: Optional[Iterable[Union[str, uuid.UUID]]] = None, 

743 graphID: Optional[BuildId] = None, 

744 minimumVersion: int = 3, 

745 ) -> QuantumGraph: 

746 """Read `QuantumGraph` from a URI. 

747 

748 Parameters 

749 ---------- 

750 uri : convertible to `ResourcePath` 

751 URI from where to load the graph. 

752 universe: `~lsst.daf.butler.DimensionUniverse` 

753 DimensionUniverse instance, not used by the method itself but 

754 needed to ensure that registry data structures are initialized. 

755 nodes: iterable of `int` or None 

756 Numbers that correspond to nodes in the graph. If specified, only 

757 these nodes will be loaded. Defaults to None, in which case all 

758 nodes will be loaded. 

759 graphID : `str` or `None` 

760 If specified this ID is verified against the loaded graph prior to 

761 loading any Nodes. This defaults to None in which case no 

762 validation is done. 

763 minimumVersion : int 

764 Minimum version of a save file to load. Set to -1 to load all 

765 versions. Older versions may need to be loaded, and re-saved 

766 to upgrade them to the latest format before they can be used in 

767 production. 

768 

769 Returns 

770 ------- 

771 graph : `QuantumGraph` 

772 Resulting QuantumGraph instance. 

773 

774 Raises 

775 ------ 

776 TypeError 

777 Raised if pickle contains instance of a type other than 

778 QuantumGraph. 

779 ValueError 

780 Raised if one or more of the nodes requested is not in the 

781 `QuantumGraph` or if graphID parameter does not match the graph 

782 being loaded or if the supplied uri does not point at a valid 

783 `QuantumGraph` save file. 

784 

785 

786 Notes 

787 ----- 

788 Reading Quanta from pickle requires existence of singleton 

789 DimensionUniverse which is usually instantiated during Registry 

790 initialization. To make sure that DimensionUniverse exists this method 

791 accepts dummy DimensionUniverse argument. 

792 """ 

793 uri = ResourcePath(uri) 

794 # With ResourcePath we have the choice of always using a local file 

795 # or reading in the bytes directly. Reading in bytes can be more 

796 # efficient for reasonably-sized pickle files when the resource 

797 # is remote. For now use the local file variant. For a local file 

798 # as_local() does nothing. 

799 

800 if uri.getExtension() in (".pickle", ".pkl"): 

801 with uri.as_local() as local, open(local.ospath, "rb") as fd: 

802 warnings.warn("Pickle graphs are deprecated, please re-save your graph with the save method") 

803 qgraph = pickle.load(fd) 

804 elif uri.getExtension() in (".qgraph"): 

805 with LoadHelper(uri, minimumVersion) as loader: 

806 qgraph = loader.load(universe, nodes, graphID) 

807 else: 

808 raise ValueError("Only know how to handle files saved as `pickle`, `pkl`, or `qgraph`") 

809 if not isinstance(qgraph, QuantumGraph): 

810 raise TypeError(f"QuantumGraph save file contains unexpected object type: {type(qgraph)}") 

811 return qgraph 

812 

813 @classmethod 

814 def readHeader(cls, uri: ResourcePathExpression, minimumVersion: int = 3) -> Optional[str]: 

815 """Read the header of a `QuantumGraph` pointed to by the uri parameter 

816 and return it as a string. 

817 

818 Parameters 

819 ---------- 

820 uri : convertible to `ResourcePath` 

821 The location of the `QuantumGraph` to load. If the argument is a 

822 string, it must correspond to a valid `ResourcePath` path. 

823 minimumVersion : int 

824 Minimum version of a save file to load. Set to -1 to load all 

825 versions. Older versions may need to be loaded, and re-saved 

826 to upgrade them to the latest format before they can be used in 

827 production. 

828 

829 Returns 

830 ------- 

831 header : `str` or `None` 

832 The header associated with the specified `QuantumGraph` it there is 

833 one, else `None`. 

834 

835 Raises 

836 ------ 

837 ValueError 

838 Raised if `QuantuGraph` was saved as a pickle. 

839 Raised if the extention of the file specified by uri is not a 

840 `QuantumGraph` extention. 

841 """ 

842 uri = ResourcePath(uri) 

843 if uri.getExtension() in (".pickle", ".pkl"): 

844 raise ValueError("Reading a header from a pickle save is not supported") 

845 elif uri.getExtension() in (".qgraph"): 

846 return LoadHelper(uri, minimumVersion).readHeader() 

847 else: 

848 raise ValueError("Only know how to handle files saved as `qgraph`") 

849 

850 def buildAndPrintHeader(self): 

851 """Creates a header that would be used in a save of this object and 

852 prints it out to standard out. 

853 """ 

854 _, header = self._buildSaveObject(returnHeader=True) 

855 print(json.dumps(header)) 

856 

857 def save(self, file: io.IO[bytes]): 

858 """Save QuantumGraph to a file. 

859 

860 Presently we store QuantumGraph in pickle format, this could 

861 potentially change in the future if better format is found. 

862 

863 Parameters 

864 ---------- 

865 file : `io.BufferedIOBase` 

866 File to write pickle data open in binary mode. 

867 """ 

868 buffer = self._buildSaveObject() 

869 file.write(buffer) # type: ignore # Ignore because bytearray is safe to use in place of bytes 

870 

871 def _buildSaveObject(self, returnHeader: bool = False) -> Union[bytearray, Tuple[bytearray, Dict]]: 

872 # make some containers 

873 jsonData = deque() 

874 # node map is a list because json does not accept mapping keys that 

875 # are not strings, so we store a list of key, value pairs that will 

876 # be converted to a mapping on load 

877 nodeMap = [] 

878 taskDefMap = {} 

879 headerData = {} 

880 

881 # Store the QauntumGraph BuildId, this will allow validating BuildIds 

882 # at load time, prior to loading any QuantumNodes. Name chosen for 

883 # unlikely conflicts. 

884 headerData["GraphBuildID"] = self.graphID 

885 headerData["Metadata"] = self._metadata 

886 

887 # counter for the number of bytes processed thus far 

888 count = 0 

889 # serialize out the task Defs recording the start and end bytes of each 

890 # taskDef 

891 inverseLookup = self._datasetDict.inverse 

892 taskDef: TaskDef 

893 # sort by task label to ensure serialization happens in the same order 

894 for taskDef in self.taskGraph: 

895 # compressing has very little impact on saving or load time, but 

896 # a large impact on on disk size, so it is worth doing 

897 taskDescription = {} 

898 # save the fully qualified name, as TaskDef not not require this, 

899 # but by doing so can save space and is easier to transport 

900 taskDescription["taskName"] = f"{taskDef.taskClass.__module__}.{taskDef.taskClass.__qualname__}" 

901 # save the config as a text stream that will be un-persisted on the 

902 # other end 

903 stream = io.StringIO() 

904 taskDef.config.saveToStream(stream) 

905 taskDescription["config"] = stream.getvalue() 

906 taskDescription["label"] = taskDef.label 

907 

908 inputs = [] 

909 outputs = [] 

910 

911 # Determine the connection between all of tasks and save that in 

912 # the header as a list of connections and edges in each task 

913 # this will help in un-persisting, and possibly in a "quick view" 

914 # method that does not require everything to be un-persisted 

915 # 

916 # Typing returns can't be parameter dependent 

917 for connection in inverseLookup[taskDef]: # type: ignore 

918 consumers = self._datasetDict.getConsumers(connection) 

919 producer = self._datasetDict.getProducer(connection) 

920 if taskDef in consumers: 

921 # This checks if the task consumes the connection directly 

922 # from the datastore or it is produced by another task 

923 producerLabel = producer.label if producer is not None else "datastore" 

924 inputs.append((producerLabel, connection)) 

925 elif taskDef not in consumers and producer is taskDef: 

926 # If there are no consumers for this tasks produced 

927 # connection, the output will be said to be the datastore 

928 # in which case the for loop will be a zero length loop 

929 if not consumers: 

930 outputs.append(("datastore", connection)) 

931 for td in consumers: 

932 outputs.append((td.label, connection)) 

933 

934 # dump to json string, and encode that string to bytes and then 

935 # conpress those bytes 

936 dump = lzma.compress(json.dumps(taskDescription).encode()) 

937 # record the sizing and relation information 

938 taskDefMap[taskDef.label] = { 

939 "bytes": (count, count + len(dump)), 

940 "inputs": inputs, 

941 "outputs": outputs, 

942 } 

943 count += len(dump) 

944 jsonData.append(dump) 

945 

946 headerData["TaskDefs"] = taskDefMap 

947 

948 # serialize the nodes, recording the start and end bytes of each node 

949 dimAccumulator = DimensionRecordsAccumulator() 

950 for node in self: 

951 # compressing has very little impact on saving or load time, but 

952 # a large impact on on disk size, so it is worth doing 

953 simpleNode = node.to_simple(accumulator=dimAccumulator) 

954 

955 dump = lzma.compress(simpleNode.json().encode()) 

956 jsonData.append(dump) 

957 nodeMap.append( 

958 ( 

959 str(node.nodeId), 

960 { 

961 "bytes": (count, count + len(dump)), 

962 "inputs": [str(n.nodeId) for n in self.determineInputsToQuantumNode(node)], 

963 "outputs": [str(n.nodeId) for n in self.determineOutputsOfQuantumNode(node)], 

964 }, 

965 ) 

966 ) 

967 count += len(dump) 

968 

969 headerData["DimensionRecords"] = { 

970 key: value.dict() for key, value in dimAccumulator.makeSerializedDimensionRecordMapping().items() 

971 } 

972 

973 # need to serialize this as a series of key,value tuples because of 

974 # a limitation on how json cant do anyting but strings as keys 

975 headerData["Nodes"] = nodeMap 

976 

977 # dump the headerData to json 

978 header_encode = lzma.compress(json.dumps(headerData).encode()) 

979 

980 # record the sizes as 2 unsigned long long numbers for a total of 16 

981 # bytes 

982 save_bytes = struct.pack(STRUCT_FMT_BASE, SAVE_VERSION) 

983 

984 fmt_string = DESERIALIZER_MAP[SAVE_VERSION].FMT_STRING() 

985 map_lengths = struct.pack(fmt_string, len(header_encode)) 

986 

987 # write each component of the save out in a deterministic order 

988 # buffer = io.BytesIO() 

989 # buffer.write(map_lengths) 

990 # buffer.write(taskDef_pickle) 

991 # buffer.write(map_pickle) 

992 buffer = bytearray() 

993 buffer.extend(MAGIC_BYTES) 

994 buffer.extend(save_bytes) 

995 buffer.extend(map_lengths) 

996 buffer.extend(header_encode) 

997 # Iterate over the length of pickleData, and for each element pop the 

998 # leftmost element off the deque and write it out. This is to save 

999 # memory, as the memory is added to the buffer object, it is removed 

1000 # from from the container. 

1001 # 

1002 # Only this section needs to worry about memory pressue because 

1003 # everything else written to the buffer prior to this pickle data is 

1004 # only on the order of kilobytes to low numbers of megabytes. 

1005 while jsonData: 

1006 buffer.extend(jsonData.popleft()) 

1007 if returnHeader: 

1008 return buffer, headerData 

1009 else: 

1010 return buffer 

1011 

1012 @classmethod 

1013 def load( 

1014 cls, 

1015 file: io.IO[bytes], 

1016 universe: DimensionUniverse, 

1017 nodes: Optional[Iterable[uuid.UUID]] = None, 

1018 graphID: Optional[BuildId] = None, 

1019 minimumVersion: int = 3, 

1020 ) -> QuantumGraph: 

1021 """Read QuantumGraph from a file that was made by `save`. 

1022 

1023 Parameters 

1024 ---------- 

1025 file : `io.IO` of bytes 

1026 File with pickle data open in binary mode. 

1027 universe: `~lsst.daf.butler.DimensionUniverse` 

1028 DimensionUniverse instance, not used by the method itself but 

1029 needed to ensure that registry data structures are initialized. 

1030 nodes: iterable of `int` or None 

1031 Numbers that correspond to nodes in the graph. If specified, only 

1032 these nodes will be loaded. Defaults to None, in which case all 

1033 nodes will be loaded. 

1034 graphID : `str` or `None` 

1035 If specified this ID is verified against the loaded graph prior to 

1036 loading any Nodes. This defaults to None in which case no 

1037 validation is done. 

1038 minimumVersion : int 

1039 Minimum version of a save file to load. Set to -1 to load all 

1040 versions. Older versions may need to be loaded, and re-saved 

1041 to upgrade them to the latest format before they can be used in 

1042 production. 

1043 

1044 Returns 

1045 ------- 

1046 graph : `QuantumGraph` 

1047 Resulting QuantumGraph instance. 

1048 

1049 Raises 

1050 ------ 

1051 TypeError 

1052 Raised if pickle contains instance of a type other than 

1053 QuantumGraph. 

1054 ValueError 

1055 Raised if one or more of the nodes requested is not in the 

1056 `QuantumGraph` or if graphID parameter does not match the graph 

1057 being loaded or if the supplied uri does not point at a valid 

1058 `QuantumGraph` save file. 

1059 

1060 Notes 

1061 ----- 

1062 Reading Quanta from pickle requires existence of singleton 

1063 DimensionUniverse which is usually instantiated during Registry 

1064 initialization. To make sure that DimensionUniverse exists this method 

1065 accepts dummy DimensionUniverse argument. 

1066 """ 

1067 # Try to see if the file handle contains pickle data, this will be 

1068 # removed in the future 

1069 try: 

1070 qgraph = pickle.load(file) 

1071 warnings.warn("Pickle graphs are deprecated, please re-save your graph with the save method") 

1072 except pickle.UnpicklingError: 

1073 # needed because we don't have Protocols yet 

1074 with LoadHelper(file, minimumVersion) as loader: # type: ignore 

1075 qgraph = loader.load(universe, nodes, graphID) 

1076 if not isinstance(qgraph, QuantumGraph): 

1077 raise TypeError(f"QuantumGraph pickle file has contains unexpected object type: {type(qgraph)}") 

1078 return qgraph 

1079 

1080 def iterTaskGraph(self) -> Generator[TaskDef, None, None]: 

1081 """Iterate over the `taskGraph` attribute in topological order 

1082 

1083 Yields 

1084 ------ 

1085 taskDef : `TaskDef` 

1086 `TaskDef` objects in topological order 

1087 """ 

1088 yield from nx.topological_sort(self.taskGraph) 

1089 

1090 @property 

1091 def graphID(self): 

1092 """Returns the ID generated by the graph at construction time""" 

1093 return self._buildId 

1094 

1095 def __iter__(self) -> Generator[QuantumNode, None, None]: 

1096 yield from nx.topological_sort(self._connectedQuanta) 

1097 

1098 def __len__(self) -> int: 

1099 return self._count 

1100 

1101 def __contains__(self, node: QuantumNode) -> bool: 

1102 return self._connectedQuanta.has_node(node) 

1103 

1104 def __getstate__(self) -> dict: 

1105 """Stores a compact form of the graph as a list of graph nodes, and a 

1106 tuple of task labels and task configs. The full graph can be 

1107 reconstructed with this information, and it preseves the ordering of 

1108 the graph ndoes. 

1109 """ 

1110 universe: Optional[DimensionUniverse] = None 

1111 for node in self: 

1112 dId = node.quantum.dataId 

1113 if dId is None: 

1114 continue 

1115 universe = dId.graph.universe 

1116 return {"reduced": self._buildSaveObject(), "graphId": self._buildId, "universe": universe} 

1117 

1118 def __setstate__(self, state: dict): 

1119 """Reconstructs the state of the graph from the information persisted 

1120 in getstate. 

1121 """ 

1122 buffer = io.BytesIO(state["reduced"]) 

1123 with LoadHelper(buffer, minimumVersion=3) as loader: 

1124 qgraph = loader.load(state["universe"], graphID=state["graphId"]) 

1125 

1126 self._metadata = qgraph._metadata 

1127 self._buildId = qgraph._buildId 

1128 self._datasetDict = qgraph._datasetDict 

1129 self._nodeIdMap = qgraph._nodeIdMap 

1130 self._count = len(qgraph) 

1131 self._taskToQuantumNode = qgraph._taskToQuantumNode 

1132 self._taskGraph = qgraph._taskGraph 

1133 self._connectedQuanta = qgraph._connectedQuanta 

1134 

1135 def __eq__(self, other: object) -> bool: 

1136 if not isinstance(other, QuantumGraph): 

1137 return False 

1138 if len(self) != len(other): 

1139 return False 

1140 for node in self: 

1141 if node not in other: 

1142 return False 

1143 if self.determineInputsToQuantumNode(node) != other.determineInputsToQuantumNode(node): 

1144 return False 

1145 if self.determineOutputsOfQuantumNode(node) != other.determineOutputsOfQuantumNode(node): 

1146 return False 

1147 if set(self.allDatasetTypes) != set(other.allDatasetTypes): 

1148 return False 

1149 return set(self.taskGraph) == set(other.taskGraph)