Coverage for tests / test_fit_bootstrap_model.py: 20%

143 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-30 08:48 +0000

1# This file is part of multiprofit. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import math 

23 

24import astropy.table 

25import lsst.gauss2d.fit as g2f 

26from lsst.multiprofit.componentconfig import ( 

27 CentroidConfig, 

28 FluxFractionParameterConfig, 

29 FluxParameterConfig, 

30 GaussianComponentConfig, 

31 ParameterConfig, 

32 SersicComponentConfig, 

33 SersicIndexParameterConfig, 

34) 

35from lsst.multiprofit.errors import RaDecConversionNotImplementedError 

36from lsst.multiprofit.fitting.fit_bootstrap_model import ( 

37 CatalogExposurePsfBootstrap, 

38 CatalogExposureSourcesBootstrap, 

39 CatalogPsfBootstrapConfig, 

40 CatalogSourceBootstrapConfig, 

41 CatalogSourceFitterBootstrap, 

42 NoisyObservationConfig, 

43 NoisyPsfObservationConfig, 

44) 

45from lsst.multiprofit.fitting.fit_psf import ( 

46 CatalogPsfFitter, 

47 CatalogPsfFitterConfig, 

48 CatalogPsfFitterConfigData, 

49) 

50from lsst.multiprofit.fitting.fit_source import CatalogSourceFitterConfig, CatalogSourceFitterConfigData 

51from lsst.multiprofit.modelconfig import ModelConfig 

52from lsst.multiprofit.modeller import ModelFitConfig 

53from lsst.multiprofit.observationconfig import CoordinateSystemConfig 

54from lsst.multiprofit.plotting import ErrorValues, plot_catalog_bootstrap, plot_loglike 

55from lsst.multiprofit.sourceconfig import ComponentGroupConfig, SourceConfig 

56from lsst.multiprofit.utils import get_params_uniq 

57import numpy as np 

58import pytest 

59 

60shape_img = (23, 27) 

61reff_x_src, reff_y_src, rho_src, nser_src = 2.5, 3.6, -0.25, 2.0 

62 

63# TODO: These can be parameterized; should they be? 

64compute_errors_no_covar = True 

65compute_errors_from_jacobian = True 

66include_point_source = False 

67n_sources = 3 

68# Set to True for interactive debugging (but don't commit) 

69plot = False 

70 

71 

72@pytest.fixture(scope="module") 

73def channels(): 

74 return {band: g2f.Channel.get(band) for band in ("R", "G", "B")} 

75 

76 

77@pytest.fixture(scope="module") 

78def config_fitter_psfs(channels) -> dict[g2f.Channel, CatalogExposurePsfBootstrap]: 

79 config_datas = {} 

80 for idx, (band, channel) in enumerate(channels.items()): 

81 n_rows = 17 + idx * 2 

82 n_cols = 15 + idx * 2 

83 config = CatalogPsfFitterConfig( 

84 model=SourceConfig( 

85 component_groups={ 

86 "": ComponentGroupConfig( 

87 centroids={ 

88 "default": CentroidConfig( 

89 x=ParameterConfig(value_initial=n_cols / 2.0), 

90 y=ParameterConfig(value_initial=n_rows / 2.0), 

91 ), 

92 }, 

93 components_gauss={ 

94 "gauss1": GaussianComponentConfig( 

95 flux=FluxParameterConfig(value_initial=1.0, fixed=True), 

96 fluxfrac=FluxFractionParameterConfig(value_initial=0.5, fixed=False), 

97 size_x=ParameterConfig(value_initial=1.5 + 0.1 * idx), 

98 size_y=ParameterConfig(value_initial=1.7 + 0.13 * idx), 

99 rho=ParameterConfig(value_initial=-0.035 - 0.007 * idx), 

100 ), 

101 "gauss2": GaussianComponentConfig( 

102 size_x=ParameterConfig(value_initial=3.1 + 0.24 * idx), 

103 size_y=ParameterConfig(value_initial=2.7 + 0.16 * idx), 

104 rho=ParameterConfig(value_initial=0.06 + 0.012 * idx), 

105 fluxfrac=FluxFractionParameterConfig(value_initial=1.0, fixed=True), 

106 ), 

107 }, 

108 is_fractional=True, 

109 ) 

110 } 

111 ), 

112 ) 

