Coverage for python/lsst/analysis/tools/actions/plot/colorColorFitPlot.py: 13%

172 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-11-29 11:31 +0000

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ("ColorColorFitPlot",) 

25 

26from typing import Mapping, cast 

27 

28import matplotlib.patheffects as pathEffects 

29import matplotlib.pyplot as plt 

30import numpy as np 

31from lsst.pex.config import Field, ListField, RangeField 

32from matplotlib.figure import Figure 

33from matplotlib.patches import Rectangle 

34from sklearn.neighbors import KernelDensity 

35 

36from ...interfaces import KeyedData, KeyedDataSchema, PlotAction, Scalar, Vector 

37from ...statistics import sigmaMad 

38from ..keyedData.stellarLocusFit import perpDistance 

39from .plotUtils import addPlotInfo, mkColormap 

40 

41 

42class ColorColorFitPlot(PlotAction): 

43 """Makes a color-color plot and overplots a 

44 prefited line to the specified area of the plot. 

45 This is mostly used for the stellar locus plots 

46 and also includes panels that illustrate the 

47 goodness of the given fit. 

48 """ 

49 

50 xAxisLabel = Field[str](doc="Label to use for the x axis", optional=False) 

51 yAxisLabel = Field[str](doc="Label to use for the y axis", optional=False) 

52 magLabel = Field[str](doc="Label to use for the magnitudes used to color code by", optional=False) 

53 

54 plotTypes = ListField[str]( 

55 doc="Selection of types of objects to plot. Can take any combination of" 

56 " stars, galaxies, unknown, mag, any.", 

57 default=["stars"], 

58 ) 

59 

60 plotName = Field[str](doc="The name for the plot.", optional=False) 

61 minPointsForFit = RangeField[int]( 

62 doc="Minimum number of valid objects to bother attempting a fit.", 

63 default=5, 

64 min=1, 

65 ) 

66 

67 def getInputSchema(self, **kwargs) -> KeyedDataSchema: 

68 base: list[tuple[str, type[Vector] | type[Scalar]]] = [] 

69 base.append(("x", Vector)) 

70 base.append(("y", Vector)) 

71 base.append(("mag", Vector)) 

72 base.append(("approxMagDepth", Scalar)) 

73 base.append((f"{self.plotName}_sigmaMAD", Scalar)) 

74 base.append((f"{self.plotName}_median", Scalar)) 

75 base.append((f"{self.plotName}_hardwired_sigmaMAD", Scalar)) 

76 base.append((f"{self.plotName}_hardwired_median", Scalar)) 

77 base.append(("xMin", Scalar)) 

78 base.append(("xMax", Scalar)) 

79 base.append(("yMin", Scalar)) 

80 base.append(("yMax", Scalar)) 

81 base.append(("mHW", Scalar)) 

82 base.append(("bHW", Scalar)) 

83 base.append(("mODR", Scalar)) 

84 base.append(("bODR", Scalar)) 

85 base.append(("yBoxMin", Scalar)) 

86 base.append(("yBoxMax", Scalar)) 

87 base.append(("bPerpMin", Scalar)) 

88 base.append(("bPerpMax", Scalar)) 

89 base.append(("mODR2", Scalar)) 

90 base.append(("bODR2", Scalar)) 

91 base.append(("mPerp", Scalar)) 

92 

93 return base 

94 

95 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure: 

96 self._validateInput(data, **kwargs) 

97 return self.makePlot(data, **kwargs) 

98 

99 def _validateInput(self, data: KeyedData, **kwargs) -> None: 

100 """NOTE currently can only check that something is not a scalar, not 

101 check that data is consistent with Vector 

102 """ 

103 needed = self.getInputSchema(**kwargs) 

104 if remainder := {key.format(**kwargs) for key, _ in needed} - { 

105 key.format(**kwargs) for key in data.keys() 

106 }: 

107 raise ValueError(f"Task needs keys {remainder} but they were not in input") 

108 for name, typ in needed: 

