Coverage for tests/test_parameters.py: 16%

75 statements  

« prev     ^ index     » next       coverage.py v7.4.3, created at 2024-03-01 11:54 +0000

1# This file is part of lsst.scarlet.lite. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import numpy as np 

23from lsst.scarlet.lite import Box 

24from lsst.scarlet.lite.parameters import ( 

25 AdaproxParameter, 

26 FistaParameter, 

27 FixedParameter, 

28 Parameter, 

29 parameter, 

30 phi_psi, 

31) 

32from numpy.testing import assert_array_equal 

33from utils import ScarletTestCase 

34 

35 

36def prox_ceiling(x, thresh: float = 20): 

37 """Test prox for testing parameters""" 

38 x[x > thresh] = thresh 

39 return x 

40 

41 

42def grad(input_grad: np.ndarray, x: np.ndarray, *args): 

43 """Test gradient for testing parameters""" 

44 return 2 * x * input_grad 

45 

46 

47class TestParameters(ScarletTestCase): 

48 def test_parameter_class(self): 

49 x = np.arange(15, dtype=float).reshape(3, 5) 

50 param = parameter(x) 

51 self.assertIsInstance(param, Parameter) 

52 assert_array_equal(param.x, x) 

53 self.assertTupleEqual(param.shape, (3, 5)) 

54 self.assertEqual(param.dtype, float) 

55 

56 with self.assertRaises(NotImplementedError): 

57 param.update(1, np.zeros((3, 5))) 

58 

59 # Test copy method 

60 y = np.zeros((3, 5), dtype=float) 

61 y[1, 3] = 1 

62 param = Parameter(x, {"y": y}, 0) 

63 self.assertIsNot(param.copy().x, x) 

64 assert_array_equal(param.copy().x, x) 

65 self.assertIsNot(param.copy().helpers["y"], y) 

66 assert_array_equal(param.copy().helpers["y"], y) 

67 

68 param2 = parameter(param) 

69 self.assertIs(param2, param) 

70 

71 def test_growing(self): 

72 x = np.arange(15, dtype=float).reshape(3, 5) 

73 y = np.zeros((3, 5), dtype=float) 

74 y[1, 3] = 1 

75 param = Parameter(x, {"y": y}, 0) 

76 

77 # Test growing in all dimensions 

78 old_box = Box((3, 5), (21, 15)) 

79 new_box = Box((11, 20), (19, 10)) 

80 param.resize(old_box, new_box) 

81 truth = np.zeros((11, 20), dtype=float) 

82 truth[2:5, 5:10] = x 

83 assert_array_equal(param.x, truth) 

84 

85 # Test shrinking in all directions 

86 param = Parameter(x, {"y": y}, 0) 

87 old_box = Box((3, 5), (21, 15)) 

88 new_box = Box((1, 3), (22, 16)) 

89 param.resize(old_box, new_box) 

90 truth = x[1:2, 1:4] 

91 assert_array_equal(param.x, truth) 

92 

93 def test_fista_parameter(self): 

94 x = np.arange(10, dtype=float) 

95 x2 = x**2 

96 param = FistaParameter( 

97 x2, 

98 0.1, 

99 grad, 

100 prox_ceiling, 

101 ) 

102 

103 assert_array_equal(param.x, x2) 

104 assert_array_equal(param.grad(np.full(x.shape, 0.1), x), 0.2 * x) 

105 truth = x2.copy() 

106 truth[truth > 20] = 20 

107 assert_array_equal(param.prox(x2), truth) 

108 param.update(10, x, x2) 

109 

110 def test_adprox_parameter(self): 

111 x = np.arange(10, dtype=float) 

112 x2 = x**2 

113 param = AdaproxParameter( 

114 x2, 

115 0.1, 

116 grad, 

117 prox_ceiling, 

118 ) 

119 

120 assert_array_equal(param.x, x2) 

121 assert_array_equal(param.grad(np.full(x.shape, 0.1), x), 0.2 * x) 

122 truth = x2.copy() 

123 truth[truth > 20] = 20 

124 assert_array_equal(param.prox(x2), truth) 

125 param.update(10, x, x2) 

126 

127 schemes = tuple(phi_psi.keys()) 

128 for scheme in schemes: 

129 param = AdaproxParameter( 

130 x2, 

131 0.1, 

132 grad, 

133 prox_ceiling, 

134 scheme=scheme, 

135 ) 

136 param.update(10, x, x2) 

137 

138 def test_fixed_parameter(self): 

139 x = np.arange(10, dtype=float) 

140 param = FixedParameter(x) 

141 param.update(10, np.arange(10) * 2) 

142 assert_array_equal(param.x, x)