Coverage for python/lsst/utils/tests.py : 27%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1#
2# LSST Data Management System
3#
4# Copyright 2008-2017 AURA/LSST.
5#
6# This product includes software developed by the
7# LSST Project (http://www.lsst.org/).
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the LSST License Statement and
20# the GNU General Public License along with this program. If not,
21# see <https://www.lsstcorp.org/LegalNotices/>.
22#
23"""Support code for running unit tests"""
25import contextlib
26import gc
27import inspect
28import os
29import subprocess
30import sys
31import unittest
32import warnings
33import numpy
34import psutil
35import functools
36import tempfile
37import shutil
39__all__ = ["init", "MemoryTestCase", "ExecutablesTestCase", "getTempFilePath",
40 "TestCase", "assertFloatsAlmostEqual", "assertFloatsNotEqual", "assertFloatsEqual",
41 "debugger", "classParameters", "methodParameters"]
43# Initialize the list of open files to an empty set
44open_files = set()
47def _get_open_files():
48 """Return a set containing the list of files currently open in this
49 process.
51 Returns
52 -------
53 open_files : `set`
54 Set containing the list of open files.
55 """
56 return set(p.path for p in psutil.Process().open_files())
59def init():
60 """Initialize the memory tester and file descriptor leak tester."""
61 global open_files
62 # Reset the list of open files
63 open_files = _get_open_files()
66def sort_tests(tests):
67 """Sort supplied test suites such that MemoryTestCases are at the end.
69 `lsst.utils.tests.MemoryTestCase` tests should always run after any other
70 tests in the module.
72 Parameters
73 ----------
74 tests : sequence
75 Sequence of test suites.
77 Returns
78 -------
79 suite : `unittest.TestSuite`
80 A combined `~unittest.TestSuite` with
81 `~lsst.utils.tests.MemoryTestCase` at the end.
82 """
84 suite = unittest.TestSuite()
85 memtests = []
86 for test_suite in tests:
87 try:
88 # Just test the first test method in the suite for MemoryTestCase
89 # Use loop rather than next as it is possible for a test class
90 # to not have any test methods and the Python community prefers
91 # for loops over catching a StopIteration exception.
92 bases = None
93 for method in test_suite:
94 bases = inspect.getmro(method.__class__)
95 break
96 if bases is not None and MemoryTestCase in bases:
97 memtests.append(test_suite)
98 else:
99 suite.addTests(test_suite)
100 except TypeError:
101 if isinstance(test_suite, MemoryTestCase):
102 memtests.append(test_suite)
103 else:
104 suite.addTest(test_suite)
105 suite.addTests(memtests)
106 return suite
109def suiteClassWrapper(tests):
110 return unittest.TestSuite(sort_tests(tests))
113# Replace the suiteClass callable in the defaultTestLoader
114# so that we can reorder the test ordering. This will have
115# no effect if no memory test cases are found.
116unittest.defaultTestLoader.suiteClass = suiteClassWrapper
119class MemoryTestCase(unittest.TestCase):
120 """Check for resource leaks."""
122 @classmethod
123 def tearDownClass(cls):
124 """Reset the leak counter when the tests have been completed"""
125 init()
127 def testFileDescriptorLeaks(self):
128 """Check if any file descriptors are open since init() called."""
129 gc.collect()
130 global open_files
131 now_open = _get_open_files()
133 # Some files are opened out of the control of the stack.
134 now_open = set(f for f in now_open if not f.endswith(".car")
135 and not f.startswith("/proc/")
136 and not f.endswith(".ttf")
137 and not (f.startswith("/var/lib/") and f.endswith("/passwd"))
138 and not f.endswith("astropy.log"))
140 diff = now_open.difference(open_files)
141 if diff:
142 for f in diff:
143 print("File open: %s" % f)
144 self.fail("Failed to close %d file%s" % (len(diff), "s" if len(diff) != 1 else ""))
147class ExecutablesTestCase(unittest.TestCase):
148 """Test that executables can be run and return good status.
150 The test methods are dynamically created. Callers
151 must subclass this class in their own test file and invoke
152 the create_executable_tests() class method to register the tests.
153 """
154 TESTS_DISCOVERED = -1
156 @classmethod
157 def setUpClass(cls):
158 """Abort testing if automated test creation was enabled and
159 no tests were found."""
161 if cls.TESTS_DISCOVERED == 0:
162 raise RuntimeError("No executables discovered.")
164 def testSanity(self):
165 """This test exists to ensure that there is at least one test to be
166 executed. This allows the test runner to trigger the class set up
167 machinery to test whether there are some executables to test."""
168 pass
170 def assertExecutable(self, executable, root_dir=None, args=None, msg=None):
171 """Check an executable runs and returns good status.
173 Prints output to standard out. On bad exit status the test
174 fails. If the executable can not be located the test is skipped.
176 Parameters
177 ----------
178 executable : `str`
179 Path to an executable. ``root_dir`` is not used if this is an
180 absolute path.
181 root_dir : `str`, optional
182 Directory containing executable. Ignored if `None`.
183 args : `list` or `tuple`, optional
184 Arguments to be provided to the executable.
185 msg : `str`, optional
186 Message to use when the test fails. Can be `None` for default
187 message.
189 Raises
190 ------
191 AssertionError
192 The executable did not return 0 exit status.
193 """
195 if root_dir is not None and not os.path.isabs(executable):
196 executable = os.path.join(root_dir, executable)
198 # Form the argument list for subprocess
199 sp_args = [executable]
200 argstr = "no arguments"
201 if args is not None:
202 sp_args.extend(args)
203 argstr = 'arguments "' + " ".join(args) + '"'
205 print("Running executable '{}' with {}...".format(executable, argstr))
206 if not os.path.exists(executable):
207 self.skipTest("Executable {} is unexpectedly missing".format(executable))
208 failmsg = None
209 try:
210 output = subprocess.check_output(sp_args)
211 except subprocess.CalledProcessError as e:
212 output = e.output
213 failmsg = "Bad exit status from '{}': {}".format(executable, e.returncode)
214 print(output.decode('utf-8'))
215 if failmsg:
216 if msg is None:
217 msg = failmsg
218 self.fail(msg)
220 @classmethod
221 def _build_test_method(cls, executable, root_dir):
222 """Build a test method and attach to class.
224 A test method is created for the supplied excutable located
225 in the supplied root directory. This method is attached to the class
226 so that the test runner will discover the test and run it.
228 Parameters
229 ----------
230 cls : `object`
231 The class in which to create the tests.
232 executable : `str`
233 Name of executable. Can be absolute path.
234 root_dir : `str`
235 Path to executable. Not used if executable path is absolute.
236 """
237 if not os.path.isabs(executable): 237 ↛ 238line 237 didn't jump to line 238, because the condition on line 237 was never true
238 executable = os.path.abspath(os.path.join(root_dir, executable))
240 # Create the test name from the executable path.
241 test_name = "test_exe_" + executable.replace("/", "_")
243 # This is the function that will become the test method
244 def test_executable_runs(*args):
245 self = args[0]
246 self.assertExecutable(executable)
248 # Give it a name and attach it to the class
249 test_executable_runs.__name__ = test_name
250 setattr(cls, test_name, test_executable_runs)
252 @classmethod
253 def create_executable_tests(cls, ref_file, executables=None):
254 """Discover executables to test and create corresponding test methods.
256 Scans the directory containing the supplied reference file
257 (usually ``__file__`` supplied from the test class) to look for
258 executables. If executables are found a test method is created
259 for each one. That test method will run the executable and
260 check the returned value.
262 Executable scripts with a ``.py`` extension and shared libraries
263 are ignored by the scanner.
265 This class method must be called before test discovery.
267 Parameters
268 ----------
269 ref_file : `str`
270 Path to a file within the directory to be searched.
271 If the files are in the same location as the test file, then
272 ``__file__`` can be used.
273 executables : `list` or `tuple`, optional
274 Sequence of executables that can override the automated
275 detection. If an executable mentioned here is not found, a
276 skipped test will be created for it, rather than a failed
277 test.
279 Examples
280 --------
281 >>> cls.create_executable_tests(__file__)
282 """
284 # Get the search directory from the reference file
285 ref_dir = os.path.abspath(os.path.dirname(ref_file))
287 if executables is None: 287 ↛ 302line 287 didn't jump to line 302, because the condition on line 287 was never false
288 # Look for executables to test by walking the tree
289 executables = []
290 for root, dirs, files in os.walk(ref_dir):
291 for f in files:
292 # Skip Python files. Shared libraries are executable.
293 if not f.endswith(".py") and not f.endswith(".so"):
294 full_path = os.path.join(root, f)
295 if os.access(full_path, os.X_OK):
296 executables.append(full_path)
298 # Store the number of tests found for later assessment.
299 # Do not raise an exception if we have no executables as this would
300 # cause the testing to abort before the test runner could properly
301 # integrate it into the failure report.
302 cls.TESTS_DISCOVERED = len(executables)
304 # Create the test functions and attach them to the class
305 for e in executables:
306 cls._build_test_method(e, ref_dir)
309@contextlib.contextmanager
310def getTempFilePath(ext, expectOutput=True):
311 """Return a path suitable for a temporary file and try to delete the
312 file on success
314 If the with block completes successfully then the file is deleted,
315 if possible; failure results in a printed warning.
316 If a file is remains when it should not, a RuntimeError exception is
317 raised. This exception is also raised if a file is not present on context
318 manager exit when one is expected to exist.
319 If the block exits with an exception the file if left on disk so it can be
320 examined. The file name has a random component such that nested context
321 managers can be used with the same file suffix.
323 Parameters
324 ----------
326 ext : `str`
327 File name extension, e.g. ``.fits``.
328 expectOutput : `bool`, optional
329 If `True`, a file should be created within the context manager.
330 If `False`, a file should not be present when the context manager
331 exits.
333 Returns
334 -------
335 `str`
336 Path for a temporary file. The path is a combination of the caller's
337 file path and the name of the top-level function
339 Notes
340 -----
341 ::
343 # file tests/testFoo.py
344 import unittest
345 import lsst.utils.tests
346 class FooTestCase(unittest.TestCase):
347 def testBasics(self):
348 self.runTest()
350 def runTest(self):
351 with lsst.utils.tests.getTempFilePath(".fits") as tmpFile:
352 # if tests/.tests exists then
353 # tmpFile = "tests/.tests/testFoo_testBasics.fits"
354 # otherwise tmpFile = "testFoo_testBasics.fits"
355 ...
356 # at the end of this "with" block the path tmpFile will be
357 # deleted, but only if the file exists and the "with"
358 # block terminated normally (rather than with an exception)
359 ...
360 """
361 stack = inspect.stack()
362 # get name of first function in the file
363 for i in range(2, len(stack)):
364 frameInfo = inspect.getframeinfo(stack[i][0])
365 if i == 2:
366 callerFilePath = frameInfo.filename
367 callerFuncName = frameInfo.function
368 elif callerFilePath == frameInfo.filename:
369 # this function called the previous function
370 callerFuncName = frameInfo.function
371 else:
372 break
374 callerDir, callerFileNameWithExt = os.path.split(callerFilePath)
375 callerFileName = os.path.splitext(callerFileNameWithExt)[0]
376 outDir = os.path.join(callerDir, ".tests")
377 if not os.path.isdir(outDir):
378 outDir = ""
379 prefix = "%s_%s-" % (callerFileName, callerFuncName)
380 outPath = tempfile.mktemp(dir=outDir, suffix=ext, prefix=prefix)
381 if os.path.exists(outPath):
382 # There should not be a file there given the randomizer. Warn and remove.
383 # Use stacklevel 3 so that the warning is reported from the end of the with block
384 warnings.warn("Unexpectedly found pre-existing tempfile named %r" % (outPath,),
385 stacklevel=3)
386 try:
387 os.remove(outPath)
388 except OSError:
389 pass
391 yield outPath
393 fileExists = os.path.exists(outPath)
394 if expectOutput:
395 if not fileExists:
396 raise RuntimeError("Temp file expected named {} but none found".format(outPath))
397 else:
398 if fileExists:
399 raise RuntimeError("Unexpectedly discovered temp file named {}".format(outPath))
400 # Try to clean up the file regardless
401 if fileExists:
402 try:
403 os.remove(outPath)
404 except OSError as e:
405 # Use stacklevel 3 so that the warning is reported from the end of the with block
406 warnings.warn("Warning: could not remove file %r: %s" % (outPath, e), stacklevel=3)
409class TestCase(unittest.TestCase):
410 """Subclass of unittest.TestCase that adds some custom assertions for
411 convenience.
412 """
415def inTestCase(func):
416 """A decorator to add a free function to our custom TestCase class, while also
417 making it available as a free function.
418 """
419 setattr(TestCase, func.__name__, func)
420 return func
423def debugger(*exceptions):
424 """Decorator to enter the debugger when there's an uncaught exception
426 To use, just slap a ``@debugger()`` on your function.
428 You may provide specific exception classes to catch as arguments to
429 the decorator function, e.g.,
430 ``@debugger(RuntimeError, NotImplementedError)``.
431 This defaults to just `AssertionError`, for use on `unittest.TestCase`
432 methods.
434 Code provided by "Rosh Oxymoron" on StackOverflow:
435 http://stackoverflow.com/questions/4398967/python-unit-testing-automatically-running-the-debugger-when-a-test-fails
437 Notes
438 -----
439 Consider using ``pytest --pdb`` instead of this decorator.
440 """
441 if not exceptions:
442 exceptions = (Exception, )
444 def decorator(f):
445 @functools.wraps(f)
446 def wrapper(*args, **kwargs):
447 try:
448 return f(*args, **kwargs)
449 except exceptions:
450 import sys
451 import pdb
452 pdb.post_mortem(sys.exc_info()[2])
453 return wrapper
454 return decorator
457def plotImageDiff(lhs, rhs, bad=None, diff=None, plotFileName=None):
458 """Plot the comparison of two 2-d NumPy arrays.
460 Parameters
461 ----------
462 lhs : `numpy.ndarray`
463 LHS values to compare; a 2-d NumPy array
464 rhs : `numpy.ndarray`
465 RHS values to compare; a 2-d NumPy array
466 bad : `numpy.ndarray`
467 A 2-d boolean NumPy array of values to emphasize in the plots
468 diff : `numpy.ndarray`
469 difference array; a 2-d NumPy array, or None to show lhs-rhs
470 plotFileName : `str`
471 Filename to save the plot to. If None, the plot will be displayed in
472 a window.
474 Notes
475 -----
476 This method uses `matplotlib` and imports it internally; it should be
477 wrapped in a try/except block within packages that do not depend on
478 `matplotlib` (including `~lsst.utils`).
479 """
480 from matplotlib import pyplot
481 if diff is None:
482 diff = lhs - rhs
483 pyplot.figure()
484 if bad is not None:
485 # make an rgba image that's red and transparent where not bad
486 badImage = numpy.zeros(bad.shape + (4,), dtype=numpy.uint8)
487 badImage[:, :, 0] = 255
488 badImage[:, :, 1] = 0
489 badImage[:, :, 2] = 0
490 badImage[:, :, 3] = 255*bad
491 vmin1 = numpy.minimum(numpy.min(lhs), numpy.min(rhs))
492 vmax1 = numpy.maximum(numpy.max(lhs), numpy.max(rhs))
493 vmin2 = numpy.min(diff)
494 vmax2 = numpy.max(diff)
495 for n, (image, title) in enumerate([(lhs, "lhs"), (rhs, "rhs"), (diff, "diff")]):
496 pyplot.subplot(2, 3, n + 1)
497 im1 = pyplot.imshow(image, cmap=pyplot.cm.gray, interpolation='nearest', origin='lower',
498 vmin=vmin1, vmax=vmax1)
499 if bad is not None:
500 pyplot.imshow(badImage, alpha=0.2, interpolation='nearest', origin='lower')
501 pyplot.axis("off")
502 pyplot.title(title)
503 pyplot.subplot(2, 3, n + 4)
504 im2 = pyplot.imshow(image, cmap=pyplot.cm.gray, interpolation='nearest', origin='lower',
505 vmin=vmin2, vmax=vmax2)
506 if bad is not None:
507 pyplot.imshow(badImage, alpha=0.2, interpolation='nearest', origin='lower')
508 pyplot.axis("off")
509 pyplot.title(title)
510 pyplot.subplots_adjust(left=0.05, bottom=0.05, top=0.92, right=0.75, wspace=0.05, hspace=0.05)
511 cax1 = pyplot.axes([0.8, 0.55, 0.05, 0.4])
512 pyplot.colorbar(im1, cax=cax1)
513 cax2 = pyplot.axes([0.8, 0.05, 0.05, 0.4])
514 pyplot.colorbar(im2, cax=cax2)
515 if plotFileName:
516 pyplot.savefig(plotFileName)
517 else:
518 pyplot.show()
521@inTestCase
522def assertFloatsAlmostEqual(testCase, lhs, rhs, rtol=sys.float_info.epsilon,
523 atol=sys.float_info.epsilon, relTo=None,
524 printFailures=True, plotOnFailure=False,
525 plotFileName=None, invert=False, msg=None):
526 """Highly-configurable floating point comparisons for scalars and arrays.
528 The test assertion will fail if all elements ``lhs`` and ``rhs`` are not
529 equal to within the tolerances specified by ``rtol`` and ``atol``.
530 More precisely, the comparison is:
532 ``abs(lhs - rhs) <= relTo*rtol OR abs(lhs - rhs) <= atol``
534 If ``rtol`` or ``atol`` is `None`, that term in the comparison is not
535 performed at all.
537 When not specified, ``relTo`` is the elementwise maximum of the absolute
538 values of ``lhs`` and ``rhs``. If set manually, it should usually be set
539 to either ``lhs`` or ``rhs``, or a scalar value typical of what is
540 expected.
542 Parameters
543 ----------
544 testCase : `unittest.TestCase`
545 Instance the test is part of.
546 lhs : scalar or array-like
547 LHS value(s) to compare; may be a scalar or array-like of any
548 dimension.
549 rhs : scalar or array-like
550 RHS value(s) to compare; may be a scalar or array-like of any
551 dimension.
552 rtol : `float`, optional
553 Relative tolerance for comparison; defaults to double-precision
554 epsilon.
555 atol : `float`, optional
556 Absolute tolerance for comparison; defaults to double-precision
557 epsilon.
558 relTo : `float`, optional
559 Value to which comparison with rtol is relative.
560 printFailures : `bool`, optional
561 Upon failure, print all inequal elements as part of the message.
562 plotOnFailure : `bool`, optional
563 Upon failure, plot the originals and their residual with matplotlib.
564 Only 2-d arrays are supported.
565 plotFileName : `str`, optional
566 Filename to save the plot to. If `None`, the plot will be displayed in
567 a window.
568 invert : `bool`, optional
569 If `True`, invert the comparison and fail only if any elements *are*
570 equal. Used to implement `~lsst.utils.tests.assertFloatsNotEqual`,
571 which should generally be used instead for clarity.
572 msg : `str`, optional
573 String to append to the error message when assert fails.
575 Raises
576 ------
577 AssertionError
578 The values are not almost equal.
579 """
580 if not numpy.isfinite(lhs).all():
581 testCase.fail("Non-finite values in lhs")
582 if not numpy.isfinite(rhs).all():
583 testCase.fail("Non-finite values in rhs")
584 diff = lhs - rhs
585 absDiff = numpy.abs(lhs - rhs)
586 if rtol is not None:
587 if relTo is None:
588 relTo = numpy.maximum(numpy.abs(lhs), numpy.abs(rhs))
589 else:
590 relTo = numpy.abs(relTo)
591 bad = absDiff > rtol*relTo
592 if atol is not None:
593 bad = numpy.logical_and(bad, absDiff > atol)
594 else:
595 if atol is None:
596 raise ValueError("rtol and atol cannot both be None")
597 bad = absDiff > atol
598 failed = numpy.any(bad)
599 if invert:
600 failed = not failed
601 bad = numpy.logical_not(bad)
602 cmpStr = "=="
603 failStr = "are the same"
604 else:
605 cmpStr = "!="
606 failStr = "differ"
607 errMsg = []
608 if failed:
609 if numpy.isscalar(bad):
610 if rtol is None:
611 errMsg = ["%s %s %s; diff=%s with atol=%s"
612 % (lhs, cmpStr, rhs, absDiff, atol)]
613 elif atol is None:
614 errMsg = ["%s %s %s; diff=%s/%s=%s with rtol=%s"
615 % (lhs, cmpStr, rhs, absDiff, relTo, absDiff/relTo, rtol)]
616 else:
617 errMsg = ["%s %s %s; diff=%s/%s=%s with rtol=%s, atol=%s"
618 % (lhs, cmpStr, rhs, absDiff, relTo, absDiff/relTo, rtol, atol)]
619 else:
620 errMsg = ["%d/%d elements %s with rtol=%s, atol=%s"
621 % (bad.sum(), bad.size, failStr, rtol, atol)]
622 if plotOnFailure:
623 if len(lhs.shape) != 2 or len(rhs.shape) != 2:
624 raise ValueError("plotOnFailure is only valid for 2-d arrays")
625 try:
626 plotImageDiff(lhs, rhs, bad, diff=diff, plotFileName=plotFileName)
627 except ImportError:
628 errMsg.append("Failure plot requested but matplotlib could not be imported.")
629 if printFailures:
630 # Make sure everything is an array if any of them are, so we can treat
631 # them the same (diff and absDiff are arrays if either rhs or lhs is),
632 # and we don't get here if neither is.
633 if numpy.isscalar(relTo):
634 relTo = numpy.ones(bad.shape, dtype=float) * relTo
635 if numpy.isscalar(lhs):
636 lhs = numpy.ones(bad.shape, dtype=float) * lhs
637 if numpy.isscalar(rhs):
638 rhs = numpy.ones(bad.shape, dtype=float) * rhs
639 if rtol is None:
640 for a, b, diff in zip(lhs[bad], rhs[bad], absDiff[bad]):
641 errMsg.append("%s %s %s (diff=%s)" % (a, cmpStr, b, diff))
642 else:
643 for a, b, diff, rel in zip(lhs[bad], rhs[bad], absDiff[bad], relTo[bad]):
644 errMsg.append("%s %s %s (diff=%s/%s=%s)" % (a, cmpStr, b, diff, rel, diff/rel))
646 if msg is not None:
647 errMsg.append(msg)
648 testCase.assertFalse(failed, msg="\n".join(errMsg))
651@inTestCase
652def assertFloatsNotEqual(testCase, lhs, rhs, **kwds):
653 """Fail a test if the given floating point values are equal to within the
654 given tolerances.
656 See `~lsst.utils.tests.assertFloatsAlmostEqual` (called with
657 ``rtol=atol=0``) for more information.
659 Parameters
660 ----------
661 testCase : `unittest.TestCase`
662 Instance the test is part of.
663 lhs : scalar or array-like
664 LHS value(s) to compare; may be a scalar or array-like of any
665 dimension.
666 rhs : scalar or array-like
667 RHS value(s) to compare; may be a scalar or array-like of any
668 dimension.
670 Raises
671 ------
672 AssertionError
673 The values are almost equal.
674 """
675 return assertFloatsAlmostEqual(testCase, lhs, rhs, invert=True, **kwds)
678@inTestCase
679def assertFloatsEqual(testCase, lhs, rhs, **kwargs):
680 """
681 Assert that lhs == rhs (both numeric types, whether scalar or array).
683 See `~lsst.utils.tests.assertFloatsAlmostEqual` (called with
684 ``rtol=atol=0``) for more information.
686 Parameters
687 ----------
688 testCase : `unittest.TestCase`
689 Instance the test is part of.
690 lhs : scalar or array-like
691 LHS value(s) to compare; may be a scalar or array-like of any
692 dimension.
693 rhs : scalar or array-like
694 RHS value(s) to compare; may be a scalar or array-like of any
695 dimension.
697 Raises
698 ------
699 AssertionError
700 The values are not equal.
701 """
702 return assertFloatsAlmostEqual(testCase, lhs, rhs, rtol=0, atol=0, **kwargs)
705def _settingsIterator(settings):
706 """Return an iterator for the provided test settings
708 Parameters
709 ----------
710 settings : `dict` (`str`: iterable)
711 Lists of test parameters. Each should be an iterable of the same length.
712 If a string is provided as an iterable, it will be converted to a list
713 of a single string.
715 Raises
716 ------
717 AssertionError
718 If the ``settings`` are not of the same length.
720 Yields
721 ------
722 parameters : `dict` (`str`: anything)
723 Set of parameters.
724 """
725 for name, values in settings.items():
726 if isinstance(values, str): 726 ↛ 728line 726 didn't jump to line 728, because the condition on line 726 was never true
727 # Probably meant as a single-element string, rather than an iterable of chars
728 settings[name] = [values]
729 num = len(next(iter(settings.values()))) # Number of settings
730 for name, values in settings.items():
731 assert len(values) == num, f"Length mismatch for setting {name}: {len(values)} vs {num}"
732 for ii in range(num):
733 values = [settings[kk][ii] for kk in settings]
734 yield dict(zip(settings.keys(), values))
737def classParameters(**settings):
738 """Class decorator for generating unit tests
740 This decorator generates classes with class variables according to the
741 supplied ``settings``.
743 Parameters
744 ----------
745 **settings : `dict` (`str`: iterable)
746 The lists of test parameters to set as class variables in turn. Each
747 should be an iterable of the same length.
749 Examples
750 --------
751 ::
753 @classParameters(foo=[1, 2], bar=[3, 4])
754 class MyTestCase(unittest.TestCase):
755 ...
757 will generate two classes, as if you wrote::
759 class MyTestCase_1_3(unittest.TestCase):
760 foo = 1
761 bar = 3
762 ...
764 class MyTestCase_2_4(unittest.TestCase):
765 foo = 2
766 bar = 4
767 ...
769 Note that the values are embedded in the class name.
770 """
771 def decorator(cls):
772 module = sys.modules[cls.__module__].__dict__
773 for params in _settingsIterator(settings):
774 name = f"{cls.__name__}_{'_'.join(str(vv) for vv in params.values())}"
775 bindings = dict(cls.__dict__)
776 bindings.update(params)
777 module[name] = type(name, (cls,), bindings)
778 return decorator
781def methodParameters(**settings):
782 """Method decorator for unit tests
784 This decorator iterates over the supplied settings, using
785 ``TestCase.subTest`` to communicate the values in the event of a failure.
787 Parameters
788 ----------
789 **settings : `dict` (`str`: iterable)
790 The lists of test parameters. Each should be an iterable of the same
791 length.
793 Examples
794 --------
795 ::
797 @methodParameters(foo=[1, 2], bar=[3, 4])
798 def testSomething(self, foo, bar):
799 ...
801 will run::
803 testSomething(foo=1, bar=3)
804 testSomething(foo=2, bar=4)
805 """
806 def decorator(func):
807 @functools.wraps(func)
808 def wrapper(self, *args, **kwargs):
809 for params in _settingsIterator(settings):
810 kwargs.update(params)
811 with self.subTest(**params):
812 func(self, *args, **kwargs)
813 return wrapper
814 return decorator
817@contextlib.contextmanager
818def temporaryDirectory():
819 """Context manager that creates and destroys a temporary directory.
821 The difference from `tempfile.TemporaryDirectory` is that this ignores
822 errors when deleting a directory, which may happen with some filesystems.
823 """
824 tmpdir = tempfile.mkdtemp()
825 yield tmpdir
826 shutil.rmtree(tmpdir, ignore_errors=True)