Coverage for python / lsst / pipe / base / tests / mocks / _pipeline_task.py: 26%

258 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-14 23:47 +0000

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27from __future__ import annotations 

28 

29from lsst.pipe.base.connectionTypes import BaseInput, Output 

30 

31__all__ = ( 

32 "DynamicConnectionConfig", 

33 "DynamicTestPipelineTask", 

34 "DynamicTestPipelineTaskConfig", 

35 "ForcedFailure", 

36 "MockAlgorithmError", 

37 "MockPipelineTask", 

38 "MockPipelineTaskConfig", 

39 "mock_pipeline_graph", 

40) 

41 

42import dataclasses 

43import logging 

44import signal 

45import time 

46from collections.abc import Collection, Iterable, Mapping 

47from typing import TYPE_CHECKING, Any, ClassVar, TypeVar 

48 

49from astropy.units import Quantity 

50 

51from lsst.daf.butler import DataCoordinate, DatasetRef, DeferredDatasetHandle, Quantum, SerializedDatasetType 

52from lsst.pex.config import Config, ConfigDictField, ConfigurableField, Field, ListField 

53from lsst.utils.doImport import doImportType 

54from lsst.utils.introspection import get_full_type_name 

55from lsst.utils.iteration import ensure_iterable 

56 

57from ... import connectionTypes as cT 

58from ..._status import AlgorithmError, AnnotatedPartialOutputsError 

59from ...automatic_connection_constants import METADATA_OUTPUT_CONNECTION_NAME, METADATA_OUTPUT_STORAGE_CLASS 

60from ...config import PipelineTaskConfig 

61from ...connections import InputQuantizedConnection, OutputQuantizedConnection, PipelineTaskConnections 

62from ...pipeline_graph import PipelineGraph 

63from ...pipelineTask import PipelineTask 

64from ._data_id_match import DataIdMatch 

65from ._storage_class import ( 

66 ConvertedUnmockedDataset, 

67 MockDataset, 

68 MockDatasetQuantum, 

69 MockStorageClass, 

70 get_mock_name, 

71) 

72 

73_LOG = logging.getLogger(__name__) 

74 

75if TYPE_CHECKING: 

76 from ..._quantumContext import QuantumContext 

77 

78 

79_T = TypeVar("_T", bound=cT.BaseConnection) 

80 

81 

82@dataclasses.dataclass 

83class ForcedFailure: 

84 """Information about an exception that should be raised by one or more 

85 quanta. 

86 """ 

87 

88 condition: str 

89 """Butler expression-language string that matches the data IDs that should 

90 raise. 

91 """ 

92 

93 exception_type: type[BaseException] | None = None 

94 """The type of exception to raise.""" 

95 

96 memory_required: Quantity | None = None 

97 """If not `None`, this failure simulates an out-of-memory failure by 

98 raising only if this value exceeds `ExecutionResources.max_mem`. 

99 """ 

100 

101 def set_config(self, config: MockPipelineTaskConfig) -> None: 

102 config.fail_condition = self.condition 

103 if self.exception_type: 

104 config.fail_exception = get_full_type_name(self.exception_type) 

105 config.memory_required = self.memory_required 

106 

107 

108class MockAlgorithmError(AlgorithmError): 

109 """A subclass of `..AlgorithmError` chained to 

110 `..AnnotatedPartialOutputsError` when the latter is configured to be raised 

111 by `MockPipelineTask`. 

112 """ 

113 

114 @property 

115 def metadata(self) -> dict[str, int]: 

116 return {"badness": 12} 

117 

118 

119def mock_pipeline_graph( 

120 original_graph: PipelineGraph, 

121 unmocked_dataset_types: Iterable[str] = (), 

122 force_failures: Mapping[str, ForcedFailure] | None = None, 

123) -> PipelineGraph: 

124 """Create mocks for a full pipeline graph. 

125 

126 Parameters 

127 ---------- 

128 original_graph : `~..pipeline_graph.PipelineGraph` 

129 Original tasks and configuration to mock. 

130 unmocked_dataset_types : `~collections.abc.Iterable` [ `str` ], optional 

131 Names of overall-input dataset types that should not be replaced with 

132 mocks. "Automatic" datasets written by the execution framework such 

133 as configs, logs, and metadata are implicitly included. 

134 force_failures : `~collections.abc.Mapping` [ `str`, `ForcedFailure` ] 

135 Mapping from original task label to information about an exception one 

136 or more quanta for this task should raise. 

137 

138 Returns 

139 ------- 

140 mocked : `~..pipeline_graph.PipelineGraph` 

141 Pipeline graph using `MockPipelineTask` configurations that target the 

142 original tasks. Never resolved. 

143 """ 

