Coverage for python/lsst/pipe/base/pipeline_graph/_tasks.py: 42%

260 statements  

« prev     ^ index     » next       coverage.py v7.4.2, created at 2024-02-22 11:04 +0000

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27from __future__ import annotations 

28 

29__all__ = ("TaskNode", "TaskInitNode", "TaskImportMode") 

30 

31import dataclasses 

32import enum 

33from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence 

34from typing import TYPE_CHECKING, Any, cast 

35 

36from lsst.daf.butler import ( 

37 DataCoordinate, 

38 DatasetRef, 

39 DatasetType, 

40 DimensionGroup, 

41 DimensionUniverse, 

42 Registry, 

43) 

44from lsst.utils.classes import immutable 

45from lsst.utils.doImport import doImportType 

46from lsst.utils.introspection import get_full_type_name 

47 

48from .. import automatic_connection_constants as acc 

49from ..connections import PipelineTaskConnections 

50from ..connectionTypes import BaseConnection, InitOutput, Output 

51from ._edges import Edge, ReadEdge, WriteEdge 

52from ._exceptions import TaskNotImportedError, UnresolvedGraphError 

53from ._nodes import NodeKey, NodeType 

54 

55if TYPE_CHECKING: 

56 from ..config import PipelineTaskConfig 

57 from ..pipelineTask import PipelineTask 

58 

59 

60class TaskImportMode(enum.Enum): 

61 """Enumeration of the ways to handle importing tasks when reading a 

62 serialized PipelineGraph. 

63 """ 

64 

65 DO_NOT_IMPORT = enum.auto() 

66 """Do not import tasks or instantiate their configs and connections.""" 

67 

68 REQUIRE_CONSISTENT_EDGES = enum.auto() 

69 """Import tasks and instantiate their config and connection objects, and 

70 check that the connections still define the same edges. 

71 """ 

72 

73 ASSUME_CONSISTENT_EDGES = enum.auto() 

74 """Import tasks and instantiate their config and connection objects, but do 

75 not check that the connections still define the same edges. 

76 

77 This is safe only when the caller knows the task definition has not changed 

78 since the pipeline graph was persisted, such as when it was saved and 

79 loaded with the same pipeline version. 

80 """ 

81 

82 OVERRIDE_EDGES = enum.auto() 

83 """Import tasks and instantiate their config and connection objects, and 

84 allow the edges defined in those connections to override those in the 

85 persisted graph. 

86 

87 This may cause dataset type nodes to be unresolved, since resolutions 

88 consistent with the original edges may be invalidated. 

89 """ 

90 

91 

92@dataclasses.dataclass(frozen=True) 

93class _TaskNodeImportedData: 

94 """An internal struct that holds `TaskNode` and `TaskInitNode` state that 

95 requires task classes to be imported. 

96 """ 

97 

98 task_class: type[PipelineTask] 

99 """Type object for the task.""" 

100 

101 config: PipelineTaskConfig 

102 """Configuration object for the task.""" 

103 

104 connection_map: dict[str, BaseConnection] 

105 """Mapping from connection name to connection. 

106 

107 In addition to ``connections.allConnections``, this also holds the 

108 "automatic" config, log, and metadata connections using the names defined 

109 in the `.automatic_connection_constants` module. 

110 """ 

111 

112 connections: PipelineTaskConnections 

113 """Configured connections object for the task.""" 

114 

115 @classmethod 

116 def configure( 

117 cls, 

118 label: str, 

119 task_class: type[PipelineTask], 

120 config: PipelineTaskConfig, 

121 connections: PipelineTaskConnections | None = None, 

122 ) -> _TaskNodeImportedData: 

123 """Construct while creating a `PipelineTaskConnections` instance if 

124 necessary. 

125 

126 Parameters 

127 ---------- 

128 label : `str` 

129 Label for the task in the pipeline. Only used in error messages. 

130 task_class : `type` [ `.PipelineTask` ] 

131 Pipeline task `type` object. 

132 config : `.PipelineTaskConfig` 

133 Configuration for the task. 

134 connections : `.PipelineTaskConnections`, optional 

135 Object that describes the dataset types used by the task. If not 

136 provided, one will be constructed from the given configuration. If 

137 provided, it is assumed that ``config`` has already been validated 

138 and frozen. 

139 

140 Returns 

141 ------- 

142 data : `_TaskNodeImportedData` 

143 Instance of this struct. 

144 """ 

145 if connections is None: 

146 # If we don't have connections yet, assume the config hasn't been 

