Coverage for python/lsst/analysis/tools/actions/plot/histPlot.py: 24%
121 statements
« prev ^ index » next coverage.py v6.4.4, created at 2022-08-30 03:25 -0700
« prev ^ index » next coverage.py v6.4.4, created at 2022-08-30 03:25 -0700
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ("HistPanel", "HistPlot")
25from collections import defaultdict
26from typing import Mapping
28import matplotlib.pyplot as plt
29import numpy as np
30from lsst.pex.config import Config, ConfigDictField, DictField, Field
31from matplotlib.figure import Figure
32from matplotlib.gridspec import GridSpec
33from matplotlib.patches import Rectangle
35from ...interfaces import KeyedData, KeyedDataSchema, PlotAction, Vector
36from ...statistics import sigmaMad
37from .plotUtils import addPlotInfo
40class HistPanel(Config):
41 label = Field[str](
42 doc="Panel x-axis label.",
43 default="label",
44 )
45 hists = DictField[str, str](
46 doc="A dict specifying the histograms to be plotted in this panel. Keys are used to identify "
47 "histogram IDs. Values are used to add to the legend label displayed in the upper corner of the "
48 "panel.",
49 optional=False,
50 )
51 yscale = Field[str](
52 doc="Y axis scaling.",
53 default="linear",
54 )
55 bins = Field[int](
56 doc="Number of x axis bins.",
57 default=50,
58 )
59 pLower = Field[float](
60 doc="Percentile used to determine the lower range of the histogram bins. If more than one histogram "
61 "is plotted in the panel, the percentile limit is the minimum value across all input data.",
62 default=2.0,
63 )
64 pUpper = Field[float](
65 doc="Percentile used to determine the upper range of the histogram bins. If more than one histogram "
66 "is plotted, the percentile limit is the maximum value across all input data.",
67 default=98.0,
68 )
71class HistPlot(PlotAction):
72 panels = ConfigDictField(
73 doc="A configurable dict describing the panels to be plotted, and the histograms for each panel.",
74 keytype=str,
75 itemtype=HistPanel,
76 default={},
77 )
78 cmap = Field[str](
79 doc="Color map used for histogram lines. All types available via `plt.cm` may be used. "
80 "A number of custom color maps are also defined: `newtab10`, `bright`, `vibrant`.",
81 default="newtab10",
82 )
84 def getInputSchema(self) -> KeyedDataSchema:
85 for panel in self.panels: # type: ignore
86 for histData in self.panels[panel].hists.items(): # type: ignore
87 yield histData, Vector
89 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure:
90 return self.makePlot(data, **kwargs)
91 # table is a dict that needs: x, y, run, skymap, filter, tract,
93 def makePlot(
94 self, data: KeyedData, plotInfo: Mapping[str, str] = None, **kwargs # type: ignore
95 ) -> Figure:
96 """Make an N-panel plot with a user-configurable number of histograms
97 displayed in each panel.
99 Parameters
100 ----------
101 data : `pandas.core.frame.DataFrame`
102 The catalog to plot the points from.
103 plotInfo : `dict`
104 A dictionary of information about the data being plotted with keys:
105 `"run"`
106 Output run for the plots (`str`).
107 `"tractTableType"`
108 Table from which results are taken (`str`).
109 `"plotName"`
110 Output plot name (`str`)
111 `"SN"`
112 The global signal-to-noise data threshold (`float`)
113 `"skymap"`
114 The type of skymap used for the data (`str`).
115 `"tract"`
116 The tract that the data comes from (`int`).
117 `"bands"`
118 The bands used for this data (`str` or `list`).
119 `"visit"`
120 The visit that the data comes from (`int`)
122 Returns
123 -------
124 fig : `matplotlib.figure.Figure`
125 The resulting figure.
127 """
129 # set up figure
130 fig = plt.figure(dpi=300)
131 hist_fig, side_fig = fig.subfigures(1, 2, wspace=0, width_ratios=[3, 1])
132 axs = self._makeAxes(hist_fig)
134 # loop over each panel; plot histograms
135 cols = self._assignColors()
136 all_handles, all_nums, all_meds, all_mads = [], [], [], []
137 for panel, ax in zip(self.panels, axs):
138 nums, meds, mads = self._makePanel(data, panel, ax, cols[panel], **kwargs)
139 handles, labels = ax.get_legend_handles_labels() # code for plotting
140 all_handles += handles
141 all_nums += nums
142 all_meds += meds
143 all_mads += mads
145 # add side panel; add statistics
146 self._addStatisticsPanel(side_fig, all_handles, all_nums, all_meds, all_mads)
148 # add general plot info
149 hist_fig = addPlotInfo(hist_fig, plotInfo)
151 # finish up
152 hist_fig.text(0.01, 0.42, "Frequency", rotation=90, transform=hist_fig.transFigure)
153 plt.draw()
154 return fig
156 def _makeAxes(self, fig):
157 """Determine axes layout for main histogram figure."""
158 num_panels = len(self.panels)
159 if num_panels <= 1:
160 ncols = 1
161 else:
162 ncols = 2
163 nrows = int(np.ceil(num_panels / ncols))
165 gs = GridSpec(nrows, ncols, left=0.13, right=0.99, bottom=0.1, top=0.88, wspace=0.25, hspace=0.45)
167 axs = []
168 counter = 0
169 for row in range(nrows):
170 for col in range(ncols):
171 counter += 1
172 if counter < num_panels:
173 axs.append(fig.add_subplot(gs[row : row + 1, col : col + 1]))
174 else:
175 axs.append(fig.add_subplot(gs[row : row + 1, col : np.min([col + 2, ncols + 1])]))
176 break
178 return axs
180 def _assignColors(self):
181 """Assign colors to histograms using a given color map."""
182 custom_cmaps = dict(
183 # https://www.tableau.com/about/blog/2016/7/colors-upgrade-tableau-10-56782
184 newtab10=[
185 "#4e79a7",
186 "#f28e2b",
187 "#e15759",
188 "#76b7b2",
189 "#59a14f",
190 "#edc948",
191 "#b07aa1",
192 "#ff9da7",
193 "#9c755f",
194 "#bab0ac",
195 ],
196 # https://personal.sron.nl/~pault/#fig:scheme_bright
197 bright=[
198 "#4477AA",
199 "#EE6677",
200 "#228833",
201 "#CCBB44",
202 "#66CCEE",
203 "#AA3377",
204 "#BBBBBB",
205 ],
206 # https://personal.sron.nl/~pault/#fig:scheme_vibrant
207 vibrant=[
208 "#EE7733",
209 "#0077BB",
210 "#33BBEE",
211 "#EE3377",
212 "#CC3311",
213 "#009988",
214 "#BBBBBB",
215 ],
216 )
217 if self.cmap in custom_cmaps.keys():
218 all_cols = custom_cmaps[self.cmap]
219 else:
220 try:
221 all_cols = getattr(plt.cm, self.cmap).copy().colors
222 except AttributeError:
223 raise ValueError(f"Unrecognized color map: {self.cmap}")
225 counter = 0
226 cols = defaultdict(list)
227 for panel in self.panels:
228 for hist in self.panels[panel].hists:
229 cols[panel].append(all_cols[counter % len(all_cols)])
230 counter += 1
231 return cols
233 def _makePanel(self, data, panel, ax, col, **kwargs):
234 """Plot a single panel containing histograms."""
235 panel_range = self._getPanelRange(data, panel)
236 nums, meds, mads = [], [], []
237 for i, hist in enumerate(self.panels[panel].hists):
238 hist_data = data[hist][np.isfinite(data[hist])]
239 ax.hist(
240 hist_data,
241 range=panel_range,
242 bins=self.panels[panel].bins,
243 histtype="step",
244 lw=2,
245 color=col[i],
246 label=self.panels[panel].hists[hist],
247 )
248 num, med, mad = self._calcStats(hist_data)
249 nums.append(num)
250 meds.append(med)
251 mads.append(mad)
252 ax.axvline(med, ls="--", lw=1, c=col[i])
253 ax.legend(fontsize=6, loc="upper left")
254 ax.set_xlim(panel_range)
255 ax.set_xlabel(self.panels[panel].label)
256 ax.set_yscale(self.panels[panel].yscale)
257 ax.tick_params(labelsize=7)
258 # add a buffer to the top of the plot to allow headspace for labels
259 ylims = list(ax.get_ylim())
260 if ax.get_yscale() == "log":
261 ylims[1] = 10 ** (np.log10(ylims[1]) * 1.1)
262 else:
263 ylims[1] *= 1.1
264 ax.set_ylim(ylims[0], ylims[1])
265 return nums, meds, mads
267 def _getPanelRange(self, data, panel):
268 """Determine panel x-axis range based on data percentile limits."""
269 panel_range = [np.nan, np.nan]
270 for hist in self.panels[panel].hists:
271 hist_range = np.nanpercentile(data[hist], [self.panels[panel].pLower, self.panels[panel].pUpper])
272 panel_range[0] = np.nanmin([panel_range[0], hist_range[0]])
273 panel_range[1] = np.nanmax([panel_range[1], hist_range[1]])
274 return panel_range
276 def _calcStats(self, data):
277 """Calculate the number of data points, median, and median absolute
278 deviation of input data."""
279 num = len(data)
280 med = np.nanmedian(data)
281 mad = sigmaMad(data)
282 return num, med, mad
284 def _addStatisticsPanel(self, fig, handles, nums, meds, mads):
285 """Add an adjoining panel containing histogram summary statistics."""
286 ax = fig.add_subplot(1, 1, 1)
287 ax.axis("off")
288 plt.subplots_adjust(left=0.05, right=0.95, bottom=0.09, top=0.9)
290 # empty handle, used to populate the bespoke legend layout
291 empty = Rectangle((0, 0), 1, 1, fc="w", fill=False, edgecolor="none", linewidth=0)
293 # set up new legend handles and labels
294 legend_handles = [empty] + handles + ([empty] * 3 * len(handles)) + ([empty] * 3)
295 legend_labels = (
296 ([""] * (len(handles) + 1))
297 + ["Num"]
298 + nums
299 + ["Med"]
300 + [f"{x:0.1f}" for x in meds]
301 + ["${{\\sigma}}_{{MAD}}$"]
302 + [f"{x:0.1f}" for x in mads]
303 )
305 # add the legend
306 ax.legend(
307 legend_handles,
308 legend_labels,
309 loc="lower left",
310 ncol=4,
311 handletextpad=-0.25,
312 fontsize=6,
313 borderpad=0,
314 frameon=False,
315 columnspacing=-0.25,
316 )