Coverage for python/lsst/scarlet/lite/fft.py: 13%

151 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-19 10:38 +0000

1# This file is part of scarlet_lite. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ["Fourier"] 

25 

26import operator 

27from typing import Callable, Sequence 

28 

29import numpy as np 

30from numpy.typing import DTypeLike 

31from scipy import fftpack 

32 

33 

34def centered(arr: np.ndarray, newshape: Sequence[int]) -> np.ndarray: 

35 """Return the central newshape portion of the array. 

36 

37 Parameters 

38 ---------- 

39 arr: 

40 The array to center. 

41 newshape: 

42 The new shape of the array. 

43 

44 Notes 

45 ----- 

46 If the array shape is odd and the target is even, 

47 the center of `arr` is shifted to the center-right 

48 pixel position. 

49 This is slightly different than the scipy implementation, 

50 which uses the center-left pixel for the array center. 

51 The reason for the difference is that we have 

52 adopted the convention of `np.fft.fftshift` in order 

53 to make sure that changing back and forth from 

54 fft standard order (0 frequency and position is 

55 in the bottom left) to 0 position in the center. 

56 """ 

57 _newshape = np.array(newshape) 

58 currshape = np.array(arr.shape) 

59 

60 if not np.all(_newshape <= currshape): 

61 msg = f"arr must be larger than newshape in both dimensions, received {arr.shape}, and {_newshape}" 

62 raise ValueError(msg) 

63 

64 startind = (currshape - _newshape + 1) // 2 

65 endind = startind + _newshape 

66 myslice = [slice(startind[k], endind[k]) for k in range(len(endind))] 

67 

68 return arr[tuple(myslice)] 

69 

70 

71def fast_zero_pad(arr: np.ndarray, pad_width: Sequence[Sequence[int]]) -> np.ndarray: 

72 """Fast version of numpy.pad when `mode="constant"` 

73 

74 Executing `numpy.pad` with zeros is ~1000 times slower 

75 because it doesn't make use of the `zeros` method for padding. 

76 

77 Parameters 

78 --------- 

79 arr: 

80 The array to pad 

81 pad_width: 

82 Number of values padded to the edges of each axis. 

83 See numpy.pad docs for more. 

84 

85 Returns 

86 ------- 

87 result: np.ndarray 

88 The array padded with `constant_values` 

89 """ 

90 newshape = tuple([a + ps[0] + ps[1] for a, ps in zip(arr.shape, pad_width)]) 

91 

92 result = np.zeros(newshape, dtype=arr.dtype) 

93 slices = tuple([slice(start, s - end) for s, (start, end) in zip(result.shape, pad_width)]) 

94 result[slices] = arr 

95 return result 

96 

97 

98def _pad( 

99 arr: np.ndarray, 

100 newshape: Sequence[int], 

101 axes: int | Sequence[int] | None = None, 

102 mode: str = "constant", 

103 constant_values: float = 0, 

104) -> np.ndarray: 

105 """Pad an array to fit into newshape 

106 

107 Pad `arr` with zeros to fit into newshape, 

108 which uses the `np.fft.fftshift` convention of moving 

109 the center pixel of `arr` (if `arr.shape` is odd) to 

110 the center-right pixel in an even shaped `newshape`. 

111 

112 Parameters 

113 ---------- 

114 arr: 

115 The arrray to pad. 

116 newshape: 

117 The new shape of the array. 

118 axes: 

119 The axes that are being reshaped. 

120 mode: 

121 The numpy mode used to pad the array. 

122 In other words, how to fill the new padded elements. 

123 See ``numpy.pad`` for details. 

124 constant_values: 

125 If `mode` == "constant" then this is the value to set all of 

126 the new padded elements to. 

127 """ 

128 _newshape = np.asarray(newshape) 

129 if axes is None: 

130 currshape = np.array(arr.shape) 

131 diff = _newshape - currshape 

132 startind = (diff + 1) // 2 

133 endind = diff - startind 

134 pad_width = list(zip(startind, endind)) 

135 else: 

136 # only pad the axes that will be transformed 

137 pad_width = [(0, 0) for _ in arr.shape] 

138 if isinstance(axes, int): 

139 axes = [axes] 

140 for a, axis in enumerate(axes): 

141 diff = _newshape[a] - arr.shape[axis] 

142 startind = (diff + 1) // 2 

143 endind = diff - startind 

144 pad_width[axis] = (startind, endind) 

145 if mode == "constant" and constant_values == 0: 

146 result = fast_zero_pad(arr, pad_width) 

147 else: 

148 result = np.pad(arr, tuple(pad_width), mode=mode) # type: ignore 

