Coverage for python/lsst/analysis/tools/actions/plot/colorColorFitPlot.py: 13%

166 statements  

« prev     ^ index     » next       coverage.py v7.2.1, created at 2023-03-12 21:46 -0700

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24from typing import Mapping, Optional, cast 

25 

26import matplotlib.patheffects as pathEffects 

27import matplotlib.pyplot as plt 

28import numpy as np 

29from lsst.analysis.tools import PlotAction 

30from lsst.pex.config import Field, ListField 

31from matplotlib.figure import Figure 

32from matplotlib.patches import Rectangle 

33from sklearn.neighbors import KernelDensity 

34 

35from ...interfaces import KeyedData, KeyedDataSchema, Scalar, Vector 

36from ...statistics import sigmaMad 

37from .plotUtils import addPlotInfo, mkColormap, perpDistance 

38 

39 

40class ColorColorFitPlot(PlotAction): 

41 

42 xAxisLabel = Field[str](doc="Label to use for the x axis", optional=False) 

43 yAxisLabel = Field[str](doc="Label to use for the y axis", optional=False) 

44 magLabel = Field[str](doc="Label to use for the magnitudes used to color code by", optional=False) 

45 

46 plotTypes = ListField[str]( 

47 doc="Selection of types of objects to plot. Can take any combination of" 

48 " stars, galaxies, unknown, mag, any.", 

49 default=["stars"], 

50 ) 

51 

52 plotName = Field[str](doc="The name for the plot.", optional=False) 

53 

54 def getInputSchema(self, **kwargs) -> KeyedDataSchema: 

55 base: list[tuple[str, type[Vector] | type[Scalar]]] = [] 

56 base.append(("x", Vector)) 

57 base.append(("y", Vector)) 

58 base.append(("mag", Vector)) 

59 base.append(("approxMagDepth", Scalar)) 

60 base.append((f"{self.plotName}_sigmaMAD", Scalar)) 

61 base.append((f"{self.plotName}_median", Scalar)) 

62 base.append((f"{self.plotName}_hardwired_sigmaMAD", Scalar)) 

63 base.append((f"{self.plotName}_hardwired_median", Scalar)) 

64 base.append(("xMin", Scalar)) 

65 base.append(("xMax", Scalar)) 

66 base.append(("yMin", Scalar)) 

67 base.append(("yMax", Scalar)) 

68 base.append(("mHW", Scalar)) 

69 base.append(("bHW", Scalar)) 

70 base.append(("mODR", Scalar)) 

71 base.append(("bODR", Scalar)) 

72 base.append(("yBoxMin", Scalar)) 

73 base.append(("yBoxMax", Scalar)) 

74 base.append(("bPerpMin", Scalar)) 

75 base.append(("bPerpMax", Scalar)) 

76 base.append(("mODR2", Scalar)) 

77 base.append(("bODR2", Scalar)) 

78 base.append(("mPerp", Scalar)) 

79 

80 return base 

81 

82 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure: 

83 self._validateInput(data, **kwargs) 

84 return self.makePlot(data, **kwargs) 

85 

86 def _validateInput(self, data: KeyedData, **kwargs) -> None: 

87 """NOTE currently can only check that something is not a scalar, not 

88 check that data is consistent with Vector 

89 """ 

90 needed = self.getInputSchema(**kwargs) 

91 if remainder := {key.format(**kwargs) for key, _ in needed} - { 

92 key.format(**kwargs) for key in data.keys() 

93 }: 

94 raise ValueError(f"Task needs keys {remainder} but they were not in input") 

95 for name, typ in needed: 

96 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar) 

97 if isScalar and typ != Scalar: 

98 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}") 

99 

100 def makePlot( 

101 self, 

102 data: KeyedData, 

103 plotInfo: Optional[Mapping[str, str]] = None, 

104 **kwargs, 

105 ) -> Figure: 

