Coverage for python/lsst/pipe/base/tests/mocks/_pipeline_task.py: 24%

206 statements  

« prev     ^ index     » next       coverage.py v7.3.0, created at 2023-08-23 10:31 +0000

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23from lsst.pipe.base.connectionTypes import BaseInput, Output 

24 

25__all__ = ( 

26 "DynamicConnectionConfig", 

27 "DynamicTestPipelineTask", 

28 "DynamicTestPipelineTaskConfig", 

29 "MockPipelineTask", 

30 "MockPipelineTaskConfig", 

31 "mock_task_defs", 

32) 

33 

34import dataclasses 

35import logging 

36from collections.abc import Collection, Iterable, Mapping 

37from typing import TYPE_CHECKING, Any, ClassVar, TypeVar 

38 

39from lsst.daf.butler import ( 

40 DataCoordinate, 

41 DatasetRef, 

42 DeferredDatasetHandle, 

43 SerializedDatasetType, 

44 SerializedDimensionGraph, 

45) 

46from lsst.pex.config import Config, ConfigDictField, ConfigurableField, Field, ListField 

47from lsst.utils.doImport import doImportType 

48from lsst.utils.introspection import get_full_type_name 

49from lsst.utils.iteration import ensure_iterable 

50 

51from ... import connectionTypes as cT 

52from ...config import PipelineTaskConfig 

53from ...connections import InputQuantizedConnection, OutputQuantizedConnection, PipelineTaskConnections 

54from ...pipeline import TaskDef 

55from ...pipelineTask import PipelineTask 

56from ._data_id_match import DataIdMatch 

57from ._storage_class import MockDataset, MockDatasetQuantum, MockStorageClass, get_mock_name 

58 

59_LOG = logging.getLogger(__name__) 

60 

61if TYPE_CHECKING: 

62 from ..._quantumContext import QuantumContext 

63 

64 

65_T = TypeVar("_T", bound=cT.BaseConnection) 

66 

67 

68def mock_task_defs( 

69 originals: Iterable[TaskDef], 

70 unmocked_dataset_types: Iterable[str] = (), 

71 force_failures: Mapping[str, tuple[str, type[Exception] | None]] | None = None, 

72) -> list[TaskDef]: 

73 """Create mocks for an iterable of TaskDefs. 

74 

75 Parameters 

76 ---------- 

77 originals : `~collections.abc.Iterable` [ `TaskDef` ] 

78 Original tasks and configuration to mock. 

79 unmocked_dataset_types : `~collections.abc.Iterable` [ `str` ], optional 

80 Names of overall-input dataset types that should not be replaced with 

81 mocks. 

82 force_failures : `~collections.abc.Mapping` [ `str`, `tuple` [ `str`, \ 

83 `type` [ `Exception` ] or `None` ] ] 

84 Mapping from original task label to a 2-tuple indicating that some 

85 quanta should raise an exception when executed. The first entry is a 

86 data ID match using the butler expression language (i.e. a string of 

87 the sort passed ass the ``where`` argument to butler query methods), 

88 while the second is the type of exception to raise when the quantum 

89 data ID matches the expression. 

90 

91 Returns 

92 ------- 

93 mocked : `list` [ `TaskDef` ] 

94 List of `TaskDef` objects using `MockPipelineTask` configurations that 

95 target the original tasks, in the same order. 

96 """ 

97 unmocked_dataset_types = tuple(unmocked_dataset_types) 

98 if force_failures is None: 

99 force_failures = {} 

100 results: list[TaskDef] = [] 

101 for original_task_def in originals: 

102 config = MockPipelineTaskConfig() 

103 config.original.retarget(original_task_def.taskClass) 

104 config.original = original_task_def.config 

105 config.unmocked_dataset_types.extend(unmocked_dataset_types) 

106 if original_task_def.label in force_failures: 

107 condition, exception_type = force_failures[original_task_def.label] 

108 config.fail_condition = condition 

109 if exception_type is not None: 

110 config.fail_exception = get_full_type_name(exception_type) 