149 return result 

150 

151 

152def get_fft_shape( 

153 im_or_shape1: np.ndarray | Sequence[int], 

154 im_or_shape2: np.ndarray | Sequence[int], 

155 padding: int = 3, 

156 axes: int | Sequence[int] | None = None, 

157 use_max: bool = False, 

158) -> tuple: 

159 """Return the fast fft shapes for each spatial axis 

160 

161 Calculate the fast fft shape for each dimension in 

162 axes. 

163 

164 Parameters 

165 ---------- 

166 im_or_shape1: 

167 The left image or shape of an image. 

168 im_or_shape2: 

169 The right image or shape of an image. 

170 padding: 

171 Any additional padding to add to the final shape. 

172 axes: 

173 The axes that are being transformed. 

174 use_max: 

175 Whether or not to use the maximum of the two shapes, 

176 or the sum of the two shapes. 

177 

178 Returns 

179 ------- 

180 shape: 

181 Tuple of the shape to use when the two images are transformed 

182 into k-space. 

183 """ 

184 if isinstance(im_or_shape1, np.ndarray): 

185 shape1 = np.asarray(im_or_shape1.shape) 

186 else: 

187 shape1 = np.asarray(im_or_shape1) 

188 if isinstance(im_or_shape2, np.ndarray): 

189 shape2 = np.asarray(im_or_shape2.shape) 

190 else: 

191 shape2 = np.asarray(im_or_shape2) 

192 # Make sure the shapes are the same size 

193 if len(shape1) != len(shape2): 

194 msg = ( 

195 "img1 and img2 must have the same number of dimensions, " 

196 f"but got {len(shape1)} and {len(shape2)}" 

197 ) 

198 raise ValueError(msg) 

199 # Set the combined shape based on the total dimensions 

200 if axes is None: 

201 if use_max: 

202 shape = np.max([shape1, shape2], axis=0) 

203 else: 

204 shape = shape1 + shape2 

205 else: 

206 if isinstance(axes, int): 

207 axes = [axes] 

208 shape = np.zeros(len(axes), dtype="int") 

209 for n, ax in enumerate(axes): 

210 shape[n] = shape1[ax] + shape2[ax] 

211 if use_max: 

212 shape[n] = np.max([shape1[ax], shape2[ax]]) 

213 

214 shape += padding 

215 # Use the next fastest shape in each dimension 

216 shape = [fftpack.next_fast_len(s) for s in shape] 

217 return tuple(shape) 

218 

219 

220class Fourier: 

221 """An array that stores its Fourier Transform 

222 

223 The `Fourier` class is used for images that will make 

224 use of their Fourier Transform multiple times. 

225 In order to prevent numerical artifacts the same image 

226 convolved with different images might require different 

227 padding, so the FFT for each different shape is stored 

228 in a dictionary. 

229 

230 Parameters 

231 ---------- 

232 image: np.ndarray 

233 The real space image. 

234 image_fft: dict[Sequence[int], np.ndarray] 

235 A dictionary of {shape: fft_value} for which each different 

236 shape has a precalculated FFT. 

237 """ 

238 

239 def __init__( 

240 self, 

241 image: np.ndarray, 

242 image_fft: dict[Sequence[Sequence[int]], np.ndarray] | None = None, 

243 ): 

244 if image_fft is None: 

245 self._fft: dict[Sequence[Sequence[int]], np.ndarray] = {} 

246 else: 

247 self._fft = image_fft 

248 self._image = image 

249 

250 @staticmethod 

251 def from_fft( 

252 image_fft: np.ndarray, 

253 fft_shape: Sequence[int], 

254 image_shape: Sequence[int], 

255 axes: int | Sequence[int] | None = None, 

256 dtype: DTypeLike = float, 

257 ) -> Fourier: 

258 """Generate a new Fourier object from an FFT dictionary 

259 

260 If the fft of an image has been generated but not its 

261 real space image (for example when creating a convolution kernel), 

262 this method can be called to create a new `Fourier` instance 

263 from the k-space representation. 

264 

265 Parameters 

266 ---------- 

267 image_fft: 

268 The FFT of the image. 

269 fft_shape: 

270 "Fast" shape of the image used to generate the FFT. 

271 This will be different than `image_fft.shape` if 

272 any of the dimensions are odd, since `np.fft.rfft` 

273 requires an even number of dimensions (for symmetry), 

274 so this tells `np.fft.irfft` how to go from 

275 complex k-space to real space. 

276 image_shape: 

277 The shape of the image *before padding*. 

278 This will regenerate the image with the extra 

279 padding stripped. 

280 axes: 

281 The dimension(s) of the array that will be transformed. 

282 

283 Returns 

284 ------- 

285 result: 

286 A `Fourier` object generated from the FFT. 

287 """ 

