Coverage for python/lsst/analysis/tools/actions/plot/scatterplotWithTwoHists.py: 16%
319 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-06 04:05 -0700
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-06 04:05 -0700
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ("ScatterPlotStatsAction", "ScatterPlotWithTwoHists")
26from typing import Mapping, NamedTuple, Optional, cast
28import matplotlib.colors
29import matplotlib.pyplot as plt
30import numpy as np
31from lsst.pex.config import Field
32from lsst.pex.config.configurableActions import ConfigurableActionField
33from lsst.pex.config.listField import ListField
34from lsst.skymap import BaseSkyMap
35from matplotlib import gridspec
36from matplotlib.axes import Axes
37from matplotlib.collections import PolyCollection
38from matplotlib.figure import Figure
39from matplotlib.path import Path
40from mpl_toolkits.axes_grid1 import make_axes_locatable
42from ...interfaces import KeyedData, KeyedDataAction, KeyedDataSchema, PlotAction, Scalar, ScalarType, Vector
43from ...statistics import nansigmaMad, sigmaMad
44from ..keyedData.summaryStatistics import SummaryStatisticAction
45from ..scalar import MedianAction
46from ..vector import ConvertFluxToMag, SnSelector
47from .plotUtils import addPlotInfo, addSummaryPlot, generateSummaryStats, mkColormap
49# ignore because coolwarm is actually part of module
50cmapPatch = plt.cm.coolwarm.copy() # type: ignore
51cmapPatch.set_bad(color="none")
54class ScatterPlotStatsAction(KeyedDataAction):
55 """Calculates the statistics needed for the
56 scatter plot with two hists.
57 """
59 vectorKey = Field[str](doc="Vector on which to compute statistics")
60 highSNSelector = ConfigurableActionField[SnSelector](
61 doc="Selector used to determine high SN Objects", default=SnSelector(threshold=2700)
62 )
63 lowSNSelector = ConfigurableActionField[SnSelector](
64 doc="Selector used to determine low SN Objects", default=SnSelector(threshold=500)
65 )
66 fluxType = Field[str](doc="Vector key to use to compute signal to noise ratio", default="{band}_psfFlux")
68 def getInputSchema(self, **kwargs) -> KeyedDataSchema:
69 yield (self.vectorKey, Vector)
70 yield (self.fluxType, Vector)
71 yield from self.highSNSelector.getInputSchema()
72 yield from self.lowSNSelector.getInputSchema()
74 def getOutputSchema(self) -> KeyedDataSchema:
75 return (
76 (f'{self.identity or ""}HighSNMask', Vector),
77 (f'{self.identity or ""}LowSNMask', Vector),
78 (f"{{band}}_lowSN{self.identity.capitalize() if self.identity else ''}_median", Scalar),
79 (f"{{band}}_lowSN{self.identity.capitalize() if self.identity else ''}_sigmaMad", Scalar),
80 (f"{{band}}_lowSN{self.identity.capitalize() if self.identity else ''}_count", Scalar),
81 (f"{{band}}_lowSN{self.identity.capitalize() if self.identity else ''}_approxMag", Scalar),
82 (f"{{band}}_highSN{self.identity.capitalize() if self.identity else ''}_median", Scalar),
83 (f"{{band}}_highSN{self.identity.capitalize() if self.identity else ''}_sigmaMad", Scalar),
84 (f"{{band}}_highSN{self.identity.capitalize() if self.identity else ''}_count", Scalar),
85 (f"{{band}}_highSN{self.identity.capitalize() if self.identity else ''}_approxMag", Scalar),
86 ("highThreshold", Scalar),
87 ("lowThreshold", Scalar),
88 )
90 def __call__(self, data: KeyedData, **kwargs) -> KeyedData:
91 results = {}
92 highMaskKey = f'{(self.identity or "").lower()}HighSNMask'
93 results[highMaskKey] = self.highSNSelector(data, **kwargs)
95 lowMaskKey = f'{(self.identity or "").lower()}LowSNMask'
96 results[lowMaskKey] = self.lowSNSelector(data, **kwargs)
98 prefix = f"{band}_" if (band := kwargs.get("band")) else ""
99 fluxes = data[self.fluxType.format(band=band)] if band is not None else None
101 statAction = SummaryStatisticAction(vectorKey=self.vectorKey)
103 # this is sad, but pex_config seems to have broken behavior that
104 # is dangerous to fix
105 statAction.setDefaults()
107 medianAction = MedianAction(vectorKey="mag")
108 magAction = ConvertFluxToMag(vectorKey="flux")
110 for maskKey, binName in ((lowMaskKey, "low"), (highMaskKey, "high")):
111 name = f"{prefix}{binName}SN{self.identity.capitalize() if self.identity else ''}"
112 # set the approxMag to the median mag in the SN selection
113 results[f"{name}_approxMag".format(**kwargs)] = (
114 medianAction({"mag": magAction({"flux": fluxes[results[maskKey]]})}) # type: ignore
115 if band is not None
116 else np.nan
117 )
118 stats = statAction(data, **(kwargs | {"mask": results[maskKey]})).items()
119 for suffix, value in stats:
120 tmpKey = f"{name}_{suffix}".format(**kwargs)
121 results[tmpKey] = value
122 results["highSnThreshold"] = self.highSNSelector.threshold # type: ignore
123 results["lowSnThreshold"] = self.lowSNSelector.threshold # type: ignore
125 return results
128def _validObjectTypes(value):
129 return value in ("stars", "galaxies", "unknown", "any")
132# ignore type because of conflicting name on tuple baseclass
133class _StatsContainer(NamedTuple):
134 median: Scalar
135 sigmaMad: Scalar
136 count: Scalar # type: ignore
137 approxMag: Scalar
140class DataTypeDefaults(NamedTuple):
141 suffix_stat: str
142 suffix_xy: str
143 color: str
144 colormap: matplotlib.colors.Colormap | None
147class ScatterPlotWithTwoHists(PlotAction):
148 """Makes a scatter plot of the data with a marginal
149 histogram for each axis.
150 """
152 yLims = ListField[float](
153 doc="ylimits of the plot, if not specified determined from data",
154 length=2,
155 optional=True,
156 )
158 xLims = ListField[float](
159 doc="xlimits of the plot, if not specified determined from data", length=2, optional=True
160 )
161 xAxisLabel = Field[str](doc="Label to use for the x axis", optional=False)
162 yAxisLabel = Field[str](doc="Label to use for the y axis", optional=False)
163 magLabel = Field[str](doc="Label to use for the magnitudes used for SNR", optional=False)
164 nBins = Field[float](doc="Number of bins on x axis", default=40.0)
165 plot2DHist = Field[bool](
166 doc="Plot a 2D histogram in dense areas of points on the scatter plot."
167 "Doesn't look great if plotting multiple datasets on top of each other.",
168 default=True,
169 )
170 plotTypes = ListField[str](
171 doc="Selection of types of objects to plot. Can take any combination of"
172 " stars, galaxies, unknown, any",
173 optional=False,
174 itemCheck=_validObjectTypes,
175 )
177 addSummaryPlot = Field[bool](
178 doc="Add a summary plot to the figure?",
179 default=False,
180 )
182 _stats = ("median", "sigmaMad", "count", "approxMag")
183 _datatypes = {
184 "galaxies": DataTypeDefaults(
185 suffix_stat="Galaxies",
186 suffix_xy="Galaxies",
187 color="firebrick",
188 colormap=mkColormap(["lemonchiffon", "firebrick"]),
189 ),
190 "stars": DataTypeDefaults(
191 suffix_stat="Stars",
192 suffix_xy="Stars",
193 color="midnightblue",
194 colormap=mkColormap(["paleturquoise", "midnightBlue"]),
195 ),
196 "unknown": DataTypeDefaults(
197 suffix_stat="Unknown",
198 suffix_xy="Unknown",
199 color="green",
200 colormap=None,
201 ),
202 "any": DataTypeDefaults(
203 suffix_stat="Any",
204 suffix_xy="",
205 color="purple",
206 colormap=None,
207 ),
208 }
210 def getInputSchema(self) -> KeyedDataSchema:
211 base: list[tuple[str, type[Vector] | ScalarType]] = []
212 for name_datatype in self.plotTypes:
213 config_datatype = self._datatypes[name_datatype]
214 base.append((f"x{config_datatype.suffix_xy}", Vector))
215 base.append((f"y{config_datatype.suffix_xy}", Vector))
216 base.append((f"{name_datatype}HighSNMask", Vector))
217 base.append((f"{name_datatype}LowSNMask", Vector))
218 # statistics
219 for name in self._stats:
220 base.append((f"{{band}}_highSN{config_datatype.suffix_stat}_{name}", Scalar))
221 base.append((f"{{band}}_lowSN{config_datatype.suffix_stat}_{name}", Scalar))
222 base.append(("lowSnThreshold", Scalar))
223 base.append(("highSnThreshold", Scalar))
225 if self.addSummaryPlot:
226 base.append(("patch", Vector))
228 return base
230 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure:
231 self._validateInput(data, **kwargs)
232 return self.makePlot(data, **kwargs)
234 def _validateInput(self, data: KeyedData, **kwargs) -> None:
235 """NOTE currently can only check that something is not a Scalar, not
236 check that the data is consistent with Vector
237 """
238 needed = self.getFormattedInputSchema(**kwargs)
239 if remainder := {key.format(**kwargs) for key, _ in needed} - {
240 key.format(**kwargs) for key in data.keys()
241 }:
242 raise ValueError(f"Task needs keys {remainder} but they were not found in input")
243 for name, typ in needed:
244 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar)
245 if isScalar and typ != Scalar:
246 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}")
248 def makePlot(
249 self,
250 data: KeyedData,
251 skymap: BaseSkyMap,
252 plotInfo: Mapping[str, str],
253 **kwargs,
254 ) -> Figure:
255 """Makes a generic plot with a 2D histogram and collapsed histograms of
256 each axis.
258 Parameters
259 ----------
260 data : `KeyedData`
261 The catalog to plot the points from.
262 skymap : `lsst.skymap.BaseSkyMap`
263 The skymap that gives the patch locations
264 plotInfo : `dict`
265 A dictionary of information about the data being plotted with keys:
267 * ``"run"``
268 The output run for the plots (`str`).
269 * ``"skymap"``
270 The type of skymap used for the data (`str`).
271 * ``"filter"``
272 The filter used for this data (`str`).
273 * ``"tract"``
274 The tract that the data comes from (`str`).
276 Returns
277 -------
278 fig : `matplotlib.figure.Figure`
279 The resulting figure.
281 Notes
282 -----
283 Uses the axisLabels config options `x` and `y` and the axisAction
284 config options `xAction` and `yAction` to plot a scatter
285 plot of the values against each other. A histogram of the points
286 collapsed onto each axis is also plotted. A summary panel showing the
287 median of the y value in each patch is shown in the upper right corner
288 of the resultant plot. The code uses the selectorActions to decide
289 which points to plot and the statisticSelector actions to determine
290 which points to use for the printed statistics.
292 If this function is being used within the pipetask framework
293 that takes care of making sure that data has all the required
294 elements but if you are running this as a standalone function
295 then you will need to provide the following things in the
296 input data.
298 * If stars is in self.plotTypes:
299 xStars, yStars, starsHighSNMask, starsLowSNMask and
300 {band}_highSNStars_{name}, {band}_lowSNStars_{name}
301 where name is median, sigma_Mad, count and approxMag.
303 * If it is for galaxies/unknowns then replace stars in the above
304 names with galaxies/unknowns.
306 * If it is for any (which covers all the points) then it
307 becomes, x, y, and any instead of stars for the other
308 parameters given above.
310 * In every case it is expected that data contains:
311 lowSnThreshold, highSnThreshold and patch
312 (if the summary plot is being plotted).
314 Examples
315 --------
316 An example of the plot produced from this code is here:
318 .. image:: /_static/analysis_tools/scatterPlotExample.png
320 For a detailed example of how to make a plot from the command line
321 please see the
322 :ref:`getting started guide<analysis-tools-getting-started>`.
323 """
324 if not self.plotTypes:
325 noDataFig = Figure()
326 noDataFig.text(0.3, 0.5, "No data to plot after selectors applied")
327 noDataFig = addPlotInfo(noDataFig, plotInfo)
328 return noDataFig
330 # Set default color and line style for the horizontal
331 # reference line at 0
332 if "hlineColor" not in kwargs:
333 kwargs["hlineColor"] = "black"
335 if "hlineStyle" not in kwargs:
336 kwargs["hlineStyle"] = (0, (1, 4))
338 fig = plt.figure(dpi=300)
339 gs = gridspec.GridSpec(4, 4)
341 # add the various plot elements
342 ax, imhist = self._scatterPlot(data, fig, gs, **kwargs)
343 self._makeTopHistogram(data, fig, gs, ax, **kwargs)
344 self._makeSideHistogram(data, fig, gs, ax, imhist, **kwargs)
345 # Needs info from run quantum
346 if self.addSummaryPlot:
347 sumStats = generateSummaryStats(data, skymap, plotInfo)
348 label = self.yAxisLabel
349 fig = addSummaryPlot(fig, gs[0, -1], sumStats, label)
351 plt.draw()
352 plt.subplots_adjust(wspace=0.0, hspace=0.0, bottom=0.22, left=0.21)
353 fig = addPlotInfo(fig, plotInfo)
354 return fig
356 def _scatterPlot(
357 self, data: KeyedData, fig: Figure, gs: gridspec.GridSpec, **kwargs
358 ) -> tuple[Axes, Optional[PolyCollection]]:
359 # Main scatter plot
360 ax = fig.add_subplot(gs[1:, :-1])
362 binThresh = 5
364 yBinsOut = []
365 linesForLegend = []
367 toPlotList = []
368 histIm = None
369 highStats: _StatsContainer
370 lowStats: _StatsContainer
372 for name_datatype in self.plotTypes:
373 config_datatype = self._datatypes[name_datatype]
374 highArgs = {}
375 lowArgs = {}
376 for name in self._stats:
377 highArgs[name] = cast(
378 Scalar, data[f"{{band}}_highSN{config_datatype.suffix_stat}_{name}".format(**kwargs)]
379 )
380 lowArgs[name] = cast(
381 Scalar, data[f"{{band}}_lowSN{config_datatype.suffix_stat}_{name}".format(**kwargs)]
382 )
383 highStats = _StatsContainer(**highArgs)
384 lowStats = _StatsContainer(**lowArgs)
386 toPlotList.append(
387 (
388 data[f"x{config_datatype.suffix_xy}"],
389 data[f"y{config_datatype.suffix_xy}"],
390 data[f"{name_datatype}HighSNMask"],
391 data[f"{name_datatype}LowSNMask"],
392 config_datatype.color,
393 config_datatype.colormap,
394 highStats,
395 lowStats,
396 )
397 )
399 xLims = self.xLims if self.xLims is not None else [np.Inf, -np.Inf]
400 for j, (xs, ys, highSn, lowSn, color, cmap, highStats, lowStats) in enumerate(toPlotList):
401 highSn = cast(Vector, highSn)
402 lowSn = cast(Vector, lowSn)
403 # ensure the columns are actually array
404 xs = np.array(xs)
405 ys = np.array(ys)
406 sigMadYs = nansigmaMad(ys)
407 # plot lone median point if there's not enough data to measure more
408 n_xs = len(xs)
409 if n_xs == 0:
410 continue
411 elif n_xs < 10:
412 xs = [np.nanmedian(xs)]
413 sigMads = np.array([nansigmaMad(ys)])
414 ys = np.array([np.nanmedian(ys)])
415 (medLine,) = ax.plot(xs, ys, color, label=f"Median: {ys[0]:.2g}", lw=0.8)
416 linesForLegend.append(medLine)
417 (sigMadLine,) = ax.plot(
418 xs,
419 ys + 1.0 * sigMads,
420 color,
421 alpha=0.8,
422 lw=0.8,
423 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:.2g}",
424 )
425 ax.plot(xs, ys - 1.0 * sigMads, color, alpha=0.8)
426 linesForLegend.append(sigMadLine)
427 histIm = None
428 continue
430 if self.xLims:
431 xMin, xMax = self.xLims
432 else:
433 # Chop off 1/3% from only the finite xs values
434 # (there may be +/-np.Inf values)
435 # TODO: This should be configurable
436 # but is there a good way to avoid redundant config params
437 # without using slightly annoying subconfigs?
438 xs1, xs97 = np.nanpercentile(xs[np.isfinite(xs)], (1, 97))
439 xScale = (xs97 - xs1) / 20.0 # This is ~5% of the data range
440 xMin, xMax = (xs1 - xScale, xs97 + xScale)
441 xLims[0] = min(xLims[0], xMin)
442 xLims[1] = max(xLims[1], xMax)
444 xEdges = np.arange(xMin, xMax, (xMax - xMin) / self.nBins)
445 medYs = np.nanmedian(ys)
446 fiveSigmaHigh = medYs + 5.0 * sigMadYs
447 fiveSigmaLow = medYs - 5.0 * sigMadYs
448 binSize = (fiveSigmaHigh - fiveSigmaLow) / 101.0
449 yEdges = np.arange(fiveSigmaLow, fiveSigmaHigh, binSize)
451 counts, xBins, yBins = np.histogram2d(xs, ys, bins=(xEdges, yEdges))
452 yBinsOut.append(yBins)
453 countsYs = np.sum(counts, axis=1)
455 ids = np.where((countsYs > binThresh))[0]
456 xEdgesPlot = xEdges[ids][1:]
457 xEdges = xEdges[ids]
459 if len(ids) > 1:
460 # Create the codes needed to turn the sigmaMad lines
461 # into a path to speed up checking which points are
462 # inside the area.
463 codes = np.ones(len(xEdgesPlot) * 2) * Path.LINETO
464 codes[0] = Path.MOVETO
465 codes[-1] = Path.CLOSEPOLY
467 meds = np.zeros(len(xEdgesPlot))
468 threeSigMadVerts = np.zeros((len(xEdgesPlot) * 2, 2))
469 sigMads = np.zeros(len(xEdgesPlot))
471 for i, xEdge in enumerate(xEdgesPlot):
472 ids = np.where((xs < xEdge) & (xs > xEdges[i]) & (np.isfinite(ys)))[0]
473 med = np.nanmedian(ys[ids])
474 sigMad = sigmaMad(ys[ids], nan_policy="omit")
475 meds[i] = med
476 sigMads[i] = sigMad
477 threeSigMadVerts[i, :] = [xEdge, med + 3 * sigMad]
478 threeSigMadVerts[-(i + 1), :] = [xEdge, med - 3 * sigMad]
480 (medLine,) = ax.plot(xEdgesPlot, meds, color, label="Running Median")
481 linesForLegend.append(medLine)
483 # Make path to check which points lie within one sigma mad
484 threeSigMadPath = Path(threeSigMadVerts, codes)
486 # Add lines for the median +/- 3 * sigma MAD
487 (threeSigMadLine,) = ax.plot(
488 xEdgesPlot,
489 threeSigMadVerts[: len(xEdgesPlot), 1],
490 color,
491 alpha=0.4,
492 label=r"3$\sigma_{MAD}$",
493 )
494 ax.plot(xEdgesPlot[::-1], threeSigMadVerts[len(xEdgesPlot) :, 1], color, alpha=0.4)
496 # Add lines for the median +/- 1 * sigma MAD
497 (sigMadLine,) = ax.plot(
498 xEdgesPlot, meds + 1.0 * sigMads, color, alpha=0.8, label=r"$\sigma_{MAD}$"
499 )
500 linesForLegend.append(sigMadLine)
501 ax.plot(xEdgesPlot, meds - 1.0 * sigMads, color, alpha=0.8)
503 # Add lines for the median +/- 2 * sigma MAD
504 (twoSigMadLine,) = ax.plot(
505 xEdgesPlot, meds + 2.0 * sigMads, color, alpha=0.6, label=r"2$\sigma_{MAD}$"
506 )
507 linesForLegend.append(twoSigMadLine)
508 linesForLegend.append(threeSigMadLine)
509 ax.plot(xEdgesPlot, meds - 2.0 * sigMads, color, alpha=0.6)
511 # Check which points are outside 3 sigma MAD of the median
512 # and plot these as points.
513 inside = threeSigMadPath.contains_points(np.array([xs, ys]).T)
514 ax.plot(xs[~inside], ys[~inside], ".", ms=3, alpha=0.3, mfc=color, mec=color, zorder=-1)
516 # Add some stats text
517 xPos = 0.65 - 0.4 * j
518 bbox = dict(edgecolor=color, linestyle="--", facecolor="none")
519 highThresh = data["highSnThreshold"]
520 statText = f"S/N > {highThresh:0.4g} Stats ({self.magLabel} < {highStats.approxMag:0.4g})\n"
521 highStatsStr = (
522 f"Median: {highStats.median:0.4g} "
523 + r"$\sigma_{MAD}$: "
524 + f"{highStats.sigmaMad:0.4g} "
525 + r"N$_{points}$: "
526 + f"{highStats.count}"
527 )
528 statText += highStatsStr
529 fig.text(xPos, 0.090, statText, bbox=bbox, transform=fig.transFigure, fontsize=6)
531 bbox = dict(edgecolor=color, linestyle=":", facecolor="none")
532 lowThresh = data["lowSnThreshold"]
533 statText = f"S/N > {lowThresh:0.4g} Stats ({self.magLabel} < {lowStats.approxMag:0.4g})\n"
534 lowStatsStr = (
535 f"Median: {lowStats.median:0.4g} "
536 + r"$\sigma_{MAD}$: "
537 + f"{lowStats.sigmaMad:0.4g} "
538 + r"N$_{points}$: "
539 + f"{lowStats.count}"
540 )
541 statText += lowStatsStr
542 fig.text(xPos, 0.020, statText, bbox=bbox, transform=fig.transFigure, fontsize=6)
544 if self.plot2DHist:
545 histIm = ax.hexbin(xs[inside], ys[inside], gridsize=75, cmap=cmap, mincnt=1, zorder=-3)
547 # If there are not many sources being used for the
548 # statistics then plot them individually as just
549 # plotting a line makes the statistics look wrong
550 # as the magnitude estimation is iffy for low
551 # numbers of sources.
552 if np.sum(highSn) < 100 and np.sum(highSn) > 0:
553 ax.plot(
554 cast(Vector, xs[highSn]),
555 cast(Vector, ys[highSn]),
556 marker="x",
557 ms=4,
558 mec="w",
559 mew=2,
560 ls="none",
561 )
562 (highSnLine,) = ax.plot(
563 cast(Vector, xs[highSn]),
564 cast(Vector, ys[highSn]),
565 color=color,
566 marker="x",
567 ms=4,
568 ls="none",
569 label="High SN",
570 )
571 linesForLegend.append(highSnLine)
572 else:
573 ax.axvline(highStats.approxMag, color=color, ls="--")
575 if np.sum(lowSn) < 100 and np.sum(lowSn) > 0:
576 ax.plot(
577 cast(Vector, xs[lowSn]),
578 cast(Vector, ys[lowSn]),
579 marker="+",
580 ms=4,
581 mec="w",
582 mew=2,
583 ls="none",
584 )
585 (lowSnLine,) = ax.plot(
586 cast(Vector, xs[lowSn]),
587 cast(Vector, ys[lowSn]),
588 color=color,
589 marker="+",
590 ms=4,
591 ls="none",
592 label="Low SN",
593 )
594 linesForLegend.append(lowSnLine)
595 else:
596 ax.axvline(lowStats.approxMag, color=color, ls=":")
598 else:
599 ax.plot(xs, ys, ".", ms=5, alpha=0.3, mfc=color, mec=color, zorder=-1)
600 meds = np.array([np.nanmedian(ys)] * len(xs))
601 (medLine,) = ax.plot(xs, meds, color, label=f"Median: {np.nanmedian(ys):0.3g}", lw=0.8)
602 linesForLegend.append(medLine)
603 sigMads = np.array([nansigmaMad(ys)] * len(xs))
604 (sigMadLine,) = ax.plot(
605 xs,
606 meds + 1.0 * sigMads,
607 color,
608 alpha=0.8,
609 lw=0.8,
610 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:0.3g}",
611 )
612 ax.plot(xs, meds - 1.0 * sigMads, color, alpha=0.8)
613 linesForLegend.append(sigMadLine)
614 histIm = None
616 # Add a horizontal reference line at 0 to the scatter plot
617 ax.axhline(0, color=kwargs["hlineColor"], ls=kwargs["hlineStyle"], alpha=0.7, zorder=-2)
619 # Set the scatter plot limits
620 # TODO: Make this not work by accident
621 if "yStars" in data and (len(cast(Vector, data["yStars"])) > 0):
622 plotMed = np.nanmedian(cast(Vector, data["yStars"]))
623 elif "yGalaxies" in data and (len(cast(Vector, data["yGalaxies"])) > 0):
624 plotMed = np.nanmedian(cast(Vector, data["yGalaxies"]))
625 else:
626 plotMed = np.nan
628 # Ignore types below pending making this not working my accident
629 if len(xs) < 2: # type: ignore
630 meds = [np.nanmedian(ys)] # type: ignore
631 if self.yLims:
632 ax.set_ylim(self.yLims[0], self.yLims[1]) # type: ignore
633 elif np.isfinite(plotMed):
634 numSig = 4
635 yLimMin = plotMed - numSig * sigMadYs # type: ignore
636 yLimMax = plotMed + numSig * sigMadYs # type: ignore
637 while (yLimMax < np.max(meds) or yLimMin > np.min(meds)) and numSig < 10: # type: ignore
638 numSig += 1
640 numSig += 1
641 yLimMin = plotMed - numSig * sigMadYs # type: ignore
642 yLimMax = plotMed + numSig * sigMadYs # type: ignore
643 ax.set_ylim(yLimMin, yLimMax)
645 # This could be false if len(x) == 0 for xs in toPlotList
646 # ... in which case nothing was plotted and limits are irrelevant
647 if all(np.isfinite(xLims)):
648 ax.set_xlim(xLims)
650 # Add a line legend
651 ax.legend(
652 handles=linesForLegend,
653 ncol=4,
654 fontsize=6,
655 loc="upper left",
656 framealpha=0.9,
657 edgecolor="k",
658 borderpad=0.4,
659 handlelength=1,
660 )
662 # Add axes labels
663 ax.set_ylabel(self.yAxisLabel, fontsize=10, labelpad=10)
664 ax.set_xlabel(self.xAxisLabel, fontsize=10, labelpad=2)
666 return ax, histIm
668 def _makeTopHistogram(
669 self, data: KeyedData, figure: Figure, gs: gridspec.GridSpec, ax: Axes, **kwargs
670 ) -> None:
671 # Top histogram
672 topHist = figure.add_subplot(gs[0, :-1], sharex=ax)
674 if "all" in self.plotTypes:
675 x_all = f"x{self._datatypes['all'].suffix_xy}"
676 keys_notall = [x for x in self.plotTypes if x != "all"]
677 else:
678 x_all = np.concatenate([data[f"x{self._datatypes[key].suffix_xy}"] for key in self.plotTypes])
679 keys_notall = self.plotTypes
681 x_min, x_max = ax.get_xlim()
682 bins = np.linspace(x_min, x_max, 100)
683 topHist.hist(x_all, bins=bins, color="grey", alpha=0.3, log=True, label=f"All ({len(x_all)})")
684 for key in keys_notall:
685 config_datatype = self._datatypes[key]
686 vector = data[f"x{config_datatype.suffix_xy}"]
687 topHist.hist(
688 vector,
689 bins=bins,
690 color=config_datatype.color,
691 histtype="step",
692 log=True,
693 label=f"{config_datatype.suffix_stat} ({len(vector)})",
694 )
695 topHist.axes.get_xaxis().set_visible(False)
696 topHist.set_ylabel("Number", fontsize=8)
697 topHist.legend(fontsize=6, framealpha=0.9, borderpad=0.4, loc="lower left", ncol=3, edgecolor="k")
699 # Side histogram
701 def _makeSideHistogram(
702 self,
703 data: KeyedData,
704 figure: Figure,
705 gs: gridspec.Gridspec,
706 ax: Axes,
707 histIm: Optional[PolyCollection],
708 **kwargs,
709 ) -> None:
710 # Side histogram
711 sideHist = figure.add_subplot(gs[1:, -1], sharey=ax)
713 if "all" in self.plotTypes:
714 y_all = f"y{self._datatypes['all'].suffix_xy}"
715 keys_notall = [x for x in self.plotTypes if x != "all"]
716 else:
717 y_all = np.concatenate([data[f"y{self._datatypes[key].suffix_xy}"] for key in self.plotTypes])
718 keys_notall = self.plotTypes
720 y_min, y_max = ax.get_ylim()
721 bins = np.linspace(y_min, y_max, 100)
722 sideHist.hist(y_all, bins=bins, color="grey", alpha=0.3, orientation="horizontal", log=True)
723 kwargs_hist = dict(
724 bins=bins,
725 histtype="step",
726 log=True,
727 orientation="horizontal",
728 )
729 for key in keys_notall:
730 config_datatype = self._datatypes[key]
731 vector = data[f"y{config_datatype.suffix_xy}"]
732 sideHist.hist(
733 vector,
734 color=config_datatype.color,
735 **kwargs_hist,
736 )
737 sideHist.hist(
738 vector[cast(Vector, data[f"{key}HighSNMask"])],
739 color=config_datatype.color,
740 ls="--",
741 **kwargs_hist,
742 )
743 sideHist.hist(
744 vector[cast(Vector, data[f"{key}LowSNMask"])],
745 color=config_datatype.color,
746 **kwargs_hist,
747 ls=":",
748 )
750 # Add a horizontal reference line at 0 to the side histogram
751 sideHist.axhline(0, color=kwargs["hlineColor"], ls=kwargs["hlineStyle"], alpha=0.7, zorder=-2)
753 sideHist.axes.get_yaxis().set_visible(False)
754 sideHist.set_xlabel("Number", fontsize=8)
755 if self.plot2DHist and histIm is not None:
756 divider = make_axes_locatable(sideHist)
757 cax = divider.append_axes("right", size="8%", pad=0)
758 plt.colorbar(histIm, cax=cax, orientation="vertical", label="Number of Points Per Bin")