Coverage for python/lsst/pipe/base/graph/graph.py: 21%

406 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-10 03:25 -0700

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27from __future__ import annotations 

28 

29__all__ = ("QuantumGraph", "IncompatibleGraphError") 

30 

31import io 

32import json 

33import lzma 

34import os 

35import struct 

36import time 

37import uuid 

38from collections import defaultdict, deque 

39from collections.abc import Generator, Iterable, Iterator, Mapping, MutableMapping 

40from itertools import chain 

41from types import MappingProxyType 

42from typing import Any, BinaryIO, TypeVar 

43 

44import networkx as nx 

45from lsst.daf.butler import ( 

46 DatasetId, 

47 DatasetRef, 

48 DatasetType, 

49 DimensionRecordsAccumulator, 

50 DimensionUniverse, 

51 Quantum, 

52) 

53from lsst.daf.butler.persistence_context import PersistenceContextVars 

54from lsst.resources import ResourcePath, ResourcePathExpression 

55from lsst.utils.introspection import get_full_type_name 

56from lsst.utils.packages import Packages 

57from networkx.drawing.nx_agraph import write_dot 

58 

59from ..connections import iterConnections 

60from ..pipeline import TaskDef 

61from ..pipeline_graph import PipelineGraph 

62from ._implDetails import DatasetTypeName, _DatasetTracker 

63from ._loadHelpers import LoadHelper 

64from ._versionDeserializers import DESERIALIZER_MAP 

65from .quantumNode import BuildId, QuantumNode 

66 

67_T = TypeVar("_T", bound="QuantumGraph") 

68 

69# modify this constant any time the on disk representation of the save file 

70# changes, and update the load helpers to behave properly for each version. 

71SAVE_VERSION = 3 

72 

73# Strings used to describe the format for the preamble bytes in a file save 

74# The base is a big endian encoded unsigned short that is used to hold the 

75# file format version. This allows reading version bytes and determine which 

76# loading code should be used for the rest of the file 

77STRUCT_FMT_BASE = ">H" 

78# 

79# Version 1 

80# This marks a big endian encoded format with an unsigned short, an unsigned 

81# long long, and an unsigned long long in the byte stream 

82# Version 2 

83# A big endian encoded format with an unsigned long long byte stream used to 

84# indicate the total length of the entire header. 

85STRUCT_FMT_STRING = {1: ">QQ", 2: ">Q"} 

86 

87# magic bytes that help determine this is a graph save 

88MAGIC_BYTES = b"qgraph4\xf6\xe8\xa9" 

89 

90 

91class IncompatibleGraphError(Exception): 

92 """Exception class to indicate that a lookup by NodeId is impossible due 

93 to incompatibilities. 

94 """ 

95 

96 pass 

97 

98 

99class QuantumGraph: 

100 """QuantumGraph is a directed acyclic graph of `QuantumNode` objects. 

101 

102 This data structure represents a concrete workflow generated from a 

103 `Pipeline`. 

104 

105 Parameters 

106 ---------- 

107 quanta : `~collections.abc.Mapping` [ `TaskDef`, \ 

108 `set` [ `~lsst.daf.butler.Quantum` ] ] 

109 This maps tasks (and their configs) to the sets of data they are to 

110 process. 

111 metadata : Optional `~collections.abc.Mapping` of `str` to primitives 

112 This is an optional parameter of extra data to carry with the graph. 

113 Entries in this mapping should be able to be serialized in JSON. 

114 universe : `~lsst.daf.butler.DimensionUniverse`, optional 

115 The dimensions in which quanta can be defined. Need only be provided if 

116 no quanta have data IDs. 

117 initInputs : `~collections.abc.Mapping`, optional 

118 Maps tasks to their InitInput dataset refs. Dataset refs can be either 

119 resolved or non-resolved. Presently the same dataset refs are included 

120 in each `~lsst.daf.butler.Quantum` for the same task. 

121 initOutputs : `~collections.abc.Mapping`, optional 

122 Maps tasks to their InitOutput dataset refs. Dataset refs can be either 

123 resolved or non-resolved. For intermediate resolved refs their dataset 

124 ID must match ``initInputs`` and Quantum ``initInputs``. 

125 globalInitOutputs : iterable [ `~lsst.daf.butler.DatasetRef` ], optional 

126 Dataset refs for some global objects produced by pipeline. These 

127 objects include task configurations and package versions. Typically 

128 they have an empty DataId, but there is no real restriction on what 

129 can appear here. 

130 registryDatasetTypes : iterable [ `~lsst.daf.butler.DatasetType` ], \ 

131 optional 

132 Dataset types which are used by this graph, their definitions must 

133 match registry. If registry does not define dataset type yet, then 

134 it should match one that will be created later. 

135 

136 Raises 

137 ------ 

138 ValueError 

139 Raised if the graph is pruned such that some tasks no longer have nodes 

140 associated with them. 

141 """ 

142 

143 def __init__( 

144 self, 

145 quanta: Mapping[TaskDef, set[Quantum]], 

146 metadata: Mapping[str, Any] | None = None, 

147 universe: DimensionUniverse | None = None, 

148 initInputs: Mapping[TaskDef, Iterable[DatasetRef]] | None = None, 

149 initOutputs: Mapping[TaskDef, Iterable[DatasetRef]] | None = None, 

150 globalInitOutputs: Iterable[DatasetRef] | None = None, 

151 registryDatasetTypes: Iterable[DatasetType] | None = None, 

152 ): 

153 self._buildGraphs( 

154 quanta, 

155 metadata=metadata, 

156 universe=universe, 

157 initInputs=initInputs, 

158 initOutputs=initOutputs, 

159 globalInitOutputs=globalInitOutputs, 

160 registryDatasetTypes=registryDatasetTypes, 

161 ) 

162 

163 def _buildGraphs( 

164 self, 

165 quanta: Mapping[TaskDef, set[Quantum]], 

166 *, 

167 _quantumToNodeId: Mapping[Quantum, uuid.UUID] | None = None, 

168 _buildId: BuildId | None = None, 

169 metadata: Mapping[str, Any] | None = None, 

170 universe: DimensionUniverse | None = None, 

171 initInputs: Mapping[TaskDef, Iterable[DatasetRef]] | None = None, 

172 initOutputs: Mapping[TaskDef, Iterable[DatasetRef]] | None = None, 

173 globalInitOutputs: Iterable[DatasetRef] | None = None, 

174 registryDatasetTypes: Iterable[DatasetType] | None = None, 

175 ) -> None: 

176 """Build the graph that is used to store the relation between tasks, 

177 and the graph that holds the relations between quanta 

178 """ 

179 # Save packages to metadata 

180 self._metadata = dict(metadata) if metadata is not None else {} 

181 self._metadata["packages"] = Packages.fromSystem() 

182 

183 self._buildId = _buildId if _buildId is not None else BuildId(f"{time.time()}-{os.getpid()}") 

184 # Data structure used to identify relations between 