111 mock_task_def = TaskDef( 

112 config=config, taskClass=MockPipelineTask, label=get_mock_name(original_task_def.label) 

113 ) 

114 results.append(mock_task_def) 

115 return results 

116 

117 

118class BaseTestPipelineTaskConnections(PipelineTaskConnections, dimensions=()): 

119 pass 

120 

121 

122class BaseTestPipelineTaskConfig(PipelineTaskConfig, pipelineConnections=BaseTestPipelineTaskConnections): 

123 fail_condition = Field[str]( 

124 dtype=str, 

125 default="", 

126 doc=( 

127 "Condition on Data ID to raise an exception. String expression which includes attributes of " 

128 "quantum data ID using a syntax of daf_butler user expressions (e.g. 'visit = 123')." 

129 ), 

130 ) 

131 

132 fail_exception = Field[str]( 

133 dtype=str, 

134 default="builtins.ValueError", 

135 doc=( 

136 "Class name of the exception to raise when fail condition is triggered. Can be " 

137 "'lsst.pipe.base.NoWorkFound' to specify non-failure exception." 

138 ), 

139 ) 

140 

141 def data_id_match(self) -> DataIdMatch | None: 

142 if not self.fail_condition: 

143 return None 

144 return DataIdMatch(self.fail_condition) 

145 

146 

147class BaseTestPipelineTask(PipelineTask): 

148 """A base class for test-utility `PipelineTask` classes that read and write 

149 mock datasets `runQuantum`. 

150 

151 Notes 

152 ----- 

153 This class overrides `runQuantum` to read inputs and write a bit of 

154 provenance into all of its outputs (always `MockDataset` instances). It 

155 can also be configured to raise exceptions on certain data IDs. It reads 

156 `MockDataset` inputs and simulates reading inputs of other types by 

157 creating `MockDataset` inputs from their DatasetRefs. 

158 

159 Subclasses are responsible for defining connections, but init-input and 

160 init-output connections are not supported at runtime (they may be present 

161 as long as the task is never constructed). All output connections must 

162 use mock storage classes. `..Input` and `..PrerequisiteInput` connections 

163 that do not use mock storage classes will be handled by constructing a 

164 `MockDataset` from the `~lsst.daf.butler.DatasetRef` rather than actually 

165 reading them. 

166 """ 

167 

168 ConfigClass: ClassVar[type[PipelineTaskConfig]] = BaseTestPipelineTaskConfig 

169 

170 def __init__( 

171 self, 

172 *, 

173 config: BaseTestPipelineTaskConfig, 

174 initInputs: Mapping[str, Any], 

175 **kwargs: Any, 

176 ): 

177 super().__init__(config=config, **kwargs) 

178 self.fail_exception: type | None = None 

179 self.data_id_match = self.config.data_id_match() 

180 if self.data_id_match: 

181 self.fail_exception = doImportType(self.config.fail_exception) 

182 # Look for, check, and record init-inputs. 

183 task_connections = self.ConfigClass.ConnectionsClass(config=config) 

184 mock_dataset_quantum = MockDatasetQuantum(task_label=self.getName(), data_id={}, inputs={}) 

185 for connection_name in task_connections.initInputs: 

186 input_dataset = initInputs[connection_name] 

187 if not isinstance(input_dataset, MockDataset): 

188 raise TypeError( 

189 f"Expected MockDataset instance for init-input {self.getName()}.{connection_name}: " 

190 f"got {input_dataset!r} of type {type(input_dataset)!r}." 

191 ) 

192 connection = task_connections.allConnections[connection_name] 

193 if input_dataset.dataset_type.name != connection.name: 

194 raise RuntimeError( 

195 f"Incorrect dataset type name for init-input {self.getName()}.{connection_name}: " 

196 f"got {input_dataset.dataset_type.name!r}, expected {connection.name!r}." 

197 ) 

198 if input_dataset.storage_class != connection.storageClass: 

199 raise RuntimeError( 

200 f"Incorrect storage class for init-input {self.getName()}.{connection_name}: " 

201 f"got {input_dataset.storage_class!r}, expected {connection.storageClass!r}." 

202 ) 

