Coverage for python / lsst / multiprofit / plotting / plot_sersicmix_interp.py: 6%
95 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-26 08:58 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-26 08:58 +0000
1# This file is part of multiprofit.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22__all__ = ["Interpolator", "plot_sersicmix_interp"]
24from typing import Any, Type, TypeAlias
26import lsst.gauss2d.fit as g2f
27import matplotlib as mpl
28import matplotlib.pyplot as plt
29import numpy as np
31from .types import FigureAxes
33Interpolator: TypeAlias = g2f.SersicMixInterpolator | tuple[Type, dict[str, Any]]
36def plot_sersicmix_interp(
37 interps: dict[str, tuple[Interpolator, str | tuple]], n_ser: np.ndarray, **kwargs: Any
38) -> FigureAxes:
39 """Plot Gaussian mixture Sersic profile interpolated values.
41 Parameters
42 ----------
43 interps
44 Dict of interpolators by name.
45 n_ser
46 Array of Sersic index values to plot interpolated quantities for.
47 **kwargs
48 Keyword arguments to pass to matplotlib.pyplot.subplots.
50 Returns
51 -------
52 figure
53 The resulting figure.
54 """
55 orders = {
56 name: interp.order
57 for name, (interp, _) in interps.items()
58 if isinstance(interp, g2f.SersicMixInterpolator)
59 }
60 order = set(orders.values())
61 if not len(order) == 1:
62 raise ValueError(f"len(set({orders})) != 1; all interpolators must have the same order")
63 order = tuple(order)[0]
65 cmap = mpl.cm.get_cmap("tab20b")
66 colors_ord = [None] * order
67 for i_ord in range(order):
68 colors_ord[i_ord] = cmap(i_ord / (order - 1.0))
70 n_ser_min = np.min(n_ser)
71 n_ser_max = np.max(n_ser)
72 knots = g2f.sersic_mix_knots(order=order)
73 n_knots = len(knots)
74 integrals_knots = np.empty((n_knots, order))
75 sigmas_knots = np.empty((n_knots, order))
76 n_ser_knots = np.empty(n_knots)
78 i_knot_first = None
79 i_knot_last = n_knots
80 for i_knot, knot in enumerate(knots):
81 if i_knot_first is None:
82 if knot.sersicindex > n_ser_min:
83 i_knot_first = i_knot
84 else:
85 continue
86 if knot.sersicindex > n_ser_max:
87 i_knot_last = i_knot
88 break
89 n_ser_knots[i_knot] = knot.sersicindex
90 for i_ord in range(order):
91 values = knot.values[i_ord]
92 integrals_knots[i_knot, i_ord] = values.integral
93 sigmas_knots[i_knot, i_ord] = values.sigma
94 range_knots = range(i_knot_first, i_knot_last)
95 integrals_knots = integrals_knots[range_knots, :]
96 sigmas_knots = sigmas_knots[range_knots, :]
97 n_ser_knots = n_ser_knots[range_knots]
99 n_values = len(n_ser)
100 integrals, dintegrals, sigmas, dsigmas = (
101 {name: np.empty((n_values, order)) for name in interps} for _ in range(4)
102 )
104 for name, (interp, _) in interps.items():
105 if not isinstance(interp, g2f.SersicMixInterpolator):
106 kwargs = interp[1] if interp[1] is not None else {}
107 interp = interp[0]
108 x = [knot.sersicindex for knot in knots]
109 for i_ord in range(order):
110 integrals_i = np.empty(n_knots, dtype=float)
111 sigmas_i = np.empty(n_knots, dtype=float)
112 for i_knot, knot in enumerate(knots):
113 integrals_i[i_knot] = knot.values[i_ord].integral
114 sigmas_i[i_knot] = knot.values[i_ord].sigma
115 interp_int = interp(x, integrals_i, **kwargs)
116 dinterp_int = interp_int.derivative()
117 interp_sigma = interp(x, sigmas_i, **kwargs)
118 dinterp_sigma = interp_sigma.derivative()
119 for i_val, value in enumerate(n_ser):
120 integrals[name][i_val, i_ord] = interp_int(value)
121 sigmas[name][i_val, i_ord] = interp_sigma(value)
122 dintegrals[name][i_val, i_ord] = dinterp_int(value)
123 dsigmas[name][i_val, i_ord] = dinterp_sigma(value)
125 for i_val, value in enumerate(n_ser):
126 for name, (interp, _) in interps.items():
127 if isinstance(interp, g2f.SersicMixInterpolator):
128 values = interp.integralsizes(value)
129 derivs = interp.integralsizes_derivs(value)
130 for i_ord in range(order):
131 integrals[name][i_val, i_ord] = values[i_ord].integral
132 sigmas[name][i_val, i_ord] = values[i_ord].sigma
133 dintegrals[name][i_val, i_ord] = derivs[i_ord].integral
134 dsigmas[name][i_val, i_ord] = derivs[i_ord].sigma
136 fig, axes = plt.subplots(2, 2, **kwargs)
137 for idx_row, (yv, yd, yk, y_label) in (
138 (0, (integrals, dintegrals, integrals_knots, "integral")),
139 (1, (sigmas, dsigmas, sigmas_knots, "sigma")),
140 ):
141 is_label_row = idx_row == 1
142 for idx_col, y_i, y_prefix in ((0, yv, ""), (1, yd, "d")):
143 is_label_col = idx_col == 0
144 make_label = is_label_col and is_label_row
145 axis = axes[idx_row, idx_col]
146 if is_label_col:
147 for i_ord in range(order):
148 axis.plot(
149 n_ser_knots,
150 yk[:, i_ord],
151 "kx",
152 label="knots" if make_label and (i_ord == 0) else None,
153 )
154 for name, (_, lstyle) in interps.items():
155 for i_ord in range(order):
156 label = f"{name}" if make_label and (i_ord == 0) else None
157 axis.plot(n_ser, y_i[name][:, i_ord], c=colors_ord[i_ord], label=label, linestyle=lstyle)
158 axis.set_xlim((n_ser_min, n_ser_max))
159 axis.set_ylabel(f"{y_prefix}{y_label}")
160 if make_label:
161 axis.legend(loc="upper left")
162 return fig, axes