Coverage for python/lsst/pipe/base/tests/mocks/_pipeline_task.py: 31%

153 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-08-06 02:28 +0000

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ( 

24 "DynamicConnectionConfig", 

25 "DynamicTestPipelineTask", 

26 "DynamicTestPipelineTaskConfig", 

27 "MockPipelineTask", 

28 "MockPipelineTaskConfig", 

29 "mock_task_defs", 

30) 

31 

32import dataclasses 

33import logging 

34from collections.abc import Iterable, Mapping 

35from typing import TYPE_CHECKING, Any, ClassVar, TypeVar 

36 

37from lsst.daf.butler import DatasetRef, DeferredDatasetHandle 

38from lsst.pex.config import Config, ConfigDictField, ConfigurableField, Field, ListField 

39from lsst.utils.doImport import doImportType 

40from lsst.utils.introspection import get_full_type_name 

41from lsst.utils.iteration import ensure_iterable 

42 

43from ... import connectionTypes as cT 

44from ...config import PipelineTaskConfig 

45from ...connections import InputQuantizedConnection, OutputQuantizedConnection, PipelineTaskConnections 

46from ...pipeline import TaskDef 

47from ...pipelineTask import PipelineTask 

48from ._data_id_match import DataIdMatch 

49from ._storage_class import MockDataset, MockDatasetQuantum, MockStorageClass, get_mock_name 

50 

51_LOG = logging.getLogger(__name__) 

52 

53if TYPE_CHECKING: 

54 from ..._quantumContext import QuantumContext 

55 

56 

57_T = TypeVar("_T", bound=cT.BaseConnection) 

58 

59 

60def mock_task_defs( 

61 originals: Iterable[TaskDef], 

62 unmocked_dataset_types: Iterable[str] = (), 

63 force_failures: Mapping[str, tuple[str, type[Exception] | None]] | None = None, 

64) -> list[TaskDef]: 

65 """Create mocks for an iterable of TaskDefs. 

66 

67 Parameters 

68 ---------- 

69 originals : `~collections.abc.Iterable` [ `TaskDef` ] 

70 Original tasks and configuration to mock. 

71 unmocked_dataset_types : `~collections.abc.Iterable` [ `str` ], optional 

72 Names of overall-input dataset types that should not be replaced with 

73 mocks. 

74 force_failures : `~collections.abc.Mapping` [ `str`, `tuple` [ `str`, \ 

75 `type` [ `Exception` ] or `None` ] ] 

76 Mapping from original task label to a 2-tuple indicating that some 

77 quanta should raise an exception when executed. The first entry is a 

78 data ID match using the butler expression language (i.e. a string of 

79 the sort passed ass the ``where`` argument to butler query methods), 

80 while the second is the type of exception to raise when the quantum 

81 data ID matches the expression. 

82 

83 Returns 

84 ------- 

85 mocked : `list` [ `TaskDef` ] 

86 List of `TaskDef` objects using `MockPipelineTask` configurations that 

87 target the original tasks, in the same order. 

88 """ 

89 unmocked_dataset_types = tuple(unmocked_dataset_types) 

90 if force_failures is None: 

91 force_failures = {} 

92 results: list[TaskDef] = [] 

93 for original_task_def in originals: 

94 config = MockPipelineTaskConfig() 

95 config.original.retarget(original_task_def.taskClass) 

96 config.original = original_task_def.config 

97 config.unmocked_dataset_types.extend(unmocked_dataset_types) 

98 if original_task_def.label in force_failures: 

99 condition, exception_type = force_failures[original_task_def.label] 

100 config.fail_condition = condition 

101 if exception_type is not None: 

102 config.fail_exception = get_full_type_name(exception_type) 

103 mock_task_def = TaskDef( 

104 config=config, taskClass=MockPipelineTask, label=get_mock_name(original_task_def.label) 

105 ) 

106 results.append(mock_task_def) 

107 return results 

108 

109 

110class BaseTestPipelineTaskConnections(PipelineTaskConnections, dimensions=()): 

111 pass 

112 

113 

114class BaseTestPipelineTaskConfig(PipelineTaskConfig, pipelineConnections=BaseTestPipelineTaskConnections): 

115 fail_condition = Field[str]( 

116 dtype=str, 

117 default="", 

118 doc=( 

119 "Condition on Data ID to raise an exception. String expression which includes attributes of " 

120 "quantum data ID using a syntax of daf_butler user expressions (e.g. 'visit = 123')." 

121 ), 

122 ) 

123 

