Coverage for python/lsst/pipe/base/tests/mocks/_pipeline_task.py: 25%

110 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-07-12 11:14 -0700

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("MockPipelineTask", "MockPipelineTaskConfig", "mock_task_defs") 

24 

25import dataclasses 

26import logging 

27from collections.abc import Iterable, Mapping 

28from typing import TYPE_CHECKING, Any, ClassVar 

29 

30from lsst.daf.butler import DeferredDatasetHandle 

31from lsst.pex.config import ConfigurableField, Field, ListField 

32from lsst.utils.doImport import doImportType 

33from lsst.utils.introspection import get_full_type_name 

34from lsst.utils.iteration import ensure_iterable 

35 

36from ...config import PipelineTaskConfig 

37from ...connections import InputQuantizedConnection, OutputQuantizedConnection, PipelineTaskConnections 

38from ...pipeline import TaskDef 

39from ...pipelineTask import PipelineTask 

40from ._data_id_match import DataIdMatch 

41from ._storage_class import MockDataset, MockDatasetQuantum, MockStorageClass, get_mock_name 

42 

43_LOG = logging.getLogger(__name__) 

44 

45if TYPE_CHECKING: 

46 from ..._quantumContext import QuantumContext 

47 

48 

49def mock_task_defs( 

50 originals: Iterable[TaskDef], 

51 unmocked_dataset_types: Iterable[str] = (), 

52 force_failures: Mapping[str, tuple[str, type[Exception] | None]] | None = None, 

53) -> list[TaskDef]: 

54 """Create mocks for an iterable of TaskDefs. 

55 

56 Parameters 

57 ---------- 

58 originals : `~collections.abc.Iterable` [ `TaskDef` ] 

59 Original tasks and configuration to mock. 

60 unmocked_dataset_types : `~collections.abc.Iterable` [ `str` ], optional 

61 Names of overall-input dataset types that should not be replaced with 

62 mocks. 

63 force_failures : `~collections.abc.Mapping` [ `str`, `tuple` [ `str`, \ 

64 `type` [ `Exception` ] or `None` ] ] 

65 Mapping from original task label to a 2-tuple indicating that some 

66 quanta should raise an exception when executed. The first entry is a 

67 data ID match using the butler expression language (i.e. a string of 

68 the sort passed ass the ``where`` argument to butler query methods), 

69 while the second is the type of exception to raise when the quantum 

70 data ID matches the expression. 

71 

72 Returns 

73 ------- 

74 mocked : `list` [ `TaskDef` ] 

75 List of `TaskDef` objects using `MockPipelineTask` configurations that 

76 target the original tasks, in the same order. 

77 """ 

78 unmocked_dataset_types = tuple(unmocked_dataset_types) 

79 if force_failures is None: 

80 force_failures = {} 

81 results: list[TaskDef] = [] 

82 for original_task_def in originals: 

83 config = MockPipelineTaskConfig() 

84 config.original.retarget(original_task_def.taskClass) 

85 config.original = original_task_def.config 

86 config.unmocked_dataset_types.extend(unmocked_dataset_types) 

87 if original_task_def.label in force_failures: 

88 condition, exception_type = force_failures[original_task_def.label] 

89 config.fail_condition = condition 

90 if exception_type is not None: 

91 config.fail_exception = get_full_type_name(exception_type) 

92 mock_task_def = TaskDef( 

93 config=config, taskClass=MockPipelineTask, label=get_mock_name(original_task_def.label) 

94 ) 

95 results.append(mock_task_def) 

96 return results 

97 

98 

99class MockPipelineDefaultTargetConnections(PipelineTaskConnections, dimensions=()): 

100 pass 

101 

102 

103class MockPipelineDefaultTargetConfig( 

104 PipelineTaskConfig, pipelineConnections=MockPipelineDefaultTargetConnections 

105): 

106 pass 

107 

108 

109class MockPipelineDefaultTargetTask(PipelineTask): 

