Coverage for python / lsst / pipe / base / pipeline_graph / _tasks.py: 34%

315 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-06 08:32 +0000

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27from __future__ import annotations 

28 

29__all__ = ("TaskImportMode", "TaskInitNode", "TaskNode") 

30 

31import dataclasses 

32import enum 

33from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence 

34from typing import TYPE_CHECKING, Any, cast 

35 

36from lsst.daf.butler import ( 

37 DataCoordinate, 

38 DatasetRef, 

39 DatasetType, 

40 DimensionGroup, 

41 DimensionUniverse, 

42 Registry, 

43) 

44from lsst.pex.config import FieldValidationError 

45from lsst.utils.classes import immutable 

46from lsst.utils.doImport import doImportType 

47from lsst.utils.introspection import get_full_type_name 

48 

49from .. import automatic_connection_constants as acc 

50from ..connections import PipelineTaskConnections 

51from ..connectionTypes import BaseConnection, BaseInput, InitOutput, Output 

52from ._edges import Edge, ReadEdge, WriteEdge 

53from ._exceptions import TaskNotImportedError, UnresolvedGraphError 

54from ._nodes import NodeKey, NodeType 

55 

56if TYPE_CHECKING: 

57 from ..config import PipelineTaskConfig 

58 from ..pipelineTask import PipelineTask 

59 

60 

61class TaskImportMode(enum.Enum): 

62 """Enumeration of the ways to handle importing tasks when reading a 

63 serialized PipelineGraph. 

64 """ 

65 

66 DO_NOT_IMPORT = enum.auto() 

67 """Do not import tasks or instantiate their configs and connections.""" 

68 

69 REQUIRE_CONSISTENT_EDGES = enum.auto() 

70 """Import tasks and instantiate their config and connection objects, and 

71 check that the connections still define the same edges. 

72 """ 

73 

74 ASSUME_CONSISTENT_EDGES = enum.auto() 

75 """Import tasks and instantiate their config and connection objects, but do 

76 not check that the connections still define the same edges. 

77 

78 This is safe only when the caller knows the task definition has not changed 

79 since the pipeline graph was persisted, such as when it was saved and 

80 loaded with the same pipeline version. 

81 """ 

82 

83 OVERRIDE_EDGES = enum.auto() 

84 """Import tasks and instantiate their config and connection objects, and 

85 allow the edges defined in those connections to override those in the 

86 persisted graph. 

87 

88 This may cause dataset type nodes to be unresolved, since resolutions 

89 consistent with the original edges may be invalidated. 

90 """ 

91 

92 

93@dataclasses.dataclass(frozen=True) 

94class _TaskNodeImportedData: 

95 """An internal struct that holds `TaskNode` and `TaskInitNode` state that 

96 requires task classes to be imported. 

97 """ 

98 

99 task_class: type[PipelineTask] 

100 """Type object for the task.""" 

101 

102 config: PipelineTaskConfig 

103 """Configuration object for the task.""" 

104 

105 connection_map: dict[str, BaseConnection] 

106 """Mapping from connection name to connection. 

107 

108 In addition to ``connections.allConnections``, this also holds the 

109 "automatic" config, log, and metadata connections using the names defined 

110 in the `.automatic_connection_constants` module. 

111 """ 

112 

113 connections: PipelineTaskConnections 

114 """Configured connections object for the task.""" 

115 

116 @classmethod 

117 def configure( 

118 cls, 

119 label: str, 

120 task_class: type[PipelineTask], 

121 config: PipelineTaskConfig, 

122 connections: PipelineTaskConnections | None = None, 

123 ) -> _TaskNodeImportedData: 

124 """Construct while creating a `PipelineTaskConnections` instance if 

125 necessary. 

126 

127 Parameters 

128 ---------- 

129 label : `str` 

130 Label for the task in the pipeline. Only used in error messages. 

131 task_class : `type` [ `.PipelineTask` ] 

132 Pipeline task `type` object. 

133 config : `.PipelineTaskConfig` 

134 Configuration for the task. 

135 connections : `.PipelineTaskConnections`, optional 

136 Object that describes the dataset types used by the task. If not 

137 provided, one will be constructed from the given configuration. If 

138 provided, it is assumed that ``config`` has already been validated 

139 and frozen. 

140 

141 Returns 

142 ------- 

143 data : `_TaskNodeImportedData` 

144 Instance of this struct. 

145 """ 

146 if connections is None: 

147 # If we don't have connections yet, assume the config hasn't been 

148 # validated yet. 

149 try: 

150 config.validate() 

151 except FieldValidationError as err: 

152 err.fullname = f"{label}: {err.fullname}" 

153 raise err 

154 except Exception as err: 

155 raise ValueError( 

156 f"Configuration validation failed for task {label!r} (see chained exception)." 

157 ) from err 

158 config.freeze() 

159 # MyPy doesn't see the metaclass attribute defined for this. 

160 connections = config.ConnectionsClass(config=config) # type: ignore 

161 connection_map = dict(connections.allConnections) 

162 connection_map[acc.CONFIG_INIT_OUTPUT_CONNECTION_NAME] = InitOutput( 

163 acc.CONFIG_INIT_OUTPUT_TEMPLATE.format(label=label), 

164 acc.CONFIG_INIT_OUTPUT_STORAGE_CLASS, 

165 ) 

