Coverage for python/lsst/analysis/tools/actions/plot/skyPlot.py: 11%
173 statements
« prev ^ index » next coverage.py v7.2.5, created at 2023-05-23 04:22 -0700
« prev ^ index » next coverage.py v7.2.5, created at 2023-05-23 04:22 -0700
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ("SkyPlot",)
26from typing import Mapping, Optional
28import matplotlib.patheffects as pathEffects
29import matplotlib.pyplot as plt
30import numpy as np
31from lsst.pex.config import Field, ListField
32from matplotlib.figure import Figure
33from matplotlib.patches import Rectangle
35from ...interfaces import KeyedData, KeyedDataSchema, PlotAction, Scalar, Vector
36from ...statistics import nansigmaMad
37from .plotUtils import addPlotInfo, mkColormap, plotProjectionWithBinning, sortAllArrays
39# from .plotUtils import generateSummaryStats, parsePlotInfo
42class SkyPlot(PlotAction):
43 """Plots the on sky distribution of a parameter.
45 Plots the values of the parameter given for the z axis
46 according to the positions given for x and y. Optimised
47 for use with RA and Dec. Also calculates some basic
48 statistics and includes those on the plot.
49 """
51 xAxisLabel = Field[str](doc="Label to use for the x axis.", optional=False)
52 yAxisLabel = Field[str](doc="Label to use for the y axis.", optional=False)
53 zAxisLabel = Field[str](doc="Label to use for the z axis.", optional=False)
55 plotOutlines = Field[bool](
56 doc="Plot the outlines of the ccds/patches?",
57 default=True,
58 )
60 plotTypes = ListField[str](
61 doc="Selection of types of objects to plot. Can take any combination of"
62 " stars, galaxies, unknown, mag, any.",
63 optional=False,
64 # itemCheck=_validatePlotTypes,
65 )
67 plotName = Field[str](doc="The name for the plot.", optional=False)
69 fixAroundZero = Field[bool](
70 doc="Fix the colorbar to be symmetric around zero.",
71 default=False,
72 )
74 def getInputSchema(self, **kwargs) -> KeyedDataSchema:
75 base = []
76 if "stars" in self.plotTypes: # type: ignore
77 base.append(("xStars", Vector))
78 base.append(("yStars", Vector))
79 base.append(("zStars", Vector))
80 base.append(("starStatMask", Vector))
81 if "galaxies" in self.plotTypes: # type: ignore
82 base.append(("xGalaxies", Vector))
83 base.append(("yGalaxies", Vector))
84 base.append(("zGalaxies", Vector))
85 base.append(("galaxyStatMask", Vector))
86 if "unknown" in self.plotTypes: # type: ignore
87 base.append(("xUnknowns", Vector))
88 base.append(("yUnknowns", Vector))
89 base.append(("zUnknowns", Vector))
90 base.append(("unknownStatMask", Vector))
91 if "any" in self.plotTypes: # type: ignore
92 base.append(("x", Vector))
93 base.append(("y", Vector))
94 base.append(("z", Vector))
95 base.append(("statMask", Vector))
97 return base
99 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure:
100 self._validateInput(data, **kwargs)
101 return self.makePlot(data, **kwargs)
102 # table is a dict that needs: x, y, run, skymap, filter, tract,
104 def _validateInput(self, data: KeyedData, **kwargs) -> None:
105 """NOTE currently can only check that something is not a Scalar, not
106 check that the data is consistent with Vector
107 """
108 needed = self.getInputSchema(**kwargs)
109 if remainder := {key.format(**kwargs) for key, _ in needed} - {
110 key.format(**kwargs) for key in data.keys()
111 }:
112 raise ValueError(f"Task needs keys {remainder} but they were not found in input")
113 for name, typ in needed:
114 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar)
115 if isScalar and typ != Scalar:
116 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}")
118 def statsAndText(self, arr, mask=None):
119 """Calculate some stats from an array and return them
120 and some text.
121 """
122 numPoints = len(arr)
123 if mask is not None:
124 arr = arr[mask]
125 med = np.nanmedian(arr)
126 sigMad = nansigmaMad(arr)
128 statsText = (
129 "Median: {:0.2f}\n".format(med)
130 + r"$\sigma_{MAD}$: "
131 + "{:0.2f}\n".format(sigMad)
132 + r"n$_{points}$: "
133 + "{}".format(numPoints)
134 )
136 return med, sigMad, statsText
138 def makePlot(
139 self,
140 data: KeyedData,
141 plotInfo: Optional[Mapping[str, str]] = None,
142 sumStats: Optional[Mapping] = None,
143 **kwargs,
144 ) -> Figure:
145 """Make a skyPlot of the given data.
147 Parameters
148 ----------
149 data : `KeyedData`
150 The catalog to plot the points from.
151 plotInfo : `dict`
152 A dictionary of information about the data being plotted with keys:
154 ``"run"``
155 The output run for the plots (`str`).
156 ``"skymap"``
157 The type of skymap used for the data (`str`).
158 ``"filter"``
159 The filter used for this data (`str`).
160 ``"tract"``
161 The tract that the data comes from (`str`).
163 sumStats : `dict`
164 A dictionary where the patchIds are the keys which store the R.A.
165 and dec of the corners of the patch.
167 Returns
168 -------
169 `pipeBase.Struct` containing:
170 skyPlot : `matplotlib.figure.Figure`
171 The resulting figure.
173 Notes
174 -----
175 Expects the data to contain slightly different things
176 depending on the types specified in plotTypes. This
177 is handled automatically if you go through the pipetask
178 framework but if you call this method separately then you
179 need to make sure that data contains what the code is expecting.
181 If stars is in the plot types given then it is expected that
182 data contains: xStars, yStars, zStars and starStatMask.
184 If galaxies is present: xGalaxies, yGalaxies, zGalaxies and
185 galaxyStatsMask.
187 If unknown is present: xUnknowns, yUnknowns, zUnknowns and
188 unknownStatMask.
190 If any is specified: x, y, z, statMask.
192 These options are not exclusive and multiple can be specified
193 and thus need to be present in data.
195 Examples
196 --------
197 An example of the plot produced from this code is here:
199 .. image:: /_static/analysis_tools/skyPlotExample.png
201 For a detailed example of how to make a plot from the command line
202 please see the
203 :ref:`getting started guide<analysis-tools-getting-started>`.
204 """
206 fig = plt.figure(dpi=300)
207 ax = fig.add_subplot(111)
209 if sumStats is None:
210 sumStats = {}
212 if plotInfo is None:
213 plotInfo = {}
215 # Make divergent colormaps for stars, galaxes and all the points
216 blueGreen = mkColormap(["midnightblue", "lightcyan", "darkgreen"])
217 redPurple = mkColormap(["indigo", "lemonchiffon", "firebrick"])
218 orangeBlue = mkColormap(["darkOrange", "thistle", "midnightblue"])
220 xCol = self.xAxisLabel
221 yCol = self.yAxisLabel
222 zCol = self.zAxisLabel # noqa: F841
224 toPlotList = []
225 # For galaxies
226 if "galaxies" in self.plotTypes:
227 sortedArrs = sortAllArrays(
228 [data["zGalaxies"], data["xGalaxies"], data["yGalaxies"], data["galaxyStatMask"]]
229 )
230 [colorValsGalaxies, xsGalaxies, ysGalaxies, statGalaxies] = sortedArrs
231 statGalMed, statGalMad, galStatsText = self.statsAndText(colorValsGalaxies, mask=statGalaxies)
232 # Add statistics
233 bbox = dict(facecolor="lemonchiffon", alpha=0.5, edgecolor="none")
234 # Check if plotting stars and galaxies, if so move the
235 # text box so that both can be seen. Needs to be
236 # > 2 becuase not being plotted points are assigned 0
237 if len(self.plotTypes) > 2:
238 boxLoc = 0.63
239 else:
240 boxLoc = 0.8
241 ax.text(boxLoc, 0.91, galStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox)
242 toPlotList.append((xsGalaxies, ysGalaxies, colorValsGalaxies, redPurple, "Galaxies"))
244 # For stars
245 if "stars" in self.plotTypes:
246 sortedArrs = sortAllArrays([data["zStars"], data["xStars"], data["yStars"], data["starStatMask"]])
247 [colorValsStars, xsStars, ysStars, statStars] = sortedArrs
248 statStarMed, statStarMad, starStatsText = self.statsAndText(colorValsStars, mask=statStars)
249 # Add statistics
250 bbox = dict(facecolor="paleturquoise", alpha=0.5, edgecolor="none")
251 ax.text(0.8, 0.91, starStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox)
252 toPlotList.append((xsStars, ysStars, colorValsStars, blueGreen, "Stars"))
254 # For unknowns
255 if "unknown" in self.plotTypes:
256 sortedArrs = sortAllArrays(
257 [data["zUnknowns"], data["xUnknowns"], data["yUnknowns"], data["unknownStatMask"]]
258 )
259 [colorValsUnknowns, xsUnknowns, ysUnknowns, statUnknowns] = sortedArrs
260 statUnknownMed, statUnknownMad, unknownStatsText = self.statsAndText(
261 colorValsUnknowns, mask=statUnknowns
262 )
263 bbox = dict(facecolor="green", alpha=0.2, edgecolor="none")
264 ax.text(0.8, 0.91, unknownStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox)
265 toPlotList.append((xsUnknowns, ysUnknowns, colorValsUnknowns, "viridis", "Unknown"))
267 if "any" in self.plotTypes:
268 sortedArrs = sortAllArrays([data["z"], data["x"], data["y"], data["statMask"]])
269 [colorValsAny, xs, ys, statAny] = sortedArrs
270 statAnyMed, statAnyMad, anyStatsText = self.statsAndText(colorValsAny, mask=statAny)
271 bbox = dict(facecolor="purple", alpha=0.2, edgecolor="none")
272 ax.text(0.8, 0.91, anyStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox)
273 toPlotList.append((xs, ys, colorValsAny, orangeBlue, "All"))
275 # Corner plot of patches showing summary stat in each
276 if self.plotOutlines:
277 patches = []
278 for dataId in sumStats.keys():
279 (corners, _) = sumStats[dataId]
280 ra = corners[0][0].asDegrees()
281 dec = corners[0][1].asDegrees()
282 xy = (ra, dec)
283 width = corners[2][0].asDegrees() - ra
284 height = corners[2][1].asDegrees() - dec
285 patches.append(Rectangle(xy, width, height, alpha=0.3))
286 ras = [ra.asDegrees() for (ra, dec) in corners]
287 decs = [dec.asDegrees() for (ra, dec) in corners]
288 ax.plot(ras + [ras[0]], decs + [decs[0]], "k", lw=0.5)
289 cenX = ra + width / 2
290 cenY = dec + height / 2
291 if dataId == "tract":
292 minRa = np.min(ras)
293 minDec = np.min(decs)
294 maxRa = np.max(ras)
295 maxDec = np.max(decs)
296 if dataId != "tract":
297 ax.annotate(
298 dataId,
299 (cenX, cenY),
300 color="k",
301 fontsize=5,
302 ha="center",
303 va="center",
304 path_effects=[pathEffects.withStroke(linewidth=2, foreground="w")],
305 )
307 for i, (xs, ys, colorVals, cmap, label) in enumerate(toPlotList):
308 finite = np.isfinite(xs) & np.isfinite(ys)
309 xs = xs[finite]
310 ys = ys[finite]
311 n_xs = len(xs)
312 # colorVal column is unusable so zero it out
313 # This should be obvious on the plot
314 if not any(np.isfinite(colorVals)):
315 colorVals[:] = 0
317 if n_xs < 5:
318 continue
319 if not self.plotOutlines or "tract" not in sumStats.keys():
320 minRa = np.min(xs)
321 maxRa = np.max(xs)
322 minDec = np.min(ys)
323 maxDec = np.max(ys)
324 # Avoid identical end points which causes problems in binning
325 if minRa == maxRa:
326 maxRa += 1e-5 # There is no reason to pick this number in particular
327 if minDec == maxDec:
328 maxDec += 1e-5 # There is no reason to pick this number in particular
330 plotOut = plotProjectionWithBinning(
331 ax,
332 xs,
333 ys,
334 colorVals,
335 cmap,
336 minRa,
337 maxRa,
338 minDec,
339 maxDec,
340 fixAroundZero=self.fixAroundZero,
341 isSorted=True,
342 )
343 cax = fig.add_axes([0.87 + i * 0.04, 0.11, 0.04, 0.77])
344 plt.colorbar(plotOut, cax=cax, extend="both")
345 colorBarLabel = "{}: {}".format(self.zAxisLabel, label)
346 text = cax.text(
347 0.5,
348 0.5,
349 colorBarLabel,
350 color="k",
351 rotation="vertical",
352 transform=cax.transAxes,
353 ha="center",
354 va="center",
355 fontsize=10,
356 )
357 text.set_path_effects([pathEffects.Stroke(linewidth=3, foreground="w"), pathEffects.Normal()])
358 cax.tick_params(labelsize=7)
360 if i == 0 and len(toPlotList) > 1:
361 cax.yaxis.set_ticks_position("left")
363 ax.set_xlabel(xCol)
364 ax.set_ylabel(yCol)
365 ax.tick_params(axis="x", labelrotation=25)
366 ax.tick_params(labelsize=7)
368 ax.set_aspect("equal")
369 plt.draw()
371 # Find some useful axis limits
372 lenXs = [len(xs) for (xs, _, _, _, _) in toPlotList]
373 if lenXs != [] and np.max(lenXs) > 1000:
374 padRa = (maxRa - minRa) / 10
375 padDec = (maxDec - minDec) / 10
376 ax.set_xlim(maxRa + padRa, minRa - padRa)
377 ax.set_ylim(minDec - padDec, maxDec + padDec)
378 else:
379 ax.invert_xaxis()
381 # Add useful information to the plot
382 plt.subplots_adjust(wspace=0.0, hspace=0.0, right=0.85)
383 fig = plt.gcf()
384 fig = addPlotInfo(fig, plotInfo)
386 return fig