Coverage for python/lsst/pipe/base/tests/mocks/_pipeline_task.py: 22%

227 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-03-23 03:26 -0700

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27from __future__ import annotations 

28 

29from lsst.pipe.base.connectionTypes import BaseInput, Output 

30 

31__all__ = ( 

32 "DynamicConnectionConfig", 

33 "DynamicTestPipelineTask", 

34 "DynamicTestPipelineTaskConfig", 

35 "MockPipelineTask", 

36 "MockPipelineTaskConfig", 

37 "mock_task_defs", 

38 "mock_pipeline_graph", 

39) 

40 

41import dataclasses 

42import logging 

43from collections.abc import Collection, Iterable, Mapping 

44from typing import TYPE_CHECKING, Any, ClassVar, TypeVar 

45 

46from lsst.daf.butler import DataCoordinate, DatasetRef, DeferredDatasetHandle, SerializedDatasetType 

47from lsst.pex.config import Config, ConfigDictField, ConfigurableField, Field, ListField 

48from lsst.utils.doImport import doImportType 

49from lsst.utils.introspection import get_full_type_name 

50from lsst.utils.iteration import ensure_iterable 

51 

52from ... import automatic_connection_constants as acc 

53from ... import connectionTypes as cT 

54from ...config import PipelineTaskConfig 

55from ...connections import InputQuantizedConnection, OutputQuantizedConnection, PipelineTaskConnections 

56from ...pipeline import TaskDef 

57from ...pipeline_graph import PipelineGraph 

58from ...pipelineTask import PipelineTask 

59from ._data_id_match import DataIdMatch 

60from ._storage_class import MockDataset, MockDatasetQuantum, MockStorageClass, get_mock_name 

61 

62_LOG = logging.getLogger(__name__) 

63 

64if TYPE_CHECKING: 

65 from ..._quantumContext import QuantumContext 

66 

67 

68_T = TypeVar("_T", bound=cT.BaseConnection) 

69 

70 

71def mock_task_defs( 

72 originals: Iterable[TaskDef], 

73 unmocked_dataset_types: Iterable[str] = (), 

74 force_failures: Mapping[str, tuple[str, type[Exception] | None]] | None = None, 

75) -> list[TaskDef]: 

76 """Create mocks for an iterable of TaskDefs. 

77 

78 Parameters 

79 ---------- 

80 originals : `~collections.abc.Iterable` [ `TaskDef` ] 

81 Original tasks and configuration to mock. 

82 unmocked_dataset_types : `~collections.abc.Iterable` [ `str` ], optional 

83 Names of overall-input dataset types that should not be replaced with 

84 mocks. 

85 force_failures : `~collections.abc.Mapping` [ `str`, `tuple` [ `str`, \ 

86 `type` [ `Exception` ] or `None` ] ] 

87 Mapping from original task label to a 2-tuple indicating that some 

88 quanta should raise an exception when executed. The first entry is a 

89 data ID match using the butler expression language (i.e. a string of 

90 the sort passed ass the ``where`` argument to butler query methods), 

91 while the second is the type of exception to raise when the quantum 

92 data ID matches the expression. An exception type of `None` uses 

93 the default, `ValueError`. 

94 

95 Returns 

96 ------- 

97 mocked : `list` [ `TaskDef` ] 

98 List of `TaskDef` objects using `MockPipelineTask` configurations that 

99 target the original tasks, in the same order. 

100 """ 

101 unmocked_dataset_types = tuple(unmocked_dataset_types) 

102 if force_failures is None: 

103 force_failures = {} 

104 results: list[TaskDef] = [] 

105 for original_task_def in originals: 

106 config = MockPipelineTaskConfig() 

107 config.original.retarget(original_task_def.taskClass) 

108 config.original = original_task_def.config 

109 config.unmocked_dataset_types.extend(unmocked_dataset_types) 

110 if original_task_def.label in force_failures: 

111 condition, exception_type = force_failures[original_task_def.label] 

