Coverage for python / lsst / scarlet / lite / wavelet.py: 10%

121 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-26 08:40 +0000

1# This file is part of scarlet_lite. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = [ 

23 "starlet_transform", 

24 "starlet_reconstruction", 

25 "multiband_starlet_transform", 

26 "multiband_starlet_reconstruction", 

27 "get_multiresolution_support", 

28] 

29 

30from dataclasses import dataclass 

31from typing import Callable, Sequence 

32 

33import numpy as np 

34 

35 

36def bspline_convolve(image: np.ndarray, scale: int) -> np.ndarray: 

37 """Convolve an image with a bspline at a given scale. 

38 

39 This uses the spline 

40 `h1d = np.array([1.0 / 16, 1.0 / 4, 3.0 / 8, 1.0 / 4, 1.0 / 16])` 

41 from Starck et al. 2011. 

42 

43 Parameters 

44 ---------- 

45 image: 

46 The 2D image or wavelet coefficients to convolve. 

47 scale: 

48 The wavelet scale for the convolution. This sets the 

49 spacing between adjacent pixels with the spline. 

50 

51 Returns 

52 ------- 

53 result: 

54 The result of convolving the `image` with the spline. 

55 """ 

56 # Filter for the scarlet transform. Here bspline 

57 h1d = np.array([1.0 / 16, 1.0 / 4, 3.0 / 8, 1.0 / 4, 1.0 / 16]).astype(image.dtype) 

58 j = scale 

59 

60 slice0 = slice(None, -(2 ** (j + 1))) 

61 slice1 = slice(None, -(2**j)) 

62 slice3 = slice(2**j, None) 

63 slice4 = slice(2 ** (j + 1), None) 

64 # row 

65 col = image * h1d[2] 

66 col[slice4] += image[slice0] * h1d[0] 

67 col[slice3] += image[slice1] * h1d[1] 

68 col[slice1] += image[slice3] * h1d[3] 

69 col[slice0] += image[slice4] * h1d[4] 

70 

71 # column 

72 result = col * h1d[2] 

73 result[:, slice4] += col[:, slice0] * h1d[0] 

74 result[:, slice3] += col[:, slice1] * h1d[1] 

75 result[:, slice1] += col[:, slice3] * h1d[3] 

76 result[:, slice0] += col[:, slice4] * h1d[4] 

77 return result 

78 

79 

80def get_starlet_scales(image_shape: Sequence[int], scales: int | None = None) -> int: 

81 """Get the number of scales to use in the starlet transform. 

82 

83 Parameters 

84 ---------- 

85 image_shape: 

86 The 2D shape of the image that is being transformed 

87 scales: 

88 The number of scales to transform with starlets. 

89 The total dimension of the starlet will have 

90 `scales+1` dimensions, since it will also hold 

91 the image at all scales higher than `scales`. 

92 

93 Returns 

94 ------- 

95 result: 

96 Number of scales, adjusted for the size of the image. 

97 """ 

98 # Number of levels for the Starlet decomposition 

99 max_scale = int(np.log2(np.min(image_shape[-2:]))) - 1 

100 if (scales is None) or scales > max_scale: 

101 scales = max_scale 

102 return int(scales) 

103 

104 

105def starlet_transform( 

106 image: np.ndarray, 

107 scales: int | None = None, 

108 generation: int = 2, 

109 convolve2d: Callable | None = None, 

110) -> np.ndarray: 

111 """Perform a starlet transform, or 2nd gen starlet transform. 

112 

113 Parameters 

114 ---------- 

115 image: 

116 The image to transform into starlet coefficients. 

117 scales: 

118 The number of scale to transform with starlets. 

119 The total dimension of the starlet will have 

120 `scales+1` dimensions, since it will also hold 

121 the image at all scales higher than `scales`. 

122 generation: 

123 The generation of the transform. 

124 This must be `1` or `2`. 

125 convolve2d: 

126 The filter function to use to convolve the image 

127 with starlets in 2D. 

128 

129 Returns 

130 ------- 

131 starlet: 

132 The starlet dictionary for the input `image`. 

133 """ 

134 if len(image.shape) != 2: 

135 raise ValueError(f"Image should be 2D, got {len(image.shape)}") 

136 if generation not in (1, 2): 

137 raise ValueError(f"generation should be 1 or 2, got {generation}") 

138 

139 scales = get_starlet_scales(image.shape, scales) 

140 c = image 

141 if convolve2d is None: 

142 convolve2d = bspline_convolve 

143 

144 # wavelet set of coefficients. 

145 starlet = np.zeros((scales + 1,) + image.shape, dtype=image.dtype) 

146 for j in range(scales): 

