Coverage for python/lsst/analysis/tools/actions/plot/colorColorFitPlot.py: 13%

167 statements  

« prev     ^ index     » next       coverage.py v7.2.5, created at 2023-05-02 11:54 -0700

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ("ColorColorFitPlot",) 

25 

26from typing import Mapping, cast 

27 

28import matplotlib.patheffects as pathEffects 

29import matplotlib.pyplot as plt 

30import numpy as np 

31from lsst.analysis.tools import PlotAction 

32from lsst.pex.config import Field, ListField 

33from matplotlib.figure import Figure 

34from matplotlib.patches import Rectangle 

35from sklearn.neighbors import KernelDensity 

36 

37from ...interfaces import KeyedData, KeyedDataSchema, Scalar, Vector 

38from ...statistics import sigmaMad 

39from .plotUtils import addPlotInfo, mkColormap, perpDistance 

40 

41 

42class ColorColorFitPlot(PlotAction): 

43 """Makes a color-color plot and overplots a 

44 prefited line to the specified area of the plot. 

45 This is mostly used for the stellar locus plots 

46 and also includes panels that illustrate the 

47 goodness of the given fit. 

48 """ 

49 

50 xAxisLabel = Field[str](doc="Label to use for the x axis", optional=False) 

51 yAxisLabel = Field[str](doc="Label to use for the y axis", optional=False) 

52 magLabel = Field[str](doc="Label to use for the magnitudes used to color code by", optional=False) 

53 

54 plotTypes = ListField[str]( 

55 doc="Selection of types of objects to plot. Can take any combination of" 

56 " stars, galaxies, unknown, mag, any.", 

57 default=["stars"], 

58 ) 

59 

60 plotName = Field[str](doc="The name for the plot.", optional=False) 

61 

62 def getInputSchema(self, **kwargs) -> KeyedDataSchema: 

63 base: list[tuple[str, type[Vector] | type[Scalar]]] = [] 

64 base.append(("x", Vector)) 

65 base.append(("y", Vector)) 

66 base.append(("mag", Vector)) 

67 base.append(("approxMagDepth", Scalar)) 

68 base.append((f"{self.plotName}_sigmaMAD", Scalar)) 

69 base.append((f"{self.plotName}_median", Scalar)) 

70 base.append((f"{self.plotName}_hardwired_sigmaMAD", Scalar)) 

71 base.append((f"{self.plotName}_hardwired_median", Scalar)) 

72 base.append(("xMin", Scalar)) 

73 base.append(("xMax", Scalar)) 

74 base.append(("yMin", Scalar)) 

75 base.append(("yMax", Scalar)) 

76 base.append(("mHW", Scalar)) 

77 base.append(("bHW", Scalar)) 

78 base.append(("mODR", Scalar)) 

79 base.append(("bODR", Scalar)) 

80 base.append(("yBoxMin", Scalar)) 

81 base.append(("yBoxMax", Scalar)) 

82 base.append(("bPerpMin", Scalar)) 

83 base.append(("bPerpMax", Scalar)) 

84 base.append(("mODR2", Scalar)) 

85 base.append(("bODR2", Scalar)) 

86 base.append(("mPerp", Scalar)) 

87 

88 return base 

89 

90 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure: 

91 self._validateInput(data, **kwargs) 

92 return self.makePlot(data, **kwargs) 

93 

94 def _validateInput(self, data: KeyedData, **kwargs) -> None: 

95 """NOTE currently can only check that something is not a scalar, not 

96 check that data is consistent with Vector 

97 """ 

98 needed = self.getInputSchema(**kwargs) 

99 if remainder := {key.format(**kwargs) for key, _ in needed} - { 

100 key.format(**kwargs) for key in data.keys() 

101 }: 

102 raise ValueError(f"Task needs keys {remainder} but they were not in input") 

103 for name, typ in needed: 

104 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar) 

105 if isScalar and typ != Scalar: 

106 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}") 

107 

108 def makePlot( 

109 self, 

110 data: KeyedData, 

111 plotInfo: Mapping[str, str], 

112 **kwargs, 

113 ) -> Figure: 

