Coverage for tests / test_io.py: 14%
98 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-26 08:40 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-26 08:40 +0000
1# This file is part of lsst.scarlet.lite.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import json
23import os
25import numpy as np
26from lsst.scarlet.lite import Blend, Image, Observation, io
27from lsst.scarlet.lite.component import CubeComponent
28from lsst.scarlet.lite.initialization import FactorizedInitialization
29from lsst.scarlet.lite.operators import Monotonicity
30from lsst.scarlet.lite.utils import integrated_circular_gaussian
31from numpy.testing import assert_almost_equal
32from utils import ScarletTestCase
35class TestIo(ScarletTestCase):
36 def setUp(self) -> None:
37 filename = os.path.join(__file__, "..", "..", "data", "hsc_cosmos_35.npz")
38 filename = os.path.abspath(filename)
39 data = np.load(filename)
40 model_psf = integrated_circular_gaussian(sigma=0.8)
41 self.detect = np.sum(data["images"], axis=0)
42 self.centers = np.array([data["catalog"]["y"], data["catalog"]["x"]]).T
43 bands = data["filters"]
44 self.observation = Observation(
45 Image(data["images"], bands=bands),
46 Image(data["variance"], bands=bands),
47 Image(1 / data["variance"], bands=bands),
48 data["psfs"],
49 model_psf[None],
50 bands=bands,
51 )
52 monotonicity = Monotonicity((101, 101))
53 init = FactorizedInitialization(self.observation, self.centers, monotonicity=monotonicity)
54 self.blend = Blend(init.sources, self.observation)
56 def test_json(self):
57 blend = self.blend
58 blend.metadata = {
59 "psf": self.observation.model_psf,
60 "bands": tuple(str(band) for band in self.observation.bands),
61 }
62 blend_data = blend.to_data()
63 metadata = {
64 "model_psf": self.observation.model_psf,
65 }
66 model_data = io.ScarletModelData(
67 blends={1: blend_data},
68 metadata=metadata,
69 )
71 # Get the json string for the model
72 model_str = model_data.json()
73 # Load the model string from the json
74 model_dict = json.loads(model_str)
75 # Load the full set of model data classes from the json string
76 model_data = io.ScarletModelData.parse_obj(model_dict)
77 metadata = model_data.metadata
78 self.assertIsNotNone(metadata)
79 # Convert the model data into scarlet models
80 loaded_blend = model_data.blends[1].minimal_data_to_blend(
81 model_psf=metadata["model_psf"], # type: ignore
82 dtype=blend.observation.dtype,
83 )
85 self.assertEqual(len(blend.sources), len(loaded_blend.sources))
86 self.assertEqual(len(blend.components), len(loaded_blend.components))
87 self.assertImageAlmostEqual(blend.get_model(), loaded_blend.get_model())
88 self.assertBoxEqual(blend.bbox, blend_data.bbox)
90 for sidx in range(len(blend.sources)):
91 source1 = blend.sources[sidx]
92 source2 = loaded_blend.sources[sidx]
93 self.assertTupleEqual(source1.center, source2.center)
94 self.assertEqual(len(source1.components), len(source2.components))
95 self.assertBoxEqual(source1.bbox, source2.bbox)
96 for cidx in range(len(source1.components)):
97 component1 = source1.components[cidx]
98 component2 = source2.components[cidx]
99 self.assertEqual(component1.peak, component2.peak)
100 assert_almost_equal(component1.spectrum, component2.spectrum)
101 assert_almost_equal(component1.morph, component2.morph)
102 self.assertBoxEqual(component1.bbox, component2.bbox)
104 def test_cube_component(self):
105 blend = self.blend
106 for i in range(len(blend.sources)):
107 blend.sources[i].metadata = {"id": f"peak-{i}"}
108 component = blend.sources[-1].components[-1]
109 # Replace one of the components with a Free-Form component.
110 blend.sources[-1].components[-1] = CubeComponent(
111 model=component.get_model(),
112 peak=component.peak,
113 )
115 blend_data = blend.to_data()
116 model_data = io.ScarletModelData(
117 blends={1: blend_data},
118 metadata={
119 "model_psf": self.observation.model_psf,
120 "psf": self.observation.psfs,
121 "bands": tuple(str(band) for band in self.observation.bands),
122 },
123 )
125 # Get the json string for the model
126 model_str = model_data.json()
127 # Load the model string from the json
128 model_dict = json.loads(model_str)
129 # Load the full set of model data classes from the json string
130 model_data = io.ScarletModelData.parse_obj(model_dict)
131 # Convert the model data into scarlet models
132 loaded_blend = model_data.blends[1].minimal_data_to_blend(
133 model_psf=model_data.metadata["model_psf"], # type: ignore
134 bands=model_data.metadata["bands"], # type: ignore
135 psf=model_data.metadata["psf"], # type: ignore
136 dtype=blend.observation.dtype,
137 )
139 self.assertEqual(len(blend.sources), len(loaded_blend.sources))
140 self.assertEqual(len(blend.components), len(loaded_blend.components))
141 self.assertImageAlmostEqual(blend.get_model(), loaded_blend.get_model())
143 # Check that the metadata was stored correctly
144 for i in range(len(blend.sources)):
145 self.assertEqual(blend.sources[i].metadata, loaded_blend.sources[i].metadata)
147 def test_legacy_json(self):
148 blend = self.blend
150 # Create legacy blend JSON data
151 blend_data = blend.to_data().as_dict()
152 encoded_psf = io.utils.numpy_to_json(self.observation.psfs)
153 blend_data["psf"] = encoded_psf["data"]
154 blend_data["psf_shape"] = encoded_psf["shape"]
155 blend_data["bands"] = tuple(str(band) for band in self.observation.bands)
156 blend_data["psf_center"] = (10, 10)
158 # Create legacy model data
159 model_data = io.ScarletModelData(blends={}).as_dict()
160 model_data["blends"][1] = blend_data
161 encoded_psf = io.utils.numpy_to_json(self.observation.model_psf)
162 model_data["psf"] = encoded_psf["data"]
163 model_data["psfShape"] = encoded_psf["shape"]
165 # Legacy models were pre-versioning, so delete any version key
166 model_data.pop("version", None)
167 blend_data.pop("version", None)
168 for source in blend_data["sources"].values():
169 source.pop("version", None)
170 for component in source["components"]:
171 component.pop("version", None)
173 self.assertIsNone(model_data["metadata"])
175 # Get the json string for the model
176 model_str = json.dumps(model_data)
177 # Load the model string from the json
178 model_dict = json.loads(model_str)
179 # Load the full set of model data classes from the json string
180 model_data = io.ScarletModelData.parse_obj(model_dict)
181 metadata = model_data.metadata
182 self.assertIsNotNone(metadata)
184 # Convert the model data into scarlet models
185 loaded_blend = model_data.blends[1].minimal_data_to_blend(
186 model_psf=metadata["model_psf"], # type: ignore
187 dtype=blend.observation.dtype,
188 )
190 self.assertEqual(len(blend.sources), len(loaded_blend.sources))
191 self.assertEqual(len(blend.components), len(loaded_blend.components))
192 self.assertImageAlmostEqual(blend.get_model(), loaded_blend.get_model())