203 # To avoid very deep provenance we trim inputs to a single 

204 # level. 

205 input_dataset.quantum = None 

206 mock_dataset_quantum.inputs[connection_name] = [input_dataset] 

207 # Add init-outputs as task instance attributes. 

208 for connection_name in task_connections.initOutputs: 

209 connection = task_connections.allConnections[connection_name] 

210 output_dataset = MockDataset( 

211 dataset_id=None, # the task has no way to get this 

212 dataset_type=SerializedDatasetType( 

213 name=connection.name, 

214 storageClass=connection.storageClass, 

215 dimensions=SerializedDimensionGraph(names=[]), 

216 ), 

217 data_id={}, 

218 run=None, # task also has no way to get this 

219 quantum=mock_dataset_quantum, 

220 output_connection_name=connection_name, 

221 ) 

222 setattr(self, connection_name, output_dataset) 

223 

224 config: BaseTestPipelineTaskConfig 

225 

226 def runQuantum( 

227 self, 

228 butlerQC: QuantumContext, 

229 inputRefs: InputQuantizedConnection, 

230 outputRefs: OutputQuantizedConnection, 

231 ) -> None: 

232 # docstring is inherited from the base class 

233 quantum = butlerQC.quantum 

234 

235 _LOG.info("Mocking execution of task '%s' on quantum %s", self.getName(), quantum.dataId) 

236 

237 assert quantum.dataId is not None, "Quantum DataId cannot be None" 

238 

239 # Possibly raise an exception. 

240 if self.data_id_match is not None and self.data_id_match.match(quantum.dataId): 

241 _LOG.info("Simulating failure of task '%s' on quantum %s", self.getName(), quantum.dataId) 

242 message = f"Simulated failure: task={self.getName()} dataId={quantum.dataId}" 

243 assert self.fail_exception is not None, "Exception type must be defined" 

244 raise self.fail_exception(message) 

245 

246 # Populate the bit of provenance we store in all outputs. 

247 _LOG.info("Reading input data for task '%s' on quantum %s", self.getName(), quantum.dataId) 

248 mock_dataset_quantum = MockDatasetQuantum( 

249 task_label=self.getName(), data_id=quantum.dataId.full.byName(), inputs={} 

250 ) 

251 for name, refs in inputRefs: 

252 inputs_list = [] 

253 ref: DatasetRef 

254 for ref in ensure_iterable(refs): 

255 if isinstance(ref.datasetType.storageClass, MockStorageClass): 

256 input_dataset = butlerQC.get(ref) 

257 if isinstance(input_dataset, DeferredDatasetHandle): 

258 input_dataset = input_dataset.get() 

259 if not isinstance(input_dataset, MockDataset): 

260 raise TypeError( 

261 f"Expected MockDataset instance for {ref}; " 

262 f"got {input_dataset!r} of type {type(input_dataset)!r}." 

263 ) 

264 # To avoid very deep provenance we trim inputs to a single 

265 # level. 

266 input_dataset.quantum = None 

267 else: 

268 input_dataset = MockDataset( 

269 dataset_id=ref.id, 

270 dataset_type=ref.datasetType.to_simple(), 

271 data_id=ref.dataId.full.byName(), 

272 run=ref.run, 

273 ) 

274 inputs_list.append(input_dataset) 

275 mock_dataset_quantum.inputs[name] = inputs_list 

276 

277 # store mock outputs 

278 for name, refs in outputRefs: 

279 for ref in ensure_iterable(refs): 

280 output = MockDataset( 

281 dataset_id=ref.id, 

282 dataset_type=ref.datasetType.to_simple(), 

283 data_id=ref.dataId.full.byName(), 

284 run=ref.run, 

285 quantum=mock_dataset_quantum, 

286 output_connection_name=name, 

287 ) 

288 butlerQC.put(output, ref) 

289 

290 _LOG.info("Finished mocking task '%s' on quantum %s", self.getName(), quantum.dataId) 

291 

292 

