Coverage for python / lsst / analysis / tools / actions / plot / barPlots.py: 16%

155 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-05 18:53 +0000

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("BarPanel", "BarPlot") 

24 

25import operator as op 

26from collections import defaultdict 

27from typing import Mapping 

28 

29import matplotlib.pyplot as plt 

30import numpy as np 

31from lsst.pex.config import Config, ConfigDictField, DictField, Field 

32from lsst.utils.plotting import set_rubin_plotstyle 

33from matplotlib.figure import Figure 

34from matplotlib.gridspec import GridSpec 

35from matplotlib.patches import Rectangle 

36 

37from ...interfaces import KeyedData, KeyedDataSchema, PlotAction, Vector 

38from .plotUtils import addPlotInfo 

39 

40 

41class BarPanel(Config): 

42 """A configurable class describing a panel in a bar plot.""" 

43 

44 label = Field[str]( 

45 doc="Panel x-axis label.", 

46 default="label", 

47 ) 

48 bars = DictField[str, str]( 

49 doc="A dict specifying the bar graphs to be plotted in this panel. Keys are used to identify " 

50 "bar graph IDs. Values are used to add to the legend label displayed in the upper corner of the " 

51 "panel.", 

52 optional=False, 

53 ) 

54 yscale = Field[str]( 

55 doc="Y axis scaling.", 

56 default="linear", 

57 ) 

58 

59 

60class BarPlot(PlotAction): 

61 """A plotting tool which can take multiple keyed data inputs 

62 and can create one or more bar graphs. 

63 """ 

64 

65 panels = ConfigDictField( 

66 doc="A configurable dict describing the panels to be plotted, and the bar graphs for each panel.", 

67 keytype=str, 

68 itemtype=BarPanel, 

69 default={}, 

70 ) 

71 cmap = Field[str]( 

72 doc="Color map used for bar lines. All types available via `plt.cm` may be used. " 

73 "A number of custom color maps are also defined: `newtab10`, `bright`, `vibrant`.", 

74 default="newtab10", 

75 ) 

76 

77 def getInputSchema(self) -> KeyedDataSchema: 

78 for panel in self.panels: # type: ignore 

79 for barData in self.panels[panel].bars.items(): # type: ignore 

80 yield barData, Vector 

81 

82 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure: 

83 return self.makePlot(data, **kwargs) 

84 

85 def makePlot( 

86 self, data: KeyedData, plotInfo: Mapping[str, str] = None, **kwargs # type: ignore 

87 ) -> Figure: 

88 """Make an N-panel plot with a user-configurable number of bar graphs 

89 displayed in each panel. 

90 

91 Parameters 

92 ---------- 

93 data : `KeyedData` 

94 The catalog to plot the points from. 

95 plotInfo : `dict` 

96 An optional dictionary of information about the data being 

97 plotted with keys: 

98 

99 `"run"` 

100 Output run for the plots (`str`). 

101 `"tractTableType"` 

102 Table from which results are taken (`str`). 

103 `"plotName"` 

104 Output plot name (`str`) 

105 `"SN"` 

106 The global signal-to-noise data threshold (`float`) 

107 `"skymap"` 

108 The type of skymap used for the data (`str`). 

109 `"tract"` 

110 The tract that the data comes from (`int`). 

111 `"bands"` 

112 The bands used for this data (`str` or `list`). 

113 `"visit"` 

114 The visit that the data comes from (`int`) 

115 

116 Returns 

117 ------- 

118 fig : `matplotlib.figure.Figure` 

119 The resulting figure. 

120 

121 """ 

122 

123 # set up figure 

124 set_rubin_plotstyle() 

125 fig = plt.figure(dpi=400) 

126 bar_fig, side_fig = fig.subfigures(1, 2, wspace=0, width_ratios=[3, 1]) 

127 axs = self._makeAxes(bar_fig) 

128 

129 # loop over each panel; plot bar graphs 

130 cols = self._assignColors() 

131 all_handles, all_nums, all_vector_labels, all_x_values = [], [], [], [] 

132 for panel, ax in zip(self.panels, axs): 

133 nums, sorted_label, sorted_x_values = self._makePanel(data, panel, ax, cols[panel], **kwargs) 

134 handles, labels = ax.get_legend_handles_labels() # code for plotting 

135 all_handles += handles 

136 all_nums += nums 

