Coverage for tests / utils.py: 16%
136 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-17 08:40 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-17 08:40 +0000
1# This file is part of lsst.scarlet.lite.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import sys
23import traceback
24from typing import Sequence, cast
25from unittest import TestCase
27import numpy as np
28from lsst.scarlet.lite.bbox import Box
29from lsst.scarlet.lite.component import FactorizedComponent
30from lsst.scarlet.lite.fft import match_kernel
31from lsst.scarlet.lite.image import Image
32from lsst.scarlet.lite.source import Source
33from lsst.scarlet.lite.utils import integrated_circular_gaussian
34from numpy.testing import assert_almost_equal, assert_array_equal
35from numpy.typing import DTypeLike
36from scipy.signal import convolve as scipy_convolve
38__all__ = ["get_psfs", "ObservationData", "ScarletTestCase"]
41def get_psfs(sigmas: float | Sequence[float]) -> np.ndarray:
42 try:
43 iter(sigmas)
44 except TypeError:
45 sigmas = (sigmas,)
46 psf = [integrated_circular_gaussian(sigma=sigma) for sigma in sigmas]
47 return np.array(psf)
50def execute_doc_scripts(filename: str):
51 """Test python code in docstrings and document files.
53 Any lines not containing code are replaced with a newline character,
54 that way if any of the code blocks fail, the line with the error will
55 match the linenumber in the .rst file or python file with the docstring.
57 Parameters
58 ----------
59 filename:
60 The name of the file to test.
61 """
62 with open(filename) as file:
63 lines = file.readlines()
65 full_script = ""
66 script = ""
67 whitespace = 0
68 code_block_start = None
69 for n, line in enumerate(lines):
70 if ".. code-block:: python" in line:
71 if code_block_start is not None:
72 message = (
73 f"End of the previous code block starting at {code_block_start}"
74 f"was not detected by the new code block starting at {n}"
75 )
76 raise ValueError(message)
77 code_block_start = n
78 tab, directive = line.split("..")
79 whitespace = len(tab) + 4
80 full_script += f"# {n + 1}: " + line
81 elif code_block_start is not None:
82 indent = len(line) - len(line.lstrip())
83 if indent < whitespace and indent != 1:
84 code_block_start = None
85 whitespace = 0
86 full_script += script + "\n"
87 script = ""
88 elif indent == 1:
89 script += "\n"
90 else:
91 script += line[whitespace:]
92 else:
93 full_script += f"# {n + 1}: " + line
95 try:
96 exec(full_script)
97 except Exception:
98 exc_info = sys.exc_info()
99 try:
100 msg = f"Error encountered in a docstring for the file {filename}."
101 raise RuntimeError(msg)
102 finally:
103 traceback.print_exception(*exc_info)
104 del exc_info
107class ObservationData:
108 """Generate an image an associated data used to create the image."""
110 def __init__(
111 self,
112 bands: tuple,
113 psfs: np.ndarray,
114 spectra: np.ndarray,
115 morphs: Sequence[np.ndarray],
116 centers: Sequence[tuple[int, int]],
117 model_psf: np.ndarray = None,
118 yx0: tuple[int, int] = (0, 0),
119 dtype: DTypeLike = float,
120 ):
121 """Initialize the test dataset
123 Parameters
124 ----------
125 psfs:
126 The psf in each band as a (bands, Y, X) array.
127 spectra:
128 The spectrum of all the components in the image.
129 morphs:
130 The morphology for every component in the image.
131 centers:
132 The center of every component in the image
133 model_psf:
134 The 2D PSF of the model space.
135 """
136 assert len(spectra) == len(morphs) == len(centers)
137 origins = [
138 tuple([center[i] - (morph.shape[i] - 1) // 2 for i in range(len(center))])
139 for center, morph in zip(centers, morphs)
140 ]
141 # Define the bounding box for each source based on its center
142 boxes = [Box((15, 15), origin) for center, origin in zip(centers, origins)]
144 # Create the image with the sources placed according to their boxes
145 images = np.zeros((3, 35, 35), dtype=dtype)
146 spectral_box = Box((len(bands),))
147 for spectrum, center, morph, bbox in zip(spectra, centers, morphs, boxes):
148 images[(spectral_box @ (bbox - yx0)).slices] += spectrum[:, None, None] * morph[None, :, :]
150 diff_kernel = match_kernel(psfs, model_psf[None], padding=3)
151 convolved = np.array([scipy_convolve(images[b], diff_kernel.image[b], mode="same") for b in range(3)])
152 convolved = convolved.astype(dtype)
154 self.images = Image(images, bands=bands, yx0=yx0)
155 self.convolved = Image(convolved, bands=bands, yx0=yx0)
156 self.diff_kernel = diff_kernel
157 self.morphs = [Image(morph, yx0=origin) for morph, origin in zip(morphs, origins)]
159 assert self.images.dtype == dtype
160 assert self.convolved.dtype == dtype
161 assert self.diff_kernel.image.dtype == dtype
162 for morph in self.morphs:
163 assert morph.dtype == dtype
166class ScarletTestCase(TestCase):
167 def assertBoxEqual(self, bbox: Box, truth: Box): # noqa: N802
168 try:
169 self.assertTupleEqual(bbox.shape, truth.shape)
170 except AssertionError:
171 msg = f"Box shapes differ: {bbox.shape}!={truth.shape}"
172 raise AssertionError(msg)
173 try:
174 self.assertTupleEqual(bbox.origin, truth.origin)
175 except AssertionError:
176 msg = f"Box origins differ: {bbox.origin}!={truth.origin}"
177 raise AssertionError(msg)
179 def assertImageAlmostEqual(self, image: Image, truth: Image, decimal: int = 7): # noqa: N802
180 if not isinstance(image, Image):
181 raise AssertionError(f"image is a {type(image)}, not a lsst.scarlet.lite `Image`")
182 if not isinstance(truth, Image):
183 raise AssertionError(f"truth is a {type(truth)}, not a lsst.scarlet.lite `Image`")
185 try:
186 self.assertTupleEqual(image.bands, truth.bands)
187 except AssertionError:
188 msg = f"Mismatched bands:{image.bands} != {truth.bands}"
189 raise AssertionError(msg)
191 try:
192 self.assertTupleEqual(image.bbox.shape, truth.bbox.shape)
193 self.assertTupleEqual(image.bbox.origin, truth.bbox.origin)
194 except AssertionError:
195 msg = f"Bounding boxes do not overlap:\nimage: {image.bbox}\ntruth: {truth.bbox}"
196 raise AssertionError(msg)
198 # The images overlap in multi-band image space,
199 # check the values of the images
200 assert_almost_equal(image.data, truth.data, decimal=decimal)
202 def assertImageEqual(self, image: Image, truth: Image): # noqa: N802
203 self.assertImageAlmostEqual(image, truth)
204 assert_array_equal(image.data, truth.data)
206 def assertFactorizedComponentEqual( # noqa: N802
207 self,
208 component: FactorizedComponent,
209 truth: FactorizedComponent,
210 ):
211 self.assertTupleEqual(component.bands, truth.bands)
212 self.assertTupleEqual(component.peak, truth.peak)
213 np.testing.assert_array_equal(component._spectrum.x, truth._spectrum.x)
214 np.testing.assert_array_equal(component._morph.x, truth._morph.x)
215 self.assertBoxEqual(component.bbox, truth.bbox)
216 self.assertEqual(component.bg_rms, truth.bg_rms)
217 self.assertEqual(component.bg_thresh, truth.bg_thresh)
218 self.assertEqual(component.floor, truth.floor)
219 self.assertEqual(component.padding, truth.padding)
220 self.assertEqual(component.is_symmetric, truth.is_symmetric)
222 def assertSourceEqual(self, source: Source, truth: Source): # noqa: N802
223 self.assertEqual(source.n_components, truth.n_components)
224 self.assertBoxEqual(source.bbox, truth.bbox)
225 self.assertTupleEqual(source.bands, truth.bands)
226 for comp, comp_truth in zip(source.components, truth.components):
227 self.assertFactorizedComponentEqual(
228 cast(FactorizedComponent, comp),
229 cast(FactorizedComponent, comp_truth),
230 )
232 def assertObservationEqual(self, obs: ObservationData, truth: ObservationData): # noqa: N802
233 self.assertImageEqual(obs.images, truth.images)
234 self.assertImageEqual(obs.variance, truth.variance)
235 self.assertImageEqual(obs.weights, truth.weights)
236 assert_array_equal(obs.psfs, truth.psfs)
237 assert_array_equal(obs.model_psf, truth.model_psf)
238 assert_array_equal(obs.noise_rms, truth.noise_rms)
239 self.assertBoxEqual(obs.bbox, truth.bbox)