Coverage for python/lsst/analysis/tools/actions/plot/skyPlot.py: 11%
165 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-02-07 03:55 -0800
« prev ^ index » next coverage.py v6.5.0, created at 2023-02-07 03:55 -0800
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ("SkyPlot",)
26from typing import Mapping, Optional
28import matplotlib.patheffects as pathEffects
29import matplotlib.pyplot as plt
30import numpy as np
31from lsst.pex.config import Field, ListField
32from matplotlib.figure import Figure
33from matplotlib.patches import Rectangle
35from ...interfaces import KeyedData, KeyedDataSchema, PlotAction, Scalar, Vector
36from ...statistics import nansigmaMad
37from .plotUtils import addPlotInfo, mkColormap, plotProjectionWithBinning, sortAllArrays
39# from .plotUtils import generateSummaryStats, parsePlotInfo
42class SkyPlot(PlotAction):
44 xAxisLabel = Field[str](doc="Label to use for the x axis.", optional=False)
45 yAxisLabel = Field[str](doc="Label to use for the y axis.", optional=False)
46 zAxisLabel = Field[str](doc="Label to use for the z axis.", optional=False)
48 plotOutlines = Field[bool](
49 doc="Plot the outlines of the ccds/patches?",
50 default=True,
51 )
53 plotTypes = ListField[str](
54 doc="Selection of types of objects to plot. Can take any combination of"
55 " stars, galaxies, unknown, mag, any.",
56 optional=False,
57 # itemCheck=_validatePlotTypes,
58 )
60 plotName = Field[str](doc="The name for the plot.", optional=False)
62 fixAroundZero = Field[bool](
63 doc="Fix the colorbar to be symmetric around zero.",
64 default=False,
65 )
67 def getInputSchema(self, **kwargs) -> KeyedDataSchema:
68 base = []
69 if "stars" in self.plotTypes: # type: ignore
70 base.append(("xStars", Vector))
71 base.append(("yStars", Vector))
72 base.append(("zStars", Vector))
73 base.append(("starStatMask", Vector))
74 if "galaxies" in self.plotTypes: # type: ignore
75 base.append(("xGalaxies", Vector))
76 base.append(("yGalaxies", Vector))
77 base.append(("zGalaxies", Vector))
78 base.append(("galaxyStatMask", Vector))
79 if "unknown" in self.plotTypes: # type: ignore
80 base.append(("xUnknowns", Vector))
81 base.append(("yUnknowns", Vector))
82 base.append(("zUnknowns", Vector))
83 base.append(("unknownStatMask", Vector))
84 if "any" in self.plotTypes: # type: ignore
85 base.append(("x", Vector))
86 base.append(("y", Vector))
87 base.append(("z", Vector))
88 base.append(("statMask", Vector))
90 return base
92 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure:
93 self._validateInput(data, **kwargs)
94 return self.makePlot(data, **kwargs)
95 # table is a dict that needs: x, y, run, skymap, filter, tract,
97 def _validateInput(self, data: KeyedData, **kwargs) -> None:
98 """NOTE currently can only check that something is not a Scalar, not
99 check that the data is consistent with Vector
100 """
101 needed = self.getInputSchema(**kwargs)
102 if remainder := {key.format(**kwargs) for key, _ in needed} - {
103 key.format(**kwargs) for key in data.keys()
104 }:
105 raise ValueError(f"Task needs keys {remainder} but they were not found in input")
106 for name, typ in needed:
107 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar)
108 if isScalar and typ != Scalar:
109 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}")
111 def statsAndText(self, arr, mask=None):
112 """Calculate some stats from an array and return them
113 and some text.
114 """
115 numPoints = len(arr)
116 if mask is not None:
117 arr = arr[mask]
118 med = np.nanmedian(arr)
119 sigMad = nansigmaMad(arr)
121 statsText = (
122 "Median: {:0.2f}\n".format(med)
123 + r"$\sigma_{MAD}$: "
124 + "{:0.2f}\n".format(sigMad)
125 + r"n$_{points}$: "
126 + "{}".format(numPoints)
127 )
129 return med, sigMad, statsText
131 def makePlot(
132 self,
133 data: KeyedData,
134 plotInfo: Optional[Mapping[str, str]] = None,
135 sumStats: Optional[Mapping] = None,
136 **kwargs,
137 ) -> Figure:
138 """Prep the catalogue and then make a skyPlot of the given column.
140 Parameters
141 ----------
142 catPlot : `pandas.core.frame.DataFrame`
143 The catalog to plot the points from.
144 dataId :
145 `lsst.daf.butler.core.dimensions._coordinate._ExpandedTupleDataCoordinate`
146 The dimensions that the plot is being made from.
147 runName : `str`
148 The name of the collection that the plot is written out to.
149 skymap : `lsst.skymap`
150 The skymap used to define the patch boundaries.
151 tableName : `str`
152 The type of table used to make the plot.
154 Returns
155 -------
156 `pipeBase.Struct` containing:
157 skyPlot : `matplotlib.figure.Figure`
158 The resulting figure.
160 Notes
161 -----
162 The catalogue is first narrowed down using the selectors specified in
163 `self.config.selectorActions`.
164 If the column names are 'Functor' then the functors specified in
165 `self.config.axisFunctors` are used to calculate the required values.
166 After this the following functions are run:
168 `parsePlotInfo` which uses the dataId, runName and tableName to add
169 useful information to the plot.
171 `generateSummaryStats` which parses the skymap to give the corners of
172 the patches for later plotting and calculates some basic statistics
173 in each patch for the column in self.config.axisActions['zAction'].
175 `SkyPlot` which makes the plot of the sky distribution of
176 `self.config.axisActions['zAction']`.
178 Makes a generic plot showing the value at given points on the sky.
180 Parameters
181 ----------
182 catPlot : `pandas.core.frame.DataFrame`
183 The catalog to plot the points from.
184 plotInfo : `dict`
185 A dictionary of information about the data being plotted with keys:
186 ``"run"``
187 The output run for the plots (`str`).
188 ``"skymap"``
189 The type of skymap used for the data (`str`).
190 ``"filter"``
191 The filter used for this data (`str`).
192 ``"tract"``
193 The tract that the data comes from (`str`).
194 sumStats : `dict`
195 A dictionary where the patchIds are the keys which store the R.A.
196 and dec of the corners of the patch.
198 Returns
199 -------
200 fig : `matplotlib.figure.Figure`
201 The resulting figure.
203 Notes
204 -----
205 Uses the config options `self.config.xColName` and
206 `self.config.yColName` to plot points color coded by
207 `self.config.axisActions['zAction']`.
208 The points plotted are those selected by the selectors specified in
209 `self.config.selectorActions`.
210 """
211 fig = plt.figure(dpi=300)
212 ax = fig.add_subplot(111)
214 if sumStats is None:
215 sumStats = {}
217 if plotInfo is None:
218 plotInfo = {}
220 # Make divergent colormaps for stars, galaxes and all the points
221 blueGreen = mkColormap(["midnightblue", "lightcyan", "darkgreen"])
222 redPurple = mkColormap(["indigo", "lemonchiffon", "firebrick"])
223 orangeBlue = mkColormap(["darkOrange", "thistle", "midnightblue"])
225 xCol = self.xAxisLabel
226 yCol = self.yAxisLabel
227 zCol = self.zAxisLabel # noqa: F841
229 toPlotList = []
230 # For galaxies
231 if "galaxies" in self.plotTypes:
232 sortedArrs = sortAllArrays(
233 [data["zGalaxies"], data["xGalaxies"], data["yGalaxies"], data["galaxyStatMask"]]
234 )
235 [colorValsGalaxies, xsGalaxies, ysGalaxies, statGalaxies] = sortedArrs
236 statGalMed, statGalMad, galStatsText = self.statsAndText(colorValsGalaxies, mask=statGalaxies)
237 # Add statistics
238 bbox = dict(facecolor="lemonchiffon", alpha=0.5, edgecolor="none")
239 # Check if plotting stars and galaxies, if so move the
240 # text box so that both can be seen. Needs to be
241 # > 2 becuase not being plotted points are assigned 0
242 if len(self.plotTypes) > 2:
243 boxLoc = 0.63
244 else:
245 boxLoc = 0.8
246 ax.text(boxLoc, 0.91, galStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox)
247 toPlotList.append((xsGalaxies, ysGalaxies, colorValsGalaxies, redPurple, "Galaxies"))
249 # For stars
250 if "stars" in self.plotTypes:
251 sortedArrs = sortAllArrays([data["zStars"], data["xStars"], data["yStars"], data["starStatMask"]])
252 [colorValsStars, xsStars, ysStars, statStars] = sortedArrs
253 statStarMed, statStarMad, starStatsText = self.statsAndText(colorValsStars, mask=statStars)
254 # Add statistics
255 bbox = dict(facecolor="paleturquoise", alpha=0.5, edgecolor="none")
256 ax.text(0.8, 0.91, starStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox)
257 toPlotList.append((xsStars, ysStars, colorValsStars, blueGreen, "Stars"))
259 # For unknowns
260 if "unknown" in self.plotTypes:
261 sortedArrs = sortAllArrays(
262 [data["zUnknowns"], data["xUnknowns"], data["yUnknowns"], data["unknownStatMask"]]
263 )
264 [colorValsUnknowns, xsUnknowns, ysUnknowns, statUnknowns] = sortedArrs
265 statUnknownMed, statUnknownMad, unknownStatsText = self.statsAndText(
266 colorValsUnknowns, mask=statUnknowns
267 )
268 bbox = dict(facecolor="green", alpha=0.2, edgecolor="none")
269 ax.text(0.8, 0.91, unknownStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox)
270 toPlotList.append((xsUnknowns, ysUnknowns, colorValsUnknowns, "viridis", "Unknown"))
272 if "any" in self.plotTypes:
273 sortedArrs = sortAllArrays([data["z"], data["x"], data["y"], data["statMask"]])
274 [colorValsAny, xs, ys, statAny] = sortedArrs
275 statAnyMed, statAnyMad, anyStatsText = self.statsAndText(colorValsAny, mask=statAny)
276 bbox = dict(facecolor="purple", alpha=0.2, edgecolor="none")
277 ax.text(0.8, 0.91, anyStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox)
278 toPlotList.append((xs, ys, colorValsAny, orangeBlue, "All"))
280 # Corner plot of patches showing summary stat in each
281 if self.plotOutlines:
282 patches = []
283 for dataId in sumStats.keys():
284 (corners, _) = sumStats[dataId]
285 ra = corners[0][0].asDegrees()
286 dec = corners[0][1].asDegrees()
287 xy = (ra, dec)
288 width = corners[2][0].asDegrees() - ra
289 height = corners[2][1].asDegrees() - dec
290 patches.append(Rectangle(xy, width, height, alpha=0.3))
291 ras = [ra.asDegrees() for (ra, dec) in corners]
292 decs = [dec.asDegrees() for (ra, dec) in corners]
293 ax.plot(ras + [ras[0]], decs + [decs[0]], "k", lw=0.5)
294 cenX = ra + width / 2
295 cenY = dec + height / 2
296 if dataId == "tract":
297 minRa = np.min(ras)
298 minDec = np.min(decs)
299 maxRa = np.max(ras)
300 maxDec = np.max(decs)
301 if dataId != "tract":
302 ax.annotate(
303 dataId,
304 (cenX, cenY),
305 color="k",
306 fontsize=5,
307 ha="center",
308 va="center",
309 path_effects=[pathEffects.withStroke(linewidth=2, foreground="w")],
310 )
312 for (i, (xs, ys, colorVals, cmap, label)) in enumerate(toPlotList):
313 if not self.plotOutlines or "tract" not in sumStats.keys():
314 minRa = np.min(xs)
315 maxRa = np.max(xs)
316 minDec = np.min(ys)
317 maxDec = np.max(ys)
318 # Avoid identical end points which causes problems in binning
319 if minRa == maxRa:
320 maxRa += 1e-5 # There is no reason to pick this number in particular
321 if minDec == maxDec:
322 maxDec += 1e-5 # There is no reason to pick this number in particular
324 plotOut = plotProjectionWithBinning(
325 ax,
326 xs,
327 ys,
328 colorVals,
329 cmap,
330 minRa,
331 maxRa,
332 minDec,
333 maxDec,
334 fixAroundZero=self.fixAroundZero,
335 isSorted=True,
336 )
337 cax = fig.add_axes([0.87 + i * 0.04, 0.11, 0.04, 0.77])
338 plt.colorbar(plotOut, cax=cax, extend="both")
339 colorBarLabel = "{}: {}".format(self.zAxisLabel, label)
340 text = cax.text(
341 0.5,
342 0.5,
343 colorBarLabel,
344 color="k",
345 rotation="vertical",
346 transform=cax.transAxes,
347 ha="center",
348 va="center",
349 fontsize=10,
350 )
351 text.set_path_effects([pathEffects.Stroke(linewidth=3, foreground="w"), pathEffects.Normal()])
352 cax.tick_params(labelsize=7)
354 if i == 0 and len(toPlotList) > 1:
355 cax.yaxis.set_ticks_position("left")
357 ax.set_xlabel(xCol)
358 ax.set_ylabel(yCol)
359 ax.tick_params(axis="x", labelrotation=25)
360 ax.tick_params(labelsize=7)
362 ax.set_aspect("equal")
363 plt.draw()
365 # Find some useful axis limits
366 lenXs = [len(xs) for (xs, _, _, _, _) in toPlotList]
367 if lenXs != [] and np.max(lenXs) > 1000:
368 padRa = (maxRa - minRa) / 10
369 padDec = (maxDec - minDec) / 10
370 ax.set_xlim(maxRa + padRa, minRa - padRa)
371 ax.set_ylim(minDec - padDec, maxDec + padDec)
372 else:
373 ax.invert_xaxis()
375 # Add useful information to the plot
376 plt.subplots_adjust(wspace=0.0, hspace=0.0, right=0.85)
377 fig = plt.gcf()
378 fig = addPlotInfo(fig, plotInfo)
380 return fig