113 config_boot = CatalogPsfBootstrapConfig( 

114 observation=NoisyPsfObservationConfig(n_rows=n_rows, n_cols=n_cols, gain=1e5), 

115 n_sources=n_sources, 

116 ) 

117 config_data = CatalogExposurePsfBootstrap(config=config, config_boot=config_boot) 

118 config_datas[channel] = config_data 

119 

120 return config_datas 

121 

122 

123@pytest.fixture(scope="module") 

124def config_fitter_source(channels) -> CatalogSourceFitterConfigData: 

125 config = CatalogSourceFitterConfig( 

126 config_fit=ModelFitConfig(fit_linear_iter=3), 

127 config_model=ModelConfig( 

128 sources={ 

129 "": SourceConfig( 

130 component_groups={ 

131 "": ComponentGroupConfig( 

132 components_gauss=( 

133 { 

134 "ps": GaussianComponentConfig( 

135 flux=FluxParameterConfig(value_initial=1000), 

136 rho=ParameterConfig(value_initial=0, fixed=True), 

137 size_x=ParameterConfig(value_initial=0, fixed=True), 

138 size_y=ParameterConfig(value_initial=0, fixed=True), 

139 ) 

140 } 

141 if include_point_source 

142 else {} 

143 ), 

144 components_sersic={ 

145 "ser": SersicComponentConfig( 

146 prior_size_mean=reff_y_src, 

147 prior_size_stddev=1.0, 

148 prior_axrat_mean=reff_x_src / reff_y_src, 

149 prior_axrat_stddev=0.2, 

150 flux=FluxParameterConfig(value_initial=5000), 

151 rho=ParameterConfig(value_initial=rho_src), 

152 size_x=ParameterConfig(value_initial=reff_x_src), 

153 size_y=ParameterConfig(value_initial=reff_y_src), 

154 sersic_index=SersicIndexParameterConfig(fixed=False, value_initial=1.0), 

155 ), 

156 }, 

157 ) 

158 } 

159 ), 

160 }, 

161 ), 

162 convert_cen_xy_to_radec=False, 

163 compute_errors_no_covar=compute_errors_no_covar, 

164 compute_errors_from_jacobian=compute_errors_from_jacobian, 

165 ) 

166 config_data = CatalogSourceFitterConfigData( 

167 channels=tuple(channels.values()), 

168 config=config, 

169 ) 

170 return config_data 

171 

172 

173@pytest.fixture(scope="module") 

174def tables_psf_fits(config_fitter_psfs) -> dict[g2f.Channel, astropy.table.Table]: 

175 fitter = CatalogPsfFitter() 

176 fits = { 

177 channel: fitter.fit( 

178 catexp=config_fitter_psf, 

179 config_data=config_fitter_psf, 

180 ) 

181 for channel, config_fitter_psf in config_fitter_psfs.items() 

182 } 

183 return fits 

184 

185 

186@pytest.fixture(scope="module") 

187def config_data_sources( 

188 config_fitter_psfs, 

189 tables_psf_fits, 

190) -> dict[g2f.Channel, CatalogExposureSourcesBootstrap]: 

191 config_datas = {} 

192 for idx, (channel, config_fitter_psf) in enumerate(config_fitter_psfs.items()): 

193 table_psf_fits = tables_psf_fits[channel] 

194 n_rows = shape_img[0] + idx * 2 

195 n_cols = shape_img[1] + idx * 2 