147 # validated yet. 

148 try: 

149 config.validate() 

150 except Exception as err: 

151 raise ValueError( 

152 f"Configuration validation failed for task {label!r} (see chained exception)." 

153 ) from err 

154 config.freeze() 

155 connections = task_class.ConfigClass.ConnectionsClass(config=config) 

156 connection_map = dict(connections.allConnections) 

157 connection_map[acc.CONFIG_INIT_OUTPUT_CONNECTION_NAME] = InitOutput( 

158 acc.CONFIG_INIT_OUTPUT_TEMPLATE.format(label=label), 

159 acc.CONFIG_INIT_OUTPUT_STORAGE_CLASS, 

160 ) 

161 connection_map[acc.METADATA_OUTPUT_CONNECTION_NAME] = Output( 

162 acc.METADATA_OUTPUT_TEMPLATE.format(label=label), 

163 acc.METADATA_OUTPUT_STORAGE_CLASS, 

164 dimensions=set(connections.dimensions), 

165 ) 

166 if config.saveLogOutput: 

167 connection_map[acc.LOG_OUTPUT_CONNECTION_NAME] = Output( 

168 acc.LOG_OUTPUT_TEMPLATE.format(label=label), 

169 acc.LOG_OUTPUT_STORAGE_CLASS, 

170 dimensions=set(connections.dimensions), 

171 ) 

172 return cls(task_class, config, connection_map, connections) 

173 

174 

175@immutable 

176class TaskInitNode: 

177 """A node in a pipeline graph that represents the construction of a 

178 `PipelineTask`. 

179 

180 Parameters 

181 ---------- 

182 key : `NodeKey` 

183 Key that identifies this node in internal and exported networkx graphs. 

184 inputs : `~collections.abc.Mapping` [ `str`, `ReadEdge` ] 

185 Graph edges that represent inputs required just to construct an 

186 instance of this task, keyed by connection name. 

187 outputs : ~collections.abc.Mapping` [ `str`, `WriteEdge` ] 

188 Graph edges that represent outputs of this task that are available 

189 after just constructing it, keyed by connection name. 

190 

191 This does not include the special `config_init_output` edge; use 

192 `iter_all_outputs` to include that, too. 

193 config_output : `WriteEdge` 

194 The special init output edge that persists the task's configuration. 

195 imported_data : `_TaskNodeImportedData`, optional 

196 Internal struct that holds information that requires the task class to 

197 have been be imported. 

198 task_class_name : `str`, optional 

199 Fully-qualified name of the task class. Must be provided if 

200 ``imported_data`` is not. 

201 config_str : `str`, optional 

202 Configuration for the task as a string of override statements. Must be 

203 provided if ``imported_data`` is not. 

204 

205 Notes 

206 ----- 

207 When included in an exported `networkx` graph (e.g. 

208 `PipelineGraph.make_xgraph`), task initialization nodes set the following 

209 node attributes: 

210 

211 - ``task_class_name`` 

212 - ``bipartite`` (see `NodeType.bipartite`) 

213 - ``task_class`` (only if `is_imported` is `True`) 

214 - ``config`` (only if `is_importd` is `True`) 

215 """ 

216 

217 def __init__( 

218 self, 

219 key: NodeKey, 

220 *, 

221 inputs: Mapping[str, ReadEdge], 

222 outputs: Mapping[str, WriteEdge], 

223 config_output: WriteEdge, 

224 imported_data: _TaskNodeImportedData | None = None, 

225 task_class_name: str | None = None, 

226 config_str: str | None = None, 

227 ): 

228 self.key = key 

229 self.inputs = inputs 

230 self.outputs = outputs 

231 self.config_output = config_output 

232 # Instead of setting attributes to None, we do not set them at all; 

233 # this works better with the @immutable decorator, which supports 

234 # deferred initialization but not reassignment. 

235 if task_class_name is not None: 

236 self._task_class_name = task_class_name 

237 if config_str is not None: 

238 self._config_str = config_str 

239 if imported_data is not None: 

240 self._imported_data = imported_data 

241 else: 

242 assert ( 

243 self._task_class_name is not None and self._config_str is not None 

244 ), "If imported_data is not present, task_class_name and config_str must be." 

245 

246 key: NodeKey 

247 """Key that identifies this node in internal and exported networkx graphs. 

248 """ 

249 

250 inputs: Mapping[str, ReadEdge] 

251 """Graph edges that represent inputs required just to construct an instance 

252 of this task, keyed by connection name. 

253 """ 