114 """Make stellar locus plots using pre fitted values. 

115 

116 Parameters 

117 ---------- 

118 data : `KeyedData` 

119 The data to plot the points from, for more information 

120 please see the notes section. 

121 plotInfo : `dict` 

122 A dictionary of information about the data being plotted 

123 with keys: 

124 

125 * ``"run"`` 

126 The output run for the plots (`str`). 

127 * ``"skymap"`` 

128 The type of skymap used for the data (`str`). 

129 * ``"filter"`` 

130 The filter used for this data (`str`). 

131 * ``"tract"`` 

132 The tract that the data comes from (`str`). 

133 

134 Returns 

135 ------- 

136 fig : `matplotlib.figure.Figure` 

137 The resulting figure. 

138 

139 Notes 

140 ----- 

141 The axis labels are given by `self.config.xLabel` and 

142 `self.config.yLabel`. The perpendicular distance of the points to 

143 the fit line is given in a histogram in the second panel. 

144 

145 For the code to work it expects various quantities to be 

146 present in the 'data' that it is given. 

147 

148 The quantities that are expected to be present are: 

149 

150 * Statistics that are shown on the plot or used by the plotting code: 

151 * ``approxMagDepth`` 

152 The approximate magnitude corresponding to the SN cut used. 

153 * ``f"{self.plotName}_sigmaMAD"`` 

154 The sigma mad of the distances to the line fit. 

155 * ``f"{self.identity or ''}_median"`` 

156 The median of the distances to the line fit. 

157 * ``f"{self.identity or ''}_hardwired_sigmaMAD"`` 

158 The sigma mad of the distances to the initial fit. 

159 * ``f"{self.identity or ''}_hardwired_median"`` 

160 The median of the distances to the initial fit. 

161 

162 

163 * Parameters from the fitting code that are illustrated on the plot: 

164 * ``"bHW"`` 

165 The hardwired intercept to fall back on. 

166 * ``"bODR"`` 

167 The intercept calculated by the orthogonal distance 

168 regression fitting. 

169 * ``"bODR2"`` 

170 The intercept calculated by the second iteration of 

171 orthogonal distance regression fitting. 

172 * ``"mHW"`` 

173 The hardwired gradient to fall back on. 

174 * ``"mODR"`` 

175 The gradient calculated by the orthogonal distance 

176 regression fitting. 

177 * ``"mODR2"`` 

178 The gradient calculated by the second iteration of 

179 orthogonal distance regression fitting. 

180 * ``"xMin`"`` 

181 The x minimum of the box used in the fit. 

182 * ``"xMax"`` 

183 The x maximum of the box used in the fit. 

184 * ``"yMin"`` 

185 The y minimum of the box used in the fit. 

186 * ``"yMax"`` 

187 The y maximum of the box used in the fit. 

188 * ``"mPerp"`` 

189 The gradient of the line perpendicular to the line from 

190 the second ODR fit. 

191 * ``"bPerpMin"`` 

192 The intercept of the perpendicular line that goes through 

193 xMin. 

194 * ``"bPerpMax"`` 

195 The intercept of the perpendicular line that goes through 

196 xMax. 

197 

198 * The main inputs to plot: 

199 x, y, mag 

200 

201 Examples 

202 -------- 

203 An example of the plot produced from this code is here: 

204 

205 .. image:: /_static/analysis_tools/stellarLocusExample.png 

206 

207 For a detailed example of how to make a plot from the command line 

208 please see the 

209 :ref:`getting started guide<analysis-tools-getting-started>`. 

210 """ 

211 

212 # Define a new colormap 

213 newBlues = mkColormap(["paleturquoise", "midnightblue"]) 

214 

215 # Make a figure with three panels 

216 fig = plt.figure(dpi=300) 

217 ax = fig.add_axes([0.12, 0.25, 0.43, 0.60]) 

218 axContour = fig.add_axes([0.65, 0.11, 0.3, 0.31]) 

219 axHist = fig.add_axes([0.65, 0.51, 0.3, 0.31]) 

220 

221 # Check for nans/infs 

222 goodPoints = np.isfinite(data["x"]) & np.isfinite(data["y"]) & np.isfinite(data["mag"]) 

223 xs = cast(Vector, data["x"])[goodPoints] 

224 ys = cast(Vector, data["y"])[goodPoints] 

225 mags = cast(Vector, data["mag"])[goodPoints] 

226 

227 # TODO: Make a no data fig function and use here 

228 if len(xs) == 0 or len(ys) == 0: 

229 return fig 

230 

231 # Points to use for the fit 

232 # type ignore because Vector needs a prototype interface 

233 fitPoints = np.where( 

234 (xs > data["xMin"]) # type: ignore 

235 & (xs < data["xMax"]) # type: ignore 

236 & (ys > data["yMin"]) # type: ignore 

237 & (ys < data["yMax"]) # type: ignore 

238 )[0] 

239 

240 # Plot the initial fit box 

241 ax.plot( 

242 [data["xMin"], data["xMax"], data["xMax"], data["xMin"], data["xMin"]], 

243 [data["yMin"], data["yMin"], data["yMax"], data["yMax"], data["yMin"]], 

244 "k", 

245 alpha=0.3, 

246 ) 

247 

248 # Add some useful information to the plot 

249 bbox = dict(alpha=0.9, facecolor="white", edgecolor="none") 

250 medMag = np.nanmedian(cast(Vector, mags)) 

251 

252 # TODO: GET THE SN FROM THE EARLIER PREP STEP 

