Coverage for python/lsst/pipe/base/tests/mocks/_pipeline_task.py: 24%

206 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-09-19 10:39 +0000

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27from __future__ import annotations 

28 

29from lsst.pipe.base.connectionTypes import BaseInput, Output 

30 

31__all__ = ( 

32 "DynamicConnectionConfig", 

33 "DynamicTestPipelineTask", 

34 "DynamicTestPipelineTaskConfig", 

35 "MockPipelineTask", 

36 "MockPipelineTaskConfig", 

37 "mock_task_defs", 

38) 

39 

40import dataclasses 

41import logging 

42from collections.abc import Collection, Iterable, Mapping 

43from typing import TYPE_CHECKING, Any, ClassVar, TypeVar 

44 

45from lsst.daf.butler import ( 

46 DataCoordinate, 

47 DatasetRef, 

48 DeferredDatasetHandle, 

49 SerializedDatasetType, 

50 SerializedDimensionGraph, 

51) 

52from lsst.pex.config import Config, ConfigDictField, ConfigurableField, Field, ListField 

53from lsst.utils.doImport import doImportType 

54from lsst.utils.introspection import get_full_type_name 

55from lsst.utils.iteration import ensure_iterable 

56 

57from ... import connectionTypes as cT 

58from ...config import PipelineTaskConfig 

59from ...connections import InputQuantizedConnection, OutputQuantizedConnection, PipelineTaskConnections 

60from ...pipeline import TaskDef 

61from ...pipelineTask import PipelineTask 

62from ._data_id_match import DataIdMatch 

63from ._storage_class import MockDataset, MockDatasetQuantum, MockStorageClass, get_mock_name 

64 

65_LOG = logging.getLogger(__name__) 

66 

67if TYPE_CHECKING: 

68 from ..._quantumContext import QuantumContext 

69 

70 

71_T = TypeVar("_T", bound=cT.BaseConnection) 

72 

73 

74def mock_task_defs( 

75 originals: Iterable[TaskDef], 

76 unmocked_dataset_types: Iterable[str] = (), 

77 force_failures: Mapping[str, tuple[str, type[Exception] | None]] | None = None, 

78) -> list[TaskDef]: 

79 """Create mocks for an iterable of TaskDefs. 

80 

81 Parameters 

82 ---------- 

83 originals : `~collections.abc.Iterable` [ `TaskDef` ] 

84 Original tasks and configuration to mock. 

85 unmocked_dataset_types : `~collections.abc.Iterable` [ `str` ], optional 

86 Names of overall-input dataset types that should not be replaced with 

87 mocks. 

88 force_failures : `~collections.abc.Mapping` [ `str`, `tuple` [ `str`, \ 

89 `type` [ `Exception` ] or `None` ] ] 

90 Mapping from original task label to a 2-tuple indicating that some 

91 quanta should raise an exception when executed. The first entry is a 

92 data ID match using the butler expression language (i.e. a string of 

93 the sort passed ass the ``where`` argument to butler query methods), 

94 while the second is the type of exception to raise when the quantum 

95 data ID matches the expression. 

96 

97 Returns 

98 ------- 

99 mocked : `list` [ `TaskDef` ] 

100 List of `TaskDef` objects using `MockPipelineTask` configurations that 

101 target the original tasks, in the same order. 

102 """ 

103 unmocked_dataset_types = tuple(unmocked_dataset_types) 

104 if force_failures is None: 

105 force_failures = {} 

106 results: list[TaskDef] = [] 

107 for original_task_def in originals: 

108 config = MockPipelineTaskConfig() 

109 config.original.retarget(original_task_def.taskClass) 

110 config.original = original_task_def.config 

111 config.unmocked_dataset_types.extend(unmocked_dataset_types) 

112 if original_task_def.label in force_failures: 

113 condition, exception_type = force_failures[original_task_def.label] 

114 config.fail_condition = condition 

115 if exception_type is not None: 

116 config.fail_exception = get_full_type_name(exception_type) 

