Coverage for python/lsst/pipe/base/tests/mocks/_pipeline_task.py: 25%

241 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-03-30 10:01 +0000

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27from __future__ import annotations 

28 

29from lsst.pipe.base.connectionTypes import BaseInput, Output 

30 

31__all__ = ( 

32 "DynamicConnectionConfig", 

33 "DynamicTestPipelineTask", 

34 "DynamicTestPipelineTaskConfig", 

35 "ForcedFailure", 

36 "MockPipelineTask", 

37 "MockPipelineTaskConfig", 

38 "mock_task_defs", 

39 "mock_pipeline_graph", 

40) 

41 

42import dataclasses 

43import logging 

44from collections.abc import Collection, Iterable, Mapping 

45from typing import TYPE_CHECKING, Any, ClassVar, TypeVar 

46 

47from astropy.units import Quantity 

48from lsst.daf.butler import DataCoordinate, DatasetRef, DeferredDatasetHandle, SerializedDatasetType 

49from lsst.pex.config import Config, ConfigDictField, ConfigurableField, Field, ListField 

50from lsst.utils.doImport import doImportType 

51from lsst.utils.introspection import get_full_type_name 

52from lsst.utils.iteration import ensure_iterable 

53 

54from ... import automatic_connection_constants as acc 

55from ... import connectionTypes as cT 

56from ...config import PipelineTaskConfig 

57from ...connections import InputQuantizedConnection, OutputQuantizedConnection, PipelineTaskConnections 

58from ...pipeline import TaskDef 

59from ...pipeline_graph import PipelineGraph 

60from ...pipelineTask import PipelineTask 

61from ._data_id_match import DataIdMatch 

62from ._storage_class import MockDataset, MockDatasetQuantum, MockStorageClass, get_mock_name 

63 

64_LOG = logging.getLogger(__name__) 

65 

66if TYPE_CHECKING: 

67 from ..._quantumContext import QuantumContext 

68 

69 

70_T = TypeVar("_T", bound=cT.BaseConnection) 

71 

72 

73@dataclasses.dataclass 

74class ForcedFailure: 

75 """Information about an exception that should be raised by one or more 

76 quanta. 

77 """ 

78 

79 condition: str 

80 """Butler expression-language string that matches the data IDs that should 

81 raise. 

82 """ 

83 

84 exception_type: type[Exception] | None = None 

85 """The type of exception to raise.""" 

86 

87 memory_required: Quantity | None = None 

88 """If not `None`, this failure simulates an out-of-memory failure by 

89 raising only if this value exceeds `ExecutionResources.max_mem`. 

90 """ 

91 

92 def set_config(self, config: MockPipelineTaskConfig) -> None: 

93 config.fail_condition = self.condition 

94 if self.exception_type: 

95 config.fail_exception = get_full_type_name(self.exception_type) 

96 config.memory_required = self.memory_required 

97 

98 

99def mock_task_defs( 

100 originals: Iterable[TaskDef], 

101 unmocked_dataset_types: Iterable[str] = (), 

102 force_failures: Mapping[str, ForcedFailure] | None = None, 

103) -> list[TaskDef]: 

104 """Create mocks for an iterable of TaskDefs. 

105 

106 Parameters 

107 ---------- 

108 originals : `~collections.abc.Iterable` [ `TaskDef` ] 

109 Original tasks and configuration to mock. 

110 unmocked_dataset_types : `~collections.abc.Iterable` [ `str` ], optional 

111 Names of overall-input dataset types that should not be replaced with 

112 mocks. 

113 force_failures : `~collections.abc.Mapping` [ `str`, `ForcedFailure` ] 

114 Mapping from original task label to information about an exception one 

115 or more quanta for this task should raise. 

116 

117 Returns 

118 ------- 

119 mocked : `list` [ `TaskDef` ] 

120 List of `TaskDef` objects using `MockPipelineTask` configurations that 

121 target the original tasks, in the same order. 

122 """ 

123 unmocked_dataset_types = tuple(unmocked_dataset_types) 

124 if force_failures is None: 

