Coverage for python/lsst/analysis/tools/actions/plot/skyPlot.py: 11%

173 statements  

« prev     ^ index     » next       coverage.py v7.2.5, created at 2023-05-10 10:36 +0000

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ("SkyPlot",) 

25 

26from typing import Mapping, Optional 

27 

28import matplotlib.patheffects as pathEffects 

29import matplotlib.pyplot as plt 

30import numpy as np 

31from lsst.pex.config import Field, ListField 

32from matplotlib.figure import Figure 

33from matplotlib.patches import Rectangle 

34 

35from ...interfaces import KeyedData, KeyedDataSchema, PlotAction, Scalar, Vector 

36from ...statistics import nansigmaMad 

37from .plotUtils import addPlotInfo, mkColormap, plotProjectionWithBinning, sortAllArrays 

38 

39# from .plotUtils import generateSummaryStats, parsePlotInfo 

40 

41 

42class SkyPlot(PlotAction): 

43 """Plots the on sky distribution of a parameter. 

44 

45 Plots the values of the parameter given for the z axis 

46 according to the positions given for x and y. Optimised 

47 for use with RA and Dec. Also calculates some basic 

48 statistics and includes those on the plot. 

49 """ 

50 

51 xAxisLabel = Field[str](doc="Label to use for the x axis.", optional=False) 

52 yAxisLabel = Field[str](doc="Label to use for the y axis.", optional=False) 

53 zAxisLabel = Field[str](doc="Label to use for the z axis.", optional=False) 

54 

55 plotOutlines = Field[bool]( 

56 doc="Plot the outlines of the ccds/patches?", 

57 default=True, 

58 ) 

59 

60 plotTypes = ListField[str]( 

61 doc="Selection of types of objects to plot. Can take any combination of" 

62 " stars, galaxies, unknown, mag, any.", 

63 optional=False, 

64 # itemCheck=_validatePlotTypes, 

65 ) 

66 

67 plotName = Field[str](doc="The name for the plot.", optional=False) 

68 

69 fixAroundZero = Field[bool]( 

70 doc="Fix the colorbar to be symmetric around zero.", 

71 default=False, 

72 ) 

73 

74 def getInputSchema(self, **kwargs) -> KeyedDataSchema: 

75 base = [] 

76 if "stars" in self.plotTypes: # type: ignore 

77 base.append(("xStars", Vector)) 

78 base.append(("yStars", Vector)) 

79 base.append(("zStars", Vector)) 

80 base.append(("starStatMask", Vector)) 

81 if "galaxies" in self.plotTypes: # type: ignore 

82 base.append(("xGalaxies", Vector)) 

83 base.append(("yGalaxies", Vector)) 

84 base.append(("zGalaxies", Vector)) 

85 base.append(("galaxyStatMask", Vector)) 

86 if "unknown" in self.plotTypes: # type: ignore 

87 base.append(("xUnknowns", Vector)) 

88 base.append(("yUnknowns", Vector)) 

89 base.append(("zUnknowns", Vector)) 

90 base.append(("unknownStatMask", Vector)) 

91 if "any" in self.plotTypes: # type: ignore 

92 base.append(("x", Vector)) 

93 base.append(("y", Vector)) 

94 base.append(("z", Vector)) 

95 base.append(("statMask", Vector)) 

96 

97 return base 

98 

99 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure: 

100 self._validateInput(data, **kwargs) 

101 return self.makePlot(data, **kwargs) 

102 # table is a dict that needs: x, y, run, skymap, filter, tract, 

103 

104 def _validateInput(self, data: KeyedData, **kwargs) -> None: 

105 """NOTE currently can only check that something is not a Scalar, not 

106 check that the data is consistent with Vector 

107 """ 

108 needed = self.getInputSchema(**kwargs) 

109 if remainder := {key.format(**kwargs) for key, _ in needed} - { 

110 key.format(**kwargs) for key in data.keys() 

111 }: 

112 raise ValueError(f"Task needs keys {remainder} but they were not found in input") 

113 for name, typ in needed: 

114 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar) 

115 if isScalar and typ != Scalar: 

116 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}") 

117 

118 def statsAndText(self, arr, mask=None): 

119 """Calculate some stats from an array and return them 

120 and some text. 

121 """ 

122 numPoints = len(arr) 

123 if mask is not None: 

124 arr = arr[mask] 

125 med = np.nanmedian(arr) 

126 sigMad = nansigmaMad(arr) 

127 

128 statsText = ( 

129 "Median: {:0.2f}\n".format(med) 

130 + r"$\sigma_{MAD}$: " 

131 + "{:0.2f}\n".format(sigMad) 

132 + r"n$_{points}$: " 

133 + "{}".format(numPoints) 

134 ) 

135 

136 return med, sigMad, statsText 

137 

138 def makePlot( 

139 self, 

140 data: KeyedData, 

141 plotInfo: Optional[Mapping[str, str]] = None, 

142 sumStats: Optional[Mapping] = None, 

143 **kwargs, 

144 ) -> Figure: 

