Coverage for python/lsst/analysis/tools/actions/plot/histPlot.py: 15%
233 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-11-02 12:31 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-11-02 12:31 +0000
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ("HistPanel", "HistPlot", "HistStatsPanel")
25import logging
26from collections import defaultdict
27from typing import Mapping
29import matplotlib.pyplot as plt
30import numpy as np
31from lsst.pex.config import (
32 ChoiceField,
33 Config,
34 ConfigDictField,
35 ConfigField,
36 DictField,
37 Field,
38 FieldValidationError,
39 ListField,
40)
41from matplotlib.figure import Figure
42from matplotlib.gridspec import GridSpec
43from matplotlib.patches import Rectangle
45from ...interfaces import KeyedData, KeyedDataSchema, PlotAction, Vector
46from ...statistics import sigmaMad
47from .plotUtils import addPlotInfo
49log = logging.getLogger(__name__)
52class HistStatsPanel(Config):
53 """A Config class that holds parameters to configure a the stats panel
54 shown for histPlot.
56 The fields in this class correspond to the parameters that can be used to
57 customize the HistPlot stats panel.
59 - The ListField parameter a dict to specify names of 3 stat columns accepts
60 latex formating
62 - The other parameters (stat1, stat2, stat3) are lists of strings that
63 specify vector keys correspoinding to scalar values computed in the
64 prep/process/produce steps of an analysis tools plot/metric configurable
65 action. There should be one key for each group in the HistPanel.
67 A separate config class is used instead of constructing
68 `~lsst.pex.config.DictField`'s in HistPanel for each parameter for clarity
69 and consistency.
71 Notes
72 -----
73 This is intended to be used as a configuration of the HistPlot/HistPanel
74 class.
76 If no HistStatsPanel is specified then the default behavor persists where
77 the stats panel shows N / median / sigma_mad for each group in the panel.
78 """
80 statsLabels = ListField[str](
81 doc="list specifying the labels for stats",
82 length=3,
83 default=("N$_{{data}}$", "Med", "${{\\sigma}}_{{MAD}}$"),
84 )
85 stat1 = ListField[str](
86 doc="A list specifying the vector keys of the first scalar statistic to be shown in this panel."
87 "there should be one entry for each hist in the panel",
88 default=None,
89 optional=True,
90 )
91 stat2 = ListField[str](
92 doc="A list specifying the vector keys of the second scalar statistic to be shown in this panel."
93 "there should be one entry for each hist in the panel",
94 default=None,
95 optional=True,
96 )
97 stat3 = ListField[str](
98 doc="A list specifying the vector keys of the third scalar statistic to be shown in this panel."
99 "there should be one entry for each hist in the panel",
100 default=None,
101 optional=True,
102 )
104 def validate(self):
105 super().validate()
106 if not all([self.stat1, self.stat2, self.stat3]) and any([self.stat1, self.stat2, self.stat3]):
107 raise ValueError(f"{self._name}: If one stat is configured, all 3 stats must be configured")
110class HistPanel(Config):
111 """A Config class that holds parameters to configure a single panel of a
112 histogram plot. This class is intended to be used within the ``HistPlot``
113 class.
114 """
116 label = Field[str](
117 doc="Panel x-axis label.",
118 default="label",
119 )
120 hists = DictField[str, str](
121 doc="A dict specifying the histograms to be plotted in this panel. Keys are used to identify "
122 "histogram IDs. Values are used to add to the legend label displayed in the upper corner of the "
123 "panel.",
124 optional=False,
125 )
126 yscale = Field[str](
127 doc="Y axis scaling.",
128 default="linear",
129 )
130 bins = Field[int](
131 doc="Number of x axis bins within plot x-range.",
132 default=50,
133 )
134 rangeType = ChoiceField[str](
135 doc="Set the type of range to use for the x-axis. Range bounds will be set according to "
136 "the values of lowerRange and upperRange.",
137 allowed={
138 "percentile": "Upper and lower percentile ranges of the data.",
139 "sigmaMad": "Range is (sigmaMad - lowerRange*sigmaMad, sigmaMad + upperRange*sigmaMad).",
140 "fixed": "Range is fixed to (lowerRange, upperRange).",
141 },
142 default="percentile",
143 )
144 lowerRange = Field[float](
145 doc="Lower range specifier for the histogram bins. See rangeType for interpretation "
146 "based on the type of range requested. If more than one histogram is plotted in a given "
147 "panel and rangeType is not set to fixed, the limit is the minimum value across all input "
148 "data.",
149 default=0.0,
150 )
151 upperRange = Field[float](
152 doc="Upper range specifier for the histogram bins. See rangeType for interpretation "
153 "based on the type of range requested. If more than one histogram is plotted in a given "
154 "panel and rangeType is not set to fixed, the limit is the maximum value across all input "
155 "data.",
156 default=100.0,
157 )
158 referenceValue = Field[float](
159 doc="Value at which to add a black solid vertical line. Ignored if set to `None`.",
160 default=None,
161 optional=True,
162 )
163 refRelativeToMedian = Field[bool](
164 doc="Is the referenceValue meant to be an offset from the median?",
165 default=False,
166 optional=True,
167 )
168 histDensity = Field[bool](
169 doc="Whether to plot the histogram as a normalized probability distribution. Must also "
170 "provide a value for referenceValue",
171 default=False,
172 )
173 statsPanel = ConfigField[HistStatsPanel](
174 doc="configuration for stats to be shown on plot, if None then "
175 "default stats: N, median, sigma mad are shown",
176 default=None,
177 )
179 def validate(self):
180 super().validate()
181 if self.rangeType == "percentile" and self.lowerRange < 0.0 or self.upperRange > 100.0:
182 msg = (
183 "For rangeType %s, ranges must obey: lowerRange >= 0 and upperRange <= 100." % self.rangeType
184 )
185 raise FieldValidationError(self.__class__.rangeType, self, msg)
186 if self.rangeType == "sigmaMad" and self.lowerRange < 0.0:
187 msg = (
188 "For rangeType %s, lower range must obey: lowerRange >= 0 (the lower range is "
189 "set as median - lowerRange*sigmaMad." % self.rangeType
190 )
191 raise FieldValidationError(self.__class__.rangeType, self, msg)
192 if self.rangeType == "fixed" and (self.upperRange - self.lowerRange) == 0.0:
193 msg = (
194 "For rangeType %s, lower and upper ranges must differ (i.e. must obey: "
195 "upperRange - lowerRange != 0)." % self.rangeType
196 )
197 raise FieldValidationError(self.__class__.rangeType, self, msg)
198 if self.histDensity and self.referenceValue is None:
199 msg = "Must provide referenceValue if histDensity is True."
200 raise FieldValidationError(self.__class__.referenceValue, self, msg)
203class HistPlot(PlotAction):
204 """Make an N-panel plot with a configurable number of histograms displayed
205 in each panel. Reference lines showing values of interest may also be added
206 to each histogram. Panels are configured using the ``HistPanel`` class.
207 """
209 panels = ConfigDictField(
210 doc="A configurable dict describing the panels to be plotted, and the histograms for each panel.",
211 keytype=str,
212 itemtype=HistPanel,
213 default={},
214 )
215 cmap = Field[str](
216 doc="Color map used for histogram lines. All types available via `plt.cm` may be used. "
217 "A number of custom color maps are also defined: `newtab10`, `bright`, `vibrant`.",
218 default="newtab10",
219 )
221 def getInputSchema(self) -> KeyedDataSchema:
222 for panel in self.panels: # type: ignore
223 for histData in self.panels[panel].hists.items(): # type: ignore
224 yield histData, Vector
226 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure:
227 return self.makePlot(data, **kwargs)
228 # table is a dict that needs: x, y, run, skymap, filter, tract,
230 def makePlot(
231 self, data: KeyedData, plotInfo: Mapping[str, str] = None, **kwargs # type: ignore
232 ) -> Figure:
233 """Make an N-panel plot with a user-configurable number of histograms
234 displayed in each panel.
236 Parameters
237 ----------
238 data : `pandas.core.frame.DataFrame`
239 The catalog to plot the points from.
240 plotInfo : `dict`
241 A dictionary of information about the data being plotted with keys:
242 `"run"`
243 Output run for the plots (`str`).
244 `"tractTableType"`
245 Table from which results are taken (`str`).
246 `"plotName"`
247 Output plot name (`str`)
248 `"SN"`
249 The global signal-to-noise data threshold (`float`)
250 `"skymap"`
251 The type of skymap used for the data (`str`).
252 `"tract"`
253 The tract that the data comes from (`int`).
254 `"bands"`
255 The bands used for this data (`str` or `list`).
256 `"visit"`
257 The visit that the data comes from (`int`)
259 Returns
260 -------
261 fig : `matplotlib.figure.Figure`
262 The resulting figure.
264 Examples
265 --------
266 An example histogram plot may be seen below:
268 .. image:: /_static/analysis_tools/histPlotExample.png
270 For further details on how to generate a plot, please refer to the
271 :ref:`getting started guide<analysis-tools-getting-started>`.
272 """
274 # set up figure
275 fig = plt.figure(dpi=300)
276 hist_fig, side_fig = fig.subfigures(1, 2, wspace=0, width_ratios=[3, 1])
277 axs, ncols, nrows = self._makeAxes(hist_fig)
279 # loop over each panel; plot histograms
280 colors = self._assignColors()
281 nth_panel = len(self.panels)
282 nth_col = ncols
283 nth_row = nrows - 1
284 label_font_size = max(6, 10 - nrows)
285 for panel, ax in zip(self.panels, axs):
286 nth_panel -= 1
287 nth_col = ncols - 1 if nth_col == 0 else nth_col - 1
288 if nth_panel == 0 and nrows * ncols - len(self.panels) > 0:
289 nth_col -= 1
290 # Set font size for legend based on number of panels being plotted.
291 legend_font_size = max(4, int(8 - len(self.panels[panel].hists) / 2 - nrows // 2)) # type: ignore
292 nums, meds, mads, stats_dict = self._makePanel(
293 data,
294 panel,
295 ax,
296 colors[panel],
297 label_font_size=label_font_size,
298 legend_font_size=legend_font_size,
299 ncols=ncols,
300 )
302 all_handles, all_nums, all_meds, all_mads = [], [], [], []
303 handles, labels = ax.get_legend_handles_labels() # code for plotting
304 all_handles += handles
305 all_nums += nums
306 all_meds += meds
307 all_mads += mads
308 title_str = self.panels[panel].label # type: ignore
309 # add side panel; add statistics
310 self._addStatisticsPanel(
311 side_fig,
312 all_handles,
313 all_nums,
314 all_meds,
315 all_mads,
316 stats_dict,
317 legend_font_size=legend_font_size,
318 yAnchor0=ax.get_position().y0,
319 nth_row=nth_row,
320 nth_col=nth_col,
321 title_str=title_str,
322 )
323 nth_row = nth_row - 1 if nth_col == 0 else nth_row
325 # add general plot info
326 if plotInfo is not None:
327 hist_fig = addPlotInfo(hist_fig, plotInfo)
329 # finish up
330 plt.draw()
331 return fig
333 def _makeAxes(self, fig):
334 """Determine axes layout for main histogram figure."""
335 num_panels = len(self.panels)
336 if num_panels <= 1:
337 ncols = 1
338 else:
339 ncols = 2
340 nrows = int(np.ceil(num_panels / ncols))
342 gs = GridSpec(nrows, ncols, left=0.12, right=0.99, bottom=0.1, top=0.88, wspace=0.31, hspace=0.45)
344 axs = []
345 counter = 0
346 for row in range(nrows):
347 for col in range(ncols):
348 counter += 1
349 if counter < num_panels:
350 axs.append(fig.add_subplot(gs[row : row + 1, col : col + 1]))
351 else:
352 axs.append(fig.add_subplot(gs[row : row + 1, col : np.min([col + 2, ncols + 1])]))
353 break
355 return axs, ncols, nrows
357 def _assignColors(self):
358 """Assign colors to histograms using a given color map."""
359 custom_cmaps = dict(
360 # https://www.tableau.com/about/blog/2016/7/colors-upgrade-tableau-10-56782
361 newtab10=[
362 "#4e79a7",
363 "#f28e2b",
364 "#e15759",
365 "#76b7b2",
366 "#59a14f",
367 "#edc948",
368 "#b07aa1",
369 "#ff9da7",
370 "#9c755f",
371 "#bab0ac",
372 ],
373 # https://personal.sron.nl/~pault/#fig:scheme_bright
374 bright=[
375 "#4477AA",
376 "#EE6677",
377 "#228833",
378 "#CCBB44",
379 "#66CCEE",
380 "#AA3377",
381 "#BBBBBB",
382 ],
383 # https://personal.sron.nl/~pault/#fig:scheme_vibrant
384 vibrant=[
385 "#EE7733",
386 "#0077BB",
387 "#33BBEE",
388 "#EE3377",
389 "#CC3311",
390 "#009988",
391 "#BBBBBB",
392 ],
393 )
394 if self.cmap in custom_cmaps.keys():
395 all_colors = custom_cmaps[self.cmap]
396 else:
397 try:
398 all_colors = getattr(plt.cm, self.cmap).copy().colors
399 except AttributeError:
400 raise ValueError(f"Unrecognized color map: {self.cmap}")
402 counter = 0
403 colors = defaultdict(list)
404 for panel in self.panels:
405 for hist in self.panels[panel].hists:
406 colors[panel].append(all_colors[counter % len(all_colors)])
407 counter += 1
408 return colors
410 def _makePanel(self, data, panel, ax, colors, label_font_size=9, legend_font_size=7, ncols=1):
411 """Plot a single panel containing histograms."""
412 nums, meds, mads = [], [], []
413 for i, hist in enumerate(self.panels[panel].hists):
414 hist_data = data[hist][np.isfinite(data[hist])]
415 num, med, mad = self._calcStats(hist_data)
416 nums.append(num)
417 meds.append(med)
418 mads.append(mad)
419 panel_range = self._getPanelRange(data, panel, mads=mads, meds=meds)
420 if all(np.isfinite(panel_range)):
421 for i, hist in enumerate(self.panels[panel].hists):
422 hist_data = data[hist][np.isfinite(data[hist])]
423 if len(hist_data) > 0:
424 ax.hist(
425 hist_data,
426 range=panel_range,
427 bins=self.panels[panel].bins,
428 histtype="step",
429 density=self.panels[panel].histDensity,
430 lw=2,
431 color=colors[i],
432 label=self.panels[panel].hists[hist],
433 )
434 ax.axvline(meds[i], ls=(0, (5, 3)), lw=1, c=colors[i])
436 ax.legend(fontsize=legend_font_size, loc="upper left", frameon=False)
437 ax.set_xlim(panel_range)
438 # The following accommodates spacing for ranges with large numbers
439 # but small-ish dynamic range (example use case: RA 300-301).
440 if ncols > 1 and max(np.abs(panel_range)) >= 100 and (panel_range[1] - panel_range[0]) < 5:
441 ax.xaxis.set_major_formatter("{x:.2f}")
442 ax.tick_params(axis="x", labelrotation=25, pad=-1)
443 ax.set_xlabel(self.panels[panel].label, fontsize=label_font_size)
444 y_label = "Normalized (PDF)" if self.panels[panel].histDensity else "Frequency"
445 ax.set_ylabel(y_label, fontsize=label_font_size)
446 ax.set_yscale(self.panels[panel].yscale)
447 ax.tick_params(labelsize=max(5, label_font_size - 2))
448 # add a buffer to the top of the plot to allow headspace for labels
449 ylims = list(ax.get_ylim())
450 if ax.get_yscale() == "log":
451 ylims[1] = 10 ** (np.log10(ylims[1]) * 1.1)
452 else:
453 ylims[1] *= 1.1
454 ax.set_ylim(ylims[0], ylims[1])
456 # Draw a vertical line at a reference value, if given.
457 # If histDensity is True, also plot a reference PDF with
458 # mean = referenceValue and sigma = 1 for reference.
459 if self.panels[panel].referenceValue is not None:
460 ax = self._addReferenceLines(ax, panel, panel_range, meds, legend_font_size=legend_font_size)
462 # Check if we should use the default stats panel or if a custom one
463 # has been created.
464 statList = [
465 self.panels[panel].statsPanel.stat1,
466 self.panels[panel].statsPanel.stat2,
467 self.panels[panel].statsPanel.stat3,
468 ]
469 if not any(statList):
470 stats_dict = {
471 "statLabels": ["N$_{{data}}$", "Med", "${{\\sigma}}_{{MAD}}$"],
472 "stat1": nums,
473 "stat2": meds,
474 "stat3": mads,
475 }
476 elif all(statList):
477 stat1 = [data[stat] for stat in self.panels[panel].statsPanel.stat1]
478 stat2 = [data[stat] for stat in self.panels[panel].statsPanel.stat2]
479 stat3 = [data[stat] for stat in self.panels[panel].statsPanel.stat3]
480 stats_dict = {
481 "statLabels": self.panels[panel].statsPanel.statsLabels,
482 "stat1": stat1,
483 "stat2": stat2,
484 "stat3": stat3,
485 }
486 else:
487 raise RuntimeError("Invalid configuration of HistStatPanel")
488 else:
489 stats_dict = {key: [] for key in ("stat1", "stat2", "stat3")}
490 stats_dict["statLabels"] = [""] * 3
491 return nums, meds, mads, stats_dict
493 def _getPanelRange(self, data, panel, mads=None, meds=None):
494 """Determine panel x-axis range based config settings."""
495 panel_range = [np.nan, np.nan]
496 rangeType = self.panels[panel].rangeType
497 lowerRange = self.panels[panel].lowerRange
498 upperRange = self.panels[panel].upperRange
499 if rangeType == "percentile":
500 panel_range = self._getPercentilePanelRange(data, panel)
501 elif rangeType == "sigmaMad":
502 # Set the panel range to extend lowerRange[upperRange] times the
503 # maximum sigmaMad for the datasets in the panel to the left[right]
504 # from the minimum[maximum] median value of all datasets in the
505 # panel.
506 maxMad = np.nanmax(mads)
507 maxMed = np.nanmax(meds)
508 minMed = np.nanmin(meds)
509 panel_range = [minMed - lowerRange * maxMad, maxMed + upperRange * maxMad]
510 if panel_range[1] - panel_range[0] == 0:
511 self.log.info(
512 "NOTE: panel_range for {} based on med/sigMad was 0. Computing using "
513 "percentile range instead.".format(panel)
514 )
515 panel_range = self._getPercentilePanelRange(data, panel)
516 elif rangeType == "fixed":
517 panel_range = [lowerRange, upperRange]
518 else:
519 raise RuntimeError(f"Invalid rangeType: {rangeType}")
520 return panel_range
522 def _getPercentilePanelRange(self, data, panel):
523 """Determine panel x-axis range based on data percentile limits."""
524 panel_range = [np.nan, np.nan]
525 for hist in self.panels[panel].hists:
526 data_hist = data[hist]
527 # TODO: Consider raising instead
528 if len(data_hist) > 0:
529 hist_range = np.nanpercentile(
530 data[hist], [self.panels[panel].lowerRange, self.panels[panel].upperRange]
531 )
532 panel_range[0] = np.nanmin([panel_range[0], hist_range[0]])
533 panel_range[1] = np.nanmax([panel_range[1], hist_range[1]])
534 return panel_range
536 def _calcStats(self, data):
537 """Calculate the number of data points, median, and median absolute
538 deviation of input data."""
539 num = len(data)
540 med = np.nanmedian(data)
541 mad = sigmaMad(data)
542 return num, med, mad
544 def _addReferenceLines(self, ax, panel, panel_range, meds, legend_font_size=7):
545 """Draw the vertical reference line and density curve (if requested)
546 on the panel.
547 """
548 ax2 = ax.twinx()
549 ax2.axis("off")
550 ax2.set_xlim(ax.get_xlim())
551 ax2.set_ylim(ax.get_ylim())
553 if self.panels[panel].histDensity:
554 reference_label = None
555 else:
556 if self.panels[panel].refRelativeToMedian:
557 reference_value = self.panels[panel].referenceValue + meds[0]
558 reference_label = "${{\\mu_{{ref}}}}$: {}".format(reference_value)
559 else:
560 reference_value = self.panels[panel].referenceValue
561 reference_label = "${{\\mu_{{ref}}}}$: {}".format(reference_value)
562 ax2.axvline(reference_value, ls="-", lw=1, c="black", zorder=0, label=reference_label)
563 if self.panels[panel].histDensity:
564 ref_x = np.arange(panel_range[0], panel_range[1], (panel_range[1] - panel_range[0]) / 100.0)
565 ref_mean = self.panels[panel].referenceValue
566 ref_std = 1.0
567 ref_y = (
568 1.0
569 / (ref_std * np.sqrt(2.0 * np.pi))
570 * np.exp(-((ref_x - ref_mean) ** 2) / (2.0 * ref_std**2))
571 )
572 ax2.fill_between(ref_x, ref_y, alpha=0.1, color="black", label="P$_{{norm}}(0,1)$", zorder=-1)
573 # Make sure the y-axis extends beyond the data plotted and that
574 # the y-ranges of both axes are in sync.
575 y_max = max(max(ref_y), ax2.get_ylim()[1])
576 if ax2.get_ylim()[1] < 1.05 * y_max:
577 ax.set_ylim(ax.get_ylim()[0], 1.05 * y_max)
578 ax2.set_ylim(ax.get_ylim())
579 ax2.legend(fontsize=legend_font_size, handlelength=1.5, loc="upper right", frameon=False)
581 return ax
583 def _addStatisticsPanel(
584 self,
585 fig,
586 handles,
587 nums,
588 meds,
589 mads,
590 stats_dict,
591 legend_font_size=8,
592 yAnchor0=0.0,
593 nth_row=0,
594 nth_col=0,
595 title_str=None,
596 ):
597 """Add an adjoining panel containing histogram summary statistics."""
598 ax = fig.add_subplot(1, 1, 1)
599 ax.axis("off")
600 plt.subplots_adjust(left=0.05, right=0.95, bottom=0.0, top=1.0)
601 # empty handle, used to populate the bespoke legend layout
602 empty = Rectangle((0, 0), 1, 1, fc="w", fill=False, edgecolor="none", linewidth=0)
604 # set up new legend handles and labels
605 legend_handles = [empty] + handles + ([empty] * 3 * len(handles)) + ([empty] * 3)
607 legend_labels = (
608 ([""] * (len(handles) + 1))
609 + [stats_dict["statLabels"][0]]
610 + [f"{x:.3g}" for x in stats_dict["stat1"]]
611 + [stats_dict["statLabels"][1]]
612 + [f"{x:.3g}" for x in stats_dict["stat2"]]
613 + [stats_dict["statLabels"][2]]
614 + [f"{x:.3g}" for x in stats_dict["stat3"]]
615 )
616 # Set the y anchor for the legend such that it roughly lines up with
617 # the panels.
618 yAnchor = max(0, yAnchor0 - 0.01) + nth_col * (0.008 + len(nums) * 0.005) * legend_font_size
620 nth_legend = ax.legend(
621 legend_handles,
622 legend_labels,
623 loc="lower left",
624 bbox_to_anchor=(0.0, yAnchor),
625 ncol=4,
626 handletextpad=-0.25,
627 fontsize=legend_font_size,
628 borderpad=0,
629 frameon=False,
630 columnspacing=-0.25,
631 title=title_str,
632 title_fontproperties={"weight": "bold", "size": legend_font_size},
633 )
634 if nth_row + nth_col > 0:
635 ax.add_artist(nth_legend)