Coverage for tests / utils.py: 16%

136 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-14 23:28 +0000

1# This file is part of lsst.scarlet.lite. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import sys 

23import traceback 

24from typing import Sequence, cast 

25from unittest import TestCase 

26 

27import numpy as np 

28from lsst.scarlet.lite.bbox import Box 

29from lsst.scarlet.lite.component import FactorizedComponent 

30from lsst.scarlet.lite.fft import match_kernel 

31from lsst.scarlet.lite.image import Image 

32from lsst.scarlet.lite.source import Source 

33from lsst.scarlet.lite.utils import integrated_circular_gaussian 

34from numpy.testing import assert_almost_equal, assert_array_equal 

35from numpy.typing import DTypeLike 

36from scipy.signal import convolve as scipy_convolve 

37 

38__all__ = ["get_psfs", "ObservationData", "ScarletTestCase"] 

39 

40 

41def get_psfs(sigmas: float | Sequence[float]) -> np.ndarray: 

42 try: 

43 iter(sigmas) 

44 except TypeError: 

45 sigmas = (sigmas,) 

46 psf = [integrated_circular_gaussian(sigma=sigma) for sigma in sigmas] 

47 return np.array(psf) 

48 

49 

50def execute_doc_scripts(filename: str): 

51 """Test python code in docstrings and document files. 

52 

53 Any lines not containing code are replaced with a newline character, 

54 that way if any of the code blocks fail, the line with the error will 

55 match the linenumber in the .rst file or python file with the docstring. 

56 

57 Parameters 

58 ---------- 

59 filename: 

60 The name of the file to test. 

61 """ 

62 with open(filename) as file: 

63 lines = file.readlines() 

64 

65 full_script = "" 

66 script = "" 

67 whitespace = 0 

68 code_block_start = None 

69 for n, line in enumerate(lines): 

70 if ".. code-block:: python" in line: 

71 if code_block_start is not None: 

72 message = ( 

73 f"End of the previous code block starting at {code_block_start}" 

74 f"was not detected by the new code block starting at {n}" 

75 ) 

76 raise ValueError(message) 

77 code_block_start = n 

78 tab, directive = line.split("..") 

79 whitespace = len(tab) + 4 

80 full_script += f"# {n + 1}: " + line 

81 elif code_block_start is not None: 

82 indent = len(line) - len(line.lstrip()) 

83 if indent < whitespace and indent != 1: 

84 code_block_start = None 

85 whitespace = 0 

86 full_script += script + "\n" 

87 script = "" 

88 elif indent == 1: 

89 script += "\n" 

90 else: 

91 script += line[whitespace:] 

92 else: 

93 full_script += f"# {n + 1}: " + line 

94 

95 try: 

96 exec(full_script) 

97 except Exception: 

98 exc_info = sys.exc_info() 

99 try: 

100 msg = f"Error encountered in a docstring for the file {filename}." 

101 raise RuntimeError(msg) 

102 finally: 

103 traceback.print_exception(*exc_info) 

104 del exc_info 

105 

106 

107class ObservationData: 

108 """Generate an image an associated data used to create the image.""" 

109 

110 def __init__( 

111 self, 

112 bands: tuple, 

113 psfs: np.ndarray, 

114 spectra: np.ndarray, 

115 morphs: Sequence[np.ndarray], 

116 centers: Sequence[tuple[int, int]], 

117 model_psf: np.ndarray = None, 

118 yx0: tuple[int, int] = (0, 0), 

119 dtype: DTypeLike = float, 

120 ): 

121 """Initialize the test dataset 

122 

123 Parameters 

124 ---------- 

125 psfs: 

126 The psf in each band as a (bands, Y, X) array. 

127 spectra: 

128 The spectrum of all the components in the image. 

129 morphs: 

130 The morphology for every component in the image. 

131 centers: 

132 The center of every component in the image 

133 model_psf: 

134 The 2D PSF of the model space. 

135 """ 

136 assert len(spectra) == len(morphs) == len(centers) 

