Coverage for python/lsst/utils/tests.py: 34%

361 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-03-20 10:50 +0000

1# This file is part of utils. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# Use of this source code is governed by a 3-clause BSD-style 

10# license that can be found in the LICENSE file. 

11 

12"""Support code for running unit tests.""" 

13 

14from __future__ import annotations 

15 

16__all__ = [ 

17 "init", 

18 "MemoryTestCase", 

19 "ExecutablesTestCase", 

20 "ImportTestCase", 

21 "getTempFilePath", 

22 "TestCase", 

23 "assertFloatsAlmostEqual", 

24 "assertFloatsNotEqual", 

25 "assertFloatsEqual", 

26 "debugger", 

27 "classParameters", 

28 "methodParameters", 

29 "temporaryDirectory", 

30] 

31 

32import contextlib 

33import functools 

34import gc 

35import inspect 

36import itertools 

37import os 

38import re 

39import shutil 

40import subprocess 

41import sys 

42import tempfile 

43import unittest 

44import warnings 

45from collections.abc import Callable, Container, Iterable, Iterator, Mapping, Sequence 

46 

47if sys.version_info < (3, 10, 0): 47 ↛ 48line 47 didn't jump to line 48, because the condition on line 47 was never true

48 import importlib_resources as resources 

49else: 

50 from importlib import resources 

51 

52from typing import Any, ClassVar 

53 

54import numpy 

55import psutil 

56 

57from .doImport import doImport 

58 

59# Initialize the list of open files to an empty set 

60open_files = set() 

61 

62 

63def _get_open_files() -> set[str]: 

64 """Return a set containing the list of files currently open in this 

65 process. 

66 

67 Returns 

68 ------- 

69 open_files : `set` 

70 Set containing the list of open files. 

71 """ 

72 return {p.path for p in psutil.Process().open_files()} 

73 

74 

75def init() -> None: 

76 """Initialize the memory tester and file descriptor leak tester.""" 

77 global open_files 

78 # Reset the list of open files 

79 open_files = _get_open_files() 

80 

81 

82def sort_tests(tests) -> unittest.TestSuite: 

83 """Sort supplied test suites such that MemoryTestCases are at the end. 

84 

85 `lsst.utils.tests.MemoryTestCase` tests should always run after any other 

86 tests in the module. 

87 

88 Parameters 

89 ---------- 

90 tests : sequence 

91 Sequence of test suites. 

92 

93 Returns 

94 ------- 

95 suite : `unittest.TestSuite` 

96 A combined `~unittest.TestSuite` with 

97 `~lsst.utils.tests.MemoryTestCase` at the end. 

98 """ 

99 suite = unittest.TestSuite() 

100 memtests = [] 

101 for test_suite in tests: 

102 try: 

103 # Just test the first test method in the suite for MemoryTestCase 

104 # Use loop rather than next as it is possible for a test class 

105 # to not have any test methods and the Python community prefers 

106 # for loops over catching a StopIteration exception. 

107 bases = None 

108 for method in test_suite: 

109 bases = inspect.getmro(method.__class__) 

110 break 

111 if bases is not None and MemoryTestCase in bases: 

112 memtests.append(test_suite) 

113 else: 

114 suite.addTests(test_suite) 

115 except TypeError: 

116 if isinstance(test_suite, MemoryTestCase): 

117 memtests.append(test_suite) 

118 else: 

119 suite.addTest(test_suite) 

120 suite.addTests(memtests) 

121 return suite 

122 

123 

124def _suiteClassWrapper(tests): 

125 return unittest.TestSuite(sort_tests(tests)) 

126 

127 

128# Replace the suiteClass callable in the defaultTestLoader 

129# so that we can reorder the test ordering. This will have 

130# no effect if no memory test cases are found. 

131unittest.defaultTestLoader.suiteClass = _suiteClassWrapper 

132 

133 

134class MemoryTestCase(unittest.TestCase): 

135 """Check for resource leaks.""" 

136 

137 ignore_regexps: ClassVar[list[str]] = [] 

138 """List of regexps to ignore when checking for open files.""" 

139 

140 @classmethod 

141 def tearDownClass(cls) -> None: 

142 """Reset the leak counter when the tests have been completed.""" 

143 init() 

144 

145 def testFileDescriptorLeaks(self) -> None: 

146 """Check if any file descriptors are open since init() called. 

147 

148 Ignores files with certain known path components and any files 

149 that match regexp patterns in class property ``ignore_regexps``. 

150 """ 

151 gc.collect() 

152 global open_files 

153 now_open = _get_open_files() 

154 

155 # Some files are opened out of the control of the stack. 

156 now_open = { 

157 f 

158 for f in now_open 

159 if not f.endswith(".car") 

160 and not f.startswith("/proc/") 

161 and not f.endswith(".ttf") 

162 and not (f.startswith("/var/lib/") and f.endswith("/passwd")) 

163 and not f.endswith("astropy.log") 

164 and not f.endswith("mime/mime.cache") 

165 and not f.endswith(".sqlite3") 

166 and not any(re.search(r, f) for r in self.ignore_regexps) 

167 } 

168 

169 diff = now_open.difference(open_files) 

170 if diff: 

171 for f in diff: 

172 print(f"File open: {f}") 

173 self.fail("Failed to close %d file%s" % (len(diff), "s" if len(diff) != 1 else "")) 

