Coverage for python / lsst / scarlet / lite / parameters.py: 22%

187 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-26 08:40 +0000

1# This file is part of scarlet_lite. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = [ 

25 "parameter", 

26 "Parameter", 

27 "FistaParameter", 

28 "AdaproxParameter", 

29 "FixedParameter", 

30 "relative_step", 

31 "phi_psi", 

32 "DEFAULT_ADAPROX_FACTOR", 

33] 

34 

35from copy import deepcopy 

36from typing import Any, Callable, Sequence, cast 

37 

38import numpy as np 

39import numpy.typing as npt 

40 

41from .bbox import Box 

42 

43# The default factor used for adaprox parameter steps 

44DEFAULT_ADAPROX_FACTOR = 1e-2 

45 

46 

47def step_function_wrapper(step: float) -> Callable: 

48 """Wrapper to make a numerical step into a step function 

49 

50 Parameters 

51 ---------- 

52 step: 

53 The step to take for a given array. 

54 

55 Returns 

56 ------- 

57 step_function: 

58 The step function that takes an array and returns the 

59 numerical step. 

60 """ 

61 return lambda x: step 

62 

63 

64class Parameter: 

65 """A parameter in a `Component` 

66 

67 Parameters 

68 ---------- 

69 x: 

70 The array of values that is being fit. 

71 helpers: 

72 A dictionary of helper arrays that are used by an optimizer to 

73 persist values like the gradient of `x`, the Hessian of `x`, etc. 

74 step: 

75 A numerical step value or function to calculate the step for a 

76 given `x``. 

77 grad: 

78 A function to calculate the gradient of `x`. 

79 prox: 

80 A function to take the proximal operator of `x`. 

81 """ 

82 

83 def __init__( 

84 self, 

85 x: np.ndarray, 

86 helpers: dict[str, np.ndarray], 

87 step: Callable | float, 

88 grad: Callable | None = None, 

89 prox: Callable | None = None, 

90 ): 

91 self.x = x 

92 self.helpers = helpers 

93 

94 if isinstance(step, float): 

95 _step = step_function_wrapper(step) 

96 else: 

97 _step = step 

98 

99 self._step = _step 

100 self.grad = grad 

101 self.prox = prox 

102 

103 @property 

104 def step(self) -> float: 

105 """Calculate the step 

106 

107 Return 

108 ------ 

109 step: 

110 The numerical step if no iteration is given. 

111 """ 

112 return self._step(self.x) 

113 

114 @property 

115 def shape(self) -> tuple[int, ...]: 

116 """The shape of the array that is being fit.""" 

117 return self.x.shape 

118 

119 @property 

120 def dtype(self) -> npt.DTypeLike: 

121 """The numpy dtype of the array that is being fit.""" 

122 return self.x.dtype 

123 

124 def __copy__(self) -> Parameter: 

125 """Create a shallow copy of this parameter. 

126 

127 Returns 

128 ------- 

129 parameter: 

130 A shallow copy of this parameter. 

131 """ 

132 helpers = {k: v.copy() for k, v in self.helpers.items()} 

133 return Parameter(self.x.copy(), helpers, 0) 

134 

135 def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Parameter: 

136 """Create a deep copy of this parameter. 

137 

138 Parameters 

139 ---------- 

140 memo: 

141 A memoization dictionary used by `copy.deepcopy`. 

142 Returns 

143 ------- 

144 parameter: 

145 A deep copy of this parameter. 

146 """ 

147 helpers = {k: deepcopy(v, memo) for k, v in self.helpers.items()} 

148 return Parameter(deepcopy(self.x, memo), helpers, 0) 

149 

150 def copy(self, deep: bool = False) -> Parameter: 

151 """Copy this parameter, including all of the helper arrays. 

152 

153 Parameters 

154 ---------- 

155 deep: 

156 If `True`, a deep copy is made. 

157 If `False`, a shallow copy is made. 

158 

159 Returns 

160 ------- 

161 parameter: 

162 A copy of this parameter. 

163 """ 

