Coverage for python / lsst / analysis / tools / actions / plot / barPlots.py: 16%

155 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-18 09:19 +0000

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("BarPanel", "BarPlot") 

24 

25import operator as op 

26from collections import defaultdict 

27from collections.abc import Mapping 

28 

29import matplotlib.pyplot as plt 

30import numpy as np 

31from matplotlib.figure import Figure 

32from matplotlib.gridspec import GridSpec 

33from matplotlib.patches import Rectangle 

34 

35from lsst.pex.config import Config, ConfigDictField, DictField, Field 

36from lsst.utils.plotting import set_rubin_plotstyle 

37 

38from ...interfaces import KeyedData, KeyedDataSchema, PlotAction, Vector 

39from .plotUtils import addPlotInfo 

40 

41 

42class BarPanel(Config): 

43 """A configurable class describing a panel in a bar plot.""" 

44 

45 label = Field[str]( 

46 doc="Panel x-axis label.", 

47 default="label", 

48 ) 

49 bars = DictField[str, str]( 

50 doc="A dict specifying the bar graphs to be plotted in this panel. Keys are used to identify " 

51 "bar graph IDs. Values are used to add to the legend label displayed in the upper corner of the " 

52 "panel.", 

53 optional=False, 

54 ) 

55 yscale = Field[str]( 

56 doc="Y axis scaling.", 

57 default="linear", 

58 ) 

59 

60 

61class BarPlot(PlotAction): 

62 """A plotting tool which can take multiple keyed data inputs 

63 and can create one or more bar graphs. 

64 """ 

65 

66 panels = ConfigDictField( 

67 doc="A configurable dict describing the panels to be plotted, and the bar graphs for each panel.", 

68 keytype=str, 

69 itemtype=BarPanel, 

70 default={}, 

71 ) 

72 cmap = Field[str]( 

73 doc="Color map used for bar lines. All types available via `plt.cm` may be used. " 

74 "A number of custom color maps are also defined: `newtab10`, `bright`, `vibrant`.", 

75 default="newtab10", 

76 ) 

77 

78 def getInputSchema(self) -> KeyedDataSchema: 

79 for panel in self.panels: # type: ignore 

80 for barData in self.panels[panel].bars.items(): # type: ignore 

81 yield barData, Vector 

82 

83 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure: 

84 return self.makePlot(data, **kwargs) 

85 

86 def makePlot( 

87 self, data: KeyedData, plotInfo: Mapping[str, str] = None, **kwargs # type: ignore 

88 ) -> Figure: 

89 """Make an N-panel plot with a user-configurable number of bar graphs 

90 displayed in each panel. 

91 

92 Parameters 

93 ---------- 

94 data : `KeyedData` 

95 The catalog to plot the points from. 

96 plotInfo : `dict` 

97 An optional dictionary of information about the data being 

98 plotted with keys: 

99 

100 `"run"` 

101 Output run for the plots (`str`). 

102 `"tractTableType"` 

103 Table from which results are taken (`str`). 

104 `"plotName"` 

105 Output plot name (`str`) 

106 `"SN"` 

107 The global signal-to-noise data threshold (`float`) 

108 `"skymap"` 

109 The type of skymap used for the data (`str`). 

110 `"tract"` 

111 The tract that the data comes from (`int`). 

112 `"bands"` 

113 The bands used for this data (`str` or `list`). 

114 `"visit"` 

115 The visit that the data comes from (`int`) 

116 

117 Returns 

118 ------- 

119 fig : `matplotlib.figure.Figure` 

120 The resulting figure. 

121 

122 """ 

123 

124 # set up figure 

125 set_rubin_plotstyle() 

126 fig = plt.figure(dpi=400) 

127 bar_fig, side_fig = fig.subfigures(1, 2, wspace=0, width_ratios=[3, 1]) 

128 axs = self._makeAxes(bar_fig) 

129 

130 # loop over each panel; plot bar graphs 

131 cols = self._assignColors() 

132 all_handles, all_nums, all_vector_labels, all_x_values = [], [], [], [] 

133 for panel, ax in zip(self.panels, axs): 

134 nums, sorted_label, sorted_x_values = self._makePanel(data, panel, ax, cols[panel], **kwargs) 

135 handles, labels = ax.get_legend_handles_labels() # code for plotting 

136 all_handles += handles 

