Coverage for python/lsst/analysis/tools/actions/plot/histPlot.py: 16%

222 statements  

« prev     ^ index     » next       coverage.py v7.2.5, created at 2023-05-02 11:55 -0700

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("HistPanel", "HistPlot", "HistStatsPanel") 

24 

25import logging 

26from collections import defaultdict 

27from typing import Mapping 

28 

29import matplotlib.pyplot as plt 

30import numpy as np 

31from lsst.pex.config import ( 

32 ChoiceField, 

33 Config, 

34 ConfigDictField, 

35 ConfigField, 

36 DictField, 

37 Field, 

38 FieldValidationError, 

39 ListField, 

40) 

41from matplotlib.figure import Figure 

42from matplotlib.gridspec import GridSpec 

43from matplotlib.patches import Rectangle 

44 

45from ...interfaces import KeyedData, KeyedDataSchema, PlotAction, Vector 

46from ...statistics import sigmaMad 

47from .plotUtils import addPlotInfo 

48 

49log = logging.getLogger(__name__) 

50 

51 

52class HistStatsPanel(Config): 

53 """A Config class that holds parameters to configure a the stats panel 

54 shown for histPlot. 

55 

56 The fields in this class correspond to the parameters that can be used to 

57 customize the HistPlot stats panel. 

58 

59 - The ListField parameter a dict to specify names of 3 stat columns accepts 

60 latex formating 

61 

62 - The other parameters (stat1, stat2, stat3) are lists of strings that 

63 specify vector keys correspoinding to scalar values computed in the 

64 prep/process/produce steps of an analysis tools plot/metric configurable 

65 action. There should be one key for each group in the HistPanel. 

66 

67 A separate config class is used instead of constructing 

68 `~lsst.pex.config.DictField`'s in HistPanel for each parameter for clarity 

69 and consistency. 

70 

71 Notes 

72 ----- 

73 This is intended to be used as a configuration of the HistPlot/HistPanel 

74 class. 

75 

76 If no HistStatsPanel is specified then the default behavor persists where 

77 the stats panel shows N / median / sigma_mad for each group in the panel. 

78 """ 

79 

80 statsLabels = ListField[str]( 

81 doc="list specifying the labels for stats", 

82 length=3, 

83 default={"N$_{{data}}$", "Med", "${{\\sigma}}_{{MAD}}$"}, 

84 ) 

85 stat1 = ListField[str]( 

86 doc="A list specifying the vector keys of the first scalar statistic to be shown in this panel." 

87 "there should be one entry for each hist in the panel", 

88 default=None, 

89 optional=True, 

90 ) 

91 stat2 = ListField[str]( 

92 doc="A list specifying the vector keys of the second scalar statistic to be shown in this panel." 

93 "there should be one entry for each hist in the panel", 

94 default=None, 

95 optional=True, 

96 ) 

97 stat3 = ListField[str]( 

98 doc="A list specifying the vector keys of the third scalar statistic to be shown in this panel." 

99 "there should be one entry for each hist in the panel", 

100 default=None, 

101 optional=True, 

102 ) 

103 

104 def validate(self): 

105 super().validate() 

106 if not all([self.stat1, self.stat2, self.stat3]) and any([self.stat1, self.stat2, self.stat3]): 

107 raise ValueError(f"{self._name}: If one stat is configured, all 3 stats must be configured") 

108 

109 

110class HistPanel(Config): 

111 label = Field[str]( 

112 doc="Panel x-axis label.", 

113 default="label", 

114 ) 

115 hists = DictField[str, str]( 

116 doc="A dict specifying the histograms to be plotted in this panel. Keys are used to identify " 

117 "histogram IDs. Values are used to add to the legend label displayed in the upper corner of the " 

118 "panel.", 

119 optional=False, 

120 ) 

121 yscale = Field[str]( 

122 doc="Y axis scaling.", 

123 default="linear", 

124 ) 

125 bins = Field[int]( 

126 doc="Number of x axis bins within plot x-range.", 

127 default=50, 

128 ) 

