Coverage for python/lsst/ctrl/mpexec/taskFactory.py: 28%
33 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-10 03:29 -0700
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-10 03:29 -0700
1# This file is part of ctrl_mpexec.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
28from __future__ import annotations
30__all__ = ["TaskFactory"]
32import logging
33import warnings
34from collections.abc import Iterable
35from typing import TYPE_CHECKING, Any
37from lsst.pipe.base import TaskDef
38from lsst.pipe.base import TaskFactory as BaseTaskFactory
39from lsst.pipe.base.pipeline_graph import TaskNode
40from lsst.utils.introspection import find_outside_stacklevel
42if TYPE_CHECKING:
43 from lsst.daf.butler import DatasetRef, LimitedButler
44 from lsst.pipe.base import PipelineTask
46_LOG = logging.getLogger(__name__)
49class TaskFactory(BaseTaskFactory):
50 """Class instantiating PipelineTasks."""
52 def makeTask(
53 self,
54 task_node: TaskDef | TaskNode,
55 /,
56 butler: LimitedButler,
57 initInputRefs: Iterable[DatasetRef] | None,
58 ) -> PipelineTask:
59 # docstring inherited
60 config = task_node.config
61 init_inputs: dict[str, Any] = {}
62 init_input_refs_by_dataset_type = {}
63 if initInputRefs is not None:
64 init_input_refs_by_dataset_type = {ref.datasetType.name: ref for ref in initInputRefs}
65 if isinstance(task_node, TaskDef):
66 # TODO: remove this block on DM-40443, along with type annotation.
67 warnings.warn(
68 "Passing TaskDef to TaskFactory is deprecated and will not be supported after v27.",
69 FutureWarning,
70 find_outside_stacklevel("lsst.pipe.base"),
71 )
72 task_class = task_node.taskClass
73 assert task_class is not None
74 if init_input_refs_by_dataset_type:
75 connections = config.connections.ConnectionsClass(config=config)
76 for name in connections.initInputs:
77 attribute = getattr(connections, name)
78 init_inputs[name] = butler.get(init_input_refs_by_dataset_type[attribute.name])
79 else:
80 task_class = task_node.task_class
81 if init_input_refs_by_dataset_type:
82 for read_edge in task_node.init.inputs.values():
83 init_inputs[read_edge.connection_name] = butler.get(
84 init_input_refs_by_dataset_type[read_edge.dataset_type_name]
85 )
86 # make task instance
87 task = task_class(config=config, initInputs=init_inputs, name=task_node.label)
88 return task