Coverage for tests / test_task_factory.py: 42%

64 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-24 08:19 +0000

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <https://www.gnu.org/licenses/>. 

27 

28from __future__ import annotations 

29 

30import shutil 

31import tempfile 

32import unittest 

33from typing import TYPE_CHECKING 

34 

35import lsst.daf.butler.tests as butlerTests 

36import lsst.pex.config as pexConfig 

37from lsst.pipe.base import ( 

38 PipelineGraph, 

39 PipelineTaskConfig, 

40 PipelineTaskConnections, 

41 TaskFactory, 

42 connectionTypes, 

43) 

44 

45if TYPE_CHECKING: 

46 from lsst.daf.butler import Butler, DatasetRef 

47 

48# Storage class to use for tests of fakes. 

49_FAKE_STORAGE_CLASS = "StructuredDataDict" 

50 

51 

52class FakeConnections(PipelineTaskConnections, dimensions=set()): 

53 """Fake connections class used for testing.""" 

54 

55 initInput = connectionTypes.InitInput(name="fakeInitInput", doc="test", storageClass=_FAKE_STORAGE_CLASS) 

56 initOutput = connectionTypes.InitOutput( 

57 name="fakeInitOutput", doc="test", storageClass=_FAKE_STORAGE_CLASS 

58 ) 

59 input = connectionTypes.Input( 

60 name="fakeInput", doc="test", storageClass=_FAKE_STORAGE_CLASS, dimensions=set() 

61 ) 

62 output = connectionTypes.Output( 

63 name="fakeOutput", doc="test", storageClass=_FAKE_STORAGE_CLASS, dimensions=set() 

64 ) 

65 

66 

67class FakeConfig(PipelineTaskConfig, pipelineConnections=FakeConnections): 

68 """Config class used along with fake connections class.""" 

69 

70 widget = pexConfig.Field(dtype=float, doc="test", default=1.0) 

71 

72 

73def mockTaskClass(): 

74 """Record calls to ``__call__``. 

75 

76 A class placeholder that records calls to __call__. 

77 """ 

78 mock = unittest.mock.Mock(__name__="_TaskMock", _DefaultName="FakeTask", ConfigClass=FakeConfig) 

79 return mock 

80 

81 

82class TaskFactoryTestCase(unittest.TestCase): 

83 """Tests for `TaskFactory`.""" 

84 

85 @classmethod 

86 def setUpClass(cls) -> None: 

87 super().setUpClass() 

88 

89 tmp = tempfile.mkdtemp() 

90 cls.addClassCleanup(shutil.rmtree, tmp, ignore_errors=True) 

91 cls.repo = butlerTests.makeTestRepo(tmp) 

92 butlerTests.addDatasetType(cls.repo, "fakeInitInput", set(), _FAKE_STORAGE_CLASS) 

93 butlerTests.addDatasetType(cls.repo, "fakeInitOutput", set(), _FAKE_STORAGE_CLASS) 

94 butlerTests.addDatasetType(cls.repo, "fakeInput", set(), _FAKE_STORAGE_CLASS) 

95 butlerTests.addDatasetType(cls.repo, "fakeOutput", set(), _FAKE_STORAGE_CLASS) 

96 

97 def setUp(self) -> None: 

98 super().setUp() 

99 

100 self.factory = TaskFactory() 

101 self.constructor = mockTaskClass() 

102 

103 @staticmethod 

104 def _alteredConfig() -> FakeConfig: 

105 config = FakeConfig() 

106 config.widget = 42.0 

107 return config 

108 

109 @staticmethod 

110 def _dummyCatalog() -> dict: 

111 return {} 

112 

113 def _tempButler(self) -> tuple[Butler, dict[str, DatasetRef]]: 

114 butler = butlerTests.makeTestCollection(self.repo, uniqueId=self.id()) 

115 self.enterContext(butler) 

116 catalog = self._dummyCatalog() 

117 refs = {} 

118 for dataset_type in ("fakeInitInput", "fakeInitOutput", "fakeInput", "fakeOutput"): 

119 refs[dataset_type] = butler.put(catalog, dataset_type) 

120 return butler, refs 

121 

122 def testDefaultConfigLabel(self) -> None: 

123 pipeline_graph = PipelineGraph() 

124 task_node = pipeline_graph.add_task(None, self.constructor) 

125 butler, _ = self._tempButler() 

126 self.factory.makeTask(task_node, butler=butler, initInputRefs=[]) 

127 self.constructor.assert_called_with(config=FakeConfig(), initInputs={}, name="FakeTask") 

128 

129 def testAllArgs(self) -> None: 

130 config = self._alteredConfig() 

131 pipeline_graph = PipelineGraph() 

132 task_node = pipeline_graph.add_task("no-name", self.constructor, config=config) 

133 butler, refs = self._tempButler() 

134 self.factory.makeTask(task_node, butler=butler, initInputRefs=[refs["fakeInitInput"]]) 

135 catalog = butler.get("fakeInitInput") 

136 self.constructor.assert_called_with(config=config, initInputs={"initInput": catalog}, name="no-name") 

137 

138 

139if __name__ == "__main__": 

140 unittest.main()