109 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar) 

110 if isScalar and typ != Scalar: 

111 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}") 

112 

113 def makePlot( 

114 self, 

115 data: KeyedData, 

116 plotInfo: Mapping[str, str], 

117 **kwargs, 

118 ) -> Figure: 

119 """Make stellar locus plots using pre fitted values. 

120 

121 Parameters 

122 ---------- 

123 data : `KeyedData` 

124 The data to plot the points from, for more information 

125 please see the notes section. 

126 plotInfo : `dict` 

127 A dictionary of information about the data being plotted 

128 with keys: 

129 

130 * ``"run"`` 

131 The output run for the plots (`str`). 

132 * ``"skymap"`` 

133 The type of skymap used for the data (`str`). 

134 * ``"filter"`` 

135 The filter used for this data (`str`). 

136 * ``"tract"`` 

137 The tract that the data comes from (`str`). 

138 

139 Returns 

140 ------- 

141 fig : `matplotlib.figure.Figure` 

142 The resulting figure. 

143 

144 Notes 

145 ----- 

146 The axis labels are given by `self.config.xLabel` and 

147 `self.config.yLabel`. The perpendicular distance of the points to 

148 the fit line is given in a histogram in the second panel. 

149 

150 For the code to work it expects various quantities to be 

151 present in the 'data' that it is given. 

152 

153 The quantities that are expected to be present are: 

154 

155 * Statistics that are shown on the plot or used by the plotting code: 

156 * ``approxMagDepth`` 

157 The approximate magnitude corresponding to the SN cut used. 

158 * ``f"{self.plotName}_sigmaMAD"`` 

159 The sigma mad of the distances to the line fit. 

160 * ``f"{self.identity or ''}_median"`` 

161 The median of the distances to the line fit. 

162 * ``f"{self.identity or ''}_hardwired_sigmaMAD"`` 

163 The sigma mad of the distances to the initial fit. 

164 * ``f"{self.identity or ''}_hardwired_median"`` 

165 The median of the distances to the initial fit. 

166 

167 

168 * Parameters from the fitting code that are illustrated on the plot: 

169 * ``"bHW"`` 

170 The hardwired intercept to fall back on. 

171 * ``"bODR"`` 

172 The intercept calculated by the orthogonal distance 

173 regression fitting. 

174 * ``"bODR2"`` 

175 The intercept calculated by the second iteration of 

176 orthogonal distance regression fitting. 

177 * ``"mHW"`` 

178 The hardwired gradient to fall back on. 

179 * ``"mODR"`` 

180 The gradient calculated by the orthogonal distance 

181 regression fitting. 

182 * ``"mODR2"`` 

183 The gradient calculated by the second iteration of 

184 orthogonal distance regression fitting. 

185 * ``"xMin`"`` 

186 The x minimum of the box used in the fit. 

187 * ``"xMax"`` 

188 The x maximum of the box used in the fit. 

189 * ``"yMin"`` 

190 The y minimum of the box used in the fit. 

191 * ``"yMax"`` 

192 The y maximum of the box used in the fit. 

193 * ``"mPerp"`` 

194 The gradient of the line perpendicular to the line from 

195 the second ODR fit. 

196 * ``"bPerpMin"`` 

197 The intercept of the perpendicular line that goes through 

198 xMin. 

199 * ``"bPerpMax"`` 

200 The intercept of the perpendicular line that goes through 

201 xMax. 

202 

203 * The main inputs to plot: 

204 x, y, mag 

205 

206 Examples 

207 -------- 

208 An example of the plot produced from this code is here: 

209 

210 .. image:: /_static/analysis_tools/stellarLocusExample.png 

211 

212 For a detailed example of how to make a plot from the command line 

213 please see the 

214 :ref:`getting started guide<analysis-tools-getting-started>`. 

215 """ 

216 

217 # Define a new colormap 

218 newBlues = mkColormap(["paleturquoise", "midnightblue"]) 

219 

