Coverage for tests / test_sourceconfig.py: 23%

61 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-28 08:43 +0000

1# This file is part of multiprofit. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import lsst.gauss2d.fit as g2f 

23from lsst.multiprofit.componentconfig import ( 

24 GaussianComponentConfig, 

25 ParameterConfig, 

26 SersicComponentConfig, 

27 SersicIndexParameterConfig, 

28) 

29from lsst.multiprofit.sourceconfig import ComponentGroupConfig, SourceConfig 

30from lsst.multiprofit.utils import get_params_uniq 

31import numpy as np 

32import pytest 

33 

34 

35@pytest.fixture(scope="module") 

36def centroid_limits(): 

37 limits = g2f.LimitsD(min=-np.inf, max=np.inf) 

38 return limits 

39 

40 

41@pytest.fixture(scope="module") 

42def centroid(centroid_limits): 

43 cenx = g2f.CentroidXParameterD(0, limits=centroid_limits, fixed=True) 

44 ceny = g2f.CentroidYParameterD(0, limits=centroid_limits, fixed=True) 

45 centroid = g2f.CentroidParameters(cenx, ceny) 

46 return centroid 

47 

48 

49@pytest.fixture(scope="module") 

50def channels(): 

51 return {band: g2f.Channel.get(band) for band in ("R", "G", "B")} 

52 

53 

54def test_ComponentGroupConfig(centroid): 

55 with pytest.raises(ValueError): 

56 config = ComponentGroupConfig( 

57 components_gauss={"x": GaussianComponentConfig()}, 

58 components_sersic={"x": SersicComponentConfig()}, 

59 ) 

60 config.validate() 

61 

62 

63def test_SourceConfig_base(): 

64 with pytest.raises(ValueError): 

65 config = SourceConfig() 

66 config.validate() 

67 

68 with pytest.raises(ValueError): 

69 config = SourceConfig(component_groups={}) 

70 config.validate() 

71 

72 

73def test_SourceConfig_fractional(centroid): 

74 rho, size_x, size_y = -0.3, 1.4, 1.6 

75 drho, dsize_x, dsize_y = 0.5, 1.6, 1.3 

76 

77 n_components = 2 

78 config = SourceConfig( 

79 component_groups={ 

80 "src": ComponentGroupConfig( 

81 components_gauss={ 

82 str(idx): GaussianComponentConfig( 

83 rho=ParameterConfig(value_initial=rho + idx * drho), 

84 size_x=ParameterConfig(value_initial=size_x + idx * dsize_x), 

85 size_y=ParameterConfig(value_initial=size_y + idx * dsize_y), 

86 ) 

87 for idx in range(n_components) 

88 }, 

89 is_fractional=True, 

90 ) 

91 }, 

92 ) 

93 config.validate() 

94 channel = g2f.Channel.NONE 

95 psf_model, priors = config.make_psf_model( 

96 [ 

97 [ 

98 {channel: 1.0}, 

99 {channel: 0.5}, 

100 ] 

101 ], 

102 ) 

103 assert len(priors) == 0 

104 assert len(psf_model.components) == n_components 

105 

106 

107def test_SourceConfig_linear(centroid, channels): 

108 rho, size_x, size_y, sersicn, flux = 0.4, 1.5, 1.9, 0.5, 4.7 

109 drho, dsize_x, dsize_y, dsersicn, dflux = -0.9, 2.5, 5.4, 2.8, 13.9 

110 

111 names = ("PS", "Sersic") 

112 config = SourceConfig( 

113 component_groups={ 

114 "src": ComponentGroupConfig( 

115 components_sersic={ 

116 name: SersicComponentConfig( 

117 rho=ParameterConfig(value_initial=rho + idx * drho), 

118 size_x=ParameterConfig(value_initial=size_x + idx * dsize_x), 

119 size_y=ParameterConfig(value_initial=size_y + idx * dsize_y), 

120 sersic_index=SersicIndexParameterConfig( 

121 value_initial=sersicn + idx * dsersicn, 

122 fixed=idx == 0, 

123 prior_mean=None, 

124 ), 

125 ) 

126 for idx, name in enumerate(names) 

127 } 

128 ), 

129 } 

130 ) 

131 fluxes = [ 

132 { 

133 channel: flux + idx_channel * dflux * idx_comp 

134 for idx_channel, channel in enumerate(channels.values()) 

135 } 

136 for idx_comp in range(len(config.component_groups["src"].components_sersic)) 

137 ] 

138 source, priors = config.make_source([fluxes]) 

139 assert len(priors) == 0 

140 for idx, component in enumerate(source.components): 

141 params = get_params_uniq(component) 

142 values_init = { 

143 g2f.RhoParameterD: rho + idx * drho, 

144 g2f.ReffXParameterD: size_x + idx * dsize_x, 

145 g2f.ReffYParameterD: size_y + idx * dsize_y, 

146 g2f.SersicIndexParameterD: sersicn + idx * dsersicn, 

147 } 

148 for name_group, component_group in config.component_groups.items(): 

149 fluxes_comp = fluxes[idx] 

150 name_comp = names[idx] 

151 config_comp = component_group.components_sersic[name_comp] 

152 fluxes_label = { 

153 config.format_label( 

154 component_group.format_label( 

155 label=config_comp.format_label( 

156 label=config.get_integral_label_default(), name_channel=channel.name 

157 ), 

158 name_component=name_comp, 

159 ), 

160 name_group=name_group, 

161 ): fluxes_comp[channel] 

162 for channel in channels.values() 

163 } 

164 for param in params: 

165 if isinstance(param, g2f.IntegralParameterD): 

166 assert fluxes_label[param.label] == param.value 

167 elif value_init := values_init.get(param.__class__): 

168 assert param.value == value_init