Coverage for python/lsst/analysis/tools/actions/plot/colorColorFitPlot.py: 13%
167 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-04 23:12 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-04 23:12 +0000
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ("ColorColorFitPlot",)
26from typing import Mapping, cast
28import matplotlib.patheffects as pathEffects
29import matplotlib.pyplot as plt
30import numpy as np
31from lsst.pex.config import Field, ListField
32from matplotlib.figure import Figure
33from matplotlib.patches import Rectangle
34from sklearn.neighbors import KernelDensity
36from ...interfaces import KeyedData, KeyedDataSchema, PlotAction, Scalar, Vector
37from ...statistics import sigmaMad
38from ..keyedData.stellarLocusFit import perpDistance
39from .plotUtils import addPlotInfo, mkColormap
42class ColorColorFitPlot(PlotAction):
43 """Makes a color-color plot and overplots a
44 prefited line to the specified area of the plot.
45 This is mostly used for the stellar locus plots
46 and also includes panels that illustrate the
47 goodness of the given fit.
48 """
50 xAxisLabel = Field[str](doc="Label to use for the x axis", optional=False)
51 yAxisLabel = Field[str](doc="Label to use for the y axis", optional=False)
52 magLabel = Field[str](doc="Label to use for the magnitudes used to color code by", optional=False)
54 plotTypes = ListField[str](
55 doc="Selection of types of objects to plot. Can take any combination of"
56 " stars, galaxies, unknown, mag, any.",
57 default=["stars"],
58 )
60 plotName = Field[str](doc="The name for the plot.", optional=False)
62 def getInputSchema(self, **kwargs) -> KeyedDataSchema:
63 base: list[tuple[str, type[Vector] | type[Scalar]]] = []
64 base.append(("x", Vector))
65 base.append(("y", Vector))
66 base.append(("mag", Vector))
67 base.append(("approxMagDepth", Scalar))
68 base.append((f"{self.plotName}_sigmaMAD", Scalar))
69 base.append((f"{self.plotName}_median", Scalar))
70 base.append((f"{self.plotName}_hardwired_sigmaMAD", Scalar))
71 base.append((f"{self.plotName}_hardwired_median", Scalar))
72 base.append(("xMin", Scalar))
73 base.append(("xMax", Scalar))
74 base.append(("yMin", Scalar))
75 base.append(("yMax", Scalar))
76 base.append(("mHW", Scalar))
77 base.append(("bHW", Scalar))
78 base.append(("mODR", Scalar))
79 base.append(("bODR", Scalar))
80 base.append(("yBoxMin", Scalar))
81 base.append(("yBoxMax", Scalar))
82 base.append(("bPerpMin", Scalar))
83 base.append(("bPerpMax", Scalar))
84 base.append(("mODR2", Scalar))
85 base.append(("bODR2", Scalar))
86 base.append(("mPerp", Scalar))
88 return base
90 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure:
91 self._validateInput(data, **kwargs)
92 return self.makePlot(data, **kwargs)
94 def _validateInput(self, data: KeyedData, **kwargs) -> None:
95 """NOTE currently can only check that something is not a scalar, not
96 check that data is consistent with Vector
97 """
98 needed = self.getInputSchema(**kwargs)
99 if remainder := {key.format(**kwargs) for key, _ in needed} - {
100 key.format(**kwargs) for key in data.keys()
101 }:
102 raise ValueError(f"Task needs keys {remainder} but they were not in input")
103 for name, typ in needed:
104 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar)
105 if isScalar and typ != Scalar:
106 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}")
108 def makePlot(
109 self,
110 data: KeyedData,
111 plotInfo: Mapping[str, str],
112 **kwargs,
113 ) -> Figure:
114 """Make stellar locus plots using pre fitted values.
116 Parameters
117 ----------
118 data : `KeyedData`
119 The data to plot the points from, for more information
120 please see the notes section.
121 plotInfo : `dict`
122 A dictionary of information about the data being plotted
123 with keys:
125 * ``"run"``
126 The output run for the plots (`str`).
127 * ``"skymap"``
128 The type of skymap used for the data (`str`).
129 * ``"filter"``
130 The filter used for this data (`str`).
131 * ``"tract"``
132 The tract that the data comes from (`str`).
134 Returns
135 -------
136 fig : `matplotlib.figure.Figure`
137 The resulting figure.
139 Notes
140 -----
141 The axis labels are given by `self.config.xLabel` and
142 `self.config.yLabel`. The perpendicular distance of the points to
143 the fit line is given in a histogram in the second panel.
145 For the code to work it expects various quantities to be
146 present in the 'data' that it is given.
148 The quantities that are expected to be present are:
150 * Statistics that are shown on the plot or used by the plotting code:
151 * ``approxMagDepth``
152 The approximate magnitude corresponding to the SN cut used.
153 * ``f"{self.plotName}_sigmaMAD"``
154 The sigma mad of the distances to the line fit.
155 * ``f"{self.identity or ''}_median"``
156 The median of the distances to the line fit.
157 * ``f"{self.identity or ''}_hardwired_sigmaMAD"``
158 The sigma mad of the distances to the initial fit.
159 * ``f"{self.identity or ''}_hardwired_median"``
160 The median of the distances to the initial fit.
163 * Parameters from the fitting code that are illustrated on the plot:
164 * ``"bHW"``
165 The hardwired intercept to fall back on.
166 * ``"bODR"``
167 The intercept calculated by the orthogonal distance
168 regression fitting.
169 * ``"bODR2"``
170 The intercept calculated by the second iteration of
171 orthogonal distance regression fitting.
172 * ``"mHW"``
173 The hardwired gradient to fall back on.
174 * ``"mODR"``
175 The gradient calculated by the orthogonal distance
176 regression fitting.
177 * ``"mODR2"``
178 The gradient calculated by the second iteration of
179 orthogonal distance regression fitting.
180 * ``"xMin`"``
181 The x minimum of the box used in the fit.
182 * ``"xMax"``
183 The x maximum of the box used in the fit.
184 * ``"yMin"``
185 The y minimum of the box used in the fit.
186 * ``"yMax"``
187 The y maximum of the box used in the fit.
188 * ``"mPerp"``
189 The gradient of the line perpendicular to the line from
190 the second ODR fit.
191 * ``"bPerpMin"``
192 The intercept of the perpendicular line that goes through
193 xMin.
194 * ``"bPerpMax"``
195 The intercept of the perpendicular line that goes through
196 xMax.
198 * The main inputs to plot:
199 x, y, mag
201 Examples
202 --------
203 An example of the plot produced from this code is here:
205 .. image:: /_static/analysis_tools/stellarLocusExample.png
207 For a detailed example of how to make a plot from the command line
208 please see the
209 :ref:`getting started guide<analysis-tools-getting-started>`.
210 """
212 # Define a new colormap
213 newBlues = mkColormap(["paleturquoise", "midnightblue"])
215 # Make a figure with three panels
216 fig = plt.figure(dpi=300)
217 ax = fig.add_axes([0.12, 0.25, 0.43, 0.60])
218 axContour = fig.add_axes([0.65, 0.11, 0.3, 0.31])
219 axHist = fig.add_axes([0.65, 0.51, 0.3, 0.31])
221 # Check for nans/infs
222 goodPoints = np.isfinite(data["x"]) & np.isfinite(data["y"]) & np.isfinite(data["mag"])
223 xs = cast(Vector, data["x"])[goodPoints]
224 ys = cast(Vector, data["y"])[goodPoints]
225 mags = cast(Vector, data["mag"])[goodPoints]
227 # TODO: Make a no data fig function and use here
228 if len(xs) == 0 or len(ys) == 0:
229 return fig
231 # Points to use for the fit
232 # type ignore because Vector needs a prototype interface
233 fitPoints = np.where(
234 (xs > data["xMin"]) # type: ignore
235 & (xs < data["xMax"]) # type: ignore
236 & (ys > data["yMin"]) # type: ignore
237 & (ys < data["yMax"]) # type: ignore
238 )[0]
240 # Plot the initial fit box
241 ax.plot(
242 [data["xMin"], data["xMax"], data["xMax"], data["xMin"], data["xMin"]],
243 [data["yMin"], data["yMin"], data["yMax"], data["yMax"], data["yMin"]],
244 "k",
245 alpha=0.3,
246 )
248 # Add some useful information to the plot
249 bbox = dict(alpha=0.9, facecolor="white", edgecolor="none")
250 medMag = np.nanmedian(cast(Vector, mags))
252 # TODO: GET THE SN FROM THE EARLIER PREP STEP
253 SN = "-"
254 infoText = "N Used: {}\nN Total: {}\nS/N cut: {}\n".format(len(fitPoints), len(xs), SN)
255 infoText += r"Mag $\lesssim$: " + "{:0.2f}".format(medMag)
256 ax.text(0.05, 0.78, infoText, color="k", transform=ax.transAxes, fontsize=8, bbox=bbox)
258 # Calculate the density of the points
259 xy = np.vstack([xs, ys]).T
260 kde = KernelDensity(kernel="gaussian").fit(xy)
261 z = np.exp(kde.score_samples(xy))
263 ax.scatter(xs[~fitPoints], ys[~fitPoints], c=z[~fitPoints], cmap="binary", s=0.3)
264 fitScatter = ax.scatter(
265 xs[fitPoints], ys[fitPoints], c=z[fitPoints], cmap=newBlues, label="Used for Fit", s=0.3
266 )
268 # Add colorbar
269 cbAx = fig.add_axes([0.12, 0.08, 0.43, 0.04])
270 plt.colorbar(fitScatter, cax=cbAx, orientation="horizontal")
271 cbText = cbAx.text(
272 0.5,
273 0.5,
274 "Number Density",
275 color="k",
276 rotation="horizontal",
277 transform=cbAx.transAxes,
278 ha="center",
279 va="center",
280 fontsize=8,
281 )
282 cbText.set_path_effects([pathEffects.Stroke(linewidth=3, foreground="w"), pathEffects.Normal()])
283 cbAx.set_xticks([np.min(z[fitPoints]), np.max(z[fitPoints])], labels=["Less", "More"])
285 ax.set_xlabel(self.xAxisLabel)
286 ax.set_ylabel(self.yAxisLabel)
288 # Set useful axis limits
289 percsX = np.nanpercentile(xs, [0.5, 99.5])
290 percsY = np.nanpercentile(ys, [0.5, 99.5])
291 x5 = (percsX[1] - percsX[0]) / 5
292 y5 = (percsY[1] - percsY[0]) / 5
293 ax.set_xlim(percsX[0] - x5, percsX[1] + x5)
294 ax.set_ylim(percsY[0] - y5, percsY[1] + y5)
296 # Plot the fit lines
297 if np.fabs(data["mHW"]) > 1:
298 ysFitLineHW = np.array([data["yMin"], data["yMax"]])
299 xsFitLineHW = (ysFitLineHW - data["bHW"]) / data["mHW"]
300 ysFitLine = np.array([data["yMin"], data["yMax"]])
301 xsFitLine = (ysFitLine - data["bODR"]) / data["mODR"]
302 ysFitLine2 = np.array([data["yMin"], data["yMax"]])
303 xsFitLine2 = (ysFitLine2 - data["bODR2"]) / data["mODR2"]
305 else:
306 xsFitLineHW = np.array([data["xMin"], data["xMax"]])
307 ysFitLineHW = data["mHW"] * xsFitLineHW + data["bHW"] # type: ignore
308 xsFitLine = np.array([data["xMin"], data["xMax"]])
309 ysFitLine = np.array(
310 [
311 data["mODR"] * xsFitLine[0] + data["bODR"],
312 data["mODR"] * xsFitLine[1] + data["bODR"],
313 ]
314 )
315 xsFitLine2 = np.array([data["xMin"], data["xMax"]])
316 ysFitLine2 = np.array(
317 [
318 data["mODR2"] * xsFitLine2[0] + data["bODR2"],
319 data["mODR2"] * xsFitLine2[1] + data["bODR2"],
320 ]
321 )
323 ax.plot(xsFitLineHW, ysFitLineHW, "w", lw=2)
324 (lineHW,) = ax.plot(xsFitLineHW, ysFitLineHW, "g", lw=1, ls="--", label="Hardwired")
326 ax.plot(xsFitLine, ysFitLine, "w", lw=2)
327 (lineInit,) = ax.plot(xsFitLine, ysFitLine, "b", lw=1, ls="--", label="Initial")
329 ax.plot(xsFitLine2, ysFitLine2, "w", lw=2)
330 (lineRefit,) = ax.plot(xsFitLine2, ysFitLine2, "k", lw=1, ls="--", label="Refit")
332 # Calculate the distances to that line
333 # Need two points to characterise the lines we want
334 # to get the distances to
335 p1 = np.array([xsFitLine[0], ysFitLine[0]])
336 p2 = np.array([xsFitLine[1], ysFitLine[1]])
338 p1HW = np.array([xsFitLine[0], ysFitLineHW[0]])
339 p2HW = np.array([xsFitLine[1], ysFitLineHW[1]])
341 # Convert to millimags
342 distsHW = np.array(perpDistance(p1HW, p2HW, zip(xs[fitPoints], ys[fitPoints]))) * 1000
343 dists = np.array(perpDistance(p1, p2, zip(xs[fitPoints], ys[fitPoints]))) * 1000
345 # Now we have the information for the perpendicular line we
346 # can use it to calculate the points at the ends of the
347 # perpendicular lines that intersect at the box edges
348 if np.fabs(data["mHW"]) > 1:
349 xMid = (data["yMin"] - data["bODR2"]) / data["mODR2"]
350 xs = np.array([xMid - 0.5, xMid, xMid + 0.5])
351 ys = data["mPerp"] * xs + data["bPerpMin"]
352 else:
353 xs = np.array([data["xMin"] - 0.2, data["xMin"], data["xMin"] + 0.2])
354 ys = xs * data["mPerp"] + data["bPerpMin"]
355 ax.plot(xs, ys, "k--", alpha=0.7)
357 if np.fabs(data["mHW"]) > 1:
358 xMid = (data["yMax"] - data["bODR2"]) / data["mODR2"]
359 xs = np.array([xMid - 0.5, xMid, xMid + 0.5])
360 ys = data["mPerp"] * xs + data["bPerpMax"]
361 else:
362 xs = np.array([data["xMax"] - 0.2, data["xMax"], data["xMax"] + 0.2])
363 ys = xs * data["mPerp"] + data["bPerpMax"]
364 ax.plot(xs, ys, "k--", alpha=0.7)
366 # Add a histogram
367 axHist.set_ylabel("Number")
368 axHist.set_xlabel("Distance to Line Fit")
369 medDists = np.nanmedian(dists)
370 madDists = sigmaMad(dists, nan_policy="omit")
371 meanDists = np.nanmean(dists)
373 axHist.set_xlim(meanDists - 2.0 * madDists, meanDists + 2.0 * madDists)
374 lineMedian = axHist.axvline(medDists, color="k", label="Median: {:0.3f}".format(medDists))
375 lineMad = axHist.axvline(
376 medDists + madDists, color="k", ls="--", label="sigma MAD: {:0.3f}".format(madDists)
377 )
378 axHist.axvline(medDists - madDists, color="k", ls="--")
380 linesForLegend = [lineHW, lineInit, lineRefit, fitScatter, lineMedian, lineMad]
381 fig.legend(
382 handles=linesForLegend,
383 fontsize=8,
384 bbox_to_anchor=(1.0, 0.99),
385 bbox_transform=fig.transFigure,
386 ncol=2,
387 )
389 axHist.hist(dists, bins=100, histtype="step", label="Refit", color="C0")
390 axHist.hist(distsHW, bins=100, histtype="step", label="HW", color="C0", alpha=0.5)
392 alphas = [1.0, 0.5]
393 handles = [Rectangle((0, 0), 1, 1, color="none", ec="C0", alpha=a) for a in alphas]
394 labels = ["Refit", "HW"]
395 axHist.legend(handles, labels, fontsize=6, loc="upper right")
397 # Add a contour plot showing the magnitude dependance
398 # of the distance to the fit
399 axContour.invert_yaxis()
400 axContour.axvline(0.0, color="k", ls="--", zorder=-1)
401 percsDists = np.nanpercentile(dists, [4, 96])
402 minXs = -1 * np.min(np.fabs(percsDists))
403 maxXs = np.min(np.fabs(percsDists))
404 plotPoints = (dists < maxXs) & (dists > minXs)
405 xs = np.array(dists)[plotPoints]
406 ys = cast(Vector, cast(Vector, mags)[cast(Vector, fitPoints)])[cast(Vector, plotPoints)]
407 H, xEdges, yEdges = np.histogram2d(xs, ys, bins=(11, 11))
408 xBinWidth = xEdges[1] - xEdges[0]
409 yBinWidth = yEdges[1] - yEdges[0]
410 axContour.contour(
411 xEdges[:-1] + xBinWidth / 2, yEdges[:-1] + yBinWidth / 2, H.T, levels=7, cmap=newBlues
412 )
413 axContour.set_xlabel("Distance to Line Fit")
414 axContour.set_ylabel(self.magLabel)
416 fig = addPlotInfo(plt.gcf(), plotInfo)
418 return fig