Coverage for python/lsst/analysis/tools/actions/plot/histPlot.py: 21%

122 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-11-16 01:27 -0800

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("HistPanel", "HistPlot") 

24 

25from collections import defaultdict 

26from typing import Mapping 

27 

28import matplotlib.pyplot as plt 

29import numpy as np 

30from lsst.pex.config import Config, ConfigDictField, DictField, Field 

31from matplotlib.figure import Figure 

32from matplotlib.gridspec import GridSpec 

33from matplotlib.patches import Rectangle 

34 

35from ...interfaces import KeyedData, KeyedDataSchema, PlotAction, Vector 

36from ...statistics import sigmaMad 

37from .plotUtils import addPlotInfo 

38 

39 

40class HistPanel(Config): 

41 label = Field[str]( 

42 doc="Panel x-axis label.", 

43 default="label", 

44 ) 

45 hists = DictField[str, str]( 

46 doc="A dict specifying the histograms to be plotted in this panel. Keys are used to identify " 

47 "histogram IDs. Values are used to add to the legend label displayed in the upper corner of the " 

48 "panel.", 

49 optional=False, 

50 ) 

51 yscale = Field[str]( 

52 doc="Y axis scaling.", 

53 default="linear", 

54 ) 

55 bins = Field[int]( 

56 doc="Number of x axis bins.", 

57 default=50, 

58 ) 

59 pLower = Field[float]( 

60 doc="Percentile used to determine the lower range of the histogram bins. If more than one histogram " 

61 "is plotted in the panel, the percentile limit is the minimum value across all input data.", 

62 default=2.0, 

63 ) 

64 pUpper = Field[float]( 

65 doc="Percentile used to determine the upper range of the histogram bins. If more than one histogram " 

66 "is plotted, the percentile limit is the maximum value across all input data.", 

67 default=98.0, 

68 ) 

69 

70 

71class HistPlot(PlotAction): 

72 panels = ConfigDictField( 

73 doc="A configurable dict describing the panels to be plotted, and the histograms for each panel.", 

74 keytype=str, 

75 itemtype=HistPanel, 

76 default={}, 

77 ) 

78 cmap = Field[str]( 

79 doc="Color map used for histogram lines. All types available via `plt.cm` may be used. " 

80 "A number of custom color maps are also defined: `newtab10`, `bright`, `vibrant`.", 

81 default="newtab10", 

82 ) 

83 

84 def getInputSchema(self) -> KeyedDataSchema: 

85 for panel in self.panels: # type: ignore 

86 for histData in self.panels[panel].hists.items(): # type: ignore 

87 yield histData, Vector 

88 

89 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure: 

90 return self.makePlot(data, **kwargs) 

91 # table is a dict that needs: x, y, run, skymap, filter, tract, 

92 

93 def makePlot( 

94 self, data: KeyedData, plotInfo: Mapping[str, str] = None, **kwargs # type: ignore 

95 ) -> Figure: 

96 """Make an N-panel plot with a user-configurable number of histograms 

97 displayed in each panel. 

98 

99 Parameters 

100 ---------- 

101 data : `pandas.core.frame.DataFrame` 

102 The catalog to plot the points from. 

103 plotInfo : `dict` 

104 A dictionary of information about the data being plotted with keys: 

105 `"run"` 

106 Output run for the plots (`str`). 

107 `"tractTableType"` 

108 Table from which results are taken (`str`). 

109 `"plotName"` 

110 Output plot name (`str`) 

111 `"SN"` 

112 The global signal-to-noise data threshold (`float`) 

113 `"skymap"` 

114 The type of skymap used for the data (`str`). 

115 `"tract"` 

116 The tract that the data comes from (`int`). 

117 `"bands"` 

118 The bands used for this data (`str` or `list`). 

119 `"visit"` 

120 The visit that the data comes from (`int`) 

121 

122 Returns 

123 ------- 

124 fig : `matplotlib.figure.Figure` 

125 The resulting figure. 

126 

127 """ 

