Coverage for tests / test_io.py: 14%

98 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-26 08:40 +0000

1# This file is part of lsst.scarlet.lite. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import json 

23import os 

24 

25import numpy as np 

26from lsst.scarlet.lite import Blend, Image, Observation, io 

27from lsst.scarlet.lite.component import CubeComponent 

28from lsst.scarlet.lite.initialization import FactorizedInitialization 

29from lsst.scarlet.lite.operators import Monotonicity 

30from lsst.scarlet.lite.utils import integrated_circular_gaussian 

31from numpy.testing import assert_almost_equal 

32from utils import ScarletTestCase 

33 

34 

35class TestIo(ScarletTestCase): 

36 def setUp(self) -> None: 

37 filename = os.path.join(__file__, "..", "..", "data", "hsc_cosmos_35.npz") 

38 filename = os.path.abspath(filename) 

39 data = np.load(filename) 

40 model_psf = integrated_circular_gaussian(sigma=0.8) 

41 self.detect = np.sum(data["images"], axis=0) 

42 self.centers = np.array([data["catalog"]["y"], data["catalog"]["x"]]).T 

43 bands = data["filters"] 

44 self.observation = Observation( 

45 Image(data["images"], bands=bands), 

46 Image(data["variance"], bands=bands), 

47 Image(1 / data["variance"], bands=bands), 

48 data["psfs"], 

49 model_psf[None], 

50 bands=bands, 

51 ) 

52 monotonicity = Monotonicity((101, 101)) 

53 init = FactorizedInitialization(self.observation, self.centers, monotonicity=monotonicity) 

54 self.blend = Blend(init.sources, self.observation) 

55 

56 def test_json(self): 

57 blend = self.blend 

58 blend.metadata = { 

59 "psf": self.observation.model_psf, 

60 "bands": tuple(str(band) for band in self.observation.bands), 

61 } 

62 blend_data = blend.to_data() 

63 metadata = { 

64 "model_psf": self.observation.model_psf, 

65 } 

66 model_data = io.ScarletModelData( 

67 blends={1: blend_data}, 

68 metadata=metadata, 

69 ) 

70 

71 # Get the json string for the model 

72 model_str = model_data.json() 

73 # Load the model string from the json 

74 model_dict = json.loads(model_str) 

75 # Load the full set of model data classes from the json string 

76 model_data = io.ScarletModelData.parse_obj(model_dict) 

77 metadata = model_data.metadata 

78 self.assertIsNotNone(metadata) 

79 # Convert the model data into scarlet models 

80 loaded_blend = model_data.blends[1].minimal_data_to_blend( 

81 model_psf=metadata["model_psf"], # type: ignore 

82 dtype=blend.observation.dtype, 

83 ) 

84 

85 self.assertEqual(len(blend.sources), len(loaded_blend.sources)) 

86 self.assertEqual(len(blend.components), len(loaded_blend.components)) 

87 self.assertImageAlmostEqual(blend.get_model(), loaded_blend.get_model()) 

88 self.assertBoxEqual(blend.bbox, blend_data.bbox) 

89 

90 for sidx in range(len(blend.sources)): 

91 source1 = blend.sources[sidx] 

92 source2 = loaded_blend.sources[sidx] 

93 self.assertTupleEqual(source1.center, source2.center) 

94 self.assertEqual(len(source1.components), len(source2.components)) 

95 self.assertBoxEqual(source1.bbox, source2.bbox) 

96 for cidx in range(len(source1.components)): 

97 component1 = source1.components[cidx] 

98 component2 = source2.components[cidx] 

99 self.assertEqual(component1.peak, component2.peak) 

100 assert_almost_equal(component1.spectrum, component2.spectrum) 

101 assert_almost_equal(component1.morph, component2.morph) 

102 self.assertBoxEqual(component1.bbox, component2.bbox) 

103 

104 def test_cube_component(self): 

105 blend = self.blend 

106 for i in range(len(blend.sources)): 

