Coverage for python / lsst / utils / tests.py: 30%
361 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 23:31 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 23:31 +0000
1# This file is part of utils.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# Use of this source code is governed by a 3-clause BSD-style
10# license that can be found in the LICENSE file.
12"""Support code for running unit tests."""
14from __future__ import annotations
16__all__ = [
17 "ExecutablesTestCase",
18 "ImportTestCase",
19 "MemoryTestCase",
20 "TestCase",
21 "assertFloatsAlmostEqual",
22 "assertFloatsEqual",
23 "assertFloatsNotEqual",
24 "classParameters",
25 "debugger",
26 "getTempFilePath",
27 "init",
28 "methodParameters",
29 "temporaryDirectory",
30]
32import contextlib
33import functools
34import gc
35import inspect
36import itertools
37import os
38import re
39import shutil
40import subprocess
41import sys
42import tempfile
43import unittest
44import warnings
45from collections.abc import Callable, Container, Iterable, Iterator, Mapping, Sequence
46from importlib import resources
47from typing import Any, ClassVar
49import numpy
50import psutil
52from .doImport import doImport
54# Initialize the list of open files to an empty set
55open_files = set()
58def _get_open_files() -> set[str]:
59 """Return a set containing the list of files currently open in this
60 process.
62 Returns
63 -------
64 open_files : `set`
65 Set containing the list of open files.
66 """
67 return {p.path for p in psutil.Process().open_files()}
70def init() -> None:
71 """Initialize the memory tester and file descriptor leak tester."""
72 global open_files
73 # Reset the list of open files
74 open_files = _get_open_files()
77def sort_tests(tests) -> unittest.TestSuite:
78 """Sort supplied test suites such that MemoryTestCases are at the end.
80 `lsst.utils.tests.MemoryTestCase` tests should always run after any other
81 tests in the module.
83 Parameters
84 ----------
85 tests : sequence
86 Sequence of test suites.
88 Returns
89 -------
90 suite : `unittest.TestSuite`
91 A combined `~unittest.TestSuite` with
92 `~lsst.utils.tests.MemoryTestCase` at the end.
93 """
94 suite = unittest.TestSuite()
95 memtests = []
96 for test_suite in tests:
97 try:
98 # Just test the first test method in the suite for MemoryTestCase
99 # Use loop rather than next as it is possible for a test class
100 # to not have any test methods and the Python community prefers
101 # for loops over catching a StopIteration exception.
102 bases = None
103 for method in test_suite:
104 bases = inspect.getmro(method.__class__)
105 break
106 if bases is not None and MemoryTestCase in bases:
107 memtests.append(test_suite)
108 else:
109 suite.addTests(test_suite)
110 except TypeError:
111 if isinstance(test_suite, MemoryTestCase):
112 memtests.append(test_suite)
113 else:
114 suite.addTest(test_suite)
115 suite.addTests(memtests)
116 return suite
119def _suiteClassWrapper(tests):
120 return unittest.TestSuite(sort_tests(tests))
123# Replace the suiteClass callable in the defaultTestLoader
124# so that we can reorder the test ordering. This will have
125# no effect if no memory test cases are found.
126unittest.defaultTestLoader.suiteClass = _suiteClassWrapper
129class MemoryTestCase(unittest.TestCase):
130 """Check for resource leaks."""
132 ignore_regexps: ClassVar[list[str]] = []
133 """List of regexps to ignore when checking for open files."""
135 @classmethod
136 def tearDownClass(cls) -> None:
137 """Reset the leak counter when the tests have been completed."""
138 init()
140 def testFileDescriptorLeaks(self) -> None:
141 """Check if any file descriptors are open since init() called.
143 Ignores files with certain known path components and any files
144 that match regexp patterns in class property ``ignore_regexps``.
145 """
146 gc.collect()
147 global open_files
148 now_open = _get_open_files()
150 # Some files are opened out of the control of the stack.
151 now_open = {
152 f
153 for f in now_open
154 if not f.endswith(".car")
155 and not f.startswith("/proc/")
156 and not f.startswith("/sys/")
157 and not f.endswith(".ttf")
158 and not (f.startswith("/var/lib/") and f.endswith("/passwd"))
159 and not f.endswith("astropy.log")
160 and not f.endswith("mime/mime.cache")
161 and not f.endswith(".sqlite3")
162 and not any(re.search(r, f) for r in self.ignore_regexps)
163 }
165 diff = now_open.difference(open_files)
166 if diff:
167 for f in diff:
168 print(f"File open: {f}")
169 self.fail(f"Failed to close {len(diff)} file{'s' if len(diff) != 1 else ''}")
172class ExecutablesTestCase(unittest.TestCase):
173 """Test that executables can be run and return good status.
175 The test methods are dynamically created. Callers
176 must subclass this class in their own test file and invoke
177 the create_executable_tests() class method to register the tests.
178 """
180 TESTS_DISCOVERED = -1
182 @classmethod
183 def setUpClass(cls) -> None:
184 """Abort testing if automated test creation was enabled and
185 no tests were found.
186 """
187 if cls.TESTS_DISCOVERED == 0:
188 raise RuntimeError("No executables discovered.")
190 def testSanity(self) -> None:
191 """Ensure that there is at least one test to be
192 executed. This allows the test runner to trigger the class set up
193 machinery to test whether there are some executables to test.
194 """
196 def assertExecutable(
197 self,
198 executable: str,
199 root_dir: str | None = None,
200 args: Sequence[str] | None = None,
201 msg: str | None = None,
202 ) -> None:
203 """Check an executable runs and returns good status.
205 Prints output to standard out. On bad exit status the test
206 fails. If the executable can not be located the test is skipped.
208 Parameters
209 ----------
210 executable : `str`
211 Path to an executable. ``root_dir`` is not used if this is an
212 absolute path.
213 root_dir : `str`, optional
214 Directory containing executable. Ignored if `None`.
215 args : `list` or `tuple`, optional
216 Arguments to be provided to the executable.
217 msg : `str`, optional
218 Message to use when the test fails. Can be `None` for default
219 message.
221 Raises
222 ------
223 AssertionError
224 The executable did not return 0 exit status.
225 """
226 if root_dir is not None and not os.path.isabs(executable):
227 executable = os.path.join(root_dir, executable)
229 # Form the argument list for subprocess
230 sp_args = [executable]
231 argstr = "no arguments"
232 if args is not None:
233 sp_args.extend(args)
234 argstr = 'arguments "' + " ".join(args) + '"'
236 print(f"Running executable '{executable}' with {argstr}...")
237 if not os.path.exists(executable):
238 self.skipTest(f"Executable {executable} is unexpectedly missing")
239 failmsg = None
240 try:
241 output = subprocess.check_output(sp_args)
242 except subprocess.CalledProcessError as e:
243 output = e.output
244 failmsg = f"Bad exit status from '{executable}': {e.returncode}"
245 print(output.decode("utf-8"))
246 if failmsg:
247 if msg is None:
248 msg = failmsg
249 self.fail(msg)
251 @classmethod
252 def _build_test_method(cls, executable: str, root_dir: str) -> None:
253 """Build a test method and attach to class.
255 A test method is created for the supplied excutable located
256 in the supplied root directory. This method is attached to the class
257 so that the test runner will discover the test and run it.
259 Parameters
260 ----------
261 cls : `object`
262 The class in which to create the tests.
263 executable : `str`
264 Name of executable. Can be absolute path.
265 root_dir : `str`
266 Path to executable. Not used if executable path is absolute.
267 """
268 if not os.path.isabs(executable): 268 ↛ 269line 268 didn't jump to line 269 because the condition on line 268 was never true
269 executable = os.path.abspath(os.path.join(root_dir, executable))
271 # Create the test name from the executable path.
272 test_name = "test_exe_" + executable.replace("/", "_")
274 # This is the function that will become the test method
275 def test_executable_runs(*args: Any) -> None:
276 self = args[0]
277 self.assertExecutable(executable)
279 # Give it a name and attach it to the class
280 test_executable_runs.__name__ = test_name
281 setattr(cls, test_name, test_executable_runs)
283 @classmethod
284 def create_executable_tests(cls, ref_file: str, executables: Sequence[str] | None = None) -> None:
285 """Discover executables to test and create corresponding test methods.
287 Scans the directory containing the supplied reference file
288 (usually ``__file__`` supplied from the test class) to look for
289 executables. If executables are found a test method is created
290 for each one. That test method will run the executable and
291 check the returned value.
293 Executable scripts with a ``.py`` extension and shared libraries
294 are ignored by the scanner.
296 This class method must be called before test discovery.
298 Parameters
299 ----------
300 ref_file : `str`
301 Path to a file within the directory to be searched.
302 If the files are in the same location as the test file, then
303 ``__file__`` can be used.
304 executables : `list` or `tuple`, optional
305 Sequence of executables that can override the automated
306 detection. If an executable mentioned here is not found, a
307 skipped test will be created for it, rather than a failed
308 test.
310 Examples
311 --------
312 >>> cls.create_executable_tests(__file__)
313 """
314 # Get the search directory from the reference file
315 ref_dir = os.path.abspath(os.path.dirname(ref_file))
317 if executables is None: 317 ↛ 332line 317 didn't jump to line 332 because the condition on line 317 was always true
318 # Look for executables to test by walking the tree
319 executables = []
320 for root, _, files in os.walk(ref_dir):
321 for f in files:
322 # Skip Python files. Shared libraries are executable.
323 if not f.endswith(".py") and not f.endswith(".so"):
324 full_path = os.path.join(root, f)
325 if os.access(full_path, os.X_OK):
326 executables.append(full_path)
328 # Store the number of tests found for later assessment.
329 # Do not raise an exception if we have no executables as this would
330 # cause the testing to abort before the test runner could properly
331 # integrate it into the failure report.
332 cls.TESTS_DISCOVERED = len(executables)
334 # Create the test functions and attach them to the class
335 for e in executables:
336 cls._build_test_method(e, ref_dir)
339class ImportTestCase(unittest.TestCase):
340 """Test that the named packages can be imported and all files within
341 that package.
343 The test methods are created dynamically. Callers must subclass this
344 method and define the ``PACKAGES`` property.
345 """
347 PACKAGES: ClassVar[Iterable[str]] = ()
348 """Packages to be imported."""
350 SKIP_FILES: ClassVar[Mapping[str, Container[str]]] = {}
351 """Files to be skipped importing; specified as key-value pairs.
353 The key is the package name and the value is a set of files names in that
354 package to skip.
356 Note: Files with names not ending in .py or beginning with leading double
357 underscores are always skipped.
358 """
360 _n_registered = 0
361 """Number of packages registered for testing by this class."""
363 def _test_no_packages_registered_for_import_testing(self) -> None:
364 """Test when no packages have been registered.
366 Without this, if no packages have been listed no tests will be
367 registered and the test system will not report on anything. This
368 test fails and reports why.
369 """
370 raise AssertionError("No packages registered with import test. Was the PACKAGES property set?")
372 def __init_subclass__(cls, **kwargs: Any) -> None:
373 """Create the test methods based on the content of the ``PACKAGES``
374 class property.
375 """
376 super().__init_subclass__(**kwargs)
378 for mod in cls.PACKAGES:
379 test_name = "test_import_" + mod.replace(".", "_")
381 def test_import(*args: Any, mod=mod) -> None:
382 self = args[0]
383 self.assertImport(mod)
385 test_import.__name__ = test_name
386 setattr(cls, test_name, test_import)
387 cls._n_registered += 1
389 # If there are no packages listed that is likely a mistake and
390 # so register a failing test.
391 if cls._n_registered == 0: 391 ↛ 392line 391 didn't jump to line 392 because the condition on line 391 was never true
392 cls.test_no_packages_registered = cls._test_no_packages_registered_for_import_testing
394 def assertImport(self, root_pkg):
395 for file in resources.files(root_pkg).iterdir():
396 file = file.name
397 # When support for python 3.9 is dropped, this could be updated to
398 # use match case construct.
399 if not file.endswith(".py"):
400 continue
401 if file.startswith("__"):
402 continue
403 if file in self.SKIP_FILES.get(root_pkg, ()):
404 continue
405 root, _ = os.path.splitext(file)
406 module_name = f"{root_pkg}.{root}"
407 with self.subTest(module=module_name):
408 try:
409 doImport(module_name)
410 except ImportError as e:
411 raise AssertionError(f"Error importing module {module_name}: {e}") from e
414@contextlib.contextmanager
415def getTempFilePath(ext: str, expectOutput: bool = True) -> Iterator[str]:
416 """Return a path suitable for a temporary file and try to delete the
417 file on success.
419 If the with block completes successfully then the file is deleted,
420 if possible; failure results in a printed warning.
421 If a file is remains when it should not, a RuntimeError exception is
422 raised. This exception is also raised if a file is not present on context
423 manager exit when one is expected to exist.
424 If the block exits with an exception the file if left on disk so it can be
425 examined. The file name has a random component such that nested context
426 managers can be used with the same file suffix.
428 Parameters
429 ----------
430 ext : `str`
431 File name extension, e.g. ``.fits``.
432 expectOutput : `bool`, optional
433 If `True`, a file should be created within the context manager.
434 If `False`, a file should not be present when the context manager
435 exits.
437 Yields
438 ------
439 path : `str`
440 Path for a temporary file. The path is a combination of the caller's
441 file path and the name of the top-level function.
443 Examples
444 --------
445 .. code-block:: python
447 # file tests/testFoo.py
448 import unittest
449 import lsst.utils.tests
452 class FooTestCase(unittest.TestCase):
453 def testBasics(self):
454 self.runTest()
456 def runTest(self):
457 with lsst.utils.tests.getTempFilePath(".fits") as tmpFile:
458 # if tests/.tests exists then
459 # tmpFile = "tests/.tests/testFoo_testBasics.fits"
460 # otherwise tmpFile = "testFoo_testBasics.fits"
461 ...
462 # at the end of this "with" block the path tmpFile will be
463 # deleted, but only if the file exists and the "with"
464 # block terminated normally (rather than with an exception)
467 ...
468 """
469 stack = inspect.stack()
470 # get name of first function in the file
471 for i in range(2, len(stack)):
472 frameInfo = inspect.getframeinfo(stack[i][0])
473 if i == 2:
474 callerFilePath = frameInfo.filename
475 callerFuncName = frameInfo.function
476 elif callerFilePath == frameInfo.filename:
477 # this function called the previous function
478 callerFuncName = frameInfo.function
479 else:
480 break
482 callerDir, callerFileNameWithExt = os.path.split(callerFilePath)
483 callerFileName = os.path.splitext(callerFileNameWithExt)[0]
484 outDir = os.path.join(callerDir, ".tests")
485 if not os.path.isdir(outDir):
486 # No .tests directory implies we are not running with sconsUtils.
487 # Need to use the current working directory, the callerDir, or
488 # /tmp equivalent. If cwd is used if must be as an absolute path
489 # in case the test code changes cwd.
490 outDir = os.path.abspath(os.path.curdir)
491 prefix = f"{callerFileName}_{callerFuncName}-"
492 outPath = tempfile.mktemp(dir=outDir, suffix=ext, prefix=prefix)
493 if os.path.exists(outPath):
494 # There should not be a file there given the randomizer. Warn and
495 # remove.
496 # Use stacklevel 3 so that the warning is reported from the end of the
497 # with block
498 warnings.warn(f"Unexpectedly found pre-existing tempfile named {outPath!r}", stacklevel=3)
499 with contextlib.suppress(OSError):
500 os.remove(outPath)
502 yield outPath
504 fileExists = os.path.exists(outPath)
505 if expectOutput:
506 if not fileExists:
507 raise RuntimeError(f"Temp file expected named {outPath} but none found")
508 else:
509 if fileExists:
510 raise RuntimeError(f"Unexpectedly discovered temp file named {outPath}")
511 # Try to clean up the file regardless
512 if fileExists:
513 try:
514 os.remove(outPath)
515 except OSError as e:
516 # Use stacklevel 3 so that the warning is reported from the end of
517 # the with block.
518 warnings.warn(f"Warning: could not remove file {outPath!r}: {e}", stacklevel=3)
521class TestCase(unittest.TestCase):
522 """Subclass of unittest.TestCase that adds some custom assertions for
523 convenience.
524 """
527def inTestCase(func: Callable) -> Callable:
528 """Add a free function to our custom TestCase class, while
529 also making it available as a free function.
531 Parameters
532 ----------
533 func : `~collections.abc.Callable`
534 Function to be added to `unittest.TestCase` class.
536 Returns
537 -------
538 func : `~collections.abc.Callable`
539 The given function.
540 """
541 setattr(TestCase, func.__name__, func)
542 return func
545def debugger(*exceptions):
546 """Enter the debugger when there's an uncaught exception.
548 To use, just slap a ``@debugger()`` on your function.
550 You may provide specific exception classes to catch as arguments to
551 the decorator function, e.g.,
552 ``@debugger(RuntimeError, NotImplementedError)``.
553 This defaults to just `AssertionError`, for use on `unittest.TestCase`
554 methods.
556 Code provided by "Rosh Oxymoron" on StackOverflow:
557 http://stackoverflow.com/questions/4398967/python-unit-testing-automatically-running-the-debugger-when-a-test-fails
559 Parameters
560 ----------
561 *exceptions : `Exception`
562 Specific exception classes to catch. Default is to catch
563 `AssertionError`.
565 Notes
566 -----
567 Consider using ``pytest --pdb`` instead of this decorator.
568 """
569 if not exceptions:
570 exceptions = (Exception,)
572 def decorator(f):
573 @functools.wraps(f)
574 def wrapper(*args, **kwargs):
575 try:
576 return f(*args, **kwargs)
577 except exceptions:
578 import pdb
579 import sys
581 pdb.post_mortem(sys.exc_info()[2])
583 return wrapper
585 return decorator
588def plotImageDiff(
589 lhs: numpy.ndarray,
590 rhs: numpy.ndarray,
591 bad: numpy.ndarray | None = None,
592 diff: numpy.ndarray | None = None,
593 plotFileName: str | None = None,
594) -> None:
595 """Plot the comparison of two 2-d NumPy arrays.
597 Parameters
598 ----------
599 lhs : `numpy.ndarray`
600 LHS values to compare; a 2-d NumPy array.
601 rhs : `numpy.ndarray`
602 RHS values to compare; a 2-d NumPy array.
603 bad : `numpy.ndarray`
604 A 2-d boolean NumPy array of values to emphasize in the plots.
605 diff : `numpy.ndarray`
606 Difference array; a 2-d NumPy array, or None to show lhs-rhs.
607 plotFileName : `str`
608 Filename to save the plot to. If None, the plot will be displayed in
609 a window.
611 Notes
612 -----
613 This method uses `matplotlib` and imports it internally; it should be
614 wrapped in a try/except block within packages that do not depend on
615 `matplotlib` (including `~lsst.utils`).
616 """
617 if plotFileName is None:
618 # We need to create an interactive plot with pyplot.
619 from matplotlib import pyplot
621 fig = pyplot.figure()
622 else:
623 # We can create a non-interactive figure.
624 from .plotting import make_figure
626 fig = make_figure()
628 if diff is None:
629 diff = lhs - rhs
631 if bad is not None:
632 # make an rgba image that's red and transparent where not bad
633 badImage = numpy.zeros(bad.shape + (4,), dtype=numpy.uint8)
634 badImage[:, :, 0] = 255
635 badImage[:, :, 1] = 0
636 badImage[:, :, 2] = 0
637 badImage[:, :, 3] = 255 * bad
638 vmin1 = numpy.minimum(numpy.min(lhs), numpy.min(rhs))
639 vmax1 = numpy.maximum(numpy.max(lhs), numpy.max(rhs))
640 vmin2 = numpy.min(diff)
641 vmax2 = numpy.max(diff)
642 for n, (image, title) in enumerate([(lhs, "lhs"), (rhs, "rhs"), (diff, "diff")]):
643 ax = fig.add_subplot(2, 3, n + 1)
644 im1 = ax.imshow(
645 image, cmap=pyplot.cm.gray, interpolation="nearest", origin="lower", vmin=vmin1, vmax=vmax1
646 )
647 if bad is not None:
648 ax.imshow(badImage, alpha=0.2, interpolation="nearest", origin="lower")
649 ax.axis("off")
650 ax.set_title(title)
651 ax = fig.add_subplot(2, 3, n + 4)
652 im2 = ax.imshow(
653 image, cmap=pyplot.cm.gray, interpolation="nearest", origin="lower", vmin=vmin2, vmax=vmax2
654 )
655 if bad is not None:
656 ax.imshow(badImage, alpha=0.2, interpolation="nearest", origin="lower")
657 ax.axis("off")
658 ax.set_title(title)
659 fig.subplots_adjust(left=0.05, bottom=0.05, top=0.92, right=0.75, wspace=0.05, hspace=0.05)
660 cax1 = fig.add_subplot([0.8, 0.55, 0.05, 0.4])
661 fig.colorbar(im1, cax=cax1)
662 cax2 = fig.add_subplot([0.8, 0.05, 0.05, 0.4])
663 fig.colorbar(im2, cax=cax2)
664 if plotFileName:
665 fig.savefig(plotFileName)
666 else:
667 pyplot.show()
670@inTestCase
671def assertFloatsAlmostEqual(
672 testCase: unittest.TestCase,
673 lhs: float | numpy.ndarray,
674 rhs: float | numpy.ndarray,
675 rtol: float | None = sys.float_info.epsilon,
676 atol: float | None = sys.float_info.epsilon,
677 relTo: float | None = None,
678 printFailures: bool = True,
679 plotOnFailure: bool = False,
680 plotFileName: str | None = None,
681 invert: bool = False,
682 msg: str | None = None,
683 ignoreNaNs: bool = False,
684) -> None:
685 """Highly-configurable floating point comparisons for scalars and arrays.
687 The test assertion will fail if all elements ``lhs`` and ``rhs`` are not
688 equal to within the tolerances specified by ``rtol`` and ``atol``.
689 More precisely, the comparison is:
691 ``abs(lhs - rhs) <= relTo*rtol OR abs(lhs - rhs) <= atol``
693 If ``rtol`` or ``atol`` is `None`, that term in the comparison is not
694 performed at all.
696 When not specified, ``relTo`` is the elementwise maximum of the absolute
697 values of ``lhs`` and ``rhs``. If set manually, it should usually be set
698 to either ``lhs`` or ``rhs``, or a scalar value typical of what is
699 expected.
701 Parameters
702 ----------
703 testCase : `unittest.TestCase`
704 Instance the test is part of.
705 lhs : scalar or array-like
706 LHS value(s) to compare; may be a scalar or array-like of any
707 dimension.
708 rhs : scalar or array-like
709 RHS value(s) to compare; may be a scalar or array-like of any
710 dimension.
711 rtol : `float`, optional
712 Relative tolerance for comparison; defaults to double-precision
713 epsilon.
714 atol : `float`, optional
715 Absolute tolerance for comparison; defaults to double-precision
716 epsilon.
717 relTo : `float`, optional
718 Value to which comparison with rtol is relative.
719 printFailures : `bool`, optional
720 Upon failure, print all inequal elements as part of the message.
721 plotOnFailure : `bool`, optional
722 Upon failure, plot the originals and their residual with matplotlib.
723 Only 2-d arrays are supported.
724 plotFileName : `str`, optional
725 Filename to save the plot to. If `None`, the plot will be displayed in
726 a window.
727 invert : `bool`, optional
728 If `True`, invert the comparison and fail only if any elements *are*
729 equal. Used to implement `~lsst.utils.tests.assertFloatsNotEqual`,
730 which should generally be used instead for clarity.
731 will return `True`).
732 msg : `str`, optional
733 String to append to the error message when assert fails.
734 ignoreNaNs : `bool`, optional
735 If `True` (`False` is default) mask out any NaNs from operand arrays
736 before performing comparisons if they are in the same locations; NaNs
737 in different locations are trigger test assertion failures, even when
738 ``invert=True``. Scalar NaNs are treated like arrays containing only
739 NaNs of the same shape as the other operand, and no comparisons are
740 performed if both sides are scalar NaNs.
742 Raises
743 ------
744 AssertionError
745 The values are not almost equal.
746 """
747 if ignoreNaNs:
748 lhsMask = numpy.isnan(lhs)
749 rhsMask = numpy.isnan(rhs)
750 if not numpy.all(lhsMask == rhsMask):
751 testCase.fail(
752 f"lhs has {lhsMask.sum()} NaN values and rhs has {rhsMask.sum()} NaN values, "
753 "in different locations."
754 )
755 if numpy.all(lhsMask):
756 assert numpy.all(rhsMask), "Should be guaranteed by previous if."
757 # All operands are fully NaN (either scalar NaNs or arrays of only
758 # NaNs).
759 return
760 assert not numpy.all(rhsMask), "Should be guaranteed by prevoius two ifs."
761 # If either operand is an array select just its not-NaN values. Note
762 # that these expressions are never True for scalar operands, because if
763 # they are NaN then the numpy.all checks above will catch them.
764 if numpy.any(lhsMask):
765 lhs = lhs[numpy.logical_not(lhsMask)]
766 if numpy.any(rhsMask):
767 rhs = rhs[numpy.logical_not(rhsMask)]
768 if not numpy.isfinite(lhs).all():
769 testCase.fail("Non-finite values in lhs")
770 if not numpy.isfinite(rhs).all():
771 testCase.fail("Non-finite values in rhs")
772 diff = lhs - rhs
773 absDiff = numpy.abs(lhs - rhs)
774 if rtol is not None:
775 if relTo is None:
776 relTo = numpy.maximum(numpy.abs(lhs), numpy.abs(rhs))
777 else:
778 relTo = numpy.abs(relTo)
779 bad = absDiff > rtol * relTo
780 if atol is not None:
781 bad = numpy.logical_and(bad, absDiff > atol)
782 else:
783 if atol is None:
784 raise ValueError("rtol and atol cannot both be None")
785 bad = absDiff > atol
786 failed = numpy.any(bad)
787 if invert:
788 failed = not failed
789 bad = numpy.logical_not(bad)
790 cmpStr = "=="
791 failStr = "are the same"
792 else:
793 cmpStr = "!="
794 failStr = "differ"
795 errMsg = []
796 if failed:
797 if numpy.isscalar(bad):
798 if rtol is None:
799 errMsg = [f"{lhs} {cmpStr} {rhs}; diff={absDiff} with atol={atol}"]
800 elif atol is None:
801 errMsg = [f"{lhs} {cmpStr} {rhs}; diff={absDiff}/{relTo}={absDiff / relTo} with rtol={rtol}"]
802 else:
803 errMsg = [
804 f"{lhs} {cmpStr} {rhs}; diff={absDiff}/{relTo}={absDiff / relTo} "
805 f"with rtol={rtol}, atol={atol}"
806 ]
807 else:
808 errMsg = [f"{bad.sum()}/{bad.size} elements {failStr} with rtol={rtol}, atol={atol}"]
809 if plotOnFailure:
810 if len(lhs.shape) != 2 or len(rhs.shape) != 2:
811 raise ValueError("plotOnFailure is only valid for 2-d arrays")
812 try:
813 plotImageDiff(lhs, rhs, bad, diff=diff, plotFileName=plotFileName)
814 except ImportError:
815 errMsg.append("Failure plot requested but matplotlib could not be imported.")
816 if printFailures:
817 # Make sure everything is an array if any of them are, so we
818 # can treat them the same (diff and absDiff are arrays if
819 # either rhs or lhs is), and we don't get here if neither is.
820 if numpy.isscalar(relTo):
821 relTo = numpy.ones(bad.shape, dtype=float) * relTo
822 if numpy.isscalar(lhs):
823 lhs = numpy.ones(bad.shape, dtype=float) * lhs
824 if numpy.isscalar(rhs):
825 rhs = numpy.ones(bad.shape, dtype=float) * rhs
826 if rtol is None:
827 for a, b, diff in zip(lhs[bad], rhs[bad], absDiff[bad]):
828 errMsg.append(f"{a} {cmpStr} {b} (diff={diff})")
829 else:
830 for a, b, diff, rel in zip(lhs[bad], rhs[bad], absDiff[bad], relTo[bad]):
831 errMsg.append(f"{a} {cmpStr} {b} (diff={diff}/{rel}={diff / rel})")
833 if msg is not None:
834 errMsg.append(msg)
835 testCase.assertFalse(failed, msg="\n".join(errMsg))
838@inTestCase
839def assertFloatsNotEqual(
840 testCase: unittest.TestCase,
841 lhs: float | numpy.ndarray,
842 rhs: float | numpy.ndarray,
843 **kwds: Any,
844) -> None:
845 """Fail a test if the given floating point values are equal to within the
846 given tolerances.
848 See `~lsst.utils.tests.assertFloatsAlmostEqual` (called with
849 ``rtol=atol=0``) for more information.
851 Parameters
852 ----------
853 testCase : `unittest.TestCase`
854 Instance the test is part of.
855 lhs : scalar or array-like
856 LHS value(s) to compare; may be a scalar or array-like of any
857 dimension.
858 rhs : scalar or array-like
859 RHS value(s) to compare; may be a scalar or array-like of any
860 dimension.
861 **kwds : `~typing.Any`
862 Keyword parameters forwarded to `assertFloatsAlmostEqual`.
864 Raises
865 ------
866 AssertionError
867 The values are almost equal.
868 """
869 return assertFloatsAlmostEqual(testCase, lhs, rhs, invert=True, **kwds)
872@inTestCase
873def assertFloatsEqual(
874 testCase: unittest.TestCase,
875 lhs: float | numpy.ndarray,
876 rhs: float | numpy.ndarray,
877 **kwargs: Any,
878) -> None:
879 """
880 Assert that lhs == rhs (both numeric types, whether scalar or array).
882 See `~lsst.utils.tests.assertFloatsAlmostEqual` (called with
883 ``rtol=atol=0``) for more information.
885 Parameters
886 ----------
887 testCase : `unittest.TestCase`
888 Instance the test is part of.
889 lhs : scalar or array-like
890 LHS value(s) to compare; may be a scalar or array-like of any
891 dimension.
892 rhs : scalar or array-like
893 RHS value(s) to compare; may be a scalar or array-like of any
894 dimension.
895 **kwargs : `~typing.Any`
896 Keyword parameters forwarded to `assertFloatsAlmostEqual`.
898 Raises
899 ------
900 AssertionError
901 The values are not equal.
902 """
903 return assertFloatsAlmostEqual(testCase, lhs, rhs, rtol=0, atol=0, **kwargs)
906def _settingsIterator(settings: dict[str, Sequence[Any]]) -> Iterator[dict[str, Any]]:
907 """Return an iterator for the provided test settings
909 Parameters
910 ----------
911 settings : `dict` (`str`: iterable)
912 Lists of test parameters. Each should be an iterable of the same
913 length. If a string is provided as an iterable, it will be converted
914 to a list of a single string.
916 Raises
917 ------
918 AssertionError
919 If the ``settings`` are not of the same length.
921 Yields
922 ------
923 parameters : `dict` (`str`: anything)
924 Set of parameters.
925 """
926 for name, values in settings.items():
927 if isinstance(values, str): 927 ↛ 930line 927 didn't jump to line 930 because the condition on line 927 was never true
928 # Probably meant as a single-element string, rather than an
929 # iterable of chars.
930 settings[name] = [values]
931 num = len(next(iter(settings.values()))) # Number of settings
932 for name, values in settings.items():
933 assert len(values) == num, f"Length mismatch for setting {name}: {len(values)} vs {num}"
934 for ii in range(num):
935 values = [settings[kk][ii] for kk in settings]
936 yield dict(zip(settings, values))
939def classParameters(**settings: Sequence[Any]) -> Callable:
940 """Class decorator for generating unit tests.
942 This decorator generates classes with class variables according to the
943 supplied ``settings``.
945 Parameters
946 ----------
947 **settings : `dict` (`str`: iterable)
948 The lists of test parameters to set as class variables in turn. Each
949 should be an iterable of the same length.
951 Examples
952 --------
953 ::
955 @classParameters(foo=[1, 2], bar=[3, 4])
956 class MyTestCase(unittest.TestCase): ...
958 will generate two classes, as if you wrote::
960 class MyTestCase_1_3(unittest.TestCase):
961 foo = 1
962 bar = 3
963 ...
966 class MyTestCase_2_4(unittest.TestCase):
967 foo = 2
968 bar = 4
969 ...
971 Note that the values are embedded in the class name.
972 """
974 def decorator(cls: type) -> None:
975 module = sys.modules[cls.__module__].__dict__
976 for params in _settingsIterator(settings):
977 name = f"{cls.__name__}_{'_'.join(str(vv) for vv in params.values())}"
978 bindings = dict(cls.__dict__)
979 bindings.update(params)
980 module[name] = type(name, (cls,), bindings)
982 return decorator
985def methodParameters(**settings: Sequence[Any]) -> Callable:
986 """Iterate over supplied settings to create subtests automatically.
988 This decorator iterates over the supplied settings, using
989 ``TestCase.subTest`` to communicate the values in the event of a failure.
991 Parameters
992 ----------
993 **settings : `dict` (`str`: iterable)
994 The lists of test parameters. Each should be an iterable of the same
995 length.
997 Examples
998 --------
999 .. code-block:: python
1001 @methodParameters(foo=[1, 2], bar=[3, 4])
1002 def testSomething(self, foo, bar): ...
1004 will run:
1006 .. code-block:: python
1008 testSomething(foo=1, bar=3)
1009 testSomething(foo=2, bar=4)
1010 """
1012 def decorator(func: Callable) -> Callable:
1013 @functools.wraps(func)
1014 def wrapper(self: unittest.TestCase, *args: Any, **kwargs: Any) -> None:
1015 for params in _settingsIterator(settings):
1016 kwargs.update(params)
1017 with self.subTest(**{k: repr(v) for k, v in params.items()}):
1018 func(self, *args, **kwargs)
1020 return wrapper
1022 return decorator
1025def _cartesianProduct(settings: Mapping[str, Sequence[Any]]) -> Mapping[str, Sequence[Any]]:
1026 """Return the cartesian product of the settings.
1028 Parameters
1029 ----------
1030 settings : `dict` mapping `str` to `iterable`
1031 Parameter combinations.
1033 Returns
1034 -------
1035 product : `dict` mapping `str` to `iterable`
1036 Parameter combinations covering the cartesian product (all possible
1037 combinations) of the input parameters.
1039 Examples
1040 --------
1041 .. code-block:: python
1043 cartesianProduct({"foo": [1, 2], "bar": ["black", "white"]})
1045 will return:
1047 .. code-block:: python
1049 {"foo": [1, 1, 2, 2], "bar": ["black", "white", "black", "white"]}
1050 """
1051 product: dict[str, list[Any]] = {kk: [] for kk in settings}
1052 for values in itertools.product(*settings.values()):
1053 for kk, vv in zip(settings.keys(), values):
1054 product[kk].append(vv)
1055 return product
1058def classParametersProduct(**settings: Sequence[Any]) -> Callable:
1059 """Class decorator for generating unit tests.
1061 This decorator generates classes with class variables according to the
1062 cartesian product of the supplied ``settings``.
1064 Parameters
1065 ----------
1066 **settings : `dict` (`str`: iterable)
1067 The lists of test parameters to set as class variables in turn. Each
1068 should be an iterable.
1070 Examples
1071 --------
1072 .. code-block:: python
1074 @classParametersProduct(foo=[1, 2], bar=[3, 4])
1075 class MyTestCase(unittest.TestCase): ...
1077 will generate four classes, as if you wrote::
1079 .. code-block:: python
1081 class MyTestCase_1_3(unittest.TestCase):
1082 foo = 1
1083 bar = 3
1084 ...
1087 class MyTestCase_1_4(unittest.TestCase):
1088 foo = 1
1089 bar = 4
1090 ...
1093 class MyTestCase_2_3(unittest.TestCase):
1094 foo = 2
1095 bar = 3
1096 ...
1099 class MyTestCase_2_4(unittest.TestCase):
1100 foo = 2
1101 bar = 4
1102 ...
1104 Note that the values are embedded in the class name.
1105 """
1106 return classParameters(**_cartesianProduct(settings))
1109def methodParametersProduct(**settings: Sequence[Any]) -> Callable:
1110 """Iterate over cartesian product creating sub tests.
1112 This decorator iterates over the cartesian product of the supplied
1113 settings, using `~unittest.TestCase.subTest` to communicate the values in
1114 the event of a failure.
1116 Parameters
1117 ----------
1118 **settings : `dict` (`str`: iterable)
1119 The parameter combinations to test. Each should be an iterable.
1121 Examples
1122 --------
1123 @methodParametersProduct(foo=[1, 2], bar=["black", "white"])
1124 def testSomething(self, foo, bar):
1125 ...
1127 will run:
1129 testSomething(foo=1, bar="black")
1130 testSomething(foo=1, bar="white")
1131 testSomething(foo=2, bar="black")
1132 testSomething(foo=2, bar="white")
1133 """
1134 return methodParameters(**_cartesianProduct(settings))
1137@contextlib.contextmanager
1138def temporaryDirectory() -> Iterator[str]:
1139 """Context manager that creates and destroys a temporary directory.
1141 The difference from `tempfile.TemporaryDirectory` is that this ignores
1142 errors when deleting a directory, which may happen with some filesystems.
1144 Yields
1145 ------
1146 `str`
1147 Name of the temporary directory.
1148 """
1149 tmpdir = tempfile.mkdtemp()
1150 yield tmpdir
1151 shutil.rmtree(tmpdir, ignore_errors=True)