Coverage for tests/test_pipelineTask.py: 25%
145 statements
« prev ^ index » next coverage.py v6.4.1, created at 2022-06-11 02:51 -0700
« prev ^ index » next coverage.py v6.4.1, created at 2022-06-11 02:51 -0700
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22"""Simple unit test for PipelineTask.
23"""
25import unittest
26from types import SimpleNamespace
28import lsst.pex.config as pexConfig
29import lsst.pipe.base as pipeBase
30import lsst.utils.logging
31import lsst.utils.tests
32from lsst.daf.butler import DataCoordinate, DatasetRef, DimensionUniverse, Quantum
35class ButlerMock:
36 """Mock version of butler, only usable for this test"""
38 def __init__(self):
39 self.datasets = {}
40 self.registry = SimpleNamespace(dimensions=DimensionUniverse())
42 def getDirect(self, ref):
43 dsdata = self.datasets.get(ref.datasetType.name)
44 if dsdata:
45 return dsdata.get(ref.dataId)
46 return None
48 def put(self, inMemoryDataset, dsRef, producer=None):
49 key = dsRef.dataId
50 if isinstance(dsRef.datasetType, str):
51 name = dsRef.datasetType
52 else:
53 name = dsRef.datasetType.name
54 dsdata = self.datasets.setdefault(name, {})
55 dsdata[key] = inMemoryDataset
58class AddConnections(pipeBase.PipelineTaskConnections, dimensions=["instrument", "visit"]):
59 input = pipeBase.connectionTypes.Input(
60 name="add_input",
61 dimensions=["instrument", "visit", "detector", "physical_filter", "band"],
62 storageClass="Catalog",
63 doc="Input dataset type for this task",
64 )
65 output = pipeBase.connectionTypes.Output(
66 name="add_output",
67 dimensions=["instrument", "visit", "detector", "physical_filter", "band"],
68 storageClass="Catalog",
69 doc="Output dataset type for this task",
70 )
73class AddConfig(pipeBase.PipelineTaskConfig, pipelineConnections=AddConnections):
74 addend = pexConfig.Field(doc="amount to add", dtype=int, default=3)
77# example task which overrides run() method
78class AddTask(pipeBase.PipelineTask):
79 ConfigClass = AddConfig
80 _DefaultName = "add_task"
82 def run(self, input):
83 self.metadata.add("add", self.config.addend)
84 output = input + self.config.addend
85 return pipeBase.Struct(output=output)
88# example task which overrides adaptArgsAndRun() method
89class AddTask2(pipeBase.PipelineTask):
90 ConfigClass = AddConfig
91 _DefaultName = "add_task"
93 def runQuantum(self, butlerQC, inputRefs, outputRefs):
94 self.metadata.add("add", self.config.addend)
95 inputs = butlerQC.get(inputRefs)
96 outputs = inputs["input"] + self.config.addend
97 butlerQC.put(pipeBase.Struct(output=outputs), outputRefs)
100class PipelineTaskTestCase(unittest.TestCase):
101 """A test case for PipelineTask"""
103 def _makeDSRefVisit(self, dstype, visitId, universe):
104 return DatasetRef(
105 datasetType=dstype,
106 dataId=DataCoordinate.standardize(
107 detector="X",
108 visit=visitId,
109 physical_filter="a",
110 band="b",
111 instrument="TestInstrument",
112 universe=universe,
113 ),
114 )
116 def _makeQuanta(self, config):
117 """Create set of Quanta"""
118 universe = DimensionUniverse()
119 connections = config.connections.ConnectionsClass(config=config)
121 dstype0 = connections.input.makeDatasetType(universe)
122 dstype1 = connections.output.makeDatasetType(universe)
124 quanta = []
125 for visit in range(100):
126 inputRef = self._makeDSRefVisit(dstype0, visit, universe)
127 outputRef = self._makeDSRefVisit(dstype1, visit, universe)
128 quantum = Quantum(
129 inputs={inputRef.datasetType: [inputRef]}, outputs={outputRef.datasetType: [outputRef]}
130 )
131 quanta.append(quantum)
133 return quanta
135 def testRunQuantum(self):
136 """Test for AddTask.runQuantum() implementation."""
137 butler = ButlerMock()
138 task = AddTask(config=AddConfig())
139 connections = task.config.connections.ConnectionsClass(config=task.config)
141 # make all quanta
142 quanta = self._makeQuanta(task.config)
144 # add input data to butler
145 dstype0 = connections.input.makeDatasetType(butler.registry.dimensions)
146 for i, quantum in enumerate(quanta):
147 ref = quantum.inputs[dstype0.name][0]
148 butler.put(100 + i, pipeBase.Struct(datasetType=dstype0.name, dataId=ref.dataId))
150 # run task on each quanta
151 checked_get = False
152 for quantum in quanta:
153 butlerQC = pipeBase.ButlerQuantumContext(butler, quantum)
154 inputRefs, outputRefs = connections.buildDatasetRefs(quantum)
155 task.runQuantum(butlerQC, inputRefs, outputRefs)
157 # Test getting of datasets in different ways.
158 # (only need to do this one time)
159 if not checked_get:
160 # Force the periodic logger to issue messages.
161 lsst.utils.logging.PeriodicLogger.LOGGING_INTERVAL = 0.0
163 checked_get = True
164 with self.assertLogs("lsst.pipe.base", level=lsst.utils.logging.VERBOSE) as cm:
165 input_data = butlerQC.get(inputRefs)
166 self.assertIn("Completed", cm.output[-1])
167 self.assertEqual(len(input_data), len(inputRefs))
169 # In this test there are no multiples returned.
170 refs = [ref for _, ref in inputRefs]
171 with self.assertLogs("lsst.pipe.base", level=lsst.utils.logging.VERBOSE) as cm:
172 list_get = butlerQC.get(refs)
174 lsst.utils.logging.PeriodicLogger.LOGGING_INTERVAL = 600.0
176 self.assertIn("Completed", cm.output[-1])
177 self.assertEqual(len(list_get), len(input_data))
178 self.assertIsInstance(list_get[0], int)
179 scalar_get = butlerQC.get(refs[0])
180 self.assertEqual(scalar_get, list_get[0])
182 with self.assertRaises(TypeError):
183 butlerQC.get({})
185 # Output ref won't be known to this quantum.
186 outputs = [ref for _, ref in outputRefs]
187 with self.assertRaises(ValueError):
188 butlerQC.get(outputs[0])
190 # look at the output produced by the task
191 outputName = connections.output.name
192 dsdata = butler.datasets[outputName]
193 self.assertEqual(len(dsdata), len(quanta))
194 for i, quantum in enumerate(quanta):
195 ref = quantum.outputs[outputName][0]
196 self.assertEqual(dsdata[ref.dataId], 100 + i + 3)
198 def testChain2(self):
199 """Test for two-task chain."""
200 butler = ButlerMock()
201 config1 = AddConfig()
202 connections1 = config1.connections.ConnectionsClass(config=config1)
203 task1 = AddTask(config=config1)
204 config2 = AddConfig()
205 config2.addend = 200
206 config2.connections.input = task1.config.connections.output
207 config2.connections.output = "add_output_2"
208 task2 = AddTask2(config=config2)
209 connections2 = config2.connections.ConnectionsClass(config=config2)
211 # make all quanta
212 quanta1 = self._makeQuanta(task1.config)
213 quanta2 = self._makeQuanta(task2.config)
215 # add input data to butler
216 task1Connections = task1.config.connections.ConnectionsClass(config=task1.config)
217 dstype0 = task1Connections.input.makeDatasetType(butler.registry.dimensions)
218 for i, quantum in enumerate(quanta1):
219 ref = quantum.inputs[dstype0.name][0]
220 butler.put(100 + i, pipeBase.Struct(datasetType=dstype0.name, dataId=ref.dataId))
222 # run task on each quanta
223 for quantum in quanta1:
224 butlerQC = pipeBase.ButlerQuantumContext(butler, quantum)
225 inputRefs, outputRefs = connections1.buildDatasetRefs(quantum)
226 task1.runQuantum(butlerQC, inputRefs, outputRefs)
227 for quantum in quanta2:
228 butlerQC = pipeBase.ButlerQuantumContext(butler, quantum)
229 inputRefs, outputRefs = connections2.buildDatasetRefs(quantum)
230 task2.runQuantum(butlerQC, inputRefs, outputRefs)
232 # look at the output produced by the task
233 outputName = task1.config.connections.output
234 dsdata = butler.datasets[outputName]
235 self.assertEqual(len(dsdata), len(quanta1))
236 for i, quantum in enumerate(quanta1):
237 ref = quantum.outputs[outputName][0]
238 self.assertEqual(dsdata[ref.dataId], 100 + i + 3)
240 outputName = task2.config.connections.output
241 dsdata = butler.datasets[outputName]
242 self.assertEqual(len(dsdata), len(quanta2))
243 for i, quantum in enumerate(quanta2):
244 ref = quantum.outputs[outputName][0]
245 self.assertEqual(dsdata[ref.dataId], 100 + i + 3 + 200)
248class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase):
249 pass
252def setup_module(module):
253 lsst.utils.tests.init()
256if __name__ == "__main__": 256 ↛ 257line 256 didn't jump to line 257, because the condition on line 256 was never true
257 lsst.utils.tests.init()
258 unittest.main()