Coverage for python/lsst/analysis/tools/actions/plot/colorColorFitPlot.py: 14%

165 statements  

« prev     ^ index     » next       coverage.py v6.4.4, created at 2022-08-20 09:54 +0000

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23from typing import Mapping, Optional, cast 

24 

25import matplotlib.patheffects as pathEffects 

26import matplotlib.pyplot as plt 

27import numpy as np 

28from lsst.analysis.tools import PlotAction 

29from lsst.pex.config import Field, ListField 

30from matplotlib.figure import Figure 

31from matplotlib.patches import Rectangle 

32from scipy.stats import median_absolute_deviation as sigmaMad 

33from sklearn.neighbors import KernelDensity 

34 

35from ...interfaces import KeyedData, KeyedDataSchema, Scalar, Vector 

36from .plotUtils import addPlotInfo, mkColormap, perpDistance 

37 

38 

39class ColorColorFitPlot(PlotAction): 

40 

41 xAxisLabel = Field[str](doc="Label to use for the x axis", optional=False) 

42 yAxisLabel = Field[str](doc="Label to use for the y axis", optional=False) 

43 magLabel = Field[str](doc="Label to use for the magnitudes used to color code by", optional=False) 

44 

45 plotTypes = ListField[str]( 

46 doc="Selection of types of objects to plot. Can take any combination of" 

47 " stars, galaxies, unknown, mag, any.", 

48 optional=False, 

49 ) 

50 

51 plotName = Field[str](doc="The name for the plot.", optional=False) 

52 

53 def getInputSchema(self, **kwargs) -> KeyedDataSchema: 

54 base: list[tuple[str, type[Vector] | type[Scalar]]] = [] 

55 base.append(("x", Vector)) 

56 base.append(("y", Vector)) 

57 base.append(("mag", Vector)) 

58 base.append(("approxMagDepth", Scalar)) 

59 base.append((f"{self.plotName}_sigmaMAD", Scalar)) 

60 base.append((f"{self.plotName}_median", Scalar)) 

61 base.append((f"{self.plotName}_hardwired_sigmaMAD", Scalar)) 

62 base.append((f"{self.plotName}_hardwired_median", Scalar)) 

63 base.append(("xMin", Scalar)) 

64 base.append(("xMax", Scalar)) 

65 base.append(("yMin", Scalar)) 

66 base.append(("yMax", Scalar)) 

67 base.append(("mHW", Scalar)) 

68 base.append(("bHW", Scalar)) 

69 base.append(("mODR", Scalar)) 

70 base.append(("bODR", Scalar)) 

71 base.append(("yBoxMin", Scalar)) 

72 base.append(("yBoxMax", Scalar)) 

73 base.append(("bPerpMin", Scalar)) 

74 base.append(("bPerpMax", Scalar)) 

75 base.append(("mODR2", Scalar)) 

76 base.append(("bODR2", Scalar)) 

77 base.append(("mPerp", Scalar)) 

78 

79 return base 

80 

81 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure: 

82 self._validateInput(data, **kwargs) 

83 return self.makePlot(data, **kwargs) 

84 

85 def _validateInput(self, data: KeyedData, **kwargs) -> None: 

86 """NOTE currently can only check that something is not a scalar, not 

87 check that data is consistent with Vector 

88 """ 

89 needed = self.getInputSchema(**kwargs) 

90 if remainder := {key.format(**kwargs) for key, _ in needed} - { 

91 key.format(**kwargs) for key in data.keys() 

92 }: 

93 raise ValueError(f"Task needs keys {remainder} but they were not in input") 

94 for name, typ in needed: 

95 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar) 

96 if isScalar and typ != Scalar: 

97 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}") 

98 

99 def makePlot( 

100 self, 

101 data: KeyedData, 

102 plotInfo: Optional[Mapping[str, str]] = None, 

103 **kwargs, 

104 ) -> Figure: 

