Coverage for python/lsst/analysis/tools/actions/plot/histPlot.py: 16%

222 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-03-11 03:11 -0800

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("HistPanel", "HistPlot", "HistStatsPanel") 

24 

25import logging 

26from collections import defaultdict 

27from typing import Mapping 

28 

29import matplotlib.pyplot as plt 

30import numpy as np 

31from lsst.pex.config import ( 

32 ChoiceField, 

33 Config, 

34 ConfigDictField, 

35 ConfigField, 

36 DictField, 

37 Field, 

38 FieldValidationError, 

39 ListField, 

40) 

41from matplotlib.figure import Figure 

42from matplotlib.gridspec import GridSpec 

43from matplotlib.patches import Rectangle 

44 

45from ...interfaces import KeyedData, KeyedDataSchema, PlotAction, Vector 

46from ...statistics import sigmaMad 

47from .plotUtils import addPlotInfo 

48 

49log = logging.getLogger(__name__) 

50 

51 

52class HistStatsPanel(Config): 

53 """A Config class that holds parameters to configure a the stats panel 

54 shown for histPlot. 

55 

56 The fields in this class correspond to the parameters that can be used to 

57 customize the HistPlot stats panel. 

58 

59 - The ListField parameter a dict to specify names of 3 stat columns accepts 

60 latex formating 

61 

62 - The other parameters (stat1, stat2, stat3) are lists of strings that 

63 specify vector keys correspoinding to scalar values computed in the 

64 prep/process/produce steps of an analysis tools plot/metric configurable 

65 action. There should be one key for each group in the HistPanel. 

66 

67 A separate config class is used instead of constructing 

68 `~lsst.pex.config.DictField`'s in HistPanel for each parameter for clarity 

69 and consistency. 

70 

71 

72 

73 Notes 

74 ----- 

75 This is intended to be used as a configuration of the HistPlot/HistPanel 

76 class. 

77 

78 If no HistStatsPanel is specified then the default behavor persists where 

79 the stats panel shows N / median / sigma_mad for each group in the panel. 

80 """ 

81 

82 statsLabels = ListField[str]( 

83 doc="list specifying the labels for stats", 

84 length=3, 

85 default={"N$_{{data}}$", "Med", "${{\\sigma}}_{{MAD}}$"}, 

86 ) 

87 stat1 = ListField[str]( 

88 doc="A list specifying the vector keys of the first scalar statistic to be shown in this panel." 

89 "there should be one entry for each hist in the panel", 

90 default=None, 

91 optional=True, 

92 ) 

93 stat2 = ListField[str]( 

94 doc="A list specifying the vector keys of the second scalar statistic to be shown in this panel." 

95 "there should be one entry for each hist in the panel", 

96 default=None, 

97 optional=True, 

98 ) 

99 stat3 = ListField[str]( 

100 doc="A list specifying the vector keys of the third scalar statistic to be shown in this panel." 

101 "there should be one entry for each hist in the panel", 

102 default=None, 

103 optional=True, 

104 ) 

105 

106 def validate(self): 

107 super().validate() 

108 if not all([self.stat1, self.stat2, self.stat3]) and any([self.stat1, self.stat2, self.stat3]): 

109 raise ValueError(f"{self._name}: If one stat is configured, all 3 stats must be configured") 

110 

111 

112class HistPanel(Config): 

113 label = Field[str]( 

114 doc="Panel x-axis label.", 

115 default="label", 

116 ) 

117 hists = DictField[str, str]( 

118 doc="A dict specifying the histograms to be plotted in this panel. Keys are used to identify " 

119 "histogram IDs. Values are used to add to the legend label displayed in the upper corner of the " 

120 "panel.", 

121 optional=False, 

122 ) 

123 yscale = Field[str]( 

124 doc="Y axis scaling.", 

125 default="linear", 

126 ) 

127 bins = Field[int]( 

128 doc="Number of x axis bins within plot x-range.", 

129 default=50, 

130 ) 