288 if axes is None: 

289 axes = range(len(image_shape)) 

290 if isinstance(axes, int): 

291 axes = [axes] 

292 all_axes = range(len(image_shape)) 

293 image = np.fft.irfftn(image_fft, fft_shape, axes=axes).astype(dtype) 

294 # Shift the center of the image from the bottom left to the center 

295 image = np.fft.fftshift(image, axes=axes) 

296 # Trim the image to remove the padding added 

297 # to reduce fft artifacts 

298 image = centered(image, image_shape) 

299 key = (tuple(fft_shape), tuple(axes), tuple(all_axes)) 

300 

301 return Fourier(image, {key: image_fft}) 

302 

303 @property 

304 def image(self) -> np.ndarray: 

305 """The real space image""" 

306 return self._image 

307 

308 @property 

309 def shape(self) -> tuple[int, ...]: 

310 """The shape of the real space image""" 

311 return self._image.shape 

312 

313 def fft(self, fft_shape: Sequence[int], axes: int | Sequence[int]) -> np.ndarray: 

314 """The FFT of an image for a given `fft_shape` along desired `axes` 

315 

316 Parameters 

317 ---------- 

318 fft_shape: 

319 "Fast" shape of the image used to generate the FFT. 

320 This will be different than `image_fft.shape` if 

321 any of the dimensions are odd, since `np.fft.rfft` 

322 requires an even number of dimensions (for symmetry), 

323 so this tells `np.fft.irfft` how to go from 

324 complex k-space to real space. 

325 axes: 

326 The dimension(s) of the array that will be transformed. 

327 """ 

328 if isinstance(axes, int): 

329 axes = (axes,) 

330 all_axes = range(len(self.image.shape)) 

331 fft_key = (tuple(fft_shape), tuple(axes), tuple(all_axes)) 

332 

333 # If this is the first time calling `fft` for this shape, 

334 # generate the FFT. 

335 if fft_key not in self._fft: 

336 if len(fft_shape) != len(axes): 

337 msg = f"fft_shape self.axes must have the same number of dimensions, got {fft_shape}, {axes}" 

338 raise ValueError(msg) 

339 image = _pad(self.image, fft_shape, axes) 

340 self._fft[fft_key] = np.fft.rfftn(np.fft.ifftshift(image, axes), axes=axes) 

341 return self._fft[fft_key] 

342 

343 def __len__(self) -> int: 

344 """Length of the image""" 

345 return len(self.image) 

346 

347 def __getitem__(self, index: int | Sequence[int] | slice) -> Fourier: 

348 # Make the index a tuple 

349 if isinstance(index, int): 

350 index = tuple([index]) 

351 

352 # Axes that are removed from the shape of the new object 

353 if isinstance(index, slice): 

354 removed = np.array([]) 

355 else: 

356 removed = np.array([n for n, idx in enumerate(index) if idx is not None]) 

357 

358 # Create views into the fft transformed values, appropriately adjusting 

359 # the shapes for the new axes 

360 

361 fft_kernels = { 

362 ( 

363 tuple([s for idx, s in enumerate(key[0]) if key[0][idx] not in removed]), 

364 tuple([s for idx, s in enumerate(key[1]) if key[1][idx] not in removed]), 

365 tuple([s for idx, s in enumerate(key[2]) if key[2][idx] not in removed]), 

366 ): kernel[index] 

367 for key, kernel in self._fft.items() 

368 } 

369 # mpypy doesn't recognize that tuple[int, ...] 

370 # is a valid Sequence[int] for some reason 

371 return Fourier(self.image[index], fft_kernels) # type: ignore 

372 

373 

374def _kspace_operation( 

375 image1: Fourier, 

376 image2: Fourier, 

377 padding: int, 

378 op: Callable, 

379 shape: Sequence[int], 

380 axes: int | Sequence[int], 

381) -> Fourier: 

382 """Combine two images in k-space using a given `operator` 

383 

384 Parameters 

385 ---------- 

386 image1: 

387 The LHS of the equation. 

388 image2: 

389 The RHS of the equation. 

390 padding: 

391 The amount of padding to add before transforming into k-space. 

392 op: 

393 The operator used to combine the two images. 

394 This is either ``operator.mul`` for a convolution 

395 or ``operator.truediv`` for deconvolution. 

396 shape: 

397 The shape of the output image. 

398 axes: 

399 The dimension(s) of the array that will be transformed. 

400 """ 