125 force_failures = {} 

126 results: list[TaskDef] = [] 

127 for original_task_def in originals: 

128 config = MockPipelineTaskConfig() 

129 config.original.retarget(original_task_def.taskClass) 

130 config.original = original_task_def.config 

131 config.unmocked_dataset_types.extend(unmocked_dataset_types) 

132 if original_task_def.label in force_failures: 

133 force_failures[original_task_def.label].set_config(config) 

134 mock_task_def = TaskDef( 

135 config=config, taskClass=MockPipelineTask, label=get_mock_name(original_task_def.label) 

136 ) 

137 results.append(mock_task_def) 

138 return results 

139 

140 

141def mock_pipeline_graph( 

142 original_graph: PipelineGraph, 

143 unmocked_dataset_types: Iterable[str] = (), 

144 force_failures: Mapping[str, ForcedFailure] | None = None, 

145) -> PipelineGraph: 

146 """Create mocks for a full pipeline graph. 

147 

148 Parameters 

149 ---------- 

150 original_graph : `~..pipeline_graph.PipelineGraph` 

151 Original tasks and configuration to mock. 

152 unmocked_dataset_types : `~collections.abc.Iterable` [ `str` ], optional 

153 Names of overall-input dataset types that should not be replaced with 

154 mocks. 

155 force_failures : `~collections.abc.Mapping` [ `str`, `ForcedFailure` ] 

156 Mapping from original task label to information about an exception one 

157 or more quanta for this task should raise. 

158 

159 Returns 

160 ------- 

161 mocked : `~..pipeline_graph.PipelineGraph` 

162 Pipeline graph using `MockPipelineTask` configurations that target the 

163 original tasks. Never resolved. 

164 """ 

165 unmocked_dataset_types = tuple(unmocked_dataset_types) 

166 if force_failures is None: 

167 force_failures = {} 

168 result = PipelineGraph(description=original_graph.description) 

169 for original_task_node in original_graph.tasks.values(): 

170 config = MockPipelineTaskConfig() 

171 config.original.retarget(original_task_node.task_class) 

172 config.original = original_task_node.config 

173 config.unmocked_dataset_types.extend(unmocked_dataset_types) 

174 if original_task_node.label in force_failures: 

175 force_failures[original_task_node.label].set_config(config) 

176 result.add_task(get_mock_name(original_task_node.label), MockPipelineTask, config=config) 

177 return result 

178 

179 

180class BaseTestPipelineTaskConnections(PipelineTaskConnections, dimensions=()): 

181 pass 

182 

183 

184class BaseTestPipelineTaskConfig(PipelineTaskConfig, pipelineConnections=BaseTestPipelineTaskConnections): 

185 fail_condition = Field[str]( 

186 dtype=str, 

187 default="", 

188 doc=( 

189 "Condition on Data ID to raise an exception. String expression which includes attributes of " 

190 "quantum data ID using a syntax of daf_butler user expressions (e.g. 'visit = 123')." 

191 ), 

192 ) 

193 

194 fail_exception = Field[str]( 

195 dtype=str, 

196 default="builtins.ValueError", 

197 doc=( 

198 "Class name of the exception to raise when fail condition is triggered. Can be " 

199 "'lsst.pipe.base.NoWorkFound' to specify non-failure exception." 

200 ), 

201 ) 

202 

203 memory_required = Field[str]( 

204 dtype=str, 

205 default=None, 

206 optional=True, 

207 doc=( 

208 "If not None, simulate an out-of-memory failure by raising only if ExecutionResource.max_mem " 

209 "exceeds this value. This string should include units as parsed by astropy.units.Quantity " 

210 "(e.g. '4GB')." 

211 ), 

212 ) 

213 

214 def data_id_match(self) -> DataIdMatch | None: 

215 if not self.fail_condition: 

216 return None 

217 return DataIdMatch(self.fail_condition) 

218 

219 

220class BaseTestPipelineTask(PipelineTask): 

