Coverage for python/lsst/pipe/base/pipeline_graph/_tasks.py: 41%

262 statements  

« prev     ^ index     » next       coverage.py v7.3.0, created at 2023-08-31 09:39 +0000

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("TaskNode", "TaskInitNode", "TaskImportMode") 

24 

25import dataclasses 

26import enum 

27from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence 

28from typing import TYPE_CHECKING, Any, cast 

29 

30from lsst.daf.butler import ( 

31 DataCoordinate, 

32 DatasetRef, 

33 DatasetType, 

34 DimensionGraph, 

35 DimensionUniverse, 

36 Registry, 

37) 

38from lsst.utils.classes import immutable 

39from lsst.utils.doImport import doImportType 

40from lsst.utils.introspection import get_full_type_name 

41 

42from .. import automatic_connection_constants as acc 

43from ..connections import PipelineTaskConnections 

44from ..connectionTypes import BaseConnection, InitOutput, Output 

45from ._edges import Edge, ReadEdge, WriteEdge 

46from ._exceptions import TaskNotImportedError, UnresolvedGraphError 

47from ._nodes import NodeKey, NodeType 

48 

49if TYPE_CHECKING: 

50 from ..config import PipelineTaskConfig 

51 from ..pipelineTask import PipelineTask 

52 

53 

54class TaskImportMode(enum.Enum): 

55 """Enumeration of the ways to handle importing tasks when reading a 

56 serialized PipelineGraph. 

57 """ 

58 

59 DO_NOT_IMPORT = enum.auto() 

60 """Do not import tasks or instantiate their configs and connections.""" 

61 

62 REQUIRE_CONSISTENT_EDGES = enum.auto() 

63 """Import tasks and instantiate their config and connection objects, and 

64 check that the connections still define the same edges. 

65 """ 

66 

67 ASSUME_CONSISTENT_EDGES = enum.auto() 

68 """Import tasks and instantiate their config and connection objects, but do 

69 not check that the connections still define the same edges. 

70 

71 This is safe only when the caller knows the task definition has not changed 

72 since the pipeline graph was persisted, such as when it was saved and 

73 loaded with the same pipeline version. 

74 """ 

75 

76 OVERRIDE_EDGES = enum.auto() 

77 """Import tasks and instantiate their config and connection objects, and 

78 allow the edges defined in those connections to override those in the 

79 persisted graph. 

80 

81 This may cause dataset type nodes to be unresolved, since resolutions 

82 consistent with the original edges may be invalidated. 

83 """ 

84 

85 

86@dataclasses.dataclass(frozen=True) 

87class _TaskNodeImportedData: 

88 """An internal struct that holds `TaskNode` and `TaskInitNode` state that 

89 requires task classes to be imported. 

90 """ 

91 

92 task_class: type[PipelineTask] 

93 """Type object for the task.""" 

94 

95 config: PipelineTaskConfig 

96 """Configuration object for the task.""" 

97 

98 connection_map: dict[str, BaseConnection] 

99 """Mapping from connection name to connection. 

100 

101 In addition to ``connections.allConnections``, this also holds the 

102 "automatic" config, log, and metadata connections using the names defined 

103 in the `.automatic_connection_constants` module. 

104 """ 

105 

106 connections: PipelineTaskConnections 

107 """Configured connections object for the task.""" 

108 

109 @classmethod 

110 def configure( 

111 cls, 

112 label: str, 

113 task_class: type[PipelineTask], 

114 config: PipelineTaskConfig, 

115 connections: PipelineTaskConnections | None = None, 

116 ) -> _TaskNodeImportedData: 

117 """Construct while creating a `PipelineTaskConnections` instance if 

118 necessary. 

119 

120 Parameters 

121 ---------- 

122 label : `str` 

123 Label for the task in the pipeline. Only used in error messages. 

124 task_class : `type` [ `.PipelineTask` ] 

125 Pipeline task `type` object. 

126 config : `.PipelineTaskConfig` 

127 Configuration for the task. 

128 connections : `.PipelineTaskConnections`, optional 

129 Object that describes the dataset types used by the task. If not 

130 provided, one will be constructed from the given configuration. If 

131 provided, it is assumed that ``config`` has already been validated 

132 and frozen. 

133 

134 Returns 

135 ------- 

136 data : `_TaskNodeImportedData` 

137 Instance of this struct. 

138 """ 

139 if connections is None: 

140 # If we don't have connections yet, assume the config hasn't been 

