Coverage for python/lsst/analysis/tools/actions/plot/histPlot.py: 24%

121 statements  

« prev     ^ index     » next       coverage.py v6.4.4, created at 2022-08-26 03:20 -0700

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("HistPanel", "HistPlot") 

24 

25from collections import defaultdict 

26from typing import Mapping 

27 

28import matplotlib.pyplot as plt 

29import numpy as np 

30from lsst.pex.config import Config, ConfigDictField, DictField, Field 

31from matplotlib.figure import Figure 

32from matplotlib.gridspec import GridSpec 

33from matplotlib.patches import Rectangle 

34from scipy.stats import median_absolute_deviation as sigmaMad 

35 

36from ...interfaces import KeyedData, KeyedDataSchema, PlotAction, Vector 

37from .plotUtils import addPlotInfo 

38 

39 

40class HistPanel(Config): 

41 label = Field[str]( 

42 doc="Panel x-axis label.", 

43 default="label", 

44 ) 

45 hists = DictField[str, str]( 

46 doc="A dict specifying the histograms to be plotted in this panel. Keys are used to identify " 

47 "histogram IDs. Values are used to add to the legend label displayed in the upper corner of the " 

48 "panel.", 

49 optional=False, 

50 ) 

51 yscale = Field[str]( 

52 doc="Y axis scaling.", 

53 default="linear", 

54 ) 

55 bins = Field[int]( 

56 doc="Number of x axis bins.", 

57 default=50, 

58 ) 

59 pLower = Field[float]( 

60 doc="Percentile used to determine the lower range of the histogram bins. If more than one histogram " 

61 "is plotted in the panel, the percentile limit is the minimum value across all input data.", 

62 default=2.0, 

63 ) 

64 pUpper = Field[float]( 

65 doc="Percentile used to determine the upper range of the histogram bins. If more than one histogram " 

66 "is plotted, the percentile limit is the maximum value across all input data.", 

67 default=98.0, 

68 ) 

69 

70 

71class HistPlot(PlotAction): 

72 panels = ConfigDictField( 

73 doc="A configurable dict describing the panels to be plotted, and the histograms for each panel.", 

74 keytype=str, 

75 itemtype=HistPanel, 

76 default={}, 

77 ) 

78 cmap = Field[str]( 

79 doc="Color map used for histogram lines. All types available via `plt.cm` may be used. " 

80 "A number of custom color maps are also defined: `newtab10`, `bright`, `vibrant`.", 

81 default="newtab10", 

82 ) 

83 

84 def getInputSchema(self) -> KeyedDataSchema: 

85 for panel in self.panels: # type: ignore 

86 for histData in self.panels[panel].hists.items(): # type: ignore 

87 yield histData, Vector 

88 

89 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure: 

90 return self.makePlot(data, **kwargs) 

91 # table is a dict that needs: x, y, run, skymap, filter, tract, 

92 

93 def makePlot( 

94 self, data: KeyedData, plotInfo: Mapping[str, str] = None, **kwargs # type: ignore 

95 ) -> Figure: 

96 """Make an N-panel plot with a user-configurable number of histograms 

97 displayed in each panel. 

98 

99 Parameters 

100 ---------- 

101 data : `pandas.core.frame.DataFrame` 

102 The catalog to plot the points from. 

103 plotInfo : `dict` 

104 A dictionary of information about the data being plotted with keys: 

105 `"run"` 

106 Output run for the plots (`str`). 

107 `"tractTableType"` 

108 Table from which results are taken (`str`). 

109 `"plotName"` 

110 Output plot name (`str`) 

111 `"SN"` 

112 The global signal-to-noise data threshold (`float`) 

113 `"skymap"` 

114 The type of skymap used for the data (`str`). 

115 `"tract"` 

116 The tract that the data comes from (`int`). 

117 `"bands"` 

118 The bands used for this data (`str` or `list`). 

119 `"visit"` 

120 The visit that the data comes from (`int`) 

121 

122 Returns 

123 ------- 

124 fig : `matplotlib.figure.Figure` 

125 The resulting figure. 

126 

127 """ 