144 unmocked_dataset_types = list(unmocked_dataset_types) 

145 if force_failures is None: 

146 force_failures = {} 

147 result = PipelineGraph(description=original_graph.description) 

148 for task_node in original_graph.tasks.values(): 

149 unmocked_dataset_types.append(task_node.init.config_output.dataset_type_name) 

150 if task_node.log_output is not None: 

151 unmocked_dataset_types.append(task_node.log_output.dataset_type_name) 

152 unmocked_dataset_types.append(task_node.metadata_output.dataset_type_name) 

153 for original_task_node in original_graph.tasks.values(): 

154 config = MockPipelineTaskConfig() 

155 config.original.retarget(original_task_node.task_class) 

156 config.original = original_task_node.config 

157 config.unmocked_dataset_types.extend(unmocked_dataset_types) 

158 if original_task_node.label in force_failures: 

159 force_failures[original_task_node.label].set_config(config) 

160 result.add_task(get_mock_name(original_task_node.label), MockPipelineTask, config=config) 

161 return result 

162 

163 

164class BaseTestPipelineTaskConnections(PipelineTaskConnections, dimensions=()): 

165 pass 

166 

167 

168class BaseTestPipelineTaskConfig(PipelineTaskConfig, pipelineConnections=BaseTestPipelineTaskConnections): 

169 fail_condition = Field[str]( 

170 dtype=str, 

171 default="", 

172 doc=( 

173 "Condition on Data ID to raise an exception. String expression which includes attributes of " 

174 "quantum data ID using a syntax of daf_butler user expressions (e.g. 'visit = 123')." 

175 ), 

176 ) 

177 

178 fail_exception = Field[str]( 

179 dtype=str, 

180 default="builtins.ValueError", 

181 doc=( 

182 "Class name of the exception to raise when fail condition is triggered. Can be " 

183 "'lsst.pipe.base.NoWorkFound' to specify non-failure exception." 

184 ), 

185 ) 

186 

187 fail_signal = Field[int](dtype=int, optional=True, doc="Signal to raise instead of an exception.") 

188 

189 memory_required = Field[str]( 

190 dtype=str, 

191 default=None, 

192 optional=True, 

193 doc=( 

194 "If not None, simulate an out-of-memory failure by raising only if ExecutionResource.max_mem " 

195 "exceeds this value. This string should include units as parsed by astropy.units.Quantity " 

196 "(e.g. '4GB')." 

197 ), 

198 ) 

199 

200 sleep = Field[float]( 

201 dtype=float, 

202 default=0.0, 

203 doc="Time to sleep (seconds) before mock execution reading inputs or failing.", 

204 ) 

205 

206 int_value = Field[int]( 

207 "Arbitrary integer value to write into mock output datasets", dtype=int, optional=True, default=None 

208 ) 

209 str_value = Field[str]( 

210 "Arbitrary string value to write into mock output datasets", dtype=str, optional=True, default=None 

211 ) 

212 

213 def data_id_match(self) -> DataIdMatch | None: 

214 if not self.fail_condition: 

215 return None 

216 return DataIdMatch(self.fail_condition) 

217 

218 

219class BaseTestPipelineTask(PipelineTask): 

220 """A base class for test-utility `PipelineTask` classes that read and write 

221 mock datasets `runQuantum`. 

222 

223 Parameters 

224 ---------- 

225 config : `PipelineTaskConfig` 

226 The pipeline task config. 

227 initInputs : `~collections.abc.Mapping` 

228 The init inputs datasets. 

229 **kwargs : `~typing.Any` 

230 Keyword parameters passed to base class constructor. 

231 

232 Notes 

233 ----- 

234 This class overrides `runQuantum` to read inputs and write a bit of 

235 provenance into all of its outputs (always `MockDataset` instances). It 

236 can also be configured to raise exceptions on certain data IDs. It reads 

237 `MockDataset` inputs and simulates reading inputs of other types by 

238 creating `MockDataset` inputs from their DatasetRefs. 

239 

240 Subclasses are responsible for defining connections, but init-input and 

241 init-output connections are not supported at runtime (they may be present 

242 as long as the task is never constructed). All output connections must 

243 use mock storage classes. `..Input` and `..PrerequisiteInput` connections 

244 that do not use mock storage classes will be handled by constructing a 

245 `MockDataset` from the `~lsst.daf.butler.DatasetRef` rather than actually 

246 reading them. 

247 """ 