105 """Make stellar locus plots using pre fitted values. 

106 

107 Parameters 

108 ---------- 

109 catPlot : `pandas.core.frame.DataFrame` 

110 The catalog to plot the points from. 

111 plotInfo : `dict` 

112 A dictionary of information about the data being plotted with keys: 

113 ``"run"`` 

114 The output run for the plots (`str`). 

115 ``"skymap"`` 

116 The type of skymap used for the data (`str`). 

117 ``"filter"`` 

118 The filter used for this data (`str`). 

119 ``"tract"`` 

120 The tract that the data comes from (`str`). 

121 fitParams : `dict` 

122 The parameters of the fit to the stellar locus calculated 

123 elsewhere, they are used to plot the fit line on the 

124 figure. 

125 ``"bHW"`` 

126 The hardwired intercept to fall back on. 

127 ``"b_odr"`` 

128 The intercept calculated by the orthogonal distance 

129 regression fitting. 

130 ``"mHW"`` 

131 The hardwired gradient to fall back on. 

132 ``"m_odr"`` 

133 The gradient calculated by the orthogonal distance 

134 regression fitting. 

135 ``"magLim"`` 

136 The magnitude limit used in the fitting. 

137 ``"x1`"`` 

138 The x minimum of the box used in the fit. 

139 ``"x2"`` 

140 The x maximum of the box used in the fit. 

141 ``"y1"`` 

142 The y minimum of the box used in the fit. 

143 ``"y2"`` 

144 The y maximum of the box used in the fit. 

145 

146 Returns 

147 ------- 

148 fig : `matplotlib.figure.Figure` 

149 The resulting figure. 

150 

151 Notes 

152 ----- 

153 Makes a color-color plot of `self.config.xColName` against 

154 `self.config.yColName`, these points are color coded by i band 

155 CModel magnitude. The stellar locus fits calculated from 

156 the calcStellarLocus task are then overplotted. The axis labels 

157 are given by `self.config.xLabel` and `self.config.yLabel`. 

158 The selector given in `self.config.sourceSelectorActions` 

159 is used for source selection. The distance of the points to 

160 the fit line is given in a histogram in the second panel. 

161 """ 

162 

163 # Define a new colormap 

164 newBlues = mkColormap(["paleturquoise", "midnightblue"]) 

165 

166 # Make a figure with three panels 

167 fig = plt.figure(dpi=300) 

168 ax = fig.add_axes([0.12, 0.25, 0.43, 0.60]) 

169 axContour = fig.add_axes([0.65, 0.11, 0.3, 0.31]) 

170 axHist = fig.add_axes([0.65, 0.51, 0.3, 0.31]) 

171 xs = cast(Vector, data["x"]) 

172 ys = cast(Vector, data["y"]) 

173 mags = data["mag"] 

174 

175 # TODO: Make a no data fig function and use here 

176 if len(xs) == 0 or len(ys) == 0: 

177 return fig 

178 

179 # Points to use for the fit 

180 # type ignore because Vector needs a prototype interface 

181 fitPoints = np.where( 

182 (xs > data["xMin"]) # type: ignore 

183 & (xs < data["xMax"]) # type: ignore 

184 & (ys > data["yMin"]) # type: ignore 

185 & (ys < data["yMax"]) # type: ignore 

186 )[0] 

187 

188 # Plot the initial fit box 

189 ax.plot( 

190 [data["xMin"], data["xMax"], data["xMax"], data["xMin"], data["xMin"]], 

191 [data["yMin"], data["yMin"], data["yMax"], data["yMax"], data["yMin"]], 

192 "k", 

193 alpha=0.3, 

194 ) 

195 

196 # Add some useful information to the plot 

197 bbox = dict(alpha=0.9, facecolor="white", edgecolor="none") 

198 medMag = np.median(cast(Vector, mags)) 

199 

200 # TODO: GET THE SN FROM THE EARLIER PREP STEP 

201 SN = "-" 

202 infoText = "N Used: {}\nN Total: {}\nS/N cut: {}\n".format(len(fitPoints), len(xs), SN) 

203 infoText += r"Mag $\lesssim$: " + "{:0.2f}".format(medMag) 

204 ax.text(0.05, 0.78, infoText, color="k", transform=ax.transAxes, fontsize=8, bbox=bbox) 

205 

206 # Calculate the density of the points 

207 xy = np.vstack([xs, ys]).T 

208 kde = KernelDensity(kernel="gaussian").fit(xy) 

209 z = np.exp(kde.score_samples(xy)) 

210 

211 ax.scatter(xs[~fitPoints], ys[~fitPoints], c=z[~fitPoints], cmap="binary", s=0.3) 