220 # Make a figure with three panels 

221 fig = plt.figure(dpi=300) 

222 ax = fig.add_axes([0.12, 0.25, 0.43, 0.60]) 

223 axContour = fig.add_axes([0.65, 0.11, 0.3, 0.31]) 

224 axHist = fig.add_axes([0.65, 0.51, 0.3, 0.31]) 

225 

226 # Check for nans/infs 

227 goodPoints = np.isfinite(data["x"]) & np.isfinite(data["y"]) & np.isfinite(data["mag"]) 

228 xs = cast(Vector, data["x"])[goodPoints] 

229 ys = cast(Vector, data["y"])[goodPoints] 

230 mags = cast(Vector, data["mag"])[goodPoints] 

231 

232 # Points to use for the fit 

233 # type ignore because Vector needs a prototype interface 

234 fitPoints = np.where( 

235 (xs > data["xMin"]) # type: ignore 

236 & (xs < data["xMax"]) # type: ignore 

237 & (ys > data["yMin"]) # type: ignore 

238 & (ys < data["yMax"]) # type: ignore 

239 )[0] 

240 

241 # TODO: Make a no data fig function and use here 

242 if len(fitPoints) < self.minPointsForFit: 

243 fig = plt.figure(dpi=120) 

244 noDataText = ( 

245 "Number of objects after cuts ({}) is less than the\nminimum required by " 

246 "minPointsForFit ({})".format(len(fitPoints), self.minPointsForFit) 

247 ) 

248 fig.text(0.5, 0.5, noDataText, ha="center", va="center") 

249 fig = addPlotInfo(plt.gcf(), plotInfo) 

250 return fig 

251 

252 # Plot the initial fit box 

253 ax.plot( 

254 [data["xMin"], data["xMax"], data["xMax"], data["xMin"], data["xMin"]], 

255 [data["yMin"], data["yMin"], data["yMax"], data["yMax"], data["yMin"]], 

256 "k", 

257 alpha=0.3, 

258 ) 

259 

260 # Add some useful information to the plot 

261 bbox = dict(alpha=0.9, facecolor="white", edgecolor="none") 

262 medMag = np.nanmedian(cast(Vector, mags)) 

263 

264 # TODO: GET THE SN FROM THE EARLIER PREP STEP 

265 SN = "-" 

266 infoText = "N Used: {}\nN Total: {}\nS/N cut: {}\n".format(len(fitPoints), len(xs), SN) 

267 infoText += r"Mag $\lesssim$: " + "{:0.2f}".format(medMag) 

268 ax.text(0.05, 0.78, infoText, color="k", transform=ax.transAxes, fontsize=8, bbox=bbox) 

269 

270 # Calculate the density of the points 

271 xy = np.vstack([xs, ys]).T 

272 kde = KernelDensity(kernel="gaussian").fit(xy) 

273 z = np.exp(kde.score_samples(xy)) 

274 

275 ax.scatter(xs[~fitPoints], ys[~fitPoints], c=z[~fitPoints], cmap="binary", s=0.3) 

276 fitScatter = ax.scatter( 

277 xs[fitPoints], ys[fitPoints], c=z[fitPoints], cmap=newBlues, label="Used for Fit", s=0.3 

278 ) 

279 

280 # Add colorbar 

281 cbAx = fig.add_axes([0.12, 0.08, 0.43, 0.04]) 

282 plt.colorbar(fitScatter, cax=cbAx, orientation="horizontal") 

283 cbText = cbAx.text( 

284 0.5, 

285 0.5, 

286 "Number Density", 

287 color="k", 

288 rotation="horizontal", 

289 transform=cbAx.transAxes, 

290 ha="center", 

291 va="center", 

292 fontsize=8, 

293 ) 

294 cbText.set_path_effects([pathEffects.Stroke(linewidth=3, foreground="w"), pathEffects.Normal()]) 

295 cbAx.set_xticks([np.min(z[fitPoints]), np.max(z[fitPoints])], labels=["Less", "More"]) 