107 blend.sources[i].metadata = {"id": f"peak-{i}"} 

108 component = blend.sources[-1].components[-1] 

109 # Replace one of the components with a Free-Form component. 

110 blend.sources[-1].components[-1] = CubeComponent( 

111 model=component.get_model(), 

112 peak=component.peak, 

113 ) 

114 

115 blend_data = blend.to_data() 

116 model_data = io.ScarletModelData( 

117 blends={1: blend_data}, 

118 metadata={ 

119 "model_psf": self.observation.model_psf, 

120 "psf": self.observation.psfs, 

121 "bands": tuple(str(band) for band in self.observation.bands), 

122 }, 

123 ) 

124 

125 # Get the json string for the model 

126 model_str = model_data.json() 

127 # Load the model string from the json 

128 model_dict = json.loads(model_str) 

129 # Load the full set of model data classes from the json string 

130 model_data = io.ScarletModelData.parse_obj(model_dict) 

131 # Convert the model data into scarlet models 

132 loaded_blend = model_data.blends[1].minimal_data_to_blend( 

133 model_psf=model_data.metadata["model_psf"], # type: ignore 

134 bands=model_data.metadata["bands"], # type: ignore 

135 psf=model_data.metadata["psf"], # type: ignore 

136 dtype=blend.observation.dtype, 

137 ) 

138 

139 self.assertEqual(len(blend.sources), len(loaded_blend.sources)) 

140 self.assertEqual(len(blend.components), len(loaded_blend.components)) 

141 self.assertImageAlmostEqual(blend.get_model(), loaded_blend.get_model()) 

142 

143 # Check that the metadata was stored correctly 

144 for i in range(len(blend.sources)): 

145 self.assertEqual(blend.sources[i].metadata, loaded_blend.sources[i].metadata) 

146 

147 def test_legacy_json(self): 

148 blend = self.blend 

149 

150 # Create legacy blend JSON data 

151 blend_data = blend.to_data().as_dict() 

152 encoded_psf = io.utils.numpy_to_json(self.observation.psfs) 

153 blend_data["psf"] = encoded_psf["data"] 

154 blend_data["psf_shape"] = encoded_psf["shape"] 

155 blend_data["bands"] = tuple(str(band) for band in self.observation.bands) 

156 blend_data["psf_center"] = (10, 10) 

157 

158 # Create legacy model data 

159 model_data = io.ScarletModelData(blends={}).as_dict() 

160 model_data["blends"][1] = blend_data 

161 encoded_psf = io.utils.numpy_to_json(self.observation.model_psf) 

162 model_data["psf"] = encoded_psf["data"] 

163 model_data["psfShape"] = encoded_psf["shape"] 

164 

165 # Legacy models were pre-versioning, so delete any version key 

166 model_data.pop("version", None) 

167 blend_data.pop("version", None) 

168 for source in blend_data["sources"].values(): 

169 source.pop("version", None) 

170 for component in source["components"]: 

171 component.pop("version", None) 

172 

173 self.assertIsNone(model_data["metadata"]) 

174 

175 # Get the json string for the model 

176 model_str = json.dumps(model_data) 

177 # Load the model string from the json 

178 model_dict = json.loads(model_str) 

179 # Load the full set of model data classes from the json string 

180 model_data = io.ScarletModelData.parse_obj(model_dict) 

181 metadata = model_data.metadata 

182 self.assertIsNotNone(metadata) 

183 

184 # Convert the model data into scarlet models 

185 loaded_blend = model_data.blends[1].minimal_data_to_blend( 

186 model_psf=metadata["model_psf"], # type: ignore 

187 dtype=blend.observation.dtype, 

188 ) 

189 

190 self.assertEqual(len(blend.sources), len(loaded_blend.sources)) 

191 self.assertEqual(len(blend.components), len(loaded_blend.components)) 

192 self.assertImageAlmostEqual(blend.get_model(), loaded_blend.get_model())