Coverage for python/lsst/analysis/tools/actions/plot/histPlot.py: 16%

215 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-02-17 01:49 -0800

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("HistPanel", "HistPlot", "HistStatsPanel") 

24 

25import logging 

26from collections import defaultdict 

27from typing import Mapping 

28 

29import matplotlib.pyplot as plt 

30import numpy as np 

31from lsst.pex.config import ( 

32 ChoiceField, 

33 Config, 

34 ConfigDictField, 

35 ConfigField, 

36 DictField, 

37 Field, 

38 FieldValidationError, 

39 ListField, 

40) 

41from matplotlib.figure import Figure 

42from matplotlib.gridspec import GridSpec 

43from matplotlib.patches import Rectangle 

44 

45from ...interfaces import KeyedData, KeyedDataSchema, PlotAction, Vector 

46from ...statistics import sigmaMad 

47from .plotUtils import addPlotInfo 

48 

49log = logging.getLogger(__name__) 

50 

51 

52class HistStatsPanel(Config): 

53 """A Config class that holds parameters to configure a the stats panel 

54 shown for histPlot. 

55 

56 The fields in this class correspond to the parameters that can be used to 

57 customize the HistPlot stats panel. 

58 

59 - The ListField parameter a dict to specify names of 3 stat columns accepts 

60 latex formating 

61 

62 - The other parameters (stat1, stat2, stat3) are lists of strings that 

63 specify vector keys correspoinding to scalar values computed in the 

64 prep/process/produce steps of an analysis tools plot/metric configurable 

65 action. There should be one key for each group in the HistPanel. 

66 

67 A separate config class is used instead of constructing 

68 `~lsst.pex.config.DictField`'s in HistPanel for each parameter for clarity 

69 and consistency. 

70 

71 

72 

73 Notes 

74 ----- 

75 This is intended to be used as a configuration of the HistPlot/HistPanel 

76 class. 

77 

78 If no HistStatsPanel is specified then the default behavor persists where 

79 the stats panel shows N / median / sigma_mad for each group in the panel. 

80 """ 

81 

82 statsLabels = ListField[str]( 

83 doc="list specifying the labels for stats", 

84 length=3, 

85 default={"N$_{{data}}$", "Med", "${{\\sigma}}_{{MAD}}$"}, 

86 ) 

87 stat1 = ListField[str]( 

88 doc="A list specifying the vector keys of the first scalar statistic to be shown in this panel." 

89 "there should be one entry for each hist in the panel", 

90 default=[], 

91 ) 

92 stat2 = ListField[str]( 

93 doc="A list specifying the vector keys of the second scalar statistic to be shown in this panel." 

94 "there should be one entry for each hist in the panel", 

95 default=[], 

96 ) 

97 stat3 = ListField[str]( 

98 doc="A list specifying the vector keys of the third scalar statistic to be shown in this panel." 

99 "there should be one entry for each hist in the panel", 

100 default=[], 

101 ) 

102 

103 

104class HistPanel(Config): 

105 label = Field[str]( 

106 doc="Panel x-axis label.", 

107 default="label", 

108 ) 

109 hists = DictField[str, str]( 

110 doc="A dict specifying the histograms to be plotted in this panel. Keys are used to identify " 

111 "histogram IDs. Values are used to add to the legend label displayed in the upper corner of the " 

112 "panel.", 

113 optional=False, 

114 ) 

115 yscale = Field[str]( 

116 doc="Y axis scaling.", 

117 default="linear", 

118 ) 

119 bins = Field[int]( 

120 doc="Number of x axis bins within plot x-range.", 

121 default=50, 

122 ) 

123 rangeType = ChoiceField[str]( 

124 doc="Set the type of range to use for the x-axis. Range bounds will be set according to " 

125 "the values of lowerRange and upperRange.", 

126 allowed={ 

127 "percentile": "Upper and lower percentile ranges of the data.", 

128 "sigmaMad": "Range is (sigmaMad - lowerRange*sigmaMad, sigmaMad + upperRange*sigmaMad).", 

129 "fixed": "Range is fixed to (lowerRange, upperRange).", 

130 }, 

131 default="percentile", 

132 ) 