112 config.fail_condition = condition 

113 if exception_type is not None: 

114 config.fail_exception = get_full_type_name(exception_type) 

115 mock_task_def = TaskDef( 

116 config=config, taskClass=MockPipelineTask, label=get_mock_name(original_task_def.label) 

117 ) 

118 results.append(mock_task_def) 

119 return results 

120 

121 

122def mock_pipeline_graph( 

123 original_graph: PipelineGraph, 

124 unmocked_dataset_types: Iterable[str] = (), 

125 force_failures: Mapping[str, tuple[str, type[Exception] | None]] | None = None, 

126) -> PipelineGraph: 

127 """Create mocks for a full pipeline graph. 

128 

129 Parameters 

130 ---------- 

131 original_graph : `~..pipeline_graph.PipelineGraph` 

132 Original tasks and configuration to mock. 

133 unmocked_dataset_types : `~collections.abc.Iterable` [ `str` ], optional 

134 Names of overall-input dataset types that should not be replaced with 

135 mocks. 

136 force_failures : `~collections.abc.Mapping` [ `str`, `tuple` [ `str`, \ 

137 `type` [ `Exception` ] or `None` ] ] 

138 Mapping from original task label to a 2-tuple indicating that some 

139 quanta should raise an exception when executed. The first entry is a 

140 data ID match using the butler expression language (i.e. a string of 

141 the sort passed as the ``where`` argument to butler query methods), 

142 while the second is the type of exception to raise when the quantum 

143 data ID matches the expression. An exception type of `None` uses 

144 the default, `ValueError`. 

145 

146 Returns 

147 ------- 

148 mocked : `~..pipeline_graph.PipelineGraph` 

149 Pipeline graph using `MockPipelineTask` configurations that target the 

150 original tasks. Never resolved. 

151 """ 

152 unmocked_dataset_types = tuple(unmocked_dataset_types) 

153 if force_failures is None: 

154 force_failures = {} 

155 result = PipelineGraph(description=original_graph.description) 

156 for original_task_node in original_graph.tasks.values(): 

157 config = MockPipelineTaskConfig() 

158 config.original.retarget(original_task_node.task_class) 

159 config.original = original_task_node.config 

160 config.unmocked_dataset_types.extend(unmocked_dataset_types) 

161 if original_task_node.label in force_failures: 

162 condition, exception_type = force_failures[original_task_node.label] 

163 config.fail_condition = condition 

164 if exception_type is not None: 

165 config.fail_exception = get_full_type_name(exception_type) 

166 result.add_task(get_mock_name(original_task_node.label), MockPipelineTask, config=config) 

167 return result 

168 

169 

170class BaseTestPipelineTaskConnections(PipelineTaskConnections, dimensions=()): 

171 pass 

172 

173 

174class BaseTestPipelineTaskConfig(PipelineTaskConfig, pipelineConnections=BaseTestPipelineTaskConnections): 

175 fail_condition = Field[str]( 

176 dtype=str, 

177 default="", 

178 doc=( 

179 "Condition on Data ID to raise an exception. String expression which includes attributes of " 

180 "quantum data ID using a syntax of daf_butler user expressions (e.g. 'visit = 123')." 

181 ), 

182 ) 

183 

184 fail_exception = Field[str]( 

185 dtype=str, 

186 default="builtins.ValueError", 

187 doc=( 

188 "Class name of the exception to raise when fail condition is triggered. Can be " 

189 "'lsst.pipe.base.NoWorkFound' to specify non-failure exception." 

190 ), 

191 ) 

192 

193 def data_id_match(self) -> DataIdMatch | None: 

194 if not self.fail_condition: 

195 return None 

196 return DataIdMatch(self.fail_condition) 

197 

198 

199class BaseTestPipelineTask(PipelineTask): 

