Coverage for python/lsst/analysis/tools/actions/plot/skyPlot.py: 12%
185 statements
« prev ^ index » next coverage.py v6.4.2, created at 2022-08-05 01:24 -0700
« prev ^ index » next coverage.py v6.4.2, created at 2022-08-05 01:24 -0700
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21from __future__ import annotations
23from typing import Mapping
25import matplotlib.patheffects as pathEffects
26import matplotlib.pyplot as plt
27import numpy as np
28from lsst.pex.config import Field, ListField
29from matplotlib.figure import Figure
30from matplotlib.patches import Rectangle
31from scipy.stats import binned_statistic_2d
32from scipy.stats import median_absolute_deviation as sigmaMad
34from ...interfaces import KeyedData, KeyedDataSchema, PlotAction, Scalar, Vector
35from .plotUtils import addPlotInfo, extremaSort, mkColormap
37# from .plotUtils import generateSummaryStats, parsePlotInfo
40class SkyPlot(PlotAction):
42 xAxisLabel = Field[str](doc="Label to use for the x axis.", optional=False)
43 yAxisLabel = Field[str](doc="Label to use for the y axis.", optional=False)
44 zAxisLabel = Field[str](doc="Label to use for the z axis.", optional=False)
46 fixAroundZero = Field[bool](
47 doc="Fix the center of the colorscale to be zero.",
48 default=False,
49 )
51 plotOutlines = Field[bool](
52 doc="Plot the outlines of the ccds/patches?",
53 default=True,
54 )
56 plotTypes = ListField[str](
57 doc="Selection of types of objects to plot. Can take any combination of"
58 " stars, galaxies, unknown, mag, any.",
59 optional=False,
60 # itemCheck=_validatePlotTypes,
61 )
63 plotName = Field[str](doc="The name for the plot.", optional=False)
65 fixAroundZero = Field[bool](
66 doc="Fix the colorbar to be symmetric around zero.",
67 default=False,
68 )
70 def getInputSchema(self, **kwargs) -> KeyedDataSchema:
71 base = []
72 if "stars" in self.plotTypes: # type: ignore
73 base.append(("xStars", Vector))
74 base.append(("yStars", Vector))
75 base.append(("zStars", Vector))
76 base.append(("starStatMask", Vector))
77 if "galaxies" in self.plotTypes: # type: ignore
78 base.append(("xGalaxies", Vector))
79 base.append(("yGalaxies", Vector))
80 base.append(("zGalaxies", Vector))
81 base.append(("galaxyStatMask", Vector))
82 if "unknown" in self.plotTypes: # type: ignore
83 base.append(("xUnknowns", Vector))
84 base.append(("yUnknowns", Vector))
85 base.append(("zUnknowns", Vector))
86 base.append(("unknownStatMask", Vector))
87 if "any" in self.plotTypes: # type: ignore
88 base.append(("x", Vector))
89 base.append(("y", Vector))
90 base.append(("z", Vector))
91 base.append(("statMask", Vector))
93 return base
95 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure:
96 self._validateInput(data, **kwargs)
97 return self.makePlot(data, **kwargs)
98 # table is a dict that needs: x, y, run, skymap, filter, tract,
100 def _validateInput(self, data: KeyedData, **kwargs) -> None:
101 """NOTE currently can only check that something is not a Scalar, not
102 check that the data is consistent with Vector
103 """
104 needed = self.getInputSchema(**kwargs)
105 if remainder := {key.format(**kwargs) for key, _ in needed} - {
106 key.format(**kwargs) for key in data.keys()
107 }:
108 raise ValueError(f"Task needs keys {remainder} but they were not found in input")
109 for name, typ in needed:
110 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar)
111 if isScalar and typ != Scalar:
112 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}")
114 def sortAllArrays(self, arrsToSort):
115 """Sort one array and then return all the others in
116 the associated order.
117 """
118 ids = extremaSort(arrsToSort[0])
119 for (i, arr) in enumerate(arrsToSort):
120 arrsToSort[i] = arr[ids]
121 return arrsToSort
123 def statsAndText(self, arr, mask=None):
124 """Calculate some stats from an array and return them
125 and some text.
126 """
127 numPoints = len(arr)
128 if mask is not None:
129 arr = arr[mask]
130 med = np.nanmedian(arr)
131 sigMad = sigmaMad(arr, nan_policy="omit")
133 statsText = (
134 "Median: {:0.2f}\n".format(med)
135 + r"$\sigma_{MAD}$: "
136 + "{:0.2f}\n".format(sigMad)
137 + r"n$_{points}$: "
138 + "{}".format(numPoints)
139 )
141 return med, sigMad, statsText
143 def makePlot(
144 self, data: KeyedData, plotInfo: Mapping[str, str] = None, sumStats: Mapping = {}, **kwargs
145 ) -> Figure:
146 """Prep the catalogue and then make a skyPlot of the given column.
148 Parameters
149 ----------
150 catPlot : `pandas.core.frame.DataFrame`
151 The catalog to plot the points from.
152 dataId :
153 `lsst.daf.butler.core.dimensions._coordinate._ExpandedTupleDataCoordinate`
154 The dimensions that the plot is being made from.
155 runName : `str`
156 The name of the collection that the plot is written out to.
157 skymap : `lsst.skymap`
158 The skymap used to define the patch boundaries.
159 tableName : `str`
160 The type of table used to make the plot.
162 Returns
163 -------
164 `pipeBase.Struct` containing:
165 skyPlot : `matplotlib.figure.Figure`
166 The resulting figure.
168 Notes
169 -----
170 The catalogue is first narrowed down using the selectors specified in
171 `self.config.selectorActions`.
172 If the column names are 'Functor' then the functors specified in
173 `self.config.axisFunctors` are used to calculate the required values.
174 After this the following functions are run:
176 `parsePlotInfo` which uses the dataId, runName and tableName to add
177 useful information to the plot.
179 `generateSummaryStats` which parses the skymap to give the corners of
180 the patches for later plotting and calculates some basic statistics
181 in each patch for the column in self.config.axisActions['zAction'].
183 `SkyPlot` which makes the plot of the sky distribution of
184 `self.config.axisActions['zAction']`.
186 Makes a generic plot showing the value at given points on the sky.
188 Parameters
189 ----------
190 catPlot : `pandas.core.frame.DataFrame`
191 The catalog to plot the points from.
192 plotInfo : `dict`
193 A dictionary of information about the data being plotted with keys:
194 ``"run"``
195 The output run for the plots (`str`).
196 ``"skymap"``
197 The type of skymap used for the data (`str`).
198 ``"filter"``
199 The filter used for this data (`str`).
200 ``"tract"``
201 The tract that the data comes from (`str`).
202 sumStats : `dict`
203 A dictionary where the patchIds are the keys which store the R.A.
204 and dec of the corners of the patch.
206 Returns
207 -------
208 fig : `matplotlib.figure.Figure`
209 The resulting figure.
211 Notes
212 -----
213 Uses the config options `self.config.xColName` and
214 `self.config.yColName` to plot points color coded by
215 `self.config.axisActions['zAction']`.
216 The points plotted are those selected by the selectors specified in
217 `self.config.selectorActions`.
218 """
219 fig = plt.figure(dpi=300)
220 ax = fig.add_subplot(111)
222 # Make divergent colormaps for stars, galaxes and all the points
223 blueGreen = mkColormap(["midnightblue", "lightcyan", "darkgreen"])
224 redPurple = mkColormap(["indigo", "lemonchiffon", "firebrick"])
225 orangeBlue = mkColormap(["darkOrange", "thistle", "midnightblue"])
227 xCol = self.xAxisLabel
228 yCol = self.yAxisLabel
229 zCol = self.zAxisLabel # noqa: F841
231 toPlotList = []
232 # For galaxies
233 if "galaxies" in self.plotTypes:
234 sortedArrs = self.sortAllArrays(
235 [data["zGalaxies"], data["xGalaxies"], data["yGalaxies"], data["galaxyStatMask"]]
236 )
237 [colorValsGalaxies, xsGalaxies, ysGalaxies, statGalaxies] = sortedArrs
238 statGalMed, statGalMad, galStatsText = self.statsAndText(colorValsGalaxies, mask=statGalaxies)
239 # Add statistics
240 bbox = dict(facecolor="lemonchiffon", alpha=0.5, edgecolor="none")
241 # Check if plotting stars and galaxies, if so move the
242 # text box so that both can be seen. Needs to be
243 # > 2 becuase not being plotted points are assigned 0
244 if len(self.plotTypes) > 2:
245 boxLoc = 0.63
246 else:
247 boxLoc = 0.8
248 ax.text(boxLoc, 0.91, galStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox)
249 toPlotList.append((xsGalaxies, ysGalaxies, colorValsGalaxies, redPurple, "Galaxies"))
251 # For stars
252 if "stars" in self.plotTypes:
253 sortedArrs = self.sortAllArrays(
254 [data["zStars"], data["xStars"], data["yStars"], data["starStatMask"]]
255 )
256 [colorValsStars, xsStars, ysStars, statStars] = sortedArrs
257 statStarMed, statStarMad, starStatsText = self.statsAndText(colorValsStars, mask=statStars)
258 # Add statistics
259 bbox = dict(facecolor="paleturquoise", alpha=0.5, edgecolor="none")
260 ax.text(0.8, 0.91, starStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox)
261 toPlotList.append((xsStars, ysStars, colorValsStars, blueGreen, "Stars"))
263 # For unknowns
264 if "unknown" in self.plotTypes:
265 sortedArrs = self.sortAllArrays(
266 [data["zUnknowns"], data["xUnknowns"], data["yUnknowns"], data["unknownStatMask"]]
267 )
268 [colorValsUnknowns, xsUnknowns, ysUnknowns, statUnknowns] = sortedArrs
269 statUnknownMed, statUnknownMad, unknownStatsText = self.statsAndText(
270 colorValsUnknowns, mask=statUnknowns
271 )
272 bbox = dict(facecolor="green", alpha=0.2, edgecolor="none")
273 ax.text(0.8, 0.91, unknownStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox)
274 toPlotList.append((xsUnknowns, ysUnknowns, colorValsUnknowns, "viridis", "Unknown"))
276 if "any" in self.plotTypes:
277 sortedArrs = self.sortAllArrays([data["z"], data["x"], data["y"], data["statMask"]])
278 [colorValsAny, xs, ys, statAny] = sortedArrs
279 statAnyMed, statAnyMad, anyStatsText = self.statsAndText(colorValsAny, mask=statAny)
280 bbox = dict(facecolor="purple", alpha=0.2, edgecolor="none")
281 ax.text(0.8, 0.91, anyStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox)
282 toPlotList.append((xs, ys, colorValsAny, orangeBlue, "All"))
284 # Corner plot of patches showing summary stat in each
285 if self.plotOutlines:
286 patches = []
287 for dataId in sumStats.keys():
288 (corners, _) = sumStats[dataId]
289 ra = corners[0][0].asDegrees()
290 dec = corners[0][1].asDegrees()
291 xy = (ra, dec)
292 width = corners[2][0].asDegrees() - ra
293 height = corners[2][1].asDegrees() - dec
294 patches.append(Rectangle(xy, width, height, alpha=0.3))
295 ras = [ra.asDegrees() for (ra, dec) in corners]
296 decs = [dec.asDegrees() for (ra, dec) in corners]
297 ax.plot(ras + [ras[0]], decs + [decs[0]], "k", lw=0.5)
298 cenX = ra + width / 2
299 cenY = dec + height / 2
300 if dataId == "tract":
301 minRa = np.min(ras)
302 minDec = np.min(decs)
303 maxRa = np.max(ras)
304 maxDec = np.max(decs)
305 if dataId != "tract":
306 ax.annotate(
307 dataId,
308 (cenX, cenY),
309 color="k",
310 fontsize=5,
311 ha="center",
312 va="center",
313 path_effects=[pathEffects.withStroke(linewidth=2, foreground="w")],
314 )
316 for (i, (xs, ys, colorVals, cmap, label)) in enumerate(toPlotList):
317 if not self.plotOutlines or "tract" not in sumStats.keys():
318 minRa = np.min(xs)
319 maxRa = np.max(xs)
320 minDec = np.min(ys)
321 maxDec = np.max(ys)
322 # Avoid identical end points which causes problems in binning
323 if minRa == maxRa:
324 maxRa += 1e-5 # There is no reason to pick this number in particular
325 if minDec == maxDec:
326 maxDec += 1e-5 # There is no reason to pick this number in particular
327 med = np.median(colorVals)
328 mad = sigmaMad(colorVals)
329 vmin = med - 2 * mad
330 vmax = med + 2 * mad
331 if self.fixAroundZero:
332 scaleEnd = np.max([np.abs(vmin), np.abs(vmax)])
333 vmin = -1 * scaleEnd
334 vmax = scaleEnd
335 nBins = 45
336 xBinEdges = np.linspace(minRa, maxRa, nBins + 1)
337 yBinEdges = np.linspace(minDec, maxDec, nBins + 1)
338 binnedStats, xEdges, yEdges, binNums = binned_statistic_2d(
339 xs, ys, colorVals, statistic="median", bins=(xBinEdges, yBinEdges)
340 )
342 if len(xs) > 5000:
343 s = 500 / (len(xs) ** 0.5)
344 lw = (s**0.5) / 10
345 plotOut = ax.imshow(
346 binnedStats.T,
347 cmap=cmap,
348 extent=[xEdges[0], xEdges[-1], yEdges[-1], yEdges[0]],
349 vmin=vmin,
350 vmax=vmax,
351 )
352 # find the most extreme 15% of points, because the list
353 # is ordered by the distance from the median this is just
354 # the final 15% of points
355 extremes = int(np.floor((len(xs) / 100)) * 85)
356 ax.scatter(
357 xs[extremes:],
358 ys[extremes:],
359 c=colorVals[extremes:],
360 s=s,
361 cmap=cmap,
362 vmin=vmin,
363 vmax=vmax,
364 edgecolor="white",
365 linewidths=lw,
366 )
368 else:
369 plotOut = ax.scatter(
370 xs,
371 ys,
372 c=colorVals,
373 cmap=cmap,
374 s=7,
375 vmin=vmin,
376 vmax=vmax,
377 edgecolor="white",
378 linewidths=0.2,
379 )
381 cax = fig.add_axes([0.87 + i * 0.04, 0.11, 0.04, 0.77])
382 plt.colorbar(plotOut, cax=cax, extend="both")
383 colorBarLabel = "{}: {}".format(self.zAxisLabel, label)
384 text = cax.text(
385 0.5,
386 0.5,
387 colorBarLabel,
388 color="k",
389 rotation="vertical",
390 transform=cax.transAxes,
391 ha="center",
392 va="center",
393 fontsize=10,
394 )
395 text.set_path_effects([pathEffects.Stroke(linewidth=3, foreground="w"), pathEffects.Normal()])
396 cax.tick_params(labelsize=7)
398 if i == 0 and len(toPlotList) > 1:
399 cax.yaxis.set_ticks_position("left")
401 ax.set_xlabel(xCol)
402 ax.set_ylabel(yCol)
403 ax.tick_params(axis="x", labelrotation=25)
404 ax.tick_params(labelsize=7)
406 ax.set_aspect("equal")
407 plt.draw()
409 # Find some useful axis limits
410 lenXs = [len(xs) for (xs, _, _, _, _) in toPlotList]
411 if lenXs != [] and np.max(lenXs) > 1000:
412 padRa = (maxRa - minRa) / 10
413 padDec = (maxDec - minDec) / 10
414 ax.set_xlim(maxRa + padRa, minRa - padRa)
415 ax.set_ylim(minDec - padDec, maxDec + padDec)
416 else:
417 ax.invert_xaxis()
419 # Add useful information to the plot
420 plt.subplots_adjust(wspace=0.0, hspace=0.0, right=0.85)
421 fig = plt.gcf()
422 fig = addPlotInfo(fig, plotInfo)
424 return fig