Coverage for tests/test_pipelineTask.py: 19%
188 statements
« prev ^ index » next coverage.py v7.2.3, created at 2023-04-19 04:01 -0700
« prev ^ index » next coverage.py v7.2.3, created at 2023-04-19 04:01 -0700
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22"""Simple unit test for PipelineTask.
23"""
25import unittest
26from types import SimpleNamespace
27from typing import Any
29import lsst.pex.config as pexConfig
30import lsst.pipe.base as pipeBase
31import lsst.utils.logging
32import lsst.utils.tests
33from lsst.daf.butler import (
34 DataCoordinate,
35 DatasetIdFactory,
36 DatasetRef,
37 DatasetType,
38 DimensionUniverse,
39 Quantum,
40)
43class ButlerMock:
44 """Mock version of butler, only usable for this test"""
46 def __init__(self) -> None:
47 self.datasets: dict[str, dict[DataCoordinate, Any]] = {}
48 self.registry = SimpleNamespace(dimensions=DimensionUniverse())
50 def get(self, ref: DatasetRef) -> Any:
51 # Requires resolved ref.
52 assert ref.id is not None
53 dsdata = self.datasets.get(ref.datasetType.name)
54 if dsdata:
55 return dsdata.get(ref.dataId)
56 return None
58 def put(self, inMemoryDataset: Any, dsRef: DatasetRef, producer: Any = None):
59 key = dsRef.dataId
60 name = dsRef.datasetType.name
61 dsdata = self.datasets.setdefault(name, {})
62 dsdata[key] = inMemoryDataset
65class AddConnections(pipeBase.PipelineTaskConnections, dimensions=["instrument", "visit"]):
66 input = pipeBase.connectionTypes.Input(
67 name="add_input",
68 dimensions=["instrument", "visit", "detector", "physical_filter", "band"],
69 storageClass="Catalog",
70 doc="Input dataset type for this task",
71 )
72 output = pipeBase.connectionTypes.Output(
73 name="add_output",
74 dimensions=["instrument", "visit", "detector", "physical_filter", "band"],
75 storageClass="Catalog",
76 doc="Output dataset type for this task",
77 )
80class AddConfig(pipeBase.PipelineTaskConfig, pipelineConnections=AddConnections):
81 addend = pexConfig.Field[int](doc="amount to add", default=3)
84# example task which overrides run() method
85class AddTask(pipeBase.PipelineTask):
86 ConfigClass = AddConfig
87 _DefaultName = "add_task"
89 def run(self, input: int) -> pipeBase.Struct:
90 self.metadata.add("add", self.config.addend)
91 output = input + self.config.addend
92 return pipeBase.Struct(output=output)
95# example task which overrides adaptArgsAndRun() method
96class AddTask2(pipeBase.PipelineTask):
97 ConfigClass = AddConfig
98 _DefaultName = "add_task"
100 def runQuantum(
101 self, butlerQC: pipeBase.ButlerQuantumContext, inputRefs: DatasetRef, outputRefs: DatasetRef
102 ) -> None:
103 self.metadata.add("add", self.config.addend)
104 inputs = butlerQC.get(inputRefs)
105 outputs = inputs["input"] + self.config.addend
106 butlerQC.put(pipeBase.Struct(output=outputs), outputRefs)
109class PipelineTaskTestCase(unittest.TestCase):
110 """A test case for PipelineTask"""
112 datasetIdFactory = DatasetIdFactory()
114 def _resolve_ref(self, ref: DatasetRef) -> DatasetRef:
115 return self.datasetIdFactory.resolveRef(ref, "test")
117 def _makeDSRefVisit(
118 self, dstype: DatasetType, visitId: int, universe: DimensionUniverse, resolve: bool = False
119 ) -> DatasetRef:
120 ref = DatasetRef(
121 datasetType=dstype,
122 dataId=DataCoordinate.standardize(
123 detector="X",
124 visit=visitId,
125 physical_filter="a",
126 band="b",
127 instrument="TestInstrument",
128 universe=universe,
129 ),
130 )
131 if resolve:
132 ref = self._resolve_ref(ref)
133 return ref
135 def _makeQuanta(
136 self, config: pipeBase.PipelineTaskConfig, nquanta: int = 100, resolve_outputs: bool = False
137 ) -> list[Quantum]:
138 """Create set of Quanta"""
139 universe = DimensionUniverse()
140 connections = config.connections.ConnectionsClass(config=config)
142 dstype0 = connections.input.makeDatasetType(universe)
143 dstype1 = connections.output.makeDatasetType(universe)
145 quanta = []
146 for visit in range(nquanta):
147 inputRef = self._makeDSRefVisit(dstype0, visit, universe, resolve=True)
148 outputRef = self._makeDSRefVisit(dstype1, visit, universe, resolve=resolve_outputs)
149 quantum = Quantum(
150 inputs={inputRef.datasetType: [inputRef]}, outputs={outputRef.datasetType: [outputRef]}
151 )
152 quanta.append(quantum)
154 return quanta
156 def testRunQuantumFull(self):
157 """Test for AddTask.runQuantum() implementation with full butler."""
158 self._testRunQuantum(full_butler=True)
160 def testRunQuantumLimited(self):
161 """Test for AddTask.runQuantum() implementation with limited butler."""
162 self._testRunQuantum(full_butler=False)
164 def _testRunQuantum(self, full_butler: bool) -> None:
165 """Test for AddTask.runQuantum() implementation."""
167 butler = ButlerMock()
168 task = AddTask(config=AddConfig())
169 connections = task.config.connections.ConnectionsClass(config=task.config)
171 # make all quanta
172 quanta = self._makeQuanta(task.config, resolve_outputs=not full_butler)
174 # add input data to butler
175 dstype0 = connections.input.makeDatasetType(butler.registry.dimensions)
176 for i, quantum in enumerate(quanta):
177 ref = quantum.inputs[dstype0.name][0]
178 butler.put(100 + i, ref)
180 # run task on each quanta
181 checked_get = False
182 for quantum in quanta:
183 if full_butler:
184 butlerQC = pipeBase.ButlerQuantumContext.from_full(butler, quantum)
185 else:
186 butlerQC = pipeBase.ButlerQuantumContext.from_limited(butler, quantum)
187 inputRefs, outputRefs = connections.buildDatasetRefs(quantum)
188 task.runQuantum(butlerQC, inputRefs, outputRefs)
190 # Test getting of datasets in different ways.
191 # (only need to do this one time)
192 if not checked_get:
193 # Force the periodic logger to issue messages.
194 lsst.utils.logging.PeriodicLogger.LOGGING_INTERVAL = 0.0
196 checked_get = True
197 with self.assertLogs("lsst.pipe.base", level=lsst.utils.logging.VERBOSE) as cm:
198 input_data = butlerQC.get(inputRefs)
199 self.assertIn("Completed", cm.output[-1])
200 self.assertEqual(len(input_data), len(inputRefs))
202 # In this test there are no multiples returned.
203 refs = [ref for _, ref in inputRefs]
204 with self.assertLogs("lsst.pipe.base", level=lsst.utils.logging.VERBOSE) as cm:
205 list_get = butlerQC.get(refs)
207 lsst.utils.logging.PeriodicLogger.LOGGING_INTERVAL = 600.0
209 self.assertIn("Completed", cm.output[-1])
210 self.assertEqual(len(list_get), len(input_data))
211 self.assertIsInstance(list_get[0], int)
212 scalar_get = butlerQC.get(refs[0])
213 self.assertEqual(scalar_get, list_get[0])
215 with self.assertRaises(TypeError):
216 butlerQC.get({})
218 # Output ref won't be known to this quantum.
219 outputs = [ref for _, ref in outputRefs]
220 with self.assertRaises(ValueError):
221 butlerQC.get(outputs[0])
223 # look at the output produced by the task
224 outputName = connections.output.name
225 dsdata = butler.datasets[outputName]
226 self.assertEqual(len(dsdata), len(quanta))
227 for i, quantum in enumerate(quanta):
228 ref = quantum.outputs[outputName][0]
229 self.assertEqual(dsdata[ref.dataId], 100 + i + 3)
231 def testChain2Full(self) -> None:
232 """Test for two-task chain with full butler."""
233 self._testChain2(full_butler=True)
235 def testChain2Limited(self) -> None:
236 """Test for two-task chain with limited butler."""
237 self._testChain2(full_butler=True)
239 def _testChain2(self, full_butler: bool) -> None:
240 """Test for two-task chain."""
241 butler = ButlerMock()
242 config1 = AddConfig()
243 connections1 = config1.connections.ConnectionsClass(config=config1)
244 task1 = AddTask(config=config1)
245 config2 = AddConfig()
246 config2.addend = 200
247 config2.connections.input = task1.config.connections.output
248 config2.connections.output = "add_output_2"
249 task2 = AddTask2(config=config2)
250 connections2 = config2.connections.ConnectionsClass(config=config2)
252 # make all quanta
253 quanta1 = self._makeQuanta(task1.config, resolve_outputs=not full_butler)
254 quanta2 = self._makeQuanta(task2.config, resolve_outputs=not full_butler)
256 # add input data to butler
257 task1Connections = task1.config.connections.ConnectionsClass(config=task1.config)
258 dstype0 = task1Connections.input.makeDatasetType(butler.registry.dimensions)
259 for i, quantum in enumerate(quanta1):
260 ref = quantum.inputs[dstype0.name][0]
261 butler.put(100 + i, ref)
263 butler_qc_factory = (
264 pipeBase.ButlerQuantumContext.from_full
265 if full_butler
266 else pipeBase.ButlerQuantumContext.from_limited
267 )
269 # run task on each quanta
270 for quantum in quanta1:
271 butlerQC = butler_qc_factory(butler, quantum)
272 inputRefs, outputRefs = connections1.buildDatasetRefs(quantum)
273 task1.runQuantum(butlerQC, inputRefs, outputRefs)
274 for quantum in quanta2:
275 butlerQC = butler_qc_factory(butler, quantum)
276 inputRefs, outputRefs = connections2.buildDatasetRefs(quantum)
277 task2.runQuantum(butlerQC, inputRefs, outputRefs)
279 # look at the output produced by the task
280 outputName = task1.config.connections.output
281 dsdata = butler.datasets[outputName]
282 self.assertEqual(len(dsdata), len(quanta1))
283 for i, quantum in enumerate(quanta1):
284 ref = quantum.outputs[outputName][0]
285 self.assertEqual(dsdata[ref.dataId], 100 + i + 3)
287 outputName = task2.config.connections.output
288 dsdata = butler.datasets[outputName]
289 self.assertEqual(len(dsdata), len(quanta2))
290 for i, quantum in enumerate(quanta2):
291 ref = quantum.outputs[outputName][0]
292 self.assertEqual(dsdata[ref.dataId], 100 + i + 3 + 200)
294 def testButlerQC(self):
295 """Test for ButlerQuantumContext. Full and limited share get
296 implementation so only full is tested.
297 """
299 butler = ButlerMock()
300 task = AddTask(config=AddConfig())
301 connections = task.config.connections.ConnectionsClass(config=task.config)
303 # make one quantum
304 (quantum,) = self._makeQuanta(task.config, 1)
306 # add input data to butler
307 dstype0 = connections.input.makeDatasetType(butler.registry.dimensions)
308 ref = quantum.inputs[dstype0.name][0]
309 butler.put(100, ref)
311 butlerQC = pipeBase.ButlerQuantumContext.from_full(butler, quantum)
313 # Pass ref as single argument or a list.
314 obj = butlerQC.get(ref)
315 self.assertEqual(obj, 100)
316 obj = butlerQC.get([ref])
317 self.assertEqual(obj, [100])
319 # Pass None instead of a ref.
320 obj = butlerQC.get(None)
321 self.assertIsNone(obj)
322 obj = butlerQC.get([None])
323 self.assertEqual(obj, [None])
325 # COmbine a ref and None.
326 obj = butlerQC.get([ref, None])
327 self.assertEqual(obj, [100, None])
329 # Use refs from a QuantizedConnection.
330 inputRefs, outputRefs = connections.buildDatasetRefs(quantum)
331 obj = butlerQC.get(inputRefs)
332 self.assertEqual(obj, {"input": 100})
334 # Add few None values to a QuantizedConnection.
335 inputRefs.input = [None, ref]
336 inputRefs.input2 = None
337 obj = butlerQC.get(inputRefs)
338 self.assertEqual(obj, {"input": [None, 100], "input2": None})
341class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase):
342 pass
345def setup_module(module):
346 lsst.utils.tests.init()
349if __name__ == "__main__": 349 ↛ 350line 349 didn't jump to line 350, because the condition on line 349 was never true
350 lsst.utils.tests.init()
351 unittest.main()