141 # validated yet. 

142 try: 

143 config.validate() 

144 except Exception as err: 

145 raise ValueError( 

146 f"Configuration validation failed for task {label!r} (see chained exception)." 

147 ) from err 

148 config.freeze() 

149 connections = task_class.ConfigClass.ConnectionsClass(config=config) 

150 connection_map = dict(connections.allConnections) 

151 connection_map[acc.CONFIG_INIT_OUTPUT_CONNECTION_NAME] = InitOutput( 

152 acc.CONFIG_INIT_OUTPUT_TEMPLATE.format(label=label), 

153 acc.CONFIG_INIT_OUTPUT_STORAGE_CLASS, 

154 ) 

155 if not config.saveMetadata: 

156 raise ValueError(f"Metadata for task {label} cannot be disabled.") 

157 connection_map[acc.METADATA_OUTPUT_CONNECTION_NAME] = Output( 

158 acc.METADATA_OUTPUT_TEMPLATE.format(label=label), 

159 acc.METADATA_OUTPUT_STORAGE_CLASS, 

160 dimensions=set(connections.dimensions), 

161 ) 

162 if config.saveLogOutput: 

163 connection_map[acc.LOG_OUTPUT_CONNECTION_NAME] = Output( 

164 acc.LOG_OUTPUT_TEMPLATE.format(label=label), 

165 acc.LOG_OUTPUT_STORAGE_CLASS, 

166 dimensions=set(connections.dimensions), 

167 ) 

168 return cls(task_class, config, connection_map, connections) 

169 

170 

171@immutable 

172class TaskInitNode: 

173 """A node in a pipeline graph that represents the construction of a 

174 `PipelineTask`. 

175 

176 Parameters 

177 ---------- 

178 inputs : `~collections.abc.Mapping` [ `str`, `ReadEdge` ] 

179 Graph edges that represent inputs required just to construct an 

180 instance of this task, keyed by connection name. 

181 outputs : ~collections.abc.Mapping` [ `str`, `WriteEdge` ] 

182 Graph edges that represent outputs of this task that are available 

183 after just constructing it, keyed by connection name. 

184 

185 This does not include the special `config_init_output` edge; use 

186 `iter_all_outputs` to include that, too. 

187 config_output : `WriteEdge` 

188 The special init output edge that persists the task's configuration. 

189 imported_data : `_TaskNodeImportedData`, optional 

190 Internal struct that holds information that requires the task class to 

191 have been be imported. 

192 task_class_name : `str`, optional 

193 Fully-qualified name of the task class. Must be provided if 

194 ``imported_data`` is not. 

195 config_str : `str`, optional 

196 Configuration for the task as a string of override statements. Must be 

197 provided if ``imported_data`` is not. 

198 

199 Notes 

200 ----- 

201 When included in an exported `networkx` graph (e.g. 

202 `PipelineGraph.make_xgraph`), task initialization nodes set the following 

203 node attributes: 

204 

205 - ``task_class_name`` 

206 - ``bipartite`` (see `NodeType.bipartite`) 

207 - ``task_class`` (only if `is_imported` is `True`) 

208 - ``config`` (only if `is_importd` is `True`) 

209 """ 

210 

211 def __init__( 

212 self, 

213 key: NodeKey, 

214 *, 

215 inputs: Mapping[str, ReadEdge], 

216 outputs: Mapping[str, WriteEdge], 

217 config_output: WriteEdge, 

218 imported_data: _TaskNodeImportedData | None = None, 

219 task_class_name: str | None = None, 

220 config_str: str | None = None, 

221 ): 

222 self.key = key 

223 self.inputs = inputs 

224 self.outputs = outputs 

225 self.config_output = config_output 

226 # Instead of setting attributes to None, we do not set them at all; 

227 # this works better with the @immutable decorator, which supports 

228 # deferred initialization but not reassignment. 

229 if task_class_name is not None: 

230 self._task_class_name = task_class_name 

231 if config_str is not None: 

232 self._config_str = config_str 

233 if imported_data is not None: 

234 self._imported_data = imported_data 

235 else: 

236 assert ( 

237 self._task_class_name is not None and self._config_str is not None 

238 ), "If imported_data is not present, task_class_name and config_str must be." 

239 

240 key: NodeKey 

241 """Key that identifies this node in internal and exported networkx graphs. 

242 """ 

243 

244 inputs: Mapping[str, ReadEdge] 

