Coverage for python / lsst / analysis / tools / actions / plot / skyPlot.py: 12%
195 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-22 09:08 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-22 09:08 +0000
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ("SkyPlot",)
26from typing import Mapping, Optional
28import matplotlib.patheffects as pathEffects
29import numpy as np
30from lsst.pex.config import Field, ListField
31from lsst.pex.config.configurableActions import ConfigurableActionField
32from lsst.utils.plotting import (
33 divergent_cmap,
34 galaxies_cmap,
35 galaxies_color,
36 make_figure,
37 set_rubin_plotstyle,
38 stars_cmap,
39 stars_color,
40)
41from matplotlib.figure import Figure
42from matplotlib.patches import Rectangle
44from ...interfaces import KeyedData, KeyedDataSchema, PlotAction, Scalar, Vector, VectorAction
45from ...math import nanMedian, nanSigmaMad
46from .calculateRange import Med2Mad
47from .plotUtils import addPlotInfo, generateSummaryStats, plotProjectionWithBinning, sortAllArrays
50class SkyPlot(PlotAction):
51 """Plots the on sky distribution of a parameter.
53 Plots the values of the parameter given for the z axis
54 according to the positions given for x and y. Optimised
55 for use with RA and Dec. Also calculates some basic
56 statistics and includes those on the plot.
58 The plotting of patch outlines requires patch information
59 to be included as an additional parameter.
60 """
62 xAxisLabel = Field[str](doc="Label to use for the x axis.", optional=False)
63 yAxisLabel = Field[str](doc="Label to use for the y axis.", optional=False)
64 zAxisLabel = Field[str](doc="Label to use for the z axis.", optional=False)
66 plotOutlines = Field[bool](
67 doc="Plot the outlines of the ccds/patches?",
68 default=True,
69 )
71 plotTypes = ListField[str](
72 doc="Selection of types of objects to plot. Can take any combination of"
73 " stars, galaxies, unknown, mag, any.",
74 optional=False,
75 # itemCheck=_validatePlotTypes,
76 )
78 plotName = Field[str](doc="The name for the plot.", optional=False)
80 fixAroundZero = Field[bool](
81 doc="Fix the colorbar to be symmetric around zero.",
82 default=False,
83 )
85 colorbarRange = ConfigurableActionField[VectorAction](
86 doc="Action to calculate the min and max of the colorbar range.",
87 default=Med2Mad,
88 )
90 showExtremeOutliers = Field[bool](
91 doc="Show the x-y positions of extreme outlier values as overlaid scatter points.",
92 default=True,
93 )
95 publicationStyle = Field[bool](
96 doc="Make a simplified plot for publication use.",
97 default=False,
98 )
100 divergent = Field[bool](
101 doc="Use a divergent colormap?",
102 default=False,
103 )
105 def getInputSchema(self, **kwargs) -> KeyedDataSchema:
106 base = []
107 if "stars" in self.plotTypes: # type: ignore
108 base.append(("xStars", Vector))
109 base.append(("yStars", Vector))
110 base.append(("zStars", Vector))
111 base.append(("starStatMask", Vector))
112 if "galaxies" in self.plotTypes: # type: ignore
113 base.append(("xGalaxies", Vector))
114 base.append(("yGalaxies", Vector))
115 base.append(("zGalaxies", Vector))
116 base.append(("galaxyStatMask", Vector))
117 if "unknown" in self.plotTypes: # type: ignore
118 base.append(("xUnknowns", Vector))
119 base.append(("yUnknowns", Vector))
120 base.append(("zUnknowns", Vector))
121 base.append(("unknownStatMask", Vector))
122 if "any" in self.plotTypes: # type: ignore
123 base.append(("x", Vector))
124 base.append(("y", Vector))
125 base.append(("z", Vector))
126 base.append(("statMask", Vector))
128 return base
130 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure:
131 self._validateInput(data, **kwargs)
132 return self.makePlot(data, **kwargs)
133 # table is a dict that needs: x, y, run, skymap, filter, tract,
135 def _validateInput(self, data: KeyedData, **kwargs) -> None:
136 """NOTE currently can only check that something is not a Scalar, not
137 check that the data is consistent with Vector
138 """
139 needed = self.getInputSchema(**kwargs)
140 if remainder := {key.format(**kwargs) for key, _ in needed} - {
141 key.format(**kwargs) for key in data.keys()
142 }:
143 raise ValueError(f"Task needs keys {remainder} but they were not found in input")
144 for name, typ in needed:
145 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar)
146 if isScalar and typ != Scalar:
147 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}")
149 def statsAndText(self, arr, mask=None):
150 """Calculate some stats from an array and return them
151 and some text.
152 """
153 numPoints = len(arr)
154 if mask is not None:
155 arr = arr[mask]
156 med = nanMedian(arr)
157 sigMad = nanSigmaMad(arr)
159 statsText = (
160 "Median: {:0.2f}\n".format(med)
161 + r"$\sigma_{MAD}$: "
162 + "{:0.2f}\n".format(sigMad)
163 + r"n$_{points}$: "
164 + "{}".format(numPoints)
165 )
167 return med, sigMad, statsText
169 def makePlot(
170 self,
171 data: KeyedData,
172 plotInfo: Optional[Mapping[str, str]] = None,
173 sumStats: Optional[Mapping] = None,
174 **kwargs,
175 ) -> Figure:
176 """Make a skyPlot of the given data.
178 Parameters
179 ----------
180 data : `KeyedData`
181 The catalog to plot the points from.
182 plotInfo : `dict`
183 A dictionary of information about the data being plotted with keys:
185 ``"run"``
186 The output run for the plots (`str`).
187 ``"skymap"``
188 The type of skymap used for the data (`str`).
189 ``"filter"``
190 The filter used for this data (`str`).
191 ``"tract"``
192 The tract that the data comes from (`str`).
194 sumStats : `dict`
195 A dictionary where the patchIds are the keys which store the R.A.
196 and dec of the corners of the patch.
198 Returns
199 -------
200 `pipeBase.Struct` containing:
201 skyPlot : `matplotlib.figure.Figure`
202 The resulting figure.
204 Notes
205 -----
206 Expects the data to contain slightly different things
207 depending on the types specified in plotTypes. This
208 is handled automatically if you go through the pipetask
209 framework but if you call this method separately then you
210 need to make sure that data contains what the code is expecting.
212 If stars is in the plot types given then it is expected that
213 data contains: xStars, yStars, zStars and starStatMask.
215 If galaxies is present: xGalaxies, yGalaxies, zGalaxies and
216 galaxyStatsMask.
218 If unknown is present: xUnknowns, yUnknowns, zUnknowns and
219 unknownStatMask.
221 If any is specified: x, y, z, statMask.
223 These options are not exclusive and multiple can be specified
224 and thus need to be present in data.
226 Examples
227 --------
228 An example of the plot produced from this code is here:
230 .. image:: /_static/analysis_tools/skyPlotExample.png
232 For a detailed example of how to make a plot from the command line
233 please see the
234 :ref:`getting started guide<analysis-tools-getting-started>`.
235 """
237 set_rubin_plotstyle()
238 fig = make_figure()
239 ax = fig.add_subplot(111)
241 if sumStats is None:
242 if self.plotOutlines and "patch" in data.keys():
243 sumStats = generateSummaryStats(data, kwargs["skymap"], plotInfo)
244 else:
245 sumStats = {}
247 if plotInfo is None:
248 plotInfo = {}
250 # Make divergent colormaps for stars, galaxes and all the points
251 starsCmap = stars_cmap()
252 galsCmap = galaxies_cmap()
254 xCol = self.xAxisLabel
255 yCol = self.yAxisLabel
256 zCol = self.zAxisLabel # noqa: F841
258 toPlotList = []
259 # For galaxies
260 if "galaxies" in self.plotTypes:
261 sortedArrs = sortAllArrays(
262 [data["zGalaxies"], data["xGalaxies"], data["yGalaxies"], data["galaxyStatMask"]]
263 )
264 [colorValsGalaxies, xsGalaxies, ysGalaxies, statGalaxies] = sortedArrs
265 statGalMed, statGalMad, galStatsText = self.statsAndText(colorValsGalaxies, mask=statGalaxies)
266 # Add statistics
267 bbox = dict(facecolor=galaxies_color(), alpha=0.5, edgecolor="none")
268 # Check if plotting stars and galaxies, if so move the
269 # text box so that both can be seen. Needs to be
270 # > 2 becuase not being plotted points are assigned 0
271 if len(self.plotTypes) > 2:
272 boxLoc = 0.63
273 else:
274 boxLoc = 0.8
275 ax.text(boxLoc, 0.91, galStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox)
276 if self.divergent:
277 galsCmap = divergent_cmap()
278 toPlotList.append((xsGalaxies, ysGalaxies, colorValsGalaxies, galsCmap, "Galaxies"))
280 # For stars
281 if "stars" in self.plotTypes:
282 sortedArrs = sortAllArrays([data["zStars"], data["xStars"], data["yStars"], data["starStatMask"]])
283 [colorValsStars, xsStars, ysStars, statStars] = sortedArrs
284 statStarMed, statStarMad, starStatsText = self.statsAndText(colorValsStars, mask=statStars)
285 if not self.publicationStyle:
286 # Add statistics
287 bbox = dict(facecolor=stars_color(), alpha=0.5, edgecolor="none")
288 ax.text(0.8, 0.91, starStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox)
289 if self.divergent:
290 starsCmap = divergent_cmap()
291 toPlotList.append((xsStars, ysStars, colorValsStars, starsCmap, "Stars"))
293 # For unknowns
294 if "unknown" in self.plotTypes:
295 sortedArrs = sortAllArrays(
296 [data["zUnknowns"], data["xUnknowns"], data["yUnknowns"], data["unknownStatMask"]]
297 )
298 [colorValsUnknowns, xsUnknowns, ysUnknowns, statUnknowns] = sortedArrs
299 statUnknownMed, statUnknownMad, unknownStatsText = self.statsAndText(
300 colorValsUnknowns, mask=statUnknowns
301 )
302 if not self.publicationStyle:
303 bbox = dict(facecolor="green", alpha=0.2, edgecolor="none")
304 ax.text(0.8, 0.91, unknownStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox)
305 toPlotList.append((xsUnknowns, ysUnknowns, colorValsUnknowns, "viridis", "Unknown"))
307 if "any" in self.plotTypes:
308 sortedArrs = sortAllArrays([data["z"], data["x"], data["y"], data["statMask"]])
309 [colorValsAny, xs, ys, statAny] = sortedArrs
310 statAnyMed, statAnyMad, anyStatsText = self.statsAndText(colorValsAny, mask=statAny)
311 if not self.publicationStyle:
312 bbox = dict(facecolor="#bab0ac", alpha=0.2, edgecolor="none")
313 ax.text(0.8, 0.91, anyStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox)
314 toPlotList.append((xs, ys, colorValsAny, "viridis", "All"))
316 # Corner plot of patches showing summary stat in each
317 if self.plotOutlines:
318 patches = []
319 for dataId in sumStats.keys():
320 (corners, _) = sumStats[dataId]
321 ra = corners[0][0].asDegrees()
322 dec = corners[0][1].asDegrees()
323 xy = (ra, dec)
324 width = corners[2][0].asDegrees() - ra
325 height = corners[2][1].asDegrees() - dec
326 patches.append(Rectangle(xy, width, height, alpha=0.3))
327 ras = [ra.asDegrees() for (ra, dec) in corners]
328 decs = [dec.asDegrees() for (ra, dec) in corners]
329 ax.plot(ras + [ras[0]], decs + [decs[0]], "k", lw=0.5)
330 cenX = ra + width / 2
331 cenY = dec + height / 2
332 if dataId == "tract":
333 minRa = np.min(ras)
334 minDec = np.min(decs)
335 maxRa = np.max(ras)
336 maxDec = np.max(decs)
337 if dataId != "tract":
338 ax.annotate(
339 dataId,
340 (cenX, cenY),
341 color="k",
342 fontsize=5,
343 ha="center",
344 va="center",
345 path_effects=[pathEffects.withStroke(linewidth=2, foreground="w")],
346 )
348 for i, (xs, ys, colorVals, cmap, label) in enumerate(toPlotList):
349 finite = np.isfinite(xs) & np.isfinite(ys)
350 xs = xs[finite]
351 ys = ys[finite]
352 colorVals = colorVals[finite]
353 n_xs = len(xs)
354 # colorVal column is unusable so zero it out
355 # This should be obvious on the plot
356 if not any(np.isfinite(colorVals)):
357 colorVals[:] = 0
358 minColorVal, maxColorVal = self.colorbarRange(colorVals)
360 if n_xs < 5:
361 continue
362 if not self.plotOutlines or "tract" not in sumStats.keys():
363 minRa = np.min(xs)
364 maxRa = np.max(xs)
365 minDec = np.min(ys)
366 maxDec = np.max(ys)
367 # Avoid identical end points which causes problems in binning
368 if minRa == maxRa:
369 maxRa += 1e-5 # There is no reason to pick this number in particular
370 if minDec == maxDec:
371 maxDec += 1e-5 # There is no reason to pick this number in particular
373 if self.publicationStyle:
374 showExtremeOutliers = False
375 else:
376 showExtremeOutliers = self.showExtremeOutliers
377 plotOut = plotProjectionWithBinning(
378 ax,
379 xs,
380 ys,
381 colorVals,
382 cmap,
383 minRa,
384 maxRa,
385 minDec,
386 maxDec,
387 vmin=minColorVal,
388 vmax=maxColorVal,
389 fixAroundZero=self.fixAroundZero,
390 isSorted=True,
391 showExtremeOutliers=showExtremeOutliers,
392 )
393 ax.set_aspect("equal")
394 if not self.publicationStyle:
395 cax = fig.add_axes([0.87 + i * 0.04, 0.11, 0.04, 0.77])
396 fig.colorbar(plotOut, cax=cax, extend="both")
397 else:
398 fig.subplots_adjust(wspace=0.0, hspace=0.0, right=0.95, bottom=0.15)
399 axBbox = ax.get_position()
400 cax = fig.add_axes([axBbox.x1, axBbox.y0, 0.04, axBbox.y1 - axBbox.y0])
401 fig.colorbar(plotOut, cax=cax)
403 colorBarLabel = "{}: {}".format(self.zAxisLabel, label)
404 text = cax.text(
405 0.5,
406 0.5,
407 colorBarLabel,
408 color="k",
409 rotation="vertical",
410 transform=cax.transAxes,
411 ha="center",
412 va="center",
413 fontsize=10,
414 )
415 text.set_path_effects([pathEffects.Stroke(linewidth=3, foreground="w"), pathEffects.Normal()])
417 if i == 0 and len(toPlotList) > 1:
418 cax.yaxis.set_ticks_position("left")
420 ax.set_xlabel(xCol)
421 ax.set_ylabel(yCol)
423 fig.canvas.draw()
425 # Find some useful axis limits
426 lenXs = [len(xs) for (xs, _, _, _, _) in toPlotList]
427 if lenXs != [] and np.max(lenXs) > 1000:
428 padRa = (maxRa - minRa) / 10
429 padDec = (maxDec - minDec) / 10
430 ax.set_xlim(maxRa + padRa, minRa - padRa)
431 ax.set_ylim(minDec - padDec, maxDec + padDec)
432 else:
433 ax.invert_xaxis()
435 # Add useful information to the plot
436 if not self.publicationStyle:
437 fig.subplots_adjust(wspace=0.0, hspace=0.0)
438 fig = addPlotInfo(fig, plotInfo)
440 return fig