Coverage for python/lsst/analysis/tools/actions/plot/colorColorFitPlot.py: 13%

167 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-01-28 03:16 -0800

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ("ColorColorFitPlot",) 

25 

26from typing import Mapping, cast 

27 

28import matplotlib.patheffects as pathEffects 

29import matplotlib.pyplot as plt 

30import numpy as np 

31from lsst.analysis.tools import PlotAction 

32from lsst.pex.config import Field, ListField 

33from matplotlib.figure import Figure 

34from matplotlib.patches import Rectangle 

35from sklearn.neighbors import KernelDensity 

36 

37from ...interfaces import KeyedData, KeyedDataSchema, Scalar, Vector 

38from ...statistics import sigmaMad 

39from .plotUtils import addPlotInfo, mkColormap, perpDistance 

40 

41 

42class ColorColorFitPlot(PlotAction): 

43 

44 xAxisLabel = Field[str](doc="Label to use for the x axis", optional=False) 

45 yAxisLabel = Field[str](doc="Label to use for the y axis", optional=False) 

46 magLabel = Field[str](doc="Label to use for the magnitudes used to color code by", optional=False) 

47 

48 plotTypes = ListField[str]( 

49 doc="Selection of types of objects to plot. Can take any combination of" 

50 " stars, galaxies, unknown, mag, any.", 

51 default=["stars"], 

52 ) 

53 

54 plotName = Field[str](doc="The name for the plot.", optional=False) 

55 

56 def getInputSchema(self, **kwargs) -> KeyedDataSchema: 

57 base: list[tuple[str, type[Vector] | type[Scalar]]] = [] 

58 base.append(("x", Vector)) 

59 base.append(("y", Vector)) 

60 base.append(("mag", Vector)) 

61 base.append(("approxMagDepth", Scalar)) 

62 base.append((f"{self.plotName}_sigmaMAD", Scalar)) 

63 base.append((f"{self.plotName}_median", Scalar)) 

64 base.append((f"{self.plotName}_hardwired_sigmaMAD", Scalar)) 

65 base.append((f"{self.plotName}_hardwired_median", Scalar)) 

66 base.append(("xMin", Scalar)) 

67 base.append(("xMax", Scalar)) 

68 base.append(("yMin", Scalar)) 

69 base.append(("yMax", Scalar)) 

70 base.append(("mHW", Scalar)) 

71 base.append(("bHW", Scalar)) 

72 base.append(("mODR", Scalar)) 

73 base.append(("bODR", Scalar)) 

74 base.append(("yBoxMin", Scalar)) 

75 base.append(("yBoxMax", Scalar)) 

76 base.append(("bPerpMin", Scalar)) 

77 base.append(("bPerpMax", Scalar)) 

78 base.append(("mODR2", Scalar)) 

79 base.append(("bODR2", Scalar)) 

80 base.append(("mPerp", Scalar)) 

81 

82 return base 

83 

84 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure: 

85 self._validateInput(data, **kwargs) 

86 return self.makePlot(data, **kwargs) 

87 

88 def _validateInput(self, data: KeyedData, **kwargs) -> None: 

89 """NOTE currently can only check that something is not a scalar, not 

90 check that data is consistent with Vector 

91 """ 

92 needed = self.getInputSchema(**kwargs) 

93 if remainder := {key.format(**kwargs) for key, _ in needed} - { 

94 key.format(**kwargs) for key in data.keys() 

95 }: 

96 raise ValueError(f"Task needs keys {remainder} but they were not in input") 

97 for name, typ in needed: 

98 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar) 

99 if isScalar and typ != Scalar: 

100 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}") 

101 

102 def makePlot( 

103 self, 

104 data: KeyedData, 

105 plotInfo: Mapping[str, str], 

106 **kwargs, 

107 ) -> Figure: 

