Coverage for tests/test_pipelineTask.py: 30%
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22"""Simple unit test for PipelineTask.
23"""
25import unittest
26from types import SimpleNamespace
28import lsst.pex.config as pexConfig
29import lsst.pipe.base as pipeBase
30import lsst.utils.tests
31from lsst.daf.butler import DataCoordinate, DatasetRef, DimensionUniverse, Quantum
34class ButlerMock:
35 """Mock version of butler, only usable for this test"""
37 def __init__(self):
38 self.datasets = {}
39 self.registry = SimpleNamespace(dimensions=DimensionUniverse())
41 def getDirect(self, ref):
42 dsdata = self.datasets.get(ref.datasetType.name)
43 if dsdata:
44 return dsdata.get(ref.dataId)
45 return None
47 def put(self, inMemoryDataset, dsRef, producer=None):
48 key = dsRef.dataId
49 if isinstance(dsRef.datasetType, str):
50 name = dsRef.datasetType
51 else:
52 name = dsRef.datasetType.name
53 dsdata = self.datasets.setdefault(name, {})
54 dsdata[key] = inMemoryDataset
57class AddConnections(pipeBase.PipelineTaskConnections, dimensions=["instrument", "visit"]):
58 input = pipeBase.connectionTypes.Input(
59 name="add_input",
60 dimensions=["instrument", "visit", "detector", "physical_filter", "band"],
61 storageClass="Catalog",
62 doc="Input dataset type for this task",
63 )
64 output = pipeBase.connectionTypes.Output(
65 name="add_output",
66 dimensions=["instrument", "visit", "detector", "physical_filter", "band"],
67 storageClass="Catalog",
68 doc="Output dataset type for this task",
69 )
72class AddConfig(pipeBase.PipelineTaskConfig, pipelineConnections=AddConnections):
73 addend = pexConfig.Field(doc="amount to add", dtype=int, default=3)
76# example task which overrides run() method
77class AddTask(pipeBase.PipelineTask):
78 ConfigClass = AddConfig
79 _DefaultName = "add_task"
81 def run(self, input):
82 self.metadata.add("add", self.config.addend)
83 output = input + self.config.addend
84 return pipeBase.Struct(output=output)
87# example task which overrides adaptArgsAndRun() method
88class AddTask2(pipeBase.PipelineTask):
89 ConfigClass = AddConfig
90 _DefaultName = "add_task"
92 def runQuantum(self, butlerQC, inputRefs, outputRefs):
93 self.metadata.add("add", self.config.addend)
94 inputs = butlerQC.get(inputRefs)
95 outputs = inputs["input"] + self.config.addend
96 butlerQC.put(pipeBase.Struct(output=outputs), outputRefs)
99class PipelineTaskTestCase(unittest.TestCase):
100 """A test case for PipelineTask"""
102 def _makeDSRefVisit(self, dstype, visitId, universe):
103 return DatasetRef(
104 datasetType=dstype,
105 dataId=DataCoordinate.standardize(
106 detector="X",
107 visit=visitId,
108 physical_filter="a",
109 band="b",
110 instrument="TestInstrument",
111 universe=universe,
112 ),
113 )
115 def _makeQuanta(self, config):
116 """Create set of Quanta"""
117 universe = DimensionUniverse()
118 connections = config.connections.ConnectionsClass(config=config)
120 dstype0 = connections.input.makeDatasetType(universe)
121 dstype1 = connections.output.makeDatasetType(universe)
123 quanta = []
124 for visit in range(100):
125 inputRef = self._makeDSRefVisit(dstype0, visit, universe)
126 outputRef = self._makeDSRefVisit(dstype1, visit, universe)
127 quantum = Quantum(
128 inputs={inputRef.datasetType: [inputRef]}, outputs={outputRef.datasetType: [outputRef]}
129 )
130 quanta.append(quantum)
132 return quanta
134 def testRunQuantum(self):
135 """Test for AddTask.runQuantum() implementation."""
136 butler = ButlerMock()
137 task = AddTask(config=AddConfig())
138 connections = task.config.connections.ConnectionsClass(config=task.config)
140 # make all quanta
141 quanta = self._makeQuanta(task.config)
143 # add input data to butler
144 dstype0 = connections.input.makeDatasetType(butler.registry.dimensions)
145 for i, quantum in enumerate(quanta):
146 ref = quantum.inputs[dstype0.name][0]
147 butler.put(100 + i, pipeBase.Struct(datasetType=dstype0.name, dataId=ref.dataId))
149 # run task on each quanta
150 for quantum in quanta:
151 butlerQC = pipeBase.ButlerQuantumContext(butler, quantum)
152 inputRefs, outputRefs = connections.buildDatasetRefs(quantum)
153 task.runQuantum(butlerQC, inputRefs, outputRefs)
155 # look at the output produced by the task
156 outputName = connections.output.name
157 dsdata = butler.datasets[outputName]
158 self.assertEqual(len(dsdata), len(quanta))
159 for i, quantum in enumerate(quanta):
160 ref = quantum.outputs[outputName][0]
161 self.assertEqual(dsdata[ref.dataId], 100 + i + 3)
163 def testChain2(self):
164 """Test for two-task chain."""
165 butler = ButlerMock()
166 config1 = AddConfig()
167 connections1 = config1.connections.ConnectionsClass(config=config1)
168 task1 = AddTask(config=config1)
169 config2 = AddConfig()
170 config2.addend = 200
171 config2.connections.input = task1.config.connections.output
172 config2.connections.output = "add_output_2"
173 task2 = AddTask2(config=config2)
174 connections2 = config2.connections.ConnectionsClass(config=config2)
176 # make all quanta
177 quanta1 = self._makeQuanta(task1.config)
178 quanta2 = self._makeQuanta(task2.config)
180 # add input data to butler
181 task1Connections = task1.config.connections.ConnectionsClass(config=task1.config)
182 dstype0 = task1Connections.input.makeDatasetType(butler.registry.dimensions)
183 for i, quantum in enumerate(quanta1):
184 ref = quantum.inputs[dstype0.name][0]
185 butler.put(100 + i, pipeBase.Struct(datasetType=dstype0.name, dataId=ref.dataId))
187 # run task on each quanta
188 for quantum in quanta1:
189 butlerQC = pipeBase.ButlerQuantumContext(butler, quantum)
190 inputRefs, outputRefs = connections1.buildDatasetRefs(quantum)
191 task1.runQuantum(butlerQC, inputRefs, outputRefs)
192 for quantum in quanta2:
193 butlerQC = pipeBase.ButlerQuantumContext(butler, quantum)
194 inputRefs, outputRefs = connections2.buildDatasetRefs(quantum)
195 task2.runQuantum(butlerQC, inputRefs, outputRefs)
197 # look at the output produced by the task
198 outputName = task1.config.connections.output
199 dsdata = butler.datasets[outputName]
200 self.assertEqual(len(dsdata), len(quanta1))
201 for i, quantum in enumerate(quanta1):
202 ref = quantum.outputs[outputName][0]
203 self.assertEqual(dsdata[ref.dataId], 100 + i + 3)
205 outputName = task2.config.connections.output
206 dsdata = butler.datasets[outputName]
207 self.assertEqual(len(dsdata), len(quanta2))
208 for i, quantum in enumerate(quanta2):
209 ref = quantum.outputs[outputName][0]
210 self.assertEqual(dsdata[ref.dataId], 100 + i + 3 + 200)
213class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase):
214 pass
217def setup_module(module):
218 lsst.utils.tests.init()
221if __name__ == "__main__": 221 ↛ 222line 221 didn't jump to line 222, because the condition on line 221 was never true
222 lsst.utils.tests.init()
223 unittest.main()