110 """A `~lsst.pipe.base.PipelineTask` class used as the default target for 

111 ``MockPipelineTaskConfig.original``. 

112 

113 This is effectively a workaround for `lsst.pex.config.ConfigurableField` 

114 not supporting ``optional=True``, but that is generally a reasonable 

115 limitation for production code and it wouldn't make sense just to support 

116 test utilities. 

117 """ 

118 

119 ConfigClass = MockPipelineDefaultTargetConfig 

120 

121 

122class MockPipelineTaskConnections(PipelineTaskConnections, dimensions=()): 

123 def __init__(self, *, config: MockPipelineTaskConfig): 

124 original: PipelineTaskConnections = config.original.connections.ConnectionsClass( 

125 config=config.original.value 

126 ) 

127 self.dimensions.update(original.dimensions) 

128 unmocked_dataset_types = frozenset(config.unmocked_dataset_types) 

129 for name, connection in original.allConnections.items(): 

130 if name in original.initInputs or name in original.initOutputs: 

131 # We just ignore initInputs and initOutputs, because the task 

132 # is never given DatasetRefs for those and hence can't create 

133 # mocks. 

134 continue 

135 if connection.name not in unmocked_dataset_types: 

136 # We register the mock storage class with the global singleton 

137 # here, but can only put its name in the connection. That means 

138 # the same global singleton (or one that also has these 

139 # registrations) has to be available whenever this dataset type 

140 # is used. 

141 storage_class = MockStorageClass.get_or_register_mock(connection.storageClass) 

142 kwargs = {} 

143 if hasattr(connection, "dimensions"): 

144 connection_dimensions = set(connection.dimensions) 

145 # Replace the generic "skypix" placeholder with htm7, since 

146 # that requires the dataset type to have already been 

147 # registered. 

148 if "skypix" in connection_dimensions: 

149 connection_dimensions.remove("skypix") 

150 connection_dimensions.add("htm7") 

151 kwargs["dimensions"] = connection_dimensions 

152 connection = dataclasses.replace( 

153 connection, 

154 name=get_mock_name(connection.name), 

155 storageClass=storage_class.name, 

156 **kwargs, 

157 ) 

158 elif name in original.outputs: 

159 raise ValueError(f"Unmocked dataset type {connection.name!r} cannot be used as an output.") 

160 setattr(self, name, connection) 

161 

162 

163class MockPipelineTaskConfig(PipelineTaskConfig, pipelineConnections=MockPipelineTaskConnections): 

164 """Configuration class for `MockPipelineTask`.""" 

165 

166 fail_condition = Field[str]( 

167 dtype=str, 

168 default="", 

169 doc=( 

170 "Condition on Data ID to raise an exception. String expression which includes attributes of " 

171 "quantum data ID using a syntax of daf_butler user expressions (e.g. 'visit = 123')." 

172 ), 

173 ) 

174 

175 fail_exception = Field[str]( 

176 dtype=str, 

177 default="builtins.ValueError", 

178 doc=( 

179 "Class name of the exception to raise when fail condition is triggered. Can be " 

180 "'lsst.pipe.base.NoWorkFound' to specify non-failure exception." 

181 ), 

182 ) 

183 

184 original: ConfigurableField = ConfigurableField( 

185 doc="The original task being mocked by this one.", target=MockPipelineDefaultTargetTask 

186 ) 

187 

188 unmocked_dataset_types = ListField[str]( 

189 doc=( 

190 "Names of input dataset types that should be used as-is instead " 

191 "of being mocked. May include dataset types not relevant for " 

192 "this task, which will be ignored." 

193 ), 

194 default=(), 

195 optional=False, 

196 ) 

197 

198 def data_id_match(self) -> DataIdMatch | None: 

199 if not self.fail_condition: 

200 return None 

201 return DataIdMatch(self.fail_condition) 

202 

203 

204class MockPipelineTask(PipelineTask): 

