Coverage for python/lsst/analysis/tools/actions/plot/colorColorFitPlot.py: 13%
166 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-15 18:40 -0800
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-15 18:40 -0800
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24from typing import Mapping, Optional, cast
26import matplotlib.patheffects as pathEffects
27import matplotlib.pyplot as plt
28import numpy as np
29from lsst.analysis.tools import PlotAction
30from lsst.pex.config import Field, ListField
31from matplotlib.figure import Figure
32from matplotlib.patches import Rectangle
33from sklearn.neighbors import KernelDensity
35from ...interfaces import KeyedData, KeyedDataSchema, Scalar, Vector
36from ...statistics import sigmaMad
37from .plotUtils import addPlotInfo, mkColormap, perpDistance
40class ColorColorFitPlot(PlotAction):
42 xAxisLabel = Field[str](doc="Label to use for the x axis", optional=False)
43 yAxisLabel = Field[str](doc="Label to use for the y axis", optional=False)
44 magLabel = Field[str](doc="Label to use for the magnitudes used to color code by", optional=False)
46 plotTypes = ListField[str](
47 doc="Selection of types of objects to plot. Can take any combination of"
48 " stars, galaxies, unknown, mag, any.",
49 default=["stars"],
50 )
52 plotName = Field[str](doc="The name for the plot.", optional=False)
54 def getInputSchema(self, **kwargs) -> KeyedDataSchema:
55 base: list[tuple[str, type[Vector] | type[Scalar]]] = []
56 base.append(("x", Vector))
57 base.append(("y", Vector))
58 base.append(("mag", Vector))
59 base.append(("approxMagDepth", Scalar))
60 base.append((f"{self.plotName}_sigmaMAD", Scalar))
61 base.append((f"{self.plotName}_median", Scalar))
62 base.append((f"{self.plotName}_hardwired_sigmaMAD", Scalar))
63 base.append((f"{self.plotName}_hardwired_median", Scalar))
64 base.append(("xMin", Scalar))
65 base.append(("xMax", Scalar))
66 base.append(("yMin", Scalar))
67 base.append(("yMax", Scalar))
68 base.append(("mHW", Scalar))
69 base.append(("bHW", Scalar))
70 base.append(("mODR", Scalar))
71 base.append(("bODR", Scalar))
72 base.append(("yBoxMin", Scalar))
73 base.append(("yBoxMax", Scalar))
74 base.append(("bPerpMin", Scalar))
75 base.append(("bPerpMax", Scalar))
76 base.append(("mODR2", Scalar))
77 base.append(("bODR2", Scalar))
78 base.append(("mPerp", Scalar))
80 return base
82 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure:
83 self._validateInput(data, **kwargs)
84 return self.makePlot(data, **kwargs)
86 def _validateInput(self, data: KeyedData, **kwargs) -> None:
87 """NOTE currently can only check that something is not a scalar, not
88 check that data is consistent with Vector
89 """
90 needed = self.getInputSchema(**kwargs)
91 if remainder := {key.format(**kwargs) for key, _ in needed} - {
92 key.format(**kwargs) for key in data.keys()
93 }:
94 raise ValueError(f"Task needs keys {remainder} but they were not in input")
95 for name, typ in needed:
96 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar)
97 if isScalar and typ != Scalar:
98 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}")
100 def makePlot(
101 self,
102 data: KeyedData,
103 plotInfo: Optional[Mapping[str, str]] = None,
104 **kwargs,
105 ) -> Figure:
106 """Make stellar locus plots using pre fitted values.
108 Parameters
109 ----------
110 catPlot : `pandas.core.frame.DataFrame`
111 The catalog to plot the points from.
112 plotInfo : `dict`
113 A dictionary of information about the data being plotted with keys:
114 ``"run"``
115 The output run for the plots (`str`).
116 ``"skymap"``
117 The type of skymap used for the data (`str`).
118 ``"filter"``
119 The filter used for this data (`str`).
120 ``"tract"``
121 The tract that the data comes from (`str`).
122 fitParams : `dict`
123 The parameters of the fit to the stellar locus calculated
124 elsewhere, they are used to plot the fit line on the
125 figure.
126 ``"bHW"``
127 The hardwired intercept to fall back on.
128 ``"b_odr"``
129 The intercept calculated by the orthogonal distance
130 regression fitting.
131 ``"mHW"``
132 The hardwired gradient to fall back on.
133 ``"m_odr"``
134 The gradient calculated by the orthogonal distance
135 regression fitting.
136 ``"magLim"``
137 The magnitude limit used in the fitting.
138 ``"x1`"``
139 The x minimum of the box used in the fit.
140 ``"x2"``
141 The x maximum of the box used in the fit.
142 ``"y1"``
143 The y minimum of the box used in the fit.
144 ``"y2"``
145 The y maximum of the box used in the fit.
147 Returns
148 -------
149 fig : `matplotlib.figure.Figure`
150 The resulting figure.
152 Notes
153 -----
154 Makes a color-color plot of `self.config.xColName` against
155 `self.config.yColName`, these points are color coded by i band
156 CModel magnitude. The stellar locus fits calculated from
157 the calcStellarLocus task are then overplotted. The axis labels
158 are given by `self.config.xLabel` and `self.config.yLabel`.
159 The selector given in `self.config.sourceSelectorActions`
160 is used for source selection. The distance of the points to
161 the fit line is given in a histogram in the second panel.
162 """
164 # Define a new colormap
165 newBlues = mkColormap(["paleturquoise", "midnightblue"])
167 # Make a figure with three panels
168 fig = plt.figure(dpi=300)
169 ax = fig.add_axes([0.12, 0.25, 0.43, 0.60])
170 axContour = fig.add_axes([0.65, 0.11, 0.3, 0.31])
171 axHist = fig.add_axes([0.65, 0.51, 0.3, 0.31])
173 # Check for nans/infs
174 goodPoints = np.isfinite(data["x"]) & np.isfinite(data["y"]) & np.isfinite(data["mag"])
175 xs = cast(Vector, data["x"])[goodPoints]
176 ys = cast(Vector, data["y"])[goodPoints]
177 mags = cast(Vector, data["mag"])[goodPoints]
179 # TODO: Make a no data fig function and use here
180 if len(xs) == 0 or len(ys) == 0:
181 return fig
183 # Points to use for the fit
184 # type ignore because Vector needs a prototype interface
185 fitPoints = np.where(
186 (xs > data["xMin"]) # type: ignore
187 & (xs < data["xMax"]) # type: ignore
188 & (ys > data["yMin"]) # type: ignore
189 & (ys < data["yMax"]) # type: ignore
190 )[0]
192 # Plot the initial fit box
193 ax.plot(
194 [data["xMin"], data["xMax"], data["xMax"], data["xMin"], data["xMin"]],
195 [data["yMin"], data["yMin"], data["yMax"], data["yMax"], data["yMin"]],
196 "k",
197 alpha=0.3,
198 )
200 # Add some useful information to the plot
201 bbox = dict(alpha=0.9, facecolor="white", edgecolor="none")
202 medMag = np.nanmedian(cast(Vector, mags))
204 # TODO: GET THE SN FROM THE EARLIER PREP STEP
205 SN = "-"
206 infoText = "N Used: {}\nN Total: {}\nS/N cut: {}\n".format(len(fitPoints), len(xs), SN)
207 infoText += r"Mag $\lesssim$: " + "{:0.2f}".format(medMag)
208 ax.text(0.05, 0.78, infoText, color="k", transform=ax.transAxes, fontsize=8, bbox=bbox)
210 # Calculate the density of the points
211 xy = np.vstack([xs, ys]).T
212 kde = KernelDensity(kernel="gaussian").fit(xy)
213 z = np.exp(kde.score_samples(xy))
215 ax.scatter(xs[~fitPoints], ys[~fitPoints], c=z[~fitPoints], cmap="binary", s=0.3)
216 fitScatter = ax.scatter(
217 xs[fitPoints], ys[fitPoints], c=z[fitPoints], cmap=newBlues, label="Used for Fit", s=0.3
218 )
220 # Add colorbar
221 cbAx = fig.add_axes([0.12, 0.08, 0.43, 0.04])
222 plt.colorbar(fitScatter, cax=cbAx, orientation="horizontal")
223 cbText = cbAx.text(
224 0.5,
225 0.5,
226 "Number Density",
227 color="k",
228 rotation="horizontal",
229 transform=cbAx.transAxes,
230 ha="center",
231 va="center",
232 fontsize=8,
233 )
234 cbText.set_path_effects([pathEffects.Stroke(linewidth=3, foreground="w"), pathEffects.Normal()])
235 cbAx.set_xticks([np.min(z[fitPoints]), np.max(z[fitPoints])], labels=["Less", "More"])
237 ax.set_xlabel(self.xAxisLabel)
238 ax.set_ylabel(self.yAxisLabel)
240 # Set useful axis limits
241 percsX = np.nanpercentile(xs, [0.5, 99.5])
242 percsY = np.nanpercentile(ys, [0.5, 99.5])
243 x5 = (percsX[1] - percsX[0]) / 5
244 y5 = (percsY[1] - percsY[0]) / 5
245 ax.set_xlim(percsX[0] - x5, percsX[1] + x5)
246 ax.set_ylim(percsY[0] - y5, percsY[1] + y5)
248 # Plot the fit lines
249 if np.fabs(data["mHW"]) > 1:
250 ysFitLineHW = np.array([data["yMin"], data["yMax"]])
251 xsFitLineHW = (ysFitLineHW - data["bHW"]) / data["mHW"]
252 ysFitLine = np.array([data["yMin"], data["yMax"]])
253 xsFitLine = (ysFitLine - data["bODR"]) / data["mODR"]
254 ysFitLine2 = np.array([data["yMin"], data["yMax"]])
255 xsFitLine2 = (ysFitLine2 - data["bODR2"]) / data["mODR2"]
257 else:
258 xsFitLineHW = np.array([data["xMin"], data["xMax"]])
259 ysFitLineHW = data["mHW"] * xsFitLineHW + data["bHW"] # type: ignore
260 xsFitLine = np.array([data["xMin"], data["xMax"]])
261 ysFitLine = np.array(
262 [
263 data["mODR"] * xsFitLine[0] + data["bODR"],
264 data["mODR"] * xsFitLine[1] + data["bODR"],
265 ]
266 )
267 xsFitLine2 = np.array([data["xMin"], data["xMax"]])
268 ysFitLine2 = np.array(
269 [
270 data["mODR2"] * xsFitLine2[0] + data["bODR2"],
271 data["mODR2"] * xsFitLine2[1] + data["bODR2"],
272 ]
273 )
275 ax.plot(xsFitLineHW, ysFitLineHW, "w", lw=2)
276 (lineHW,) = ax.plot(xsFitLineHW, ysFitLineHW, "g", lw=1, ls="--", label="Hardwired")
278 ax.plot(xsFitLine, ysFitLine, "w", lw=2)
279 (lineInit,) = ax.plot(xsFitLine, ysFitLine, "b", lw=1, ls="--", label="Initial")
281 ax.plot(xsFitLine2, ysFitLine2, "w", lw=2)
282 (lineRefit,) = ax.plot(xsFitLine2, ysFitLine2, "k", lw=1, ls="--", label="Refit")
284 # Calculate the distances to that line
285 # Need two points to characterise the lines we want
286 # to get the distances to
287 p1 = np.array([xsFitLine[0], ysFitLine[0]])
288 p2 = np.array([xsFitLine[1], ysFitLine[1]])
290 p1HW = np.array([xsFitLine[0], ysFitLineHW[0]])
291 p2HW = np.array([xsFitLine[1], ysFitLineHW[1]])
293 # Convert to millimags
294 distsHW = np.array(perpDistance(p1HW, p2HW, zip(xs[fitPoints], ys[fitPoints]))) * 1000
295 dists = np.array(perpDistance(p1, p2, zip(xs[fitPoints], ys[fitPoints]))) * 1000
297 # Now we have the information for the perpendicular line we
298 # can use it to calculate the points at the ends of the
299 # perpendicular lines that intersect at the box edges
300 if np.fabs(data["mHW"]) > 1:
301 xMid = (data["yMin"] - data["bODR2"]) / data["mODR2"]
302 xs = np.array([xMid - 0.5, xMid, xMid + 0.5])
303 ys = data["mPerp"] * xs + data["bPerpMin"]
304 else:
305 xs = np.array([data["xMin"] - 0.2, data["xMin"], data["xMin"] + 0.2])
306 ys = xs * data["mPerp"] + data["bPerpMin"]
307 ax.plot(xs, ys, "k--", alpha=0.7)
309 if np.fabs(data["mHW"]) > 1:
310 xMid = (data["yMax"] - data["bODR2"]) / data["mODR2"]
311 xs = np.array([xMid - 0.5, xMid, xMid + 0.5])
312 ys = data["mPerp"] * xs + data["bPerpMax"]
313 else:
314 xs = np.array([data["xMax"] - 0.2, data["xMax"], data["xMax"] + 0.2])
315 ys = xs * data["mPerp"] + data["bPerpMax"]
316 ax.plot(xs, ys, "k--", alpha=0.7)
318 # Add a histogram
319 axHist.set_ylabel("Number")
320 axHist.set_xlabel("Distance to Line Fit")
321 medDists = np.nanmedian(dists)
322 madDists = sigmaMad(dists, nan_policy="omit")
323 meanDists = np.nanmean(dists)
325 axHist.set_xlim(meanDists - 2.0 * madDists, meanDists + 2.0 * madDists)
326 lineMedian = axHist.axvline(medDists, color="k", label="Median: {:0.3f}".format(medDists))
327 lineMad = axHist.axvline(
328 medDists + madDists, color="k", ls="--", label="sigma MAD: {:0.3f}".format(madDists)
329 )
330 axHist.axvline(medDists - madDists, color="k", ls="--")
332 linesForLegend = [lineHW, lineInit, lineRefit, fitScatter, lineMedian, lineMad]
333 fig.legend(
334 handles=linesForLegend,
335 fontsize=8,
336 bbox_to_anchor=(1.0, 0.99),
337 bbox_transform=fig.transFigure,
338 ncol=2,
339 )
341 axHist.hist(dists, bins=100, histtype="step", label="Refit", color="C0")
342 axHist.hist(distsHW, bins=100, histtype="step", label="HW", color="C0", alpha=0.5)
344 alphas = [1.0, 0.5]
345 handles = [Rectangle((0, 0), 1, 1, color="none", ec="C0", alpha=a) for a in alphas]
346 labels = ["Refit", "HW"]
347 axHist.legend(handles, labels, fontsize=6, loc="upper right")
349 # Add a contour plot showing the magnitude dependance
350 # of the distance to the fit
351 axContour.invert_yaxis()
352 axContour.axvline(0.0, color="k", ls="--", zorder=-1)
353 percsDists = np.nanpercentile(dists, [4, 96])
354 minXs = -1 * np.min(np.fabs(percsDists))
355 maxXs = np.min(np.fabs(percsDists))
356 plotPoints = (dists < maxXs) & (dists > minXs)
357 xs = np.array(dists)[plotPoints]
358 ys = cast(Vector, cast(Vector, mags)[cast(Vector, fitPoints)])[cast(Vector, plotPoints)]
359 H, xEdges, yEdges = np.histogram2d(xs, ys, bins=(11, 11))
360 xBinWidth = xEdges[1] - xEdges[0]
361 yBinWidth = yEdges[1] - yEdges[0]
362 axContour.contour(
363 xEdges[:-1] + xBinWidth / 2, yEdges[:-1] + yBinWidth / 2, H.T, levels=7, cmap=newBlues
364 )
365 axContour.set_xlabel("Distance to Line Fit")
366 axContour.set_ylabel(self.magLabel)
368 fig = addPlotInfo(plt.gcf(), plotInfo)
370 return fig