185 # DatasetTypeName -> TaskDef. 

186 self._datasetDict = _DatasetTracker(createInverse=True) 

187 

188 # Temporary graph that will have dataset UUIDs (as raw bytes) and 

189 # QuantumNode objects as nodes; will be collapsed down to just quanta 

190 # later. 

191 bipartite_graph = nx.DiGraph() 

192 

193 self._nodeIdMap: dict[uuid.UUID, QuantumNode] = {} 

194 self._taskToQuantumNode: MutableMapping[TaskDef, set[QuantumNode]] = defaultdict(set) 

195 for taskDef, quantumSet in quanta.items(): 

196 connections = taskDef.connections 

197 

198 # For each type of connection in the task, add a key to the 

199 # `_DatasetTracker` for the connections name, with a value of 

200 # the TaskDef in the appropriate field 

201 for inpt in iterConnections(connections, ("inputs", "prerequisiteInputs", "initInputs")): 

202 # Have to handle components in inputs. 

203 dataset_name, _, _ = inpt.name.partition(".") 

204 self._datasetDict.addConsumer(DatasetTypeName(dataset_name), taskDef) 

205 

206 for output in iterConnections(connections, ("outputs",)): 

207 # Have to handle possible components in outputs. 

208 dataset_name, _, _ = output.name.partition(".") 

209 self._datasetDict.addProducer(DatasetTypeName(dataset_name), taskDef) 

210 

211 # For each `Quantum` in the set of all `Quantum` for this task, 

212 # add a key to the `_DatasetTracker` that is a `DatasetRef` for one 

213 # of the individual datasets inside the `Quantum`, with a value of 

214 # a newly created QuantumNode to the appropriate input/output 

215 # field. 

216 for quantum in quantumSet: 

217 if quantum.dataId is not None: 

218 if universe is None: 

219 universe = quantum.dataId.universe 

220 elif universe != quantum.dataId.universe: 

221 raise RuntimeError( 

222 "Mismatched dimension universes in QuantumGraph construction: " 

223 f"{universe} != {quantum.dataId.universe}. " 

224 ) 

225 

226 if _quantumToNodeId: 

227 if (nodeId := _quantumToNodeId.get(quantum)) is None: 

228 raise ValueError( 

229 "If _quantuMToNodeNumber is not None, all quanta must have an " 

230 "associated value in the mapping" 

231 ) 

232 else: 

233 nodeId = uuid.uuid4() 

234 

235 inits = quantum.initInputs.values() 

236 inputs = quantum.inputs.values() 

237 value = QuantumNode(quantum, taskDef, nodeId) 

238 self._taskToQuantumNode[taskDef].add(value) 

239 self._nodeIdMap[nodeId] = value 

240 

241 bipartite_graph.add_node(value, bipartite=0) 

242 for dsRef in chain(inits, inputs): 

243 # unfortunately, `Quantum` allows inits to be individual 

244 # `DatasetRef`s or an Iterable of such, so there must 

245 # be an instance check here 

246 if isinstance(dsRef, Iterable): 

247 for sub in dsRef: 

248 bipartite_graph.add_node(sub.id.bytes, bipartite=1) 

249 bipartite_graph.add_edge(sub.id.bytes, value) 

250 else: 

251 assert isinstance(dsRef, DatasetRef) 

252 if dsRef.isComponent(): 

253 dsRef = dsRef.makeCompositeRef() 

254 bipartite_graph.add_node(dsRef.id.bytes, bipartite=1) 

255 bipartite_graph.add_edge(dsRef.id.bytes, value) 

256 for dsRef in chain.from_iterable(quantum.outputs.values()): 

257 bipartite_graph.add_node(dsRef.id.bytes, bipartite=1) 

258 bipartite_graph.add_edge(value, dsRef.id.bytes) 

259 

260 # Dimension universe 

261 if universe is None: 

262 raise RuntimeError( 

263 "Dimension universe or at least one quantum with a data ID " 

264 "must be provided when constructing a QuantumGraph." 

265 ) 

266 self._universe = universe 

267 

268 # Make graph of quanta relations, by projecting out the dataset nodes 

269 # in the bipartite_graph, leaving just the quanta. 

270 self._connectedQuanta = nx.algorithms.bipartite.projected_graph( 

271 bipartite_graph, self._nodeIdMap.values() 

272 ) 

273 self._count = len(self._connectedQuanta) 

274 

275 # Graph of task relations, used in various methods 

276 self._taskGraph = self._datasetDict.makeNetworkXGraph() 

277 

278 # convert default dict into a regular to prevent accidental key 

279 # insertion 

280 self._taskToQuantumNode = dict(self._taskToQuantumNode.items()) 

281 

282 self._initInputRefs: dict[TaskDef, list[DatasetRef]] = {} 

283 self._initOutputRefs: dict[TaskDef, list[DatasetRef]] = {} 

284 self._globalInitOutputRefs: list[DatasetRef] = [] 

285 self._registryDatasetTypes: list[DatasetType] = [] 

286 if initInputs is not None: 

287 self._initInputRefs = {taskDef: list(refs) for taskDef, refs in initInputs.items()} 

288 if initOutputs is not None: 

289 self._initOutputRefs = {taskDef: list(refs) for taskDef, refs in initOutputs.items()} 

290 if globalInitOutputs is not None: 

291 self._globalInitOutputRefs = list(globalInitOutputs) 

292 if registryDatasetTypes is not None: 

293 self._registryDatasetTypes = list(registryDatasetTypes) 

294 

295 # PipelineGraph is current constructed on first use. 

296 # TODO DM-40442: use PipelineGraph instead of TaskDef 

297 # collections. 

298 self._pipeline_graph: PipelineGraph | None = None 

299 

300 @property 

301 def pipeline_graph(self) -> PipelineGraph: 

302 """A graph representation of the tasks and dataset types in the quantum 

303 graph. 

304 """ 

305 if self._pipeline_graph is None: 

306 # Construct into a temporary for strong exception safety. 

307 pipeline_graph = PipelineGraph() 

308 for task_def in self._taskToQuantumNode.keys(): 

309 pipeline_graph.add_task( 

310 task_def.label, task_def.taskClass, task_def.config, connections=task_def.connections 

311 ) 

312 dataset_types = {dataset_type.name: dataset_type for dataset_type in self._registryDatasetTypes} 

313 pipeline_graph.resolve(dimensions=self._universe, dataset_types=dataset_types) 

314 self._pipeline_graph = pipeline_graph 

315 return self._pipeline_graph 

316 

317 def get_task_quanta(self, label: str) -> Mapping[uuid.UUID, Quantum]: 

318 """Return the quanta associated with the given task label. 

319 

320 Parameters 

321 ---------- 

322 label : `str` 

323 Task label. 

324 

325 Returns 

326 ------- 

327 quanta : `~collections.abc.Mapping` [ uuid.UUID, `Quantum` ] 

328 Mapping from quantum ID to quantum. Empty if ``label`` does not 

329 correspond to a task in this graph. 

330 """ 