221 """A base class for test-utility `PipelineTask` classes that read and write 

222 mock datasets `runQuantum`. 

223 

224 Parameters 

225 ---------- 

226 config : `PipelineTaskConfig` 

227 The pipeline task config. 

228 initInputs : `~collections.abc.Mapping` 

229 The init inputs datasets. 

230 **kwargs : `~typing.Any` 

231 Keyword parameters passed to base class constructor. 

232 

233 Notes 

234 ----- 

235 This class overrides `runQuantum` to read inputs and write a bit of 

236 provenance into all of its outputs (always `MockDataset` instances). It 

237 can also be configured to raise exceptions on certain data IDs. It reads 

238 `MockDataset` inputs and simulates reading inputs of other types by 

239 creating `MockDataset` inputs from their DatasetRefs. 

240 

241 Subclasses are responsible for defining connections, but init-input and 

242 init-output connections are not supported at runtime (they may be present 

243 as long as the task is never constructed). All output connections must 

244 use mock storage classes. `..Input` and `..PrerequisiteInput` connections 

245 that do not use mock storage classes will be handled by constructing a 

246 `MockDataset` from the `~lsst.daf.butler.DatasetRef` rather than actually 

247 reading them. 

248 """ 

249 

250 ConfigClass: ClassVar[type[PipelineTaskConfig]] = BaseTestPipelineTaskConfig 

251 

252 def __init__( 

253 self, 

254 *, 

255 config: BaseTestPipelineTaskConfig, 

256 initInputs: Mapping[str, Any], 

257 **kwargs: Any, 

258 ): 

259 super().__init__(config=config, **kwargs) 

260 self.fail_exception: type | None = None 

261 self.data_id_match = self.config.data_id_match() 

262 if self.data_id_match: 

263 self.fail_exception = doImportType(self.config.fail_exception) 

264 self.memory_required = ( 

265 Quantity(self.config.memory_required) if self.config.memory_required is not None else None 

266 ) 

267 # Look for, check, and record init-inputs. 

268 task_connections = self.ConfigClass.ConnectionsClass(config=config) 

269 mock_dataset_quantum = MockDatasetQuantum(task_label=self.getName(), data_id={}, inputs={}) 

270 for connection_name in task_connections.initInputs: 

271 input_dataset = initInputs[connection_name] 

272 if not isinstance(input_dataset, MockDataset): 

273 raise TypeError( 

274 f"Expected MockDataset instance for init-input {self.getName()}.{connection_name}: " 

275 f"got {input_dataset!r} of type {type(input_dataset)!r}." 

276 ) 

277 connection = task_connections.allConnections[connection_name] 

278 if input_dataset.dataset_type.name != connection.name: 

279 raise RuntimeError( 

280 f"Incorrect dataset type name for init-input {self.getName()}.{connection_name}: " 

281 f"got {input_dataset.dataset_type.name!r}, expected {connection.name!r}." 

282 ) 

283 if input_dataset.storage_class != connection.storageClass: 

284 raise RuntimeError( 

285 f"Incorrect storage class for init-input {self.getName()}.{connection_name}: " 

286 f"got {input_dataset.storage_class!r}, expected {connection.storageClass!r}." 

287 ) 

288 # To avoid very deep provenance we trim inputs to a single 

289 # level. 

290 input_dataset.quantum = None 

291 mock_dataset_quantum.inputs[connection_name] = [input_dataset] 

292 # Add init-outputs as task instance attributes. 

293 for connection_name in task_connections.initOutputs: 

294 connection = task_connections.allConnections[connection_name] 

295 output_dataset = MockDataset( 

296 dataset_id=None, # the task has no way to get this 

297 dataset_type=SerializedDatasetType( 

298 name=connection.name, 

299 storageClass=connection.storageClass, 

300 dimensions=[], 

301 ), 

302 data_id={}, 

303 run=None, # task also has no way to get this 

304 quantum=mock_dataset_quantum, 

305 output_connection_name=connection_name, 

306 ) 

307 setattr(self, connection_name, output_dataset) 

308 

309 config: BaseTestPipelineTaskConfig 

310 

