Coverage for tests/test_dynamic_connections.py: 12%

180 statements  

« prev     ^ index     » next       coverage.py v7.3.0, created at 2023-08-31 09:39 +0000

1# This file is part of pipe_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (http://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <http://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24import unittest 

25from collections.abc import Callable 

26from types import MappingProxyType 

27 

28from lsst.pipe.base import PipelineTaskConfig, PipelineTaskConnections 

29from lsst.pipe.base.connectionTypes import Input, Output, PrerequisiteInput 

30 

31# Keyword arguments for defining our lone test connection in its class. 

32DEFAULT_CONNECTION_KWARGS = dict( 

33 doc="field docs", 

34 name="unconfigured", 

35 dimensions=(), 

36 storageClass="Dummy", 

37) 

38 

39# Keyword arguments for making our lone test connection post-configuration. 

40# We always rename the dataset type via configuration to make sure those 

41# changes are never dropped. 

42RENAMED_CONNECTION_KWARGS = DEFAULT_CONNECTION_KWARGS.copy() 

43RENAMED_CONNECTION_KWARGS["name"] = "configured" 

44 

45 

46class TestDynamicConnectionsClass(unittest.TestCase): 

47 """Test modifying connections in derived __init__ implementations.""" 

48 

49 def build_dynamic_connections( 

50 self, init_callback: Callable[[PipelineTaskConnections], None] = None 

51 ) -> PipelineTaskConnections: 

52 """Define and construct a connections class instance with a callback 

53 run in ``__init__``. 

54 

55 Parameters 

56 ---------- 

57 init_callback 

58 Callback to invoke with the `PipelineTaskConnections` instance 

59 as its only argument. Return value is ignored. 

60 

61 Returns 

62 ------- 

63 connections : `PipelineTaskConnections` 

64 Constructed connections instance. 

65 """ 

66 

67 class ExampleConnections(PipelineTaskConnections, dimensions=()): 

68 the_connection = Input(**DEFAULT_CONNECTION_KWARGS) 

69 

70 def __init__(self, config: ExampleConfig): 

71 # Calling super() is harmless but now unnecessary, so don't do 

72 # it to make sure that works. Old code calling it is fine. 

73 if init_callback is not None: 

74 init_callback(self) 

75 

76 class ExampleConfig(PipelineTaskConfig, pipelineConnections=ExampleConnections): 

77 pass 

78 

79 config = ExampleConfig() 

80 config.connections.the_connection = RENAMED_CONNECTION_KWARGS["name"] 

81 return ExampleConnections(config=config) 

82 

83 def test_freeze_after_construction(self): 

84 connections = self.build_dynamic_connections() 

85 self.assertIsInstance(connections.dimensions, frozenset) 

86 self.assertIsInstance(connections.inputs, frozenset) 

87 self.assertIsInstance(connections.prerequisiteInputs, frozenset) 

88 self.assertIsInstance(connections.outputs, frozenset) 

89 self.assertIsInstance(connections.initInputs, frozenset) 

90 self.assertIsInstance(connections.initOutputs, frozenset) 

91 self.assertIsInstance(connections.allConnections, MappingProxyType) 

92 

93 def test_change_attr_after_construction(self): 

94 connections = self.build_dynamic_connections() 

95 with self.assertRaises(TypeError): 

96 connections.the_connection = PrerequisiteInput(**RENAMED_CONNECTION_KWARGS) 

97 

98 def test_delete_attr_after_construction(self): 

99 connections = self.build_dynamic_connections() 

100 with self.assertRaises(TypeError): 

101 del connections.the_connection 

102 

103 def test_change_dimensions(self): 

104 def callback(instance): 

105 instance.dimensions = {"new", "dimensions"} 

106 instance.the_connection = Input( 

107 doc=instance.the_connection.doc, 

108 name=instance.the_connection.name, 

109 dimensions={"new", "dimensions"}, 

110 storageClass=instance.the_connection.storageClass, 

111 ) 

112 

113 connections = self.build_dynamic_connections(callback) 

114 self.assertEqual(connections.dimensions, {"new", "dimensions"}) 

115 self.assertIsInstance(connections.dimensions, frozenset) 

116 self.assertEqual(connections.inputs, {"the_connection"}) 

117 self.assertEqual(connections.allConnections.keys(), {"the_connection"}) 

118 updated_connection_kwargs = RENAMED_CONNECTION_KWARGS.copy() 

119 updated_connection_kwargs["dimensions"] = {"new", "dimensions"} 

120 self.assertEqual(connections.allConnections["the_connection"], Input(**updated_connection_kwargs)) 

121 self.assertEqual(connections.the_connection, Input(**updated_connection_kwargs)) 

122 

123 def test_change_connection_type(self): 

124 def callback(instance): 

