Coverage for tests/test_connections.py: 15%
104 statements
« prev ^ index » next coverage.py v7.3.0, created at 2023-08-23 10:31 +0000
« prev ^ index » next coverage.py v7.3.0, created at 2023-08-23 10:31 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22"""Simple unit test for PipelineTaskConnections.
23"""
25import unittest
26import warnings
28import lsst.pipe.base as pipeBase
29import lsst.utils.tests
30import pytest
31from lsst.pex.config import Field
34class TestConnectionsClass(unittest.TestCase):
35 """Test connection classes."""
37 def __init__(self, *args, **kwargs):
38 super().__init__(*args, **kwargs)
40 # Test dimensions
41 self.test_dims = ("a", "b")
43 def testConnectionsDeclaration(self):
44 """Tests the declaration of a Connections Class"""
45 with pytest.raises(TypeError):
46 # This should raise because this Connections class is created with
47 # no dimensions
48 class TestConnections(pipeBase.PipelineTaskConnections):
49 pass
51 with pytest.raises(TypeError):
52 # This should raise because this Connections class is created with
53 # out template defaults
54 class TestConnectionsTemplate(pipeBase.PipelineTaskConnections, dimensions=self.test_dims):
55 field = pipeBase.connectionTypes.Input(
56 doc="Test", name="{template}test", dimensions=self.test_dims, storageClass="Dummy"
57 )
59 # This declaration should raise no exceptions
60 class TestConnectionsWithDimensions(pipeBase.PipelineTaskConnections, dimensions=self.test_dims):
61 pass
63 # This should not raise
64 class TestConnectionsWithTemplate(
65 pipeBase.PipelineTaskConnections,
66 dimensions=self.test_dims,
67 defaultTemplates={"template": "working"},
68 ):
69 field = pipeBase.connectionTypes.Input(
70 doc="Test", name="{template}test", dimensions=self.test_dims, storageClass="Dummy"
71 )
73 def testConnectionsOnConnectionsClass(self):
74 class TestConnections(pipeBase.PipelineTaskConnections, dimensions=self.test_dims):
75 initInput1 = pipeBase.connectionTypes.InitInput(
76 doc="Test Init input", name="init_input", storageClass="Dummy"
77 )
78 initInput2 = pipeBase.connectionTypes.InitInput(
79 doc="Test Init input", name="init_input2", storageClass="Dummy"
80 )
82 initOutput1 = pipeBase.connectionTypes.InitOutput(
83 doc="Test Init output", name="init_output1", storageClass="Dummy"
84 )
85 initOutput2 = pipeBase.connectionTypes.InitOutput(
86 doc="Test Init output", name="init_output2", storageClass="Dummy"
87 )
89 input1 = pipeBase.connectionTypes.Input(
90 doc="test input", name="input2", dimensions=self.test_dims, storageClass="Dummy"
91 )
92 input2 = pipeBase.connectionTypes.Input(
93 doc="test input", name="input2", dimensions=self.test_dims, storageClass="Dummy"
94 )
96 prereqInputs1 = pipeBase.connectionTypes.PrerequisiteInput(
97 doc="test input", name="pre_input1", dimensions=self.test_dims, storageClass="Dummy"
98 )
99 prereqInputs2 = pipeBase.connectionTypes.PrerequisiteInput(
100 doc="test input", name="pre_input2", dimensions=self.test_dims, storageClass="Dummy"
101 )
103 output1 = pipeBase.connectionTypes.Output(
104 doc="test output", name="output", dimensions=self.test_dims, storageClass="Dummy"
105 )
106 output2 = pipeBase.connectionTypes.Output(
107 doc="test output", name="output", dimensions=self.test_dims, storageClass="Dummy"
108 )
110 self.assertEqual(TestConnections.initInputs, frozenset(["initInput1", "initInput2"]))
111 self.assertEqual(TestConnections.initOutputs, frozenset(["initOutput1", "initOutput2"]))
112 self.assertEqual(TestConnections.inputs, frozenset(["input1", "input2"]))
113 self.assertEqual(TestConnections.prerequisiteInputs, frozenset(["prereqInputs1", "prereqInputs2"]))
114 self.assertEqual(TestConnections.outputs, frozenset(["output1", "output2"]))
116 def buildTestConnections(self):
117 class TestConnectionsWithTemplate(
118 pipeBase.PipelineTaskConnections,
119 dimensions=self.test_dims,
120 defaultTemplates={"template": "working"},
121 ):
122 field = pipeBase.connectionTypes.Input(
123 doc="Test", name="{template}test", dimensions=self.test_dims, storageClass="Dummy"
124 )
125 field2 = pipeBase.connectionTypes.Output(
126 doc="Test", name="field2Type", dimensions=self.test_dims, storageClass="Dummy", multiple=True
127 )
129 def adjustQuantum(self, datasetRefMap):
130 if len(datasetRefMap.field) < 2:
131 raise ValueError("This connection should have more than one entry")
133 class TestConfig(pipeBase.PipelineTaskConfig, pipelineConnections=TestConnectionsWithTemplate):
134 pass
136 config = TestConfig()
137 config.connections.template = "fromConfig"
138 config.connections.field2 = "field2FromConfig"
140 connections = TestConnectionsWithTemplate(config=config)
141 return connections
143 def testConnectionsInstantiation(self):
144 connections = self.buildTestConnections()
145 self.assertEqual(connections.field.name, "fromConfigtest")
146 self.assertEqual(connections.field2.name, "field2FromConfig")
147 self.assertEqual(connections.allConnections["field"].name, "fromConfigtest")
148 self.assertEqual(connections.allConnections["field2"].name, "field2FromConfig")
150 def testBuildDatasetRefs(self):
151 connections = self.buildTestConnections()
153 mockQuantum = pipeBase.Struct(
154 inputs={"fromConfigtest": ["a"]}, outputs={"field2FromConfig": ["b", "c"]}
155 )
157 inputRefs, outputRefs = connections.buildDatasetRefs(mockQuantum)
158 self.assertEqual(inputRefs.field, "a")
159 self.assertEqual(outputRefs.field2, ["b", "c"])
161 def testAdjustQuantum(self):
162 connections = self.buildTestConnections()
163 mockQuantum = pipeBase.Struct(
164 inputs={"fromConfigtest": ["a"]}, outputs={"field2FromConfig": ["b", "c"]}
165 )
166 inputRefs, outputRefs = connections.buildDatasetRefs(mockQuantum)
167 with self.assertRaises(ValueError):
168 connections.adjustQuantum(inputRefs)
170 def testDimensionCheck(self):
171 with self.assertRaises(TypeError):
173 class TestConnectionsWithBrokenDimensionsStr(pipeBase.PipelineTask, dimensions={"a"}):
174 pass
176 with self.assertRaises(TypeError):
178 class TestConnectionsWithBrokenDimensionsIter(pipeBase.PipelineTask, dimensions=2):
179 pass
181 with self.assertRaises(TypeError):
182 pipeBase.connectionTypes.Output(
183 Doc="mock doc", dimensions={"a"}, name="output", storageClass="mock"
184 )
186 with self.assertRaises(TypeError):
187 pipeBase.connectionTypes.Output(Doc="mock doc", dimensions=1, name="output", storageClass="mock")
189 def test_deprecation(self) -> None:
190 """Test support for deprecating connections."""
192 class TestConnections(
193 pipeBase.PipelineTaskConnections,
194 dimensions=self.test_dims,
195 defaultTemplates={"t1": "dataset_type_1"},
196 deprecatedTemplates={"t1": "Deprecated in v600, will be removed after v601."},
197 ):
198 input1 = pipeBase.connectionTypes.Input(
199 doc="Docs for input1",
200 name="input1_{t1}",
201 storageClass="StructuredDataDict",
202 deprecated="Deprecated in v50000, will be removed after v50001.",
203 )
205 def __init__(self, config):
206 if config.drop_input1:
207 del self.input1
209 class TestConfig(pipeBase.PipelineTaskConfig, pipelineConnections=TestConnections):
210 drop_input1 = Field("Remove the 'input1' connection if True", dtype=bool, default=False)
212 config = TestConfig()
213 with self.assertWarns(FutureWarning):
214 config.connections.input1 = "dataset_type_2"
215 with self.assertWarns(FutureWarning):
216 config.connections.t1 = "dataset_type_3"
218 with self.assertWarns(FutureWarning):
219 TestConnections(config=config)
221 config.drop_input1 = True
223 with warnings.catch_warnings():
224 warnings.simplefilter("error", FutureWarning)
225 TestConnections(config=config)
228class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase):
229 """Run file leak tests."""
232def setup_module(module):
233 """Configure pytest."""
234 lsst.utils.tests.init()
237if __name__ == "__main__":
238 lsst.utils.tests.init()
239 unittest.main()