Coverage for python/lsst/pipe/base/graph/graph.py: 19%

Shortcuts on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

339 statements  

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23from lsst.daf.butler.core.datasets.type import DatasetType 

24__all__ = ("QuantumGraph", "IncompatibleGraphError") 

25 

26import warnings 

27 

28from lsst.daf.butler import Quantum, DatasetRef, ButlerURI, DimensionUniverse, DimensionRecordsAccumulator 

29 

30from collections import defaultdict, deque 

31 

32from itertools import chain 

33import io 

34import json 

35import networkx as nx 

36from networkx.drawing.nx_agraph import write_dot 

37import os 

38import pickle 

39import lzma 

40import struct 

41import time 

42import uuid 

43from types import MappingProxyType 

44from typing import (Any, DefaultDict, Dict, FrozenSet, Iterable, List, Mapping, MutableMapping, Set, 

45 Generator, Optional, Tuple, Union, TypeVar) 

46 

47from ..connections import iterConnections 

48from ..pipeline import TaskDef 

49 

50from ._implDetails import _DatasetTracker, DatasetTypeName, _pruner 

51from .quantumNode import QuantumNode, BuildId 

52from ._loadHelpers import LoadHelper 

53from ._versionDeserializers import DESERIALIZER_MAP 

54 

55 

56_T = TypeVar("_T", bound="QuantumGraph") 

57 

58# modify this constant any time the on disk representation of the save file 

59# changes, and update the load helpers to behave properly for each version. 

60SAVE_VERSION = 3 

61 

62# Strings used to describe the format for the preamble bytes in a file save 

63# The base is a big endian encoded unsigned short that is used to hold the 

64# file format version. This allows reading version bytes and determine which 

65# loading code should be used for the rest of the file 

66STRUCT_FMT_BASE = '>H' 

67# 

68# Version 1 

69# This marks a big endian encoded format with an unsigned short, an unsigned 

70# long long, and an unsigned long long in the byte stream 

71# Version 2 

72# A big endian encoded format with an unsigned long long byte stream used to 

73# indicate the total length of the entire header. 

74STRUCT_FMT_STRING = { 

75 1: '>QQ', 

76 2: '>Q' 

77} 

78 

79# magic bytes that help determine this is a graph save 

80MAGIC_BYTES = b"qgraph4\xf6\xe8\xa9" 

81 

82 

83class IncompatibleGraphError(Exception): 

84 """Exception class to indicate that a lookup by NodeId is impossible due 

85 to incompatibilities 

86 """ 

87 pass 

88 

89 

90class QuantumGraph: 

91 """QuantumGraph is a directed acyclic graph of `QuantumNode` objects 

92 

93 This data structure represents a concrete workflow generated from a 

94 `Pipeline`. 

95 

96 Parameters 

97 ---------- 

98 quanta : Mapping of `TaskDef` to sets of `Quantum` 

99 This maps tasks (and their configs) to the sets of data they are to 

100 process. 

101 metadata : Optional Mapping of `str` to primitives 

102 This is an optional parameter of extra data to carry with the graph. 

103 Entries in this mapping should be able to be serialized in JSON. 

104 

105 Raises 

106 ------ 

107 ValueError 

108 Raised if the graph is pruned such that some tasks no longer have nodes 

109 associated with them. 

110 """ 

111 def __init__(self, quanta: Mapping[TaskDef, Set[Quantum]], 

112 metadata: Optional[Mapping[str, Any]] = None, 

113 pruneRefs: Optional[Iterable[DatasetRef]] = None): 

114 self._buildGraphs(quanta, metadata=metadata, pruneRefs=pruneRefs) 

115 

116 def _buildGraphs(self, 

117 quanta: Mapping[TaskDef, Set[Quantum]], 

118 *, 

119 _quantumToNodeId: Optional[Mapping[Quantum, uuid.UUID]] = None, 

120 _buildId: Optional[BuildId] = None, 

121 metadata: Optional[Mapping[str, Any]] = None, 

122 pruneRefs: Optional[Iterable[DatasetRef]] = None): 

123 """Builds the graph that is used to store the relation between tasks, 

124 and the graph that holds the relations between quanta 

125 """ 

126 self._metadata = metadata 

127 self._buildId = _buildId if _buildId is not None else BuildId(f"{time.time()}-{os.getpid()}") 

128 # Data structures used to identify relations between components; 

129 # DatasetTypeName -> TaskDef for task, 

130 # and DatasetRef -> QuantumNode for the quanta 

131 self._datasetDict = _DatasetTracker[DatasetTypeName, TaskDef](createInverse=True) 

132 self._datasetRefDict = _DatasetTracker[DatasetRef, QuantumNode]() 

133 

134 self._nodeIdMap: Dict[uuid.UUID, QuantumNode] = {} 

135 self._taskToQuantumNode: MutableMapping[TaskDef, Set[QuantumNode]] = defaultdict(set) 

136 for taskDef, quantumSet in quanta.items(): 

137 connections = taskDef.connections 

138 

139 # For each type of connection in the task, add a key to the 

140 # `_DatasetTracker` for the connections name, with a value of 

141 # the TaskDef in the appropriate field 

142 for inpt in iterConnections(connections, ("inputs", "prerequisiteInputs", "initInputs")): 

143 self._datasetDict.addConsumer(DatasetTypeName(inpt.name), taskDef) 

144 

145 for output in iterConnections(connections, ("outputs",)): 

146 self._datasetDict.addProducer(DatasetTypeName(output.name), taskDef) 