254 

255 outputs: Mapping[str, WriteEdge] 

256 """Graph edges that represent outputs of this task that are available after 

257 just constructing it, keyed by connection name. 

258 

259 This does not include the special `config_output` edge; use 

260 `iter_all_outputs` to include that, too. 

261 """ 

262 

263 config_output: WriteEdge 

264 """The special output edge that persists the task's configuration. 

265 """ 

266 

267 @property 

268 def label(self) -> str: 

269 """Label of this configuration of a task in the pipeline.""" 

270 return str(self.key) 

271 

272 @property 

273 def is_imported(self) -> bool: 

274 """Whether this the task type for this node has been imported and 

275 its configuration overrides applied. 

276 

277 If this is `False`, the `task_class` and `config` attributes may not 

278 be accessed. 

279 """ 

280 return hasattr(self, "_imported_data") 

281 

282 @property 

283 def task_class(self) -> type[PipelineTask]: 

284 """Type object for the task. 

285 

286 Accessing this attribute when `is_imported` is `False` will raise 

287 `TaskNotImportedError`, but accessing `task_class_name` will not. 

288 """ 

289 return self._get_imported_data().task_class 

290 

291 @property 

292 def task_class_name(self) -> str: 

293 """The fully-qualified string name of the task class.""" 

294 try: 

295 return self._task_class_name 

296 except AttributeError: 

297 pass 

298 self._task_class_name = get_full_type_name(self.task_class) 

299 return self._task_class_name 

300 

301 @property 

302 def config(self) -> PipelineTaskConfig: 

303 """Configuration for the task. 

304 

305 This is always frozen. 

306 

307 Accessing this attribute when `is_imported` is `False` will raise 

308 `TaskNotImportedError`, but calling `get_config_str` will not. 

309 """ 

310 return self._get_imported_data().config 

311 

312 def __repr__(self) -> str: 

313 return f"{self.label} [init] ({self.task_class_name})" 

314 

315 def get_config_str(self) -> str: 

316 """Return the configuration for this task as a string of override 

317 statements. 

318 

319 Returns 

320 ------- 

321 config_str : `str` 

322 String containing configuration-overload statements. 

323 """ 

324 try: 

325 return self._config_str 

326 except AttributeError: 

327 pass 

328 self._config_str = self.config.saveToString() 

329 return self._config_str 

330 

331 def iter_all_inputs(self) -> Iterator[ReadEdge]: 

332 """Iterate over all inputs required for construction. 

333 

334 This is the same as iteration over ``inputs.values()``, but it will be 

335 updated to include any automatic init-input connections added in the 

336 future, while `inputs` will continue to hold only task-defined init 

337 inputs. 

338 

339 Yields 

340 ------ 

341 `ReadEdge` 

342 All the inputs required for construction. 

343 """ 

344 return iter(self.inputs.values()) 

345 

346 def iter_all_outputs(self) -> Iterator[WriteEdge]: 

347 """Iterate over all outputs available after construction, including 

348 special ones. 

349 

350 Yields 

351 ------ 

352 `ReadEdge` 

353 All the outputs available after construction. 

354 """ 

355 yield from self.outputs.values() 

356 yield self.config_output 

357 

358 def diff_edges(self, other: TaskInitNode) -> list[str]: 

359 """Compare the edges of this task initialization node to those from the 

360 same task label in a different pipeline. 

361 

362 Parameters 

363 ---------- 

364 other : `TaskInitNode` 

365 Other node to compare to. Must have the same task label, but need 

366 not have the same configuration or even the same task class. 

367 

368 Returns 

369 ------- 

370 differences : `list` [ `str` ] 

371 List of string messages describing differences between ``self`` and 

372 ``other``. Will be empty if the two nodes have the same edges. 

373 Messages will use 'A' to refer to ``self`` and 'B' to refer to 

374 ``other``. 

375 """ 

376 result = [] 

377 result += _diff_edge_mapping(self.inputs, self.inputs, self.label, "init input") 

378 result += _diff_edge_mapping(self.outputs, other.outputs, self.label, "init output") 

379 result += self.config_output.diff(other.config_output, "config init output") 

380 return result 

381 

382 def _to_xgraph_state(self) -> dict[str, Any]: 

383 """Convert this nodes's attributes into a dictionary suitable for use 

384 in exported networkx graphs. 

385 """ 

386 result = {"task_class_name": self.task_class_name, "bipartite": NodeType.TASK_INIT.bipartite} 

