Coverage for tests/utils.py: 15%

109 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-19 10:38 +0000

1# This file is part of lsst.scarlet.lite. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import sys 

23import traceback 

24from typing import Sequence 

25from unittest import TestCase 

26 

27import numpy as np 

28from lsst.scarlet.lite.bbox import Box 

29from lsst.scarlet.lite.fft import match_kernel 

30from lsst.scarlet.lite.image import Image 

31from lsst.scarlet.lite.utils import integrated_circular_gaussian 

32from numpy.testing import assert_almost_equal, assert_array_equal 

33from numpy.typing import DTypeLike 

34from scipy.signal import convolve as scipy_convolve 

35 

36__all__ = ["get_psfs", "ObservationData", "ScarletTestCase"] 

37 

38 

39def get_psfs(sigmas: float | Sequence[float]) -> np.ndarray: 

40 try: 

41 iter(sigmas) 

42 except TypeError: 

43 sigmas = (sigmas,) 

44 psf = [integrated_circular_gaussian(sigma=sigma) for sigma in sigmas] 

45 return np.array(psf) 

46 

47 

48def execute_doc_scripts(filename: str): 

49 """Test python code in docstrings and document files. 

50 

51 Any lines not containing code are replaced with a newline character, 

52 that way if any of the code blocks fail, the line with the error will 

53 match the linenumber in the .rst file or python file with the docstring. 

54 

55 Parameters 

56 ---------- 

57 filename: 

58 The name of the file to test. 

59 """ 

60 with open(filename) as file: 

61 lines = file.readlines() 

62 

63 full_script = "" 

64 script = "" 

65 whitespace = 0 

66 code_block_start = None 

67 for n, line in enumerate(lines): 

68 if ".. code-block:: python" in line: 

69 if code_block_start is not None: 

70 message = ( 

71 f"End of the previous code block starting at {code_block_start}" 

72 f"was not detected by the new code block starting at {n}" 

73 ) 

74 raise ValueError(message) 

75 code_block_start = n 

76 tab, directive = line.split("..") 

77 whitespace = len(tab) + 4 

78 full_script += f"# {n+1}: " + line 

79 elif code_block_start is not None: 

80 indent = len(line) - len(line.lstrip()) 

81 if indent < whitespace and indent != 1: 

82 code_block_start = None 

83 whitespace = 0 

84 full_script += script + "\n" 

85 script = "" 

86 elif indent == 1: 

87 script += "\n" 

88 else: 

89 script += line[whitespace:] 

90 else: 

91 full_script += f"# {n+1}: " + line 

92 

93 try: 

94 exec(full_script) 

95 except Exception: 

96 exc_info = sys.exc_info() 

97 try: 

98 msg = f"Error encountered in a docstring for the file {filename}." 

99 raise RuntimeError(msg) 

100 finally: 

101 traceback.print_exception(*exc_info) 

102 del exc_info 

103 

104 

105class ObservationData: 

106 """Generate an image an associated data used to create the image.""" 

107 

108 def __init__( 

109 self, 

110 bands: tuple, 

111 psfs: np.ndarray, 

112 spectra: np.ndarray, 

113 morphs: Sequence[np.ndarray], 

114 centers: Sequence[tuple[int, int]], 

115 model_psf: np.ndarray = None, 

116 yx0: tuple[int, int] = (0, 0), 

117 dtype: DTypeLike = float, 

118 ): 

119 """Initialize the test dataset 

120 

121 Parameters 

122 ---------- 

123 psfs: 

124 The psf in each band as a (bands, Y, X) array. 

125 spectra: 

126 The spectrum of all the components in the image. 

127 morphs: 

128 The morphology for every component in the image. 

129 centers: 

130 The center of every component in the image 

131 model_psf: 

132 The 2D PSF of the model space. 

133 """ 

134 assert len(spectra) == len(morphs) == len(centers) 

135 origins = [ 

136 tuple([center[i] - (morph.shape[i] - 1) // 2 for i in range(len(center))]) 

137 for center, morph in zip(centers, morphs) 

138 ] 

139 # Define the bounding box for each source based on its center 

140 boxes = [Box((15, 15), origin) for center, origin in zip(centers, origins)] 

141 

142 # Create the image with the sources placed according to their boxes 

143 images = np.zeros((3, 35, 35), dtype=dtype) 

144 spectral_box = Box((len(bands),)) 

145 for spectrum, center, morph, bbox in zip(spectra, centers, morphs, boxes): 

146 images[(spectral_box @ (bbox - yx0)).slices] += spectrum[:, None, None] * morph[None, :, :] 

147 

148 diff_kernel = match_kernel(psfs, model_psf[None], padding=3) 

149 convolved = np.array([scipy_convolve(images[b], diff_kernel.image[b], mode="same") for b in range(3)]) 

150 convolved = convolved.astype(dtype) 

151 

152 self.images = Image(images, bands=bands, yx0=yx0) 

153 self.convolved = Image(convolved, bands=bands, yx0=yx0) 

154 self.diff_kernel = diff_kernel 

155 self.morphs = [Image(morph, yx0=origin) for morph, origin in zip(morphs, origins)] 

156 

157 assert self.images.dtype == dtype 

158 assert self.convolved.dtype == dtype 

159 assert self.diff_kernel.image.dtype == dtype 

160 for morph in self.morphs: 

161 assert morph.dtype == dtype 

162 

163 

164class ScarletTestCase(TestCase): 

165 def assertBoxEqual(self, bbox: Box, truth: Box): # noqa: N802 

166 try: 

167 self.assertTupleEqual(bbox.shape, truth.shape) 

168 except AssertionError: 

169 msg = f"Box shapes differ: {bbox.shape}!={truth.shape}" 

170 raise AssertionError(msg) 

171 try: 

172 self.assertTupleEqual(bbox.origin, truth.origin) 

173 except AssertionError: 

174 msg = f"Box origins differ: {bbox.origin}!={truth.origin}" 

175 raise AssertionError(msg) 

176 

177 def assertImageAlmostEqual(self, image: Image, truth: Image, decimal: int = 7): # noqa: N802 

178 if not isinstance(image, Image): 

179 raise AssertionError(f"image is a {type(image)}, not a lsst.scarlet.lite `Image`") 

180 if not isinstance(truth, Image): 

181 raise AssertionError(f"truth is a {type(truth)}, not a lsst.scarlet.lite `Image`") 

182 

183 try: 

184 self.assertTupleEqual(image.bands, truth.bands) 

185 except AssertionError: 

186 msg = f"Mismatched bands:{image.bands} != {truth.bands}" 

187 raise AssertionError(msg) 

188 

189 try: 

190 self.assertTupleEqual(image.bbox.shape, truth.bbox.shape) 

191 self.assertTupleEqual(image.bbox.origin, truth.bbox.origin) 

192 except AssertionError: 

193 msg = f"Bounding boxes do not overlap:\nimage: {image.bbox}\ntruth: {truth.bbox}" 

194 raise AssertionError(msg) 

195 

196 # The images overlap in multi-band image space, 

197 # check the values of the images 

198 assert_almost_equal(image.data, truth.data, decimal=decimal) 

199 

200 def assertImageEqual(self, image: Image, truth: Image): # noqa: N802 

201 self.assertImageAlmostEqual(image, truth) 

202 assert_array_equal(image.data, truth.data)