Coverage for python/lsst/analysis/tools/actions/plot/scatterplotWithTwoHists.py: 12%
367 statements
« prev ^ index » next coverage.py v7.2.4, created at 2023-04-30 03:04 -0700
« prev ^ index » next coverage.py v7.2.4, created at 2023-04-30 03:04 -0700
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ("ScatterPlotStatsAction", "ScatterPlotWithTwoHists")
26from itertools import chain
27from typing import Mapping, NamedTuple, Optional, cast
29import matplotlib.pyplot as plt
30import numpy as np
31from lsst.pex.config import Field
32from lsst.pex.config.configurableActions import ConfigurableActionField
33from lsst.pex.config.listField import ListField
34from lsst.skymap import BaseSkyMap
35from matplotlib import gridspec
36from matplotlib.axes import Axes
37from matplotlib.collections import PolyCollection
38from matplotlib.figure import Figure
39from matplotlib.path import Path
40from mpl_toolkits.axes_grid1 import make_axes_locatable
42from ...interfaces import KeyedData, KeyedDataAction, KeyedDataSchema, PlotAction, Scalar, ScalarType, Vector
43from ...statistics import nansigmaMad, sigmaMad
44from ..keyedData.summaryStatistics import SummaryStatisticAction
45from ..scalar import MedianAction
46from ..vector import MagColumnNanoJansky, SnSelector
47from .plotUtils import addPlotInfo, addSummaryPlot, generateSummaryStats, mkColormap
49# ignore because coolwarm is actually part of module
50cmapPatch = plt.cm.coolwarm.copy() # type: ignore
51cmapPatch.set_bad(color="none")
54class ScatterPlotStatsAction(KeyedDataAction):
55 vectorKey = Field[str](doc="Vector on which to compute statistics")
56 highSNSelector = ConfigurableActionField[SnSelector](
57 doc="Selector used to determine high SN Objects", default=SnSelector(threshold=2700)
58 )
59 lowSNSelector = ConfigurableActionField[SnSelector](
60 doc="Selector used to determine low SN Objects", default=SnSelector(threshold=500)
61 )
62 fluxType = Field[str](doc="Vector key to use to compute signal to noise ratio", default="{band}_psfFlux")
64 def getInputSchema(self, **kwargs) -> KeyedDataSchema:
65 yield (self.vectorKey, Vector)
66 yield (self.fluxType, Vector)
67 yield from self.highSNSelector.getInputSchema()
68 yield from self.lowSNSelector.getInputSchema()
70 def getOutputSchema(self) -> KeyedDataSchema:
71 return (
72 (f'{self.identity or ""}HighSNMask', Vector),
73 (f'{self.identity or ""}LowSNMask', Vector),
74 (f"{{band}}_lowSN{self.identity.capitalize() if self.identity else ''}_median", Scalar),
75 (f"{{band}}_lowSN{self.identity.capitalize() if self.identity else ''}_sigmaMad", Scalar),
76 (f"{{band}}_lowSN{self.identity.capitalize() if self.identity else ''}_count", Scalar),
77 (f"{{band}}_lowSN{self.identity.capitalize() if self.identity else ''}_approxMag", Scalar),
78 (f"{{band}}_highSN{self.identity.capitalize() if self.identity else ''}_median", Scalar),
79 (f"{{band}}_highSN{self.identity.capitalize() if self.identity else ''}_sigmaMad", Scalar),
80 (f"{{band}}_highSN{self.identity.capitalize() if self.identity else ''}_count", Scalar),
81 (f"{{band}}_highSN{self.identity.capitalize() if self.identity else ''}_approxMag", Scalar),
82 ("highThreshold", Scalar),
83 ("lowThreshold", Scalar),
84 )
86 def __call__(self, data: KeyedData, **kwargs) -> KeyedData:
87 results = {}
88 highMaskKey = f'{(self.identity or "").lower()}HighSNMask'
89 results[highMaskKey] = self.highSNSelector(data, **kwargs)
91 lowMaskKey = f'{(self.identity or "").lower()}LowSNMask'
92 results[lowMaskKey] = self.lowSNSelector(data, **kwargs)
94 prefix = f"{band}_" if (band := kwargs.get("band")) else ""
95 fluxes = data[self.fluxType.format(band=band)] if band is not None else None
97 statAction = SummaryStatisticAction(vectorKey=self.vectorKey)
99 # this is sad, but pex_config seems to have broken behavior that
100 # is dangerous to fix
101 statAction.setDefaults()
103 medianAction = MedianAction(vectorKey="mag")
104 magAction = MagColumnNanoJansky(vectorKey="flux")
106 for maskKey, binName in ((lowMaskKey, "low"), (highMaskKey, "high")):
107 name = f"{prefix}{binName}SN{self.identity.capitalize() if self.identity else ''}"
108 # set the approxMag to the median mag in the SN selection
109 results[f"{name}_approxMag".format(**kwargs)] = (
110 medianAction({"mag": magAction({"flux": fluxes[results[maskKey]]})}) # type: ignore
111 if band is not None
112 else np.nan
113 )
114 stats = statAction(data, **(kwargs | {"mask": results[maskKey]})).items()
115 for suffix, value in stats:
116 tmpKey = f"{name}_{suffix}".format(**kwargs)
117 results[tmpKey] = value
118 results["highSnThreshold"] = self.highSNSelector.threshold # type: ignore
119 results["lowSnThreshold"] = self.lowSNSelector.threshold # type: ignore
121 return results
124def _validatePlotTypes(value):
125 return value in ("stars", "galaxies", "unknown", "any", "mag")
128# ignore type because of conflicting name on tuple baseclass
129class _StatsContainer(NamedTuple):
130 median: Scalar
131 sigmaMad: Scalar
132 count: Scalar # type: ignore
133 approxMag: Scalar
136class ScatterPlotWithTwoHists(PlotAction):
137 yLims = ListField[float](
138 doc="ylimits of the plot, if not specified determined from data",
139 length=2,
140 optional=True,
141 )
143 xLims = ListField[float](
144 doc="xlimits of the plot, if not specified determined from data", length=2, optional=True
145 )
146 xAxisLabel = Field[str](doc="Label to use for the x axis", optional=False)
147 yAxisLabel = Field[str](doc="Label to use for the y axis", optional=False)
148 magLabel = Field[str](doc="Label to use for the magnitudes used for SNR", optional=False)
149 nBins = Field[float](doc="Number of bins on x axis", default=40.0)
150 plot2DHist = Field[bool](
151 doc="Plot a 2D histogram in dense areas of points on the scatter plot."
152 "Doesn't look great if plotting multiple datasets on top of each other.",
153 default=True,
154 )
155 plotTypes = ListField[str](
156 doc="Selection of types of objects to plot. Can take any combination of"
157 " stars, galaxies, unknown, mag, any",
158 optional=False,
159 itemCheck=_validatePlotTypes,
160 )
162 addSummaryPlot = Field[bool](
163 doc="Add a summary plot to the figure?",
164 default=False,
165 )
167 _stats = ("median", "sigmaMad", "count", "approxMag")
169 def getInputSchema(self) -> KeyedDataSchema:
170 base: list[tuple[str, type[Vector] | ScalarType]] = []
171 if "stars" in self.plotTypes: # type: ignore
172 base.append(("xStars", Vector))
173 base.append(("yStars", Vector))
174 base.append(("starsHighSNMask", Vector))
175 base.append(("starsLowSNMask", Vector))
176 # statistics
177 for name in self._stats:
178 base.append((f"{{band}}_highSNStars_{name}", Scalar))
179 base.append((f"{{band}}_lowSNStars_{name}", Scalar))
180 if "galaxies" in self.plotTypes: # type: ignore
181 base.append(("xGalaxies", Vector))
182 base.append(("yGalaxies", Vector))
183 base.append(("galaxiesHighSNMask", Vector))
184 base.append(("galaxiesLowSNMask", Vector))
185 # statistics
186 for name in self._stats:
187 base.append((f"{{band}}_highSNGalaxies_{name}", Scalar))
188 base.append((f"{{band}}_lowSNGalaxies_{name}", Scalar))
189 if "unknown" in self.plotTypes: # type: ignore
190 base.append(("xUnknown", Vector))
191 base.append(("yUnknown", Vector))
192 base.append(("unknownHighSNMask", Vector))
193 base.append(("unknownLowSNMask", Vector))
194 # statistics
195 for name in self._stats:
196 base.append((f"{{band}}_highSNUnknown_{name}", Scalar))
197 base.append((f"{{band}}_lowSNUnknown_{name}", Scalar))
198 if "any" in self.plotTypes: # type: ignore
199 base.append(("x", Vector))
200 base.append(("y", Vector))
201 base.append(("anyHighSNMask", Vector))
202 base.append(("anySNMask", Vector))
203 # statistics
204 for name in self._stats:
205 base.append((f"{{band}}_highSNAny_{name}", Scalar))
206 base.append((f"{{band}}_lowSNAny_{name}", Scalar))
207 base.append(("lowSnThreshold", Scalar))
208 base.append(("highSnThreshold", Scalar))
210 if self.addSummaryPlot:
211 base.append(("patch", Vector))
213 return base
215 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure:
216 self._validateInput(data, **kwargs)
217 return self.makePlot(data, **kwargs)
219 def _validateInput(self, data: KeyedData, **kwargs) -> None:
220 """NOTE currently can only check that something is not a Scalar, not
221 check that the data is consistent with Vector
222 """
223 needed = self.getFormattedInputSchema(**kwargs)
224 if remainder := {key.format(**kwargs) for key, _ in needed} - {
225 key.format(**kwargs) for key in data.keys()
226 }:
227 raise ValueError(f"Task needs keys {remainder} but they were not found in input")
228 for name, typ in needed:
229 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar)
230 if isScalar and typ != Scalar:
231 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}")
233 def makePlot(
234 self,
235 data: KeyedData,
236 skymap: BaseSkyMap,
237 plotInfo: Mapping[str, str],
238 sumStats: Optional[Mapping] = None,
239 **kwargs,
240 ) -> Figure:
241 """Makes a generic plot with a 2D histogram and collapsed histograms of
242 each axis.
243 Parameters
244 ----------
245 data : `pandas.core.frame.DataFrame`
246 The catalog to plot the points from.
247 plotInfo : `dict`
248 A dictionary of information about the data being plotted with keys:
249 ``"run"``
250 The output run for the plots (`str`).
251 ``"skymap"``
252 The type of skymap used for the data (`str`).
253 ``"filter"``
254 The filter used for this data (`str`).
255 ``"tract"``
256 The tract that the data comes from (`str`).
257 sumStats : `dict`
258 A dictionary where the patchIds are the keys which store the R.A.
259 and dec of the corners of the patch, along with a summary
260 statistic for each patch.
261 Returns
262 -------
263 fig : `matplotlib.figure.Figure`
264 The resulting figure.
265 Notes
266 -----
267 Uses the axisLabels config options `x` and `y` and the axisAction
268 config options `xAction` and `yAction` to plot a scatter
269 plot of the values against each other. A histogram of the points
270 collapsed onto each axis is also plotted. A summary panel showing the
271 median of the y value in each patch is shown in the upper right corner
272 of the resultant plot. The code uses the selectorActions to decide
273 which points to plot and the statisticSelector actions to determine
274 which points to use for the printed statistics.
275 """
276 if not self.plotTypes:
277 noDataFig = Figure()
278 noDataFig.text(0.3, 0.5, "No data to plot after selectors applied")
279 noDataFig = addPlotInfo(noDataFig, plotInfo)
280 return noDataFig
282 # Set default color and line style for the horizontal
283 # reference line at 0
284 if "hlineColor" not in kwargs:
285 kwargs["hlineColor"] = "black"
287 if "hlineStyle" not in kwargs:
288 kwargs["hlineStyle"] = (0, (1, 4))
290 fig = plt.figure(dpi=300)
291 gs = gridspec.GridSpec(4, 4)
293 # add the various plot elements
294 ax, imhist = self._scatterPlot(data, fig, gs, **kwargs)
295 self._makeTopHistogram(data, fig, gs, ax, **kwargs)
296 self._makeSideHistogram(data, fig, gs, ax, imhist, **kwargs)
297 # Needs info from run quantum
298 if self.addSummaryPlot:
299 sumStats = generateSummaryStats(data, skymap, plotInfo)
300 label = self.yAxisLabel
301 fig = addSummaryPlot(fig, gs[0, -1], sumStats, label)
303 plt.draw()
304 plt.subplots_adjust(wspace=0.0, hspace=0.0, bottom=0.22, left=0.21)
305 fig = addPlotInfo(fig, plotInfo)
306 return fig
308 def _scatterPlot(
309 self, data: KeyedData, fig: Figure, gs: gridspec.GridSpec, **kwargs
310 ) -> tuple[Axes, Optional[PolyCollection]]:
311 # Main scatter plot
312 ax = fig.add_subplot(gs[1:, :-1])
314 newBlues = mkColormap(["paleturquoise", "midnightBlue"])
315 newReds = mkColormap(["lemonchiffon", "firebrick"])
317 binThresh = 5
319 yBinsOut = []
320 linesForLegend = []
322 toPlotList = []
323 histIm = None
324 highStats: _StatsContainer
325 lowStats: _StatsContainer
326 if "stars" in self.plotTypes: # type: ignore
327 highArgs = {}
328 lowArgs = {}
329 for name in self._stats:
330 highArgs[name] = cast(Scalar, data[f"{{band}}_highSNStars_{name}".format(**kwargs)])
331 lowArgs[name] = cast(Scalar, data[f"{{band}}_lowSNStars_{name}".format(**kwargs)])
332 highStats = _StatsContainer(**highArgs)
333 lowStats = _StatsContainer(**lowArgs)
335 toPlotList.append(
336 (
337 data["xStars"],
338 data["yStars"],
339 data["starsHighSNMask"],
340 data["starsLowSNMask"],
341 "midnightblue",
342 newBlues,
343 highStats,
344 lowStats,
345 )
346 )
347 if "galaxies" in self.plotTypes: # type: ignore
348 highArgs = {}
349 lowArgs = {}
350 for name in self._stats:
351 highArgs[name] = cast(Scalar, data[f"{{band}}_highSNGalaxies_{name}".format(**kwargs)])
352 lowArgs[name] = cast(Scalar, data[f"{{band}}_lowSNGalaxies_{name}".format(**kwargs)])
353 highStats = _StatsContainer(**highArgs)
354 lowStats = _StatsContainer(**lowArgs)
356 toPlotList.append(
357 (
358 data["xGalaxies"],
359 data["yGalaxies"],
360 data["galaxiesHighSNMask"],
361 data["galaxiesLowSNMask"],
362 "firebrick",
363 newReds,
364 highStats,
365 lowStats,
366 )
367 )
368 if "unknown" in self.plotTypes: # type: ignore
369 highArgs = {}
370 lowArgs = {}
371 for name in self._stats:
372 highArgs[name] = cast(Scalar, data[f"{{band}}_highSNUnknown_{name}".format(**kwargs)])
373 lowArgs[name] = cast(Scalar, data[f"{{band}}_lowSNUnknown_{name}".format(**kwargs)])
374 highStats = _StatsContainer(**highArgs)
375 lowStats = _StatsContainer(**lowArgs)
377 toPlotList.append(
378 (
379 data["xUnknown"],
380 data["yUnknown"],
381 data["unknownHighSNMask"],
382 data["unknownLowSNMask"],
383 "green",
384 None,
385 highStats,
386 lowStats,
387 )
388 )
389 if "any" in self.plotTypes: # type: ignore
390 highArgs = {}
391 lowArgs = {}
392 for name in self._stats:
393 highArgs[name] = cast(Scalar, data[f"{{band}}_highSNUnknown_{name}".format(**kwargs)])
394 lowArgs[name] = cast(Scalar, data[f"{{band}}_lowSNUnknown_{name}".format(**kwargs)])
395 highStats = _StatsContainer(**highArgs)
396 lowStats = _StatsContainer(**lowArgs)
398 toPlotList.append(
399 (
400 data["x"],
401 data["y"],
402 data["anyHighSNMask"],
403 data["anyLowSNMask"],
404 "purple",
405 None,
406 highStats,
407 lowStats,
408 )
409 )
411 xMin = None
412 for j, (xs, ys, highSn, lowSn, color, cmap, highStats, lowStats) in enumerate(toPlotList):
413 highSn = cast(Vector, highSn)
414 lowSn = cast(Vector, lowSn)
415 # ensure the columns are actually array
416 xs = np.array(xs)
417 ys = np.array(ys)
418 sigMadYs = nansigmaMad(ys)
419 if len(xs) < 2:
420 (medLine,) = ax.plot(
421 xs, np.nanmedian(ys), color, label=f"Median: {np.nanmedian(ys):.2g}", lw=0.8
422 )
423 linesForLegend.append(medLine)
424 sigMads = np.array([nansigmaMad(ys)] * len(xs))
425 (sigMadLine,) = ax.plot(
426 xs,
427 np.nanmedian(ys) + 1.0 * sigMads,
428 color,
429 alpha=0.8,
430 lw=0.8,
431 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:.2g}",
432 )
433 ax.plot(xs, np.nanmedian(ys) - 1.0 * sigMads, color, alpha=0.8)
434 linesForLegend.append(sigMadLine)
435 histIm = None
436 continue
438 [xs1, xs25, xs50, xs75, xs95, xs97] = np.nanpercentile(xs, [1, 25, 50, 75, 95, 97])
439 xScale = (xs97 - xs1) / 20.0 # This is ~5% of the data range
441 # 40 was used as the default number of bins because it looked good
442 xEdges = np.arange(
443 np.nanmin(xs) - xScale,
444 np.nanmax(xs) + xScale,
445 (np.nanmax(xs) + xScale - (np.nanmin(xs) - xScale)) / self.nBins,
446 )
447 medYs = np.nanmedian(ys)
448 fiveSigmaHigh = medYs + 5.0 * sigMadYs
449 fiveSigmaLow = medYs - 5.0 * sigMadYs
450 binSize = (fiveSigmaHigh - fiveSigmaLow) / 101.0
451 yEdges = np.arange(fiveSigmaLow, fiveSigmaHigh, binSize)
453 counts, xBins, yBins = np.histogram2d(xs, ys, bins=(xEdges, yEdges))
454 yBinsOut.append(yBins)
455 countsYs = np.sum(counts, axis=1)
457 ids = np.where((countsYs > binThresh))[0]
458 xEdgesPlot = xEdges[ids][1:]
459 xEdges = xEdges[ids]
461 if len(ids) > 1:
462 # Create the codes needed to turn the sigmaMad lines
463 # into a path to speed up checking which points are
464 # inside the area.
465 codes = np.ones(len(xEdgesPlot) * 2) * Path.LINETO
466 codes[0] = Path.MOVETO
467 codes[-1] = Path.CLOSEPOLY
469 meds = np.zeros(len(xEdgesPlot))
470 threeSigMadVerts = np.zeros((len(xEdgesPlot) * 2, 2))
471 sigMads = np.zeros(len(xEdgesPlot))
473 for i, xEdge in enumerate(xEdgesPlot):
474 ids = np.where((xs < xEdge) & (xs > xEdges[i]) & (np.isfinite(ys)))[0]
475 med = np.nanmedian(ys[ids])
476 sigMad = sigmaMad(ys[ids], nan_policy="omit")
477 meds[i] = med
478 sigMads[i] = sigMad
479 threeSigMadVerts[i, :] = [xEdge, med + 3 * sigMad]
480 threeSigMadVerts[-(i + 1), :] = [xEdge, med - 3 * sigMad]
482 (medLine,) = ax.plot(xEdgesPlot, meds, color, label="Running Median")
483 linesForLegend.append(medLine)
485 # Make path to check which points lie within one sigma mad
486 threeSigMadPath = Path(threeSigMadVerts, codes)
488 # Add lines for the median +/- 3 * sigma MAD
489 (threeSigMadLine,) = ax.plot(
490 xEdgesPlot,
491 threeSigMadVerts[: len(xEdgesPlot), 1],
492 color,
493 alpha=0.4,
494 label=r"3$\sigma_{MAD}$",
495 )
496 ax.plot(xEdgesPlot[::-1], threeSigMadVerts[len(xEdgesPlot) :, 1], color, alpha=0.4)
498 # Add lines for the median +/- 1 * sigma MAD
499 (sigMadLine,) = ax.plot(
500 xEdgesPlot, meds + 1.0 * sigMads, color, alpha=0.8, label=r"$\sigma_{MAD}$"
501 )
502 linesForLegend.append(sigMadLine)
503 ax.plot(xEdgesPlot, meds - 1.0 * sigMads, color, alpha=0.8)
505 # Add lines for the median +/- 2 * sigma MAD
506 (twoSigMadLine,) = ax.plot(
507 xEdgesPlot, meds + 2.0 * sigMads, color, alpha=0.6, label=r"2$\sigma_{MAD}$"
508 )
509 linesForLegend.append(twoSigMadLine)
510 linesForLegend.append(threeSigMadLine)
511 ax.plot(xEdgesPlot, meds - 2.0 * sigMads, color, alpha=0.6)
513 # Check which points are outside 3 sigma MAD of the median
514 # and plot these as points.
515 inside = threeSigMadPath.contains_points(np.array([xs, ys]).T)
516 ax.plot(xs[~inside], ys[~inside], ".", ms=3, alpha=0.3, mfc=color, mec=color, zorder=-1)
518 # Add some stats text
519 xPos = 0.65 - 0.4 * j
520 bbox = dict(edgecolor=color, linestyle="--", facecolor="none")
521 highThresh = data["highSnThreshold"]
522 statText = f"S/N > {highThresh:0.4g} Stats ({self.magLabel} < {highStats.approxMag:0.4g})\n"
523 highStatsStr = (
524 f"Median: {highStats.median:0.4g} "
525 + r"$\sigma_{MAD}$: "
526 + f"{highStats.sigmaMad:0.4g} "
527 + r"N$_{points}$: "
528 + f"{highStats.count}"
529 )
530 statText += highStatsStr
531 fig.text(xPos, 0.090, statText, bbox=bbox, transform=fig.transFigure, fontsize=6)
533 bbox = dict(edgecolor=color, linestyle=":", facecolor="none")
534 lowThresh = data["lowSnThreshold"]
535 statText = f"S/N > {lowThresh:0.4g} Stats ({self.magLabel} < {lowStats.approxMag:0.4g})\n"
536 lowStatsStr = (
537 f"Median: {lowStats.median:0.4g} "
538 + r"$\sigma_{MAD}$: "
539 + f"{lowStats.sigmaMad:0.4g} "
540 + r"N$_{points}$: "
541 + f"{lowStats.count}"
542 )
543 statText += lowStatsStr
544 fig.text(xPos, 0.020, statText, bbox=bbox, transform=fig.transFigure, fontsize=6)
546 if self.plot2DHist:
547 histIm = ax.hexbin(xs[inside], ys[inside], gridsize=75, cmap=cmap, mincnt=1, zorder=-3)
549 # If there are not many sources being used for the
550 # statistics then plot them individually as just
551 # plotting a line makes the statistics look wrong
552 # as the magnitude estimation is iffy for low
553 # numbers of sources.
554 if np.sum(highSn) < 100 and np.sum(highSn) > 0:
555 ax.plot(
556 cast(Vector, xs[highSn]),
557 cast(Vector, ys[highSn]),
558 marker="x",
559 ms=4,
560 mec="w",
561 mew=2,
562 ls="none",
563 )
564 (highSnLine,) = ax.plot(
565 cast(Vector, xs[highSn]),
566 cast(Vector, ys[highSn]),
567 color=color,
568 marker="x",
569 ms=4,
570 ls="none",
571 label="High SN",
572 )
573 linesForLegend.append(highSnLine)
574 xMin = np.min(cast(Vector, xs[highSn]))
575 else:
576 ax.axvline(highStats.approxMag, color=color, ls="--")
578 if np.sum(lowSn) < 100 and np.sum(lowSn) > 0:
579 ax.plot(
580 cast(Vector, xs[lowSn]),
581 cast(Vector, ys[lowSn]),
582 marker="+",
583 ms=4,
584 mec="w",
585 mew=2,
586 ls="none",
587 )
588 (lowSnLine,) = ax.plot(
589 cast(Vector, xs[lowSn]),
590 cast(Vector, ys[lowSn]),
591 color=color,
592 marker="+",
593 ms=4,
594 ls="none",
595 label="Low SN",
596 )
597 linesForLegend.append(lowSnLine)
598 if xMin is None or xMin > np.min(cast(Vector, xs[lowSn])):
599 xMin = np.min(cast(Vector, xs[lowSn]))
600 else:
601 ax.axvline(lowStats.approxMag, color=color, ls=":")
603 else:
604 ax.plot(xs, ys, ".", ms=5, alpha=0.3, mfc=color, mec=color, zorder=-1)
605 meds = np.array([np.nanmedian(ys)] * len(xs))
606 (medLine,) = ax.plot(xs, meds, color, label=f"Median: {np.nanmedian(ys):0.3g}", lw=0.8)
607 linesForLegend.append(medLine)
608 sigMads = np.array([nansigmaMad(ys)] * len(xs))
609 (sigMadLine,) = ax.plot(
610 xs,
611 meds + 1.0 * sigMads,
612 color,
613 alpha=0.8,
614 lw=0.8,
615 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:0.3g}",
616 )
617 ax.plot(xs, meds - 1.0 * sigMads, color, alpha=0.8)
618 linesForLegend.append(sigMadLine)
619 histIm = None
621 # Add a horizontal reference line at 0 to the scatter plot
622 ax.axhline(0, color=kwargs["hlineColor"], ls=kwargs["hlineStyle"], alpha=0.7, zorder=-2)
624 # Set the scatter plot limits
625 # TODO: Make this not work by accident
626 if len(cast(Vector, data["yStars"])) > 0:
627 plotMed = np.nanmedian(cast(Vector, data["yStars"]))
628 else:
629 plotMed = np.nanmedian(cast(Vector, data["yGalaxies"]))
630 # Ignore types below pending making this not working my accident
631 if len(xs) < 2: # type: ignore
632 meds = [np.nanmedian(ys)] # type: ignore
633 if self.yLims:
634 ax.set_ylim(self.yLims[0], self.yLims[1]) # type: ignore
635 else:
636 numSig = 4
637 yLimMin = plotMed - numSig * sigMadYs # type: ignore
638 yLimMax = plotMed + numSig * sigMadYs # type: ignore
639 while (yLimMax < np.max(meds) or yLimMin > np.min(meds)) and numSig < 10: # type: ignore
640 numSig += 1
642 numSig += 1
643 yLimMin = plotMed - numSig * sigMadYs # type: ignore
644 yLimMax = plotMed + numSig * sigMadYs # type: ignore
645 ax.set_ylim(yLimMin, yLimMax)
647 if self.xLims:
648 ax.set_xlim(self.xLims[0], self.xLims[1]) # type: ignore
649 elif len(xs) > 2: # type: ignore
650 if xMin is None:
651 xMin = xs1 - 2 * xScale # type: ignore
652 ax.set_xlim(xMin, xs97 + 2 * xScale) # type: ignore
654 # Add a line legend
655 ax.legend(
656 handles=linesForLegend,
657 ncol=4,
658 fontsize=6,
659 loc="upper left",
660 framealpha=0.9,
661 edgecolor="k",
662 borderpad=0.4,
663 handlelength=1,
664 )
666 # Add axes labels
667 ax.set_ylabel(self.yAxisLabel, fontsize=10, labelpad=10)
668 ax.set_xlabel(self.xAxisLabel, fontsize=10, labelpad=2)
670 return ax, histIm
672 def _makeTopHistogram(
673 self, data: KeyedData, figure: Figure, gs: gridspec.GridSpec, ax: Axes, **kwargs
674 ) -> None:
675 # Top histogram
676 totalX: list[Vector] = []
677 if "stars" in self.plotTypes: # type: ignore
678 totalX.append(cast(Vector, data["xStars"]))
679 if "galaxies" in self.plotTypes: # type: ignore
680 totalX.append(cast(Vector, data["xGalaxies"]))
681 if "unknown" in self.plotTypes: # type: ignore
682 totalX.append(cast(Vector, data["xUknown"]))
683 if "any" in self.plotTypes: # type: ignore
684 totalX.append(cast(Vector, data["x"]))
686 totalXChained = [x for x in chain.from_iterable(totalX) if x == x]
688 topHist = figure.add_subplot(gs[0, :-1], sharex=ax)
689 topHist.hist(
690 totalXChained, bins=100, color="grey", alpha=0.3, log=True, label=f"All ({len(totalXChained)})"
691 )
692 if "galaxies" in self.plotTypes: # type: ignore
693 topHist.hist(
694 data["xGalaxies"],
695 bins=100,
696 color="firebrick",
697 histtype="step",
698 log=True,
699 label=f"Galaxies ({len(cast(Vector, data['xGalaxies']))})",
700 )
701 if "stars" in self.plotTypes: # type: ignore
702 topHist.hist(
703 data["xStars"],
704 bins=100,
705 color="midnightblue",
706 histtype="step",
707 log=True,
708 label=f"Stars ({len(cast(Vector, data['xStars']))})",
709 )
710 topHist.axes.get_xaxis().set_visible(False)
711 topHist.set_ylabel("Number", fontsize=8)
712 topHist.legend(fontsize=6, framealpha=0.9, borderpad=0.4, loc="lower left", ncol=3, edgecolor="k")
714 # Side histogram
716 def _makeSideHistogram(
717 self,
718 data: KeyedData,
719 figure: Figure,
720 gs: gridspec.Gridspec,
721 ax: Axes,
722 histIm: Optional[PolyCollection],
723 **kwargs,
724 ) -> None:
725 sideHist = figure.add_subplot(gs[1:, -1], sharey=ax)
727 totalY: list[Vector] = []
728 if "stars" in self.plotTypes: # type: ignore
729 totalY.append(cast(Vector, data["yStars"]))
730 if "galaxies" in self.plotTypes: # type: ignore
731 totalY.append(cast(Vector, data["yGalaxies"]))
732 if "unknown" in self.plotTypes: # type: ignore
733 totalY.append(cast(Vector, data["yUknown"]))
734 if "any" in self.plotTypes: # type: ignore
735 totalY.append(cast(Vector, data["y"]))
736 totalYChained = [y for y in chain.from_iterable(totalY) if y == y]
738 # cheat to get the total count while iterating once
739 yLimMin, yLimMax = ax.get_ylim()
740 bins = np.linspace(yLimMin, yLimMax)
741 sideHist.hist(
742 totalYChained,
743 bins=bins,
744 color="grey",
745 alpha=0.3,
746 orientation="horizontal",
747 log=True,
748 )
749 if "galaxies" in self.plotTypes: # type: ignore
750 sideHist.hist(
751 [g for g in cast(Vector, data["yGalaxies"]) if g == g],
752 bins=bins,
753 color="firebrick",
754 histtype="step",
755 orientation="horizontal",
756 log=True,
757 )
758 sideHist.hist(
759 cast(Vector, data["yGalaxies"])[cast(Vector, data["galaxiesHighSNMask"])],
760 bins=bins,
761 color="firebrick",
762 histtype="step",
763 orientation="horizontal",
764 log=True,
765 ls="--",
766 )
767 sideHist.hist(
768 cast(Vector, data["yGalaxies"])[cast(Vector, data["galaxiesLowSNMask"])],
769 bins=bins,
770 color="firebrick",
771 histtype="step",
772 orientation="horizontal",
773 log=True,
774 ls=":",
775 )
777 if "stars" in self.plotTypes: # type: ignore
778 sideHist.hist(
779 [s for s in cast(Vector, data["yStars"]) if s == s],
780 bins=bins,
781 color="midnightblue",
782 histtype="step",
783 orientation="horizontal",
784 log=True,
785 )
786 sideHist.hist(
787 cast(Vector, data["yStars"])[cast(Vector, data["starsHighSNMask"])],
788 bins=bins,
789 color="midnightblue",
790 histtype="step",
791 orientation="horizontal",
792 log=True,
793 ls="--",
794 )
795 sideHist.hist(
796 cast(Vector, data["yStars"])[cast(Vector, data["starsLowSNMask"])],
797 bins=bins,
798 color="midnightblue",
799 histtype="step",
800 orientation="horizontal",
801 log=True,
802 ls=":",
803 )
805 # Add a horizontal reference line at 0 to the side histogram
806 sideHist.axhline(0, color=kwargs["hlineColor"], ls=kwargs["hlineStyle"], alpha=0.7, zorder=-2)
808 sideHist.axes.get_yaxis().set_visible(False)
809 sideHist.set_xlabel("Number", fontsize=8)
810 if self.plot2DHist and histIm is not None:
811 divider = make_axes_locatable(sideHist)
812 cax = divider.append_axes("right", size="8%", pad=0)
813 plt.colorbar(histIm, cax=cax, orientation="vertical", label="Number of Points Per Bin")