Coverage for python/lsst/analysis/tools/actions/plot/skyPlot.py: 11%
165 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-03-03 02:52 -0800
« prev ^ index » next coverage.py v6.5.0, created at 2023-03-03 02:52 -0800
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ("SkyPlot",)
26from typing import Mapping, Optional
28import matplotlib.patheffects as pathEffects
29import matplotlib.pyplot as plt
30import numpy as np
31from lsst.pex.config import Field, ListField
32from matplotlib.figure import Figure
33from matplotlib.patches import Rectangle
35from ...interfaces import KeyedData, KeyedDataSchema, PlotAction, Scalar, Vector
36from ...statistics import nansigmaMad
37from .plotUtils import addPlotInfo, mkColormap, plotProjectionWithBinning, sortAllArrays
39# from .plotUtils import generateSummaryStats, parsePlotInfo
42class SkyPlot(PlotAction):
43 xAxisLabel = Field[str](doc="Label to use for the x axis.", optional=False)
44 yAxisLabel = Field[str](doc="Label to use for the y axis.", optional=False)
45 zAxisLabel = Field[str](doc="Label to use for the z axis.", optional=False)
47 plotOutlines = Field[bool](
48 doc="Plot the outlines of the ccds/patches?",
49 default=True,
50 )
52 plotTypes = ListField[str](
53 doc="Selection of types of objects to plot. Can take any combination of"
54 " stars, galaxies, unknown, mag, any.",
55 optional=False,
56 # itemCheck=_validatePlotTypes,
57 )
59 plotName = Field[str](doc="The name for the plot.", optional=False)
61 fixAroundZero = Field[bool](
62 doc="Fix the colorbar to be symmetric around zero.",
63 default=False,
64 )
66 def getInputSchema(self, **kwargs) -> KeyedDataSchema:
67 base = []
68 if "stars" in self.plotTypes: # type: ignore
69 base.append(("xStars", Vector))
70 base.append(("yStars", Vector))
71 base.append(("zStars", Vector))
72 base.append(("starStatMask", Vector))
73 if "galaxies" in self.plotTypes: # type: ignore
74 base.append(("xGalaxies", Vector))
75 base.append(("yGalaxies", Vector))
76 base.append(("zGalaxies", Vector))
77 base.append(("galaxyStatMask", Vector))
78 if "unknown" in self.plotTypes: # type: ignore
79 base.append(("xUnknowns", Vector))
80 base.append(("yUnknowns", Vector))
81 base.append(("zUnknowns", Vector))
82 base.append(("unknownStatMask", Vector))
83 if "any" in self.plotTypes: # type: ignore
84 base.append(("x", Vector))
85 base.append(("y", Vector))
86 base.append(("z", Vector))
87 base.append(("statMask", Vector))
89 return base
91 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure:
92 self._validateInput(data, **kwargs)
93 return self.makePlot(data, **kwargs)
94 # table is a dict that needs: x, y, run, skymap, filter, tract,
96 def _validateInput(self, data: KeyedData, **kwargs) -> None:
97 """NOTE currently can only check that something is not a Scalar, not
98 check that the data is consistent with Vector
99 """
100 needed = self.getInputSchema(**kwargs)
101 if remainder := {key.format(**kwargs) for key, _ in needed} - {
102 key.format(**kwargs) for key in data.keys()
103 }:
104 raise ValueError(f"Task needs keys {remainder} but they were not found in input")
105 for name, typ in needed:
106 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar)
107 if isScalar and typ != Scalar:
108 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}")
110 def statsAndText(self, arr, mask=None):
111 """Calculate some stats from an array and return them
112 and some text.
113 """
114 numPoints = len(arr)
115 if mask is not None:
116 arr = arr[mask]
117 med = np.nanmedian(arr)
118 sigMad = nansigmaMad(arr)
120 statsText = (
121 "Median: {:0.2f}\n".format(med)
122 + r"$\sigma_{MAD}$: "
123 + "{:0.2f}\n".format(sigMad)
124 + r"n$_{points}$: "
125 + "{}".format(numPoints)
126 )
128 return med, sigMad, statsText
130 def makePlot(
131 self,
132 data: KeyedData,
133 plotInfo: Optional[Mapping[str, str]] = None,
134 sumStats: Optional[Mapping] = None,
135 **kwargs,
136 ) -> Figure:
137 """Prep the catalogue and then make a skyPlot of the given column.
139 Parameters
140 ----------
141 catPlot : `pandas.core.frame.DataFrame`
142 The catalog to plot the points from.
143 dataId :
144 `lsst.daf.butler.core.dimensions._coordinate._ExpandedTupleDataCoordinate`
145 The dimensions that the plot is being made from.
146 runName : `str`
147 The name of the collection that the plot is written out to.
148 skymap : `lsst.skymap`
149 The skymap used to define the patch boundaries.
150 tableName : `str`
151 The type of table used to make the plot.
153 Returns
154 -------
155 `pipeBase.Struct` containing:
156 skyPlot : `matplotlib.figure.Figure`
157 The resulting figure.
159 Notes
160 -----
161 The catalogue is first narrowed down using the selectors specified in
162 `self.config.selectorActions`.
163 If the column names are 'Functor' then the functors specified in
164 `self.config.axisFunctors` are used to calculate the required values.
165 After this the following functions are run:
167 `parsePlotInfo` which uses the dataId, runName and tableName to add
168 useful information to the plot.
170 `generateSummaryStats` which parses the skymap to give the corners of
171 the patches for later plotting and calculates some basic statistics
172 in each patch for the column in self.config.axisActions['zAction'].
174 `SkyPlot` which makes the plot of the sky distribution of
175 `self.config.axisActions['zAction']`.
177 Makes a generic plot showing the value at given points on the sky.
179 Parameters
180 ----------
181 catPlot : `pandas.core.frame.DataFrame`
182 The catalog to plot the points from.
183 plotInfo : `dict`
184 A dictionary of information about the data being plotted with keys:
185 ``"run"``
186 The output run for the plots (`str`).
187 ``"skymap"``
188 The type of skymap used for the data (`str`).
189 ``"filter"``
190 The filter used for this data (`str`).
191 ``"tract"``
192 The tract that the data comes from (`str`).
193 sumStats : `dict`
194 A dictionary where the patchIds are the keys which store the R.A.
195 and dec of the corners of the patch.
197 Returns
198 -------
199 fig : `matplotlib.figure.Figure`
200 The resulting figure.
202 Notes
203 -----
204 Uses the config options `self.config.xColName` and
205 `self.config.yColName` to plot points color coded by
206 `self.config.axisActions['zAction']`.
207 The points plotted are those selected by the selectors specified in
208 `self.config.selectorActions`.
209 """
210 fig = plt.figure(dpi=300)
211 ax = fig.add_subplot(111)
213 if sumStats is None:
214 sumStats = {}
216 if plotInfo is None:
217 plotInfo = {}
219 # Make divergent colormaps for stars, galaxes and all the points
220 blueGreen = mkColormap(["midnightblue", "lightcyan", "darkgreen"])
221 redPurple = mkColormap(["indigo", "lemonchiffon", "firebrick"])
222 orangeBlue = mkColormap(["darkOrange", "thistle", "midnightblue"])
224 xCol = self.xAxisLabel
225 yCol = self.yAxisLabel
226 zCol = self.zAxisLabel # noqa: F841
228 toPlotList = []
229 # For galaxies
230 if "galaxies" in self.plotTypes:
231 sortedArrs = sortAllArrays(
232 [data["zGalaxies"], data["xGalaxies"], data["yGalaxies"], data["galaxyStatMask"]]
233 )
234 [colorValsGalaxies, xsGalaxies, ysGalaxies, statGalaxies] = sortedArrs
235 statGalMed, statGalMad, galStatsText = self.statsAndText(colorValsGalaxies, mask=statGalaxies)
236 # Add statistics
237 bbox = dict(facecolor="lemonchiffon", alpha=0.5, edgecolor="none")
238 # Check if plotting stars and galaxies, if so move the
239 # text box so that both can be seen. Needs to be
240 # > 2 becuase not being plotted points are assigned 0
241 if len(self.plotTypes) > 2:
242 boxLoc = 0.63
243 else:
244 boxLoc = 0.8
245 ax.text(boxLoc, 0.91, galStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox)
246 toPlotList.append((xsGalaxies, ysGalaxies, colorValsGalaxies, redPurple, "Galaxies"))
248 # For stars
249 if "stars" in self.plotTypes:
250 sortedArrs = sortAllArrays([data["zStars"], data["xStars"], data["yStars"], data["starStatMask"]])
251 [colorValsStars, xsStars, ysStars, statStars] = sortedArrs
252 statStarMed, statStarMad, starStatsText = self.statsAndText(colorValsStars, mask=statStars)
253 # Add statistics
254 bbox = dict(facecolor="paleturquoise", alpha=0.5, edgecolor="none")
255 ax.text(0.8, 0.91, starStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox)
256 toPlotList.append((xsStars, ysStars, colorValsStars, blueGreen, "Stars"))
258 # For unknowns
259 if "unknown" in self.plotTypes:
260 sortedArrs = sortAllArrays(
261 [data["zUnknowns"], data["xUnknowns"], data["yUnknowns"], data["unknownStatMask"]]
262 )
263 [colorValsUnknowns, xsUnknowns, ysUnknowns, statUnknowns] = sortedArrs
264 statUnknownMed, statUnknownMad, unknownStatsText = self.statsAndText(
265 colorValsUnknowns, mask=statUnknowns
266 )
267 bbox = dict(facecolor="green", alpha=0.2, edgecolor="none")
268 ax.text(0.8, 0.91, unknownStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox)
269 toPlotList.append((xsUnknowns, ysUnknowns, colorValsUnknowns, "viridis", "Unknown"))
271 if "any" in self.plotTypes:
272 sortedArrs = sortAllArrays([data["z"], data["x"], data["y"], data["statMask"]])
273 [colorValsAny, xs, ys, statAny] = sortedArrs
274 statAnyMed, statAnyMad, anyStatsText = self.statsAndText(colorValsAny, mask=statAny)
275 bbox = dict(facecolor="purple", alpha=0.2, edgecolor="none")
276 ax.text(0.8, 0.91, anyStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox)
277 toPlotList.append((xs, ys, colorValsAny, orangeBlue, "All"))
279 # Corner plot of patches showing summary stat in each
280 if self.plotOutlines:
281 patches = []
282 for dataId in sumStats.keys():
283 (corners, _) = sumStats[dataId]
284 ra = corners[0][0].asDegrees()
285 dec = corners[0][1].asDegrees()
286 xy = (ra, dec)
287 width = corners[2][0].asDegrees() - ra
288 height = corners[2][1].asDegrees() - dec
289 patches.append(Rectangle(xy, width, height, alpha=0.3))
290 ras = [ra.asDegrees() for (ra, dec) in corners]
291 decs = [dec.asDegrees() for (ra, dec) in corners]
292 ax.plot(ras + [ras[0]], decs + [decs[0]], "k", lw=0.5)
293 cenX = ra + width / 2
294 cenY = dec + height / 2
295 if dataId == "tract":
296 minRa = np.min(ras)
297 minDec = np.min(decs)
298 maxRa = np.max(ras)
299 maxDec = np.max(decs)
300 if dataId != "tract":
301 ax.annotate(
302 dataId,
303 (cenX, cenY),
304 color="k",
305 fontsize=5,
306 ha="center",
307 va="center",
308 path_effects=[pathEffects.withStroke(linewidth=2, foreground="w")],
309 )
311 for i, (xs, ys, colorVals, cmap, label) in enumerate(toPlotList):
312 if not self.plotOutlines or "tract" not in sumStats.keys():
313 minRa = np.min(xs)
314 maxRa = np.max(xs)
315 minDec = np.min(ys)
316 maxDec = np.max(ys)
317 # Avoid identical end points which causes problems in binning
318 if minRa == maxRa:
319 maxRa += 1e-5 # There is no reason to pick this number in particular
320 if minDec == maxDec:
321 maxDec += 1e-5 # There is no reason to pick this number in particular
323 plotOut = plotProjectionWithBinning(
324 ax,
325 xs,
326 ys,
327 colorVals,
328 cmap,
329 minRa,
330 maxRa,
331 minDec,
332 maxDec,
333 fixAroundZero=self.fixAroundZero,
334 isSorted=True,
335 )
336 cax = fig.add_axes([0.87 + i * 0.04, 0.11, 0.04, 0.77])
337 plt.colorbar(plotOut, cax=cax, extend="both")
338 colorBarLabel = "{}: {}".format(self.zAxisLabel, label)
339 text = cax.text(
340 0.5,
341 0.5,
342 colorBarLabel,
343 color="k",
344 rotation="vertical",
345 transform=cax.transAxes,
346 ha="center",
347 va="center",
348 fontsize=10,
349 )
350 text.set_path_effects([pathEffects.Stroke(linewidth=3, foreground="w"), pathEffects.Normal()])
351 cax.tick_params(labelsize=7)
353 if i == 0 and len(toPlotList) > 1:
354 cax.yaxis.set_ticks_position("left")
356 ax.set_xlabel(xCol)
357 ax.set_ylabel(yCol)
358 ax.tick_params(axis="x", labelrotation=25)
359 ax.tick_params(labelsize=7)
361 ax.set_aspect("equal")
362 plt.draw()
364 # Find some useful axis limits
365 lenXs = [len(xs) for (xs, _, _, _, _) in toPlotList]
366 if lenXs != [] and np.max(lenXs) > 1000:
367 padRa = (maxRa - minRa) / 10
368 padDec = (maxDec - minDec) / 10
369 ax.set_xlim(maxRa + padRa, minRa - padRa)
370 ax.set_ylim(minDec - padDec, maxDec + padDec)
371 else:
372 ax.invert_xaxis()
374 # Add useful information to the plot
375 plt.subplots_adjust(wspace=0.0, hspace=0.0, right=0.85)
376 fig = plt.gcf()
377 fig = addPlotInfo(fig, plotInfo)
379 return fig