Coverage for python/lsst/analysis/tools/actions/plot/histPlot.py: 21%
122 statements
« prev ^ index » next coverage.py v6.5.0, created at 2022-11-12 03:16 -0800
« prev ^ index » next coverage.py v6.5.0, created at 2022-11-12 03:16 -0800
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ("HistPanel", "HistPlot")
25from collections import defaultdict
26from typing import Mapping
28import matplotlib.pyplot as plt
29import numpy as np
30from lsst.pex.config import Config, ConfigDictField, DictField, Field
31from matplotlib.figure import Figure
32from matplotlib.gridspec import GridSpec
33from matplotlib.patches import Rectangle
35from ...interfaces import KeyedData, KeyedDataSchema, PlotAction, Vector
36from ...statistics import sigmaMad
37from .plotUtils import addPlotInfo
40class HistPanel(Config):
41 label = Field[str](
42 doc="Panel x-axis label.",
43 default="label",
44 )
45 hists = DictField[str, str](
46 doc="A dict specifying the histograms to be plotted in this panel. Keys are used to identify "
47 "histogram IDs. Values are used to add to the legend label displayed in the upper corner of the "
48 "panel.",
49 optional=False,
50 )
51 yscale = Field[str](
52 doc="Y axis scaling.",
53 default="linear",
54 )
55 bins = Field[int](
56 doc="Number of x axis bins.",
57 default=50,
58 )
59 pLower = Field[float](
60 doc="Percentile used to determine the lower range of the histogram bins. If more than one histogram "
61 "is plotted in the panel, the percentile limit is the minimum value across all input data.",
62 default=2.0,
63 )
64 pUpper = Field[float](
65 doc="Percentile used to determine the upper range of the histogram bins. If more than one histogram "
66 "is plotted, the percentile limit is the maximum value across all input data.",
67 default=98.0,
68 )
71class HistPlot(PlotAction):
72 panels = ConfigDictField(
73 doc="A configurable dict describing the panels to be plotted, and the histograms for each panel.",
74 keytype=str,
75 itemtype=HistPanel,
76 default={},
77 )
78 cmap = Field[str](
79 doc="Color map used for histogram lines. All types available via `plt.cm` may be used. "
80 "A number of custom color maps are also defined: `newtab10`, `bright`, `vibrant`.",
81 default="newtab10",
82 )
84 def getInputSchema(self) -> KeyedDataSchema:
85 for panel in self.panels: # type: ignore
86 for histData in self.panels[panel].hists.items(): # type: ignore
87 yield histData, Vector
89 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure:
90 return self.makePlot(data, **kwargs)
91 # table is a dict that needs: x, y, run, skymap, filter, tract,
93 def makePlot(
94 self, data: KeyedData, plotInfo: Mapping[str, str] = None, **kwargs # type: ignore
95 ) -> Figure:
96 """Make an N-panel plot with a user-configurable number of histograms
97 displayed in each panel.
99 Parameters
100 ----------
101 data : `pandas.core.frame.DataFrame`
102 The catalog to plot the points from.
103 plotInfo : `dict`
104 A dictionary of information about the data being plotted with keys:
105 `"run"`
106 Output run for the plots (`str`).
107 `"tractTableType"`
108 Table from which results are taken (`str`).
109 `"plotName"`
110 Output plot name (`str`)
111 `"SN"`
112 The global signal-to-noise data threshold (`float`)
113 `"skymap"`
114 The type of skymap used for the data (`str`).
115 `"tract"`
116 The tract that the data comes from (`int`).
117 `"bands"`
118 The bands used for this data (`str` or `list`).
119 `"visit"`
120 The visit that the data comes from (`int`)
122 Returns
123 -------
124 fig : `matplotlib.figure.Figure`
125 The resulting figure.
127 """
129 # set up figure
130 fig = plt.figure(dpi=300)
131 hist_fig, side_fig = fig.subfigures(1, 2, wspace=0, width_ratios=[3, 1])
132 axs = self._makeAxes(hist_fig)
134 # loop over each panel; plot histograms
135 cols = self._assignColors()
136 all_handles, all_nums, all_meds, all_mads = [], [], [], []
137 for panel, ax in zip(self.panels, axs):
138 nums, meds, mads = self._makePanel(data, panel, ax, cols[panel], **kwargs)
139 handles, labels = ax.get_legend_handles_labels() # code for plotting
140 all_handles += handles
141 all_nums += nums
142 all_meds += meds
143 all_mads += mads
145 # add side panel; add statistics
146 self._addStatisticsPanel(side_fig, all_handles, all_nums, all_meds, all_mads)
148 # add general plot info
149 if plotInfo is not None:
150 hist_fig = addPlotInfo(hist_fig, plotInfo)
152 # finish up
153 hist_fig.text(0.01, 0.42, "Frequency", rotation=90, transform=hist_fig.transFigure)
154 plt.draw()
155 return fig
157 def _makeAxes(self, fig):
158 """Determine axes layout for main histogram figure."""
159 num_panels = len(self.panels)
160 if num_panels <= 1:
161 ncols = 1
162 else:
163 ncols = 2
164 nrows = int(np.ceil(num_panels / ncols))
166 gs = GridSpec(nrows, ncols, left=0.13, right=0.99, bottom=0.1, top=0.88, wspace=0.25, hspace=0.45)
168 axs = []
169 counter = 0
170 for row in range(nrows):
171 for col in range(ncols):
172 counter += 1
173 if counter < num_panels:
174 axs.append(fig.add_subplot(gs[row : row + 1, col : col + 1]))
175 else:
176 axs.append(fig.add_subplot(gs[row : row + 1, col : np.min([col + 2, ncols + 1])]))
177 break
179 return axs
181 def _assignColors(self):
182 """Assign colors to histograms using a given color map."""
183 custom_cmaps = dict(
184 # https://www.tableau.com/about/blog/2016/7/colors-upgrade-tableau-10-56782
185 newtab10=[
186 "#4e79a7",
187 "#f28e2b",
188 "#e15759",
189 "#76b7b2",
190 "#59a14f",
191 "#edc948",
192 "#b07aa1",
193 "#ff9da7",
194 "#9c755f",
195 "#bab0ac",
196 ],
197 # https://personal.sron.nl/~pault/#fig:scheme_bright
198 bright=[
199 "#4477AA",
200 "#EE6677",
201 "#228833",
202 "#CCBB44",
203 "#66CCEE",
204 "#AA3377",
205 "#BBBBBB",
206 ],
207 # https://personal.sron.nl/~pault/#fig:scheme_vibrant
208 vibrant=[
209 "#EE7733",
210 "#0077BB",
211 "#33BBEE",
212 "#EE3377",
213 "#CC3311",
214 "#009988",
215 "#BBBBBB",
216 ],
217 )
218 if self.cmap in custom_cmaps.keys():
219 all_cols = custom_cmaps[self.cmap]
220 else:
221 try:
222 all_cols = getattr(plt.cm, self.cmap).copy().colors
223 except AttributeError:
224 raise ValueError(f"Unrecognized color map: {self.cmap}")
226 counter = 0
227 cols = defaultdict(list)
228 for panel in self.panels:
229 for hist in self.panels[panel].hists:
230 cols[panel].append(all_cols[counter % len(all_cols)])
231 counter += 1
232 return cols
234 def _makePanel(self, data, panel, ax, col, **kwargs):
235 """Plot a single panel containing histograms."""
236 panel_range = self._getPanelRange(data, panel)
237 nums, meds, mads = [], [], []
238 for i, hist in enumerate(self.panels[panel].hists):
239 hist_data = data[hist][np.isfinite(data[hist])]
240 ax.hist(
241 hist_data,
242 range=panel_range,
243 bins=self.panels[panel].bins,
244 histtype="step",
245 lw=2,
246 color=col[i],
247 label=self.panels[panel].hists[hist],
248 )
249 num, med, mad = self._calcStats(hist_data)
250 nums.append(num)
251 meds.append(med)
252 mads.append(mad)
253 ax.axvline(med, ls="--", lw=1, c=col[i])
254 ax.legend(fontsize=6, loc="upper left")
255 ax.set_xlim(panel_range)
256 ax.set_xlabel(self.panels[panel].label)
257 ax.set_yscale(self.panels[panel].yscale)
258 ax.tick_params(labelsize=7)
259 # add a buffer to the top of the plot to allow headspace for labels
260 ylims = list(ax.get_ylim())
261 if ax.get_yscale() == "log":
262 ylims[1] = 10 ** (np.log10(ylims[1]) * 1.1)
263 else:
264 ylims[1] *= 1.1
265 ax.set_ylim(ylims[0], ylims[1])
266 return nums, meds, mads
268 def _getPanelRange(self, data, panel):
269 """Determine panel x-axis range based on data percentile limits."""
270 panel_range = [np.nan, np.nan]
271 for hist in self.panels[panel].hists:
272 hist_range = np.nanpercentile(data[hist], [self.panels[panel].pLower, self.panels[panel].pUpper])
273 panel_range[0] = np.nanmin([panel_range[0], hist_range[0]])
274 panel_range[1] = np.nanmax([panel_range[1], hist_range[1]])
275 return panel_range
277 def _calcStats(self, data):
278 """Calculate the number of data points, median, and median absolute
279 deviation of input data."""
280 num = len(data)
281 med = np.nanmedian(data)
282 mad = sigmaMad(data)
283 return num, med, mad
285 def _addStatisticsPanel(self, fig, handles, nums, meds, mads):
286 """Add an adjoining panel containing histogram summary statistics."""
287 ax = fig.add_subplot(1, 1, 1)
288 ax.axis("off")
289 plt.subplots_adjust(left=0.05, right=0.95, bottom=0.09, top=0.9)
291 # empty handle, used to populate the bespoke legend layout
292 empty = Rectangle((0, 0), 1, 1, fc="w", fill=False, edgecolor="none", linewidth=0)
294 # set up new legend handles and labels
295 legend_handles = [empty] + handles + ([empty] * 3 * len(handles)) + ([empty] * 3)
296 legend_labels = (
297 ([""] * (len(handles) + 1))
298 + ["Num"]
299 + nums
300 + ["Med"]
301 + [f"{x:0.1f}" for x in meds]
302 + ["${{\\sigma}}_{{MAD}}$"]
303 + [f"{x:0.1f}" for x in mads]
304 )
306 # add the legend
307 ax.legend(
308 legend_handles,
309 legend_labels,
310 loc="lower left",
311 ncol=4,
312 handletextpad=-0.25,
313 fontsize=6,
314 borderpad=0,
315 frameon=False,
316 columnspacing=-0.25,
317 )