Coverage for tests/test_pipelineTask.py: 15%
221 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-03-27 02:40 -0700
« prev ^ index » next coverage.py v7.4.4, created at 2024-03-27 02:40 -0700
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
28"""Simple unit test for PipelineTask.
29"""
31import copy
32import pickle
33import unittest
34from typing import Any
36import astropy.units as u
37import lsst.pex.config as pexConfig
38import lsst.pipe.base as pipeBase
39import lsst.utils.logging
40import lsst.utils.tests
41from lsst.daf.butler import DataCoordinate, DatasetRef, DatasetType, DimensionUniverse, Quantum
44class ButlerMock:
45 """Mock version of butler, only usable for this test."""
47 def __init__(self) -> None:
48 self.datasets: dict[str, dict[DataCoordinate, Any]] = {}
49 self.dimensions = DimensionUniverse()
51 def get(self, ref: DatasetRef) -> Any:
52 dsdata = self.datasets.get(ref.datasetType.name)
53 if dsdata:
54 return dsdata.get(ref.dataId)
55 return None
57 def put(self, inMemoryDataset: Any, dsRef: DatasetRef, producer: Any = None):
58 key = dsRef.dataId
59 name = dsRef.datasetType.name
60 dsdata = self.datasets.setdefault(name, {})
61 dsdata[key] = inMemoryDataset
64class AddConnections(pipeBase.PipelineTaskConnections, dimensions=["instrument", "visit"]):
65 """Connections for the AddTask."""
67 input = pipeBase.connectionTypes.Input(
68 name="add_input",
69 dimensions=["instrument", "visit", "detector", "physical_filter", "band"],
70 storageClass="Catalog",
71 doc="Input dataset type for this task",
72 )
73 output = pipeBase.connectionTypes.Output(
74 name="add_output",
75 dimensions=["instrument", "visit", "detector", "physical_filter", "band"],
76 storageClass="Catalog",
77 doc="Output dataset type for this task",
78 )
81class AddConfig(pipeBase.PipelineTaskConfig, pipelineConnections=AddConnections):
82 """Config for the AddTask."""
84 addend = pexConfig.Field[int](doc="amount to add", default=3)
87class AddTask(pipeBase.PipelineTask):
88 """Example task which overrides run() method."""
90 ConfigClass = AddConfig
91 _DefaultName = "add_task"
93 def run(self, input: int) -> pipeBase.Struct:
94 self.metadata.add("add", self.config.addend)
95 output = input + self.config.addend
96 return pipeBase.Struct(output=output)
99class AddTask2(pipeBase.PipelineTask):
100 """Example task which overrides runQuantum() method."""
102 ConfigClass = AddConfig
103 _DefaultName = "add_task"
105 def runQuantum(
106 self, butlerQC: pipeBase.QuantumContext, inputRefs: DatasetRef, outputRefs: DatasetRef
107 ) -> None:
108 self.metadata.add("add", self.config.addend)
109 inputs = butlerQC.get(inputRefs)
110 outputs = inputs["input"] + self.config.addend
111 butlerQC.put(pipeBase.Struct(output=outputs), outputRefs)
114class PipelineTaskTestCase(unittest.TestCase):
115 """A test case for PipelineTask."""
117 def _makeDSRefVisit(self, dstype: DatasetType, visitId: int, universe: DimensionUniverse) -> DatasetRef:
118 dataId = DataCoordinate.standardize(
119 detector="X",
120 visit=visitId,
121 physical_filter="a",
122 band="b",
123 instrument="TestInstrument",
124 universe=universe,
125 )
126 run = "test"
127 ref = DatasetRef(datasetType=dstype, dataId=dataId, run=run)
128 return ref
130 def _makeQuanta(self, config: pipeBase.PipelineTaskConfig, nquanta: int = 100) -> list[Quantum]:
131 """Create set of Quanta."""
132 universe = DimensionUniverse()
133 connections = config.connections.ConnectionsClass(config=config)
135 dstype0 = connections.input.makeDatasetType(universe)
136 dstype1 = connections.output.makeDatasetType(universe)
138 quanta = []
139 for visit in range(nquanta):
140 inputRef = self._makeDSRefVisit(dstype0, visit, universe)
141 outputRef = self._makeDSRefVisit(dstype1, visit, universe)
142 quantum = Quantum(
143 inputs={inputRef.datasetType: [inputRef]}, outputs={outputRef.datasetType: [outputRef]}
144 )
145 quanta.append(quantum)
147 return quanta
149 def testRunQuantumFull(self):
150 """Test for AddTask.runQuantum() implementation with full butler."""
151 self._testRunQuantum(full_butler=True)
153 def testRunQuantumLimited(self):
154 """Test for AddTask.runQuantum() implementation with limited butler."""
155 self._testRunQuantum(full_butler=False)
157 def _testRunQuantum(self, full_butler: bool) -> None:
158 """Test for AddTask.runQuantum() implementation."""
159 butler = ButlerMock()
160 task = AddTask(config=AddConfig())
161 connections = task.config.connections.ConnectionsClass(config=task.config)
163 # make all quanta
164 quanta = self._makeQuanta(task.config)
166 # add input data to butler
167 dstype0 = connections.input.makeDatasetType(butler.dimensions)
168 for i, quantum in enumerate(quanta):
169 ref = quantum.inputs[dstype0.name][0]
170 butler.put(100 + i, ref)
172 # run task on each quanta
173 checked_get = False
174 for quantum in quanta:
175 butlerQC = pipeBase.QuantumContext(butler, quantum)
176 inputRefs, outputRefs = connections.buildDatasetRefs(quantum)
177 task.runQuantum(butlerQC, inputRefs, outputRefs)
179 # Test getting of datasets in different ways.
180 # (only need to do this one time)
181 if not checked_get:
182 # Force the periodic logger to issue messages.
183 lsst.utils.logging.PeriodicLogger.LOGGING_INTERVAL = 0.0
185 checked_get = True
186 with self.assertLogs("lsst.pipe.base", level=lsst.utils.logging.VERBOSE) as cm:
187 input_data = butlerQC.get(inputRefs)
188 self.assertIn("Completed", cm.output[-1])
189 self.assertEqual(len(input_data), len(inputRefs))
191 # In this test there are no multiples returned.
192 refs = [ref for _, ref in inputRefs]
193 with self.assertLogs("lsst.pipe.base", level=lsst.utils.logging.VERBOSE) as cm:
194 list_get = butlerQC.get(refs)
196 lsst.utils.logging.PeriodicLogger.LOGGING_INTERVAL = 600.0
198 self.assertIn("Completed", cm.output[-1])
199 self.assertEqual(len(list_get), len(input_data))
200 self.assertIsInstance(list_get[0], int)
201 scalar_get = butlerQC.get(refs[0])
202 self.assertEqual(scalar_get, list_get[0])
204 with self.assertRaises(TypeError):
205 butlerQC.get({})
207 # Output ref won't be known to this quantum.
208 outputs = [ref for _, ref in outputRefs]
209 with self.assertRaises(ValueError):
210 butlerQC.get(outputs[0])
212 # look at the output produced by the task
213 outputName = connections.output.name
214 dsdata = butler.datasets[outputName]
215 self.assertEqual(len(dsdata), len(quanta))
216 for i, quantum in enumerate(quanta):
217 ref = quantum.outputs[outputName][0]
218 self.assertEqual(dsdata[ref.dataId], 100 + i + 3)
220 def testChain2Full(self) -> None:
221 """Test for two-task chain with full butler."""
222 self._testChain2(full_butler=True)
224 def testChain2Limited(self) -> None:
225 """Test for two-task chain with limited butler."""
226 self._testChain2(full_butler=False)
228 def _testChain2(self, full_butler: bool) -> None:
229 """Test for two-task chain."""
230 butler = ButlerMock()
231 config1 = AddConfig()
232 connections1 = config1.connections.ConnectionsClass(config=config1)
233 task1 = AddTask(config=config1)
234 config2 = AddConfig()
235 config2.addend = 200
236 config2.connections.input = task1.config.connections.output
237 config2.connections.output = "add_output_2"
238 task2 = AddTask2(config=config2)
239 connections2 = config2.connections.ConnectionsClass(config=config2)
241 # make all quanta
242 quanta1 = self._makeQuanta(task1.config)
243 quanta2 = self._makeQuanta(task2.config)
245 # add input data to butler
246 task1Connections = task1.config.connections.ConnectionsClass(config=task1.config)
247 dstype0 = task1Connections.input.makeDatasetType(butler.dimensions)
248 for i, quantum in enumerate(quanta1):
249 ref = quantum.inputs[dstype0.name][0]
250 butler.put(100 + i, ref)
252 # run task on each quanta
253 for quantum in quanta1:
254 butlerQC = pipeBase.QuantumContext(butler, quantum)
255 inputRefs, outputRefs = connections1.buildDatasetRefs(quantum)
256 task1.runQuantum(butlerQC, inputRefs, outputRefs)
257 for quantum in quanta2:
258 butlerQC = pipeBase.QuantumContext(butler, quantum)
259 inputRefs, outputRefs = connections2.buildDatasetRefs(quantum)
260 task2.runQuantum(butlerQC, inputRefs, outputRefs)
262 # look at the output produced by the task
263 outputName = task1.config.connections.output
264 dsdata = butler.datasets[outputName]
265 self.assertEqual(len(dsdata), len(quanta1))
266 for i, quantum in enumerate(quanta1):
267 ref = quantum.outputs[outputName][0]
268 self.assertEqual(dsdata[ref.dataId], 100 + i + 3)
270 outputName = task2.config.connections.output
271 dsdata = butler.datasets[outputName]
272 self.assertEqual(len(dsdata), len(quanta2))
273 for i, quantum in enumerate(quanta2):
274 ref = quantum.outputs[outputName][0]
275 self.assertEqual(dsdata[ref.dataId], 100 + i + 3 + 200)
277 def testButlerQC(self):
278 """Test for QuantumContext. Full and limited share get
279 implementation so only full is tested.
280 """
281 butler = ButlerMock()
282 task = AddTask(config=AddConfig())
283 connections = task.config.connections.ConnectionsClass(config=task.config)
285 # make one quantum
286 (quantum,) = self._makeQuanta(task.config, 1)
288 # add input data to butler
289 dstype0 = connections.input.makeDatasetType(butler.dimensions)
290 ref = quantum.inputs[dstype0.name][0]
291 butler.put(100, ref)
293 butlerQC = pipeBase.QuantumContext(butler, quantum)
294 self.assertEqual(butlerQC.resources.num_cores, 1)
295 self.assertIsNone(butlerQC.resources.max_mem)
297 # Pass ref as single argument or a list.
298 obj = butlerQC.get(ref)
299 self.assertEqual(obj, 100)
300 obj = butlerQC.get([ref])
301 self.assertEqual(obj, [100])
303 # Pass None instead of a ref.
304 obj = butlerQC.get(None)
305 self.assertIsNone(obj)
306 obj = butlerQC.get([None])
307 self.assertEqual(obj, [None])
309 # COmbine a ref and None.
310 obj = butlerQC.get([ref, None])
311 self.assertEqual(obj, [100, None])
313 # Use refs from a QuantizedConnection.
314 inputRefs, outputRefs = connections.buildDatasetRefs(quantum)
315 obj = butlerQC.get(inputRefs)
316 self.assertEqual(obj, {"input": 100})
318 # Add few None values to a QuantizedConnection.
319 inputRefs.input = [None, ref]
320 inputRefs.input2 = None
321 obj = butlerQC.get(inputRefs)
322 self.assertEqual(obj, {"input": [None, 100], "input2": None})
324 # Set additional context.
325 resources = pipeBase.ExecutionResources(num_cores=4, max_mem=5 * u.MB)
326 butlerQC = pipeBase.QuantumContext(butler, quantum, resources=resources)
327 self.assertEqual(butlerQC.resources.num_cores, 4)
328 self.assertEqual(butlerQC.resources.max_mem, 5_000_000 * u.B)
330 resources = pipeBase.ExecutionResources(max_mem=5)
331 butlerQC = pipeBase.QuantumContext(butler, quantum, resources=resources)
332 self.assertEqual(butlerQC.resources.num_cores, 1)
333 self.assertEqual(butlerQC.resources.max_mem, 5 * u.B)
335 def test_ExecutionResources(self):
336 res = pipeBase.ExecutionResources()
337 self.assertEqual(res.num_cores, 1)
338 self.assertIsNone(res.max_mem)
339 self.assertEqual(pickle.loads(pickle.dumps(res)), res)
341 res = pipeBase.ExecutionResources(num_cores=4, max_mem=1 * u.MiB)
342 self.assertEqual(res.num_cores, 4)
343 self.assertEqual(res.max_mem.value, 1024 * 1024)
344 self.assertEqual(pickle.loads(pickle.dumps(res)), res)
346 res = pipeBase.ExecutionResources(max_mem=512)
347 self.assertEqual(res.num_cores, 1)
348 self.assertEqual(res.max_mem.value, 512)
349 self.assertEqual(pickle.loads(pickle.dumps(res)), res)
351 res = pipeBase.ExecutionResources(max_mem="")
352 self.assertIsNone(res.max_mem)
354 res = pipeBase.ExecutionResources(max_mem="32 KiB")
355 self.assertEqual(res.num_cores, 1)
356 self.assertEqual(res.max_mem.value, 32 * 1024)
357 self.assertEqual(pickle.loads(pickle.dumps(res)), res)
359 self.assertIs(res, copy.deepcopy(res))
361 with self.assertRaises(AttributeError):
362 res.num_cores = 4
364 with self.assertRaises(u.UnitConversionError):
365 pipeBase.ExecutionResources(max_mem=1 * u.m)
366 with self.assertRaises(ValueError):
367 pipeBase.ExecutionResources(num_cores=-32)
368 with self.assertRaises(u.UnitConversionError):
369 pipeBase.ExecutionResources(max_mem=1 * u.m)
370 with self.assertRaises(u.UnitConversionError):
371 pipeBase.ExecutionResources(max_mem=1, default_mem_units=u.m)
372 with self.assertRaises(u.UnitConversionError):
373 pipeBase.ExecutionResources(max_mem="32 Pa")
376class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase):
377 """Run file leak tests."""
380def setup_module(module):
381 """Configure pytest."""
382 lsst.utils.tests.init()
385if __name__ == "__main__":
386 lsst.utils.tests.init()
387 unittest.main()