Coverage for tests / test_gp_interp.py: 21%
65 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-01 08:26 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-01 08:26 +0000
1# This file is part of meas_algorithms.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
23import unittest
25import numpy as np
27import lsst.utils.tests
28import lsst.geom
29import lsst.afw.image as afwImage
30from lsst.meas.algorithms import (
31 InterpolateOverDefectGaussianProcess,
32 GaussianProcessTreegp,
33)
36def rbf_kernel(x1, x2, sigma, correlation_length):
37 """
38 Computes the radial basis function (RBF) kernel matrix.
40 Parameters:
41 -----------
42 x1 : `np.array`
43 Location of training data point with shape (n_samples, n_features).
44 x2 : `np.array`
45 Location of training/test data point with shape (n_samples, n_features).
46 sigma : `float`
47 The scale parameter of the kernel.
48 correlation_length : `float`
49 The correlation length parameter of the kernel.
51 Returns:
52 --------
53 kernel : `np.array`
54 RBF kernel matrix with shape (n_samples, n_samples).
55 """
56 distance_squared = np.sum((x1[:, None, :] - x2[None, :, :]) ** 2, axis=-1)
57 kernel = (sigma**2) * np.exp(-0.5 * distance_squared / (correlation_length**2))
58 return kernel
61class InterpolateOverDefectGaussianProcessTestCase(lsst.utils.tests.TestCase):
62 """Test InterpolateOverDefectGaussianProcess."""
64 def setUp(self):
65 super().setUp()
67 npoints = 1000
68 self.std = 100
69 self.correlation_length = 10.0
70 self.white_noise = 1e-5
72 rng = np.random.Generator(np.random.MT19937(5))
74 x1 = rng.uniform(0, 99, npoints)
75 x2 = rng.uniform(0, 120, npoints)
76 coord1 = np.array([x1, x2]).T
78 kernel = rbf_kernel(coord1, coord1, self.std, self.correlation_length)
79 kernel += np.eye(npoints) * self.white_noise**2
81 # Data augmentation. Create a gaussian random field
82 # on a 100 * 100 is to slow. So generate 1e3 points
83 # and then interpolate it with a GP to do data augmentation.
85 z1 = rng.multivariate_normal(np.zeros(npoints), kernel)
87 x1 = np.linspace(0, 99, 100)
88 x2 = np.linspace(0, 120, 121)
89 x2, x1 = np.meshgrid(x2, x1)
90 coord2 = np.array([x1.reshape(-1), x2.reshape(-1)]).T
92 tgp = GaussianProcessTreegp(
93 std=self.std,
94 correlation_length=self.correlation_length,
95 white_noise=self.white_noise,
96 mean=0.0,
97 )
98 tgp.fit(coord1, z1)
99 z2 = tgp.predict(coord2)
100 z2 = z2.reshape(100, 121)
102 self.maskedimage = afwImage.MaskedImageF(100, 121)
103 for x in range(100):
104 for y in range(121):
105 self.maskedimage[x, y] = (z2[x, y], 0, 1.0)
107 # Clone the maskedimage so we can compare it after running the task.
108 self.reference = self.maskedimage.clone()
110 # Set some central pixels as SAT
111 sliceX, sliceY = slice(30, 35), slice(40, 45)
112 self.maskedimage.mask[sliceX, sliceY] = afwImage.Mask.getPlaneBitMask("SAT")
113 self.maskedimage.image[sliceX, sliceY] = np.nan
114 # Put nans here to make sure interp is done ok
116 # Set an entire column as BAD
117 self.maskedimage.mask[54:55, :] = afwImage.Mask.getPlaneBitMask("BAD")
118 self.maskedimage.image[54:55, :] = np.nan
120 # Set an entire row as BAD
121 self.maskedimage.mask[:, 110:111] = afwImage.Mask.getPlaneBitMask("BAD")
122 self.maskedimage.image[:, 110:111] = np.nan
124 # Set a diagonal set of pixels as CR
125 for i in range(74, 78):
126 self.maskedimage.mask[i, i] = afwImage.Mask.getPlaneBitMask("CR")
127 self.maskedimage.image[i, i] = np.nan
129 # Set one of the edges as EDGE
130 self.maskedimage.mask[0:1, :] = afwImage.Mask.getPlaneBitMask("EDGE")
131 self.maskedimage.image[0:1, :] = np.nan
133 # Set a smaller streak at the edge
134 self.maskedimage.mask[25:28, 0:1] = afwImage.Mask.getPlaneBitMask("EDGE")
135 self.maskedimage.image[25:28, 0:1] = np.nan
137 # Update the reference image's mask alone, so we can compare them after
138 # running the task.
139 self.reference.mask.array[:, :] = self.maskedimage.mask.array
141 # Create a noise image
142 # self.noise = self.maskedimage.clone()
143 # rng = np.random.Generator(np.random.MT19937(5))
144 # self.noise.image.array[:, :] = rng.normal(size=self.noise.image.array.shape)
146 def test_interpolation(self):
147 """Test that the interpolation is done correctly.
149 Parameters
150 ----------
151 method : `str`
152 Code used to solve gaussian process.
153 """
155 gp = InterpolateOverDefectGaussianProcess(
156 self.maskedimage,
157 defects=["BAD", "SAT", "CR", "EDGE"],
158 fwhm=self.correlation_length,
159 bin_image=False,
160 bin_spacing=30,
161 threshold_dynamic_binning=1000,
162 threshold_subdivide=20000,
163 correlation_length_cut=5,
164 log=None,
165 )
167 gp.run()
169 # Assert that the mask and the variance planes remain unchanged.
170 self.assertImagesEqual(self.maskedimage.variance, self.reference.variance)
172 # Check that interpolated pixels are close to the reference (original),
173 # and that none of them is still NaN.
174 self.assertTrue(np.isfinite(self.maskedimage.image.array).all())
175 self.assertImagesAlmostEqual(
176 self.maskedimage.image[1:, :],
177 self.reference.image[1:, :],
178 atol=2,
179 )
182def setup_module(module):
183 lsst.utils.tests.init()
186class MemoryTestCase(lsst.utils.tests.MemoryTestCase):
187 pass
190if __name__ == "__main__": 190 ↛ 191line 190 didn't jump to line 191 because the condition on line 190 was never true
191 lsst.utils.tests.init()
192 unittest.main()