Coverage for python/lsst/analysis/tools/actions/plot/skyPlot.py: 11%

165 statements  

« prev     ^ index     » next       coverage.py v7.2.5, created at 2023-05-04 11:09 +0000

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ("SkyPlot",) 

25 

26from typing import Mapping, Optional 

27 

28import matplotlib.patheffects as pathEffects 

29import matplotlib.pyplot as plt 

30import numpy as np 

31from lsst.pex.config import Field, ListField 

32from matplotlib.figure import Figure 

33from matplotlib.patches import Rectangle 

34 

35from ...interfaces import KeyedData, KeyedDataSchema, PlotAction, Scalar, Vector 

36from ...statistics import nansigmaMad 

37from .plotUtils import addPlotInfo, mkColormap, plotProjectionWithBinning, sortAllArrays 

38 

39# from .plotUtils import generateSummaryStats, parsePlotInfo 

40 

41 

42class SkyPlot(PlotAction): 

43 """Plots the on sky distribution of a parameter. 

44 

45 Plots the values of the parameter given for the z axis 

46 according to the positions given for x and y. Optimised 

47 for use with RA and Dec. Also calculates some basic 

48 statistics and includes those on the plot. 

49 """ 

50 

51 xAxisLabel = Field[str](doc="Label to use for the x axis.", optional=False) 

52 yAxisLabel = Field[str](doc="Label to use for the y axis.", optional=False) 

53 zAxisLabel = Field[str](doc="Label to use for the z axis.", optional=False) 

54 

55 plotOutlines = Field[bool]( 

56 doc="Plot the outlines of the ccds/patches?", 

57 default=True, 

58 ) 

59 

60 plotTypes = ListField[str]( 

61 doc="Selection of types of objects to plot. Can take any combination of" 

62 " stars, galaxies, unknown, mag, any.", 

63 optional=False, 

64 # itemCheck=_validatePlotTypes, 

65 ) 

66 

67 plotName = Field[str](doc="The name for the plot.", optional=False) 

68 

69 fixAroundZero = Field[bool]( 

70 doc="Fix the colorbar to be symmetric around zero.", 

71 default=False, 

72 ) 

73 

74 def getInputSchema(self, **kwargs) -> KeyedDataSchema: 

75 base = [] 

76 if "stars" in self.plotTypes: # type: ignore 

77 base.append(("xStars", Vector)) 

78 base.append(("yStars", Vector)) 

79 base.append(("zStars", Vector)) 

80 base.append(("starStatMask", Vector)) 

81 if "galaxies" in self.plotTypes: # type: ignore 

82 base.append(("xGalaxies", Vector)) 

83 base.append(("yGalaxies", Vector)) 

84 base.append(("zGalaxies", Vector)) 

85 base.append(("galaxyStatMask", Vector)) 

86 if "unknown" in self.plotTypes: # type: ignore 

87 base.append(("xUnknowns", Vector)) 

88 base.append(("yUnknowns", Vector)) 

89 base.append(("zUnknowns", Vector)) 

90 base.append(("unknownStatMask", Vector)) 

91 if "any" in self.plotTypes: # type: ignore 

92 base.append(("x", Vector)) 

93 base.append(("y", Vector)) 

94 base.append(("z", Vector)) 

95 base.append(("statMask", Vector)) 

96 

97 return base 

98 

99 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure: 

100 self._validateInput(data, **kwargs) 

101 return self.makePlot(data, **kwargs) 

102 # table is a dict that needs: x, y, run, skymap, filter, tract, 

103 

104 def _validateInput(self, data: KeyedData, **kwargs) -> None: 

105 """NOTE currently can only check that something is not a Scalar, not 

106 check that the data is consistent with Vector 

107 """ 

108 needed = self.getInputSchema(**kwargs) 

109 if remainder := {key.format(**kwargs) for key, _ in needed} - { 

110 key.format(**kwargs) for key in data.keys() 

111 }: 

112 raise ValueError(f"Task needs keys {remainder} but they were not found in input") 

113 for name, typ in needed: 

114 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar) 

115 if isScalar and typ != Scalar: 

116 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}") 

117 

118 def statsAndText(self, arr, mask=None): 

119 """Calculate some stats from an array and return them 

120 and some text. 

121 """ 

122 numPoints = len(arr) 

123 if mask is not None: 

124 arr = arr[mask] 

125 med = np.nanmedian(arr) 

126 sigMad = nansigmaMad(arr) 

127 

128 statsText = ( 

129 "Median: {:0.2f}\n".format(med) 

130 + r"$\sigma_{MAD}$: " 

131 + "{:0.2f}\n".format(sigMad) 

132 + r"n$_{points}$: " 

133 + "{}".format(numPoints) 

134 ) 

135 

136 return med, sigMad, statsText 

137 

138 def makePlot( 

139 self, 

140 data: KeyedData, 

141 plotInfo: Optional[Mapping[str, str]] = None, 

142 sumStats: Optional[Mapping] = None, 

143 **kwargs, 

144 ) -> Figure: 