293class MockPipelineDefaultTargetConnections(PipelineTaskConnections, dimensions=()): 

294 pass 

295 

296 

297class MockPipelineDefaultTargetConfig( 

298 PipelineTaskConfig, pipelineConnections=MockPipelineDefaultTargetConnections 

299): 

300 pass 

301 

302 

303class MockPipelineDefaultTargetTask(PipelineTask): 

304 """A `~lsst.pipe.base.PipelineTask` class used as the default target for 

305 ``MockPipelineTaskConfig.original``. 

306 

307 This is effectively a workaround for `lsst.pex.config.ConfigurableField` 

308 not supporting ``optional=True``, but that is generally a reasonable 

309 limitation for production code and it wouldn't make sense just to support 

310 test utilities. 

311 """ 

312 

313 ConfigClass = MockPipelineDefaultTargetConfig 

314 

315 

316class MockPipelineTaskConnections(BaseTestPipelineTaskConnections, dimensions=()): 

317 """A connections class that creates mock connections from the connections 

318 of a real PipelineTask. 

319 """ 

320 

321 def __init__(self, *, config: MockPipelineTaskConfig): 

322 self.original: PipelineTaskConnections = config.original.connections.ConnectionsClass( 

323 config=config.original.value 

324 ) 

325 self.dimensions.update(self.original.dimensions) 

326 self.unmocked_dataset_types = frozenset(config.unmocked_dataset_types) 

327 for name, connection in self.original.allConnections.items(): 

328 if connection.name not in self.unmocked_dataset_types: 

329 # We register the mock storage class with the global singleton 

330 # here, but can only put its name in the connection. That means 

331 # the same global singleton (or one that also has these 

332 # registrations) has to be available whenever this dataset type 

333 # is used. 

334 storage_class = MockStorageClass.get_or_register_mock(connection.storageClass) 

335 kwargs: dict[str, Any] = {} 

336 if hasattr(connection, "dimensions"): 

337 connection_dimensions = set(connection.dimensions) 

338 # Replace the generic "skypix" placeholder with htm7, since 

339 # that requires the dataset type to have already been 

340 # registered. 

341 if "skypix" in connection_dimensions: 

342 connection_dimensions.remove("skypix") 

343 connection_dimensions.add("htm7") 

344 kwargs["dimensions"] = connection_dimensions 

345 connection = dataclasses.replace( 

346 connection, 

347 name=get_mock_name(connection.name), 

348 storageClass=storage_class.name, 

349 **kwargs, 

350 ) 

351 elif name in self.original.outputs: 

352 raise ValueError(f"Unmocked dataset type {connection.name!r} cannot be used as an output.") 

353 elif name in self.original.initInputs: 

354 raise ValueError( 

355 f"Unmocked dataset type {connection.name!r} cannot be used as an init-input." 

356 ) 

357 elif name in self.original.initOutputs: 

358 raise ValueError( 

359 f"Unmocked dataset type {connection.name!r} cannot be used as an init-output." 

360 ) 

361 setattr(self, name, connection) 

362 

363 def getSpatialBoundsConnections(self) -> Iterable[str]: 

364 return self.original.getSpatialBoundsConnections() 

365 

366 def getTemporalBoundsConnections(self) -> Iterable[str]: 

367 return self.original.getTemporalBoundsConnections() 

368 

369 def adjustQuantum( 

370 self, 

371 inputs: dict[str, tuple[BaseInput, Collection[DatasetRef]]], 

372 outputs: dict[str, tuple[Output, Collection[DatasetRef]]], 

373 label: str, 

374 data_id: DataCoordinate, 

375 ) -> tuple[ 

376 Mapping[str, tuple[BaseInput, Collection[DatasetRef]]], 

377 Mapping[str, tuple[Output, Collection[DatasetRef]]], 

378 ]: 

379 # Convert the given mappings from the mock dataset types to the 

380 # original dataset types they were produced from. 

381 original_inputs = {} 

382 for connection_name, (_, mock_refs) in inputs.items(): 

383 original_connection = getattr(self.original, connection_name) 