164 if deep: 

165 return self.__deepcopy__({}) 

166 return self.__copy__() 

167 

168 def update(self, it: int, input_grad: np.ndarray, *args): 

169 """Update the parameter in one iteration. 

170 

171 This includes the gradient update, proximal update, 

172 and any meta parameters that are stored as class 

173 attributes to update the parameter. 

174 

175 Parameters 

176 ---------- 

177 it: 

178 The current iteration 

179 input_grad: 

180 The gradient from the full model, passed to the parameter. 

181 """ 

182 raise NotImplementedError("Base Parameters cannot be updated") 

183 

184 def resize(self, old_box: Box, new_box: Box): 

185 """Grow the parameter and all of the helper parameters 

186 

187 Parameters 

188 ---------- 

189 old_box: 

190 The old bounding box for the parameter. 

191 new_box: 

192 The new bounding box for the parameter. 

193 """ 

194 slices = new_box.overlapped_slices(old_box) 

195 x = np.zeros(new_box.shape, dtype=self.dtype) 

196 x[slices[0]] = self.x[slices[1]] 

197 self.x = x 

198 

199 for name, value in self.helpers.items(): 

200 result = np.zeros(new_box.shape, dtype=self.dtype) 

201 result[slices[0]] = value[slices[1]] 

202 self.helpers[name] = result 

203 

204 

205def parameter(x: np.ndarray | Parameter) -> Parameter: 

206 """Convert a `np.ndarray` into a `Parameter`. 

207 

208 Parameters 

209 ---------- 

210 x: 

211 The array or parameter to convert into a `Parameter`. 

212 

213 Returns 

214 ------- 

215 result: 

216 `x`, converted into a `Parameter` if necessary. 

217 """ 

218 if isinstance(x, Parameter): 

219 return x 

220 return Parameter(x, {}, 0) 

221 

222 

223class FistaParameter(Parameter): 

224 """A `Parameter` that updates itself using the Beck-Teboulle 2009 

225 FISTA proximal gradient method. 

226 

227 See https://www.ceremade.dauphine.fr/~carlier/FISTA 

228 """ 

229 

230 def __init__( 

231 self, 

232 x: np.ndarray, 

233 step: float, 

234 grad: Callable | None = None, 

235 prox: Callable | None = None, 

236 t0: float = 1, 

237 z0: np.ndarray | None = None, 

238 ): 

239 if z0 is None: 

240 z0 = x.copy() 

241 

242 super().__init__( 

243 x, 

244 {"z": z0}, 

245 step, 

246 grad, 

247 prox, 

248 ) 

249 self.t = t0 

250 

251 def update(self, it: int, input_grad: np.ndarray, *args): 

252 """Update the parameter and meta-parameters using the PGM 

253 

254 See `Parameter` for the full description. 

255 """ 

256 if len(args) == 0: 

257 step = self.step 

258 else: 

259 step = self.step / np.sum(args[0] * args[0]) 

260 _x = self.x 

261 _z = self.helpers["z"] 

262 

263 y = _z - step * cast(Callable, self.grad)(input_grad, _x, *args) 

264 if self.prox is not None: 

265 x = self.prox(y) 

266 else: 

267 x = y 

268 t = 0.5 * (1 + np.sqrt(1 + 4 * self.t**2)) 

269 omega = 1 + (self.t - 1) / t 

270 self.helpers["z"] = _x + omega * (x - _x) 

271 _x[:] = x 

272 self.t = t 

273 

274 def __deepcopy__(self, memo: dict[int, Any] | None = None) -> FistaParameter: 

275 """Create a deep copy of this parameter. 

276 

277 Parameters 

278 ---------- 

279 memo: 

280 A memoization dictionary used by `copy.deepcopy`. 

281 Returns 

282 ------- 

283 parameter: 

284 A deep copy of this parameter. 

285 """ 

