Coverage for python/lsst/analysis/tools/actions/plot/barPlots.py: 15%

153 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-07-23 09:42 +0000

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("BarPanel", "BarPlot") 

24 

25import operator as op 

26from collections import defaultdict 

27from typing import Mapping 

28 

29import matplotlib.pyplot as plt 

30import numpy as np 

31from lsst.pex.config import Config, ConfigDictField, DictField, Field 

32from matplotlib.figure import Figure 

33from matplotlib.gridspec import GridSpec 

34from matplotlib.patches import Rectangle 

35 

36from ...interfaces import KeyedData, KeyedDataSchema, PlotAction, Vector 

37from .plotUtils import addPlotInfo 

38 

39 

40class BarPanel(Config): 

41 """A configurable class describing a panel in a bar plot.""" 

42 

43 label = Field[str]( 

44 doc="Panel x-axis label.", 

45 default="label", 

46 ) 

47 bars = DictField[str, str]( 

48 doc="A dict specifying the bar graphs to be plotted in this panel. Keys are used to identify " 

49 "bar graph IDs. Values are used to add to the legend label displayed in the upper corner of the " 

50 "panel.", 

51 optional=False, 

52 ) 

53 yscale = Field[str]( 

54 doc="Y axis scaling.", 

55 default="linear", 

56 ) 

57 

58 

59class BarPlot(PlotAction): 

60 """A plotting tool which can take multiple keyed data inputs 

61 and can create one or more bar graphs. 

62 """ 

63 

64 panels = ConfigDictField( 

65 doc="A configurable dict describing the panels to be plotted, and the bar graphs for each panel.", 

66 keytype=str, 

67 itemtype=BarPanel, 

68 default={}, 

69 ) 

70 cmap = Field[str]( 

71 doc="Color map used for bar lines. All types available via `plt.cm` may be used. " 

72 "A number of custom color maps are also defined: `newtab10`, `bright`, `vibrant`.", 

73 default="newtab10", 

74 ) 

75 

76 def getInputSchema(self) -> KeyedDataSchema: 

77 for panel in self.panels: # type: ignore 

78 for barData in self.panels[panel].bars.items(): # type: ignore 

79 yield barData, Vector 

80 

81 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure: 

82 return self.makePlot(data, **kwargs) 

83 

84 def makePlot( 

85 self, data: KeyedData, plotInfo: Mapping[str, str] = None, **kwargs # type: ignore 

86 ) -> Figure: 

87 """Make an N-panel plot with a user-configurable number of bar graphs 

88 displayed in each panel. 

89 

90 Parameters 

91 ---------- 

92 data : `KeyedData` 

93 The catalog to plot the points from. 

94 plotInfo : `dict` 

95 An optional dictionary of information about the data being 

96 plotted with keys: 

97 

98 `"run"` 

99 Output run for the plots (`str`). 

100 `"tractTableType"` 

101 Table from which results are taken (`str`). 

102 `"plotName"` 

103 Output plot name (`str`) 

104 `"SN"` 

105 The global signal-to-noise data threshold (`float`) 

106 `"skymap"` 

107 The type of skymap used for the data (`str`). 

108 `"tract"` 

109 The tract that the data comes from (`int`). 

110 `"bands"` 

111 The bands used for this data (`str` or `list`). 

112 `"visit"` 

113 The visit that the data comes from (`int`) 

114 

115 Returns 

116 ------- 

117 fig : `matplotlib.figure.Figure` 

118 The resulting figure. 

119 

120 """ 

121 

122 # set up figure 

123 fig = plt.figure(dpi=400) 

124 bar_fig, side_fig = fig.subfigures(1, 2, wspace=0, width_ratios=[3, 1]) 

125 axs = self._makeAxes(bar_fig) 

126 

127 # loop over each panel; plot bar graphs 

128 cols = self._assignColors() 

129 all_handles, all_nums, all_vector_labels, all_x_values = [], [], [], [] 

130 for panel, ax in zip(self.panels, axs): 

131 nums, sorted_label, sorted_x_values = self._makePanel(data, panel, ax, cols[panel], **kwargs) 

132 handles, labels = ax.get_legend_handles_labels() # code for plotting 

133 all_handles += handles 

134 all_nums += nums 

135 all_vector_labels += sorted_label 