212 fitScatter = ax.scatter( 

213 xs[fitPoints], ys[fitPoints], c=z[fitPoints], cmap=newBlues, label="Used for Fit", s=0.3 

214 ) 

215 

216 # Add colorbar 

217 cbAx = fig.add_axes([0.12, 0.08, 0.43, 0.04]) 

218 plt.colorbar(fitScatter, cax=cbAx, orientation="horizontal") 

219 cbText = cbAx.text( 

220 0.5, 

221 0.5, 

222 "Number Density", 

223 color="k", 

224 rotation="horizontal", 

225 transform=cbAx.transAxes, 

226 ha="center", 

227 va="center", 

228 fontsize=8, 

229 ) 

230 cbText.set_path_effects([pathEffects.Stroke(linewidth=3, foreground="w"), pathEffects.Normal()]) 

231 cbAx.set_xticks([np.min(z[fitPoints]), np.max(z[fitPoints])], labels=["Less", "More"]) 

232 

233 ax.set_xlabel(self.xAxisLabel) 

234 ax.set_ylabel(self.yAxisLabel) 

235 

236 # Set useful axis limits 

237 percsX = np.nanpercentile(xs, [0.5, 99.5]) 

238 percsY = np.nanpercentile(ys, [0.5, 99.5]) 

239 x5 = (percsX[1] - percsX[0]) / 5 

240 y5 = (percsY[1] - percsY[0]) / 5 

241 ax.set_xlim(percsX[0] - x5, percsX[1] + x5) 

242 ax.set_ylim(percsY[0] - y5, percsY[1] + y5) 

243 

244 # Plot the fit lines 

245 if np.fabs(data["mHW"]) > 1: 

246 ysFitLineHW = np.array([data["yMin"], data["yMax"]]) 

247 xsFitLineHW = (ysFitLineHW - data["bHW"]) / data["mHW"] 

248 ysFitLine = np.array([data["yMin"], data["yMax"]]) 

249 xsFitLine = (ysFitLine - data["bODR"]) / data["mODR"] 

250 ysFitLine2 = np.array([data["yMin"], data["yMax"]]) 

251 xsFitLine2 = (ysFitLine2 - data["bODR2"]) / data["mODR2"] 

252 

253 else: 

254 xsFitLineHW = np.array([data["xMin"], data["xMax"]]) 

255 ysFitLineHW = data["mHW"] * xsFitLineHW + data["bHW"] # type: ignore 

256 xsFitLine = np.array([data["xMin"], data["xMax"]]) 

257 ysFitLine = np.array( 

258 [ 

259 data["mODR"] * xsFitLine[0] + data["bODR"], 

260 data["mODR"] * xsFitLine[1] + data["bODR"], 

261 ] 

262 ) 

263 xsFitLine2 = np.array([data["xMin"], data["xMax"]]) 

264 ysFitLine2 = np.array( 

265 [ 

266 data["mODR2"] * xsFitLine2[0] + data["bODR2"], 

267 data["mODR2"] * xsFitLine2[1] + data["bODR2"], 

268 ] 

269 ) 

270 

271 ax.plot(xsFitLineHW, ysFitLineHW, "w", lw=2) 

272 (lineHW,) = ax.plot(xsFitLineHW, ysFitLineHW, "g", lw=1, ls="--", label="Hardwired") 

273 

274 ax.plot(xsFitLine, ysFitLine, "w", lw=2) 

275 (lineInit,) = ax.plot(xsFitLine, ysFitLine, "b", lw=1, ls="--", label="Initial") 

276 

277 ax.plot(xsFitLine2, ysFitLine2, "w", lw=2) 

278 (lineRefit,) = ax.plot(xsFitLine2, ysFitLine2, "k", lw=1, ls="--", label="Refit") 

279 

280 # Calculate the distances to that line 

281 # Need two points to characterise the lines we want 

282 # to get the distances to 

283 p1 = np.array([xsFitLine[0], ysFitLine[0]]) 

284 p2 = np.array([xsFitLine[1], ysFitLine[1]]) 

285 

286 p1HW = np.array([xsFitLine[0], ysFitLineHW[0]]) 

287 p2HW = np.array([xsFitLine[1], ysFitLineHW[1]]) 

288 

289 distsHW = perpDistance(p1HW, p2HW, zip(xs[fitPoints], ys[fitPoints])) 