331 task_def = self.findTaskDefByLabel(label) 

332 if not task_def: 

333 return {} 

334 return {node.nodeId: node.quantum for node in self.getNodesForTask(task_def)} 

335 

336 @property 

337 def taskGraph(self) -> nx.DiGraph: 

338 """A graph representing the relations between the tasks inside 

339 the quantum graph (`networkx.DiGraph`). 

340 """ 

341 return self._taskGraph 

342 

343 @property 

344 def graph(self) -> nx.DiGraph: 

345 """A graph representing the relations between all the `QuantumNode` 

346 objects (`networkx.DiGraph`). 

347 

348 The graph should usually be iterated over, or passed to methods of this 

349 class, but sometimes direct access to the ``networkx`` object may be 

350 helpful. 

351 """ 

352 return self._connectedQuanta 

353 

354 @property 

355 def inputQuanta(self) -> Iterable[QuantumNode]: 

356 """The nodes that are inputs to the graph (iterable [`QuantumNode`]). 

357 

358 These are the nodes that do not depend on any other nodes in the 

359 graph. 

360 """ 

361 return (q for q, n in self._connectedQuanta.in_degree if n == 0) 

362 

363 @property 

364 def outputQuanta(self) -> Iterable[QuantumNode]: 

365 """The nodes that are outputs of the graph (iterable [`QuantumNode`]). 

366 

367 These are the nodes that have no nodes that depend on them in the 

368 graph. 

369 """ 

370 return [q for q, n in self._connectedQuanta.out_degree if n == 0] 

371 

372 @property 

373 def allDatasetTypes(self) -> tuple[DatasetTypeName, ...]: 

374 """All the data set type names that are present in the graph 

375 (`tuple` [`str`]). 

376 

377 These types do not include global init-outputs. 

378 """ 

379 return tuple(self._datasetDict.keys()) 

380 

381 @property 

382 def isConnected(self) -> bool: 

383 """Whether all of the nodes in the graph are connected, ignoring 

384 directionality of connections (`bool`). 

385 """ 

386 return nx.is_weakly_connected(self._connectedQuanta) 

387 

388 def getQuantumNodeByNodeId(self, nodeId: uuid.UUID) -> QuantumNode: 

389 """Lookup a `QuantumNode` from an id associated with the node. 

390 

391 Parameters 

392 ---------- 

393 nodeId : `NodeId` 

394 The number associated with a node. 

395 

396 Returns 

397 ------- 

398 node : `QuantumNode` 

399 The node corresponding with input number. 

400 

401 Raises 

402 ------ 

403 KeyError 

404 Raised if the requested nodeId is not in the graph. 

405 """ 

406 return self._nodeIdMap[nodeId] 

407 

408 def getQuantaForTask(self, taskDef: TaskDef) -> frozenset[Quantum]: 

409 """Return all the `~lsst.daf.butler.Quantum` associated with a 

410 `TaskDef`. 

411 

412 Parameters 

413 ---------- 

414 taskDef : `TaskDef` 

415 The `TaskDef` for which `~lsst.daf.butler.Quantum` are to be 

416 queried. 

417 

418 Returns 

419 ------- 

420 quanta : `frozenset` of `~lsst.daf.butler.Quantum` 

421 The `set` of `~lsst.daf.butler.Quantum` that is associated with the 

422 specified `TaskDef`. 

423 """ 

424 return frozenset(node.quantum for node in self._taskToQuantumNode.get(taskDef, ())) 

425 

426 def getNumberOfQuantaForTask(self, taskDef: TaskDef) -> int: 

427 """Return the number of `~lsst.daf.butler.Quantum` associated with 

428 a `TaskDef`. 

429 

430 Parameters 

431 ---------- 

432 taskDef : `TaskDef` 

433 The `TaskDef` for which `~lsst.daf.butler.Quantum` are to be 

434 queried. 

435 

436 Returns 

437 ------- 

438 count : `int` 

439 The number of `~lsst.daf.butler.Quantum` that are associated with 

440 the specified `TaskDef`. 

441 """ 

442 return len(self._taskToQuantumNode.get(taskDef, ())) 

443 

444 def getNodesForTask(self, taskDef: TaskDef) -> frozenset[QuantumNode]: 

445 r"""Return all the `QuantumNode`\s associated with a `TaskDef`. 

446 

447 Parameters 

448 ---------- 

449 taskDef : `TaskDef` 

450 The `TaskDef` for which `~lsst.daf.butler.Quantum` are to be 

451 queried. 

452 

453 Returns 

454 ------- 

455 nodes : `frozenset` [ `QuantumNode` ] 

456 A `frozenset` of `QuantumNode` that is associated with the 

457 specified `TaskDef`. 

458 """ 

459 return frozenset(self._taskToQuantumNode[taskDef]) 

460 

461 def findTasksWithInput(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]: 

462 """Find all tasks that have the specified dataset type name as an 

463 input. 

464 

465 Parameters 

466 ---------- 

467 datasetTypeName : `str` 

468 A string representing the name of a dataset type to be queried, 

469 can also accept a `DatasetTypeName` which is a `~typing.NewType` of 

470 `str` for type safety in static type checking. 

471 

472 Returns 

473 ------- 

474 tasks : iterable of `TaskDef` 

475 `TaskDef` objects that have the specified `DatasetTypeName` as an 

476 input, list will be empty if no tasks use specified 

477 `DatasetTypeName` as an input. 

478 

479 Raises 

480 ------ 

481 KeyError 

482 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`. 

483 """ 

484 return (c for c in self._datasetDict.getConsumers(datasetTypeName)) 

485 

486 def findTaskWithOutput(self, datasetTypeName: DatasetTypeName) -> TaskDef | None: 

487 """Find all tasks that have the specified dataset type name as an 

488 output. 

489 

490 Parameters 

491 ---------- 

492 datasetTypeName : `str` 

493 A string representing the name of a dataset type to be queried, 

494 can also accept a `DatasetTypeName` which is a `~typing.NewType` of 

495 `str` for type safety in static type checking. 

496 

497 Returns 

498 ------- 

499 result : `TaskDef` or `None` 

500 `TaskDef` that outputs `DatasetTypeName` as an output or `None` if 

501 none of the tasks produce this `DatasetTypeName`. 

502 

503 Raises 

504 ------ 

505 KeyError 

506 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`. 

507 """ 

508 return self._datasetDict.getProducer(datasetTypeName) 

509 

510 def tasksWithDSType(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]: 

511 """Find all tasks that are associated with the specified dataset type 

512 name. 

513 

514 Parameters 

515 ---------- 

516 datasetTypeName : `str` 

517 A string representing the name of a dataset type to be queried, 

518 can also accept a `DatasetTypeName` which is a `~typing.NewType` of 

519 `str` for type safety in static type checking. 

520 

521 Returns 

522 ------- 

523 result : iterable of `TaskDef` 

524 `TaskDef` objects that are associated with the specified 

525 `DatasetTypeName`. 

526 

527 Raises 

528 ------ 

529 KeyError 

530 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`. 

531 """ 

