Coverage for python/lsst/pipe/base/tests/mocks/_pipeline_task.py: 24%
206 statements
« prev ^ index » next coverage.py v7.3.0, created at 2023-08-31 09:39 +0000
« prev ^ index » next coverage.py v7.3.0, created at 2023-08-31 09:39 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
21from __future__ import annotations
23from lsst.pipe.base.connectionTypes import BaseInput, Output
25__all__ = (
26 "DynamicConnectionConfig",
27 "DynamicTestPipelineTask",
28 "DynamicTestPipelineTaskConfig",
29 "MockPipelineTask",
30 "MockPipelineTaskConfig",
31 "mock_task_defs",
32)
34import dataclasses
35import logging
36from collections.abc import Collection, Iterable, Mapping
37from typing import TYPE_CHECKING, Any, ClassVar, TypeVar
39from lsst.daf.butler import (
40 DataCoordinate,
41 DatasetRef,
42 DeferredDatasetHandle,
43 SerializedDatasetType,
44 SerializedDimensionGraph,
45)
46from lsst.pex.config import Config, ConfigDictField, ConfigurableField, Field, ListField
47from lsst.utils.doImport import doImportType
48from lsst.utils.introspection import get_full_type_name
49from lsst.utils.iteration import ensure_iterable
51from ... import connectionTypes as cT
52from ...config import PipelineTaskConfig
53from ...connections import InputQuantizedConnection, OutputQuantizedConnection, PipelineTaskConnections
54from ...pipeline import TaskDef
55from ...pipelineTask import PipelineTask
56from ._data_id_match import DataIdMatch
57from ._storage_class import MockDataset, MockDatasetQuantum, MockStorageClass, get_mock_name
59_LOG = logging.getLogger(__name__)
61if TYPE_CHECKING:
62 from ..._quantumContext import QuantumContext
65_T = TypeVar("_T", bound=cT.BaseConnection)
68def mock_task_defs(
69 originals: Iterable[TaskDef],
70 unmocked_dataset_types: Iterable[str] = (),
71 force_failures: Mapping[str, tuple[str, type[Exception] | None]] | None = None,
72) -> list[TaskDef]:
73 """Create mocks for an iterable of TaskDefs.
75 Parameters
76 ----------
77 originals : `~collections.abc.Iterable` [ `TaskDef` ]
78 Original tasks and configuration to mock.
79 unmocked_dataset_types : `~collections.abc.Iterable` [ `str` ], optional
80 Names of overall-input dataset types that should not be replaced with
81 mocks.
82 force_failures : `~collections.abc.Mapping` [ `str`, `tuple` [ `str`, \
83 `type` [ `Exception` ] or `None` ] ]
84 Mapping from original task label to a 2-tuple indicating that some
85 quanta should raise an exception when executed. The first entry is a
86 data ID match using the butler expression language (i.e. a string of
87 the sort passed ass the ``where`` argument to butler query methods),
88 while the second is the type of exception to raise when the quantum
89 data ID matches the expression.
91 Returns
92 -------
93 mocked : `list` [ `TaskDef` ]
94 List of `TaskDef` objects using `MockPipelineTask` configurations that
95 target the original tasks, in the same order.
96 """
97 unmocked_dataset_types = tuple(unmocked_dataset_types)
98 if force_failures is None:
99 force_failures = {}
100 results: list[TaskDef] = []
101 for original_task_def in originals:
102 config = MockPipelineTaskConfig()
103 config.original.retarget(original_task_def.taskClass)
104 config.original = original_task_def.config
105 config.unmocked_dataset_types.extend(unmocked_dataset_types)
106 if original_task_def.label in force_failures:
107 condition, exception_type = force_failures[original_task_def.label]
108 config.fail_condition = condition
109 if exception_type is not None:
110 config.fail_exception = get_full_type_name(exception_type)
111 mock_task_def = TaskDef(
112 config=config, taskClass=MockPipelineTask, label=get_mock_name(original_task_def.label)
113 )
114 results.append(mock_task_def)
115 return results
118class BaseTestPipelineTaskConnections(PipelineTaskConnections, dimensions=()):
119 pass
122class BaseTestPipelineTaskConfig(PipelineTaskConfig, pipelineConnections=BaseTestPipelineTaskConnections):
123 fail_condition = Field[str](
124 dtype=str,
125 default="",
126 doc=(
127 "Condition on Data ID to raise an exception. String expression which includes attributes of "
128 "quantum data ID using a syntax of daf_butler user expressions (e.g. 'visit = 123')."
129 ),
130 )
132 fail_exception = Field[str](
133 dtype=str,
134 default="builtins.ValueError",
135 doc=(
136 "Class name of the exception to raise when fail condition is triggered. Can be "
137 "'lsst.pipe.base.NoWorkFound' to specify non-failure exception."
138 ),
139 )
141 def data_id_match(self) -> DataIdMatch | None:
142 if not self.fail_condition:
143 return None
144 return DataIdMatch(self.fail_condition)
147class BaseTestPipelineTask(PipelineTask):
148 """A base class for test-utility `PipelineTask` classes that read and write
149 mock datasets `runQuantum`.
151 Notes
152 -----
153 This class overrides `runQuantum` to read inputs and write a bit of
154 provenance into all of its outputs (always `MockDataset` instances). It
155 can also be configured to raise exceptions on certain data IDs. It reads
156 `MockDataset` inputs and simulates reading inputs of other types by
157 creating `MockDataset` inputs from their DatasetRefs.
159 Subclasses are responsible for defining connections, but init-input and
160 init-output connections are not supported at runtime (they may be present
161 as long as the task is never constructed). All output connections must
162 use mock storage classes. `..Input` and `..PrerequisiteInput` connections
163 that do not use mock storage classes will be handled by constructing a
164 `MockDataset` from the `~lsst.daf.butler.DatasetRef` rather than actually
165 reading them.
166 """
168 ConfigClass: ClassVar[type[PipelineTaskConfig]] = BaseTestPipelineTaskConfig
170 def __init__(
171 self,
172 *,
173 config: BaseTestPipelineTaskConfig,
174 initInputs: Mapping[str, Any],
175 **kwargs: Any,
176 ):
177 super().__init__(config=config, **kwargs)
178 self.fail_exception: type | None = None
179 self.data_id_match = self.config.data_id_match()
180 if self.data_id_match:
181 self.fail_exception = doImportType(self.config.fail_exception)
182 # Look for, check, and record init-inputs.
183 task_connections = self.ConfigClass.ConnectionsClass(config=config)
184 mock_dataset_quantum = MockDatasetQuantum(task_label=self.getName(), data_id={}, inputs={})
185 for connection_name in task_connections.initInputs:
186 input_dataset = initInputs[connection_name]
187 if not isinstance(input_dataset, MockDataset):
188 raise TypeError(
189 f"Expected MockDataset instance for init-input {self.getName()}.{connection_name}: "
190 f"got {input_dataset!r} of type {type(input_dataset)!r}."
191 )
192 connection = task_connections.allConnections[connection_name]
193 if input_dataset.dataset_type.name != connection.name:
194 raise RuntimeError(
195 f"Incorrect dataset type name for init-input {self.getName()}.{connection_name}: "
196 f"got {input_dataset.dataset_type.name!r}, expected {connection.name!r}."
197 )
198 if input_dataset.storage_class != connection.storageClass:
199 raise RuntimeError(
200 f"Incorrect storage class for init-input {self.getName()}.{connection_name}: "
201 f"got {input_dataset.storage_class!r}, expected {connection.storageClass!r}."
202 )
203 # To avoid very deep provenance we trim inputs to a single
204 # level.
205 input_dataset.quantum = None
206 mock_dataset_quantum.inputs[connection_name] = [input_dataset]
207 # Add init-outputs as task instance attributes.
208 for connection_name in task_connections.initOutputs:
209 connection = task_connections.allConnections[connection_name]
210 output_dataset = MockDataset(
211 dataset_id=None, # the task has no way to get this
212 dataset_type=SerializedDatasetType(
213 name=connection.name,
214 storageClass=connection.storageClass,
215 dimensions=SerializedDimensionGraph(names=[]),
216 ),
217 data_id={},
218 run=None, # task also has no way to get this
219 quantum=mock_dataset_quantum,
220 output_connection_name=connection_name,
221 )
222 setattr(self, connection_name, output_dataset)
224 config: BaseTestPipelineTaskConfig
226 def runQuantum(
227 self,
228 butlerQC: QuantumContext,
229 inputRefs: InputQuantizedConnection,
230 outputRefs: OutputQuantizedConnection,
231 ) -> None:
232 # docstring is inherited from the base class
233 quantum = butlerQC.quantum
235 _LOG.info("Mocking execution of task '%s' on quantum %s", self.getName(), quantum.dataId)
237 assert quantum.dataId is not None, "Quantum DataId cannot be None"
239 # Possibly raise an exception.
240 if self.data_id_match is not None and self.data_id_match.match(quantum.dataId):
241 _LOG.info("Simulating failure of task '%s' on quantum %s", self.getName(), quantum.dataId)
242 message = f"Simulated failure: task={self.getName()} dataId={quantum.dataId}"
243 assert self.fail_exception is not None, "Exception type must be defined"
244 raise self.fail_exception(message)
246 # Populate the bit of provenance we store in all outputs.
247 _LOG.info("Reading input data for task '%s' on quantum %s", self.getName(), quantum.dataId)
248 mock_dataset_quantum = MockDatasetQuantum(
249 task_label=self.getName(), data_id=quantum.dataId.full.byName(), inputs={}
250 )
251 for name, refs in inputRefs:
252 inputs_list = []
253 ref: DatasetRef
254 for ref in ensure_iterable(refs):
255 if isinstance(ref.datasetType.storageClass, MockStorageClass):
256 input_dataset = butlerQC.get(ref)
257 if isinstance(input_dataset, DeferredDatasetHandle):
258 input_dataset = input_dataset.get()
259 if not isinstance(input_dataset, MockDataset):
260 raise TypeError(
261 f"Expected MockDataset instance for {ref}; "
262 f"got {input_dataset!r} of type {type(input_dataset)!r}."
263 )
264 # To avoid very deep provenance we trim inputs to a single
265 # level.
266 input_dataset.quantum = None
267 else:
268 input_dataset = MockDataset(
269 dataset_id=ref.id,
270 dataset_type=ref.datasetType.to_simple(),
271 data_id=ref.dataId.full.byName(),
272 run=ref.run,
273 )
274 inputs_list.append(input_dataset)
275 mock_dataset_quantum.inputs[name] = inputs_list
277 # store mock outputs
278 for name, refs in outputRefs:
279 for ref in ensure_iterable(refs):
280 output = MockDataset(
281 dataset_id=ref.id,
282 dataset_type=ref.datasetType.to_simple(),
283 data_id=ref.dataId.full.byName(),
284 run=ref.run,
285 quantum=mock_dataset_quantum,
286 output_connection_name=name,
287 )
288 butlerQC.put(output, ref)
290 _LOG.info("Finished mocking task '%s' on quantum %s", self.getName(), quantum.dataId)
293class MockPipelineDefaultTargetConnections(PipelineTaskConnections, dimensions=()):
294 pass
297class MockPipelineDefaultTargetConfig(
298 PipelineTaskConfig, pipelineConnections=MockPipelineDefaultTargetConnections
299):
300 pass
303class MockPipelineDefaultTargetTask(PipelineTask):
304 """A `~lsst.pipe.base.PipelineTask` class used as the default target for
305 ``MockPipelineTaskConfig.original``.
307 This is effectively a workaround for `lsst.pex.config.ConfigurableField`
308 not supporting ``optional=True``, but that is generally a reasonable
309 limitation for production code and it wouldn't make sense just to support
310 test utilities.
311 """
313 ConfigClass = MockPipelineDefaultTargetConfig
316class MockPipelineTaskConnections(BaseTestPipelineTaskConnections, dimensions=()):
317 """A connections class that creates mock connections from the connections
318 of a real PipelineTask.
319 """
321 def __init__(self, *, config: MockPipelineTaskConfig):
322 self.original: PipelineTaskConnections = config.original.connections.ConnectionsClass(
323 config=config.original.value
324 )
325 self.dimensions.update(self.original.dimensions)
326 self.unmocked_dataset_types = frozenset(config.unmocked_dataset_types)
327 for name, connection in self.original.allConnections.items():
328 if connection.name not in self.unmocked_dataset_types:
329 # We register the mock storage class with the global singleton
330 # here, but can only put its name in the connection. That means
331 # the same global singleton (or one that also has these
332 # registrations) has to be available whenever this dataset type
333 # is used.
334 storage_class = MockStorageClass.get_or_register_mock(connection.storageClass)
335 kwargs: dict[str, Any] = {}
336 if hasattr(connection, "dimensions"):
337 connection_dimensions = set(connection.dimensions)
338 # Replace the generic "skypix" placeholder with htm7, since
339 # that requires the dataset type to have already been
340 # registered.
341 if "skypix" in connection_dimensions:
342 connection_dimensions.remove("skypix")
343 connection_dimensions.add("htm7")
344 kwargs["dimensions"] = connection_dimensions
345 connection = dataclasses.replace(
346 connection,
347 name=get_mock_name(connection.name),
348 storageClass=storage_class.name,
349 **kwargs,
350 )
351 elif name in self.original.outputs:
352 raise ValueError(f"Unmocked dataset type {connection.name!r} cannot be used as an output.")
353 elif name in self.original.initInputs:
354 raise ValueError(
355 f"Unmocked dataset type {connection.name!r} cannot be used as an init-input."
356 )
357 elif name in self.original.initOutputs:
358 raise ValueError(
359 f"Unmocked dataset type {connection.name!r} cannot be used as an init-output."
360 )
361 setattr(self, name, connection)
363 def getSpatialBoundsConnections(self) -> Iterable[str]:
364 return self.original.getSpatialBoundsConnections()
366 def getTemporalBoundsConnections(self) -> Iterable[str]:
367 return self.original.getTemporalBoundsConnections()
369 def adjustQuantum(
370 self,
371 inputs: dict[str, tuple[BaseInput, Collection[DatasetRef]]],
372 outputs: dict[str, tuple[Output, Collection[DatasetRef]]],
373 label: str,
374 data_id: DataCoordinate,
375 ) -> tuple[
376 Mapping[str, tuple[BaseInput, Collection[DatasetRef]]],
377 Mapping[str, tuple[Output, Collection[DatasetRef]]],
378 ]:
379 # Convert the given mappings from the mock dataset types to the
380 # original dataset types they were produced from.
381 original_inputs = {}
382 for connection_name, (_, mock_refs) in inputs.items():
383 original_connection = getattr(self.original, connection_name)
384 if original_connection.name in self.unmocked_dataset_types:
385 refs = mock_refs
386 else:
387 refs = MockStorageClass.unmock_dataset_refs(mock_refs)
388 original_inputs[connection_name] = (original_connection, refs)
389 original_outputs = {}
390 for connection_name, (_, mock_refs) in outputs.items():
391 original_connection = getattr(self.original, connection_name)
392 if original_connection.name in self.unmocked_dataset_types:
393 refs = mock_refs
394 else:
395 refs = MockStorageClass.unmock_dataset_refs(mock_refs)
396 original_outputs[connection_name] = (original_connection, refs)
397 # Call adjustQuantum on the original connections class.
398 adjusted_original_inputs, adjusted_original_outputs = self.original.adjustQuantum(
399 original_inputs, original_outputs, label, data_id
400 )
401 # Convert the results back to the mock dataset type.s
402 adjusted_inputs = {}
403 for connection_name, (original_connection, original_refs) in adjusted_original_inputs.items():
404 if original_connection.name in self.unmocked_dataset_types:
405 refs = original_refs
406 else:
407 refs = MockStorageClass.mock_dataset_refs(original_refs)
408 adjusted_inputs[connection_name] = (getattr(self, connection_name), refs)
409 adjusted_outputs = {}
410 for connection_name, (original_connection, original_refs) in adjusted_original_outputs.items():
411 if original_connection.name in self.unmocked_dataset_types:
412 refs = original_refs
413 else:
414 refs = MockStorageClass.mock_dataset_refs(original_refs)
415 adjusted_outputs[connection_name] = (getattr(self, connection_name), refs)
416 return adjusted_inputs, adjusted_outputs
419class MockPipelineTaskConfig(BaseTestPipelineTaskConfig, pipelineConnections=MockPipelineTaskConnections):
420 """Configuration class for `MockPipelineTask`."""
422 original: ConfigurableField = ConfigurableField(
423 doc="The original task being mocked by this one.", target=MockPipelineDefaultTargetTask
424 )
426 unmocked_dataset_types = ListField[str](
427 doc=(
428 "Names of input dataset types that should be used as-is instead "
429 "of being mocked. May include dataset types not relevant for "
430 "this task, which will be ignored."
431 ),
432 default=(),
433 optional=False,
434 )
437class MockPipelineTask(BaseTestPipelineTask):
438 """A test-utility implementation of `PipelineTask` with connections
439 generated by mocking those of a real task.
441 Notes
442 -----
443 At present `MockPipelineTask` simply drops any ``initInput`` and
444 ``initOutput`` connections present on the original, since `MockDataset`
445 creation for those would have to happen in the code that executes the task,
446 not in the task itself. Because `MockPipelineTask` never instantiates the
447 mock task (just its connections class), this is a limitation on what the
448 mocks can be used to test, not anything deeper.
449 """
451 ConfigClass: ClassVar[type[PipelineTaskConfig]] = MockPipelineTaskConfig
454class DynamicConnectionConfig(Config):
455 """A config class that defines a completely dynamic connection."""
457 dataset_type_name = Field[str](doc="Name for the dataset type as seen by the butler.", dtype=str)
458 dimensions = ListField[str](doc="Dimensions for the dataset type.", dtype=str, default=[])
459 storage_class = Field[str](
460 doc="Name of the butler storage class for the dataset type.", dtype=str, default="StructuredDataDict"
461 )
462 is_calibration = Field[bool](doc="Whether this dataset type is a calibration.", dtype=bool, default=False)
463 multiple = Field[bool](
464 doc="Whether this connection gets or puts multiple datasets for each quantum.",
465 dtype=bool,
466 default=False,
467 )
468 mock_storage_class = Field[bool](
469 doc="Whether the storage class should actually be a mock of the storage class given.",
470 dtype=bool,
471 default=True,
472 )
474 def make_connection(self, cls: type[_T]) -> _T:
475 storage_class = self.storage_class
476 if self.mock_storage_class:
477 storage_class = MockStorageClass.get_or_register_mock(storage_class).name
478 if issubclass(cls, cT.DimensionedConnection):
479 return cls( # type: ignore
480 name=self.dataset_type_name,
481 storageClass=storage_class,
482 isCalibration=self.is_calibration,
483 multiple=self.multiple,
484 dimensions=frozenset(self.dimensions),
485 )
486 else:
487 return cls(
488 name=self.dataset_type_name,
489 storageClass=storage_class,
490 multiple=self.multiple,
491 )
494class DynamicTestPipelineTaskConnections(PipelineTaskConnections, dimensions=()):
495 """A connections class whose dimensions and connections are wholly
496 determined via configuration.
497 """
499 def __init__(self, *, config: DynamicTestPipelineTaskConfig):
500 self.dimensions.update(config.dimensions)
501 connection_config: DynamicConnectionConfig
502 for connection_name, connection_config in config.init_inputs.items():
503 setattr(self, connection_name, connection_config.make_connection(cT.InitInput))
504 for connection_name, connection_config in config.init_outputs.items():
505 setattr(self, connection_name, connection_config.make_connection(cT.InitOutput))
506 for connection_name, connection_config in config.prerequisite_inputs.items():
507 setattr(self, connection_name, connection_config.make_connection(cT.PrerequisiteInput))
508 for connection_name, connection_config in config.inputs.items():
509 setattr(self, connection_name, connection_config.make_connection(cT.Input))
510 for connection_name, connection_config in config.outputs.items():
511 setattr(self, connection_name, connection_config.make_connection(cT.Output))
514class DynamicTestPipelineTaskConfig(
515 PipelineTaskConfig, pipelineConnections=DynamicTestPipelineTaskConnections
516):
517 """Configuration for DynamicTestPipelineTask."""
519 dimensions = ListField[str](doc="Dimensions for the task's quanta.", dtype=str, default=[])
520 init_inputs = ConfigDictField(
521 doc=(
522 "Init-input connections, keyed by the connection name as seen by the task. "
523 "Must be empty if the task will be constructed."
524 ),
525 keytype=str,
526 itemtype=DynamicConnectionConfig,
527 default={},
528 )
529 init_outputs = ConfigDictField(
530 doc=(
531 "Init-output connections, keyed by the connection name as seen by the task. "
532 "Must be empty if the task will be constructed."
533 ),
534 keytype=str,
535 itemtype=DynamicConnectionConfig,
536 default={},
537 )
538 prerequisite_inputs = ConfigDictField(
539 doc="Prerequisite input connections, keyed by the connection name as seen by the task.",
540 keytype=str,
541 itemtype=DynamicConnectionConfig,
542 default={},
543 )
544 inputs = ConfigDictField(
545 doc="Regular input connections, keyed by the connection name as seen by the task.",
546 keytype=str,
547 itemtype=DynamicConnectionConfig,
548 default={},
549 )
550 outputs = ConfigDictField(
551 doc="Regular output connections, keyed by the connection name as seen by the task.",
552 keytype=str,
553 itemtype=DynamicConnectionConfig,
554 default={},
555 )
558class DynamicTestPipelineTask(BaseTestPipelineTask):
559 """A test-utility implementation of `PipelineTask` with dimensions and
560 connections determined wholly from configuration.
561 """
563 ConfigClass: ClassVar[type[PipelineTaskConfig]] = DynamicTestPipelineTaskConfig