124 fail_exception = Field[str]( 

125 dtype=str, 

126 default="builtins.ValueError", 

127 doc=( 

128 "Class name of the exception to raise when fail condition is triggered. Can be " 

129 "'lsst.pipe.base.NoWorkFound' to specify non-failure exception." 

130 ), 

131 ) 

132 

133 def data_id_match(self) -> DataIdMatch | None: 

134 if not self.fail_condition: 

135 return None 

136 return DataIdMatch(self.fail_condition) 

137 

138 

139class BaseTestPipelineTask(PipelineTask): 

140 """A base class for test-utility `PipelineTask` classes that read and write 

141 mock datasets `runQuantum`. 

142 

143 Notes 

144 ----- 

145 This class overrides `runQuantum` to read inputs and write a bit of 

146 provenance into all of its outputs (always `MockDataset` instances). It 

147 can also be configured to raise exceptions on certain data IDs. It reads 

148 `MockDataset` inputs and simulates reading inputs of other types by 

149 creating `MockDataset` inputs from their DatasetRefs. 

150 

151 Subclasses are responsible for defining connections, but init-input and 

152 init-output connections are not supported at runtime (they may be present 

153 as long as the task is never constructed). All output connections must 

154 use mock storage classes. `..Input` and `..PrerequisiteInput` connections 

155 that do not use mock storage classes will be handled by constructing a 

156 `MockDataset` from the `~lsst.daf.butler.DatasetRef` rather than actually 

157 reading them. 

158 """ 

159 

160 ConfigClass: ClassVar[type[PipelineTaskConfig]] = BaseTestPipelineTaskConfig 

161 

162 def __init__( 

163 self, 

164 *, 

165 config: BaseTestPipelineTaskConfig, 

166 **kwargs: Any, 

167 ): 

168 super().__init__(config=config, **kwargs) 

169 self.fail_exception: type | None = None 

170 self.data_id_match = self.config.data_id_match() 

171 if self.data_id_match: 

172 self.fail_exception = doImportType(self.config.fail_exception) 

173 

174 config: BaseTestPipelineTaskConfig 

175 

176 def runQuantum( 

177 self, 

178 butlerQC: QuantumContext, 

179 inputRefs: InputQuantizedConnection, 

180 outputRefs: OutputQuantizedConnection, 

181 ) -> None: 

182 # docstring is inherited from the base class 

183 quantum = butlerQC.quantum 

184 

185 _LOG.info("Mocking execution of task '%s' on quantum %s", self.getName(), quantum.dataId) 

186 

187 assert quantum.dataId is not None, "Quantum DataId cannot be None" 

188 

189 # Possibly raise an exception. 

190 if self.data_id_match is not None and self.data_id_match.match(quantum.dataId): 

191 _LOG.info("Simulating failure of task '%s' on quantum %s", self.getName(), quantum.dataId) 

192 message = f"Simulated failure: task={self.getName()} dataId={quantum.dataId}" 

193 assert self.fail_exception is not None, "Exception type must be defined" 

194 raise self.fail_exception(message) 

195 

196 # Populate the bit of provenance we store in all outputs. 

197 _LOG.info("Reading input data for task '%s' on quantum %s", self.getName(), quantum.dataId) 

198 mock_dataset_quantum = MockDatasetQuantum( 

199 task_label=self.getName(), data_id=quantum.dataId.to_simple(), inputs={} 

200 ) 

201 for name, refs in inputRefs: 

202 inputs_list = [] 

203 ref: DatasetRef 

204 for ref in ensure_iterable(refs): 

205 if isinstance(ref.datasetType.storageClass, MockStorageClass): 

206 input_dataset = butlerQC.get(ref) 

207 if isinstance(input_dataset, DeferredDatasetHandle): 

208 input_dataset = input_dataset.get() 

209 if not isinstance(input_dataset, MockDataset): 

210 raise TypeError( 

211 f"Expected MockDataset instance for {ref}; " 

212 f"got {input_dataset!r} of type {type(input_dataset)!r}." 

213 ) 

214 # To avoid very deep provenance we trim inputs to a single 

215 # level. 

216 input_dataset.quantum = None 

217 else: 

218 input_dataset = MockDataset(ref=ref.to_simple()) 

219 inputs_list.append(input_dataset) 

220 mock_dataset_quantum.inputs[name] = inputs_list 

221 

222 # store mock outputs 

223 for name, refs in outputRefs: 

224 for ref in ensure_iterable(refs): 

225 output = MockDataset( 

226 ref=ref.to_simple(), quantum=mock_dataset_quantum, output_connection_name=name 

227 ) 

228 butlerQC.put(output, ref) 

229 

230 _LOG.info("Finished mocking task '%s' on quantum %s", self.getName(), quantum.dataId) 