387 if hasattr(self, "_imported_data"): 

388 result["task_class"] = self.task_class 

389 result["config"] = self.config 

390 return result 

391 

392 def _get_imported_data(self) -> _TaskNodeImportedData: 

393 """Return the imported data struct. 

394 

395 Returns 

396 ------- 

397 imported_data : `_TaskNodeImportedData` 

398 Internal structure holding state that requires the task class to 

399 have been imported. 

400 

401 Raises 

402 ------ 

403 TaskNotImportedError 

404 Raised if `is_imported` is `False`. 

405 """ 

406 try: 

407 return self._imported_data 

408 except AttributeError: 

409 raise TaskNotImportedError( 

410 f"Task class {self.task_class_name!r} for label {self.label!r} has not been imported " 

411 "(see PipelineGraph.import_and_configure)." 

412 ) from None 

413 

414 

415@immutable 

416class TaskNode: 

417 """A node in a pipeline graph that represents a labeled configuration of a 

418 `PipelineTask`. 

419 

420 Parameters 

421 ---------- 

422 key : `NodeKey` 

423 Identifier for this node in networkx graphs. 

424 init : `TaskInitNode` 

425 Node representing the initialization of this task. 

426 prerequisite_inputs : `~collections.abc.Mapping` [ `str`, `ReadEdge` ] 

427 Graph edges that represent prerequisite inputs to this task, keyed by 

428 connection name. 

429 

430 Prerequisite inputs must already exist in the data repository when a 

431 `QuantumGraph` is built, but have more flexibility in how they are 

432 looked up than regular inputs. 

433 inputs : `~collections.abc.Mapping` [ `str`, `ReadEdge` ] 

434 Graph edges that represent regular runtime inputs to this task, keyed 

435 by connection name. 

436 outputs : ~collections.abc.Mapping` [ `str`, `WriteEdge` ] 

437 Graph edges that represent regular runtime outputs of this task, keyed 

438 by connection name. 

439 

440 This does not include the special `log_output` and `metadata_output` 

441 edges; use `iter_all_outputs` to include that, too. 

442 log_output : `WriteEdge` or `None` 

443 The special runtime output that persists the task's logs. 

444 metadata_output : `WriteEdge` 

445 The special runtime output that persists the task's metadata. 

446 dimensions : `lsst.daf.butler.DimensionGroup` or `frozenset` 

447 Dimensions of the task. If a `frozenset`, the dimensions have not been 

448 resolved by a `~lsst.daf.butler.DimensionUniverse` and cannot be safely 

449 compared to other sets of dimensions. 

450 

451 Notes 

452 ----- 

453 Task nodes are intentionally not equality comparable, since there are many 

454 different (and useful) ways to compare these objects with no clear winner 

455 as the most obvious behavior. 

456 

457 When included in an exported `networkx` graph (e.g. 

458 `PipelineGraph.make_xgraph`), task nodes set the following node attributes: 

459 

460 - ``task_class_name`` 

461 - ``bipartite`` (see `NodeType.bipartite`) 

462 - ``task_class`` (only if `is_imported` is `True`) 

463 - ``config`` (only if `is_importd` is `True`) 

464 """ 

465 

466 def __init__( 

467 self, 

468 key: NodeKey, 

469 init: TaskInitNode, 

470 *, 

471 prerequisite_inputs: Mapping[str, ReadEdge], 

472 inputs: Mapping[str, ReadEdge], 

473 outputs: Mapping[str, WriteEdge], 

474 log_output: WriteEdge | None, 

475 metadata_output: WriteEdge, 

476 dimensions: DimensionGroup | frozenset, 

477 ): 

478 self.key = key 

479 self.init = init 

480 self.prerequisite_inputs = prerequisite_inputs 

481 self.inputs = inputs 

482 self.outputs = outputs 

483 self.log_output = log_output 

484 self.metadata_output = metadata_output 

485 self._dimensions = dimensions 

486 

487 @staticmethod 

488 def _from_imported_data( 

489 key: NodeKey, 

490 init_key: NodeKey, 

491 data: _TaskNodeImportedData, 

492 universe: DimensionUniverse | None, 

493 ) -> TaskNode: 

