Coverage for python/lsst/analysis/tools/actions/plot/colorColorFitPlot.py: 13%

166 statements  

« prev     ^ index     » next       coverage.py v7.2.5, created at 2023-05-07 02:02 -0700

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ("ColorColorFitPlot",) 

25 

26from typing import Mapping, cast 

27 

28import matplotlib.patheffects as pathEffects 

29import matplotlib.pyplot as plt 

30import numpy as np 

31from lsst.pex.config import Field, ListField 

32from matplotlib.figure import Figure 

33from matplotlib.patches import Rectangle 

34from sklearn.neighbors import KernelDensity 

35 

36from ...interfaces import KeyedData, KeyedDataSchema, PlotAction, Scalar, Vector 

37from ...statistics import sigmaMad 

38from .plotUtils import addPlotInfo, mkColormap, perpDistance 

39 

40 

41class ColorColorFitPlot(PlotAction): 

42 """Makes a color-color plot and overplots a 

43 prefited line to the specified area of the plot. 

44 This is mostly used for the stellar locus plots 

45 and also includes panels that illustrate the 

46 goodness of the given fit. 

47 """ 

48 

49 xAxisLabel = Field[str](doc="Label to use for the x axis", optional=False) 

50 yAxisLabel = Field[str](doc="Label to use for the y axis", optional=False) 

51 magLabel = Field[str](doc="Label to use for the magnitudes used to color code by", optional=False) 

52 

53 plotTypes = ListField[str]( 

54 doc="Selection of types of objects to plot. Can take any combination of" 

55 " stars, galaxies, unknown, mag, any.", 

56 default=["stars"], 

57 ) 

58 

59 plotName = Field[str](doc="The name for the plot.", optional=False) 

60 

61 def getInputSchema(self, **kwargs) -> KeyedDataSchema: 

62 base: list[tuple[str, type[Vector] | type[Scalar]]] = [] 

63 base.append(("x", Vector)) 

64 base.append(("y", Vector)) 

65 base.append(("mag", Vector)) 

66 base.append(("approxMagDepth", Scalar)) 

67 base.append((f"{self.plotName}_sigmaMAD", Scalar)) 

68 base.append((f"{self.plotName}_median", Scalar)) 

69 base.append((f"{self.plotName}_hardwired_sigmaMAD", Scalar)) 

70 base.append((f"{self.plotName}_hardwired_median", Scalar)) 

71 base.append(("xMin", Scalar)) 

72 base.append(("xMax", Scalar)) 

73 base.append(("yMin", Scalar)) 

74 base.append(("yMax", Scalar)) 

75 base.append(("mHW", Scalar)) 

76 base.append(("bHW", Scalar)) 

77 base.append(("mODR", Scalar)) 

78 base.append(("bODR", Scalar)) 

79 base.append(("yBoxMin", Scalar)) 

80 base.append(("yBoxMax", Scalar)) 

81 base.append(("bPerpMin", Scalar)) 

82 base.append(("bPerpMax", Scalar)) 

83 base.append(("mODR2", Scalar)) 

84 base.append(("bODR2", Scalar)) 

85 base.append(("mPerp", Scalar)) 

86 

87 return base 

88 

89 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure: 

90 self._validateInput(data, **kwargs) 

91 return self.makePlot(data, **kwargs) 

92 

93 def _validateInput(self, data: KeyedData, **kwargs) -> None: 

94 """NOTE currently can only check that something is not a scalar, not 

95 check that data is consistent with Vector 

96 """ 

97 needed = self.getInputSchema(**kwargs) 

98 if remainder := {key.format(**kwargs) for key, _ in needed} - { 

99 key.format(**kwargs) for key in data.keys() 

100 }: 

101 raise ValueError(f"Task needs keys {remainder} but they were not in input") 

102 for name, typ in needed: 

103 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar) 

104 if isScalar and typ != Scalar: 

105 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}") 

106 

107 def makePlot( 

108 self, 

109 data: KeyedData, 

110 plotInfo: Mapping[str, str], 

111 **kwargs, 

112 ) -> Figure: 

