Coverage for python/lsst/analysis/tools/actions/plot/scatterplotWithTwoHists.py: 12%
361 statements
« prev ^ index » next coverage.py v6.5.0, created at 2022-12-20 02:23 -0800
« prev ^ index » next coverage.py v6.5.0, created at 2022-12-20 02:23 -0800
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ("ScatterPlotStatsAction", "ScatterPlotWithTwoHists")
26from itertools import chain
27from typing import Mapping, NamedTuple, Optional, cast
29import matplotlib.pyplot as plt
30import numpy as np
31from lsst.pex.config import Field
32from lsst.pex.config.listField import ListField
33from lsst.pipe.tasks.configurableActions import ConfigurableActionField
34from lsst.skymap import BaseSkyMap
35from matplotlib import gridspec
36from matplotlib.axes import Axes
37from matplotlib.collections import PolyCollection
38from matplotlib.figure import Figure
39from matplotlib.path import Path
40from mpl_toolkits.axes_grid1 import make_axes_locatable
42from ...interfaces import KeyedData, KeyedDataAction, KeyedDataSchema, PlotAction, Scalar, Vector
43from ...statistics import nansigmaMad, sigmaMad
44from ..keyedData.summaryStatistics import SummaryStatisticAction
45from ..scalar import MedianAction
46from ..vector import MagColumnNanoJansky, SnSelector
47from .plotUtils import addPlotInfo, addSummaryPlot, generateSummaryStats, mkColormap
49# ignore because coolwarm is actually part of module
50cmapPatch = plt.cm.coolwarm.copy() # type: ignore
51cmapPatch.set_bad(color="none")
54class ScatterPlotStatsAction(KeyedDataAction):
55 vectorKey = Field[str](doc="Vector on which to compute statistics")
56 highSNSelector = ConfigurableActionField[SnSelector](
57 doc="Selector used to determine high SN Objects", default=SnSelector(threshold=2700)
58 )
59 lowSNSelector = ConfigurableActionField[SnSelector](
60 doc="Selector used to determine low SN Objects", default=SnSelector(threshold=500)
61 )
62 fluxType = Field[str](doc="Vector key to use to compute signal to noise ratio", default="{band}_psfFlux")
64 def getInputSchema(self, **kwargs) -> KeyedDataSchema:
65 yield (self.vectorKey, Vector)
66 yield (self.fluxType, Vector)
67 yield from self.highSNSelector.getInputSchema()
68 yield from self.lowSNSelector.getInputSchema()
70 def getOutputSchema(self) -> KeyedDataSchema:
71 return (
72 (f'{self.identity or ""}HighSNMask', Vector),
73 (f'{self.identity or ""}LowSNMask', Vector),
74 (f"{{band}}_lowSN{self.identity.capitalize() if self.identity else ''}_median", Scalar),
75 (f"{{band}}_lowSN{self.identity.capitalize() if self.identity else ''}_sigmaMad", Scalar),
76 (f"{{band}}_lowSN{self.identity.capitalize() if self.identity else ''}_count", Scalar),
77 (f"{{band}}_lowSN{self.identity.capitalize() if self.identity else ''}_approxMag", Scalar),
78 (f"{{band}}_highSN{self.identity.capitalize() if self.identity else ''}_median", Scalar),
79 (f"{{band}}_highSN{self.identity.capitalize() if self.identity else ''}_sigmaMad", Scalar),
80 (f"{{band}}_highSN{self.identity.capitalize() if self.identity else ''}_count", Scalar),
81 (f"{{band}}_highSN{self.identity.capitalize() if self.identity else ''}_approxMag", Scalar),
82 ("highThreshold", Scalar),
83 ("lowThreshold", Scalar),
84 )
86 def __call__(self, data: KeyedData, **kwargs) -> KeyedData:
87 results = {}
88 highMaskKey = f'{self.identity.lower() or ""}HighSNMask'
89 results[highMaskKey] = self.highSNSelector(data, **kwargs)
91 lowMaskKey = f'{self.identity.lower() or ""}LowSNMask'
92 results[lowMaskKey] = self.lowSNSelector(data, **kwargs)
94 prefix = f"{band}_" if (band := kwargs.get("band")) else ""
95 fluxes = data[self.fluxType.format(band=band)] if band is not None else None
97 statAction = SummaryStatisticAction(vectorKey=self.vectorKey)
99 # this is sad, but pex_config seems to have broken behavior that
100 # is dangerous to fix
101 statAction.setDefaults()
103 medianAction = MedianAction(vectorKey="mag")
104 magAction = MagColumnNanoJansky(vectorKey="flux")
106 for maskKey, binName in ((lowMaskKey, "low"), (highMaskKey, "high")):
107 name = f"{prefix}{binName}SN{self.identity.capitalize() if self.identity else ''}"
108 # set the approxMag to the median mag in the SN selection
109 results[f"{name}_approxMag".format(**kwargs)] = (
110 medianAction({"mag": magAction({"flux": fluxes[results[maskKey]]})}) # type: ignore
111 if band is not None
112 else np.nan
113 )
114 stats = statAction(data, **(kwargs | {"mask": results[maskKey]})).items()
115 for suffix, value in stats:
116 tmpKey = f"{name}_{suffix}".format(**kwargs)
117 results[tmpKey] = value
118 results["highSnThreshold"] = self.highSNSelector.threshold # type: ignore
119 results["lowSnThreshold"] = self.lowSNSelector.threshold # type: ignore
121 return results
124def _validatePlotTypes(value):
125 return value in ("stars", "galaxies", "unknown", "any", "mag")
128# ignore type because of conflicting name on tuple baseclass
129class _StatsContainer(NamedTuple):
130 median: Scalar
131 sigmaMad: Scalar
132 count: Scalar # type: ignore
133 approxMag: Scalar
136class ScatterPlotWithTwoHists(PlotAction):
137 yLims = ListField[float](
138 doc="ylimits of the plot, if not specified determined from data",
139 length=2,
140 optional=True,
141 )
143 xLims = ListField[float](
144 doc="xlimits of the plot, if not specified determined from data", length=2, optional=True
145 )
146 xAxisLabel = Field[str](doc="Label to use for the x axis", optional=False)
147 yAxisLabel = Field[str](doc="Label to use for the y axis", optional=False)
148 magLabel = Field[str](doc="Label to use for the magnitudes used for SNR", optional=False)
149 nBins = Field[float](doc="Number of bins on x axis", default=40.0)
150 plot2DHist = Field[bool](
151 doc="Plot a 2D histogram in dense areas of points on the scatter plot."
152 "Doesn't look great if plotting multiple datasets on top of each other.",
153 default=True,
154 )
155 plotTypes = ListField[str](
156 doc="Selection of types of objects to plot. Can take any combination of"
157 " stars, galaxies, unknown, mag, any",
158 optional=False,
159 itemCheck=_validatePlotTypes,
160 )
162 addSummaryPlot = Field[bool](
163 doc="Add a summary plot to the figure?",
164 default=False,
165 )
167 _stats = ("median", "sigmaMad", "count", "approxMag")
169 def getInputSchema(self) -> KeyedDataSchema:
170 base: list[tuple[str, type[Vector] | type[Scalar]]] = []
171 if "stars" in self.plotTypes: # type: ignore
172 base.append(("xStars", Vector))
173 base.append(("yStars", Vector))
174 base.append(("starsHighSNMask", Vector))
175 base.append(("starsLowSNMask", Vector))
176 # statistics
177 for name in self._stats:
178 base.append((f"{{band}}_highSNStars_{name}", Scalar))
179 base.append((f"{{band}}_lowSNStars_{name}", Scalar))
180 if "galaxies" in self.plotTypes: # type: ignore
181 base.append(("xGalaxies", Vector))
182 base.append(("yGalaxies", Vector))
183 base.append(("galaxiesHighSNMask", Vector))
184 base.append(("galaxiesLowSNMask", Vector))
185 # statistics
186 for name in self._stats:
187 base.append((f"{{band}}_highSNGalaxies_{name}", Scalar))
188 base.append((f"{{band}}_lowSNGalaxies_{name}", Scalar))
189 if "unknown" in self.plotTypes: # type: ignore
190 base.append(("xUnknown", Vector))
191 base.append(("yUnknown", Vector))
192 base.append(("unknownHighSNMask", Vector))
193 base.append(("unknownLowSNMask", Vector))
194 # statistics
195 for name in self._stats:
196 base.append((f"{{band}}_highSNUnknown_{name}", Scalar))
197 base.append((f"{{band}}_lowSNUnknown_{name}", Scalar))
198 if "any" in self.plotTypes: # type: ignore
199 base.append(("x", Vector))
200 base.append(("y", Vector))
201 base.append(("anyHighSNMask", Vector))
202 base.append(("anySNMask", Vector))
203 # statistics
204 for name in self._stats:
205 base.append((f"{{band}}_highSNAny_{name}", Scalar))
206 base.append((f"{{band}}_lowSNAny_{name}", Scalar))
207 base.append(("lowSnThreshold", Scalar))
208 base.append(("highSnThreshold", Scalar))
210 if self.addSummaryPlot:
211 base.append(("patch", Vector))
213 return base
215 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure:
217 self._validateInput(data, **kwargs)
218 return self.makePlot(data, **kwargs)
220 def _validateInput(self, data: KeyedData, **kwargs) -> None:
221 """NOTE currently can only check that something is not a Scalar, not
222 check that the data is consistent with Vector
223 """
224 needed = self.getFormattedInputSchema(**kwargs)
225 if remainder := {key.format(**kwargs) for key, _ in needed} - {
226 key.format(**kwargs) for key in data.keys()
227 }:
228 raise ValueError(f"Task needs keys {remainder} but they were not found in input")
229 for name, typ in needed:
230 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar)
231 if isScalar and typ != Scalar:
232 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}")
234 def makePlot(
235 self,
236 data: KeyedData,
237 skymap: BaseSkyMap,
238 plotInfo: Optional[Mapping[str, str]] = None,
239 sumStats: Optional[Mapping] = None,
240 **kwargs,
241 ) -> Figure:
242 """Makes a generic plot with a 2D histogram and collapsed histograms of
243 each axis.
244 Parameters
245 ----------
246 data : `pandas.core.frame.DataFrame`
247 The catalog to plot the points from.
248 plotInfo : `dict`
249 A dictionary of information about the data being plotted with keys:
250 ``"run"``
251 The output run for the plots (`str`).
252 ``"skymap"``
253 The type of skymap used for the data (`str`).
254 ``"filter"``
255 The filter used for this data (`str`).
256 ``"tract"``
257 The tract that the data comes from (`str`).
258 sumStats : `dict`
259 A dictionary where the patchIds are the keys which store the R.A.
260 and dec of the corners of the patch, along with a summary
261 statistic for each patch.
262 Returns
263 -------
264 fig : `matplotlib.figure.Figure`
265 The resulting figure.
266 Notes
267 -----
268 Uses the axisLabels config options `x` and `y` and the axisAction
269 config options `xAction` and `yAction` to plot a scatter
270 plot of the values against each other. A histogram of the points
271 collapsed onto each axis is also plotted. A summary panel showing the
272 median of the y value in each patch is shown in the upper right corner
273 of the resultant plot. The code uses the selectorActions to decide
274 which points to plot and the statisticSelector actions to determine
275 which points to use for the printed statistics.
276 """
277 if not self.plotTypes:
278 noDataFig = Figure()
279 noDataFig.text(0.3, 0.5, "No data to plot after selectors applied")
280 noDataFig = addPlotInfo(noDataFig, plotInfo)
281 return noDataFig
283 fig = plt.figure(dpi=300)
284 gs = gridspec.GridSpec(4, 4)
286 # add the various plot elements
287 ax, imhist = self._scatterPlot(data, fig, gs, **kwargs)
288 self._makeTopHistogram(data, fig, gs, ax, **kwargs)
289 self._makeSideHistogram(data, fig, gs, ax, imhist, **kwargs)
290 # Needs info from run quantum
291 if self.addSummaryPlot:
292 sumStats = generateSummaryStats(data, skymap, plotInfo)
293 label = self.yAxisLabel
294 fig = addSummaryPlot(fig, gs[0, -1], sumStats, label)
296 plt.draw()
297 plt.subplots_adjust(wspace=0.0, hspace=0.0, bottom=0.22, left=0.21)
298 fig = addPlotInfo(fig, plotInfo)
299 return fig
301 def _scatterPlot(
302 self, data: KeyedData, fig: Figure, gs: gridspec.GridSpec, **kwargs
303 ) -> tuple[Axes, Optional[PolyCollection]]:
304 # Main scatter plot
305 ax = fig.add_subplot(gs[1:, :-1])
307 newBlues = mkColormap(["paleturquoise", "midnightBlue"])
308 newReds = mkColormap(["lemonchiffon", "firebrick"])
310 binThresh = 5
312 yBinsOut = []
313 linesForLegend = []
315 toPlotList = []
316 histIm = None
317 highStats: _StatsContainer
318 lowStats: _StatsContainer
319 if "stars" in self.plotTypes: # type: ignore
320 highArgs = {}
321 lowArgs = {}
322 for name in self._stats:
323 highArgs[name] = cast(Scalar, data[f"{{band}}_highSNStars_{name}".format(**kwargs)])
324 lowArgs[name] = cast(Scalar, data[f"{{band}}_lowSNStars_{name}".format(**kwargs)])
325 highStats = _StatsContainer(**highArgs)
326 lowStats = _StatsContainer(**lowArgs)
328 toPlotList.append(
329 (
330 data["xStars"],
331 data["yStars"],
332 data["starsHighSNMask"],
333 data["starsLowSNMask"],
334 "midnightblue",
335 newBlues,
336 highStats,
337 lowStats,
338 )
339 )
340 if "galaxies" in self.plotTypes: # type: ignore
341 highArgs = {}
342 lowArgs = {}
343 for name in self._stats:
344 highArgs[name] = cast(Scalar, data[f"{{band}}_highSNGalaxies_{name}".format(**kwargs)])
345 lowArgs[name] = cast(Scalar, data[f"{{band}}_lowSNGalaxies_{name}".format(**kwargs)])
346 highStats = _StatsContainer(**highArgs)
347 lowStats = _StatsContainer(**lowArgs)
349 toPlotList.append(
350 (
351 data["xGalaxies"],
352 data["yGalaxies"],
353 data["galaxiesHighSNMask"],
354 data["galaxiesLowSNMask"],
355 "firebrick",
356 newReds,
357 highStats,
358 lowStats,
359 )
360 )
361 if "unknown" in self.plotTypes: # type: ignore
362 highArgs = {}
363 lowArgs = {}
364 for name in self._stats:
365 highArgs[name] = cast(Scalar, data[f"{{band}}_highSNUnknown_{name}".format(**kwargs)])
366 lowArgs[name] = cast(Scalar, data[f"{{band}}_lowSNUnknown_{name}".format(**kwargs)])
367 highStats = _StatsContainer(**highArgs)
368 lowStats = _StatsContainer(**lowArgs)
370 toPlotList.append(
371 (
372 data["xUnknown"],
373 data["yUnknown"],
374 data["unknownHighSNMask"],
375 data["unknownLowSNMask"],
376 "green",
377 None,
378 highStats,
379 lowStats,
380 )
381 )
382 if "any" in self.plotTypes: # type: ignore
383 highArgs = {}
384 lowArgs = {}
385 for name in self._stats:
386 highArgs[name] = cast(Scalar, data[f"{{band}}_highSNUnknown_{name}".format(**kwargs)])
387 lowArgs[name] = cast(Scalar, data[f"{{band}}_lowSNUnknown_{name}".format(**kwargs)])
388 highStats = _StatsContainer(**highArgs)
389 lowStats = _StatsContainer(**lowArgs)
391 toPlotList.append(
392 (
393 data["x"],
394 data["y"],
395 data["anyHighSNMask"],
396 data["anyLowSNMask"],
397 "purple",
398 None,
399 highStats,
400 lowStats,
401 )
402 )
404 xMin = None
405 for (j, (xs, ys, highSn, lowSn, color, cmap, highStats, lowStats)) in enumerate(toPlotList):
406 highSn = cast(Vector, highSn)
407 lowSn = cast(Vector, lowSn)
408 # ensure the columns are actually array
409 xs = np.array(xs)
410 ys = np.array(ys)
411 sigMadYs = nansigmaMad(ys)
412 if len(xs) < 2:
413 (medLine,) = ax.plot(
414 xs, np.nanmedian(ys), color, label=f"Median: {np.nanmedian(ys):.2g}", lw=0.8
415 )
416 linesForLegend.append(medLine)
417 sigMads = np.array([nansigmaMad(ys)] * len(xs))
418 (sigMadLine,) = ax.plot(
419 xs,
420 np.nanmedian(ys) + 1.0 * sigMads,
421 color,
422 alpha=0.8,
423 lw=0.8,
424 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:.2g}",
425 )
426 ax.plot(xs, np.nanmedian(ys) - 1.0 * sigMads, color, alpha=0.8)
427 linesForLegend.append(sigMadLine)
428 histIm = None
429 continue
431 [xs1, xs25, xs50, xs75, xs95, xs97] = np.nanpercentile(xs, [1, 25, 50, 75, 95, 97])
432 xScale = (xs97 - xs1) / 20.0 # This is ~5% of the data range
434 # 40 was used as the default number of bins because it looked good
435 xEdges = np.arange(
436 np.nanmin(xs) - xScale,
437 np.nanmax(xs) + xScale,
438 (np.nanmax(xs) + xScale - (np.nanmin(xs) - xScale)) / self.nBins,
439 )
440 medYs = np.nanmedian(ys)
441 fiveSigmaHigh = medYs + 5.0 * sigMadYs
442 fiveSigmaLow = medYs - 5.0 * sigMadYs
443 binSize = (fiveSigmaHigh - fiveSigmaLow) / 101.0
444 yEdges = np.arange(fiveSigmaLow, fiveSigmaHigh, binSize)
446 counts, xBins, yBins = np.histogram2d(xs, ys, bins=(xEdges, yEdges))
447 yBinsOut.append(yBins)
448 countsYs = np.sum(counts, axis=1)
450 ids = np.where((countsYs > binThresh))[0]
451 xEdgesPlot = xEdges[ids][1:]
452 xEdges = xEdges[ids]
454 if len(ids) > 1:
455 # Create the codes needed to turn the sigmaMad lines
456 # into a path to speed up checking which points are
457 # inside the area.
458 codes = np.ones(len(xEdgesPlot) * 2) * Path.LINETO
459 codes[0] = Path.MOVETO
460 codes[-1] = Path.CLOSEPOLY
462 meds = np.zeros(len(xEdgesPlot))
463 threeSigMadVerts = np.zeros((len(xEdgesPlot) * 2, 2))
464 sigMads = np.zeros(len(xEdgesPlot))
466 for (i, xEdge) in enumerate(xEdgesPlot):
467 ids = np.where((xs < xEdge) & (xs > xEdges[i]) & (np.isfinite(ys)))[0]
468 med = np.nanmedian(ys[ids])
469 sigMad = sigmaMad(ys[ids], nan_policy="omit")
470 meds[i] = med
471 sigMads[i] = sigMad
472 threeSigMadVerts[i, :] = [xEdge, med + 3 * sigMad]
473 threeSigMadVerts[-(i + 1), :] = [xEdge, med - 3 * sigMad]
475 (medLine,) = ax.plot(xEdgesPlot, meds, color, label="Running Median")
476 linesForLegend.append(medLine)
478 # Make path to check which points lie within one sigma mad
479 threeSigMadPath = Path(threeSigMadVerts, codes)
481 # Add lines for the median +/- 3 * sigma MAD
482 (threeSigMadLine,) = ax.plot(
483 xEdgesPlot,
484 threeSigMadVerts[: len(xEdgesPlot), 1],
485 color,
486 alpha=0.4,
487 label=r"3$\sigma_{MAD}$",
488 )
489 ax.plot(xEdgesPlot[::-1], threeSigMadVerts[len(xEdgesPlot) :, 1], color, alpha=0.4)
491 # Add lines for the median +/- 1 * sigma MAD
492 (sigMadLine,) = ax.plot(
493 xEdgesPlot, meds + 1.0 * sigMads, color, alpha=0.8, label=r"$\sigma_{MAD}$"
494 )
495 linesForLegend.append(sigMadLine)
496 ax.plot(xEdgesPlot, meds - 1.0 * sigMads, color, alpha=0.8)
498 # Add lines for the median +/- 2 * sigma MAD
499 (twoSigMadLine,) = ax.plot(
500 xEdgesPlot, meds + 2.0 * sigMads, color, alpha=0.6, label=r"2$\sigma_{MAD}$"
501 )
502 linesForLegend.append(twoSigMadLine)
503 linesForLegend.append(threeSigMadLine)
504 ax.plot(xEdgesPlot, meds - 2.0 * sigMads, color, alpha=0.6)
506 # Check which points are outside 3 sigma MAD of the median
507 # and plot these as points.
508 inside = threeSigMadPath.contains_points(np.array([xs, ys]).T)
509 ax.plot(xs[~inside], ys[~inside], ".", ms=3, alpha=0.3, mfc=color, mec=color, zorder=-1)
511 # Add some stats text
512 xPos = 0.65 - 0.4 * j
513 bbox = dict(edgecolor=color, linestyle="--", facecolor="none")
514 highThresh = data["highSnThreshold"]
515 statText = f"S/N > {highThresh:0.4g} Stats ({self.magLabel} < {highStats.approxMag:0.4g})\n"
516 highStatsStr = (
517 f"Median: {highStats.median:0.4g} "
518 + r"$\sigma_{MAD}$: "
519 + f"{highStats.sigmaMad:0.4g} "
520 + r"N$_{points}$: "
521 + f"{highStats.count}"
522 )
523 statText += highStatsStr
524 fig.text(xPos, 0.090, statText, bbox=bbox, transform=fig.transFigure, fontsize=6)
526 bbox = dict(edgecolor=color, linestyle=":", facecolor="none")
527 lowThresh = data["lowSnThreshold"]
528 statText = f"S/N > {lowThresh:0.4g} Stats ({self.magLabel} < {lowStats.approxMag:0.4g})\n"
529 lowStatsStr = (
530 f"Median: {lowStats.median:0.4g} "
531 + r"$\sigma_{MAD}$: "
532 + f"{lowStats.sigmaMad:0.4g} "
533 + r"N$_{points}$: "
534 + f"{lowStats.count}"
535 )
536 statText += lowStatsStr
537 fig.text(xPos, 0.020, statText, bbox=bbox, transform=fig.transFigure, fontsize=6)
539 if self.plot2DHist:
540 histIm = ax.hexbin(xs[inside], ys[inside], gridsize=75, cmap=cmap, mincnt=1, zorder=-2)
542 # If there are not many sources being used for the
543 # statistics then plot them individually as just
544 # plotting a line makes the statistics look wrong
545 # as the magnitude estimation is iffy for low
546 # numbers of sources.
547 if np.sum(highSn) < 100 and np.sum(highSn) > 0:
548 ax.plot(
549 cast(Vector, xs[highSn]),
550 cast(Vector, ys[highSn]),
551 marker="x",
552 ms=4,
553 mec="w",
554 mew=2,
555 ls="none",
556 )
557 (highSnLine,) = ax.plot(
558 cast(Vector, xs[highSn]),
559 cast(Vector, ys[highSn]),
560 color=color,
561 marker="x",
562 ms=4,
563 ls="none",
564 label="High SN",
565 )
566 linesForLegend.append(highSnLine)
567 xMin = np.min(cast(Vector, xs[highSn]))
568 else:
569 ax.axvline(highStats.approxMag, color=color, ls="--")
571 if np.sum(lowSn) < 100 and np.sum(lowSn) > 0:
572 ax.plot(
573 cast(Vector, xs[lowSn]),
574 cast(Vector, ys[lowSn]),
575 marker="+",
576 ms=4,
577 mec="w",
578 mew=2,
579 ls="none",
580 )
581 (lowSnLine,) = ax.plot(
582 cast(Vector, xs[lowSn]),
583 cast(Vector, ys[lowSn]),
584 color=color,
585 marker="+",
586 ms=4,
587 ls="none",
588 label="Low SN",
589 )
590 linesForLegend.append(lowSnLine)
591 if xMin is None or xMin > np.min(cast(Vector, xs[lowSn])):
592 xMin = np.min(cast(Vector, xs[lowSn]))
593 else:
594 ax.axvline(lowStats.approxMag, color=color, ls=":")
596 else:
597 ax.plot(xs, ys, ".", ms=5, alpha=0.3, mfc=color, mec=color, zorder=-1)
598 meds = np.array([np.nanmedian(ys)] * len(xs))
599 (medLine,) = ax.plot(xs, meds, color, label=f"Median: {np.nanmedian(ys):0.3g}", lw=0.8)
600 linesForLegend.append(medLine)
601 sigMads = np.array([nansigmaMad(ys)] * len(xs))
602 (sigMadLine,) = ax.plot(
603 xs,
604 meds + 1.0 * sigMads,
605 color,
606 alpha=0.8,
607 lw=0.8,
608 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:0.3g}",
609 )
610 ax.plot(xs, meds - 1.0 * sigMads, color, alpha=0.8)
611 linesForLegend.append(sigMadLine)
612 histIm = None
614 # Set the scatter plot limits
615 # TODO: Make this not work by accident
616 if len(cast(Vector, data["yStars"])) > 0:
617 plotMed = np.nanmedian(cast(Vector, data["yStars"]))
618 else:
619 plotMed = np.nanmedian(cast(Vector, data["yGalaxies"]))
620 # Ignore types below pending making this not working my accident
621 if len(xs) < 2: # type: ignore
622 meds = [np.nanmedian(ys)] # type: ignore
623 if self.yLims:
624 ax.set_ylim(self.yLims[0], self.yLims[1]) # type: ignore
625 else:
626 numSig = 4
627 yLimMin = plotMed - numSig * sigMadYs # type: ignore
628 yLimMax = plotMed + numSig * sigMadYs # type: ignore
629 while (yLimMax < np.max(meds) or yLimMin > np.min(meds)) and numSig < 10: # type: ignore
630 numSig += 1
632 numSig += 1
633 yLimMin = plotMed - numSig * sigMadYs # type: ignore
634 yLimMax = plotMed + numSig * sigMadYs # type: ignore
635 ax.set_ylim(yLimMin, yLimMax)
637 if self.xLims:
638 ax.set_xlim(self.xLims[0], self.xLims[1]) # type: ignore
639 elif len(xs) > 2: # type: ignore
640 if xMin is None:
641 xMin = xs1 - 2 * xScale # type: ignore
642 ax.set_xlim(xMin, xs97 + 2 * xScale) # type: ignore
644 # Add a line legend
645 ax.legend(
646 handles=linesForLegend,
647 ncol=4,
648 fontsize=6,
649 loc="upper left",
650 framealpha=0.9,
651 edgecolor="k",
652 borderpad=0.4,
653 handlelength=1,
654 )
656 # Add axes labels
657 ax.set_ylabel(self.yAxisLabel, fontsize=10, labelpad=10)
658 ax.set_xlabel(self.xAxisLabel, fontsize=10, labelpad=2)
660 return ax, histIm
662 def _makeTopHistogram(
663 self, data: KeyedData, figure: Figure, gs: gridspec.GridSpec, ax: Axes, **kwargs
664 ) -> None:
665 # Top histogram
666 totalX: list[Vector] = []
667 if "stars" in self.plotTypes: # type: ignore
668 totalX.append(cast(Vector, data["xStars"]))
669 if "galaxies" in self.plotTypes: # type: ignore
670 totalX.append(cast(Vector, data["xGalaxies"]))
671 if "unknown" in self.plotTypes: # type: ignore
672 totalX.append(cast(Vector, data["xUknown"]))
673 if "any" in self.plotTypes: # type: ignore
674 totalX.append(cast(Vector, data["x"]))
676 totalXChained = [x for x in chain.from_iterable(totalX) if x == x]
678 topHist = figure.add_subplot(gs[0, :-1], sharex=ax)
679 topHist.hist(
680 totalXChained, bins=100, color="grey", alpha=0.3, log=True, label=f"All ({len(totalXChained)})"
681 )
682 if "galaxies" in self.plotTypes: # type: ignore
683 topHist.hist(
684 data["xGalaxies"],
685 bins=100,
686 color="firebrick",
687 histtype="step",
688 log=True,
689 label=f"Galaxies ({len(cast(Vector, data['xGalaxies']))})",
690 )
691 if "stars" in self.plotTypes: # type: ignore
692 topHist.hist(
693 data["xStars"],
694 bins=100,
695 color="midnightblue",
696 histtype="step",
697 log=True,
698 label=f"Stars ({len(cast(Vector, data['xStars']))})",
699 )
700 topHist.axes.get_xaxis().set_visible(False)
701 topHist.set_ylabel("Number", fontsize=8)
702 topHist.legend(fontsize=6, framealpha=0.9, borderpad=0.4, loc="lower left", ncol=3, edgecolor="k")
704 # Side histogram
706 def _makeSideHistogram(
707 self,
708 data: KeyedData,
709 figure: Figure,
710 gs: gridspec.Gridspec,
711 ax: Axes,
712 histIm: Optional[PolyCollection],
713 **kwargs,
714 ) -> None:
715 sideHist = figure.add_subplot(gs[1:, -1], sharey=ax)
717 totalY: list[Vector] = []
718 if "stars" in self.plotTypes: # type: ignore
719 totalY.append(cast(Vector, data["yStars"]))
720 if "galaxies" in self.plotTypes: # type: ignore
721 totalY.append(cast(Vector, data["yGalaxies"]))
722 if "unknown" in self.plotTypes: # type: ignore
723 totalY.append(cast(Vector, data["yUknown"]))
724 if "any" in self.plotTypes: # type: ignore
725 totalY.append(cast(Vector, data["y"]))
726 totalYChained = [y for y in chain.from_iterable(totalY) if y == y]
728 # cheat to get the total count while iterating once
729 yLimMin, yLimMax = ax.get_ylim()
730 bins = np.linspace(yLimMin, yLimMax)
731 sideHist.hist(
732 totalYChained,
733 bins=bins,
734 color="grey",
735 alpha=0.3,
736 orientation="horizontal",
737 log=True,
738 )
739 if "galaxies" in self.plotTypes: # type: ignore
740 sideHist.hist(
741 [g for g in cast(Vector, data["yGalaxies"]) if g == g],
742 bins=bins,
743 color="firebrick",
744 histtype="step",
745 orientation="horizontal",
746 log=True,
747 )
748 sideHist.hist(
749 cast(Vector, data["yGalaxies"])[cast(Vector, data["galaxiesHighSNMask"])],
750 bins=bins,
751 color="firebrick",
752 histtype="step",
753 orientation="horizontal",
754 log=True,
755 ls="--",
756 )
757 sideHist.hist(
758 cast(Vector, data["yGalaxies"])[cast(Vector, data["galaxiesLowSNMask"])],
759 bins=bins,
760 color="firebrick",
761 histtype="step",
762 orientation="horizontal",
763 log=True,
764 ls=":",
765 )
767 if "stars" in self.plotTypes: # type: ignore
768 sideHist.hist(
769 [s for s in cast(Vector, data["yStars"]) if s == s],
770 bins=bins,
771 color="midnightblue",
772 histtype="step",
773 orientation="horizontal",
774 log=True,
775 )
776 sideHist.hist(
777 cast(Vector, data["yStars"])[cast(Vector, data["starsHighSNMask"])],
778 bins=bins,
779 color="midnightblue",
780 histtype="step",
781 orientation="horizontal",
782 log=True,
783 ls="--",
784 )
785 sideHist.hist(
786 cast(Vector, data["yStars"])[cast(Vector, data["starsLowSNMask"])],
787 bins=bins,
788 color="midnightblue",
789 histtype="step",
790 orientation="horizontal",
791 log=True,
792 ls=":",
793 )
795 sideHist.axes.get_yaxis().set_visible(False)
796 sideHist.set_xlabel("Number", fontsize=8)
797 if self.plot2DHist and histIm is not None:
798 divider = make_axes_locatable(sideHist)
799 cax = divider.append_axes("right", size="8%", pad=0)
800 plt.colorbar(histIm, cax=cax, orientation="vertical", label="Number of Points Per Bin")