131 rangeType = ChoiceField[str]( 

132 doc="Set the type of range to use for the x-axis. Range bounds will be set according to " 

133 "the values of lowerRange and upperRange.", 

134 allowed={ 

135 "percentile": "Upper and lower percentile ranges of the data.", 

136 "sigmaMad": "Range is (sigmaMad - lowerRange*sigmaMad, sigmaMad + upperRange*sigmaMad).", 

137 "fixed": "Range is fixed to (lowerRange, upperRange).", 

138 }, 

139 default="percentile", 

140 ) 

141 lowerRange = Field[float]( 

142 doc="Lower range specifier for the histogram bins. See rangeType for interpretation " 

143 "based on the type of range requested. If more than one histogram is plotted in a given " 

144 "panel and rangeType is not set to fixed, the limit is the minimum value across all input " 

145 "data.", 

146 default=0.0, 

147 ) 

148 upperRange = Field[float]( 

149 doc="Upper range specifier for the histogram bins. See rangeType for interpretation " 

150 "based on the type of range requested. If more than one histogram is plotted in a given " 

151 "panel and rangeType is not set to fixed, the limit is the maximum value across all input " 

152 "data.", 

153 default=100.0, 

154 ) 

155 referenceValue = Field[float]( 

156 doc="Value at which to add a black solid vertical line. Ignored if set to `None`.", 

157 default=None, 

158 optional=True, 

159 ) 

160 histDensity = Field[bool]( 

161 doc="Whether to plot the histogram as a normalized probability distribution. Must also " 

162 "provide a value for referenceValue", 

163 default=False, 

164 ) 

165 statsPanel = ConfigField[HistStatsPanel]( 

166 doc="configuration for stats to be shown on plot, if None then " 

167 "default stats: N, median, sigma mad are shown", 

168 default=None, 

169 ) 

170 

171 def validate(self): 

172 super().validate() 

173 if self.rangeType == "percentile" and self.lowerRange < 0.0 or self.upperRange > 100.0: 

174 msg = ( 

175 "For rangeType %s, ranges must obey: lowerRange >= 0 and upperRange <= 100." % self.rangeType 

176 ) 

177 raise FieldValidationError(self.__class__.rangeType, self, msg) 

178 if self.rangeType == "sigmaMad" and self.lowerRange < 0.0: 

179 msg = ( 

180 "For rangeType %s, lower range must obey: lowerRange >= 0 (the lower range is " 

181 "set as median - lowerRange*sigmaMad." % self.rangeType 

182 ) 

183 raise FieldValidationError(self.__class__.rangeType, self, msg) 

184 if self.rangeType == "fixed" and (self.upperRange - self.lowerRange) == 0.0: 

185 msg = ( 

186 "For rangeType %s, lower and upper ranges must differ (i.e. must obey: " 

187 "upperRange - lowerRange != 0)." % self.rangeType 

188 ) 

189 raise FieldValidationError(self.__class__.rangeType, self, msg) 

190 if self.histDensity and self.referenceValue is None: 

191 msg = "Must provide referenceValue if histDensity is True." 

192 raise FieldValidationError(self.__class__.referenceValue, self, msg) 

193 

194 

195class HistPlot(PlotAction): 

196 panels = ConfigDictField( 

197 doc="A configurable dict describing the panels to be plotted, and the histograms for each panel.", 

198 keytype=str, 

199 itemtype=HistPanel, 

200 default={}, 

201 ) 

202 cmap = Field[str]( 

203 doc="Color map used for histogram lines. All types available via `plt.cm` may be used. " 

204 "A number of custom color maps are also defined: `newtab10`, `bright`, `vibrant`.", 

205 default="newtab10", 

206 ) 

207 

208 def getInputSchema(self) -> KeyedDataSchema: 

209 for panel in self.panels: # type: ignore 

210 for histData in self.panels[panel].hists.items(): # type: ignore 

211 yield histData, Vector 

212 

213 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure: 

214 return self.makePlot(data, **kwargs) 

215 # table is a dict that needs: x, y, run, skymap, filter, tract, 

216 

217 def makePlot( 

218 self, data: KeyedData, plotInfo: Mapping[str, str] = None, **kwargs # type: ignore 

219 ) -> Figure: 