113 """Make stellar locus plots using pre fitted values. 

114 

115 Parameters 

116 ---------- 

117 data : `KeyedData` 

118 The data to plot the points from, for more information 

119 please see the notes section. 

120 plotInfo : `dict` 

121 A dictionary of information about the data being plotted 

122 with keys: 

123 

124 * ``"run"`` 

125 The output run for the plots (`str`). 

126 * ``"skymap"`` 

127 The type of skymap used for the data (`str`). 

128 * ``"filter"`` 

129 The filter used for this data (`str`). 

130 * ``"tract"`` 

131 The tract that the data comes from (`str`). 

132 

133 Returns 

134 ------- 

135 fig : `matplotlib.figure.Figure` 

136 The resulting figure. 

137 

138 Notes 

139 ----- 

140 The axis labels are given by `self.config.xLabel` and 

141 `self.config.yLabel`. The perpendicular distance of the points to 

142 the fit line is given in a histogram in the second panel. 

143 

144 For the code to work it expects various quantities to be 

145 present in the 'data' that it is given. 

146 

147 The quantities that are expected to be present are: 

148 

149 * Statistics that are shown on the plot or used by the plotting code: 

150 * ``approxMagDepth`` 

151 The approximate magnitude corresponding to the SN cut used. 

152 * ``f"{self.plotName}_sigmaMAD"`` 

153 The sigma mad of the distances to the line fit. 

154 * ``f"{self.identity or ''}_median"`` 

155 The median of the distances to the line fit. 

156 * ``f"{self.identity or ''}_hardwired_sigmaMAD"`` 

157 The sigma mad of the distances to the initial fit. 

158 * ``f"{self.identity or ''}_hardwired_median"`` 

159 The median of the distances to the initial fit. 

160 

161 

162 * Parameters from the fitting code that are illustrated on the plot: 

163 * ``"bHW"`` 

164 The hardwired intercept to fall back on. 

165 * ``"bODR"`` 

166 The intercept calculated by the orthogonal distance 

167 regression fitting. 

168 * ``"bODR2"`` 

169 The intercept calculated by the second iteration of 

170 orthogonal distance regression fitting. 

171 * ``"mHW"`` 

172 The hardwired gradient to fall back on. 

173 * ``"mODR"`` 

174 The gradient calculated by the orthogonal distance 

175 regression fitting. 

176 * ``"mODR2"`` 

177 The gradient calculated by the second iteration of 

178 orthogonal distance regression fitting. 

179 * ``"xMin`"`` 

180 The x minimum of the box used in the fit. 

181 * ``"xMax"`` 

182 The x maximum of the box used in the fit. 

183 * ``"yMin"`` 

184 The y minimum of the box used in the fit. 

185 * ``"yMax"`` 

186 The y maximum of the box used in the fit. 

187 * ``"mPerp"`` 

188 The gradient of the line perpendicular to the line from 

189 the second ODR fit. 

190 * ``"bPerpMin"`` 

191 The intercept of the perpendicular line that goes through 

192 xMin. 

193 * ``"bPerpMax"`` 

194 The intercept of the perpendicular line that goes through 

195 xMax. 

196 

197 * The main inputs to plot: 

198 x, y, mag 

199 

200 Examples 

201 -------- 

202 An example of the plot produced from this code is here: 

203 

204 .. image:: /_static/analysis_tools/stellarLocusExample.png 

205 

206 For a detailed example of how to make a plot from the command line 

207 please see the 

208 :ref:`getting started guide<analysis-tools-getting-started>`. 

209 """ 

210 

211 # Define a new colormap 

212 newBlues = mkColormap(["paleturquoise", "midnightblue"]) 

213 

214 # Make a figure with three panels 

215 fig = plt.figure(dpi=300) 

216 ax = fig.add_axes([0.12, 0.25, 0.43, 0.60]) 

217 axContour = fig.add_axes([0.65, 0.11, 0.3, 0.31]) 

218 axHist = fig.add_axes([0.65, 0.51, 0.3, 0.31]) 

219 

220 # Check for nans/infs 

221 goodPoints = np.isfinite(data["x"]) & np.isfinite(data["y"]) & np.isfinite(data["mag"]) 

222 xs = cast(Vector, data["x"])[goodPoints] 

223 ys = cast(Vector, data["y"])[goodPoints] 

224 mags = cast(Vector, data["mag"])[goodPoints] 

225 

226 # TODO: Make a no data fig function and use here 

227 if len(xs) == 0 or len(ys) == 0: 

228 return fig 

229 

230 # Points to use for the fit 

231 # type ignore because Vector needs a prototype interface 

232 fitPoints = np.where( 

233 (xs > data["xMin"]) # type: ignore 

234 & (xs < data["xMax"]) # type: ignore 

235 & (ys > data["yMin"]) # type: ignore 

236 & (ys < data["yMax"]) # type: ignore 

237 )[0] 

238 

239 # Plot the initial fit box 

240 ax.plot( 

241 [data["xMin"], data["xMax"], data["xMax"], data["xMin"], data["xMin"]], 

242 [data["yMin"], data["yMin"], data["yMax"], data["yMax"], data["yMin"]], 

243 "k", 

244 alpha=0.3, 

245 ) 

246 

247 # Add some useful information to the plot 

248 bbox = dict(alpha=0.9, facecolor="white", edgecolor="none") 

249 medMag = np.nanmedian(cast(Vector, mags)) 

250 

251 # TODO: GET THE SN FROM THE EARLIER PREP STEP 