532 return self._datasetDict.getAll(datasetTypeName) 

533 

534 def findTaskDefByName(self, taskName: str) -> list[TaskDef]: 

535 """Determine which `TaskDef` objects in this graph are associated 

536 with a `str` representing a task name (looks at the ``taskName`` 

537 property of `TaskDef` objects). 

538 

539 Returns a list of `TaskDef` objects as a `PipelineTask` may appear 

540 multiple times in a graph with different labels. 

541 

542 Parameters 

543 ---------- 

544 taskName : `str` 

545 Name of a task to search for. 

546 

547 Returns 

548 ------- 

549 result : `list` of `TaskDef` 

550 List of the `TaskDef` objects that have the name specified. 

551 Multiple values are returned in the case that a task is used 

552 multiple times with different labels. 

553 """ 

554 results = [] 

555 for task in self._taskToQuantumNode: 

556 split = task.taskName.split(".") 

557 if split[-1] == taskName: 

558 results.append(task) 

559 return results 

560 

561 def findTaskDefByLabel(self, label: str) -> TaskDef | None: 

562 """Determine which `TaskDef` objects in this graph are associated 

563 with a `str` representing a tasks label. 

564 

565 Parameters 

566 ---------- 

567 label : `str` 

568 Name of a task to search for. 

569 

570 Returns 

571 ------- 

572 result : `TaskDef` 

573 `TaskDef` objects that has the specified label. 

574 """ 

575 for task in self._taskToQuantumNode: 

576 if label == task.label: 

577 return task 

578 return None 

579 

580 def findQuantaWithDSType(self, datasetTypeName: DatasetTypeName) -> set[Quantum]: 

581 r"""Return all the `~lsst.daf.butler.Quantum` that contain a specified 

582 `DatasetTypeName`. 

583 

584 Parameters 

585 ---------- 

586 datasetTypeName : `str` 

587 The name of the dataset type to search for as a string, 

588 can also accept a `DatasetTypeName` which is a `~typing.NewType` of 

589 `str` for type safety in static type checking. 

590 

591 Returns 

592 ------- 

593 result : `set` of `QuantumNode` objects 

594 A `set` of `QuantumNode`\s that contain specified 

595 `DatasetTypeName`. 

596 

597 Raises 

598 ------ 

599 KeyError 

600 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`. 

601 """ 

602 tasks = self._datasetDict.getAll(datasetTypeName) 

603 result: set[Quantum] = set() 

604 result = result.union(quantum for task in tasks for quantum in self.getQuantaForTask(task)) 

605 return result 

606 

607 def checkQuantumInGraph(self, quantum: Quantum) -> bool: 

608 """Check if specified quantum appears in the graph as part of a node. 

609 

610 Parameters 

611 ---------- 

612 quantum : `lsst.daf.butler.Quantum` 

613 The quantum to search for. 

614 

615 Returns 

616 ------- 

617 in_graph : `bool` 

618 The result of searching for the quantum. 

619 """ 

620 return any(quantum == node.quantum for node in self) 

621 

622 def writeDotGraph(self, output: str | io.BufferedIOBase) -> None: 

623 """Write out the graph as a dot graph. 

624 

625 Parameters 

626 ---------- 

627 output : `str` or `io.BufferedIOBase` 

628 Either a filesystem path to write to, or a file handle object. 

629 """ 

630 write_dot(self._connectedQuanta, output) 

631 

632 def subset(self: _T, nodes: QuantumNode | Iterable[QuantumNode]) -> _T: 

633 """Create a new graph object that contains the subset of the nodes 

634 specified as input. Node number is preserved. 

635 

636 Parameters 

637 ---------- 

638 nodes : `QuantumNode` or iterable of `QuantumNode` 

639 Nodes from which to create subset. 

640 

641 Returns 

642 ------- 

643 graph : instance of graph type 

644 An instance of the type from which the subset was created. 

645 """ 

646 if not isinstance(nodes, Iterable): 

647 nodes = (nodes,) 

648 quantumSubgraph = self._connectedQuanta.subgraph(nodes).nodes 

649 quantumMap = defaultdict(set) 

650 

651 dataset_type_names: set[str] = set() 

652 node: QuantumNode 

653 for node in quantumSubgraph: 

654 quantumMap[node.taskDef].add(node.quantum) 

655 dataset_type_names.update( 

656 dstype.name 

657 for dstype in chain( 

658 node.quantum.inputs.keys(), node.quantum.outputs.keys(), node.quantum.initInputs.keys() 

659 ) 

660 ) 

661 

662 # May need to trim dataset types from registryDatasetTypes. 

663 for taskDef in quantumMap: 

664 if refs := self.initOutputRefs(taskDef): 

665 dataset_type_names.update(ref.datasetType.name for ref in refs) 

666 dataset_type_names.update(ref.datasetType.name for ref in self._globalInitOutputRefs) 

667 registryDatasetTypes = [ 

668 dstype for dstype in self._registryDatasetTypes if dstype.name in dataset_type_names 

669 ] 

670 

671 # convert to standard dict to prevent accidental key insertion 

672 quantumDict: dict[TaskDef, set[Quantum]] = dict(quantumMap.items()) 

673 # Create an empty graph, and then populate it with custom mapping 

674 newInst = type(self)({}, universe=self._universe) 

675 # TODO: Do we need to copy initInputs/initOutputs? 

676 newInst._buildGraphs( 

677 quantumDict, 

678 _quantumToNodeId={n.quantum: n.nodeId for n in nodes}, 

679 _buildId=self._buildId, 

680 metadata=self._metadata, 

681 universe=self._universe, 

682 globalInitOutputs=self._globalInitOutputRefs, 

683 registryDatasetTypes=registryDatasetTypes, 

684 ) 

685 return newInst 

686 

687 def subsetToConnected(self: _T) -> tuple[_T, ...]: 

688 """Generate a list of subgraphs where each is connected. 

689 

690 Returns 

691 ------- 

692 result : `list` of `QuantumGraph` 

693 A list of graphs that are each connected. 

694 """ 

695 return tuple( 

696 self.subset(connectedSet) 

697 for connectedSet in nx.weakly_connected_components(self._connectedQuanta) 

698 ) 

699 

700 def determineInputsToQuantumNode(self, node: QuantumNode) -> set[QuantumNode]: 

701 """Return a set of `QuantumNode` that are direct inputs to a specified 

702 node. 

703 

704 Parameters 

705 ---------- 

706 node : `QuantumNode` 

707 The node of the graph for which inputs are to be determined. 

708 

709 Returns 

710 ------- 

711 inputs : `set` of `QuantumNode` 

712 All the nodes that are direct inputs to specified node. 

713 """ 