137 all_nums += nums 

138 all_vector_labels += sorted_label 

139 all_x_values += sorted_x_values 

140 

141 # add side panel; add statistics 

142 self._addStatisticsPanel(side_fig, all_handles, all_nums, all_vector_labels, all_x_values) 

143 

144 # add general plot info 

145 if plotInfo is not None: 

146 bar_fig = addPlotInfo(bar_fig, plotInfo) 

147 

148 # finish up 

149 bar_fig.text(0.01, 0.42, "Frequency", rotation=90, transform=bar_fig.transFigure) 

150 plt.draw() 

151 return fig 

152 

153 def _makeAxes(self, fig): 

154 """Determine axes layout for main bar graph figure.""" 

155 num_panels = len(self.panels) 

156 if num_panels <= 1: 

157 ncols = 1 

158 else: 

159 ncols = 2 

160 nrows = int(np.ceil(num_panels / ncols)) 

161 

162 gs = GridSpec(nrows, ncols, left=0.13, right=0.99, bottom=0.1, top=0.88, wspace=0.25, hspace=0.45) 

163 

164 axs = [] 

165 counter = 0 

166 for row in range(nrows): 

167 for col in range(ncols): 

168 counter += 1 

169 if counter < num_panels: 

170 axs.append(fig.add_subplot(gs[row : row + 1, col : col + 1])) 

171 else: 

172 axs.append(fig.add_subplot(gs[row : row + 1, col : np.min([col + 2, ncols + 1])])) 

173 break 

174 

175 return axs 

176 

177 def _assignColors(self): 

178 """Assign colors to bar graphs using a given color map.""" 

179 custom_cmaps = dict( 

180 # https://www.tableau.com/about/blog/2016/7/colors-upgrade-tableau-10-56782 

181 newtab10=[ 

182 "#4e79a7", 

183 "#f28e2b", 

184 "#e15759", 

185 "#76b7b2", 

186 "#59a14f", 

187 "#edc948", 

188 "#b07aa1", 

189 "#ff9da7", 

190 "#9c755f", 

191 "#bab0ac", 

192 ], 

193 # https://personal.sron.nl/~pault/#fig:scheme_bright 

194 bright=[ 

195 "#4477AA", 

196 "#EE6677", 

197 "#228833", 

198 "#CCBB44", 

199 "#66CCEE", 

200 "#AA3377", 

201 "#BBBBBB", 

202 ], 

203 # https://personal.sron.nl/~pault/#fig:scheme_vibrant 

204 vibrant=[ 

205 "#EE7733", 

206 "#0077BB", 

207 "#33BBEE", 

208 "#EE3377", 

209 "#CC3311", 

210 "#009988", 

211 "#BBBBBB", 

212 ], 

213 ) 

214 if self.cmap in custom_cmaps.keys(): 

215 all_cols = custom_cmaps[self.cmap] 

216 else: 

217 try: 

218 all_cols = getattr(plt.cm, self.cmap).copy().colors 

219 except AttributeError: 

220 raise ValueError(f"Unrecognized color map: {self.cmap}") 

221 

222 counter = 0 

223 cols = defaultdict(list) 

224 for panel in self.panels: 

225 for bar in self.panels[panel].bars: 

226 cols[panel].append(all_cols[counter % len(all_cols)]) 

227 counter += 1 

228 return cols 

229 

230 def _makePanel(self, data, panel, ax, col, **kwargs): 

231 """Plot a single panel containing bar graphs.""" 

232 nums = [] 

233 x_values, assigned_labels, assigned_colors = self._assignBinElements(data, panel, col) 

234 sorted_x_values, sorted_labels, sorted_colors = self._sortBarBins( 

235 x_values, assigned_labels, assigned_colors 

236 ) 

237 width, columns = self._getBarWidths(sorted_x_values) 

238 

239 for i, bin in enumerate(sorted_x_values): 

240 bar_data = op.countOf(data[sorted_labels[i]][np.isfinite(data[sorted_labels[i]])], bin) 

241 

242 if width[i] == 1: 

243 bin_center = bin 

244 else: 

245 bin_center = bin - 0.35 + width[i] * columns[i] 

246 

247 ax.bar(bin_center, bar_data, width[i], lw=2, label=sorted_labels[i], color=sorted_colors[i]) 

248 nums.append(bar_data) 

249 

250 # Get plot range 