147 gen1 = convolve2d(c, j) 

148 

149 if generation == 2: 

150 gen2 = convolve2d(gen1, j) 

151 starlet[j] = c - gen2 

152 else: 

153 starlet[j] = c - gen1 

154 

155 c = gen1 

156 

157 starlet[-1] = c 

158 return starlet 

159 

160 

161def multiband_starlet_transform( 

162 image: np.ndarray, 

163 scales: int | None = None, 

164 generation: int = 2, 

165 convolve2d: Callable | None = None, 

166) -> np.ndarray: 

167 """Perform a starlet transform of a multiband image. 

168 

169 See `starlet_transform` for a description of the parameters. 

170 """ 

171 if len(image.shape) != 3: 

172 raise ValueError(f"Image should be 3D (bands, height, width), got shape {len(image.shape)}") 

173 if generation not in (1, 2): 

174 raise ValueError(f"generation should be 1 or 2, got {generation}") 

175 scales = get_starlet_scales(image.shape, scales) 

176 

177 wavelets = np.empty((scales + 1,) + image.shape, dtype=image.dtype) 

178 for b, image in enumerate(image): 

179 wavelets[:, b] = starlet_transform(image, scales=scales, generation=generation, convolve2d=convolve2d) 

180 return wavelets 

181 

182 

183def starlet_reconstruction( 

184 starlets: np.ndarray, 

185 generation: int = 2, 

186 convolve2d: Callable | None = None, 

187) -> np.ndarray: 

188 """Reconstruct an image from a dictionary of starlets 

189 

190 Parameters 

191 ---------- 

192 starlets: 

193 The starlet dictionary used to reconstruct the image 

194 with dimension (scales+1, Ny, Nx). 

195 generation: 

196 The generation of the starlet transform (either ``1`` or ``2``). 

197 convolve2d: 

198 The filter function to use to convolve the image 

199 with starlets in 2D. 

200 

201 Returns 

202 ------- 

203 image: 

204 The 2D image reconstructed from the input `starlet`. 

205 """ 

206 if generation == 1: 

207 return np.sum(starlets, axis=0) 

208 if convolve2d is None: 

209 convolve2d = bspline_convolve 

210 scales = len(starlets) - 1 

211 

212 c = starlets[-1] 

213 for i in range(1, scales + 1): 

214 j = scales - i 

215 cj = convolve2d(c, j) 

216 c = cj + starlets[j] 

217 return c 

218 

219 

220def multiband_starlet_reconstruction( 

221 starlets: np.ndarray, 

222 generation: int = 2, 

223 convolve2d: Callable | None = None, 

224) -> np.ndarray: 

225 """Reconstruct a multiband image. 

226 

227 See `starlet_reconstruction` for a description of the 

228 remainder of the parameters. 

229 """ 

230 _, bands, width, height = starlets.shape 

231 result = np.zeros((bands, width, height), dtype=starlets.dtype) 

232 for band in range(bands): 

233 result[band] = starlet_reconstruction(starlets[:, band], generation=generation, convolve2d=convolve2d) 

234 return result 

235 

236 

237@dataclass 

238class MultiResolutionSupport: 

239 support: np.ndarray 

240 sigma: np.ndarray 

241 

242 

243def get_multiresolution_support( 

244 image: np.ndarray, 

245 starlets: np.ndarray, 

246 sigma: np.floating, 

247 sigma_scaling: float = 3, 

248 epsilon: float = 1e-1, 

249 max_iter: int = 20, 

250 image_type: str = "ground", 

251) -> MultiResolutionSupport: 

252 """Calculate the multi-resolution support for a 

253 dictionary of starlet coefficients. 

254 

255 This is different for ground and space based telescopes. 

256 For space-based telescopes the procedure in Starck and Murtagh 1998 

257 iteratively calculates the multi-resolution support. 

258 For ground based images, where the PSF is much wider and there are no 

259 pixels with no signal at all scales, we use a modified method that 

260 estimates support at each scale independently. 

261 

262 Parameters 

263 ---------- 

264 image: 

265 The image to transform into starlet coefficients. 

266 starlets: 

267 The starlet dictionary used to reconstruct `image` with 

268 dimension (scales+1, Ny, Nx). 

269 sigma: 

270 The standard deviation of the `image`. 

271 sigma_scaling: 

272 The multiple of `sigma` to use to calculate significance. 

273 Coefficients `w` where `|w| > K*sigma_j`, where `sigma_j` is 

274 standard deviation at the jth scale, are considered significant. 

275 epsilon: 

276 The convergence criteria of the algorithm. 

277 Once `|new_sigma_j - sigma_j|/new_sigma_j < epsilon` the 

278 algorithm has completed. 

279 max_iter: 

280 Maximum number of iterations to fit `sigma_j` at each scale. 

281 image_type: 

282 The type of image that is being used. 

283 This should be "ground" for ground based images with wide PSFs or 

284 "space" for images from space-based telescopes with a narrow PSF. 

285 

286 Returns 

287 ------- 

288 M: 

289 Mask with significant coefficients in `starlets` set to `True`. 

290 """ 

