Coverage for tests/test_connections.py: 15%
104 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-17 02:45 -0700
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-17 02:45 -0700
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
28"""Simple unit test for PipelineTaskConnections.
29"""
31import unittest
32import warnings
34import lsst.pipe.base as pipeBase
35import lsst.utils.tests
36import pytest
37from lsst.pex.config import Field
40class TestConnectionsClass(unittest.TestCase):
41 """Test connection classes."""
43 def __init__(self, *args, **kwargs):
44 super().__init__(*args, **kwargs)
46 # Test dimensions
47 self.test_dims = ("a", "b")
49 def testConnectionsDeclaration(self):
50 """Tests the declaration of a Connections Class."""
51 with pytest.raises(TypeError):
52 # This should raise because this Connections class is created with
53 # no dimensions
54 class TestConnections(pipeBase.PipelineTaskConnections):
55 pass
57 with pytest.raises(TypeError):
58 # This should raise because this Connections class is created with
59 # out template defaults
60 class TestConnectionsTemplate(pipeBase.PipelineTaskConnections, dimensions=self.test_dims):
61 field = pipeBase.connectionTypes.Input(
62 doc="Test", name="{template}test", dimensions=self.test_dims, storageClass="Dummy"
63 )
65 # This declaration should raise no exceptions
66 class TestConnectionsWithDimensions(pipeBase.PipelineTaskConnections, dimensions=self.test_dims):
67 pass
69 # This should not raise
70 class TestConnectionsWithTemplate(
71 pipeBase.PipelineTaskConnections,
72 dimensions=self.test_dims,
73 defaultTemplates={"template": "working"},
74 ):
75 field = pipeBase.connectionTypes.Input(
76 doc="Test", name="{template}test", dimensions=self.test_dims, storageClass="Dummy"
77 )
79 def testConnectionsOnConnectionsClass(self):
80 class TestConnections(pipeBase.PipelineTaskConnections, dimensions=self.test_dims):
81 initInput1 = pipeBase.connectionTypes.InitInput(
82 doc="Test Init input", name="init_input", storageClass="Dummy"
83 )
84 initInput2 = pipeBase.connectionTypes.InitInput(
85 doc="Test Init input", name="init_input2", storageClass="Dummy"
86 )
88 initOutput1 = pipeBase.connectionTypes.InitOutput(
89 doc="Test Init output", name="init_output1", storageClass="Dummy"
90 )
91 initOutput2 = pipeBase.connectionTypes.InitOutput(
92 doc="Test Init output", name="init_output2", storageClass="Dummy"
93 )
95 input1 = pipeBase.connectionTypes.Input(
96 doc="test input", name="input2", dimensions=self.test_dims, storageClass="Dummy"
97 )
98 input2 = pipeBase.connectionTypes.Input(
99 doc="test input", name="input2", dimensions=self.test_dims, storageClass="Dummy"
100 )
102 prereqInputs1 = pipeBase.connectionTypes.PrerequisiteInput(
103 doc="test input", name="pre_input1", dimensions=self.test_dims, storageClass="Dummy"
104 )
105 prereqInputs2 = pipeBase.connectionTypes.PrerequisiteInput(
106 doc="test input", name="pre_input2", dimensions=self.test_dims, storageClass="Dummy"
107 )
109 output1 = pipeBase.connectionTypes.Output(
110 doc="test output", name="output", dimensions=self.test_dims, storageClass="Dummy"
111 )
112 output2 = pipeBase.connectionTypes.Output(
113 doc="test output", name="output", dimensions=self.test_dims, storageClass="Dummy"
114 )
116 self.assertEqual(TestConnections.initInputs, frozenset(["initInput1", "initInput2"]))
117 self.assertEqual(TestConnections.initOutputs, frozenset(["initOutput1", "initOutput2"]))
118 self.assertEqual(TestConnections.inputs, frozenset(["input1", "input2"]))
119 self.assertEqual(TestConnections.prerequisiteInputs, frozenset(["prereqInputs1", "prereqInputs2"]))
120 self.assertEqual(TestConnections.outputs, frozenset(["output1", "output2"]))
122 def buildTestConnections(self):
123 class TestConnectionsWithTemplate(
124 pipeBase.PipelineTaskConnections,
125 dimensions=self.test_dims,
126 defaultTemplates={"template": "working"},
127 ):
128 field = pipeBase.connectionTypes.Input(
129 doc="Test", name="{template}test", dimensions=self.test_dims, storageClass="Dummy"
130 )
131 field2 = pipeBase.connectionTypes.Output(
132 doc="Test", name="field2Type", dimensions=self.test_dims, storageClass="Dummy", multiple=True
133 )
135 def adjustQuantum(self, datasetRefMap):
136 if len(datasetRefMap.field) < 2:
137 raise ValueError("This connection should have more than one entry")
139 class TestConfig(pipeBase.PipelineTaskConfig, pipelineConnections=TestConnectionsWithTemplate):
140 pass
142 config = TestConfig()
143 config.connections.template = "fromConfig"
144 config.connections.field2 = "field2FromConfig"
146 connections = TestConnectionsWithTemplate(config=config)
147 return connections
149 def testConnectionsInstantiation(self):
150 connections = self.buildTestConnections()
151 self.assertEqual(connections.field.name, "fromConfigtest")
152 self.assertEqual(connections.field2.name, "field2FromConfig")
153 self.assertEqual(connections.allConnections["field"].name, "fromConfigtest")
154 self.assertEqual(connections.allConnections["field2"].name, "field2FromConfig")
156 def testBuildDatasetRefs(self):
157 connections = self.buildTestConnections()
159 mockQuantum = pipeBase.Struct(
160 inputs={"fromConfigtest": ["a"]}, outputs={"field2FromConfig": ["b", "c"]}
161 )
163 inputRefs, outputRefs = connections.buildDatasetRefs(mockQuantum)
164 self.assertEqual(inputRefs.field, "a")
165 self.assertEqual(outputRefs.field2, ["b", "c"])
167 def testAdjustQuantum(self):
168 connections = self.buildTestConnections()
169 mockQuantum = pipeBase.Struct(
170 inputs={"fromConfigtest": ["a"]}, outputs={"field2FromConfig": ["b", "c"]}
171 )
172 inputRefs, outputRefs = connections.buildDatasetRefs(mockQuantum)
173 with self.assertRaises(ValueError):
174 connections.adjustQuantum(inputRefs)
176 def testDimensionCheck(self):
177 with self.assertRaises(TypeError):
179 class TestConnectionsWithBrokenDimensionsStr(pipeBase.PipelineTask, dimensions={"a"}):
180 pass
182 with self.assertRaises(TypeError):
184 class TestConnectionsWithBrokenDimensionsIter(pipeBase.PipelineTask, dimensions=2):
185 pass
187 with self.assertRaises(TypeError):
188 pipeBase.connectionTypes.Output(
189 Doc="mock doc", dimensions={"a"}, name="output", storageClass="mock"
190 )
192 with self.assertRaises(TypeError):
193 pipeBase.connectionTypes.Output(Doc="mock doc", dimensions=1, name="output", storageClass="mock")
195 def test_deprecation(self) -> None:
196 """Test support for deprecating connections."""
198 class TestConnections(
199 pipeBase.PipelineTaskConnections,
200 dimensions=self.test_dims,
201 defaultTemplates={"t1": "dataset_type_1"},
202 deprecatedTemplates={"t1": "Deprecated in v600, will be removed after v601."},
203 ):
204 input1 = pipeBase.connectionTypes.Input(
205 doc="Docs for input1",
206 name="input1_{t1}",
207 storageClass="StructuredDataDict",
208 deprecated="Deprecated in v50000, will be removed after v50001.",
209 )
211 def __init__(self, config):
212 if config.drop_input1:
213 del self.input1
215 class TestConfig(pipeBase.PipelineTaskConfig, pipelineConnections=TestConnections):
216 drop_input1 = Field("Remove the 'input1' connection if True", dtype=bool, default=False)
218 config = TestConfig()
219 with self.assertWarns(FutureWarning):
220 config.connections.input1 = "dataset_type_2"
221 with self.assertWarns(FutureWarning):
222 config.connections.t1 = "dataset_type_3"
224 with self.assertWarns(FutureWarning):
225 TestConnections(config=config)
227 config.drop_input1 = True
229 with warnings.catch_warnings():
230 warnings.simplefilter("error", FutureWarning)
231 TestConnections(config=config)
234class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase):
235 """Run file leak tests."""
238def setup_module(module):
239 """Configure pytest."""
240 lsst.utils.tests.init()
243if __name__ == "__main__":
244 lsst.utils.tests.init()
245 unittest.main()