Coverage for python/lsst/analysis/tools/actions/plot/skyPlot.py: 11%

173 statements  

« prev     ^ index     » next       coverage.py v7.2.6, created at 2023-05-24 02:36 -0700

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ("SkyPlot",) 

25 

26from typing import Mapping, Optional 

27 

28import matplotlib.patheffects as pathEffects 

29import matplotlib.pyplot as plt 

30import numpy as np 

31from lsst.pex.config import Field, ListField 

32from matplotlib.figure import Figure 

33from matplotlib.patches import Rectangle 

34 

35from ...interfaces import KeyedData, KeyedDataSchema, PlotAction, Scalar, Vector 

36from ...statistics import nansigmaMad 

37from .plotUtils import addPlotInfo, mkColormap, plotProjectionWithBinning, sortAllArrays 

38 

39# from .plotUtils import generateSummaryStats, parsePlotInfo 

40 

41 

42class SkyPlot(PlotAction): 

43 """Plots the on sky distribution of a parameter. 

44 

45 Plots the values of the parameter given for the z axis 

46 according to the positions given for x and y. Optimised 

47 for use with RA and Dec. Also calculates some basic 

48 statistics and includes those on the plot. 

49 """ 

50 

51 xAxisLabel = Field[str](doc="Label to use for the x axis.", optional=False) 

52 yAxisLabel = Field[str](doc="Label to use for the y axis.", optional=False) 

53 zAxisLabel = Field[str](doc="Label to use for the z axis.", optional=False) 

54 

55 plotOutlines = Field[bool]( 

56 doc="Plot the outlines of the ccds/patches?", 

57 default=True, 

58 ) 

59 

60 plotTypes = ListField[str]( 

61 doc="Selection of types of objects to plot. Can take any combination of" 

62 " stars, galaxies, unknown, mag, any.", 

63 optional=False, 

64 # itemCheck=_validatePlotTypes, 

65 ) 

66 

67 plotName = Field[str](doc="The name for the plot.", optional=False) 

68 

69 fixAroundZero = Field[bool]( 

70 doc="Fix the colorbar to be symmetric around zero.", 

71 default=False, 

72 ) 

73 

74 def getInputSchema(self, **kwargs) -> KeyedDataSchema: 

75 base = [] 

76 if "stars" in self.plotTypes: # type: ignore 

77 base.append(("xStars", Vector)) 

78 base.append(("yStars", Vector)) 

79 base.append(("zStars", Vector)) 

80 base.append(("starStatMask", Vector)) 

81 if "galaxies" in self.plotTypes: # type: ignore 

82 base.append(("xGalaxies", Vector)) 

83 base.append(("yGalaxies", Vector)) 

84 base.append(("zGalaxies", Vector)) 

85 base.append(("galaxyStatMask", Vector)) 

86 if "unknown" in self.plotTypes: # type: ignore 

87 base.append(("xUnknowns", Vector)) 

88 base.append(("yUnknowns", Vector)) 

89 base.append(("zUnknowns", Vector)) 

90 base.append(("unknownStatMask", Vector)) 

91 if "any" in self.plotTypes: # type: ignore 

92 base.append(("x", Vector)) 

93 base.append(("y", Vector)) 

94 base.append(("z", Vector)) 

95 base.append(("statMask", Vector)) 

96 

97 return base 

98 

99 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure: 

100 self._validateInput(data, **kwargs) 

101 return self.makePlot(data, **kwargs) 

102 # table is a dict that needs: x, y, run, skymap, filter, tract, 

103 

104 def _validateInput(self, data: KeyedData, **kwargs) -> None: 

105 """NOTE currently can only check that something is not a Scalar, not 

106 check that the data is consistent with Vector 

107 """ 

108 needed = self.getInputSchema(**kwargs) 

109 if remainder := {key.format(**kwargs) for key, _ in needed} - { 

110 key.format(**kwargs) for key in data.keys() 

111 }: 

