Coverage for python/lsst/analysis/tools/actions/plot/histPlot.py: 15%

203 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-12-23 09:29 +0000

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("HistPanel", "HistPlot") 

24 

25import logging 

26from collections import defaultdict 

27from typing import Mapping 

28 

29import matplotlib.pyplot as plt 

30import numpy as np 

31from lsst.pex.config import ChoiceField, Config, ConfigDictField, DictField, Field, FieldValidationError 

32from matplotlib.figure import Figure 

33from matplotlib.gridspec import GridSpec 

34from matplotlib.patches import Rectangle 

35 

36from ...interfaces import KeyedData, KeyedDataSchema, PlotAction, Vector 

37from ...statistics import sigmaMad 

38from .plotUtils import addPlotInfo 

39 

40log = logging.getLogger(__name__) 

41 

42 

43class HistPanel(Config): 

44 label = Field[str]( 

45 doc="Panel x-axis label.", 

46 default="label", 

47 ) 

48 hists = DictField[str, str]( 

49 doc="A dict specifying the histograms to be plotted in this panel. Keys are used to identify " 

50 "histogram IDs. Values are used to add to the legend label displayed in the upper corner of the " 

51 "panel.", 

52 optional=False, 

53 ) 

54 yscale = Field[str]( 

55 doc="Y axis scaling.", 

56 default="linear", 

57 ) 

58 bins = Field[int]( 

59 doc="Number of x axis bins within plot x-range.", 

60 default=50, 

61 ) 

62 rangeType = ChoiceField[str]( 

63 doc="Set the type of range to use for the x-axis. Range bounds will be set according to " 

64 "the values of lowerRange and upperRange.", 

65 allowed={ 

66 "percentile": "Upper and lower percentile ranges of the data.", 

67 "sigmaMad": "Range is (sigmaMad - lowerRange*sigmaMad, sigmaMad + upperRange*sigmaMad).", 

68 "fixed": "Range is fixed to (lowerRange, upperRange).", 

69 }, 

70 default="percentile", 

71 ) 

72 lowerRange = Field[float]( 

73 doc="Lower range specifier for the histogram bins. See rangeType for interpretation " 

74 "based on the type of range requested. If more than one histogram is plotted in a given " 

75 "panel and rangeType is not set to fixed, the limit is the minimum value across all input " 

76 "data.", 

77 default=0.0, 

78 ) 

79 upperRange = Field[float]( 

80 doc="Upper range specifier for the histogram bins. See rangeType for interpretation " 

81 "based on the type of range requested. If more than one histogram is plotted in a given " 

82 "panel and rangeType is not set to fixed, the limit is the maximum value across all input " 

83 "data.", 

84 default=100.0, 

85 ) 

86 referenceValue = Field[float]( 

87 doc="Value at which to add a black solid vertical line. Ignored if set to `None`.", 

88 default=None, 

89 optional=True, 

90 ) 

91 histDensity = Field[bool]( 

92 doc="Whether to plot the histogram as a normalized probability distribution. Must also " 

93 "provide a value for referenceValue", 

94 default=False, 

95 ) 

96 

97 def validate(self): 

98 super().validate() 

99 if self.rangeType == "percentile" and self.lowerRange < 0.0 or self.upperRange > 100.0: 

100 msg = ( 

101 "For rangeType %s, ranges must obey: lowerRange >= 0 and upperRange <= 100." % self.rangeType 

102 ) 

103 raise FieldValidationError(self.__class__.rangeType, self, msg) 

104 if self.rangeType == "sigmaMad" and self.lowerRange < 0.0: 

105 msg = ( 

106 "For rangeType %s, lower range must obey: lowerRange >= 0 (the lower range is " 

107 "set as median - lowerRange*sigmaMad." % self.rangeType 

108 ) 

109 raise FieldValidationError(self.__class__.rangeType, self, msg) 

110 if self.rangeType == "fixed" and (self.upperRange - self.lowerRange) == 0.0: 

111 msg = ( 

112 "For rangeType %s, lower and upper ranges must differ (i.e. must obey: " 

113 "upperRange - lowerRange != 0)." % self.rangeType 

114 ) 