147 

148 # For each `Quantum` in the set of all `Quantum` for this task, 

149 # add a key to the `_DatasetTracker` that is a `DatasetRef` for one 

150 # of the individual datasets inside the `Quantum`, with a value of 

151 # a newly created QuantumNode to the appropriate input/output 

152 # field. 

153 for quantum in quantumSet: 

154 if _quantumToNodeId: 

155 if (nodeId := _quantumToNodeId.get(quantum)) is None: 

156 raise ValueError("If _quantuMToNodeNumber is not None, all quanta must have an " 

157 "associated value in the mapping") 

158 else: 

159 nodeId = uuid.uuid4() 

160 

161 inits = quantum.initInputs.values() 

162 inputs = quantum.inputs.values() 

163 value = QuantumNode(quantum, taskDef, nodeId) 

164 self._taskToQuantumNode[taskDef].add(value) 

165 self._nodeIdMap[nodeId] = value 

166 

167 for dsRef in chain(inits, inputs): 

168 # unfortunately, `Quantum` allows inits to be individual 

169 # `DatasetRef`s or an Iterable of such, so there must 

170 # be an instance check here 

171 if isinstance(dsRef, Iterable): 

172 for sub in dsRef: 

173 if sub.isComponent(): 

174 sub = sub.makeCompositeRef() 

175 self._datasetRefDict.addConsumer(sub, value) 

176 else: 

177 if dsRef.isComponent(): 

178 dsRef = dsRef.makeCompositeRef() 

179 self._datasetRefDict.addConsumer(dsRef, value) 

180 for dsRef in chain.from_iterable(quantum.outputs.values()): 

181 self._datasetRefDict.addProducer(dsRef, value) 

182 

183 if pruneRefs is not None: 

184 # track what refs were pruned and prune the graph 

185 prunes = set() 

186 _pruner(self._datasetRefDict, pruneRefs, alreadyPruned=prunes) 

187 

188 # recreate the taskToQuantumNode dict removing nodes that have been 

189 # pruned. Keep track of task defs that now have no QuantumNodes 

190 emptyTasks: Set[str] = set() 

191 newTaskToQuantumNode: DefaultDict[TaskDef, Set[QuantumNode]] = defaultdict(set) 

192 # accumulate all types 

193 types_ = set() 

194 # tracker for any pruneRefs that have caused tasks to have no nodes 

195 # This helps the user find out what caused the issues seen. 

196 culprits = set() 

197 # Find all the types from the refs to prune 

198 for r in pruneRefs: 

199 types_.add(r.datasetType) 

200 

201 # For each of the tasks, and their associated nodes, remove any 

202 # any nodes that were pruned. If there are no nodes associated 

203 # with a task, record that task, and find out if that was due to 

204 # a type from an input ref to prune. 

205 for td, taskNodes in self._taskToQuantumNode.items(): 

206 diff = taskNodes.difference(prunes) 

207 if len(diff) == 0: 

208 if len(taskNodes) != 0: 

209 tp: DatasetType 

210 for tp in types_: 

211 if ((tmpRefs := next(iter(taskNodes)).quantum.inputs.get(tp)) and not 

212 set(tmpRefs).difference(pruneRefs)): 

213 culprits.add(tp.name) 

214 emptyTasks.add(td.label) 

215 newTaskToQuantumNode[td] = diff 

216 

217 # update the internal dict 

218 self._taskToQuantumNode = newTaskToQuantumNode 

219 

220 if emptyTasks: 

221 raise ValueError(f"{', '.join(emptyTasks)} task(s) have no nodes associated with them " 

222 f"after graph pruning; {', '.join(culprits)} caused over-pruning") 

223 

224 # Graph of quanta relations 

225 self._connectedQuanta = self._datasetRefDict.makeNetworkXGraph() 

226 self._count = len(self._connectedQuanta) 

227 

228 # Graph of task relations, used in various methods 

229 self._taskGraph = self._datasetDict.makeNetworkXGraph() 

230 

231 # convert default dict into a regular to prevent accidental key 

232 # insertion 

233 self._taskToQuantumNode = dict(self._taskToQuantumNode.items()) 

234 

235 @property 

236 def taskGraph(self) -> nx.DiGraph: 

237 """Return a graph representing the relations between the tasks inside 

238 the quantum graph. 

239 

240 Returns 

241 ------- 

242 taskGraph : `networkx.Digraph` 

243 Internal datastructure that holds relations of `TaskDef` objects 

244 """ 

245 return self._taskGraph 

246 

247 @property 

248 def graph(self) -> nx.DiGraph: 

249 """Return a graph representing the relations between all the 

250 `QuantumNode` objects. Largely it should be preferred to iterate 

251 over, and use methods of this class, but sometimes direct access to 

252 the networkx object may be helpful 

253 

254 Returns 

255 ------- 

256 graph : `networkx.Digraph` 

257 Internal datastructure that holds relations of `QuantumNode` 

258 objects 

259 """ 

260 return self._connectedQuanta 

261 

262 @property 

263 def inputQuanta(self) -> Iterable[QuantumNode]: 

264 """Make a `list` of all `QuantumNode` objects that are 'input' nodes 

265 to the graph, meaning those nodes to not depend on any other nodes in 

266 the graph. 

267 

268 Returns 

269 ------- 

270 inputNodes : iterable of `QuantumNode` 

271 A list of nodes that are inputs to the graph 

272 """ 

273 return (q for q, n in self._connectedQuanta.in_degree if n == 0) 

