Coverage for python/lsst/utils/tests.py: 34%
361 statements
« prev ^ index » next coverage.py v7.4.3, created at 2024-02-27 11:49 +0000
« prev ^ index » next coverage.py v7.4.3, created at 2024-02-27 11:49 +0000
1# This file is part of utils.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# Use of this source code is governed by a 3-clause BSD-style
10# license that can be found in the LICENSE file.
12"""Support code for running unit tests."""
14from __future__ import annotations
16__all__ = [
17 "init",
18 "MemoryTestCase",
19 "ExecutablesTestCase",
20 "ImportTestCase",
21 "getTempFilePath",
22 "TestCase",
23 "assertFloatsAlmostEqual",
24 "assertFloatsNotEqual",
25 "assertFloatsEqual",
26 "debugger",
27 "classParameters",
28 "methodParameters",
29 "temporaryDirectory",
30]
32import contextlib
33import functools
34import gc
35import inspect
36import itertools
37import os
38import re
39import shutil
40import subprocess
41import sys
42import tempfile
43import unittest
44import warnings
45from collections.abc import Callable, Container, Iterable, Iterator, Mapping, Sequence
47if sys.version_info < (3, 10, 0): 47 ↛ 48line 47 didn't jump to line 48, because the condition on line 47 was never true
48 import importlib_resources as resources
49else:
50 from importlib import resources
52from typing import Any, ClassVar
54import numpy
55import psutil
57from .doImport import doImport
59# Initialize the list of open files to an empty set
60open_files = set()
63def _get_open_files() -> set[str]:
64 """Return a set containing the list of files currently open in this
65 process.
67 Returns
68 -------
69 open_files : `set`
70 Set containing the list of open files.
71 """
72 return {p.path for p in psutil.Process().open_files()}
75def init() -> None:
76 """Initialize the memory tester and file descriptor leak tester."""
77 global open_files
78 # Reset the list of open files
79 open_files = _get_open_files()
82def sort_tests(tests) -> unittest.TestSuite:
83 """Sort supplied test suites such that MemoryTestCases are at the end.
85 `lsst.utils.tests.MemoryTestCase` tests should always run after any other
86 tests in the module.
88 Parameters
89 ----------
90 tests : sequence
91 Sequence of test suites.
93 Returns
94 -------
95 suite : `unittest.TestSuite`
96 A combined `~unittest.TestSuite` with
97 `~lsst.utils.tests.MemoryTestCase` at the end.
98 """
99 suite = unittest.TestSuite()
100 memtests = []
101 for test_suite in tests:
102 try:
103 # Just test the first test method in the suite for MemoryTestCase
104 # Use loop rather than next as it is possible for a test class
105 # to not have any test methods and the Python community prefers
106 # for loops over catching a StopIteration exception.
107 bases = None
108 for method in test_suite:
109 bases = inspect.getmro(method.__class__)
110 break
111 if bases is not None and MemoryTestCase in bases:
112 memtests.append(test_suite)
113 else:
114 suite.addTests(test_suite)
115 except TypeError:
116 if isinstance(test_suite, MemoryTestCase):
117 memtests.append(test_suite)
118 else:
119 suite.addTest(test_suite)
120 suite.addTests(memtests)
121 return suite
124def _suiteClassWrapper(tests):
125 return unittest.TestSuite(sort_tests(tests))
128# Replace the suiteClass callable in the defaultTestLoader
129# so that we can reorder the test ordering. This will have
130# no effect if no memory test cases are found.
131unittest.defaultTestLoader.suiteClass = _suiteClassWrapper
134class MemoryTestCase(unittest.TestCase):
135 """Check for resource leaks."""
137 ignore_regexps: ClassVar[list[str]] = []
138 """List of regexps to ignore when checking for open files."""
140 @classmethod
141 def tearDownClass(cls) -> None:
142 """Reset the leak counter when the tests have been completed."""
143 init()
145 def testFileDescriptorLeaks(self) -> None:
146 """Check if any file descriptors are open since init() called.
148 Ignores files with certain known path components and any files
149 that match regexp patterns in class property ``ignore_regexps``.
150 """
151 gc.collect()
152 global open_files
153 now_open = _get_open_files()
155 # Some files are opened out of the control of the stack.
156 now_open = {
157 f
158 for f in now_open
159 if not f.endswith(".car")
160 and not f.startswith("/proc/")
161 and not f.endswith(".ttf")
162 and not (f.startswith("/var/lib/") and f.endswith("/passwd"))
163 and not f.endswith("astropy.log")
164 and not f.endswith("mime/mime.cache")
165 and not f.endswith(".sqlite3")
166 and not any(re.search(r, f) for r in self.ignore_regexps)
167 }
169 diff = now_open.difference(open_files)
170 if diff:
171 for f in diff:
172 print(f"File open: {f}")
173 self.fail("Failed to close %d file%s" % (len(diff), "s" if len(diff) != 1 else ""))
176class ExecutablesTestCase(unittest.TestCase):
177 """Test that executables can be run and return good status.
179 The test methods are dynamically created. Callers
180 must subclass this class in their own test file and invoke
181 the create_executable_tests() class method to register the tests.
182 """
184 TESTS_DISCOVERED = -1
186 @classmethod
187 def setUpClass(cls) -> None:
188 """Abort testing if automated test creation was enabled and
189 no tests were found.
190 """
191 if cls.TESTS_DISCOVERED == 0:
192 raise RuntimeError("No executables discovered.")
194 def testSanity(self) -> None:
195 """Ensure that there is at least one test to be
196 executed. This allows the test runner to trigger the class set up
197 machinery to test whether there are some executables to test.
198 """
200 def assertExecutable(
201 self,
202 executable: str,
203 root_dir: str | None = None,
204 args: Sequence[str] | None = None,
205 msg: str | None = None,
206 ) -> None:
207 """Check an executable runs and returns good status.
209 Prints output to standard out. On bad exit status the test
210 fails. If the executable can not be located the test is skipped.
212 Parameters
213 ----------
214 executable : `str`
215 Path to an executable. ``root_dir`` is not used if this is an
216 absolute path.
217 root_dir : `str`, optional
218 Directory containing executable. Ignored if `None`.
219 args : `list` or `tuple`, optional
220 Arguments to be provided to the executable.
221 msg : `str`, optional
222 Message to use when the test fails. Can be `None` for default
223 message.
225 Raises
226 ------
227 AssertionError
228 The executable did not return 0 exit status.
229 """
230 if root_dir is not None and not os.path.isabs(executable):
231 executable = os.path.join(root_dir, executable)
233 # Form the argument list for subprocess
234 sp_args = [executable]
235 argstr = "no arguments"
236 if args is not None:
237 sp_args.extend(args)
238 argstr = 'arguments "' + " ".join(args) + '"'
240 print(f"Running executable '{executable}' with {argstr}...")
241 if not os.path.exists(executable):
242 self.skipTest(f"Executable {executable} is unexpectedly missing")
243 failmsg = None
244 try:
245 output = subprocess.check_output(sp_args)
246 except subprocess.CalledProcessError as e:
247 output = e.output
248 failmsg = f"Bad exit status from '{executable}': {e.returncode}"
249 print(output.decode("utf-8"))
250 if failmsg:
251 if msg is None:
252 msg = failmsg
253 self.fail(msg)
255 @classmethod
256 def _build_test_method(cls, executable: str, root_dir: str) -> None:
257 """Build a test method and attach to class.
259 A test method is created for the supplied excutable located
260 in the supplied root directory. This method is attached to the class
261 so that the test runner will discover the test and run it.
263 Parameters
264 ----------
265 cls : `object`
266 The class in which to create the tests.
267 executable : `str`
268 Name of executable. Can be absolute path.
269 root_dir : `str`
270 Path to executable. Not used if executable path is absolute.
271 """
272 if not os.path.isabs(executable): 272 ↛ 273line 272 didn't jump to line 273, because the condition on line 272 was never true
273 executable = os.path.abspath(os.path.join(root_dir, executable))
275 # Create the test name from the executable path.
276 test_name = "test_exe_" + executable.replace("/", "_")
278 # This is the function that will become the test method
279 def test_executable_runs(*args: Any) -> None:
280 self = args[0]
281 self.assertExecutable(executable)
283 # Give it a name and attach it to the class
284 test_executable_runs.__name__ = test_name
285 setattr(cls, test_name, test_executable_runs)
287 @classmethod
288 def create_executable_tests(cls, ref_file: str, executables: Sequence[str] | None = None) -> None:
289 """Discover executables to test and create corresponding test methods.
291 Scans the directory containing the supplied reference file
292 (usually ``__file__`` supplied from the test class) to look for
293 executables. If executables are found a test method is created
294 for each one. That test method will run the executable and
295 check the returned value.
297 Executable scripts with a ``.py`` extension and shared libraries
298 are ignored by the scanner.
300 This class method must be called before test discovery.
302 Parameters
303 ----------
304 ref_file : `str`
305 Path to a file within the directory to be searched.
306 If the files are in the same location as the test file, then
307 ``__file__`` can be used.
308 executables : `list` or `tuple`, optional
309 Sequence of executables that can override the automated
310 detection. If an executable mentioned here is not found, a
311 skipped test will be created for it, rather than a failed
312 test.
314 Examples
315 --------
316 >>> cls.create_executable_tests(__file__)
317 """
318 # Get the search directory from the reference file
319 ref_dir = os.path.abspath(os.path.dirname(ref_file))
321 if executables is None: 321 ↛ 336line 321 didn't jump to line 336, because the condition on line 321 was never false
322 # Look for executables to test by walking the tree
323 executables = []
324 for root, _, files in os.walk(ref_dir):
325 for f in files:
326 # Skip Python files. Shared libraries are executable.
327 if not f.endswith(".py") and not f.endswith(".so"):
328 full_path = os.path.join(root, f)
329 if os.access(full_path, os.X_OK):
330 executables.append(full_path)
332 # Store the number of tests found for later assessment.
333 # Do not raise an exception if we have no executables as this would
334 # cause the testing to abort before the test runner could properly
335 # integrate it into the failure report.
336 cls.TESTS_DISCOVERED = len(executables)
338 # Create the test functions and attach them to the class
339 for e in executables:
340 cls._build_test_method(e, ref_dir)
343class ImportTestCase(unittest.TestCase):
344 """Test that the named packages can be imported and all files within
345 that package.
347 The test methods are created dynamically. Callers must subclass this
348 method and define the ``PACKAGES`` property.
349 """
351 PACKAGES: ClassVar[Iterable[str]] = ()
352 """Packages to be imported."""
354 SKIP_FILES: ClassVar[Mapping[str, Container[str]]] = {}
355 """Files to be skipped importing; specified as key-value pairs.
357 The key is the package name and the value is a set of files names in that
358 package to skip.
360 Note: Files with names not ending in .py or beginning with leading double
361 underscores are always skipped.
362 """
364 _n_registered = 0
365 """Number of packages registered for testing by this class."""
367 def _test_no_packages_registered_for_import_testing(self) -> None:
368 """Test when no packages have been registered.
370 Without this, if no packages have been listed no tests will be
371 registered and the test system will not report on anything. This
372 test fails and reports why.
373 """
374 raise AssertionError("No packages registered with import test. Was the PACKAGES property set?")
376 def __init_subclass__(cls, **kwargs: Any) -> None:
377 """Create the test methods based on the content of the ``PACKAGES``
378 class property.
379 """
380 super().__init_subclass__(**kwargs)
382 for mod in cls.PACKAGES:
383 test_name = "test_import_" + mod.replace(".", "_")
385 def test_import(*args: Any, mod=mod) -> None:
386 self = args[0]
387 self.assertImport(mod)
389 test_import.__name__ = test_name
390 setattr(cls, test_name, test_import)
391 cls._n_registered += 1
393 # If there are no packages listed that is likely a mistake and
394 # so register a failing test.
395 if cls._n_registered == 0: 395 ↛ 396line 395 didn't jump to line 396, because the condition on line 395 was never true
396 cls.test_no_packages_registered = cls._test_no_packages_registered_for_import_testing
398 def assertImport(self, root_pkg):
399 for file in resources.files(root_pkg).iterdir():
400 file = file.name
401 # When support for python 3.9 is dropped, this could be updated to
402 # use match case construct.
403 if not file.endswith(".py"):
404 continue
405 if file.startswith("__"):
406 continue
407 if file in self.SKIP_FILES.get(root_pkg, ()):
408 continue
409 root, _ = os.path.splitext(file)
410 module_name = f"{root_pkg}.{root}"
411 with self.subTest(module=module_name):
412 try:
413 doImport(module_name)
414 except ImportError as e:
415 raise AssertionError(f"Error importing module {module_name}: {e}") from e
418@contextlib.contextmanager
419def getTempFilePath(ext: str, expectOutput: bool = True) -> Iterator[str]:
420 """Return a path suitable for a temporary file and try to delete the
421 file on success.
423 If the with block completes successfully then the file is deleted,
424 if possible; failure results in a printed warning.
425 If a file is remains when it should not, a RuntimeError exception is
426 raised. This exception is also raised if a file is not present on context
427 manager exit when one is expected to exist.
428 If the block exits with an exception the file if left on disk so it can be
429 examined. The file name has a random component such that nested context
430 managers can be used with the same file suffix.
432 Parameters
433 ----------
434 ext : `str`
435 File name extension, e.g. ``.fits``.
436 expectOutput : `bool`, optional
437 If `True`, a file should be created within the context manager.
438 If `False`, a file should not be present when the context manager
439 exits.
441 Yields
442 ------
443 path : `str`
444 Path for a temporary file. The path is a combination of the caller's
445 file path and the name of the top-level function.
447 Examples
448 --------
449 .. code-block:: python
451 # file tests/testFoo.py
452 import unittest
453 import lsst.utils.tests
454 class FooTestCase(unittest.TestCase):
455 def testBasics(self):
456 self.runTest()
458 def runTest(self):
459 with lsst.utils.tests.getTempFilePath(".fits") as tmpFile:
460 # if tests/.tests exists then
461 # tmpFile = "tests/.tests/testFoo_testBasics.fits"
462 # otherwise tmpFile = "testFoo_testBasics.fits"
463 ...
464 # at the end of this "with" block the path tmpFile will be
465 # deleted, but only if the file exists and the "with"
466 # block terminated normally (rather than with an exception)
467 ...
468 """
469 stack = inspect.stack()
470 # get name of first function in the file
471 for i in range(2, len(stack)):
472 frameInfo = inspect.getframeinfo(stack[i][0])
473 if i == 2:
474 callerFilePath = frameInfo.filename
475 callerFuncName = frameInfo.function
476 elif callerFilePath == frameInfo.filename:
477 # this function called the previous function
478 callerFuncName = frameInfo.function
479 else:
480 break
482 callerDir, callerFileNameWithExt = os.path.split(callerFilePath)
483 callerFileName = os.path.splitext(callerFileNameWithExt)[0]
484 outDir = os.path.join(callerDir, ".tests")
485 if not os.path.isdir(outDir):
486 # No .tests directory implies we are not running with sconsUtils.
487 # Need to use the current working directory, the callerDir, or
488 # /tmp equivalent. If cwd is used if must be as an absolute path
489 # in case the test code changes cwd.
490 outDir = os.path.abspath(os.path.curdir)
491 prefix = f"{callerFileName}_{callerFuncName}-"
492 outPath = tempfile.mktemp(dir=outDir, suffix=ext, prefix=prefix)
493 if os.path.exists(outPath):
494 # There should not be a file there given the randomizer. Warn and
495 # remove.
496 # Use stacklevel 3 so that the warning is reported from the end of the
497 # with block
498 warnings.warn(f"Unexpectedly found pre-existing tempfile named {outPath!r}", stacklevel=3)
499 with contextlib.suppress(OSError):
500 os.remove(outPath)
502 yield outPath
504 fileExists = os.path.exists(outPath)
505 if expectOutput:
506 if not fileExists:
507 raise RuntimeError(f"Temp file expected named {outPath} but none found")
508 else:
509 if fileExists:
510 raise RuntimeError(f"Unexpectedly discovered temp file named {outPath}")
511 # Try to clean up the file regardless
512 if fileExists:
513 try:
514 os.remove(outPath)
515 except OSError as e:
516 # Use stacklevel 3 so that the warning is reported from the end of
517 # the with block.
518 warnings.warn(f"Warning: could not remove file {outPath!r}: {e}", stacklevel=3)
521class TestCase(unittest.TestCase):
522 """Subclass of unittest.TestCase that adds some custom assertions for
523 convenience.
524 """
527def inTestCase(func: Callable) -> Callable:
528 """Add a free function to our custom TestCase class, while
529 also making it available as a free function.
531 Parameters
532 ----------
533 func : `~collections.abc.Callable`
534 Function to be added to `unittest.TestCase` class.
536 Returns
537 -------
538 func : `~collections.abc.Callable`
539 The given function.
540 """
541 setattr(TestCase, func.__name__, func)
542 return func
545def debugger(*exceptions):
546 """Enter the debugger when there's an uncaught exception.
548 To use, just slap a ``@debugger()`` on your function.
550 You may provide specific exception classes to catch as arguments to
551 the decorator function, e.g.,
552 ``@debugger(RuntimeError, NotImplementedError)``.
553 This defaults to just `AssertionError`, for use on `unittest.TestCase`
554 methods.
556 Code provided by "Rosh Oxymoron" on StackOverflow:
557 http://stackoverflow.com/questions/4398967/python-unit-testing-automatically-running-the-debugger-when-a-test-fails
559 Parameters
560 ----------
561 *exceptions : `Exception`
562 Specific exception classes to catch. Default is to catch
563 `AssertionError`.
565 Notes
566 -----
567 Consider using ``pytest --pdb`` instead of this decorator.
568 """
569 if not exceptions:
570 exceptions = (Exception,)
572 def decorator(f):
573 @functools.wraps(f)
574 def wrapper(*args, **kwargs):
575 try:
576 return f(*args, **kwargs)
577 except exceptions:
578 import pdb
579 import sys
581 pdb.post_mortem(sys.exc_info()[2])
583 return wrapper
585 return decorator
588def plotImageDiff(
589 lhs: numpy.ndarray,
590 rhs: numpy.ndarray,
591 bad: numpy.ndarray | None = None,
592 diff: numpy.ndarray | None = None,
593 plotFileName: str | None = None,
594) -> None:
595 """Plot the comparison of two 2-d NumPy arrays.
597 Parameters
598 ----------
599 lhs : `numpy.ndarray`
600 LHS values to compare; a 2-d NumPy array.
601 rhs : `numpy.ndarray`
602 RHS values to compare; a 2-d NumPy array.
603 bad : `numpy.ndarray`
604 A 2-d boolean NumPy array of values to emphasize in the plots.
605 diff : `numpy.ndarray`
606 Difference array; a 2-d NumPy array, or None to show lhs-rhs.
607 plotFileName : `str`
608 Filename to save the plot to. If None, the plot will be displayed in
609 a window.
611 Notes
612 -----
613 This method uses `matplotlib` and imports it internally; it should be
614 wrapped in a try/except block within packages that do not depend on
615 `matplotlib` (including `~lsst.utils`).
616 """
617 from matplotlib import pyplot
619 if diff is None:
620 diff = lhs - rhs
621 pyplot.figure()
622 if bad is not None:
623 # make an rgba image that's red and transparent where not bad
624 badImage = numpy.zeros(bad.shape + (4,), dtype=numpy.uint8)
625 badImage[:, :, 0] = 255
626 badImage[:, :, 1] = 0
627 badImage[:, :, 2] = 0
628 badImage[:, :, 3] = 255 * bad
629 vmin1 = numpy.minimum(numpy.min(lhs), numpy.min(rhs))
630 vmax1 = numpy.maximum(numpy.max(lhs), numpy.max(rhs))
631 vmin2 = numpy.min(diff)
632 vmax2 = numpy.max(diff)
633 for n, (image, title) in enumerate([(lhs, "lhs"), (rhs, "rhs"), (diff, "diff")]):
634 pyplot.subplot(2, 3, n + 1)
635 im1 = pyplot.imshow(
636 image, cmap=pyplot.cm.gray, interpolation="nearest", origin="lower", vmin=vmin1, vmax=vmax1
637 )
638 if bad is not None:
639 pyplot.imshow(badImage, alpha=0.2, interpolation="nearest", origin="lower")
640 pyplot.axis("off")
641 pyplot.title(title)
642 pyplot.subplot(2, 3, n + 4)
643 im2 = pyplot.imshow(
644 image, cmap=pyplot.cm.gray, interpolation="nearest", origin="lower", vmin=vmin2, vmax=vmax2
645 )
646 if bad is not None:
647 pyplot.imshow(badImage, alpha=0.2, interpolation="nearest", origin="lower")
648 pyplot.axis("off")
649 pyplot.title(title)
650 pyplot.subplots_adjust(left=0.05, bottom=0.05, top=0.92, right=0.75, wspace=0.05, hspace=0.05)
651 cax1 = pyplot.axes([0.8, 0.55, 0.05, 0.4])
652 pyplot.colorbar(im1, cax=cax1)
653 cax2 = pyplot.axes([0.8, 0.05, 0.05, 0.4])
654 pyplot.colorbar(im2, cax=cax2)
655 if plotFileName:
656 pyplot.savefig(plotFileName)
657 else:
658 pyplot.show()
661@inTestCase
662def assertFloatsAlmostEqual(
663 testCase: unittest.TestCase,
664 lhs: float | numpy.ndarray,
665 rhs: float | numpy.ndarray,
666 rtol: float | None = sys.float_info.epsilon,
667 atol: float | None = sys.float_info.epsilon,
668 relTo: float | None = None,
669 printFailures: bool = True,
670 plotOnFailure: bool = False,
671 plotFileName: str | None = None,
672 invert: bool = False,
673 msg: str | None = None,
674 ignoreNaNs: bool = False,
675) -> None:
676 """Highly-configurable floating point comparisons for scalars and arrays.
678 The test assertion will fail if all elements ``lhs`` and ``rhs`` are not
679 equal to within the tolerances specified by ``rtol`` and ``atol``.
680 More precisely, the comparison is:
682 ``abs(lhs - rhs) <= relTo*rtol OR abs(lhs - rhs) <= atol``
684 If ``rtol`` or ``atol`` is `None`, that term in the comparison is not
685 performed at all.
687 When not specified, ``relTo`` is the elementwise maximum of the absolute
688 values of ``lhs`` and ``rhs``. If set manually, it should usually be set
689 to either ``lhs`` or ``rhs``, or a scalar value typical of what is
690 expected.
692 Parameters
693 ----------
694 testCase : `unittest.TestCase`
695 Instance the test is part of.
696 lhs : scalar or array-like
697 LHS value(s) to compare; may be a scalar or array-like of any
698 dimension.
699 rhs : scalar or array-like
700 RHS value(s) to compare; may be a scalar or array-like of any
701 dimension.
702 rtol : `float`, optional
703 Relative tolerance for comparison; defaults to double-precision
704 epsilon.
705 atol : `float`, optional
706 Absolute tolerance for comparison; defaults to double-precision
707 epsilon.
708 relTo : `float`, optional
709 Value to which comparison with rtol is relative.
710 printFailures : `bool`, optional
711 Upon failure, print all inequal elements as part of the message.
712 plotOnFailure : `bool`, optional
713 Upon failure, plot the originals and their residual with matplotlib.
714 Only 2-d arrays are supported.
715 plotFileName : `str`, optional
716 Filename to save the plot to. If `None`, the plot will be displayed in
717 a window.
718 invert : `bool`, optional
719 If `True`, invert the comparison and fail only if any elements *are*
720 equal. Used to implement `~lsst.utils.tests.assertFloatsNotEqual`,
721 which should generally be used instead for clarity.
722 will return `True`).
723 msg : `str`, optional
724 String to append to the error message when assert fails.
725 ignoreNaNs : `bool`, optional
726 If `True` (`False` is default) mask out any NaNs from operand arrays
727 before performing comparisons if they are in the same locations; NaNs
728 in different locations are trigger test assertion failures, even when
729 ``invert=True``. Scalar NaNs are treated like arrays containing only
730 NaNs of the same shape as the other operand, and no comparisons are
731 performed if both sides are scalar NaNs.
733 Raises
734 ------
735 AssertionError
736 The values are not almost equal.
737 """
738 if ignoreNaNs:
739 lhsMask = numpy.isnan(lhs)
740 rhsMask = numpy.isnan(rhs)
741 if not numpy.all(lhsMask == rhsMask):
742 testCase.fail(
743 f"lhs has {lhsMask.sum()} NaN values and rhs has {rhsMask.sum()} NaN values, "
744 "in different locations."
745 )
746 if numpy.all(lhsMask):
747 assert numpy.all(rhsMask), "Should be guaranteed by previous if."
748 # All operands are fully NaN (either scalar NaNs or arrays of only
749 # NaNs).
750 return
751 assert not numpy.all(rhsMask), "Should be guaranteed by prevoius two ifs."
752 # If either operand is an array select just its not-NaN values. Note
753 # that these expressions are never True for scalar operands, because if
754 # they are NaN then the numpy.all checks above will catch them.
755 if numpy.any(lhsMask):
756 lhs = lhs[numpy.logical_not(lhsMask)]
757 if numpy.any(rhsMask):
758 rhs = rhs[numpy.logical_not(rhsMask)]
759 if not numpy.isfinite(lhs).all():
760 testCase.fail("Non-finite values in lhs")
761 if not numpy.isfinite(rhs).all():
762 testCase.fail("Non-finite values in rhs")
763 diff = lhs - rhs
764 absDiff = numpy.abs(lhs - rhs)
765 if rtol is not None:
766 if relTo is None:
767 relTo = numpy.maximum(numpy.abs(lhs), numpy.abs(rhs))
768 else:
769 relTo = numpy.abs(relTo)
770 bad = absDiff > rtol * relTo
771 if atol is not None:
772 bad = numpy.logical_and(bad, absDiff > atol)
773 else:
774 if atol is None:
775 raise ValueError("rtol and atol cannot both be None")
776 bad = absDiff > atol
777 failed = numpy.any(bad)
778 if invert:
779 failed = not failed
780 bad = numpy.logical_not(bad)
781 cmpStr = "=="
782 failStr = "are the same"
783 else:
784 cmpStr = "!="
785 failStr = "differ"
786 errMsg = []
787 if failed:
788 if numpy.isscalar(bad):
789 if rtol is None:
790 errMsg = [f"{lhs} {cmpStr} {rhs}; diff={absDiff} with atol={atol}"]
791 elif atol is None:
792 errMsg = [f"{lhs} {cmpStr} {rhs}; diff={absDiff}/{relTo}={absDiff / relTo} with rtol={rtol}"]
793 else:
794 errMsg = [
795 f"{lhs} {cmpStr} {rhs}; diff={absDiff}/{relTo}={absDiff / relTo} "
796 f"with rtol={rtol}, atol={atol}"
797 ]
798 else:
799 errMsg = [f"{bad.sum()}/{bad.size} elements {failStr} with rtol={rtol}, atol={atol}"]
800 if plotOnFailure:
801 if len(lhs.shape) != 2 or len(rhs.shape) != 2:
802 raise ValueError("plotOnFailure is only valid for 2-d arrays")
803 try:
804 plotImageDiff(lhs, rhs, bad, diff=diff, plotFileName=plotFileName)
805 except ImportError:
806 errMsg.append("Failure plot requested but matplotlib could not be imported.")
807 if printFailures:
808 # Make sure everything is an array if any of them are, so we
809 # can treat them the same (diff and absDiff are arrays if
810 # either rhs or lhs is), and we don't get here if neither is.
811 if numpy.isscalar(relTo):
812 relTo = numpy.ones(bad.shape, dtype=float) * relTo
813 if numpy.isscalar(lhs):
814 lhs = numpy.ones(bad.shape, dtype=float) * lhs
815 if numpy.isscalar(rhs):
816 rhs = numpy.ones(bad.shape, dtype=float) * rhs
817 if rtol is None:
818 for a, b, diff in zip(lhs[bad], rhs[bad], absDiff[bad]):
819 errMsg.append(f"{a} {cmpStr} {b} (diff={diff})")
820 else:
821 for a, b, diff, rel in zip(lhs[bad], rhs[bad], absDiff[bad], relTo[bad]):
822 errMsg.append(f"{a} {cmpStr} {b} (diff={diff}/{rel}={diff / rel})")
824 if msg is not None:
825 errMsg.append(msg)
826 testCase.assertFalse(failed, msg="\n".join(errMsg))
829@inTestCase
830def assertFloatsNotEqual(
831 testCase: unittest.TestCase,
832 lhs: float | numpy.ndarray,
833 rhs: float | numpy.ndarray,
834 **kwds: Any,
835) -> None:
836 """Fail a test if the given floating point values are equal to within the
837 given tolerances.
839 See `~lsst.utils.tests.assertFloatsAlmostEqual` (called with
840 ``rtol=atol=0``) for more information.
842 Parameters
843 ----------
844 testCase : `unittest.TestCase`
845 Instance the test is part of.
846 lhs : scalar or array-like
847 LHS value(s) to compare; may be a scalar or array-like of any
848 dimension.
849 rhs : scalar or array-like
850 RHS value(s) to compare; may be a scalar or array-like of any
851 dimension.
852 **kwds : `~typing.Any`
853 Keyword parameters forwarded to `assertFloatsAlmostEqual`.
855 Raises
856 ------
857 AssertionError
858 The values are almost equal.
859 """
860 return assertFloatsAlmostEqual(testCase, lhs, rhs, invert=True, **kwds)
863@inTestCase
864def assertFloatsEqual(
865 testCase: unittest.TestCase,
866 lhs: float | numpy.ndarray,
867 rhs: float | numpy.ndarray,
868 **kwargs: Any,
869) -> None:
870 """
871 Assert that lhs == rhs (both numeric types, whether scalar or array).
873 See `~lsst.utils.tests.assertFloatsAlmostEqual` (called with
874 ``rtol=atol=0``) for more information.
876 Parameters
877 ----------
878 testCase : `unittest.TestCase`
879 Instance the test is part of.
880 lhs : scalar or array-like
881 LHS value(s) to compare; may be a scalar or array-like of any
882 dimension.
883 rhs : scalar or array-like
884 RHS value(s) to compare; may be a scalar or array-like of any
885 dimension.
886 **kwargs : `~typing.Any`
887 Keyword parameters forwarded to `assertFloatsAlmostEqual`.
889 Raises
890 ------
891 AssertionError
892 The values are not equal.
893 """
894 return assertFloatsAlmostEqual(testCase, lhs, rhs, rtol=0, atol=0, **kwargs)
897def _settingsIterator(settings: dict[str, Sequence[Any]]) -> Iterator[dict[str, Any]]:
898 """Return an iterator for the provided test settings
900 Parameters
901 ----------
902 settings : `dict` (`str`: iterable)
903 Lists of test parameters. Each should be an iterable of the same
904 length. If a string is provided as an iterable, it will be converted
905 to a list of a single string.
907 Raises
908 ------
909 AssertionError
910 If the ``settings`` are not of the same length.
912 Yields
913 ------
914 parameters : `dict` (`str`: anything)
915 Set of parameters.
916 """
917 for name, values in settings.items():
918 if isinstance(values, str): 918 ↛ 921line 918 didn't jump to line 921, because the condition on line 918 was never true
919 # Probably meant as a single-element string, rather than an
920 # iterable of chars.
921 settings[name] = [values]
922 num = len(next(iter(settings.values()))) # Number of settings
923 for name, values in settings.items():
924 assert len(values) == num, f"Length mismatch for setting {name}: {len(values)} vs {num}"
925 for ii in range(num):
926 values = [settings[kk][ii] for kk in settings]
927 yield dict(zip(settings, values))
930def classParameters(**settings: Sequence[Any]) -> Callable:
931 """Class decorator for generating unit tests.
933 This decorator generates classes with class variables according to the
934 supplied ``settings``.
936 Parameters
937 ----------
938 **settings : `dict` (`str`: iterable)
939 The lists of test parameters to set as class variables in turn. Each
940 should be an iterable of the same length.
942 Examples
943 --------
944 ::
946 @classParameters(foo=[1, 2], bar=[3, 4])
947 class MyTestCase(unittest.TestCase):
948 ...
950 will generate two classes, as if you wrote::
952 class MyTestCase_1_3(unittest.TestCase):
953 foo = 1
954 bar = 3
955 ...
957 class MyTestCase_2_4(unittest.TestCase):
958 foo = 2
959 bar = 4
960 ...
962 Note that the values are embedded in the class name.
963 """
965 def decorator(cls: type) -> None:
966 module = sys.modules[cls.__module__].__dict__
967 for params in _settingsIterator(settings):
968 name = f"{cls.__name__}_{'_'.join(str(vv) for vv in params.values())}"
969 bindings = dict(cls.__dict__)
970 bindings.update(params)
971 module[name] = type(name, (cls,), bindings)
973 return decorator
976def methodParameters(**settings: Sequence[Any]) -> Callable:
977 """Iterate over supplied settings to create subtests automatically.
979 This decorator iterates over the supplied settings, using
980 ``TestCase.subTest`` to communicate the values in the event of a failure.
982 Parameters
983 ----------
984 **settings : `dict` (`str`: iterable)
985 The lists of test parameters. Each should be an iterable of the same
986 length.
988 Examples
989 --------
990 .. code-block:: python
992 @methodParameters(foo=[1, 2], bar=[3, 4])
993 def testSomething(self, foo, bar):
994 ...
996 will run:
998 .. code-block:: python
1000 testSomething(foo=1, bar=3)
1001 testSomething(foo=2, bar=4)
1002 """
1004 def decorator(func: Callable) -> Callable:
1005 @functools.wraps(func)
1006 def wrapper(self: unittest.TestCase, *args: Any, **kwargs: Any) -> None:
1007 for params in _settingsIterator(settings):
1008 kwargs.update(params)
1009 with self.subTest(**params):
1010 func(self, *args, **kwargs)
1012 return wrapper
1014 return decorator
1017def _cartesianProduct(settings: Mapping[str, Sequence[Any]]) -> Mapping[str, Sequence[Any]]:
1018 """Return the cartesian product of the settings.
1020 Parameters
1021 ----------
1022 settings : `dict` mapping `str` to `iterable`
1023 Parameter combinations.
1025 Returns
1026 -------
1027 product : `dict` mapping `str` to `iterable`
1028 Parameter combinations covering the cartesian product (all possible
1029 combinations) of the input parameters.
1031 Examples
1032 --------
1033 .. code-block:: python
1035 cartesianProduct({"foo": [1, 2], "bar": ["black", "white"]})
1037 will return:
1039 .. code-block:: python
1041 {"foo": [1, 1, 2, 2], "bar": ["black", "white", "black", "white"]}
1042 """
1043 product: dict[str, list[Any]] = {kk: [] for kk in settings}
1044 for values in itertools.product(*settings.values()):
1045 for kk, vv in zip(settings.keys(), values):
1046 product[kk].append(vv)
1047 return product
1050def classParametersProduct(**settings: Sequence[Any]) -> Callable:
1051 """Class decorator for generating unit tests.
1053 This decorator generates classes with class variables according to the
1054 cartesian product of the supplied ``settings``.
1056 Parameters
1057 ----------
1058 **settings : `dict` (`str`: iterable)
1059 The lists of test parameters to set as class variables in turn. Each
1060 should be an iterable.
1062 Examples
1063 --------
1064 .. code-block:: python
1066 @classParametersProduct(foo=[1, 2], bar=[3, 4])
1067 class MyTestCase(unittest.TestCase):
1068 ...
1070 will generate four classes, as if you wrote::
1072 .. code-block:: python
1074 class MyTestCase_1_3(unittest.TestCase):
1075 foo = 1
1076 bar = 3
1077 ...
1079 class MyTestCase_1_4(unittest.TestCase):
1080 foo = 1
1081 bar = 4
1082 ...
1084 class MyTestCase_2_3(unittest.TestCase):
1085 foo = 2
1086 bar = 3
1087 ...
1089 class MyTestCase_2_4(unittest.TestCase):
1090 foo = 2
1091 bar = 4
1092 ...
1094 Note that the values are embedded in the class name.
1095 """
1096 return classParameters(**_cartesianProduct(settings))
1099def methodParametersProduct(**settings: Sequence[Any]) -> Callable:
1100 """Iterate over cartesian product creating sub tests.
1102 This decorator iterates over the cartesian product of the supplied
1103 settings, using `~unittest.TestCase.subTest` to communicate the values in
1104 the event of a failure.
1106 Parameters
1107 ----------
1108 **settings : `dict` (`str`: iterable)
1109 The parameter combinations to test. Each should be an iterable.
1111 Examples
1112 --------
1113 @methodParametersProduct(foo=[1, 2], bar=["black", "white"])
1114 def testSomething(self, foo, bar):
1115 ...
1117 will run:
1119 testSomething(foo=1, bar="black")
1120 testSomething(foo=1, bar="white")
1121 testSomething(foo=2, bar="black")
1122 testSomething(foo=2, bar="white")
1123 """
1124 return methodParameters(**_cartesianProduct(settings))
1127@contextlib.contextmanager
1128def temporaryDirectory() -> Iterator[str]:
1129 """Context manager that creates and destroys a temporary directory.
1131 The difference from `tempfile.TemporaryDirectory` is that this ignores
1132 errors when deleting a directory, which may happen with some filesystems.
1134 Yields
1135 ------
1136 `str`
1137 Name of the temporary directory.
1138 """
1139 tmpdir = tempfile.mkdtemp()
1140 yield tmpdir
1141 shutil.rmtree(tmpdir, ignore_errors=True)