Coverage for python/lsst/analysis/tools/actions/plot/skyPlot.py: 11%

173 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-07-23 09:42 +0000

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ("SkyPlot",) 

25 

26from typing import Mapping, Optional 

27 

28import matplotlib.patheffects as pathEffects 

29import matplotlib.pyplot as plt 

30import numpy as np 

31from lsst.pex.config import Field, ListField 

32from matplotlib.figure import Figure 

33from matplotlib.patches import Rectangle 

34 

35from ...interfaces import KeyedData, KeyedDataSchema, PlotAction, Scalar, Vector 

36from ...statistics import nansigmaMad 

37from .plotUtils import addPlotInfo, mkColormap, plotProjectionWithBinning, sortAllArrays 

38 

39 

40class SkyPlot(PlotAction): 

41 """Plots the on sky distribution of a parameter. 

42 

43 Plots the values of the parameter given for the z axis 

44 according to the positions given for x and y. Optimised 

45 for use with RA and Dec. Also calculates some basic 

46 statistics and includes those on the plot. 

47 """ 

48 

49 xAxisLabel = Field[str](doc="Label to use for the x axis.", optional=False) 

50 yAxisLabel = Field[str](doc="Label to use for the y axis.", optional=False) 

51 zAxisLabel = Field[str](doc="Label to use for the z axis.", optional=False) 

52 

53 plotOutlines = Field[bool]( 

54 doc="Plot the outlines of the ccds/patches?", 

55 default=True, 

56 ) 

57 

58 plotTypes = ListField[str]( 

59 doc="Selection of types of objects to plot. Can take any combination of" 

60 " stars, galaxies, unknown, mag, any.", 

61 optional=False, 

62 # itemCheck=_validatePlotTypes, 

63 ) 

64 

65 plotName = Field[str](doc="The name for the plot.", optional=False) 

66 

67 fixAroundZero = Field[bool]( 

68 doc="Fix the colorbar to be symmetric around zero.", 

69 default=False, 

70 ) 

71 

72 def getInputSchema(self, **kwargs) -> KeyedDataSchema: 

73 base = [] 

74 if "stars" in self.plotTypes: # type: ignore 

75 base.append(("xStars", Vector)) 

76 base.append(("yStars", Vector)) 

77 base.append(("zStars", Vector)) 

78 base.append(("starStatMask", Vector)) 

79 if "galaxies" in self.plotTypes: # type: ignore 

80 base.append(("xGalaxies", Vector)) 

81 base.append(("yGalaxies", Vector)) 

82 base.append(("zGalaxies", Vector)) 

83 base.append(("galaxyStatMask", Vector)) 

84 if "unknown" in self.plotTypes: # type: ignore 

85 base.append(("xUnknowns", Vector)) 

86 base.append(("yUnknowns", Vector)) 

87 base.append(("zUnknowns", Vector)) 

88 base.append(("unknownStatMask", Vector)) 

89 if "any" in self.plotTypes: # type: ignore 

90 base.append(("x", Vector)) 

91 base.append(("y", Vector)) 

92 base.append(("z", Vector)) 

93 base.append(("statMask", Vector)) 

94 

95 return base 

96 

97 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure: 

98 self._validateInput(data, **kwargs) 

99 return self.makePlot(data, **kwargs) 

100 # table is a dict that needs: x, y, run, skymap, filter, tract, 

101 

102 def _validateInput(self, data: KeyedData, **kwargs) -> None: 

103 """NOTE currently can only check that something is not a Scalar, not 

104 check that the data is consistent with Vector 

105 """ 

106 needed = self.getInputSchema(**kwargs) 

107 if remainder := {key.format(**kwargs) for key, _ in needed} - { 

108 key.format(**kwargs) for key in data.keys() 

109 }: 

110 raise ValueError(f"Task needs keys {remainder} but they were not found in input") 

111 for name, typ in needed: 

112 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar) 

113 if isScalar and typ != Scalar: 

114 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}") 

115 

116 def statsAndText(self, arr, mask=None): 

117 """Calculate some stats from an array and return them 

118 and some text. 

119 """ 

120 numPoints = len(arr) 

121 if mask is not None: 

122 arr = arr[mask] 

123 med = np.nanmedian(arr) 

124 sigMad = nansigmaMad(arr) 

125 

126 statsText = ( 

127 "Median: {:0.2f}\n".format(med) 

128 + r"$\sigma_{MAD}$: " 

129 + "{:0.2f}\n".format(sigMad) 

130 + r"n$_{points}$: " 

131 + "{}".format(numPoints) 

132 ) 

133 

134 return med, sigMad, statsText 

135 

136 def makePlot( 

137 self, 

138 data: KeyedData, 

139 plotInfo: Optional[Mapping[str, str]] = None, 

140 sumStats: Optional[Mapping] = None, 

141 **kwargs, 

142 ) -> Figure: 