108 """Make stellar locus plots using pre fitted values. 

109 

110 Parameters 

111 ---------- 

112 catPlot : `pandas.core.frame.DataFrame` 

113 The catalog to plot the points from. 

114 plotInfo : `dict` 

115 A dictionary of information about the data being plotted with keys: 

116 ``"run"`` 

117 The output run for the plots (`str`). 

118 ``"skymap"`` 

119 The type of skymap used for the data (`str`). 

120 ``"filter"`` 

121 The filter used for this data (`str`). 

122 ``"tract"`` 

123 The tract that the data comes from (`str`). 

124 fitParams : `dict` 

125 The parameters of the fit to the stellar locus calculated 

126 elsewhere, they are used to plot the fit line on the 

127 figure. 

128 ``"bHW"`` 

129 The hardwired intercept to fall back on. 

130 ``"b_odr"`` 

131 The intercept calculated by the orthogonal distance 

132 regression fitting. 

133 ``"mHW"`` 

134 The hardwired gradient to fall back on. 

135 ``"m_odr"`` 

136 The gradient calculated by the orthogonal distance 

137 regression fitting. 

138 ``"magLim"`` 

139 The magnitude limit used in the fitting. 

140 ``"x1`"`` 

141 The x minimum of the box used in the fit. 

142 ``"x2"`` 

143 The x maximum of the box used in the fit. 

144 ``"y1"`` 

145 The y minimum of the box used in the fit. 

146 ``"y2"`` 

147 The y maximum of the box used in the fit. 

148 

149 Returns 

150 ------- 

151 fig : `matplotlib.figure.Figure` 

152 The resulting figure. 

153 

154 Notes 

155 ----- 

156 Makes a color-color plot of `self.config.xColName` against 

157 `self.config.yColName`, these points are color coded by i band 

158 CModel magnitude. The stellar locus fits calculated from 

159 the calcStellarLocus task are then overplotted. The axis labels 

160 are given by `self.config.xLabel` and `self.config.yLabel`. 

161 The selector given in `self.config.sourceSelectorActions` 

162 is used for source selection. The distance of the points to 

163 the fit line is given in a histogram in the second panel. 

164 """ 

165 

166 # Define a new colormap 

167 newBlues = mkColormap(["paleturquoise", "midnightblue"]) 

168 

169 # Make a figure with three panels 

170 fig = plt.figure(dpi=300) 

171 ax = fig.add_axes([0.12, 0.25, 0.43, 0.60]) 

172 axContour = fig.add_axes([0.65, 0.11, 0.3, 0.31]) 

173 axHist = fig.add_axes([0.65, 0.51, 0.3, 0.31]) 

174 

175 # Check for nans/infs 

176 goodPoints = np.isfinite(data["x"]) & np.isfinite(data["y"]) & np.isfinite(data["mag"]) 

177 xs = cast(Vector, data["x"])[goodPoints] 

178 ys = cast(Vector, data["y"])[goodPoints] 

179 mags = cast(Vector, data["mag"])[goodPoints] 

180 

181 # TODO: Make a no data fig function and use here 

182 if len(xs) == 0 or len(ys) == 0: 

183 return fig 

184 

185 # Points to use for the fit 

186 # type ignore because Vector needs a prototype interface 

187 fitPoints = np.where( 

188 (xs > data["xMin"]) # type: ignore 

189 & (xs < data["xMax"]) # type: ignore 

190 & (ys > data["yMin"]) # type: ignore 

191 & (ys < data["yMax"]) # type: ignore 

192 )[0] 

193 

194 # Plot the initial fit box 

195 ax.plot( 

196 [data["xMin"], data["xMax"], data["xMax"], data["xMin"], data["xMin"]], 

197 [data["yMin"], data["yMin"], data["yMax"], data["yMax"], data["yMin"]], 

198 "k", 

199 alpha=0.3, 

200 ) 

201 

202 # Add some useful information to the plot 

203 bbox = dict(alpha=0.9, facecolor="white", edgecolor="none") 

204 medMag = np.nanmedian(cast(Vector, mags)) 

205 

206 # TODO: GET THE SN FROM THE EARLIER PREP STEP 

207 SN = "-" 

208 infoText = "N Used: {}\nN Total: {}\nS/N cut: {}\n".format(len(fitPoints), len(xs), SN) 

209 infoText += r"Mag $\lesssim$: " + "{:0.2f}".format(medMag) 

210 ax.text(0.05, 0.78, infoText, color="k", transform=ax.transAxes, fontsize=8, bbox=bbox) 

211 

212 # Calculate the density of the points 

213 xy = np.vstack([xs, ys]).T 

214 kde = KernelDensity(kernel="gaussian").fit(xy) 

215 z = np.exp(kde.score_samples(xy)) 

216 

217 ax.scatter(xs[~fitPoints], ys[~fitPoints], c=z[~fitPoints], cmap="binary", s=0.3) 