252 SN = "-" 

253 infoText = "N Used: {}\nN Total: {}\nS/N cut: {}\n".format(len(fitPoints), len(xs), SN) 

254 infoText += r"Mag $\lesssim$: " + "{:0.2f}".format(medMag) 

255 ax.text(0.05, 0.78, infoText, color="k", transform=ax.transAxes, fontsize=8, bbox=bbox) 

256 

257 # Calculate the density of the points 

258 xy = np.vstack([xs, ys]).T 

259 kde = KernelDensity(kernel="gaussian").fit(xy) 

260 z = np.exp(kde.score_samples(xy)) 

261 

262 ax.scatter(xs[~fitPoints], ys[~fitPoints], c=z[~fitPoints], cmap="binary", s=0.3) 

263 fitScatter = ax.scatter( 

264 xs[fitPoints], ys[fitPoints], c=z[fitPoints], cmap=newBlues, label="Used for Fit", s=0.3 

265 ) 

266 

267 # Add colorbar 

268 cbAx = fig.add_axes([0.12, 0.08, 0.43, 0.04]) 

269 plt.colorbar(fitScatter, cax=cbAx, orientation="horizontal") 

270 cbText = cbAx.text( 

271 0.5, 

272 0.5, 

273 "Number Density", 

274 color="k", 

275 rotation="horizontal", 

276 transform=cbAx.transAxes, 

277 ha="center", 

278 va="center", 

279 fontsize=8, 

280 ) 

281 cbText.set_path_effects([pathEffects.Stroke(linewidth=3, foreground="w"), pathEffects.Normal()]) 

282 cbAx.set_xticks([np.min(z[fitPoints]), np.max(z[fitPoints])], labels=["Less", "More"]) 

283 

284 ax.set_xlabel(self.xAxisLabel) 

285 ax.set_ylabel(self.yAxisLabel) 

286 

287 # Set useful axis limits 

288 percsX = np.nanpercentile(xs, [0.5, 99.5]) 

289 percsY = np.nanpercentile(ys, [0.5, 99.5]) 

290 x5 = (percsX[1] - percsX[0]) / 5 

291 y5 = (percsY[1] - percsY[0]) / 5 

292 ax.set_xlim(percsX[0] - x5, percsX[1] + x5) 

293 ax.set_ylim(percsY[0] - y5, percsY[1] + y5) 

294 

295 # Plot the fit lines 

296 if np.fabs(data["mHW"]) > 1: 

297 ysFitLineHW = np.array([data["yMin"], data["yMax"]]) 

298 xsFitLineHW = (ysFitLineHW - data["bHW"]) / data["mHW"] 

299 ysFitLine = np.array([data["yMin"], data["yMax"]]) 

300 xsFitLine = (ysFitLine - data["bODR"]) / data["mODR"] 

301 ysFitLine2 = np.array([data["yMin"], data["yMax"]]) 

302 xsFitLine2 = (ysFitLine2 - data["bODR2"]) / data["mODR2"] 

303 

304 else: 

305 xsFitLineHW = np.array([data["xMin"], data["xMax"]]) 

306 ysFitLineHW = data["mHW"] * xsFitLineHW + data["bHW"] # type: ignore 

307 xsFitLine = np.array([data["xMin"], data["xMax"]]) 

308 ysFitLine = np.array( 

309 [ 

310 data["mODR"] * xsFitLine[0] + data["bODR"], 

311 data["mODR"] * xsFitLine[1] + data["bODR"], 

312 ] 

313 ) 

314 xsFitLine2 = np.array([data["xMin"], data["xMax"]]) 

315 ysFitLine2 = np.array( 

316 [ 

317 data["mODR2"] * xsFitLine2[0] + data["bODR2"], 

318 data["mODR2"] * xsFitLine2[1] + data["bODR2"], 

319 ] 

320 ) 

321 

322 ax.plot(xsFitLineHW, ysFitLineHW, "w", lw=2) 

323 (lineHW,) = ax.plot(xsFitLineHW, ysFitLineHW, "g", lw=1, ls="--", label="Hardwired") 

324 

325 ax.plot(xsFitLine, ysFitLine, "w", lw=2) 

326 (lineInit,) = ax.plot(xsFitLine, ysFitLine, "b", lw=1, ls="--", label="Initial") 

327 

328 ax.plot(xsFitLine2, ysFitLine2, "w", lw=2) 

329 (lineRefit,) = ax.plot(xsFitLine2, ysFitLine2, "k", lw=1, ls="--", label="Refit") 

330 

331 # Calculate the distances to that line 

332 # Need two points to characterise the lines we want 

333 # to get the distances to 

334 p1 = np.array([xsFitLine[0], ysFitLine[0]]) 

335 p2 = np.array([xsFitLine[1], ysFitLine[1]]) 

