Coverage for python/lsst/utils/tests.py: 33%
355 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-17 07:53 +0000
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-17 07:53 +0000
1# This file is part of utils.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# Use of this source code is governed by a 3-clause BSD-style
10# license that can be found in the LICENSE file.
12"""Support code for running unit tests"""
14from __future__ import annotations
16__all__ = [
17 "init",
18 "MemoryTestCase",
19 "ExecutablesTestCase",
20 "ImportTestCase",
21 "getTempFilePath",
22 "TestCase",
23 "assertFloatsAlmostEqual",
24 "assertFloatsNotEqual",
25 "assertFloatsEqual",
26 "debugger",
27 "classParameters",
28 "methodParameters",
29 "temporaryDirectory",
30]
32import contextlib
33import functools
34import gc
35import inspect
36import itertools
37import os
38import re
39import shutil
40import subprocess
41import sys
42import tempfile
43import unittest
44import warnings
45from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
46from importlib import resources
47from typing import Any, ClassVar
49import numpy
50import psutil
52from .doImport import doImport
54# Initialize the list of open files to an empty set
55open_files = set()
58def _get_open_files() -> set[str]:
59 """Return a set containing the list of files currently open in this
60 process.
62 Returns
63 -------
64 open_files : `set`
65 Set containing the list of open files.
66 """
67 return {p.path for p in psutil.Process().open_files()}
70def init() -> None:
71 """Initialize the memory tester and file descriptor leak tester."""
72 global open_files
73 # Reset the list of open files
74 open_files = _get_open_files()
77def sort_tests(tests) -> unittest.TestSuite:
78 """Sort supplied test suites such that MemoryTestCases are at the end.
80 `lsst.utils.tests.MemoryTestCase` tests should always run after any other
81 tests in the module.
83 Parameters
84 ----------
85 tests : sequence
86 Sequence of test suites.
88 Returns
89 -------
90 suite : `unittest.TestSuite`
91 A combined `~unittest.TestSuite` with
92 `~lsst.utils.tests.MemoryTestCase` at the end.
93 """
94 suite = unittest.TestSuite()
95 memtests = []
96 for test_suite in tests:
97 try:
98 # Just test the first test method in the suite for MemoryTestCase
99 # Use loop rather than next as it is possible for a test class
100 # to not have any test methods and the Python community prefers
101 # for loops over catching a StopIteration exception.
102 bases = None
103 for method in test_suite:
104 bases = inspect.getmro(method.__class__)
105 break
106 if bases is not None and MemoryTestCase in bases:
107 memtests.append(test_suite)
108 else:
109 suite.addTests(test_suite)
110 except TypeError:
111 if isinstance(test_suite, MemoryTestCase):
112 memtests.append(test_suite)
113 else:
114 suite.addTest(test_suite)
115 suite.addTests(memtests)
116 return suite
119def _suiteClassWrapper(tests):
120 return unittest.TestSuite(sort_tests(tests))
123# Replace the suiteClass callable in the defaultTestLoader
124# so that we can reorder the test ordering. This will have
125# no effect if no memory test cases are found.
126unittest.defaultTestLoader.suiteClass = _suiteClassWrapper
129class MemoryTestCase(unittest.TestCase):
130 """Check for resource leaks."""
132 ignore_regexps: ClassVar[list[str]] = []
133 """List of regexps to ignore when checking for open files."""
135 @classmethod
136 def tearDownClass(cls) -> None:
137 """Reset the leak counter when the tests have been completed"""
138 init()
140 def testFileDescriptorLeaks(self) -> None:
141 """Check if any file descriptors are open since init() called.
143 Ignores files with certain known path components and any files
144 that match regexp patterns in class property ``ignore_regexps``.
145 """
146 gc.collect()
147 global open_files
148 now_open = _get_open_files()
150 # Some files are opened out of the control of the stack.
151 now_open = {
152 f
153 for f in now_open
154 if not f.endswith(".car")
155 and not f.startswith("/proc/")
156 and not f.endswith(".ttf")
157 and not (f.startswith("/var/lib/") and f.endswith("/passwd"))
158 and not f.endswith("astropy.log")
159 and not f.endswith("mime/mime.cache")
160 and not f.endswith(".sqlite3")
161 and not any(re.search(r, f) for r in self.ignore_regexps)
162 }
164 diff = now_open.difference(open_files)
165 if diff:
166 for f in diff:
167 print(f"File open: {f}")
168 self.fail("Failed to close %d file%s" % (len(diff), "s" if len(diff) != 1 else ""))
171class ExecutablesTestCase(unittest.TestCase):
172 """Test that executables can be run and return good status.
174 The test methods are dynamically created. Callers
175 must subclass this class in their own test file and invoke
176 the create_executable_tests() class method to register the tests.
177 """
179 TESTS_DISCOVERED = -1
181 @classmethod
182 def setUpClass(cls) -> None:
183 """Abort testing if automated test creation was enabled and
184 no tests were found.
185 """
186 if cls.TESTS_DISCOVERED == 0:
187 raise RuntimeError("No executables discovered.")
189 def testSanity(self) -> None:
190 """Ensure that there is at least one test to be
191 executed. This allows the test runner to trigger the class set up
192 machinery to test whether there are some executables to test.
193 """
195 def assertExecutable(
196 self,
197 executable: str,
198 root_dir: str | None = None,
199 args: Sequence[str] | None = None,
200 msg: str | None = None,
201 ) -> None:
202 """Check an executable runs and returns good status.
204 Prints output to standard out. On bad exit status the test
205 fails. If the executable can not be located the test is skipped.
207 Parameters
208 ----------
209 executable : `str`
210 Path to an executable. ``root_dir`` is not used if this is an
211 absolute path.
212 root_dir : `str`, optional
213 Directory containing executable. Ignored if `None`.
214 args : `list` or `tuple`, optional
215 Arguments to be provided to the executable.
216 msg : `str`, optional
217 Message to use when the test fails. Can be `None` for default
218 message.
220 Raises
221 ------
222 AssertionError
223 The executable did not return 0 exit status.
224 """
225 if root_dir is not None and not os.path.isabs(executable):
226 executable = os.path.join(root_dir, executable)
228 # Form the argument list for subprocess
229 sp_args = [executable]
230 argstr = "no arguments"
231 if args is not None:
232 sp_args.extend(args)
233 argstr = 'arguments "' + " ".join(args) + '"'
235 print(f"Running executable '{executable}' with {argstr}...")
236 if not os.path.exists(executable):
237 self.skipTest(f"Executable {executable} is unexpectedly missing")
238 failmsg = None
239 try:
240 output = subprocess.check_output(sp_args)
241 except subprocess.CalledProcessError as e:
242 output = e.output
243 failmsg = f"Bad exit status from '{executable}': {e.returncode}"
244 print(output.decode("utf-8"))
245 if failmsg:
246 if msg is None:
247 msg = failmsg
248 self.fail(msg)
250 @classmethod
251 def _build_test_method(cls, executable: str, root_dir: str) -> None:
252 """Build a test method and attach to class.
254 A test method is created for the supplied excutable located
255 in the supplied root directory. This method is attached to the class
256 so that the test runner will discover the test and run it.
258 Parameters
259 ----------
260 cls : `object`
261 The class in which to create the tests.
262 executable : `str`
263 Name of executable. Can be absolute path.
264 root_dir : `str`
265 Path to executable. Not used if executable path is absolute.
266 """
267 if not os.path.isabs(executable): 267 ↛ 268line 267 didn't jump to line 268, because the condition on line 267 was never true
268 executable = os.path.abspath(os.path.join(root_dir, executable))
270 # Create the test name from the executable path.
271 test_name = "test_exe_" + executable.replace("/", "_")
273 # This is the function that will become the test method
274 def test_executable_runs(*args: Any) -> None:
275 self = args[0]
276 self.assertExecutable(executable)
278 # Give it a name and attach it to the class
279 test_executable_runs.__name__ = test_name
280 setattr(cls, test_name, test_executable_runs)
282 @classmethod
283 def create_executable_tests(cls, ref_file: str, executables: Sequence[str] | None = None) -> None:
284 """Discover executables to test and create corresponding test methods.
286 Scans the directory containing the supplied reference file
287 (usually ``__file__`` supplied from the test class) to look for
288 executables. If executables are found a test method is created
289 for each one. That test method will run the executable and
290 check the returned value.
292 Executable scripts with a ``.py`` extension and shared libraries
293 are ignored by the scanner.
295 This class method must be called before test discovery.
297 Parameters
298 ----------
299 ref_file : `str`
300 Path to a file within the directory to be searched.
301 If the files are in the same location as the test file, then
302 ``__file__`` can be used.
303 executables : `list` or `tuple`, optional
304 Sequence of executables that can override the automated
305 detection. If an executable mentioned here is not found, a
306 skipped test will be created for it, rather than a failed
307 test.
309 Examples
310 --------
311 >>> cls.create_executable_tests(__file__)
312 """
313 # Get the search directory from the reference file
314 ref_dir = os.path.abspath(os.path.dirname(ref_file))
316 if executables is None: 316 ↛ 331line 316 didn't jump to line 331, because the condition on line 316 was never false
317 # Look for executables to test by walking the tree
318 executables = []
319 for root, _, files in os.walk(ref_dir):
320 for f in files:
321 # Skip Python files. Shared libraries are executable.
322 if not f.endswith(".py") and not f.endswith(".so"):
323 full_path = os.path.join(root, f)
324 if os.access(full_path, os.X_OK):
325 executables.append(full_path)
327 # Store the number of tests found for later assessment.
328 # Do not raise an exception if we have no executables as this would
329 # cause the testing to abort before the test runner could properly
330 # integrate it into the failure report.
331 cls.TESTS_DISCOVERED = len(executables)
333 # Create the test functions and attach them to the class
334 for e in executables:
335 cls._build_test_method(e, ref_dir)
338class ImportTestCase(unittest.TestCase):
339 """Test that the named packages can be imported and all files within
340 that package.
342 The test methods are created dynamically. Callers must subclass this
343 method and define the ``PACKAGES`` property.
344 """
346 PACKAGES: ClassVar[Iterable[str]] = ()
347 """Packages to be imported."""
349 _n_registered = 0
350 """Number of packages registered for testing by this class."""
352 def _test_no_packages_registered_for_import_testing(self) -> None:
353 """Test when no packages have been registered.
355 Without this, if no packages have been listed no tests will be
356 registered and the test system will not report on anything. This
357 test fails and reports why.
358 """
359 raise AssertionError("No packages registered with import test. Was the PACKAGES property set?")
361 def __init_subclass__(cls, **kwargs: Any) -> None:
362 """Create the test methods based on the content of the ``PACKAGES``
363 class property.
364 """
365 super().__init_subclass__(**kwargs)
367 for mod in cls.PACKAGES:
368 test_name = "test_import_" + mod.replace(".", "_")
370 def test_import(*args: Any, mod=mod) -> None:
371 self = args[0]
372 self.assertImport(mod)
374 test_import.__name__ = test_name
375 setattr(cls, test_name, test_import)
376 cls._n_registered += 1
378 # If there are no packages listed that is likely a mistake and
379 # so register a failing test.
380 if cls._n_registered == 0: 380 ↛ 381line 380 didn't jump to line 381, because the condition on line 380 was never true
381 cls.test_no_packages_registered = cls._test_no_packages_registered_for_import_testing
383 def assertImport(self, root_pkg):
384 for file in resources.files(root_pkg).iterdir():
385 file = file.name
386 if not file.endswith(".py"):
387 continue
388 if file.startswith("__"):
389 continue
390 root, _ = os.path.splitext(file)
391 module_name = f"{root_pkg}.{root}"
392 with self.subTest(module=module_name):
393 try:
394 doImport(module_name)
395 except ImportError as e:
396 raise AssertionError(f"Error importing module {module_name}: {e}") from e
399@contextlib.contextmanager
400def getTempFilePath(ext: str, expectOutput: bool = True) -> Iterator[str]:
401 """Return a path suitable for a temporary file and try to delete the
402 file on success
404 If the with block completes successfully then the file is deleted,
405 if possible; failure results in a printed warning.
406 If a file is remains when it should not, a RuntimeError exception is
407 raised. This exception is also raised if a file is not present on context
408 manager exit when one is expected to exist.
409 If the block exits with an exception the file if left on disk so it can be
410 examined. The file name has a random component such that nested context
411 managers can be used with the same file suffix.
413 Parameters
414 ----------
415 ext : `str`
416 File name extension, e.g. ``.fits``.
417 expectOutput : `bool`, optional
418 If `True`, a file should be created within the context manager.
419 If `False`, a file should not be present when the context manager
420 exits.
422 Returns
423 -------
424 path : `str`
425 Path for a temporary file. The path is a combination of the caller's
426 file path and the name of the top-level function
428 Examples
429 --------
430 .. code-block:: python
432 # file tests/testFoo.py
433 import unittest
434 import lsst.utils.tests
435 class FooTestCase(unittest.TestCase):
436 def testBasics(self):
437 self.runTest()
439 def runTest(self):
440 with lsst.utils.tests.getTempFilePath(".fits") as tmpFile:
441 # if tests/.tests exists then
442 # tmpFile = "tests/.tests/testFoo_testBasics.fits"
443 # otherwise tmpFile = "testFoo_testBasics.fits"
444 ...
445 # at the end of this "with" block the path tmpFile will be
446 # deleted, but only if the file exists and the "with"
447 # block terminated normally (rather than with an exception)
448 ...
449 """
450 stack = inspect.stack()
451 # get name of first function in the file
452 for i in range(2, len(stack)):
453 frameInfo = inspect.getframeinfo(stack[i][0])
454 if i == 2:
455 callerFilePath = frameInfo.filename
456 callerFuncName = frameInfo.function
457 elif callerFilePath == frameInfo.filename:
458 # this function called the previous function
459 callerFuncName = frameInfo.function
460 else:
461 break
463 callerDir, callerFileNameWithExt = os.path.split(callerFilePath)
464 callerFileName = os.path.splitext(callerFileNameWithExt)[0]
465 outDir = os.path.join(callerDir, ".tests")
466 if not os.path.isdir(outDir):
467 # No .tests directory implies we are not running with sconsUtils.
468 # Need to use the current working directory, the callerDir, or
469 # /tmp equivalent. If cwd is used if must be as an absolute path
470 # in case the test code changes cwd.
471 outDir = os.path.abspath(os.path.curdir)
472 prefix = f"{callerFileName}_{callerFuncName}-"
473 outPath = tempfile.mktemp(dir=outDir, suffix=ext, prefix=prefix)
474 if os.path.exists(outPath):
475 # There should not be a file there given the randomizer. Warn and
476 # remove.
477 # Use stacklevel 3 so that the warning is reported from the end of the
478 # with block
479 warnings.warn(f"Unexpectedly found pre-existing tempfile named {outPath!r}", stacklevel=3)
480 with contextlib.suppress(OSError):
481 os.remove(outPath)
483 yield outPath
485 fileExists = os.path.exists(outPath)
486 if expectOutput:
487 if not fileExists:
488 raise RuntimeError(f"Temp file expected named {outPath} but none found")
489 else:
490 if fileExists:
491 raise RuntimeError(f"Unexpectedly discovered temp file named {outPath}")
492 # Try to clean up the file regardless
493 if fileExists:
494 try:
495 os.remove(outPath)
496 except OSError as e:
497 # Use stacklevel 3 so that the warning is reported from the end of
498 # the with block.
499 warnings.warn(f"Warning: could not remove file {outPath!r}: {e}", stacklevel=3)
502class TestCase(unittest.TestCase):
503 """Subclass of unittest.TestCase that adds some custom assertions for
504 convenience.
505 """
508def inTestCase(func: Callable) -> Callable:
509 """Add a free function to our custom TestCase class, while
510 also making it available as a free function.
511 """
512 setattr(TestCase, func.__name__, func)
513 return func
516def debugger(*exceptions):
517 """Enter the debugger when there's an uncaught exception
519 To use, just slap a ``@debugger()`` on your function.
521 You may provide specific exception classes to catch as arguments to
522 the decorator function, e.g.,
523 ``@debugger(RuntimeError, NotImplementedError)``.
524 This defaults to just `AssertionError`, for use on `unittest.TestCase`
525 methods.
527 Code provided by "Rosh Oxymoron" on StackOverflow:
528 http://stackoverflow.com/questions/4398967/python-unit-testing-automatically-running-the-debugger-when-a-test-fails
530 Notes
531 -----
532 Consider using ``pytest --pdb`` instead of this decorator.
533 """
534 if not exceptions:
535 exceptions = (Exception,)
537 def decorator(f):
538 @functools.wraps(f)
539 def wrapper(*args, **kwargs):
540 try:
541 return f(*args, **kwargs)
542 except exceptions:
543 import pdb
544 import sys
546 pdb.post_mortem(sys.exc_info()[2])
548 return wrapper
550 return decorator
553def plotImageDiff(
554 lhs: numpy.ndarray,
555 rhs: numpy.ndarray,
556 bad: numpy.ndarray | None = None,
557 diff: numpy.ndarray | None = None,
558 plotFileName: str | None = None,
559) -> None:
560 """Plot the comparison of two 2-d NumPy arrays.
562 Parameters
563 ----------
564 lhs : `numpy.ndarray`
565 LHS values to compare; a 2-d NumPy array
566 rhs : `numpy.ndarray`
567 RHS values to compare; a 2-d NumPy array
568 bad : `numpy.ndarray`
569 A 2-d boolean NumPy array of values to emphasize in the plots
570 diff : `numpy.ndarray`
571 difference array; a 2-d NumPy array, or None to show lhs-rhs
572 plotFileName : `str`
573 Filename to save the plot to. If None, the plot will be displayed in
574 a window.
576 Notes
577 -----
578 This method uses `matplotlib` and imports it internally; it should be
579 wrapped in a try/except block within packages that do not depend on
580 `matplotlib` (including `~lsst.utils`).
581 """
582 from matplotlib import pyplot
584 if diff is None:
585 diff = lhs - rhs
586 pyplot.figure()
587 if bad is not None:
588 # make an rgba image that's red and transparent where not bad
589 badImage = numpy.zeros(bad.shape + (4,), dtype=numpy.uint8)
590 badImage[:, :, 0] = 255
591 badImage[:, :, 1] = 0
592 badImage[:, :, 2] = 0
593 badImage[:, :, 3] = 255 * bad
594 vmin1 = numpy.minimum(numpy.min(lhs), numpy.min(rhs))
595 vmax1 = numpy.maximum(numpy.max(lhs), numpy.max(rhs))
596 vmin2 = numpy.min(diff)
597 vmax2 = numpy.max(diff)
598 for n, (image, title) in enumerate([(lhs, "lhs"), (rhs, "rhs"), (diff, "diff")]):
599 pyplot.subplot(2, 3, n + 1)
600 im1 = pyplot.imshow(
601 image, cmap=pyplot.cm.gray, interpolation="nearest", origin="lower", vmin=vmin1, vmax=vmax1
602 )
603 if bad is not None:
604 pyplot.imshow(badImage, alpha=0.2, interpolation="nearest", origin="lower")
605 pyplot.axis("off")
606 pyplot.title(title)
607 pyplot.subplot(2, 3, n + 4)
608 im2 = pyplot.imshow(
609 image, cmap=pyplot.cm.gray, interpolation="nearest", origin="lower", vmin=vmin2, vmax=vmax2
610 )
611 if bad is not None:
612 pyplot.imshow(badImage, alpha=0.2, interpolation="nearest", origin="lower")
613 pyplot.axis("off")
614 pyplot.title(title)
615 pyplot.subplots_adjust(left=0.05, bottom=0.05, top=0.92, right=0.75, wspace=0.05, hspace=0.05)
616 cax1 = pyplot.axes([0.8, 0.55, 0.05, 0.4])
617 pyplot.colorbar(im1, cax=cax1)
618 cax2 = pyplot.axes([0.8, 0.05, 0.05, 0.4])
619 pyplot.colorbar(im2, cax=cax2)
620 if plotFileName:
621 pyplot.savefig(plotFileName)
622 else:
623 pyplot.show()
626@inTestCase
627def assertFloatsAlmostEqual(
628 testCase: unittest.TestCase,
629 lhs: float | numpy.ndarray,
630 rhs: float | numpy.ndarray,
631 rtol: float | None = sys.float_info.epsilon,
632 atol: float | None = sys.float_info.epsilon,
633 relTo: float | None = None,
634 printFailures: bool = True,
635 plotOnFailure: bool = False,
636 plotFileName: str | None = None,
637 invert: bool = False,
638 msg: str | None = None,
639 ignoreNaNs: bool = False,
640) -> None:
641 """Highly-configurable floating point comparisons for scalars and arrays.
643 The test assertion will fail if all elements ``lhs`` and ``rhs`` are not
644 equal to within the tolerances specified by ``rtol`` and ``atol``.
645 More precisely, the comparison is:
647 ``abs(lhs - rhs) <= relTo*rtol OR abs(lhs - rhs) <= atol``
649 If ``rtol`` or ``atol`` is `None`, that term in the comparison is not
650 performed at all.
652 When not specified, ``relTo`` is the elementwise maximum of the absolute
653 values of ``lhs`` and ``rhs``. If set manually, it should usually be set
654 to either ``lhs`` or ``rhs``, or a scalar value typical of what is
655 expected.
657 Parameters
658 ----------
659 testCase : `unittest.TestCase`
660 Instance the test is part of.
661 lhs : scalar or array-like
662 LHS value(s) to compare; may be a scalar or array-like of any
663 dimension.
664 rhs : scalar or array-like
665 RHS value(s) to compare; may be a scalar or array-like of any
666 dimension.
667 rtol : `float`, optional
668 Relative tolerance for comparison; defaults to double-precision
669 epsilon.
670 atol : `float`, optional
671 Absolute tolerance for comparison; defaults to double-precision
672 epsilon.
673 relTo : `float`, optional
674 Value to which comparison with rtol is relative.
675 printFailures : `bool`, optional
676 Upon failure, print all inequal elements as part of the message.
677 plotOnFailure : `bool`, optional
678 Upon failure, plot the originals and their residual with matplotlib.
679 Only 2-d arrays are supported.
680 plotFileName : `str`, optional
681 Filename to save the plot to. If `None`, the plot will be displayed in
682 a window.
683 invert : `bool`, optional
684 If `True`, invert the comparison and fail only if any elements *are*
685 equal. Used to implement `~lsst.utils.tests.assertFloatsNotEqual`,
686 which should generally be used instead for clarity.
687 will return `True`).
688 msg : `str`, optional
689 String to append to the error message when assert fails.
690 ignoreNaNs : `bool`, optional
691 If `True` (`False` is default) mask out any NaNs from operand arrays
692 before performing comparisons if they are in the same locations; NaNs
693 in different locations are trigger test assertion failures, even when
694 ``invert=True``. Scalar NaNs are treated like arrays containing only
695 NaNs of the same shape as the other operand, and no comparisons are
696 performed if both sides are scalar NaNs.
698 Raises
699 ------
700 AssertionError
701 The values are not almost equal.
702 """
703 if ignoreNaNs:
704 lhsMask = numpy.isnan(lhs)
705 rhsMask = numpy.isnan(rhs)
706 if not numpy.all(lhsMask == rhsMask):
707 testCase.fail(
708 f"lhs has {lhsMask.sum()} NaN values and rhs has {rhsMask.sum()} NaN values, "
709 "in different locations."
710 )
711 if numpy.all(lhsMask):
712 assert numpy.all(rhsMask), "Should be guaranteed by previous if."
713 # All operands are fully NaN (either scalar NaNs or arrays of only
714 # NaNs).
715 return
716 assert not numpy.all(rhsMask), "Should be guaranteed by prevoius two ifs."
717 # If either operand is an array select just its not-NaN values. Note
718 # that these expressions are never True for scalar operands, because if
719 # they are NaN then the numpy.all checks above will catch them.
720 if numpy.any(lhsMask):
721 lhs = lhs[numpy.logical_not(lhsMask)]
722 if numpy.any(rhsMask):
723 rhs = rhs[numpy.logical_not(rhsMask)]
724 if not numpy.isfinite(lhs).all():
725 testCase.fail("Non-finite values in lhs")
726 if not numpy.isfinite(rhs).all():
727 testCase.fail("Non-finite values in rhs")
728 diff = lhs - rhs
729 absDiff = numpy.abs(lhs - rhs)
730 if rtol is not None:
731 if relTo is None:
732 relTo = numpy.maximum(numpy.abs(lhs), numpy.abs(rhs))
733 else:
734 relTo = numpy.abs(relTo)
735 bad = absDiff > rtol * relTo
736 if atol is not None:
737 bad = numpy.logical_and(bad, absDiff > atol)
738 else:
739 if atol is None:
740 raise ValueError("rtol and atol cannot both be None")
741 bad = absDiff > atol
742 failed = numpy.any(bad)
743 if invert:
744 failed = not failed
745 bad = numpy.logical_not(bad)
746 cmpStr = "=="
747 failStr = "are the same"
748 else:
749 cmpStr = "!="
750 failStr = "differ"
751 errMsg = []
752 if failed:
753 if numpy.isscalar(bad):
754 if rtol is None:
755 errMsg = [f"{lhs} {cmpStr} {rhs}; diff={absDiff} with atol={atol}"]
756 elif atol is None:
757 errMsg = [f"{lhs} {cmpStr} {rhs}; diff={absDiff}/{relTo}={absDiff / relTo} with rtol={rtol}"]
758 else:
759 errMsg = [
760 f"{lhs} {cmpStr} {rhs}; diff={absDiff}/{relTo}={absDiff / relTo} "
761 f"with rtol={rtol}, atol={atol}"
762 ]
763 else:
764 errMsg = [f"{bad.sum()}/{bad.size} elements {failStr} with rtol={rtol}, atol={atol}"]
765 if plotOnFailure:
766 if len(lhs.shape) != 2 or len(rhs.shape) != 2:
767 raise ValueError("plotOnFailure is only valid for 2-d arrays")
768 try:
769 plotImageDiff(lhs, rhs, bad, diff=diff, plotFileName=plotFileName)
770 except ImportError:
771 errMsg.append("Failure plot requested but matplotlib could not be imported.")
772 if printFailures:
773 # Make sure everything is an array if any of them are, so we
774 # can treat them the same (diff and absDiff are arrays if
775 # either rhs or lhs is), and we don't get here if neither is.
776 if numpy.isscalar(relTo):
777 relTo = numpy.ones(bad.shape, dtype=float) * relTo
778 if numpy.isscalar(lhs):
779 lhs = numpy.ones(bad.shape, dtype=float) * lhs
780 if numpy.isscalar(rhs):
781 rhs = numpy.ones(bad.shape, dtype=float) * rhs
782 if rtol is None:
783 for a, b, diff in zip(lhs[bad], rhs[bad], absDiff[bad]):
784 errMsg.append(f"{a} {cmpStr} {b} (diff={diff})")
785 else:
786 for a, b, diff, rel in zip(lhs[bad], rhs[bad], absDiff[bad], relTo[bad]):
787 errMsg.append(f"{a} {cmpStr} {b} (diff={diff}/{rel}={diff / rel})")
789 if msg is not None:
790 errMsg.append(msg)
791 testCase.assertFalse(failed, msg="\n".join(errMsg))
794@inTestCase
795def assertFloatsNotEqual(
796 testCase: unittest.TestCase,
797 lhs: float | numpy.ndarray,
798 rhs: float | numpy.ndarray,
799 **kwds: Any,
800) -> None:
801 """Fail a test if the given floating point values are equal to within the
802 given tolerances.
804 See `~lsst.utils.tests.assertFloatsAlmostEqual` (called with
805 ``rtol=atol=0``) for more information.
807 Parameters
808 ----------
809 testCase : `unittest.TestCase`
810 Instance the test is part of.
811 lhs : scalar or array-like
812 LHS value(s) to compare; may be a scalar or array-like of any
813 dimension.
814 rhs : scalar or array-like
815 RHS value(s) to compare; may be a scalar or array-like of any
816 dimension.
818 Raises
819 ------
820 AssertionError
821 The values are almost equal.
822 """
823 return assertFloatsAlmostEqual(testCase, lhs, rhs, invert=True, **kwds)
826@inTestCase
827def assertFloatsEqual(
828 testCase: unittest.TestCase,
829 lhs: float | numpy.ndarray,
830 rhs: float | numpy.ndarray,
831 **kwargs: Any,
832) -> None:
833 """
834 Assert that lhs == rhs (both numeric types, whether scalar or array).
836 See `~lsst.utils.tests.assertFloatsAlmostEqual` (called with
837 ``rtol=atol=0``) for more information.
839 Parameters
840 ----------
841 testCase : `unittest.TestCase`
842 Instance the test is part of.
843 lhs : scalar or array-like
844 LHS value(s) to compare; may be a scalar or array-like of any
845 dimension.
846 rhs : scalar or array-like
847 RHS value(s) to compare; may be a scalar or array-like of any
848 dimension.
850 Raises
851 ------
852 AssertionError
853 The values are not equal.
854 """
855 return assertFloatsAlmostEqual(testCase, lhs, rhs, rtol=0, atol=0, **kwargs)
858def _settingsIterator(settings: dict[str, Sequence[Any]]) -> Iterator[dict[str, Any]]:
859 """Return an iterator for the provided test settings
861 Parameters
862 ----------
863 settings : `dict` (`str`: iterable)
864 Lists of test parameters. Each should be an iterable of the same
865 length. If a string is provided as an iterable, it will be converted
866 to a list of a single string.
868 Raises
869 ------
870 AssertionError
871 If the ``settings`` are not of the same length.
873 Yields
874 ------
875 parameters : `dict` (`str`: anything)
876 Set of parameters.
877 """
878 for name, values in settings.items():
879 if isinstance(values, str): 879 ↛ 882line 879 didn't jump to line 882, because the condition on line 879 was never true
880 # Probably meant as a single-element string, rather than an
881 # iterable of chars.
882 settings[name] = [values]
883 num = len(next(iter(settings.values()))) # Number of settings
884 for name, values in settings.items():
885 assert len(values) == num, f"Length mismatch for setting {name}: {len(values)} vs {num}"
886 for ii in range(num):
887 values = [settings[kk][ii] for kk in settings]
888 yield dict(zip(settings, values))
891def classParameters(**settings: Sequence[Any]) -> Callable:
892 """Class decorator for generating unit tests
894 This decorator generates classes with class variables according to the
895 supplied ``settings``.
897 Parameters
898 ----------
899 **settings : `dict` (`str`: iterable)
900 The lists of test parameters to set as class variables in turn. Each
901 should be an iterable of the same length.
903 Examples
904 --------
905 ::
907 @classParameters(foo=[1, 2], bar=[3, 4])
908 class MyTestCase(unittest.TestCase):
909 ...
911 will generate two classes, as if you wrote::
913 class MyTestCase_1_3(unittest.TestCase):
914 foo = 1
915 bar = 3
916 ...
918 class MyTestCase_2_4(unittest.TestCase):
919 foo = 2
920 bar = 4
921 ...
923 Note that the values are embedded in the class name.
924 """
926 def decorator(cls: type) -> None:
927 module = sys.modules[cls.__module__].__dict__
928 for params in _settingsIterator(settings):
929 name = f"{cls.__name__}_{'_'.join(str(vv) for vv in params.values())}"
930 bindings = dict(cls.__dict__)
931 bindings.update(params)
932 module[name] = type(name, (cls,), bindings)
934 return decorator
937def methodParameters(**settings: Sequence[Any]) -> Callable:
938 """Iterate over supplied settings to create subtests automatically.
940 This decorator iterates over the supplied settings, using
941 ``TestCase.subTest`` to communicate the values in the event of a failure.
943 Parameters
944 ----------
945 **settings : `dict` (`str`: iterable)
946 The lists of test parameters. Each should be an iterable of the same
947 length.
949 Examples
950 --------
951 .. code-block:: python
953 @methodParameters(foo=[1, 2], bar=[3, 4])
954 def testSomething(self, foo, bar):
955 ...
957 will run:
959 .. code-block:: python
961 testSomething(foo=1, bar=3)
962 testSomething(foo=2, bar=4)
963 """
965 def decorator(func: Callable) -> Callable:
966 @functools.wraps(func)
967 def wrapper(self: unittest.TestCase, *args: Any, **kwargs: Any) -> None:
968 for params in _settingsIterator(settings):
969 kwargs.update(params)
970 with self.subTest(**params):
971 func(self, *args, **kwargs)
973 return wrapper
975 return decorator
978def _cartesianProduct(settings: Mapping[str, Sequence[Any]]) -> Mapping[str, Sequence[Any]]:
979 """Return the cartesian product of the settings
981 Parameters
982 ----------
983 settings : `dict` mapping `str` to `iterable`
984 Parameter combinations.
986 Returns
987 -------
988 product : `dict` mapping `str` to `iterable`
989 Parameter combinations covering the cartesian product (all possible
990 combinations) of the input parameters.
992 Examples
993 --------
994 .. code-block:: python
996 cartesianProduct({"foo": [1, 2], "bar": ["black", "white"]})
998 will return:
1000 .. code-block:: python
1002 {"foo": [1, 1, 2, 2], "bar": ["black", "white", "black", "white"]}
1003 """
1004 product: dict[str, list[Any]] = {kk: [] for kk in settings}
1005 for values in itertools.product(*settings.values()):
1006 for kk, vv in zip(settings.keys(), values):
1007 product[kk].append(vv)
1008 return product
1011def classParametersProduct(**settings: Sequence[Any]) -> Callable:
1012 """Class decorator for generating unit tests
1014 This decorator generates classes with class variables according to the
1015 cartesian product of the supplied ``settings``.
1017 Parameters
1018 ----------
1019 **settings : `dict` (`str`: iterable)
1020 The lists of test parameters to set as class variables in turn. Each
1021 should be an iterable.
1023 Examples
1024 --------
1025 .. code-block:: python
1027 @classParametersProduct(foo=[1, 2], bar=[3, 4])
1028 class MyTestCase(unittest.TestCase):
1029 ...
1031 will generate four classes, as if you wrote::
1033 .. code-block:: python
1035 class MyTestCase_1_3(unittest.TestCase):
1036 foo = 1
1037 bar = 3
1038 ...
1040 class MyTestCase_1_4(unittest.TestCase):
1041 foo = 1
1042 bar = 4
1043 ...
1045 class MyTestCase_2_3(unittest.TestCase):
1046 foo = 2
1047 bar = 3
1048 ...
1050 class MyTestCase_2_4(unittest.TestCase):
1051 foo = 2
1052 bar = 4
1053 ...
1055 Note that the values are embedded in the class name.
1056 """
1057 return classParameters(**_cartesianProduct(settings))
1060def methodParametersProduct(**settings: Sequence[Any]) -> Callable:
1061 """Iterate over cartesian product creating sub tests.
1063 This decorator iterates over the cartesian product of the supplied
1064 settings, using `~unittest.TestCase.subTest` to communicate the values in
1065 the event of a failure.
1067 Parameters
1068 ----------
1069 **settings : `dict` (`str`: iterable)
1070 The parameter combinations to test. Each should be an iterable.
1072 Example
1073 -------
1075 @methodParametersProduct(foo=[1, 2], bar=["black", "white"])
1076 def testSomething(self, foo, bar):
1077 ...
1079 will run:
1081 testSomething(foo=1, bar="black")
1082 testSomething(foo=1, bar="white")
1083 testSomething(foo=2, bar="black")
1084 testSomething(foo=2, bar="white")
1085 """
1086 return methodParameters(**_cartesianProduct(settings))
1089@contextlib.contextmanager
1090def temporaryDirectory() -> Iterator[str]:
1091 """Context manager that creates and destroys a temporary directory.
1093 The difference from `tempfile.TemporaryDirectory` is that this ignores
1094 errors when deleting a directory, which may happen with some filesystems.
1095 """
1096 tmpdir = tempfile.mkdtemp()
1097 yield tmpdir
1098 shutil.rmtree(tmpdir, ignore_errors=True)