296 

297 ax.set_xlabel(self.xAxisLabel) 

298 ax.set_ylabel(self.yAxisLabel) 

299 

300 # Set useful axis limits 

301 percsX = np.nanpercentile(xs, [0.5, 99.5]) 

302 percsY = np.nanpercentile(ys, [0.5, 99.5]) 

303 x5 = (percsX[1] - percsX[0]) / 5 

304 y5 = (percsY[1] - percsY[0]) / 5 

305 ax.set_xlim(percsX[0] - x5, percsX[1] + x5) 

306 ax.set_ylim(percsY[0] - y5, percsY[1] + y5) 

307 

308 # Plot the fit lines 

309 if np.fabs(data["mHW"]) > 1: 

310 ysFitLineHW = np.array([data["yMin"], data["yMax"]]) 

311 xsFitLineHW = (ysFitLineHW - data["bHW"]) / data["mHW"] 

312 ysFitLine = np.array([data["yMin"], data["yMax"]]) 

313 xsFitLine = (ysFitLine - data["bODR"]) / data["mODR"] 

314 ysFitLine2 = np.array([data["yMin"], data["yMax"]]) 

315 xsFitLine2 = (ysFitLine2 - data["bODR2"]) / data["mODR2"] 

316 

317 else: 

318 xsFitLineHW = np.array([data["xMin"], data["xMax"]]) 

319 ysFitLineHW = data["mHW"] * xsFitLineHW + data["bHW"] # type: ignore 

320 xsFitLine = np.array([data["xMin"], data["xMax"]]) 

321 ysFitLine = np.array( 

322 [ 

323 data["mODR"] * xsFitLine[0] + data["bODR"], 

324 data["mODR"] * xsFitLine[1] + data["bODR"], 

325 ] 

326 ) 

327 xsFitLine2 = np.array([data["xMin"], data["xMax"]]) 

328 ysFitLine2 = np.array( 

329 [ 

330 data["mODR2"] * xsFitLine2[0] + data["bODR2"], 

331 data["mODR2"] * xsFitLine2[1] + data["bODR2"], 

332 ] 

333 ) 

334 

335 ax.plot(xsFitLineHW, ysFitLineHW, "w", lw=2) 

336 (lineHW,) = ax.plot(xsFitLineHW, ysFitLineHW, "g", lw=1, ls="--", label="Hardwired") 

337 

338 ax.plot(xsFitLine, ysFitLine, "w", lw=2) 

339 (lineInit,) = ax.plot(xsFitLine, ysFitLine, "b", lw=1, ls="--", label="Initial") 

340 

341 ax.plot(xsFitLine2, ysFitLine2, "w", lw=2) 

342 (lineRefit,) = ax.plot(xsFitLine2, ysFitLine2, "k", lw=1, ls="--", label="Refit") 

343 

344 # Calculate the distances to that line 

345 # Need two points to characterise the lines we want 

346 # to get the distances to 

347 p1 = np.array([xsFitLine[0], ysFitLine[0]]) 

348 p2 = np.array([xsFitLine[1], ysFitLine[1]]) 

349 

350 p1HW = np.array([xsFitLine[0], ysFitLineHW[0]]) 

351 p2HW = np.array([xsFitLine[1], ysFitLineHW[1]]) 

352 

353 # Convert to millimags 

354 distsHW = np.array(perpDistance(p1HW, p2HW, zip(xs[fitPoints], ys[fitPoints]))) * 1000 

355 dists = np.array(perpDistance(p1, p2, zip(xs[fitPoints], ys[fitPoints]))) * 1000 

356 

357 # Now we have the information for the perpendicular line we 

358 # can use it to calculate the points at the ends of the 

359 # perpendicular lines that intersect at the box edges 

360 if np.fabs(data["mHW"]) > 1: 

361 xMid = (data["yMin"] - data["bODR2"]) / data["mODR2"] 