231 

232 

233class MockPipelineDefaultTargetConnections(PipelineTaskConnections, dimensions=()): 

234 pass 

235 

236 

237class MockPipelineDefaultTargetConfig( 

238 PipelineTaskConfig, pipelineConnections=MockPipelineDefaultTargetConnections 

239): 

240 pass 

241 

242 

243class MockPipelineDefaultTargetTask(PipelineTask): 

244 """A `~lsst.pipe.base.PipelineTask` class used as the default target for 

245 ``MockPipelineTaskConfig.original``. 

246 

247 This is effectively a workaround for `lsst.pex.config.ConfigurableField` 

248 not supporting ``optional=True``, but that is generally a reasonable 

249 limitation for production code and it wouldn't make sense just to support 

250 test utilities. 

251 """ 

252 

253 ConfigClass = MockPipelineDefaultTargetConfig 

254 

255 

256class MockPipelineTaskConnections(BaseTestPipelineTaskConnections, dimensions=()): 

257 """A connections class that creates mock connections from the connections 

258 of a real PipelineTask. 

259 """ 

260 

261 def __init__(self, *, config: MockPipelineTaskConfig): 

262 original: PipelineTaskConnections = config.original.connections.ConnectionsClass( 

263 config=config.original.value 

264 ) 

265 self.dimensions.update(original.dimensions) 

266 unmocked_dataset_types = frozenset(config.unmocked_dataset_types) 

267 for name, connection in original.allConnections.items(): 

268 if name in original.initInputs or name in original.initOutputs: 

269 # We just ignore initInputs and initOutputs, because the task 

270 # is never given DatasetRefs for those and hence can't create 

271 # mocks. 

272 continue 

273 if connection.name not in unmocked_dataset_types: 

274 # We register the mock storage class with the global singleton 

275 # here, but can only put its name in the connection. That means 

276 # the same global singleton (or one that also has these 

277 # registrations) has to be available whenever this dataset type 

278 # is used. 

279 storage_class = MockStorageClass.get_or_register_mock(connection.storageClass) 

280 kwargs = {} 

281 if hasattr(connection, "dimensions"): 

282 connection_dimensions = set(connection.dimensions) 

283 # Replace the generic "skypix" placeholder with htm7, since 

284 # that requires the dataset type to have already been 

285 # registered. 

286 if "skypix" in connection_dimensions: 

287 connection_dimensions.remove("skypix") 

288 connection_dimensions.add("htm7") 

289 kwargs["dimensions"] = connection_dimensions 

290 connection = dataclasses.replace( 

291 connection, 

292 name=get_mock_name(connection.name), 

293 storageClass=storage_class.name, 

294 **kwargs, 

295 ) 

296 elif name in original.outputs: 

297 raise ValueError(f"Unmocked dataset type {connection.name!r} cannot be used as an output.") 

298 setattr(self, name, connection) 

299 

300 

301class MockPipelineTaskConfig(BaseTestPipelineTaskConfig, pipelineConnections=MockPipelineTaskConnections): 

302 """Configuration class for `MockPipelineTask`.""" 

303 

304 original: ConfigurableField = ConfigurableField( 

305 doc="The original task being mocked by this one.", target=MockPipelineDefaultTargetTask 

306 ) 

307 

308 unmocked_dataset_types = ListField[str]( 

309 doc=( 

310 "Names of input dataset types that should be used as-is instead " 

311 "of being mocked. May include dataset types not relevant for " 

312 "this task, which will be ignored." 

313 ), 

314 default=(), 

315 optional=False, 

316 ) 

317 

318 

319class MockPipelineTask(BaseTestPipelineTask): 

320 """A test-utility implementation of `PipelineTask` with connections 

321 generated by mocking those of a real task. 

322 

323 Notes 

324 ----- 

325 At present `MockPipelineTask` simply drops any ``initInput`` and 

326 ``initOutput`` connections present on the original, since `MockDataset` 

327 creation for those would have to happen in the code that executes the task, 

328 not in the task itself. Because `MockPipelineTask` never instantiates the 

329 mock task (just its connections class), this is a limitation on what the 

330 mocks can be used to test, not anything deeper. 

331 """ 

332 

333 ConfigClass: ClassVar[type[PipelineTaskConfig]] = MockPipelineTaskConfig 

334 

335 

336class DynamicConnectionConfig(Config): 

337 """A config class that defines a completely dynamic connection.""" 

338 

339 dataset_type_name = Field[str](doc="Name for the dataset type as seen by the butler.", dtype=str) 

