Coverage for python/lsst/analysis/tools/actions/plot/scatterplotWithTwoHists.py: 16%
342 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-15 02:38 -0700
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-15 02:38 -0700
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ("ScatterPlotStatsAction", "ScatterPlotWithTwoHists")
26from typing import Mapping, NamedTuple, Optional, cast
28import matplotlib.colors
29import matplotlib.pyplot as plt
30import numpy as np
31from lsst.pex.config import Field
32from lsst.pex.config.configurableActions import ConfigurableActionField
33from lsst.pex.config.listField import ListField
34from matplotlib import gridspec
35from matplotlib.axes import Axes
36from matplotlib.collections import PolyCollection
37from matplotlib.figure import Figure
38from matplotlib.path import Path
39from mpl_toolkits.axes_grid1 import make_axes_locatable
41from ...interfaces import KeyedData, KeyedDataAction, KeyedDataSchema, PlotAction, Scalar, ScalarType, Vector
42from ...math import nanMedian, nanSigmaMad
43from ..keyedData.summaryStatistics import SummaryStatisticAction
44from ..scalar import MedianAction
45from ..vector import ConvertFluxToMag, SnSelector
46from .plotUtils import addPlotInfo, addSummaryPlot, generateSummaryStats, mkColormap
48# ignore because coolwarm is actually part of module
49cmapPatch = plt.cm.coolwarm.copy() # type: ignore
50cmapPatch.set_bad(color="none")
53class ScatterPlotStatsAction(KeyedDataAction):
54 """Calculates the statistics needed for the
55 scatter plot with two hists.
56 """
58 vectorKey = Field[str](doc="Vector on which to compute statistics")
59 highSNSelector = ConfigurableActionField[SnSelector](
60 doc="Selector used to determine high SN Objects", default=SnSelector(threshold=2700)
61 )
62 lowSNSelector = ConfigurableActionField[SnSelector](
63 doc="Selector used to determine low SN Objects", default=SnSelector(threshold=500)
64 )
65 prefix = Field[str](
66 doc="Prefix for all output fields; will use self.identity if None",
67 optional=True,
68 default=None,
69 )
70 fluxType = Field[str](doc="Vector key to use to compute signal to noise ratio", default="{band}_psfFlux")
71 suffix = Field[str](doc="Suffix for all output fields", default="")
73 def _get_key_prefix(self):
74 prefix = self.prefix if self.prefix else (self.identity if self.identity else "")
75 return prefix
77 def getInputSchema(self, **kwargs) -> KeyedDataSchema:
78 yield (self.vectorKey, Vector)
79 yield (self.fluxType, Vector)
80 yield from self.highSNSelector.getInputSchema()
81 yield from self.lowSNSelector.getInputSchema()
83 def getOutputSchema(self) -> KeyedDataSchema:
84 prefix = self._get_key_prefix()
85 prefix_lower = prefix.lower() if prefix else ""
86 prefix_upper = prefix.capitalize() if prefix else ""
87 suffix = self.suffix
88 return (
89 (f"{prefix_lower}HighSNMask{suffix}", Vector),
90 (f"{prefix_lower}LowSNMask{suffix}", Vector),
91 (f"{{band}}_lowSN{prefix_upper}_median{suffix}", Scalar),
92 (f"{{band}}_lowSN{prefix_upper}_sigmaMad{suffix}", Scalar),
93 (f"{{band}}_lowSN{prefix_upper}_count{suffix}", Scalar),
94 (f"{{band}}_lowSN{prefix_upper}_approxMag{suffix}", Scalar),
95 (f"{{band}}_highSN{prefix_upper}_median{suffix}", Scalar),
96 (f"{{band}}_highSN{prefix_upper}_sigmaMad{suffix}", Scalar),
97 (f"{{band}}_highSN{prefix_upper}_count{suffix}", Scalar),
98 (f"{{band}}_highSN{prefix_upper}_approxMag{suffix}", Scalar),
99 (f"{prefix_lower}HighSNThreshold{suffix}", Scalar),
100 (f"{prefix_lower}LowSNThreshold{suffix}", Scalar),
101 )
103 def __call__(self, data: KeyedData, **kwargs) -> KeyedData:
104 results = {}
105 prefix = self._get_key_prefix()
106 prefix_lower = prefix.lower() if prefix else ""
107 prefix_upper = prefix.capitalize() if prefix else ""
108 suffix = self.suffix
109 highMaskKey = f"{prefix_lower}HighSNMask{suffix}"
110 results[highMaskKey] = self.highSNSelector(data, **kwargs)
112 lowMaskKey = f"{prefix_lower}LowSNMask{suffix}"
113 results[lowMaskKey] = self.lowSNSelector(data, **kwargs)
115 prefix_band = f"{band}_" if (band := kwargs.get("band")) else ""
116 fluxes = data[self.fluxType.format(band=band)] if band is not None else None
118 statAction = SummaryStatisticAction(vectorKey=self.vectorKey)
120 # this is sad, but pex_config seems to have broken behavior that
121 # is dangerous to fix
122 statAction.setDefaults()
124 medianAction = MedianAction(vectorKey="mag")
125 magAction = ConvertFluxToMag(vectorKey="flux")
127 for maskKey, binName in ((lowMaskKey, "low"), (highMaskKey, "high")):
128 name = f"{prefix_band}{binName}SN{prefix_upper}"
129 # set the approxMag to the median mag in the SN selection
130 results[f"{name}_approxMag{suffix}".format(**kwargs)] = (
131 medianAction({"mag": magAction({"flux": fluxes[results[maskKey]]})}) # type: ignore
132 if band is not None
133 else np.nan
134 )
135 stats = statAction(data, **(kwargs | {"mask": results[maskKey]})).items()
136 for name_stat, value in stats:
137 tmpKey = f"{name}_{name_stat}{suffix}".format(**kwargs)
138 results[tmpKey] = value
139 results[f"{prefix_lower}HighSNThreshold{suffix}"] = self.highSNSelector.threshold # type: ignore
140 results[f"{prefix_lower}LowSNThreshold{suffix}"] = self.lowSNSelector.threshold # type: ignore
142 return results
145def _validObjectTypes(value):
146 return value in ("stars", "galaxies", "unknown", "any")
149# ignore type because of conflicting name on tuple baseclass
150class _StatsContainer(NamedTuple):
151 median: Scalar
152 sigmaMad: Scalar
153 count: Scalar # type: ignore
154 approxMag: Scalar
157class DataTypeDefaults(NamedTuple):
158 suffix_stat: str
159 suffix_xy: str
160 color: str
161 colormap: matplotlib.colors.Colormap | None
164class ScatterPlotWithTwoHists(PlotAction):
165 """Makes a scatter plot of the data with a marginal
166 histogram for each axis.
167 """
169 yLims = ListField[float](
170 doc="ylimits of the plot, if not specified determined from data",
171 length=2,
172 optional=True,
173 )
175 xLims = ListField[float](
176 doc="xlimits of the plot, if not specified determined from data", length=2, optional=True
177 )
178 xAxisLabel = Field[str](doc="Label to use for the x axis", optional=False)
179 yAxisLabel = Field[str](doc="Label to use for the y axis", optional=False)
180 magLabel = Field[str](doc="Label to use for the magnitudes used for SNR", optional=False)
182 legendLocation = Field[str](doc="Legend position within main plot", default="upper left")
183 nBins = Field[float](doc="Number of bins on x axis", default=40.0)
184 plot2DHist = Field[bool](
185 doc="Plot a 2D histogram in dense areas of points on the scatter plot."
186 "Doesn't look great if plotting multiple datasets on top of each other.",
187 default=True,
188 )
189 plotTypes = ListField[str](
190 doc="Selection of types of objects to plot. Can take any combination of"
191 " stars, galaxies, unknown, any",
192 optional=False,
193 itemCheck=_validObjectTypes,
194 )
196 addSummaryPlot = Field[bool](
197 doc="Add a summary plot to the figure?",
198 default=True,
199 )
200 suffix_x = Field[str](doc="Suffix for all x-axis action inputs", optional=True, default="")
201 suffix_y = Field[str](doc="Suffix for all y-axis action inputs", optional=True, default="")
202 suffix_stat = Field[str](doc="Suffix for all binned statistic action inputs", optional=True, default="")
204 _stats = ("median", "sigmaMad", "count", "approxMag")
205 _datatypes = {
206 "galaxies": DataTypeDefaults(
207 suffix_stat="Galaxies",
208 suffix_xy="Galaxies",
209 color="firebrick",
210 colormap=mkColormap(["lemonchiffon", "firebrick"]),
211 ),
212 "stars": DataTypeDefaults(
213 suffix_stat="Stars",
214 suffix_xy="Stars",
215 color="midnightblue",
216 colormap=mkColormap(["paleturquoise", "midnightBlue"]),
217 ),
218 "unknown": DataTypeDefaults(
219 suffix_stat="Unknown",
220 suffix_xy="Unknown",
221 color="green",
222 colormap=None,
223 ),
224 "any": DataTypeDefaults(
225 suffix_stat="Any",
226 suffix_xy="",
227 color="purple",
228 colormap=None,
229 ),
230 }
232 def getInputSchema(self) -> KeyedDataSchema:
233 base: list[tuple[str, type[Vector] | ScalarType]] = []
234 for name_datatype in self.plotTypes:
235 config_datatype = self._datatypes[name_datatype]
236 base.append((f"x{config_datatype.suffix_xy}{self.suffix_x}", Vector))
237 base.append((f"y{config_datatype.suffix_xy}{self.suffix_y}", Vector))
238 base.append((f"{name_datatype}HighSNMask{self.suffix_stat}", Vector))
239 base.append((f"{name_datatype}LowSNMask{self.suffix_stat}", Vector))
240 # statistics
241 for name in self._stats:
242 base.append(
243 (f"{{band}}_highSN{config_datatype.suffix_stat}_{name}{self.suffix_stat}", Scalar)
244 )
245 base.append((f"{{band}}_lowSN{config_datatype.suffix_stat}_{name}{self.suffix_stat}", Scalar))
246 base.append((f"{name_datatype}LowSNThreshold{self.suffix_stat}", Scalar))
247 base.append((f"{name_datatype}HighSNThreshold{self.suffix_stat}", Scalar))
249 if self.addSummaryPlot:
250 base.append(("patch", Vector))
252 return base
254 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure:
255 self._validateInput(data, **kwargs)
256 return self.makePlot(data, **kwargs)
258 def _validateInput(self, data: KeyedData, **kwargs) -> None:
259 """NOTE currently can only check that something is not a Scalar, not
260 check that the data is consistent with Vector
261 """
262 needed = self.getFormattedInputSchema(**kwargs)
263 if remainder := {key.format(**kwargs) for key, _ in needed} - {
264 key.format(**kwargs) for key in data.keys()
265 }:
266 raise ValueError(
267 f"Task needs keys {remainder} but they were not found in input keys" f" {list(data.keys())}"
268 )
269 for name, typ in needed:
270 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar)
271 if isScalar and typ != Scalar:
272 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}")
274 def makePlot(
275 self,
276 data: KeyedData,
277 plotInfo: Mapping[str, str],
278 **kwargs,
279 ) -> Figure:
280 """Makes a generic plot with a 2D histogram and collapsed histograms of
281 each axis.
283 Parameters
284 ----------
285 data : `KeyedData`
286 The catalog to plot the points from.
287 plotInfo : `dict`
288 A dictionary of information about the data being plotted with keys:
290 * ``"run"``
291 The output run for the plots (`str`).
292 * ``"skymap"``
293 The type of skymap used for the data (`str`).
294 * ``"filter"``
295 The filter used for this data (`str`).
296 * ``"tract"``
297 The tract that the data comes from (`str`).
299 Returns
300 -------
301 fig : `matplotlib.figure.Figure`
302 The resulting figure.
304 Notes
305 -----
306 Uses the axisLabels config options `x` and `y` and the axisAction
307 config options `xAction` and `yAction` to plot a scatter
308 plot of the values against each other. A histogram of the points
309 collapsed onto each axis is also plotted. A summary panel showing the
310 median of the y value in each patch is shown in the upper right corner
311 of the resultant plot. The code uses the selectorActions to decide
312 which points to plot and the statisticSelector actions to determine
313 which points to use for the printed statistics.
315 If this function is being used within the pipetask framework
316 that takes care of making sure that data has all the required
317 elements but if you are running this as a standalone function
318 then you will need to provide the following things in the
319 input data.
321 * If stars is in self.plotTypes:
322 xStars, yStars, starsHighSNMask, starsLowSNMask and
323 {band}_highSNStars_{name}, {band}_lowSNStars_{name}
324 where name is median, sigma_Mad, count and approxMag.
326 * If it is for galaxies/unknowns then replace stars in the above
327 names with galaxies/unknowns.
329 * If it is for any (which covers all the points) then it
330 becomes, x, y, and any instead of stars for the other
331 parameters given above.
333 * In every case it is expected that data contains:
334 lowSnThreshold, highSnThreshold and patch
335 (if the summary plot is being plotted).
337 Examples
338 --------
339 An example of the plot produced from this code is here:
341 .. image:: /_static/analysis_tools/scatterPlotExample.png
343 For a detailed example of how to make a plot from the command line
344 please see the
345 :ref:`getting started guide<analysis-tools-getting-started>`.
346 """
347 if not self.plotTypes:
348 noDataFig = Figure()
349 noDataFig.text(0.3, 0.5, "No data to plot after selectors applied")
350 noDataFig = addPlotInfo(noDataFig, plotInfo)
351 return noDataFig
353 # Set default color and line style for the horizontal
354 # reference line at 0
355 if "hlineColor" not in kwargs:
356 kwargs["hlineColor"] = "black"
358 if "hlineStyle" not in kwargs:
359 kwargs["hlineStyle"] = (0, (1, 4))
361 fig = plt.figure(dpi=300)
362 gs = gridspec.GridSpec(4, 4)
364 # add the various plot elements
365 ax, imhist = self._scatterPlot(data, fig, gs, **kwargs)
366 self._makeTopHistogram(data, fig, gs, ax, **kwargs)
367 self._makeSideHistogram(data, fig, gs, ax, imhist, **kwargs)
368 # Needs info from run quantum
369 skymap = kwargs.get("skymap", None)
370 if self.addSummaryPlot and skymap is not None:
371 sumStats = generateSummaryStats(data, skymap, plotInfo)
372 label = self.yAxisLabel
373 fig = addSummaryPlot(fig, gs[0, -1], sumStats, label)
375 plt.draw()
376 plt.subplots_adjust(wspace=0.0, hspace=0.0, bottom=0.22, left=0.21)
377 fig = addPlotInfo(fig, plotInfo)
378 return fig
380 def _scatterPlot(
381 self, data: KeyedData, fig: Figure, gs: gridspec.GridSpec, **kwargs
382 ) -> tuple[Axes, Optional[PolyCollection]]:
383 suf_x = self.suffix_x
384 suf_y = self.suffix_y
385 suf_stat = self.suffix_stat
386 # Main scatter plot
387 ax = fig.add_subplot(gs[1:, :-1])
389 binThresh = 5
391 yBinsOut = []
392 linesForLegend = []
394 toPlotList = []
395 histIm = None
396 highStats: _StatsContainer
397 lowStats: _StatsContainer
399 for name_datatype in self.plotTypes:
400 config_datatype = self._datatypes[name_datatype]
401 highArgs = {}
402 lowArgs = {}
403 for name in self._stats:
404 highArgs[name] = cast(
405 Scalar,
406 data[f"{{band}}_highSN{config_datatype.suffix_stat}_{name}{suf_stat}".format(**kwargs)],
407 )
408 lowArgs[name] = cast(
409 Scalar,
410 data[f"{{band}}_lowSN{config_datatype.suffix_stat}_{name}{suf_stat}".format(**kwargs)],
411 )
412 highStats = _StatsContainer(**highArgs)
413 lowStats = _StatsContainer(**lowArgs)
415 toPlotList.append(
416 (
417 data[f"x{config_datatype.suffix_xy}{suf_x}"],
418 data[f"y{config_datatype.suffix_xy}{suf_y}"],
419 data[f"{name_datatype}HighSNMask{suf_stat}"],
420 data[f"{name_datatype}LowSNMask{suf_stat}"],
421 data[f"{name_datatype}HighSNThreshold{suf_stat}"],
422 data[f"{name_datatype}LowSNThreshold{suf_stat}"],
423 config_datatype.color,
424 config_datatype.colormap,
425 highStats,
426 lowStats,
427 )
428 )
430 xLims = self.xLims if self.xLims is not None else [np.Inf, -np.Inf]
431 for j, (
432 xs,
433 ys,
434 highSn,
435 lowSn,
436 highThresh,
437 lowThresh,
438 color,
439 cmap,
440 highStats,
441 lowStats,
442 ) in enumerate(toPlotList):
443 highSn = cast(Vector, highSn)
444 lowSn = cast(Vector, lowSn)
445 # ensure the columns are actually array
446 xs = np.array(xs)
447 ys = np.array(ys)
448 sigMadYs = nanSigmaMad(ys)
449 # plot lone median point if there's not enough data to measure more
450 n_xs = len(xs)
451 if n_xs == 0 or not np.isfinite(sigMadYs):
452 continue
453 elif n_xs < 10:
454 xs = [nanMedian(xs)]
455 sigMads = np.array([nanSigmaMad(ys)])
456 ys = np.array([nanMedian(ys)])
457 (medLine,) = ax.plot(xs, ys, color, label=f"Median: {ys[0]:.2g}", lw=0.8)
458 linesForLegend.append(medLine)
459 (sigMadLine,) = ax.plot(
460 xs,
461 ys + 1.0 * sigMads,
462 color,
463 alpha=0.8,
464 lw=0.8,
465 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:.2g}",
466 )
467 ax.plot(xs, ys - 1.0 * sigMads, color, alpha=0.8)
468 linesForLegend.append(sigMadLine)
469 histIm = None
470 continue
472 if self.xLims:
473 xMin, xMax = self.xLims
474 else:
475 # Chop off 1/3% from only the finite xs values
476 # (there may be +/-np.Inf values)
477 # TODO: This should be configurable
478 # but is there a good way to avoid redundant config params
479 # without using slightly annoying subconfigs?
480 xs1, xs97 = np.nanpercentile(xs[np.isfinite(xs)], (1, 97))
481 xScale = (xs97 - xs1) / 20.0 # This is ~5% of the data range
482 xMin, xMax = (xs1 - xScale, xs97 + xScale)
483 xLims[0] = min(xLims[0], xMin)
484 xLims[1] = max(xLims[1], xMax)
486 xEdges = np.arange(xMin, xMax, (xMax - xMin) / self.nBins)
487 medYs = nanMedian(ys)
488 fiveSigmaHigh = medYs + 5.0 * sigMadYs
489 fiveSigmaLow = medYs - 5.0 * sigMadYs
490 binSize = (fiveSigmaHigh - fiveSigmaLow) / 101.0
491 yEdges = np.arange(fiveSigmaLow, fiveSigmaHigh, binSize)
493 counts, xBins, yBins = np.histogram2d(xs, ys, bins=(xEdges, yEdges))
494 yBinsOut.append(yBins)
495 countsYs = np.sum(counts, axis=1)
497 ids = np.where((countsYs > binThresh))[0]
498 xEdgesPlot = xEdges[ids][1:]
499 xEdges = xEdges[ids]
501 if len(ids) > 1:
502 # Create the codes needed to turn the sigmaMad lines
503 # into a path to speed up checking which points are
504 # inside the area.
505 codes = np.ones(len(xEdgesPlot) * 2) * Path.LINETO
506 codes[0] = Path.MOVETO
507 codes[-1] = Path.CLOSEPOLY
509 meds = np.zeros(len(xEdgesPlot))
510 threeSigMadVerts = np.zeros((len(xEdgesPlot) * 2, 2))
511 sigMads = np.zeros(len(xEdgesPlot))
513 for i, xEdge in enumerate(xEdgesPlot):
514 ids = np.where((xs < xEdge) & (xs > xEdges[i]) & (np.isfinite(ys)))[0]
515 med = nanMedian(ys[ids])
516 sigMad = nanSigmaMad(ys[ids])
517 meds[i] = med
518 sigMads[i] = sigMad
519 threeSigMadVerts[i, :] = [xEdge, med + 3 * sigMad]
520 threeSigMadVerts[-(i + 1), :] = [xEdge, med - 3 * sigMad]
522 (medLine,) = ax.plot(xEdgesPlot, meds, color, label="Running Median")
523 linesForLegend.append(medLine)
525 # Make path to check which points lie within one sigma mad
526 threeSigMadPath = Path(threeSigMadVerts, codes)
528 # Add lines for the median +/- 3 * sigma MAD
529 (threeSigMadLine,) = ax.plot(
530 xEdgesPlot,
531 threeSigMadVerts[: len(xEdgesPlot), 1],
532 color,
533 alpha=0.4,
534 label=r"3$\sigma_{MAD}$",
535 )
536 ax.plot(xEdgesPlot[::-1], threeSigMadVerts[len(xEdgesPlot) :, 1], color, alpha=0.4)
538 # Add lines for the median +/- 1 * sigma MAD
539 (sigMadLine,) = ax.plot(
540 xEdgesPlot, meds + 1.0 * sigMads, color, alpha=0.8, label=r"$\sigma_{MAD}$"
541 )
542 linesForLegend.append(sigMadLine)
543 ax.plot(xEdgesPlot, meds - 1.0 * sigMads, color, alpha=0.8)
545 # Add lines for the median +/- 2 * sigma MAD
546 (twoSigMadLine,) = ax.plot(
547 xEdgesPlot, meds + 2.0 * sigMads, color, alpha=0.6, label=r"2$\sigma_{MAD}$"
548 )
549 linesForLegend.append(twoSigMadLine)
550 linesForLegend.append(threeSigMadLine)
551 ax.plot(xEdgesPlot, meds - 2.0 * sigMads, color, alpha=0.6)
553 # Check which points are outside 3 sigma MAD of the median
554 # and plot these as points.
555 inside = threeSigMadPath.contains_points(np.array([xs, ys]).T)
556 ax.plot(xs[~inside], ys[~inside], ".", ms=3, alpha=0.3, mfc=color, mec=color, zorder=-1)
558 # Add some stats text
559 xPos = 0.65 - 0.4 * j
560 bbox = dict(edgecolor=color, linestyle="--", facecolor="none")
561 statText = f"S/N > {highThresh:0.4g} Stats ({self.magLabel} < {highStats.approxMag:0.4g})\n"
562 highStatsStr = (
563 f"Median: {highStats.median:0.4g} "
564 + r"$\sigma_{MAD}$: "
565 + f"{highStats.sigmaMad:0.4g} "
566 + r"N$_{points}$: "
567 + f"{highStats.count}"
568 )
569 statText += highStatsStr
570 fig.text(xPos, 0.090, statText, bbox=bbox, transform=fig.transFigure, fontsize=6)
572 bbox = dict(edgecolor=color, linestyle=":", facecolor="none")
573 statText = f"S/N > {lowThresh:0.4g} Stats ({self.magLabel} < {lowStats.approxMag:0.4g})\n"
574 lowStatsStr = (
575 f"Median: {lowStats.median:0.4g} "
576 + r"$\sigma_{MAD}$: "
577 + f"{lowStats.sigmaMad:0.4g} "
578 + r"N$_{points}$: "
579 + f"{lowStats.count}"
580 )
581 statText += lowStatsStr
582 fig.text(xPos, 0.020, statText, bbox=bbox, transform=fig.transFigure, fontsize=6)
584 if self.plot2DHist:
585 histIm = ax.hexbin(xs[inside], ys[inside], gridsize=75, cmap=cmap, mincnt=1, zorder=-3)
587 # If there are not many sources being used for the
588 # statistics then plot them individually as just
589 # plotting a line makes the statistics look wrong
590 # as the magnitude estimation is iffy for low
591 # numbers of sources.
592 if np.sum(highSn) < 100 and np.sum(highSn) > 0:
593 ax.plot(
594 cast(Vector, xs[highSn]),
595 cast(Vector, ys[highSn]),
596 marker="x",
597 ms=4,
598 mec="w",
599 mew=2,
600 ls="none",
601 )
602 (highSnLine,) = ax.plot(
603 cast(Vector, xs[highSn]),
604 cast(Vector, ys[highSn]),
605 color=color,
606 marker="x",
607 ms=4,
608 ls="none",
609 label="High SN",
610 )
611 linesForLegend.append(highSnLine)
612 else:
613 ax.axvline(highStats.approxMag, color=color, ls="--")
615 if np.sum(lowSn) < 100 and np.sum(lowSn) > 0:
616 ax.plot(
617 cast(Vector, xs[lowSn]),
618 cast(Vector, ys[lowSn]),
619 marker="+",
620 ms=4,
621 mec="w",
622 mew=2,
623 ls="none",
624 )
625 (lowSnLine,) = ax.plot(
626 cast(Vector, xs[lowSn]),
627 cast(Vector, ys[lowSn]),
628 color=color,
629 marker="+",
630 ms=4,
631 ls="none",
632 label="Low SN",
633 )
634 linesForLegend.append(lowSnLine)
635 else:
636 ax.axvline(lowStats.approxMag, color=color, ls=":")
638 else:
639 ax.plot(xs, ys, ".", ms=5, alpha=0.3, mfc=color, mec=color, zorder=-1)
640 meds = np.array([nanMedian(ys)] * len(xs))
641 (medLine,) = ax.plot(xs, meds, color, label=f"Median: {nanMedian(ys):0.3g}", lw=0.8)
642 linesForLegend.append(medLine)
643 sigMads = np.array([nanSigmaMad(ys)] * len(xs))
644 (sigMadLine,) = ax.plot(
645 xs,
646 meds + 1.0 * sigMads,
647 color,
648 alpha=0.8,
649 lw=0.8,
650 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:0.3g}",
651 )
652 ax.plot(xs, meds - 1.0 * sigMads, color, alpha=0.8)
653 linesForLegend.append(sigMadLine)
654 histIm = None
656 # Add a horizontal reference line at 0 to the scatter plot
657 ax.axhline(0, color=kwargs["hlineColor"], ls=kwargs["hlineStyle"], alpha=0.7, zorder=-2)
659 # Set the scatter plot limits
660 suf_x = self.suffix_y
661 # TODO: Make this not work by accident
662 if f"yStars{suf_x}" in data and (len(cast(Vector, data[f"yStars{suf_x}"])) > 0):
663 plotMed = nanMedian(cast(Vector, data[f"yStars{suf_x}"]))
664 elif f"yGalaxies{suf_x}" in data and (len(cast(Vector, data[f"yGalaxies{suf_x}"])) > 0):
665 plotMed = nanMedian(cast(Vector, data[f"yGalaxies{suf_x}"]))
666 else:
667 plotMed = np.nan
669 # Ignore types below pending making this not working my accident
670 if len(xs) < 2: # type: ignore
671 meds = [nanMedian(ys)] # type: ignore
672 if self.yLims:
673 ax.set_ylim(self.yLims[0], self.yLims[1]) # type: ignore
674 elif np.isfinite(plotMed):
675 numSig = 4
676 yLimMin = plotMed - numSig * sigMadYs # type: ignore
677 yLimMax = plotMed + numSig * sigMadYs # type: ignore
678 while (yLimMax < np.max(meds) or yLimMin > np.min(meds)) and numSig < 10: # type: ignore
679 numSig += 1
681 numSig += 1
682 yLimMin = plotMed - numSig * sigMadYs # type: ignore
683 yLimMax = plotMed + numSig * sigMadYs # type: ignore
684 ax.set_ylim(yLimMin, yLimMax)
686 # This could be false if len(x) == 0 for xs in toPlotList
687 # ... in which case nothing was plotted and limits are irrelevant
688 if all(np.isfinite(xLims)):
689 ax.set_xlim(xLims)
691 # Add a line legend
692 ax.legend(
693 handles=linesForLegend,
694 ncol=4,
695 fontsize=6,
696 loc=self.legendLocation,
697 framealpha=0.9,
698 edgecolor="k",
699 borderpad=0.4,
700 handlelength=1,
701 )
703 # Add axes labels
704 ax.set_ylabel(self.yAxisLabel, fontsize=10, labelpad=10)
705 ax.set_xlabel(self.xAxisLabel, fontsize=10, labelpad=2)
707 return ax, histIm
709 def _makeTopHistogram(
710 self, data: KeyedData, figure: Figure, gs: gridspec.GridSpec, ax: Axes, **kwargs
711 ) -> None:
712 suf_x = self.suffix_x
713 # Top histogram
714 topHist = figure.add_subplot(gs[0, :-1], sharex=ax)
715 x_min, x_max = ax.get_xlim()
716 bins = np.linspace(x_min, x_max, 100)
718 if "any" in self.plotTypes:
719 x_any = f"x{self._datatypes['any'].suffix_xy}{suf_x}"
720 keys_notany = [x for x in self.plotTypes if x != "any"]
721 else:
722 x_any = (
723 np.concatenate([data[f"x{self._datatypes[key].suffix_xy}{suf_x}"] for key in self.plotTypes])
724 if (len(self.plotTypes) > 1)
725 else None
726 )
727 keys_notany = self.plotTypes
728 if x_any is not None:
729 topHist.hist(x_any, bins=bins, color="grey", alpha=0.3, log=True, label=f"Any ({len(x_any)})")
731 for key in keys_notany:
732 config_datatype = self._datatypes[key]
733 vector = data[f"x{config_datatype.suffix_xy}{suf_x}"]
734 topHist.hist(
735 vector,
736 bins=bins,
737 color=config_datatype.color,
738 histtype="step",
739 log=True,
740 label=f"{config_datatype.suffix_stat} ({len(vector)})",
741 )
742 topHist.axes.get_xaxis().set_visible(False)
743 topHist.set_ylabel("Number", fontsize=8)
744 topHist.legend(fontsize=6, framealpha=0.9, borderpad=0.4, loc="lower left", ncol=3, edgecolor="k")
746 # Side histogram
748 def _makeSideHistogram(
749 self,
750 data: KeyedData,
751 figure: Figure,
752 gs: gridspec.Gridspec,
753 ax: Axes,
754 histIm: Optional[PolyCollection],
755 **kwargs,
756 ) -> None:
757 suf_y = self.suffix_y
758 # Side histogram
759 sideHist = figure.add_subplot(gs[1:, -1], sharey=ax)
760 y_min, y_max = ax.get_ylim()
761 bins = np.linspace(y_min, y_max, 100)
763 if "any" in self.plotTypes:
764 y_any = f"y{self._datatypes['any'].suffix_xy}{suf_y}"
765 keys_notany = [x for x in self.plotTypes if x != "any"]
766 else:
767 y_any = (
768 np.concatenate([data[f"y{self._datatypes[key].suffix_xy}{suf_y}"] for key in self.plotTypes])
769 if (len(self.plotTypes) > 1)
770 else None
771 )
772 keys_notany = self.plotTypes
773 if y_any is not None:
774 sideHist.hist(
775 np.array(y_any),
776 bins=bins,
777 color="grey",
778 alpha=0.3,
779 orientation="horizontal",
780 log=True,
781 )
782 kwargs_hist = dict(
783 bins=bins,
784 histtype="step",
785 log=True,
786 orientation="horizontal",
787 )
788 for key in keys_notany:
789 config_datatype = self._datatypes[key]
790 vector = data[f"y{config_datatype.suffix_xy}{suf_y}"]
791 sideHist.hist(
792 vector,
793 color=config_datatype.color,
794 **kwargs_hist,
795 )
796 sideHist.hist(
797 vector[cast(Vector, data[f"{key}HighSNMask{self.suffix_stat}"])],
798 color=config_datatype.color,
799 ls="--",
800 **kwargs_hist,
801 )
802 sideHist.hist(
803 vector[cast(Vector, data[f"{key}LowSNMask{self.suffix_stat}"])],
804 color=config_datatype.color,
805 **kwargs_hist,
806 ls=":",
807 )
809 # Add a horizontal reference line at 0 to the side histogram
810 sideHist.axhline(0, color=kwargs["hlineColor"], ls=kwargs["hlineStyle"], alpha=0.7, zorder=-2)
812 sideHist.axes.get_yaxis().set_visible(False)
813 sideHist.set_xlabel("Number", fontsize=8)
814 if self.plot2DHist and histIm is not None:
815 divider = make_axes_locatable(sideHist)
816 cax = divider.append_axes("right", size="8%", pad=0)
817 plt.colorbar(histIm, cax=cax, orientation="vertical", label="Number of Points Per Bin")