128 

129 # set up figure 

130 fig = plt.figure(dpi=300) 

131 hist_fig, side_fig = fig.subfigures(1, 2, wspace=0, width_ratios=[3, 1]) 

132 axs = self._makeAxes(hist_fig) 

133 

134 # loop over each panel; plot histograms 

135 cols = self._assignColors() 

136 all_handles, all_nums, all_meds, all_mads = [], [], [], [] 

137 for panel, ax in zip(self.panels, axs): 

138 nums, meds, mads = self._makePanel(data, panel, ax, cols[panel], **kwargs) 

139 handles, labels = ax.get_legend_handles_labels() # code for plotting 

140 all_handles += handles 

141 all_nums += nums 

142 all_meds += meds 

143 all_mads += mads 

144 

145 # add side panel; add statistics 

146 self._addStatisticsPanel(side_fig, all_handles, all_nums, all_meds, all_mads) 

147 

148 # add general plot info 

149 if plotInfo is not None: 

150 hist_fig = addPlotInfo(hist_fig, plotInfo) 

151 

152 # finish up 

153 hist_fig.text(0.01, 0.42, "Frequency", rotation=90, transform=hist_fig.transFigure) 

154 plt.draw() 

155 return fig 

156 

157 def _makeAxes(self, fig): 

158 """Determine axes layout for main histogram figure.""" 

159 num_panels = len(self.panels) 

160 if num_panels <= 1: 

161 ncols = 1 

162 else: 

163 ncols = 2 

164 nrows = int(np.ceil(num_panels / ncols)) 

165 

166 gs = GridSpec(nrows, ncols, left=0.13, right=0.99, bottom=0.1, top=0.88, wspace=0.25, hspace=0.45) 

167 

168 axs = [] 

169 counter = 0 

170 for row in range(nrows): 

171 for col in range(ncols): 

172 counter += 1 

173 if counter < num_panels: 

174 axs.append(fig.add_subplot(gs[row : row + 1, col : col + 1])) 

175 else: 

176 axs.append(fig.add_subplot(gs[row : row + 1, col : np.min([col + 2, ncols + 1])])) 

177 break 

178 

179 return axs 

180 

181 def _assignColors(self): 

182 """Assign colors to histograms using a given color map.""" 

183 custom_cmaps = dict( 

184 # https://www.tableau.com/about/blog/2016/7/colors-upgrade-tableau-10-56782 

185 newtab10=[ 

186 "#4e79a7", 

187 "#f28e2b", 

188 "#e15759", 

189 "#76b7b2", 

190 "#59a14f", 

191 "#edc948", 

192 "#b07aa1", 

193 "#ff9da7", 

194 "#9c755f", 

195 "#bab0ac", 

196 ], 

197 # https://personal.sron.nl/~pault/#fig:scheme_bright 

198 bright=[ 

199 "#4477AA", 

200 "#EE6677", 

201 "#228833", 

202 "#CCBB44", 

203 "#66CCEE", 

204 "#AA3377", 

205 "#BBBBBB", 

206 ], 

207 # https://personal.sron.nl/~pault/#fig:scheme_vibrant 

208 vibrant=[ 

209 "#EE7733", 

210 "#0077BB", 

211 "#33BBEE", 

212 "#EE3377", 

213 "#CC3311", 

214 "#009988", 

215 "#BBBBBB", 

216 ], 

217 ) 

218 if self.cmap in custom_cmaps.keys(): 

219 all_cols = custom_cmaps[self.cmap] 

220 else: 

221 try: 

222 all_cols = getattr(plt.cm, self.cmap).copy().colors 

223 except AttributeError: 

224 raise ValueError(f"Unrecognized color map: {self.cmap}") 

225 

226 counter = 0 

227 cols = defaultdict(list) 

228 for panel in self.panels: 