166 connection_map[acc.METADATA_OUTPUT_CONNECTION_NAME] = Output( 

167 acc.METADATA_OUTPUT_TEMPLATE.format(label=label), 

168 acc.METADATA_OUTPUT_STORAGE_CLASS, 

169 dimensions=set(connections.dimensions), 

170 ) 

171 if config.saveLogOutput: 

172 connection_map[acc.LOG_OUTPUT_CONNECTION_NAME] = Output( 

173 acc.LOG_OUTPUT_TEMPLATE.format(label=label), 

174 acc.LOG_OUTPUT_STORAGE_CLASS, 

175 dimensions=set(connections.dimensions), 

176 ) 

177 return cls(task_class, config, connection_map, connections) 

178 

179 

180@immutable 

181class TaskInitNode: 

182 """A node in a pipeline graph that represents the construction of a 

183 `PipelineTask`. 

184 

185 Parameters 

186 ---------- 

187 key : `NodeKey` 

188 Key that identifies this node in internal and exported networkx graphs. 

189 inputs : `~collections.abc.Mapping` [ `str`, `ReadEdge` ] 

190 Graph edges that represent inputs required just to construct an 

191 instance of this task, keyed by connection name. 

192 outputs : ~collections.abc.Mapping` [ `str`, `WriteEdge` ] 

193 Graph edges that represent outputs of this task that are available 

194 after just constructing it, keyed by connection name. 

195 

196 This does not include the special `config_init_output` edge; use 

197 `iter_all_outputs` to include that, too. 

198 config_output : `WriteEdge` 

199 The special init output edge that persists the task's configuration. 

200 imported_data : `_TaskNodeImportedData`, optional 

201 Internal struct that holds information that requires the task class to 

202 have been be imported. 

203 task_class_name : `str`, optional 

204 Fully-qualified name of the task class. Must be provided if 

205 ``imported_data`` is not. 

206 config_str : `str`, optional 

207 Configuration for the task as a string of override statements. Must be 

208 provided if ``imported_data`` is not. 

209 

210 Notes 

211 ----- 

212 When included in an exported `networkx` graph (e.g. 

213 `PipelineGraph.make_xgraph`), task initialization nodes set the following 

214 node attributes: 

215 

216 - ``task_class_name`` 

217 - ``bipartite`` (see `NodeType.bipartite`) 

218 - ``task_class`` (only if `is_imported` is `True`) 

219 - ``config`` (only if `is_imported` is `True`) 

220 """ 

221 

222 def __init__( 

223 self, 

224 key: NodeKey, 

225 *, 

226 inputs: Mapping[str, ReadEdge], 

227 outputs: Mapping[str, WriteEdge], 

228 config_output: WriteEdge, 

229 imported_data: _TaskNodeImportedData | None = None, 

230 task_class_name: str | None = None, 

231 config_str: str | None = None, 

232 ): 

233 self.key = key 

234 self.inputs = inputs 

235 self.outputs = outputs 

236 self.config_output = config_output 

237 # Instead of setting attributes to None, we do not set them at all; 

238 # this works better with the @immutable decorator, which supports 

239 # deferred initialization but not reassignment. 

240 if task_class_name is not None: 

241 self._task_class_name = task_class_name 

242 if config_str is not None: 

243 self._config_str = config_str 

244 if imported_data is not None: 

245 self._imported_data = imported_data 

246 else: 

247 assert self._task_class_name is not None and self._config_str is not None, ( 

248 "If imported_data is not present, task_class_name and config_str must be." 

249 ) 

250 

251 key: NodeKey 

252 """Key that identifies this node in internal and exported networkx graphs. 

253 """ 

254 

255 inputs: Mapping[str, ReadEdge] 

256 """Graph edges that represent inputs required just to construct an instance 

257 of this task, keyed by connection name. 

258 """ 

259 

260 outputs: Mapping[str, WriteEdge] 

261 """Graph edges that represent outputs of this task that are available after 

262 just constructing it, keyed by connection name. 

263 

264 This does not include the special `config_output` edge; use 

265 `iter_all_outputs` to include that, too. 

266 """ 

267 

268 config_output: WriteEdge 

269 """The special output edge that persists the task's configuration. 

270 """ 

271 

272 @property 

273 def label(self) -> str: 

274 """Label of this configuration of a task in the pipeline.""" 

275 return str(self.key) 

276 

277 @property 

278 def is_imported(self) -> bool: 

279 """Whether this the task type for this node has been imported and 

280 its configuration overrides applied. 

281 

282 If this is `False`, the `task_class` and `config` attributes may not 

283 be accessed. 

284 """ 

285 return hasattr(self, "_imported_data") 

286 

287 @property 

288 def task_class(self) -> type[PipelineTask]: 

289 """Type object for the task. 

290 

291 Accessing this attribute when `is_imported` is `False` will raise 

292 `TaskNotImportedError`, but accessing `task_class_name` will not. 

293 """ 

294 return self._get_imported_data().task_class 

295 

296 @property 

297 def task_class_name(self) -> str: 

298 """The fully-qualified string name of the task class.""" 

299 try: 

300 return self._task_class_name 

301 except AttributeError: 

302 pass 

303 self._task_class_name = get_full_type_name(self.task_class) 

304 return self._task_class_name 

305 

306 @property 

