Coverage for python/lsst/analysis/tools/actions/plot/barPlots.py: 15%

153 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-03-07 10:56 +0000

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("BarPanel", "BarPlot") 

24 

25import operator as op 

26from collections import defaultdict 

27from typing import Mapping 

28 

29import matplotlib.pyplot as plt 

30import numpy as np 

31from lsst.analysis.tools.actions.plot.plotUtils import addPlotInfo 

32from lsst.analysis.tools.interfaces import KeyedData, KeyedDataSchema, PlotAction, Vector 

33from lsst.pex.config import Config, ConfigDictField, DictField, Field 

34from matplotlib.figure import Figure 

35from matplotlib.gridspec import GridSpec 

36from matplotlib.patches import Rectangle 

37 

38 

39class BarPanel(Config): 

40 label = Field[str]( 

41 doc="Panel x-axis label.", 

42 default="label", 

43 ) 

44 bars = DictField[str, str]( 

45 doc="A dict specifying the bar graphs to be plotted in this panel. Keys are used to identify " 

46 "bar graph IDs. Values are used to add to the legend label displayed in the upper corner of the " 

47 "panel.", 

48 optional=False, 

49 ) 

50 yscale = Field[str]( 

51 doc="Y axis scaling.", 

52 default="linear", 

53 ) 

54 

55 

56class BarPlot(PlotAction): 

57 panels = ConfigDictField( 

58 doc="A configurable dict describing the panels to be plotted, and the bar graphs for each panel.", 

59 keytype=str, 

60 itemtype=BarPanel, 

61 default={}, 

62 ) 

63 cmap = Field[str]( 

64 doc="Color map used for bar lines. All types available via `plt.cm` may be used. " 

65 "A number of custom color maps are also defined: `newtab10`, `bright`, `vibrant`.", 

66 default="newtab10", 

67 ) 

68 

69 def getInputSchema(self) -> KeyedDataSchema: 

70 for panel in self.panels: # type: ignore 

71 for barData in self.panels[panel].bars.items(): # type: ignore 

72 yield barData, Vector 

73 

74 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure: 

75 return self.makePlot(data, **kwargs) 

76 

77 def makePlot( 

78 self, data: KeyedData, plotInfo: Mapping[str, str] = None, **kwargs # type: ignore 

79 ) -> Figure: 

80 """Make an N-panel plot with a user-configurable number of bar graphs 

81 displayed in each panel. 

82 

83 

84 Parameters 

85 ---------- 

86 data : `KeyedData` 

87 The catalog to plot the points from. 

88 plotInfo : `dict` 

89 An optional dictionary of information about the data being 

90 plotted with keys: 

91 `"run"` 

92 Output run for the plots (`str`). 

93 `"tractTableType"` 

94 Table from which results are taken (`str`). 

95 `"plotName"` 

96 Output plot name (`str`) 

97 `"SN"` 

98 The global signal-to-noise data threshold (`float`) 

99 `"skymap"` 

100 The type of skymap used for the data (`str`). 

101 `"tract"` 

102 The tract that the data comes from (`int`). 

103 `"bands"` 

104 The bands used for this data (`str` or `list`). 

105 `"visit"` 

106 The visit that the data comes from (`int`) 

107 

108 Returns 

109 ------- 

110 fig : `matplotlib.figure.Figure` 

111 The resulting figure. 

112 

113 """ 

114 

115 # set up figure 

116 fig = plt.figure(dpi=400) 

117 bar_fig, side_fig = fig.subfigures(1, 2, wspace=0, width_ratios=[3, 1]) 

118 axs = self._makeAxes(bar_fig) 

119 

120 # loop over each panel; plot bar graphs 

121 cols = self._assignColors() 

122 all_handles, all_nums, all_vector_labels, all_x_values = [], [], [], [] 

123 for panel, ax in zip(self.panels, axs): 

124 nums, sorted_label, sorted_x_values = self._makePanel(data, panel, ax, cols[panel], **kwargs) 

125 handles, labels = ax.get_legend_handles_labels() # code for plotting 

126 all_handles += handles 

127 all_nums += nums 

128 all_vector_labels += sorted_label 

129 all_x_values += sorted_x_values 

130 

131 # add side panel; add statistics 