112 raise ValueError(f"Task needs keys {remainder} but they were not found in input") 

113 for name, typ in needed: 

114 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar) 

115 if isScalar and typ != Scalar: 

116 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}") 

117 

118 def statsAndText(self, arr, mask=None): 

119 """Calculate some stats from an array and return them 

120 and some text. 

121 """ 

122 numPoints = len(arr) 

123 if mask is not None: 

124 arr = arr[mask] 

125 med = np.nanmedian(arr) 

126 sigMad = nansigmaMad(arr) 

127 

128 statsText = ( 

129 "Median: {:0.2f}\n".format(med) 

130 + r"$\sigma_{MAD}$: " 

131 + "{:0.2f}\n".format(sigMad) 

132 + r"n$_{points}$: " 

133 + "{}".format(numPoints) 

134 ) 

135 

136 return med, sigMad, statsText 

137 

138 def makePlot( 

139 self, 

140 data: KeyedData, 

141 plotInfo: Optional[Mapping[str, str]] = None, 

142 sumStats: Optional[Mapping] = None, 

143 **kwargs, 

144 ) -> Figure: 

145 """Make a skyPlot of the given data. 

146 

147 Parameters 

148 ---------- 

149 data : `KeyedData` 

150 The catalog to plot the points from. 

151 plotInfo : `dict` 

152 A dictionary of information about the data being plotted with keys: 

153 

154 ``"run"`` 

155 The output run for the plots (`str`). 

156 ``"skymap"`` 

157 The type of skymap used for the data (`str`). 

158 ``"filter"`` 

159 The filter used for this data (`str`). 

160 ``"tract"`` 

161 The tract that the data comes from (`str`). 

162 

163 sumStats : `dict` 

164 A dictionary where the patchIds are the keys which store the R.A. 

165 and dec of the corners of the patch. 

166 

167 Returns 

168 ------- 

169 `pipeBase.Struct` containing: 

170 skyPlot : `matplotlib.figure.Figure` 

171 The resulting figure. 

172 

173 Notes 

174 ----- 

175 Expects the data to contain slightly different things 

176 depending on the types specified in plotTypes. This 

177 is handled automatically if you go through the pipetask 

178 framework but if you call this method separately then you 

179 need to make sure that data contains what the code is expecting. 

180 

181 If stars is in the plot types given then it is expected that 

182 data contains: xStars, yStars, zStars and starStatMask. 

183 

184 If galaxies is present: xGalaxies, yGalaxies, zGalaxies and 

185 galaxyStatsMask. 

186 

187 If unknown is present: xUnknowns, yUnknowns, zUnknowns and 

188 unknownStatMask. 

189 

190 If any is specified: x, y, z, statMask. 

191 

192 These options are not exclusive and multiple can be specified 

193 and thus need to be present in data. 

194 

195 Examples 

196 -------- 

197 An example of the plot produced from this code is here: 

198 

199 .. image:: /_static/analysis_tools/skyPlotExample.png 

200 

201 For a detailed example of how to make a plot from the command line 

202 please see the 

203 :ref:`getting started guide<analysis-tools-getting-started>`. 

204 """ 

205 

206 fig = plt.figure(dpi=300) 

207 ax = fig.add_subplot(111) 

208 

209 if sumStats is None: 

210 sumStats = {} 

211 

212 if plotInfo is None: 

213 plotInfo = {} 

214 

215 # Make divergent colormaps for stars, galaxes and all the points 

216 blueGreen = mkColormap(["midnightblue", "lightcyan", "darkgreen"]) 

217 redPurple = mkColormap(["indigo", "lemonchiffon", "firebrick"]) 

218 orangeBlue = mkColormap(["darkOrange", "thistle", "midnightblue"]) 

219 

220 xCol = self.xAxisLabel 

221 yCol = self.yAxisLabel 

222 zCol = self.zAxisLabel # noqa: F841 

223 