125 instance.the_connection = PrerequisiteInput( 

126 doc=instance.the_connection.doc, 

127 name=instance.the_connection.name, 

128 dimensions=instance.the_connection.dimensions, 

129 storageClass=instance.the_connection.storageClass, 

130 ) 

131 

132 connections = self.build_dynamic_connections(callback) 

133 self.assertEqual(connections.inputs, set()) 

134 self.assertEqual(connections.prerequisiteInputs, {"the_connection"}) 

135 self.assertEqual(connections.allConnections.keys(), {"the_connection"}) 

136 self.assertEqual( 

137 connections.allConnections["the_connection"], PrerequisiteInput(**RENAMED_CONNECTION_KWARGS) 

138 ) 

139 self.assertEqual(connections.the_connection, PrerequisiteInput(**RENAMED_CONNECTION_KWARGS)) 

140 

141 def test_change_connection_type_twice(self): 

142 def callback(instance): 

143 instance.the_connection = PrerequisiteInput( 

144 doc=instance.the_connection.doc, 

145 name=instance.the_connection.name, 

146 dimensions=instance.the_connection.dimensions, 

147 storageClass=instance.the_connection.storageClass, 

148 ) 

149 instance.the_connection = Output( 

150 doc=instance.the_connection.doc, 

151 name=instance.the_connection.name, 

152 dimensions=instance.the_connection.dimensions, 

153 storageClass=instance.the_connection.storageClass, 

154 ) 

155 

156 connections = self.build_dynamic_connections(callback) 

157 self.assertEqual(connections.inputs, set()) 

158 self.assertEqual(connections.prerequisiteInputs, set()) 

159 self.assertEqual(connections.outputs, {"the_connection"}) 

160 self.assertEqual(connections.allConnections.keys(), {"the_connection"}) 

161 self.assertEqual(connections.allConnections["the_connection"], Output(**RENAMED_CONNECTION_KWARGS)) 

162 self.assertEqual(connections.the_connection, Output(**RENAMED_CONNECTION_KWARGS)) 

163 

164 def test_remove_from_set(self): 

165 def callback(instance): 

166 instance.inputs.remove("the_connection") 

167 # We can't make this remove corresponding attribute or the entry in 

168 # allConnections *immediately* without using a custom set class for 

169 # 'inputs' etc, which we haven't bothered to do, because even that 

170 # wouldn't be enough to have additions to those sets update the 

171 # attributes and allConnections. Instead updates to those happen 

172 # after __init__. 

173 

174 connections = self.build_dynamic_connections(callback) 

175 self.assertEqual(connections.inputs, set()) 

176 self.assertEqual(connections.allConnections, {}) 

177 with self.assertRaises(AttributeError): 

178 connections.the_connection 

179 

180 def test_delete_attr(self): 

181 def callback(instance): 

182 del instance.the_connection 

183 # This updates the corresponding entry from the inputs set and 

184 # the allConnections dict. 

185 self.assertEqual(instance.inputs, set()) 

186 self.assertEqual(instance.allConnections, {}) 

187 

188 connections = self.build_dynamic_connections(callback) 

189 self.assertEqual(connections.inputs, set()) 

190 self.assertEqual(connections.allConnections, {}) 

191 with self.assertRaises(AttributeError): 

192 connections.the_connection 

193 

194 def test_delete_attr_twice(self): 

195 def callback(instance): 

196 del instance.the_connection 

197 with self.assertRaises(AttributeError): 

198 del instance.the_connection 

199 

200 self.build_dynamic_connections(callback) 

201 

202 def test_change_connection_type_then_remove_from_set(self): 

203 def callback(instance): 

204 instance.the_connection = PrerequisiteInput( 

205 doc=instance.the_connection.doc, 

206 name=instance.the_connection.name, 

207 dimensions=instance.the_connection.dimensions, 

208 storageClass=instance.the_connection.storageClass, 

209 ) 

210 instance.prerequisiteInputs.remove("the_connection") 

211 

212 connections = self.build_dynamic_connections(callback) 

213 self.assertEqual(connections.inputs, set()) 

214 self.assertEqual(connections.prerequisiteInputs, set()) 

215 self.assertEqual(connections.allConnections, {}) 

216 with self.assertRaises(AttributeError): 

217 connections.the_connection 

218 with self.assertRaises(KeyError): 

219 connections.allConnections["the_connection"] 

220 

221 def test_change_connection_type_then_delete_attr(self): 

222 def callback(instance): 

223 instance.the_connection = PrerequisiteInput( 

224 doc=instance.the_connection.doc, 

225 name=instance.the_connection.name, 

226 dimensions=instance.the_connection.dimensions, 

227 storageClass=instance.the_connection.storageClass, 

228 ) 