286 return FistaParameter( 

287 deepcopy(self.x, memo), 

288 self.step, 

289 self.grad, 

290 self.prox, 

291 self.t, 

292 deepcopy(self.helpers["z"], memo), 

293 ) 

294 

295 def __copy__(self) -> FistaParameter: 

296 """Create a shallow copy of this parameter. 

297 

298 Returns 

299 ------- 

300 parameter: 

301 A shallow copy of this parameter. 

302 """ 

303 return FistaParameter( 

304 self.x.copy(), 

305 self.step, 

306 self.grad, 

307 self.prox, 

308 self.t, 

309 self.helpers["z"].copy(), 

310 ) 

311 

312 

313# The following code block contains different update methods for 

314# various implementations of ADAM. 

315# We currently use the `amsgrad_phi_psi` update by default, 

316# but it can easily be interchanged by passing a different 

317# variant name to the `AdaproxParameter`. 

318 

319 

320# noinspection PyUnusedLocal 

321def _adam_phi_psi(it, g, m, v, vhat, b1, b2, eps, p): 

322 # moving averages 

323 m[:] = (1 - b1[it]) * g + b1[it] * m 

324 v[:] = (1 - b2) * (g**2) + b2 * v 

325 

326 # bias correction 

327 t = it + 1 

328 phi = m / (1 - b1[it] ** t) 

329 psi = np.sqrt(v / (1 - b2**t)) + eps 

330 return phi, psi 

331 

332 

333# noinspection PyUnusedLocal 

334def _nadam_phi_psi(it, g, m, v, vhat, b1, b2, eps, p): 

335 # moving averages 

336 m[:] = (1 - b1[it]) * g + b1[it] * m 

337 v[:] = (1 - b2) * (g**2) + b2 * v 

338 

339 # bias correction 

340 t = it + 1 

341 phi = (b1[it] * m[:] + (1 - b1[it]) * g) / (1 - b1[it] ** t) 

342 psi = np.sqrt(v / (1 - b2**t)) + eps 

343 return phi, psi 

344 

345 

346# noinspection PyUnusedLocal 

347def _amsgrad_phi_psi(it, g, m, v, vhat, b1, b2, eps, p): 

348 # moving averages 

349 m[:] = (1 - b1[it]) * g + b1[it] * m 

350 v[:] = (1 - b2) * (g**2) + b2 * v 

351 

352 phi = m 

353 vhat[:] = np.maximum(vhat, v) 

354 # sanitize zero-gradient elements 

355 if eps > 0: 

356 vhat = np.maximum(vhat, eps) 

357 psi = np.sqrt(vhat) 

358 return phi, psi 

359 

360 

361def _padam_phi_psi(it, g, m, v, vhat, b1, b2, eps, p): 

362 # moving averages 

363 m[:] = (1 - b1[it]) * g + b1[it] * m 

364 v[:] = (1 - b2) * (g**2) + b2 * v 

365 

366 phi = m 

367 vhat[:] = np.maximum(vhat, v) 

368 # sanitize zero-gradient elements 

369 if eps > 0: 

370 vhat = np.maximum(vhat, eps) 

371 psi = vhat**p 

372 return phi, psi 

373 

374 

375# noinspection PyUnusedLocal 

376def _adamx_phi_psi(it, g, m, v, vhat, b1, b2, eps, p): 

377 # moving averages 

378 m[:] = (1 - b1[it]) * g + b1[it] * m 

379 v[:] = (1 - b2) * (g**2) + b2 * v 

380 

381 phi = m 

382 factor = (1 - b1[it]) ** 2 / (1 - b1[it - 1]) ** 2 

383 vhat[:] = np.maximum(factor * vhat, v) 

384 # sanitize zero-gradient elements 

385 if eps > 0: 

386 vhat = np.maximum(vhat, eps) 

387 psi = np.sqrt(vhat) 

388 return phi, psi 

389 

390 

