Coverage for python/lsst/pipe/base/graph/graph.py: 19%

Shortcuts on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

340 statements  

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("QuantumGraph", "IncompatibleGraphError") 

24 

25import io 

26import json 

27import lzma 

28import os 

29import pickle 

30import struct 

31import time 

32import uuid 

33import warnings 

34from collections import defaultdict, deque 

35from itertools import chain 

36from types import MappingProxyType 

37from typing import ( 

38 Any, 

39 DefaultDict, 

40 Dict, 

41 FrozenSet, 

42 Generator, 

43 Iterable, 

44 List, 

45 Mapping, 

46 MutableMapping, 

47 Optional, 

48 Set, 

49 Tuple, 

50 TypeVar, 

51 Union, 

52) 

53 

54import networkx as nx 

55from lsst.daf.butler import DatasetRef, DatasetType, DimensionRecordsAccumulator, DimensionUniverse, Quantum 

56from lsst.resources import ResourcePath, ResourcePathExpression 

57from lsst.utils.introspection import get_full_type_name 

58from networkx.drawing.nx_agraph import write_dot 

59 

60from ..connections import iterConnections 

61from ..pipeline import TaskDef 

62from ._implDetails import DatasetTypeName, _DatasetTracker, _pruner 

63from ._loadHelpers import LoadHelper 

64from ._versionDeserializers import DESERIALIZER_MAP 

65from .quantumNode import BuildId, QuantumNode 

66 

67_T = TypeVar("_T", bound="QuantumGraph") 

68 

69# modify this constant any time the on disk representation of the save file 

70# changes, and update the load helpers to behave properly for each version. 

71SAVE_VERSION = 3 

72 

73# Strings used to describe the format for the preamble bytes in a file save 

74# The base is a big endian encoded unsigned short that is used to hold the 

75# file format version. This allows reading version bytes and determine which 

76# loading code should be used for the rest of the file 

77STRUCT_FMT_BASE = ">H" 

78# 

79# Version 1 

80# This marks a big endian encoded format with an unsigned short, an unsigned 

81# long long, and an unsigned long long in the byte stream 

82# Version 2 

83# A big endian encoded format with an unsigned long long byte stream used to 

84# indicate the total length of the entire header. 

85STRUCT_FMT_STRING = {1: ">QQ", 2: ">Q"} 

86 

87# magic bytes that help determine this is a graph save 

88MAGIC_BYTES = b"qgraph4\xf6\xe8\xa9" 

89 

90 

91class IncompatibleGraphError(Exception): 

92 """Exception class to indicate that a lookup by NodeId is impossible due 

93 to incompatibilities 

94 """ 

95 

96 pass 

97 

98 

99class QuantumGraph: 

100 """QuantumGraph is a directed acyclic graph of `QuantumNode` objects 

101 

102 This data structure represents a concrete workflow generated from a 

103 `Pipeline`. 

104 

105 Parameters 

106 ---------- 

107 quanta : Mapping of `TaskDef` to sets of `Quantum` 

108 This maps tasks (and their configs) to the sets of data they are to 

109 process. 

110 metadata : Optional Mapping of `str` to primitives 

111 This is an optional parameter of extra data to carry with the graph. 

112 Entries in this mapping should be able to be serialized in JSON. 

113 

114 Raises 

115 ------ 

116 ValueError 

117 Raised if the graph is pruned such that some tasks no longer have nodes 

118 associated with them. 

119 """ 

120 

121 def __init__( 

122 self, 

123 quanta: Mapping[TaskDef, Set[Quantum]], 

124 metadata: Optional[Mapping[str, Any]] = None, 

125 pruneRefs: Optional[Iterable[DatasetRef]] = None, 

126 ): 

127 self._buildGraphs(quanta, metadata=metadata, pruneRefs=pruneRefs) 

128 

129 def _buildGraphs( 

130 self, 

131 quanta: Mapping[TaskDef, Set[Quantum]], 

132 *, 

133 _quantumToNodeId: Optional[Mapping[Quantum, uuid.UUID]] = None, 

134 _buildId: Optional[BuildId] = None, 

135 metadata: Optional[Mapping[str, Any]] = None, 

136 pruneRefs: Optional[Iterable[DatasetRef]] = None, 

137 ): 

138 """Builds the graph that is used to store the relation between tasks, 

139 and the graph that holds the relations between quanta 

140 """ 

141 self._metadata = metadata 

142 self._buildId = _buildId if _buildId is not None else BuildId(f"{time.time()}-{os.getpid()}") 

143 # Data structures used to identify relations between components; 

144 # DatasetTypeName -> TaskDef for task, 

145 # and DatasetRef -> QuantumNode for the quanta 

146 self._datasetDict = _DatasetTracker[DatasetTypeName, TaskDef](createInverse=True) 

147 self._datasetRefDict = _DatasetTracker[DatasetRef, QuantumNode]() 

148 

149 self._nodeIdMap: Dict[uuid.UUID, QuantumNode] = {} 

150 self._taskToQuantumNode: MutableMapping[TaskDef, Set[QuantumNode]] = defaultdict(set) 

151 for taskDef, quantumSet in quanta.items(): 

152 connections = taskDef.connections 

153 

154 # For each type of connection in the task, add a key to the 

155 # `_DatasetTracker` for the connections name, with a value of 

156 # the TaskDef in the appropriate field 

157 for inpt in iterConnections(connections, ("inputs", "prerequisiteInputs", "initInputs")): 

158 self._datasetDict.addConsumer(DatasetTypeName(inpt.name), taskDef) 