248 

249 ConfigClass: ClassVar[type[PipelineTaskConfig]] = BaseTestPipelineTaskConfig 

250 

251 def __init__( 

252 self, 

253 *, 

254 config: BaseTestPipelineTaskConfig, 

255 initInputs: Mapping[str, Any], 

256 **kwargs: Any, 

257 ): 

258 super().__init__(config=config, **kwargs) 

259 self.fail_exception: type | None = None 

260 self.data_id_match = self.config.data_id_match() 

261 if self.data_id_match: 

262 self.fail_exception = doImportType(self.config.fail_exception) 

263 self.memory_required = ( 

264 Quantity(self.config.memory_required) if self.config.memory_required is not None else None 

265 ) 

266 # Look for, check, and record init-inputs. 

267 task_connections = self.ConfigClass.ConnectionsClass(config=config) 

268 mock_dataset_quantum = MockDatasetQuantum(task_label=self.getName(), data_id={}, inputs={}) 

269 for connection_name in task_connections.initInputs: 

270 input_dataset = initInputs[connection_name] 

271 if not isinstance(input_dataset, MockDataset): 

272 raise TypeError( 

273 f"Expected MockDataset instance for init-input {self.getName()}.{connection_name}: " 

274 f"got {input_dataset!r} of type {type(input_dataset)!r}." 

275 ) 

276 connection = task_connections.allConnections[connection_name] 

277 if input_dataset.dataset_type.name != connection.name: 

278 raise RuntimeError( 

279 f"Incorrect dataset type name for init-input {self.getName()}.{connection_name}: " 

280 f"got {input_dataset.dataset_type.name!r}, expected {connection.name!r}." 

281 ) 

282 if input_dataset.storage_class != connection.storageClass: 

283 raise RuntimeError( 

284 f"Incorrect storage class for init-input {self.getName()}.{connection_name}: " 

285 f"got {input_dataset.storage_class!r}, expected {connection.storageClass!r}." 

286 ) 

287 # To avoid very deep provenance we trim inputs to a single 

288 # level. 

289 input_dataset.quantum = None 

290 mock_dataset_quantum.inputs[connection_name] = [input_dataset] 

291 # Add init-outputs as task instance attributes. 

292 for connection_name in task_connections.initOutputs: 

293 connection = task_connections.allConnections[connection_name] 

294 output_dataset = MockDataset( 

295 dataset_id=None, # the task has no way to get this 

296 dataset_type=SerializedDatasetType( 

297 name=connection.name, 

298 storageClass=connection.storageClass, 

299 dimensions=[], 

300 ), 

301 data_id={}, 

302 run=None, # task also has no way to get this 

303 quantum=mock_dataset_quantum, 

304 output_connection_name=connection_name, 

305 int_value=self.config.int_value, 

306 str_value=self.config.str_value, 

307 ) 

308 setattr(self, connection_name, output_dataset) 

309 

310 config: BaseTestPipelineTaskConfig 

311 

312 def runQuantum( 

313 self, 

314 butlerQC: QuantumContext, 

315 inputRefs: InputQuantizedConnection, 

316 outputRefs: OutputQuantizedConnection, 

317 ) -> None: 

318 # docstring is inherited from the base class 

319 quantum = butlerQC.quantum 

320 

321 _LOG.info("Mocking execution of task '%s' on quantum %s", self.getName(), quantum.dataId) 

322 

323 if self.config.sleep: 

324 time.sleep(self.config.sleep) 

325 

326 assert quantum.dataId is not None, "Quantum DataId cannot be None" 

327 

328 # Possibly raise an exception. 

329 if self.data_id_match is not None and self.data_id_match.match(quantum.dataId): 

330 assert self.fail_exception is not None, "Exception type must be defined" 

331 

