Coverage for python/lsst/pipe/base/graph/graph.py: 17%

367 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-10-04 09:40 +0000

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("QuantumGraph", "IncompatibleGraphError") 

24 

25import io 

26import json 

27import lzma 

28import os 

29import pickle 

30import struct 

31import time 

32import uuid 

33import warnings 

34from collections import defaultdict, deque 

35from itertools import chain 

36from types import MappingProxyType 

37from typing import ( 

38 Any, 

39 BinaryIO, 

40 DefaultDict, 

41 Deque, 

42 Dict, 

43 FrozenSet, 

44 Generator, 

45 Iterable, 

46 List, 

47 Mapping, 

48 MutableMapping, 

49 Optional, 

50 Set, 

51 Tuple, 

52 TypeVar, 

53 Union, 

54) 

55 

56import networkx as nx 

57from lsst.daf.butler import DatasetRef, DatasetType, DimensionRecordsAccumulator, DimensionUniverse, Quantum 

58from lsst.resources import ResourcePath, ResourcePathExpression 

59from lsst.utils.introspection import get_full_type_name 

60from networkx.drawing.nx_agraph import write_dot 

61 

62from ..connections import iterConnections 

63from ..pipeline import TaskDef 

64from ._implDetails import DatasetTypeName, _DatasetTracker, _pruner 

65from ._loadHelpers import LoadHelper 

66from ._versionDeserializers import DESERIALIZER_MAP 

67from .quantumNode import BuildId, QuantumNode 

68 

69_T = TypeVar("_T", bound="QuantumGraph") 

70 

71# modify this constant any time the on disk representation of the save file 

72# changes, and update the load helpers to behave properly for each version. 

73SAVE_VERSION = 3 

74 

75# Strings used to describe the format for the preamble bytes in a file save 

76# The base is a big endian encoded unsigned short that is used to hold the 

77# file format version. This allows reading version bytes and determine which 

78# loading code should be used for the rest of the file 

79STRUCT_FMT_BASE = ">H" 

80# 

81# Version 1 

82# This marks a big endian encoded format with an unsigned short, an unsigned 

83# long long, and an unsigned long long in the byte stream 

84# Version 2 

85# A big endian encoded format with an unsigned long long byte stream used to 

86# indicate the total length of the entire header. 

87STRUCT_FMT_STRING = {1: ">QQ", 2: ">Q"} 

88 

89# magic bytes that help determine this is a graph save 

90MAGIC_BYTES = b"qgraph4\xf6\xe8\xa9" 

91 

92 

93class IncompatibleGraphError(Exception): 

94 """Exception class to indicate that a lookup by NodeId is impossible due 

95 to incompatibilities 

96 """ 

97 

98 pass 

99 

100 

101class QuantumGraph: 

102 """QuantumGraph is a directed acyclic graph of `QuantumNode` objects 

103 

104 This data structure represents a concrete workflow generated from a 

105 `Pipeline`. 

106 

107 Parameters 

108 ---------- 

109 quanta : Mapping of `TaskDef` to sets of `Quantum` 

110 This maps tasks (and their configs) to the sets of data they are to 

111 process. 

112 metadata : Optional Mapping of `str` to primitives 

113 This is an optional parameter of extra data to carry with the graph. 

114 Entries in this mapping should be able to be serialized in JSON. 

115 pruneRefs : iterable [ `DatasetRef` ], optional 

116 Set of dataset refs to exclude from a graph. 

117 initInputs : `Mapping`, optional 

118 Maps tasks to their InitInput dataset refs. Dataset refs can be either 

119 resolved or non-resolved. Presently the same dataset refs are included 

120 in each `Quantum` for the same task. 

121 initOutputs : `Mapping`, optional 

122 Maps tasks to their InitOutput dataset refs. Dataset refs can be either 

123 resolved or non-resolved. For intermediate resolved refs their dataset 

124 ID must match ``initInputs`` and Quantum ``initInputs``. 

125 

126 Raises 

127 ------ 

128 ValueError 

129 Raised if the graph is pruned such that some tasks no longer have nodes 

130 associated with them. 

131 """ 

132 

133 def __init__( 

134 self, 

135 quanta: Mapping[TaskDef, Set[Quantum]], 

136 metadata: Optional[Mapping[str, Any]] = None, 

137 pruneRefs: Optional[Iterable[DatasetRef]] = None, 

138 universe: Optional[DimensionUniverse] = None, 

139 initInputs: Optional[Mapping[TaskDef, Iterable[DatasetRef]]] = None, 

140 initOutputs: Optional[Mapping[TaskDef, Iterable[DatasetRef]]] = None, 

141 ): 

142 self._buildGraphs( 

143 quanta, 

144 metadata=metadata, 

145 pruneRefs=pruneRefs, 

146 universe=universe, 

147 initInputs=initInputs, 

148 initOutputs=initOutputs, 

149 ) 

150 

151 def _buildGraphs( 

152 self, 

153 quanta: Mapping[TaskDef, Set[Quantum]], 

154 *, 

155 _quantumToNodeId: Optional[Mapping[Quantum, uuid.UUID]] = None, 

156 _buildId: Optional[BuildId] = None, 

157 metadata: Optional[Mapping[str, Any]] = None, 

158 pruneRefs: Optional[Iterable[DatasetRef]] = None, 

159 universe: Optional[DimensionUniverse] = None, 

160 initInputs: Optional[Mapping[TaskDef, Iterable[DatasetRef]]] = None, 

161 initOutputs: Optional[Mapping[TaskDef, Iterable[DatasetRef]]] = None, 

162 ) -> None: 

163 """Builds the graph that is used to store the relation between tasks, 

164 and the graph that holds the relations between quanta 

165 """ 

166 self._metadata = metadata 

167 self._buildId = _buildId if _buildId is not None else BuildId(f"{time.time()}-{os.getpid()}") 

168 # Data structures used to identify relations between components; 

169 # DatasetTypeName -> TaskDef for task, 

170 # and DatasetRef -> QuantumNode for the quanta 

171 self._datasetDict = _DatasetTracker[DatasetTypeName, TaskDef](createInverse=True) 

172 self._datasetRefDict = _DatasetTracker[DatasetRef, QuantumNode]() 