391# noinspection PyUnusedLocal 

392def _radam_phi_psi(it, g, m, v, vhat, b1, b2, eps, p): 

393 rho_inf = 2 / (1 - b2) - 1 

394 

395 # moving averages 

396 m[:] = (1 - b1[it]) * g + b1[it] * m 

397 v[:] = (1 - b2) * (g**2) + b2 * v 

398 

399 # bias correction 

400 t = it + 1 

401 phi = m / (1 - b1[it] ** t) 

402 rho = rho_inf - 2 * t * b2**t / (1 - b2**t) 

403 

404 if rho > 4: 

405 psi = np.sqrt(v / (1 - b2**t)) 

406 r = np.sqrt((rho - 4) * (rho - 2) * rho_inf / (rho_inf - 4) / (rho_inf - 2) / rho) 

407 psi /= r 

408 else: 

409 psi = np.ones(g.shape, g.dtype) 

410 # sanitize zero-gradient elements 

411 if eps > 0: 

412 psi = np.maximum(psi, np.sqrt(eps)) 

413 return phi, psi 

414 

415 

416# Dictionary to link ADAM variation names to their functional algorithms. 

417phi_psi = { 

418 "adam": _adam_phi_psi, 

419 "nadam": _nadam_phi_psi, 

420 "amsgrad": _amsgrad_phi_psi, 

421 "padam": _padam_phi_psi, 

422 "adamx": _adamx_phi_psi, 

423 "radam": _radam_phi_psi, 

424} 

425 

426 

427class SingleItemArray: 

428 """Mock an array with only a single item""" 

429 

430 def __init__(self, value): 

431 self.value = value 

432 

433 def __getitem__(self, item): 

434 return self.value 

435 

436 

437class AdaproxParameter(Parameter): 

438 """Operator updated using te Proximal ADAM algorithm 

439 

440 Uses multiple variants of adaptive quasi-Newton gradient descent 

441 * Adam (Kingma & Ba 2015) 

442 * NAdam (Dozat 2016) 

443 * AMSGrad (Reddi, Kale & Kumar 2018) 

444 * PAdam (Chen & Gu 2018) 

445 * AdamX (Phuong & Phong 2019) 

446 * RAdam (Liu et al. 2019) 

447 See details of the algorithms in the respective papers. 

448 """ 

449 

450 def __init__( 

451 self, 

452 x: np.ndarray, 

453 step: Callable | float, 

454 grad: Callable | None = None, 

455 prox: Callable | None = None, 

456 b1: float | SingleItemArray = 0.9, 

457 b2: float = 0.999, 

458 eps: float = 1e-8, 

459 p: float = 0.25, 

460 m0: np.ndarray | None = None, 

461 v0: np.ndarray | None = None, 

462 vhat0: np.ndarray | None = None, 

463 scheme: str = "amsgrad", 

464 prox_e_rel: float = 1e-6, 

465 ): 

466 shape = x.shape 

467 dtype = x.dtype 

468 if m0 is None: 

469 m0 = np.zeros(shape, dtype=dtype) 

470 

471 if v0 is None: 

472 v0 = np.zeros(shape, dtype=dtype) 

473 

474 if vhat0 is None: 

475 vhat0 = np.ones(shape, dtype=dtype) * -np.inf 

476 

477 super().__init__( 

478 x, 

479 { 

480 "m": m0, 

481 "v": v0, 

482 "vhat": vhat0, 

483 }, 

484 step, 

485 grad, 

486 prox, 

487 ) 

488 

489 if isinstance(b1, float): 

490 _b1 = SingleItemArray(b1) 

491 else: 

492 _b1 = b1 

493 

494 self.b1 = _b1 

495 self.b2 = b2 

496 self.eps = eps 

497 self.p = p 

498 

499 self.scheme = scheme 

500 self.phi_psi = phi_psi[scheme] 

501 self.e_rel = prox_e_rel 

502 

503 def update(self, it: int, input_grad: np.ndarray, *args): 

