Coverage for python/lsst/analysis/tools/actions/plot/skyPlot.py: 10%
188 statements
« prev ^ index » next coverage.py v6.5.0, created at 2022-12-08 10:53 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2022-12-08 10:53 +0000
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24from typing import Mapping, Optional
26import matplotlib.patheffects as pathEffects
27import matplotlib.pyplot as plt
28import numpy as np
29from lsst.pex.config import Field, ListField
30from matplotlib.figure import Figure
31from matplotlib.patches import Rectangle
32from scipy.stats import binned_statistic_2d
34from ...interfaces import KeyedData, KeyedDataSchema, PlotAction, Scalar, Vector
35from ...statistics import nansigmaMad, sigmaMad
36from .plotUtils import addPlotInfo, extremaSort, mkColormap
38# from .plotUtils import generateSummaryStats, parsePlotInfo
41class SkyPlot(PlotAction):
43 xAxisLabel = Field[str](doc="Label to use for the x axis.", optional=False)
44 yAxisLabel = Field[str](doc="Label to use for the y axis.", optional=False)
45 zAxisLabel = Field[str](doc="Label to use for the z axis.", optional=False)
47 plotOutlines = Field[bool](
48 doc="Plot the outlines of the ccds/patches?",
49 default=True,
50 )
52 plotTypes = ListField[str](
53 doc="Selection of types of objects to plot. Can take any combination of"
54 " stars, galaxies, unknown, mag, any.",
55 optional=False,
56 # itemCheck=_validatePlotTypes,
57 )
59 plotName = Field[str](doc="The name for the plot.", optional=False)
61 fixAroundZero = Field[bool](
62 doc="Fix the colorbar to be symmetric around zero.",
63 default=False,
64 )
66 def getInputSchema(self, **kwargs) -> KeyedDataSchema:
67 base = []
68 if "stars" in self.plotTypes: # type: ignore
69 base.append(("xStars", Vector))
70 base.append(("yStars", Vector))
71 base.append(("zStars", Vector))
72 base.append(("starStatMask", Vector))
73 if "galaxies" in self.plotTypes: # type: ignore
74 base.append(("xGalaxies", Vector))
75 base.append(("yGalaxies", Vector))
76 base.append(("zGalaxies", Vector))
77 base.append(("galaxyStatMask", Vector))
78 if "unknown" in self.plotTypes: # type: ignore
79 base.append(("xUnknowns", Vector))
80 base.append(("yUnknowns", Vector))
81 base.append(("zUnknowns", Vector))
82 base.append(("unknownStatMask", Vector))
83 if "any" in self.plotTypes: # type: ignore
84 base.append(("x", Vector))
85 base.append(("y", Vector))
86 base.append(("z", Vector))
87 base.append(("statMask", Vector))
89 return base
91 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure:
92 self._validateInput(data, **kwargs)
93 return self.makePlot(data, **kwargs)
94 # table is a dict that needs: x, y, run, skymap, filter, tract,
96 def _validateInput(self, data: KeyedData, **kwargs) -> None:
97 """NOTE currently can only check that something is not a Scalar, not
98 check that the data is consistent with Vector
99 """
100 needed = self.getInputSchema(**kwargs)
101 if remainder := {key.format(**kwargs) for key, _ in needed} - {
102 key.format(**kwargs) for key in data.keys()
103 }:
104 raise ValueError(f"Task needs keys {remainder} but they were not found in input")
105 for name, typ in needed:
106 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar)
107 if isScalar and typ != Scalar:
108 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}")
110 def sortAllArrays(self, arrsToSort):
111 """Sort one array and then return all the others in
112 the associated order.
113 """
114 ids = extremaSort(arrsToSort[0])
115 for (i, arr) in enumerate(arrsToSort):
116 arrsToSort[i] = arr[ids]
117 return arrsToSort
119 def statsAndText(self, arr, mask=None):
120 """Calculate some stats from an array and return them
121 and some text.
122 """
123 numPoints = len(arr)
124 if mask is not None:
125 arr = arr[mask]
126 med = np.nanmedian(arr)
127 sigMad = nansigmaMad(arr)
129 statsText = (
130 "Median: {:0.2f}\n".format(med)
131 + r"$\sigma_{MAD}$: "
132 + "{:0.2f}\n".format(sigMad)
133 + r"n$_{points}$: "
134 + "{}".format(numPoints)
135 )
137 return med, sigMad, statsText
139 def makePlot(
140 self,
141 data: KeyedData,
142 plotInfo: Optional[Mapping[str, str]] = None,
143 sumStats: Optional[Mapping] = None,
144 **kwargs,
145 ) -> Figure:
146 """Prep the catalogue and then make a skyPlot of the given column.
148 Parameters
149 ----------
150 catPlot : `pandas.core.frame.DataFrame`
151 The catalog to plot the points from.
152 dataId :
153 `lsst.daf.butler.core.dimensions._coordinate._ExpandedTupleDataCoordinate`
154 The dimensions that the plot is being made from.
155 runName : `str`
156 The name of the collection that the plot is written out to.
157 skymap : `lsst.skymap`
158 The skymap used to define the patch boundaries.
159 tableName : `str`
160 The type of table used to make the plot.
162 Returns
163 -------
164 `pipeBase.Struct` containing:
165 skyPlot : `matplotlib.figure.Figure`
166 The resulting figure.
168 Notes
169 -----
170 The catalogue is first narrowed down using the selectors specified in
171 `self.config.selectorActions`.
172 If the column names are 'Functor' then the functors specified in
173 `self.config.axisFunctors` are used to calculate the required values.
174 After this the following functions are run:
176 `parsePlotInfo` which uses the dataId, runName and tableName to add
177 useful information to the plot.
179 `generateSummaryStats` which parses the skymap to give the corners of
180 the patches for later plotting and calculates some basic statistics
181 in each patch for the column in self.config.axisActions['zAction'].
183 `SkyPlot` which makes the plot of the sky distribution of
184 `self.config.axisActions['zAction']`.
186 Makes a generic plot showing the value at given points on the sky.
188 Parameters
189 ----------
190 catPlot : `pandas.core.frame.DataFrame`
191 The catalog to plot the points from.
192 plotInfo : `dict`
193 A dictionary of information about the data being plotted with keys:
194 ``"run"``
195 The output run for the plots (`str`).
196 ``"skymap"``
197 The type of skymap used for the data (`str`).
198 ``"filter"``
199 The filter used for this data (`str`).
200 ``"tract"``
201 The tract that the data comes from (`str`).
202 sumStats : `dict`
203 A dictionary where the patchIds are the keys which store the R.A.
204 and dec of the corners of the patch.
206 Returns
207 -------
208 fig : `matplotlib.figure.Figure`
209 The resulting figure.
211 Notes
212 -----
213 Uses the config options `self.config.xColName` and
214 `self.config.yColName` to plot points color coded by
215 `self.config.axisActions['zAction']`.
216 The points plotted are those selected by the selectors specified in
217 `self.config.selectorActions`.
218 """
219 fig = plt.figure(dpi=300)
220 ax = fig.add_subplot(111)
222 if sumStats is None:
223 sumStats = {}
225 if plotInfo is None:
226 plotInfo = {}
228 # Make divergent colormaps for stars, galaxes and all the points
229 blueGreen = mkColormap(["midnightblue", "lightcyan", "darkgreen"])
230 redPurple = mkColormap(["indigo", "lemonchiffon", "firebrick"])
231 orangeBlue = mkColormap(["darkOrange", "thistle", "midnightblue"])
233 xCol = self.xAxisLabel
234 yCol = self.yAxisLabel
235 zCol = self.zAxisLabel # noqa: F841
237 toPlotList = []
238 # For galaxies
239 if "galaxies" in self.plotTypes:
240 sortedArrs = self.sortAllArrays(
241 [data["zGalaxies"], data["xGalaxies"], data["yGalaxies"], data["galaxyStatMask"]]
242 )
243 [colorValsGalaxies, xsGalaxies, ysGalaxies, statGalaxies] = sortedArrs
244 statGalMed, statGalMad, galStatsText = self.statsAndText(colorValsGalaxies, mask=statGalaxies)
245 # Add statistics
246 bbox = dict(facecolor="lemonchiffon", alpha=0.5, edgecolor="none")
247 # Check if plotting stars and galaxies, if so move the
248 # text box so that both can be seen. Needs to be
249 # > 2 becuase not being plotted points are assigned 0
250 if len(self.plotTypes) > 2:
251 boxLoc = 0.63
252 else:
253 boxLoc = 0.8
254 ax.text(boxLoc, 0.91, galStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox)
255 toPlotList.append((xsGalaxies, ysGalaxies, colorValsGalaxies, redPurple, "Galaxies"))
257 # For stars
258 if "stars" in self.plotTypes:
259 sortedArrs = self.sortAllArrays(
260 [data["zStars"], data["xStars"], data["yStars"], data["starStatMask"]]
261 )
262 [colorValsStars, xsStars, ysStars, statStars] = sortedArrs
263 statStarMed, statStarMad, starStatsText = self.statsAndText(colorValsStars, mask=statStars)
264 # Add statistics
265 bbox = dict(facecolor="paleturquoise", alpha=0.5, edgecolor="none")
266 ax.text(0.8, 0.91, starStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox)
267 toPlotList.append((xsStars, ysStars, colorValsStars, blueGreen, "Stars"))
269 # For unknowns
270 if "unknown" in self.plotTypes:
271 sortedArrs = self.sortAllArrays(
272 [data["zUnknowns"], data["xUnknowns"], data["yUnknowns"], data["unknownStatMask"]]
273 )
274 [colorValsUnknowns, xsUnknowns, ysUnknowns, statUnknowns] = sortedArrs
275 statUnknownMed, statUnknownMad, unknownStatsText = self.statsAndText(
276 colorValsUnknowns, mask=statUnknowns
277 )
278 bbox = dict(facecolor="green", alpha=0.2, edgecolor="none")
279 ax.text(0.8, 0.91, unknownStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox)
280 toPlotList.append((xsUnknowns, ysUnknowns, colorValsUnknowns, "viridis", "Unknown"))
282 if "any" in self.plotTypes:
283 sortedArrs = self.sortAllArrays([data["z"], data["x"], data["y"], data["statMask"]])
284 [colorValsAny, xs, ys, statAny] = sortedArrs
285 statAnyMed, statAnyMad, anyStatsText = self.statsAndText(colorValsAny, mask=statAny)
286 bbox = dict(facecolor="purple", alpha=0.2, edgecolor="none")
287 ax.text(0.8, 0.91, anyStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox)
288 toPlotList.append((xs, ys, colorValsAny, orangeBlue, "All"))
290 # Corner plot of patches showing summary stat in each
291 if self.plotOutlines:
292 patches = []
293 for dataId in sumStats.keys():
294 (corners, _) = sumStats[dataId]
295 ra = corners[0][0].asDegrees()
296 dec = corners[0][1].asDegrees()
297 xy = (ra, dec)
298 width = corners[2][0].asDegrees() - ra
299 height = corners[2][1].asDegrees() - dec
300 patches.append(Rectangle(xy, width, height, alpha=0.3))
301 ras = [ra.asDegrees() for (ra, dec) in corners]
302 decs = [dec.asDegrees() for (ra, dec) in corners]
303 ax.plot(ras + [ras[0]], decs + [decs[0]], "k", lw=0.5)
304 cenX = ra + width / 2
305 cenY = dec + height / 2
306 if dataId == "tract":
307 minRa = np.min(ras)
308 minDec = np.min(decs)
309 maxRa = np.max(ras)
310 maxDec = np.max(decs)
311 if dataId != "tract":
312 ax.annotate(
313 dataId,
314 (cenX, cenY),
315 color="k",
316 fontsize=5,
317 ha="center",
318 va="center",
319 path_effects=[pathEffects.withStroke(linewidth=2, foreground="w")],
320 )
322 for (i, (xs, ys, colorVals, cmap, label)) in enumerate(toPlotList):
323 if not self.plotOutlines or "tract" not in sumStats.keys():
324 minRa = np.min(xs)
325 maxRa = np.max(xs)
326 minDec = np.min(ys)
327 maxDec = np.max(ys)
328 # Avoid identical end points which causes problems in binning
329 if minRa == maxRa:
330 maxRa += 1e-5 # There is no reason to pick this number in particular
331 if minDec == maxDec:
332 maxDec += 1e-5 # There is no reason to pick this number in particular
333 med = np.nanmedian(colorVals)
334 mad = sigmaMad(colorVals, nan_policy="omit")
335 vmin = med - 2 * mad
336 vmax = med + 2 * mad
337 if self.fixAroundZero:
338 scaleEnd = np.max([np.abs(vmin), np.abs(vmax)])
339 vmin = -1 * scaleEnd
340 vmax = scaleEnd
341 nBins = 45
342 xBinEdges = np.linspace(minRa, maxRa, nBins + 1)
343 yBinEdges = np.linspace(minDec, maxDec, nBins + 1)
344 binnedStats, xEdges, yEdges, binNums = binned_statistic_2d(
345 xs, ys, colorVals, statistic="median", bins=(xBinEdges, yBinEdges)
346 )
348 if len(xs) > 5000:
349 s = 500 / (len(xs) ** 0.5)
350 lw = (s**0.5) / 10
351 plotOut = ax.imshow(
352 binnedStats.T,
353 cmap=cmap,
354 extent=[xEdges[0], xEdges[-1], yEdges[-1], yEdges[0]],
355 vmin=vmin,
356 vmax=vmax,
357 )
358 # find the most extreme 15% of points, because the list
359 # is ordered by the distance from the median this is just
360 # the final 15% of points
361 extremes = int(np.floor((len(xs) / 100)) * 85)
362 ax.scatter(
363 xs[extremes:],
364 ys[extremes:],
365 c=colorVals[extremes:],
366 s=s,
367 cmap=cmap,
368 vmin=vmin,
369 vmax=vmax,
370 edgecolor="white",
371 linewidths=lw,
372 )
374 else:
375 plotOut = ax.scatter(
376 xs,
377 ys,
378 c=colorVals,
379 cmap=cmap,
380 s=7,
381 vmin=vmin,
382 vmax=vmax,
383 edgecolor="white",
384 linewidths=0.2,
385 )
387 cax = fig.add_axes([0.87 + i * 0.04, 0.11, 0.04, 0.77])
388 plt.colorbar(plotOut, cax=cax, extend="both")
389 colorBarLabel = "{}: {}".format(self.zAxisLabel, label)
390 text = cax.text(
391 0.5,
392 0.5,
393 colorBarLabel,
394 color="k",
395 rotation="vertical",
396 transform=cax.transAxes,
397 ha="center",
398 va="center",
399 fontsize=10,
400 )
401 text.set_path_effects([pathEffects.Stroke(linewidth=3, foreground="w"), pathEffects.Normal()])
402 cax.tick_params(labelsize=7)
404 if i == 0 and len(toPlotList) > 1:
405 cax.yaxis.set_ticks_position("left")
407 ax.set_xlabel(xCol)
408 ax.set_ylabel(yCol)
409 ax.tick_params(axis="x", labelrotation=25)
410 ax.tick_params(labelsize=7)
412 ax.set_aspect("equal")
413 plt.draw()
415 # Find some useful axis limits
416 lenXs = [len(xs) for (xs, _, _, _, _) in toPlotList]
417 if lenXs != [] and np.max(lenXs) > 1000:
418 padRa = (maxRa - minRa) / 10
419 padDec = (maxDec - minDec) / 10
420 ax.set_xlim(maxRa + padRa, minRa - padRa)
421 ax.set_ylim(minDec - padDec, maxDec + padDec)
422 else:
423 ax.invert_xaxis()
425 # Add useful information to the plot
426 plt.subplots_adjust(wspace=0.0, hspace=0.0, right=0.85)
427 fig = plt.gcf()
428 fig = addPlotInfo(fig, plotInfo)
430 return fig