332 if self.memory_required is not None: 

333 if butlerQC.resources.max_mem < self.memory_required: 

334 _LOG.info( 

335 "Simulating out-of-memory failure for task '%s' on quantum %s", 

336 self.getName(), 

337 quantum.dataId, 

338 ) 

339 self._fail(quantum) 

340 else: 

341 _LOG.info("Simulating failure of task '%s' on quantum %s", self.getName(), quantum.dataId) 

342 self._fail(quantum) 

343 

344 # Populate the bit of provenance we store in all outputs. 

345 _LOG.info("Reading input data for task '%s' on quantum %s", self.getName(), quantum.dataId) 

346 mock_dataset_quantum = MockDatasetQuantum( 

347 task_label=self.getName(), data_id=dict(quantum.dataId.mapping), inputs={} 

348 ) 

349 for name, refs in inputRefs: 

350 inputs_list = [] 

351 ref: DatasetRef 

352 for ref in ensure_iterable(refs): 

353 if isinstance(ref.datasetType.storageClass, MockStorageClass): 

354 input_dataset = butlerQC.get(ref) 

355 if isinstance(input_dataset, DeferredDatasetHandle): 

356 input_dataset = input_dataset.get() 

357 if isinstance(input_dataset, MockDataset): 

358 # To avoid very deep provenance we trim inputs to a 

359 # single level. 

360 input_dataset.quantum = None 

361 elif not isinstance(input_dataset, ConvertedUnmockedDataset): 

362 raise TypeError( 

363 f"Expected MockDataset or ConvertedUnmockedDataset instance for {ref}; " 

364 f"got {input_dataset!r} of type {type(input_dataset)!r}." 

365 ) 

366 else: 

367 input_dataset = MockDataset( 

368 dataset_id=ref.id, 

369 dataset_type=ref.datasetType.to_simple(), 

370 data_id=dict(ref.dataId.mapping), 

371 run=ref.run, 

372 ) 

373 inputs_list.append(input_dataset) 

374 mock_dataset_quantum.inputs[name] = inputs_list 

375 

376 # store mock outputs 

377 for name, refs in outputRefs: 

378 for ref in ensure_iterable(refs): 

379 output = MockDataset( 

380 dataset_id=ref.id, 

381 dataset_type=ref.datasetType.to_simple(), 

382 data_id=dict(ref.dataId.mapping), 

383 run=ref.run, 

384 quantum=mock_dataset_quantum, 

385 output_connection_name=name, 

386 int_value=self.config.int_value, 

387 str_value=self.config.str_value, 

388 ) 

389 butlerQC.put(output, ref) 

390 

391 _LOG.info("Finished mocking task '%s' on quantum %s", self.getName(), quantum.dataId) 

392 

393 def _fail(self, quantum: Quantum) -> None: 

394 """Raise the configured exception. 

395 

396 Parameters 

397 ---------- 

398 quantum : `lsst.daf.butler.Quantum` 

399 Quantum producing the error. 

400 """ 

401 message = f"Simulated failure: task={self.getName()} dataId={quantum.dataId}" 

402 # Type annotations for optional config fields are broken, so MyPy 

403 # doesn't think fail_signal could be None. 

404 if self.config.fail_signal is not None: 

405 signal.raise_signal(signal.Signals(self.config.fail_signal)) 

406 elif self.fail_exception is AnnotatedPartialOutputsError: # type: ignore[unreachable] 

407 # This exception is expected to always chain another. 

408 try: 

409 raise MockAlgorithmError(message) 

410 except AlgorithmError as original: 

411 error = AnnotatedPartialOutputsError.annotate(original, self, log=self.log) 

412 raise error from original 

413 else: 

414 assert self.fail_exception is not None, "Method should not be called." 

415 raise self.fail_exception(message) 

416 

417 

418class MockPipelineDefaultTargetConnections(PipelineTaskConnections, dimensions=()): 

419 pass 

420 

421 

422class MockPipelineDefaultTargetConfig( 

423 PipelineTaskConfig, pipelineConnections=MockPipelineDefaultTargetConnections 

424): 

425 pass 

426 

427 

428class MockPipelineDefaultTargetTask(PipelineTask): 