311 def runQuantum( 

312 self, 

313 butlerQC: QuantumContext, 

314 inputRefs: InputQuantizedConnection, 

315 outputRefs: OutputQuantizedConnection, 

316 ) -> None: 

317 # docstring is inherited from the base class 

318 quantum = butlerQC.quantum 

319 

320 _LOG.info("Mocking execution of task '%s' on quantum %s", self.getName(), quantum.dataId) 

321 

322 assert quantum.dataId is not None, "Quantum DataId cannot be None" 

323 

324 # Possibly raise an exception. 

325 if self.data_id_match is not None and self.data_id_match.match(quantum.dataId): 

326 assert self.fail_exception is not None, "Exception type must be defined" 

327 message = f"Simulated failure: task={self.getName()} dataId={quantum.dataId}" 

328 if self.memory_required is not None: 

329 if butlerQC.resources.max_mem < self.memory_required: 

330 _LOG.info( 

331 "Simulating out-of-memory failure for task '%s' on quantum %s", 

332 self.getName(), 

333 quantum.dataId, 

334 ) 

335 raise self.fail_exception(message) 

336 else: 

337 _LOG.info("Simulating failure of task '%s' on quantum %s", self.getName(), quantum.dataId) 

338 raise self.fail_exception(message) 

339 

340 # Populate the bit of provenance we store in all outputs. 

341 _LOG.info("Reading input data for task '%s' on quantum %s", self.getName(), quantum.dataId) 

342 mock_dataset_quantum = MockDatasetQuantum( 

343 task_label=self.getName(), data_id=dict(quantum.dataId.mapping), inputs={} 

344 ) 

345 for name, refs in inputRefs: 

346 inputs_list = [] 

347 ref: DatasetRef 

348 for ref in ensure_iterable(refs): 

349 if isinstance(ref.datasetType.storageClass, MockStorageClass): 

350 input_dataset = butlerQC.get(ref) 

351 if isinstance(input_dataset, DeferredDatasetHandle): 

352 input_dataset = input_dataset.get() 

353 if not isinstance(input_dataset, MockDataset): 

354 raise TypeError( 

355 f"Expected MockDataset instance for {ref}; " 

356 f"got {input_dataset!r} of type {type(input_dataset)!r}." 

357 ) 

358 # To avoid very deep provenance we trim inputs to a single 

359 # level. 

360 input_dataset.quantum = None 

361 else: 

362 input_dataset = MockDataset( 

363 dataset_id=ref.id, 

364 dataset_type=ref.datasetType.to_simple(), 

365 data_id=dict(ref.dataId.mapping), 

366 run=ref.run, 

367 ) 

368 inputs_list.append(input_dataset) 

369 mock_dataset_quantum.inputs[name] = inputs_list 

370 

371 # store mock outputs 

372 for name, refs in outputRefs: 

373 for ref in ensure_iterable(refs): 

374 output = MockDataset( 

375 dataset_id=ref.id, 

376 dataset_type=ref.datasetType.to_simple(), 

377 data_id=dict(ref.dataId.mapping), 

378 run=ref.run, 

379 quantum=mock_dataset_quantum, 

380 output_connection_name=name, 

381 ) 

382 butlerQC.put(output, ref) 

383 

384 _LOG.info("Finished mocking task '%s' on quantum %s", self.getName(), quantum.dataId) 

385 

386 

387class MockPipelineDefaultTargetConnections(PipelineTaskConnections, dimensions=()): 

388 pass 

389 

390 

391class MockPipelineDefaultTargetConfig( 

392 PipelineTaskConfig, pipelineConnections=MockPipelineDefaultTargetConnections 

393): 

394 pass 

395 

396 

397class MockPipelineDefaultTargetTask(PipelineTask): 

398 """A `~lsst.pipe.base.PipelineTask` class used as the default target for 

399 ``MockPipelineTaskConfig.original``. 

400 

401 This is effectively a workaround for `lsst.pex.config.ConfigurableField` 

402 not supporting ``optional=True``, but that is generally a reasonable 

403 limitation for production code and it wouldn't make sense just to support 

404 test utilities. 

405 """ 

