Coverage for python/lsst/scarlet/lite/parameters.py: 23%

164 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-05-01 15:13 -0700

1# This file is part of scarlet_lite. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = [ 

25 "parameter", 

26 "Parameter", 

27 "FistaParameter", 

28 "AdaproxParameter", 

29 "FixedParameter", 

30 "relative_step", 

31 "phi_psi", 

32 "DEFAULT_ADAPROX_FACTOR", 

33] 

34 

35from typing import Callable, Sequence, cast 

36 

37import numpy as np 

38import numpy.typing as npt 

39 

40from .bbox import Box 

41 

42# The default factor used for adaprox parameter steps 

43DEFAULT_ADAPROX_FACTOR = 1e-2 

44 

45 

46def step_function_wrapper(step: float) -> Callable: 

47 """Wrapper to make a numerical step into a step function 

48 

49 Parameters 

50 ---------- 

51 step: 

52 The step to take for a given array. 

53 

54 Returns 

55 ------- 

56 step_function: 

57 The step function that takes an array and returns the 

58 numerical step. 

59 """ 

60 return lambda x: step 

61 

62 

63class Parameter: 

64 """A parameter in a `Component` 

65 

66 Parameters 

67 ---------- 

68 x: 

69 The array of values that is being fit. 

70 helpers: 

71 A dictionary of helper arrays that are used by an optimizer to 

72 persist values like the gradient of `x`, the Hessian of `x`, etc. 

73 step: 

74 A numerical step value or function to calculate the step for a 

75 given `x``. 

76 grad: 

77 A function to calculate the gradient of `x`. 

78 prox: 

79 A function to take the proximal operator of `x`. 

80 """ 

81 

82 def __init__( 

83 self, 

84 x: np.ndarray, 

85 helpers: dict[str, np.ndarray], 

86 step: Callable | float, 

87 grad: Callable | None = None, 

88 prox: Callable | None = None, 

89 ): 

90 self.x = x 

91 self.helpers = helpers 

92 

93 if isinstance(step, float): 

94 _step = step_function_wrapper(step) 

95 else: 

96 _step = step 

97 

98 self._step = _step 

99 self.grad = grad 

100 self.prox = prox 

101 

102 @property 

103 def step(self) -> float: 

104 """Calculate the step 

105 

106 Return 

107 ------ 

108 step: 

109 The numerical step if no iteration is given. 

110 """ 

111 return self._step(self.x) 

112 

113 @property 

114 def shape(self) -> tuple[int, ...]: 

115 """The shape of the array that is being fit.""" 

116 return self.x.shape 

117 

118 @property 

119 def dtype(self) -> npt.DTypeLike: 

120 """The numpy dtype of the array that is being fit.""" 

121 return self.x.dtype 

122 

123 def copy(self) -> Parameter: 

124 """Copy this parameter, including all of the helper arrays.""" 

125 helpers = {k: v.copy() for k, v in self.helpers.items()} 

126 return Parameter(self.x.copy(), helpers, 0) 

127 

128 def update(self, it: int, input_grad: np.ndarray, *args): 

129 """Update the parameter in one iteration. 

130 

131 This includes the gradient update, proximal update, 

132 and any meta parameters that are stored as class 

133 attributes to update the parameter. 

134 

135 Parameters 

136 ---------- 

137 it: 

138 The current iteration 

139 input_grad: 

140 The gradient from the full model, passed to the parameter. 

141 """ 

142 raise NotImplementedError("Base Parameters cannot be updated") 

143 

144 def resize(self, old_box: Box, new_box: Box): 

145 """Grow the parameter and all of the helper parameters 

146 

147 Parameters 

148 ---------- 

149 old_box: 

150 The old bounding box for the parameter. 

151 new_box: 

152 The new bounding box for the parameter. 

153 """ 

154 slices = new_box.overlapped_slices(old_box) 

155 x = np.zeros(new_box.shape, dtype=self.dtype) 

156 x[slices[0]] = self.x[slices[1]] 

157 self.x = x 

158 

159 for name, value in self.helpers.items(): 

160 result = np.zeros(new_box.shape, dtype=self.dtype) 

161 result[slices[0]] = value[slices[1]] 

162 self.helpers[name] = result 

163 

164 

165def parameter(x: np.ndarray | Parameter) -> Parameter: 

166 """Convert a `np.ndarray` into a `Parameter`. 

167 

168 Parameters 

169 ---------- 

170 x: 

171 The array or parameter to convert into a `Parameter`. 

172 

173 Returns 

174 ------- 

175 result: 

176 `x`, converted into a `Parameter` if necessary. 

177 """ 

178 if isinstance(x, Parameter): 

179 return x 

180 return Parameter(x, {}, 0) 

181 

182 

183class FistaParameter(Parameter): 

184 """A `Parameter` that updates itself using the Beck-Teboulle 2009 

185 FISTA proximal gradient method. 

186 

187 See https://www.ceremade.dauphine.fr/~carlier/FISTA 

188 """ 

189 