218 fitScatter = ax.scatter( 

219 xs[fitPoints], ys[fitPoints], c=z[fitPoints], cmap=newBlues, label="Used for Fit", s=0.3 

220 ) 

221 

222 # Add colorbar 

223 cbAx = fig.add_axes([0.12, 0.08, 0.43, 0.04]) 

224 plt.colorbar(fitScatter, cax=cbAx, orientation="horizontal") 

225 cbText = cbAx.text( 

226 0.5, 

227 0.5, 

228 "Number Density", 

229 color="k", 

230 rotation="horizontal", 

231 transform=cbAx.transAxes, 

232 ha="center", 

233 va="center", 

234 fontsize=8, 

235 ) 

236 cbText.set_path_effects([pathEffects.Stroke(linewidth=3, foreground="w"), pathEffects.Normal()]) 

237 cbAx.set_xticks([np.min(z[fitPoints]), np.max(z[fitPoints])], labels=["Less", "More"]) 

238 

239 ax.set_xlabel(self.xAxisLabel) 

240 ax.set_ylabel(self.yAxisLabel) 

241 

242 # Set useful axis limits 

243 percsX = np.nanpercentile(xs, [0.5, 99.5]) 

244 percsY = np.nanpercentile(ys, [0.5, 99.5]) 

245 x5 = (percsX[1] - percsX[0]) / 5 

246 y5 = (percsY[1] - percsY[0]) / 5 

247 ax.set_xlim(percsX[0] - x5, percsX[1] + x5) 

248 ax.set_ylim(percsY[0] - y5, percsY[1] + y5) 

249 

250 # Plot the fit lines 

251 if np.fabs(data["mHW"]) > 1: 

252 ysFitLineHW = np.array([data["yMin"], data["yMax"]]) 

253 xsFitLineHW = (ysFitLineHW - data["bHW"]) / data["mHW"] 

254 ysFitLine = np.array([data["yMin"], data["yMax"]]) 

255 xsFitLine = (ysFitLine - data["bODR"]) / data["mODR"] 

256 ysFitLine2 = np.array([data["yMin"], data["yMax"]]) 

257 xsFitLine2 = (ysFitLine2 - data["bODR2"]) / data["mODR2"] 

258 

259 else: 

260 xsFitLineHW = np.array([data["xMin"], data["xMax"]]) 

261 ysFitLineHW = data["mHW"] * xsFitLineHW + data["bHW"] # type: ignore 

262 xsFitLine = np.array([data["xMin"], data["xMax"]]) 

263 ysFitLine = np.array( 

264 [ 

265 data["mODR"] * xsFitLine[0] + data["bODR"], 

266 data["mODR"] * xsFitLine[1] + data["bODR"], 

267 ] 

268 ) 

269 xsFitLine2 = np.array([data["xMin"], data["xMax"]]) 

270 ysFitLine2 = np.array( 

271 [ 

272 data["mODR2"] * xsFitLine2[0] + data["bODR2"], 

273 data["mODR2"] * xsFitLine2[1] + data["bODR2"], 

274 ] 

275 ) 

276 

277 ax.plot(xsFitLineHW, ysFitLineHW, "w", lw=2) 

278 (lineHW,) = ax.plot(xsFitLineHW, ysFitLineHW, "g", lw=1, ls="--", label="Hardwired") 

279 

280 ax.plot(xsFitLine, ysFitLine, "w", lw=2) 

281 (lineInit,) = ax.plot(xsFitLine, ysFitLine, "b", lw=1, ls="--", label="Initial") 

282 

283 ax.plot(xsFitLine2, ysFitLine2, "w", lw=2) 

284 (lineRefit,) = ax.plot(xsFitLine2, ysFitLine2, "k", lw=1, ls="--", label="Refit") 

285 

286 # Calculate the distances to that line 

287 # Need two points to characterise the lines we want 

288 # to get the distances to 

289 p1 = np.array([xsFitLine[0], ysFitLine[0]]) 

290 p2 = np.array([xsFitLine[1], ysFitLine[1]]) 

291 

292 p1HW = np.array([xsFitLine[0], ysFitLineHW[0]]) 

293 p2HW = np.array([xsFitLine[1], ysFitLineHW[1]]) 

294 

295 # Convert to millimags 

296 distsHW = np.array(perpDistance(p1HW, p2HW, zip(xs[fitPoints], ys[fitPoints]))) * 1000 