159 

160 for output in iterConnections(connections, ("outputs",)): 

161 self._datasetDict.addProducer(DatasetTypeName(output.name), taskDef) 

162 

163 # For each `Quantum` in the set of all `Quantum` for this task, 

164 # add a key to the `_DatasetTracker` that is a `DatasetRef` for one 

165 # of the individual datasets inside the `Quantum`, with a value of 

166 # a newly created QuantumNode to the appropriate input/output 

167 # field. 

168 for quantum in quantumSet: 

169 if _quantumToNodeId: 

170 if (nodeId := _quantumToNodeId.get(quantum)) is None: 

171 raise ValueError( 

172 "If _quantuMToNodeNumber is not None, all quanta must have an " 

173 "associated value in the mapping" 

174 ) 

175 else: 

176 nodeId = uuid.uuid4() 

177 

178 inits = quantum.initInputs.values() 

179 inputs = quantum.inputs.values() 

180 value = QuantumNode(quantum, taskDef, nodeId) 

181 self._taskToQuantumNode[taskDef].add(value) 

182 self._nodeIdMap[nodeId] = value 

183 

184 for dsRef in chain(inits, inputs): 

185 # unfortunately, `Quantum` allows inits to be individual 

186 # `DatasetRef`s or an Iterable of such, so there must 

187 # be an instance check here 

188 if isinstance(dsRef, Iterable): 

189 for sub in dsRef: 

190 if sub.isComponent(): 

191 sub = sub.makeCompositeRef() 

192 self._datasetRefDict.addConsumer(sub, value) 

193 else: 

194 if dsRef.isComponent(): 

195 dsRef = dsRef.makeCompositeRef() 

196 self._datasetRefDict.addConsumer(dsRef, value) 

197 for dsRef in chain.from_iterable(quantum.outputs.values()): 

198 self._datasetRefDict.addProducer(dsRef, value) 

199 

200 if pruneRefs is not None: 

201 # track what refs were pruned and prune the graph 

202 prunes = set() 

203 _pruner(self._datasetRefDict, pruneRefs, alreadyPruned=prunes) 

204 

205 # recreate the taskToQuantumNode dict removing nodes that have been 

206 # pruned. Keep track of task defs that now have no QuantumNodes 

207 emptyTasks: Set[str] = set() 

208 newTaskToQuantumNode: DefaultDict[TaskDef, Set[QuantumNode]] = defaultdict(set) 

209 # accumulate all types 

210 types_ = set() 

211 # tracker for any pruneRefs that have caused tasks to have no nodes 

212 # This helps the user find out what caused the issues seen. 

213 culprits = set() 

214 # Find all the types from the refs to prune 

215 for r in pruneRefs: 

216 types_.add(r.datasetType) 

217 

218 # For each of the tasks, and their associated nodes, remove any 

219 # any nodes that were pruned. If there are no nodes associated 

220 # with a task, record that task, and find out if that was due to 

221 # a type from an input ref to prune. 

222 for td, taskNodes in self._taskToQuantumNode.items(): 

223 diff = taskNodes.difference(prunes) 

224 if len(diff) == 0: 

225 if len(taskNodes) != 0: 

226 tp: DatasetType 

227 for tp in types_: 

228 if (tmpRefs := next(iter(taskNodes)).quantum.inputs.get(tp)) and not set( 

229 tmpRefs 

230 ).difference(pruneRefs): 

231 culprits.add(tp.name) 

232 emptyTasks.add(td.label) 

233 newTaskToQuantumNode[td] = diff 

234 

235 # update the internal dict 

236 self._taskToQuantumNode = newTaskToQuantumNode 

237 

238 if emptyTasks: 

239 raise ValueError( 

240 f"{', '.join(emptyTasks)} task(s) have no nodes associated with them " 

241 f"after graph pruning; {', '.join(culprits)} caused over-pruning" 

242 ) 

243 

244 # Graph of quanta relations 

245 self._connectedQuanta = self._datasetRefDict.makeNetworkXGraph() 

246 self._count = len(self._connectedQuanta) 

247 

248 # Graph of task relations, used in various methods 

249 self._taskGraph = self._datasetDict.makeNetworkXGraph() 

250 

251 # convert default dict into a regular to prevent accidental key 

252 # insertion 

253 self._taskToQuantumNode = dict(self._taskToQuantumNode.items()) 

254 

255 @property 

256 def taskGraph(self) -> nx.DiGraph: 

257 """Return a graph representing the relations between the tasks inside 

258 the quantum graph. 

259 

260 Returns 

261 ------- 

262 taskGraph : `networkx.Digraph` 

263 Internal datastructure that holds relations of `TaskDef` objects 

264 """ 

265 return self._taskGraph 

266 

267 @property 

268 def graph(self) -> nx.DiGraph: 

269 """Return a graph representing the relations between all the 

270 `QuantumNode` objects. Largely it should be preferred to iterate 

271 over, and use methods of this class, but sometimes direct access to 

272 the networkx object may be helpful 

273 

274 Returns 

275 ------- 

276 graph : `networkx.Digraph` 

277 Internal datastructure that holds relations of `QuantumNode` 

278 objects 

279 """ 

280 return self._connectedQuanta 

281 

282 @property 

283 def inputQuanta(self) -> Iterable[QuantumNode]: 

284 """Make a `list` of all `QuantumNode` objects that are 'input' nodes 

285 to the graph, meaning those nodes to not depend on any other nodes in 

286 the graph. 

287 

288 Returns 

289 ------- 

290 inputNodes : iterable of `QuantumNode` 

291 A list of nodes that are inputs to the graph 

292 """ 