307 def config(self) -> PipelineTaskConfig: 

308 """Configuration for the task. 

309 

310 This is always frozen. 

311 

312 Accessing this attribute when `is_imported` is `False` will raise 

313 `TaskNotImportedError`, but calling `get_config_str` will not. 

314 """ 

315 return self._get_imported_data().config 

316 

317 def __repr__(self) -> str: 

318 return f"{self.label} [init] ({self.task_class_name})" 

319 

320 def get_config_str(self) -> str: 

321 """Return the configuration for this task as a string of override 

322 statements. 

323 

324 Returns 

325 ------- 

326 config_str : `str` 

327 String containing configuration-overload statements. 

328 """ 

329 try: 

330 return self._config_str 

331 except AttributeError: 

332 pass 

333 self._config_str = self.config.saveToString() 

334 return self._config_str 

335 

336 def iter_all_inputs(self) -> Iterator[ReadEdge]: 

337 """Iterate over all inputs required for construction. 

338 

339 This is the same as iteration over ``inputs.values()``, but it will be 

340 updated to include any automatic init-input connections added in the 

341 future, while `inputs` will continue to hold only task-defined init 

342 inputs. 

343 

344 Yields 

345 ------ 

346 `ReadEdge` 

347 All the inputs required for construction. 

348 """ 

349 return iter(self.inputs.values()) 

350 

351 def iter_all_outputs(self) -> Iterator[WriteEdge]: 

352 """Iterate over all outputs available after construction, including 

353 special ones. 

354 

355 Yields 

356 ------ 

357 `ReadEdge` 

358 All the outputs available after construction. 

359 """ 

360 yield from self.outputs.values() 

361 yield self.config_output 

362 

363 def get_input_edge(self, connection_name: str) -> ReadEdge: 

364 """Look up an input edge by connection name. 

365 

366 Parameters 

367 ---------- 

368 connection_name : `str` 

369 Name of the connection. 

370 

371 Returns 

372 ------- 

373 edge : `ReadEdge` 

374 Input edge. 

375 """ 

376 return self.inputs[connection_name] 

377 

378 def get_output_edge(self, connection_name: str) -> WriteEdge: 

379 """Look up an output edge by connection name. 

380 

381 Parameters 

382 ---------- 

383 connection_name : `str` 

384 Name of the connection. 

385 

386 Returns 

387 ------- 

388 edge : `WriteEdge` 

389 Output edge. 

390 """ 

391 if connection_name == acc.CONFIG_INIT_OUTPUT_CONNECTION_NAME: 

392 return self.config_output 

393 return self.outputs[connection_name] 

394 

395 def get_edge(self, connection_name: str) -> Edge: 

396 """Look up an edge by connection name. 

397 

398 Parameters 

399 ---------- 

400 connection_name : `str` 

401 Name of the connection. 

402 

403 Returns 

404 ------- 

405 edge : `Edge` 

406 Edge. 

407 """ 

408 try: 

409 return self.get_input_edge(connection_name) 

410 except KeyError: 

411 pass 

412 return self.get_output_edge(connection_name) 

413 

414 def diff_edges(self, other: TaskInitNode) -> list[str]: 

415 """Compare the edges of this task initialization node to those from the 

416 same task label in a different pipeline. 

417 

418 Parameters 

419 ---------- 

420 other : `TaskInitNode` 

421 Other node to compare to. Must have the same task label, but need 

422 not have the same configuration or even the same task class. 

423 

424 Returns 

425 ------- 

426 differences : `list` [ `str` ] 

427 List of string messages describing differences between ``self`` and 

428 ``other``. Will be empty if the two nodes have the same edges. 

429 Messages will use 'A' to refer to ``self`` and 'B' to refer to 

430 ``other``. 

431 """ 

432 result = [] 

433 result += _diff_edge_mapping(self.inputs, self.inputs, self.label, "init input") 

434 result += _diff_edge_mapping(self.outputs, other.outputs, self.label, "init output") 

435 result += self.config_output.diff(other.config_output, "config init output") 

436 return result 

437 

438 def _to_xgraph_state(self) -> dict[str, Any]: 

439 """Convert this nodes's attributes into a dictionary suitable for use 

440 in exported networkx graphs. 

441 """ 

442 result = {"task_class_name": self.task_class_name, "bipartite": NodeType.TASK_INIT.bipartite} 

443 if hasattr(self, "_imported_data"): 

444 result["task_class"] = self.task_class 

445 result["config"] = self.config 

446 return result 

447 

448 def _get_imported_data(self) -> _TaskNodeImportedData: 

449 """Return the imported data struct. 

450 

451 Returns 

452 ------- 

453 imported_data : `_TaskNodeImportedData` 

454 Internal structure holding state that requires the task class to 

455 have been imported. 

456 

457 Raises 

458 ------ 

459 TaskNotImportedError 

460 Raised if `is_imported` is `False`. 

461 """ 

462 try: 

463 return self._imported_data 

464 except AttributeError: 

465 raise TaskNotImportedError( 

466 f"Task class {self.task_class_name!r} for label {self.label!r} has not been imported " 

467 "(see PipelineGraph.import_and_configure)." 

468 ) from None 

469 

470 @staticmethod 

471 def _unreduce(kwargs: dict[str, Any]) -> TaskInitNode: 