274 

275 @property 

276 def outputQuanta(self) -> Iterable[QuantumNode]: 

277 """Make a `list` of all `QuantumNode` objects that are 'output' nodes 

278 to the graph, meaning those nodes have no nodes that depend them in 

279 the graph. 

280 

281 Returns 

282 ------- 

283 outputNodes : iterable of `QuantumNode` 

284 A list of nodes that are outputs of the graph 

285 """ 

286 return [q for q, n in self._connectedQuanta.out_degree if n == 0] 

287 

288 @property 

289 def allDatasetTypes(self) -> Tuple[DatasetTypeName, ...]: 

290 """Return all the `DatasetTypeName` objects that are contained inside 

291 the graph. 

292 

293 Returns 

294 ------- 

295 tuple of `DatasetTypeName` 

296 All the data set type names that are present in the graph 

297 """ 

298 return tuple(self._datasetDict.keys()) 

299 

300 @property 

301 def isConnected(self) -> bool: 

302 """Return True if all of the nodes in the graph are connected, ignores 

303 directionality of connections. 

304 """ 

305 return nx.is_weakly_connected(self._connectedQuanta) 

306 

307 def pruneGraphFromRefs(self: _T, refs: Iterable[DatasetRef]) -> _T: 

308 r"""Return a graph pruned of input `~lsst.daf.butler.DatasetRef`\ s 

309 and nodes which depend on them. 

310 

311 Parameters 

312 ---------- 

313 refs : `Iterable` of `DatasetRef` 

314 Refs which should be removed from resulting graph 

315 

316 Returns 

317 ------- 

318 graph : `QuantumGraph` 

319 A graph that has been pruned of specified refs and the nodes that 

320 depend on them. 

321 """ 

322 newInst = object.__new__(type(self)) 

323 quantumMap = defaultdict(set) 

324 for node in self: 

325 quantumMap[node.taskDef].add(node.quantum) 

326 

327 # convert to standard dict to prevent accidental key insertion 

328 quantumMap = dict(quantumMap.items()) 

329 

330 newInst._buildGraphs(quantumMap, _quantumToNodeId={n.quantum: n.nodeId for n in self}, 

331 metadata=self._metadata, pruneRefs=refs) 

332 return newInst 

333 

334 def getQuantumNodeByNodeId(self, nodeId: uuid.UUID) -> QuantumNode: 

335 """Lookup a `QuantumNode` from an id associated with the node. 

336 

337 Parameters 

338 ---------- 

339 nodeId : `NodeId` 

340 The number associated with a node 

341 

342 Returns 

343 ------- 

344 node : `QuantumNode` 

345 The node corresponding with input number 

346 

347 Raises 

348 ------ 

349 KeyError 

350 Raised if the requested nodeId is not in the graph. 

351 """ 

352 return self._nodeIdMap[nodeId] 

353 

354 def getQuantaForTask(self, taskDef: TaskDef) -> FrozenSet[Quantum]: 

355 """Return all the `Quantum` associated with a `TaskDef`. 

356 

357 Parameters 

358 ---------- 

359 taskDef : `TaskDef` 

360 The `TaskDef` for which `Quantum` are to be queried 

361 

362 Returns 

363 ------- 

364 frozenset of `Quantum` 

365 The `set` of `Quantum` that is associated with the specified 

366 `TaskDef`. 

367 """ 

368 return frozenset(node.quantum for node in self._taskToQuantumNode[taskDef]) 

369 

370 def getNodesForTask(self, taskDef: TaskDef) -> FrozenSet[QuantumNode]: 

371 """Return all the `QuantumNodes` associated with a `TaskDef`. 

372 

373 Parameters 

374 ---------- 

375 taskDef : `TaskDef` 

376 The `TaskDef` for which `Quantum` are to be queried 

377 

378 Returns 

379 ------- 

380 frozenset of `QuantumNodes` 

381 The `frozenset` of `QuantumNodes` that is associated with the 

382 specified `TaskDef`. 

383 """ 

384 return frozenset(self._taskToQuantumNode[taskDef]) 

385 

386 def findTasksWithInput(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]: 

387 """Find all tasks that have the specified dataset type name as an 

388 input. 

389 

390 Parameters 

391 ---------- 

392 datasetTypeName : `str` 

393 A string representing the name of a dataset type to be queried, 

394 can also accept a `DatasetTypeName` which is a `NewType` of str for 

395 type safety in static type checking. 

396 

397 Returns 

398 ------- 

399 tasks : iterable of `TaskDef` 

400 `TaskDef` objects that have the specified `DatasetTypeName` as an 

401 input, list will be empty if no tasks use specified 

402 `DatasetTypeName` as an input. 

403 

404 Raises 

405 ------ 

406 KeyError 

407 Raised if the `DatasetTypeName` is not part of the `QuantumGraph` 

408 """ 

409 return (c for c in self._datasetDict.getConsumers(datasetTypeName)) 

410 

411 def findTaskWithOutput(self, datasetTypeName: DatasetTypeName) -> Optional[TaskDef]: 

412 """Find all tasks that have the specified dataset type name as an 

413 output. 

414 

415 Parameters 

416 ---------- 

417 datasetTypeName : `str` 

418 A string representing the name of a dataset type to be queried, 

419 can also accept a `DatasetTypeName` which is a `NewType` of str for 

420 type safety in static type checking. 

421 

422 Returns 

423 ------- 

424 `TaskDef` or `None` 

425 `TaskDef` that outputs `DatasetTypeName` as an output or None if 

426 none of the tasks produce this `DatasetTypeName`. 

427 

428 Raises 

429 ------ 

430 KeyError 

431 Raised if the `DatasetTypeName` is not part of the `QuantumGraph` 

432 """ 

