Coverage for python/lsst/analysis/tools/actions/plot/colorColorFitPlot.py: 13%

167 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-03-03 02:52 -0800

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ("ColorColorFitPlot",) 

25 

26from typing import Mapping, cast 

27 

28import matplotlib.patheffects as pathEffects 

29import matplotlib.pyplot as plt 

30import numpy as np 

31from lsst.analysis.tools import PlotAction 

32from lsst.pex.config import Field, ListField 

33from matplotlib.figure import Figure 

34from matplotlib.patches import Rectangle 

35from sklearn.neighbors import KernelDensity 

36 

37from ...interfaces import KeyedData, KeyedDataSchema, Scalar, Vector 

38from ...statistics import sigmaMad 

39from .plotUtils import addPlotInfo, mkColormap, perpDistance 

40 

41 

42class ColorColorFitPlot(PlotAction): 

43 xAxisLabel = Field[str](doc="Label to use for the x axis", optional=False) 

44 yAxisLabel = Field[str](doc="Label to use for the y axis", optional=False) 

45 magLabel = Field[str](doc="Label to use for the magnitudes used to color code by", optional=False) 

46 

47 plotTypes = ListField[str]( 

48 doc="Selection of types of objects to plot. Can take any combination of" 

49 " stars, galaxies, unknown, mag, any.", 

50 default=["stars"], 

51 ) 

52 

53 plotName = Field[str](doc="The name for the plot.", optional=False) 

54 

55 def getInputSchema(self, **kwargs) -> KeyedDataSchema: 

56 base: list[tuple[str, type[Vector] | type[Scalar]]] = [] 

57 base.append(("x", Vector)) 

58 base.append(("y", Vector)) 

59 base.append(("mag", Vector)) 

60 base.append(("approxMagDepth", Scalar)) 

61 base.append((f"{self.plotName}_sigmaMAD", Scalar)) 

62 base.append((f"{self.plotName}_median", Scalar)) 

63 base.append((f"{self.plotName}_hardwired_sigmaMAD", Scalar)) 

64 base.append((f"{self.plotName}_hardwired_median", Scalar)) 

65 base.append(("xMin", Scalar)) 

66 base.append(("xMax", Scalar)) 

67 base.append(("yMin", Scalar)) 

68 base.append(("yMax", Scalar)) 

69 base.append(("mHW", Scalar)) 

70 base.append(("bHW", Scalar)) 

71 base.append(("mODR", Scalar)) 

72 base.append(("bODR", Scalar)) 

73 base.append(("yBoxMin", Scalar)) 

74 base.append(("yBoxMax", Scalar)) 

75 base.append(("bPerpMin", Scalar)) 

76 base.append(("bPerpMax", Scalar)) 

77 base.append(("mODR2", Scalar)) 

78 base.append(("bODR2", Scalar)) 

79 base.append(("mPerp", Scalar)) 

80 

81 return base 

82 

83 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure: 

84 self._validateInput(data, **kwargs) 

85 return self.makePlot(data, **kwargs) 

86 

87 def _validateInput(self, data: KeyedData, **kwargs) -> None: 

88 """NOTE currently can only check that something is not a scalar, not 

89 check that data is consistent with Vector 

90 """ 

91 needed = self.getInputSchema(**kwargs) 

92 if remainder := {key.format(**kwargs) for key, _ in needed} - { 

93 key.format(**kwargs) for key in data.keys() 

94 }: 

95 raise ValueError(f"Task needs keys {remainder} but they were not in input") 

96 for name, typ in needed: 

97 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar) 

98 if isScalar and typ != Scalar: 

99 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}") 

100 

101 def makePlot( 

102 self, 

103 data: KeyedData, 

104 plotInfo: Mapping[str, str], 

105 **kwargs, 

106 ) -> Figure: 

