Coverage for python/lsst/analysis/tools/actions/plot/colorColorFitPlot.py: 14%

165 statements  

« prev     ^ index     » next       coverage.py v6.4.4, created at 2022-09-17 01:32 -0700

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24from typing import Mapping, Optional, cast 

25 

26import matplotlib.patheffects as pathEffects 

27import matplotlib.pyplot as plt 

28import numpy as np 

29from lsst.analysis.tools import PlotAction 

30from lsst.pex.config import Field, ListField 

31from matplotlib.figure import Figure 

32from matplotlib.patches import Rectangle 

33from sklearn.neighbors import KernelDensity 

34 

35from ...interfaces import KeyedData, KeyedDataSchema, Scalar, Vector 

36from ...statistics import sigmaMad 

37from .plotUtils import addPlotInfo, mkColormap, perpDistance 

38 

39 

40class ColorColorFitPlot(PlotAction): 

41 

42 xAxisLabel = Field[str](doc="Label to use for the x axis", optional=False) 

43 yAxisLabel = Field[str](doc="Label to use for the y axis", optional=False) 

44 magLabel = Field[str](doc="Label to use for the magnitudes used to color code by", optional=False) 

45 

46 plotTypes = ListField[str]( 

47 doc="Selection of types of objects to plot. Can take any combination of" 

48 " stars, galaxies, unknown, mag, any.", 

49 optional=False, 

50 ) 

51 

52 plotName = Field[str](doc="The name for the plot.", optional=False) 

53 

54 def getInputSchema(self, **kwargs) -> KeyedDataSchema: 

55 base: list[tuple[str, type[Vector] | type[Scalar]]] = [] 

56 base.append(("x", Vector)) 

57 base.append(("y", Vector)) 

58 base.append(("mag", Vector)) 

59 base.append(("approxMagDepth", Scalar)) 

60 base.append((f"{self.plotName}_sigmaMAD", Scalar)) 

61 base.append((f"{self.plotName}_median", Scalar)) 

62 base.append((f"{self.plotName}_hardwired_sigmaMAD", Scalar)) 

63 base.append((f"{self.plotName}_hardwired_median", Scalar)) 

64 base.append(("xMin", Scalar)) 

65 base.append(("xMax", Scalar)) 

66 base.append(("yMin", Scalar)) 

67 base.append(("yMax", Scalar)) 

68 base.append(("mHW", Scalar)) 

69 base.append(("bHW", Scalar)) 

70 base.append(("mODR", Scalar)) 

71 base.append(("bODR", Scalar)) 

72 base.append(("yBoxMin", Scalar)) 

73 base.append(("yBoxMax", Scalar)) 

74 base.append(("bPerpMin", Scalar)) 

75 base.append(("bPerpMax", Scalar)) 

76 base.append(("mODR2", Scalar)) 

77 base.append(("bODR2", Scalar)) 

78 base.append(("mPerp", Scalar)) 

79 

80 return base 

81 

82 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure: 

83 self._validateInput(data, **kwargs) 

84 return self.makePlot(data, **kwargs) 

85 

86 def _validateInput(self, data: KeyedData, **kwargs) -> None: 

87 """NOTE currently can only check that something is not a scalar, not 

88 check that data is consistent with Vector 

89 """ 

90 needed = self.getInputSchema(**kwargs) 

91 if remainder := {key.format(**kwargs) for key, _ in needed} - { 

92 key.format(**kwargs) for key in data.keys() 

93 }: 

94 raise ValueError(f"Task needs keys {remainder} but they were not in input") 

95 for name, typ in needed: 

96 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar) 

97 if isScalar and typ != Scalar: 

98 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}") 

99 

100 def makePlot( 

101 self, 

102 data: KeyedData, 

103 plotInfo: Optional[Mapping[str, str]] = None, 

104 **kwargs, 

105 ) -> Figure: 