472 """Unpickle a `TaskInitNode` instance.""" 

473 # Connections classes are not pickleable, so we can't use the 

474 # dataclass-provided pickle implementation of _TaskNodeImportedData, 

475 # and it's easier to just call its `configure` method than to fix it. 

476 if (imported_data_args := kwargs.pop("imported_data_args", None)) is not None: 

477 imported_data = _TaskNodeImportedData.configure(*imported_data_args) 

478 else: 

479 imported_data = None 

480 return TaskInitNode(imported_data=imported_data, **kwargs) 

481 

482 def __reduce__(self) -> tuple[Callable[[dict[str, Any]], TaskInitNode], tuple[dict[str, Any]]]: 

483 kwargs = dict( 

484 key=self.key, 

485 inputs=self.inputs, 

486 outputs=self.outputs, 

487 config_output=self.config_output, 

488 task_class_name=getattr(self, "_task_class_name", None), 

489 config_str=getattr(self, "_config_str", None), 

490 ) 

491 if hasattr(self, "_imported_data"): 

492 kwargs["imported_data_args"] = ( 

493 self.label, 

494 self.task_class, 

495 self.config, 

496 ) 

497 return (self._unreduce, (kwargs,)) 

498 

499 

500@immutable 

501class TaskNode: 

502 """A node in a pipeline graph that represents a labeled configuration of a 

503 `PipelineTask`. 

504 

505 Parameters 

506 ---------- 

507 key : `NodeKey` 

508 Identifier for this node in networkx graphs. 

509 init : `TaskInitNode` 

510 Node representing the initialization of this task. 

511 prerequisite_inputs : `~collections.abc.Mapping` [ `str`, `ReadEdge` ] 

512 Graph edges that represent prerequisite inputs to this task, keyed by 

513 connection name. 

514 

515 Prerequisite inputs must already exist in the data repository when a 

516 `QuantumGraph` is built, but have more flexibility in how they are 

517 looked up than regular inputs. 

518 inputs : `~collections.abc.Mapping` [ `str`, `ReadEdge` ] 

519 Graph edges that represent regular runtime inputs to this task, keyed 

520 by connection name. 

521 outputs : ~collections.abc.Mapping` [ `str`, `WriteEdge` ] 

522 Graph edges that represent regular runtime outputs of this task, keyed 

523 by connection name. 

524 

525 This does not include the special `log_output` and `metadata_output` 

526 edges; use `iter_all_outputs` to include that, too. 

527 log_output : `WriteEdge` or `None` 

528 The special runtime output that persists the task's logs. 

529 metadata_output : `WriteEdge` 

530 The special runtime output that persists the task's metadata. 

531 dimensions : `lsst.daf.butler.DimensionGroup` or `frozenset` [ `str` ] 

532 Dimensions of the task. If a `frozenset`, the dimensions have not been 

533 resolved by a `~lsst.daf.butler.DimensionUniverse` and cannot be safely 

534 compared to other sets of dimensions. 

535 

536 Notes 

537 ----- 

538 Task nodes are intentionally not equality comparable, since there are many 

539 different (and useful) ways to compare these objects with no clear winner 

540 as the most obvious behavior. 

541 

542 When included in an exported `networkx` graph (e.g. 

543 `PipelineGraph.make_xgraph`), task nodes set the following node attributes: 

544 

545 - ``task_class_name`` 

546 - ``bipartite`` (see `NodeType.bipartite`) 

547 - ``task_class`` (only if `is_imported` is `True`) 

548 - ``config`` (only if `is_imported` is `True`) 

549 """ 

550 

551 def __init__( 

552 self, 

553 key: NodeKey, 

554 init: TaskInitNode, 

555 *, 

556 prerequisite_inputs: Mapping[str, ReadEdge], 

557 inputs: Mapping[str, ReadEdge], 

558 outputs: Mapping[str, WriteEdge], 

559 log_output: WriteEdge | None, 

560 metadata_output: WriteEdge, 

561 dimensions: DimensionGroup | frozenset[str], 

562 ): 

563 self.key = key 

564 self.init = init 

565 self.prerequisite_inputs = prerequisite_inputs 

566 self.inputs = inputs 

567 self.outputs = outputs 

568 self.log_output = log_output 

569 self.metadata_output = metadata_output 

570 self._dimensions = dimensions 

571 

572 @staticmethod 

573 def _from_imported_data( 

574 key: NodeKey, 

575 init_key: NodeKey, 

576 data: _TaskNodeImportedData, 

577 universe: DimensionUniverse | None, 

578 ) -> TaskNode: 

579 """Construct from a `PipelineTask` type and its configuration. 

580 

581 Parameters 

582 ---------- 

583 key : `NodeKey` 

584 Identifier for this node in networkx graphs. 

585 init_key : `TaskInitNode` 

586 Node representing the initialization of this task. 

587 data : `_TaskNodeImportedData` 

588 Internal struct that holds information that requires the task class 

589 to have been be imported. 

590 universe : `lsst.daf.butler.DimensionUniverse` or `None` 

591 Definitions of all dimensions. 

592 

593 Returns 

594 ------- 

595 node : `TaskNode` 

596 New task node. 

597 

598 Raises 

599 ------ 

600 ValueError 

601 Raised if configuration validation failed when constructing 

602 ``connections``. 

603 """ 