115 raise FieldValidationError(self.__class__.rangeType, self, msg) 

116 if self.histDensity and self.referenceValue is None: 

117 msg = "Must provide referenceValue if histDensity is True." 

118 raise FieldValidationError(self.__class__.referenceValue, self, msg) 

119 

120 

121class HistPlot(PlotAction): 

122 panels = ConfigDictField( 

123 doc="A configurable dict describing the panels to be plotted, and the histograms for each panel.", 

124 keytype=str, 

125 itemtype=HistPanel, 

126 default={}, 

127 ) 

128 cmap = Field[str]( 

129 doc="Color map used for histogram lines. All types available via `plt.cm` may be used. " 

130 "A number of custom color maps are also defined: `newtab10`, `bright`, `vibrant`.", 

131 default="newtab10", 

132 ) 

133 

134 def getInputSchema(self) -> KeyedDataSchema: 

135 for panel in self.panels: # type: ignore 

136 for histData in self.panels[panel].hists.items(): # type: ignore 

137 yield histData, Vector 

138 

139 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure: 

140 return self.makePlot(data, **kwargs) 

141 # table is a dict that needs: x, y, run, skymap, filter, tract, 

142 

143 def makePlot( 

144 self, data: KeyedData, plotInfo: Mapping[str, str] = None, **kwargs # type: ignore 

145 ) -> Figure: 

146 """Make an N-panel plot with a user-configurable number of histograms 

147 displayed in each panel. 

148 

149 Parameters 

150 ---------- 

151 data : `pandas.core.frame.DataFrame` 

152 The catalog to plot the points from. 

153 plotInfo : `dict` 

154 A dictionary of information about the data being plotted with keys: 

155 `"run"` 

156 Output run for the plots (`str`). 

157 `"tractTableType"` 

158 Table from which results are taken (`str`). 

159 `"plotName"` 

160 Output plot name (`str`) 

161 `"SN"` 

162 The global signal-to-noise data threshold (`float`) 

163 `"skymap"` 

164 The type of skymap used for the data (`str`). 

165 `"tract"` 

166 The tract that the data comes from (`int`). 

167 `"bands"` 

168 The bands used for this data (`str` or `list`). 

169 `"visit"` 

170 The visit that the data comes from (`int`) 

171 

172 Returns 

173 ------- 

174 fig : `matplotlib.figure.Figure` 

175 The resulting figure. 

176 

177 """ 

178 

179 # set up figure 

180 fig = plt.figure(dpi=300) 

181 hist_fig, side_fig = fig.subfigures(1, 2, wspace=0, width_ratios=[3, 1]) 

182 axs, ncols, nrows = self._makeAxes(hist_fig) 

183 

184 # loop over each panel; plot histograms 

185 colors = self._assignColors() 

186 nth_panel = len(self.panels) 

187 nth_col = ncols 

188 nth_row = nrows - 1 

189 label_font_size = max(6, 10 - nrows) 

190 for panel, ax in zip(self.panels, axs): 

191 nth_panel -= 1 

192 nth_col = ncols - 1 if nth_col == 0 else nth_col - 1 

193 if nth_panel == 0 and nrows * ncols - len(self.panels) > 0: 

194 nth_col -= 1 

195 # Set font size for legend based on number of panels being plotted. 

