Coverage for python/lsst/pipe/base/tests/mocks/_pipeline_task.py: 24%
206 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-19 10:39 +0000
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-19 10:39 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
27from __future__ import annotations
29from lsst.pipe.base.connectionTypes import BaseInput, Output
31__all__ = (
32 "DynamicConnectionConfig",
33 "DynamicTestPipelineTask",
34 "DynamicTestPipelineTaskConfig",
35 "MockPipelineTask",
36 "MockPipelineTaskConfig",
37 "mock_task_defs",
38)
40import dataclasses
41import logging
42from collections.abc import Collection, Iterable, Mapping
43from typing import TYPE_CHECKING, Any, ClassVar, TypeVar
45from lsst.daf.butler import (
46 DataCoordinate,
47 DatasetRef,
48 DeferredDatasetHandle,
49 SerializedDatasetType,
50 SerializedDimensionGraph,
51)
52from lsst.pex.config import Config, ConfigDictField, ConfigurableField, Field, ListField
53from lsst.utils.doImport import doImportType
54from lsst.utils.introspection import get_full_type_name
55from lsst.utils.iteration import ensure_iterable
57from ... import connectionTypes as cT
58from ...config import PipelineTaskConfig
59from ...connections import InputQuantizedConnection, OutputQuantizedConnection, PipelineTaskConnections
60from ...pipeline import TaskDef
61from ...pipelineTask import PipelineTask
62from ._data_id_match import DataIdMatch
63from ._storage_class import MockDataset, MockDatasetQuantum, MockStorageClass, get_mock_name
65_LOG = logging.getLogger(__name__)
67if TYPE_CHECKING:
68 from ..._quantumContext import QuantumContext
71_T = TypeVar("_T", bound=cT.BaseConnection)
74def mock_task_defs(
75 originals: Iterable[TaskDef],
76 unmocked_dataset_types: Iterable[str] = (),
77 force_failures: Mapping[str, tuple[str, type[Exception] | None]] | None = None,
78) -> list[TaskDef]:
79 """Create mocks for an iterable of TaskDefs.
81 Parameters
82 ----------
83 originals : `~collections.abc.Iterable` [ `TaskDef` ]
84 Original tasks and configuration to mock.
85 unmocked_dataset_types : `~collections.abc.Iterable` [ `str` ], optional
86 Names of overall-input dataset types that should not be replaced with
87 mocks.
88 force_failures : `~collections.abc.Mapping` [ `str`, `tuple` [ `str`, \
89 `type` [ `Exception` ] or `None` ] ]
90 Mapping from original task label to a 2-tuple indicating that some
91 quanta should raise an exception when executed. The first entry is a
92 data ID match using the butler expression language (i.e. a string of
93 the sort passed ass the ``where`` argument to butler query methods),
94 while the second is the type of exception to raise when the quantum
95 data ID matches the expression.
97 Returns
98 -------
99 mocked : `list` [ `TaskDef` ]
100 List of `TaskDef` objects using `MockPipelineTask` configurations that
101 target the original tasks, in the same order.
102 """
103 unmocked_dataset_types = tuple(unmocked_dataset_types)
104 if force_failures is None:
105 force_failures = {}
106 results: list[TaskDef] = []
107 for original_task_def in originals:
108 config = MockPipelineTaskConfig()
109 config.original.retarget(original_task_def.taskClass)
110 config.original = original_task_def.config
111 config.unmocked_dataset_types.extend(unmocked_dataset_types)
112 if original_task_def.label in force_failures:
113 condition, exception_type = force_failures[original_task_def.label]
114 config.fail_condition = condition
115 if exception_type is not None:
116 config.fail_exception = get_full_type_name(exception_type)
117 mock_task_def = TaskDef(
118 config=config, taskClass=MockPipelineTask, label=get_mock_name(original_task_def.label)
119 )
120 results.append(mock_task_def)
121 return results
124class BaseTestPipelineTaskConnections(PipelineTaskConnections, dimensions=()):
125 pass
128class BaseTestPipelineTaskConfig(PipelineTaskConfig, pipelineConnections=BaseTestPipelineTaskConnections):
129 fail_condition = Field[str](
130 dtype=str,
131 default="",
132 doc=(
133 "Condition on Data ID to raise an exception. String expression which includes attributes of "
134 "quantum data ID using a syntax of daf_butler user expressions (e.g. 'visit = 123')."
135 ),
136 )
138 fail_exception = Field[str](
139 dtype=str,
140 default="builtins.ValueError",
141 doc=(
142 "Class name of the exception to raise when fail condition is triggered. Can be "
143 "'lsst.pipe.base.NoWorkFound' to specify non-failure exception."
144 ),
145 )
147 def data_id_match(self) -> DataIdMatch | None:
148 if not self.fail_condition:
149 return None
150 return DataIdMatch(self.fail_condition)
153class BaseTestPipelineTask(PipelineTask):
154 """A base class for test-utility `PipelineTask` classes that read and write
155 mock datasets `runQuantum`.
157 Notes
158 -----
159 This class overrides `runQuantum` to read inputs and write a bit of
160 provenance into all of its outputs (always `MockDataset` instances). It
161 can also be configured to raise exceptions on certain data IDs. It reads
162 `MockDataset` inputs and simulates reading inputs of other types by
163 creating `MockDataset` inputs from their DatasetRefs.
165 Subclasses are responsible for defining connections, but init-input and
166 init-output connections are not supported at runtime (they may be present
167 as long as the task is never constructed). All output connections must
168 use mock storage classes. `..Input` and `..PrerequisiteInput` connections
169 that do not use mock storage classes will be handled by constructing a
170 `MockDataset` from the `~lsst.daf.butler.DatasetRef` rather than actually
171 reading them.
172 """
174 ConfigClass: ClassVar[type[PipelineTaskConfig]] = BaseTestPipelineTaskConfig
176 def __init__(
177 self,
178 *,
179 config: BaseTestPipelineTaskConfig,
180 initInputs: Mapping[str, Any],
181 **kwargs: Any,
182 ):
183 super().__init__(config=config, **kwargs)
184 self.fail_exception: type | None = None
185 self.data_id_match = self.config.data_id_match()
186 if self.data_id_match:
187 self.fail_exception = doImportType(self.config.fail_exception)
188 # Look for, check, and record init-inputs.
189 task_connections = self.ConfigClass.ConnectionsClass(config=config)
190 mock_dataset_quantum = MockDatasetQuantum(task_label=self.getName(), data_id={}, inputs={})
191 for connection_name in task_connections.initInputs:
192 input_dataset = initInputs[connection_name]
193 if not isinstance(input_dataset, MockDataset):
194 raise TypeError(
195 f"Expected MockDataset instance for init-input {self.getName()}.{connection_name}: "
196 f"got {input_dataset!r} of type {type(input_dataset)!r}."
197 )
198 connection = task_connections.allConnections[connection_name]
199 if input_dataset.dataset_type.name != connection.name:
200 raise RuntimeError(
201 f"Incorrect dataset type name for init-input {self.getName()}.{connection_name}: "
202 f"got {input_dataset.dataset_type.name!r}, expected {connection.name!r}."
203 )
204 if input_dataset.storage_class != connection.storageClass:
205 raise RuntimeError(
206 f"Incorrect storage class for init-input {self.getName()}.{connection_name}: "
207 f"got {input_dataset.storage_class!r}, expected {connection.storageClass!r}."
208 )
209 # To avoid very deep provenance we trim inputs to a single
210 # level.
211 input_dataset.quantum = None
212 mock_dataset_quantum.inputs[connection_name] = [input_dataset]
213 # Add init-outputs as task instance attributes.
214 for connection_name in task_connections.initOutputs:
215 connection = task_connections.allConnections[connection_name]
216 output_dataset = MockDataset(
217 dataset_id=None, # the task has no way to get this
218 dataset_type=SerializedDatasetType(
219 name=connection.name,
220 storageClass=connection.storageClass,
221 dimensions=SerializedDimensionGraph(names=[]),
222 ),
223 data_id={},
224 run=None, # task also has no way to get this
225 quantum=mock_dataset_quantum,
226 output_connection_name=connection_name,
227 )
228 setattr(self, connection_name, output_dataset)
230 config: BaseTestPipelineTaskConfig
232 def runQuantum(
233 self,
234 butlerQC: QuantumContext,
235 inputRefs: InputQuantizedConnection,
236 outputRefs: OutputQuantizedConnection,
237 ) -> None:
238 # docstring is inherited from the base class
239 quantum = butlerQC.quantum
241 _LOG.info("Mocking execution of task '%s' on quantum %s", self.getName(), quantum.dataId)
243 assert quantum.dataId is not None, "Quantum DataId cannot be None"
245 # Possibly raise an exception.
246 if self.data_id_match is not None and self.data_id_match.match(quantum.dataId):
247 _LOG.info("Simulating failure of task '%s' on quantum %s", self.getName(), quantum.dataId)
248 message = f"Simulated failure: task={self.getName()} dataId={quantum.dataId}"
249 assert self.fail_exception is not None, "Exception type must be defined"
250 raise self.fail_exception(message)
252 # Populate the bit of provenance we store in all outputs.
253 _LOG.info("Reading input data for task '%s' on quantum %s", self.getName(), quantum.dataId)
254 mock_dataset_quantum = MockDatasetQuantum(
255 task_label=self.getName(), data_id=quantum.dataId.full.byName(), inputs={}
256 )
257 for name, refs in inputRefs:
258 inputs_list = []
259 ref: DatasetRef
260 for ref in ensure_iterable(refs):
261 if isinstance(ref.datasetType.storageClass, MockStorageClass):
262 input_dataset = butlerQC.get(ref)
263 if isinstance(input_dataset, DeferredDatasetHandle):
264 input_dataset = input_dataset.get()
265 if not isinstance(input_dataset, MockDataset):
266 raise TypeError(
267 f"Expected MockDataset instance for {ref}; "
268 f"got {input_dataset!r} of type {type(input_dataset)!r}."
269 )
270 # To avoid very deep provenance we trim inputs to a single
271 # level.
272 input_dataset.quantum = None
273 else:
274 input_dataset = MockDataset(
275 dataset_id=ref.id,
276 dataset_type=ref.datasetType.to_simple(),
277 data_id=ref.dataId.full.byName(),
278 run=ref.run,
279 )
280 inputs_list.append(input_dataset)
281 mock_dataset_quantum.inputs[name] = inputs_list
283 # store mock outputs
284 for name, refs in outputRefs:
285 for ref in ensure_iterable(refs):
286 output = MockDataset(
287 dataset_id=ref.id,
288 dataset_type=ref.datasetType.to_simple(),
289 data_id=ref.dataId.full.byName(),
290 run=ref.run,
291 quantum=mock_dataset_quantum,
292 output_connection_name=name,
293 )
294 butlerQC.put(output, ref)
296 _LOG.info("Finished mocking task '%s' on quantum %s", self.getName(), quantum.dataId)
299class MockPipelineDefaultTargetConnections(PipelineTaskConnections, dimensions=()):
300 pass
303class MockPipelineDefaultTargetConfig(
304 PipelineTaskConfig, pipelineConnections=MockPipelineDefaultTargetConnections
305):
306 pass
309class MockPipelineDefaultTargetTask(PipelineTask):
310 """A `~lsst.pipe.base.PipelineTask` class used as the default target for
311 ``MockPipelineTaskConfig.original``.
313 This is effectively a workaround for `lsst.pex.config.ConfigurableField`
314 not supporting ``optional=True``, but that is generally a reasonable
315 limitation for production code and it wouldn't make sense just to support
316 test utilities.
317 """
319 ConfigClass = MockPipelineDefaultTargetConfig
322class MockPipelineTaskConnections(BaseTestPipelineTaskConnections, dimensions=()):
323 """A connections class that creates mock connections from the connections
324 of a real PipelineTask.
325 """
327 def __init__(self, *, config: MockPipelineTaskConfig):
328 self.original: PipelineTaskConnections = config.original.connections.ConnectionsClass(
329 config=config.original.value
330 )
331 self.dimensions.update(self.original.dimensions)
332 self.unmocked_dataset_types = frozenset(config.unmocked_dataset_types)
333 for name, connection in self.original.allConnections.items():
334 if connection.name not in self.unmocked_dataset_types:
335 # We register the mock storage class with the global singleton
336 # here, but can only put its name in the connection. That means
337 # the same global singleton (or one that also has these
338 # registrations) has to be available whenever this dataset type
339 # is used.
340 storage_class = MockStorageClass.get_or_register_mock(connection.storageClass)
341 kwargs: dict[str, Any] = {}
342 if hasattr(connection, "dimensions"):
343 connection_dimensions = set(connection.dimensions)
344 # Replace the generic "skypix" placeholder with htm7, since
345 # that requires the dataset type to have already been
346 # registered.
347 if "skypix" in connection_dimensions:
348 connection_dimensions.remove("skypix")
349 connection_dimensions.add("htm7")
350 kwargs["dimensions"] = connection_dimensions
351 connection = dataclasses.replace(
352 connection,
353 name=get_mock_name(connection.name),
354 storageClass=storage_class.name,
355 **kwargs,
356 )
357 elif name in self.original.outputs:
358 raise ValueError(f"Unmocked dataset type {connection.name!r} cannot be used as an output.")
359 elif name in self.original.initInputs:
360 raise ValueError(
361 f"Unmocked dataset type {connection.name!r} cannot be used as an init-input."
362 )
363 elif name in self.original.initOutputs:
364 raise ValueError(
365 f"Unmocked dataset type {connection.name!r} cannot be used as an init-output."
366 )
367 setattr(self, name, connection)
369 def getSpatialBoundsConnections(self) -> Iterable[str]:
370 return self.original.getSpatialBoundsConnections()
372 def getTemporalBoundsConnections(self) -> Iterable[str]:
373 return self.original.getTemporalBoundsConnections()
375 def adjustQuantum(
376 self,
377 inputs: dict[str, tuple[BaseInput, Collection[DatasetRef]]],
378 outputs: dict[str, tuple[Output, Collection[DatasetRef]]],
379 label: str,
380 data_id: DataCoordinate,
381 ) -> tuple[
382 Mapping[str, tuple[BaseInput, Collection[DatasetRef]]],
383 Mapping[str, tuple[Output, Collection[DatasetRef]]],
384 ]:
385 # Convert the given mappings from the mock dataset types to the
386 # original dataset types they were produced from.
387 original_inputs = {}
388 for connection_name, (_, mock_refs) in inputs.items():
389 original_connection = getattr(self.original, connection_name)
390 if original_connection.name in self.unmocked_dataset_types:
391 refs = mock_refs
392 else:
393 refs = MockStorageClass.unmock_dataset_refs(mock_refs)
394 original_inputs[connection_name] = (original_connection, refs)
395 original_outputs = {}
396 for connection_name, (_, mock_refs) in outputs.items():
397 original_connection = getattr(self.original, connection_name)
398 if original_connection.name in self.unmocked_dataset_types:
399 refs = mock_refs
400 else:
401 refs = MockStorageClass.unmock_dataset_refs(mock_refs)
402 original_outputs[connection_name] = (original_connection, refs)
403 # Call adjustQuantum on the original connections class.
404 adjusted_original_inputs, adjusted_original_outputs = self.original.adjustQuantum(
405 original_inputs, original_outputs, label, data_id
406 )
407 # Convert the results back to the mock dataset type.s
408 adjusted_inputs = {}
409 for connection_name, (original_connection, original_refs) in adjusted_original_inputs.items():
410 if original_connection.name in self.unmocked_dataset_types:
411 refs = original_refs
412 else:
413 refs = MockStorageClass.mock_dataset_refs(original_refs)
414 adjusted_inputs[connection_name] = (getattr(self, connection_name), refs)
415 adjusted_outputs = {}
416 for connection_name, (original_connection, original_refs) in adjusted_original_outputs.items():
417 if original_connection.name in self.unmocked_dataset_types:
418 refs = original_refs
419 else:
420 refs = MockStorageClass.mock_dataset_refs(original_refs)
421 adjusted_outputs[connection_name] = (getattr(self, connection_name), refs)
422 return adjusted_inputs, adjusted_outputs
425class MockPipelineTaskConfig(BaseTestPipelineTaskConfig, pipelineConnections=MockPipelineTaskConnections):
426 """Configuration class for `MockPipelineTask`."""
428 original: ConfigurableField = ConfigurableField(
429 doc="The original task being mocked by this one.", target=MockPipelineDefaultTargetTask
430 )
432 unmocked_dataset_types = ListField[str](
433 doc=(
434 "Names of input dataset types that should be used as-is instead "
435 "of being mocked. May include dataset types not relevant for "
436 "this task, which will be ignored."
437 ),
438 default=(),
439 optional=False,
440 )
443class MockPipelineTask(BaseTestPipelineTask):
444 """A test-utility implementation of `PipelineTask` with connections
445 generated by mocking those of a real task.
447 Notes
448 -----
449 At present `MockPipelineTask` simply drops any ``initInput`` and
450 ``initOutput`` connections present on the original, since `MockDataset`
451 creation for those would have to happen in the code that executes the task,
452 not in the task itself. Because `MockPipelineTask` never instantiates the
453 mock task (just its connections class), this is a limitation on what the
454 mocks can be used to test, not anything deeper.
455 """
457 ConfigClass: ClassVar[type[PipelineTaskConfig]] = MockPipelineTaskConfig
460class DynamicConnectionConfig(Config):
461 """A config class that defines a completely dynamic connection."""
463 dataset_type_name = Field[str](doc="Name for the dataset type as seen by the butler.", dtype=str)
464 dimensions = ListField[str](doc="Dimensions for the dataset type.", dtype=str, default=[])
465 storage_class = Field[str](
466 doc="Name of the butler storage class for the dataset type.", dtype=str, default="StructuredDataDict"
467 )
468 is_calibration = Field[bool](doc="Whether this dataset type is a calibration.", dtype=bool, default=False)
469 multiple = Field[bool](
470 doc="Whether this connection gets or puts multiple datasets for each quantum.",
471 dtype=bool,
472 default=False,
473 )
474 mock_storage_class = Field[bool](
475 doc="Whether the storage class should actually be a mock of the storage class given.",
476 dtype=bool,
477 default=True,
478 )
480 def make_connection(self, cls: type[_T]) -> _T:
481 storage_class = self.storage_class
482 if self.mock_storage_class:
483 storage_class = MockStorageClass.get_or_register_mock(storage_class).name
484 if issubclass(cls, cT.DimensionedConnection):
485 return cls( # type: ignore
486 name=self.dataset_type_name,
487 storageClass=storage_class,
488 isCalibration=self.is_calibration,
489 multiple=self.multiple,
490 dimensions=frozenset(self.dimensions),
491 )
492 else:
493 return cls(
494 name=self.dataset_type_name,
495 storageClass=storage_class,
496 multiple=self.multiple,
497 )
500class DynamicTestPipelineTaskConnections(PipelineTaskConnections, dimensions=()):
501 """A connections class whose dimensions and connections are wholly
502 determined via configuration.
503 """
505 def __init__(self, *, config: DynamicTestPipelineTaskConfig):
506 self.dimensions.update(config.dimensions)
507 connection_config: DynamicConnectionConfig
508 for connection_name, connection_config in config.init_inputs.items():
509 setattr(self, connection_name, connection_config.make_connection(cT.InitInput))
510 for connection_name, connection_config in config.init_outputs.items():
511 setattr(self, connection_name, connection_config.make_connection(cT.InitOutput))
512 for connection_name, connection_config in config.prerequisite_inputs.items():
513 setattr(self, connection_name, connection_config.make_connection(cT.PrerequisiteInput))
514 for connection_name, connection_config in config.inputs.items():
515 setattr(self, connection_name, connection_config.make_connection(cT.Input))
516 for connection_name, connection_config in config.outputs.items():
517 setattr(self, connection_name, connection_config.make_connection(cT.Output))
520class DynamicTestPipelineTaskConfig(
521 PipelineTaskConfig, pipelineConnections=DynamicTestPipelineTaskConnections
522):
523 """Configuration for DynamicTestPipelineTask."""
525 dimensions = ListField[str](doc="Dimensions for the task's quanta.", dtype=str, default=[])
526 init_inputs = ConfigDictField(
527 doc=(
528 "Init-input connections, keyed by the connection name as seen by the task. "
529 "Must be empty if the task will be constructed."
530 ),
531 keytype=str,
532 itemtype=DynamicConnectionConfig,
533 default={},
534 )
535 init_outputs = ConfigDictField(
536 doc=(
537 "Init-output connections, keyed by the connection name as seen by the task. "
538 "Must be empty if the task will be constructed."
539 ),
540 keytype=str,
541 itemtype=DynamicConnectionConfig,
542 default={},
543 )
544 prerequisite_inputs = ConfigDictField(
545 doc="Prerequisite input connections, keyed by the connection name as seen by the task.",
546 keytype=str,
547 itemtype=DynamicConnectionConfig,
548 default={},
549 )
550 inputs = ConfigDictField(
551 doc="Regular input connections, keyed by the connection name as seen by the task.",
552 keytype=str,
553 itemtype=DynamicConnectionConfig,
554 default={},
555 )
556 outputs = ConfigDictField(
557 doc="Regular output connections, keyed by the connection name as seen by the task.",
558 keytype=str,
559 itemtype=DynamicConnectionConfig,
560 default={},
561 )
564class DynamicTestPipelineTask(BaseTestPipelineTask):
565 """A test-utility implementation of `PipelineTask` with dimensions and
566 connections determined wholly from configuration.
567 """
569 ConfigClass: ClassVar[type[PipelineTaskConfig]] = DynamicTestPipelineTaskConfig