Coverage for tests / test_fit_bootstrap_model.py: 20%
143 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-26 08:58 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-26 08:58 +0000
1# This file is part of multiprofit.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import math
24import astropy.table
25import lsst.gauss2d.fit as g2f
26from lsst.multiprofit.componentconfig import (
27 CentroidConfig,
28 FluxFractionParameterConfig,
29 FluxParameterConfig,
30 GaussianComponentConfig,
31 ParameterConfig,
32 SersicComponentConfig,
33 SersicIndexParameterConfig,
34)
35from lsst.multiprofit.errors import RaDecConversionNotImplementedError
36from lsst.multiprofit.fitting.fit_bootstrap_model import (
37 CatalogExposurePsfBootstrap,
38 CatalogExposureSourcesBootstrap,
39 CatalogPsfBootstrapConfig,
40 CatalogSourceBootstrapConfig,
41 CatalogSourceFitterBootstrap,
42 NoisyObservationConfig,
43 NoisyPsfObservationConfig,
44)
45from lsst.multiprofit.fitting.fit_psf import (
46 CatalogPsfFitter,
47 CatalogPsfFitterConfig,
48 CatalogPsfFitterConfigData,
49)
50from lsst.multiprofit.fitting.fit_source import CatalogSourceFitterConfig, CatalogSourceFitterConfigData
51from lsst.multiprofit.modelconfig import ModelConfig
52from lsst.multiprofit.modeller import ModelFitConfig
53from lsst.multiprofit.observationconfig import CoordinateSystemConfig
54from lsst.multiprofit.plotting import ErrorValues, plot_catalog_bootstrap, plot_loglike
55from lsst.multiprofit.sourceconfig import ComponentGroupConfig, SourceConfig
56from lsst.multiprofit.utils import get_params_uniq
57import numpy as np
58import pytest
60shape_img = (23, 27)
61reff_x_src, reff_y_src, rho_src, nser_src = 2.5, 3.6, -0.25, 2.0
63# TODO: These can be parameterized; should they be?
64compute_errors_no_covar = True
65compute_errors_from_jacobian = True
66include_point_source = False
67n_sources = 3
68# Set to True for interactive debugging (but don't commit)
69plot = False
72@pytest.fixture(scope="module")
73def channels():
74 return {band: g2f.Channel.get(band) for band in ("R", "G", "B")}
77@pytest.fixture(scope="module")
78def config_fitter_psfs(channels) -> dict[g2f.Channel, CatalogExposurePsfBootstrap]:
79 config_datas = {}
80 for idx, (band, channel) in enumerate(channels.items()):
81 n_rows = 17 + idx * 2
82 n_cols = 15 + idx * 2
83 config = CatalogPsfFitterConfig(
84 model=SourceConfig(
85 component_groups={
86 "": ComponentGroupConfig(
87 centroids={
88 "default": CentroidConfig(
89 x=ParameterConfig(value_initial=n_cols / 2.0),
90 y=ParameterConfig(value_initial=n_rows / 2.0),
91 ),
92 },
93 components_gauss={
94 "gauss1": GaussianComponentConfig(
95 flux=FluxParameterConfig(value_initial=1.0, fixed=True),
96 fluxfrac=FluxFractionParameterConfig(value_initial=0.5, fixed=False),
97 size_x=ParameterConfig(value_initial=1.5 + 0.1 * idx),
98 size_y=ParameterConfig(value_initial=1.7 + 0.13 * idx),
99 rho=ParameterConfig(value_initial=-0.035 - 0.007 * idx),
100 ),
101 "gauss2": GaussianComponentConfig(
102 size_x=ParameterConfig(value_initial=3.1 + 0.24 * idx),
103 size_y=ParameterConfig(value_initial=2.7 + 0.16 * idx),
104 rho=ParameterConfig(value_initial=0.06 + 0.012 * idx),
105 fluxfrac=FluxFractionParameterConfig(value_initial=1.0, fixed=True),
106 ),
107 },
108 is_fractional=True,
109 )
110 }
111 ),
112 )
113 config_boot = CatalogPsfBootstrapConfig(
114 observation=NoisyPsfObservationConfig(n_rows=n_rows, n_cols=n_cols, gain=1e5),
115 n_sources=n_sources,
116 )
117 config_data = CatalogExposurePsfBootstrap(config=config, config_boot=config_boot)
118 config_datas[channel] = config_data
120 return config_datas
123@pytest.fixture(scope="module")
124def config_fitter_source(channels) -> CatalogSourceFitterConfigData:
125 config = CatalogSourceFitterConfig(
126 config_fit=ModelFitConfig(fit_linear_iter=3),
127 config_model=ModelConfig(
128 sources={
129 "": SourceConfig(
130 component_groups={
131 "": ComponentGroupConfig(
132 components_gauss=(
133 {
134 "ps": GaussianComponentConfig(
135 flux=FluxParameterConfig(value_initial=1000),
136 rho=ParameterConfig(value_initial=0, fixed=True),
137 size_x=ParameterConfig(value_initial=0, fixed=True),
138 size_y=ParameterConfig(value_initial=0, fixed=True),
139 )
140 }
141 if include_point_source
142 else {}
143 ),
144 components_sersic={
145 "ser": SersicComponentConfig(
146 prior_size_mean=reff_y_src,
147 prior_size_stddev=1.0,
148 prior_axrat_mean=reff_x_src / reff_y_src,
149 prior_axrat_stddev=0.2,
150 flux=FluxParameterConfig(value_initial=5000),
151 rho=ParameterConfig(value_initial=rho_src),
152 size_x=ParameterConfig(value_initial=reff_x_src),
153 size_y=ParameterConfig(value_initial=reff_y_src),
154 sersic_index=SersicIndexParameterConfig(fixed=False, value_initial=1.0),
155 ),
156 },
157 )
158 }
159 ),
160 },
161 ),
162 convert_cen_xy_to_radec=False,
163 compute_errors_no_covar=compute_errors_no_covar,
164 compute_errors_from_jacobian=compute_errors_from_jacobian,
165 )
166 config_data = CatalogSourceFitterConfigData(
167 channels=tuple(channels.values()),
168 config=config,
169 )
170 return config_data
173@pytest.fixture(scope="module")
174def tables_psf_fits(config_fitter_psfs) -> dict[g2f.Channel, astropy.table.Table]:
175 fitter = CatalogPsfFitter()
176 fits = {
177 channel: fitter.fit(
178 catexp=config_fitter_psf,
179 config_data=config_fitter_psf,
180 )
181 for channel, config_fitter_psf in config_fitter_psfs.items()
182 }
183 return fits
186@pytest.fixture(scope="module")
187def config_data_sources(
188 config_fitter_psfs,
189 tables_psf_fits,
190) -> dict[g2f.Channel, CatalogExposureSourcesBootstrap]:
191 config_datas = {}
192 for idx, (channel, config_fitter_psf) in enumerate(config_fitter_psfs.items()):
193 table_psf_fits = tables_psf_fits[channel]
194 n_rows = shape_img[0] + idx * 2
195 n_cols = shape_img[1] + idx * 2
196 config_boot = CatalogSourceBootstrapConfig(
197 observation=NoisyObservationConfig(
198 n_rows=n_rows,
199 n_cols=n_cols,
200 band=channel.name,
201 background=100,
202 coordsys=CoordinateSystemConfig(x_min=-2 + 3 * idx, y_min=5 - 4 * idx),
203 ),
204 n_sources=n_sources,
205 )
206 config_data = CatalogExposureSourcesBootstrap(
207 config_boot=config_boot,
208 table_psf_fits=table_psf_fits,
209 )
210 config_datas[channel] = config_data
212 return config_datas
215def test_fit_psf(config_fitter_psfs, tables_psf_fits):
216 for band, results in tables_psf_fits.items():
217 assert len(results) == n_sources
218 assert np.sum(results["mpf_psf_unknown_flag"]) == 0
219 assert all(np.isfinite(list(results[0].values())))
220 config_data_psf = config_fitter_psfs[band]
221 psf_model_init = config_data_psf.config.make_psf_model()
222 psfdata = CatalogPsfFitterConfigData(config=config_data_psf.config)
223 psf_model_fit = psfdata.psf_model
224 psfdata.init_psf_model(results[0])
225 assert len(psf_model_init.components) == len(psf_model_fit.components)
226 params_init = psf_model_init.parameters()
227 params_fit = psf_model_fit.parameters()
228 assert len(params_init) == len(params_fit)
229 sigma_min_sq = config_data_psf.config.sigma_min**2
230 for p_init, p_meas in zip(params_init, params_fit):
231 assert p_meas.fixed == p_init.fixed
232 if p_meas.fixed:
233 assert p_init.value == p_meas.value
234 else:
235 value = p_meas.value
236 # TODO: come up with better (noise-dependent) thresholds here
237 if isinstance(p_init, g2f.IntegralParameterD):
238 atol, rtol = 0, 0.02
239 elif isinstance(p_init, g2f.ProperFractionParameterD):
240 atol, rtol = 0.1, 0.01
241 elif isinstance(p_init, g2f.RhoParameterD):
242 atol, rtol = 0.05, 0.1
243 elif isinstance(p_init, g2f.SigmaXParameterD) or isinstance(p_init, g2f.SigmaYParameterD):
244 value = math.sqrt(value**2 + sigma_min_sq)
245 else:
246 atol, rtol = 0.01, 0.1
247 assert np.isclose(p_init.value, value, atol=atol, rtol=rtol)
250def test_fit_source(config_fitter_source, config_data_sources):
251 fitter = CatalogSourceFitterBootstrap()
252 # We don't have or need a multiband input catalog - just use the first one
253 catalog_multi = next(iter(config_data_sources.values())).get_catalog()
254 catexps = list(config_data_sources.values())
256 defer_conversion = config_fitter_source.config.defer_radec_conversion
257 config_fitter_source.config.convert_cen_xy_to_radec = True
259 conversion_error_cls = RaDecConversionNotImplementedError
260 conversion_error_key = conversion_error_cls.column_name()
261 fitter.errors_expected[conversion_error_cls] = conversion_error_key
262 config_fitter_source.config.flag_errors[conversion_error_key] = conversion_error_cls.__name__
264 # Test both code paths for failure to convert RA/Dec, returning to original
265 for value in (not defer_conversion, defer_conversion):
266 config_fitter_source.config.defer_radec_conversion = value
267 results = fitter.fit(catalog_multi=catalog_multi, catexps=catexps, config_data=config_fitter_source)
268 assert np.all(results[f"mpf_{conversion_error_key}"] == 1)
270 config_fitter_source.config.convert_cen_xy_to_radec = False
271 results = fitter.fit(catalog_multi=catalog_multi, catexps=catexps, config_data=config_fitter_source)
272 assert len(results) == n_sources
273 assert np.sum(results["mpf_unknown_flag"]) == 0
274 assert all(np.isfinite(list(results[0].values())))
276 model = fitter.get_model(
277 0,
278 catalog_multi=catalog_multi,
279 catexps=catexps,
280 config_data=config_fitter_source,
281 results=results,
282 )
284 model_sources, priors = config_fitter_source.config.make_sources(
285 channels=list(config_data_sources.keys())
286 )
287 model_true = g2f.ModelD(data=model.data, psfmodels=model.psfmodels, sources=model_sources)
288 fitter.initialize_model(model_true, catalog_multi[0], catexps=catexps)
289 params_true = tuple(param.value for param in get_params_uniq(model_true, fixed=False))
290 plot_catalog_bootstrap(
291 results, histtype="step", paramvals_ref=params_true, plot_total_fluxes=True, plot_colors=True
292 )
293 if plot:
294 import matplotlib.pyplot as plt
296 plt.show()
298 variances = []
299 for return_negative in (False, True):
300 variances.append(
301 fitter.modeller.compute_variances(
302 model,
303 transformed=False,
304 options=g2f.HessianOptions(return_negative=return_negative),
305 use_diag_only=True,
306 )
307 )
308 assert np.all(variances[-1] > 0)
309 if return_negative:
310 variances = np.array(variances)
311 variances[variances <= 0] = 0
312 variances = list(variances)
314 # Bootstrap errors
315 model.setup_evaluators(evaluatormode=g2f.EvaluatorMode.image)
316 model.evaluate()
317 img_data_old = []
318 for obs, output in zip(model.data, model.outputs):
319 img_data_old.append(obs.image.data.copy())
320 img = obs.image.data
321 img.flat = output.data.flat
322 options_hessian = g2f.HessianOptions(return_negative=return_negative)
323 variances_bootstrap = fitter.modeller.compute_variances(model, transformed=False, options=options_hessian)
324 variances_bootstrap_diag = fitter.modeller.compute_variances(
325 model, transformed=False, options=options_hessian, use_diag_only=True
326 )
327 for obs, img_datum_old in zip(model.data, img_data_old):
328 obs.image.data.flat = img_datum_old.flat
329 variances_jac = fitter.modeller.compute_variances(model, transformed=False)
330 variances_jac_diag = fitter.modeller.compute_variances(model, transformed=False, use_diag_only=True)
332 errors_plot = {
333 "inv_hess": ErrorValues(values=np.sqrt(variances[0]), kwargs_plot={"linestyle": "-", "color": "r"}),
334 "-inv_hess": ErrorValues(values=np.sqrt(variances[1]), kwargs_plot={"linestyle": "--", "color": "r"}),
335 "inv_jac": ErrorValues(values=np.sqrt(variances_jac), kwargs_plot={"linestyle": "-.", "color": "r"}),
336 "boot_hess": ErrorValues(
337 values=np.sqrt(variances_bootstrap), kwargs_plot={"linestyle": "-", "color": "b"}
338 ),
339 "boot_diag": ErrorValues(
340 values=np.sqrt(variances_bootstrap_diag), kwargs_plot={"linestyle": "--", "color": "b"}
341 ),
342 "boot_jac_diag": ErrorValues(
343 values=np.sqrt(variances_jac_diag), kwargs_plot={"linestyle": "-.", "color": "m"}
344 ),
345 }
346 fig, ax = plot_loglike(model, errors=errors_plot, values_reference=params_true)
347 if plot:
348 plt.tight_layout()
349 plt.show()