Coverage for tests/test_pipelineTask.py: 19%
181 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-06-08 09:15 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2023-06-08 09:15 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22"""Simple unit test for PipelineTask.
23"""
25import unittest
26from types import SimpleNamespace
27from typing import Any
29import lsst.pex.config as pexConfig
30import lsst.pipe.base as pipeBase
31import lsst.utils.logging
32import lsst.utils.tests
33from lsst.daf.butler import DataCoordinate, DatasetRef, DatasetType, DimensionUniverse, Quantum
36class ButlerMock:
37 """Mock version of butler, only usable for this test"""
39 def __init__(self) -> None:
40 self.datasets: dict[str, dict[DataCoordinate, Any]] = {}
41 self.registry = SimpleNamespace(dimensions=DimensionUniverse())
43 def get(self, ref: DatasetRef) -> Any:
44 dsdata = self.datasets.get(ref.datasetType.name)
45 if dsdata:
46 return dsdata.get(ref.dataId)
47 return None
49 def put(self, inMemoryDataset: Any, dsRef: DatasetRef, producer: Any = None):
50 key = dsRef.dataId
51 name = dsRef.datasetType.name
52 dsdata = self.datasets.setdefault(name, {})
53 dsdata[key] = inMemoryDataset
56class AddConnections(pipeBase.PipelineTaskConnections, dimensions=["instrument", "visit"]):
57 input = pipeBase.connectionTypes.Input(
58 name="add_input",
59 dimensions=["instrument", "visit", "detector", "physical_filter", "band"],
60 storageClass="Catalog",
61 doc="Input dataset type for this task",
62 )
63 output = pipeBase.connectionTypes.Output(
64 name="add_output",
65 dimensions=["instrument", "visit", "detector", "physical_filter", "band"],
66 storageClass="Catalog",
67 doc="Output dataset type for this task",
68 )
71class AddConfig(pipeBase.PipelineTaskConfig, pipelineConnections=AddConnections):
72 addend = pexConfig.Field[int](doc="amount to add", default=3)
75# example task which overrides run() method
76class AddTask(pipeBase.PipelineTask):
77 ConfigClass = AddConfig
78 _DefaultName = "add_task"
80 def run(self, input: int) -> pipeBase.Struct:
81 self.metadata.add("add", self.config.addend)
82 output = input + self.config.addend
83 return pipeBase.Struct(output=output)
86# example task which overrides adaptArgsAndRun() method
87class AddTask2(pipeBase.PipelineTask):
88 ConfigClass = AddConfig
89 _DefaultName = "add_task"
91 def runQuantum(
92 self, butlerQC: pipeBase.ButlerQuantumContext, inputRefs: DatasetRef, outputRefs: DatasetRef
93 ) -> None:
94 self.metadata.add("add", self.config.addend)
95 inputs = butlerQC.get(inputRefs)
96 outputs = inputs["input"] + self.config.addend
97 butlerQC.put(pipeBase.Struct(output=outputs), outputRefs)
100class PipelineTaskTestCase(unittest.TestCase):
101 """A test case for PipelineTask"""
103 def _makeDSRefVisit(self, dstype: DatasetType, visitId: int, universe: DimensionUniverse) -> DatasetRef:
104 dataId = DataCoordinate.standardize(
105 detector="X",
106 visit=visitId,
107 physical_filter="a",
108 band="b",
109 instrument="TestInstrument",
110 universe=universe,
111 )
112 run = "test"
113 ref = DatasetRef(datasetType=dstype, dataId=dataId, run=run)
114 return ref
116 def _makeQuanta(self, config: pipeBase.PipelineTaskConfig, nquanta: int = 100) -> list[Quantum]:
117 """Create set of Quanta"""
118 universe = DimensionUniverse()
119 connections = config.connections.ConnectionsClass(config=config)
121 dstype0 = connections.input.makeDatasetType(universe)
122 dstype1 = connections.output.makeDatasetType(universe)
124 quanta = []
125 for visit in range(nquanta):
126 inputRef = self._makeDSRefVisit(dstype0, visit, universe)
127 outputRef = self._makeDSRefVisit(dstype1, visit, universe)
128 quantum = Quantum(
129 inputs={inputRef.datasetType: [inputRef]}, outputs={outputRef.datasetType: [outputRef]}
130 )
131 quanta.append(quantum)
133 return quanta
135 def testRunQuantumFull(self):
136 """Test for AddTask.runQuantum() implementation with full butler."""
137 self._testRunQuantum(full_butler=True)
139 def testRunQuantumLimited(self):
140 """Test for AddTask.runQuantum() implementation with limited butler."""
141 self._testRunQuantum(full_butler=False)
143 def _testRunQuantum(self, full_butler: bool) -> None:
144 """Test for AddTask.runQuantum() implementation."""
146 butler = ButlerMock()
147 task = AddTask(config=AddConfig())
148 connections = task.config.connections.ConnectionsClass(config=task.config)
150 # make all quanta
151 quanta = self._makeQuanta(task.config)
153 # add input data to butler
154 dstype0 = connections.input.makeDatasetType(butler.registry.dimensions)
155 for i, quantum in enumerate(quanta):
156 ref = quantum.inputs[dstype0.name][0]
157 butler.put(100 + i, ref)
159 # run task on each quanta
160 checked_get = False
161 for quantum in quanta:
162 butlerQC = pipeBase.ButlerQuantumContext(butler, quantum)
163 inputRefs, outputRefs = connections.buildDatasetRefs(quantum)
164 task.runQuantum(butlerQC, inputRefs, outputRefs)
166 # Test getting of datasets in different ways.
167 # (only need to do this one time)
168 if not checked_get:
169 # Force the periodic logger to issue messages.
170 lsst.utils.logging.PeriodicLogger.LOGGING_INTERVAL = 0.0
172 checked_get = True
173 with self.assertLogs("lsst.pipe.base", level=lsst.utils.logging.VERBOSE) as cm:
174 input_data = butlerQC.get(inputRefs)
175 self.assertIn("Completed", cm.output[-1])
176 self.assertEqual(len(input_data), len(inputRefs))
178 # In this test there are no multiples returned.
179 refs = [ref for _, ref in inputRefs]
180 with self.assertLogs("lsst.pipe.base", level=lsst.utils.logging.VERBOSE) as cm:
181 list_get = butlerQC.get(refs)
183 lsst.utils.logging.PeriodicLogger.LOGGING_INTERVAL = 600.0
185 self.assertIn("Completed", cm.output[-1])
186 self.assertEqual(len(list_get), len(input_data))
187 self.assertIsInstance(list_get[0], int)
188 scalar_get = butlerQC.get(refs[0])
189 self.assertEqual(scalar_get, list_get[0])
191 with self.assertRaises(TypeError):
192 butlerQC.get({})
194 # Output ref won't be known to this quantum.
195 outputs = [ref for _, ref in outputRefs]
196 with self.assertRaises(ValueError):
197 butlerQC.get(outputs[0])
199 # look at the output produced by the task
200 outputName = connections.output.name
201 dsdata = butler.datasets[outputName]
202 self.assertEqual(len(dsdata), len(quanta))
203 for i, quantum in enumerate(quanta):
204 ref = quantum.outputs[outputName][0]
205 self.assertEqual(dsdata[ref.dataId], 100 + i + 3)
207 def testChain2Full(self) -> None:
208 """Test for two-task chain with full butler."""
209 self._testChain2(full_butler=True)
211 def testChain2Limited(self) -> None:
212 """Test for two-task chain with limited butler."""
213 self._testChain2(full_butler=False)
215 def _testChain2(self, full_butler: bool) -> None:
216 """Test for two-task chain."""
217 butler = ButlerMock()
218 config1 = AddConfig()
219 connections1 = config1.connections.ConnectionsClass(config=config1)
220 task1 = AddTask(config=config1)
221 config2 = AddConfig()
222 config2.addend = 200
223 config2.connections.input = task1.config.connections.output
224 config2.connections.output = "add_output_2"
225 task2 = AddTask2(config=config2)
226 connections2 = config2.connections.ConnectionsClass(config=config2)
228 # make all quanta
229 quanta1 = self._makeQuanta(task1.config)
230 quanta2 = self._makeQuanta(task2.config)
232 # add input data to butler
233 task1Connections = task1.config.connections.ConnectionsClass(config=task1.config)
234 dstype0 = task1Connections.input.makeDatasetType(butler.registry.dimensions)
235 for i, quantum in enumerate(quanta1):
236 ref = quantum.inputs[dstype0.name][0]
237 butler.put(100 + i, ref)
239 # run task on each quanta
240 for quantum in quanta1:
241 butlerQC = pipeBase.ButlerQuantumContext(butler, quantum)
242 inputRefs, outputRefs = connections1.buildDatasetRefs(quantum)
243 task1.runQuantum(butlerQC, inputRefs, outputRefs)
244 for quantum in quanta2:
245 butlerQC = pipeBase.ButlerQuantumContext(butler, quantum)
246 inputRefs, outputRefs = connections2.buildDatasetRefs(quantum)
247 task2.runQuantum(butlerQC, inputRefs, outputRefs)
249 # look at the output produced by the task
250 outputName = task1.config.connections.output
251 dsdata = butler.datasets[outputName]
252 self.assertEqual(len(dsdata), len(quanta1))
253 for i, quantum in enumerate(quanta1):
254 ref = quantum.outputs[outputName][0]
255 self.assertEqual(dsdata[ref.dataId], 100 + i + 3)
257 outputName = task2.config.connections.output
258 dsdata = butler.datasets[outputName]
259 self.assertEqual(len(dsdata), len(quanta2))
260 for i, quantum in enumerate(quanta2):
261 ref = quantum.outputs[outputName][0]
262 self.assertEqual(dsdata[ref.dataId], 100 + i + 3 + 200)
264 def testButlerQC(self):
265 """Test for ButlerQuantumContext. Full and limited share get
266 implementation so only full is tested.
267 """
269 butler = ButlerMock()
270 task = AddTask(config=AddConfig())
271 connections = task.config.connections.ConnectionsClass(config=task.config)
273 # make one quantum
274 (quantum,) = self._makeQuanta(task.config, 1)
276 # add input data to butler
277 dstype0 = connections.input.makeDatasetType(butler.registry.dimensions)
278 ref = quantum.inputs[dstype0.name][0]
279 butler.put(100, ref)
281 butlerQC = pipeBase.ButlerQuantumContext(butler, quantum)
283 # Pass ref as single argument or a list.
284 obj = butlerQC.get(ref)
285 self.assertEqual(obj, 100)
286 obj = butlerQC.get([ref])
287 self.assertEqual(obj, [100])
289 # Pass None instead of a ref.
290 obj = butlerQC.get(None)
291 self.assertIsNone(obj)
292 obj = butlerQC.get([None])
293 self.assertEqual(obj, [None])
295 # COmbine a ref and None.
296 obj = butlerQC.get([ref, None])
297 self.assertEqual(obj, [100, None])
299 # Use refs from a QuantizedConnection.
300 inputRefs, outputRefs = connections.buildDatasetRefs(quantum)
301 obj = butlerQC.get(inputRefs)
302 self.assertEqual(obj, {"input": 100})
304 # Add few None values to a QuantizedConnection.
305 inputRefs.input = [None, ref]
306 inputRefs.input2 = None
307 obj = butlerQC.get(inputRefs)
308 self.assertEqual(obj, {"input": [None, 100], "input2": None})
311class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase):
312 pass
315def setup_module(module):
316 lsst.utils.tests.init()
319if __name__ == "__main__": 319 ↛ 320line 319 didn't jump to line 320, because the condition on line 319 was never true
320 lsst.utils.tests.init()
321 unittest.main()