Coverage for python/lsst/pipe/base/tests/mocks/_pipeline_task.py: 25%

111 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-11 02:00 -0700

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("MockPipelineTask", "MockPipelineTaskConfig", "mock_task_defs") 

24 

25import dataclasses 

26import logging 

27from collections.abc import Iterable, Mapping 

28from typing import TYPE_CHECKING, Any, ClassVar 

29 

30from lsst.daf.butler import DeferredDatasetHandle 

31from lsst.pex.config import ConfigurableField, Field, ListField 

32from lsst.utils.doImport import doImportType 

33from lsst.utils.introspection import get_full_type_name 

34from lsst.utils.iteration import ensure_iterable 

35 

36from ...config import PipelineTaskConfig 

37from ...connections import InputQuantizedConnection, OutputQuantizedConnection, PipelineTaskConnections 

38from ...pipeline import TaskDef 

39from ...pipelineTask import PipelineTask 

40from ._data_id_match import DataIdMatch 

41from ._storage_class import MockDataset, MockDatasetQuantum, MockStorageClass, get_mock_name 

42 

43_LOG = logging.getLogger(__name__) 

44 

45if TYPE_CHECKING: 

46 from ...butlerQuantumContext import ButlerQuantumContext 

47 

48 

49def mock_task_defs( 

50 originals: Iterable[TaskDef], 

51 unmocked_dataset_types: Iterable[str] = (), 

52 force_failures: Mapping[str, tuple[str, type[Exception]]] | None = None, 

53) -> list[TaskDef]: 

54 """Create mocks for an iterable of TaskDefs. 

55 

56 Parameters 

57 ---------- 

58 originals : `~collections.abc.Iterable` [ `TaskDef` ] 

59 Original tasks and configuration to mock. 

60 unmocked_dataset_types : `~collections.abc.Iterable` [ `str` ], optional 

61 Names of overall-input dataset types that should not be replaced with 

62 mocks. 

63 force_failures : `~collections.abc.Mapping` [ `str`, `tuple` [ `str`, \ 

64 `type` [ `Exception` ] ] ] 

65 Mapping from original task label to a 2-tuple indicating that some 

66 quanta should raise an exception when executed. The first entry is a 

67 data ID match using the butler expression language (i.e. a string of 

68 the sort passed ass the ``where`` argument to butler query methods), 

69 while the second is the type of exception to raise when the quantum 

70 data ID matches the expression. 

71 

72 Returns 

73 ------- 

74 mocked : `list` [ `TaskDef` ] 

75 List of `TaskDef` objects using `MockPipelineTask` configurations that 

76 target the original tasks, in the same order. 

77 """ 

78 unmocked_dataset_types = tuple(unmocked_dataset_types) 

79 if force_failures is None: 

80 force_failures = {} 

81 results: list[TaskDef] = [] 

82 for original_task_def in originals: 

83 config = MockPipelineTaskConfig() 

84 config.original.retarget(original_task_def.taskClass) 

85 config.original = original_task_def.config 

86 config.unmocked_dataset_types.extend(unmocked_dataset_types) 

87 if original_task_def.label in force_failures: 

88 condition, exception_type = force_failures[original_task_def.label] 

89 config.fail_condition = condition 

90 config.fail_exception = get_full_type_name(exception_type) 

91 mock_task_def = TaskDef( 

92 config=config, taskClass=MockPipelineTask, label=get_mock_name(original_task_def.label) 

93 ) 

94 results.append(mock_task_def) 

95 return results 

96 

97 

98class MockPipelineDefaultTargetConnections(PipelineTaskConnections, dimensions=()): 

99 pass 

100 

101 

102class MockPipelineDefaultTargetConfig( 

103 PipelineTaskConfig, pipelineConnections=MockPipelineDefaultTargetConnections 

104): 

105 pass 

106 

107 

108class MockPipelineDefaultTargetTask(PipelineTask): 