384 if original_connection.name in self.unmocked_dataset_types: 

385 refs = mock_refs 

386 else: 

387 refs = MockStorageClass.unmock_dataset_refs(mock_refs) 

388 original_inputs[connection_name] = (original_connection, refs) 

389 original_outputs = {} 

390 for connection_name, (_, mock_refs) in outputs.items(): 

391 original_connection = getattr(self.original, connection_name) 

392 if original_connection.name in self.unmocked_dataset_types: 

393 refs = mock_refs 

394 else: 

395 refs = MockStorageClass.unmock_dataset_refs(mock_refs) 

396 original_outputs[connection_name] = (original_connection, refs) 

397 # Call adjustQuantum on the original connections class. 

398 adjusted_original_inputs, adjusted_original_outputs = self.original.adjustQuantum( 

399 original_inputs, original_outputs, label, data_id 

400 ) 

401 # Convert the results back to the mock dataset type.s 

402 adjusted_inputs = {} 

403 for connection_name, (original_connection, original_refs) in adjusted_original_inputs.items(): 

404 if original_connection.name in self.unmocked_dataset_types: 

405 refs = original_refs 

406 else: 

407 refs = MockStorageClass.mock_dataset_refs(original_refs) 

408 adjusted_inputs[connection_name] = (getattr(self, connection_name), refs) 

409 adjusted_outputs = {} 

410 for connection_name, (original_connection, original_refs) in adjusted_original_outputs.items(): 

411 if original_connection.name in self.unmocked_dataset_types: 

412 refs = original_refs 

413 else: 

414 refs = MockStorageClass.mock_dataset_refs(original_refs) 

415 adjusted_outputs[connection_name] = (getattr(self, connection_name), refs) 

416 return adjusted_inputs, adjusted_outputs 

417 

418 

419class MockPipelineTaskConfig(BaseTestPipelineTaskConfig, pipelineConnections=MockPipelineTaskConnections): 

420 """Configuration class for `MockPipelineTask`.""" 

421 

422 original: ConfigurableField = ConfigurableField( 

423 doc="The original task being mocked by this one.", target=MockPipelineDefaultTargetTask 

424 ) 

425 

426 unmocked_dataset_types = ListField[str]( 

427 doc=( 

428 "Names of input dataset types that should be used as-is instead " 

429 "of being mocked. May include dataset types not relevant for " 

430 "this task, which will be ignored." 

431 ), 

432 default=(), 

433 optional=False, 

434 ) 

435 

436 

437class MockPipelineTask(BaseTestPipelineTask): 

438 """A test-utility implementation of `PipelineTask` with connections 

439 generated by mocking those of a real task. 

440 

441 Notes 

442 ----- 

443 At present `MockPipelineTask` simply drops any ``initInput`` and 

444 ``initOutput`` connections present on the original, since `MockDataset` 

445 creation for those would have to happen in the code that executes the task, 

446 not in the task itself. Because `MockPipelineTask` never instantiates the 

447 mock task (just its connections class), this is a limitation on what the 

448 mocks can be used to test, not anything deeper. 

449 """ 

450 

451 ConfigClass: ClassVar[type[PipelineTaskConfig]] = MockPipelineTaskConfig 

452 

453 

454class DynamicConnectionConfig(Config): 

455 """A config class that defines a completely dynamic connection.""" 

456 

457 dataset_type_name = Field[str](doc="Name for the dataset type as seen by the butler.", dtype=str) 

458 dimensions = ListField[str](doc="Dimensions for the dataset type.", dtype=str, default=[]) 

459 storage_class = Field[str]( 

460 doc="Name of the butler storage class for the dataset type.", dtype=str, default="StructuredDataDict" 

461 ) 

462 is_calibration = Field[bool](doc="Whether this dataset type is a calibration.", dtype=bool, default=False) 

463 multiple = Field[bool]( 

464 doc="Whether this connection gets or puts multiple datasets for each quantum.", 

465 dtype=bool, 

466 default=False, 

467 ) 