106 """Make stellar locus plots using pre fitted values. 

107 

108 Parameters 

109 ---------- 

110 catPlot : `pandas.core.frame.DataFrame` 

111 The catalog to plot the points from. 

112 plotInfo : `dict` 

113 A dictionary of information about the data being plotted with keys: 

114 ``"run"`` 

115 The output run for the plots (`str`). 

116 ``"skymap"`` 

117 The type of skymap used for the data (`str`). 

118 ``"filter"`` 

119 The filter used for this data (`str`). 

120 ``"tract"`` 

121 The tract that the data comes from (`str`). 

122 fitParams : `dict` 

123 The parameters of the fit to the stellar locus calculated 

124 elsewhere, they are used to plot the fit line on the 

125 figure. 

126 ``"bHW"`` 

127 The hardwired intercept to fall back on. 

128 ``"b_odr"`` 

129 The intercept calculated by the orthogonal distance 

130 regression fitting. 

131 ``"mHW"`` 

132 The hardwired gradient to fall back on. 

133 ``"m_odr"`` 

134 The gradient calculated by the orthogonal distance 

135 regression fitting. 

136 ``"magLim"`` 

137 The magnitude limit used in the fitting. 

138 ``"x1`"`` 

139 The x minimum of the box used in the fit. 

140 ``"x2"`` 

141 The x maximum of the box used in the fit. 

142 ``"y1"`` 

143 The y minimum of the box used in the fit. 

144 ``"y2"`` 

145 The y maximum of the box used in the fit. 

146 

147 Returns 

148 ------- 

149 fig : `matplotlib.figure.Figure` 

150 The resulting figure. 

151 

152 Notes 

153 ----- 

154 Makes a color-color plot of `self.config.xColName` against 

155 `self.config.yColName`, these points are color coded by i band 

156 CModel magnitude. The stellar locus fits calculated from 

157 the calcStellarLocus task are then overplotted. The axis labels 

158 are given by `self.config.xLabel` and `self.config.yLabel`. 

159 The selector given in `self.config.sourceSelectorActions` 

160 is used for source selection. The distance of the points to 

161 the fit line is given in a histogram in the second panel. 

162 """ 

163 

164 # Define a new colormap 

165 newBlues = mkColormap(["paleturquoise", "midnightblue"]) 

166 

167 # Make a figure with three panels 

168 fig = plt.figure(dpi=300) 

169 ax = fig.add_axes([0.12, 0.25, 0.43, 0.60]) 

170 axContour = fig.add_axes([0.65, 0.11, 0.3, 0.31]) 

171 axHist = fig.add_axes([0.65, 0.51, 0.3, 0.31]) 

172 

173 # Check for nans/infs 

174 goodPoints = np.isfinite(data["x"]) & np.isfinite(data["y"]) & np.isfinite(data["mag"]) 

175 xs = cast(Vector, data["x"])[goodPoints] 

176 ys = cast(Vector, data["y"])[goodPoints] 

177 mags = cast(Vector, data["mag"])[goodPoints] 

178 

179 # TODO: Make a no data fig function and use here 

180 if len(xs) == 0 or len(ys) == 0: 

181 return fig 

182 

183 # Points to use for the fit 

184 # type ignore because Vector needs a prototype interface 

185 fitPoints = np.where( 

186 (xs > data["xMin"]) # type: ignore 

187 & (xs < data["xMax"]) # type: ignore 

188 & (ys > data["yMin"]) # type: ignore 

189 & (ys < data["yMax"]) # type: ignore 

190 )[0] 

191 

192 # Plot the initial fit box 

193 ax.plot( 

194 [data["xMin"], data["xMax"], data["xMax"], data["xMin"], data["xMin"]], 

195 [data["yMin"], data["yMin"], data["yMax"], data["yMax"], data["yMin"]], 

196 "k", 

197 alpha=0.3, 

198 ) 

199 

200 # Add some useful information to the plot 

201 bbox = dict(alpha=0.9, facecolor="white", edgecolor="none") 

202 medMag = np.nanmedian(cast(Vector, mags)) 

203 

204 # TODO: GET THE SN FROM THE EARLIER PREP STEP 

205 SN = "-" 

206 infoText = "N Used: {}\nN Total: {}\nS/N cut: {}\n".format(len(fitPoints), len(xs), SN) 

207 infoText += r"Mag $\lesssim$: " + "{:0.2f}".format(medMag) 

208 ax.text(0.05, 0.78, infoText, color="k", transform=ax.transAxes, fontsize=8, bbox=bbox) 

209 

210 # Calculate the density of the points 

211 xy = np.vstack([xs, ys]).T 

212 kde = KernelDensity(kernel="gaussian").fit(xy) 

213 z = np.exp(kde.score_samples(xy)) 

214 

215 ax.scatter(xs[~fitPoints], ys[~fitPoints], c=z[~fitPoints], cmap="binary", s=0.3) 