362 xs = np.array([xMid - 0.5, xMid, xMid + 0.5]) 

363 ys = data["mPerp"] * xs + data["bPerpMin"] 

364 else: 

365 xs = np.array([data["xMin"] - 0.2, data["xMin"], data["xMin"] + 0.2]) 

366 ys = xs * data["mPerp"] + data["bPerpMin"] 

367 ax.plot(xs, ys, "k--", alpha=0.7) 

368 

369 if np.fabs(data["mHW"]) > 1: 

370 xMid = (data["yMax"] - data["bODR2"]) / data["mODR2"] 

371 xs = np.array([xMid - 0.5, xMid, xMid + 0.5]) 

372 ys = data["mPerp"] * xs + data["bPerpMax"] 

373 else: 

374 xs = np.array([data["xMax"] - 0.2, data["xMax"], data["xMax"] + 0.2]) 

375 ys = xs * data["mPerp"] + data["bPerpMax"] 

376 ax.plot(xs, ys, "k--", alpha=0.7) 

377 

378 # Add a histogram 

379 axHist.set_ylabel("Number") 

380 axHist.set_xlabel("Distance to Line Fit") 

381 medDists = np.nanmedian(dists) 

382 madDists = sigmaMad(dists, nan_policy="omit") 

383 meanDists = np.nanmean(dists) 

384 

385 axHist.set_xlim(meanDists - 2.0 * madDists, meanDists + 2.0 * madDists) 

386 lineMedian = axHist.axvline(medDists, color="k", label="Median: {:0.3f}".format(medDists)) 

387 lineMad = axHist.axvline( 

388 medDists + madDists, color="k", ls="--", label="sigma MAD: {:0.3f}".format(madDists) 

389 ) 

390 axHist.axvline(medDists - madDists, color="k", ls="--") 

391 

392 linesForLegend = [lineHW, lineInit, lineRefit, fitScatter, lineMedian, lineMad] 

393 fig.legend( 

394 handles=linesForLegend, 

395 fontsize=8, 

396 bbox_to_anchor=(1.0, 0.99), 

397 bbox_transform=fig.transFigure, 

398 ncol=2, 

399 ) 

400 

401 axHist.hist(dists, bins=100, histtype="step", label="Refit", color="C0") 

402 axHist.hist(distsHW, bins=100, histtype="step", label="HW", color="C0", alpha=0.5) 

403 

404 alphas = [1.0, 0.5] 

405 handles = [Rectangle((0, 0), 1, 1, color="none", ec="C0", alpha=a) for a in alphas] 

406 labels = ["Refit", "HW"] 

407 axHist.legend(handles, labels, fontsize=6, loc="upper right") 

408 

409 # Add a contour plot showing the magnitude dependance 

410 # of the distance to the fit 

411 axContour.invert_yaxis() 

412 axContour.axvline(0.0, color="k", ls="--", zorder=-1) 

413 percsDists = np.nanpercentile(dists, [4, 96]) 

414 minXs = -1 * np.min(np.fabs(percsDists)) 

415 maxXs = np.min(np.fabs(percsDists)) 

416 plotPoints = (dists < maxXs) & (dists > minXs) 

417 xs = np.array(dists)[plotPoints] 

418 ys = cast(Vector, cast(Vector, mags)[cast(Vector, fitPoints)])[cast(Vector, plotPoints)] 

419 H, xEdges, yEdges = np.histogram2d(xs, ys, bins=(11, 11)) 

420 xBinWidth = xEdges[1] - xEdges[0] 

421 yBinWidth = yEdges[1] - yEdges[0] 

422 axContour.contour( 

423 xEdges[:-1] + xBinWidth / 2, yEdges[:-1] + yBinWidth / 2, H.T, levels=7, cmap=newBlues 

424 ) 

425 axContour.set_xlabel("Distance to Line Fit") 

426 axContour.set_ylabel(self.magLabel) 

427 

428 fig = addPlotInfo(plt.gcf(), plotInfo) 

429 

430 return fig