117 mock_task_def = TaskDef( 

118 config=config, taskClass=MockPipelineTask, label=get_mock_name(original_task_def.label) 

119 ) 

120 results.append(mock_task_def) 

121 return results 

122 

123 

124class BaseTestPipelineTaskConnections(PipelineTaskConnections, dimensions=()): 

125 pass 

126 

127 

128class BaseTestPipelineTaskConfig(PipelineTaskConfig, pipelineConnections=BaseTestPipelineTaskConnections): 

129 fail_condition = Field[str]( 

130 dtype=str, 

131 default="", 

132 doc=( 

133 "Condition on Data ID to raise an exception. String expression which includes attributes of " 

134 "quantum data ID using a syntax of daf_butler user expressions (e.g. 'visit = 123')." 

135 ), 

136 ) 

137 

138 fail_exception = Field[str]( 

139 dtype=str, 

140 default="builtins.ValueError", 

141 doc=( 

142 "Class name of the exception to raise when fail condition is triggered. Can be " 

143 "'lsst.pipe.base.NoWorkFound' to specify non-failure exception." 

144 ), 

145 ) 

146 

147 def data_id_match(self) -> DataIdMatch | None: 

148 if not self.fail_condition: 

149 return None 

150 return DataIdMatch(self.fail_condition) 

151 

152 

153class BaseTestPipelineTask(PipelineTask): 

154 """A base class for test-utility `PipelineTask` classes that read and write 

155 mock datasets `runQuantum`. 

156 

157 Notes 

158 ----- 

159 This class overrides `runQuantum` to read inputs and write a bit of 

160 provenance into all of its outputs (always `MockDataset` instances). It 

161 can also be configured to raise exceptions on certain data IDs. It reads 

162 `MockDataset` inputs and simulates reading inputs of other types by 

163 creating `MockDataset` inputs from their DatasetRefs. 

164 

165 Subclasses are responsible for defining connections, but init-input and 

166 init-output connections are not supported at runtime (they may be present 

167 as long as the task is never constructed). All output connections must 

168 use mock storage classes. `..Input` and `..PrerequisiteInput` connections 

169 that do not use mock storage classes will be handled by constructing a 

170 `MockDataset` from the `~lsst.daf.butler.DatasetRef` rather than actually 

171 reading them. 

172 """ 

173 

174 ConfigClass: ClassVar[type[PipelineTaskConfig]] = BaseTestPipelineTaskConfig 

175 

176 def __init__( 

177 self, 

178 *, 

179 config: BaseTestPipelineTaskConfig, 

180 initInputs: Mapping[str, Any], 

181 **kwargs: Any, 

182 ): 

183 super().__init__(config=config, **kwargs) 

184 self.fail_exception: type | None = None 

185 self.data_id_match = self.config.data_id_match() 

186 if self.data_id_match: 

187 self.fail_exception = doImportType(self.config.fail_exception) 

188 # Look for, check, and record init-inputs. 

189 task_connections = self.ConfigClass.ConnectionsClass(config=config) 

190 mock_dataset_quantum = MockDatasetQuantum(task_label=self.getName(), data_id={}, inputs={}) 

191 for connection_name in task_connections.initInputs: 

192 input_dataset = initInputs[connection_name] 

193 if not isinstance(input_dataset, MockDataset): 

194 raise TypeError( 

195 f"Expected MockDataset instance for init-input {self.getName()}.{connection_name}: " 

196 f"got {input_dataset!r} of type {type(input_dataset)!r}." 

197 ) 

198 connection = task_connections.allConnections[connection_name] 

199 if input_dataset.dataset_type.name != connection.name: 

200 raise RuntimeError( 

201 f"Incorrect dataset type name for init-input {self.getName()}.{connection_name}: " 

202 f"got {input_dataset.dataset_type.name!r}, expected {connection.name!r}." 

203 ) 

204 if input_dataset.storage_class != connection.storageClass: 

205 raise RuntimeError( 

206 f"Incorrect storage class for init-input {self.getName()}.{connection_name}: " 

207 f"got {input_dataset.storage_class!r}, expected {connection.storageClass!r}." 

208 ) 