229 for hist in self.panels[panel].hists: 

230 cols[panel].append(all_cols[counter % len(all_cols)]) 

231 counter += 1 

232 return cols 

233 

234 def _makePanel(self, data, panel, ax, col, **kwargs): 

235 """Plot a single panel containing histograms.""" 

236 panel_range = self._getPanelRange(data, panel) 

237 nums, meds, mads = [], [], [] 

238 for i, hist in enumerate(self.panels[panel].hists): 

239 hist_data = data[hist][np.isfinite(data[hist])] 

240 ax.hist( 

241 hist_data, 

242 range=panel_range, 

243 bins=self.panels[panel].bins, 

244 histtype="step", 

245 lw=2, 

246 color=col[i], 

247 label=self.panels[panel].hists[hist], 

248 ) 

249 num, med, mad = self._calcStats(hist_data) 

250 nums.append(num) 

251 meds.append(med) 

252 mads.append(mad) 

253 ax.axvline(med, ls="--", lw=1, c=col[i]) 

254 ax.legend(fontsize=6, loc="upper left") 

255 ax.set_xlim(panel_range) 

256 ax.set_xlabel(self.panels[panel].label) 

257 ax.set_yscale(self.panels[panel].yscale) 

258 ax.tick_params(labelsize=7) 

259 # add a buffer to the top of the plot to allow headspace for labels 

260 ylims = list(ax.get_ylim()) 

261 if ax.get_yscale() == "log": 

262 ylims[1] = 10 ** (np.log10(ylims[1]) * 1.1) 

263 else: 

264 ylims[1] *= 1.1 

265 ax.set_ylim(ylims[0], ylims[1]) 

266 return nums, meds, mads 

267 

268 def _getPanelRange(self, data, panel): 

269 """Determine panel x-axis range based on data percentile limits.""" 

270 panel_range = [np.nan, np.nan] 

271 for hist in self.panels[panel].hists: 

272 hist_range = np.nanpercentile(data[hist], [self.panels[panel].pLower, self.panels[panel].pUpper]) 

273 panel_range[0] = np.nanmin([panel_range[0], hist_range[0]]) 

274 panel_range[1] = np.nanmax([panel_range[1], hist_range[1]]) 

275 return panel_range 

276 

277 def _calcStats(self, data): 

278 """Calculate the number of data points, median, and median absolute 

279 deviation of input data.""" 

280 num = len(data) 

281 med = np.nanmedian(data) 

282 mad = sigmaMad(data) 

283 return num, med, mad 

284 

285 def _addStatisticsPanel(self, fig, handles, nums, meds, mads): 

286 """Add an adjoining panel containing histogram summary statistics.""" 

287 ax = fig.add_subplot(1, 1, 1) 

288 ax.axis("off") 

289 plt.subplots_adjust(left=0.05, right=0.95, bottom=0.09, top=0.9) 

290 

291 # empty handle, used to populate the bespoke legend layout 

292 empty = Rectangle((0, 0), 1, 1, fc="w", fill=False, edgecolor="none", linewidth=0) 

293 

294 # set up new legend handles and labels 

295 legend_handles = [empty] + handles + ([empty] * 3 * len(handles)) + ([empty] * 3) 

296 legend_labels = ( 

297 ([""] * (len(handles) + 1)) 

298 + ["Num"] 

299 + nums 

300 + ["Med"] 

301 + [f"{x:0.1f}" for x in meds] 

302 + ["${{\\sigma}}_{{MAD}}$"] 

303 + [f"{x:0.1f}" for x in mads] 

304 ) 

305 

306 # add the legend 

307 ax.legend( 

308 legend_handles, 

309 legend_labels, 

310 loc="lower left", 

311 ncol=4, 

312 handletextpad=-0.25, 

313 fontsize=6, 

314 borderpad=0, 

315 frameon=False, 

316 columnspacing=-0.25, 

317 )