109 """A `PipelineTask` class used as the default target for 

110 ``MockPipelineTaskConfig.original``. 

111 

112 This is effectively a workaround for `lsst.pex.config.ConfigurableField` 

113 not supporting ``optional=True``, but that is generally a reasonable 

114 limitation for production code and it wouldn't make sense just to support 

115 test utilities. 

116 """ 

117 

118 ConfigClass = MockPipelineDefaultTargetConfig 

119 

120 

121class MockPipelineTaskConnections(PipelineTaskConnections, dimensions=()): 

122 def __init__(self, *, config: MockPipelineTaskConfig): 

123 original: PipelineTaskConnections = config.original.connections.ConnectionsClass( 

124 config=config.original.value 

125 ) 

126 self.dimensions.update(original.dimensions) 

127 unmocked_dataset_types = frozenset(config.unmocked_dataset_types) 

128 for name, connection in original.allConnections.items(): 

129 if name in original.initInputs or name in original.initOutputs: 

130 # We just ignore initInputs and initOutputs, because the task 

131 # is never given DatasetRefs for those and hence can't create 

132 # mocks. 

133 continue 

134 if connection.name not in unmocked_dataset_types: 

135 # We register the mock storage class with the global singleton 

136 # here, but can only put its name in the connection. That means 

137 # the same global singleton (or one that also has these 

138 # registrations) has to be available whenever this dataset type 

139 # is used. 

140 storage_class = MockStorageClass.get_or_register_mock(connection.storageClass) 

141 kwargs = {} 

142 if hasattr(connection, "dimensions"): 

143 connection_dimensions = set(connection.dimensions) 

144 # Replace the generic "skypix" placeholder with htm7, since 

145 # that requires the dataset type to have already been 

146 # registered. 

147 if "skypix" in connection_dimensions: 

148 connection_dimensions.remove("skypix") 

149 connection_dimensions.add("htm7") 

150 kwargs["dimensions"] = connection_dimensions 

151 connection = dataclasses.replace( 

152 connection, 

153 name=get_mock_name(connection.name), 

154 storageClass=storage_class.name, 

155 **kwargs, 

156 ) 

157 elif name in original.outputs: 

158 raise ValueError(f"Unmocked dataset type {connection.name!r} cannot be used as an output.") 

159 setattr(self, name, connection) 

160 

161 

162class MockPipelineTaskConfig(PipelineTaskConfig, pipelineConnections=MockPipelineTaskConnections): 

163 """Configuration class for `MockPipelineTask`.""" 

164 

165 fail_condition = Field[str]( 

166 dtype=str, 

167 default="", 

168 doc=( 

169 "Condition on Data ID to raise an exception. String expression which includes attributes of " 

170 "quantum data ID using a syntax of daf_butler user expressions (e.g. 'visit = 123')." 

171 ), 

172 ) 

173 

174 fail_exception = Field[str]( 

175 dtype=str, 

176 default="builtins.ValueError", 

177 doc=( 

178 "Class name of the exception to raise when fail condition is triggered. Can be " 

179 "'lsst.pipe.base.NoWorkFound' to specify non-failure exception." 

180 ), 

181 ) 

182 

183 original: ConfigurableField = ConfigurableField( 

184 doc="The original task being mocked by this one.", target=MockPipelineDefaultTargetTask 

185 ) 

186 

187 unmocked_dataset_types = ListField[str]( 

188 doc=( 

189 "Names of input dataset types that should be used as-is instead " 

190 "of being mocked. May include dataset types not relevant for " 

191 "this task, which will be ignored." 

192 ), 

193 default=(), 

194 optional=False, 

195 ) 

196 

197 def data_id_match(self) -> DataIdMatch | None: 

198 if not self.fail_condition: 

199 return None 

200 return DataIdMatch(self.fail_condition) 

201 

202 

203class MockPipelineTask(PipelineTask): 