137 origins = [ 

138 tuple([center[i] - (morph.shape[i] - 1) // 2 for i in range(len(center))]) 

139 for center, morph in zip(centers, morphs) 

140 ] 

141 # Define the bounding box for each source based on its center 

142 boxes = [Box((15, 15), origin) for center, origin in zip(centers, origins)] 

143 

144 # Create the image with the sources placed according to their boxes 

145 images = np.zeros((3, 35, 35), dtype=dtype) 

146 spectral_box = Box((len(bands),)) 

147 for spectrum, center, morph, bbox in zip(spectra, centers, morphs, boxes): 

148 images[(spectral_box @ (bbox - yx0)).slices] += spectrum[:, None, None] * morph[None, :, :] 

149 

150 diff_kernel = match_kernel(psfs, model_psf[None], padding=3) 

151 convolved = np.array([scipy_convolve(images[b], diff_kernel.image[b], mode="same") for b in range(3)]) 

152 convolved = convolved.astype(dtype) 

153 

154 self.images = Image(images, bands=bands, yx0=yx0) 

155 self.convolved = Image(convolved, bands=bands, yx0=yx0) 

156 self.diff_kernel = diff_kernel 

157 self.morphs = [Image(morph, yx0=origin) for morph, origin in zip(morphs, origins)] 

158 

159 assert self.images.dtype == dtype 

160 assert self.convolved.dtype == dtype 

161 assert self.diff_kernel.image.dtype == dtype 

162 for morph in self.morphs: 

163 assert morph.dtype == dtype 

164 

165 

166class ScarletTestCase(TestCase): 

167 def assertBoxEqual(self, bbox: Box, truth: Box): # noqa: N802 

168 try: 

169 self.assertTupleEqual(bbox.shape, truth.shape) 

170 except AssertionError: 

171 msg = f"Box shapes differ: {bbox.shape}!={truth.shape}" 

172 raise AssertionError(msg) 

173 try: 

174 self.assertTupleEqual(bbox.origin, truth.origin) 

175 except AssertionError: 

176 msg = f"Box origins differ: {bbox.origin}!={truth.origin}" 

177 raise AssertionError(msg) 

178 

179 def assertImageAlmostEqual(self, image: Image, truth: Image, decimal: int = 7): # noqa: N802 

180 if not isinstance(image, Image): 

181 raise AssertionError(f"image is a {type(image)}, not a lsst.scarlet.lite `Image`") 

182 if not isinstance(truth, Image): 

183 raise AssertionError(f"truth is a {type(truth)}, not a lsst.scarlet.lite `Image`") 

184 

185 try: 

186 self.assertTupleEqual(image.bands, truth.bands) 

187 except AssertionError: 

188 msg = f"Mismatched bands:{image.bands} != {truth.bands}" 

189 raise AssertionError(msg) 

190 

191 try: 

192 self.assertTupleEqual(image.bbox.shape, truth.bbox.shape) 

193 self.assertTupleEqual(image.bbox.origin, truth.bbox.origin) 

194 except AssertionError: 

195 msg = f"Bounding boxes do not overlap:\nimage: {image.bbox}\ntruth: {truth.bbox}" 

196 raise AssertionError(msg) 

197 

198 # The images overlap in multi-band image space, 

199 # check the values of the images 

200 assert_almost_equal(image.data, truth.data, decimal=decimal) 

201 

202 def assertImageEqual(self, image: Image, truth: Image): # noqa: N802 

203 self.assertImageAlmostEqual(image, truth) 

204 assert_array_equal(image.data, truth.data) 

205 

206 def assertFactorizedComponentEqual( # noqa: N802 

207 self, 

208 component: FactorizedComponent, 

209 truth: FactorizedComponent, 

210 ): 

211 self.assertTupleEqual(component.bands, truth.bands) 

212 self.assertTupleEqual(component.peak, truth.peak) 

213 np.testing.assert_array_equal(component._spectrum.x, truth._spectrum.x) 

214 np.testing.assert_array_equal(component._morph.x, truth._morph.x) 

215 self.assertBoxEqual(component.bbox, truth.bbox) 

216 self.assertEqual(component.bg_rms, truth.bg_rms) 

217 self.assertEqual(component.bg_thresh, truth.bg_thresh) 

218 self.assertEqual(component.floor, truth.floor) 

219 self.assertEqual(component.padding, truth.padding) 

220 self.assertEqual(component.is_symmetric, truth.is_symmetric) 

221 

222 def assertSourceEqual(self, source: Source, truth: Source): # noqa: N802 

223 self.assertEqual(source.n_components, truth.n_components) 

224 self.assertBoxEqual(source.bbox, truth.bbox) 

225 self.assertTupleEqual(source.bands, truth.bands) 

226 for comp, comp_truth in zip(source.components, truth.components): 

227 self.assertFactorizedComponentEqual( 

228 cast(FactorizedComponent, comp), 

229 cast(FactorizedComponent, comp_truth), 

230 ) 

231 

232 def assertObservationEqual(self, obs: ObservationData, truth: ObservationData): # noqa: N802 

233 self.assertImageEqual(obs.images, truth.images) 

234 self.assertImageEqual(obs.variance, truth.variance) 

235 self.assertImageEqual(obs.weights, truth.weights) 

236 assert_array_equal(obs.psfs, truth.psfs) 

237 assert_array_equal(obs.model_psf, truth.model_psf) 

238 assert_array_equal(obs.noise_rms, truth.noise_rms) 

239 self.assertBoxEqual(obs.bbox, truth.bbox)