196 config_boot = CatalogSourceBootstrapConfig( 

197 observation=NoisyObservationConfig( 

198 n_rows=n_rows, 

199 n_cols=n_cols, 

200 band=channel.name, 

201 background=100, 

202 coordsys=CoordinateSystemConfig(x_min=-2 + 3 * idx, y_min=5 - 4 * idx), 

203 ), 

204 n_sources=n_sources, 

205 ) 

206 config_data = CatalogExposureSourcesBootstrap( 

207 config_boot=config_boot, 

208 table_psf_fits=table_psf_fits, 

209 ) 

210 config_datas[channel] = config_data 

211 

212 return config_datas 

213 

214 

215def test_fit_psf(config_fitter_psfs, tables_psf_fits): 

216 for band, results in tables_psf_fits.items(): 

217 assert len(results) == n_sources 

218 assert np.sum(results["mpf_psf_unknown_flag"]) == 0 

219 assert all(np.isfinite(list(results[0].values()))) 

220 config_data_psf = config_fitter_psfs[band] 

221 psf_model_init = config_data_psf.config.make_psf_model() 

222 psfdata = CatalogPsfFitterConfigData(config=config_data_psf.config) 

223 psf_model_fit = psfdata.psf_model 

224 psfdata.init_psf_model(results[0]) 

225 assert len(psf_model_init.components) == len(psf_model_fit.components) 

226 params_init = psf_model_init.parameters() 

227 params_fit = psf_model_fit.parameters() 

228 assert len(params_init) == len(params_fit) 

229 sigma_min_sq = config_data_psf.config.sigma_min**2 

230 for p_init, p_meas in zip(params_init, params_fit): 

231 assert p_meas.fixed == p_init.fixed 

232 if p_meas.fixed: 

233 assert p_init.value == p_meas.value 

234 else: 

235 value = p_meas.value 

236 # TODO: come up with better (noise-dependent) thresholds here 

237 if isinstance(p_init, g2f.IntegralParameterD): 

238 atol, rtol = 0, 0.02 

239 elif isinstance(p_init, g2f.ProperFractionParameterD): 

240 atol, rtol = 0.1, 0.01 

241 elif isinstance(p_init, g2f.RhoParameterD): 

242 atol, rtol = 0.05, 0.1 

243 elif isinstance(p_init, g2f.SigmaXParameterD) or isinstance(p_init, g2f.SigmaYParameterD): 

244 value = math.sqrt(value**2 + sigma_min_sq) 

245 else: 

246 atol, rtol = 0.01, 0.1 

247 assert np.isclose(p_init.value, value, atol=atol, rtol=rtol) 

248 

249 

250def test_fit_source(config_fitter_source, config_data_sources): 

251 fitter = CatalogSourceFitterBootstrap() 

252 # We don't have or need a multiband input catalog - just use the first one 

253 catalog_multi = next(iter(config_data_sources.values())).get_catalog() 

254 catexps = list(config_data_sources.values()) 

255 

256 defer_conversion = config_fitter_source.config.defer_radec_conversion 

257 config_fitter_source.config.convert_cen_xy_to_radec = True 

258 

259 conversion_error_cls = RaDecConversionNotImplementedError 

260 conversion_error_key = conversion_error_cls.column_name() 

261 fitter.errors_expected[conversion_error_cls] = conversion_error_key 

262 config_fitter_source.config.flag_errors[conversion_error_key] = conversion_error_cls.__name__ 

263 

264 # Test both code paths for failure to convert RA/Dec, returning to original 

265 for value in (not defer_conversion, defer_conversion): 

266 config_fitter_source.config.defer_radec_conversion = value 

267 results = fitter.fit(catalog_multi=catalog_multi, catexps=catexps, config_data=config_fitter_source) 

268 assert np.all(results[f"mpf_{conversion_error_key}"] == 1) 

269 

270 config_fitter_source.config.convert_cen_xy_to_radec = False 

