Coverage for tests / test_parameters.py: 11%
132 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 23:28 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-14 23:28 +0000
1# This file is part of lsst.scarlet.lite.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import numpy as np
23from lsst.scarlet.lite import Box
24from lsst.scarlet.lite.parameters import (
25 AdaproxParameter,
26 FistaParameter,
27 FixedParameter,
28 Parameter,
29 parameter,
30 phi_psi,
31)
32from numpy.testing import assert_array_equal
33from utils import ScarletTestCase
36def prox_ceiling(x, thresh: float = 20):
37 """Test prox for testing parameters"""
38 x[x > thresh] = thresh
39 return x
42def grad(input_grad: np.ndarray, x: np.ndarray, *args):
43 """Test gradient for testing parameters"""
44 return 2 * x * input_grad
47class TestParameters(ScarletTestCase):
48 def test_parameter_class(self):
49 x = np.arange(15, dtype=float).reshape(3, 5)
50 param = parameter(x)
51 self.assertIsInstance(param, Parameter)
52 assert_array_equal(param.x, x)
53 self.assertTupleEqual(param.shape, (3, 5))
54 self.assertEqual(param.dtype, float)
56 with self.assertRaises(NotImplementedError):
57 param.update(1, np.zeros((3, 5)))
59 # Test copy method
60 y = np.zeros((3, 5), dtype=float)
61 y[1, 3] = 1
62 param = Parameter(x, {"y": y}, 0)
63 self.assertIsNot(param.copy().x, x)
64 assert_array_equal(param.copy().x, x)
65 self.assertIsNot(param.copy().helpers["y"], y)
66 assert_array_equal(param.copy().helpers["y"], y)
68 param2 = parameter(param)
69 self.assertIs(param2, param)
71 def test_growing(self):
72 x = np.arange(15, dtype=float).reshape(3, 5)
73 y = np.zeros((3, 5), dtype=float)
74 y[1, 3] = 1
75 param = Parameter(x, {"y": y}, 0)
77 # Test growing in all dimensions
78 old_box = Box((3, 5), (21, 15))
79 new_box = Box((11, 20), (19, 10))
80 param.resize(old_box, new_box)
81 truth = np.zeros((11, 20), dtype=float)
82 truth[2:5, 5:10] = x
83 assert_array_equal(param.x, truth)
85 # Test shrinking in all directions
86 param = Parameter(x, {"y": y}, 0)
87 old_box = Box((3, 5), (21, 15))
88 new_box = Box((1, 3), (22, 16))
89 param.resize(old_box, new_box)
90 truth = x[1:2, 1:4]
91 assert_array_equal(param.x, truth)
93 def test_fista_parameter(self):
94 x = np.arange(10, dtype=float)
95 x2 = x**2
96 param = FistaParameter(
97 x2,
98 0.1,
99 grad,
100 prox_ceiling,
101 )
103 assert_array_equal(param.x, x2)
104 assert_array_equal(param.grad(np.full(x.shape, 0.1), x), 0.2 * x)
105 truth = x2.copy()
106 truth[truth > 20] = 20
107 assert_array_equal(param.prox(x2), truth)
108 param.update(10, x, x2)
110 def test_adprox_parameter(self):
111 x = np.arange(10, dtype=float)
112 x2 = x**2
113 param = AdaproxParameter(
114 x2,
115 0.1,
116 grad,
117 prox_ceiling,
118 )
120 assert_array_equal(param.x, x2)
121 assert_array_equal(param.grad(np.full(x.shape, 0.1), x), 0.2 * x)
122 truth = x2.copy()
123 truth[truth > 20] = 20
124 assert_array_equal(param.prox(x2), truth)
125 param.update(10, x, x2)
127 schemes = tuple(phi_psi.keys())
128 for scheme in schemes:
129 param = AdaproxParameter(
130 x2,
131 0.1,
132 grad,
133 prox_ceiling,
134 scheme=scheme,
135 )
136 param.update(10, x, x2)
138 def test_fixed_parameter(self):
139 x = np.arange(10, dtype=float)
140 param = FixedParameter(x)
141 param.update(10, np.arange(10) * 2)
142 assert_array_equal(param.x, x)
144 def test_shallow_copy(self):
145 x = np.arange(10, dtype=float)
147 # FistaParameter
148 param = FistaParameter(x, 0.1)
149 param_copy = param.copy()
150 self.assertIsInstance(param_copy, FistaParameter)
152 assert_array_equal(param.x, param_copy.x)
153 assert_array_equal(param.helpers["z"], param_copy.helpers["z"])
155 # AdaproxParameter
156 param = AdaproxParameter(x, 0.1)
157 param_copy = param.copy()
158 self.assertIsInstance(param_copy, AdaproxParameter)
160 assert_array_equal(param.x, param_copy.x)
161 assert_array_equal(param.helpers["m"], param_copy.helpers["m"])
162 assert_array_equal(param.helpers["v"], param_copy.helpers["v"])
163 assert_array_equal(param.helpers["vhat"], param_copy.helpers["vhat"])
165 # FixedParameter
166 param = FixedParameter(x)
167 param_copy = param.copy()
168 self.assertIsInstance(param_copy, FixedParameter)
169 assert_array_equal(param.x, param_copy.x)
171 def test_deep_copy(self):
172 x = np.arange(10, dtype=float)
174 # FistaParameter
175 param = FistaParameter(x, 0.1)
176 param_deepcopy = param.copy(deep=True)
177 self.assertIsInstance(param_deepcopy, FistaParameter)
179 assert_array_equal(param.x, param_deepcopy.x)
180 param_deepcopy.x += 1
181 with self.assertRaises(AssertionError):
182 assert_array_equal(param.x, param_deepcopy.x)
184 assert_array_equal(param.helpers["z"], param_deepcopy.helpers["z"])
185 param_deepcopy.helpers["z"] += 1
186 with self.assertRaises(AssertionError):
187 assert_array_equal(param.helpers["z"], param_deepcopy.helpers["z"])
189 # AdaproxParameter
190 param = AdaproxParameter(x, 0.1)
191 param_deepcopy = param.copy(deep=True)
192 self.assertIsInstance(param_deepcopy, AdaproxParameter)
194 assert_array_equal(param.x, param_deepcopy.x)
195 param_deepcopy.x += 1
196 with self.assertRaises(AssertionError):
197 assert_array_equal(param.x, param_deepcopy.x)
199 assert_array_equal(param.helpers["m"], param_deepcopy.helpers["m"])
200 param_deepcopy.helpers["m"] = -1
201 with self.assertRaises(AssertionError):
202 assert_array_equal(param.helpers["m"], param_deepcopy.helpers["m"])
204 assert_array_equal(param.helpers["v"], param_deepcopy.helpers["v"])
205 param_deepcopy.helpers["v"] = -1
206 with self.assertRaises(AssertionError):
207 assert_array_equal(param.helpers["v"], param_deepcopy.helpers["v"])
209 assert_array_equal(param.helpers["vhat"], param_deepcopy.helpers["vhat"])
210 param_deepcopy.helpers["vhat"] = -1
211 with self.assertRaises(AssertionError):
212 assert_array_equal(param.helpers["vhat"], param_deepcopy.helpers["vhat"])
214 # FixedParameter
215 param = FixedParameter(x)
216 param_deepcopy = param.copy(deep=True)
217 self.assertIsInstance(param_deepcopy, FixedParameter)
218 assert_array_equal(param.x, param_deepcopy.x)
219 param_deepcopy.x += 1
220 with self.assertRaises(AssertionError):
221 assert_array_equal(param.x, param_deepcopy.x)