Coverage for python/lsst/pipe/base/graph/graph.py: 20%

426 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-30 02:55 -0700

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27from __future__ import annotations 

28 

29__all__ = ("QuantumGraph", "IncompatibleGraphError") 

30 

31import datetime 

32import getpass 

33import io 

34import json 

35import lzma 

36import os 

37import struct 

38import sys 

39import time 

40import uuid 

41from collections import defaultdict, deque 

42from collections.abc import Generator, Iterable, Iterator, Mapping, MutableMapping 

43from itertools import chain 

44from types import MappingProxyType 

45from typing import Any, BinaryIO, TypeVar 

46 

47import networkx as nx 

48from lsst.daf.butler import ( 

49 DatasetId, 

50 DatasetRef, 

51 DatasetType, 

52 DimensionRecordsAccumulator, 

53 DimensionUniverse, 

54 Quantum, 

55) 

56from lsst.daf.butler.persistence_context import PersistenceContextVars 

57from lsst.resources import ResourcePath, ResourcePathExpression 

58from lsst.utils.introspection import get_full_type_name 

59from lsst.utils.packages import Packages 

60from networkx.drawing.nx_agraph import write_dot 

61 

62from ..connections import iterConnections 

63from ..pipeline import TaskDef 

64from ..pipeline_graph import PipelineGraph 

65from ._implDetails import DatasetTypeName, _DatasetTracker 

66from ._loadHelpers import LoadHelper 

67from ._versionDeserializers import DESERIALIZER_MAP 

68from .graphSummary import QgraphSummary, QgraphTaskSummary 

69from .quantumNode import BuildId, QuantumNode 

70 

71_T = TypeVar("_T", bound="QuantumGraph") 

72 

73# modify this constant any time the on disk representation of the save file 

74# changes, and update the load helpers to behave properly for each version. 

75SAVE_VERSION = 3 

76 

77# Strings used to describe the format for the preamble bytes in a file save 

78# The base is a big endian encoded unsigned short that is used to hold the 

79# file format version. This allows reading version bytes and determine which 

80# loading code should be used for the rest of the file 

81STRUCT_FMT_BASE = ">H" 

82# 

83# Version 1 

84# This marks a big endian encoded format with an unsigned short, an unsigned 

85# long long, and an unsigned long long in the byte stream 

86# Version 2 

87# A big endian encoded format with an unsigned long long byte stream used to 

88# indicate the total length of the entire header. 

89STRUCT_FMT_STRING = {1: ">QQ", 2: ">Q"} 

90 

91# magic bytes that help determine this is a graph save 

92MAGIC_BYTES = b"qgraph4\xf6\xe8\xa9" 

93 

94 

95class IncompatibleGraphError(Exception): 

96 """Exception class to indicate that a lookup by NodeId is impossible due 

97 to incompatibilities. 

98 """ 

99 

100 pass 

101 

102 

103class QuantumGraph: 

104 """QuantumGraph is a directed acyclic graph of `QuantumNode` objects. 

105 

106 This data structure represents a concrete workflow generated from a 

107 `Pipeline`. 

108 

109 Parameters 

110 ---------- 

111 quanta : `~collections.abc.Mapping` [ `TaskDef`, \ 

112 `set` [ `~lsst.daf.butler.Quantum` ] ] 

113 This maps tasks (and their configs) to the sets of data they are to 

114 process. 

115 metadata : Optional `~collections.abc.Mapping` of `str` to primitives 

116 This is an optional parameter of extra data to carry with the graph. 

117 Entries in this mapping should be able to be serialized in JSON. 

118 universe : `~lsst.daf.butler.DimensionUniverse`, optional 

119 The dimensions in which quanta can be defined. Need only be provided if 

120 no quanta have data IDs. 

121 initInputs : `~collections.abc.Mapping`, optional 

122 Maps tasks to their InitInput dataset refs. Dataset refs can be either 

123 resolved or non-resolved. Presently the same dataset refs are included 

124 in each `~lsst.daf.butler.Quantum` for the same task. 

125 initOutputs : `~collections.abc.Mapping`, optional 

126 Maps tasks to their InitOutput dataset refs. Dataset refs can be either 

127 resolved or non-resolved. For intermediate resolved refs their dataset 

128 ID must match ``initInputs`` and Quantum ``initInputs``. 

129 globalInitOutputs : iterable [ `~lsst.daf.butler.DatasetRef` ], optional 

130 Dataset refs for some global objects produced by pipeline. These 

131 objects include task configurations and package versions. Typically 

132 they have an empty DataId, but there is no real restriction on what 

133 can appear here. 

134 registryDatasetTypes : iterable [ `~lsst.daf.butler.DatasetType` ], \ 

135 optional 

136 Dataset types which are used by this graph, their definitions must 

137 match registry. If registry does not define dataset type yet, then 

138 it should match one that will be created later. 

139 

140 Raises 

141 ------ 

142 ValueError 

143 Raised if the graph is pruned such that some tasks no longer have nodes 

144 associated with them. 

145 """ 

146 

147 def __init__( 

148 self, 

149 quanta: Mapping[TaskDef, set[Quantum]], 

150 metadata: Mapping[str, Any] | None = None, 

151 universe: DimensionUniverse | None = None, 

152 initInputs: Mapping[TaskDef, Iterable[DatasetRef]] | None = None, 

153 initOutputs: Mapping[TaskDef, Iterable[DatasetRef]] | None = None, 

154 globalInitOutputs: Iterable[DatasetRef] | None = None, 

155 registryDatasetTypes: Iterable[DatasetType] | None = None, 

156 ): 

157 self._buildGraphs( 

158 quanta, 

159 metadata=metadata, 

160 universe=universe, 

161 initInputs=initInputs, 

162 initOutputs=initOutputs, 

163 globalInitOutputs=globalInitOutputs, 

164 registryDatasetTypes=registryDatasetTypes, 

165 ) 

166 

167 def _buildGraphs( 

168 self, 

169 quanta: Mapping[TaskDef, set[Quantum]], 

170 *, 

171 _quantumToNodeId: Mapping[Quantum, uuid.UUID] | None = None, 

172 _buildId: BuildId | None = None, 

173 metadata: Mapping[str, Any] | None = None, 

174 universe: DimensionUniverse | None = None, 

175 initInputs: Mapping[TaskDef, Iterable[DatasetRef]] | None = None, 

176 initOutputs: Mapping[TaskDef, Iterable[DatasetRef]] | None = None, 

177 globalInitOutputs: Iterable[DatasetRef] | None = None, 

178 registryDatasetTypes: Iterable[DatasetType] | None = None, 

179 ) -> None: 

180 """Build the graph that is used to store the relation between tasks, 

181 and the graph that holds the relations between quanta 

182 """ 

183 # Save packages to metadata 

184 self._metadata = dict(metadata) if metadata is not None else {} 

185 self._metadata["packages"] = Packages.fromSystem() 

186 self._metadata["user"] = getpass.getuser() 

187 self._metadata["time"] = f"{datetime.datetime.now()}" 

188 self._metadata["full_command"] = " ".join(sys.argv) 

189 