293 return (q for q, n in self._connectedQuanta.in_degree if n == 0) 

294 

295 @property 

296 def outputQuanta(self) -> Iterable[QuantumNode]: 

297 """Make a `list` of all `QuantumNode` objects that are 'output' nodes 

298 to the graph, meaning those nodes have no nodes that depend them in 

299 the graph. 

300 

301 Returns 

302 ------- 

303 outputNodes : iterable of `QuantumNode` 

304 A list of nodes that are outputs of the graph 

305 """ 

306 return [q for q, n in self._connectedQuanta.out_degree if n == 0] 

307 

308 @property 

309 def allDatasetTypes(self) -> Tuple[DatasetTypeName, ...]: 

310 """Return all the `DatasetTypeName` objects that are contained inside 

311 the graph. 

312 

313 Returns 

314 ------- 

315 tuple of `DatasetTypeName` 

316 All the data set type names that are present in the graph 

317 """ 

318 return tuple(self._datasetDict.keys()) 

319 

320 @property 

321 def isConnected(self) -> bool: 

322 """Return True if all of the nodes in the graph are connected, ignores 

323 directionality of connections. 

324 """ 

325 return nx.is_weakly_connected(self._connectedQuanta) 

326 

327 def pruneGraphFromRefs(self: _T, refs: Iterable[DatasetRef]) -> _T: 

328 r"""Return a graph pruned of input `~lsst.daf.butler.DatasetRef`\ s 

329 and nodes which depend on them. 

330 

331 Parameters 

332 ---------- 

333 refs : `Iterable` of `DatasetRef` 

334 Refs which should be removed from resulting graph 

335 

336 Returns 

337 ------- 

338 graph : `QuantumGraph` 

339 A graph that has been pruned of specified refs and the nodes that 

340 depend on them. 

341 """ 

342 newInst = object.__new__(type(self)) 

343 quantumMap = defaultdict(set) 

344 for node in self: 

345 quantumMap[node.taskDef].add(node.quantum) 

346 

347 # convert to standard dict to prevent accidental key insertion 

348 quantumMap = dict(quantumMap.items()) 

349 

350 newInst._buildGraphs( 

351 quantumMap, 

352 _quantumToNodeId={n.quantum: n.nodeId for n in self}, 

353 metadata=self._metadata, 

354 pruneRefs=refs, 

355 ) 

356 return newInst 

357 

358 def getQuantumNodeByNodeId(self, nodeId: uuid.UUID) -> QuantumNode: 

359 """Lookup a `QuantumNode` from an id associated with the node. 

360 

361 Parameters 

362 ---------- 

363 nodeId : `NodeId` 

364 The number associated with a node 

365 

366 Returns 

367 ------- 

368 node : `QuantumNode` 

369 The node corresponding with input number 

370 

371 Raises 

372 ------ 

373 KeyError 

374 Raised if the requested nodeId is not in the graph. 

375 """ 

376 return self._nodeIdMap[nodeId] 

377 

378 def getQuantaForTask(self, taskDef: TaskDef) -> FrozenSet[Quantum]: 

379 """Return all the `Quantum` associated with a `TaskDef`. 

380 

381 Parameters 

382 ---------- 

383 taskDef : `TaskDef` 

384 The `TaskDef` for which `Quantum` are to be queried 

385 

386 Returns 

387 ------- 

388 frozenset of `Quantum` 

389 The `set` of `Quantum` that is associated with the specified 

390 `TaskDef`. 

391 """ 

392 return frozenset(node.quantum for node in self._taskToQuantumNode[taskDef]) 

393 

394 def getNodesForTask(self, taskDef: TaskDef) -> FrozenSet[QuantumNode]: 

395 """Return all the `QuantumNodes` associated with a `TaskDef`. 

396 

397 Parameters 

398 ---------- 

399 taskDef : `TaskDef` 

400 The `TaskDef` for which `Quantum` are to be queried 

401 

402 Returns 

403 ------- 

404 frozenset of `QuantumNodes` 

405 The `frozenset` of `QuantumNodes` that is associated with the 

406 specified `TaskDef`. 

407 """ 

408 return frozenset(self._taskToQuantumNode[taskDef]) 

409 

410 def findTasksWithInput(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]: 

411 """Find all tasks that have the specified dataset type name as an 

412 input. 

413 

414 Parameters 

415 ---------- 

416 datasetTypeName : `str` 

417 A string representing the name of a dataset type to be queried, 

418 can also accept a `DatasetTypeName` which is a `NewType` of str for 

419 type safety in static type checking. 

420 

421 Returns 

422 ------- 

423 tasks : iterable of `TaskDef` 

424 `TaskDef` objects that have the specified `DatasetTypeName` as an 

425 input, list will be empty if no tasks use specified 

426 `DatasetTypeName` as an input. 

427 

428 Raises 

429 ------ 

430 KeyError 

431 Raised if the `DatasetTypeName` is not part of the `QuantumGraph` 

432 """ 

433 return (c for c in self._datasetDict.getConsumers(datasetTypeName)) 

434 

435 def findTaskWithOutput(self, datasetTypeName: DatasetTypeName) -> Optional[TaskDef]: 

