Coverage for tests/test_taskFactory.py: 46%

66 statements  

« prev     ^ index     » next       coverage.py v7.2.6, created at 2023-05-26 02:14 -0700

1# This file is part of ctrl_mpexec. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24import shutil 

25import tempfile 

26import unittest 

27from typing import TYPE_CHECKING 

28 

29import lsst.daf.butler.tests as butlerTests 

30import lsst.pex.config as pexConfig 

31from lsst.ctrl.mpexec import TaskFactory 

32from lsst.pipe.base import PipelineTaskConfig, PipelineTaskConnections, TaskDef, connectionTypes 

33 

34if TYPE_CHECKING: 34 ↛ 35line 34 didn't jump to line 35, because the condition on line 34 was never true

35 from lsst.daf.butler import Butler, DatasetRef 

36 

37# Storage class to use for tests of fakes. 

38_FAKE_STORAGE_CLASS = "StructuredDataDict" 

39 

40 

41class FakeConnections(PipelineTaskConnections, dimensions=set()): 

42 initInput = connectionTypes.InitInput(name="fakeInitInput", doc="", storageClass=_FAKE_STORAGE_CLASS) 

43 initOutput = connectionTypes.InitOutput(name="fakeInitOutput", doc="", storageClass=_FAKE_STORAGE_CLASS) 

44 input = connectionTypes.Input( 

45 name="fakeInput", doc="", storageClass=_FAKE_STORAGE_CLASS, dimensions=set() 

46 ) 

47 output = connectionTypes.Output( 

48 name="fakeOutput", doc="", storageClass=_FAKE_STORAGE_CLASS, dimensions=set() 

49 ) 

50 

51 

52class FakeConfig(PipelineTaskConfig, pipelineConnections=FakeConnections): 

53 widget = pexConfig.Field(dtype=float, doc="", default=1.0) 

54 

55 

56def mockTaskClass(): 

57 """A class placeholder that records calls to __call__.""" 

58 mock = unittest.mock.Mock(__name__="_TaskMock", _DefaultName="FakeTask", ConfigClass=FakeConfig) 

59 return mock 

60 

61 

62class TaskFactoryTestCase(unittest.TestCase): 

63 @classmethod 

64 def setUpClass(cls) -> None: 

65 super().setUpClass() 

66 

67 tmp = tempfile.mkdtemp() 

68 cls.addClassCleanup(shutil.rmtree, tmp, ignore_errors=True) 

69 cls.repo = butlerTests.makeTestRepo(tmp) 

70 butlerTests.addDatasetType(cls.repo, "fakeInitInput", set(), _FAKE_STORAGE_CLASS) 

71 butlerTests.addDatasetType(cls.repo, "fakeInitOutput", set(), _FAKE_STORAGE_CLASS) 

72 butlerTests.addDatasetType(cls.repo, "fakeInput", set(), _FAKE_STORAGE_CLASS) 

73 butlerTests.addDatasetType(cls.repo, "fakeOutput", set(), _FAKE_STORAGE_CLASS) 

74 

75 def setUp(self) -> None: 

76 super().setUp() 

77 

78 self.factory = TaskFactory() 

79 self.constructor = mockTaskClass() 

80 

81 @staticmethod 

82 def _alteredConfig() -> FakeConfig: 

83 config = FakeConfig() 

84 config.widget = 42.0 

85 return config 

86 

87 @staticmethod 

88 def _dummyCatalog() -> dict: 

89 return {} 

90 

91 def _tempButler(self) -> tuple[Butler, dict[str, DatasetRef]]: 

92 butler = butlerTests.makeTestCollection(self.repo, uniqueId=self.id()) 

93 catalog = self._dummyCatalog() 

94 refs = {} 

95 for dataset_type in ("fakeInitInput", "fakeInitOutput", "fakeInput", "fakeOutput"): 

96 refs[dataset_type] = butler.put(catalog, dataset_type) 

97 return butler, refs 

98 

99 def testDefaultConfigLabel(self) -> None: 

100 task_def = TaskDef(taskClass=self.constructor, label=None, config=None) 

101 butler, _ = self._tempButler() 

102 self.factory.makeTask(taskDef=task_def, butler=butler, initInputRefs=[]) 

103 self.constructor.assert_called_with(config=FakeConfig(), initInputs={}, name="FakeTask") 

104 

105 def testAllArgs(self) -> None: 

106 config = self._alteredConfig() 

107 task_def = TaskDef(taskClass=self.constructor, label="no-name", config=config) 

108 butler, refs = self._tempButler() 

109 self.factory.makeTask(taskDef=task_def, butler=butler, initInputRefs=[refs["fakeInitInput"]]) 

110 catalog = butler.get("fakeInitInput") 

111 self.constructor.assert_called_with(config=config, initInputs={"initInput": catalog}, name="no-name") 

112 

113 

114if __name__ == "__main__": 114 ↛ 115line 114 didn't jump to line 115, because the condition on line 114 was never true

115 unittest.main()