190 def __init__( 

191 self, 

192 x: np.ndarray, 

193 step: float, 

194 grad: Callable | None = None, 

195 prox: Callable | None = None, 

196 t0: float = 1, 

197 z0: np.ndarray | None = None, 

198 ): 

199 if z0 is None: 

200 z0 = x 

201 

202 super().__init__( 

203 x, 

204 {"z": z0}, 

205 step, 

206 grad, 

207 prox, 

208 ) 

209 self.t = t0 

210 

211 def update(self, it: int, input_grad: np.ndarray, *args): 

212 """Update the parameter and meta-parameters using the PGM 

213 

214 See `Parameter` for the full description. 

215 """ 

216 step = self.step / np.sum(args[0] * args[0]) 

217 _x = self.x 

218 _z = self.helpers["z"] 

219 

220 y = _z - step * cast(Callable, self.grad)(input_grad, _x, *args) 

221 if self.prox is not None: 

222 x = self.prox(y) 

223 else: 

224 x = y 

225 t = 0.5 * (1 + np.sqrt(1 + 4 * self.t**2)) 

226 omega = 1 + (self.t - 1) / t 

227 self.helpers["z"] = _x + omega * (x - _x) 

228 _x[:] = x 

229 self.t = t 

230 

231 

232# The following code block contains different update methods for 

233# various implementations of ADAM. 

234# We currently use the `amsgrad_phi_psi` update by default, 

235# but it can easily be interchanged by passing a different 

236# variant name to the `AdaproxParameter`. 

237 

238 

239# noinspection PyUnusedLocal 

240def _adam_phi_psi(it, g, m, v, vhat, b1, b2, eps, p): 

241 # moving averages 

242 m[:] = (1 - b1[it]) * g + b1[it] * m 

243 v[:] = (1 - b2) * (g**2) + b2 * v 

244 

245 # bias correction 

246 t = it + 1 

247 phi = m / (1 - b1[it] ** t) 

248 psi = np.sqrt(v / (1 - b2**t)) + eps 

249 return phi, psi 

250 

251 

252# noinspection PyUnusedLocal 

253def _nadam_phi_psi(it, g, m, v, vhat, b1, b2, eps, p): 

254 # moving averages 

255 m[:] = (1 - b1[it]) * g + b1[it] * m 

256 v[:] = (1 - b2) * (g**2) + b2 * v 

257 

258 # bias correction 

259 t = it + 1 

260 phi = (b1[it] * m[:] + (1 - b1[it]) * g) / (1 - b1[it] ** t) 

261 psi = np.sqrt(v / (1 - b2**t)) + eps 

262 return phi, psi 

263 

264 

265# noinspection PyUnusedLocal 

266def _amsgrad_phi_psi(it, g, m, v, vhat, b1, b2, eps, p): 

267 # moving averages 

268 m[:] = (1 - b1[it]) * g + b1[it] * m 

269 v[:] = (1 - b2) * (g**2) + b2 * v 

270 

271 phi = m 

272 vhat[:] = np.maximum(vhat, v) 

273 # sanitize zero-gradient elements 

274 if eps > 0: 

275 vhat = np.maximum(vhat, eps) 

276 psi = np.sqrt(vhat) 

277 return phi, psi 

278 

279 

280def _padam_phi_psi(it, g, m, v, vhat, b1, b2, eps, p): 

281 # moving averages 

282 m[:] = (1 - b1[it]) * g + b1[it] * m 

283 v[:] = (1 - b2) * (g**2) + b2 * v 

284 

285 phi = m 

286 vhat[:] = np.maximum(vhat, v) 

287 # sanitize zero-gradient elements 

288 if eps > 0: 

289 vhat = np.maximum(vhat, eps) 

290 psi = vhat**p 

291 return phi, psi 

292 

293 

294# noinspection PyUnusedLocal 

295def _adamx_phi_psi(it, g, m, v, vhat, b1, b2, eps, p): 

296 # moving averages 

297 m[:] = (1 - b1[it]) * g + b1[it] * m 

298 v[:] = (1 - b2) * (g**2) + b2 * v 

299 

300 phi = m 

301 factor = (1 - b1[it]) ** 2 / (1 - b1[it - 1]) ** 2 

302 vhat[:] = np.maximum(factor * vhat, v) 

303 # sanitize zero-gradient elements 

304 if eps > 0: 

305 vhat = np.maximum(vhat, eps) 

306 psi = np.sqrt(vhat) 

307 return phi, psi 

308 

309 

310# noinspection PyUnusedLocal 

311def _radam_phi_psi(it, g, m, v, vhat, b1, b2, eps, p): 

312 rho_inf = 2 / (1 - b2) - 1 

313 

314 # moving averages 

315 m[:] = (1 - b1[it]) * g + b1[it] * m 

316 v[:] = (1 - b2) * (g**2) + b2 * v 

317 

318 # bias correction 

319 t = it + 1 

320 phi = m / (1 - b1[it] ** t) 

321 rho = rho_inf - 2 * t * b2**t / (1 - b2**t) 

