Coverage for python / lsst / multiprofit / plotting / plot_catalog_bootstrap.py: 7%
109 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-26 08:58 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-26 08:58 +0000
1# This file is part of multiprofit.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22__all__ = ["plot_catalog_bootstrap"]
24from collections import defaultdict
25from typing import Any, Iterable
27import astropy.table
28import matplotlib.pyplot as plt
29import numpy as np
31from ..fitting.fit_source import CatalogSourceFitterConfig, ModelConfig
32from ..utils import set_config_from_dict
34ln10 = np.log(10)
37def plot_catalog_bootstrap(
38 catalog_bootstrap: astropy.table.Table,
39 n_bins: int | None = None,
40 paramvals_ref: Iterable[np.ndarray] | None = None,
41 plot_total_fluxes: bool = False,
42 plot_colors: bool = False,
43 **kwargs: Any,
44) -> tuple[plt.Figure, plt.Axes]:
45 """Plot a bootstrap catalog for a single source model.
47 Parameters
48 ----------
49 catalog_bootstrap
50 A bootstrap catalog, as returned by
51 `multiprofit.fit_bootstrap_model.CatalogSourceFitterBootstrap`.
52 n_bins
53 The number of bins for parameter value histograms. Default
54 is sqrt(N) with a minimum of 10.
55 paramvals_ref
56 Reference parameter values to plot, if any.
57 plot_total_fluxes
58 Whether to plot total fluxes, not just component.
59 plot_colors
60 Whether to plot colors in addition to fluxes.
61 **kwargs
62 Keyword arguments to pass to matplotlib hist calls.
64 Returns
65 -------
66 fig, ax
67 Matplotlib figure and axis handles, as returned by plt.subplots.
68 """
69 n_sources = len(catalog_bootstrap)
70 if n_bins is None:
71 n_bins = np.max([int(np.ceil(np.sqrt(n_sources))), 10])
73 config = CatalogSourceFitterConfig()
74 config_dict = catalog_bootstrap.meta["config"]
75 # TODO: Figure out if this can be implemented correctly in DM-48911
76 # In the meantime, we don't need to know the ModelConfig to format columns
77 # However, it would be useful to get the band and component name (if any)
78 # given a formatted flux (error) column key
79 config_dict["config_model"] = ModelConfig().toDict()
80 set_config_from_dict(config, config_dict)
81 prefix = config.prefix_column
82 suffix_err = config.suffix_error
83 len_suffix_err = len(suffix_err)
84 # This won't work if the flux format isn't a suffix
85 # TODO: Consider if this can be fixed in DM-48911
86 suffix_flux = config.get_key_flux("", "")
88 # TODO: There are probably better ways of doing this
89 colnames_err = [col for col in catalog_bootstrap.colnames if col.endswith(suffix_err)]
90 colnames_meas = [col[:-len_suffix_err] for col in colnames_err]
91 n_params_init = len(colnames_meas)
92 if paramvals_ref is not None and (len(paramvals_ref) != n_params_init):
93 raise ValueError(f"{len(paramvals_ref)=} != {n_params_init=}")
95 results_good = catalog_bootstrap[catalog_bootstrap[f"{prefix}n_iter"] > 0]
97 if plot_total_fluxes or plot_colors:
98 if paramvals_ref:
99 paramvals_ref = {
100 colname: paramval_ref for colname, paramval_ref in zip(colnames_meas, paramvals_ref)
101 }
102 results_dict = {}
103 for colname_meas, colname_err in zip(colnames_meas, colnames_err):
104 results_dict[colname_meas] = results_good[colname_meas]
105 results_dict[colname_err] = results_good[colname_err]
107 colnames_flux = [colname for colname in colnames_meas if colname.endswith(suffix_flux)]
109 colnames_flux_band = defaultdict(list)
110 colnames_flux_comp = defaultdict(list)
112 for colname in colnames_flux:
113 colname_short = colname.partition(prefix)[-1]
114 comp_band = colname_short.split(suffix_flux)[0]
115 comp, band = comp_band.split("_") if ("_" in comp_band) else ("", comp_band)
116 colnames_flux_band[band].append(colname)
117 colnames_flux_comp[comp].append(colname)
119 n_comps = len(colnames_flux_comp)
121 band_prev = None
122 for band, colnames_band in colnames_flux_band.items():
123 # There's no need to make a total flux column with one component
124 # ... unless there's a component with fixed flux, but that isn't
125 # supported anyway.
126 if n_comps >= 2:
127 for suffix, target in (("", colnames_meas), (suffix_err, colnames_err)):
128 is_err = suffix == suffix_err
129 colname_flux = f"{config.get_key_flux(band=band, label=prefix)}{suffix}"
130 total = np.sum(
131 [results_good[f"{colname}{suffix}"] ** (1 + is_err) for colname in colnames_band],
132 axis=0,
133 )
134 if is_err:
135 total = np.sqrt(total)
136 elif paramvals_ref and plot_total_fluxes:
137 if colname_flux in paramvals_ref:
138 raise RuntimeError(
139 f"Tried to set a new total flux column {colname_flux} but it already exists"
140 )
141 paramvals_ref[colname_flux] = sum(
142 (paramvals_ref[colname] for colname in colnames_band)
143 )
144 results_dict[colname_flux] = total
145 if plot_total_fluxes:
146 target.append(colname_flux)
148 if band_prev:
149 flux_prev, flux = (results_dict[f"{prefix}{b}{suffix_flux}"] for b in (band_prev, band))
150 mag_prev, mag = (-2.5 * np.log10(flux_b) for flux_b in (flux_prev, flux))
151 mag_err_prev, mag_err = (
152 results_dict[f"{prefix}{b}{suffix_flux}{suffix_err}"] / (-0.4 * flux_b * ln10)
153 for b, flux_b in ((band_prev, flux_prev), (band, flux))
154 )
155 colname_color = f"{prefix}{band_prev}-{band}{suffix_flux}"
156 colnames_meas.append(colname_color)
157 colnames_err.append(f"{colname_color}{suffix_err}")
159 results_dict[colname_color] = mag_prev - mag
160 results_dict[f"{colname_color}{suffix_err}"] = 2.5 / ln10 * np.hypot(mag_err, mag_err_prev)
161 if paramvals_ref:
162 mag_prev_ref, mag_ref = (
163 -2.5 * np.log10(paramvals_ref[f"{prefix}{b}{suffix_flux}"]) for b in (band_prev, band)
164 )
165 paramvals_ref[colname_color] = mag_prev_ref - mag_ref
167 band_prev = band
169 results_good = results_dict
170 if paramvals_ref:
171 paramvals_ref = tuple(paramvals_ref.values())
173 n_colnames = len(colnames_err)
174 n_cols = 3
175 n_rows = int(np.ceil(n_colnames / n_cols))
177 fig, ax = plt.subplots(nrows=n_rows, ncols=n_cols, constrained_layout=True)
178 idx_row, idx_col = 0, 0
180 for idx_colname in range(n_colnames):
181 colname_meas = colnames_meas[idx_colname]
182 colname_short = colname_meas.partition(prefix)[-1]
183 values = results_good[colname_meas]
184 errors = results_good[colnames_err[idx_colname]]
185 median = np.median(values)
186 std = np.std(values)
188 median_err = np.median(errors)
190 axis = ax[idx_row][idx_col]
191 axis.hist(values, bins=n_bins, color="b", label="fit values", **kwargs)
193 label = "median +/- stddev"
194 for offset in (-std, 0, std):
195 axis.axvline(median + offset, label=label, color="k")
196 label = None
197 if paramvals_ref is not None:
198 value_ref = paramvals_ref[idx_colname]
199 label_value = f" {value_ref=:.3e} bias={median - value_ref:.3e}"
200 axis.axvline(value_ref, label="reference", color="k", linestyle="--")
201 else:
202 label_value = f" {median=:.3e}"
203 axis.hist(median + errors, bins=n_bins, color="r", label="median + error", **kwargs)
204 axis.set_title(f"{colname_short} {std=:.3e} vs {median_err=:.3e}")
205 axis.set_xlabel(f"{colname_short} {label_value}")
206 axis.legend()
208 idx_col += 1
210 if idx_col == n_cols:
211 idx_row += 1
212 idx_col = 0
214 return fig, ax