401 if len(image1.shape) != len(image2.shape): 

402 msg = ( 

403 "Both images must have the same number of axes, " 

404 f"got {len(image1.shape)} and {len(image2.shape)}" 

405 ) 

406 raise ValueError(msg) 

407 

408 fft_shape = get_fft_shape(image1.image, image2.image, padding, axes) 

409 if ( 

410 op == operator.truediv 

411 or op == operator.floordiv 

412 or op == operator.itruediv 

413 or op == operator.ifloordiv 

414 ): 

415 # prevent divide by zero 

416 lhs = image1.fft(fft_shape, axes) 

417 rhs = image2.fft(fft_shape, axes) 

418 

419 # Broadcast, if necessary 

420 if rhs.shape[0] == 1 and lhs.shape[0] != rhs.shape[0]: 

421 rhs = np.tile(rhs, (lhs.shape[0],) + (1,) * len(rhs.shape[1:])) 

422 if lhs.shape[0] == 1 and lhs.shape[0] != rhs.shape[0]: 

423 lhs = np.tile(lhs, (rhs.shape[0],) + (1,) * len(lhs.shape[1:])) 

424 # only select non-zero elements for the denominator 

425 cuts = rhs != 0 

426 transformed_fft = np.zeros(lhs.shape, dtype=lhs.dtype) 

427 transformed_fft[cuts] = op(lhs[cuts], rhs[cuts]) 

428 else: 

429 transformed_fft = op(image1.fft(fft_shape, axes), image2.fft(fft_shape, axes)) 

430 return Fourier.from_fft(transformed_fft, fft_shape, shape, axes, image1.image.dtype) 

431 

432 

433def match_kernel( 

434 kernel1: np.ndarray | Fourier, 

435 kernel2: np.ndarray | Fourier, 

436 padding: int = 3, 

437 axes: int | Sequence[int] = (-2, -1), 

438 return_fourier: bool = True, 

439 normalize: bool = False, 

440) -> Fourier | np.ndarray: 

441 """Calculate the difference kernel to match kernel1 to kernel2 

442 

443 Parameters 

444 ---------- 

445 kernel1: 

446 The first kernel, either as array or as `Fourier` object 

447 kernel2: 

448 The second kernel, either as array or as `Fourier` object 

449 padding: 

450 Additional padding to use when generating the FFT 

451 to supress artifacts. 

452 axes: 

453 Axes that contain the spatial information for the kernels. 

454 return_fourier: 

455 Whether to return `Fourier` or array 

456 normalize: 

457 Whether or not to normalize the input kernels. 

458 

459 Returns 

460 ------- 

461 result: 

462 The difference kernel to go from `kernel1` to `kernel2`. 

463 """ 

464 if not isinstance(kernel1, Fourier): 

465 kernel1 = Fourier(kernel1) 

466 if not isinstance(kernel2, Fourier): 

467 kernel2 = Fourier(kernel2) 

468 

469 if kernel1.shape[0] < kernel2.shape[0]: 

470 shape = kernel2.shape 

471 else: 

472 shape = kernel1.shape 

473 

474 diff = _kspace_operation(kernel1, kernel2, padding, operator.truediv, shape, axes=axes) 

475 if return_fourier: 

476 return diff 

477 else: 

478 return np.real(diff.image) 

479 

480 

481def convolve( 

482 image: np.ndarray | Fourier, 

483 kernel: np.ndarray | Fourier, 

484 padding: int = 3, 

485 axes: int | Sequence[int] = (-2, -1), 

486 return_fourier: bool = True, 

487 normalize: bool = False, 

488) -> np.ndarray | Fourier: 

489 """Convolve image with a kernel 

490 

491 Parameters 

492 ---------- 

493 image: 

494 Image either as array or as `Fourier` object 

495 kernel: 

496 Convolution kernel either as array or as `Fourier` object 

497 padding: 

498 Additional padding to use when generating the FFT 

499 to suppress artifacts. 

500 axes: 

501 Axes that contain the spatial information for the PSFs. 

502 return_fourier: 

503 Whether to return `Fourier` or array 

504 normalize: 

505 Whether or not to normalize the input kernels. 

506 

507 Returns 

508 ------- 

509 result: 

510 The convolution of the image with the kernel. 

511 """ 

512 if not isinstance(image, Fourier): 

513 image = Fourier(image) 

514 if not isinstance(kernel, Fourier): 

515 kernel = Fourier(kernel) 

516 

517 convolved = _kspace_operation(image, kernel, padding, operator.mul, image.shape, axes=axes) 

518 if return_fourier: 

519 return convolved 

520 else: 

521 return np.real(convolved.image)