245 """Graph edges that represent inputs required just to construct an instance 

246 of this task, keyed by connection name. 

247 """ 

248 

249 outputs: Mapping[str, WriteEdge] 

250 """Graph edges that represent outputs of this task that are available after 

251 just constructing it, keyed by connection name. 

252 

253 This does not include the special `config_output` edge; use 

254 `iter_all_outputs` to include that, too. 

255 """ 

256 

257 config_output: WriteEdge 

258 """The special output edge that persists the task's configuration. 

259 """ 

260 

261 @property 

262 def label(self) -> str: 

263 """Label of this configuration of a task in the pipeline.""" 

264 return str(self.key) 

265 

266 @property 

267 def is_imported(self) -> bool: 

268 """Whether this the task type for this node has been imported and 

269 its configuration overrides applied. 

270 

271 If this is `False`, the `task_class` and `config` attributes may not 

272 be accessed. 

273 """ 

274 return hasattr(self, "_imported_data") 

275 

276 @property 

277 def task_class(self) -> type[PipelineTask]: 

278 """Type object for the task. 

279 

280 Accessing this attribute when `is_imported` is `False` will raise 

281 `TaskNotImportedError`, but accessing `task_class_name` will not. 

282 """ 

283 return self._get_imported_data().task_class 

284 

285 @property 

286 def task_class_name(self) -> str: 

287 """The fully-qualified string name of the task class.""" 

288 try: 

289 return self._task_class_name 

290 except AttributeError: 

291 pass 

292 self._task_class_name = get_full_type_name(self.task_class) 

293 return self._task_class_name 

294 

295 @property 

296 def config(self) -> PipelineTaskConfig: 

297 """Configuration for the task. 

298 

299 This is always frozen. 

300 

301 Accessing this attribute when `is_imported` is `False` will raise 

302 `TaskNotImportedError`, but calling `get_config_str` will not. 

303 """ 

304 return self._get_imported_data().config 

305 

306 def __repr__(self) -> str: 

307 return f"{self.label} [init] ({self.task_class_name})" 

308 

309 def get_config_str(self) -> str: 

310 """Return the configuration for this task as a string of override 

311 statements. 

312 

313 Returns 

314 ------- 

315 config_str : `str` 

316 String containing configuration-overload statements. 

317 """ 

318 try: 

319 return self._config_str 

320 except AttributeError: 

321 pass 

322 self._config_str = self.config.saveToString() 

323 return self._config_str 

324 

325 def iter_all_inputs(self) -> Iterator[ReadEdge]: 

326 """Iterate over all inputs required for construction. 

327 

328 This is the same as iteration over ``inputs.values()``, but it will be 

329 updated to include any automatic init-input connections added in the 

330 future, while `inputs` will continue to hold only task-defined init 

331 inputs. 

332 """ 

333 return iter(self.inputs.values()) 

334 

335 def iter_all_outputs(self) -> Iterator[WriteEdge]: 

336 """Iterate over all outputs available after construction, including 

337 special ones. 

338 """ 

339 yield from self.outputs.values() 

340 yield self.config_output 

341 

342 def diff_edges(self, other: TaskInitNode) -> list[str]: 

343 """Compare the edges of this task initialization node to those from the 

344 same task label in a different pipeline. 

345 

346 Parameters 

347 ---------- 

348 other : `TaskInitNode` 

349 Other node to compare to. Must have the same task label, but need 

350 not have the same configuration or even the same task class. 

351 

352 Returns 

353 ------- 

354 differences : `list` [ `str` ] 

355 List of string messages describing differences between ``self`` and 

356 ``other``. Will be empty if the two nodes have the same edges. 

357 Messages will use 'A' to refer to ``self`` and 'B' to refer to 

358 ``other``. 

359 """ 

360 result = [] 

361 result += _diff_edge_mapping(self.inputs, self.inputs, self.label, "init input") 

362 result += _diff_edge_mapping(self.outputs, other.outputs, self.label, "init output") 

363 result += self.config_output.diff(other.config_output, "config init output") 

364 return result 

365 

366 def _to_xgraph_state(self) -> dict[str, Any]: 

367 """Convert this nodes's attributes into a dictionary suitable for use 

368 in exported networkx graphs. 

369 """ 

370 result = {"task_class_name": self.task_class_name, "bipartite": NodeType.TASK_INIT.bipartite} 

371 if hasattr(self, "_imported_data"): 

372 result["task_class"] = self.task_class 