253 SN = "-" 

254 infoText = "N Used: {}\nN Total: {}\nS/N cut: {}\n".format(len(fitPoints), len(xs), SN) 

255 infoText += r"Mag $\lesssim$: " + "{:0.2f}".format(medMag) 

256 ax.text(0.05, 0.78, infoText, color="k", transform=ax.transAxes, fontsize=8, bbox=bbox) 

257 

258 # Calculate the density of the points 

259 xy = np.vstack([xs, ys]).T 

260 kde = KernelDensity(kernel="gaussian").fit(xy) 

261 z = np.exp(kde.score_samples(xy)) 

262 

263 ax.scatter(xs[~fitPoints], ys[~fitPoints], c=z[~fitPoints], cmap="binary", s=0.3) 

264 fitScatter = ax.scatter( 

265 xs[fitPoints], ys[fitPoints], c=z[fitPoints], cmap=newBlues, label="Used for Fit", s=0.3 

266 ) 

267 

268 # Add colorbar 

269 cbAx = fig.add_axes([0.12, 0.08, 0.43, 0.04]) 

270 plt.colorbar(fitScatter, cax=cbAx, orientation="horizontal") 

271 cbText = cbAx.text( 

272 0.5, 

273 0.5, 

274 "Number Density", 

275 color="k", 

276 rotation="horizontal", 

277 transform=cbAx.transAxes, 

278 ha="center", 

279 va="center", 

280 fontsize=8, 

281 ) 

282 cbText.set_path_effects([pathEffects.Stroke(linewidth=3, foreground="w"), pathEffects.Normal()]) 

283 cbAx.set_xticks([np.min(z[fitPoints]), np.max(z[fitPoints])], labels=["Less", "More"]) 

284 

285 ax.set_xlabel(self.xAxisLabel) 

286 ax.set_ylabel(self.yAxisLabel) 

287 

288 # Set useful axis limits 

289 percsX = np.nanpercentile(xs, [0.5, 99.5]) 

290 percsY = np.nanpercentile(ys, [0.5, 99.5]) 

291 x5 = (percsX[1] - percsX[0]) / 5 

292 y5 = (percsY[1] - percsY[0]) / 5 

293 ax.set_xlim(percsX[0] - x5, percsX[1] + x5) 

294 ax.set_ylim(percsY[0] - y5, percsY[1] + y5) 

295 

296 # Plot the fit lines 

297 if np.fabs(data["mHW"]) > 1: 

298 ysFitLineHW = np.array([data["yMin"], data["yMax"]]) 

299 xsFitLineHW = (ysFitLineHW - data["bHW"]) / data["mHW"] 

300 ysFitLine = np.array([data["yMin"], data["yMax"]]) 

301 xsFitLine = (ysFitLine - data["bODR"]) / data["mODR"] 

302 ysFitLine2 = np.array([data["yMin"], data["yMax"]]) 

303 xsFitLine2 = (ysFitLine2 - data["bODR2"]) / data["mODR2"] 

304 

305 else: 

306 xsFitLineHW = np.array([data["xMin"], data["xMax"]]) 

307 ysFitLineHW = data["mHW"] * xsFitLineHW + data["bHW"] # type: ignore 

308 xsFitLine = np.array([data["xMin"], data["xMax"]]) 

309 ysFitLine = np.array( 

310 [ 

311 data["mODR"] * xsFitLine[0] + data["bODR"], 

312 data["mODR"] * xsFitLine[1] + data["bODR"], 

313 ] 

314 ) 

315 xsFitLine2 = np.array([data["xMin"], data["xMax"]]) 

316 ysFitLine2 = np.array( 

317 [ 

318 data["mODR2"] * xsFitLine2[0] + data["bODR2"], 

319 data["mODR2"] * xsFitLine2[1] + data["bODR2"], 

320 ] 

321 ) 

322 

323 ax.plot(xsFitLineHW, ysFitLineHW, "w", lw=2) 

324 (lineHW,) = ax.plot(xsFitLineHW, ysFitLineHW, "g", lw=1, ls="--", label="Hardwired") 

325 

326 ax.plot(xsFitLine, ysFitLine, "w", lw=2) 

327 (lineInit,) = ax.plot(xsFitLine, ysFitLine, "b", lw=1, ls="--", label="Initial") 

328 

329 ax.plot(xsFitLine2, ysFitLine2, "w", lw=2) 

330 (lineRefit,) = ax.plot(xsFitLine2, ysFitLine2, "k", lw=1, ls="--", label="Refit") 

331 

332 # Calculate the distances to that line 

333 # Need two points to characterise the lines we want 

334 # to get the distances to 

335 p1 = np.array([xsFitLine[0], ysFitLine[0]]) 

336 p2 = np.array([xsFitLine[1], ysFitLine[1]]) 