251 x_range = [x for x in range(int(min(sorted_x_values)), int(max(sorted_x_values)) + 1)] 

252 ax.set_xticks(x_range) 

253 ax.set_xlabel(self.panels[panel].label) 

254 ax.set_yscale(self.panels[panel].yscale) 

255 ax.tick_params(labelsize=7) 

256 # add a buffer to the top of the plot to allow headspace for labels 

257 ylims = list(ax.get_ylim()) 

258 if ax.get_yscale() == "log": 

259 ylims[1] = 10 ** (np.log10(ylims[1]) * 1.1) 

260 else: 

261 ylims[1] *= 1.1 

262 ax.set_ylim(ylims[0], ylims[1]) 

263 return nums, sorted_labels, sorted_x_values 

264 

265 def _assignBinElements(self, data, panel, col): 

266 labels = [] 

267 assigned_labels = [] 

268 x_values = [] 

269 assigned_colors = [] 

270 n_labels = 0 

271 

272 for bar in self.panels[panel].bars: 

273 labels.append(bar) 

274 n_labels += 1 

275 

276 # If a label has multiple unique elements in it, repeats the label 

277 i = 0 

278 for single_label in labels: 

279 unique_elements = np.unique(data[single_label]) 

280 

281 for bin in unique_elements: 

282 x_values.append(int(bin)) 

283 

284 for count in range(len(unique_elements)): 

285 assigned_labels.append(single_label) 

286 assigned_colors.append(col[i]) # Assign color from color cmap 

287 

288 i += 1 

289 

290 return x_values, assigned_labels, assigned_colors 

291 

292 def _sortBarBins(self, x_values, assigned_labels, assigned_colors): 

293 """Sorts the existing x_values, assigned_labels, 

294 and assigned_colors/x_value from lowest to 

295 highest and then uses the sorted indices to sort 

296 all x, labels, and colors in that order. 

297 """ 

298 

299 sorted_indices = np.argsort(x_values) 

300 

301 sorted_labels = [] 

302 sorted_x_values = [] 

303 sorted_colors = [] 

304 

305 for position in sorted_indices: 

306 sorted_x_values.append(x_values[position]) 

307 sorted_labels.append(assigned_labels[position]) 

308 sorted_colors.append(assigned_colors[position]) 

309 

310 return sorted_x_values, sorted_labels, sorted_colors 

311 

312 def _getBarWidths(self, x_values): 

313 """Determine the width of the panels in each 

314 bin and which column is assigned.""" 

315 width = [] 

316 columns = [] 

317 current_column = 0 

318 current_i = 0 

319 

320 for i in x_values: 

321 # Number of repeating values 

322 n_repeating = x_values.count(i) 

323 width.append(1.0 / n_repeating) 

324 if n_repeating > 1 and current_column != 0 and current_i == i: 

325 columns.append(current_column) 

326 current_column += 1 

327 

328 else: 

329 current_column = 0 

330 columns.append(current_column) 

331 current_i = i 

332 current_column += 1 

333 

334 return width, columns 

335 

336 def _addStatisticsPanel(self, fig, handles, nums, sorted_labels, sorted_x_value): 

337 """Add an adjoining panel containing bar graph summary statistics.""" 

338 ax = fig.add_subplot(1, 1, 1) 

339 ax.axis("off") 

340 plt.subplots_adjust(left=0.05, right=0.95, bottom=0.09, top=0.9) 

341 

342 # empty handle, used to populate the bespoke legend layout 

343 empty = Rectangle((0, 0), 1, 1, fc="w", fill=False, edgecolor="none", linewidth=0) 

344 

345 # set up new legend handles and labels 

346 

347 legend_handles = [empty] + handles + ([empty] * 3 * len(handles)) + ([empty] * 3) 

348 legend_labels = ( 

349 ([""] * (len(handles) + 1)) 

350 + ["Bin"] 

351 + sorted_x_value 

352 + ["Count"] 

353 + nums 

354 + ["Sources"] 

355 + sorted_labels 

356 ) 

357 

358 # add the legend 

359 ax.legend( 

360 legend_handles, 

361 legend_labels, 

362 loc="lower left", 

363 ncol=4, 

364 handletextpad=-0.25, 

365 fontsize=6, 

366 borderpad=0, 

367 frameon=False, 

368 columnspacing=-0.25, 

369 )