129 rangeType = ChoiceField[str]( 

130 doc="Set the type of range to use for the x-axis. Range bounds will be set according to " 

131 "the values of lowerRange and upperRange.", 

132 allowed={ 

133 "percentile": "Upper and lower percentile ranges of the data.", 

134 "sigmaMad": "Range is (sigmaMad - lowerRange*sigmaMad, sigmaMad + upperRange*sigmaMad).", 

135 "fixed": "Range is fixed to (lowerRange, upperRange).", 

136 }, 

137 default="percentile", 

138 ) 

139 lowerRange = Field[float]( 

140 doc="Lower range specifier for the histogram bins. See rangeType for interpretation " 

141 "based on the type of range requested. If more than one histogram is plotted in a given " 

142 "panel and rangeType is not set to fixed, the limit is the minimum value across all input " 

143 "data.", 

144 default=0.0, 

145 ) 

146 upperRange = Field[float]( 

147 doc="Upper range specifier for the histogram bins. See rangeType for interpretation " 

148 "based on the type of range requested. If more than one histogram is plotted in a given " 

149 "panel and rangeType is not set to fixed, the limit is the maximum value across all input " 

150 "data.", 

151 default=100.0, 

152 ) 

153 referenceValue = Field[float]( 

154 doc="Value at which to add a black solid vertical line. Ignored if set to `None`.", 

155 default=None, 

156 optional=True, 

157 ) 

158 histDensity = Field[bool]( 

159 doc="Whether to plot the histogram as a normalized probability distribution. Must also " 

160 "provide a value for referenceValue", 

161 default=False, 

162 ) 

163 statsPanel = ConfigField[HistStatsPanel]( 

164 doc="configuration for stats to be shown on plot, if None then " 

165 "default stats: N, median, sigma mad are shown", 

166 default=None, 

167 ) 

168 

169 def validate(self): 

170 super().validate() 

171 if self.rangeType == "percentile" and self.lowerRange < 0.0 or self.upperRange > 100.0: 

172 msg = ( 

173 "For rangeType %s, ranges must obey: lowerRange >= 0 and upperRange <= 100." % self.rangeType 

174 ) 

175 raise FieldValidationError(self.__class__.rangeType, self, msg) 

176 if self.rangeType == "sigmaMad" and self.lowerRange < 0.0: 

177 msg = ( 

178 "For rangeType %s, lower range must obey: lowerRange >= 0 (the lower range is " 

179 "set as median - lowerRange*sigmaMad." % self.rangeType 

180 ) 

181 raise FieldValidationError(self.__class__.rangeType, self, msg) 

182 if self.rangeType == "fixed" and (self.upperRange - self.lowerRange) == 0.0: 

183 msg = ( 

184 "For rangeType %s, lower and upper ranges must differ (i.e. must obey: " 

185 "upperRange - lowerRange != 0)." % self.rangeType 

186 ) 

187 raise FieldValidationError(self.__class__.rangeType, self, msg) 

188 if self.histDensity and self.referenceValue is None: 

189 msg = "Must provide referenceValue if histDensity is True." 

190 raise FieldValidationError(self.__class__.referenceValue, self, msg) 

191 

192 

193class HistPlot(PlotAction): 

194 panels = ConfigDictField( 

195 doc="A configurable dict describing the panels to be plotted, and the histograms for each panel.", 

196 keytype=str, 

197 itemtype=HistPanel, 

198 default={}, 

199 ) 

200 cmap = Field[str]( 

201 doc="Color map used for histogram lines. All types available via `plt.cm` may be used. " 

202 "A number of custom color maps are also defined: `newtab10`, `bright`, `vibrant`.", 

203 default="newtab10", 

204 ) 

205 

206 def getInputSchema(self) -> KeyedDataSchema: 

207 for panel in self.panels: # type: ignore 

208 for histData in self.panels[panel].hists.items(): # type: ignore 

209 yield histData, Vector 

210 

211 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure: 

212 return self.makePlot(data, **kwargs) 

213 # table is a dict that needs: x, y, run, skymap, filter, tract, 

214 

215 def makePlot( 

216 self, data: KeyedData, plotInfo: Mapping[str, str] = None, **kwargs # type: ignore 

217 ) -> Figure: 

