Coverage for python/lsst/analysis/tools/actions/plot/scatterplotWithTwoHists.py: 16%
342 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-24 04:10 -0700
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-24 04:10 -0700
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ("ScatterPlotStatsAction", "ScatterPlotWithTwoHists")
26from typing import Mapping, NamedTuple, Optional, cast
28import matplotlib.colors
29import matplotlib.pyplot as plt
30import numpy as np
31from lsst.pex.config import Field
32from lsst.pex.config.configurableActions import ConfigurableActionField
33from lsst.pex.config.listField import ListField
34from lsst.skymap import BaseSkyMap
35from matplotlib import gridspec
36from matplotlib.axes import Axes
37from matplotlib.collections import PolyCollection
38from matplotlib.figure import Figure
39from matplotlib.path import Path
40from mpl_toolkits.axes_grid1 import make_axes_locatable
42from ...interfaces import KeyedData, KeyedDataAction, KeyedDataSchema, PlotAction, Scalar, ScalarType, Vector
43from ...math import nanMedian, nanSigmaMad
44from ..keyedData.summaryStatistics import SummaryStatisticAction
45from ..scalar import MedianAction
46from ..vector import ConvertFluxToMag, SnSelector
47from .plotUtils import addPlotInfo, addSummaryPlot, generateSummaryStats, mkColormap
49# ignore because coolwarm is actually part of module
50cmapPatch = plt.cm.coolwarm.copy() # type: ignore
51cmapPatch.set_bad(color="none")
54class ScatterPlotStatsAction(KeyedDataAction):
55 """Calculates the statistics needed for the
56 scatter plot with two hists.
57 """
59 vectorKey = Field[str](doc="Vector on which to compute statistics")
60 highSNSelector = ConfigurableActionField[SnSelector](
61 doc="Selector used to determine high SN Objects", default=SnSelector(threshold=2700)
62 )
63 lowSNSelector = ConfigurableActionField[SnSelector](
64 doc="Selector used to determine low SN Objects", default=SnSelector(threshold=500)
65 )
66 prefix = Field[str](
67 doc="Prefix for all output fields; will use self.identity if None",
68 optional=True,
69 default=None,
70 )
71 fluxType = Field[str](doc="Vector key to use to compute signal to noise ratio", default="{band}_psfFlux")
72 suffix = Field[str](doc="Suffix for all output fields", default="")
74 def _get_key_prefix(self):
75 prefix = self.prefix if self.prefix else (self.identity if self.identity else "")
76 return prefix
78 def getInputSchema(self, **kwargs) -> KeyedDataSchema:
79 yield (self.vectorKey, Vector)
80 yield (self.fluxType, Vector)
81 yield from self.highSNSelector.getInputSchema()
82 yield from self.lowSNSelector.getInputSchema()
84 def getOutputSchema(self) -> KeyedDataSchema:
85 prefix = self._get_key_prefix()
86 prefix_lower = prefix.lower() if prefix else ""
87 prefix_upper = prefix.capitalize() if prefix else ""
88 suffix = self.suffix
89 return (
90 (f"{prefix_lower}HighSNMask{suffix}", Vector),
91 (f"{prefix_lower}LowSNMask{suffix}", Vector),
92 (f"{{band}}_lowSN{prefix_upper}_median{suffix}", Scalar),
93 (f"{{band}}_lowSN{prefix_upper}_sigmaMad{suffix}", Scalar),
94 (f"{{band}}_lowSN{prefix_upper}_count{suffix}", Scalar),
95 (f"{{band}}_lowSN{prefix_upper}_approxMag{suffix}", Scalar),
96 (f"{{band}}_highSN{prefix_upper}_median{suffix}", Scalar),
97 (f"{{band}}_highSN{prefix_upper}_sigmaMad{suffix}", Scalar),
98 (f"{{band}}_highSN{prefix_upper}_count{suffix}", Scalar),
99 (f"{{band}}_highSN{prefix_upper}_approxMag{suffix}", Scalar),
100 (f"{prefix_lower}HighSNThreshold{suffix}", Scalar),
101 (f"{prefix_lower}LowSNThreshold{suffix}", Scalar),
102 )
104 def __call__(self, data: KeyedData, **kwargs) -> KeyedData:
105 results = {}
106 prefix = self._get_key_prefix()
107 prefix_lower = prefix.lower() if prefix else ""
108 prefix_upper = prefix.capitalize() if prefix else ""
109 suffix = self.suffix
110 highMaskKey = f"{prefix_lower}HighSNMask{suffix}"
111 results[highMaskKey] = self.highSNSelector(data, **kwargs)
113 lowMaskKey = f"{prefix_lower}LowSNMask{suffix}"
114 results[lowMaskKey] = self.lowSNSelector(data, **kwargs)
116 prefix_band = f"{band}_" if (band := kwargs.get("band")) else ""
117 fluxes = data[self.fluxType.format(band=band)] if band is not None else None
119 statAction = SummaryStatisticAction(vectorKey=self.vectorKey)
121 # this is sad, but pex_config seems to have broken behavior that
122 # is dangerous to fix
123 statAction.setDefaults()
125 medianAction = MedianAction(vectorKey="mag")
126 magAction = ConvertFluxToMag(vectorKey="flux")
128 for maskKey, binName in ((lowMaskKey, "low"), (highMaskKey, "high")):
129 name = f"{prefix_band}{binName}SN{prefix_upper}"
130 # set the approxMag to the median mag in the SN selection
131 results[f"{name}_approxMag{suffix}".format(**kwargs)] = (
132 medianAction({"mag": magAction({"flux": fluxes[results[maskKey]]})}) # type: ignore
133 if band is not None
134 else np.nan
135 )
136 stats = statAction(data, **(kwargs | {"mask": results[maskKey]})).items()
137 for name_stat, value in stats:
138 tmpKey = f"{name}_{name_stat}{suffix}".format(**kwargs)
139 results[tmpKey] = value
140 results[f"{prefix_lower}HighSNThreshold{suffix}"] = self.highSNSelector.threshold # type: ignore
141 results[f"{prefix_lower}LowSNThreshold{suffix}"] = self.lowSNSelector.threshold # type: ignore
143 return results
146def _validObjectTypes(value):
147 return value in ("stars", "galaxies", "unknown", "any")
150# ignore type because of conflicting name on tuple baseclass
151class _StatsContainer(NamedTuple):
152 median: Scalar
153 sigmaMad: Scalar
154 count: Scalar # type: ignore
155 approxMag: Scalar
158class DataTypeDefaults(NamedTuple):
159 suffix_stat: str
160 suffix_xy: str
161 color: str
162 colormap: matplotlib.colors.Colormap | None
165class ScatterPlotWithTwoHists(PlotAction):
166 """Makes a scatter plot of the data with a marginal
167 histogram for each axis.
168 """
170 yLims = ListField[float](
171 doc="ylimits of the plot, if not specified determined from data",
172 length=2,
173 optional=True,
174 )
176 xLims = ListField[float](
177 doc="xlimits of the plot, if not specified determined from data", length=2, optional=True
178 )
179 xAxisLabel = Field[str](doc="Label to use for the x axis", optional=False)
180 yAxisLabel = Field[str](doc="Label to use for the y axis", optional=False)
181 magLabel = Field[str](doc="Label to use for the magnitudes used for SNR", optional=False)
183 legendLocation = Field[str](doc="Legend position within main plot", default="upper left")
184 nBins = Field[float](doc="Number of bins on x axis", default=40.0)
185 plot2DHist = Field[bool](
186 doc="Plot a 2D histogram in dense areas of points on the scatter plot."
187 "Doesn't look great if plotting multiple datasets on top of each other.",
188 default=True,
189 )
190 plotTypes = ListField[str](
191 doc="Selection of types of objects to plot. Can take any combination of"
192 " stars, galaxies, unknown, any",
193 optional=False,
194 itemCheck=_validObjectTypes,
195 )
197 addSummaryPlot = Field[bool](
198 doc="Add a summary plot to the figure?",
199 default=True,
200 )
201 suffix_x = Field[str](doc="Suffix for all x-axis action inputs", optional=True, default="")
202 suffix_y = Field[str](doc="Suffix for all y-axis action inputs", optional=True, default="")
203 suffix_stat = Field[str](doc="Suffix for all binned statistic action inputs", optional=True, default="")
205 _stats = ("median", "sigmaMad", "count", "approxMag")
206 _datatypes = {
207 "galaxies": DataTypeDefaults(
208 suffix_stat="Galaxies",
209 suffix_xy="Galaxies",
210 color="firebrick",
211 colormap=mkColormap(["lemonchiffon", "firebrick"]),
212 ),
213 "stars": DataTypeDefaults(
214 suffix_stat="Stars",
215 suffix_xy="Stars",
216 color="midnightblue",
217 colormap=mkColormap(["paleturquoise", "midnightBlue"]),
218 ),
219 "unknown": DataTypeDefaults(
220 suffix_stat="Unknown",
221 suffix_xy="Unknown",
222 color="green",
223 colormap=None,
224 ),
225 "any": DataTypeDefaults(
226 suffix_stat="Any",
227 suffix_xy="",
228 color="purple",
229 colormap=None,
230 ),
231 }
233 def getInputSchema(self) -> KeyedDataSchema:
234 base: list[tuple[str, type[Vector] | ScalarType]] = []
235 for name_datatype in self.plotTypes:
236 config_datatype = self._datatypes[name_datatype]
237 base.append((f"x{config_datatype.suffix_xy}{self.suffix_x}", Vector))
238 base.append((f"y{config_datatype.suffix_xy}{self.suffix_y}", Vector))
239 base.append((f"{name_datatype}HighSNMask{self.suffix_stat}", Vector))
240 base.append((f"{name_datatype}LowSNMask{self.suffix_stat}", Vector))
241 # statistics
242 for name in self._stats:
243 base.append(
244 (f"{{band}}_highSN{config_datatype.suffix_stat}_{name}{self.suffix_stat}", Scalar)
245 )
246 base.append((f"{{band}}_lowSN{config_datatype.suffix_stat}_{name}{self.suffix_stat}", Scalar))
247 base.append((f"{name_datatype}LowSNThreshold{self.suffix_stat}", Scalar))
248 base.append((f"{name_datatype}HighSNThreshold{self.suffix_stat}", Scalar))
250 if self.addSummaryPlot:
251 base.append(("patch", Vector))
253 return base
255 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure:
256 self._validateInput(data, **kwargs)
257 return self.makePlot(data, **kwargs)
259 def _validateInput(self, data: KeyedData, **kwargs) -> None:
260 """NOTE currently can only check that something is not a Scalar, not
261 check that the data is consistent with Vector
262 """
263 needed = self.getFormattedInputSchema(**kwargs)
264 if remainder := {key.format(**kwargs) for key, _ in needed} - {
265 key.format(**kwargs) for key in data.keys()
266 }:
267 raise ValueError(
268 f"Task needs keys {remainder} but they were not found in input keys" f" {list(data.keys())}"
269 )
270 for name, typ in needed:
271 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar)
272 if isScalar and typ != Scalar:
273 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}")
275 def makePlot(
276 self,
277 data: KeyedData,
278 skymap: BaseSkyMap,
279 plotInfo: Mapping[str, str],
280 **kwargs,
281 ) -> Figure:
282 """Makes a generic plot with a 2D histogram and collapsed histograms of
283 each axis.
285 Parameters
286 ----------
287 data : `KeyedData`
288 The catalog to plot the points from.
289 skymap : `lsst.skymap.BaseSkyMap`
290 The skymap that gives the patch locations
291 plotInfo : `dict`
292 A dictionary of information about the data being plotted with keys:
294 * ``"run"``
295 The output run for the plots (`str`).
296 * ``"skymap"``
297 The type of skymap used for the data (`str`).
298 * ``"filter"``
299 The filter used for this data (`str`).
300 * ``"tract"``
301 The tract that the data comes from (`str`).
303 Returns
304 -------
305 fig : `matplotlib.figure.Figure`
306 The resulting figure.
308 Notes
309 -----
310 Uses the axisLabels config options `x` and `y` and the axisAction
311 config options `xAction` and `yAction` to plot a scatter
312 plot of the values against each other. A histogram of the points
313 collapsed onto each axis is also plotted. A summary panel showing the
314 median of the y value in each patch is shown in the upper right corner
315 of the resultant plot. The code uses the selectorActions to decide
316 which points to plot and the statisticSelector actions to determine
317 which points to use for the printed statistics.
319 If this function is being used within the pipetask framework
320 that takes care of making sure that data has all the required
321 elements but if you are running this as a standalone function
322 then you will need to provide the following things in the
323 input data.
325 * If stars is in self.plotTypes:
326 xStars, yStars, starsHighSNMask, starsLowSNMask and
327 {band}_highSNStars_{name}, {band}_lowSNStars_{name}
328 where name is median, sigma_Mad, count and approxMag.
330 * If it is for galaxies/unknowns then replace stars in the above
331 names with galaxies/unknowns.
333 * If it is for any (which covers all the points) then it
334 becomes, x, y, and any instead of stars for the other
335 parameters given above.
337 * In every case it is expected that data contains:
338 lowSnThreshold, highSnThreshold and patch
339 (if the summary plot is being plotted).
341 Examples
342 --------
343 An example of the plot produced from this code is here:
345 .. image:: /_static/analysis_tools/scatterPlotExample.png
347 For a detailed example of how to make a plot from the command line
348 please see the
349 :ref:`getting started guide<analysis-tools-getting-started>`.
350 """
351 if not self.plotTypes:
352 noDataFig = Figure()
353 noDataFig.text(0.3, 0.5, "No data to plot after selectors applied")
354 noDataFig = addPlotInfo(noDataFig, plotInfo)
355 return noDataFig
357 # Set default color and line style for the horizontal
358 # reference line at 0
359 if "hlineColor" not in kwargs:
360 kwargs["hlineColor"] = "black"
362 if "hlineStyle" not in kwargs:
363 kwargs["hlineStyle"] = (0, (1, 4))
365 fig = plt.figure(dpi=300)
366 gs = gridspec.GridSpec(4, 4)
368 # add the various plot elements
369 ax, imhist = self._scatterPlot(data, fig, gs, **kwargs)
370 self._makeTopHistogram(data, fig, gs, ax, **kwargs)
371 self._makeSideHistogram(data, fig, gs, ax, imhist, **kwargs)
372 # Needs info from run quantum
373 if self.addSummaryPlot:
374 sumStats = generateSummaryStats(data, skymap, plotInfo)
375 label = self.yAxisLabel
376 fig = addSummaryPlot(fig, gs[0, -1], sumStats, label)
378 plt.draw()
379 plt.subplots_adjust(wspace=0.0, hspace=0.0, bottom=0.22, left=0.21)
380 fig = addPlotInfo(fig, plotInfo)
381 return fig
383 def _scatterPlot(
384 self, data: KeyedData, fig: Figure, gs: gridspec.GridSpec, **kwargs
385 ) -> tuple[Axes, Optional[PolyCollection]]:
386 suf_x = self.suffix_x
387 suf_y = self.suffix_y
388 suf_stat = self.suffix_stat
389 # Main scatter plot
390 ax = fig.add_subplot(gs[1:, :-1])
392 binThresh = 5
394 yBinsOut = []
395 linesForLegend = []
397 toPlotList = []
398 histIm = None
399 highStats: _StatsContainer
400 lowStats: _StatsContainer
402 for name_datatype in self.plotTypes:
403 config_datatype = self._datatypes[name_datatype]
404 highArgs = {}
405 lowArgs = {}
406 for name in self._stats:
407 highArgs[name] = cast(
408 Scalar,
409 data[f"{{band}}_highSN{config_datatype.suffix_stat}_{name}{suf_stat}".format(**kwargs)],
410 )
411 lowArgs[name] = cast(
412 Scalar,
413 data[f"{{band}}_lowSN{config_datatype.suffix_stat}_{name}{suf_stat}".format(**kwargs)],
414 )
415 highStats = _StatsContainer(**highArgs)
416 lowStats = _StatsContainer(**lowArgs)
418 toPlotList.append(
419 (
420 data[f"x{config_datatype.suffix_xy}{suf_x}"],
421 data[f"y{config_datatype.suffix_xy}{suf_y}"],
422 data[f"{name_datatype}HighSNMask{suf_stat}"],
423 data[f"{name_datatype}LowSNMask{suf_stat}"],
424 data[f"{name_datatype}HighSNThreshold{suf_stat}"],
425 data[f"{name_datatype}LowSNThreshold{suf_stat}"],
426 config_datatype.color,
427 config_datatype.colormap,
428 highStats,
429 lowStats,
430 )
431 )
433 xLims = self.xLims if self.xLims is not None else [np.Inf, -np.Inf]
434 for j, (
435 xs,
436 ys,
437 highSn,
438 lowSn,
439 highThresh,
440 lowThresh,
441 color,
442 cmap,
443 highStats,
444 lowStats,
445 ) in enumerate(toPlotList):
446 highSn = cast(Vector, highSn)
447 lowSn = cast(Vector, lowSn)
448 # ensure the columns are actually array
449 xs = np.array(xs)
450 ys = np.array(ys)
451 sigMadYs = nanSigmaMad(ys)
452 # plot lone median point if there's not enough data to measure more
453 n_xs = len(xs)
454 if n_xs == 0 or not np.isfinite(sigMadYs):
455 continue
456 elif n_xs < 10:
457 xs = [nanMedian(xs)]
458 sigMads = np.array([nanSigmaMad(ys)])
459 ys = np.array([nanMedian(ys)])
460 (medLine,) = ax.plot(xs, ys, color, label=f"Median: {ys[0]:.2g}", lw=0.8)
461 linesForLegend.append(medLine)
462 (sigMadLine,) = ax.plot(
463 xs,
464 ys + 1.0 * sigMads,
465 color,
466 alpha=0.8,
467 lw=0.8,
468 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:.2g}",
469 )
470 ax.plot(xs, ys - 1.0 * sigMads, color, alpha=0.8)
471 linesForLegend.append(sigMadLine)
472 histIm = None
473 continue
475 if self.xLims:
476 xMin, xMax = self.xLims
477 else:
478 # Chop off 1/3% from only the finite xs values
479 # (there may be +/-np.Inf values)
480 # TODO: This should be configurable
481 # but is there a good way to avoid redundant config params
482 # without using slightly annoying subconfigs?
483 xs1, xs97 = np.nanpercentile(xs[np.isfinite(xs)], (1, 97))
484 xScale = (xs97 - xs1) / 20.0 # This is ~5% of the data range
485 xMin, xMax = (xs1 - xScale, xs97 + xScale)
486 xLims[0] = min(xLims[0], xMin)
487 xLims[1] = max(xLims[1], xMax)
489 xEdges = np.arange(xMin, xMax, (xMax - xMin) / self.nBins)
490 medYs = nanMedian(ys)
491 fiveSigmaHigh = medYs + 5.0 * sigMadYs
492 fiveSigmaLow = medYs - 5.0 * sigMadYs
493 binSize = (fiveSigmaHigh - fiveSigmaLow) / 101.0
494 yEdges = np.arange(fiveSigmaLow, fiveSigmaHigh, binSize)
496 counts, xBins, yBins = np.histogram2d(xs, ys, bins=(xEdges, yEdges))
497 yBinsOut.append(yBins)
498 countsYs = np.sum(counts, axis=1)
500 ids = np.where((countsYs > binThresh))[0]
501 xEdgesPlot = xEdges[ids][1:]
502 xEdges = xEdges[ids]
504 if len(ids) > 1:
505 # Create the codes needed to turn the sigmaMad lines
506 # into a path to speed up checking which points are
507 # inside the area.
508 codes = np.ones(len(xEdgesPlot) * 2) * Path.LINETO
509 codes[0] = Path.MOVETO
510 codes[-1] = Path.CLOSEPOLY
512 meds = np.zeros(len(xEdgesPlot))
513 threeSigMadVerts = np.zeros((len(xEdgesPlot) * 2, 2))
514 sigMads = np.zeros(len(xEdgesPlot))
516 for i, xEdge in enumerate(xEdgesPlot):
517 ids = np.where((xs < xEdge) & (xs > xEdges[i]) & (np.isfinite(ys)))[0]
518 med = nanMedian(ys[ids])
519 sigMad = nanSigmaMad(ys[ids])
520 meds[i] = med
521 sigMads[i] = sigMad
522 threeSigMadVerts[i, :] = [xEdge, med + 3 * sigMad]
523 threeSigMadVerts[-(i + 1), :] = [xEdge, med - 3 * sigMad]
525 (medLine,) = ax.plot(xEdgesPlot, meds, color, label="Running Median")
526 linesForLegend.append(medLine)
528 # Make path to check which points lie within one sigma mad
529 threeSigMadPath = Path(threeSigMadVerts, codes)
531 # Add lines for the median +/- 3 * sigma MAD
532 (threeSigMadLine,) = ax.plot(
533 xEdgesPlot,
534 threeSigMadVerts[: len(xEdgesPlot), 1],
535 color,
536 alpha=0.4,
537 label=r"3$\sigma_{MAD}$",
538 )
539 ax.plot(xEdgesPlot[::-1], threeSigMadVerts[len(xEdgesPlot) :, 1], color, alpha=0.4)
541 # Add lines for the median +/- 1 * sigma MAD
542 (sigMadLine,) = ax.plot(
543 xEdgesPlot, meds + 1.0 * sigMads, color, alpha=0.8, label=r"$\sigma_{MAD}$"
544 )
545 linesForLegend.append(sigMadLine)
546 ax.plot(xEdgesPlot, meds - 1.0 * sigMads, color, alpha=0.8)
548 # Add lines for the median +/- 2 * sigma MAD
549 (twoSigMadLine,) = ax.plot(
550 xEdgesPlot, meds + 2.0 * sigMads, color, alpha=0.6, label=r"2$\sigma_{MAD}$"
551 )
552 linesForLegend.append(twoSigMadLine)
553 linesForLegend.append(threeSigMadLine)
554 ax.plot(xEdgesPlot, meds - 2.0 * sigMads, color, alpha=0.6)
556 # Check which points are outside 3 sigma MAD of the median
557 # and plot these as points.
558 inside = threeSigMadPath.contains_points(np.array([xs, ys]).T)
559 ax.plot(xs[~inside], ys[~inside], ".", ms=3, alpha=0.3, mfc=color, mec=color, zorder=-1)
561 # Add some stats text
562 xPos = 0.65 - 0.4 * j
563 bbox = dict(edgecolor=color, linestyle="--", facecolor="none")
564 statText = f"S/N > {highThresh:0.4g} Stats ({self.magLabel} < {highStats.approxMag:0.4g})\n"
565 highStatsStr = (
566 f"Median: {highStats.median:0.4g} "
567 + r"$\sigma_{MAD}$: "
568 + f"{highStats.sigmaMad:0.4g} "
569 + r"N$_{points}$: "
570 + f"{highStats.count}"
571 )
572 statText += highStatsStr
573 fig.text(xPos, 0.090, statText, bbox=bbox, transform=fig.transFigure, fontsize=6)
575 bbox = dict(edgecolor=color, linestyle=":", facecolor="none")
576 statText = f"S/N > {lowThresh:0.4g} Stats ({self.magLabel} < {lowStats.approxMag:0.4g})\n"
577 lowStatsStr = (
578 f"Median: {lowStats.median:0.4g} "
579 + r"$\sigma_{MAD}$: "
580 + f"{lowStats.sigmaMad:0.4g} "
581 + r"N$_{points}$: "
582 + f"{lowStats.count}"
583 )
584 statText += lowStatsStr
585 fig.text(xPos, 0.020, statText, bbox=bbox, transform=fig.transFigure, fontsize=6)
587 if self.plot2DHist:
588 histIm = ax.hexbin(xs[inside], ys[inside], gridsize=75, cmap=cmap, mincnt=1, zorder=-3)
590 # If there are not many sources being used for the
591 # statistics then plot them individually as just
592 # plotting a line makes the statistics look wrong
593 # as the magnitude estimation is iffy for low
594 # numbers of sources.
595 if np.sum(highSn) < 100 and np.sum(highSn) > 0:
596 ax.plot(
597 cast(Vector, xs[highSn]),
598 cast(Vector, ys[highSn]),
599 marker="x",
600 ms=4,
601 mec="w",
602 mew=2,
603 ls="none",
604 )
605 (highSnLine,) = ax.plot(
606 cast(Vector, xs[highSn]),
607 cast(Vector, ys[highSn]),
608 color=color,
609 marker="x",
610 ms=4,
611 ls="none",
612 label="High SN",
613 )
614 linesForLegend.append(highSnLine)
615 else:
616 ax.axvline(highStats.approxMag, color=color, ls="--")
618 if np.sum(lowSn) < 100 and np.sum(lowSn) > 0:
619 ax.plot(
620 cast(Vector, xs[lowSn]),
621 cast(Vector, ys[lowSn]),
622 marker="+",
623 ms=4,
624 mec="w",
625 mew=2,
626 ls="none",
627 )
628 (lowSnLine,) = ax.plot(
629 cast(Vector, xs[lowSn]),
630 cast(Vector, ys[lowSn]),
631 color=color,
632 marker="+",
633 ms=4,
634 ls="none",
635 label="Low SN",
636 )
637 linesForLegend.append(lowSnLine)
638 else:
639 ax.axvline(lowStats.approxMag, color=color, ls=":")
641 else:
642 ax.plot(xs, ys, ".", ms=5, alpha=0.3, mfc=color, mec=color, zorder=-1)
643 meds = np.array([nanMedian(ys)] * len(xs))
644 (medLine,) = ax.plot(xs, meds, color, label=f"Median: {nanMedian(ys):0.3g}", lw=0.8)
645 linesForLegend.append(medLine)
646 sigMads = np.array([nanSigmaMad(ys)] * len(xs))
647 (sigMadLine,) = ax.plot(
648 xs,
649 meds + 1.0 * sigMads,
650 color,
651 alpha=0.8,
652 lw=0.8,
653 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:0.3g}",
654 )
655 ax.plot(xs, meds - 1.0 * sigMads, color, alpha=0.8)
656 linesForLegend.append(sigMadLine)
657 histIm = None
659 # Add a horizontal reference line at 0 to the scatter plot
660 ax.axhline(0, color=kwargs["hlineColor"], ls=kwargs["hlineStyle"], alpha=0.7, zorder=-2)
662 # Set the scatter plot limits
663 suf_x = self.suffix_y
664 # TODO: Make this not work by accident
665 if f"yStars{suf_x}" in data and (len(cast(Vector, data[f"yStars{suf_x}"])) > 0):
666 plotMed = nanMedian(cast(Vector, data[f"yStars{suf_x}"]))
667 elif f"yGalaxies{suf_x}" in data and (len(cast(Vector, data[f"yGalaxies{suf_x}"])) > 0):
668 plotMed = nanMedian(cast(Vector, data[f"yGalaxies{suf_x}"]))
669 else:
670 plotMed = np.nan
672 # Ignore types below pending making this not working my accident
673 if len(xs) < 2: # type: ignore
674 meds = [nanMedian(ys)] # type: ignore
675 if self.yLims:
676 ax.set_ylim(self.yLims[0], self.yLims[1]) # type: ignore
677 elif np.isfinite(plotMed):
678 numSig = 4
679 yLimMin = plotMed - numSig * sigMadYs # type: ignore
680 yLimMax = plotMed + numSig * sigMadYs # type: ignore
681 while (yLimMax < np.max(meds) or yLimMin > np.min(meds)) and numSig < 10: # type: ignore
682 numSig += 1
684 numSig += 1
685 yLimMin = plotMed - numSig * sigMadYs # type: ignore
686 yLimMax = plotMed + numSig * sigMadYs # type: ignore
687 ax.set_ylim(yLimMin, yLimMax)
689 # This could be false if len(x) == 0 for xs in toPlotList
690 # ... in which case nothing was plotted and limits are irrelevant
691 if all(np.isfinite(xLims)):
692 ax.set_xlim(xLims)
694 # Add a line legend
695 ax.legend(
696 handles=linesForLegend,
697 ncol=4,
698 fontsize=6,
699 loc=self.legendLocation,
700 framealpha=0.9,
701 edgecolor="k",
702 borderpad=0.4,
703 handlelength=1,
704 )
706 # Add axes labels
707 ax.set_ylabel(self.yAxisLabel, fontsize=10, labelpad=10)
708 ax.set_xlabel(self.xAxisLabel, fontsize=10, labelpad=2)
710 return ax, histIm
712 def _makeTopHistogram(
713 self, data: KeyedData, figure: Figure, gs: gridspec.GridSpec, ax: Axes, **kwargs
714 ) -> None:
715 suf_x = self.suffix_x
716 # Top histogram
717 topHist = figure.add_subplot(gs[0, :-1], sharex=ax)
718 x_min, x_max = ax.get_xlim()
719 bins = np.linspace(x_min, x_max, 100)
721 if "any" in self.plotTypes:
722 x_any = f"x{self._datatypes['any'].suffix_xy}{suf_x}"
723 keys_notany = [x for x in self.plotTypes if x != "any"]
724 else:
725 x_any = (
726 np.concatenate([data[f"x{self._datatypes[key].suffix_xy}{suf_x}"] for key in self.plotTypes])
727 if (len(self.plotTypes) > 1)
728 else None
729 )
730 keys_notany = self.plotTypes
731 if x_any is not None:
732 topHist.hist(x_any, bins=bins, color="grey", alpha=0.3, log=True, label=f"Any ({len(x_any)})")
734 for key in keys_notany:
735 config_datatype = self._datatypes[key]
736 vector = data[f"x{config_datatype.suffix_xy}{suf_x}"]
737 topHist.hist(
738 vector,
739 bins=bins,
740 color=config_datatype.color,
741 histtype="step",
742 log=True,
743 label=f"{config_datatype.suffix_stat} ({len(vector)})",
744 )
745 topHist.axes.get_xaxis().set_visible(False)
746 topHist.set_ylabel("Number", fontsize=8)
747 topHist.legend(fontsize=6, framealpha=0.9, borderpad=0.4, loc="lower left", ncol=3, edgecolor="k")
749 # Side histogram
751 def _makeSideHistogram(
752 self,
753 data: KeyedData,
754 figure: Figure,
755 gs: gridspec.Gridspec,
756 ax: Axes,
757 histIm: Optional[PolyCollection],
758 **kwargs,
759 ) -> None:
760 suf_y = self.suffix_y
761 # Side histogram
762 sideHist = figure.add_subplot(gs[1:, -1], sharey=ax)
763 y_min, y_max = ax.get_ylim()
764 bins = np.linspace(y_min, y_max, 100)
766 if "any" in self.plotTypes:
767 y_any = f"y{self._datatypes['any'].suffix_xy}{suf_y}"
768 keys_notany = [x for x in self.plotTypes if x != "any"]
769 else:
770 y_any = (
771 np.concatenate([data[f"y{self._datatypes[key].suffix_xy}{suf_y}"] for key in self.plotTypes])
772 if (len(self.plotTypes) > 1)
773 else None
774 )
775 keys_notany = self.plotTypes
776 if y_any is not None:
777 sideHist.hist(
778 np.array(y_any),
779 bins=bins,
780 color="grey",
781 alpha=0.3,
782 orientation="horizontal",
783 log=True,
784 )
785 kwargs_hist = dict(
786 bins=bins,
787 histtype="step",
788 log=True,
789 orientation="horizontal",
790 )
791 for key in keys_notany:
792 config_datatype = self._datatypes[key]
793 vector = data[f"y{config_datatype.suffix_xy}{suf_y}"]
794 sideHist.hist(
795 vector,
796 color=config_datatype.color,
797 **kwargs_hist,
798 )
799 sideHist.hist(
800 vector[cast(Vector, data[f"{key}HighSNMask{self.suffix_stat}"])],
801 color=config_datatype.color,
802 ls="--",
803 **kwargs_hist,
804 )
805 sideHist.hist(
806 vector[cast(Vector, data[f"{key}LowSNMask{self.suffix_stat}"])],
807 color=config_datatype.color,
808 **kwargs_hist,
809 ls=":",
810 )
812 # Add a horizontal reference line at 0 to the side histogram
813 sideHist.axhline(0, color=kwargs["hlineColor"], ls=kwargs["hlineStyle"], alpha=0.7, zorder=-2)
815 sideHist.axes.get_yaxis().set_visible(False)
816 sideHist.set_xlabel("Number", fontsize=8)
817 if self.plot2DHist and histIm is not None:
818 divider = make_axes_locatable(sideHist)
819 cax = divider.append_axes("right", size="8%", pad=0)
820 plt.colorbar(histIm, cax=cax, orientation="vertical", label="Number of Points Per Bin")