436 """Find all tasks that have the specified dataset type name as an 

437 output. 

438 

439 Parameters 

440 ---------- 

441 datasetTypeName : `str` 

442 A string representing the name of a dataset type to be queried, 

443 can also accept a `DatasetTypeName` which is a `NewType` of str for 

444 type safety in static type checking. 

445 

446 Returns 

447 ------- 

448 `TaskDef` or `None` 

449 `TaskDef` that outputs `DatasetTypeName` as an output or None if 

450 none of the tasks produce this `DatasetTypeName`. 

451 

452 Raises 

453 ------ 

454 KeyError 

455 Raised if the `DatasetTypeName` is not part of the `QuantumGraph` 

456 """ 

457 return self._datasetDict.getProducer(datasetTypeName) 

458 

459 def tasksWithDSType(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]: 

460 """Find all tasks that are associated with the specified dataset type 

461 name. 

462 

463 Parameters 

464 ---------- 

465 datasetTypeName : `str` 

466 A string representing the name of a dataset type to be queried, 

467 can also accept a `DatasetTypeName` which is a `NewType` of str for 

468 type safety in static type checking. 

469 

470 Returns 

471 ------- 

472 result : iterable of `TaskDef` 

473 `TaskDef` objects that are associated with the specified 

474 `DatasetTypeName` 

475 

476 Raises 

477 ------ 

478 KeyError 

479 Raised if the `DatasetTypeName` is not part of the `QuantumGraph` 

480 """ 

481 return self._datasetDict.getAll(datasetTypeName) 

482 

483 def findTaskDefByName(self, taskName: str) -> List[TaskDef]: 

484 """Determine which `TaskDef` objects in this graph are associated 

485 with a `str` representing a task name (looks at the taskName property 

486 of `TaskDef` objects). 

487 

488 Returns a list of `TaskDef` objects as a `PipelineTask` may appear 

489 multiple times in a graph with different labels. 

490 

491 Parameters 

492 ---------- 

493 taskName : str 

494 Name of a task to search for 

495 

496 Returns 

497 ------- 

498 result : list of `TaskDef` 

499 List of the `TaskDef` objects that have the name specified. 

500 Multiple values are returned in the case that a task is used 

501 multiple times with different labels. 

502 """ 

503 results = [] 

504 for task in self._taskToQuantumNode.keys(): 

505 split = task.taskName.split(".") 

506 if split[-1] == taskName: 

507 results.append(task) 

508 return results 

509 

510 def findTaskDefByLabel(self, label: str) -> Optional[TaskDef]: 

511 """Determine which `TaskDef` objects in this graph are associated 

512 with a `str` representing a tasks label. 

513 

514 Parameters 

515 ---------- 

516 taskName : str 

517 Name of a task to search for 

518 

519 Returns 

520 ------- 

521 result : `TaskDef` 

522 `TaskDef` objects that has the specified label. 

523 """ 

524 for task in self._taskToQuantumNode.keys(): 

525 if label == task.label: 

526 return task 

527 return None 

528 

529 def findQuantaWithDSType(self, datasetTypeName: DatasetTypeName) -> Set[Quantum]: 

530 """Return all the `Quantum` that contain a specified `DatasetTypeName`. 

531 

532 Parameters 

533 ---------- 

534 datasetTypeName : `str` 

535 The name of the dataset type to search for as a string, 

536 can also accept a `DatasetTypeName` which is a `NewType` of str for 

537 type safety in static type checking. 

538 

539 Returns 

540 ------- 

541 result : `set` of `QuantumNode` objects 

542 A `set` of `QuantumNode`s that contain specified `DatasetTypeName` 

543 

544 Raises 

545 ------ 

546 KeyError 

547 Raised if the `DatasetTypeName` is not part of the `QuantumGraph` 

548 

549 """ 

550 tasks = self._datasetDict.getAll(datasetTypeName) 

551 result: Set[Quantum] = set() 

552 result = result.union(quantum for task in tasks for quantum in self.getQuantaForTask(task)) 

553 return result 

554 

555 def checkQuantumInGraph(self, quantum: Quantum) -> bool: 

556 """Check if specified quantum appears in the graph as part of a node. 

557 

558 Parameters 

559 ---------- 

560 quantum : `Quantum` 

561 The quantum to search for 

562 

563 Returns 

564 ------- 

565 `bool` 

566 The result of searching for the quantum 

567 """ 

568 for node in self: 

569 if quantum == node.quantum: 

570 return True 

571 return False 

572 

573 def writeDotGraph(self, output: Union[str, io.BufferedIOBase]): 

574 """Write out the graph as a dot graph. 

575 

576 Parameters 

577 ---------- 

578 output : str or `io.BufferedIOBase` 

579 Either a filesystem path to write to, or a file handle object 

580 """ 

581 write_dot(self._connectedQuanta, output) 

582 

583 def subset(self: _T, nodes: Union[QuantumNode, Iterable[QuantumNode]]) -> _T: 

584 """Create a new graph object that contains the subset of the nodes 

585 specified as input. Node number is preserved. 

586 

587 Parameters 

588 ---------- 

589 nodes : `QuantumNode` or iterable of `QuantumNode` 

590 

591 Returns 

592 ------- 

593 graph : instance of graph type 

594 An instance of the type from which the subset was created 

595 """ 

596 if not isinstance(nodes, Iterable): 

597 nodes = (nodes,) 

598 quantumSubgraph = self._connectedQuanta.subgraph(nodes).nodes 

599 quantumMap = defaultdict(set) 

600 

601 node: QuantumNode 

602 for node in quantumSubgraph: 

603 quantumMap[node.taskDef].add(node.quantum) 

604 

605 # convert to standard dict to prevent accidental key insertion 

606 quantumMap = dict(quantumMap.items()) 