218 """Make an N-panel plot with a user-configurable number of histograms 

219 displayed in each panel. 

220 

221 Parameters 

222 ---------- 

223 data : `pandas.core.frame.DataFrame` 

224 The catalog to plot the points from. 

225 plotInfo : `dict` 

226 A dictionary of information about the data being plotted with keys: 

227 `"run"` 

228 Output run for the plots (`str`). 

229 `"tractTableType"` 

230 Table from which results are taken (`str`). 

231 `"plotName"` 

232 Output plot name (`str`) 

233 `"SN"` 

234 The global signal-to-noise data threshold (`float`) 

235 `"skymap"` 

236 The type of skymap used for the data (`str`). 

237 `"tract"` 

238 The tract that the data comes from (`int`). 

239 `"bands"` 

240 The bands used for this data (`str` or `list`). 

241 `"visit"` 

242 The visit that the data comes from (`int`) 

243 

244 Returns 

245 ------- 

246 fig : `matplotlib.figure.Figure` 

247 The resulting figure. 

248 

249 """ 

250 

251 # set up figure 

252 fig = plt.figure(dpi=300) 

253 hist_fig, side_fig = fig.subfigures(1, 2, wspace=0, width_ratios=[3, 1]) 

254 axs, ncols, nrows = self._makeAxes(hist_fig) 

255 

256 # loop over each panel; plot histograms 

257 colors = self._assignColors() 

258 nth_panel = len(self.panels) 

259 nth_col = ncols 

260 nth_row = nrows - 1 

261 label_font_size = max(6, 10 - nrows) 

262 for panel, ax in zip(self.panels, axs): 

263 nth_panel -= 1 

264 nth_col = ncols - 1 if nth_col == 0 else nth_col - 1 

265 if nth_panel == 0 and nrows * ncols - len(self.panels) > 0: 

266 nth_col -= 1 

267 # Set font size for legend based on number of panels being plotted. 