174 

175 

176class ExecutablesTestCase(unittest.TestCase): 

177 """Test that executables can be run and return good status. 

178 

179 The test methods are dynamically created. Callers 

180 must subclass this class in their own test file and invoke 

181 the create_executable_tests() class method to register the tests. 

182 """ 

183 

184 TESTS_DISCOVERED = -1 

185 

186 @classmethod 

187 def setUpClass(cls) -> None: 

188 """Abort testing if automated test creation was enabled and 

189 no tests were found. 

190 """ 

191 if cls.TESTS_DISCOVERED == 0: 

192 raise RuntimeError("No executables discovered.") 

193 

194 def testSanity(self) -> None: 

195 """Ensure that there is at least one test to be 

196 executed. This allows the test runner to trigger the class set up 

197 machinery to test whether there are some executables to test. 

198 """ 

199 

200 def assertExecutable( 

201 self, 

202 executable: str, 

203 root_dir: str | None = None, 

204 args: Sequence[str] | None = None, 

205 msg: str | None = None, 

206 ) -> None: 

207 """Check an executable runs and returns good status. 

208 

209 Prints output to standard out. On bad exit status the test 

210 fails. If the executable can not be located the test is skipped. 

211 

212 Parameters 

213 ---------- 

214 executable : `str` 

215 Path to an executable. ``root_dir`` is not used if this is an 

216 absolute path. 

217 root_dir : `str`, optional 

218 Directory containing executable. Ignored if `None`. 

219 args : `list` or `tuple`, optional 

220 Arguments to be provided to the executable. 

221 msg : `str`, optional 

222 Message to use when the test fails. Can be `None` for default 

223 message. 

224 

225 Raises 

226 ------ 

227 AssertionError 

228 The executable did not return 0 exit status. 

229 """ 

230 if root_dir is not None and not os.path.isabs(executable): 

231 executable = os.path.join(root_dir, executable) 

232 

233 # Form the argument list for subprocess 

234 sp_args = [executable] 

235 argstr = "no arguments" 

236 if args is not None: 

237 sp_args.extend(args) 

238 argstr = 'arguments "' + " ".join(args) + '"' 

239 

240 print(f"Running executable '{executable}' with {argstr}...") 

241 if not os.path.exists(executable): 

242 self.skipTest(f"Executable {executable} is unexpectedly missing") 

243 failmsg = None 

244 try: 

245 output = subprocess.check_output(sp_args) 

246 except subprocess.CalledProcessError as e: 

247 output = e.output 

248 failmsg = f"Bad exit status from '{executable}': {e.returncode}" 

249 print(output.decode("utf-8")) 

250 if failmsg: 

251 if msg is None: 

252 msg = failmsg 

253 self.fail(msg) 

254 

255 @classmethod 

256 def _build_test_method(cls, executable: str, root_dir: str) -> None: 

257 """Build a test method and attach to class. 

258 

259 A test method is created for the supplied excutable located 

260 in the supplied root directory. This method is attached to the class 

261 so that the test runner will discover the test and run it. 

262 

263 Parameters 

264 ---------- 

265 cls : `object` 

266 The class in which to create the tests. 

267 executable : `str` 

268 Name of executable. Can be absolute path. 

269 root_dir : `str` 

270 Path to executable. Not used if executable path is absolute. 

271 """ 

272 if not os.path.isabs(executable): 272 ↛ 273line 272 didn't jump to line 273, because the condition on line 272 was never true

273 executable = os.path.abspath(os.path.join(root_dir, executable)) 

274 

275 # Create the test name from the executable path. 

276 test_name = "test_exe_" + executable.replace("/", "_") 

277 

278 # This is the function that will become the test method 

279 def test_executable_runs(*args: Any) -> None: 

280 self = args[0] 

281 self.assertExecutable(executable) 

282 

283 # Give it a name and attach it to the class 

284 test_executable_runs.__name__ = test_name 

285 setattr(cls, test_name, test_executable_runs) 

286 

287 @classmethod 

288 def create_executable_tests(cls, ref_file: str, executables: Sequence[str] | None = None) -> None: 

289 """Discover executables to test and create corresponding test methods. 

290 

291 Scans the directory containing the supplied reference file 

292 (usually ``__file__`` supplied from the test class) to look for 

293 executables. If executables are found a test method is created 

294 for each one. That test method will run the executable and 

295 check the returned value. 

296 

297 Executable scripts with a ``.py`` extension and shared libraries 

298 are ignored by the scanner. 

299 

300 This class method must be called before test discovery. 

301 

302 Parameters 

303 ---------- 

304 ref_file : `str` 

305 Path to a file within the directory to be searched. 

306 If the files are in the same location as the test file, then 

307 ``__file__`` can be used. 

308 executables : `list` or `tuple`, optional 

309 Sequence of executables that can override the automated 

310 detection. If an executable mentioned here is not found, a 

311 skipped test will be created for it, rather than a failed 

312 test. 

313 

314 Examples 

315 -------- 

316 >>> cls.create_executable_tests(__file__) 

317 """ 

318 # Get the search directory from the reference file 

319 ref_dir = os.path.abspath(os.path.dirname(ref_file)) 

320 

321 if executables is None: 321 ↛ 336line 321 didn't jump to line 336, because the condition on line 321 was never false

322 # Look for executables to test by walking the tree 

323 executables = [] 

