Coverage for tests/test_io.py: 19%

66 statements  

« prev     ^ index     » next       coverage.py v7.4.3, created at 2024-03-01 11:54 +0000

1# This file is part of lsst.scarlet.lite. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import json 

23import os 

24 

25import numpy as np 

26from lsst.scarlet.lite import Blend, Image, Observation, io 

27from lsst.scarlet.lite.initialization import FactorizedChi2Initialization 

28from lsst.scarlet.lite.models.free_form import FreeFormComponent 

29from lsst.scarlet.lite.operators import Monotonicity 

30from lsst.scarlet.lite.utils import integrated_circular_gaussian 

31from numpy.testing import assert_almost_equal 

32from utils import ScarletTestCase 

33 

34 

35class TestIo(ScarletTestCase): 

36 def setUp(self) -> None: 

37 filename = os.path.join(__file__, "..", "..", "data", "hsc_cosmos_35.npz") 

38 filename = os.path.abspath(filename) 

39 data = np.load(filename) 

40 model_psf = integrated_circular_gaussian(sigma=0.8) 

41 self.detect = np.sum(data["images"], axis=0) 

42 self.centers = np.array([data["catalog"]["y"], data["catalog"]["x"]]).T 

43 bands = data["filters"] 

44 self.observation = Observation( 

45 Image(data["images"], bands=bands), 

46 Image(data["variance"], bands=bands), 

47 Image(1 / data["variance"], bands=bands), 

48 data["psfs"], 

49 model_psf[None], 

50 bands=bands, 

51 ) 

52 monotonicity = Monotonicity((101, 101)) 

53 init = FactorizedChi2Initialization(self.observation, self.centers, monotonicity=monotonicity) 

54 self.blend = Blend(init.sources, self.observation) 

55 

56 def test_json(self): 

57 blend = self.blend 

58 for i in range(len(blend.sources)): 

59 blend.sources[i].record_id = i * 10 

60 blend.sources[i].peak_id = i 

61 blend_data = io.ScarletBlendData.from_blend(blend, (51, 67)) 

62 model_data = io.ScarletModelData( 

63 psf=self.observation.model_psf, 

64 blends={1: blend_data}, 

65 ) 

66 

67 # Get the json string for the model 

68 model_str = model_data.json() 

69 # Load the model string from the json 

70 model_dict = json.loads(model_str) 

71 # Load the full set of model data classes from the json string 

72 model_data = io.ScarletModelData.parse_obj(model_dict) 

73 # Convert the model data into scarlet models 

74 loaded_blend = model_data.blends[1].minimal_data_to_blend( 

75 model_psf=model_data.psf, 

76 dtype=blend.observation.dtype, 

77 ) 

78 

79 self.assertEqual(len(blend.sources), len(loaded_blend.sources)) 

80 self.assertEqual(len(blend.components), len(loaded_blend.components)) 

81 self.assertImageAlmostEqual(blend.get_model(), loaded_blend.get_model()) 

82 

83 for sidx in range(len(blend.sources)): 

84 source1 = blend.sources[sidx] 

85 source2 = loaded_blend.sources[sidx] 

86 self.assertTupleEqual(source1.center, source2.center) 

87 self.assertEqual(len(source1.components), len(source2.components)) 

88 self.assertBoxEqual(source1.bbox, source2.bbox) 

89 for cidx in range(len(source1.components)): 

90 component1 = source1.components[cidx] 

91 component2 = source2.components[cidx] 

92 self.assertEqual(component1.peak, component2.peak) 

93 assert_almost_equal(component1.spectrum, component2.spectrum) 

94 assert_almost_equal(component1.morph, component2.morph) 

95 self.assertBoxEqual(component1.bbox, component2.bbox) 

96 

97 def test_cube_component(self): 

98 blend = self.blend 

99 for i in range(len(blend.sources)): 

100 blend.sources[i].record_id = i * 10 

101 blend.sources[i].peak_id = i 

102 component = blend.sources[-1].components[-1] 

103 # Replace one of the components with a Free-Form component. 

104 blend.sources[-1].components[-1] = FreeFormComponent( 

105 bands=self.observation.bands, 

106 spectrum=component.spectrum, 

107 morph=component.morph, 

108 model_bbox=self.observation.bbox, 

109 ) 

110 

111 blend_data = io.ScarletBlendData.from_blend(blend, (51, 67)) 

112 model_data = io.ScarletModelData( 

113 psf=self.observation.model_psf, 

114 blends={1: blend_data}, 

115 ) 

116 

117 # Get the json string for the model 

118 model_str = model_data.json() 

119 # Load the model string from the json 

120 model_dict = json.loads(model_str) 

121 # Load the full set of model data classes from the json string 

122 model_data = io.ScarletModelData.parse_obj(model_dict) 

123 # Convert the model data into scarlet models 

124 loaded_blend = model_data.blends[1].minimal_data_to_blend( 

125 model_psf=model_data.psf, 

126 dtype=blend.observation.dtype, 

127 ) 

128 

129 self.assertEqual(len(blend.sources), len(loaded_blend.sources)) 

130 self.assertEqual(len(blend.components), len(loaded_blend.components)) 

131 self.assertImageAlmostEqual(blend.get_model(), loaded_blend.get_model())