132 self._addStatisticsPanel(side_fig, all_handles, all_nums, all_vector_labels, all_x_values) 

133 

134 # add general plot info 

135 if plotInfo is not None: 

136 bar_fig = addPlotInfo(bar_fig, plotInfo) 

137 

138 # finish up 

139 bar_fig.text(0.01, 0.42, "Frequency", rotation=90, transform=bar_fig.transFigure) 

140 plt.draw() 

141 return fig 

142 

143 def _makeAxes(self, fig): 

144 """Determine axes layout for main bar graph figure.""" 

145 num_panels = len(self.panels) 

146 if num_panels <= 1: 

147 ncols = 1 

148 else: 

149 ncols = 2 

150 nrows = int(np.ceil(num_panels / ncols)) 

151 

152 gs = GridSpec(nrows, ncols, left=0.13, right=0.99, bottom=0.1, top=0.88, wspace=0.25, hspace=0.45) 

153 

154 axs = [] 

155 counter = 0 

156 for row in range(nrows): 

157 for col in range(ncols): 

158 counter += 1 

159 if counter < num_panels: 

160 axs.append(fig.add_subplot(gs[row : row + 1, col : col + 1])) 

161 else: 

162 axs.append(fig.add_subplot(gs[row : row + 1, col : np.min([col + 2, ncols + 1])])) 

163 break 

164 

165 return axs 

166 

167 def _assignColors(self): 

168 """Assign colors to bar graphs using a given color map.""" 

169 custom_cmaps = dict( 

170 # https://www.tableau.com/about/blog/2016/7/colors-upgrade-tableau-10-56782 

171 newtab10=[ 

172 "#4e79a7", 

173 "#f28e2b", 

174 "#e15759", 

175 "#76b7b2", 

176 "#59a14f", 

177 "#edc948", 

178 "#b07aa1", 

179 "#ff9da7", 

180 "#9c755f", 

181 "#bab0ac", 

182 ], 

183 # https://personal.sron.nl/~pault/#fig:scheme_bright 

184 bright=[ 

185 "#4477AA", 

186 "#EE6677", 

187 "#228833", 

188 "#CCBB44", 

189 "#66CCEE", 

190 "#AA3377", 

191 "#BBBBBB", 

192 ], 

193 # https://personal.sron.nl/~pault/#fig:scheme_vibrant 

194 vibrant=[ 

195 "#EE7733", 

196 "#0077BB", 

197 "#33BBEE", 

198 "#EE3377", 

199 "#CC3311", 

200 "#009988", 

201 "#BBBBBB", 

202 ], 

203 ) 

204 if self.cmap in custom_cmaps.keys(): 

205 all_cols = custom_cmaps[self.cmap] 

206 else: 

207 try: 

208 all_cols = getattr(plt.cm, self.cmap).copy().colors 

209 except AttributeError: 

210 raise ValueError(f"Unrecognized color map: {self.cmap}") 

211 

212 counter = 0 

213 cols = defaultdict(list) 

214 for panel in self.panels: 

215 for bar in self.panels[panel].bars: 

216 cols[panel].append(all_cols[counter % len(all_cols)]) 

217 counter += 1 

218 return cols 

219 

220 def _makePanel(self, data, panel, ax, col, **kwargs): 

221 """Plot a single panel containing bar graphs.""" 

222 nums = [] 

223 x_values, assigned_labels, assigned_colors = self._assignBinElements(data, panel, col) 

224 sorted_x_values, sorted_labels, sorted_colors = self._sortBarBins( 

225 x_values, assigned_labels, assigned_colors 

226 ) 

227 width, columns = self._getBarWidths(sorted_x_values) 

228 

229 for i, bin in enumerate(sorted_x_values): 

230 bar_data = op.countOf(data[sorted_labels[i]][np.isfinite(data[sorted_labels[i]])], bin) 

231 

232 if width[i] == 1: 

233 bin_center = bin 

234 else: 

235 bin_center = bin - 0.35 + width[i] * columns[i] 

236 

237 ax.bar(bin_center, bar_data, width[i], lw=2, label=sorted_labels[i], color=sorted_colors[i]) 

238 nums.append(bar_data) 

239 

240 # Get plot range 

241 x_range = [x for x in range(int(min(sorted_x_values)), int(max(sorted_x_values)) + 1)] 