133 lowerRange = Field[float]( 

134 doc="Lower range specifier for the histogram bins. See rangeType for interpretation " 

135 "based on the type of range requested. If more than one histogram is plotted in a given " 

136 "panel and rangeType is not set to fixed, the limit is the minimum value across all input " 

137 "data.", 

138 default=0.0, 

139 ) 

140 upperRange = Field[float]( 

141 doc="Upper range specifier for the histogram bins. See rangeType for interpretation " 

142 "based on the type of range requested. If more than one histogram is plotted in a given " 

143 "panel and rangeType is not set to fixed, the limit is the maximum value across all input " 

144 "data.", 

145 default=100.0, 

146 ) 

147 referenceValue = Field[float]( 

148 doc="Value at which to add a black solid vertical line. Ignored if set to `None`.", 

149 default=None, 

150 optional=True, 

151 ) 

152 histDensity = Field[bool]( 

153 doc="Whether to plot the histogram as a normalized probability distribution. Must also " 

154 "provide a value for referenceValue", 

155 default=False, 

156 ) 

157 statsPanel = ConfigField[HistStatsPanel]( 

158 doc="configuration for stats to be shown on plot, if None then " 

159 "default stats: N, median, sigma mad are shown", 

160 default=None, 

161 ) 

162 

163 def validate(self): 

164 super().validate() 

165 if self.rangeType == "percentile" and self.lowerRange < 0.0 or self.upperRange > 100.0: 

166 msg = ( 

167 "For rangeType %s, ranges must obey: lowerRange >= 0 and upperRange <= 100." % self.rangeType 

168 ) 

169 raise FieldValidationError(self.__class__.rangeType, self, msg) 

170 if self.rangeType == "sigmaMad" and self.lowerRange < 0.0: 

171 msg = ( 

172 "For rangeType %s, lower range must obey: lowerRange >= 0 (the lower range is " 

173 "set as median - lowerRange*sigmaMad." % self.rangeType 

174 ) 

175 raise FieldValidationError(self.__class__.rangeType, self, msg) 

176 if self.rangeType == "fixed" and (self.upperRange - self.lowerRange) == 0.0: 

177 msg = ( 

178 "For rangeType %s, lower and upper ranges must differ (i.e. must obey: " 

179 "upperRange - lowerRange != 0)." % self.rangeType 

180 ) 

181 raise FieldValidationError(self.__class__.rangeType, self, msg) 

182 if self.histDensity and self.referenceValue is None: 

183 msg = "Must provide referenceValue if histDensity is True." 

184 raise FieldValidationError(self.__class__.referenceValue, self, msg) 

185 

186 

187class HistPlot(PlotAction): 

188 panels = ConfigDictField( 

189 doc="A configurable dict describing the panels to be plotted, and the histograms for each panel.", 

190 keytype=str, 

191 itemtype=HistPanel, 

192 default={}, 

193 ) 

194 cmap = Field[str]( 

195 doc="Color map used for histogram lines. All types available via `plt.cm` may be used. " 

196 "A number of custom color maps are also defined: `newtab10`, `bright`, `vibrant`.", 

197 default="newtab10", 

198 ) 

199 

200 def getInputSchema(self) -> KeyedDataSchema: 

201 for panel in self.panels: # type: ignore 

202 for histData in self.panels[panel].hists.items(): # type: ignore 

203 yield histData, Vector 

204 

205 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure: 

206 return self.makePlot(data, **kwargs) 

207 # table is a dict that needs: x, y, run, skymap, filter, tract, 

208 

209 def makePlot( 

210 self, data: KeyedData, plotInfo: Mapping[str, str] = None, **kwargs # type: ignore 

211 ) -> Figure: 