137 all_vector_labels += sorted_label 

138 all_x_values += sorted_x_values 

139 

140 # add side panel; add statistics 

141 self._addStatisticsPanel(side_fig, all_handles, all_nums, all_vector_labels, all_x_values) 

142 

143 # add general plot info 

144 if plotInfo is not None: 

145 bar_fig = addPlotInfo(bar_fig, plotInfo) 

146 

147 # finish up 

148 bar_fig.text(0.01, 0.42, "Frequency", rotation=90, transform=bar_fig.transFigure) 

149 plt.draw() 

150 return fig 

151 

152 def _makeAxes(self, fig): 

153 """Determine axes layout for main bar graph figure.""" 

154 num_panels = len(self.panels) 

155 if num_panels <= 1: 

156 ncols = 1 

157 else: 

158 ncols = 2 

159 nrows = int(np.ceil(num_panels / ncols)) 

160 

161 gs = GridSpec(nrows, ncols, left=0.13, right=0.99, bottom=0.1, top=0.88, wspace=0.25, hspace=0.45) 

162 

163 axs = [] 

164 counter = 0 

165 for row in range(nrows): 

166 for col in range(ncols): 

167 counter += 1 

168 if counter < num_panels: 

169 axs.append(fig.add_subplot(gs[row : row + 1, col : col + 1])) 

170 else: 

171 axs.append(fig.add_subplot(gs[row : row + 1, col : np.min([col + 2, ncols + 1])])) 

172 break 

173 

174 return axs 

175 

176 def _assignColors(self): 

177 """Assign colors to bar graphs using a given color map.""" 

178 custom_cmaps = dict( 

179 # https://www.tableau.com/about/blog/2016/7/colors-upgrade-tableau-10-56782 

180 newtab10=[ 

181 "#4e79a7", 

182 "#f28e2b", 

183 "#e15759", 

184 "#76b7b2", 

185 "#59a14f", 

186 "#edc948", 

187 "#b07aa1", 

188 "#ff9da7", 

189 "#9c755f", 

190 "#bab0ac", 

191 ], 

192 # https://personal.sron.nl/~pault/#fig:scheme_bright 

193 bright=[ 

194 "#4477AA", 

195 "#EE6677", 

196 "#228833", 

197 "#CCBB44", 

198 "#66CCEE", 

199 "#AA3377", 

200 "#BBBBBB", 

201 ], 

202 # https://personal.sron.nl/~pault/#fig:scheme_vibrant 

203 vibrant=[ 

204 "#EE7733", 

205 "#0077BB", 

206 "#33BBEE", 

207 "#EE3377", 

208 "#CC3311", 

209 "#009988", 

210 "#BBBBBB", 

211 ], 

212 ) 

213 if self.cmap in custom_cmaps.keys(): 

214 all_cols = custom_cmaps[self.cmap] 

215 else: 

216 try: 

217 all_cols = getattr(plt.cm, self.cmap).copy().colors 

218 except AttributeError: 

219 raise ValueError(f"Unrecognized color map: {self.cmap}") 

220 

221 counter = 0 

222 cols = defaultdict(list) 

223 for panel in self.panels: 

224 for bar in self.panels[panel].bars: 

225 cols[panel].append(all_cols[counter % len(all_cols)]) 

226 counter += 1 

227 return cols 

228 

229 def _makePanel(self, data, panel, ax, col, **kwargs): 

230 """Plot a single panel containing bar graphs.""" 

231 nums = [] 

232 x_values, assigned_labels, assigned_colors = self._assignBinElements(data, panel, col) 

233 sorted_x_values, sorted_labels, sorted_colors = self._sortBarBins( 

234 x_values, assigned_labels, assigned_colors 

235 ) 

236 width, columns = self._getBarWidths(sorted_x_values) 

237 

238 for i, bin in enumerate(sorted_x_values): 

239 bar_data = op.countOf(data[sorted_labels[i]][np.isfinite(data[sorted_labels[i]])], bin) 

240 

241 if width[i] == 1: 

242 bin_center = bin 

243 else: 

244 bin_center = bin - 0.35 + width[i] * columns[i] 

245 

246 ax.bar(bin_center, bar_data, width[i], lw=2, label=sorted_labels[i], color=sorted_colors[i]) 

247 nums.append(bar_data) 

248 

249 # Get plot range 

