Coverage for python/lsst/analysis/tools/actions/plot/colorColorFitPlot.py: 13%
167 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-03-31 12:04 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2023-03-31 12:04 +0000
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ("ColorColorFitPlot",)
26from typing import Mapping, cast
28import matplotlib.patheffects as pathEffects
29import matplotlib.pyplot as plt
30import numpy as np
31from lsst.analysis.tools import PlotAction
32from lsst.pex.config import Field, ListField
33from matplotlib.figure import Figure
34from matplotlib.patches import Rectangle
35from sklearn.neighbors import KernelDensity
37from ...interfaces import KeyedData, KeyedDataSchema, Scalar, Vector
38from ...statistics import sigmaMad
39from .plotUtils import addPlotInfo, mkColormap, perpDistance
42class ColorColorFitPlot(PlotAction):
43 xAxisLabel = Field[str](doc="Label to use for the x axis", optional=False)
44 yAxisLabel = Field[str](doc="Label to use for the y axis", optional=False)
45 magLabel = Field[str](doc="Label to use for the magnitudes used to color code by", optional=False)
47 plotTypes = ListField[str](
48 doc="Selection of types of objects to plot. Can take any combination of"
49 " stars, galaxies, unknown, mag, any.",
50 default=["stars"],
51 )
53 plotName = Field[str](doc="The name for the plot.", optional=False)
55 def getInputSchema(self, **kwargs) -> KeyedDataSchema:
56 base: list[tuple[str, type[Vector] | type[Scalar]]] = []
57 base.append(("x", Vector))
58 base.append(("y", Vector))
59 base.append(("mag", Vector))
60 base.append(("approxMagDepth", Scalar))
61 base.append((f"{self.plotName}_sigmaMAD", Scalar))
62 base.append((f"{self.plotName}_median", Scalar))
63 base.append((f"{self.plotName}_hardwired_sigmaMAD", Scalar))
64 base.append((f"{self.plotName}_hardwired_median", Scalar))
65 base.append(("xMin", Scalar))
66 base.append(("xMax", Scalar))
67 base.append(("yMin", Scalar))
68 base.append(("yMax", Scalar))
69 base.append(("mHW", Scalar))
70 base.append(("bHW", Scalar))
71 base.append(("mODR", Scalar))
72 base.append(("bODR", Scalar))
73 base.append(("yBoxMin", Scalar))
74 base.append(("yBoxMax", Scalar))
75 base.append(("bPerpMin", Scalar))
76 base.append(("bPerpMax", Scalar))
77 base.append(("mODR2", Scalar))
78 base.append(("bODR2", Scalar))
79 base.append(("mPerp", Scalar))
81 return base
83 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure:
84 self._validateInput(data, **kwargs)
85 return self.makePlot(data, **kwargs)
87 def _validateInput(self, data: KeyedData, **kwargs) -> None:
88 """NOTE currently can only check that something is not a scalar, not
89 check that data is consistent with Vector
90 """
91 needed = self.getInputSchema(**kwargs)
92 if remainder := {key.format(**kwargs) for key, _ in needed} - {
93 key.format(**kwargs) for key in data.keys()
94 }:
95 raise ValueError(f"Task needs keys {remainder} but they were not in input")
96 for name, typ in needed:
97 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar)
98 if isScalar and typ != Scalar:
99 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}")
101 def makePlot(
102 self,
103 data: KeyedData,
104 plotInfo: Mapping[str, str],
105 **kwargs,
106 ) -> Figure:
107 """Make stellar locus plots using pre fitted values.
109 Parameters
110 ----------
111 catPlot : `pandas.core.frame.DataFrame`
112 The catalog to plot the points from.
113 plotInfo : `dict`
114 A dictionary of information about the data being plotted with keys:
115 ``"run"``
116 The output run for the plots (`str`).
117 ``"skymap"``
118 The type of skymap used for the data (`str`).
119 ``"filter"``
120 The filter used for this data (`str`).
121 ``"tract"``
122 The tract that the data comes from (`str`).
123 fitParams : `dict`
124 The parameters of the fit to the stellar locus calculated
125 elsewhere, they are used to plot the fit line on the
126 figure.
127 ``"bHW"``
128 The hardwired intercept to fall back on.
129 ``"b_odr"``
130 The intercept calculated by the orthogonal distance
131 regression fitting.
132 ``"mHW"``
133 The hardwired gradient to fall back on.
134 ``"m_odr"``
135 The gradient calculated by the orthogonal distance
136 regression fitting.
137 ``"magLim"``
138 The magnitude limit used in the fitting.
139 ``"x1`"``
140 The x minimum of the box used in the fit.
141 ``"x2"``
142 The x maximum of the box used in the fit.
143 ``"y1"``
144 The y minimum of the box used in the fit.
145 ``"y2"``
146 The y maximum of the box used in the fit.
148 Returns
149 -------
150 fig : `matplotlib.figure.Figure`
151 The resulting figure.
153 Notes
154 -----
155 Makes a color-color plot of `self.config.xColName` against
156 `self.config.yColName`, these points are color coded by i band
157 CModel magnitude. The stellar locus fits calculated from
158 the calcStellarLocus task are then overplotted. The axis labels
159 are given by `self.config.xLabel` and `self.config.yLabel`.
160 The selector given in `self.config.sourceSelectorActions`
161 is used for source selection. The distance of the points to
162 the fit line is given in a histogram in the second panel.
163 """
165 # Define a new colormap
166 newBlues = mkColormap(["paleturquoise", "midnightblue"])
168 # Make a figure with three panels
169 fig = plt.figure(dpi=300)
170 ax = fig.add_axes([0.12, 0.25, 0.43, 0.60])
171 axContour = fig.add_axes([0.65, 0.11, 0.3, 0.31])
172 axHist = fig.add_axes([0.65, 0.51, 0.3, 0.31])
174 # Check for nans/infs
175 goodPoints = np.isfinite(data["x"]) & np.isfinite(data["y"]) & np.isfinite(data["mag"])
176 xs = cast(Vector, data["x"])[goodPoints]
177 ys = cast(Vector, data["y"])[goodPoints]
178 mags = cast(Vector, data["mag"])[goodPoints]
180 # TODO: Make a no data fig function and use here
181 if len(xs) == 0 or len(ys) == 0:
182 return fig
184 # Points to use for the fit
185 # type ignore because Vector needs a prototype interface
186 fitPoints = np.where(
187 (xs > data["xMin"]) # type: ignore
188 & (xs < data["xMax"]) # type: ignore
189 & (ys > data["yMin"]) # type: ignore
190 & (ys < data["yMax"]) # type: ignore
191 )[0]
193 # Plot the initial fit box
194 ax.plot(
195 [data["xMin"], data["xMax"], data["xMax"], data["xMin"], data["xMin"]],
196 [data["yMin"], data["yMin"], data["yMax"], data["yMax"], data["yMin"]],
197 "k",
198 alpha=0.3,
199 )
201 # Add some useful information to the plot
202 bbox = dict(alpha=0.9, facecolor="white", edgecolor="none")
203 medMag = np.nanmedian(cast(Vector, mags))
205 # TODO: GET THE SN FROM THE EARLIER PREP STEP
206 SN = "-"
207 infoText = "N Used: {}\nN Total: {}\nS/N cut: {}\n".format(len(fitPoints), len(xs), SN)
208 infoText += r"Mag $\lesssim$: " + "{:0.2f}".format(medMag)
209 ax.text(0.05, 0.78, infoText, color="k", transform=ax.transAxes, fontsize=8, bbox=bbox)
211 # Calculate the density of the points
212 xy = np.vstack([xs, ys]).T
213 kde = KernelDensity(kernel="gaussian").fit(xy)
214 z = np.exp(kde.score_samples(xy))
216 ax.scatter(xs[~fitPoints], ys[~fitPoints], c=z[~fitPoints], cmap="binary", s=0.3)
217 fitScatter = ax.scatter(
218 xs[fitPoints], ys[fitPoints], c=z[fitPoints], cmap=newBlues, label="Used for Fit", s=0.3
219 )
221 # Add colorbar
222 cbAx = fig.add_axes([0.12, 0.08, 0.43, 0.04])
223 plt.colorbar(fitScatter, cax=cbAx, orientation="horizontal")
224 cbText = cbAx.text(
225 0.5,
226 0.5,
227 "Number Density",
228 color="k",
229 rotation="horizontal",
230 transform=cbAx.transAxes,
231 ha="center",
232 va="center",
233 fontsize=8,
234 )
235 cbText.set_path_effects([pathEffects.Stroke(linewidth=3, foreground="w"), pathEffects.Normal()])
236 cbAx.set_xticks([np.min(z[fitPoints]), np.max(z[fitPoints])], labels=["Less", "More"])
238 ax.set_xlabel(self.xAxisLabel)
239 ax.set_ylabel(self.yAxisLabel)
241 # Set useful axis limits
242 percsX = np.nanpercentile(xs, [0.5, 99.5])
243 percsY = np.nanpercentile(ys, [0.5, 99.5])
244 x5 = (percsX[1] - percsX[0]) / 5
245 y5 = (percsY[1] - percsY[0]) / 5
246 ax.set_xlim(percsX[0] - x5, percsX[1] + x5)
247 ax.set_ylim(percsY[0] - y5, percsY[1] + y5)
249 # Plot the fit lines
250 if np.fabs(data["mHW"]) > 1:
251 ysFitLineHW = np.array([data["yMin"], data["yMax"]])
252 xsFitLineHW = (ysFitLineHW - data["bHW"]) / data["mHW"]
253 ysFitLine = np.array([data["yMin"], data["yMax"]])
254 xsFitLine = (ysFitLine - data["bODR"]) / data["mODR"]
255 ysFitLine2 = np.array([data["yMin"], data["yMax"]])
256 xsFitLine2 = (ysFitLine2 - data["bODR2"]) / data["mODR2"]
258 else:
259 xsFitLineHW = np.array([data["xMin"], data["xMax"]])
260 ysFitLineHW = data["mHW"] * xsFitLineHW + data["bHW"] # type: ignore
261 xsFitLine = np.array([data["xMin"], data["xMax"]])
262 ysFitLine = np.array(
263 [
264 data["mODR"] * xsFitLine[0] + data["bODR"],
265 data["mODR"] * xsFitLine[1] + data["bODR"],
266 ]
267 )
268 xsFitLine2 = np.array([data["xMin"], data["xMax"]])
269 ysFitLine2 = np.array(
270 [
271 data["mODR2"] * xsFitLine2[0] + data["bODR2"],
272 data["mODR2"] * xsFitLine2[1] + data["bODR2"],
273 ]
274 )
276 ax.plot(xsFitLineHW, ysFitLineHW, "w", lw=2)
277 (lineHW,) = ax.plot(xsFitLineHW, ysFitLineHW, "g", lw=1, ls="--", label="Hardwired")
279 ax.plot(xsFitLine, ysFitLine, "w", lw=2)
280 (lineInit,) = ax.plot(xsFitLine, ysFitLine, "b", lw=1, ls="--", label="Initial")
282 ax.plot(xsFitLine2, ysFitLine2, "w", lw=2)
283 (lineRefit,) = ax.plot(xsFitLine2, ysFitLine2, "k", lw=1, ls="--", label="Refit")
285 # Calculate the distances to that line
286 # Need two points to characterise the lines we want
287 # to get the distances to
288 p1 = np.array([xsFitLine[0], ysFitLine[0]])
289 p2 = np.array([xsFitLine[1], ysFitLine[1]])
291 p1HW = np.array([xsFitLine[0], ysFitLineHW[0]])
292 p2HW = np.array([xsFitLine[1], ysFitLineHW[1]])
294 # Convert to millimags
295 distsHW = np.array(perpDistance(p1HW, p2HW, zip(xs[fitPoints], ys[fitPoints]))) * 1000
296 dists = np.array(perpDistance(p1, p2, zip(xs[fitPoints], ys[fitPoints]))) * 1000
298 # Now we have the information for the perpendicular line we
299 # can use it to calculate the points at the ends of the
300 # perpendicular lines that intersect at the box edges
301 if np.fabs(data["mHW"]) > 1:
302 xMid = (data["yMin"] - data["bODR2"]) / data["mODR2"]
303 xs = np.array([xMid - 0.5, xMid, xMid + 0.5])
304 ys = data["mPerp"] * xs + data["bPerpMin"]
305 else:
306 xs = np.array([data["xMin"] - 0.2, data["xMin"], data["xMin"] + 0.2])
307 ys = xs * data["mPerp"] + data["bPerpMin"]
308 ax.plot(xs, ys, "k--", alpha=0.7)
310 if np.fabs(data["mHW"]) > 1:
311 xMid = (data["yMax"] - data["bODR2"]) / data["mODR2"]
312 xs = np.array([xMid - 0.5, xMid, xMid + 0.5])
313 ys = data["mPerp"] * xs + data["bPerpMax"]
314 else:
315 xs = np.array([data["xMax"] - 0.2, data["xMax"], data["xMax"] + 0.2])
316 ys = xs * data["mPerp"] + data["bPerpMax"]
317 ax.plot(xs, ys, "k--", alpha=0.7)
319 # Add a histogram
320 axHist.set_ylabel("Number")
321 axHist.set_xlabel("Distance to Line Fit")
322 medDists = np.nanmedian(dists)
323 madDists = sigmaMad(dists, nan_policy="omit")
324 meanDists = np.nanmean(dists)
326 axHist.set_xlim(meanDists - 2.0 * madDists, meanDists + 2.0 * madDists)
327 lineMedian = axHist.axvline(medDists, color="k", label="Median: {:0.3f}".format(medDists))
328 lineMad = axHist.axvline(
329 medDists + madDists, color="k", ls="--", label="sigma MAD: {:0.3f}".format(madDists)
330 )
331 axHist.axvline(medDists - madDists, color="k", ls="--")
333 linesForLegend = [lineHW, lineInit, lineRefit, fitScatter, lineMedian, lineMad]
334 fig.legend(
335 handles=linesForLegend,
336 fontsize=8,
337 bbox_to_anchor=(1.0, 0.99),
338 bbox_transform=fig.transFigure,
339 ncol=2,
340 )
342 axHist.hist(dists, bins=100, histtype="step", label="Refit", color="C0")
343 axHist.hist(distsHW, bins=100, histtype="step", label="HW", color="C0", alpha=0.5)
345 alphas = [1.0, 0.5]
346 handles = [Rectangle((0, 0), 1, 1, color="none", ec="C0", alpha=a) for a in alphas]
347 labels = ["Refit", "HW"]
348 axHist.legend(handles, labels, fontsize=6, loc="upper right")
350 # Add a contour plot showing the magnitude dependance
351 # of the distance to the fit
352 axContour.invert_yaxis()
353 axContour.axvline(0.0, color="k", ls="--", zorder=-1)
354 percsDists = np.nanpercentile(dists, [4, 96])
355 minXs = -1 * np.min(np.fabs(percsDists))
356 maxXs = np.min(np.fabs(percsDists))
357 plotPoints = (dists < maxXs) & (dists > minXs)
358 xs = np.array(dists)[plotPoints]
359 ys = cast(Vector, cast(Vector, mags)[cast(Vector, fitPoints)])[cast(Vector, plotPoints)]
360 H, xEdges, yEdges = np.histogram2d(xs, ys, bins=(11, 11))
361 xBinWidth = xEdges[1] - xEdges[0]
362 yBinWidth = yEdges[1] - yEdges[0]
363 axContour.contour(
364 xEdges[:-1] + xBinWidth / 2, yEdges[:-1] + yBinWidth / 2, H.T, levels=7, cmap=newBlues
365 )
366 axContour.set_xlabel("Distance to Line Fit")
367 axContour.set_ylabel(self.magLabel)
369 fig = addPlotInfo(plt.gcf(), plotInfo)
371 return fig