Coverage for python/lsst/analysis/tools/actions/plot/barPlots.py: 15%

153 statements  

« prev     ^ index     » next       coverage.py v7.2.5, created at 2023-05-02 11:54 -0700

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("BarPanel", "BarPlot") 

24 

25import operator as op 

26from collections import defaultdict 

27from typing import Mapping 

28 

29import matplotlib.pyplot as plt 

30import numpy as np 

31from lsst.analysis.tools.actions.plot.plotUtils import addPlotInfo 

32from lsst.analysis.tools.interfaces import KeyedData, KeyedDataSchema, PlotAction, Vector 

33from lsst.pex.config import Config, ConfigDictField, DictField, Field 

34from matplotlib.figure import Figure 

35from matplotlib.gridspec import GridSpec 

36from matplotlib.patches import Rectangle 

37 

38 

39class BarPanel(Config): 

40 label = Field[str]( 

41 doc="Panel x-axis label.", 

42 default="label", 

43 ) 

44 bars = DictField[str, str]( 

45 doc="A dict specifying the bar graphs to be plotted in this panel. Keys are used to identify " 

46 "bar graph IDs. Values are used to add to the legend label displayed in the upper corner of the " 

47 "panel.", 

48 optional=False, 

49 ) 

50 yscale = Field[str]( 

51 doc="Y axis scaling.", 

52 default="linear", 

53 ) 

54 

55 

56class BarPlot(PlotAction): 

57 """A plotting tool which can take multiple keyed data inputs 

58 and can create one or more bar graphs. 

59 """ 

60 

61 panels = ConfigDictField( 

62 doc="A configurable dict describing the panels to be plotted, and the bar graphs for each panel.", 

63 keytype=str, 

64 itemtype=BarPanel, 

65 default={}, 

66 ) 

67 cmap = Field[str]( 

68 doc="Color map used for bar lines. All types available via `plt.cm` may be used. " 

69 "A number of custom color maps are also defined: `newtab10`, `bright`, `vibrant`.", 

70 default="newtab10", 

71 ) 

72 

73 def getInputSchema(self) -> KeyedDataSchema: 

74 for panel in self.panels: # type: ignore 

75 for barData in self.panels[panel].bars.items(): # type: ignore 

76 yield barData, Vector 

77 

78 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure: 

79 return self.makePlot(data, **kwargs) 

80 

81 def makePlot( 

82 self, data: KeyedData, plotInfo: Mapping[str, str] = None, **kwargs # type: ignore 

83 ) -> Figure: 

84 """Make an N-panel plot with a user-configurable number of bar graphs 

85 displayed in each panel. 

86 

87 Parameters 

88 ---------- 

89 data : `KeyedData` 

90 The catalog to plot the points from. 

91 plotInfo : `dict` 

92 An optional dictionary of information about the data being 

93 plotted with keys: 

94 

95 `"run"` 

96 Output run for the plots (`str`). 

97 `"tractTableType"` 

98 Table from which results are taken (`str`). 

99 `"plotName"` 

100 Output plot name (`str`) 

101 `"SN"` 

102 The global signal-to-noise data threshold (`float`) 

103 `"skymap"` 

104 The type of skymap used for the data (`str`). 

105 `"tract"` 

106 The tract that the data comes from (`int`). 

107 `"bands"` 

108 The bands used for this data (`str` or `list`). 

109 `"visit"` 

110 The visit that the data comes from (`int`) 

111 

112 Returns 

113 ------- 

114 fig : `matplotlib.figure.Figure` 

115 The resulting figure. 

116 

117 """ 

118 

119 # set up figure 

120 fig = plt.figure(dpi=400) 

121 bar_fig, side_fig = fig.subfigures(1, 2, wspace=0, width_ratios=[3, 1]) 

122 axs = self._makeAxes(bar_fig) 

123 

124 # loop over each panel; plot bar graphs 

125 cols = self._assignColors() 

126 all_handles, all_nums, all_vector_labels, all_x_values = [], [], [], [] 

127 for panel, ax in zip(self.panels, axs): 

128 nums, sorted_label, sorted_x_values = self._makePanel(data, panel, ax, cols[panel], **kwargs) 

129 handles, labels = ax.get_legend_handles_labels() # code for plotting 

130 all_handles += handles 

131 all_nums += nums 

