Coverage for python/lsst/analysis/tools/actions/plot/histPlot.py: 16%
222 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-03-30 03:58 -0700
« prev ^ index » next coverage.py v6.5.0, created at 2023-03-30 03:58 -0700
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ("HistPanel", "HistPlot", "HistStatsPanel")
25import logging
26from collections import defaultdict
27from typing import Mapping
29import matplotlib.pyplot as plt
30import numpy as np
31from lsst.pex.config import (
32 ChoiceField,
33 Config,
34 ConfigDictField,
35 ConfigField,
36 DictField,
37 Field,
38 FieldValidationError,
39 ListField,
40)
41from matplotlib.figure import Figure
42from matplotlib.gridspec import GridSpec
43from matplotlib.patches import Rectangle
45from ...interfaces import KeyedData, KeyedDataSchema, PlotAction, Vector
46from ...statistics import sigmaMad
47from .plotUtils import addPlotInfo
49log = logging.getLogger(__name__)
52class HistStatsPanel(Config):
53 """A Config class that holds parameters to configure a the stats panel
54 shown for histPlot.
56 The fields in this class correspond to the parameters that can be used to
57 customize the HistPlot stats panel.
59 - The ListField parameter a dict to specify names of 3 stat columns accepts
60 latex formating
62 - The other parameters (stat1, stat2, stat3) are lists of strings that
63 specify vector keys correspoinding to scalar values computed in the
64 prep/process/produce steps of an analysis tools plot/metric configurable
65 action. There should be one key for each group in the HistPanel.
67 A separate config class is used instead of constructing
68 `~lsst.pex.config.DictField`'s in HistPanel for each parameter for clarity
69 and consistency.
73 Notes
74 -----
75 This is intended to be used as a configuration of the HistPlot/HistPanel
76 class.
78 If no HistStatsPanel is specified then the default behavor persists where
79 the stats panel shows N / median / sigma_mad for each group in the panel.
80 """
82 statsLabels = ListField[str](
83 doc="list specifying the labels for stats",
84 length=3,
85 default={"N$_{{data}}$", "Med", "${{\\sigma}}_{{MAD}}$"},
86 )
87 stat1 = ListField[str](
88 doc="A list specifying the vector keys of the first scalar statistic to be shown in this panel."
89 "there should be one entry for each hist in the panel",
90 default=None,
91 optional=True,
92 )
93 stat2 = ListField[str](
94 doc="A list specifying the vector keys of the second scalar statistic to be shown in this panel."
95 "there should be one entry for each hist in the panel",
96 default=None,
97 optional=True,
98 )
99 stat3 = ListField[str](
100 doc="A list specifying the vector keys of the third scalar statistic to be shown in this panel."
101 "there should be one entry for each hist in the panel",
102 default=None,
103 optional=True,
104 )
106 def validate(self):
107 super().validate()
108 if not all([self.stat1, self.stat2, self.stat3]) and any([self.stat1, self.stat2, self.stat3]):
109 raise ValueError(f"{self._name}: If one stat is configured, all 3 stats must be configured")
112class HistPanel(Config):
113 label = Field[str](
114 doc="Panel x-axis label.",
115 default="label",
116 )
117 hists = DictField[str, str](
118 doc="A dict specifying the histograms to be plotted in this panel. Keys are used to identify "
119 "histogram IDs. Values are used to add to the legend label displayed in the upper corner of the "
120 "panel.",
121 optional=False,
122 )
123 yscale = Field[str](
124 doc="Y axis scaling.",
125 default="linear",
126 )
127 bins = Field[int](
128 doc="Number of x axis bins within plot x-range.",
129 default=50,
130 )
131 rangeType = ChoiceField[str](
132 doc="Set the type of range to use for the x-axis. Range bounds will be set according to "
133 "the values of lowerRange and upperRange.",
134 allowed={
135 "percentile": "Upper and lower percentile ranges of the data.",
136 "sigmaMad": "Range is (sigmaMad - lowerRange*sigmaMad, sigmaMad + upperRange*sigmaMad).",
137 "fixed": "Range is fixed to (lowerRange, upperRange).",
138 },
139 default="percentile",
140 )
141 lowerRange = Field[float](
142 doc="Lower range specifier for the histogram bins. See rangeType for interpretation "
143 "based on the type of range requested. If more than one histogram is plotted in a given "
144 "panel and rangeType is not set to fixed, the limit is the minimum value across all input "
145 "data.",
146 default=0.0,
147 )
148 upperRange = Field[float](
149 doc="Upper range specifier for the histogram bins. See rangeType for interpretation "
150 "based on the type of range requested. If more than one histogram is plotted in a given "
151 "panel and rangeType is not set to fixed, the limit is the maximum value across all input "
152 "data.",
153 default=100.0,
154 )
155 referenceValue = Field[float](
156 doc="Value at which to add a black solid vertical line. Ignored if set to `None`.",
157 default=None,
158 optional=True,
159 )
160 histDensity = Field[bool](
161 doc="Whether to plot the histogram as a normalized probability distribution. Must also "
162 "provide a value for referenceValue",
163 default=False,
164 )
165 statsPanel = ConfigField[HistStatsPanel](
166 doc="configuration for stats to be shown on plot, if None then "
167 "default stats: N, median, sigma mad are shown",
168 default=None,
169 )
171 def validate(self):
172 super().validate()
173 if self.rangeType == "percentile" and self.lowerRange < 0.0 or self.upperRange > 100.0:
174 msg = (
175 "For rangeType %s, ranges must obey: lowerRange >= 0 and upperRange <= 100." % self.rangeType
176 )
177 raise FieldValidationError(self.__class__.rangeType, self, msg)
178 if self.rangeType == "sigmaMad" and self.lowerRange < 0.0:
179 msg = (
180 "For rangeType %s, lower range must obey: lowerRange >= 0 (the lower range is "
181 "set as median - lowerRange*sigmaMad." % self.rangeType
182 )
183 raise FieldValidationError(self.__class__.rangeType, self, msg)
184 if self.rangeType == "fixed" and (self.upperRange - self.lowerRange) == 0.0:
185 msg = (
186 "For rangeType %s, lower and upper ranges must differ (i.e. must obey: "
187 "upperRange - lowerRange != 0)." % self.rangeType
188 )
189 raise FieldValidationError(self.__class__.rangeType, self, msg)
190 if self.histDensity and self.referenceValue is None:
191 msg = "Must provide referenceValue if histDensity is True."
192 raise FieldValidationError(self.__class__.referenceValue, self, msg)
195class HistPlot(PlotAction):
196 panels = ConfigDictField(
197 doc="A configurable dict describing the panels to be plotted, and the histograms for each panel.",
198 keytype=str,
199 itemtype=HistPanel,
200 default={},
201 )
202 cmap = Field[str](
203 doc="Color map used for histogram lines. All types available via `plt.cm` may be used. "
204 "A number of custom color maps are also defined: `newtab10`, `bright`, `vibrant`.",
205 default="newtab10",
206 )
208 def getInputSchema(self) -> KeyedDataSchema:
209 for panel in self.panels: # type: ignore
210 for histData in self.panels[panel].hists.items(): # type: ignore
211 yield histData, Vector
213 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure:
214 return self.makePlot(data, **kwargs)
215 # table is a dict that needs: x, y, run, skymap, filter, tract,
217 def makePlot(
218 self, data: KeyedData, plotInfo: Mapping[str, str] = None, **kwargs # type: ignore
219 ) -> Figure:
220 """Make an N-panel plot with a user-configurable number of histograms
221 displayed in each panel.
223 Parameters
224 ----------
225 data : `pandas.core.frame.DataFrame`
226 The catalog to plot the points from.
227 plotInfo : `dict`
228 A dictionary of information about the data being plotted with keys:
229 `"run"`
230 Output run for the plots (`str`).
231 `"tractTableType"`
232 Table from which results are taken (`str`).
233 `"plotName"`
234 Output plot name (`str`)
235 `"SN"`
236 The global signal-to-noise data threshold (`float`)
237 `"skymap"`
238 The type of skymap used for the data (`str`).
239 `"tract"`
240 The tract that the data comes from (`int`).
241 `"bands"`
242 The bands used for this data (`str` or `list`).
243 `"visit"`
244 The visit that the data comes from (`int`)
246 Returns
247 -------
248 fig : `matplotlib.figure.Figure`
249 The resulting figure.
251 """
253 # set up figure
254 fig = plt.figure(dpi=300)
255 hist_fig, side_fig = fig.subfigures(1, 2, wspace=0, width_ratios=[3, 1])
256 axs, ncols, nrows = self._makeAxes(hist_fig)
258 # loop over each panel; plot histograms
259 colors = self._assignColors()
260 nth_panel = len(self.panels)
261 nth_col = ncols
262 nth_row = nrows - 1
263 label_font_size = max(6, 10 - nrows)
264 for panel, ax in zip(self.panels, axs):
265 nth_panel -= 1
266 nth_col = ncols - 1 if nth_col == 0 else nth_col - 1
267 if nth_panel == 0 and nrows * ncols - len(self.panels) > 0:
268 nth_col -= 1
269 # Set font size for legend based on number of panels being plotted.
270 legend_font_size = max(4, int(8 - len(self.panels[panel].hists) / 2 - nrows // 2)) # type: ignore
271 nums, meds, mads, stats_dict = self._makePanel(
272 data,
273 panel,
274 ax,
275 colors[panel],
276 label_font_size=label_font_size,
277 legend_font_size=legend_font_size,
278 ncols=ncols,
279 )
281 all_handles, all_nums, all_meds, all_mads = [], [], [], []
282 handles, labels = ax.get_legend_handles_labels() # code for plotting
283 all_handles += handles
284 all_nums += nums
285 all_meds += meds
286 all_mads += mads
287 title_str = self.panels[panel].label # type: ignore
288 # add side panel; add statistics
289 self._addStatisticsPanel(
290 side_fig,
291 all_handles,
292 all_nums,
293 all_meds,
294 all_mads,
295 stats_dict,
296 legend_font_size=legend_font_size,
297 yAnchor0=ax.get_position().y0,
298 nth_row=nth_row,
299 nth_col=nth_col,
300 title_str=title_str,
301 )
302 nth_row = nth_row - 1 if nth_col == 0 else nth_row
304 # add general plot info
305 if plotInfo is not None:
306 hist_fig = addPlotInfo(hist_fig, plotInfo)
308 # finish up
309 plt.draw()
310 return fig
312 def _makeAxes(self, fig):
313 """Determine axes layout for main histogram figure."""
314 num_panels = len(self.panels)
315 if num_panels <= 1:
316 ncols = 1
317 else:
318 ncols = 2
319 nrows = int(np.ceil(num_panels / ncols))
321 gs = GridSpec(nrows, ncols, left=0.12, right=0.99, bottom=0.1, top=0.88, wspace=0.31, hspace=0.45)
323 axs = []
324 counter = 0
325 for row in range(nrows):
326 for col in range(ncols):
327 counter += 1
328 if counter < num_panels:
329 axs.append(fig.add_subplot(gs[row : row + 1, col : col + 1]))
330 else:
331 axs.append(fig.add_subplot(gs[row : row + 1, col : np.min([col + 2, ncols + 1])]))
332 break
334 return axs, ncols, nrows
336 def _assignColors(self):
337 """Assign colors to histograms using a given color map."""
338 custom_cmaps = dict(
339 # https://www.tableau.com/about/blog/2016/7/colors-upgrade-tableau-10-56782
340 newtab10=[
341 "#4e79a7",
342 "#f28e2b",
343 "#e15759",
344 "#76b7b2",
345 "#59a14f",
346 "#edc948",
347 "#b07aa1",
348 "#ff9da7",
349 "#9c755f",
350 "#bab0ac",
351 ],
352 # https://personal.sron.nl/~pault/#fig:scheme_bright
353 bright=[
354 "#4477AA",
355 "#EE6677",
356 "#228833",
357 "#CCBB44",
358 "#66CCEE",
359 "#AA3377",
360 "#BBBBBB",
361 ],
362 # https://personal.sron.nl/~pault/#fig:scheme_vibrant
363 vibrant=[
364 "#EE7733",
365 "#0077BB",
366 "#33BBEE",
367 "#EE3377",
368 "#CC3311",
369 "#009988",
370 "#BBBBBB",
371 ],
372 )
373 if self.cmap in custom_cmaps.keys():
374 all_colors = custom_cmaps[self.cmap]
375 else:
376 try:
377 all_colors = getattr(plt.cm, self.cmap).copy().colors
378 except AttributeError:
379 raise ValueError(f"Unrecognized color map: {self.cmap}")
381 counter = 0
382 colors = defaultdict(list)
383 for panel in self.panels:
384 for hist in self.panels[panel].hists:
385 colors[panel].append(all_colors[counter % len(all_colors)])
386 counter += 1
387 return colors
389 def _makePanel(self, data, panel, ax, colors, label_font_size=9, legend_font_size=7, ncols=1):
390 """Plot a single panel containing histograms."""
391 nums, meds, mads = [], [], []
392 for i, hist in enumerate(self.panels[panel].hists):
393 hist_data = data[hist][np.isfinite(data[hist])]
394 num, med, mad = self._calcStats(hist_data)
395 nums.append(num)
396 meds.append(med)
397 mads.append(mad)
398 panel_range = self._getPanelRange(data, panel, mads=mads, meds=meds)
400 for i, hist in enumerate(self.panels[panel].hists):
401 hist_data = data[hist][np.isfinite(data[hist])]
402 ax.hist(
403 hist_data,
404 range=panel_range,
405 bins=self.panels[panel].bins,
406 histtype="step",
407 density=self.panels[panel].histDensity,
408 lw=2,
409 color=colors[i],
410 label=self.panels[panel].hists[hist],
411 )
412 ax.axvline(meds[i], ls=(0, (5, 3)), lw=1, c=colors[i])
414 ax.legend(fontsize=legend_font_size, loc="upper left", frameon=False)
415 ax.set_xlim(panel_range)
416 # The following accommodates spacing for ranges with large numbers
417 # but small-ish dynamic range (example use case: RA 300-301).
418 if ncols > 1 and max(np.abs(panel_range)) >= 100 and (panel_range[1] - panel_range[0]) < 5:
419 ax.xaxis.set_major_formatter("{x:.2f}")
420 ax.tick_params(axis="x", labelrotation=25, pad=-1)
421 ax.set_xlabel(self.panels[panel].label, fontsize=label_font_size)
422 y_label = "Normalized (PDF)" if self.panels[panel].histDensity else "Frequency"
423 ax.set_ylabel(y_label, fontsize=label_font_size)
424 ax.set_yscale(self.panels[panel].yscale)
425 ax.tick_params(labelsize=max(5, label_font_size - 2))
426 # add a buffer to the top of the plot to allow headspace for labels
427 ylims = list(ax.get_ylim())
428 if ax.get_yscale() == "log":
429 ylims[1] = 10 ** (np.log10(ylims[1]) * 1.1)
430 else:
431 ylims[1] *= 1.1
432 ax.set_ylim(ylims[0], ylims[1])
434 # Draw a vertical line at a reference value, if given. If histDensity
435 # is True, also plot a reference PDF with mean = referenceValue and
436 # sigma = 1 for reference.
437 if self.panels[panel].referenceValue is not None:
438 ax = self._addReferenceLines(ax, panel, panel_range, legend_font_size=legend_font_size)
440 # Check if we should use the default stats panel or if a custom one
441 # has been created.
442 statList = [
443 self.panels[panel].statsPanel.stat1,
444 self.panels[panel].statsPanel.stat2,
445 self.panels[panel].statsPanel.stat3,
446 ]
447 if not any(statList):
448 stats_dict = {
449 "statLabels": ["N$_{{data}}$", "Med", "${{\\sigma}}_{{MAD}}$"],
450 "stat1": nums,
451 "stat2": meds,
452 "stat3": mads,
453 }
454 elif all(statList):
455 stat1 = [data[stat] for stat in self.panels[panel].statsPanel.stat1]
456 stat2 = [data[stat] for stat in self.panels[panel].statsPanel.stat2]
457 stat3 = [data[stat] for stat in self.panels[panel].statsPanel.stat3]
458 stats_dict = {
459 "statLabels": self.panels[panel].statsPanel.statsLabels,
460 "stat1": stat1,
461 "stat2": stat2,
462 "stat3": stat3,
463 }
464 else:
465 raise RuntimeError("Invalid configuration of HistStatPanel")
467 return nums, meds, mads, stats_dict
469 def _getPanelRange(self, data, panel, mads=None, meds=None):
470 """Determine panel x-axis range based config settings."""
471 panel_range = [np.nan, np.nan]
472 rangeType = self.panels[panel].rangeType
473 lowerRange = self.panels[panel].lowerRange
474 upperRange = self.panels[panel].upperRange
475 if rangeType == "percentile":
476 panel_range = self._getPercentilePanelRange(data, panel)
477 elif rangeType == "sigmaMad":
478 # Set the panel range to extend lowerRange[upperRange] times the
479 # maximum sigmaMad for the datasets in the panel to the left[right]
480 # from the minimum[maximum] median value of all datasets in the
481 # panel.
482 maxMad = np.nanmax(mads)
483 maxMed = np.nanmax(meds)
484 minMed = np.nanmin(meds)
485 panel_range = [minMed - lowerRange * maxMad, maxMed + upperRange * maxMad]
486 if panel_range[1] - panel_range[0] == 0:
487 self.log.info(
488 "NOTE: panel_range for {} based on med/sigMad was 0. Computing using "
489 "percentile range instead.".format(panel)
490 )
491 panel_range = self._getPercentilePanelRange(data, panel)
492 elif rangeType == "fixed":
493 panel_range = [lowerRange, upperRange]
494 else:
495 raise RuntimeError(f"Invalid rangeType: {rangeType}")
496 return panel_range
498 def _getPercentilePanelRange(self, data, panel):
499 """Determine panel x-axis range based on data percentile limits."""
500 panel_range = [np.nan, np.nan]
501 for hist in self.panels[panel].hists:
502 hist_range = np.nanpercentile(
503 data[hist], [self.panels[panel].lowerRange, self.panels[panel].upperRange]
504 )
505 panel_range[0] = np.nanmin([panel_range[0], hist_range[0]])
506 panel_range[1] = np.nanmax([panel_range[1], hist_range[1]])
507 return panel_range
509 def _calcStats(self, data):
510 """Calculate the number of data points, median, and median absolute
511 deviation of input data."""
512 num = len(data)
513 med = np.nanmedian(data)
514 mad = sigmaMad(data)
515 return num, med, mad
517 def _addReferenceLines(self, ax, panel, panel_range, legend_font_size=7):
518 """Draw the vertical reference line and density curve (if requested)
519 on the panel.
520 """
521 ax2 = ax.twinx()
522 ax2.axis("off")
523 ax2.set_xlim(ax.get_xlim())
524 ax2.set_ylim(ax.get_ylim())
526 if self.panels[panel].histDensity:
527 reference_label = None
528 else:
529 reference_label = "${{\\mu_{{ref}}}}$: {}".format(self.panels[panel].referenceValue)
530 ax2.axvline(
531 self.panels[panel].referenceValue, ls="-", lw=1, c="black", zorder=0, label=reference_label
532 )
533 if self.panels[panel].histDensity:
534 ref_x = np.arange(panel_range[0], panel_range[1], (panel_range[1] - panel_range[0]) / 100.0)
535 ref_mean = self.panels[panel].referenceValue
536 ref_std = 1.0
537 ref_y = (
538 1.0
539 / (ref_std * np.sqrt(2.0 * np.pi))
540 * np.exp(-((ref_x - ref_mean) ** 2) / (2.0 * ref_std**2))
541 )
542 ax2.fill_between(ref_x, ref_y, alpha=0.1, color="black", label="P$_{{norm}}(0,1)$", zorder=-1)
543 # Make sure the y-axis extends beyond the data plotted and that
544 # the y-ranges of both axes are in sync.
545 y_max = max(max(ref_y), ax2.get_ylim()[1])
546 if ax2.get_ylim()[1] < 1.05 * y_max:
547 ax.set_ylim(ax.get_ylim()[0], 1.05 * y_max)
548 ax2.set_ylim(ax.get_ylim())
549 ax2.legend(fontsize=legend_font_size, handlelength=1.5, loc="upper right", frameon=False)
551 return ax
553 def _addStatisticsPanel(
554 self,
555 fig,
556 handles,
557 nums,
558 meds,
559 mads,
560 stats_dict,
561 legend_font_size=8,
562 yAnchor0=0.0,
563 nth_row=0,
564 nth_col=0,
565 title_str=None,
566 ):
567 """Add an adjoining panel containing histogram summary statistics."""
568 ax = fig.add_subplot(1, 1, 1)
569 ax.axis("off")
570 plt.subplots_adjust(left=0.05, right=0.95, bottom=0.0, top=1.0)
571 # empty handle, used to populate the bespoke legend layout
572 empty = Rectangle((0, 0), 1, 1, fc="w", fill=False, edgecolor="none", linewidth=0)
574 # set up new legend handles and labels
575 legend_handles = [empty] + handles + ([empty] * 3 * len(handles)) + ([empty] * 3)
577 legend_labels = (
578 ([""] * (len(handles) + 1))
579 + [stats_dict["statLabels"][0]]
580 + [f"{x:.3g}" for x in stats_dict["stat1"]]
581 + [stats_dict["statLabels"][1]]
582 + [f"{x:.3g}" for x in stats_dict["stat2"]]
583 + [stats_dict["statLabels"][2]]
584 + [f"{x:.3g}" for x in stats_dict["stat3"]]
585 )
586 # Set the y anchor for the legend such that it roughly lines up with
587 # the panels.
588 yAnchor = max(0, yAnchor0 - 0.01) + nth_col * (0.008 + len(nums) * 0.005) * legend_font_size
590 nth_legend = ax.legend(
591 legend_handles,
592 legend_labels,
593 loc="lower left",
594 bbox_to_anchor=(0.0, yAnchor),
595 ncol=4,
596 handletextpad=-0.25,
597 fontsize=legend_font_size,
598 borderpad=0,
599 frameon=False,
600 columnspacing=-0.25,
601 title=title_str,
602 title_fontproperties={"weight": "bold", "size": legend_font_size},
603 )
604 if nth_row + nth_col > 0:
605 ax.add_artist(nth_legend)