Coverage for python/lsst/analysis/tools/actions/plot/colorColorFitPlot.py: 14%
165 statements
« prev ^ index » next coverage.py v6.4.4, created at 2022-09-01 03:10 -0700
« prev ^ index » next coverage.py v6.4.4, created at 2022-09-01 03:10 -0700
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24from typing import Mapping, Optional, cast
26import matplotlib.patheffects as pathEffects
27import matplotlib.pyplot as plt
28import numpy as np
29from lsst.analysis.tools import PlotAction
30from lsst.pex.config import Field, ListField
31from matplotlib.figure import Figure
32from matplotlib.patches import Rectangle
33from sklearn.neighbors import KernelDensity
35from ...interfaces import KeyedData, KeyedDataSchema, Scalar, Vector
36from ...statistics import sigmaMad
37from .plotUtils import addPlotInfo, mkColormap, perpDistance
40class ColorColorFitPlot(PlotAction):
42 xAxisLabel = Field[str](doc="Label to use for the x axis", optional=False)
43 yAxisLabel = Field[str](doc="Label to use for the y axis", optional=False)
44 magLabel = Field[str](doc="Label to use for the magnitudes used to color code by", optional=False)
46 plotTypes = ListField[str](
47 doc="Selection of types of objects to plot. Can take any combination of"
48 " stars, galaxies, unknown, mag, any.",
49 optional=False,
50 )
52 plotName = Field[str](doc="The name for the plot.", optional=False)
54 def getInputSchema(self, **kwargs) -> KeyedDataSchema:
55 base: list[tuple[str, type[Vector] | type[Scalar]]] = []
56 base.append(("x", Vector))
57 base.append(("y", Vector))
58 base.append(("mag", Vector))
59 base.append(("approxMagDepth", Scalar))
60 base.append((f"{self.plotName}_sigmaMAD", Scalar))
61 base.append((f"{self.plotName}_median", Scalar))
62 base.append((f"{self.plotName}_hardwired_sigmaMAD", Scalar))
63 base.append((f"{self.plotName}_hardwired_median", Scalar))
64 base.append(("xMin", Scalar))
65 base.append(("xMax", Scalar))
66 base.append(("yMin", Scalar))
67 base.append(("yMax", Scalar))
68 base.append(("mHW", Scalar))
69 base.append(("bHW", Scalar))
70 base.append(("mODR", Scalar))
71 base.append(("bODR", Scalar))
72 base.append(("yBoxMin", Scalar))
73 base.append(("yBoxMax", Scalar))
74 base.append(("bPerpMin", Scalar))
75 base.append(("bPerpMax", Scalar))
76 base.append(("mODR2", Scalar))
77 base.append(("bODR2", Scalar))
78 base.append(("mPerp", Scalar))
80 return base
82 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure:
83 self._validateInput(data, **kwargs)
84 return self.makePlot(data, **kwargs)
86 def _validateInput(self, data: KeyedData, **kwargs) -> None:
87 """NOTE currently can only check that something is not a scalar, not
88 check that data is consistent with Vector
89 """
90 needed = self.getInputSchema(**kwargs)
91 if remainder := {key.format(**kwargs) for key, _ in needed} - {
92 key.format(**kwargs) for key in data.keys()
93 }:
94 raise ValueError(f"Task needs keys {remainder} but they were not in input")
95 for name, typ in needed:
96 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar)
97 if isScalar and typ != Scalar:
98 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}")
100 def makePlot(
101 self,
102 data: KeyedData,
103 plotInfo: Optional[Mapping[str, str]] = None,
104 **kwargs,
105 ) -> Figure:
106 """Make stellar locus plots using pre fitted values.
108 Parameters
109 ----------
110 catPlot : `pandas.core.frame.DataFrame`
111 The catalog to plot the points from.
112 plotInfo : `dict`
113 A dictionary of information about the data being plotted with keys:
114 ``"run"``
115 The output run for the plots (`str`).
116 ``"skymap"``
117 The type of skymap used for the data (`str`).
118 ``"filter"``
119 The filter used for this data (`str`).
120 ``"tract"``
121 The tract that the data comes from (`str`).
122 fitParams : `dict`
123 The parameters of the fit to the stellar locus calculated
124 elsewhere, they are used to plot the fit line on the
125 figure.
126 ``"bHW"``
127 The hardwired intercept to fall back on.
128 ``"b_odr"``
129 The intercept calculated by the orthogonal distance
130 regression fitting.
131 ``"mHW"``
132 The hardwired gradient to fall back on.
133 ``"m_odr"``
134 The gradient calculated by the orthogonal distance
135 regression fitting.
136 ``"magLim"``
137 The magnitude limit used in the fitting.
138 ``"x1`"``
139 The x minimum of the box used in the fit.
140 ``"x2"``
141 The x maximum of the box used in the fit.
142 ``"y1"``
143 The y minimum of the box used in the fit.
144 ``"y2"``
145 The y maximum of the box used in the fit.
147 Returns
148 -------
149 fig : `matplotlib.figure.Figure`
150 The resulting figure.
152 Notes
153 -----
154 Makes a color-color plot of `self.config.xColName` against
155 `self.config.yColName`, these points are color coded by i band
156 CModel magnitude. The stellar locus fits calculated from
157 the calcStellarLocus task are then overplotted. The axis labels
158 are given by `self.config.xLabel` and `self.config.yLabel`.
159 The selector given in `self.config.sourceSelectorActions`
160 is used for source selection. The distance of the points to
161 the fit line is given in a histogram in the second panel.
162 """
164 # Define a new colormap
165 newBlues = mkColormap(["paleturquoise", "midnightblue"])
167 # Make a figure with three panels
168 fig = plt.figure(dpi=300)
169 ax = fig.add_axes([0.12, 0.25, 0.43, 0.60])
170 axContour = fig.add_axes([0.65, 0.11, 0.3, 0.31])
171 axHist = fig.add_axes([0.65, 0.51, 0.3, 0.31])
172 xs = cast(Vector, data["x"])
173 ys = cast(Vector, data["y"])
174 mags = data["mag"]
176 # TODO: Make a no data fig function and use here
177 if len(xs) == 0 or len(ys) == 0:
178 return fig
180 # Points to use for the fit
181 # type ignore because Vector needs a prototype interface
182 fitPoints = np.where(
183 (xs > data["xMin"]) # type: ignore
184 & (xs < data["xMax"]) # type: ignore
185 & (ys > data["yMin"]) # type: ignore
186 & (ys < data["yMax"]) # type: ignore
187 )[0]
189 # Plot the initial fit box
190 ax.plot(
191 [data["xMin"], data["xMax"], data["xMax"], data["xMin"], data["xMin"]],
192 [data["yMin"], data["yMin"], data["yMax"], data["yMax"], data["yMin"]],
193 "k",
194 alpha=0.3,
195 )
197 # Add some useful information to the plot
198 bbox = dict(alpha=0.9, facecolor="white", edgecolor="none")
199 medMag = np.median(cast(Vector, mags))
201 # TODO: GET THE SN FROM THE EARLIER PREP STEP
202 SN = "-"
203 infoText = "N Used: {}\nN Total: {}\nS/N cut: {}\n".format(len(fitPoints), len(xs), SN)
204 infoText += r"Mag $\lesssim$: " + "{:0.2f}".format(medMag)
205 ax.text(0.05, 0.78, infoText, color="k", transform=ax.transAxes, fontsize=8, bbox=bbox)
207 # Calculate the density of the points
208 xy = np.vstack([xs, ys]).T
209 kde = KernelDensity(kernel="gaussian").fit(xy)
210 z = np.exp(kde.score_samples(xy))
212 ax.scatter(xs[~fitPoints], ys[~fitPoints], c=z[~fitPoints], cmap="binary", s=0.3)
213 fitScatter = ax.scatter(
214 xs[fitPoints], ys[fitPoints], c=z[fitPoints], cmap=newBlues, label="Used for Fit", s=0.3
215 )
217 # Add colorbar
218 cbAx = fig.add_axes([0.12, 0.08, 0.43, 0.04])
219 plt.colorbar(fitScatter, cax=cbAx, orientation="horizontal")
220 cbText = cbAx.text(
221 0.5,
222 0.5,
223 "Number Density",
224 color="k",
225 rotation="horizontal",
226 transform=cbAx.transAxes,
227 ha="center",
228 va="center",
229 fontsize=8,
230 )
231 cbText.set_path_effects([pathEffects.Stroke(linewidth=3, foreground="w"), pathEffects.Normal()])
232 cbAx.set_xticks([np.min(z[fitPoints]), np.max(z[fitPoints])], labels=["Less", "More"])
234 ax.set_xlabel(self.xAxisLabel)
235 ax.set_ylabel(self.yAxisLabel)
237 # Set useful axis limits
238 percsX = np.nanpercentile(xs, [0.5, 99.5])
239 percsY = np.nanpercentile(ys, [0.5, 99.5])
240 x5 = (percsX[1] - percsX[0]) / 5
241 y5 = (percsY[1] - percsY[0]) / 5
242 ax.set_xlim(percsX[0] - x5, percsX[1] + x5)
243 ax.set_ylim(percsY[0] - y5, percsY[1] + y5)
245 # Plot the fit lines
246 if np.fabs(data["mHW"]) > 1:
247 ysFitLineHW = np.array([data["yMin"], data["yMax"]])
248 xsFitLineHW = (ysFitLineHW - data["bHW"]) / data["mHW"]
249 ysFitLine = np.array([data["yMin"], data["yMax"]])
250 xsFitLine = (ysFitLine - data["bODR"]) / data["mODR"]
251 ysFitLine2 = np.array([data["yMin"], data["yMax"]])
252 xsFitLine2 = (ysFitLine2 - data["bODR2"]) / data["mODR2"]
254 else:
255 xsFitLineHW = np.array([data["xMin"], data["xMax"]])
256 ysFitLineHW = data["mHW"] * xsFitLineHW + data["bHW"] # type: ignore
257 xsFitLine = np.array([data["xMin"], data["xMax"]])
258 ysFitLine = np.array(
259 [
260 data["mODR"] * xsFitLine[0] + data["bODR"],
261 data["mODR"] * xsFitLine[1] + data["bODR"],
262 ]
263 )
264 xsFitLine2 = np.array([data["xMin"], data["xMax"]])
265 ysFitLine2 = np.array(
266 [
267 data["mODR2"] * xsFitLine2[0] + data["bODR2"],
268 data["mODR2"] * xsFitLine2[1] + data["bODR2"],
269 ]
270 )
272 ax.plot(xsFitLineHW, ysFitLineHW, "w", lw=2)
273 (lineHW,) = ax.plot(xsFitLineHW, ysFitLineHW, "g", lw=1, ls="--", label="Hardwired")
275 ax.plot(xsFitLine, ysFitLine, "w", lw=2)
276 (lineInit,) = ax.plot(xsFitLine, ysFitLine, "b", lw=1, ls="--", label="Initial")
278 ax.plot(xsFitLine2, ysFitLine2, "w", lw=2)
279 (lineRefit,) = ax.plot(xsFitLine2, ysFitLine2, "k", lw=1, ls="--", label="Refit")
281 # Calculate the distances to that line
282 # Need two points to characterise the lines we want
283 # to get the distances to
284 p1 = np.array([xsFitLine[0], ysFitLine[0]])
285 p2 = np.array([xsFitLine[1], ysFitLine[1]])
287 p1HW = np.array([xsFitLine[0], ysFitLineHW[0]])
288 p2HW = np.array([xsFitLine[1], ysFitLineHW[1]])
290 distsHW = perpDistance(p1HW, p2HW, zip(xs[fitPoints], ys[fitPoints]))
291 dists = perpDistance(p1, p2, zip(xs[fitPoints], ys[fitPoints]))
293 # Now we have the information for the perpendicular line we
294 # can use it to calculate the points at the ends of the
295 # perpendicular lines that intersect at the box edges
296 if np.fabs(data["mHW"]) > 1:
297 xMid = (data["yMin"] - data["bODR2"]) / data["mODR2"]
298 xs = np.array([xMid - 0.5, xMid, xMid + 0.5])
299 ys = data["mPerp"] * xs + data["bPerpMin"]
300 else:
301 xs = np.array([data["xMin"] - 0.2, data["xMin"], data["xMin"] + 0.2])
302 ys = xs * data["mPerp"] + data["bPerpMin"]
303 ax.plot(xs, ys, "k--", alpha=0.7)
305 if np.fabs(data["mHW"]) > 1:
306 xMid = (data["yMax"] - data["bODR2"]) / data["mODR2"]
307 xs = np.array([xMid - 0.5, xMid, xMid + 0.5])
308 ys = data["mPerp"] * xs + data["bPerpMax"]
309 else:
310 xs = np.array([data["xMax"] - 0.2, data["xMax"], data["xMax"] + 0.2])
311 ys = xs * data["mPerp"] + data["bPerpMax"]
312 ax.plot(xs, ys, "k--", alpha=0.7)
314 # Add a histogram
315 axHist.set_ylabel("Number")
316 axHist.set_xlabel("Distance to Line Fit")
317 medDists = np.median(dists)
318 madDists = sigmaMad(dists)
319 meanDists = np.mean(dists)
321 axHist.set_xlim(meanDists - 2.0 * madDists, meanDists + 2.0 * madDists)
322 lineMedian = axHist.axvline(medDists, color="k", label="Median: {:0.3f}".format(medDists))
323 lineMad = axHist.axvline(
324 medDists + madDists, color="k", ls="--", label="sigma MAD: {:0.3f}".format(madDists)
325 )
326 axHist.axvline(medDists - madDists, color="k", ls="--")
328 linesForLegend = [lineHW, lineInit, lineRefit, fitScatter, lineMedian, lineMad]
329 fig.legend(
330 handles=linesForLegend,
331 fontsize=8,
332 bbox_to_anchor=(1.0, 0.99),
333 bbox_transform=fig.transFigure,
334 ncol=2,
335 )
337 axHist.hist(dists, bins=100, histtype="step", label="Refit", color="C0")
338 axHist.hist(distsHW, bins=100, histtype="step", label="HW", color="C0", alpha=0.5)
340 alphas = [1.0, 0.5]
341 handles = [Rectangle((0, 0), 1, 1, color="none", ec="C0", alpha=a) for a in alphas]
342 labels = ["Refit", "HW"]
343 axHist.legend(handles, labels, fontsize=6, loc="upper right")
345 # Add a contour plot showing the magnitude dependance
346 # of the distance to the fit
347 axContour.invert_yaxis()
348 axContour.axvline(0.0, color="k", ls="--", zorder=-1)
349 percsDists = np.nanpercentile(dists, [4, 96])
350 minXs = -1 * np.min(np.fabs(percsDists))
351 maxXs = np.min(np.fabs(percsDists))
352 plotPoints = (dists < maxXs) & (dists > minXs)
353 xs = np.array(dists)[plotPoints]
354 ys = cast(Vector, cast(Vector, mags)[cast(Vector, fitPoints)])[cast(Vector, plotPoints)]
355 H, xEdges, yEdges = np.histogram2d(xs, ys, bins=(11, 11))
356 xBinWidth = xEdges[1] - xEdges[0]
357 yBinWidth = yEdges[1] - yEdges[0]
358 axContour.contour(
359 xEdges[:-1] + xBinWidth / 2, yEdges[:-1] + yBinWidth / 2, H.T, levels=7, cmap=newBlues
360 )
361 axContour.set_xlabel("Distance to Line Fit")
362 axContour.set_ylabel(self.magLabel)
364 fig = addPlotInfo(plt.gcf(), plotInfo)
366 return fig