Coverage for tests/test_connections.py : 21%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22"""Simple unit test for PipelineTaskConnections.
23"""
25import unittest
26import pytest
28import lsst.utils.tests
29import lsst.pipe.base as pipeBase
32class TestConnectionsClass(unittest.TestCase):
33 def __init__(self, *args, **kwargs):
34 super().__init__(*args, **kwargs)
36 # Test dimensions
37 self.test_dims = ("a", "b")
39 def testConnectionsDeclaration(self):
40 """Tests the declaration of a Connections Class
41 """
42 with pytest.raises(TypeError):
43 # This should raise because this Connections class is created with no dimensions
44 class TestConnections(pipeBase.PipelineTaskConnections):
45 pass
47 with pytest.raises(TypeError):
48 # This should raise because this Connections class is created with out template defaults
49 class TestConnectionsTemplate(pipeBase.PipelineTaskConnections, dimensions=self.test_dims):
50 field = pipeBase.connectionTypes.Input(doc="Test", name="{template}test",
51 dimensions=self.test_dims,
52 storageClass='Dummy')
54 # This declaration should raise no exceptions
55 class TestConnectionsWithDimensions(pipeBase.PipelineTaskConnections, dimensions=self.test_dims):
56 pass
58 # This should not raise
59 class TestConnectionsWithTemplate(pipeBase.PipelineTaskConnections, dimensions=self.test_dims,
60 defaultTemplates={"template": "working"}):
61 field = pipeBase.connectionTypes.Input(doc="Test", name="{template}test",
62 dimensions=self.test_dims,
63 storageClass='Dummy')
65 def testConnectionsOnConnectionsClass(self):
66 class TestConnections(pipeBase.PipelineTaskConnections, dimensions=self.test_dims):
67 initInput1 = pipeBase.connectionTypes.InitInput(doc="Test Init input", name="init_input",
68 storageClass='Dummy')
69 initInput2 = pipeBase.connectionTypes.InitInput(doc="Test Init input", name="init_input2",
70 storageClass='Dummy')
72 initOutput1 = pipeBase.connectionTypes.InitOutput(doc="Test Init output", name="init_output1",
73 storageClass='Dummy')
74 initOutput2 = pipeBase.connectionTypes.InitOutput(doc="Test Init output", name="init_output2",
75 storageClass='Dummy')
77 input1 = pipeBase.connectionTypes.Input(doc="test input", name="input2",
78 dimensions=self.test_dims,
79 storageClass='Dummy')
80 input2 = pipeBase.connectionTypes.Input(doc="test input", name="input2",
81 dimensions=self.test_dims,
82 storageClass='Dummy')
84 prereqInputs1 = pipeBase.connectionTypes.PrerequisiteInput(doc="test input", name="pre_input1",
85 dimensions=self.test_dims,
86 storageClass='Dummy')
87 prereqInputs2 = pipeBase.connectionTypes.PrerequisiteInput(doc="test input", name="pre_input2",
88 dimensions=self.test_dims,
89 storageClass='Dummy')
91 output1 = pipeBase.connectionTypes.Output(doc="test output", name="output",
92 dimensions=self.test_dims,
93 storageClass='Dummy')
94 output2 = pipeBase.connectionTypes.Output(doc="test output", name="output",
95 dimensions=self.test_dims,
96 storageClass='Dummy')
98 self.assertEqual(TestConnections.initInputs, frozenset(["initInput1", "initInput2"]))
99 self.assertEqual(TestConnections.initOutputs, frozenset(["initOutput1", "initOutput2"]))
100 self.assertEqual(TestConnections.inputs, frozenset(["input1", "input2"]))
101 self.assertEqual(TestConnections.prerequisiteInputs, frozenset(["prereqInputs1", "prereqInputs2"]))
102 self.assertEqual(TestConnections.outputs, frozenset(["output1", "output2"]))
104 def buildTestConnections(self):
105 class TestConnectionsWithTemplate(pipeBase.PipelineTaskConnections, dimensions=self.test_dims,
106 defaultTemplates={"template": "working"}):
107 field = pipeBase.connectionTypes.Input(doc="Test", name="{template}test",
108 dimensions=self.test_dims,
109 storageClass='Dummy')
110 field2 = pipeBase.connectionTypes.Output(doc="Test", name="field2Type",
111 dimensions=self.test_dims,
112 storageClass='Dummy',
113 multiple=True)
115 def adjustQuantum(self, datasetRefMap):
116 if len(datasetRefMap.field) < 2:
117 raise ValueError("This connection should have more than one entry")
119 class TestConfig(pipeBase.PipelineTaskConfig, pipelineConnections=TestConnectionsWithTemplate):
120 pass
122 config = TestConfig()
123 config.connections.template = "fromConfig"
124 config.connections.field2 = "field2FromConfig"
126 connections = TestConnectionsWithTemplate(config=config)
127 return connections
129 def testConnectionsInstantiation(self):
130 connections = self.buildTestConnections()
131 self.assertEqual(connections.field.name, "fromConfigtest")
132 self.assertEqual(connections.field2.name, "field2FromConfig")
134 def testBuildDatasetRefs(self):
135 connections = self.buildTestConnections()
137 mockQuantum = pipeBase.Struct(predictedInputs={"fromConfigtest": ["a"]},
138 outputs={"field2FromConfig": ["b", "c"]})
140 inputRefs, outputRefs = connections.buildDatasetRefs(mockQuantum)
141 self.assertEqual(inputRefs.field, "a")
142 self.assertEqual(outputRefs.field2, ["b", "c"])
144 def testAdjustQuantum(self):
145 connections = self.buildTestConnections()
146 mockQuantum = pipeBase.Struct(predictedInputs={"fromConfigtest": ["a"]},
147 outputs={"field2FromConfig": ["b", "c"]})
148 inputRefs, outputRefs = connections.buildDatasetRefs(mockQuantum)
149 with self.assertRaises(ValueError):
150 connections.adjustQuantum(inputRefs)
152 def testDimensionCheck(self):
153 with self.assertRaises(TypeError):
154 class TestConnectionsWithBrokenDimensionsStr(pipeBase.PipelineTask, dimensions=("a")):
155 pass
157 with self.assertRaises(TypeError):
158 class TestConnectionsWithBrokenDimensionsIter(pipeBase.PipelineTask, dimensions=2):
159 pass
161 with self.assertRaises(TypeError):
162 pipeBase.connectionTypes.Output(Doc="mock doc", dimensions=("a"), name="output",
163 storageClass="mock")
165 with self.assertRaises(TypeError):
166 pipeBase.connectionTypes.Output(Doc="mock doc", dimensions=1, name="output",
167 storageClass="mock")
170class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase):
171 pass
174def setup_module(module):
175 lsst.utils.tests.init()
178if __name__ == "__main__": 178 ↛ 179line 178 didn't jump to line 179, because the condition on line 178 was never true
179 lsst.utils.tests.init()
180 unittest.main()