433 return self._datasetDict.getProducer(datasetTypeName) 

434 

435 def tasksWithDSType(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]: 

436 """Find all tasks that are associated with the specified dataset type 

437 name. 

438 

439 Parameters 

440 ---------- 

441 datasetTypeName : `str` 

442 A string representing the name of a dataset type to be queried, 

443 can also accept a `DatasetTypeName` which is a `NewType` of str for 

444 type safety in static type checking. 

445 

446 Returns 

447 ------- 

448 result : iterable of `TaskDef` 

449 `TaskDef` objects that are associated with the specified 

450 `DatasetTypeName` 

451 

452 Raises 

453 ------ 

454 KeyError 

455 Raised if the `DatasetTypeName` is not part of the `QuantumGraph` 

456 """ 

457 return self._datasetDict.getAll(datasetTypeName) 

458 

459 def findTaskDefByName(self, taskName: str) -> List[TaskDef]: 

460 """Determine which `TaskDef` objects in this graph are associated 

461 with a `str` representing a task name (looks at the taskName property 

462 of `TaskDef` objects). 

463 

464 Returns a list of `TaskDef` objects as a `PipelineTask` may appear 

465 multiple times in a graph with different labels. 

466 

467 Parameters 

468 ---------- 

469 taskName : str 

470 Name of a task to search for 

471 

472 Returns 

473 ------- 

474 result : list of `TaskDef` 

475 List of the `TaskDef` objects that have the name specified. 

476 Multiple values are returned in the case that a task is used 

477 multiple times with different labels. 

478 """ 

479 results = [] 

480 for task in self._taskToQuantumNode.keys(): 

481 split = task.taskName.split('.') 

482 if split[-1] == taskName: 

483 results.append(task) 

484 return results 

485 

486 def findTaskDefByLabel(self, label: str) -> Optional[TaskDef]: 

487 """Determine which `TaskDef` objects in this graph are associated 

488 with a `str` representing a tasks label. 

489 

490 Parameters 

491 ---------- 

492 taskName : str 

493 Name of a task to search for 

494 

495 Returns 

496 ------- 

497 result : `TaskDef` 

498 `TaskDef` objects that has the specified label. 

499 """ 

500 for task in self._taskToQuantumNode.keys(): 

501 if label == task.label: 

502 return task 

503 return None 

504 

505 def findQuantaWithDSType(self, datasetTypeName: DatasetTypeName) -> Set[Quantum]: 

506 """Return all the `Quantum` that contain a specified `DatasetTypeName`. 

507 

508 Parameters 

509 ---------- 

510 datasetTypeName : `str` 

511 The name of the dataset type to search for as a string, 

512 can also accept a `DatasetTypeName` which is a `NewType` of str for 

513 type safety in static type checking. 

514 

515 Returns 

516 ------- 

517 result : `set` of `QuantumNode` objects 

518 A `set` of `QuantumNode`s that contain specified `DatasetTypeName` 

519 

520 Raises 

521 ------ 

522 KeyError 

523 Raised if the `DatasetTypeName` is not part of the `QuantumGraph` 

524 

525 """ 

526 tasks = self._datasetDict.getAll(datasetTypeName) 

527 result: Set[Quantum] = set() 

528 result = result.union(quantum for task in tasks for quantum in self.getQuantaForTask(task)) 

529 return result 

530 

531 def checkQuantumInGraph(self, quantum: Quantum) -> bool: 

532 """Check if specified quantum appears in the graph as part of a node. 

533 

534 Parameters 

535 ---------- 

536 quantum : `Quantum` 

537 The quantum to search for 

538 

539 Returns 

540 ------- 

541 `bool` 

542 The result of searching for the quantum 

543 """ 

544 for node in self: 

545 if quantum == node.quantum: 

546 return True 

547 return False 

548 

549 def writeDotGraph(self, output: Union[str, io.BufferedIOBase]): 

550 """Write out the graph as a dot graph. 

551 

552 Parameters 

553 ---------- 

554 output : str or `io.BufferedIOBase` 

555 Either a filesystem path to write to, or a file handle object 

556 """ 

557 write_dot(self._connectedQuanta, output) 

558 

559 def subset(self: _T, nodes: Union[QuantumNode, Iterable[QuantumNode]]) -> _T: 

560 """Create a new graph object that contains the subset of the nodes 

561 specified as input. Node number is preserved. 

562 

563 Parameters 

564 ---------- 

565 nodes : `QuantumNode` or iterable of `QuantumNode` 

566 

567 Returns 

568 ------- 

569 graph : instance of graph type 

570 An instance of the type from which the subset was created 

571 """ 

572 if not isinstance(nodes, Iterable): 

573 nodes = (nodes, ) 

574 quantumSubgraph = self._connectedQuanta.subgraph(nodes).nodes 

575 quantumMap = defaultdict(set) 

576 

577 node: QuantumNode 

578 for node in quantumSubgraph: 

579 quantumMap[node.taskDef].add(node.quantum) 

580 

581 # convert to standard dict to prevent accidental key insertion 

582 quantumMap = dict(quantumMap.items()) 

583 # Create an empty graph, and then populate it with custom mapping 

584 newInst = type(self)({}) 