205 """Implementation of `~lsst.pipe.base.PipelineTask` used for running a 

206 mock pipeline. 

207 

208 Notes 

209 ----- 

210 This class overrides `runQuantum` to read inputs and write a bit of 

211 provenance into all of its outputs (always `MockDataset` instances). It 

212 can also be configured to raise exceptions on certain data IDs. It reads 

213 `MockDataset` inputs and simulates reading inputs of other types by 

214 creating `MockDataset` inputs from their DatasetRefs. 

215 

216 At present `MockPipelineTask` simply drops any ``initInput`` and 

217 ``initOutput`` connections present on the original, since `MockDataset` 

218 creation for those would have to happen in the code that executes the task, 

219 not in the task itself. Because `MockPipelineTask` never instantiates the 

220 mock task (just its connections class), this is a limitation on what the 

221 mocks can be used to test, not anything deeper. 

222 """ 

223 

224 ConfigClass: ClassVar[type[PipelineTaskConfig]] = MockPipelineTaskConfig 

225 

226 def __init__( 

227 self, 

228 *, 

229 config: MockPipelineTaskConfig, 

230 **kwargs: Any, 

231 ): 

232 super().__init__(config=config, **kwargs) 

233 self.fail_exception: type | None = None 

234 self.data_id_match = self.config.data_id_match() 

235 if self.data_id_match: 

236 self.fail_exception = doImportType(self.config.fail_exception) 

237 

238 config: MockPipelineTaskConfig 

239 

240 def runQuantum( 

241 self, 

242 butlerQC: QuantumContext, 

243 inputRefs: InputQuantizedConnection, 

244 outputRefs: OutputQuantizedConnection, 

245 ) -> None: 

246 # docstring is inherited from the base class 

247 quantum = butlerQC.quantum 

248 

249 _LOG.info("Mocking execution of task '%s' on quantum %s", self.getName(), quantum.dataId) 

250 

251 assert quantum.dataId is not None, "Quantum DataId cannot be None" 

252 

253 # Possibly raise an exception. 

254 if self.data_id_match is not None and self.data_id_match.match(quantum.dataId): 

255 _LOG.info("Simulating failure of task '%s' on quantum %s", self.getName(), quantum.dataId) 

256 message = f"Simulated failure: task={self.getName()} dataId={quantum.dataId}" 

257 assert self.fail_exception is not None, "Exception type must be defined" 

258 raise self.fail_exception(message) 

259 

260 # Populate the bit of provenance we store in all outputs. 

261 _LOG.info("Reading input data for task '%s' on quantum %s", self.getName(), quantum.dataId) 

262 mock_dataset_quantum = MockDatasetQuantum( 

263 task_label=self.getName(), data_id=quantum.dataId.to_simple(), inputs={} 

264 ) 

265 for name, refs in inputRefs: 

266 inputs_list = [] 

267 for ref in ensure_iterable(refs): 

268 if isinstance(ref.datasetType.storageClass, MockStorageClass): 

269 input_dataset = butlerQC.get(ref) 

270 if isinstance(input_dataset, DeferredDatasetHandle): 

271 input_dataset = input_dataset.get() 

272 if not isinstance(input_dataset, MockDataset): 

273 raise TypeError( 

274 f"Expected MockDataset instance for {ref}; " 

275 f"got {input_dataset!r} of type {type(input_dataset)!r}." 

276 ) 

277 # To avoid very deep provenance we trim inputs to a single 

278 # level. 

279 input_dataset.quantum = None 

280 else: 

281 input_dataset = MockDataset(ref=ref.to_simple()) 

282 inputs_list.append(input_dataset) 

283 mock_dataset_quantum.inputs[name] = inputs_list 

284 

285 # store mock outputs 

286 for name, refs in outputRefs: 

287 for ref in ensure_iterable(refs): 

288 output = MockDataset( 

289 ref=ref.to_simple(), quantum=mock_dataset_quantum, output_connection_name=name 

290 ) 

291 butlerQC.put(output, ref) 

292 

293 _LOG.info("Finished mocking task '%s' on quantum %s", self.getName(), quantum.dataId)