Coverage for tests/test_connections.py: 15%

104 statements  

« prev     ^ index     » next       coverage.py v7.3.0, created at 2023-08-31 09:39 +0000

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22"""Simple unit test for PipelineTaskConnections. 

23""" 

24 

25import unittest 

26import warnings 

27 

28import lsst.pipe.base as pipeBase 

29import lsst.utils.tests 

30import pytest 

31from lsst.pex.config import Field 

32 

33 

34class TestConnectionsClass(unittest.TestCase): 

35 """Test connection classes.""" 

36 

37 def __init__(self, *args, **kwargs): 

38 super().__init__(*args, **kwargs) 

39 

40 # Test dimensions 

41 self.test_dims = ("a", "b") 

42 

43 def testConnectionsDeclaration(self): 

44 """Tests the declaration of a Connections Class""" 

45 with pytest.raises(TypeError): 

46 # This should raise because this Connections class is created with 

47 # no dimensions 

48 class TestConnections(pipeBase.PipelineTaskConnections): 

49 pass 

50 

51 with pytest.raises(TypeError): 

52 # This should raise because this Connections class is created with 

53 # out template defaults 

54 class TestConnectionsTemplate(pipeBase.PipelineTaskConnections, dimensions=self.test_dims): 

55 field = pipeBase.connectionTypes.Input( 

56 doc="Test", name="{template}test", dimensions=self.test_dims, storageClass="Dummy" 

57 ) 

58 

59 # This declaration should raise no exceptions 

60 class TestConnectionsWithDimensions(pipeBase.PipelineTaskConnections, dimensions=self.test_dims): 

61 pass 

62 

63 # This should not raise 

64 class TestConnectionsWithTemplate( 

65 pipeBase.PipelineTaskConnections, 

66 dimensions=self.test_dims, 

67 defaultTemplates={"template": "working"}, 

68 ): 

69 field = pipeBase.connectionTypes.Input( 

70 doc="Test", name="{template}test", dimensions=self.test_dims, storageClass="Dummy" 

71 ) 

72 

73 def testConnectionsOnConnectionsClass(self): 

74 class TestConnections(pipeBase.PipelineTaskConnections, dimensions=self.test_dims): 

75 initInput1 = pipeBase.connectionTypes.InitInput( 

76 doc="Test Init input", name="init_input", storageClass="Dummy" 

77 ) 

78 initInput2 = pipeBase.connectionTypes.InitInput( 

79 doc="Test Init input", name="init_input2", storageClass="Dummy" 

80 ) 

81 

82 initOutput1 = pipeBase.connectionTypes.InitOutput( 

83 doc="Test Init output", name="init_output1", storageClass="Dummy" 

84 ) 

85 initOutput2 = pipeBase.connectionTypes.InitOutput( 

86 doc="Test Init output", name="init_output2", storageClass="Dummy" 

87 ) 

88 

89 input1 = pipeBase.connectionTypes.Input( 

90 doc="test input", name="input2", dimensions=self.test_dims, storageClass="Dummy" 

91 ) 

92 input2 = pipeBase.connectionTypes.Input( 

93 doc="test input", name="input2", dimensions=self.test_dims, storageClass="Dummy" 

94 ) 

95 

96 prereqInputs1 = pipeBase.connectionTypes.PrerequisiteInput( 

97 doc="test input", name="pre_input1", dimensions=self.test_dims, storageClass="Dummy" 

98 ) 

99 prereqInputs2 = pipeBase.connectionTypes.PrerequisiteInput( 

100 doc="test input", name="pre_input2", dimensions=self.test_dims, storageClass="Dummy" 

101 ) 

102 

103 output1 = pipeBase.connectionTypes.Output( 

104 doc="test output", name="output", dimensions=self.test_dims, storageClass="Dummy" 

105 ) 

106 output2 = pipeBase.connectionTypes.Output( 

107 doc="test output", name="output", dimensions=self.test_dims, storageClass="Dummy" 

108 ) 

109 

110 self.assertEqual(TestConnections.initInputs, frozenset(["initInput1", "initInput2"])) 

111 self.assertEqual(TestConnections.initOutputs, frozenset(["initOutput1", "initOutput2"])) 

112 self.assertEqual(TestConnections.inputs, frozenset(["input1", "input2"])) 

113 self.assertEqual(TestConnections.prerequisiteInputs, frozenset(["prereqInputs1", "prereqInputs2"])) 

114 self.assertEqual(TestConnections.outputs, frozenset(["output1", "output2"])) 

115 

116 def buildTestConnections(self): 

117 class TestConnectionsWithTemplate( 

118 pipeBase.PipelineTaskConnections, 

119 dimensions=self.test_dims, 

120 defaultTemplates={"template": "working"}, 

121 ): 

