Coverage for python/lsst/utils/tests.py : 26%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1#
2# LSST Data Management System
3#
4# Copyright 2008-2017 AURA/LSST.
5#
6# This product includes software developed by the
7# LSST Project (http://www.lsst.org/).
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the LSST License Statement and
20# the GNU General Public License along with this program. If not,
21# see <https://www.lsstcorp.org/LegalNotices/>.
22#
23"""Support code for running unit tests"""
25import contextlib
26import gc
27import inspect
28import os
29import subprocess
30import sys
31import unittest
32import warnings
33import numpy
34import functools
35import tempfile
37__all__ = ["init", "MemoryTestCase", "ExecutablesTestCase", "getTempFilePath",
38 "TestCase", "assertFloatsAlmostEqual", "assertFloatsNotEqual", "assertFloatsEqual",
39 "debugger", "classParameters", "methodParameters"]
41# File descriptor leak test will be skipped if psutil can not be imported
42try:
43 import psutil
44except ImportError:
45 psutil = None
47# Initialize the list of open files to an empty set
48open_files = set()
51def _get_open_files():
52 """Return a set containing the list of files currently open in this
53 process.
55 Returns
56 -------
57 open_files : `set`
58 Set containing the list of open files.
59 """
60 if psutil is None:
61 return set()
62 return set(p.path for p in psutil.Process().open_files())
65def init():
66 """Initialize the memory tester and file descriptor leak tester."""
67 global open_files
68 # Reset the list of open files
69 open_files = _get_open_files()
72def sort_tests(tests):
73 """Sort supplied test suites such that MemoryTestCases are at the end.
75 `lsst.utils.tests.MemoryTestCase` tests should always run after any other
76 tests in the module.
78 Parameters
79 ----------
80 tests : sequence
81 Sequence of test suites.
83 Returns
84 -------
85 suite : `unittest.TestSuite`
86 A combined `~unittest.TestSuite` with
87 `~lsst.utils.tests.MemoryTestCase` at the end.
88 """
90 suite = unittest.TestSuite()
91 memtests = []
92 for test_suite in tests:
93 try:
94 # Just test the first test method in the suite for MemoryTestCase
95 # Use loop rather than next as it is possible for a test class
96 # to not have any test methods and the Python community prefers
97 # for loops over catching a StopIteration exception.
98 bases = None
99 for method in test_suite:
100 bases = inspect.getmro(method.__class__)
101 break
102 if bases is not None and MemoryTestCase in bases:
103 memtests.append(test_suite)
104 else:
105 suite.addTests(test_suite)
106 except TypeError:
107 if isinstance(test_suite, MemoryTestCase):
108 memtests.append(test_suite)
109 else:
110 suite.addTest(test_suite)
111 suite.addTests(memtests)
112 return suite
115def suiteClassWrapper(tests):
116 return unittest.TestSuite(sort_tests(tests))
119# Replace the suiteClass callable in the defaultTestLoader
120# so that we can reorder the test ordering. This will have
121# no effect if no memory test cases are found.
122unittest.defaultTestLoader.suiteClass = suiteClassWrapper
125class MemoryTestCase(unittest.TestCase):
126 """Check for resource leaks."""
128 @classmethod
129 def tearDownClass(cls):
130 """Reset the leak counter when the tests have been completed"""
131 init()
133 def testFileDescriptorLeaks(self):
134 """Check if any file descriptors are open since init() called."""
135 if psutil is None:
136 self.skipTest("Unable to test file descriptor leaks. psutil unavailable.")
137 gc.collect()
138 global open_files
139 now_open = _get_open_files()
141 # Some files are opened out of the control of the stack.
142 now_open = set(f for f in now_open if not f.endswith(".car") and
143 not f.startswith("/proc/") and
144 not f.endswith(".ttf") and
145 not (f.startswith("/var/lib/") and f.endswith("/passwd")) and
146 not f.endswith("astropy.log"))
148 diff = now_open.difference(open_files)
149 if diff:
150 for f in diff:
151 print("File open: %s" % f)
152 self.fail("Failed to close %d file%s" % (len(diff), "s" if len(diff) != 1 else ""))
155class ExecutablesTestCase(unittest.TestCase):
156 """Test that executables can be run and return good status.
158 The test methods are dynamically created. Callers
159 must subclass this class in their own test file and invoke
160 the create_executable_tests() class method to register the tests.
161 """
162 TESTS_DISCOVERED = -1
164 @classmethod
165 def setUpClass(cls):
166 """Abort testing if automated test creation was enabled and
167 no tests were found."""
169 if cls.TESTS_DISCOVERED == 0:
170 raise Exception("No executables discovered.")
172 def testSanity(self):
173 """This test exists to ensure that there is at least one test to be
174 executed. This allows the test runner to trigger the class set up
175 machinery to test whether there are some executables to test."""
176 pass
178 def assertExecutable(self, executable, root_dir=None, args=None, msg=None):
179 """Check an executable runs and returns good status.
181 Prints output to standard out. On bad exit status the test
182 fails. If the executable can not be located the test is skipped.
184 Parameters
185 ----------
186 executable : `str`
187 Path to an executable. ``root_dir`` is not used if this is an
188 absolute path.
189 root_dir : `str`, optional
190 Directory containing executable. Ignored if `None`.
191 args : `list` or `tuple`, optional
192 Arguments to be provided to the executable.
193 msg : `str`, optional
194 Message to use when the test fails. Can be `None` for default
195 message.
197 Raises
198 ------
199 AssertionError
200 The executable did not return 0 exit status.
201 """
203 if root_dir is not None and not os.path.isabs(executable):
204 executable = os.path.join(root_dir, executable)
206 # Form the argument list for subprocess
207 sp_args = [executable]
208 argstr = "no arguments"
209 if args is not None:
210 sp_args.extend(args)
211 argstr = 'arguments "' + " ".join(args) + '"'
213 print("Running executable '{}' with {}...".format(executable, argstr))
214 if not os.path.exists(executable):
215 self.skipTest("Executable {} is unexpectedly missing".format(executable))
216 failmsg = None
217 try:
218 output = subprocess.check_output(sp_args)
219 except subprocess.CalledProcessError as e:
220 output = e.output
221 failmsg = "Bad exit status from '{}': {}".format(executable, e.returncode)
222 print(output.decode('utf-8'))
223 if failmsg:
224 if msg is None:
225 msg = failmsg
226 self.fail(msg)
228 @classmethod
229 def _build_test_method(cls, executable, root_dir):
230 """Build a test method and attach to class.
232 A test method is created for the supplied excutable located
233 in the supplied root directory. This method is attached to the class
234 so that the test runner will discover the test and run it.
236 Parameters
237 ----------
238 cls : `object`
239 The class in which to create the tests.
240 executable : `str`
241 Name of executable. Can be absolute path.
242 root_dir : `str`
243 Path to executable. Not used if executable path is absolute.
244 """
245 if not os.path.isabs(executable): 245 ↛ 246line 245 didn't jump to line 246, because the condition on line 245 was never true
246 executable = os.path.abspath(os.path.join(root_dir, executable))
248 # Create the test name from the executable path.
249 test_name = "test_exe_" + executable.replace("/", "_")
251 # This is the function that will become the test method
252 def test_executable_runs(*args):
253 self = args[0]
254 self.assertExecutable(executable)
256 # Give it a name and attach it to the class
257 test_executable_runs.__name__ = test_name
258 setattr(cls, test_name, test_executable_runs)
260 @classmethod
261 def create_executable_tests(cls, ref_file, executables=None):
262 """Discover executables to test and create corresponding test methods.
264 Scans the directory containing the supplied reference file
265 (usually ``__file__`` supplied from the test class) to look for
266 executables. If executables are found a test method is created
267 for each one. That test method will run the executable and
268 check the returned value.
270 Executable scripts with a ``.py`` extension and shared libraries
271 are ignored by the scanner.
273 This class method must be called before test discovery.
275 Parameters
276 ----------
277 ref_file : `str`
278 Path to a file within the directory to be searched.
279 If the files are in the same location as the test file, then
280 ``__file__`` can be used.
281 executables : `list` or `tuple`, optional
282 Sequence of executables that can override the automated
283 detection. If an executable mentioned here is not found, a
284 skipped test will be created for it, rather than a failed
285 test.
287 Examples
288 --------
289 >>> cls.create_executable_tests(__file__)
290 """
292 # Get the search directory from the reference file
293 ref_dir = os.path.abspath(os.path.dirname(ref_file))
295 if executables is None: 295 ↛ 310line 295 didn't jump to line 310, because the condition on line 295 was never false
296 # Look for executables to test by walking the tree
297 executables = []
298 for root, dirs, files in os.walk(ref_dir):
299 for f in files:
300 # Skip Python files. Shared libraries are executable.
301 if not f.endswith(".py") and not f.endswith(".so"):
302 full_path = os.path.join(root, f)
303 if os.access(full_path, os.X_OK):
304 executables.append(full_path)
306 # Store the number of tests found for later assessment.
307 # Do not raise an exception if we have no executables as this would
308 # cause the testing to abort before the test runner could properly
309 # integrate it into the failure report.
310 cls.TESTS_DISCOVERED = len(executables)
312 # Create the test functions and attach them to the class
313 for e in executables:
314 cls._build_test_method(e, ref_dir)
317@contextlib.contextmanager
318def getTempFilePath(ext, expectOutput=True):
319 """Return a path suitable for a temporary file and try to delete the
320 file on success
322 If the with block completes successfully then the file is deleted,
323 if possible; failure results in a printed warning.
324 If a file is remains when it should not, a RuntimeError exception is
325 raised. This exception is also raised if a file is not present on context
326 manager exit when one is expected to exist.
327 If the block exits with an exception the file if left on disk so it can be
328 examined. The file name has a random component such that nested context
329 managers can be used with the same file suffix.
331 Parameters
332 ----------
334 ext : `str`
335 File name extension, e.g. ``.fits``.
336 expectOutput : `bool`, optional
337 If `True`, a file should be created within the context manager.
338 If `False`, a file should not be present when the context manager
339 exits.
341 Returns
342 -------
343 `str`
344 Path for a temporary file. The path is a combination of the caller's
345 file path and the name of the top-level function
347 Notes
348 -----
349 ::
351 # file tests/testFoo.py
352 import unittest
353 import lsst.utils.tests
354 class FooTestCase(unittest.TestCase):
355 def testBasics(self):
356 self.runTest()
358 def runTest(self):
359 with lsst.utils.tests.getTempFilePath(".fits") as tmpFile:
360 # if tests/.tests exists then
361 # tmpFile = "tests/.tests/testFoo_testBasics.fits"
362 # otherwise tmpFile = "testFoo_testBasics.fits"
363 ...
364 # at the end of this "with" block the path tmpFile will be
365 # deleted, but only if the file exists and the "with"
366 # block terminated normally (rather than with an exception)
367 ...
368 """
369 stack = inspect.stack()
370 # get name of first function in the file
371 for i in range(2, len(stack)):
372 frameInfo = inspect.getframeinfo(stack[i][0])
373 if i == 2:
374 callerFilePath = frameInfo.filename
375 callerFuncName = frameInfo.function
376 elif callerFilePath == frameInfo.filename:
377 # this function called the previous function
378 callerFuncName = frameInfo.function
379 else:
380 break
382 callerDir, callerFileNameWithExt = os.path.split(callerFilePath)
383 callerFileName = os.path.splitext(callerFileNameWithExt)[0]
384 outDir = os.path.join(callerDir, ".tests")
385 if not os.path.isdir(outDir):
386 outDir = ""
387 prefix = "%s_%s-" % (callerFileName, callerFuncName)
388 outPath = tempfile.mktemp(dir=outDir, suffix=ext, prefix=prefix)
389 if os.path.exists(outPath):
390 # There should not be a file there given the randomizer. Warn and remove.
391 # Use stacklevel 3 so that the warning is reported from the end of the with block
392 warnings.warn("Unexpectedly found pre-existing tempfile named %r" % (outPath,),
393 stacklevel=3)
394 try:
395 os.remove(outPath)
396 except OSError:
397 pass
399 yield outPath
401 fileExists = os.path.exists(outPath)
402 if expectOutput:
403 if not fileExists:
404 raise RuntimeError("Temp file expected named {} but none found".format(outPath))
405 else:
406 if fileExists:
407 raise RuntimeError("Unexpectedly discovered temp file named {}".format(outPath))
408 # Try to clean up the file regardless
409 if fileExists:
410 try:
411 os.remove(outPath)
412 except OSError as e:
413 # Use stacklevel 3 so that the warning is reported from the end of the with block
414 warnings.warn("Warning: could not remove file %r: %s" % (outPath, e), stacklevel=3)
417class TestCase(unittest.TestCase):
418 """Subclass of unittest.TestCase that adds some custom assertions for
419 convenience.
420 """
423def inTestCase(func):
424 """A decorator to add a free function to our custom TestCase class, while also
425 making it available as a free function.
426 """
427 setattr(TestCase, func.__name__, func)
428 return func
431def debugger(*exceptions):
432 """Decorator to enter the debugger when there's an uncaught exception
434 To use, just slap a ``@debugger()`` on your function.
436 You may provide specific exception classes to catch as arguments to
437 the decorator function, e.g.,
438 ``@debugger(RuntimeError, NotImplementedError)``.
439 This defaults to just `AssertionError`, for use on `unittest.TestCase`
440 methods.
442 Code provided by "Rosh Oxymoron" on StackOverflow:
443 http://stackoverflow.com/questions/4398967/python-unit-testing-automatically-running-the-debugger-when-a-test-fails
445 Notes
446 -----
447 Consider using ``pytest --pdb`` instead of this decorator.
448 """
449 if not exceptions:
450 exceptions = (Exception, )
452 def decorator(f):
453 @functools.wraps(f)
454 def wrapper(*args, **kwargs):
455 try:
456 return f(*args, **kwargs)
457 except exceptions:
458 import sys
459 import pdb
460 pdb.post_mortem(sys.exc_info()[2])
461 return wrapper
462 return decorator
465def plotImageDiff(lhs, rhs, bad=None, diff=None, plotFileName=None):
466 """Plot the comparison of two 2-d NumPy arrays.
468 Parameters
469 ----------
470 lhs : `numpy.ndarray`
471 LHS values to compare; a 2-d NumPy array
472 rhs : `numpy.ndarray`
473 RHS values to compare; a 2-d NumPy array
474 bad : `numpy.ndarray`
475 A 2-d boolean NumPy array of values to emphasize in the plots
476 diff : `numpy.ndarray`
477 difference array; a 2-d NumPy array, or None to show lhs-rhs
478 plotFileName : `str`
479 Filename to save the plot to. If None, the plot will be displayed in
480 a window.
482 Notes
483 -----
484 This method uses `matplotlib` and imports it internally; it should be
485 wrapped in a try/except block within packages that do not depend on
486 `matplotlib` (including `~lsst.utils`).
487 """
488 from matplotlib import pyplot
489 if diff is None:
490 diff = lhs - rhs
491 pyplot.figure()
492 if bad is not None:
493 # make an rgba image that's red and transparent where not bad
494 badImage = numpy.zeros(bad.shape + (4,), dtype=numpy.uint8)
495 badImage[:, :, 0] = 255
496 badImage[:, :, 1] = 0
497 badImage[:, :, 2] = 0
498 badImage[:, :, 3] = 255*bad
499 vmin1 = numpy.minimum(numpy.min(lhs), numpy.min(rhs))
500 vmax1 = numpy.maximum(numpy.max(lhs), numpy.max(rhs))
501 vmin2 = numpy.min(diff)
502 vmax2 = numpy.max(diff)
503 for n, (image, title) in enumerate([(lhs, "lhs"), (rhs, "rhs"), (diff, "diff")]):
504 pyplot.subplot(2, 3, n + 1)
505 im1 = pyplot.imshow(image, cmap=pyplot.cm.gray, interpolation='nearest', origin='lower',
506 vmin=vmin1, vmax=vmax1)
507 if bad is not None:
508 pyplot.imshow(badImage, alpha=0.2, interpolation='nearest', origin='lower')
509 pyplot.axis("off")
510 pyplot.title(title)
511 pyplot.subplot(2, 3, n + 4)
512 im2 = pyplot.imshow(image, cmap=pyplot.cm.gray, interpolation='nearest', origin='lower',
513 vmin=vmin2, vmax=vmax2)
514 if bad is not None:
515 pyplot.imshow(badImage, alpha=0.2, interpolation='nearest', origin='lower')
516 pyplot.axis("off")
517 pyplot.title(title)
518 pyplot.subplots_adjust(left=0.05, bottom=0.05, top=0.92, right=0.75, wspace=0.05, hspace=0.05)
519 cax1 = pyplot.axes([0.8, 0.55, 0.05, 0.4])
520 pyplot.colorbar(im1, cax=cax1)
521 cax2 = pyplot.axes([0.8, 0.05, 0.05, 0.4])
522 pyplot.colorbar(im2, cax=cax2)
523 if plotFileName:
524 pyplot.savefig(plotFileName)
525 else:
526 pyplot.show()
529@inTestCase
530def assertFloatsAlmostEqual(testCase, lhs, rhs, rtol=sys.float_info.epsilon,
531 atol=sys.float_info.epsilon, relTo=None,
532 printFailures=True, plotOnFailure=False,
533 plotFileName=None, invert=False, msg=None):
534 """Highly-configurable floating point comparisons for scalars and arrays.
536 The test assertion will fail if all elements ``lhs`` and ``rhs`` are not
537 equal to within the tolerances specified by ``rtol`` and ``atol``.
538 More precisely, the comparison is:
540 ``abs(lhs - rhs) <= relTo*rtol OR abs(lhs - rhs) <= atol``
542 If ``rtol`` or ``atol`` is `None`, that term in the comparison is not
543 performed at all.
545 When not specified, ``relTo`` is the elementwise maximum of the absolute
546 values of ``lhs`` and ``rhs``. If set manually, it should usually be set
547 to either ``lhs`` or ``rhs``, or a scalar value typical of what is
548 expected.
550 Parameters
551 ----------
552 testCase : `unittest.TestCase`
553 Instance the test is part of.
554 lhs : scalar or array-like
555 LHS value(s) to compare; may be a scalar or array-like of any
556 dimension.
557 rhs : scalar or array-like
558 RHS value(s) to compare; may be a scalar or array-like of any
559 dimension.
560 rtol : `float`, optional
561 Relative tolerance for comparison; defaults to double-precision
562 epsilon.
563 atol : `float`, optional
564 Absolute tolerance for comparison; defaults to double-precision
565 epsilon.
566 relTo : `float`, optional
567 Value to which comparison with rtol is relative.
568 printFailures : `bool`, optional
569 Upon failure, print all inequal elements as part of the message.
570 plotOnFailure : `bool`, optional
571 Upon failure, plot the originals and their residual with matplotlib.
572 Only 2-d arrays are supported.
573 plotFileName : `str`, optional
574 Filename to save the plot to. If `None`, the plot will be displayed in
575 a window.
576 invert : `bool`, optional
577 If `True`, invert the comparison and fail only if any elements *are*
578 equal. Used to implement `~lsst.utils.tests.assertFloatsNotEqual`,
579 which should generally be used instead for clarity.
580 msg : `str`, optional
581 String to append to the error message when assert fails.
583 Raises
584 ------
585 AssertionError
586 The values are not almost equal.
587 """
588 if not numpy.isfinite(lhs).all():
589 testCase.fail("Non-finite values in lhs")
590 if not numpy.isfinite(rhs).all():
591 testCase.fail("Non-finite values in rhs")
592 diff = lhs - rhs
593 absDiff = numpy.abs(lhs - rhs)
594 if rtol is not None:
595 if relTo is None:
596 relTo = numpy.maximum(numpy.abs(lhs), numpy.abs(rhs))
597 else:
598 relTo = numpy.abs(relTo)
599 bad = absDiff > rtol*relTo
600 if atol is not None:
601 bad = numpy.logical_and(bad, absDiff > atol)
602 else:
603 if atol is None:
604 raise ValueError("rtol and atol cannot both be None")
605 bad = absDiff > atol
606 failed = numpy.any(bad)
607 if invert:
608 failed = not failed
609 bad = numpy.logical_not(bad)
610 cmpStr = "=="
611 failStr = "are the same"
612 else:
613 cmpStr = "!="
614 failStr = "differ"
615 errMsg = []
616 if failed:
617 if numpy.isscalar(bad):
618 if rtol is None:
619 errMsg = ["%s %s %s; diff=%s with atol=%s"
620 % (lhs, cmpStr, rhs, absDiff, atol)]
621 elif atol is None:
622 errMsg = ["%s %s %s; diff=%s/%s=%s with rtol=%s"
623 % (lhs, cmpStr, rhs, absDiff, relTo, absDiff/relTo, rtol)]
624 else:
625 errMsg = ["%s %s %s; diff=%s/%s=%s with rtol=%s, atol=%s"
626 % (lhs, cmpStr, rhs, absDiff, relTo, absDiff/relTo, rtol, atol)]
627 else:
628 errMsg = ["%d/%d elements %s with rtol=%s, atol=%s"
629 % (bad.sum(), bad.size, failStr, rtol, atol)]
630 if plotOnFailure:
631 if len(lhs.shape) != 2 or len(rhs.shape) != 2:
632 raise ValueError("plotOnFailure is only valid for 2-d arrays")
633 try:
634 plotImageDiff(lhs, rhs, bad, diff=diff, plotFileName=plotFileName)
635 except ImportError:
636 errMsg.append("Failure plot requested but matplotlib could not be imported.")
637 if printFailures:
638 # Make sure everything is an array if any of them are, so we can treat
639 # them the same (diff and absDiff are arrays if either rhs or lhs is),
640 # and we don't get here if neither is.
641 if numpy.isscalar(relTo):
642 relTo = numpy.ones(bad.shape, dtype=float) * relTo
643 if numpy.isscalar(lhs):
644 lhs = numpy.ones(bad.shape, dtype=float) * lhs
645 if numpy.isscalar(rhs):
646 rhs = numpy.ones(bad.shape, dtype=float) * rhs
647 if rtol is None:
648 for a, b, diff in zip(lhs[bad], rhs[bad], absDiff[bad]):
649 errMsg.append("%s %s %s (diff=%s)" % (a, cmpStr, b, diff))
650 else:
651 for a, b, diff, rel in zip(lhs[bad], rhs[bad], absDiff[bad], relTo[bad]):
652 errMsg.append("%s %s %s (diff=%s/%s=%s)" % (a, cmpStr, b, diff, rel, diff/rel))
654 if msg is not None:
655 errMsg.append(msg)
656 testCase.assertFalse(failed, msg="\n".join(errMsg))
659@inTestCase
660def assertFloatsNotEqual(testCase, lhs, rhs, **kwds):
661 """Fail a test if the given floating point values are equal to within the
662 given tolerances.
664 See `~lsst.utils.tests.assertFloatsAlmostEqual` (called with
665 ``rtol=atol=0``) for more information.
667 Parameters
668 ----------
669 testCase : `unittest.TestCase`
670 Instance the test is part of.
671 lhs : scalar or array-like
672 LHS value(s) to compare; may be a scalar or array-like of any
673 dimension.
674 rhs : scalar or array-like
675 RHS value(s) to compare; may be a scalar or array-like of any
676 dimension.
678 Raises
679 ------
680 AssertionError
681 The values are almost equal.
682 """
683 return assertFloatsAlmostEqual(testCase, lhs, rhs, invert=True, **kwds)
686@inTestCase
687def assertFloatsEqual(testCase, lhs, rhs, **kwargs):
688 """
689 Assert that lhs == rhs (both numeric types, whether scalar or array).
691 See `~lsst.utils.tests.assertFloatsAlmostEqual` (called with
692 ``rtol=atol=0``) for more information.
694 Parameters
695 ----------
696 testCase : `unittest.TestCase`
697 Instance the test is part of.
698 lhs : scalar or array-like
699 LHS value(s) to compare; may be a scalar or array-like of any
700 dimension.
701 rhs : scalar or array-like
702 RHS value(s) to compare; may be a scalar or array-like of any
703 dimension.
705 Raises
706 ------
707 AssertionError
708 The values are not equal.
709 """
710 return assertFloatsAlmostEqual(testCase, lhs, rhs, rtol=0, atol=0, **kwargs)
713def _settingsIterator(settings):
714 """Return an iterator for the provided test settings
716 Parameters
717 ----------
718 settings : `dict` (`str`: iterable)
719 Lists of test parameters. Each should be an iterable of the same length.
720 If a string is provided as an iterable, it will be converted to a list
721 of a single string.
723 Raises
724 ------
725 AssertionError
726 If the ``settings`` are not of the same length.
728 Yields
729 ------
730 parameters : `dict` (`str`: anything)
731 Set of parameters.
732 """
733 for name, values in settings.items():
734 if isinstance(values, str): 734 ↛ 736line 734 didn't jump to line 736, because the condition on line 734 was never true
735 # Probably meant as a single-element string, rather than an iterable of chars
736 settings[name] = [values]
737 num = len(next(iter(settings.values()))) # Number of settings
738 for name, values in settings.items():
739 assert len(values) == num, f"Length mismatch for setting {name}: {len(values)} vs {num}"
740 for ii in range(num):
741 values = [settings[kk][ii] for kk in settings]
742 yield dict(zip(settings.keys(), values))
745def classParameters(**settings):
746 """Class decorator for generating unit tests
748 This decorator generates classes with class variables according to the
749 supplied ``settings``.
751 Parameters
752 ----------
753 **settings : `dict` (`str`: iterable)
754 The lists of test parameters to set as class variables in turn. Each
755 should be an iterable of the same length.
757 Examples
758 --------
759 ::
761 @classParameters(foo=[1, 2], bar=[3, 4])
762 class MyTestCase(unittest.TestCase):
763 ...
765 will generate two classes, as if you wrote::
767 class MyTestCase_1_3(unittest.TestCase):
768 foo = 1
769 bar = 3
770 ...
772 class MyTestCase_2_4(unittest.TestCase):
773 foo = 2
774 bar = 4
775 ...
777 Note that the values are embedded in the class name.
778 """
779 def decorator(cls):
780 module = sys.modules[cls.__module__].__dict__
781 for params in _settingsIterator(settings):
782 name = f"{cls.__name__}_{'_'.join(str(vv) for vv in params.values())}"
783 bindings = dict(cls.__dict__)
784 bindings.update(params)
785 module[name] = type(name, (cls,), bindings)
786 return decorator
789def methodParameters(**settings):
790 """Method decorator for unit tests
792 This decorator iterates over the supplied settings, using
793 ``TestCase.subTest`` to communicate the values in the event of a failure.
795 Parameters
796 ----------
797 **settings : `dict` (`str`: iterable)
798 The lists of test parameters. Each should be an iterable of the same
799 length.
801 Examples
802 --------
803 ::
805 @methodParameters(foo=[1, 2], bar=[3, 4])
806 def testSomething(self, foo, bar):
807 ...
809 will run::
811 testSomething(foo=1, bar=3)
812 testSomething(foo=2, bar=4)
813 """
814 def decorator(func):
815 @functools.wraps(func)
816 def wrapper(self, *args, **kwargs):
817 for params in _settingsIterator(settings):
818 kwargs.update(params)
819 with self.subTest(**params):
820 func(self, *args, **kwargs)
821 return wrapper
822 return decorator