291 if image_type not in ("ground", "space"): 

292 raise ValueError(f"image_type must be 'ground' or 'space', got {image_type}") 

293 

294 if image_type == "space": 

295 # Calculate sigma_je, the standard deviation at 

296 # each scale due to gaussian noise 

297 noise_img = np.random.normal(size=image.shape) 

298 noise_starlet = starlet_transform(noise_img, generation=1, scales=len(starlets) - 1) 

299 sigma_je = np.zeros((len(noise_starlet),)) 

300 for j, star in enumerate(noise_starlet): 

301 sigma_je[j] = np.std(star) 

302 noise = image - starlets[-1] 

303 

304 last_sigma_i = sigma 

305 for it in range(max_iter): 

306 m = np.abs(starlets) > sigma_scaling * sigma * sigma_je[:, None, None] 

307 s = np.sum(m, axis=0) == 0 

308 sigma_i = np.std(noise * s) 

309 if np.abs(sigma_i - last_sigma_i) / sigma_i < epsilon: 

310 break 

311 last_sigma_i = sigma_i 

312 sigma_j = sigma_je 

313 else: 

314 # Sigma to use for significance at each scale 

315 # Initially we use the input `sigma` 

316 sigma_j = np.full(len(starlets), sigma, dtype=image.dtype) 

317 last_sigma_j = sigma_j 

318 for it in range(max_iter): 

319 m = np.abs(starlets) > sigma_scaling * sigma_j[:, None, None] 

320 # Take the standard deviation of the current 

321 # insignificant coeffs at each scale 

322 s = ~m 

323 sigma_j = np.std(starlets * s.astype(int), axis=(1, 2)) 

324 # At lower scales all of the pixels may be significant, 

325 # so sigma is effectively zero. To avoid infinities we 

326 # only check the scales with non-zero sigma 

327 cut = sigma_j > 0 

328 if np.all(np.abs(sigma_j[cut] - last_sigma_j[cut]) / sigma_j[cut] < epsilon): 

329 break 

330 

331 last_sigma_j = sigma_j 

332 # noinspection PyUnboundLocalVariable 

333 return MultiResolutionSupport(support=m.astype(int), sigma=sigma_j) 

334 

335 

336def apply_wavelet_denoising( 

337 image: np.ndarray, 

338 sigma: np.floating | None = None, 

339 sigma_scaling: float = 3, 

340 epsilon: float = 1e-1, 

341 max_iter: int = 20, 

342 image_type: str = "ground", 

343 positive: bool = True, 

344) -> np.ndarray: 

345 """Apply wavelet denoising 

346 

347 Uses the algorithm and notation from Starck et al. 2011, section 4.1 

348 

349 Parameters 

350 ---------- 

351 image: 

352 The image to denoise 

353 sigma: 

354 The standard deviation of the image 

355 sigma_scaling: 

356 The threshold in units of sigma to declare a coefficient significant 

357 epsilon: 

358 Convergence criteria for determining the support 

359 max_iter: 

360 The maximum number of iterations. 

361 This applies to both finding the support and the denoising loop. 

362 image_type: 

363 The type of image that is being used. 

364 This should be "ground" for ground based images with wide PSFs or 

365 "space" for images from space-based telescopes with a narrow PSF. 

366 positive: 

367 Whether or not the expected result should be positive 

368 

369 Returns 

370 ------- 

371 result: 

372 The resulting denoised image after `max_iter` iterations. 

373 """ 

374 image_coeffs = starlet_transform(image) 

375 if sigma is None: 

376 sigma = np.median(np.absolute(image - np.median(image))) 

377 coeffs = image_coeffs.copy() 

378 support = get_multiresolution_support( 

379 image=image, 

380 starlets=coeffs, 

381 sigma=sigma, 

382 sigma_scaling=sigma_scaling, 

383 epsilon=epsilon, 

384 max_iter=max_iter, 

385 image_type=image_type, 

386 ) 

387 x = starlet_reconstruction(coeffs) 

388 

389 for n in range(max_iter): 

390 coeffs = starlet_transform(x) 

391 x = x + starlet_reconstruction(support.support * (image_coeffs - coeffs)) 

392 if positive: 

393 x[x < 0] = 0 

394 return x