Coverage for tests/test_dynamic_connections.py: 12%
180 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-11 02:00 -0700
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-11 02:00 -0700
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22from __future__ import annotations
24import unittest
25from types import MappingProxyType
26from typing import Callable
28from lsst.pipe.base import PipelineTaskConfig, PipelineTaskConnections
29from lsst.pipe.base.connectionTypes import Input, Output, PrerequisiteInput
31# Keyword arguments for defining our lone test connection in its class.
32DEFAULT_CONNECTION_KWARGS = dict(
33 doc="field docs",
34 name="unconfigured",
35 dimensions=(),
36 storageClass="Dummy",
37)
39# Keyword arguments for making our lone test connection post-configuration.
40# We always rename the dataset type via configuration to make sure those
41# changes are never dropped.
42RENAMED_CONNECTION_KWARGS = DEFAULT_CONNECTION_KWARGS.copy()
43RENAMED_CONNECTION_KWARGS["name"] = "configured"
46class TestDynamicConnectionsClass(unittest.TestCase):
47 """Test modifying connections in derived __init__ implementations."""
49 def build_dynamic_connections(
50 self, init_callback: Callable[[PipelineTaskConnections], None] = None
51 ) -> PipelineTaskConnections:
52 """Define and construct a connections class instance with a callback
53 run in ``__init__``.
55 Parameters
56 ----------
57 init_callback
58 Callback to invoke with the `PipelineTaskConnections` instance
59 as its only argument. Return value is ignored.
61 Returns
62 -------
63 connections : `PipelineTaskConnections`
64 Constructed connections instance.
65 """
67 class ExampleConnections(PipelineTaskConnections, dimensions=()):
68 the_connection = Input(**DEFAULT_CONNECTION_KWARGS)
70 def __init__(self, config: ExampleConfig):
71 # Calling super() is harmless but now unnecessary, so don't do
72 # it to make sure that works. Old code calling it is fine.
73 if init_callback is not None:
74 init_callback(self)
76 class ExampleConfig(PipelineTaskConfig, pipelineConnections=ExampleConnections):
77 pass
79 config = ExampleConfig()
80 config.connections.the_connection = RENAMED_CONNECTION_KWARGS["name"]
81 return ExampleConnections(config=config)
83 def test_freeze_after_construction(self):
84 connections = self.build_dynamic_connections()
85 self.assertIsInstance(connections.dimensions, frozenset)
86 self.assertIsInstance(connections.inputs, frozenset)
87 self.assertIsInstance(connections.prerequisiteInputs, frozenset)
88 self.assertIsInstance(connections.outputs, frozenset)
89 self.assertIsInstance(connections.initInputs, frozenset)
90 self.assertIsInstance(connections.initOutputs, frozenset)
91 self.assertIsInstance(connections.allConnections, MappingProxyType)
93 def test_change_attr_after_construction(self):
94 connections = self.build_dynamic_connections()
95 with self.assertRaises(TypeError):
96 connections.the_connection = PrerequisiteInput(**RENAMED_CONNECTION_KWARGS)
98 def test_delete_attr_after_construction(self):
99 connections = self.build_dynamic_connections()
100 with self.assertRaises(TypeError):
101 del connections.the_connection
103 def test_change_dimensions(self):
104 def callback(instance):
105 instance.dimensions = {"new", "dimensions"}
106 instance.the_connection = Input(
107 doc=instance.the_connection.doc,
108 name=instance.the_connection.name,
109 dimensions={"new", "dimensions"},
110 storageClass=instance.the_connection.storageClass,
111 )
113 connections = self.build_dynamic_connections(callback)
114 self.assertEqual(connections.dimensions, {"new", "dimensions"})
115 self.assertIsInstance(connections.dimensions, frozenset)
116 self.assertEqual(connections.inputs, {"the_connection"})
117 self.assertEqual(connections.allConnections.keys(), {"the_connection"})
118 updated_connection_kwargs = RENAMED_CONNECTION_KWARGS.copy()
119 updated_connection_kwargs["dimensions"] = {"new", "dimensions"}
120 self.assertEqual(connections.allConnections["the_connection"], Input(**updated_connection_kwargs))
121 self.assertEqual(connections.the_connection, Input(**updated_connection_kwargs))
123 def test_change_connection_type(self):
124 def callback(instance):
125 instance.the_connection = PrerequisiteInput(
126 doc=instance.the_connection.doc,
127 name=instance.the_connection.name,
128 dimensions=instance.the_connection.dimensions,
129 storageClass=instance.the_connection.storageClass,
130 )
132 connections = self.build_dynamic_connections(callback)
133 self.assertEqual(connections.inputs, set())
134 self.assertEqual(connections.prerequisiteInputs, {"the_connection"})
135 self.assertEqual(connections.allConnections.keys(), {"the_connection"})
136 self.assertEqual(
137 connections.allConnections["the_connection"], PrerequisiteInput(**RENAMED_CONNECTION_KWARGS)
138 )
139 self.assertEqual(connections.the_connection, PrerequisiteInput(**RENAMED_CONNECTION_KWARGS))
141 def test_change_connection_type_twice(self):
142 def callback(instance):
143 instance.the_connection = PrerequisiteInput(
144 doc=instance.the_connection.doc,
145 name=instance.the_connection.name,
146 dimensions=instance.the_connection.dimensions,
147 storageClass=instance.the_connection.storageClass,
148 )
149 instance.the_connection = Output(
150 doc=instance.the_connection.doc,
151 name=instance.the_connection.name,
152 dimensions=instance.the_connection.dimensions,
153 storageClass=instance.the_connection.storageClass,
154 )
156 connections = self.build_dynamic_connections(callback)
157 self.assertEqual(connections.inputs, set())
158 self.assertEqual(connections.prerequisiteInputs, set())
159 self.assertEqual(connections.outputs, {"the_connection"})
160 self.assertEqual(connections.allConnections.keys(), {"the_connection"})
161 self.assertEqual(connections.allConnections["the_connection"], Output(**RENAMED_CONNECTION_KWARGS))
162 self.assertEqual(connections.the_connection, Output(**RENAMED_CONNECTION_KWARGS))
164 def test_remove_from_set(self):
165 def callback(instance):
166 instance.inputs.remove("the_connection")
167 # We can't make this remove corresponding attribute or the entry in
168 # allConnections *immediately* without using a custom set class for
169 # 'inputs' etc, which we haven't bothered to do, because even that
170 # wouldn't be enough to have additions to those sets update the
171 # attributes and allConnections. Instead updates to those happen
172 # after __init__.
174 connections = self.build_dynamic_connections(callback)
175 self.assertEqual(connections.inputs, set())
176 self.assertEqual(connections.allConnections, {})
177 with self.assertRaises(AttributeError):
178 connections.the_connection
180 def test_delete_attr(self):
181 def callback(instance):
182 del instance.the_connection
183 # This updates the corresponding entry from the inputs set and
184 # the allConnections dict.
185 self.assertEqual(instance.inputs, set())
186 self.assertEqual(instance.allConnections, {})
188 connections = self.build_dynamic_connections(callback)
189 self.assertEqual(connections.inputs, set())
190 self.assertEqual(connections.allConnections, {})
191 with self.assertRaises(AttributeError):
192 connections.the_connection
194 def test_delete_attr_twice(self):
195 def callback(instance):
196 del instance.the_connection
197 with self.assertRaises(AttributeError):
198 del instance.the_connection
200 self.build_dynamic_connections(callback)
202 def test_change_connection_type_then_remove_from_set(self):
203 def callback(instance):
204 instance.the_connection = PrerequisiteInput(
205 doc=instance.the_connection.doc,
206 name=instance.the_connection.name,
207 dimensions=instance.the_connection.dimensions,
208 storageClass=instance.the_connection.storageClass,
209 )
210 instance.prerequisiteInputs.remove("the_connection")
212 connections = self.build_dynamic_connections(callback)
213 self.assertEqual(connections.inputs, set())
214 self.assertEqual(connections.prerequisiteInputs, set())
215 self.assertEqual(connections.allConnections, {})
216 with self.assertRaises(AttributeError):
217 connections.the_connection
218 with self.assertRaises(KeyError):
219 connections.allConnections["the_connection"]
221 def test_change_connection_type_then_delete_attr(self):
222 def callback(instance):
223 instance.the_connection = PrerequisiteInput(
224 doc=instance.the_connection.doc,
225 name=instance.the_connection.name,
226 dimensions=instance.the_connection.dimensions,
227 storageClass=instance.the_connection.storageClass,
228 )
229 del instance.the_connection
230 # This updates the corresponding entry from the inputs set and
231 # the allConnections dict.
232 self.assertEqual(instance.inputs, set())
233 self.assertEqual(instance.prerequisiteInputs, set())
234 self.assertEqual(instance.allConnections, {})
235 with self.assertRaises(AttributeError):
236 instance.the_connection
237 with self.assertRaises(KeyError):
238 instance.allConnections["the_connection"]
240 connections = self.build_dynamic_connections(callback)
241 self.assertEqual(connections.inputs, set())
242 self.assertEqual(connections.prerequisiteInputs, set())
243 self.assertEqual(connections.allConnections, {})
244 with self.assertRaises(AttributeError):
245 connections.the_connection
246 with self.assertRaises(KeyError):
247 connections.allConnections["the_connection"]
249 def test_add_new_connection(self):
250 new_connection = Output(
251 name="new_dataset_type",
252 doc="new connection_docs",
253 storageClass="Dummy",
254 dimensions=(),
255 )
257 def callback(instance):
258 instance.new_connection = new_connection
259 self.assertEqual(instance.outputs, {"new_connection"})
260 self.assertEqual(instance.allConnections.keys(), {"new_connection", "the_connection"})
261 self.assertIs(instance.new_connection, new_connection)
262 self.assertIs(instance.allConnections["new_connection"], new_connection)
264 connections = self.build_dynamic_connections(callback)
265 self.assertEqual(connections.outputs, {"new_connection"})
266 self.assertEqual(connections.allConnections.keys(), {"new_connection", "the_connection"})
267 self.assertIs(connections.new_connection, new_connection)
268 self.assertIs(connections.allConnections["new_connection"], new_connection)
270 def test_add_and_change_new_connection(self):
271 new_connection = Output(
272 name="new_dataset_type",
273 doc="new connection_docs",
274 storageClass="Dummy",
275 dimensions=(),
276 )
277 changed_connection = PrerequisiteInput(
278 name="new_dataset_type",
279 doc="new connection_docs",
280 storageClass="Dummy",
281 dimensions=(),
282 )
284 def callback(instance):
285 instance.new_connection = new_connection
286 self.assertEqual(instance.outputs, {"new_connection"})
287 self.assertEqual(instance.allConnections.keys(), {"new_connection", "the_connection"})
288 self.assertIs(instance.new_connection, new_connection)
289 self.assertIs(instance.allConnections["new_connection"], new_connection)
290 instance.new_connection = changed_connection
291 self.assertEqual(instance.outputs, set())
292 self.assertEqual(instance.allConnections.keys(), {"new_connection", "the_connection"})
293 self.assertIs(instance.new_connection, changed_connection)
294 self.assertIs(instance.allConnections["new_connection"], changed_connection)
296 connections = self.build_dynamic_connections(callback)
297 self.assertEqual(connections.outputs, set())
298 self.assertEqual(connections.allConnections.keys(), {"new_connection", "the_connection"})
299 self.assertIs(connections.new_connection, changed_connection)
300 self.assertIs(connections.allConnections["new_connection"], changed_connection)
302 def test_add_and_remove_new_connection(self):
303 new_connection = Output(
304 name="new_dataset_type",
305 doc="new connection_docs",
306 storageClass="Dummy",
307 dimensions=(),
308 )
310 def callback(instance):
311 instance.new_connection = new_connection
312 self.assertEqual(instance.outputs, {"new_connection"})
313 self.assertEqual(instance.allConnections.keys(), {"new_connection", "the_connection"})
314 self.assertIs(instance.new_connection, new_connection)
315 self.assertIs(instance.allConnections["new_connection"], new_connection)
316 del instance.new_connection
317 self.assertEqual(instance.outputs, set())
318 self.assertEqual(instance.allConnections.keys(), {"the_connection"})
319 with self.assertRaises(AttributeError):
320 instance.new_connection
321 with self.assertRaises(KeyError):
322 instance.allConnections["new_connection"]
324 connections = self.build_dynamic_connections(callback)
325 self.assertEqual(connections.outputs, set())
326 self.assertEqual(connections.allConnections.keys(), {"the_connection"})
327 with self.assertRaises(AttributeError):
328 connections.new_connection
329 with self.assertRaises(KeyError):
330 connections.allConnections["new_connection"]
333if __name__ == "__main__":
334 unittest.main()