Coverage for python/lsst/analysis/tools/actions/plot/histPlot.py: 15%
203 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-01-04 03:14 -0800
« prev ^ index » next coverage.py v6.5.0, created at 2023-01-04 03:14 -0800
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ("HistPanel", "HistPlot")
25import logging
26from collections import defaultdict
27from typing import Mapping
29import matplotlib.pyplot as plt
30import numpy as np
31from lsst.pex.config import ChoiceField, Config, ConfigDictField, DictField, Field, FieldValidationError
32from matplotlib.figure import Figure
33from matplotlib.gridspec import GridSpec
34from matplotlib.patches import Rectangle
36from ...interfaces import KeyedData, KeyedDataSchema, PlotAction, Vector
37from ...statistics import sigmaMad
38from .plotUtils import addPlotInfo
40log = logging.getLogger(__name__)
43class HistPanel(Config):
44 label = Field[str](
45 doc="Panel x-axis label.",
46 default="label",
47 )
48 hists = DictField[str, str](
49 doc="A dict specifying the histograms to be plotted in this panel. Keys are used to identify "
50 "histogram IDs. Values are used to add to the legend label displayed in the upper corner of the "
51 "panel.",
52 optional=False,
53 )
54 yscale = Field[str](
55 doc="Y axis scaling.",
56 default="linear",
57 )
58 bins = Field[int](
59 doc="Number of x axis bins within plot x-range.",
60 default=50,
61 )
62 rangeType = ChoiceField[str](
63 doc="Set the type of range to use for the x-axis. Range bounds will be set according to "
64 "the values of lowerRange and upperRange.",
65 allowed={
66 "percentile": "Upper and lower percentile ranges of the data.",
67 "sigmaMad": "Range is (sigmaMad - lowerRange*sigmaMad, sigmaMad + upperRange*sigmaMad).",
68 "fixed": "Range is fixed to (lowerRange, upperRange).",
69 },
70 default="percentile",
71 )
72 lowerRange = Field[float](
73 doc="Lower range specifier for the histogram bins. See rangeType for interpretation "
74 "based on the type of range requested. If more than one histogram is plotted in a given "
75 "panel and rangeType is not set to fixed, the limit is the minimum value across all input "
76 "data.",
77 default=0.0,
78 )
79 upperRange = Field[float](
80 doc="Upper range specifier for the histogram bins. See rangeType for interpretation "
81 "based on the type of range requested. If more than one histogram is plotted in a given "
82 "panel and rangeType is not set to fixed, the limit is the maximum value across all input "
83 "data.",
84 default=100.0,
85 )
86 referenceValue = Field[float](
87 doc="Value at which to add a black solid vertical line. Ignored if set to `None`.",
88 default=None,
89 optional=True,
90 )
91 histDensity = Field[bool](
92 doc="Whether to plot the histogram as a normalized probability distribution. Must also "
93 "provide a value for referenceValue",
94 default=False,
95 )
97 def validate(self):
98 super().validate()
99 if self.rangeType == "percentile" and self.lowerRange < 0.0 or self.upperRange > 100.0:
100 msg = (
101 "For rangeType %s, ranges must obey: lowerRange >= 0 and upperRange <= 100." % self.rangeType
102 )
103 raise FieldValidationError(self.__class__.rangeType, self, msg)
104 if self.rangeType == "sigmaMad" and self.lowerRange < 0.0:
105 msg = (
106 "For rangeType %s, lower range must obey: lowerRange >= 0 (the lower range is "
107 "set as median - lowerRange*sigmaMad." % self.rangeType
108 )
109 raise FieldValidationError(self.__class__.rangeType, self, msg)
110 if self.rangeType == "fixed" and (self.upperRange - self.lowerRange) == 0.0:
111 msg = (
112 "For rangeType %s, lower and upper ranges must differ (i.e. must obey: "
113 "upperRange - lowerRange != 0)." % self.rangeType
114 )
115 raise FieldValidationError(self.__class__.rangeType, self, msg)
116 if self.histDensity and self.referenceValue is None:
117 msg = "Must provide referenceValue if histDensity is True."
118 raise FieldValidationError(self.__class__.referenceValue, self, msg)
121class HistPlot(PlotAction):
122 panels = ConfigDictField(
123 doc="A configurable dict describing the panels to be plotted, and the histograms for each panel.",
124 keytype=str,
125 itemtype=HistPanel,
126 default={},
127 )
128 cmap = Field[str](
129 doc="Color map used for histogram lines. All types available via `plt.cm` may be used. "
130 "A number of custom color maps are also defined: `newtab10`, `bright`, `vibrant`.",
131 default="newtab10",
132 )
134 def getInputSchema(self) -> KeyedDataSchema:
135 for panel in self.panels: # type: ignore
136 for histData in self.panels[panel].hists.items(): # type: ignore
137 yield histData, Vector
139 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure:
140 return self.makePlot(data, **kwargs)
141 # table is a dict that needs: x, y, run, skymap, filter, tract,
143 def makePlot(
144 self, data: KeyedData, plotInfo: Mapping[str, str] = None, **kwargs # type: ignore
145 ) -> Figure:
146 """Make an N-panel plot with a user-configurable number of histograms
147 displayed in each panel.
149 Parameters
150 ----------
151 data : `pandas.core.frame.DataFrame`
152 The catalog to plot the points from.
153 plotInfo : `dict`
154 A dictionary of information about the data being plotted with keys:
155 `"run"`
156 Output run for the plots (`str`).
157 `"tractTableType"`
158 Table from which results are taken (`str`).
159 `"plotName"`
160 Output plot name (`str`)
161 `"SN"`
162 The global signal-to-noise data threshold (`float`)
163 `"skymap"`
164 The type of skymap used for the data (`str`).
165 `"tract"`
166 The tract that the data comes from (`int`).
167 `"bands"`
168 The bands used for this data (`str` or `list`).
169 `"visit"`
170 The visit that the data comes from (`int`)
172 Returns
173 -------
174 fig : `matplotlib.figure.Figure`
175 The resulting figure.
177 """
179 # set up figure
180 fig = plt.figure(dpi=300)
181 hist_fig, side_fig = fig.subfigures(1, 2, wspace=0, width_ratios=[3, 1])
182 axs, ncols, nrows = self._makeAxes(hist_fig)
184 # loop over each panel; plot histograms
185 colors = self._assignColors()
186 nth_panel = len(self.panels)
187 nth_col = ncols
188 nth_row = nrows - 1
189 label_font_size = max(6, 10 - nrows)
190 for panel, ax in zip(self.panels, axs):
191 nth_panel -= 1
192 nth_col = ncols - 1 if nth_col == 0 else nth_col - 1
193 if nth_panel == 0 and nrows * ncols - len(self.panels) > 0:
194 nth_col -= 1
195 # Set font size for legend based on number of panels being plotted.
196 legend_font_size = max(4, int(8 - len(self.panels[panel].hists) / 2 - nrows // 2)) # type: ignore
197 nums, meds, mads = self._makePanel(
198 data,
199 panel,
200 ax,
201 colors[panel],
202 label_font_size=label_font_size,
203 legend_font_size=legend_font_size,
204 ncols=ncols,
205 )
207 all_handles, all_nums, all_meds, all_mads = [], [], [], []
208 handles, labels = ax.get_legend_handles_labels() # code for plotting
209 all_handles += handles
210 all_nums += nums
211 all_meds += meds
212 all_mads += mads
213 title_str = self.panels[panel].label # type: ignore
214 # add side panel; add statistics
215 self._addStatisticsPanel(
216 side_fig,
217 all_handles,
218 all_nums,
219 all_meds,
220 all_mads,
221 legend_font_size=legend_font_size,
222 yAnchor0=ax.get_position().y0,
223 nth_row=nth_row,
224 nth_col=nth_col,
225 title_str=title_str,
226 )
227 nth_row = nth_row - 1 if nth_col == 0 else nth_row
229 # add general plot info
230 if plotInfo is not None:
231 hist_fig = addPlotInfo(hist_fig, plotInfo)
233 # finish up
234 plt.draw()
235 return fig
237 def _makeAxes(self, fig):
238 """Determine axes layout for main histogram figure."""
239 num_panels = len(self.panels)
240 if num_panels <= 1:
241 ncols = 1
242 else:
243 ncols = 2
244 nrows = int(np.ceil(num_panels / ncols))
246 gs = GridSpec(nrows, ncols, left=0.12, right=0.99, bottom=0.1, top=0.88, wspace=0.31, hspace=0.45)
248 axs = []
249 counter = 0
250 for row in range(nrows):
251 for col in range(ncols):
252 counter += 1
253 if counter < num_panels:
254 axs.append(fig.add_subplot(gs[row : row + 1, col : col + 1]))
255 else:
256 axs.append(fig.add_subplot(gs[row : row + 1, col : np.min([col + 2, ncols + 1])]))
257 break
259 return axs, ncols, nrows
261 def _assignColors(self):
262 """Assign colors to histograms using a given color map."""
263 custom_cmaps = dict(
264 # https://www.tableau.com/about/blog/2016/7/colors-upgrade-tableau-10-56782
265 newtab10=[
266 "#4e79a7",
267 "#f28e2b",
268 "#e15759",
269 "#76b7b2",
270 "#59a14f",
271 "#edc948",
272 "#b07aa1",
273 "#ff9da7",
274 "#9c755f",
275 "#bab0ac",
276 ],
277 # https://personal.sron.nl/~pault/#fig:scheme_bright
278 bright=[
279 "#4477AA",
280 "#EE6677",
281 "#228833",
282 "#CCBB44",
283 "#66CCEE",
284 "#AA3377",
285 "#BBBBBB",
286 ],
287 # https://personal.sron.nl/~pault/#fig:scheme_vibrant
288 vibrant=[
289 "#EE7733",
290 "#0077BB",
291 "#33BBEE",
292 "#EE3377",
293 "#CC3311",
294 "#009988",
295 "#BBBBBB",
296 ],
297 )
298 if self.cmap in custom_cmaps.keys():
299 all_colors = custom_cmaps[self.cmap]
300 else:
301 try:
302 all_colors = getattr(plt.cm, self.cmap).copy().colors
303 except AttributeError:
304 raise ValueError(f"Unrecognized color map: {self.cmap}")
306 counter = 0
307 colors = defaultdict(list)
308 for panel in self.panels:
309 for hist in self.panels[panel].hists:
310 colors[panel].append(all_colors[counter % len(all_colors)])
311 counter += 1
312 return colors
314 def _makePanel(self, data, panel, ax, colors, label_font_size=9, legend_font_size=7, ncols=1):
315 """Plot a single panel containing histograms."""
316 nums, meds, mads = [], [], []
317 for i, hist in enumerate(self.panels[panel].hists):
318 hist_data = data[hist][np.isfinite(data[hist])]
319 num, med, mad = self._calcStats(hist_data)
320 nums.append(num)
321 meds.append(med)
322 mads.append(mad)
323 panel_range = self._getPanelRange(data, panel, mads=mads, meds=meds)
325 for i, hist in enumerate(self.panels[panel].hists):
326 hist_data = data[hist][np.isfinite(data[hist])]
327 ax.hist(
328 hist_data,
329 range=panel_range,
330 bins=self.panels[panel].bins,
331 histtype="step",
332 density=self.panels[panel].histDensity,
333 lw=2,
334 color=colors[i],
335 label=self.panels[panel].hists[hist],
336 )
337 ax.axvline(meds[i], ls=(0, (5, 3)), lw=1, c=colors[i])
339 ax.legend(fontsize=legend_font_size, loc="upper left", frameon=False)
340 ax.set_xlim(panel_range)
341 # The following accommodates spacing for ranges with large numbers
342 # but small-ish dynamic range (example use case: RA 300-301).
343 if ncols > 1 and max(np.abs(panel_range)) >= 100 and (panel_range[1] - panel_range[0]) < 5:
344 ax.xaxis.set_major_formatter("{x:.2f}")
345 ax.tick_params(axis="x", labelrotation=25, pad=-1)
346 ax.set_xlabel(self.panels[panel].label, fontsize=label_font_size)
347 y_label = "Normalized (PDF)" if self.panels[panel].histDensity else "Frequency"
348 ax.set_ylabel(y_label, fontsize=label_font_size)
349 ax.set_yscale(self.panels[panel].yscale)
350 ax.tick_params(labelsize=max(5, label_font_size - 2))
351 # add a buffer to the top of the plot to allow headspace for labels
352 ylims = list(ax.get_ylim())
353 if ax.get_yscale() == "log":
354 ylims[1] = 10 ** (np.log10(ylims[1]) * 1.1)
355 else:
356 ylims[1] *= 1.1
357 ax.set_ylim(ylims[0], ylims[1])
359 # Draw a vertical line at a reference value, if given. If histDensity
360 # is True, also plot a reference PDF with mean = referenceValue and
361 # sigma = 1 for reference.
362 if self.panels[panel].referenceValue is not None:
363 ax = self._addReferenceLines(ax, panel, panel_range, legend_font_size=legend_font_size)
365 return nums, meds, mads
367 def _getPanelRange(self, data, panel, mads=None, meds=None):
368 """Determine panel x-axis range based config settings."""
369 panel_range = [np.nan, np.nan]
370 rangeType = self.panels[panel].rangeType
371 lowerRange = self.panels[panel].lowerRange
372 upperRange = self.panels[panel].upperRange
373 if rangeType == "percentile":
374 panel_range = self._getPercentilePanelRange(data, panel)
375 elif rangeType == "sigmaMad":
376 # Set the panel range to extend lowerRange[upperRange] times the
377 # maximum sigmaMad for the datasets in the panel to the left[right]
378 # from the minimum[maximum] median value of all datasets in the
379 # panel.
380 maxMad = np.nanmax(mads)
381 maxMed = np.nanmax(meds)
382 minMed = np.nanmin(meds)
383 panel_range = [minMed - lowerRange * maxMad, maxMed + upperRange * maxMad]
384 if panel_range[1] - panel_range[0] == 0:
385 self.log.info(
386 "NOTE: panel_range for {} based on med/sigMad was 0. Computing using "
387 "percentile range instead.".format(panel)
388 )
389 panel_range = self._getPercentilePanelRange(data, panel)
390 elif rangeType == "fixed":
391 panel_range = [lowerRange, upperRange]
392 else:
393 raise RuntimeError(f"Invalid rangeType: {rangeType}")
394 return panel_range
396 def _getPercentilePanelRange(self, data, panel):
397 """Determine panel x-axis range based on data percentile limits."""
398 panel_range = [np.nan, np.nan]
399 for hist in self.panels[panel].hists:
400 hist_range = np.nanpercentile(
401 data[hist], [self.panels[panel].lowerRange, self.panels[panel].upperRange]
402 )
403 panel_range[0] = np.nanmin([panel_range[0], hist_range[0]])
404 panel_range[1] = np.nanmax([panel_range[1], hist_range[1]])
405 return panel_range
407 def _calcStats(self, data):
408 """Calculate the number of data points, median, and median absolute
409 deviation of input data."""
410 num = len(data)
411 med = np.nanmedian(data)
412 mad = sigmaMad(data)
413 return num, med, mad
415 def _addReferenceLines(self, ax, panel, panel_range, legend_font_size=7):
416 """Draw the vertical reference line and density curve (if requested)
417 on the panel.
418 """
419 ax2 = ax.twinx()
420 ax2.axis("off")
421 ax2.set_xlim(ax.get_xlim())
422 ax2.set_ylim(ax.get_ylim())
424 if self.panels[panel].histDensity:
425 reference_label = None
426 else:
427 reference_label = "${{\\mu_{{ref}}}}$: {}".format(self.panels[panel].referenceValue)
428 ax2.axvline(
429 self.panels[panel].referenceValue, ls="-", lw=1, c="black", zorder=0, label=reference_label
430 )
431 if self.panels[panel].histDensity:
432 ref_x = np.arange(panel_range[0], panel_range[1], (panel_range[1] - panel_range[0]) / 100.0)
433 ref_mean = self.panels[panel].referenceValue
434 ref_std = 1.0
435 ref_y = (
436 1.0
437 / (ref_std * np.sqrt(2.0 * np.pi))
438 * np.exp(-((ref_x - ref_mean) ** 2) / (2.0 * ref_std**2))
439 )
440 ax2.fill_between(ref_x, ref_y, alpha=0.1, color="black", label="P$_{{norm}}(0,1)$", zorder=-1)
441 # Make sure the y-axis extends beyond the data plotted and that
442 # the y-ranges of both axes are in sync.
443 y_max = max(max(ref_y), ax2.get_ylim()[1])
444 if ax2.get_ylim()[1] < 1.05 * y_max:
445 ax.set_ylim(ax.get_ylim()[0], 1.05 * y_max)
446 ax2.set_ylim(ax.get_ylim())
447 ax2.legend(fontsize=legend_font_size, handlelength=1.5, loc="upper right", frameon=False)
449 return ax
451 def _addStatisticsPanel(
452 self,
453 fig,
454 handles,
455 nums,
456 meds,
457 mads,
458 legend_font_size=8,
459 yAnchor0=0.0,
460 nth_row=0,
461 nth_col=0,
462 title_str=None,
463 ):
464 """Add an adjoining panel containing histogram summary statistics."""
465 ax = fig.add_subplot(1, 1, 1)
466 ax.axis("off")
467 plt.subplots_adjust(left=0.05, right=0.95, bottom=0.0, top=1.0)
468 # empty handle, used to populate the bespoke legend layout
469 empty = Rectangle((0, 0), 1, 1, fc="w", fill=False, edgecolor="none", linewidth=0)
471 # set up new legend handles and labels
472 legend_handles = [empty] + handles + ([empty] * 3 * len(handles)) + ([empty] * 3)
473 legend_labels = (
474 ([""] * (len(handles) + 1))
475 + ["N$_{{data}}$"]
476 + nums
477 + ["Med"]
478 + [f"{x:.2f}" for x in meds]
479 + ["${{\\sigma}}_{{MAD}}$"]
480 + [f"{x:.2f}" for x in mads]
481 )
483 # Set the y anchor for the legend such that it roughly lines up with
484 # the panels.
485 yAnchor = max(0, yAnchor0 - 0.01) + nth_col * (0.008 + len(nums) * 0.005) * legend_font_size
487 nth_legend = ax.legend(
488 legend_handles,
489 legend_labels,
490 loc="lower left",
491 bbox_to_anchor=(0.0, yAnchor),
492 ncol=4,
493 handletextpad=-0.25,
494 fontsize=legend_font_size,
495 borderpad=0,
496 frameon=False,
497 columnspacing=-0.25,
498 title=title_str,
499 title_fontproperties={"weight": "bold", "size": legend_font_size},
500 )
501 if nth_row + nth_col > 0:
502 ax.add_artist(nth_legend)