Coverage for tests/test_parameters.py: 16%
75 statements
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-01 11:54 +0000
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-01 11:54 +0000
1# This file is part of lsst.scarlet.lite.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import numpy as np
23from lsst.scarlet.lite import Box
24from lsst.scarlet.lite.parameters import (
25 AdaproxParameter,
26 FistaParameter,
27 FixedParameter,
28 Parameter,
29 parameter,
30 phi_psi,
31)
32from numpy.testing import assert_array_equal
33from utils import ScarletTestCase
36def prox_ceiling(x, thresh: float = 20):
37 """Test prox for testing parameters"""
38 x[x > thresh] = thresh
39 return x
42def grad(input_grad: np.ndarray, x: np.ndarray, *args):
43 """Test gradient for testing parameters"""
44 return 2 * x * input_grad
47class TestParameters(ScarletTestCase):
48 def test_parameter_class(self):
49 x = np.arange(15, dtype=float).reshape(3, 5)
50 param = parameter(x)
51 self.assertIsInstance(param, Parameter)
52 assert_array_equal(param.x, x)
53 self.assertTupleEqual(param.shape, (3, 5))
54 self.assertEqual(param.dtype, float)
56 with self.assertRaises(NotImplementedError):
57 param.update(1, np.zeros((3, 5)))
59 # Test copy method
60 y = np.zeros((3, 5), dtype=float)
61 y[1, 3] = 1
62 param = Parameter(x, {"y": y}, 0)
63 self.assertIsNot(param.copy().x, x)
64 assert_array_equal(param.copy().x, x)
65 self.assertIsNot(param.copy().helpers["y"], y)
66 assert_array_equal(param.copy().helpers["y"], y)
68 param2 = parameter(param)
69 self.assertIs(param2, param)
71 def test_growing(self):
72 x = np.arange(15, dtype=float).reshape(3, 5)
73 y = np.zeros((3, 5), dtype=float)
74 y[1, 3] = 1
75 param = Parameter(x, {"y": y}, 0)
77 # Test growing in all dimensions
78 old_box = Box((3, 5), (21, 15))
79 new_box = Box((11, 20), (19, 10))
80 param.resize(old_box, new_box)
81 truth = np.zeros((11, 20), dtype=float)
82 truth[2:5, 5:10] = x
83 assert_array_equal(param.x, truth)
85 # Test shrinking in all directions
86 param = Parameter(x, {"y": y}, 0)
87 old_box = Box((3, 5), (21, 15))
88 new_box = Box((1, 3), (22, 16))
89 param.resize(old_box, new_box)
90 truth = x[1:2, 1:4]
91 assert_array_equal(param.x, truth)
93 def test_fista_parameter(self):
94 x = np.arange(10, dtype=float)
95 x2 = x**2
96 param = FistaParameter(
97 x2,
98 0.1,
99 grad,
100 prox_ceiling,
101 )
103 assert_array_equal(param.x, x2)
104 assert_array_equal(param.grad(np.full(x.shape, 0.1), x), 0.2 * x)
105 truth = x2.copy()
106 truth[truth > 20] = 20
107 assert_array_equal(param.prox(x2), truth)
108 param.update(10, x, x2)
110 def test_adprox_parameter(self):
111 x = np.arange(10, dtype=float)
112 x2 = x**2
113 param = AdaproxParameter(
114 x2,
115 0.1,
116 grad,
117 prox_ceiling,
118 )
120 assert_array_equal(param.x, x2)
121 assert_array_equal(param.grad(np.full(x.shape, 0.1), x), 0.2 * x)
122 truth = x2.copy()
123 truth[truth > 20] = 20
124 assert_array_equal(param.prox(x2), truth)
125 param.update(10, x, x2)
127 schemes = tuple(phi_psi.keys())
128 for scheme in schemes:
129 param = AdaproxParameter(
130 x2,
131 0.1,
132 grad,
133 prox_ceiling,
134 scheme=scheme,
135 )
136 param.update(10, x, x2)
138 def test_fixed_parameter(self):
139 x = np.arange(10, dtype=float)
140 param = FixedParameter(x)
141 param.update(10, np.arange(10) * 2)
142 assert_array_equal(param.x, x)