Coverage for tests/test_component.py: 16%

121 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-03-20 03:40 -0700

1# This file is part of scarlet_lite. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from typing import Callable 

23 

24import numpy as np 

25from lsst.scarlet.lite import Box, Image, Parameter 

26from lsst.scarlet.lite.component import ( 

27 Component, 

28 FactorizedComponent, 

29 default_adaprox_parameterization, 

30 default_fista_parameterization, 

31) 

32from lsst.scarlet.lite.operators import Monotonicity 

33from numpy.testing import assert_almost_equal, assert_array_equal 

34from utils import ScarletTestCase 

35 

36 

37class DummyComponent(Component): 

38 def resize(self) -> bool: 

39 pass 

40 

41 def update(self, it: int, input_grad: np.ndarray): 

42 pass 

43 

44 def get_model(self) -> Image: 

45 pass 

46 

47 def parameterize(self, parameterization: Callable) -> None: 

48 parameterization(self) 

49 

50 

51class TestFactorizedComponent(ScarletTestCase): 

52 def setUp(self) -> None: 

53 spectrum = np.arange(3).astype(np.float32) 

54 morph = np.arange(20).reshape(4, 5).astype(np.float32) 

55 bands = ("g", "r", "i") 

56 bbox = Box((4, 5), (22, 31)) 

57 self.model_box = Box((100, 100)) 

58 center = (24, 33) 

59 

60 self.component = FactorizedComponent( 

61 bands, 

62 spectrum, 

63 morph, 

64 bbox, 

65 center, 

66 ) 

67 

68 self.bands = bands 

69 self.spectrum = spectrum 

70 self.morph = morph 

71 self.full_shape = (3, 100, 100) 

72 

73 def test_constructor(self): 

74 # Test with only required parameters 

75 component = FactorizedComponent( 

76 self.bands, 

77 self.spectrum, 

78 self.morph, 

79 self.component.bbox, 

80 ) 

81 

82 self.assertIsInstance(component._spectrum, Parameter) 

83 assert_array_equal(component.spectrum, self.spectrum) 

84 self.assertIsInstance(component._morph, Parameter) 

85 assert_array_equal(component.morph, self.morph) 

86 self.assertBoxEqual(component.bbox, self.component.bbox) 

87 self.assertIsNone(component.peak) 

88 self.assertIsNone(component.bg_rms) 

89 self.assertEqual(component.bg_thresh, 0.25) 

90 self.assertEqual(component.floor, 1e-20) 

91 self.assertTupleEqual(component.shape, (3, 4, 5)) 

92 

93 # Test that parameters are passed through 

94 center = self.component.peak 

95 bg_rms = np.arange(5) / 10 

96 bg_thresh = 0.9 

97 floor = 1e-10 

98 

99 component = FactorizedComponent( 

100 self.bands, 

101 self.spectrum, 

102 self.morph, 

103 self.component.bbox, 

104 center, 

105 bg_rms, 

106 bg_thresh, 

107 floor, 

108 ) 

109 

110 self.assertTupleEqual(component.peak, center) 

111 assert_array_equal(component.bg_rms, bg_rms) # type: ignore 

112 self.assertEqual(component.bg_thresh, bg_thresh) 

113 self.assertEqual(component.floor, floor) 

114 self.assertEqual(component.get_model().dtype, np.float32) 

115 

116 def test_get_model(self): 

117 component = self.component 

118 assert_array_equal(component.get_model(), self.spectrum[:, None, None] * self.morph[None, :, :]) 

119 

120 # Insert component into a larger model 

121 full_model = np.zeros(self.full_shape) 

122 full_model[:, 22:26, 31:36] = self.spectrum[:, None, None] * self.morph[None, :, :] 

123 

124 test_model = Image(np.zeros(self.full_shape), bands=self.bands) 

125 test_model += component.get_model() 

126 

127 assert_array_equal(test_model.data, full_model) 

128 

129 def test_gradients(self): 

130 component = self.component 

131 morph = self.morph 

132 spectrum = self.spectrum 

133 

134 input_grad = np.array([morph, 2 * morph, 3 * morph]) 

135 true_spectrum_grad = np.array( 

136 [ 

137 np.sum(morph**2), 

138 np.sum(2 * morph**2), 

139 np.sum(3 * morph**2), 

140 ] 

141 ) 

142 assert_almost_equal(component.grad_spectrum(input_grad, spectrum, morph), true_spectrum_grad) 