324 for root, _, files in os.walk(ref_dir): 

325 for f in files: 

326 # Skip Python files. Shared libraries are executable. 

327 if not f.endswith(".py") and not f.endswith(".so"): 

328 full_path = os.path.join(root, f) 

329 if os.access(full_path, os.X_OK): 

330 executables.append(full_path) 

331 

332 # Store the number of tests found for later assessment. 

333 # Do not raise an exception if we have no executables as this would 

334 # cause the testing to abort before the test runner could properly 

335 # integrate it into the failure report. 

336 cls.TESTS_DISCOVERED = len(executables) 

337 

338 # Create the test functions and attach them to the class 

339 for e in executables: 

340 cls._build_test_method(e, ref_dir) 

341 

342 

343class ImportTestCase(unittest.TestCase): 

344 """Test that the named packages can be imported and all files within 

345 that package. 

346 

347 The test methods are created dynamically. Callers must subclass this 

348 method and define the ``PACKAGES`` property. 

349 """ 

350 

351 PACKAGES: ClassVar[Iterable[str]] = () 

352 """Packages to be imported.""" 

353 

354 SKIP_FILES: ClassVar[Mapping[str, Container[str]]] = {} 

355 """Files to be skipped importing; specified as key-value pairs. 

356 

357 The key is the package name and the value is a set of files names in that 

358 package to skip. 

359 

360 Note: Files with names not ending in .py or beginning with leading double 

361 underscores are always skipped. 

362 """ 

363 

364 _n_registered = 0 

365 """Number of packages registered for testing by this class.""" 

366 

367 def _test_no_packages_registered_for_import_testing(self) -> None: 

368 """Test when no packages have been registered. 

369 

370 Without this, if no packages have been listed no tests will be 

371 registered and the test system will not report on anything. This 

372 test fails and reports why. 

373 """ 

374 raise AssertionError("No packages registered with import test. Was the PACKAGES property set?") 

375 

376 def __init_subclass__(cls, **kwargs: Any) -> None: 

377 """Create the test methods based on the content of the ``PACKAGES`` 

378 class property. 

379 """ 

380 super().__init_subclass__(**kwargs) 

381 

382 for mod in cls.PACKAGES: 

383 test_name = "test_import_" + mod.replace(".", "_") 

384 

385 def test_import(*args: Any, mod=mod) -> None: 

386 self = args[0] 

387 self.assertImport(mod) 

388 

389 test_import.__name__ = test_name 

390 setattr(cls, test_name, test_import) 

391 cls._n_registered += 1 

392 

393 # If there are no packages listed that is likely a mistake and 

394 # so register a failing test. 

395 if cls._n_registered == 0: 395 ↛ 396line 395 didn't jump to line 396, because the condition on line 395 was never true

396 cls.test_no_packages_registered = cls._test_no_packages_registered_for_import_testing 

397 

398 def assertImport(self, root_pkg): 

399 for file in resources.files(root_pkg).iterdir(): 

400 file = file.name 

401 # When support for python 3.9 is dropped, this could be updated to 

402 # use match case construct. 

403 if not file.endswith(".py"): 

404 continue 

405 if file.startswith("__"): 

406 continue 

407 if file in self.SKIP_FILES.get(root_pkg, ()): 

408 continue 

409 root, _ = os.path.splitext(file) 

410 module_name = f"{root_pkg}.{root}" 

411 with self.subTest(module=module_name): 

412 try: 

413 doImport(module_name) 

414 except ImportError as e: 

415 raise AssertionError(f"Error importing module {module_name}: {e}") from e 

416 

417 

418@contextlib.contextmanager 

419def getTempFilePath(ext: str, expectOutput: bool = True) -> Iterator[str]: 

420 """Return a path suitable for a temporary file and try to delete the 

421 file on success. 

422 

423 If the with block completes successfully then the file is deleted, 

424 if possible; failure results in a printed warning. 

425 If a file is remains when it should not, a RuntimeError exception is 

426 raised. This exception is also raised if a file is not present on context 

427 manager exit when one is expected to exist. 

428 If the block exits with an exception the file if left on disk so it can be 

429 examined. The file name has a random component such that nested context 

430 managers can be used with the same file suffix. 

431 

432 Parameters 

433 ---------- 

434 ext : `str` 

435 File name extension, e.g. ``.fits``. 

436 expectOutput : `bool`, optional 

437 If `True`, a file should be created within the context manager. 

438 If `False`, a file should not be present when the context manager 

439 exits. 

440 

441 Yields 

442 ------ 

443 path : `str` 

444 Path for a temporary file. The path is a combination of the caller's 

445 file path and the name of the top-level function. 

446 

447 Examples 

448 -------- 

449 .. code-block:: python 

450 

451 # file tests/testFoo.py 

452 import unittest 

453 import lsst.utils.tests 

454 class FooTestCase(unittest.TestCase): 

455 def testBasics(self): 

456 self.runTest() 

457 

458 def runTest(self): 

459 with lsst.utils.tests.getTempFilePath(".fits") as tmpFile: 

460 # if tests/.tests exists then 

461 # tmpFile = "tests/.tests/testFoo_testBasics.fits" 

462 # otherwise tmpFile = "testFoo_testBasics.fits" 

463 ... 

464 # at the end of this "with" block the path tmpFile will be 

465 # deleted, but only if the file exists and the "with" 

466 # block terminated normally (rather than with an exception) 

467 ... 

468 """ 