107 """Make stellar locus plots using pre fitted values. 

108 

109 Parameters 

110 ---------- 

111 catPlot : `pandas.core.frame.DataFrame` 

112 The catalog to plot the points from. 

113 plotInfo : `dict` 

114 A dictionary of information about the data being plotted with keys: 

115 ``"run"`` 

116 The output run for the plots (`str`). 

117 ``"skymap"`` 

118 The type of skymap used for the data (`str`). 

119 ``"filter"`` 

120 The filter used for this data (`str`). 

121 ``"tract"`` 

122 The tract that the data comes from (`str`). 

123 fitParams : `dict` 

124 The parameters of the fit to the stellar locus calculated 

125 elsewhere, they are used to plot the fit line on the 

126 figure. 

127 ``"bHW"`` 

128 The hardwired intercept to fall back on. 

129 ``"b_odr"`` 

130 The intercept calculated by the orthogonal distance 

131 regression fitting. 

132 ``"mHW"`` 

133 The hardwired gradient to fall back on. 

134 ``"m_odr"`` 

135 The gradient calculated by the orthogonal distance 

136 regression fitting. 

137 ``"magLim"`` 

138 The magnitude limit used in the fitting. 

139 ``"x1`"`` 

140 The x minimum of the box used in the fit. 

141 ``"x2"`` 

142 The x maximum of the box used in the fit. 

143 ``"y1"`` 

144 The y minimum of the box used in the fit. 

145 ``"y2"`` 

146 The y maximum of the box used in the fit. 

147 

148 Returns 

149 ------- 

150 fig : `matplotlib.figure.Figure` 

151 The resulting figure. 

152 

153 Notes 

154 ----- 

155 Makes a color-color plot of `self.config.xColName` against 

156 `self.config.yColName`, these points are color coded by i band 

157 CModel magnitude. The stellar locus fits calculated from 

158 the calcStellarLocus task are then overplotted. The axis labels 

159 are given by `self.config.xLabel` and `self.config.yLabel`. 

160 The selector given in `self.config.sourceSelectorActions` 

161 is used for source selection. The distance of the points to 

162 the fit line is given in a histogram in the second panel. 

163 """ 

164 

165 # Define a new colormap 

166 newBlues = mkColormap(["paleturquoise", "midnightblue"]) 

167 

168 # Make a figure with three panels 

169 fig = plt.figure(dpi=300) 

170 ax = fig.add_axes([0.12, 0.25, 0.43, 0.60]) 

171 axContour = fig.add_axes([0.65, 0.11, 0.3, 0.31]) 

172 axHist = fig.add_axes([0.65, 0.51, 0.3, 0.31]) 

173 

174 # Check for nans/infs 

175 goodPoints = np.isfinite(data["x"]) & np.isfinite(data["y"]) & np.isfinite(data["mag"]) 

176 xs = cast(Vector, data["x"])[goodPoints] 

177 ys = cast(Vector, data["y"])[goodPoints] 

178 mags = cast(Vector, data["mag"])[goodPoints] 

179 

180 # TODO: Make a no data fig function and use here 

181 if len(xs) == 0 or len(ys) == 0: 

182 return fig 

183 

184 # Points to use for the fit 

185 # type ignore because Vector needs a prototype interface 

186 fitPoints = np.where( 

187 (xs > data["xMin"]) # type: ignore 

188 & (xs < data["xMax"]) # type: ignore 

189 & (ys > data["yMin"]) # type: ignore 

190 & (ys < data["yMax"]) # type: ignore 

191 )[0] 

192 

193 # Plot the initial fit box 

194 ax.plot( 

195 [data["xMin"], data["xMax"], data["xMax"], data["xMin"], data["xMin"]], 

196 [data["yMin"], data["yMin"], data["yMax"], data["yMax"], data["yMin"]], 

197 "k", 

198 alpha=0.3, 

199 ) 

200 

201 # Add some useful information to the plot 

202 bbox = dict(alpha=0.9, facecolor="white", edgecolor="none") 

203 medMag = np.nanmedian(cast(Vector, mags)) 

204 

205 # TODO: GET THE SN FROM THE EARLIER PREP STEP 

206 SN = "-" 

207 infoText = "N Used: {}\nN Total: {}\nS/N cut: {}\n".format(len(fitPoints), len(xs), SN) 

208 infoText += r"Mag $\lesssim$: " + "{:0.2f}".format(medMag) 

209 ax.text(0.05, 0.78, infoText, color="k", transform=ax.transAxes, fontsize=8, bbox=bbox) 

210 

211 # Calculate the density of the points 

212 xy = np.vstack([xs, ys]).T 

213 kde = KernelDensity(kernel="gaussian").fit(xy) 

214 z = np.exp(kde.score_samples(xy)) 

215 

216 ax.scatter(xs[~fitPoints], ys[~fitPoints], c=z[~fitPoints], cmap="binary", s=0.3) 

