Coverage for python / lsst / multiprofit / plotting / plot_sersicmix_interp.py: 6%

95 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-30 08:48 +0000

1# This file is part of multiprofit. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = ["Interpolator", "plot_sersicmix_interp"] 

23 

24from typing import Any, Type, TypeAlias 

25 

26import lsst.gauss2d.fit as g2f 

27import matplotlib as mpl 

28import matplotlib.pyplot as plt 

29import numpy as np 

30 

31from .types import FigureAxes 

32 

33Interpolator: TypeAlias = g2f.SersicMixInterpolator | tuple[Type, dict[str, Any]] 

34 

35 

36def plot_sersicmix_interp( 

37 interps: dict[str, tuple[Interpolator, str | tuple]], n_ser: np.ndarray, **kwargs: Any 

38) -> FigureAxes: 

39 """Plot Gaussian mixture Sersic profile interpolated values. 

40 

41 Parameters 

42 ---------- 

43 interps 

44 Dict of interpolators by name. 

45 n_ser 

46 Array of Sersic index values to plot interpolated quantities for. 

47 **kwargs 

48 Keyword arguments to pass to matplotlib.pyplot.subplots. 

49 

50 Returns 

51 ------- 

52 figure 

53 The resulting figure. 

54 """ 

55 orders = { 

56 name: interp.order 

57 for name, (interp, _) in interps.items() 

58 if isinstance(interp, g2f.SersicMixInterpolator) 

59 } 

60 order = set(orders.values()) 

61 if not len(order) == 1: 

62 raise ValueError(f"len(set({orders})) != 1; all interpolators must have the same order") 

63 order = tuple(order)[0] 

64 

65 cmap = mpl.cm.get_cmap("tab20b") 

66 colors_ord = [None] * order 

67 for i_ord in range(order): 

68 colors_ord[i_ord] = cmap(i_ord / (order - 1.0)) 

69 

70 n_ser_min = np.min(n_ser) 

71 n_ser_max = np.max(n_ser) 

72 knots = g2f.sersic_mix_knots(order=order) 

73 n_knots = len(knots) 

74 integrals_knots = np.empty((n_knots, order)) 

75 sigmas_knots = np.empty((n_knots, order)) 

76 n_ser_knots = np.empty(n_knots) 

77 

78 i_knot_first = None 

79 i_knot_last = n_knots 

80 for i_knot, knot in enumerate(knots): 

81 if i_knot_first is None: 

82 if knot.sersicindex > n_ser_min: 

83 i_knot_first = i_knot 

84 else: 

85 continue 

86 if knot.sersicindex > n_ser_max: 

87 i_knot_last = i_knot 

88 break 

89 n_ser_knots[i_knot] = knot.sersicindex 

90 for i_ord in range(order): 

91 values = knot.values[i_ord] 

92 integrals_knots[i_knot, i_ord] = values.integral 

93 sigmas_knots[i_knot, i_ord] = values.sigma 

94 range_knots = range(i_knot_first, i_knot_last) 

95 integrals_knots = integrals_knots[range_knots, :] 

96 sigmas_knots = sigmas_knots[range_knots, :] 

97 n_ser_knots = n_ser_knots[range_knots] 

98 

99 n_values = len(n_ser) 

100 integrals, dintegrals, sigmas, dsigmas = ( 

101 {name: np.empty((n_values, order)) for name in interps} for _ in range(4) 

102 ) 

103 

104 for name, (interp, _) in interps.items(): 

105 if not isinstance(interp, g2f.SersicMixInterpolator): 

106 kwargs = interp[1] if interp[1] is not None else {} 

107 interp = interp[0] 

108 x = [knot.sersicindex for knot in knots] 

109 for i_ord in range(order): 

110 integrals_i = np.empty(n_knots, dtype=float) 

111 sigmas_i = np.empty(n_knots, dtype=float) 

112 for i_knot, knot in enumerate(knots): 

113 integrals_i[i_knot] = knot.values[i_ord].integral 

114 sigmas_i[i_knot] = knot.values[i_ord].sigma 

115 interp_int = interp(x, integrals_i, **kwargs) 

116 dinterp_int = interp_int.derivative() 

117 interp_sigma = interp(x, sigmas_i, **kwargs) 

118 dinterp_sigma = interp_sigma.derivative() 

119 for i_val, value in enumerate(n_ser): 

120 integrals[name][i_val, i_ord] = interp_int(value) 

121 sigmas[name][i_val, i_ord] = interp_sigma(value) 

122 dintegrals[name][i_val, i_ord] = dinterp_int(value) 

123 dsigmas[name][i_val, i_ord] = dinterp_sigma(value) 

124 

125 for i_val, value in enumerate(n_ser): 

126 for name, (interp, _) in interps.items(): 

127 if isinstance(interp, g2f.SersicMixInterpolator): 

128 values = interp.integralsizes(value) 

129 derivs = interp.integralsizes_derivs(value) 

130 for i_ord in range(order): 

131 integrals[name][i_val, i_ord] = values[i_ord].integral 

132 sigmas[name][i_val, i_ord] = values[i_ord].sigma 

133 dintegrals[name][i_val, i_ord] = derivs[i_ord].integral 

134 dsigmas[name][i_val, i_ord] = derivs[i_ord].sigma 

135 

136 fig, axes = plt.subplots(2, 2, **kwargs) 

137 for idx_row, (yv, yd, yk, y_label) in ( 

138 (0, (integrals, dintegrals, integrals_knots, "integral")), 

139 (1, (sigmas, dsigmas, sigmas_knots, "sigma")), 

140 ): 

141 is_label_row = idx_row == 1 

142 for idx_col, y_i, y_prefix in ((0, yv, ""), (1, yd, "d")): 

143 is_label_col = idx_col == 0 

144 make_label = is_label_col and is_label_row 

145 axis = axes[idx_row, idx_col] 

146 if is_label_col: 

147 for i_ord in range(order): 

148 axis.plot( 

149 n_ser_knots, 

150 yk[:, i_ord], 

151 "kx", 

152 label="knots" if make_label and (i_ord == 0) else None, 

153 ) 

154 for name, (_, lstyle) in interps.items(): 

155 for i_ord in range(order): 

156 label = f"{name}" if make_label and (i_ord == 0) else None 

157 axis.plot(n_ser, y_i[name][:, i_ord], c=colors_ord[i_ord], label=label, linestyle=lstyle) 

158 axis.set_xlim((n_ser_min, n_ser_max)) 

159 axis.set_ylabel(f"{y_prefix}{y_label}") 

160 if make_label: 

161 axis.legend(loc="upper left") 

162 return fig, axes