Coverage for python/lsst/analysis/tools/actions/plot/skyPlot.py: 11%
173 statements
« prev ^ index » next coverage.py v7.2.5, created at 2023-05-10 10:36 +0000
« prev ^ index » next coverage.py v7.2.5, created at 2023-05-10 10:36 +0000
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ("SkyPlot",)
26from typing import Mapping, Optional
28import matplotlib.patheffects as pathEffects
29import matplotlib.pyplot as plt
30import numpy as np
31from lsst.pex.config import Field, ListField
32from matplotlib.figure import Figure
33from matplotlib.patches import Rectangle
35from ...interfaces import KeyedData, KeyedDataSchema, PlotAction, Scalar, Vector
36from ...statistics import nansigmaMad
37from .plotUtils import addPlotInfo, mkColormap, plotProjectionWithBinning, sortAllArrays
39# from .plotUtils import generateSummaryStats, parsePlotInfo
42class SkyPlot(PlotAction):
43 """Plots the on sky distribution of a parameter.
45 Plots the values of the parameter given for the z axis
46 according to the positions given for x and y. Optimised
47 for use with RA and Dec. Also calculates some basic
48 statistics and includes those on the plot.
49 """
51 xAxisLabel = Field[str](doc="Label to use for the x axis.", optional=False)
52 yAxisLabel = Field[str](doc="Label to use for the y axis.", optional=False)
53 zAxisLabel = Field[str](doc="Label to use for the z axis.", optional=False)
55 plotOutlines = Field[bool](
56 doc="Plot the outlines of the ccds/patches?",
57 default=True,
58 )
60 plotTypes = ListField[str](
61 doc="Selection of types of objects to plot. Can take any combination of"
62 " stars, galaxies, unknown, mag, any.",
63 optional=False,
64 # itemCheck=_validatePlotTypes,
65 )
67 plotName = Field[str](doc="The name for the plot.", optional=False)
69 fixAroundZero = Field[bool](
70 doc="Fix the colorbar to be symmetric around zero.",
71 default=False,
72 )
74 def getInputSchema(self, **kwargs) -> KeyedDataSchema:
75 base = []
76 if "stars" in self.plotTypes: # type: ignore
77 base.append(("xStars", Vector))
78 base.append(("yStars", Vector))
79 base.append(("zStars", Vector))
80 base.append(("starStatMask", Vector))
81 if "galaxies" in self.plotTypes: # type: ignore
82 base.append(("xGalaxies", Vector))
83 base.append(("yGalaxies", Vector))
84 base.append(("zGalaxies", Vector))
85 base.append(("galaxyStatMask", Vector))
86 if "unknown" in self.plotTypes: # type: ignore
87 base.append(("xUnknowns", Vector))
88 base.append(("yUnknowns", Vector))
89 base.append(("zUnknowns", Vector))
90 base.append(("unknownStatMask", Vector))
91 if "any" in self.plotTypes: # type: ignore
92 base.append(("x", Vector))
93 base.append(("y", Vector))
94 base.append(("z", Vector))
95 base.append(("statMask", Vector))
97 return base
99 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure:
100 self._validateInput(data, **kwargs)
101 return self.makePlot(data, **kwargs)
102 # table is a dict that needs: x, y, run, skymap, filter, tract,
104 def _validateInput(self, data: KeyedData, **kwargs) -> None:
105 """NOTE currently can only check that something is not a Scalar, not
106 check that the data is consistent with Vector
107 """
108 needed = self.getInputSchema(**kwargs)
109 if remainder := {key.format(**kwargs) for key, _ in needed} - {
110 key.format(**kwargs) for key in data.keys()
111 }:
112 raise ValueError(f"Task needs keys {remainder} but they were not found in input")
113 for name, typ in needed:
114 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar)
115 if isScalar and typ != Scalar:
116 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}")
118 def statsAndText(self, arr, mask=None):
119 """Calculate some stats from an array and return them
120 and some text.
121 """
122 numPoints = len(arr)
123 if mask is not None:
124 arr = arr[mask]
125 med = np.nanmedian(arr)
126 sigMad = nansigmaMad(arr)
128 statsText = (
129 "Median: {:0.2f}\n".format(med)
130 + r"$\sigma_{MAD}$: "
131 + "{:0.2f}\n".format(sigMad)
132 + r"n$_{points}$: "
133 + "{}".format(numPoints)
134 )
136 return med, sigMad, statsText
138 def makePlot(
139 self,
140 data: KeyedData,
141 plotInfo: Optional[Mapping[str, str]] = None,
142 sumStats: Optional[Mapping] = None,
143 **kwargs,
144 ) -> Figure:
145 """Make a skyPlot of the given data.
147 Parameters
148 ----------
149 data : `KeyedData`
150 The catalog to plot the points from.
151 plotInfo : `dict`
152 A dictionary of information about the data being plotted with keys:
154 ``"run"``
155 The output run for the plots (`str`).
156 ``"skymap"``
157 The type of skymap used for the data (`str`).
158 ``"filter"``
159 The filter used for this data (`str`).
160 ``"tract"``
161 The tract that the data comes from (`str`).
163 sumStats : `dict`
164 A dictionary where the patchIds are the keys which store the R.A.
165 and dec of the corners of the patch.
167 Returns
168 -------
169 `pipeBase.Struct` containing:
170 skyPlot : `matplotlib.figure.Figure`
171 The resulting figure.
173 Notes
174 -----
175 Expects the data to contain slightly different things
176 depending on the types specified in plotTypes. This
177 is handled automatically if you go through the pipetask
178 framework but if you call this method separately then you
179 need to make sure that data contains what the code is expecting.
181 If stars is in the plot types given then it is expected that
182 data contains: xStars, yStars, zStars and starStatMask.
184 If galaxies is present: xGalaxies, yGalaxies, zGalaxies and
185 galaxyStatsMask.
187 If unknown is present: xUnknowns, yUnknowns, zUnknowns and
188 unknownStatMask.
190 If any is specified: x, y, z, statMask.
192 These options are not exclusive and multiple can be specified
193 and thus need to be present in data.
194 """
196 fig = plt.figure(dpi=300)
197 ax = fig.add_subplot(111)
199 if sumStats is None:
200 sumStats = {}
202 if plotInfo is None:
203 plotInfo = {}
205 # Make divergent colormaps for stars, galaxes and all the points
206 blueGreen = mkColormap(["midnightblue", "lightcyan", "darkgreen"])
207 redPurple = mkColormap(["indigo", "lemonchiffon", "firebrick"])
208 orangeBlue = mkColormap(["darkOrange", "thistle", "midnightblue"])
210 xCol = self.xAxisLabel
211 yCol = self.yAxisLabel
212 zCol = self.zAxisLabel # noqa: F841
214 toPlotList = []
215 # For galaxies
216 if "galaxies" in self.plotTypes:
217 sortedArrs = sortAllArrays(
218 [data["zGalaxies"], data["xGalaxies"], data["yGalaxies"], data["galaxyStatMask"]]
219 )
220 [colorValsGalaxies, xsGalaxies, ysGalaxies, statGalaxies] = sortedArrs
221 statGalMed, statGalMad, galStatsText = self.statsAndText(colorValsGalaxies, mask=statGalaxies)
222 # Add statistics
223 bbox = dict(facecolor="lemonchiffon", alpha=0.5, edgecolor="none")
224 # Check if plotting stars and galaxies, if so move the
225 # text box so that both can be seen. Needs to be
226 # > 2 becuase not being plotted points are assigned 0
227 if len(self.plotTypes) > 2:
228 boxLoc = 0.63
229 else:
230 boxLoc = 0.8
231 ax.text(boxLoc, 0.91, galStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox)
232 toPlotList.append((xsGalaxies, ysGalaxies, colorValsGalaxies, redPurple, "Galaxies"))
234 # For stars
235 if "stars" in self.plotTypes:
236 sortedArrs = sortAllArrays([data["zStars"], data["xStars"], data["yStars"], data["starStatMask"]])
237 [colorValsStars, xsStars, ysStars, statStars] = sortedArrs
238 statStarMed, statStarMad, starStatsText = self.statsAndText(colorValsStars, mask=statStars)
239 # Add statistics
240 bbox = dict(facecolor="paleturquoise", alpha=0.5, edgecolor="none")
241 ax.text(0.8, 0.91, starStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox)
242 toPlotList.append((xsStars, ysStars, colorValsStars, blueGreen, "Stars"))
244 # For unknowns
245 if "unknown" in self.plotTypes:
246 sortedArrs = sortAllArrays(
247 [data["zUnknowns"], data["xUnknowns"], data["yUnknowns"], data["unknownStatMask"]]
248 )
249 [colorValsUnknowns, xsUnknowns, ysUnknowns, statUnknowns] = sortedArrs
250 statUnknownMed, statUnknownMad, unknownStatsText = self.statsAndText(
251 colorValsUnknowns, mask=statUnknowns
252 )
253 bbox = dict(facecolor="green", alpha=0.2, edgecolor="none")
254 ax.text(0.8, 0.91, unknownStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox)
255 toPlotList.append((xsUnknowns, ysUnknowns, colorValsUnknowns, "viridis", "Unknown"))
257 if "any" in self.plotTypes:
258 sortedArrs = sortAllArrays([data["z"], data["x"], data["y"], data["statMask"]])
259 [colorValsAny, xs, ys, statAny] = sortedArrs
260 statAnyMed, statAnyMad, anyStatsText = self.statsAndText(colorValsAny, mask=statAny)
261 bbox = dict(facecolor="purple", alpha=0.2, edgecolor="none")
262 ax.text(0.8, 0.91, anyStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox)
263 toPlotList.append((xs, ys, colorValsAny, orangeBlue, "All"))
265 # Corner plot of patches showing summary stat in each
266 if self.plotOutlines:
267 patches = []
268 for dataId in sumStats.keys():
269 (corners, _) = sumStats[dataId]
270 ra = corners[0][0].asDegrees()
271 dec = corners[0][1].asDegrees()
272 xy = (ra, dec)
273 width = corners[2][0].asDegrees() - ra
274 height = corners[2][1].asDegrees() - dec
275 patches.append(Rectangle(xy, width, height, alpha=0.3))
276 ras = [ra.asDegrees() for (ra, dec) in corners]
277 decs = [dec.asDegrees() for (ra, dec) in corners]
278 ax.plot(ras + [ras[0]], decs + [decs[0]], "k", lw=0.5)
279 cenX = ra + width / 2
280 cenY = dec + height / 2
281 if dataId == "tract":
282 minRa = np.min(ras)
283 minDec = np.min(decs)
284 maxRa = np.max(ras)
285 maxDec = np.max(decs)
286 if dataId != "tract":
287 ax.annotate(
288 dataId,
289 (cenX, cenY),
290 color="k",
291 fontsize=5,
292 ha="center",
293 va="center",
294 path_effects=[pathEffects.withStroke(linewidth=2, foreground="w")],
295 )
297 for i, (xs, ys, colorVals, cmap, label) in enumerate(toPlotList):
298 finite = np.isfinite(xs) & np.isfinite(ys)
299 xs = xs[finite]
300 ys = ys[finite]
301 n_xs = len(xs)
302 # colorVal column is unusable so zero it out
303 # This should be obvious on the plot
304 if not any(np.isfinite(colorVals)):
305 colorVals[:] = 0
307 if n_xs < 5:
308 continue
309 if not self.plotOutlines or "tract" not in sumStats.keys():
310 minRa = np.min(xs)
311 maxRa = np.max(xs)
312 minDec = np.min(ys)
313 maxDec = np.max(ys)
314 # Avoid identical end points which causes problems in binning
315 if minRa == maxRa:
316 maxRa += 1e-5 # There is no reason to pick this number in particular
317 if minDec == maxDec:
318 maxDec += 1e-5 # There is no reason to pick this number in particular
320 plotOut = plotProjectionWithBinning(
321 ax,
322 xs,
323 ys,
324 colorVals,
325 cmap,
326 minRa,
327 maxRa,
328 minDec,
329 maxDec,
330 fixAroundZero=self.fixAroundZero,
331 isSorted=True,
332 )
333 cax = fig.add_axes([0.87 + i * 0.04, 0.11, 0.04, 0.77])
334 plt.colorbar(plotOut, cax=cax, extend="both")
335 colorBarLabel = "{}: {}".format(self.zAxisLabel, label)
336 text = cax.text(
337 0.5,
338 0.5,
339 colorBarLabel,
340 color="k",
341 rotation="vertical",
342 transform=cax.transAxes,
343 ha="center",
344 va="center",
345 fontsize=10,
346 )
347 text.set_path_effects([pathEffects.Stroke(linewidth=3, foreground="w"), pathEffects.Normal()])
348 cax.tick_params(labelsize=7)
350 if i == 0 and len(toPlotList) > 1:
351 cax.yaxis.set_ticks_position("left")
353 ax.set_xlabel(xCol)
354 ax.set_ylabel(yCol)
355 ax.tick_params(axis="x", labelrotation=25)
356 ax.tick_params(labelsize=7)
358 ax.set_aspect("equal")
359 plt.draw()
361 # Find some useful axis limits
362 lenXs = [len(xs) for (xs, _, _, _, _) in toPlotList]
363 if lenXs != [] and np.max(lenXs) > 1000:
364 padRa = (maxRa - minRa) / 10
365 padDec = (maxDec - minDec) / 10
366 ax.set_xlim(maxRa + padRa, minRa - padRa)
367 ax.set_ylim(minDec - padDec, maxDec + padDec)
368 else:
369 ax.invert_xaxis()
371 # Add useful information to the plot
372 plt.subplots_adjust(wspace=0.0, hspace=0.0, right=0.85)
373 fig = plt.gcf()
374 fig = addPlotInfo(fig, plotInfo)
376 return fig