Coverage for python / lsst / multiprofit / plotting / plot_loglike.py: 8%
89 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-26 08:58 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-26 08:58 +0000
1# This file is part of multiprofit.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22__all__ = ["plot_loglike"]
24import itertools
26import lsst.gauss2d.fit as g2f
27import matplotlib.pyplot as plt
28import numpy as np
30from ..utils import get_params_uniq
31from .config import linestyles_default
32from .errorvalues import ErrorValues
35def plot_loglike(
36 model: g2f.ModelD,
37 params: list[g2f.ParameterD] | None = None,
38 n_values: int = 15,
39 errors: dict[str, ErrorValues] | None = None,
40 values_reference: np.ndarray | None = None,
41) -> tuple[plt.Figure, plt.Axes]:
42 """Plot the loglikehood and derivatives vs free parameter values around
43 best-fit values.
45 Parameters
46 ----------
47 model
48 The model to evaluate.
49 params
50 Free parameters to plot marginal loglikelihood for.
51 n_values
52 The number of evaluations to make on either side of each param value.
53 errors
54 A dict keyed by label of uncertainties to plot. Values must be the same
55 length as `params`.
56 values_reference
57 Reference values to plot (e.g. true parameter values). Must be the same
58 length as `params`.
60 Returns
61 -------
62 fig, ax
63 Matplotlib figure and axis handles, as returned by plt.subplots.
64 """
65 if errors is None:
66 errors = {}
67 loglike_grads = np.array(model.compute_loglike_grad())
68 loglike_init = np.array(model.evaluate())
70 if params is None:
71 params = tuple(get_params_uniq(model, fixed=False))
73 n_params = len(params)
75 if values_reference is not None and len(values_reference) != n_params:
76 raise ValueError(f"{len(values_reference)=} != {n_params=}")
78 n_rows = n_params
79 fig, ax = plt.subplots(nrows=n_rows, ncols=2, figsize=(10, 3 * n_rows))
80 axes = [ax] if (n_rows == 1) else ax
82 n_loglikes = len(loglike_init)
83 labels = [channel.name for channel in model.data.channels]
84 labels.extend(["prior", "total"])
86 for param in params:
87 param.fixed = True
89 for row, param in enumerate(params):
90 value_init = param.value
91 param.fixed = False
92 values = [value_init]
93 loglikes = [loglike_init * 0]
94 dlls = [loglike_grads[row]]
96 diff_init = 1e-4 * np.sign(loglike_grads[row])
97 diff = diff_init
99 # TODO: This entire scheme should be improved/replaced
100 # It sometimes takes excessively large steps
101 # Option: Try to fit a curve once there are a couple of points
102 # on each side of the peak
103 idx_prev = -1
104 for idx in range(2 * n_values):
105 try:
106 param.value_transformed += diff
107 loglikes_new = np.array(model.evaluate()) - loglike_init
108 dloglike_actual = np.sum(loglikes_new) - np.sum(loglikes[idx_prev])
109 values.append(param.value)
110 loglikes.append(loglikes_new)
111 dloglike_actual_abs = np.abs(dloglike_actual)
112 if dloglike_actual_abs > 1:
113 diff /= dloglike_actual_abs
114 elif dloglike_actual_abs < 0.5:
115 diff /= np.clip(dloglike_actual_abs, 0.2, 0.5)
116 dlls.append(model.compute_loglike_grad()[0])
117 if idx == n_values:
118 diff = -diff_init
119 param.value = value_init
120 idx_prev = 0
121 else:
122 idx_prev = -1
123 except RuntimeError:
124 break
125 param.value = value_init
126 param.fixed = True
128 subplot = axes[row][0]
129 sorted = np.argsort(values)
130 values = np.array(values)[sorted]
131 loglikes = [loglikes[idx] for idx in sorted]
132 dlls = np.array(dlls)[sorted]
134 for idx in range(n_loglikes):
135 subplot.plot(values, [loglike[idx] for loglike in loglikes], label=labels[idx])
136 subplot.plot(values, np.sum(loglikes, axis=1), label=labels[-1])
137 vline_kwargs = dict(ymin=np.min(loglikes) - 1, ymax=np.max(loglikes) + 1, color="k")
138 subplot.vlines(value_init, **vline_kwargs)
140 suffix = f" {param.label}" if param.label else ""
141 subplot.legend()
142 subplot.set_title(f"{param.name}{suffix}")
143 subplot.set_ylabel("loglike")
144 subplot.set_ylim(vline_kwargs["ymin"], vline_kwargs["ymax"])
146 subplot = axes[row][1]
147 subplot.plot(values, dlls)
148 subplot.axhline(0, color="k")
149 subplot.set_ylabel("dloglike/dx")
151 vline_kwargs = dict(ymin=np.min(dlls), ymax=np.max(dlls))
152 subplot.vlines(value_init, **vline_kwargs, color="k", label="fit")
153 if values_reference is not None:
154 subplot.vlines(values_reference[row], **vline_kwargs, color="b", label="ref")
156 cycler_linestyle = itertools.cycle(linestyles_default)
157 for name_error, valerr in errors.items():
158 linestyle = valerr.kwargs_plot.pop("linestyle", next(cycler_linestyle))
159 for idx_ax in range(2):
160 axes[row][idx_ax].vlines(
161 [value_init - valerr.values[row], value_init + valerr.values[row]],
162 linestyles=[linestyle, linestyle],
163 label=name_error if (idx_ax == 1) else None,
164 **valerr.kwargs_plot,
165 **vline_kwargs,
166 )
167 subplot.legend()
169 for param in params:
170 param.fixed = False
172 return fig, ax