216 fitScatter = ax.scatter( 

217 xs[fitPoints], ys[fitPoints], c=z[fitPoints], cmap=newBlues, label="Used for Fit", s=0.3 

218 ) 

219 

220 # Add colorbar 

221 cbAx = fig.add_axes([0.12, 0.08, 0.43, 0.04]) 

222 plt.colorbar(fitScatter, cax=cbAx, orientation="horizontal") 

223 cbText = cbAx.text( 

224 0.5, 

225 0.5, 

226 "Number Density", 

227 color="k", 

228 rotation="horizontal", 

229 transform=cbAx.transAxes, 

230 ha="center", 

231 va="center", 

232 fontsize=8, 

233 ) 

234 cbText.set_path_effects([pathEffects.Stroke(linewidth=3, foreground="w"), pathEffects.Normal()]) 

235 cbAx.set_xticks([np.min(z[fitPoints]), np.max(z[fitPoints])], labels=["Less", "More"]) 

236 

237 ax.set_xlabel(self.xAxisLabel) 

238 ax.set_ylabel(self.yAxisLabel) 

239 

240 # Set useful axis limits 

241 percsX = np.nanpercentile(xs, [0.5, 99.5]) 

242 percsY = np.nanpercentile(ys, [0.5, 99.5]) 

243 x5 = (percsX[1] - percsX[0]) / 5 

244 y5 = (percsY[1] - percsY[0]) / 5 

245 ax.set_xlim(percsX[0] - x5, percsX[1] + x5) 

246 ax.set_ylim(percsY[0] - y5, percsY[1] + y5) 

247 

248 # Plot the fit lines 

249 if np.fabs(data["mHW"]) > 1: 

250 ysFitLineHW = np.array([data["yMin"], data["yMax"]]) 

251 xsFitLineHW = (ysFitLineHW - data["bHW"]) / data["mHW"] 

252 ysFitLine = np.array([data["yMin"], data["yMax"]]) 

253 xsFitLine = (ysFitLine - data["bODR"]) / data["mODR"] 

254 ysFitLine2 = np.array([data["yMin"], data["yMax"]]) 

255 xsFitLine2 = (ysFitLine2 - data["bODR2"]) / data["mODR2"] 

256 

257 else: 

258 xsFitLineHW = np.array([data["xMin"], data["xMax"]]) 

259 ysFitLineHW = data["mHW"] * xsFitLineHW + data["bHW"] # type: ignore 

260 xsFitLine = np.array([data["xMin"], data["xMax"]]) 

261 ysFitLine = np.array( 

262 [ 

263 data["mODR"] * xsFitLine[0] + data["bODR"], 

264 data["mODR"] * xsFitLine[1] + data["bODR"], 

265 ] 

266 ) 

267 xsFitLine2 = np.array([data["xMin"], data["xMax"]]) 

268 ysFitLine2 = np.array( 

269 [ 

270 data["mODR2"] * xsFitLine2[0] + data["bODR2"], 

271 data["mODR2"] * xsFitLine2[1] + data["bODR2"], 

272 ] 

273 ) 

274 

275 ax.plot(xsFitLineHW, ysFitLineHW, "w", lw=2) 

276 (lineHW,) = ax.plot(xsFitLineHW, ysFitLineHW, "g", lw=1, ls="--", label="Hardwired") 

277 

278 ax.plot(xsFitLine, ysFitLine, "w", lw=2) 

279 (lineInit,) = ax.plot(xsFitLine, ysFitLine, "b", lw=1, ls="--", label="Initial") 

280 

281 ax.plot(xsFitLine2, ysFitLine2, "w", lw=2) 

282 (lineRefit,) = ax.plot(xsFitLine2, ysFitLine2, "k", lw=1, ls="--", label="Refit") 

283 

284 # Calculate the distances to that line 

285 # Need two points to characterise the lines we want 

286 # to get the distances to 

287 p1 = np.array([xsFitLine[0], ysFitLine[0]]) 

288 p2 = np.array([xsFitLine[1], ysFitLine[1]]) 

289 

290 p1HW = np.array([xsFitLine[0], ysFitLineHW[0]]) 

291 p2HW = np.array([xsFitLine[1], ysFitLineHW[1]]) 

292 

293 # Convert to millimags 

294 distsHW = np.array(perpDistance(p1HW, p2HW, zip(xs[fitPoints], ys[fitPoints]))) * 1000 