504 """Update the parameter and meta-parameters using the PGM 

505 

506 See `~Parameter` for more. 

507 """ 

508 _x = self.x 

509 # Calculate the gradient 

510 grad = cast(Callable, self.grad)(input_grad, _x, *args) 

511 # Get the update for the parameter 

512 phi, psi = self.phi_psi( 

513 it, 

514 grad, 

515 self.helpers["m"], 

516 self.helpers["v"], 

517 self.helpers["vhat"], 

518 self.b1, 

519 self.b2, 

520 self.eps, 

521 self.p, 

522 ) 

523 # Calculate the step size 

524 step = self.step 

525 if it > 0: 

526 _x += -step * phi / psi 

527 else: 

528 # This is a scheme that Peter Melchior and I came up with to 

529 # dampen the known affect of ADAM, where the first iteration 

530 # is often much larger than desired. 

531 _x += -step * phi / psi / 10 

532 

533 self.x = cast(Callable, self.prox)(_x) 

534 

535 def __deepcopy__(self, memo: dict[int, Any] | None = None) -> AdaproxParameter: 

536 """Create a deep copy of this parameter. 

537 

538 Parameters 

539 ---------- 

540 memo: 

541 A memoization dictionary used by `copy.deepcopy`. 

542 Returns 

543 ------- 

544 parameter: 

545 A deep copy of this parameter. 

546 """ 

547 return AdaproxParameter( 

548 deepcopy(self.x, memo), 

549 self.step, 

550 self.grad, 

551 self.prox, 

552 self.b1, 

553 self.b2, 

554 self.eps, 

555 self.p, 

556 deepcopy(self.helpers["m"], memo), 

557 deepcopy(self.helpers["v"], memo), 

558 deepcopy(self.helpers["vhat"], memo), 

559 scheme=self.scheme, 

560 prox_e_rel=self.e_rel, 

561 ) 

562 

563 def __copy__(self) -> AdaproxParameter: 

564 """Create a shallow copy of this parameter. 

565 

566 Returns 

567 ------- 

568 parameter: 

569 A shallow copy of this parameter. 

570 """ 

571 return AdaproxParameter( 

572 self.x, 

573 self.step, 

574 self.grad, 

575 self.prox, 

576 self.b1, 

577 self.b2, 

578 self.eps, 

579 self.p, 

580 self.helpers["m"], 

581 self.helpers["v"], 

582 self.helpers["vhat"], 

583 scheme=self.scheme, 

584 prox_e_rel=self.e_rel, 

585 ) 

586 

587 

588class FixedParameter(Parameter): 

589 """A parameter that is not updated""" 

590 

591 def __init__(self, x: np.ndarray): 

592 super().__init__(x, {}, 0) 

593 

594 def update(self, it: int, input_grad: np.ndarray, *args): 

595 pass 

596 

597 def __copy__(self) -> FixedParameter: 

598 """Create a shallow copy of this parameter. 

599 

600 Returns 

601 ------- 

602 parameter: 

603 A shallow copy of this parameter. 

604 """ 

605 return FixedParameter(self.x) 

606 

607 def __deepcopy__(self, memo: dict[int, Any] | None = None) -> FixedParameter: 

608 """Create a deep copy of this parameter. 

609 

610 Parameters 

611 ---------- 

612 memo: 

613 A memoization dictionary used by `copy.deepcopy`. 

614 

615 Returns 

616 ------- 

617 parameter: 

618 A deep copy of this parameter. 

619 """ 

620 return FixedParameter(deepcopy(self.x, memo)) 

621 

622 

623def relative_step( 

624 x: np.ndarray, 

625 factor: float = 0.1, 

626 minimum: float = 0, 

627 axis: int | Sequence[int] | None = None, 

628): 

629 """Step size set at `factor` times the mean of `X` in direction `axis`""" 

630 return np.maximum(minimum, factor * x.mean(axis=axis))