145 """Make a skyPlot of the given data. 

146 

147 Parameters 

148 ---------- 

149 data : `KeyedData` 

150 The catalog to plot the points from. 

151 plotInfo : `dict` 

152 A dictionary of information about the data being plotted with keys: 

153 

154 ``"run"`` 

155 The output run for the plots (`str`). 

156 ``"skymap"`` 

157 The type of skymap used for the data (`str`). 

158 ``"filter"`` 

159 The filter used for this data (`str`). 

160 ``"tract"`` 

161 The tract that the data comes from (`str`). 

162 

163 sumStats : `dict` 

164 A dictionary where the patchIds are the keys which store the R.A. 

165 and dec of the corners of the patch. 

166 

167 Returns 

168 ------- 

169 `pipeBase.Struct` containing: 

170 skyPlot : `matplotlib.figure.Figure` 

171 The resulting figure. 

172 

173 Notes 

174 ----- 

175 Expects the data to contain slightly different things 

176 depending on the types specified in plotTypes. This 

177 is handled automatically if you go through the pipetask 

178 framework but if you call this method separately then you 

179 need to make sure that data contains what the code is expecting. 

180 

181 If stars is in the plot types given then it is expected that 

182 data contains: xStars, yStars, zStars and starStatMask. 

183 

184 If galaxies is present: xGalaxies, yGalaxies, zGalaxies and 

185 galaxyStatsMask. 

186 

187 If unknown is present: xUnknowns, yUnknowns, zUnknowns and 

188 unknownStatMask. 

189 

190 If any is specified: x, y, z, statMask. 

191 

192 These options are not exclusive and multiple can be specified 

193 and thus need to be present in data. 

194 """ 

195 

196 fig = plt.figure(dpi=300) 

197 ax = fig.add_subplot(111) 

198 

199 if sumStats is None: 

200 sumStats = {} 

201 

202 if plotInfo is None: 

203 plotInfo = {} 

204 

205 # Make divergent colormaps for stars, galaxes and all the points 

206 blueGreen = mkColormap(["midnightblue", "lightcyan", "darkgreen"]) 

207 redPurple = mkColormap(["indigo", "lemonchiffon", "firebrick"]) 

208 orangeBlue = mkColormap(["darkOrange", "thistle", "midnightblue"]) 

209 

210 xCol = self.xAxisLabel 

211 yCol = self.yAxisLabel 

212 zCol = self.zAxisLabel # noqa: F841 

213 

214 toPlotList = [] 

215 # For galaxies 

216 if "galaxies" in self.plotTypes: 

217 sortedArrs = sortAllArrays( 

218 [data["zGalaxies"], data["xGalaxies"], data["yGalaxies"], data["galaxyStatMask"]] 

219 ) 

220 [colorValsGalaxies, xsGalaxies, ysGalaxies, statGalaxies] = sortedArrs 

221 statGalMed, statGalMad, galStatsText = self.statsAndText(colorValsGalaxies, mask=statGalaxies) 

222 # Add statistics 

223 bbox = dict(facecolor="lemonchiffon", alpha=0.5, edgecolor="none") 

224 # Check if plotting stars and galaxies, if so move the 

225 # text box so that both can be seen. Needs to be 

226 # > 2 becuase not being plotted points are assigned 0 

227 if len(self.plotTypes) > 2: 

228 boxLoc = 0.63 

229 else: 

230 boxLoc = 0.8 

231 ax.text(boxLoc, 0.91, galStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox) 

232 toPlotList.append((xsGalaxies, ysGalaxies, colorValsGalaxies, redPurple, "Galaxies")) 

233 

234 # For stars 

235 if "stars" in self.plotTypes: 

236 sortedArrs = sortAllArrays([data["zStars"], data["xStars"], data["yStars"], data["starStatMask"]]) 

237 [colorValsStars, xsStars, ysStars, statStars] = sortedArrs 

238 statStarMed, statStarMad, starStatsText = self.statsAndText(colorValsStars, mask=statStars) 

239 # Add statistics 

240 bbox = dict(facecolor="paleturquoise", alpha=0.5, edgecolor="none") 

241 ax.text(0.8, 0.91, starStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox) 

242 toPlotList.append((xsStars, ysStars, colorValsStars, blueGreen, "Stars")) 

243 

244 # For unknowns 

245 if "unknown" in self.plotTypes: 

246 sortedArrs = sortAllArrays( 

247 [data["zUnknowns"], data["xUnknowns"], data["yUnknowns"], data["unknownStatMask"]] 

248 ) 

249 [colorValsUnknowns, xsUnknowns, ysUnknowns, statUnknowns] = sortedArrs 

250 statUnknownMed, statUnknownMad, unknownStatsText = self.statsAndText( 

251 colorValsUnknowns, mask=statUnknowns 

252 ) 

253 bbox = dict(facecolor="green", alpha=0.2, edgecolor="none") 

