Coverage for tests/test_taskFactory.py: 46%
66 statements
« prev ^ index » next coverage.py v7.2.3, created at 2023-04-20 10:51 +0000
« prev ^ index » next coverage.py v7.2.3, created at 2023-04-20 10:51 +0000
1# This file is part of ctrl_mpexec.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24import shutil
25import tempfile
26import unittest
27from typing import TYPE_CHECKING
29import lsst.daf.butler.tests as butlerTests
30import lsst.pex.config as pexConfig
31from lsst.ctrl.mpexec import TaskFactory
32from lsst.pipe.base import PipelineTaskConfig, PipelineTaskConnections, TaskDef, connectionTypes
34if TYPE_CHECKING: 34 ↛ 35line 34 didn't jump to line 35, because the condition on line 34 was never true
35 from lsst.daf.butler import Butler, DatasetRef
37# Storage class to use for tests of fakes.
38_FAKE_STORAGE_CLASS = "StructuredDataDict"
41class FakeConnections(PipelineTaskConnections, dimensions=set()):
42 initInput = connectionTypes.InitInput(name="fakeInitInput", doc="", storageClass=_FAKE_STORAGE_CLASS)
43 initOutput = connectionTypes.InitOutput(name="fakeInitOutput", doc="", storageClass=_FAKE_STORAGE_CLASS)
44 input = connectionTypes.Input(
45 name="fakeInput", doc="", storageClass=_FAKE_STORAGE_CLASS, dimensions=set()
46 )
47 output = connectionTypes.Output(
48 name="fakeOutput", doc="", storageClass=_FAKE_STORAGE_CLASS, dimensions=set()
49 )
52class FakeConfig(PipelineTaskConfig, pipelineConnections=FakeConnections):
53 widget = pexConfig.Field(dtype=float, doc="", default=1.0)
56def mockTaskClass():
57 """A class placeholder that records calls to __call__."""
58 mock = unittest.mock.Mock(__name__="_TaskMock", _DefaultName="FakeTask", ConfigClass=FakeConfig)
59 return mock
62class TaskFactoryTestCase(unittest.TestCase):
63 @classmethod
64 def setUpClass(cls) -> None:
65 super().setUpClass()
67 tmp = tempfile.mkdtemp()
68 cls.addClassCleanup(shutil.rmtree, tmp, ignore_errors=True)
69 cls.repo = butlerTests.makeTestRepo(tmp)
70 butlerTests.addDatasetType(cls.repo, "fakeInitInput", set(), _FAKE_STORAGE_CLASS)
71 butlerTests.addDatasetType(cls.repo, "fakeInitOutput", set(), _FAKE_STORAGE_CLASS)
72 butlerTests.addDatasetType(cls.repo, "fakeInput", set(), _FAKE_STORAGE_CLASS)
73 butlerTests.addDatasetType(cls.repo, "fakeOutput", set(), _FAKE_STORAGE_CLASS)
75 def setUp(self) -> None:
76 super().setUp()
78 self.factory = TaskFactory()
79 self.constructor = mockTaskClass()
81 @staticmethod
82 def _alteredConfig() -> FakeConfig:
83 config = FakeConfig()
84 config.widget = 42.0
85 return config
87 @staticmethod
88 def _dummyCatalog() -> dict:
89 return {}
91 def _tempButler(self) -> tuple[Butler, dict[str, DatasetRef]]:
92 butler = butlerTests.makeTestCollection(self.repo, uniqueId=self.id())
93 catalog = self._dummyCatalog()
94 refs = {}
95 for dataset_type in ("fakeInitInput", "fakeInitOutput", "fakeInput", "fakeOutput"):
96 refs[dataset_type] = butler.put(catalog, dataset_type)
97 return butler, refs
99 def testDefaultConfigLabel(self) -> None:
100 task_def = TaskDef(taskClass=self.constructor, label=None, config=None)
101 butler, _ = self._tempButler()
102 self.factory.makeTask(taskDef=task_def, butler=butler, initInputRefs=[])
103 self.constructor.assert_called_with(config=FakeConfig(), initInputs={}, name="FakeTask")
105 def testAllArgs(self) -> None:
106 config = self._alteredConfig()
107 task_def = TaskDef(taskClass=self.constructor, label="no-name", config=config)
108 butler, refs = self._tempButler()
109 self.factory.makeTask(taskDef=task_def, butler=butler, initInputRefs=[refs["fakeInitInput"]])
110 catalog = butler.get("fakeInitInput")
111 self.constructor.assert_called_with(config=config, initInputs={"initInput": catalog}, name="no-name")
114if __name__ == "__main__": 114 ↛ 115line 114 didn't jump to line 115, because the condition on line 114 was never true
115 unittest.main()