Coverage for python/lsst/analysis/tools/actions/plot/scatterplotWithTwoHists.py: 12%
367 statements
« prev ^ index » next coverage.py v7.2.5, created at 2023-05-04 11:09 +0000
« prev ^ index » next coverage.py v7.2.5, created at 2023-05-04 11:09 +0000
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ("ScatterPlotStatsAction", "ScatterPlotWithTwoHists")
26from itertools import chain
27from typing import Mapping, NamedTuple, Optional, cast
29import matplotlib.pyplot as plt
30import numpy as np
31from lsst.pex.config import Field
32from lsst.pex.config.configurableActions import ConfigurableActionField
33from lsst.pex.config.listField import ListField
34from lsst.skymap import BaseSkyMap
35from matplotlib import gridspec
36from matplotlib.axes import Axes
37from matplotlib.collections import PolyCollection
38from matplotlib.figure import Figure
39from matplotlib.path import Path
40from mpl_toolkits.axes_grid1 import make_axes_locatable
42from ...interfaces import KeyedData, KeyedDataAction, KeyedDataSchema, PlotAction, Scalar, ScalarType, Vector
43from ...statistics import nansigmaMad, sigmaMad
44from ..keyedData.summaryStatistics import SummaryStatisticAction
45from ..scalar import MedianAction
46from ..vector import MagColumnNanoJansky, SnSelector
47from .plotUtils import addPlotInfo, addSummaryPlot, generateSummaryStats, mkColormap
49# ignore because coolwarm is actually part of module
50cmapPatch = plt.cm.coolwarm.copy() # type: ignore
51cmapPatch.set_bad(color="none")
54class ScatterPlotStatsAction(KeyedDataAction):
55 """Calculates the statistics needed for the
56 scatter plot with two hists.
57 """
59 vectorKey = Field[str](doc="Vector on which to compute statistics")
60 highSNSelector = ConfigurableActionField[SnSelector](
61 doc="Selector used to determine high SN Objects", default=SnSelector(threshold=2700)
62 )
63 lowSNSelector = ConfigurableActionField[SnSelector](
64 doc="Selector used to determine low SN Objects", default=SnSelector(threshold=500)
65 )
66 fluxType = Field[str](doc="Vector key to use to compute signal to noise ratio", default="{band}_psfFlux")
68 def getInputSchema(self, **kwargs) -> KeyedDataSchema:
69 yield (self.vectorKey, Vector)
70 yield (self.fluxType, Vector)
71 yield from self.highSNSelector.getInputSchema()
72 yield from self.lowSNSelector.getInputSchema()
74 def getOutputSchema(self) -> KeyedDataSchema:
75 return (
76 (f'{self.identity or ""}HighSNMask', Vector),
77 (f'{self.identity or ""}LowSNMask', Vector),
78 (f"{{band}}_lowSN{self.identity.capitalize() if self.identity else ''}_median", Scalar),
79 (f"{{band}}_lowSN{self.identity.capitalize() if self.identity else ''}_sigmaMad", Scalar),
80 (f"{{band}}_lowSN{self.identity.capitalize() if self.identity else ''}_count", Scalar),
81 (f"{{band}}_lowSN{self.identity.capitalize() if self.identity else ''}_approxMag", Scalar),
82 (f"{{band}}_highSN{self.identity.capitalize() if self.identity else ''}_median", Scalar),
83 (f"{{band}}_highSN{self.identity.capitalize() if self.identity else ''}_sigmaMad", Scalar),
84 (f"{{band}}_highSN{self.identity.capitalize() if self.identity else ''}_count", Scalar),
85 (f"{{band}}_highSN{self.identity.capitalize() if self.identity else ''}_approxMag", Scalar),
86 ("highThreshold", Scalar),
87 ("lowThreshold", Scalar),
88 )
90 def __call__(self, data: KeyedData, **kwargs) -> KeyedData:
91 results = {}
92 highMaskKey = f'{(self.identity or "").lower()}HighSNMask'
93 results[highMaskKey] = self.highSNSelector(data, **kwargs)
95 lowMaskKey = f'{(self.identity or "").lower()}LowSNMask'
96 results[lowMaskKey] = self.lowSNSelector(data, **kwargs)
98 prefix = f"{band}_" if (band := kwargs.get("band")) else ""
99 fluxes = data[self.fluxType.format(band=band)] if band is not None else None
101 statAction = SummaryStatisticAction(vectorKey=self.vectorKey)
103 # this is sad, but pex_config seems to have broken behavior that
104 # is dangerous to fix
105 statAction.setDefaults()
107 medianAction = MedianAction(vectorKey="mag")
108 magAction = MagColumnNanoJansky(vectorKey="flux")
110 for maskKey, binName in ((lowMaskKey, "low"), (highMaskKey, "high")):
111 name = f"{prefix}{binName}SN{self.identity.capitalize() if self.identity else ''}"
112 # set the approxMag to the median mag in the SN selection
113 results[f"{name}_approxMag".format(**kwargs)] = (
114 medianAction({"mag": magAction({"flux": fluxes[results[maskKey]]})}) # type: ignore
115 if band is not None
116 else np.nan
117 )
118 stats = statAction(data, **(kwargs | {"mask": results[maskKey]})).items()
119 for suffix, value in stats:
120 tmpKey = f"{name}_{suffix}".format(**kwargs)
121 results[tmpKey] = value
122 results["highSnThreshold"] = self.highSNSelector.threshold # type: ignore
123 results["lowSnThreshold"] = self.lowSNSelector.threshold # type: ignore
125 return results
128def _validatePlotTypes(value):
129 return value in ("stars", "galaxies", "unknown", "any", "mag")
132# ignore type because of conflicting name on tuple baseclass
133class _StatsContainer(NamedTuple):
134 median: Scalar
135 sigmaMad: Scalar
136 count: Scalar # type: ignore
137 approxMag: Scalar
140class ScatterPlotWithTwoHists(PlotAction):
141 """Makes a scatter plot of the data with a marginal
142 histogram for each axis.
143 """
145 yLims = ListField[float](
146 doc="ylimits of the plot, if not specified determined from data",
147 length=2,
148 optional=True,
149 )
151 xLims = ListField[float](
152 doc="xlimits of the plot, if not specified determined from data", length=2, optional=True
153 )
154 xAxisLabel = Field[str](doc="Label to use for the x axis", optional=False)
155 yAxisLabel = Field[str](doc="Label to use for the y axis", optional=False)
156 magLabel = Field[str](doc="Label to use for the magnitudes used for SNR", optional=False)
157 nBins = Field[float](doc="Number of bins on x axis", default=40.0)
158 plot2DHist = Field[bool](
159 doc="Plot a 2D histogram in dense areas of points on the scatter plot."
160 "Doesn't look great if plotting multiple datasets on top of each other.",
161 default=True,
162 )
163 plotTypes = ListField[str](
164 doc="Selection of types of objects to plot. Can take any combination of"
165 " stars, galaxies, unknown, mag, any",
166 optional=False,
167 itemCheck=_validatePlotTypes,
168 )
170 addSummaryPlot = Field[bool](
171 doc="Add a summary plot to the figure?",
172 default=False,
173 )
175 _stats = ("median", "sigmaMad", "count", "approxMag")
177 def getInputSchema(self) -> KeyedDataSchema:
178 base: list[tuple[str, type[Vector] | ScalarType]] = []
179 if "stars" in self.plotTypes: # type: ignore
180 base.append(("xStars", Vector))
181 base.append(("yStars", Vector))
182 base.append(("starsHighSNMask", Vector))
183 base.append(("starsLowSNMask", Vector))
184 # statistics
185 for name in self._stats:
186 base.append((f"{{band}}_highSNStars_{name}", Scalar))
187 base.append((f"{{band}}_lowSNStars_{name}", Scalar))
188 if "galaxies" in self.plotTypes: # type: ignore
189 base.append(("xGalaxies", Vector))
190 base.append(("yGalaxies", Vector))
191 base.append(("galaxiesHighSNMask", Vector))
192 base.append(("galaxiesLowSNMask", Vector))
193 # statistics
194 for name in self._stats:
195 base.append((f"{{band}}_highSNGalaxies_{name}", Scalar))
196 base.append((f"{{band}}_lowSNGalaxies_{name}", Scalar))
197 if "unknown" in self.plotTypes: # type: ignore
198 base.append(("xUnknown", Vector))
199 base.append(("yUnknown", Vector))
200 base.append(("unknownHighSNMask", Vector))
201 base.append(("unknownLowSNMask", Vector))
202 # statistics
203 for name in self._stats:
204 base.append((f"{{band}}_highSNUnknown_{name}", Scalar))
205 base.append((f"{{band}}_lowSNUnknown_{name}", Scalar))
206 if "any" in self.plotTypes: # type: ignore
207 base.append(("x", Vector))
208 base.append(("y", Vector))
209 base.append(("anyHighSNMask", Vector))
210 base.append(("anySNMask", Vector))
211 # statistics
212 for name in self._stats:
213 base.append((f"{{band}}_highSNAny_{name}", Scalar))
214 base.append((f"{{band}}_lowSNAny_{name}", Scalar))
215 base.append(("lowSnThreshold", Scalar))
216 base.append(("highSnThreshold", Scalar))
218 if self.addSummaryPlot:
219 base.append(("patch", Vector))
221 return base
223 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure:
224 self._validateInput(data, **kwargs)
225 return self.makePlot(data, **kwargs)
227 def _validateInput(self, data: KeyedData, **kwargs) -> None:
228 """NOTE currently can only check that something is not a Scalar, not
229 check that the data is consistent with Vector
230 """
231 needed = self.getFormattedInputSchema(**kwargs)
232 if remainder := {key.format(**kwargs) for key, _ in needed} - {
233 key.format(**kwargs) for key in data.keys()
234 }:
235 raise ValueError(f"Task needs keys {remainder} but they were not found in input")
236 for name, typ in needed:
237 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar)
238 if isScalar and typ != Scalar:
239 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}")
241 def makePlot(
242 self,
243 data: KeyedData,
244 skymap: BaseSkyMap,
245 plotInfo: Mapping[str, str],
246 sumStats: Optional[Mapping] = None,
247 **kwargs,
248 ) -> Figure:
249 """Makes a generic plot with a 2D histogram and collapsed histograms of
250 each axis.
252 Parameters
253 ----------
254 data : `KeyedData`
255 The catalog to plot the points from.
256 skymap : `lsst.skymap.BaseSkyMap`
257 The skymap that gives the patch locations
258 plotInfo : `dict`
259 A dictionary of information about the data being plotted with keys:
261 * ``"run"``
262 The output run for the plots (`str`).
263 * ``"skymap"``
264 The type of skymap used for the data (`str`).
265 * ``"filter"``
266 The filter used for this data (`str`).
267 * ``"tract"``
268 The tract that the data comes from (`str`).
269 sumStats : `dict`
270 A dictionary where the patchIds are the keys which store the R.A.
271 and dec of the corners of the patch, along with a summary
272 statistic for each patch.
274 Returns
275 -------
276 fig : `matplotlib.figure.Figure`
277 The resulting figure.
279 Notes
280 -----
281 Uses the axisLabels config options `x` and `y` and the axisAction
282 config options `xAction` and `yAction` to plot a scatter
283 plot of the values against each other. A histogram of the points
284 collapsed onto each axis is also plotted. A summary panel showing the
285 median of the y value in each patch is shown in the upper right corner
286 of the resultant plot. The code uses the selectorActions to decide
287 which points to plot and the statisticSelector actions to determine
288 which points to use for the printed statistics.
290 If this function is being used within the pipetask framework
291 that takes care of making sure that data has all the required
292 elements but if you are runnign this as a standalone function
293 then you will need to provide the following things in the
294 input data.
296 * If stars is in self.plotTypes:
297 xStars, yStars, starsHighSNMask, starsLowSNMask and
298 {band}_highSNStars_{name}, {band}_lowSNStars_{name}
299 where name is median, sigma_Mad, count and approxMag.
301 * If it is for galaxies/unknowns then replace stars in the above
302 names with galaxies/unknowns.
304 * If it is for any (which covers all the points) then it
305 becomes, x, y, and any instead of stars for the other
306 parameters given above.
308 * In every case it is expected that data contains:
309 lowSnThreshold, highSnThreshold and patch
310 (if the summary plot is being plotted).
312 Examples
313 --------
314 An example of the plot produced from this code is here:
316 .. image:: /_static/analysis_tools/scatterPlotExample.png
318 For a detailed example of how to make a plot from the command line
319 please see the
320 :ref:`getting started guide<analysis-tools-getting-started>`.
321 """
322 if not self.plotTypes:
323 noDataFig = Figure()
324 noDataFig.text(0.3, 0.5, "No data to plot after selectors applied")
325 noDataFig = addPlotInfo(noDataFig, plotInfo)
326 return noDataFig
328 # Set default color and line style for the horizontal
329 # reference line at 0
330 if "hlineColor" not in kwargs:
331 kwargs["hlineColor"] = "black"
333 if "hlineStyle" not in kwargs:
334 kwargs["hlineStyle"] = (0, (1, 4))
336 fig = plt.figure(dpi=300)
337 gs = gridspec.GridSpec(4, 4)
339 # add the various plot elements
340 ax, imhist = self._scatterPlot(data, fig, gs, **kwargs)
341 self._makeTopHistogram(data, fig, gs, ax, **kwargs)
342 self._makeSideHistogram(data, fig, gs, ax, imhist, **kwargs)
343 # Needs info from run quantum
344 if self.addSummaryPlot:
345 sumStats = generateSummaryStats(data, skymap, plotInfo)
346 label = self.yAxisLabel
347 fig = addSummaryPlot(fig, gs[0, -1], sumStats, label)
349 plt.draw()
350 plt.subplots_adjust(wspace=0.0, hspace=0.0, bottom=0.22, left=0.21)
351 fig = addPlotInfo(fig, plotInfo)
352 return fig
354 def _scatterPlot(
355 self, data: KeyedData, fig: Figure, gs: gridspec.GridSpec, **kwargs
356 ) -> tuple[Axes, Optional[PolyCollection]]:
357 # Main scatter plot
358 ax = fig.add_subplot(gs[1:, :-1])
360 newBlues = mkColormap(["paleturquoise", "midnightBlue"])
361 newReds = mkColormap(["lemonchiffon", "firebrick"])
363 binThresh = 5
365 yBinsOut = []
366 linesForLegend = []
368 toPlotList = []
369 histIm = None
370 highStats: _StatsContainer
371 lowStats: _StatsContainer
372 if "stars" in self.plotTypes: # type: ignore
373 highArgs = {}
374 lowArgs = {}
375 for name in self._stats:
376 highArgs[name] = cast(Scalar, data[f"{{band}}_highSNStars_{name}".format(**kwargs)])
377 lowArgs[name] = cast(Scalar, data[f"{{band}}_lowSNStars_{name}".format(**kwargs)])
378 highStats = _StatsContainer(**highArgs)
379 lowStats = _StatsContainer(**lowArgs)
381 toPlotList.append(
382 (
383 data["xStars"],
384 data["yStars"],
385 data["starsHighSNMask"],
386 data["starsLowSNMask"],
387 "midnightblue",
388 newBlues,
389 highStats,
390 lowStats,
391 )
392 )
393 if "galaxies" in self.plotTypes: # type: ignore
394 highArgs = {}
395 lowArgs = {}
396 for name in self._stats:
397 highArgs[name] = cast(Scalar, data[f"{{band}}_highSNGalaxies_{name}".format(**kwargs)])
398 lowArgs[name] = cast(Scalar, data[f"{{band}}_lowSNGalaxies_{name}".format(**kwargs)])
399 highStats = _StatsContainer(**highArgs)
400 lowStats = _StatsContainer(**lowArgs)
402 toPlotList.append(
403 (
404 data["xGalaxies"],
405 data["yGalaxies"],
406 data["galaxiesHighSNMask"],
407 data["galaxiesLowSNMask"],
408 "firebrick",
409 newReds,
410 highStats,
411 lowStats,
412 )
413 )
414 if "unknown" in self.plotTypes: # type: ignore
415 highArgs = {}
416 lowArgs = {}
417 for name in self._stats:
418 highArgs[name] = cast(Scalar, data[f"{{band}}_highSNUnknown_{name}".format(**kwargs)])
419 lowArgs[name] = cast(Scalar, data[f"{{band}}_lowSNUnknown_{name}".format(**kwargs)])
420 highStats = _StatsContainer(**highArgs)
421 lowStats = _StatsContainer(**lowArgs)
423 toPlotList.append(
424 (
425 data["xUnknown"],
426 data["yUnknown"],
427 data["unknownHighSNMask"],
428 data["unknownLowSNMask"],
429 "green",
430 None,
431 highStats,
432 lowStats,
433 )
434 )
435 if "any" in self.plotTypes: # type: ignore
436 highArgs = {}
437 lowArgs = {}
438 for name in self._stats:
439 highArgs[name] = cast(Scalar, data[f"{{band}}_highSNUnknown_{name}".format(**kwargs)])
440 lowArgs[name] = cast(Scalar, data[f"{{band}}_lowSNUnknown_{name}".format(**kwargs)])
441 highStats = _StatsContainer(**highArgs)
442 lowStats = _StatsContainer(**lowArgs)
444 toPlotList.append(
445 (
446 data["x"],
447 data["y"],
448 data["anyHighSNMask"],
449 data["anyLowSNMask"],
450 "purple",
451 None,
452 highStats,
453 lowStats,
454 )
455 )
457 xMin = None
458 for j, (xs, ys, highSn, lowSn, color, cmap, highStats, lowStats) in enumerate(toPlotList):
459 highSn = cast(Vector, highSn)
460 lowSn = cast(Vector, lowSn)
461 # ensure the columns are actually array
462 xs = np.array(xs)
463 ys = np.array(ys)
464 sigMadYs = nansigmaMad(ys)
465 if len(xs) < 2:
466 (medLine,) = ax.plot(
467 xs, np.nanmedian(ys), color, label=f"Median: {np.nanmedian(ys):.2g}", lw=0.8
468 )
469 linesForLegend.append(medLine)
470 sigMads = np.array([nansigmaMad(ys)] * len(xs))
471 (sigMadLine,) = ax.plot(
472 xs,
473 np.nanmedian(ys) + 1.0 * sigMads,
474 color,
475 alpha=0.8,
476 lw=0.8,
477 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:.2g}",
478 )
479 ax.plot(xs, np.nanmedian(ys) - 1.0 * sigMads, color, alpha=0.8)
480 linesForLegend.append(sigMadLine)
481 histIm = None
482 continue
484 [xs1, xs25, xs50, xs75, xs95, xs97] = np.nanpercentile(xs, [1, 25, 50, 75, 95, 97])
485 xScale = (xs97 - xs1) / 20.0 # This is ~5% of the data range
487 # 40 was used as the default number of bins because it looked good
488 xEdges = np.arange(
489 np.nanmin(xs) - xScale,
490 np.nanmax(xs) + xScale,
491 (np.nanmax(xs) + xScale - (np.nanmin(xs) - xScale)) / self.nBins,
492 )
493 medYs = np.nanmedian(ys)
494 fiveSigmaHigh = medYs + 5.0 * sigMadYs
495 fiveSigmaLow = medYs - 5.0 * sigMadYs
496 binSize = (fiveSigmaHigh - fiveSigmaLow) / 101.0
497 yEdges = np.arange(fiveSigmaLow, fiveSigmaHigh, binSize)
499 counts, xBins, yBins = np.histogram2d(xs, ys, bins=(xEdges, yEdges))
500 yBinsOut.append(yBins)
501 countsYs = np.sum(counts, axis=1)
503 ids = np.where((countsYs > binThresh))[0]
504 xEdgesPlot = xEdges[ids][1:]
505 xEdges = xEdges[ids]
507 if len(ids) > 1:
508 # Create the codes needed to turn the sigmaMad lines
509 # into a path to speed up checking which points are
510 # inside the area.
511 codes = np.ones(len(xEdgesPlot) * 2) * Path.LINETO
512 codes[0] = Path.MOVETO
513 codes[-1] = Path.CLOSEPOLY
515 meds = np.zeros(len(xEdgesPlot))
516 threeSigMadVerts = np.zeros((len(xEdgesPlot) * 2, 2))
517 sigMads = np.zeros(len(xEdgesPlot))
519 for i, xEdge in enumerate(xEdgesPlot):
520 ids = np.where((xs < xEdge) & (xs > xEdges[i]) & (np.isfinite(ys)))[0]
521 med = np.nanmedian(ys[ids])
522 sigMad = sigmaMad(ys[ids], nan_policy="omit")
523 meds[i] = med
524 sigMads[i] = sigMad
525 threeSigMadVerts[i, :] = [xEdge, med + 3 * sigMad]
526 threeSigMadVerts[-(i + 1), :] = [xEdge, med - 3 * sigMad]
528 (medLine,) = ax.plot(xEdgesPlot, meds, color, label="Running Median")
529 linesForLegend.append(medLine)
531 # Make path to check which points lie within one sigma mad
532 threeSigMadPath = Path(threeSigMadVerts, codes)
534 # Add lines for the median +/- 3 * sigma MAD
535 (threeSigMadLine,) = ax.plot(
536 xEdgesPlot,
537 threeSigMadVerts[: len(xEdgesPlot), 1],
538 color,
539 alpha=0.4,
540 label=r"3$\sigma_{MAD}$",
541 )
542 ax.plot(xEdgesPlot[::-1], threeSigMadVerts[len(xEdgesPlot) :, 1], color, alpha=0.4)
544 # Add lines for the median +/- 1 * sigma MAD
545 (sigMadLine,) = ax.plot(
546 xEdgesPlot, meds + 1.0 * sigMads, color, alpha=0.8, label=r"$\sigma_{MAD}$"
547 )
548 linesForLegend.append(sigMadLine)
549 ax.plot(xEdgesPlot, meds - 1.0 * sigMads, color, alpha=0.8)
551 # Add lines for the median +/- 2 * sigma MAD
552 (twoSigMadLine,) = ax.plot(
553 xEdgesPlot, meds + 2.0 * sigMads, color, alpha=0.6, label=r"2$\sigma_{MAD}$"
554 )
555 linesForLegend.append(twoSigMadLine)
556 linesForLegend.append(threeSigMadLine)
557 ax.plot(xEdgesPlot, meds - 2.0 * sigMads, color, alpha=0.6)
559 # Check which points are outside 3 sigma MAD of the median
560 # and plot these as points.
561 inside = threeSigMadPath.contains_points(np.array([xs, ys]).T)
562 ax.plot(xs[~inside], ys[~inside], ".", ms=3, alpha=0.3, mfc=color, mec=color, zorder=-1)
564 # Add some stats text
565 xPos = 0.65 - 0.4 * j
566 bbox = dict(edgecolor=color, linestyle="--", facecolor="none")
567 highThresh = data["highSnThreshold"]
568 statText = f"S/N > {highThresh:0.4g} Stats ({self.magLabel} < {highStats.approxMag:0.4g})\n"
569 highStatsStr = (
570 f"Median: {highStats.median:0.4g} "
571 + r"$\sigma_{MAD}$: "
572 + f"{highStats.sigmaMad:0.4g} "
573 + r"N$_{points}$: "
574 + f"{highStats.count}"
575 )
576 statText += highStatsStr
577 fig.text(xPos, 0.090, statText, bbox=bbox, transform=fig.transFigure, fontsize=6)
579 bbox = dict(edgecolor=color, linestyle=":", facecolor="none")
580 lowThresh = data["lowSnThreshold"]
581 statText = f"S/N > {lowThresh:0.4g} Stats ({self.magLabel} < {lowStats.approxMag:0.4g})\n"
582 lowStatsStr = (
583 f"Median: {lowStats.median:0.4g} "
584 + r"$\sigma_{MAD}$: "
585 + f"{lowStats.sigmaMad:0.4g} "
586 + r"N$_{points}$: "
587 + f"{lowStats.count}"
588 )
589 statText += lowStatsStr
590 fig.text(xPos, 0.020, statText, bbox=bbox, transform=fig.transFigure, fontsize=6)
592 if self.plot2DHist:
593 histIm = ax.hexbin(xs[inside], ys[inside], gridsize=75, cmap=cmap, mincnt=1, zorder=-3)
595 # If there are not many sources being used for the
596 # statistics then plot them individually as just
597 # plotting a line makes the statistics look wrong
598 # as the magnitude estimation is iffy for low
599 # numbers of sources.
600 if np.sum(highSn) < 100 and np.sum(highSn) > 0:
601 ax.plot(
602 cast(Vector, xs[highSn]),
603 cast(Vector, ys[highSn]),
604 marker="x",
605 ms=4,
606 mec="w",
607 mew=2,
608 ls="none",
609 )
610 (highSnLine,) = ax.plot(
611 cast(Vector, xs[highSn]),
612 cast(Vector, ys[highSn]),
613 color=color,
614 marker="x",
615 ms=4,
616 ls="none",
617 label="High SN",
618 )
619 linesForLegend.append(highSnLine)
620 xMin = np.min(cast(Vector, xs[highSn]))
621 else:
622 ax.axvline(highStats.approxMag, color=color, ls="--")
624 if np.sum(lowSn) < 100 and np.sum(lowSn) > 0:
625 ax.plot(
626 cast(Vector, xs[lowSn]),
627 cast(Vector, ys[lowSn]),
628 marker="+",
629 ms=4,
630 mec="w",
631 mew=2,
632 ls="none",
633 )
634 (lowSnLine,) = ax.plot(
635 cast(Vector, xs[lowSn]),
636 cast(Vector, ys[lowSn]),
637 color=color,
638 marker="+",
639 ms=4,
640 ls="none",
641 label="Low SN",
642 )
643 linesForLegend.append(lowSnLine)
644 if xMin is None or xMin > np.min(cast(Vector, xs[lowSn])):
645 xMin = np.min(cast(Vector, xs[lowSn]))
646 else:
647 ax.axvline(lowStats.approxMag, color=color, ls=":")
649 else:
650 ax.plot(xs, ys, ".", ms=5, alpha=0.3, mfc=color, mec=color, zorder=-1)
651 meds = np.array([np.nanmedian(ys)] * len(xs))
652 (medLine,) = ax.plot(xs, meds, color, label=f"Median: {np.nanmedian(ys):0.3g}", lw=0.8)
653 linesForLegend.append(medLine)
654 sigMads = np.array([nansigmaMad(ys)] * len(xs))
655 (sigMadLine,) = ax.plot(
656 xs,
657 meds + 1.0 * sigMads,
658 color,
659 alpha=0.8,
660 lw=0.8,
661 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:0.3g}",
662 )
663 ax.plot(xs, meds - 1.0 * sigMads, color, alpha=0.8)
664 linesForLegend.append(sigMadLine)
665 histIm = None
667 # Add a horizontal reference line at 0 to the scatter plot
668 ax.axhline(0, color=kwargs["hlineColor"], ls=kwargs["hlineStyle"], alpha=0.7, zorder=-2)
670 # Set the scatter plot limits
671 # TODO: Make this not work by accident
672 if len(cast(Vector, data["yStars"])) > 0:
673 plotMed = np.nanmedian(cast(Vector, data["yStars"]))
674 else:
675 plotMed = np.nanmedian(cast(Vector, data["yGalaxies"]))
676 # Ignore types below pending making this not working my accident
677 if len(xs) < 2: # type: ignore
678 meds = [np.nanmedian(ys)] # type: ignore
679 if self.yLims:
680 ax.set_ylim(self.yLims[0], self.yLims[1]) # type: ignore
681 else:
682 numSig = 4
683 yLimMin = plotMed - numSig * sigMadYs # type: ignore
684 yLimMax = plotMed + numSig * sigMadYs # type: ignore
685 while (yLimMax < np.max(meds) or yLimMin > np.min(meds)) and numSig < 10: # type: ignore
686 numSig += 1
688 numSig += 1
689 yLimMin = plotMed - numSig * sigMadYs # type: ignore
690 yLimMax = plotMed + numSig * sigMadYs # type: ignore
691 ax.set_ylim(yLimMin, yLimMax)
693 if self.xLims:
694 ax.set_xlim(self.xLims[0], self.xLims[1]) # type: ignore
695 elif len(xs) > 2: # type: ignore
696 if xMin is None:
697 xMin = xs1 - 2 * xScale # type: ignore
698 ax.set_xlim(xMin, xs97 + 2 * xScale) # type: ignore
700 # Add a line legend
701 ax.legend(
702 handles=linesForLegend,
703 ncol=4,
704 fontsize=6,
705 loc="upper left",
706 framealpha=0.9,
707 edgecolor="k",
708 borderpad=0.4,
709 handlelength=1,
710 )
712 # Add axes labels
713 ax.set_ylabel(self.yAxisLabel, fontsize=10, labelpad=10)
714 ax.set_xlabel(self.xAxisLabel, fontsize=10, labelpad=2)
716 return ax, histIm
718 def _makeTopHistogram(
719 self, data: KeyedData, figure: Figure, gs: gridspec.GridSpec, ax: Axes, **kwargs
720 ) -> None:
721 # Top histogram
722 totalX: list[Vector] = []
723 if "stars" in self.plotTypes: # type: ignore
724 totalX.append(cast(Vector, data["xStars"]))
725 if "galaxies" in self.plotTypes: # type: ignore
726 totalX.append(cast(Vector, data["xGalaxies"]))
727 if "unknown" in self.plotTypes: # type: ignore
728 totalX.append(cast(Vector, data["xUknown"]))
729 if "any" in self.plotTypes: # type: ignore
730 totalX.append(cast(Vector, data["x"]))
732 totalXChained = [x for x in chain.from_iterable(totalX) if x == x]
734 topHist = figure.add_subplot(gs[0, :-1], sharex=ax)
735 topHist.hist(
736 totalXChained, bins=100, color="grey", alpha=0.3, log=True, label=f"All ({len(totalXChained)})"
737 )
738 if "galaxies" in self.plotTypes: # type: ignore
739 topHist.hist(
740 data["xGalaxies"],
741 bins=100,
742 color="firebrick",
743 histtype="step",
744 log=True,
745 label=f"Galaxies ({len(cast(Vector, data['xGalaxies']))})",
746 )
747 if "stars" in self.plotTypes: # type: ignore
748 topHist.hist(
749 data["xStars"],
750 bins=100,
751 color="midnightblue",
752 histtype="step",
753 log=True,
754 label=f"Stars ({len(cast(Vector, data['xStars']))})",
755 )
756 topHist.axes.get_xaxis().set_visible(False)
757 topHist.set_ylabel("Number", fontsize=8)
758 topHist.legend(fontsize=6, framealpha=0.9, borderpad=0.4, loc="lower left", ncol=3, edgecolor="k")
760 # Side histogram
762 def _makeSideHistogram(
763 self,
764 data: KeyedData,
765 figure: Figure,
766 gs: gridspec.Gridspec,
767 ax: Axes,
768 histIm: Optional[PolyCollection],
769 **kwargs,
770 ) -> None:
771 sideHist = figure.add_subplot(gs[1:, -1], sharey=ax)
773 totalY: list[Vector] = []
774 if "stars" in self.plotTypes: # type: ignore
775 totalY.append(cast(Vector, data["yStars"]))
776 if "galaxies" in self.plotTypes: # type: ignore
777 totalY.append(cast(Vector, data["yGalaxies"]))
778 if "unknown" in self.plotTypes: # type: ignore
779 totalY.append(cast(Vector, data["yUknown"]))
780 if "any" in self.plotTypes: # type: ignore
781 totalY.append(cast(Vector, data["y"]))
782 totalYChained = [y for y in chain.from_iterable(totalY) if y == y]
784 # cheat to get the total count while iterating once
785 yLimMin, yLimMax = ax.get_ylim()
786 bins = np.linspace(yLimMin, yLimMax)
787 sideHist.hist(
788 totalYChained,
789 bins=bins,
790 color="grey",
791 alpha=0.3,
792 orientation="horizontal",
793 log=True,
794 )
795 if "galaxies" in self.plotTypes: # type: ignore
796 sideHist.hist(
797 [g for g in cast(Vector, data["yGalaxies"]) if g == g],
798 bins=bins,
799 color="firebrick",
800 histtype="step",
801 orientation="horizontal",
802 log=True,
803 )
804 sideHist.hist(
805 cast(Vector, data["yGalaxies"])[cast(Vector, data["galaxiesHighSNMask"])],
806 bins=bins,
807 color="firebrick",
808 histtype="step",
809 orientation="horizontal",
810 log=True,
811 ls="--",
812 )
813 sideHist.hist(
814 cast(Vector, data["yGalaxies"])[cast(Vector, data["galaxiesLowSNMask"])],
815 bins=bins,
816 color="firebrick",
817 histtype="step",
818 orientation="horizontal",
819 log=True,
820 ls=":",
821 )
823 if "stars" in self.plotTypes: # type: ignore
824 sideHist.hist(
825 [s for s in cast(Vector, data["yStars"]) if s == s],
826 bins=bins,
827 color="midnightblue",
828 histtype="step",
829 orientation="horizontal",
830 log=True,
831 )
832 sideHist.hist(
833 cast(Vector, data["yStars"])[cast(Vector, data["starsHighSNMask"])],
834 bins=bins,
835 color="midnightblue",
836 histtype="step",
837 orientation="horizontal",
838 log=True,
839 ls="--",
840 )
841 sideHist.hist(
842 cast(Vector, data["yStars"])[cast(Vector, data["starsLowSNMask"])],
843 bins=bins,
844 color="midnightblue",
845 histtype="step",
846 orientation="horizontal",
847 log=True,
848 ls=":",
849 )
851 # Add a horizontal reference line at 0 to the side histogram
852 sideHist.axhline(0, color=kwargs["hlineColor"], ls=kwargs["hlineStyle"], alpha=0.7, zorder=-2)
854 sideHist.axes.get_yaxis().set_visible(False)
855 sideHist.set_xlabel("Number", fontsize=8)
856 if self.plot2DHist and histIm is not None:
857 divider = make_axes_locatable(sideHist)
858 cax = divider.append_axes("right", size="8%", pad=0)
859 plt.colorbar(histIm, cax=cax, orientation="vertical", label="Number of Points Per Bin")