469 stack = inspect.stack() 

470 # get name of first function in the file 

471 for i in range(2, len(stack)): 

472 frameInfo = inspect.getframeinfo(stack[i][0]) 

473 if i == 2: 

474 callerFilePath = frameInfo.filename 

475 callerFuncName = frameInfo.function 

476 elif callerFilePath == frameInfo.filename: 

477 # this function called the previous function 

478 callerFuncName = frameInfo.function 

479 else: 

480 break 

481 

482 callerDir, callerFileNameWithExt = os.path.split(callerFilePath) 

483 callerFileName = os.path.splitext(callerFileNameWithExt)[0] 

484 outDir = os.path.join(callerDir, ".tests") 

485 if not os.path.isdir(outDir): 

486 # No .tests directory implies we are not running with sconsUtils. 

487 # Need to use the current working directory, the callerDir, or 

488 # /tmp equivalent. If cwd is used if must be as an absolute path 

489 # in case the test code changes cwd. 

490 outDir = os.path.abspath(os.path.curdir) 

491 prefix = f"{callerFileName}_{callerFuncName}-" 

492 outPath = tempfile.mktemp(dir=outDir, suffix=ext, prefix=prefix) 

493 if os.path.exists(outPath): 

494 # There should not be a file there given the randomizer. Warn and 

495 # remove. 

496 # Use stacklevel 3 so that the warning is reported from the end of the 

497 # with block 

498 warnings.warn(f"Unexpectedly found pre-existing tempfile named {outPath!r}", stacklevel=3) 

499 with contextlib.suppress(OSError): 

500 os.remove(outPath) 

501 

502 yield outPath 

503 

504 fileExists = os.path.exists(outPath) 

505 if expectOutput: 

506 if not fileExists: 

507 raise RuntimeError(f"Temp file expected named {outPath} but none found") 

508 else: 

509 if fileExists: 

510 raise RuntimeError(f"Unexpectedly discovered temp file named {outPath}") 

511 # Try to clean up the file regardless 

512 if fileExists: 

513 try: 

514 os.remove(outPath) 

515 except OSError as e: 

516 # Use stacklevel 3 so that the warning is reported from the end of 

517 # the with block. 

518 warnings.warn(f"Warning: could not remove file {outPath!r}: {e}", stacklevel=3) 

519 

520 

521class TestCase(unittest.TestCase): 

522 """Subclass of unittest.TestCase that adds some custom assertions for 

523 convenience. 

524 """ 

525 

526 

527def inTestCase(func: Callable) -> Callable: 

528 """Add a free function to our custom TestCase class, while 

529 also making it available as a free function. 

530 

531 Parameters 

532 ---------- 

533 func : `~collections.abc.Callable` 

534 Function to be added to `unittest.TestCase` class. 

535 

536 Returns 

537 ------- 

538 func : `~collections.abc.Callable` 

539 The given function. 

540 """ 

541 setattr(TestCase, func.__name__, func) 

542 return func 

543 

544 

545def debugger(*exceptions): 

546 """Enter the debugger when there's an uncaught exception. 

547 

548 To use, just slap a ``@debugger()`` on your function. 

549 

550 You may provide specific exception classes to catch as arguments to 

551 the decorator function, e.g., 

552 ``@debugger(RuntimeError, NotImplementedError)``. 

553 This defaults to just `AssertionError`, for use on `unittest.TestCase` 

554 methods. 

555 

556 Code provided by "Rosh Oxymoron" on StackOverflow: 

557 http://stackoverflow.com/questions/4398967/python-unit-testing-automatically-running-the-debugger-when-a-test-fails 

558 

559 Parameters 

560 ---------- 

561 *exceptions : `Exception` 

562 Specific exception classes to catch. Default is to catch 

563 `AssertionError`. 

564 

565 Notes 

566 ----- 

567 Consider using ``pytest --pdb`` instead of this decorator. 

568 """ 

569 if not exceptions: 

570 exceptions = (Exception,) 

571 

572 def decorator(f): 

573 @functools.wraps(f) 

574 def wrapper(*args, **kwargs): 

575 try: 

576 return f(*args, **kwargs) 

577 except exceptions: 

578 import pdb 

579 import sys 

580 

581 pdb.post_mortem(sys.exc_info()[2]) 

582 

583 return wrapper 

584 

585 return decorator 

586 

587 

588def plotImageDiff( 

589 lhs: numpy.ndarray, 

590 rhs: numpy.ndarray, 

591 bad: numpy.ndarray | None = None, 

592 diff: numpy.ndarray | None = None, 

593 plotFileName: str | None = None, 

594) -> None: 

595 """Plot the comparison of two 2-d NumPy arrays. 

596 

597 Parameters 

598 ---------- 

599 lhs : `numpy.ndarray` 

600 LHS values to compare; a 2-d NumPy array. 

601 rhs : `numpy.ndarray` 

602 RHS values to compare; a 2-d NumPy array. 

603 bad : `numpy.ndarray` 

604 A 2-d boolean NumPy array of values to emphasize in the plots. 

605 diff : `numpy.ndarray` 

606 Difference array; a 2-d NumPy array, or None to show lhs-rhs. 

607 plotFileName : `str` 

608 Filename to save the plot to. If None, the plot will be displayed in 

609 a window. 

610 

611 Notes 

612 ----- 

613 This method uses `matplotlib` and imports it internally; it should be 

614 wrapped in a try/except block within packages that do not depend on 

615 `matplotlib` (including `~lsst.utils`). 

616 """ 

