Coverage for tests/test_pipelineTask.py: 15%
223 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-05-02 03:31 -0700
« prev ^ index » next coverage.py v7.5.0, created at 2024-05-02 03:31 -0700
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
28"""Simple unit test for PipelineTask.
29"""
31import copy
32import pickle
33import unittest
34from typing import Any
36import astropy.units as u
37import lsst.pex.config as pexConfig
38import lsst.pipe.base as pipeBase
39import lsst.utils.logging
40import lsst.utils.tests
41from lsst.daf.butler import DataCoordinate, DatasetRef, DatasetType, DimensionUniverse, Quantum
44class ButlerMock:
45 """Mock version of butler, only usable for this test."""
47 def __init__(self) -> None:
48 self.datasets: dict[str, dict[DataCoordinate, Any]] = {}
49 self.dimensions = DimensionUniverse()
51 def get(self, ref: DatasetRef) -> Any:
52 dsdata = self.datasets.get(ref.datasetType.name)
53 if dsdata:
54 return dsdata.get(ref.dataId)
55 return None
57 def put(self, inMemoryDataset: Any, dsRef: DatasetRef, producer: Any = None):
58 key = dsRef.dataId
59 name = dsRef.datasetType.name
60 dsdata = self.datasets.setdefault(name, {})
61 dsdata[key] = inMemoryDataset
64class AddConnections(pipeBase.PipelineTaskConnections, dimensions=["instrument", "visit"]):
65 """Connections for the AddTask."""
67 input = pipeBase.connectionTypes.Input(
68 name="add_input",
69 dimensions=["instrument", "visit", "detector", "physical_filter", "band"],
70 storageClass="Catalog",
71 doc="Input dataset type for this task",
72 )
73 output = pipeBase.connectionTypes.Output(
74 name="add_output",
75 dimensions=["instrument", "visit", "detector", "physical_filter", "band"],
76 storageClass="Catalog",
77 doc="Output dataset type for this task",
78 )
81class AddConfig(pipeBase.PipelineTaskConfig, pipelineConnections=AddConnections):
82 """Config for the AddTask."""
84 addend = pexConfig.Field[int](doc="amount to add", default=3)
87class AddTask(pipeBase.PipelineTask):
88 """Example task which overrides run() method."""
90 ConfigClass = AddConfig
91 _DefaultName = "add_task"
93 def run(self, input: int) -> pipeBase.Struct:
94 self.metadata.add("add", self.config.addend)
95 output = input + self.config.addend
96 return pipeBase.Struct(output=output)
99class AddTask2(pipeBase.PipelineTask):
100 """Example task which overrides runQuantum() method."""
102 ConfigClass = AddConfig
103 _DefaultName = "add_task_2"
105 def runQuantum(
106 self, butlerQC: pipeBase.QuantumContext, inputRefs: DatasetRef, outputRefs: DatasetRef
107 ) -> None:
108 self.metadata.add("add", self.config.addend)
109 inputs = butlerQC.get(inputRefs)
110 outputs = inputs["input"] + self.config.addend
111 butlerQC.put(pipeBase.Struct(output=outputs), outputRefs)
114class PipelineTaskTestCase(unittest.TestCase):
115 """A test case for PipelineTask."""
117 def _makeDSRefVisit(self, dstype: DatasetType, visitId: int) -> DatasetRef:
118 dataId = DataCoordinate.standardize(
119 detector="X",
120 visit=visitId,
121 physical_filter="a",
122 band="b",
123 instrument="TestInstrument",
124 dimensions=dstype.dimensions,
125 )
126 run = "test"
127 ref = DatasetRef(datasetType=dstype, dataId=dataId, run=run)
128 return ref
130 def _makeQuanta(
131 self,
132 task_node: pipeBase.pipeline_graph.TaskNode,
133 pipeline_graph: pipeBase.PipelineGraph,
134 nquanta: int = 100,
135 ) -> list[Quantum]:
136 """Create set of Quanta"""
137 quanta = []
138 for visit in range(nquanta):
139 quantum = Quantum(
140 inputs={
141 pipeline_graph.dataset_types[edge.dataset_type_name].dataset_type: [
142 self._makeDSRefVisit(
143 pipeline_graph.dataset_types[edge.dataset_type_name].dataset_type, visit
144 )
145 ]
146 for edge in task_node.inputs.values()
147 },
148 outputs={
149 pipeline_graph.dataset_types[edge.dataset_type_name].dataset_type: [
150 self._makeDSRefVisit(
151 pipeline_graph.dataset_types[edge.dataset_type_name].dataset_type, visit
152 )
153 ]
154 for edge in task_node.outputs.values()
155 },
156 )
157 quanta.append(quantum)
159 return quanta
161 def testRunQuantumFull(self):
162 """Test for AddTask.runQuantum() implementation with full butler."""
163 self._testRunQuantum(full_butler=True)
165 def testRunQuantumLimited(self):
166 """Test for AddTask.runQuantum() implementation with limited butler."""
167 self._testRunQuantum(full_butler=False)
169 def _testRunQuantum(self, full_butler: bool) -> None:
170 """Test for AddTask.runQuantum() implementation."""
171 butler = ButlerMock()
172 task = AddTask(config=AddConfig())
173 input_name = task.config.connections.input
174 output_name = task.config.connections.output
176 pipeline_graph = pipeBase.PipelineGraph()
177 task_node = pipeline_graph.add_task(None, type(task), task.config)
178 pipeline_graph.resolve(dimensions=butler.dimensions, dataset_types={})
180 # make all quanta
181 quanta = self._makeQuanta(task_node, pipeline_graph)
183 # add input data to butler
184 dstype0 = pipeline_graph.dataset_types[input_name].dataset_type
185 for i, quantum in enumerate(quanta):
186 ref = quantum.inputs[dstype0.name][0]
187 butler.put(100 + i, ref)
189 # run task on each quanta
190 checked_get = False
191 for quantum in quanta:
192 butlerQC = pipeBase.QuantumContext(butler, quantum)
193 inputRefs, outputRefs = task_node.get_connections().buildDatasetRefs(quantum)
194 task.runQuantum(butlerQC, inputRefs, outputRefs)
196 # Test getting of datasets in different ways.
197 # (only need to do this one time)
198 if not checked_get:
199 # Force the periodic logger to issue messages.
200 lsst.utils.logging.PeriodicLogger.LOGGING_INTERVAL = 0.0
202 checked_get = True
203 with self.assertLogs("lsst.pipe.base", level=lsst.utils.logging.VERBOSE) as cm:
204 input_data = butlerQC.get(inputRefs)
205 self.assertIn("Completed", cm.output[-1])
206 self.assertEqual(len(input_data), len(inputRefs))
208 # In this test there are no multiples returned.
209 refs = [ref for _, ref in inputRefs]
210 with self.assertLogs("lsst.pipe.base", level=lsst.utils.logging.VERBOSE) as cm:
211 list_get = butlerQC.get(refs)
213 lsst.utils.logging.PeriodicLogger.LOGGING_INTERVAL = 600.0
215 self.assertIn("Completed", cm.output[-1])
216 self.assertEqual(len(list_get), len(input_data))
217 self.assertIsInstance(list_get[0], int)
218 scalar_get = butlerQC.get(refs[0])
219 self.assertEqual(scalar_get, list_get[0])
221 with self.assertRaises(TypeError):
222 butlerQC.get({})
224 # Output ref won't be known to this quantum.
225 outputs = [ref for _, ref in outputRefs]
226 with self.assertRaises(ValueError):
227 butlerQC.get(outputs[0])
229 # look at the output produced by the task
230 dsdata = butler.datasets[output_name]
231 self.assertEqual(len(dsdata), len(quanta))
232 for i, quantum in enumerate(quanta):
233 ref = quantum.outputs[output_name][0]
234 self.assertEqual(dsdata[ref.dataId], 100 + i + 3)
236 def testChain2Full(self) -> None:
237 """Test for two-task chain with full butler."""
238 self._testChain2(full_butler=True)
240 def testChain2Limited(self) -> None:
241 """Test for two-task chain with limited butler."""
242 self._testChain2(full_butler=False)
244 def _testChain2(self, full_butler: bool) -> None:
245 """Test for two-task chain."""
246 butler = ButlerMock()
247 config1 = AddConfig()
248 overall_input_name = config1.connections.input
249 intermediate_name = config1.connections.output
250 task1 = AddTask(config=config1)
251 config2 = AddConfig()
252 config2.addend = 200
253 config2.connections.input = intermediate_name
254 overall_output_name = "add_output_2"
255 config2.connections.output = overall_output_name
256 task2 = AddTask2(config=config2)
258 pipeline_graph = pipeBase.PipelineGraph()
259 task1_node = pipeline_graph.add_task(None, type(task1), task1.config)
260 task2_node = pipeline_graph.add_task(None, type(task2), task2.config)
261 pipeline_graph.resolve(dimensions=butler.dimensions, dataset_types={})
263 # make all quanta
264 quanta1 = self._makeQuanta(task1_node, pipeline_graph)
265 quanta2 = self._makeQuanta(task2_node, pipeline_graph)
267 # add input data to butler
268 dstype0 = pipeline_graph.dataset_types[overall_input_name].dataset_type
269 for i, quantum in enumerate(quanta1):
270 ref = quantum.inputs[dstype0.name][0]
271 butler.put(100 + i, ref)
273 # run task on each quanta
274 for quantum in quanta1:
275 butlerQC = pipeBase.QuantumContext(butler, quantum)
276 inputRefs, outputRefs = task1_node.get_connections().buildDatasetRefs(quantum)
277 task1.runQuantum(butlerQC, inputRefs, outputRefs)
278 for quantum in quanta2:
279 butlerQC = pipeBase.QuantumContext(butler, quantum)
280 inputRefs, outputRefs = task2_node.get_connections().buildDatasetRefs(quantum)
281 task2.runQuantum(butlerQC, inputRefs, outputRefs)
283 # look at the output produced by the task
284 dsdata = butler.datasets[intermediate_name]
285 self.assertEqual(len(dsdata), len(quanta1))
286 for i, quantum in enumerate(quanta1):
287 ref = quantum.outputs[intermediate_name][0]
288 self.assertEqual(dsdata[ref.dataId], 100 + i + 3)
290 dsdata = butler.datasets[overall_output_name]
291 self.assertEqual(len(dsdata), len(quanta2))
292 for i, quantum in enumerate(quanta2):
293 ref = quantum.outputs[overall_output_name][0]
294 self.assertEqual(dsdata[ref.dataId], 100 + i + 3 + 200)
296 def testButlerQC(self):
297 """Test for QuantumContext. Full and limited share get
298 implementation so only full is tested.
299 """
300 butler = ButlerMock()
301 task = AddTask(config=AddConfig())
302 input_name = task.config.connections.input
304 pipeline_graph = pipeBase.PipelineGraph()
305 task_node = pipeline_graph.add_task(None, type(task), config=task.config)
306 pipeline_graph.resolve(dimensions=butler.dimensions, dataset_types={})
308 # make one quantum
309 (quantum,) = self._makeQuanta(task_node, pipeline_graph, 1)
311 # add input data to butler
312 dstype0 = pipeline_graph.dataset_types[input_name].dataset_type
313 ref = quantum.inputs[dstype0.name][0]
314 butler.put(100, ref)
316 butlerQC = pipeBase.QuantumContext(butler, quantum)
317 self.assertEqual(butlerQC.resources.num_cores, 1)
318 self.assertIsNone(butlerQC.resources.max_mem)
320 # Pass ref as single argument or a list.
321 obj = butlerQC.get(ref)
322 self.assertEqual(obj, 100)
323 obj = butlerQC.get([ref])
324 self.assertEqual(obj, [100])
326 # Pass None instead of a ref.
327 obj = butlerQC.get(None)
328 self.assertIsNone(obj)
329 obj = butlerQC.get([None])
330 self.assertEqual(obj, [None])
332 # COmbine a ref and None.
333 obj = butlerQC.get([ref, None])
334 self.assertEqual(obj, [100, None])
336 # Use refs from a QuantizedConnection.
337 inputRefs, outputRefs = task_node.get_connections().buildDatasetRefs(quantum)
338 obj = butlerQC.get(inputRefs)
339 self.assertEqual(obj, {"input": 100})
341 # Add few None values to a QuantizedConnection.
342 inputRefs.input = [None, ref]
343 inputRefs.input2 = None
344 obj = butlerQC.get(inputRefs)
345 self.assertEqual(obj, {"input": [None, 100], "input2": None})
347 # Set additional context.
348 resources = pipeBase.ExecutionResources(num_cores=4, max_mem=5 * u.MB)
349 butlerQC = pipeBase.QuantumContext(butler, quantum, resources=resources)
350 self.assertEqual(butlerQC.resources.num_cores, 4)
351 self.assertEqual(butlerQC.resources.max_mem, 5_000_000 * u.B)
353 resources = pipeBase.ExecutionResources(max_mem=5)
354 butlerQC = pipeBase.QuantumContext(butler, quantum, resources=resources)
355 self.assertEqual(butlerQC.resources.num_cores, 1)
356 self.assertEqual(butlerQC.resources.max_mem, 5 * u.B)
358 def test_ExecutionResources(self):
359 res = pipeBase.ExecutionResources()
360 self.assertEqual(res.num_cores, 1)
361 self.assertIsNone(res.max_mem)
362 self.assertEqual(pickle.loads(pickle.dumps(res)), res)
364 res = pipeBase.ExecutionResources(num_cores=4, max_mem=1 * u.MiB)
365 self.assertEqual(res.num_cores, 4)
366 self.assertEqual(res.max_mem.value, 1024 * 1024)
367 self.assertEqual(pickle.loads(pickle.dumps(res)), res)
369 res = pipeBase.ExecutionResources(max_mem=512)
370 self.assertEqual(res.num_cores, 1)
371 self.assertEqual(res.max_mem.value, 512)
372 self.assertEqual(pickle.loads(pickle.dumps(res)), res)
374 res = pipeBase.ExecutionResources(max_mem="")
375 self.assertIsNone(res.max_mem)
377 res = pipeBase.ExecutionResources(max_mem="32 KiB")
378 self.assertEqual(res.num_cores, 1)
379 self.assertEqual(res.max_mem.value, 32 * 1024)
380 self.assertEqual(pickle.loads(pickle.dumps(res)), res)
382 self.assertIs(res, copy.deepcopy(res))
384 with self.assertRaises(AttributeError):
385 res.num_cores = 4
387 with self.assertRaises(u.UnitConversionError):
388 pipeBase.ExecutionResources(max_mem=1 * u.m)
389 with self.assertRaises(ValueError):
390 pipeBase.ExecutionResources(num_cores=-32)
391 with self.assertRaises(u.UnitConversionError):
392 pipeBase.ExecutionResources(max_mem=1 * u.m)
393 with self.assertRaises(u.UnitConversionError):
394 pipeBase.ExecutionResources(max_mem=1, default_mem_units=u.m)
395 with self.assertRaises(u.UnitConversionError):
396 pipeBase.ExecutionResources(max_mem="32 Pa")
399class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase):
400 """Run file leak tests."""
403def setup_module(module):
404 """Configure pytest."""
405 lsst.utils.tests.init()
408if __name__ == "__main__":
409 lsst.utils.tests.init()
410 unittest.main()