200 """A base class for test-utility `PipelineTask` classes that read and write 

201 mock datasets `runQuantum`. 

202 

203 Parameters 

204 ---------- 

205 config : `PipelineTaskConfig` 

206 The pipeline task config. 

207 initInputs : `~collections.abc.Mapping` 

208 The init inputs datasets. 

209 **kwargs : `~typing.Any` 

210 Keyword parameters passed to base class constructor. 

211 

212 Notes 

213 ----- 

214 This class overrides `runQuantum` to read inputs and write a bit of 

215 provenance into all of its outputs (always `MockDataset` instances). It 

216 can also be configured to raise exceptions on certain data IDs. It reads 

217 `MockDataset` inputs and simulates reading inputs of other types by 

218 creating `MockDataset` inputs from their DatasetRefs. 

219 

220 Subclasses are responsible for defining connections, but init-input and 

221 init-output connections are not supported at runtime (they may be present 

222 as long as the task is never constructed). All output connections must 

223 use mock storage classes. `..Input` and `..PrerequisiteInput` connections 

224 that do not use mock storage classes will be handled by constructing a 

225 `MockDataset` from the `~lsst.daf.butler.DatasetRef` rather than actually 

226 reading them. 

227 """ 

228 

229 ConfigClass: ClassVar[type[PipelineTaskConfig]] = BaseTestPipelineTaskConfig 

230 

231 def __init__( 

232 self, 

233 *, 

234 config: BaseTestPipelineTaskConfig, 

235 initInputs: Mapping[str, Any], 

236 **kwargs: Any, 

237 ): 

238 super().__init__(config=config, **kwargs) 

239 self.fail_exception: type | None = None 

240 self.data_id_match = self.config.data_id_match() 

241 if self.data_id_match: 

242 self.fail_exception = doImportType(self.config.fail_exception) 

243 # Look for, check, and record init-inputs. 

244 task_connections = self.ConfigClass.ConnectionsClass(config=config) 

245 mock_dataset_quantum = MockDatasetQuantum(task_label=self.getName(), data_id={}, inputs={}) 

246 for connection_name in task_connections.initInputs: 

247 input_dataset = initInputs[connection_name] 

248 if not isinstance(input_dataset, MockDataset): 

249 raise TypeError( 

250 f"Expected MockDataset instance for init-input {self.getName()}.{connection_name}: " 

251 f"got {input_dataset!r} of type {type(input_dataset)!r}." 

252 ) 

253 connection = task_connections.allConnections[connection_name] 

254 if input_dataset.dataset_type.name != connection.name: 

255 raise RuntimeError( 

256 f"Incorrect dataset type name for init-input {self.getName()}.{connection_name}: " 

257 f"got {input_dataset.dataset_type.name!r}, expected {connection.name!r}." 

258 ) 

259 if input_dataset.storage_class != connection.storageClass: 

260 raise RuntimeError( 

261 f"Incorrect storage class for init-input {self.getName()}.{connection_name}: " 

262 f"got {input_dataset.storage_class!r}, expected {connection.storageClass!r}." 

263 ) 

264 # To avoid very deep provenance we trim inputs to a single 

265 # level. 

266 input_dataset.quantum = None 

267 mock_dataset_quantum.inputs[connection_name] = [input_dataset] 

268 # Add init-outputs as task instance attributes. 

269 for connection_name in task_connections.initOutputs: 

270 connection = task_connections.allConnections[connection_name] 

271 output_dataset = MockDataset( 

272 dataset_id=None, # the task has no way to get this 

273 dataset_type=SerializedDatasetType( 

274 name=connection.name, 

275 storageClass=connection.storageClass, 

276 dimensions=[], 

277 ), 

278 data_id={}, 

279 run=None, # task also has no way to get this 

280 quantum=mock_dataset_quantum, 

281 output_connection_name=connection_name, 

282 ) 

283 setattr(self, connection_name, output_dataset) 

284 

285 config: BaseTestPipelineTaskConfig 

286 

287 def runQuantum( 

288 self, 

289 butlerQC: QuantumContext, 

290 inputRefs: InputQuantizedConnection, 

291 outputRefs: OutputQuantizedConnection, 

292 ) -> None: 

