Coverage for python/lsst/analysis/tools/actions/plot/scatterplotWithTwoHists.py: 11%
375 statements
« prev ^ index » next coverage.py v7.2.5, created at 2023-05-12 03:08 -0700
« prev ^ index » next coverage.py v7.2.5, created at 2023-05-12 03:08 -0700
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ("ScatterPlotStatsAction", "ScatterPlotWithTwoHists")
26from itertools import chain
27from typing import Mapping, NamedTuple, Optional, cast
29import matplotlib.pyplot as plt
30import numpy as np
31from lsst.pex.config import Field
32from lsst.pex.config.configurableActions import ConfigurableActionField
33from lsst.pex.config.listField import ListField
34from lsst.skymap import BaseSkyMap
35from matplotlib import gridspec
36from matplotlib.axes import Axes
37from matplotlib.collections import PolyCollection
38from matplotlib.figure import Figure
39from matplotlib.path import Path
40from mpl_toolkits.axes_grid1 import make_axes_locatable
42from ...interfaces import KeyedData, KeyedDataAction, KeyedDataSchema, PlotAction, Scalar, ScalarType, Vector
43from ...statistics import nansigmaMad, sigmaMad
44from ..keyedData.summaryStatistics import SummaryStatisticAction
45from ..scalar import MedianAction
46from ..vector import MagColumnNanoJansky, SnSelector
47from .plotUtils import addPlotInfo, addSummaryPlot, generateSummaryStats, mkColormap
49# ignore because coolwarm is actually part of module
50cmapPatch = plt.cm.coolwarm.copy() # type: ignore
51cmapPatch.set_bad(color="none")
54class ScatterPlotStatsAction(KeyedDataAction):
55 """Calculates the statistics needed for the
56 scatter plot with two hists.
57 """
59 vectorKey = Field[str](doc="Vector on which to compute statistics")
60 highSNSelector = ConfigurableActionField[SnSelector](
61 doc="Selector used to determine high SN Objects", default=SnSelector(threshold=2700)
62 )
63 lowSNSelector = ConfigurableActionField[SnSelector](
64 doc="Selector used to determine low SN Objects", default=SnSelector(threshold=500)
65 )
66 fluxType = Field[str](doc="Vector key to use to compute signal to noise ratio", default="{band}_psfFlux")
68 def getInputSchema(self, **kwargs) -> KeyedDataSchema:
69 yield (self.vectorKey, Vector)
70 yield (self.fluxType, Vector)
71 yield from self.highSNSelector.getInputSchema()
72 yield from self.lowSNSelector.getInputSchema()
74 def getOutputSchema(self) -> KeyedDataSchema:
75 return (
76 (f'{self.identity or ""}HighSNMask', Vector),
77 (f'{self.identity or ""}LowSNMask', Vector),
78 (f"{{band}}_lowSN{self.identity.capitalize() if self.identity else ''}_median", Scalar),
79 (f"{{band}}_lowSN{self.identity.capitalize() if self.identity else ''}_sigmaMad", Scalar),
80 (f"{{band}}_lowSN{self.identity.capitalize() if self.identity else ''}_count", Scalar),
81 (f"{{band}}_lowSN{self.identity.capitalize() if self.identity else ''}_approxMag", Scalar),
82 (f"{{band}}_highSN{self.identity.capitalize() if self.identity else ''}_median", Scalar),
83 (f"{{band}}_highSN{self.identity.capitalize() if self.identity else ''}_sigmaMad", Scalar),
84 (f"{{band}}_highSN{self.identity.capitalize() if self.identity else ''}_count", Scalar),
85 (f"{{band}}_highSN{self.identity.capitalize() if self.identity else ''}_approxMag", Scalar),
86 ("highThreshold", Scalar),
87 ("lowThreshold", Scalar),
88 )
90 def __call__(self, data: KeyedData, **kwargs) -> KeyedData:
91 results = {}
92 highMaskKey = f'{(self.identity or "").lower()}HighSNMask'
93 results[highMaskKey] = self.highSNSelector(data, **kwargs)
95 lowMaskKey = f'{(self.identity or "").lower()}LowSNMask'
96 results[lowMaskKey] = self.lowSNSelector(data, **kwargs)
98 prefix = f"{band}_" if (band := kwargs.get("band")) else ""
99 fluxes = data[self.fluxType.format(band=band)] if band is not None else None
101 statAction = SummaryStatisticAction(vectorKey=self.vectorKey)
103 # this is sad, but pex_config seems to have broken behavior that
104 # is dangerous to fix
105 statAction.setDefaults()
107 medianAction = MedianAction(vectorKey="mag")
108 magAction = MagColumnNanoJansky(vectorKey="flux")
110 for maskKey, binName in ((lowMaskKey, "low"), (highMaskKey, "high")):
111 name = f"{prefix}{binName}SN{self.identity.capitalize() if self.identity else ''}"
112 # set the approxMag to the median mag in the SN selection
113 results[f"{name}_approxMag".format(**kwargs)] = (
114 medianAction({"mag": magAction({"flux": fluxes[results[maskKey]]})}) # type: ignore
115 if band is not None
116 else np.nan
117 )
118 stats = statAction(data, **(kwargs | {"mask": results[maskKey]})).items()
119 for suffix, value in stats:
120 tmpKey = f"{name}_{suffix}".format(**kwargs)
121 results[tmpKey] = value
122 results["highSnThreshold"] = self.highSNSelector.threshold # type: ignore
123 results["lowSnThreshold"] = self.lowSNSelector.threshold # type: ignore
125 return results
128def _validatePlotTypes(value):
129 return value in ("stars", "galaxies", "unknown", "any", "mag")
132# ignore type because of conflicting name on tuple baseclass
133class _StatsContainer(NamedTuple):
134 median: Scalar
135 sigmaMad: Scalar
136 count: Scalar # type: ignore
137 approxMag: Scalar
140class ScatterPlotWithTwoHists(PlotAction):
141 """Makes a scatter plot of the data with a marginal
142 histogram for each axis.
143 """
145 yLims = ListField[float](
146 doc="ylimits of the plot, if not specified determined from data",
147 length=2,
148 optional=True,
149 )
151 xLims = ListField[float](
152 doc="xlimits of the plot, if not specified determined from data", length=2, optional=True
153 )
154 xAxisLabel = Field[str](doc="Label to use for the x axis", optional=False)
155 yAxisLabel = Field[str](doc="Label to use for the y axis", optional=False)
156 magLabel = Field[str](doc="Label to use for the magnitudes used for SNR", optional=False)
157 nBins = Field[float](doc="Number of bins on x axis", default=40.0)
158 plot2DHist = Field[bool](
159 doc="Plot a 2D histogram in dense areas of points on the scatter plot."
160 "Doesn't look great if plotting multiple datasets on top of each other.",
161 default=True,
162 )
163 plotTypes = ListField[str](
164 doc="Selection of types of objects to plot. Can take any combination of"
165 " stars, galaxies, unknown, mag, any",
166 optional=False,
167 itemCheck=_validatePlotTypes,
168 )
170 addSummaryPlot = Field[bool](
171 doc="Add a summary plot to the figure?",
172 default=False,
173 )
175 _stats = ("median", "sigmaMad", "count", "approxMag")
177 def getInputSchema(self) -> KeyedDataSchema:
178 base: list[tuple[str, type[Vector] | ScalarType]] = []
179 if "stars" in self.plotTypes: # type: ignore
180 base.append(("xStars", Vector))
181 base.append(("yStars", Vector))
182 base.append(("starsHighSNMask", Vector))
183 base.append(("starsLowSNMask", Vector))
184 # statistics
185 for name in self._stats:
186 base.append((f"{{band}}_highSNStars_{name}", Scalar))
187 base.append((f"{{band}}_lowSNStars_{name}", Scalar))
188 if "galaxies" in self.plotTypes: # type: ignore
189 base.append(("xGalaxies", Vector))
190 base.append(("yGalaxies", Vector))
191 base.append(("galaxiesHighSNMask", Vector))
192 base.append(("galaxiesLowSNMask", Vector))
193 # statistics
194 for name in self._stats:
195 base.append((f"{{band}}_highSNGalaxies_{name}", Scalar))
196 base.append((f"{{band}}_lowSNGalaxies_{name}", Scalar))
197 if "unknown" in self.plotTypes: # type: ignore
198 base.append(("xUnknown", Vector))
199 base.append(("yUnknown", Vector))
200 base.append(("unknownHighSNMask", Vector))
201 base.append(("unknownLowSNMask", Vector))
202 # statistics
203 for name in self._stats:
204 base.append((f"{{band}}_highSNUnknown_{name}", Scalar))
205 base.append((f"{{band}}_lowSNUnknown_{name}", Scalar))
206 if "any" in self.plotTypes: # type: ignore
207 base.append(("x", Vector))
208 base.append(("y", Vector))
209 base.append(("anyHighSNMask", Vector))
210 base.append(("anySNMask", Vector))
211 # statistics
212 for name in self._stats:
213 base.append((f"{{band}}_highSNAny_{name}", Scalar))
214 base.append((f"{{band}}_lowSNAny_{name}", Scalar))
215 base.append(("lowSnThreshold", Scalar))
216 base.append(("highSnThreshold", Scalar))
218 if self.addSummaryPlot:
219 base.append(("patch", Vector))
221 return base
223 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure:
224 self._validateInput(data, **kwargs)
225 return self.makePlot(data, **kwargs)
227 def _validateInput(self, data: KeyedData, **kwargs) -> None:
228 """NOTE currently can only check that something is not a Scalar, not
229 check that the data is consistent with Vector
230 """
231 needed = self.getFormattedInputSchema(**kwargs)
232 if remainder := {key.format(**kwargs) for key, _ in needed} - {
233 key.format(**kwargs) for key in data.keys()
234 }:
235 raise ValueError(f"Task needs keys {remainder} but they were not found in input")
236 for name, typ in needed:
237 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar)
238 if isScalar and typ != Scalar:
239 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}")
241 def makePlot(
242 self,
243 data: KeyedData,
244 skymap: BaseSkyMap,
245 plotInfo: Mapping[str, str],
246 sumStats: Optional[Mapping] = None,
247 **kwargs,
248 ) -> Figure:
249 """Makes a generic plot with a 2D histogram and collapsed histograms of
250 each axis.
252 Parameters
253 ----------
254 data : `KeyedData`
255 The catalog to plot the points from.
256 skymap : `lsst.skymap.BaseSkyMap`
257 The skymap that gives the patch locations
258 plotInfo : `dict`
259 A dictionary of information about the data being plotted with keys:
261 * ``"run"``
262 The output run for the plots (`str`).
263 * ``"skymap"``
264 The type of skymap used for the data (`str`).
265 * ``"filter"``
266 The filter used for this data (`str`).
267 * ``"tract"``
268 The tract that the data comes from (`str`).
269 sumStats : `dict`
270 A dictionary where the patchIds are the keys which store the R.A.
271 and dec of the corners of the patch, along with a summary
272 statistic for each patch.
274 Returns
275 -------
276 fig : `matplotlib.figure.Figure`
277 The resulting figure.
279 Notes
280 -----
281 Uses the axisLabels config options `x` and `y` and the axisAction
282 config options `xAction` and `yAction` to plot a scatter
283 plot of the values against each other. A histogram of the points
284 collapsed onto each axis is also plotted. A summary panel showing the
285 median of the y value in each patch is shown in the upper right corner
286 of the resultant plot. The code uses the selectorActions to decide
287 which points to plot and the statisticSelector actions to determine
288 which points to use for the printed statistics.
290 If this function is being used within the pipetask framework
291 that takes care of making sure that data has all the required
292 elements but if you are runnign this as a standalone function
293 then you will need to provide the following things in the
294 input data.
296 * If stars is in self.plotTypes:
297 xStars, yStars, starsHighSNMask, starsLowSNMask and
298 {band}_highSNStars_{name}, {band}_lowSNStars_{name}
299 where name is median, sigma_Mad, count and approxMag.
301 * If it is for galaxies/unknowns then replace stars in the above
302 names with galaxies/unknowns.
304 * If it is for any (which covers all the points) then it
305 becomes, x, y, and any instead of stars for the other
306 parameters given above.
308 * In every case it is expected that data contains:
309 lowSnThreshold, highSnThreshold and patch
310 (if the summary plot is being plotted).
312 Examples
313 --------
314 An example of the plot produced from this code is here:
316 .. image:: /_static/analysis_tools/scatterPlotExample.png
318 For a detailed example of how to make a plot from the command line
319 please see the
320 :ref:`getting started guide<analysis-tools-getting-started>`.
321 """
322 if not self.plotTypes:
323 noDataFig = Figure()
324 noDataFig.text(0.3, 0.5, "No data to plot after selectors applied")
325 noDataFig = addPlotInfo(noDataFig, plotInfo)
326 return noDataFig
328 # Set default color and line style for the horizontal
329 # reference line at 0
330 if "hlineColor" not in kwargs:
331 kwargs["hlineColor"] = "black"
333 if "hlineStyle" not in kwargs:
334 kwargs["hlineStyle"] = (0, (1, 4))
336 fig = plt.figure(dpi=300)
337 gs = gridspec.GridSpec(4, 4)
339 # add the various plot elements
340 ax, imhist = self._scatterPlot(data, fig, gs, **kwargs)
341 self._makeTopHistogram(data, fig, gs, ax, **kwargs)
342 self._makeSideHistogram(data, fig, gs, ax, imhist, **kwargs)
343 # Needs info from run quantum
344 if self.addSummaryPlot:
345 sumStats = generateSummaryStats(data, skymap, plotInfo)
346 label = self.yAxisLabel
347 fig = addSummaryPlot(fig, gs[0, -1], sumStats, label)
349 plt.draw()
350 plt.subplots_adjust(wspace=0.0, hspace=0.0, bottom=0.22, left=0.21)
351 fig = addPlotInfo(fig, plotInfo)
352 return fig
354 def _scatterPlot(
355 self, data: KeyedData, fig: Figure, gs: gridspec.GridSpec, **kwargs
356 ) -> tuple[Axes, Optional[PolyCollection]]:
357 # Main scatter plot
358 ax = fig.add_subplot(gs[1:, :-1])
360 newBlues = mkColormap(["paleturquoise", "midnightBlue"])
361 newReds = mkColormap(["lemonchiffon", "firebrick"])
363 binThresh = 5
365 yBinsOut = []
366 linesForLegend = []
368 toPlotList = []
369 histIm = None
370 highStats: _StatsContainer
371 lowStats: _StatsContainer
372 if "stars" in self.plotTypes: # type: ignore
373 highArgs = {}
374 lowArgs = {}
375 for name in self._stats:
376 highArgs[name] = cast(Scalar, data[f"{{band}}_highSNStars_{name}".format(**kwargs)])
377 lowArgs[name] = cast(Scalar, data[f"{{band}}_lowSNStars_{name}".format(**kwargs)])
378 highStats = _StatsContainer(**highArgs)
379 lowStats = _StatsContainer(**lowArgs)
381 toPlotList.append(
382 (
383 data["xStars"],
384 data["yStars"],
385 data["starsHighSNMask"],
386 data["starsLowSNMask"],
387 "midnightblue",
388 newBlues,
389 highStats,
390 lowStats,
391 )
392 )
393 if "galaxies" in self.plotTypes: # type: ignore
394 highArgs = {}
395 lowArgs = {}
396 for name in self._stats:
397 highArgs[name] = cast(Scalar, data[f"{{band}}_highSNGalaxies_{name}".format(**kwargs)])
398 lowArgs[name] = cast(Scalar, data[f"{{band}}_lowSNGalaxies_{name}".format(**kwargs)])
399 highStats = _StatsContainer(**highArgs)
400 lowStats = _StatsContainer(**lowArgs)
402 toPlotList.append(
403 (
404 data["xGalaxies"],
405 data["yGalaxies"],
406 data["galaxiesHighSNMask"],
407 data["galaxiesLowSNMask"],
408 "firebrick",
409 newReds,
410 highStats,
411 lowStats,
412 )
413 )
414 if "unknown" in self.plotTypes: # type: ignore
415 highArgs = {}
416 lowArgs = {}
417 for name in self._stats:
418 highArgs[name] = cast(Scalar, data[f"{{band}}_highSNUnknown_{name}".format(**kwargs)])
419 lowArgs[name] = cast(Scalar, data[f"{{band}}_lowSNUnknown_{name}".format(**kwargs)])
420 highStats = _StatsContainer(**highArgs)
421 lowStats = _StatsContainer(**lowArgs)
423 toPlotList.append(
424 (
425 data["xUnknown"],
426 data["yUnknown"],
427 data["unknownHighSNMask"],
428 data["unknownLowSNMask"],
429 "green",
430 None,
431 highStats,
432 lowStats,
433 )
434 )
435 if "any" in self.plotTypes: # type: ignore
436 highArgs = {}
437 lowArgs = {}
438 for name in self._stats:
439 highArgs[name] = cast(Scalar, data[f"{{band}}_highSNUnknown_{name}".format(**kwargs)])
440 lowArgs[name] = cast(Scalar, data[f"{{band}}_lowSNUnknown_{name}".format(**kwargs)])
441 highStats = _StatsContainer(**highArgs)
442 lowStats = _StatsContainer(**lowArgs)
444 toPlotList.append(
445 (
446 data["x"],
447 data["y"],
448 data["anyHighSNMask"],
449 data["anyLowSNMask"],
450 "purple",
451 None,
452 highStats,
453 lowStats,
454 )
455 )
457 xMin = None
458 for j, (xs, ys, highSn, lowSn, color, cmap, highStats, lowStats) in enumerate(toPlotList):
459 highSn = cast(Vector, highSn)
460 lowSn = cast(Vector, lowSn)
461 # ensure the columns are actually array
462 xs = np.array(xs)
463 ys = np.array(ys)
464 sigMadYs = nansigmaMad(ys)
465 # plot lone median point if there's not enough data to measure more
466 n_xs = len(xs)
467 if n_xs == 0:
468 continue
469 elif n_xs < 10:
470 xs = [np.nanmedian(xs)]
471 sigMads = np.array([nansigmaMad(ys)])
472 ys = [np.nanmedian(ys)]
473 (medLine,) = ax.plot(xs, ys, color, label=f"Median: {ys[0]:.2g}", lw=0.8)
474 linesForLegend.append(medLine)
475 (sigMadLine,) = ax.plot(
476 xs,
477 ys + 1.0 * sigMads,
478 color,
479 alpha=0.8,
480 lw=0.8,
481 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:.2g}",
482 )
483 ax.plot(xs, ys - 1.0 * sigMads, color, alpha=0.8)
484 linesForLegend.append(sigMadLine)
485 histIm = None
486 continue
488 [xs1, xs25, xs50, xs75, xs95, xs97] = np.nanpercentile(xs, [1, 25, 50, 75, 95, 97])
489 xScale = (xs97 - xs1) / 20.0 # This is ~5% of the data range
491 # 40 was used as the default number of bins because it looked good
492 xEdges = np.arange(
493 np.nanmin(xs) - xScale,
494 np.nanmax(xs) + xScale,
495 (np.nanmax(xs) + xScale - (np.nanmin(xs) - xScale)) / self.nBins,
496 )
497 medYs = np.nanmedian(ys)
498 fiveSigmaHigh = medYs + 5.0 * sigMadYs
499 fiveSigmaLow = medYs - 5.0 * sigMadYs
500 binSize = (fiveSigmaHigh - fiveSigmaLow) / 101.0
501 yEdges = np.arange(fiveSigmaLow, fiveSigmaHigh, binSize)
503 counts, xBins, yBins = np.histogram2d(xs, ys, bins=(xEdges, yEdges))
504 yBinsOut.append(yBins)
505 countsYs = np.sum(counts, axis=1)
507 ids = np.where((countsYs > binThresh))[0]
508 xEdgesPlot = xEdges[ids][1:]
509 xEdges = xEdges[ids]
511 if len(ids) > 1:
512 # Create the codes needed to turn the sigmaMad lines
513 # into a path to speed up checking which points are
514 # inside the area.
515 codes = np.ones(len(xEdgesPlot) * 2) * Path.LINETO
516 codes[0] = Path.MOVETO
517 codes[-1] = Path.CLOSEPOLY
519 meds = np.zeros(len(xEdgesPlot))
520 threeSigMadVerts = np.zeros((len(xEdgesPlot) * 2, 2))
521 sigMads = np.zeros(len(xEdgesPlot))
523 for i, xEdge in enumerate(xEdgesPlot):
524 ids = np.where((xs < xEdge) & (xs > xEdges[i]) & (np.isfinite(ys)))[0]
525 med = np.nanmedian(ys[ids])
526 sigMad = sigmaMad(ys[ids], nan_policy="omit")
527 meds[i] = med
528 sigMads[i] = sigMad
529 threeSigMadVerts[i, :] = [xEdge, med + 3 * sigMad]
530 threeSigMadVerts[-(i + 1), :] = [xEdge, med - 3 * sigMad]
532 (medLine,) = ax.plot(xEdgesPlot, meds, color, label="Running Median")
533 linesForLegend.append(medLine)
535 # Make path to check which points lie within one sigma mad
536 threeSigMadPath = Path(threeSigMadVerts, codes)
538 # Add lines for the median +/- 3 * sigma MAD
539 (threeSigMadLine,) = ax.plot(
540 xEdgesPlot,
541 threeSigMadVerts[: len(xEdgesPlot), 1],
542 color,
543 alpha=0.4,
544 label=r"3$\sigma_{MAD}$",
545 )
546 ax.plot(xEdgesPlot[::-1], threeSigMadVerts[len(xEdgesPlot) :, 1], color, alpha=0.4)
548 # Add lines for the median +/- 1 * sigma MAD
549 (sigMadLine,) = ax.plot(
550 xEdgesPlot, meds + 1.0 * sigMads, color, alpha=0.8, label=r"$\sigma_{MAD}$"
551 )
552 linesForLegend.append(sigMadLine)
553 ax.plot(xEdgesPlot, meds - 1.0 * sigMads, color, alpha=0.8)
555 # Add lines for the median +/- 2 * sigma MAD
556 (twoSigMadLine,) = ax.plot(
557 xEdgesPlot, meds + 2.0 * sigMads, color, alpha=0.6, label=r"2$\sigma_{MAD}$"
558 )
559 linesForLegend.append(twoSigMadLine)
560 linesForLegend.append(threeSigMadLine)
561 ax.plot(xEdgesPlot, meds - 2.0 * sigMads, color, alpha=0.6)
563 # Check which points are outside 3 sigma MAD of the median
564 # and plot these as points.
565 inside = threeSigMadPath.contains_points(np.array([xs, ys]).T)
566 ax.plot(xs[~inside], ys[~inside], ".", ms=3, alpha=0.3, mfc=color, mec=color, zorder=-1)
568 # Add some stats text
569 xPos = 0.65 - 0.4 * j
570 bbox = dict(edgecolor=color, linestyle="--", facecolor="none")
571 highThresh = data["highSnThreshold"]
572 statText = f"S/N > {highThresh:0.4g} Stats ({self.magLabel} < {highStats.approxMag:0.4g})\n"
573 highStatsStr = (
574 f"Median: {highStats.median:0.4g} "
575 + r"$\sigma_{MAD}$: "
576 + f"{highStats.sigmaMad:0.4g} "
577 + r"N$_{points}$: "
578 + f"{highStats.count}"
579 )
580 statText += highStatsStr
581 fig.text(xPos, 0.090, statText, bbox=bbox, transform=fig.transFigure, fontsize=6)
583 bbox = dict(edgecolor=color, linestyle=":", facecolor="none")
584 lowThresh = data["lowSnThreshold"]
585 statText = f"S/N > {lowThresh:0.4g} Stats ({self.magLabel} < {lowStats.approxMag:0.4g})\n"
586 lowStatsStr = (
587 f"Median: {lowStats.median:0.4g} "
588 + r"$\sigma_{MAD}$: "
589 + f"{lowStats.sigmaMad:0.4g} "
590 + r"N$_{points}$: "
591 + f"{lowStats.count}"
592 )
593 statText += lowStatsStr
594 fig.text(xPos, 0.020, statText, bbox=bbox, transform=fig.transFigure, fontsize=6)
596 if self.plot2DHist:
597 histIm = ax.hexbin(xs[inside], ys[inside], gridsize=75, cmap=cmap, mincnt=1, zorder=-3)
599 # If there are not many sources being used for the
600 # statistics then plot them individually as just
601 # plotting a line makes the statistics look wrong
602 # as the magnitude estimation is iffy for low
603 # numbers of sources.
604 if np.sum(highSn) < 100 and np.sum(highSn) > 0:
605 ax.plot(
606 cast(Vector, xs[highSn]),
607 cast(Vector, ys[highSn]),
608 marker="x",
609 ms=4,
610 mec="w",
611 mew=2,
612 ls="none",
613 )
614 (highSnLine,) = ax.plot(
615 cast(Vector, xs[highSn]),
616 cast(Vector, ys[highSn]),
617 color=color,
618 marker="x",
619 ms=4,
620 ls="none",
621 label="High SN",
622 )
623 linesForLegend.append(highSnLine)
624 xMin = np.min(cast(Vector, xs[highSn]))
625 else:
626 ax.axvline(highStats.approxMag, color=color, ls="--")
628 if np.sum(lowSn) < 100 and np.sum(lowSn) > 0:
629 ax.plot(
630 cast(Vector, xs[lowSn]),
631 cast(Vector, ys[lowSn]),
632 marker="+",
633 ms=4,
634 mec="w",
635 mew=2,
636 ls="none",
637 )
638 (lowSnLine,) = ax.plot(
639 cast(Vector, xs[lowSn]),
640 cast(Vector, ys[lowSn]),
641 color=color,
642 marker="+",
643 ms=4,
644 ls="none",
645 label="Low SN",
646 )
647 linesForLegend.append(lowSnLine)
648 if xMin is None or xMin > np.min(cast(Vector, xs[lowSn])):
649 xMin = np.min(cast(Vector, xs[lowSn]))
650 else:
651 ax.axvline(lowStats.approxMag, color=color, ls=":")
653 else:
654 ax.plot(xs, ys, ".", ms=5, alpha=0.3, mfc=color, mec=color, zorder=-1)
655 meds = np.array([np.nanmedian(ys)] * len(xs))
656 (medLine,) = ax.plot(xs, meds, color, label=f"Median: {np.nanmedian(ys):0.3g}", lw=0.8)
657 linesForLegend.append(medLine)
658 sigMads = np.array([nansigmaMad(ys)] * len(xs))
659 (sigMadLine,) = ax.plot(
660 xs,
661 meds + 1.0 * sigMads,
662 color,
663 alpha=0.8,
664 lw=0.8,
665 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:0.3g}",
666 )
667 ax.plot(xs, meds - 1.0 * sigMads, color, alpha=0.8)
668 linesForLegend.append(sigMadLine)
669 histIm = None
671 # Add a horizontal reference line at 0 to the scatter plot
672 ax.axhline(0, color=kwargs["hlineColor"], ls=kwargs["hlineStyle"], alpha=0.7, zorder=-2)
674 # Set the scatter plot limits
675 # TODO: Make this not work by accident
676 if "yStars" in data and (len(cast(Vector, data["yStars"])) > 0):
677 plotMed = np.nanmedian(cast(Vector, data["yStars"]))
678 elif "yGalaxies" in data and (len(cast(Vector, data["yGalaxies"])) > 0):
679 plotMed = np.nanmedian(cast(Vector, data["yGalaxies"]))
680 else:
681 plotMed = np.nan
683 # Ignore types below pending making this not working my accident
684 if len(xs) < 2: # type: ignore
685 meds = [np.nanmedian(ys)] # type: ignore
686 if self.yLims:
687 ax.set_ylim(self.yLims[0], self.yLims[1]) # type: ignore
688 elif np.isfinite(plotMed):
689 numSig = 4
690 yLimMin = plotMed - numSig * sigMadYs # type: ignore
691 yLimMax = plotMed + numSig * sigMadYs # type: ignore
692 while (yLimMax < np.max(meds) or yLimMin > np.min(meds)) and numSig < 10: # type: ignore
693 numSig += 1
695 numSig += 1
696 yLimMin = plotMed - numSig * sigMadYs # type: ignore
697 yLimMax = plotMed + numSig * sigMadYs # type: ignore
698 ax.set_ylim(yLimMin, yLimMax)
700 if self.xLims:
701 ax.set_xlim(self.xLims[0], self.xLims[1]) # type: ignore
702 elif len(xs) > 2: # type: ignore
703 if xMin is None:
704 xMin = xs1 - 2 * xScale # type: ignore
705 ax.set_xlim(xMin, xs97 + 2 * xScale) # type: ignore
707 # Add a line legend
708 ax.legend(
709 handles=linesForLegend,
710 ncol=4,
711 fontsize=6,
712 loc="upper left",
713 framealpha=0.9,
714 edgecolor="k",
715 borderpad=0.4,
716 handlelength=1,
717 )
719 # Add axes labels
720 ax.set_ylabel(self.yAxisLabel, fontsize=10, labelpad=10)
721 ax.set_xlabel(self.xAxisLabel, fontsize=10, labelpad=2)
723 return ax, histIm
725 def _makeTopHistogram(
726 self, data: KeyedData, figure: Figure, gs: gridspec.GridSpec, ax: Axes, **kwargs
727 ) -> None:
728 # Top histogram
729 totalX: list[Vector] = []
730 if "stars" in self.plotTypes: # type: ignore
731 totalX.append(cast(Vector, data["xStars"]))
732 if "galaxies" in self.plotTypes: # type: ignore
733 totalX.append(cast(Vector, data["xGalaxies"]))
734 if "unknown" in self.plotTypes: # type: ignore
735 totalX.append(cast(Vector, data["xUknown"]))
736 if "any" in self.plotTypes: # type: ignore
737 totalX.append(cast(Vector, data["x"]))
739 totalXChained = [x for x in chain.from_iterable(totalX) if x == x]
741 topHist = figure.add_subplot(gs[0, :-1], sharex=ax)
742 topHist.hist(
743 totalXChained, bins=100, color="grey", alpha=0.3, log=True, label=f"All ({len(totalXChained)})"
744 )
745 if "galaxies" in self.plotTypes: # type: ignore
746 topHist.hist(
747 data["xGalaxies"],
748 bins=100,
749 color="firebrick",
750 histtype="step",
751 log=True,
752 label=f"Galaxies ({len(cast(Vector, data['xGalaxies']))})",
753 )
754 if "stars" in self.plotTypes: # type: ignore
755 topHist.hist(
756 data["xStars"],
757 bins=100,
758 color="midnightblue",
759 histtype="step",
760 log=True,
761 label=f"Stars ({len(cast(Vector, data['xStars']))})",
762 )
763 topHist.axes.get_xaxis().set_visible(False)
764 topHist.set_ylabel("Number", fontsize=8)
765 topHist.legend(fontsize=6, framealpha=0.9, borderpad=0.4, loc="lower left", ncol=3, edgecolor="k")
767 # Side histogram
769 def _makeSideHistogram(
770 self,
771 data: KeyedData,
772 figure: Figure,
773 gs: gridspec.Gridspec,
774 ax: Axes,
775 histIm: Optional[PolyCollection],
776 **kwargs,
777 ) -> None:
778 sideHist = figure.add_subplot(gs[1:, -1], sharey=ax)
780 totalY: dict[str, Vector] = {}
781 if "stars" in self.plotTypes and "yStars" in data: # type: ignore
782 totalY["stars"] = cast(Vector, data["yStars"])
783 if "galaxies" in self.plotTypes and "yGalaxies" in data: # type: ignore
784 totalY["galaxies"] = cast(Vector, data["yGalaxies"])
785 if "unknown" in self.plotTypes and "yUnknown" in data: # type: ignore
786 totalY["unknown"] = cast(Vector, data["yUnknown"])
787 if "any" in self.plotTypes and "y" in data: # type: ignore
788 totalY["y"] = cast(Vector, data["y"])
789 totalYChained = [y for y in chain.from_iterable(totalY.values()) if y == y]
791 # cheat to get the total count while iterating once
792 yLimMin, yLimMax = ax.get_ylim()
793 bins = np.linspace(yLimMin, yLimMax)
794 sideHist.hist(
795 totalYChained,
796 bins=bins,
797 color="grey",
798 alpha=0.3,
799 orientation="horizontal",
800 log=True,
801 )
802 if "galaxies" in totalY: # type: ignore
803 sideHist.hist(
804 [g for g in cast(Vector, data["yGalaxies"]) if g == g],
805 bins=bins,
806 color="firebrick",
807 histtype="step",
808 orientation="horizontal",
809 log=True,
810 )
811 sideHist.hist(
812 cast(Vector, data["yGalaxies"])[cast(Vector, data["galaxiesHighSNMask"])],
813 bins=bins,
814 color="firebrick",
815 histtype="step",
816 orientation="horizontal",
817 log=True,
818 ls="--",
819 )
820 sideHist.hist(
821 cast(Vector, data["yGalaxies"])[cast(Vector, data["galaxiesLowSNMask"])],
822 bins=bins,
823 color="firebrick",
824 histtype="step",
825 orientation="horizontal",
826 log=True,
827 ls=":",
828 )
830 if "stars" in totalY: # type: ignore
831 sideHist.hist(
832 [s for s in cast(Vector, data["yStars"]) if s == s],
833 bins=bins,
834 color="midnightblue",
835 histtype="step",
836 orientation="horizontal",
837 log=True,
838 )
839 sideHist.hist(
840 cast(Vector, data["yStars"])[cast(Vector, data["starsHighSNMask"])],
841 bins=bins,
842 color="midnightblue",
843 histtype="step",
844 orientation="horizontal",
845 log=True,
846 ls="--",
847 )
848 sideHist.hist(
849 cast(Vector, data["yStars"])[cast(Vector, data["starsLowSNMask"])],
850 bins=bins,
851 color="midnightblue",
852 histtype="step",
853 orientation="horizontal",
854 log=True,
855 ls=":",
856 )
858 # Add a horizontal reference line at 0 to the side histogram
859 sideHist.axhline(0, color=kwargs["hlineColor"], ls=kwargs["hlineStyle"], alpha=0.7, zorder=-2)
861 sideHist.axes.get_yaxis().set_visible(False)
862 sideHist.set_xlabel("Number", fontsize=8)
863 if self.plot2DHist and histIm is not None:
864 divider = make_axes_locatable(sideHist)
865 cax = divider.append_axes("right", size="8%", pad=0)
866 plt.colorbar(histIm, cax=cax, orientation="vertical", label="Number of Points Per Bin")