494 """Construct from a `PipelineTask` type and its configuration. 

495 

496 Parameters 

497 ---------- 

498 key : `NodeKey` 

499 Identifier for this node in networkx graphs. 

500 init_key : `TaskInitNode` 

501 Node representing the initialization of this task. 

502 data : `_TaskNodeImportedData` 

503 Internal struct that holds information that requires the task class 

504 to have been be imported. 

505 universe : `lsst.daf.butler.DimensionUniverse` or `None` 

506 Definitions of all dimensions. 

507 

508 Returns 

509 ------- 

510 node : `TaskNode` 

511 New task node. 

512 

513 Raises 

514 ------ 

515 ValueError 

516 Raised if configuration validation failed when constructing 

517 ``connections``. 

518 """ 

519 init_inputs = { 

520 name: ReadEdge._from_connection_map(init_key, name, data.connection_map) 

521 for name in data.connections.initInputs 

522 } 

523 prerequisite_inputs = { 

524 name: ReadEdge._from_connection_map(key, name, data.connection_map, is_prerequisite=True) 

525 for name in data.connections.prerequisiteInputs 

526 } 

527 inputs = { 

528 name: ReadEdge._from_connection_map(key, name, data.connection_map) 

529 for name in data.connections.inputs 

530 } 

531 init_outputs = { 

532 name: WriteEdge._from_connection_map(init_key, name, data.connection_map) 

533 for name in data.connections.initOutputs 

534 } 

535 outputs = { 

536 name: WriteEdge._from_connection_map(key, name, data.connection_map) 

537 for name in data.connections.outputs 

538 } 

539 init = TaskInitNode( 

540 key=init_key, 

541 inputs=init_inputs, 

542 outputs=init_outputs, 

543 config_output=WriteEdge._from_connection_map( 

544 init_key, acc.CONFIG_INIT_OUTPUT_CONNECTION_NAME, data.connection_map 

545 ), 

546 imported_data=data, 

547 ) 

548 instance = TaskNode( 

549 key=key, 

550 init=init, 

551 prerequisite_inputs=prerequisite_inputs, 

552 inputs=inputs, 

553 outputs=outputs, 

554 log_output=( 

555 WriteEdge._from_connection_map(key, acc.LOG_OUTPUT_CONNECTION_NAME, data.connection_map) 

556 if data.config.saveLogOutput 

557 else None 

558 ), 

559 metadata_output=WriteEdge._from_connection_map( 

560 key, acc.METADATA_OUTPUT_CONNECTION_NAME, data.connection_map 

561 ), 

562 dimensions=( 

563 frozenset(data.connections.dimensions) 

564 if universe is None 

565 else universe.conform(data.connections.dimensions) 

566 ), 

567 ) 

568 return instance 

569 

570 key: NodeKey 

571 """Key that identifies this node in internal and exported networkx graphs. 

572 """ 

573 

574 prerequisite_inputs: Mapping[str, ReadEdge] 

575 """Graph edges that represent prerequisite inputs to this task. 

576 

577 Prerequisite inputs must already exist in the data repository when a 

578 `QuantumGraph` is built, but have more flexibility in how they are looked 

579 up than regular inputs. 

580 """ 

581 

582 inputs: Mapping[str, ReadEdge] 

583 """Graph edges that represent regular runtime inputs to this task. 

584 """ 

585 

586 outputs: Mapping[str, WriteEdge] 

587 """Graph edges that represent regular runtime outputs of this task. 

588 

589 This does not include the special `log_output` and `metadata_output` edges; 

590 use `iter_all_outputs` to include that, too. 

591 """ 

592 

593 log_output: WriteEdge | None 

594 """The special runtime output that persists the task's logs. 

595 """ 

596 

597 metadata_output: WriteEdge 

598 """The special runtime output that persists the task's metadata. 

599 """ 

600 

601 @property 

602 def label(self) -> str: 

603 """Label of this configuration of a task in the pipeline.""" 

604 return self.key.name 

605 

606 @property 

607 def is_imported(self) -> bool: 

608 """Whether this the task type for this node has been imported and 

609 its configuration overrides applied. 

610 

611 If this is `False`, the `task_class` and `config` attributes may not 

612 be accessed. 

613 """ 

614 return self.init.is_imported 

615 

616 @property 

617 def task_class(self) -> type[PipelineTask]: 

618 """Type object for the task. 

619 

620 Accessing this attribute when `is_imported` is `False` will raise 

621 `TaskNotImportedError`, but accessing `task_class_name` will not. 

622 """ 

623 return self.init.task_class 

624 

625 @property 

626 def task_class_name(self) -> str: 

627 """The fully-qualified string name of the task class.""" 

628 return self.init.task_class_name 

629 

630 @property 

631 def config(self) -> PipelineTaskConfig: 