322 

323 if rho > 4: 

324 psi = np.sqrt(v / (1 - b2**t)) 

325 r = np.sqrt((rho - 4) * (rho - 2) * rho_inf / (rho_inf - 4) / (rho_inf - 2) / rho) 

326 psi /= r 

327 else: 

328 psi = np.ones(g.shape, g.dtype) 

329 # sanitize zero-gradient elements 

330 if eps > 0: 

331 psi = np.maximum(psi, np.sqrt(eps)) 

332 return phi, psi 

333 

334 

335# Dictionary to link ADAM variation names to their functional algorithms. 

336phi_psi = { 

337 "adam": _adam_phi_psi, 

338 "nadam": _nadam_phi_psi, 

339 "amsgrad": _amsgrad_phi_psi, 

340 "padam": _padam_phi_psi, 

341 "adamx": _adamx_phi_psi, 

342 "radam": _radam_phi_psi, 

343} 

344 

345 

346class SingleItemArray: 

347 """Mock an array with only a single item""" 

348 

349 def __init__(self, value): 

350 self.value = value 

351 

352 def __getitem__(self, item): 

353 return self.value 

354 

355 

356class AdaproxParameter(Parameter): 

357 """Operator updated using te Proximal ADAM algorithm 

358 

359 Uses multiple variants of adaptive quasi-Newton gradient descent 

360 * Adam (Kingma & Ba 2015) 

361 * NAdam (Dozat 2016) 

362 * AMSGrad (Reddi, Kale & Kumar 2018) 

363 * PAdam (Chen & Gu 2018) 

364 * AdamX (Phuong & Phong 2019) 

365 * RAdam (Liu et al. 2019) 

366 See details of the algorithms in the respective papers. 

367 """ 

368 

369 def __init__( 

370 self, 

371 x: np.ndarray, 

372 step: Callable | float, 

373 grad: Callable | None = None, 

374 prox: Callable | None = None, 

375 b1: float = 0.9, 

376 b2: float = 0.999, 

377 eps: float = 1e-8, 

378 p: float = 0.25, 

379 m0: np.ndarray | None = None, 

380 v0: np.ndarray | None = None, 

381 vhat0: np.ndarray | None = None, 

382 scheme: str = "amsgrad", 

383 prox_e_rel: float = 1e-6, 

384 ): 

385 shape = x.shape 

386 dtype = x.dtype 

387 if m0 is None: 

388 m0 = np.zeros(shape, dtype=dtype) 

389 

390 if v0 is None: 

391 v0 = np.zeros(shape, dtype=dtype) 

392 

393 if vhat0 is None: 

394 vhat0 = np.ones(shape, dtype=dtype) * -np.inf 

395 

396 super().__init__( 

397 x, 

398 { 

399 "m": m0, 

400 "v": v0, 

401 "vhat": vhat0, 

402 }, 

403 step, 

404 grad, 

405 prox, 

406 ) 

407 

408 if isinstance(b1, float): 

409 _b1 = SingleItemArray(b1) 

410 else: 

411 _b1 = b1 

412 

413 self.b1 = _b1 

414 self.b2 = b2 

415 self.eps = eps 

416 self.p = p 

417 

418 self.phi_psi = phi_psi[scheme] 

419 self.e_rel = prox_e_rel 

420 

421 def update(self, it: int, input_grad: np.ndarray, *args): 

422 """Update the parameter and meta-parameters using the PGM 

423 

424 See `~Parameter` for more. 

425 """ 

426 _x = self.x 

427 # Calculate the gradient 

428 grad = cast(Callable, self.grad)(input_grad, _x, *args) 

429 # Get the update for the parameter 

430 phi, psi = self.phi_psi( 

431 it, 

432 grad, 

433 self.helpers["m"], 

434 self.helpers["v"], 

435 self.helpers["vhat"], 

436 self.b1, 

437 self.b2, 

438 self.eps, 

439 self.p, 

440 ) 

441 # Calculate the step size 

442 step = self.step 

443 if it > 0: 

444 _x += -step * phi / psi 

445 else: 

446 # This is a scheme that Peter Melchior and I came up with to 

447 # dampen the known affect of ADAM, where the first iteration 

448 # is often much larger than desired. 

449 _x += -step * phi / psi / 10 

450 

451 self.x = cast(Callable, self.prox)(_x) 

452 

453 

454class FixedParameter(Parameter): 

455 """A parameter that is not updated""" 

456 

457 def __init__(self, x: np.ndarray): 

458 super().__init__(x, {}, 0) 

459 

460 def update(self, it: int, input_grad: np.ndarray, *args): 

461 pass 

462 

463 

464def relative_step( 

465 x: np.ndarray, 

466 factor: float = 0.1, 

467 minimum: float = 0, 

468 axis: int | Sequence[int] | None = None, 

469): 

470 """Step size set at `factor` times the mean of `X` in direction `axis`""" 

471 return np.maximum(minimum, factor * x.mean(axis=axis))