Coverage for tests/test_fft.py: 15%

109 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-03-20 03:40 -0700

1# This file is part of lsst.scarlet.lite. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import operator 

23 

24import lsst.scarlet.lite.fft as fft 

25import numpy as np 

26from lsst.scarlet.lite import Fourier 

27from lsst.scarlet.lite.utils import integrated_circular_gaussian 

28from numpy.testing import assert_almost_equal, assert_array_equal 

29from scipy.signal import convolve as scipy_convolve 

30from utils import ScarletTestCase, get_psfs 

31 

32 

33class TestFourier(ScarletTestCase): 

34 """Test the centering and padding algorithms""" 

35 

36 def test_shift(self): 

37 """Test that padding and fft shift/unshift are consistent""" 

38 a0 = np.ones((1, 1)) 

39 a_pad = fft._pad(a0, (5, 4)) 

40 truth = [ 

41 [0.0, 0.0, 0.0, 0.0], 

42 [0.0, 0.0, 0.0, 0.0], 

43 [0.0, 0.0, 1.0, 0.0], 

44 [0.0, 0.0, 0.0, 0.0], 

45 [0.0, 0.0, 0.0, 0.0], 

46 ] 

47 assert_array_equal(a_pad, truth) 

48 

49 a_shift = np.fft.ifftshift(a_pad) 

50 truth = [ 

51 [1.0, 0.0, 0.0, 0.0], 

52 [0.0, 0.0, 0.0, 0.0], 

53 [0.0, 0.0, 0.0, 0.0], 

54 [0.0, 0.0, 0.0, 0.0], 

55 [0.0, 0.0, 0.0, 0.0], 

56 ] 

57 assert_array_equal(a_shift, truth) 

58 

59 # Shifting back should give us a_pad again 

60 a_shift_back = np.fft.fftshift(a_shift) 

61 assert_array_equal(a_shift_back, a_pad) 

62 

63 def test_center(self): 

64 """Test that centered method is compatible with shift/unshift""" 

65 shape = (5, 2) 

66 a0 = np.arange(10).reshape(shape) 

67 a_pad = fft._pad(a0, (9, 11)) 

68 truth = [ 

69 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 

70 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 

71 [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0], 

72 [0, 0, 0, 0, 0, 2, 3, 0, 0, 0, 0], 

73 [0, 0, 0, 0, 0, 4, 5, 0, 0, 0, 0], 

74 [0, 0, 0, 0, 0, 6, 7, 0, 0, 0, 0], 

75 [0, 0, 0, 0, 0, 8, 9, 0, 0, 0, 0], 

76 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 

77 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 

78 ] 

79 assert_array_equal(a_pad, truth) 

80 

81 a_shift = np.fft.ifftshift(a_pad) 

82 truth = [ 

83 [4, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0], 

84 [6, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0], 

85 [8, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0], 

86 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 

87 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 

88 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 

89 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 

90 [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], 

91 [2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0], 

92 ] 

93 assert_array_equal(a_shift, truth) 

94 

95 # Shifting back should give us a_pad again 

96 a_shift_back = np.fft.fftshift(a_shift) 

97 assert_array_equal(a_shift_back, a_pad) 

98 

99 # _centered should undo the padding, returning the original array 

100 a_final = fft.centered(a_pad, shape) 

101 assert_array_equal(a_final, a0) 

102 

103 with self.assertRaises(ValueError): 

104 fft.centered(a_final, (20, 20)) 

105 

106 def test_pad(self): 

107 x = np.arange(40).reshape(2, 4, 5) 

108 truth = np.zeros((2, 10, 11), dtype=int) 

109 truth[:, 3:7, 3:8] = x.copy() 

110 result = fft._pad(x, (10, 11), axes=(1, 2)) 

111 assert_array_equal(result, truth) 

112 

113 truth = np.zeros((4, 4, 5)) 

114 truth[1:3] = x 

115 result = fft._pad(x, (4,), axes=0) 

116 assert_array_equal(result, truth) 

117 

118 truth = np.pad(x, 5, mode="edge") 

119 result = fft._pad(x, (12, 14, 15), mode="edge") 

120 assert_array_equal(result, truth) 

121 

122 def test_get_fft_shape(self): 

123 shape1 = (3, 11) 

124 shape2 = (5, 10) 

