Coverage for python/lsst/scarlet/lite/wavelet.py: 7%

115 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-16 02:46 -0700

1# This file is part of scarlet_lite. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = [ 

23 "starlet_transform", 

24 "starlet_reconstruction", 

25 "multiband_starlet_transform", 

26 "multiband_starlet_reconstruction", 

27 "get_multiresolution_support", 

28] 

29 

30from typing import Callable, Sequence 

31 

32import numpy as np 

33 

34 

35def bspline_convolve(image: np.ndarray, scale: int) -> np.ndarray: 

36 """Convolve an image with a bspline at a given scale. 

37 

38 This uses the spline 

39 `h1d = np.array([1.0 / 16, 1.0 / 4, 3.0 / 8, 1.0 / 4, 1.0 / 16])` 

40 from Starck et al. 2011. 

41 

42 Parameters 

43 ---------- 

44 image: 

45 The 2D image or wavelet coefficients to convolve. 

46 scale: 

47 The wavelet scale for the convolution. This sets the 

48 spacing between adjacent pixels with the spline. 

49 

50 Returns 

51 ------- 

52 result: 

53 The result of convolving the `image` with the spline. 

54 """ 

55 # Filter for the scarlet transform. Here bspline 

56 h1d = np.array([1.0 / 16, 1.0 / 4, 3.0 / 8, 1.0 / 4, 1.0 / 16]) 

57 j = scale 

58 

59 slice0 = slice(None, -(2 ** (j + 1))) 

60 slice1 = slice(None, -(2**j)) 

61 slice3 = slice(2**j, None) 

62 slice4 = slice(2 ** (j + 1), None) 

63 # row 

64 col = image * h1d[2] 

65 col[slice4] += image[slice0] * h1d[0] 

66 col[slice3] += image[slice1] * h1d[1] 

67 col[slice1] += image[slice3] * h1d[3] 

68 col[slice0] += image[slice4] * h1d[4] 

69 

70 # column 

71 result = col * h1d[2] 

72 result[:, slice4] += col[:, slice0] * h1d[0] 

73 result[:, slice3] += col[:, slice1] * h1d[1] 

74 result[:, slice1] += col[:, slice3] * h1d[3] 

75 result[:, slice0] += col[:, slice4] * h1d[4] 

76 return result 

77 

78 

79def get_starlet_scales(image_shape: Sequence[int], scales: int | None = None) -> int: 

80 """Get the number of scales to use in the starlet transform. 

81 

82 Parameters 

83 ---------- 

84 image_shape: 

85 The 2D shape of the image that is being transformed 

86 scales: 

87 The number of scales to transform with starlets. 

88 The total dimension of the starlet will have 

89 `scales+1` dimensions, since it will also hold 

90 the image at all scales higher than `scales`. 

91 

92 Returns 

93 ------- 

94 result: 

95 Number of scales, adjusted for the size of the image. 

96 """ 

97 # Number of levels for the Starlet decomposition 

98 max_scale = int(np.log2(np.min(image_shape[-2:]))) - 1 

99 if (scales is None) or scales > max_scale: 

100 scales = max_scale 

101 return int(scales) 

102 

103 

104def starlet_transform( 

105 image: np.ndarray, 

106 scales: int | None = None, 

107 generation: int = 2, 

108 convolve2d: Callable | None = None, 

109) -> np.ndarray: 

110 """Perform a starlet transform, or 2nd gen starlet transform. 

111 

112 Parameters 

113 ---------- 

114 image: 

115 The image to transform into starlet coefficients. 

116 scales: 

117 The number of scale to transform with starlets. 

118 The total dimension of the starlet will have 

119 `scales+1` dimensions, since it will also hold 

120 the image at all scales higher than `scales`. 

121 generation: 

122 The generation of the transform. 

123 This must be `1` or `2`. 

124 convolve2d: 

125 The filter function to use to convolve the image 

126 with starlets in 2D. 

127 

128 Returns 

129 ------- 

130 starlet: 

131 The starlet dictionary for the input `image`. 

132 """ 

133 if len(image.shape) != 2: 

134 raise ValueError(f"Image should be 2D, got {len(image.shape)}") 

135 if generation not in (1, 2): 

136 raise ValueError(f"generation should be 1 or 2, got {generation}") 

137 

138 scales = get_starlet_scales(image.shape, scales) 

139 c = image 

140 if convolve2d is None: 

141 convolve2d = bspline_convolve 

142 

143 # wavelet set of coefficients. 

144 starlet = np.zeros((scales + 1,) + image.shape, dtype=image.dtype) 