373 result["config"] = self.config 

374 return result 

375 

376 def _get_imported_data(self) -> _TaskNodeImportedData: 

377 """Return the imported data struct. 

378 

379 Returns 

380 ------- 

381 imported_data : `_TaskNodeImportedData` 

382 Internal structure holding state that requires the task class to 

383 have been imported. 

384 

385 Raises 

386 ------ 

387 TaskNotImportedError 

388 Raised if `is_imported` is `False`. 

389 """ 

390 try: 

391 return self._imported_data 

392 except AttributeError: 

393 raise TaskNotImportedError( 

394 f"Task class {self.task_class_name!r} for label {self.label!r} has not been imported " 

395 "(see PipelineGraph.import_and_configure)." 

396 ) from None 

397 

398 

399@immutable 

400class TaskNode: 

401 """A node in a pipeline graph that represents a labeled configuration of a 

402 `PipelineTask`. 

403 

404 Parameters 

405 ---------- 

406 key : `NodeKey` 

407 Identifier for this node in networkx graphs. 

408 init : `TaskInitNode` 

409 Node representing the initialization of this task. 

410 prerequisite_inputs : `~collections.abc.Mapping` [ `str`, `ReadEdge` ] 

411 Graph edges that represent prerequisite inputs to this task, keyed by 

412 connection name. 

413 

414 Prerequisite inputs must already exist in the data repository when a 

415 `QuantumGraph` is built, but have more flexibility in how they are 

416 looked up than regular inputs. 

417 inputs : `~collections.abc.Mapping` [ `str`, `ReadEdge` ] 

418 Graph edges that represent regular runtime inputs to this task, keyed 

419 by connection name. 

420 outputs : ~collections.abc.Mapping` [ `str`, `WriteEdge` ] 

421 Graph edges that represent regular runtime outputs of this task, keyed 

422 by connection name. 

423 

424 This does not include the special `log_output` and `metadata_output` 

425 edges; use `iter_all_outputs` to include that, too. 

426 log_output : `WriteEdge` or `None` 

427 The special runtime output that persists the task's logs. 

428 metadata_output : `WriteEdge` 

429 The special runtime output that persists the task's metadata. 

430 dimensions : `lsst.daf.butler.DimensionGraph` or `frozenset` 

431 Dimensions of the task. If a `frozenset`, the dimensions have not been 

432 resolved by a `~lsst.daf.butler.DimensionUniverse` and cannot be safely 

433 compared to other sets of dimensions. 

434 

435 Notes 

436 ----- 

437 Task nodes are intentionally not equality comparable, since there are many 

438 different (and useful) ways to compare these objects with no clear winner 

439 as the most obvious behavior. 

440 

441 When included in an exported `networkx` graph (e.g. 

442 `PipelineGraph.make_xgraph`), task nodes set the following node attributes: 

443 

444 - ``task_class_name`` 

445 - ``bipartite`` (see `NodeType.bipartite`) 

446 - ``task_class`` (only if `is_imported` is `True`) 

447 - ``config`` (only if `is_importd` is `True`) 

448 """ 

449 

450 def __init__( 

451 self, 

452 key: NodeKey, 

453 init: TaskInitNode, 

454 *, 

455 prerequisite_inputs: Mapping[str, ReadEdge], 

456 inputs: Mapping[str, ReadEdge], 

457 outputs: Mapping[str, WriteEdge], 

458 log_output: WriteEdge | None, 

459 metadata_output: WriteEdge, 

460 dimensions: DimensionGraph | frozenset, 

461 ): 

462 self.key = key 

463 self.init = init 

464 self.prerequisite_inputs = prerequisite_inputs 

465 self.inputs = inputs 

466 self.outputs = outputs 

467 self.log_output = log_output 

468 self.metadata_output = metadata_output 

469 self._dimensions = dimensions 

470 

471 @staticmethod 

472 def _from_imported_data( 

473 key: NodeKey, 

474 init_key: NodeKey, 

475 data: _TaskNodeImportedData, 

476 universe: DimensionUniverse | None, 

477 ) -> TaskNode: 