209 # To avoid very deep provenance we trim inputs to a single 

210 # level. 

211 input_dataset.quantum = None 

212 mock_dataset_quantum.inputs[connection_name] = [input_dataset] 

213 # Add init-outputs as task instance attributes. 

214 for connection_name in task_connections.initOutputs: 

215 connection = task_connections.allConnections[connection_name] 

216 output_dataset = MockDataset( 

217 dataset_id=None, # the task has no way to get this 

218 dataset_type=SerializedDatasetType( 

219 name=connection.name, 

220 storageClass=connection.storageClass, 

221 dimensions=SerializedDimensionGraph(names=[]), 

222 ), 

223 data_id={}, 

224 run=None, # task also has no way to get this 

225 quantum=mock_dataset_quantum, 

226 output_connection_name=connection_name, 

227 ) 

228 setattr(self, connection_name, output_dataset) 

229 

230 config: BaseTestPipelineTaskConfig 

231 

232 def runQuantum( 

233 self, 

234 butlerQC: QuantumContext, 

235 inputRefs: InputQuantizedConnection, 

236 outputRefs: OutputQuantizedConnection, 

237 ) -> None: 

238 # docstring is inherited from the base class 

239 quantum = butlerQC.quantum 

240 

241 _LOG.info("Mocking execution of task '%s' on quantum %s", self.getName(), quantum.dataId) 

242 

243 assert quantum.dataId is not None, "Quantum DataId cannot be None" 

244 

245 # Possibly raise an exception. 

246 if self.data_id_match is not None and self.data_id_match.match(quantum.dataId): 

247 _LOG.info("Simulating failure of task '%s' on quantum %s", self.getName(), quantum.dataId) 

248 message = f"Simulated failure: task={self.getName()} dataId={quantum.dataId}" 

249 assert self.fail_exception is not None, "Exception type must be defined" 

250 raise self.fail_exception(message) 

251 

252 # Populate the bit of provenance we store in all outputs. 

253 _LOG.info("Reading input data for task '%s' on quantum %s", self.getName(), quantum.dataId) 

254 mock_dataset_quantum = MockDatasetQuantum( 

255 task_label=self.getName(), data_id=quantum.dataId.full.byName(), inputs={} 

256 ) 

257 for name, refs in inputRefs: 

258 inputs_list = [] 

259 ref: DatasetRef 

260 for ref in ensure_iterable(refs): 

261 if isinstance(ref.datasetType.storageClass, MockStorageClass): 

262 input_dataset = butlerQC.get(ref) 

263 if isinstance(input_dataset, DeferredDatasetHandle): 

264 input_dataset = input_dataset.get() 

265 if not isinstance(input_dataset, MockDataset): 

266 raise TypeError( 

267 f"Expected MockDataset instance for {ref}; " 

268 f"got {input_dataset!r} of type {type(input_dataset)!r}." 

269 ) 

270 # To avoid very deep provenance we trim inputs to a single 

271 # level. 

272 input_dataset.quantum = None 

273 else: 

274 input_dataset = MockDataset( 

275 dataset_id=ref.id, 

276 dataset_type=ref.datasetType.to_simple(), 

277 data_id=ref.dataId.full.byName(), 

278 run=ref.run, 

279 ) 

280 inputs_list.append(input_dataset) 

281 mock_dataset_quantum.inputs[name] = inputs_list 

282 

283 # store mock outputs 

284 for name, refs in outputRefs: 

285 for ref in ensure_iterable(refs): 

286 output = MockDataset( 

287 dataset_id=ref.id, 

288 dataset_type=ref.datasetType.to_simple(), 

289 data_id=ref.dataId.full.byName(), 

290 run=ref.run, 

291 quantum=mock_dataset_quantum, 

292 output_connection_name=name, 

293 ) 

294 butlerQC.put(output, ref) 

295 

296 _LOG.info("Finished mocking task '%s' on quantum %s", self.getName(), quantum.dataId) 

297 

298 