143 

144 true_morph_grad = np.sum(input_grad * spectrum[:, None, None], axis=0) 

145 assert_almost_equal(component.grad_morph(input_grad, morph, spectrum), true_morph_grad) 

146 

147 def test_proximal_operators(self): 

148 # Test spectrum positivity, morph threshold, and monotonicity 

149 spectrum = np.array([-1, 2, 3], dtype=float) 

150 morph = np.array([[10, 2, 1], [1, 5, 3], [0.1, 4, -1]], dtype=float) 

151 bbox = Box((3, 3), (10, 10)) 

152 morph_bbox = Box((100, 100)) 

153 center = (11, 11) 

154 monotonicity = Monotonicity((101, 101), fit_radius=0) 

155 

156 component = FactorizedComponent( 

157 self.bands, 

158 spectrum.copy(), 

159 morph.copy(), 

160 bbox, 

161 center, 

162 bg_rms=np.array([1, 1, 1]), 

163 bg_thresh=0.5, 

164 monotonicity=monotonicity, 

165 ) 

166 

167 proxed_spectrum = np.array([1e-20, 2, 3]) 

168 proxed_morph = np.array([[2.6666666666666667, 2, 1], [1, 5, 3], [0, 4, 0]]) 

169 proxed_morph = proxed_morph / 5 

170 

171 component.prox_spectrum(component.spectrum) 

172 component.prox_morph(component.morph) 

173 

174 assert_array_equal(component.spectrum, proxed_spectrum) 

175 assert_array_equal(component.morph, proxed_morph) 

176 

177 component = FactorizedComponent( 

178 self.bands, 

179 spectrum.copy(), 

180 morph.copy(), 

181 bbox, 

182 None, 

183 ) 

184 

185 proxed_spectrum = np.array([1e-20, 2, 3]) 

186 proxed_morph = np.array([[10, 2, 1], [1, 5, 3], [0.1, 4, 0]]) 

187 proxed_morph = proxed_morph / 10 

188 

189 component.prox_spectrum(component.spectrum) 

190 component.prox_morph(component.morph) 

191 

192 assert_array_equal(component.spectrum, proxed_spectrum) 

193 assert_array_equal(component.morph, proxed_morph) 

194 

195 self.assertFalse(component.resize(morph_bbox)) 

196 

197 def test_resize(self): 

198 spectrum = np.array([1, 2, 3], dtype=float) 

199 morph = np.zeros((10, 10), dtype=float) 

200 morph[3:6, 5:8] = np.arange(9).reshape(3, 3) 

201 bbox = Box((10, 10), (3, 5)) 

202 

203 morph_bbox = Box((100, 100)) 

204 monotonicity = Monotonicity((101, 101), fit_radius=0) 

205 

206 component = FactorizedComponent( 

207 self.bands, 

208 spectrum.copy(), 

209 morph.copy(), 

210 bbox, 

211 None, 

212 bg_rms=np.array([1, 1, 1]), 

213 bg_thresh=0.5, 

214 monotonicity=monotonicity, 

215 padding=1, 

216 ) 

217 

218 self.assertTupleEqual(component.morph.shape, (10, 10)) 

219 self.assertIsNone(component.component_center) 

220 

221 component.resize(morph_bbox) 

222 self.assertTupleEqual(component.morph.shape, (5, 5)) 

223 self.assertTupleEqual(component.bbox.origin, (5, 9)) 

224 self.assertTupleEqual(component.bbox.shape, (5, 5)) 

225 self.assertIsNone(component.component_center) 

226 

227 def test_parameterization(self): 

228 component = self.component 

229 assert_array_equal(component.get_model(), self.spectrum[:, None, None] * self.morph[None, :, :]) 

230 

231 component.parameterize(default_fista_parameterization) 

232 helpers = set(component._morph.helpers.keys()) 

233 self.assertSetEqual(helpers, {"z"}) 

234 component.parameterize(default_adaprox_parameterization) 

235 helpers = set(component._morph.helpers.keys()) 

236 self.assertSetEqual(helpers, {"m", "v", "vhat"}) 

237 

238 params = (tuple("grizy"), Box((5, 5))) 

239 with self.assertRaises(NotImplementedError): 

240 default_fista_parameterization(DummyComponent(*params)) 

241 

242 with self.assertRaises(NotImplementedError): 

243 default_adaprox_parameterization(DummyComponent(*params))