143 """Make a skyPlot of the given data. 

144 

145 Parameters 

146 ---------- 

147 data : `KeyedData` 

148 The catalog to plot the points from. 

149 plotInfo : `dict` 

150 A dictionary of information about the data being plotted with keys: 

151 

152 ``"run"`` 

153 The output run for the plots (`str`). 

154 ``"skymap"`` 

155 The type of skymap used for the data (`str`). 

156 ``"filter"`` 

157 The filter used for this data (`str`). 

158 ``"tract"`` 

159 The tract that the data comes from (`str`). 

160 

161 sumStats : `dict` 

162 A dictionary where the patchIds are the keys which store the R.A. 

163 and dec of the corners of the patch. 

164 

165 Returns 

166 ------- 

167 `pipeBase.Struct` containing: 

168 skyPlot : `matplotlib.figure.Figure` 

169 The resulting figure. 

170 

171 Notes 

172 ----- 

173 Expects the data to contain slightly different things 

174 depending on the types specified in plotTypes. This 

175 is handled automatically if you go through the pipetask 

176 framework but if you call this method separately then you 

177 need to make sure that data contains what the code is expecting. 

178 

179 If stars is in the plot types given then it is expected that 

180 data contains: xStars, yStars, zStars and starStatMask. 

181 

182 If galaxies is present: xGalaxies, yGalaxies, zGalaxies and 

183 galaxyStatsMask. 

184 

185 If unknown is present: xUnknowns, yUnknowns, zUnknowns and 

186 unknownStatMask. 

187 

188 If any is specified: x, y, z, statMask. 

189 

190 These options are not exclusive and multiple can be specified 

191 and thus need to be present in data. 

192 

193 Examples 

194 -------- 

195 An example of the plot produced from this code is here: 

196 

197 .. image:: /_static/analysis_tools/skyPlotExample.png 

198 

199 For a detailed example of how to make a plot from the command line 

200 please see the 

201 :ref:`getting started guide<analysis-tools-getting-started>`. 

202 """ 

203 

204 fig = plt.figure(dpi=300) 

205 ax = fig.add_subplot(111) 

206 

207 if sumStats is None: 

208 sumStats = {} 

209 

210 if plotInfo is None: 

211 plotInfo = {} 

212 

213 # Make divergent colormaps for stars, galaxes and all the points 

214 blueGreen = mkColormap(["midnightblue", "lightcyan", "darkgreen"]) 

215 redPurple = mkColormap(["indigo", "lemonchiffon", "firebrick"]) 

216 orangeBlue = mkColormap(["darkOrange", "thistle", "midnightblue"]) 

217 

218 xCol = self.xAxisLabel 

219 yCol = self.yAxisLabel 

220 zCol = self.zAxisLabel # noqa: F841 

221 

222 toPlotList = [] 

223 # For galaxies 

224 if "galaxies" in self.plotTypes: 

225 sortedArrs = sortAllArrays( 

226 [data["zGalaxies"], data["xGalaxies"], data["yGalaxies"], data["galaxyStatMask"]] 

227 ) 

228 [colorValsGalaxies, xsGalaxies, ysGalaxies, statGalaxies] = sortedArrs 

229 statGalMed, statGalMad, galStatsText = self.statsAndText(colorValsGalaxies, mask=statGalaxies) 

230 # Add statistics 

231 bbox = dict(facecolor="lemonchiffon", alpha=0.5, edgecolor="none") 

232 # Check if plotting stars and galaxies, if so move the 

233 # text box so that both can be seen. Needs to be 

234 # > 2 becuase not being plotted points are assigned 0 

235 if len(self.plotTypes) > 2: 

236 boxLoc = 0.63 

237 else: 

238 boxLoc = 0.8 

239 ax.text(boxLoc, 0.91, galStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox) 

240 toPlotList.append((xsGalaxies, ysGalaxies, colorValsGalaxies, redPurple, "Galaxies")) 

241 

242 # For stars 

243 if "stars" in self.plotTypes: 

244 sortedArrs = sortAllArrays([data["zStars"], data["xStars"], data["yStars"], data["starStatMask"]]) 

245 [colorValsStars, xsStars, ysStars, statStars] = sortedArrs 

246 statStarMed, statStarMad, starStatsText = self.statsAndText(colorValsStars, mask=statStars) 

247 # Add statistics 

248 bbox = dict(facecolor="paleturquoise", alpha=0.5, edgecolor="none") 

249 ax.text(0.8, 0.91, starStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox) 

250 toPlotList.append((xsStars, ysStars, colorValsStars, blueGreen, "Stars")) 

251 

252 # For unknowns 

253 if "unknown" in self.plotTypes: 

254 sortedArrs = sortAllArrays( 

255 [data["zUnknowns"], data["xUnknowns"], data["yUnknowns"], data["unknownStatMask"]] 

256 ) 

257 [colorValsUnknowns, xsUnknowns, ysUnknowns, statUnknowns] = sortedArrs 

258 statUnknownMed, statUnknownMad, unknownStatsText = self.statsAndText( 

259 colorValsUnknowns, mask=statUnknowns 

260 ) 

261 bbox = dict(facecolor="green", alpha=0.2, edgecolor="none") 

262 ax.text(0.8, 0.91, unknownStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox) 

263 toPlotList.append((xsUnknowns, ysUnknowns, colorValsUnknowns, "viridis", "Unknown")) 

264 