632 """Configuration for the task. 

633 

634 This is always frozen. 

635 

636 Accessing this attribute when `is_imported` is `False` will raise 

637 `TaskNotImportedError`, but calling `get_config_str` will not. 

638 """ 

639 return self.init.config 

640 

641 @property 

642 def has_resolved_dimensions(self) -> bool: 

643 """Whether the `dimensions` attribute may be accessed. 

644 

645 If `False`, the `raw_dimensions` attribute may be used to obtain a 

646 set of dimension names that has not been resolved by a 

647 `~lsst.daf.butler.DimensionsUniverse`. 

648 """ 

649 return type(self._dimensions) is DimensionGroup 

650 

651 @property 

652 def dimensions(self) -> DimensionGroup: 

653 """Standardized dimensions of the task.""" 

654 if not self.has_resolved_dimensions: 

655 raise UnresolvedGraphError(f"Dimensions for task {self.label!r} have not been resolved.") 

656 return cast(DimensionGroup, self._dimensions) 

657 

658 @property 

659 def raw_dimensions(self) -> frozenset[str]: 

660 """Raw dimensions of the task, with standardization by a 

661 `~lsst.daf.butler.DimensionUniverse` not guaranteed. 

662 """ 

663 if self.has_resolved_dimensions: 

664 return frozenset(cast(DimensionGroup, self._dimensions).names) 

665 else: 

666 return cast(frozenset[str], self._dimensions) 

667 

668 def __repr__(self) -> str: 

669 if self.has_resolved_dimensions: 

670 return f"{self.label} ({self.task_class_name}, {self.dimensions})" 

671 else: 

672 return f"{self.label} ({self.task_class_name})" 

673 

674 def get_config_str(self) -> str: 

675 """Return the configuration for this task as a string of override 

676 statements. 

677 

678 Returns 

679 ------- 

680 config_str : `str` 

681 String containing configuration-overload statements. 

682 """ 

683 return self.init.get_config_str() 

684 

685 def iter_all_inputs(self) -> Iterator[ReadEdge]: 

686 """Iterate over all runtime inputs, including both regular inputs and 

687 prerequisites. 

688 

689 Yields 

690 ------ 

691 `ReadEdge` 

692 All the runtime inputs. 

693 """ 

694 yield from self.prerequisite_inputs.values() 

695 yield from self.inputs.values() 

696 

697 def iter_all_outputs(self) -> Iterator[WriteEdge]: 

698 """Iterate over all runtime outputs, including special ones. 

699 

700 Yields 

701 ------ 

702 `ReadEdge` 

703 All the runtime outputs. 

704 """ 

705 yield from self.outputs.values() 

706 yield self.metadata_output 

707 if self.log_output is not None: 

708 yield self.log_output 

709 

710 def diff_edges(self, other: TaskNode) -> list[str]: 

711 """Compare the edges of this task node to those from the same task 

712 label in a different pipeline. 

713 

714 This also calls `TaskInitNode.diff_edges`. 

715 

716 Parameters 

717 ---------- 

718 other : `TaskInitNode` 

719 Other node to compare to. Must have the same task label, but need 

720 not have the same configuration or even the same task class. 

721 

722 Returns 

723 ------- 

724 differences : `list` [ `str` ] 

725 List of string messages describing differences between ``self`` and 

726 ``other``. Will be empty if the two nodes have the same edges. 

727 Messages will use 'A' to refer to ``self`` and 'B' to refer to 

728 ``other``. 

729 """ 

730 result = self.init.diff_edges(other.init) 

731 result += _diff_edge_mapping( 

732 self.prerequisite_inputs, other.prerequisite_inputs, self.label, "prerequisite input" 

733 ) 

734 result += _diff_edge_mapping(self.inputs, other.inputs, self.label, "input") 

735 result += _diff_edge_mapping(self.outputs, other.outputs, self.label, "output") 

736 if self.log_output is not None: 

737 if other.log_output is not None: 

738 result += self.log_output.diff(other.log_output, "log output") 

739 else: 

740 result.append("Log output is present in A, but not in B.") 

741 elif other.log_output is not None: 

742 result.append("Log output is present in B, but not in A.") 

743 result += self.metadata_output.diff(other.metadata_output, "metadata output") 

744 return result 

745 

746 def get_lookup_function( 

747 self, connection_name: str 

748 ) -> Callable[[DatasetType, Registry, DataCoordinate, Sequence[str]], Iterable[DatasetRef]] | None: 