478 """Construct from a `PipelineTask` type and its configuration. 

479 

480 Parameters 

481 ---------- 

482 key : `NodeKey` 

483 Identifier for this node in networkx graphs. 

484 init : `TaskInitNode` 

485 Node representing the initialization of this task. 

486 data : `_TaskNodeImportedData` 

487 Internal struct that holds information that requires the task class 

488 to have been be imported. 

489 universe : `lsst.daf.butler.DimensionUniverse` or `None` 

490 Definitions of all dimensions. 

491 

492 Returns 

493 ------- 

494 node : `TaskNode` 

495 New task node. 

496 

497 Raises 

498 ------ 

499 ValueError 

500 Raised if configuration validation failed when constructing 

501 ``connections``. 

502 """ 

503 init_inputs = { 

504 name: ReadEdge._from_connection_map(init_key, name, data.connection_map) 

505 for name in data.connections.initInputs 

506 } 

507 prerequisite_inputs = { 

508 name: ReadEdge._from_connection_map(key, name, data.connection_map, is_prerequisite=True) 

509 for name in data.connections.prerequisiteInputs 

510 } 

511 inputs = { 

512 name: ReadEdge._from_connection_map(key, name, data.connection_map) 

513 for name in data.connections.inputs 

514 } 

515 init_outputs = { 

516 name: WriteEdge._from_connection_map(init_key, name, data.connection_map) 

517 for name in data.connections.initOutputs 

518 } 

519 outputs = { 

520 name: WriteEdge._from_connection_map(key, name, data.connection_map) 

521 for name in data.connections.outputs 

522 } 

523 init = TaskInitNode( 

524 key=init_key, 

525 inputs=init_inputs, 

526 outputs=init_outputs, 

527 config_output=WriteEdge._from_connection_map( 

528 init_key, acc.CONFIG_INIT_OUTPUT_CONNECTION_NAME, data.connection_map 

529 ), 

530 imported_data=data, 

531 ) 

532 instance = TaskNode( 

533 key=key, 

534 init=init, 

535 prerequisite_inputs=prerequisite_inputs, 

536 inputs=inputs, 

537 outputs=outputs, 

538 log_output=( 

539 WriteEdge._from_connection_map(key, acc.LOG_OUTPUT_CONNECTION_NAME, data.connection_map) 

540 if data.config.saveLogOutput 

541 else None 

542 ), 

543 metadata_output=WriteEdge._from_connection_map( 

544 key, acc.METADATA_OUTPUT_CONNECTION_NAME, data.connection_map 

545 ), 

546 dimensions=( 

547 frozenset(data.connections.dimensions) 

548 if universe is None 

549 else universe.extract(data.connections.dimensions) 

550 ), 

551 ) 

552 return instance 

553 

554 key: NodeKey 

555 """Key that identifies this node in internal and exported networkx graphs. 

556 """ 

557 

558 prerequisite_inputs: Mapping[str, ReadEdge] 

559 """Graph edges that represent prerequisite inputs to this task. 

560 

561 Prerequisite inputs must already exist in the data repository when a 

562 `QuantumGraph` is built, but have more flexibility in how they are looked 

563 up than regular inputs. 

564 """ 

565 

566 inputs: Mapping[str, ReadEdge] 

567 """Graph edges that represent regular runtime inputs to this task. 

568 """ 

569 

570 outputs: Mapping[str, WriteEdge] 

571 """Graph edges that represent regular runtime outputs of this task. 

572 

573 This does not include the special `log_output` and `metadata_output` edges; 

574 use `iter_all_outputs` to include that, too. 

575 """ 

576 

577 log_output: WriteEdge | None 

578 """The special runtime output that persists the task's logs. 

579 """ 

580 

581 metadata_output: WriteEdge 

582 """The special runtime output that persists the task's metadata. 

583 """ 

584 

585 @property 

586 def label(self) -> str: 

587 """Label of this configuration of a task in the pipeline.""" 

588 return self.key.name 

589 

590 @property 

591 def is_imported(self) -> bool: 

592 """Whether this the task type for this node has been imported and 

593 its configuration overrides applied. 

594 

595 If this is `False`, the `task_class` and `config` attributes may not 

596 be accessed. 

597 """ 

598 return self.init.is_imported 

599 

600 @property 

601 def task_class(self) -> type[PipelineTask]: 

602 """Type object for the task. 

603 

604 Accessing this attribute when `is_imported` is `False` will raise 

605 `TaskNotImportedError`, but accessing `task_class_name` will not. 

606 """ 

607 return self.init.task_class 

608 

609 @property 

610 def task_class_name(self) -> str: 

611 """The fully-qualified string name of the task class.""" 

612 return self.init.task_class_name 

613 

614 @property 

615 def config(self) -> PipelineTaskConfig: 