190 self._buildId = _buildId if _buildId is not None else BuildId(f"{time.time()}-{os.getpid()}") 

191 # Data structure used to identify relations between 

192 # DatasetTypeName -> TaskDef. 

193 self._datasetDict = _DatasetTracker(createInverse=True) 

194 

195 # Temporary graph that will have dataset UUIDs (as raw bytes) and 

196 # QuantumNode objects as nodes; will be collapsed down to just quanta 

197 # later. 

198 bipartite_graph = nx.DiGraph() 

199 

200 self._nodeIdMap: dict[uuid.UUID, QuantumNode] = {} 

201 self._taskToQuantumNode: MutableMapping[TaskDef, set[QuantumNode]] = defaultdict(set) 

202 for taskDef, quantumSet in quanta.items(): 

203 connections = taskDef.connections 

204 

205 # For each type of connection in the task, add a key to the 

206 # `_DatasetTracker` for the connections name, with a value of 

207 # the TaskDef in the appropriate field 

208 for inpt in iterConnections(connections, ("inputs", "prerequisiteInputs", "initInputs")): 

209 # Have to handle components in inputs. 

210 dataset_name, _, _ = inpt.name.partition(".") 

211 self._datasetDict.addConsumer(DatasetTypeName(dataset_name), taskDef) 

212 

213 for output in iterConnections(connections, ("outputs",)): 

214 # Have to handle possible components in outputs. 

215 dataset_name, _, _ = output.name.partition(".") 

216 self._datasetDict.addProducer(DatasetTypeName(dataset_name), taskDef) 

217 

218 # For each `Quantum` in the set of all `Quantum` for this task, 

219 # add a key to the `_DatasetTracker` that is a `DatasetRef` for one 

220 # of the individual datasets inside the `Quantum`, with a value of 

221 # a newly created QuantumNode to the appropriate input/output 

222 # field. 

223 for quantum in quantumSet: 

224 if quantum.dataId is not None: 

225 if universe is None: 

226 universe = quantum.dataId.universe 

227 elif universe != quantum.dataId.universe: 

228 raise RuntimeError( 

229 "Mismatched dimension universes in QuantumGraph construction: " 

230 f"{universe} != {quantum.dataId.universe}. " 

231 ) 

232 

233 if _quantumToNodeId: 

234 if (nodeId := _quantumToNodeId.get(quantum)) is None: 

235 raise ValueError( 

236 "If _quantuMToNodeNumber is not None, all quanta must have an " 

237 "associated value in the mapping" 

238 ) 

239 else: 

240 nodeId = uuid.uuid4() 

241 

242 inits = quantum.initInputs.values() 

243 inputs = quantum.inputs.values() 

244 value = QuantumNode(quantum, taskDef, nodeId) 

245 self._taskToQuantumNode[taskDef].add(value) 

246 self._nodeIdMap[nodeId] = value 

247 

248 bipartite_graph.add_node(value, bipartite=0) 

249 for dsRef in chain(inits, inputs): 

250 # unfortunately, `Quantum` allows inits to be individual 

251 # `DatasetRef`s or an Iterable of such, so there must 

252 # be an instance check here 

253 if isinstance(dsRef, Iterable): 

254 for sub in dsRef: 

255 bipartite_graph.add_node(sub.id.bytes, bipartite=1) 

256 bipartite_graph.add_edge(sub.id.bytes, value) 

257 else: 

258 assert isinstance(dsRef, DatasetRef) 

259 if dsRef.isComponent(): 

260 dsRef = dsRef.makeCompositeRef() 

261 bipartite_graph.add_node(dsRef.id.bytes, bipartite=1) 

262 bipartite_graph.add_edge(dsRef.id.bytes, value) 

263 for dsRef in chain.from_iterable(quantum.outputs.values()): 

264 bipartite_graph.add_node(dsRef.id.bytes, bipartite=1) 

265 bipartite_graph.add_edge(value, dsRef.id.bytes) 

266 

267 # Dimension universe 

268 if universe is None: 

269 raise RuntimeError( 

270 "Dimension universe or at least one quantum with a data ID " 

271 "must be provided when constructing a QuantumGraph." 

272 ) 

273 self._universe = universe 

274 

275 # Make graph of quanta relations, by projecting out the dataset nodes 

276 # in the bipartite_graph, leaving just the quanta. 

277 self._connectedQuanta = nx.algorithms.bipartite.projected_graph( 

278 bipartite_graph, self._nodeIdMap.values() 

279 ) 

280 self._count = len(self._connectedQuanta) 

281 

282 # Graph of task relations, used in various methods 

283 self._taskGraph = self._datasetDict.makeNetworkXGraph() 

284 

285 # convert default dict into a regular to prevent accidental key 

286 # insertion 

287 self._taskToQuantumNode = dict(self._taskToQuantumNode.items()) 

288 

289 self._initInputRefs: dict[TaskDef, list[DatasetRef]] = {} 

290 self._initOutputRefs: dict[TaskDef, list[DatasetRef]] = {} 

291 self._globalInitOutputRefs: list[DatasetRef] = [] 

292 self._registryDatasetTypes: list[DatasetType] = [] 

293 if initInputs is not None: 

294 self._initInputRefs = {taskDef: list(refs) for taskDef, refs in initInputs.items()} 

295 if initOutputs is not None: 

296 self._initOutputRefs = {taskDef: list(refs) for taskDef, refs in initOutputs.items()} 

297 if globalInitOutputs is not None: 

298 self._globalInitOutputRefs = list(globalInitOutputs) 

299 if registryDatasetTypes is not None: 

300 self._registryDatasetTypes = list(registryDatasetTypes) 

301 

302 # PipelineGraph is current constructed on first use. 

303 # TODO DM-40442: use PipelineGraph instead of TaskDef 

304 # collections. 

305 self._pipeline_graph: PipelineGraph | None = None 

306 

307 @property 

308 def pipeline_graph(self) -> PipelineGraph: 

309 """A graph representation of the tasks and dataset types in the quantum 

310 graph. 

311 """ 

312 if self._pipeline_graph is None: 

313 # Construct into a temporary for strong exception safety. 

314 pipeline_graph = PipelineGraph() 

315 for task_def in self._taskToQuantumNode.keys(): 

316 pipeline_graph.add_task( 

317 task_def.label, task_def.taskClass, task_def.config, connections=task_def.connections 

318 ) 

319 dataset_types = {dataset_type.name: dataset_type for dataset_type in self._registryDatasetTypes} 

320 pipeline_graph.resolve(dimensions=self._universe, dataset_types=dataset_types) 

321 self._pipeline_graph = pipeline_graph 

322 return self._pipeline_graph 

323 

324 def get_task_quanta(self, label: str) -> Mapping[uuid.UUID, Quantum]: 

325 """Return the quanta associated with the given task label. 

326 

327 Parameters 

328 ---------- 

329 label : `str` 

330 Task label. 

331 

332 Returns 

333 ------- 

334 quanta : `~collections.abc.Mapping` [ uuid.UUID, `Quantum` ] 

335 Mapping from quantum ID to quantum. Empty if ``label`` does not 

336 correspond to a task in this graph. 

337 """ 

