Coverage for python/lsst/pipe/base/tests/mocks/_pipeline_task.py: 25%
110 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-23 08:14 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-23 08:14 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ("MockPipelineTask", "MockPipelineTaskConfig", "mock_task_defs")
25import dataclasses
26import logging
27from collections.abc import Iterable, Mapping
28from typing import TYPE_CHECKING, Any, ClassVar
30from lsst.daf.butler import DeferredDatasetHandle
31from lsst.pex.config import ConfigurableField, Field, ListField
32from lsst.utils.doImport import doImportType
33from lsst.utils.introspection import get_full_type_name
34from lsst.utils.iteration import ensure_iterable
36from ...config import PipelineTaskConfig
37from ...connections import InputQuantizedConnection, OutputQuantizedConnection, PipelineTaskConnections
38from ...pipeline import TaskDef
39from ...pipelineTask import PipelineTask
40from ._data_id_match import DataIdMatch
41from ._storage_class import MockDataset, MockDatasetQuantum, MockStorageClass, get_mock_name
43_LOG = logging.getLogger(__name__)
45if TYPE_CHECKING:
46 from ..._quantumContext import QuantumContext
49def mock_task_defs(
50 originals: Iterable[TaskDef],
51 unmocked_dataset_types: Iterable[str] = (),
52 force_failures: Mapping[str, tuple[str, type[Exception] | None]] | None = None,
53) -> list[TaskDef]:
54 """Create mocks for an iterable of TaskDefs.
56 Parameters
57 ----------
58 originals : `~collections.abc.Iterable` [ `TaskDef` ]
59 Original tasks and configuration to mock.
60 unmocked_dataset_types : `~collections.abc.Iterable` [ `str` ], optional
61 Names of overall-input dataset types that should not be replaced with
62 mocks.
63 force_failures : `~collections.abc.Mapping` [ `str`, `tuple` [ `str`, \
64 `type` [ `Exception` ] or `None` ] ]
65 Mapping from original task label to a 2-tuple indicating that some
66 quanta should raise an exception when executed. The first entry is a
67 data ID match using the butler expression language (i.e. a string of
68 the sort passed ass the ``where`` argument to butler query methods),
69 while the second is the type of exception to raise when the quantum
70 data ID matches the expression.
72 Returns
73 -------
74 mocked : `list` [ `TaskDef` ]
75 List of `TaskDef` objects using `MockPipelineTask` configurations that
76 target the original tasks, in the same order.
77 """
78 unmocked_dataset_types = tuple(unmocked_dataset_types)
79 if force_failures is None:
80 force_failures = {}
81 results: list[TaskDef] = []
82 for original_task_def in originals:
83 config = MockPipelineTaskConfig()
84 config.original.retarget(original_task_def.taskClass)
85 config.original = original_task_def.config
86 config.unmocked_dataset_types.extend(unmocked_dataset_types)
87 if original_task_def.label in force_failures:
88 condition, exception_type = force_failures[original_task_def.label]
89 config.fail_condition = condition
90 if exception_type is not None:
91 config.fail_exception = get_full_type_name(exception_type)
92 mock_task_def = TaskDef(
93 config=config, taskClass=MockPipelineTask, label=get_mock_name(original_task_def.label)
94 )
95 results.append(mock_task_def)
96 return results
99class MockPipelineDefaultTargetConnections(PipelineTaskConnections, dimensions=()):
100 pass
103class MockPipelineDefaultTargetConfig(
104 PipelineTaskConfig, pipelineConnections=MockPipelineDefaultTargetConnections
105):
106 pass
109class MockPipelineDefaultTargetTask(PipelineTask):
110 """A `~lsst.pipe.base.PipelineTask` class used as the default target for
111 ``MockPipelineTaskConfig.original``.
113 This is effectively a workaround for `lsst.pex.config.ConfigurableField`
114 not supporting ``optional=True``, but that is generally a reasonable
115 limitation for production code and it wouldn't make sense just to support
116 test utilities.
117 """
119 ConfigClass = MockPipelineDefaultTargetConfig
122class MockPipelineTaskConnections(PipelineTaskConnections, dimensions=()):
123 def __init__(self, *, config: MockPipelineTaskConfig):
124 original: PipelineTaskConnections = config.original.connections.ConnectionsClass(
125 config=config.original.value
126 )
127 self.dimensions.update(original.dimensions)
128 unmocked_dataset_types = frozenset(config.unmocked_dataset_types)
129 for name, connection in original.allConnections.items():
130 if name in original.initInputs or name in original.initOutputs:
131 # We just ignore initInputs and initOutputs, because the task
132 # is never given DatasetRefs for those and hence can't create
133 # mocks.
134 continue
135 if connection.name not in unmocked_dataset_types:
136 # We register the mock storage class with the global singleton
137 # here, but can only put its name in the connection. That means
138 # the same global singleton (or one that also has these
139 # registrations) has to be available whenever this dataset type
140 # is used.
141 storage_class = MockStorageClass.get_or_register_mock(connection.storageClass)
142 kwargs = {}
143 if hasattr(connection, "dimensions"):
144 connection_dimensions = set(connection.dimensions)
145 # Replace the generic "skypix" placeholder with htm7, since
146 # that requires the dataset type to have already been
147 # registered.
148 if "skypix" in connection_dimensions:
149 connection_dimensions.remove("skypix")
150 connection_dimensions.add("htm7")
151 kwargs["dimensions"] = connection_dimensions
152 connection = dataclasses.replace(
153 connection,
154 name=get_mock_name(connection.name),
155 storageClass=storage_class.name,
156 **kwargs,
157 )
158 elif name in original.outputs:
159 raise ValueError(f"Unmocked dataset type {connection.name!r} cannot be used as an output.")
160 setattr(self, name, connection)
163class MockPipelineTaskConfig(PipelineTaskConfig, pipelineConnections=MockPipelineTaskConnections):
164 """Configuration class for `MockPipelineTask`."""
166 fail_condition = Field[str](
167 dtype=str,
168 default="",
169 doc=(
170 "Condition on Data ID to raise an exception. String expression which includes attributes of "
171 "quantum data ID using a syntax of daf_butler user expressions (e.g. 'visit = 123')."
172 ),
173 )
175 fail_exception = Field[str](
176 dtype=str,
177 default="builtins.ValueError",
178 doc=(
179 "Class name of the exception to raise when fail condition is triggered. Can be "
180 "'lsst.pipe.base.NoWorkFound' to specify non-failure exception."
181 ),
182 )
184 original: ConfigurableField = ConfigurableField(
185 doc="The original task being mocked by this one.", target=MockPipelineDefaultTargetTask
186 )
188 unmocked_dataset_types = ListField[str](
189 doc=(
190 "Names of input dataset types that should be used as-is instead "
191 "of being mocked. May include dataset types not relevant for "
192 "this task, which will be ignored."
193 ),
194 default=(),
195 optional=False,
196 )
198 def data_id_match(self) -> DataIdMatch | None:
199 if not self.fail_condition:
200 return None
201 return DataIdMatch(self.fail_condition)
204class MockPipelineTask(PipelineTask):
205 """Implementation of `~lsst.pipe.base.PipelineTask` used for running a
206 mock pipeline.
208 Notes
209 -----
210 This class overrides `runQuantum` to read inputs and write a bit of
211 provenance into all of its outputs (always `MockDataset` instances). It
212 can also be configured to raise exceptions on certain data IDs. It reads
213 `MockDataset` inputs and simulates reading inputs of other types by
214 creating `MockDataset` inputs from their DatasetRefs.
216 At present `MockPipelineTask` simply drops any ``initInput`` and
217 ``initOutput`` connections present on the original, since `MockDataset`
218 creation for those would have to happen in the code that executes the task,
219 not in the task itself. Because `MockPipelineTask` never instantiates the
220 mock task (just its connections class), this is a limitation on what the
221 mocks can be used to test, not anything deeper.
222 """
224 ConfigClass: ClassVar[type[PipelineTaskConfig]] = MockPipelineTaskConfig
226 def __init__(
227 self,
228 *,
229 config: MockPipelineTaskConfig,
230 **kwargs: Any,
231 ):
232 super().__init__(config=config, **kwargs)
233 self.fail_exception: type | None = None
234 self.data_id_match = self.config.data_id_match()
235 if self.data_id_match:
236 self.fail_exception = doImportType(self.config.fail_exception)
238 config: MockPipelineTaskConfig
240 def runQuantum(
241 self,
242 butlerQC: QuantumContext,
243 inputRefs: InputQuantizedConnection,
244 outputRefs: OutputQuantizedConnection,
245 ) -> None:
246 # docstring is inherited from the base class
247 quantum = butlerQC.quantum
249 _LOG.info("Mocking execution of task '%s' on quantum %s", self.getName(), quantum.dataId)
251 assert quantum.dataId is not None, "Quantum DataId cannot be None"
253 # Possibly raise an exception.
254 if self.data_id_match is not None and self.data_id_match.match(quantum.dataId):
255 _LOG.info("Simulating failure of task '%s' on quantum %s", self.getName(), quantum.dataId)
256 message = f"Simulated failure: task={self.getName()} dataId={quantum.dataId}"
257 assert self.fail_exception is not None, "Exception type must be defined"
258 raise self.fail_exception(message)
260 # Populate the bit of provenance we store in all outputs.
261 _LOG.info("Reading input data for task '%s' on quantum %s", self.getName(), quantum.dataId)
262 mock_dataset_quantum = MockDatasetQuantum(
263 task_label=self.getName(), data_id=quantum.dataId.to_simple(), inputs={}
264 )
265 for name, refs in inputRefs:
266 inputs_list = []
267 for ref in ensure_iterable(refs):
268 if isinstance(ref.datasetType.storageClass, MockStorageClass):
269 input_dataset = butlerQC.get(ref)
270 if isinstance(input_dataset, DeferredDatasetHandle):
271 input_dataset = input_dataset.get()
272 if not isinstance(input_dataset, MockDataset):
273 raise TypeError(
274 f"Expected MockDataset instance for {ref}; "
275 f"got {input_dataset!r} of type {type(input_dataset)!r}."
276 )
277 # To avoid very deep provenance we trim inputs to a single
278 # level.
279 input_dataset.quantum = None
280 else:
281 input_dataset = MockDataset(ref=ref.to_simple())
282 inputs_list.append(input_dataset)
283 mock_dataset_quantum.inputs[name] = inputs_list
285 # store mock outputs
286 for name, refs in outputRefs:
287 for ref in ensure_iterable(refs):
288 output = MockDataset(
289 ref=ref.to_simple(), quantum=mock_dataset_quantum, output_connection_name=name
290 )
291 butlerQC.put(output, ref)
293 _LOG.info("Finished mocking task '%s' on quantum %s", self.getName(), quantum.dataId)