Coverage for tests / test_task_factory.py: 42%
64 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-30 08:49 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-30 08:49 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <https://www.gnu.org/licenses/>.
28from __future__ import annotations
30import shutil
31import tempfile
32import unittest
33from typing import TYPE_CHECKING
35import lsst.daf.butler.tests as butlerTests
36import lsst.pex.config as pexConfig
37from lsst.pipe.base import (
38 PipelineGraph,
39 PipelineTaskConfig,
40 PipelineTaskConnections,
41 TaskFactory,
42 connectionTypes,
43)
45if TYPE_CHECKING:
46 from lsst.daf.butler import Butler, DatasetRef
48# Storage class to use for tests of fakes.
49_FAKE_STORAGE_CLASS = "StructuredDataDict"
52class FakeConnections(PipelineTaskConnections, dimensions=set()):
53 """Fake connections class used for testing."""
55 initInput = connectionTypes.InitInput(name="fakeInitInput", doc="test", storageClass=_FAKE_STORAGE_CLASS)
56 initOutput = connectionTypes.InitOutput(
57 name="fakeInitOutput", doc="test", storageClass=_FAKE_STORAGE_CLASS
58 )
59 input = connectionTypes.Input(
60 name="fakeInput", doc="test", storageClass=_FAKE_STORAGE_CLASS, dimensions=set()
61 )
62 output = connectionTypes.Output(
63 name="fakeOutput", doc="test", storageClass=_FAKE_STORAGE_CLASS, dimensions=set()
64 )
67class FakeConfig(PipelineTaskConfig, pipelineConnections=FakeConnections):
68 """Config class used along with fake connections class."""
70 widget = pexConfig.Field(dtype=float, doc="test", default=1.0)
73def mockTaskClass():
74 """Record calls to ``__call__``.
76 A class placeholder that records calls to __call__.
77 """
78 mock = unittest.mock.Mock(__name__="_TaskMock", _DefaultName="FakeTask", ConfigClass=FakeConfig)
79 return mock
82class TaskFactoryTestCase(unittest.TestCase):
83 """Tests for `TaskFactory`."""
85 @classmethod
86 def setUpClass(cls) -> None:
87 super().setUpClass()
89 tmp = tempfile.mkdtemp()
90 cls.addClassCleanup(shutil.rmtree, tmp, ignore_errors=True)
91 cls.repo = butlerTests.makeTestRepo(tmp)
92 butlerTests.addDatasetType(cls.repo, "fakeInitInput", set(), _FAKE_STORAGE_CLASS)
93 butlerTests.addDatasetType(cls.repo, "fakeInitOutput", set(), _FAKE_STORAGE_CLASS)
94 butlerTests.addDatasetType(cls.repo, "fakeInput", set(), _FAKE_STORAGE_CLASS)
95 butlerTests.addDatasetType(cls.repo, "fakeOutput", set(), _FAKE_STORAGE_CLASS)
97 def setUp(self) -> None:
98 super().setUp()
100 self.factory = TaskFactory()
101 self.constructor = mockTaskClass()
103 @staticmethod
104 def _alteredConfig() -> FakeConfig:
105 config = FakeConfig()
106 config.widget = 42.0
107 return config
109 @staticmethod
110 def _dummyCatalog() -> dict:
111 return {}
113 def _tempButler(self) -> tuple[Butler, dict[str, DatasetRef]]:
114 butler = butlerTests.makeTestCollection(self.repo, uniqueId=self.id())
115 self.enterContext(butler)
116 catalog = self._dummyCatalog()
117 refs = {}
118 for dataset_type in ("fakeInitInput", "fakeInitOutput", "fakeInput", "fakeOutput"):
119 refs[dataset_type] = butler.put(catalog, dataset_type)
120 return butler, refs
122 def testDefaultConfigLabel(self) -> None:
123 pipeline_graph = PipelineGraph()
124 task_node = pipeline_graph.add_task(None, self.constructor)
125 butler, _ = self._tempButler()
126 self.factory.makeTask(task_node, butler=butler, initInputRefs=[])
127 self.constructor.assert_called_with(config=FakeConfig(), initInputs={}, name="FakeTask")
129 def testAllArgs(self) -> None:
130 config = self._alteredConfig()
131 pipeline_graph = PipelineGraph()
132 task_node = pipeline_graph.add_task("no-name", self.constructor, config=config)
133 butler, refs = self._tempButler()
134 self.factory.makeTask(task_node, butler=butler, initInputRefs=[refs["fakeInitInput"]])
135 catalog = butler.get("fakeInitInput")
136 self.constructor.assert_called_with(config=config, initInputs={"initInput": catalog}, name="no-name")
139if __name__ == "__main__":
140 unittest.main()