Coverage for python / lsst / multiprofit / plotting / plot_catalog_bootstrap.py: 7%

109 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-14 23:46 +0000

1# This file is part of multiprofit. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = ["plot_catalog_bootstrap"] 

23 

24from collections import defaultdict 

25from typing import Any, Iterable 

26 

27import astropy.table 

28import matplotlib.pyplot as plt 

29import numpy as np 

30 

31from ..fitting.fit_source import CatalogSourceFitterConfig, ModelConfig 

32from ..utils import set_config_from_dict 

33 

34ln10 = np.log(10) 

35 

36 

37def plot_catalog_bootstrap( 

38 catalog_bootstrap: astropy.table.Table, 

39 n_bins: int | None = None, 

40 paramvals_ref: Iterable[np.ndarray] | None = None, 

41 plot_total_fluxes: bool = False, 

42 plot_colors: bool = False, 

43 **kwargs: Any, 

44) -> tuple[plt.Figure, plt.Axes]: 

45 """Plot a bootstrap catalog for a single source model. 

46 

47 Parameters 

48 ---------- 

49 catalog_bootstrap 

50 A bootstrap catalog, as returned by 

51 `multiprofit.fit_bootstrap_model.CatalogSourceFitterBootstrap`. 

52 n_bins 

53 The number of bins for parameter value histograms. Default 

54 is sqrt(N) with a minimum of 10. 

55 paramvals_ref 

56 Reference parameter values to plot, if any. 

57 plot_total_fluxes 

58 Whether to plot total fluxes, not just component. 

59 plot_colors 

60 Whether to plot colors in addition to fluxes. 

61 **kwargs 

62 Keyword arguments to pass to matplotlib hist calls. 

63 

64 Returns 

65 ------- 

66 fig, ax 

67 Matplotlib figure and axis handles, as returned by plt.subplots. 

68 """ 

69 n_sources = len(catalog_bootstrap) 

70 if n_bins is None: 

71 n_bins = np.max([int(np.ceil(np.sqrt(n_sources))), 10]) 

72 

73 config = CatalogSourceFitterConfig() 

74 config_dict = catalog_bootstrap.meta["config"] 

75 # TODO: Figure out if this can be implemented correctly in DM-48911 

76 # In the meantime, we don't need to know the ModelConfig to format columns 

77 # However, it would be useful to get the band and component name (if any) 

78 # given a formatted flux (error) column key 

79 config_dict["config_model"] = ModelConfig().toDict() 

80 set_config_from_dict(config, config_dict) 

81 prefix = config.prefix_column 

82 suffix_err = config.suffix_error 

83 len_suffix_err = len(suffix_err) 

84 # This won't work if the flux format isn't a suffix 

85 # TODO: Consider if this can be fixed in DM-48911 

86 suffix_flux = config.get_key_flux("", "") 

87 

88 # TODO: There are probably better ways of doing this 

89 colnames_err = [col for col in catalog_bootstrap.colnames if col.endswith(suffix_err)] 

90 colnames_meas = [col[:-len_suffix_err] for col in colnames_err] 

91 n_params_init = len(colnames_meas) 

92 if paramvals_ref is not None and (len(paramvals_ref) != n_params_init): 

93 raise ValueError(f"{len(paramvals_ref)=} != {n_params_init=}") 

94 

95 results_good = catalog_bootstrap[catalog_bootstrap[f"{prefix}n_iter"] > 0] 

96 

97 if plot_total_fluxes or plot_colors: 

98 if paramvals_ref: 

99 paramvals_ref = { 

100 colname: paramval_ref for colname, paramval_ref in zip(colnames_meas, paramvals_ref) 

101 } 

102 results_dict = {} 

103 for colname_meas, colname_err in zip(colnames_meas, colnames_err): 

104 results_dict[colname_meas] = results_good[colname_meas] 

105 results_dict[colname_err] = results_good[colname_err] 

106 

107 colnames_flux = [colname for colname in colnames_meas if colname.endswith(suffix_flux)] 

108 

109 colnames_flux_band = defaultdict(list) 

110 colnames_flux_comp = defaultdict(list) 

111 

112 for colname in colnames_flux: 

113 colname_short = colname.partition(prefix)[-1] 

114 comp_band = colname_short.split(suffix_flux)[0] 

115 comp, band = comp_band.split("_") if ("_" in comp_band) else ("", comp_band) 

116 colnames_flux_band[band].append(colname) 

117 colnames_flux_comp[comp].append(colname) 

118 