338 task_def = self.findTaskDefByLabel(label) 

339 if not task_def: 

340 return {} 

341 return {node.nodeId: node.quantum for node in self.getNodesForTask(task_def)} 

342 

343 @property 

344 def taskGraph(self) -> nx.DiGraph: 

345 """A graph representing the relations between the tasks inside 

346 the quantum graph (`networkx.DiGraph`). 

347 """ 

348 return self._taskGraph 

349 

350 @property 

351 def graph(self) -> nx.DiGraph: 

352 """A graph representing the relations between all the `QuantumNode` 

353 objects (`networkx.DiGraph`). 

354 

355 The graph should usually be iterated over, or passed to methods of this 

356 class, but sometimes direct access to the ``networkx`` object may be 

357 helpful. 

358 """ 

359 return self._connectedQuanta 

360 

361 @property 

362 def inputQuanta(self) -> Iterable[QuantumNode]: 

363 """The nodes that are inputs to the graph (iterable [`QuantumNode`]). 

364 

365 These are the nodes that do not depend on any other nodes in the 

366 graph. 

367 """ 

368 return (q for q, n in self._connectedQuanta.in_degree if n == 0) 

369 

370 @property 

371 def outputQuanta(self) -> Iterable[QuantumNode]: 

372 """The nodes that are outputs of the graph (iterable [`QuantumNode`]). 

373 

374 These are the nodes that have no nodes that depend on them in the 

375 graph. 

376 """ 

377 return [q for q, n in self._connectedQuanta.out_degree if n == 0] 

378 

379 @property 

380 def allDatasetTypes(self) -> tuple[DatasetTypeName, ...]: 

381 """All the data set type names that are present in the graph 

382 (`tuple` [`str`]). 

383 

384 These types do not include global init-outputs. 

385 """ 

386 return tuple(self._datasetDict.keys()) 

387 

388 @property 

389 def isConnected(self) -> bool: 

390 """Whether all of the nodes in the graph are connected, ignoring 

391 directionality of connections (`bool`). 

392 """ 

393 return nx.is_weakly_connected(self._connectedQuanta) 

394 

395 def getQuantumNodeByNodeId(self, nodeId: uuid.UUID) -> QuantumNode: 

396 """Lookup a `QuantumNode` from an id associated with the node. 

397 

398 Parameters 

399 ---------- 

400 nodeId : `NodeId` 

401 The number associated with a node. 

402 

403 Returns 

404 ------- 

405 node : `QuantumNode` 

406 The node corresponding with input number. 

407 

408 Raises 

409 ------ 

410 KeyError 

411 Raised if the requested nodeId is not in the graph. 

412 """ 

413 return self._nodeIdMap[nodeId] 

414 

415 def getQuantaForTask(self, taskDef: TaskDef) -> frozenset[Quantum]: 

416 """Return all the `~lsst.daf.butler.Quantum` associated with a 

417 `TaskDef`. 

418 

419 Parameters 

420 ---------- 

421 taskDef : `TaskDef` 

422 The `TaskDef` for which `~lsst.daf.butler.Quantum` are to be 

423 queried. 

424 

425 Returns 

426 ------- 

427 quanta : `frozenset` of `~lsst.daf.butler.Quantum` 

428 The `set` of `~lsst.daf.butler.Quantum` that is associated with the 

429 specified `TaskDef`. 

430 """ 

431 return frozenset(node.quantum for node in self._taskToQuantumNode.get(taskDef, ())) 

432 

433 def getNumberOfQuantaForTask(self, taskDef: TaskDef) -> int: 

434 """Return the number of `~lsst.daf.butler.Quantum` associated with 

435 a `TaskDef`. 

436 

437 Parameters 

438 ---------- 

439 taskDef : `TaskDef` 

440 The `TaskDef` for which `~lsst.daf.butler.Quantum` are to be 

441 queried. 

442 

443 Returns 

444 ------- 

445 count : `int` 

446 The number of `~lsst.daf.butler.Quantum` that are associated with 

447 the specified `TaskDef`. 

448 """ 

449 return len(self._taskToQuantumNode.get(taskDef, ())) 

450 

451 def getNodesForTask(self, taskDef: TaskDef) -> frozenset[QuantumNode]: 

452 r"""Return all the `QuantumNode`\s associated with a `TaskDef`. 

453 

454 Parameters 

455 ---------- 

456 taskDef : `TaskDef` 

457 The `TaskDef` for which `~lsst.daf.butler.Quantum` are to be 

458 queried. 

459 

460 Returns 

461 ------- 

462 nodes : `frozenset` [ `QuantumNode` ] 

463 A `frozenset` of `QuantumNode` that is associated with the 

464 specified `TaskDef`. 

465 """ 

466 return frozenset(self._taskToQuantumNode[taskDef]) 

467 

468 def findTasksWithInput(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]: 

469 """Find all tasks that have the specified dataset type name as an 

470 input. 

471 

472 Parameters 

473 ---------- 

474 datasetTypeName : `str` 

475 A string representing the name of a dataset type to be queried, 

476 can also accept a `DatasetTypeName` which is a `~typing.NewType` of 

477 `str` for type safety in static type checking. 

478 

479 Returns 

480 ------- 

481 tasks : iterable of `TaskDef` 

482 `TaskDef` objects that have the specified `DatasetTypeName` as an 

483 input, list will be empty if no tasks use specified 

484 `DatasetTypeName` as an input. 

485 

486 Raises 

487 ------ 

488 KeyError 

489 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`. 

490 """ 

491 return (c for c in self._datasetDict.getConsumers(datasetTypeName)) 

492 

493 def findTaskWithOutput(self, datasetTypeName: DatasetTypeName) -> TaskDef | None: 

494 """Find all tasks that have the specified dataset type name as an 

495 output. 

496 

497 Parameters 

498 ---------- 

499 datasetTypeName : `str` 

500 A string representing the name of a dataset type to be queried, 

501 can also accept a `DatasetTypeName` which is a `~typing.NewType` of 

502 `str` for type safety in static type checking. 

503 

504 Returns 

505 ------- 

506 result : `TaskDef` or `None` 

507 `TaskDef` that outputs `DatasetTypeName` as an output or `None` if 

508 none of the tasks produce this `DatasetTypeName`. 

509 

510 Raises 

511 ------ 

512 KeyError 

513 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`. 

514 """ 

515 return self._datasetDict.getProducer(datasetTypeName) 

516 

517 def tasksWithDSType(self, datasetTypeName: DatasetTypeName) -> Iterable[TaskDef]: 

518 """Find all tasks that are associated with the specified dataset type 

519 name. 

520 

521 Parameters 

522 ---------- 

523 datasetTypeName : `str` 

524 A string representing the name of a dataset type to be queried, 

525 can also accept a `DatasetTypeName` which is a `~typing.NewType` of 

526 `str` for type safety in static type checking. 

527 

528 Returns 

529 ------- 

530 result : iterable of `TaskDef` 

531 `TaskDef` objects that are associated with the specified 

532 `DatasetTypeName`. 

533 

534 Raises 

535 ------ 

536 KeyError 

537 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`. 

538 """ 

539 return self._datasetDict.getAll(datasetTypeName) 

