Coverage for python/lsst/analysis/tools/actions/plot/histPlot.py: 16%
222 statements
« prev ^ index » next coverage.py v7.2.5, created at 2023-05-04 11:09 +0000
« prev ^ index » next coverage.py v7.2.5, created at 2023-05-04 11:09 +0000
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ("HistPanel", "HistPlot", "HistStatsPanel")
25import logging
26from collections import defaultdict
27from typing import Mapping
29import matplotlib.pyplot as plt
30import numpy as np
31from lsst.pex.config import (
32 ChoiceField,
33 Config,
34 ConfigDictField,
35 ConfigField,
36 DictField,
37 Field,
38 FieldValidationError,
39 ListField,
40)
41from matplotlib.figure import Figure
42from matplotlib.gridspec import GridSpec
43from matplotlib.patches import Rectangle
45from ...interfaces import KeyedData, KeyedDataSchema, PlotAction, Vector
46from ...statistics import sigmaMad
47from .plotUtils import addPlotInfo
49log = logging.getLogger(__name__)
52class HistStatsPanel(Config):
53 """A Config class that holds parameters to configure a the stats panel
54 shown for histPlot.
56 The fields in this class correspond to the parameters that can be used to
57 customize the HistPlot stats panel.
59 - The ListField parameter a dict to specify names of 3 stat columns accepts
60 latex formating
62 - The other parameters (stat1, stat2, stat3) are lists of strings that
63 specify vector keys correspoinding to scalar values computed in the
64 prep/process/produce steps of an analysis tools plot/metric configurable
65 action. There should be one key for each group in the HistPanel.
67 A separate config class is used instead of constructing
68 `~lsst.pex.config.DictField`'s in HistPanel for each parameter for clarity
69 and consistency.
71 Notes
72 -----
73 This is intended to be used as a configuration of the HistPlot/HistPanel
74 class.
76 If no HistStatsPanel is specified then the default behavor persists where
77 the stats panel shows N / median / sigma_mad for each group in the panel.
78 """
80 statsLabels = ListField[str](
81 doc="list specifying the labels for stats",
82 length=3,
83 default={"N$_{{data}}$", "Med", "${{\\sigma}}_{{MAD}}$"},
84 )
85 stat1 = ListField[str](
86 doc="A list specifying the vector keys of the first scalar statistic to be shown in this panel."
87 "there should be one entry for each hist in the panel",
88 default=None,
89 optional=True,
90 )
91 stat2 = ListField[str](
92 doc="A list specifying the vector keys of the second scalar statistic to be shown in this panel."
93 "there should be one entry for each hist in the panel",
94 default=None,
95 optional=True,
96 )
97 stat3 = ListField[str](
98 doc="A list specifying the vector keys of the third scalar statistic to be shown in this panel."
99 "there should be one entry for each hist in the panel",
100 default=None,
101 optional=True,
102 )
104 def validate(self):
105 super().validate()
106 if not all([self.stat1, self.stat2, self.stat3]) and any([self.stat1, self.stat2, self.stat3]):
107 raise ValueError(f"{self._name}: If one stat is configured, all 3 stats must be configured")
110class HistPanel(Config):
111 """A Config class that holds parameters to configure a single panel of a
112 histogram plot. This class is intended to be used within the ``HistPlot``
113 class.
114 """
116 label = Field[str](
117 doc="Panel x-axis label.",
118 default="label",
119 )
120 hists = DictField[str, str](
121 doc="A dict specifying the histograms to be plotted in this panel. Keys are used to identify "
122 "histogram IDs. Values are used to add to the legend label displayed in the upper corner of the "
123 "panel.",
124 optional=False,
125 )
126 yscale = Field[str](
127 doc="Y axis scaling.",
128 default="linear",
129 )
130 bins = Field[int](
131 doc="Number of x axis bins within plot x-range.",
132 default=50,
133 )
134 rangeType = ChoiceField[str](
135 doc="Set the type of range to use for the x-axis. Range bounds will be set according to "
136 "the values of lowerRange and upperRange.",
137 allowed={
138 "percentile": "Upper and lower percentile ranges of the data.",
139 "sigmaMad": "Range is (sigmaMad - lowerRange*sigmaMad, sigmaMad + upperRange*sigmaMad).",
140 "fixed": "Range is fixed to (lowerRange, upperRange).",
141 },
142 default="percentile",
143 )
144 lowerRange = Field[float](
145 doc="Lower range specifier for the histogram bins. See rangeType for interpretation "
146 "based on the type of range requested. If more than one histogram is plotted in a given "
147 "panel and rangeType is not set to fixed, the limit is the minimum value across all input "
148 "data.",
149 default=0.0,
150 )
151 upperRange = Field[float](
152 doc="Upper range specifier for the histogram bins. See rangeType for interpretation "
153 "based on the type of range requested. If more than one histogram is plotted in a given "
154 "panel and rangeType is not set to fixed, the limit is the maximum value across all input "
155 "data.",
156 default=100.0,
157 )
158 referenceValue = Field[float](
159 doc="Value at which to add a black solid vertical line. Ignored if set to `None`.",
160 default=None,
161 optional=True,
162 )
163 histDensity = Field[bool](
164 doc="Whether to plot the histogram as a normalized probability distribution. Must also "
165 "provide a value for referenceValue",
166 default=False,
167 )
168 statsPanel = ConfigField[HistStatsPanel](
169 doc="configuration for stats to be shown on plot, if None then "
170 "default stats: N, median, sigma mad are shown",
171 default=None,
172 )
174 def validate(self):
175 super().validate()
176 if self.rangeType == "percentile" and self.lowerRange < 0.0 or self.upperRange > 100.0:
177 msg = (
178 "For rangeType %s, ranges must obey: lowerRange >= 0 and upperRange <= 100." % self.rangeType
179 )
180 raise FieldValidationError(self.__class__.rangeType, self, msg)
181 if self.rangeType == "sigmaMad" and self.lowerRange < 0.0:
182 msg = (
183 "For rangeType %s, lower range must obey: lowerRange >= 0 (the lower range is "
184 "set as median - lowerRange*sigmaMad." % self.rangeType
185 )
186 raise FieldValidationError(self.__class__.rangeType, self, msg)
187 if self.rangeType == "fixed" and (self.upperRange - self.lowerRange) == 0.0:
188 msg = (
189 "For rangeType %s, lower and upper ranges must differ (i.e. must obey: "
190 "upperRange - lowerRange != 0)." % self.rangeType
191 )
192 raise FieldValidationError(self.__class__.rangeType, self, msg)
193 if self.histDensity and self.referenceValue is None:
194 msg = "Must provide referenceValue if histDensity is True."
195 raise FieldValidationError(self.__class__.referenceValue, self, msg)
198class HistPlot(PlotAction):
199 """Make an N-panel plot with a configurable number of histograms displayed
200 in each panel. Reference lines showing values of interest may also be added
201 to each histogram. Panels are configured using the ``HistPanel`` class.
202 """
204 panels = ConfigDictField(
205 doc="A configurable dict describing the panels to be plotted, and the histograms for each panel.",
206 keytype=str,
207 itemtype=HistPanel,
208 default={},
209 )
210 cmap = Field[str](
211 doc="Color map used for histogram lines. All types available via `plt.cm` may be used. "
212 "A number of custom color maps are also defined: `newtab10`, `bright`, `vibrant`.",
213 default="newtab10",
214 )
216 def getInputSchema(self) -> KeyedDataSchema:
217 for panel in self.panels: # type: ignore
218 for histData in self.panels[panel].hists.items(): # type: ignore
219 yield histData, Vector
221 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure:
222 return self.makePlot(data, **kwargs)
223 # table is a dict that needs: x, y, run, skymap, filter, tract,
225 def makePlot(
226 self, data: KeyedData, plotInfo: Mapping[str, str] = None, **kwargs # type: ignore
227 ) -> Figure:
228 """Make an N-panel plot with a user-configurable number of histograms
229 displayed in each panel.
231 Parameters
232 ----------
233 data : `pandas.core.frame.DataFrame`
234 The catalog to plot the points from.
235 plotInfo : `dict`
236 A dictionary of information about the data being plotted with keys:
237 `"run"`
238 Output run for the plots (`str`).
239 `"tractTableType"`
240 Table from which results are taken (`str`).
241 `"plotName"`
242 Output plot name (`str`)
243 `"SN"`
244 The global signal-to-noise data threshold (`float`)
245 `"skymap"`
246 The type of skymap used for the data (`str`).
247 `"tract"`
248 The tract that the data comes from (`int`).
249 `"bands"`
250 The bands used for this data (`str` or `list`).
251 `"visit"`
252 The visit that the data comes from (`int`)
254 Returns
255 -------
256 fig : `matplotlib.figure.Figure`
257 The resulting figure.
259 Examples
260 --------
261 An example histogram plot may be seen below:
263 .. image:: /_static/analysis_tools/histPlotExample.png
265 For further details on how to generate a plot, please refer to the
266 :ref:`getting started guide<analysis-tools-getting-started>`.
267 """
269 # set up figure
270 fig = plt.figure(dpi=300)
271 hist_fig, side_fig = fig.subfigures(1, 2, wspace=0, width_ratios=[3, 1])
272 axs, ncols, nrows = self._makeAxes(hist_fig)
274 # loop over each panel; plot histograms
275 colors = self._assignColors()
276 nth_panel = len(self.panels)
277 nth_col = ncols
278 nth_row = nrows - 1
279 label_font_size = max(6, 10 - nrows)
280 for panel, ax in zip(self.panels, axs):
281 nth_panel -= 1
282 nth_col = ncols - 1 if nth_col == 0 else nth_col - 1
283 if nth_panel == 0 and nrows * ncols - len(self.panels) > 0:
284 nth_col -= 1
285 # Set font size for legend based on number of panels being plotted.
286 legend_font_size = max(4, int(8 - len(self.panels[panel].hists) / 2 - nrows // 2)) # type: ignore
287 nums, meds, mads, stats_dict = self._makePanel(
288 data,
289 panel,
290 ax,
291 colors[panel],
292 label_font_size=label_font_size,
293 legend_font_size=legend_font_size,
294 ncols=ncols,
295 )
297 all_handles, all_nums, all_meds, all_mads = [], [], [], []
298 handles, labels = ax.get_legend_handles_labels() # code for plotting
299 all_handles += handles
300 all_nums += nums
301 all_meds += meds
302 all_mads += mads
303 title_str = self.panels[panel].label # type: ignore
304 # add side panel; add statistics
305 self._addStatisticsPanel(
306 side_fig,
307 all_handles,
308 all_nums,
309 all_meds,
310 all_mads,
311 stats_dict,
312 legend_font_size=legend_font_size,
313 yAnchor0=ax.get_position().y0,
314 nth_row=nth_row,
315 nth_col=nth_col,
316 title_str=title_str,
317 )
318 nth_row = nth_row - 1 if nth_col == 0 else nth_row
320 # add general plot info
321 if plotInfo is not None:
322 hist_fig = addPlotInfo(hist_fig, plotInfo)
324 # finish up
325 plt.draw()
326 return fig
328 def _makeAxes(self, fig):
329 """Determine axes layout for main histogram figure."""
330 num_panels = len(self.panels)
331 if num_panels <= 1:
332 ncols = 1
333 else:
334 ncols = 2
335 nrows = int(np.ceil(num_panels / ncols))
337 gs = GridSpec(nrows, ncols, left=0.12, right=0.99, bottom=0.1, top=0.88, wspace=0.31, hspace=0.45)
339 axs = []
340 counter = 0
341 for row in range(nrows):
342 for col in range(ncols):
343 counter += 1
344 if counter < num_panels:
345 axs.append(fig.add_subplot(gs[row : row + 1, col : col + 1]))
346 else:
347 axs.append(fig.add_subplot(gs[row : row + 1, col : np.min([col + 2, ncols + 1])]))
348 break
350 return axs, ncols, nrows
352 def _assignColors(self):
353 """Assign colors to histograms using a given color map."""
354 custom_cmaps = dict(
355 # https://www.tableau.com/about/blog/2016/7/colors-upgrade-tableau-10-56782
356 newtab10=[
357 "#4e79a7",
358 "#f28e2b",
359 "#e15759",
360 "#76b7b2",
361 "#59a14f",
362 "#edc948",
363 "#b07aa1",
364 "#ff9da7",
365 "#9c755f",
366 "#bab0ac",
367 ],
368 # https://personal.sron.nl/~pault/#fig:scheme_bright
369 bright=[
370 "#4477AA",
371 "#EE6677",
372 "#228833",
373 "#CCBB44",
374 "#66CCEE",
375 "#AA3377",
376 "#BBBBBB",
377 ],
378 # https://personal.sron.nl/~pault/#fig:scheme_vibrant
379 vibrant=[
380 "#EE7733",
381 "#0077BB",
382 "#33BBEE",
383 "#EE3377",
384 "#CC3311",
385 "#009988",
386 "#BBBBBB",
387 ],
388 )
389 if self.cmap in custom_cmaps.keys():
390 all_colors = custom_cmaps[self.cmap]
391 else:
392 try:
393 all_colors = getattr(plt.cm, self.cmap).copy().colors
394 except AttributeError:
395 raise ValueError(f"Unrecognized color map: {self.cmap}")
397 counter = 0
398 colors = defaultdict(list)
399 for panel in self.panels:
400 for hist in self.panels[panel].hists:
401 colors[panel].append(all_colors[counter % len(all_colors)])
402 counter += 1
403 return colors
405 def _makePanel(self, data, panel, ax, colors, label_font_size=9, legend_font_size=7, ncols=1):
406 """Plot a single panel containing histograms."""
407 nums, meds, mads = [], [], []
408 for i, hist in enumerate(self.panels[panel].hists):
409 hist_data = data[hist][np.isfinite(data[hist])]
410 num, med, mad = self._calcStats(hist_data)
411 nums.append(num)
412 meds.append(med)
413 mads.append(mad)
414 panel_range = self._getPanelRange(data, panel, mads=mads, meds=meds)
416 for i, hist in enumerate(self.panels[panel].hists):
417 hist_data = data[hist][np.isfinite(data[hist])]
418 ax.hist(
419 hist_data,
420 range=panel_range,
421 bins=self.panels[panel].bins,
422 histtype="step",
423 density=self.panels[panel].histDensity,
424 lw=2,
425 color=colors[i],
426 label=self.panels[panel].hists[hist],
427 )
428 ax.axvline(meds[i], ls=(0, (5, 3)), lw=1, c=colors[i])
430 ax.legend(fontsize=legend_font_size, loc="upper left", frameon=False)
431 ax.set_xlim(panel_range)
432 # The following accommodates spacing for ranges with large numbers
433 # but small-ish dynamic range (example use case: RA 300-301).
434 if ncols > 1 and max(np.abs(panel_range)) >= 100 and (panel_range[1] - panel_range[0]) < 5:
435 ax.xaxis.set_major_formatter("{x:.2f}")
436 ax.tick_params(axis="x", labelrotation=25, pad=-1)
437 ax.set_xlabel(self.panels[panel].label, fontsize=label_font_size)
438 y_label = "Normalized (PDF)" if self.panels[panel].histDensity else "Frequency"
439 ax.set_ylabel(y_label, fontsize=label_font_size)
440 ax.set_yscale(self.panels[panel].yscale)
441 ax.tick_params(labelsize=max(5, label_font_size - 2))
442 # add a buffer to the top of the plot to allow headspace for labels
443 ylims = list(ax.get_ylim())
444 if ax.get_yscale() == "log":
445 ylims[1] = 10 ** (np.log10(ylims[1]) * 1.1)
446 else:
447 ylims[1] *= 1.1
448 ax.set_ylim(ylims[0], ylims[1])
450 # Draw a vertical line at a reference value, if given. If histDensity
451 # is True, also plot a reference PDF with mean = referenceValue and
452 # sigma = 1 for reference.
453 if self.panels[panel].referenceValue is not None:
454 ax = self._addReferenceLines(ax, panel, panel_range, legend_font_size=legend_font_size)
456 # Check if we should use the default stats panel or if a custom one
457 # has been created.
458 statList = [
459 self.panels[panel].statsPanel.stat1,
460 self.panels[panel].statsPanel.stat2,
461 self.panels[panel].statsPanel.stat3,
462 ]
463 if not any(statList):
464 stats_dict = {
465 "statLabels": ["N$_{{data}}$", "Med", "${{\\sigma}}_{{MAD}}$"],
466 "stat1": nums,
467 "stat2": meds,
468 "stat3": mads,
469 }
470 elif all(statList):
471 stat1 = [data[stat] for stat in self.panels[panel].statsPanel.stat1]
472 stat2 = [data[stat] for stat in self.panels[panel].statsPanel.stat2]
473 stat3 = [data[stat] for stat in self.panels[panel].statsPanel.stat3]
474 stats_dict = {
475 "statLabels": self.panels[panel].statsPanel.statsLabels,
476 "stat1": stat1,
477 "stat2": stat2,
478 "stat3": stat3,
479 }
480 else:
481 raise RuntimeError("Invalid configuration of HistStatPanel")
483 return nums, meds, mads, stats_dict
485 def _getPanelRange(self, data, panel, mads=None, meds=None):
486 """Determine panel x-axis range based config settings."""
487 panel_range = [np.nan, np.nan]
488 rangeType = self.panels[panel].rangeType
489 lowerRange = self.panels[panel].lowerRange
490 upperRange = self.panels[panel].upperRange
491 if rangeType == "percentile":
492 panel_range = self._getPercentilePanelRange(data, panel)
493 elif rangeType == "sigmaMad":
494 # Set the panel range to extend lowerRange[upperRange] times the
495 # maximum sigmaMad for the datasets in the panel to the left[right]
496 # from the minimum[maximum] median value of all datasets in the
497 # panel.
498 maxMad = np.nanmax(mads)
499 maxMed = np.nanmax(meds)
500 minMed = np.nanmin(meds)
501 panel_range = [minMed - lowerRange * maxMad, maxMed + upperRange * maxMad]
502 if panel_range[1] - panel_range[0] == 0:
503 self.log.info(
504 "NOTE: panel_range for {} based on med/sigMad was 0. Computing using "
505 "percentile range instead.".format(panel)
506 )
507 panel_range = self._getPercentilePanelRange(data, panel)
508 elif rangeType == "fixed":
509 panel_range = [lowerRange, upperRange]
510 else:
511 raise RuntimeError(f"Invalid rangeType: {rangeType}")
512 return panel_range
514 def _getPercentilePanelRange(self, data, panel):
515 """Determine panel x-axis range based on data percentile limits."""
516 panel_range = [np.nan, np.nan]
517 for hist in self.panels[panel].hists:
518 hist_range = np.nanpercentile(
519 data[hist], [self.panels[panel].lowerRange, self.panels[panel].upperRange]
520 )
521 panel_range[0] = np.nanmin([panel_range[0], hist_range[0]])
522 panel_range[1] = np.nanmax([panel_range[1], hist_range[1]])
523 return panel_range
525 def _calcStats(self, data):
526 """Calculate the number of data points, median, and median absolute
527 deviation of input data."""
528 num = len(data)
529 med = np.nanmedian(data)
530 mad = sigmaMad(data)
531 return num, med, mad
533 def _addReferenceLines(self, ax, panel, panel_range, legend_font_size=7):
534 """Draw the vertical reference line and density curve (if requested)
535 on the panel.
536 """
537 ax2 = ax.twinx()
538 ax2.axis("off")
539 ax2.set_xlim(ax.get_xlim())
540 ax2.set_ylim(ax.get_ylim())
542 if self.panels[panel].histDensity:
543 reference_label = None
544 else:
545 reference_label = "${{\\mu_{{ref}}}}$: {}".format(self.panels[panel].referenceValue)
546 ax2.axvline(
547 self.panels[panel].referenceValue, ls="-", lw=1, c="black", zorder=0, label=reference_label
548 )
549 if self.panels[panel].histDensity:
550 ref_x = np.arange(panel_range[0], panel_range[1], (panel_range[1] - panel_range[0]) / 100.0)
551 ref_mean = self.panels[panel].referenceValue
552 ref_std = 1.0
553 ref_y = (
554 1.0
555 / (ref_std * np.sqrt(2.0 * np.pi))
556 * np.exp(-((ref_x - ref_mean) ** 2) / (2.0 * ref_std**2))
557 )
558 ax2.fill_between(ref_x, ref_y, alpha=0.1, color="black", label="P$_{{norm}}(0,1)$", zorder=-1)
559 # Make sure the y-axis extends beyond the data plotted and that
560 # the y-ranges of both axes are in sync.
561 y_max = max(max(ref_y), ax2.get_ylim()[1])
562 if ax2.get_ylim()[1] < 1.05 * y_max:
563 ax.set_ylim(ax.get_ylim()[0], 1.05 * y_max)
564 ax2.set_ylim(ax.get_ylim())
565 ax2.legend(fontsize=legend_font_size, handlelength=1.5, loc="upper right", frameon=False)
567 return ax
569 def _addStatisticsPanel(
570 self,
571 fig,
572 handles,
573 nums,
574 meds,
575 mads,
576 stats_dict,
577 legend_font_size=8,
578 yAnchor0=0.0,
579 nth_row=0,
580 nth_col=0,
581 title_str=None,
582 ):
583 """Add an adjoining panel containing histogram summary statistics."""
584 ax = fig.add_subplot(1, 1, 1)
585 ax.axis("off")
586 plt.subplots_adjust(left=0.05, right=0.95, bottom=0.0, top=1.0)
587 # empty handle, used to populate the bespoke legend layout
588 empty = Rectangle((0, 0), 1, 1, fc="w", fill=False, edgecolor="none", linewidth=0)
590 # set up new legend handles and labels
591 legend_handles = [empty] + handles + ([empty] * 3 * len(handles)) + ([empty] * 3)
593 legend_labels = (
594 ([""] * (len(handles) + 1))
595 + [stats_dict["statLabels"][0]]
596 + [f"{x:.3g}" for x in stats_dict["stat1"]]
597 + [stats_dict["statLabels"][1]]
598 + [f"{x:.3g}" for x in stats_dict["stat2"]]
599 + [stats_dict["statLabels"][2]]
600 + [f"{x:.3g}" for x in stats_dict["stat3"]]
601 )
602 # Set the y anchor for the legend such that it roughly lines up with
603 # the panels.
604 yAnchor = max(0, yAnchor0 - 0.01) + nth_col * (0.008 + len(nums) * 0.005) * legend_font_size
606 nth_legend = ax.legend(
607 legend_handles,
608 legend_labels,
609 loc="lower left",
610 bbox_to_anchor=(0.0, yAnchor),
611 ncol=4,
612 handletextpad=-0.25,
613 fontsize=legend_font_size,
614 borderpad=0,
615 frameon=False,
616 columnspacing=-0.25,
617 title=title_str,
618 title_fontproperties={"weight": "bold", "size": legend_font_size},
619 )
620 if nth_row + nth_col > 0:
621 ax.add_artist(nth_legend)