Coverage for tests/test_connections.py: 15%

104 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-04 02:56 -0700

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This software is dual licensed under the GNU General Public License and also 

10# under a 3-clause BSD license. Recipients may choose which of these licenses 

11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, 

12# respectively. If you choose the GPL option then the following text applies 

13# (but note that there is still no warranty even if you opt for BSD instead): 

14# 

15# This program is free software: you can redistribute it and/or modify 

16# it under the terms of the GNU General Public License as published by 

17# the Free Software Foundation, either version 3 of the License, or 

18# (at your option) any later version. 

19# 

20# This program is distributed in the hope that it will be useful, 

21# but WITHOUT ANY WARRANTY; without even the implied warranty of 

22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

23# GNU General Public License for more details. 

24# 

25# You should have received a copy of the GNU General Public License 

26# along with this program. If not, see <http://www.gnu.org/licenses/>. 

27 

28"""Simple unit test for PipelineTaskConnections. 

29""" 

30 

31import unittest 

32import warnings 

33 

34import lsst.pipe.base as pipeBase 

35import lsst.utils.tests 

36import pytest 

37from lsst.pex.config import Field 

38 

39 

40class TestConnectionsClass(unittest.TestCase): 

41 """Test connection classes.""" 

42 

43 def __init__(self, *args, **kwargs): 

44 super().__init__(*args, **kwargs) 

45 

46 # Test dimensions 

47 self.test_dims = ("a", "b") 

48 

49 def testConnectionsDeclaration(self): 

50 """Tests the declaration of a Connections Class.""" 

51 with pytest.raises(TypeError): 

52 # This should raise because this Connections class is created with 

53 # no dimensions 

54 class TestConnections(pipeBase.PipelineTaskConnections): 

55 pass 

56 

57 with pytest.raises(TypeError): 

58 # This should raise because this Connections class is created with 

59 # out template defaults 

60 class TestConnectionsTemplate(pipeBase.PipelineTaskConnections, dimensions=self.test_dims): 

61 field = pipeBase.connectionTypes.Input( 

62 doc="Test", name="{template}test", dimensions=self.test_dims, storageClass="Dummy" 

63 ) 

64 

65 # This declaration should raise no exceptions 

66 class TestConnectionsWithDimensions(pipeBase.PipelineTaskConnections, dimensions=self.test_dims): 

67 pass 

68 

69 # This should not raise 

70 class TestConnectionsWithTemplate( 

71 pipeBase.PipelineTaskConnections, 

72 dimensions=self.test_dims, 

73 defaultTemplates={"template": "working"}, 

74 ): 

75 field = pipeBase.connectionTypes.Input( 

76 doc="Test", name="{template}test", dimensions=self.test_dims, storageClass="Dummy" 

77 ) 

78 

79 def testConnectionsOnConnectionsClass(self): 

80 class TestConnections(pipeBase.PipelineTaskConnections, dimensions=self.test_dims): 

81 initInput1 = pipeBase.connectionTypes.InitInput( 

82 doc="Test Init input", name="init_input", storageClass="Dummy" 

83 ) 

84 initInput2 = pipeBase.connectionTypes.InitInput( 

85 doc="Test Init input", name="init_input2", storageClass="Dummy" 

86 ) 

87 

88 initOutput1 = pipeBase.connectionTypes.InitOutput( 

89 doc="Test Init output", name="init_output1", storageClass="Dummy" 

90 ) 

91 initOutput2 = pipeBase.connectionTypes.InitOutput( 

92 doc="Test Init output", name="init_output2", storageClass="Dummy" 

93 ) 

94 

95 input1 = pipeBase.connectionTypes.Input( 

96 doc="test input", name="input2", dimensions=self.test_dims, storageClass="Dummy" 

97 ) 

98 input2 = pipeBase.connectionTypes.Input( 

99 doc="test input", name="input2", dimensions=self.test_dims, storageClass="Dummy" 

100 ) 

101 

102 prereqInputs1 = pipeBase.connectionTypes.PrerequisiteInput( 

103 doc="test input", name="pre_input1", dimensions=self.test_dims, storageClass="Dummy" 

104 ) 

105 prereqInputs2 = pipeBase.connectionTypes.PrerequisiteInput( 

106 doc="test input", name="pre_input2", dimensions=self.test_dims, storageClass="Dummy" 

107 ) 

108 

109 output1 = pipeBase.connectionTypes.Output( 

110 doc="test output", name="output", dimensions=self.test_dims, storageClass="Dummy" 

111 ) 

112 output2 = pipeBase.connectionTypes.Output( 

113 doc="test output", name="output", dimensions=self.test_dims, storageClass="Dummy" 

114 ) 

115 

116 self.assertEqual(TestConnections.initInputs, frozenset(["initInput1", "initInput2"])) 

117 self.assertEqual(TestConnections.initOutputs, frozenset(["initOutput1", "initOutput2"])) 

118 self.assertEqual(TestConnections.inputs, frozenset(["input1", "input2"])) 

119 self.assertEqual(TestConnections.prerequisiteInputs, frozenset(["prereqInputs1", "prereqInputs2"])) 

120 self.assertEqual(TestConnections.outputs, frozenset(["output1", "output2"])) 