604 init_inputs = { 

605 name: ReadEdge._from_connection_map(init_key, name, data.connection_map) 

606 for name in data.connections.initInputs 

607 } 

608 prerequisite_inputs = { 

609 name: ReadEdge._from_connection_map(key, name, data.connection_map, is_prerequisite=True) 

610 for name in data.connections.prerequisiteInputs 

611 } 

612 inputs = { 

613 name: ReadEdge._from_connection_map(key, name, data.connection_map) 

614 for name in data.connections.inputs 

615 if not getattr(data.connections, name).deferBinding 

616 } 

617 init_outputs = { 

618 name: WriteEdge._from_connection_map(init_key, name, data.connection_map) 

619 for name in data.connections.initOutputs 

620 } 

621 outputs = { 

622 name: WriteEdge._from_connection_map(key, name, data.connection_map) 

623 for name in data.connections.outputs 

624 } 

625 init = TaskInitNode( 

626 key=init_key, 

627 inputs=init_inputs, 

628 outputs=init_outputs, 

629 config_output=WriteEdge._from_connection_map( 

630 init_key, acc.CONFIG_INIT_OUTPUT_CONNECTION_NAME, data.connection_map 

631 ), 

632 imported_data=data, 

633 ) 

634 instance = TaskNode( 

635 key=key, 

636 init=init, 

637 prerequisite_inputs=prerequisite_inputs, 

638 inputs=inputs, 

639 outputs=outputs, 

640 log_output=( 

641 WriteEdge._from_connection_map(key, acc.LOG_OUTPUT_CONNECTION_NAME, data.connection_map) 

642 if data.config.saveLogOutput 

643 else None 

644 ), 

645 metadata_output=WriteEdge._from_connection_map( 

646 key, acc.METADATA_OUTPUT_CONNECTION_NAME, data.connection_map 

647 ), 

648 dimensions=( 

649 frozenset(data.connections.dimensions) 

650 if universe is None 

651 else universe.conform(data.connections.dimensions) 

652 ), 

653 ) 

654 return instance 

655 

656 key: NodeKey 

657 """Key that identifies this node in internal and exported networkx graphs. 

658 """ 

659 

660 prerequisite_inputs: Mapping[str, ReadEdge] 

661 """Graph edges that represent prerequisite inputs to this task. 

662 

663 Prerequisite inputs must already exist in the data repository when a 

664 `QuantumGraph` is built, but have more flexibility in how they are looked 

665 up than regular inputs. 

666 """ 

667 

668 inputs: Mapping[str, ReadEdge] 

669 """Graph edges that represent regular runtime inputs to this task. 

670 """ 

671 

672 outputs: Mapping[str, WriteEdge] 

673 """Graph edges that represent regular runtime outputs of this task. 

674 

675 This does not include the special `log_output` and `metadata_output` edges; 

676 use `iter_all_outputs` to include that, too. 

677 """ 

678 

679 log_output: WriteEdge | None 

680 """The special runtime output that persists the task's logs. 

681 """ 

682 

683 metadata_output: WriteEdge 

684 """The special runtime output that persists the task's metadata. 

685 """ 

686 

687 @property 

688 def label(self) -> str: 

689 """Label of this configuration of a task in the pipeline.""" 

690 return self.key.name 

691 

692 @property 

693 def is_imported(self) -> bool: 

694 """Whether this the task type for this node has been imported and 

695 its configuration overrides applied. 

696 

697 If this is `False`, the `task_class` and `config` attributes may not 

698 be accessed. 

699 """ 

700 return self.init.is_imported 

701 

702 @property 

703 def task_class(self) -> type[PipelineTask]: 

704 """Type object for the task. 

705 

706 Accessing this attribute when `is_imported` is `False` will raise 

707 `TaskNotImportedError`, but accessing `task_class_name` will not. 

708 """ 

709 return self.init.task_class 

710 

711 @property 

712 def task_class_name(self) -> str: 

713 """The fully-qualified string name of the task class.""" 

714 return self.init.task_class_name 

715 

716 @property 

717 def config(self) -> PipelineTaskConfig: 

718 """Configuration for the task. 

719 

720 This is always frozen. 

721 

722 Accessing this attribute when `is_imported` is `False` will raise 

723 `TaskNotImportedError`, but calling `get_config_str` will not. 

724 """ 

725 return self.init.config 

726 

727 @property 

728 def has_resolved_dimensions(self) -> bool: 

729 """Whether the `dimensions` attribute may be accessed. 

730 

731 If `False`, the `raw_dimensions` attribute may be used to obtain a 

732 set of dimension names that has not been resolved by a 

733 `~lsst.daf.butler.DimensionsUniverse`. 

734 """ 

735 return type(self._dimensions) is DimensionGroup 

736 

737 @property 

738 def dimensions(self) -> DimensionGroup: 

739 """Standardized dimensions of the task.""" 

740 if not self.has_resolved_dimensions: 

741 raise UnresolvedGraphError(f"Dimensions for task {self.label!r} have not been resolved.") 

742 return cast(DimensionGroup, self._dimensions) 

743 

744 @property 

745 def raw_dimensions(self) -> frozenset[str]: 

746 """Raw dimensions of the task, with standardization by a 

747 `~lsst.daf.butler.DimensionUniverse` not guaranteed. 

748 """ 