607 # Create an empty graph, and then populate it with custom mapping 

608 newInst = type(self)({}) 

609 newInst._buildGraphs( 

610 quantumMap, 

611 _quantumToNodeId={n.quantum: n.nodeId for n in nodes}, 

612 _buildId=self._buildId, 

613 metadata=self._metadata, 

614 ) 

615 return newInst 

616 

617 def subsetToConnected(self: _T) -> Tuple[_T, ...]: 

618 """Generate a list of subgraphs where each is connected. 

619 

620 Returns 

621 ------- 

622 result : list of `QuantumGraph` 

623 A list of graphs that are each connected 

624 """ 

625 return tuple( 

626 self.subset(connectedSet) 

627 for connectedSet in nx.weakly_connected_components(self._connectedQuanta) 

628 ) 

629 

630 def determineInputsToQuantumNode(self, node: QuantumNode) -> Set[QuantumNode]: 

631 """Return a set of `QuantumNode` that are direct inputs to a specified 

632 node. 

633 

634 Parameters 

635 ---------- 

636 node : `QuantumNode` 

637 The node of the graph for which inputs are to be determined 

638 

639 Returns 

640 ------- 

641 set of `QuantumNode` 

642 All the nodes that are direct inputs to specified node 

643 """ 

644 return set(pred for pred in self._connectedQuanta.predecessors(node)) 

645 

646 def determineOutputsOfQuantumNode(self, node: QuantumNode) -> Set[QuantumNode]: 

647 """Return a set of `QuantumNode` that are direct outputs of a specified 

648 node. 

649 

650 Parameters 

651 ---------- 

652 node : `QuantumNode` 

653 The node of the graph for which outputs are to be determined 

654 

655 Returns 

656 ------- 

657 set of `QuantumNode` 

658 All the nodes that are direct outputs to specified node 

659 """ 

660 return set(succ for succ in self._connectedQuanta.successors(node)) 

661 

662 def determineConnectionsOfQuantumNode(self: _T, node: QuantumNode) -> _T: 

663 """Return a graph of `QuantumNode` that are direct inputs and outputs 

664 of a specified node. 

665 

666 Parameters 

667 ---------- 

668 node : `QuantumNode` 

669 The node of the graph for which connected nodes are to be 

670 determined. 

671 

672 Returns 

673 ------- 

674 graph : graph of `QuantumNode` 

675 All the nodes that are directly connected to specified node 

676 """ 

677 nodes = self.determineInputsToQuantumNode(node).union(self.determineOutputsOfQuantumNode(node)) 

678 nodes.add(node) 

679 return self.subset(nodes) 

680 

681 def determineAncestorsOfQuantumNode(self: _T, node: QuantumNode) -> _T: 

682 """Return a graph of the specified node and all the ancestor nodes 

683 directly reachable by walking edges. 

684 

685 Parameters 

686 ---------- 

687 node : `QuantumNode` 

688 The node for which all ansestors are to be determined 

689 

690 Returns 

691 ------- 

692 graph of `QuantumNode` 

693 Graph of node and all of its ansestors 

694 """ 

695 predecessorNodes = nx.ancestors(self._connectedQuanta, node) 

696 predecessorNodes.add(node) 

697 return self.subset(predecessorNodes) 

698 

699 def findCycle(self) -> List[Tuple[QuantumNode, QuantumNode]]: 

700 """Check a graph for the presense of cycles and returns the edges of 

701 any cycles found, or an empty list if there is no cycle. 

702 

703 Returns 

704 ------- 

705 result : list of tuple of `QuantumNode`, `QuantumNode` 

706 A list of any graph edges that form a cycle, or an empty list if 

707 there is no cycle. Empty list to so support if graph.find_cycle() 

708 syntax as an empty list is falsy. 

709 """ 

710 try: 

711 return nx.find_cycle(self._connectedQuanta) 

712 except nx.NetworkXNoCycle: 

713 return [] 

714 

715 def saveUri(self, uri): 

716 """Save `QuantumGraph` to the specified URI. 

717 

718 Parameters 

719 ---------- 

720 uri : convertible to `ResourcePath` 

721 URI to where the graph should be saved. 

722 """ 

723 buffer = self._buildSaveObject() 

724 path = ResourcePath(uri) 

725 if path.getExtension() not in (".qgraph"): 

726 raise TypeError(f"Can currently only save a graph in qgraph format not {uri}") 

727 path.write(buffer) # type: ignore # Ignore because bytearray is safe to use in place of bytes 

728 

729 @property 

730 def metadata(self) -> Optional[MappingProxyType[str, Any]]: 

731 """ """ 

732 if self._metadata is None: 

733 return None 

734 return MappingProxyType(self._metadata) 

735 

736 @classmethod 

737 def loadUri( 

738 cls, 

739 uri: ResourcePathExpression, 

740 universe: DimensionUniverse, 

741 nodes: Optional[Iterable[Union[str, uuid.UUID]]] = None, 

742 graphID: Optional[BuildId] = None, 

743 minimumVersion: int = 3, 

744 ) -> QuantumGraph: 