119 n_comps = len(colnames_flux_comp) 

120 

121 band_prev = None 

122 for band, colnames_band in colnames_flux_band.items(): 

123 # There's no need to make a total flux column with one component 

124 # ... unless there's a component with fixed flux, but that isn't 

125 # supported anyway. 

126 if n_comps >= 2: 

127 for suffix, target in (("", colnames_meas), (suffix_err, colnames_err)): 

128 is_err = suffix == suffix_err 

129 colname_flux = f"{config.get_key_flux(band=band, label=prefix)}{suffix}" 

130 total = np.sum( 

131 [results_good[f"{colname}{suffix}"] ** (1 + is_err) for colname in colnames_band], 

132 axis=0, 

133 ) 

134 if is_err: 

135 total = np.sqrt(total) 

136 elif paramvals_ref and plot_total_fluxes: 

137 if colname_flux in paramvals_ref: 

138 raise RuntimeError( 

139 f"Tried to set a new total flux column {colname_flux} but it already exists" 

140 ) 

141 paramvals_ref[colname_flux] = sum( 

142 (paramvals_ref[colname] for colname in colnames_band) 

143 ) 

144 results_dict[colname_flux] = total 

145 if plot_total_fluxes: 

146 target.append(colname_flux) 

147 

148 if band_prev: 

149 flux_prev, flux = (results_dict[f"{prefix}{b}{suffix_flux}"] for b in (band_prev, band)) 

150 mag_prev, mag = (-2.5 * np.log10(flux_b) for flux_b in (flux_prev, flux)) 

151 mag_err_prev, mag_err = ( 

152 results_dict[f"{prefix}{b}{suffix_flux}{suffix_err}"] / (-0.4 * flux_b * ln10) 

153 for b, flux_b in ((band_prev, flux_prev), (band, flux)) 

154 ) 

155 colname_color = f"{prefix}{band_prev}-{band}{suffix_flux}" 

156 colnames_meas.append(colname_color) 

157 colnames_err.append(f"{colname_color}{suffix_err}") 

158 

159 results_dict[colname_color] = mag_prev - mag 

160 results_dict[f"{colname_color}{suffix_err}"] = 2.5 / ln10 * np.hypot(mag_err, mag_err_prev) 

161 if paramvals_ref: 

162 mag_prev_ref, mag_ref = ( 

163 -2.5 * np.log10(paramvals_ref[f"{prefix}{b}{suffix_flux}"]) for b in (band_prev, band) 

164 ) 

165 paramvals_ref[colname_color] = mag_prev_ref - mag_ref 

166 

167 band_prev = band 

168 

169 results_good = results_dict 

170 if paramvals_ref: 

171 paramvals_ref = tuple(paramvals_ref.values()) 

172 

173 n_colnames = len(colnames_err) 

174 n_cols = 3 

175 n_rows = int(np.ceil(n_colnames / n_cols)) 

176 

177 fig, ax = plt.subplots(nrows=n_rows, ncols=n_cols, constrained_layout=True) 

178 idx_row, idx_col = 0, 0 

179 

180 for idx_colname in range(n_colnames): 

181 colname_meas = colnames_meas[idx_colname] 

182 colname_short = colname_meas.partition(prefix)[-1] 

183 values = results_good[colname_meas] 

184 errors = results_good[colnames_err[idx_colname]] 

185 median = np.median(values) 

186 std = np.std(values) 

187 

188 median_err = np.median(errors) 

189 

190 axis = ax[idx_row][idx_col] 

191 axis.hist(values, bins=n_bins, color="b", label="fit values", **kwargs) 

192 

193 label = "median +/- stddev" 

194 for offset in (-std, 0, std): 

195 axis.axvline(median + offset, label=label, color="k") 

196 label = None 

197 if paramvals_ref is not None: 

198 value_ref = paramvals_ref[idx_colname] 

199 label_value = f" {value_ref=:.3e} bias={median - value_ref:.3e}" 

200 axis.axvline(value_ref, label="reference", color="k", linestyle="--") 

201 else: 

202 label_value = f" {median=:.3e}" 

203 axis.hist(median + errors, bins=n_bins, color="r", label="median + error", **kwargs) 

204 axis.set_title(f"{colname_short} {std=:.3e} vs {median_err=:.3e}") 

205 axis.set_xlabel(f"{colname_short} {label_value}") 

206 axis.legend() 

207 

208 idx_col += 1 

209 

210 if idx_col == n_cols: 

211 idx_row += 1 

212 idx_col = 0 

213 

214 return fig, ax