Coverage for python/lsst/analysis/tools/actions/plot/barPlots.py: 15%
153 statements
« prev ^ index » next coverage.py v6.5.0, created at 2022-12-09 02:45 -0800
« prev ^ index » next coverage.py v6.5.0, created at 2022-12-09 02:45 -0800
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ("BarPanel", "BarPlot")
25import operator as op
26from collections import defaultdict
27from typing import Mapping
29import matplotlib.pyplot as plt
30import numpy as np
31from lsst.analysis.tools.actions.plot.plotUtils import addPlotInfo
32from lsst.analysis.tools.interfaces import KeyedData, KeyedDataSchema, PlotAction, Vector
33from lsst.pex.config import Config, ConfigDictField, DictField, Field
34from matplotlib.figure import Figure
35from matplotlib.gridspec import GridSpec
36from matplotlib.patches import Rectangle
39class BarPanel(Config):
40 label = Field[str](
41 doc="Panel x-axis label.",
42 default="label",
43 )
44 bars = DictField[str, str](
45 doc="A dict specifying the bar graphs to be plotted in this panel. Keys are used to identify "
46 "bar graph IDs. Values are used to add to the legend label displayed in the upper corner of the "
47 "panel.",
48 optional=False,
49 )
50 yscale = Field[str](
51 doc="Y axis scaling.",
52 default="linear",
53 )
56class BarPlot(PlotAction):
57 panels = ConfigDictField(
58 doc="A configurable dict describing the panels to be plotted, and the bar graphs for each panel.",
59 keytype=str,
60 itemtype=BarPanel,
61 default={},
62 )
63 cmap = Field[str](
64 doc="Color map used for bar lines. All types available via `plt.cm` may be used. "
65 "A number of custom color maps are also defined: `newtab10`, `bright`, `vibrant`.",
66 default="newtab10",
67 )
69 def getInputSchema(self) -> KeyedDataSchema:
70 for panel in self.panels: # type: ignore
71 for barData in self.panels[panel].bars.items(): # type: ignore
72 yield barData, Vector
74 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure:
75 return self.makePlot(data, **kwargs)
77 def makePlot(
78 self, data: KeyedData, plotInfo: Mapping[str, str] = None, **kwargs # type: ignore
79 ) -> Figure:
80 """Make an N-panel plot with a user-configurable number of bar graphs
81 displayed in each panel.
84 Parameters
85 ----------
86 data : `KeyedData`
87 The catalog to plot the points from.
88 plotInfo : `dict`
89 An optional dictionary of information about the data being
90 plotted with keys:
91 `"run"`
92 Output run for the plots (`str`).
93 `"tractTableType"`
94 Table from which results are taken (`str`).
95 `"plotName"`
96 Output plot name (`str`)
97 `"SN"`
98 The global signal-to-noise data threshold (`float`)
99 `"skymap"`
100 The type of skymap used for the data (`str`).
101 `"tract"`
102 The tract that the data comes from (`int`).
103 `"bands"`
104 The bands used for this data (`str` or `list`).
105 `"visit"`
106 The visit that the data comes from (`int`)
108 Returns
109 -------
110 fig : `matplotlib.figure.Figure`
111 The resulting figure.
113 """
115 # set up figure
116 fig = plt.figure(dpi=400)
117 bar_fig, side_fig = fig.subfigures(1, 2, wspace=0, width_ratios=[3, 1])
118 axs = self._makeAxes(bar_fig)
120 # loop over each panel; plot bar graphs
121 cols = self._assignColors()
122 all_handles, all_nums, all_vector_labels, all_x_values = [], [], [], []
123 for panel, ax in zip(self.panels, axs):
124 nums, sorted_label, sorted_x_values = self._makePanel(data, panel, ax, cols[panel], **kwargs)
125 handles, labels = ax.get_legend_handles_labels() # code for plotting
126 all_handles += handles
127 all_nums += nums
128 all_vector_labels += sorted_label
129 all_x_values += sorted_x_values
131 # add side panel; add statistics
132 self._addStatisticsPanel(side_fig, all_handles, all_nums, all_vector_labels, all_x_values)
134 # add general plot info
135 if plotInfo is not None:
136 bar_fig = addPlotInfo(bar_fig, plotInfo)
138 # finish up
139 bar_fig.text(0.01, 0.42, "Frequency", rotation=90, transform=bar_fig.transFigure)
140 plt.draw()
141 return fig
143 def _makeAxes(self, fig):
144 """Determine axes layout for main bar graph figure."""
145 num_panels = len(self.panels)
146 if num_panels <= 1:
147 ncols = 1
148 else:
149 ncols = 2
150 nrows = int(np.ceil(num_panels / ncols))
152 gs = GridSpec(nrows, ncols, left=0.13, right=0.99, bottom=0.1, top=0.88, wspace=0.25, hspace=0.45)
154 axs = []
155 counter = 0
156 for row in range(nrows):
157 for col in range(ncols):
158 counter += 1
159 if counter < num_panels:
160 axs.append(fig.add_subplot(gs[row : row + 1, col : col + 1]))
161 else:
162 axs.append(fig.add_subplot(gs[row : row + 1, col : np.min([col + 2, ncols + 1])]))
163 break
165 return axs
167 def _assignColors(self):
168 """Assign colors to bar graphs using a given color map."""
169 custom_cmaps = dict(
170 # https://www.tableau.com/about/blog/2016/7/colors-upgrade-tableau-10-56782
171 newtab10=[
172 "#4e79a7",
173 "#f28e2b",
174 "#e15759",
175 "#76b7b2",
176 "#59a14f",
177 "#edc948",
178 "#b07aa1",
179 "#ff9da7",
180 "#9c755f",
181 "#bab0ac",
182 ],
183 # https://personal.sron.nl/~pault/#fig:scheme_bright
184 bright=[
185 "#4477AA",
186 "#EE6677",
187 "#228833",
188 "#CCBB44",
189 "#66CCEE",
190 "#AA3377",
191 "#BBBBBB",
192 ],
193 # https://personal.sron.nl/~pault/#fig:scheme_vibrant
194 vibrant=[
195 "#EE7733",
196 "#0077BB",
197 "#33BBEE",
198 "#EE3377",
199 "#CC3311",
200 "#009988",
201 "#BBBBBB",
202 ],
203 )
204 if self.cmap in custom_cmaps.keys():
205 all_cols = custom_cmaps[self.cmap]
206 else:
207 try:
208 all_cols = getattr(plt.cm, self.cmap).copy().colors
209 except AttributeError:
210 raise ValueError(f"Unrecognized color map: {self.cmap}")
212 counter = 0
213 cols = defaultdict(list)
214 for panel in self.panels:
215 for bar in self.panels[panel].bars:
216 cols[panel].append(all_cols[counter % len(all_cols)])
217 counter += 1
218 return cols
220 def _makePanel(self, data, panel, ax, col, **kwargs):
221 """Plot a single panel containing bar graphs."""
222 nums = []
223 x_values, assigned_labels, assigned_colors = self._assignBinElements(data, panel, col)
224 sorted_x_values, sorted_labels, sorted_colors = self._sortBarBins(
225 x_values, assigned_labels, assigned_colors
226 )
227 width, columns = self._getBarWidths(sorted_x_values)
229 for i, bin in enumerate(sorted_x_values):
230 bar_data = op.countOf(data[sorted_labels[i]][np.isfinite(data[sorted_labels[i]])], bin)
232 if width[i] == 1:
233 bin_center = bin
234 else:
235 bin_center = bin - 0.35 + width[i] * columns[i]
237 ax.bar(bin_center, bar_data, width[i], lw=2, label=sorted_labels[i], color=sorted_colors[i])
238 nums.append(bar_data)
240 # Get plot range
241 x_range = [x for x in range(int(min(sorted_x_values)), int(max(sorted_x_values)) + 1)]
242 ax.set_xticks(x_range)
243 ax.set_xlabel(self.panels[panel].label)
244 ax.set_yscale(self.panels[panel].yscale)
245 ax.tick_params(labelsize=7)
246 # add a buffer to the top of the plot to allow headspace for labels
247 ylims = list(ax.get_ylim())
248 if ax.get_yscale() == "log":
249 ylims[1] = 10 ** (np.log10(ylims[1]) * 1.1)
250 else:
251 ylims[1] *= 1.1
252 ax.set_ylim(ylims[0], ylims[1])
253 return nums, sorted_labels, sorted_x_values
255 def _assignBinElements(self, data, panel, col):
256 labels = []
257 assigned_labels = []
258 x_values = []
259 assigned_colors = []
260 n_labels = 0
262 for bar in self.panels[panel].bars:
263 labels.append(bar)
264 n_labels += 1
266 # If a label has multiple unique elements in it, repeats the label
267 i = 0
268 for single_label in labels:
269 unique_elements = np.unique(data[single_label])
271 for bin in unique_elements:
272 x_values.append(int(bin))
274 for count in range(len(unique_elements)):
275 assigned_labels.append(single_label)
276 assigned_colors.append(col[i]) # Assign color from color cmap
278 i += 1
280 return x_values, assigned_labels, assigned_colors
282 def _sortBarBins(self, x_values, assigned_labels, assigned_colors):
283 """Sorts the existing x_values, assigned_labels,
284 and assigned_colors/x_value from lowest to
285 highest and then uses the sorted indices to sort
286 all x, labels, and colors in that order.
287 """
289 sorted_indices = np.argsort(x_values)
291 sorted_labels = []
292 sorted_x_values = []
293 sorted_colors = []
295 for position in sorted_indices:
296 sorted_x_values.append(x_values[position])
297 sorted_labels.append(assigned_labels[position])
298 sorted_colors.append(assigned_colors[position])
300 return sorted_x_values, sorted_labels, sorted_colors
302 def _getBarWidths(self, x_values):
303 """Determine the width of the panels in each
304 bin and which column is assigned."""
305 width = []
306 columns = []
307 current_column = 0
308 current_i = 0
310 for i in x_values:
312 # Number of repeating values
313 n_repeating = x_values.count(i)
314 width.append(1.0 / n_repeating)
315 if n_repeating > 1 and current_column != 0 and current_i == i:
316 columns.append(current_column)
317 current_column += 1
319 else:
320 current_column = 0
321 columns.append(current_column)
322 current_i = i
323 current_column += 1
325 return width, columns
327 def _addStatisticsPanel(self, fig, handles, nums, sorted_labels, sorted_x_value):
328 """Add an adjoining panel containing bar graph summary statistics."""
329 ax = fig.add_subplot(1, 1, 1)
330 ax.axis("off")
331 plt.subplots_adjust(left=0.05, right=0.95, bottom=0.09, top=0.9)
333 # empty handle, used to populate the bespoke legend layout
334 empty = Rectangle((0, 0), 1, 1, fc="w", fill=False, edgecolor="none", linewidth=0)
336 # set up new legend handles and labels
338 legend_handles = [empty] + handles + ([empty] * 3 * len(handles)) + ([empty] * 3)
339 legend_labels = (
340 ([""] * (len(handles) + 1))
341 + ["Bin"]
342 + sorted_x_value
343 + ["Count"]
344 + nums
345 + ["Sources"]
346 + sorted_labels
347 )
349 # add the legend
350 ax.legend(
351 legend_handles,
352 legend_labels,
353 loc="lower left",
354 ncol=4,
355 handletextpad=-0.25,
356 fontsize=6,
357 borderpad=0,
358 frameon=False,
359 columnspacing=-0.25,
360 )