Coverage for python/lsst/analysis/tools/actions/plot/histPlot.py: 16%
222 statements
« prev ^ index » next coverage.py v7.2.5, created at 2023-05-02 11:55 -0700
« prev ^ index » next coverage.py v7.2.5, created at 2023-05-02 11:55 -0700
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ("HistPanel", "HistPlot", "HistStatsPanel")
25import logging
26from collections import defaultdict
27from typing import Mapping
29import matplotlib.pyplot as plt
30import numpy as np
31from lsst.pex.config import (
32 ChoiceField,
33 Config,
34 ConfigDictField,
35 ConfigField,
36 DictField,
37 Field,
38 FieldValidationError,
39 ListField,
40)
41from matplotlib.figure import Figure
42from matplotlib.gridspec import GridSpec
43from matplotlib.patches import Rectangle
45from ...interfaces import KeyedData, KeyedDataSchema, PlotAction, Vector
46from ...statistics import sigmaMad
47from .plotUtils import addPlotInfo
49log = logging.getLogger(__name__)
52class HistStatsPanel(Config):
53 """A Config class that holds parameters to configure a the stats panel
54 shown for histPlot.
56 The fields in this class correspond to the parameters that can be used to
57 customize the HistPlot stats panel.
59 - The ListField parameter a dict to specify names of 3 stat columns accepts
60 latex formating
62 - The other parameters (stat1, stat2, stat3) are lists of strings that
63 specify vector keys correspoinding to scalar values computed in the
64 prep/process/produce steps of an analysis tools plot/metric configurable
65 action. There should be one key for each group in the HistPanel.
67 A separate config class is used instead of constructing
68 `~lsst.pex.config.DictField`'s in HistPanel for each parameter for clarity
69 and consistency.
71 Notes
72 -----
73 This is intended to be used as a configuration of the HistPlot/HistPanel
74 class.
76 If no HistStatsPanel is specified then the default behavor persists where
77 the stats panel shows N / median / sigma_mad for each group in the panel.
78 """
80 statsLabels = ListField[str](
81 doc="list specifying the labels for stats",
82 length=3,
83 default={"N$_{{data}}$", "Med", "${{\\sigma}}_{{MAD}}$"},
84 )
85 stat1 = ListField[str](
86 doc="A list specifying the vector keys of the first scalar statistic to be shown in this panel."
87 "there should be one entry for each hist in the panel",
88 default=None,
89 optional=True,
90 )
91 stat2 = ListField[str](
92 doc="A list specifying the vector keys of the second scalar statistic to be shown in this panel."
93 "there should be one entry for each hist in the panel",
94 default=None,
95 optional=True,
96 )
97 stat3 = ListField[str](
98 doc="A list specifying the vector keys of the third scalar statistic to be shown in this panel."
99 "there should be one entry for each hist in the panel",
100 default=None,
101 optional=True,
102 )
104 def validate(self):
105 super().validate()
106 if not all([self.stat1, self.stat2, self.stat3]) and any([self.stat1, self.stat2, self.stat3]):
107 raise ValueError(f"{self._name}: If one stat is configured, all 3 stats must be configured")
110class HistPanel(Config):
111 label = Field[str](
112 doc="Panel x-axis label.",
113 default="label",
114 )
115 hists = DictField[str, str](
116 doc="A dict specifying the histograms to be plotted in this panel. Keys are used to identify "
117 "histogram IDs. Values are used to add to the legend label displayed in the upper corner of the "
118 "panel.",
119 optional=False,
120 )
121 yscale = Field[str](
122 doc="Y axis scaling.",
123 default="linear",
124 )
125 bins = Field[int](
126 doc="Number of x axis bins within plot x-range.",
127 default=50,
128 )
129 rangeType = ChoiceField[str](
130 doc="Set the type of range to use for the x-axis. Range bounds will be set according to "
131 "the values of lowerRange and upperRange.",
132 allowed={
133 "percentile": "Upper and lower percentile ranges of the data.",
134 "sigmaMad": "Range is (sigmaMad - lowerRange*sigmaMad, sigmaMad + upperRange*sigmaMad).",
135 "fixed": "Range is fixed to (lowerRange, upperRange).",
136 },
137 default="percentile",
138 )
139 lowerRange = Field[float](
140 doc="Lower range specifier for the histogram bins. See rangeType for interpretation "
141 "based on the type of range requested. If more than one histogram is plotted in a given "
142 "panel and rangeType is not set to fixed, the limit is the minimum value across all input "
143 "data.",
144 default=0.0,
145 )
146 upperRange = Field[float](
147 doc="Upper range specifier for the histogram bins. See rangeType for interpretation "
148 "based on the type of range requested. If more than one histogram is plotted in a given "
149 "panel and rangeType is not set to fixed, the limit is the maximum value across all input "
150 "data.",
151 default=100.0,
152 )
153 referenceValue = Field[float](
154 doc="Value at which to add a black solid vertical line. Ignored if set to `None`.",
155 default=None,
156 optional=True,
157 )
158 histDensity = Field[bool](
159 doc="Whether to plot the histogram as a normalized probability distribution. Must also "
160 "provide a value for referenceValue",
161 default=False,
162 )
163 statsPanel = ConfigField[HistStatsPanel](
164 doc="configuration for stats to be shown on plot, if None then "
165 "default stats: N, median, sigma mad are shown",
166 default=None,
167 )
169 def validate(self):
170 super().validate()
171 if self.rangeType == "percentile" and self.lowerRange < 0.0 or self.upperRange > 100.0:
172 msg = (
173 "For rangeType %s, ranges must obey: lowerRange >= 0 and upperRange <= 100." % self.rangeType
174 )
175 raise FieldValidationError(self.__class__.rangeType, self, msg)
176 if self.rangeType == "sigmaMad" and self.lowerRange < 0.0:
177 msg = (
178 "For rangeType %s, lower range must obey: lowerRange >= 0 (the lower range is "
179 "set as median - lowerRange*sigmaMad." % self.rangeType
180 )
181 raise FieldValidationError(self.__class__.rangeType, self, msg)
182 if self.rangeType == "fixed" and (self.upperRange - self.lowerRange) == 0.0:
183 msg = (
184 "For rangeType %s, lower and upper ranges must differ (i.e. must obey: "
185 "upperRange - lowerRange != 0)." % self.rangeType
186 )
187 raise FieldValidationError(self.__class__.rangeType, self, msg)
188 if self.histDensity and self.referenceValue is None:
189 msg = "Must provide referenceValue if histDensity is True."
190 raise FieldValidationError(self.__class__.referenceValue, self, msg)
193class HistPlot(PlotAction):
194 panels = ConfigDictField(
195 doc="A configurable dict describing the panels to be plotted, and the histograms for each panel.",
196 keytype=str,
197 itemtype=HistPanel,
198 default={},
199 )
200 cmap = Field[str](
201 doc="Color map used for histogram lines. All types available via `plt.cm` may be used. "
202 "A number of custom color maps are also defined: `newtab10`, `bright`, `vibrant`.",
203 default="newtab10",
204 )
206 def getInputSchema(self) -> KeyedDataSchema:
207 for panel in self.panels: # type: ignore
208 for histData in self.panels[panel].hists.items(): # type: ignore
209 yield histData, Vector
211 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure:
212 return self.makePlot(data, **kwargs)
213 # table is a dict that needs: x, y, run, skymap, filter, tract,
215 def makePlot(
216 self, data: KeyedData, plotInfo: Mapping[str, str] = None, **kwargs # type: ignore
217 ) -> Figure:
218 """Make an N-panel plot with a user-configurable number of histograms
219 displayed in each panel.
221 Parameters
222 ----------
223 data : `pandas.core.frame.DataFrame`
224 The catalog to plot the points from.
225 plotInfo : `dict`
226 A dictionary of information about the data being plotted with keys:
227 `"run"`
228 Output run for the plots (`str`).
229 `"tractTableType"`
230 Table from which results are taken (`str`).
231 `"plotName"`
232 Output plot name (`str`)
233 `"SN"`
234 The global signal-to-noise data threshold (`float`)
235 `"skymap"`
236 The type of skymap used for the data (`str`).
237 `"tract"`
238 The tract that the data comes from (`int`).
239 `"bands"`
240 The bands used for this data (`str` or `list`).
241 `"visit"`
242 The visit that the data comes from (`int`)
244 Returns
245 -------
246 fig : `matplotlib.figure.Figure`
247 The resulting figure.
249 """
251 # set up figure
252 fig = plt.figure(dpi=300)
253 hist_fig, side_fig = fig.subfigures(1, 2, wspace=0, width_ratios=[3, 1])
254 axs, ncols, nrows = self._makeAxes(hist_fig)
256 # loop over each panel; plot histograms
257 colors = self._assignColors()
258 nth_panel = len(self.panels)
259 nth_col = ncols
260 nth_row = nrows - 1
261 label_font_size = max(6, 10 - nrows)
262 for panel, ax in zip(self.panels, axs):
263 nth_panel -= 1
264 nth_col = ncols - 1 if nth_col == 0 else nth_col - 1
265 if nth_panel == 0 and nrows * ncols - len(self.panels) > 0:
266 nth_col -= 1
267 # Set font size for legend based on number of panels being plotted.
268 legend_font_size = max(4, int(8 - len(self.panels[panel].hists) / 2 - nrows // 2)) # type: ignore
269 nums, meds, mads, stats_dict = self._makePanel(
270 data,
271 panel,
272 ax,
273 colors[panel],
274 label_font_size=label_font_size,
275 legend_font_size=legend_font_size,
276 ncols=ncols,
277 )
279 all_handles, all_nums, all_meds, all_mads = [], [], [], []
280 handles, labels = ax.get_legend_handles_labels() # code for plotting
281 all_handles += handles
282 all_nums += nums
283 all_meds += meds
284 all_mads += mads
285 title_str = self.panels[panel].label # type: ignore
286 # add side panel; add statistics
287 self._addStatisticsPanel(
288 side_fig,
289 all_handles,
290 all_nums,
291 all_meds,
292 all_mads,
293 stats_dict,
294 legend_font_size=legend_font_size,
295 yAnchor0=ax.get_position().y0,
296 nth_row=nth_row,
297 nth_col=nth_col,
298 title_str=title_str,
299 )
300 nth_row = nth_row - 1 if nth_col == 0 else nth_row
302 # add general plot info
303 if plotInfo is not None:
304 hist_fig = addPlotInfo(hist_fig, plotInfo)
306 # finish up
307 plt.draw()
308 return fig
310 def _makeAxes(self, fig):
311 """Determine axes layout for main histogram figure."""
312 num_panels = len(self.panels)
313 if num_panels <= 1:
314 ncols = 1
315 else:
316 ncols = 2
317 nrows = int(np.ceil(num_panels / ncols))
319 gs = GridSpec(nrows, ncols, left=0.12, right=0.99, bottom=0.1, top=0.88, wspace=0.31, hspace=0.45)
321 axs = []
322 counter = 0
323 for row in range(nrows):
324 for col in range(ncols):
325 counter += 1
326 if counter < num_panels:
327 axs.append(fig.add_subplot(gs[row : row + 1, col : col + 1]))
328 else:
329 axs.append(fig.add_subplot(gs[row : row + 1, col : np.min([col + 2, ncols + 1])]))
330 break
332 return axs, ncols, nrows
334 def _assignColors(self):
335 """Assign colors to histograms using a given color map."""
336 custom_cmaps = dict(
337 # https://www.tableau.com/about/blog/2016/7/colors-upgrade-tableau-10-56782
338 newtab10=[
339 "#4e79a7",
340 "#f28e2b",
341 "#e15759",
342 "#76b7b2",
343 "#59a14f",
344 "#edc948",
345 "#b07aa1",
346 "#ff9da7",
347 "#9c755f",
348 "#bab0ac",
349 ],
350 # https://personal.sron.nl/~pault/#fig:scheme_bright
351 bright=[
352 "#4477AA",
353 "#EE6677",
354 "#228833",
355 "#CCBB44",
356 "#66CCEE",
357 "#AA3377",
358 "#BBBBBB",
359 ],
360 # https://personal.sron.nl/~pault/#fig:scheme_vibrant
361 vibrant=[
362 "#EE7733",
363 "#0077BB",
364 "#33BBEE",
365 "#EE3377",
366 "#CC3311",
367 "#009988",
368 "#BBBBBB",
369 ],
370 )
371 if self.cmap in custom_cmaps.keys():
372 all_colors = custom_cmaps[self.cmap]
373 else:
374 try:
375 all_colors = getattr(plt.cm, self.cmap).copy().colors
376 except AttributeError:
377 raise ValueError(f"Unrecognized color map: {self.cmap}")
379 counter = 0
380 colors = defaultdict(list)
381 for panel in self.panels:
382 for hist in self.panels[panel].hists:
383 colors[panel].append(all_colors[counter % len(all_colors)])
384 counter += 1
385 return colors
387 def _makePanel(self, data, panel, ax, colors, label_font_size=9, legend_font_size=7, ncols=1):
388 """Plot a single panel containing histograms."""
389 nums, meds, mads = [], [], []
390 for i, hist in enumerate(self.panels[panel].hists):
391 hist_data = data[hist][np.isfinite(data[hist])]
392 num, med, mad = self._calcStats(hist_data)
393 nums.append(num)
394 meds.append(med)
395 mads.append(mad)
396 panel_range = self._getPanelRange(data, panel, mads=mads, meds=meds)
398 for i, hist in enumerate(self.panels[panel].hists):
399 hist_data = data[hist][np.isfinite(data[hist])]
400 ax.hist(
401 hist_data,
402 range=panel_range,
403 bins=self.panels[panel].bins,
404 histtype="step",
405 density=self.panels[panel].histDensity,
406 lw=2,
407 color=colors[i],
408 label=self.panels[panel].hists[hist],
409 )
410 ax.axvline(meds[i], ls=(0, (5, 3)), lw=1, c=colors[i])
412 ax.legend(fontsize=legend_font_size, loc="upper left", frameon=False)
413 ax.set_xlim(panel_range)
414 # The following accommodates spacing for ranges with large numbers
415 # but small-ish dynamic range (example use case: RA 300-301).
416 if ncols > 1 and max(np.abs(panel_range)) >= 100 and (panel_range[1] - panel_range[0]) < 5:
417 ax.xaxis.set_major_formatter("{x:.2f}")
418 ax.tick_params(axis="x", labelrotation=25, pad=-1)
419 ax.set_xlabel(self.panels[panel].label, fontsize=label_font_size)
420 y_label = "Normalized (PDF)" if self.panels[panel].histDensity else "Frequency"
421 ax.set_ylabel(y_label, fontsize=label_font_size)
422 ax.set_yscale(self.panels[panel].yscale)
423 ax.tick_params(labelsize=max(5, label_font_size - 2))
424 # add a buffer to the top of the plot to allow headspace for labels
425 ylims = list(ax.get_ylim())
426 if ax.get_yscale() == "log":
427 ylims[1] = 10 ** (np.log10(ylims[1]) * 1.1)
428 else:
429 ylims[1] *= 1.1
430 ax.set_ylim(ylims[0], ylims[1])
432 # Draw a vertical line at a reference value, if given. If histDensity
433 # is True, also plot a reference PDF with mean = referenceValue and
434 # sigma = 1 for reference.
435 if self.panels[panel].referenceValue is not None:
436 ax = self._addReferenceLines(ax, panel, panel_range, legend_font_size=legend_font_size)
438 # Check if we should use the default stats panel or if a custom one
439 # has been created.
440 statList = [
441 self.panels[panel].statsPanel.stat1,
442 self.panels[panel].statsPanel.stat2,
443 self.panels[panel].statsPanel.stat3,
444 ]
445 if not any(statList):
446 stats_dict = {
447 "statLabels": ["N$_{{data}}$", "Med", "${{\\sigma}}_{{MAD}}$"],
448 "stat1": nums,
449 "stat2": meds,
450 "stat3": mads,
451 }
452 elif all(statList):
453 stat1 = [data[stat] for stat in self.panels[panel].statsPanel.stat1]
454 stat2 = [data[stat] for stat in self.panels[panel].statsPanel.stat2]
455 stat3 = [data[stat] for stat in self.panels[panel].statsPanel.stat3]
456 stats_dict = {
457 "statLabels": self.panels[panel].statsPanel.statsLabels,
458 "stat1": stat1,
459 "stat2": stat2,
460 "stat3": stat3,
461 }
462 else:
463 raise RuntimeError("Invalid configuration of HistStatPanel")
465 return nums, meds, mads, stats_dict
467 def _getPanelRange(self, data, panel, mads=None, meds=None):
468 """Determine panel x-axis range based config settings."""
469 panel_range = [np.nan, np.nan]
470 rangeType = self.panels[panel].rangeType
471 lowerRange = self.panels[panel].lowerRange
472 upperRange = self.panels[panel].upperRange
473 if rangeType == "percentile":
474 panel_range = self._getPercentilePanelRange(data, panel)
475 elif rangeType == "sigmaMad":
476 # Set the panel range to extend lowerRange[upperRange] times the
477 # maximum sigmaMad for the datasets in the panel to the left[right]
478 # from the minimum[maximum] median value of all datasets in the
479 # panel.
480 maxMad = np.nanmax(mads)
481 maxMed = np.nanmax(meds)
482 minMed = np.nanmin(meds)
483 panel_range = [minMed - lowerRange * maxMad, maxMed + upperRange * maxMad]
484 if panel_range[1] - panel_range[0] == 0:
485 self.log.info(
486 "NOTE: panel_range for {} based on med/sigMad was 0. Computing using "
487 "percentile range instead.".format(panel)
488 )
489 panel_range = self._getPercentilePanelRange(data, panel)
490 elif rangeType == "fixed":
491 panel_range = [lowerRange, upperRange]
492 else:
493 raise RuntimeError(f"Invalid rangeType: {rangeType}")
494 return panel_range
496 def _getPercentilePanelRange(self, data, panel):
497 """Determine panel x-axis range based on data percentile limits."""
498 panel_range = [np.nan, np.nan]
499 for hist in self.panels[panel].hists:
500 hist_range = np.nanpercentile(
501 data[hist], [self.panels[panel].lowerRange, self.panels[panel].upperRange]
502 )
503 panel_range[0] = np.nanmin([panel_range[0], hist_range[0]])
504 panel_range[1] = np.nanmax([panel_range[1], hist_range[1]])
505 return panel_range
507 def _calcStats(self, data):
508 """Calculate the number of data points, median, and median absolute
509 deviation of input data."""
510 num = len(data)
511 med = np.nanmedian(data)
512 mad = sigmaMad(data)
513 return num, med, mad
515 def _addReferenceLines(self, ax, panel, panel_range, legend_font_size=7):
516 """Draw the vertical reference line and density curve (if requested)
517 on the panel.
518 """
519 ax2 = ax.twinx()
520 ax2.axis("off")
521 ax2.set_xlim(ax.get_xlim())
522 ax2.set_ylim(ax.get_ylim())
524 if self.panels[panel].histDensity:
525 reference_label = None
526 else:
527 reference_label = "${{\\mu_{{ref}}}}$: {}".format(self.panels[panel].referenceValue)
528 ax2.axvline(
529 self.panels[panel].referenceValue, ls="-", lw=1, c="black", zorder=0, label=reference_label
530 )
531 if self.panels[panel].histDensity:
532 ref_x = np.arange(panel_range[0], panel_range[1], (panel_range[1] - panel_range[0]) / 100.0)
533 ref_mean = self.panels[panel].referenceValue
534 ref_std = 1.0
535 ref_y = (
536 1.0
537 / (ref_std * np.sqrt(2.0 * np.pi))
538 * np.exp(-((ref_x - ref_mean) ** 2) / (2.0 * ref_std**2))
539 )
540 ax2.fill_between(ref_x, ref_y, alpha=0.1, color="black", label="P$_{{norm}}(0,1)$", zorder=-1)
541 # Make sure the y-axis extends beyond the data plotted and that
542 # the y-ranges of both axes are in sync.
543 y_max = max(max(ref_y), ax2.get_ylim()[1])
544 if ax2.get_ylim()[1] < 1.05 * y_max:
545 ax.set_ylim(ax.get_ylim()[0], 1.05 * y_max)
546 ax2.set_ylim(ax.get_ylim())
547 ax2.legend(fontsize=legend_font_size, handlelength=1.5, loc="upper right", frameon=False)
549 return ax
551 def _addStatisticsPanel(
552 self,
553 fig,
554 handles,
555 nums,
556 meds,
557 mads,
558 stats_dict,
559 legend_font_size=8,
560 yAnchor0=0.0,
561 nth_row=0,
562 nth_col=0,
563 title_str=None,
564 ):
565 """Add an adjoining panel containing histogram summary statistics."""
566 ax = fig.add_subplot(1, 1, 1)
567 ax.axis("off")
568 plt.subplots_adjust(left=0.05, right=0.95, bottom=0.0, top=1.0)
569 # empty handle, used to populate the bespoke legend layout
570 empty = Rectangle((0, 0), 1, 1, fc="w", fill=False, edgecolor="none", linewidth=0)
572 # set up new legend handles and labels
573 legend_handles = [empty] + handles + ([empty] * 3 * len(handles)) + ([empty] * 3)
575 legend_labels = (
576 ([""] * (len(handles) + 1))
577 + [stats_dict["statLabels"][0]]
578 + [f"{x:.3g}" for x in stats_dict["stat1"]]
579 + [stats_dict["statLabels"][1]]
580 + [f"{x:.3g}" for x in stats_dict["stat2"]]
581 + [stats_dict["statLabels"][2]]
582 + [f"{x:.3g}" for x in stats_dict["stat3"]]
583 )
584 # Set the y anchor for the legend such that it roughly lines up with
585 # the panels.
586 yAnchor = max(0, yAnchor0 - 0.01) + nth_col * (0.008 + len(nums) * 0.005) * legend_font_size
588 nth_legend = ax.legend(
589 legend_handles,
590 legend_labels,
591 loc="lower left",
592 bbox_to_anchor=(0.0, yAnchor),
593 ncol=4,
594 handletextpad=-0.25,
595 fontsize=legend_font_size,
596 borderpad=0,
597 frameon=False,
598 columnspacing=-0.25,
599 title=title_str,
600 title_fontproperties={"weight": "bold", "size": legend_font_size},
601 )
602 if nth_row + nth_col > 0:
603 ax.add_artist(nth_legend)