Coverage for python / lsst / multiprofit / plotting / plot_loglike.py: 8%

89 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-14 23:46 +0000

1# This file is part of multiprofit. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = ["plot_loglike"] 

23 

24import itertools 

25 

26import lsst.gauss2d.fit as g2f 

27import matplotlib.pyplot as plt 

28import numpy as np 

29 

30from ..utils import get_params_uniq 

31from .config import linestyles_default 

32from .errorvalues import ErrorValues 

33 

34 

35def plot_loglike( 

36 model: g2f.ModelD, 

37 params: list[g2f.ParameterD] | None = None, 

38 n_values: int = 15, 

39 errors: dict[str, ErrorValues] | None = None, 

40 values_reference: np.ndarray | None = None, 

41) -> tuple[plt.Figure, plt.Axes]: 

42 """Plot the loglikehood and derivatives vs free parameter values around 

43 best-fit values. 

44 

45 Parameters 

46 ---------- 

47 model 

48 The model to evaluate. 

49 params 

50 Free parameters to plot marginal loglikelihood for. 

51 n_values 

52 The number of evaluations to make on either side of each param value. 

53 errors 

54 A dict keyed by label of uncertainties to plot. Values must be the same 

55 length as `params`. 

56 values_reference 

57 Reference values to plot (e.g. true parameter values). Must be the same 

58 length as `params`. 

59 

60 Returns 

61 ------- 

62 fig, ax 

63 Matplotlib figure and axis handles, as returned by plt.subplots. 

64 """ 

65 if errors is None: 

66 errors = {} 

67 loglike_grads = np.array(model.compute_loglike_grad()) 

68 loglike_init = np.array(model.evaluate()) 

69 

70 if params is None: 

71 params = tuple(get_params_uniq(model, fixed=False)) 

72 

73 n_params = len(params) 

74 

75 if values_reference is not None and len(values_reference) != n_params: 

76 raise ValueError(f"{len(values_reference)=} != {n_params=}") 

77 

78 n_rows = n_params 

79 fig, ax = plt.subplots(nrows=n_rows, ncols=2, figsize=(10, 3 * n_rows)) 

80 axes = [ax] if (n_rows == 1) else ax 

81 

82 n_loglikes = len(loglike_init) 

83 labels = [channel.name for channel in model.data.channels] 

84 labels.extend(["prior", "total"]) 

85 

86 for param in params: 

87 param.fixed = True 

88 

89 for row, param in enumerate(params): 

90 value_init = param.value 

91 param.fixed = False 

92 values = [value_init] 

93 loglikes = [loglike_init * 0] 

94 dlls = [loglike_grads[row]] 

95 

96 diff_init = 1e-4 * np.sign(loglike_grads[row]) 

97 diff = diff_init 

98 

99 # TODO: This entire scheme should be improved/replaced 

100 # It sometimes takes excessively large steps 

101 # Option: Try to fit a curve once there are a couple of points 

102 # on each side of the peak 

103 idx_prev = -1 

104 for idx in range(2 * n_values): 

105 try: 

106 param.value_transformed += diff 

107 loglikes_new = np.array(model.evaluate()) - loglike_init 

108 dloglike_actual = np.sum(loglikes_new) - np.sum(loglikes[idx_prev]) 

109 values.append(param.value) 

110 loglikes.append(loglikes_new) 

111 dloglike_actual_abs = np.abs(dloglike_actual) 

112 if dloglike_actual_abs > 1: 

113 diff /= dloglike_actual_abs 

114 elif dloglike_actual_abs < 0.5: 

115 diff /= np.clip(dloglike_actual_abs, 0.2, 0.5) 

116 dlls.append(model.compute_loglike_grad()[0]) 

117 if idx == n_values: 

118 diff = -diff_init 

119 param.value = value_init 

120 idx_prev = 0 

121 else: 

122 idx_prev = -1 

123 except RuntimeError: 

124 break 

125 param.value = value_init 

126 param.fixed = True 

127 

128 subplot = axes[row][0] 

129 sorted = np.argsort(values) 

130 values = np.array(values)[sorted] 

131 loglikes = [loglikes[idx] for idx in sorted] 

132 dlls = np.array(dlls)[sorted] 

133 

134 for idx in range(n_loglikes): 

135 subplot.plot(values, [loglike[idx] for loglike in loglikes], label=labels[idx]) 

136 subplot.plot(values, np.sum(loglikes, axis=1), label=labels[-1]) 

137 vline_kwargs = dict(ymin=np.min(loglikes) - 1, ymax=np.max(loglikes) + 1, color="k") 

138 subplot.vlines(value_init, **vline_kwargs) 

139 

140 suffix = f" {param.label}" if param.label else "" 

141 subplot.legend() 

142 subplot.set_title(f"{param.name}{suffix}") 

143 subplot.set_ylabel("loglike") 

144 subplot.set_ylim(vline_kwargs["ymin"], vline_kwargs["ymax"]) 

145 

146 subplot = axes[row][1] 

147 subplot.plot(values, dlls) 

148 subplot.axhline(0, color="k") 

149 subplot.set_ylabel("dloglike/dx") 

150 

151 vline_kwargs = dict(ymin=np.min(dlls), ymax=np.max(dlls)) 

152 subplot.vlines(value_init, **vline_kwargs, color="k", label="fit") 

153 if values_reference is not None: 

154 subplot.vlines(values_reference[row], **vline_kwargs, color="b", label="ref") 

155 

156 cycler_linestyle = itertools.cycle(linestyles_default) 

157 for name_error, valerr in errors.items(): 

158 linestyle = valerr.kwargs_plot.pop("linestyle", next(cycler_linestyle)) 

159 for idx_ax in range(2): 

160 axes[row][idx_ax].vlines( 

161 [value_init - valerr.values[row], value_init + valerr.values[row]], 

162 linestyles=[linestyle, linestyle], 

163 label=name_error if (idx_ax == 1) else None, 

164 **valerr.kwargs_plot, 

165 **vline_kwargs, 

166 ) 

167 subplot.legend() 

168 

169 for param in params: 

170 param.fixed = False 

171 

172 return fig, ax