173 

174 self._nodeIdMap: Dict[uuid.UUID, QuantumNode] = {} 

175 self._taskToQuantumNode: MutableMapping[TaskDef, Set[QuantumNode]] = defaultdict(set) 

176 for taskDef, quantumSet in quanta.items(): 

177 connections = taskDef.connections 

178 

179 # For each type of connection in the task, add a key to the 

180 # `_DatasetTracker` for the connections name, with a value of 

181 # the TaskDef in the appropriate field 

182 for inpt in iterConnections(connections, ("inputs", "prerequisiteInputs", "initInputs")): 

183 self._datasetDict.addConsumer(DatasetTypeName(inpt.name), taskDef) 

184 

185 for output in iterConnections(connections, ("outputs",)): 

186 self._datasetDict.addProducer(DatasetTypeName(output.name), taskDef) 

187 

188 # For each `Quantum` in the set of all `Quantum` for this task, 

189 # add a key to the `_DatasetTracker` that is a `DatasetRef` for one 

190 # of the individual datasets inside the `Quantum`, with a value of 

191 # a newly created QuantumNode to the appropriate input/output 

192 # field. 

193 for quantum in quantumSet: 

194 if quantum.dataId is not None: 

195 if universe is None: 

196 universe = quantum.dataId.universe 

197 elif universe != quantum.dataId.universe: 

198 raise RuntimeError( 

199 "Mismatched dimension universes in QuantumGraph construction: " 

200 f"{universe} != {quantum.dataId.universe}. " 

201 ) 

202 

203 if _quantumToNodeId: 

204 if (nodeId := _quantumToNodeId.get(quantum)) is None: 

205 raise ValueError( 

206 "If _quantuMToNodeNumber is not None, all quanta must have an " 

207 "associated value in the mapping" 

208 ) 

209 else: 

210 nodeId = uuid.uuid4() 

211 

212 inits = quantum.initInputs.values() 

213 inputs = quantum.inputs.values() 

214 value = QuantumNode(quantum, taskDef, nodeId) 

215 self._taskToQuantumNode[taskDef].add(value) 

216 self._nodeIdMap[nodeId] = value 

217 

218 for dsRef in chain(inits, inputs): 

219 # unfortunately, `Quantum` allows inits to be individual 

220 # `DatasetRef`s or an Iterable of such, so there must 

221 # be an instance check here 

222 if isinstance(dsRef, Iterable): 

223 for sub in dsRef: 

224 if sub.isComponent(): 

225 sub = sub.makeCompositeRef() 

226 self._datasetRefDict.addConsumer(sub, value) 

227 else: 

228 assert isinstance(dsRef, DatasetRef) 

229 if dsRef.isComponent(): 

230 dsRef = dsRef.makeCompositeRef() 

231 self._datasetRefDict.addConsumer(dsRef, value) 

232 for dsRef in chain.from_iterable(quantum.outputs.values()): 

233 self._datasetRefDict.addProducer(dsRef, value) 

234 

235 if pruneRefs is not None: 

236 # track what refs were pruned and prune the graph 

237 prunes: Set[QuantumNode] = set() 

238 _pruner(self._datasetRefDict, pruneRefs, alreadyPruned=prunes) 

239 

240 # recreate the taskToQuantumNode dict removing nodes that have been 

241 # pruned. Keep track of task defs that now have no QuantumNodes 

242 emptyTasks: Set[str] = set() 

243 newTaskToQuantumNode: DefaultDict[TaskDef, Set[QuantumNode]] = defaultdict(set) 

244 # accumulate all types 

245 types_ = set() 

246 # tracker for any pruneRefs that have caused tasks to have no nodes 

247 # This helps the user find out what caused the issues seen. 

248 culprits = set() 

249 # Find all the types from the refs to prune 

250 for r in pruneRefs: 

251 types_.add(r.datasetType) 

252 

253 # For each of the tasks, and their associated nodes, remove any 

254 # any nodes that were pruned. If there are no nodes associated 

255 # with a task, record that task, and find out if that was due to 

256 # a type from an input ref to prune. 

257 for td, taskNodes in self._taskToQuantumNode.items(): 

258 diff = taskNodes.difference(prunes) 

259 if len(diff) == 0: 

260 if len(taskNodes) != 0: 

261 tp: DatasetType 

262 for tp in types_: 

263 if (tmpRefs := next(iter(taskNodes)).quantum.inputs.get(tp)) and not set( 

264 tmpRefs 

265 ).difference(pruneRefs): 

266 culprits.add(tp.name) 

267 emptyTasks.add(td.label) 

268 newTaskToQuantumNode[td] = diff 

269 

270 # update the internal dict 

271 self._taskToQuantumNode = newTaskToQuantumNode 

272 

273 if emptyTasks: 

274 raise ValueError( 

275 f"{', '.join(emptyTasks)} task(s) have no nodes associated with them " 

276 f"after graph pruning; {', '.join(culprits)} caused over-pruning" 

277 ) 

278 

279 # Dimension universe 

280 if universe is None: 

281 raise RuntimeError( 

282 "Dimension universe or at least one quantum with a data ID " 

283 "must be provided when constructing a QuantumGraph." 

284 ) 

285 self._universe = universe 

286 

287 # Graph of quanta relations 

288 self._connectedQuanta = self._datasetRefDict.makeNetworkXGraph() 

289 self._count = len(self._connectedQuanta) 

290 

291 # Graph of task relations, used in various methods 

292 self._taskGraph = self._datasetDict.makeNetworkXGraph() 

293 

294 # convert default dict into a regular to prevent accidental key 

295 # insertion 

296 self._taskToQuantumNode = dict(self._taskToQuantumNode.items()) 

297 

298 self._initInputRefs: Dict[TaskDef, List[DatasetRef]] = {} 

299 self._initOutputRefs: Dict[TaskDef, List[DatasetRef]] = {} 

300 if initInputs is not None: 

301 self._initInputRefs = {taskDef: list(refs) for taskDef, refs in initInputs.items()} 

302 if initOutputs is not None: 

303 self._initOutputRefs = {taskDef: list(refs) for taskDef, refs in initOutputs.items()} 

304 