217 fitScatter = ax.scatter( 

218 xs[fitPoints], ys[fitPoints], c=z[fitPoints], cmap=newBlues, label="Used for Fit", s=0.3 

219 ) 

220 

221 # Add colorbar 

222 cbAx = fig.add_axes([0.12, 0.08, 0.43, 0.04]) 

223 plt.colorbar(fitScatter, cax=cbAx, orientation="horizontal") 

224 cbText = cbAx.text( 

225 0.5, 

226 0.5, 

227 "Number Density", 

228 color="k", 

229 rotation="horizontal", 

230 transform=cbAx.transAxes, 

231 ha="center", 

232 va="center", 

233 fontsize=8, 

234 ) 

235 cbText.set_path_effects([pathEffects.Stroke(linewidth=3, foreground="w"), pathEffects.Normal()]) 

236 cbAx.set_xticks([np.min(z[fitPoints]), np.max(z[fitPoints])], labels=["Less", "More"]) 

237 

238 ax.set_xlabel(self.xAxisLabel) 

239 ax.set_ylabel(self.yAxisLabel) 

240 

241 # Set useful axis limits 

242 percsX = np.nanpercentile(xs, [0.5, 99.5]) 

243 percsY = np.nanpercentile(ys, [0.5, 99.5]) 

244 x5 = (percsX[1] - percsX[0]) / 5 

245 y5 = (percsY[1] - percsY[0]) / 5 

246 ax.set_xlim(percsX[0] - x5, percsX[1] + x5) 

247 ax.set_ylim(percsY[0] - y5, percsY[1] + y5) 

248 

249 # Plot the fit lines 

250 if np.fabs(data["mHW"]) > 1: 

251 ysFitLineHW = np.array([data["yMin"], data["yMax"]]) 

252 xsFitLineHW = (ysFitLineHW - data["bHW"]) / data["mHW"] 

253 ysFitLine = np.array([data["yMin"], data["yMax"]]) 

254 xsFitLine = (ysFitLine - data["bODR"]) / data["mODR"] 

255 ysFitLine2 = np.array([data["yMin"], data["yMax"]]) 

256 xsFitLine2 = (ysFitLine2 - data["bODR2"]) / data["mODR2"] 

257 

258 else: 

259 xsFitLineHW = np.array([data["xMin"], data["xMax"]]) 

260 ysFitLineHW = data["mHW"] * xsFitLineHW + data["bHW"] # type: ignore 

261 xsFitLine = np.array([data["xMin"], data["xMax"]]) 

262 ysFitLine = np.array( 

263 [ 

264 data["mODR"] * xsFitLine[0] + data["bODR"], 

265 data["mODR"] * xsFitLine[1] + data["bODR"], 

266 ] 

267 ) 

268 xsFitLine2 = np.array([data["xMin"], data["xMax"]]) 

269 ysFitLine2 = np.array( 

270 [ 

271 data["mODR2"] * xsFitLine2[0] + data["bODR2"], 

272 data["mODR2"] * xsFitLine2[1] + data["bODR2"], 

273 ] 

274 ) 

275 

276 ax.plot(xsFitLineHW, ysFitLineHW, "w", lw=2) 

277 (lineHW,) = ax.plot(xsFitLineHW, ysFitLineHW, "g", lw=1, ls="--", label="Hardwired") 

278 

279 ax.plot(xsFitLine, ysFitLine, "w", lw=2) 

280 (lineInit,) = ax.plot(xsFitLine, ysFitLine, "b", lw=1, ls="--", label="Initial") 

281 

282 ax.plot(xsFitLine2, ysFitLine2, "w", lw=2) 

283 (lineRefit,) = ax.plot(xsFitLine2, ysFitLine2, "k", lw=1, ls="--", label="Refit") 

284 

285 # Calculate the distances to that line 

286 # Need two points to characterise the lines we want 

287 # to get the distances to 

288 p1 = np.array([xsFitLine[0], ysFitLine[0]]) 

289 p2 = np.array([xsFitLine[1], ysFitLine[1]]) 

290 

291 p1HW = np.array([xsFitLine[0], ysFitLineHW[0]]) 

292 p2HW = np.array([xsFitLine[1], ysFitLineHW[1]]) 

293 

294 # Convert to millimags 

295 distsHW = np.array(perpDistance(p1HW, p2HW, zip(xs[fitPoints], ys[fitPoints]))) * 1000 

