Coverage for tests/test_pipelineTask.py: 18%
190 statements
« prev ^ index » next coverage.py v7.2.3, created at 2023-04-22 02:19 -0700
« prev ^ index » next coverage.py v7.2.3, created at 2023-04-22 02:19 -0700
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22"""Simple unit test for PipelineTask.
23"""
25import unittest
26import warnings
27from types import SimpleNamespace
28from typing import Any
30import lsst.pex.config as pexConfig
31import lsst.pipe.base as pipeBase
32import lsst.utils.logging
33import lsst.utils.tests
34from lsst.daf.butler import (
35 DataCoordinate,
36 DatasetRef,
37 DatasetType,
38 DimensionUniverse,
39 Quantum,
40 UnresolvedRefWarning,
41)
44class ButlerMock:
45 """Mock version of butler, only usable for this test"""
47 def __init__(self) -> None:
48 self.datasets: dict[str, dict[DataCoordinate, Any]] = {}
49 self.registry = SimpleNamespace(dimensions=DimensionUniverse())
51 def get(self, ref: DatasetRef) -> Any:
52 # Requires resolved ref.
53 assert ref.id is not None
54 dsdata = self.datasets.get(ref.datasetType.name)
55 if dsdata:
56 return dsdata.get(ref.dataId)
57 return None
59 def put(self, inMemoryDataset: Any, dsRef: DatasetRef, producer: Any = None):
60 key = dsRef.dataId
61 name = dsRef.datasetType.name
62 dsdata = self.datasets.setdefault(name, {})
63 dsdata[key] = inMemoryDataset
66class AddConnections(pipeBase.PipelineTaskConnections, dimensions=["instrument", "visit"]):
67 input = pipeBase.connectionTypes.Input(
68 name="add_input",
69 dimensions=["instrument", "visit", "detector", "physical_filter", "band"],
70 storageClass="Catalog",
71 doc="Input dataset type for this task",
72 )
73 output = pipeBase.connectionTypes.Output(
74 name="add_output",
75 dimensions=["instrument", "visit", "detector", "physical_filter", "band"],
76 storageClass="Catalog",
77 doc="Output dataset type for this task",
78 )
81class AddConfig(pipeBase.PipelineTaskConfig, pipelineConnections=AddConnections):
82 addend = pexConfig.Field[int](doc="amount to add", default=3)
85# example task which overrides run() method
86class AddTask(pipeBase.PipelineTask):
87 ConfigClass = AddConfig
88 _DefaultName = "add_task"
90 def run(self, input: int) -> pipeBase.Struct:
91 self.metadata.add("add", self.config.addend)
92 output = input + self.config.addend
93 return pipeBase.Struct(output=output)
96# example task which overrides adaptArgsAndRun() method
97class AddTask2(pipeBase.PipelineTask):
98 ConfigClass = AddConfig
99 _DefaultName = "add_task"
101 def runQuantum(
102 self, butlerQC: pipeBase.ButlerQuantumContext, inputRefs: DatasetRef, outputRefs: DatasetRef
103 ) -> None:
104 self.metadata.add("add", self.config.addend)
105 inputs = butlerQC.get(inputRefs)
106 outputs = inputs["input"] + self.config.addend
107 butlerQC.put(pipeBase.Struct(output=outputs), outputRefs)
110class PipelineTaskTestCase(unittest.TestCase):
111 """A test case for PipelineTask"""
113 def _makeDSRefVisit(
114 self, dstype: DatasetType, visitId: int, universe: DimensionUniverse, resolve: bool = False
115 ) -> DatasetRef:
116 dataId = DataCoordinate.standardize(
117 detector="X",
118 visit=visitId,
119 physical_filter="a",
120 band="b",
121 instrument="TestInstrument",
122 universe=universe,
123 )
124 if resolve:
125 run = "test"
126 ref = DatasetRef(datasetType=dstype, dataId=dataId, run=run)
127 else:
128 with warnings.catch_warnings():
129 warnings.simplefilter("ignore", category=UnresolvedRefWarning)
130 ref = DatasetRef(datasetType=dstype, dataId=dataId)
131 return ref
133 def _makeQuanta(
134 self, config: pipeBase.PipelineTaskConfig, nquanta: int = 100, resolve_outputs: bool = False
135 ) -> list[Quantum]:
136 """Create set of Quanta"""
137 universe = DimensionUniverse()
138 connections = config.connections.ConnectionsClass(config=config)
140 dstype0 = connections.input.makeDatasetType(universe)
141 dstype1 = connections.output.makeDatasetType(universe)
143 quanta = []
144 for visit in range(nquanta):
145 inputRef = self._makeDSRefVisit(dstype0, visit, universe, resolve=True)
146 outputRef = self._makeDSRefVisit(dstype1, visit, universe, resolve=resolve_outputs)
147 quantum = Quantum(
148 inputs={inputRef.datasetType: [inputRef]}, outputs={outputRef.datasetType: [outputRef]}
149 )
150 quanta.append(quantum)
152 return quanta
154 def testRunQuantumFull(self):
155 """Test for AddTask.runQuantum() implementation with full butler."""
156 self._testRunQuantum(full_butler=True)
158 def testRunQuantumLimited(self):
159 """Test for AddTask.runQuantum() implementation with limited butler."""
160 self._testRunQuantum(full_butler=False)
162 def _testRunQuantum(self, full_butler: bool) -> None:
163 """Test for AddTask.runQuantum() implementation."""
165 butler = ButlerMock()
166 task = AddTask(config=AddConfig())
167 connections = task.config.connections.ConnectionsClass(config=task.config)
169 # make all quanta
170 quanta = self._makeQuanta(task.config, resolve_outputs=not full_butler)
172 # add input data to butler
173 dstype0 = connections.input.makeDatasetType(butler.registry.dimensions)
174 for i, quantum in enumerate(quanta):
175 ref = quantum.inputs[dstype0.name][0]
176 butler.put(100 + i, ref)
178 # run task on each quanta
179 checked_get = False
180 for quantum in quanta:
181 if full_butler:
182 butlerQC = pipeBase.ButlerQuantumContext.from_full(butler, quantum)
183 else:
184 butlerQC = pipeBase.ButlerQuantumContext.from_limited(butler, quantum)
185 inputRefs, outputRefs = connections.buildDatasetRefs(quantum)
186 task.runQuantum(butlerQC, inputRefs, outputRefs)
188 # Test getting of datasets in different ways.
189 # (only need to do this one time)
190 if not checked_get:
191 # Force the periodic logger to issue messages.
192 lsst.utils.logging.PeriodicLogger.LOGGING_INTERVAL = 0.0
194 checked_get = True
195 with self.assertLogs("lsst.pipe.base", level=lsst.utils.logging.VERBOSE) as cm:
196 input_data = butlerQC.get(inputRefs)
197 self.assertIn("Completed", cm.output[-1])
198 self.assertEqual(len(input_data), len(inputRefs))
200 # In this test there are no multiples returned.
201 refs = [ref for _, ref in inputRefs]
202 with self.assertLogs("lsst.pipe.base", level=lsst.utils.logging.VERBOSE) as cm:
203 list_get = butlerQC.get(refs)
205 lsst.utils.logging.PeriodicLogger.LOGGING_INTERVAL = 600.0
207 self.assertIn("Completed", cm.output[-1])
208 self.assertEqual(len(list_get), len(input_data))
209 self.assertIsInstance(list_get[0], int)
210 scalar_get = butlerQC.get(refs[0])
211 self.assertEqual(scalar_get, list_get[0])
213 with self.assertRaises(TypeError):
214 butlerQC.get({})
216 # Output ref won't be known to this quantum.
217 outputs = [ref for _, ref in outputRefs]
218 with self.assertRaises(ValueError):
219 butlerQC.get(outputs[0])
221 # look at the output produced by the task
222 outputName = connections.output.name
223 dsdata = butler.datasets[outputName]
224 self.assertEqual(len(dsdata), len(quanta))
225 for i, quantum in enumerate(quanta):
226 ref = quantum.outputs[outputName][0]
227 self.assertEqual(dsdata[ref.dataId], 100 + i + 3)
229 def testChain2Full(self) -> None:
230 """Test for two-task chain with full butler."""
231 self._testChain2(full_butler=True)
233 def testChain2Limited(self) -> None:
234 """Test for two-task chain with limited butler."""
235 self._testChain2(full_butler=True)
237 def _testChain2(self, full_butler: bool) -> None:
238 """Test for two-task chain."""
239 butler = ButlerMock()
240 config1 = AddConfig()
241 connections1 = config1.connections.ConnectionsClass(config=config1)
242 task1 = AddTask(config=config1)
243 config2 = AddConfig()
244 config2.addend = 200
245 config2.connections.input = task1.config.connections.output
246 config2.connections.output = "add_output_2"
247 task2 = AddTask2(config=config2)
248 connections2 = config2.connections.ConnectionsClass(config=config2)
250 # make all quanta
251 quanta1 = self._makeQuanta(task1.config, resolve_outputs=not full_butler)
252 quanta2 = self._makeQuanta(task2.config, resolve_outputs=not full_butler)
254 # add input data to butler
255 task1Connections = task1.config.connections.ConnectionsClass(config=task1.config)
256 dstype0 = task1Connections.input.makeDatasetType(butler.registry.dimensions)
257 for i, quantum in enumerate(quanta1):
258 ref = quantum.inputs[dstype0.name][0]
259 butler.put(100 + i, ref)
261 butler_qc_factory = (
262 pipeBase.ButlerQuantumContext.from_full
263 if full_butler
264 else pipeBase.ButlerQuantumContext.from_limited
265 )
267 # run task on each quanta
268 for quantum in quanta1:
269 butlerQC = butler_qc_factory(butler, quantum)
270 inputRefs, outputRefs = connections1.buildDatasetRefs(quantum)
271 task1.runQuantum(butlerQC, inputRefs, outputRefs)
272 for quantum in quanta2:
273 butlerQC = butler_qc_factory(butler, quantum)
274 inputRefs, outputRefs = connections2.buildDatasetRefs(quantum)
275 task2.runQuantum(butlerQC, inputRefs, outputRefs)
277 # look at the output produced by the task
278 outputName = task1.config.connections.output
279 dsdata = butler.datasets[outputName]
280 self.assertEqual(len(dsdata), len(quanta1))
281 for i, quantum in enumerate(quanta1):
282 ref = quantum.outputs[outputName][0]
283 self.assertEqual(dsdata[ref.dataId], 100 + i + 3)
285 outputName = task2.config.connections.output
286 dsdata = butler.datasets[outputName]
287 self.assertEqual(len(dsdata), len(quanta2))
288 for i, quantum in enumerate(quanta2):
289 ref = quantum.outputs[outputName][0]
290 self.assertEqual(dsdata[ref.dataId], 100 + i + 3 + 200)
292 def testButlerQC(self):
293 """Test for ButlerQuantumContext. Full and limited share get
294 implementation so only full is tested.
295 """
297 butler = ButlerMock()
298 task = AddTask(config=AddConfig())
299 connections = task.config.connections.ConnectionsClass(config=task.config)
301 # make one quantum
302 (quantum,) = self._makeQuanta(task.config, 1)
304 # add input data to butler
305 dstype0 = connections.input.makeDatasetType(butler.registry.dimensions)
306 ref = quantum.inputs[dstype0.name][0]
307 butler.put(100, ref)
309 butlerQC = pipeBase.ButlerQuantumContext.from_full(butler, quantum)
311 # Pass ref as single argument or a list.
312 obj = butlerQC.get(ref)
313 self.assertEqual(obj, 100)
314 obj = butlerQC.get([ref])
315 self.assertEqual(obj, [100])
317 # Pass None instead of a ref.
318 obj = butlerQC.get(None)
319 self.assertIsNone(obj)
320 obj = butlerQC.get([None])
321 self.assertEqual(obj, [None])
323 # COmbine a ref and None.
324 obj = butlerQC.get([ref, None])
325 self.assertEqual(obj, [100, None])
327 # Use refs from a QuantizedConnection.
328 inputRefs, outputRefs = connections.buildDatasetRefs(quantum)
329 obj = butlerQC.get(inputRefs)
330 self.assertEqual(obj, {"input": 100})
332 # Add few None values to a QuantizedConnection.
333 inputRefs.input = [None, ref]
334 inputRefs.input2 = None
335 obj = butlerQC.get(inputRefs)
336 self.assertEqual(obj, {"input": [None, 100], "input2": None})
339class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase):
340 pass
343def setup_module(module):
344 lsst.utils.tests.init()
347if __name__ == "__main__": 347 ↛ 348line 347 didn't jump to line 348, because the condition on line 347 was never true
348 lsst.utils.tests.init()
349 unittest.main()