616 """Configuration for the task. 

617 

618 This is always frozen. 

619 

620 Accessing this attribute when `is_imported` is `False` will raise 

621 `TaskNotImportedError`, but calling `get_config_str` will not. 

622 """ 

623 return self.init.config 

624 

625 @property 

626 def has_resolved_dimensions(self) -> bool: 

627 """Whether the `dimensions` attribute may be accessed. 

628 

629 If `False`, the `raw_dimensions` attribute may be used to obtain a 

630 set of dimension names that has not been resolved by a 

631 `~lsst.daf.butler.DimensionsUniverse`. 

632 """ 

633 return type(self._dimensions) is DimensionGraph 

634 

635 @property 

636 def dimensions(self) -> DimensionGraph: 

637 """Standardized dimensions of the task.""" 

638 if not self.has_resolved_dimensions: 

639 raise UnresolvedGraphError(f"Dimensions for task {self.label!r} have not been resolved.") 

640 return cast(DimensionGraph, self._dimensions) 

641 

642 @property 

643 def raw_dimensions(self) -> frozenset[str]: 

644 """Raw dimensions of the task, with standardization by a 

645 `~lsst.daf.butler.DimensionUniverse` not guaranteed. 

646 """ 

647 if self.has_resolved_dimensions: 

648 return frozenset(cast(DimensionGraph, self._dimensions).names) 

649 else: 

650 return cast(frozenset[str], self._dimensions) 

651 

652 def __repr__(self) -> str: 

653 if self.has_resolved_dimensions: 

654 return f"{self.label} ({self.task_class_name}, {self.dimensions})" 

655 else: 

656 return f"{self.label} ({self.task_class_name})" 

657 

658 def get_config_str(self) -> str: 

659 """Return the configuration for this task as a string of override 

660 statements. 

661 

662 Returns 

663 ------- 

664 config_str : `str` 

665 String containing configuration-overload statements. 

666 """ 

667 return self.init.get_config_str() 

668 

669 def iter_all_inputs(self) -> Iterator[ReadEdge]: 

670 """Iterate over all runtime inputs, including both regular inputs and 

671 prerequisites. 

672 """ 

673 yield from self.prerequisite_inputs.values() 

674 yield from self.inputs.values() 

675 

676 def iter_all_outputs(self) -> Iterator[WriteEdge]: 

677 """Iterate over all runtime outputs, including special ones.""" 

678 yield from self.outputs.values() 

679 yield self.metadata_output 

680 if self.log_output is not None: 

681 yield self.log_output 

682 

683 def diff_edges(self, other: TaskNode) -> list[str]: 

684 """Compare the edges of this task node to those from the same task 

685 label in a different pipeline. 

686 

687 This also calls `TaskInitNode.diff_edges`. 

688 

689 Parameters 

690 ---------- 

691 other : `TaskInitNode` 

692 Other node to compare to. Must have the same task label, but need 

693 not have the same configuration or even the same task class. 

694 

695 Returns 

696 ------- 

697 differences : `list` [ `str` ] 

698 List of string messages describing differences between ``self`` and 

699 ``other``. Will be empty if the two nodes have the same edges. 

700 Messages will use 'A' to refer to ``self`` and 'B' to refer to 

701 ``other``. 

702 """ 

703 result = self.init.diff_edges(other.init) 

704 result += _diff_edge_mapping( 

705 self.prerequisite_inputs, other.prerequisite_inputs, self.label, "prerequisite input" 

706 ) 

707 result += _diff_edge_mapping(self.inputs, other.inputs, self.label, "input") 

708 result += _diff_edge_mapping(self.outputs, other.outputs, self.label, "output") 

709 if self.log_output is not None: 

710 if other.log_output is not None: 

711 result += self.log_output.diff(other.log_output, "log output") 

712 else: 

713 result.append("Log output is present in A, but not in B.") 

714 elif other.log_output is not None: 

715 result.append("Log output is present in B, but not in A.") 

716 result += self.metadata_output.diff(other.metadata_output, "metadata output") 

717 return result 

718 

719 def get_lookup_function( 

720 self, connection_name: str 

721 ) -> Callable[[DatasetType, Registry, DataCoordinate, Sequence[str]], Iterable[DatasetRef]] | None: 