337 

338 p1HW = np.array([xsFitLine[0], ysFitLineHW[0]]) 

339 p2HW = np.array([xsFitLine[1], ysFitLineHW[1]]) 

340 

341 # Convert to millimags 

342 distsHW = np.array(perpDistance(p1HW, p2HW, zip(xs[fitPoints], ys[fitPoints]))) * 1000 

343 dists = np.array(perpDistance(p1, p2, zip(xs[fitPoints], ys[fitPoints]))) * 1000 

344 

345 # Now we have the information for the perpendicular line we 

346 # can use it to calculate the points at the ends of the 

347 # perpendicular lines that intersect at the box edges 

348 if np.fabs(data["mHW"]) > 1: 

349 xMid = (data["yMin"] - data["bODR2"]) / data["mODR2"] 

350 xs = np.array([xMid - 0.5, xMid, xMid + 0.5]) 

351 ys = data["mPerp"] * xs + data["bPerpMin"] 

352 else: 

353 xs = np.array([data["xMin"] - 0.2, data["xMin"], data["xMin"] + 0.2]) 

354 ys = xs * data["mPerp"] + data["bPerpMin"] 

355 ax.plot(xs, ys, "k--", alpha=0.7) 

356 

357 if np.fabs(data["mHW"]) > 1: 

358 xMid = (data["yMax"] - data["bODR2"]) / data["mODR2"] 

359 xs = np.array([xMid - 0.5, xMid, xMid + 0.5]) 

360 ys = data["mPerp"] * xs + data["bPerpMax"] 

361 else: 

362 xs = np.array([data["xMax"] - 0.2, data["xMax"], data["xMax"] + 0.2]) 

363 ys = xs * data["mPerp"] + data["bPerpMax"] 

364 ax.plot(xs, ys, "k--", alpha=0.7) 

365 

366 # Add a histogram 

367 axHist.set_ylabel("Number") 

368 axHist.set_xlabel("Distance to Line Fit") 

369 medDists = np.nanmedian(dists) 

370 madDists = sigmaMad(dists, nan_policy="omit") 

371 meanDists = np.nanmean(dists) 

372 

373 axHist.set_xlim(meanDists - 2.0 * madDists, meanDists + 2.0 * madDists) 

374 lineMedian = axHist.axvline(medDists, color="k", label="Median: {:0.3f}".format(medDists)) 

375 lineMad = axHist.axvline( 

376 medDists + madDists, color="k", ls="--", label="sigma MAD: {:0.3f}".format(madDists) 

377 ) 

378 axHist.axvline(medDists - madDists, color="k", ls="--") 

379 

380 linesForLegend = [lineHW, lineInit, lineRefit, fitScatter, lineMedian, lineMad] 

381 fig.legend( 

382 handles=linesForLegend, 

383 fontsize=8, 

384 bbox_to_anchor=(1.0, 0.99), 

385 bbox_transform=fig.transFigure, 

386 ncol=2, 

387 ) 

388 

389 axHist.hist(dists, bins=100, histtype="step", label="Refit", color="C0") 

390 axHist.hist(distsHW, bins=100, histtype="step", label="HW", color="C0", alpha=0.5) 

391 

392 alphas = [1.0, 0.5] 

393 handles = [Rectangle((0, 0), 1, 1, color="none", ec="C0", alpha=a) for a in alphas] 

394 labels = ["Refit", "HW"] 

395 axHist.legend(handles, labels, fontsize=6, loc="upper right") 

396 

397 # Add a contour plot showing the magnitude dependance 

398 # of the distance to the fit 

399 axContour.invert_yaxis() 

400 axContour.axvline(0.0, color="k", ls="--", zorder=-1) 

401 percsDists = np.nanpercentile(dists, [4, 96]) 

402 minXs = -1 * np.min(np.fabs(percsDists)) 

403 maxXs = np.min(np.fabs(percsDists)) 

404 plotPoints = (dists < maxXs) & (dists > minXs) 

405 xs = np.array(dists)[plotPoints] 

406 ys = cast(Vector, cast(Vector, mags)[cast(Vector, fitPoints)])[cast(Vector, plotPoints)] 

407 H, xEdges, yEdges = np.histogram2d(xs, ys, bins=(11, 11)) 

408 xBinWidth = xEdges[1] - xEdges[0] 

409 yBinWidth = yEdges[1] - yEdges[0] 

410 axContour.contour( 

411 xEdges[:-1] + xBinWidth / 2, yEdges[:-1] + yBinWidth / 2, H.T, levels=7, cmap=newBlues 

412 ) 

413 axContour.set_xlabel("Distance to Line Fit") 

414 axContour.set_ylabel(self.magLabel) 

415 

416 fig = addPlotInfo(plt.gcf(), plotInfo) 

417 

418 return fig