Coverage for tests/test_pipelineTask.py: 18%
185 statements
« prev ^ index » next coverage.py v7.2.5, created at 2023-05-04 09:31 +0000
« prev ^ index » next coverage.py v7.2.5, created at 2023-05-04 09:31 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22"""Simple unit test for PipelineTask.
23"""
25import unittest
26from types import SimpleNamespace
27from typing import Any
29import lsst.pex.config as pexConfig
30import lsst.pipe.base as pipeBase
31import lsst.utils.logging
32import lsst.utils.tests
33from lsst.daf.butler import DataCoordinate, DatasetRef, DatasetType, DimensionUniverse, Quantum
36class ButlerMock:
37 """Mock version of butler, only usable for this test"""
39 def __init__(self) -> None:
40 self.datasets: dict[str, dict[DataCoordinate, Any]] = {}
41 self.registry = SimpleNamespace(dimensions=DimensionUniverse())
43 def get(self, ref: DatasetRef) -> Any:
44 # Requires resolved ref.
45 assert ref.id is not None
46 dsdata = self.datasets.get(ref.datasetType.name)
47 if dsdata:
48 return dsdata.get(ref.dataId)
49 return None
51 def put(self, inMemoryDataset: Any, dsRef: DatasetRef, producer: Any = None):
52 key = dsRef.dataId
53 name = dsRef.datasetType.name
54 dsdata = self.datasets.setdefault(name, {})
55 dsdata[key] = inMemoryDataset
58class AddConnections(pipeBase.PipelineTaskConnections, dimensions=["instrument", "visit"]):
59 input = pipeBase.connectionTypes.Input(
60 name="add_input",
61 dimensions=["instrument", "visit", "detector", "physical_filter", "band"],
62 storageClass="Catalog",
63 doc="Input dataset type for this task",
64 )
65 output = pipeBase.connectionTypes.Output(
66 name="add_output",
67 dimensions=["instrument", "visit", "detector", "physical_filter", "band"],
68 storageClass="Catalog",
69 doc="Output dataset type for this task",
70 )
73class AddConfig(pipeBase.PipelineTaskConfig, pipelineConnections=AddConnections):
74 addend = pexConfig.Field[int](doc="amount to add", default=3)
77# example task which overrides run() method
78class AddTask(pipeBase.PipelineTask):
79 ConfigClass = AddConfig
80 _DefaultName = "add_task"
82 def run(self, input: int) -> pipeBase.Struct:
83 self.metadata.add("add", self.config.addend)
84 output = input + self.config.addend
85 return pipeBase.Struct(output=output)
88# example task which overrides adaptArgsAndRun() method
89class AddTask2(pipeBase.PipelineTask):
90 ConfigClass = AddConfig
91 _DefaultName = "add_task"
93 def runQuantum(
94 self, butlerQC: pipeBase.ButlerQuantumContext, inputRefs: DatasetRef, outputRefs: DatasetRef
95 ) -> None:
96 self.metadata.add("add", self.config.addend)
97 inputs = butlerQC.get(inputRefs)
98 outputs = inputs["input"] + self.config.addend
99 butlerQC.put(pipeBase.Struct(output=outputs), outputRefs)
102class PipelineTaskTestCase(unittest.TestCase):
103 """A test case for PipelineTask"""
105 def _makeDSRefVisit(self, dstype: DatasetType, visitId: int, universe: DimensionUniverse) -> DatasetRef:
106 dataId = DataCoordinate.standardize(
107 detector="X",
108 visit=visitId,
109 physical_filter="a",
110 band="b",
111 instrument="TestInstrument",
112 universe=universe,
113 )
114 run = "test"
115 ref = DatasetRef(datasetType=dstype, dataId=dataId, run=run)
116 return ref
118 def _makeQuanta(self, config: pipeBase.PipelineTaskConfig, nquanta: int = 100) -> list[Quantum]:
119 """Create set of Quanta"""
120 universe = DimensionUniverse()
121 connections = config.connections.ConnectionsClass(config=config)
123 dstype0 = connections.input.makeDatasetType(universe)
124 dstype1 = connections.output.makeDatasetType(universe)
126 quanta = []
127 for visit in range(nquanta):
128 inputRef = self._makeDSRefVisit(dstype0, visit, universe)
129 outputRef = self._makeDSRefVisit(dstype1, visit, universe)
130 quantum = Quantum(
131 inputs={inputRef.datasetType: [inputRef]}, outputs={outputRef.datasetType: [outputRef]}
132 )
133 quanta.append(quantum)
135 return quanta
137 def testRunQuantumFull(self):
138 """Test for AddTask.runQuantum() implementation with full butler."""
139 self._testRunQuantum(full_butler=True)
141 def testRunQuantumLimited(self):
142 """Test for AddTask.runQuantum() implementation with limited butler."""
143 self._testRunQuantum(full_butler=False)
145 def _testRunQuantum(self, full_butler: bool) -> None:
146 """Test for AddTask.runQuantum() implementation."""
148 butler = ButlerMock()
149 task = AddTask(config=AddConfig())
150 connections = task.config.connections.ConnectionsClass(config=task.config)
152 # make all quanta
153 quanta = self._makeQuanta(task.config)
155 # add input data to butler
156 dstype0 = connections.input.makeDatasetType(butler.registry.dimensions)
157 for i, quantum in enumerate(quanta):
158 ref = quantum.inputs[dstype0.name][0]
159 butler.put(100 + i, ref)
161 # run task on each quanta
162 checked_get = False
163 for quantum in quanta:
164 if full_butler:
165 butlerQC = pipeBase.ButlerQuantumContext.from_full(butler, quantum)
166 else:
167 butlerQC = pipeBase.ButlerQuantumContext.from_limited(butler, quantum)
168 inputRefs, outputRefs = connections.buildDatasetRefs(quantum)
169 task.runQuantum(butlerQC, inputRefs, outputRefs)
171 # Test getting of datasets in different ways.
172 # (only need to do this one time)
173 if not checked_get:
174 # Force the periodic logger to issue messages.
175 lsst.utils.logging.PeriodicLogger.LOGGING_INTERVAL = 0.0
177 checked_get = True
178 with self.assertLogs("lsst.pipe.base", level=lsst.utils.logging.VERBOSE) as cm:
179 input_data = butlerQC.get(inputRefs)
180 self.assertIn("Completed", cm.output[-1])
181 self.assertEqual(len(input_data), len(inputRefs))
183 # In this test there are no multiples returned.
184 refs = [ref for _, ref in inputRefs]
185 with self.assertLogs("lsst.pipe.base", level=lsst.utils.logging.VERBOSE) as cm:
186 list_get = butlerQC.get(refs)
188 lsst.utils.logging.PeriodicLogger.LOGGING_INTERVAL = 600.0
190 self.assertIn("Completed", cm.output[-1])
191 self.assertEqual(len(list_get), len(input_data))
192 self.assertIsInstance(list_get[0], int)
193 scalar_get = butlerQC.get(refs[0])
194 self.assertEqual(scalar_get, list_get[0])
196 with self.assertRaises(TypeError):
197 butlerQC.get({})
199 # Output ref won't be known to this quantum.
200 outputs = [ref for _, ref in outputRefs]
201 with self.assertRaises(ValueError):
202 butlerQC.get(outputs[0])
204 # look at the output produced by the task
205 outputName = connections.output.name
206 dsdata = butler.datasets[outputName]
207 self.assertEqual(len(dsdata), len(quanta))
208 for i, quantum in enumerate(quanta):
209 ref = quantum.outputs[outputName][0]
210 self.assertEqual(dsdata[ref.dataId], 100 + i + 3)
212 def testChain2Full(self) -> None:
213 """Test for two-task chain with full butler."""
214 self._testChain2(full_butler=True)
216 def testChain2Limited(self) -> None:
217 """Test for two-task chain with limited butler."""
218 self._testChain2(full_butler=False)
220 def _testChain2(self, full_butler: bool) -> None:
221 """Test for two-task chain."""
222 butler = ButlerMock()
223 config1 = AddConfig()
224 connections1 = config1.connections.ConnectionsClass(config=config1)
225 task1 = AddTask(config=config1)
226 config2 = AddConfig()
227 config2.addend = 200
228 config2.connections.input = task1.config.connections.output
229 config2.connections.output = "add_output_2"
230 task2 = AddTask2(config=config2)
231 connections2 = config2.connections.ConnectionsClass(config=config2)
233 # make all quanta
234 quanta1 = self._makeQuanta(task1.config)
235 quanta2 = self._makeQuanta(task2.config)
237 # add input data to butler
238 task1Connections = task1.config.connections.ConnectionsClass(config=task1.config)
239 dstype0 = task1Connections.input.makeDatasetType(butler.registry.dimensions)
240 for i, quantum in enumerate(quanta1):
241 ref = quantum.inputs[dstype0.name][0]
242 butler.put(100 + i, ref)
244 butler_qc_factory = (
245 pipeBase.ButlerQuantumContext.from_full
246 if full_butler
247 else pipeBase.ButlerQuantumContext.from_limited
248 )
250 # run task on each quanta
251 for quantum in quanta1:
252 butlerQC = butler_qc_factory(butler, quantum)
253 inputRefs, outputRefs = connections1.buildDatasetRefs(quantum)
254 task1.runQuantum(butlerQC, inputRefs, outputRefs)
255 for quantum in quanta2:
256 butlerQC = butler_qc_factory(butler, quantum)
257 inputRefs, outputRefs = connections2.buildDatasetRefs(quantum)
258 task2.runQuantum(butlerQC, inputRefs, outputRefs)
260 # look at the output produced by the task
261 outputName = task1.config.connections.output
262 dsdata = butler.datasets[outputName]
263 self.assertEqual(len(dsdata), len(quanta1))
264 for i, quantum in enumerate(quanta1):
265 ref = quantum.outputs[outputName][0]
266 self.assertEqual(dsdata[ref.dataId], 100 + i + 3)
268 outputName = task2.config.connections.output
269 dsdata = butler.datasets[outputName]
270 self.assertEqual(len(dsdata), len(quanta2))
271 for i, quantum in enumerate(quanta2):
272 ref = quantum.outputs[outputName][0]
273 self.assertEqual(dsdata[ref.dataId], 100 + i + 3 + 200)
275 def testButlerQC(self):
276 """Test for ButlerQuantumContext. Full and limited share get
277 implementation so only full is tested.
278 """
280 butler = ButlerMock()
281 task = AddTask(config=AddConfig())
282 connections = task.config.connections.ConnectionsClass(config=task.config)
284 # make one quantum
285 (quantum,) = self._makeQuanta(task.config, 1)
287 # add input data to butler
288 dstype0 = connections.input.makeDatasetType(butler.registry.dimensions)
289 ref = quantum.inputs[dstype0.name][0]
290 butler.put(100, ref)
292 butlerQC = pipeBase.ButlerQuantumContext.from_full(butler, quantum)
294 # Pass ref as single argument or a list.
295 obj = butlerQC.get(ref)
296 self.assertEqual(obj, 100)
297 obj = butlerQC.get([ref])
298 self.assertEqual(obj, [100])
300 # Pass None instead of a ref.
301 obj = butlerQC.get(None)
302 self.assertIsNone(obj)
303 obj = butlerQC.get([None])
304 self.assertEqual(obj, [None])
306 # COmbine a ref and None.
307 obj = butlerQC.get([ref, None])
308 self.assertEqual(obj, [100, None])
310 # Use refs from a QuantizedConnection.
311 inputRefs, outputRefs = connections.buildDatasetRefs(quantum)
312 obj = butlerQC.get(inputRefs)
313 self.assertEqual(obj, {"input": 100})
315 # Add few None values to a QuantizedConnection.
316 inputRefs.input = [None, ref]
317 inputRefs.input2 = None
318 obj = butlerQC.get(inputRefs)
319 self.assertEqual(obj, {"input": [None, 100], "input2": None})
322class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase):
323 pass
326def setup_module(module):
327 lsst.utils.tests.init()
330if __name__ == "__main__": 330 ↛ 331line 330 didn't jump to line 331, because the condition on line 330 was never true
331 lsst.utils.tests.init()
332 unittest.main()