Coverage for python/lsst/utils/tests.py: 33%
355 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-25 09:27 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-25 09:27 +0000
1# This file is part of utils.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# Use of this source code is governed by a 3-clause BSD-style
10# license that can be found in the LICENSE file.
12"""Support code for running unit tests"""
14from __future__ import annotations
16__all__ = [
17 "init",
18 "MemoryTestCase",
19 "ExecutablesTestCase",
20 "ImportTestCase",
21 "getTempFilePath",
22 "TestCase",
23 "assertFloatsAlmostEqual",
24 "assertFloatsNotEqual",
25 "assertFloatsEqual",
26 "debugger",
27 "classParameters",
28 "methodParameters",
29 "temporaryDirectory",
30]
32import contextlib
33import functools
34import gc
35import inspect
36import itertools
37import os
38import re
39import shutil
40import subprocess
41import sys
42import tempfile
43import unittest
44import warnings
45from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
46from importlib import resources
47from typing import Any, ClassVar
49import numpy
50import psutil
52from .doImport import doImport
54# Initialize the list of open files to an empty set
55open_files = set()
58def _get_open_files() -> set[str]:
59 """Return a set containing the list of files currently open in this
60 process.
62 Returns
63 -------
64 open_files : `set`
65 Set containing the list of open files.
66 """
67 return {p.path for p in psutil.Process().open_files()}
70def init() -> None:
71 """Initialize the memory tester and file descriptor leak tester."""
72 global open_files
73 # Reset the list of open files
74 open_files = _get_open_files()
77def sort_tests(tests) -> unittest.TestSuite:
78 """Sort supplied test suites such that MemoryTestCases are at the end.
80 `lsst.utils.tests.MemoryTestCase` tests should always run after any other
81 tests in the module.
83 Parameters
84 ----------
85 tests : sequence
86 Sequence of test suites.
88 Returns
89 -------
90 suite : `unittest.TestSuite`
91 A combined `~unittest.TestSuite` with
92 `~lsst.utils.tests.MemoryTestCase` at the end.
93 """
94 suite = unittest.TestSuite()
95 memtests = []
96 for test_suite in tests:
97 try:
98 # Just test the first test method in the suite for MemoryTestCase
99 # Use loop rather than next as it is possible for a test class
100 # to not have any test methods and the Python community prefers
101 # for loops over catching a StopIteration exception.
102 bases = None
103 for method in test_suite:
104 bases = inspect.getmro(method.__class__)
105 break
106 if bases is not None and MemoryTestCase in bases:
107 memtests.append(test_suite)
108 else:
109 suite.addTests(test_suite)
110 except TypeError:
111 if isinstance(test_suite, MemoryTestCase):
112 memtests.append(test_suite)
113 else:
114 suite.addTest(test_suite)
115 suite.addTests(memtests)
116 return suite
119def _suiteClassWrapper(tests):
120 return unittest.TestSuite(sort_tests(tests))
123# Replace the suiteClass callable in the defaultTestLoader
124# so that we can reorder the test ordering. This will have
125# no effect if no memory test cases are found.
126unittest.defaultTestLoader.suiteClass = _suiteClassWrapper
129class MemoryTestCase(unittest.TestCase):
130 """Check for resource leaks."""
132 ignore_regexps: ClassVar[list[str]] = []
133 """List of regexps to ignore when checking for open files."""
135 @classmethod
136 def tearDownClass(cls) -> None:
137 """Reset the leak counter when the tests have been completed"""
138 init()
140 def testFileDescriptorLeaks(self) -> None:
141 """Check if any file descriptors are open since init() called.
143 Ignores files with certain known path components and any files
144 that match regexp patterns in class property ``ignore_regexps``.
145 """
146 gc.collect()
147 global open_files
148 now_open = _get_open_files()
150 # Some files are opened out of the control of the stack.
151 now_open = {
152 f
153 for f in now_open
154 if not f.endswith(".car")
155 and not f.startswith("/proc/")
156 and not f.endswith(".ttf")
157 and not (f.startswith("/var/lib/") and f.endswith("/passwd"))
158 and not f.endswith("astropy.log")
159 and not f.endswith("mime/mime.cache")
160 and not f.endswith(".sqlite3")
161 and not any(re.search(r, f) for r in self.ignore_regexps)
162 }
164 diff = now_open.difference(open_files)
165 if diff:
166 for f in diff:
167 print(f"File open: {f}")
168 self.fail("Failed to close %d file%s" % (len(diff), "s" if len(diff) != 1 else ""))
171class ExecutablesTestCase(unittest.TestCase):
172 """Test that executables can be run and return good status.
174 The test methods are dynamically created. Callers
175 must subclass this class in their own test file and invoke
176 the create_executable_tests() class method to register the tests.
177 """
179 TESTS_DISCOVERED = -1
181 @classmethod
182 def setUpClass(cls) -> None:
183 """Abort testing if automated test creation was enabled and
184 no tests were found.
185 """
186 if cls.TESTS_DISCOVERED == 0:
187 raise RuntimeError("No executables discovered.")
189 def testSanity(self) -> None:
190 """Ensure that there is at least one test to be
191 executed. This allows the test runner to trigger the class set up
192 machinery to test whether there are some executables to test.
193 """
195 def assertExecutable(
196 self,
197 executable: str,
198 root_dir: str | None = None,
199 args: Sequence[str] | None = None,
200 msg: str | None = None,
201 ) -> None:
202 """Check an executable runs and returns good status.
204 Prints output to standard out. On bad exit status the test
205 fails. If the executable can not be located the test is skipped.
207 Parameters
208 ----------
209 executable : `str`
210 Path to an executable. ``root_dir`` is not used if this is an
211 absolute path.
212 root_dir : `str`, optional
213 Directory containing executable. Ignored if `None`.
214 args : `list` or `tuple`, optional
215 Arguments to be provided to the executable.
216 msg : `str`, optional
217 Message to use when the test fails. Can be `None` for default
218 message.
220 Raises
221 ------
222 AssertionError
223 The executable did not return 0 exit status.
224 """
225 if root_dir is not None and not os.path.isabs(executable):
226 executable = os.path.join(root_dir, executable)
228 # Form the argument list for subprocess
229 sp_args = [executable]
230 argstr = "no arguments"
231 if args is not None:
232 sp_args.extend(args)
233 argstr = 'arguments "' + " ".join(args) + '"'
235 print(f"Running executable '{executable}' with {argstr}...")
236 if not os.path.exists(executable):
237 self.skipTest(f"Executable {executable} is unexpectedly missing")
238 failmsg = None
239 try:
240 output = subprocess.check_output(sp_args)
241 except subprocess.CalledProcessError as e:
242 output = e.output
243 failmsg = f"Bad exit status from '{executable}': {e.returncode}"
244 print(output.decode("utf-8"))
245 if failmsg:
246 if msg is None:
247 msg = failmsg
248 self.fail(msg)
250 @classmethod
251 def _build_test_method(cls, executable: str, root_dir: str) -> None:
252 """Build a test method and attach to class.
254 A test method is created for the supplied excutable located
255 in the supplied root directory. This method is attached to the class
256 so that the test runner will discover the test and run it.
258 Parameters
259 ----------
260 cls : `object`
261 The class in which to create the tests.
262 executable : `str`
263 Name of executable. Can be absolute path.
264 root_dir : `str`
265 Path to executable. Not used if executable path is absolute.
266 """
267 if not os.path.isabs(executable): 267 ↛ 268line 267 didn't jump to line 268, because the condition on line 267 was never true
268 executable = os.path.abspath(os.path.join(root_dir, executable))
270 # Create the test name from the executable path.
271 test_name = "test_exe_" + executable.replace("/", "_")
273 # This is the function that will become the test method
274 def test_executable_runs(*args: Any) -> None:
275 self = args[0]
276 self.assertExecutable(executable)
278 # Give it a name and attach it to the class
279 test_executable_runs.__name__ = test_name
280 setattr(cls, test_name, test_executable_runs)
282 @classmethod
283 def create_executable_tests(cls, ref_file: str, executables: Sequence[str] | None = None) -> None:
284 """Discover executables to test and create corresponding test methods.
286 Scans the directory containing the supplied reference file
287 (usually ``__file__`` supplied from the test class) to look for
288 executables. If executables are found a test method is created
289 for each one. That test method will run the executable and
290 check the returned value.
292 Executable scripts with a ``.py`` extension and shared libraries
293 are ignored by the scanner.
295 This class method must be called before test discovery.
297 Parameters
298 ----------
299 ref_file : `str`
300 Path to a file within the directory to be searched.
301 If the files are in the same location as the test file, then
302 ``__file__`` can be used.
303 executables : `list` or `tuple`, optional
304 Sequence of executables that can override the automated
305 detection. If an executable mentioned here is not found, a
306 skipped test will be created for it, rather than a failed
307 test.
309 Examples
310 --------
311 >>> cls.create_executable_tests(__file__)
312 """
313 # Get the search directory from the reference file
314 ref_dir = os.path.abspath(os.path.dirname(ref_file))
316 if executables is None: 316 ↛ 331line 316 didn't jump to line 331, because the condition on line 316 was never false
317 # Look for executables to test by walking the tree
318 executables = []
319 for root, _, files in os.walk(ref_dir):
320 for f in files:
321 # Skip Python files. Shared libraries are executable.
322 if not f.endswith(".py") and not f.endswith(".so"):
323 full_path = os.path.join(root, f)
324 if os.access(full_path, os.X_OK):
325 executables.append(full_path)
327 # Store the number of tests found for later assessment.
328 # Do not raise an exception if we have no executables as this would
329 # cause the testing to abort before the test runner could properly
330 # integrate it into the failure report.
331 cls.TESTS_DISCOVERED = len(executables)
333 # Create the test functions and attach them to the class
334 for e in executables:
335 cls._build_test_method(e, ref_dir)
338class ImportTestCase(unittest.TestCase):
339 """Test that the named packages can be imported and all files within
340 that package.
342 The test methods are created dynamically. Callers must subclass this
343 method and define the ``PACKAGES`` property.
344 """
346 PACKAGES: ClassVar[Iterable[str]] = ()
347 """Packages to be imported."""
349 _n_registered = 0
350 """Number of packages registered for testing by this class."""
352 def _test_no_packages_registered_for_import_testing(self) -> None:
353 """Test when no packages have been registered.
355 Without this, if no packages have been listed no tests will be
356 registered and the test system will not report on anything. This
357 test fails and reports why.
358 """
359 raise AssertionError("No packages registered with import test. Was the PACKAGES property set?")
361 def __init_subclass__(cls, **kwargs: Any) -> None:
362 """Create the test methods based on the content of the ``PACKAGES``
363 class property.
364 """
365 super().__init_subclass__(**kwargs)
367 for mod in cls.PACKAGES:
368 test_name = "test_import_" + mod.replace(".", "_")
370 def test_import(*args: Any, mod=mod) -> None:
371 self = args[0]
372 self.assertImport(mod)
374 test_import.__name__ = test_name
375 setattr(cls, test_name, test_import)
376 cls._n_registered += 1
378 # If there are no packages listed that is likely a mistake and
379 # so register a failing test.
380 if cls._n_registered == 0: 380 ↛ 381line 380 didn't jump to line 381, because the condition on line 380 was never true
381 cls.test_no_packages_registered = cls._test_no_packages_registered_for_import_testing
383 def assertImport(self, root_pkg):
384 for file in resources.files(root_pkg).iterdir():
385 file = file.name
386 if not file.endswith(".py"):
387 continue
388 if file.startswith("__"):
389 continue
390 root, _ = os.path.splitext(file)
391 module_name = f"{root_pkg}.{root}"
392 with self.subTest(module=module_name):
393 try:
394 doImport(module_name)
395 except ImportError as e:
396 raise AssertionError(f"Error importing module {module_name}: {e}") from e
399@contextlib.contextmanager
400def getTempFilePath(ext: str, expectOutput: bool = True) -> Iterator[str]:
401 """Return a path suitable for a temporary file and try to delete the
402 file on success
404 If the with block completes successfully then the file is deleted,
405 if possible; failure results in a printed warning.
406 If a file is remains when it should not, a RuntimeError exception is
407 raised. This exception is also raised if a file is not present on context
408 manager exit when one is expected to exist.
409 If the block exits with an exception the file if left on disk so it can be
410 examined. The file name has a random component such that nested context
411 managers can be used with the same file suffix.
413 Parameters
414 ----------
415 ext : `str`
416 File name extension, e.g. ``.fits``.
417 expectOutput : `bool`, optional
418 If `True`, a file should be created within the context manager.
419 If `False`, a file should not be present when the context manager
420 exits.
422 Returns
423 -------
424 path : `str`
425 Path for a temporary file. The path is a combination of the caller's
426 file path and the name of the top-level function
428 Examples
429 --------
430 .. code-block:: python
432 # file tests/testFoo.py
433 import unittest
434 import lsst.utils.tests
435 class FooTestCase(unittest.TestCase):
436 def testBasics(self):
437 self.runTest()
439 def runTest(self):
440 with lsst.utils.tests.getTempFilePath(".fits") as tmpFile:
441 # if tests/.tests exists then
442 # tmpFile = "tests/.tests/testFoo_testBasics.fits"
443 # otherwise tmpFile = "testFoo_testBasics.fits"
444 ...
445 # at the end of this "with" block the path tmpFile will be
446 # deleted, but only if the file exists and the "with"
447 # block terminated normally (rather than with an exception)
448 ...
449 """
450 stack = inspect.stack()
451 # get name of first function in the file
452 for i in range(2, len(stack)):
453 frameInfo = inspect.getframeinfo(stack[i][0])
454 if i == 2:
455 callerFilePath = frameInfo.filename
456 callerFuncName = frameInfo.function
457 elif callerFilePath == frameInfo.filename:
458 # this function called the previous function
459 callerFuncName = frameInfo.function
460 else:
461 break
463 callerDir, callerFileNameWithExt = os.path.split(callerFilePath)
464 callerFileName = os.path.splitext(callerFileNameWithExt)[0]
465 outDir = os.path.join(callerDir, ".tests")
466 if not os.path.isdir(outDir):
467 outDir = ""
468 prefix = f"{callerFileName}_{callerFuncName}-"
469 outPath = tempfile.mktemp(dir=outDir, suffix=ext, prefix=prefix)
470 if os.path.exists(outPath):
471 # There should not be a file there given the randomizer. Warn and
472 # remove.
473 # Use stacklevel 3 so that the warning is reported from the end of the
474 # with block
475 warnings.warn(f"Unexpectedly found pre-existing tempfile named {outPath!r}", stacklevel=3)
476 with contextlib.suppress(OSError):
477 os.remove(outPath)
479 yield outPath
481 fileExists = os.path.exists(outPath)
482 if expectOutput:
483 if not fileExists:
484 raise RuntimeError(f"Temp file expected named {outPath} but none found")
485 else:
486 if fileExists:
487 raise RuntimeError(f"Unexpectedly discovered temp file named {outPath}")
488 # Try to clean up the file regardless
489 if fileExists:
490 try:
491 os.remove(outPath)
492 except OSError as e:
493 # Use stacklevel 3 so that the warning is reported from the end of
494 # the with block.
495 warnings.warn(f"Warning: could not remove file {outPath!r}: {e}", stacklevel=3)
498class TestCase(unittest.TestCase):
499 """Subclass of unittest.TestCase that adds some custom assertions for
500 convenience.
501 """
504def inTestCase(func: Callable) -> Callable:
505 """Add a free function to our custom TestCase class, while
506 also making it available as a free function.
507 """
508 setattr(TestCase, func.__name__, func)
509 return func
512def debugger(*exceptions):
513 """Enter the debugger when there's an uncaught exception
515 To use, just slap a ``@debugger()`` on your function.
517 You may provide specific exception classes to catch as arguments to
518 the decorator function, e.g.,
519 ``@debugger(RuntimeError, NotImplementedError)``.
520 This defaults to just `AssertionError`, for use on `unittest.TestCase`
521 methods.
523 Code provided by "Rosh Oxymoron" on StackOverflow:
524 http://stackoverflow.com/questions/4398967/python-unit-testing-automatically-running-the-debugger-when-a-test-fails
526 Notes
527 -----
528 Consider using ``pytest --pdb`` instead of this decorator.
529 """
530 if not exceptions:
531 exceptions = (Exception,)
533 def decorator(f):
534 @functools.wraps(f)
535 def wrapper(*args, **kwargs):
536 try:
537 return f(*args, **kwargs)
538 except exceptions:
539 import pdb
540 import sys
542 pdb.post_mortem(sys.exc_info()[2])
544 return wrapper
546 return decorator
549def plotImageDiff(
550 lhs: numpy.ndarray,
551 rhs: numpy.ndarray,
552 bad: numpy.ndarray | None = None,
553 diff: numpy.ndarray | None = None,
554 plotFileName: str | None = None,
555) -> None:
556 """Plot the comparison of two 2-d NumPy arrays.
558 Parameters
559 ----------
560 lhs : `numpy.ndarray`
561 LHS values to compare; a 2-d NumPy array
562 rhs : `numpy.ndarray`
563 RHS values to compare; a 2-d NumPy array
564 bad : `numpy.ndarray`
565 A 2-d boolean NumPy array of values to emphasize in the plots
566 diff : `numpy.ndarray`
567 difference array; a 2-d NumPy array, or None to show lhs-rhs
568 plotFileName : `str`
569 Filename to save the plot to. If None, the plot will be displayed in
570 a window.
572 Notes
573 -----
574 This method uses `matplotlib` and imports it internally; it should be
575 wrapped in a try/except block within packages that do not depend on
576 `matplotlib` (including `~lsst.utils`).
577 """
578 from matplotlib import pyplot
580 if diff is None:
581 diff = lhs - rhs
582 pyplot.figure()
583 if bad is not None:
584 # make an rgba image that's red and transparent where not bad
585 badImage = numpy.zeros(bad.shape + (4,), dtype=numpy.uint8)
586 badImage[:, :, 0] = 255
587 badImage[:, :, 1] = 0
588 badImage[:, :, 2] = 0
589 badImage[:, :, 3] = 255 * bad
590 vmin1 = numpy.minimum(numpy.min(lhs), numpy.min(rhs))
591 vmax1 = numpy.maximum(numpy.max(lhs), numpy.max(rhs))
592 vmin2 = numpy.min(diff)
593 vmax2 = numpy.max(diff)
594 for n, (image, title) in enumerate([(lhs, "lhs"), (rhs, "rhs"), (diff, "diff")]):
595 pyplot.subplot(2, 3, n + 1)
596 im1 = pyplot.imshow(
597 image, cmap=pyplot.cm.gray, interpolation="nearest", origin="lower", vmin=vmin1, vmax=vmax1
598 )
599 if bad is not None:
600 pyplot.imshow(badImage, alpha=0.2, interpolation="nearest", origin="lower")
601 pyplot.axis("off")
602 pyplot.title(title)
603 pyplot.subplot(2, 3, n + 4)
604 im2 = pyplot.imshow(
605 image, cmap=pyplot.cm.gray, interpolation="nearest", origin="lower", vmin=vmin2, vmax=vmax2
606 )
607 if bad is not None:
608 pyplot.imshow(badImage, alpha=0.2, interpolation="nearest", origin="lower")
609 pyplot.axis("off")
610 pyplot.title(title)
611 pyplot.subplots_adjust(left=0.05, bottom=0.05, top=0.92, right=0.75, wspace=0.05, hspace=0.05)
612 cax1 = pyplot.axes([0.8, 0.55, 0.05, 0.4])
613 pyplot.colorbar(im1, cax=cax1)
614 cax2 = pyplot.axes([0.8, 0.05, 0.05, 0.4])
615 pyplot.colorbar(im2, cax=cax2)
616 if plotFileName:
617 pyplot.savefig(plotFileName)
618 else:
619 pyplot.show()
622@inTestCase
623def assertFloatsAlmostEqual(
624 testCase: unittest.TestCase,
625 lhs: float | numpy.ndarray,
626 rhs: float | numpy.ndarray,
627 rtol: float | None = sys.float_info.epsilon,
628 atol: float | None = sys.float_info.epsilon,
629 relTo: float | None = None,
630 printFailures: bool = True,
631 plotOnFailure: bool = False,
632 plotFileName: str | None = None,
633 invert: bool = False,
634 msg: str | None = None,
635 ignoreNaNs: bool = False,
636) -> None:
637 """Highly-configurable floating point comparisons for scalars and arrays.
639 The test assertion will fail if all elements ``lhs`` and ``rhs`` are not
640 equal to within the tolerances specified by ``rtol`` and ``atol``.
641 More precisely, the comparison is:
643 ``abs(lhs - rhs) <= relTo*rtol OR abs(lhs - rhs) <= atol``
645 If ``rtol`` or ``atol`` is `None`, that term in the comparison is not
646 performed at all.
648 When not specified, ``relTo`` is the elementwise maximum of the absolute
649 values of ``lhs`` and ``rhs``. If set manually, it should usually be set
650 to either ``lhs`` or ``rhs``, or a scalar value typical of what is
651 expected.
653 Parameters
654 ----------
655 testCase : `unittest.TestCase`
656 Instance the test is part of.
657 lhs : scalar or array-like
658 LHS value(s) to compare; may be a scalar or array-like of any
659 dimension.
660 rhs : scalar or array-like
661 RHS value(s) to compare; may be a scalar or array-like of any
662 dimension.
663 rtol : `float`, optional
664 Relative tolerance for comparison; defaults to double-precision
665 epsilon.
666 atol : `float`, optional
667 Absolute tolerance for comparison; defaults to double-precision
668 epsilon.
669 relTo : `float`, optional
670 Value to which comparison with rtol is relative.
671 printFailures : `bool`, optional
672 Upon failure, print all inequal elements as part of the message.
673 plotOnFailure : `bool`, optional
674 Upon failure, plot the originals and their residual with matplotlib.
675 Only 2-d arrays are supported.
676 plotFileName : `str`, optional
677 Filename to save the plot to. If `None`, the plot will be displayed in
678 a window.
679 invert : `bool`, optional
680 If `True`, invert the comparison and fail only if any elements *are*
681 equal. Used to implement `~lsst.utils.tests.assertFloatsNotEqual`,
682 which should generally be used instead for clarity.
683 will return `True`).
684 msg : `str`, optional
685 String to append to the error message when assert fails.
686 ignoreNaNs : `bool`, optional
687 If `True` (`False` is default) mask out any NaNs from operand arrays
688 before performing comparisons if they are in the same locations; NaNs
689 in different locations are trigger test assertion failures, even when
690 ``invert=True``. Scalar NaNs are treated like arrays containing only
691 NaNs of the same shape as the other operand, and no comparisons are
692 performed if both sides are scalar NaNs.
694 Raises
695 ------
696 AssertionError
697 The values are not almost equal.
698 """
699 if ignoreNaNs:
700 lhsMask = numpy.isnan(lhs)
701 rhsMask = numpy.isnan(rhs)
702 if not numpy.all(lhsMask == rhsMask):
703 testCase.fail(
704 f"lhs has {lhsMask.sum()} NaN values and rhs has {rhsMask.sum()} NaN values, "
705 "in different locations."
706 )
707 if numpy.all(lhsMask):
708 assert numpy.all(rhsMask), "Should be guaranteed by previous if."
709 # All operands are fully NaN (either scalar NaNs or arrays of only
710 # NaNs).
711 return
712 assert not numpy.all(rhsMask), "Should be guaranteed by prevoius two ifs."
713 # If either operand is an array select just its not-NaN values. Note
714 # that these expressions are never True for scalar operands, because if
715 # they are NaN then the numpy.all checks above will catch them.
716 if numpy.any(lhsMask):
717 lhs = lhs[numpy.logical_not(lhsMask)]
718 if numpy.any(rhsMask):
719 rhs = rhs[numpy.logical_not(rhsMask)]
720 if not numpy.isfinite(lhs).all():
721 testCase.fail("Non-finite values in lhs")
722 if not numpy.isfinite(rhs).all():
723 testCase.fail("Non-finite values in rhs")
724 diff = lhs - rhs
725 absDiff = numpy.abs(lhs - rhs)
726 if rtol is not None:
727 if relTo is None:
728 relTo = numpy.maximum(numpy.abs(lhs), numpy.abs(rhs))
729 else:
730 relTo = numpy.abs(relTo)
731 bad = absDiff > rtol * relTo
732 if atol is not None:
733 bad = numpy.logical_and(bad, absDiff > atol)
734 else:
735 if atol is None:
736 raise ValueError("rtol and atol cannot both be None")
737 bad = absDiff > atol
738 failed = numpy.any(bad)
739 if invert:
740 failed = not failed
741 bad = numpy.logical_not(bad)
742 cmpStr = "=="
743 failStr = "are the same"
744 else:
745 cmpStr = "!="
746 failStr = "differ"
747 errMsg = []
748 if failed:
749 if numpy.isscalar(bad):
750 if rtol is None:
751 errMsg = [f"{lhs} {cmpStr} {rhs}; diff={absDiff} with atol={atol}"]
752 elif atol is None:
753 errMsg = [f"{lhs} {cmpStr} {rhs}; diff={absDiff}/{relTo}={absDiff / relTo} with rtol={rtol}"]
754 else:
755 errMsg = [
756 f"{lhs} {cmpStr} {rhs}; diff={absDiff}/{relTo}={absDiff / relTo} "
757 f"with rtol={rtol}, atol={atol}"
758 ]
759 else:
760 errMsg = [f"{bad.sum()}/{bad.size} elements {failStr} with rtol={rtol}, atol={atol}"]
761 if plotOnFailure:
762 if len(lhs.shape) != 2 or len(rhs.shape) != 2:
763 raise ValueError("plotOnFailure is only valid for 2-d arrays")
764 try:
765 plotImageDiff(lhs, rhs, bad, diff=diff, plotFileName=plotFileName)
766 except ImportError:
767 errMsg.append("Failure plot requested but matplotlib could not be imported.")
768 if printFailures:
769 # Make sure everything is an array if any of them are, so we
770 # can treat them the same (diff and absDiff are arrays if
771 # either rhs or lhs is), and we don't get here if neither is.
772 if numpy.isscalar(relTo):
773 relTo = numpy.ones(bad.shape, dtype=float) * relTo
774 if numpy.isscalar(lhs):
775 lhs = numpy.ones(bad.shape, dtype=float) * lhs
776 if numpy.isscalar(rhs):
777 rhs = numpy.ones(bad.shape, dtype=float) * rhs
778 if rtol is None:
779 for a, b, diff in zip(lhs[bad], rhs[bad], absDiff[bad]):
780 errMsg.append(f"{a} {cmpStr} {b} (diff={diff})")
781 else:
782 for a, b, diff, rel in zip(lhs[bad], rhs[bad], absDiff[bad], relTo[bad]):
783 errMsg.append(f"{a} {cmpStr} {b} (diff={diff}/{rel}={diff / rel})")
785 if msg is not None:
786 errMsg.append(msg)
787 testCase.assertFalse(failed, msg="\n".join(errMsg))
790@inTestCase
791def assertFloatsNotEqual(
792 testCase: unittest.TestCase,
793 lhs: float | numpy.ndarray,
794 rhs: float | numpy.ndarray,
795 **kwds: Any,
796) -> None:
797 """Fail a test if the given floating point values are equal to within the
798 given tolerances.
800 See `~lsst.utils.tests.assertFloatsAlmostEqual` (called with
801 ``rtol=atol=0``) for more information.
803 Parameters
804 ----------
805 testCase : `unittest.TestCase`
806 Instance the test is part of.
807 lhs : scalar or array-like
808 LHS value(s) to compare; may be a scalar or array-like of any
809 dimension.
810 rhs : scalar or array-like
811 RHS value(s) to compare; may be a scalar or array-like of any
812 dimension.
814 Raises
815 ------
816 AssertionError
817 The values are almost equal.
818 """
819 return assertFloatsAlmostEqual(testCase, lhs, rhs, invert=True, **kwds)
822@inTestCase
823def assertFloatsEqual(
824 testCase: unittest.TestCase,
825 lhs: float | numpy.ndarray,
826 rhs: float | numpy.ndarray,
827 **kwargs: Any,
828) -> None:
829 """
830 Assert that lhs == rhs (both numeric types, whether scalar or array).
832 See `~lsst.utils.tests.assertFloatsAlmostEqual` (called with
833 ``rtol=atol=0``) for more information.
835 Parameters
836 ----------
837 testCase : `unittest.TestCase`
838 Instance the test is part of.
839 lhs : scalar or array-like
840 LHS value(s) to compare; may be a scalar or array-like of any
841 dimension.
842 rhs : scalar or array-like
843 RHS value(s) to compare; may be a scalar or array-like of any
844 dimension.
846 Raises
847 ------
848 AssertionError
849 The values are not equal.
850 """
851 return assertFloatsAlmostEqual(testCase, lhs, rhs, rtol=0, atol=0, **kwargs)
854def _settingsIterator(settings: dict[str, Sequence[Any]]) -> Iterator[dict[str, Any]]:
855 """Return an iterator for the provided test settings
857 Parameters
858 ----------
859 settings : `dict` (`str`: iterable)
860 Lists of test parameters. Each should be an iterable of the same
861 length. If a string is provided as an iterable, it will be converted
862 to a list of a single string.
864 Raises
865 ------
866 AssertionError
867 If the ``settings`` are not of the same length.
869 Yields
870 ------
871 parameters : `dict` (`str`: anything)
872 Set of parameters.
873 """
874 for name, values in settings.items():
875 if isinstance(values, str): 875 ↛ 878line 875 didn't jump to line 878, because the condition on line 875 was never true
876 # Probably meant as a single-element string, rather than an
877 # iterable of chars.
878 settings[name] = [values]
879 num = len(next(iter(settings.values()))) # Number of settings
880 for name, values in settings.items():
881 assert len(values) == num, f"Length mismatch for setting {name}: {len(values)} vs {num}"
882 for ii in range(num):
883 values = [settings[kk][ii] for kk in settings]
884 yield dict(zip(settings, values))
887def classParameters(**settings: Sequence[Any]) -> Callable:
888 """Class decorator for generating unit tests
890 This decorator generates classes with class variables according to the
891 supplied ``settings``.
893 Parameters
894 ----------
895 **settings : `dict` (`str`: iterable)
896 The lists of test parameters to set as class variables in turn. Each
897 should be an iterable of the same length.
899 Examples
900 --------
901 ::
903 @classParameters(foo=[1, 2], bar=[3, 4])
904 class MyTestCase(unittest.TestCase):
905 ...
907 will generate two classes, as if you wrote::
909 class MyTestCase_1_3(unittest.TestCase):
910 foo = 1
911 bar = 3
912 ...
914 class MyTestCase_2_4(unittest.TestCase):
915 foo = 2
916 bar = 4
917 ...
919 Note that the values are embedded in the class name.
920 """
922 def decorator(cls: type) -> None:
923 module = sys.modules[cls.__module__].__dict__
924 for params in _settingsIterator(settings):
925 name = f"{cls.__name__}_{'_'.join(str(vv) for vv in params.values())}"
926 bindings = dict(cls.__dict__)
927 bindings.update(params)
928 module[name] = type(name, (cls,), bindings)
930 return decorator
933def methodParameters(**settings: Sequence[Any]) -> Callable:
934 """Iterate over supplied settings to create subtests automatically.
936 This decorator iterates over the supplied settings, using
937 ``TestCase.subTest`` to communicate the values in the event of a failure.
939 Parameters
940 ----------
941 **settings : `dict` (`str`: iterable)
942 The lists of test parameters. Each should be an iterable of the same
943 length.
945 Examples
946 --------
947 .. code-block:: python
949 @methodParameters(foo=[1, 2], bar=[3, 4])
950 def testSomething(self, foo, bar):
951 ...
953 will run:
955 .. code-block:: python
957 testSomething(foo=1, bar=3)
958 testSomething(foo=2, bar=4)
959 """
961 def decorator(func: Callable) -> Callable:
962 @functools.wraps(func)
963 def wrapper(self: unittest.TestCase, *args: Any, **kwargs: Any) -> None:
964 for params in _settingsIterator(settings):
965 kwargs.update(params)
966 with self.subTest(**params):
967 func(self, *args, **kwargs)
969 return wrapper
971 return decorator
974def _cartesianProduct(settings: Mapping[str, Sequence[Any]]) -> Mapping[str, Sequence[Any]]:
975 """Return the cartesian product of the settings
977 Parameters
978 ----------
979 settings : `dict` mapping `str` to `iterable`
980 Parameter combinations.
982 Returns
983 -------
984 product : `dict` mapping `str` to `iterable`
985 Parameter combinations covering the cartesian product (all possible
986 combinations) of the input parameters.
988 Examples
989 --------
990 .. code-block:: python
992 cartesianProduct({"foo": [1, 2], "bar": ["black", "white"]})
994 will return:
996 .. code-block:: python
998 {"foo": [1, 1, 2, 2], "bar": ["black", "white", "black", "white"]}
999 """
1000 product: dict[str, list[Any]] = {kk: [] for kk in settings}
1001 for values in itertools.product(*settings.values()):
1002 for kk, vv in zip(settings.keys(), values):
1003 product[kk].append(vv)
1004 return product
1007def classParametersProduct(**settings: Sequence[Any]) -> Callable:
1008 """Class decorator for generating unit tests
1010 This decorator generates classes with class variables according to the
1011 cartesian product of the supplied ``settings``.
1013 Parameters
1014 ----------
1015 **settings : `dict` (`str`: iterable)
1016 The lists of test parameters to set as class variables in turn. Each
1017 should be an iterable.
1019 Examples
1020 --------
1021 .. code-block:: python
1023 @classParametersProduct(foo=[1, 2], bar=[3, 4])
1024 class MyTestCase(unittest.TestCase):
1025 ...
1027 will generate four classes, as if you wrote::
1029 .. code-block:: python
1031 class MyTestCase_1_3(unittest.TestCase):
1032 foo = 1
1033 bar = 3
1034 ...
1036 class MyTestCase_1_4(unittest.TestCase):
1037 foo = 1
1038 bar = 4
1039 ...
1041 class MyTestCase_2_3(unittest.TestCase):
1042 foo = 2
1043 bar = 3
1044 ...
1046 class MyTestCase_2_4(unittest.TestCase):
1047 foo = 2
1048 bar = 4
1049 ...
1051 Note that the values are embedded in the class name.
1052 """
1053 return classParameters(**_cartesianProduct(settings))
1056def methodParametersProduct(**settings: Sequence[Any]) -> Callable:
1057 """Iterate over cartesian product creating sub tests.
1059 This decorator iterates over the cartesian product of the supplied
1060 settings, using `~unittest.TestCase.subTest` to communicate the values in
1061 the event of a failure.
1063 Parameters
1064 ----------
1065 **settings : `dict` (`str`: iterable)
1066 The parameter combinations to test. Each should be an iterable.
1068 Example
1069 -------
1071 @methodParametersProduct(foo=[1, 2], bar=["black", "white"])
1072 def testSomething(self, foo, bar):
1073 ...
1075 will run:
1077 testSomething(foo=1, bar="black")
1078 testSomething(foo=1, bar="white")
1079 testSomething(foo=2, bar="black")
1080 testSomething(foo=2, bar="white")
1081 """
1082 return methodParameters(**_cartesianProduct(settings))
1085@contextlib.contextmanager
1086def temporaryDirectory() -> Iterator[str]:
1087 """Context manager that creates and destroys a temporary directory.
1089 The difference from `tempfile.TemporaryDirectory` is that this ignores
1090 errors when deleting a directory, which may happen with some filesystems.
1091 """
1092 tmpdir = tempfile.mkdtemp()
1093 yield tmpdir
1094 shutil.rmtree(tmpdir, ignore_errors=True)