121 

122 def buildTestConnections(self): 

123 class TestConnectionsWithTemplate( 

124 pipeBase.PipelineTaskConnections, 

125 dimensions=self.test_dims, 

126 defaultTemplates={"template": "working"}, 

127 ): 

128 field = pipeBase.connectionTypes.Input( 

129 doc="Test", name="{template}test", dimensions=self.test_dims, storageClass="Dummy" 

130 ) 

131 field2 = pipeBase.connectionTypes.Output( 

132 doc="Test", name="field2Type", dimensions=self.test_dims, storageClass="Dummy", multiple=True 

133 ) 

134 

135 def adjustQuantum(self, datasetRefMap): 

136 if len(datasetRefMap.field) < 2: 

137 raise ValueError("This connection should have more than one entry") 

138 

139 class TestConfig(pipeBase.PipelineTaskConfig, pipelineConnections=TestConnectionsWithTemplate): 

140 pass 

141 

142 config = TestConfig() 

143 config.connections.template = "fromConfig" 

144 config.connections.field2 = "field2FromConfig" 

145 

146 connections = TestConnectionsWithTemplate(config=config) 

147 return connections 

148 

149 def testConnectionsInstantiation(self): 

150 connections = self.buildTestConnections() 

151 self.assertEqual(connections.field.name, "fromConfigtest") 

152 self.assertEqual(connections.field2.name, "field2FromConfig") 

153 self.assertEqual(connections.allConnections["field"].name, "fromConfigtest") 

154 self.assertEqual(connections.allConnections["field2"].name, "field2FromConfig") 

155 

156 def testBuildDatasetRefs(self): 

157 connections = self.buildTestConnections() 

158 

159 mockQuantum = pipeBase.Struct( 

160 inputs={"fromConfigtest": ["a"]}, outputs={"field2FromConfig": ["b", "c"]} 

161 ) 

162 

163 inputRefs, outputRefs = connections.buildDatasetRefs(mockQuantum) 

164 self.assertEqual(inputRefs.field, "a") 

165 self.assertEqual(outputRefs.field2, ["b", "c"]) 

166 

167 def testAdjustQuantum(self): 

168 connections = self.buildTestConnections() 

169 mockQuantum = pipeBase.Struct( 

170 inputs={"fromConfigtest": ["a"]}, outputs={"field2FromConfig": ["b", "c"]} 

171 ) 

172 inputRefs, outputRefs = connections.buildDatasetRefs(mockQuantum) 

173 with self.assertRaises(ValueError): 

174 connections.adjustQuantum(inputRefs) 

175 

176 def testDimensionCheck(self): 

177 with self.assertRaises(TypeError): 

178 

179 class TestConnectionsWithBrokenDimensionsStr(pipeBase.PipelineTask, dimensions={"a"}): 

180 pass 

181 

182 with self.assertRaises(TypeError): 

183 

184 class TestConnectionsWithBrokenDimensionsIter(pipeBase.PipelineTask, dimensions=2): 

185 pass 

186 

187 with self.assertRaises(TypeError): 

188 pipeBase.connectionTypes.Output( 

189 Doc="mock doc", dimensions={"a"}, name="output", storageClass="mock" 

190 ) 

191 

192 with self.assertRaises(TypeError): 

193 pipeBase.connectionTypes.Output(Doc="mock doc", dimensions=1, name="output", storageClass="mock") 

194 

195 def test_deprecation(self) -> None: 

196 """Test support for deprecating connections.""" 

197 

198 class TestConnections( 

199 pipeBase.PipelineTaskConnections, 

200 dimensions=self.test_dims, 

201 defaultTemplates={"t1": "dataset_type_1"}, 

202 deprecatedTemplates={"t1": "Deprecated in v600, will be removed after v601."}, 

203 ): 

204 input1 = pipeBase.connectionTypes.Input( 

205 doc="Docs for input1", 

206 name="input1_{t1}", 

207 storageClass="StructuredDataDict", 

208 deprecated="Deprecated in v50000, will be removed after v50001.", 

209 ) 

210 

211 def __init__(self, config): 

212 if config.drop_input1: 

213 del self.input1 

214 

215 class TestConfig(pipeBase.PipelineTaskConfig, pipelineConnections=TestConnections): 

216 drop_input1 = Field("Remove the 'input1' connection if True", dtype=bool, default=False) 

217 

218 config = TestConfig() 

219 with self.assertWarns(FutureWarning): 

220 config.connections.input1 = "dataset_type_2" 

221 with self.assertWarns(FutureWarning): 

222 config.connections.t1 = "dataset_type_3" 

223 

224 with self.assertWarns(FutureWarning): 

225 TestConnections(config=config) 

226 

227 config.drop_input1 = True 

228 

229 with warnings.catch_warnings(): 

230 warnings.simplefilter("error", FutureWarning) 

231 TestConnections(config=config) 

232 

233 

234class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase): 

235 """Run file leak tests.""" 

236 

237 

238def setup_module(module): 

239 """Configure pytest.""" 

240 lsst.utils.tests.init() 

241 

242 

243if __name__ == "__main__": 

244 lsst.utils.tests.init() 

245 unittest.main()