Coverage for tests/test_pipelineTask.py: 23%
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22"""Simple unit test for PipelineTask.
23"""
25import unittest
26from types import SimpleNamespace
28import lsst.utils.tests
29from lsst.daf.butler import DatasetRef, Quantum, DimensionUniverse, DataCoordinate
30import lsst.pex.config as pexConfig
31import lsst.pipe.base as pipeBase
34class ButlerMock():
35 """Mock version of butler, only usable for this test
36 """
37 def __init__(self):
38 self.datasets = {}
39 self.registry = SimpleNamespace(dimensions=DimensionUniverse())
41 def getDirect(self, ref):
42 dsdata = self.datasets.get(ref.datasetType.name)
43 if dsdata:
44 return dsdata.get(ref.dataId)
45 return None
47 def put(self, inMemoryDataset, dsRef, producer=None):
48 key = dsRef.dataId
49 if isinstance(dsRef.datasetType, str):
50 name = dsRef.datasetType
51 else:
52 name = dsRef.datasetType.name
53 dsdata = self.datasets.setdefault(name, {})
54 dsdata[key] = inMemoryDataset
57class AddConnections(pipeBase.PipelineTaskConnections, dimensions=["instrument", "visit"]):
58 input = pipeBase.connectionTypes.Input(name="add_input",
59 dimensions=["instrument", "visit", "detector",
60 "physical_filter", "band"],
61 storageClass="Catalog",
62 doc="Input dataset type for this task")
63 output = pipeBase.connectionTypes.Output(name="add_output",
64 dimensions=["instrument", "visit", "detector",
65 "physical_filter", "band"],
66 storageClass="Catalog",
67 doc="Output dataset type for this task")
70class AddConfig(pipeBase.PipelineTaskConfig, pipelineConnections=AddConnections):
71 addend = pexConfig.Field(doc="amount to add", dtype=int, default=3)
74# example task which overrides run() method
75class AddTask(pipeBase.PipelineTask):
76 ConfigClass = AddConfig
77 _DefaultName = "add_task"
79 def run(self, input):
80 self.metadata.add("add", self.config.addend)
81 output = input + self.config.addend
82 return pipeBase.Struct(output=output)
85# example task which overrides adaptArgsAndRun() method
86class AddTask2(pipeBase.PipelineTask):
87 ConfigClass = AddConfig
88 _DefaultName = "add_task"
90 def runQuantum(self, butlerQC, inputRefs, outputRefs):
91 self.metadata.add("add", self.config.addend)
92 inputs = butlerQC.get(inputRefs)
93 outputs = inputs['input'] + self.config.addend
94 butlerQC.put(pipeBase.Struct(output=outputs), outputRefs)
97class PipelineTaskTestCase(unittest.TestCase):
98 """A test case for PipelineTask
99 """
101 def _makeDSRefVisit(self, dstype, visitId, universe):
102 return DatasetRef(
103 datasetType=dstype,
104 dataId=DataCoordinate.standardize(
105 detector="X",
106 visit=visitId,
107 physical_filter='a',
108 band='b',
109 instrument='TestInstrument',
110 universe=universe
111 )
112 )
114 def _makeQuanta(self, config):
115 """Create set of Quanta
116 """
117 universe = DimensionUniverse()
118 connections = config.connections.ConnectionsClass(config=config)
120 dstype0 = connections.input.makeDatasetType(universe)
121 dstype1 = connections.output.makeDatasetType(universe)
123 quanta = []
124 for visit in range(100):
125 inputRef = self._makeDSRefVisit(dstype0, visit, universe)
126 outputRef = self._makeDSRefVisit(dstype1, visit, universe)
127 quantum = Quantum(inputs={inputRef.datasetType: [inputRef]},
128 outputs={outputRef.datasetType: [outputRef]})
129 quanta.append(quantum)
131 return quanta
133 def testRunQuantum(self):
134 """Test for AddTask.runQuantum() implementation.
135 """
136 butler = ButlerMock()
137 task = AddTask(config=AddConfig())
138 connections = task.config.connections.ConnectionsClass(config=task.config)
140 # make all quanta
141 quanta = self._makeQuanta(task.config)
143 # add input data to butler
144 dstype0 = connections.input.makeDatasetType(butler.registry.dimensions)
145 for i, quantum in enumerate(quanta):
146 ref = quantum.inputs[dstype0.name][0]
147 butler.put(100 + i, pipeBase.Struct(datasetType=dstype0.name, dataId=ref.dataId))
149 # run task on each quanta
150 for quantum in quanta:
151 butlerQC = pipeBase.ButlerQuantumContext(butler, quantum)
152 inputRefs, outputRefs = connections.buildDatasetRefs(quantum)
153 task.runQuantum(butlerQC, inputRefs, outputRefs)
155 # look at the output produced by the task
156 outputName = connections.output.name
157 dsdata = butler.datasets[outputName]
158 self.assertEqual(len(dsdata), len(quanta))
159 for i, quantum in enumerate(quanta):
160 ref = quantum.outputs[outputName][0]
161 self.assertEqual(dsdata[ref.dataId], 100 + i + 3)
163 def testChain2(self):
164 """Test for two-task chain.
165 """
166 butler = ButlerMock()
167 config1 = AddConfig()
168 connections1 = config1.connections.ConnectionsClass(config=config1)
169 task1 = AddTask(config=config1)
170 config2 = AddConfig()
171 config2.addend = 200
172 config2.connections.input = task1.config.connections.output
173 config2.connections.output = "add_output_2"
174 task2 = AddTask2(config=config2)
175 connections2 = config2.connections.ConnectionsClass(config=config2)
177 # make all quanta
178 quanta1 = self._makeQuanta(task1.config)
179 quanta2 = self._makeQuanta(task2.config)
181 # add input data to butler
182 task1Connections = task1.config.connections.ConnectionsClass(config=task1.config)
183 dstype0 = task1Connections.input.makeDatasetType(butler.registry.dimensions)
184 for i, quantum in enumerate(quanta1):
185 ref = quantum.inputs[dstype0.name][0]
186 butler.put(100 + i, pipeBase.Struct(datasetType=dstype0.name, dataId=ref.dataId))
188 # run task on each quanta
189 for quantum in quanta1:
190 butlerQC = pipeBase.ButlerQuantumContext(butler, quantum)
191 inputRefs, outputRefs = connections1.buildDatasetRefs(quantum)
192 task1.runQuantum(butlerQC, inputRefs, outputRefs)
193 for quantum in quanta2:
194 butlerQC = pipeBase.ButlerQuantumContext(butler, quantum)
195 inputRefs, outputRefs = connections2.buildDatasetRefs(quantum)
196 task2.runQuantum(butlerQC, inputRefs, outputRefs)
198 # look at the output produced by the task
199 outputName = task1.config.connections.output
200 dsdata = butler.datasets[outputName]
201 self.assertEqual(len(dsdata), len(quanta1))
202 for i, quantum in enumerate(quanta1):
203 ref = quantum.outputs[outputName][0]
204 self.assertEqual(dsdata[ref.dataId], 100 + i + 3)
206 outputName = task2.config.connections.output
207 dsdata = butler.datasets[outputName]
208 self.assertEqual(len(dsdata), len(quanta2))
209 for i, quantum in enumerate(quanta2):
210 ref = quantum.outputs[outputName][0]
211 self.assertEqual(dsdata[ref.dataId], 100 + i + 3 + 200)
214class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase):
215 pass
218def setup_module(module):
219 lsst.utils.tests.init()
222if __name__ == "__main__": 222 ↛ 223line 222 didn't jump to line 223, because the condition on line 222 was never true
223 lsst.utils.tests.init()
224 unittest.main()