Coverage for python/lsst/analysis/tools/actions/plot/histPlot.py: 15%

233 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-10 14:10 +0000

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("HistPanel", "HistPlot", "HistStatsPanel") 

24 

25import logging 

26from collections import defaultdict 

27from typing import Mapping 

28 

29import matplotlib.pyplot as plt 

30import numpy as np 

31from lsst.pex.config import ( 

32 ChoiceField, 

33 Config, 

34 ConfigDictField, 

35 ConfigField, 

36 DictField, 

37 Field, 

38 FieldValidationError, 

39 ListField, 

40) 

41from matplotlib.figure import Figure 

42from matplotlib.gridspec import GridSpec 

43from matplotlib.patches import Rectangle 

44 

45from ...interfaces import KeyedData, KeyedDataSchema, PlotAction, Vector 

46from ...math import nanMax, nanMedian, nanMin, sigmaMad 

47from .plotUtils import addPlotInfo 

48 

49log = logging.getLogger(__name__) 

50 

51 

52class HistStatsPanel(Config): 

53 """A Config class that holds parameters to configure a the stats panel 

54 shown for histPlot. 

55 

56 The fields in this class correspond to the parameters that can be used to 

57 customize the HistPlot stats panel. 

58 

59 - The ListField parameter a dict to specify names of 3 stat columns accepts 

60 latex formating 

61 

62 - The other parameters (stat1, stat2, stat3) are lists of strings that 

63 specify vector keys correspoinding to scalar values computed in the 

64 prep/process/produce steps of an analysis tools plot/metric configurable 

65 action. There should be one key for each group in the HistPanel. 

66 

67 A separate config class is used instead of constructing 

68 `~lsst.pex.config.DictField`'s in HistPanel for each parameter for clarity 

69 and consistency. 

70 

71 Notes 

72 ----- 

73 This is intended to be used as a configuration of the HistPlot/HistPanel 

74 class. 

75 

76 If no HistStatsPanel is specified then the default behavor persists where 

77 the stats panel shows N / median / sigma_mad for each group in the panel. 

78 """ 

79 

80 statsLabels = ListField[str]( 

81 doc="list specifying the labels for stats", 

82 length=3, 

83 default=("N$_{{data}}$", "Med", "${{\\sigma}}_{{MAD}}$"), 

84 ) 

85 stat1 = ListField[str]( 

86 doc="A list specifying the vector keys of the first scalar statistic to be shown in this panel." 

87 "there should be one entry for each hist in the panel", 

88 default=None, 

89 optional=True, 

90 ) 

91 stat2 = ListField[str]( 

92 doc="A list specifying the vector keys of the second scalar statistic to be shown in this panel." 

93 "there should be one entry for each hist in the panel", 

94 default=None, 

95 optional=True, 

96 ) 

97 stat3 = ListField[str]( 

98 doc="A list specifying the vector keys of the third scalar statistic to be shown in this panel." 

99 "there should be one entry for each hist in the panel", 

100 default=None, 

101 optional=True, 

102 ) 

103 

104 def validate(self): 

105 super().validate() 

106 if not all([self.stat1, self.stat2, self.stat3]) and any([self.stat1, self.stat2, self.stat3]): 

107 raise ValueError(f"{self._name}: If one stat is configured, all 3 stats must be configured") 

108 

109 

110class HistPanel(Config): 

111 """A Config class that holds parameters to configure a single panel of a 

112 histogram plot. This class is intended to be used within the ``HistPlot`` 

113 class. 

114 """ 

115 

116 label = Field[str]( 

117 doc="Panel x-axis label.", 

118 default="label", 

119 ) 

120 hists = DictField[str, str]( 

121 doc="A dict specifying the histograms to be plotted in this panel. Keys are used to identify " 

122 "histogram IDs. Values are used to add to the legend label displayed in the upper corner of the " 

123 "panel.", 

124 optional=False, 

125 ) 

126 yscale = Field[str]( 

127 doc="Y axis scaling.", 

128 default="linear", 

129 ) 

130 bins = Field[int]( 

131 doc="Number of x axis bins within plot x-range.", 

132 default=50, 

133 ) 

