Coverage for python / lsst / analysis / tools / actions / plot / barPlots.py: 16%
155 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-05 18:53 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-05 18:53 +0000
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ("BarPanel", "BarPlot")
25import operator as op
26from collections import defaultdict
27from typing import Mapping
29import matplotlib.pyplot as plt
30import numpy as np
31from lsst.pex.config import Config, ConfigDictField, DictField, Field
32from lsst.utils.plotting import set_rubin_plotstyle
33from matplotlib.figure import Figure
34from matplotlib.gridspec import GridSpec
35from matplotlib.patches import Rectangle
37from ...interfaces import KeyedData, KeyedDataSchema, PlotAction, Vector
38from .plotUtils import addPlotInfo
41class BarPanel(Config):
42 """A configurable class describing a panel in a bar plot."""
44 label = Field[str](
45 doc="Panel x-axis label.",
46 default="label",
47 )
48 bars = DictField[str, str](
49 doc="A dict specifying the bar graphs to be plotted in this panel. Keys are used to identify "
50 "bar graph IDs. Values are used to add to the legend label displayed in the upper corner of the "
51 "panel.",
52 optional=False,
53 )
54 yscale = Field[str](
55 doc="Y axis scaling.",
56 default="linear",
57 )
60class BarPlot(PlotAction):
61 """A plotting tool which can take multiple keyed data inputs
62 and can create one or more bar graphs.
63 """
65 panels = ConfigDictField(
66 doc="A configurable dict describing the panels to be plotted, and the bar graphs for each panel.",
67 keytype=str,
68 itemtype=BarPanel,
69 default={},
70 )
71 cmap = Field[str](
72 doc="Color map used for bar lines. All types available via `plt.cm` may be used. "
73 "A number of custom color maps are also defined: `newtab10`, `bright`, `vibrant`.",
74 default="newtab10",
75 )
77 def getInputSchema(self) -> KeyedDataSchema:
78 for panel in self.panels: # type: ignore
79 for barData in self.panels[panel].bars.items(): # type: ignore
80 yield barData, Vector
82 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure:
83 return self.makePlot(data, **kwargs)
85 def makePlot(
86 self, data: KeyedData, plotInfo: Mapping[str, str] = None, **kwargs # type: ignore
87 ) -> Figure:
88 """Make an N-panel plot with a user-configurable number of bar graphs
89 displayed in each panel.
91 Parameters
92 ----------
93 data : `KeyedData`
94 The catalog to plot the points from.
95 plotInfo : `dict`
96 An optional dictionary of information about the data being
97 plotted with keys:
99 `"run"`
100 Output run for the plots (`str`).
101 `"tractTableType"`
102 Table from which results are taken (`str`).
103 `"plotName"`
104 Output plot name (`str`)
105 `"SN"`
106 The global signal-to-noise data threshold (`float`)
107 `"skymap"`
108 The type of skymap used for the data (`str`).
109 `"tract"`
110 The tract that the data comes from (`int`).
111 `"bands"`
112 The bands used for this data (`str` or `list`).
113 `"visit"`
114 The visit that the data comes from (`int`)
116 Returns
117 -------
118 fig : `matplotlib.figure.Figure`
119 The resulting figure.
121 """
123 # set up figure
124 set_rubin_plotstyle()
125 fig = plt.figure(dpi=400)
126 bar_fig, side_fig = fig.subfigures(1, 2, wspace=0, width_ratios=[3, 1])
127 axs = self._makeAxes(bar_fig)
129 # loop over each panel; plot bar graphs
130 cols = self._assignColors()
131 all_handles, all_nums, all_vector_labels, all_x_values = [], [], [], []
132 for panel, ax in zip(self.panels, axs):
133 nums, sorted_label, sorted_x_values = self._makePanel(data, panel, ax, cols[panel], **kwargs)
134 handles, labels = ax.get_legend_handles_labels() # code for plotting
135 all_handles += handles
136 all_nums += nums
137 all_vector_labels += sorted_label
138 all_x_values += sorted_x_values
140 # add side panel; add statistics
141 self._addStatisticsPanel(side_fig, all_handles, all_nums, all_vector_labels, all_x_values)
143 # add general plot info
144 if plotInfo is not None:
145 bar_fig = addPlotInfo(bar_fig, plotInfo)
147 # finish up
148 bar_fig.text(0.01, 0.42, "Frequency", rotation=90, transform=bar_fig.transFigure)
149 plt.draw()
150 return fig
152 def _makeAxes(self, fig):
153 """Determine axes layout for main bar graph figure."""
154 num_panels = len(self.panels)
155 if num_panels <= 1:
156 ncols = 1
157 else:
158 ncols = 2
159 nrows = int(np.ceil(num_panels / ncols))
161 gs = GridSpec(nrows, ncols, left=0.13, right=0.99, bottom=0.1, top=0.88, wspace=0.25, hspace=0.45)
163 axs = []
164 counter = 0
165 for row in range(nrows):
166 for col in range(ncols):
167 counter += 1
168 if counter < num_panels:
169 axs.append(fig.add_subplot(gs[row : row + 1, col : col + 1]))
170 else:
171 axs.append(fig.add_subplot(gs[row : row + 1, col : np.min([col + 2, ncols + 1])]))
172 break
174 return axs
176 def _assignColors(self):
177 """Assign colors to bar graphs using a given color map."""
178 custom_cmaps = dict(
179 # https://www.tableau.com/about/blog/2016/7/colors-upgrade-tableau-10-56782
180 newtab10=[
181 "#4e79a7",
182 "#f28e2b",
183 "#e15759",
184 "#76b7b2",
185 "#59a14f",
186 "#edc948",
187 "#b07aa1",
188 "#ff9da7",
189 "#9c755f",
190 "#bab0ac",
191 ],
192 # https://personal.sron.nl/~pault/#fig:scheme_bright
193 bright=[
194 "#4477AA",
195 "#EE6677",
196 "#228833",
197 "#CCBB44",
198 "#66CCEE",
199 "#AA3377",
200 "#BBBBBB",
201 ],
202 # https://personal.sron.nl/~pault/#fig:scheme_vibrant
203 vibrant=[
204 "#EE7733",
205 "#0077BB",
206 "#33BBEE",
207 "#EE3377",
208 "#CC3311",
209 "#009988",
210 "#BBBBBB",
211 ],
212 )
213 if self.cmap in custom_cmaps.keys():
214 all_cols = custom_cmaps[self.cmap]
215 else:
216 try:
217 all_cols = getattr(plt.cm, self.cmap).copy().colors
218 except AttributeError:
219 raise ValueError(f"Unrecognized color map: {self.cmap}")
221 counter = 0
222 cols = defaultdict(list)
223 for panel in self.panels:
224 for bar in self.panels[panel].bars:
225 cols[panel].append(all_cols[counter % len(all_cols)])
226 counter += 1
227 return cols
229 def _makePanel(self, data, panel, ax, col, **kwargs):
230 """Plot a single panel containing bar graphs."""
231 nums = []
232 x_values, assigned_labels, assigned_colors = self._assignBinElements(data, panel, col)
233 sorted_x_values, sorted_labels, sorted_colors = self._sortBarBins(
234 x_values, assigned_labels, assigned_colors
235 )
236 width, columns = self._getBarWidths(sorted_x_values)
238 for i, bin in enumerate(sorted_x_values):
239 bar_data = op.countOf(data[sorted_labels[i]][np.isfinite(data[sorted_labels[i]])], bin)
241 if width[i] == 1:
242 bin_center = bin
243 else:
244 bin_center = bin - 0.35 + width[i] * columns[i]
246 ax.bar(bin_center, bar_data, width[i], lw=2, label=sorted_labels[i], color=sorted_colors[i])
247 nums.append(bar_data)
249 # Get plot range
250 x_range = [x for x in range(int(min(sorted_x_values)), int(max(sorted_x_values)) + 1)]
251 ax.set_xticks(x_range)
252 ax.set_xlabel(self.panels[panel].label)
253 ax.set_yscale(self.panels[panel].yscale)
254 ax.tick_params(labelsize=7)
255 # add a buffer to the top of the plot to allow headspace for labels
256 ylims = list(ax.get_ylim())
257 if ax.get_yscale() == "log":
258 ylims[1] = 10 ** (np.log10(ylims[1]) * 1.1)
259 else:
260 ylims[1] *= 1.1
261 ax.set_ylim(ylims[0], ylims[1])
262 return nums, sorted_labels, sorted_x_values
264 def _assignBinElements(self, data, panel, col):
265 labels = []
266 assigned_labels = []
267 x_values = []
268 assigned_colors = []
269 n_labels = 0
271 for bar in self.panels[panel].bars:
272 labels.append(bar)
273 n_labels += 1
275 # If a label has multiple unique elements in it, repeats the label
276 i = 0
277 for single_label in labels:
278 unique_elements = np.unique(data[single_label])
280 for bin in unique_elements:
281 x_values.append(int(bin))
283 for count in range(len(unique_elements)):
284 assigned_labels.append(single_label)
285 assigned_colors.append(col[i]) # Assign color from color cmap
287 i += 1
289 return x_values, assigned_labels, assigned_colors
291 def _sortBarBins(self, x_values, assigned_labels, assigned_colors):
292 """Sorts the existing x_values, assigned_labels,
293 and assigned_colors/x_value from lowest to
294 highest and then uses the sorted indices to sort
295 all x, labels, and colors in that order.
296 """
298 sorted_indices = np.argsort(x_values)
300 sorted_labels = []
301 sorted_x_values = []
302 sorted_colors = []
304 for position in sorted_indices:
305 sorted_x_values.append(x_values[position])
306 sorted_labels.append(assigned_labels[position])
307 sorted_colors.append(assigned_colors[position])
309 return sorted_x_values, sorted_labels, sorted_colors
311 def _getBarWidths(self, x_values):
312 """Determine the width of the panels in each
313 bin and which column is assigned."""
314 width = []
315 columns = []
316 current_column = 0
317 current_i = 0
319 for i in x_values:
320 # Number of repeating values
321 n_repeating = x_values.count(i)
322 width.append(1.0 / n_repeating)
323 if n_repeating > 1 and current_column != 0 and current_i == i:
324 columns.append(current_column)
325 current_column += 1
327 else:
328 current_column = 0
329 columns.append(current_column)
330 current_i = i
331 current_column += 1
333 return width, columns
335 def _addStatisticsPanel(self, fig, handles, nums, sorted_labels, sorted_x_value):
336 """Add an adjoining panel containing bar graph summary statistics."""
337 ax = fig.add_subplot(1, 1, 1)
338 ax.axis("off")
339 plt.subplots_adjust(left=0.05, right=0.95, bottom=0.09, top=0.9)
341 # empty handle, used to populate the bespoke legend layout
342 empty = Rectangle((0, 0), 1, 1, fc="w", fill=False, edgecolor="none", linewidth=0)
344 # set up new legend handles and labels
346 legend_handles = [empty] + handles + ([empty] * 3 * len(handles)) + ([empty] * 3)
347 legend_labels = (
348 ([""] * (len(handles) + 1))
349 + ["Bin"]
350 + sorted_x_value
351 + ["Count"]
352 + nums
353 + ["Sources"]
354 + sorted_labels
355 )
357 # add the legend
358 ax.legend(
359 legend_handles,
360 legend_labels,
361 loc="lower left",
362 ncol=4,
363 handletextpad=-0.25,
364 fontsize=6,
365 borderpad=0,
366 frameon=False,
367 columnspacing=-0.25,
368 )