Coverage for tests/test_pipelineTask.py: 18%
177 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-15 02:49 -0700
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-15 02:49 -0700
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22"""Simple unit test for PipelineTask.
23"""
25import unittest
26from typing import Any
28import lsst.pex.config as pexConfig
29import lsst.pipe.base as pipeBase
30import lsst.utils.logging
31import lsst.utils.tests
32from lsst.daf.butler import DataCoordinate, DatasetRef, DatasetType, DimensionUniverse, Quantum
35class ButlerMock:
36 """Mock version of butler, only usable for this test"""
38 def __init__(self) -> None:
39 self.datasets: dict[str, dict[DataCoordinate, Any]] = {}
40 self.dimensions = DimensionUniverse()
42 def get(self, ref: DatasetRef) -> Any:
43 dsdata = self.datasets.get(ref.datasetType.name)
44 if dsdata:
45 return dsdata.get(ref.dataId)
46 return None
48 def put(self, inMemoryDataset: Any, dsRef: DatasetRef, producer: Any = None):
49 key = dsRef.dataId
50 name = dsRef.datasetType.name
51 dsdata = self.datasets.setdefault(name, {})
52 dsdata[key] = inMemoryDataset
55class AddConnections(pipeBase.PipelineTaskConnections, dimensions=["instrument", "visit"]):
56 input = pipeBase.connectionTypes.Input(
57 name="add_input",
58 dimensions=["instrument", "visit", "detector", "physical_filter", "band"],
59 storageClass="Catalog",
60 doc="Input dataset type for this task",
61 )
62 output = pipeBase.connectionTypes.Output(
63 name="add_output",
64 dimensions=["instrument", "visit", "detector", "physical_filter", "band"],
65 storageClass="Catalog",
66 doc="Output dataset type for this task",
67 )
70class AddConfig(pipeBase.PipelineTaskConfig, pipelineConnections=AddConnections):
71 addend = pexConfig.Field[int](doc="amount to add", default=3)
74# example task which overrides run() method
75class AddTask(pipeBase.PipelineTask):
76 ConfigClass = AddConfig
77 _DefaultName = "add_task"
79 def run(self, input: int) -> pipeBase.Struct:
80 self.metadata.add("add", self.config.addend)
81 output = input + self.config.addend
82 return pipeBase.Struct(output=output)
85# example task which overrides adaptArgsAndRun() method
86class AddTask2(pipeBase.PipelineTask):
87 ConfigClass = AddConfig
88 _DefaultName = "add_task"
90 def runQuantum(
91 self, butlerQC: pipeBase.ButlerQuantumContext, inputRefs: DatasetRef, outputRefs: DatasetRef
92 ) -> None:
93 self.metadata.add("add", self.config.addend)
94 inputs = butlerQC.get(inputRefs)
95 outputs = inputs["input"] + self.config.addend
96 butlerQC.put(pipeBase.Struct(output=outputs), outputRefs)
99class PipelineTaskTestCase(unittest.TestCase):
100 """A test case for PipelineTask"""
102 def _makeDSRefVisit(self, dstype: DatasetType, visitId: int, universe: DimensionUniverse) -> DatasetRef:
103 dataId = DataCoordinate.standardize(
104 detector="X",
105 visit=visitId,
106 physical_filter="a",
107 band="b",
108 instrument="TestInstrument",
109 universe=universe,
110 )
111 run = "test"
112 ref = DatasetRef(datasetType=dstype, dataId=dataId, run=run)
113 return ref
115 def _makeQuanta(self, config: pipeBase.PipelineTaskConfig, nquanta: int = 100) -> list[Quantum]:
116 """Create set of Quanta"""
117 universe = DimensionUniverse()
118 connections = config.connections.ConnectionsClass(config=config)
120 dstype0 = connections.input.makeDatasetType(universe)
121 dstype1 = connections.output.makeDatasetType(universe)
123 quanta = []
124 for visit in range(nquanta):
125 inputRef = self._makeDSRefVisit(dstype0, visit, universe)
126 outputRef = self._makeDSRefVisit(dstype1, visit, universe)
127 quantum = Quantum(
128 inputs={inputRef.datasetType: [inputRef]}, outputs={outputRef.datasetType: [outputRef]}
129 )
130 quanta.append(quantum)
132 return quanta
134 def testRunQuantumFull(self):
135 """Test for AddTask.runQuantum() implementation with full butler."""
136 self._testRunQuantum(full_butler=True)
138 def testRunQuantumLimited(self):
139 """Test for AddTask.runQuantum() implementation with limited butler."""
140 self._testRunQuantum(full_butler=False)
142 def _testRunQuantum(self, full_butler: bool) -> None:
143 """Test for AddTask.runQuantum() implementation."""
145 butler = ButlerMock()
146 task = AddTask(config=AddConfig())
147 connections = task.config.connections.ConnectionsClass(config=task.config)
149 # make all quanta
150 quanta = self._makeQuanta(task.config)
152 # add input data to butler
153 dstype0 = connections.input.makeDatasetType(butler.dimensions)
154 for i, quantum in enumerate(quanta):
155 ref = quantum.inputs[dstype0.name][0]
156 butler.put(100 + i, ref)
158 # run task on each quanta
159 checked_get = False
160 for quantum in quanta:
161 butlerQC = pipeBase.ButlerQuantumContext(butler, quantum)
162 inputRefs, outputRefs = connections.buildDatasetRefs(quantum)
163 task.runQuantum(butlerQC, inputRefs, outputRefs)
165 # Test getting of datasets in different ways.
166 # (only need to do this one time)
167 if not checked_get:
168 # Force the periodic logger to issue messages.
169 lsst.utils.logging.PeriodicLogger.LOGGING_INTERVAL = 0.0
171 checked_get = True
172 with self.assertLogs("lsst.pipe.base", level=lsst.utils.logging.VERBOSE) as cm:
173 input_data = butlerQC.get(inputRefs)
174 self.assertIn("Completed", cm.output[-1])
175 self.assertEqual(len(input_data), len(inputRefs))
177 # In this test there are no multiples returned.
178 refs = [ref for _, ref in inputRefs]
179 with self.assertLogs("lsst.pipe.base", level=lsst.utils.logging.VERBOSE) as cm:
180 list_get = butlerQC.get(refs)
182 lsst.utils.logging.PeriodicLogger.LOGGING_INTERVAL = 600.0
184 self.assertIn("Completed", cm.output[-1])
185 self.assertEqual(len(list_get), len(input_data))
186 self.assertIsInstance(list_get[0], int)
187 scalar_get = butlerQC.get(refs[0])
188 self.assertEqual(scalar_get, list_get[0])
190 with self.assertRaises(TypeError):
191 butlerQC.get({})
193 # Output ref won't be known to this quantum.
194 outputs = [ref for _, ref in outputRefs]
195 with self.assertRaises(ValueError):
196 butlerQC.get(outputs[0])
198 # look at the output produced by the task
199 outputName = connections.output.name
200 dsdata = butler.datasets[outputName]
201 self.assertEqual(len(dsdata), len(quanta))
202 for i, quantum in enumerate(quanta):
203 ref = quantum.outputs[outputName][0]
204 self.assertEqual(dsdata[ref.dataId], 100 + i + 3)
206 def testChain2Full(self) -> None:
207 """Test for two-task chain with full butler."""
208 self._testChain2(full_butler=True)
210 def testChain2Limited(self) -> None:
211 """Test for two-task chain with limited butler."""
212 self._testChain2(full_butler=False)
214 def _testChain2(self, full_butler: bool) -> None:
215 """Test for two-task chain."""
216 butler = ButlerMock()
217 config1 = AddConfig()
218 connections1 = config1.connections.ConnectionsClass(config=config1)
219 task1 = AddTask(config=config1)
220 config2 = AddConfig()
221 config2.addend = 200
222 config2.connections.input = task1.config.connections.output
223 config2.connections.output = "add_output_2"
224 task2 = AddTask2(config=config2)
225 connections2 = config2.connections.ConnectionsClass(config=config2)
227 # make all quanta
228 quanta1 = self._makeQuanta(task1.config)
229 quanta2 = self._makeQuanta(task2.config)
231 # add input data to butler
232 task1Connections = task1.config.connections.ConnectionsClass(config=task1.config)
233 dstype0 = task1Connections.input.makeDatasetType(butler.dimensions)
234 for i, quantum in enumerate(quanta1):
235 ref = quantum.inputs[dstype0.name][0]
236 butler.put(100 + i, ref)
238 # run task on each quanta
239 for quantum in quanta1:
240 butlerQC = pipeBase.ButlerQuantumContext(butler, quantum)
241 inputRefs, outputRefs = connections1.buildDatasetRefs(quantum)
242 task1.runQuantum(butlerQC, inputRefs, outputRefs)
243 for quantum in quanta2:
244 butlerQC = pipeBase.ButlerQuantumContext(butler, quantum)
245 inputRefs, outputRefs = connections2.buildDatasetRefs(quantum)
246 task2.runQuantum(butlerQC, inputRefs, outputRefs)
248 # look at the output produced by the task
249 outputName = task1.config.connections.output
250 dsdata = butler.datasets[outputName]
251 self.assertEqual(len(dsdata), len(quanta1))
252 for i, quantum in enumerate(quanta1):
253 ref = quantum.outputs[outputName][0]
254 self.assertEqual(dsdata[ref.dataId], 100 + i + 3)
256 outputName = task2.config.connections.output
257 dsdata = butler.datasets[outputName]
258 self.assertEqual(len(dsdata), len(quanta2))
259 for i, quantum in enumerate(quanta2):
260 ref = quantum.outputs[outputName][0]
261 self.assertEqual(dsdata[ref.dataId], 100 + i + 3 + 200)
263 def testButlerQC(self):
264 """Test for ButlerQuantumContext. Full and limited share get
265 implementation so only full is tested.
266 """
268 butler = ButlerMock()
269 task = AddTask(config=AddConfig())
270 connections = task.config.connections.ConnectionsClass(config=task.config)
272 # make one quantum
273 (quantum,) = self._makeQuanta(task.config, 1)
275 # add input data to butler
276 dstype0 = connections.input.makeDatasetType(butler.dimensions)
277 ref = quantum.inputs[dstype0.name][0]
278 butler.put(100, ref)
280 butlerQC = pipeBase.ButlerQuantumContext(butler, quantum)
282 # Pass ref as single argument or a list.
283 obj = butlerQC.get(ref)
284 self.assertEqual(obj, 100)
285 obj = butlerQC.get([ref])
286 self.assertEqual(obj, [100])
288 # Pass None instead of a ref.
289 obj = butlerQC.get(None)
290 self.assertIsNone(obj)
291 obj = butlerQC.get([None])
292 self.assertEqual(obj, [None])
294 # COmbine a ref and None.
295 obj = butlerQC.get([ref, None])
296 self.assertEqual(obj, [100, None])
298 # Use refs from a QuantizedConnection.
299 inputRefs, outputRefs = connections.buildDatasetRefs(quantum)
300 obj = butlerQC.get(inputRefs)
301 self.assertEqual(obj, {"input": 100})
303 # Add few None values to a QuantizedConnection.
304 inputRefs.input = [None, ref]
305 inputRefs.input2 = None
306 obj = butlerQC.get(inputRefs)
307 self.assertEqual(obj, {"input": [None, 100], "input2": None})
310class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase):
311 pass
314def setup_module(module):
315 lsst.utils.tests.init()
318if __name__ == "__main__":
319 lsst.utils.tests.init()
320 unittest.main()