145 for j in range(scales): 

146 gen1 = convolve2d(c, j) 

147 

148 if generation == 2: 

149 gen2 = convolve2d(gen1, j) 

150 starlet[j] = c - gen2 

151 else: 

152 starlet[j] = c - gen1 

153 

154 c = gen1 

155 

156 starlet[-1] = c 

157 return starlet 

158 

159 

160def multiband_starlet_transform( 

161 image: np.ndarray, 

162 scales: int | None = None, 

163 generation: int = 2, 

164 convolve2d: Callable | None = None, 

165) -> np.ndarray: 

166 """Perform a starlet transform of a multiband image. 

167 

168 See `starlet_transform` for a description of the parameters. 

169 """ 

170 if len(image.shape) != 3: 

171 raise ValueError(f"Image should be 3D (bands, height, width), got shape {len(image.shape)}") 

172 if generation not in (1, 2): 

173 raise ValueError(f"generation should be 1 or 2, got {generation}") 

174 scales = get_starlet_scales(image.shape, scales) 

175 

176 wavelets = np.empty((scales + 1,) + image.shape, dtype=image.dtype) 

177 for b, image in enumerate(image): 

178 wavelets[:, b] = starlet_transform(image, scales=scales, generation=generation, convolve2d=convolve2d) 

179 return wavelets 

180 

181 

182def starlet_reconstruction( 

183 starlets: np.ndarray, 

184 generation: int = 2, 

185 convolve2d: Callable | None = None, 

186) -> np.ndarray: 

187 """Reconstruct an image from a dictionary of starlets 

188 

189 Parameters 

190 ---------- 

191 starlets: 

192 The starlet dictionary used to reconstruct the image 

193 with dimension (scales+1, Ny, Nx). 

194 generation: 

195 The generation of the starlet transform (either ``1`` or ``2``). 

196 convolve2d: 

197 The filter function to use to convolve the image 

198 with starlets in 2D. 

199 

200 Returns 

201 ------- 

202 image: 

203 The 2D image reconstructed from the input `starlet`. 

204 """ 

205 if generation == 1: 

206 return np.sum(starlets, axis=0) 

207 if convolve2d is None: 

208 convolve2d = bspline_convolve 

209 scales = len(starlets) - 1 

210 

211 c = starlets[-1] 

212 for i in range(1, scales + 1): 

213 j = scales - i 

214 cj = convolve2d(c, j) 

215 c = cj + starlets[j] 

216 return c 

217 

218 

219def multiband_starlet_reconstruction( 

220 starlets: np.ndarray, 

221 generation: int = 2, 

222 convolve2d: Callable | None = None, 

223) -> np.ndarray: 

224 """Reconstruct a multiband image. 

225 

226 See `starlet_reconstruction` for a description of the 

227 remainder of the parameters. 

228 """ 

229 scales, bands, width, height = starlets.shape 

230 result = np.zeros((bands, width, height), dtype=starlets.dtype) 

231 for band in range(bands): 

232 result[band] = starlet_reconstruction(starlets[:, band], generation=generation, convolve2d=convolve2d) 

233 return result 

234 

235 

236def get_multiresolution_support( 

237 image: np.ndarray, 

238 starlets: np.ndarray, 

239 sigma: float, 

240 sigma_scaling: float = 3, 

241 epsilon: float = 1e-1, 

242 max_iter: int = 20, 

243 image_type: str = "ground", 

244) -> np.ndarray: 

245 """Calculate the multi-resolution support for a 

246 dictionary of starlet coefficients. 

247 

248 This is different for ground and space based telescopes. 

249 For space-based telescopes the procedure in Starck and Murtagh 1998 

250 iteratively calculates the multi-resolution support. 

251 For ground based images, where the PSF is much wider and there are no 

252 pixels with no signal at all scales, we use a modified method that 

253 estimates support at each scale independently. 

254 

255 Parameters 

256 ---------- 

257 image: 

258 The image to transform into starlet coefficients. 

259 starlets: 

260 The starlet dictionary used to reconstruct `image` with 

261 dimension (scales+1, Ny, Nx). 

262 sigma: 

263 The standard deviation of the `image`. 

264 sigma_scaling: 

265 The multiple of `sigma` to use to calculate significance. 

266 Coefficients `w` where `|w| > K*sigma_j`, where `sigma_j` is 

267 standard deviation at the jth scale, are considered significant. 

268 epsilon: 

269 The convergence criteria of the algorithm. 

270 Once `|new_sigma_j - sigma_j|/new_sigma_j < epsilon` the 

271 algorithm has completed. 

272 max_iter: 

273 Maximum number of iterations to fit `sigma_j` at each scale. 

274 image_type: 

275 The type of image that is being used. 

276 This should be "ground" for ground based images with wide PSFs or 

277 "space" for images from space-based telescopes with a narrow PSF. 

278 

279 Returns 

280 ------- 

281 M: 

282 Mask with significant coefficients in `starlets` set to `True`. 

283 """ 