212 """Make an N-panel plot with a user-configurable number of histograms 

213 displayed in each panel. 

214 

215 Parameters 

216 ---------- 

217 data : `pandas.core.frame.DataFrame` 

218 The catalog to plot the points from. 

219 plotInfo : `dict` 

220 A dictionary of information about the data being plotted with keys: 

221 `"run"` 

222 Output run for the plots (`str`). 

223 `"tractTableType"` 

224 Table from which results are taken (`str`). 

225 `"plotName"` 

226 Output plot name (`str`) 

227 `"SN"` 

228 The global signal-to-noise data threshold (`float`) 

229 `"skymap"` 

230 The type of skymap used for the data (`str`). 

231 `"tract"` 

232 The tract that the data comes from (`int`). 

233 `"bands"` 

234 The bands used for this data (`str` or `list`). 

235 `"visit"` 

236 The visit that the data comes from (`int`) 

237 

238 Returns 

239 ------- 

240 fig : `matplotlib.figure.Figure` 

241 The resulting figure. 

242 

243 """ 

244 

245 # set up figure 

246 fig = plt.figure(dpi=300) 

247 hist_fig, side_fig = fig.subfigures(1, 2, wspace=0, width_ratios=[3, 1]) 

248 axs, ncols, nrows = self._makeAxes(hist_fig) 

249 

250 # loop over each panel; plot histograms 

251 colors = self._assignColors() 

252 nth_panel = len(self.panels) 

253 nth_col = ncols 

254 nth_row = nrows - 1 

255 label_font_size = max(6, 10 - nrows) 

256 for panel, ax in zip(self.panels, axs): 

257 nth_panel -= 1 

258 nth_col = ncols - 1 if nth_col == 0 else nth_col - 1 

259 if nth_panel == 0 and nrows * ncols - len(self.panels) > 0: 

260 nth_col -= 1 

261 # Set font size for legend based on number of panels being plotted. 