290 dists = perpDistance(p1, p2, zip(xs[fitPoints], ys[fitPoints])) 

291 

292 # Now we have the information for the perpendicular line we 

293 # can use it to calculate the points at the ends of the 

294 # perpendicular lines that intersect at the box edges 

295 if np.fabs(data["mHW"]) > 1: 

296 xMid = (data["yMin"] - data["bODR2"]) / data["mODR2"] 

297 xs = np.array([xMid - 0.5, xMid, xMid + 0.5]) 

298 ys = data["mPerp"] * xs + data["bPerpMin"] 

299 else: 

300 xs = np.array([data["xMin"] - 0.2, data["xMin"], data["xMin"] + 0.2]) 

301 ys = xs * data["mPerp"] + data["bPerpMin"] 

302 ax.plot(xs, ys, "k--", alpha=0.7) 

303 

304 if np.fabs(data["mHW"]) > 1: 

305 xMid = (data["yMax"] - data["bODR2"]) / data["mODR2"] 

306 xs = np.array([xMid - 0.5, xMid, xMid + 0.5]) 

307 ys = data["mPerp"] * xs + data["bPerpMax"] 

308 else: 

309 xs = np.array([data["xMax"] - 0.2, data["xMax"], data["xMax"] + 0.2]) 

310 ys = xs * data["mPerp"] + data["bPerpMax"] 

311 ax.plot(xs, ys, "k--", alpha=0.7) 

312 

313 # Add a histogram 

314 axHist.set_ylabel("Number") 

315 axHist.set_xlabel("Distance to Line Fit") 

316 medDists = np.median(dists) 

317 madDists = sigmaMad(dists) 

318 meanDists = np.mean(dists) 

319 

320 axHist.set_xlim(meanDists - 2.0 * madDists, meanDists + 2.0 * madDists) 

321 lineMedian = axHist.axvline(medDists, color="k", label="Median: {:0.3f}".format(medDists)) 

322 lineMad = axHist.axvline( 

323 medDists + madDists, color="k", ls="--", label="sigma MAD: {:0.3f}".format(madDists) 

324 ) 

325 axHist.axvline(medDists - madDists, color="k", ls="--") 

326 

327 linesForLegend = [lineHW, lineInit, lineRefit, fitScatter, lineMedian, lineMad] 

328 fig.legend( 

329 handles=linesForLegend, 

330 fontsize=8, 

331 bbox_to_anchor=(1.0, 0.99), 

332 bbox_transform=fig.transFigure, 

333 ncol=2, 

334 ) 

335 

336 axHist.hist(dists, bins=100, histtype="step", label="Refit", color="C0") 

337 axHist.hist(distsHW, bins=100, histtype="step", label="HW", color="C0", alpha=0.5) 

338 

339 alphas = [1.0, 0.5] 

340 handles = [Rectangle((0, 0), 1, 1, color="none", ec="C0", alpha=a) for a in alphas] 

341 labels = ["Refit", "HW"] 

342 axHist.legend(handles, labels, fontsize=6, loc="upper right") 

343 

344 # Add a contour plot showing the magnitude dependance 

345 # of the distance to the fit 

346 axContour.invert_yaxis() 

347 axContour.axvline(0.0, color="k", ls="--", zorder=-1) 

348 percsDists = np.nanpercentile(dists, [4, 96]) 

349 minXs = -1 * np.min(np.fabs(percsDists)) 

350 maxXs = np.min(np.fabs(percsDists)) 

351 plotPoints = (dists < maxXs) & (dists > minXs) 

352 xs = np.array(dists)[plotPoints] 

353 ys = cast(Vector, cast(Vector, mags)[cast(Vector, fitPoints)])[cast(Vector, plotPoints)] 

354 H, xEdges, yEdges = np.histogram2d(xs, ys, bins=(11, 11)) 

355 xBinWidth = xEdges[1] - xEdges[0] 

356 yBinWidth = yEdges[1] - yEdges[0] 

357 axContour.contour( 

358 xEdges[:-1] + xBinWidth / 2, yEdges[:-1] + yBinWidth / 2, H.T, levels=7, cmap=newBlues 

359 ) 

360 axContour.set_xlabel("Distance to Line Fit") 

361 axContour.set_ylabel(self.magLabel) 

362 

363 fig = addPlotInfo(plt.gcf(), plotInfo) 

364 

365 return fig