Coverage for python/lsst/scarlet/lite/parameters.py: 23%
164 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-16 02:46 -0700
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-16 02:46 -0700
1# This file is part of scarlet_lite.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = [
25 "parameter",
26 "Parameter",
27 "FistaParameter",
28 "AdaproxParameter",
29 "FixedParameter",
30 "relative_step",
31 "phi_psi",
32 "DEFAULT_ADAPROX_FACTOR",
33]
35from typing import Callable, Sequence, cast
37import numpy as np
38import numpy.typing as npt
40from .bbox import Box
42# The default factor used for adaprox parameter steps
43DEFAULT_ADAPROX_FACTOR = 1e-2
46def step_function_wrapper(step: float) -> Callable:
47 """Wrapper to make a numerical step into a step function
49 Parameters
50 ----------
51 step:
52 The step to take for a given array.
54 Returns
55 -------
56 step_function:
57 The step function that takes an array and returns the
58 numerical step.
59 """
60 return lambda x: step
63class Parameter:
64 """A parameter in a `Component`
66 Parameters
67 ----------
68 x:
69 The array of values that is being fit.
70 helpers:
71 A dictionary of helper arrays that are used by an optimizer to
72 persist values like the gradient of `x`, the Hessian of `x`, etc.
73 step:
74 A numerical step value or function to calculate the step for a
75 given `x``.
76 grad:
77 A function to calculate the gradient of `x`.
78 prox:
79 A function to take the proximal operator of `x`.
80 """
82 def __init__(
83 self,
84 x: np.ndarray,
85 helpers: dict[str, np.ndarray],
86 step: Callable | float,
87 grad: Callable | None = None,
88 prox: Callable | None = None,
89 ):
90 self.x = x
91 self.helpers = helpers
93 if isinstance(step, float):
94 _step = step_function_wrapper(step)
95 else:
96 _step = step
98 self._step = _step
99 self.grad = grad
100 self.prox = prox
102 @property
103 def step(self) -> float:
104 """Calculate the step
106 Return
107 ------
108 step:
109 The numerical step if no iteration is given.
110 """
111 return self._step(self.x)
113 @property
114 def shape(self) -> tuple[int, ...]:
115 """The shape of the array that is being fit."""
116 return self.x.shape
118 @property
119 def dtype(self) -> npt.DTypeLike:
120 """The numpy dtype of the array that is being fit."""
121 return self.x.dtype
123 def copy(self) -> Parameter:
124 """Copy this parameter, including all of the helper arrays."""
125 helpers = {k: v.copy() for k, v in self.helpers.items()}
126 return Parameter(self.x.copy(), helpers, 0)
128 def update(self, it: int, input_grad: np.ndarray, *args):
129 """Update the parameter in one iteration.
131 This includes the gradient update, proximal update,
132 and any meta parameters that are stored as class
133 attributes to update the parameter.
135 Parameters
136 ----------
137 it:
138 The current iteration
139 input_grad:
140 The gradient from the full model, passed to the parameter.
141 """
142 raise NotImplementedError("Base Parameters cannot be updated")
144 def resize(self, old_box: Box, new_box: Box):
145 """Grow the parameter and all of the helper parameters
147 Parameters
148 ----------
149 old_box:
150 The old bounding box for the parameter.
151 new_box:
152 The new bounding box for the parameter.
153 """
154 slices = new_box.overlapped_slices(old_box)
155 x = np.zeros(new_box.shape, dtype=self.dtype)
156 x[slices[0]] = self.x[slices[1]]
157 self.x = x
159 for name, value in self.helpers.items():
160 result = np.zeros(new_box.shape, dtype=self.dtype)
161 result[slices[0]] = value[slices[1]]
162 self.helpers[name] = result
165def parameter(x: np.ndarray | Parameter) -> Parameter:
166 """Convert a `np.ndarray` into a `Parameter`.
168 Parameters
169 ----------
170 x:
171 The array or parameter to convert into a `Parameter`.
173 Returns
174 -------
175 result:
176 `x`, converted into a `Parameter` if necessary.
177 """
178 if isinstance(x, Parameter):
179 return x
180 return Parameter(x, {}, 0)
183class FistaParameter(Parameter):
184 """A `Parameter` that updates itself using the Beck-Teboulle 2009
185 FISTA proximal gradient method.
187 See https://www.ceremade.dauphine.fr/~carlier/FISTA
188 """
190 def __init__(
191 self,
192 x: np.ndarray,
193 step: float,
194 grad: Callable | None = None,
195 prox: Callable | None = None,
196 t0: float = 1,
197 z0: np.ndarray | None = None,
198 ):
199 if z0 is None:
200 z0 = x
202 super().__init__(
203 x,
204 {"z": z0},
205 step,
206 grad,
207 prox,
208 )
209 self.t = t0
211 def update(self, it: int, input_grad: np.ndarray, *args):
212 """Update the parameter and meta-parameters using the PGM
214 See `Parameter` for the full description.
215 """
216 step = self.step / np.sum(args[0] * args[0])
217 _x = self.x
218 _z = self.helpers["z"]
220 y = _z - step * cast(Callable, self.grad)(input_grad, _x, *args)
221 if self.prox is not None:
222 x = self.prox(y)
223 else:
224 x = y
225 t = 0.5 * (1 + np.sqrt(1 + 4 * self.t**2))
226 omega = 1 + (self.t - 1) / t
227 self.helpers["z"] = _x + omega * (x - _x)
228 _x[:] = x
229 self.t = t
232# The following code block contains different update methods for
233# various implementations of ADAM.
234# We currently use the `amsgrad_phi_psi` update by default,
235# but it can easily be interchanged by passing a different
236# variant name to the `AdaproxParameter`.
239# noinspection PyUnusedLocal
240def _adam_phi_psi(it, g, m, v, vhat, b1, b2, eps, p):
241 # moving averages
242 m[:] = (1 - b1[it]) * g + b1[it] * m
243 v[:] = (1 - b2) * (g**2) + b2 * v
245 # bias correction
246 t = it + 1
247 phi = m / (1 - b1[it] ** t)
248 psi = np.sqrt(v / (1 - b2**t)) + eps
249 return phi, psi
252# noinspection PyUnusedLocal
253def _nadam_phi_psi(it, g, m, v, vhat, b1, b2, eps, p):
254 # moving averages
255 m[:] = (1 - b1[it]) * g + b1[it] * m
256 v[:] = (1 - b2) * (g**2) + b2 * v
258 # bias correction
259 t = it + 1
260 phi = (b1[it] * m[:] + (1 - b1[it]) * g) / (1 - b1[it] ** t)
261 psi = np.sqrt(v / (1 - b2**t)) + eps
262 return phi, psi
265# noinspection PyUnusedLocal
266def _amsgrad_phi_psi(it, g, m, v, vhat, b1, b2, eps, p):
267 # moving averages
268 m[:] = (1 - b1[it]) * g + b1[it] * m
269 v[:] = (1 - b2) * (g**2) + b2 * v
271 phi = m
272 vhat[:] = np.maximum(vhat, v)
273 # sanitize zero-gradient elements
274 if eps > 0:
275 vhat = np.maximum(vhat, eps)
276 psi = np.sqrt(vhat)
277 return phi, psi
280def _padam_phi_psi(it, g, m, v, vhat, b1, b2, eps, p):
281 # moving averages
282 m[:] = (1 - b1[it]) * g + b1[it] * m
283 v[:] = (1 - b2) * (g**2) + b2 * v
285 phi = m
286 vhat[:] = np.maximum(vhat, v)
287 # sanitize zero-gradient elements
288 if eps > 0:
289 vhat = np.maximum(vhat, eps)
290 psi = vhat**p
291 return phi, psi
294# noinspection PyUnusedLocal
295def _adamx_phi_psi(it, g, m, v, vhat, b1, b2, eps, p):
296 # moving averages
297 m[:] = (1 - b1[it]) * g + b1[it] * m
298 v[:] = (1 - b2) * (g**2) + b2 * v
300 phi = m
301 factor = (1 - b1[it]) ** 2 / (1 - b1[it - 1]) ** 2
302 vhat[:] = np.maximum(factor * vhat, v)
303 # sanitize zero-gradient elements
304 if eps > 0:
305 vhat = np.maximum(vhat, eps)
306 psi = np.sqrt(vhat)
307 return phi, psi
310# noinspection PyUnusedLocal
311def _radam_phi_psi(it, g, m, v, vhat, b1, b2, eps, p):
312 rho_inf = 2 / (1 - b2) - 1
314 # moving averages
315 m[:] = (1 - b1[it]) * g + b1[it] * m
316 v[:] = (1 - b2) * (g**2) + b2 * v
318 # bias correction
319 t = it + 1
320 phi = m / (1 - b1[it] ** t)
321 rho = rho_inf - 2 * t * b2**t / (1 - b2**t)
323 if rho > 4:
324 psi = np.sqrt(v / (1 - b2**t))
325 r = np.sqrt((rho - 4) * (rho - 2) * rho_inf / (rho_inf - 4) / (rho_inf - 2) / rho)
326 psi /= r
327 else:
328 psi = np.ones(g.shape, g.dtype)
329 # sanitize zero-gradient elements
330 if eps > 0:
331 psi = np.maximum(psi, np.sqrt(eps))
332 return phi, psi
335# Dictionary to link ADAM variation names to their functional algorithms.
336phi_psi = {
337 "adam": _adam_phi_psi,
338 "nadam": _nadam_phi_psi,
339 "amsgrad": _amsgrad_phi_psi,
340 "padam": _padam_phi_psi,
341 "adamx": _adamx_phi_psi,
342 "radam": _radam_phi_psi,
343}
346class SingleItemArray:
347 """Mock an array with only a single item"""
349 def __init__(self, value):
350 self.value = value
352 def __getitem__(self, item):
353 return self.value
356class AdaproxParameter(Parameter):
357 """Operator updated using te Proximal ADAM algorithm
359 Uses multiple variants of adaptive quasi-Newton gradient descent
360 * Adam (Kingma & Ba 2015)
361 * NAdam (Dozat 2016)
362 * AMSGrad (Reddi, Kale & Kumar 2018)
363 * PAdam (Chen & Gu 2018)
364 * AdamX (Phuong & Phong 2019)
365 * RAdam (Liu et al. 2019)
366 See details of the algorithms in the respective papers.
367 """
369 def __init__(
370 self,
371 x: np.ndarray,
372 step: Callable | float,
373 grad: Callable | None = None,
374 prox: Callable | None = None,
375 b1: float = 0.9,
376 b2: float = 0.999,
377 eps: float = 1e-8,
378 p: float = 0.25,
379 m0: np.ndarray | None = None,
380 v0: np.ndarray | None = None,
381 vhat0: np.ndarray | None = None,
382 scheme: str = "amsgrad",
383 prox_e_rel: float = 1e-6,
384 ):
385 shape = x.shape
386 dtype = x.dtype
387 if m0 is None:
388 m0 = np.zeros(shape, dtype=dtype)
390 if v0 is None:
391 v0 = np.zeros(shape, dtype=dtype)
393 if vhat0 is None:
394 vhat0 = np.ones(shape, dtype=dtype) * -np.inf
396 super().__init__(
397 x,
398 {
399 "m": m0,
400 "v": v0,
401 "vhat": vhat0,
402 },
403 step,
404 grad,
405 prox,
406 )
408 if isinstance(b1, float):
409 _b1 = SingleItemArray(b1)
410 else:
411 _b1 = b1
413 self.b1 = _b1
414 self.b2 = b2
415 self.eps = eps
416 self.p = p
418 self.phi_psi = phi_psi[scheme]
419 self.e_rel = prox_e_rel
421 def update(self, it: int, input_grad: np.ndarray, *args):
422 """Update the parameter and meta-parameters using the PGM
424 See `~Parameter` for more.
425 """
426 _x = self.x
427 # Calculate the gradient
428 grad = cast(Callable, self.grad)(input_grad, _x, *args)
429 # Get the update for the parameter
430 phi, psi = self.phi_psi(
431 it,
432 grad,
433 self.helpers["m"],
434 self.helpers["v"],
435 self.helpers["vhat"],
436 self.b1,
437 self.b2,
438 self.eps,
439 self.p,
440 )
441 # Calculate the step size
442 step = self.step
443 if it > 0:
444 _x += -step * phi / psi
445 else:
446 # This is a scheme that Peter Melchior and I came up with to
447 # dampen the known affect of ADAM, where the first iteration
448 # is often much larger than desired.
449 _x += -step * phi / psi / 10
451 self.x = cast(Callable, self.prox)(_x)
454class FixedParameter(Parameter):
455 """A parameter that is not updated"""
457 def __init__(self, x: np.ndarray):
458 super().__init__(x, {}, 0)
460 def update(self, it: int, input_grad: np.ndarray, *args):
461 pass
464def relative_step(
465 x: np.ndarray,
466 factor: float = 0.1,
467 minimum: float = 0,
468 axis: int | Sequence[int] | None = None,
469):
470 """Step size set at `factor` times the mean of `X` in direction `axis`"""
471 return np.maximum(minimum, factor * x.mean(axis=axis))