714 return set(self._connectedQuanta.predecessors(node)) 

715 

716 def determineOutputsOfQuantumNode(self, node: QuantumNode) -> set[QuantumNode]: 

717 """Return a set of `QuantumNode` that are direct outputs of a specified 

718 node. 

719 

720 Parameters 

721 ---------- 

722 node : `QuantumNode` 

723 The node of the graph for which outputs are to be determined. 

724 

725 Returns 

726 ------- 

727 outputs : `set` of `QuantumNode` 

728 All the nodes that are direct outputs to specified node. 

729 """ 

730 return set(self._connectedQuanta.successors(node)) 

731 

732 def determineConnectionsOfQuantumNode(self: _T, node: QuantumNode) -> _T: 

733 """Return a graph of `QuantumNode` that are direct inputs and outputs 

734 of a specified node. 

735 

736 Parameters 

737 ---------- 

738 node : `QuantumNode` 

739 The node of the graph for which connected nodes are to be 

740 determined. 

741 

742 Returns 

743 ------- 

744 graph : graph of `QuantumNode` 

745 All the nodes that are directly connected to specified node. 

746 """ 

747 nodes = self.determineInputsToQuantumNode(node).union(self.determineOutputsOfQuantumNode(node)) 

748 nodes.add(node) 

749 return self.subset(nodes) 

750 

751 def determineAncestorsOfQuantumNode(self: _T, node: QuantumNode) -> _T: 

752 """Return a graph of the specified node and all the ancestor nodes 

753 directly reachable by walking edges. 

754 

755 Parameters 

756 ---------- 

757 node : `QuantumNode` 

758 The node for which all ancestors are to be determined. 

759 

760 Returns 

761 ------- 

762 ancestors : graph of `QuantumNode` 

763 Graph of node and all of its ancestors. 

764 """ 

765 predecessorNodes = nx.ancestors(self._connectedQuanta, node) 

766 predecessorNodes.add(node) 

767 return self.subset(predecessorNodes) 

768 

769 def findCycle(self) -> list[tuple[QuantumNode, QuantumNode]]: 

770 """Check a graph for the presense of cycles and returns the edges of 

771 any cycles found, or an empty list if there is no cycle. 

772 

773 Returns 

774 ------- 

775 result : `list` of `tuple` of [ `QuantumNode`, `QuantumNode` ] 

776 A list of any graph edges that form a cycle, or an empty list if 

777 there is no cycle. Empty list to so support if graph.find_cycle() 

778 syntax as an empty list is falsy. 

779 """ 

780 try: 

781 return nx.find_cycle(self._connectedQuanta) 

782 except nx.NetworkXNoCycle: 

783 return [] 

784 

785 def saveUri(self, uri: ResourcePathExpression) -> None: 

786 """Save `QuantumGraph` to the specified URI. 

787 

788 Parameters 

789 ---------- 

790 uri : convertible to `~lsst.resources.ResourcePath` 

791 URI to where the graph should be saved. 

792 """ 

793 buffer = self._buildSaveObject() 

794 path = ResourcePath(uri) 

795 if path.getExtension() not in (".qgraph"): 

796 raise TypeError(f"Can currently only save a graph in qgraph format not {uri}") 

797 path.write(buffer) # type: ignore # Ignore because bytearray is safe to use in place of bytes 

798 

799 @property 

800 def metadata(self) -> MappingProxyType[str, Any] | None: 

801 """Extra data carried with the graph (mapping [`str`] or `None`). 

802 

803 The mapping is a dynamic view of this object's metadata. Values should 

804 be able to be serialized in JSON. 

805 """ 

806 return MappingProxyType(self._metadata) 

807 

808 def initInputRefs(self, taskDef: TaskDef) -> list[DatasetRef] | None: 

809 """Return DatasetRefs for a given task InitInputs. 

810 

811 Parameters 

812 ---------- 

813 taskDef : `TaskDef` 

814 Task definition structure. 

815 

816 Returns 

817 ------- 

818 refs : `list` [ `~lsst.daf.butler.DatasetRef` ] or `None` 

819 DatasetRef for the task InitInput, can be `None`. This can return 

820 either resolved or non-resolved reference. 

821 """ 

822 return self._initInputRefs.get(taskDef) 

823 

824 def initOutputRefs(self, taskDef: TaskDef) -> list[DatasetRef] | None: 

825 """Return DatasetRefs for a given task InitOutputs. 

826 

827 Parameters 

828 ---------- 

829 taskDef : `TaskDef` 

830 Task definition structure. 

831 

832 Returns 

833 ------- 

834 refs : `list` [ `~lsst.daf.butler.DatasetRef` ] or `None` 

835 DatasetRefs for the task InitOutput, can be `None`. This can return 

836 either resolved or non-resolved reference. Resolved reference will 

837 match Quantum's initInputs if this is an intermediate dataset type. 

838 """ 

839 return self._initOutputRefs.get(taskDef) 

840 

841 def globalInitOutputRefs(self) -> list[DatasetRef]: 

842 """Return DatasetRefs for global InitOutputs. 

843 

844 Returns 

845 ------- 

846 refs : `list` [ `~lsst.daf.butler.DatasetRef` ] 

847 DatasetRefs for global InitOutputs. 

848 """ 

849 return self._globalInitOutputRefs 

850 

851 def registryDatasetTypes(self) -> list[DatasetType]: 

852 """Return dataset types used by this graph, their definitions match 

853 dataset types from registry. 

854 

855 Returns 

856 ------- 

857 refs : `list` [ `~lsst.daf.butler.DatasetType` ] 

858 Dataset types for this graph. 

859 """ 

860 return self._registryDatasetTypes 

861 

862 @classmethod 

863 def loadUri( 

864 cls, 

865 uri: ResourcePathExpression, 

866 universe: DimensionUniverse | None = None, 

867 nodes: Iterable[uuid.UUID] | None = None, 

868 graphID: BuildId | None = None, 

869 minimumVersion: int = 3, 

870 ) -> QuantumGraph: 

871 """Read `QuantumGraph` from a URI. 

872 

873 Parameters 

874 ---------- 

875 uri : convertible to `~lsst.resources.ResourcePath` 

876 URI from where to load the graph. 

877 universe : `~lsst.daf.butler.DimensionUniverse`, optional 

878 If `None` it is loaded from the `QuantumGraph` 

879 saved structure. If supplied, the 

880 `~lsst.daf.butler.DimensionUniverse` from the loaded `QuantumGraph` 

881 will be validated against the supplied argument for compatibility. 

882 nodes : iterable of `uuid.UUID` or `None` 

883 UUIDs that correspond to nodes in the graph. If specified, only 

884 these nodes will be loaded. Defaults to None, in which case all 

885 nodes will be loaded. 

886 graphID : `str` or `None` 

887 If specified this ID is verified against the loaded graph prior to 

888 loading any Nodes. This defaults to None in which case no 

889 validation is done. 

890 minimumVersion : `int` 

891 Minimum version of a save file to load. Set to -1 to load all 

892 versions. Older versions may need to be loaded, and re-saved 

893 to upgrade them to the latest format before they can be used in 

894 production. 

895 

896 Returns 

897 ------- 

898 graph : `QuantumGraph` 

899 Resulting QuantumGraph instance. 

900 

901 Raises 

902 ------ 

903 TypeError 

904 Raised if file contains instance of a type other than 

905 `QuantumGraph`. 

906 ValueError 

907 Raised if one or more of the nodes requested is not in the 

908 `QuantumGraph` or if graphID parameter does not match the graph 

909 being loaded or if the supplied uri does not point at a valid 

910 `QuantumGraph` save file. 

911 RuntimeError 

912 Raise if Supplied `~lsst.daf.butler.DimensionUniverse` is not 

913 compatible with the `~lsst.daf.butler.DimensionUniverse` saved in 

914 the graph. 

915 """ 