305 @property 

306 def taskGraph(self) -> nx.DiGraph: 

307 """Return a graph representing the relations between the tasks inside 

308 the quantum graph. 

309 

310 Returns 

311 ------- 

312 taskGraph : `networkx.Digraph` 

313 Internal datastructure that holds relations of `TaskDef` objects 

314 """ 

315 return self._taskGraph 

316 

317 @property 

318 def graph(self) -> nx.DiGraph: 

319 """Return a graph representing the relations between all the 

320 `QuantumNode` objects. Largely it should be preferred to iterate 

321 over, and use methods of this class, but sometimes direct access to 

322 the networkx object may be helpful 

323 

324 Returns 

325 ------- 

326 graph : `networkx.Digraph` 

327 Internal datastructure that holds relations of `QuantumNode` 

328 objects 

329 """ 

330 return self._connectedQuanta 

331 

332 @property 

333 def inputQuanta(self) -> Iterable[QuantumNode]: 

334 """Make a `list` of all `QuantumNode` objects that are 'input' nodes 

335 to the graph, meaning those nodes to not depend on any other nodes in 

336 the graph. 

337 

338 Returns 

339 ------- 

340 inputNodes : iterable of `QuantumNode` 

341 A list of nodes that are inputs to the graph 

342 """ 

343 return (q for q, n in self._connectedQuanta.in_degree if n == 0) 

344 

345 @property 

346 def outputQuanta(self) -> Iterable[QuantumNode]: 

347 """Make a `list` of all `QuantumNode` objects that are 'output' nodes 

348 to the graph, meaning those nodes have no nodes that depend them in 

349 the graph. 

350 

351 Returns 

352 ------- 

353 outputNodes : iterable of `QuantumNode` 

354 A list of nodes that are outputs of the graph 

355 """ 

356 return [q for q, n in self._connectedQuanta.out_degree if n == 0] 

357 

358 @property 

359 def allDatasetTypes(self) -> Tuple[DatasetTypeName, ...]: 

360 """Return all the `DatasetTypeName` objects that are contained inside 

361 the graph. 

362 

363 Returns 

364 ------- 

365 tuple of `DatasetTypeName` 

366 All the data set type names that are present in the graph 

367 """ 

368 return tuple(self._datasetDict.keys()) 

369 

370 @property 

371 def isConnected(self) -> bool: 

372 """Return True if all of the nodes in the graph are connected, ignores 

373 directionality of connections. 

374 """ 

375 return nx.is_weakly_connected(self._connectedQuanta) 

376 

377 def pruneGraphFromRefs(self: _T, refs: Iterable[DatasetRef]) -> _T: 

378 r"""Return a graph pruned of input `~lsst.daf.butler.DatasetRef`\ s 

379 and nodes which depend on them. 

380 

381 Parameters 

382 ---------- 

383 refs : `Iterable` of `DatasetRef` 

384 Refs which should be removed from resulting graph 

385 

386 Returns 

387 ------- 

388 graph : `QuantumGraph` 

389 A graph that has been pruned of specified refs and the nodes that 

390 depend on them. 

391 """ 

392 newInst = object.__new__(type(self)) 

393 quantumMap = defaultdict(set) 

394 for node in self: 

395 quantumMap[node.taskDef].add(node.quantum) 

396 

397 # convert to standard dict to prevent accidental key insertion 

398 quantumDict: Dict[TaskDef, Set[Quantum]] = dict(quantumMap.items()) 

399 

400 newInst._buildGraphs( 

401 quantumDict, 

402 _quantumToNodeId={n.quantum: n.nodeId for n in self}, 

403 metadata=self._metadata, 

404 pruneRefs=refs, 

405 universe=self._universe, 

406 ) 

407 return newInst 

408 

409 def getQuantumNodeByNodeId(self, nodeId: uuid.UUID) -> QuantumNode: 

410 """Lookup a `QuantumNode` from an id associated with the node. 

411 

412 Parameters 

413 ---------- 

414 nodeId : `NodeId` 

415 The number associated with a node 

416 

417 Returns 

418 ------- 

419 node : `QuantumNode` 

420 The node corresponding with input number 

421 

422 Raises 

423 ------ 

424 KeyError 

425 Raised if the requested nodeId is not in the graph. 

426 """ 

427 return self._nodeIdMap[nodeId] 

428 

429 def getQuantaForTask(self, taskDef: TaskDef) -> FrozenSet[Quantum]: 

430 """Return all the `Quantum` associated with a `TaskDef`. 

431 

432 Parameters 

433 ---------- 

434 taskDef : `TaskDef` 

435 The `TaskDef` for which `Quantum` are to be queried 

436 

437 Returns 

438 ------- 

439 frozenset of `Quantum` 

440 The `set` of `Quantum` that is associated with the specified 

441 `TaskDef`. 

442 """ 

443 return frozenset(node.quantum for node in self._taskToQuantumNode[taskDef]) 

444 

445 def getNodesForTask(self, taskDef: TaskDef) -> FrozenSet[QuantumNode]: 

446 """Return all the `QuantumNodes` associated with a `TaskDef`. 

447 

448 Parameters 

449 ---------- 

450 taskDef : `TaskDef` 

451 The `TaskDef` for which `Quantum` are to be queried 

452 

453 Returns 

454 ------- 

455 frozenset of `QuantumNodes` 

456 The `frozenset` of `QuantumNodes` that is associated with the 

457 specified `TaskDef`. 

458 """ 

459 return frozenset(self._taskToQuantumNode[taskDef]) 

460 

461 def findTasksWithInput(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]: 

462 """Find all tasks that have the specified dataset type name as an 

463 input. 

464 

465 Parameters 

466 ---------- 

467 datasetTypeName : `str` 

468 A string representing the name of a dataset type to be queried, 

469 can also accept a `DatasetTypeName` which is a `NewType` of str for 

470 type safety in static type checking. 

471 

472 Returns 

473 ------- 

474 tasks : iterable of `TaskDef` 

475 `TaskDef` objects that have the specified `DatasetTypeName` as an 

476 input, list will be empty if no tasks use specified 

477 `DatasetTypeName` as an input. 

478 

479 Raises 

480 ------ 

481 KeyError 

482 Raised if the `DatasetTypeName` is not part of the `QuantumGraph` 

483 """ 

