Coverage for python / lsst / analysis / tools / actions / plot / wholeSkyPlot.py: 14%
241 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-01 08:55 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-01 08:55 +0000
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ("WholeSkyPlot",)
26import importlib.resources as importResources
27import json
28from collections.abc import Mapping
30import matplotlib.patheffects as pathEffects
31import numpy as np
32import yaml
33from matplotlib import gridspec
34from matplotlib.collections import PatchCollection
35from matplotlib.colors import CenteredNorm
36from matplotlib.figure import Figure
37from matplotlib.patches import Patch, Polygon
39import lsst.analysis.tools
40from lsst.pex.config import ChoiceField, Field, ListField
41from lsst.utils.plotting import (
42 accent_color,
43 divergent_cmap,
44 make_figure,
45 set_rubin_plotstyle,
46 stars_cmap,
47 stars_color,
48)
50from ...interfaces import KeyedData, KeyedDataSchema, PlotAction, Scalar, Vector
51from ...math import nanSigmaMad
52from ...utils import getTractCorners
53from .plotUtils import addPlotInfo
56class AnnotatedFigure(Figure):
57 metadata: dict
60class WholeSkyPlot(PlotAction):
61 """Plots the on sky distribution of a parameter.
63 Plots the values of the parameter given for the z axis
64 according to the positions given for x and y. Optimised
65 for use with RA and Dec. Also calculates some basic
66 statistics and includes those on the plot.
68 The default axes limits and figure size were chosen to plot HSC PDR2.
69 """
71 xAxisLabel = Field[str](doc="Label to use for the x axis.", default="RA (degrees)")
72 yAxisLabel = Field[str](doc="Label to use for the y axis.", default="Dec (degrees)")
73 zAxisLabel = Field[str](doc="Label to use for the z axis.", default="")
74 autoAxesLimits = Field[bool](doc="Find axes limits automatically.", default=True)
75 xLimits = ListField[float](doc="Plotting limits for the x axis.", default=[-5.0, 365.0])
76 yLimits = ListField[float](doc="Plotting limits for the y axis.", default=[-10.0, 60.0])
77 autoAxesLimits = Field[bool](doc="Find axes limits automatically.", default=True)
78 colorBarMin = Field[float](doc="The minimum value of the color bar.", optional=True)
79 colorBarMax = Field[float](doc="The minimum value of the color bar.", optional=True)
80 colorBarRange = Field[float](
81 doc="The multiplier for the color bar range. The max/min range values are: median +/- N * sigmaMad"
82 ", where N is this config value.",
83 default=3.0,
84 )
85 colorMapType = ChoiceField[str](
86 doc="Type of color map to use for the color bar. Options: sequential, divergent, userDefined.",
87 allowed={cmType: cmType for cmType in ("sequential", "divergent")},
88 default="divergent",
89 )
90 colorMap = ListField[str](
91 doc="List of hexidecimal colors for a user-defined color map.",
92 optional=True,
93 )
94 showOutliers = Field[bool](
95 doc="Show the outliers on the plot. "
96 "Outliers are values whose absolute value is > colorBarRange * sigmaMAD.",
97 default=True,
98 )
99 showNaNs = Field[bool](doc="Show the NaNs on the plot.", default=True)
100 labelTracts = Field[bool](doc="Label the tracts.", default=False)
102 addThresholds = Field[bool](
103 doc="Read in the predefined thresholds and indicate them on the histogram.",
104 default=True,
105 )
107 def getInputSchema(self, **kwargs) -> KeyedDataSchema:
108 base = []
109 base.append(("z", Vector))
110 base.append(("tract", Vector))
111 return base
113 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure:
114 self._validateInput(data, **kwargs)
115 return self.makePlot(data, **kwargs)
117 def _validateInput(self, data: KeyedData, **kwargs) -> None:
118 """NOTE currently can only check that something is not a Scalar, not
119 check that the data is consistent with Vector
120 """
121 needed = self.getInputSchema(**kwargs)
122 if remainder := {key.format(**kwargs) for key, _ in needed} - {
123 key.format(**kwargs) for key in data.keys()
124 }:
125 raise ValueError(f"Task needs keys {remainder} but they were not found in input")
126 for name, typ in needed:
127 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar)
128 if isScalar and typ != Scalar:
129 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}")
131 def _getAxesLimits(self, xs: list, ys: list) -> tuple(list, list):
132 """Get the x and y axes limits in degrees.
134 Parameters
135 ----------
136 xs : `list`
137 X coordinates for the tracts to plot.
138 ys : `list`
139 Y coordinates for the tracts to plot.
141 Returns
142 -------
143 xlim : `list`
144 Minimun and maximum x axis values.
145 ylim : `list`
146 Minimun and maximum y axis values.
147 """
149 # Add some blank space on the edges of the plot.
150 xlim = [np.nanmin(xs) - 5, np.nanmax(xs) + 5]
151 ylim = [np.nanmin(ys) - 5, np.nanmax(ys) + 5]
153 # Limit to only show real RA/Dec values.
154 if xlim[0] < 0.0:
155 xlim[0] = 0.0
156 if xlim[1] > 360.0:
157 xlim[1] = 360.0
158 if ylim[0] < -90.0:
159 ylim[0] = -90.0
160 if ylim[1] > 90.0:
161 ylim[1] = 90.0
163 return (xlim, ylim)
165 def _getMaxOutlierVals(self, multiplier: float, tracts: list, values: list, outlierInds: list) -> str:
166 """Get the 5 largest outlier values in a string.
168 Parameters
169 ----------
170 multiplier : `float`
171 Select values whose absolute value is > multiplier * sigmaMAD.
172 tracts : `list`
173 All the tracts.
174 values : `list`
175 All the metric values.
176 outlierInds : `list`
177 Indicies of outlier values.
179 Returns
180 -------
181 text : `str`
182 A string containing the 10 tracts with the largest outlier values.
183 """
184 if self.addThresholds:
185 text = "Tracts with value outside thresholds: "
186 else:
187 text = f"Tracts with |value| > {multiplier}" + r"$\sigma_{MAD}$" + ": "
188 if len(outlierInds) > 0:
189 outlierValues = np.array(values)[outlierInds]
190 outlierTracts = np.array(tracts)[outlierInds]
191 # Sort values in descending (-) absolute value order discounting
192 # NaNs.
193 maxInds = np.argsort(-np.abs(outlierValues))
194 # Show up to ten values on the plot.
195 for ind in maxInds[:10]:
196 val = outlierValues[ind]
197 tract = outlierTracts[ind]
198 text += f"{tract}, {val:.3}; "
199 # Remove the final trailing comma and whitespace.
200 text = text[:-2]
201 else:
202 text += "None"
204 return text
206 def makePlot(
207 self,
208 data: KeyedData,
209 plotInfo: Mapping[str, str] | None = None,
210 **kwargs,
211 ) -> AnnotatedFigure:
212 """Make a WholeSkyPlot of the given data.
214 Parameters
215 ----------
216 data : `KeyedData`
217 The catalog to plot the points from.
218 plotInfo : `dict`
219 A dictionary of information about the data being plotted with keys:
221 ``"run"``
222 The output run for the plots (`str`).
223 ``"skymap"``
224 The type of skymap used for the data (`str`).
225 ``"filter"``
226 The filter used for this data (`str`).
227 ``"tract"``
228 The tract that the data comes from (`str`).
230 Returns
231 -------
232 `pipeBase.Struct` containing:
233 skyPlot : `matplotlib.figure.Figure`
234 The resulting figure.
237 Examples
238 --------
239 An example of the plot produced from this code is here:
241 .. image:: /_static/analysis_tools/wholeSkyPlotExample.png
243 For a detailed example of how to make a plot from the command line
244 please see the
245 :ref:`getting started guide<analysis-tools-getting-started>`.
246 """
247 skymap = kwargs["skymap"]
248 if plotInfo is None:
249 plotInfo = {}
251 if self.addThresholds:
252 metricThresholdFile = importResources.read_text(lsst.analysis.tools, "metricInformation.yaml")
253 metricDefs = yaml.safe_load(metricThresholdFile)
255 # Prevent Bands in the plot info showing a list of bands.
256 # If bands is a list, it implies that parameterizedBand=False,
257 # and that the metric is not band-specific.
258 if "bands" in plotInfo:
259 if isinstance(plotInfo["bands"], list):
260 plotInfo["bands"] = "N/A"
262 colorMap = self.colorMap
263 match self.colorMapType:
264 case "sequential":
265 if colorMap is None:
266 colorMap = stars_cmap()
267 outlierColor = "red"
268 norm = None
269 case "divergent":
270 if colorMap is None:
271 colorMap = divergent_cmap()
272 outlierColor = "fuchsia"
273 norm = CenteredNorm()
275 # Create patches using the corners of each tract.
276 patches = []
277 colBarVals = []
278 tracts = []
279 ras = []
280 decs = []
281 mid_ras = []
282 mid_decs = []
283 for i, tract in enumerate(data["tract"]):
284 corners = getTractCorners(skymap, tract)
285 patches.append(Polygon(corners, closed=True))
286 colBarVals.append(data["z"][i])
287 tracts.append(tract)
288 ras.append(corners[0][0])
289 decs.append(corners[0][1])
290 mid_ras.append((corners[0][0] + corners[1][0]) / 2)
291 mid_decs.append((corners[0][1] + corners[2][1]) / 2)
293 # Setup figure.
294 fig: AnnotatedFigure = make_figure(dpi=300, figsize=(12, 3.5))
295 set_rubin_plotstyle()
296 gs = gridspec.GridSpec(1, 4)
297 ax = fig.add_subplot(gs[:3])
298 # Add colored patches showing tract metric values.
299 patchCollection = PatchCollection(patches, cmap=colorMap, norm=norm)
300 ax.add_collection(patchCollection)
302 # Define color bar range.
303 if np.sum(np.isfinite(colBarVals)) > 0:
304 med = np.nanmedian(colBarVals)
305 else:
306 med = np.nan
307 sigmaMad = nanSigmaMad(colBarVals)
308 if self.colorBarMin is not None:
309 vmin = np.float64(self.colorBarMin)
310 else:
311 vmin = med - self.colorBarRange * sigmaMad
312 if self.colorBarMax is not None:
313 vmax = np.float64(self.colorBarMax)
314 else:
315 vmax = med + self.colorBarRange * sigmaMad
317 dataName = self.zAxisLabel.format_map(kwargs)
318 colBarVals = np.array(colBarVals)
319 if self.addThresholds and dataName in metricDefs:
320 if "lowThreshold" in metricDefs[dataName].keys():
321 lowThreshold = metricDefs[dataName]["lowThreshold"]
322 else:
323 lowThreshold = np.nan
324 if "highThreshold" in metricDefs[dataName].keys():
325 highThreshold = metricDefs[dataName]["highThreshold"]
326 else:
327 highThreshold = np.nan
328 outlierInds = np.where((colBarVals < lowThreshold) | (colBarVals > highThreshold))[0]
329 else:
330 # Note tracts with metrics outside (vmin, vmax) as outliers.
331 outlierInds = np.where((colBarVals < vmin) | (colBarVals > vmax))[0]
333 # Initialize legend handles.
334 handles = []
336 if self.showOutliers:
337 # Plot the outlier patches.
338 outlierPatches = []
339 if len(outlierInds) > 0:
340 for ind in outlierInds:
341 outlierPatches.append(patches[ind])
342 outlierPatchCollection = PatchCollection(
343 outlierPatches,
344 cmap=colorMap,
345 norm=norm,
346 facecolors="none",
347 edgecolors=outlierColor,
348 linewidths=0.5,
349 zorder=100,
350 )
351 ax.add_collection(outlierPatchCollection)
352 # Add legend information.
353 outlierPatch = Patch(
354 facecolor="none",
355 edgecolor=outlierColor,
356 linewidth=0.5,
357 label="Outlier",
358 )
359 handles.append(outlierPatch)
361 if self.showNaNs:
362 # Plot tracts with NaN metric values.
363 nanInds = np.where(~np.isfinite(colBarVals))[0]
364 nanPatches = []
365 if len(nanInds) > 0:
366 for ind in nanInds:
367 nanPatches.append(patches[ind])
368 nanPatchCollection = PatchCollection(
369 nanPatches,
370 cmap=None,
371 norm=norm,
372 facecolors="white",
373 edgecolors="grey",
374 linestyles="dotted",
375 linewidths=0.5,
376 zorder=100,
377 )
378 ax.add_collection(nanPatchCollection)
379 # Add legend information.
380 nanPatch = Patch(
381 facecolor="white",
382 edgecolor="grey",
383 linestyle="dotted",
384 linewidth=0.5,
385 label="NaN",
386 )
387 handles.append(nanPatch)
389 if len(handles) > 0:
390 fig.legend(handles=handles)
392 if self.labelTracts:
393 # Label the tracts
394 for i, tract in enumerate(tracts):
395 ax.text(
396 mid_ras[i],
397 mid_decs[i],
398 f"{tract}",
399 ha="center",
400 va="center",
401 fontsize=2,
402 alpha=0.7,
403 zorder=100,
404 )
406 ax.set_aspect("equal")
407 axPos = ax.get_position()
408 ax1 = fig.add_axes([0.73, 0.25, 0.20, 0.47])
410 if np.sum(np.isfinite(data["z"])) > 0:
411 ax1.hist(data["z"], bins=len(data["z"] / 10), color=stars_color(), histtype="step")
412 else:
413 ax1.text(0.5, 0.5, "Data all NaN/Inf")
414 ax1.set_xlabel("Metric Values")
415 ax1.set_ylabel("Number")
416 ax1.yaxis.set_label_position("right")
417 ax1.yaxis.tick_right()
419 if self.addThresholds and dataName in metricDefs:
420 # Check the thresholds are finite and set them to
421 # the min/max of the data if they aren't to calculate
422 # the x range of the plot
423 if np.isfinite(lowThreshold):
424 ax1.axvline(lowThreshold, color=accent_color())
425 else:
426 lowThreshold = np.nanmin(colBarVals)
427 if np.isfinite(highThreshold):
428 ax1.axvline(highThreshold, color=accent_color())
429 else:
430 highThreshold = np.nanmax(colBarVals)
432 widthThreshold = highThreshold - lowThreshold
433 upperLim = highThreshold + 0.5 * widthThreshold
434 lowerLim = lowThreshold - 0.5 * widthThreshold
435 ax1.set_xlim(lowerLim, upperLim)
436 numOutside = np.sum((data["z"] > upperLim) | (data["z"] < lowerLim))
437 ax1.set_title("Outside plot limits: " + str(numOutside))
439 else:
440 if vmin != vmax and np.isfinite(vmin) and np.isfinite(vmax):
441 ax1.set_xlim(vmin, vmax)
443 if self.autoAxesLimits:
444 xlim, ylim = self._getAxesLimits(ras, decs)
445 else:
446 xlim, ylim = self.xLimits, self.yLimits
447 ax.set_xlim(xlim)
448 ax.set_ylim(ylim)
449 ax.set_xlabel(self.xAxisLabel)
450 ax.set_ylabel(self.yAxisLabel)
451 ax.invert_xaxis()
453 if self.showOutliers:
454 # Add text boxes to show the number of tracts, number of NaNs,
455 # median, sigma MAD, and the five largest outlier values.
456 outlierText = self._getMaxOutlierVals(self.colorBarRange, tracts, colBarVals, outlierInds)
457 # Make vertical text spacing readable for different figure sizes.
458 multiplier = 3.5 / fig.get_size_inches()[1]
459 verticalSpacing = 0.028 * multiplier
460 fig.text(
461 0.01,
462 0.01 + 3 * verticalSpacing,
463 f"Num tracts: {len(tracts)}",
464 transform=fig.transFigure,
465 fontsize=8,
466 alpha=0.7,
467 )
468 if self.showNaNs:
469 fig.text(
470 0.01,
471 0.01 + 2 * verticalSpacing,
472 f"Num nans: {len(nanInds)}",
473 transform=fig.transFigure,
474 fontsize=8,
475 alpha=0.7,
476 )
477 fig.text(
478 0.01,
479 0.01 + verticalSpacing,
480 f"Median: {med:.3f}; " + r"$\sigma_{MAD}$" + f": {sigmaMad:.3f}",
481 transform=fig.transFigure,
482 fontsize=8,
483 alpha=0.7,
484 )
485 if self.showOutliers:
486 fig.text(0.01, 0.01, outlierText, transform=fig.transFigure, fontsize=8, alpha=0.7)
488 # Truncate the color range to (vmin, vmax).
489 if vmin != vmax and np.isfinite(vmin) and np.isfinite(vmax):
490 colBarVals = np.clip(np.array(colBarVals), vmin, vmax)
491 patchCollection.set_array(colBarVals)
492 # Make the color bar with a metric label.
493 axPos = ax.get_position()
494 cax = fig.add_axes([0.084, axPos.y1 + 0.02, 0.62, 0.07])
495 fig.colorbar(
496 patchCollection,
497 cax=cax,
498 shrink=0.7,
499 extend="both",
500 location="top",
501 orientation="horizontal",
502 )
503 cbarText = "Metric Values"
505 text = cax.text(
506 0.5,
507 0.5,
508 cbarText,
509 transform=cax.transAxes,
510 ha="center",
511 va="center",
512 fontsize=10,
513 zorder=100,
514 )
515 text.set_path_effects([pathEffects.Stroke(linewidth=3, foreground="w"), pathEffects.Normal()])
517 # Finalize plot appearance.
518 ax.grid()
519 ax.set_axisbelow(True)
520 addPlotInfo(fig, plotInfo)
521 fig.subplots_adjust(left=0.08, right=0.92, top=0.8, bottom=0.17, wspace=0.05)
522 titleText = self.zAxisLabel.format_map(kwargs)
523 if "zUnit" in data and data["zUnit"] != "":
524 titleText += f" ({data['zUnit']})"
525 fig.suptitle("Metric: " + titleText, fontsize=20)
527 # This saves metadata in the PNG that allows the plot-navigator
528 # to provide tract numbers and metric values on mouseover.
529 #
530 # PNG metadata is a set of string keys and string values.
531 # The WholeSkyPlot stores two keys:
532 # - label: the string describing the regions ('tract')
533 # - boxes, JSON string of a list of per-region dictionaries,
534 # where each dictionary has fields:
535 # - min_x, max_x, min_y, max_y for the pixel coordinates of
536 # the four corners of the region
537 # - id: the identifier of the region (e.g. tract number)
538 # - value: the region's metric, as a string.
539 #
540 def make_patch_md(patch, id_field, value, ax):
541 path = ax.transData.transform_path(patch.get_path())
542 x_path = [int(x) for x in path.vertices[:, 0].tolist()]
543 y_path = [int(y) for y in path.vertices[:, 1].tolist()]
544 return {
545 "min_x": min(x_path),
546 "max_x": max(x_path),
547 "min_y": min(y_path),
548 "max_y": max(y_path),
549 "id": f"{id_field}",
550 "value": f"{value:.3}",
551 }
553 # After ax.set_aspect(), the figure needs to be drawn for the axes
554 # transformations to be updated to the right values.
555 fig.canvas.draw_idle()
557 patch_coordinate_entries = [
558 make_patch_md(patch, tract, value, ax)
559 for (patch, tract, value) in zip(patches, tracts, colBarVals)
560 ]
562 fig.metadata = {"label": "Tract", "boxes": json.dumps(patch_coordinate_entries)}
564 return fig