540 

541 def findTaskDefByName(self, taskName: str) -> list[TaskDef]: 

542 """Determine which `TaskDef` objects in this graph are associated 

543 with a `str` representing a task name (looks at the ``taskName`` 

544 property of `TaskDef` objects). 

545 

546 Returns a list of `TaskDef` objects as a `PipelineTask` may appear 

547 multiple times in a graph with different labels. 

548 

549 Parameters 

550 ---------- 

551 taskName : `str` 

552 Name of a task to search for. 

553 

554 Returns 

555 ------- 

556 result : `list` of `TaskDef` 

557 List of the `TaskDef` objects that have the name specified. 

558 Multiple values are returned in the case that a task is used 

559 multiple times with different labels. 

560 """ 

561 results = [] 

562 for task in self._taskToQuantumNode: 

563 split = task.taskName.split(".") 

564 if split[-1] == taskName: 

565 results.append(task) 

566 return results 

567 

568 def findTaskDefByLabel(self, label: str) -> TaskDef | None: 

569 """Determine which `TaskDef` objects in this graph are associated 

570 with a `str` representing a tasks label. 

571 

572 Parameters 

573 ---------- 

574 label : `str` 

575 Name of a task to search for. 

576 

577 Returns 

578 ------- 

579 result : `TaskDef` 

580 `TaskDef` objects that has the specified label. 

581 """ 

582 for task in self._taskToQuantumNode: 

583 if label == task.label: 

584 return task 

585 return None 

586 

587 def findQuantaWithDSType(self, datasetTypeName: DatasetTypeName) -> set[Quantum]: 

588 r"""Return all the `~lsst.daf.butler.Quantum` that contain a specified 

589 `DatasetTypeName`. 

590 

591 Parameters 

592 ---------- 

593 datasetTypeName : `str` 

594 The name of the dataset type to search for as a string, 

595 can also accept a `DatasetTypeName` which is a `~typing.NewType` of 

596 `str` for type safety in static type checking. 

597 

598 Returns 

599 ------- 

600 result : `set` of `QuantumNode` objects 

601 A `set` of `QuantumNode`\s that contain specified 

602 `DatasetTypeName`. 

603 

604 Raises 

605 ------ 

606 KeyError 

607 Raised if the `DatasetTypeName` is not part of the `QuantumGraph`. 

608 """ 

609 tasks = self._datasetDict.getAll(datasetTypeName) 

610 result: set[Quantum] = set() 

611 result = result.union(quantum for task in tasks for quantum in self.getQuantaForTask(task)) 

612 return result 

613 

614 def checkQuantumInGraph(self, quantum: Quantum) -> bool: 

615 """Check if specified quantum appears in the graph as part of a node. 

616 

617 Parameters 

618 ---------- 

619 quantum : `lsst.daf.butler.Quantum` 

620 The quantum to search for. 

621 

622 Returns 

623 ------- 

624 in_graph : `bool` 

625 The result of searching for the quantum. 

626 """ 

627 return any(quantum == node.quantum for node in self) 

628 

629 def writeDotGraph(self, output: str | io.BufferedIOBase) -> None: 

630 """Write out the graph as a dot graph. 

631 

632 Parameters 

633 ---------- 

634 output : `str` or `io.BufferedIOBase` 

635 Either a filesystem path to write to, or a file handle object. 

636 """ 

637 write_dot(self._connectedQuanta, output) 

638 

639 def subset(self: _T, nodes: QuantumNode | Iterable[QuantumNode]) -> _T: 

640 """Create a new graph object that contains the subset of the nodes 

641 specified as input. Node number is preserved. 

642 

643 Parameters 

644 ---------- 

645 nodes : `QuantumNode` or iterable of `QuantumNode` 

646 Nodes from which to create subset. 

647 

648 Returns 

649 ------- 

650 graph : instance of graph type 

651 An instance of the type from which the subset was created. 

652 """ 

653 if not isinstance(nodes, Iterable): 

654 nodes = (nodes,) 

655 quantumSubgraph = self._connectedQuanta.subgraph(nodes).nodes 

656 quantumMap = defaultdict(set) 

657 

658 dataset_type_names: set[str] = set() 

659 node: QuantumNode 

660 for node in quantumSubgraph: 

661 quantumMap[node.taskDef].add(node.quantum) 

662 dataset_type_names.update( 

663 dstype.name 

664 for dstype in chain( 

665 node.quantum.inputs.keys(), node.quantum.outputs.keys(), node.quantum.initInputs.keys() 

666 ) 

667 ) 

668 

669 # May need to trim dataset types from registryDatasetTypes. 

670 for taskDef in quantumMap: 

671 if refs := self.initOutputRefs(taskDef): 

672 dataset_type_names.update(ref.datasetType.name for ref in refs) 

673 dataset_type_names.update(ref.datasetType.name for ref in self._globalInitOutputRefs) 

674 registryDatasetTypes = [ 

675 dstype for dstype in self._registryDatasetTypes if dstype.name in dataset_type_names 

676 ] 

677 

678 # convert to standard dict to prevent accidental key insertion 

679 quantumDict: dict[TaskDef, set[Quantum]] = dict(quantumMap.items()) 

680 # Create an empty graph, and then populate it with custom mapping 

681 newInst = type(self)({}, universe=self._universe) 

682 # TODO: Do we need to copy initInputs/initOutputs? 

683 newInst._buildGraphs( 

684 quantumDict, 

685 _quantumToNodeId={n.quantum: n.nodeId for n in nodes}, 

686 _buildId=self._buildId, 

687 metadata=self._metadata, 

688 universe=self._universe, 

689 globalInitOutputs=self._globalInitOutputRefs, 

690 registryDatasetTypes=registryDatasetTypes, 

691 ) 

692 return newInst 

693 

694 def subsetToConnected(self: _T) -> tuple[_T, ...]: 

695 """Generate a list of subgraphs where each is connected. 

696 

697 Returns 

698 ------- 

699 result : `list` of `QuantumGraph` 

700 A list of graphs that are each connected. 

701 """ 

702 return tuple( 

703 self.subset(connectedSet) 

704 for connectedSet in nx.weakly_connected_components(self._connectedQuanta) 

705 ) 

706 

707 def determineInputsToQuantumNode(self, node: QuantumNode) -> set[QuantumNode]: 

708 """Return a set of `QuantumNode` that are direct inputs to a specified 

709 node. 

710 

711 Parameters 

712 ---------- 

713 node : `QuantumNode` 

714 The node of the graph for which inputs are to be determined. 

715 

716 Returns 

717 ------- 

718 inputs : `set` of `QuantumNode` 

719 All the nodes that are direct inputs to specified node. 

720 """ 

721 return set(self._connectedQuanta.predecessors(node)) 

722 

723 def determineOutputsOfQuantumNode(self, node: QuantumNode) -> set[QuantumNode]: 

724 """Return a set of `QuantumNode` that are direct outputs of a specified 

725 node. 

726 

727 Parameters 

728 ---------- 

729 node : `QuantumNode` 

730 The node of the graph for which outputs are to be determined. 

731 

732 Returns 

733 ------- 

734 outputs : `set` of `QuantumNode` 

735 All the nodes that are direct outputs to specified node. 

736 """ 