136 all_x_values += sorted_x_values 

137 

138 # add side panel; add statistics 

139 self._addStatisticsPanel(side_fig, all_handles, all_nums, all_vector_labels, all_x_values) 

140 

141 # add general plot info 

142 if plotInfo is not None: 

143 bar_fig = addPlotInfo(bar_fig, plotInfo) 

144 

145 # finish up 

146 bar_fig.text(0.01, 0.42, "Frequency", rotation=90, transform=bar_fig.transFigure) 

147 plt.draw() 

148 return fig 

149 

150 def _makeAxes(self, fig): 

151 """Determine axes layout for main bar graph figure.""" 

152 num_panels = len(self.panels) 

153 if num_panels <= 1: 

154 ncols = 1 

155 else: 

156 ncols = 2 

157 nrows = int(np.ceil(num_panels / ncols)) 

158 

159 gs = GridSpec(nrows, ncols, left=0.13, right=0.99, bottom=0.1, top=0.88, wspace=0.25, hspace=0.45) 

160 

161 axs = [] 

162 counter = 0 

163 for row in range(nrows): 

164 for col in range(ncols): 

165 counter += 1 

166 if counter < num_panels: 

167 axs.append(fig.add_subplot(gs[row : row + 1, col : col + 1])) 

168 else: 

169 axs.append(fig.add_subplot(gs[row : row + 1, col : np.min([col + 2, ncols + 1])])) 

170 break 

171 

172 return axs 

173 

174 def _assignColors(self): 

175 """Assign colors to bar graphs using a given color map.""" 

176 custom_cmaps = dict( 

177 # https://www.tableau.com/about/blog/2016/7/colors-upgrade-tableau-10-56782 

178 newtab10=[ 

179 "#4e79a7", 

180 "#f28e2b", 

181 "#e15759", 

182 "#76b7b2", 

183 "#59a14f", 

184 "#edc948", 

185 "#b07aa1", 

186 "#ff9da7", 

187 "#9c755f", 

188 "#bab0ac", 

189 ], 

190 # https://personal.sron.nl/~pault/#fig:scheme_bright 

191 bright=[ 

192 "#4477AA", 

193 "#EE6677", 

194 "#228833", 

195 "#CCBB44", 

196 "#66CCEE", 

197 "#AA3377", 

198 "#BBBBBB", 

199 ], 

200 # https://personal.sron.nl/~pault/#fig:scheme_vibrant 

201 vibrant=[ 

202 "#EE7733", 

203 "#0077BB", 

204 "#33BBEE", 

205 "#EE3377", 

206 "#CC3311", 

207 "#009988", 

208 "#BBBBBB", 

209 ], 

210 ) 

211 if self.cmap in custom_cmaps.keys(): 

212 all_cols = custom_cmaps[self.cmap] 

213 else: 

214 try: 

215 all_cols = getattr(plt.cm, self.cmap).copy().colors 

216 except AttributeError: 

217 raise ValueError(f"Unrecognized color map: {self.cmap}") 

218 

219 counter = 0 

220 cols = defaultdict(list) 

221 for panel in self.panels: 

222 for bar in self.panels[panel].bars: 

223 cols[panel].append(all_cols[counter % len(all_cols)]) 

224 counter += 1 

225 return cols 

226 

227 def _makePanel(self, data, panel, ax, col, **kwargs): 

228 """Plot a single panel containing bar graphs.""" 

229 nums = [] 

230 x_values, assigned_labels, assigned_colors = self._assignBinElements(data, panel, col) 

231 sorted_x_values, sorted_labels, sorted_colors = self._sortBarBins( 

232 x_values, assigned_labels, assigned_colors 

233 ) 

234 width, columns = self._getBarWidths(sorted_x_values) 

235 

236 for i, bin in enumerate(sorted_x_values): 

237 bar_data = op.countOf(data[sorted_labels[i]][np.isfinite(data[sorted_labels[i]])], bin) 

238 

239 if width[i] == 1: 

240 bin_center = bin 

241 else: 

242 bin_center = bin - 0.35 + width[i] * columns[i] 

243 

244 ax.bar(bin_center, bar_data, width[i], lw=2, label=sorted_labels[i], color=sorted_colors[i]) 

245 nums.append(bar_data) 

246 

247 # Get plot range 