585 newInst._buildGraphs(quantumMap, _quantumToNodeId={n.quantum: n.nodeId for n in nodes}, 

586 _buildId=self._buildId, metadata=self._metadata) 

587 return newInst 

588 

589 def subsetToConnected(self: _T) -> Tuple[_T, ...]: 

590 """Generate a list of subgraphs where each is connected. 

591 

592 Returns 

593 ------- 

594 result : list of `QuantumGraph` 

595 A list of graphs that are each connected 

596 """ 

597 return tuple(self.subset(connectedSet) 

598 for connectedSet in nx.weakly_connected_components(self._connectedQuanta)) 

599 

600 def determineInputsToQuantumNode(self, node: QuantumNode) -> Set[QuantumNode]: 

601 """Return a set of `QuantumNode` that are direct inputs to a specified 

602 node. 

603 

604 Parameters 

605 ---------- 

606 node : `QuantumNode` 

607 The node of the graph for which inputs are to be determined 

608 

609 Returns 

610 ------- 

611 set of `QuantumNode` 

612 All the nodes that are direct inputs to specified node 

613 """ 

614 return set(pred for pred in self._connectedQuanta.predecessors(node)) 

615 

616 def determineOutputsOfQuantumNode(self, node: QuantumNode) -> Set[QuantumNode]: 

617 """Return a set of `QuantumNode` that are direct outputs of a specified 

618 node. 

619 

620 Parameters 

621 ---------- 

622 node : `QuantumNode` 

623 The node of the graph for which outputs are to be determined 

624 

625 Returns 

626 ------- 

627 set of `QuantumNode` 

628 All the nodes that are direct outputs to specified node 

629 """ 

630 return set(succ for succ in self._connectedQuanta.successors(node)) 

631 

632 def determineConnectionsOfQuantumNode(self: _T, node: QuantumNode) -> _T: 

633 """Return a graph of `QuantumNode` that are direct inputs and outputs 

634 of a specified node. 

635 

636 Parameters 

637 ---------- 

638 node : `QuantumNode` 

639 The node of the graph for which connected nodes are to be 

640 determined. 

641 

642 Returns 

643 ------- 

644 graph : graph of `QuantumNode` 

645 All the nodes that are directly connected to specified node 

646 """ 

647 nodes = self.determineInputsToQuantumNode(node).union(self.determineOutputsOfQuantumNode(node)) 

648 nodes.add(node) 

649 return self.subset(nodes) 

650 

651 def determineAncestorsOfQuantumNode(self: _T, node: QuantumNode) -> _T: 

652 """Return a graph of the specified node and all the ancestor nodes 

653 directly reachable by walking edges. 

654 

655 Parameters 

656 ---------- 

657 node : `QuantumNode` 

658 The node for which all ansestors are to be determined 

659 

660 Returns 

661 ------- 

662 graph of `QuantumNode` 

663 Graph of node and all of its ansestors 

664 """ 

665 predecessorNodes = nx.ancestors(self._connectedQuanta, node) 

666 predecessorNodes.add(node) 

667 return self.subset(predecessorNodes) 

668 

669 def findCycle(self) -> List[Tuple[QuantumNode, QuantumNode]]: 

670 """Check a graph for the presense of cycles and returns the edges of 

671 any cycles found, or an empty list if there is no cycle. 

672 

673 Returns 

674 ------- 

675 result : list of tuple of `QuantumNode`, `QuantumNode` 

676 A list of any graph edges that form a cycle, or an empty list if 

677 there is no cycle. Empty list to so support if graph.find_cycle() 

678 syntax as an empty list is falsy. 

679 """ 

680 try: 

681 return nx.find_cycle(self._connectedQuanta) 

682 except nx.NetworkXNoCycle: 

683 return [] 

684 

685 def saveUri(self, uri): 

686 """Save `QuantumGraph` to the specified URI. 

687 

688 Parameters 

689 ---------- 

690 uri : `ButlerURI` or `str` 

691 URI to where the graph should be saved. 

692 """ 

693 buffer = self._buildSaveObject() 

694 butlerUri = ButlerURI(uri) 

695 if butlerUri.getExtension() not in (".qgraph"): 

696 raise TypeError(f"Can currently only save a graph in qgraph format not {uri}") 

697 butlerUri.write(buffer) # type: ignore # Ignore because bytearray is safe to use in place of bytes 

698 

699 @property 

700 def metadata(self) -> Optional[MappingProxyType[str, Any]]: 

701 """ 

702 """ 

703 if self._metadata is None: 

704 return None 

705 return MappingProxyType(self._metadata) 

706 

707 @classmethod 

708 def loadUri(cls, uri: Union[ButlerURI, str], universe: DimensionUniverse, 

709 nodes: Optional[Iterable[Union[str, uuid.UUID]]] = None, 

710 graphID: Optional[BuildId] = None, 

711 minimumVersion: int = 3 

712 ) -> QuantumGraph: 