617 from matplotlib import pyplot 

618 

619 if diff is None: 

620 diff = lhs - rhs 

621 pyplot.figure() 

622 if bad is not None: 

623 # make an rgba image that's red and transparent where not bad 

624 badImage = numpy.zeros(bad.shape + (4,), dtype=numpy.uint8) 

625 badImage[:, :, 0] = 255 

626 badImage[:, :, 1] = 0 

627 badImage[:, :, 2] = 0 

628 badImage[:, :, 3] = 255 * bad 

629 vmin1 = numpy.minimum(numpy.min(lhs), numpy.min(rhs)) 

630 vmax1 = numpy.maximum(numpy.max(lhs), numpy.max(rhs)) 

631 vmin2 = numpy.min(diff) 

632 vmax2 = numpy.max(diff) 

633 for n, (image, title) in enumerate([(lhs, "lhs"), (rhs, "rhs"), (diff, "diff")]): 

634 pyplot.subplot(2, 3, n + 1) 

635 im1 = pyplot.imshow( 

636 image, cmap=pyplot.cm.gray, interpolation="nearest", origin="lower", vmin=vmin1, vmax=vmax1 

637 ) 

638 if bad is not None: 

639 pyplot.imshow(badImage, alpha=0.2, interpolation="nearest", origin="lower") 

640 pyplot.axis("off") 

641 pyplot.title(title) 

642 pyplot.subplot(2, 3, n + 4) 

643 im2 = pyplot.imshow( 

644 image, cmap=pyplot.cm.gray, interpolation="nearest", origin="lower", vmin=vmin2, vmax=vmax2 

645 ) 

646 if bad is not None: 

647 pyplot.imshow(badImage, alpha=0.2, interpolation="nearest", origin="lower") 

648 pyplot.axis("off") 

649 pyplot.title(title) 

650 pyplot.subplots_adjust(left=0.05, bottom=0.05, top=0.92, right=0.75, wspace=0.05, hspace=0.05) 

651 cax1 = pyplot.axes([0.8, 0.55, 0.05, 0.4]) 

652 pyplot.colorbar(im1, cax=cax1) 

653 cax2 = pyplot.axes([0.8, 0.05, 0.05, 0.4]) 

654 pyplot.colorbar(im2, cax=cax2) 

655 if plotFileName: 

656 pyplot.savefig(plotFileName) 

657 else: 

658 pyplot.show() 

659 

660 

661@inTestCase 

662def assertFloatsAlmostEqual( 

663 testCase: unittest.TestCase, 

664 lhs: float | numpy.ndarray, 

665 rhs: float | numpy.ndarray, 

666 rtol: float | None = sys.float_info.epsilon, 

667 atol: float | None = sys.float_info.epsilon, 

668 relTo: float | None = None, 

669 printFailures: bool = True, 

670 plotOnFailure: bool = False, 

671 plotFileName: str | None = None, 

672 invert: bool = False, 

673 msg: str | None = None, 

674 ignoreNaNs: bool = False, 

675) -> None: 

676 """Highly-configurable floating point comparisons for scalars and arrays. 

677 

678 The test assertion will fail if all elements ``lhs`` and ``rhs`` are not 

679 equal to within the tolerances specified by ``rtol`` and ``atol``. 

680 More precisely, the comparison is: 

681 

682 ``abs(lhs - rhs) <= relTo*rtol OR abs(lhs - rhs) <= atol`` 

683 

684 If ``rtol`` or ``atol`` is `None`, that term in the comparison is not 

685 performed at all. 

686 

687 When not specified, ``relTo`` is the elementwise maximum of the absolute 

688 values of ``lhs`` and ``rhs``. If set manually, it should usually be set 

689 to either ``lhs`` or ``rhs``, or a scalar value typical of what is 

690 expected. 

691 

692 Parameters 

693 ---------- 

694 testCase : `unittest.TestCase` 

695 Instance the test is part of. 

696 lhs : scalar or array-like 

697 LHS value(s) to compare; may be a scalar or array-like of any 

698 dimension. 

699 rhs : scalar or array-like 

700 RHS value(s) to compare; may be a scalar or array-like of any 

701 dimension. 

702 rtol : `float`, optional 

703 Relative tolerance for comparison; defaults to double-precision 

704 epsilon. 

705 atol : `float`, optional 

706 Absolute tolerance for comparison; defaults to double-precision 

707 epsilon. 

708 relTo : `float`, optional 

709 Value to which comparison with rtol is relative. 

710 printFailures : `bool`, optional 

711 Upon failure, print all inequal elements as part of the message. 

712 plotOnFailure : `bool`, optional 

713 Upon failure, plot the originals and their residual with matplotlib. 

714 Only 2-d arrays are supported. 

715 plotFileName : `str`, optional 

716 Filename to save the plot to. If `None`, the plot will be displayed in 

717 a window. 

718 invert : `bool`, optional 

719 If `True`, invert the comparison and fail only if any elements *are* 

720 equal. Used to implement `~lsst.utils.tests.assertFloatsNotEqual`, 

721 which should generally be used instead for clarity. 

722 will return `True`). 

723 msg : `str`, optional 

724 String to append to the error message when assert fails. 

725 ignoreNaNs : `bool`, optional 

726 If `True` (`False` is default) mask out any NaNs from operand arrays 

727 before performing comparisons if they are in the same locations; NaNs 

728 in different locations are trigger test assertion failures, even when 

729 ``invert=True``. Scalar NaNs are treated like arrays containing only 

730 NaNs of the same shape as the other operand, and no comparisons are 

731 performed if both sides are scalar NaNs. 

732 

733 Raises 

734 ------ 

735 AssertionError 

736 The values are not almost equal. 

737 """ 

