Coverage for python / lsst / analysis / tools / actions / plot / scatterplotWithTwoHists.py: 13%
444 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-15 00:23 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-15 00:23 +0000
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ("ScatterPlotStatsAction", "ScatterPlotWithTwoHists")
26import math
27from typing import Mapping, NamedTuple, Optional, cast
29import matplotlib.colors
30import matplotlib.patheffects as pathEffects
31import numpy as np
32from lsst.pex.config import Field
33from lsst.pex.config.configurableActions import ConfigurableActionField
34from lsst.pex.config.listField import ListField
35from lsst.utils.plotting import (
36 galaxies_cmap,
37 galaxies_color,
38 make_figure,
39 set_rubin_plotstyle,
40 stars_cmap,
41 stars_color,
42)
43from matplotlib import gridspec
44from matplotlib.axes import Axes
45from matplotlib.collections import PolyCollection
46from matplotlib.figure import Figure
47from matplotlib.path import Path
48from matplotlib.ticker import LogFormatterMathtext, NullFormatter
49from mpl_toolkits.axes_grid1 import make_axes_locatable
51from ...interfaces import KeyedData, KeyedDataAction, KeyedDataSchema, PlotAction, Scalar, ScalarType, Vector
52from ...math import nanMedian, nanSigmaMad
53from ..keyedData.summaryStatistics import SummaryStatisticAction
54from ..scalar import MedianAction
55from ..vector import ConvertFluxToMag, SnSelector
56from .plotUtils import addPlotInfo, addSummaryPlot, generateSummaryStats
59class ScatterPlotStatsAction(KeyedDataAction):
60 """Calculates the statistics needed for the
61 scatter plot with two hists.
62 """
64 vectorKey = Field[str](doc="Vector on which to compute statistics")
65 highSNSelector = ConfigurableActionField[SnSelector](
66 doc="Selector used to determine high SN Objects", default=SnSelector(threshold=2700)
67 )
68 lowSNSelector = ConfigurableActionField[SnSelector](
69 doc="Selector used to determine low SN Objects", default=SnSelector(threshold=500)
70 )
71 prefix = Field[str](
72 doc="Prefix for all output fields; will use self.identity if None",
73 optional=True,
74 default=None,
75 )
76 fluxType = Field[str](doc="Vector key to use to compute signal to noise ratio", default="{band}_psfFlux")
77 suffix = Field[str](doc="Suffix for all output fields", default="")
79 def _get_key_prefix(self):
80 prefix = self.prefix if self.prefix else (self.identity if self.identity else "")
81 return prefix
83 def getInputSchema(self, **kwargs) -> KeyedDataSchema:
84 yield (self.vectorKey, Vector)
85 yield (self.fluxType, Vector)
86 yield from self.highSNSelector.getInputSchema()
87 yield from self.lowSNSelector.getInputSchema()
89 def getOutputSchema(self) -> KeyedDataSchema:
90 prefix = self._get_key_prefix()
91 prefix_lower = prefix.lower() if prefix else ""
92 prefix_upper = prefix.capitalize() if prefix else ""
93 suffix = self.suffix
94 return (
95 (f"{prefix_lower}HighSNMask{suffix}", Vector),
96 (f"{prefix_lower}LowSNMask{suffix}", Vector),
97 (f"{{band}}_lowSN{prefix_upper}_median{suffix}", Scalar),
98 (f"{{band}}_lowSN{prefix_upper}_sigmaMad{suffix}", Scalar),
99 (f"{{band}}_lowSN{prefix_upper}_count{suffix}", Scalar),
100 (f"{{band}}_lowSN{prefix_upper}_approxMag{suffix}", Scalar),
101 (f"{{band}}_highSN{prefix_upper}_median{suffix}", Scalar),
102 (f"{{band}}_highSN{prefix_upper}_sigmaMad{suffix}", Scalar),
103 (f"{{band}}_highSN{prefix_upper}_count{suffix}", Scalar),
104 (f"{{band}}_highSN{prefix_upper}_approxMag{suffix}", Scalar),
105 (f"{prefix_lower}HighSNThreshold{suffix}", Scalar),
106 (f"{prefix_lower}LowSNThreshold{suffix}", Scalar),
107 )
109 def __call__(self, data: KeyedData, **kwargs) -> KeyedData:
110 results = {}
111 prefix = self._get_key_prefix()
112 prefix_lower = prefix.lower() if prefix else ""
113 prefix_upper = prefix.capitalize() if prefix else ""
114 suffix = self.suffix
115 highMaskKey = f"{prefix_lower}HighSNMask{suffix}"
116 results[highMaskKey] = self.highSNSelector(data, **kwargs)
118 lowMaskKey = f"{prefix_lower}LowSNMask{suffix}"
119 results[lowMaskKey] = self.lowSNSelector(data, **kwargs)
121 prefix_band = f"{band}_" if (band := kwargs.get("band")) else ""
122 fluxes = data[self.fluxType.format(band=band)] if band is not None else None
124 statAction = SummaryStatisticAction(vectorKey=self.vectorKey)
126 # this is sad, but pex_config seems to have broken behavior that
127 # is dangerous to fix
128 statAction.setDefaults()
130 medianAction = MedianAction(vectorKey="mag")
131 magAction = ConvertFluxToMag(vectorKey="flux")
133 for maskKey, binName in ((lowMaskKey, "low"), (highMaskKey, "high")):
134 name = f"{prefix_band}{binName}SN{prefix_upper}"
135 # set the approxMag to the median mag in the SN selection
136 results[f"{name}_approxMag{suffix}".format(**kwargs)] = (
137 medianAction({"mag": magAction({"flux": fluxes[results[maskKey]]})}) # type: ignore
138 if band is not None
139 else np.nan
140 )
141 stats = statAction(data, **(kwargs | {"mask": results[maskKey]})).items()
142 for name_stat, value in stats:
143 tmpKey = f"{name}_{name_stat}{suffix}".format(**kwargs)
144 results[tmpKey] = value
145 results[f"{prefix_lower}HighSNThreshold{suffix}"] = self.highSNSelector.threshold # type: ignore
146 results[f"{prefix_lower}LowSNThreshold{suffix}"] = self.lowSNSelector.threshold # type: ignore
148 return results
151def _validObjectTypes(value):
152 return value in ("stars", "galaxies", "unknown", "any")
155# ignore type because of conflicting name on tuple baseclass
156class _StatsContainer(NamedTuple):
157 median: Scalar
158 sigmaMad: Scalar
159 count: Scalar # type: ignore
160 approxMag: Scalar
163class DataTypeDefaults(NamedTuple):
164 suffix_stat: str
165 suffix_xy: str
166 color: str
167 colormap: matplotlib.colors.Colormap | None
170class LogFormatterExponentSci(LogFormatterMathtext):
171 """
172 Format values following scientific notation.
174 Unlike the matplotlib LogFormatterExponent, this will print near-integer
175 coefficients with a base between 0 and 2 such as 500 as 500 (if base10)
176 or 5e2 otherwise.
177 """
179 def _non_decade_format(self, sign_string, base, fx, usetex):
180 """Return string for non-decade locations."""
181 b = float(base)
182 exponent = math.floor(fx)
183 coeff = b ** (fx - exponent)
184 rounded = round(coeff)
185 if math.isclose(coeff, rounded):
186 if (base == "10") and (0 <= exponent <= 3):
187 return f"{sign_string}{rounded}{'0'*int(exponent)}"
188 coeff = rounded
189 return f"{sign_string}{coeff:1.1f}e{exponent}"
192class ScatterPlotWithTwoHists(PlotAction):
193 """Makes a scatter plot of the data with a marginal
194 histogram for each axis.
195 """
197 yLims = ListField[float](
198 doc="ylimits of the plot, if not specified determined from data",
199 length=2,
200 optional=True,
201 )
203 xLims = ListField[float](
204 doc="xlimits of the plot, if not specified determined from data", length=2, optional=True
205 )
206 xAxisLabel = Field[str](doc="Label to use for the x axis", optional=False)
207 yAxisLabel = Field[str](doc="Label to use for the y axis", optional=False)
208 magLabel = Field[str](doc="Label to use for the magnitudes used for SNR", optional=False)
210 legendLocation = Field[str](doc="Legend position within main plot", default="upper left")
211 nBins = Field[float](doc="Number of bins on x axis", default=40.0)
212 plot2DHist = Field[bool](
213 doc="Plot a 2D histogram in dense areas of points on the scatter plot."
214 "Doesn't look great if plotting multiple datasets on top of each other.",
215 default=True,
216 )
217 plotTypes = ListField[str](
218 doc="Selection of types of objects to plot. Can take any combination of"
219 " stars, galaxies, unknown, any",
220 optional=False,
221 itemCheck=_validObjectTypes,
222 )
224 addSummaryPlot = Field[bool](
225 doc="Add a summary plot to the figure?",
226 default=True,
227 )
228 histMinimum = Field[float](
229 doc="Minimum value for the histogram count axis",
230 default=0.3,
231 )
232 xHistMaxLabels = Field[int](
233 doc="Maximum number of labels for ticks on the x-axis marginal histogram",
234 default=3,
235 check=lambda x: x >= 2,
236 )
237 yHistMaxLabels = Field[int](
238 doc="Maximum number of labels for ticks on the y-axis marginal histogram",
239 default=3,
240 check=lambda x: x >= 2,
241 )
243 suffix_x = Field[str](doc="Suffix for all x-axis action inputs", optional=True, default="")
244 suffix_y = Field[str](doc="Suffix for all y-axis action inputs", optional=True, default="")
245 suffix_stat = Field[str](doc="Suffix for all binned statistic action inputs", optional=True, default="")
247 publicationStyle = Field[bool](doc="Slimmed down publication style plot?", default=False)
249 _stats = ("median", "sigmaMad", "count", "approxMag")
250 _datatypes = {
251 "galaxies": DataTypeDefaults(
252 suffix_stat="Galaxies",
253 suffix_xy="Galaxies",
254 color=galaxies_color(),
255 colormap=galaxies_cmap(single_color=True),
256 ),
257 "stars": DataTypeDefaults(
258 suffix_stat="Stars",
259 suffix_xy="Stars",
260 color=stars_color(),
261 colormap=stars_cmap(single_color=True),
262 ),
263 "unknown": DataTypeDefaults(
264 suffix_stat="Unknown",
265 suffix_xy="Unknown",
266 color="green",
267 colormap=None,
268 ),
269 "any": DataTypeDefaults(
270 suffix_stat="Any",
271 suffix_xy="",
272 color="purple",
273 colormap=None,
274 ),
275 }
277 def getInputSchema(self) -> KeyedDataSchema:
278 base: list[tuple[str, type[Vector] | ScalarType]] = []
279 for name_datatype in self.plotTypes:
280 config_datatype = self._datatypes[name_datatype]
281 if not self.publicationStyle:
282 base.append((f"x{config_datatype.suffix_xy}{self.suffix_x}", Vector))
283 base.append((f"y{config_datatype.suffix_xy}{self.suffix_y}", Vector))
284 base.append((f"{name_datatype}HighSNMask{self.suffix_stat}", Vector))
285 base.append((f"{name_datatype}LowSNMask{self.suffix_stat}", Vector))
286 # statistics
287 for name in self._stats:
288 base.append(
289 (f"{{band}}_highSN{config_datatype.suffix_stat}_{name}{self.suffix_stat}", Scalar)
290 )
291 base.append(
292 (f"{{band}}_lowSN{config_datatype.suffix_stat}_{name}{self.suffix_stat}", Scalar)
293 )
294 base.append((f"{name_datatype}LowSNThreshold{self.suffix_stat}", Scalar))
295 base.append((f"{name_datatype}HighSNThreshold{self.suffix_stat}", Scalar))
297 if self.addSummaryPlot and not self.publicationStyle:
298 base.append(("patch", Vector))
300 return base
302 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure:
303 self._validateInput(data, **kwargs)
304 return self.makePlot(data, **kwargs)
306 def _validateInput(self, data: KeyedData, **kwargs) -> None:
307 """NOTE currently can only check that something is not a Scalar, not
308 check that the data is consistent with Vector
309 """
310 needed = self.getFormattedInputSchema(**kwargs)
311 if remainder := {key.format(**kwargs) for key, _ in needed} - {
312 key.format(**kwargs) for key in data.keys()
313 }:
314 raise ValueError(
315 f"Task needs keys {remainder} but they were not found in input keys" f" {list(data.keys())}"
316 )
317 for name, typ in needed:
318 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar)
319 if isScalar and typ != Scalar:
320 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}")
322 def makePlot(
323 self,
324 data: KeyedData,
325 plotInfo: Mapping[str, str],
326 **kwargs,
327 ) -> Figure:
328 """Makes a generic plot with a 2D histogram and collapsed histograms of
329 each axis.
331 Parameters
332 ----------
333 data : `KeyedData`
334 The catalog to plot the points from.
335 plotInfo : `dict`
336 A dictionary of information about the data being plotted with keys:
338 * ``"run"``
339 The output run for the plots (`str`).
340 * ``"skymap"``
341 The type of skymap used for the data (`str`).
342 * ``"filter"``
343 The filter used for this data (`str`).
344 * ``"tract"``
345 The tract that the data comes from (`str`).
347 Returns
348 -------
349 fig : `matplotlib.figure.Figure`
350 The resulting figure.
352 Notes
353 -----
354 Uses the axisLabels config options `x` and `y` and the axisAction
355 config options `xAction` and `yAction` to plot a scatter
356 plot of the values against each other. A histogram of the points
357 collapsed onto each axis is also plotted. A summary panel showing the
358 median of the y value in each patch is shown in the upper right corner
359 of the resultant plot. The code uses the selectorActions to decide
360 which points to plot and the statisticSelector actions to determine
361 which points to use for the printed statistics.
363 If this function is being used within the pipetask framework
364 that takes care of making sure that data has all the required
365 elements but if you are running this as a standalone function
366 then you will need to provide the following things in the
367 input data.
369 * If stars is in self.plotTypes:
370 xStars, yStars, starsHighSNMask, starsLowSNMask and
371 {band}_highSNStars_{name}, {band}_lowSNStars_{name}
372 where name is median, sigma_Mad, count and approxMag.
374 * If it is for galaxies/unknowns then replace stars in the above
375 names with galaxies/unknowns.
377 * If it is for any (which covers all the points) then it
378 becomes, x, y, and any instead of stars for the other
379 parameters given above.
381 * In every case it is expected that data contains:
382 lowSnThreshold, highSnThreshold and patch
383 (if the summary plot is being plotted).
385 Examples
386 --------
387 An example of the plot produced from this code is here:
389 .. image:: /_static/analysis_tools/scatterPlotExample.png
391 For a detailed example of how to make a plot from the command line
392 please see the
393 :ref:`getting started guide<analysis-tools-getting-started>`.
394 """
395 if not self.plotTypes:
396 noDataFig = Figure()
397 noDataFig.text(0.3, 0.5, "No data to plot after selectors applied")
398 noDataFig = addPlotInfo(noDataFig, plotInfo)
399 return noDataFig
401 # Set default color and line style for the horizontal
402 # reference line at 0
403 if "hlineColor" not in kwargs:
404 kwargs["hlineColor"] = "black"
406 if "hlineStyle" not in kwargs:
407 kwargs["hlineStyle"] = (0, (1, 4))
409 set_rubin_plotstyle()
410 fig = make_figure()
411 gs = gridspec.GridSpec(4, 4)
413 # add the various plot elements
414 ax, imhist = self._scatterPlot(data, fig, gs, **kwargs)
415 if ax is None:
416 noDataFig = Figure()
417 noDataFig.text(0.3, 0.5, "No data to plot after selectors applied")
418 if not self.publicationStyle:
419 noDataFig = addPlotInfo(noDataFig, plotInfo)
420 return noDataFig
422 self._makeTopHistogram(data, fig, gs, ax, **kwargs)
423 self._makeSideHistogram(data, fig, gs, ax, imhist, **kwargs)
424 # Needs info from run quantum
425 skymap = kwargs.get("skymap", None)
426 if self.addSummaryPlot and skymap is not None and not self.publicationStyle:
427 sumStats = generateSummaryStats(data, skymap, plotInfo)
428 label = self.yAxisLabel
429 fig = addSummaryPlot(fig, gs[0, -1], sumStats, label)
431 fig.canvas.draw()
432 # TODO: Check if these spacings can be defined less arbitrarily
433 fig.subplots_adjust(
434 wspace=0.0,
435 hspace=0.0,
436 bottom=0.13 if self.publicationStyle else 0.22,
437 left=0.18 if self.publicationStyle else 0.21,
438 right=0.92 if self.publicationStyle else None,
439 top=0.98 if self.publicationStyle else None,
440 )
441 if not self.publicationStyle:
442 fig = addPlotInfo(fig, plotInfo)
443 return fig
445 def _scatterPlot(
446 self, data: KeyedData, fig: Figure, gs: gridspec.GridSpec, **kwargs
447 ) -> tuple[Axes, Optional[PolyCollection]]:
448 suf_x = self.suffix_x
449 suf_y = self.suffix_y
450 suf_stat = self.suffix_stat
451 # Main scatter plot
452 ax = fig.add_subplot(gs[1:, :-1])
454 binThresh = 5
455 min_n_xs_for_stats = 10
457 yBinsOut = []
458 linesForLegend = []
460 toPlotList = []
461 histIm = None
462 highStats: _StatsContainer
463 lowStats: _StatsContainer
465 magLabel = self.magLabel
466 kwargs_in_label = {k: v for k, v in kwargs.items() if f"{{{k}}}" in magLabel}
467 if kwargs_in_label:
468 magLabel = magLabel.format(**kwargs_in_label)
470 for name_datatype in self.plotTypes:
471 config_datatype = self._datatypes[name_datatype]
472 highArgs = {}
473 lowArgs = {}
474 if not self.publicationStyle:
475 for name in self._stats:
476 highArgs[name] = cast(
477 Scalar,
478 data[
479 f"{{band}}_highSN{config_datatype.suffix_stat}_{name}{suf_stat}".format(**kwargs)
480 ],
481 )
482 lowArgs[name] = cast(
483 Scalar,
484 data[
485 f"{{band}}_lowSN{config_datatype.suffix_stat}_{name}{suf_stat}".format(**kwargs)
486 ],
487 )
488 highStats = _StatsContainer(**highArgs)
489 lowStats = _StatsContainer(**lowArgs)
491 toPlotList.append(
492 (
493 data[f"x{config_datatype.suffix_xy}{suf_x}"],
494 data[f"y{config_datatype.suffix_xy}{suf_y}"],
495 data[f"{name_datatype}HighSNMask{suf_stat}"],
496 data[f"{name_datatype}LowSNMask{suf_stat}"],
497 data[f"{name_datatype}HighSNThreshold{suf_stat}"],
498 data[f"{name_datatype}LowSNThreshold{suf_stat}"],
499 config_datatype.color,
500 config_datatype.colormap,
501 highStats,
502 lowStats,
503 )
504 )
505 else:
506 toPlotList.append(
507 (
508 data[f"x{config_datatype.suffix_xy}{suf_x}"],
509 data[f"y{config_datatype.suffix_xy}{suf_y}"],
510 [],
511 [],
512 [],
513 [],
514 config_datatype.color,
515 config_datatype.colormap,
516 [],
517 [],
518 )
519 )
521 xLims = self.xLims if self.xLims is not None else [np.inf, -np.inf]
523 # If there is no data to plot make a
524 # no data figure
525 numData = 0
526 for xs, _, _, _, _, _, _, _, _, _ in toPlotList:
527 numData += np.count_nonzero(np.isfinite(xs))
528 if numData == 0:
529 return None, None
531 for j, (
532 xs,
533 ys,
534 highSn,
535 lowSn,
536 highThresh,
537 lowThresh,
538 color,
539 cmap,
540 highStats,
541 lowStats,
542 ) in enumerate(toPlotList):
543 highSn = cast(Vector, highSn)
544 lowSn = cast(Vector, lowSn)
545 # ensure the columns are actually array
546 xs = np.array(xs)
547 ys = np.array(ys)
548 sigMadYs = nanSigmaMad(ys)
549 # plot lone median point if there's not enough data to measure more
550 n_xs = np.count_nonzero(np.isfinite(xs))
551 if n_xs <= 1 or not (np.isfinite(sigMadYs) and sigMadYs >= 0.0):
552 continue
553 elif n_xs < min_n_xs_for_stats:
554 xs = [nanMedian(xs)]
555 sigMads = np.array([nanSigmaMad(ys)])
556 ys = np.array([nanMedian(ys)])
557 (medLine,) = ax.plot(xs, ys, color, label=f"Median: {ys[0]:.2g}", lw=0.8)
558 linesForLegend.append(medLine)
559 (sigMadLine,) = ax.plot(
560 xs,
561 ys + 1.0 * sigMads,
562 color,
563 alpha=0.8,
564 lw=0.8,
565 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:.2g}",
566 )
567 ax.plot(xs, ys - 1.0 * sigMads, color, alpha=0.8)
568 linesForLegend.append(sigMadLine)
569 histIm = None
570 continue
572 if self.xLims:
573 xMin, xMax = self.xLims
574 else:
575 # Chop off 1/3% from only the finite xs values
576 # (there may be +/-np.Inf values)
577 # TODO: This should be configurable
578 # but is there a good way to avoid redundant config params
579 # without using slightly annoying subconfigs?
580 xs1, xs97 = np.nanpercentile(xs[np.isfinite(xs)], (1, 97))
581 xScale = (xs97 - xs1) / 20.0 # This is ~5% of the data range
582 xMin, xMax = (xs1 - xScale, xs97 + xScale)
583 xLims[0] = min(xLims[0], xMin)
584 xLims[1] = max(xLims[1], xMax)
586 xEdges = np.arange(xMin, xMax, (xMax - xMin) / self.nBins)
587 medYs = nanMedian(ys)
588 fiveSigmaHigh = medYs + 5.0 * sigMadYs
589 fiveSigmaLow = medYs - 5.0 * sigMadYs
590 binSize = (fiveSigmaHigh - fiveSigmaLow) / 101.0
591 # When the binsize is 0 try using the 1st and 99th
592 # percentile instead of the sigmas.
593 if binSize == 0.0:
594 p1, p99 = np.nanpercentile(ys, [1, 99])
595 binSize = (p99 - p1) / 101.0
597 # If fiveSigmaHigh and fiveSigmaLow are the same
598 # then use the 1st and 99th percentiles to define the
599 # yEdges.
600 yEdges = np.arange(fiveSigmaLow, fiveSigmaHigh, binSize)
601 if fiveSigmaLow == fiveSigmaHigh:
602 yEdges = np.arange(p1, p99, binSize)
604 counts, xBins, yBins = np.histogram2d(xs, ys, bins=(xEdges, yEdges))
605 yBinsOut.append(yBins)
606 countsYs = np.sum(counts, axis=1)
608 ids = np.where((countsYs > binThresh))[0]
609 xEdgesPlot = xEdges[ids][1:]
610 xEdges = xEdges[ids]
612 if len(ids) > 1:
613 # Create the codes needed to turn the sigmaMad lines
614 # into a path to speed up checking which points are
615 # inside the area.
616 codes = np.ones(len(xEdgesPlot) * 2) * Path.LINETO
617 codes[0] = Path.MOVETO
618 codes[-1] = Path.CLOSEPOLY
620 meds = np.zeros(len(xEdgesPlot))
621 threeSigMadVerts = np.zeros((len(xEdgesPlot) * 2, 2))
622 sigMads = np.zeros(len(xEdgesPlot))
624 for i, xEdge in enumerate(xEdgesPlot):
625 ids = np.where((xs < xEdge) & (xs > xEdges[i]) & (np.isfinite(ys)))[0]
626 med = nanMedian(ys[ids])
627 sigMad = nanSigmaMad(ys[ids])
628 meds[i] = med
629 sigMads[i] = sigMad
630 threeSigMadVerts[i, :] = [xEdge, med + 3 * sigMad]
631 threeSigMadVerts[-(i + 1), :] = [xEdge, med - 3 * sigMad]
633 if self.publicationStyle:
634 linecolor = "k"
635 else:
636 linecolor = color
638 (medLine,) = ax.plot(xEdgesPlot, meds, linecolor, label="Running Median")
639 linesForLegend.append(medLine)
641 # Make path to check which points lie within one sigma mad
642 threeSigMadPath = Path(threeSigMadVerts, codes)
644 if not self.publicationStyle:
645 # Add lines for the median +/- 3 * sigma MAD
646 (threeSigMadLine,) = ax.plot(
647 xEdgesPlot,
648 threeSigMadVerts[: len(xEdgesPlot), 1],
649 linecolor,
650 alpha=0.4,
651 label=r"3$\sigma_{MAD}$",
652 )
653 ax.plot(xEdgesPlot[::-1], threeSigMadVerts[len(xEdgesPlot) :, 1], linecolor, alpha=0.4)
655 # Add lines for the median +/- 1 * sigma MAD
656 (sigMadLine,) = ax.plot(
657 xEdgesPlot,
658 meds + 1.0 * sigMads,
659 linecolor,
660 alpha=0.8,
661 label=r"$\sigma_{MAD}$",
662 ls="dashed",
663 )
664 linesForLegend.append(sigMadLine)
665 ax.plot(xEdgesPlot, meds - 1.0 * sigMads, linecolor, alpha=0.8, ls="dashed")
667 if not self.publicationStyle:
668 # Add lines for the median +/- 2 * sigma MAD
669 (twoSigMadLine,) = ax.plot(
670 xEdgesPlot, meds + 2.0 * sigMads, linecolor, alpha=0.6, label=r"2$\sigma_{MAD}$"
671 )
672 linesForLegend.append(twoSigMadLine)
673 linesForLegend.append(threeSigMadLine)
674 ax.plot(xEdgesPlot, meds - 2.0 * sigMads, linecolor, alpha=0.6)
676 # Check which points are outside 3 sigma MAD of the median
677 # and plot these as points.
678 inside = threeSigMadPath.contains_points(np.array([xs, ys]).T)
679 ax.plot(xs[~inside], ys[~inside], ".", ms=5, alpha=0.2, mfc=color, mec="none", zorder=-1)
681 if not self.publicationStyle:
682 # Add some stats text
683 xPos = 0.65 - 0.4 * j
684 bbox = dict(edgecolor=color, linestyle="--", facecolor="none")
685 statText = f"S/N > {highThresh:0.4g} Stats ({magLabel} < {highStats.approxMag:0.4g})\n"
686 highStatsStr = (
687 f"Median: {highStats.median:0.4g} "
688 + r"$\sigma_{MAD}$: "
689 + f"{highStats.sigmaMad:0.4g} "
690 + r"N$_{points}$: "
691 + f"{highStats.count}"
692 )
693 statText += highStatsStr
694 fig.text(xPos, 0.090, statText, bbox=bbox, transform=fig.transFigure, fontsize=6)
696 bbox = dict(edgecolor=color, linestyle=":", facecolor="none")
697 statText = f"S/N > {lowThresh:0.4g} Stats ({magLabel} < {lowStats.approxMag:0.4g})\n"
698 lowStatsStr = (
699 f"Median: {lowStats.median:0.4g} "
700 + r"$\sigma_{MAD}$: "
701 + f"{lowStats.sigmaMad:0.4g} "
702 + r"N$_{points}$: "
703 + f"{lowStats.count}"
704 )
705 statText += lowStatsStr
706 fig.text(xPos, 0.020, statText, bbox=bbox, transform=fig.transFigure, fontsize=6)
708 if self.plot2DHist:
709 extent = [xLims[0], xLims[1], self.yLims[0], self.yLims[1]] if self.yLims else None
710 histIm = ax.hexbin(
711 xs[inside],
712 ys[inside],
713 gridsize=75,
714 extent=extent,
715 cmap=cmap,
716 mincnt=1,
717 zorder=-3,
718 edgecolors=None,
719 )
720 else:
721 ax.plot(xs[inside], ys[inside], ".", ms=3, alpha=0.2, mfc=color, mec=color, zorder=-1)
723 if not self.publicationStyle:
724 # If there are not many sources being used for the
725 # statistics then plot them individually as just
726 # plotting a line makes the statistics look wrong
727 # as the magnitude estimation is iffy for low
728 # numbers of sources.
729 if np.sum(highSn) < 100 and np.sum(highSn) > 0:
730 ax.plot(
731 cast(Vector, xs[highSn]),
732 cast(Vector, ys[highSn]),
733 marker="x",
734 ms=4,
735 mec="w",
736 mew=2,
737 ls="none",
738 )
739 (highSnLine,) = ax.plot(
740 cast(Vector, xs[highSn]),
741 cast(Vector, ys[highSn]),
742 color=color,
743 marker="x",
744 ms=4,
745 ls="none",
746 label="High SN",
747 )
748 linesForLegend.append(highSnLine)
749 else:
750 ax.axvline(highStats.approxMag, color=color, ls="--")
752 if np.sum(lowSn) < 100 and np.sum(lowSn) > 0:
753 ax.plot(
754 cast(Vector, xs[lowSn]),
755 cast(Vector, ys[lowSn]),
756 marker="+",
757 ms=4,
758 mec="w",
759 mew=2,
760 ls="none",
761 )
762 (lowSnLine,) = ax.plot(
763 cast(Vector, xs[lowSn]),
764 cast(Vector, ys[lowSn]),
765 color=color,
766 marker="+",
767 ms=4,
768 ls="none",
769 label="Low SN",
770 )
771 linesForLegend.append(lowSnLine)
772 else:
773 ax.axvline(lowStats.approxMag, color=color, ls=":")
775 else:
776 ax.plot(xs, ys, ".", ms=5, alpha=0.3, mfc=color, mec=color, zorder=-1)
777 meds = np.array([nanMedian(ys)] * len(xs))
778 (medLine,) = ax.plot(xs, meds, color, label=f"Median: {nanMedian(ys):0.3g}", lw=0.8)
779 linesForLegend.append(medLine)
780 sigMads = np.array([nanSigmaMad(ys)] * len(xs))
781 (sigMadLine,) = ax.plot(
782 xs,
783 meds + 1.0 * sigMads,
784 color,
785 alpha=0.8,
786 lw=0.8,
787 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:0.3g}",
788 )
789 ax.plot(xs, meds - 1.0 * sigMads, color, alpha=0.8)
790 linesForLegend.append(sigMadLine)
791 histIm = None
793 # Add a horizontal reference line at 0 to the scatter plot
794 ax.axhline(0, color=kwargs["hlineColor"], ls=kwargs["hlineStyle"], alpha=0.7, zorder=-2)
796 # Set the scatter plot limits
797 suf_x = self.suffix_y
798 # TODO: Make this not work by accident
799 if f"yStars{suf_x}" in data and (len(cast(Vector, data[f"yStars{suf_x}"])) > 0):
800 plotMed = nanMedian(cast(Vector, data[f"yStars{suf_x}"]))
801 elif f"yGalaxies{suf_x}" in data and (len(cast(Vector, data[f"yGalaxies{suf_x}"])) > 0):
802 plotMed = nanMedian(cast(Vector, data[f"yGalaxies{suf_x}"]))
803 else:
804 plotMed = np.nan
806 # Ignore types below pending making this not working my accident
807 # If len(xs) < min_n_xs_for_stats then `meds` doesn't exist.
808 if len(xs) < min_n_xs_for_stats: # type: ignore
809 meds = [nanMedian(ys)] # type: ignore
810 if self.yLims:
811 ax.set_ylim(self.yLims[0], self.yLims[1]) # type: ignore
812 elif np.isfinite(plotMed):
813 numSig = 4
814 yLimMin = plotMed - numSig * sigMadYs # type: ignore
815 yLimMax = plotMed + numSig * sigMadYs # type: ignore
816 while (yLimMax < np.max(meds) or yLimMin > np.min(meds)) and numSig < 10: # type: ignore
817 numSig += 1
819 numSig += 1
820 yLimMin = plotMed - numSig * sigMadYs # type: ignore
821 yLimMax = plotMed + numSig * sigMadYs # type: ignore
823 # If len(y) == 0 for ys in toPlotList then sigMadY = NaN
824 # ... in which case nothing was plotted and limits are irrelevant.
825 if all(np.isfinite([yLimMin, yLimMax])):
826 ax.set_ylim(yLimMin, yLimMax)
828 # This could be false if len(x) == 0 for xs in toPlotList
829 # ... in which case nothing was plotted and limits are irrelevant
830 if all(np.isfinite(xLims)):
831 ax.set_xlim(xLims)
833 # Add a line legend
834 ax.legend(
835 handles=linesForLegend,
836 ncol=4,
837 fontsize=6,
838 loc=self.legendLocation,
839 framealpha=0.9,
840 edgecolor="k",
841 borderpad=0.4,
842 handlelength=3,
843 )
845 # Add axes labels
846 band = kwargs.get("band", "unspecified")
847 xlabel = self.xAxisLabel
848 ylabel = self.yAxisLabel
849 if "{band}" in xlabel:
850 xlabel = xlabel.format(band=band)
851 if "{band}" in ylabel:
852 ylabel = ylabel.format(band=band)
853 if self.publicationStyle:
854 ax.set_ylabel(ylabel, labelpad=10)
855 ax.set_xlabel(xlabel, labelpad=2)
856 else:
857 ax.set_ylabel(ylabel, labelpad=10, fontsize=10)
858 ax.set_xlabel(xlabel, labelpad=2, fontsize=10)
859 ax.tick_params(labelsize=8)
861 return ax, histIm
863 def _makeTopHistogram(
864 self, data: KeyedData, figure: Figure, gs: gridspec.GridSpec, ax: Axes, **kwargs
865 ) -> None:
866 suf_x = self.suffix_x
867 # Top histogram
868 topHist = figure.add_subplot(gs[0, :-1], sharex=ax)
869 x_min, x_max = ax.get_xlim()
870 bins = np.linspace(x_min, x_max, 100)
872 if "any" in self.plotTypes:
873 x_any = f"x{self._datatypes['any'].suffix_xy}{suf_x}"
874 keys_notany = [x for x in self.plotTypes if x != "any"]
875 else:
876 x_any = (
877 np.concatenate([data[f"x{self._datatypes[key].suffix_xy}{suf_x}"] for key in self.plotTypes])
878 if (len(self.plotTypes) > 1)
879 else None
880 )
881 keys_notany = self.plotTypes
882 if x_any is not None:
883 if np.sum(x_any > 0) > 0:
884 log = True
885 else:
886 log = False
887 topHist.hist(x_any, bins=bins, color="grey", alpha=0.3, log=log, label=f"Any ({len(x_any)})")
889 for key in keys_notany:
890 config_datatype = self._datatypes[key]
891 vector = np.array(data[f"x{config_datatype.suffix_xy}{suf_x}"])
892 if np.sum(vector > 0) > 0:
893 log = True
894 else:
895 log = False
896 topHist.hist(
897 vector,
898 bins=bins,
899 color=config_datatype.color,
900 histtype="step",
901 log=log,
902 label=f"{config_datatype.suffix_stat} ({len(vector)})",
903 )
904 topHist.axes.get_xaxis().set_visible(False)
905 topHist.set_ylabel("Count", fontsize=10 + 4 * self.publicationStyle)
906 if not self.publicationStyle:
907 topHist.legend(fontsize=6, framealpha=0.9, borderpad=0.4, loc="lower left", ncol=3, edgecolor="k")
908 topHist.tick_params(labelsize=8)
910 self._modifyHistogramTicks(topHist, do_x=False, max_labels=self.xHistMaxLabels)
912 def _makeSideHistogram(
913 self,
914 data: KeyedData,
915 figure: Figure,
916 gs: gridspec.Gridspec,
917 ax: Axes,
918 histIm: Optional[PolyCollection],
919 **kwargs,
920 ) -> None:
921 suf_y = self.suffix_y
922 # Side histogram
923 sideHist = figure.add_subplot(gs[1:, -1], sharey=ax)
924 y_min, y_max = ax.get_ylim()
925 bins = np.linspace(y_min, y_max, 100)
927 if "any" in self.plotTypes:
928 y_any = np.array(data[f"y{self._datatypes['any'].suffix_xy}{suf_y}"])
929 keys_notany = [x for x in self.plotTypes if x != "any"]
930 else:
931 y_any = (
932 np.concatenate(
933 [np.array(data[f"y{self._datatypes[key].suffix_xy}{suf_y}"]) for key in self.plotTypes]
934 )
935 if (len(self.plotTypes) > 1)
936 else None
937 )
938 keys_notany = self.plotTypes
939 if y_any is not None:
940 sideHist.hist(
941 np.array(y_any),
942 bins=bins,
943 color="grey",
944 alpha=0.3,
945 orientation="horizontal",
946 log=np.any(y_any > 0),
947 )
949 kwargs_hist = dict(
950 bins=bins,
951 histtype="step",
952 log=True,
953 orientation="horizontal",
954 )
955 for key in keys_notany:
956 config_datatype = self._datatypes[key]
957 # If the data has no positive values then it
958 # cannot be log scaled and it prints a bunch
959 # of irritating warnings, in this case don't
960 # try.
961 numPos = np.sum(data[f"y{config_datatype.suffix_xy}{suf_y}"] > 0)
963 if numPos <= 0:
964 kwargs_hist["log"] = False
966 vector = data[f"y{config_datatype.suffix_xy}{suf_y}"]
967 sideHist.hist(
968 vector,
969 color=config_datatype.color,
970 **kwargs_hist,
971 )
972 if not self.publicationStyle:
973 sideHist.hist(
974 vector[cast(Vector, data[f"{key}HighSNMask{self.suffix_stat}"])],
975 color=config_datatype.color,
976 linestyle="--",
977 **kwargs_hist,
978 )
979 sideHist.hist(
980 vector[cast(Vector, data[f"{key}LowSNMask{self.suffix_stat}"])],
981 color=config_datatype.color,
982 **kwargs_hist,
983 linestyle=":",
984 )
986 # Add a horizontal reference line at 0 to the side histogram
987 sideHist.axhline(0, color=kwargs["hlineColor"], ls=kwargs["hlineStyle"], alpha=0.7, zorder=-2)
989 sideHist.axes.get_yaxis().set_visible(False)
990 sideHist.set_xlabel("Count", fontsize=10 + 4 * self.publicationStyle)
991 self._modifyHistogramTicks(sideHist, do_x=True, max_labels=self.yHistMaxLabels)
993 if not self.publicationStyle:
994 sideHist.tick_params(labelsize=8)
995 if self.plot2DHist and histIm is not None:
996 divider = make_axes_locatable(sideHist)
997 cax = divider.append_axes("right", size="25%", pad=0)
998 sideHist.get_figure().colorbar(histIm, cax=cax, orientation="vertical")
999 text = cax.text(
1000 0.5,
1001 0.5,
1002 "Points Per Bin",
1003 color="k",
1004 rotation="vertical",
1005 transform=cax.transAxes,
1006 ha="center",
1007 va="center",
1008 fontsize=10,
1009 )
1010 text.set_path_effects([pathEffects.Stroke(linewidth=3, foreground="w"), pathEffects.Normal()])
1012 def _modifyHistogramTicks(self, histogram, do_x: bool, max_labels: int):
1013 axis = histogram.get_xaxis() if do_x else histogram.get_yaxis()
1014 limits = list(histogram.get_xlim() if do_x else histogram.get_ylim())
1015 get_ticks = histogram.get_xticks if do_x else histogram.get_yticks
1016 ticks = get_ticks()
1017 # Let the minimum be larger then specified if the histogram has large
1018 # values everywhere, but cut it down a little so the lowest-valued bin
1019 # is still easily visible
1020 limits[0] = max(self.histMinimum, 0.9 * limits[0])
1021 # Round the upper limit to the nearest power of 10
1022 limits[1] = 10 ** (np.ceil(np.log10(limits[1]))) if (limits[1] > 0) else limits[1]
1023 for minor in (False, True):
1024 # Ignore ticks that are below the minimum value
1025 valid = (ticks >= limits[0]) & (ticks <= limits[1])
1026 labels = [label for label, _valid in zip(axis.get_ticklabels(minor=minor), valid) if _valid]
1027 if (n_labels := len(labels)) > max_labels:
1028 labels_new = [""] * n_labels
1029 # Skip the first label if we're not using minor axis labels
1030 # This helps avoid overlap with the scatter plot labels
1031 for idx_fill in np.round(np.linspace(1 - minor, n_labels - 1, max_labels)).astype(int):
1032 labels_new[idx_fill] = labels[idx_fill]
1033 axis.set_ticks(ticks[valid], labels_new)
1034 # If there are enough major tick labels, disable minor tick labels
1035 if len(labels) >= 2:
1036 axis.set_minor_formatter(NullFormatter())
1037 break
1038 else:
1039 axis.set_minor_formatter(
1040 LogFormatterExponentSci(minor_thresholds=(1, self.histMinimum / 10.0))
1041 )
1042 ticks = get_ticks(minor=True)
1044 (histogram.set_xlim if do_x else histogram.set_ylim)(limits)