134 rangeType = ChoiceField[str]( 

135 doc="Set the type of range to use for the x-axis. Range bounds will be set according to " 

136 "the values of lowerRange and upperRange.", 

137 allowed={ 

138 "percentile": "Upper and lower percentile ranges of the data.", 

139 "sigmaMad": "Range is (sigmaMad - lowerRange*sigmaMad, sigmaMad + upperRange*sigmaMad).", 

140 "fixed": "Range is fixed to (lowerRange, upperRange).", 

141 }, 

142 default="percentile", 

143 ) 

144 lowerRange = Field[float]( 

145 doc="Lower range specifier for the histogram bins. See rangeType for interpretation " 

146 "based on the type of range requested. If more than one histogram is plotted in a given " 

147 "panel and rangeType is not set to fixed, the limit is the minimum value across all input " 

148 "data.", 

149 default=0.0, 

150 ) 

151 upperRange = Field[float]( 

152 doc="Upper range specifier for the histogram bins. See rangeType for interpretation " 

153 "based on the type of range requested. If more than one histogram is plotted in a given " 

154 "panel and rangeType is not set to fixed, the limit is the maximum value across all input " 

155 "data.", 

156 default=100.0, 

157 ) 

158 referenceValue = Field[float]( 

159 doc="Value at which to add a black solid vertical line. Ignored if set to `None`.", 

160 default=None, 

161 optional=True, 

162 ) 

163 refRelativeToMedian = Field[bool]( 

164 doc="Is the referenceValue meant to be an offset from the median?", 

165 default=False, 

166 optional=True, 

167 ) 

168 histDensity = Field[bool]( 

169 doc="Whether to plot the histogram as a normalized probability distribution. Must also " 

170 "provide a value for referenceValue", 

171 default=False, 

172 ) 

173 statsPanel = ConfigField[HistStatsPanel]( 

174 doc="configuration for stats to be shown on plot, if None then " 

175 "default stats: N, median, sigma mad are shown", 

176 default=None, 

177 ) 

178 

179 def validate(self): 

180 super().validate() 

181 if self.rangeType == "percentile" and self.lowerRange < 0.0 or self.upperRange > 100.0: 

182 msg = ( 

183 "For rangeType %s, ranges must obey: lowerRange >= 0 and upperRange <= 100." % self.rangeType 

184 ) 

185 raise FieldValidationError(self.__class__.rangeType, self, msg) 

186 if self.rangeType == "sigmaMad" and self.lowerRange < 0.0: 

187 msg = ( 

188 "For rangeType %s, lower range must obey: lowerRange >= 0 (the lower range is " 

189 "set as median - lowerRange*sigmaMad." % self.rangeType 

190 ) 

191 raise FieldValidationError(self.__class__.rangeType, self, msg) 

192 if self.rangeType == "fixed" and (self.upperRange - self.lowerRange) == 0.0: 

193 msg = ( 

194 "For rangeType %s, lower and upper ranges must differ (i.e. must obey: " 

195 "upperRange - lowerRange != 0)." % self.rangeType 

196 ) 

197 raise FieldValidationError(self.__class__.rangeType, self, msg) 

198 if self.histDensity and self.referenceValue is None: 

199 msg = "Must provide referenceValue if histDensity is True." 

200 raise FieldValidationError(self.__class__.referenceValue, self, msg) 

201 

202 

203class HistPlot(PlotAction): 

204 """Make an N-panel plot with a configurable number of histograms displayed 

205 in each panel. Reference lines showing values of interest may also be added 

206 to each histogram. Panels are configured using the ``HistPanel`` class. 

207 """ 

208 

209 panels = ConfigDictField( 

210 doc="A configurable dict describing the panels to be plotted, and the histograms for each panel.", 

211 keytype=str, 

212 itemtype=HistPanel, 

213 default={}, 

214 ) 

215 cmap = Field[str]( 

216 doc="Color map used for histogram lines. All types available via `plt.cm` may be used. " 

217 "A number of custom color maps are also defined: `newtab10`, `bright`, `vibrant`.", 

218 default="newtab10", 

219 ) 

220 

221 def getInputSchema(self) -> KeyedDataSchema: 

222 for panel in self.panels: # type: ignore 

223 for histData in self.panels[panel].hists.items(): # type: ignore 

224 yield histData, Vector 

225 

226 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure: 

227 return self.makePlot(data, **kwargs) 

228 # table is a dict that needs: x, y, run, skymap, filter, tract, 

229 

230 def makePlot( 

231 self, data: KeyedData, plotInfo: Mapping[str, str] = None, **kwargs # type: ignore 

232 ) -> Figure: 

