Coverage for python/lsst/analysis/tools/actions/plot/colorColorFitPlot.py: 13%
167 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-02-08 04:01 -0800
« prev ^ index » next coverage.py v6.5.0, created at 2023-02-08 04:01 -0800
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ("ColorColorFitPlot",)
26from typing import Mapping, cast
28import matplotlib.patheffects as pathEffects
29import matplotlib.pyplot as plt
30import numpy as np
31from lsst.analysis.tools import PlotAction
32from lsst.pex.config import Field, ListField
33from matplotlib.figure import Figure
34from matplotlib.patches import Rectangle
35from sklearn.neighbors import KernelDensity
37from ...interfaces import KeyedData, KeyedDataSchema, Scalar, Vector
38from ...statistics import sigmaMad
39from .plotUtils import addPlotInfo, mkColormap, perpDistance
42class ColorColorFitPlot(PlotAction):
44 xAxisLabel = Field[str](doc="Label to use for the x axis", optional=False)
45 yAxisLabel = Field[str](doc="Label to use for the y axis", optional=False)
46 magLabel = Field[str](doc="Label to use for the magnitudes used to color code by", optional=False)
48 plotTypes = ListField[str](
49 doc="Selection of types of objects to plot. Can take any combination of"
50 " stars, galaxies, unknown, mag, any.",
51 default=["stars"],
52 )
54 plotName = Field[str](doc="The name for the plot.", optional=False)
56 def getInputSchema(self, **kwargs) -> KeyedDataSchema:
57 base: list[tuple[str, type[Vector] | type[Scalar]]] = []
58 base.append(("x", Vector))
59 base.append(("y", Vector))
60 base.append(("mag", Vector))
61 base.append(("approxMagDepth", Scalar))
62 base.append((f"{self.plotName}_sigmaMAD", Scalar))
63 base.append((f"{self.plotName}_median", Scalar))
64 base.append((f"{self.plotName}_hardwired_sigmaMAD", Scalar))
65 base.append((f"{self.plotName}_hardwired_median", Scalar))
66 base.append(("xMin", Scalar))
67 base.append(("xMax", Scalar))
68 base.append(("yMin", Scalar))
69 base.append(("yMax", Scalar))
70 base.append(("mHW", Scalar))
71 base.append(("bHW", Scalar))
72 base.append(("mODR", Scalar))
73 base.append(("bODR", Scalar))
74 base.append(("yBoxMin", Scalar))
75 base.append(("yBoxMax", Scalar))
76 base.append(("bPerpMin", Scalar))
77 base.append(("bPerpMax", Scalar))
78 base.append(("mODR2", Scalar))
79 base.append(("bODR2", Scalar))
80 base.append(("mPerp", Scalar))
82 return base
84 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure:
85 self._validateInput(data, **kwargs)
86 return self.makePlot(data, **kwargs)
88 def _validateInput(self, data: KeyedData, **kwargs) -> None:
89 """NOTE currently can only check that something is not a scalar, not
90 check that data is consistent with Vector
91 """
92 needed = self.getInputSchema(**kwargs)
93 if remainder := {key.format(**kwargs) for key, _ in needed} - {
94 key.format(**kwargs) for key in data.keys()
95 }:
96 raise ValueError(f"Task needs keys {remainder} but they were not in input")
97 for name, typ in needed:
98 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar)
99 if isScalar and typ != Scalar:
100 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}")
102 def makePlot(
103 self,
104 data: KeyedData,
105 plotInfo: Mapping[str, str],
106 **kwargs,
107 ) -> Figure:
108 """Make stellar locus plots using pre fitted values.
110 Parameters
111 ----------
112 catPlot : `pandas.core.frame.DataFrame`
113 The catalog to plot the points from.
114 plotInfo : `dict`
115 A dictionary of information about the data being plotted with keys:
116 ``"run"``
117 The output run for the plots (`str`).
118 ``"skymap"``
119 The type of skymap used for the data (`str`).
120 ``"filter"``
121 The filter used for this data (`str`).
122 ``"tract"``
123 The tract that the data comes from (`str`).
124 fitParams : `dict`
125 The parameters of the fit to the stellar locus calculated
126 elsewhere, they are used to plot the fit line on the
127 figure.
128 ``"bHW"``
129 The hardwired intercept to fall back on.
130 ``"b_odr"``
131 The intercept calculated by the orthogonal distance
132 regression fitting.
133 ``"mHW"``
134 The hardwired gradient to fall back on.
135 ``"m_odr"``
136 The gradient calculated by the orthogonal distance
137 regression fitting.
138 ``"magLim"``
139 The magnitude limit used in the fitting.
140 ``"x1`"``
141 The x minimum of the box used in the fit.
142 ``"x2"``
143 The x maximum of the box used in the fit.
144 ``"y1"``
145 The y minimum of the box used in the fit.
146 ``"y2"``
147 The y maximum of the box used in the fit.
149 Returns
150 -------
151 fig : `matplotlib.figure.Figure`
152 The resulting figure.
154 Notes
155 -----
156 Makes a color-color plot of `self.config.xColName` against
157 `self.config.yColName`, these points are color coded by i band
158 CModel magnitude. The stellar locus fits calculated from
159 the calcStellarLocus task are then overplotted. The axis labels
160 are given by `self.config.xLabel` and `self.config.yLabel`.
161 The selector given in `self.config.sourceSelectorActions`
162 is used for source selection. The distance of the points to
163 the fit line is given in a histogram in the second panel.
164 """
166 # Define a new colormap
167 newBlues = mkColormap(["paleturquoise", "midnightblue"])
169 # Make a figure with three panels
170 fig = plt.figure(dpi=300)
171 ax = fig.add_axes([0.12, 0.25, 0.43, 0.60])
172 axContour = fig.add_axes([0.65, 0.11, 0.3, 0.31])
173 axHist = fig.add_axes([0.65, 0.51, 0.3, 0.31])
175 # Check for nans/infs
176 goodPoints = np.isfinite(data["x"]) & np.isfinite(data["y"]) & np.isfinite(data["mag"])
177 xs = cast(Vector, data["x"])[goodPoints]
178 ys = cast(Vector, data["y"])[goodPoints]
179 mags = cast(Vector, data["mag"])[goodPoints]
181 # TODO: Make a no data fig function and use here
182 if len(xs) == 0 or len(ys) == 0:
183 return fig
185 # Points to use for the fit
186 # type ignore because Vector needs a prototype interface
187 fitPoints = np.where(
188 (xs > data["xMin"]) # type: ignore
189 & (xs < data["xMax"]) # type: ignore
190 & (ys > data["yMin"]) # type: ignore
191 & (ys < data["yMax"]) # type: ignore
192 )[0]
194 # Plot the initial fit box
195 ax.plot(
196 [data["xMin"], data["xMax"], data["xMax"], data["xMin"], data["xMin"]],
197 [data["yMin"], data["yMin"], data["yMax"], data["yMax"], data["yMin"]],
198 "k",
199 alpha=0.3,
200 )
202 # Add some useful information to the plot
203 bbox = dict(alpha=0.9, facecolor="white", edgecolor="none")
204 medMag = np.nanmedian(cast(Vector, mags))
206 # TODO: GET THE SN FROM THE EARLIER PREP STEP
207 SN = "-"
208 infoText = "N Used: {}\nN Total: {}\nS/N cut: {}\n".format(len(fitPoints), len(xs), SN)
209 infoText += r"Mag $\lesssim$: " + "{:0.2f}".format(medMag)
210 ax.text(0.05, 0.78, infoText, color="k", transform=ax.transAxes, fontsize=8, bbox=bbox)
212 # Calculate the density of the points
213 xy = np.vstack([xs, ys]).T
214 kde = KernelDensity(kernel="gaussian").fit(xy)
215 z = np.exp(kde.score_samples(xy))
217 ax.scatter(xs[~fitPoints], ys[~fitPoints], c=z[~fitPoints], cmap="binary", s=0.3)
218 fitScatter = ax.scatter(
219 xs[fitPoints], ys[fitPoints], c=z[fitPoints], cmap=newBlues, label="Used for Fit", s=0.3
220 )
222 # Add colorbar
223 cbAx = fig.add_axes([0.12, 0.08, 0.43, 0.04])
224 plt.colorbar(fitScatter, cax=cbAx, orientation="horizontal")
225 cbText = cbAx.text(
226 0.5,
227 0.5,
228 "Number Density",
229 color="k",
230 rotation="horizontal",
231 transform=cbAx.transAxes,
232 ha="center",
233 va="center",
234 fontsize=8,
235 )
236 cbText.set_path_effects([pathEffects.Stroke(linewidth=3, foreground="w"), pathEffects.Normal()])
237 cbAx.set_xticks([np.min(z[fitPoints]), np.max(z[fitPoints])], labels=["Less", "More"])
239 ax.set_xlabel(self.xAxisLabel)
240 ax.set_ylabel(self.yAxisLabel)
242 # Set useful axis limits
243 percsX = np.nanpercentile(xs, [0.5, 99.5])
244 percsY = np.nanpercentile(ys, [0.5, 99.5])
245 x5 = (percsX[1] - percsX[0]) / 5
246 y5 = (percsY[1] - percsY[0]) / 5
247 ax.set_xlim(percsX[0] - x5, percsX[1] + x5)
248 ax.set_ylim(percsY[0] - y5, percsY[1] + y5)
250 # Plot the fit lines
251 if np.fabs(data["mHW"]) > 1:
252 ysFitLineHW = np.array([data["yMin"], data["yMax"]])
253 xsFitLineHW = (ysFitLineHW - data["bHW"]) / data["mHW"]
254 ysFitLine = np.array([data["yMin"], data["yMax"]])
255 xsFitLine = (ysFitLine - data["bODR"]) / data["mODR"]
256 ysFitLine2 = np.array([data["yMin"], data["yMax"]])
257 xsFitLine2 = (ysFitLine2 - data["bODR2"]) / data["mODR2"]
259 else:
260 xsFitLineHW = np.array([data["xMin"], data["xMax"]])
261 ysFitLineHW = data["mHW"] * xsFitLineHW + data["bHW"] # type: ignore
262 xsFitLine = np.array([data["xMin"], data["xMax"]])
263 ysFitLine = np.array(
264 [
265 data["mODR"] * xsFitLine[0] + data["bODR"],
266 data["mODR"] * xsFitLine[1] + data["bODR"],
267 ]
268 )
269 xsFitLine2 = np.array([data["xMin"], data["xMax"]])
270 ysFitLine2 = np.array(
271 [
272 data["mODR2"] * xsFitLine2[0] + data["bODR2"],
273 data["mODR2"] * xsFitLine2[1] + data["bODR2"],
274 ]
275 )
277 ax.plot(xsFitLineHW, ysFitLineHW, "w", lw=2)
278 (lineHW,) = ax.plot(xsFitLineHW, ysFitLineHW, "g", lw=1, ls="--", label="Hardwired")
280 ax.plot(xsFitLine, ysFitLine, "w", lw=2)
281 (lineInit,) = ax.plot(xsFitLine, ysFitLine, "b", lw=1, ls="--", label="Initial")
283 ax.plot(xsFitLine2, ysFitLine2, "w", lw=2)
284 (lineRefit,) = ax.plot(xsFitLine2, ysFitLine2, "k", lw=1, ls="--", label="Refit")
286 # Calculate the distances to that line
287 # Need two points to characterise the lines we want
288 # to get the distances to
289 p1 = np.array([xsFitLine[0], ysFitLine[0]])
290 p2 = np.array([xsFitLine[1], ysFitLine[1]])
292 p1HW = np.array([xsFitLine[0], ysFitLineHW[0]])
293 p2HW = np.array([xsFitLine[1], ysFitLineHW[1]])
295 # Convert to millimags
296 distsHW = np.array(perpDistance(p1HW, p2HW, zip(xs[fitPoints], ys[fitPoints]))) * 1000
297 dists = np.array(perpDistance(p1, p2, zip(xs[fitPoints], ys[fitPoints]))) * 1000
299 # Now we have the information for the perpendicular line we
300 # can use it to calculate the points at the ends of the
301 # perpendicular lines that intersect at the box edges
302 if np.fabs(data["mHW"]) > 1:
303 xMid = (data["yMin"] - data["bODR2"]) / data["mODR2"]
304 xs = np.array([xMid - 0.5, xMid, xMid + 0.5])
305 ys = data["mPerp"] * xs + data["bPerpMin"]
306 else:
307 xs = np.array([data["xMin"] - 0.2, data["xMin"], data["xMin"] + 0.2])
308 ys = xs * data["mPerp"] + data["bPerpMin"]
309 ax.plot(xs, ys, "k--", alpha=0.7)
311 if np.fabs(data["mHW"]) > 1:
312 xMid = (data["yMax"] - data["bODR2"]) / data["mODR2"]
313 xs = np.array([xMid - 0.5, xMid, xMid + 0.5])
314 ys = data["mPerp"] * xs + data["bPerpMax"]
315 else:
316 xs = np.array([data["xMax"] - 0.2, data["xMax"], data["xMax"] + 0.2])
317 ys = xs * data["mPerp"] + data["bPerpMax"]
318 ax.plot(xs, ys, "k--", alpha=0.7)
320 # Add a histogram
321 axHist.set_ylabel("Number")
322 axHist.set_xlabel("Distance to Line Fit")
323 medDists = np.nanmedian(dists)
324 madDists = sigmaMad(dists, nan_policy="omit")
325 meanDists = np.nanmean(dists)
327 axHist.set_xlim(meanDists - 2.0 * madDists, meanDists + 2.0 * madDists)
328 lineMedian = axHist.axvline(medDists, color="k", label="Median: {:0.3f}".format(medDists))
329 lineMad = axHist.axvline(
330 medDists + madDists, color="k", ls="--", label="sigma MAD: {:0.3f}".format(madDists)
331 )
332 axHist.axvline(medDists - madDists, color="k", ls="--")
334 linesForLegend = [lineHW, lineInit, lineRefit, fitScatter, lineMedian, lineMad]
335 fig.legend(
336 handles=linesForLegend,
337 fontsize=8,
338 bbox_to_anchor=(1.0, 0.99),
339 bbox_transform=fig.transFigure,
340 ncol=2,
341 )
343 axHist.hist(dists, bins=100, histtype="step", label="Refit", color="C0")
344 axHist.hist(distsHW, bins=100, histtype="step", label="HW", color="C0", alpha=0.5)
346 alphas = [1.0, 0.5]
347 handles = [Rectangle((0, 0), 1, 1, color="none", ec="C0", alpha=a) for a in alphas]
348 labels = ["Refit", "HW"]
349 axHist.legend(handles, labels, fontsize=6, loc="upper right")
351 # Add a contour plot showing the magnitude dependance
352 # of the distance to the fit
353 axContour.invert_yaxis()
354 axContour.axvline(0.0, color="k", ls="--", zorder=-1)
355 percsDists = np.nanpercentile(dists, [4, 96])
356 minXs = -1 * np.min(np.fabs(percsDists))
357 maxXs = np.min(np.fabs(percsDists))
358 plotPoints = (dists < maxXs) & (dists > minXs)
359 xs = np.array(dists)[plotPoints]
360 ys = cast(Vector, cast(Vector, mags)[cast(Vector, fitPoints)])[cast(Vector, plotPoints)]
361 H, xEdges, yEdges = np.histogram2d(xs, ys, bins=(11, 11))
362 xBinWidth = xEdges[1] - xEdges[0]
363 yBinWidth = yEdges[1] - yEdges[0]
364 axContour.contour(
365 xEdges[:-1] + xBinWidth / 2, yEdges[:-1] + yBinWidth / 2, H.T, levels=7, cmap=newBlues
366 )
367 axContour.set_xlabel("Distance to Line Fit")
368 axContour.set_ylabel(self.magLabel)
370 fig = addPlotInfo(plt.gcf(), plotInfo)
372 return fig