224 toPlotList = [] 

225 # For galaxies 

226 if "galaxies" in self.plotTypes: 

227 sortedArrs = sortAllArrays( 

228 [data["zGalaxies"], data["xGalaxies"], data["yGalaxies"], data["galaxyStatMask"]] 

229 ) 

230 [colorValsGalaxies, xsGalaxies, ysGalaxies, statGalaxies] = sortedArrs 

231 statGalMed, statGalMad, galStatsText = self.statsAndText(colorValsGalaxies, mask=statGalaxies) 

232 # Add statistics 

233 bbox = dict(facecolor="lemonchiffon", alpha=0.5, edgecolor="none") 

234 # Check if plotting stars and galaxies, if so move the 

235 # text box so that both can be seen. Needs to be 

236 # > 2 becuase not being plotted points are assigned 0 

237 if len(self.plotTypes) > 2: 

238 boxLoc = 0.63 

239 else: 

240 boxLoc = 0.8 

241 ax.text(boxLoc, 0.91, galStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox) 

242 toPlotList.append((xsGalaxies, ysGalaxies, colorValsGalaxies, redPurple, "Galaxies")) 

243 

244 # For stars 

245 if "stars" in self.plotTypes: 

246 sortedArrs = sortAllArrays([data["zStars"], data["xStars"], data["yStars"], data["starStatMask"]]) 

247 [colorValsStars, xsStars, ysStars, statStars] = sortedArrs 

248 statStarMed, statStarMad, starStatsText = self.statsAndText(colorValsStars, mask=statStars) 

249 # Add statistics 

250 bbox = dict(facecolor="paleturquoise", alpha=0.5, edgecolor="none") 

251 ax.text(0.8, 0.91, starStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox) 

252 toPlotList.append((xsStars, ysStars, colorValsStars, blueGreen, "Stars")) 

253 

254 # For unknowns 

255 if "unknown" in self.plotTypes: 

256 sortedArrs = sortAllArrays( 

257 [data["zUnknowns"], data["xUnknowns"], data["yUnknowns"], data["unknownStatMask"]] 

258 ) 

259 [colorValsUnknowns, xsUnknowns, ysUnknowns, statUnknowns] = sortedArrs 

260 statUnknownMed, statUnknownMad, unknownStatsText = self.statsAndText( 

261 colorValsUnknowns, mask=statUnknowns 

262 ) 

263 bbox = dict(facecolor="green", alpha=0.2, edgecolor="none") 

264 ax.text(0.8, 0.91, unknownStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox) 

265 toPlotList.append((xsUnknowns, ysUnknowns, colorValsUnknowns, "viridis", "Unknown")) 

266 

267 if "any" in self.plotTypes: 

268 sortedArrs = sortAllArrays([data["z"], data["x"], data["y"], data["statMask"]]) 

269 [colorValsAny, xs, ys, statAny] = sortedArrs 

270 statAnyMed, statAnyMad, anyStatsText = self.statsAndText(colorValsAny, mask=statAny) 

271 bbox = dict(facecolor="purple", alpha=0.2, edgecolor="none") 

272 ax.text(0.8, 0.91, anyStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox) 

273 toPlotList.append((xs, ys, colorValsAny, orangeBlue, "All")) 

274 

275 # Corner plot of patches showing summary stat in each 

276 if self.plotOutlines: 

277 patches = [] 

278 for dataId in sumStats.keys(): 

279 (corners, _) = sumStats[dataId] 

280 ra = corners[0][0].asDegrees() 

281 dec = corners[0][1].asDegrees() 

282 xy = (ra, dec) 

283 width = corners[2][0].asDegrees() - ra 

284 height = corners[2][1].asDegrees() - dec 

285 patches.append(Rectangle(xy, width, height, alpha=0.3)) 

286 ras = [ra.asDegrees() for (ra, dec) in corners] 

287 decs = [dec.asDegrees() for (ra, dec) in corners] 