713 """Read `QuantumGraph` from a URI. 

714 

715 Parameters 

716 ---------- 

717 uri : `ButlerURI` or `str` 

718 URI from where to load the graph. 

719 universe: `~lsst.daf.butler.DimensionUniverse` 

720 DimensionUniverse instance, not used by the method itself but 

721 needed to ensure that registry data structures are initialized. 

722 nodes: iterable of `int` or None 

723 Numbers that correspond to nodes in the graph. If specified, only 

724 these nodes will be loaded. Defaults to None, in which case all 

725 nodes will be loaded. 

726 graphID : `str` or `None` 

727 If specified this ID is verified against the loaded graph prior to 

728 loading any Nodes. This defaults to None in which case no 

729 validation is done. 

730 minimumVersion : int 

731 Minimum version of a save file to load. Set to -1 to load all 

732 versions. Older versions may need to be loaded, and re-saved 

733 to upgrade them to the latest format before they can be used in 

734 production. 

735 

736 Returns 

737 ------- 

738 graph : `QuantumGraph` 

739 Resulting QuantumGraph instance. 

740 

741 Raises 

742 ------ 

743 TypeError 

744 Raised if pickle contains instance of a type other than 

745 QuantumGraph. 

746 ValueError 

747 Raised if one or more of the nodes requested is not in the 

748 `QuantumGraph` or if graphID parameter does not match the graph 

749 being loaded or if the supplied uri does not point at a valid 

750 `QuantumGraph` save file. 

751 

752 

753 Notes 

754 ----- 

755 Reading Quanta from pickle requires existence of singleton 

756 DimensionUniverse which is usually instantiated during Registry 

757 initialization. To make sure that DimensionUniverse exists this method 

758 accepts dummy DimensionUniverse argument. 

759 """ 

760 uri = ButlerURI(uri) 

761 # With ButlerURI we have the choice of always using a local file 

762 # or reading in the bytes directly. Reading in bytes can be more 

763 # efficient for reasonably-sized pickle files when the resource 

764 # is remote. For now use the local file variant. For a local file 

765 # as_local() does nothing. 

766 

767 if uri.getExtension() in (".pickle", ".pkl"): 

768 with uri.as_local() as local, open(local.ospath, "rb") as fd: 

769 warnings.warn("Pickle graphs are deprecated, please re-save your graph with the save method") 

770 qgraph = pickle.load(fd) 

771 elif uri.getExtension() in ('.qgraph'): 

772 with LoadHelper(uri, minimumVersion) as loader: 

773 qgraph = loader.load(universe, nodes, graphID) 

774 else: 

775 raise ValueError("Only know how to handle files saved as `pickle`, `pkl`, or `qgraph`") 

776 if not isinstance(qgraph, QuantumGraph): 

777 raise TypeError(f"QuantumGraph save file contains unexpected object type: {type(qgraph)}") 

778 return qgraph 

779 

780 @classmethod 

781 def readHeader(cls, uri: Union[ButlerURI, str], minimumVersion: int = 3) -> Optional[str]: 

782 """Read the header of a `QuantumGraph` pointed to by the uri parameter 

783 and return it as a string. 

784 

785 Parameters 

786 ---------- 

787 uri : `~lsst.daf.butler.ButlerURI` or `str` 

788 The location of the `QuantumGraph` to load. If the argument is a 

789 string, it must correspond to a valid `~lsst.daf.butler.ButlerURI` 

790 path. 

791 minimumVersion : int 

792 Minimum version of a save file to load. Set to -1 to load all 

793 versions. Older versions may need to be loaded, and re-saved 

794 to upgrade them to the latest format before they can be used in 

795 production. 

796 

797 Returns 

798 ------- 

799 header : `str` or `None` 

800 The header associated with the specified `QuantumGraph` it there is 

801 one, else `None`. 

802 

803 Raises 

804 ------ 

805 ValueError 

806 Raised if `QuantuGraph` was saved as a pickle. 

807 Raised if the extention of the file specified by uri is not a 

808 `QuantumGraph` extention. 

809 """ 

810 uri = ButlerURI(uri) 

811 if uri.getExtension() in (".pickle", ".pkl"): 

812 raise ValueError("Reading a header from a pickle save is not supported") 

813 elif uri.getExtension() in ('.qgraph'): 

814 return LoadHelper(uri, minimumVersion).readHeader() 

815 else: 

816 raise ValueError("Only know how to handle files saved as `qgraph`") 

817 

818 def buildAndPrintHeader(self): 

819 """Creates a header that would be used in a save of this object and 

820 prints it out to standard out. 

821 """ 

822 _, header = self._buildSaveObject(returnHeader=True) 

823 print(json.dumps(header)) 

824 

825 def save(self, file: io.IO[bytes]): 

826 """Save QuantumGraph to a file. 

827 

828 Presently we store QuantumGraph in pickle format, this could 

829 potentially change in the future if better format is found. 

830 

831 Parameters 

832 ---------- 

833 file : `io.BufferedIOBase` 

834 File to write pickle data open in binary mode. 

835 """ 

836 buffer = self._buildSaveObject() 

837 file.write(buffer) # type: ignore # Ignore because bytearray is safe to use in place of bytes 

838 

839 def _buildSaveObject(self, returnHeader: bool = False) -> Union[bytearray, Tuple[bytearray, Dict]]: 

840 # make some containers 

841 jsonData = deque() 

842 # node map is a list because json does not accept mapping keys that 

843 # are not strings, so we store a list of key, value pairs that will 

844 # be converted to a mapping on load 

845 nodeMap = [] 

846 taskDefMap = {} 

847 headerData = {} 

848 

849 # Store the QauntumGraph BuildId, this will allow validating BuildIds 

850 # at load time, prior to loading any QuantumNodes. Name chosen for 

851 # unlikely conflicts. 

852 headerData['GraphBuildID'] = self.graphID 

853 headerData['Metadata'] = self._metadata 

854 

855 # counter for the number of bytes processed thus far 

856 count = 0 

857 # serialize out the task Defs recording the start and end bytes of each 

858 # taskDef 