132 all_vector_labels += sorted_label 

133 all_x_values += sorted_x_values 

134 

135 # add side panel; add statistics 

136 self._addStatisticsPanel(side_fig, all_handles, all_nums, all_vector_labels, all_x_values) 

137 

138 # add general plot info 

139 if plotInfo is not None: 

140 bar_fig = addPlotInfo(bar_fig, plotInfo) 

141 

142 # finish up 

143 bar_fig.text(0.01, 0.42, "Frequency", rotation=90, transform=bar_fig.transFigure) 

144 plt.draw() 

145 return fig 

146 

147 def _makeAxes(self, fig): 

148 """Determine axes layout for main bar graph figure.""" 

149 num_panels = len(self.panels) 

150 if num_panels <= 1: 

151 ncols = 1 

152 else: 

153 ncols = 2 

154 nrows = int(np.ceil(num_panels / ncols)) 

155 

156 gs = GridSpec(nrows, ncols, left=0.13, right=0.99, bottom=0.1, top=0.88, wspace=0.25, hspace=0.45) 

157 

158 axs = [] 

159 counter = 0 

160 for row in range(nrows): 

161 for col in range(ncols): 

162 counter += 1 

163 if counter < num_panels: 

164 axs.append(fig.add_subplot(gs[row : row + 1, col : col + 1])) 

165 else: 

166 axs.append(fig.add_subplot(gs[row : row + 1, col : np.min([col + 2, ncols + 1])])) 

167 break 

168 

169 return axs 

170 

171 def _assignColors(self): 

172 """Assign colors to bar graphs using a given color map.""" 

173 custom_cmaps = dict( 

174 # https://www.tableau.com/about/blog/2016/7/colors-upgrade-tableau-10-56782 

175 newtab10=[ 

176 "#4e79a7", 

177 "#f28e2b", 

178 "#e15759", 

179 "#76b7b2", 

180 "#59a14f", 

181 "#edc948", 

182 "#b07aa1", 

183 "#ff9da7", 

184 "#9c755f", 

185 "#bab0ac", 

186 ], 

187 # https://personal.sron.nl/~pault/#fig:scheme_bright 

188 bright=[ 

189 "#4477AA", 

190 "#EE6677", 

191 "#228833", 

192 "#CCBB44", 

193 "#66CCEE", 

194 "#AA3377", 

195 "#BBBBBB", 

196 ], 

197 # https://personal.sron.nl/~pault/#fig:scheme_vibrant 

198 vibrant=[ 

199 "#EE7733", 

200 "#0077BB", 

201 "#33BBEE", 

202 "#EE3377", 

203 "#CC3311", 

204 "#009988", 

205 "#BBBBBB", 

206 ], 

207 ) 

208 if self.cmap in custom_cmaps.keys(): 

209 all_cols = custom_cmaps[self.cmap] 

210 else: 

211 try: 

212 all_cols = getattr(plt.cm, self.cmap).copy().colors 

213 except AttributeError: 

214 raise ValueError(f"Unrecognized color map: {self.cmap}") 

215 

216 counter = 0 

217 cols = defaultdict(list) 

218 for panel in self.panels: 

219 for bar in self.panels[panel].bars: 

220 cols[panel].append(all_cols[counter % len(all_cols)]) 

221 counter += 1 

222 return cols 

223 

224 def _makePanel(self, data, panel, ax, col, **kwargs): 

225 """Plot a single panel containing bar graphs.""" 

226 nums = [] 

227 x_values, assigned_labels, assigned_colors = self._assignBinElements(data, panel, col) 

228 sorted_x_values, sorted_labels, sorted_colors = self._sortBarBins( 

229 x_values, assigned_labels, assigned_colors 

230 ) 

231 width, columns = self._getBarWidths(sorted_x_values) 

232 

233 for i, bin in enumerate(sorted_x_values): 

234 bar_data = op.countOf(data[sorted_labels[i]][np.isfinite(data[sorted_labels[i]])], bin) 

235 

236 if width[i] == 1: 

237 bin_center = bin 

238 else: 

239 bin_center = bin - 0.35 + width[i] * columns[i] 

240 

241 ax.bar(bin_center, bar_data, width[i], lw=2, label=sorted_labels[i], color=sorted_colors[i]) 

242 nums.append(bar_data) 

243 

244 # Get plot range 

