Coverage for python / lsst / analysis / tools / actions / plot / scatterplotWithTwoHists.py: 13%
445 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-18 09:19 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-18 09:19 +0000
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ("ScatterPlotStatsAction", "ScatterPlotWithTwoHists")
26import math
27from collections.abc import Mapping
28from typing import NamedTuple, cast
30import matplotlib.colors
31import matplotlib.patheffects as pathEffects
32import numpy as np
33from matplotlib import gridspec
34from matplotlib.axes import Axes
35from matplotlib.collections import PolyCollection
36from matplotlib.figure import Figure
37from matplotlib.path import Path
38from matplotlib.ticker import LogFormatterMathtext, NullFormatter
39from mpl_toolkits.axes_grid1 import make_axes_locatable
41from lsst.pex.config import Field
42from lsst.pex.config.configurableActions import ConfigurableActionField
43from lsst.pex.config.listField import ListField
44from lsst.utils.plotting import (
45 galaxies_cmap,
46 galaxies_color,
47 make_figure,
48 set_rubin_plotstyle,
49 stars_cmap,
50 stars_color,
51)
53from ...interfaces import KeyedData, KeyedDataAction, KeyedDataSchema, PlotAction, Scalar, ScalarType, Vector
54from ...math import nanMedian, nanSigmaMad
55from ..keyedData.summaryStatistics import SummaryStatisticAction
56from ..scalar import MedianAction
57from ..vector import ConvertFluxToMag, SnSelector
58from .plotUtils import addPlotInfo, addSummaryPlot, generateSummaryStats
61class ScatterPlotStatsAction(KeyedDataAction):
62 """Calculates the statistics needed for the
63 scatter plot with two hists.
64 """
66 vectorKey = Field[str](doc="Vector on which to compute statistics")
67 highSNSelector = ConfigurableActionField[SnSelector](
68 doc="Selector used to determine high SN Objects", default=SnSelector(threshold=2700)
69 )
70 lowSNSelector = ConfigurableActionField[SnSelector](
71 doc="Selector used to determine low SN Objects", default=SnSelector(threshold=500)
72 )
73 prefix = Field[str](
74 doc="Prefix for all output fields; will use self.identity if None",
75 optional=True,
76 default=None,
77 )
78 fluxType = Field[str](doc="Vector key to use to compute signal to noise ratio", default="{band}_psfFlux")
79 suffix = Field[str](doc="Suffix for all output fields", default="")
81 def _get_key_prefix(self):
82 prefix = self.prefix if self.prefix else (self.identity if self.identity else "")
83 return prefix
85 def getInputSchema(self, **kwargs) -> KeyedDataSchema:
86 yield (self.vectorKey, Vector)
87 yield (self.fluxType, Vector)
88 yield from self.highSNSelector.getInputSchema()
89 yield from self.lowSNSelector.getInputSchema()
91 def getOutputSchema(self) -> KeyedDataSchema:
92 prefix = self._get_key_prefix()
93 prefix_lower = prefix.lower() if prefix else ""
94 prefix_upper = prefix.capitalize() if prefix else ""
95 suffix = self.suffix
96 return (
97 (f"{prefix_lower}HighSNMask{suffix}", Vector),
98 (f"{prefix_lower}LowSNMask{suffix}", Vector),
99 (f"{{band}}_lowSN{prefix_upper}_median{suffix}", Scalar),
100 (f"{{band}}_lowSN{prefix_upper}_sigmaMad{suffix}", Scalar),
101 (f"{{band}}_lowSN{prefix_upper}_count{suffix}", Scalar),
102 (f"{{band}}_lowSN{prefix_upper}_approxMag{suffix}", Scalar),
103 (f"{{band}}_highSN{prefix_upper}_median{suffix}", Scalar),
104 (f"{{band}}_highSN{prefix_upper}_sigmaMad{suffix}", Scalar),
105 (f"{{band}}_highSN{prefix_upper}_count{suffix}", Scalar),
106 (f"{{band}}_highSN{prefix_upper}_approxMag{suffix}", Scalar),
107 (f"{prefix_lower}HighSNThreshold{suffix}", Scalar),
108 (f"{prefix_lower}LowSNThreshold{suffix}", Scalar),
109 )
111 def __call__(self, data: KeyedData, **kwargs) -> KeyedData:
112 results = {}
113 prefix = self._get_key_prefix()
114 prefix_lower = prefix.lower() if prefix else ""
115 prefix_upper = prefix.capitalize() if prefix else ""
116 suffix = self.suffix
117 highMaskKey = f"{prefix_lower}HighSNMask{suffix}"
118 results[highMaskKey] = self.highSNSelector(data, **kwargs)
120 lowMaskKey = f"{prefix_lower}LowSNMask{suffix}"
121 results[lowMaskKey] = self.lowSNSelector(data, **kwargs)
123 prefix_band = f"{band}_" if (band := kwargs.get("band")) else ""
124 fluxes = data[self.fluxType.format(band=band)] if band is not None else None
126 statAction = SummaryStatisticAction(vectorKey=self.vectorKey)
128 # this is sad, but pex_config seems to have broken behavior that
129 # is dangerous to fix
130 statAction.setDefaults()
132 medianAction = MedianAction(vectorKey="mag")
133 magAction = ConvertFluxToMag(vectorKey="flux")
135 for maskKey, binName in ((lowMaskKey, "low"), (highMaskKey, "high")):
136 name = f"{prefix_band}{binName}SN{prefix_upper}"
137 # set the approxMag to the median mag in the SN selection
138 results[f"{name}_approxMag{suffix}".format(**kwargs)] = (
139 medianAction({"mag": magAction({"flux": fluxes[results[maskKey]]})}) # type: ignore
140 if band is not None
141 else np.nan
142 )
143 stats = statAction(data, **(kwargs | {"mask": results[maskKey]})).items()
144 for name_stat, value in stats:
145 tmpKey = f"{name}_{name_stat}{suffix}".format(**kwargs)
146 results[tmpKey] = value
147 results[f"{prefix_lower}HighSNThreshold{suffix}"] = self.highSNSelector.threshold # type: ignore
148 results[f"{prefix_lower}LowSNThreshold{suffix}"] = self.lowSNSelector.threshold # type: ignore
150 return results
153def _validObjectTypes(value):
154 return value in ("stars", "galaxies", "unknown", "any")
157# ignore type because of conflicting name on tuple baseclass
158class _StatsContainer(NamedTuple):
159 median: Scalar
160 sigmaMad: Scalar
161 count: Scalar # type: ignore
162 approxMag: Scalar
165class DataTypeDefaults(NamedTuple):
166 suffix_stat: str
167 suffix_xy: str
168 color: str
169 colormap: matplotlib.colors.Colormap | None
172class LogFormatterExponentSci(LogFormatterMathtext):
173 """
174 Format values following scientific notation.
176 Unlike the matplotlib LogFormatterExponent, this will print near-integer
177 coefficients with a base between 0 and 2 such as 500 as 500 (if base10)
178 or 5e2 otherwise.
179 """
181 def _non_decade_format(self, sign_string, base, fx, usetex):
182 """Return string for non-decade locations."""
183 b = float(base)
184 exponent = math.floor(fx)
185 coeff = b ** (fx - exponent)
186 rounded = round(coeff)
187 if math.isclose(coeff, rounded):
188 if (base == "10") and (0 <= exponent <= 3):
189 return f"{sign_string}{rounded}{'0'*int(exponent)}"
190 coeff = rounded
191 return f"{sign_string}{coeff:1.1f}e{exponent}"
194class ScatterPlotWithTwoHists(PlotAction):
195 """Makes a scatter plot of the data with a marginal
196 histogram for each axis.
197 """
199 yLims = ListField[float](
200 doc="ylimits of the plot, if not specified determined from data",
201 length=2,
202 optional=True,
203 )
205 xLims = ListField[float](
206 doc="xlimits of the plot, if not specified determined from data", length=2, optional=True
207 )
208 xAxisLabel = Field[str](doc="Label to use for the x axis", optional=False)
209 yAxisLabel = Field[str](doc="Label to use for the y axis", optional=False)
210 magLabel = Field[str](doc="Label to use for the magnitudes used for SNR", optional=False)
212 legendLocation = Field[str](doc="Legend position within main plot", default="upper left")
213 nBins = Field[float](doc="Number of bins on x axis", default=40.0)
214 plot2DHist = Field[bool](
215 doc="Plot a 2D histogram in dense areas of points on the scatter plot."
216 "Doesn't look great if plotting multiple datasets on top of each other.",
217 default=True,
218 )
219 plotTypes = ListField[str](
220 doc="Selection of types of objects to plot. Can take any combination of"
221 " stars, galaxies, unknown, any",
222 optional=False,
223 itemCheck=_validObjectTypes,
224 )
226 addSummaryPlot = Field[bool](
227 doc="Add a summary plot to the figure?",
228 default=True,
229 )
230 histMinimum = Field[float](
231 doc="Minimum value for the histogram count axis",
232 default=0.3,
233 )
234 xHistMaxLabels = Field[int](
235 doc="Maximum number of labels for ticks on the x-axis marginal histogram",
236 default=3,
237 check=lambda x: x >= 2,
238 )
239 yHistMaxLabels = Field[int](
240 doc="Maximum number of labels for ticks on the y-axis marginal histogram",
241 default=3,
242 check=lambda x: x >= 2,
243 )
245 suffix_x = Field[str](doc="Suffix for all x-axis action inputs", optional=True, default="")
246 suffix_y = Field[str](doc="Suffix for all y-axis action inputs", optional=True, default="")
247 suffix_stat = Field[str](doc="Suffix for all binned statistic action inputs", optional=True, default="")
249 publicationStyle = Field[bool](doc="Slimmed down publication style plot?", default=False)
251 _stats = ("median", "sigmaMad", "count", "approxMag")
252 _datatypes = {
253 "galaxies": DataTypeDefaults(
254 suffix_stat="Galaxies",
255 suffix_xy="Galaxies",
256 color=galaxies_color(),
257 colormap=galaxies_cmap(single_color=True),
258 ),
259 "stars": DataTypeDefaults(
260 suffix_stat="Stars",
261 suffix_xy="Stars",
262 color=stars_color(),
263 colormap=stars_cmap(single_color=True),
264 ),
265 "unknown": DataTypeDefaults(
266 suffix_stat="Unknown",
267 suffix_xy="Unknown",
268 color="green",
269 colormap=None,
270 ),
271 "any": DataTypeDefaults(
272 suffix_stat="Any",
273 suffix_xy="",
274 color="purple",
275 colormap=None,
276 ),
277 }
279 def getInputSchema(self) -> KeyedDataSchema:
280 base: list[tuple[str, type[Vector] | ScalarType]] = []
281 for name_datatype in self.plotTypes:
282 config_datatype = self._datatypes[name_datatype]
283 if not self.publicationStyle:
284 base.append((f"x{config_datatype.suffix_xy}{self.suffix_x}", Vector))
285 base.append((f"y{config_datatype.suffix_xy}{self.suffix_y}", Vector))
286 base.append((f"{name_datatype}HighSNMask{self.suffix_stat}", Vector))
287 base.append((f"{name_datatype}LowSNMask{self.suffix_stat}", Vector))
288 # statistics
289 for name in self._stats:
290 base.append(
291 (f"{{band}}_highSN{config_datatype.suffix_stat}_{name}{self.suffix_stat}", Scalar)
292 )
293 base.append(
294 (f"{{band}}_lowSN{config_datatype.suffix_stat}_{name}{self.suffix_stat}", Scalar)
295 )
296 base.append((f"{name_datatype}LowSNThreshold{self.suffix_stat}", Scalar))
297 base.append((f"{name_datatype}HighSNThreshold{self.suffix_stat}", Scalar))
299 if self.addSummaryPlot and not self.publicationStyle:
300 base.append(("patch", Vector))
302 return base
304 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure:
305 self._validateInput(data, **kwargs)
306 return self.makePlot(data, **kwargs)
308 def _validateInput(self, data: KeyedData, **kwargs) -> None:
309 """NOTE currently can only check that something is not a Scalar, not
310 check that the data is consistent with Vector
311 """
312 needed = self.getFormattedInputSchema(**kwargs)
313 if remainder := {key.format(**kwargs) for key, _ in needed} - {
314 key.format(**kwargs) for key in data.keys()
315 }:
316 raise ValueError(
317 f"Task needs keys {remainder} but they were not found in input keys" f" {list(data.keys())}"
318 )
319 for name, typ in needed:
320 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar)
321 if isScalar and typ != Scalar:
322 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}")
324 def makePlot(
325 self,
326 data: KeyedData,
327 plotInfo: Mapping[str, str],
328 **kwargs,
329 ) -> Figure:
330 """Makes a generic plot with a 2D histogram and collapsed histograms of
331 each axis.
333 Parameters
334 ----------
335 data : `KeyedData`
336 The catalog to plot the points from.
337 plotInfo : `dict`
338 A dictionary of information about the data being plotted with keys:
340 * ``"run"``
341 The output run for the plots (`str`).
342 * ``"skymap"``
343 The type of skymap used for the data (`str`).
344 * ``"filter"``
345 The filter used for this data (`str`).
346 * ``"tract"``
347 The tract that the data comes from (`str`).
349 Returns
350 -------
351 fig : `matplotlib.figure.Figure`
352 The resulting figure.
354 Notes
355 -----
356 Uses the axisLabels config options `x` and `y` and the axisAction
357 config options `xAction` and `yAction` to plot a scatter
358 plot of the values against each other. A histogram of the points
359 collapsed onto each axis is also plotted. A summary panel showing the
360 median of the y value in each patch is shown in the upper right corner
361 of the resultant plot. The code uses the selectorActions to decide
362 which points to plot and the statisticSelector actions to determine
363 which points to use for the printed statistics.
365 If this function is being used within the pipetask framework
366 that takes care of making sure that data has all the required
367 elements but if you are running this as a standalone function
368 then you will need to provide the following things in the
369 input data.
371 * If stars is in self.plotTypes:
372 xStars, yStars, starsHighSNMask, starsLowSNMask and
373 {band}_highSNStars_{name}, {band}_lowSNStars_{name}
374 where name is median, sigma_Mad, count and approxMag.
376 * If it is for galaxies/unknowns then replace stars in the above
377 names with galaxies/unknowns.
379 * If it is for any (which covers all the points) then it
380 becomes, x, y, and any instead of stars for the other
381 parameters given above.
383 * In every case it is expected that data contains:
384 lowSnThreshold, highSnThreshold and patch
385 (if the summary plot is being plotted).
387 Examples
388 --------
389 An example of the plot produced from this code is here:
391 .. image:: /_static/analysis_tools/scatterPlotExample.png
393 For a detailed example of how to make a plot from the command line
394 please see the
395 :ref:`getting started guide<analysis-tools-getting-started>`.
396 """
397 if not self.plotTypes:
398 noDataFig = Figure()
399 noDataFig.text(0.3, 0.5, "No data to plot after selectors applied")
400 noDataFig = addPlotInfo(noDataFig, plotInfo)
401 return noDataFig
403 # Set default color and line style for the horizontal
404 # reference line at 0
405 if "hlineColor" not in kwargs:
406 kwargs["hlineColor"] = "black"
408 if "hlineStyle" not in kwargs:
409 kwargs["hlineStyle"] = (0, (1, 4))
411 set_rubin_plotstyle()
412 fig = make_figure()
413 gs = gridspec.GridSpec(4, 4)
415 # add the various plot elements
416 ax, imhist = self._scatterPlot(data, fig, gs, **kwargs)
417 if ax is None:
418 noDataFig = Figure()
419 noDataFig.text(0.3, 0.5, "No data to plot after selectors applied")
420 if not self.publicationStyle:
421 noDataFig = addPlotInfo(noDataFig, plotInfo)
422 return noDataFig
424 self._makeTopHistogram(data, fig, gs, ax, **kwargs)
425 self._makeSideHistogram(data, fig, gs, ax, imhist, **kwargs)
426 # Needs info from run quantum
427 skymap = kwargs.get("skymap", None)
428 if self.addSummaryPlot and skymap is not None and not self.publicationStyle:
429 sumStats = generateSummaryStats(data, skymap, plotInfo)
430 label = self.yAxisLabel
431 fig = addSummaryPlot(fig, gs[0, -1], sumStats, label)
433 fig.canvas.draw()
434 # TODO: Check if these spacings can be defined less arbitrarily
435 fig.subplots_adjust(
436 wspace=0.0,
437 hspace=0.0,
438 bottom=0.13 if self.publicationStyle else 0.22,
439 left=0.18 if self.publicationStyle else 0.21,
440 right=0.92 if self.publicationStyle else None,
441 top=0.98 if self.publicationStyle else None,
442 )
443 if not self.publicationStyle:
444 fig = addPlotInfo(fig, plotInfo)
445 return fig
447 def _scatterPlot(
448 self, data: KeyedData, fig: Figure, gs: gridspec.GridSpec, **kwargs
449 ) -> tuple[Axes, PolyCollection | None]:
450 suf_x = self.suffix_x
451 suf_y = self.suffix_y
452 suf_stat = self.suffix_stat
453 # Main scatter plot
454 ax = fig.add_subplot(gs[1:, :-1])
456 binThresh = 5
457 min_n_xs_for_stats = 10
459 yBinsOut = []
460 linesForLegend = []
462 toPlotList = []
463 histIm = None
464 highStats: _StatsContainer
465 lowStats: _StatsContainer
467 magLabel = self.magLabel
468 kwargs_in_label = {k: v for k, v in kwargs.items() if f"{{{k}}}" in magLabel}
469 if kwargs_in_label:
470 magLabel = magLabel.format(**kwargs_in_label)
472 for name_datatype in self.plotTypes:
473 config_datatype = self._datatypes[name_datatype]
474 highArgs = {}
475 lowArgs = {}
476 if not self.publicationStyle:
477 for name in self._stats:
478 highArgs[name] = cast(
479 Scalar,
480 data[
481 f"{{band}}_highSN{config_datatype.suffix_stat}_{name}{suf_stat}".format(**kwargs)
482 ],
483 )
484 lowArgs[name] = cast(
485 Scalar,
486 data[
487 f"{{band}}_lowSN{config_datatype.suffix_stat}_{name}{suf_stat}".format(**kwargs)
488 ],
489 )
490 highStats = _StatsContainer(**highArgs)
491 lowStats = _StatsContainer(**lowArgs)
493 toPlotList.append(
494 (
495 data[f"x{config_datatype.suffix_xy}{suf_x}"],
496 data[f"y{config_datatype.suffix_xy}{suf_y}"],
497 data[f"{name_datatype}HighSNMask{suf_stat}"],
498 data[f"{name_datatype}LowSNMask{suf_stat}"],
499 data[f"{name_datatype}HighSNThreshold{suf_stat}"],
500 data[f"{name_datatype}LowSNThreshold{suf_stat}"],
501 config_datatype.color,
502 config_datatype.colormap,
503 highStats,
504 lowStats,
505 )
506 )
507 else:
508 toPlotList.append(
509 (
510 data[f"x{config_datatype.suffix_xy}{suf_x}"],
511 data[f"y{config_datatype.suffix_xy}{suf_y}"],
512 [],
513 [],
514 [],
515 [],
516 config_datatype.color,
517 config_datatype.colormap,
518 [],
519 [],
520 )
521 )
523 xLims = self.xLims if self.xLims is not None else [np.inf, -np.inf]
525 # If there is no data to plot make a
526 # no data figure
527 numData = 0
528 for xs, _, _, _, _, _, _, _, _, _ in toPlotList:
529 numData += np.count_nonzero(np.isfinite(xs))
530 if numData == 0:
531 return None, None
533 for j, (
534 xs,
535 ys,
536 highSn,
537 lowSn,
538 highThresh,
539 lowThresh,
540 color,
541 cmap,
542 highStats,
543 lowStats,
544 ) in enumerate(toPlotList):
545 highSn = cast(Vector, highSn)
546 lowSn = cast(Vector, lowSn)
547 # ensure the columns are actually array
548 xs = np.array(xs)
549 ys = np.array(ys)
550 sigMadYs = nanSigmaMad(ys)
551 # plot lone median point if there's not enough data to measure more
552 n_xs = np.count_nonzero(np.isfinite(xs))
553 if n_xs <= 1 or not (np.isfinite(sigMadYs) and sigMadYs >= 0.0):
554 continue
555 elif n_xs < min_n_xs_for_stats:
556 xs = [nanMedian(xs)]
557 sigMads = np.array([nanSigmaMad(ys)])
558 ys = np.array([nanMedian(ys)])
559 (medLine,) = ax.plot(xs, ys, color, label=f"Median: {ys[0]:.2g}", lw=0.8)
560 linesForLegend.append(medLine)
561 (sigMadLine,) = ax.plot(
562 xs,
563 ys + 1.0 * sigMads,
564 color,
565 alpha=0.8,
566 lw=0.8,
567 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:.2g}",
568 )
569 ax.plot(xs, ys - 1.0 * sigMads, color, alpha=0.8)
570 linesForLegend.append(sigMadLine)
571 histIm = None
572 continue
574 if self.xLims:
575 xMin, xMax = self.xLims
576 else:
577 # Chop off 1/3% from only the finite xs values
578 # (there may be +/-np.Inf values)
579 # TODO: This should be configurable
580 # but is there a good way to avoid redundant config params
581 # without using slightly annoying subconfigs?
582 xs1, xs97 = np.nanpercentile(xs[np.isfinite(xs)], (1, 97))
583 xScale = (xs97 - xs1) / 20.0 # This is ~5% of the data range
584 xMin, xMax = (xs1 - xScale, xs97 + xScale)
585 xLims[0] = min(xLims[0], xMin)
586 xLims[1] = max(xLims[1], xMax)
588 xEdges = np.arange(xMin, xMax, (xMax - xMin) / self.nBins)
589 medYs = nanMedian(ys)
590 fiveSigmaHigh = medYs + 5.0 * sigMadYs
591 fiveSigmaLow = medYs - 5.0 * sigMadYs
592 binSize = (fiveSigmaHigh - fiveSigmaLow) / 101.0
593 # When the binsize is 0 try using the 1st and 99th
594 # percentile instead of the sigmas.
595 if binSize == 0.0:
596 p1, p99 = np.nanpercentile(ys, [1, 99])
597 binSize = (p99 - p1) / 101.0
599 # If fiveSigmaHigh and fiveSigmaLow are the same
600 # then use the 1st and 99th percentiles to define the
601 # yEdges.
602 yEdges = np.arange(fiveSigmaLow, fiveSigmaHigh, binSize)
603 if fiveSigmaLow == fiveSigmaHigh:
604 yEdges = np.arange(p1, p99, binSize)
606 counts, xBins, yBins = np.histogram2d(xs, ys, bins=(xEdges, yEdges))
607 yBinsOut.append(yBins)
608 countsYs = np.sum(counts, axis=1)
610 ids = np.where(countsYs > binThresh)[0]
611 xEdgesPlot = xEdges[ids][1:]
612 xEdges = xEdges[ids]
614 if len(ids) > 1:
615 # Create the codes needed to turn the sigmaMad lines
616 # into a path to speed up checking which points are
617 # inside the area.
618 codes = np.ones(len(xEdgesPlot) * 2) * Path.LINETO
619 codes[0] = Path.MOVETO
620 codes[-1] = Path.CLOSEPOLY
622 meds = np.zeros(len(xEdgesPlot))
623 threeSigMadVerts = np.zeros((len(xEdgesPlot) * 2, 2))
624 sigMads = np.zeros(len(xEdgesPlot))
626 for i, xEdge in enumerate(xEdgesPlot):
627 ids = np.where((xs < xEdge) & (xs > xEdges[i]) & (np.isfinite(ys)))[0]
628 med = nanMedian(ys[ids])
629 sigMad = nanSigmaMad(ys[ids])
630 meds[i] = med
631 sigMads[i] = sigMad
632 threeSigMadVerts[i, :] = [xEdge, med + 3 * sigMad]
633 threeSigMadVerts[-(i + 1), :] = [xEdge, med - 3 * sigMad]
635 if self.publicationStyle:
636 linecolor = "k"
637 else:
638 linecolor = color
640 (medLine,) = ax.plot(xEdgesPlot, meds, linecolor, label="Running Median")
641 linesForLegend.append(medLine)
643 # Make path to check which points lie within one sigma mad
644 threeSigMadPath = Path(threeSigMadVerts, codes)
646 if not self.publicationStyle:
647 # Add lines for the median +/- 3 * sigma MAD
648 (threeSigMadLine,) = ax.plot(
649 xEdgesPlot,
650 threeSigMadVerts[: len(xEdgesPlot), 1],
651 linecolor,
652 alpha=0.4,
653 label=r"3$\sigma_{MAD}$",
654 )
655 ax.plot(xEdgesPlot[::-1], threeSigMadVerts[len(xEdgesPlot) :, 1], linecolor, alpha=0.4)
657 # Add lines for the median +/- 1 * sigma MAD
658 (sigMadLine,) = ax.plot(
659 xEdgesPlot,
660 meds + 1.0 * sigMads,
661 linecolor,
662 alpha=0.8,
663 label=r"$\sigma_{MAD}$",
664 ls="dashed",
665 )
666 linesForLegend.append(sigMadLine)
667 ax.plot(xEdgesPlot, meds - 1.0 * sigMads, linecolor, alpha=0.8, ls="dashed")
669 if not self.publicationStyle:
670 # Add lines for the median +/- 2 * sigma MAD
671 (twoSigMadLine,) = ax.plot(
672 xEdgesPlot, meds + 2.0 * sigMads, linecolor, alpha=0.6, label=r"2$\sigma_{MAD}$"
673 )
674 linesForLegend.append(twoSigMadLine)
675 linesForLegend.append(threeSigMadLine)
676 ax.plot(xEdgesPlot, meds - 2.0 * sigMads, linecolor, alpha=0.6)
678 # Check which points are outside 3 sigma MAD of the median
679 # and plot these as points.
680 inside = threeSigMadPath.contains_points(np.array([xs, ys]).T)
681 ax.plot(xs[~inside], ys[~inside], ".", ms=5, alpha=0.2, mfc=color, mec="none", zorder=-1)
683 if not self.publicationStyle:
684 # Add some stats text
685 xPos = 0.65 - 0.4 * j
686 bbox = dict(edgecolor=color, linestyle="--", facecolor="none")
687 statText = f"S/N > {highThresh:0.4g} Stats ({magLabel} < {highStats.approxMag:0.4g})\n"
688 highStatsStr = (
689 f"Median: {highStats.median:0.4g} "
690 + r"$\sigma_{MAD}$: "
691 + f"{highStats.sigmaMad:0.4g} "
692 + r"N$_{points}$: "
693 + f"{highStats.count}"
694 )
695 statText += highStatsStr
696 fig.text(xPos, 0.090, statText, bbox=bbox, transform=fig.transFigure, fontsize=6)
698 bbox = dict(edgecolor=color, linestyle=":", facecolor="none")
699 statText = f"S/N > {lowThresh:0.4g} Stats ({magLabel} < {lowStats.approxMag:0.4g})\n"
700 lowStatsStr = (
701 f"Median: {lowStats.median:0.4g} "
702 + r"$\sigma_{MAD}$: "
703 + f"{lowStats.sigmaMad:0.4g} "
704 + r"N$_{points}$: "
705 + f"{lowStats.count}"
706 )
707 statText += lowStatsStr
708 fig.text(xPos, 0.020, statText, bbox=bbox, transform=fig.transFigure, fontsize=6)
710 if self.plot2DHist:
711 extent = [xLims[0], xLims[1], self.yLims[0], self.yLims[1]] if self.yLims else None
712 histIm = ax.hexbin(
713 xs[inside],
714 ys[inside],
715 gridsize=75,
716 extent=extent,
717 cmap=cmap,
718 mincnt=1,
719 zorder=-3,
720 edgecolors=None,
721 )
722 else:
723 ax.plot(xs[inside], ys[inside], ".", ms=3, alpha=0.2, mfc=color, mec=color, zorder=-1)
725 if not self.publicationStyle:
726 # If there are not many sources being used for the
727 # statistics then plot them individually as just
728 # plotting a line makes the statistics look wrong
729 # as the magnitude estimation is iffy for low
730 # numbers of sources.
731 if np.sum(highSn) < 100 and np.sum(highSn) > 0:
732 ax.plot(
733 cast(Vector, xs[highSn]),
734 cast(Vector, ys[highSn]),
735 marker="x",
736 ms=4,
737 mec="w",
738 mew=2,
739 ls="none",
740 )
741 (highSnLine,) = ax.plot(
742 cast(Vector, xs[highSn]),
743 cast(Vector, ys[highSn]),
744 color=color,
745 marker="x",
746 ms=4,
747 ls="none",
748 label="High SN",
749 )
750 linesForLegend.append(highSnLine)
751 else:
752 ax.axvline(highStats.approxMag, color=color, ls="--")
754 if np.sum(lowSn) < 100 and np.sum(lowSn) > 0:
755 ax.plot(
756 cast(Vector, xs[lowSn]),
757 cast(Vector, ys[lowSn]),
758 marker="+",
759 ms=4,
760 mec="w",
761 mew=2,
762 ls="none",
763 )
764 (lowSnLine,) = ax.plot(
765 cast(Vector, xs[lowSn]),
766 cast(Vector, ys[lowSn]),
767 color=color,
768 marker="+",
769 ms=4,
770 ls="none",
771 label="Low SN",
772 )
773 linesForLegend.append(lowSnLine)
774 else:
775 ax.axvline(lowStats.approxMag, color=color, ls=":")
777 else:
778 ax.plot(xs, ys, ".", ms=5, alpha=0.3, mfc=color, mec=color, zorder=-1)
779 meds = np.array([nanMedian(ys)] * len(xs))
780 (medLine,) = ax.plot(xs, meds, color, label=f"Median: {nanMedian(ys):0.3g}", lw=0.8)
781 linesForLegend.append(medLine)
782 sigMads = np.array([nanSigmaMad(ys)] * len(xs))
783 (sigMadLine,) = ax.plot(
784 xs,
785 meds + 1.0 * sigMads,
786 color,
787 alpha=0.8,
788 lw=0.8,
789 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:0.3g}",
790 )
791 ax.plot(xs, meds - 1.0 * sigMads, color, alpha=0.8)
792 linesForLegend.append(sigMadLine)
793 histIm = None
795 # Add a horizontal reference line at 0 to the scatter plot
796 ax.axhline(0, color=kwargs["hlineColor"], ls=kwargs["hlineStyle"], alpha=0.7, zorder=-2)
798 # Set the scatter plot limits
799 suf_x = self.suffix_y
800 # TODO: Make this not work by accident
801 if f"yStars{suf_x}" in data and (len(cast(Vector, data[f"yStars{suf_x}"])) > 0):
802 plotMed = nanMedian(cast(Vector, data[f"yStars{suf_x}"]))
803 elif f"yGalaxies{suf_x}" in data and (len(cast(Vector, data[f"yGalaxies{suf_x}"])) > 0):
804 plotMed = nanMedian(cast(Vector, data[f"yGalaxies{suf_x}"]))
805 else:
806 plotMed = np.nan
808 # Ignore types below pending making this not working my accident
809 # If len(xs) < min_n_xs_for_stats then `meds` doesn't exist.
810 if len(xs) < min_n_xs_for_stats: # type: ignore
811 meds = [nanMedian(ys)] # type: ignore
812 if self.yLims:
813 ax.set_ylim(self.yLims[0], self.yLims[1]) # type: ignore
814 elif np.isfinite(plotMed):
815 numSig = 4
816 yLimMin = plotMed - numSig * sigMadYs # type: ignore
817 yLimMax = plotMed + numSig * sigMadYs # type: ignore
818 while (yLimMax < np.max(meds) or yLimMin > np.min(meds)) and numSig < 10: # type: ignore
819 numSig += 1
821 numSig += 1
822 yLimMin = plotMed - numSig * sigMadYs # type: ignore
823 yLimMax = plotMed + numSig * sigMadYs # type: ignore
825 # If len(y) == 0 for ys in toPlotList then sigMadY = NaN
826 # ... in which case nothing was plotted and limits are irrelevant.
827 if all(np.isfinite([yLimMin, yLimMax])):
828 ax.set_ylim(yLimMin, yLimMax)
830 # This could be false if len(x) == 0 for xs in toPlotList
831 # ... in which case nothing was plotted and limits are irrelevant
832 if all(np.isfinite(xLims)):
833 ax.set_xlim(xLims)
835 # Add a line legend
836 ax.legend(
837 handles=linesForLegend,
838 ncol=4,
839 fontsize=6,
840 loc=self.legendLocation,
841 framealpha=0.9,
842 edgecolor="k",
843 borderpad=0.4,
844 handlelength=3,
845 )
847 # Add axes labels
848 band = kwargs.get("band", "unspecified")
849 xlabel = self.xAxisLabel
850 ylabel = self.yAxisLabel
851 if "{band}" in xlabel:
852 xlabel = xlabel.format(band=band)
853 if "{band}" in ylabel:
854 ylabel = ylabel.format(band=band)
855 if self.publicationStyle:
856 ax.set_ylabel(ylabel, labelpad=10)
857 ax.set_xlabel(xlabel, labelpad=2)
858 else:
859 ax.set_ylabel(ylabel, labelpad=10, fontsize=10)
860 ax.set_xlabel(xlabel, labelpad=2, fontsize=10)
861 ax.tick_params(labelsize=8)
863 return ax, histIm
865 def _makeTopHistogram(
866 self, data: KeyedData, figure: Figure, gs: gridspec.GridSpec, ax: Axes, **kwargs
867 ) -> None:
868 suf_x = self.suffix_x
869 # Top histogram
870 topHist = figure.add_subplot(gs[0, :-1], sharex=ax)
871 x_min, x_max = ax.get_xlim()
872 bins = np.linspace(x_min, x_max, 100)
874 if "any" in self.plotTypes:
875 x_any = f"x{self._datatypes['any'].suffix_xy}{suf_x}"
876 keys_notany = [x for x in self.plotTypes if x != "any"]
877 else:
878 x_any = (
879 np.concatenate([data[f"x{self._datatypes[key].suffix_xy}{suf_x}"] for key in self.plotTypes])
880 if (len(self.plotTypes) > 1)
881 else None
882 )
883 keys_notany = self.plotTypes
884 if x_any is not None:
885 if np.sum(x_any > 0) > 0:
886 log = True
887 else:
888 log = False
889 topHist.hist(x_any, bins=bins, color="grey", alpha=0.3, log=log, label=f"Any ({len(x_any)})")
891 for key in keys_notany:
892 config_datatype = self._datatypes[key]
893 vector = np.array(data[f"x{config_datatype.suffix_xy}{suf_x}"])
894 if np.sum(vector > 0) > 0:
895 log = True
896 else:
897 log = False
898 topHist.hist(
899 vector,
900 bins=bins,
901 color=config_datatype.color,
902 histtype="step",
903 log=log,
904 label=f"{config_datatype.suffix_stat} ({len(vector)})",
905 )
906 topHist.axes.get_xaxis().set_visible(False)
907 topHist.set_ylabel("Count", fontsize=10 + 4 * self.publicationStyle)
908 if not self.publicationStyle:
909 topHist.legend(fontsize=6, framealpha=0.9, borderpad=0.4, loc="lower left", ncol=3, edgecolor="k")
910 topHist.tick_params(labelsize=8)
912 self._modifyHistogramTicks(topHist, do_x=False, max_labels=self.xHistMaxLabels)
914 def _makeSideHistogram(
915 self,
916 data: KeyedData,
917 figure: Figure,
918 gs: gridspec.Gridspec,
919 ax: Axes,
920 histIm: PolyCollection | None,
921 **kwargs,
922 ) -> None:
923 suf_y = self.suffix_y
924 # Side histogram
925 sideHist = figure.add_subplot(gs[1:, -1], sharey=ax)
926 y_min, y_max = ax.get_ylim()
927 bins = np.linspace(y_min, y_max, 100)
929 if "any" in self.plotTypes:
930 y_any = np.array(data[f"y{self._datatypes['any'].suffix_xy}{suf_y}"])
931 keys_notany = [x for x in self.plotTypes if x != "any"]
932 else:
933 y_any = (
934 np.concatenate(
935 [np.array(data[f"y{self._datatypes[key].suffix_xy}{suf_y}"]) for key in self.plotTypes]
936 )
937 if (len(self.plotTypes) > 1)
938 else None
939 )
940 keys_notany = self.plotTypes
941 if y_any is not None:
942 sideHist.hist(
943 np.array(y_any),
944 bins=bins,
945 color="grey",
946 alpha=0.3,
947 orientation="horizontal",
948 log=np.any(y_any > 0),
949 )
951 kwargs_hist = dict(
952 bins=bins,
953 histtype="step",
954 log=True,
955 orientation="horizontal",
956 )
957 for key in keys_notany:
958 config_datatype = self._datatypes[key]
959 # If the data has no positive values then it
960 # cannot be log scaled and it prints a bunch
961 # of irritating warnings, in this case don't
962 # try.
963 numPos = np.sum(data[f"y{config_datatype.suffix_xy}{suf_y}"] > 0)
965 if numPos <= 0:
966 kwargs_hist["log"] = False
968 vector = data[f"y{config_datatype.suffix_xy}{suf_y}"]
969 sideHist.hist(
970 vector,
971 color=config_datatype.color,
972 **kwargs_hist,
973 )
974 if not self.publicationStyle:
975 sideHist.hist(
976 vector[cast(Vector, data[f"{key}HighSNMask{self.suffix_stat}"])],
977 color=config_datatype.color,
978 linestyle="--",
979 **kwargs_hist,
980 )
981 sideHist.hist(
982 vector[cast(Vector, data[f"{key}LowSNMask{self.suffix_stat}"])],
983 color=config_datatype.color,
984 **kwargs_hist,
985 linestyle=":",
986 )
988 # Add a horizontal reference line at 0 to the side histogram
989 sideHist.axhline(0, color=kwargs["hlineColor"], ls=kwargs["hlineStyle"], alpha=0.7, zorder=-2)
991 sideHist.axes.get_yaxis().set_visible(False)
992 sideHist.set_xlabel("Count", fontsize=10 + 4 * self.publicationStyle)
993 self._modifyHistogramTicks(sideHist, do_x=True, max_labels=self.yHistMaxLabels)
995 if not self.publicationStyle:
996 sideHist.tick_params(labelsize=8)
997 if self.plot2DHist and histIm is not None:
998 divider = make_axes_locatable(sideHist)
999 cax = divider.append_axes("right", size="25%", pad=0)
1000 sideHist.get_figure().colorbar(histIm, cax=cax, orientation="vertical")
1001 text = cax.text(
1002 0.5,
1003 0.5,
1004 "Points Per Bin",
1005 color="k",
1006 rotation="vertical",
1007 transform=cax.transAxes,
1008 ha="center",
1009 va="center",
1010 fontsize=10,
1011 )
1012 text.set_path_effects([pathEffects.Stroke(linewidth=3, foreground="w"), pathEffects.Normal()])
1014 def _modifyHistogramTicks(self, histogram, do_x: bool, max_labels: int):
1015 axis = histogram.get_xaxis() if do_x else histogram.get_yaxis()
1016 limits = list(histogram.get_xlim() if do_x else histogram.get_ylim())
1017 get_ticks = histogram.get_xticks if do_x else histogram.get_yticks
1018 ticks = get_ticks()
1019 # Let the minimum be larger then specified if the histogram has large
1020 # values everywhere, but cut it down a little so the lowest-valued bin
1021 # is still easily visible
1022 limits[0] = max(self.histMinimum, 0.9 * limits[0])
1023 # Round the upper limit to the nearest power of 10
1024 limits[1] = 10 ** (np.ceil(np.log10(limits[1]))) if (limits[1] > 0) else limits[1]
1025 for minor in (False, True):
1026 # Ignore ticks that are below the minimum value
1027 valid = (ticks >= limits[0]) & (ticks <= limits[1])
1028 labels = [label for label, _valid in zip(axis.get_ticklabels(minor=minor), valid) if _valid]
1029 if (n_labels := len(labels)) > max_labels:
1030 labels_new = [""] * n_labels
1031 # Skip the first label if we're not using minor axis labels
1032 # This helps avoid overlap with the scatter plot labels
1033 for idx_fill in np.round(np.linspace(1 - minor, n_labels - 1, max_labels)).astype(int):
1034 labels_new[idx_fill] = labels[idx_fill]
1035 axis.set_ticks(ticks[valid], labels_new)
1036 # If there are enough major tick labels, disable minor tick labels
1037 if len(labels) >= 2:
1038 axis.set_minor_formatter(NullFormatter())
1039 break
1040 else:
1041 axis.set_minor_formatter(
1042 LogFormatterExponentSci(minor_thresholds=(1, self.histMinimum / 10.0))
1043 )
1044 ticks = get_ticks(minor=True)
1046 (histogram.set_xlim if do_x else histogram.set_ylim)(limits)