Coverage for tests/test_fft.py: 15%
109 statements
« prev ^ index » next coverage.py v7.4.1, created at 2024-02-07 11:26 +0000
« prev ^ index » next coverage.py v7.4.1, created at 2024-02-07 11:26 +0000
1# This file is part of lsst.scarlet.lite.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import operator
24import lsst.scarlet.lite.fft as fft
25import numpy as np
26from lsst.scarlet.lite import Fourier
27from lsst.scarlet.lite.utils import integrated_circular_gaussian
28from numpy.testing import assert_almost_equal, assert_array_equal
29from scipy.signal import convolve as scipy_convolve
30from utils import ScarletTestCase, get_psfs
33class TestFourier(ScarletTestCase):
34 """Test the centering and padding algorithms"""
36 def test_shift(self):
37 """Test that padding and fft shift/unshift are consistent"""
38 a0 = np.ones((1, 1))
39 a_pad = fft._pad(a0, (5, 4))
40 truth = [
41 [0.0, 0.0, 0.0, 0.0],
42 [0.0, 0.0, 0.0, 0.0],
43 [0.0, 0.0, 1.0, 0.0],
44 [0.0, 0.0, 0.0, 0.0],
45 [0.0, 0.0, 0.0, 0.0],
46 ]
47 assert_array_equal(a_pad, truth)
49 a_shift = np.fft.ifftshift(a_pad)
50 truth = [
51 [1.0, 0.0, 0.0, 0.0],
52 [0.0, 0.0, 0.0, 0.0],
53 [0.0, 0.0, 0.0, 0.0],
54 [0.0, 0.0, 0.0, 0.0],
55 [0.0, 0.0, 0.0, 0.0],
56 ]
57 assert_array_equal(a_shift, truth)
59 # Shifting back should give us a_pad again
60 a_shift_back = np.fft.fftshift(a_shift)
61 assert_array_equal(a_shift_back, a_pad)
63 def test_center(self):
64 """Test that centered method is compatible with shift/unshift"""
65 shape = (5, 2)
66 a0 = np.arange(10).reshape(shape)
67 a_pad = fft._pad(a0, (9, 11))
68 truth = [
69 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
70 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
71 [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
72 [0, 0, 0, 0, 0, 2, 3, 0, 0, 0, 0],
73 [0, 0, 0, 0, 0, 4, 5, 0, 0, 0, 0],
74 [0, 0, 0, 0, 0, 6, 7, 0, 0, 0, 0],
75 [0, 0, 0, 0, 0, 8, 9, 0, 0, 0, 0],
76 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
77 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
78 ]
79 assert_array_equal(a_pad, truth)
81 a_shift = np.fft.ifftshift(a_pad)
82 truth = [
83 [4, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0],
84 [6, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0],
85 [8, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0],
86 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
87 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
88 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
89 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
90 [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
91 [2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0],
92 ]
93 assert_array_equal(a_shift, truth)
95 # Shifting back should give us a_pad again
96 a_shift_back = np.fft.fftshift(a_shift)
97 assert_array_equal(a_shift_back, a_pad)
99 # _centered should undo the padding, returning the original array
100 a_final = fft.centered(a_pad, shape)
101 assert_array_equal(a_final, a0)
103 with self.assertRaises(ValueError):
104 fft.centered(a_final, (20, 20))
106 def test_pad(self):
107 x = np.arange(40).reshape(2, 4, 5)
108 truth = np.zeros((2, 10, 11), dtype=int)
109 truth[:, 3:7, 3:8] = x.copy()
110 result = fft._pad(x, (10, 11), axes=(1, 2))
111 assert_array_equal(result, truth)
113 truth = np.zeros((4, 4, 5))
114 truth[1:3] = x
115 result = fft._pad(x, (4,), axes=0)
116 assert_array_equal(result, truth)
118 truth = np.pad(x, 5, mode="edge")
119 result = fft._pad(x, (12, 14, 15), mode="edge")
120 assert_array_equal(result, truth)
122 def test_get_fft_shape(self):
123 shape1 = (3, 11)
124 shape2 = (5, 10)
125 shape = tuple(fft.get_fft_shape(shape1, shape2))
126 self.assertTupleEqual(shape, (12, 24))
128 shape = fft.get_fft_shape(shape1, shape2, use_max=True)
129 self.assertTupleEqual(shape, (8, 15))
131 shape = fft.get_fft_shape(shape1, shape2, axes=1)
132 self.assertTupleEqual(shape, (24,))
134 shape = fft.get_fft_shape(shape1, shape2, axes=1, use_max=True)
135 self.assertTupleEqual(shape, (15,))
137 with self.assertRaises(ValueError):
138 fft.get_fft_shape((1, 2), (1, 2, 3))
140 def test_2d_psf_matching(self):
141 """Test matching two 2D psfs"""
142 # Narrow PSF
143 psf1 = Fourier(get_psfs(1))
144 # Wide PSF
145 psf2 = Fourier(get_psfs(2))
147 # Test narrow to wide
148 kernel_1to2 = fft.match_kernel(psf2, psf1)
149 img2 = fft.convolve(psf1, kernel_1to2)
150 assert_almost_equal(img2.image, psf2.image)
152 # Test wide to narrow
153 kernel_2to1 = fft.match_kernel(psf1, psf2)
154 img1 = fft.convolve(psf2, kernel_2to1)
155 assert_almost_equal(img1.image, psf1.image)
157 def test_from_fft(self):
158 x = integrated_circular_gaussian(sigma=1.0)
159 _x = np.pad(x, 3, mode="constant")
160 fft_x = np.fft.rfftn(np.fft.ifftshift(_x))
161 fourier = Fourier.from_fft(fft_x, (21, 21), (15, 15))
162 assert_almost_equal(fourier.image, x)
163 self.assertEqual(len(fourier), 15)
165 def test_fourier(self):
166 x = integrated_circular_gaussian(sigma=1.0)
167 fourier = Fourier(x)
168 assert_almost_equal(fourier.image, x)
169 _x = np.pad(x, 3, mode="constant")
170 fft_x = np.fft.rfftn(np.fft.ifftshift(_x))
171 assert_almost_equal(fourier.fft((21, 21), (0, 1)), fft_x)
172 self.assertEqual(len(fourier), 15)
174 _x = np.pad(x, ((0, 0), (3, 3)), mode="constant")
175 fft_x = np.fft.rfftn(np.fft.ifftshift(_x, axes=1), axes=(1,))
176 assert_almost_equal(fourier.fft((21,), 1), fft_x)
178 with self.assertRaises(ValueError):
179 fourier.fft((3, 4, 5), (2, 3))
181 def test_convolutions(self):
182 x = integrated_circular_gaussian(sigma=1.0)
183 y = integrated_circular_gaussian(sigma=1.3)
185 with self.assertRaises(ValueError):
186 fft._kspace_operation(Fourier(x), Fourier(y[None, :, :]), 3, operator.mul, (15, 15), (0, 1))
188 convolved = fft.convolve(x, y, return_fourier=False)
189 truth = scipy_convolve(x, y, mode="same", method="direct")
190 assert_almost_equal(convolved, truth)
192 def test_multiband_psf_matching(self):
193 """Test matching two PSFs with a spectral dimension"""
194 # Narrow PSF
195 psf1 = Fourier(get_psfs(1))
196 # Wide PSF
197 psf2 = Fourier(get_psfs((1, 2, 3)))
199 # Narrow to wide
200 kernel_1to2 = fft.match_kernel(psf2, psf1)
201 image = fft.convolve(kernel_1to2, psf1)
202 assert_almost_equal(psf2.image, image.image)
204 kernel_array = fft.match_kernel(psf2, psf1, return_fourier=False)
205 assert_almost_equal(kernel_array, kernel_1to2.image)
207 # Wide to narrow
208 kernel_2to1 = fft.match_kernel(psf1, psf2)
209 image = fft.convolve(kernel_2to1, psf2).image
211 for img in image:
212 assert_almost_equal(img, psf1.image[0])