268 legend_font_size = max(4, int(8 - len(self.panels[panel].hists) / 2 - nrows // 2)) # type: ignore 

269 nums, meds, mads, stats_dict = self._makePanel( 

270 data, 

271 panel, 

272 ax, 

273 colors[panel], 

274 label_font_size=label_font_size, 

275 legend_font_size=legend_font_size, 

276 ncols=ncols, 

277 ) 

278 

279 all_handles, all_nums, all_meds, all_mads = [], [], [], [] 

280 handles, labels = ax.get_legend_handles_labels() # code for plotting 

281 all_handles += handles 

282 all_nums += nums 

283 all_meds += meds 

284 all_mads += mads 

285 title_str = self.panels[panel].label # type: ignore 

286 # add side panel; add statistics 

287 self._addStatisticsPanel( 

288 side_fig, 

289 all_handles, 

290 all_nums, 

291 all_meds, 

292 all_mads, 

293 stats_dict, 

294 legend_font_size=legend_font_size, 

295 yAnchor0=ax.get_position().y0, 

296 nth_row=nth_row, 

297 nth_col=nth_col, 

298 title_str=title_str, 

299 ) 

300 nth_row = nth_row - 1 if nth_col == 0 else nth_row 

301 

302 # add general plot info 

303 if plotInfo is not None: 

304 hist_fig = addPlotInfo(hist_fig, plotInfo) 

305 

306 # finish up 

307 plt.draw() 

308 return fig 

309 

310 def _makeAxes(self, fig): 

311 """Determine axes layout for main histogram figure.""" 

312 num_panels = len(self.panels) 

313 if num_panels <= 1: 

314 ncols = 1 

315 else: 

316 ncols = 2 

317 nrows = int(np.ceil(num_panels / ncols)) 

318 

319 gs = GridSpec(nrows, ncols, left=0.12, right=0.99, bottom=0.1, top=0.88, wspace=0.31, hspace=0.45) 

320 

321 axs = [] 

322 counter = 0 

323 for row in range(nrows): 

324 for col in range(ncols): 

325 counter += 1 

326 if counter < num_panels: 

327 axs.append(fig.add_subplot(gs[row : row + 1, col : col + 1])) 

328 else: 

329 axs.append(fig.add_subplot(gs[row : row + 1, col : np.min([col + 2, ncols + 1])])) 

330 break 

331 

332 return axs, ncols, nrows 

333 

334 def _assignColors(self): 

335 """Assign colors to histograms using a given color map.""" 

336 custom_cmaps = dict( 

337 # https://www.tableau.com/about/blog/2016/7/colors-upgrade-tableau-10-56782 

338 newtab10=[ 

339 "#4e79a7", 

340 "#f28e2b", 

341 "#e15759", 

342 "#76b7b2", 

343 "#59a14f", 

344 "#edc948", 

345 "#b07aa1", 

346 "#ff9da7", 

347 "#9c755f", 

348 "#bab0ac", 

349 ], 

350 # https://personal.sron.nl/~pault/#fig:scheme_bright 

351 bright=[ 

352 "#4477AA", 

353 "#EE6677", 

354 "#228833", 

355 "#CCBB44", 

356 "#66CCEE", 

357 "#AA3377", 

358 "#BBBBBB", 

359 ], 

360 # https://personal.sron.nl/~pault/#fig:scheme_vibrant 

361 vibrant=[ 

362 "#EE7733", 

363 "#0077BB", 

364 "#33BBEE", 

365 "#EE3377", 

366 "#CC3311", 

367 "#009988", 

368 "#BBBBBB", 

369 ], 

370 ) 

371 if self.cmap in custom_cmaps.keys(): 

372 all_colors = custom_cmaps[self.cmap] 

373 else: 

374 try: 

375 all_colors = getattr(plt.cm, self.cmap).copy().colors 

376 except AttributeError: 

377 raise ValueError(f"Unrecognized color map: {self.cmap}") 

378 

379 counter = 0 

380 colors = defaultdict(list) 

381 for panel in self.panels: 

382 for hist in self.panels[panel].hists: 

383 colors[panel].append(all_colors[counter % len(all_colors)]) 

384 counter += 1 

385 return colors 

386 

387 def _makePanel(self, data, panel, ax, colors, label_font_size=9, legend_font_size=7, ncols=1): 

388 """Plot a single panel containing histograms.""" 

389 nums, meds, mads = [], [], [] 

390 for i, hist in enumerate(self.panels[panel].hists): 

391 hist_data = data[hist][np.isfinite(data[hist])] 

392 num, med, mad = self._calcStats(hist_data) 

393 nums.append(num) 

394 meds.append(med) 

395 mads.append(mad) 

396 panel_range = self._getPanelRange(data, panel, mads=mads, meds=meds) 

397 

398 for i, hist in enumerate(self.panels[panel].hists): 

399 hist_data = data[hist][np.isfinite(data[hist])] 

400 ax.hist( 

401 hist_data, 

402 range=panel_range, 

403 bins=self.panels[panel].bins, 

404 histtype="step", 

405 density=self.panels[panel].histDensity, 

406 lw=2, 

407 color=colors[i], 

408 label=self.panels[panel].hists[hist], 

409 ) 

410 ax.axvline(meds[i], ls=(0, (5, 3)), lw=1, c=colors[i]) 

411 

412 ax.legend(fontsize=legend_font_size, loc="upper left", frameon=False) 

413 ax.set_xlim(panel_range) 

414 # The following accommodates spacing for ranges with large numbers 

415 # but small-ish dynamic range (example use case: RA 300-301). 

416 if ncols > 1 and max(np.abs(panel_range)) >= 100 and (panel_range[1] - panel_range[0]) < 5: 

417 ax.xaxis.set_major_formatter("{x:.2f}") 

418 ax.tick_params(axis="x", labelrotation=25, pad=-1) 

419 ax.set_xlabel(self.panels[panel].label, fontsize=label_font_size) 

420 y_label = "Normalized (PDF)" if self.panels[panel].histDensity else "Frequency" 

421 ax.set_ylabel(y_label, fontsize=label_font_size) 

422 ax.set_yscale(self.panels[panel].yscale) 

423 ax.tick_params(labelsize=max(5, label_font_size - 2)) 

424 # add a buffer to the top of the plot to allow headspace for labels 

425 ylims = list(ax.get_ylim()) 

426 if ax.get_yscale() == "log": 

427 ylims[1] = 10 ** (np.log10(ylims[1]) * 1.1) 

428 else: 

429 ylims[1] *= 1.1 

430 ax.set_ylim(ylims[0], ylims[1]) 

431 

432 # Draw a vertical line at a reference value, if given. If histDensity 

433 # is True, also plot a reference PDF with mean = referenceValue and 

434 # sigma = 1 for reference. 

435 if self.panels[panel].referenceValue is not None: 

436 ax = self._addReferenceLines(ax, panel, panel_range, legend_font_size=legend_font_size) 

437 

438 # Check if we should use the default stats panel or if a custom one 

439 # has been created. 

440 statList = [ 

441 self.panels[panel].statsPanel.stat1, 

442 self.panels[panel].statsPanel.stat2, 

443 self.panels[panel].statsPanel.stat3, 

444 ] 

445 if not any(statList): 

446 stats_dict = { 

447 "statLabels": ["N$_{{data}}$", "Med", "${{\\sigma}}_{{MAD}}$"], 

448 "stat1": nums, 

449 "stat2": meds, 

450 "stat3": mads, 

451 } 

452 elif all(statList): 

453 stat1 = [data[stat] for stat in self.panels[panel].statsPanel.stat1] 

454 stat2 = [data[stat] for stat in self.panels[panel].statsPanel.stat2] 

455 stat3 = [data[stat] for stat in self.panels[panel].statsPanel.stat3] 

456 stats_dict = { 

457 "statLabels": self.panels[panel].statsPanel.statsLabels, 

458 "stat1": stat1, 

459 "stat2": stat2, 

460 "stat3": stat3, 

461 } 

462 else: 

463 raise RuntimeError("Invalid configuration of HistStatPanel") 

464 

465 return nums, meds, mads, stats_dict 

466 

467 def _getPanelRange(self, data, panel, mads=None, meds=None): 

468 """Determine panel x-axis range based config settings.""" 

469 panel_range = [np.nan, np.nan] 

470 rangeType = self.panels[panel].rangeType 

471 lowerRange = self.panels[panel].lowerRange 

472 upperRange = self.panels[panel].upperRange 

473 if rangeType == "percentile": 

474 panel_range = self._getPercentilePanelRange(data, panel) 

475 elif rangeType == "sigmaMad": 

476 # Set the panel range to extend lowerRange[upperRange] times the 

477 # maximum sigmaMad for the datasets in the panel to the left[right] 

478 # from the minimum[maximum] median value of all datasets in the 

479 # panel. 

480 maxMad = np.nanmax(mads) 

481 maxMed = np.nanmax(meds) 

482 minMed = np.nanmin(meds) 

483 panel_range = [minMed - lowerRange * maxMad, maxMed + upperRange * maxMad] 

484 if panel_range[1] - panel_range[0] == 0: 

485 self.log.info( 

486 "NOTE: panel_range for {} based on med/sigMad was 0. Computing using " 

487 "percentile range instead.".format(panel) 

488 ) 

489 panel_range = self._getPercentilePanelRange(data, panel) 

490 elif rangeType == "fixed": 

491 panel_range = [lowerRange, upperRange] 

492 else: 

493 raise RuntimeError(f"Invalid rangeType: {rangeType}") 

494 return panel_range 

495 

496 def _getPercentilePanelRange(self, data, panel): 

497 """Determine panel x-axis range based on data percentile limits.""" 

498 panel_range = [np.nan, np.nan] 

499 for hist in self.panels[panel].hists: 

500 hist_range = np.nanpercentile( 

501 data[hist], [self.panels[panel].lowerRange, self.panels[panel].upperRange] 

502 ) 

503 panel_range[0] = np.nanmin([panel_range[0], hist_range[0]]) 

504 panel_range[1] = np.nanmax([panel_range[1], hist_range[1]]) 

505 return panel_range 

506 

507 def _calcStats(self, data): 

508 """Calculate the number of data points, median, and median absolute 

509 deviation of input data.""" 

510 num = len(data) 

511 med = np.nanmedian(data) 

512 mad = sigmaMad(data) 

513 return num, med, mad 

514 

515 def _addReferenceLines(self, ax, panel, panel_range, legend_font_size=7): 

516 """Draw the vertical reference line and density curve (if requested) 

517 on the panel. 

518 """ 

519 ax2 = ax.twinx() 

520 ax2.axis("off") 

521 ax2.set_xlim(ax.get_xlim()) 

522 ax2.set_ylim(ax.get_ylim()) 

523 

524 if self.panels[panel].histDensity: 

525 reference_label = None 

526 else: 

527 reference_label = "${{\\mu_{{ref}}}}$: {}".format(self.panels[panel].referenceValue) 

528 ax2.axvline( 

529 self.panels[panel].referenceValue, ls="-", lw=1, c="black", zorder=0, label=reference_label 

530 ) 

531 if self.panels[panel].histDensity: 

532 ref_x = np.arange(panel_range[0], panel_range[1], (panel_range[1] - panel_range[0]) / 100.0) 

533 ref_mean = self.panels[panel].referenceValue 

534 ref_std = 1.0 

535 ref_y = ( 

536 1.0 

537 / (ref_std * np.sqrt(2.0 * np.pi)) 

538 * np.exp(-((ref_x - ref_mean) ** 2) / (2.0 * ref_std**2)) 

539 ) 

540 ax2.fill_between(ref_x, ref_y, alpha=0.1, color="black", label="P$_{{norm}}(0,1)$", zorder=-1) 

541 # Make sure the y-axis extends beyond the data plotted and that 

542 # the y-ranges of both axes are in sync. 

543 y_max = max(max(ref_y), ax2.get_ylim()[1]) 

544 if ax2.get_ylim()[1] < 1.05 * y_max: 

545 ax.set_ylim(ax.get_ylim()[0], 1.05 * y_max) 

546 ax2.set_ylim(ax.get_ylim()) 

547 ax2.legend(fontsize=legend_font_size, handlelength=1.5, loc="upper right", frameon=False) 

548 

549 return ax 

550 

551 def _addStatisticsPanel( 

552 self, 

553 fig, 

554 handles, 

555 nums, 

556 meds, 

557 mads, 

558 stats_dict, 

559 legend_font_size=8, 

560 yAnchor0=0.0, 

561 nth_row=0, 

562 nth_col=0, 

563 title_str=None, 

564 ): 

565 """Add an adjoining panel containing histogram summary statistics.""" 

566 ax = fig.add_subplot(1, 1, 1) 

567 ax.axis("off") 

568 plt.subplots_adjust(left=0.05, right=0.95, bottom=0.0, top=1.0) 

569 # empty handle, used to populate the bespoke legend layout 

570 empty = Rectangle((0, 0), 1, 1, fc="w", fill=False, edgecolor="none", linewidth=0) 

571 

572 # set up new legend handles and labels 

573 legend_handles = [empty] + handles + ([empty] * 3 * len(handles)) + ([empty] * 3) 

574 

575 legend_labels = ( 

576 ([""] * (len(handles) + 1)) 

577 + [stats_dict["statLabels"][0]] 

578 + [f"{x:.3g}" for x in stats_dict["stat1"]] 

579 + [stats_dict["statLabels"][1]] 

580 + [f"{x:.3g}" for x in stats_dict["stat2"]] 

581 + [stats_dict["statLabels"][2]] 

582 + [f"{x:.3g}" for x in stats_dict["stat3"]] 

583 ) 

584 # Set the y anchor for the legend such that it roughly lines up with 

585 # the panels. 

586 yAnchor = max(0, yAnchor0 - 0.01) + nth_col * (0.008 + len(nums) * 0.005) * legend_font_size 

587 

588 nth_legend = ax.legend( 

589 legend_handles, 

590 legend_labels, 

591 loc="lower left", 

592 bbox_to_anchor=(0.0, yAnchor), 

593 ncol=4, 

594 handletextpad=-0.25, 

595 fontsize=legend_font_size, 

596 borderpad=0, 

597 frameon=False, 

598 columnspacing=-0.25, 

599 title=title_str, 

600 title_fontproperties={"weight": "bold", "size": legend_font_size}, 

601 ) 

602 if nth_row + nth_col > 0: 

603 ax.add_artist(nth_legend)