Coverage for python/lsst/analysis/tools/actions/plot/histPlot.py: 16%
215 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-02-11 04:04 -0800
« prev ^ index » next coverage.py v6.5.0, created at 2023-02-11 04:04 -0800
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ("HistPanel", "HistPlot", "HistStatsPanel")
25import logging
26from collections import defaultdict
27from typing import Mapping
29import matplotlib.pyplot as plt
30import numpy as np
31from lsst.pex.config import (
32 ChoiceField,
33 Config,
34 ConfigDictField,
35 ConfigField,
36 DictField,
37 Field,
38 FieldValidationError,
39 ListField,
40)
41from matplotlib.figure import Figure
42from matplotlib.gridspec import GridSpec
43from matplotlib.patches import Rectangle
45from ...interfaces import KeyedData, KeyedDataSchema, PlotAction, Vector
46from ...statistics import sigmaMad
47from .plotUtils import addPlotInfo
49log = logging.getLogger(__name__)
52class HistStatsPanel(Config):
53 """A Config class that holds parameters to configure a the stats panel
54 shown for histPlot.
56 The fields in this class correspond to the parameters that can be used to
57 customize the HistPlot stats panel.
59 - The ListField parameter a dict to specify names of 3 stat columns accepts
60 latex formating
62 - The other parameters (stat1, stat2, stat3) are lists of strings that
63 specify vector keys correspoinding to scalar values computed in the
64 prep/process/produce steps of an analysis tools plot/metric configurable
65 action. There should be one key for each group in the HistPanel.
67 A separate config class is used instead of constructing
68 `~lsst.pex.config.DictField`'s in HistPanel for each parameter for clarity
69 and consistency.
73 Notes
74 -----
75 This is intended to be used as a configuration of the HistPlot/HistPanel
76 class.
78 If no HistStatsPanel is specified then the default behavor persists where
79 the stats panel shows N / median / sigma_mad for each group in the panel.
80 """
82 statsLabels = ListField[str](
83 doc="list specifying the labels for stats",
84 length=3,
85 default={"N$_{{data}}$", "Med", "${{\\sigma}}_{{MAD}}$"},
86 )
87 stat1 = ListField[str](
88 doc="A list specifying the vector keys of the first scalar statistic to be shown in this panel."
89 "there should be one entry for each hist in the panel",
90 default=[],
91 )
92 stat2 = ListField[str](
93 doc="A list specifying the vector keys of the second scalar statistic to be shown in this panel."
94 "there should be one entry for each hist in the panel",
95 default=[],
96 )
97 stat3 = ListField[str](
98 doc="A list specifying the vector keys of the third scalar statistic to be shown in this panel."
99 "there should be one entry for each hist in the panel",
100 default=[],
101 )
104class HistPanel(Config):
105 label = Field[str](
106 doc="Panel x-axis label.",
107 default="label",
108 )
109 hists = DictField[str, str](
110 doc="A dict specifying the histograms to be plotted in this panel. Keys are used to identify "
111 "histogram IDs. Values are used to add to the legend label displayed in the upper corner of the "
112 "panel.",
113 optional=False,
114 )
115 yscale = Field[str](
116 doc="Y axis scaling.",
117 default="linear",
118 )
119 bins = Field[int](
120 doc="Number of x axis bins within plot x-range.",
121 default=50,
122 )
123 rangeType = ChoiceField[str](
124 doc="Set the type of range to use for the x-axis. Range bounds will be set according to "
125 "the values of lowerRange and upperRange.",
126 allowed={
127 "percentile": "Upper and lower percentile ranges of the data.",
128 "sigmaMad": "Range is (sigmaMad - lowerRange*sigmaMad, sigmaMad + upperRange*sigmaMad).",
129 "fixed": "Range is fixed to (lowerRange, upperRange).",
130 },
131 default="percentile",
132 )
133 lowerRange = Field[float](
134 doc="Lower range specifier for the histogram bins. See rangeType for interpretation "
135 "based on the type of range requested. If more than one histogram is plotted in a given "
136 "panel and rangeType is not set to fixed, the limit is the minimum value across all input "
137 "data.",
138 default=0.0,
139 )
140 upperRange = Field[float](
141 doc="Upper range specifier for the histogram bins. See rangeType for interpretation "
142 "based on the type of range requested. If more than one histogram is plotted in a given "
143 "panel and rangeType is not set to fixed, the limit is the maximum value across all input "
144 "data.",
145 default=100.0,
146 )
147 referenceValue = Field[float](
148 doc="Value at which to add a black solid vertical line. Ignored if set to `None`.",
149 default=None,
150 optional=True,
151 )
152 histDensity = Field[bool](
153 doc="Whether to plot the histogram as a normalized probability distribution. Must also "
154 "provide a value for referenceValue",
155 default=False,
156 )
157 statsPanel = ConfigField[HistStatsPanel](
158 doc="configuration for stats to be shown on plot, if None then "
159 "default stats: N, median, sigma mad are shown",
160 default=None,
161 )
163 def validate(self):
164 super().validate()
165 if self.rangeType == "percentile" and self.lowerRange < 0.0 or self.upperRange > 100.0:
166 msg = (
167 "For rangeType %s, ranges must obey: lowerRange >= 0 and upperRange <= 100." % self.rangeType
168 )
169 raise FieldValidationError(self.__class__.rangeType, self, msg)
170 if self.rangeType == "sigmaMad" and self.lowerRange < 0.0:
171 msg = (
172 "For rangeType %s, lower range must obey: lowerRange >= 0 (the lower range is "
173 "set as median - lowerRange*sigmaMad." % self.rangeType
174 )
175 raise FieldValidationError(self.__class__.rangeType, self, msg)
176 if self.rangeType == "fixed" and (self.upperRange - self.lowerRange) == 0.0:
177 msg = (
178 "For rangeType %s, lower and upper ranges must differ (i.e. must obey: "
179 "upperRange - lowerRange != 0)." % self.rangeType
180 )
181 raise FieldValidationError(self.__class__.rangeType, self, msg)
182 if self.histDensity and self.referenceValue is None:
183 msg = "Must provide referenceValue if histDensity is True."
184 raise FieldValidationError(self.__class__.referenceValue, self, msg)
187class HistPlot(PlotAction):
188 panels = ConfigDictField(
189 doc="A configurable dict describing the panels to be plotted, and the histograms for each panel.",
190 keytype=str,
191 itemtype=HistPanel,
192 default={},
193 )
194 cmap = Field[str](
195 doc="Color map used for histogram lines. All types available via `plt.cm` may be used. "
196 "A number of custom color maps are also defined: `newtab10`, `bright`, `vibrant`.",
197 default="newtab10",
198 )
200 def getInputSchema(self) -> KeyedDataSchema:
201 for panel in self.panels: # type: ignore
202 for histData in self.panels[panel].hists.items(): # type: ignore
203 yield histData, Vector
205 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure:
206 return self.makePlot(data, **kwargs)
207 # table is a dict that needs: x, y, run, skymap, filter, tract,
209 def makePlot(
210 self, data: KeyedData, plotInfo: Mapping[str, str] = None, **kwargs # type: ignore
211 ) -> Figure:
212 """Make an N-panel plot with a user-configurable number of histograms
213 displayed in each panel.
215 Parameters
216 ----------
217 data : `pandas.core.frame.DataFrame`
218 The catalog to plot the points from.
219 plotInfo : `dict`
220 A dictionary of information about the data being plotted with keys:
221 `"run"`
222 Output run for the plots (`str`).
223 `"tractTableType"`
224 Table from which results are taken (`str`).
225 `"plotName"`
226 Output plot name (`str`)
227 `"SN"`
228 The global signal-to-noise data threshold (`float`)
229 `"skymap"`
230 The type of skymap used for the data (`str`).
231 `"tract"`
232 The tract that the data comes from (`int`).
233 `"bands"`
234 The bands used for this data (`str` or `list`).
235 `"visit"`
236 The visit that the data comes from (`int`)
238 Returns
239 -------
240 fig : `matplotlib.figure.Figure`
241 The resulting figure.
243 """
245 # set up figure
246 fig = plt.figure(dpi=300)
247 hist_fig, side_fig = fig.subfigures(1, 2, wspace=0, width_ratios=[3, 1])
248 axs, ncols, nrows = self._makeAxes(hist_fig)
250 # loop over each panel; plot histograms
251 colors = self._assignColors()
252 nth_panel = len(self.panels)
253 nth_col = ncols
254 nth_row = nrows - 1
255 label_font_size = max(6, 10 - nrows)
256 for panel, ax in zip(self.panels, axs):
257 nth_panel -= 1
258 nth_col = ncols - 1 if nth_col == 0 else nth_col - 1
259 if nth_panel == 0 and nrows * ncols - len(self.panels) > 0:
260 nth_col -= 1
261 # Set font size for legend based on number of panels being plotted.
262 legend_font_size = max(4, int(8 - len(self.panels[panel].hists) / 2 - nrows // 2)) # type: ignore
263 nums, meds, mads, stats_dict = self._makePanel(
264 data,
265 panel,
266 ax,
267 colors[panel],
268 label_font_size=label_font_size,
269 legend_font_size=legend_font_size,
270 ncols=ncols,
271 )
273 all_handles, all_nums, all_meds, all_mads = [], [], [], []
274 handles, labels = ax.get_legend_handles_labels() # code for plotting
275 all_handles += handles
276 all_nums += nums
277 all_meds += meds
278 all_mads += mads
279 title_str = self.panels[panel].label # type: ignore
280 # add side panel; add statistics
281 self._addStatisticsPanel(
282 side_fig,
283 all_handles,
284 all_nums,
285 all_meds,
286 all_mads,
287 stats_dict,
288 legend_font_size=legend_font_size,
289 yAnchor0=ax.get_position().y0,
290 nth_row=nth_row,
291 nth_col=nth_col,
292 title_str=title_str,
293 )
294 nth_row = nth_row - 1 if nth_col == 0 else nth_row
296 # add general plot info
297 if plotInfo is not None:
298 hist_fig = addPlotInfo(hist_fig, plotInfo)
300 # finish up
301 plt.draw()
302 return fig
304 def _makeAxes(self, fig):
305 """Determine axes layout for main histogram figure."""
306 num_panels = len(self.panels)
307 if num_panels <= 1:
308 ncols = 1
309 else:
310 ncols = 2
311 nrows = int(np.ceil(num_panels / ncols))
313 gs = GridSpec(nrows, ncols, left=0.12, right=0.99, bottom=0.1, top=0.88, wspace=0.31, hspace=0.45)
315 axs = []
316 counter = 0
317 for row in range(nrows):
318 for col in range(ncols):
319 counter += 1
320 if counter < num_panels:
321 axs.append(fig.add_subplot(gs[row : row + 1, col : col + 1]))
322 else:
323 axs.append(fig.add_subplot(gs[row : row + 1, col : np.min([col + 2, ncols + 1])]))
324 break
326 return axs, ncols, nrows
328 def _assignColors(self):
329 """Assign colors to histograms using a given color map."""
330 custom_cmaps = dict(
331 # https://www.tableau.com/about/blog/2016/7/colors-upgrade-tableau-10-56782
332 newtab10=[
333 "#4e79a7",
334 "#f28e2b",
335 "#e15759",
336 "#76b7b2",
337 "#59a14f",
338 "#edc948",
339 "#b07aa1",
340 "#ff9da7",
341 "#9c755f",
342 "#bab0ac",
343 ],
344 # https://personal.sron.nl/~pault/#fig:scheme_bright
345 bright=[
346 "#4477AA",
347 "#EE6677",
348 "#228833",
349 "#CCBB44",
350 "#66CCEE",
351 "#AA3377",
352 "#BBBBBB",
353 ],
354 # https://personal.sron.nl/~pault/#fig:scheme_vibrant
355 vibrant=[
356 "#EE7733",
357 "#0077BB",
358 "#33BBEE",
359 "#EE3377",
360 "#CC3311",
361 "#009988",
362 "#BBBBBB",
363 ],
364 )
365 if self.cmap in custom_cmaps.keys():
366 all_colors = custom_cmaps[self.cmap]
367 else:
368 try:
369 all_colors = getattr(plt.cm, self.cmap).copy().colors
370 except AttributeError:
371 raise ValueError(f"Unrecognized color map: {self.cmap}")
373 counter = 0
374 colors = defaultdict(list)
375 for panel in self.panels:
376 for hist in self.panels[panel].hists:
377 colors[panel].append(all_colors[counter % len(all_colors)])
378 counter += 1
379 return colors
381 def _makePanel(self, data, panel, ax, colors, label_font_size=9, legend_font_size=7, ncols=1):
382 """Plot a single panel containing histograms."""
383 nums, meds, mads = [], [], []
384 for i, hist in enumerate(self.panels[panel].hists):
385 hist_data = data[hist][np.isfinite(data[hist])]
386 num, med, mad = self._calcStats(hist_data)
387 nums.append(num)
388 meds.append(med)
389 mads.append(mad)
390 panel_range = self._getPanelRange(data, panel, mads=mads, meds=meds)
392 for i, hist in enumerate(self.panels[panel].hists):
393 hist_data = data[hist][np.isfinite(data[hist])]
394 ax.hist(
395 hist_data,
396 range=panel_range,
397 bins=self.panels[panel].bins,
398 histtype="step",
399 density=self.panels[panel].histDensity,
400 lw=2,
401 color=colors[i],
402 label=self.panels[panel].hists[hist],
403 )
404 ax.axvline(meds[i], ls=(0, (5, 3)), lw=1, c=colors[i])
406 ax.legend(fontsize=legend_font_size, loc="upper left", frameon=False)
407 ax.set_xlim(panel_range)
408 # The following accommodates spacing for ranges with large numbers
409 # but small-ish dynamic range (example use case: RA 300-301).
410 if ncols > 1 and max(np.abs(panel_range)) >= 100 and (panel_range[1] - panel_range[0]) < 5:
411 ax.xaxis.set_major_formatter("{x:.2f}")
412 ax.tick_params(axis="x", labelrotation=25, pad=-1)
413 ax.set_xlabel(self.panels[panel].label, fontsize=label_font_size)
414 y_label = "Normalized (PDF)" if self.panels[panel].histDensity else "Frequency"
415 ax.set_ylabel(y_label, fontsize=label_font_size)
416 ax.set_yscale(self.panels[panel].yscale)
417 ax.tick_params(labelsize=max(5, label_font_size - 2))
418 # add a buffer to the top of the plot to allow headspace for labels
419 ylims = list(ax.get_ylim())
420 if ax.get_yscale() == "log":
421 ylims[1] = 10 ** (np.log10(ylims[1]) * 1.1)
422 else:
423 ylims[1] *= 1.1
424 ax.set_ylim(ylims[0], ylims[1])
426 # Draw a vertical line at a reference value, if given. If histDensity
427 # is True, also plot a reference PDF with mean = referenceValue and
428 # sigma = 1 for reference.
429 if self.panels[panel].referenceValue is not None:
430 ax = self._addReferenceLines(ax, panel, panel_range, legend_font_size=legend_font_size)
432 if self.panels[panel].statsPanel is None:
433 stats_dict = {
434 "statLabels": ["N$_{{data}}$", "Med", "${{\\sigma}}_{{MAD}}$"],
435 "stat1": nums,
436 "stat2": meds,
437 "stat3": mads,
438 }
439 else:
440 stat1 = [data[stat] for stat in self.panels[panel].statsPanel.stat1]
441 stat2 = [data[stat] for stat in self.panels[panel].statsPanel.stat2]
442 stat3 = [data[stat] for stat in self.panels[panel].statsPanel.stat3]
443 stats_dict = {
444 "statLabels": self.panels[panel].statsPanel.statsLabels,
445 "stat1": stat1,
446 "stat2": stat2,
447 "stat3": stat3,
448 }
449 return nums, meds, mads, stats_dict
451 def _getPanelRange(self, data, panel, mads=None, meds=None):
452 """Determine panel x-axis range based config settings."""
453 panel_range = [np.nan, np.nan]
454 rangeType = self.panels[panel].rangeType
455 lowerRange = self.panels[panel].lowerRange
456 upperRange = self.panels[panel].upperRange
457 if rangeType == "percentile":
458 panel_range = self._getPercentilePanelRange(data, panel)
459 elif rangeType == "sigmaMad":
460 # Set the panel range to extend lowerRange[upperRange] times the
461 # maximum sigmaMad for the datasets in the panel to the left[right]
462 # from the minimum[maximum] median value of all datasets in the
463 # panel.
464 maxMad = np.nanmax(mads)
465 maxMed = np.nanmax(meds)
466 minMed = np.nanmin(meds)
467 panel_range = [minMed - lowerRange * maxMad, maxMed + upperRange * maxMad]
468 if panel_range[1] - panel_range[0] == 0:
469 self.log.info(
470 "NOTE: panel_range for {} based on med/sigMad was 0. Computing using "
471 "percentile range instead.".format(panel)
472 )
473 panel_range = self._getPercentilePanelRange(data, panel)
474 elif rangeType == "fixed":
475 panel_range = [lowerRange, upperRange]
476 else:
477 raise RuntimeError(f"Invalid rangeType: {rangeType}")
478 return panel_range
480 def _getPercentilePanelRange(self, data, panel):
481 """Determine panel x-axis range based on data percentile limits."""
482 panel_range = [np.nan, np.nan]
483 for hist in self.panels[panel].hists:
484 hist_range = np.nanpercentile(
485 data[hist], [self.panels[panel].lowerRange, self.panels[panel].upperRange]
486 )
487 panel_range[0] = np.nanmin([panel_range[0], hist_range[0]])
488 panel_range[1] = np.nanmax([panel_range[1], hist_range[1]])
489 return panel_range
491 def _calcStats(self, data):
492 """Calculate the number of data points, median, and median absolute
493 deviation of input data."""
494 num = len(data)
495 med = np.nanmedian(data)
496 mad = sigmaMad(data)
497 return num, med, mad
499 def _addReferenceLines(self, ax, panel, panel_range, legend_font_size=7):
500 """Draw the vertical reference line and density curve (if requested)
501 on the panel.
502 """
503 ax2 = ax.twinx()
504 ax2.axis("off")
505 ax2.set_xlim(ax.get_xlim())
506 ax2.set_ylim(ax.get_ylim())
508 if self.panels[panel].histDensity:
509 reference_label = None
510 else:
511 reference_label = "${{\\mu_{{ref}}}}$: {}".format(self.panels[panel].referenceValue)
512 ax2.axvline(
513 self.panels[panel].referenceValue, ls="-", lw=1, c="black", zorder=0, label=reference_label
514 )
515 if self.panels[panel].histDensity:
516 ref_x = np.arange(panel_range[0], panel_range[1], (panel_range[1] - panel_range[0]) / 100.0)
517 ref_mean = self.panels[panel].referenceValue
518 ref_std = 1.0
519 ref_y = (
520 1.0
521 / (ref_std * np.sqrt(2.0 * np.pi))
522 * np.exp(-((ref_x - ref_mean) ** 2) / (2.0 * ref_std**2))
523 )
524 ax2.fill_between(ref_x, ref_y, alpha=0.1, color="black", label="P$_{{norm}}(0,1)$", zorder=-1)
525 # Make sure the y-axis extends beyond the data plotted and that
526 # the y-ranges of both axes are in sync.
527 y_max = max(max(ref_y), ax2.get_ylim()[1])
528 if ax2.get_ylim()[1] < 1.05 * y_max:
529 ax.set_ylim(ax.get_ylim()[0], 1.05 * y_max)
530 ax2.set_ylim(ax.get_ylim())
531 ax2.legend(fontsize=legend_font_size, handlelength=1.5, loc="upper right", frameon=False)
533 return ax
535 def _addStatisticsPanel(
536 self,
537 fig,
538 handles,
539 nums,
540 meds,
541 mads,
542 stats_dict,
543 legend_font_size=8,
544 yAnchor0=0.0,
545 nth_row=0,
546 nth_col=0,
547 title_str=None,
548 ):
549 """Add an adjoining panel containing histogram summary statistics."""
550 ax = fig.add_subplot(1, 1, 1)
551 ax.axis("off")
552 plt.subplots_adjust(left=0.05, right=0.95, bottom=0.0, top=1.0)
553 # empty handle, used to populate the bespoke legend layout
554 empty = Rectangle((0, 0), 1, 1, fc="w", fill=False, edgecolor="none", linewidth=0)
556 # set up new legend handles and labels
557 legend_handles = [empty] + handles + ([empty] * 3 * len(handles)) + ([empty] * 3)
559 legend_labels = (
560 ([""] * (len(handles) + 1))
561 + [stats_dict["statLabels"][0]]
562 + [f"{x:.3g}" for x in stats_dict["stat1"]]
563 + [stats_dict["statLabels"][1]]
564 + [f"{x:.3g}" for x in stats_dict["stat2"]]
565 + [stats_dict["statLabels"][2]]
566 + [f"{x:.3g}" for x in stats_dict["stat3"]]
567 )
568 # Set the y anchor for the legend such that it roughly lines up with
569 # the panels.
570 yAnchor = max(0, yAnchor0 - 0.01) + nth_col * (0.008 + len(nums) * 0.005) * legend_font_size
572 nth_legend = ax.legend(
573 legend_handles,
574 legend_labels,
575 loc="lower left",
576 bbox_to_anchor=(0.0, yAnchor),
577 ncol=4,
578 handletextpad=-0.25,
579 fontsize=legend_font_size,
580 borderpad=0,
581 frameon=False,
582 columnspacing=-0.25,
583 title=title_str,
584 title_fontproperties={"weight": "bold", "size": legend_font_size},
585 )
586 if nth_row + nth_col > 0:
587 ax.add_artist(nth_legend)