296 dists = np.array(perpDistance(p1, p2, zip(xs[fitPoints], ys[fitPoints]))) * 1000 

297 

298 # Now we have the information for the perpendicular line we 

299 # can use it to calculate the points at the ends of the 

300 # perpendicular lines that intersect at the box edges 

301 if np.fabs(data["mHW"]) > 1: 

302 xMid = (data["yMin"] - data["bODR2"]) / data["mODR2"] 

303 xs = np.array([xMid - 0.5, xMid, xMid + 0.5]) 

304 ys = data["mPerp"] * xs + data["bPerpMin"] 

305 else: 

306 xs = np.array([data["xMin"] - 0.2, data["xMin"], data["xMin"] + 0.2]) 

307 ys = xs * data["mPerp"] + data["bPerpMin"] 

308 ax.plot(xs, ys, "k--", alpha=0.7) 

309 

310 if np.fabs(data["mHW"]) > 1: 

311 xMid = (data["yMax"] - data["bODR2"]) / data["mODR2"] 

312 xs = np.array([xMid - 0.5, xMid, xMid + 0.5]) 

313 ys = data["mPerp"] * xs + data["bPerpMax"] 

314 else: 

315 xs = np.array([data["xMax"] - 0.2, data["xMax"], data["xMax"] + 0.2]) 

316 ys = xs * data["mPerp"] + data["bPerpMax"] 

317 ax.plot(xs, ys, "k--", alpha=0.7) 

318 

319 # Add a histogram 

320 axHist.set_ylabel("Number") 

321 axHist.set_xlabel("Distance to Line Fit") 

322 medDists = np.nanmedian(dists) 

323 madDists = sigmaMad(dists, nan_policy="omit") 

324 meanDists = np.nanmean(dists) 

325 

326 axHist.set_xlim(meanDists - 2.0 * madDists, meanDists + 2.0 * madDists) 

327 lineMedian = axHist.axvline(medDists, color="k", label="Median: {:0.3f}".format(medDists)) 

328 lineMad = axHist.axvline( 

329 medDists + madDists, color="k", ls="--", label="sigma MAD: {:0.3f}".format(madDists) 

330 ) 

331 axHist.axvline(medDists - madDists, color="k", ls="--") 

332 

333 linesForLegend = [lineHW, lineInit, lineRefit, fitScatter, lineMedian, lineMad] 

334 fig.legend( 

335 handles=linesForLegend, 

336 fontsize=8, 

337 bbox_to_anchor=(1.0, 0.99), 

338 bbox_transform=fig.transFigure, 

339 ncol=2, 

340 ) 

341 

342 axHist.hist(dists, bins=100, histtype="step", label="Refit", color="C0") 

343 axHist.hist(distsHW, bins=100, histtype="step", label="HW", color="C0", alpha=0.5) 

344 

345 alphas = [1.0, 0.5] 

346 handles = [Rectangle((0, 0), 1, 1, color="none", ec="C0", alpha=a) for a in alphas] 

347 labels = ["Refit", "HW"] 

348 axHist.legend(handles, labels, fontsize=6, loc="upper right") 

349 

350 # Add a contour plot showing the magnitude dependance 

351 # of the distance to the fit 

352 axContour.invert_yaxis() 

353 axContour.axvline(0.0, color="k", ls="--", zorder=-1) 

354 percsDists = np.nanpercentile(dists, [4, 96]) 

355 minXs = -1 * np.min(np.fabs(percsDists)) 

356 maxXs = np.min(np.fabs(percsDists)) 

357 plotPoints = (dists < maxXs) & (dists > minXs) 

358 xs = np.array(dists)[plotPoints] 

359 ys = cast(Vector, cast(Vector, mags)[cast(Vector, fitPoints)])[cast(Vector, plotPoints)] 

360 H, xEdges, yEdges = np.histogram2d(xs, ys, bins=(11, 11)) 

361 xBinWidth = xEdges[1] - xEdges[0] 

362 yBinWidth = yEdges[1] - yEdges[0] 

363 axContour.contour( 

364 xEdges[:-1] + xBinWidth / 2, yEdges[:-1] + yBinWidth / 2, H.T, levels=7, cmap=newBlues 

365 ) 

366 axContour.set_xlabel("Distance to Line Fit") 

367 axContour.set_ylabel(self.magLabel) 

368 

369 fig = addPlotInfo(plt.gcf(), plotInfo) 

370 

371 return fig