Coverage for python / lsst / pipe / base / tests / mocks / _pipeline_task.py: 26%
258 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-30 08:49 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-30 08:49 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
27from __future__ import annotations
29from lsst.pipe.base.connectionTypes import BaseInput, Output
31__all__ = (
32 "DynamicConnectionConfig",
33 "DynamicTestPipelineTask",
34 "DynamicTestPipelineTaskConfig",
35 "ForcedFailure",
36 "MockAlgorithmError",
37 "MockPipelineTask",
38 "MockPipelineTaskConfig",
39 "mock_pipeline_graph",
40)
42import dataclasses
43import logging
44import signal
45import time
46from collections.abc import Collection, Iterable, Mapping
47from typing import TYPE_CHECKING, Any, ClassVar, TypeVar
49from astropy.units import Quantity
51from lsst.daf.butler import DataCoordinate, DatasetRef, DeferredDatasetHandle, Quantum, SerializedDatasetType
52from lsst.pex.config import Config, ConfigDictField, ConfigurableField, Field, ListField
53from lsst.utils.doImport import doImportType
54from lsst.utils.introspection import get_full_type_name
55from lsst.utils.iteration import ensure_iterable
57from ... import connectionTypes as cT
58from ..._status import AlgorithmError, AnnotatedPartialOutputsError
59from ...automatic_connection_constants import METADATA_OUTPUT_CONNECTION_NAME, METADATA_OUTPUT_STORAGE_CLASS
60from ...config import PipelineTaskConfig
61from ...connections import InputQuantizedConnection, OutputQuantizedConnection, PipelineTaskConnections
62from ...pipeline_graph import PipelineGraph
63from ...pipelineTask import PipelineTask
64from ._data_id_match import DataIdMatch
65from ._storage_class import (
66 ConvertedUnmockedDataset,
67 MockDataset,
68 MockDatasetQuantum,
69 MockStorageClass,
70 get_mock_name,
71)
73_LOG = logging.getLogger(__name__)
75if TYPE_CHECKING:
76 from ..._quantumContext import QuantumContext
79_T = TypeVar("_T", bound=cT.BaseConnection)
82@dataclasses.dataclass
83class ForcedFailure:
84 """Information about an exception that should be raised by one or more
85 quanta.
86 """
88 condition: str
89 """Butler expression-language string that matches the data IDs that should
90 raise.
91 """
93 exception_type: type[BaseException] | None = None
94 """The type of exception to raise."""
96 memory_required: Quantity | None = None
97 """If not `None`, this failure simulates an out-of-memory failure by
98 raising only if this value exceeds `ExecutionResources.max_mem`.
99 """
101 def set_config(self, config: MockPipelineTaskConfig) -> None:
102 config.fail_condition = self.condition
103 if self.exception_type:
104 config.fail_exception = get_full_type_name(self.exception_type)
105 config.memory_required = self.memory_required
108class MockAlgorithmError(AlgorithmError):
109 """A subclass of `..AlgorithmError` chained to
110 `..AnnotatedPartialOutputsError` when the latter is configured to be raised
111 by `MockPipelineTask`.
112 """
114 @property
115 def metadata(self) -> dict[str, int]:
116 return {"badness": 12}
119def mock_pipeline_graph(
120 original_graph: PipelineGraph,
121 unmocked_dataset_types: Iterable[str] = (),
122 force_failures: Mapping[str, ForcedFailure] | None = None,
123) -> PipelineGraph:
124 """Create mocks for a full pipeline graph.
126 Parameters
127 ----------
128 original_graph : `~..pipeline_graph.PipelineGraph`
129 Original tasks and configuration to mock.
130 unmocked_dataset_types : `~collections.abc.Iterable` [ `str` ], optional
131 Names of overall-input dataset types that should not be replaced with
132 mocks. "Automatic" datasets written by the execution framework such
133 as configs, logs, and metadata are implicitly included.
134 force_failures : `~collections.abc.Mapping` [ `str`, `ForcedFailure` ]
135 Mapping from original task label to information about an exception one
136 or more quanta for this task should raise.
138 Returns
139 -------
140 mocked : `~..pipeline_graph.PipelineGraph`
141 Pipeline graph using `MockPipelineTask` configurations that target the
142 original tasks. Never resolved.
143 """
144 unmocked_dataset_types = list(unmocked_dataset_types)
145 if force_failures is None:
146 force_failures = {}
147 result = PipelineGraph(description=original_graph.description)
148 for task_node in original_graph.tasks.values():
149 unmocked_dataset_types.append(task_node.init.config_output.dataset_type_name)
150 if task_node.log_output is not None:
151 unmocked_dataset_types.append(task_node.log_output.dataset_type_name)
152 unmocked_dataset_types.append(task_node.metadata_output.dataset_type_name)
153 for original_task_node in original_graph.tasks.values():
154 config = MockPipelineTaskConfig()
155 config.original.retarget(original_task_node.task_class)
156 config.original = original_task_node.config
157 config.unmocked_dataset_types.extend(unmocked_dataset_types)
158 if original_task_node.label in force_failures:
159 force_failures[original_task_node.label].set_config(config)
160 result.add_task(get_mock_name(original_task_node.label), MockPipelineTask, config=config)
161 return result
164class BaseTestPipelineTaskConnections(PipelineTaskConnections, dimensions=()):
165 pass
168class BaseTestPipelineTaskConfig(PipelineTaskConfig, pipelineConnections=BaseTestPipelineTaskConnections):
169 fail_condition = Field[str](
170 dtype=str,
171 default="",
172 doc=(
173 "Condition on Data ID to raise an exception. String expression which includes attributes of "
174 "quantum data ID using a syntax of daf_butler user expressions (e.g. 'visit = 123')."
175 ),
176 )
178 fail_exception = Field[str](
179 dtype=str,
180 default="builtins.ValueError",
181 doc=(
182 "Class name of the exception to raise when fail condition is triggered. Can be "
183 "'lsst.pipe.base.NoWorkFound' to specify non-failure exception."
184 ),
185 )
187 fail_signal = Field[int](dtype=int, optional=True, doc="Signal to raise instead of an exception.")
189 memory_required = Field[str](
190 dtype=str,
191 default=None,
192 optional=True,
193 doc=(
194 "If not None, simulate an out-of-memory failure by raising only if ExecutionResource.max_mem "
195 "exceeds this value. This string should include units as parsed by astropy.units.Quantity "
196 "(e.g. '4GB')."
197 ),
198 )
200 sleep = Field[float](
201 dtype=float,
202 default=0.0,
203 doc="Time to sleep (seconds) before mock execution reading inputs or failing.",
204 )
206 int_value = Field[int](
207 "Arbitrary integer value to write into mock output datasets", dtype=int, optional=True, default=None
208 )
209 str_value = Field[str](
210 "Arbitrary string value to write into mock output datasets", dtype=str, optional=True, default=None
211 )
213 def data_id_match(self) -> DataIdMatch | None:
214 if not self.fail_condition:
215 return None
216 return DataIdMatch(self.fail_condition)
219class BaseTestPipelineTask(PipelineTask):
220 """A base class for test-utility `PipelineTask` classes that read and write
221 mock datasets `runQuantum`.
223 Parameters
224 ----------
225 config : `PipelineTaskConfig`
226 The pipeline task config.
227 initInputs : `~collections.abc.Mapping`
228 The init inputs datasets.
229 **kwargs : `~typing.Any`
230 Keyword parameters passed to base class constructor.
232 Notes
233 -----
234 This class overrides `runQuantum` to read inputs and write a bit of
235 provenance into all of its outputs (always `MockDataset` instances). It
236 can also be configured to raise exceptions on certain data IDs. It reads
237 `MockDataset` inputs and simulates reading inputs of other types by
238 creating `MockDataset` inputs from their DatasetRefs.
240 Subclasses are responsible for defining connections, but init-input and
241 init-output connections are not supported at runtime (they may be present
242 as long as the task is never constructed). All output connections must
243 use mock storage classes. `..Input` and `..PrerequisiteInput` connections
244 that do not use mock storage classes will be handled by constructing a
245 `MockDataset` from the `~lsst.daf.butler.DatasetRef` rather than actually
246 reading them.
247 """
249 ConfigClass: ClassVar[type[PipelineTaskConfig]] = BaseTestPipelineTaskConfig
251 def __init__(
252 self,
253 *,
254 config: BaseTestPipelineTaskConfig,
255 initInputs: Mapping[str, Any],
256 **kwargs: Any,
257 ):
258 super().__init__(config=config, **kwargs)
259 self.fail_exception: type | None = None
260 self.data_id_match = self.config.data_id_match()
261 if self.data_id_match:
262 self.fail_exception = doImportType(self.config.fail_exception)
263 self.memory_required = (
264 Quantity(self.config.memory_required) if self.config.memory_required is not None else None
265 )
266 # Look for, check, and record init-inputs.
267 task_connections = self.ConfigClass.ConnectionsClass(config=config)
268 mock_dataset_quantum = MockDatasetQuantum(task_label=self.getName(), data_id={}, inputs={})
269 for connection_name in task_connections.initInputs:
270 input_dataset = initInputs[connection_name]
271 if not isinstance(input_dataset, MockDataset):
272 raise TypeError(
273 f"Expected MockDataset instance for init-input {self.getName()}.{connection_name}: "
274 f"got {input_dataset!r} of type {type(input_dataset)!r}."
275 )
276 connection = task_connections.allConnections[connection_name]
277 if input_dataset.dataset_type.name != connection.name:
278 raise RuntimeError(
279 f"Incorrect dataset type name for init-input {self.getName()}.{connection_name}: "
280 f"got {input_dataset.dataset_type.name!r}, expected {connection.name!r}."
281 )
282 if input_dataset.storage_class != connection.storageClass:
283 raise RuntimeError(
284 f"Incorrect storage class for init-input {self.getName()}.{connection_name}: "
285 f"got {input_dataset.storage_class!r}, expected {connection.storageClass!r}."
286 )
287 # To avoid very deep provenance we trim inputs to a single
288 # level.
289 input_dataset.quantum = None
290 mock_dataset_quantum.inputs[connection_name] = [input_dataset]
291 # Add init-outputs as task instance attributes.
292 for connection_name in task_connections.initOutputs:
293 connection = task_connections.allConnections[connection_name]
294 output_dataset = MockDataset(
295 dataset_id=None, # the task has no way to get this
296 dataset_type=SerializedDatasetType(
297 name=connection.name,
298 storageClass=connection.storageClass,
299 dimensions=[],
300 ),
301 data_id={},
302 run=None, # task also has no way to get this
303 quantum=mock_dataset_quantum,
304 output_connection_name=connection_name,
305 int_value=self.config.int_value,
306 str_value=self.config.str_value,
307 )
308 setattr(self, connection_name, output_dataset)
310 config: BaseTestPipelineTaskConfig
312 def runQuantum(
313 self,
314 butlerQC: QuantumContext,
315 inputRefs: InputQuantizedConnection,
316 outputRefs: OutputQuantizedConnection,
317 ) -> None:
318 # docstring is inherited from the base class
319 quantum = butlerQC.quantum
321 _LOG.info("Mocking execution of task '%s' on quantum %s", self.getName(), quantum.dataId)
323 if self.config.sleep:
324 time.sleep(self.config.sleep)
326 assert quantum.dataId is not None, "Quantum DataId cannot be None"
328 # Possibly raise an exception.
329 if self.data_id_match is not None and self.data_id_match.match(quantum.dataId):
330 assert self.fail_exception is not None, "Exception type must be defined"
332 if self.memory_required is not None:
333 if butlerQC.resources.max_mem < self.memory_required:
334 _LOG.info(
335 "Simulating out-of-memory failure for task '%s' on quantum %s",
336 self.getName(),
337 quantum.dataId,
338 )
339 self._fail(quantum)
340 else:
341 _LOG.info("Simulating failure of task '%s' on quantum %s", self.getName(), quantum.dataId)
342 self._fail(quantum)
344 # Populate the bit of provenance we store in all outputs.
345 _LOG.info("Reading input data for task '%s' on quantum %s", self.getName(), quantum.dataId)
346 mock_dataset_quantum = MockDatasetQuantum(
347 task_label=self.getName(), data_id=dict(quantum.dataId.mapping), inputs={}
348 )
349 for name, refs in inputRefs:
350 inputs_list = []
351 ref: DatasetRef
352 for ref in ensure_iterable(refs):
353 if isinstance(ref.datasetType.storageClass, MockStorageClass):
354 input_dataset = butlerQC.get(ref)
355 if isinstance(input_dataset, DeferredDatasetHandle):
356 input_dataset = input_dataset.get()
357 if isinstance(input_dataset, MockDataset):
358 # To avoid very deep provenance we trim inputs to a
359 # single level.
360 input_dataset.quantum = None
361 elif not isinstance(input_dataset, ConvertedUnmockedDataset):
362 raise TypeError(
363 f"Expected MockDataset or ConvertedUnmockedDataset instance for {ref}; "
364 f"got {input_dataset!r} of type {type(input_dataset)!r}."
365 )
366 else:
367 input_dataset = MockDataset(
368 dataset_id=ref.id,
369 dataset_type=ref.datasetType.to_simple(),
370 data_id=dict(ref.dataId.mapping),
371 run=ref.run,
372 )
373 inputs_list.append(input_dataset)
374 mock_dataset_quantum.inputs[name] = inputs_list
376 # store mock outputs
377 for name, refs in outputRefs:
378 for ref in ensure_iterable(refs):
379 output = MockDataset(
380 dataset_id=ref.id,
381 dataset_type=ref.datasetType.to_simple(),
382 data_id=dict(ref.dataId.mapping),
383 run=ref.run,
384 quantum=mock_dataset_quantum,
385 output_connection_name=name,
386 int_value=self.config.int_value,
387 str_value=self.config.str_value,
388 )
389 butlerQC.put(output, ref)
391 _LOG.info("Finished mocking task '%s' on quantum %s", self.getName(), quantum.dataId)
393 def _fail(self, quantum: Quantum) -> None:
394 """Raise the configured exception.
396 Parameters
397 ----------
398 quantum : `lsst.daf.butler.Quantum`
399 Quantum producing the error.
400 """
401 message = f"Simulated failure: task={self.getName()} dataId={quantum.dataId}"
402 # Type annotations for optional config fields are broken, so MyPy
403 # doesn't think fail_signal could be None.
404 if self.config.fail_signal is not None:
405 signal.raise_signal(signal.Signals(self.config.fail_signal))
406 elif self.fail_exception is AnnotatedPartialOutputsError: # type: ignore[unreachable]
407 # This exception is expected to always chain another.
408 try:
409 raise MockAlgorithmError(message)
410 except AlgorithmError as original:
411 error = AnnotatedPartialOutputsError.annotate(original, self, log=self.log)
412 raise error from original
413 else:
414 assert self.fail_exception is not None, "Method should not be called."
415 raise self.fail_exception(message)
418class MockPipelineDefaultTargetConnections(PipelineTaskConnections, dimensions=()):
419 pass
422class MockPipelineDefaultTargetConfig(
423 PipelineTaskConfig, pipelineConnections=MockPipelineDefaultTargetConnections
424):
425 pass
428class MockPipelineDefaultTargetTask(PipelineTask):
429 """A `~lsst.pipe.base.PipelineTask` class used as the default target for
430 ``MockPipelineTaskConfig.original``.
432 This is effectively a workaround for `lsst.pex.config.ConfigurableField`
433 not supporting ``optional=True``, but that is generally a reasonable
434 limitation for production code and it wouldn't make sense just to support
435 test utilities.
436 """
438 ConfigClass = MockPipelineDefaultTargetConfig
441class MockPipelineTaskConnections(BaseTestPipelineTaskConnections, dimensions=()):
442 """A connections class that creates mock connections from the connections
443 of a real PipelineTask.
445 Parameters
446 ----------
447 config : `PipelineTaskConfig`
448 The config to use for the connection.
449 """
451 def __init__(self, *, config: MockPipelineTaskConfig):
452 self.original: PipelineTaskConnections = config.original.connections.ConnectionsClass(
453 config=config.original.value
454 )
455 self.dimensions.update(self.original.dimensions)
456 self.unmocked_dataset_types = frozenset(config.unmocked_dataset_types)
457 for name, connection in self.original.allConnections.items():
458 if connection.name not in self.unmocked_dataset_types:
459 # We register the mock storage class with the global
460 # singleton here, but can only put its name in the
461 # connection. That means the same global singleton (or one
462 # that also has these registrations) has to be available
463 # whenever this dataset type is used.
464 storage_class_name = MockStorageClass.get_or_register_mock(connection.storageClass).name
465 kwargs: dict[str, Any] = {}
466 if hasattr(connection, "dimensions"):
467 connection_dimensions = set(connection.dimensions)
468 # Replace the generic "skypix" placeholder with htm7, since
469 # that requires the dataset type to have already been
470 # registered.
471 if "skypix" in connection_dimensions:
472 connection_dimensions.remove("skypix")
473 connection_dimensions.add("htm7")
474 kwargs["dimensions"] = connection_dimensions
475 connection = dataclasses.replace(
476 connection,
477 name=get_mock_name(connection.name),
478 storageClass=storage_class_name,
479 **kwargs,
480 )
481 elif name in self.original.outputs:
482 raise ValueError(f"Unmocked dataset type {connection.name!r} cannot be used as an output.")
483 elif name in self.original.initInputs:
484 raise ValueError(
485 f"Unmocked dataset type {connection.name!r} cannot be used as an init-input."
486 )
487 elif name in self.original.initOutputs:
488 raise ValueError(
489 f"Unmocked dataset type {connection.name!r} cannot be used as an init-output."
490 )
491 elif (
492 connection.name.endswith(METADATA_OUTPUT_CONNECTION_NAME)
493 and connection.storageClass == METADATA_OUTPUT_STORAGE_CLASS
494 ):
495 # Task metadata does not use a mock storage class, because it's
496 # written by the system, but it does end up with the _mock_*
497 # prefix because the task label does.
498 connection = dataclasses.replace(connection, name=get_mock_name(connection.name))
499 setattr(self, name, connection)
501 def getSpatialBoundsConnections(self) -> Iterable[str]:
502 return self.original.getSpatialBoundsConnections()
504 def getTemporalBoundsConnections(self) -> Iterable[str]:
505 return self.original.getTemporalBoundsConnections()
507 def adjustQuantum(
508 self,
509 inputs: dict[str, tuple[BaseInput, Collection[DatasetRef]]],
510 outputs: dict[str, tuple[Output, Collection[DatasetRef]]],
511 label: str,
512 data_id: DataCoordinate,
513 ) -> tuple[
514 Mapping[str, tuple[BaseInput, Collection[DatasetRef]]],
515 Mapping[str, tuple[Output, Collection[DatasetRef]]],
516 ]:
517 # Convert the given mappings from the mock dataset types to the
518 # original dataset types they were produced from.
519 original_inputs = {}
520 for connection_name, (_, mock_refs) in inputs.items():
521 original_connection = getattr(self.original, connection_name)
522 if original_connection.name in self.unmocked_dataset_types:
523 refs = mock_refs
524 else:
525 refs = MockStorageClass.unmock_dataset_refs(mock_refs)
526 original_inputs[connection_name] = (original_connection, refs)
527 original_outputs = {}
528 for connection_name, (_, mock_refs) in outputs.items():
529 original_connection = getattr(self.original, connection_name)
530 if original_connection.name in self.unmocked_dataset_types:
531 refs = mock_refs
532 else:
533 refs = MockStorageClass.unmock_dataset_refs(mock_refs)
534 original_outputs[connection_name] = (original_connection, refs)
535 # Call adjustQuantum on the original connections class.
536 adjusted_original_inputs, adjusted_original_outputs = self.original.adjustQuantum(
537 original_inputs, original_outputs, label, data_id
538 )
539 # Convert the results back to the mock dataset type.s
540 adjusted_inputs = {}
541 for connection_name, (original_connection, original_refs) in adjusted_original_inputs.items():
542 if original_connection.name in self.unmocked_dataset_types:
543 refs = original_refs
544 else:
545 refs = MockStorageClass.mock_dataset_refs(original_refs)
546 adjusted_inputs[connection_name] = (getattr(self, connection_name), refs)
547 adjusted_outputs = {}
548 for connection_name, (original_connection, original_refs) in adjusted_original_outputs.items():
549 if original_connection.name in self.unmocked_dataset_types:
550 refs = original_refs
551 else:
552 refs = MockStorageClass.mock_dataset_refs(original_refs)
553 adjusted_outputs[connection_name] = (getattr(self, connection_name), refs)
554 return adjusted_inputs, adjusted_outputs
557class MockPipelineTaskConfig(BaseTestPipelineTaskConfig, pipelineConnections=MockPipelineTaskConnections):
558 """Configuration class for `MockPipelineTask`."""
560 original: ConfigurableField = ConfigurableField(
561 doc="The original task being mocked by this one.", target=MockPipelineDefaultTargetTask
562 )
564 unmocked_dataset_types = ListField[str](
565 doc=(
566 "Names of input dataset types that should be used as-is instead "
567 "of being mocked. May include dataset types not relevant for "
568 "this task, which will be ignored."
569 ),
570 default=(),
571 optional=False,
572 )
575class MockPipelineTask(BaseTestPipelineTask):
576 """A test-utility implementation of `PipelineTask` with connections
577 generated by mocking those of a real task.
579 Notes
580 -----
581 At present `MockPipelineTask` simply drops any ``initInput`` and
582 ``initOutput`` connections present on the original, since `MockDataset`
583 creation for those would have to happen in the code that executes the task,
584 not in the task itself. Because `MockPipelineTask` never instantiates the
585 mock task (just its connections class), this is a limitation on what the
586 mocks can be used to test, not anything deeper.
587 """
589 ConfigClass: ClassVar[type[PipelineTaskConfig]] = MockPipelineTaskConfig
592class DynamicConnectionConfig(Config):
593 """A config class that defines a completely dynamic connection."""
595 dataset_type_name = Field[str](doc="Name for the dataset type as seen by the butler.", dtype=str)
596 dimensions = ListField[str](doc="Dimensions for the dataset type.", dtype=str, default=[])
597 storage_class = Field[str](
598 doc="Name of the butler storage class for the dataset type.", dtype=str, default="StructuredDataDict"
599 )
600 is_calibration = Field[bool](doc="Whether this dataset type is a calibration.", dtype=bool, default=False)
601 multiple = Field[bool](
602 doc="Whether this connection gets or puts multiple datasets for each quantum.",
603 dtype=bool,
604 default=False,
605 )
606 mock_storage_class = Field[bool](
607 doc="Whether the storage class should actually be a mock of the storage class given.",
608 dtype=bool,
609 default=True,
610 )
611 minimum = Field[int](
612 doc="Minimum number of datasets per quantum required for this connection. Ignored for non-inputs.",
613 dtype=int,
614 default=1,
615 )
617 def make_connection(self, cls: type[_T]) -> _T:
618 storage_class = self.storage_class
619 if self.mock_storage_class:
620 storage_class = MockStorageClass.get_or_register_mock(storage_class).name
621 if issubclass(cls, cT.BaseInput):
622 return cls( # type: ignore
623 name=self.dataset_type_name,
624 storageClass=storage_class,
625 isCalibration=self.is_calibration,
626 multiple=self.multiple,
627 dimensions=frozenset(self.dimensions),
628 minimum=self.minimum,
629 )
630 elif issubclass(cls, cT.DimensionedConnection):
631 return cls( # type: ignore
632 name=self.dataset_type_name,
633 storageClass=storage_class,
634 isCalibration=self.is_calibration,
635 multiple=self.multiple,
636 dimensions=frozenset(self.dimensions),
637 )
638 else:
639 return cls(
640 name=self.dataset_type_name,
641 storageClass=storage_class,
642 multiple=self.multiple,
643 )
646class DynamicTestPipelineTaskConnections(PipelineTaskConnections, dimensions=()):
647 """A connections class whose dimensions and connections are wholly
648 determined via configuration.
650 Parameters
651 ----------
652 config : `PipelineTaskConfig`
653 Config to use for this connections object.
654 """
656 def __init__(self, *, config: DynamicTestPipelineTaskConfig):
657 self.dimensions.update(config.dimensions)
658 connection_config: DynamicConnectionConfig
659 for connection_name, connection_config in config.init_inputs.items():
660 setattr(self, connection_name, connection_config.make_connection(cT.InitInput))
661 for connection_name, connection_config in config.init_outputs.items():
662 setattr(self, connection_name, connection_config.make_connection(cT.InitOutput))
663 for connection_name, connection_config in config.prerequisite_inputs.items():
664 setattr(self, connection_name, connection_config.make_connection(cT.PrerequisiteInput))
665 for connection_name, connection_config in config.inputs.items():
666 setattr(self, connection_name, connection_config.make_connection(cT.Input))
667 for connection_name, connection_config in config.outputs.items():
668 setattr(self, connection_name, connection_config.make_connection(cT.Output))
671class DynamicTestPipelineTaskConfig(
672 BaseTestPipelineTaskConfig, pipelineConnections=DynamicTestPipelineTaskConnections
673):
674 """Configuration for DynamicTestPipelineTask."""
676 dimensions = ListField[str](doc="Dimensions for the task's quanta.", dtype=str, default=[])
677 init_inputs = ConfigDictField(
678 doc=(
679 "Init-input connections, keyed by the connection name as seen by the task. "
680 "Must be empty if the task will be constructed."
681 ),
682 keytype=str,
683 itemtype=DynamicConnectionConfig,
684 default={},
685 )
686 init_outputs = ConfigDictField(
687 doc=(
688 "Init-output connections, keyed by the connection name as seen by the task. "
689 "Must be empty if the task will be constructed."
690 ),
691 keytype=str,
692 itemtype=DynamicConnectionConfig,
693 default={},
694 )
695 prerequisite_inputs = ConfigDictField(
696 doc="Prerequisite input connections, keyed by the connection name as seen by the task.",
697 keytype=str,
698 itemtype=DynamicConnectionConfig,
699 default={},
700 )
701 inputs = ConfigDictField(
702 doc="Regular input connections, keyed by the connection name as seen by the task.",
703 keytype=str,
704 itemtype=DynamicConnectionConfig,
705 default={},
706 )
707 outputs = ConfigDictField(
708 doc="Regular output connections, keyed by the connection name as seen by the task.",
709 keytype=str,
710 itemtype=DynamicConnectionConfig,
711 default={},
712 )
715class DynamicTestPipelineTask(BaseTestPipelineTask):
716 """A test-utility implementation of `PipelineTask` with dimensions and
717 connections determined wholly from configuration.
718 """
720 ConfigClass: ClassVar[type[PipelineTaskConfig]] = DynamicTestPipelineTaskConfig