Coverage for tests / test_pipelineTask.py: 16%
222 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 23:47 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 23:47 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (http://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <http://www.gnu.org/licenses/>.
28"""Simple unit test for PipelineTask."""
30import copy
31import pickle
32import unittest
33from typing import Any
35import astropy.units as u
37import lsst.pex.config as pexConfig
38import lsst.pipe.base as pipeBase
39import lsst.utils.logging
40import lsst.utils.tests
41from lsst.daf.butler import (
42 DataCoordinate,
43 DatasetProvenance,
44 DatasetRef,
45 DatasetType,
46 DimensionUniverse,
47 Quantum,
48)
51class ButlerMock:
52 """Mock version of butler, only usable for this test."""
54 def __init__(self) -> None:
55 self.datasets: dict[str, dict[DataCoordinate, Any]] = {}
56 self.dimensions = DimensionUniverse()
58 def get(self, ref: DatasetRef) -> Any:
59 dsdata = self.datasets.get(ref.datasetType.name)
60 if dsdata:
61 return dsdata.get(ref.dataId)
62 return None
64 def put(
65 self,
66 inMemoryDataset: Any,
67 dsRef: DatasetRef,
68 provenance: DatasetProvenance | None = None,
69 producer: Any = None,
70 ):
71 key = dsRef.dataId
72 name = dsRef.datasetType.name
73 dsdata = self.datasets.setdefault(name, {})
74 dsdata[key] = inMemoryDataset
77class AddConnections(pipeBase.PipelineTaskConnections, dimensions=["instrument", "visit"]):
78 """Connections for the AddTask."""
80 input = pipeBase.connectionTypes.Input(
81 name="add_input",
82 dimensions=["instrument", "visit", "detector", "physical_filter", "band"],
83 storageClass="Catalog",
84 doc="Input dataset type for this task",
85 )
86 output = pipeBase.connectionTypes.Output(
87 name="add_output",
88 dimensions=["instrument", "visit", "detector", "physical_filter", "band"],
89 storageClass="Catalog",
90 doc="Output dataset type for this task",
91 )
94class AddConfig(pipeBase.PipelineTaskConfig, pipelineConnections=AddConnections):
95 """Config for the AddTask."""
97 addend = pexConfig.Field[int](doc="amount to add", default=3)
100class AddTask(pipeBase.PipelineTask):
101 """Example task which overrides run() method."""
103 ConfigClass = AddConfig
104 _DefaultName = "add_task"
106 def run(self, input: int) -> pipeBase.Struct:
107 self.metadata.add("add", self.config.addend)
108 output = input + self.config.addend
109 return pipeBase.Struct(output=output)
112class AddTask2(pipeBase.PipelineTask):
113 """Example task which overrides runQuantum() method."""
115 ConfigClass = AddConfig
116 _DefaultName = "add_task_2"
118 def runQuantum(
119 self, butlerQC: pipeBase.QuantumContext, inputRefs: DatasetRef, outputRefs: DatasetRef
120 ) -> None:
121 self.metadata.add("add", self.config.addend)
122 inputs = butlerQC.get(inputRefs)
123 outputs = inputs["input"] + self.config.addend
124 butlerQC.put(pipeBase.Struct(output=outputs), outputRefs)
127class PipelineTaskTestCase(unittest.TestCase):
128 """A test case for PipelineTask."""
130 def _makeDSRefVisit(self, dstype: DatasetType, visitId: int) -> DatasetRef:
131 dataId = DataCoordinate.standardize(
132 detector="X",
133 visit=visitId,
134 physical_filter="a",
135 band="b",
136 instrument="TestInstrument",
137 dimensions=dstype.dimensions,
138 )
139 run = "test"
140 ref = DatasetRef(datasetType=dstype, dataId=dataId, run=run)
141 return ref
143 def _makeQuanta(
144 self,
145 task_node: pipeBase.pipeline_graph.TaskNode,
146 pipeline_graph: pipeBase.PipelineGraph,
147 nquanta: int = 100,
148 ) -> list[Quantum]:
149 """Create set of Quanta"""
150 quanta = []
151 for visit in range(nquanta):
152 quantum = Quantum(
153 inputs={
154 pipeline_graph.dataset_types[edge.dataset_type_name].dataset_type: [
155 self._makeDSRefVisit(
156 pipeline_graph.dataset_types[edge.dataset_type_name].dataset_type, visit
157 )
158 ]
159 for edge in task_node.inputs.values()
160 },
161 outputs={
162 pipeline_graph.dataset_types[edge.dataset_type_name].dataset_type: [
163 self._makeDSRefVisit(
164 pipeline_graph.dataset_types[edge.dataset_type_name].dataset_type, visit
165 )
166 ]
167 for edge in task_node.outputs.values()
168 },
169 )
170 quanta.append(quantum)
172 return quanta
174 def testRunQuantumFull(self):
175 """Test for AddTask.runQuantum() implementation with full butler."""
176 self._testRunQuantum(full_butler=True)
178 def testRunQuantumLimited(self):
179 """Test for AddTask.runQuantum() implementation with limited butler."""
180 self._testRunQuantum(full_butler=False)
182 def _testRunQuantum(self, full_butler: bool) -> None:
183 """Test for AddTask.runQuantum() implementation."""
184 butler = ButlerMock()
185 task = AddTask(config=AddConfig())
186 input_name = task.config.connections.input
187 output_name = task.config.connections.output
189 pipeline_graph = pipeBase.PipelineGraph()
190 task_node = pipeline_graph.add_task(None, type(task), task.config)
191 pipeline_graph.resolve(dimensions=butler.dimensions, dataset_types={})
193 # make all quanta
194 quanta = self._makeQuanta(task_node, pipeline_graph)
196 # add input data to butler
197 dstype0 = pipeline_graph.dataset_types[input_name].dataset_type
198 for i, quantum in enumerate(quanta):
199 ref = quantum.inputs[dstype0.name][0]
200 butler.put(100 + i, ref)
202 # run task on each quanta
203 checked_get = False
204 for quantum in quanta:
205 butlerQC = pipeBase.QuantumContext(butler, quantum)
206 inputRefs, outputRefs = task_node.get_connections().buildDatasetRefs(quantum)
207 task.runQuantum(butlerQC, inputRefs, outputRefs)
209 # Test getting of datasets in different ways.
210 # (only need to do this one time)
211 if not checked_get:
212 # Force the periodic logger to issue messages.
213 lsst.utils.logging.PeriodicLogger.LOGGING_INTERVAL = 0.0
215 checked_get = True
216 with self.assertLogs("lsst.pipe.base", level=lsst.utils.logging.VERBOSE) as cm:
217 input_data = butlerQC.get(inputRefs)
218 self.assertIn("Completed", cm.output[-1])
219 self.assertEqual(len(input_data), len(inputRefs))
221 # In this test there are no multiples returned.
222 refs = [ref for _, ref in inputRefs]
223 with self.assertLogs("lsst.pipe.base", level=lsst.utils.logging.VERBOSE) as cm:
224 list_get = butlerQC.get(refs)
226 lsst.utils.logging.PeriodicLogger.LOGGING_INTERVAL = 600.0
228 self.assertIn("Completed", cm.output[-1])
229 self.assertEqual(len(list_get), len(input_data))
230 self.assertIsInstance(list_get[0], int)
231 scalar_get = butlerQC.get(refs[0])
232 self.assertEqual(scalar_get, list_get[0])
234 with self.assertRaises(TypeError):
235 butlerQC.get({})
237 # Output ref won't be known to this quantum.
238 outputs = [ref for _, ref in outputRefs]
239 with self.assertRaises(ValueError):
240 butlerQC.get(outputs[0])
242 # look at the output produced by the task
243 dsdata = butler.datasets[output_name]
244 self.assertEqual(len(dsdata), len(quanta))
245 for i, quantum in enumerate(quanta):
246 ref = quantum.outputs[output_name][0]
247 self.assertEqual(dsdata[ref.dataId], 100 + i + 3)
249 def testChain2Full(self) -> None:
250 """Test for two-task chain with full butler."""
251 self._testChain2(full_butler=True)
253 def testChain2Limited(self) -> None:
254 """Test for two-task chain with limited butler."""
255 self._testChain2(full_butler=False)
257 def _testChain2(self, full_butler: bool) -> None:
258 """Test for two-task chain."""
259 butler = ButlerMock()
260 config1 = AddConfig()
261 overall_input_name = config1.connections.input
262 intermediate_name = config1.connections.output
263 task1 = AddTask(config=config1)
264 config2 = AddConfig()
265 config2.addend = 200
266 config2.connections.input = intermediate_name
267 overall_output_name = "add_output_2"
268 config2.connections.output = overall_output_name
269 task2 = AddTask2(config=config2)
271 pipeline_graph = pipeBase.PipelineGraph()
272 task1_node = pipeline_graph.add_task(None, type(task1), task1.config)
273 task2_node = pipeline_graph.add_task(None, type(task2), task2.config)
274 pipeline_graph.resolve(dimensions=butler.dimensions, dataset_types={})
276 # make all quanta
277 quanta1 = self._makeQuanta(task1_node, pipeline_graph)
278 quanta2 = self._makeQuanta(task2_node, pipeline_graph)
280 # add input data to butler
281 dstype0 = pipeline_graph.dataset_types[overall_input_name].dataset_type
282 for i, quantum in enumerate(quanta1):
283 ref = quantum.inputs[dstype0.name][0]
284 butler.put(100 + i, ref)
286 # run task on each quanta
287 for quantum in quanta1:
288 butlerQC = pipeBase.QuantumContext(butler, quantum)
289 inputRefs, outputRefs = task1_node.get_connections().buildDatasetRefs(quantum)
290 task1.runQuantum(butlerQC, inputRefs, outputRefs)
291 for quantum in quanta2:
292 butlerQC = pipeBase.QuantumContext(butler, quantum)
293 inputRefs, outputRefs = task2_node.get_connections().buildDatasetRefs(quantum)
294 task2.runQuantum(butlerQC, inputRefs, outputRefs)
296 # look at the output produced by the task
297 dsdata = butler.datasets[intermediate_name]
298 self.assertEqual(len(dsdata), len(quanta1))
299 for i, quantum in enumerate(quanta1):
300 ref = quantum.outputs[intermediate_name][0]
301 self.assertEqual(dsdata[ref.dataId], 100 + i + 3)
303 dsdata = butler.datasets[overall_output_name]
304 self.assertEqual(len(dsdata), len(quanta2))
305 for i, quantum in enumerate(quanta2):
306 ref = quantum.outputs[overall_output_name][0]
307 self.assertEqual(dsdata[ref.dataId], 100 + i + 3 + 200)
309 def testButlerQC(self):
310 """Test for QuantumContext. Full and limited share get
311 implementation so only full is tested.
312 """
313 butler = ButlerMock()
314 task = AddTask(config=AddConfig())
315 input_name = task.config.connections.input
317 pipeline_graph = pipeBase.PipelineGraph()
318 task_node = pipeline_graph.add_task(None, type(task), config=task.config)
319 pipeline_graph.resolve(dimensions=butler.dimensions, dataset_types={})
321 # make one quantum
322 (quantum,) = self._makeQuanta(task_node, pipeline_graph, 1)
324 # add input data to butler
325 dstype0 = pipeline_graph.dataset_types[input_name].dataset_type
326 ref = quantum.inputs[dstype0.name][0]
327 butler.put(100, ref)
329 butlerQC = pipeBase.QuantumContext(butler, quantum)
330 self.assertEqual(butlerQC.resources.num_cores, 1)
331 self.assertIsNone(butlerQC.resources.max_mem)
333 # Pass ref as single argument or a list.
334 obj = butlerQC.get(ref)
335 self.assertEqual(obj, 100)
336 obj = butlerQC.get([ref])
337 self.assertEqual(obj, [100])
339 # Pass None instead of a ref.
340 obj = butlerQC.get(None)
341 self.assertIsNone(obj)
342 obj = butlerQC.get([None])
343 self.assertEqual(obj, [None])
345 # COmbine a ref and None.
346 obj = butlerQC.get([ref, None])
347 self.assertEqual(obj, [100, None])
349 # Use refs from a QuantizedConnection.
350 inputRefs, outputRefs = task_node.get_connections().buildDatasetRefs(quantum)
351 obj = butlerQC.get(inputRefs)
352 self.assertEqual(obj, {"input": 100})
354 # Add few None values to a QuantizedConnection.
355 inputRefs.input = [None, ref]
356 inputRefs.input2 = None
357 obj = butlerQC.get(inputRefs)
358 self.assertEqual(obj, {"input": [None, 100], "input2": None})
360 # Set additional context.
361 resources = pipeBase.ExecutionResources(num_cores=4, max_mem=5 * u.MB)
362 butlerQC = pipeBase.QuantumContext(butler, quantum, resources=resources)
363 self.assertEqual(butlerQC.resources.num_cores, 4)
364 self.assertEqual(butlerQC.resources.max_mem, 5_000_000 * u.B)
366 resources = pipeBase.ExecutionResources(max_mem=5)
367 butlerQC = pipeBase.QuantumContext(butler, quantum, resources=resources)
368 self.assertEqual(butlerQC.resources.num_cores, 1)
369 self.assertEqual(butlerQC.resources.max_mem, 5 * u.B)
371 def test_ExecutionResources(self):
372 res = pipeBase.ExecutionResources()
373 self.assertEqual(res.num_cores, 1)
374 self.assertIsNone(res.max_mem)
375 self.assertEqual(pickle.loads(pickle.dumps(res)), res)
377 res = pipeBase.ExecutionResources(num_cores=4, max_mem=1 * u.MiB)
378 self.assertEqual(res.num_cores, 4)
379 self.assertEqual(res.max_mem.value, 1024 * 1024)
380 self.assertEqual(pickle.loads(pickle.dumps(res)), res)
382 res = pipeBase.ExecutionResources(max_mem=512)
383 self.assertEqual(res.num_cores, 1)
384 self.assertEqual(res.max_mem.value, 512)
385 self.assertEqual(pickle.loads(pickle.dumps(res)), res)
387 res = pipeBase.ExecutionResources(max_mem="")
388 self.assertIsNone(res.max_mem)
390 res = pipeBase.ExecutionResources(max_mem="32 KiB")
391 self.assertEqual(res.num_cores, 1)
392 self.assertEqual(res.max_mem.value, 32 * 1024)
393 self.assertEqual(pickle.loads(pickle.dumps(res)), res)
395 self.assertIs(res, copy.deepcopy(res))
397 with self.assertRaises(AttributeError):
398 res.num_cores = 4
400 with self.assertRaises(u.UnitConversionError):
401 pipeBase.ExecutionResources(max_mem=1 * u.m)
402 with self.assertRaises(ValueError):
403 pipeBase.ExecutionResources(num_cores=-32)
404 with self.assertRaises(u.UnitConversionError):
405 pipeBase.ExecutionResources(max_mem=1 * u.m)
406 with self.assertRaises(u.UnitConversionError):
407 pipeBase.ExecutionResources(max_mem=1, default_mem_units=u.m)
408 with self.assertRaises(u.UnitConversionError):
409 pipeBase.ExecutionResources(max_mem="32 Pa")
412class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase):
413 """Run file leak tests."""
416def setup_module(module):
417 """Configure pytest."""
418 lsst.utils.tests.init()
421if __name__ == "__main__":
422 lsst.utils.tests.init()
423 unittest.main()