299class MockPipelineDefaultTargetConnections(PipelineTaskConnections, dimensions=()): 

300 pass 

301 

302 

303class MockPipelineDefaultTargetConfig( 

304 PipelineTaskConfig, pipelineConnections=MockPipelineDefaultTargetConnections 

305): 

306 pass 

307 

308 

309class MockPipelineDefaultTargetTask(PipelineTask): 

310 """A `~lsst.pipe.base.PipelineTask` class used as the default target for 

311 ``MockPipelineTaskConfig.original``. 

312 

313 This is effectively a workaround for `lsst.pex.config.ConfigurableField` 

314 not supporting ``optional=True``, but that is generally a reasonable 

315 limitation for production code and it wouldn't make sense just to support 

316 test utilities. 

317 """ 

318 

319 ConfigClass = MockPipelineDefaultTargetConfig 

320 

321 

322class MockPipelineTaskConnections(BaseTestPipelineTaskConnections, dimensions=()): 

323 """A connections class that creates mock connections from the connections 

324 of a real PipelineTask. 

325 """ 

326 

327 def __init__(self, *, config: MockPipelineTaskConfig): 

328 self.original: PipelineTaskConnections = config.original.connections.ConnectionsClass( 

329 config=config.original.value 

330 ) 

331 self.dimensions.update(self.original.dimensions) 

332 self.unmocked_dataset_types = frozenset(config.unmocked_dataset_types) 

333 for name, connection in self.original.allConnections.items(): 

334 if connection.name not in self.unmocked_dataset_types: 

335 # We register the mock storage class with the global singleton 

336 # here, but can only put its name in the connection. That means 

337 # the same global singleton (or one that also has these 

338 # registrations) has to be available whenever this dataset type 

339 # is used. 

340 storage_class = MockStorageClass.get_or_register_mock(connection.storageClass) 

341 kwargs: dict[str, Any] = {} 

342 if hasattr(connection, "dimensions"): 

343 connection_dimensions = set(connection.dimensions) 

344 # Replace the generic "skypix" placeholder with htm7, since 

345 # that requires the dataset type to have already been 

346 # registered. 

347 if "skypix" in connection_dimensions: 

348 connection_dimensions.remove("skypix") 

349 connection_dimensions.add("htm7") 

350 kwargs["dimensions"] = connection_dimensions 

351 connection = dataclasses.replace( 

352 connection, 

353 name=get_mock_name(connection.name), 

354 storageClass=storage_class.name, 

355 **kwargs, 

356 ) 

357 elif name in self.original.outputs: 

358 raise ValueError(f"Unmocked dataset type {connection.name!r} cannot be used as an output.") 

359 elif name in self.original.initInputs: 

360 raise ValueError( 

361 f"Unmocked dataset type {connection.name!r} cannot be used as an init-input." 

362 ) 

363 elif name in self.original.initOutputs: 

364 raise ValueError( 

365 f"Unmocked dataset type {connection.name!r} cannot be used as an init-output." 

366 ) 

367 setattr(self, name, connection) 

368 

369 def getSpatialBoundsConnections(self) -> Iterable[str]: 

370 return self.original.getSpatialBoundsConnections() 

371 

372 def getTemporalBoundsConnections(self) -> Iterable[str]: 

373 return self.original.getTemporalBoundsConnections() 

374 

375 def adjustQuantum( 

376 self, 

377 inputs: dict[str, tuple[BaseInput, Collection[DatasetRef]]], 

378 outputs: dict[str, tuple[Output, Collection[DatasetRef]]], 

379 label: str, 

380 data_id: DataCoordinate, 

381 ) -> tuple[ 

382 Mapping[str, tuple[BaseInput, Collection[DatasetRef]]], 

383 Mapping[str, tuple[Output, Collection[DatasetRef]]], 

384 ]: 

385 # Convert the given mappings from the mock dataset types to the 

386 # original dataset types they were produced from. 

387 original_inputs = {} 

388 for connection_name, (_, mock_refs) in inputs.items(): 

389 original_connection = getattr(self.original, connection_name) 

