Coverage for tests/test_pipelineTask.py: 15%
221 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-23 08:14 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-23 08:14 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <http://www.gnu.org/licenses/>.
22"""Simple unit test for PipelineTask.
23"""
25import copy
26import pickle
27import unittest
28from typing import Any
30import astropy.units as u
31import lsst.pex.config as pexConfig
32import lsst.pipe.base as pipeBase
33import lsst.utils.logging
34import lsst.utils.tests
35from lsst.daf.butler import DataCoordinate, DatasetRef, DatasetType, DimensionUniverse, Quantum
38class ButlerMock:
39 """Mock version of butler, only usable for this test"""
41 def __init__(self) -> None:
42 self.datasets: dict[str, dict[DataCoordinate, Any]] = {}
43 self.dimensions = DimensionUniverse()
45 def get(self, ref: DatasetRef) -> Any:
46 dsdata = self.datasets.get(ref.datasetType.name)
47 if dsdata:
48 return dsdata.get(ref.dataId)
49 return None
51 def put(self, inMemoryDataset: Any, dsRef: DatasetRef, producer: Any = None):
52 key = dsRef.dataId
53 name = dsRef.datasetType.name
54 dsdata = self.datasets.setdefault(name, {})
55 dsdata[key] = inMemoryDataset
58class AddConnections(pipeBase.PipelineTaskConnections, dimensions=["instrument", "visit"]):
59 """Connections for the AddTask."""
61 input = pipeBase.connectionTypes.Input(
62 name="add_input",
63 dimensions=["instrument", "visit", "detector", "physical_filter", "band"],
64 storageClass="Catalog",
65 doc="Input dataset type for this task",
66 )
67 output = pipeBase.connectionTypes.Output(
68 name="add_output",
69 dimensions=["instrument", "visit", "detector", "physical_filter", "band"],
70 storageClass="Catalog",
71 doc="Output dataset type for this task",
72 )
75class AddConfig(pipeBase.PipelineTaskConfig, pipelineConnections=AddConnections):
76 """Config for the AddTask."""
78 addend = pexConfig.Field[int](doc="amount to add", default=3)
81class AddTask(pipeBase.PipelineTask):
82 """Example task which overrides run() method."""
84 ConfigClass = AddConfig
85 _DefaultName = "add_task"
87 def run(self, input: int) -> pipeBase.Struct:
88 self.metadata.add("add", self.config.addend)
89 output = input + self.config.addend
90 return pipeBase.Struct(output=output)
93class AddTask2(pipeBase.PipelineTask):
94 """Example task which overrides runQuantum() method."""
96 ConfigClass = AddConfig
97 _DefaultName = "add_task"
99 def runQuantum(
100 self, butlerQC: pipeBase.QuantumContext, inputRefs: DatasetRef, outputRefs: DatasetRef
101 ) -> None:
102 self.metadata.add("add", self.config.addend)
103 inputs = butlerQC.get(inputRefs)
104 outputs = inputs["input"] + self.config.addend
105 butlerQC.put(pipeBase.Struct(output=outputs), outputRefs)
108class PipelineTaskTestCase(unittest.TestCase):
109 """A test case for PipelineTask"""
111 def _makeDSRefVisit(self, dstype: DatasetType, visitId: int, universe: DimensionUniverse) -> DatasetRef:
112 dataId = DataCoordinate.standardize(
113 detector="X",
114 visit=visitId,
115 physical_filter="a",
116 band="b",
117 instrument="TestInstrument",
118 universe=universe,
119 )
120 run = "test"
121 ref = DatasetRef(datasetType=dstype, dataId=dataId, run=run)
122 return ref
124 def _makeQuanta(self, config: pipeBase.PipelineTaskConfig, nquanta: int = 100) -> list[Quantum]:
125 """Create set of Quanta"""
126 universe = DimensionUniverse()
127 connections = config.connections.ConnectionsClass(config=config)
129 dstype0 = connections.input.makeDatasetType(universe)
130 dstype1 = connections.output.makeDatasetType(universe)
132 quanta = []
133 for visit in range(nquanta):
134 inputRef = self._makeDSRefVisit(dstype0, visit, universe)
135 outputRef = self._makeDSRefVisit(dstype1, visit, universe)
136 quantum = Quantum(
137 inputs={inputRef.datasetType: [inputRef]}, outputs={outputRef.datasetType: [outputRef]}
138 )
139 quanta.append(quantum)
141 return quanta
143 def testRunQuantumFull(self):
144 """Test for AddTask.runQuantum() implementation with full butler."""
145 self._testRunQuantum(full_butler=True)
147 def testRunQuantumLimited(self):
148 """Test for AddTask.runQuantum() implementation with limited butler."""
149 self._testRunQuantum(full_butler=False)
151 def _testRunQuantum(self, full_butler: bool) -> None:
152 """Test for AddTask.runQuantum() implementation."""
153 butler = ButlerMock()
154 task = AddTask(config=AddConfig())
155 connections = task.config.connections.ConnectionsClass(config=task.config)
157 # make all quanta
158 quanta = self._makeQuanta(task.config)
160 # add input data to butler
161 dstype0 = connections.input.makeDatasetType(butler.dimensions)
162 for i, quantum in enumerate(quanta):
163 ref = quantum.inputs[dstype0.name][0]
164 butler.put(100 + i, ref)
166 # run task on each quanta
167 checked_get = False
168 for quantum in quanta:
169 butlerQC = pipeBase.QuantumContext(butler, quantum)
170 inputRefs, outputRefs = connections.buildDatasetRefs(quantum)
171 task.runQuantum(butlerQC, inputRefs, outputRefs)
173 # Test getting of datasets in different ways.
174 # (only need to do this one time)
175 if not checked_get:
176 # Force the periodic logger to issue messages.
177 lsst.utils.logging.PeriodicLogger.LOGGING_INTERVAL = 0.0
179 checked_get = True
180 with self.assertLogs("lsst.pipe.base", level=lsst.utils.logging.VERBOSE) as cm:
181 input_data = butlerQC.get(inputRefs)
182 self.assertIn("Completed", cm.output[-1])
183 self.assertEqual(len(input_data), len(inputRefs))
185 # In this test there are no multiples returned.
186 refs = [ref for _, ref in inputRefs]
187 with self.assertLogs("lsst.pipe.base", level=lsst.utils.logging.VERBOSE) as cm:
188 list_get = butlerQC.get(refs)
190 lsst.utils.logging.PeriodicLogger.LOGGING_INTERVAL = 600.0
192 self.assertIn("Completed", cm.output[-1])
193 self.assertEqual(len(list_get), len(input_data))
194 self.assertIsInstance(list_get[0], int)
195 scalar_get = butlerQC.get(refs[0])
196 self.assertEqual(scalar_get, list_get[0])
198 with self.assertRaises(TypeError):
199 butlerQC.get({})
201 # Output ref won't be known to this quantum.
202 outputs = [ref for _, ref in outputRefs]
203 with self.assertRaises(ValueError):
204 butlerQC.get(outputs[0])
206 # look at the output produced by the task
207 outputName = connections.output.name
208 dsdata = butler.datasets[outputName]
209 self.assertEqual(len(dsdata), len(quanta))
210 for i, quantum in enumerate(quanta):
211 ref = quantum.outputs[outputName][0]
212 self.assertEqual(dsdata[ref.dataId], 100 + i + 3)
214 def testChain2Full(self) -> None:
215 """Test for two-task chain with full butler."""
216 self._testChain2(full_butler=True)
218 def testChain2Limited(self) -> None:
219 """Test for two-task chain with limited butler."""
220 self._testChain2(full_butler=False)
222 def _testChain2(self, full_butler: bool) -> None:
223 """Test for two-task chain."""
224 butler = ButlerMock()
225 config1 = AddConfig()
226 connections1 = config1.connections.ConnectionsClass(config=config1)
227 task1 = AddTask(config=config1)
228 config2 = AddConfig()
229 config2.addend = 200
230 config2.connections.input = task1.config.connections.output
231 config2.connections.output = "add_output_2"
232 task2 = AddTask2(config=config2)
233 connections2 = config2.connections.ConnectionsClass(config=config2)
235 # make all quanta
236 quanta1 = self._makeQuanta(task1.config)
237 quanta2 = self._makeQuanta(task2.config)
239 # add input data to butler
240 task1Connections = task1.config.connections.ConnectionsClass(config=task1.config)
241 dstype0 = task1Connections.input.makeDatasetType(butler.dimensions)
242 for i, quantum in enumerate(quanta1):
243 ref = quantum.inputs[dstype0.name][0]
244 butler.put(100 + i, ref)
246 # run task on each quanta
247 for quantum in quanta1:
248 butlerQC = pipeBase.QuantumContext(butler, quantum)
249 inputRefs, outputRefs = connections1.buildDatasetRefs(quantum)
250 task1.runQuantum(butlerQC, inputRefs, outputRefs)
251 for quantum in quanta2:
252 butlerQC = pipeBase.QuantumContext(butler, quantum)
253 inputRefs, outputRefs = connections2.buildDatasetRefs(quantum)
254 task2.runQuantum(butlerQC, inputRefs, outputRefs)
256 # look at the output produced by the task
257 outputName = task1.config.connections.output
258 dsdata = butler.datasets[outputName]
259 self.assertEqual(len(dsdata), len(quanta1))
260 for i, quantum in enumerate(quanta1):
261 ref = quantum.outputs[outputName][0]
262 self.assertEqual(dsdata[ref.dataId], 100 + i + 3)
264 outputName = task2.config.connections.output
265 dsdata = butler.datasets[outputName]
266 self.assertEqual(len(dsdata), len(quanta2))
267 for i, quantum in enumerate(quanta2):
268 ref = quantum.outputs[outputName][0]
269 self.assertEqual(dsdata[ref.dataId], 100 + i + 3 + 200)
271 def testButlerQC(self):
272 """Test for QuantumContext. Full and limited share get
273 implementation so only full is tested.
274 """
275 butler = ButlerMock()
276 task = AddTask(config=AddConfig())
277 connections = task.config.connections.ConnectionsClass(config=task.config)
279 # make one quantum
280 (quantum,) = self._makeQuanta(task.config, 1)
282 # add input data to butler
283 dstype0 = connections.input.makeDatasetType(butler.dimensions)
284 ref = quantum.inputs[dstype0.name][0]
285 butler.put(100, ref)
287 butlerQC = pipeBase.QuantumContext(butler, quantum)
288 self.assertEqual(butlerQC.resources.num_cores, 1)
289 self.assertIsNone(butlerQC.resources.max_mem)
291 # Pass ref as single argument or a list.
292 obj = butlerQC.get(ref)
293 self.assertEqual(obj, 100)
294 obj = butlerQC.get([ref])
295 self.assertEqual(obj, [100])
297 # Pass None instead of a ref.
298 obj = butlerQC.get(None)
299 self.assertIsNone(obj)
300 obj = butlerQC.get([None])
301 self.assertEqual(obj, [None])
303 # COmbine a ref and None.
304 obj = butlerQC.get([ref, None])
305 self.assertEqual(obj, [100, None])
307 # Use refs from a QuantizedConnection.
308 inputRefs, outputRefs = connections.buildDatasetRefs(quantum)
309 obj = butlerQC.get(inputRefs)
310 self.assertEqual(obj, {"input": 100})
312 # Add few None values to a QuantizedConnection.
313 inputRefs.input = [None, ref]
314 inputRefs.input2 = None
315 obj = butlerQC.get(inputRefs)
316 self.assertEqual(obj, {"input": [None, 100], "input2": None})
318 # Set additional context.
319 resources = pipeBase.ExecutionResources(num_cores=4, max_mem=5 * u.MB)
320 butlerQC = pipeBase.QuantumContext(butler, quantum, resources=resources)
321 self.assertEqual(butlerQC.resources.num_cores, 4)
322 self.assertEqual(butlerQC.resources.max_mem, 5_000_000 * u.B)
324 resources = pipeBase.ExecutionResources(max_mem=5)
325 butlerQC = pipeBase.QuantumContext(butler, quantum, resources=resources)
326 self.assertEqual(butlerQC.resources.num_cores, 1)
327 self.assertEqual(butlerQC.resources.max_mem, 5 * u.B)
329 def test_ExecutionResources(self):
330 res = pipeBase.ExecutionResources()
331 self.assertEqual(res.num_cores, 1)
332 self.assertIsNone(res.max_mem)
333 self.assertEqual(pickle.loads(pickle.dumps(res)), res)
335 res = pipeBase.ExecutionResources(num_cores=4, max_mem=1 * u.MiB)
336 self.assertEqual(res.num_cores, 4)
337 self.assertEqual(res.max_mem.value, 1024 * 1024)
338 self.assertEqual(pickle.loads(pickle.dumps(res)), res)
340 res = pipeBase.ExecutionResources(max_mem=512)
341 self.assertEqual(res.num_cores, 1)
342 self.assertEqual(res.max_mem.value, 512)
343 self.assertEqual(pickle.loads(pickle.dumps(res)), res)
345 res = pipeBase.ExecutionResources(max_mem="")
346 self.assertIsNone(res.max_mem)
348 res = pipeBase.ExecutionResources(max_mem="32 KiB")
349 self.assertEqual(res.num_cores, 1)
350 self.assertEqual(res.max_mem.value, 32 * 1024)
351 self.assertEqual(pickle.loads(pickle.dumps(res)), res)
353 self.assertIs(res, copy.deepcopy(res))
355 with self.assertRaises(AttributeError):
356 res.num_cores = 4
358 with self.assertRaises(u.UnitConversionError):
359 pipeBase.ExecutionResources(max_mem=1 * u.m)
360 with self.assertRaises(ValueError):
361 pipeBase.ExecutionResources(num_cores=-32)
362 with self.assertRaises(u.UnitConversionError):
363 pipeBase.ExecutionResources(max_mem=1 * u.m)
364 with self.assertRaises(u.UnitConversionError):
365 pipeBase.ExecutionResources(max_mem=1, default_mem_units=u.m)
366 with self.assertRaises(u.UnitConversionError):
367 pipeBase.ExecutionResources(max_mem="32 Pa")
370class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase):
371 """Run file leak tests."""
374def setup_module(module):
375 """Configure pytest."""
376 lsst.utils.tests.init()
379if __name__ == "__main__":
380 lsst.utils.tests.init()
381 unittest.main()