737 return set(self._connectedQuanta.successors(node)) 

738 

739 def determineConnectionsOfQuantumNode(self: _T, node: QuantumNode) -> _T: 

740 """Return a graph of `QuantumNode` that are direct inputs and outputs 

741 of a specified node. 

742 

743 Parameters 

744 ---------- 

745 node : `QuantumNode` 

746 The node of the graph for which connected nodes are to be 

747 determined. 

748 

749 Returns 

750 ------- 

751 graph : graph of `QuantumNode` 

752 All the nodes that are directly connected to specified node. 

753 """ 

754 nodes = self.determineInputsToQuantumNode(node).union(self.determineOutputsOfQuantumNode(node)) 

755 nodes.add(node) 

756 return self.subset(nodes) 

757 

758 def determineAncestorsOfQuantumNode(self: _T, node: QuantumNode) -> _T: 

759 """Return a graph of the specified node and all the ancestor nodes 

760 directly reachable by walking edges. 

761 

762 Parameters 

763 ---------- 

764 node : `QuantumNode` 

765 The node for which all ancestors are to be determined. 

766 

767 Returns 

768 ------- 

769 ancestors : graph of `QuantumNode` 

770 Graph of node and all of its ancestors. 

771 """ 

772 predecessorNodes = nx.ancestors(self._connectedQuanta, node) 

773 predecessorNodes.add(node) 

774 return self.subset(predecessorNodes) 

775 

776 def findCycle(self) -> list[tuple[QuantumNode, QuantumNode]]: 

777 """Check a graph for the presense of cycles and returns the edges of 

778 any cycles found, or an empty list if there is no cycle. 

779 

780 Returns 

781 ------- 

782 result : `list` of `tuple` of [ `QuantumNode`, `QuantumNode` ] 

783 A list of any graph edges that form a cycle, or an empty list if 

784 there is no cycle. Empty list to so support if graph.find_cycle() 

785 syntax as an empty list is falsy. 

786 """ 

787 try: 

788 return nx.find_cycle(self._connectedQuanta) 

789 except nx.NetworkXNoCycle: 

790 return [] 

791 

792 def saveUri(self, uri: ResourcePathExpression) -> None: 

793 """Save `QuantumGraph` to the specified URI. 

794 

795 Parameters 

796 ---------- 

797 uri : convertible to `~lsst.resources.ResourcePath` 

798 URI to where the graph should be saved. 

799 """ 

800 buffer = self._buildSaveObject() 

801 path = ResourcePath(uri) 

802 if path.getExtension() not in (".qgraph"): 

803 raise TypeError(f"Can currently only save a graph in qgraph format not {uri}") 

804 path.write(buffer) # type: ignore # Ignore because bytearray is safe to use in place of bytes 

805 

806 @property 

807 def metadata(self) -> MappingProxyType[str, Any]: 

808 """Extra data carried with the graph (mapping [`str`] or `None`). 

809 

810 The mapping is a dynamic view of this object's metadata. Values should 

811 be able to be serialized in JSON. 

812 """ 

813 return MappingProxyType(self._metadata) 

814 

815 def initInputRefs(self, taskDef: TaskDef) -> list[DatasetRef] | None: 

816 """Return DatasetRefs for a given task InitInputs. 

817 

818 Parameters 

819 ---------- 

820 taskDef : `TaskDef` 

821 Task definition structure. 

822 

823 Returns 

824 ------- 

825 refs : `list` [ `~lsst.daf.butler.DatasetRef` ] or `None` 

826 DatasetRef for the task InitInput, can be `None`. This can return 

827 either resolved or non-resolved reference. 

828 """ 

829 return self._initInputRefs.get(taskDef) 

830 

831 def initOutputRefs(self, taskDef: TaskDef) -> list[DatasetRef] | None: 

832 """Return DatasetRefs for a given task InitOutputs. 

833 

834 Parameters 

835 ---------- 

836 taskDef : `TaskDef` 

837 Task definition structure. 

838 

839 Returns 

840 ------- 

841 refs : `list` [ `~lsst.daf.butler.DatasetRef` ] or `None` 

842 DatasetRefs for the task InitOutput, can be `None`. This can return 

843 either resolved or non-resolved reference. Resolved reference will 

844 match Quantum's initInputs if this is an intermediate dataset type. 

845 """ 

846 return self._initOutputRefs.get(taskDef) 

847 

848 def globalInitOutputRefs(self) -> list[DatasetRef]: 

849 """Return DatasetRefs for global InitOutputs. 

850 

851 Returns 

852 ------- 

853 refs : `list` [ `~lsst.daf.butler.DatasetRef` ] 

854 DatasetRefs for global InitOutputs. 

855 """ 

856 return self._globalInitOutputRefs 

857 

858 def registryDatasetTypes(self) -> list[DatasetType]: 

859 """Return dataset types used by this graph, their definitions match 

860 dataset types from registry. 

861 

862 Returns 

863 ------- 

864 refs : `list` [ `~lsst.daf.butler.DatasetType` ] 

865 Dataset types for this graph. 

866 """ 

867 return self._registryDatasetTypes 

868 

869 @classmethod 

870 def loadUri( 

871 cls, 

872 uri: ResourcePathExpression, 

873 universe: DimensionUniverse | None = None, 

874 nodes: Iterable[uuid.UUID] | None = None, 

875 graphID: BuildId | None = None, 

876 minimumVersion: int = 3, 

877 ) -> QuantumGraph: 

878 """Read `QuantumGraph` from a URI. 

879 

880 Parameters 

881 ---------- 

882 uri : convertible to `~lsst.resources.ResourcePath` 

883 URI from where to load the graph. 

884 universe : `~lsst.daf.butler.DimensionUniverse`, optional 

885 If `None` it is loaded from the `QuantumGraph` 

886 saved structure. If supplied, the 

887 `~lsst.daf.butler.DimensionUniverse` from the loaded `QuantumGraph` 

888 will be validated against the supplied argument for compatibility. 

889 nodes : iterable of `uuid.UUID` or `None` 

890 UUIDs that correspond to nodes in the graph. If specified, only 

891 these nodes will be loaded. Defaults to None, in which case all 

892 nodes will be loaded. 

893 graphID : `str` or `None` 

894 If specified this ID is verified against the loaded graph prior to 

895 loading any Nodes. This defaults to None in which case no 

896 validation is done. 

897 minimumVersion : `int` 

898 Minimum version of a save file to load. Set to -1 to load all 

899 versions. Older versions may need to be loaded, and re-saved 

900 to upgrade them to the latest format before they can be used in 

901 production. 

902 

903 Returns 

904 ------- 

905 graph : `QuantumGraph` 

906 Resulting QuantumGraph instance. 

907 

908 Raises 

909 ------ 

910 TypeError 

911 Raised if file contains instance of a type other than 

912 `QuantumGraph`. 

913 ValueError 

914 Raised if one or more of the nodes requested is not in the 

915 `QuantumGraph` or if graphID parameter does not match the graph 

916 being loaded or if the supplied uri does not point at a valid 

917 `QuantumGraph` save file. 

918 RuntimeError 

919 Raise if Supplied `~lsst.daf.butler.DimensionUniverse` is not 

920 compatible with the `~lsst.daf.butler.DimensionUniverse` saved in 

921 the graph. 

922 """ 