233 """Make an N-panel plot with a user-configurable number of histograms 

234 displayed in each panel. 

235 

236 Parameters 

237 ---------- 

238 data : `pandas.core.frame.DataFrame` 

239 The catalog to plot the points from. 

240 plotInfo : `dict` 

241 A dictionary of information about the data being plotted with keys: 

242 `"run"` 

243 Output run for the plots (`str`). 

244 `"tractTableType"` 

245 Table from which results are taken (`str`). 

246 `"plotName"` 

247 Output plot name (`str`) 

248 `"SN"` 

249 The global signal-to-noise data threshold (`float`) 

250 `"skymap"` 

251 The type of skymap used for the data (`str`). 

252 `"tract"` 

253 The tract that the data comes from (`int`). 

254 `"bands"` 

255 The bands used for this data (`str` or `list`). 

256 `"visit"` 

257 The visit that the data comes from (`int`) 

258 

259 Returns 

260 ------- 

261 fig : `matplotlib.figure.Figure` 

262 The resulting figure. 

263 

264 Examples 

265 -------- 

266 An example histogram plot may be seen below: 

267 

268 .. image:: /_static/analysis_tools/histPlotExample.png 

269 

270 For further details on how to generate a plot, please refer to the 

271 :ref:`getting started guide<analysis-tools-getting-started>`. 

272 """ 

273 

274 # set up figure 

275 fig = plt.figure(dpi=300) 

276 hist_fig, side_fig = fig.subfigures(1, 2, wspace=0, width_ratios=[3, 1]) 

277 axs, ncols, nrows = self._makeAxes(hist_fig) 

278 

279 # loop over each panel; plot histograms 

280 colors = self._assignColors() 

281 nth_panel = len(self.panels) 

282 nth_col = ncols 

283 nth_row = nrows - 1 

284 label_font_size = max(6, 10 - nrows) 

285 for panel, ax in zip(self.panels, axs): 

286 nth_panel -= 1 

287 nth_col = ncols - 1 if nth_col == 0 else nth_col - 1 

288 if nth_panel == 0 and nrows * ncols - len(self.panels) > 0: 

289 nth_col -= 1 

290 # Set font size for legend based on number of panels being plotted. 