288 ax.plot(ras + [ras[0]], decs + [decs[0]], "k", lw=0.5) 

289 cenX = ra + width / 2 

290 cenY = dec + height / 2 

291 if dataId == "tract": 

292 minRa = np.min(ras) 

293 minDec = np.min(decs) 

294 maxRa = np.max(ras) 

295 maxDec = np.max(decs) 

296 if dataId != "tract": 

297 ax.annotate( 

298 dataId, 

299 (cenX, cenY), 

300 color="k", 

301 fontsize=5, 

302 ha="center", 

303 va="center", 

304 path_effects=[pathEffects.withStroke(linewidth=2, foreground="w")], 

305 ) 

306 

307 for i, (xs, ys, colorVals, cmap, label) in enumerate(toPlotList): 

308 finite = np.isfinite(xs) & np.isfinite(ys) 

309 xs = xs[finite] 

310 ys = ys[finite] 

311 n_xs = len(xs) 

312 # colorVal column is unusable so zero it out 

313 # This should be obvious on the plot 

314 if not any(np.isfinite(colorVals)): 

315 colorVals[:] = 0 

316 

317 if n_xs < 5: 

318 continue 

319 if not self.plotOutlines or "tract" not in sumStats.keys(): 

320 minRa = np.min(xs) 

321 maxRa = np.max(xs) 

322 minDec = np.min(ys) 

323 maxDec = np.max(ys) 

324 # Avoid identical end points which causes problems in binning 

325 if minRa == maxRa: 

326 maxRa += 1e-5 # There is no reason to pick this number in particular 

327 if minDec == maxDec: 

328 maxDec += 1e-5 # There is no reason to pick this number in particular 

329 

330 plotOut = plotProjectionWithBinning( 

331 ax, 

332 xs, 

333 ys, 

334 colorVals, 

335 cmap, 

336 minRa, 

337 maxRa, 

338 minDec, 

339 maxDec, 

340 fixAroundZero=self.fixAroundZero, 

341 isSorted=True, 

342 ) 

343 cax = fig.add_axes([0.87 + i * 0.04, 0.11, 0.04, 0.77]) 

344 plt.colorbar(plotOut, cax=cax, extend="both") 

345 colorBarLabel = "{}: {}".format(self.zAxisLabel, label) 

346 text = cax.text( 

347 0.5, 

348 0.5, 

349 colorBarLabel, 

350 color="k", 

351 rotation="vertical", 

352 transform=cax.transAxes, 

353 ha="center", 

354 va="center", 

355 fontsize=10, 

356 ) 

357 text.set_path_effects([pathEffects.Stroke(linewidth=3, foreground="w"), pathEffects.Normal()]) 

358 cax.tick_params(labelsize=7) 

359 

360 if i == 0 and len(toPlotList) > 1: 

361 cax.yaxis.set_ticks_position("left") 

362 

363 ax.set_xlabel(xCol) 

364 ax.set_ylabel(yCol) 

365 ax.tick_params(axis="x", labelrotation=25) 

366 ax.tick_params(labelsize=7) 

367 

368 ax.set_aspect("equal") 

369 plt.draw() 

370 

371 # Find some useful axis limits 

372 lenXs = [len(xs) for (xs, _, _, _, _) in toPlotList] 

373 if lenXs != [] and np.max(lenXs) > 1000: 

374 padRa = (maxRa - minRa) / 10 

375 padDec = (maxDec - minDec) / 10 

376 ax.set_xlim(maxRa + padRa, minRa - padRa) 

377 ax.set_ylim(minDec - padDec, maxDec + padDec) 

378 else: 

379 ax.invert_xaxis() 

380 

381 # Add useful information to the plot 

382 plt.subplots_adjust(wspace=0.0, hspace=0.0, right=0.85) 

383 fig = plt.gcf() 

384 fig = addPlotInfo(fig, plotInfo) 

385 

386 return fig