262 legend_font_size = max(4, int(8 - len(self.panels[panel].hists) / 2 - nrows // 2)) # type: ignore 

263 nums, meds, mads, stats_dict = self._makePanel( 

264 data, 

265 panel, 

266 ax, 

267 colors[panel], 

268 label_font_size=label_font_size, 

269 legend_font_size=legend_font_size, 

270 ncols=ncols, 

271 ) 

272 

273 all_handles, all_nums, all_meds, all_mads = [], [], [], [] 

274 handles, labels = ax.get_legend_handles_labels() # code for plotting 

275 all_handles += handles 

276 all_nums += nums 

277 all_meds += meds 

278 all_mads += mads 

279 title_str = self.panels[panel].label # type: ignore 

280 # add side panel; add statistics 

281 self._addStatisticsPanel( 

282 side_fig, 

283 all_handles, 

284 all_nums, 

285 all_meds, 

286 all_mads, 

287 stats_dict, 

288 legend_font_size=legend_font_size, 

289 yAnchor0=ax.get_position().y0, 

290 nth_row=nth_row, 

291 nth_col=nth_col, 

292 title_str=title_str, 

293 ) 

294 nth_row = nth_row - 1 if nth_col == 0 else nth_row 

295 

296 # add general plot info 

297 if plotInfo is not None: 

298 hist_fig = addPlotInfo(hist_fig, plotInfo) 

299 

300 # finish up 

301 plt.draw() 

302 return fig 

303 

304 def _makeAxes(self, fig): 

305 """Determine axes layout for main histogram figure.""" 

306 num_panels = len(self.panels) 

307 if num_panels <= 1: 

308 ncols = 1 

309 else: 

310 ncols = 2 

311 nrows = int(np.ceil(num_panels / ncols)) 

312 

313 gs = GridSpec(nrows, ncols, left=0.12, right=0.99, bottom=0.1, top=0.88, wspace=0.31, hspace=0.45) 

314 

315 axs = [] 

316 counter = 0 

317 for row in range(nrows): 

318 for col in range(ncols): 

319 counter += 1 

320 if counter < num_panels: 

321 axs.append(fig.add_subplot(gs[row : row + 1, col : col + 1])) 

322 else: 

323 axs.append(fig.add_subplot(gs[row : row + 1, col : np.min([col + 2, ncols + 1])])) 

324 break 

325 

326 return axs, ncols, nrows 

327 

328 def _assignColors(self): 

329 """Assign colors to histograms using a given color map.""" 

330 custom_cmaps = dict( 

331 # https://www.tableau.com/about/blog/2016/7/colors-upgrade-tableau-10-56782 

332 newtab10=[ 

333 "#4e79a7", 

334 "#f28e2b", 

335 "#e15759", 

336 "#76b7b2", 

337 "#59a14f", 

338 "#edc948", 

339 "#b07aa1", 

340 "#ff9da7", 

341 "#9c755f", 

342 "#bab0ac", 

343 ], 

344 # https://personal.sron.nl/~pault/#fig:scheme_bright 

345 bright=[ 

346 "#4477AA", 

347 "#EE6677", 

348 "#228833", 

349 "#CCBB44", 

350 "#66CCEE", 

351 "#AA3377", 

352 "#BBBBBB", 

353 ], 

354 # https://personal.sron.nl/~pault/#fig:scheme_vibrant 

355 vibrant=[ 

356 "#EE7733", 

357 "#0077BB", 

358 "#33BBEE", 

359 "#EE3377", 

360 "#CC3311", 

361 "#009988", 

362 "#BBBBBB", 

363 ], 

364 ) 

365 if self.cmap in custom_cmaps.keys(): 

366 all_colors = custom_cmaps[self.cmap] 

367 else: 

368 try: 

369 all_colors = getattr(plt.cm, self.cmap).copy().colors 

370 except AttributeError: 

371 raise ValueError(f"Unrecognized color map: {self.cmap}") 

372 

373 counter = 0 

374 colors = defaultdict(list) 

375 for panel in self.panels: 

376 for hist in self.panels[panel].hists: 

377 colors[panel].append(all_colors[counter % len(all_colors)]) 

378 counter += 1 

379 return colors 

380 

381 def _makePanel(self, data, panel, ax, colors, label_font_size=9, legend_font_size=7, ncols=1): 

382 """Plot a single panel containing histograms.""" 

383 nums, meds, mads = [], [], [] 

384 for i, hist in enumerate(self.panels[panel].hists): 

385 hist_data = data[hist][np.isfinite(data[hist])] 

386 num, med, mad = self._calcStats(hist_data) 

387 nums.append(num) 

388 meds.append(med) 

389 mads.append(mad) 

390 panel_range = self._getPanelRange(data, panel, mads=mads, meds=meds) 

391 

392 for i, hist in enumerate(self.panels[panel].hists): 

393 hist_data = data[hist][np.isfinite(data[hist])] 

394 ax.hist( 

395 hist_data, 

396 range=panel_range, 

397 bins=self.panels[panel].bins, 

398 histtype="step", 

399 density=self.panels[panel].histDensity, 

400 lw=2, 

401 color=colors[i], 

402 label=self.panels[panel].hists[hist], 

403 ) 

404 ax.axvline(meds[i], ls=(0, (5, 3)), lw=1, c=colors[i]) 

405 

406 ax.legend(fontsize=legend_font_size, loc="upper left", frameon=False) 

407 ax.set_xlim(panel_range) 

408 # The following accommodates spacing for ranges with large numbers 

409 # but small-ish dynamic range (example use case: RA 300-301). 

410 if ncols > 1 and max(np.abs(panel_range)) >= 100 and (panel_range[1] - panel_range[0]) < 5: 

411 ax.xaxis.set_major_formatter("{x:.2f}") 

412 ax.tick_params(axis="x", labelrotation=25, pad=-1) 

413 ax.set_xlabel(self.panels[panel].label, fontsize=label_font_size) 

414 y_label = "Normalized (PDF)" if self.panels[panel].histDensity else "Frequency" 

415 ax.set_ylabel(y_label, fontsize=label_font_size) 

416 ax.set_yscale(self.panels[panel].yscale) 

417 ax.tick_params(labelsize=max(5, label_font_size - 2)) 

418 # add a buffer to the top of the plot to allow headspace for labels 

419 ylims = list(ax.get_ylim()) 

420 if ax.get_yscale() == "log": 

421 ylims[1] = 10 ** (np.log10(ylims[1]) * 1.1) 

422 else: 

423 ylims[1] *= 1.1 

424 ax.set_ylim(ylims[0], ylims[1]) 

425 

426 # Draw a vertical line at a reference value, if given. If histDensity 

427 # is True, also plot a reference PDF with mean = referenceValue and 

428 # sigma = 1 for reference. 

429 if self.panels[panel].referenceValue is not None: 

430 ax = self._addReferenceLines(ax, panel, panel_range, legend_font_size=legend_font_size) 

431 

432 if self.panels[panel].statsPanel is None: 

433 stats_dict = { 

434 "statLabels": ["N$_{{data}}$", "Med", "${{\\sigma}}_{{MAD}}$"], 

435 "stat1": nums, 

436 "stat2": meds, 

437 "stat3": mads, 

438 } 

439 else: 

440 stat1 = [data[stat] for stat in self.panels[panel].statsPanel.stat1] 

441 stat2 = [data[stat] for stat in self.panels[panel].statsPanel.stat2] 

442 stat3 = [data[stat] for stat in self.panels[panel].statsPanel.stat3] 

443 stats_dict = { 

444 "statLabels": self.panels[panel].statsPanel.statsLabels, 

445 "stat1": stat1, 

446 "stat2": stat2, 

447 "stat3": stat3, 

448 } 

449 return nums, meds, mads, stats_dict 

450 

451 def _getPanelRange(self, data, panel, mads=None, meds=None): 

452 """Determine panel x-axis range based config settings.""" 

453 panel_range = [np.nan, np.nan] 

454 rangeType = self.panels[panel].rangeType 

455 lowerRange = self.panels[panel].lowerRange 

456 upperRange = self.panels[panel].upperRange 

457 if rangeType == "percentile": 

458 panel_range = self._getPercentilePanelRange(data, panel) 

459 elif rangeType == "sigmaMad": 

460 # Set the panel range to extend lowerRange[upperRange] times the 

461 # maximum sigmaMad for the datasets in the panel to the left[right] 

462 # from the minimum[maximum] median value of all datasets in the 

463 # panel. 

464 maxMad = np.nanmax(mads) 

465 maxMed = np.nanmax(meds) 

466 minMed = np.nanmin(meds) 

467 panel_range = [minMed - lowerRange * maxMad, maxMed + upperRange * maxMad] 

468 if panel_range[1] - panel_range[0] == 0: 

469 self.log.info( 

470 "NOTE: panel_range for {} based on med/sigMad was 0. Computing using " 

471 "percentile range instead.".format(panel) 

472 ) 

473 panel_range = self._getPercentilePanelRange(data, panel) 

474 elif rangeType == "fixed": 

475 panel_range = [lowerRange, upperRange] 

476 else: 

477 raise RuntimeError(f"Invalid rangeType: {rangeType}") 

478 return panel_range 

479 

480 def _getPercentilePanelRange(self, data, panel): 

481 """Determine panel x-axis range based on data percentile limits.""" 

482 panel_range = [np.nan, np.nan] 

483 for hist in self.panels[panel].hists: 

484 hist_range = np.nanpercentile( 

485 data[hist], [self.panels[panel].lowerRange, self.panels[panel].upperRange] 

486 ) 

487 panel_range[0] = np.nanmin([panel_range[0], hist_range[0]]) 

488 panel_range[1] = np.nanmax([panel_range[1], hist_range[1]]) 

489 return panel_range 

490 

491 def _calcStats(self, data): 

492 """Calculate the number of data points, median, and median absolute 

493 deviation of input data.""" 

494 num = len(data) 

495 med = np.nanmedian(data) 

496 mad = sigmaMad(data) 

497 return num, med, mad 

498 

499 def _addReferenceLines(self, ax, panel, panel_range, legend_font_size=7): 

500 """Draw the vertical reference line and density curve (if requested) 

501 on the panel. 

502 """ 

503 ax2 = ax.twinx() 

504 ax2.axis("off") 

505 ax2.set_xlim(ax.get_xlim()) 

506 ax2.set_ylim(ax.get_ylim()) 

507 

508 if self.panels[panel].histDensity: 

509 reference_label = None 

510 else: 

511 reference_label = "${{\\mu_{{ref}}}}$: {}".format(self.panels[panel].referenceValue) 

512 ax2.axvline( 

513 self.panels[panel].referenceValue, ls="-", lw=1, c="black", zorder=0, label=reference_label 

514 ) 

515 if self.panels[panel].histDensity: 

516 ref_x = np.arange(panel_range[0], panel_range[1], (panel_range[1] - panel_range[0]) / 100.0) 

517 ref_mean = self.panels[panel].referenceValue 

518 ref_std = 1.0 

519 ref_y = ( 

520 1.0 

521 / (ref_std * np.sqrt(2.0 * np.pi)) 

522 * np.exp(-((ref_x - ref_mean) ** 2) / (2.0 * ref_std**2)) 

523 ) 

524 ax2.fill_between(ref_x, ref_y, alpha=0.1, color="black", label="P$_{{norm}}(0,1)$", zorder=-1) 

525 # Make sure the y-axis extends beyond the data plotted and that 

526 # the y-ranges of both axes are in sync. 

527 y_max = max(max(ref_y), ax2.get_ylim()[1]) 

528 if ax2.get_ylim()[1] < 1.05 * y_max: 

529 ax.set_ylim(ax.get_ylim()[0], 1.05 * y_max) 

530 ax2.set_ylim(ax.get_ylim()) 

531 ax2.legend(fontsize=legend_font_size, handlelength=1.5, loc="upper right", frameon=False) 

532 

533 return ax 

534 

535 def _addStatisticsPanel( 

536 self, 

537 fig, 

538 handles, 

539 nums, 

540 meds, 

541 mads, 

542 stats_dict, 

543 legend_font_size=8, 

544 yAnchor0=0.0, 

545 nth_row=0, 

546 nth_col=0, 

547 title_str=None, 

548 ): 

549 """Add an adjoining panel containing histogram summary statistics.""" 

550 ax = fig.add_subplot(1, 1, 1) 

551 ax.axis("off") 

552 plt.subplots_adjust(left=0.05, right=0.95, bottom=0.0, top=1.0) 

553 # empty handle, used to populate the bespoke legend layout 

554 empty = Rectangle((0, 0), 1, 1, fc="w", fill=False, edgecolor="none", linewidth=0) 

555 

556 # set up new legend handles and labels 

557 legend_handles = [empty] + handles + ([empty] * 3 * len(handles)) + ([empty] * 3) 

558 

559 legend_labels = ( 

560 ([""] * (len(handles) + 1)) 

561 + [stats_dict["statLabels"][0]] 

562 + [f"{x:.3g}" for x in stats_dict["stat1"]] 

563 + [stats_dict["statLabels"][1]] 

564 + [f"{x:.3g}" for x in stats_dict["stat2"]] 

565 + [stats_dict["statLabels"][2]] 

566 + [f"{x:.3g}" for x in stats_dict["stat3"]] 

567 ) 

568 # Set the y anchor for the legend such that it roughly lines up with 

569 # the panels. 

570 yAnchor = max(0, yAnchor0 - 0.01) + nth_col * (0.008 + len(nums) * 0.005) * legend_font_size 

571 

572 nth_legend = ax.legend( 

573 legend_handles, 

574 legend_labels, 

575 loc="lower left", 

576 bbox_to_anchor=(0.0, yAnchor), 

577 ncol=4, 

578 handletextpad=-0.25, 

579 fontsize=legend_font_size, 

580 borderpad=0, 

581 frameon=False, 

582 columnspacing=-0.25, 

583 title=title_str, 

584 title_fontproperties={"weight": "bold", "size": legend_font_size}, 

585 ) 

586 if nth_row + nth_col > 0: 

587 ax.add_artist(nth_legend)