Coverage for python / lsst / analysis / tools / actions / plot / wholeSkyPlot.py: 13%
230 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-22 09:08 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-22 09:08 +0000
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ("WholeSkyPlot",)
26import importlib.resources as importResources
27from typing import Mapping, Optional
29import lsst.analysis.tools
30import matplotlib.patheffects as pathEffects
31import numpy as np
32import yaml
33from lsst.pex.config import ChoiceField, Field, ListField
34from lsst.utils.plotting import (
35 accent_color,
36 divergent_cmap,
37 make_figure,
38 set_rubin_plotstyle,
39 stars_cmap,
40 stars_color,
41)
42from matplotlib import gridspec
43from matplotlib.collections import PatchCollection
44from matplotlib.colors import CenteredNorm
45from matplotlib.figure import Figure
46from matplotlib.patches import Patch, Polygon
48from ...interfaces import KeyedData, KeyedDataSchema, PlotAction, Scalar, Vector
49from ...math import nanSigmaMad
50from ...utils import getTractCorners
51from .plotUtils import addPlotInfo
54class WholeSkyPlot(PlotAction):
55 """Plots the on sky distribution of a parameter.
57 Plots the values of the parameter given for the z axis
58 according to the positions given for x and y. Optimised
59 for use with RA and Dec. Also calculates some basic
60 statistics and includes those on the plot.
62 The default axes limits and figure size were chosen to plot HSC PDR2.
63 """
65 xAxisLabel = Field[str](doc="Label to use for the x axis.", default="RA (degrees)")
66 yAxisLabel = Field[str](doc="Label to use for the y axis.", default="Dec (degrees)")
67 zAxisLabel = Field[str](doc="Label to use for the z axis.", default="")
68 autoAxesLimits = Field[bool](doc="Find axes limits automatically.", default=True)
69 xLimits = ListField[float](doc="Plotting limits for the x axis.", default=[-5.0, 365.0])
70 yLimits = ListField[float](doc="Plotting limits for the y axis.", default=[-10.0, 60.0])
71 autoAxesLimits = Field[bool](doc="Find axes limits automatically.", default=True)
72 colorBarMin = Field[float](doc="The minimum value of the color bar.", optional=True)
73 colorBarMax = Field[float](doc="The minimum value of the color bar.", optional=True)
74 colorBarRange = Field[float](
75 doc="The multiplier for the color bar range. The max/min range values are: median +/- N * sigmaMad"
76 ", where N is this config value.",
77 default=3.0,
78 )
79 colorMapType = ChoiceField[str](
80 doc="Type of color map to use for the color bar. Options: sequential, divergent, userDefined.",
81 allowed={cmType: cmType for cmType in ("sequential", "divergent")},
82 default="divergent",
83 )
84 colorMap = ListField[str](
85 doc="List of hexidecimal colors for a user-defined color map.",
86 optional=True,
87 )
88 showOutliers = Field[bool](
89 doc="Show the outliers on the plot. "
90 "Outliers are values whose absolute value is > colorBarRange * sigmaMAD.",
91 default=True,
92 )
93 showNaNs = Field[bool](doc="Show the NaNs on the plot.", default=True)
94 labelTracts = Field[bool](doc="Label the tracts.", default=False)
96 addThresholds = Field[bool](
97 doc="Read in the predefined thresholds and indicate them on the histogram.",
98 default=True,
99 )
101 def getInputSchema(self, **kwargs) -> KeyedDataSchema:
102 base = []
103 base.append(("z", Vector))
104 base.append(("tract", Vector))
105 return base
107 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure:
108 self._validateInput(data, **kwargs)
109 return self.makePlot(data, **kwargs)
111 def _validateInput(self, data: KeyedData, **kwargs) -> None:
112 """NOTE currently can only check that something is not a Scalar, not
113 check that the data is consistent with Vector
114 """
115 needed = self.getInputSchema(**kwargs)
116 if remainder := {key.format(**kwargs) for key, _ in needed} - {
117 key.format(**kwargs) for key in data.keys()
118 }:
119 raise ValueError(f"Task needs keys {remainder} but they were not found in input")
120 for name, typ in needed:
121 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar)
122 if isScalar and typ != Scalar:
123 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}")
125 def _getAxesLimits(self, xs: list, ys: list) -> tuple(list, list):
126 """Get the x and y axes limits in degrees.
128 Parameters
129 ----------
130 xs : `list`
131 X coordinates for the tracts to plot.
132 ys : `list`
133 Y coordinates for the tracts to plot.
135 Returns
136 -------
137 xlim : `list`
138 Minimun and maximum x axis values.
139 ylim : `list`
140 Minimun and maximum y axis values.
141 """
143 # Add some blank space on the edges of the plot.
144 xlim = [np.nanmin(xs) - 5, np.nanmax(xs) + 5]
145 ylim = [np.nanmin(ys) - 5, np.nanmax(ys) + 5]
147 # Limit to only show real RA/Dec values.
148 if xlim[0] < 0.0:
149 xlim[0] = 0.0
150 if xlim[1] > 360.0:
151 xlim[1] = 360.0
152 if ylim[0] < -90.0:
153 ylim[0] = -90.0
154 if ylim[1] > 90.0:
155 ylim[1] = 90.0
157 return (xlim, ylim)
159 def _getMaxOutlierVals(self, multiplier: float, tracts: list, values: list, outlierInds: list) -> str:
160 """Get the 5 largest outlier values in a string.
162 Parameters
163 ----------
164 multiplier : `float`
165 Select values whose absolute value is > multiplier * sigmaMAD.
166 tracts : `list`
167 All the tracts.
168 values : `list`
169 All the metric values.
170 outlierInds : `list`
171 Indicies of outlier values.
173 Returns
174 -------
175 text : `str`
176 A string containing the 10 tracts with the largest outlier values.
177 """
178 if self.addThresholds:
179 text = "Tracts with value outside thresholds: "
180 else:
181 text = f"Tracts with |value| > {multiplier}" + r"$\sigma_{MAD}$" + ": "
182 if len(outlierInds) > 0:
183 outlierValues = np.array(values)[outlierInds]
184 outlierTracts = np.array(tracts)[outlierInds]
185 # Sort values in descending (-) absolute value order discounting
186 # NaNs.
187 maxInds = np.argsort(-np.abs(outlierValues))
188 # Show up to ten values on the plot.
189 for ind in maxInds[:10]:
190 val = outlierValues[ind]
191 tract = outlierTracts[ind]
192 text += f"{tract}, {val:.3}; "
193 # Remove the final trailing comma and whitespace.
194 text = text[:-2]
195 else:
196 text += "None"
198 return text
200 def makePlot(
201 self,
202 data: KeyedData,
203 plotInfo: Optional[Mapping[str, str]] = None,
204 **kwargs,
205 ) -> Figure:
206 """Make a WholeSkyPlot of the given data.
208 Parameters
209 ----------
210 data : `KeyedData`
211 The catalog to plot the points from.
212 plotInfo : `dict`
213 A dictionary of information about the data being plotted with keys:
215 ``"run"``
216 The output run for the plots (`str`).
217 ``"skymap"``
218 The type of skymap used for the data (`str`).
219 ``"filter"``
220 The filter used for this data (`str`).
221 ``"tract"``
222 The tract that the data comes from (`str`).
224 Returns
225 -------
226 `pipeBase.Struct` containing:
227 skyPlot : `matplotlib.figure.Figure`
228 The resulting figure.
231 Examples
232 --------
233 An example of the plot produced from this code is here:
235 .. image:: /_static/analysis_tools/wholeSkyPlotExample.png
237 For a detailed example of how to make a plot from the command line
238 please see the
239 :ref:`getting started guide<analysis-tools-getting-started>`.
240 """
241 skymap = kwargs["skymap"]
242 if plotInfo is None:
243 plotInfo = {}
245 if self.addThresholds:
246 metricThresholdFile = importResources.read_text(lsst.analysis.tools, "metricInformation.yaml")
247 metricDefs = yaml.safe_load(metricThresholdFile)
249 # Prevent Bands in the plot info showing a list of bands.
250 # If bands is a list, it implies that parameterizedBand=False,
251 # and that the metric is not band-specific.
252 if "bands" in plotInfo:
253 if isinstance(plotInfo["bands"], list):
254 plotInfo["bands"] = "N/A"
256 colorMap = self.colorMap
257 match self.colorMapType:
258 case "sequential":
259 if colorMap is None:
260 colorMap = stars_cmap()
261 outlierColor = "red"
262 norm = None
263 case "divergent":
264 if colorMap is None:
265 colorMap = divergent_cmap()
266 outlierColor = "fuchsia"
267 norm = CenteredNorm()
269 # Create patches using the corners of each tract.
270 patches = []
271 colBarVals = []
272 tracts = []
273 ras = []
274 decs = []
275 mid_ras = []
276 mid_decs = []
277 for i, tract in enumerate(data["tract"]):
278 corners = getTractCorners(skymap, tract)
279 patches.append(Polygon(corners, closed=True))
280 colBarVals.append(data["z"][i])
281 tracts.append(tract)
282 ras.append(corners[0][0])
283 decs.append(corners[0][1])
284 mid_ras.append((corners[0][0] + corners[1][0]) / 2)
285 mid_decs.append((corners[0][1] + corners[2][1]) / 2)
287 # Setup figure.
288 fig = make_figure(dpi=300, figsize=(12, 3.5))
289 set_rubin_plotstyle()
290 gs = gridspec.GridSpec(1, 4)
291 ax = fig.add_subplot(gs[:3])
292 # Add colored patches showing tract metric values.
293 patchCollection = PatchCollection(patches, cmap=colorMap, norm=norm)
294 ax.add_collection(patchCollection)
296 # Define color bar range.
297 if np.sum(np.isfinite(colBarVals)) > 0:
298 med = np.nanmedian(colBarVals)
299 else:
300 med = np.nan
301 sigmaMad = nanSigmaMad(colBarVals)
302 if self.colorBarMin is not None:
303 vmin = np.float64(self.colorBarMin)
304 else:
305 vmin = med - self.colorBarRange * sigmaMad
306 if self.colorBarMax is not None:
307 vmax = np.float64(self.colorBarMax)
308 else:
309 vmax = med + self.colorBarRange * sigmaMad
311 dataName = self.zAxisLabel.format_map(kwargs)
312 colBarVals = np.array(colBarVals)
313 if self.addThresholds and dataName in metricDefs:
314 if "lowThreshold" in metricDefs[dataName].keys():
315 lowThreshold = metricDefs[dataName]["lowThreshold"]
316 else:
317 lowThreshold = np.nan
318 if "highThreshold" in metricDefs[dataName].keys():
319 highThreshold = metricDefs[dataName]["highThreshold"]
320 else:
321 highThreshold = np.nan
322 outlierInds = np.where((colBarVals < lowThreshold) | (colBarVals > highThreshold))[0]
323 else:
324 # Note tracts with metrics outside (vmin, vmax) as outliers.
325 outlierInds = np.where((colBarVals < vmin) | (colBarVals > vmax))[0]
327 # Initialize legend handles.
328 handles = []
330 if self.showOutliers:
331 # Plot the outlier patches.
332 outlierPatches = []
333 if len(outlierInds) > 0:
334 for ind in outlierInds:
335 outlierPatches.append(patches[ind])
336 outlierPatchCollection = PatchCollection(
337 outlierPatches,
338 cmap=colorMap,
339 norm=norm,
340 facecolors="none",
341 edgecolors=outlierColor,
342 linewidths=0.5,
343 zorder=100,
344 )
345 ax.add_collection(outlierPatchCollection)
346 # Add legend information.
347 outlierPatch = Patch(
348 facecolor="none",
349 edgecolor=outlierColor,
350 linewidth=0.5,
351 label="Outlier",
352 )
353 handles.append(outlierPatch)
355 if self.showNaNs:
356 # Plot tracts with NaN metric values.
357 nanInds = np.where(~np.isfinite(colBarVals))[0]
358 nanPatches = []
359 if len(nanInds) > 0:
360 for ind in nanInds:
361 nanPatches.append(patches[ind])
362 nanPatchCollection = PatchCollection(
363 nanPatches,
364 cmap=None,
365 norm=norm,
366 facecolors="white",
367 edgecolors="grey",
368 linestyles="dotted",
369 linewidths=0.5,
370 zorder=100,
371 )
372 ax.add_collection(nanPatchCollection)
373 # Add legend information.
374 nanPatch = Patch(
375 facecolor="white",
376 edgecolor="grey",
377 linestyle="dotted",
378 linewidth=0.5,
379 label="NaN",
380 )
381 handles.append(nanPatch)
383 if len(handles) > 0:
384 fig.legend(handles=handles)
386 if self.labelTracts:
387 # Label the tracts
388 for i, tract in enumerate(tracts):
389 ax.text(
390 mid_ras[i],
391 mid_decs[i],
392 f"{tract}",
393 ha="center",
394 va="center",
395 fontsize=2,
396 alpha=0.7,
397 zorder=100,
398 )
400 ax.set_aspect("equal")
401 axPos = ax.get_position()
402 ax1 = fig.add_axes([0.73, 0.25, 0.20, 0.47])
404 if np.sum(np.isfinite(data["z"])) > 0:
405 ax1.hist(data["z"], bins=len(data["z"] / 10), color=stars_color(), histtype="step")
406 else:
407 ax1.text(0.5, 0.5, "Data all NaN/Inf")
408 ax1.set_xlabel("Metric Values")
409 ax1.set_ylabel("Number")
410 ax1.yaxis.set_label_position("right")
411 ax1.yaxis.tick_right()
413 if self.addThresholds and dataName in metricDefs:
414 # Check the thresholds are finite and set them to
415 # the min/max of the data if they aren't to calculate
416 # the x range of the plot
417 if np.isfinite(lowThreshold):
418 ax1.axvline(lowThreshold, color=accent_color())
419 else:
420 lowThreshold = np.nanmin(colBarVals)
421 if np.isfinite(highThreshold):
422 ax1.axvline(highThreshold, color=accent_color())
423 else:
424 highThreshold = np.nanmax(colBarVals)
426 widthThreshold = highThreshold - lowThreshold
427 upperLim = highThreshold + 0.5 * widthThreshold
428 lowerLim = lowThreshold - 0.5 * widthThreshold
429 ax1.set_xlim(lowerLim, upperLim)
430 numOutside = np.sum(((data["z"] > upperLim) | (data["z"] < lowerLim)))
431 ax1.set_title("Outside plot limits: " + str(numOutside))
433 else:
434 if vmin != vmax and np.isfinite(vmin) and np.isfinite(vmax):
435 ax1.set_xlim(vmin, vmax)
437 if self.autoAxesLimits:
438 xlim, ylim = self._getAxesLimits(ras, decs)
439 else:
440 xlim, ylim = self.xLimits, self.yLimits
441 ax.set_xlim(xlim)
442 ax.set_ylim(ylim)
443 ax.set_xlabel(self.xAxisLabel)
444 ax.set_ylabel(self.yAxisLabel)
445 ax.invert_xaxis()
447 if self.showOutliers:
448 # Add text boxes to show the number of tracts, number of NaNs,
449 # median, sigma MAD, and the five largest outlier values.
450 outlierText = self._getMaxOutlierVals(self.colorBarRange, tracts, colBarVals, outlierInds)
451 # Make vertical text spacing readable for different figure sizes.
452 multiplier = 3.5 / fig.get_size_inches()[1]
453 verticalSpacing = 0.028 * multiplier
454 fig.text(
455 0.01,
456 0.01 + 3 * verticalSpacing,
457 f"Num tracts: {len(tracts)}",
458 transform=fig.transFigure,
459 fontsize=8,
460 alpha=0.7,
461 )
462 if self.showNaNs:
463 fig.text(
464 0.01,
465 0.01 + 2 * verticalSpacing,
466 f"Num nans: {len(nanInds)}",
467 transform=fig.transFigure,
468 fontsize=8,
469 alpha=0.7,
470 )
471 fig.text(
472 0.01,
473 0.01 + verticalSpacing,
474 f"Median: {med:.3f}; " + r"$\sigma_{MAD}$" + f": {sigmaMad:.3f}",
475 transform=fig.transFigure,
476 fontsize=8,
477 alpha=0.7,
478 )
479 if self.showOutliers:
480 fig.text(0.01, 0.01, outlierText, transform=fig.transFigure, fontsize=8, alpha=0.7)
482 # Truncate the color range to (vmin, vmax).
483 if vmin != vmax and np.isfinite(vmin) and np.isfinite(vmax):
484 colBarVals = np.clip(np.array(colBarVals), vmin, vmax)
485 patchCollection.set_array(colBarVals)
486 # Make the color bar with a metric label.
487 axPos = ax.get_position()
488 cax = fig.add_axes([0.084, axPos.y1 + 0.02, 0.62, 0.07])
489 fig.colorbar(
490 patchCollection,
491 cax=cax,
492 shrink=0.7,
493 extend="both",
494 location="top",
495 orientation="horizontal",
496 )
497 cbarText = "Metric Values"
499 text = cax.text(
500 0.5,
501 0.5,
502 cbarText,
503 transform=cax.transAxes,
504 ha="center",
505 va="center",
506 fontsize=10,
507 zorder=100,
508 )
509 text.set_path_effects([pathEffects.Stroke(linewidth=3, foreground="w"), pathEffects.Normal()])
511 # Finalize plot appearance.
512 ax.grid()
513 ax.set_axisbelow(True)
514 fig = addPlotInfo(fig, plotInfo)
515 fig.subplots_adjust(left=0.08, right=0.92, top=0.8, bottom=0.17, wspace=0.05)
516 titleText = self.zAxisLabel.format_map(kwargs)
517 if "zUnit" in data and data["zUnit"] != "":
518 titleText += f" ({data['zUnit']})"
519 fig.suptitle("Metric: " + titleText, fontsize=20)
521 return fig