Coverage for python/lsst/pipe/base/tests/mocks/_pipeline_task.py: 26%
225 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-07 02:48 -0700
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-07 02:48 -0700
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
27from __future__ import annotations
29from lsst.pipe.base.connectionTypes import BaseInput, Output
31__all__ = (
32 "DynamicConnectionConfig",
33 "DynamicTestPipelineTask",
34 "DynamicTestPipelineTaskConfig",
35 "ForcedFailure",
36 "MockPipelineTask",
37 "MockPipelineTaskConfig",
38 "mock_pipeline_graph",
39)
41import dataclasses
42import logging
43from collections.abc import Collection, Iterable, Mapping
44from typing import TYPE_CHECKING, Any, ClassVar, TypeVar
46from astropy.units import Quantity
47from lsst.daf.butler import DataCoordinate, DatasetRef, DeferredDatasetHandle, SerializedDatasetType
48from lsst.pex.config import Config, ConfigDictField, ConfigurableField, Field, ListField
49from lsst.utils.doImport import doImportType
50from lsst.utils.introspection import get_full_type_name
51from lsst.utils.iteration import ensure_iterable
53from ... import automatic_connection_constants as acc
54from ... import connectionTypes as cT
55from ...config import PipelineTaskConfig
56from ...connections import InputQuantizedConnection, OutputQuantizedConnection, PipelineTaskConnections
57from ...pipeline_graph import PipelineGraph
58from ...pipelineTask import PipelineTask
59from ._data_id_match import DataIdMatch
60from ._storage_class import MockDataset, MockDatasetQuantum, MockStorageClass, get_mock_name
62_LOG = logging.getLogger(__name__)
64if TYPE_CHECKING:
65 from ..._quantumContext import QuantumContext
68_T = TypeVar("_T", bound=cT.BaseConnection)
71@dataclasses.dataclass
72class ForcedFailure:
73 """Information about an exception that should be raised by one or more
74 quanta.
75 """
77 condition: str
78 """Butler expression-language string that matches the data IDs that should
79 raise.
80 """
82 exception_type: type[Exception] | None = None
83 """The type of exception to raise."""
85 memory_required: Quantity | None = None
86 """If not `None`, this failure simulates an out-of-memory failure by
87 raising only if this value exceeds `ExecutionResources.max_mem`.
88 """
90 def set_config(self, config: MockPipelineTaskConfig) -> None:
91 config.fail_condition = self.condition
92 if self.exception_type:
93 config.fail_exception = get_full_type_name(self.exception_type)
94 config.memory_required = self.memory_required
97def mock_pipeline_graph(
98 original_graph: PipelineGraph,
99 unmocked_dataset_types: Iterable[str] = (),
100 force_failures: Mapping[str, ForcedFailure] | None = None,
101) -> PipelineGraph:
102 """Create mocks for a full pipeline graph.
104 Parameters
105 ----------
106 original_graph : `~..pipeline_graph.PipelineGraph`
107 Original tasks and configuration to mock.
108 unmocked_dataset_types : `~collections.abc.Iterable` [ `str` ], optional
109 Names of overall-input dataset types that should not be replaced with
110 mocks.
111 force_failures : `~collections.abc.Mapping` [ `str`, `ForcedFailure` ]
112 Mapping from original task label to information about an exception one
113 or more quanta for this task should raise.
115 Returns
116 -------
117 mocked : `~..pipeline_graph.PipelineGraph`
118 Pipeline graph using `MockPipelineTask` configurations that target the
119 original tasks. Never resolved.
120 """
121 unmocked_dataset_types = tuple(unmocked_dataset_types)
122 if force_failures is None:
123 force_failures = {}
124 result = PipelineGraph(description=original_graph.description)
125 for original_task_node in original_graph.tasks.values():
126 config = MockPipelineTaskConfig()
127 config.original.retarget(original_task_node.task_class)
128 config.original = original_task_node.config
129 config.unmocked_dataset_types.extend(unmocked_dataset_types)
130 if original_task_node.label in force_failures:
131 force_failures[original_task_node.label].set_config(config)
132 result.add_task(get_mock_name(original_task_node.label), MockPipelineTask, config=config)
133 return result
136class BaseTestPipelineTaskConnections(PipelineTaskConnections, dimensions=()):
137 pass
140class BaseTestPipelineTaskConfig(PipelineTaskConfig, pipelineConnections=BaseTestPipelineTaskConnections):
141 fail_condition = Field[str](
142 dtype=str,
143 default="",
144 doc=(
145 "Condition on Data ID to raise an exception. String expression which includes attributes of "
146 "quantum data ID using a syntax of daf_butler user expressions (e.g. 'visit = 123')."
147 ),
148 )
150 fail_exception = Field[str](
151 dtype=str,
152 default="builtins.ValueError",
153 doc=(
154 "Class name of the exception to raise when fail condition is triggered. Can be "
155 "'lsst.pipe.base.NoWorkFound' to specify non-failure exception."
156 ),
157 )
159 memory_required = Field[str](
160 dtype=str,
161 default=None,
162 optional=True,
163 doc=(
164 "If not None, simulate an out-of-memory failure by raising only if ExecutionResource.max_mem "
165 "exceeds this value. This string should include units as parsed by astropy.units.Quantity "
166 "(e.g. '4GB')."
167 ),
168 )
170 def data_id_match(self) -> DataIdMatch | None:
171 if not self.fail_condition:
172 return None
173 return DataIdMatch(self.fail_condition)
176class BaseTestPipelineTask(PipelineTask):
177 """A base class for test-utility `PipelineTask` classes that read and write
178 mock datasets `runQuantum`.
180 Parameters
181 ----------
182 config : `PipelineTaskConfig`
183 The pipeline task config.
184 initInputs : `~collections.abc.Mapping`
185 The init inputs datasets.
186 **kwargs : `~typing.Any`
187 Keyword parameters passed to base class constructor.
189 Notes
190 -----
191 This class overrides `runQuantum` to read inputs and write a bit of
192 provenance into all of its outputs (always `MockDataset` instances). It
193 can also be configured to raise exceptions on certain data IDs. It reads
194 `MockDataset` inputs and simulates reading inputs of other types by
195 creating `MockDataset` inputs from their DatasetRefs.
197 Subclasses are responsible for defining connections, but init-input and
198 init-output connections are not supported at runtime (they may be present
199 as long as the task is never constructed). All output connections must
200 use mock storage classes. `..Input` and `..PrerequisiteInput` connections
201 that do not use mock storage classes will be handled by constructing a
202 `MockDataset` from the `~lsst.daf.butler.DatasetRef` rather than actually
203 reading them.
204 """
206 ConfigClass: ClassVar[type[PipelineTaskConfig]] = BaseTestPipelineTaskConfig
208 def __init__(
209 self,
210 *,
211 config: BaseTestPipelineTaskConfig,
212 initInputs: Mapping[str, Any],
213 **kwargs: Any,
214 ):
215 super().__init__(config=config, **kwargs)
216 self.fail_exception: type | None = None
217 self.data_id_match = self.config.data_id_match()
218 if self.data_id_match:
219 self.fail_exception = doImportType(self.config.fail_exception)
220 self.memory_required = (
221 Quantity(self.config.memory_required) if self.config.memory_required is not None else None
222 )
223 # Look for, check, and record init-inputs.
224 task_connections = self.ConfigClass.ConnectionsClass(config=config)
225 mock_dataset_quantum = MockDatasetQuantum(task_label=self.getName(), data_id={}, inputs={})
226 for connection_name in task_connections.initInputs:
227 input_dataset = initInputs[connection_name]
228 if not isinstance(input_dataset, MockDataset):
229 raise TypeError(
230 f"Expected MockDataset instance for init-input {self.getName()}.{connection_name}: "
231 f"got {input_dataset!r} of type {type(input_dataset)!r}."
232 )
233 connection = task_connections.allConnections[connection_name]
234 if input_dataset.dataset_type.name != connection.name:
235 raise RuntimeError(
236 f"Incorrect dataset type name for init-input {self.getName()}.{connection_name}: "
237 f"got {input_dataset.dataset_type.name!r}, expected {connection.name!r}."
238 )
239 if input_dataset.storage_class != connection.storageClass:
240 raise RuntimeError(
241 f"Incorrect storage class for init-input {self.getName()}.{connection_name}: "
242 f"got {input_dataset.storage_class!r}, expected {connection.storageClass!r}."
243 )
244 # To avoid very deep provenance we trim inputs to a single
245 # level.
246 input_dataset.quantum = None
247 mock_dataset_quantum.inputs[connection_name] = [input_dataset]
248 # Add init-outputs as task instance attributes.
249 for connection_name in task_connections.initOutputs:
250 connection = task_connections.allConnections[connection_name]
251 output_dataset = MockDataset(
252 dataset_id=None, # the task has no way to get this
253 dataset_type=SerializedDatasetType(
254 name=connection.name,
255 storageClass=connection.storageClass,
256 dimensions=[],
257 ),
258 data_id={},
259 run=None, # task also has no way to get this
260 quantum=mock_dataset_quantum,
261 output_connection_name=connection_name,
262 )
263 setattr(self, connection_name, output_dataset)
265 config: BaseTestPipelineTaskConfig
267 def runQuantum(
268 self,
269 butlerQC: QuantumContext,
270 inputRefs: InputQuantizedConnection,
271 outputRefs: OutputQuantizedConnection,
272 ) -> None:
273 # docstring is inherited from the base class
274 quantum = butlerQC.quantum
276 _LOG.info("Mocking execution of task '%s' on quantum %s", self.getName(), quantum.dataId)
278 assert quantum.dataId is not None, "Quantum DataId cannot be None"
280 # Possibly raise an exception.
281 if self.data_id_match is not None and self.data_id_match.match(quantum.dataId):
282 assert self.fail_exception is not None, "Exception type must be defined"
283 message = f"Simulated failure: task={self.getName()} dataId={quantum.dataId}"
284 if self.memory_required is not None:
285 if butlerQC.resources.max_mem < self.memory_required:
286 _LOG.info(
287 "Simulating out-of-memory failure for task '%s' on quantum %s",
288 self.getName(),
289 quantum.dataId,
290 )
291 raise self.fail_exception(message)
292 else:
293 _LOG.info("Simulating failure of task '%s' on quantum %s", self.getName(), quantum.dataId)
294 raise self.fail_exception(message)
296 # Populate the bit of provenance we store in all outputs.
297 _LOG.info("Reading input data for task '%s' on quantum %s", self.getName(), quantum.dataId)
298 mock_dataset_quantum = MockDatasetQuantum(
299 task_label=self.getName(), data_id=dict(quantum.dataId.mapping), inputs={}
300 )
301 for name, refs in inputRefs:
302 inputs_list = []
303 ref: DatasetRef
304 for ref in ensure_iterable(refs):
305 if isinstance(ref.datasetType.storageClass, MockStorageClass):
306 input_dataset = butlerQC.get(ref)
307 if isinstance(input_dataset, DeferredDatasetHandle):
308 input_dataset = input_dataset.get()
309 if not isinstance(input_dataset, MockDataset):
310 raise TypeError(
311 f"Expected MockDataset instance for {ref}; "
312 f"got {input_dataset!r} of type {type(input_dataset)!r}."
313 )
314 # To avoid very deep provenance we trim inputs to a single
315 # level.
316 input_dataset.quantum = None
317 else:
318 input_dataset = MockDataset(
319 dataset_id=ref.id,
320 dataset_type=ref.datasetType.to_simple(),
321 data_id=dict(ref.dataId.mapping),
322 run=ref.run,
323 )
324 inputs_list.append(input_dataset)
325 mock_dataset_quantum.inputs[name] = inputs_list
327 # store mock outputs
328 for name, refs in outputRefs:
329 for ref in ensure_iterable(refs):
330 output = MockDataset(
331 dataset_id=ref.id,
332 dataset_type=ref.datasetType.to_simple(),
333 data_id=dict(ref.dataId.mapping),
334 run=ref.run,
335 quantum=mock_dataset_quantum,
336 output_connection_name=name,
337 )
338 butlerQC.put(output, ref)
340 _LOG.info("Finished mocking task '%s' on quantum %s", self.getName(), quantum.dataId)
343class MockPipelineDefaultTargetConnections(PipelineTaskConnections, dimensions=()):
344 pass
347class MockPipelineDefaultTargetConfig(
348 PipelineTaskConfig, pipelineConnections=MockPipelineDefaultTargetConnections
349):
350 pass
353class MockPipelineDefaultTargetTask(PipelineTask):
354 """A `~lsst.pipe.base.PipelineTask` class used as the default target for
355 ``MockPipelineTaskConfig.original``.
357 This is effectively a workaround for `lsst.pex.config.ConfigurableField`
358 not supporting ``optional=True``, but that is generally a reasonable
359 limitation for production code and it wouldn't make sense just to support
360 test utilities.
361 """
363 ConfigClass = MockPipelineDefaultTargetConfig
366class MockPipelineTaskConnections(BaseTestPipelineTaskConnections, dimensions=()):
367 """A connections class that creates mock connections from the connections
368 of a real PipelineTask.
370 Parameters
371 ----------
372 config : `PipelineTaskConfig`
373 The config to use for the connection.
374 """
376 def __init__(self, *, config: MockPipelineTaskConfig):
377 self.original: PipelineTaskConnections = config.original.connections.ConnectionsClass(
378 config=config.original.value
379 )
380 self.dimensions.update(self.original.dimensions)
381 self.unmocked_dataset_types = frozenset(config.unmocked_dataset_types)
382 for name, connection in self.original.allConnections.items():
383 if connection.name not in self.unmocked_dataset_types:
384 if connection.storageClass in (
385 acc.CONFIG_INIT_OUTPUT_STORAGE_CLASS,
386 acc.METADATA_OUTPUT_STORAGE_CLASS,
387 acc.LOG_OUTPUT_STORAGE_CLASS,
388 ):
389 # We don't mock the automatic output connections, so if
390 # they're used as an input in any other connection, we
391 # can't mock them there either.
392 storage_class_name = connection.storageClass
393 else:
394 # We register the mock storage class with the global
395 # singleton here, but can only put its name in the
396 # connection. That means the same global singleton (or one
397 # that also has these registrations) has to be available
398 # whenever this dataset type is used.
399 storage_class_name = MockStorageClass.get_or_register_mock(connection.storageClass).name
400 kwargs: dict[str, Any] = {}
401 if hasattr(connection, "dimensions"):
402 connection_dimensions = set(connection.dimensions)
403 # Replace the generic "skypix" placeholder with htm7, since
404 # that requires the dataset type to have already been
405 # registered.
406 if "skypix" in connection_dimensions:
407 connection_dimensions.remove("skypix")
408 connection_dimensions.add("htm7")
409 kwargs["dimensions"] = connection_dimensions
410 connection = dataclasses.replace(
411 connection,
412 name=get_mock_name(connection.name),
413 storageClass=storage_class_name,
414 **kwargs,
415 )
416 elif name in self.original.outputs:
417 raise ValueError(f"Unmocked dataset type {connection.name!r} cannot be used as an output.")
418 elif name in self.original.initInputs:
419 raise ValueError(
420 f"Unmocked dataset type {connection.name!r} cannot be used as an init-input."
421 )
422 elif name in self.original.initOutputs:
423 raise ValueError(
424 f"Unmocked dataset type {connection.name!r} cannot be used as an init-output."
425 )
426 setattr(self, name, connection)
428 def getSpatialBoundsConnections(self) -> Iterable[str]:
429 return self.original.getSpatialBoundsConnections()
431 def getTemporalBoundsConnections(self) -> Iterable[str]:
432 return self.original.getTemporalBoundsConnections()
434 def adjustQuantum(
435 self,
436 inputs: dict[str, tuple[BaseInput, Collection[DatasetRef]]],
437 outputs: dict[str, tuple[Output, Collection[DatasetRef]]],
438 label: str,
439 data_id: DataCoordinate,
440 ) -> tuple[
441 Mapping[str, tuple[BaseInput, Collection[DatasetRef]]],
442 Mapping[str, tuple[Output, Collection[DatasetRef]]],
443 ]:
444 # Convert the given mappings from the mock dataset types to the
445 # original dataset types they were produced from.
446 original_inputs = {}
447 for connection_name, (_, mock_refs) in inputs.items():
448 original_connection = getattr(self.original, connection_name)
449 if original_connection.name in self.unmocked_dataset_types:
450 refs = mock_refs
451 else:
452 refs = MockStorageClass.unmock_dataset_refs(mock_refs)
453 original_inputs[connection_name] = (original_connection, refs)
454 original_outputs = {}
455 for connection_name, (_, mock_refs) in outputs.items():
456 original_connection = getattr(self.original, connection_name)
457 if original_connection.name in self.unmocked_dataset_types:
458 refs = mock_refs
459 else:
460 refs = MockStorageClass.unmock_dataset_refs(mock_refs)
461 original_outputs[connection_name] = (original_connection, refs)
462 # Call adjustQuantum on the original connections class.
463 adjusted_original_inputs, adjusted_original_outputs = self.original.adjustQuantum(
464 original_inputs, original_outputs, label, data_id
465 )
466 # Convert the results back to the mock dataset type.s
467 adjusted_inputs = {}
468 for connection_name, (original_connection, original_refs) in adjusted_original_inputs.items():
469 if original_connection.name in self.unmocked_dataset_types:
470 refs = original_refs
471 else:
472 refs = MockStorageClass.mock_dataset_refs(original_refs)
473 adjusted_inputs[connection_name] = (getattr(self, connection_name), refs)
474 adjusted_outputs = {}
475 for connection_name, (original_connection, original_refs) in adjusted_original_outputs.items():
476 if original_connection.name in self.unmocked_dataset_types:
477 refs = original_refs
478 else:
479 refs = MockStorageClass.mock_dataset_refs(original_refs)
480 adjusted_outputs[connection_name] = (getattr(self, connection_name), refs)
481 return adjusted_inputs, adjusted_outputs
484class MockPipelineTaskConfig(BaseTestPipelineTaskConfig, pipelineConnections=MockPipelineTaskConnections):
485 """Configuration class for `MockPipelineTask`."""
487 original: ConfigurableField = ConfigurableField(
488 doc="The original task being mocked by this one.", target=MockPipelineDefaultTargetTask
489 )
491 unmocked_dataset_types = ListField[str](
492 doc=(
493 "Names of input dataset types that should be used as-is instead "
494 "of being mocked. May include dataset types not relevant for "
495 "this task, which will be ignored."
496 ),
497 default=(),
498 optional=False,
499 )
502class MockPipelineTask(BaseTestPipelineTask):
503 """A test-utility implementation of `PipelineTask` with connections
504 generated by mocking those of a real task.
506 Notes
507 -----
508 At present `MockPipelineTask` simply drops any ``initInput`` and
509 ``initOutput`` connections present on the original, since `MockDataset`
510 creation for those would have to happen in the code that executes the task,
511 not in the task itself. Because `MockPipelineTask` never instantiates the
512 mock task (just its connections class), this is a limitation on what the
513 mocks can be used to test, not anything deeper.
514 """
516 ConfigClass: ClassVar[type[PipelineTaskConfig]] = MockPipelineTaskConfig
519class DynamicConnectionConfig(Config):
520 """A config class that defines a completely dynamic connection."""
522 dataset_type_name = Field[str](doc="Name for the dataset type as seen by the butler.", dtype=str)
523 dimensions = ListField[str](doc="Dimensions for the dataset type.", dtype=str, default=[])
524 storage_class = Field[str](
525 doc="Name of the butler storage class for the dataset type.", dtype=str, default="StructuredDataDict"
526 )
527 is_calibration = Field[bool](doc="Whether this dataset type is a calibration.", dtype=bool, default=False)
528 multiple = Field[bool](
529 doc="Whether this connection gets or puts multiple datasets for each quantum.",
530 dtype=bool,
531 default=False,
532 )
533 mock_storage_class = Field[bool](
534 doc="Whether the storage class should actually be a mock of the storage class given.",
535 dtype=bool,
536 default=True,
537 )
539 def make_connection(self, cls: type[_T]) -> _T:
540 storage_class = self.storage_class
541 if self.mock_storage_class:
542 storage_class = MockStorageClass.get_or_register_mock(storage_class).name
543 if issubclass(cls, cT.DimensionedConnection):
544 return cls( # type: ignore
545 name=self.dataset_type_name,
546 storageClass=storage_class,
547 isCalibration=self.is_calibration,
548 multiple=self.multiple,
549 dimensions=frozenset(self.dimensions),
550 )
551 else:
552 return cls(
553 name=self.dataset_type_name,
554 storageClass=storage_class,
555 multiple=self.multiple,
556 )
559class DynamicTestPipelineTaskConnections(PipelineTaskConnections, dimensions=()):
560 """A connections class whose dimensions and connections are wholly
561 determined via configuration.
563 Parameters
564 ----------
565 config : `PipelineTaskConfig`
566 Config to use for this connections object.
567 """
569 def __init__(self, *, config: DynamicTestPipelineTaskConfig):
570 self.dimensions.update(config.dimensions)
571 connection_config: DynamicConnectionConfig
572 for connection_name, connection_config in config.init_inputs.items():
573 setattr(self, connection_name, connection_config.make_connection(cT.InitInput))
574 for connection_name, connection_config in config.init_outputs.items():
575 setattr(self, connection_name, connection_config.make_connection(cT.InitOutput))
576 for connection_name, connection_config in config.prerequisite_inputs.items():
577 setattr(self, connection_name, connection_config.make_connection(cT.PrerequisiteInput))
578 for connection_name, connection_config in config.inputs.items():
579 setattr(self, connection_name, connection_config.make_connection(cT.Input))
580 for connection_name, connection_config in config.outputs.items():
581 setattr(self, connection_name, connection_config.make_connection(cT.Output))
584class DynamicTestPipelineTaskConfig(
585 PipelineTaskConfig, pipelineConnections=DynamicTestPipelineTaskConnections
586):
587 """Configuration for DynamicTestPipelineTask."""
589 dimensions = ListField[str](doc="Dimensions for the task's quanta.", dtype=str, default=[])
590 init_inputs = ConfigDictField(
591 doc=(
592 "Init-input connections, keyed by the connection name as seen by the task. "
593 "Must be empty if the task will be constructed."
594 ),
595 keytype=str,
596 itemtype=DynamicConnectionConfig,
597 default={},
598 )
599 init_outputs = ConfigDictField(
600 doc=(
601 "Init-output connections, keyed by the connection name as seen by the task. "
602 "Must be empty if the task will be constructed."
603 ),
604 keytype=str,
605 itemtype=DynamicConnectionConfig,
606 default={},
607 )
608 prerequisite_inputs = ConfigDictField(
609 doc="Prerequisite input connections, keyed by the connection name as seen by the task.",
610 keytype=str,
611 itemtype=DynamicConnectionConfig,
612 default={},
613 )
614 inputs = ConfigDictField(
615 doc="Regular input connections, keyed by the connection name as seen by the task.",
616 keytype=str,
617 itemtype=DynamicConnectionConfig,
618 default={},
619 )
620 outputs = ConfigDictField(
621 doc="Regular output connections, keyed by the connection name as seen by the task.",
622 keytype=str,
623 itemtype=DynamicConnectionConfig,
624 default={},
625 )
628class DynamicTestPipelineTask(BaseTestPipelineTask):
629 """A test-utility implementation of `PipelineTask` with dimensions and
630 connections determined wholly from configuration.
631 """
633 ConfigClass: ClassVar[type[PipelineTaskConfig]] = DynamicTestPipelineTaskConfig