429 """A `~lsst.pipe.base.PipelineTask` class used as the default target for 

430 ``MockPipelineTaskConfig.original``. 

431 

432 This is effectively a workaround for `lsst.pex.config.ConfigurableField` 

433 not supporting ``optional=True``, but that is generally a reasonable 

434 limitation for production code and it wouldn't make sense just to support 

435 test utilities. 

436 """ 

437 

438 ConfigClass = MockPipelineDefaultTargetConfig 

439 

440 

441class MockPipelineTaskConnections(BaseTestPipelineTaskConnections, dimensions=()): 

442 """A connections class that creates mock connections from the connections 

443 of a real PipelineTask. 

444 

445 Parameters 

446 ---------- 

447 config : `PipelineTaskConfig` 

448 The config to use for the connection. 

449 """ 

450 

451 def __init__(self, *, config: MockPipelineTaskConfig): 

452 self.original: PipelineTaskConnections = config.original.connections.ConnectionsClass( 

453 config=config.original.value 

454 ) 

455 self.dimensions.update(self.original.dimensions) 

456 self.unmocked_dataset_types = frozenset(config.unmocked_dataset_types) 

457 for name, connection in self.original.allConnections.items(): 

458 if connection.name not in self.unmocked_dataset_types: 

459 # We register the mock storage class with the global 

460 # singleton here, but can only put its name in the 

461 # connection. That means the same global singleton (or one 

462 # that also has these registrations) has to be available 

463 # whenever this dataset type is used. 

464 storage_class_name = MockStorageClass.get_or_register_mock(connection.storageClass).name 

465 kwargs: dict[str, Any] = {} 

466 if hasattr(connection, "dimensions"): 

467 connection_dimensions = set(connection.dimensions) 

468 # Replace the generic "skypix" placeholder with htm7, since 

469 # that requires the dataset type to have already been 

470 # registered. 

471 if "skypix" in connection_dimensions: 

472 connection_dimensions.remove("skypix") 

473 connection_dimensions.add("htm7") 

474 kwargs["dimensions"] = connection_dimensions 

475 connection = dataclasses.replace( 

476 connection, 

477 name=get_mock_name(connection.name), 

478 storageClass=storage_class_name, 

479 **kwargs, 

480 ) 

481 elif name in self.original.outputs: 

482 raise ValueError(f"Unmocked dataset type {connection.name!r} cannot be used as an output.") 

483 elif name in self.original.initInputs: 

484 raise ValueError( 

485 f"Unmocked dataset type {connection.name!r} cannot be used as an init-input." 

486 ) 

487 elif name in self.original.initOutputs: 

488 raise ValueError( 

489 f"Unmocked dataset type {connection.name!r} cannot be used as an init-output." 

490 ) 

491 elif ( 

492 connection.name.endswith(METADATA_OUTPUT_CONNECTION_NAME) 

493 and connection.storageClass == METADATA_OUTPUT_STORAGE_CLASS 

494 ): 

495 # Task metadata does not use a mock storage class, because it's 

496 # written by the system, but it does end up with the _mock_* 

497 # prefix because the task label does. 

498 connection = dataclasses.replace(connection, name=get_mock_name(connection.name)) 

499 setattr(self, name, connection) 

500 

501 def getSpatialBoundsConnections(self) -> Iterable[str]: 

502 return self.original.getSpatialBoundsConnections() 

503 

504 def getTemporalBoundsConnections(self) -> Iterable[str]: 

505 return self.original.getTemporalBoundsConnections() 

506 

507 def adjustQuantum( 

508 self, 

509 inputs: dict[str, tuple[BaseInput, Collection[DatasetRef]]], 

510 outputs: dict[str, tuple[Output, Collection[DatasetRef]]], 

511 label: str, 

512 data_id: DataCoordinate, 

513 ) -> tuple[ 

514 Mapping[str, tuple[BaseInput, Collection[DatasetRef]]], 

515 Mapping[str, tuple[Output, Collection[DatasetRef]]], 

516 ]: 

517 # Convert the given mappings from the mock dataset types to the 

518 # original dataset types they were produced from. 

519 original_inputs = {} 

520 for connection_name, (_, mock_refs) in inputs.items(): 

521 original_connection = getattr(self.original, connection_name) 

522 if original_connection.name in self.unmocked_dataset_types: 

523 refs = mock_refs 

524 else: 

525 refs = MockStorageClass.unmock_dataset_refs(mock_refs) 

