Coverage for python/lsst/analysis/tools/actions/plot/colorColorFitPlot.py: 13%
166 statements
« prev ^ index » next coverage.py v6.5.0, created at 2022-10-04 03:16 -0700
« prev ^ index » next coverage.py v6.5.0, created at 2022-10-04 03:16 -0700
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24from typing import Mapping, Optional, cast
26import matplotlib.patheffects as pathEffects
27import matplotlib.pyplot as plt
28import numpy as np
29from lsst.analysis.tools import PlotAction
30from lsst.pex.config import Field, ListField
31from matplotlib.figure import Figure
32from matplotlib.patches import Rectangle
33from sklearn.neighbors import KernelDensity
35from ...interfaces import KeyedData, KeyedDataSchema, Scalar, Vector
36from ...statistics import sigmaMad
37from .plotUtils import addPlotInfo, mkColormap, perpDistance
40class ColorColorFitPlot(PlotAction):
42 xAxisLabel = Field[str](doc="Label to use for the x axis", optional=False)
43 yAxisLabel = Field[str](doc="Label to use for the y axis", optional=False)
44 magLabel = Field[str](doc="Label to use for the magnitudes used to color code by", optional=False)
46 plotTypes = ListField[str](
47 doc="Selection of types of objects to plot. Can take any combination of"
48 " stars, galaxies, unknown, mag, any.",
49 default=["stars"],
50 )
52 plotName = Field[str](doc="The name for the plot.", optional=False)
54 def getInputSchema(self, **kwargs) -> KeyedDataSchema:
55 base: list[tuple[str, type[Vector] | type[Scalar]]] = []
56 base.append(("x", Vector))
57 base.append(("y", Vector))
58 base.append(("mag", Vector))
59 base.append(("approxMagDepth", Scalar))
60 base.append((f"{self.plotName}_sigmaMAD", Scalar))
61 base.append((f"{self.plotName}_median", Scalar))
62 base.append((f"{self.plotName}_hardwired_sigmaMAD", Scalar))
63 base.append((f"{self.plotName}_hardwired_median", Scalar))
64 base.append(("xMin", Scalar))
65 base.append(("xMax", Scalar))
66 base.append(("yMin", Scalar))
67 base.append(("yMax", Scalar))
68 base.append(("mHW", Scalar))
69 base.append(("bHW", Scalar))
70 base.append(("mODR", Scalar))
71 base.append(("bODR", Scalar))
72 base.append(("yBoxMin", Scalar))
73 base.append(("yBoxMax", Scalar))
74 base.append(("bPerpMin", Scalar))
75 base.append(("bPerpMax", Scalar))
76 base.append(("mODR2", Scalar))
77 base.append(("bODR2", Scalar))
78 base.append(("mPerp", Scalar))
80 return base
82 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure:
83 self._validateInput(data, **kwargs)
84 return self.makePlot(data, **kwargs)
86 def _validateInput(self, data: KeyedData, **kwargs) -> None:
87 """NOTE currently can only check that something is not a scalar, not
88 check that data is consistent with Vector
89 """
90 needed = self.getInputSchema(**kwargs)
91 if remainder := {key.format(**kwargs) for key, _ in needed} - {
92 key.format(**kwargs) for key in data.keys()
93 }:
94 raise ValueError(f"Task needs keys {remainder} but they were not in input")
95 for name, typ in needed:
96 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar)
97 if isScalar and typ != Scalar:
98 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}")
100 def makePlot(
101 self,
102 data: KeyedData,
103 plotInfo: Optional[Mapping[str, str]] = None,
104 **kwargs,
105 ) -> Figure:
106 """Make stellar locus plots using pre fitted values.
108 Parameters
109 ----------
110 catPlot : `pandas.core.frame.DataFrame`
111 The catalog to plot the points from.
112 plotInfo : `dict`
113 A dictionary of information about the data being plotted with keys:
114 ``"run"``
115 The output run for the plots (`str`).
116 ``"skymap"``
117 The type of skymap used for the data (`str`).
118 ``"filter"``
119 The filter used for this data (`str`).
120 ``"tract"``
121 The tract that the data comes from (`str`).
122 fitParams : `dict`
123 The parameters of the fit to the stellar locus calculated
124 elsewhere, they are used to plot the fit line on the
125 figure.
126 ``"bHW"``
127 The hardwired intercept to fall back on.
128 ``"b_odr"``
129 The intercept calculated by the orthogonal distance
130 regression fitting.
131 ``"mHW"``
132 The hardwired gradient to fall back on.
133 ``"m_odr"``
134 The gradient calculated by the orthogonal distance
135 regression fitting.
136 ``"magLim"``
137 The magnitude limit used in the fitting.
138 ``"x1`"``
139 The x minimum of the box used in the fit.
140 ``"x2"``
141 The x maximum of the box used in the fit.
142 ``"y1"``
143 The y minimum of the box used in the fit.
144 ``"y2"``
145 The y maximum of the box used in the fit.
147 Returns
148 -------
149 fig : `matplotlib.figure.Figure`
150 The resulting figure.
152 Notes
153 -----
154 Makes a color-color plot of `self.config.xColName` against
155 `self.config.yColName`, these points are color coded by i band
156 CModel magnitude. The stellar locus fits calculated from
157 the calcStellarLocus task are then overplotted. The axis labels
158 are given by `self.config.xLabel` and `self.config.yLabel`.
159 The selector given in `self.config.sourceSelectorActions`
160 is used for source selection. The distance of the points to
161 the fit line is given in a histogram in the second panel.
162 """
164 # Define a new colormap
165 newBlues = mkColormap(["paleturquoise", "midnightblue"])
167 # Make a figure with three panels
168 fig = plt.figure(dpi=300)
169 ax = fig.add_axes([0.12, 0.25, 0.43, 0.60])
170 axContour = fig.add_axes([0.65, 0.11, 0.3, 0.31])
171 axHist = fig.add_axes([0.65, 0.51, 0.3, 0.31])
173 # Check for nans/infs
174 goodPoints = np.isfinite(data["x"]) & np.isfinite(data["y"]) & np.isfinite(data["mag"])
175 xs = cast(Vector, data["x"])[goodPoints]
176 ys = cast(Vector, data["y"])[goodPoints]
177 mags = cast(Vector, data["mag"])[goodPoints]
179 # TODO: Make a no data fig function and use here
180 if len(xs) == 0 or len(ys) == 0:
181 return fig
183 # Points to use for the fit
184 # type ignore because Vector needs a prototype interface
185 fitPoints = np.where(
186 (xs > data["xMin"]) # type: ignore
187 & (xs < data["xMax"]) # type: ignore
188 & (ys > data["yMin"]) # type: ignore
189 & (ys < data["yMax"]) # type: ignore
190 )[0]
192 # Plot the initial fit box
193 ax.plot(
194 [data["xMin"], data["xMax"], data["xMax"], data["xMin"], data["xMin"]],
195 [data["yMin"], data["yMin"], data["yMax"], data["yMax"], data["yMin"]],
196 "k",
197 alpha=0.3,
198 )
200 # Add some useful information to the plot
201 bbox = dict(alpha=0.9, facecolor="white", edgecolor="none")
202 medMag = np.median(cast(Vector, mags))
204 # TODO: GET THE SN FROM THE EARLIER PREP STEP
205 SN = "-"
206 infoText = "N Used: {}\nN Total: {}\nS/N cut: {}\n".format(len(fitPoints), len(xs), SN)
207 infoText += r"Mag $\lesssim$: " + "{:0.2f}".format(medMag)
208 ax.text(0.05, 0.78, infoText, color="k", transform=ax.transAxes, fontsize=8, bbox=bbox)
210 # Calculate the density of the points
211 xy = np.vstack([xs, ys]).T
212 kde = KernelDensity(kernel="gaussian").fit(xy)
213 z = np.exp(kde.score_samples(xy))
215 ax.scatter(xs[~fitPoints], ys[~fitPoints], c=z[~fitPoints], cmap="binary", s=0.3)
216 fitScatter = ax.scatter(
217 xs[fitPoints], ys[fitPoints], c=z[fitPoints], cmap=newBlues, label="Used for Fit", s=0.3
218 )
220 # Add colorbar
221 cbAx = fig.add_axes([0.12, 0.08, 0.43, 0.04])
222 plt.colorbar(fitScatter, cax=cbAx, orientation="horizontal")
223 cbText = cbAx.text(
224 0.5,
225 0.5,
226 "Number Density",
227 color="k",
228 rotation="horizontal",
229 transform=cbAx.transAxes,
230 ha="center",
231 va="center",
232 fontsize=8,
233 )
234 cbText.set_path_effects([pathEffects.Stroke(linewidth=3, foreground="w"), pathEffects.Normal()])
235 cbAx.set_xticks([np.min(z[fitPoints]), np.max(z[fitPoints])], labels=["Less", "More"])
237 ax.set_xlabel(self.xAxisLabel)
238 ax.set_ylabel(self.yAxisLabel)
240 # Set useful axis limits
241 percsX = np.nanpercentile(xs, [0.5, 99.5])
242 percsY = np.nanpercentile(ys, [0.5, 99.5])
243 x5 = (percsX[1] - percsX[0]) / 5
244 y5 = (percsY[1] - percsY[0]) / 5
245 ax.set_xlim(percsX[0] - x5, percsX[1] + x5)
246 ax.set_ylim(percsY[0] - y5, percsY[1] + y5)
248 # Plot the fit lines
249 if np.fabs(data["mHW"]) > 1:
250 ysFitLineHW = np.array([data["yMin"], data["yMax"]])
251 xsFitLineHW = (ysFitLineHW - data["bHW"]) / data["mHW"]
252 ysFitLine = np.array([data["yMin"], data["yMax"]])
253 xsFitLine = (ysFitLine - data["bODR"]) / data["mODR"]
254 ysFitLine2 = np.array([data["yMin"], data["yMax"]])
255 xsFitLine2 = (ysFitLine2 - data["bODR2"]) / data["mODR2"]
257 else:
258 xsFitLineHW = np.array([data["xMin"], data["xMax"]])
259 ysFitLineHW = data["mHW"] * xsFitLineHW + data["bHW"] # type: ignore
260 xsFitLine = np.array([data["xMin"], data["xMax"]])
261 ysFitLine = np.array(
262 [
263 data["mODR"] * xsFitLine[0] + data["bODR"],
264 data["mODR"] * xsFitLine[1] + data["bODR"],
265 ]
266 )
267 xsFitLine2 = np.array([data["xMin"], data["xMax"]])
268 ysFitLine2 = np.array(
269 [
270 data["mODR2"] * xsFitLine2[0] + data["bODR2"],
271 data["mODR2"] * xsFitLine2[1] + data["bODR2"],
272 ]
273 )
275 ax.plot(xsFitLineHW, ysFitLineHW, "w", lw=2)
276 (lineHW,) = ax.plot(xsFitLineHW, ysFitLineHW, "g", lw=1, ls="--", label="Hardwired")
278 ax.plot(xsFitLine, ysFitLine, "w", lw=2)
279 (lineInit,) = ax.plot(xsFitLine, ysFitLine, "b", lw=1, ls="--", label="Initial")
281 ax.plot(xsFitLine2, ysFitLine2, "w", lw=2)
282 (lineRefit,) = ax.plot(xsFitLine2, ysFitLine2, "k", lw=1, ls="--", label="Refit")
284 # Calculate the distances to that line
285 # Need two points to characterise the lines we want
286 # to get the distances to
287 p1 = np.array([xsFitLine[0], ysFitLine[0]])
288 p2 = np.array([xsFitLine[1], ysFitLine[1]])
290 p1HW = np.array([xsFitLine[0], ysFitLineHW[0]])
291 p2HW = np.array([xsFitLine[1], ysFitLineHW[1]])
293 distsHW = perpDistance(p1HW, p2HW, zip(xs[fitPoints], ys[fitPoints]))
294 dists = perpDistance(p1, p2, zip(xs[fitPoints], ys[fitPoints]))
296 # Now we have the information for the perpendicular line we
297 # can use it to calculate the points at the ends of the
298 # perpendicular lines that intersect at the box edges
299 if np.fabs(data["mHW"]) > 1:
300 xMid = (data["yMin"] - data["bODR2"]) / data["mODR2"]
301 xs = np.array([xMid - 0.5, xMid, xMid + 0.5])
302 ys = data["mPerp"] * xs + data["bPerpMin"]
303 else:
304 xs = np.array([data["xMin"] - 0.2, data["xMin"], data["xMin"] + 0.2])
305 ys = xs * data["mPerp"] + data["bPerpMin"]
306 ax.plot(xs, ys, "k--", alpha=0.7)
308 if np.fabs(data["mHW"]) > 1:
309 xMid = (data["yMax"] - data["bODR2"]) / data["mODR2"]
310 xs = np.array([xMid - 0.5, xMid, xMid + 0.5])
311 ys = data["mPerp"] * xs + data["bPerpMax"]
312 else:
313 xs = np.array([data["xMax"] - 0.2, data["xMax"], data["xMax"] + 0.2])
314 ys = xs * data["mPerp"] + data["bPerpMax"]
315 ax.plot(xs, ys, "k--", alpha=0.7)
317 # Add a histogram
318 axHist.set_ylabel("Number")
319 axHist.set_xlabel("Distance to Line Fit")
320 medDists = np.median(dists)
321 madDists = sigmaMad(dists)
322 meanDists = np.mean(dists)
324 axHist.set_xlim(meanDists - 2.0 * madDists, meanDists + 2.0 * madDists)
325 lineMedian = axHist.axvline(medDists, color="k", label="Median: {:0.3f}".format(medDists))
326 lineMad = axHist.axvline(
327 medDists + madDists, color="k", ls="--", label="sigma MAD: {:0.3f}".format(madDists)
328 )
329 axHist.axvline(medDists - madDists, color="k", ls="--")
331 linesForLegend = [lineHW, lineInit, lineRefit, fitScatter, lineMedian, lineMad]
332 fig.legend(
333 handles=linesForLegend,
334 fontsize=8,
335 bbox_to_anchor=(1.0, 0.99),
336 bbox_transform=fig.transFigure,
337 ncol=2,
338 )
340 axHist.hist(dists, bins=100, histtype="step", label="Refit", color="C0")
341 axHist.hist(distsHW, bins=100, histtype="step", label="HW", color="C0", alpha=0.5)
343 alphas = [1.0, 0.5]
344 handles = [Rectangle((0, 0), 1, 1, color="none", ec="C0", alpha=a) for a in alphas]
345 labels = ["Refit", "HW"]
346 axHist.legend(handles, labels, fontsize=6, loc="upper right")
348 # Add a contour plot showing the magnitude dependance
349 # of the distance to the fit
350 axContour.invert_yaxis()
351 axContour.axvline(0.0, color="k", ls="--", zorder=-1)
352 percsDists = np.nanpercentile(dists, [4, 96])
353 minXs = -1 * np.min(np.fabs(percsDists))
354 maxXs = np.min(np.fabs(percsDists))
355 plotPoints = (dists < maxXs) & (dists > minXs)
356 xs = np.array(dists)[plotPoints]
357 ys = cast(Vector, cast(Vector, mags)[cast(Vector, fitPoints)])[cast(Vector, plotPoints)]
358 H, xEdges, yEdges = np.histogram2d(xs, ys, bins=(11, 11))
359 xBinWidth = xEdges[1] - xEdges[0]
360 yBinWidth = yEdges[1] - yEdges[0]
361 axContour.contour(
362 xEdges[:-1] + xBinWidth / 2, yEdges[:-1] + yBinWidth / 2, H.T, levels=7, cmap=newBlues
363 )
364 axContour.set_xlabel("Distance to Line Fit")
365 axContour.set_ylabel(self.magLabel)
367 fig = addPlotInfo(plt.gcf(), plotInfo)
369 return fig