250 x_range = [x for x in range(int(min(sorted_x_values)), int(max(sorted_x_values)) + 1)] 

251 ax.set_xticks(x_range) 

252 ax.set_xlabel(self.panels[panel].label) 

253 ax.set_yscale(self.panels[panel].yscale) 

254 ax.tick_params(labelsize=7) 

255 # add a buffer to the top of the plot to allow headspace for labels 

256 ylims = list(ax.get_ylim()) 

257 if ax.get_yscale() == "log": 

258 ylims[1] = 10 ** (np.log10(ylims[1]) * 1.1) 

259 else: 

260 ylims[1] *= 1.1 

261 ax.set_ylim(ylims[0], ylims[1]) 

262 return nums, sorted_labels, sorted_x_values 

263 

264 def _assignBinElements(self, data, panel, col): 

265 labels = [] 

266 assigned_labels = [] 

267 x_values = [] 

268 assigned_colors = [] 

269 n_labels = 0 

270 

271 for bar in self.panels[panel].bars: 

272 labels.append(bar) 

273 n_labels += 1 

274 

275 # If a label has multiple unique elements in it, repeats the label 

276 i = 0 

277 for single_label in labels: 

278 unique_elements = np.unique(data[single_label]) 

279 

280 for bin in unique_elements: 

281 x_values.append(int(bin)) 

282 

283 for count in range(len(unique_elements)): 

284 assigned_labels.append(single_label) 

285 assigned_colors.append(col[i]) # Assign color from color cmap 

286 

287 i += 1 

288 

289 return x_values, assigned_labels, assigned_colors 

290 

291 def _sortBarBins(self, x_values, assigned_labels, assigned_colors): 

292 """Sorts the existing x_values, assigned_labels, 

293 and assigned_colors/x_value from lowest to 

294 highest and then uses the sorted indices to sort 

295 all x, labels, and colors in that order. 

296 """ 

297 

298 sorted_indices = np.argsort(x_values) 

299 

300 sorted_labels = [] 

301 sorted_x_values = [] 

302 sorted_colors = [] 

303 

304 for position in sorted_indices: 

305 sorted_x_values.append(x_values[position]) 

306 sorted_labels.append(assigned_labels[position]) 

307 sorted_colors.append(assigned_colors[position]) 

308 

309 return sorted_x_values, sorted_labels, sorted_colors 

310 

311 def _getBarWidths(self, x_values): 

312 """Determine the width of the panels in each 

313 bin and which column is assigned.""" 

314 width = [] 

315 columns = [] 

316 current_column = 0 

317 current_i = 0 

318 

319 for i in x_values: 

320 # Number of repeating values 

321 n_repeating = x_values.count(i) 

322 width.append(1.0 / n_repeating) 

323 if n_repeating > 1 and current_column != 0 and current_i == i: 

324 columns.append(current_column) 

325 current_column += 1 

326 

327 else: 

328 current_column = 0 

329 columns.append(current_column) 

330 current_i = i 

331 current_column += 1 

332 

333 return width, columns 

334 

335 def _addStatisticsPanel(self, fig, handles, nums, sorted_labels, sorted_x_value): 

336 """Add an adjoining panel containing bar graph summary statistics.""" 

337 ax = fig.add_subplot(1, 1, 1) 

338 ax.axis("off") 

339 plt.subplots_adjust(left=0.05, right=0.95, bottom=0.09, top=0.9) 

340 

341 # empty handle, used to populate the bespoke legend layout 

342 empty = Rectangle((0, 0), 1, 1, fc="w", fill=False, edgecolor="none", linewidth=0) 

343 

344 # set up new legend handles and labels 

345 

346 legend_handles = [empty] + handles + ([empty] * 3 * len(handles)) + ([empty] * 3) 

347 legend_labels = ( 

348 ([""] * (len(handles) + 1)) 

349 + ["Bin"] 

350 + sorted_x_value 

351 + ["Count"] 

352 + nums 

353 + ["Sources"] 

354 + sorted_labels 

355 ) 

356 

357 # add the legend 

358 ax.legend( 

359 legend_handles, 

360 legend_labels, 

361 loc="lower left", 

362 ncol=4, 

363 handletextpad=-0.25, 

364 fontsize=6, 

365 borderpad=0, 

366 frameon=False, 

367 columnspacing=-0.25, 

368 )