Coverage for python/lsst/analysis/tools/actions/plot/skyPlot.py: 12%
180 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-20 13:15 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-20 13:15 +0000
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ("SkyPlot",)
26from typing import Mapping, Optional
28import matplotlib.patheffects as pathEffects
29import matplotlib.pyplot as plt
30import numpy as np
31from lsst.pex.config import Field, ListField
32from lsst.pex.config.configurableActions import ConfigurableActionField
33from matplotlib.figure import Figure
34from matplotlib.patches import Rectangle
36from ...interfaces import KeyedData, KeyedDataSchema, PlotAction, Scalar, Vector, VectorAction
37from ...math import nanMedian, nanSigmaMad
38from .calculateRange import Med2Mad
39from .plotUtils import addPlotInfo, generateSummaryStats, mkColormap, plotProjectionWithBinning, sortAllArrays
42class SkyPlot(PlotAction):
43 """Plots the on sky distribution of a parameter.
45 Plots the values of the parameter given for the z axis
46 according to the positions given for x and y. Optimised
47 for use with RA and Dec. Also calculates some basic
48 statistics and includes those on the plot.
50 The plotting of patch outlines requires patch information
51 to be included as an additional parameter.
52 """
54 xAxisLabel = Field[str](doc="Label to use for the x axis.", optional=False)
55 yAxisLabel = Field[str](doc="Label to use for the y axis.", optional=False)
56 zAxisLabel = Field[str](doc="Label to use for the z axis.", optional=False)
58 plotOutlines = Field[bool](
59 doc="Plot the outlines of the ccds/patches?",
60 default=True,
61 )
63 plotTypes = ListField[str](
64 doc="Selection of types of objects to plot. Can take any combination of"
65 " stars, galaxies, unknown, mag, any.",
66 optional=False,
67 # itemCheck=_validatePlotTypes,
68 )
70 plotName = Field[str](doc="The name for the plot.", optional=False)
72 fixAroundZero = Field[bool](
73 doc="Fix the colorbar to be symmetric around zero.",
74 default=False,
75 )
77 colorbarRange = ConfigurableActionField[VectorAction](
78 doc="Action to calculate the min and max of the colorbar range.",
79 default=Med2Mad,
80 )
82 showExtremeOutliers = Field[bool](
83 doc="Show the x-y positions of extreme outlier values as overlaid scatter points.",
84 default=True,
85 )
87 def getInputSchema(self, **kwargs) -> KeyedDataSchema:
88 base = []
89 if "stars" in self.plotTypes: # type: ignore
90 base.append(("xStars", Vector))
91 base.append(("yStars", Vector))
92 base.append(("zStars", Vector))
93 base.append(("starStatMask", Vector))
94 if "galaxies" in self.plotTypes: # type: ignore
95 base.append(("xGalaxies", Vector))
96 base.append(("yGalaxies", Vector))
97 base.append(("zGalaxies", Vector))
98 base.append(("galaxyStatMask", Vector))
99 if "unknown" in self.plotTypes: # type: ignore
100 base.append(("xUnknowns", Vector))
101 base.append(("yUnknowns", Vector))
102 base.append(("zUnknowns", Vector))
103 base.append(("unknownStatMask", Vector))
104 if "any" in self.plotTypes: # type: ignore
105 base.append(("x", Vector))
106 base.append(("y", Vector))
107 base.append(("z", Vector))
108 base.append(("statMask", Vector))
110 return base
112 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure:
113 self._validateInput(data, **kwargs)
114 return self.makePlot(data, **kwargs)
115 # table is a dict that needs: x, y, run, skymap, filter, tract,
117 def _validateInput(self, data: KeyedData, **kwargs) -> None:
118 """NOTE currently can only check that something is not a Scalar, not
119 check that the data is consistent with Vector
120 """
121 needed = self.getInputSchema(**kwargs)
122 if remainder := {key.format(**kwargs) for key, _ in needed} - {
123 key.format(**kwargs) for key in data.keys()
124 }:
125 raise ValueError(f"Task needs keys {remainder} but they were not found in input")
126 for name, typ in needed:
127 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar)
128 if isScalar and typ != Scalar:
129 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}")
131 def statsAndText(self, arr, mask=None):
132 """Calculate some stats from an array and return them
133 and some text.
134 """
135 numPoints = len(arr)
136 if mask is not None:
137 arr = arr[mask]
138 med = nanMedian(arr)
139 sigMad = nanSigmaMad(arr)
141 statsText = (
142 "Median: {:0.2f}\n".format(med)
143 + r"$\sigma_{MAD}$: "
144 + "{:0.2f}\n".format(sigMad)
145 + r"n$_{points}$: "
146 + "{}".format(numPoints)
147 )
149 return med, sigMad, statsText
151 def makePlot(
152 self,
153 data: KeyedData,
154 plotInfo: Optional[Mapping[str, str]] = None,
155 sumStats: Optional[Mapping] = None,
156 **kwargs,
157 ) -> Figure:
158 """Make a skyPlot of the given data.
160 Parameters
161 ----------
162 data : `KeyedData`
163 The catalog to plot the points from.
164 plotInfo : `dict`
165 A dictionary of information about the data being plotted with keys:
167 ``"run"``
168 The output run for the plots (`str`).
169 ``"skymap"``
170 The type of skymap used for the data (`str`).
171 ``"filter"``
172 The filter used for this data (`str`).
173 ``"tract"``
174 The tract that the data comes from (`str`).
176 sumStats : `dict`
177 A dictionary where the patchIds are the keys which store the R.A.
178 and dec of the corners of the patch.
180 Returns
181 -------
182 `pipeBase.Struct` containing:
183 skyPlot : `matplotlib.figure.Figure`
184 The resulting figure.
186 Notes
187 -----
188 Expects the data to contain slightly different things
189 depending on the types specified in plotTypes. This
190 is handled automatically if you go through the pipetask
191 framework but if you call this method separately then you
192 need to make sure that data contains what the code is expecting.
194 If stars is in the plot types given then it is expected that
195 data contains: xStars, yStars, zStars and starStatMask.
197 If galaxies is present: xGalaxies, yGalaxies, zGalaxies and
198 galaxyStatsMask.
200 If unknown is present: xUnknowns, yUnknowns, zUnknowns and
201 unknownStatMask.
203 If any is specified: x, y, z, statMask.
205 These options are not exclusive and multiple can be specified
206 and thus need to be present in data.
208 Examples
209 --------
210 An example of the plot produced from this code is here:
212 .. image:: /_static/analysis_tools/skyPlotExample.png
214 For a detailed example of how to make a plot from the command line
215 please see the
216 :ref:`getting started guide<analysis-tools-getting-started>`.
217 """
219 fig = plt.figure(dpi=300)
220 ax = fig.add_subplot(111)
222 if sumStats is None:
223 if self.plotOutlines and "patch" in data.keys():
224 sumStats = generateSummaryStats(data, kwargs["skymap"], plotInfo)
225 else:
226 sumStats = {}
228 if plotInfo is None:
229 plotInfo = {}
231 # Make divergent colormaps for stars, galaxes and all the points
232 blueGreen = mkColormap(["midnightblue", "lightcyan", "darkgreen"])
233 redPurple = mkColormap(["indigo", "lemonchiffon", "firebrick"])
234 orangeBlue = mkColormap(["darkOrange", "thistle", "midnightblue"])
236 xCol = self.xAxisLabel
237 yCol = self.yAxisLabel
238 zCol = self.zAxisLabel # noqa: F841
240 toPlotList = []
241 # For galaxies
242 if "galaxies" in self.plotTypes:
243 sortedArrs = sortAllArrays(
244 [data["zGalaxies"], data["xGalaxies"], data["yGalaxies"], data["galaxyStatMask"]]
245 )
246 [colorValsGalaxies, xsGalaxies, ysGalaxies, statGalaxies] = sortedArrs
247 statGalMed, statGalMad, galStatsText = self.statsAndText(colorValsGalaxies, mask=statGalaxies)
248 # Add statistics
249 bbox = dict(facecolor="lemonchiffon", alpha=0.5, edgecolor="none")
250 # Check if plotting stars and galaxies, if so move the
251 # text box so that both can be seen. Needs to be
252 # > 2 becuase not being plotted points are assigned 0
253 if len(self.plotTypes) > 2:
254 boxLoc = 0.63
255 else:
256 boxLoc = 0.8
257 ax.text(boxLoc, 0.91, galStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox)
258 toPlotList.append((xsGalaxies, ysGalaxies, colorValsGalaxies, redPurple, "Galaxies"))
260 # For stars
261 if "stars" in self.plotTypes:
262 sortedArrs = sortAllArrays([data["zStars"], data["xStars"], data["yStars"], data["starStatMask"]])
263 [colorValsStars, xsStars, ysStars, statStars] = sortedArrs
264 statStarMed, statStarMad, starStatsText = self.statsAndText(colorValsStars, mask=statStars)
265 # Add statistics
266 bbox = dict(facecolor="paleturquoise", alpha=0.5, edgecolor="none")
267 ax.text(0.8, 0.91, starStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox)
268 toPlotList.append((xsStars, ysStars, colorValsStars, blueGreen, "Stars"))
270 # For unknowns
271 if "unknown" in self.plotTypes:
272 sortedArrs = sortAllArrays(
273 [data["zUnknowns"], data["xUnknowns"], data["yUnknowns"], data["unknownStatMask"]]
274 )
275 [colorValsUnknowns, xsUnknowns, ysUnknowns, statUnknowns] = sortedArrs
276 statUnknownMed, statUnknownMad, unknownStatsText = self.statsAndText(
277 colorValsUnknowns, mask=statUnknowns
278 )
279 bbox = dict(facecolor="green", alpha=0.2, edgecolor="none")
280 ax.text(0.8, 0.91, unknownStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox)
281 toPlotList.append((xsUnknowns, ysUnknowns, colorValsUnknowns, "viridis", "Unknown"))
283 if "any" in self.plotTypes:
284 sortedArrs = sortAllArrays([data["z"], data["x"], data["y"], data["statMask"]])
285 [colorValsAny, xs, ys, statAny] = sortedArrs
286 statAnyMed, statAnyMad, anyStatsText = self.statsAndText(colorValsAny, mask=statAny)
287 bbox = dict(facecolor="purple", alpha=0.2, edgecolor="none")
288 ax.text(0.8, 0.91, anyStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox)
289 toPlotList.append((xs, ys, colorValsAny, orangeBlue, "All"))
291 # Corner plot of patches showing summary stat in each
292 if self.plotOutlines:
293 patches = []
294 for dataId in sumStats.keys():
295 (corners, _) = sumStats[dataId]
296 ra = corners[0][0].asDegrees()
297 dec = corners[0][1].asDegrees()
298 xy = (ra, dec)
299 width = corners[2][0].asDegrees() - ra
300 height = corners[2][1].asDegrees() - dec
301 patches.append(Rectangle(xy, width, height, alpha=0.3))
302 ras = [ra.asDegrees() for (ra, dec) in corners]
303 decs = [dec.asDegrees() for (ra, dec) in corners]
304 ax.plot(ras + [ras[0]], decs + [decs[0]], "k", lw=0.5)
305 cenX = ra + width / 2
306 cenY = dec + height / 2
307 if dataId == "tract":
308 minRa = np.min(ras)
309 minDec = np.min(decs)
310 maxRa = np.max(ras)
311 maxDec = np.max(decs)
312 if dataId != "tract":
313 ax.annotate(
314 dataId,
315 (cenX, cenY),
316 color="k",
317 fontsize=5,
318 ha="center",
319 va="center",
320 path_effects=[pathEffects.withStroke(linewidth=2, foreground="w")],
321 )
323 for i, (xs, ys, colorVals, cmap, label) in enumerate(toPlotList):
324 finite = np.isfinite(xs) & np.isfinite(ys)
325 xs = xs[finite]
326 ys = ys[finite]
327 n_xs = len(xs)
328 # colorVal column is unusable so zero it out
329 # This should be obvious on the plot
330 if not any(np.isfinite(colorVals)):
331 colorVals[:] = 0
332 minColorVal, maxColorVal = self.colorbarRange(colorVals)
334 if n_xs < 5:
335 continue
336 if not self.plotOutlines or "tract" not in sumStats.keys():
337 minRa = np.min(xs)
338 maxRa = np.max(xs)
339 minDec = np.min(ys)
340 maxDec = np.max(ys)
341 # Avoid identical end points which causes problems in binning
342 if minRa == maxRa:
343 maxRa += 1e-5 # There is no reason to pick this number in particular
344 if minDec == maxDec:
345 maxDec += 1e-5 # There is no reason to pick this number in particular
347 plotOut = plotProjectionWithBinning(
348 ax,
349 xs,
350 ys,
351 colorVals,
352 cmap,
353 minRa,
354 maxRa,
355 minDec,
356 maxDec,
357 vmin=minColorVal,
358 vmax=maxColorVal,
359 fixAroundZero=self.fixAroundZero,
360 isSorted=True,
361 showExtremeOutliers=self.showExtremeOutliers,
362 )
363 cax = fig.add_axes([0.87 + i * 0.04, 0.11, 0.04, 0.77])
364 plt.colorbar(plotOut, cax=cax, extend="both")
365 colorBarLabel = "{}: {}".format(self.zAxisLabel, label)
366 text = cax.text(
367 0.5,
368 0.5,
369 colorBarLabel,
370 color="k",
371 rotation="vertical",
372 transform=cax.transAxes,
373 ha="center",
374 va="center",
375 fontsize=10,
376 )
377 text.set_path_effects([pathEffects.Stroke(linewidth=3, foreground="w"), pathEffects.Normal()])
378 cax.tick_params(labelsize=7)
380 if i == 0 and len(toPlotList) > 1:
381 cax.yaxis.set_ticks_position("left")
383 ax.set_xlabel(xCol)
384 ax.set_ylabel(yCol)
385 ax.tick_params(axis="x", labelrotation=25)
386 ax.tick_params(labelsize=7)
388 ax.set_aspect("equal")
389 plt.draw()
391 # Find some useful axis limits
392 lenXs = [len(xs) for (xs, _, _, _, _) in toPlotList]
393 if lenXs != [] and np.max(lenXs) > 1000:
394 padRa = (maxRa - minRa) / 10
395 padDec = (maxDec - minDec) / 10
396 ax.set_xlim(maxRa + padRa, minRa - padRa)
397 ax.set_ylim(minDec - padDec, maxDec + padDec)
398 else:
399 ax.invert_xaxis()
401 # Add useful information to the plot
402 plt.subplots_adjust(wspace=0.0, hspace=0.0, right=0.85)
403 fig = plt.gcf()
404 fig = addPlotInfo(fig, plotInfo)
406 return fig