745 """Read `QuantumGraph` from a URI. 

746 

747 Parameters 

748 ---------- 

749 uri : convertible to `ResourcePath` 

750 URI from where to load the graph. 

751 universe: `~lsst.daf.butler.DimensionUniverse` 

752 DimensionUniverse instance, not used by the method itself but 

753 needed to ensure that registry data structures are initialized. 

754 nodes: iterable of `int` or None 

755 Numbers that correspond to nodes in the graph. If specified, only 

756 these nodes will be loaded. Defaults to None, in which case all 

757 nodes will be loaded. 

758 graphID : `str` or `None` 

759 If specified this ID is verified against the loaded graph prior to 

760 loading any Nodes. This defaults to None in which case no 

761 validation is done. 

762 minimumVersion : int 

763 Minimum version of a save file to load. Set to -1 to load all 

764 versions. Older versions may need to be loaded, and re-saved 

765 to upgrade them to the latest format before they can be used in 

766 production. 

767 

768 Returns 

769 ------- 

770 graph : `QuantumGraph` 

771 Resulting QuantumGraph instance. 

772 

773 Raises 

774 ------ 

775 TypeError 

776 Raised if pickle contains instance of a type other than 

777 QuantumGraph. 

778 ValueError 

779 Raised if one or more of the nodes requested is not in the 

780 `QuantumGraph` or if graphID parameter does not match the graph 

781 being loaded or if the supplied uri does not point at a valid 

782 `QuantumGraph` save file. 

783 

784 

785 Notes 

786 ----- 

787 Reading Quanta from pickle requires existence of singleton 

788 DimensionUniverse which is usually instantiated during Registry 

789 initialization. To make sure that DimensionUniverse exists this method 

790 accepts dummy DimensionUniverse argument. 

791 """ 

792 uri = ResourcePath(uri) 

793 # With ResourcePath we have the choice of always using a local file 

794 # or reading in the bytes directly. Reading in bytes can be more 

795 # efficient for reasonably-sized pickle files when the resource 

796 # is remote. For now use the local file variant. For a local file 

797 # as_local() does nothing. 

798 

799 if uri.getExtension() in (".pickle", ".pkl"): 

800 with uri.as_local() as local, open(local.ospath, "rb") as fd: 

801 warnings.warn("Pickle graphs are deprecated, please re-save your graph with the save method") 

802 qgraph = pickle.load(fd) 

803 elif uri.getExtension() in (".qgraph"): 

804 with LoadHelper(uri, minimumVersion) as loader: 

805 qgraph = loader.load(universe, nodes, graphID) 

806 else: 

807 raise ValueError("Only know how to handle files saved as `pickle`, `pkl`, or `qgraph`") 

808 if not isinstance(qgraph, QuantumGraph): 

809 raise TypeError(f"QuantumGraph save file contains unexpected object type: {type(qgraph)}") 

810 return qgraph 

811 

812 @classmethod 

813 def readHeader(cls, uri: ResourcePathExpression, minimumVersion: int = 3) -> Optional[str]: 

814 """Read the header of a `QuantumGraph` pointed to by the uri parameter 

815 and return it as a string. 

816 

817 Parameters 

818 ---------- 

819 uri : convertible to `ResourcePath` 

820 The location of the `QuantumGraph` to load. If the argument is a 

821 string, it must correspond to a valid `ResourcePath` path. 

822 minimumVersion : int 

823 Minimum version of a save file to load. Set to -1 to load all 

824 versions. Older versions may need to be loaded, and re-saved 

825 to upgrade them to the latest format before they can be used in 

826 production. 

827 

828 Returns 

829 ------- 

830 header : `str` or `None` 

831 The header associated with the specified `QuantumGraph` it there is 

832 one, else `None`. 

833 

834 Raises 

835 ------ 

836 ValueError 

837 Raised if `QuantuGraph` was saved as a pickle. 

838 Raised if the extention of the file specified by uri is not a 

839 `QuantumGraph` extention. 

840 """ 

841 uri = ResourcePath(uri) 

842 if uri.getExtension() in (".pickle", ".pkl"): 

843 raise ValueError("Reading a header from a pickle save is not supported") 

844 elif uri.getExtension() in (".qgraph"): 

845 return LoadHelper(uri, minimumVersion).readHeader() 

846 else: 

847 raise ValueError("Only know how to handle files saved as `qgraph`") 

848 

849 def buildAndPrintHeader(self): 

850 """Creates a header that would be used in a save of this object and 

851 prints it out to standard out. 

852 """ 

853 _, header = self._buildSaveObject(returnHeader=True) 

854 print(json.dumps(header)) 

855 

856 def save(self, file: io.IO[bytes]): 

857 """Save QuantumGraph to a file. 

858 

859 Presently we store QuantumGraph in pickle format, this could 

860 potentially change in the future if better format is found. 

861 

862 Parameters 

863 ---------- 

864 file : `io.BufferedIOBase` 

865 File to write pickle data open in binary mode. 

866 """ 

867 buffer = self._buildSaveObject() 

868 file.write(buffer) # type: ignore # Ignore because bytearray is safe to use in place of bytes 

869 

870 def _buildSaveObject(self, returnHeader: bool = False) -> Union[bytearray, Tuple[bytearray, Dict]]: 

871 # make some containers 

872 jsonData = deque() 

873 # node map is a list because json does not accept mapping keys that 

874 # are not strings, so we store a list of key, value pairs that will 

875 # be converted to a mapping on load 

876 nodeMap = [] 

877 taskDefMap = {} 

878 headerData = {} 

879 

880 # Store the QauntumGraph BuildId, this will allow validating BuildIds 

881 # at load time, prior to loading any QuantumNodes. Name chosen for 

882 # unlikely conflicts. 

883 headerData["GraphBuildID"] = self.graphID 

884 headerData["Metadata"] = self._metadata 

885 

886 # counter for the number of bytes processed thus far 

887 count = 0 

888 # serialize out the task Defs recording the start and end bytes of each 

889 # taskDef 