749 """Return the custom dataset query function for an edge, if one exists. 

750 

751 Parameters 

752 ---------- 

753 connection_name : `str` 

754 Name of the connection. 

755 

756 Returns 

757 ------- 

758 lookup_function : `~collections.abc.Callable` or `None` 

759 Callable that takes a dataset type, a butler registry, a data 

760 coordinate (the quantum data ID), and an ordered list of 

761 collections to search, and returns an iterable of 

762 `~lsst.daf.butler.DatasetRef`. 

763 """ 

764 return getattr(self._get_imported_data().connection_map[connection_name], "lookupFunction", None) 

765 

766 def get_spatial_bounds_connections(self) -> frozenset[str]: 

767 """Return the names of connections whose data IDs should be included 

768 in the calculation of the spatial bounds for this task's quanta. 

769 

770 Returns 

771 ------- 

772 connection_names : `frozenset` [ `str` ] 

773 Names of connections with spatial dimensions. 

774 """ 

775 return frozenset(self._get_imported_data().connections.getSpatialBoundsConnections()) 

776 

777 def get_temporal_bounds_connections(self) -> frozenset[str]: 

778 """Return the names of connections whose data IDs should be included 

779 in the calculation of the temporal bounds for this task's quanta. 

780 

781 Returns 

782 ------- 

783 connection_names : `frozenset` [ `str` ] 

784 Names of connections with temporal dimensions. 

785 """ 

786 return frozenset(self._get_imported_data().connections.getTemporalBoundsConnections()) 

787 

788 def _imported_and_configured(self, rebuild: bool) -> TaskNode: 

789 """Import the task class and use it to construct a new instance. 

790 

791 Parameters 

792 ---------- 

793 rebuild : `bool` 

794 If `True`, import the task class and configure its connections to 

795 generate new edges that may differ from the current ones. If 

796 `False`, import the task class but just update the `task_class` and 

797 `config` attributes, and assume the edges have not changed. 

798 

799 Returns 

800 ------- 

801 node : `TaskNode` 

802 Task node instance for which `is_imported` is `True`. Will be 

803 ``self`` if this is the case already. 

804 """ 

805 from ..pipelineTask import PipelineTask 

806 

807 if self.is_imported: 

808 return self 

809 task_class = doImportType(self.task_class_name) 

810 if not issubclass(task_class, PipelineTask): 

811 raise TypeError(f"{self.task_class_name!r} is not a PipelineTask subclass.") 

812 config = task_class.ConfigClass() 

813 config.loadFromString(self.get_config_str()) 

814 return self._reconfigured(config, rebuild=rebuild, task_class=task_class) 

815 

816 def _reconfigured( 

817 self, 

818 config: PipelineTaskConfig, 

819 rebuild: bool, 

820 task_class: type[PipelineTask] | None = None, 

821 ) -> TaskNode: 

822 """Return a version of this node with new configuration. 

823 

824 Parameters 

825 ---------- 

826 config : `.PipelineTaskConfig` 

827 New configuration for the task. 

828 rebuild : `bool` 

829 If `True`, use the configured connections to generate new edges 

830 that may differ from the current ones. If `False`, just update the 

831 `task_class` and `config` attributes, and assume the edges have not 

832 changed. 

833 task_class : `type` [ `PipelineTask` ], optional 

834 Subclass of `PipelineTask`. This defaults to ``self.task_class`, 

835 but may be passed as an argument if that is not available because 

836 the task class was not imported when ``self`` was constructed. 

837 

838 Returns 

839 ------- 

840 node : `TaskNode` 

841 Task node instance with the new config. 

842 """ 

843 if task_class is None: 

844 task_class = self.task_class 

845 imported_data = _TaskNodeImportedData.configure(self.key.name, task_class, config) 

846 if rebuild: 

847 return self._from_imported_data( 

848 self.key, 

849 self.init.key, 

850 imported_data, 

851 universe=self._dimensions.universe if type(self._dimensions) is DimensionGroup else None, 

852 ) 

853 else: 

854 return TaskNode( 

855 self.key, 

856 TaskInitNode( 

857 self.init.key, 

858 inputs=self.init.inputs, 

859 outputs=self.init.outputs, 

860 config_output=self.init.config_output, 

861 imported_data=imported_data, 

862 ), 

863 prerequisite_inputs=self.prerequisite_inputs, 

864 inputs=self.inputs, 

865 outputs=self.outputs, 

866 log_output=self.log_output, 

867 metadata_output=self.metadata_output, 

868 dimensions=self._dimensions, 

869 ) 