106 """Make stellar locus plots using pre fitted values. 

107 

108 Parameters 

109 ---------- 

110 catPlot : `pandas.core.frame.DataFrame` 

111 The catalog to plot the points from. 

112 plotInfo : `dict` 

113 A dictionary of information about the data being plotted with keys: 

114 ``"run"`` 

115 The output run for the plots (`str`). 

116 ``"skymap"`` 

117 The type of skymap used for the data (`str`). 

118 ``"filter"`` 

119 The filter used for this data (`str`). 

120 ``"tract"`` 

121 The tract that the data comes from (`str`). 

122 fitParams : `dict` 

123 The parameters of the fit to the stellar locus calculated 

124 elsewhere, they are used to plot the fit line on the 

125 figure. 

126 ``"bHW"`` 

127 The hardwired intercept to fall back on. 

128 ``"b_odr"`` 

129 The intercept calculated by the orthogonal distance 

130 regression fitting. 

131 ``"mHW"`` 

132 The hardwired gradient to fall back on. 

133 ``"m_odr"`` 

134 The gradient calculated by the orthogonal distance 

135 regression fitting. 

136 ``"magLim"`` 

137 The magnitude limit used in the fitting. 

138 ``"x1`"`` 

139 The x minimum of the box used in the fit. 

140 ``"x2"`` 

141 The x maximum of the box used in the fit. 

142 ``"y1"`` 

143 The y minimum of the box used in the fit. 

144 ``"y2"`` 

145 The y maximum of the box used in the fit. 

146 

147 Returns 

148 ------- 

149 fig : `matplotlib.figure.Figure` 

150 The resulting figure. 

151 

152 Notes 

153 ----- 

154 Makes a color-color plot of `self.config.xColName` against 

155 `self.config.yColName`, these points are color coded by i band 

156 CModel magnitude. The stellar locus fits calculated from 

157 the calcStellarLocus task are then overplotted. The axis labels 

158 are given by `self.config.xLabel` and `self.config.yLabel`. 

159 The selector given in `self.config.sourceSelectorActions` 

160 is used for source selection. The distance of the points to 

161 the fit line is given in a histogram in the second panel. 

162 """ 

163 

164 # Define a new colormap 

165 newBlues = mkColormap(["paleturquoise", "midnightblue"]) 

166 

167 # Make a figure with three panels 

168 fig = plt.figure(dpi=300) 

169 ax = fig.add_axes([0.12, 0.25, 0.43, 0.60]) 

170 axContour = fig.add_axes([0.65, 0.11, 0.3, 0.31]) 

171 axHist = fig.add_axes([0.65, 0.51, 0.3, 0.31]) 

172 xs = cast(Vector, data["x"]) 

173 ys = cast(Vector, data["y"]) 

174 mags = data["mag"] 

175 

176 # TODO: Make a no data fig function and use here 

177 if len(xs) == 0 or len(ys) == 0: 

178 return fig 

179 

180 # Points to use for the fit 

181 # type ignore because Vector needs a prototype interface 

182 fitPoints = np.where( 

183 (xs > data["xMin"]) # type: ignore 

184 & (xs < data["xMax"]) # type: ignore 

185 & (ys > data["yMin"]) # type: ignore 

186 & (ys < data["yMax"]) # type: ignore 

187 )[0] 

188 

189 # Plot the initial fit box 

190 ax.plot( 

191 [data["xMin"], data["xMax"], data["xMax"], data["xMin"], data["xMin"]], 

192 [data["yMin"], data["yMin"], data["yMax"], data["yMax"], data["yMin"]], 

193 "k", 

194 alpha=0.3, 

195 ) 

196 

197 # Add some useful information to the plot 

198 bbox = dict(alpha=0.9, facecolor="white", edgecolor="none") 

199 medMag = np.median(cast(Vector, mags)) 

200 

201 # TODO: GET THE SN FROM THE EARLIER PREP STEP 

202 SN = "-" 

203 infoText = "N Used: {}\nN Total: {}\nS/N cut: {}\n".format(len(fitPoints), len(xs), SN) 

204 infoText += r"Mag $\lesssim$: " + "{:0.2f}".format(medMag) 

205 ax.text(0.05, 0.78, infoText, color="k", transform=ax.transAxes, fontsize=8, bbox=bbox) 

206 

207 # Calculate the density of the points 

208 xy = np.vstack([xs, ys]).T 

209 kde = KernelDensity(kernel="gaussian").fit(xy) 

210 z = np.exp(kde.score_samples(xy)) 

211 

212 ax.scatter(xs[~fitPoints], ys[~fitPoints], c=z[~fitPoints], cmap="binary", s=0.3) 