145 """Make a skyPlot of the given data. 

146 

147 Parameters 

148 ---------- 

149 data : `KeyedData` 

150 The catalog to plot the points from. 

151 plotInfo : `dict` 

152 A dictionary of information about the data being plotted with keys: 

153 

154 ``"run"`` 

155 The output run for the plots (`str`). 

156 ``"skymap"`` 

157 The type of skymap used for the data (`str`). 

158 ``"filter"`` 

159 The filter used for this data (`str`). 

160 ``"tract"`` 

161 The tract that the data comes from (`str`). 

162 

163 sumStats : `dict` 

164 A dictionary where the patchIds are the keys which store the R.A. 

165 and dec of the corners of the patch. 

166 

167 Returns 

168 ------- 

169 `pipeBase.Struct` containing: 

170 skyPlot : `matplotlib.figure.Figure` 

171 The resulting figure. 

172 

173 Notes 

174 ----- 

175 Expects the data to contain slightly different things 

176 depending on the types specified in plotTypes. This 

177 is handled automatically if you go through the pipetask 

178 framework but if you call this method separately then you 

179 need to make sure that data contains what the code is expecting. 

180 

181 If stars is in the plot types given then it is expected that 

182 data contains: xStars, yStars, zStars and starStatMask. 

183 

184 If galaxies is present: xGalaxies, yGalaxies, zGalaxies and 

185 galaxyStatsMask. 

186 

187 If unknown is present: xUnknowns, yUnknowns, zUnknowns and 

188 unknownStatMask. 

189 

190 If any is specified: x, y, z, statMask. 

191 

192 These options are not exclusive and multiple can be specified 

193 and thus need to be present in data. 

194 """ 

195 

196 fig = plt.figure(dpi=300) 

197 ax = fig.add_subplot(111) 

198 

199 if sumStats is None: 

200 sumStats = {} 

201 

202 if plotInfo is None: 

203 plotInfo = {} 

204 

205 # Make divergent colormaps for stars, galaxes and all the points 

206 blueGreen = mkColormap(["midnightblue", "lightcyan", "darkgreen"]) 

207 redPurple = mkColormap(["indigo", "lemonchiffon", "firebrick"]) 

208 orangeBlue = mkColormap(["darkOrange", "thistle", "midnightblue"]) 

209 

210 xCol = self.xAxisLabel 

211 yCol = self.yAxisLabel 

212 zCol = self.zAxisLabel # noqa: F841 

213 

214 toPlotList = [] 

215 # For galaxies 

216 if "galaxies" in self.plotTypes: 

217 sortedArrs = sortAllArrays( 

218 [data["zGalaxies"], data["xGalaxies"], data["yGalaxies"], data["galaxyStatMask"]] 

219 ) 

220 [colorValsGalaxies, xsGalaxies, ysGalaxies, statGalaxies] = sortedArrs 

221 statGalMed, statGalMad, galStatsText = self.statsAndText(colorValsGalaxies, mask=statGalaxies) 

222 # Add statistics 

223 bbox = dict(facecolor="lemonchiffon", alpha=0.5, edgecolor="none") 

224 # Check if plotting stars and galaxies, if so move the 

225 # text box so that both can be seen. Needs to be 

226 # > 2 becuase not being plotted points are assigned 0 

227 if len(self.plotTypes) > 2: 

228 boxLoc = 0.63 

229 else: 

230 boxLoc = 0.8 

231 ax.text(boxLoc, 0.91, galStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox) 

232 toPlotList.append((xsGalaxies, ysGalaxies, colorValsGalaxies, redPurple, "Galaxies")) 

233 

234 # For stars 

235 if "stars" in self.plotTypes: 

236 sortedArrs = sortAllArrays([data["zStars"], data["xStars"], data["yStars"], data["starStatMask"]]) 

237 [colorValsStars, xsStars, ysStars, statStars] = sortedArrs 

238 statStarMed, statStarMad, starStatsText = self.statsAndText(colorValsStars, mask=statStars) 

239 # Add statistics 

240 bbox = dict(facecolor="paleturquoise", alpha=0.5, edgecolor="none") 

241 ax.text(0.8, 0.91, starStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox) 

242 toPlotList.append((xsStars, ysStars, colorValsStars, blueGreen, "Stars")) 

243 

244 # For unknowns 

245 if "unknown" in self.plotTypes: 

246 sortedArrs = sortAllArrays( 

247 [data["zUnknowns"], data["xUnknowns"], data["yUnknowns"], data["unknownStatMask"]] 

248 ) 

249 [colorValsUnknowns, xsUnknowns, ysUnknowns, statUnknowns] = sortedArrs 

250 statUnknownMed, statUnknownMad, unknownStatsText = self.statsAndText( 

251 colorValsUnknowns, mask=statUnknowns 

252 ) 

253 bbox = dict(facecolor="green", alpha=0.2, edgecolor="none") 

254 ax.text(0.8, 0.91, unknownStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox) 

255 toPlotList.append((xsUnknowns, ysUnknowns, colorValsUnknowns, "viridis", "Unknown")) 

256 

257 if "any" in self.plotTypes: 

