Coverage for python / lsst / analysis / tools / actions / plot / skyPlot.py: 12%
210 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-30 09:26 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-30 09:26 +0000
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ("SkyPlot",)
26from collections.abc import Mapping
28import matplotlib.patheffects as pathEffects
29import numpy as np
30from matplotlib.figure import Figure
31from matplotlib.patches import Rectangle
33from lsst.pex.config import Field, ListField
34from lsst.pex.config.configurableActions import ConfigurableActionField
35from lsst.utils.plotting import (
36 divergent_cmap,
37 galaxies_cmap,
38 galaxies_color,
39 make_figure,
40 set_rubin_plotstyle,
41 stars_cmap,
42 stars_color,
43)
45from ...interfaces import KeyedData, KeyedDataSchema, PlotAction, Scalar, Vector, VectorAction
46from ...math import nanMedian, nanSigmaMad
47from .calculateRange import Med2Mad
48from .plotUtils import addPlotInfo, generateSummaryStats, plotProjectionWithBinning, sortAllArrays
51class SkyPlot(PlotAction):
52 """Plots the on sky distribution of a parameter.
54 Plots the values of the parameter given for the z axis
55 according to the positions given for x and y. Optimised
56 for use with RA and Dec. Also calculates some basic
57 statistics and includes those on the plot.
59 The plotting of patch outlines requires patch information
60 to be included as an additional parameter.
61 """
63 xAxisLabel = Field[str](doc="Label to use for the x axis.", optional=False)
64 yAxisLabel = Field[str](doc="Label to use for the y axis.", optional=False)
65 zAxisLabel = Field[str](doc="Label to use for the z axis.", optional=False)
67 plotOutlines = Field[bool](
68 doc="Plot the outlines of the ccds/patches?",
69 default=True,
70 )
72 plotTypes = ListField[str](
73 doc="Selection of types of objects to plot. Can take any combination of"
74 " stars, galaxies, unknown, mag, any.",
75 optional=False,
76 # itemCheck=_validatePlotTypes,
77 )
79 plotName = Field[str](doc="The name for the plot.", optional=False)
81 alpha = Field[float](
82 doc="Transparency for scatter plot points.",
83 default=1.0,
84 )
86 scatPtSize = Field[float](
87 doc="Marker size for scatter plot points.",
88 default=7,
89 )
91 fixAroundZero = Field[bool](
92 doc="Fix the colorbar to be symmetric around zero.",
93 default=False,
94 )
96 doBinning = Field[bool](
97 doc="Spatially bin the data? Extreme outliers can be shown using" " `showExtremeOutliers`.",
98 optional=True,
99 default=True,
100 )
102 colorbarRange = ConfigurableActionField[VectorAction](
103 doc="Action to calculate the min and max of the colorbar range.",
104 default=Med2Mad,
105 )
107 showExtremeOutliers = Field[bool](
108 doc="Show the x-y positions of extreme outlier values as overlaid scatter points.",
109 default=True,
110 )
112 publicationStyle = Field[bool](
113 doc="Make a simplified plot for publication use.",
114 default=False,
115 )
117 divergent = Field[bool](
118 doc="Use a divergent colormap?",
119 default=False,
120 )
122 def getInputSchema(self, **kwargs) -> KeyedDataSchema:
123 base = []
124 if "stars" in self.plotTypes: # type: ignore
125 base.append(("xStars", Vector))
126 base.append(("yStars", Vector))
127 base.append(("zStars", Vector))
128 base.append(("starStatMask", Vector))
129 if "galaxies" in self.plotTypes: # type: ignore
130 base.append(("xGalaxies", Vector))
131 base.append(("yGalaxies", Vector))
132 base.append(("zGalaxies", Vector))
133 base.append(("galaxyStatMask", Vector))
134 if "unknown" in self.plotTypes: # type: ignore
135 base.append(("xUnknowns", Vector))
136 base.append(("yUnknowns", Vector))
137 base.append(("zUnknowns", Vector))
138 base.append(("unknownStatMask", Vector))
139 if "any" in self.plotTypes: # type: ignore
140 base.append(("x", Vector))
141 base.append(("y", Vector))
142 base.append(("z", Vector))
143 base.append(("statMask", Vector))
145 return base
147 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure:
148 self._validateInput(data, **kwargs)
149 return self.makePlot(data, **kwargs)
150 # table is a dict that needs: x, y, run, skymap, filter, tract,
152 def _validateInput(self, data: KeyedData, **kwargs) -> None:
153 """NOTE currently can only check that something is not a Scalar, not
154 check that the data is consistent with Vector
155 """
156 needed = self.getInputSchema(**kwargs)
157 if remainder := {key.format(**kwargs) for key, _ in needed} - {
158 key.format(**kwargs) for key in data.keys()
159 }:
160 raise ValueError(f"Task needs keys {remainder} but they were not found in input")
161 for name, typ in needed:
162 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar)
163 if isScalar and typ != Scalar:
164 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}")
166 def statsAndText(self, arr, mask=None):
167 """Calculate some stats from an array and return them
168 and some text.
169 """
170 numPoints = len(arr)
171 if mask is not None:
172 arr = arr[mask]
173 med = nanMedian(arr)
174 sigMad = nanSigmaMad(arr)
176 statsText = (
177 f"Median: {med:0.2f}\n"
178 + r"$\sigma_{MAD}$: "
179 + f"{sigMad:0.2f}\n"
180 + r"n$_{points}$: "
181 + f"{numPoints}"
182 )
184 return med, sigMad, statsText
186 def makePlot(
187 self,
188 data: KeyedData,
189 plotInfo: Mapping[str, str] | None = None,
190 sumStats: Mapping | None = None,
191 **kwargs,
192 ) -> Figure:
193 """Make a skyPlot of the given data.
195 Parameters
196 ----------
197 data : `KeyedData`
198 The catalog to plot the points from.
199 plotInfo : `dict`
200 A dictionary of information about the data being plotted with keys:
202 ``"run"``
203 The output run for the plots (`str`).
204 ``"skymap"``
205 The type of skymap used for the data (`str`).
206 ``"filter"``
207 The filter used for this data (`str`).
208 ``"tract"``
209 The tract that the data comes from (`str`).
211 sumStats : `dict`
212 A dictionary where the patchIds are the keys which store the R.A.
213 and dec of the corners of the patch.
215 Returns
216 -------
217 `pipeBase.Struct` containing:
218 skyPlot : `matplotlib.figure.Figure`
219 The resulting figure.
221 Notes
222 -----
223 Expects the data to contain slightly different things
224 depending on the types specified in plotTypes. This
225 is handled automatically if you go through the pipetask
226 framework but if you call this method separately then you
227 need to make sure that data contains what the code is expecting.
229 If stars is in the plot types given then it is expected that
230 data contains: xStars, yStars, zStars and starStatMask.
232 If galaxies is present: xGalaxies, yGalaxies, zGalaxies and
233 galaxyStatsMask.
235 If unknown is present: xUnknowns, yUnknowns, zUnknowns and
236 unknownStatMask.
238 If any is specified: x, y, z, statMask.
240 These options are not exclusive and multiple can be specified
241 and thus need to be present in data.
243 Examples
244 --------
245 An example of the plot produced from this code is here:
247 .. image:: /_static/analysis_tools/skyPlotExample.png
249 For a detailed example of how to make a plot from the command line
250 please see the
251 :ref:`getting started guide<analysis-tools-getting-started>`.
252 """
254 set_rubin_plotstyle()
256 # 'plotName' by default is constructed from the attribute specified in
257 # 'atools.<attribute>' in the pipeline YAML. If the atool sets
258 # self.produce.plot.plotName, it will override this default.
259 if self.plotName:
260 plotInfo["plotName"] = self.plotName
262 fig = make_figure()
263 ax = fig.add_subplot(111)
265 if sumStats is None:
266 if self.plotOutlines and "patch" in data.keys():
267 sumStats = generateSummaryStats(data, kwargs["skymap"], plotInfo)
268 else:
269 sumStats = {}
271 if plotInfo is None:
272 plotInfo = {}
274 # Make divergent colormaps for stars, galaxes and all the points
275 starsCmap = stars_cmap()
276 galsCmap = galaxies_cmap()
278 xCol = self.xAxisLabel
279 yCol = self.yAxisLabel
280 zCol = self.zAxisLabel # noqa: F841
282 toPlotList = []
283 # For galaxies
284 if "galaxies" in self.plotTypes:
285 sortedArrs = sortAllArrays(
286 [data["zGalaxies"], data["xGalaxies"], data["yGalaxies"], data["galaxyStatMask"]]
287 )
288 [colorValsGalaxies, xsGalaxies, ysGalaxies, statGalaxies] = sortedArrs
289 statGalMed, statGalMad, galStatsText = self.statsAndText(colorValsGalaxies, mask=statGalaxies)
290 # Add statistics
291 bbox = dict(facecolor=galaxies_color(), alpha=0.5, edgecolor="none")
292 # Check if plotting stars and galaxies, if so move the
293 # text box so that both can be seen. Needs to be
294 # > 2 becuase not being plotted points are assigned 0
295 if len(self.plotTypes) > 2:
296 boxLoc = 0.63
297 else:
298 boxLoc = 0.8
299 ax.text(boxLoc, 0.91, galStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox)
300 if self.divergent:
301 galsCmap = divergent_cmap()
302 toPlotList.append((xsGalaxies, ysGalaxies, colorValsGalaxies, galsCmap, "Galaxies"))
304 # For stars
305 if "stars" in self.plotTypes:
306 sortedArrs = sortAllArrays([data["zStars"], data["xStars"], data["yStars"], data["starStatMask"]])
307 [colorValsStars, xsStars, ysStars, statStars] = sortedArrs
308 statStarMed, statStarMad, starStatsText = self.statsAndText(colorValsStars, mask=statStars)
309 if not self.publicationStyle:
310 # Add statistics
311 bbox = dict(facecolor=stars_color(), alpha=0.5, edgecolor="none")
312 ax.text(0.8, 0.91, starStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox)
313 if self.divergent:
314 starsCmap = divergent_cmap()
315 toPlotList.append((xsStars, ysStars, colorValsStars, starsCmap, "Stars"))
317 # For unknowns
318 if "unknown" in self.plotTypes:
319 sortedArrs = sortAllArrays(
320 [data["zUnknowns"], data["xUnknowns"], data["yUnknowns"], data["unknownStatMask"]]
321 )
322 [colorValsUnknowns, xsUnknowns, ysUnknowns, statUnknowns] = sortedArrs
323 statUnknownMed, statUnknownMad, unknownStatsText = self.statsAndText(
324 colorValsUnknowns, mask=statUnknowns
325 )
326 if not self.publicationStyle:
327 bbox = dict(facecolor="green", alpha=0.2, edgecolor="none")
328 ax.text(0.8, 0.91, unknownStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox)
329 toPlotList.append((xsUnknowns, ysUnknowns, colorValsUnknowns, "viridis", "Unknown"))
331 if "any" in self.plotTypes:
332 sortedArrs = sortAllArrays([data["z"], data["x"], data["y"], data["statMask"]])
333 [colorValsAny, xs, ys, statAny] = sortedArrs
334 statAnyMed, statAnyMad, anyStatsText = self.statsAndText(colorValsAny, mask=statAny)
335 if not self.publicationStyle:
336 bbox = dict(facecolor="#bab0ac", alpha=0.2, edgecolor="none")
337 ax.text(0.8, 0.91, anyStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox)
338 toPlotList.append((xs, ys, colorValsAny, "viridis", ""))
340 # Corner plot of patches showing summary stat in each
341 if self.plotOutlines:
342 patches = []
343 for dataId in sumStats.keys():
344 (corners, _) = sumStats[dataId]
345 ra = corners[0][0].asDegrees()
346 dec = corners[0][1].asDegrees()
347 xy = (ra, dec)
348 width = corners[2][0].asDegrees() - ra
349 height = corners[2][1].asDegrees() - dec
350 patches.append(Rectangle(xy, width, height, alpha=0.3))
351 ras = [ra.asDegrees() for (ra, dec) in corners]
352 decs = [dec.asDegrees() for (ra, dec) in corners]
353 ax.plot(ras + [ras[0]], decs + [decs[0]], "k", lw=0.5)
354 cenX = ra + width / 2
355 cenY = dec + height / 2
356 if dataId == "tract":
357 minRa = np.min(ras)
358 minDec = np.min(decs)
359 maxRa = np.max(ras)
360 maxDec = np.max(decs)
361 if dataId != "tract":
362 ax.annotate(
363 dataId,
364 (cenX, cenY),
365 color="k",
366 fontsize=5,
367 ha="center",
368 va="center",
369 path_effects=[pathEffects.withStroke(linewidth=2, foreground="w")],
370 )
372 for i, (xs, ys, colorVals, cmap, label) in enumerate(toPlotList):
373 finite = np.isfinite(xs) & np.isfinite(ys)
374 xs = xs[finite]
375 ys = ys[finite]
376 colorVals = colorVals[finite]
377 n_xs = len(xs)
378 if n_xs == 0:
379 continue
381 # colorVal column is unusable so zero it out
382 # This should be obvious on the plot
383 if not any(np.isfinite(colorVals)):
384 colorVals[:] = 0
385 minColorVal, maxColorVal = self.colorbarRange(colorVals)
387 if not self.plotOutlines or "tract" not in sumStats.keys():
388 minRa = np.min(xs)
389 maxRa = np.max(xs)
390 minDec = np.min(ys)
391 maxDec = np.max(ys)
392 # Avoid identical end points which causes problems in binning
393 if minRa == maxRa:
394 maxRa += 1e-5 # There is no reason to pick this number in particular
395 if minDec == maxDec:
396 maxDec += 1e-5 # There is no reason to pick this number in particular
397 if n_xs < 5:
398 continue
400 if self.publicationStyle:
401 showExtremeOutliers = False
402 else:
403 showExtremeOutliers = self.showExtremeOutliers
405 if self.doBinning:
406 nPointBinThresh = 5000
407 else: # Make a true scatter plot (plot all the points)
408 nPointBinThresh = len(xs) + 1
410 # If transparency is being used, point edgecolor matches facecolor
411 if self.alpha == 1.0:
412 edgecolor = "white"
413 else:
414 edgecolor = "face"
416 plotOut = plotProjectionWithBinning(
417 ax,
418 xs,
419 ys,
420 colorVals,
421 cmap,
422 minRa,
423 maxRa,
424 minDec,
425 maxDec,
426 vmin=minColorVal,
427 vmax=maxColorVal,
428 alpha=self.alpha,
429 edgecolor=edgecolor,
430 fixAroundZero=self.fixAroundZero,
431 nPointBinThresh=nPointBinThresh,
432 isSorted=True,
433 showExtremeOutliers=showExtremeOutliers,
434 scatPtSize=self.scatPtSize,
435 )
436 ax.set_aspect("equal")
437 if not self.publicationStyle:
438 cax = fig.add_axes([0.87 + i * 0.04, 0.11, 0.04, 0.77])
439 fig.colorbar(plotOut, cax=cax, extend="both")
440 else:
441 fig.subplots_adjust(wspace=0.0, hspace=0.0, right=0.95, bottom=0.15)
442 axBbox = ax.get_position()
443 cax = fig.add_axes([axBbox.x1, axBbox.y0, 0.04, axBbox.y1 - axBbox.y0])
444 fig.colorbar(plotOut, cax=cax)
446 if len(label) > 0:
447 colorBarLabel = f"{self.zAxisLabel}: {label}"
448 else:
449 colorBarLabel = f"{self.zAxisLabel}"
450 text = cax.text(
451 0.5,
452 0.5,
453 colorBarLabel,
454 color="k",
455 rotation="vertical",
456 transform=cax.transAxes,
457 ha="center",
458 va="center",
459 fontsize=10,
460 )
461 text.set_path_effects([pathEffects.Stroke(linewidth=3, foreground="w"), pathEffects.Normal()])
463 if i == 0 and len(toPlotList) > 1:
464 cax.yaxis.set_ticks_position("left")
466 ax.set_xlabel(xCol)
467 ax.set_ylabel(yCol)
469 fig.canvas.draw()
471 # Find some useful axis limits
472 lenXs = [len(xs) for (xs, _, _, _, _) in toPlotList]
473 if lenXs != [] and np.max(lenXs) > 1000:
474 padRa = (maxRa - minRa) / 10
475 padDec = (maxDec - minDec) / 10
476 ax.set_xlim(maxRa + padRa, minRa - padRa)
477 ax.set_ylim(minDec - padDec, maxDec + padDec)
478 else:
479 ax.invert_xaxis()
481 # Add useful information to the plot
482 if not self.publicationStyle:
483 fig.subplots_adjust(wspace=0.0, hspace=0.0)
484 fig = addPlotInfo(fig, plotInfo)
486 return fig