738 if ignoreNaNs: 

739 lhsMask = numpy.isnan(lhs) 

740 rhsMask = numpy.isnan(rhs) 

741 if not numpy.all(lhsMask == rhsMask): 

742 testCase.fail( 

743 f"lhs has {lhsMask.sum()} NaN values and rhs has {rhsMask.sum()} NaN values, " 

744 "in different locations." 

745 ) 

746 if numpy.all(lhsMask): 

747 assert numpy.all(rhsMask), "Should be guaranteed by previous if." 

748 # All operands are fully NaN (either scalar NaNs or arrays of only 

749 # NaNs). 

750 return 

751 assert not numpy.all(rhsMask), "Should be guaranteed by prevoius two ifs." 

752 # If either operand is an array select just its not-NaN values. Note 

753 # that these expressions are never True for scalar operands, because if 

754 # they are NaN then the numpy.all checks above will catch them. 

755 if numpy.any(lhsMask): 

756 lhs = lhs[numpy.logical_not(lhsMask)] 

757 if numpy.any(rhsMask): 

758 rhs = rhs[numpy.logical_not(rhsMask)] 

759 if not numpy.isfinite(lhs).all(): 

760 testCase.fail("Non-finite values in lhs") 

761 if not numpy.isfinite(rhs).all(): 

762 testCase.fail("Non-finite values in rhs") 

763 diff = lhs - rhs 

764 absDiff = numpy.abs(lhs - rhs) 

765 if rtol is not None: 

766 if relTo is None: 

767 relTo = numpy.maximum(numpy.abs(lhs), numpy.abs(rhs)) 

768 else: 

769 relTo = numpy.abs(relTo) 

770 bad = absDiff > rtol * relTo 

771 if atol is not None: 

772 bad = numpy.logical_and(bad, absDiff > atol) 

773 else: 

774 if atol is None: 

775 raise ValueError("rtol and atol cannot both be None") 

776 bad = absDiff > atol 

777 failed = numpy.any(bad) 

778 if invert: 

779 failed = not failed 

780 bad = numpy.logical_not(bad) 

781 cmpStr = "==" 

782 failStr = "are the same" 

783 else: 

784 cmpStr = "!=" 

785 failStr = "differ" 

786 errMsg = [] 

787 if failed: 

788 if numpy.isscalar(bad): 

789 if rtol is None: 

790 errMsg = [f"{lhs} {cmpStr} {rhs}; diff={absDiff} with atol={atol}"] 

791 elif atol is None: 

792 errMsg = [f"{lhs} {cmpStr} {rhs}; diff={absDiff}/{relTo}={absDiff / relTo} with rtol={rtol}"] 

793 else: 

794 errMsg = [ 

795 f"{lhs} {cmpStr} {rhs}; diff={absDiff}/{relTo}={absDiff / relTo} " 

796 f"with rtol={rtol}, atol={atol}" 

797 ] 

798 else: 

799 errMsg = [f"{bad.sum()}/{bad.size} elements {failStr} with rtol={rtol}, atol={atol}"] 

800 if plotOnFailure: 

801 if len(lhs.shape) != 2 or len(rhs.shape) != 2: 

802 raise ValueError("plotOnFailure is only valid for 2-d arrays") 

803 try: 

804 plotImageDiff(lhs, rhs, bad, diff=diff, plotFileName=plotFileName) 

805 except ImportError: 

806 errMsg.append("Failure plot requested but matplotlib could not be imported.") 

807 if printFailures: 

808 # Make sure everything is an array if any of them are, so we 

809 # can treat them the same (diff and absDiff are arrays if 

810 # either rhs or lhs is), and we don't get here if neither is. 

811 if numpy.isscalar(relTo): 

812 relTo = numpy.ones(bad.shape, dtype=float) * relTo 

813 if numpy.isscalar(lhs): 

814 lhs = numpy.ones(bad.shape, dtype=float) * lhs 

815 if numpy.isscalar(rhs): 

816 rhs = numpy.ones(bad.shape, dtype=float) * rhs 

817 if rtol is None: 

818 for a, b, diff in zip(lhs[bad], rhs[bad], absDiff[bad]): 

819 errMsg.append(f"{a} {cmpStr} {b} (diff={diff})") 

820 else: 

821 for a, b, diff, rel in zip(lhs[bad], rhs[bad], absDiff[bad], relTo[bad]): 

822 errMsg.append(f"{a} {cmpStr} {b} (diff={diff}/{rel}={diff / rel})") 

823 

824 if msg is not None: 

825 errMsg.append(msg) 

826 testCase.assertFalse(failed, msg="\n".join(errMsg)) 

827 

828 

829@inTestCase 