484 return (c for c in self._datasetDict.getConsumers(datasetTypeName)) 

485 

486 def findTaskWithOutput(self, datasetTypeName: DatasetTypeName) -> Optional[TaskDef]: 

487 """Find all tasks that have the specified dataset type name as an 

488 output. 

489 

490 Parameters 

491 ---------- 

492 datasetTypeName : `str` 

493 A string representing the name of a dataset type to be queried, 

494 can also accept a `DatasetTypeName` which is a `NewType` of str for 

495 type safety in static type checking. 

496 

497 Returns 

498 ------- 

499 `TaskDef` or `None` 

500 `TaskDef` that outputs `DatasetTypeName` as an output or None if 

501 none of the tasks produce this `DatasetTypeName`. 

502 

503 Raises 

504 ------ 

505 KeyError 

506 Raised if the `DatasetTypeName` is not part of the `QuantumGraph` 

507 """ 

508 return self._datasetDict.getProducer(datasetTypeName) 

509 

510 def tasksWithDSType(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]: 

511 """Find all tasks that are associated with the specified dataset type 

512 name. 

513 

514 Parameters 

515 ---------- 

516 datasetTypeName : `str` 

517 A string representing the name of a dataset type to be queried, 

518 can also accept a `DatasetTypeName` which is a `NewType` of str for 

519 type safety in static type checking. 

520 

521 Returns 

522 ------- 

523 result : iterable of `TaskDef` 

524 `TaskDef` objects that are associated with the specified 

525 `DatasetTypeName` 

526 

527 Raises 

528 ------ 

529 KeyError 

530 Raised if the `DatasetTypeName` is not part of the `QuantumGraph` 

531 """ 

532 return self._datasetDict.getAll(datasetTypeName) 

533 

534 def findTaskDefByName(self, taskName: str) -> List[TaskDef]: 

535 """Determine which `TaskDef` objects in this graph are associated 

536 with a `str` representing a task name (looks at the taskName property 

537 of `TaskDef` objects). 

538 

539 Returns a list of `TaskDef` objects as a `PipelineTask` may appear 

540 multiple times in a graph with different labels. 

541 

542 Parameters 

543 ---------- 

544 taskName : str 

545 Name of a task to search for 

546 

547 Returns 

548 ------- 

549 result : list of `TaskDef` 

550 List of the `TaskDef` objects that have the name specified. 

551 Multiple values are returned in the case that a task is used 

552 multiple times with different labels. 

553 """ 

554 results = [] 

555 for task in self._taskToQuantumNode.keys(): 

556 split = task.taskName.split(".") 

557 if split[-1] == taskName: 

558 results.append(task) 

559 return results 

560 

561 def findTaskDefByLabel(self, label: str) -> Optional[TaskDef]: 

562 """Determine which `TaskDef` objects in this graph are associated 

563 with a `str` representing a tasks label. 

564 

565 Parameters 

566 ---------- 

567 taskName : str 

568 Name of a task to search for 

569 

570 Returns 

571 ------- 

572 result : `TaskDef` 

573 `TaskDef` objects that has the specified label. 

574 """ 

575 for task in self._taskToQuantumNode.keys(): 

576 if label == task.label: 

577 return task 

578 return None 

579 

580 def findQuantaWithDSType(self, datasetTypeName: DatasetTypeName) -> Set[Quantum]: 

581 """Return all the `Quantum` that contain a specified `DatasetTypeName`. 

582 

583 Parameters 

584 ---------- 

585 datasetTypeName : `str` 

586 The name of the dataset type to search for as a string, 

587 can also accept a `DatasetTypeName` which is a `NewType` of str for 

588 type safety in static type checking. 

589 

590 Returns 

591 ------- 

592 result : `set` of `QuantumNode` objects 

593 A `set` of `QuantumNode`s that contain specified `DatasetTypeName` 

594 

595 Raises 

596 ------ 

597 KeyError 

598 Raised if the `DatasetTypeName` is not part of the `QuantumGraph` 

599 

600 """ 

601 tasks = self._datasetDict.getAll(datasetTypeName) 

602 result: Set[Quantum] = set() 

603 result = result.union(quantum for task in tasks for quantum in self.getQuantaForTask(task)) 

604 return result 

605 

606 def checkQuantumInGraph(self, quantum: Quantum) -> bool: 

607 """Check if specified quantum appears in the graph as part of a node. 

608 

609 Parameters 

610 ---------- 

611 quantum : `Quantum` 

612 The quantum to search for 

613 

614 Returns 

615 ------- 

616 `bool` 

617 The result of searching for the quantum 

618 """ 

619 for node in self: 

620 if quantum == node.quantum: 

621 return True 

622 return False 

623 

624 def writeDotGraph(self, output: Union[str, io.BufferedIOBase]) -> None: 

625 """Write out the graph as a dot graph. 

626 

627 Parameters 

628 ---------- 

629 output : str or `io.BufferedIOBase` 

630 Either a filesystem path to write to, or a file handle object 

631 """ 

632 write_dot(self._connectedQuanta, output) 

633 

634 def subset(self: _T, nodes: Union[QuantumNode, Iterable[QuantumNode]]) -> _T: 

635 """Create a new graph object that contains the subset of the nodes 

636 specified as input. Node number is preserved. 

637 

638 Parameters 

639 ---------- 

640 nodes : `QuantumNode` or iterable of `QuantumNode` 

641 

642 Returns 

643 ------- 

644 graph : instance of graph type 

645 An instance of the type from which the subset was created 

646 """ 

647 if not isinstance(nodes, Iterable): 

648 nodes = (nodes,) 

649 quantumSubgraph = self._connectedQuanta.subgraph(nodes).nodes 

650 quantumMap = defaultdict(set) 

651 

652 node: QuantumNode 

653 for node in quantumSubgraph: 

654 quantumMap[node.taskDef].add(node.quantum) 

655 

656 # convert to standard dict to prevent accidental key insertion 

657 quantumDict: Dict[TaskDef, Set[Quantum]] = dict(quantumMap.items()) 

658 # Create an empty graph, and then populate it with custom mapping 

