Coverage for python/lsst/pipe/base/tests/mocks/_pipeline_task.py: 31%
153 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-08-06 02:28 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-08-06 02:28 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = (
24 "DynamicConnectionConfig",
25 "DynamicTestPipelineTask",
26 "DynamicTestPipelineTaskConfig",
27 "MockPipelineTask",
28 "MockPipelineTaskConfig",
29 "mock_task_defs",
30)
32import dataclasses
33import logging
34from collections.abc import Iterable, Mapping
35from typing import TYPE_CHECKING, Any, ClassVar, TypeVar
37from lsst.daf.butler import DatasetRef, DeferredDatasetHandle
38from lsst.pex.config import Config, ConfigDictField, ConfigurableField, Field, ListField
39from lsst.utils.doImport import doImportType
40from lsst.utils.introspection import get_full_type_name
41from lsst.utils.iteration import ensure_iterable
43from ... import connectionTypes as cT
44from ...config import PipelineTaskConfig
45from ...connections import InputQuantizedConnection, OutputQuantizedConnection, PipelineTaskConnections
46from ...pipeline import TaskDef
47from ...pipelineTask import PipelineTask
48from ._data_id_match import DataIdMatch
49from ._storage_class import MockDataset, MockDatasetQuantum, MockStorageClass, get_mock_name
51_LOG = logging.getLogger(__name__)
53if TYPE_CHECKING:
54 from ..._quantumContext import QuantumContext
57_T = TypeVar("_T", bound=cT.BaseConnection)
60def mock_task_defs(
61 originals: Iterable[TaskDef],
62 unmocked_dataset_types: Iterable[str] = (),
63 force_failures: Mapping[str, tuple[str, type[Exception] | None]] | None = None,
64) -> list[TaskDef]:
65 """Create mocks for an iterable of TaskDefs.
67 Parameters
68 ----------
69 originals : `~collections.abc.Iterable` [ `TaskDef` ]
70 Original tasks and configuration to mock.
71 unmocked_dataset_types : `~collections.abc.Iterable` [ `str` ], optional
72 Names of overall-input dataset types that should not be replaced with
73 mocks.
74 force_failures : `~collections.abc.Mapping` [ `str`, `tuple` [ `str`, \
75 `type` [ `Exception` ] or `None` ] ]
76 Mapping from original task label to a 2-tuple indicating that some
77 quanta should raise an exception when executed. The first entry is a
78 data ID match using the butler expression language (i.e. a string of
79 the sort passed ass the ``where`` argument to butler query methods),
80 while the second is the type of exception to raise when the quantum
81 data ID matches the expression.
83 Returns
84 -------
85 mocked : `list` [ `TaskDef` ]
86 List of `TaskDef` objects using `MockPipelineTask` configurations that
87 target the original tasks, in the same order.
88 """
89 unmocked_dataset_types = tuple(unmocked_dataset_types)
90 if force_failures is None:
91 force_failures = {}
92 results: list[TaskDef] = []
93 for original_task_def in originals:
94 config = MockPipelineTaskConfig()
95 config.original.retarget(original_task_def.taskClass)
96 config.original = original_task_def.config
97 config.unmocked_dataset_types.extend(unmocked_dataset_types)
98 if original_task_def.label in force_failures:
99 condition, exception_type = force_failures[original_task_def.label]
100 config.fail_condition = condition
101 if exception_type is not None:
102 config.fail_exception = get_full_type_name(exception_type)
103 mock_task_def = TaskDef(
104 config=config, taskClass=MockPipelineTask, label=get_mock_name(original_task_def.label)
105 )
106 results.append(mock_task_def)
107 return results
110class BaseTestPipelineTaskConnections(PipelineTaskConnections, dimensions=()):
111 pass
114class BaseTestPipelineTaskConfig(PipelineTaskConfig, pipelineConnections=BaseTestPipelineTaskConnections):
115 fail_condition = Field[str](
116 dtype=str,
117 default="",
118 doc=(
119 "Condition on Data ID to raise an exception. String expression which includes attributes of "
120 "quantum data ID using a syntax of daf_butler user expressions (e.g. 'visit = 123')."
121 ),
122 )
124 fail_exception = Field[str](
125 dtype=str,
126 default="builtins.ValueError",
127 doc=(
128 "Class name of the exception to raise when fail condition is triggered. Can be "
129 "'lsst.pipe.base.NoWorkFound' to specify non-failure exception."
130 ),
131 )
133 def data_id_match(self) -> DataIdMatch | None:
134 if not self.fail_condition:
135 return None
136 return DataIdMatch(self.fail_condition)
139class BaseTestPipelineTask(PipelineTask):
140 """A base class for test-utility `PipelineTask` classes that read and write
141 mock datasets `runQuantum`.
143 Notes
144 -----
145 This class overrides `runQuantum` to read inputs and write a bit of
146 provenance into all of its outputs (always `MockDataset` instances). It
147 can also be configured to raise exceptions on certain data IDs. It reads
148 `MockDataset` inputs and simulates reading inputs of other types by
149 creating `MockDataset` inputs from their DatasetRefs.
151 Subclasses are responsible for defining connections, but init-input and
152 init-output connections are not supported at runtime (they may be present
153 as long as the task is never constructed). All output connections must
154 use mock storage classes. `..Input` and `..PrerequisiteInput` connections
155 that do not use mock storage classes will be handled by constructing a
156 `MockDataset` from the `~lsst.daf.butler.DatasetRef` rather than actually
157 reading them.
158 """
160 ConfigClass: ClassVar[type[PipelineTaskConfig]] = BaseTestPipelineTaskConfig
162 def __init__(
163 self,
164 *,
165 config: BaseTestPipelineTaskConfig,
166 **kwargs: Any,
167 ):
168 super().__init__(config=config, **kwargs)
169 self.fail_exception: type | None = None
170 self.data_id_match = self.config.data_id_match()
171 if self.data_id_match:
172 self.fail_exception = doImportType(self.config.fail_exception)
174 config: BaseTestPipelineTaskConfig
176 def runQuantum(
177 self,
178 butlerQC: QuantumContext,
179 inputRefs: InputQuantizedConnection,
180 outputRefs: OutputQuantizedConnection,
181 ) -> None:
182 # docstring is inherited from the base class
183 quantum = butlerQC.quantum
185 _LOG.info("Mocking execution of task '%s' on quantum %s", self.getName(), quantum.dataId)
187 assert quantum.dataId is not None, "Quantum DataId cannot be None"
189 # Possibly raise an exception.
190 if self.data_id_match is not None and self.data_id_match.match(quantum.dataId):
191 _LOG.info("Simulating failure of task '%s' on quantum %s", self.getName(), quantum.dataId)
192 message = f"Simulated failure: task={self.getName()} dataId={quantum.dataId}"
193 assert self.fail_exception is not None, "Exception type must be defined"
194 raise self.fail_exception(message)
196 # Populate the bit of provenance we store in all outputs.
197 _LOG.info("Reading input data for task '%s' on quantum %s", self.getName(), quantum.dataId)
198 mock_dataset_quantum = MockDatasetQuantum(
199 task_label=self.getName(), data_id=quantum.dataId.to_simple(), inputs={}
200 )
201 for name, refs in inputRefs:
202 inputs_list = []
203 ref: DatasetRef
204 for ref in ensure_iterable(refs):
205 if isinstance(ref.datasetType.storageClass, MockStorageClass):
206 input_dataset = butlerQC.get(ref)
207 if isinstance(input_dataset, DeferredDatasetHandle):
208 input_dataset = input_dataset.get()
209 if not isinstance(input_dataset, MockDataset):
210 raise TypeError(
211 f"Expected MockDataset instance for {ref}; "
212 f"got {input_dataset!r} of type {type(input_dataset)!r}."
213 )
214 # To avoid very deep provenance we trim inputs to a single
215 # level.
216 input_dataset.quantum = None
217 else:
218 input_dataset = MockDataset(ref=ref.to_simple())
219 inputs_list.append(input_dataset)
220 mock_dataset_quantum.inputs[name] = inputs_list
222 # store mock outputs
223 for name, refs in outputRefs:
224 for ref in ensure_iterable(refs):
225 output = MockDataset(
226 ref=ref.to_simple(), quantum=mock_dataset_quantum, output_connection_name=name
227 )
228 butlerQC.put(output, ref)
230 _LOG.info("Finished mocking task '%s' on quantum %s", self.getName(), quantum.dataId)
233class MockPipelineDefaultTargetConnections(PipelineTaskConnections, dimensions=()):
234 pass
237class MockPipelineDefaultTargetConfig(
238 PipelineTaskConfig, pipelineConnections=MockPipelineDefaultTargetConnections
239):
240 pass
243class MockPipelineDefaultTargetTask(PipelineTask):
244 """A `~lsst.pipe.base.PipelineTask` class used as the default target for
245 ``MockPipelineTaskConfig.original``.
247 This is effectively a workaround for `lsst.pex.config.ConfigurableField`
248 not supporting ``optional=True``, but that is generally a reasonable
249 limitation for production code and it wouldn't make sense just to support
250 test utilities.
251 """
253 ConfigClass = MockPipelineDefaultTargetConfig
256class MockPipelineTaskConnections(BaseTestPipelineTaskConnections, dimensions=()):
257 """A connections class that creates mock connections from the connections
258 of a real PipelineTask.
259 """
261 def __init__(self, *, config: MockPipelineTaskConfig):
262 original: PipelineTaskConnections = config.original.connections.ConnectionsClass(
263 config=config.original.value
264 )
265 self.dimensions.update(original.dimensions)
266 unmocked_dataset_types = frozenset(config.unmocked_dataset_types)
267 for name, connection in original.allConnections.items():
268 if name in original.initInputs or name in original.initOutputs:
269 # We just ignore initInputs and initOutputs, because the task
270 # is never given DatasetRefs for those and hence can't create
271 # mocks.
272 continue
273 if connection.name not in unmocked_dataset_types:
274 # We register the mock storage class with the global singleton
275 # here, but can only put its name in the connection. That means
276 # the same global singleton (or one that also has these
277 # registrations) has to be available whenever this dataset type
278 # is used.
279 storage_class = MockStorageClass.get_or_register_mock(connection.storageClass)
280 kwargs = {}
281 if hasattr(connection, "dimensions"):
282 connection_dimensions = set(connection.dimensions)
283 # Replace the generic "skypix" placeholder with htm7, since
284 # that requires the dataset type to have already been
285 # registered.
286 if "skypix" in connection_dimensions:
287 connection_dimensions.remove("skypix")
288 connection_dimensions.add("htm7")
289 kwargs["dimensions"] = connection_dimensions
290 connection = dataclasses.replace(
291 connection,
292 name=get_mock_name(connection.name),
293 storageClass=storage_class.name,
294 **kwargs,
295 )
296 elif name in original.outputs:
297 raise ValueError(f"Unmocked dataset type {connection.name!r} cannot be used as an output.")
298 setattr(self, name, connection)
301class MockPipelineTaskConfig(BaseTestPipelineTaskConfig, pipelineConnections=MockPipelineTaskConnections):
302 """Configuration class for `MockPipelineTask`."""
304 original: ConfigurableField = ConfigurableField(
305 doc="The original task being mocked by this one.", target=MockPipelineDefaultTargetTask
306 )
308 unmocked_dataset_types = ListField[str](
309 doc=(
310 "Names of input dataset types that should be used as-is instead "
311 "of being mocked. May include dataset types not relevant for "
312 "this task, which will be ignored."
313 ),
314 default=(),
315 optional=False,
316 )
319class MockPipelineTask(BaseTestPipelineTask):
320 """A test-utility implementation of `PipelineTask` with connections
321 generated by mocking those of a real task.
323 Notes
324 -----
325 At present `MockPipelineTask` simply drops any ``initInput`` and
326 ``initOutput`` connections present on the original, since `MockDataset`
327 creation for those would have to happen in the code that executes the task,
328 not in the task itself. Because `MockPipelineTask` never instantiates the
329 mock task (just its connections class), this is a limitation on what the
330 mocks can be used to test, not anything deeper.
331 """
333 ConfigClass: ClassVar[type[PipelineTaskConfig]] = MockPipelineTaskConfig
336class DynamicConnectionConfig(Config):
337 """A config class that defines a completely dynamic connection."""
339 dataset_type_name = Field[str](doc="Name for the dataset type as seen by the butler.", dtype=str)
340 dimensions = ListField[str](doc="Dimensions for the dataset type.", dtype=str, default=[])
341 storage_class = Field[str](
342 doc="Name of the butler storage class for the dataset type.", dtype=str, default="StructuredDataDict"
343 )
344 is_calibration = Field[bool](doc="Whether this dataset type is a calibration.", dtype=bool, default=False)
345 multiple = Field[bool](
346 doc="Whether this connection gets or puts multiple datasets for each quantum.",
347 dtype=bool,
348 default=False,
349 )
350 mock_storage_class = Field[bool](
351 doc="Whether the storage class should actually be a mock of the storage class given.",
352 dtype=bool,
353 default=True,
354 )
356 def make_connection(self, cls: type[_T]) -> _T:
357 storage_class = self.storage_class
358 if self.mock_storage_class:
359 storage_class = MockStorageClass.get_or_register_mock(storage_class).name
360 if issubclass(cls, cT.DimensionedConnection):
361 return cls( # type: ignore
362 name=self.dataset_type_name,
363 storageClass=storage_class,
364 isCalibration=self.is_calibration,
365 multiple=self.multiple,
366 dimensions=frozenset(self.dimensions),
367 )
368 else:
369 return cls(
370 name=self.dataset_type_name,
371 storageClass=storage_class,
372 multiple=self.multiple,
373 )
376class DynamicTestPipelineTaskConnections(PipelineTaskConnections, dimensions=()):
377 """A connections class whose dimensions and connections are wholly
378 determined via configuration.
379 """
381 def __init__(self, *, config: DynamicTestPipelineTaskConfig):
382 self.dimensions.update(config.dimensions)
383 connection_config: DynamicConnectionConfig
384 for connection_name, connection_config in config.init_inputs.items():
385 setattr(self, connection_name, connection_config.make_connection(cT.InitInput))
386 for connection_name, connection_config in config.init_outputs.items():
387 setattr(self, connection_name, connection_config.make_connection(cT.InitOutput))
388 for connection_name, connection_config in config.prerequisite_inputs.items():
389 setattr(self, connection_name, connection_config.make_connection(cT.PrerequisiteInput))
390 for connection_name, connection_config in config.inputs.items():
391 setattr(self, connection_name, connection_config.make_connection(cT.Input))
392 for connection_name, connection_config in config.outputs.items():
393 setattr(self, connection_name, connection_config.make_connection(cT.Output))
396class DynamicTestPipelineTaskConfig(
397 PipelineTaskConfig, pipelineConnections=DynamicTestPipelineTaskConnections
398):
399 """Configuration for DynamicTestPipelineTask."""
401 dimensions = ListField[str](doc="Dimensions for the task's quanta.", dtype=str, default=[])
402 init_inputs = ConfigDictField(
403 doc=(
404 "Init-input connections, keyed by the connection name as seen by the task. "
405 "Must be empty if the task will be constructed."
406 ),
407 keytype=str,
408 itemtype=DynamicConnectionConfig,
409 default={},
410 )
411 init_outputs = ConfigDictField(
412 doc=(
413 "Init-output connections, keyed by the connection name as seen by the task. "
414 "Must be empty if the task will be constructed."
415 ),
416 keytype=str,
417 itemtype=DynamicConnectionConfig,
418 default={},
419 )
420 prerequisite_inputs = ConfigDictField(
421 doc="Prerequisite input connections, keyed by the connection name as seen by the task.",
422 keytype=str,
423 itemtype=DynamicConnectionConfig,
424 default={},
425 )
426 inputs = ConfigDictField(
427 doc="Regular input connections, keyed by the connection name as seen by the task.",
428 keytype=str,
429 itemtype=DynamicConnectionConfig,
430 default={},
431 )
432 outputs = ConfigDictField(
433 doc="Regular output connections, keyed by the connection name as seen by the task.",
434 keytype=str,
435 itemtype=DynamicConnectionConfig,
436 default={},
437 )
440class DynamicTestPipelineTask(BaseTestPipelineTask):
441 """A test-utility implementation of `PipelineTask` with dimensions and
442 connections determined wholly from configuration.
443 """
445 ConfigClass: ClassVar[type[PipelineTaskConfig]] = DynamicTestPipelineTaskConfig