254 ax.text(0.8, 0.91, unknownStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox) 

255 toPlotList.append((xsUnknowns, ysUnknowns, colorValsUnknowns, "viridis", "Unknown")) 

256 

257 if "any" in self.plotTypes: 

258 sortedArrs = sortAllArrays([data["z"], data["x"], data["y"], data["statMask"]]) 

259 [colorValsAny, xs, ys, statAny] = sortedArrs 

260 statAnyMed, statAnyMad, anyStatsText = self.statsAndText(colorValsAny, mask=statAny) 

261 bbox = dict(facecolor="purple", alpha=0.2, edgecolor="none") 

262 ax.text(0.8, 0.91, anyStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox) 

263 toPlotList.append((xs, ys, colorValsAny, orangeBlue, "All")) 

264 

265 # Corner plot of patches showing summary stat in each 

266 if self.plotOutlines: 

267 patches = [] 

268 for dataId in sumStats.keys(): 

269 (corners, _) = sumStats[dataId] 

270 ra = corners[0][0].asDegrees() 

271 dec = corners[0][1].asDegrees() 

272 xy = (ra, dec) 

273 width = corners[2][0].asDegrees() - ra 

274 height = corners[2][1].asDegrees() - dec 

275 patches.append(Rectangle(xy, width, height, alpha=0.3)) 

276 ras = [ra.asDegrees() for (ra, dec) in corners] 

277 decs = [dec.asDegrees() for (ra, dec) in corners] 

278 ax.plot(ras + [ras[0]], decs + [decs[0]], "k", lw=0.5) 

279 cenX = ra + width / 2 

280 cenY = dec + height / 2 

281 if dataId == "tract": 

282 minRa = np.min(ras) 

283 minDec = np.min(decs) 

284 maxRa = np.max(ras) 

285 maxDec = np.max(decs) 

286 if dataId != "tract": 

287 ax.annotate( 

288 dataId, 

289 (cenX, cenY), 

290 color="k", 

291 fontsize=5, 

292 ha="center", 

293 va="center", 

294 path_effects=[pathEffects.withStroke(linewidth=2, foreground="w")], 

295 ) 

296 

297 for i, (xs, ys, colorVals, cmap, label) in enumerate(toPlotList): 

298 if not self.plotOutlines or "tract" not in sumStats.keys(): 

299 minRa = np.min(xs) 

300 maxRa = np.max(xs) 

301 minDec = np.min(ys) 

302 maxDec = np.max(ys) 

303 # Avoid identical end points which causes problems in binning 

304 if minRa == maxRa: 

305 maxRa += 1e-5 # There is no reason to pick this number in particular 

306 if minDec == maxDec: 

307 maxDec += 1e-5 # There is no reason to pick this number in particular 

308 

309 plotOut = plotProjectionWithBinning( 

310 ax, 

311 xs, 

312 ys, 

313 colorVals, 

314 cmap, 

315 minRa, 

316 maxRa, 

317 minDec, 

318 maxDec, 

319 fixAroundZero=self.fixAroundZero, 

320 isSorted=True, 

321 ) 

322 cax = fig.add_axes([0.87 + i * 0.04, 0.11, 0.04, 0.77]) 

323 plt.colorbar(plotOut, cax=cax, extend="both") 

324 colorBarLabel = "{}: {}".format(self.zAxisLabel, label) 

325 text = cax.text( 

326 0.5, 

327 0.5, 

328 colorBarLabel, 

329 color="k", 

330 rotation="vertical", 

331 transform=cax.transAxes, 

332 ha="center", 

333 va="center", 

334 fontsize=10, 

335 ) 

336 text.set_path_effects([pathEffects.Stroke(linewidth=3, foreground="w"), pathEffects.Normal()]) 

337 cax.tick_params(labelsize=7) 

338 

339 if i == 0 and len(toPlotList) > 1: 

340 cax.yaxis.set_ticks_position("left") 

341 

342 ax.set_xlabel(xCol) 

343 ax.set_ylabel(yCol) 

344 ax.tick_params(axis="x", labelrotation=25) 

345 ax.tick_params(labelsize=7) 

346 

347 ax.set_aspect("equal") 

348 plt.draw() 

349 

350 # Find some useful axis limits 

351 lenXs = [len(xs) for (xs, _, _, _, _) in toPlotList] 

352 if lenXs != [] and np.max(lenXs) > 1000: 

353 padRa = (maxRa - minRa) / 10 

354 padDec = (maxDec - minDec) / 10 

355 ax.set_xlim(maxRa + padRa, minRa - padRa) 

356 ax.set_ylim(minDec - padDec, maxDec + padDec) 

357 else: 

358 ax.invert_xaxis() 

359 

360 # Add useful information to the plot 

361 plt.subplots_adjust(wspace=0.0, hspace=0.0, right=0.85) 

362 fig = plt.gcf() 

363 fig = addPlotInfo(fig, plotInfo) 

364 

365 return fig