293 # docstring is inherited from the base class 

294 quantum = butlerQC.quantum 

295 

296 _LOG.info("Mocking execution of task '%s' on quantum %s", self.getName(), quantum.dataId) 

297 

298 assert quantum.dataId is not None, "Quantum DataId cannot be None" 

299 

300 # Possibly raise an exception. 

301 if self.data_id_match is not None and self.data_id_match.match(quantum.dataId): 

302 _LOG.info("Simulating failure of task '%s' on quantum %s", self.getName(), quantum.dataId) 

303 message = f"Simulated failure: task={self.getName()} dataId={quantum.dataId}" 

304 assert self.fail_exception is not None, "Exception type must be defined" 

305 raise self.fail_exception(message) 

306 

307 # Populate the bit of provenance we store in all outputs. 

308 _LOG.info("Reading input data for task '%s' on quantum %s", self.getName(), quantum.dataId) 

309 mock_dataset_quantum = MockDatasetQuantum( 

310 task_label=self.getName(), data_id=dict(quantum.dataId.mapping), inputs={} 

311 ) 

312 for name, refs in inputRefs: 

313 inputs_list = [] 

314 ref: DatasetRef 

315 for ref in ensure_iterable(refs): 

316 if isinstance(ref.datasetType.storageClass, MockStorageClass): 

317 input_dataset = butlerQC.get(ref) 

318 if isinstance(input_dataset, DeferredDatasetHandle): 

319 input_dataset = input_dataset.get() 

320 if not isinstance(input_dataset, MockDataset): 

321 raise TypeError( 

322 f"Expected MockDataset instance for {ref}; " 

323 f"got {input_dataset!r} of type {type(input_dataset)!r}." 

324 ) 

325 # To avoid very deep provenance we trim inputs to a single 

326 # level. 

327 input_dataset.quantum = None 

328 else: 

329 input_dataset = MockDataset( 

330 dataset_id=ref.id, 

331 dataset_type=ref.datasetType.to_simple(), 

332 data_id=dict(ref.dataId.mapping), 

333 run=ref.run, 

334 ) 

335 inputs_list.append(input_dataset) 

336 mock_dataset_quantum.inputs[name] = inputs_list 

337 

338 # store mock outputs 

339 for name, refs in outputRefs: 

340 for ref in ensure_iterable(refs): 

341 output = MockDataset( 

342 dataset_id=ref.id, 

343 dataset_type=ref.datasetType.to_simple(), 

344 data_id=dict(ref.dataId.mapping), 

345 run=ref.run, 

346 quantum=mock_dataset_quantum, 

347 output_connection_name=name, 

348 ) 

349 butlerQC.put(output, ref) 

350 

351 _LOG.info("Finished mocking task '%s' on quantum %s", self.getName(), quantum.dataId) 

352 

353 

354class MockPipelineDefaultTargetConnections(PipelineTaskConnections, dimensions=()): 

355 pass 

356 

357 

358class MockPipelineDefaultTargetConfig( 

359 PipelineTaskConfig, pipelineConnections=MockPipelineDefaultTargetConnections 

360): 

361 pass 

362 

363 

364class MockPipelineDefaultTargetTask(PipelineTask): 

365 """A `~lsst.pipe.base.PipelineTask` class used as the default target for 

366 ``MockPipelineTaskConfig.original``. 

367 

368 This is effectively a workaround for `lsst.pex.config.ConfigurableField` 

369 not supporting ``optional=True``, but that is generally a reasonable 

370 limitation for production code and it wouldn't make sense just to support 

371 test utilities. 

372 """ 

373 

374 ConfigClass = MockPipelineDefaultTargetConfig 

375 

376 

377class MockPipelineTaskConnections(BaseTestPipelineTaskConnections, dimensions=()): 