297 dists = np.array(perpDistance(p1, p2, zip(xs[fitPoints], ys[fitPoints]))) * 1000 

298 

299 # Now we have the information for the perpendicular line we 

300 # can use it to calculate the points at the ends of the 

301 # perpendicular lines that intersect at the box edges 

302 if np.fabs(data["mHW"]) > 1: 

303 xMid = (data["yMin"] - data["bODR2"]) / data["mODR2"] 

304 xs = np.array([xMid - 0.5, xMid, xMid + 0.5]) 

305 ys = data["mPerp"] * xs + data["bPerpMin"] 

306 else: 

307 xs = np.array([data["xMin"] - 0.2, data["xMin"], data["xMin"] + 0.2]) 

308 ys = xs * data["mPerp"] + data["bPerpMin"] 

309 ax.plot(xs, ys, "k--", alpha=0.7) 

310 

311 if np.fabs(data["mHW"]) > 1: 

312 xMid = (data["yMax"] - data["bODR2"]) / data["mODR2"] 

313 xs = np.array([xMid - 0.5, xMid, xMid + 0.5]) 

314 ys = data["mPerp"] * xs + data["bPerpMax"] 

315 else: 

316 xs = np.array([data["xMax"] - 0.2, data["xMax"], data["xMax"] + 0.2]) 

317 ys = xs * data["mPerp"] + data["bPerpMax"] 

318 ax.plot(xs, ys, "k--", alpha=0.7) 

319 

320 # Add a histogram 

321 axHist.set_ylabel("Number") 

322 axHist.set_xlabel("Distance to Line Fit") 

323 medDists = np.nanmedian(dists) 

324 madDists = sigmaMad(dists, nan_policy="omit") 

325 meanDists = np.nanmean(dists) 

326 

327 axHist.set_xlim(meanDists - 2.0 * madDists, meanDists + 2.0 * madDists) 

328 lineMedian = axHist.axvline(medDists, color="k", label="Median: {:0.3f}".format(medDists)) 

329 lineMad = axHist.axvline( 

330 medDists + madDists, color="k", ls="--", label="sigma MAD: {:0.3f}".format(madDists) 

331 ) 

332 axHist.axvline(medDists - madDists, color="k", ls="--") 

333 

334 linesForLegend = [lineHW, lineInit, lineRefit, fitScatter, lineMedian, lineMad] 

335 fig.legend( 

336 handles=linesForLegend, 

337 fontsize=8, 

338 bbox_to_anchor=(1.0, 0.99), 

339 bbox_transform=fig.transFigure, 

340 ncol=2, 

341 ) 

342 

343 axHist.hist(dists, bins=100, histtype="step", label="Refit", color="C0") 

344 axHist.hist(distsHW, bins=100, histtype="step", label="HW", color="C0", alpha=0.5) 

345 

346 alphas = [1.0, 0.5] 

347 handles = [Rectangle((0, 0), 1, 1, color="none", ec="C0", alpha=a) for a in alphas] 

348 labels = ["Refit", "HW"] 

349 axHist.legend(handles, labels, fontsize=6, loc="upper right") 

350 

351 # Add a contour plot showing the magnitude dependance 

352 # of the distance to the fit 

353 axContour.invert_yaxis() 

354 axContour.axvline(0.0, color="k", ls="--", zorder=-1) 

355 percsDists = np.nanpercentile(dists, [4, 96]) 

356 minXs = -1 * np.min(np.fabs(percsDists)) 

357 maxXs = np.min(np.fabs(percsDists)) 

358 plotPoints = (dists < maxXs) & (dists > minXs) 

359 xs = np.array(dists)[plotPoints] 

360 ys = cast(Vector, cast(Vector, mags)[cast(Vector, fitPoints)])[cast(Vector, plotPoints)] 

361 H, xEdges, yEdges = np.histogram2d(xs, ys, bins=(11, 11)) 

362 xBinWidth = xEdges[1] - xEdges[0] 

363 yBinWidth = yEdges[1] - yEdges[0] 

364 axContour.contour( 

365 xEdges[:-1] + xBinWidth / 2, yEdges[:-1] + yBinWidth / 2, H.T, levels=7, cmap=newBlues 

366 ) 

367 axContour.set_xlabel("Distance to Line Fit") 

368 axContour.set_ylabel(self.magLabel) 

369 

370 fig = addPlotInfo(plt.gcf(), plotInfo) 

371 

372 return fig