Coverage for tests/test_leastSquares.py: 17%

119 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-15 02:24 -0700

1# 

2# This file is part of afw. 

3# 

4# Developed for the LSST Data Management System. 

5# This product includes software developed by the LSST Project 

6# (https://www.lsst.org). 

7# See the COPYRIGHT file at the top-level directory of this distribution 

8# for details of code ownership. 

9# 

10# This program is free software: you can redistribute it and/or modify 

11# it under the terms of the GNU General Public License as published by 

12# the Free Software Foundation, either version 3 of the License, or 

13# (at your option) any later version. 

14# 

15# This program is distributed in the hope that it will be useful, 

16# but WITHOUT ANY WARRANTY; without even the implied warranty of 

17# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

18# GNU General Public License for more details. 

19# 

20# You should have received a copy of the GNU General Public License 

21# along with this program. If not, see <https://www.gnu.org/licenses/>. 

22# 

23 

24""" 

25Tests for math.LeastSquares 

26 

27Run with: 

28 python test_leastSquares.py 

29or 

30 pytest test_leastSquares.py 

31""" 

32 

33import unittest 

34import sys 

35 

36import numpy as np 

37 

38import lsst.utils.tests 

39import lsst.pex.exceptions 

40from lsst.afw.math import LeastSquares 

41from lsst.log import Log 

42 

43Log.getLogger("lsst.afw.math.LeastSquares").setLevel(Log.DEBUG) 

44 

45 

46class LeastSquaresTestCase(lsst.utils.tests.TestCase): 

47 

48 def _assertClose(self, a, b, rtol=1E-5, atol=1E-8): 

49 self.assertFloatsAlmostEqual( 

50 a, b, rtol=rtol, atol=atol, msg=f"\n{a}\n!=\n{b}") 

51 

52 def _assertNotClose(self, a, b, rtol=1E-5, atol=1E-8): 

53 self.assertFloatsNotEqual( 

54 a, b, rtol=rtol, atol=atol, msg=f"\n{a}\n!=\n{b}") 

55 

56 def setUp(self): 

57 np.random.seed(500) 

58 

59 def check(self, solver, solution, rank, fisher, cov, sv): 

60 self.assertEqual(solver.getRank(), rank) 

61 self.assertEqual(solver.getDimension(), solution.shape[0]) 

62 self._assertClose(solver.getSolution(), solution) 

63 self._assertClose(solver.getFisherMatrix(), fisher) 

64 self._assertClose(solver.getCovariance(), cov) 

65 if solver.getFactorization() != LeastSquares.NORMAL_CHOLESKY: 

66 self._assertClose( 

67 solver.getDiagnostic(LeastSquares.NORMAL_EIGENSYSTEM), 

68 sv**2) 

69 diagnostic = solver.getDiagnostic(solver.getFactorization()) 

70 rcond = diagnostic[0] * solver.getThreshold() 

71 self.assertGreater(diagnostic[rank-1], rcond) 

72 if rank < solver.getDimension(): 

73 self.assertLess(diagnostic[rank], rcond) 

74 else: 

75 self._assertClose( 

76 np.multiply.reduce(solver.getDiagnostic(LeastSquares.NORMAL_CHOLESKY)), 

77 np.multiply.reduce(sv**2)) 

78 

79 def testFullRank(self): 

80 dimension = 10 

81 nData = 500 

82 design = np.random.randn(dimension, nData).transpose() 

83 data = np.random.randn(nData) 

84 fisher = np.dot(design.transpose(), design) 

85 rhs = np.dot(design.transpose(), data) 

86 solution, residues, rank, sv = np.linalg.lstsq(design, data, rcond=None) 

87 cov = np.linalg.inv(fisher) 

88 s_svd = LeastSquares.fromDesignMatrix( 

89 design, data, LeastSquares.DIRECT_SVD) 

90 s_design_eigen = LeastSquares.fromDesignMatrix( 

91 design, data, LeastSquares.NORMAL_EIGENSYSTEM) 

92 s_design_cholesky = LeastSquares.fromDesignMatrix( 

93 design, data, LeastSquares.NORMAL_CHOLESKY) 

94 s_normal_eigen = LeastSquares.fromNormalEquations( 

95 fisher, rhs, LeastSquares.NORMAL_EIGENSYSTEM) 

96 s_normal_cholesky = LeastSquares.fromNormalEquations( 

97 fisher, rhs, LeastSquares.NORMAL_CHOLESKY) 

98 self.check(s_svd, solution, rank, fisher, cov, sv) 

99 self.check(s_design_eigen, solution, rank, fisher, cov, sv) 

100 self.check(s_design_cholesky, solution, rank, fisher, cov, sv) 

101 self.check(s_normal_eigen, solution, rank, fisher, cov, sv) 

102 self.check(s_normal_cholesky, solution, rank, fisher, cov, sv) 

103 # test updating solver in-place with the same kind of inputs 

104 design = np.random.randn(dimension, nData).transpose() 

105 data = np.random.randn(nData) 

106 fisher = np.dot(design.transpose(), design) 