526 original_inputs[connection_name] = (original_connection, refs) 

527 original_outputs = {} 

528 for connection_name, (_, mock_refs) in outputs.items(): 

529 original_connection = getattr(self.original, connection_name) 

530 if original_connection.name in self.unmocked_dataset_types: 

531 refs = mock_refs 

532 else: 

533 refs = MockStorageClass.unmock_dataset_refs(mock_refs) 

534 original_outputs[connection_name] = (original_connection, refs) 

535 # Call adjustQuantum on the original connections class. 

536 adjusted_original_inputs, adjusted_original_outputs = self.original.adjustQuantum( 

537 original_inputs, original_outputs, label, data_id 

538 ) 

539 # Convert the results back to the mock dataset type.s 

540 adjusted_inputs = {} 

541 for connection_name, (original_connection, original_refs) in adjusted_original_inputs.items(): 

542 if original_connection.name in self.unmocked_dataset_types: 

543 refs = original_refs 

544 else: 

545 refs = MockStorageClass.mock_dataset_refs(original_refs) 

546 adjusted_inputs[connection_name] = (getattr(self, connection_name), refs) 

547 adjusted_outputs = {} 

548 for connection_name, (original_connection, original_refs) in adjusted_original_outputs.items(): 

549 if original_connection.name in self.unmocked_dataset_types: 

550 refs = original_refs 

551 else: 

552 refs = MockStorageClass.mock_dataset_refs(original_refs) 

553 adjusted_outputs[connection_name] = (getattr(self, connection_name), refs) 

554 return adjusted_inputs, adjusted_outputs 

555 

556 

557class MockPipelineTaskConfig(BaseTestPipelineTaskConfig, pipelineConnections=MockPipelineTaskConnections): 

558 """Configuration class for `MockPipelineTask`.""" 

559 

560 original: ConfigurableField = ConfigurableField( 

561 doc="The original task being mocked by this one.", target=MockPipelineDefaultTargetTask 

562 ) 

563 

564 unmocked_dataset_types = ListField[str]( 

565 doc=( 

566 "Names of input dataset types that should be used as-is instead " 

567 "of being mocked. May include dataset types not relevant for " 

568 "this task, which will be ignored." 

569 ), 

570 default=(), 

571 optional=False, 

572 ) 

573 

574 

575class MockPipelineTask(BaseTestPipelineTask): 

576 """A test-utility implementation of `PipelineTask` with connections 

577 generated by mocking those of a real task. 

578 

579 Notes 

580 ----- 

581 At present `MockPipelineTask` simply drops any ``initInput`` and 

582 ``initOutput`` connections present on the original, since `MockDataset` 

583 creation for those would have to happen in the code that executes the task, 

584 not in the task itself. Because `MockPipelineTask` never instantiates the 

585 mock task (just its connections class), this is a limitation on what the 

586 mocks can be used to test, not anything deeper. 

587 """ 

588 

589 ConfigClass: ClassVar[type[PipelineTaskConfig]] = MockPipelineTaskConfig 

590 

591 

592class DynamicConnectionConfig(Config): 

593 """A config class that defines a completely dynamic connection.""" 

594 

595 dataset_type_name = Field[str](doc="Name for the dataset type as seen by the butler.", dtype=str) 

596 dimensions = ListField[str](doc="Dimensions for the dataset type.", dtype=str, default=[]) 

597 storage_class = Field[str]( 

598 doc="Name of the butler storage class for the dataset type.", dtype=str, default="StructuredDataDict" 

599 ) 

600 is_calibration = Field[bool](doc="Whether this dataset type is a calibration.", dtype=bool, default=False) 

601 multiple = Field[bool]( 

602 doc="Whether this connection gets or puts multiple datasets for each quantum.", 

603 dtype=bool, 

604 default=False, 

605 ) 

606 mock_storage_class = Field[bool]( 

607 doc="Whether the storage class should actually be a mock of the storage class given.", 

608 dtype=bool, 

609 default=True, 

610 ) 

611 minimum = Field[int]( 

612 doc="Minimum number of datasets per quantum required for this connection. Ignored for non-inputs.", 

613 dtype=int, 

614 default=1, 

615 ) 

616 

617 def make_connection(self, cls: type[_T]) -> _T: 

618 storage_class = self.storage_class 

