Coverage for tests/test_pipelineTask.py: 19%
192 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-03-11 10:14 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2023-03-11 10:14 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22"""Simple unit test for PipelineTask.
23"""
25import unittest
26from types import SimpleNamespace
27from typing import Any
29import lsst.pex.config as pexConfig
30import lsst.pipe.base as pipeBase
31import lsst.utils.logging
32import lsst.utils.tests
33from lsst.daf.butler import (
34 DataCoordinate,
35 DatasetIdFactory,
36 DatasetRef,
37 DatasetType,
38 DimensionUniverse,
39 Quantum,
40)
43class ButlerMock:
44 """Mock version of butler, only usable for this test"""
46 def __init__(self) -> None:
47 self.datasets: dict[str, dict[DataCoordinate, Any]] = {}
48 self.registry = SimpleNamespace(dimensions=DimensionUniverse())
50 def getDirect(self, ref: DatasetRef) -> Any:
51 # getDirect requires resolved ref
52 assert ref.id is not None
53 dsdata = self.datasets.get(ref.datasetType.name)
54 if dsdata:
55 return dsdata.get(ref.dataId)
56 return None
58 def put(self, inMemoryDataset: Any, dsRef: DatasetRef, producer: Any = None):
59 # put requires unresolved ref
60 assert dsRef.id is None
61 key = dsRef.dataId
62 name = dsRef.datasetType.name
63 dsdata = self.datasets.setdefault(name, {})
64 dsdata[key] = inMemoryDataset
66 def putDirect(self, obj: Any, ref: DatasetRef):
67 # putDirect requires resolved ref
68 assert ref.id is not None
69 self.put(obj, ref.unresolved())
72class AddConnections(pipeBase.PipelineTaskConnections, dimensions=["instrument", "visit"]):
73 input = pipeBase.connectionTypes.Input(
74 name="add_input",
75 dimensions=["instrument", "visit", "detector", "physical_filter", "band"],
76 storageClass="Catalog",
77 doc="Input dataset type for this task",
78 )
79 output = pipeBase.connectionTypes.Output(
80 name="add_output",
81 dimensions=["instrument", "visit", "detector", "physical_filter", "band"],
82 storageClass="Catalog",
83 doc="Output dataset type for this task",
84 )
87class AddConfig(pipeBase.PipelineTaskConfig, pipelineConnections=AddConnections):
88 addend = pexConfig.Field[int](doc="amount to add", default=3)
91# example task which overrides run() method
92class AddTask(pipeBase.PipelineTask):
93 ConfigClass = AddConfig
94 _DefaultName = "add_task"
96 def run(self, input: int) -> pipeBase.Struct:
97 self.metadata.add("add", self.config.addend)
98 output = input + self.config.addend
99 return pipeBase.Struct(output=output)
102# example task which overrides adaptArgsAndRun() method
103class AddTask2(pipeBase.PipelineTask):
104 ConfigClass = AddConfig
105 _DefaultName = "add_task"
107 def runQuantum(
108 self, butlerQC: pipeBase.ButlerQuantumContext, inputRefs: DatasetRef, outputRefs: DatasetRef
109 ) -> None:
110 self.metadata.add("add", self.config.addend)
111 inputs = butlerQC.get(inputRefs)
112 outputs = inputs["input"] + self.config.addend
113 butlerQC.put(pipeBase.Struct(output=outputs), outputRefs)
116class PipelineTaskTestCase(unittest.TestCase):
117 """A test case for PipelineTask"""
119 datasetIdFactory = DatasetIdFactory()
121 def _resolve_ref(self, ref: DatasetRef) -> DatasetRef:
122 return self.datasetIdFactory.resolveRef(ref, "test")
124 def _makeDSRefVisit(
125 self, dstype: DatasetType, visitId: int, universe: DimensionUniverse, resolve: bool = False
126 ) -> DatasetRef:
127 ref = DatasetRef(
128 datasetType=dstype,
129 dataId=DataCoordinate.standardize(
130 detector="X",
131 visit=visitId,
132 physical_filter="a",
133 band="b",
134 instrument="TestInstrument",
135 universe=universe,
136 ),
137 )
138 if resolve:
139 ref = self._resolve_ref(ref)
140 return ref
142 def _makeQuanta(
143 self, config: pipeBase.PipelineTaskConfig, nquanta: int = 100, resolve_outputs: bool = False
144 ) -> list[Quantum]:
145 """Create set of Quanta"""
146 universe = DimensionUniverse()
147 connections = config.connections.ConnectionsClass(config=config)
149 dstype0 = connections.input.makeDatasetType(universe)
150 dstype1 = connections.output.makeDatasetType(universe)
152 quanta = []
153 for visit in range(nquanta):
154 inputRef = self._makeDSRefVisit(dstype0, visit, universe, resolve=True)
155 outputRef = self._makeDSRefVisit(dstype1, visit, universe, resolve=resolve_outputs)
156 quantum = Quantum(
157 inputs={inputRef.datasetType: [inputRef]}, outputs={outputRef.datasetType: [outputRef]}
158 )
159 quanta.append(quantum)
161 return quanta
163 def testRunQuantumFull(self):
164 """Test for AddTask.runQuantum() implementation with full butler."""
165 self._testRunQuantum(full_butler=True)
167 def testRunQuantumLimited(self):
168 """Test for AddTask.runQuantum() implementation with limited butler."""
169 self._testRunQuantum(full_butler=False)
171 def _testRunQuantum(self, full_butler: bool) -> None:
172 """Test for AddTask.runQuantum() implementation."""
174 butler = ButlerMock()
175 task = AddTask(config=AddConfig())
176 connections = task.config.connections.ConnectionsClass(config=task.config)
178 # make all quanta
179 quanta = self._makeQuanta(task.config, resolve_outputs=not full_butler)
181 # add input data to butler
182 dstype0 = connections.input.makeDatasetType(butler.registry.dimensions)
183 for i, quantum in enumerate(quanta):
184 ref = quantum.inputs[dstype0.name][0]
185 butler.putDirect(100 + i, ref)
187 # run task on each quanta
188 checked_get = False
189 for quantum in quanta:
190 if full_butler:
191 butlerQC = pipeBase.ButlerQuantumContext.from_full(butler, quantum)
192 else:
193 butlerQC = pipeBase.ButlerQuantumContext.from_limited(butler, quantum)
194 inputRefs, outputRefs = connections.buildDatasetRefs(quantum)
195 task.runQuantum(butlerQC, inputRefs, outputRefs)
197 # Test getting of datasets in different ways.
198 # (only need to do this one time)
199 if not checked_get:
200 # Force the periodic logger to issue messages.
201 lsst.utils.logging.PeriodicLogger.LOGGING_INTERVAL = 0.0
203 checked_get = True
204 with self.assertLogs("lsst.pipe.base", level=lsst.utils.logging.VERBOSE) as cm:
205 input_data = butlerQC.get(inputRefs)
206 self.assertIn("Completed", cm.output[-1])
207 self.assertEqual(len(input_data), len(inputRefs))
209 # In this test there are no multiples returned.
210 refs = [ref for _, ref in inputRefs]
211 with self.assertLogs("lsst.pipe.base", level=lsst.utils.logging.VERBOSE) as cm:
212 list_get = butlerQC.get(refs)
214 lsst.utils.logging.PeriodicLogger.LOGGING_INTERVAL = 600.0
216 self.assertIn("Completed", cm.output[-1])
217 self.assertEqual(len(list_get), len(input_data))
218 self.assertIsInstance(list_get[0], int)
219 scalar_get = butlerQC.get(refs[0])
220 self.assertEqual(scalar_get, list_get[0])
222 with self.assertRaises(TypeError):
223 butlerQC.get({})
225 # Output ref won't be known to this quantum.
226 outputs = [ref for _, ref in outputRefs]
227 with self.assertRaises(ValueError):
228 butlerQC.get(outputs[0])
230 # look at the output produced by the task
231 outputName = connections.output.name
232 dsdata = butler.datasets[outputName]
233 self.assertEqual(len(dsdata), len(quanta))
234 for i, quantum in enumerate(quanta):
235 ref = quantum.outputs[outputName][0]
236 self.assertEqual(dsdata[ref.dataId], 100 + i + 3)
238 def testChain2Full(self) -> None:
239 """Test for two-task chain with full butler."""
240 self._testChain2(full_butler=True)
242 def testChain2Limited(self) -> None:
243 """Test for two-task chain with limited butler."""
244 self._testChain2(full_butler=True)
246 def _testChain2(self, full_butler: bool) -> None:
247 """Test for two-task chain."""
248 butler = ButlerMock()
249 config1 = AddConfig()
250 connections1 = config1.connections.ConnectionsClass(config=config1)
251 task1 = AddTask(config=config1)
252 config2 = AddConfig()
253 config2.addend = 200
254 config2.connections.input = task1.config.connections.output
255 config2.connections.output = "add_output_2"
256 task2 = AddTask2(config=config2)
257 connections2 = config2.connections.ConnectionsClass(config=config2)
259 # make all quanta
260 quanta1 = self._makeQuanta(task1.config, resolve_outputs=not full_butler)
261 quanta2 = self._makeQuanta(task2.config, resolve_outputs=not full_butler)
263 # add input data to butler
264 task1Connections = task1.config.connections.ConnectionsClass(config=task1.config)
265 dstype0 = task1Connections.input.makeDatasetType(butler.registry.dimensions)
266 for i, quantum in enumerate(quanta1):
267 ref = quantum.inputs[dstype0.name][0]
268 butler.putDirect(100 + i, ref)
270 butler_qc_factory = (
271 pipeBase.ButlerQuantumContext.from_full
272 if full_butler
273 else pipeBase.ButlerQuantumContext.from_limited
274 )
276 # run task on each quanta
277 for quantum in quanta1:
278 butlerQC = butler_qc_factory(butler, quantum)
279 inputRefs, outputRefs = connections1.buildDatasetRefs(quantum)
280 task1.runQuantum(butlerQC, inputRefs, outputRefs)
281 for quantum in quanta2:
282 butlerQC = butler_qc_factory(butler, quantum)
283 inputRefs, outputRefs = connections2.buildDatasetRefs(quantum)
284 task2.runQuantum(butlerQC, inputRefs, outputRefs)
286 # look at the output produced by the task
287 outputName = task1.config.connections.output
288 dsdata = butler.datasets[outputName]
289 self.assertEqual(len(dsdata), len(quanta1))
290 for i, quantum in enumerate(quanta1):
291 ref = quantum.outputs[outputName][0]
292 self.assertEqual(dsdata[ref.dataId], 100 + i + 3)
294 outputName = task2.config.connections.output
295 dsdata = butler.datasets[outputName]
296 self.assertEqual(len(dsdata), len(quanta2))
297 for i, quantum in enumerate(quanta2):
298 ref = quantum.outputs[outputName][0]
299 self.assertEqual(dsdata[ref.dataId], 100 + i + 3 + 200)
301 def testButlerQC(self):
302 """Test for ButlerQuantumContext. Full and limited share get
303 implementation so only full is tested.
304 """
306 butler = ButlerMock()
307 task = AddTask(config=AddConfig())
308 connections = task.config.connections.ConnectionsClass(config=task.config)
310 # make one quantum
311 (quantum,) = self._makeQuanta(task.config, 1)
313 # add input data to butler
314 dstype0 = connections.input.makeDatasetType(butler.registry.dimensions)
315 ref = quantum.inputs[dstype0.name][0]
316 butler.putDirect(100, ref)
318 butlerQC = pipeBase.ButlerQuantumContext.from_full(butler, quantum)
320 # Pass ref as single argument or a list.
321 obj = butlerQC.get(ref)
322 self.assertEqual(obj, 100)
323 obj = butlerQC.get([ref])
324 self.assertEqual(obj, [100])
326 # Pass None instead of a ref.
327 obj = butlerQC.get(None)
328 self.assertIsNone(obj)
329 obj = butlerQC.get([None])
330 self.assertEqual(obj, [None])
332 # COmbine a ref and None.
333 obj = butlerQC.get([ref, None])
334 self.assertEqual(obj, [100, None])
336 # Use refs from a QuantizedConnection.
337 inputRefs, outputRefs = connections.buildDatasetRefs(quantum)
338 obj = butlerQC.get(inputRefs)
339 self.assertEqual(obj, {"input": 100})
341 # Add few None values to a QuantizedConnection.
342 inputRefs.input = [None, ref]
343 inputRefs.input2 = None
344 obj = butlerQC.get(inputRefs)
345 self.assertEqual(obj, {"input": [None, 100], "input2": None})
348class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase):
349 pass
352def setup_module(module):
353 lsst.utils.tests.init()
356if __name__ == "__main__": 356 ↛ 357line 356 didn't jump to line 357, because the condition on line 356 was never true
357 lsst.utils.tests.init()
358 unittest.main()