859 inverseLookup = self._datasetDict.inverse 

860 taskDef: TaskDef 

861 # sort by task label to ensure serialization happens in the same order 

862 for taskDef in self.taskGraph: 

863 # compressing has very little impact on saving or load time, but 

864 # a large impact on on disk size, so it is worth doing 

865 taskDescription = {} 

866 # save the fully qualified name, as TaskDef not not require this, 

867 # but by doing so can save space and is easier to transport 

868 taskDescription['taskName'] = f"{taskDef.taskClass.__module__}.{taskDef.taskClass.__qualname__}" 

869 # save the config as a text stream that will be un-persisted on the 

870 # other end 

871 stream = io.StringIO() 

872 taskDef.config.saveToStream(stream) 

873 taskDescription['config'] = stream.getvalue() 

874 taskDescription['label'] = taskDef.label 

875 

876 inputs = [] 

877 outputs = [] 

878 

879 # Determine the connection between all of tasks and save that in 

880 # the header as a list of connections and edges in each task 

881 # this will help in un-persisting, and possibly in a "quick view" 

882 # method that does not require everything to be un-persisted 

883 # 

884 # Typing returns can't be parameter dependent 

885 for connection in inverseLookup[taskDef]: # type: ignore 

886 consumers = self._datasetDict.getConsumers(connection) 

887 producer = self._datasetDict.getProducer(connection) 

888 if taskDef in consumers: 

889 # This checks if the task consumes the connection directly 

890 # from the datastore or it is produced by another task 

891 producerLabel = producer.label if producer is not None else "datastore" 

892 inputs.append((producerLabel, connection)) 

893 elif taskDef not in consumers and producer is taskDef: 

894 # If there are no consumers for this tasks produced 

895 # connection, the output will be said to be the datastore 

896 # in which case the for loop will be a zero length loop 

897 if not consumers: 

898 outputs.append(("datastore", connection)) 

899 for td in consumers: 

900 outputs.append((td.label, connection)) 

901 

902 # dump to json string, and encode that string to bytes and then 

903 # conpress those bytes 

904 dump = lzma.compress(json.dumps(taskDescription).encode()) 

905 # record the sizing and relation information 

906 taskDefMap[taskDef.label] = {"bytes": (count, count+len(dump)), 

907 "inputs": inputs, 

908 "outputs": outputs} 

909 count += len(dump) 

910 jsonData.append(dump) 

911 

912 headerData['TaskDefs'] = taskDefMap 

913 

914 # serialize the nodes, recording the start and end bytes of each node 

915 dimAccumulator = DimensionRecordsAccumulator() 

916 for node in self: 

917 # compressing has very little impact on saving or load time, but 

918 # a large impact on on disk size, so it is worth doing 

919 simpleNode = node.to_simple(accumulator=dimAccumulator) 

920 

921 dump = lzma.compress(simpleNode.json().encode()) 

922 jsonData.append(dump) 

923 nodeMap.append((str(node.nodeId), 

924 {"bytes": (count, count+len(dump)), 

925 "inputs": [str(n.nodeId) for n in self.determineInputsToQuantumNode(node)], 

926 "outputs": [str(n.nodeId) for n in self.determineOutputsOfQuantumNode(node)]} 

927 )) 

928 count += len(dump) 

929 

930 headerData['DimensionRecords'] = {key: value.dict() for key, value in 

931 dimAccumulator.makeSerializedDimensionRecordMapping().items()} 

932 

933 # need to serialize this as a series of key,value tuples because of 

934 # a limitation on how json cant do anyting but strings as keys 

935 headerData['Nodes'] = nodeMap 

936 

937 # dump the headerData to json 

938 header_encode = lzma.compress(json.dumps(headerData).encode()) 

939 

940 # record the sizes as 2 unsigned long long numbers for a total of 16 

941 # bytes 

942 save_bytes = struct.pack(STRUCT_FMT_BASE, SAVE_VERSION) 

943 

944 fmt_string = DESERIALIZER_MAP[SAVE_VERSION].FMT_STRING() 

945 map_lengths = struct.pack(fmt_string, len(header_encode)) 

946 

947 # write each component of the save out in a deterministic order 

948 # buffer = io.BytesIO() 

949 # buffer.write(map_lengths) 

950 # buffer.write(taskDef_pickle) 

951 # buffer.write(map_pickle) 

952 buffer = bytearray() 

953 buffer.extend(MAGIC_BYTES) 

954 buffer.extend(save_bytes) 

955 buffer.extend(map_lengths) 

956 buffer.extend(header_encode) 

957 # Iterate over the length of pickleData, and for each element pop the 

958 # leftmost element off the deque and write it out. This is to save 

959 # memory, as the memory is added to the buffer object, it is removed 

960 # from from the container. 

961 # 

962 # Only this section needs to worry about memory pressue because 

963 # everything else written to the buffer prior to this pickle data is 

964 # only on the order of kilobytes to low numbers of megabytes. 

965 while jsonData: 

966 buffer.extend(jsonData.popleft()) 

967 if returnHeader: 

968 return buffer, headerData 

969 else: 

970 return buffer 

971 

972 @classmethod 

973 def load(cls, file: io.IO[bytes], universe: DimensionUniverse, 

974 nodes: Optional[Iterable[uuid.UUID]] = None, 

975 graphID: Optional[BuildId] = None, 

976 minimumVersion: int = 3 

977 ) -> QuantumGraph: 