619 if self.mock_storage_class: 

620 storage_class = MockStorageClass.get_or_register_mock(storage_class).name 

621 if issubclass(cls, cT.BaseInput): 

622 return cls( # type: ignore 

623 name=self.dataset_type_name, 

624 storageClass=storage_class, 

625 isCalibration=self.is_calibration, 

626 multiple=self.multiple, 

627 dimensions=frozenset(self.dimensions), 

628 minimum=self.minimum, 

629 ) 

630 elif issubclass(cls, cT.DimensionedConnection): 

631 return cls( # type: ignore 

632 name=self.dataset_type_name, 

633 storageClass=storage_class, 

634 isCalibration=self.is_calibration, 

635 multiple=self.multiple, 

636 dimensions=frozenset(self.dimensions), 

637 ) 

638 else: 

639 return cls( 

640 name=self.dataset_type_name, 

641 storageClass=storage_class, 

642 multiple=self.multiple, 

643 ) 

644 

645 

646class DynamicTestPipelineTaskConnections(PipelineTaskConnections, dimensions=()): 

647 """A connections class whose dimensions and connections are wholly 

648 determined via configuration. 

649 

650 Parameters 

651 ---------- 

652 config : `PipelineTaskConfig` 

653 Config to use for this connections object. 

654 """ 

655 

656 def __init__(self, *, config: DynamicTestPipelineTaskConfig): 

657 self.dimensions.update(config.dimensions) 

658 connection_config: DynamicConnectionConfig 

659 for connection_name, connection_config in config.init_inputs.items(): 

660 setattr(self, connection_name, connection_config.make_connection(cT.InitInput)) 

661 for connection_name, connection_config in config.init_outputs.items(): 

662 setattr(self, connection_name, connection_config.make_connection(cT.InitOutput)) 

663 for connection_name, connection_config in config.prerequisite_inputs.items(): 

664 setattr(self, connection_name, connection_config.make_connection(cT.PrerequisiteInput)) 

665 for connection_name, connection_config in config.inputs.items(): 

666 setattr(self, connection_name, connection_config.make_connection(cT.Input)) 

667 for connection_name, connection_config in config.outputs.items(): 

668 setattr(self, connection_name, connection_config.make_connection(cT.Output)) 

669 

670 

671class DynamicTestPipelineTaskConfig( 

672 BaseTestPipelineTaskConfig, pipelineConnections=DynamicTestPipelineTaskConnections 

673): 

674 """Configuration for DynamicTestPipelineTask.""" 

675 

676 dimensions = ListField[str](doc="Dimensions for the task's quanta.", dtype=str, default=[]) 

677 init_inputs = ConfigDictField( 

678 doc=( 

679 "Init-input connections, keyed by the connection name as seen by the task. " 

680 "Must be empty if the task will be constructed." 

681 ), 

682 keytype=str, 

683 itemtype=DynamicConnectionConfig, 

684 default={}, 

685 ) 

686 init_outputs = ConfigDictField( 

687 doc=( 

688 "Init-output connections, keyed by the connection name as seen by the task. " 

689 "Must be empty if the task will be constructed." 

690 ), 

691 keytype=str, 

692 itemtype=DynamicConnectionConfig, 

693 default={}, 

694 ) 

695 prerequisite_inputs = ConfigDictField( 

696 doc="Prerequisite input connections, keyed by the connection name as seen by the task.", 

697 keytype=str, 

698 itemtype=DynamicConnectionConfig, 

699 default={}, 

700 ) 

701 inputs = ConfigDictField( 

702 doc="Regular input connections, keyed by the connection name as seen by the task.", 

703 keytype=str, 

704 itemtype=DynamicConnectionConfig, 

705 default={}, 

706 ) 

707 outputs = ConfigDictField( 

708 doc="Regular output connections, keyed by the connection name as seen by the task.", 

709 keytype=str, 

710 itemtype=DynamicConnectionConfig, 

711 default={}, 

712 ) 

713 

714 

715class DynamicTestPipelineTask(BaseTestPipelineTask): 

716 """A test-utility implementation of `PipelineTask` with dimensions and 

717 connections determined wholly from configuration. 

718 """ 

719 

720 ConfigClass: ClassVar[type[PipelineTaskConfig]] = DynamicTestPipelineTaskConfig