Coverage for python/lsst/analysis/tools/actions/plot/histPlot.py: 15%
228 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-28 04:49 -0700
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-28 04:49 -0700
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ("HistPanel", "HistPlot", "HistStatsPanel")
25import logging
26from collections import defaultdict
27from typing import Mapping
29import matplotlib.pyplot as plt
30import numpy as np
31from lsst.pex.config import (
32 ChoiceField,
33 Config,
34 ConfigDictField,
35 ConfigField,
36 DictField,
37 Field,
38 FieldValidationError,
39 ListField,
40)
41from matplotlib.figure import Figure
42from matplotlib.gridspec import GridSpec
43from matplotlib.patches import Rectangle
45from ...interfaces import KeyedData, KeyedDataSchema, PlotAction, Vector
46from ...statistics import sigmaMad
47from .plotUtils import addPlotInfo
49log = logging.getLogger(__name__)
52class HistStatsPanel(Config):
53 """A Config class that holds parameters to configure a the stats panel
54 shown for histPlot.
56 The fields in this class correspond to the parameters that can be used to
57 customize the HistPlot stats panel.
59 - The ListField parameter a dict to specify names of 3 stat columns accepts
60 latex formating
62 - The other parameters (stat1, stat2, stat3) are lists of strings that
63 specify vector keys correspoinding to scalar values computed in the
64 prep/process/produce steps of an analysis tools plot/metric configurable
65 action. There should be one key for each group in the HistPanel.
67 A separate config class is used instead of constructing
68 `~lsst.pex.config.DictField`'s in HistPanel for each parameter for clarity
69 and consistency.
71 Notes
72 -----
73 This is intended to be used as a configuration of the HistPlot/HistPanel
74 class.
76 If no HistStatsPanel is specified then the default behavor persists where
77 the stats panel shows N / median / sigma_mad for each group in the panel.
78 """
80 statsLabels = ListField[str](
81 doc="list specifying the labels for stats",
82 length=3,
83 default=("N$_{{data}}$", "Med", "${{\\sigma}}_{{MAD}}$"),
84 )
85 stat1 = ListField[str](
86 doc="A list specifying the vector keys of the first scalar statistic to be shown in this panel."
87 "there should be one entry for each hist in the panel",
88 default=None,
89 optional=True,
90 )
91 stat2 = ListField[str](
92 doc="A list specifying the vector keys of the second scalar statistic to be shown in this panel."
93 "there should be one entry for each hist in the panel",
94 default=None,
95 optional=True,
96 )
97 stat3 = ListField[str](
98 doc="A list specifying the vector keys of the third scalar statistic to be shown in this panel."
99 "there should be one entry for each hist in the panel",
100 default=None,
101 optional=True,
102 )
104 def validate(self):
105 super().validate()
106 if not all([self.stat1, self.stat2, self.stat3]) and any([self.stat1, self.stat2, self.stat3]):
107 raise ValueError(f"{self._name}: If one stat is configured, all 3 stats must be configured")
110class HistPanel(Config):
111 """A Config class that holds parameters to configure a single panel of a
112 histogram plot. This class is intended to be used within the ``HistPlot``
113 class.
114 """
116 label = Field[str](
117 doc="Panel x-axis label.",
118 default="label",
119 )
120 hists = DictField[str, str](
121 doc="A dict specifying the histograms to be plotted in this panel. Keys are used to identify "
122 "histogram IDs. Values are used to add to the legend label displayed in the upper corner of the "
123 "panel.",
124 optional=False,
125 )
126 yscale = Field[str](
127 doc="Y axis scaling.",
128 default="linear",
129 )
130 bins = Field[int](
131 doc="Number of x axis bins within plot x-range.",
132 default=50,
133 )
134 rangeType = ChoiceField[str](
135 doc="Set the type of range to use for the x-axis. Range bounds will be set according to "
136 "the values of lowerRange and upperRange.",
137 allowed={
138 "percentile": "Upper and lower percentile ranges of the data.",
139 "sigmaMad": "Range is (sigmaMad - lowerRange*sigmaMad, sigmaMad + upperRange*sigmaMad).",
140 "fixed": "Range is fixed to (lowerRange, upperRange).",
141 },
142 default="percentile",
143 )
144 lowerRange = Field[float](
145 doc="Lower range specifier for the histogram bins. See rangeType for interpretation "
146 "based on the type of range requested. If more than one histogram is plotted in a given "
147 "panel and rangeType is not set to fixed, the limit is the minimum value across all input "
148 "data.",
149 default=0.0,
150 )
151 upperRange = Field[float](
152 doc="Upper range specifier for the histogram bins. See rangeType for interpretation "
153 "based on the type of range requested. If more than one histogram is plotted in a given "
154 "panel and rangeType is not set to fixed, the limit is the maximum value across all input "
155 "data.",
156 default=100.0,
157 )
158 referenceValue = Field[float](
159 doc="Value at which to add a black solid vertical line. Ignored if set to `None`.",
160 default=None,
161 optional=True,
162 )
163 histDensity = Field[bool](
164 doc="Whether to plot the histogram as a normalized probability distribution. Must also "
165 "provide a value for referenceValue",
166 default=False,
167 )
168 statsPanel = ConfigField[HistStatsPanel](
169 doc="configuration for stats to be shown on plot, if None then "
170 "default stats: N, median, sigma mad are shown",
171 default=None,
172 )
174 def validate(self):
175 super().validate()
176 if self.rangeType == "percentile" and self.lowerRange < 0.0 or self.upperRange > 100.0:
177 msg = (
178 "For rangeType %s, ranges must obey: lowerRange >= 0 and upperRange <= 100." % self.rangeType
179 )
180 raise FieldValidationError(self.__class__.rangeType, self, msg)
181 if self.rangeType == "sigmaMad" and self.lowerRange < 0.0:
182 msg = (
183 "For rangeType %s, lower range must obey: lowerRange >= 0 (the lower range is "
184 "set as median - lowerRange*sigmaMad." % self.rangeType
185 )
186 raise FieldValidationError(self.__class__.rangeType, self, msg)
187 if self.rangeType == "fixed" and (self.upperRange - self.lowerRange) == 0.0:
188 msg = (
189 "For rangeType %s, lower and upper ranges must differ (i.e. must obey: "
190 "upperRange - lowerRange != 0)." % self.rangeType
191 )
192 raise FieldValidationError(self.__class__.rangeType, self, msg)
193 if self.histDensity and self.referenceValue is None:
194 msg = "Must provide referenceValue if histDensity is True."
195 raise FieldValidationError(self.__class__.referenceValue, self, msg)
198class HistPlot(PlotAction):
199 """Make an N-panel plot with a configurable number of histograms displayed
200 in each panel. Reference lines showing values of interest may also be added
201 to each histogram. Panels are configured using the ``HistPanel`` class.
202 """
204 panels = ConfigDictField(
205 doc="A configurable dict describing the panels to be plotted, and the histograms for each panel.",
206 keytype=str,
207 itemtype=HistPanel,
208 default={},
209 )
210 cmap = Field[str](
211 doc="Color map used for histogram lines. All types available via `plt.cm` may be used. "
212 "A number of custom color maps are also defined: `newtab10`, `bright`, `vibrant`.",
213 default="newtab10",
214 )
216 def getInputSchema(self) -> KeyedDataSchema:
217 for panel in self.panels: # type: ignore
218 for histData in self.panels[panel].hists.items(): # type: ignore
219 yield histData, Vector
221 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure:
222 return self.makePlot(data, **kwargs)
223 # table is a dict that needs: x, y, run, skymap, filter, tract,
225 def makePlot(
226 self, data: KeyedData, plotInfo: Mapping[str, str] = None, **kwargs # type: ignore
227 ) -> Figure:
228 """Make an N-panel plot with a user-configurable number of histograms
229 displayed in each panel.
231 Parameters
232 ----------
233 data : `pandas.core.frame.DataFrame`
234 The catalog to plot the points from.
235 plotInfo : `dict`
236 A dictionary of information about the data being plotted with keys:
237 `"run"`
238 Output run for the plots (`str`).
239 `"tractTableType"`
240 Table from which results are taken (`str`).
241 `"plotName"`
242 Output plot name (`str`)
243 `"SN"`
244 The global signal-to-noise data threshold (`float`)
245 `"skymap"`
246 The type of skymap used for the data (`str`).
247 `"tract"`
248 The tract that the data comes from (`int`).
249 `"bands"`
250 The bands used for this data (`str` or `list`).
251 `"visit"`
252 The visit that the data comes from (`int`)
254 Returns
255 -------
256 fig : `matplotlib.figure.Figure`
257 The resulting figure.
259 Examples
260 --------
261 An example histogram plot may be seen below:
263 .. image:: /_static/analysis_tools/histPlotExample.png
265 For further details on how to generate a plot, please refer to the
266 :ref:`getting started guide<analysis-tools-getting-started>`.
267 """
269 # set up figure
270 fig = plt.figure(dpi=300)
271 hist_fig, side_fig = fig.subfigures(1, 2, wspace=0, width_ratios=[3, 1])
272 axs, ncols, nrows = self._makeAxes(hist_fig)
274 # loop over each panel; plot histograms
275 colors = self._assignColors()
276 nth_panel = len(self.panels)
277 nth_col = ncols
278 nth_row = nrows - 1
279 label_font_size = max(6, 10 - nrows)
280 for panel, ax in zip(self.panels, axs):
281 nth_panel -= 1
282 nth_col = ncols - 1 if nth_col == 0 else nth_col - 1
283 if nth_panel == 0 and nrows * ncols - len(self.panels) > 0:
284 nth_col -= 1
285 # Set font size for legend based on number of panels being plotted.
286 legend_font_size = max(4, int(8 - len(self.panels[panel].hists) / 2 - nrows // 2)) # type: ignore
287 nums, meds, mads, stats_dict = self._makePanel(
288 data,
289 panel,
290 ax,
291 colors[panel],
292 label_font_size=label_font_size,
293 legend_font_size=legend_font_size,
294 ncols=ncols,
295 )
297 all_handles, all_nums, all_meds, all_mads = [], [], [], []
298 handles, labels = ax.get_legend_handles_labels() # code for plotting
299 all_handles += handles
300 all_nums += nums
301 all_meds += meds
302 all_mads += mads
303 title_str = self.panels[panel].label # type: ignore
304 # add side panel; add statistics
305 self._addStatisticsPanel(
306 side_fig,
307 all_handles,
308 all_nums,
309 all_meds,
310 all_mads,
311 stats_dict,
312 legend_font_size=legend_font_size,
313 yAnchor0=ax.get_position().y0,
314 nth_row=nth_row,
315 nth_col=nth_col,
316 title_str=title_str,
317 )
318 nth_row = nth_row - 1 if nth_col == 0 else nth_row
320 # add general plot info
321 if plotInfo is not None:
322 hist_fig = addPlotInfo(hist_fig, plotInfo)
324 # finish up
325 plt.draw()
326 return fig
328 def _makeAxes(self, fig):
329 """Determine axes layout for main histogram figure."""
330 num_panels = len(self.panels)
331 if num_panels <= 1:
332 ncols = 1
333 else:
334 ncols = 2
335 nrows = int(np.ceil(num_panels / ncols))
337 gs = GridSpec(nrows, ncols, left=0.12, right=0.99, bottom=0.1, top=0.88, wspace=0.31, hspace=0.45)
339 axs = []
340 counter = 0
341 for row in range(nrows):
342 for col in range(ncols):
343 counter += 1
344 if counter < num_panels:
345 axs.append(fig.add_subplot(gs[row : row + 1, col : col + 1]))
346 else:
347 axs.append(fig.add_subplot(gs[row : row + 1, col : np.min([col + 2, ncols + 1])]))
348 break
350 return axs, ncols, nrows
352 def _assignColors(self):
353 """Assign colors to histograms using a given color map."""
354 custom_cmaps = dict(
355 # https://www.tableau.com/about/blog/2016/7/colors-upgrade-tableau-10-56782
356 newtab10=[
357 "#4e79a7",
358 "#f28e2b",
359 "#e15759",
360 "#76b7b2",
361 "#59a14f",
362 "#edc948",
363 "#b07aa1",
364 "#ff9da7",
365 "#9c755f",
366 "#bab0ac",
367 ],
368 # https://personal.sron.nl/~pault/#fig:scheme_bright
369 bright=[
370 "#4477AA",
371 "#EE6677",
372 "#228833",
373 "#CCBB44",
374 "#66CCEE",
375 "#AA3377",
376 "#BBBBBB",
377 ],
378 # https://personal.sron.nl/~pault/#fig:scheme_vibrant
379 vibrant=[
380 "#EE7733",
381 "#0077BB",
382 "#33BBEE",
383 "#EE3377",
384 "#CC3311",
385 "#009988",
386 "#BBBBBB",
387 ],
388 )
389 if self.cmap in custom_cmaps.keys():
390 all_colors = custom_cmaps[self.cmap]
391 else:
392 try:
393 all_colors = getattr(plt.cm, self.cmap).copy().colors
394 except AttributeError:
395 raise ValueError(f"Unrecognized color map: {self.cmap}")
397 counter = 0
398 colors = defaultdict(list)
399 for panel in self.panels:
400 for hist in self.panels[panel].hists:
401 colors[panel].append(all_colors[counter % len(all_colors)])
402 counter += 1
403 return colors
405 def _makePanel(self, data, panel, ax, colors, label_font_size=9, legend_font_size=7, ncols=1):
406 """Plot a single panel containing histograms."""
407 nums, meds, mads = [], [], []
408 for i, hist in enumerate(self.panels[panel].hists):
409 hist_data = data[hist][np.isfinite(data[hist])]
410 num, med, mad = self._calcStats(hist_data)
411 nums.append(num)
412 meds.append(med)
413 mads.append(mad)
414 panel_range = self._getPanelRange(data, panel, mads=mads, meds=meds)
415 if all(np.isfinite(panel_range)):
416 for i, hist in enumerate(self.panels[panel].hists):
417 hist_data = data[hist][np.isfinite(data[hist])]
418 if len(hist_data) > 0:
419 ax.hist(
420 hist_data,
421 range=panel_range,
422 bins=self.panels[panel].bins,
423 histtype="step",
424 density=self.panels[panel].histDensity,
425 lw=2,
426 color=colors[i],
427 label=self.panels[panel].hists[hist],
428 )
429 ax.axvline(meds[i], ls=(0, (5, 3)), lw=1, c=colors[i])
431 ax.legend(fontsize=legend_font_size, loc="upper left", frameon=False)
432 ax.set_xlim(panel_range)
433 # The following accommodates spacing for ranges with large numbers
434 # but small-ish dynamic range (example use case: RA 300-301).
435 if ncols > 1 and max(np.abs(panel_range)) >= 100 and (panel_range[1] - panel_range[0]) < 5:
436 ax.xaxis.set_major_formatter("{x:.2f}")
437 ax.tick_params(axis="x", labelrotation=25, pad=-1)
438 ax.set_xlabel(self.panels[panel].label, fontsize=label_font_size)
439 y_label = "Normalized (PDF)" if self.panels[panel].histDensity else "Frequency"
440 ax.set_ylabel(y_label, fontsize=label_font_size)
441 ax.set_yscale(self.panels[panel].yscale)
442 ax.tick_params(labelsize=max(5, label_font_size - 2))
443 # add a buffer to the top of the plot to allow headspace for labels
444 ylims = list(ax.get_ylim())
445 if ax.get_yscale() == "log":
446 ylims[1] = 10 ** (np.log10(ylims[1]) * 1.1)
447 else:
448 ylims[1] *= 1.1
449 ax.set_ylim(ylims[0], ylims[1])
451 # Draw a vertical line at a reference value, if given.
452 # If histDensity is True, also plot a reference PDF with
453 # mean = referenceValue and sigma = 1 for reference.
454 if self.panels[panel].referenceValue is not None:
455 ax = self._addReferenceLines(ax, panel, panel_range, legend_font_size=legend_font_size)
457 # Check if we should use the default stats panel or if a custom one
458 # has been created.
459 statList = [
460 self.panels[panel].statsPanel.stat1,
461 self.panels[panel].statsPanel.stat2,
462 self.panels[panel].statsPanel.stat3,
463 ]
464 if not any(statList):
465 stats_dict = {
466 "statLabels": ["N$_{{data}}$", "Med", "${{\\sigma}}_{{MAD}}$"],
467 "stat1": nums,
468 "stat2": meds,
469 "stat3": mads,
470 }
471 elif all(statList):
472 stat1 = [data[stat] for stat in self.panels[panel].statsPanel.stat1]
473 stat2 = [data[stat] for stat in self.panels[panel].statsPanel.stat2]
474 stat3 = [data[stat] for stat in self.panels[panel].statsPanel.stat3]
475 stats_dict = {
476 "statLabels": self.panels[panel].statsPanel.statsLabels,
477 "stat1": stat1,
478 "stat2": stat2,
479 "stat3": stat3,
480 }
481 else:
482 raise RuntimeError("Invalid configuration of HistStatPanel")
483 else:
484 stats_dict = {key: [] for key in ("stat1", "stat2", "stat3")}
485 stats_dict["statLabels"] = [""] * 3
486 return nums, meds, mads, stats_dict
488 def _getPanelRange(self, data, panel, mads=None, meds=None):
489 """Determine panel x-axis range based config settings."""
490 panel_range = [np.nan, np.nan]
491 rangeType = self.panels[panel].rangeType
492 lowerRange = self.panels[panel].lowerRange
493 upperRange = self.panels[panel].upperRange
494 if rangeType == "percentile":
495 panel_range = self._getPercentilePanelRange(data, panel)
496 elif rangeType == "sigmaMad":
497 # Set the panel range to extend lowerRange[upperRange] times the
498 # maximum sigmaMad for the datasets in the panel to the left[right]
499 # from the minimum[maximum] median value of all datasets in the
500 # panel.
501 maxMad = np.nanmax(mads)
502 maxMed = np.nanmax(meds)
503 minMed = np.nanmin(meds)
504 panel_range = [minMed - lowerRange * maxMad, maxMed + upperRange * maxMad]
505 if panel_range[1] - panel_range[0] == 0:
506 self.log.info(
507 "NOTE: panel_range for {} based on med/sigMad was 0. Computing using "
508 "percentile range instead.".format(panel)
509 )
510 panel_range = self._getPercentilePanelRange(data, panel)
511 elif rangeType == "fixed":
512 panel_range = [lowerRange, upperRange]
513 else:
514 raise RuntimeError(f"Invalid rangeType: {rangeType}")
515 return panel_range
517 def _getPercentilePanelRange(self, data, panel):
518 """Determine panel x-axis range based on data percentile limits."""
519 panel_range = [np.nan, np.nan]
520 for hist in self.panels[panel].hists:
521 data_hist = data[hist]
522 # TODO: Consider raising instead
523 if len(data_hist) > 0:
524 hist_range = np.nanpercentile(
525 data[hist], [self.panels[panel].lowerRange, self.panels[panel].upperRange]
526 )
527 panel_range[0] = np.nanmin([panel_range[0], hist_range[0]])
528 panel_range[1] = np.nanmax([panel_range[1], hist_range[1]])
529 return panel_range
531 def _calcStats(self, data):
532 """Calculate the number of data points, median, and median absolute
533 deviation of input data."""
534 num = len(data)
535 med = np.nanmedian(data)
536 mad = sigmaMad(data)
537 return num, med, mad
539 def _addReferenceLines(self, ax, panel, panel_range, legend_font_size=7):
540 """Draw the vertical reference line and density curve (if requested)
541 on the panel.
542 """
543 ax2 = ax.twinx()
544 ax2.axis("off")
545 ax2.set_xlim(ax.get_xlim())
546 ax2.set_ylim(ax.get_ylim())
548 if self.panels[panel].histDensity:
549 reference_label = None
550 else:
551 reference_label = "${{\\mu_{{ref}}}}$: {}".format(self.panels[panel].referenceValue)
552 ax2.axvline(
553 self.panels[panel].referenceValue, ls="-", lw=1, c="black", zorder=0, label=reference_label
554 )
555 if self.panels[panel].histDensity:
556 ref_x = np.arange(panel_range[0], panel_range[1], (panel_range[1] - panel_range[0]) / 100.0)
557 ref_mean = self.panels[panel].referenceValue
558 ref_std = 1.0
559 ref_y = (
560 1.0
561 / (ref_std * np.sqrt(2.0 * np.pi))
562 * np.exp(-((ref_x - ref_mean) ** 2) / (2.0 * ref_std**2))
563 )
564 ax2.fill_between(ref_x, ref_y, alpha=0.1, color="black", label="P$_{{norm}}(0,1)$", zorder=-1)
565 # Make sure the y-axis extends beyond the data plotted and that
566 # the y-ranges of both axes are in sync.
567 y_max = max(max(ref_y), ax2.get_ylim()[1])
568 if ax2.get_ylim()[1] < 1.05 * y_max:
569 ax.set_ylim(ax.get_ylim()[0], 1.05 * y_max)
570 ax2.set_ylim(ax.get_ylim())
571 ax2.legend(fontsize=legend_font_size, handlelength=1.5, loc="upper right", frameon=False)
573 return ax
575 def _addStatisticsPanel(
576 self,
577 fig,
578 handles,
579 nums,
580 meds,
581 mads,
582 stats_dict,
583 legend_font_size=8,
584 yAnchor0=0.0,
585 nth_row=0,
586 nth_col=0,
587 title_str=None,
588 ):
589 """Add an adjoining panel containing histogram summary statistics."""
590 ax = fig.add_subplot(1, 1, 1)
591 ax.axis("off")
592 plt.subplots_adjust(left=0.05, right=0.95, bottom=0.0, top=1.0)
593 # empty handle, used to populate the bespoke legend layout
594 empty = Rectangle((0, 0), 1, 1, fc="w", fill=False, edgecolor="none", linewidth=0)
596 # set up new legend handles and labels
597 legend_handles = [empty] + handles + ([empty] * 3 * len(handles)) + ([empty] * 3)
599 legend_labels = (
600 ([""] * (len(handles) + 1))
601 + [stats_dict["statLabels"][0]]
602 + [f"{x:.3g}" for x in stats_dict["stat1"]]
603 + [stats_dict["statLabels"][1]]
604 + [f"{x:.3g}" for x in stats_dict["stat2"]]
605 + [stats_dict["statLabels"][2]]
606 + [f"{x:.3g}" for x in stats_dict["stat3"]]
607 )
608 # Set the y anchor for the legend such that it roughly lines up with
609 # the panels.
610 yAnchor = max(0, yAnchor0 - 0.01) + nth_col * (0.008 + len(nums) * 0.005) * legend_font_size
612 nth_legend = ax.legend(
613 legend_handles,
614 legend_labels,
615 loc="lower left",
616 bbox_to_anchor=(0.0, yAnchor),
617 ncol=4,
618 handletextpad=-0.25,
619 fontsize=legend_font_size,
620 borderpad=0,
621 frameon=False,
622 columnspacing=-0.25,
623 title=title_str,
624 title_fontproperties={"weight": "bold", "size": legend_font_size},
625 )
626 if nth_row + nth_col > 0:
627 ax.add_artist(nth_legend)