890 inverseLookup = self._datasetDict.inverse 

891 taskDef: TaskDef 

892 # sort by task label to ensure serialization happens in the same order 

893 for taskDef in self.taskGraph: 

894 # compressing has very little impact on saving or load time, but 

895 # a large impact on on disk size, so it is worth doing 

896 taskDescription = {} 

897 # save the fully qualified name. 

898 taskDescription["taskName"] = get_full_type_name(taskDef.taskClass) 

899 # save the config as a text stream that will be un-persisted on the 

900 # other end 

901 stream = io.StringIO() 

902 taskDef.config.saveToStream(stream) 

903 taskDescription["config"] = stream.getvalue() 

904 taskDescription["label"] = taskDef.label 

905 

906 inputs = [] 

907 outputs = [] 

908 

909 # Determine the connection between all of tasks and save that in 

910 # the header as a list of connections and edges in each task 

911 # this will help in un-persisting, and possibly in a "quick view" 

912 # method that does not require everything to be un-persisted 

913 # 

914 # Typing returns can't be parameter dependent 

915 for connection in inverseLookup[taskDef]: # type: ignore 

916 consumers = self._datasetDict.getConsumers(connection) 

917 producer = self._datasetDict.getProducer(connection) 

918 if taskDef in consumers: 

919 # This checks if the task consumes the connection directly 

920 # from the datastore or it is produced by another task 

921 producerLabel = producer.label if producer is not None else "datastore" 

922 inputs.append((producerLabel, connection)) 

923 elif taskDef not in consumers and producer is taskDef: 

924 # If there are no consumers for this tasks produced 

925 # connection, the output will be said to be the datastore 

926 # in which case the for loop will be a zero length loop 

927 if not consumers: 

928 outputs.append(("datastore", connection)) 

929 for td in consumers: 

930 outputs.append((td.label, connection)) 

931 

932 # dump to json string, and encode that string to bytes and then 

933 # conpress those bytes 

934 dump = lzma.compress(json.dumps(taskDescription).encode()) 

935 # record the sizing and relation information 

936 taskDefMap[taskDef.label] = { 

937 "bytes": (count, count + len(dump)), 

938 "inputs": inputs, 

939 "outputs": outputs, 

940 } 

941 count += len(dump) 

942 jsonData.append(dump) 

943 

944 headerData["TaskDefs"] = taskDefMap 

945 

946 # serialize the nodes, recording the start and end bytes of each node 

947 dimAccumulator = DimensionRecordsAccumulator() 

948 for node in self: 

949 # compressing has very little impact on saving or load time, but 

950 # a large impact on on disk size, so it is worth doing 

951 simpleNode = node.to_simple(accumulator=dimAccumulator) 

952 

953 dump = lzma.compress(simpleNode.json().encode()) 

954 jsonData.append(dump) 

955 nodeMap.append( 

956 ( 

957 str(node.nodeId), 

958 { 

959 "bytes": (count, count + len(dump)), 

960 "inputs": [str(n.nodeId) for n in self.determineInputsToQuantumNode(node)], 

961 "outputs": [str(n.nodeId) for n in self.determineOutputsOfQuantumNode(node)], 

962 }, 

963 ) 

964 ) 

965 count += len(dump) 

966 

967 headerData["DimensionRecords"] = { 

968 key: value.dict() for key, value in dimAccumulator.makeSerializedDimensionRecordMapping().items() 

969 } 

970 

971 # need to serialize this as a series of key,value tuples because of 

972 # a limitation on how json cant do anyting but strings as keys 

973 headerData["Nodes"] = nodeMap 

974 

975 # dump the headerData to json 

976 header_encode = lzma.compress(json.dumps(headerData).encode()) 

977 

978 # record the sizes as 2 unsigned long long numbers for a total of 16 

979 # bytes 

980 save_bytes = struct.pack(STRUCT_FMT_BASE, SAVE_VERSION) 

981 

982 fmt_string = DESERIALIZER_MAP[SAVE_VERSION].FMT_STRING() 

983 map_lengths = struct.pack(fmt_string, len(header_encode)) 

984 

985 # write each component of the save out in a deterministic order 

986 # buffer = io.BytesIO() 

987 # buffer.write(map_lengths) 

988 # buffer.write(taskDef_pickle) 

989 # buffer.write(map_pickle) 

990 buffer = bytearray() 

991 buffer.extend(MAGIC_BYTES) 

992 buffer.extend(save_bytes) 

993 buffer.extend(map_lengths) 

994 buffer.extend(header_encode) 

995 # Iterate over the length of pickleData, and for each element pop the 

996 # leftmost element off the deque and write it out. This is to save 

997 # memory, as the memory is added to the buffer object, it is removed 

998 # from from the container. 

999 # 

1000 # Only this section needs to worry about memory pressue because 

1001 # everything else written to the buffer prior to this pickle data is 

1002 # only on the order of kilobytes to low numbers of megabytes. 

1003 while jsonData: 

1004 buffer.extend(jsonData.popleft()) 

1005 if returnHeader: 

1006 return buffer, headerData 

1007 else: 

1008 return buffer 

1009 

1010 @classmethod 

1011 def load( 

1012 cls, 

1013 file: io.IO[bytes], 

1014 universe: DimensionUniverse, 

1015 nodes: Optional[Iterable[uuid.UUID]] = None, 

1016 graphID: Optional[BuildId] = None, 

1017 minimumVersion: int = 3, 

1018 ) -> QuantumGraph: 