406 

407 ConfigClass = MockPipelineDefaultTargetConfig 

408 

409 

410class MockPipelineTaskConnections(BaseTestPipelineTaskConnections, dimensions=()): 

411 """A connections class that creates mock connections from the connections 

412 of a real PipelineTask. 

413 

414 Parameters 

415 ---------- 

416 config : `PipelineTaskConfig` 

417 The config to use for the connection. 

418 """ 

419 

420 def __init__(self, *, config: MockPipelineTaskConfig): 

421 self.original: PipelineTaskConnections = config.original.connections.ConnectionsClass( 

422 config=config.original.value 

423 ) 

424 self.dimensions.update(self.original.dimensions) 

425 self.unmocked_dataset_types = frozenset(config.unmocked_dataset_types) 

426 for name, connection in self.original.allConnections.items(): 

427 if connection.name not in self.unmocked_dataset_types: 

428 if connection.storageClass in ( 

429 acc.CONFIG_INIT_OUTPUT_STORAGE_CLASS, 

430 acc.METADATA_OUTPUT_STORAGE_CLASS, 

431 acc.LOG_OUTPUT_STORAGE_CLASS, 

432 ): 

433 # We don't mock the automatic output connections, so if 

434 # they're used as an input in any other connection, we 

435 # can't mock them there either. 

436 storage_class_name = connection.storageClass 

437 else: 

438 # We register the mock storage class with the global 

439 # singleton here, but can only put its name in the 

440 # connection. That means the same global singleton (or one 

441 # that also has these registrations) has to be available 

442 # whenever this dataset type is used. 

443 storage_class_name = MockStorageClass.get_or_register_mock(connection.storageClass).name 

444 kwargs: dict[str, Any] = {} 

445 if hasattr(connection, "dimensions"): 

446 connection_dimensions = set(connection.dimensions) 

447 # Replace the generic "skypix" placeholder with htm7, since 

448 # that requires the dataset type to have already been 

449 # registered. 

450 if "skypix" in connection_dimensions: 

451 connection_dimensions.remove("skypix") 

452 connection_dimensions.add("htm7") 

453 kwargs["dimensions"] = connection_dimensions 

454 connection = dataclasses.replace( 

455 connection, 

456 name=get_mock_name(connection.name), 

457 storageClass=storage_class_name, 

458 **kwargs, 

459 ) 

460 elif name in self.original.outputs: 

461 raise ValueError(f"Unmocked dataset type {connection.name!r} cannot be used as an output.") 

462 elif name in self.original.initInputs: 

463 raise ValueError( 

464 f"Unmocked dataset type {connection.name!r} cannot be used as an init-input." 

465 ) 

466 elif name in self.original.initOutputs: 

467 raise ValueError( 

468 f"Unmocked dataset type {connection.name!r} cannot be used as an init-output." 

469 ) 

470 setattr(self, name, connection) 

471 

472 def getSpatialBoundsConnections(self) -> Iterable[str]: 

473 return self.original.getSpatialBoundsConnections() 

474 

475 def getTemporalBoundsConnections(self) -> Iterable[str]: 

476 return self.original.getTemporalBoundsConnections() 

477 

478 def adjustQuantum( 

479 self, 

480 inputs: dict[str, tuple[BaseInput, Collection[DatasetRef]]], 

481 outputs: dict[str, tuple[Output, Collection[DatasetRef]]], 

482 label: str, 

483 data_id: DataCoordinate, 

484 ) -> tuple[ 

485 Mapping[str, tuple[BaseInput, Collection[DatasetRef]]], 

486 Mapping[str, tuple[Output, Collection[DatasetRef]]], 

487 ]: 

488 # Convert the given mappings from the mock dataset types to the 

489 # original dataset types they were produced from. 

490 original_inputs = {} 

491 for connection_name, (_, mock_refs) in inputs.items(): 

492 original_connection = getattr(self.original, connection_name) 

493 if original_connection.name in self.unmocked_dataset_types: 

494 refs = mock_refs 

495 else: 