390 if original_connection.name in self.unmocked_dataset_types: 

391 refs = mock_refs 

392 else: 

393 refs = MockStorageClass.unmock_dataset_refs(mock_refs) 

394 original_inputs[connection_name] = (original_connection, refs) 

395 original_outputs = {} 

396 for connection_name, (_, mock_refs) in outputs.items(): 

397 original_connection = getattr(self.original, connection_name) 

398 if original_connection.name in self.unmocked_dataset_types: 

399 refs = mock_refs 

400 else: 

401 refs = MockStorageClass.unmock_dataset_refs(mock_refs) 

402 original_outputs[connection_name] = (original_connection, refs) 

403 # Call adjustQuantum on the original connections class. 

404 adjusted_original_inputs, adjusted_original_outputs = self.original.adjustQuantum( 

405 original_inputs, original_outputs, label, data_id 

406 ) 

407 # Convert the results back to the mock dataset type.s 

408 adjusted_inputs = {} 

409 for connection_name, (original_connection, original_refs) in adjusted_original_inputs.items(): 

410 if original_connection.name in self.unmocked_dataset_types: 

411 refs = original_refs 

412 else: 

413 refs = MockStorageClass.mock_dataset_refs(original_refs) 

414 adjusted_inputs[connection_name] = (getattr(self, connection_name), refs) 

415 adjusted_outputs = {} 

416 for connection_name, (original_connection, original_refs) in adjusted_original_outputs.items(): 

417 if original_connection.name in self.unmocked_dataset_types: 

418 refs = original_refs 

419 else: 

420 refs = MockStorageClass.mock_dataset_refs(original_refs) 

421 adjusted_outputs[connection_name] = (getattr(self, connection_name), refs) 

422 return adjusted_inputs, adjusted_outputs 

423 

424 

425class MockPipelineTaskConfig(BaseTestPipelineTaskConfig, pipelineConnections=MockPipelineTaskConnections): 

426 """Configuration class for `MockPipelineTask`.""" 

427 

428 original: ConfigurableField = ConfigurableField( 

429 doc="The original task being mocked by this one.", target=MockPipelineDefaultTargetTask 

430 ) 

431 

432 unmocked_dataset_types = ListField[str]( 

433 doc=( 

434 "Names of input dataset types that should be used as-is instead " 

435 "of being mocked. May include dataset types not relevant for " 

436 "this task, which will be ignored." 

437 ), 

438 default=(), 

439 optional=False, 

440 ) 

441 

442 

443class MockPipelineTask(BaseTestPipelineTask): 

444 """A test-utility implementation of `PipelineTask` with connections 

445 generated by mocking those of a real task. 

446 

447 Notes 

448 ----- 

449 At present `MockPipelineTask` simply drops any ``initInput`` and 

450 ``initOutput`` connections present on the original, since `MockDataset` 

451 creation for those would have to happen in the code that executes the task, 

452 not in the task itself. Because `MockPipelineTask` never instantiates the 

453 mock task (just its connections class), this is a limitation on what the 

454 mocks can be used to test, not anything deeper. 

455 """ 

456 

457 ConfigClass: ClassVar[type[PipelineTaskConfig]] = MockPipelineTaskConfig 

458 

459 

460class DynamicConnectionConfig(Config): 

461 """A config class that defines a completely dynamic connection.""" 

462 

463 dataset_type_name = Field[str](doc="Name for the dataset type as seen by the butler.", dtype=str) 

464 dimensions = ListField[str](doc="Dimensions for the dataset type.", dtype=str, default=[]) 

465 storage_class = Field[str]( 

466 doc="Name of the butler storage class for the dataset type.", dtype=str, default="StructuredDataDict" 

467 ) 

468 is_calibration = Field[bool](doc="Whether this dataset type is a calibration.", dtype=bool, default=False) 

469 multiple = Field[bool]( 

470 doc="Whether this connection gets or puts multiple datasets for each quantum.", 

471 dtype=bool, 

472 default=False, 

473 ) 