213 fitScatter = ax.scatter( 

214 xs[fitPoints], ys[fitPoints], c=z[fitPoints], cmap=newBlues, label="Used for Fit", s=0.3 

215 ) 

216 

217 # Add colorbar 

218 cbAx = fig.add_axes([0.12, 0.08, 0.43, 0.04]) 

219 plt.colorbar(fitScatter, cax=cbAx, orientation="horizontal") 

220 cbText = cbAx.text( 

221 0.5, 

222 0.5, 

223 "Number Density", 

224 color="k", 

225 rotation="horizontal", 

226 transform=cbAx.transAxes, 

227 ha="center", 

228 va="center", 

229 fontsize=8, 

230 ) 

231 cbText.set_path_effects([pathEffects.Stroke(linewidth=3, foreground="w"), pathEffects.Normal()]) 

232 cbAx.set_xticks([np.min(z[fitPoints]), np.max(z[fitPoints])], labels=["Less", "More"]) 

233 

234 ax.set_xlabel(self.xAxisLabel) 

235 ax.set_ylabel(self.yAxisLabel) 

236 

237 # Set useful axis limits 

238 percsX = np.nanpercentile(xs, [0.5, 99.5]) 

239 percsY = np.nanpercentile(ys, [0.5, 99.5]) 

240 x5 = (percsX[1] - percsX[0]) / 5 

241 y5 = (percsY[1] - percsY[0]) / 5 

242 ax.set_xlim(percsX[0] - x5, percsX[1] + x5) 

243 ax.set_ylim(percsY[0] - y5, percsY[1] + y5) 

244 

245 # Plot the fit lines 

246 if np.fabs(data["mHW"]) > 1: 

247 ysFitLineHW = np.array([data["yMin"], data["yMax"]]) 

248 xsFitLineHW = (ysFitLineHW - data["bHW"]) / data["mHW"] 

249 ysFitLine = np.array([data["yMin"], data["yMax"]]) 

250 xsFitLine = (ysFitLine - data["bODR"]) / data["mODR"] 

251 ysFitLine2 = np.array([data["yMin"], data["yMax"]]) 

252 xsFitLine2 = (ysFitLine2 - data["bODR2"]) / data["mODR2"] 

253 

254 else: 

255 xsFitLineHW = np.array([data["xMin"], data["xMax"]]) 

256 ysFitLineHW = data["mHW"] * xsFitLineHW + data["bHW"] # type: ignore 

257 xsFitLine = np.array([data["xMin"], data["xMax"]]) 

258 ysFitLine = np.array( 

259 [ 

260 data["mODR"] * xsFitLine[0] + data["bODR"], 

261 data["mODR"] * xsFitLine[1] + data["bODR"], 

262 ] 

263 ) 

264 xsFitLine2 = np.array([data["xMin"], data["xMax"]]) 

265 ysFitLine2 = np.array( 

266 [ 

267 data["mODR2"] * xsFitLine2[0] + data["bODR2"], 

268 data["mODR2"] * xsFitLine2[1] + data["bODR2"], 

269 ] 

270 ) 

271 

272 ax.plot(xsFitLineHW, ysFitLineHW, "w", lw=2) 

273 (lineHW,) = ax.plot(xsFitLineHW, ysFitLineHW, "g", lw=1, ls="--", label="Hardwired") 

274 

275 ax.plot(xsFitLine, ysFitLine, "w", lw=2) 

276 (lineInit,) = ax.plot(xsFitLine, ysFitLine, "b", lw=1, ls="--", label="Initial") 

277 

278 ax.plot(xsFitLine2, ysFitLine2, "w", lw=2) 

279 (lineRefit,) = ax.plot(xsFitLine2, ysFitLine2, "k", lw=1, ls="--", label="Refit") 

280 

281 # Calculate the distances to that line 

282 # Need two points to characterise the lines we want 

283 # to get the distances to 

284 p1 = np.array([xsFitLine[0], ysFitLine[0]]) 

285 p2 = np.array([xsFitLine[1], ysFitLine[1]]) 

286 

287 p1HW = np.array([xsFitLine[0], ysFitLineHW[0]]) 

288 p2HW = np.array([xsFitLine[1], ysFitLineHW[1]]) 

289 

290 distsHW = perpDistance(p1HW, p2HW, zip(xs[fitPoints], ys[fitPoints])) 

291 dists = perpDistance(p1, p2, zip(xs[fitPoints], ys[fitPoints])) 