284 if image_type not in ("ground", "space"): 

285 raise ValueError(f"image_type must be 'ground' or 'space', got {image_type}") 

286 

287 if image_type == "space": 

288 # Calculate sigma_je, the standard deviation at 

289 # each scale due to gaussian noise 

290 noise_img = np.random.normal(size=image.shape) 

291 noise_starlet = starlet_transform(noise_img, generation=1, scales=len(starlets) - 1) 

292 sigma_je = np.zeros((len(noise_starlet),)) 

293 for j, star in enumerate(noise_starlet): 

294 sigma_je[j] = np.std(star) 

295 noise = image - starlets[-1] 

296 

297 last_sigma_i = sigma 

298 for it in range(max_iter): 

299 m = np.abs(starlets) > sigma_scaling * sigma * sigma_je[:, None, None] 

300 s = np.sum(m, axis=0) == 0 

301 sigma_i = np.std(noise * s) 

302 if np.abs(sigma_i - last_sigma_i) / sigma_i < epsilon: 

303 break 

304 last_sigma_i = sigma_i 

305 else: 

306 # Sigma to use for significance at each scale 

307 # Initially we use the input `sigma` 

308 sigma_j = np.ones((len(starlets),), dtype=image.dtype) * sigma 

309 last_sigma_j = sigma_j 

310 for it in range(max_iter): 

311 m = np.abs(starlets) > sigma_scaling * sigma_j[:, None, None] 

312 # Take the standard deviation of the current 

313 # insignificant coeffs at each scale 

314 s = ~m 

315 sigma_j = np.std(starlets * s.astype(int), axis=(1, 2)) 

316 # At lower scales all of the pixels may be significant, 

317 # so sigma is effectively zero. To avoid infinities we 

318 # only check the scales with non-zero sigma 

319 cut = sigma_j > 0 

320 if np.all(np.abs(sigma_j[cut] - last_sigma_j[cut]) / sigma_j[cut] < epsilon): 

321 break 

322 

323 last_sigma_j = sigma_j 

324 # noinspection PyUnboundLocalVariable 

325 return m.astype(int) 

326 

327 

328def apply_wavelet_denoising( 

329 image: np.ndarray, 

330 sigma: float | None = None, 

331 sigma_scaling: float = 3, 

332 epsilon: float = 1e-1, 

333 max_iter: int = 20, 

334 image_type: str = "ground", 

335 positive: bool = True, 

336) -> np.ndarray: 

337 """Apply wavelet denoising 

338 

339 Uses the algorithm and notation from Starck et al. 2011, section 4.1 

340 

341 Parameters 

342 ---------- 

343 image: 

344 The image to denoise 

345 sigma: 

346 The standard deviation of the image 

347 sigma_scaling: 

348 The threshold in units of sigma to declare a coefficient significant 

349 epsilon: 

350 Convergence criteria for determining the support 

351 max_iter: 

352 The maximum number of iterations. 

353 This applies to both finding the support and the denoising loop. 

354 image_type: 

355 The type of image that is being used. 

356 This should be "ground" for ground based images with wide PSFs or 

357 "space" for images from space-based telescopes with a narrow PSF. 

358 positive: 

359 Whether or not the expected result should be positive 

360 

361 Returns 

362 ------- 

363 result: 

364 The resulting denoised image after `max_iter` iterations. 

365 """ 

366 image_coeffs = starlet_transform(image) 

367 if sigma is None: 

368 sigma = np.median(np.absolute(image - np.median(image))) 

369 coeffs = image_coeffs.copy() 

370 support = get_multiresolution_support( 

371 image=image, 

372 starlets=coeffs, 

373 sigma=sigma, 

374 sigma_scaling=sigma_scaling, 

375 epsilon=epsilon, 

376 max_iter=max_iter, 

377 image_type=image_type, 

378 ) 

379 x = starlet_reconstruction(coeffs) 

380 

381 for n in range(max_iter): 

382 coeffs = starlet_transform(x) 

383 x = x + starlet_reconstruction(support * (image_coeffs - coeffs)) 

384 if positive: 

385 x[x < 0] = 0 

386 return x