659 newInst = type(self)({}, universe=self._universe) 

660 newInst._buildGraphs( 

661 quantumDict, 

662 _quantumToNodeId={n.quantum: n.nodeId for n in nodes}, 

663 _buildId=self._buildId, 

664 metadata=self._metadata, 

665 universe=self._universe, 

666 ) 

667 return newInst 

668 

669 def subsetToConnected(self: _T) -> Tuple[_T, ...]: 

670 """Generate a list of subgraphs where each is connected. 

671 

672 Returns 

673 ------- 

674 result : list of `QuantumGraph` 

675 A list of graphs that are each connected 

676 """ 

677 return tuple( 

678 self.subset(connectedSet) 

679 for connectedSet in nx.weakly_connected_components(self._connectedQuanta) 

680 ) 

681 

682 def determineInputsToQuantumNode(self, node: QuantumNode) -> Set[QuantumNode]: 

683 """Return a set of `QuantumNode` that are direct inputs to a specified 

684 node. 

685 

686 Parameters 

687 ---------- 

688 node : `QuantumNode` 

689 The node of the graph for which inputs are to be determined 

690 

691 Returns 

692 ------- 

693 set of `QuantumNode` 

694 All the nodes that are direct inputs to specified node 

695 """ 

696 return set(pred for pred in self._connectedQuanta.predecessors(node)) 

697 

698 def determineOutputsOfQuantumNode(self, node: QuantumNode) -> Set[QuantumNode]: 

699 """Return a set of `QuantumNode` that are direct outputs of a specified 

700 node. 

701 

702 Parameters 

703 ---------- 

704 node : `QuantumNode` 

705 The node of the graph for which outputs are to be determined 

706 

707 Returns 

708 ------- 

709 set of `QuantumNode` 

710 All the nodes that are direct outputs to specified node 

711 """ 

712 return set(succ for succ in self._connectedQuanta.successors(node)) 

713 

714 def determineConnectionsOfQuantumNode(self: _T, node: QuantumNode) -> _T: 

715 """Return a graph of `QuantumNode` that are direct inputs and outputs 

716 of a specified node. 

717 

718 Parameters 

719 ---------- 

720 node : `QuantumNode` 

721 The node of the graph for which connected nodes are to be 

722 determined. 

723 

724 Returns 

725 ------- 

726 graph : graph of `QuantumNode` 

727 All the nodes that are directly connected to specified node 

728 """ 

729 nodes = self.determineInputsToQuantumNode(node).union(self.determineOutputsOfQuantumNode(node)) 

730 nodes.add(node) 

731 return self.subset(nodes) 

732 

733 def determineAncestorsOfQuantumNode(self: _T, node: QuantumNode) -> _T: 

734 """Return a graph of the specified node and all the ancestor nodes 

735 directly reachable by walking edges. 

736 

737 Parameters 

738 ---------- 

739 node : `QuantumNode` 

740 The node for which all ansestors are to be determined 

741 

742 Returns 

743 ------- 

744 graph of `QuantumNode` 

745 Graph of node and all of its ansestors 

746 """ 

747 predecessorNodes = nx.ancestors(self._connectedQuanta, node) 

748 predecessorNodes.add(node) 

749 return self.subset(predecessorNodes) 

750 

751 def findCycle(self) -> List[Tuple[QuantumNode, QuantumNode]]: 

752 """Check a graph for the presense of cycles and returns the edges of 

753 any cycles found, or an empty list if there is no cycle. 

754 

755 Returns 

756 ------- 

757 result : list of tuple of `QuantumNode`, `QuantumNode` 

758 A list of any graph edges that form a cycle, or an empty list if 

759 there is no cycle. Empty list to so support if graph.find_cycle() 

760 syntax as an empty list is falsy. 

761 """ 

762 try: 

763 return nx.find_cycle(self._connectedQuanta) 

764 except nx.NetworkXNoCycle: 

765 return [] 

766 

767 def saveUri(self, uri: ResourcePathExpression) -> None: 

768 """Save `QuantumGraph` to the specified URI. 

769 

770 Parameters 

771 ---------- 

772 uri : convertible to `ResourcePath` 

773 URI to where the graph should be saved. 

774 """ 

775 buffer = self._buildSaveObject() 

776 path = ResourcePath(uri) 

777 if path.getExtension() not in (".qgraph"): 

778 raise TypeError(f"Can currently only save a graph in qgraph format not {uri}") 

779 path.write(buffer) # type: ignore # Ignore because bytearray is safe to use in place of bytes 

780 

781 @property 

782 def metadata(self) -> Optional[MappingProxyType[str, Any]]: 

783 """ """ 

784 if self._metadata is None: 

785 return None 

786 return MappingProxyType(self._metadata) 

787 

788 def initInputRefs(self, taskDef: TaskDef) -> Optional[List[DatasetRef]]: 

789 """Return DatasetRefs for a given task InitInputs. 

790 

791 Parameters 

792 ---------- 

793 taskDef : `TaskDef` 

794 Task definition structure. 

795 

796 Returns 

797 ------- 

798 refs : `list` [ `DatasetRef` ] or None 

799 DatasetRef for the task InitInput, can be `None`. This can return 

800 either resolved or non-resolved reference. 

801 """ 

802 return self._initInputRefs.get(taskDef) 

803 

804 def initOutputRefs(self, taskDef: TaskDef) -> Optional[List[DatasetRef]]: 

805 """Return DatasetRefs for a given task InitOutputs. 

806 

807 Parameters 

808 ---------- 

809 taskDef : `TaskDef` 

810 Task definition structure. 

811 

812 Returns 

813 ------- 

814 refs : `list` [ `DatasetRef` ] or None 

815 DatasetRefs for the task InitOutput, can be `None`. This can return 

816 either resolved or non-resolved reference. Resolved reference will 

817 match Quantum's initInputs if this is an intermediate dataset type. 

818 """ 

819 return self._initOutputRefs.get(taskDef) 

820 

821 @classmethod 

822 def loadUri( 

823 cls, 

824 uri: ResourcePathExpression, 

825 universe: Optional[DimensionUniverse] = None, 

826 nodes: Optional[Iterable[uuid.UUID]] = None, 

827 graphID: Optional[BuildId] = None, 

828 minimumVersion: int = 3, 

829 ) -> QuantumGraph: 