923 uri = ResourcePath(uri) 

924 if uri.getExtension() in {".qgraph"}: 

925 with LoadHelper(uri, minimumVersion) as loader: 

926 qgraph = loader.load(universe, nodes, graphID) 

927 else: 

928 raise ValueError(f"Only know how to handle files saved as `.qgraph`, not {uri}") 

929 if not isinstance(qgraph, QuantumGraph): 

930 raise TypeError(f"QuantumGraph file {uri} contains unexpected object type: {type(qgraph)}") 

931 return qgraph 

932 

933 @classmethod 

934 def readHeader(cls, uri: ResourcePathExpression, minimumVersion: int = 3) -> str | None: 

935 """Read the header of a `QuantumGraph` pointed to by the uri parameter 

936 and return it as a string. 

937 

938 Parameters 

939 ---------- 

940 uri : convertible to `~lsst.resources.ResourcePath` 

941 The location of the `QuantumGraph` to load. If the argument is a 

942 string, it must correspond to a valid 

943 `~lsst.resources.ResourcePath` path. 

944 minimumVersion : `int` 

945 Minimum version of a save file to load. Set to -1 to load all 

946 versions. Older versions may need to be loaded, and re-saved 

947 to upgrade them to the latest format before they can be used in 

948 production. 

949 

950 Returns 

951 ------- 

952 header : `str` or `None` 

953 The header associated with the specified `QuantumGraph` it there is 

954 one, else `None`. 

955 

956 Raises 

957 ------ 

958 ValueError 

959 Raised if the extension of the file specified by uri is not a 

960 `QuantumGraph` extension. 

961 """ 

962 uri = ResourcePath(uri) 

963 if uri.getExtension() in {".qgraph"}: 

964 return LoadHelper(uri, minimumVersion).readHeader() 

965 else: 

966 raise ValueError("Only know how to handle files saved as `.qgraph`") 

967 

968 def buildAndPrintHeader(self) -> None: 

969 """Create a header that would be used in a save of this object and 

970 prints it out to standard out. 

971 """ 

972 _, header = self._buildSaveObject(returnHeader=True) 

973 print(json.dumps(header)) 

974 

975 def save(self, file: BinaryIO) -> None: 

976 """Save QuantumGraph to a file. 

977 

978 Parameters 

979 ---------- 

980 file : `io.BufferedIOBase` 

981 File to write data open in binary mode. 

982 """ 

983 buffer = self._buildSaveObject() 

984 file.write(buffer) # type: ignore # Ignore because bytearray is safe to use in place of bytes 

985 

986 def _buildSaveObject(self, returnHeader: bool = False) -> bytearray | tuple[bytearray, dict]: 

987 thing = PersistenceContextVars() 

988 result = thing.run(self._buildSaveObjectImpl, returnHeader) 

989 return result 

990 

991 def _buildSaveObjectImpl(self, returnHeader: bool = False) -> bytearray | tuple[bytearray, dict]: 

992 # make some containers 

993 jsonData: deque[bytes] = deque() 

994 # node map is a list because json does not accept mapping keys that 

995 # are not strings, so we store a list of key, value pairs that will 

996 # be converted to a mapping on load 

997 nodeMap = [] 

998 taskDefMap = {} 

999 headerData: dict[str, Any] = {} 

1000 

1001 # Store the QuantumGraph BuildId, this will allow validating BuildIds 

1002 # at load time, prior to loading any QuantumNodes. Name chosen for 

1003 # unlikely conflicts. 

1004 headerData["GraphBuildID"] = self.graphID 

1005 headerData["Metadata"] = self._metadata 

1006 

1007 # Store the universe this graph was created with 

1008 universeConfig = self._universe.dimensionConfig 

1009 headerData["universe"] = universeConfig.toDict() 

1010 

1011 # counter for the number of bytes processed thus far 

1012 count = 0 

1013 # serialize out the task Defs recording the start and end bytes of each 

1014 # taskDef 

1015 inverseLookup = self._datasetDict.inverse 

1016 taskDef: TaskDef 

1017 # sort by task label to ensure serialization happens in the same order 

1018 for taskDef in self.taskGraph: 

1019 # compressing has very little impact on saving or load time, but 

1020 # a large impact on on disk size, so it is worth doing 

1021 taskDescription: dict[str, Any] = {} 

1022 # save the fully qualified name. 

1023 taskDescription["taskName"] = get_full_type_name(taskDef.taskClass) 

1024 # save the config as a text stream that will be un-persisted on the 

1025 # other end 

1026 stream = io.StringIO() 

1027 taskDef.config.saveToStream(stream) 

1028 taskDescription["config"] = stream.getvalue() 

1029 taskDescription["label"] = taskDef.label 

1030 if (refs := self._initInputRefs.get(taskDef)) is not None: 

1031 taskDescription["initInputRefs"] = [ref.to_json() for ref in refs] 

1032 if (refs := self._initOutputRefs.get(taskDef)) is not None: 

1033 taskDescription["initOutputRefs"] = [ref.to_json() for ref in refs] 

1034 

1035 inputs = [] 

1036 outputs = [] 

1037 

1038 # Determine the connection between all of tasks and save that in 

1039 # the header as a list of connections and edges in each task 

1040 # this will help in un-persisting, and possibly in a "quick view" 

1041 # method that does not require everything to be un-persisted 

1042 # 

1043 # Typing returns can't be parameter dependent 

1044 for connection in inverseLookup[taskDef]: # type: ignore 

1045 consumers = self._datasetDict.getConsumers(connection) 

1046 producer = self._datasetDict.getProducer(connection) 

1047 if taskDef in consumers: 

1048 # This checks if the task consumes the connection directly 

1049 # from the datastore or it is produced by another task 

1050 producerLabel = producer.label if producer is not None else "datastore" 

1051 inputs.append((producerLabel, connection)) 

1052 elif taskDef not in consumers and producer is taskDef: 

1053 # If there are no consumers for this tasks produced 

1054 # connection, the output will be said to be the datastore 

1055 # in which case the for loop will be a zero length loop 

1056 if not consumers: 

1057 outputs.append(("datastore", connection)) 

1058 for td in consumers: 

1059 outputs.append((td.label, connection)) 

1060 

1061 # dump to json string, and encode that string to bytes and then 

1062 # conpress those bytes 

1063 dump = lzma.compress(json.dumps(taskDescription).encode(), preset=2) 

1064 # record the sizing and relation information 

1065 taskDefMap[taskDef.label] = { 

1066 "bytes": (count, count + len(dump)), 

1067 "inputs": inputs, 

1068 "outputs": outputs, 

1069 } 

1070 count += len(dump) 

1071 jsonData.append(dump) 

1072 

1073 headerData["TaskDefs"] = taskDefMap 

1074 

1075 # serialize the nodes, recording the start and end bytes of each node 

1076 dimAccumulator = DimensionRecordsAccumulator() 