978 """Read QuantumGraph from a file that was made by `save`. 

979 

980 Parameters 

981 ---------- 

982 file : `io.IO` of bytes 

983 File with pickle data open in binary mode. 

984 universe: `~lsst.daf.butler.DimensionUniverse` 

985 DimensionUniverse instance, not used by the method itself but 

986 needed to ensure that registry data structures are initialized. 

987 nodes: iterable of `int` or None 

988 Numbers that correspond to nodes in the graph. If specified, only 

989 these nodes will be loaded. Defaults to None, in which case all 

990 nodes will be loaded. 

991 graphID : `str` or `None` 

992 If specified this ID is verified against the loaded graph prior to 

993 loading any Nodes. This defaults to None in which case no 

994 validation is done. 

995 minimumVersion : int 

996 Minimum version of a save file to load. Set to -1 to load all 

997 versions. Older versions may need to be loaded, and re-saved 

998 to upgrade them to the latest format before they can be used in 

999 production. 

1000 

1001 Returns 

1002 ------- 

1003 graph : `QuantumGraph` 

1004 Resulting QuantumGraph instance. 

1005 

1006 Raises 

1007 ------ 

1008 TypeError 

1009 Raised if pickle contains instance of a type other than 

1010 QuantumGraph. 

1011 ValueError 

1012 Raised if one or more of the nodes requested is not in the 

1013 `QuantumGraph` or if graphID parameter does not match the graph 

1014 being loaded or if the supplied uri does not point at a valid 

1015 `QuantumGraph` save file. 

1016 

1017 Notes 

1018 ----- 

1019 Reading Quanta from pickle requires existence of singleton 

1020 DimensionUniverse which is usually instantiated during Registry 

1021 initialization. To make sure that DimensionUniverse exists this method 

1022 accepts dummy DimensionUniverse argument. 

1023 """ 

1024 # Try to see if the file handle contains pickle data, this will be 

1025 # removed in the future 

1026 try: 

1027 qgraph = pickle.load(file) 

1028 warnings.warn("Pickle graphs are deprecated, please re-save your graph with the save method") 

1029 except pickle.UnpicklingError: 

1030 # needed because we don't have Protocols yet 

1031 with LoadHelper(file, minimumVersion) as loader: # type: ignore 

1032 qgraph = loader.load(universe, nodes, graphID) 

1033 if not isinstance(qgraph, QuantumGraph): 

1034 raise TypeError(f"QuantumGraph pickle file has contains unexpected object type: {type(qgraph)}") 

1035 return qgraph 

1036 

1037 def iterTaskGraph(self) -> Generator[TaskDef, None, None]: 

1038 """Iterate over the `taskGraph` attribute in topological order 

1039 

1040 Yields 

1041 ------ 

1042 taskDef : `TaskDef` 

1043 `TaskDef` objects in topological order 

1044 """ 

1045 yield from nx.topological_sort(self.taskGraph) 

1046 

1047 @property 

1048 def graphID(self): 

1049 """Returns the ID generated by the graph at construction time 

1050 """ 

1051 return self._buildId 

1052 

1053 def __iter__(self) -> Generator[QuantumNode, None, None]: 

1054 yield from nx.topological_sort(self._connectedQuanta) 

1055 

1056 def __len__(self) -> int: 

1057 return self._count 

1058 

1059 def __contains__(self, node: QuantumNode) -> bool: 

1060 return self._connectedQuanta.has_node(node) 

1061 

1062 def __getstate__(self) -> dict: 

1063 """Stores a compact form of the graph as a list of graph nodes, and a 

1064 tuple of task labels and task configs. The full graph can be 

1065 reconstructed with this information, and it preseves the ordering of 

1066 the graph ndoes. 

1067 """ 

1068 universe: Optional[DimensionUniverse] = None 

1069 for node in self: 

1070 dId = node.quantum.dataId 

1071 if dId is None: 

1072 continue 

1073 universe = dId.graph.universe 

1074 return {"reduced": self._buildSaveObject(), 

1075 'graphId': self._buildId, 'universe': universe} 

1076 

1077 def __setstate__(self, state: dict): 

1078 """Reconstructs the state of the graph from the information persisted 

1079 in getstate. 

1080 """ 

1081 buffer = io.BytesIO(state['reduced']) 

1082 with LoadHelper(buffer, minimumVersion=3) as loader: 

1083 qgraph = loader.load(state['universe'], graphID=state['graphId']) 

1084 

1085 self._metadata = qgraph._metadata 

1086 self._buildId = qgraph._buildId 

1087 self._datasetDict = qgraph._datasetDict 

1088 self._nodeIdMap = qgraph._nodeIdMap 

1089 self._count = len(qgraph) 

1090 self._taskToQuantumNode = qgraph._taskToQuantumNode 

1091 self._taskGraph = qgraph._taskGraph 

1092 self._connectedQuanta = qgraph._connectedQuanta 

1093 

1094 def __eq__(self, other: object) -> bool: 

1095 if not isinstance(other, QuantumGraph): 

1096 return False 

1097 if len(self) != len(other): 

1098 return False 

1099 for node in self: 

1100 if node not in other: 

1101 return False 

1102 if self.determineInputsToQuantumNode(node) != other.determineInputsToQuantumNode(node): 

1103 return False 

1104 if self.determineOutputsOfQuantumNode(node) != other.determineOutputsOfQuantumNode(node): 

1105 return False 

1106 if set(self.allDatasetTypes) != set(other.allDatasetTypes): 

1107 return False 

1108 return set(self.taskGraph) == set(other.taskGraph)