722 """Return the custom dataset query function for an edge, if one exists. 

723 

724 Parameters 

725 ---------- 

726 connection_name : `str` 

727 Name of the connection. 

728 

729 Returns 

730 ------- 

731 lookup_function : `~collections.abc.Callable` or `None` 

732 Callable that takes a dataset type, a butler registry, a data 

733 coordinate (the quantum data ID), and an ordered list of 

734 collections to search, and returns an iterable of 

735 `~lsst.daf.butler.DatasetRef`. 

736 """ 

737 return getattr(self._get_imported_data().connection_map[connection_name], "lookupFunction", None) 

738 

739 def get_spatial_bounds_connections(self) -> frozenset[str]: 

740 """Return the names of connections whose data IDs should be included 

741 in the calculation of the spatial bounds for this task's quanta. 

742 

743 Returns 

744 ------- 

745 connection_names : `frozenset` [ `str` ] 

746 Names of connections with spatial dimensions. 

747 """ 

748 return frozenset(self._get_imported_data().connections.getSpatialBoundsConnections()) 

749 

750 def get_temporal_bounds_connections(self) -> frozenset[str]: 

751 """Return the names of connections whose data IDs should be included 

752 in the calculation of the temporal bounds for this task's quanta. 

753 

754 Returns 

755 ------- 

756 connection_names : `frozenset` [ `str` ] 

757 Names of connections with temporal dimensions. 

758 """ 

759 return frozenset(self._get_imported_data().connections.getTemporalBoundsConnections()) 

760 

761 def _imported_and_configured(self, rebuild: bool) -> TaskNode: 

762 """Import the task class and use it to construct a new instance. 

763 

764 Parameters 

765 ---------- 

766 rebuild : `bool` 

767 If `True`, import the task class and configure its connections to 

768 generate new edges that may differ from the current ones. If 

769 `False`, import the task class but just update the `task_class` and 

770 `config` attributes, and assume the edges have not changed. 

771 

772 Returns 

773 ------- 

774 node : `TaskNode` 

775 Task node instance for which `is_imported` is `True`. Will be 

776 ``self`` if this is the case already. 

777 """ 

778 from ..pipelineTask import PipelineTask 

779 

780 if self.is_imported: 

781 return self 

782 task_class = doImportType(self.task_class_name) 

783 if not issubclass(task_class, PipelineTask): 

784 raise TypeError(f"{self.task_class_name!r} is not a PipelineTask subclass.") 

785 config = task_class.ConfigClass() 

786 config.loadFromString(self.get_config_str()) 

787 return self._reconfigured(config, rebuild=rebuild, task_class=task_class) 

788 

789 def _reconfigured( 

790 self, 

791 config: PipelineTaskConfig, 

792 rebuild: bool, 

793 task_class: type[PipelineTask] | None = None, 

794 ) -> TaskNode: 

795 """Return a version of this node with new configuration. 

796 

797 Parameters 

798 ---------- 

799 config : `.PipelineTaskConfig` 

800 New configuration for the task. 

801 rebuild : `bool` 

802 If `True`, use the configured connections to generate new edges 

803 that may differ from the current ones. If `False`, just update the 

804 `task_class` and `config` attributes, and assume the edges have not 

805 changed. 

806 task_class : `type` [ `PipelineTask` ], optional 

807 Subclass of `PipelineTask`. This defaults to ``self.task_class`, 

808 but may be passed as an argument if that is not available because 

809 the task class was not imported when ``self`` was constructed. 

810 

811 Returns 

812 ------- 

813 node : `TaskNode` 

814 Task node instance with the new config. 

815 """ 

816 if task_class is None: 

817 task_class = self.task_class 

818 imported_data = _TaskNodeImportedData.configure(self.key.name, task_class, config) 

819 if rebuild: 

820 return self._from_imported_data( 

821 self.key, 

822 self.init.key, 

823 imported_data, 

824 universe=self._dimensions.universe if type(self._dimensions) is DimensionGraph else None, 

825 ) 

826 else: 

827 return TaskNode( 

828 self.key, 

829 TaskInitNode( 

830 self.init.key, 

831 inputs=self.init.inputs, 

832 outputs=self.init.outputs, 

833 config_output=self.init.config_output, 

834 imported_data=imported_data, 

835 ), 

836 prerequisite_inputs=self.prerequisite_inputs, 

837 inputs=self.inputs, 

838 outputs=self.outputs, 

839 log_output=self.log_output, 

840 metadata_output=self.metadata_output, 

841 dimensions=self._dimensions, 

842 ) 

843 