242 ax.set_xticks(x_range) 

243 ax.set_xlabel(self.panels[panel].label) 

244 ax.set_yscale(self.panels[panel].yscale) 

245 ax.tick_params(labelsize=7) 

246 # add a buffer to the top of the plot to allow headspace for labels 

247 ylims = list(ax.get_ylim()) 

248 if ax.get_yscale() == "log": 

249 ylims[1] = 10 ** (np.log10(ylims[1]) * 1.1) 

250 else: 

251 ylims[1] *= 1.1 

252 ax.set_ylim(ylims[0], ylims[1]) 

253 return nums, sorted_labels, sorted_x_values 

254 

255 def _assignBinElements(self, data, panel, col): 

256 labels = [] 

257 assigned_labels = [] 

258 x_values = [] 

259 assigned_colors = [] 

260 n_labels = 0 

261 

262 for bar in self.panels[panel].bars: 

263 labels.append(bar) 

264 n_labels += 1 

265 

266 # If a label has multiple unique elements in it, repeats the label 

267 i = 0 

268 for single_label in labels: 

269 unique_elements = np.unique(data[single_label]) 

270 

271 for bin in unique_elements: 

272 x_values.append(int(bin)) 

273 

274 for count in range(len(unique_elements)): 

275 assigned_labels.append(single_label) 

276 assigned_colors.append(col[i]) # Assign color from color cmap 

277 

278 i += 1 

279 

280 return x_values, assigned_labels, assigned_colors 

281 

282 def _sortBarBins(self, x_values, assigned_labels, assigned_colors): 

283 """Sorts the existing x_values, assigned_labels, 

284 and assigned_colors/x_value from lowest to 

285 highest and then uses the sorted indices to sort 

286 all x, labels, and colors in that order. 

287 """ 

288 

289 sorted_indices = np.argsort(x_values) 

290 

291 sorted_labels = [] 

292 sorted_x_values = [] 

293 sorted_colors = [] 

294 

295 for position in sorted_indices: 

296 sorted_x_values.append(x_values[position]) 

297 sorted_labels.append(assigned_labels[position]) 

298 sorted_colors.append(assigned_colors[position]) 

299 

300 return sorted_x_values, sorted_labels, sorted_colors 

301 

302 def _getBarWidths(self, x_values): 

303 """Determine the width of the panels in each 

304 bin and which column is assigned.""" 

305 width = [] 

306 columns = [] 

307 current_column = 0 

308 current_i = 0 

309 

310 for i in x_values: 

311 # Number of repeating values 

312 n_repeating = x_values.count(i) 

313 width.append(1.0 / n_repeating) 

314 if n_repeating > 1 and current_column != 0 and current_i == i: 

315 columns.append(current_column) 

316 current_column += 1 

317 

318 else: 

319 current_column = 0 

320 columns.append(current_column) 

321 current_i = i 

322 current_column += 1 

323 

324 return width, columns 

325 

326 def _addStatisticsPanel(self, fig, handles, nums, sorted_labels, sorted_x_value): 

327 """Add an adjoining panel containing bar graph summary statistics.""" 

328 ax = fig.add_subplot(1, 1, 1) 

329 ax.axis("off") 

330 plt.subplots_adjust(left=0.05, right=0.95, bottom=0.09, top=0.9) 

331 

332 # empty handle, used to populate the bespoke legend layout 

333 empty = Rectangle((0, 0), 1, 1, fc="w", fill=False, edgecolor="none", linewidth=0) 

334 

335 # set up new legend handles and labels 

336 

337 legend_handles = [empty] + handles + ([empty] * 3 * len(handles)) + ([empty] * 3) 

338 legend_labels = ( 

339 ([""] * (len(handles) + 1)) 

340 + ["Bin"] 

341 + sorted_x_value 

342 + ["Count"] 

343 + nums 

344 + ["Sources"] 

345 + sorted_labels 

346 ) 

347 

348 # add the legend 

349 ax.legend( 

350 legend_handles, 

351 legend_labels, 

352 loc="lower left", 

353 ncol=4, 

354 handletextpad=-0.25, 

355 fontsize=6, 

356 borderpad=0, 

357 frameon=False, 

358 columnspacing=-0.25, 

359 )