Coverage for python/lsst/analysis/tools/actions/plot/colorColorFitPlot.py: 14%
165 statements
« prev ^ index » next coverage.py v6.4.2, created at 2022-08-06 09:59 +0000
« prev ^ index » next coverage.py v6.4.2, created at 2022-08-06 09:59 +0000
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21from __future__ import annotations
23from typing import Mapping, Optional, cast
25import matplotlib.patheffects as pathEffects
26import matplotlib.pyplot as plt
27import numpy as np
28from lsst.analysis.tools import PlotAction
29from lsst.pex.config import Field, ListField
30from matplotlib.figure import Figure
31from matplotlib.patches import Rectangle
32from scipy.stats import median_absolute_deviation as sigmaMad
33from sklearn.neighbors import KernelDensity
35from ...interfaces import KeyedData, KeyedDataSchema, Scalar, Vector
36from .plotUtils import addPlotInfo, mkColormap, perpDistance
39class ColorColorFitPlot(PlotAction):
41 xAxisLabel = Field[str](doc="Label to use for the x axis", optional=False)
42 yAxisLabel = Field[str](doc="Label to use for the y axis", optional=False)
43 magLabel = Field[str](doc="Label to use for the magnitudes used to color code by", optional=False)
45 plotTypes = ListField[str](
46 doc="Selection of types of objects to plot. Can take any combination of"
47 " stars, galaxies, unknown, mag, any.",
48 optional=False,
49 )
51 plotName = Field[str](doc="The name for the plot.", optional=False)
53 def getInputSchema(self, **kwargs) -> KeyedDataSchema:
54 base: list[tuple[str, type[Vector] | type[Scalar]]] = []
55 base.append(("x", Vector))
56 base.append(("y", Vector))
57 base.append(("mag", Vector))
58 base.append(("approxMagDepth", Scalar))
59 base.append((f"{self.plotName}_sigmaMAD", Scalar))
60 base.append((f"{self.plotName}_median", Scalar))
61 base.append((f"{self.plotName}_hardwired_sigmaMAD", Scalar))
62 base.append((f"{self.plotName}_hardwired_median", Scalar))
63 base.append(("xMin", Scalar))
64 base.append(("xMax", Scalar))
65 base.append(("yMin", Scalar))
66 base.append(("yMax", Scalar))
67 base.append(("mHW", Scalar))
68 base.append(("bHW", Scalar))
69 base.append(("mODR", Scalar))
70 base.append(("bODR", Scalar))
71 base.append(("yBoxMin", Scalar))
72 base.append(("yBoxMax", Scalar))
73 base.append(("bPerpMin", Scalar))
74 base.append(("bPerpMax", Scalar))
75 base.append(("mODR2", Scalar))
76 base.append(("bODR2", Scalar))
77 base.append(("mPerp", Scalar))
79 return base
81 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure:
82 self._validateInput(data, **kwargs)
83 return self.makePlot(data, **kwargs)
85 def _validateInput(self, data: KeyedData, **kwargs) -> None:
86 """NOTE currently can only check that something is not a scalar, not
87 check that data is consistent with Vector
88 """
89 needed = self.getInputSchema(**kwargs)
90 if remainder := {key.format(**kwargs) for key, _ in needed} - {
91 key.format(**kwargs) for key in data.keys()
92 }:
93 raise ValueError(f"Task needs keys {remainder} but they were not in input")
94 for name, typ in needed:
95 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar)
96 if isScalar and typ != Scalar:
97 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}")
99 def makePlot(
100 self,
101 data: KeyedData,
102 plotInfo: Optional[Mapping[str, str]] = None,
103 **kwargs,
104 ) -> Figure:
105 """Make stellar locus plots using pre fitted values.
107 Parameters
108 ----------
109 catPlot : `pandas.core.frame.DataFrame`
110 The catalog to plot the points from.
111 plotInfo : `dict`
112 A dictionary of information about the data being plotted with keys:
113 ``"run"``
114 The output run for the plots (`str`).
115 ``"skymap"``
116 The type of skymap used for the data (`str`).
117 ``"filter"``
118 The filter used for this data (`str`).
119 ``"tract"``
120 The tract that the data comes from (`str`).
121 fitParams : `dict`
122 The parameters of the fit to the stellar locus calculated
123 elsewhere, they are used to plot the fit line on the
124 figure.
125 ``"bHW"``
126 The hardwired intercept to fall back on.
127 ``"b_odr"``
128 The intercept calculated by the orthogonal distance
129 regression fitting.
130 ``"mHW"``
131 The hardwired gradient to fall back on.
132 ``"m_odr"``
133 The gradient calculated by the orthogonal distance
134 regression fitting.
135 ``"magLim"``
136 The magnitude limit used in the fitting.
137 ``"x1`"``
138 The x minimum of the box used in the fit.
139 ``"x2"``
140 The x maximum of the box used in the fit.
141 ``"y1"``
142 The y minimum of the box used in the fit.
143 ``"y2"``
144 The y maximum of the box used in the fit.
146 Returns
147 -------
148 fig : `matplotlib.figure.Figure`
149 The resulting figure.
151 Notes
152 -----
153 Makes a color-color plot of `self.config.xColName` against
154 `self.config.yColName`, these points are color coded by i band
155 CModel magnitude. The stellar locus fits calculated from
156 the calcStellarLocus task are then overplotted. The axis labels
157 are given by `self.config.xLabel` and `self.config.yLabel`.
158 The selector given in `self.config.sourceSelectorActions`
159 is used for source selection. The distance of the points to
160 the fit line is given in a histogram in the second panel.
161 """
163 # Define a new colormap
164 newBlues = mkColormap(["paleturquoise", "midnightblue"])
166 # Make a figure with three panels
167 fig = plt.figure(dpi=300)
168 ax = fig.add_axes([0.12, 0.25, 0.43, 0.60])
169 axContour = fig.add_axes([0.65, 0.11, 0.3, 0.31])
170 axHist = fig.add_axes([0.65, 0.51, 0.3, 0.31])
171 xs = cast(Vector, data["x"])
172 ys = cast(Vector, data["y"])
173 mags = data["mag"]
175 # TODO: Make a no data fig function and use here
176 if len(xs) == 0 or len(ys) == 0:
177 return fig
179 # Points to use for the fit
180 # type ignore because Vector needs a prototype interface
181 fitPoints = np.where(
182 (xs > data["xMin"]) # type: ignore
183 & (xs < data["xMax"]) # type: ignore
184 & (ys > data["yMin"]) # type: ignore
185 & (ys < data["yMax"]) # type: ignore
186 )[0]
188 # Plot the initial fit box
189 ax.plot(
190 [data["xMin"], data["xMax"], data["xMax"], data["xMin"], data["xMin"]],
191 [data["yMin"], data["yMin"], data["yMax"], data["yMax"], data["yMin"]],
192 "k",
193 alpha=0.3,
194 )
196 # Add some useful information to the plot
197 bbox = dict(alpha=0.9, facecolor="white", edgecolor="none")
198 medMag = np.median(cast(Vector, mags))
200 # TODO: GET THE SN FROM THE EARLIER PREP STEP
201 SN = "-"
202 infoText = "N Used: {}\nN Total: {}\nS/N cut: {}\n".format(len(fitPoints), len(xs), SN)
203 infoText += r"Mag $\lesssim$: " + "{:0.2f}".format(medMag)
204 ax.text(0.05, 0.78, infoText, color="k", transform=ax.transAxes, fontsize=8, bbox=bbox)
206 # Calculate the density of the points
207 xy = np.vstack([xs, ys]).T
208 kde = KernelDensity(kernel="gaussian").fit(xy)
209 z = np.exp(kde.score_samples(xy))
211 ax.scatter(xs[~fitPoints], ys[~fitPoints], c=z[~fitPoints], cmap="binary", s=0.3)
212 fitScatter = ax.scatter(
213 xs[fitPoints], ys[fitPoints], c=z[fitPoints], cmap=newBlues, label="Used for Fit", s=0.3
214 )
216 # Add colorbar
217 cbAx = fig.add_axes([0.12, 0.08, 0.43, 0.04])
218 plt.colorbar(fitScatter, cax=cbAx, orientation="horizontal")
219 cbText = cbAx.text(
220 0.5,
221 0.5,
222 "Number Density",
223 color="k",
224 rotation="horizontal",
225 transform=cbAx.transAxes,
226 ha="center",
227 va="center",
228 fontsize=8,
229 )
230 cbText.set_path_effects([pathEffects.Stroke(linewidth=3, foreground="w"), pathEffects.Normal()])
231 cbAx.set_xticks([np.min(z[fitPoints]), np.max(z[fitPoints])], labels=["Less", "More"])
233 ax.set_xlabel(self.xAxisLabel)
234 ax.set_ylabel(self.yAxisLabel)
236 # Set useful axis limits
237 percsX = np.nanpercentile(xs, [0.5, 99.5])
238 percsY = np.nanpercentile(ys, [0.5, 99.5])
239 x5 = (percsX[1] - percsX[0]) / 5
240 y5 = (percsY[1] - percsY[0]) / 5
241 ax.set_xlim(percsX[0] - x5, percsX[1] + x5)
242 ax.set_ylim(percsY[0] - y5, percsY[1] + y5)
244 # Plot the fit lines
245 if np.fabs(data["mHW"]) > 1:
246 ysFitLineHW = np.array([data["yMin"], data["yMax"]])
247 xsFitLineHW = (ysFitLineHW - data["bHW"]) / data["mHW"]
248 ysFitLine = np.array([data["yMin"], data["yMax"]])
249 xsFitLine = (ysFitLine - data["bODR"]) / data["mODR"]
250 ysFitLine2 = np.array([data["yMin"], data["yMax"]])
251 xsFitLine2 = (ysFitLine2 - data["bODR2"]) / data["mODR2"]
253 else:
254 xsFitLineHW = np.array([data["xMin"], data["xMax"]])
255 ysFitLineHW = data["mHW"] * xsFitLineHW + data["bHW"] # type: ignore
256 xsFitLine = np.array([data["xMin"], data["xMax"]])
257 ysFitLine = np.array(
258 [
259 data["mODR"] * xsFitLine[0] + data["bODR"],
260 data["mODR"] * xsFitLine[1] + data["bODR"],
261 ]
262 )
263 xsFitLine2 = np.array([data["xMin"], data["xMax"]])
264 ysFitLine2 = np.array(
265 [
266 data["mODR2"] * xsFitLine2[0] + data["bODR2"],
267 data["mODR2"] * xsFitLine2[1] + data["bODR2"],
268 ]
269 )
271 ax.plot(xsFitLineHW, ysFitLineHW, "w", lw=2)
272 (lineHW,) = ax.plot(xsFitLineHW, ysFitLineHW, "g", lw=1, ls="--", label="Hardwired")
274 ax.plot(xsFitLine, ysFitLine, "w", lw=2)
275 (lineInit,) = ax.plot(xsFitLine, ysFitLine, "b", lw=1, ls="--", label="Initial")
277 ax.plot(xsFitLine2, ysFitLine2, "w", lw=2)
278 (lineRefit,) = ax.plot(xsFitLine2, ysFitLine2, "k", lw=1, ls="--", label="Refit")
280 # Calculate the distances to that line
281 # Need two points to characterise the lines we want
282 # to get the distances to
283 p1 = np.array([xsFitLine[0], ysFitLine[0]])
284 p2 = np.array([xsFitLine[1], ysFitLine[1]])
286 p1HW = np.array([xsFitLine[0], ysFitLineHW[0]])
287 p2HW = np.array([xsFitLine[1], ysFitLineHW[1]])
289 distsHW = perpDistance(p1HW, p2HW, zip(xs[fitPoints], ys[fitPoints]))
290 dists = perpDistance(p1, p2, zip(xs[fitPoints], ys[fitPoints]))
292 # Now we have the information for the perpendicular line we
293 # can use it to calculate the points at the ends of the
294 # perpendicular lines that intersect at the box edges
295 if np.fabs(data["mHW"]) > 1:
296 xMid = (data["yMin"] - data["bODR2"]) / data["mODR2"]
297 xs = np.array([xMid - 0.5, xMid, xMid + 0.5])
298 ys = data["mPerp"] * xs + data["bPerpMin"]
299 else:
300 xs = np.array([data["xMin"] - 0.2, data["xMin"], data["xMin"] + 0.2])
301 ys = xs * data["mPerp"] + data["bPerpMin"]
302 ax.plot(xs, ys, "k--", alpha=0.7)
304 if np.fabs(data["mHW"]) > 1:
305 xMid = (data["yMax"] - data["bODR2"]) / data["mODR2"]
306 xs = np.array([xMid - 0.5, xMid, xMid + 0.5])
307 ys = data["mPerp"] * xs + data["bPerpMax"]
308 else:
309 xs = np.array([data["xMax"] - 0.2, data["xMax"], data["xMax"] + 0.2])
310 ys = xs * data["mPerp"] + data["bPerpMax"]
311 ax.plot(xs, ys, "k--", alpha=0.7)
313 # Add a histogram
314 axHist.set_ylabel("Number")
315 axHist.set_xlabel("Distance to Line Fit")
316 medDists = np.median(dists)
317 madDists = sigmaMad(dists)
318 meanDists = np.mean(dists)
320 axHist.set_xlim(meanDists - 2.0 * madDists, meanDists + 2.0 * madDists)
321 lineMedian = axHist.axvline(medDists, color="k", label="Median: {:0.3f}".format(medDists))
322 lineMad = axHist.axvline(
323 medDists + madDists, color="k", ls="--", label="sigma MAD: {:0.3f}".format(madDists)
324 )
325 axHist.axvline(medDists - madDists, color="k", ls="--")
327 linesForLegend = [lineHW, lineInit, lineRefit, fitScatter, lineMedian, lineMad]
328 fig.legend(
329 handles=linesForLegend,
330 fontsize=8,
331 bbox_to_anchor=(1.0, 0.99),
332 bbox_transform=fig.transFigure,
333 ncol=2,
334 )
336 axHist.hist(dists, bins=100, histtype="step", label="Refit", color="C0")
337 axHist.hist(distsHW, bins=100, histtype="step", label="HW", color="C0", alpha=0.5)
339 alphas = [1.0, 0.5]
340 handles = [Rectangle((0, 0), 1, 1, color="none", ec="C0", alpha=a) for a in alphas]
341 labels = ["Refit", "HW"]
342 axHist.legend(handles, labels, fontsize=6, loc="upper right")
344 # Add a contour plot showing the magnitude dependance
345 # of the distance to the fit
346 axContour.invert_yaxis()
347 axContour.axvline(0.0, color="k", ls="--", zorder=-1)
348 percsDists = np.nanpercentile(dists, [4, 96])
349 minXs = -1 * np.min(np.fabs(percsDists))
350 maxXs = np.min(np.fabs(percsDists))
351 plotPoints = (dists < maxXs) & (dists > minXs)
352 xs = np.array(dists)[plotPoints]
353 ys = cast(Vector, cast(Vector, mags)[cast(Vector, fitPoints)])[cast(Vector, plotPoints)]
354 H, xEdges, yEdges = np.histogram2d(xs, ys, bins=(11, 11))
355 xBinWidth = xEdges[1] - xEdges[0]
356 yBinWidth = yEdges[1] - yEdges[0]
357 axContour.contour(
358 xEdges[:-1] + xBinWidth / 2, yEdges[:-1] + yBinWidth / 2, H.T, levels=7, cmap=newBlues
359 )
360 axContour.set_xlabel("Distance to Line Fit")
361 axContour.set_ylabel(self.magLabel)
363 fig = addPlotInfo(plt.gcf(), plotInfo)
365 return fig