1019 """Read QuantumGraph from a file that was made by `save`. 

1020 

1021 Parameters 

1022 ---------- 

1023 file : `io.IO` of bytes 

1024 File with pickle data open in binary mode. 

1025 universe: `~lsst.daf.butler.DimensionUniverse` 

1026 DimensionUniverse instance, not used by the method itself but 

1027 needed to ensure that registry data structures are initialized. 

1028 nodes: iterable of `int` or None 

1029 Numbers that correspond to nodes in the graph. If specified, only 

1030 these nodes will be loaded. Defaults to None, in which case all 

1031 nodes will be loaded. 

1032 graphID : `str` or `None` 

1033 If specified this ID is verified against the loaded graph prior to 

1034 loading any Nodes. This defaults to None in which case no 

1035 validation is done. 

1036 minimumVersion : int 

1037 Minimum version of a save file to load. Set to -1 to load all 

1038 versions. Older versions may need to be loaded, and re-saved 

1039 to upgrade them to the latest format before they can be used in 

1040 production. 

1041 

1042 Returns 

1043 ------- 

1044 graph : `QuantumGraph` 

1045 Resulting QuantumGraph instance. 

1046 

1047 Raises 

1048 ------ 

1049 TypeError 

1050 Raised if pickle contains instance of a type other than 

1051 QuantumGraph. 

1052 ValueError 

1053 Raised if one or more of the nodes requested is not in the 

1054 `QuantumGraph` or if graphID parameter does not match the graph 

1055 being loaded or if the supplied uri does not point at a valid 

1056 `QuantumGraph` save file. 

1057 

1058 Notes 

1059 ----- 

1060 Reading Quanta from pickle requires existence of singleton 

1061 DimensionUniverse which is usually instantiated during Registry 

1062 initialization. To make sure that DimensionUniverse exists this method 

1063 accepts dummy DimensionUniverse argument. 

1064 """ 

1065 # Try to see if the file handle contains pickle data, this will be 

1066 # removed in the future 

1067 try: 

1068 qgraph = pickle.load(file) 

1069 warnings.warn("Pickle graphs are deprecated, please re-save your graph with the save method") 

1070 except pickle.UnpicklingError: 

1071 # needed because we don't have Protocols yet 

1072 with LoadHelper(file, minimumVersion) as loader: # type: ignore 

1073 qgraph = loader.load(universe, nodes, graphID) 

1074 if not isinstance(qgraph, QuantumGraph): 

1075 raise TypeError(f"QuantumGraph pickle file has contains unexpected object type: {type(qgraph)}") 

1076 return qgraph 

1077 

1078 def iterTaskGraph(self) -> Generator[TaskDef, None, None]: 

1079 """Iterate over the `taskGraph` attribute in topological order 

1080 

1081 Yields 

1082 ------ 

1083 taskDef : `TaskDef` 

1084 `TaskDef` objects in topological order 

1085 """ 

1086 yield from nx.topological_sort(self.taskGraph) 

1087 

1088 @property 

1089 def graphID(self): 

1090 """Returns the ID generated by the graph at construction time""" 

1091 return self._buildId 

1092 

1093 def __iter__(self) -> Generator[QuantumNode, None, None]: 

1094 yield from nx.topological_sort(self._connectedQuanta) 

1095 

1096 def __len__(self) -> int: 

1097 return self._count 

1098 

1099 def __contains__(self, node: QuantumNode) -> bool: 

1100 return self._connectedQuanta.has_node(node) 

1101 

1102 def __getstate__(self) -> dict: 

1103 """Stores a compact form of the graph as a list of graph nodes, and a 

1104 tuple of task labels and task configs. The full graph can be 

1105 reconstructed with this information, and it preseves the ordering of 

1106 the graph ndoes. 

1107 """ 

1108 universe: Optional[DimensionUniverse] = None 

1109 for node in self: 

1110 dId = node.quantum.dataId 

1111 if dId is None: 

1112 continue 

1113 universe = dId.graph.universe 

1114 return {"reduced": self._buildSaveObject(), "graphId": self._buildId, "universe": universe} 

1115 

1116 def __setstate__(self, state: dict): 

1117 """Reconstructs the state of the graph from the information persisted 

1118 in getstate. 

1119 """ 

1120 buffer = io.BytesIO(state["reduced"]) 

1121 with LoadHelper(buffer, minimumVersion=3) as loader: 

1122 qgraph = loader.load(state["universe"], graphID=state["graphId"]) 

1123 

1124 self._metadata = qgraph._metadata 

1125 self._buildId = qgraph._buildId 

1126 self._datasetDict = qgraph._datasetDict 

1127 self._nodeIdMap = qgraph._nodeIdMap 

1128 self._count = len(qgraph) 

1129 self._taskToQuantumNode = qgraph._taskToQuantumNode 

1130 self._taskGraph = qgraph._taskGraph 

1131 self._connectedQuanta = qgraph._connectedQuanta 

1132 

1133 def __eq__(self, other: object) -> bool: 

1134 if not isinstance(other, QuantumGraph): 

1135 return False 

1136 if len(self) != len(other): 

1137 return False 

1138 for node in self: 

1139 if node not in other: 

1140 return False 

1141 if self.determineInputsToQuantumNode(node) != other.determineInputsToQuantumNode(node): 

1142 return False 

1143 if self.determineOutputsOfQuantumNode(node) != other.determineOutputsOfQuantumNode(node): 

1144 return False 

1145 if set(self.allDatasetTypes) != set(other.allDatasetTypes): 

1146 return False 

1147 return set(self.taskGraph) == set(other.taskGraph)