870 

871 def _resolved(self, universe: DimensionUniverse | None) -> TaskNode: 

872 """Return an otherwise-equivalent task node with resolved dimensions. 

873 

874 Parameters 

875 ---------- 

876 universe : `lsst.daf.butler.DimensionUniverse` or `None` 

877 Definitions for all dimensions. 

878 

879 Returns 

880 ------- 

881 node : `TaskNode` 

882 Task node instance with `dimensions` resolved by the given 

883 universe. Will be ``self`` if this is the case already. 

884 """ 

885 if self.has_resolved_dimensions: 

886 if cast(DimensionGroup, self._dimensions).universe is universe: 

887 return self 

888 elif universe is None: 

889 return self 

890 return TaskNode( 

891 key=self.key, 

892 init=self.init, 

893 prerequisite_inputs=self.prerequisite_inputs, 

894 inputs=self.inputs, 

895 outputs=self.outputs, 

896 log_output=self.log_output, 

897 metadata_output=self.metadata_output, 

898 dimensions=( 

899 universe.conform(self.raw_dimensions) if universe is not None else self.raw_dimensions 

900 ), 

901 ) 

902 

903 def _to_xgraph_state(self) -> dict[str, Any]: 

904 """Convert this nodes's attributes into a dictionary suitable for use 

905 in exported networkx graphs. 

906 """ 

907 result = self.init._to_xgraph_state() 

908 if self.has_resolved_dimensions: 

909 result["dimensions"] = self._dimensions 

910 result["raw_dimensions"] = self.raw_dimensions 

911 return result 

912 

913 def _get_imported_data(self) -> _TaskNodeImportedData: 

914 """Return the imported data struct. 

915 

916 Returns 

917 ------- 

918 imported_data : `_TaskNodeImportedData` 

919 Internal structure holding state that requires the task class to 

920 have been imported. 

921 

922 Raises 

923 ------ 

924 TaskNotImportedError 

925 Raised if `is_imported` is `False`. 

926 """ 

927 return self.init._get_imported_data() 

928 

929 

930def _diff_edge_mapping( 

931 a_mapping: Mapping[str, Edge], b_mapping: Mapping[str, Edge], task_label: str, connection_type: str 

932) -> list[str]: 

933 """Compare a pair of mappings of edges. 

934 

935 Parameters 

936 ---------- 

937 a_mapping : `~collections.abc.Mapping` [ `str`, `Edge` ] 

938 First mapping to compare. Expected to have connection names as keys. 

939 b_mapping : `~collections.abc.Mapping` [ `str`, `Edge` ] 

940 First mapping to compare. If keys differ from those of ``a_mapping``, 

941 this will be reported as a difference (in addition to element-wise 

942 comparisons). 

943 task_label : `str` 

944 Task label associated with both mappings. 

945 connection_type : `str` 

946 Type of connection (e.g. "input" or "init output") associated with both 

947 connections. This is a human-readable string to include in difference 

948 messages. 

949 

950 Returns 

951 ------- 

952 differences : `list` [ `str` ] 

953 List of string messages describing differences between the two 

954 mappings. Will be empty if the two mappings have the same edges. 

955 Messages will include "A" and "B", and are expected to be a preceded 

956 by a message describing what "A" and "B" are in the context in which 

957 this method is called. 

958 

959 Notes 

960 ----- 

961 This is expected to be used to compare one edge-holding mapping attribute 

962 of a task or task init node to the same attribute on another task or task 

963 init node (i.e. any of `TaskNode.inputs`, `TaskNode.outputs`, 

964 `TaskNode.prerequisite_inputs`, `TaskInitNode.inputs`, 

965 `TaskInitNode.outputs`). 

966 """ 

967 results = [] 

968 b_to_do = set(b_mapping.keys()) 

969 for connection_name, a_edge in a_mapping.items(): 

970 if (b_edge := b_mapping.get(connection_name)) is None: 

971 results.append( 

972 f"{connection_type.capitalize()} {connection_name!r} of task " 

973 f"{task_label!r} exists in A, but not in B (or it may have a different connection type)." 

974 ) 

975 else: 

976 results.extend(a_edge.diff(b_edge, connection_type)) 

977 b_to_do.discard(connection_name) 

978 for connection_name in b_to_do: 

979 results.append( 

980 f"{connection_type.capitalize()} {connection_name!r} of task " 

981 f"{task_label!r} exists in A, but not in B (or it may have a different connection type)." 

982 ) 

983 return results