291 legend_font_size = max(4, int(8 - len(self.panels[panel].hists) / 2 - nrows // 2)) # type: ignore 

292 nums, meds, mads, stats_dict = self._makePanel( 

293 data, 

294 panel, 

295 ax, 

296 colors[panel], 

297 label_font_size=label_font_size, 

298 legend_font_size=legend_font_size, 

299 ncols=ncols, 

300 ) 

301 

302 all_handles, all_nums, all_meds, all_mads = [], [], [], [] 

303 handles, labels = ax.get_legend_handles_labels() # code for plotting 

304 all_handles += handles 

305 all_nums += nums 

306 all_meds += meds 

307 all_mads += mads 

308 title_str = self.panels[panel].label # type: ignore 

309 # add side panel; add statistics 

310 self._addStatisticsPanel( 

311 side_fig, 

312 all_handles, 

313 all_nums, 

314 all_meds, 

315 all_mads, 

316 stats_dict, 

317 legend_font_size=legend_font_size, 

318 yAnchor0=ax.get_position().y0, 

319 nth_row=nth_row, 

320 nth_col=nth_col, 

321 title_str=title_str, 

322 ) 

323 nth_row = nth_row - 1 if nth_col == 0 else nth_row 

324 

325 # add general plot info 

326 if plotInfo is not None: 

327 hist_fig = addPlotInfo(hist_fig, plotInfo) 

328 

329 # finish up 

330 plt.draw() 

331 return fig 

332 

333 def _makeAxes(self, fig): 

334 """Determine axes layout for main histogram figure.""" 

335 num_panels = len(self.panels) 

336 if num_panels <= 1: 

337 ncols = 1 

338 else: 

339 ncols = 2 

340 nrows = int(np.ceil(num_panels / ncols)) 

341 

342 gs = GridSpec(nrows, ncols, left=0.12, right=0.99, bottom=0.1, top=0.88, wspace=0.31, hspace=0.45) 

343 

344 axs = [] 

345 counter = 0 

346 for row in range(nrows): 

347 for col in range(ncols): 

348 counter += 1 

349 if counter < num_panels: 

350 axs.append(fig.add_subplot(gs[row : row + 1, col : col + 1])) 

351 else: 

352 axs.append(fig.add_subplot(gs[row : row + 1, col : np.min([col + 2, ncols + 1])])) 

353 break 

354 

355 return axs, ncols, nrows 

356 

357 def _assignColors(self): 

358 """Assign colors to histograms using a given color map.""" 

359 custom_cmaps = dict( 

360 # https://www.tableau.com/about/blog/2016/7/colors-upgrade-tableau-10-56782 

361 newtab10=[ 

362 "#4e79a7", 

363 "#f28e2b", 

364 "#e15759", 

365 "#76b7b2", 

366 "#59a14f", 

367 "#edc948", 

368 "#b07aa1", 

369 "#ff9da7", 

370 "#9c755f", 

371 "#bab0ac", 

372 ], 

373 # https://personal.sron.nl/~pault/#fig:scheme_bright 

374 bright=[ 

375 "#4477AA", 

376 "#EE6677", 

377 "#228833", 

378 "#CCBB44", 

379 "#66CCEE", 

380 "#AA3377", 

381 "#BBBBBB", 

382 ], 

383 # https://personal.sron.nl/~pault/#fig:scheme_vibrant 

384 vibrant=[ 

385 "#EE7733", 

386 "#0077BB", 

387 "#33BBEE", 

388 "#EE3377", 

389 "#CC3311", 

390 "#009988", 

391 "#BBBBBB", 

392 ], 

393 ) 

394 if self.cmap in custom_cmaps.keys(): 

395 all_colors = custom_cmaps[self.cmap] 

396 else: 

397 try: 

398 all_colors = getattr(plt.cm, self.cmap).copy().colors 

399 except AttributeError: 

400 raise ValueError(f"Unrecognized color map: {self.cmap}") 

401 

402 counter = 0 

403 colors = defaultdict(list) 

404 for panel in self.panels: 

405 for hist in self.panels[panel].hists: 

406 colors[panel].append(all_colors[counter % len(all_colors)]) 

407 counter += 1 

408 return colors 

409 

410 def _makePanel(self, data, panel, ax, colors, label_font_size=9, legend_font_size=7, ncols=1): 

411 """Plot a single panel containing histograms.""" 

412 nums, meds, mads = [], [], [] 

413 for i, hist in enumerate(self.panels[panel].hists): 

414 hist_data = data[hist][np.isfinite(data[hist])] 

415 num, med, mad = self._calcStats(hist_data) 

416 nums.append(num) 

417 meds.append(med) 

418 mads.append(mad) 

419 panel_range = self._getPanelRange(data, panel, mads=mads, meds=meds) 

420 if all(np.isfinite(panel_range)): 

421 for i, hist in enumerate(self.panels[panel].hists): 

422 hist_data = data[hist][np.isfinite(data[hist])] 

423 if len(hist_data) > 0: 

424 ax.hist( 

425 hist_data, 

426 range=panel_range, 

427 bins=self.panels[panel].bins, 

428 histtype="step", 

429 density=self.panels[panel].histDensity, 

430 lw=2, 

431 color=colors[i], 

432 label=self.panels[panel].hists[hist], 

433 ) 

434 ax.axvline(meds[i], ls=(0, (5, 3)), lw=1, c=colors[i]) 

435 

436 ax.legend(fontsize=legend_font_size, loc="upper left", frameon=False) 

437 ax.set_xlim(panel_range) 

438 # The following accommodates spacing for ranges with large numbers 

439 # but small-ish dynamic range (example use case: RA 300-301). 

440 if ncols > 1 and max(np.abs(panel_range)) >= 100 and (panel_range[1] - panel_range[0]) < 5: 

441 ax.xaxis.set_major_formatter("{x:.2f}") 

442 ax.tick_params(axis="x", labelrotation=25, pad=-1) 

443 ax.set_xlabel(self.panels[panel].label, fontsize=label_font_size) 

444 y_label = "Normalized (PDF)" if self.panels[panel].histDensity else "Frequency" 

445 ax.set_ylabel(y_label, fontsize=label_font_size) 

446 ax.set_yscale(self.panels[panel].yscale) 

447 ax.tick_params(labelsize=max(5, label_font_size - 2)) 

448 # add a buffer to the top of the plot to allow headspace for labels 

449 ylims = list(ax.get_ylim()) 

450 if ax.get_yscale() == "log": 

451 ylims[1] = 10 ** (np.log10(ylims[1]) * 1.1) 

452 else: 

453 ylims[1] *= 1.1 

454 ax.set_ylim(ylims[0], ylims[1]) 

455 

456 # Draw a vertical line at a reference value, if given. 

457 # If histDensity is True, also plot a reference PDF with 

458 # mean = referenceValue and sigma = 1 for reference. 

459 if self.panels[panel].referenceValue is not None: 

460 ax = self._addReferenceLines(ax, panel, panel_range, meds, legend_font_size=legend_font_size) 

461 

462 # Check if we should use the default stats panel or if a custom one 

463 # has been created. 

464 statList = [ 

465 self.panels[panel].statsPanel.stat1, 

466 self.panels[panel].statsPanel.stat2, 

467 self.panels[panel].statsPanel.stat3, 

468 ] 

469 if not any(statList): 

470 stats_dict = { 

471 "statLabels": ["N$_{{data}}$", "Med", "${{\\sigma}}_{{MAD}}$"], 

472 "stat1": nums, 

473 "stat2": meds, 

474 "stat3": mads, 

475 } 

476 elif all(statList): 

477 stat1 = [data[stat] for stat in self.panels[panel].statsPanel.stat1] 

478 stat2 = [data[stat] for stat in self.panels[panel].statsPanel.stat2] 

479 stat3 = [data[stat] for stat in self.panels[panel].statsPanel.stat3] 

480 stats_dict = { 

481 "statLabels": self.panels[panel].statsPanel.statsLabels, 

482 "stat1": stat1, 

483 "stat2": stat2, 

484 "stat3": stat3, 

485 } 

486 else: 

487 raise RuntimeError("Invalid configuration of HistStatPanel") 

488 else: 

489 stats_dict = {key: [] for key in ("stat1", "stat2", "stat3")} 

490 stats_dict["statLabels"] = [""] * 3 

491 return nums, meds, mads, stats_dict 

492 

493 def _getPanelRange(self, data, panel, mads=None, meds=None): 

494 """Determine panel x-axis range based config settings.""" 

495 panel_range = [np.nan, np.nan] 

496 rangeType = self.panels[panel].rangeType 

497 lowerRange = self.panels[panel].lowerRange 

498 upperRange = self.panels[panel].upperRange 

499 if rangeType == "percentile": 

500 panel_range = self._getPercentilePanelRange(data, panel) 

501 elif rangeType == "sigmaMad": 

502 # Set the panel range to extend lowerRange[upperRange] times the 

503 # maximum sigmaMad for the datasets in the panel to the left[right] 

504 # from the minimum[maximum] median value of all datasets in the 

505 # panel. 

506 maxMad = nanMax(mads) 

507 maxMed = nanMax(meds) 

508 minMed = nanMin(meds) 

509 panel_range = [minMed - lowerRange * maxMad, maxMed + upperRange * maxMad] 

510 if panel_range[1] - panel_range[0] == 0: 

511 log.info( 

512 "NOTE: panel_range for {} based on med/sigMad was 0. Computing using " 

513 "percentile range instead.".format(panel) 

514 ) 

515 panel_range = self._getPercentilePanelRange(data, panel) 

516 elif rangeType == "fixed": 

517 panel_range = [lowerRange, upperRange] 

518 else: 

519 raise RuntimeError(f"Invalid rangeType: {rangeType}") 

520 return panel_range 

521 

522 def _getPercentilePanelRange(self, data, panel): 

523 """Determine panel x-axis range based on data percentile limits.""" 

524 panel_range = [np.nan, np.nan] 

525 for hist in self.panels[panel].hists: 

526 data_hist = data[hist] 

527 # TODO: Consider raising instead 

528 if len(data_hist) > 0: 

529 hist_range = np.nanpercentile( 

530 data[hist], [self.panels[panel].lowerRange, self.panels[panel].upperRange] 

531 ) 

532 panel_range[0] = nanMin([panel_range[0], hist_range[0]]) 

533 panel_range[1] = nanMax([panel_range[1], hist_range[1]]) 

534 return panel_range 

535 

536 def _calcStats(self, data): 

537 """Calculate the number of data points, median, and median absolute 

538 deviation of input data.""" 

539 num = len(data) 

540 med = nanMedian(data) 

541 mad = sigmaMad(data) 

542 return num, med, mad 

543 

544 def _addReferenceLines(self, ax, panel, panel_range, meds, legend_font_size=7): 

545 """Draw the vertical reference line and density curve (if requested) 

546 on the panel. 

547 """ 

548 ax2 = ax.twinx() 

549 ax2.axis("off") 

550 ax2.set_xlim(ax.get_xlim()) 

551 ax2.set_ylim(ax.get_ylim()) 

552 

553 if self.panels[panel].histDensity: 

554 reference_label = None 

555 else: 

556 if self.panels[panel].refRelativeToMedian: 

557 reference_value = self.panels[panel].referenceValue + meds[0] 

558 reference_label = "${{\\mu_{{ref}}}}$: {}".format(reference_value) 

559 else: 

560 reference_value = self.panels[panel].referenceValue 

561 reference_label = "${{\\mu_{{ref}}}}$: {}".format(reference_value) 

562 ax2.axvline(reference_value, ls="-", lw=1, c="black", zorder=0, label=reference_label) 

563 if self.panels[panel].histDensity: 

564 ref_x = np.arange(panel_range[0], panel_range[1], (panel_range[1] - panel_range[0]) / 100.0) 

565 ref_mean = self.panels[panel].referenceValue 

566 ref_std = 1.0 

567 ref_y = ( 

568 1.0 

569 / (ref_std * np.sqrt(2.0 * np.pi)) 

570 * np.exp(-((ref_x - ref_mean) ** 2) / (2.0 * ref_std**2)) 

571 ) 

572 ax2.fill_between(ref_x, ref_y, alpha=0.1, color="black", label="P$_{{norm}}(0,1)$", zorder=-1) 

573 # Make sure the y-axis extends beyond the data plotted and that 

574 # the y-ranges of both axes are in sync. 

575 y_max = max(max(ref_y), ax2.get_ylim()[1]) 

576 if ax2.get_ylim()[1] < 1.05 * y_max: 

577 ax.set_ylim(ax.get_ylim()[0], 1.05 * y_max) 

578 ax2.set_ylim(ax.get_ylim()) 

579 ax2.legend(fontsize=legend_font_size, handlelength=1.5, loc="upper right", frameon=False) 

580 

581 return ax 

582 

583 def _addStatisticsPanel( 

584 self, 

585 fig, 

586 handles, 

587 nums, 

588 meds, 

589 mads, 

590 stats_dict, 

591 legend_font_size=8, 

592 yAnchor0=0.0, 

593 nth_row=0, 

594 nth_col=0, 

595 title_str=None, 

596 ): 

597 """Add an adjoining panel containing histogram summary statistics.""" 

598 ax = fig.add_subplot(1, 1, 1) 

599 ax.axis("off") 

600 plt.subplots_adjust(left=0.05, right=0.95, bottom=0.0, top=1.0) 

601 # empty handle, used to populate the bespoke legend layout 

602 empty = Rectangle((0, 0), 1, 1, fc="w", fill=False, edgecolor="none", linewidth=0) 

603 

604 # set up new legend handles and labels 

605 legend_handles = [empty] + handles + ([empty] * 3 * len(handles)) + ([empty] * 3) 

606 

607 legend_labels = ( 

608 ([""] * (len(handles) + 1)) 

609 + [stats_dict["statLabels"][0]] 

610 + [f"{x:.3g}" for x in stats_dict["stat1"]] 

611 + [stats_dict["statLabels"][1]] 

612 + [f"{x:.3g}" for x in stats_dict["stat2"]] 

613 + [stats_dict["statLabels"][2]] 

614 + [f"{x:.3g}" for x in stats_dict["stat3"]] 

615 ) 

616 # Set the y anchor for the legend such that it roughly lines up with 

617 # the panels. 

618 yAnchor = max(0, yAnchor0 - 0.01) + nth_col * (0.008 + len(nums) * 0.005) * legend_font_size 

619 

620 nth_legend = ax.legend( 

621 legend_handles, 

622 legend_labels, 

623 loc="lower left", 

624 bbox_to_anchor=(0.0, yAnchor), 

625 ncol=4, 

626 handletextpad=-0.25, 

627 fontsize=legend_font_size, 

628 borderpad=0, 

629 frameon=False, 

630 columnspacing=-0.25, 

631 title=title_str, 

632 title_fontproperties={"weight": "bold", "size": legend_font_size}, 

633 ) 

634 if nth_row + nth_col > 0: 

635 ax.add_artist(nth_legend)