378 """A connections class that creates mock connections from the connections 

379 of a real PipelineTask. 

380 

381 Parameters 

382 ---------- 

383 config : `PipelineTaskConfig` 

384 The config to use for the connection. 

385 """ 

386 

387 def __init__(self, *, config: MockPipelineTaskConfig): 

388 self.original: PipelineTaskConnections = config.original.connections.ConnectionsClass( 

389 config=config.original.value 

390 ) 

391 self.dimensions.update(self.original.dimensions) 

392 self.unmocked_dataset_types = frozenset(config.unmocked_dataset_types) 

393 for name, connection in self.original.allConnections.items(): 

394 if connection.name not in self.unmocked_dataset_types: 

395 if connection.storageClass in ( 

396 acc.CONFIG_INIT_OUTPUT_STORAGE_CLASS, 

397 acc.METADATA_OUTPUT_STORAGE_CLASS, 

398 acc.LOG_OUTPUT_STORAGE_CLASS, 

399 ): 

400 # We don't mock the automatic output connections, so if 

401 # they're used as an input in any other connection, we 

402 # can't mock them there either. 

403 storage_class_name = connection.storageClass 

404 else: 

405 # We register the mock storage class with the global 

406 # singleton here, but can only put its name in the 

407 # connection. That means the same global singleton (or one 

408 # that also has these registrations) has to be available 

409 # whenever this dataset type is used. 

410 storage_class_name = MockStorageClass.get_or_register_mock(connection.storageClass).name 

411 kwargs: dict[str, Any] = {} 

412 if hasattr(connection, "dimensions"): 

413 connection_dimensions = set(connection.dimensions) 

414 # Replace the generic "skypix" placeholder with htm7, since 

415 # that requires the dataset type to have already been 

416 # registered. 

417 if "skypix" in connection_dimensions: 

418 connection_dimensions.remove("skypix") 

419 connection_dimensions.add("htm7") 

420 kwargs["dimensions"] = connection_dimensions 

421 connection = dataclasses.replace( 

422 connection, 

423 name=get_mock_name(connection.name), 

424 storageClass=storage_class_name, 

425 **kwargs, 

426 ) 

427 elif name in self.original.outputs: 

428 raise ValueError(f"Unmocked dataset type {connection.name!r} cannot be used as an output.") 

429 elif name in self.original.initInputs: 

430 raise ValueError( 

431 f"Unmocked dataset type {connection.name!r} cannot be used as an init-input." 

432 ) 

433 elif name in self.original.initOutputs: 

434 raise ValueError( 

435 f"Unmocked dataset type {connection.name!r} cannot be used as an init-output." 

436 ) 

437 setattr(self, name, connection) 

438 

439 def getSpatialBoundsConnections(self) -> Iterable[str]: 

440 return self.original.getSpatialBoundsConnections() 

441 

442 def getTemporalBoundsConnections(self) -> Iterable[str]: 

443 return self.original.getTemporalBoundsConnections() 

444 

445 def adjustQuantum( 

446 self, 

447 inputs: dict[str, tuple[BaseInput, Collection[DatasetRef]]], 

448 outputs: dict[str, tuple[Output, Collection[DatasetRef]]], 

449 label: str, 

450 data_id: DataCoordinate, 

451 ) -> tuple[ 

452 Mapping[str, tuple[BaseInput, Collection[DatasetRef]]], 

453 Mapping[str, tuple[Output, Collection[DatasetRef]]], 

454 ]: 

455 # Convert the given mappings from the mock dataset types to the 

456 # original dataset types they were produced from. 

457 original_inputs = {} 

458 for connection_name, (_, mock_refs) in inputs.items(): 

459 original_connection = getattr(self.original, connection_name) 

460 if original_connection.name in self.unmocked_dataset_types: 

461 refs = mock_refs 

462 else: 

463 refs = MockStorageClass.unmock_dataset_refs(mock_refs) 

464 original_inputs[connection_name] = (original_connection, refs) 

465 original_outputs = {} 

466 for connection_name, (_, mock_refs) in outputs.items(): 