122 field = pipeBase.connectionTypes.Input( 

123 doc="Test", name="{template}test", dimensions=self.test_dims, storageClass="Dummy" 

124 ) 

125 field2 = pipeBase.connectionTypes.Output( 

126 doc="Test", name="field2Type", dimensions=self.test_dims, storageClass="Dummy", multiple=True 

127 ) 

128 

129 def adjustQuantum(self, datasetRefMap): 

130 if len(datasetRefMap.field) < 2: 

131 raise ValueError("This connection should have more than one entry") 

132 

133 class TestConfig(pipeBase.PipelineTaskConfig, pipelineConnections=TestConnectionsWithTemplate): 

134 pass 

135 

136 config = TestConfig() 

137 config.connections.template = "fromConfig" 

138 config.connections.field2 = "field2FromConfig" 

139 

140 connections = TestConnectionsWithTemplate(config=config) 

141 return connections 

142 

143 def testConnectionsInstantiation(self): 

144 connections = self.buildTestConnections() 

145 self.assertEqual(connections.field.name, "fromConfigtest") 

146 self.assertEqual(connections.field2.name, "field2FromConfig") 

147 self.assertEqual(connections.allConnections["field"].name, "fromConfigtest") 

148 self.assertEqual(connections.allConnections["field2"].name, "field2FromConfig") 

149 

150 def testBuildDatasetRefs(self): 

151 connections = self.buildTestConnections() 

152 

153 mockQuantum = pipeBase.Struct( 

154 inputs={"fromConfigtest": ["a"]}, outputs={"field2FromConfig": ["b", "c"]} 

155 ) 

156 

157 inputRefs, outputRefs = connections.buildDatasetRefs(mockQuantum) 

158 self.assertEqual(inputRefs.field, "a") 

159 self.assertEqual(outputRefs.field2, ["b", "c"]) 

160 

161 def testAdjustQuantum(self): 

162 connections = self.buildTestConnections() 

163 mockQuantum = pipeBase.Struct( 

164 inputs={"fromConfigtest": ["a"]}, outputs={"field2FromConfig": ["b", "c"]} 

165 ) 

166 inputRefs, outputRefs = connections.buildDatasetRefs(mockQuantum) 

167 with self.assertRaises(ValueError): 

168 connections.adjustQuantum(inputRefs) 

169 

170 def testDimensionCheck(self): 

171 with self.assertRaises(TypeError): 

172 

173 class TestConnectionsWithBrokenDimensionsStr(pipeBase.PipelineTask, dimensions={"a"}): 

174 pass 

175 

176 with self.assertRaises(TypeError): 

177 

178 class TestConnectionsWithBrokenDimensionsIter(pipeBase.PipelineTask, dimensions=2): 

179 pass 

180 

181 with self.assertRaises(TypeError): 

182 pipeBase.connectionTypes.Output( 

183 Doc="mock doc", dimensions={"a"}, name="output", storageClass="mock" 

184 ) 

185 

186 with self.assertRaises(TypeError): 

187 pipeBase.connectionTypes.Output(Doc="mock doc", dimensions=1, name="output", storageClass="mock") 

188 

189 def test_deprecation(self) -> None: 

190 """Test support for deprecating connections.""" 

191 

192 class TestConnections( 

193 pipeBase.PipelineTaskConnections, 

194 dimensions=self.test_dims, 

195 defaultTemplates={"t1": "dataset_type_1"}, 

196 deprecatedTemplates={"t1": "Deprecated in v600, will be removed after v601."}, 

197 ): 

198 input1 = pipeBase.connectionTypes.Input( 

199 doc="Docs for input1", 

200 name="input1_{t1}", 

201 storageClass="StructuredDataDict", 

202 deprecated="Deprecated in v50000, will be removed after v50001.", 

203 ) 

204 

205 def __init__(self, config): 

206 if config.drop_input1: 

207 del self.input1 

208 

209 class TestConfig(pipeBase.PipelineTaskConfig, pipelineConnections=TestConnections): 

210 drop_input1 = Field("Remove the 'input1' connection if True", dtype=bool, default=False) 

211 

212 config = TestConfig() 

213 with self.assertWarns(FutureWarning): 

214 config.connections.input1 = "dataset_type_2" 

215 with self.assertWarns(FutureWarning): 

216 config.connections.t1 = "dataset_type_3" 

217 

218 with self.assertWarns(FutureWarning): 

219 TestConnections(config=config) 

220 

221 config.drop_input1 = True 

222 

223 with warnings.catch_warnings(): 

224 warnings.simplefilter("error", FutureWarning) 

225 TestConnections(config=config) 

226 

227 

228class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase): 

229 """Run file leak tests.""" 

230 

231 

232def setup_module(module): 

233 """Configure pytest.""" 

234 lsst.utils.tests.init() 

235 

236 

237if __name__ == "__main__": 

238 lsst.utils.tests.init() 

239 unittest.main()