749 if self.has_resolved_dimensions: 

750 return frozenset(cast(DimensionGroup, self._dimensions).names) 

751 else: 

752 return cast(frozenset[str], self._dimensions) 

753 

754 def __repr__(self) -> str: 

755 if self.has_resolved_dimensions: 

756 return f"{self.label} ({self.task_class_name}, {self.dimensions})" 

757 else: 

758 return f"{self.label} ({self.task_class_name})" 

759 

760 def get_config_str(self) -> str: 

761 """Return the configuration for this task as a string of override 

762 statements. 

763 

764 Returns 

765 ------- 

766 config_str : `str` 

767 String containing configuration-overload statements. 

768 """ 

769 return self.init.get_config_str() 

770 

771 def iter_all_inputs(self) -> Iterator[ReadEdge]: 

772 """Iterate over all runtime inputs, including both regular inputs and 

773 prerequisites. 

774 

775 Yields 

776 ------ 

777 `ReadEdge` 

778 All the runtime inputs. 

779 """ 

780 yield from self.prerequisite_inputs.values() 

781 yield from self.inputs.values() 

782 

783 def iter_all_outputs(self) -> Iterator[WriteEdge]: 

784 """Iterate over all runtime outputs, including special ones. 

785 

786 Yields 

787 ------ 

788 `ReadEdge` 

789 All the runtime outputs. 

790 """ 

791 yield from self.outputs.values() 

792 yield self.metadata_output 

793 if self.log_output is not None: 

794 yield self.log_output 

795 

796 def get_input_edge(self, connection_name: str) -> ReadEdge: 

797 """Look up an input edge by connection name. 

798 

799 Parameters 

800 ---------- 

801 connection_name : `str` 

802 Name of the connection. 

803 

804 Returns 

805 ------- 

806 edge : `ReadEdge` 

807 Input edge. 

808 """ 

809 if (edge := self.inputs.get(connection_name)) is not None: 

810 return edge 

811 return self.prerequisite_inputs[connection_name] 

812 

813 def get_output_edge(self, connection_name: str) -> WriteEdge: 

814 """Look up an output edge by connection name. 

815 

816 Parameters 

817 ---------- 

818 connection_name : `str` 

819 Name of the connection. 

820 

821 Returns 

822 ------- 

823 edge : `WriteEdge` 

824 Output edge. 

825 """ 

826 if connection_name == acc.METADATA_OUTPUT_CONNECTION_NAME: 

827 return self.metadata_output 

828 if connection_name == acc.LOG_OUTPUT_CONNECTION_NAME: 

829 if self.log_output is None: 

830 raise KeyError(connection_name) 

831 return self.log_output 

832 return self.outputs[connection_name] 

833 

834 def get_edge(self, connection_name: str) -> Edge: 

835 """Look up an edge by connection name. 

836 

837 Parameters 

838 ---------- 

839 connection_name : `str` 

840 Name of the connection. 

841 

842 Returns 

843 ------- 

844 edge : `Edge` 

845 Edge. 

846 """ 

847 try: 

848 return self.get_input_edge(connection_name) 

849 except KeyError: 

850 pass 

851 return self.get_output_edge(connection_name) 

852 

853 def diff_edges(self, other: TaskNode) -> list[str]: 

854 """Compare the edges of this task node to those from the same task 

855 label in a different pipeline. 

856 

857 This also calls `TaskInitNode.diff_edges`. 

858 

859 Parameters 

860 ---------- 

861 other : `TaskInitNode` 

862 Other node to compare to. Must have the same task label, but need 

863 not have the same configuration or even the same task class. 

864 

865 Returns 

866 ------- 

867 differences : `list` [ `str` ] 

868 List of string messages describing differences between ``self`` and 

869 ``other``. Will be empty if the two nodes have the same edges. 

870 Messages will use 'A' to refer to ``self`` and 'B' to refer to 

871 ``other``. 

872 """ 

873 result = self.init.diff_edges(other.init) 

874 result += _diff_edge_mapping( 

875 self.prerequisite_inputs, other.prerequisite_inputs, self.label, "prerequisite input" 

876 ) 

877 result += _diff_edge_mapping(self.inputs, other.inputs, self.label, "input") 

878 result += _diff_edge_mapping(self.outputs, other.outputs, self.label, "output") 

879 if self.log_output is not None: 

880 if other.log_output is not None: 

881 result += self.log_output.diff(other.log_output, "log output") 

882 else: 

883 result.append("Log output is present in A, but not in B.") 

884 elif other.log_output is not None: 

885 result.append("Log output is present in B, but not in A.") 

886 result += self.metadata_output.diff(other.metadata_output, "metadata output") 

887 return result 

888 

889 def get_lookup_function( 

890 self, connection_name: str 

891 ) -> Callable[[DatasetType, Registry, DataCoordinate, Sequence[str]], Iterable[DatasetRef]] | None: 

892 """Return the custom dataset query function for an edge, if one exists. 

893 

894 Parameters 

895 ---------- 

896 connection_name : `str` 

897 Name of the connection. 

898 

899 Returns 

900 ------- 

901 lookup_function : `~collections.abc.Callable` or `None` 

902 Callable that takes a dataset type, a butler registry, a data 

903 coordinate (the quantum data ID), and an ordered list of 

904 collections to search, and returns an iterable of 

905 `~lsst.daf.butler.DatasetRef`. 

906 """ 