292 

293 # Now we have the information for the perpendicular line we 

294 # can use it to calculate the points at the ends of the 

295 # perpendicular lines that intersect at the box edges 

296 if np.fabs(data["mHW"]) > 1: 

297 xMid = (data["yMin"] - data["bODR2"]) / data["mODR2"] 

298 xs = np.array([xMid - 0.5, xMid, xMid + 0.5]) 

299 ys = data["mPerp"] * xs + data["bPerpMin"] 

300 else: 

301 xs = np.array([data["xMin"] - 0.2, data["xMin"], data["xMin"] + 0.2]) 

302 ys = xs * data["mPerp"] + data["bPerpMin"] 

303 ax.plot(xs, ys, "k--", alpha=0.7) 

304 

305 if np.fabs(data["mHW"]) > 1: 

306 xMid = (data["yMax"] - data["bODR2"]) / data["mODR2"] 

307 xs = np.array([xMid - 0.5, xMid, xMid + 0.5]) 

308 ys = data["mPerp"] * xs + data["bPerpMax"] 

309 else: 

310 xs = np.array([data["xMax"] - 0.2, data["xMax"], data["xMax"] + 0.2]) 

311 ys = xs * data["mPerp"] + data["bPerpMax"] 

312 ax.plot(xs, ys, "k--", alpha=0.7) 

313 

314 # Add a histogram 

315 axHist.set_ylabel("Number") 

316 axHist.set_xlabel("Distance to Line Fit") 

317 medDists = np.median(dists) 

318 madDists = sigmaMad(dists) 

319 meanDists = np.mean(dists) 

320 

321 axHist.set_xlim(meanDists - 2.0 * madDists, meanDists + 2.0 * madDists) 

322 lineMedian = axHist.axvline(medDists, color="k", label="Median: {:0.3f}".format(medDists)) 

323 lineMad = axHist.axvline( 

324 medDists + madDists, color="k", ls="--", label="sigma MAD: {:0.3f}".format(madDists) 

325 ) 

326 axHist.axvline(medDists - madDists, color="k", ls="--") 

327 

328 linesForLegend = [lineHW, lineInit, lineRefit, fitScatter, lineMedian, lineMad] 

329 fig.legend( 

330 handles=linesForLegend, 

331 fontsize=8, 

332 bbox_to_anchor=(1.0, 0.99), 

333 bbox_transform=fig.transFigure, 

334 ncol=2, 

335 ) 

336 

337 axHist.hist(dists, bins=100, histtype="step", label="Refit", color="C0") 

338 axHist.hist(distsHW, bins=100, histtype="step", label="HW", color="C0", alpha=0.5) 

339 

340 alphas = [1.0, 0.5] 

341 handles = [Rectangle((0, 0), 1, 1, color="none", ec="C0", alpha=a) for a in alphas] 

342 labels = ["Refit", "HW"] 

343 axHist.legend(handles, labels, fontsize=6, loc="upper right") 

344 

345 # Add a contour plot showing the magnitude dependance 

346 # of the distance to the fit 

347 axContour.invert_yaxis() 

348 axContour.axvline(0.0, color="k", ls="--", zorder=-1) 

349 percsDists = np.nanpercentile(dists, [4, 96]) 

350 minXs = -1 * np.min(np.fabs(percsDists)) 

351 maxXs = np.min(np.fabs(percsDists)) 

352 plotPoints = (dists < maxXs) & (dists > minXs) 

353 xs = np.array(dists)[plotPoints] 

354 ys = cast(Vector, cast(Vector, mags)[cast(Vector, fitPoints)])[cast(Vector, plotPoints)] 

355 H, xEdges, yEdges = np.histogram2d(xs, ys, bins=(11, 11)) 

356 xBinWidth = xEdges[1] - xEdges[0] 

357 yBinWidth = yEdges[1] - yEdges[0] 

358 axContour.contour( 

359 xEdges[:-1] + xBinWidth / 2, yEdges[:-1] + yBinWidth / 2, H.T, levels=7, cmap=newBlues 

360 ) 

361 axContour.set_xlabel("Distance to Line Fit") 

362 axContour.set_ylabel(self.magLabel) 

363 

364 fig = addPlotInfo(plt.gcf(), plotInfo) 

365 

366 return fig