496 refs = MockStorageClass.unmock_dataset_refs(mock_refs) 

497 original_inputs[connection_name] = (original_connection, refs) 

498 original_outputs = {} 

499 for connection_name, (_, mock_refs) in outputs.items(): 

500 original_connection = getattr(self.original, connection_name) 

501 if original_connection.name in self.unmocked_dataset_types: 

502 refs = mock_refs 

503 else: 

504 refs = MockStorageClass.unmock_dataset_refs(mock_refs) 

505 original_outputs[connection_name] = (original_connection, refs) 

506 # Call adjustQuantum on the original connections class. 

507 adjusted_original_inputs, adjusted_original_outputs = self.original.adjustQuantum( 

508 original_inputs, original_outputs, label, data_id 

509 ) 

510 # Convert the results back to the mock dataset type.s 

511 adjusted_inputs = {} 

512 for connection_name, (original_connection, original_refs) in adjusted_original_inputs.items(): 

513 if original_connection.name in self.unmocked_dataset_types: 

514 refs = original_refs 

515 else: 

516 refs = MockStorageClass.mock_dataset_refs(original_refs) 

517 adjusted_inputs[connection_name] = (getattr(self, connection_name), refs) 

518 adjusted_outputs = {} 

519 for connection_name, (original_connection, original_refs) in adjusted_original_outputs.items(): 

520 if original_connection.name in self.unmocked_dataset_types: 

521 refs = original_refs 

522 else: 

523 refs = MockStorageClass.mock_dataset_refs(original_refs) 

524 adjusted_outputs[connection_name] = (getattr(self, connection_name), refs) 

525 return adjusted_inputs, adjusted_outputs 

526 

527 

528class MockPipelineTaskConfig(BaseTestPipelineTaskConfig, pipelineConnections=MockPipelineTaskConnections): 

529 """Configuration class for `MockPipelineTask`.""" 

530 

531 original: ConfigurableField = ConfigurableField( 

532 doc="The original task being mocked by this one.", target=MockPipelineDefaultTargetTask 

533 ) 

534 

535 unmocked_dataset_types = ListField[str]( 

536 doc=( 

537 "Names of input dataset types that should be used as-is instead " 

538 "of being mocked. May include dataset types not relevant for " 

539 "this task, which will be ignored." 

540 ), 

541 default=(), 

542 optional=False, 

543 ) 

544 

545 

546class MockPipelineTask(BaseTestPipelineTask): 

547 """A test-utility implementation of `PipelineTask` with connections 

548 generated by mocking those of a real task. 

549 

550 Notes 

551 ----- 

552 At present `MockPipelineTask` simply drops any ``initInput`` and 

553 ``initOutput`` connections present on the original, since `MockDataset` 

554 creation for those would have to happen in the code that executes the task, 

555 not in the task itself. Because `MockPipelineTask` never instantiates the 

556 mock task (just its connections class), this is a limitation on what the 

557 mocks can be used to test, not anything deeper. 

558 """ 

559 

560 ConfigClass: ClassVar[type[PipelineTaskConfig]] = MockPipelineTaskConfig 

561 

562 

563class DynamicConnectionConfig(Config): 

564 """A config class that defines a completely dynamic connection.""" 

565 

566 dataset_type_name = Field[str](doc="Name for the dataset type as seen by the butler.", dtype=str) 

567 dimensions = ListField[str](doc="Dimensions for the dataset type.", dtype=str, default=[]) 

568 storage_class = Field[str]( 

569 doc="Name of the butler storage class for the dataset type.", dtype=str, default="StructuredDataDict" 

570 ) 

571 is_calibration = Field[bool](doc="Whether this dataset type is a calibration.", dtype=bool, default=False) 

572 multiple = Field[bool]( 

573 doc="Whether this connection gets or puts multiple datasets for each quantum.", 

574 dtype=bool, 

575 default=False, 

576 ) 

577 mock_storage_class = Field[bool]( 

578 doc="Whether the storage class should actually be a mock of the storage class given.", 

579 dtype=bool, 

580 default=True, 

581 ) 