271 results = fitter.fit(catalog_multi=catalog_multi, catexps=catexps, config_data=config_fitter_source) 

272 assert len(results) == n_sources 

273 assert np.sum(results["mpf_unknown_flag"]) == 0 

274 assert all(np.isfinite(list(results[0].values()))) 

275 

276 model = fitter.get_model( 

277 0, 

278 catalog_multi=catalog_multi, 

279 catexps=catexps, 

280 config_data=config_fitter_source, 

281 results=results, 

282 ) 

283 

284 model_sources, priors = config_fitter_source.config.make_sources( 

285 channels=list(config_data_sources.keys()) 

286 ) 

287 model_true = g2f.ModelD(data=model.data, psfmodels=model.psfmodels, sources=model_sources) 

288 fitter.initialize_model(model_true, catalog_multi[0], catexps=catexps) 

289 params_true = tuple(param.value for param in get_params_uniq(model_true, fixed=False)) 

290 plot_catalog_bootstrap( 

291 results, histtype="step", paramvals_ref=params_true, plot_total_fluxes=True, plot_colors=True 

292 ) 

293 if plot: 

294 import matplotlib.pyplot as plt 

295 

296 plt.show() 

297 

298 variances = [] 

299 for return_negative in (False, True): 

300 variances.append( 

301 fitter.modeller.compute_variances( 

302 model, 

303 transformed=False, 

304 options=g2f.HessianOptions(return_negative=return_negative), 

305 use_diag_only=True, 

306 ) 

307 ) 

308 assert np.all(variances[-1] > 0) 

309 if return_negative: 

310 variances = np.array(variances) 

311 variances[variances <= 0] = 0 

312 variances = list(variances) 

313 

314 # Bootstrap errors 

315 model.setup_evaluators(evaluatormode=g2f.EvaluatorMode.image) 

316 model.evaluate() 

317 img_data_old = [] 

318 for obs, output in zip(model.data, model.outputs): 

319 img_data_old.append(obs.image.data.copy()) 

320 img = obs.image.data 

321 img.flat = output.data.flat 

322 options_hessian = g2f.HessianOptions(return_negative=return_negative) 

323 variances_bootstrap = fitter.modeller.compute_variances(model, transformed=False, options=options_hessian) 

324 variances_bootstrap_diag = fitter.modeller.compute_variances( 

325 model, transformed=False, options=options_hessian, use_diag_only=True 

326 ) 

327 for obs, img_datum_old in zip(model.data, img_data_old): 

328 obs.image.data.flat = img_datum_old.flat 

329 variances_jac = fitter.modeller.compute_variances(model, transformed=False) 

330 variances_jac_diag = fitter.modeller.compute_variances(model, transformed=False, use_diag_only=True) 

331 

332 errors_plot = { 

333 "inv_hess": ErrorValues(values=np.sqrt(variances[0]), kwargs_plot={"linestyle": "-", "color": "r"}), 

334 "-inv_hess": ErrorValues(values=np.sqrt(variances[1]), kwargs_plot={"linestyle": "--", "color": "r"}), 

335 "inv_jac": ErrorValues(values=np.sqrt(variances_jac), kwargs_plot={"linestyle": "-.", "color": "r"}), 

336 "boot_hess": ErrorValues( 

337 values=np.sqrt(variances_bootstrap), kwargs_plot={"linestyle": "-", "color": "b"} 

338 ), 

339 "boot_diag": ErrorValues( 

340 values=np.sqrt(variances_bootstrap_diag), kwargs_plot={"linestyle": "--", "color": "b"} 

341 ), 

342 "boot_jac_diag": ErrorValues( 

343 values=np.sqrt(variances_jac_diag), kwargs_plot={"linestyle": "-.", "color": "m"} 

344 ), 

345 } 

346 fig, ax = plot_loglike(model, errors=errors_plot, values_reference=params_true) 

347 if plot: 

348 plt.tight_layout() 

349 plt.show()