830 """Read `QuantumGraph` from a URI. 

831 

832 Parameters 

833 ---------- 

834 uri : convertible to `ResourcePath` 

835 URI from where to load the graph. 

836 universe: `~lsst.daf.butler.DimensionUniverse` optional 

837 DimensionUniverse instance, not used by the method itself but 

838 needed to ensure that registry data structures are initialized. 

839 If None it is loaded from the QuantumGraph saved structure. If 

840 supplied, the DimensionUniverse from the loaded `QuantumGraph` 

841 will be validated against the supplied argument for compatibility. 

842 nodes: iterable of `int` or None 

843 Numbers that correspond to nodes in the graph. If specified, only 

844 these nodes will be loaded. Defaults to None, in which case all 

845 nodes will be loaded. 

846 graphID : `str` or `None` 

847 If specified this ID is verified against the loaded graph prior to 

848 loading any Nodes. This defaults to None in which case no 

849 validation is done. 

850 minimumVersion : int 

851 Minimum version of a save file to load. Set to -1 to load all 

852 versions. Older versions may need to be loaded, and re-saved 

853 to upgrade them to the latest format before they can be used in 

854 production. 

855 

856 Returns 

857 ------- 

858 graph : `QuantumGraph` 

859 Resulting QuantumGraph instance. 

860 

861 Raises 

862 ------ 

863 TypeError 

864 Raised if pickle contains instance of a type other than 

865 QuantumGraph. 

866 ValueError 

867 Raised if one or more of the nodes requested is not in the 

868 `QuantumGraph` or if graphID parameter does not match the graph 

869 being loaded or if the supplied uri does not point at a valid 

870 `QuantumGraph` save file. 

871 RuntimeError 

872 Raise if Supplied DimensionUniverse is not compatible with the 

873 DimensionUniverse saved in the graph 

874 

875 

876 Notes 

877 ----- 

878 Reading Quanta from pickle requires existence of singleton 

879 DimensionUniverse which is usually instantiated during Registry 

880 initialization. To make sure that DimensionUniverse exists this method 

881 accepts dummy DimensionUniverse argument. 

882 """ 

883 uri = ResourcePath(uri) 

884 # With ResourcePath we have the choice of always using a local file 

885 # or reading in the bytes directly. Reading in bytes can be more 

886 # efficient for reasonably-sized pickle files when the resource 

887 # is remote. For now use the local file variant. For a local file 

888 # as_local() does nothing. 

889 

890 if uri.getExtension() in (".pickle", ".pkl"): 

891 with uri.as_local() as local, open(local.ospath, "rb") as fd: 

892 warnings.warn("Pickle graphs are deprecated, please re-save your graph with the save method") 

893 qgraph = pickle.load(fd) 

894 elif uri.getExtension() in (".qgraph"): 

895 with LoadHelper(uri, minimumVersion) as loader: 

896 qgraph = loader.load(universe, nodes, graphID) 

897 else: 

898 raise ValueError("Only know how to handle files saved as `pickle`, `pkl`, or `qgraph`") 

899 if not isinstance(qgraph, QuantumGraph): 

900 raise TypeError(f"QuantumGraph save file contains unexpected object type: {type(qgraph)}") 

901 return qgraph 

902 

903 @classmethod 

904 def readHeader(cls, uri: ResourcePathExpression, minimumVersion: int = 3) -> Optional[str]: 

905 """Read the header of a `QuantumGraph` pointed to by the uri parameter 

906 and return it as a string. 

907 

908 Parameters 

909 ---------- 

910 uri : convertible to `ResourcePath` 

911 The location of the `QuantumGraph` to load. If the argument is a 

912 string, it must correspond to a valid `ResourcePath` path. 

913 minimumVersion : int 

914 Minimum version of a save file to load. Set to -1 to load all 

915 versions. Older versions may need to be loaded, and re-saved 

916 to upgrade them to the latest format before they can be used in 

917 production. 

918 

919 Returns 

920 ------- 

921 header : `str` or `None` 

922 The header associated with the specified `QuantumGraph` it there is 

923 one, else `None`. 

924 

925 Raises 

926 ------ 

927 ValueError 

928 Raised if `QuantuGraph` was saved as a pickle. 

929 Raised if the extention of the file specified by uri is not a 

930 `QuantumGraph` extention. 

931 """ 

932 uri = ResourcePath(uri) 

933 if uri.getExtension() in (".pickle", ".pkl"): 

934 raise ValueError("Reading a header from a pickle save is not supported") 

935 elif uri.getExtension() in (".qgraph"): 

936 return LoadHelper(uri, minimumVersion).readHeader() 

937 else: 

938 raise ValueError("Only know how to handle files saved as `qgraph`") 

939 

940 def buildAndPrintHeader(self) -> None: 

941 """Creates a header that would be used in a save of this object and 

942 prints it out to standard out. 

943 """ 

944 _, header = self._buildSaveObject(returnHeader=True) 

945 print(json.dumps(header)) 

946 

947 def save(self, file: BinaryIO) -> None: 

948 """Save QuantumGraph to a file. 

949 

950 Parameters 

951 ---------- 

952 file : `io.BufferedIOBase` 

953 File to write pickle data open in binary mode. 

954 """ 

955 buffer = self._buildSaveObject() 

956 file.write(buffer) # type: ignore # Ignore because bytearray is safe to use in place of bytes 

957 

958 def _buildSaveObject(self, returnHeader: bool = False) -> Union[bytearray, Tuple[bytearray, Dict]]: 

959 # make some containers 

960 jsonData: Deque[bytes] = deque() 

961 # node map is a list because json does not accept mapping keys that 

962 # are not strings, so we store a list of key, value pairs that will 

963 # be converted to a mapping on load 

964 nodeMap = [] 

965 taskDefMap = {} 

966 headerData: Dict[str, Any] = {} 

967 

968 # Store the QauntumGraph BuildId, this will allow validating BuildIds 

969 # at load time, prior to loading any QuantumNodes. Name chosen for 

970 # unlikely conflicts. 

