Coverage for python/lsst/analysis/tools/actions/plot/skyPlot.py: 12%
179 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-08 13:15 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-08 13:15 +0000
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ("SkyPlot",)
26from typing import Mapping, Optional
28import matplotlib.patheffects as pathEffects
29import matplotlib.pyplot as plt
30import numpy as np
31from lsst.pex.config import Field, ListField
32from lsst.pex.config.configurableActions import ConfigurableActionField
33from matplotlib.figure import Figure
34from matplotlib.patches import Rectangle
36from ...interfaces import KeyedData, KeyedDataSchema, PlotAction, Scalar, Vector, VectorAction
37from ...statistics import nansigmaMad
38from .calculateRange import Med2Mad
39from .plotUtils import addPlotInfo, generateSummaryStats, mkColormap, plotProjectionWithBinning, sortAllArrays
42class SkyPlot(PlotAction):
43 """Plots the on sky distribution of a parameter.
45 Plots the values of the parameter given for the z axis
46 according to the positions given for x and y. Optimised
47 for use with RA and Dec. Also calculates some basic
48 statistics and includes those on the plot.
50 The plotting of patch outlines requires patch information
51 to be included as an additional parameter.
52 """
54 xAxisLabel = Field[str](doc="Label to use for the x axis.", optional=False)
55 yAxisLabel = Field[str](doc="Label to use for the y axis.", optional=False)
56 zAxisLabel = Field[str](doc="Label to use for the z axis.", optional=False)
58 plotOutlines = Field[bool](
59 doc="Plot the outlines of the ccds/patches?",
60 default=True,
61 )
63 plotTypes = ListField[str](
64 doc="Selection of types of objects to plot. Can take any combination of"
65 " stars, galaxies, unknown, mag, any.",
66 optional=False,
67 # itemCheck=_validatePlotTypes,
68 )
70 plotName = Field[str](doc="The name for the plot.", optional=False)
72 fixAroundZero = Field[bool](
73 doc="Fix the colorbar to be symmetric around zero.",
74 default=False,
75 )
77 colorbarRange = ConfigurableActionField[VectorAction](
78 doc="Action to calculate the min and max of the colorbar range.",
79 default=Med2Mad,
80 )
82 def getInputSchema(self, **kwargs) -> KeyedDataSchema:
83 base = []
84 if "stars" in self.plotTypes: # type: ignore
85 base.append(("xStars", Vector))
86 base.append(("yStars", Vector))
87 base.append(("zStars", Vector))
88 base.append(("starStatMask", Vector))
89 if "galaxies" in self.plotTypes: # type: ignore
90 base.append(("xGalaxies", Vector))
91 base.append(("yGalaxies", Vector))
92 base.append(("zGalaxies", Vector))
93 base.append(("galaxyStatMask", Vector))
94 if "unknown" in self.plotTypes: # type: ignore
95 base.append(("xUnknowns", Vector))
96 base.append(("yUnknowns", Vector))
97 base.append(("zUnknowns", Vector))
98 base.append(("unknownStatMask", Vector))
99 if "any" in self.plotTypes: # type: ignore
100 base.append(("x", Vector))
101 base.append(("y", Vector))
102 base.append(("z", Vector))
103 base.append(("statMask", Vector))
105 return base
107 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure:
108 self._validateInput(data, **kwargs)
109 return self.makePlot(data, **kwargs)
110 # table is a dict that needs: x, y, run, skymap, filter, tract,
112 def _validateInput(self, data: KeyedData, **kwargs) -> None:
113 """NOTE currently can only check that something is not a Scalar, not
114 check that the data is consistent with Vector
115 """
116 needed = self.getInputSchema(**kwargs)
117 if remainder := {key.format(**kwargs) for key, _ in needed} - {
118 key.format(**kwargs) for key in data.keys()
119 }:
120 raise ValueError(f"Task needs keys {remainder} but they were not found in input")
121 for name, typ in needed:
122 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar)
123 if isScalar and typ != Scalar:
124 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}")
126 def statsAndText(self, arr, mask=None):
127 """Calculate some stats from an array and return them
128 and some text.
129 """
130 numPoints = len(arr)
131 if mask is not None:
132 arr = arr[mask]
133 med = np.nanmedian(arr)
134 sigMad = nansigmaMad(arr)
136 statsText = (
137 "Median: {:0.2f}\n".format(med)
138 + r"$\sigma_{MAD}$: "
139 + "{:0.2f}\n".format(sigMad)
140 + r"n$_{points}$: "
141 + "{}".format(numPoints)
142 )
144 return med, sigMad, statsText
146 def makePlot(
147 self,
148 data: KeyedData,
149 plotInfo: Optional[Mapping[str, str]] = None,
150 sumStats: Optional[Mapping] = None,
151 **kwargs,
152 ) -> Figure:
153 """Make a skyPlot of the given data.
155 Parameters
156 ----------
157 data : `KeyedData`
158 The catalog to plot the points from.
159 plotInfo : `dict`
160 A dictionary of information about the data being plotted with keys:
162 ``"run"``
163 The output run for the plots (`str`).
164 ``"skymap"``
165 The type of skymap used for the data (`str`).
166 ``"filter"``
167 The filter used for this data (`str`).
168 ``"tract"``
169 The tract that the data comes from (`str`).
171 sumStats : `dict`
172 A dictionary where the patchIds are the keys which store the R.A.
173 and dec of the corners of the patch.
175 Returns
176 -------
177 `pipeBase.Struct` containing:
178 skyPlot : `matplotlib.figure.Figure`
179 The resulting figure.
181 Notes
182 -----
183 Expects the data to contain slightly different things
184 depending on the types specified in plotTypes. This
185 is handled automatically if you go through the pipetask
186 framework but if you call this method separately then you
187 need to make sure that data contains what the code is expecting.
189 If stars is in the plot types given then it is expected that
190 data contains: xStars, yStars, zStars and starStatMask.
192 If galaxies is present: xGalaxies, yGalaxies, zGalaxies and
193 galaxyStatsMask.
195 If unknown is present: xUnknowns, yUnknowns, zUnknowns and
196 unknownStatMask.
198 If any is specified: x, y, z, statMask.
200 These options are not exclusive and multiple can be specified
201 and thus need to be present in data.
203 Examples
204 --------
205 An example of the plot produced from this code is here:
207 .. image:: /_static/analysis_tools/skyPlotExample.png
209 For a detailed example of how to make a plot from the command line
210 please see the
211 :ref:`getting started guide<analysis-tools-getting-started>`.
212 """
214 fig = plt.figure(dpi=300)
215 ax = fig.add_subplot(111)
217 if sumStats is None:
218 if self.plotOutlines and "patch" in data.keys():
219 sumStats = generateSummaryStats(data, kwargs["skymap"], plotInfo)
220 else:
221 sumStats = {}
223 if plotInfo is None:
224 plotInfo = {}
226 # Make divergent colormaps for stars, galaxes and all the points
227 blueGreen = mkColormap(["midnightblue", "lightcyan", "darkgreen"])
228 redPurple = mkColormap(["indigo", "lemonchiffon", "firebrick"])
229 orangeBlue = mkColormap(["darkOrange", "thistle", "midnightblue"])
231 xCol = self.xAxisLabel
232 yCol = self.yAxisLabel
233 zCol = self.zAxisLabel # noqa: F841
235 toPlotList = []
236 # For galaxies
237 if "galaxies" in self.plotTypes:
238 sortedArrs = sortAllArrays(
239 [data["zGalaxies"], data["xGalaxies"], data["yGalaxies"], data["galaxyStatMask"]]
240 )
241 [colorValsGalaxies, xsGalaxies, ysGalaxies, statGalaxies] = sortedArrs
242 statGalMed, statGalMad, galStatsText = self.statsAndText(colorValsGalaxies, mask=statGalaxies)
243 # Add statistics
244 bbox = dict(facecolor="lemonchiffon", alpha=0.5, edgecolor="none")
245 # Check if plotting stars and galaxies, if so move the
246 # text box so that both can be seen. Needs to be
247 # > 2 becuase not being plotted points are assigned 0
248 if len(self.plotTypes) > 2:
249 boxLoc = 0.63
250 else:
251 boxLoc = 0.8
252 ax.text(boxLoc, 0.91, galStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox)
253 toPlotList.append((xsGalaxies, ysGalaxies, colorValsGalaxies, redPurple, "Galaxies"))
255 # For stars
256 if "stars" in self.plotTypes:
257 sortedArrs = sortAllArrays([data["zStars"], data["xStars"], data["yStars"], data["starStatMask"]])
258 [colorValsStars, xsStars, ysStars, statStars] = sortedArrs
259 statStarMed, statStarMad, starStatsText = self.statsAndText(colorValsStars, mask=statStars)
260 # Add statistics
261 bbox = dict(facecolor="paleturquoise", alpha=0.5, edgecolor="none")
262 ax.text(0.8, 0.91, starStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox)
263 toPlotList.append((xsStars, ysStars, colorValsStars, blueGreen, "Stars"))
265 # For unknowns
266 if "unknown" in self.plotTypes:
267 sortedArrs = sortAllArrays(
268 [data["zUnknowns"], data["xUnknowns"], data["yUnknowns"], data["unknownStatMask"]]
269 )
270 [colorValsUnknowns, xsUnknowns, ysUnknowns, statUnknowns] = sortedArrs
271 statUnknownMed, statUnknownMad, unknownStatsText = self.statsAndText(
272 colorValsUnknowns, mask=statUnknowns
273 )
274 bbox = dict(facecolor="green", alpha=0.2, edgecolor="none")
275 ax.text(0.8, 0.91, unknownStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox)
276 toPlotList.append((xsUnknowns, ysUnknowns, colorValsUnknowns, "viridis", "Unknown"))
278 if "any" in self.plotTypes:
279 sortedArrs = sortAllArrays([data["z"], data["x"], data["y"], data["statMask"]])
280 [colorValsAny, xs, ys, statAny] = sortedArrs
281 statAnyMed, statAnyMad, anyStatsText = self.statsAndText(colorValsAny, mask=statAny)
282 bbox = dict(facecolor="purple", alpha=0.2, edgecolor="none")
283 ax.text(0.8, 0.91, anyStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox)
284 toPlotList.append((xs, ys, colorValsAny, orangeBlue, "All"))
286 # Corner plot of patches showing summary stat in each
287 if self.plotOutlines:
288 patches = []
289 for dataId in sumStats.keys():
290 (corners, _) = sumStats[dataId]
291 ra = corners[0][0].asDegrees()
292 dec = corners[0][1].asDegrees()
293 xy = (ra, dec)
294 width = corners[2][0].asDegrees() - ra
295 height = corners[2][1].asDegrees() - dec
296 patches.append(Rectangle(xy, width, height, alpha=0.3))
297 ras = [ra.asDegrees() for (ra, dec) in corners]
298 decs = [dec.asDegrees() for (ra, dec) in corners]
299 ax.plot(ras + [ras[0]], decs + [decs[0]], "k", lw=0.5)
300 cenX = ra + width / 2
301 cenY = dec + height / 2
302 if dataId == "tract":
303 minRa = np.min(ras)
304 minDec = np.min(decs)
305 maxRa = np.max(ras)
306 maxDec = np.max(decs)
307 if dataId != "tract":
308 ax.annotate(
309 dataId,
310 (cenX, cenY),
311 color="k",
312 fontsize=5,
313 ha="center",
314 va="center",
315 path_effects=[pathEffects.withStroke(linewidth=2, foreground="w")],
316 )
318 for i, (xs, ys, colorVals, cmap, label) in enumerate(toPlotList):
319 finite = np.isfinite(xs) & np.isfinite(ys)
320 xs = xs[finite]
321 ys = ys[finite]
322 n_xs = len(xs)
323 # colorVal column is unusable so zero it out
324 # This should be obvious on the plot
325 if not any(np.isfinite(colorVals)):
326 colorVals[:] = 0
327 minColorVal, maxColorVal = self.colorbarRange(colorVals)
329 if n_xs < 5:
330 continue
331 if not self.plotOutlines or "tract" not in sumStats.keys():
332 minRa = np.min(xs)
333 maxRa = np.max(xs)
334 minDec = np.min(ys)
335 maxDec = np.max(ys)
336 # Avoid identical end points which causes problems in binning
337 if minRa == maxRa:
338 maxRa += 1e-5 # There is no reason to pick this number in particular
339 if minDec == maxDec:
340 maxDec += 1e-5 # There is no reason to pick this number in particular
342 plotOut = plotProjectionWithBinning(
343 ax,
344 xs,
345 ys,
346 colorVals,
347 cmap,
348 minRa,
349 maxRa,
350 minDec,
351 maxDec,
352 vmin=minColorVal,
353 vmax=maxColorVal,
354 fixAroundZero=self.fixAroundZero,
355 isSorted=True,
356 )
357 cax = fig.add_axes([0.87 + i * 0.04, 0.11, 0.04, 0.77])
358 plt.colorbar(plotOut, cax=cax, extend="both")
359 colorBarLabel = "{}: {}".format(self.zAxisLabel, label)
360 text = cax.text(
361 0.5,
362 0.5,
363 colorBarLabel,
364 color="k",
365 rotation="vertical",
366 transform=cax.transAxes,
367 ha="center",
368 va="center",
369 fontsize=10,
370 )
371 text.set_path_effects([pathEffects.Stroke(linewidth=3, foreground="w"), pathEffects.Normal()])
372 cax.tick_params(labelsize=7)
374 if i == 0 and len(toPlotList) > 1:
375 cax.yaxis.set_ticks_position("left")
377 ax.set_xlabel(xCol)
378 ax.set_ylabel(yCol)
379 ax.tick_params(axis="x", labelrotation=25)
380 ax.tick_params(labelsize=7)
382 ax.set_aspect("equal")
383 plt.draw()
385 # Find some useful axis limits
386 lenXs = [len(xs) for (xs, _, _, _, _) in toPlotList]
387 if lenXs != [] and np.max(lenXs) > 1000:
388 padRa = (maxRa - minRa) / 10
389 padDec = (maxDec - minDec) / 10
390 ax.set_xlim(maxRa + padRa, minRa - padRa)
391 ax.set_ylim(minDec - padDec, maxDec + padDec)
392 else:
393 ax.invert_xaxis()
395 # Add useful information to the plot
396 plt.subplots_adjust(wspace=0.0, hspace=0.0, right=0.85)
397 fig = plt.gcf()
398 fig = addPlotInfo(fig, plotInfo)
400 return fig