Coverage for tests/test_hermiteTransformMatrix.py: 32%

60 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-10 10:52 +0000

1# 

2# LSST Data Management System 

3# Copyright 2008-2017 LSST Corporation. 

4# 

5# This product includes software developed by the 

6# LSST Project (http://www.lsst.org/). 

7# 

8# This program is free software: you can redistribute it and/or modify 

9# it under the terms of the GNU General Public License as published by 

10# the Free Software Foundation, either version 3 of the License, or 

11# (at your option) any later version. 

12# 

13# This program is distributed in the hope that it will be useful, 

14# but WITHOUT ANY WARRANTY; without even the implied warranty of 

15# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

16# GNU General Public License for more details. 

17# 

18# You should have received a copy of the LSST License Statement and 

19# the GNU General Public License along with this program. If not, 

20# see <http://www.lsstcorp.org/LegalNotices/>. 

21# 

22 

23import unittest 

24 

25import numpy as np 

26 

27try: 

28 import scipy.special 

29except ImportError: 

30 scipy = None 

31 

32import lsst.utils.tests 

33import lsst.geom 

34import lsst.shapelet.tests 

35 

36 

37class HermiteTransformMatrixTestCase(lsst.shapelet.tests.ShapeletTestCase): 

38 

39 def setUp(self): 

40 np.random.seed(500) 

41 self.order = 4 

42 self.size = lsst.shapelet.computeSize(self.order) 

43 self.htm = lsst.shapelet.HermiteTransformMatrix(self.order) 

44 

45 @staticmethod 

46 def ht(n): 

47 """return a numpy Polynomial for the nth 'alternate' Hermite polynomial 

48 (i.e. Hermite polynomial with shapelet normalization)""" 

49 hermite = scipy.special.hermite(n) 

50 # scipy currently returns an np.poly1d; convert it if necessary 

51 if not isinstance(hermite, np.polynomial.Polynomial): 

52 hermite = np.polynomial.Polynomial(hermite.coef[::-1]) 

53 return (2**n * np.pi**0.5 * scipy.special.gamma(n+1))**(-0.5) * hermite 

54 

55 def testCoefficientMatrices(self): 

56 coeff = self.htm.getCoefficientMatrix() 

57 coeffInv = self.htm.getInverseCoefficientMatrix() 

58 self.assertFloatsAlmostEqual(np.identity(self.order+1), np.dot(coeff, coeffInv)) 

59 # Both matrices should be lower-triangular 

60 for i in range(0, self.order+1): 

61 for j in range(i+1, self.order+1): 

62 self.assertEqual(coeff[i, j], 0.0) 

63 self.assertEqual(coeffInv[i, j], 0.0) 

64 

65 @unittest.skipIf(scipy is None, "Test requires SciPy") 

66 def testCoefficientsAgainstHermite(self): 

67 """Test coefficient matrix values against scipy Hermite polynomials""" 

68 coeff = self.htm.getCoefficientMatrix() 

69 for n in range(0, self.order+1): 

70 poly = self.ht(n) 

71 self.assertFloatsAlmostEqual(coeff[n, :n+1], poly.coef, atol=1E-15) 

72 

73 @unittest.skipIf(scipy is None, "Test requires SciPy") 

74 def testTransformMatrix(self): 

75 s = lsst.geom.LinearTransform.makeScaling(2.0, 1.5) 

76 r = lsst.geom.LinearTransform.makeRotation(0.30*lsst.geom.radians) 

77 transforms = [s, r, s*r*s] 

78 testPoints = np.random.randn(10, 2) 

79 for transform in transforms: 

80 m = self.htm.compute(transform) 

81 for testPoint in testPoints: 

82 assert testPoint.size == 2 

83 origPoint = lsst.geom.Point2D(testPoint[0], testPoint[1]) 

84 transPoint = transform(origPoint) 

85 for i, inx, iny in lsst.shapelet.HermiteIndexGenerator(self.order): 

86 v1 = self.ht(inx)(transPoint.getX()) * self.ht(iny)(transPoint.getY()) 

87 v2 = 0.0 

88 for j, jnx, jny in lsst.shapelet.HermiteIndexGenerator(self.order): 

89 v2 += m[i, j] * self.ht(jnx)(origPoint.getX()) * self.ht(jny)(origPoint.getY()) 

90 self.assertFloatsAlmostEqual(v1, v2, rtol=1E-11) 

91 

92 

93class MemoryTester(lsst.utils.tests.MemoryTestCase): 

94 pass 

95 

96 

97def setup_module(module): 

98 lsst.utils.tests.init() 

99 

100 

101if __name__ == "__main__": 101 ↛ 102line 101 didn't jump to line 102, because the condition on line 101 was never true

102 lsst.utils.tests.init() 

103 unittest.main()