Coverage for python/lsst/analysis/tools/actions/plot/skyPlot.py: 11%
173 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-17 10:45 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-17 10:45 +0000
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ("SkyPlot",)
26from typing import Mapping, Optional
28import matplotlib.patheffects as pathEffects
29import matplotlib.pyplot as plt
30import numpy as np
31from lsst.pex.config import Field, ListField
32from matplotlib.figure import Figure
33from matplotlib.patches import Rectangle
35from ...interfaces import KeyedData, KeyedDataSchema, PlotAction, Scalar, Vector
36from ...statistics import nansigmaMad
37from .plotUtils import addPlotInfo, mkColormap, plotProjectionWithBinning, sortAllArrays
40class SkyPlot(PlotAction):
41 """Plots the on sky distribution of a parameter.
43 Plots the values of the parameter given for the z axis
44 according to the positions given for x and y. Optimised
45 for use with RA and Dec. Also calculates some basic
46 statistics and includes those on the plot.
47 """
49 xAxisLabel = Field[str](doc="Label to use for the x axis.", optional=False)
50 yAxisLabel = Field[str](doc="Label to use for the y axis.", optional=False)
51 zAxisLabel = Field[str](doc="Label to use for the z axis.", optional=False)
53 plotOutlines = Field[bool](
54 doc="Plot the outlines of the ccds/patches?",
55 default=True,
56 )
58 plotTypes = ListField[str](
59 doc="Selection of types of objects to plot. Can take any combination of"
60 " stars, galaxies, unknown, mag, any.",
61 optional=False,
62 # itemCheck=_validatePlotTypes,
63 )
65 plotName = Field[str](doc="The name for the plot.", optional=False)
67 fixAroundZero = Field[bool](
68 doc="Fix the colorbar to be symmetric around zero.",
69 default=False,
70 )
72 def getInputSchema(self, **kwargs) -> KeyedDataSchema:
73 base = []
74 if "stars" in self.plotTypes: # type: ignore
75 base.append(("xStars", Vector))
76 base.append(("yStars", Vector))
77 base.append(("zStars", Vector))
78 base.append(("starStatMask", Vector))
79 if "galaxies" in self.plotTypes: # type: ignore
80 base.append(("xGalaxies", Vector))
81 base.append(("yGalaxies", Vector))
82 base.append(("zGalaxies", Vector))
83 base.append(("galaxyStatMask", Vector))
84 if "unknown" in self.plotTypes: # type: ignore
85 base.append(("xUnknowns", Vector))
86 base.append(("yUnknowns", Vector))
87 base.append(("zUnknowns", Vector))
88 base.append(("unknownStatMask", Vector))
89 if "any" in self.plotTypes: # type: ignore
90 base.append(("x", Vector))
91 base.append(("y", Vector))
92 base.append(("z", Vector))
93 base.append(("statMask", Vector))
95 return base
97 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure:
98 self._validateInput(data, **kwargs)
99 return self.makePlot(data, **kwargs)
100 # table is a dict that needs: x, y, run, skymap, filter, tract,
102 def _validateInput(self, data: KeyedData, **kwargs) -> None:
103 """NOTE currently can only check that something is not a Scalar, not
104 check that the data is consistent with Vector
105 """
106 needed = self.getInputSchema(**kwargs)
107 if remainder := {key.format(**kwargs) for key, _ in needed} - {
108 key.format(**kwargs) for key in data.keys()
109 }:
110 raise ValueError(f"Task needs keys {remainder} but they were not found in input")
111 for name, typ in needed:
112 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar)
113 if isScalar and typ != Scalar:
114 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}")
116 def statsAndText(self, arr, mask=None):
117 """Calculate some stats from an array and return them
118 and some text.
119 """
120 numPoints = len(arr)
121 if mask is not None:
122 arr = arr[mask]
123 med = np.nanmedian(arr)
124 sigMad = nansigmaMad(arr)
126 statsText = (
127 "Median: {:0.2f}\n".format(med)
128 + r"$\sigma_{MAD}$: "
129 + "{:0.2f}\n".format(sigMad)
130 + r"n$_{points}$: "
131 + "{}".format(numPoints)
132 )
134 return med, sigMad, statsText
136 def makePlot(
137 self,
138 data: KeyedData,
139 plotInfo: Optional[Mapping[str, str]] = None,
140 sumStats: Optional[Mapping] = None,
141 **kwargs,
142 ) -> Figure:
143 """Make a skyPlot of the given data.
145 Parameters
146 ----------
147 data : `KeyedData`
148 The catalog to plot the points from.
149 plotInfo : `dict`
150 A dictionary of information about the data being plotted with keys:
152 ``"run"``
153 The output run for the plots (`str`).
154 ``"skymap"``
155 The type of skymap used for the data (`str`).
156 ``"filter"``
157 The filter used for this data (`str`).
158 ``"tract"``
159 The tract that the data comes from (`str`).
161 sumStats : `dict`
162 A dictionary where the patchIds are the keys which store the R.A.
163 and dec of the corners of the patch.
165 Returns
166 -------
167 `pipeBase.Struct` containing:
168 skyPlot : `matplotlib.figure.Figure`
169 The resulting figure.
171 Notes
172 -----
173 Expects the data to contain slightly different things
174 depending on the types specified in plotTypes. This
175 is handled automatically if you go through the pipetask
176 framework but if you call this method separately then you
177 need to make sure that data contains what the code is expecting.
179 If stars is in the plot types given then it is expected that
180 data contains: xStars, yStars, zStars and starStatMask.
182 If galaxies is present: xGalaxies, yGalaxies, zGalaxies and
183 galaxyStatsMask.
185 If unknown is present: xUnknowns, yUnknowns, zUnknowns and
186 unknownStatMask.
188 If any is specified: x, y, z, statMask.
190 These options are not exclusive and multiple can be specified
191 and thus need to be present in data.
193 Examples
194 --------
195 An example of the plot produced from this code is here:
197 .. image:: /_static/analysis_tools/skyPlotExample.png
199 For a detailed example of how to make a plot from the command line
200 please see the
201 :ref:`getting started guide<analysis-tools-getting-started>`.
202 """
204 fig = plt.figure(dpi=300)
205 ax = fig.add_subplot(111)
207 if sumStats is None:
208 sumStats = {}
210 if plotInfo is None:
211 plotInfo = {}
213 # Make divergent colormaps for stars, galaxes and all the points
214 blueGreen = mkColormap(["midnightblue", "lightcyan", "darkgreen"])
215 redPurple = mkColormap(["indigo", "lemonchiffon", "firebrick"])
216 orangeBlue = mkColormap(["darkOrange", "thistle", "midnightblue"])
218 xCol = self.xAxisLabel
219 yCol = self.yAxisLabel
220 zCol = self.zAxisLabel # noqa: F841
222 toPlotList = []
223 # For galaxies
224 if "galaxies" in self.plotTypes:
225 sortedArrs = sortAllArrays(
226 [data["zGalaxies"], data["xGalaxies"], data["yGalaxies"], data["galaxyStatMask"]]
227 )
228 [colorValsGalaxies, xsGalaxies, ysGalaxies, statGalaxies] = sortedArrs
229 statGalMed, statGalMad, galStatsText = self.statsAndText(colorValsGalaxies, mask=statGalaxies)
230 # Add statistics
231 bbox = dict(facecolor="lemonchiffon", alpha=0.5, edgecolor="none")
232 # Check if plotting stars and galaxies, if so move the
233 # text box so that both can be seen. Needs to be
234 # > 2 becuase not being plotted points are assigned 0
235 if len(self.plotTypes) > 2:
236 boxLoc = 0.63
237 else:
238 boxLoc = 0.8
239 ax.text(boxLoc, 0.91, galStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox)
240 toPlotList.append((xsGalaxies, ysGalaxies, colorValsGalaxies, redPurple, "Galaxies"))
242 # For stars
243 if "stars" in self.plotTypes:
244 sortedArrs = sortAllArrays([data["zStars"], data["xStars"], data["yStars"], data["starStatMask"]])
245 [colorValsStars, xsStars, ysStars, statStars] = sortedArrs
246 statStarMed, statStarMad, starStatsText = self.statsAndText(colorValsStars, mask=statStars)
247 # Add statistics
248 bbox = dict(facecolor="paleturquoise", alpha=0.5, edgecolor="none")
249 ax.text(0.8, 0.91, starStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox)
250 toPlotList.append((xsStars, ysStars, colorValsStars, blueGreen, "Stars"))
252 # For unknowns
253 if "unknown" in self.plotTypes:
254 sortedArrs = sortAllArrays(
255 [data["zUnknowns"], data["xUnknowns"], data["yUnknowns"], data["unknownStatMask"]]
256 )
257 [colorValsUnknowns, xsUnknowns, ysUnknowns, statUnknowns] = sortedArrs
258 statUnknownMed, statUnknownMad, unknownStatsText = self.statsAndText(
259 colorValsUnknowns, mask=statUnknowns
260 )
261 bbox = dict(facecolor="green", alpha=0.2, edgecolor="none")
262 ax.text(0.8, 0.91, unknownStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox)
263 toPlotList.append((xsUnknowns, ysUnknowns, colorValsUnknowns, "viridis", "Unknown"))
265 if "any" in self.plotTypes:
266 sortedArrs = sortAllArrays([data["z"], data["x"], data["y"], data["statMask"]])
267 [colorValsAny, xs, ys, statAny] = sortedArrs
268 statAnyMed, statAnyMad, anyStatsText = self.statsAndText(colorValsAny, mask=statAny)
269 bbox = dict(facecolor="purple", alpha=0.2, edgecolor="none")
270 ax.text(0.8, 0.91, anyStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox)
271 toPlotList.append((xs, ys, colorValsAny, orangeBlue, "All"))
273 # Corner plot of patches showing summary stat in each
274 if self.plotOutlines:
275 patches = []
276 for dataId in sumStats.keys():
277 (corners, _) = sumStats[dataId]
278 ra = corners[0][0].asDegrees()
279 dec = corners[0][1].asDegrees()
280 xy = (ra, dec)
281 width = corners[2][0].asDegrees() - ra
282 height = corners[2][1].asDegrees() - dec
283 patches.append(Rectangle(xy, width, height, alpha=0.3))
284 ras = [ra.asDegrees() for (ra, dec) in corners]
285 decs = [dec.asDegrees() for (ra, dec) in corners]
286 ax.plot(ras + [ras[0]], decs + [decs[0]], "k", lw=0.5)
287 cenX = ra + width / 2
288 cenY = dec + height / 2
289 if dataId == "tract":
290 minRa = np.min(ras)
291 minDec = np.min(decs)
292 maxRa = np.max(ras)
293 maxDec = np.max(decs)
294 if dataId != "tract":
295 ax.annotate(
296 dataId,
297 (cenX, cenY),
298 color="k",
299 fontsize=5,
300 ha="center",
301 va="center",
302 path_effects=[pathEffects.withStroke(linewidth=2, foreground="w")],
303 )
305 for i, (xs, ys, colorVals, cmap, label) in enumerate(toPlotList):
306 finite = np.isfinite(xs) & np.isfinite(ys)
307 xs = xs[finite]
308 ys = ys[finite]
309 n_xs = len(xs)
310 # colorVal column is unusable so zero it out
311 # This should be obvious on the plot
312 if not any(np.isfinite(colorVals)):
313 colorVals[:] = 0
315 if n_xs < 5:
316 continue
317 if not self.plotOutlines or "tract" not in sumStats.keys():
318 minRa = np.min(xs)
319 maxRa = np.max(xs)
320 minDec = np.min(ys)
321 maxDec = np.max(ys)
322 # Avoid identical end points which causes problems in binning
323 if minRa == maxRa:
324 maxRa += 1e-5 # There is no reason to pick this number in particular
325 if minDec == maxDec:
326 maxDec += 1e-5 # There is no reason to pick this number in particular
328 plotOut = plotProjectionWithBinning(
329 ax,
330 xs,
331 ys,
332 colorVals,
333 cmap,
334 minRa,
335 maxRa,
336 minDec,
337 maxDec,
338 fixAroundZero=self.fixAroundZero,
339 isSorted=True,
340 )
341 cax = fig.add_axes([0.87 + i * 0.04, 0.11, 0.04, 0.77])
342 plt.colorbar(plotOut, cax=cax, extend="both")
343 colorBarLabel = "{}: {}".format(self.zAxisLabel, label)
344 text = cax.text(
345 0.5,
346 0.5,
347 colorBarLabel,
348 color="k",
349 rotation="vertical",
350 transform=cax.transAxes,
351 ha="center",
352 va="center",
353 fontsize=10,
354 )
355 text.set_path_effects([pathEffects.Stroke(linewidth=3, foreground="w"), pathEffects.Normal()])
356 cax.tick_params(labelsize=7)
358 if i == 0 and len(toPlotList) > 1:
359 cax.yaxis.set_ticks_position("left")
361 ax.set_xlabel(xCol)
362 ax.set_ylabel(yCol)
363 ax.tick_params(axis="x", labelrotation=25)
364 ax.tick_params(labelsize=7)
366 ax.set_aspect("equal")
367 plt.draw()
369 # Find some useful axis limits
370 lenXs = [len(xs) for (xs, _, _, _, _) in toPlotList]
371 if lenXs != [] and np.max(lenXs) > 1000:
372 padRa = (maxRa - minRa) / 10
373 padDec = (maxDec - minDec) / 10
374 ax.set_xlim(maxRa + padRa, minRa - padRa)
375 ax.set_ylim(minDec - padDec, maxDec + padDec)
376 else:
377 ax.invert_xaxis()
379 # Add useful information to the plot
380 plt.subplots_adjust(wspace=0.0, hspace=0.0, right=0.85)
381 fig = plt.gcf()
382 fig = addPlotInfo(fig, plotInfo)
384 return fig