Coverage for python/lsst/pipe/base/tests/mocks/_pipeline_task.py: 25%
111 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-14 02:10 -0700
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-14 02:10 -0700
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ("MockPipelineTask", "MockPipelineTaskConfig", "mock_task_defs")
25import dataclasses
26import logging
27from collections.abc import Iterable, Mapping
28from typing import TYPE_CHECKING, Any, ClassVar
30from lsst.daf.butler import DeferredDatasetHandle
31from lsst.pex.config import ConfigurableField, Field, ListField
32from lsst.utils.doImport import doImportType
33from lsst.utils.introspection import get_full_type_name
34from lsst.utils.iteration import ensure_iterable
36from ...config import PipelineTaskConfig
37from ...connections import InputQuantizedConnection, OutputQuantizedConnection, PipelineTaskConnections
38from ...pipeline import TaskDef
39from ...pipelineTask import PipelineTask
40from ._data_id_match import DataIdMatch
41from ._storage_class import MockDataset, MockDatasetQuantum, MockStorageClass, get_mock_name
43_LOG = logging.getLogger(__name__)
45if TYPE_CHECKING:
46 from ...butlerQuantumContext import ButlerQuantumContext
49def mock_task_defs(
50 originals: Iterable[TaskDef],
51 unmocked_dataset_types: Iterable[str] = (),
52 force_failures: Mapping[str, tuple[str, type[Exception]]] | None = None,
53) -> list[TaskDef]:
54 """Create mocks for an iterable of TaskDefs.
56 Parameters
57 ----------
58 originals : `~collections.abc.Iterable` [ `TaskDef` ]
59 Original tasks and configuration to mock.
60 unmocked_dataset_types : `~collections.abc.Iterable` [ `str` ], optional
61 Names of overall-input dataset types that should not be replaced with
62 mocks.
63 force_failures : `~collections.abc.Mapping` [ `str`, `tuple` [ `str`, \
64 `type` [ `Exception` ] ] ]
65 Mapping from original task label to a 2-tuple indicating that some
66 quanta should raise an exception when executed. The first entry is a
67 data ID match using the butler expression language (i.e. a string of
68 the sort passed ass the ``where`` argument to butler query methods),
69 while the second is the type of exception to raise when the quantum
70 data ID matches the expression.
72 Returns
73 -------
74 mocked : `list` [ `TaskDef` ]
75 List of `TaskDef` objects using `MockPipelineTask` configurations that
76 target the original tasks, in the same order.
77 """
78 unmocked_dataset_types = tuple(unmocked_dataset_types)
79 if force_failures is None:
80 force_failures = {}
81 results: list[TaskDef] = []
82 for original_task_def in originals:
83 config = MockPipelineTaskConfig()
84 config.original.retarget(original_task_def.taskClass)
85 config.original = original_task_def.config
86 config.unmocked_dataset_types.extend(unmocked_dataset_types)
87 if original_task_def.label in force_failures:
88 condition, exception_type = force_failures[original_task_def.label]
89 config.fail_condition = condition
90 config.fail_exception = get_full_type_name(exception_type)
91 mock_task_def = TaskDef(
92 config=config, taskClass=MockPipelineTask, label=get_mock_name(original_task_def.label)
93 )
94 results.append(mock_task_def)
95 return results
98class MockPipelineDefaultTargetConnections(PipelineTaskConnections, dimensions=()):
99 pass
102class MockPipelineDefaultTargetConfig(
103 PipelineTaskConfig, pipelineConnections=MockPipelineDefaultTargetConnections
104):
105 pass
108class MockPipelineDefaultTargetTask(PipelineTask):
109 """A `PipelineTask` class used as the default target for
110 ``MockPipelineTaskConfig.original``.
112 This is effectively a workaround for `lsst.pex.config.ConfigurableField`
113 not supporting ``optional=True``, but that is generally a reasonable
114 limitation for production code and it wouldn't make sense just to support
115 test utilities.
116 """
118 ConfigClass = MockPipelineDefaultTargetConfig
121class MockPipelineTaskConnections(PipelineTaskConnections, dimensions=()):
122 def __init__(self, *, config: MockPipelineTaskConfig):
123 original: PipelineTaskConnections = config.original.connections.ConnectionsClass(
124 config=config.original.value
125 )
126 self.dimensions.update(original.dimensions)
127 unmocked_dataset_types = frozenset(config.unmocked_dataset_types)
128 for name, connection in original.allConnections.items():
129 if name in original.initInputs or name in original.initOutputs:
130 # We just ignore initInputs and initOutputs, because the task
131 # is never given DatasetRefs for those and hence can't create
132 # mocks.
133 continue
134 if connection.name not in unmocked_dataset_types:
135 # We register the mock storage class with the global singleton
136 # here, but can only put its name in the connection. That means
137 # the same global singleton (or one that also has these
138 # registrations) has to be available whenever this dataset type
139 # is used.
140 storage_class = MockStorageClass.get_or_register_mock(connection.storageClass)
141 kwargs = {}
142 if hasattr(connection, "dimensions"):
143 connection_dimensions = set(connection.dimensions)
144 # Replace the generic "skypix" placeholder with htm7, since
145 # that requires the dataset type to have already been
146 # registered.
147 if "skypix" in connection_dimensions:
148 connection_dimensions.remove("skypix")
149 connection_dimensions.add("htm7")
150 kwargs["dimensions"] = connection_dimensions
151 connection = dataclasses.replace(
152 connection,
153 name=get_mock_name(connection.name),
154 storageClass=storage_class.name,
155 **kwargs,
156 )
157 elif name in original.outputs:
158 raise ValueError(f"Unmocked dataset type {connection.name!r} cannot be used as an output.")
159 setattr(self, name, connection)
162class MockPipelineTaskConfig(PipelineTaskConfig, pipelineConnections=MockPipelineTaskConnections):
163 """Configuration class for `MockPipelineTask`."""
165 fail_condition = Field[str](
166 dtype=str,
167 default="",
168 doc=(
169 "Condition on Data ID to raise an exception. String expression which includes attributes of "
170 "quantum data ID using a syntax of daf_butler user expressions (e.g. 'visit = 123')."
171 ),
172 )
174 fail_exception = Field[str](
175 dtype=str,
176 default="builtins.ValueError",
177 doc=(
178 "Class name of the exception to raise when fail condition is triggered. Can be "
179 "'lsst.pipe.base.NoWorkFound' to specify non-failure exception."
180 ),
181 )
183 original: ConfigurableField = ConfigurableField(
184 doc="The original task being mocked by this one.", target=MockPipelineDefaultTargetTask
185 )
187 unmocked_dataset_types = ListField[str](
188 doc=(
189 "Names of input dataset types that should be used as-is instead "
190 "of being mocked. May include dataset types not relevant for "
191 "this task, which will be ignored."
192 ),
193 default=(),
194 optional=False,
195 )
197 def data_id_match(self) -> DataIdMatch | None:
198 if not self.fail_condition:
199 return None
200 return DataIdMatch(self.fail_condition)
203class MockPipelineTask(PipelineTask):
204 """Implementation of `PipelineTask` used for running a mock pipeline.
206 Notes
207 -----
208 This class overrides `runQuantum` to read inputs and write a bit of
209 provenance into all of its outputs (always `MockDataset` instances). It
210 can also be configured to raise exceptions on certain data IDs. It reads
211 `MockDataset` inputs and simulates reading inputs of other types by
212 creating `MockDataset` inputs from their DatasetRefs.
214 At present `MockPipelineTask` simply drops any ``initInput`` and
215 ``initOutput`` connections present on the original, since `MockDataset`
216 creation for those would have to happen in the code that executes the task,
217 not in the task itself. Because `MockPipelineTask` never instantiates the
218 mock task (just its connections class), this is a limitation on what the
219 mocks can be used to test, not anything deeper.
220 """
222 ConfigClass: ClassVar[type[PipelineTaskConfig]] = MockPipelineTaskConfig
224 def __init__(
225 self,
226 *,
227 config: MockPipelineTaskConfig,
228 **kwargs: Any,
229 ):
230 super().__init__(config=config, **kwargs)
231 self.fail_exception: type | None = None
232 self.data_id_match = self.config.data_id_match()
233 if self.data_id_match:
234 self.fail_exception = doImportType(self.config.fail_exception)
236 config: MockPipelineTaskConfig
238 def runQuantum(
239 self,
240 butlerQC: ButlerQuantumContext,
241 inputRefs: InputQuantizedConnection,
242 outputRefs: OutputQuantizedConnection,
243 ) -> None:
244 # docstring is inherited from the base class
245 quantum = butlerQC.quantum
247 _LOG.info("Mocking execution of task '%s' on quantum %s", self.getName(), quantum.dataId)
249 assert quantum.dataId is not None, "Quantum DataId cannot be None"
251 # Possibly raise an exception.
252 if self.data_id_match is not None and self.data_id_match.match(quantum.dataId):
253 _LOG.info("Simulating failure of task '%s' on quantum %s", self.getName(), quantum.dataId)
254 message = f"Simulated failure: task={self.getName()} dataId={quantum.dataId}"
255 assert self.fail_exception is not None, "Exception type must be defined"
256 raise self.fail_exception(message)
258 # Populate the bit of provenance we store in all outputs.
259 _LOG.info("Reading input data for task '%s' on quantum %s", self.getName(), quantum.dataId)
260 mock_dataset_quantum = MockDatasetQuantum(
261 task_label=self.getName(), data_id=quantum.dataId.to_simple(), inputs={}
262 )
263 for name, refs in inputRefs:
264 inputs_list = []
265 for ref in ensure_iterable(refs):
266 if isinstance(ref.datasetType.storageClass, MockStorageClass):
267 input_dataset = butlerQC.get(ref)
268 if isinstance(input_dataset, DeferredDatasetHandle):
269 input_dataset = input_dataset.get()
270 if not isinstance(input_dataset, MockDataset):
271 raise TypeError(
272 f"Expected MockDataset instance for {ref}; "
273 f"got {input_dataset!r} of type {type(input_dataset)!r}."
274 )
275 # To avoid very deep provenance we trim inputs to a single
276 # level.
277 input_dataset.quantum = None
278 else:
279 input_dataset = MockDataset(ref=ref.to_simple())
280 inputs_list.append(input_dataset)
281 mock_dataset_quantum.inputs[name] = inputs_list
283 # store mock outputs
284 for name, refs in outputRefs:
285 if not isinstance(refs, list):
286 refs = [refs]
287 for ref in refs:
288 output = MockDataset(
289 ref=ref.to_simple(), quantum=mock_dataset_quantum, output_connection_name=name
290 )
291 butlerQC.put(output, ref)
293 _LOG.info("Finished mocking task '%s' on quantum %s", self.getName(), quantum.dataId)