844 def _resolved(self, universe: DimensionUniverse | None) -> TaskNode: 

845 """Return an otherwise-equivalent task node with resolved dimensions. 

846 

847 Parameters 

848 ---------- 

849 universe : `lsst.daf.butler.DimensionUniverse` or `None` 

850 Definitions for all dimensions. 

851 

852 Returns 

853 ------- 

854 node : `TaskNode` 

855 Task node instance with `dimensions` resolved by the given 

856 universe. Will be ``self`` if this is the case already. 

857 """ 

858 if self.has_resolved_dimensions: 

859 if cast(DimensionGraph, self._dimensions).universe is universe: 

860 return self 

861 elif universe is None: 

862 return self 

863 return TaskNode( 

864 key=self.key, 

865 init=self.init, 

866 prerequisite_inputs=self.prerequisite_inputs, 

867 inputs=self.inputs, 

868 outputs=self.outputs, 

869 log_output=self.log_output, 

870 metadata_output=self.metadata_output, 

871 dimensions=( 

872 universe.extract(self.raw_dimensions) if universe is not None else self.raw_dimensions 

873 ), 

874 ) 

875 

876 def _to_xgraph_state(self) -> dict[str, Any]: 

877 """Convert this nodes's attributes into a dictionary suitable for use 

878 in exported networkx graphs. 

879 """ 

880 result = self.init._to_xgraph_state() 

881 if self.has_resolved_dimensions: 

882 result["dimensions"] = self._dimensions 

883 result["raw_dimensions"] = self.raw_dimensions 

884 return result 

885 

886 def _get_imported_data(self) -> _TaskNodeImportedData: 

887 """Return the imported data struct. 

888 

889 Returns 

890 ------- 

891 imported_data : `_TaskNodeImportedData` 

892 Internal structure holding state that requires the task class to 

893 have been imported. 

894 

895 Raises 

896 ------ 

897 TaskNotImportedError 

898 Raised if `is_imported` is `False`. 

899 """ 

900 return self.init._get_imported_data() 

901 

902 

903def _diff_edge_mapping( 

904 a_mapping: Mapping[str, Edge], b_mapping: Mapping[str, Edge], task_label: str, connection_type: str 

905) -> list[str]: 

906 """Compare a pair of mappings of edges. 

907 

908 Parameters 

909 ---------- 

910 a_mapping : `~collections.abc.Mapping` [ `str`, `Edge` ] 

911 First mapping to compare. Expected to have connection names as keys. 

912 b_mapping : `~collections.abc.Mapping` [ `str`, `Edge` ] 

913 First mapping to compare. If keys differ from those of ``a_mapping``, 

914 this will be reported as a difference (in addition to element-wise 

915 comparisons). 

916 task_label : `str` 

917 Task label associated with both mappings. 

918 connection_type : `str` 

919 Type of connection (e.g. "input" or "init output") associated with both 

920 connections. This is a human-readable string to include in difference 

921 messages. 

922 

923 Returns 

924 ------- 

925 differences : `list` [ `str` ] 

926 List of string messages describing differences between the two 

927 mappings. Will be empty if the two mappings have the same edges. 

928 Messages will include "A" and "B", and are expected to be a preceded 

929 by a message describing what "A" and "B" are in the context in which 

930 this method is called. 

931 

932 Notes 

933 ----- 

934 This is expected to be used to compare one edge-holding mapping attribute 

935 of a task or task init node to the same attribute on another task or task 

936 init node (i.e. any of `TaskNode.inputs`, `TaskNode.outputs`, 

937 `TaskNode.prerequisite_inputs`, `TaskInitNode.inputs`, 

938 `TaskInitNode.outputs`). 

939 """ 

940 results = [] 

941 b_to_do = set(b_mapping.keys()) 

942 for connection_name, a_edge in a_mapping.items(): 

943 if (b_edge := b_mapping.get(connection_name)) is None: 

944 results.append( 

945 f"{connection_type.capitalize()} {connection_name!r} of task " 

946 f"{task_label!r} exists in A, but not in B (or it may have a different connection type)." 

947 ) 

948 else: 

949 results.extend(a_edge.diff(b_edge, connection_type)) 

950 b_to_do.discard(connection_name) 

951 for connection_name in b_to_do: 

952 results.append( 

953 f"{connection_type.capitalize()} {connection_name!r} of task " 

954 f"{task_label!r} exists in A, but not in B (or it may have a different connection type)." 

955 ) 

956 return results