468 mock_storage_class = Field[bool]( 

469 doc="Whether the storage class should actually be a mock of the storage class given.", 

470 dtype=bool, 

471 default=True, 

472 ) 

473 

474 def make_connection(self, cls: type[_T]) -> _T: 

475 storage_class = self.storage_class 

476 if self.mock_storage_class: 

477 storage_class = MockStorageClass.get_or_register_mock(storage_class).name 

478 if issubclass(cls, cT.DimensionedConnection): 

479 return cls( # type: ignore 

480 name=self.dataset_type_name, 

481 storageClass=storage_class, 

482 isCalibration=self.is_calibration, 

483 multiple=self.multiple, 

484 dimensions=frozenset(self.dimensions), 

485 ) 

486 else: 

487 return cls( 

488 name=self.dataset_type_name, 

489 storageClass=storage_class, 

490 multiple=self.multiple, 

491 ) 

492 

493 

494class DynamicTestPipelineTaskConnections(PipelineTaskConnections, dimensions=()): 

495 """A connections class whose dimensions and connections are wholly 

496 determined via configuration. 

497 """ 

498 

499 def __init__(self, *, config: DynamicTestPipelineTaskConfig): 

500 self.dimensions.update(config.dimensions) 

501 connection_config: DynamicConnectionConfig 

502 for connection_name, connection_config in config.init_inputs.items(): 

503 setattr(self, connection_name, connection_config.make_connection(cT.InitInput)) 

504 for connection_name, connection_config in config.init_outputs.items(): 

505 setattr(self, connection_name, connection_config.make_connection(cT.InitOutput)) 

506 for connection_name, connection_config in config.prerequisite_inputs.items(): 

507 setattr(self, connection_name, connection_config.make_connection(cT.PrerequisiteInput)) 

508 for connection_name, connection_config in config.inputs.items(): 

509 setattr(self, connection_name, connection_config.make_connection(cT.Input)) 

510 for connection_name, connection_config in config.outputs.items(): 

511 setattr(self, connection_name, connection_config.make_connection(cT.Output)) 

512 

513 

514class DynamicTestPipelineTaskConfig( 

515 PipelineTaskConfig, pipelineConnections=DynamicTestPipelineTaskConnections 

516): 

517 """Configuration for DynamicTestPipelineTask.""" 

518 

519 dimensions = ListField[str](doc="Dimensions for the task's quanta.", dtype=str, default=[]) 

520 init_inputs = ConfigDictField( 

521 doc=( 

522 "Init-input connections, keyed by the connection name as seen by the task. " 

523 "Must be empty if the task will be constructed." 

524 ), 

525 keytype=str, 

526 itemtype=DynamicConnectionConfig, 

527 default={}, 

528 ) 

529 init_outputs = ConfigDictField( 

530 doc=( 

531 "Init-output connections, keyed by the connection name as seen by the task. " 

532 "Must be empty if the task will be constructed." 

533 ), 

534 keytype=str, 

535 itemtype=DynamicConnectionConfig, 

536 default={}, 

537 ) 

538 prerequisite_inputs = ConfigDictField( 

539 doc="Prerequisite input connections, keyed by the connection name as seen by the task.", 

540 keytype=str, 

541 itemtype=DynamicConnectionConfig, 

542 default={}, 

543 ) 

544 inputs = ConfigDictField( 

545 doc="Regular input connections, keyed by the connection name as seen by the task.", 

546 keytype=str, 

547 itemtype=DynamicConnectionConfig, 

548 default={}, 

549 ) 

550 outputs = ConfigDictField( 

551 doc="Regular output connections, keyed by the connection name as seen by the task.", 

552 keytype=str, 

553 itemtype=DynamicConnectionConfig, 

554 default={}, 

555 ) 

556 

557 

558class DynamicTestPipelineTask(BaseTestPipelineTask): 

559 """A test-utility implementation of `PipelineTask` with dimensions and 

560 connections determined wholly from configuration. 

561 """ 

562 

563 ConfigClass: ClassVar[type[PipelineTaskConfig]] = DynamicTestPipelineTaskConfig