265 if "any" in self.plotTypes: 

266 sortedArrs = sortAllArrays([data["z"], data["x"], data["y"], data["statMask"]]) 

267 [colorValsAny, xs, ys, statAny] = sortedArrs 

268 statAnyMed, statAnyMad, anyStatsText = self.statsAndText(colorValsAny, mask=statAny) 

269 bbox = dict(facecolor="purple", alpha=0.2, edgecolor="none") 

270 ax.text(0.8, 0.91, anyStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox) 

271 toPlotList.append((xs, ys, colorValsAny, orangeBlue, "All")) 

272 

273 # Corner plot of patches showing summary stat in each 

274 if self.plotOutlines: 

275 patches = [] 

276 for dataId in sumStats.keys(): 

277 (corners, _) = sumStats[dataId] 

278 ra = corners[0][0].asDegrees() 

279 dec = corners[0][1].asDegrees() 

280 xy = (ra, dec) 

281 width = corners[2][0].asDegrees() - ra 

282 height = corners[2][1].asDegrees() - dec 

283 patches.append(Rectangle(xy, width, height, alpha=0.3)) 

284 ras = [ra.asDegrees() for (ra, dec) in corners] 

285 decs = [dec.asDegrees() for (ra, dec) in corners] 

286 ax.plot(ras + [ras[0]], decs + [decs[0]], "k", lw=0.5) 

287 cenX = ra + width / 2 

288 cenY = dec + height / 2 

289 if dataId == "tract": 

290 minRa = np.min(ras) 

291 minDec = np.min(decs) 

292 maxRa = np.max(ras) 

293 maxDec = np.max(decs) 

294 if dataId != "tract": 

295 ax.annotate( 

296 dataId, 

297 (cenX, cenY), 

298 color="k", 

299 fontsize=5, 

300 ha="center", 

301 va="center", 

302 path_effects=[pathEffects.withStroke(linewidth=2, foreground="w")], 

303 ) 

304 

305 for i, (xs, ys, colorVals, cmap, label) in enumerate(toPlotList): 

306 finite = np.isfinite(xs) & np.isfinite(ys) 

307 xs = xs[finite] 

308 ys = ys[finite] 

309 n_xs = len(xs) 

310 # colorVal column is unusable so zero it out 

311 # This should be obvious on the plot 

312 if not any(np.isfinite(colorVals)): 

313 colorVals[:] = 0 

314 

315 if n_xs < 5: 

316 continue 

317 if not self.plotOutlines or "tract" not in sumStats.keys(): 

318 minRa = np.min(xs) 

319 maxRa = np.max(xs) 

320 minDec = np.min(ys) 

321 maxDec = np.max(ys) 

322 # Avoid identical end points which causes problems in binning 

323 if minRa == maxRa: 

324 maxRa += 1e-5 # There is no reason to pick this number in particular 

325 if minDec == maxDec: 

326 maxDec += 1e-5 # There is no reason to pick this number in particular 

327 

328 plotOut = plotProjectionWithBinning( 

329 ax, 

330 xs, 

331 ys, 

332 colorVals, 

333 cmap, 

334 minRa, 

335 maxRa, 

336 minDec, 

337 maxDec, 

338 fixAroundZero=self.fixAroundZero, 

339 isSorted=True, 

340 ) 

341 cax = fig.add_axes([0.87 + i * 0.04, 0.11, 0.04, 0.77]) 

342 plt.colorbar(plotOut, cax=cax, extend="both") 

343 colorBarLabel = "{}: {}".format(self.zAxisLabel, label) 

344 text = cax.text( 

345 0.5, 

346 0.5, 

347 colorBarLabel, 

348 color="k", 

349 rotation="vertical", 

350 transform=cax.transAxes, 

351 ha="center", 

352 va="center", 

353 fontsize=10, 

354 ) 

355 text.set_path_effects([pathEffects.Stroke(linewidth=3, foreground="w"), pathEffects.Normal()]) 

356 cax.tick_params(labelsize=7) 

357 

358 if i == 0 and len(toPlotList) > 1: 

359 cax.yaxis.set_ticks_position("left") 

360 

361 ax.set_xlabel(xCol) 

362 ax.set_ylabel(yCol) 

363 ax.tick_params(axis="x", labelrotation=25) 

364 ax.tick_params(labelsize=7) 

365 

366 ax.set_aspect("equal") 

367 plt.draw() 

368 

369 # Find some useful axis limits 

370 lenXs = [len(xs) for (xs, _, _, _, _) in toPlotList] 

371 if lenXs != [] and np.max(lenXs) > 1000: 

372 padRa = (maxRa - minRa) / 10 

373 padDec = (maxDec - minDec) / 10 

374 ax.set_xlim(maxRa + padRa, minRa - padRa) 

375 ax.set_ylim(minDec - padDec, maxDec + padDec) 

376 else: 

377 ax.invert_xaxis() 

378 

379 # Add useful information to the plot 

380 plt.subplots_adjust(wspace=0.0, hspace=0.0, right=0.85) 

381 fig = plt.gcf() 

382 fig = addPlotInfo(fig, plotInfo) 

383 

384 return fig