125 shape = tuple(fft.get_fft_shape(shape1, shape2)) 

126 self.assertTupleEqual(shape, (12, 24)) 

127 

128 shape = fft.get_fft_shape(shape1, shape2, use_max=True) 

129 self.assertTupleEqual(shape, (8, 15)) 

130 

131 shape = fft.get_fft_shape(shape1, shape2, axes=1) 

132 self.assertTupleEqual(shape, (24,)) 

133 

134 shape = fft.get_fft_shape(shape1, shape2, axes=1, use_max=True) 

135 self.assertTupleEqual(shape, (15,)) 

136 

137 with self.assertRaises(ValueError): 

138 fft.get_fft_shape((1, 2), (1, 2, 3)) 

139 

140 def test_2d_psf_matching(self): 

141 """Test matching two 2D psfs""" 

142 # Narrow PSF 

143 psf1 = Fourier(get_psfs(1)) 

144 # Wide PSF 

145 psf2 = Fourier(get_psfs(2)) 

146 

147 # Test narrow to wide 

148 kernel_1to2 = fft.match_kernel(psf2, psf1) 

149 img2 = fft.convolve(psf1, kernel_1to2) 

150 assert_almost_equal(img2.image, psf2.image) 

151 

152 # Test wide to narrow 

153 kernel_2to1 = fft.match_kernel(psf1, psf2) 

154 img1 = fft.convolve(psf2, kernel_2to1) 

155 assert_almost_equal(img1.image, psf1.image) 

156 

157 def test_from_fft(self): 

158 x = integrated_circular_gaussian(sigma=1.0) 

159 _x = np.pad(x, 3, mode="constant") 

160 fft_x = np.fft.rfftn(np.fft.ifftshift(_x)) 

161 fourier = Fourier.from_fft(fft_x, (21, 21), (15, 15)) 

162 assert_almost_equal(fourier.image, x) 

163 self.assertEqual(len(fourier), 15) 

164 

165 def test_fourier(self): 

166 x = integrated_circular_gaussian(sigma=1.0) 

167 fourier = Fourier(x) 

168 assert_almost_equal(fourier.image, x) 

169 _x = np.pad(x, 3, mode="constant") 

170 fft_x = np.fft.rfftn(np.fft.ifftshift(_x)) 

171 assert_almost_equal(fourier.fft((21, 21), (0, 1)), fft_x) 

172 self.assertEqual(len(fourier), 15) 

173 

174 _x = np.pad(x, ((0, 0), (3, 3)), mode="constant") 

175 fft_x = np.fft.rfftn(np.fft.ifftshift(_x, axes=1), axes=(1,)) 

176 assert_almost_equal(fourier.fft((21,), 1), fft_x) 

177 

178 with self.assertRaises(ValueError): 

179 fourier.fft((3, 4, 5), (2, 3)) 

180 

181 def test_convolutions(self): 

182 x = integrated_circular_gaussian(sigma=1.0) 

183 y = integrated_circular_gaussian(sigma=1.3) 

184 

185 with self.assertRaises(ValueError): 

186 fft._kspace_operation(Fourier(x), Fourier(y[None, :, :]), 3, operator.mul, (15, 15), (0, 1)) 

187 

188 convolved = fft.convolve(x, y, return_fourier=False) 

189 truth = scipy_convolve(x, y, mode="same", method="direct") 

190 assert_almost_equal(convolved, truth) 

191 

192 def test_multiband_psf_matching(self): 

193 """Test matching two PSFs with a spectral dimension""" 

194 # Narrow PSF 

195 psf1 = Fourier(get_psfs(1)) 

196 # Wide PSF 

197 psf2 = Fourier(get_psfs((1, 2, 3))) 

198 

199 # Narrow to wide 

200 kernel_1to2 = fft.match_kernel(psf2, psf1) 

201 image = fft.convolve(kernel_1to2, psf1) 

202 assert_almost_equal(psf2.image, image.image) 

203 

204 kernel_array = fft.match_kernel(psf2, psf1, return_fourier=False) 

205 assert_almost_equal(kernel_array, kernel_1to2.image) 

206 

207 # Wide to narrow 

208 kernel_2to1 = fft.match_kernel(psf1, psf2) 

209 image = fft.convolve(kernel_2to1, psf2).image 

210 

211 for img in image: 

212 assert_almost_equal(img, psf1.image[0])