220 """Make an N-panel plot with a user-configurable number of histograms 

221 displayed in each panel. 

222 

223 Parameters 

224 ---------- 

225 data : `pandas.core.frame.DataFrame` 

226 The catalog to plot the points from. 

227 plotInfo : `dict` 

228 A dictionary of information about the data being plotted with keys: 

229 `"run"` 

230 Output run for the plots (`str`). 

231 `"tractTableType"` 

232 Table from which results are taken (`str`). 

233 `"plotName"` 

234 Output plot name (`str`) 

235 `"SN"` 

236 The global signal-to-noise data threshold (`float`) 

237 `"skymap"` 

238 The type of skymap used for the data (`str`). 

239 `"tract"` 

240 The tract that the data comes from (`int`). 

241 `"bands"` 

242 The bands used for this data (`str` or `list`). 

243 `"visit"` 

244 The visit that the data comes from (`int`) 

245 

246 Returns 

247 ------- 

248 fig : `matplotlib.figure.Figure` 

249 The resulting figure. 

250 

251 """ 

252 

253 # set up figure 

254 fig = plt.figure(dpi=300) 

255 hist_fig, side_fig = fig.subfigures(1, 2, wspace=0, width_ratios=[3, 1]) 

256 axs, ncols, nrows = self._makeAxes(hist_fig) 

257 

258 # loop over each panel; plot histograms 

259 colors = self._assignColors() 

260 nth_panel = len(self.panels) 

261 nth_col = ncols 

262 nth_row = nrows - 1 

263 label_font_size = max(6, 10 - nrows) 

264 for panel, ax in zip(self.panels, axs): 

265 nth_panel -= 1 

266 nth_col = ncols - 1 if nth_col == 0 else nth_col - 1 

267 if nth_panel == 0 and nrows * ncols - len(self.panels) > 0: 

268 nth_col -= 1 

269 # Set font size for legend based on number of panels being plotted. 