1077 for node in self: 

1078 # compressing has very little impact on saving or load time, but 

1079 # a large impact on on disk size, so it is worth doing 

1080 simpleNode = node.to_simple(accumulator=dimAccumulator) 

1081 

1082 dump = lzma.compress(simpleNode.model_dump_json().encode(), preset=2) 

1083 jsonData.append(dump) 

1084 nodeMap.append( 

1085 ( 

1086 str(node.nodeId), 

1087 { 

1088 "bytes": (count, count + len(dump)), 

1089 "inputs": [str(n.nodeId) for n in self.determineInputsToQuantumNode(node)], 

1090 "outputs": [str(n.nodeId) for n in self.determineOutputsOfQuantumNode(node)], 

1091 }, 

1092 ) 

1093 ) 

1094 count += len(dump) 

1095 

1096 headerData["DimensionRecords"] = { 

1097 key: value.model_dump() 

1098 for key, value in dimAccumulator.makeSerializedDimensionRecordMapping().items() 

1099 } 

1100 

1101 # need to serialize this as a series of key,value tuples because of 

1102 # a limitation on how json cant do anything but strings as keys 

1103 headerData["Nodes"] = nodeMap 

1104 

1105 if self._globalInitOutputRefs: 

1106 headerData["GlobalInitOutputRefs"] = [ref.to_json() for ref in self._globalInitOutputRefs] 

1107 

1108 if self._registryDatasetTypes: 

1109 headerData["RegistryDatasetTypes"] = [dstype.to_json() for dstype in self._registryDatasetTypes] 

1110 

1111 # dump the headerData to json 

1112 header_encode = lzma.compress(json.dumps(headerData).encode()) 

1113 

1114 # record the sizes as 2 unsigned long long numbers for a total of 16 

1115 # bytes 

1116 save_bytes = struct.pack(STRUCT_FMT_BASE, SAVE_VERSION) 

1117 

1118 fmt_string = DESERIALIZER_MAP[SAVE_VERSION].FMT_STRING() 

1119 map_lengths = struct.pack(fmt_string, len(header_encode)) 

1120 

1121 # write each component of the save out in a deterministic order 

1122 buffer = bytearray() 

1123 buffer.extend(MAGIC_BYTES) 

1124 buffer.extend(save_bytes) 

1125 buffer.extend(map_lengths) 

1126 buffer.extend(header_encode) 

1127 # Iterate over the length of jsonData, and for each element pop the 

1128 # leftmost element off the deque and write it out. This is to save 

1129 # memory, as the memory is added to the buffer object, it is removed 

1130 # from from the container. 

1131 # 

1132 # Only this section needs to worry about memory pressure because 

1133 # everything else written to the buffer prior to this data is 

1134 # only on the order of kilobytes to low numbers of megabytes. 

1135 while jsonData: 

1136 buffer.extend(jsonData.popleft()) 

1137 if returnHeader: 

1138 return buffer, headerData 

1139 else: 

1140 return buffer 

1141 

1142 @classmethod 

1143 def load( 

1144 cls, 

1145 file: BinaryIO, 

1146 universe: DimensionUniverse | None = None, 

1147 nodes: Iterable[uuid.UUID] | None = None, 

1148 graphID: BuildId | None = None, 

1149 minimumVersion: int = 3, 

1150 ) -> QuantumGraph: 

1151 """Read `QuantumGraph` from a file that was made by `save`. 

1152 

1153 Parameters 

1154 ---------- 

1155 file : `io.IO` of bytes 

1156 File with data open in binary mode. 

1157 universe : `~lsst.daf.butler.DimensionUniverse`, optional 

1158 If `None` it is loaded from the `QuantumGraph` 

1159 saved structure. If supplied, the 

1160 `~lsst.daf.butler.DimensionUniverse` from the loaded `QuantumGraph` 

1161 will be validated against the supplied argument for compatibility. 

1162 nodes : iterable of `uuid.UUID` or `None` 

1163 UUIDs that correspond to nodes in the graph. If specified, only 

1164 these nodes will be loaded. Defaults to None, in which case all 

1165 nodes will be loaded. 

1166 graphID : `str` or `None` 

1167 If specified this ID is verified against the loaded graph prior to 

1168 loading any Nodes. This defaults to None in which case no 

1169 validation is done. 

1170 minimumVersion : `int` 

1171 Minimum version of a save file to load. Set to -1 to load all 

1172 versions. Older versions may need to be loaded, and re-saved 

1173 to upgrade them to the latest format before they can be used in 

1174 production. 

1175 

1176 Returns 

1177 ------- 

1178 graph : `QuantumGraph` 

1179 Resulting QuantumGraph instance. 

1180 

1181 Raises 

1182 ------ 

1183 TypeError 

1184 Raised if data contains instance of a type other than 

1185 `QuantumGraph`. 

1186 ValueError 

1187 Raised if one or more of the nodes requested is not in the 

1188 `QuantumGraph` or if graphID parameter does not match the graph 

1189 being loaded or if the supplied uri does not point at a valid 

1190 `QuantumGraph` save file. 

1191 """ 

1192 with LoadHelper(file, minimumVersion) as loader: 

1193 qgraph = loader.load(universe, nodes, graphID) 

1194 if not isinstance(qgraph, QuantumGraph): 

1195 raise TypeError(f"QuantumGraph file contains unexpected object type: {type(qgraph)}") 

1196 return qgraph 

1197 

1198 def iterTaskGraph(self) -> Generator[TaskDef, None, None]: 

1199 """Iterate over the `taskGraph` attribute in topological order. 

1200 

1201 Yields 

1202 ------ 

1203 taskDef : `TaskDef` 

1204 `TaskDef` objects in topological order. 

1205 """ 

1206 yield from nx.topological_sort(self.taskGraph) 

1207 

1208 def updateRun(self, run: str, *, metadata_key: str | None = None, update_graph_id: bool = False) -> None: 

1209 """Change output run and dataset ID for each output dataset. 

1210 

1211 Parameters 

1212 ---------- 

1213 run : `str` 

1214 New output run name. 

1215 metadata_key : `str` or `None` 

1216 Specifies matadata key corresponding to output run name to update 

1217 with new run name. If `None` or if metadata is missing it is not 

1218 updated. If metadata is present but key is missing, it will be 

1219 added. 

1220 update_graph_id : `bool`, optional 

1221 If `True` then also update graph ID with a new unique value. 

1222 """ 

1223 dataset_id_map: dict[DatasetId, DatasetId] = {} 

1224 

1225 def _update_output_refs( 

1226 refs: Iterable[DatasetRef], run: str, dataset_id_map: MutableMapping[DatasetId, DatasetId] 

1227 ) -> Iterator[DatasetRef]: 

1228 """Update a collection of `~lsst.daf.butler.DatasetRef` with new 

1229 run and dataset IDs. 

1230 """ 

1231 for ref in refs: 

1232 new_ref = ref.replace(run=run) 

1233 dataset_id_map[ref.id] = new_ref.id 

1234 yield new_ref 

1235 