295 dists = np.array(perpDistance(p1, p2, zip(xs[fitPoints], ys[fitPoints]))) * 1000 

296 

297 # Now we have the information for the perpendicular line we 

298 # can use it to calculate the points at the ends of the 

299 # perpendicular lines that intersect at the box edges 

300 if np.fabs(data["mHW"]) > 1: 

301 xMid = (data["yMin"] - data["bODR2"]) / data["mODR2"] 

302 xs = np.array([xMid - 0.5, xMid, xMid + 0.5]) 

303 ys = data["mPerp"] * xs + data["bPerpMin"] 

304 else: 

305 xs = np.array([data["xMin"] - 0.2, data["xMin"], data["xMin"] + 0.2]) 

306 ys = xs * data["mPerp"] + data["bPerpMin"] 

307 ax.plot(xs, ys, "k--", alpha=0.7) 

308 

309 if np.fabs(data["mHW"]) > 1: 

310 xMid = (data["yMax"] - data["bODR2"]) / data["mODR2"] 

311 xs = np.array([xMid - 0.5, xMid, xMid + 0.5]) 

312 ys = data["mPerp"] * xs + data["bPerpMax"] 

313 else: 

314 xs = np.array([data["xMax"] - 0.2, data["xMax"], data["xMax"] + 0.2]) 

315 ys = xs * data["mPerp"] + data["bPerpMax"] 

316 ax.plot(xs, ys, "k--", alpha=0.7) 

317 

318 # Add a histogram 

319 axHist.set_ylabel("Number") 

320 axHist.set_xlabel("Distance to Line Fit") 

321 medDists = np.nanmedian(dists) 

322 madDists = sigmaMad(dists, nan_policy="omit") 

323 meanDists = np.nanmean(dists) 

324 

325 axHist.set_xlim(meanDists - 2.0 * madDists, meanDists + 2.0 * madDists) 

326 lineMedian = axHist.axvline(medDists, color="k", label="Median: {:0.3f}".format(medDists)) 

327 lineMad = axHist.axvline( 

328 medDists + madDists, color="k", ls="--", label="sigma MAD: {:0.3f}".format(madDists) 

329 ) 

330 axHist.axvline(medDists - madDists, color="k", ls="--") 

331 

332 linesForLegend = [lineHW, lineInit, lineRefit, fitScatter, lineMedian, lineMad] 

333 fig.legend( 

334 handles=linesForLegend, 

335 fontsize=8, 

336 bbox_to_anchor=(1.0, 0.99), 

337 bbox_transform=fig.transFigure, 

338 ncol=2, 

339 ) 

340 

341 axHist.hist(dists, bins=100, histtype="step", label="Refit", color="C0") 

342 axHist.hist(distsHW, bins=100, histtype="step", label="HW", color="C0", alpha=0.5) 

343 

344 alphas = [1.0, 0.5] 

345 handles = [Rectangle((0, 0), 1, 1, color="none", ec="C0", alpha=a) for a in alphas] 

346 labels = ["Refit", "HW"] 

347 axHist.legend(handles, labels, fontsize=6, loc="upper right") 

348 

349 # Add a contour plot showing the magnitude dependance 

350 # of the distance to the fit 

351 axContour.invert_yaxis() 

352 axContour.axvline(0.0, color="k", ls="--", zorder=-1) 

353 percsDists = np.nanpercentile(dists, [4, 96]) 

354 minXs = -1 * np.min(np.fabs(percsDists)) 

355 maxXs = np.min(np.fabs(percsDists)) 

356 plotPoints = (dists < maxXs) & (dists > minXs) 

357 xs = np.array(dists)[plotPoints] 

358 ys = cast(Vector, cast(Vector, mags)[cast(Vector, fitPoints)])[cast(Vector, plotPoints)] 

359 H, xEdges, yEdges = np.histogram2d(xs, ys, bins=(11, 11)) 

360 xBinWidth = xEdges[1] - xEdges[0] 

361 yBinWidth = yEdges[1] - yEdges[0] 

362 axContour.contour( 

363 xEdges[:-1] + xBinWidth / 2, yEdges[:-1] + yBinWidth / 2, H.T, levels=7, cmap=newBlues 

364 ) 

365 axContour.set_xlabel("Distance to Line Fit") 

366 axContour.set_ylabel(self.magLabel) 

367 

368 fig = addPlotInfo(plt.gcf(), plotInfo) 

369 

370 return fig