229 del instance.the_connection 

230 # This updates the corresponding entry from the inputs set and 

231 # the allConnections dict. 

232 self.assertEqual(instance.inputs, set()) 

233 self.assertEqual(instance.prerequisiteInputs, set()) 

234 self.assertEqual(instance.allConnections, {}) 

235 with self.assertRaises(AttributeError): 

236 instance.the_connection 

237 with self.assertRaises(KeyError): 

238 instance.allConnections["the_connection"] 

239 

240 connections = self.build_dynamic_connections(callback) 

241 self.assertEqual(connections.inputs, set()) 

242 self.assertEqual(connections.prerequisiteInputs, set()) 

243 self.assertEqual(connections.allConnections, {}) 

244 with self.assertRaises(AttributeError): 

245 connections.the_connection 

246 with self.assertRaises(KeyError): 

247 connections.allConnections["the_connection"] 

248 

249 def test_add_new_connection(self): 

250 new_connection = Output( 

251 name="new_dataset_type", 

252 doc="new connection_docs", 

253 storageClass="Dummy", 

254 dimensions=(), 

255 ) 

256 

257 def callback(instance): 

258 instance.new_connection = new_connection 

259 self.assertEqual(instance.outputs, {"new_connection"}) 

260 self.assertEqual(instance.allConnections.keys(), {"new_connection", "the_connection"}) 

261 self.assertIs(instance.new_connection, new_connection) 

262 self.assertIs(instance.allConnections["new_connection"], new_connection) 

263 

264 connections = self.build_dynamic_connections(callback) 

265 self.assertEqual(connections.outputs, {"new_connection"}) 

266 self.assertEqual(connections.allConnections.keys(), {"new_connection", "the_connection"}) 

267 self.assertIs(connections.new_connection, new_connection) 

268 self.assertIs(connections.allConnections["new_connection"], new_connection) 

269 

270 def test_add_and_change_new_connection(self): 

271 new_connection = Output( 

272 name="new_dataset_type", 

273 doc="new connection_docs", 

274 storageClass="Dummy", 

275 dimensions=(), 

276 ) 

277 changed_connection = PrerequisiteInput( 

278 name="new_dataset_type", 

279 doc="new connection_docs", 

280 storageClass="Dummy", 

281 dimensions=(), 

282 ) 

283 

284 def callback(instance): 

285 instance.new_connection = new_connection 

286 self.assertEqual(instance.outputs, {"new_connection"}) 

287 self.assertEqual(instance.allConnections.keys(), {"new_connection", "the_connection"}) 

288 self.assertIs(instance.new_connection, new_connection) 

289 self.assertIs(instance.allConnections["new_connection"], new_connection) 

290 instance.new_connection = changed_connection 

291 self.assertEqual(instance.outputs, set()) 

292 self.assertEqual(instance.allConnections.keys(), {"new_connection", "the_connection"}) 

293 self.assertIs(instance.new_connection, changed_connection) 

294 self.assertIs(instance.allConnections["new_connection"], changed_connection) 

295 

296 connections = self.build_dynamic_connections(callback) 

297 self.assertEqual(connections.outputs, set()) 

298 self.assertEqual(connections.allConnections.keys(), {"new_connection", "the_connection"}) 

299 self.assertIs(connections.new_connection, changed_connection) 

300 self.assertIs(connections.allConnections["new_connection"], changed_connection) 

301 

302 def test_add_and_remove_new_connection(self): 

303 new_connection = Output( 

304 name="new_dataset_type", 

305 doc="new connection_docs", 

306 storageClass="Dummy", 

307 dimensions=(), 

308 ) 

309 

310 def callback(instance): 

311 instance.new_connection = new_connection 

312 self.assertEqual(instance.outputs, {"new_connection"}) 

313 self.assertEqual(instance.allConnections.keys(), {"new_connection", "the_connection"}) 

314 self.assertIs(instance.new_connection, new_connection) 

315 self.assertIs(instance.allConnections["new_connection"], new_connection) 

316 del instance.new_connection 

317 self.assertEqual(instance.outputs, set()) 

318 self.assertEqual(instance.allConnections.keys(), {"the_connection"}) 

319 with self.assertRaises(AttributeError): 

320 instance.new_connection 

321 with self.assertRaises(KeyError): 

322 instance.allConnections["new_connection"] 

323 

324 connections = self.build_dynamic_connections(callback) 

325 self.assertEqual(connections.outputs, set()) 

326 self.assertEqual(connections.allConnections.keys(), {"the_connection"}) 

327 with self.assertRaises(AttributeError): 

328 connections.new_connection 

329 with self.assertRaises(KeyError): 

330 connections.allConnections["new_connection"] 

331 

332 

333if __name__ == "__main__": 

334 unittest.main()