248 x_range = [x for x in range(int(min(sorted_x_values)), int(max(sorted_x_values)) + 1)] 

249 ax.set_xticks(x_range) 

250 ax.set_xlabel(self.panels[panel].label) 

251 ax.set_yscale(self.panels[panel].yscale) 

252 ax.tick_params(labelsize=7) 

253 # add a buffer to the top of the plot to allow headspace for labels 

254 ylims = list(ax.get_ylim()) 

255 if ax.get_yscale() == "log": 

256 ylims[1] = 10 ** (np.log10(ylims[1]) * 1.1) 

257 else: 

258 ylims[1] *= 1.1 

259 ax.set_ylim(ylims[0], ylims[1]) 

260 return nums, sorted_labels, sorted_x_values 

261 

262 def _assignBinElements(self, data, panel, col): 

263 labels = [] 

264 assigned_labels = [] 

265 x_values = [] 

266 assigned_colors = [] 

267 n_labels = 0 

268 

269 for bar in self.panels[panel].bars: 

270 labels.append(bar) 

271 n_labels += 1 

272 

273 # If a label has multiple unique elements in it, repeats the label 

274 i = 0 

275 for single_label in labels: 

276 unique_elements = np.unique(data[single_label]) 

277 

278 for bin in unique_elements: 

279 x_values.append(int(bin)) 

280 

281 for count in range(len(unique_elements)): 

282 assigned_labels.append(single_label) 

283 assigned_colors.append(col[i]) # Assign color from color cmap 

284 

285 i += 1 

286 

287 return x_values, assigned_labels, assigned_colors 

288 

289 def _sortBarBins(self, x_values, assigned_labels, assigned_colors): 

290 """Sorts the existing x_values, assigned_labels, 

291 and assigned_colors/x_value from lowest to 

292 highest and then uses the sorted indices to sort 

293 all x, labels, and colors in that order. 

294 """ 

295 

296 sorted_indices = np.argsort(x_values) 

297 

298 sorted_labels = [] 

299 sorted_x_values = [] 

300 sorted_colors = [] 

301 

302 for position in sorted_indices: 

303 sorted_x_values.append(x_values[position]) 

304 sorted_labels.append(assigned_labels[position]) 

305 sorted_colors.append(assigned_colors[position]) 

306 

307 return sorted_x_values, sorted_labels, sorted_colors 

308 

309 def _getBarWidths(self, x_values): 

310 """Determine the width of the panels in each 

311 bin and which column is assigned.""" 

312 width = [] 

313 columns = [] 

314 current_column = 0 

315 current_i = 0 

316 

317 for i in x_values: 

318 # Number of repeating values 

319 n_repeating = x_values.count(i) 

320 width.append(1.0 / n_repeating) 

321 if n_repeating > 1 and current_column != 0 and current_i == i: 

322 columns.append(current_column) 

323 current_column += 1 

324 

325 else: 

326 current_column = 0 

327 columns.append(current_column) 

328 current_i = i 

329 current_column += 1 

330 

331 return width, columns 

332 

333 def _addStatisticsPanel(self, fig, handles, nums, sorted_labels, sorted_x_value): 

334 """Add an adjoining panel containing bar graph summary statistics.""" 

335 ax = fig.add_subplot(1, 1, 1) 

336 ax.axis("off") 

337 plt.subplots_adjust(left=0.05, right=0.95, bottom=0.09, top=0.9) 

338 

339 # empty handle, used to populate the bespoke legend layout 

340 empty = Rectangle((0, 0), 1, 1, fc="w", fill=False, edgecolor="none", linewidth=0) 

341 

342 # set up new legend handles and labels 

343 

344 legend_handles = [empty] + handles + ([empty] * 3 * len(handles)) + ([empty] * 3) 

345 legend_labels = ( 

346 ([""] * (len(handles) + 1)) 

347 + ["Bin"] 

348 + sorted_x_value 

349 + ["Count"] 

350 + nums 

351 + ["Sources"] 

352 + sorted_labels 

353 ) 

354 

355 # add the legend 

356 ax.legend( 

357 legend_handles, 

358 legend_labels, 

359 loc="lower left", 

360 ncol=4, 

361 handletextpad=-0.25, 

362 fontsize=6, 

363 borderpad=0, 

364 frameon=False, 

365 columnspacing=-0.25, 

366 )