196 legend_font_size = max(4, int(8 - len(self.panels[panel].hists) / 2 - nrows // 2)) # type: ignore 

197 nums, meds, mads = self._makePanel( 

198 data, 

199 panel, 

200 ax, 

201 colors[panel], 

202 label_font_size=label_font_size, 

203 legend_font_size=legend_font_size, 

204 ncols=ncols, 

205 ) 

206 

207 all_handles, all_nums, all_meds, all_mads = [], [], [], [] 

208 handles, labels = ax.get_legend_handles_labels() # code for plotting 

209 all_handles += handles 

210 all_nums += nums 

211 all_meds += meds 

212 all_mads += mads 

213 title_str = self.panels[panel].label # type: ignore 

214 # add side panel; add statistics 

215 self._addStatisticsPanel( 

216 side_fig, 

217 all_handles, 

218 all_nums, 

219 all_meds, 

220 all_mads, 

221 legend_font_size=legend_font_size, 

222 yAnchor0=ax.get_position().y0, 

223 nth_row=nth_row, 

224 nth_col=nth_col, 

225 title_str=title_str, 

226 ) 

227 nth_row = nth_row - 1 if nth_col == 0 else nth_row 

228 

229 # add general plot info 

230 if plotInfo is not None: 

231 hist_fig = addPlotInfo(hist_fig, plotInfo) 

232 

233 # finish up 

234 plt.draw() 

235 return fig 

236 

237 def _makeAxes(self, fig): 

238 """Determine axes layout for main histogram figure.""" 

239 num_panels = len(self.panels) 

240 if num_panels <= 1: 

241 ncols = 1 

242 else: 

243 ncols = 2 

244 nrows = int(np.ceil(num_panels / ncols)) 

245 

246 gs = GridSpec(nrows, ncols, left=0.12, right=0.99, bottom=0.1, top=0.88, wspace=0.31, hspace=0.45) 

247 

248 axs = [] 

249 counter = 0 

250 for row in range(nrows): 

251 for col in range(ncols): 

252 counter += 1 

253 if counter < num_panels: 

254 axs.append(fig.add_subplot(gs[row : row + 1, col : col + 1])) 

255 else: 

256 axs.append(fig.add_subplot(gs[row : row + 1, col : np.min([col + 2, ncols + 1])])) 

257 break 

258 

259 return axs, ncols, nrows 

260 

261 def _assignColors(self): 

262 """Assign colors to histograms using a given color map.""" 

263 custom_cmaps = dict( 

264 # https://www.tableau.com/about/blog/2016/7/colors-upgrade-tableau-10-56782 

265 newtab10=[ 

266 "#4e79a7", 

267 "#f28e2b", 

268 "#e15759", 

269 "#76b7b2", 

270 "#59a14f", 

271 "#edc948", 

272 "#b07aa1", 

273 "#ff9da7", 

274 "#9c755f", 

275 "#bab0ac", 

276 ], 

277 # https://personal.sron.nl/~pault/#fig:scheme_bright 

278 bright=[ 

279 "#4477AA", 

280 "#EE6677", 

281 "#228833", 

282 "#CCBB44", 

283 "#66CCEE", 

284 "#AA3377", 

285 "#BBBBBB", 

286 ], 

287 # https://personal.sron.nl/~pault/#fig:scheme_vibrant 

288 vibrant=[ 

289 "#EE7733", 

290 "#0077BB", 

291 "#33BBEE", 

292 "#EE3377", 

293 "#CC3311", 

294 "#009988", 

295 "#BBBBBB", 

296 ], 

297 ) 

298 if self.cmap in custom_cmaps.keys(): 

299 all_colors = custom_cmaps[self.cmap] 

300 else: 

301 try: 

302 all_colors = getattr(plt.cm, self.cmap).copy().colors 

303 except AttributeError: 

304 raise ValueError(f"Unrecognized color map: {self.cmap}") 

305 

306 counter = 0 

307 colors = defaultdict(list) 

308 for panel in self.panels: 

309 for hist in self.panels[panel].hists: 

310 colors[panel].append(all_colors[counter % len(all_colors)]) 

311 counter += 1 

312 return colors 

313 

314 def _makePanel(self, data, panel, ax, colors, label_font_size=9, legend_font_size=7, ncols=1): 

315 """Plot a single panel containing histograms.""" 

316 nums, meds, mads = [], [], [] 

317 for i, hist in enumerate(self.panels[panel].hists): 

318 hist_data = data[hist][np.isfinite(data[hist])] 

319 num, med, mad = self._calcStats(hist_data) 

320 nums.append(num) 

321 meds.append(med) 

322 mads.append(mad) 

323 panel_range = self._getPanelRange(data, panel, mads=mads, meds=meds) 

324 

325 for i, hist in enumerate(self.panels[panel].hists): 

326 hist_data = data[hist][np.isfinite(data[hist])] 

327 ax.hist( 

328 hist_data, 

329 range=panel_range, 

330 bins=self.panels[panel].bins, 

331 histtype="step", 

332 density=self.panels[panel].histDensity, 

333 lw=2, 

334 color=colors[i], 

335 label=self.panels[panel].hists[hist], 

336 ) 

337 ax.axvline(meds[i], ls=(0, (5, 3)), lw=1, c=colors[i]) 

338 

339 ax.legend(fontsize=legend_font_size, loc="upper left", frameon=False) 

340 ax.set_xlim(panel_range) 

341 # The following accommodates spacing for ranges with large numbers 

342 # but small-ish dynamic range (example use case: RA 300-301). 

343 if ncols > 1 and max(np.abs(panel_range)) >= 100 and (panel_range[1] - panel_range[0]) < 5: 

344 ax.xaxis.set_major_formatter("{x:.2f}") 

345 ax.tick_params(axis="x", labelrotation=25, pad=-1) 

346 ax.set_xlabel(self.panels[panel].label, fontsize=label_font_size) 

347 y_label = "Normalized (PDF)" if self.panels[panel].histDensity else "Frequency" 

348 ax.set_ylabel(y_label, fontsize=label_font_size) 

349 ax.set_yscale(self.panels[panel].yscale) 

350 ax.tick_params(labelsize=max(5, label_font_size - 2)) 

351 # add a buffer to the top of the plot to allow headspace for labels 

352 ylims = list(ax.get_ylim()) 

353 if ax.get_yscale() == "log": 

354 ylims[1] = 10 ** (np.log10(ylims[1]) * 1.1) 

355 else: 

356 ylims[1] *= 1.1 

357 ax.set_ylim(ylims[0], ylims[1]) 

358 

359 # Draw a vertical line at a reference value, if given. If histDensity 

360 # is True, also plot a reference PDF with mean = referenceValue and 

361 # sigma = 1 for reference. 

362 if self.panels[panel].referenceValue is not None: 

363 ax = self._addReferenceLines(ax, panel, panel_range, legend_font_size=legend_font_size) 

364 

365 return nums, meds, mads 

366 

367 def _getPanelRange(self, data, panel, mads=None, meds=None): 

368 """Determine panel x-axis range based config settings.""" 

369 panel_range = [np.nan, np.nan] 

370 rangeType = self.panels[panel].rangeType 

371 lowerRange = self.panels[panel].lowerRange 

372 upperRange = self.panels[panel].upperRange 

373 if rangeType == "percentile": 

374 panel_range = self._getPercentilePanelRange(data, panel) 

375 elif rangeType == "sigmaMad": 

376 # Set the panel range to extend lowerRange[upperRange] times the 

377 # maximum sigmaMad for the datasets in the panel to the left[right] 

378 # from the minimum[maximum] median value of all datasets in the 

379 # panel. 

380 maxMad = np.nanmax(mads) 

381 maxMed = np.nanmax(meds) 

382 minMed = np.nanmin(meds) 

383 panel_range = [minMed - lowerRange * maxMad, maxMed + upperRange * maxMad] 

384 if panel_range[1] - panel_range[0] == 0: 

385 self.log.info( 

386 "NOTE: panel_range for {} based on med/sigMad was 0. Computing using " 

387 "percentile range instead.".format(panel) 

388 ) 

389 panel_range = self._getPercentilePanelRange(data, panel) 

390 elif rangeType == "fixed": 

391 panel_range = [lowerRange, upperRange] 

392 else: 

393 raise RuntimeError(f"Invalid rangeType: {rangeType}") 

394 return panel_range 

395 

396 def _getPercentilePanelRange(self, data, panel): 

397 """Determine panel x-axis range based on data percentile limits.""" 

398 panel_range = [np.nan, np.nan] 

399 for hist in self.panels[panel].hists: 

400 hist_range = np.nanpercentile( 

401 data[hist], [self.panels[panel].lowerRange, self.panels[panel].upperRange] 

402 ) 

403 panel_range[0] = np.nanmin([panel_range[0], hist_range[0]]) 

404 panel_range[1] = np.nanmax([panel_range[1], hist_range[1]]) 

405 return panel_range 

406 

407 def _calcStats(self, data): 

408 """Calculate the number of data points, median, and median absolute 

409 deviation of input data.""" 

410 num = len(data) 

411 med = np.nanmedian(data) 

412 mad = sigmaMad(data) 

413 return num, med, mad 

414 

415 def _addReferenceLines(self, ax, panel, panel_range, legend_font_size=7): 

416 """Draw the vertical reference line and density curve (if requested) 

417 on the panel. 

418 """ 

419 ax2 = ax.twinx() 

420 ax2.axis("off") 

421 ax2.set_xlim(ax.get_xlim()) 

422 ax2.set_ylim(ax.get_ylim()) 

423 

424 if self.panels[panel].histDensity: 

425 reference_label = None 

426 else: 

427 reference_label = "${{\\mu_{{ref}}}}$: {}".format(self.panels[panel].referenceValue) 

428 ax2.axvline( 

429 self.panels[panel].referenceValue, ls="-", lw=1, c="black", zorder=0, label=reference_label 

430 ) 

431 if self.panels[panel].histDensity: 

432 ref_x = np.arange(panel_range[0], panel_range[1], (panel_range[1] - panel_range[0]) / 100.0) 

433 ref_mean = self.panels[panel].referenceValue 

434 ref_std = 1.0 

435 ref_y = ( 

436 1.0 

437 / (ref_std * np.sqrt(2.0 * np.pi)) 

438 * np.exp(-((ref_x - ref_mean) ** 2) / (2.0 * ref_std**2)) 

439 ) 

440 ax2.fill_between(ref_x, ref_y, alpha=0.1, color="black", label="P$_{{norm}}(0,1)$", zorder=-1) 

441 # Make sure the y-axis extends beyond the data plotted and that 

442 # the y-ranges of both axes are in sync. 

443 y_max = max(max(ref_y), ax2.get_ylim()[1]) 

444 if ax2.get_ylim()[1] < 1.05 * y_max: 

445 ax.set_ylim(ax.get_ylim()[0], 1.05 * y_max) 

446 ax2.set_ylim(ax.get_ylim()) 

447 ax2.legend(fontsize=legend_font_size, handlelength=1.5, loc="upper right", frameon=False) 

448 

449 return ax 

450 

451 def _addStatisticsPanel( 

452 self, 

453 fig, 

454 handles, 

455 nums, 

456 meds, 

457 mads, 

458 legend_font_size=8, 

459 yAnchor0=0.0, 

460 nth_row=0, 

461 nth_col=0, 

462 title_str=None, 

463 ): 

464 """Add an adjoining panel containing histogram summary statistics.""" 

465 ax = fig.add_subplot(1, 1, 1) 

466 ax.axis("off") 

467 plt.subplots_adjust(left=0.05, right=0.95, bottom=0.0, top=1.0) 

468 # empty handle, used to populate the bespoke legend layout 

469 empty = Rectangle((0, 0), 1, 1, fc="w", fill=False, edgecolor="none", linewidth=0) 

470 

471 # set up new legend handles and labels 

472 legend_handles = [empty] + handles + ([empty] * 3 * len(handles)) + ([empty] * 3) 

473 legend_labels = ( 

474 ([""] * (len(handles) + 1)) 

475 + ["N$_{{data}}$"] 

476 + nums 

477 + ["Med"] 

478 + [f"{x:.2f}" for x in meds] 

479 + ["${{\\sigma}}_{{MAD}}$"] 

480 + [f"{x:.2f}" for x in mads] 

481 ) 

482 

483 # Set the y anchor for the legend such that it roughly lines up with 

484 # the panels. 

485 yAnchor = max(0, yAnchor0 - 0.01) + nth_col * (0.008 + len(nums) * 0.005) * legend_font_size 

486 

487 nth_legend = ax.legend( 

488 legend_handles, 

489 legend_labels, 

490 loc="lower left", 

491 bbox_to_anchor=(0.0, yAnchor), 

492 ncol=4, 

493 handletextpad=-0.25, 

494 fontsize=legend_font_size, 

495 borderpad=0, 

496 frameon=False, 

497 columnspacing=-0.25, 

498 title=title_str, 

499 title_fontproperties={"weight": "bold", "size": legend_font_size}, 

500 ) 

501 if nth_row + nth_col > 0: 

502 ax.add_artist(nth_legend)