Coverage for python / lsst / scarlet / lite / parameters.py: 22%
187 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-17 08:40 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-17 08:40 +0000
1# This file is part of scarlet_lite.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = [
25 "parameter",
26 "Parameter",
27 "FistaParameter",
28 "AdaproxParameter",
29 "FixedParameter",
30 "relative_step",
31 "phi_psi",
32 "DEFAULT_ADAPROX_FACTOR",
33]
35from copy import deepcopy
36from typing import Any, Callable, Sequence, cast
38import numpy as np
39import numpy.typing as npt
41from .bbox import Box
43# The default factor used for adaprox parameter steps
44DEFAULT_ADAPROX_FACTOR = 1e-2
47def step_function_wrapper(step: float) -> Callable:
48 """Wrapper to make a numerical step into a step function
50 Parameters
51 ----------
52 step:
53 The step to take for a given array.
55 Returns
56 -------
57 step_function:
58 The step function that takes an array and returns the
59 numerical step.
60 """
61 return lambda x: step
64class Parameter:
65 """A parameter in a `Component`
67 Parameters
68 ----------
69 x:
70 The array of values that is being fit.
71 helpers:
72 A dictionary of helper arrays that are used by an optimizer to
73 persist values like the gradient of `x`, the Hessian of `x`, etc.
74 step:
75 A numerical step value or function to calculate the step for a
76 given `x``.
77 grad:
78 A function to calculate the gradient of `x`.
79 prox:
80 A function to take the proximal operator of `x`.
81 """
83 def __init__(
84 self,
85 x: np.ndarray,
86 helpers: dict[str, np.ndarray],
87 step: Callable | float,
88 grad: Callable | None = None,
89 prox: Callable | None = None,
90 ):
91 self.x = x
92 self.helpers = helpers
94 if isinstance(step, float):
95 _step = step_function_wrapper(step)
96 else:
97 _step = step
99 self._step = _step
100 self.grad = grad
101 self.prox = prox
103 @property
104 def step(self) -> float:
105 """Calculate the step
107 Return
108 ------
109 step:
110 The numerical step if no iteration is given.
111 """
112 return self._step(self.x)
114 @property
115 def shape(self) -> tuple[int, ...]:
116 """The shape of the array that is being fit."""
117 return self.x.shape
119 @property
120 def dtype(self) -> npt.DTypeLike:
121 """The numpy dtype of the array that is being fit."""
122 return self.x.dtype
124 def __copy__(self) -> Parameter:
125 """Create a shallow copy of this parameter.
127 Returns
128 -------
129 parameter:
130 A shallow copy of this parameter.
131 """
132 helpers = {k: v.copy() for k, v in self.helpers.items()}
133 return Parameter(self.x.copy(), helpers, 0)
135 def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Parameter:
136 """Create a deep copy of this parameter.
138 Parameters
139 ----------
140 memo:
141 A memoization dictionary used by `copy.deepcopy`.
142 Returns
143 -------
144 parameter:
145 A deep copy of this parameter.
146 """
147 helpers = {k: deepcopy(v, memo) for k, v in self.helpers.items()}
148 return Parameter(deepcopy(self.x, memo), helpers, 0)
150 def copy(self, deep: bool = False) -> Parameter:
151 """Copy this parameter, including all of the helper arrays.
153 Parameters
154 ----------
155 deep:
156 If `True`, a deep copy is made.
157 If `False`, a shallow copy is made.
159 Returns
160 -------
161 parameter:
162 A copy of this parameter.
163 """
164 if deep:
165 return self.__deepcopy__({})
166 return self.__copy__()
168 def update(self, it: int, input_grad: np.ndarray, *args):
169 """Update the parameter in one iteration.
171 This includes the gradient update, proximal update,
172 and any meta parameters that are stored as class
173 attributes to update the parameter.
175 Parameters
176 ----------
177 it:
178 The current iteration
179 input_grad:
180 The gradient from the full model, passed to the parameter.
181 """
182 raise NotImplementedError("Base Parameters cannot be updated")
184 def resize(self, old_box: Box, new_box: Box):
185 """Grow the parameter and all of the helper parameters
187 Parameters
188 ----------
189 old_box:
190 The old bounding box for the parameter.
191 new_box:
192 The new bounding box for the parameter.
193 """
194 slices = new_box.overlapped_slices(old_box)
195 x = np.zeros(new_box.shape, dtype=self.dtype)
196 x[slices[0]] = self.x[slices[1]]
197 self.x = x
199 for name, value in self.helpers.items():
200 result = np.zeros(new_box.shape, dtype=self.dtype)
201 result[slices[0]] = value[slices[1]]
202 self.helpers[name] = result
205def parameter(x: np.ndarray | Parameter) -> Parameter:
206 """Convert a `np.ndarray` into a `Parameter`.
208 Parameters
209 ----------
210 x:
211 The array or parameter to convert into a `Parameter`.
213 Returns
214 -------
215 result:
216 `x`, converted into a `Parameter` if necessary.
217 """
218 if isinstance(x, Parameter):
219 return x
220 return Parameter(x, {}, 0)
223class FistaParameter(Parameter):
224 """A `Parameter` that updates itself using the Beck-Teboulle 2009
225 FISTA proximal gradient method.
227 See https://www.ceremade.dauphine.fr/~carlier/FISTA
228 """
230 def __init__(
231 self,
232 x: np.ndarray,
233 step: float,
234 grad: Callable | None = None,
235 prox: Callable | None = None,
236 t0: float = 1,
237 z0: np.ndarray | None = None,
238 ):
239 if z0 is None:
240 z0 = x.copy()
242 super().__init__(
243 x,
244 {"z": z0},
245 step,
246 grad,
247 prox,
248 )
249 self.t = t0
251 def update(self, it: int, input_grad: np.ndarray, *args):
252 """Update the parameter and meta-parameters using the PGM
254 See `Parameter` for the full description.
255 """
256 if len(args) == 0:
257 step = self.step
258 else:
259 step = self.step / np.sum(args[0] * args[0])
260 _x = self.x
261 _z = self.helpers["z"]
263 y = _z - step * cast(Callable, self.grad)(input_grad, _x, *args)
264 if self.prox is not None:
265 x = self.prox(y)
266 else:
267 x = y
268 t = 0.5 * (1 + np.sqrt(1 + 4 * self.t**2))
269 omega = 1 + (self.t - 1) / t
270 self.helpers["z"] = _x + omega * (x - _x)
271 _x[:] = x
272 self.t = t
274 def __deepcopy__(self, memo: dict[int, Any] | None = None) -> FistaParameter:
275 """Create a deep copy of this parameter.
277 Parameters
278 ----------
279 memo:
280 A memoization dictionary used by `copy.deepcopy`.
281 Returns
282 -------
283 parameter:
284 A deep copy of this parameter.
285 """
286 return FistaParameter(
287 deepcopy(self.x, memo),
288 self.step,
289 self.grad,
290 self.prox,
291 self.t,
292 deepcopy(self.helpers["z"], memo),
293 )
295 def __copy__(self) -> FistaParameter:
296 """Create a shallow copy of this parameter.
298 Returns
299 -------
300 parameter:
301 A shallow copy of this parameter.
302 """
303 return FistaParameter(
304 self.x.copy(),
305 self.step,
306 self.grad,
307 self.prox,
308 self.t,
309 self.helpers["z"].copy(),
310 )
313# The following code block contains different update methods for
314# various implementations of ADAM.
315# We currently use the `amsgrad_phi_psi` update by default,
316# but it can easily be interchanged by passing a different
317# variant name to the `AdaproxParameter`.
320# noinspection PyUnusedLocal
321def _adam_phi_psi(it, g, m, v, vhat, b1, b2, eps, p):
322 # moving averages
323 m[:] = (1 - b1[it]) * g + b1[it] * m
324 v[:] = (1 - b2) * (g**2) + b2 * v
326 # bias correction
327 t = it + 1
328 phi = m / (1 - b1[it] ** t)
329 psi = np.sqrt(v / (1 - b2**t)) + eps
330 return phi, psi
333# noinspection PyUnusedLocal
334def _nadam_phi_psi(it, g, m, v, vhat, b1, b2, eps, p):
335 # moving averages
336 m[:] = (1 - b1[it]) * g + b1[it] * m
337 v[:] = (1 - b2) * (g**2) + b2 * v
339 # bias correction
340 t = it + 1
341 phi = (b1[it] * m[:] + (1 - b1[it]) * g) / (1 - b1[it] ** t)
342 psi = np.sqrt(v / (1 - b2**t)) + eps
343 return phi, psi
346# noinspection PyUnusedLocal
347def _amsgrad_phi_psi(it, g, m, v, vhat, b1, b2, eps, p):
348 # moving averages
349 m[:] = (1 - b1[it]) * g + b1[it] * m
350 v[:] = (1 - b2) * (g**2) + b2 * v
352 phi = m
353 vhat[:] = np.maximum(vhat, v)
354 # sanitize zero-gradient elements
355 if eps > 0:
356 vhat = np.maximum(vhat, eps)
357 psi = np.sqrt(vhat)
358 return phi, psi
361def _padam_phi_psi(it, g, m, v, vhat, b1, b2, eps, p):
362 # moving averages
363 m[:] = (1 - b1[it]) * g + b1[it] * m
364 v[:] = (1 - b2) * (g**2) + b2 * v
366 phi = m
367 vhat[:] = np.maximum(vhat, v)
368 # sanitize zero-gradient elements
369 if eps > 0:
370 vhat = np.maximum(vhat, eps)
371 psi = vhat**p
372 return phi, psi
375# noinspection PyUnusedLocal
376def _adamx_phi_psi(it, g, m, v, vhat, b1, b2, eps, p):
377 # moving averages
378 m[:] = (1 - b1[it]) * g + b1[it] * m
379 v[:] = (1 - b2) * (g**2) + b2 * v
381 phi = m
382 factor = (1 - b1[it]) ** 2 / (1 - b1[it - 1]) ** 2
383 vhat[:] = np.maximum(factor * vhat, v)
384 # sanitize zero-gradient elements
385 if eps > 0:
386 vhat = np.maximum(vhat, eps)
387 psi = np.sqrt(vhat)
388 return phi, psi
391# noinspection PyUnusedLocal
392def _radam_phi_psi(it, g, m, v, vhat, b1, b2, eps, p):
393 rho_inf = 2 / (1 - b2) - 1
395 # moving averages
396 m[:] = (1 - b1[it]) * g + b1[it] * m
397 v[:] = (1 - b2) * (g**2) + b2 * v
399 # bias correction
400 t = it + 1
401 phi = m / (1 - b1[it] ** t)
402 rho = rho_inf - 2 * t * b2**t / (1 - b2**t)
404 if rho > 4:
405 psi = np.sqrt(v / (1 - b2**t))
406 r = np.sqrt((rho - 4) * (rho - 2) * rho_inf / (rho_inf - 4) / (rho_inf - 2) / rho)
407 psi /= r
408 else:
409 psi = np.ones(g.shape, g.dtype)
410 # sanitize zero-gradient elements
411 if eps > 0:
412 psi = np.maximum(psi, np.sqrt(eps))
413 return phi, psi
416# Dictionary to link ADAM variation names to their functional algorithms.
417phi_psi = {
418 "adam": _adam_phi_psi,
419 "nadam": _nadam_phi_psi,
420 "amsgrad": _amsgrad_phi_psi,
421 "padam": _padam_phi_psi,
422 "adamx": _adamx_phi_psi,
423 "radam": _radam_phi_psi,
424}
427class SingleItemArray:
428 """Mock an array with only a single item"""
430 def __init__(self, value):
431 self.value = value
433 def __getitem__(self, item):
434 return self.value
437class AdaproxParameter(Parameter):
438 """Operator updated using te Proximal ADAM algorithm
440 Uses multiple variants of adaptive quasi-Newton gradient descent
441 * Adam (Kingma & Ba 2015)
442 * NAdam (Dozat 2016)
443 * AMSGrad (Reddi, Kale & Kumar 2018)
444 * PAdam (Chen & Gu 2018)
445 * AdamX (Phuong & Phong 2019)
446 * RAdam (Liu et al. 2019)
447 See details of the algorithms in the respective papers.
448 """
450 def __init__(
451 self,
452 x: np.ndarray,
453 step: Callable | float,
454 grad: Callable | None = None,
455 prox: Callable | None = None,
456 b1: float | SingleItemArray = 0.9,
457 b2: float = 0.999,
458 eps: float = 1e-8,
459 p: float = 0.25,
460 m0: np.ndarray | None = None,
461 v0: np.ndarray | None = None,
462 vhat0: np.ndarray | None = None,
463 scheme: str = "amsgrad",
464 prox_e_rel: float = 1e-6,
465 ):
466 shape = x.shape
467 dtype = x.dtype
468 if m0 is None:
469 m0 = np.zeros(shape, dtype=dtype)
471 if v0 is None:
472 v0 = np.zeros(shape, dtype=dtype)
474 if vhat0 is None:
475 vhat0 = np.ones(shape, dtype=dtype) * -np.inf
477 super().__init__(
478 x,
479 {
480 "m": m0,
481 "v": v0,
482 "vhat": vhat0,
483 },
484 step,
485 grad,
486 prox,
487 )
489 if isinstance(b1, float):
490 _b1 = SingleItemArray(b1)
491 else:
492 _b1 = b1
494 self.b1 = _b1
495 self.b2 = b2
496 self.eps = eps
497 self.p = p
499 self.scheme = scheme
500 self.phi_psi = phi_psi[scheme]
501 self.e_rel = prox_e_rel
503 def update(self, it: int, input_grad: np.ndarray, *args):
504 """Update the parameter and meta-parameters using the PGM
506 See `~Parameter` for more.
507 """
508 _x = self.x
509 # Calculate the gradient
510 grad = cast(Callable, self.grad)(input_grad, _x, *args)
511 # Get the update for the parameter
512 phi, psi = self.phi_psi(
513 it,
514 grad,
515 self.helpers["m"],
516 self.helpers["v"],
517 self.helpers["vhat"],
518 self.b1,
519 self.b2,
520 self.eps,
521 self.p,
522 )
523 # Calculate the step size
524 step = self.step
525 if it > 0:
526 _x += -step * phi / psi
527 else:
528 # This is a scheme that Peter Melchior and I came up with to
529 # dampen the known affect of ADAM, where the first iteration
530 # is often much larger than desired.
531 _x += -step * phi / psi / 10
533 self.x = cast(Callable, self.prox)(_x)
535 def __deepcopy__(self, memo: dict[int, Any] | None = None) -> AdaproxParameter:
536 """Create a deep copy of this parameter.
538 Parameters
539 ----------
540 memo:
541 A memoization dictionary used by `copy.deepcopy`.
542 Returns
543 -------
544 parameter:
545 A deep copy of this parameter.
546 """
547 return AdaproxParameter(
548 deepcopy(self.x, memo),
549 self.step,
550 self.grad,
551 self.prox,
552 self.b1,
553 self.b2,
554 self.eps,
555 self.p,
556 deepcopy(self.helpers["m"], memo),
557 deepcopy(self.helpers["v"], memo),
558 deepcopy(self.helpers["vhat"], memo),
559 scheme=self.scheme,
560 prox_e_rel=self.e_rel,
561 )
563 def __copy__(self) -> AdaproxParameter:
564 """Create a shallow copy of this parameter.
566 Returns
567 -------
568 parameter:
569 A shallow copy of this parameter.
570 """
571 return AdaproxParameter(
572 self.x,
573 self.step,
574 self.grad,
575 self.prox,
576 self.b1,
577 self.b2,
578 self.eps,
579 self.p,
580 self.helpers["m"],
581 self.helpers["v"],
582 self.helpers["vhat"],
583 scheme=self.scheme,
584 prox_e_rel=self.e_rel,
585 )
588class FixedParameter(Parameter):
589 """A parameter that is not updated"""
591 def __init__(self, x: np.ndarray):
592 super().__init__(x, {}, 0)
594 def update(self, it: int, input_grad: np.ndarray, *args):
595 pass
597 def __copy__(self) -> FixedParameter:
598 """Create a shallow copy of this parameter.
600 Returns
601 -------
602 parameter:
603 A shallow copy of this parameter.
604 """
605 return FixedParameter(self.x)
607 def __deepcopy__(self, memo: dict[int, Any] | None = None) -> FixedParameter:
608 """Create a deep copy of this parameter.
610 Parameters
611 ----------
612 memo:
613 A memoization dictionary used by `copy.deepcopy`.
615 Returns
616 -------
617 parameter:
618 A deep copy of this parameter.
619 """
620 return FixedParameter(deepcopy(self.x, memo))
623def relative_step(
624 x: np.ndarray,
625 factor: float = 0.1,
626 minimum: float = 0,
627 axis: int | Sequence[int] | None = None,
628):
629 """Step size set at `factor` times the mean of `X` in direction `axis`"""
630 return np.maximum(minimum, factor * x.mean(axis=axis))