830def assertFloatsNotEqual( 

831 testCase: unittest.TestCase, 

832 lhs: float | numpy.ndarray, 

833 rhs: float | numpy.ndarray, 

834 **kwds: Any, 

835) -> None: 

836 """Fail a test if the given floating point values are equal to within the 

837 given tolerances. 

838 

839 See `~lsst.utils.tests.assertFloatsAlmostEqual` (called with 

840 ``rtol=atol=0``) for more information. 

841 

842 Parameters 

843 ---------- 

844 testCase : `unittest.TestCase` 

845 Instance the test is part of. 

846 lhs : scalar or array-like 

847 LHS value(s) to compare; may be a scalar or array-like of any 

848 dimension. 

849 rhs : scalar or array-like 

850 RHS value(s) to compare; may be a scalar or array-like of any 

851 dimension. 

852 **kwds : `~typing.Any` 

853 Keyword parameters forwarded to `assertFloatsAlmostEqual`. 

854 

855 Raises 

856 ------ 

857 AssertionError 

858 The values are almost equal. 

859 """ 

860 return assertFloatsAlmostEqual(testCase, lhs, rhs, invert=True, **kwds) 

861 

862 

863@inTestCase 

864def assertFloatsEqual( 

865 testCase: unittest.TestCase, 

866 lhs: float | numpy.ndarray, 

867 rhs: float | numpy.ndarray, 

868 **kwargs: Any, 

869) -> None: 

870 """ 

871 Assert that lhs == rhs (both numeric types, whether scalar or array). 

872 

873 See `~lsst.utils.tests.assertFloatsAlmostEqual` (called with 

874 ``rtol=atol=0``) for more information. 

875 

876 Parameters 

877 ---------- 

878 testCase : `unittest.TestCase` 

879 Instance the test is part of. 

880 lhs : scalar or array-like 

881 LHS value(s) to compare; may be a scalar or array-like of any 

882 dimension. 

883 rhs : scalar or array-like 

884 RHS value(s) to compare; may be a scalar or array-like of any 

885 dimension. 

886 **kwargs : `~typing.Any` 

887 Keyword parameters forwarded to `assertFloatsAlmostEqual`. 

888 

889 Raises 

890 ------ 

891 AssertionError 

892 The values are not equal. 

893 """ 

894 return assertFloatsAlmostEqual(testCase, lhs, rhs, rtol=0, atol=0, **kwargs) 

895 

896 

897def _settingsIterator(settings: dict[str, Sequence[Any]]) -> Iterator[dict[str, Any]]: 

898 """Return an iterator for the provided test settings 

899 

900 Parameters 

901 ---------- 

902 settings : `dict` (`str`: iterable) 

903 Lists of test parameters. Each should be an iterable of the same 

904 length. If a string is provided as an iterable, it will be converted 

905 to a list of a single string. 

906 

907 Raises 

908 ------ 

909 AssertionError 

910 If the ``settings`` are not of the same length. 

911 

912 Yields 

913 ------ 

914 parameters : `dict` (`str`: anything) 

915 Set of parameters. 

916 """ 

917 for name, values in settings.items(): 

918 if isinstance(values, str): 918 ↛ 921line 918 didn't jump to line 921, because the condition on line 918 was never true

919 # Probably meant as a single-element string, rather than an 

920 # iterable of chars. 

921 settings[name] = [values] 

922 num = len(next(iter(settings.values()))) # Number of settings 

923 for name, values in settings.items(): 

924 assert len(values) == num, f"Length mismatch for setting {name}: {len(values)} vs {num}" 

925 for ii in range(num): 

926 values = [settings[kk][ii] for kk in settings] 

927 yield dict(zip(settings, values)) 

928 

929 

930def classParameters(**settings: Sequence[Any]) -> Callable: 

931 """Class decorator for generating unit tests. 

932 

933 This decorator generates classes with class variables according to the 

934 supplied ``settings``. 

935 

936 Parameters 

937 ---------- 

938 **settings : `dict` (`str`: iterable) 

939 The lists of test parameters to set as class variables in turn. Each 

940 should be an iterable of the same length. 

941 

942 Examples 

943 -------- 

944 :: 

945 

946 @classParameters(foo=[1, 2], bar=[3, 4]) 

947 class MyTestCase(unittest.TestCase): 

948 ... 

949 

950 will generate two classes, as if you wrote:: 

951 

952 class MyTestCase_1_3(unittest.TestCase): 

953 foo = 1 

954 bar = 3 

955 ... 

956 

957 class MyTestCase_2_4(unittest.TestCase): 

958 foo = 2 

959 bar = 4 

960 ... 

961 

962 Note that the values are embedded in the class name. 

963 """ 

964 

965 def decorator(cls: type) -> None: 

966 module = sys.modules[cls.__module__].__dict__ 

967 for params in _settingsIterator(settings): 

968 name = f"{cls.__name__}_{'_'.join(str(vv) for vv in params.values())}" 

969 bindings = dict(cls.__dict__) 

970 bindings.update(params) 

971 module[name] = type(name, (cls,), bindings) 

972 

973 return decorator 

974 

975 

976def methodParameters(**settings: Sequence[Any]) -> Callable: 