474 mock_storage_class = Field[bool]( 

475 doc="Whether the storage class should actually be a mock of the storage class given.", 

476 dtype=bool, 

477 default=True, 

478 ) 

479 

480 def make_connection(self, cls: type[_T]) -> _T: 

481 storage_class = self.storage_class 

482 if self.mock_storage_class: 

483 storage_class = MockStorageClass.get_or_register_mock(storage_class).name 

484 if issubclass(cls, cT.DimensionedConnection): 

485 return cls( # type: ignore 

486 name=self.dataset_type_name, 

487 storageClass=storage_class, 

488 isCalibration=self.is_calibration, 

489 multiple=self.multiple, 

490 dimensions=frozenset(self.dimensions), 

491 ) 

492 else: 

493 return cls( 

494 name=self.dataset_type_name, 

495 storageClass=storage_class, 

496 multiple=self.multiple, 

497 ) 

498 

499 

500class DynamicTestPipelineTaskConnections(PipelineTaskConnections, dimensions=()): 

501 """A connections class whose dimensions and connections are wholly 

502 determined via configuration. 

503 """ 

504 

505 def __init__(self, *, config: DynamicTestPipelineTaskConfig): 

506 self.dimensions.update(config.dimensions) 

507 connection_config: DynamicConnectionConfig 

508 for connection_name, connection_config in config.init_inputs.items(): 

509 setattr(self, connection_name, connection_config.make_connection(cT.InitInput)) 

510 for connection_name, connection_config in config.init_outputs.items(): 

511 setattr(self, connection_name, connection_config.make_connection(cT.InitOutput)) 

512 for connection_name, connection_config in config.prerequisite_inputs.items(): 

513 setattr(self, connection_name, connection_config.make_connection(cT.PrerequisiteInput)) 

514 for connection_name, connection_config in config.inputs.items(): 

515 setattr(self, connection_name, connection_config.make_connection(cT.Input)) 

516 for connection_name, connection_config in config.outputs.items(): 

517 setattr(self, connection_name, connection_config.make_connection(cT.Output)) 

518 

519 

520class DynamicTestPipelineTaskConfig( 

521 PipelineTaskConfig, pipelineConnections=DynamicTestPipelineTaskConnections 

522): 

523 """Configuration for DynamicTestPipelineTask.""" 

524 

525 dimensions = ListField[str](doc="Dimensions for the task's quanta.", dtype=str, default=[]) 

526 init_inputs = ConfigDictField( 

527 doc=( 

528 "Init-input connections, keyed by the connection name as seen by the task. " 

529 "Must be empty if the task will be constructed." 

530 ), 

531 keytype=str, 

532 itemtype=DynamicConnectionConfig, 

533 default={}, 

534 ) 

535 init_outputs = ConfigDictField( 

536 doc=( 

537 "Init-output connections, keyed by the connection name as seen by the task. " 

538 "Must be empty if the task will be constructed." 

539 ), 

540 keytype=str, 

541 itemtype=DynamicConnectionConfig, 

542 default={}, 

543 ) 

544 prerequisite_inputs = ConfigDictField( 

545 doc="Prerequisite input connections, keyed by the connection name as seen by the task.", 

546 keytype=str, 

547 itemtype=DynamicConnectionConfig, 

548 default={}, 

549 ) 

550 inputs = ConfigDictField( 

551 doc="Regular input connections, keyed by the connection name as seen by the task.", 

552 keytype=str, 

553 itemtype=DynamicConnectionConfig, 

554 default={}, 

555 ) 

556 outputs = ConfigDictField( 

557 doc="Regular output connections, keyed by the connection name as seen by the task.", 

558 keytype=str, 

559 itemtype=DynamicConnectionConfig, 

560 default={}, 

561 ) 

562 

563 

564class DynamicTestPipelineTask(BaseTestPipelineTask): 

565 """A test-utility implementation of `PipelineTask` with dimensions and 

566 connections determined wholly from configuration. 

567 """ 

568 

569 ConfigClass: ClassVar[type[PipelineTaskConfig]] = DynamicTestPipelineTaskConfig