107 rhs = np.dot(design.transpose(), data) 

108 solution, residues, rank, sv = np.linalg.lstsq(design, data, rcond=None) 

109 cov = np.linalg.inv(fisher) 

110 s_svd.setDesignMatrix(design, data) 

111 s_design_eigen.setDesignMatrix(design, data) 

112 s_design_cholesky.setDesignMatrix(design, data) 

113 s_normal_eigen.setNormalEquations(fisher, rhs) 

114 s_normal_cholesky.setNormalEquations(fisher, rhs) 

115 self.check(s_svd, solution, rank, fisher, cov, sv) 

116 self.check(s_design_eigen, solution, rank, fisher, cov, sv) 

117 self.check(s_design_cholesky, solution, rank, fisher, cov, sv) 

118 self.check(s_normal_eigen, solution, rank, fisher, cov, sv) 

119 self.check(s_normal_cholesky, solution, rank, fisher, cov, sv) 

120 # test updating solver in-place with the opposite kind of inputs 

121 design = np.random.randn(dimension, nData).transpose() 

122 data = np.random.randn(nData) 

123 fisher = np.dot(design.transpose(), design) 

124 rhs = np.dot(design.transpose(), data) 

125 solution, residues, rank, sv = np.linalg.lstsq(design, data, rcond=None) 

126 cov = np.linalg.inv(fisher) 

127 s_normal_eigen.setDesignMatrix(design, data) 

128 s_normal_cholesky.setDesignMatrix(design, data) 

129 s_design_eigen.setNormalEquations(fisher, rhs) 

130 s_design_cholesky.setNormalEquations(fisher, rhs) 

131 self.check(s_design_eigen, solution, rank, fisher, cov, sv) 

132 self.check(s_design_cholesky, solution, rank, fisher, cov, sv) 

133 self.check(s_normal_eigen, solution, rank, fisher, cov, sv) 

134 self.check(s_normal_cholesky, solution, rank, fisher, cov, sv) 

135 

136 def testSingular(self): 

137 dimension = 10 

138 nData = 100 

139 svIn = (np.random.randn(dimension) + 1.0)**2 + 1.0 

140 svIn = np.sort(svIn)[::-1] 

141 svIn[-1] = 0.0 

142 svIn[-2] = svIn[0] * 1E-4 

143 # Just use SVD to get a pair of orthogonal matrices; we'll use our own singular values 

144 # so we can control the stability of the matrix. 

145 u, s, vt = np.linalg.svd(np.random.randn(dimension, nData), 

146 full_matrices=False) 

147 design = np.dot(u * svIn, vt).transpose() 

148 data = np.random.randn(nData) 

149 fisher = np.dot(design.transpose(), design) 

150 rhs = np.dot(design.transpose(), data) 

151 threshold = 10 * sys.float_info.epsilon 

152 solution, residues, rank, sv = np.linalg.lstsq( 

153 design, data, rcond=threshold) 

154 self._assertClose(svIn, sv) 

155 cov = np.linalg.pinv(fisher, rcond=threshold) 

156 s_svd = LeastSquares.fromDesignMatrix( 

157 design, data, LeastSquares.DIRECT_SVD) 

158 s_design_eigen = LeastSquares.fromDesignMatrix( 

159 design, data, LeastSquares.NORMAL_EIGENSYSTEM) 

160 s_normal_eigen = LeastSquares.fromNormalEquations( 

161 fisher, rhs, LeastSquares.NORMAL_EIGENSYSTEM) 

162 self.check(s_svd, solution, rank, fisher, cov, sv) 

163 self.check(s_design_eigen, solution, rank, fisher, cov, sv) 

164 self.check(s_normal_eigen, solution, rank, fisher, cov, sv) 

165 s_svd.setThreshold(1E-3) 

166 s_design_eigen.setThreshold(1E-6) 

167 s_normal_eigen.setThreshold(1E-6) 

168 self.assertEqual(s_svd.getRank(), dimension - 2) 

169 self.assertEqual(s_design_eigen.getRank(), dimension - 2) 

170 self.assertEqual(s_normal_eigen.getRank(), dimension - 2) 

171 # Just check that solutions are different from before, but consistent with each other; 

172 # I can't figure out how get np.lstsq to deal with the thresholds appropriately to 

173 # test against that. 

174 self._assertNotClose(s_svd.getSolution(), solution) 

175 self._assertNotClose(s_design_eigen.getSolution(), solution) 

176 self._assertNotClose(s_normal_eigen.getSolution(), solution) 

177 self._assertClose(s_svd.getSolution(), s_design_eigen.getSolution()) 

178 self._assertClose(s_svd.getSolution(), s_normal_eigen.getSolution()) 

179 

180 

181class MemoryTester(lsst.utils.tests.MemoryTestCase): 

182 pass 

183 

184 

185def setup_module(module): 

186 lsst.utils.tests.init() 

187 

188 

189if __name__ == "__main__": 189 ↛ 190line 189 didn't jump to line 190, because the condition on line 189 was never true

190 lsst.utils.tests.init() 

191 unittest.main()