Coverage for python/lsst/analysis/tools/actions/plot/scatterplotWithTwoHists.py: 16%
383 statements
« prev ^ index » next coverage.py v6.4.2, created at 2022-08-05 01:24 -0700
« prev ^ index » next coverage.py v6.4.2, created at 2022-08-05 01:24 -0700
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21from __future__ import annotations
23from functools import partial
24from itertools import chain
25from typing import Mapping, NamedTuple, Optional, cast
27import astropy.units as u
28import matplotlib.pyplot as plt
29import numpy as np
30import scipy.stats as sps
31from lsst.analysis.tools.actions.scalar.scalarActions import CountAction, MedianAction, SigmaMadAction
32from lsst.pex.config import Field
33from lsst.pex.config.listField import ListField
34from lsst.pipe.tasks.configurableActions import ConfigurableActionField
35from lsst.skymap import BaseSkyMap
36from matplotlib import gridspec
37from matplotlib.axes import Axes
38from matplotlib.collections import PolyCollection
39from matplotlib.figure import Figure
40from matplotlib.path import Path
41from mpl_toolkits.axes_grid1 import make_axes_locatable
43from ...interfaces import (
44 KeyedData,
45 KeyedDataAction,
46 KeyedDataSchema,
47 PlotAction,
48 Scalar,
49 ScalarAction,
50 Vector,
51 VectorAction,
52)
53from ..keyedData import KeyedScalars
54from ..vector import SnSelector
55from .plotUtils import addPlotInfo, addSummaryPlot, generateSummaryStats, mkColormap
57# ignore because coolwarm is actually part of module
58cmapPatch = plt.cm.coolwarm.copy() # type: ignore
59cmapPatch.set_bad(color="none")
61sigmaMad = partial(sps.median_abs_deviation, scale="normal") # type: ignore
64class _ApproxMedian(ScalarAction):
65 vectorKey = Field[str](doc="Key for the vector to perform action on", optional=False)
66 inputUnit = Field[str](doc="Input unit of the vector", default="nJy")
67 outputUnit = Field[str](doc="Output unit of the vector", default="mag(AB)")
69 def getInputSchema(self, **kwargs) -> KeyedDataSchema:
70 return ((self.vectorKey.format(**kwargs), Vector),)
72 def __call__(self, data: KeyedData, **kwargs) -> Scalar:
73 mask = self.getMask(**kwargs)
74 value = np.sort(cast(Vector, data[self.vectorKey.format(**kwargs)])[mask])
75 x = int(len(value) / 10)
76 median = np.nanmedian(value[-x:])
77 if self.inputUnit != self.outputUnit:
78 median = (median * u.Unit(self.inputUnit)).to(u.Unit(self.outputUnit)).value
79 return median
82class _StatsImpl(KeyedScalars):
83 vectorKey = Field[str](doc="Column key to compute scalars")
85 snFluxType = Field[str](doc="column key for the flux type used in SN selection")
87 def setDefaults(self):
88 super().setDefaults()
89 self.scalarActions.median = MedianAction(vectorKey=self.vectorKey)
90 self.scalarActions.sigmaMad = SigmaMadAction(vectorKey=self.vectorKey)
91 self.scalarActions.count = CountAction(vectorKey=self.vectorKey)
92 self.scalarActions.approxMag = _ApproxMedian(vectorKey=self.snFluxType)
94 def __call__(self, data: KeyedData, **kwargs) -> KeyedData:
95 mask = kwargs.get("mask")
96 return super().__call__(data, **(kwargs | dict(mask=mask)))
99class ScatterPlotStatsAction(KeyedDataAction):
100 vectorKey = Field[str](doc="Vector on which to compute statistics")
101 highSNSelector = ConfigurableActionField[VectorAction](
102 doc="Selector used to determine high SN Objects", default=SnSelector(threshold=2700)
103 )
104 lowSNSelector = ConfigurableActionField[VectorAction](
105 doc="Selector used to determine low SN Objects", default=SnSelector(threshold=500)
106 )
107 fluxType = Field[str](doc="Vector key to use to compute signal to noise ratio", default="{band}_psfFlux")
109 def getInputSchema(self, **kwargs) -> KeyedDataSchema:
110 yield (self.vectorKey, Vector)
111 yield (self.fluxType, Vector)
112 yield from self.highSNSelector.getInputSchema()
113 yield from self.lowSNSelector.getInputSchema()
115 def getOutputSchema(self) -> KeyedDataSchema:
116 return (
117 (f'{self.identity or ""}HighSNMask', Vector),
118 (f'{self.identity or ""}LowSNMask', Vector),
119 (f"{{band}}_lowSN{self.identity.capitalize() if self.identity else ''}_median", Scalar),
120 (f"{{band}}_lowSN{self.identity.capitalize() if self.identity else ''}_sigmaMad", Scalar),
121 (f"{{band}}_lowSN{self.identity.capitalize() if self.identity else ''}_count", Scalar),
122 (f"{{band}}_lowSN{self.identity.capitalize() if self.identity else ''}_approxMag", Scalar),
123 (f"{{band}}_highSN{self.identity.capitalize() if self.identity else ''}_median", Scalar),
124 (f"{{band}}_highSN{self.identity.capitalize() if self.identity else ''}_sigmaMad", Scalar),
125 (f"{{band}}_highSN{self.identity.capitalize() if self.identity else ''}_count", Scalar),
126 (f"{{band}}_highSN{self.identity.capitalize() if self.identity else ''}_approxMag", Scalar),
127 ("highThreshold", Scalar),
128 ("lowThreshold", Scalar),
129 )
131 def __call__(self, data: KeyedData, **kwargs) -> KeyedData:
132 results = {}
133 highMaskKey = f'{self.identity or ""}HighSNMask'
134 results[highMaskKey] = self.highSNSelector(data, **kwargs)
136 lowMaskKey = f'{self.identity or ""}LowSNMask'
137 results[lowMaskKey] = self.lowSNSelector(data, **kwargs)
139 prefix = f"{band}_" if (band := kwargs.get("band")) else ""
141 stats = _StatsImpl(vectorKey=self.vectorKey, snFluxType=self.fluxType)
142 # this is sad, but pex_config seems to have broken behavior that
143 # is dangerous to fix
144 stats.setDefaults()
145 for maskKey, typ in ((lowMaskKey, "low"), (highMaskKey, "high")):
146 for name, value in stats(data, **(kwargs | {"mask": results[maskKey]})).items():
147 tmpKey = (
148 f"{prefix}{typ}SN{self.identity.capitalize() if self.identity else '' }_{name}".format(
149 **kwargs
150 )
151 )
152 results[tmpKey] = value
153 results["highSnThreshold"] = self.highSNSelector.threshold # type: ignore
154 results["lowSnThreshold"] = self.lowSNSelector.threshold # type: ignore
156 return results
159def _validatePlotTypes(value):
160 return value in ("stars", "galaxies", "unknown", "any", "mag")
163# ignore type because of conflicting name on tuple baseclass
164class _StatsContainer(NamedTuple):
165 median: Scalar
166 sigmaMad: Scalar
167 count: Scalar # type: ignore
168 approxMag: Scalar
171class ScatterPlotWithTwoHists(PlotAction):
172 yLims = ListField[float](
173 doc="ylimits of the plot, if not specified determined from data",
174 length=2,
175 optional=True,
176 )
178 xLims = ListField[float](
179 doc="xlimits of the plot, if not specified determined from data", length=2, optional=True
180 )
181 xAxisLabel = Field[str](doc="Label to use for the x axis", optional=False)
182 yAxisLabel = Field[str](doc="Label to use for the y axis", optional=False)
183 magLabel = Field[str](doc="Label to use for the magnitudes used for SNR", optional=False)
184 nBins = Field[float](doc="Number of bins on x axis", default=40.0)
185 plot2DHist = Field[bool](
186 doc="Plot a 2D histogram in dense areas of points on the scatter plot."
187 "Doesn't look great if plotting multiple datasets on top of each other.",
188 default=True,
189 )
190 plotTypes = ListField[str](
191 doc="Selection of types of objects to plot. Can take any combination of"
192 " stars, galaxies, unknown, mag, any",
193 optional=False,
194 itemCheck=_validatePlotTypes,
195 )
197 addSummaryPlot = Field[bool](
198 doc="Add a summary plot to the figure?",
199 default=False,
200 )
202 _stats = ("median", "sigmaMad", "count", "approxMag")
204 def getInputSchema(self) -> KeyedDataSchema:
205 base: list[tuple[str, type[Vector] | type[Scalar]]] = []
206 if "stars" in self.plotTypes: # type: ignore
207 base.append(("xStars", Vector))
208 base.append(("yStars", Vector))
209 base.append(("starsHighSNMask", Vector))
210 base.append(("starsLowSNMask", Vector))
211 # statistics
212 for name in self._stats:
213 base.append((f"{{band}}_highSNStars_{name}", Scalar))
214 base.append((f"{{band}}_lowSNStars_{name}", Scalar))
215 if "galaxies" in self.plotTypes: # type: ignore
216 base.append(("xGalaxies", Vector))
217 base.append(("yGalaxies", Vector))
218 base.append(("galaxiesHighSNMask", Vector))
219 base.append(("galaxiesLowSNMask", Vector))
220 # statistics
221 for name in self._stats:
222 base.append((f"{{band}}_highSNGalaxies_{name}", Scalar))
223 base.append((f"{{band}}_lowSNGalaxies_{name}", Scalar))
224 if "unknown" in self.plotTypes: # type: ignore
225 base.append(("xUnknown", Vector))
226 base.append(("yUnknown", Vector))
227 base.append(("unknownHighSNMask", Vector))
228 base.append(("unknownLowSNMask", Vector))
229 # statistics
230 for name in self._stats:
231 base.append((f"{{band}}_highSNUnknown_{name}", Scalar))
232 base.append((f"{{band}}_lowSNUnknown_{name}", Scalar))
233 if "any" in self.plotTypes: # type: ignore
234 base.append(("x", Vector))
235 base.append(("y", Vector))
236 base.append(("anyHighSNMask", Vector))
237 base.append(("anySNMask", Vector))
238 # statistics
239 for name in self._stats:
240 base.append((f"{{band}}_highSNAny_{name}", Scalar))
241 base.append((f"{{band}}_lowSNAny_{name}", Scalar))
242 base.append(("lowSnThreshold", Scalar))
243 base.append(("highSnThreshold", Scalar))
245 if self.addSummaryPlot:
246 base.append(("patch", Vector))
248 return base
250 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure:
252 self._validateInput(data, **kwargs)
253 return self.makePlot(data, **kwargs)
255 def _validateInput(self, data: KeyedData, **kwargs) -> None:
256 """NOTE currently can only check that something is not a Scalar, not
257 check that the data is consistent with Vector
258 """
259 needed = self.getFormattedInputSchema(**kwargs)
260 if remainder := {key.format(**kwargs) for key, _ in needed} - {
261 key.format(**kwargs) for key in data.keys()
262 }:
263 raise ValueError(f"Task needs keys {remainder} but they were not found in input")
264 for name, typ in needed:
265 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar)
266 if isScalar and typ != Scalar:
267 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}")
269 def makePlot(
270 self,
271 data: KeyedData,
272 skymap: BaseSkyMap,
273 plotInfo: Optional[Mapping[str, str]] = None,
274 sumStats: Optional[Mapping] = None,
275 **kwargs,
276 ) -> Figure:
277 """Makes a generic plot with a 2D histogram and collapsed histograms of
278 each axis.
279 Parameters
280 ----------
281 data : `pandas.core.frame.DataFrame`
282 The catalog to plot the points from.
283 plotInfo : `dict`
284 A dictionary of information about the data being plotted with keys:
285 ``"run"``
286 The output run for the plots (`str`).
287 ``"skymap"``
288 The type of skymap used for the data (`str`).
289 ``"filter"``
290 The filter used for this data (`str`).
291 ``"tract"``
292 The tract that the data comes from (`str`).
293 sumStats : `dict`
294 A dictionary where the patchIds are the keys which store the R.A.
295 and dec of the corners of the patch, along with a summary
296 statistic for each patch.
297 Returns
298 -------
299 fig : `matplotlib.figure.Figure`
300 The resulting figure.
301 Notes
302 -----
303 Uses the axisLabels config options `x` and `y` and the axisAction
304 config options `xAction` and `yAction` to plot a scatter
305 plot of the values against each other. A histogram of the points
306 collapsed onto each axis is also plotted. A summary panel showing the
307 median of the y value in each patch is shown in the upper right corner
308 of the resultant plot. The code uses the selectorActions to decide
309 which points to plot and the statisticSelector actions to determine
310 which points to use for the printed statistics.
311 """
312 if not self.plotTypes:
313 noDataFig = Figure()
314 noDataFig.text(0.3, 0.5, "No data to plot after selectors applied")
315 noDataFig = addPlotInfo(noDataFig, plotInfo)
316 return noDataFig
318 fig = plt.figure(dpi=300)
319 gs = gridspec.GridSpec(4, 4)
321 # add the various plot elements
322 ax, imhist = self._scatterPlot(data, fig, gs, **kwargs)
323 self._makeTopHistogram(data, fig, gs, ax, **kwargs)
324 self._makeSideHistogram(data, fig, gs, ax, imhist, **kwargs)
325 # Needs info from run quantum
326 if self.addSummaryPlot:
327 sumStats = generateSummaryStats(data, skymap, plotInfo)
328 label = self.yAxisLabel
329 fig = addSummaryPlot(fig, gs[0, -1], sumStats, label)
331 plt.draw()
332 plt.subplots_adjust(wspace=0.0, hspace=0.0, bottom=0.22, left=0.21)
333 fig = addPlotInfo(fig, plotInfo)
334 return fig
336 def _scatterPlot(
337 self, data: KeyedData, fig: Figure, gs: gridspec.GridSpec, **kwargs
338 ) -> tuple[Axes, Optional[PolyCollection]]:
339 # Main scatter plot
340 ax = fig.add_subplot(gs[1:, :-1])
342 newBlues = mkColormap(["paleturquoise", "midnightBlue"])
343 newReds = mkColormap(["lemonchiffon", "firebrick"])
345 binThresh = 5
347 yBinsOut = []
348 linesForLegend = []
350 toPlotList = []
351 histIm = None
352 highStats: _StatsContainer
353 lowStats: _StatsContainer
354 if "stars" in self.plotTypes: # type: ignore
355 highArgs = {}
356 lowArgs = {}
357 for name in self._stats:
358 highArgs[name] = cast(Scalar, data[f"{{band}}_highSNStars_{name}".format(**kwargs)])
359 lowArgs[name] = cast(Scalar, data[f"{{band}}_lowSNStars_{name}".format(**kwargs)])
360 highStats = _StatsContainer(**highArgs)
361 lowStats = _StatsContainer(**lowArgs)
363 toPlotList.append(
364 (
365 data["xStars"],
366 data["yStars"],
367 data["starsHighSNMask"],
368 data["starsLowSNMask"],
369 "midnightblue",
370 newBlues,
371 highStats,
372 lowStats,
373 )
374 )
375 if "galaxies" in self.plotTypes: # type: ignore
376 highArgs = {}
377 lowArgs = {}
378 for name in self._stats:
379 highArgs[name] = cast(Scalar, data[f"{{band}}_highSNGalaxies_{name}".format(**kwargs)])
380 lowArgs[name] = cast(Scalar, data[f"{{band}}_lowSNGalaxies_{name}".format(**kwargs)])
381 highStats = _StatsContainer(**highArgs)
382 lowStats = _StatsContainer(**lowArgs)
384 toPlotList.append(
385 (
386 data["xGalaxies"],
387 data["yGalaxies"],
388 data["galaxiesHighSNMask"],
389 data["galaxiesLowSNMask"],
390 "firebrick",
391 newReds,
392 highStats,
393 lowStats,
394 )
395 )
396 if "unknown" in self.plotTypes: # type: ignore
397 highArgs = {}
398 lowArgs = {}
399 for name in self._stats:
400 highArgs[name] = cast(Scalar, data[f"{{band}}_highSNUnknown_{name}".format(**kwargs)])
401 lowArgs[name] = cast(Scalar, data[f"{{band}}_lowSNUnknown_{name}".format(**kwargs)])
402 highStats = _StatsContainer(**highArgs)
403 lowStats = _StatsContainer(**lowArgs)
405 toPlotList.append(
406 (
407 data["xUnknown"],
408 data["yUnknown"],
409 data["unknownHighSNMask"],
410 data["unknownLowSNMask"],
411 "green",
412 None,
413 highStats,
414 lowStats,
415 )
416 )
417 if "any" in self.plotTypes: # type: ignore
418 highArgs = {}
419 lowArgs = {}
420 for name in self._stats:
421 highArgs[name] = cast(Scalar, data[f"{{band}}_highSNUnknown_{name}".format(**kwargs)])
422 lowArgs[name] = cast(Scalar, data[f"{{band}}_lowSNUnknown_{name}".format(**kwargs)])
423 highStats = _StatsContainer(**highArgs)
424 lowStats = _StatsContainer(**lowArgs)
426 toPlotList.append(
427 (
428 data["x"],
429 data["y"],
430 data["anyHighSNMask"],
431 data["anyLowSNMask"],
432 "purple",
433 None,
434 highStats,
435 lowStats,
436 )
437 )
439 xMin = None
440 for (j, (xs, ys, highSn, lowSn, color, cmap, highStats, lowStats)) in enumerate(toPlotList):
441 highSn = cast(Vector, highSn)
442 lowSn = cast(Vector, lowSn)
443 # ensure the columns are actually array
444 xs = np.array(xs)
445 ys = np.array(ys)
446 sigMadYs = sigmaMad(ys, nan_policy="omit")
447 if len(xs) < 2:
448 (medLine,) = ax.plot(
449 xs, np.nanmedian(ys), color, label=f"Median: {np.nanmedian(ys):0.3g}", lw=0.8
450 )
451 linesForLegend.append(medLine)
452 sigMads = np.array([sigmaMad(ys, nan_policy="omit")] * len(xs))
453 (sigMadLine,) = ax.plot(
454 xs,
455 np.nanmedian(ys) + 1.0 * sigMads,
456 color,
457 alpha=0.8,
458 lw=0.8,
459 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:0.3g}",
460 )
461 ax.plot(xs, np.nanmedian(ys) - 1.0 * sigMads, color, alpha=0.8)
462 linesForLegend.append(sigMadLine)
463 histIm = None
464 continue
466 [xs1, xs25, xs50, xs75, xs95, xs97] = np.nanpercentile(xs, [1, 25, 50, 75, 95, 97])
467 xScale = (xs97 - xs1) / 20.0 # This is ~5% of the data range
469 # 40 was used as the default number of bins because it looked good
470 xEdges = np.arange(
471 np.nanmin(xs) - xScale,
472 np.nanmax(xs) + xScale,
473 (np.nanmax(xs) + xScale - (np.nanmin(xs) - xScale)) / self.nBins,
474 )
475 medYs = np.nanmedian(ys)
476 fiveSigmaHigh = medYs + 5.0 * sigMadYs
477 fiveSigmaLow = medYs - 5.0 * sigMadYs
478 binSize = (fiveSigmaHigh - fiveSigmaLow) / 101.0
479 yEdges = np.arange(fiveSigmaLow, fiveSigmaHigh, binSize)
481 counts, xBins, yBins = np.histogram2d(xs, ys, bins=(xEdges, yEdges))
482 yBinsOut.append(yBins)
483 countsYs = np.sum(counts, axis=1)
485 ids = np.where((countsYs > binThresh))[0]
486 xEdgesPlot = xEdges[ids][1:]
487 xEdges = xEdges[ids]
489 if len(ids) > 1:
490 # Create the codes needed to turn the sigmaMad lines
491 # into a path to speed up checking which points are
492 # inside the area.
493 codes = np.ones(len(xEdgesPlot) * 2) * Path.LINETO
494 codes[0] = Path.MOVETO
495 codes[-1] = Path.CLOSEPOLY
497 meds = np.zeros(len(xEdgesPlot))
498 threeSigMadVerts = np.zeros((len(xEdgesPlot) * 2, 2))
499 sigMads = np.zeros(len(xEdgesPlot))
501 for (i, xEdge) in enumerate(xEdgesPlot):
502 ids = np.where((xs < xEdge) & (xs > xEdges[i]) & (np.isfinite(ys)))[0]
503 med = np.median(ys[ids])
504 sigMad = sigmaMad(ys[ids])
505 meds[i] = med
506 sigMads[i] = sigMad
507 threeSigMadVerts[i, :] = [xEdge, med + 3 * sigMad]
508 threeSigMadVerts[-(i + 1), :] = [xEdge, med - 3 * sigMad]
510 (medLine,) = ax.plot(xEdgesPlot, meds, color, label="Running Median")
511 linesForLegend.append(medLine)
513 # Make path to check which points lie within one sigma mad
514 threeSigMadPath = Path(threeSigMadVerts, codes)
516 # Add lines for the median +/- 3 * sigma MAD
517 (threeSigMadLine,) = ax.plot(
518 xEdgesPlot,
519 threeSigMadVerts[: len(xEdgesPlot), 1],
520 color,
521 alpha=0.4,
522 label=r"3$\sigma_{MAD}$",
523 )
524 ax.plot(xEdgesPlot[::-1], threeSigMadVerts[len(xEdgesPlot) :, 1], color, alpha=0.4)
526 # Add lines for the median +/- 1 * sigma MAD
527 (sigMadLine,) = ax.plot(
528 xEdgesPlot, meds + 1.0 * sigMads, color, alpha=0.8, label=r"$\sigma_{MAD}$"
529 )
530 linesForLegend.append(sigMadLine)
531 ax.plot(xEdgesPlot, meds - 1.0 * sigMads, color, alpha=0.8)
533 # Add lines for the median +/- 2 * sigma MAD
534 (twoSigMadLine,) = ax.plot(
535 xEdgesPlot, meds + 2.0 * sigMads, color, alpha=0.6, label=r"2$\sigma_{MAD}$"
536 )
537 linesForLegend.append(twoSigMadLine)
538 linesForLegend.append(threeSigMadLine)
539 ax.plot(xEdgesPlot, meds - 2.0 * sigMads, color, alpha=0.6)
541 # Check which points are outside 3 sigma MAD of the median
542 # and plot these as points.
543 inside = threeSigMadPath.contains_points(np.array([xs, ys]).T)
544 ax.plot(xs[~inside], ys[~inside], ".", ms=3, alpha=0.3, mfc=color, mec=color, zorder=-1)
546 # Add some stats text
547 xPos = 0.65 - 0.4 * j
548 bbox = dict(edgecolor=color, linestyle="--", facecolor="none")
549 highThresh = data["highSnThreshold"]
550 statText = f"S/N > {highThresh} Stats ({self.magLabel} < {highStats.approxMag})\n"
551 highStatsStr = (
552 f"Median: {highStats.median} "
553 + r"$\sigma_{MAD}$: "
554 + f"{highStats.sigmaMad} "
555 + r"N$_{points}$: "
556 + f"{highStats.count}"
557 )
558 statText += highStatsStr
559 fig.text(xPos, 0.090, statText, bbox=bbox, transform=fig.transFigure, fontsize=6)
561 bbox = dict(edgecolor=color, linestyle=":", facecolor="none")
562 lowThresh = data["lowSnThreshold"]
563 statText = f"S/N > {lowThresh} Stats ({self.magLabel} < {lowStats.approxMag})\n"
564 lowStatsStr = (
565 f"Median: {lowStats.median} "
566 + r"$\sigma_{MAD}$: "
567 + f"{lowStats.sigmaMad} "
568 + r"N$_{points}$: "
569 + f"{lowStats.count}"
570 )
571 statText += lowStatsStr
572 fig.text(xPos, 0.020, statText, bbox=bbox, transform=fig.transFigure, fontsize=6)
574 if self.plot2DHist:
575 histIm = ax.hexbin(xs[inside], ys[inside], gridsize=75, cmap=cmap, mincnt=1, zorder=-2)
577 # If there are not many sources being used for the
578 # statistics then plot them individually as just
579 # plotting a line makes the statistics look wrong
580 # as the magnitude estimation is iffy for low
581 # numbers of sources.
582 if np.sum(highSn) < 100 and np.sum(highSn) > 0:
583 ax.plot(
584 cast(Vector, xs[highSn]),
585 cast(Vector, ys[highSn]),
586 marker="x",
587 ms=4,
588 mec="w",
589 mew=2,
590 ls="none",
591 )
592 (highSnLine,) = ax.plot(
593 cast(Vector, xs[highSn]),
594 cast(Vector, ys[highSn]),
595 color=color,
596 marker="x",
597 ms=4,
598 ls="none",
599 label="High SN",
600 )
601 linesForLegend.append(highSnLine)
602 xMin = np.min(cast(Vector, xs[highSn]))
603 else:
604 ax.axvline(highStats.approxMag, color=color, ls="--")
606 if np.sum(lowSn) < 100 and np.sum(lowSn) > 0:
607 ax.plot(
608 cast(Vector, xs[lowSn]),
609 cast(Vector, ys[lowSn]),
610 marker="+",
611 ms=4,
612 mec="w",
613 mew=2,
614 ls="none",
615 )
616 (lowSnLine,) = ax.plot(
617 cast(Vector, xs[lowSn]),
618 cast(Vector, ys[lowSn]),
619 color=color,
620 marker="+",
621 ms=4,
622 ls="none",
623 label="Low SN",
624 )
625 linesForLegend.append(lowSnLine)
626 if xMin is None or xMin > np.min(cast(Vector, xs[lowSn])):
627 xMin = np.min(cast(Vector, xs[lowSn]))
628 else:
629 ax.axvline(lowStats.approxMag, color=color, ls=":")
631 else:
632 ax.plot(xs, ys, ".", ms=5, alpha=0.3, mfc=color, mec=color, zorder=-1)
633 meds = np.array([np.nanmedian(ys)] * len(xs))
634 (medLine,) = ax.plot(xs, meds, color, label=f"Median: {np.nanmedian(ys):0.3g}", lw=0.8)
635 linesForLegend.append(medLine)
636 sigMads = np.array([sigmaMad(ys, nan_policy="omit")] * len(xs))
637 (sigMadLine,) = ax.plot(
638 xs,
639 meds + 1.0 * sigMads,
640 color,
641 alpha=0.8,
642 lw=0.8,
643 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:0.3g}",
644 )
645 ax.plot(xs, meds - 1.0 * sigMads, color, alpha=0.8)
646 linesForLegend.append(sigMadLine)
647 histIm = None
649 # Set the scatter plot limits
650 # TODO: Make this not work by accident
651 if len(cast(Vector, data["yStars"])) > 0:
652 plotMed = np.nanmedian(cast(Vector, data["yStars"]))
653 else:
654 plotMed = np.nanmedian(cast(Vector, data["yGalaxies"]))
655 # Ignore types below pending making this not working my accident
656 if len(xs) < 2: # type: ignore
657 meds = [np.median(ys)] # type: ignore
658 if self.yLims:
659 ax.set_ylim(self.yLims[0], self.yLims[1]) # type: ignore
660 else:
661 numSig = 4
662 yLimMin = plotMed - numSig * sigMadYs # type: ignore
663 yLimMax = plotMed + numSig * sigMadYs # type: ignore
664 while (yLimMax < np.max(meds) or yLimMin > np.min(meds)) and numSig < 10: # type: ignore
665 numSig += 1
667 numSig += 1
668 yLimMin = plotMed - numSig * sigMadYs # type: ignore
669 yLimMax = plotMed + numSig * sigMadYs # type: ignore
670 ax.set_ylim(yLimMin, yLimMax)
672 if self.xLims:
673 ax.set_xlim(self.xLims[0], self.xLims[1]) # type: ignore
674 elif len(xs) > 2: # type: ignore
675 if xMin is None:
676 xMin = xs1 - 2 * xScale # type: ignore
677 ax.set_xlim(xMin, xs97 + 2 * xScale) # type: ignore
679 # Add a line legend
680 ax.legend(
681 handles=linesForLegend,
682 ncol=4,
683 fontsize=6,
684 loc="upper left",
685 framealpha=0.9,
686 edgecolor="k",
687 borderpad=0.4,
688 handlelength=1,
689 )
691 # Add axes labels
692 ax.set_ylabel(self.yAxisLabel, fontsize=10, labelpad=10)
693 ax.set_xlabel(self.xAxisLabel, fontsize=10, labelpad=2)
695 return ax, histIm
697 def _makeTopHistogram(
698 self, data: KeyedData, figure: Figure, gs: gridspec.GridSpec, ax: Axes, **kwargs
699 ) -> None:
700 # Top histogram
701 totalX: list[Vector] = []
702 if "stars" in self.plotTypes: # type: ignore
703 totalX.append(cast(Vector, data["xStars"]))
704 if "galaxies" in self.plotTypes: # type: ignore
705 totalX.append(cast(Vector, data["xGalaxies"]))
706 if "unknown" in self.plotTypes: # type: ignore
707 totalX.append(cast(Vector, data["xUknown"]))
708 if "any" in self.plotTypes: # type: ignore
709 totalX.append(cast(Vector, data["x"]))
711 totalXChained = [x for x in chain.from_iterable(totalX) if x == x]
713 topHist = figure.add_subplot(gs[0, :-1], sharex=ax)
714 topHist.hist(
715 totalXChained, bins=100, color="grey", alpha=0.3, log=True, label=f"All ({len(totalXChained)})"
716 )
717 if "galaxies" in self.plotTypes: # type: ignore
718 topHist.hist(
719 data["xGalaxies"],
720 bins=100,
721 color="firebrick",
722 histtype="step",
723 log=True,
724 label=f"Galaxies ({len(cast(Vector, data['xGalaxies']))})",
725 )
726 if "stars" in self.plotTypes: # type: ignore
727 topHist.hist(
728 data["xStars"],
729 bins=100,
730 color="midnightblue",
731 histtype="step",
732 log=True,
733 label=f"Stars ({len(cast(Vector, data['xStars']))})",
734 )
735 topHist.axes.get_xaxis().set_visible(False)
736 topHist.set_ylabel("Number", fontsize=8)
737 topHist.legend(fontsize=6, framealpha=0.9, borderpad=0.4, loc="lower left", ncol=3, edgecolor="k")
739 # Side histogram
741 def _makeSideHistogram(
742 self,
743 data: KeyedData,
744 figure: Figure,
745 gs: gridspec.Gridspec,
746 ax: Axes,
747 histIm: Optional[PolyCollection],
748 **kwargs,
749 ) -> None:
750 sideHist = figure.add_subplot(gs[1:, -1], sharey=ax)
752 totalY: list[Vector] = []
753 if "stars" in self.plotTypes: # type: ignore
754 totalY.append(cast(Vector, data["yStars"]))
755 if "galaxies" in self.plotTypes: # type: ignore
756 totalY.append(cast(Vector, data["yGalaxies"]))
757 if "unknown" in self.plotTypes: # type: ignore
758 totalY.append(cast(Vector, data["yUknown"]))
759 if "any" in self.plotTypes: # type: ignore
760 totalY.append(cast(Vector, data["y"]))
761 totalYChained = [y for y in chain.from_iterable(totalY) if y == y]
763 # cheat to get the total count while iterating once
764 yLimMin, yLimMax = ax.get_ylim()
765 bins = np.linspace(yLimMin, yLimMax)
766 sideHist.hist(
767 totalYChained,
768 bins=bins,
769 color="grey",
770 alpha=0.3,
771 orientation="horizontal",
772 log=True,
773 )
774 if "galaxies" in self.plotTypes: # type: ignore
775 sideHist.hist(
776 [g for g in cast(Vector, data["yGalaxies"]) if g == g],
777 bins=bins,
778 color="firebrick",
779 histtype="step",
780 orientation="horizontal",
781 log=True,
782 )
783 sideHist.hist(
784 cast(Vector, data["yGalaxies"])[cast(Vector, data["galaxiesHighSNMask"])],
785 bins=bins,
786 color="firebrick",
787 histtype="step",
788 orientation="horizontal",
789 log=True,
790 ls="--",
791 )
792 sideHist.hist(
793 cast(Vector, data["yGalaxies"])[cast(Vector, data["galaxiesLowSNMask"])],
794 bins=bins,
795 color="firebrick",
796 histtype="step",
797 orientation="horizontal",
798 log=True,
799 ls=":",
800 )
802 if "stars" in self.plotTypes: # type: ignore
803 sideHist.hist(
804 [s for s in cast(Vector, data["yStars"]) if s == s],
805 bins=bins,
806 color="midnightblue",
807 histtype="step",
808 orientation="horizontal",
809 log=True,
810 )
811 sideHist.hist(
812 cast(Vector, data["yStars"])[cast(Vector, data["starsHighSNMask"])],
813 bins=bins,
814 color="midnightblue",
815 histtype="step",
816 orientation="horizontal",
817 log=True,
818 ls="--",
819 )
820 sideHist.hist(
821 cast(Vector, data["yStars"])[cast(Vector, data["starsLowSNMask"])],
822 bins=bins,
823 color="midnightblue",
824 histtype="step",
825 orientation="horizontal",
826 log=True,
827 ls=":",
828 )
830 sideHist.axes.get_yaxis().set_visible(False)
831 sideHist.set_xlabel("Number", fontsize=8)
832 if self.plot2DHist and histIm is not None:
833 divider = make_axes_locatable(sideHist)
834 cax = divider.append_axes("right", size="8%", pad=0)
835 plt.colorbar(histIm, cax=cax, orientation="vertical", label="Number of Points Per Bin")