467 original_connection = getattr(self.original, connection_name) 

468 if original_connection.name in self.unmocked_dataset_types: 

469 refs = mock_refs 

470 else: 

471 refs = MockStorageClass.unmock_dataset_refs(mock_refs) 

472 original_outputs[connection_name] = (original_connection, refs) 

473 # Call adjustQuantum on the original connections class. 

474 adjusted_original_inputs, adjusted_original_outputs = self.original.adjustQuantum( 

475 original_inputs, original_outputs, label, data_id 

476 ) 

477 # Convert the results back to the mock dataset type.s 

478 adjusted_inputs = {} 

479 for connection_name, (original_connection, original_refs) in adjusted_original_inputs.items(): 

480 if original_connection.name in self.unmocked_dataset_types: 

481 refs = original_refs 

482 else: 

483 refs = MockStorageClass.mock_dataset_refs(original_refs) 

484 adjusted_inputs[connection_name] = (getattr(self, connection_name), refs) 

485 adjusted_outputs = {} 

486 for connection_name, (original_connection, original_refs) in adjusted_original_outputs.items(): 

487 if original_connection.name in self.unmocked_dataset_types: 

488 refs = original_refs 

489 else: 

490 refs = MockStorageClass.mock_dataset_refs(original_refs) 

491 adjusted_outputs[connection_name] = (getattr(self, connection_name), refs) 

492 return adjusted_inputs, adjusted_outputs 

493 

494 

495class MockPipelineTaskConfig(BaseTestPipelineTaskConfig, pipelineConnections=MockPipelineTaskConnections): 

496 """Configuration class for `MockPipelineTask`.""" 

497 

498 original: ConfigurableField = ConfigurableField( 

499 doc="The original task being mocked by this one.", target=MockPipelineDefaultTargetTask 

500 ) 

501 

502 unmocked_dataset_types = ListField[str]( 

503 doc=( 

504 "Names of input dataset types that should be used as-is instead " 

505 "of being mocked. May include dataset types not relevant for " 

506 "this task, which will be ignored." 

507 ), 

508 default=(), 

509 optional=False, 

510 ) 

511 

512 

513class MockPipelineTask(BaseTestPipelineTask): 

514 """A test-utility implementation of `PipelineTask` with connections 

515 generated by mocking those of a real task. 

516 

517 Notes 

518 ----- 

519 At present `MockPipelineTask` simply drops any ``initInput`` and 

520 ``initOutput`` connections present on the original, since `MockDataset` 

521 creation for those would have to happen in the code that executes the task, 

522 not in the task itself. Because `MockPipelineTask` never instantiates the 

523 mock task (just its connections class), this is a limitation on what the 

524 mocks can be used to test, not anything deeper. 

525 """ 

526 

527 ConfigClass: ClassVar[type[PipelineTaskConfig]] = MockPipelineTaskConfig 

528 

529 

530class DynamicConnectionConfig(Config): 

531 """A config class that defines a completely dynamic connection.""" 

532 

533 dataset_type_name = Field[str](doc="Name for the dataset type as seen by the butler.", dtype=str) 

534 dimensions = ListField[str](doc="Dimensions for the dataset type.", dtype=str, default=[]) 

535 storage_class = Field[str]( 

536 doc="Name of the butler storage class for the dataset type.", dtype=str, default="StructuredDataDict" 

537 ) 

538 is_calibration = Field[bool](doc="Whether this dataset type is a calibration.", dtype=bool, default=False) 

539 multiple = Field[bool]( 

540 doc="Whether this connection gets or puts multiple datasets for each quantum.", 

541 dtype=bool, 

542 default=False, 

543 ) 

544 mock_storage_class = Field[bool]( 

545 doc="Whether the storage class should actually be a mock of the storage class given.", 

546 dtype=bool, 

547 default=True, 

548 ) 

549 

550 def make_connection(self, cls: type[_T]) -> _T: 