128 

129 # set up figure 

130 fig = plt.figure(dpi=300) 

131 hist_fig, side_fig = fig.subfigures(1, 2, wspace=0, width_ratios=[3, 1]) 

132 axs = self._makeAxes(hist_fig) 

133 

134 # loop over each panel; plot histograms 

135 cols = self._assignColors() 

136 all_handles, all_nums, all_meds, all_mads = [], [], [], [] 

137 for panel, ax in zip(self.panels, axs): 

138 nums, meds, mads = self._makePanel(data, panel, ax, cols[panel], **kwargs) 

139 handles, labels = ax.get_legend_handles_labels() # code for plotting 

140 all_handles += handles 

141 all_nums += nums 

142 all_meds += meds 

143 all_mads += mads 

144 

145 # add side panel; add statistics 

146 self._addStatisticsPanel(side_fig, all_handles, all_nums, all_meds, all_mads) 

147 

148 # add general plot info 

149 hist_fig = addPlotInfo(hist_fig, plotInfo) 

150 

151 # finish up 

152 hist_fig.text(0.01, 0.42, "Frequency", rotation=90, transform=hist_fig.transFigure) 

153 plt.draw() 

154 return fig 

155 

156 def _makeAxes(self, fig): 

157 """Determine axes layout for main histogram figure.""" 

158 num_panels = len(self.panels) 

159 if num_panels <= 1: 

160 ncols = 1 

161 else: 

162 ncols = 2 

163 nrows = int(np.ceil(num_panels / ncols)) 

164 

165 gs = GridSpec(nrows, ncols, left=0.13, right=0.99, bottom=0.1, top=0.88, wspace=0.25, hspace=0.45) 

166 

167 axs = [] 

168 counter = 0 

169 for row in range(nrows): 

170 for col in range(ncols): 

171 counter += 1 

172 if counter < num_panels: 

173 axs.append(fig.add_subplot(gs[row : row + 1, col : col + 1])) 

174 else: 

175 axs.append(fig.add_subplot(gs[row : row + 1, col : np.min([col + 2, ncols + 1])])) 

176 break 

177 

178 return axs 

179 

180 def _assignColors(self): 

181 """Assign colors to histograms using a given color map.""" 

182 custom_cmaps = dict( 

183 # https://www.tableau.com/about/blog/2016/7/colors-upgrade-tableau-10-56782 

184 newtab10=[ 

185 "#4e79a7", 

186 "#f28e2b", 

187 "#e15759", 

188 "#76b7b2", 

189 "#59a14f", 

190 "#edc948", 

191 "#b07aa1", 

192 "#ff9da7", 

193 "#9c755f", 

194 "#bab0ac", 

195 ], 

196 # https://personal.sron.nl/~pault/#fig:scheme_bright 

197 bright=[ 

198 "#4477AA", 

199 "#EE6677", 

200 "#228833", 

201 "#CCBB44", 

202 "#66CCEE", 

203 "#AA3377", 

204 "#BBBBBB", 

205 ], 

206 # https://personal.sron.nl/~pault/#fig:scheme_vibrant 

207 vibrant=[ 

208 "#EE7733", 

209 "#0077BB", 

210 "#33BBEE", 

211 "#EE3377", 

212 "#CC3311", 

213 "#009988", 

214 "#BBBBBB", 

215 ], 

216 ) 

217 if self.cmap in custom_cmaps.keys(): 

218 all_cols = custom_cmaps[self.cmap] 

219 else: 

220 try: 

221 all_cols = getattr(plt.cm, self.cmap).copy().colors 

222 except AttributeError: 

223 raise ValueError(f"Unrecognized color map: {self.cmap}") 

224 

225 counter = 0 

226 cols = defaultdict(list) 

