Coverage for python/lsst/analysis/tools/actions/plot/colorColorFitPlot.py: 13%
166 statements
« prev ^ index » next coverage.py v7.2.5, created at 2023-05-07 02:02 -0700
« prev ^ index » next coverage.py v7.2.5, created at 2023-05-07 02:02 -0700
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ("ColorColorFitPlot",)
26from typing import Mapping, cast
28import matplotlib.patheffects as pathEffects
29import matplotlib.pyplot as plt
30import numpy as np
31from lsst.pex.config import Field, ListField
32from matplotlib.figure import Figure
33from matplotlib.patches import Rectangle
34from sklearn.neighbors import KernelDensity
36from ...interfaces import KeyedData, KeyedDataSchema, PlotAction, Scalar, Vector
37from ...statistics import sigmaMad
38from .plotUtils import addPlotInfo, mkColormap, perpDistance
41class ColorColorFitPlot(PlotAction):
42 """Makes a color-color plot and overplots a
43 prefited line to the specified area of the plot.
44 This is mostly used for the stellar locus plots
45 and also includes panels that illustrate the
46 goodness of the given fit.
47 """
49 xAxisLabel = Field[str](doc="Label to use for the x axis", optional=False)
50 yAxisLabel = Field[str](doc="Label to use for the y axis", optional=False)
51 magLabel = Field[str](doc="Label to use for the magnitudes used to color code by", optional=False)
53 plotTypes = ListField[str](
54 doc="Selection of types of objects to plot. Can take any combination of"
55 " stars, galaxies, unknown, mag, any.",
56 default=["stars"],
57 )
59 plotName = Field[str](doc="The name for the plot.", optional=False)
61 def getInputSchema(self, **kwargs) -> KeyedDataSchema:
62 base: list[tuple[str, type[Vector] | type[Scalar]]] = []
63 base.append(("x", Vector))
64 base.append(("y", Vector))
65 base.append(("mag", Vector))
66 base.append(("approxMagDepth", Scalar))
67 base.append((f"{self.plotName}_sigmaMAD", Scalar))
68 base.append((f"{self.plotName}_median", Scalar))
69 base.append((f"{self.plotName}_hardwired_sigmaMAD", Scalar))
70 base.append((f"{self.plotName}_hardwired_median", Scalar))
71 base.append(("xMin", Scalar))
72 base.append(("xMax", Scalar))
73 base.append(("yMin", Scalar))
74 base.append(("yMax", Scalar))
75 base.append(("mHW", Scalar))
76 base.append(("bHW", Scalar))
77 base.append(("mODR", Scalar))
78 base.append(("bODR", Scalar))
79 base.append(("yBoxMin", Scalar))
80 base.append(("yBoxMax", Scalar))
81 base.append(("bPerpMin", Scalar))
82 base.append(("bPerpMax", Scalar))
83 base.append(("mODR2", Scalar))
84 base.append(("bODR2", Scalar))
85 base.append(("mPerp", Scalar))
87 return base
89 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure:
90 self._validateInput(data, **kwargs)
91 return self.makePlot(data, **kwargs)
93 def _validateInput(self, data: KeyedData, **kwargs) -> None:
94 """NOTE currently can only check that something is not a scalar, not
95 check that data is consistent with Vector
96 """
97 needed = self.getInputSchema(**kwargs)
98 if remainder := {key.format(**kwargs) for key, _ in needed} - {
99 key.format(**kwargs) for key in data.keys()
100 }:
101 raise ValueError(f"Task needs keys {remainder} but they were not in input")
102 for name, typ in needed:
103 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar)
104 if isScalar and typ != Scalar:
105 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}")
107 def makePlot(
108 self,
109 data: KeyedData,
110 plotInfo: Mapping[str, str],
111 **kwargs,
112 ) -> Figure:
113 """Make stellar locus plots using pre fitted values.
115 Parameters
116 ----------
117 data : `KeyedData`
118 The data to plot the points from, for more information
119 please see the notes section.
120 plotInfo : `dict`
121 A dictionary of information about the data being plotted
122 with keys:
124 * ``"run"``
125 The output run for the plots (`str`).
126 * ``"skymap"``
127 The type of skymap used for the data (`str`).
128 * ``"filter"``
129 The filter used for this data (`str`).
130 * ``"tract"``
131 The tract that the data comes from (`str`).
133 Returns
134 -------
135 fig : `matplotlib.figure.Figure`
136 The resulting figure.
138 Notes
139 -----
140 The axis labels are given by `self.config.xLabel` and
141 `self.config.yLabel`. The perpendicular distance of the points to
142 the fit line is given in a histogram in the second panel.
144 For the code to work it expects various quantities to be
145 present in the 'data' that it is given.
147 The quantities that are expected to be present are:
149 * Statistics that are shown on the plot or used by the plotting code:
150 * ``approxMagDepth``
151 The approximate magnitude corresponding to the SN cut used.
152 * ``f"{self.plotName}_sigmaMAD"``
153 The sigma mad of the distances to the line fit.
154 * ``f"{self.identity or ''}_median"``
155 The median of the distances to the line fit.
156 * ``f"{self.identity or ''}_hardwired_sigmaMAD"``
157 The sigma mad of the distances to the initial fit.
158 * ``f"{self.identity or ''}_hardwired_median"``
159 The median of the distances to the initial fit.
162 * Parameters from the fitting code that are illustrated on the plot:
163 * ``"bHW"``
164 The hardwired intercept to fall back on.
165 * ``"bODR"``
166 The intercept calculated by the orthogonal distance
167 regression fitting.
168 * ``"bODR2"``
169 The intercept calculated by the second iteration of
170 orthogonal distance regression fitting.
171 * ``"mHW"``
172 The hardwired gradient to fall back on.
173 * ``"mODR"``
174 The gradient calculated by the orthogonal distance
175 regression fitting.
176 * ``"mODR2"``
177 The gradient calculated by the second iteration of
178 orthogonal distance regression fitting.
179 * ``"xMin`"``
180 The x minimum of the box used in the fit.
181 * ``"xMax"``
182 The x maximum of the box used in the fit.
183 * ``"yMin"``
184 The y minimum of the box used in the fit.
185 * ``"yMax"``
186 The y maximum of the box used in the fit.
187 * ``"mPerp"``
188 The gradient of the line perpendicular to the line from
189 the second ODR fit.
190 * ``"bPerpMin"``
191 The intercept of the perpendicular line that goes through
192 xMin.
193 * ``"bPerpMax"``
194 The intercept of the perpendicular line that goes through
195 xMax.
197 * The main inputs to plot:
198 x, y, mag
200 Examples
201 --------
202 An example of the plot produced from this code is here:
204 .. image:: /_static/analysis_tools/stellarLocusExample.png
206 For a detailed example of how to make a plot from the command line
207 please see the
208 :ref:`getting started guide<analysis-tools-getting-started>`.
209 """
211 # Define a new colormap
212 newBlues = mkColormap(["paleturquoise", "midnightblue"])
214 # Make a figure with three panels
215 fig = plt.figure(dpi=300)
216 ax = fig.add_axes([0.12, 0.25, 0.43, 0.60])
217 axContour = fig.add_axes([0.65, 0.11, 0.3, 0.31])
218 axHist = fig.add_axes([0.65, 0.51, 0.3, 0.31])
220 # Check for nans/infs
221 goodPoints = np.isfinite(data["x"]) & np.isfinite(data["y"]) & np.isfinite(data["mag"])
222 xs = cast(Vector, data["x"])[goodPoints]
223 ys = cast(Vector, data["y"])[goodPoints]
224 mags = cast(Vector, data["mag"])[goodPoints]
226 # TODO: Make a no data fig function and use here
227 if len(xs) == 0 or len(ys) == 0:
228 return fig
230 # Points to use for the fit
231 # type ignore because Vector needs a prototype interface
232 fitPoints = np.where(
233 (xs > data["xMin"]) # type: ignore
234 & (xs < data["xMax"]) # type: ignore
235 & (ys > data["yMin"]) # type: ignore
236 & (ys < data["yMax"]) # type: ignore
237 )[0]
239 # Plot the initial fit box
240 ax.plot(
241 [data["xMin"], data["xMax"], data["xMax"], data["xMin"], data["xMin"]],
242 [data["yMin"], data["yMin"], data["yMax"], data["yMax"], data["yMin"]],
243 "k",
244 alpha=0.3,
245 )
247 # Add some useful information to the plot
248 bbox = dict(alpha=0.9, facecolor="white", edgecolor="none")
249 medMag = np.nanmedian(cast(Vector, mags))
251 # TODO: GET THE SN FROM THE EARLIER PREP STEP
252 SN = "-"
253 infoText = "N Used: {}\nN Total: {}\nS/N cut: {}\n".format(len(fitPoints), len(xs), SN)
254 infoText += r"Mag $\lesssim$: " + "{:0.2f}".format(medMag)
255 ax.text(0.05, 0.78, infoText, color="k", transform=ax.transAxes, fontsize=8, bbox=bbox)
257 # Calculate the density of the points
258 xy = np.vstack([xs, ys]).T
259 kde = KernelDensity(kernel="gaussian").fit(xy)
260 z = np.exp(kde.score_samples(xy))
262 ax.scatter(xs[~fitPoints], ys[~fitPoints], c=z[~fitPoints], cmap="binary", s=0.3)
263 fitScatter = ax.scatter(
264 xs[fitPoints], ys[fitPoints], c=z[fitPoints], cmap=newBlues, label="Used for Fit", s=0.3
265 )
267 # Add colorbar
268 cbAx = fig.add_axes([0.12, 0.08, 0.43, 0.04])
269 plt.colorbar(fitScatter, cax=cbAx, orientation="horizontal")
270 cbText = cbAx.text(
271 0.5,
272 0.5,
273 "Number Density",
274 color="k",
275 rotation="horizontal",
276 transform=cbAx.transAxes,
277 ha="center",
278 va="center",
279 fontsize=8,
280 )
281 cbText.set_path_effects([pathEffects.Stroke(linewidth=3, foreground="w"), pathEffects.Normal()])
282 cbAx.set_xticks([np.min(z[fitPoints]), np.max(z[fitPoints])], labels=["Less", "More"])
284 ax.set_xlabel(self.xAxisLabel)
285 ax.set_ylabel(self.yAxisLabel)
287 # Set useful axis limits
288 percsX = np.nanpercentile(xs, [0.5, 99.5])
289 percsY = np.nanpercentile(ys, [0.5, 99.5])
290 x5 = (percsX[1] - percsX[0]) / 5
291 y5 = (percsY[1] - percsY[0]) / 5
292 ax.set_xlim(percsX[0] - x5, percsX[1] + x5)
293 ax.set_ylim(percsY[0] - y5, percsY[1] + y5)
295 # Plot the fit lines
296 if np.fabs(data["mHW"]) > 1:
297 ysFitLineHW = np.array([data["yMin"], data["yMax"]])
298 xsFitLineHW = (ysFitLineHW - data["bHW"]) / data["mHW"]
299 ysFitLine = np.array([data["yMin"], data["yMax"]])
300 xsFitLine = (ysFitLine - data["bODR"]) / data["mODR"]
301 ysFitLine2 = np.array([data["yMin"], data["yMax"]])
302 xsFitLine2 = (ysFitLine2 - data["bODR2"]) / data["mODR2"]
304 else:
305 xsFitLineHW = np.array([data["xMin"], data["xMax"]])
306 ysFitLineHW = data["mHW"] * xsFitLineHW + data["bHW"] # type: ignore
307 xsFitLine = np.array([data["xMin"], data["xMax"]])
308 ysFitLine = np.array(
309 [
310 data["mODR"] * xsFitLine[0] + data["bODR"],
311 data["mODR"] * xsFitLine[1] + data["bODR"],
312 ]
313 )
314 xsFitLine2 = np.array([data["xMin"], data["xMax"]])
315 ysFitLine2 = np.array(
316 [
317 data["mODR2"] * xsFitLine2[0] + data["bODR2"],
318 data["mODR2"] * xsFitLine2[1] + data["bODR2"],
319 ]
320 )
322 ax.plot(xsFitLineHW, ysFitLineHW, "w", lw=2)
323 (lineHW,) = ax.plot(xsFitLineHW, ysFitLineHW, "g", lw=1, ls="--", label="Hardwired")
325 ax.plot(xsFitLine, ysFitLine, "w", lw=2)
326 (lineInit,) = ax.plot(xsFitLine, ysFitLine, "b", lw=1, ls="--", label="Initial")
328 ax.plot(xsFitLine2, ysFitLine2, "w", lw=2)
329 (lineRefit,) = ax.plot(xsFitLine2, ysFitLine2, "k", lw=1, ls="--", label="Refit")
331 # Calculate the distances to that line
332 # Need two points to characterise the lines we want
333 # to get the distances to
334 p1 = np.array([xsFitLine[0], ysFitLine[0]])
335 p2 = np.array([xsFitLine[1], ysFitLine[1]])
337 p1HW = np.array([xsFitLine[0], ysFitLineHW[0]])
338 p2HW = np.array([xsFitLine[1], ysFitLineHW[1]])
340 # Convert to millimags
341 distsHW = np.array(perpDistance(p1HW, p2HW, zip(xs[fitPoints], ys[fitPoints]))) * 1000
342 dists = np.array(perpDistance(p1, p2, zip(xs[fitPoints], ys[fitPoints]))) * 1000
344 # Now we have the information for the perpendicular line we
345 # can use it to calculate the points at the ends of the
346 # perpendicular lines that intersect at the box edges
347 if np.fabs(data["mHW"]) > 1:
348 xMid = (data["yMin"] - data["bODR2"]) / data["mODR2"]
349 xs = np.array([xMid - 0.5, xMid, xMid + 0.5])
350 ys = data["mPerp"] * xs + data["bPerpMin"]
351 else:
352 xs = np.array([data["xMin"] - 0.2, data["xMin"], data["xMin"] + 0.2])
353 ys = xs * data["mPerp"] + data["bPerpMin"]
354 ax.plot(xs, ys, "k--", alpha=0.7)
356 if np.fabs(data["mHW"]) > 1:
357 xMid = (data["yMax"] - data["bODR2"]) / data["mODR2"]
358 xs = np.array([xMid - 0.5, xMid, xMid + 0.5])
359 ys = data["mPerp"] * xs + data["bPerpMax"]
360 else:
361 xs = np.array([data["xMax"] - 0.2, data["xMax"], data["xMax"] + 0.2])
362 ys = xs * data["mPerp"] + data["bPerpMax"]
363 ax.plot(xs, ys, "k--", alpha=0.7)
365 # Add a histogram
366 axHist.set_ylabel("Number")
367 axHist.set_xlabel("Distance to Line Fit")
368 medDists = np.nanmedian(dists)
369 madDists = sigmaMad(dists, nan_policy="omit")
370 meanDists = np.nanmean(dists)
372 axHist.set_xlim(meanDists - 2.0 * madDists, meanDists + 2.0 * madDists)
373 lineMedian = axHist.axvline(medDists, color="k", label="Median: {:0.3f}".format(medDists))
374 lineMad = axHist.axvline(
375 medDists + madDists, color="k", ls="--", label="sigma MAD: {:0.3f}".format(madDists)
376 )
377 axHist.axvline(medDists - madDists, color="k", ls="--")
379 linesForLegend = [lineHW, lineInit, lineRefit, fitScatter, lineMedian, lineMad]
380 fig.legend(
381 handles=linesForLegend,
382 fontsize=8,
383 bbox_to_anchor=(1.0, 0.99),
384 bbox_transform=fig.transFigure,
385 ncol=2,
386 )
388 axHist.hist(dists, bins=100, histtype="step", label="Refit", color="C0")
389 axHist.hist(distsHW, bins=100, histtype="step", label="HW", color="C0", alpha=0.5)
391 alphas = [1.0, 0.5]
392 handles = [Rectangle((0, 0), 1, 1, color="none", ec="C0", alpha=a) for a in alphas]
393 labels = ["Refit", "HW"]
394 axHist.legend(handles, labels, fontsize=6, loc="upper right")
396 # Add a contour plot showing the magnitude dependance
397 # of the distance to the fit
398 axContour.invert_yaxis()
399 axContour.axvline(0.0, color="k", ls="--", zorder=-1)
400 percsDists = np.nanpercentile(dists, [4, 96])
401 minXs = -1 * np.min(np.fabs(percsDists))
402 maxXs = np.min(np.fabs(percsDists))
403 plotPoints = (dists < maxXs) & (dists > minXs)
404 xs = np.array(dists)[plotPoints]
405 ys = cast(Vector, cast(Vector, mags)[cast(Vector, fitPoints)])[cast(Vector, plotPoints)]
406 H, xEdges, yEdges = np.histogram2d(xs, ys, bins=(11, 11))
407 xBinWidth = xEdges[1] - xEdges[0]
408 yBinWidth = yEdges[1] - yEdges[0]
409 axContour.contour(
410 xEdges[:-1] + xBinWidth / 2, yEdges[:-1] + yBinWidth / 2, H.T, levels=7, cmap=newBlues
411 )
412 axContour.set_xlabel("Distance to Line Fit")
413 axContour.set_ylabel(self.magLabel)
415 fig = addPlotInfo(plt.gcf(), plotInfo)
417 return fig