971 headerData["GraphBuildID"] = self.graphID 

972 headerData["Metadata"] = self._metadata 

973 

974 # Store the universe this graph was created with 

975 universeConfig = self._universe.dimensionConfig 

976 headerData["universe"] = universeConfig.toDict() 

977 

978 # counter for the number of bytes processed thus far 

979 count = 0 

980 # serialize out the task Defs recording the start and end bytes of each 

981 # taskDef 

982 inverseLookup = self._datasetDict.inverse 

983 taskDef: TaskDef 

984 # sort by task label to ensure serialization happens in the same order 

985 for taskDef in self.taskGraph: 

986 # compressing has very little impact on saving or load time, but 

987 # a large impact on on disk size, so it is worth doing 

988 taskDescription: Dict[str, Any] = {} 

989 # save the fully qualified name. 

990 taskDescription["taskName"] = get_full_type_name(taskDef.taskClass) 

991 # save the config as a text stream that will be un-persisted on the 

992 # other end 

993 stream = io.StringIO() 

994 taskDef.config.saveToStream(stream) 

995 taskDescription["config"] = stream.getvalue() 

996 taskDescription["label"] = taskDef.label 

997 if (refs := self._initInputRefs.get(taskDef)) is not None: 

998 taskDescription["initInputRefs"] = [ref.to_json() for ref in refs] 

999 if (refs := self._initOutputRefs.get(taskDef)) is not None: 

1000 taskDescription["initOutputRefs"] = [ref.to_json() for ref in refs] 

1001 

1002 inputs = [] 

1003 outputs = [] 

1004 

1005 # Determine the connection between all of tasks and save that in 

1006 # the header as a list of connections and edges in each task 

1007 # this will help in un-persisting, and possibly in a "quick view" 

1008 # method that does not require everything to be un-persisted 

1009 # 

1010 # Typing returns can't be parameter dependent 

1011 for connection in inverseLookup[taskDef]: # type: ignore 

1012 consumers = self._datasetDict.getConsumers(connection) 

1013 producer = self._datasetDict.getProducer(connection) 

1014 if taskDef in consumers: 

1015 # This checks if the task consumes the connection directly 

1016 # from the datastore or it is produced by another task 

1017 producerLabel = producer.label if producer is not None else "datastore" 

1018 inputs.append((producerLabel, connection)) 

1019 elif taskDef not in consumers and producer is taskDef: 

1020 # If there are no consumers for this tasks produced 

1021 # connection, the output will be said to be the datastore 

1022 # in which case the for loop will be a zero length loop 

1023 if not consumers: 

1024 outputs.append(("datastore", connection)) 

1025 for td in consumers: 

1026 outputs.append((td.label, connection)) 

1027 

1028 # dump to json string, and encode that string to bytes and then 

1029 # conpress those bytes 

1030 dump = lzma.compress(json.dumps(taskDescription).encode()) 

1031 # record the sizing and relation information 

1032 taskDefMap[taskDef.label] = { 

1033 "bytes": (count, count + len(dump)), 

1034 "inputs": inputs, 

1035 "outputs": outputs, 

1036 } 

1037 count += len(dump) 

1038 jsonData.append(dump) 

1039 

1040 headerData["TaskDefs"] = taskDefMap 

1041 

1042 # serialize the nodes, recording the start and end bytes of each node 

1043 dimAccumulator = DimensionRecordsAccumulator() 

1044 for node in self: 

1045 # compressing has very little impact on saving or load time, but 

1046 # a large impact on on disk size, so it is worth doing 

1047 simpleNode = node.to_simple(accumulator=dimAccumulator) 

1048 

1049 dump = lzma.compress(simpleNode.json().encode()) 

1050 jsonData.append(dump) 

1051 nodeMap.append( 

1052 ( 

1053 str(node.nodeId), 

1054 { 

1055 "bytes": (count, count + len(dump)), 

1056 "inputs": [str(n.nodeId) for n in self.determineInputsToQuantumNode(node)], 

1057 "outputs": [str(n.nodeId) for n in self.determineOutputsOfQuantumNode(node)], 

1058 }, 

1059 ) 

1060 ) 

1061 count += len(dump) 

1062 

1063 headerData["DimensionRecords"] = { 

1064 key: value.dict() for key, value in dimAccumulator.makeSerializedDimensionRecordMapping().items() 

1065 } 

1066 

1067 # need to serialize this as a series of key,value tuples because of 

1068 # a limitation on how json cant do anyting but strings as keys 

1069 headerData["Nodes"] = nodeMap 

1070 

1071 # dump the headerData to json 

1072 header_encode = lzma.compress(json.dumps(headerData).encode()) 

1073 

1074 # record the sizes as 2 unsigned long long numbers for a total of 16 

1075 # bytes 

1076 save_bytes = struct.pack(STRUCT_FMT_BASE, SAVE_VERSION) 

1077 

1078 fmt_string = DESERIALIZER_MAP[SAVE_VERSION].FMT_STRING() 

1079 map_lengths = struct.pack(fmt_string, len(header_encode)) 

1080 

1081 # write each component of the save out in a deterministic order 

1082 # buffer = io.BytesIO() 

1083 # buffer.write(map_lengths) 

1084 # buffer.write(taskDef_pickle) 

1085 # buffer.write(map_pickle) 

1086 buffer = bytearray() 

1087 buffer.extend(MAGIC_BYTES) 

1088 buffer.extend(save_bytes) 

1089 buffer.extend(map_lengths) 

1090 buffer.extend(header_encode) 

1091 # Iterate over the length of pickleData, and for each element pop the 

1092 # leftmost element off the deque and write it out. This is to save 

1093 # memory, as the memory is added to the buffer object, it is removed 

1094 # from from the container. 

1095 # 

1096 # Only this section needs to worry about memory pressue because 

1097 # everything else written to the buffer prior to this pickle data is 

1098 # only on the order of kilobytes to low numbers of megabytes. 

1099 while jsonData: 

1100 buffer.extend(jsonData.popleft()) 

1101 if returnHeader: 

1102 return buffer, headerData 

1103 else: 

1104 return buffer 

1105 

1106 @classmethod 