227 for panel in self.panels: 

228 for hist in self.panels[panel].hists: 

229 cols[panel].append(all_cols[counter % len(all_cols)]) 

230 counter += 1 

231 return cols 

232 

233 def _makePanel(self, data, panel, ax, col, **kwargs): 

234 """Plot a single panel containing histograms.""" 

235 panel_range = self._getPanelRange(data, panel) 

236 nums, meds, mads = [], [], [] 

237 for i, hist in enumerate(self.panels[panel].hists): 

238 hist_data = data[hist][np.isfinite(data[hist])] 

239 ax.hist( 

240 hist_data, 

241 range=panel_range, 

242 bins=self.panels[panel].bins, 

243 histtype="step", 

244 lw=2, 

245 color=col[i], 

246 label=self.panels[panel].hists[hist], 

247 ) 

248 num, med, mad = self._calcStats(hist_data) 

249 nums.append(num) 

250 meds.append(med) 

251 mads.append(mad) 

252 ax.axvline(med, ls="--", lw=1, c=col[i]) 

253 ax.legend(fontsize=6, loc="upper left") 

254 ax.set_xlim(panel_range) 

255 ax.set_xlabel(self.panels[panel].label) 

256 ax.set_yscale(self.panels[panel].yscale) 

257 ax.tick_params(labelsize=7) 

258 # add a buffer to the top of the plot to allow headspace for labels 

259 ylims = list(ax.get_ylim()) 

260 if ax.get_yscale() == "log": 

261 ylims[1] = 10 ** (np.log10(ylims[1]) * 1.1) 

262 else: 

263 ylims[1] *= 1.1 

264 ax.set_ylim(ylims[0], ylims[1]) 

265 return nums, meds, mads 

266 

267 def _getPanelRange(self, data, panel): 

268 """Determine panel x-axis range based on data percentile limits.""" 

269 panel_range = [np.nan, np.nan] 

270 for hist in self.panels[panel].hists: 

271 hist_range = np.nanpercentile(data[hist], [self.panels[panel].pLower, self.panels[panel].pUpper]) 

272 panel_range[0] = np.nanmin([panel_range[0], hist_range[0]]) 

273 panel_range[1] = np.nanmax([panel_range[1], hist_range[1]]) 

274 return panel_range 

275 

276 def _calcStats(self, data): 

277 """Calculate the number of data points, median, and median absolute 

278 deviation of input data.""" 

279 num = len(data) 

280 med = np.nanmedian(data) 

281 mad = sigmaMad(data) 

282 return num, med, mad 

283 

284 def _addStatisticsPanel(self, fig, handles, nums, meds, mads): 

285 """Add an adjoining panel containing histogram summary statistics.""" 

286 ax = fig.add_subplot(1, 1, 1) 

287 ax.axis("off") 

288 plt.subplots_adjust(left=0.05, right=0.95, bottom=0.09, top=0.9) 

289 

290 # empty handle, used to populate the bespoke legend layout 

291 empty = Rectangle((0, 0), 1, 1, fc="w", fill=False, edgecolor="none", linewidth=0) 

292 

293 # set up new legend handles and labels 

294 legend_handles = [empty] + handles + ([empty] * 3 * len(handles)) + ([empty] * 3) 

295 legend_labels = ( 

296 ([""] * (len(handles) + 1)) 

297 + ["Num"] 

298 + nums 

299 + ["Med"] 

300 + [f"{x:0.1f}" for x in meds] 

301 + ["${{\\sigma}}_{{MAD}}$"] 

302 + [f"{x:0.1f}" for x in mads] 

303 ) 

304 

305 # add the legend 

306 ax.legend( 

307 legend_handles, 

308 legend_labels, 

309 loc="lower left", 

310 ncol=4, 

311 handletextpad=-0.25, 

312 fontsize=6, 

313 borderpad=0, 

314 frameon=False, 

315 columnspacing=-0.25, 

316 )