1236 def _update_intermediate_refs( 

1237 refs: Iterable[DatasetRef], run: str, dataset_id_map: Mapping[DatasetId, DatasetId] 

1238 ) -> Iterator[DatasetRef]: 

1239 """Update intermediate references with new run and IDs. Only the 

1240 references that appear in ``dataset_id_map`` are updated, others 

1241 are returned unchanged. 

1242 """ 

1243 for ref in refs: 

1244 if dataset_id := dataset_id_map.get(ref.id): 

1245 ref = ref.replace(run=run, id=dataset_id) 

1246 yield ref 

1247 

1248 # Replace quantum output refs first. 

1249 for node in self._connectedQuanta: 

1250 quantum = node.quantum 

1251 outputs = { 

1252 dataset_type: tuple(_update_output_refs(refs, run, dataset_id_map)) 

1253 for dataset_type, refs in quantum.outputs.items() 

1254 } 

1255 updated_quantum = Quantum( 

1256 taskName=quantum.taskName, 

1257 dataId=quantum.dataId, 

1258 initInputs=quantum.initInputs, 

1259 inputs=quantum.inputs, 

1260 outputs=outputs, 

1261 datastore_records=quantum.datastore_records, 

1262 ) 

1263 node._replace_quantum(updated_quantum) 

1264 

1265 self._initOutputRefs = { 

1266 task_def: list(_update_output_refs(refs, run, dataset_id_map)) 

1267 for task_def, refs in self._initOutputRefs.items() 

1268 } 

1269 self._globalInitOutputRefs = list( 

1270 _update_output_refs(self._globalInitOutputRefs, run, dataset_id_map) 

1271 ) 

1272 

1273 # Update all intermediates from their matching outputs. 

1274 for node in self._connectedQuanta: 

1275 quantum = node.quantum 

1276 inputs = { 

1277 dataset_type: tuple(_update_intermediate_refs(refs, run, dataset_id_map)) 

1278 for dataset_type, refs in quantum.inputs.items() 

1279 } 

1280 initInputs = list(_update_intermediate_refs(quantum.initInputs.values(), run, dataset_id_map)) 

1281 

1282 updated_quantum = Quantum( 

1283 taskName=quantum.taskName, 

1284 dataId=quantum.dataId, 

1285 initInputs=initInputs, 

1286 inputs=inputs, 

1287 outputs=quantum.outputs, 

1288 datastore_records=quantum.datastore_records, 

1289 ) 

1290 node._replace_quantum(updated_quantum) 

1291 

1292 self._initInputRefs = { 

1293 task_def: list(_update_intermediate_refs(refs, run, dataset_id_map)) 

1294 for task_def, refs in self._initInputRefs.items() 

1295 } 

1296 

1297 if update_graph_id: 

1298 self._buildId = BuildId(f"{time.time()}-{os.getpid()}") 

1299 

1300 # Update run if given. 

1301 if metadata_key is not None: 

1302 self._metadata[metadata_key] = run 

1303 

1304 @property 

1305 def graphID(self) -> BuildId: 

1306 """The ID generated by the graph at construction time (`str`).""" 

1307 return self._buildId 

1308 

1309 @property 

1310 def universe(self) -> DimensionUniverse: 

1311 """Dimension universe associated with this graph 

1312 (`~lsst.daf.butler.DimensionUniverse`). 

1313 """ 

1314 return self._universe 

1315 

1316 def __iter__(self) -> Generator[QuantumNode, None, None]: 

1317 yield from nx.topological_sort(self._connectedQuanta) 

1318 

1319 def __len__(self) -> int: 

1320 return self._count 

1321 

1322 def __contains__(self, node: QuantumNode) -> bool: 

1323 return self._connectedQuanta.has_node(node) 

1324 

1325 def __getstate__(self) -> dict: 

1326 """Store a compact form of the graph as a list of graph nodes, and a 

1327 tuple of task labels and task configs. The full graph can be 

1328 reconstructed with this information, and it preserves the ordering of 

1329 the graph nodes. 

1330 """ 

1331 universe: DimensionUniverse | None = None 

1332 for node in self: 

1333 dId = node.quantum.dataId 

1334 if dId is None: 

1335 continue 

1336 universe = dId.universe 

1337 return {"reduced": self._buildSaveObject(), "graphId": self._buildId, "universe": universe} 

1338 

1339 def __setstate__(self, state: dict) -> None: 

1340 """Reconstructs the state of the graph from the information persisted 

1341 in getstate. 

1342 """ 

1343 buffer = io.BytesIO(state["reduced"]) 

1344 with LoadHelper(buffer, minimumVersion=3) as loader: 

1345 qgraph = loader.load(state["universe"], graphID=state["graphId"]) 

1346 

1347 self._metadata = qgraph._metadata 

1348 self._buildId = qgraph._buildId 

1349 self._datasetDict = qgraph._datasetDict 

1350 self._nodeIdMap = qgraph._nodeIdMap 

1351 self._count = len(qgraph) 

1352 self._taskToQuantumNode = qgraph._taskToQuantumNode 

1353 self._taskGraph = qgraph._taskGraph 

1354 self._connectedQuanta = qgraph._connectedQuanta 

1355 self._initInputRefs = qgraph._initInputRefs 

1356 self._initOutputRefs = qgraph._initOutputRefs 

1357 

1358 def __eq__(self, other: object) -> bool: 

1359 if not isinstance(other, QuantumGraph): 

1360 return False 

1361 if len(self) != len(other): 

1362 return False 

1363 for node in self: 

1364 if node not in other: 

1365 return False 

1366 if self.determineInputsToQuantumNode(node) != other.determineInputsToQuantumNode(node): 

1367 return False 

1368 if self.determineOutputsOfQuantumNode(node) != other.determineOutputsOfQuantumNode(node): 

1369 return False 

1370 if set(self.allDatasetTypes) != set(other.allDatasetTypes): 

1371 return False 

1372 return set(self.taskGraph) == set(other.taskGraph) 

1373 

1374 def getSummary(self) -> QgraphSummary: 

1375 """Create summary of graph. 

1376 

1377 Returns 

1378 ------- 

1379 summary : `QgraphSummary` 

1380 Summary of QuantumGraph. 

1381 """ 

1382 inCollection = self.metadata.get("input", None) 

1383 if isinstance(inCollection, str): 

1384 inCollection = [inCollection] 

1385 summary = QgraphSummary( 

1386 graphID=self.graphID, 

1387 cmdLine=self.metadata.get("full_command", None), 

1388 creationUTC=self.metadata.get("time", None), 

1389 inputCollection=inCollection, 

1390 outputCollection=self.metadata.get("output", None), 

1391 outputRun=self.metadata.get("output_run", None), 

1392 ) 

1393 for q in self: 

1394 qts = summary.qgraphTaskSummaries.setdefault( 

1395 q.taskDef.label, QgraphTaskSummary(taskLabel=q.taskDef.label) 

1396 ) 

1397 qts.numQuanta += 1 

1398 

1399 for k in q.quantum.inputs.keys(): 

1400 qts.numInputs[k.name] += 1 

1401 

1402 for k in q.quantum.outputs.keys(): 

1403 qts.numOutputs[k.name] += 1 

1404 

1405 return summary