245 x_range = [x for x in range(int(min(sorted_x_values)), int(max(sorted_x_values)) + 1)] 

246 ax.set_xticks(x_range) 

247 ax.set_xlabel(self.panels[panel].label) 

248 ax.set_yscale(self.panels[panel].yscale) 

249 ax.tick_params(labelsize=7) 

250 # add a buffer to the top of the plot to allow headspace for labels 

251 ylims = list(ax.get_ylim()) 

252 if ax.get_yscale() == "log": 

253 ylims[1] = 10 ** (np.log10(ylims[1]) * 1.1) 

254 else: 

255 ylims[1] *= 1.1 

256 ax.set_ylim(ylims[0], ylims[1]) 

257 return nums, sorted_labels, sorted_x_values 

258 

259 def _assignBinElements(self, data, panel, col): 

260 labels = [] 

261 assigned_labels = [] 

262 x_values = [] 

263 assigned_colors = [] 

264 n_labels = 0 

265 

266 for bar in self.panels[panel].bars: 

267 labels.append(bar) 

268 n_labels += 1 

269 

270 # If a label has multiple unique elements in it, repeats the label 

271 i = 0 

272 for single_label in labels: 

273 unique_elements = np.unique(data[single_label]) 

274 

275 for bin in unique_elements: 

276 x_values.append(int(bin)) 

277 

278 for count in range(len(unique_elements)): 

279 assigned_labels.append(single_label) 

280 assigned_colors.append(col[i]) # Assign color from color cmap 

281 

282 i += 1 

283 

284 return x_values, assigned_labels, assigned_colors 

285 

286 def _sortBarBins(self, x_values, assigned_labels, assigned_colors): 

287 """Sorts the existing x_values, assigned_labels, 

288 and assigned_colors/x_value from lowest to 

289 highest and then uses the sorted indices to sort 

290 all x, labels, and colors in that order. 

291 """ 

292 

293 sorted_indices = np.argsort(x_values) 

294 

295 sorted_labels = [] 

296 sorted_x_values = [] 

297 sorted_colors = [] 

298 

299 for position in sorted_indices: 

300 sorted_x_values.append(x_values[position]) 

301 sorted_labels.append(assigned_labels[position]) 

302 sorted_colors.append(assigned_colors[position]) 

303 

304 return sorted_x_values, sorted_labels, sorted_colors 

305 

306 def _getBarWidths(self, x_values): 

307 """Determine the width of the panels in each 

308 bin and which column is assigned.""" 

309 width = [] 

310 columns = [] 

311 current_column = 0 

312 current_i = 0 

313 

314 for i in x_values: 

315 # Number of repeating values 

316 n_repeating = x_values.count(i) 

317 width.append(1.0 / n_repeating) 

318 if n_repeating > 1 and current_column != 0 and current_i == i: 

319 columns.append(current_column) 

320 current_column += 1 

321 

322 else: 

323 current_column = 0 

324 columns.append(current_column) 

325 current_i = i 

326 current_column += 1 

327 

328 return width, columns 

329 

330 def _addStatisticsPanel(self, fig, handles, nums, sorted_labels, sorted_x_value): 

331 """Add an adjoining panel containing bar graph summary statistics.""" 

332 ax = fig.add_subplot(1, 1, 1) 

333 ax.axis("off") 

334 plt.subplots_adjust(left=0.05, right=0.95, bottom=0.09, top=0.9) 

335 

336 # empty handle, used to populate the bespoke legend layout 

337 empty = Rectangle((0, 0), 1, 1, fc="w", fill=False, edgecolor="none", linewidth=0) 

338 

339 # set up new legend handles and labels 

340 

341 legend_handles = [empty] + handles + ([empty] * 3 * len(handles)) + ([empty] * 3) 

342 legend_labels = ( 

343 ([""] * (len(handles) + 1)) 

344 + ["Bin"] 

345 + sorted_x_value 

346 + ["Count"] 

347 + nums 

348 + ["Sources"] 

349 + sorted_labels 

350 ) 

351 

352 # add the legend 

353 ax.legend( 

354 legend_handles, 

355 legend_labels, 

356 loc="lower left", 

357 ncol=4, 

358 handletextpad=-0.25, 

359 fontsize=6, 

360 borderpad=0, 

361 frameon=False, 

362 columnspacing=-0.25, 

363 )