340 dimensions = ListField[str](doc="Dimensions for the dataset type.", dtype=str, default=[]) 

341 storage_class = Field[str]( 

342 doc="Name of the butler storage class for the dataset type.", dtype=str, default="StructuredDataDict" 

343 ) 

344 is_calibration = Field[bool](doc="Whether this dataset type is a calibration.", dtype=bool, default=False) 

345 multiple = Field[bool]( 

346 doc="Whether this connection gets or puts multiple datasets for each quantum.", 

347 dtype=bool, 

348 default=False, 

349 ) 

350 mock_storage_class = Field[bool]( 

351 doc="Whether the storage class should actually be a mock of the storage class given.", 

352 dtype=bool, 

353 default=True, 

354 ) 

355 

356 def make_connection(self, cls: type[_T]) -> _T: 

357 storage_class = self.storage_class 

358 if self.mock_storage_class: 

359 storage_class = MockStorageClass.get_or_register_mock(storage_class).name 

360 if issubclass(cls, cT.DimensionedConnection): 

361 return cls( # type: ignore 

362 name=self.dataset_type_name, 

363 storageClass=storage_class, 

364 isCalibration=self.is_calibration, 

365 multiple=self.multiple, 

366 dimensions=frozenset(self.dimensions), 

367 ) 

368 else: 

369 return cls( 

370 name=self.dataset_type_name, 

371 storageClass=storage_class, 

372 multiple=self.multiple, 

373 ) 

374 

375 

376class DynamicTestPipelineTaskConnections(PipelineTaskConnections, dimensions=()): 

377 """A connections class whose dimensions and connections are wholly 

378 determined via configuration. 

379 """ 

380 

381 def __init__(self, *, config: DynamicTestPipelineTaskConfig): 

382 self.dimensions.update(config.dimensions) 

383 connection_config: DynamicConnectionConfig 

384 for connection_name, connection_config in config.init_inputs.items(): 

385 setattr(self, connection_name, connection_config.make_connection(cT.InitInput)) 

386 for connection_name, connection_config in config.init_outputs.items(): 

387 setattr(self, connection_name, connection_config.make_connection(cT.InitOutput)) 

388 for connection_name, connection_config in config.prerequisite_inputs.items(): 

389 setattr(self, connection_name, connection_config.make_connection(cT.PrerequisiteInput)) 

390 for connection_name, connection_config in config.inputs.items(): 

391 setattr(self, connection_name, connection_config.make_connection(cT.Input)) 

392 for connection_name, connection_config in config.outputs.items(): 

393 setattr(self, connection_name, connection_config.make_connection(cT.Output)) 

394 

395 

396class DynamicTestPipelineTaskConfig( 

397 PipelineTaskConfig, pipelineConnections=DynamicTestPipelineTaskConnections 

398): 

399 """Configuration for DynamicTestPipelineTask.""" 

400 

401 dimensions = ListField[str](doc="Dimensions for the task's quanta.", dtype=str, default=[]) 

402 init_inputs = ConfigDictField( 

403 doc=( 

404 "Init-input connections, keyed by the connection name as seen by the task. " 

405 "Must be empty if the task will be constructed." 

406 ), 

407 keytype=str, 

408 itemtype=DynamicConnectionConfig, 

409 default={}, 

410 ) 

411 init_outputs = ConfigDictField( 

412 doc=( 

413 "Init-output connections, keyed by the connection name as seen by the task. " 

414 "Must be empty if the task will be constructed." 

415 ), 

416 keytype=str, 

417 itemtype=DynamicConnectionConfig, 

418 default={}, 

419 ) 

420 prerequisite_inputs = ConfigDictField( 

421 doc="Prerequisite input connections, keyed by the connection name as seen by the task.", 

422 keytype=str, 

423 itemtype=DynamicConnectionConfig, 

424 default={}, 

425 ) 

426 inputs = ConfigDictField( 

427 doc="Regular input connections, keyed by the connection name as seen by the task.", 

428 keytype=str, 

429 itemtype=DynamicConnectionConfig, 

430 default={}, 

431 ) 

432 outputs = ConfigDictField( 

433 doc="Regular output connections, keyed by the connection name as seen by the task.", 

434 keytype=str, 

435 itemtype=DynamicConnectionConfig, 

436 default={}, 

437 ) 

438 

439 

440class DynamicTestPipelineTask(BaseTestPipelineTask): 

441 """A test-utility implementation of `PipelineTask` with dimensions and 

442 connections determined wholly from configuration. 

443 """ 

444 

445 ConfigClass: ClassVar[type[PipelineTaskConfig]] = DynamicTestPipelineTaskConfig