258 sortedArrs = sortAllArrays([data["z"], data["x"], data["y"], data["statMask"]]) 

259 [colorValsAny, xs, ys, statAny] = sortedArrs 

260 statAnyMed, statAnyMad, anyStatsText = self.statsAndText(colorValsAny, mask=statAny) 

261 bbox = dict(facecolor="purple", alpha=0.2, edgecolor="none") 

262 ax.text(0.8, 0.91, anyStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox) 

263 toPlotList.append((xs, ys, colorValsAny, orangeBlue, "All")) 

264 

265 # Corner plot of patches showing summary stat in each 

266 if self.plotOutlines: 

267 patches = [] 

268 for dataId in sumStats.keys(): 

269 (corners, _) = sumStats[dataId] 

270 ra = corners[0][0].asDegrees() 

271 dec = corners[0][1].asDegrees() 

272 xy = (ra, dec) 

273 width = corners[2][0].asDegrees() - ra 

274 height = corners[2][1].asDegrees() - dec 

275 patches.append(Rectangle(xy, width, height, alpha=0.3)) 

276 ras = [ra.asDegrees() for (ra, dec) in corners] 

277 decs = [dec.asDegrees() for (ra, dec) in corners] 

278 ax.plot(ras + [ras[0]], decs + [decs[0]], "k", lw=0.5) 

279 cenX = ra + width / 2 

280 cenY = dec + height / 2 

281 if dataId == "tract": 

282 minRa = np.min(ras) 

283 minDec = np.min(decs) 

284 maxRa = np.max(ras) 

285 maxDec = np.max(decs) 

286 if dataId != "tract": 

287 ax.annotate( 

288 dataId, 

289 (cenX, cenY), 

290 color="k", 

291 fontsize=5, 

292 ha="center", 

293 va="center", 

294 path_effects=[pathEffects.withStroke(linewidth=2, foreground="w")], 

295 ) 

296 

297 for i, (xs, ys, colorVals, cmap, label) in enumerate(toPlotList): 

298 finite = np.isfinite(xs) & np.isfinite(ys) 

299 xs = xs[finite] 

300 ys = ys[finite] 

301 n_xs = len(xs) 

302 # colorVal column is unusable so zero it out 

303 # This should be obvious on the plot 

304 if not any(np.isfinite(colorVals)): 

305 colorVals[:] = 0 

306 

307 if n_xs < 5: 

308 continue 

309 if not self.plotOutlines or "tract" not in sumStats.keys(): 

310 minRa = np.min(xs) 

311 maxRa = np.max(xs) 

312 minDec = np.min(ys) 

313 maxDec = np.max(ys) 

314 # Avoid identical end points which causes problems in binning 

315 if minRa == maxRa: 

316 maxRa += 1e-5 # There is no reason to pick this number in particular 

317 if minDec == maxDec: 

318 maxDec += 1e-5 # There is no reason to pick this number in particular 

319 

320 plotOut = plotProjectionWithBinning( 

321 ax, 

322 xs, 

323 ys, 

324 colorVals, 

325 cmap, 

326 minRa, 

327 maxRa, 

328 minDec, 

329 maxDec, 

330 fixAroundZero=self.fixAroundZero, 

331 isSorted=True, 

332 ) 

333 cax = fig.add_axes([0.87 + i * 0.04, 0.11, 0.04, 0.77]) 

334 plt.colorbar(plotOut, cax=cax, extend="both") 

335 colorBarLabel = "{}: {}".format(self.zAxisLabel, label) 

336 text = cax.text( 

337 0.5, 

338 0.5, 

339 colorBarLabel, 

340 color="k", 

341 rotation="vertical", 

342 transform=cax.transAxes, 

343 ha="center", 

344 va="center", 

345 fontsize=10, 

346 ) 

347 text.set_path_effects([pathEffects.Stroke(linewidth=3, foreground="w"), pathEffects.Normal()]) 

348 cax.tick_params(labelsize=7) 

349 

350 if i == 0 and len(toPlotList) > 1: 

351 cax.yaxis.set_ticks_position("left") 

352 

353 ax.set_xlabel(xCol) 

354 ax.set_ylabel(yCol) 

355 ax.tick_params(axis="x", labelrotation=25) 

356 ax.tick_params(labelsize=7) 

357 

358 ax.set_aspect("equal") 

359 plt.draw() 

360 

361 # Find some useful axis limits 

362 lenXs = [len(xs) for (xs, _, _, _, _) in toPlotList] 

363 if lenXs != [] and np.max(lenXs) > 1000: 

364 padRa = (maxRa - minRa) / 10 

365 padDec = (maxDec - minDec) / 10 

366 ax.set_xlim(maxRa + padRa, minRa - padRa) 

367 ax.set_ylim(minDec - padDec, maxDec + padDec) 

368 else: 

369 ax.invert_xaxis() 

370 

371 # Add useful information to the plot 

372 plt.subplots_adjust(wspace=0.0, hspace=0.0, right=0.85) 

373 fig = plt.gcf() 

374 fig = addPlotInfo(fig, plotInfo) 

375 

376 return fig