Coverage for python/lsst/analysis/tools/actions/plot/colorColorFitPlot.py: 13%
172 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-11-30 14:25 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-11-30 14:25 +0000
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ("ColorColorFitPlot",)
26from typing import Mapping, cast
28import matplotlib.patheffects as pathEffects
29import matplotlib.pyplot as plt
30import numpy as np
31from lsst.pex.config import Field, ListField, RangeField
32from matplotlib.figure import Figure
33from matplotlib.patches import Rectangle
34from sklearn.neighbors import KernelDensity
36from ...interfaces import KeyedData, KeyedDataSchema, PlotAction, Scalar, Vector
37from ...statistics import sigmaMad
38from ..keyedData.stellarLocusFit import perpDistance
39from .plotUtils import addPlotInfo, mkColormap
42class ColorColorFitPlot(PlotAction):
43 """Makes a color-color plot and overplots a
44 prefited line to the specified area of the plot.
45 This is mostly used for the stellar locus plots
46 and also includes panels that illustrate the
47 goodness of the given fit.
48 """
50 xAxisLabel = Field[str](doc="Label to use for the x axis", optional=False)
51 yAxisLabel = Field[str](doc="Label to use for the y axis", optional=False)
52 magLabel = Field[str](doc="Label to use for the magnitudes used to color code by", optional=False)
54 plotTypes = ListField[str](
55 doc="Selection of types of objects to plot. Can take any combination of"
56 " stars, galaxies, unknown, mag, any.",
57 default=["stars"],
58 )
60 plotName = Field[str](doc="The name for the plot.", optional=False)
61 minPointsForFit = RangeField[int](
62 doc="Minimum number of valid objects to bother attempting a fit.",
63 default=5,
64 min=1,
65 )
67 def getInputSchema(self, **kwargs) -> KeyedDataSchema:
68 base: list[tuple[str, type[Vector] | type[Scalar]]] = []
69 base.append(("x", Vector))
70 base.append(("y", Vector))
71 base.append(("mag", Vector))
72 base.append(("approxMagDepth", Scalar))
73 base.append((f"{self.plotName}_sigmaMAD", Scalar))
74 base.append((f"{self.plotName}_median", Scalar))
75 base.append((f"{self.plotName}_hardwired_sigmaMAD", Scalar))
76 base.append((f"{self.plotName}_hardwired_median", Scalar))
77 base.append(("xMin", Scalar))
78 base.append(("xMax", Scalar))
79 base.append(("yMin", Scalar))
80 base.append(("yMax", Scalar))
81 base.append(("mHW", Scalar))
82 base.append(("bHW", Scalar))
83 base.append(("mODR", Scalar))
84 base.append(("bODR", Scalar))
85 base.append(("yBoxMin", Scalar))
86 base.append(("yBoxMax", Scalar))
87 base.append(("bPerpMin", Scalar))
88 base.append(("bPerpMax", Scalar))
89 base.append(("mODR2", Scalar))
90 base.append(("bODR2", Scalar))
91 base.append(("mPerp", Scalar))
93 return base
95 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure:
96 self._validateInput(data, **kwargs)
97 return self.makePlot(data, **kwargs)
99 def _validateInput(self, data: KeyedData, **kwargs) -> None:
100 """NOTE currently can only check that something is not a scalar, not
101 check that data is consistent with Vector
102 """
103 needed = self.getInputSchema(**kwargs)
104 if remainder := {key.format(**kwargs) for key, _ in needed} - {
105 key.format(**kwargs) for key in data.keys()
106 }:
107 raise ValueError(f"Task needs keys {remainder} but they were not in input")
108 for name, typ in needed:
109 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar)
110 if isScalar and typ != Scalar:
111 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}")
113 def makePlot(
114 self,
115 data: KeyedData,
116 plotInfo: Mapping[str, str],
117 **kwargs,
118 ) -> Figure:
119 """Make stellar locus plots using pre fitted values.
121 Parameters
122 ----------
123 data : `KeyedData`
124 The data to plot the points from, for more information
125 please see the notes section.
126 plotInfo : `dict`
127 A dictionary of information about the data being plotted
128 with keys:
130 * ``"run"``
131 The output run for the plots (`str`).
132 * ``"skymap"``
133 The type of skymap used for the data (`str`).
134 * ``"filter"``
135 The filter used for this data (`str`).
136 * ``"tract"``
137 The tract that the data comes from (`str`).
139 Returns
140 -------
141 fig : `matplotlib.figure.Figure`
142 The resulting figure.
144 Notes
145 -----
146 The axis labels are given by `self.config.xLabel` and
147 `self.config.yLabel`. The perpendicular distance of the points to
148 the fit line is given in a histogram in the second panel.
150 For the code to work it expects various quantities to be
151 present in the 'data' that it is given.
153 The quantities that are expected to be present are:
155 * Statistics that are shown on the plot or used by the plotting code:
156 * ``approxMagDepth``
157 The approximate magnitude corresponding to the SN cut used.
158 * ``f"{self.plotName}_sigmaMAD"``
159 The sigma mad of the distances to the line fit.
160 * ``f"{self.identity or ''}_median"``
161 The median of the distances to the line fit.
162 * ``f"{self.identity or ''}_hardwired_sigmaMAD"``
163 The sigma mad of the distances to the initial fit.
164 * ``f"{self.identity or ''}_hardwired_median"``
165 The median of the distances to the initial fit.
168 * Parameters from the fitting code that are illustrated on the plot:
169 * ``"bHW"``
170 The hardwired intercept to fall back on.
171 * ``"bODR"``
172 The intercept calculated by the orthogonal distance
173 regression fitting.
174 * ``"bODR2"``
175 The intercept calculated by the second iteration of
176 orthogonal distance regression fitting.
177 * ``"mHW"``
178 The hardwired gradient to fall back on.
179 * ``"mODR"``
180 The gradient calculated by the orthogonal distance
181 regression fitting.
182 * ``"mODR2"``
183 The gradient calculated by the second iteration of
184 orthogonal distance regression fitting.
185 * ``"xMin`"``
186 The x minimum of the box used in the fit.
187 * ``"xMax"``
188 The x maximum of the box used in the fit.
189 * ``"yMin"``
190 The y minimum of the box used in the fit.
191 * ``"yMax"``
192 The y maximum of the box used in the fit.
193 * ``"mPerp"``
194 The gradient of the line perpendicular to the line from
195 the second ODR fit.
196 * ``"bPerpMin"``
197 The intercept of the perpendicular line that goes through
198 xMin.
199 * ``"bPerpMax"``
200 The intercept of the perpendicular line that goes through
201 xMax.
203 * The main inputs to plot:
204 x, y, mag
206 Examples
207 --------
208 An example of the plot produced from this code is here:
210 .. image:: /_static/analysis_tools/stellarLocusExample.png
212 For a detailed example of how to make a plot from the command line
213 please see the
214 :ref:`getting started guide<analysis-tools-getting-started>`.
215 """
217 # Define a new colormap
218 newBlues = mkColormap(["paleturquoise", "midnightblue"])
220 # Make a figure with three panels
221 fig = plt.figure(dpi=300)
222 ax = fig.add_axes([0.12, 0.25, 0.43, 0.60])
223 axContour = fig.add_axes([0.65, 0.11, 0.3, 0.31])
224 axHist = fig.add_axes([0.65, 0.51, 0.3, 0.31])
226 # Check for nans/infs
227 goodPoints = np.isfinite(data["x"]) & np.isfinite(data["y"]) & np.isfinite(data["mag"])
228 xs = cast(Vector, data["x"])[goodPoints]
229 ys = cast(Vector, data["y"])[goodPoints]
230 mags = cast(Vector, data["mag"])[goodPoints]
232 # Points to use for the fit
233 # type ignore because Vector needs a prototype interface
234 fitPoints = np.where(
235 (xs > data["xMin"]) # type: ignore
236 & (xs < data["xMax"]) # type: ignore
237 & (ys > data["yMin"]) # type: ignore
238 & (ys < data["yMax"]) # type: ignore
239 )[0]
241 # TODO: Make a no data fig function and use here
242 if len(fitPoints) < self.minPointsForFit:
243 fig = plt.figure(dpi=120)
244 noDataText = (
245 "Number of objects after cuts ({}) is less than the\nminimum required by "
246 "minPointsForFit ({})".format(len(fitPoints), self.minPointsForFit)
247 )
248 fig.text(0.5, 0.5, noDataText, ha="center", va="center")
249 fig = addPlotInfo(plt.gcf(), plotInfo)
250 return fig
252 # Plot the initial fit box
253 ax.plot(
254 [data["xMin"], data["xMax"], data["xMax"], data["xMin"], data["xMin"]],
255 [data["yMin"], data["yMin"], data["yMax"], data["yMax"], data["yMin"]],
256 "k",
257 alpha=0.3,
258 )
260 # Add some useful information to the plot
261 bbox = dict(alpha=0.9, facecolor="white", edgecolor="none")
262 medMag = np.nanmedian(cast(Vector, mags))
264 # TODO: GET THE SN FROM THE EARLIER PREP STEP
265 SN = "-"
266 infoText = "N Used: {}\nN Total: {}\nS/N cut: {}\n".format(len(fitPoints), len(xs), SN)
267 infoText += r"Mag $\lesssim$: " + "{:0.2f}".format(medMag)
268 ax.text(0.05, 0.78, infoText, color="k", transform=ax.transAxes, fontsize=8, bbox=bbox)
270 # Calculate the density of the points
271 xy = np.vstack([xs, ys]).T
272 kde = KernelDensity(kernel="gaussian").fit(xy)
273 z = np.exp(kde.score_samples(xy))
275 ax.scatter(xs[~fitPoints], ys[~fitPoints], c=z[~fitPoints], cmap="binary", s=0.3)
276 fitScatter = ax.scatter(
277 xs[fitPoints], ys[fitPoints], c=z[fitPoints], cmap=newBlues, label="Used for Fit", s=0.3
278 )
280 # Add colorbar
281 cbAx = fig.add_axes([0.12, 0.08, 0.43, 0.04])
282 plt.colorbar(fitScatter, cax=cbAx, orientation="horizontal")
283 cbText = cbAx.text(
284 0.5,
285 0.5,
286 "Number Density",
287 color="k",
288 rotation="horizontal",
289 transform=cbAx.transAxes,
290 ha="center",
291 va="center",
292 fontsize=8,
293 )
294 cbText.set_path_effects([pathEffects.Stroke(linewidth=3, foreground="w"), pathEffects.Normal()])
295 cbAx.set_xticks([np.min(z[fitPoints]), np.max(z[fitPoints])], labels=["Less", "More"])
297 ax.set_xlabel(self.xAxisLabel)
298 ax.set_ylabel(self.yAxisLabel)
300 # Set useful axis limits
301 percsX = np.nanpercentile(xs, [0.5, 99.5])
302 percsY = np.nanpercentile(ys, [0.5, 99.5])
303 x5 = (percsX[1] - percsX[0]) / 5
304 y5 = (percsY[1] - percsY[0]) / 5
305 ax.set_xlim(percsX[0] - x5, percsX[1] + x5)
306 ax.set_ylim(percsY[0] - y5, percsY[1] + y5)
308 # Plot the fit lines
309 if np.fabs(data["mHW"]) > 1:
310 ysFitLineHW = np.array([data["yMin"], data["yMax"]])
311 xsFitLineHW = (ysFitLineHW - data["bHW"]) / data["mHW"]
312 ysFitLine = np.array([data["yMin"], data["yMax"]])
313 xsFitLine = (ysFitLine - data["bODR"]) / data["mODR"]
314 ysFitLine2 = np.array([data["yMin"], data["yMax"]])
315 xsFitLine2 = (ysFitLine2 - data["bODR2"]) / data["mODR2"]
317 else:
318 xsFitLineHW = np.array([data["xMin"], data["xMax"]])
319 ysFitLineHW = data["mHW"] * xsFitLineHW + data["bHW"] # type: ignore
320 xsFitLine = np.array([data["xMin"], data["xMax"]])
321 ysFitLine = np.array(
322 [
323 data["mODR"] * xsFitLine[0] + data["bODR"],
324 data["mODR"] * xsFitLine[1] + data["bODR"],
325 ]
326 )
327 xsFitLine2 = np.array([data["xMin"], data["xMax"]])
328 ysFitLine2 = np.array(
329 [
330 data["mODR2"] * xsFitLine2[0] + data["bODR2"],
331 data["mODR2"] * xsFitLine2[1] + data["bODR2"],
332 ]
333 )
335 ax.plot(xsFitLineHW, ysFitLineHW, "w", lw=2)
336 (lineHW,) = ax.plot(xsFitLineHW, ysFitLineHW, "g", lw=1, ls="--", label="Hardwired")
338 ax.plot(xsFitLine, ysFitLine, "w", lw=2)
339 (lineInit,) = ax.plot(xsFitLine, ysFitLine, "b", lw=1, ls="--", label="Initial")
341 ax.plot(xsFitLine2, ysFitLine2, "w", lw=2)
342 (lineRefit,) = ax.plot(xsFitLine2, ysFitLine2, "k", lw=1, ls="--", label="Refit")
344 # Calculate the distances to that line
345 # Need two points to characterise the lines we want
346 # to get the distances to
347 p1 = np.array([xsFitLine[0], ysFitLine[0]])
348 p2 = np.array([xsFitLine[1], ysFitLine[1]])
350 p1HW = np.array([xsFitLine[0], ysFitLineHW[0]])
351 p2HW = np.array([xsFitLine[1], ysFitLineHW[1]])
353 # Convert to millimags
354 distsHW = np.array(perpDistance(p1HW, p2HW, zip(xs[fitPoints], ys[fitPoints]))) * 1000
355 dists = np.array(perpDistance(p1, p2, zip(xs[fitPoints], ys[fitPoints]))) * 1000
357 # Now we have the information for the perpendicular line we
358 # can use it to calculate the points at the ends of the
359 # perpendicular lines that intersect at the box edges
360 if np.fabs(data["mHW"]) > 1:
361 xMid = (data["yMin"] - data["bODR2"]) / data["mODR2"]
362 xs = np.array([xMid - 0.5, xMid, xMid + 0.5])
363 ys = data["mPerp"] * xs + data["bPerpMin"]
364 else:
365 xs = np.array([data["xMin"] - 0.2, data["xMin"], data["xMin"] + 0.2])
366 ys = xs * data["mPerp"] + data["bPerpMin"]
367 ax.plot(xs, ys, "k--", alpha=0.7)
369 if np.fabs(data["mHW"]) > 1:
370 xMid = (data["yMax"] - data["bODR2"]) / data["mODR2"]
371 xs = np.array([xMid - 0.5, xMid, xMid + 0.5])
372 ys = data["mPerp"] * xs + data["bPerpMax"]
373 else:
374 xs = np.array([data["xMax"] - 0.2, data["xMax"], data["xMax"] + 0.2])
375 ys = xs * data["mPerp"] + data["bPerpMax"]
376 ax.plot(xs, ys, "k--", alpha=0.7)
378 # Add a histogram
379 axHist.set_ylabel("Number")
380 axHist.set_xlabel("Distance to Line Fit")
381 medDists = np.nanmedian(dists)
382 madDists = sigmaMad(dists, nan_policy="omit")
383 meanDists = np.nanmean(dists)
385 axHist.set_xlim(meanDists - 2.0 * madDists, meanDists + 2.0 * madDists)
386 lineMedian = axHist.axvline(medDists, color="k", label="Median: {:0.3f}".format(medDists))
387 lineMad = axHist.axvline(
388 medDists + madDists, color="k", ls="--", label="sigma MAD: {:0.3f}".format(madDists)
389 )
390 axHist.axvline(medDists - madDists, color="k", ls="--")
392 linesForLegend = [lineHW, lineInit, lineRefit, fitScatter, lineMedian, lineMad]
393 fig.legend(
394 handles=linesForLegend,
395 fontsize=8,
396 bbox_to_anchor=(1.0, 0.99),
397 bbox_transform=fig.transFigure,
398 ncol=2,
399 )
401 axHist.hist(dists, bins=100, histtype="step", label="Refit", color="C0")
402 axHist.hist(distsHW, bins=100, histtype="step", label="HW", color="C0", alpha=0.5)
404 alphas = [1.0, 0.5]
405 handles = [Rectangle((0, 0), 1, 1, color="none", ec="C0", alpha=a) for a in alphas]
406 labels = ["Refit", "HW"]
407 axHist.legend(handles, labels, fontsize=6, loc="upper right")
409 # Add a contour plot showing the magnitude dependance
410 # of the distance to the fit
411 axContour.invert_yaxis()
412 axContour.axvline(0.0, color="k", ls="--", zorder=-1)
413 percsDists = np.nanpercentile(dists, [4, 96])
414 minXs = -1 * np.min(np.fabs(percsDists))
415 maxXs = np.min(np.fabs(percsDists))
416 plotPoints = (dists < maxXs) & (dists > minXs)
417 xs = np.array(dists)[plotPoints]
418 ys = cast(Vector, cast(Vector, mags)[cast(Vector, fitPoints)])[cast(Vector, plotPoints)]
419 H, xEdges, yEdges = np.histogram2d(xs, ys, bins=(11, 11))
420 xBinWidth = xEdges[1] - xEdges[0]
421 yBinWidth = yEdges[1] - yEdges[0]
422 axContour.contour(
423 xEdges[:-1] + xBinWidth / 2, yEdges[:-1] + yBinWidth / 2, H.T, levels=7, cmap=newBlues
424 )
425 axContour.set_xlabel("Distance to Line Fit")
426 axContour.set_ylabel(self.magLabel)
428 fig = addPlotInfo(plt.gcf(), plotInfo)
430 return fig