916 uri = ResourcePath(uri) 

917 if uri.getExtension() in {".qgraph"}: 

918 with LoadHelper(uri, minimumVersion) as loader: 

919 qgraph = loader.load(universe, nodes, graphID) 

920 else: 

921 raise ValueError(f"Only know how to handle files saved as `.qgraph`, not {uri}") 

922 if not isinstance(qgraph, QuantumGraph): 

923 raise TypeError(f"QuantumGraph file {uri} contains unexpected object type: {type(qgraph)}") 

924 return qgraph 

925 

926 @classmethod 

927 def readHeader(cls, uri: ResourcePathExpression, minimumVersion: int = 3) -> str | None: 

928 """Read the header of a `QuantumGraph` pointed to by the uri parameter 

929 and return it as a string. 

930 

931 Parameters 

932 ---------- 

933 uri : convertible to `~lsst.resources.ResourcePath` 

934 The location of the `QuantumGraph` to load. If the argument is a 

935 string, it must correspond to a valid 

936 `~lsst.resources.ResourcePath` path. 

937 minimumVersion : `int` 

938 Minimum version of a save file to load. Set to -1 to load all 

939 versions. Older versions may need to be loaded, and re-saved 

940 to upgrade them to the latest format before they can be used in 

941 production. 

942 

943 Returns 

944 ------- 

945 header : `str` or `None` 

946 The header associated with the specified `QuantumGraph` it there is 

947 one, else `None`. 

948 

949 Raises 

950 ------ 

951 ValueError 

952 Raised if the extension of the file specified by uri is not a 

953 `QuantumGraph` extension. 

954 """ 

955 uri = ResourcePath(uri) 

956 if uri.getExtension() in {".qgraph"}: 

957 return LoadHelper(uri, minimumVersion).readHeader() 

958 else: 

959 raise ValueError("Only know how to handle files saved as `.qgraph`") 

960 

961 def buildAndPrintHeader(self) -> None: 

962 """Create a header that would be used in a save of this object and 

963 prints it out to standard out. 

964 """ 

965 _, header = self._buildSaveObject(returnHeader=True) 

966 print(json.dumps(header)) 

967 

968 def save(self, file: BinaryIO) -> None: 

969 """Save QuantumGraph to a file. 

970 

971 Parameters 

972 ---------- 

973 file : `io.BufferedIOBase` 

974 File to write data open in binary mode. 

975 """ 

976 buffer = self._buildSaveObject() 

977 file.write(buffer) # type: ignore # Ignore because bytearray is safe to use in place of bytes 

978 

979 def _buildSaveObject(self, returnHeader: bool = False) -> bytearray | tuple[bytearray, dict]: 

980 thing = PersistenceContextVars() 

981 result = thing.run(self._buildSaveObjectImpl, returnHeader) 

982 return result 

983 

984 def _buildSaveObjectImpl(self, returnHeader: bool = False) -> bytearray | tuple[bytearray, dict]: 

985 # make some containers 

986 jsonData: deque[bytes] = deque() 

987 # node map is a list because json does not accept mapping keys that 

988 # are not strings, so we store a list of key, value pairs that will 

989 # be converted to a mapping on load 

990 nodeMap = [] 

991 taskDefMap = {} 

992 headerData: dict[str, Any] = {} 

993 

994 # Store the QuantumGraph BuildId, this will allow validating BuildIds 

995 # at load time, prior to loading any QuantumNodes. Name chosen for 

996 # unlikely conflicts. 

997 headerData["GraphBuildID"] = self.graphID 

998 headerData["Metadata"] = self._metadata 

999 

1000 # Store the universe this graph was created with 

1001 universeConfig = self._universe.dimensionConfig 

1002 headerData["universe"] = universeConfig.toDict() 

1003 

1004 # counter for the number of bytes processed thus far 

1005 count = 0 

1006 # serialize out the task Defs recording the start and end bytes of each 

1007 # taskDef 

1008 inverseLookup = self._datasetDict.inverse 

1009 taskDef: TaskDef 

1010 # sort by task label to ensure serialization happens in the same order 

1011 for taskDef in self.taskGraph: 

1012 # compressing has very little impact on saving or load time, but 

1013 # a large impact on on disk size, so it is worth doing 

1014 taskDescription: dict[str, Any] = {} 

1015 # save the fully qualified name. 

1016 taskDescription["taskName"] = get_full_type_name(taskDef.taskClass) 

1017 # save the config as a text stream that will be un-persisted on the 

1018 # other end 

1019 stream = io.StringIO() 

1020 taskDef.config.saveToStream(stream) 

1021 taskDescription["config"] = stream.getvalue() 

1022 taskDescription["label"] = taskDef.label 

1023 if (refs := self._initInputRefs.get(taskDef)) is not None: 

1024 taskDescription["initInputRefs"] = [ref.to_json() for ref in refs] 

1025 if (refs := self._initOutputRefs.get(taskDef)) is not None: 

1026 taskDescription["initOutputRefs"] = [ref.to_json() for ref in refs] 

1027 

1028 inputs = [] 

1029 outputs = [] 

1030 

1031 # Determine the connection between all of tasks and save that in 

1032 # the header as a list of connections and edges in each task 

1033 # this will help in un-persisting, and possibly in a "quick view" 

1034 # method that does not require everything to be un-persisted 

1035 # 

1036 # Typing returns can't be parameter dependent 

1037 for connection in inverseLookup[taskDef]: # type: ignore 

1038 consumers = self._datasetDict.getConsumers(connection) 

1039 producer = self._datasetDict.getProducer(connection) 

1040 if taskDef in consumers: 

1041 # This checks if the task consumes the connection directly 

1042 # from the datastore or it is produced by another task 

1043 producerLabel = producer.label if producer is not None else "datastore" 

1044 inputs.append((producerLabel, connection)) 

1045 elif taskDef not in consumers and producer is taskDef: 

1046 # If there are no consumers for this tasks produced 

1047 # connection, the output will be said to be the datastore 

1048 # in which case the for loop will be a zero length loop 

