Coverage for python/lsst/analysis/tools/actions/plot/histPlot.py: 15%
233 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-03-22 11:50 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-03-22 11:50 +0000
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ("HistPanel", "HistPlot", "HistStatsPanel")
25import logging
26from collections import defaultdict
27from typing import Mapping
29import matplotlib.pyplot as plt
30import numpy as np
31from lsst.pex.config import (
32 ChoiceField,
33 Config,
34 ConfigDictField,
35 ConfigField,
36 DictField,
37 Field,
38 FieldValidationError,
39 ListField,
40)
41from matplotlib.figure import Figure
42from matplotlib.gridspec import GridSpec
43from matplotlib.patches import Rectangle
45from ...interfaces import KeyedData, KeyedDataSchema, PlotAction, Vector
46from ...math import nanMax, nanMedian, nanMin, sigmaMad
47from .plotUtils import addPlotInfo
49log = logging.getLogger(__name__)
52class HistStatsPanel(Config):
53 """A Config class that holds parameters to configure a the stats panel
54 shown for histPlot.
56 The fields in this class correspond to the parameters that can be used to
57 customize the HistPlot stats panel.
59 - The ListField parameter a dict to specify names of 3 stat columns accepts
60 latex formating
62 - The other parameters (stat1, stat2, stat3) are lists of strings that
63 specify vector keys correspoinding to scalar values computed in the
64 prep/process/produce steps of an analysis tools plot/metric configurable
65 action. There should be one key for each group in the HistPanel.
67 A separate config class is used instead of constructing
68 `~lsst.pex.config.DictField`'s in HistPanel for each parameter for clarity
69 and consistency.
71 Notes
72 -----
73 This is intended to be used as a configuration of the HistPlot/HistPanel
74 class.
76 If no HistStatsPanel is specified then the default behavor persists where
77 the stats panel shows N / median / sigma_mad for each group in the panel.
78 """
80 statsLabels = ListField[str](
81 doc="list specifying the labels for stats",
82 length=3,
83 default=("N$_{{data}}$", "Med", "${{\\sigma}}_{{MAD}}$"),
84 )
85 stat1 = ListField[str](
86 doc="A list specifying the vector keys of the first scalar statistic to be shown in this panel."
87 "there should be one entry for each hist in the panel",
88 default=None,
89 optional=True,
90 )
91 stat2 = ListField[str](
92 doc="A list specifying the vector keys of the second scalar statistic to be shown in this panel."
93 "there should be one entry for each hist in the panel",
94 default=None,
95 optional=True,
96 )
97 stat3 = ListField[str](
98 doc="A list specifying the vector keys of the third scalar statistic to be shown in this panel."
99 "there should be one entry for each hist in the panel",
100 default=None,
101 optional=True,
102 )
104 def validate(self):
105 super().validate()
106 if not all([self.stat1, self.stat2, self.stat3]) and any([self.stat1, self.stat2, self.stat3]):
107 raise ValueError(f"{self._name}: If one stat is configured, all 3 stats must be configured")
110class HistPanel(Config):
111 """A Config class that holds parameters to configure a single panel of a
112 histogram plot. This class is intended to be used within the ``HistPlot``
113 class.
114 """
116 label = Field[str](
117 doc="Panel x-axis label.",
118 default="label",
119 )
120 hists = DictField[str, str](
121 doc="A dict specifying the histograms to be plotted in this panel. Keys are used to identify "
122 "histogram IDs. Values are used to add to the legend label displayed in the upper corner of the "
123 "panel.",
124 optional=False,
125 )
126 yscale = Field[str](
127 doc="Y axis scaling.",
128 default="linear",
129 )
130 bins = Field[int](
131 doc="Number of x axis bins within plot x-range.",
132 default=50,
133 )
134 rangeType = ChoiceField[str](
135 doc="Set the type of range to use for the x-axis. Range bounds will be set according to "
136 "the values of lowerRange and upperRange.",
137 allowed={
138 "percentile": "Upper and lower percentile ranges of the data.",
139 "sigmaMad": "Range is (sigmaMad - lowerRange*sigmaMad, sigmaMad + upperRange*sigmaMad).",
140 "fixed": "Range is fixed to (lowerRange, upperRange).",
141 },
142 default="percentile",
143 )
144 lowerRange = Field[float](
145 doc="Lower range specifier for the histogram bins. See rangeType for interpretation "
146 "based on the type of range requested. If more than one histogram is plotted in a given "
147 "panel and rangeType is not set to fixed, the limit is the minimum value across all input "
148 "data.",
149 default=0.0,
150 )
151 upperRange = Field[float](
152 doc="Upper range specifier for the histogram bins. See rangeType for interpretation "
153 "based on the type of range requested. If more than one histogram is plotted in a given "
154 "panel and rangeType is not set to fixed, the limit is the maximum value across all input "
155 "data.",
156 default=100.0,
157 )
158 referenceValue = Field[float](
159 doc="Value at which to add a black solid vertical line. Ignored if set to `None`.",
160 default=None,
161 optional=True,
162 )
163 refRelativeToMedian = Field[bool](
164 doc="Is the referenceValue meant to be an offset from the median?",
165 default=False,
166 optional=True,
167 )
168 histDensity = Field[bool](
169 doc="Whether to plot the histogram as a normalized probability distribution. Must also "
170 "provide a value for referenceValue",
171 default=False,
172 )
173 statsPanel = ConfigField[HistStatsPanel](
174 doc="configuration for stats to be shown on plot, if None then "
175 "default stats: N, median, sigma mad are shown",
176 default=None,
177 )
179 def validate(self):
180 super().validate()
181 if self.rangeType == "percentile" and self.lowerRange < 0.0 or self.upperRange > 100.0:
182 msg = (
183 "For rangeType %s, ranges must obey: lowerRange >= 0 and upperRange <= 100." % self.rangeType
184 )
185 raise FieldValidationError(self.__class__.rangeType, self, msg)
186 if self.rangeType == "sigmaMad" and self.lowerRange < 0.0:
187 msg = (
188 "For rangeType %s, lower range must obey: lowerRange >= 0 (the lower range is "
189 "set as median - lowerRange*sigmaMad." % self.rangeType
190 )
191 raise FieldValidationError(self.__class__.rangeType, self, msg)
192 if self.rangeType == "fixed" and (self.upperRange - self.lowerRange) == 0.0:
193 msg = (
194 "For rangeType %s, lower and upper ranges must differ (i.e. must obey: "
195 "upperRange - lowerRange != 0)." % self.rangeType
196 )
197 raise FieldValidationError(self.__class__.rangeType, self, msg)
198 if self.histDensity and self.referenceValue is None:
199 msg = "Must provide referenceValue if histDensity is True."
200 raise FieldValidationError(self.__class__.referenceValue, self, msg)
203class HistPlot(PlotAction):
204 """Make an N-panel plot with a configurable number of histograms displayed
205 in each panel. Reference lines showing values of interest may also be added
206 to each histogram. Panels are configured using the ``HistPanel`` class.
207 """
209 panels = ConfigDictField(
210 doc="A configurable dict describing the panels to be plotted, and the histograms for each panel.",
211 keytype=str,
212 itemtype=HistPanel,
213 default={},
214 )
215 cmap = Field[str](
216 doc="Color map used for histogram lines. All types available via `plt.cm` may be used. "
217 "A number of custom color maps are also defined: `newtab10`, `bright`, `vibrant`.",
218 default="newtab10",
219 )
221 def getInputSchema(self) -> KeyedDataSchema:
222 for panel in self.panels: # type: ignore
223 for histData in self.panels[panel].hists.items(): # type: ignore
224 yield histData, Vector
226 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure:
227 return self.makePlot(data, **kwargs)
228 # table is a dict that needs: x, y, run, skymap, filter, tract,
230 def makePlot(
231 self, data: KeyedData, plotInfo: Mapping[str, str] = None, **kwargs # type: ignore
232 ) -> Figure:
233 """Make an N-panel plot with a user-configurable number of histograms
234 displayed in each panel.
236 Parameters
237 ----------
238 data : `pandas.core.frame.DataFrame`
239 The catalog to plot the points from.
240 plotInfo : `dict`
241 A dictionary of information about the data being plotted with keys:
242 `"run"`
243 Output run for the plots (`str`).
244 `"tractTableType"`
245 Table from which results are taken (`str`).
246 `"plotName"`
247 Output plot name (`str`)
248 `"SN"`
249 The global signal-to-noise data threshold (`float`)
250 `"skymap"`
251 The type of skymap used for the data (`str`).
252 `"tract"`
253 The tract that the data comes from (`int`).
254 `"bands"`
255 The bands used for this data (`str` or `list`).
256 `"visit"`
257 The visit that the data comes from (`int`)
259 Returns
260 -------
261 fig : `matplotlib.figure.Figure`
262 The resulting figure.
264 Examples
265 --------
266 An example histogram plot may be seen below:
268 .. image:: /_static/analysis_tools/histPlotExample.png
270 For further details on how to generate a plot, please refer to the
271 :ref:`getting started guide<analysis-tools-getting-started>`.
272 """
274 # set up figure
275 fig = plt.figure(dpi=300)
276 hist_fig, side_fig = fig.subfigures(1, 2, wspace=0, width_ratios=[3, 1])
277 axs, ncols, nrows = self._makeAxes(hist_fig)
279 # loop over each panel; plot histograms
280 colors = self._assignColors()
281 nth_panel = len(self.panels)
282 nth_col = ncols
283 nth_row = nrows - 1
284 label_font_size = max(6, 10 - nrows)
285 for panel, ax in zip(self.panels, axs):
286 nth_panel -= 1
287 nth_col = ncols - 1 if nth_col == 0 else nth_col - 1
288 if nth_panel == 0 and nrows * ncols - len(self.panels) > 0:
289 nth_col -= 1
290 # Set font size for legend based on number of panels being plotted.
291 legend_font_size = max(4, int(8 - len(self.panels[panel].hists) / 2 - nrows // 2)) # type: ignore
292 nums, meds, mads, stats_dict = self._makePanel(
293 data,
294 panel,
295 ax,
296 colors[panel],
297 label_font_size=label_font_size,
298 legend_font_size=legend_font_size,
299 ncols=ncols,
300 )
302 all_handles, all_nums, all_meds, all_mads = [], [], [], []
303 handles, labels = ax.get_legend_handles_labels() # code for plotting
304 all_handles += handles
305 all_nums += nums
306 all_meds += meds
307 all_mads += mads
308 title_str = self.panels[panel].label # type: ignore
309 # add side panel; add statistics
310 self._addStatisticsPanel(
311 side_fig,
312 all_handles,
313 all_nums,
314 all_meds,
315 all_mads,
316 stats_dict,
317 legend_font_size=legend_font_size,
318 yAnchor0=ax.get_position().y0,
319 nth_row=nth_row,
320 nth_col=nth_col,
321 title_str=title_str,
322 )
323 nth_row = nth_row - 1 if nth_col == 0 else nth_row
325 # add general plot info
326 if plotInfo is not None:
327 hist_fig = addPlotInfo(hist_fig, plotInfo)
329 # finish up
330 plt.draw()
331 return fig
333 def _makeAxes(self, fig):
334 """Determine axes layout for main histogram figure."""
335 num_panels = len(self.panels)
336 if num_panels <= 1:
337 ncols = 1
338 else:
339 ncols = 2
340 nrows = int(np.ceil(num_panels / ncols))
342 gs = GridSpec(nrows, ncols, left=0.12, right=0.99, bottom=0.1, top=0.88, wspace=0.31, hspace=0.45)
344 axs = []
345 counter = 0
346 for row in range(nrows):
347 for col in range(ncols):
348 counter += 1
349 if counter < num_panels:
350 axs.append(fig.add_subplot(gs[row : row + 1, col : col + 1]))
351 else:
352 axs.append(fig.add_subplot(gs[row : row + 1, col : np.min([col + 2, ncols + 1])]))
353 break
355 return axs, ncols, nrows
357 def _assignColors(self):
358 """Assign colors to histograms using a given color map."""
359 custom_cmaps = dict(
360 # https://www.tableau.com/about/blog/2016/7/colors-upgrade-tableau-10-56782
361 newtab10=[
362 "#4e79a7",
363 "#f28e2b",
364 "#e15759",
365 "#76b7b2",
366 "#59a14f",
367 "#edc948",
368 "#b07aa1",
369 "#ff9da7",
370 "#9c755f",
371 "#bab0ac",
372 ],
373 # https://personal.sron.nl/~pault/#fig:scheme_bright
374 bright=[
375 "#4477AA",
376 "#EE6677",
377 "#228833",
378 "#CCBB44",
379 "#66CCEE",
380 "#AA3377",
381 "#BBBBBB",
382 ],
383 # https://personal.sron.nl/~pault/#fig:scheme_vibrant
384 vibrant=[
385 "#EE7733",
386 "#0077BB",
387 "#33BBEE",
388 "#EE3377",
389 "#CC3311",
390 "#009988",
391 "#BBBBBB",
392 ],
393 )
394 if self.cmap in custom_cmaps.keys():
395 all_colors = custom_cmaps[self.cmap]
396 else:
397 try:
398 all_colors = getattr(plt.cm, self.cmap).copy().colors
399 except AttributeError:
400 raise ValueError(f"Unrecognized color map: {self.cmap}")
402 counter = 0
403 colors = defaultdict(list)
404 for panel in self.panels:
405 for hist in self.panels[panel].hists:
406 colors[panel].append(all_colors[counter % len(all_colors)])
407 counter += 1
408 return colors
410 def _makePanel(self, data, panel, ax, colors, label_font_size=9, legend_font_size=7, ncols=1):
411 """Plot a single panel containing histograms."""
412 nums, meds, mads = [], [], []
413 for i, hist in enumerate(self.panels[panel].hists):
414 hist_data = data[hist][np.isfinite(data[hist])]
415 num, med, mad = self._calcStats(hist_data)
416 nums.append(num)
417 meds.append(med)
418 mads.append(mad)
419 panel_range = self._getPanelRange(data, panel, mads=mads, meds=meds)
420 if all(np.isfinite(panel_range)):
421 for i, hist in enumerate(self.panels[panel].hists):
422 hist_data = data[hist][np.isfinite(data[hist])]
423 if len(hist_data) > 0:
424 ax.hist(
425 hist_data,
426 range=panel_range,
427 bins=self.panels[panel].bins,
428 histtype="step",
429 density=self.panels[panel].histDensity,
430 lw=2,
431 color=colors[i],
432 label=self.panels[panel].hists[hist],
433 )
434 ax.axvline(meds[i], ls=(0, (5, 3)), lw=1, c=colors[i])
436 ax.legend(fontsize=legend_font_size, loc="upper left", frameon=False)
437 ax.set_xlim(panel_range)
438 # The following accommodates spacing for ranges with large numbers
439 # but small-ish dynamic range (example use case: RA 300-301).
440 if ncols > 1 and max(np.abs(panel_range)) >= 100 and (panel_range[1] - panel_range[0]) < 5:
441 ax.xaxis.set_major_formatter("{x:.2f}")
442 ax.tick_params(axis="x", labelrotation=25, pad=-1)
443 ax.set_xlabel(self.panels[panel].label, fontsize=label_font_size)
444 y_label = "Normalized (PDF)" if self.panels[panel].histDensity else "Frequency"
445 ax.set_ylabel(y_label, fontsize=label_font_size)
446 ax.set_yscale(self.panels[panel].yscale)
447 ax.tick_params(labelsize=max(5, label_font_size - 2))
448 # add a buffer to the top of the plot to allow headspace for labels
449 ylims = list(ax.get_ylim())
450 if ax.get_yscale() == "log":
451 ylims[1] = 10 ** (np.log10(ylims[1]) * 1.1)
452 else:
453 ylims[1] *= 1.1
454 ax.set_ylim(ylims[0], ylims[1])
456 # Draw a vertical line at a reference value, if given.
457 # If histDensity is True, also plot a reference PDF with
458 # mean = referenceValue and sigma = 1 for reference.
459 if self.panels[panel].referenceValue is not None:
460 ax = self._addReferenceLines(ax, panel, panel_range, meds, legend_font_size=legend_font_size)
462 # Check if we should use the default stats panel or if a custom one
463 # has been created.
464 statList = [
465 self.panels[panel].statsPanel.stat1,
466 self.panels[panel].statsPanel.stat2,
467 self.panels[panel].statsPanel.stat3,
468 ]
469 if not any(statList):
470 stats_dict = {
471 "statLabels": ["N$_{{data}}$", "Med", "${{\\sigma}}_{{MAD}}$"],
472 "stat1": nums,
473 "stat2": meds,
474 "stat3": mads,
475 }
476 elif all(statList):
477 stat1 = [data[stat] for stat in self.panels[panel].statsPanel.stat1]
478 stat2 = [data[stat] for stat in self.panels[panel].statsPanel.stat2]
479 stat3 = [data[stat] for stat in self.panels[panel].statsPanel.stat3]
480 stats_dict = {
481 "statLabels": self.panels[panel].statsPanel.statsLabels,
482 "stat1": stat1,
483 "stat2": stat2,
484 "stat3": stat3,
485 }
486 else:
487 raise RuntimeError("Invalid configuration of HistStatPanel")
488 else:
489 stats_dict = {key: [] for key in ("stat1", "stat2", "stat3")}
490 stats_dict["statLabels"] = [""] * 3
491 return nums, meds, mads, stats_dict
493 def _getPanelRange(self, data, panel, mads=None, meds=None):
494 """Determine panel x-axis range based config settings."""
495 panel_range = [np.nan, np.nan]
496 rangeType = self.panels[panel].rangeType
497 lowerRange = self.panels[panel].lowerRange
498 upperRange = self.panels[panel].upperRange
499 if rangeType == "percentile":
500 panel_range = self._getPercentilePanelRange(data, panel)
501 elif rangeType == "sigmaMad":
502 # Set the panel range to extend lowerRange[upperRange] times the
503 # maximum sigmaMad for the datasets in the panel to the left[right]
504 # from the minimum[maximum] median value of all datasets in the
505 # panel.
506 maxMad = nanMax(mads)
507 maxMed = nanMax(meds)
508 minMed = nanMin(meds)
509 panel_range = [minMed - lowerRange * maxMad, maxMed + upperRange * maxMad]
510 if panel_range[1] - panel_range[0] == 0:
511 log.info(
512 "NOTE: panel_range for {} based on med/sigMad was 0. Computing using "
513 "percentile range instead.".format(panel)
514 )
515 panel_range = self._getPercentilePanelRange(data, panel)
516 elif rangeType == "fixed":
517 panel_range = [lowerRange, upperRange]
518 else:
519 raise RuntimeError(f"Invalid rangeType: {rangeType}")
520 return panel_range
522 def _getPercentilePanelRange(self, data, panel):
523 """Determine panel x-axis range based on data percentile limits."""
524 panel_range = [np.nan, np.nan]
525 for hist in self.panels[panel].hists:
526 data_hist = data[hist]
527 # TODO: Consider raising instead
528 if len(data_hist) > 0:
529 hist_range = np.nanpercentile(
530 data[hist], [self.panels[panel].lowerRange, self.panels[panel].upperRange]
531 )
532 panel_range[0] = nanMin([panel_range[0], hist_range[0]])
533 panel_range[1] = nanMax([panel_range[1], hist_range[1]])
534 return panel_range
536 def _calcStats(self, data):
537 """Calculate the number of data points, median, and median absolute
538 deviation of input data."""
539 num = len(data)
540 med = nanMedian(data)
541 mad = sigmaMad(data)
542 return num, med, mad
544 def _addReferenceLines(self, ax, panel, panel_range, meds, legend_font_size=7):
545 """Draw the vertical reference line and density curve (if requested)
546 on the panel.
547 """
548 ax2 = ax.twinx()
549 ax2.axis("off")
550 ax2.set_xlim(ax.get_xlim())
551 ax2.set_ylim(ax.get_ylim())
553 if self.panels[panel].histDensity:
554 reference_label = None
555 else:
556 if self.panels[panel].refRelativeToMedian:
557 reference_value = self.panels[panel].referenceValue + meds[0]
558 reference_label = "${{\\mu_{{ref}}}}$: {}".format(reference_value)
559 else:
560 reference_value = self.panels[panel].referenceValue
561 reference_label = "${{\\mu_{{ref}}}}$: {}".format(reference_value)
562 ax2.axvline(reference_value, ls="-", lw=1, c="black", zorder=0, label=reference_label)
563 if self.panels[panel].histDensity:
564 ref_x = np.arange(panel_range[0], panel_range[1], (panel_range[1] - panel_range[0]) / 100.0)
565 ref_mean = self.panels[panel].referenceValue
566 ref_std = 1.0
567 ref_y = (
568 1.0
569 / (ref_std * np.sqrt(2.0 * np.pi))
570 * np.exp(-((ref_x - ref_mean) ** 2) / (2.0 * ref_std**2))
571 )
572 ax2.fill_between(ref_x, ref_y, alpha=0.1, color="black", label="P$_{{norm}}(0,1)$", zorder=-1)
573 # Make sure the y-axis extends beyond the data plotted and that
574 # the y-ranges of both axes are in sync.
575 y_max = max(max(ref_y), ax2.get_ylim()[1])
576 if ax2.get_ylim()[1] < 1.05 * y_max:
577 ax.set_ylim(ax.get_ylim()[0], 1.05 * y_max)
578 ax2.set_ylim(ax.get_ylim())
579 ax2.legend(fontsize=legend_font_size, handlelength=1.5, loc="upper right", frameon=False)
581 return ax
583 def _addStatisticsPanel(
584 self,
585 fig,
586 handles,
587 nums,
588 meds,
589 mads,
590 stats_dict,
591 legend_font_size=8,
592 yAnchor0=0.0,
593 nth_row=0,
594 nth_col=0,
595 title_str=None,
596 ):
597 """Add an adjoining panel containing histogram summary statistics."""
598 ax = fig.add_subplot(1, 1, 1)
599 ax.axis("off")
600 plt.subplots_adjust(left=0.05, right=0.95, bottom=0.0, top=1.0)
601 # empty handle, used to populate the bespoke legend layout
602 empty = Rectangle((0, 0), 1, 1, fc="w", fill=False, edgecolor="none", linewidth=0)
604 # set up new legend handles and labels
605 legend_handles = [empty] + handles + ([empty] * 3 * len(handles)) + ([empty] * 3)
607 legend_labels = (
608 ([""] * (len(handles) + 1))
609 + [stats_dict["statLabels"][0]]
610 + [f"{x:.3g}" for x in stats_dict["stat1"]]
611 + [stats_dict["statLabels"][1]]
612 + [f"{x:.3g}" for x in stats_dict["stat2"]]
613 + [stats_dict["statLabels"][2]]
614 + [f"{x:.3g}" for x in stats_dict["stat3"]]
615 )
616 # Set the y anchor for the legend such that it roughly lines up with
617 # the panels.
618 yAnchor = max(0, yAnchor0 - 0.01) + nth_col * (0.008 + len(nums) * 0.005) * legend_font_size
620 nth_legend = ax.legend(
621 legend_handles,
622 legend_labels,
623 loc="lower left",
624 bbox_to_anchor=(0.0, yAnchor),
625 ncol=4,
626 handletextpad=-0.25,
627 fontsize=legend_font_size,
628 borderpad=0,
629 frameon=False,
630 columnspacing=-0.25,
631 title=title_str,
632 title_fontproperties={"weight": "bold", "size": legend_font_size},
633 )
634 if nth_row + nth_col > 0:
635 ax.add_artist(nth_legend)