Coverage for tests/test_connections.py : 21%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22"""Simple unit test for PipelineTaskConnections.
23"""
25import unittest
26import pytest
28import lsst.utils.tests
29import lsst.pipe.base as pipeBase
32class TestConnectionsClass(unittest.TestCase):
33 def __init__(self, *args, **kwargs):
34 super().__init__(*args, **kwargs)
36 # Test dimensions
37 self.test_dims = ("a", "b")
39 def testConnectionsDeclaration(self):
40 """Tests the declaration of a Connections Class
41 """
42 with pytest.raises(TypeError):
43 # This should raise because this Connections class is created with
44 # no dimensions
45 class TestConnections(pipeBase.PipelineTaskConnections):
46 pass
48 with pytest.raises(TypeError):
49 # This should raise because this Connections class is created with
50 # out template defaults
51 class TestConnectionsTemplate(pipeBase.PipelineTaskConnections, dimensions=self.test_dims):
52 field = pipeBase.connectionTypes.Input(doc="Test", name="{template}test",
53 dimensions=self.test_dims,
54 storageClass='Dummy')
56 # This declaration should raise no exceptions
57 class TestConnectionsWithDimensions(pipeBase.PipelineTaskConnections, dimensions=self.test_dims):
58 pass
60 # This should not raise
61 class TestConnectionsWithTemplate(pipeBase.PipelineTaskConnections, dimensions=self.test_dims,
62 defaultTemplates={"template": "working"}):
63 field = pipeBase.connectionTypes.Input(doc="Test", name="{template}test",
64 dimensions=self.test_dims,
65 storageClass='Dummy')
67 def testConnectionsOnConnectionsClass(self):
68 class TestConnections(pipeBase.PipelineTaskConnections, dimensions=self.test_dims):
69 initInput1 = pipeBase.connectionTypes.InitInput(doc="Test Init input", name="init_input",
70 storageClass='Dummy')
71 initInput2 = pipeBase.connectionTypes.InitInput(doc="Test Init input", name="init_input2",
72 storageClass='Dummy')
74 initOutput1 = pipeBase.connectionTypes.InitOutput(doc="Test Init output", name="init_output1",
75 storageClass='Dummy')
76 initOutput2 = pipeBase.connectionTypes.InitOutput(doc="Test Init output", name="init_output2",
77 storageClass='Dummy')
79 input1 = pipeBase.connectionTypes.Input(doc="test input", name="input2",
80 dimensions=self.test_dims,
81 storageClass='Dummy')
82 input2 = pipeBase.connectionTypes.Input(doc="test input", name="input2",
83 dimensions=self.test_dims,
84 storageClass='Dummy')
86 prereqInputs1 = pipeBase.connectionTypes.PrerequisiteInput(doc="test input", name="pre_input1",
87 dimensions=self.test_dims,
88 storageClass='Dummy')
89 prereqInputs2 = pipeBase.connectionTypes.PrerequisiteInput(doc="test input", name="pre_input2",
90 dimensions=self.test_dims,
91 storageClass='Dummy')
93 output1 = pipeBase.connectionTypes.Output(doc="test output", name="output",
94 dimensions=self.test_dims,
95 storageClass='Dummy')
96 output2 = pipeBase.connectionTypes.Output(doc="test output", name="output",
97 dimensions=self.test_dims,
98 storageClass='Dummy')
100 self.assertEqual(TestConnections.initInputs, frozenset(["initInput1", "initInput2"]))
101 self.assertEqual(TestConnections.initOutputs, frozenset(["initOutput1", "initOutput2"]))
102 self.assertEqual(TestConnections.inputs, frozenset(["input1", "input2"]))
103 self.assertEqual(TestConnections.prerequisiteInputs, frozenset(["prereqInputs1", "prereqInputs2"]))
104 self.assertEqual(TestConnections.outputs, frozenset(["output1", "output2"]))
106 def buildTestConnections(self):
107 class TestConnectionsWithTemplate(pipeBase.PipelineTaskConnections, dimensions=self.test_dims,
108 defaultTemplates={"template": "working"}):
109 field = pipeBase.connectionTypes.Input(doc="Test", name="{template}test",
110 dimensions=self.test_dims,
111 storageClass='Dummy')
112 field2 = pipeBase.connectionTypes.Output(doc="Test", name="field2Type",
113 dimensions=self.test_dims,
114 storageClass='Dummy',
115 multiple=True)
117 def adjustQuantum(self, datasetRefMap):
118 if len(datasetRefMap.field) < 2:
119 raise ValueError("This connection should have more than one entry")
121 class TestConfig(pipeBase.PipelineTaskConfig, pipelineConnections=TestConnectionsWithTemplate):
122 pass
124 config = TestConfig()
125 config.connections.template = "fromConfig"
126 config.connections.field2 = "field2FromConfig"
128 connections = TestConnectionsWithTemplate(config=config)
129 return connections
131 def testConnectionsInstantiation(self):
132 connections = self.buildTestConnections()
133 self.assertEqual(connections.field.name, "fromConfigtest")
134 self.assertEqual(connections.field2.name, "field2FromConfig")
136 def testBuildDatasetRefs(self):
137 connections = self.buildTestConnections()
139 mockQuantum = pipeBase.Struct(inputs={"fromConfigtest": ["a"]},
140 outputs={"field2FromConfig": ["b", "c"]})
142 inputRefs, outputRefs = connections.buildDatasetRefs(mockQuantum)
143 self.assertEqual(inputRefs.field, "a")
144 self.assertEqual(outputRefs.field2, ["b", "c"])
146 def testAdjustQuantum(self):
147 connections = self.buildTestConnections()
148 mockQuantum = pipeBase.Struct(inputs={"fromConfigtest": ["a"]},
149 outputs={"field2FromConfig": ["b", "c"]})
150 inputRefs, outputRefs = connections.buildDatasetRefs(mockQuantum)
151 with self.assertRaises(ValueError):
152 connections.adjustQuantum(inputRefs)
154 def testDimensionCheck(self):
155 with self.assertRaises(TypeError):
156 class TestConnectionsWithBrokenDimensionsStr(pipeBase.PipelineTask, dimensions=("a")):
157 pass
159 with self.assertRaises(TypeError):
160 class TestConnectionsWithBrokenDimensionsIter(pipeBase.PipelineTask, dimensions=2):
161 pass
163 with self.assertRaises(TypeError):
164 pipeBase.connectionTypes.Output(Doc="mock doc", dimensions=("a"), name="output",
165 storageClass="mock")
167 with self.assertRaises(TypeError):
168 pipeBase.connectionTypes.Output(Doc="mock doc", dimensions=1, name="output",
169 storageClass="mock")
172class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase):
173 pass
176def setup_module(module):
177 lsst.utils.tests.init()
180if __name__ == "__main__": 180 ↛ 181line 180 didn't jump to line 181, because the condition on line 180 was never true
181 lsst.utils.tests.init()
182 unittest.main()