907 return getattr(self._get_imported_data().connection_map[connection_name], "lookupFunction", None) 

908 

909 def is_optional(self, connection_name: str) -> bool: 

910 """Check whether the given connection has ``minimum==0``. 

911 

912 Parameters 

913 ---------- 

914 connection_name : `str` 

915 Name of the connection. 

916 

917 Returns 

918 ------- 

919 optional : `bool` 

920 Whether this task can run without any datasets for the given 

921 connection. 

922 """ 

923 connection = getattr(self.get_connections(), connection_name) 

924 return isinstance(connection, BaseInput) and connection.minimum == 0 

925 

926 def get_connections(self) -> PipelineTaskConnections: 

927 """Return the connections class instance for this task. 

928 

929 Returns 

930 ------- 

931 connections : `.PipelineTaskConnections` 

932 Task-provided object that defines inputs and outputs from 

933 configuration. 

934 """ 

935 return self._get_imported_data().connections 

936 

937 def get_spatial_bounds_connections(self) -> frozenset[str]: 

938 """Return the names of connections whose data IDs should be included 

939 in the calculation of the spatial bounds for this task's quanta. 

940 

941 Returns 

942 ------- 

943 connection_names : `frozenset` [ `str` ] 

944 Names of connections with spatial dimensions. 

945 """ 

946 return frozenset(self._get_imported_data().connections.getSpatialBoundsConnections()) 

947 

948 def get_temporal_bounds_connections(self) -> frozenset[str]: 

949 """Return the names of connections whose data IDs should be included 

950 in the calculation of the temporal bounds for this task's quanta. 

951 

952 Returns 

953 ------- 

954 connection_names : `frozenset` [ `str` ] 

955 Names of connections with temporal dimensions. 

956 """ 

957 return frozenset(self._get_imported_data().connections.getTemporalBoundsConnections()) 

958 

959 def _imported_and_configured(self, rebuild: bool) -> TaskNode: 

960 """Import the task class and use it to construct a new instance. 

961 

962 Parameters 

963 ---------- 

964 rebuild : `bool` 

965 If `True`, import the task class and configure its connections to 

966 generate new edges that may differ from the current ones. If 

967 `False`, import the task class but just update the `task_class` and 

968 `config` attributes, and assume the edges have not changed. 

969 

970 Returns 

971 ------- 

972 node : `TaskNode` 

973 Task node instance for which `is_imported` is `True`. Will be 

974 ``self`` if this is the case already. 

975 """ 

976 from ..pipelineTask import PipelineTask 

977 

978 if self.is_imported: 

979 return self 

980 task_class = doImportType(self.task_class_name) 

981 if not issubclass(task_class, PipelineTask): 

982 raise TypeError(f"{self.task_class_name!r} is not a PipelineTask subclass.") 

983 config = task_class.ConfigClass() 

984 config.loadFromString(self.get_config_str()) 

985 return self._reconfigured(config, rebuild=rebuild, task_class=task_class) 

986 

987 def _reconfigured( 

988 self, 

989 config: PipelineTaskConfig, 

990 rebuild: bool, 

991 task_class: type[PipelineTask] | None = None, 

992 ) -> TaskNode: 

993 """Return a version of this node with new configuration. 

994 

995 Parameters 

996 ---------- 

997 config : `.PipelineTaskConfig` 

998 New configuration for the task. 

999 rebuild : `bool` 

1000 If `True`, use the configured connections to generate new edges 

1001 that may differ from the current ones. If `False`, just update the 

1002 `task_class` and `config` attributes, and assume the edges have not 

1003 changed. 

1004 task_class : `type` [ `PipelineTask` ], optional 

1005 Subclass of `PipelineTask`. This defaults to ``self.task_class`, 

1006 but may be passed as an argument if that is not available because 

1007 the task class was not imported when ``self`` was constructed. 

1008 

1009 Returns 

1010 ------- 

1011 node : `TaskNode` 

1012 Task node instance with the new config. 

1013 """ 

1014 if task_class is None: 

1015 task_class = self.task_class 

1016 imported_data = _TaskNodeImportedData.configure(self.key.name, task_class, config) 

1017 if rebuild: 

1018 return self._from_imported_data( 

1019 self.key, 

1020 self.init.key, 

1021 imported_data, 

1022 universe=self._dimensions.universe if type(self._dimensions) is DimensionGroup else None, 

1023 ) 

1024 else: 

1025 return TaskNode( 

1026 self.key, 

1027 TaskInitNode( 

1028 self.init.key, 

1029 inputs=self.init.inputs, 

1030 outputs=self.init.outputs, 

1031 config_output=self.init.config_output, 

1032 imported_data=imported_data, 

1033 ), 

1034 prerequisite_inputs=self.prerequisite_inputs, 

1035 inputs=self.inputs, 

1036 outputs=self.outputs, 

1037 log_output=self.log_output, 

1038 metadata_output=self.metadata_output, 

1039 dimensions=self._dimensions, 

1040 ) 

1041 

1042 def _resolved(self, universe: DimensionUniverse | None) -> TaskNode: 