977 """Iterate over supplied settings to create subtests automatically. 

978 

979 This decorator iterates over the supplied settings, using 

980 ``TestCase.subTest`` to communicate the values in the event of a failure. 

981 

982 Parameters 

983 ---------- 

984 **settings : `dict` (`str`: iterable) 

985 The lists of test parameters. Each should be an iterable of the same 

986 length. 

987 

988 Examples 

989 -------- 

990 .. code-block:: python 

991 

992 @methodParameters(foo=[1, 2], bar=[3, 4]) 

993 def testSomething(self, foo, bar): 

994 ... 

995 

996 will run: 

997 

998 .. code-block:: python 

999 

1000 testSomething(foo=1, bar=3) 

1001 testSomething(foo=2, bar=4) 

1002 """ 

1003 

1004 def decorator(func: Callable) -> Callable: 

1005 @functools.wraps(func) 

1006 def wrapper(self: unittest.TestCase, *args: Any, **kwargs: Any) -> None: 

1007 for params in _settingsIterator(settings): 

1008 kwargs.update(params) 

1009 with self.subTest(**params): 

1010 func(self, *args, **kwargs) 

1011 

1012 return wrapper 

1013 

1014 return decorator 

1015 

1016 

1017def _cartesianProduct(settings: Mapping[str, Sequence[Any]]) -> Mapping[str, Sequence[Any]]: 

1018 """Return the cartesian product of the settings. 

1019 

1020 Parameters 

1021 ---------- 

1022 settings : `dict` mapping `str` to `iterable` 

1023 Parameter combinations. 

1024 

1025 Returns 

1026 ------- 

1027 product : `dict` mapping `str` to `iterable` 

1028 Parameter combinations covering the cartesian product (all possible 

1029 combinations) of the input parameters. 

1030 

1031 Examples 

1032 -------- 

1033 .. code-block:: python 

1034 

1035 cartesianProduct({"foo": [1, 2], "bar": ["black", "white"]}) 

1036 

1037 will return: 

1038 

1039 .. code-block:: python 

1040 

1041 {"foo": [1, 1, 2, 2], "bar": ["black", "white", "black", "white"]} 

1042 """ 

1043 product: dict[str, list[Any]] = {kk: [] for kk in settings} 

1044 for values in itertools.product(*settings.values()): 

1045 for kk, vv in zip(settings.keys(), values): 

1046 product[kk].append(vv) 

1047 return product 

1048 

1049 

1050def classParametersProduct(**settings: Sequence[Any]) -> Callable: 

1051 """Class decorator for generating unit tests. 

1052 

1053 This decorator generates classes with class variables according to the 

1054 cartesian product of the supplied ``settings``. 

1055 

1056 Parameters 

1057 ---------- 

1058 **settings : `dict` (`str`: iterable) 

1059 The lists of test parameters to set as class variables in turn. Each 

1060 should be an iterable. 

1061 

1062 Examples 

1063 -------- 

1064 .. code-block:: python 

1065 

1066 @classParametersProduct(foo=[1, 2], bar=[3, 4]) 

1067 class MyTestCase(unittest.TestCase): 

1068 ... 

1069 

1070 will generate four classes, as if you wrote:: 

1071 

1072 .. code-block:: python 

1073 

1074 class MyTestCase_1_3(unittest.TestCase): 

1075 foo = 1 

1076 bar = 3 

1077 ... 

1078 

1079 class MyTestCase_1_4(unittest.TestCase): 

1080 foo = 1 

1081 bar = 4 

1082 ... 

1083 

1084 class MyTestCase_2_3(unittest.TestCase): 

1085 foo = 2 

1086 bar = 3 

1087 ... 

1088 

1089 class MyTestCase_2_4(unittest.TestCase): 

1090 foo = 2 

1091 bar = 4 

1092 ... 

1093 

1094 Note that the values are embedded in the class name. 

1095 """ 

1096 return classParameters(**_cartesianProduct(settings)) 

1097 

1098 

1099def methodParametersProduct(**settings: Sequence[Any]) -> Callable: 

1100 """Iterate over cartesian product creating sub tests. 

1101 

1102 This decorator iterates over the cartesian product of the supplied 

1103 settings, using `~unittest.TestCase.subTest` to communicate the values in 

1104 the event of a failure. 

1105 

1106 Parameters 

1107 ---------- 

1108 **settings : `dict` (`str`: iterable) 

1109 The parameter combinations to test. Each should be an iterable. 

1110 

1111 Examples 

1112 -------- 

1113 @methodParametersProduct(foo=[1, 2], bar=["black", "white"]) 

1114 def testSomething(self, foo, bar): 

1115 ... 

1116 

1117 will run: 

1118 

1119 testSomething(foo=1, bar="black") 

1120 testSomething(foo=1, bar="white") 

1121 testSomething(foo=2, bar="black") 

1122 testSomething(foo=2, bar="white") 

1123 """ 

1124 return methodParameters(**_cartesianProduct(settings)) 

1125 

1126 

1127@contextlib.contextmanager 

1128def temporaryDirectory() -> Iterator[str]: 

1129 """Context manager that creates and destroys a temporary directory. 

1130 

1131 The difference from `tempfile.TemporaryDirectory` is that this ignores 

1132 errors when deleting a directory, which may happen with some filesystems. 

1133 

1134 Yields 

1135 ------ 

1136 `str` 

1137 Name of the temporary directory. 

1138 """ 

1139 tmpdir = tempfile.mkdtemp() 

1140 yield tmpdir 

1141 shutil.rmtree(tmpdir, ignore_errors=True)