Coverage for tests/test_pipelineTask.py: 16%
222 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-25 09:14 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-25 09:14 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22"""Simple unit test for PipelineTask.
23"""
25import copy
26import pickle
27import unittest
28from typing import Any
30import astropy.units as u
31import lsst.pex.config as pexConfig
32import lsst.pipe.base as pipeBase
33import lsst.utils.logging
34import lsst.utils.tests
35from lsst.daf.butler import DataCoordinate, DatasetRef, DatasetType, DimensionUniverse, Quantum
38class ButlerMock:
39 """Mock version of butler, only usable for this test"""
41 def __init__(self) -> None:
42 self.datasets: dict[str, dict[DataCoordinate, Any]] = {}
43 self.dimensions = DimensionUniverse()
45 def get(self, ref: DatasetRef) -> Any:
46 dsdata = self.datasets.get(ref.datasetType.name)
47 if dsdata:
48 return dsdata.get(ref.dataId)
49 return None
51 def put(self, inMemoryDataset: Any, dsRef: DatasetRef, producer: Any = None):
52 key = dsRef.dataId
53 name = dsRef.datasetType.name
54 dsdata = self.datasets.setdefault(name, {})
55 dsdata[key] = inMemoryDataset
58class AddConnections(pipeBase.PipelineTaskConnections, dimensions=["instrument", "visit"]):
59 input = pipeBase.connectionTypes.Input(
60 name="add_input",
61 dimensions=["instrument", "visit", "detector", "physical_filter", "band"],
62 storageClass="Catalog",
63 doc="Input dataset type for this task",
64 )
65 output = pipeBase.connectionTypes.Output(
66 name="add_output",
67 dimensions=["instrument", "visit", "detector", "physical_filter", "band"],
68 storageClass="Catalog",
69 doc="Output dataset type for this task",
70 )
73class AddConfig(pipeBase.PipelineTaskConfig, pipelineConnections=AddConnections):
74 addend = pexConfig.Field[int](doc="amount to add", default=3)
77# example task which overrides run() method
78class AddTask(pipeBase.PipelineTask):
79 ConfigClass = AddConfig
80 _DefaultName = "add_task"
82 def run(self, input: int) -> pipeBase.Struct:
83 self.metadata.add("add", self.config.addend)
84 output = input + self.config.addend
85 return pipeBase.Struct(output=output)
88# example task which overrides adaptArgsAndRun() method
89class AddTask2(pipeBase.PipelineTask):
90 ConfigClass = AddConfig
91 _DefaultName = "add_task"
93 def runQuantum(
94 self, butlerQC: pipeBase.QuantumContext, inputRefs: DatasetRef, outputRefs: DatasetRef
95 ) -> None:
96 self.metadata.add("add", self.config.addend)
97 inputs = butlerQC.get(inputRefs)
98 outputs = inputs["input"] + self.config.addend
99 butlerQC.put(pipeBase.Struct(output=outputs), outputRefs)
102class PipelineTaskTestCase(unittest.TestCase):
103 """A test case for PipelineTask"""
105 def _makeDSRefVisit(self, dstype: DatasetType, visitId: int, universe: DimensionUniverse) -> DatasetRef:
106 dataId = DataCoordinate.standardize(
107 detector="X",
108 visit=visitId,
109 physical_filter="a",
110 band="b",
111 instrument="TestInstrument",
112 universe=universe,
113 )
114 run = "test"
115 ref = DatasetRef(datasetType=dstype, dataId=dataId, run=run)
116 return ref
118 def _makeQuanta(self, config: pipeBase.PipelineTaskConfig, nquanta: int = 100) -> list[Quantum]:
119 """Create set of Quanta"""
120 universe = DimensionUniverse()
121 connections = config.connections.ConnectionsClass(config=config)
123 dstype0 = connections.input.makeDatasetType(universe)
124 dstype1 = connections.output.makeDatasetType(universe)
126 quanta = []
127 for visit in range(nquanta):
128 inputRef = self._makeDSRefVisit(dstype0, visit, universe)
129 outputRef = self._makeDSRefVisit(dstype1, visit, universe)
130 quantum = Quantum(
131 inputs={inputRef.datasetType: [inputRef]}, outputs={outputRef.datasetType: [outputRef]}
132 )
133 quanta.append(quantum)
135 return quanta
137 def testRunQuantumFull(self):
138 """Test for AddTask.runQuantum() implementation with full butler."""
139 self._testRunQuantum(full_butler=True)
141 def testRunQuantumLimited(self):
142 """Test for AddTask.runQuantum() implementation with limited butler."""
143 self._testRunQuantum(full_butler=False)
145 def _testRunQuantum(self, full_butler: bool) -> None:
146 """Test for AddTask.runQuantum() implementation."""
148 butler = ButlerMock()
149 task = AddTask(config=AddConfig())
150 connections = task.config.connections.ConnectionsClass(config=task.config)
152 # make all quanta
153 quanta = self._makeQuanta(task.config)
155 # add input data to butler
156 dstype0 = connections.input.makeDatasetType(butler.dimensions)
157 for i, quantum in enumerate(quanta):
158 ref = quantum.inputs[dstype0.name][0]
159 butler.put(100 + i, ref)
161 # run task on each quanta
162 checked_get = False
163 for quantum in quanta:
164 butlerQC = pipeBase.QuantumContext(butler, quantum)
165 inputRefs, outputRefs = connections.buildDatasetRefs(quantum)
166 task.runQuantum(butlerQC, inputRefs, outputRefs)
168 # Test getting of datasets in different ways.
169 # (only need to do this one time)
170 if not checked_get:
171 # Force the periodic logger to issue messages.
172 lsst.utils.logging.PeriodicLogger.LOGGING_INTERVAL = 0.0
174 checked_get = True
175 with self.assertLogs("lsst.pipe.base", level=lsst.utils.logging.VERBOSE) as cm:
176 input_data = butlerQC.get(inputRefs)
177 self.assertIn("Completed", cm.output[-1])
178 self.assertEqual(len(input_data), len(inputRefs))
180 # In this test there are no multiples returned.
181 refs = [ref for _, ref in inputRefs]
182 with self.assertLogs("lsst.pipe.base", level=lsst.utils.logging.VERBOSE) as cm:
183 list_get = butlerQC.get(refs)
185 lsst.utils.logging.PeriodicLogger.LOGGING_INTERVAL = 600.0
187 self.assertIn("Completed", cm.output[-1])
188 self.assertEqual(len(list_get), len(input_data))
189 self.assertIsInstance(list_get[0], int)
190 scalar_get = butlerQC.get(refs[0])
191 self.assertEqual(scalar_get, list_get[0])
193 with self.assertRaises(TypeError):
194 butlerQC.get({})
196 # Output ref won't be known to this quantum.
197 outputs = [ref for _, ref in outputRefs]
198 with self.assertRaises(ValueError):
199 butlerQC.get(outputs[0])
201 # look at the output produced by the task
202 outputName = connections.output.name
203 dsdata = butler.datasets[outputName]
204 self.assertEqual(len(dsdata), len(quanta))
205 for i, quantum in enumerate(quanta):
206 ref = quantum.outputs[outputName][0]
207 self.assertEqual(dsdata[ref.dataId], 100 + i + 3)
209 def testChain2Full(self) -> None:
210 """Test for two-task chain with full butler."""
211 self._testChain2(full_butler=True)
213 def testChain2Limited(self) -> None:
214 """Test for two-task chain with limited butler."""
215 self._testChain2(full_butler=False)
217 def _testChain2(self, full_butler: bool) -> None:
218 """Test for two-task chain."""
219 butler = ButlerMock()
220 config1 = AddConfig()
221 connections1 = config1.connections.ConnectionsClass(config=config1)
222 task1 = AddTask(config=config1)
223 config2 = AddConfig()
224 config2.addend = 200
225 config2.connections.input = task1.config.connections.output
226 config2.connections.output = "add_output_2"
227 task2 = AddTask2(config=config2)
228 connections2 = config2.connections.ConnectionsClass(config=config2)
230 # make all quanta
231 quanta1 = self._makeQuanta(task1.config)
232 quanta2 = self._makeQuanta(task2.config)
234 # add input data to butler
235 task1Connections = task1.config.connections.ConnectionsClass(config=task1.config)
236 dstype0 = task1Connections.input.makeDatasetType(butler.dimensions)
237 for i, quantum in enumerate(quanta1):
238 ref = quantum.inputs[dstype0.name][0]
239 butler.put(100 + i, ref)
241 # run task on each quanta
242 for quantum in quanta1:
243 butlerQC = pipeBase.QuantumContext(butler, quantum)
244 inputRefs, outputRefs = connections1.buildDatasetRefs(quantum)
245 task1.runQuantum(butlerQC, inputRefs, outputRefs)
246 for quantum in quanta2:
247 butlerQC = pipeBase.QuantumContext(butler, quantum)
248 inputRefs, outputRefs = connections2.buildDatasetRefs(quantum)
249 task2.runQuantum(butlerQC, inputRefs, outputRefs)
251 # look at the output produced by the task
252 outputName = task1.config.connections.output
253 dsdata = butler.datasets[outputName]
254 self.assertEqual(len(dsdata), len(quanta1))
255 for i, quantum in enumerate(quanta1):
256 ref = quantum.outputs[outputName][0]
257 self.assertEqual(dsdata[ref.dataId], 100 + i + 3)
259 outputName = task2.config.connections.output
260 dsdata = butler.datasets[outputName]
261 self.assertEqual(len(dsdata), len(quanta2))
262 for i, quantum in enumerate(quanta2):
263 ref = quantum.outputs[outputName][0]
264 self.assertEqual(dsdata[ref.dataId], 100 + i + 3 + 200)
266 def testButlerQC(self):
267 """Test for QuantumContext. Full and limited share get
268 implementation so only full is tested.
269 """
271 butler = ButlerMock()
272 task = AddTask(config=AddConfig())
273 connections = task.config.connections.ConnectionsClass(config=task.config)
275 # make one quantum
276 (quantum,) = self._makeQuanta(task.config, 1)
278 # add input data to butler
279 dstype0 = connections.input.makeDatasetType(butler.dimensions)
280 ref = quantum.inputs[dstype0.name][0]
281 butler.put(100, ref)
283 butlerQC = pipeBase.QuantumContext(butler, quantum)
284 self.assertEqual(butlerQC.resources.num_cores, 1)
285 self.assertIsNone(butlerQC.resources.max_mem)
287 # Pass ref as single argument or a list.
288 obj = butlerQC.get(ref)
289 self.assertEqual(obj, 100)
290 obj = butlerQC.get([ref])
291 self.assertEqual(obj, [100])
293 # Pass None instead of a ref.
294 obj = butlerQC.get(None)
295 self.assertIsNone(obj)
296 obj = butlerQC.get([None])
297 self.assertEqual(obj, [None])
299 # COmbine a ref and None.
300 obj = butlerQC.get([ref, None])
301 self.assertEqual(obj, [100, None])
303 # Use refs from a QuantizedConnection.
304 inputRefs, outputRefs = connections.buildDatasetRefs(quantum)
305 obj = butlerQC.get(inputRefs)
306 self.assertEqual(obj, {"input": 100})
308 # Add few None values to a QuantizedConnection.
309 inputRefs.input = [None, ref]
310 inputRefs.input2 = None
311 obj = butlerQC.get(inputRefs)
312 self.assertEqual(obj, {"input": [None, 100], "input2": None})
314 # Set additional context.
315 resources = pipeBase.ExecutionResources(num_cores=4, max_mem=5 * u.MB)
316 butlerQC = pipeBase.QuantumContext(butler, quantum, resources=resources)
317 self.assertEqual(butlerQC.resources.num_cores, 4)
318 self.assertEqual(butlerQC.resources.max_mem, 5_000_000 * u.B)
320 resources = pipeBase.ExecutionResources(max_mem=5)
321 butlerQC = pipeBase.QuantumContext(butler, quantum, resources=resources)
322 self.assertEqual(butlerQC.resources.num_cores, 1)
323 self.assertEqual(butlerQC.resources.max_mem, 5 * u.B)
325 def test_ExecutionResources(self):
326 res = pipeBase.ExecutionResources()
327 self.assertEqual(res.num_cores, 1)
328 self.assertIsNone(res.max_mem)
329 self.assertEqual(pickle.loads(pickle.dumps(res)), res)
331 res = pipeBase.ExecutionResources(num_cores=4, max_mem=1 * u.MiB)
332 self.assertEqual(res.num_cores, 4)
333 self.assertEqual(res.max_mem.value, 1024 * 1024)
334 self.assertEqual(pickle.loads(pickle.dumps(res)), res)
336 res = pipeBase.ExecutionResources(max_mem=512)
337 self.assertEqual(res.num_cores, 1)
338 self.assertEqual(res.max_mem.value, 512)
339 self.assertEqual(pickle.loads(pickle.dumps(res)), res)
341 res = pipeBase.ExecutionResources(max_mem="")
342 self.assertIsNone(res.max_mem)
344 res = pipeBase.ExecutionResources(max_mem="32 KiB")
345 self.assertEqual(res.num_cores, 1)
346 self.assertEqual(res.max_mem.value, 32 * 1024)
347 self.assertEqual(pickle.loads(pickle.dumps(res)), res)
349 self.assertIs(res, copy.deepcopy(res))
351 with self.assertRaises(AttributeError):
352 res.num_cores = 4
354 with self.assertRaises(u.UnitConversionError):
355 pipeBase.ExecutionResources(max_mem=1 * u.m)
356 with self.assertRaises(ValueError):
357 pipeBase.ExecutionResources(num_cores=-32)
358 with self.assertRaises(u.UnitConversionError):
359 pipeBase.ExecutionResources(max_mem=1 * u.m)
360 with self.assertRaises(u.UnitConversionError):
361 pipeBase.ExecutionResources(max_mem=1, default_mem_units=u.m)
362 with self.assertRaises(u.UnitConversionError):
363 pipeBase.ExecutionResources(max_mem="32 Pa")
366class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase):
367 pass
370def setup_module(module):
371 lsst.utils.tests.init()
374if __name__ == "__main__":
375 lsst.utils.tests.init()
376 unittest.main()