Coverage for tests/utils.py: 15%
109 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-03-20 03:40 -0700
« prev ^ index » next coverage.py v7.4.4, created at 2024-03-20 03:40 -0700
1# This file is part of lsst.scarlet.lite.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import sys
23import traceback
24from typing import Sequence
25from unittest import TestCase
27import numpy as np
28from lsst.scarlet.lite.bbox import Box
29from lsst.scarlet.lite.fft import match_kernel
30from lsst.scarlet.lite.image import Image
31from lsst.scarlet.lite.utils import integrated_circular_gaussian
32from numpy.testing import assert_almost_equal, assert_array_equal
33from numpy.typing import DTypeLike
34from scipy.signal import convolve as scipy_convolve
36__all__ = ["get_psfs", "ObservationData", "ScarletTestCase"]
39def get_psfs(sigmas: float | Sequence[float]) -> np.ndarray:
40 try:
41 iter(sigmas)
42 except TypeError:
43 sigmas = (sigmas,)
44 psf = [integrated_circular_gaussian(sigma=sigma) for sigma in sigmas]
45 return np.array(psf)
48def execute_doc_scripts(filename: str):
49 """Test python code in docstrings and document files.
51 Any lines not containing code are replaced with a newline character,
52 that way if any of the code blocks fail, the line with the error will
53 match the linenumber in the .rst file or python file with the docstring.
55 Parameters
56 ----------
57 filename:
58 The name of the file to test.
59 """
60 with open(filename) as file:
61 lines = file.readlines()
63 full_script = ""
64 script = ""
65 whitespace = 0
66 code_block_start = None
67 for n, line in enumerate(lines):
68 if ".. code-block:: python" in line:
69 if code_block_start is not None:
70 message = (
71 f"End of the previous code block starting at {code_block_start}"
72 f"was not detected by the new code block starting at {n}"
73 )
74 raise ValueError(message)
75 code_block_start = n
76 tab, directive = line.split("..")
77 whitespace = len(tab) + 4
78 full_script += f"# {n+1}: " + line
79 elif code_block_start is not None:
80 indent = len(line) - len(line.lstrip())
81 if indent < whitespace and indent != 1:
82 code_block_start = None
83 whitespace = 0
84 full_script += script + "\n"
85 script = ""
86 elif indent == 1:
87 script += "\n"
88 else:
89 script += line[whitespace:]
90 else:
91 full_script += f"# {n+1}: " + line
93 try:
94 exec(full_script)
95 except Exception:
96 exc_info = sys.exc_info()
97 try:
98 msg = f"Error encountered in a docstring for the file {filename}."
99 raise RuntimeError(msg)
100 finally:
101 traceback.print_exception(*exc_info)
102 del exc_info
105class ObservationData:
106 """Generate an image an associated data used to create the image."""
108 def __init__(
109 self,
110 bands: tuple,
111 psfs: np.ndarray,
112 spectra: np.ndarray,
113 morphs: Sequence[np.ndarray],
114 centers: Sequence[tuple[int, int]],
115 model_psf: np.ndarray = None,
116 yx0: tuple[int, int] = (0, 0),
117 dtype: DTypeLike = float,
118 ):
119 """Initialize the test dataset
121 Parameters
122 ----------
123 psfs:
124 The psf in each band as a (bands, Y, X) array.
125 spectra:
126 The spectrum of all the components in the image.
127 morphs:
128 The morphology for every component in the image.
129 centers:
130 The center of every component in the image
131 model_psf:
132 The 2D PSF of the model space.
133 """
134 assert len(spectra) == len(morphs) == len(centers)
135 origins = [
136 tuple([center[i] - (morph.shape[i] - 1) // 2 for i in range(len(center))])
137 for center, morph in zip(centers, morphs)
138 ]
139 # Define the bounding box for each source based on its center
140 boxes = [Box((15, 15), origin) for center, origin in zip(centers, origins)]
142 # Create the image with the sources placed according to their boxes
143 images = np.zeros((3, 35, 35), dtype=dtype)
144 spectral_box = Box((len(bands),))
145 for spectrum, center, morph, bbox in zip(spectra, centers, morphs, boxes):
146 images[(spectral_box @ (bbox - yx0)).slices] += spectrum[:, None, None] * morph[None, :, :]
148 diff_kernel = match_kernel(psfs, model_psf[None], padding=3)
149 convolved = np.array([scipy_convolve(images[b], diff_kernel.image[b], mode="same") for b in range(3)])
150 convolved = convolved.astype(dtype)
152 self.images = Image(images, bands=bands, yx0=yx0)
153 self.convolved = Image(convolved, bands=bands, yx0=yx0)
154 self.diff_kernel = diff_kernel
155 self.morphs = [Image(morph, yx0=origin) for morph, origin in zip(morphs, origins)]
157 assert self.images.dtype == dtype
158 assert self.convolved.dtype == dtype
159 assert self.diff_kernel.image.dtype == dtype
160 for morph in self.morphs:
161 assert morph.dtype == dtype
164class ScarletTestCase(TestCase):
165 def assertBoxEqual(self, bbox: Box, truth: Box): # noqa: N802
166 try:
167 self.assertTupleEqual(bbox.shape, truth.shape)
168 except AssertionError:
169 msg = f"Box shapes differ: {bbox.shape}!={truth.shape}"
170 raise AssertionError(msg)
171 try:
172 self.assertTupleEqual(bbox.origin, truth.origin)
173 except AssertionError:
174 msg = f"Box origins differ: {bbox.origin}!={truth.origin}"
175 raise AssertionError(msg)
177 def assertImageAlmostEqual(self, image: Image, truth: Image, decimal: int = 7): # noqa: N802
178 if not isinstance(image, Image):
179 raise AssertionError(f"image is a {type(image)}, not a lsst.scarlet.lite `Image`")
180 if not isinstance(truth, Image):
181 raise AssertionError(f"truth is a {type(truth)}, not a lsst.scarlet.lite `Image`")
183 try:
184 self.assertTupleEqual(image.bands, truth.bands)
185 except AssertionError:
186 msg = f"Mismatched bands:{image.bands} != {truth.bands}"
187 raise AssertionError(msg)
189 try:
190 self.assertTupleEqual(image.bbox.shape, truth.bbox.shape)
191 self.assertTupleEqual(image.bbox.origin, truth.bbox.origin)
192 except AssertionError:
193 msg = f"Bounding boxes do not overlap:\nimage: {image.bbox}\ntruth: {truth.bbox}"
194 raise AssertionError(msg)
196 # The images overlap in multi-band image space,
197 # check the values of the images
198 assert_almost_equal(image.data, truth.data, decimal=decimal)
200 def assertImageEqual(self, image: Image, truth: Image): # noqa: N802
201 self.assertImageAlmostEqual(image, truth)
202 assert_array_equal(image.data, truth.data)