1107 def load( 

1108 cls, 

1109 file: BinaryIO, 

1110 universe: Optional[DimensionUniverse] = None, 

1111 nodes: Optional[Iterable[uuid.UUID]] = None, 

1112 graphID: Optional[BuildId] = None, 

1113 minimumVersion: int = 3, 

1114 ) -> QuantumGraph: 

1115 """Read QuantumGraph from a file that was made by `save`. 

1116 

1117 Parameters 

1118 ---------- 

1119 file : `io.IO` of bytes 

1120 File with pickle data open in binary mode. 

1121 universe: `~lsst.daf.butler.DimensionUniverse`, optional 

1122 DimensionUniverse instance, not used by the method itself but 

1123 needed to ensure that registry data structures are initialized. 

1124 If None it is loaded from the QuantumGraph saved structure. If 

1125 supplied, the DimensionUniverse from the loaded `QuantumGraph` 

1126 will be validated against the supplied argument for compatibility. 

1127 nodes: iterable of `int` or None 

1128 Numbers that correspond to nodes in the graph. If specified, only 

1129 these nodes will be loaded. Defaults to None, in which case all 

1130 nodes will be loaded. 

1131 graphID : `str` or `None` 

1132 If specified this ID is verified against the loaded graph prior to 

1133 loading any Nodes. This defaults to None in which case no 

1134 validation is done. 

1135 minimumVersion : int 

1136 Minimum version of a save file to load. Set to -1 to load all 

1137 versions. Older versions may need to be loaded, and re-saved 

1138 to upgrade them to the latest format before they can be used in 

1139 production. 

1140 

1141 Returns 

1142 ------- 

1143 graph : `QuantumGraph` 

1144 Resulting QuantumGraph instance. 

1145 

1146 Raises 

1147 ------ 

1148 TypeError 

1149 Raised if pickle contains instance of a type other than 

1150 QuantumGraph. 

1151 ValueError 

1152 Raised if one or more of the nodes requested is not in the 

1153 `QuantumGraph` or if graphID parameter does not match the graph 

1154 being loaded or if the supplied uri does not point at a valid 

1155 `QuantumGraph` save file. 

1156 

1157 Notes 

1158 ----- 

1159 Reading Quanta from pickle requires existence of singleton 

1160 DimensionUniverse which is usually instantiated during Registry 

1161 initialization. To make sure that DimensionUniverse exists this method 

1162 accepts dummy DimensionUniverse argument. 

1163 """ 

1164 # Try to see if the file handle contains pickle data, this will be 

1165 # removed in the future 

1166 try: 

1167 qgraph = pickle.load(file) 

1168 warnings.warn("Pickle graphs are deprecated, please re-save your graph with the save method") 

1169 except pickle.UnpicklingError: 

1170 with LoadHelper(file, minimumVersion) as loader: 

1171 qgraph = loader.load(universe, nodes, graphID) 

1172 if not isinstance(qgraph, QuantumGraph): 

1173 raise TypeError(f"QuantumGraph pickle file has contains unexpected object type: {type(qgraph)}") 

1174 return qgraph 

1175 

1176 def iterTaskGraph(self) -> Generator[TaskDef, None, None]: 

1177 """Iterate over the `taskGraph` attribute in topological order 

1178 

1179 Yields 

1180 ------ 

1181 taskDef : `TaskDef` 

1182 `TaskDef` objects in topological order 

1183 """ 

1184 yield from nx.topological_sort(self.taskGraph) 

1185 

1186 @property 

1187 def graphID(self) -> BuildId: 

1188 """Returns the ID generated by the graph at construction time""" 

1189 return self._buildId 

1190 

1191 def __iter__(self) -> Generator[QuantumNode, None, None]: 

1192 yield from nx.topological_sort(self._connectedQuanta) 

1193 

1194 def __len__(self) -> int: 

1195 return self._count 

1196 

1197 def __contains__(self, node: QuantumNode) -> bool: 

1198 return self._connectedQuanta.has_node(node) 

1199 

1200 def __getstate__(self) -> dict: 

1201 """Stores a compact form of the graph as a list of graph nodes, and a 

1202 tuple of task labels and task configs. The full graph can be 

1203 reconstructed with this information, and it preseves the ordering of 

1204 the graph ndoes. 

1205 """ 

1206 universe: Optional[DimensionUniverse] = None 

1207 for node in self: 

1208 dId = node.quantum.dataId 

1209 if dId is None: 

1210 continue 

1211 universe = dId.graph.universe 

1212 return {"reduced": self._buildSaveObject(), "graphId": self._buildId, "universe": universe} 

1213 

1214 def __setstate__(self, state: dict) -> None: 

1215 """Reconstructs the state of the graph from the information persisted 

1216 in getstate. 

1217 """ 

1218 buffer = io.BytesIO(state["reduced"]) 

1219 with LoadHelper(buffer, minimumVersion=3) as loader: 

1220 qgraph = loader.load(state["universe"], graphID=state["graphId"]) 

1221 

1222 self._metadata = qgraph._metadata 

1223 self._buildId = qgraph._buildId 

1224 self._datasetDict = qgraph._datasetDict 

1225 self._nodeIdMap = qgraph._nodeIdMap 

1226 self._count = len(qgraph) 

1227 self._taskToQuantumNode = qgraph._taskToQuantumNode 

1228 self._taskGraph = qgraph._taskGraph 

1229 self._connectedQuanta = qgraph._connectedQuanta 

1230 self._initInputRefs = qgraph._initInputRefs 

1231 self._initOutputRefs = qgraph._initOutputRefs 

1232 

1233 def __eq__(self, other: object) -> bool: 

1234 if not isinstance(other, QuantumGraph): 

1235 return False 

1236 if len(self) != len(other): 

1237 return False 

1238 for node in self: 

1239 if node not in other: 

1240 return False 

1241 if self.determineInputsToQuantumNode(node) != other.determineInputsToQuantumNode(node): 

1242 return False 

1243 if self.determineOutputsOfQuantumNode(node) != other.determineOutputsOfQuantumNode(node): 

1244 return False 

1245 if set(self.allDatasetTypes) != set(other.allDatasetTypes): 

1246 return False 

1247 return set(self.taskGraph) == set(other.taskGraph)