204 """Implementation of `PipelineTask` used for running a mock pipeline. 

205 

206 Notes 

207 ----- 

208 This class overrides `runQuantum` to read inputs and write a bit of 

209 provenance into all of its outputs (always `MockDataset` instances). It 

210 can also be configured to raise exceptions on certain data IDs. It reads 

211 `MockDataset` inputs and simulates reading inputs of other types by 

212 creating `MockDataset` inputs from their DatasetRefs. 

213 

214 At present `MockPipelineTask` simply drops any ``initInput`` and 

215 ``initOutput`` connections present on the original, since `MockDataset` 

216 creation for those would have to happen in the code that executes the task, 

217 not in the task itself. Because `MockPipelineTask` never instantiates the 

218 mock task (just its connections class), this is a limitation on what the 

219 mocks can be used to test, not anything deeper. 

220 """ 

221 

222 ConfigClass: ClassVar[type[PipelineTaskConfig]] = MockPipelineTaskConfig 

223 

224 def __init__( 

225 self, 

226 *, 

227 config: MockPipelineTaskConfig, 

228 **kwargs: Any, 

229 ): 

230 super().__init__(config=config, **kwargs) 

231 self.fail_exception: type | None = None 

232 self.data_id_match = self.config.data_id_match() 

233 if self.data_id_match: 

234 self.fail_exception = doImportType(self.config.fail_exception) 

235 

236 config: MockPipelineTaskConfig 

237 

238 def runQuantum( 

239 self, 

240 butlerQC: ButlerQuantumContext, 

241 inputRefs: InputQuantizedConnection, 

242 outputRefs: OutputQuantizedConnection, 

243 ) -> None: 

244 # docstring is inherited from the base class 

245 quantum = butlerQC.quantum 

246 

247 _LOG.info("Mocking execution of task '%s' on quantum %s", self.getName(), quantum.dataId) 

248 

249 assert quantum.dataId is not None, "Quantum DataId cannot be None" 

250 

251 # Possibly raise an exception. 

252 if self.data_id_match is not None and self.data_id_match.match(quantum.dataId): 

253 _LOG.info("Simulating failure of task '%s' on quantum %s", self.getName(), quantum.dataId) 

254 message = f"Simulated failure: task={self.getName()} dataId={quantum.dataId}" 

255 assert self.fail_exception is not None, "Exception type must be defined" 

256 raise self.fail_exception(message) 

257 

258 # Populate the bit of provenance we store in all outputs. 

259 _LOG.info("Reading input data for task '%s' on quantum %s", self.getName(), quantum.dataId) 

260 mock_dataset_quantum = MockDatasetQuantum( 

261 task_label=self.getName(), data_id=quantum.dataId.to_simple(), inputs={} 

262 ) 

263 for name, refs in inputRefs: 

264 inputs_list = [] 

265 for ref in ensure_iterable(refs): 

266 if isinstance(ref.datasetType.storageClass, MockStorageClass): 

267 input_dataset = butlerQC.get(ref) 

268 if isinstance(input_dataset, DeferredDatasetHandle): 

269 input_dataset = input_dataset.get() 

270 if not isinstance(input_dataset, MockDataset): 

271 raise TypeError( 

272 f"Expected MockDataset instance for {ref}; " 

273 f"got {input_dataset!r} of type {type(input_dataset)!r}." 

274 ) 

275 # To avoid very deep provenance we trim inputs to a single 

276 # level. 

277 input_dataset.quantum = None 

278 else: 

279 input_dataset = MockDataset(ref=ref.to_simple()) 

280 inputs_list.append(input_dataset) 

281 mock_dataset_quantum.inputs[name] = inputs_list 

282 

283 # store mock outputs 

284 for name, refs in outputRefs: 

285 if not isinstance(refs, list): 

286 refs = [refs] 

287 for ref in refs: 

288 output = MockDataset( 

289 ref=ref.to_simple(), quantum=mock_dataset_quantum, output_connection_name=name 

290 ) 

291 butlerQC.put(output, ref) 

292 

293 _LOG.info("Finished mocking task '%s' on quantum %s", self.getName(), quantum.dataId)