551 storage_class = self.storage_class 

552 if self.mock_storage_class: 

553 storage_class = MockStorageClass.get_or_register_mock(storage_class).name 

554 if issubclass(cls, cT.DimensionedConnection): 

555 return cls( # type: ignore 

556 name=self.dataset_type_name, 

557 storageClass=storage_class, 

558 isCalibration=self.is_calibration, 

559 multiple=self.multiple, 

560 dimensions=frozenset(self.dimensions), 

561 ) 

562 else: 

563 return cls( 

564 name=self.dataset_type_name, 

565 storageClass=storage_class, 

566 multiple=self.multiple, 

567 ) 

568 

569 

570class DynamicTestPipelineTaskConnections(PipelineTaskConnections, dimensions=()): 

571 """A connections class whose dimensions and connections are wholly 

572 determined via configuration. 

573 

574 Parameters 

575 ---------- 

576 config : `PipelineTaskConfig` 

577 Config to use for this connections object. 

578 """ 

579 

580 def __init__(self, *, config: DynamicTestPipelineTaskConfig): 

581 self.dimensions.update(config.dimensions) 

582 connection_config: DynamicConnectionConfig 

583 for connection_name, connection_config in config.init_inputs.items(): 

584 setattr(self, connection_name, connection_config.make_connection(cT.InitInput)) 

585 for connection_name, connection_config in config.init_outputs.items(): 

586 setattr(self, connection_name, connection_config.make_connection(cT.InitOutput)) 

587 for connection_name, connection_config in config.prerequisite_inputs.items(): 

588 setattr(self, connection_name, connection_config.make_connection(cT.PrerequisiteInput)) 

589 for connection_name, connection_config in config.inputs.items(): 

590 setattr(self, connection_name, connection_config.make_connection(cT.Input)) 

591 for connection_name, connection_config in config.outputs.items(): 

592 setattr(self, connection_name, connection_config.make_connection(cT.Output)) 

593 

594 

595class DynamicTestPipelineTaskConfig( 

596 PipelineTaskConfig, pipelineConnections=DynamicTestPipelineTaskConnections 

597): 

598 """Configuration for DynamicTestPipelineTask.""" 

599 

600 dimensions = ListField[str](doc="Dimensions for the task's quanta.", dtype=str, default=[]) 

601 init_inputs = ConfigDictField( 

602 doc=( 

603 "Init-input connections, keyed by the connection name as seen by the task. " 

604 "Must be empty if the task will be constructed." 

605 ), 

606 keytype=str, 

607 itemtype=DynamicConnectionConfig, 

608 default={}, 

609 ) 

610 init_outputs = ConfigDictField( 

611 doc=( 

612 "Init-output connections, keyed by the connection name as seen by the task. " 

613 "Must be empty if the task will be constructed." 

614 ), 

615 keytype=str, 

616 itemtype=DynamicConnectionConfig, 

617 default={}, 

618 ) 

619 prerequisite_inputs = ConfigDictField( 

620 doc="Prerequisite input connections, keyed by the connection name as seen by the task.", 

621 keytype=str, 

622 itemtype=DynamicConnectionConfig, 

623 default={}, 

624 ) 

625 inputs = ConfigDictField( 

626 doc="Regular input connections, keyed by the connection name as seen by the task.", 

627 keytype=str, 

628 itemtype=DynamicConnectionConfig, 

629 default={}, 

630 ) 

631 outputs = ConfigDictField( 

632 doc="Regular output connections, keyed by the connection name as seen by the task.", 

633 keytype=str, 

634 itemtype=DynamicConnectionConfig, 

635 default={}, 

636 ) 

637 

638 

639class DynamicTestPipelineTask(BaseTestPipelineTask): 

640 """A test-utility implementation of `PipelineTask` with dimensions and 

641 connections determined wholly from configuration. 

642 """ 

643 

644 ConfigClass: ClassVar[type[PipelineTaskConfig]] = DynamicTestPipelineTaskConfig