270 legend_font_size = max(4, int(8 - len(self.panels[panel].hists) / 2 - nrows // 2)) # type: ignore 

271 nums, meds, mads, stats_dict = self._makePanel( 

272 data, 

273 panel, 

274 ax, 

275 colors[panel], 

276 label_font_size=label_font_size, 

277 legend_font_size=legend_font_size, 

278 ncols=ncols, 

279 ) 

280 

281 all_handles, all_nums, all_meds, all_mads = [], [], [], [] 

282 handles, labels = ax.get_legend_handles_labels() # code for plotting 

283 all_handles += handles 

284 all_nums += nums 

285 all_meds += meds 

286 all_mads += mads 

287 title_str = self.panels[panel].label # type: ignore 

288 # add side panel; add statistics 

289 self._addStatisticsPanel( 

290 side_fig, 

291 all_handles, 

292 all_nums, 

293 all_meds, 

294 all_mads, 

295 stats_dict, 

296 legend_font_size=legend_font_size, 

297 yAnchor0=ax.get_position().y0, 

298 nth_row=nth_row, 

299 nth_col=nth_col, 

300 title_str=title_str, 

301 ) 

302 nth_row = nth_row - 1 if nth_col == 0 else nth_row 

303 

304 # add general plot info 

305 if plotInfo is not None: 

306 hist_fig = addPlotInfo(hist_fig, plotInfo) 

307 

308 # finish up 

309 plt.draw() 

310 return fig 

311 

312 def _makeAxes(self, fig): 

313 """Determine axes layout for main histogram figure.""" 

314 num_panels = len(self.panels) 

315 if num_panels <= 1: 

316 ncols = 1 

317 else: 

318 ncols = 2 

319 nrows = int(np.ceil(num_panels / ncols)) 

320 

321 gs = GridSpec(nrows, ncols, left=0.12, right=0.99, bottom=0.1, top=0.88, wspace=0.31, hspace=0.45) 

322 

323 axs = [] 

324 counter = 0 

325 for row in range(nrows): 

326 for col in range(ncols): 

327 counter += 1 

328 if counter < num_panels: 

329 axs.append(fig.add_subplot(gs[row : row + 1, col : col + 1])) 

330 else: 

331 axs.append(fig.add_subplot(gs[row : row + 1, col : np.min([col + 2, ncols + 1])])) 

332 break 

333 

334 return axs, ncols, nrows 

335 

336 def _assignColors(self): 

337 """Assign colors to histograms using a given color map.""" 

338 custom_cmaps = dict( 

339 # https://www.tableau.com/about/blog/2016/7/colors-upgrade-tableau-10-56782 

340 newtab10=[ 

341 "#4e79a7", 

342 "#f28e2b", 

343 "#e15759", 

344 "#76b7b2", 

345 "#59a14f", 

346 "#edc948", 

347 "#b07aa1", 

348 "#ff9da7", 

349 "#9c755f", 

350 "#bab0ac", 

351 ], 

352 # https://personal.sron.nl/~pault/#fig:scheme_bright 

353 bright=[ 

354 "#4477AA", 

355 "#EE6677", 

356 "#228833", 

357 "#CCBB44", 

358 "#66CCEE", 

359 "#AA3377", 

360 "#BBBBBB", 

361 ], 

362 # https://personal.sron.nl/~pault/#fig:scheme_vibrant 

363 vibrant=[ 

364 "#EE7733", 

365 "#0077BB", 

366 "#33BBEE", 

367 "#EE3377", 

368 "#CC3311", 

369 "#009988", 

370 "#BBBBBB", 

371 ], 

372 ) 

373 if self.cmap in custom_cmaps.keys(): 

374 all_colors = custom_cmaps[self.cmap] 

375 else: 

376 try: 

377 all_colors = getattr(plt.cm, self.cmap).copy().colors 

378 except AttributeError: 

379 raise ValueError(f"Unrecognized color map: {self.cmap}") 

380 

381 counter = 0 

382 colors = defaultdict(list) 

383 for panel in self.panels: 

384 for hist in self.panels[panel].hists: 

385 colors[panel].append(all_colors[counter % len(all_colors)]) 

386 counter += 1 

387 return colors 

388 

389 def _makePanel(self, data, panel, ax, colors, label_font_size=9, legend_font_size=7, ncols=1): 

390 """Plot a single panel containing histograms.""" 

391 nums, meds, mads = [], [], [] 

392 for i, hist in enumerate(self.panels[panel].hists): 

393 hist_data = data[hist][np.isfinite(data[hist])] 

394 num, med, mad = self._calcStats(hist_data) 

395 nums.append(num) 

396 meds.append(med) 

397 mads.append(mad) 

398 panel_range = self._getPanelRange(data, panel, mads=mads, meds=meds) 

399 

400 for i, hist in enumerate(self.panels[panel].hists): 

401 hist_data = data[hist][np.isfinite(data[hist])] 

402 ax.hist( 

403 hist_data, 

404 range=panel_range, 

405 bins=self.panels[panel].bins, 

406 histtype="step", 

407 density=self.panels[panel].histDensity, 

408 lw=2, 

409 color=colors[i], 

410 label=self.panels[panel].hists[hist], 

411 ) 

412 ax.axvline(meds[i], ls=(0, (5, 3)), lw=1, c=colors[i]) 

413 

414 ax.legend(fontsize=legend_font_size, loc="upper left", frameon=False) 

415 ax.set_xlim(panel_range) 

416 # The following accommodates spacing for ranges with large numbers 

417 # but small-ish dynamic range (example use case: RA 300-301). 

418 if ncols > 1 and max(np.abs(panel_range)) >= 100 and (panel_range[1] - panel_range[0]) < 5: 

419 ax.xaxis.set_major_formatter("{x:.2f}") 

420 ax.tick_params(axis="x", labelrotation=25, pad=-1) 

421 ax.set_xlabel(self.panels[panel].label, fontsize=label_font_size) 

422 y_label = "Normalized (PDF)" if self.panels[panel].histDensity else "Frequency" 

423 ax.set_ylabel(y_label, fontsize=label_font_size) 

424 ax.set_yscale(self.panels[panel].yscale) 

425 ax.tick_params(labelsize=max(5, label_font_size - 2)) 

426 # add a buffer to the top of the plot to allow headspace for labels 

427 ylims = list(ax.get_ylim()) 

428 if ax.get_yscale() == "log": 

429 ylims[1] = 10 ** (np.log10(ylims[1]) * 1.1) 

430 else: 

431 ylims[1] *= 1.1 

432 ax.set_ylim(ylims[0], ylims[1]) 

433 

434 # Draw a vertical line at a reference value, if given. If histDensity 

435 # is True, also plot a reference PDF with mean = referenceValue and 

436 # sigma = 1 for reference. 

437 if self.panels[panel].referenceValue is not None: 

438 ax = self._addReferenceLines(ax, panel, panel_range, legend_font_size=legend_font_size) 

439 

440 # Check if we should use the default stats panel or if a custom one 

441 # has been created. 

442 statList = [ 

443 self.panels[panel].statsPanel.stat1, 

444 self.panels[panel].statsPanel.stat2, 

445 self.panels[panel].statsPanel.stat3, 

446 ] 

447 if not any(statList): 

448 stats_dict = { 

449 "statLabels": ["N$_{{data}}$", "Med", "${{\\sigma}}_{{MAD}}$"], 

450 "stat1": nums, 

451 "stat2": meds, 

452 "stat3": mads, 

453 } 

454 elif all(statList): 

455 stat1 = [data[stat] for stat in self.panels[panel].statsPanel.stat1] 

456 stat2 = [data[stat] for stat in self.panels[panel].statsPanel.stat2] 

457 stat3 = [data[stat] for stat in self.panels[panel].statsPanel.stat3] 

458 stats_dict = { 

459 "statLabels": self.panels[panel].statsPanel.statsLabels, 

460 "stat1": stat1, 

461 "stat2": stat2, 

462 "stat3": stat3, 

463 } 

464 else: 

465 raise RuntimeError("Invalid configuration of HistStatPanel") 

466 

467 return nums, meds, mads, stats_dict 

468 

469 def _getPanelRange(self, data, panel, mads=None, meds=None): 

470 """Determine panel x-axis range based config settings.""" 

471 panel_range = [np.nan, np.nan] 

472 rangeType = self.panels[panel].rangeType 

473 lowerRange = self.panels[panel].lowerRange 

474 upperRange = self.panels[panel].upperRange 

475 if rangeType == "percentile": 

476 panel_range = self._getPercentilePanelRange(data, panel) 

477 elif rangeType == "sigmaMad": 

478 # Set the panel range to extend lowerRange[upperRange] times the 

479 # maximum sigmaMad for the datasets in the panel to the left[right] 

480 # from the minimum[maximum] median value of all datasets in the 

481 # panel. 

482 maxMad = np.nanmax(mads) 

483 maxMed = np.nanmax(meds) 

484 minMed = np.nanmin(meds) 

485 panel_range = [minMed - lowerRange * maxMad, maxMed + upperRange * maxMad] 

486 if panel_range[1] - panel_range[0] == 0: 

487 self.log.info( 

488 "NOTE: panel_range for {} based on med/sigMad was 0. Computing using " 

489 "percentile range instead.".format(panel) 

490 ) 

491 panel_range = self._getPercentilePanelRange(data, panel) 

492 elif rangeType == "fixed": 

493 panel_range = [lowerRange, upperRange] 

494 else: 

495 raise RuntimeError(f"Invalid rangeType: {rangeType}") 

496 return panel_range 

497 

498 def _getPercentilePanelRange(self, data, panel): 

499 """Determine panel x-axis range based on data percentile limits.""" 

500 panel_range = [np.nan, np.nan] 

501 for hist in self.panels[panel].hists: 

502 hist_range = np.nanpercentile( 

503 data[hist], [self.panels[panel].lowerRange, self.panels[panel].upperRange] 

504 ) 

505 panel_range[0] = np.nanmin([panel_range[0], hist_range[0]]) 

506 panel_range[1] = np.nanmax([panel_range[1], hist_range[1]]) 

507 return panel_range 

508 

509 def _calcStats(self, data): 

510 """Calculate the number of data points, median, and median absolute 

511 deviation of input data.""" 

512 num = len(data) 

513 med = np.nanmedian(data) 

514 mad = sigmaMad(data) 

515 return num, med, mad 

516 

517 def _addReferenceLines(self, ax, panel, panel_range, legend_font_size=7): 

518 """Draw the vertical reference line and density curve (if requested) 

519 on the panel. 

520 """ 

521 ax2 = ax.twinx() 

522 ax2.axis("off") 

523 ax2.set_xlim(ax.get_xlim()) 

524 ax2.set_ylim(ax.get_ylim()) 

525 

526 if self.panels[panel].histDensity: 

527 reference_label = None 

528 else: 

529 reference_label = "${{\\mu_{{ref}}}}$: {}".format(self.panels[panel].referenceValue) 

530 ax2.axvline( 

531 self.panels[panel].referenceValue, ls="-", lw=1, c="black", zorder=0, label=reference_label 

532 ) 

533 if self.panels[panel].histDensity: 

534 ref_x = np.arange(panel_range[0], panel_range[1], (panel_range[1] - panel_range[0]) / 100.0) 

535 ref_mean = self.panels[panel].referenceValue 

536 ref_std = 1.0 

537 ref_y = ( 

538 1.0 

539 / (ref_std * np.sqrt(2.0 * np.pi)) 

540 * np.exp(-((ref_x - ref_mean) ** 2) / (2.0 * ref_std**2)) 

541 ) 

542 ax2.fill_between(ref_x, ref_y, alpha=0.1, color="black", label="P$_{{norm}}(0,1)$", zorder=-1) 

543 # Make sure the y-axis extends beyond the data plotted and that 

544 # the y-ranges of both axes are in sync. 

545 y_max = max(max(ref_y), ax2.get_ylim()[1]) 

546 if ax2.get_ylim()[1] < 1.05 * y_max: 

547 ax.set_ylim(ax.get_ylim()[0], 1.05 * y_max) 

548 ax2.set_ylim(ax.get_ylim()) 

549 ax2.legend(fontsize=legend_font_size, handlelength=1.5, loc="upper right", frameon=False) 

550 

551 return ax 

552 

553 def _addStatisticsPanel( 

554 self, 

555 fig, 

556 handles, 

557 nums, 

558 meds, 

559 mads, 

560 stats_dict, 

561 legend_font_size=8, 

562 yAnchor0=0.0, 

563 nth_row=0, 

564 nth_col=0, 

565 title_str=None, 

566 ): 

567 """Add an adjoining panel containing histogram summary statistics.""" 

568 ax = fig.add_subplot(1, 1, 1) 

569 ax.axis("off") 

570 plt.subplots_adjust(left=0.05, right=0.95, bottom=0.0, top=1.0) 

571 # empty handle, used to populate the bespoke legend layout 

572 empty = Rectangle((0, 0), 1, 1, fc="w", fill=False, edgecolor="none", linewidth=0) 

573 

574 # set up new legend handles and labels 

575 legend_handles = [empty] + handles + ([empty] * 3 * len(handles)) + ([empty] * 3) 

576 

577 legend_labels = ( 

578 ([""] * (len(handles) + 1)) 

579 + [stats_dict["statLabels"][0]] 

580 + [f"{x:.3g}" for x in stats_dict["stat1"]] 

581 + [stats_dict["statLabels"][1]] 

582 + [f"{x:.3g}" for x in stats_dict["stat2"]] 

583 + [stats_dict["statLabels"][2]] 

584 + [f"{x:.3g}" for x in stats_dict["stat3"]] 

585 ) 

586 # Set the y anchor for the legend such that it roughly lines up with 

587 # the panels. 

588 yAnchor = max(0, yAnchor0 - 0.01) + nth_col * (0.008 + len(nums) * 0.005) * legend_font_size 

589 

590 nth_legend = ax.legend( 

591 legend_handles, 

592 legend_labels, 

593 loc="lower left", 

594 bbox_to_anchor=(0.0, yAnchor), 

595 ncol=4, 

596 handletextpad=-0.25, 

597 fontsize=legend_font_size, 

598 borderpad=0, 

599 frameon=False, 

600 columnspacing=-0.25, 

601 title=title_str, 

602 title_fontproperties={"weight": "bold", "size": legend_font_size}, 

603 ) 

604 if nth_row + nth_col > 0: 

605 ax.add_artist(nth_legend)