Coverage for python/lsst/utils/tests.py: 31%
357 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-08 09:53 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-08 09:53 +0000
1# This file is part of utils.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# Use of this source code is governed by a 3-clause BSD-style
10# license that can be found in the LICENSE file.
12"""Support code for running unit tests"""
14from __future__ import annotations
16__all__ = [
17 "init",
18 "MemoryTestCase",
19 "ExecutablesTestCase",
20 "ImportTestCase",
21 "getTempFilePath",
22 "TestCase",
23 "assertFloatsAlmostEqual",
24 "assertFloatsNotEqual",
25 "assertFloatsEqual",
26 "debugger",
27 "classParameters",
28 "methodParameters",
29 "temporaryDirectory",
30]
32import contextlib
33import functools
34import gc
35import importlib.resources as resources
36import inspect
37import itertools
38import os
39import re
40import shutil
41import subprocess
42import sys
43import tempfile
44import unittest
45import warnings
46from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
47from typing import Any, ClassVar
49import numpy
50import psutil
52from .doImport import doImport
54# Initialize the list of open files to an empty set
55open_files = set()
58def _get_open_files() -> set[str]:
59 """Return a set containing the list of files currently open in this
60 process.
62 Returns
63 -------
64 open_files : `set`
65 Set containing the list of open files.
66 """
67 return {p.path for p in psutil.Process().open_files()}
70def init() -> None:
71 """Initialize the memory tester and file descriptor leak tester."""
72 global open_files
73 # Reset the list of open files
74 open_files = _get_open_files()
77def sort_tests(tests) -> unittest.TestSuite:
78 """Sort supplied test suites such that MemoryTestCases are at the end.
80 `lsst.utils.tests.MemoryTestCase` tests should always run after any other
81 tests in the module.
83 Parameters
84 ----------
85 tests : sequence
86 Sequence of test suites.
88 Returns
89 -------
90 suite : `unittest.TestSuite`
91 A combined `~unittest.TestSuite` with
92 `~lsst.utils.tests.MemoryTestCase` at the end.
93 """
94 suite = unittest.TestSuite()
95 memtests = []
96 for test_suite in tests:
97 try:
98 # Just test the first test method in the suite for MemoryTestCase
99 # Use loop rather than next as it is possible for a test class
100 # to not have any test methods and the Python community prefers
101 # for loops over catching a StopIteration exception.
102 bases = None
103 for method in test_suite:
104 bases = inspect.getmro(method.__class__)
105 break
106 if bases is not None and MemoryTestCase in bases:
107 memtests.append(test_suite)
108 else:
109 suite.addTests(test_suite)
110 except TypeError:
111 if isinstance(test_suite, MemoryTestCase):
112 memtests.append(test_suite)
113 else:
114 suite.addTest(test_suite)
115 suite.addTests(memtests)
116 return suite
119def _suiteClassWrapper(tests):
120 return unittest.TestSuite(sort_tests(tests))
123# Replace the suiteClass callable in the defaultTestLoader
124# so that we can reorder the test ordering. This will have
125# no effect if no memory test cases are found.
126unittest.defaultTestLoader.suiteClass = _suiteClassWrapper
129class MemoryTestCase(unittest.TestCase):
130 """Check for resource leaks."""
132 ignore_regexps: list[str] = []
133 """List of regexps to ignore when checking for open files."""
135 @classmethod
136 def tearDownClass(cls) -> None:
137 """Reset the leak counter when the tests have been completed"""
138 init()
140 def testFileDescriptorLeaks(self) -> None:
141 """Check if any file descriptors are open since init() called.
143 Ignores files with certain known path components and any files
144 that match regexp patterns in class property ``ignore_regexps``.
145 """
146 gc.collect()
147 global open_files
148 now_open = _get_open_files()
150 # Some files are opened out of the control of the stack.
151 now_open = {
152 f
153 for f in now_open
154 if not f.endswith(".car")
155 and not f.startswith("/proc/")
156 and not f.endswith(".ttf")
157 and not (f.startswith("/var/lib/") and f.endswith("/passwd"))
158 and not f.endswith("astropy.log")
159 and not f.endswith("mime/mime.cache")
160 and not f.endswith(".sqlite3")
161 and not any(re.search(r, f) for r in self.ignore_regexps)
162 }
164 diff = now_open.difference(open_files)
165 if diff:
166 for f in diff:
167 print(f"File open: {f}")
168 self.fail("Failed to close %d file%s" % (len(diff), "s" if len(diff) != 1 else ""))
171class ExecutablesTestCase(unittest.TestCase):
172 """Test that executables can be run and return good status.
174 The test methods are dynamically created. Callers
175 must subclass this class in their own test file and invoke
176 the create_executable_tests() class method to register the tests.
177 """
179 TESTS_DISCOVERED = -1
181 @classmethod
182 def setUpClass(cls) -> None:
183 """Abort testing if automated test creation was enabled and
184 no tests were found.
185 """
186 if cls.TESTS_DISCOVERED == 0:
187 raise RuntimeError("No executables discovered.")
189 def testSanity(self) -> None:
190 """Ensure that there is at least one test to be
191 executed. This allows the test runner to trigger the class set up
192 machinery to test whether there are some executables to test.
193 """
195 def assertExecutable(
196 self,
197 executable: str,
198 root_dir: str | None = None,
199 args: Sequence[str] | None = None,
200 msg: str | None = None,
201 ) -> None:
202 """Check an executable runs and returns good status.
204 Prints output to standard out. On bad exit status the test
205 fails. If the executable can not be located the test is skipped.
207 Parameters
208 ----------
209 executable : `str`
210 Path to an executable. ``root_dir`` is not used if this is an
211 absolute path.
212 root_dir : `str`, optional
213 Directory containing executable. Ignored if `None`.
214 args : `list` or `tuple`, optional
215 Arguments to be provided to the executable.
216 msg : `str`, optional
217 Message to use when the test fails. Can be `None` for default
218 message.
220 Raises
221 ------
222 AssertionError
223 The executable did not return 0 exit status.
224 """
225 if root_dir is not None and not os.path.isabs(executable):
226 executable = os.path.join(root_dir, executable)
228 # Form the argument list for subprocess
229 sp_args = [executable]
230 argstr = "no arguments"
231 if args is not None:
232 sp_args.extend(args)
233 argstr = 'arguments "' + " ".join(args) + '"'
235 print(f"Running executable '{executable}' with {argstr}...")
236 if not os.path.exists(executable):
237 self.skipTest(f"Executable {executable} is unexpectedly missing")
238 failmsg = None
239 try:
240 output = subprocess.check_output(sp_args)
241 except subprocess.CalledProcessError as e:
242 output = e.output
243 failmsg = f"Bad exit status from '{executable}': {e.returncode}"
244 print(output.decode("utf-8"))
245 if failmsg:
246 if msg is None:
247 msg = failmsg
248 self.fail(msg)
250 @classmethod
251 def _build_test_method(cls, executable: str, root_dir: str) -> None:
252 """Build a test method and attach to class.
254 A test method is created for the supplied excutable located
255 in the supplied root directory. This method is attached to the class
256 so that the test runner will discover the test and run it.
258 Parameters
259 ----------
260 cls : `object`
261 The class in which to create the tests.
262 executable : `str`
263 Name of executable. Can be absolute path.
264 root_dir : `str`
265 Path to executable. Not used if executable path is absolute.
266 """
267 if not os.path.isabs(executable): 267 ↛ 268line 267 didn't jump to line 268, because the condition on line 267 was never true
268 executable = os.path.abspath(os.path.join(root_dir, executable))
270 # Create the test name from the executable path.
271 test_name = "test_exe_" + executable.replace("/", "_")
273 # This is the function that will become the test method
274 def test_executable_runs(*args: Any) -> None:
275 self = args[0]
276 self.assertExecutable(executable)
278 # Give it a name and attach it to the class
279 test_executable_runs.__name__ = test_name
280 setattr(cls, test_name, test_executable_runs)
282 @classmethod
283 def create_executable_tests(cls, ref_file: str, executables: Sequence[str] | None = None) -> None:
284 """Discover executables to test and create corresponding test methods.
286 Scans the directory containing the supplied reference file
287 (usually ``__file__`` supplied from the test class) to look for
288 executables. If executables are found a test method is created
289 for each one. That test method will run the executable and
290 check the returned value.
292 Executable scripts with a ``.py`` extension and shared libraries
293 are ignored by the scanner.
295 This class method must be called before test discovery.
297 Parameters
298 ----------
299 ref_file : `str`
300 Path to a file within the directory to be searched.
301 If the files are in the same location as the test file, then
302 ``__file__`` can be used.
303 executables : `list` or `tuple`, optional
304 Sequence of executables that can override the automated
305 detection. If an executable mentioned here is not found, a
306 skipped test will be created for it, rather than a failed
307 test.
309 Examples
310 --------
311 >>> cls.create_executable_tests(__file__)
312 """
313 # Get the search directory from the reference file
314 ref_dir = os.path.abspath(os.path.dirname(ref_file))
316 if executables is None: 316 ↛ 331line 316 didn't jump to line 331, because the condition on line 316 was never false
317 # Look for executables to test by walking the tree
318 executables = []
319 for root, _, files in os.walk(ref_dir):
320 for f in files:
321 # Skip Python files. Shared libraries are executable.
322 if not f.endswith(".py") and not f.endswith(".so"):
323 full_path = os.path.join(root, f)
324 if os.access(full_path, os.X_OK):
325 executables.append(full_path)
327 # Store the number of tests found for later assessment.
328 # Do not raise an exception if we have no executables as this would
329 # cause the testing to abort before the test runner could properly
330 # integrate it into the failure report.
331 cls.TESTS_DISCOVERED = len(executables)
333 # Create the test functions and attach them to the class
334 for e in executables:
335 cls._build_test_method(e, ref_dir)
338class ImportTestCase(unittest.TestCase):
339 """Test that the named packages can be imported and all files within
340 that package.
342 The test methods are created dynamically. Callers must subclass this
343 method and define the ``PACKAGES`` property.
344 """
346 PACKAGES: ClassVar[Iterable[str]] = ()
347 """Packages to be imported."""
349 _n_registered = 0
350 """Number of packages registered for testing by this class."""
352 def _test_no_packages_registered_for_import_testing(self) -> None:
353 """Test when no packages have been registered.
355 Without this, if no packages have been listed no tests will be
356 registered and the test system will not report on anything. This
357 test fails and reports why.
358 """
359 raise AssertionError("No packages registered with import test. Was the PACKAGES property set?")
361 def __init_subclass__(cls, **kwargs: Any) -> None:
362 """Create the test methods based on the content of the ``PACKAGES``
363 class property.
364 """
365 super().__init_subclass__(**kwargs)
367 for mod in cls.PACKAGES:
368 test_name = "test_import_" + mod.replace(".", "_")
370 def test_import(*args: Any) -> None:
371 self = args[0]
372 self.assertImport(mod)
374 test_import.__name__ = test_name
375 setattr(cls, test_name, test_import)
376 cls._n_registered += 1
378 # If there are no packages listed that is likely a mistake and
379 # so register a failing test.
380 if cls._n_registered == 0: 380 ↛ 381line 380 didn't jump to line 381, because the condition on line 380 was never true
381 setattr(cls, "test_no_packages_registered", cls._test_no_packages_registered_for_import_testing)
383 def assertImport(self, root_pkg):
384 for file in resources.files(root_pkg).iterdir():
385 file = file.name
386 if not file.endswith(".py"):
387 continue
388 if file.startswith("__"):
389 continue
390 root, _ = os.path.splitext(file)
391 module_name = f"{root_pkg}.{root}"
392 with self.subTest(module=module_name):
393 try:
394 doImport(module_name)
395 except ImportError as e:
396 raise AssertionError(f"Error importing module {module_name}: {e}") from e
399@contextlib.contextmanager
400def getTempFilePath(ext: str, expectOutput: bool = True) -> Iterator[str]:
401 """Return a path suitable for a temporary file and try to delete the
402 file on success
404 If the with block completes successfully then the file is deleted,
405 if possible; failure results in a printed warning.
406 If a file is remains when it should not, a RuntimeError exception is
407 raised. This exception is also raised if a file is not present on context
408 manager exit when one is expected to exist.
409 If the block exits with an exception the file if left on disk so it can be
410 examined. The file name has a random component such that nested context
411 managers can be used with the same file suffix.
413 Parameters
414 ----------
415 ext : `str`
416 File name extension, e.g. ``.fits``.
417 expectOutput : `bool`, optional
418 If `True`, a file should be created within the context manager.
419 If `False`, a file should not be present when the context manager
420 exits.
422 Returns
423 -------
424 path : `str`
425 Path for a temporary file. The path is a combination of the caller's
426 file path and the name of the top-level function
428 Examples
429 --------
430 .. code-block:: python
432 # file tests/testFoo.py
433 import unittest
434 import lsst.utils.tests
435 class FooTestCase(unittest.TestCase):
436 def testBasics(self):
437 self.runTest()
439 def runTest(self):
440 with lsst.utils.tests.getTempFilePath(".fits") as tmpFile:
441 # if tests/.tests exists then
442 # tmpFile = "tests/.tests/testFoo_testBasics.fits"
443 # otherwise tmpFile = "testFoo_testBasics.fits"
444 ...
445 # at the end of this "with" block the path tmpFile will be
446 # deleted, but only if the file exists and the "with"
447 # block terminated normally (rather than with an exception)
448 ...
449 """
450 stack = inspect.stack()
451 # get name of first function in the file
452 for i in range(2, len(stack)):
453 frameInfo = inspect.getframeinfo(stack[i][0])
454 if i == 2:
455 callerFilePath = frameInfo.filename
456 callerFuncName = frameInfo.function
457 elif callerFilePath == frameInfo.filename:
458 # this function called the previous function
459 callerFuncName = frameInfo.function
460 else:
461 break
463 callerDir, callerFileNameWithExt = os.path.split(callerFilePath)
464 callerFileName = os.path.splitext(callerFileNameWithExt)[0]
465 outDir = os.path.join(callerDir, ".tests")
466 if not os.path.isdir(outDir):
467 outDir = ""
468 prefix = f"{callerFileName}_{callerFuncName}-"
469 outPath = tempfile.mktemp(dir=outDir, suffix=ext, prefix=prefix)
470 if os.path.exists(outPath):
471 # There should not be a file there given the randomizer. Warn and
472 # remove.
473 # Use stacklevel 3 so that the warning is reported from the end of the
474 # with block
475 warnings.warn(f"Unexpectedly found pre-existing tempfile named {outPath!r}", stacklevel=3)
476 try:
477 os.remove(outPath)
478 except OSError:
479 pass
481 yield outPath
483 fileExists = os.path.exists(outPath)
484 if expectOutput:
485 if not fileExists:
486 raise RuntimeError(f"Temp file expected named {outPath} but none found")
487 else:
488 if fileExists:
489 raise RuntimeError(f"Unexpectedly discovered temp file named {outPath}")
490 # Try to clean up the file regardless
491 if fileExists:
492 try:
493 os.remove(outPath)
494 except OSError as e:
495 # Use stacklevel 3 so that the warning is reported from the end of
496 # the with block.
497 warnings.warn(f"Warning: could not remove file {outPath!r}: {e}", stacklevel=3)
500class TestCase(unittest.TestCase):
501 """Subclass of unittest.TestCase that adds some custom assertions for
502 convenience.
503 """
506def inTestCase(func: Callable) -> Callable:
507 """Add a free function to our custom TestCase class, while
508 also making it available as a free function.
509 """
510 setattr(TestCase, func.__name__, func)
511 return func
514def debugger(*exceptions):
515 """Enter the debugger when there's an uncaught exception
517 To use, just slap a ``@debugger()`` on your function.
519 You may provide specific exception classes to catch as arguments to
520 the decorator function, e.g.,
521 ``@debugger(RuntimeError, NotImplementedError)``.
522 This defaults to just `AssertionError`, for use on `unittest.TestCase`
523 methods.
525 Code provided by "Rosh Oxymoron" on StackOverflow:
526 http://stackoverflow.com/questions/4398967/python-unit-testing-automatically-running-the-debugger-when-a-test-fails
528 Notes
529 -----
530 Consider using ``pytest --pdb`` instead of this decorator.
531 """
532 if not exceptions:
533 exceptions = (Exception,)
535 def decorator(f):
536 @functools.wraps(f)
537 def wrapper(*args, **kwargs):
538 try:
539 return f(*args, **kwargs)
540 except exceptions:
541 import pdb
542 import sys
544 pdb.post_mortem(sys.exc_info()[2])
546 return wrapper
548 return decorator
551def plotImageDiff(
552 lhs: numpy.ndarray,
553 rhs: numpy.ndarray,
554 bad: numpy.ndarray | None = None,
555 diff: numpy.ndarray | None = None,
556 plotFileName: str | None = None,
557) -> None:
558 """Plot the comparison of two 2-d NumPy arrays.
560 Parameters
561 ----------
562 lhs : `numpy.ndarray`
563 LHS values to compare; a 2-d NumPy array
564 rhs : `numpy.ndarray`
565 RHS values to compare; a 2-d NumPy array
566 bad : `numpy.ndarray`
567 A 2-d boolean NumPy array of values to emphasize in the plots
568 diff : `numpy.ndarray`
569 difference array; a 2-d NumPy array, or None to show lhs-rhs
570 plotFileName : `str`
571 Filename to save the plot to. If None, the plot will be displayed in
572 a window.
574 Notes
575 -----
576 This method uses `matplotlib` and imports it internally; it should be
577 wrapped in a try/except block within packages that do not depend on
578 `matplotlib` (including `~lsst.utils`).
579 """
580 from matplotlib import pyplot
582 if diff is None:
583 diff = lhs - rhs
584 pyplot.figure()
585 if bad is not None:
586 # make an rgba image that's red and transparent where not bad
587 badImage = numpy.zeros(bad.shape + (4,), dtype=numpy.uint8)
588 badImage[:, :, 0] = 255
589 badImage[:, :, 1] = 0
590 badImage[:, :, 2] = 0
591 badImage[:, :, 3] = 255 * bad
592 vmin1 = numpy.minimum(numpy.min(lhs), numpy.min(rhs))
593 vmax1 = numpy.maximum(numpy.max(lhs), numpy.max(rhs))
594 vmin2 = numpy.min(diff)
595 vmax2 = numpy.max(diff)
596 for n, (image, title) in enumerate([(lhs, "lhs"), (rhs, "rhs"), (diff, "diff")]):
597 pyplot.subplot(2, 3, n + 1)
598 im1 = pyplot.imshow(
599 image, cmap=pyplot.cm.gray, interpolation="nearest", origin="lower", vmin=vmin1, vmax=vmax1
600 )
601 if bad is not None:
602 pyplot.imshow(badImage, alpha=0.2, interpolation="nearest", origin="lower")
603 pyplot.axis("off")
604 pyplot.title(title)
605 pyplot.subplot(2, 3, n + 4)
606 im2 = pyplot.imshow(
607 image, cmap=pyplot.cm.gray, interpolation="nearest", origin="lower", vmin=vmin2, vmax=vmax2
608 )
609 if bad is not None:
610 pyplot.imshow(badImage, alpha=0.2, interpolation="nearest", origin="lower")
611 pyplot.axis("off")
612 pyplot.title(title)
613 pyplot.subplots_adjust(left=0.05, bottom=0.05, top=0.92, right=0.75, wspace=0.05, hspace=0.05)
614 cax1 = pyplot.axes([0.8, 0.55, 0.05, 0.4])
615 pyplot.colorbar(im1, cax=cax1)
616 cax2 = pyplot.axes([0.8, 0.05, 0.05, 0.4])
617 pyplot.colorbar(im2, cax=cax2)
618 if plotFileName:
619 pyplot.savefig(plotFileName)
620 else:
621 pyplot.show()
624@inTestCase
625def assertFloatsAlmostEqual(
626 testCase: unittest.TestCase,
627 lhs: float | numpy.ndarray,
628 rhs: float | numpy.ndarray,
629 rtol: float | None = sys.float_info.epsilon,
630 atol: float | None = sys.float_info.epsilon,
631 relTo: float | None = None,
632 printFailures: bool = True,
633 plotOnFailure: bool = False,
634 plotFileName: str | None = None,
635 invert: bool = False,
636 msg: str | None = None,
637 ignoreNaNs: bool = False,
638) -> None:
639 """Highly-configurable floating point comparisons for scalars and arrays.
641 The test assertion will fail if all elements ``lhs`` and ``rhs`` are not
642 equal to within the tolerances specified by ``rtol`` and ``atol``.
643 More precisely, the comparison is:
645 ``abs(lhs - rhs) <= relTo*rtol OR abs(lhs - rhs) <= atol``
647 If ``rtol`` or ``atol`` is `None`, that term in the comparison is not
648 performed at all.
650 When not specified, ``relTo`` is the elementwise maximum of the absolute
651 values of ``lhs`` and ``rhs``. If set manually, it should usually be set
652 to either ``lhs`` or ``rhs``, or a scalar value typical of what is
653 expected.
655 Parameters
656 ----------
657 testCase : `unittest.TestCase`
658 Instance the test is part of.
659 lhs : scalar or array-like
660 LHS value(s) to compare; may be a scalar or array-like of any
661 dimension.
662 rhs : scalar or array-like
663 RHS value(s) to compare; may be a scalar or array-like of any
664 dimension.
665 rtol : `float`, optional
666 Relative tolerance for comparison; defaults to double-precision
667 epsilon.
668 atol : `float`, optional
669 Absolute tolerance for comparison; defaults to double-precision
670 epsilon.
671 relTo : `float`, optional
672 Value to which comparison with rtol is relative.
673 printFailures : `bool`, optional
674 Upon failure, print all inequal elements as part of the message.
675 plotOnFailure : `bool`, optional
676 Upon failure, plot the originals and their residual with matplotlib.
677 Only 2-d arrays are supported.
678 plotFileName : `str`, optional
679 Filename to save the plot to. If `None`, the plot will be displayed in
680 a window.
681 invert : `bool`, optional
682 If `True`, invert the comparison and fail only if any elements *are*
683 equal. Used to implement `~lsst.utils.tests.assertFloatsNotEqual`,
684 which should generally be used instead for clarity.
685 will return `True`).
686 msg : `str`, optional
687 String to append to the error message when assert fails.
688 ignoreNaNs : `bool`, optional
689 If `True` (`False` is default) mask out any NaNs from operand arrays
690 before performing comparisons if they are in the same locations; NaNs
691 in different locations are trigger test assertion failures, even when
692 ``invert=True``. Scalar NaNs are treated like arrays containing only
693 NaNs of the same shape as the other operand, and no comparisons are
694 performed if both sides are scalar NaNs.
696 Raises
697 ------
698 AssertionError
699 The values are not almost equal.
700 """
701 if ignoreNaNs:
702 lhsMask = numpy.isnan(lhs)
703 rhsMask = numpy.isnan(rhs)
704 if not numpy.all(lhsMask == rhsMask):
705 testCase.fail(
706 f"lhs has {lhsMask.sum()} NaN values and rhs has {rhsMask.sum()} NaN values, "
707 "in different locations."
708 )
709 if numpy.all(lhsMask):
710 assert numpy.all(rhsMask), "Should be guaranteed by previous if."
711 # All operands are fully NaN (either scalar NaNs or arrays of only
712 # NaNs).
713 return
714 assert not numpy.all(rhsMask), "Should be guaranteed by prevoius two ifs."
715 # If either operand is an array select just its not-NaN values. Note
716 # that these expressions are never True for scalar operands, because if
717 # they are NaN then the numpy.all checks above will catch them.
718 if numpy.any(lhsMask):
719 lhs = lhs[numpy.logical_not(lhsMask)]
720 if numpy.any(rhsMask):
721 rhs = rhs[numpy.logical_not(rhsMask)]
722 if not numpy.isfinite(lhs).all():
723 testCase.fail("Non-finite values in lhs")
724 if not numpy.isfinite(rhs).all():
725 testCase.fail("Non-finite values in rhs")
726 diff = lhs - rhs
727 absDiff = numpy.abs(lhs - rhs)
728 if rtol is not None:
729 if relTo is None:
730 relTo = numpy.maximum(numpy.abs(lhs), numpy.abs(rhs))
731 else:
732 relTo = numpy.abs(relTo)
733 bad = absDiff > rtol * relTo
734 if atol is not None:
735 bad = numpy.logical_and(bad, absDiff > atol)
736 else:
737 if atol is None:
738 raise ValueError("rtol and atol cannot both be None")
739 bad = absDiff > atol
740 failed = numpy.any(bad)
741 if invert:
742 failed = not failed
743 bad = numpy.logical_not(bad)
744 cmpStr = "=="
745 failStr = "are the same"
746 else:
747 cmpStr = "!="
748 failStr = "differ"
749 errMsg = []
750 if failed:
751 if numpy.isscalar(bad):
752 if rtol is None:
753 errMsg = [f"{lhs} {cmpStr} {rhs}; diff={absDiff} with atol={atol}"]
754 elif atol is None:
755 errMsg = [f"{lhs} {cmpStr} {rhs}; diff={absDiff}/{relTo}={absDiff / relTo} with rtol={rtol}"]
756 else:
757 errMsg = [
758 f"{lhs} {cmpStr} {rhs}; diff={absDiff}/{relTo}={absDiff / relTo} "
759 f"with rtol={rtol}, atol={atol}"
760 ]
761 else:
762 errMsg = [f"{bad.sum()}/{bad.size} elements {failStr} with rtol={rtol}, atol={atol}"]
763 if plotOnFailure:
764 if len(lhs.shape) != 2 or len(rhs.shape) != 2:
765 raise ValueError("plotOnFailure is only valid for 2-d arrays")
766 try:
767 plotImageDiff(lhs, rhs, bad, diff=diff, plotFileName=plotFileName)
768 except ImportError:
769 errMsg.append("Failure plot requested but matplotlib could not be imported.")
770 if printFailures:
771 # Make sure everything is an array if any of them are, so we
772 # can treat them the same (diff and absDiff are arrays if
773 # either rhs or lhs is), and we don't get here if neither is.
774 if numpy.isscalar(relTo):
775 relTo = numpy.ones(bad.shape, dtype=float) * relTo
776 if numpy.isscalar(lhs):
777 lhs = numpy.ones(bad.shape, dtype=float) * lhs
778 if numpy.isscalar(rhs):
779 rhs = numpy.ones(bad.shape, dtype=float) * rhs
780 if rtol is None:
781 for a, b, diff in zip(lhs[bad], rhs[bad], absDiff[bad]):
782 errMsg.append(f"{a} {cmpStr} {b} (diff={diff})")
783 else:
784 for a, b, diff, rel in zip(lhs[bad], rhs[bad], absDiff[bad], relTo[bad]):
785 errMsg.append(f"{a} {cmpStr} {b} (diff={diff}/{rel}={diff / rel})")
787 if msg is not None:
788 errMsg.append(msg)
789 testCase.assertFalse(failed, msg="\n".join(errMsg))
792@inTestCase
793def assertFloatsNotEqual(
794 testCase: unittest.TestCase,
795 lhs: float | numpy.ndarray,
796 rhs: float | numpy.ndarray,
797 **kwds: Any,
798) -> None:
799 """Fail a test if the given floating point values are equal to within the
800 given tolerances.
802 See `~lsst.utils.tests.assertFloatsAlmostEqual` (called with
803 ``rtol=atol=0``) for more information.
805 Parameters
806 ----------
807 testCase : `unittest.TestCase`
808 Instance the test is part of.
809 lhs : scalar or array-like
810 LHS value(s) to compare; may be a scalar or array-like of any
811 dimension.
812 rhs : scalar or array-like
813 RHS value(s) to compare; may be a scalar or array-like of any
814 dimension.
816 Raises
817 ------
818 AssertionError
819 The values are almost equal.
820 """
821 return assertFloatsAlmostEqual(testCase, lhs, rhs, invert=True, **kwds)
824@inTestCase
825def assertFloatsEqual(
826 testCase: unittest.TestCase,
827 lhs: float | numpy.ndarray,
828 rhs: float | numpy.ndarray,
829 **kwargs: Any,
830) -> None:
831 """
832 Assert that lhs == rhs (both numeric types, whether scalar or array).
834 See `~lsst.utils.tests.assertFloatsAlmostEqual` (called with
835 ``rtol=atol=0``) for more information.
837 Parameters
838 ----------
839 testCase : `unittest.TestCase`
840 Instance the test is part of.
841 lhs : scalar or array-like
842 LHS value(s) to compare; may be a scalar or array-like of any
843 dimension.
844 rhs : scalar or array-like
845 RHS value(s) to compare; may be a scalar or array-like of any
846 dimension.
848 Raises
849 ------
850 AssertionError
851 The values are not equal.
852 """
853 return assertFloatsAlmostEqual(testCase, lhs, rhs, rtol=0, atol=0, **kwargs)
856def _settingsIterator(settings: dict[str, Sequence[Any]]) -> Iterator[dict[str, Any]]:
857 """Return an iterator for the provided test settings
859 Parameters
860 ----------
861 settings : `dict` (`str`: iterable)
862 Lists of test parameters. Each should be an iterable of the same
863 length. If a string is provided as an iterable, it will be converted
864 to a list of a single string.
866 Raises
867 ------
868 AssertionError
869 If the ``settings`` are not of the same length.
871 Yields
872 ------
873 parameters : `dict` (`str`: anything)
874 Set of parameters.
875 """
876 for name, values in settings.items():
877 if isinstance(values, str): 877 ↛ 880line 877 didn't jump to line 880, because the condition on line 877 was never true
878 # Probably meant as a single-element string, rather than an
879 # iterable of chars.
880 settings[name] = [values]
881 num = len(next(iter(settings.values()))) # Number of settings
882 for name, values in settings.items():
883 assert len(values) == num, f"Length mismatch for setting {name}: {len(values)} vs {num}"
884 for ii in range(num):
885 values = [settings[kk][ii] for kk in settings]
886 yield dict(zip(settings, values))
889def classParameters(**settings: Sequence[Any]) -> Callable:
890 """Class decorator for generating unit tests
892 This decorator generates classes with class variables according to the
893 supplied ``settings``.
895 Parameters
896 ----------
897 **settings : `dict` (`str`: iterable)
898 The lists of test parameters to set as class variables in turn. Each
899 should be an iterable of the same length.
901 Examples
902 --------
903 ::
905 @classParameters(foo=[1, 2], bar=[3, 4])
906 class MyTestCase(unittest.TestCase):
907 ...
909 will generate two classes, as if you wrote::
911 class MyTestCase_1_3(unittest.TestCase):
912 foo = 1
913 bar = 3
914 ...
916 class MyTestCase_2_4(unittest.TestCase):
917 foo = 2
918 bar = 4
919 ...
921 Note that the values are embedded in the class name.
922 """
924 def decorator(cls: type) -> None:
925 module = sys.modules[cls.__module__].__dict__
926 for params in _settingsIterator(settings):
927 name = f"{cls.__name__}_{'_'.join(str(vv) for vv in params.values())}"
928 bindings = dict(cls.__dict__)
929 bindings.update(params)
930 module[name] = type(name, (cls,), bindings)
932 return decorator
935def methodParameters(**settings: Sequence[Any]) -> Callable:
936 """Iterate over supplied settings to create subtests automatically.
938 This decorator iterates over the supplied settings, using
939 ``TestCase.subTest`` to communicate the values in the event of a failure.
941 Parameters
942 ----------
943 **settings : `dict` (`str`: iterable)
944 The lists of test parameters. Each should be an iterable of the same
945 length.
947 Examples
948 --------
949 .. code-block:: python
951 @methodParameters(foo=[1, 2], bar=[3, 4])
952 def testSomething(self, foo, bar):
953 ...
955 will run:
957 .. code-block:: python
959 testSomething(foo=1, bar=3)
960 testSomething(foo=2, bar=4)
961 """
963 def decorator(func: Callable) -> Callable:
964 @functools.wraps(func)
965 def wrapper(self: unittest.TestCase, *args: Any, **kwargs: Any) -> None:
966 for params in _settingsIterator(settings):
967 kwargs.update(params)
968 with self.subTest(**params):
969 func(self, *args, **kwargs)
971 return wrapper
973 return decorator
976def _cartesianProduct(settings: Mapping[str, Sequence[Any]]) -> Mapping[str, Sequence[Any]]:
977 """Return the cartesian product of the settings
979 Parameters
980 ----------
981 settings : `dict` mapping `str` to `iterable`
982 Parameter combinations.
984 Returns
985 -------
986 product : `dict` mapping `str` to `iterable`
987 Parameter combinations covering the cartesian product (all possible
988 combinations) of the input parameters.
990 Examples
991 --------
992 .. code-block:: python
994 cartesianProduct({"foo": [1, 2], "bar": ["black", "white"]})
996 will return:
998 .. code-block:: python
1000 {"foo": [1, 1, 2, 2], "bar": ["black", "white", "black", "white"]}
1001 """
1002 product: dict[str, list[Any]] = {kk: [] for kk in settings}
1003 for values in itertools.product(*settings.values()):
1004 for kk, vv in zip(settings.keys(), values):
1005 product[kk].append(vv)
1006 return product
1009def classParametersProduct(**settings: Sequence[Any]) -> Callable:
1010 """Class decorator for generating unit tests
1012 This decorator generates classes with class variables according to the
1013 cartesian product of the supplied ``settings``.
1015 Parameters
1016 ----------
1017 **settings : `dict` (`str`: iterable)
1018 The lists of test parameters to set as class variables in turn. Each
1019 should be an iterable.
1021 Examples
1022 --------
1023 .. code-block:: python
1025 @classParametersProduct(foo=[1, 2], bar=[3, 4])
1026 class MyTestCase(unittest.TestCase):
1027 ...
1029 will generate four classes, as if you wrote::
1031 .. code-block:: python
1033 class MyTestCase_1_3(unittest.TestCase):
1034 foo = 1
1035 bar = 3
1036 ...
1038 class MyTestCase_1_4(unittest.TestCase):
1039 foo = 1
1040 bar = 4
1041 ...
1043 class MyTestCase_2_3(unittest.TestCase):
1044 foo = 2
1045 bar = 3
1046 ...
1048 class MyTestCase_2_4(unittest.TestCase):
1049 foo = 2
1050 bar = 4
1051 ...
1053 Note that the values are embedded in the class name.
1054 """
1055 return classParameters(**_cartesianProduct(settings))
1058def methodParametersProduct(**settings: Sequence[Any]) -> Callable:
1059 """Iterate over cartesian product creating sub tests.
1061 This decorator iterates over the cartesian product of the supplied
1062 settings, using `~unittest.TestCase.subTest` to communicate the values in
1063 the event of a failure.
1065 Parameters
1066 ----------
1067 **settings : `dict` (`str`: iterable)
1068 The parameter combinations to test. Each should be an iterable.
1070 Example
1071 -------
1073 @methodParametersProduct(foo=[1, 2], bar=["black", "white"])
1074 def testSomething(self, foo, bar):
1075 ...
1077 will run:
1079 testSomething(foo=1, bar="black")
1080 testSomething(foo=1, bar="white")
1081 testSomething(foo=2, bar="black")
1082 testSomething(foo=2, bar="white")
1083 """
1084 return methodParameters(**_cartesianProduct(settings))
1087@contextlib.contextmanager
1088def temporaryDirectory() -> Iterator[str]:
1089 """Context manager that creates and destroys a temporary directory.
1091 The difference from `tempfile.TemporaryDirectory` is that this ignores
1092 errors when deleting a directory, which may happen with some filesystems.
1093 """
1094 tmpdir = tempfile.mkdtemp()
1095 yield tmpdir
1096 shutil.rmtree(tmpdir, ignore_errors=True)