Coverage for tests / test_mp_graph_executor.py: 13%
264 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-28 08:44 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-28 08:44 +0000
1# This file is part of pipe_base.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This software is dual licensed under the GNU General Public License and also
10# under a 3-clause BSD license. Recipients may choose which of these licenses
11# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
12# respectively. If you choose the GPL option then the following text applies
13# (but note that there is still no warranty even if you opt for BSD instead):
14#
15# This program is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# This program is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program. If not, see <https://www.gnu.org/licenses/>.
28from __future__ import annotations
30import logging
31import multiprocessing
32import multiprocessing.context
33import os
34import signal
35import sys
36import unittest
37import warnings
38from typing import Literal
40import psutil
42from lsst.pipe.base.exec_fixup_data_id import ExecFixupDataId
43from lsst.pipe.base.mp_graph_executor import MPGraphExecutor, MPGraphExecutorError, MPTimeoutError
44from lsst.pipe.base.quantum_reports import ExecutionStatus, Report
45from lsst.pipe.base.tests.mocks import (
46 DynamicConnectionConfig,
47 DynamicTestPipelineTask,
48 DynamicTestPipelineTaskConfig,
49 InMemoryRepo,
50)
52logging.basicConfig(level=logging.DEBUG)
54_LOG = logging.getLogger(__name__)
56TESTDIR = os.path.abspath(os.path.dirname(__file__))
59class NoMultiprocessingTask(DynamicTestPipelineTask):
60 """A test pipeline task that declares that it cannot be used in
61 multiprocessing.
62 """
64 canMultiprocess = False
67def _count_status(report: Report, status: ExecutionStatus) -> int:
68 """Count number of quanta with a given status."""
69 return len([qrep for qrep in report.quantaReports if qrep.status is status])
72class MPGraphExecutorTestCase(unittest.TestCase):
73 """A test case for MPGraphExecutor class."""
75 def test_mpexec_nomp(self) -> None:
76 """Make simple graph and execute."""
77 helper = InMemoryRepo("base.yaml")
78 self.enterContext(helper)
79 helper.add_task(dimensions=["detector"])
80 qgraph = helper.make_quantum_graph()
81 qexec, butler = helper.make_single_quantum_executor()
82 # run in single-process mode
83 mpexec = MPGraphExecutor(num_proc=1, timeout=100, quantum_executor=qexec)
84 mpexec.execute(qgraph) # type: ignore[arg-type]
85 self.assertCountEqual(
86 [ref.dataId["detector"] for ref in butler.get_datasets("dataset_auto1")], [1, 2, 3, 4]
87 )
88 report = mpexec.getReport()
89 assert report is not None
90 self.assertEqual(report.status, ExecutionStatus.SUCCESS)
91 self.assertIsNone(report.exitCode)
92 self.assertIsNone(report.exceptionInfo)
93 self.assertEqual(len(report.quantaReports), 4)
94 self.assertTrue(all(qrep.status == ExecutionStatus.SUCCESS for qrep in report.quantaReports))
95 self.assertTrue(all(qrep.exitCode is None for qrep in report.quantaReports))
96 self.assertTrue(all(qrep.exceptionInfo is None for qrep in report.quantaReports))
97 self.assertTrue(all(qrep.taskLabel == "task_auto1" for qrep in report.quantaReports))
99 def test_mpexec_mp(self) -> None:
100 """Make simple graph and execute."""
101 helper = InMemoryRepo("base.yaml")
102 self.enterContext(helper)
103 helper.add_task(dimensions=["detector"])
104 qg = helper.make_quantum_graph()
105 qexec, butler = helper.make_single_quantum_executor()
107 methods: list[Literal["spawn", "forkserver"]] = ["spawn"]
108 if sys.platform == "linux":
109 methods.append("forkserver")
111 for method in methods:
112 with self.subTest(startMethod=method):
113 # Run in multi-process mode, the order of results is not
114 # defined.
115 mpexec = MPGraphExecutor(num_proc=3, timeout=100, quantum_executor=qexec, start_method=method)
116 mpexec.execute(qg) # type: ignore[arg-type]
117 report = mpexec.getReport()
118 assert report is not None
119 self.assertEqual(report.status, ExecutionStatus.SUCCESS)
120 self.assertIsNone(report.exitCode)
121 self.assertIsNone(report.exceptionInfo)
122 self.assertEqual(len(report.quantaReports), 4)
123 self.assertTrue(all(qrep.status == ExecutionStatus.SUCCESS for qrep in report.quantaReports))
124 self.assertTrue(all(qrep.exitCode == 0 for qrep in report.quantaReports))
125 self.assertTrue(all(qrep.exceptionInfo is None for qrep in report.quantaReports))
126 self.assertTrue(all(qrep.taskLabel == "task_auto1" for qrep in report.quantaReports))
128 def test_mpexec_nompsupport(self) -> None:
129 """Try to run MP for task that has no MP support which should fail."""
130 helper = InMemoryRepo("base.yaml")
131 self.enterContext(helper)
132 helper.add_task(task_class=NoMultiprocessingTask, dimensions=["detector"])
133 qg = helper.make_quantum_graph()
134 qexec, butler = helper.make_single_quantum_executor()
135 mpexec = MPGraphExecutor(num_proc=3, timeout=100, quantum_executor=qexec)
136 with self.assertRaisesRegex(
137 MPGraphExecutorError, "Task 'task_auto1' does not support multiprocessing"
138 ):
139 mpexec.execute(qg) # type: ignore[arg-type]
141 def test_mpexec_fixup(self) -> None:
142 """Make simple graph and execute, add dependencies by executing fixup
143 code.
144 """
145 helper = InMemoryRepo("base.yaml")
146 self.enterContext(helper)
147 helper.add_task(dimensions=["detector"])
148 qg = helper.make_quantum_graph()
149 for reverse in (False, True):
150 qexec, butler = helper.make_single_quantum_executor()
151 fixup = ExecFixupDataId("task_auto1", "detector", reverse=reverse)
152 mpexec = MPGraphExecutor(
153 num_proc=1, timeout=100, quantum_executor=qexec, execution_graph_fixup=fixup
154 )
155 mpexec.execute(qg) # type: ignore[arg-type]
156 expected = [1, 2, 3, 4]
157 if reverse:
158 expected = list(reversed(expected))
159 self.assertEqual(
160 [ref.dataId["detector"] for ref in butler.get_datasets("dataset_auto1")], expected
161 )
163 def test_mpexec_fixup_old_qg(self) -> None:
164 """Test using an old QuantumGraph object to initialize the executor,
165 with an ordering fixup.
166 """
167 helper = InMemoryRepo("base.yaml")
168 self.enterContext(helper)
169 helper.add_task(dimensions=["detector"])
170 qgraph = helper.make_quantum_graph_builder().build(attach_datastore_records=False)
171 for reverse in (False, True):
172 qexec, butler = helper.make_single_quantum_executor()
173 fixup = ExecFixupDataId("task_auto1", "detector", reverse=reverse)
174 mpexec = MPGraphExecutor(
175 num_proc=1, timeout=100, quantum_executor=qexec, execution_graph_fixup=fixup
176 )
177 mpexec.execute(qgraph) # type: ignore[arg-type]
178 expected = [1, 2, 3, 4]
179 if reverse:
180 expected = list(reversed(expected))
181 self.assertEqual(
182 [ref.dataId["detector"] for ref in butler.get_datasets("dataset_auto1")], expected
183 )
185 def test_mpexec_timeout(self) -> None:
186 """Fail due to timeout."""
187 helper = InMemoryRepo("base.yaml")
188 self.enterContext(helper)
189 helper.add_task(label="a")
190 helper.add_task(
191 label="b",
192 inputs={"input_connection": DynamicConnectionConfig(dataset_type_name="dataset_auto0")},
193 )
194 helper.add_task(
195 label="c",
196 inputs={"input_connection": DynamicConnectionConfig(dataset_type_name="dataset_auto0")},
197 config=DynamicTestPipelineTaskConfig(sleep=100.0),
198 )
199 qg = helper.make_quantum_graph()
201 # with failFast we'll get immediate MPTimeoutError
202 qexec, _ = helper.make_single_quantum_executor()
203 mpexec = MPGraphExecutor(num_proc=3, timeout=1, quantum_executor=qexec, fail_fast=True)
204 with self.assertRaises(MPTimeoutError):
205 mpexec.execute(qg) # type: ignore[arg-type]
206 report = mpexec.getReport()
207 assert report is not None and report.exceptionInfo is not None
208 self.assertEqual(report.status, ExecutionStatus.TIMEOUT)
209 self.assertEqual(report.exceptionInfo.className, "lsst.pipe.base.mp_graph_executor.MPTimeoutError")
210 self.assertGreater(len(report.quantaReports), 0)
211 self.assertEqual(_count_status(report, ExecutionStatus.TIMEOUT), 1)
212 self.assertTrue(any(qrep.exitCode is not None and qrep.exitCode < 0 for qrep in report.quantaReports))
213 self.assertTrue(all(qrep.exceptionInfo is None for qrep in report.quantaReports))
215 # with failFast=False exception happens after last task finishes
216 qexec, _ = helper.make_single_quantum_executor()
217 mpexec = MPGraphExecutor(num_proc=3, timeout=3, quantum_executor=qexec, fail_fast=False)
218 with self.assertRaises(MPTimeoutError):
219 mpexec.execute(qg) # type: ignore[arg-type]
220 assert report is not None and report.exceptionInfo is not None
221 self.assertEqual(report.status, ExecutionStatus.TIMEOUT)
222 self.assertEqual(report.exceptionInfo.className, "lsst.pipe.base.mp_graph_executor.MPTimeoutError")
223 # We expect two tasks ('a' and 'b') to finish successfully and one task
224 # ('c') to timeout, which should get us all three reports.
225 # Unfortunately on busy CPU there is no guarantee that tasks finish on
226 # time, so expect more timeouts and issue a warning.
227 if len(report.quantaReports) != 3:
228 warnings.warn(
229 f"Possibly timed out tasks, expected three reports, received {len(report.quantaReports)})."
230 )
231 report = mpexec.getReport()
232 self.assertGreater(_count_status(report, ExecutionStatus.TIMEOUT), 0)
233 self.assertTrue(any(qrep.exitCode is not None and qrep.exitCode < 0 for qrep in report.quantaReports))
234 self.assertTrue(all(qrep.exceptionInfo is None for qrep in report.quantaReports))
236 def test_mpexec_failure(self) -> None:
237 """Failure in one task should not stop other tasks."""
238 helper = InMemoryRepo("base.yaml")
239 self.enterContext(helper)
240 helper.add_task(
241 config=DynamicTestPipelineTaskConfig(fail_condition="detector=2"),
242 dimensions=["detector"],
243 )
244 qg = helper.make_quantum_graph()
245 qexec, _ = helper.make_single_quantum_executor()
246 mpexec = MPGraphExecutor(num_proc=3, timeout=100, quantum_executor=qexec)
247 with self.assertRaisesRegex(MPGraphExecutorError, "One or more tasks failed"):
248 mpexec.execute(qg) # type: ignore[arg-type]
249 report = mpexec.getReport()
250 assert report is not None and report.exceptionInfo is not None
251 self.assertEqual(report.status, ExecutionStatus.FAILURE)
252 self.assertEqual(
253 report.exceptionInfo.className, "lsst.pipe.base.mp_graph_executor.MPGraphExecutorError"
254 )
255 self.assertGreater(len(report.quantaReports), 0)
256 self.assertEqual(_count_status(report, ExecutionStatus.FAILURE), 1)
257 self.assertEqual(_count_status(report, ExecutionStatus.SUCCESS), 3)
258 self.assertTrue(any(qrep.exitCode is not None and qrep.exitCode > 0 for qrep in report.quantaReports))
259 self.assertTrue(any(qrep.exceptionInfo is not None for qrep in report.quantaReports))
261 def test_mpexec_failure_dep(self) -> None:
262 """Failure in one task should skip dependents."""
263 helper = InMemoryRepo("base.yaml")
264 self.enterContext(helper)
265 helper.add_task(
266 "a", config=DynamicTestPipelineTaskConfig(fail_condition="detector=2"), dimensions=["detector"]
267 )
268 helper.add_task("b", dimensions=["detector"]) # depends on 'a', for the same detector.
269 qg = helper.make_quantum_graph()
270 qexec, _ = helper.make_single_quantum_executor()
271 mpexec = MPGraphExecutor(num_proc=3, timeout=100, quantum_executor=qexec)
272 with self.assertRaisesRegex(MPGraphExecutorError, "One or more tasks failed"):
273 mpexec.execute(qg) # type: ignore[arg-type]
274 report = mpexec.getReport()
275 assert report is not None and report.exceptionInfo is not None
276 self.assertEqual(report.status, ExecutionStatus.FAILURE)
277 self.assertEqual(
278 report.exceptionInfo.className, "lsst.pipe.base.mp_graph_executor.MPGraphExecutorError"
279 )
280 # Dependencies of failed tasks do not appear in quantaReports
281 self.assertGreater(len(report.quantaReports), 0)
282 self.assertEqual(_count_status(report, ExecutionStatus.FAILURE), 1)
283 self.assertEqual(_count_status(report, ExecutionStatus.SUCCESS), 6)
284 self.assertEqual(_count_status(report, ExecutionStatus.SKIPPED), 1)
285 self.assertTrue(any(qrep.exitCode is not None and qrep.exitCode > 0 for qrep in report.quantaReports))
286 self.assertTrue(any(qrep.exceptionInfo is not None for qrep in report.quantaReports))
288 def test_mpexec_failure_dep_nomp(self) -> None:
289 """Failure in one task should skip dependents, in-process version."""
290 helper = InMemoryRepo("base.yaml")
291 self.enterContext(helper)
292 helper.add_task(
293 "a", config=DynamicTestPipelineTaskConfig(fail_condition="detector=2"), dimensions=["detector"]
294 )
295 helper.add_task("b", dimensions=["detector"]) # depends on 'a', for the same detector.
296 qg = helper.make_quantum_graph()
297 qexec, butler = helper.make_single_quantum_executor()
298 mpexec = MPGraphExecutor(num_proc=1, timeout=100, quantum_executor=qexec)
299 with self.assertRaisesRegex(MPGraphExecutorError, "One or more tasks failed"):
300 mpexec.execute(qg) # type: ignore[arg-type]
301 self.assertCountEqual(
302 [ref.dataId["detector"] for ref in butler.get_datasets("dataset_auto1")], [1, 3, 4]
303 )
304 self.assertCountEqual(
305 [ref.dataId["detector"] for ref in butler.get_datasets("dataset_auto2")], [1, 3, 4]
306 )
307 report = mpexec.getReport()
308 assert report is not None and report.exceptionInfo is not None
309 self.assertEqual(report.status, ExecutionStatus.FAILURE)
310 self.assertEqual(
311 report.exceptionInfo.className, "lsst.pipe.base.mp_graph_executor.MPGraphExecutorError"
312 )
313 # Dependencies of failed tasks do not appear in quantaReports
314 self.assertGreater(len(report.quantaReports), 0)
315 self.assertEqual(_count_status(report, ExecutionStatus.FAILURE), 1)
316 self.assertEqual(_count_status(report, ExecutionStatus.SUCCESS), 6)
317 self.assertEqual(_count_status(report, ExecutionStatus.SKIPPED), 1)
318 self.assertTrue(all(qrep.exitCode is None for qrep in report.quantaReports))
319 self.assertTrue(any(qrep.exceptionInfo is not None for qrep in report.quantaReports))
321 def test_mpexec_failure_failfast(self) -> None:
322 """Fast fail stops quickly.
324 Timing delay of task 'b' should be sufficient to process
325 failure and raise exception before task 'c'.
326 """
327 helper = InMemoryRepo("base.yaml")
328 self.enterContext(helper)
329 helper.add_task(
330 "a", config=DynamicTestPipelineTaskConfig(fail_condition="detector=2"), dimensions=["detector"]
331 )
332 helper.add_task("b", config=DynamicTestPipelineTaskConfig(sleep=100.0), dimensions=["detector"])
333 helper.add_task("c", dimensions=["detector"]) # depends on 'b', for the same detector.
334 qg = helper.make_quantum_graph()
335 qexec, _ = helper.make_single_quantum_executor()
336 mpexec = MPGraphExecutor(num_proc=3, timeout=100, quantum_executor=qexec, fail_fast=True)
337 with self.assertRaisesRegex(MPGraphExecutorError, "failed, exit code=1"):
338 mpexec.execute(qg) # type: ignore[arg-type]
339 report = mpexec.getReport()
340 assert report is not None and report.exceptionInfo is not None
341 self.assertEqual(report.status, ExecutionStatus.FAILURE)
342 self.assertEqual(
343 report.exceptionInfo.className, "lsst.pipe.base.mp_graph_executor.MPGraphExecutorError"
344 )
345 # Dependencies of failed tasks do not appear in quantaReports
346 self.assertGreater(len(report.quantaReports), 0)
347 self.assertEqual(_count_status(report, ExecutionStatus.FAILURE), 1)
348 self.assertTrue(any(qrep.exitCode is not None and qrep.exitCode > 0 for qrep in report.quantaReports))
349 self.assertTrue(any(qrep.exceptionInfo is not None for qrep in report.quantaReports))
351 def test_mpexec_crash(self) -> None:
352 """Check task crash due to signal."""
353 helper = InMemoryRepo("base.yaml")
354 self.enterContext(helper)
355 helper.add_task(
356 config=DynamicTestPipelineTaskConfig(fail_condition="detector=2", fail_signal=signal.SIGILL),
357 dimensions=["detector"],
358 )
359 qg = helper.make_quantum_graph()
360 qexec, _ = helper.make_single_quantum_executor()
361 mpexec = MPGraphExecutor(num_proc=3, timeout=100, quantum_executor=qexec)
362 with self.assertRaisesRegex(MPGraphExecutorError, "One or more tasks failed"):
363 mpexec.execute(qg) # type: ignore[arg-type]
364 report = mpexec.getReport()
365 assert report is not None and report.exceptionInfo is not None
366 self.assertEqual(report.status, ExecutionStatus.FAILURE)
367 self.assertEqual(
368 report.exceptionInfo.className, "lsst.pipe.base.mp_graph_executor.MPGraphExecutorError"
369 )
370 # Dependencies of failed tasks do not appear in quantaReports
371 self.assertGreater(len(report.quantaReports), 0)
372 self.assertEqual(_count_status(report, ExecutionStatus.FAILURE), 1)
373 self.assertEqual(_count_status(report, ExecutionStatus.SUCCESS), 3)
374 self.assertTrue(any(qrep.exitCode == -signal.SIGILL for qrep in report.quantaReports))
375 self.assertTrue(all(qrep.exceptionInfo is None for qrep in report.quantaReports))
377 def test_mpexec_crash_failfast(self) -> None:
378 """Check task crash due to signal with --fail-fast."""
379 helper = InMemoryRepo("base.yaml")
380 self.enterContext(helper)
381 helper.add_task(
382 "a",
383 config=DynamicTestPipelineTaskConfig(fail_condition="detector=2", fail_signal=signal.SIGILL),
384 dimensions=["detector"],
385 )
386 helper.add_task("b", config=DynamicTestPipelineTaskConfig(sleep=100.0), dimensions=["detector"])
387 helper.add_task("c", dimensions=["detector"]) # depends on 'b', for the same detector.
388 qg = helper.make_quantum_graph()
389 qexec, _ = helper.make_single_quantum_executor()
390 mpexec = MPGraphExecutor(num_proc=3, timeout=100, quantum_executor=qexec, fail_fast=True)
391 with self.assertRaisesRegex(MPGraphExecutorError, "failed, killed by signal 4 .Illegal instruction"):
392 mpexec.execute(qg) # type: ignore[arg-type]
393 report = mpexec.getReport()
394 assert report is not None and report.exceptionInfo is not None
395 self.assertEqual(report.status, ExecutionStatus.FAILURE)
396 self.assertEqual(
397 report.exceptionInfo.className, "lsst.pipe.base.mp_graph_executor.MPGraphExecutorError"
398 )
399 self.assertEqual(_count_status(report, ExecutionStatus.FAILURE), 1)
400 self.assertTrue(any(qrep.exitCode == -signal.SIGILL for qrep in report.quantaReports))
401 self.assertTrue(all(qrep.exceptionInfo is None for qrep in report.quantaReports))
403 def test_mpexec_num_fd(self) -> None:
404 """Check that number of open files stays reasonable."""
405 helper = InMemoryRepo("base.yaml")
406 self.enterContext(helper)
407 helper.add_task("a", task_class=NoMultiprocessingTask, dimensions=["detector", "visit"])
408 helper.add_task("b", task_class=NoMultiprocessingTask, dimensions=["detector", "visit"])
409 qg = helper.make_quantum_graph()
410 qexec, _ = helper.make_single_quantum_executor()
411 this_proc = psutil.Process()
412 num_fds_0 = this_proc.num_fds()
414 # run in multi-process mode, the order of results is not defined
415 mpexec = MPGraphExecutor(num_proc=3, timeout=100, quantum_executor=qexec)
416 mpexec.execute(qg) # type: ignore[arg-type]
418 num_fds_1 = this_proc.num_fds()
419 # They should be the same but allow small growth just in case.
420 # Without DM-26728 fix the difference would be equal to number of
421 # quanta (20).
422 self.assertLess(num_fds_1 - num_fds_0, 5)
425def setup_module(module):
426 """Force spawn to be used if no method given explicitly.
428 This can be removed when Python 3.14 changes the default.
430 Parameters
431 ----------
432 module : `~types.ModuleType`
433 Module to set up.
434 """
435 multiprocessing.set_start_method("spawn", force=True)
438if __name__ == "__main__":
439 # Do not need to force start mode when running standalone.
440 multiprocessing.set_start_method("spawn")
441 unittest.main()