336 

337 p1HW = np.array([xsFitLine[0], ysFitLineHW[0]]) 

338 p2HW = np.array([xsFitLine[1], ysFitLineHW[1]]) 

339 

340 # Convert to millimags 

341 distsHW = np.array(perpDistance(p1HW, p2HW, zip(xs[fitPoints], ys[fitPoints]))) * 1000 

342 dists = np.array(perpDistance(p1, p2, zip(xs[fitPoints], ys[fitPoints]))) * 1000 

343 

344 # Now we have the information for the perpendicular line we 

345 # can use it to calculate the points at the ends of the 

346 # perpendicular lines that intersect at the box edges 

347 if np.fabs(data["mHW"]) > 1: 

348 xMid = (data["yMin"] - data["bODR2"]) / data["mODR2"] 

349 xs = np.array([xMid - 0.5, xMid, xMid + 0.5]) 

350 ys = data["mPerp"] * xs + data["bPerpMin"] 

351 else: 

352 xs = np.array([data["xMin"] - 0.2, data["xMin"], data["xMin"] + 0.2]) 

353 ys = xs * data["mPerp"] + data["bPerpMin"] 

354 ax.plot(xs, ys, "k--", alpha=0.7) 

355 

356 if np.fabs(data["mHW"]) > 1: 

357 xMid = (data["yMax"] - data["bODR2"]) / data["mODR2"] 

358 xs = np.array([xMid - 0.5, xMid, xMid + 0.5]) 

359 ys = data["mPerp"] * xs + data["bPerpMax"] 

360 else: 

361 xs = np.array([data["xMax"] - 0.2, data["xMax"], data["xMax"] + 0.2]) 

362 ys = xs * data["mPerp"] + data["bPerpMax"] 

363 ax.plot(xs, ys, "k--", alpha=0.7) 

364 

365 # Add a histogram 

366 axHist.set_ylabel("Number") 

367 axHist.set_xlabel("Distance to Line Fit") 

368 medDists = np.nanmedian(dists) 

369 madDists = sigmaMad(dists, nan_policy="omit") 

370 meanDists = np.nanmean(dists) 

371 

372 axHist.set_xlim(meanDists - 2.0 * madDists, meanDists + 2.0 * madDists) 

373 lineMedian = axHist.axvline(medDists, color="k", label="Median: {:0.3f}".format(medDists)) 

374 lineMad = axHist.axvline( 

375 medDists + madDists, color="k", ls="--", label="sigma MAD: {:0.3f}".format(madDists) 

376 ) 

377 axHist.axvline(medDists - madDists, color="k", ls="--") 

378 

379 linesForLegend = [lineHW, lineInit, lineRefit, fitScatter, lineMedian, lineMad] 

380 fig.legend( 

381 handles=linesForLegend, 

382 fontsize=8, 

383 bbox_to_anchor=(1.0, 0.99), 

384 bbox_transform=fig.transFigure, 

385 ncol=2, 

386 ) 

387 

388 axHist.hist(dists, bins=100, histtype="step", label="Refit", color="C0") 

389 axHist.hist(distsHW, bins=100, histtype="step", label="HW", color="C0", alpha=0.5) 

390 

391 alphas = [1.0, 0.5] 

392 handles = [Rectangle((0, 0), 1, 1, color="none", ec="C0", alpha=a) for a in alphas] 

393 labels = ["Refit", "HW"] 

394 axHist.legend(handles, labels, fontsize=6, loc="upper right") 

395 

396 # Add a contour plot showing the magnitude dependance 

397 # of the distance to the fit 

398 axContour.invert_yaxis() 

399 axContour.axvline(0.0, color="k", ls="--", zorder=-1) 

400 percsDists = np.nanpercentile(dists, [4, 96]) 

401 minXs = -1 * np.min(np.fabs(percsDists)) 

402 maxXs = np.min(np.fabs(percsDists)) 

403 plotPoints = (dists < maxXs) & (dists > minXs) 

404 xs = np.array(dists)[plotPoints] 

405 ys = cast(Vector, cast(Vector, mags)[cast(Vector, fitPoints)])[cast(Vector, plotPoints)] 

406 H, xEdges, yEdges = np.histogram2d(xs, ys, bins=(11, 11)) 

407 xBinWidth = xEdges[1] - xEdges[0] 

408 yBinWidth = yEdges[1] - yEdges[0] 

409 axContour.contour( 

410 xEdges[:-1] + xBinWidth / 2, yEdges[:-1] + yBinWidth / 2, H.T, levels=7, cmap=newBlues 

411 ) 

412 axContour.set_xlabel("Distance to Line Fit") 

413 axContour.set_ylabel(self.magLabel) 

414 

415 fig = addPlotInfo(plt.gcf(), plotInfo) 

416 

417 return fig