Coverage for tests/test_pipelineTask.py : 22%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22"""Simple unit test for PipelineTask.
23"""
25import unittest
26from types import SimpleNamespace
28import lsst.utils.tests
29from lsst.daf.butler import DatasetRef, Quantum, DimensionUniverse, DataCoordinate
30import lsst.pex.config as pexConfig
31import lsst.pipe.base as pipeBase
34class ButlerMock():
35 """Mock version of butler, only usable for this test
36 """
37 def __init__(self):
38 self.datasets = {}
39 self.registry = SimpleNamespace(dimensions=DimensionUniverse())
41 def get(self, datasetRefOrType, dataId=None):
42 if isinstance(datasetRefOrType, DatasetRef):
43 dataId = datasetRefOrType.dataId
44 dsTypeName = datasetRefOrType.datasetType.name
45 else:
46 dsTypeName = datasetRefOrType
47 key = dataId
48 dsdata = self.datasets.get(dsTypeName)
49 if dsdata:
50 return dsdata.get(key)
51 return None
53 def put(self, inMemoryDataset, dsRef, producer=None):
54 key = dsRef.dataId
55 if isinstance(dsRef.datasetType, str):
56 name = dsRef.datasetType
57 else:
58 name = dsRef.datasetType.name
59 dsdata = self.datasets.setdefault(name, {})
60 dsdata[key] = inMemoryDataset
63class AddConnections(pipeBase.PipelineTaskConnections, dimensions=["instrument", "visit"]):
64 input = pipeBase.connectionTypes.Input(name="add_input",
65 dimensions=["instrument", "visit", "detector",
66 "physical_filter", "band"],
67 storageClass="Catalog",
68 doc="Input dataset type for this task")
69 output = pipeBase.connectionTypes.Output(name="add_output",
70 dimensions=["instrument", "visit", "detector",
71 "physical_filter", "band"],
72 storageClass="Catalog",
73 doc="Output dataset type for this task")
76class AddConfig(pipeBase.PipelineTaskConfig, pipelineConnections=AddConnections):
77 addend = pexConfig.Field(doc="amount to add", dtype=int, default=3)
80# example task which overrides run() method
81class AddTask(pipeBase.PipelineTask):
82 ConfigClass = AddConfig
83 _DefaultName = "add_task"
85 def run(self, input):
86 self.metadata.add("add", self.config.addend)
87 output = input + self.config.addend
88 return pipeBase.Struct(output=output)
91# example task which overrides adaptArgsAndRun() method
92class AddTask2(pipeBase.PipelineTask):
93 ConfigClass = AddConfig
94 _DefaultName = "add_task"
96 def runQuantum(self, butlerQC, inputRefs, outputRefs):
97 self.metadata.add("add", self.config.addend)
98 inputs = butlerQC.get(inputRefs)
99 outputs = inputs['input'] + self.config.addend
100 butlerQC.put(pipeBase.Struct(output=outputs), outputRefs)
103class PipelineTaskTestCase(unittest.TestCase):
104 """A test case for PipelineTask
105 """
107 def _makeDSRefVisit(self, dstype, visitId, universe):
108 return DatasetRef(
109 datasetType=dstype,
110 dataId=DataCoordinate.standardize(
111 detector="X",
112 visit=visitId,
113 physical_filter='a',
114 band='b',
115 instrument='TestInstrument',
116 universe=universe
117 )
118 )
120 def _makeQuanta(self, config):
121 """Create set of Quanta
122 """
123 universe = DimensionUniverse()
124 run = "run1"
125 connections = config.connections.ConnectionsClass(config=config)
127 dstype0 = connections.input.makeDatasetType(universe)
128 dstype1 = connections.output.makeDatasetType(universe)
130 quanta = []
131 for visit in range(100):
132 quantum = Quantum(run=run)
133 quantum.addPredictedInput(self._makeDSRefVisit(dstype0, visit, universe))
134 quantum.addOutput(self._makeDSRefVisit(dstype1, visit, universe))
135 quanta.append(quantum)
137 return quanta
139 def testRunQuantum(self):
140 """Test for AddTask.runQuantum() implementation.
141 """
142 butler = ButlerMock()
143 task = AddTask(config=AddConfig())
144 connections = task.config.connections.ConnectionsClass(config=task.config)
146 # make all quanta
147 quanta = self._makeQuanta(task.config)
149 # add input data to butler
150 dstype0 = connections.input.makeDatasetType(butler.registry.dimensions)
151 for i, quantum in enumerate(quanta):
152 ref = quantum.predictedInputs[dstype0.name][0]
153 butler.put(100 + i, pipeBase.Struct(datasetType=dstype0.name, dataId=ref.dataId))
155 # run task on each quanta
156 for quantum in quanta:
157 butlerQC = pipeBase.ButlerQuantumContext(butler, quantum)
158 inputRefs, outputRefs = connections.buildDatasetRefs(quantum)
159 task.runQuantum(butlerQC, inputRefs, outputRefs)
161 # look at the output produced by the task
162 outputName = connections.output.name
163 dsdata = butler.datasets[outputName]
164 self.assertEqual(len(dsdata), len(quanta))
165 for i, quantum in enumerate(quanta):
166 ref = quantum.outputs[outputName][0]
167 self.assertEqual(dsdata[ref.dataId], 100 + i + 3)
169 def testChain2(self):
170 """Test for two-task chain.
171 """
172 butler = ButlerMock()
173 config1 = AddConfig()
174 connections1 = config1.connections.ConnectionsClass(config=config1)
175 task1 = AddTask(config=config1)
176 config2 = AddConfig()
177 config2.addend = 200
178 config2.connections.input = task1.config.connections.output
179 config2.connections.output = "add_output_2"
180 task2 = AddTask2(config=config2)
181 connections2 = config2.connections.ConnectionsClass(config=config2)
183 # make all quanta
184 quanta1 = self._makeQuanta(task1.config)
185 quanta2 = self._makeQuanta(task2.config)
187 # add input data to butler
188 task1Connections = task1.config.connections.ConnectionsClass(config=task1.config)
189 dstype0 = task1Connections.input.makeDatasetType(butler.registry.dimensions)
190 for i, quantum in enumerate(quanta1):
191 ref = quantum.predictedInputs[dstype0.name][0]
192 butler.put(100 + i, pipeBase.Struct(datasetType=dstype0.name, dataId=ref.dataId))
194 # run task on each quanta
195 for quantum in quanta1:
196 butlerQC = pipeBase.ButlerQuantumContext(butler, quantum)
197 inputRefs, outputRefs = connections1.buildDatasetRefs(quantum)
198 task1.runQuantum(butlerQC, inputRefs, outputRefs)
199 for quantum in quanta2:
200 butlerQC = pipeBase.ButlerQuantumContext(butler, quantum)
201 inputRefs, outputRefs = connections2.buildDatasetRefs(quantum)
202 task2.runQuantum(butlerQC, inputRefs, outputRefs)
204 # look at the output produced by the task
205 outputName = task1.config.connections.output
206 dsdata = butler.datasets[outputName]
207 self.assertEqual(len(dsdata), len(quanta1))
208 for i, quantum in enumerate(quanta1):
209 ref = quantum.outputs[outputName][0]
210 self.assertEqual(dsdata[ref.dataId], 100 + i + 3)
212 outputName = task2.config.connections.output
213 dsdata = butler.datasets[outputName]
214 self.assertEqual(len(dsdata), len(quanta2))
215 for i, quantum in enumerate(quanta2):
216 ref = quantum.outputs[outputName][0]
217 self.assertEqual(dsdata[ref.dataId], 100 + i + 3 + 200)
220class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase):
221 pass
224def setup_module(module):
225 lsst.utils.tests.init()
228if __name__ == "__main__": 228 ↛ 229line 228 didn't jump to line 229, because the condition on line 228 was never true
229 lsst.utils.tests.init()
230 unittest.main()