1043 """Return an otherwise-equivalent task node with resolved dimensions. 

1044 

1045 Parameters 

1046 ---------- 

1047 universe : `lsst.daf.butler.DimensionUniverse` or `None` 

1048 Definitions for all dimensions. 

1049 

1050 Returns 

1051 ------- 

1052 node : `TaskNode` 

1053 Task node instance with `dimensions` resolved by the given 

1054 universe. Will be ``self`` if this is the case already. 

1055 """ 

1056 if self.has_resolved_dimensions: 

1057 if cast(DimensionGroup, self._dimensions).universe is universe: 

1058 return self 

1059 elif universe is None: 

1060 return self 

1061 return TaskNode( 

1062 key=self.key, 

1063 init=self.init, 

1064 prerequisite_inputs=self.prerequisite_inputs, 

1065 inputs=self.inputs, 

1066 outputs=self.outputs, 

1067 log_output=self.log_output, 

1068 metadata_output=self.metadata_output, 

1069 dimensions=( 

1070 universe.conform(self.raw_dimensions) if universe is not None else self.raw_dimensions 

1071 ), 

1072 ) 

1073 

1074 def _to_xgraph_state(self) -> dict[str, Any]: 

1075 """Convert this nodes's attributes into a dictionary suitable for use 

1076 in exported networkx graphs. 

1077 """ 

1078 result = self.init._to_xgraph_state() 

1079 if self.has_resolved_dimensions: 

1080 result["dimensions"] = self._dimensions 

1081 result["raw_dimensions"] = self.raw_dimensions 

1082 return result 

1083 

1084 def _get_imported_data(self) -> _TaskNodeImportedData: 

1085 """Return the imported data struct. 

1086 

1087 Returns 

1088 ------- 

1089 imported_data : `_TaskNodeImportedData` 

1090 Internal structure holding state that requires the task class to 

1091 have been imported. 

1092 

1093 Raises 

1094 ------ 

1095 TaskNotImportedError 

1096 Raised if `is_imported` is `False`. 

1097 """ 

1098 return self.init._get_imported_data() 

1099 

1100 @staticmethod 

1101 def _unreduce(kwargs: dict[str, Any]) -> TaskNode: 

1102 """Unpickle a `TaskNode` instance.""" 

1103 return TaskNode(**kwargs) 

1104 

1105 def __reduce__(self) -> tuple[Callable[[dict[str, Any]], TaskNode], tuple[dict[str, Any]]]: 

1106 return ( 

1107 self._unreduce, 

1108 ( 

1109 dict( 

1110 key=self.key, 

1111 init=self.init, 

1112 prerequisite_inputs=self.prerequisite_inputs, 

1113 inputs=self.inputs, 

1114 outputs=self.outputs, 

1115 log_output=self.log_output, 

1116 metadata_output=self.metadata_output, 

1117 dimensions=self._dimensions, 

1118 ), 

1119 ), 

1120 ) 

1121 

1122 

1123def _diff_edge_mapping( 

1124 a_mapping: Mapping[str, Edge], b_mapping: Mapping[str, Edge], task_label: str, connection_type: str 

1125) -> list[str]: 

1126 """Compare a pair of mappings of edges. 

1127 

1128 Parameters 

1129 ---------- 

1130 a_mapping : `~collections.abc.Mapping` [ `str`, `Edge` ] 

1131 First mapping to compare. Expected to have connection names as keys. 

1132 b_mapping : `~collections.abc.Mapping` [ `str`, `Edge` ] 

1133 First mapping to compare. If keys differ from those of ``a_mapping``, 

1134 this will be reported as a difference (in addition to element-wise 

1135 comparisons). 

1136 task_label : `str` 

1137 Task label associated with both mappings. 

1138 connection_type : `str` 

1139 Type of connection (e.g. "input" or "init output") associated with both 

1140 connections. This is a human-readable string to include in difference 

1141 messages. 

1142 

1143 Returns 

1144 ------- 

1145 differences : `list` [ `str` ] 

1146 List of string messages describing differences between the two 

1147 mappings. Will be empty if the two mappings have the same edges. 

1148 Messages will include "A" and "B", and are expected to be a preceded 

1149 by a message describing what "A" and "B" are in the context in which 

1150 this method is called. 

1151 

1152 Notes 

1153 ----- 

1154 This is expected to be used to compare one edge-holding mapping attribute 

1155 of a task or task init node to the same attribute on another task or task 

1156 init node (i.e. any of `TaskNode.inputs`, `TaskNode.outputs`, 

1157 `TaskNode.prerequisite_inputs`, `TaskInitNode.inputs`, 

1158 `TaskInitNode.outputs`). 

1159 """ 

1160 results = [] 

1161 b_to_do = set(b_mapping.keys()) 

1162 for connection_name, a_edge in a_mapping.items(): 

1163 if (b_edge := b_mapping.get(connection_name)) is None: 

1164 results.append( 

1165 f"{connection_type.capitalize()} {connection_name!r} of task " 

1166 f"{task_label!r} exists in A, but not in B (or it may have a different connection type)." 

1167 ) 

1168 else: 

1169 results.extend(a_edge.diff(b_edge, connection_type)) 

1170 b_to_do.discard(connection_name) 

1171 for connection_name in b_to_do: 

1172 results.append( 

1173 f"{connection_type.capitalize()} {connection_name!r} of task " 

1174 f"{task_label!r} exists in A, but not in B (or it may have a different connection type)." 

1175 ) 

1176 return results