Coverage for python/lsst/analysis/tools/actions/plot/scatterplotWithTwoHists.py: 16%
320 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-03-28 12:44 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-03-28 12:44 +0000
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ("ScatterPlotStatsAction", "ScatterPlotWithTwoHists")
26from typing import Mapping, NamedTuple, Optional, cast
28import matplotlib.colors
29import matplotlib.pyplot as plt
30import numpy as np
31from lsst.pex.config import Field
32from lsst.pex.config.configurableActions import ConfigurableActionField
33from lsst.pex.config.listField import ListField
34from lsst.skymap import BaseSkyMap
35from matplotlib import gridspec
36from matplotlib.axes import Axes
37from matplotlib.collections import PolyCollection
38from matplotlib.figure import Figure
39from matplotlib.path import Path
40from mpl_toolkits.axes_grid1 import make_axes_locatable
42from ...interfaces import KeyedData, KeyedDataAction, KeyedDataSchema, PlotAction, Scalar, ScalarType, Vector
43from ...math import nanMedian, nanSigmaMad
44from ..keyedData.summaryStatistics import SummaryStatisticAction
45from ..scalar import MedianAction
46from ..vector import ConvertFluxToMag, SnSelector
47from .plotUtils import addPlotInfo, addSummaryPlot, generateSummaryStats, mkColormap
49# ignore because coolwarm is actually part of module
50cmapPatch = plt.cm.coolwarm.copy() # type: ignore
51cmapPatch.set_bad(color="none")
54class ScatterPlotStatsAction(KeyedDataAction):
55 """Calculates the statistics needed for the
56 scatter plot with two hists.
57 """
59 vectorKey = Field[str](doc="Vector on which to compute statistics")
60 highSNSelector = ConfigurableActionField[SnSelector](
61 doc="Selector used to determine high SN Objects", default=SnSelector(threshold=2700)
62 )
63 lowSNSelector = ConfigurableActionField[SnSelector](
64 doc="Selector used to determine low SN Objects", default=SnSelector(threshold=500)
65 )
66 fluxType = Field[str](doc="Vector key to use to compute signal to noise ratio", default="{band}_psfFlux")
68 def getInputSchema(self, **kwargs) -> KeyedDataSchema:
69 yield (self.vectorKey, Vector)
70 yield (self.fluxType, Vector)
71 yield from self.highSNSelector.getInputSchema()
72 yield from self.lowSNSelector.getInputSchema()
74 def getOutputSchema(self) -> KeyedDataSchema:
75 return (
76 (f'{self.identity or ""}HighSNMask', Vector),
77 (f'{self.identity or ""}LowSNMask', Vector),
78 (f"{{band}}_lowSN{self.identity.capitalize() if self.identity else ''}_median", Scalar),
79 (f"{{band}}_lowSN{self.identity.capitalize() if self.identity else ''}_sigmaMad", Scalar),
80 (f"{{band}}_lowSN{self.identity.capitalize() if self.identity else ''}_count", Scalar),
81 (f"{{band}}_lowSN{self.identity.capitalize() if self.identity else ''}_approxMag", Scalar),
82 (f"{{band}}_highSN{self.identity.capitalize() if self.identity else ''}_median", Scalar),
83 (f"{{band}}_highSN{self.identity.capitalize() if self.identity else ''}_sigmaMad", Scalar),
84 (f"{{band}}_highSN{self.identity.capitalize() if self.identity else ''}_count", Scalar),
85 (f"{{band}}_highSN{self.identity.capitalize() if self.identity else ''}_approxMag", Scalar),
86 ("highThreshold", Scalar),
87 ("lowThreshold", Scalar),
88 )
90 def __call__(self, data: KeyedData, **kwargs) -> KeyedData:
91 results = {}
92 highMaskKey = f'{(self.identity or "").lower()}HighSNMask'
93 results[highMaskKey] = self.highSNSelector(data, **kwargs)
95 lowMaskKey = f'{(self.identity or "").lower()}LowSNMask'
96 results[lowMaskKey] = self.lowSNSelector(data, **kwargs)
98 prefix = f"{band}_" if (band := kwargs.get("band")) else ""
99 fluxes = data[self.fluxType.format(band=band)] if band is not None else None
101 statAction = SummaryStatisticAction(vectorKey=self.vectorKey)
103 # this is sad, but pex_config seems to have broken behavior that
104 # is dangerous to fix
105 statAction.setDefaults()
107 medianAction = MedianAction(vectorKey="mag")
108 magAction = ConvertFluxToMag(vectorKey="flux")
110 for maskKey, binName in ((lowMaskKey, "low"), (highMaskKey, "high")):
111 name = f"{prefix}{binName}SN{self.identity.capitalize() if self.identity else ''}"
112 # set the approxMag to the median mag in the SN selection
113 results[f"{name}_approxMag".format(**kwargs)] = (
114 medianAction({"mag": magAction({"flux": fluxes[results[maskKey]]})}) # type: ignore
115 if band is not None
116 else np.nan
117 )
118 stats = statAction(data, **(kwargs | {"mask": results[maskKey]})).items()
119 for suffix, value in stats:
120 tmpKey = f"{name}_{suffix}".format(**kwargs)
121 results[tmpKey] = value
122 results["highSnThreshold"] = self.highSNSelector.threshold # type: ignore
123 results["lowSnThreshold"] = self.lowSNSelector.threshold # type: ignore
125 return results
128def _validObjectTypes(value):
129 return value in ("stars", "galaxies", "unknown", "any")
132# ignore type because of conflicting name on tuple baseclass
133class _StatsContainer(NamedTuple):
134 median: Scalar
135 sigmaMad: Scalar
136 count: Scalar # type: ignore
137 approxMag: Scalar
140class DataTypeDefaults(NamedTuple):
141 suffix_stat: str
142 suffix_xy: str
143 color: str
144 colormap: matplotlib.colors.Colormap | None
147class ScatterPlotWithTwoHists(PlotAction):
148 """Makes a scatter plot of the data with a marginal
149 histogram for each axis.
150 """
152 yLims = ListField[float](
153 doc="ylimits of the plot, if not specified determined from data",
154 length=2,
155 optional=True,
156 )
158 xLims = ListField[float](
159 doc="xlimits of the plot, if not specified determined from data", length=2, optional=True
160 )
161 xAxisLabel = Field[str](doc="Label to use for the x axis", optional=False)
162 yAxisLabel = Field[str](doc="Label to use for the y axis", optional=False)
163 magLabel = Field[str](doc="Label to use for the magnitudes used for SNR", optional=False)
165 legendLocation = Field[str](doc="Legend position within main plot", default="upper left")
166 nBins = Field[float](doc="Number of bins on x axis", default=40.0)
167 plot2DHist = Field[bool](
168 doc="Plot a 2D histogram in dense areas of points on the scatter plot."
169 "Doesn't look great if plotting multiple datasets on top of each other.",
170 default=True,
171 )
172 plotTypes = ListField[str](
173 doc="Selection of types of objects to plot. Can take any combination of"
174 " stars, galaxies, unknown, any",
175 optional=False,
176 itemCheck=_validObjectTypes,
177 )
179 addSummaryPlot = Field[bool](
180 doc="Add a summary plot to the figure?",
181 default=True,
182 )
184 _stats = ("median", "sigmaMad", "count", "approxMag")
185 _datatypes = {
186 "galaxies": DataTypeDefaults(
187 suffix_stat="Galaxies",
188 suffix_xy="Galaxies",
189 color="firebrick",
190 colormap=mkColormap(["lemonchiffon", "firebrick"]),
191 ),
192 "stars": DataTypeDefaults(
193 suffix_stat="Stars",
194 suffix_xy="Stars",
195 color="midnightblue",
196 colormap=mkColormap(["paleturquoise", "midnightBlue"]),
197 ),
198 "unknown": DataTypeDefaults(
199 suffix_stat="Unknown",
200 suffix_xy="Unknown",
201 color="green",
202 colormap=None,
203 ),
204 "any": DataTypeDefaults(
205 suffix_stat="Any",
206 suffix_xy="",
207 color="purple",
208 colormap=None,
209 ),
210 }
212 def getInputSchema(self) -> KeyedDataSchema:
213 base: list[tuple[str, type[Vector] | ScalarType]] = []
214 for name_datatype in self.plotTypes:
215 config_datatype = self._datatypes[name_datatype]
216 base.append((f"x{config_datatype.suffix_xy}", Vector))
217 base.append((f"y{config_datatype.suffix_xy}", Vector))
218 base.append((f"{name_datatype}HighSNMask", Vector))
219 base.append((f"{name_datatype}LowSNMask", Vector))
220 # statistics
221 for name in self._stats:
222 base.append((f"{{band}}_highSN{config_datatype.suffix_stat}_{name}", Scalar))
223 base.append((f"{{band}}_lowSN{config_datatype.suffix_stat}_{name}", Scalar))
224 base.append(("lowSnThreshold", Scalar))
225 base.append(("highSnThreshold", Scalar))
227 if self.addSummaryPlot:
228 base.append(("patch", Vector))
230 return base
232 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure:
233 self._validateInput(data, **kwargs)
234 return self.makePlot(data, **kwargs)
236 def _validateInput(self, data: KeyedData, **kwargs) -> None:
237 """NOTE currently can only check that something is not a Scalar, not
238 check that the data is consistent with Vector
239 """
240 needed = self.getFormattedInputSchema(**kwargs)
241 if remainder := {key.format(**kwargs) for key, _ in needed} - {
242 key.format(**kwargs) for key in data.keys()
243 }:
244 raise ValueError(f"Task needs keys {remainder} but they were not found in input")
245 for name, typ in needed:
246 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar)
247 if isScalar and typ != Scalar:
248 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}")
250 def makePlot(
251 self,
252 data: KeyedData,
253 skymap: BaseSkyMap,
254 plotInfo: Mapping[str, str],
255 **kwargs,
256 ) -> Figure:
257 """Makes a generic plot with a 2D histogram and collapsed histograms of
258 each axis.
260 Parameters
261 ----------
262 data : `KeyedData`
263 The catalog to plot the points from.
264 skymap : `lsst.skymap.BaseSkyMap`
265 The skymap that gives the patch locations
266 plotInfo : `dict`
267 A dictionary of information about the data being plotted with keys:
269 * ``"run"``
270 The output run for the plots (`str`).
271 * ``"skymap"``
272 The type of skymap used for the data (`str`).
273 * ``"filter"``
274 The filter used for this data (`str`).
275 * ``"tract"``
276 The tract that the data comes from (`str`).
278 Returns
279 -------
280 fig : `matplotlib.figure.Figure`
281 The resulting figure.
283 Notes
284 -----
285 Uses the axisLabels config options `x` and `y` and the axisAction
286 config options `xAction` and `yAction` to plot a scatter
287 plot of the values against each other. A histogram of the points
288 collapsed onto each axis is also plotted. A summary panel showing the
289 median of the y value in each patch is shown in the upper right corner
290 of the resultant plot. The code uses the selectorActions to decide
291 which points to plot and the statisticSelector actions to determine
292 which points to use for the printed statistics.
294 If this function is being used within the pipetask framework
295 that takes care of making sure that data has all the required
296 elements but if you are running this as a standalone function
297 then you will need to provide the following things in the
298 input data.
300 * If stars is in self.plotTypes:
301 xStars, yStars, starsHighSNMask, starsLowSNMask and
302 {band}_highSNStars_{name}, {band}_lowSNStars_{name}
303 where name is median, sigma_Mad, count and approxMag.
305 * If it is for galaxies/unknowns then replace stars in the above
306 names with galaxies/unknowns.
308 * If it is for any (which covers all the points) then it
309 becomes, x, y, and any instead of stars for the other
310 parameters given above.
312 * In every case it is expected that data contains:
313 lowSnThreshold, highSnThreshold and patch
314 (if the summary plot is being plotted).
316 Examples
317 --------
318 An example of the plot produced from this code is here:
320 .. image:: /_static/analysis_tools/scatterPlotExample.png
322 For a detailed example of how to make a plot from the command line
323 please see the
324 :ref:`getting started guide<analysis-tools-getting-started>`.
325 """
326 if not self.plotTypes:
327 noDataFig = Figure()
328 noDataFig.text(0.3, 0.5, "No data to plot after selectors applied")
329 noDataFig = addPlotInfo(noDataFig, plotInfo)
330 return noDataFig
332 # Set default color and line style for the horizontal
333 # reference line at 0
334 if "hlineColor" not in kwargs:
335 kwargs["hlineColor"] = "black"
337 if "hlineStyle" not in kwargs:
338 kwargs["hlineStyle"] = (0, (1, 4))
340 fig = plt.figure(dpi=300)
341 gs = gridspec.GridSpec(4, 4)
343 # add the various plot elements
344 ax, imhist = self._scatterPlot(data, fig, gs, **kwargs)
345 self._makeTopHistogram(data, fig, gs, ax, **kwargs)
346 self._makeSideHistogram(data, fig, gs, ax, imhist, **kwargs)
347 # Needs info from run quantum
348 if self.addSummaryPlot:
349 sumStats = generateSummaryStats(data, skymap, plotInfo)
350 label = self.yAxisLabel
351 fig = addSummaryPlot(fig, gs[0, -1], sumStats, label)
353 plt.draw()
354 plt.subplots_adjust(wspace=0.0, hspace=0.0, bottom=0.22, left=0.21)
355 fig = addPlotInfo(fig, plotInfo)
356 return fig
358 def _scatterPlot(
359 self, data: KeyedData, fig: Figure, gs: gridspec.GridSpec, **kwargs
360 ) -> tuple[Axes, Optional[PolyCollection]]:
361 # Main scatter plot
362 ax = fig.add_subplot(gs[1:, :-1])
364 binThresh = 5
366 yBinsOut = []
367 linesForLegend = []
369 toPlotList = []
370 histIm = None
371 highStats: _StatsContainer
372 lowStats: _StatsContainer
374 for name_datatype in self.plotTypes:
375 config_datatype = self._datatypes[name_datatype]
376 highArgs = {}
377 lowArgs = {}
378 for name in self._stats:
379 highArgs[name] = cast(
380 Scalar, data[f"{{band}}_highSN{config_datatype.suffix_stat}_{name}".format(**kwargs)]
381 )
382 lowArgs[name] = cast(
383 Scalar, data[f"{{band}}_lowSN{config_datatype.suffix_stat}_{name}".format(**kwargs)]
384 )
385 highStats = _StatsContainer(**highArgs)
386 lowStats = _StatsContainer(**lowArgs)
388 toPlotList.append(
389 (
390 data[f"x{config_datatype.suffix_xy}"],
391 data[f"y{config_datatype.suffix_xy}"],
392 data[f"{name_datatype}HighSNMask"],
393 data[f"{name_datatype}LowSNMask"],
394 config_datatype.color,
395 config_datatype.colormap,
396 highStats,
397 lowStats,
398 )
399 )
401 xLims = self.xLims if self.xLims is not None else [np.Inf, -np.Inf]
402 for j, (xs, ys, highSn, lowSn, color, cmap, highStats, lowStats) in enumerate(toPlotList):
403 highSn = cast(Vector, highSn)
404 lowSn = cast(Vector, lowSn)
405 # ensure the columns are actually array
406 xs = np.array(xs)
407 ys = np.array(ys)
408 sigMadYs = nanSigmaMad(ys)
409 # plot lone median point if there's not enough data to measure more
410 n_xs = len(xs)
411 if n_xs == 0 or not np.isfinite(sigMadYs):
412 continue
413 elif n_xs < 10:
414 xs = [nanMedian(xs)]
415 sigMads = np.array([nanSigmaMad(ys)])
416 ys = np.array([nanMedian(ys)])
417 (medLine,) = ax.plot(xs, ys, color, label=f"Median: {ys[0]:.2g}", lw=0.8)
418 linesForLegend.append(medLine)
419 (sigMadLine,) = ax.plot(
420 xs,
421 ys + 1.0 * sigMads,
422 color,
423 alpha=0.8,
424 lw=0.8,
425 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:.2g}",
426 )
427 ax.plot(xs, ys - 1.0 * sigMads, color, alpha=0.8)
428 linesForLegend.append(sigMadLine)
429 histIm = None
430 continue
432 if self.xLims:
433 xMin, xMax = self.xLims
434 else:
435 # Chop off 1/3% from only the finite xs values
436 # (there may be +/-np.Inf values)
437 # TODO: This should be configurable
438 # but is there a good way to avoid redundant config params
439 # without using slightly annoying subconfigs?
440 xs1, xs97 = np.nanpercentile(xs[np.isfinite(xs)], (1, 97))
441 xScale = (xs97 - xs1) / 20.0 # This is ~5% of the data range
442 xMin, xMax = (xs1 - xScale, xs97 + xScale)
443 xLims[0] = min(xLims[0], xMin)
444 xLims[1] = max(xLims[1], xMax)
446 xEdges = np.arange(xMin, xMax, (xMax - xMin) / self.nBins)
447 medYs = nanMedian(ys)
448 fiveSigmaHigh = medYs + 5.0 * sigMadYs
449 fiveSigmaLow = medYs - 5.0 * sigMadYs
450 binSize = (fiveSigmaHigh - fiveSigmaLow) / 101.0
451 yEdges = np.arange(fiveSigmaLow, fiveSigmaHigh, binSize)
453 counts, xBins, yBins = np.histogram2d(xs, ys, bins=(xEdges, yEdges))
454 yBinsOut.append(yBins)
455 countsYs = np.sum(counts, axis=1)
457 ids = np.where((countsYs > binThresh))[0]
458 xEdgesPlot = xEdges[ids][1:]
459 xEdges = xEdges[ids]
461 if len(ids) > 1:
462 # Create the codes needed to turn the sigmaMad lines
463 # into a path to speed up checking which points are
464 # inside the area.
465 codes = np.ones(len(xEdgesPlot) * 2) * Path.LINETO
466 codes[0] = Path.MOVETO
467 codes[-1] = Path.CLOSEPOLY
469 meds = np.zeros(len(xEdgesPlot))
470 threeSigMadVerts = np.zeros((len(xEdgesPlot) * 2, 2))
471 sigMads = np.zeros(len(xEdgesPlot))
473 for i, xEdge in enumerate(xEdgesPlot):
474 ids = np.where((xs < xEdge) & (xs > xEdges[i]) & (np.isfinite(ys)))[0]
475 med = nanMedian(ys[ids])
476 sigMad = nanSigmaMad(ys[ids])
477 meds[i] = med
478 sigMads[i] = sigMad
479 threeSigMadVerts[i, :] = [xEdge, med + 3 * sigMad]
480 threeSigMadVerts[-(i + 1), :] = [xEdge, med - 3 * sigMad]
482 (medLine,) = ax.plot(xEdgesPlot, meds, color, label="Running Median")
483 linesForLegend.append(medLine)
485 # Make path to check which points lie within one sigma mad
486 threeSigMadPath = Path(threeSigMadVerts, codes)
488 # Add lines for the median +/- 3 * sigma MAD
489 (threeSigMadLine,) = ax.plot(
490 xEdgesPlot,
491 threeSigMadVerts[: len(xEdgesPlot), 1],
492 color,
493 alpha=0.4,
494 label=r"3$\sigma_{MAD}$",
495 )
496 ax.plot(xEdgesPlot[::-1], threeSigMadVerts[len(xEdgesPlot) :, 1], color, alpha=0.4)
498 # Add lines for the median +/- 1 * sigma MAD
499 (sigMadLine,) = ax.plot(
500 xEdgesPlot, meds + 1.0 * sigMads, color, alpha=0.8, label=r"$\sigma_{MAD}$"
501 )
502 linesForLegend.append(sigMadLine)
503 ax.plot(xEdgesPlot, meds - 1.0 * sigMads, color, alpha=0.8)
505 # Add lines for the median +/- 2 * sigma MAD
506 (twoSigMadLine,) = ax.plot(
507 xEdgesPlot, meds + 2.0 * sigMads, color, alpha=0.6, label=r"2$\sigma_{MAD}$"
508 )
509 linesForLegend.append(twoSigMadLine)
510 linesForLegend.append(threeSigMadLine)
511 ax.plot(xEdgesPlot, meds - 2.0 * sigMads, color, alpha=0.6)
513 # Check which points are outside 3 sigma MAD of the median
514 # and plot these as points.
515 inside = threeSigMadPath.contains_points(np.array([xs, ys]).T)
516 ax.plot(xs[~inside], ys[~inside], ".", ms=3, alpha=0.3, mfc=color, mec=color, zorder=-1)
518 # Add some stats text
519 xPos = 0.65 - 0.4 * j
520 bbox = dict(edgecolor=color, linestyle="--", facecolor="none")
521 highThresh = data["highSnThreshold"]
522 statText = f"S/N > {highThresh:0.4g} Stats ({self.magLabel} < {highStats.approxMag:0.4g})\n"
523 highStatsStr = (
524 f"Median: {highStats.median:0.4g} "
525 + r"$\sigma_{MAD}$: "
526 + f"{highStats.sigmaMad:0.4g} "
527 + r"N$_{points}$: "
528 + f"{highStats.count}"
529 )
530 statText += highStatsStr
531 fig.text(xPos, 0.090, statText, bbox=bbox, transform=fig.transFigure, fontsize=6)
533 bbox = dict(edgecolor=color, linestyle=":", facecolor="none")
534 lowThresh = data["lowSnThreshold"]
535 statText = f"S/N > {lowThresh:0.4g} Stats ({self.magLabel} < {lowStats.approxMag:0.4g})\n"
536 lowStatsStr = (
537 f"Median: {lowStats.median:0.4g} "
538 + r"$\sigma_{MAD}$: "
539 + f"{lowStats.sigmaMad:0.4g} "
540 + r"N$_{points}$: "
541 + f"{lowStats.count}"
542 )
543 statText += lowStatsStr
544 fig.text(xPos, 0.020, statText, bbox=bbox, transform=fig.transFigure, fontsize=6)
546 if self.plot2DHist:
547 histIm = ax.hexbin(xs[inside], ys[inside], gridsize=75, cmap=cmap, mincnt=1, zorder=-3)
549 # If there are not many sources being used for the
550 # statistics then plot them individually as just
551 # plotting a line makes the statistics look wrong
552 # as the magnitude estimation is iffy for low
553 # numbers of sources.
554 if np.sum(highSn) < 100 and np.sum(highSn) > 0:
555 ax.plot(
556 cast(Vector, xs[highSn]),
557 cast(Vector, ys[highSn]),
558 marker="x",
559 ms=4,
560 mec="w",
561 mew=2,
562 ls="none",
563 )
564 (highSnLine,) = ax.plot(
565 cast(Vector, xs[highSn]),
566 cast(Vector, ys[highSn]),
567 color=color,
568 marker="x",
569 ms=4,
570 ls="none",
571 label="High SN",
572 )
573 linesForLegend.append(highSnLine)
574 else:
575 ax.axvline(highStats.approxMag, color=color, ls="--")
577 if np.sum(lowSn) < 100 and np.sum(lowSn) > 0:
578 ax.plot(
579 cast(Vector, xs[lowSn]),
580 cast(Vector, ys[lowSn]),
581 marker="+",
582 ms=4,
583 mec="w",
584 mew=2,
585 ls="none",
586 )
587 (lowSnLine,) = ax.plot(
588 cast(Vector, xs[lowSn]),
589 cast(Vector, ys[lowSn]),
590 color=color,
591 marker="+",
592 ms=4,
593 ls="none",
594 label="Low SN",
595 )
596 linesForLegend.append(lowSnLine)
597 else:
598 ax.axvline(lowStats.approxMag, color=color, ls=":")
600 else:
601 ax.plot(xs, ys, ".", ms=5, alpha=0.3, mfc=color, mec=color, zorder=-1)
602 meds = np.array([nanMedian(ys)] * len(xs))
603 (medLine,) = ax.plot(xs, meds, color, label=f"Median: {nanMedian(ys):0.3g}", lw=0.8)
604 linesForLegend.append(medLine)
605 sigMads = np.array([nanSigmaMad(ys)] * len(xs))
606 (sigMadLine,) = ax.plot(
607 xs,
608 meds + 1.0 * sigMads,
609 color,
610 alpha=0.8,
611 lw=0.8,
612 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:0.3g}",
613 )
614 ax.plot(xs, meds - 1.0 * sigMads, color, alpha=0.8)
615 linesForLegend.append(sigMadLine)
616 histIm = None
618 # Add a horizontal reference line at 0 to the scatter plot
619 ax.axhline(0, color=kwargs["hlineColor"], ls=kwargs["hlineStyle"], alpha=0.7, zorder=-2)
621 # Set the scatter plot limits
622 # TODO: Make this not work by accident
623 if "yStars" in data and (len(cast(Vector, data["yStars"])) > 0):
624 plotMed = nanMedian(cast(Vector, data["yStars"]))
625 elif "yGalaxies" in data and (len(cast(Vector, data["yGalaxies"])) > 0):
626 plotMed = nanMedian(cast(Vector, data["yGalaxies"]))
627 else:
628 plotMed = np.nan
630 # Ignore types below pending making this not working my accident
631 if len(xs) < 2: # type: ignore
632 meds = [nanMedian(ys)] # type: ignore
633 if self.yLims:
634 ax.set_ylim(self.yLims[0], self.yLims[1]) # type: ignore
635 elif np.isfinite(plotMed):
636 numSig = 4
637 yLimMin = plotMed - numSig * sigMadYs # type: ignore
638 yLimMax = plotMed + numSig * sigMadYs # type: ignore
639 while (yLimMax < np.max(meds) or yLimMin > np.min(meds)) and numSig < 10: # type: ignore
640 numSig += 1
642 numSig += 1
643 yLimMin = plotMed - numSig * sigMadYs # type: ignore
644 yLimMax = plotMed + numSig * sigMadYs # type: ignore
645 ax.set_ylim(yLimMin, yLimMax)
647 # This could be false if len(x) == 0 for xs in toPlotList
648 # ... in which case nothing was plotted and limits are irrelevant
649 if all(np.isfinite(xLims)):
650 ax.set_xlim(xLims)
652 # Add a line legend
653 ax.legend(
654 handles=linesForLegend,
655 ncol=4,
656 fontsize=6,
657 loc=self.legendLocation,
658 framealpha=0.9,
659 edgecolor="k",
660 borderpad=0.4,
661 handlelength=1,
662 )
664 # Add axes labels
665 ax.set_ylabel(self.yAxisLabel, fontsize=10, labelpad=10)
666 ax.set_xlabel(self.xAxisLabel, fontsize=10, labelpad=2)
668 return ax, histIm
670 def _makeTopHistogram(
671 self, data: KeyedData, figure: Figure, gs: gridspec.GridSpec, ax: Axes, **kwargs
672 ) -> None:
673 # Top histogram
674 topHist = figure.add_subplot(gs[0, :-1], sharex=ax)
675 x_min, x_max = ax.get_xlim()
676 bins = np.linspace(x_min, x_max, 100)
678 if "any" in self.plotTypes:
679 x_any = f"x{self._datatypes['any'].suffix_xy}"
680 keys_notany = [x for x in self.plotTypes if x != "any"]
681 topHist.hist(x_any, bins=bins, color="grey", alpha=0.3, log=True, label=f"Any ({len(x_any)})")
682 else:
683 x_any = np.concatenate([data[f"x{self._datatypes[key].suffix_xy}"] for key in self.plotTypes])
684 keys_notany = self.plotTypes
686 for key in keys_notany:
687 config_datatype = self._datatypes[key]
688 vector = data[f"x{config_datatype.suffix_xy}"]
689 topHist.hist(
690 vector,
691 bins=bins,
692 color=config_datatype.color,
693 histtype="step",
694 log=True,
695 label=f"{config_datatype.suffix_stat} ({len(vector)})",
696 )
697 topHist.axes.get_xaxis().set_visible(False)
698 topHist.set_ylabel("Number", fontsize=8)
699 topHist.legend(fontsize=6, framealpha=0.9, borderpad=0.4, loc="lower left", ncol=3, edgecolor="k")
701 # Side histogram
703 def _makeSideHistogram(
704 self,
705 data: KeyedData,
706 figure: Figure,
707 gs: gridspec.Gridspec,
708 ax: Axes,
709 histIm: Optional[PolyCollection],
710 **kwargs,
711 ) -> None:
712 # Side histogram
713 sideHist = figure.add_subplot(gs[1:, -1], sharey=ax)
714 y_min, y_max = ax.get_ylim()
715 bins = np.linspace(y_min, y_max, 100)
717 if "any" in self.plotTypes:
718 y_any = f"y{self._datatypes['any'].suffix_xy}"
719 keys_notany = [x for x in self.plotTypes if x != "any"]
720 sideHist.hist(y_any, bins=bins, color="grey", alpha=0.3, orientation="horizontal", log=True)
721 else:
722 y_any = np.concatenate([data[f"y{self._datatypes[key].suffix_xy}"] for key in self.plotTypes])
723 keys_notany = self.plotTypes
725 kwargs_hist = dict(
726 bins=bins,
727 histtype="step",
728 log=True,
729 orientation="horizontal",
730 )
731 for key in keys_notany:
732 config_datatype = self._datatypes[key]
733 vector = data[f"y{config_datatype.suffix_xy}"]
734 sideHist.hist(
735 vector,
736 color=config_datatype.color,
737 **kwargs_hist,
738 )
739 sideHist.hist(
740 vector[cast(Vector, data[f"{key}HighSNMask"])],
741 color=config_datatype.color,
742 ls="--",
743 **kwargs_hist,
744 )
745 sideHist.hist(
746 vector[cast(Vector, data[f"{key}LowSNMask"])],
747 color=config_datatype.color,
748 **kwargs_hist,
749 ls=":",
750 )
752 # Add a horizontal reference line at 0 to the side histogram
753 sideHist.axhline(0, color=kwargs["hlineColor"], ls=kwargs["hlineStyle"], alpha=0.7, zorder=-2)
755 sideHist.axes.get_yaxis().set_visible(False)
756 sideHist.set_xlabel("Number", fontsize=8)
757 if self.plot2DHist and histIm is not None:
758 divider = make_axes_locatable(sideHist)
759 cax = divider.append_axes("right", size="8%", pad=0)
760 plt.colorbar(histIm, cax=cax, orientation="vertical", label="Number of Points Per Bin")