Coverage for python/lsst/analysis/tools/actions/plot/scatterplotWithTwoHists.py: 13%
360 statements
« prev ^ index » next coverage.py v6.4.4, created at 2022-08-20 02:31 -0700
« prev ^ index » next coverage.py v6.4.4, created at 2022-08-20 02:31 -0700
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ("ScatterPlotStatsAction", "ScatterPlotWithTwoHists")
25from itertools import chain
26from typing import Mapping, NamedTuple, Optional, cast
28import matplotlib.pyplot as plt
29import numpy as np
30from lsst.pex.config import Field
31from lsst.pex.config.listField import ListField
32from lsst.pipe.tasks.configurableActions import ConfigurableActionField
33from lsst.skymap import BaseSkyMap
34from matplotlib import gridspec
35from matplotlib.axes import Axes
36from matplotlib.collections import PolyCollection
37from matplotlib.figure import Figure
38from matplotlib.path import Path
39from mpl_toolkits.axes_grid1 import make_axes_locatable
41from ...interfaces import KeyedData, KeyedDataAction, KeyedDataSchema, PlotAction, Scalar, Vector
42from ..keyedData.summaryStatistics import SummaryStatisticAction, sigmaMad
43from ..scalar import MedianAction
44from ..vector import MagColumnNanoJansky, SnSelector
45from .plotUtils import addPlotInfo, addSummaryPlot, generateSummaryStats, mkColormap
47# ignore because coolwarm is actually part of module
48cmapPatch = plt.cm.coolwarm.copy() # type: ignore
49cmapPatch.set_bad(color="none")
52class ScatterPlotStatsAction(KeyedDataAction):
53 vectorKey = Field[str](doc="Vector on which to compute statistics")
54 highSNSelector = ConfigurableActionField[SnSelector](
55 doc="Selector used to determine high SN Objects", default=SnSelector(threshold=2700)
56 )
57 lowSNSelector = ConfigurableActionField[SnSelector](
58 doc="Selector used to determine low SN Objects", default=SnSelector(threshold=500)
59 )
60 fluxType = Field[str](doc="Vector key to use to compute signal to noise ratio", default="{band}_psfFlux")
62 def getInputSchema(self, **kwargs) -> KeyedDataSchema:
63 yield (self.vectorKey, Vector)
64 yield (self.fluxType, Vector)
65 yield from self.highSNSelector.getInputSchema()
66 yield from self.lowSNSelector.getInputSchema()
68 def getOutputSchema(self) -> KeyedDataSchema:
69 return (
70 (f'{self.identity or ""}HighSNMask', Vector),
71 (f'{self.identity or ""}LowSNMask', Vector),
72 (f"{{band}}_lowSN{self.identity.capitalize() if self.identity else ''}_median", Scalar),
73 (f"{{band}}_lowSN{self.identity.capitalize() if self.identity else ''}_sigmaMad", Scalar),
74 (f"{{band}}_lowSN{self.identity.capitalize() if self.identity else ''}_count", Scalar),
75 (f"{{band}}_lowSN{self.identity.capitalize() if self.identity else ''}_approxMag", Scalar),
76 (f"{{band}}_highSN{self.identity.capitalize() if self.identity else ''}_median", Scalar),
77 (f"{{band}}_highSN{self.identity.capitalize() if self.identity else ''}_sigmaMad", Scalar),
78 (f"{{band}}_highSN{self.identity.capitalize() if self.identity else ''}_count", Scalar),
79 (f"{{band}}_highSN{self.identity.capitalize() if self.identity else ''}_approxMag", Scalar),
80 ("highThreshold", Scalar),
81 ("lowThreshold", Scalar),
82 )
84 def __call__(self, data: KeyedData, **kwargs) -> KeyedData:
85 results = {}
86 highMaskKey = f'{self.identity.lower() or ""}HighSNMask'
87 results[highMaskKey] = self.highSNSelector(data, **kwargs)
89 lowMaskKey = f'{self.identity.lower() or ""}LowSNMask'
90 results[lowMaskKey] = self.lowSNSelector(data, **kwargs)
92 prefix = f"{band}_" if (band := kwargs.get("band")) else ""
93 fluxes = data[self.fluxType.format(band=band)] if band is not None else None
95 statAction = SummaryStatisticAction(vectorKey=self.vectorKey)
97 # this is sad, but pex_config seems to have broken behavior that
98 # is dangerous to fix
99 statAction.setDefaults()
101 medianAction = MedianAction(vectorKey="mag")
102 magAction = MagColumnNanoJansky(vectorKey="flux")
104 for maskKey, binName in ((lowMaskKey, "low"), (highMaskKey, "high")):
105 name = f"{prefix}{binName}SN{self.identity.capitalize() if self.identity else ''}"
106 # set the approxMag to the median mag in the SN selection
107 results[f"{name}_approxMag".format(**kwargs)] = (
108 medianAction({"mag": magAction({"flux": fluxes[results[maskKey]]})}) # type: ignore
109 if band is not None
110 else np.nan
111 )
112 stats = statAction(data, **(kwargs | {"mask": results[maskKey]})).items()
113 for suffix, value in stats:
114 tmpKey = f"{name}_{suffix}".format(**kwargs)
115 results[tmpKey] = value
116 results["highSnThreshold"] = self.highSNSelector.threshold # type: ignore
117 results["lowSnThreshold"] = self.lowSNSelector.threshold # type: ignore
119 return results
122def _validatePlotTypes(value):
123 return value in ("stars", "galaxies", "unknown", "any", "mag")
126# ignore type because of conflicting name on tuple baseclass
127class _StatsContainer(NamedTuple):
128 median: Scalar
129 sigmaMad: Scalar
130 count: Scalar # type: ignore
131 approxMag: Scalar
134class ScatterPlotWithTwoHists(PlotAction):
135 yLims = ListField[float](
136 doc="ylimits of the plot, if not specified determined from data",
137 length=2,
138 optional=True,
139 )
141 xLims = ListField[float](
142 doc="xlimits of the plot, if not specified determined from data", length=2, optional=True
143 )
144 xAxisLabel = Field[str](doc="Label to use for the x axis", optional=False)
145 yAxisLabel = Field[str](doc="Label to use for the y axis", optional=False)
146 magLabel = Field[str](doc="Label to use for the magnitudes used for SNR", optional=False)
147 nBins = Field[float](doc="Number of bins on x axis", default=40.0)
148 plot2DHist = Field[bool](
149 doc="Plot a 2D histogram in dense areas of points on the scatter plot."
150 "Doesn't look great if plotting multiple datasets on top of each other.",
151 default=True,
152 )
153 plotTypes = ListField[str](
154 doc="Selection of types of objects to plot. Can take any combination of"
155 " stars, galaxies, unknown, mag, any",
156 optional=False,
157 itemCheck=_validatePlotTypes,
158 )
160 addSummaryPlot = Field[bool](
161 doc="Add a summary plot to the figure?",
162 default=False,
163 )
165 _stats = ("median", "sigmaMad", "count", "approxMag")
167 def getInputSchema(self) -> KeyedDataSchema:
168 base: list[tuple[str, type[Vector] | type[Scalar]]] = []
169 if "stars" in self.plotTypes: # type: ignore
170 base.append(("xStars", Vector))
171 base.append(("yStars", Vector))
172 base.append(("starsHighSNMask", Vector))
173 base.append(("starsLowSNMask", Vector))
174 # statistics
175 for name in self._stats:
176 base.append((f"{{band}}_highSNStars_{name}", Scalar))
177 base.append((f"{{band}}_lowSNStars_{name}", Scalar))
178 if "galaxies" in self.plotTypes: # type: ignore
179 base.append(("xGalaxies", Vector))
180 base.append(("yGalaxies", Vector))
181 base.append(("galaxiesHighSNMask", Vector))
182 base.append(("galaxiesLowSNMask", Vector))
183 # statistics
184 for name in self._stats:
185 base.append((f"{{band}}_highSNGalaxies_{name}", Scalar))
186 base.append((f"{{band}}_lowSNGalaxies_{name}", Scalar))
187 if "unknown" in self.plotTypes: # type: ignore
188 base.append(("xUnknown", Vector))
189 base.append(("yUnknown", Vector))
190 base.append(("unknownHighSNMask", Vector))
191 base.append(("unknownLowSNMask", Vector))
192 # statistics
193 for name in self._stats:
194 base.append((f"{{band}}_highSNUnknown_{name}", Scalar))
195 base.append((f"{{band}}_lowSNUnknown_{name}", Scalar))
196 if "any" in self.plotTypes: # type: ignore
197 base.append(("x", Vector))
198 base.append(("y", Vector))
199 base.append(("anyHighSNMask", Vector))
200 base.append(("anySNMask", Vector))
201 # statistics
202 for name in self._stats:
203 base.append((f"{{band}}_highSNAny_{name}", Scalar))
204 base.append((f"{{band}}_lowSNAny_{name}", Scalar))
205 base.append(("lowSnThreshold", Scalar))
206 base.append(("highSnThreshold", Scalar))
208 if self.addSummaryPlot:
209 base.append(("patch", Vector))
211 return base
213 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure:
215 self._validateInput(data, **kwargs)
216 return self.makePlot(data, **kwargs)
218 def _validateInput(self, data: KeyedData, **kwargs) -> None:
219 """NOTE currently can only check that something is not a Scalar, not
220 check that the data is consistent with Vector
221 """
222 needed = self.getFormattedInputSchema(**kwargs)
223 if remainder := {key.format(**kwargs) for key, _ in needed} - {
224 key.format(**kwargs) for key in data.keys()
225 }:
226 raise ValueError(f"Task needs keys {remainder} but they were not found in input")
227 for name, typ in needed:
228 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar)
229 if isScalar and typ != Scalar:
230 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}")
232 def makePlot(
233 self,
234 data: KeyedData,
235 skymap: BaseSkyMap,
236 plotInfo: Optional[Mapping[str, str]] = None,
237 sumStats: Optional[Mapping] = None,
238 **kwargs,
239 ) -> Figure:
240 """Makes a generic plot with a 2D histogram and collapsed histograms of
241 each axis.
242 Parameters
243 ----------
244 data : `pandas.core.frame.DataFrame`
245 The catalog to plot the points from.
246 plotInfo : `dict`
247 A dictionary of information about the data being plotted with keys:
248 ``"run"``
249 The output run for the plots (`str`).
250 ``"skymap"``
251 The type of skymap used for the data (`str`).
252 ``"filter"``
253 The filter used for this data (`str`).
254 ``"tract"``
255 The tract that the data comes from (`str`).
256 sumStats : `dict`
257 A dictionary where the patchIds are the keys which store the R.A.
258 and dec of the corners of the patch, along with a summary
259 statistic for each patch.
260 Returns
261 -------
262 fig : `matplotlib.figure.Figure`
263 The resulting figure.
264 Notes
265 -----
266 Uses the axisLabels config options `x` and `y` and the axisAction
267 config options `xAction` and `yAction` to plot a scatter
268 plot of the values against each other. A histogram of the points
269 collapsed onto each axis is also plotted. A summary panel showing the
270 median of the y value in each patch is shown in the upper right corner
271 of the resultant plot. The code uses the selectorActions to decide
272 which points to plot and the statisticSelector actions to determine
273 which points to use for the printed statistics.
274 """
275 if not self.plotTypes:
276 noDataFig = Figure()
277 noDataFig.text(0.3, 0.5, "No data to plot after selectors applied")
278 noDataFig = addPlotInfo(noDataFig, plotInfo)
279 return noDataFig
281 fig = plt.figure(dpi=300)
282 gs = gridspec.GridSpec(4, 4)
284 # add the various plot elements
285 ax, imhist = self._scatterPlot(data, fig, gs, **kwargs)
286 self._makeTopHistogram(data, fig, gs, ax, **kwargs)
287 self._makeSideHistogram(data, fig, gs, ax, imhist, **kwargs)
288 # Needs info from run quantum
289 if self.addSummaryPlot:
290 sumStats = generateSummaryStats(data, skymap, plotInfo)
291 label = self.yAxisLabel
292 fig = addSummaryPlot(fig, gs[0, -1], sumStats, label)
294 plt.draw()
295 plt.subplots_adjust(wspace=0.0, hspace=0.0, bottom=0.22, left=0.21)
296 fig = addPlotInfo(fig, plotInfo)
297 return fig
299 def _scatterPlot(
300 self, data: KeyedData, fig: Figure, gs: gridspec.GridSpec, **kwargs
301 ) -> tuple[Axes, Optional[PolyCollection]]:
302 # Main scatter plot
303 ax = fig.add_subplot(gs[1:, :-1])
305 newBlues = mkColormap(["paleturquoise", "midnightBlue"])
306 newReds = mkColormap(["lemonchiffon", "firebrick"])
308 binThresh = 5
310 yBinsOut = []
311 linesForLegend = []
313 toPlotList = []
314 histIm = None
315 highStats: _StatsContainer
316 lowStats: _StatsContainer
317 if "stars" in self.plotTypes: # type: ignore
318 highArgs = {}
319 lowArgs = {}
320 for name in self._stats:
321 highArgs[name] = cast(Scalar, data[f"{{band}}_highSNStars_{name}".format(**kwargs)])
322 lowArgs[name] = cast(Scalar, data[f"{{band}}_lowSNStars_{name}".format(**kwargs)])
323 highStats = _StatsContainer(**highArgs)
324 lowStats = _StatsContainer(**lowArgs)
326 toPlotList.append(
327 (
328 data["xStars"],
329 data["yStars"],
330 data["starsHighSNMask"],
331 data["starsLowSNMask"],
332 "midnightblue",
333 newBlues,
334 highStats,
335 lowStats,
336 )
337 )
338 if "galaxies" in self.plotTypes: # type: ignore
339 highArgs = {}
340 lowArgs = {}
341 for name in self._stats:
342 highArgs[name] = cast(Scalar, data[f"{{band}}_highSNGalaxies_{name}".format(**kwargs)])
343 lowArgs[name] = cast(Scalar, data[f"{{band}}_lowSNGalaxies_{name}".format(**kwargs)])
344 highStats = _StatsContainer(**highArgs)
345 lowStats = _StatsContainer(**lowArgs)
347 toPlotList.append(
348 (
349 data["xGalaxies"],
350 data["yGalaxies"],
351 data["galaxiesHighSNMask"],
352 data["galaxiesLowSNMask"],
353 "firebrick",
354 newReds,
355 highStats,
356 lowStats,
357 )
358 )
359 if "unknown" in self.plotTypes: # type: ignore
360 highArgs = {}
361 lowArgs = {}
362 for name in self._stats:
363 highArgs[name] = cast(Scalar, data[f"{{band}}_highSNUnknown_{name}".format(**kwargs)])
364 lowArgs[name] = cast(Scalar, data[f"{{band}}_lowSNUnknown_{name}".format(**kwargs)])
365 highStats = _StatsContainer(**highArgs)
366 lowStats = _StatsContainer(**lowArgs)
368 toPlotList.append(
369 (
370 data["xUnknown"],
371 data["yUnknown"],
372 data["unknownHighSNMask"],
373 data["unknownLowSNMask"],
374 "green",
375 None,
376 highStats,
377 lowStats,
378 )
379 )
380 if "any" in self.plotTypes: # type: ignore
381 highArgs = {}
382 lowArgs = {}
383 for name in self._stats:
384 highArgs[name] = cast(Scalar, data[f"{{band}}_highSNUnknown_{name}".format(**kwargs)])
385 lowArgs[name] = cast(Scalar, data[f"{{band}}_lowSNUnknown_{name}".format(**kwargs)])
386 highStats = _StatsContainer(**highArgs)
387 lowStats = _StatsContainer(**lowArgs)
389 toPlotList.append(
390 (
391 data["x"],
392 data["y"],
393 data["anyHighSNMask"],
394 data["anyLowSNMask"],
395 "purple",
396 None,
397 highStats,
398 lowStats,
399 )
400 )
402 xMin = None
403 for (j, (xs, ys, highSn, lowSn, color, cmap, highStats, lowStats)) in enumerate(toPlotList):
404 highSn = cast(Vector, highSn)
405 lowSn = cast(Vector, lowSn)
406 # ensure the columns are actually array
407 xs = np.array(xs)
408 ys = np.array(ys)
409 sigMadYs = sigmaMad(ys, nan_policy="omit")
410 if len(xs) < 2:
411 (medLine,) = ax.plot(
412 xs, np.nanmedian(ys), color, label=f"Median: {np.nanmedian(ys):.2g}", lw=0.8
413 )
414 linesForLegend.append(medLine)
415 sigMads = np.array([sigmaMad(ys, nan_policy="omit")] * len(xs))
416 (sigMadLine,) = ax.plot(
417 xs,
418 np.nanmedian(ys) + 1.0 * sigMads,
419 color,
420 alpha=0.8,
421 lw=0.8,
422 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:.2g}",
423 )
424 ax.plot(xs, np.nanmedian(ys) - 1.0 * sigMads, color, alpha=0.8)
425 linesForLegend.append(sigMadLine)
426 histIm = None
427 continue
429 [xs1, xs25, xs50, xs75, xs95, xs97] = np.nanpercentile(xs, [1, 25, 50, 75, 95, 97])
430 xScale = (xs97 - xs1) / 20.0 # This is ~5% of the data range
432 # 40 was used as the default number of bins because it looked good
433 xEdges = np.arange(
434 np.nanmin(xs) - xScale,
435 np.nanmax(xs) + xScale,
436 (np.nanmax(xs) + xScale - (np.nanmin(xs) - xScale)) / self.nBins,
437 )
438 medYs = np.nanmedian(ys)
439 fiveSigmaHigh = medYs + 5.0 * sigMadYs
440 fiveSigmaLow = medYs - 5.0 * sigMadYs
441 binSize = (fiveSigmaHigh - fiveSigmaLow) / 101.0
442 yEdges = np.arange(fiveSigmaLow, fiveSigmaHigh, binSize)
444 counts, xBins, yBins = np.histogram2d(xs, ys, bins=(xEdges, yEdges))
445 yBinsOut.append(yBins)
446 countsYs = np.sum(counts, axis=1)
448 ids = np.where((countsYs > binThresh))[0]
449 xEdgesPlot = xEdges[ids][1:]
450 xEdges = xEdges[ids]
452 if len(ids) > 1:
453 # Create the codes needed to turn the sigmaMad lines
454 # into a path to speed up checking which points are
455 # inside the area.
456 codes = np.ones(len(xEdgesPlot) * 2) * Path.LINETO
457 codes[0] = Path.MOVETO
458 codes[-1] = Path.CLOSEPOLY
460 meds = np.zeros(len(xEdgesPlot))
461 threeSigMadVerts = np.zeros((len(xEdgesPlot) * 2, 2))
462 sigMads = np.zeros(len(xEdgesPlot))
464 for (i, xEdge) in enumerate(xEdgesPlot):
465 ids = np.where((xs < xEdge) & (xs > xEdges[i]) & (np.isfinite(ys)))[0]
466 med = np.median(ys[ids])
467 sigMad = sigmaMad(ys[ids])
468 meds[i] = med
469 sigMads[i] = sigMad
470 threeSigMadVerts[i, :] = [xEdge, med + 3 * sigMad]
471 threeSigMadVerts[-(i + 1), :] = [xEdge, med - 3 * sigMad]
473 (medLine,) = ax.plot(xEdgesPlot, meds, color, label="Running Median")
474 linesForLegend.append(medLine)
476 # Make path to check which points lie within one sigma mad
477 threeSigMadPath = Path(threeSigMadVerts, codes)
479 # Add lines for the median +/- 3 * sigma MAD
480 (threeSigMadLine,) = ax.plot(
481 xEdgesPlot,
482 threeSigMadVerts[: len(xEdgesPlot), 1],
483 color,
484 alpha=0.4,
485 label=r"3$\sigma_{MAD}$",
486 )
487 ax.plot(xEdgesPlot[::-1], threeSigMadVerts[len(xEdgesPlot) :, 1], color, alpha=0.4)
489 # Add lines for the median +/- 1 * sigma MAD
490 (sigMadLine,) = ax.plot(
491 xEdgesPlot, meds + 1.0 * sigMads, color, alpha=0.8, label=r"$\sigma_{MAD}$"
492 )
493 linesForLegend.append(sigMadLine)
494 ax.plot(xEdgesPlot, meds - 1.0 * sigMads, color, alpha=0.8)
496 # Add lines for the median +/- 2 * sigma MAD
497 (twoSigMadLine,) = ax.plot(
498 xEdgesPlot, meds + 2.0 * sigMads, color, alpha=0.6, label=r"2$\sigma_{MAD}$"
499 )
500 linesForLegend.append(twoSigMadLine)
501 linesForLegend.append(threeSigMadLine)
502 ax.plot(xEdgesPlot, meds - 2.0 * sigMads, color, alpha=0.6)
504 # Check which points are outside 3 sigma MAD of the median
505 # and plot these as points.
506 inside = threeSigMadPath.contains_points(np.array([xs, ys]).T)
507 ax.plot(xs[~inside], ys[~inside], ".", ms=3, alpha=0.3, mfc=color, mec=color, zorder=-1)
509 # Add some stats text
510 xPos = 0.65 - 0.4 * j
511 bbox = dict(edgecolor=color, linestyle="--", facecolor="none")
512 highThresh = data["highSnThreshold"]
513 statText = f"S/N > {highThresh:0.4g} Stats ({self.magLabel} < {highStats.approxMag:0.4g})\n"
514 highStatsStr = (
515 f"Median: {highStats.median:0.4g} "
516 + r"$\sigma_{MAD}$: "
517 + f"{highStats.sigmaMad:0.4g} "
518 + r"N$_{points}$: "
519 + f"{highStats.count}"
520 )
521 statText += highStatsStr
522 fig.text(xPos, 0.090, statText, bbox=bbox, transform=fig.transFigure, fontsize=6)
524 bbox = dict(edgecolor=color, linestyle=":", facecolor="none")
525 lowThresh = data["lowSnThreshold"]
526 statText = f"S/N > {lowThresh:0.4g} Stats ({self.magLabel} < {lowStats.approxMag:0.4g})\n"
527 lowStatsStr = (
528 f"Median: {lowStats.median:0.4g} "
529 + r"$\sigma_{MAD}$: "
530 + f"{lowStats.sigmaMad:0.4g} "
531 + r"N$_{points}$: "
532 + f"{lowStats.count}"
533 )
534 statText += lowStatsStr
535 fig.text(xPos, 0.020, statText, bbox=bbox, transform=fig.transFigure, fontsize=6)
537 if self.plot2DHist:
538 histIm = ax.hexbin(xs[inside], ys[inside], gridsize=75, cmap=cmap, mincnt=1, zorder=-2)
540 # If there are not many sources being used for the
541 # statistics then plot them individually as just
542 # plotting a line makes the statistics look wrong
543 # as the magnitude estimation is iffy for low
544 # numbers of sources.
545 if np.sum(highSn) < 100 and np.sum(highSn) > 0:
546 ax.plot(
547 cast(Vector, xs[highSn]),
548 cast(Vector, ys[highSn]),
549 marker="x",
550 ms=4,
551 mec="w",
552 mew=2,
553 ls="none",
554 )
555 (highSnLine,) = ax.plot(
556 cast(Vector, xs[highSn]),
557 cast(Vector, ys[highSn]),
558 color=color,
559 marker="x",
560 ms=4,
561 ls="none",
562 label="High SN",
563 )
564 linesForLegend.append(highSnLine)
565 xMin = np.min(cast(Vector, xs[highSn]))
566 else:
567 ax.axvline(highStats.approxMag, color=color, ls="--")
569 if np.sum(lowSn) < 100 and np.sum(lowSn) > 0:
570 ax.plot(
571 cast(Vector, xs[lowSn]),
572 cast(Vector, ys[lowSn]),
573 marker="+",
574 ms=4,
575 mec="w",
576 mew=2,
577 ls="none",
578 )
579 (lowSnLine,) = ax.plot(
580 cast(Vector, xs[lowSn]),
581 cast(Vector, ys[lowSn]),
582 color=color,
583 marker="+",
584 ms=4,
585 ls="none",
586 label="Low SN",
587 )
588 linesForLegend.append(lowSnLine)
589 if xMin is None or xMin > np.min(cast(Vector, xs[lowSn])):
590 xMin = np.min(cast(Vector, xs[lowSn]))
591 else:
592 ax.axvline(lowStats.approxMag, color=color, ls=":")
594 else:
595 ax.plot(xs, ys, ".", ms=5, alpha=0.3, mfc=color, mec=color, zorder=-1)
596 meds = np.array([np.nanmedian(ys)] * len(xs))
597 (medLine,) = ax.plot(xs, meds, color, label=f"Median: {np.nanmedian(ys):0.3g}", lw=0.8)
598 linesForLegend.append(medLine)
599 sigMads = np.array([sigmaMad(ys, nan_policy="omit")] * len(xs))
600 (sigMadLine,) = ax.plot(
601 xs,
602 meds + 1.0 * sigMads,
603 color,
604 alpha=0.8,
605 lw=0.8,
606 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:0.3g}",
607 )
608 ax.plot(xs, meds - 1.0 * sigMads, color, alpha=0.8)
609 linesForLegend.append(sigMadLine)
610 histIm = None
612 # Set the scatter plot limits
613 # TODO: Make this not work by accident
614 if len(cast(Vector, data["yStars"])) > 0:
615 plotMed = np.nanmedian(cast(Vector, data["yStars"]))
616 else:
617 plotMed = np.nanmedian(cast(Vector, data["yGalaxies"]))
618 # Ignore types below pending making this not working my accident
619 if len(xs) < 2: # type: ignore
620 meds = [np.median(ys)] # type: ignore
621 if self.yLims:
622 ax.set_ylim(self.yLims[0], self.yLims[1]) # type: ignore
623 else:
624 numSig = 4
625 yLimMin = plotMed - numSig * sigMadYs # type: ignore
626 yLimMax = plotMed + numSig * sigMadYs # type: ignore
627 while (yLimMax < np.max(meds) or yLimMin > np.min(meds)) and numSig < 10: # type: ignore
628 numSig += 1
630 numSig += 1
631 yLimMin = plotMed - numSig * sigMadYs # type: ignore
632 yLimMax = plotMed + numSig * sigMadYs # type: ignore
633 ax.set_ylim(yLimMin, yLimMax)
635 if self.xLims:
636 ax.set_xlim(self.xLims[0], self.xLims[1]) # type: ignore
637 elif len(xs) > 2: # type: ignore
638 if xMin is None:
639 xMin = xs1 - 2 * xScale # type: ignore
640 ax.set_xlim(xMin, xs97 + 2 * xScale) # type: ignore
642 # Add a line legend
643 ax.legend(
644 handles=linesForLegend,
645 ncol=4,
646 fontsize=6,
647 loc="upper left",
648 framealpha=0.9,
649 edgecolor="k",
650 borderpad=0.4,
651 handlelength=1,
652 )
654 # Add axes labels
655 ax.set_ylabel(self.yAxisLabel, fontsize=10, labelpad=10)
656 ax.set_xlabel(self.xAxisLabel, fontsize=10, labelpad=2)
658 return ax, histIm
660 def _makeTopHistogram(
661 self, data: KeyedData, figure: Figure, gs: gridspec.GridSpec, ax: Axes, **kwargs
662 ) -> None:
663 # Top histogram
664 totalX: list[Vector] = []
665 if "stars" in self.plotTypes: # type: ignore
666 totalX.append(cast(Vector, data["xStars"]))
667 if "galaxies" in self.plotTypes: # type: ignore
668 totalX.append(cast(Vector, data["xGalaxies"]))
669 if "unknown" in self.plotTypes: # type: ignore
670 totalX.append(cast(Vector, data["xUknown"]))
671 if "any" in self.plotTypes: # type: ignore
672 totalX.append(cast(Vector, data["x"]))
674 totalXChained = [x for x in chain.from_iterable(totalX) if x == x]
676 topHist = figure.add_subplot(gs[0, :-1], sharex=ax)
677 topHist.hist(
678 totalXChained, bins=100, color="grey", alpha=0.3, log=True, label=f"All ({len(totalXChained)})"
679 )
680 if "galaxies" in self.plotTypes: # type: ignore
681 topHist.hist(
682 data["xGalaxies"],
683 bins=100,
684 color="firebrick",
685 histtype="step",
686 log=True,
687 label=f"Galaxies ({len(cast(Vector, data['xGalaxies']))})",
688 )
689 if "stars" in self.plotTypes: # type: ignore
690 topHist.hist(
691 data["xStars"],
692 bins=100,
693 color="midnightblue",
694 histtype="step",
695 log=True,
696 label=f"Stars ({len(cast(Vector, data['xStars']))})",
697 )
698 topHist.axes.get_xaxis().set_visible(False)
699 topHist.set_ylabel("Number", fontsize=8)
700 topHist.legend(fontsize=6, framealpha=0.9, borderpad=0.4, loc="lower left", ncol=3, edgecolor="k")
702 # Side histogram
704 def _makeSideHistogram(
705 self,
706 data: KeyedData,
707 figure: Figure,
708 gs: gridspec.Gridspec,
709 ax: Axes,
710 histIm: Optional[PolyCollection],
711 **kwargs,
712 ) -> None:
713 sideHist = figure.add_subplot(gs[1:, -1], sharey=ax)
715 totalY: list[Vector] = []
716 if "stars" in self.plotTypes: # type: ignore
717 totalY.append(cast(Vector, data["yStars"]))
718 if "galaxies" in self.plotTypes: # type: ignore
719 totalY.append(cast(Vector, data["yGalaxies"]))
720 if "unknown" in self.plotTypes: # type: ignore
721 totalY.append(cast(Vector, data["yUknown"]))
722 if "any" in self.plotTypes: # type: ignore
723 totalY.append(cast(Vector, data["y"]))
724 totalYChained = [y for y in chain.from_iterable(totalY) if y == y]
726 # cheat to get the total count while iterating once
727 yLimMin, yLimMax = ax.get_ylim()
728 bins = np.linspace(yLimMin, yLimMax)
729 sideHist.hist(
730 totalYChained,
731 bins=bins,
732 color="grey",
733 alpha=0.3,
734 orientation="horizontal",
735 log=True,
736 )
737 if "galaxies" in self.plotTypes: # type: ignore
738 sideHist.hist(
739 [g for g in cast(Vector, data["yGalaxies"]) if g == g],
740 bins=bins,
741 color="firebrick",
742 histtype="step",
743 orientation="horizontal",
744 log=True,
745 )
746 sideHist.hist(
747 cast(Vector, data["yGalaxies"])[cast(Vector, data["galaxiesHighSNMask"])],
748 bins=bins,
749 color="firebrick",
750 histtype="step",
751 orientation="horizontal",
752 log=True,
753 ls="--",
754 )
755 sideHist.hist(
756 cast(Vector, data["yGalaxies"])[cast(Vector, data["galaxiesLowSNMask"])],
757 bins=bins,
758 color="firebrick",
759 histtype="step",
760 orientation="horizontal",
761 log=True,
762 ls=":",
763 )
765 if "stars" in self.plotTypes: # type: ignore
766 sideHist.hist(
767 [s for s in cast(Vector, data["yStars"]) if s == s],
768 bins=bins,
769 color="midnightblue",
770 histtype="step",
771 orientation="horizontal",
772 log=True,
773 )
774 sideHist.hist(
775 cast(Vector, data["yStars"])[cast(Vector, data["starsHighSNMask"])],
776 bins=bins,
777 color="midnightblue",
778 histtype="step",
779 orientation="horizontal",
780 log=True,
781 ls="--",
782 )
783 sideHist.hist(
784 cast(Vector, data["yStars"])[cast(Vector, data["starsLowSNMask"])],
785 bins=bins,
786 color="midnightblue",
787 histtype="step",
788 orientation="horizontal",
789 log=True,
790 ls=":",
791 )
793 sideHist.axes.get_yaxis().set_visible(False)
794 sideHist.set_xlabel("Number", fontsize=8)
795 if self.plot2DHist and histIm is not None:
796 divider = make_axes_locatable(sideHist)
797 cax = divider.append_axes("right", size="8%", pad=0)
798 plt.colorbar(histIm, cax=cax, orientation="vertical", label="Number of Points Per Bin")