1049 if not consumers: 

1050 outputs.append(("datastore", connection)) 

1051 for td in consumers: 

1052 outputs.append((td.label, connection)) 

1053 

1054 # dump to json string, and encode that string to bytes and then 

1055 # conpress those bytes 

1056 dump = lzma.compress(json.dumps(taskDescription).encode(), preset=2) 

1057 # record the sizing and relation information 

1058 taskDefMap[taskDef.label] = { 

1059 "bytes": (count, count + len(dump)), 

1060 "inputs": inputs, 

1061 "outputs": outputs, 

1062 } 

1063 count += len(dump) 

1064 jsonData.append(dump) 

1065 

1066 headerData["TaskDefs"] = taskDefMap 

1067 

1068 # serialize the nodes, recording the start and end bytes of each node 

1069 dimAccumulator = DimensionRecordsAccumulator() 

1070 for node in self: 

1071 # compressing has very little impact on saving or load time, but 

1072 # a large impact on on disk size, so it is worth doing 

1073 simpleNode = node.to_simple(accumulator=dimAccumulator) 

1074 

1075 dump = lzma.compress(simpleNode.model_dump_json().encode(), preset=2) 

1076 jsonData.append(dump) 

1077 nodeMap.append( 

1078 ( 

1079 str(node.nodeId), 

1080 { 

1081 "bytes": (count, count + len(dump)), 

1082 "inputs": [str(n.nodeId) for n in self.determineInputsToQuantumNode(node)], 

1083 "outputs": [str(n.nodeId) for n in self.determineOutputsOfQuantumNode(node)], 

1084 }, 

1085 ) 

1086 ) 

1087 count += len(dump) 

1088 

1089 headerData["DimensionRecords"] = { 

1090 key: value.model_dump() 

1091 for key, value in dimAccumulator.makeSerializedDimensionRecordMapping().items() 

1092 } 

1093 

1094 # need to serialize this as a series of key,value tuples because of 

1095 # a limitation on how json cant do anything but strings as keys 

1096 headerData["Nodes"] = nodeMap 

1097 

1098 if self._globalInitOutputRefs: 

1099 headerData["GlobalInitOutputRefs"] = [ref.to_json() for ref in self._globalInitOutputRefs] 

1100 

1101 if self._registryDatasetTypes: 

1102 headerData["RegistryDatasetTypes"] = [dstype.to_json() for dstype in self._registryDatasetTypes] 

1103 

1104 # dump the headerData to json 

1105 header_encode = lzma.compress(json.dumps(headerData).encode()) 

1106 

1107 # record the sizes as 2 unsigned long long numbers for a total of 16 

1108 # bytes 

1109 save_bytes = struct.pack(STRUCT_FMT_BASE, SAVE_VERSION) 

1110 

1111 fmt_string = DESERIALIZER_MAP[SAVE_VERSION].FMT_STRING() 

1112 map_lengths = struct.pack(fmt_string, len(header_encode)) 

1113 

1114 # write each component of the save out in a deterministic order 

1115 buffer = bytearray() 

1116 buffer.extend(MAGIC_BYTES) 

1117 buffer.extend(save_bytes) 

1118 buffer.extend(map_lengths) 

1119 buffer.extend(header_encode) 

1120 # Iterate over the length of jsonData, and for each element pop the 

1121 # leftmost element off the deque and write it out. This is to save 

1122 # memory, as the memory is added to the buffer object, it is removed 

1123 # from from the container. 

1124 # 

1125 # Only this section needs to worry about memory pressure because 

1126 # everything else written to the buffer prior to this data is 

1127 # only on the order of kilobytes to low numbers of megabytes. 

1128 while jsonData: 

1129 buffer.extend(jsonData.popleft()) 

1130 if returnHeader: 

1131 return buffer, headerData 

1132 else: 

1133 return buffer 

1134 

1135 @classmethod 

1136 def load( 

1137 cls, 

1138 file: BinaryIO, 

1139 universe: DimensionUniverse | None = None, 

1140 nodes: Iterable[uuid.UUID] | None = None, 

1141 graphID: BuildId | None = None, 

1142 minimumVersion: int = 3, 

1143 ) -> QuantumGraph: 

1144 """Read `QuantumGraph` from a file that was made by `save`. 

1145 

1146 Parameters 

1147 ---------- 

1148 file : `io.IO` of bytes 

1149 File with data open in binary mode. 

1150 universe : `~lsst.daf.butler.DimensionUniverse`, optional 

1151 If `None` it is loaded from the `QuantumGraph` 

1152 saved structure. If supplied, the 

1153 `~lsst.daf.butler.DimensionUniverse` from the loaded `QuantumGraph` 

1154 will be validated against the supplied argument for compatibility. 

1155 nodes : iterable of `uuid.UUID` or `None` 

1156 UUIDs that correspond to nodes in the graph. If specified, only 

1157 these nodes will be loaded. Defaults to None, in which case all 

1158 nodes will be loaded. 

1159 graphID : `str` or `None` 

1160 If specified this ID is verified against the loaded graph prior to 

1161 loading any Nodes. This defaults to None in which case no 

1162 validation is done. 

1163 minimumVersion : `int` 

1164 Minimum version of a save file to load. Set to -1 to load all 

1165 versions. Older versions may need to be loaded, and re-saved 

1166 to upgrade them to the latest format before they can be used in 

1167 production. 

1168 

1169 Returns 

1170 ------- 

1171 graph : `QuantumGraph` 

1172 Resulting QuantumGraph instance. 

1173 

1174 Raises 

1175 ------ 

1176 TypeError 

1177 Raised if data contains instance of a type other than 

1178 `QuantumGraph`. 

1179 ValueError 

1180 Raised if one or more of the nodes requested is not in the 

1181 `QuantumGraph` or if graphID parameter does not match the graph 

1182 being loaded or if the supplied uri does not point at a valid 

1183 `QuantumGraph` save file. 

1184 """ 

1185 with LoadHelper(file, minimumVersion) as loader: 

1186 qgraph = loader.load(universe, nodes, graphID) 

1187 if not isinstance(qgraph, QuantumGraph): 

1188 raise TypeError(f"QuantumGraph file contains unexpected object type: {type(qgraph)}") 

1189 return qgraph 

1190 

1191 def iterTaskGraph(self) -> Generator[TaskDef, None, None]: 

1192 """Iterate over the `taskGraph` attribute in topological order. 

1193 

1194 Yields 

1195 ------ 

1196 taskDef : `TaskDef` 

1197 `TaskDef` objects in topological order. 

1198 """ 

1199 yield from nx.topological_sort(self.taskGraph) 

1200 

1201 def updateRun(self, run: str, *, metadata_key: str | None = None, update_graph_id: bool = False) -> None: 