582 

583 def make_connection(self, cls: type[_T]) -> _T: 

584 storage_class = self.storage_class 

585 if self.mock_storage_class: 

586 storage_class = MockStorageClass.get_or_register_mock(storage_class).name 

587 if issubclass(cls, cT.DimensionedConnection): 

588 return cls( # type: ignore 

589 name=self.dataset_type_name, 

590 storageClass=storage_class, 

591 isCalibration=self.is_calibration, 

592 multiple=self.multiple, 

593 dimensions=frozenset(self.dimensions), 

594 ) 

595 else: 

596 return cls( 

597 name=self.dataset_type_name, 

598 storageClass=storage_class, 

599 multiple=self.multiple, 

600 ) 

601 

602 

603class DynamicTestPipelineTaskConnections(PipelineTaskConnections, dimensions=()): 

604 """A connections class whose dimensions and connections are wholly 

605 determined via configuration. 

606 

607 Parameters 

608 ---------- 

609 config : `PipelineTaskConfig` 

610 Config to use for this connections object. 

611 """ 

612 

613 def __init__(self, *, config: DynamicTestPipelineTaskConfig): 

614 self.dimensions.update(config.dimensions) 

615 connection_config: DynamicConnectionConfig 

616 for connection_name, connection_config in config.init_inputs.items(): 

617 setattr(self, connection_name, connection_config.make_connection(cT.InitInput)) 

618 for connection_name, connection_config in config.init_outputs.items(): 

619 setattr(self, connection_name, connection_config.make_connection(cT.InitOutput)) 

620 for connection_name, connection_config in config.prerequisite_inputs.items(): 

621 setattr(self, connection_name, connection_config.make_connection(cT.PrerequisiteInput)) 

622 for connection_name, connection_config in config.inputs.items(): 

623 setattr(self, connection_name, connection_config.make_connection(cT.Input)) 

624 for connection_name, connection_config in config.outputs.items(): 

625 setattr(self, connection_name, connection_config.make_connection(cT.Output)) 

626 

627 

628class DynamicTestPipelineTaskConfig( 

629 PipelineTaskConfig, pipelineConnections=DynamicTestPipelineTaskConnections 

630): 

631 """Configuration for DynamicTestPipelineTask.""" 

632 

633 dimensions = ListField[str](doc="Dimensions for the task's quanta.", dtype=str, default=[]) 

634 init_inputs = ConfigDictField( 

635 doc=( 

636 "Init-input connections, keyed by the connection name as seen by the task. " 

637 "Must be empty if the task will be constructed." 

638 ), 

639 keytype=str, 

640 itemtype=DynamicConnectionConfig, 

641 default={}, 

642 ) 

643 init_outputs = ConfigDictField( 

644 doc=( 

645 "Init-output connections, keyed by the connection name as seen by the task. " 

646 "Must be empty if the task will be constructed." 

647 ), 

648 keytype=str, 

649 itemtype=DynamicConnectionConfig, 

650 default={}, 

651 ) 

652 prerequisite_inputs = ConfigDictField( 

653 doc="Prerequisite input connections, keyed by the connection name as seen by the task.", 

654 keytype=str, 

655 itemtype=DynamicConnectionConfig, 

656 default={}, 

657 ) 

658 inputs = ConfigDictField( 

659 doc="Regular input connections, keyed by the connection name as seen by the task.", 

660 keytype=str, 

661 itemtype=DynamicConnectionConfig, 

662 default={}, 

663 ) 

664 outputs = ConfigDictField( 

665 doc="Regular output connections, keyed by the connection name as seen by the task.", 

666 keytype=str, 

667 itemtype=DynamicConnectionConfig, 

668 default={}, 

669 ) 

670 

671 

672class DynamicTestPipelineTask(BaseTestPipelineTask): 

673 """A test-utility implementation of `PipelineTask` with dimensions and 

674 connections determined wholly from configuration. 

675 """ 

676 

677 ConfigClass: ClassVar[type[PipelineTaskConfig]] = DynamicTestPipelineTaskConfig