1202 """Change output run and dataset ID for each output dataset. 

1203 

1204 Parameters 

1205 ---------- 

1206 run : `str` 

1207 New output run name. 

1208 metadata_key : `str` or `None` 

1209 Specifies matadata key corresponding to output run name to update 

1210 with new run name. If `None` or if metadata is missing it is not 

1211 updated. If metadata is present but key is missing, it will be 

1212 added. 

1213 update_graph_id : `bool`, optional 

1214 If `True` then also update graph ID with a new unique value. 

1215 """ 

1216 dataset_id_map: dict[DatasetId, DatasetId] = {} 

1217 

1218 def _update_output_refs( 

1219 refs: Iterable[DatasetRef], run: str, dataset_id_map: MutableMapping[DatasetId, DatasetId] 

1220 ) -> Iterator[DatasetRef]: 

1221 """Update a collection of `~lsst.daf.butler.DatasetRef` with new 

1222 run and dataset IDs. 

1223 """ 

1224 for ref in refs: 

1225 new_ref = ref.replace(run=run) 

1226 dataset_id_map[ref.id] = new_ref.id 

1227 yield new_ref 

1228 

1229 def _update_intermediate_refs( 

1230 refs: Iterable[DatasetRef], run: str, dataset_id_map: Mapping[DatasetId, DatasetId] 

1231 ) -> Iterator[DatasetRef]: 

1232 """Update intermediate references with new run and IDs. Only the 

1233 references that appear in ``dataset_id_map`` are updated, others 

1234 are returned unchanged. 

1235 """ 

1236 for ref in refs: 

1237 if dataset_id := dataset_id_map.get(ref.id): 

1238 ref = ref.replace(run=run, id=dataset_id) 

1239 yield ref 

1240 

1241 # Replace quantum output refs first. 

1242 for node in self._connectedQuanta: 

1243 quantum = node.quantum 

1244 outputs = { 

1245 dataset_type: tuple(_update_output_refs(refs, run, dataset_id_map)) 

1246 for dataset_type, refs in quantum.outputs.items() 

1247 } 

1248 updated_quantum = Quantum( 

1249 taskName=quantum.taskName, 

1250 dataId=quantum.dataId, 

1251 initInputs=quantum.initInputs, 

1252 inputs=quantum.inputs, 

1253 outputs=outputs, 

1254 datastore_records=quantum.datastore_records, 

1255 ) 

1256 node._replace_quantum(updated_quantum) 

1257 

1258 self._initOutputRefs = { 

1259 task_def: list(_update_output_refs(refs, run, dataset_id_map)) 

1260 for task_def, refs in self._initOutputRefs.items() 

1261 } 

1262 self._globalInitOutputRefs = list( 

1263 _update_output_refs(self._globalInitOutputRefs, run, dataset_id_map) 

1264 ) 

1265 

1266 # Update all intermediates from their matching outputs. 

1267 for node in self._connectedQuanta: 

1268 quantum = node.quantum 

1269 inputs = { 

1270 dataset_type: tuple(_update_intermediate_refs(refs, run, dataset_id_map)) 

1271 for dataset_type, refs in quantum.inputs.items() 

1272 } 

1273 initInputs = list(_update_intermediate_refs(quantum.initInputs.values(), run, dataset_id_map)) 

1274 

1275 updated_quantum = Quantum( 

1276 taskName=quantum.taskName, 

1277 dataId=quantum.dataId, 

1278 initInputs=initInputs, 

1279 inputs=inputs, 

1280 outputs=quantum.outputs, 

1281 datastore_records=quantum.datastore_records, 

1282 ) 

1283 node._replace_quantum(updated_quantum) 

1284 

1285 self._initInputRefs = { 

1286 task_def: list(_update_intermediate_refs(refs, run, dataset_id_map)) 

1287 for task_def, refs in self._initInputRefs.items() 

1288 } 

1289 

1290 if update_graph_id: 

1291 self._buildId = BuildId(f"{time.time()}-{os.getpid()}") 

1292 

1293 # Update run if given. 

1294 if metadata_key is not None: 

1295 self._metadata[metadata_key] = run 

1296 

1297 @property 

1298 def graphID(self) -> BuildId: 

1299 """The ID generated by the graph at construction time (`str`).""" 

1300 return self._buildId 

1301 

1302 @property 

1303 def universe(self) -> DimensionUniverse: 

1304 """Dimension universe associated with this graph 

1305 (`~lsst.daf.butler.DimensionUniverse`). 

1306 """ 

1307 return self._universe 

1308 

1309 def __iter__(self) -> Generator[QuantumNode, None, None]: 

1310 yield from nx.topological_sort(self._connectedQuanta) 

1311 

1312 def __len__(self) -> int: 

1313 return self._count 

1314 

1315 def __contains__(self, node: QuantumNode) -> bool: 

1316 return self._connectedQuanta.has_node(node) 

1317 

1318 def __getstate__(self) -> dict: 

1319 """Store a compact form of the graph as a list of graph nodes, and a 

1320 tuple of task labels and task configs. The full graph can be 

1321 reconstructed with this information, and it preserves the ordering of 

1322 the graph nodes. 

1323 """ 

1324 universe: DimensionUniverse | None = None 

1325 for node in self: 

1326 dId = node.quantum.dataId 

1327 if dId is None: 

1328 continue 

1329 universe = dId.universe 

1330 return {"reduced": self._buildSaveObject(), "graphId": self._buildId, "universe": universe} 

1331 

1332 def __setstate__(self, state: dict) -> None: 

1333 """Reconstructs the state of the graph from the information persisted 

1334 in getstate. 

1335 """ 

1336 buffer = io.BytesIO(state["reduced"]) 

1337 with LoadHelper(buffer, minimumVersion=3) as loader: 

1338 qgraph = loader.load(state["universe"], graphID=state["graphId"]) 

1339 

1340 self._metadata = qgraph._metadata 

1341 self._buildId = qgraph._buildId 

1342 self._datasetDict = qgraph._datasetDict 

1343 self._nodeIdMap = qgraph._nodeIdMap 

1344 self._count = len(qgraph) 

1345 self._taskToQuantumNode = qgraph._taskToQuantumNode 

1346 self._taskGraph = qgraph._taskGraph 

1347 self._connectedQuanta = qgraph._connectedQuanta 

1348 self._initInputRefs = qgraph._initInputRefs 

1349 self._initOutputRefs = qgraph._initOutputRefs 

1350 

1351 def __eq__(self, other: object) -> bool: 

1352 if not isinstance(other, QuantumGraph): 

1353 return False 

1354 if len(self) != len(other): 

1355 return False 

1356 for node in self: 

1357 if node not in other: 

1358 return False 

1359 if self.determineInputsToQuantumNode(node) != other.determineInputsToQuantumNode(node): 

1360 return False 

1361 if self.determineOutputsOfQuantumNode(node) != other.determineOutputsOfQuantumNode(node): 

1362 return False 

1363 if set(self.allDatasetTypes) != set(other.allDatasetTypes): 

1364 return False 

1365 return set(self.taskGraph) == set(other.taskGraph)