Coverage for python/lsst/analysis/tools/actions/plot/skyPlot.py: 12%

179 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-12-08 13:15 +0000

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ("SkyPlot",) 

25 

26from typing import Mapping, Optional 

27 

28import matplotlib.patheffects as pathEffects 

29import matplotlib.pyplot as plt 

30import numpy as np 

31from lsst.pex.config import Field, ListField 

32from lsst.pex.config.configurableActions import ConfigurableActionField 

33from matplotlib.figure import Figure 

34from matplotlib.patches import Rectangle 

35 

36from ...interfaces import KeyedData, KeyedDataSchema, PlotAction, Scalar, Vector, VectorAction 

37from ...statistics import nansigmaMad 

38from .calculateRange import Med2Mad 

39from .plotUtils import addPlotInfo, generateSummaryStats, mkColormap, plotProjectionWithBinning, sortAllArrays 

40 

41 

42class SkyPlot(PlotAction): 

43 """Plots the on sky distribution of a parameter. 

44 

45 Plots the values of the parameter given for the z axis 

46 according to the positions given for x and y. Optimised 

47 for use with RA and Dec. Also calculates some basic 

48 statistics and includes those on the plot. 

49 

50 The plotting of patch outlines requires patch information 

51 to be included as an additional parameter. 

52 """ 

53 

54 xAxisLabel = Field[str](doc="Label to use for the x axis.", optional=False) 

55 yAxisLabel = Field[str](doc="Label to use for the y axis.", optional=False) 

56 zAxisLabel = Field[str](doc="Label to use for the z axis.", optional=False) 

57 

58 plotOutlines = Field[bool]( 

59 doc="Plot the outlines of the ccds/patches?", 

60 default=True, 

61 ) 

62 

63 plotTypes = ListField[str]( 

64 doc="Selection of types of objects to plot. Can take any combination of" 

65 " stars, galaxies, unknown, mag, any.", 

66 optional=False, 

67 # itemCheck=_validatePlotTypes, 

68 ) 

69 

70 plotName = Field[str](doc="The name for the plot.", optional=False) 

71 

72 fixAroundZero = Field[bool]( 

73 doc="Fix the colorbar to be symmetric around zero.", 

74 default=False, 

75 ) 

76 

77 colorbarRange = ConfigurableActionField[VectorAction]( 

78 doc="Action to calculate the min and max of the colorbar range.", 

79 default=Med2Mad, 

80 ) 

81 

82 def getInputSchema(self, **kwargs) -> KeyedDataSchema: 

83 base = [] 

84 if "stars" in self.plotTypes: # type: ignore 

85 base.append(("xStars", Vector)) 

86 base.append(("yStars", Vector)) 

87 base.append(("zStars", Vector)) 

88 base.append(("starStatMask", Vector)) 

89 if "galaxies" in self.plotTypes: # type: ignore 

90 base.append(("xGalaxies", Vector)) 

91 base.append(("yGalaxies", Vector)) 

92 base.append(("zGalaxies", Vector)) 

93 base.append(("galaxyStatMask", Vector)) 

94 if "unknown" in self.plotTypes: # type: ignore 

95 base.append(("xUnknowns", Vector)) 

96 base.append(("yUnknowns", Vector)) 

97 base.append(("zUnknowns", Vector)) 

98 base.append(("unknownStatMask", Vector)) 

99 if "any" in self.plotTypes: # type: ignore 

100 base.append(("x", Vector)) 

101 base.append(("y", Vector)) 

102 base.append(("z", Vector)) 

103 base.append(("statMask", Vector)) 

104 

105 return base 

106 

107 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure: 

108 self._validateInput(data, **kwargs) 

109 return self.makePlot(data, **kwargs) 

110 # table is a dict that needs: x, y, run, skymap, filter, tract, 

111 

112 def _validateInput(self, data: KeyedData, **kwargs) -> None: 

113 """NOTE currently can only check that something is not a Scalar, not 

114 check that the data is consistent with Vector 

115 """ 

116 needed = self.getInputSchema(**kwargs) 

117 if remainder := {key.format(**kwargs) for key, _ in needed} - { 

118 key.format(**kwargs) for key in data.keys() 

119 }: 

120 raise ValueError(f"Task needs keys {remainder} but they were not found in input") 

121 for name, typ in needed: 

122 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar) 

123 if isScalar and typ != Scalar: 

124 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}") 

125 

126 def statsAndText(self, arr, mask=None): 

127 """Calculate some stats from an array and return them 

128 and some text. 

129 """ 

130 numPoints = len(arr) 

131 if mask is not None: 

132 arr = arr[mask] 

133 med = np.nanmedian(arr) 

134 sigMad = nansigmaMad(arr) 

135 

136 statsText = ( 

137 "Median: {:0.2f}\n".format(med) 

138 + r"$\sigma_{MAD}$: " 

139 + "{:0.2f}\n".format(sigMad) 

140 + r"n$_{points}$: " 

141 + "{}".format(numPoints) 

142 ) 

143 

144 return med, sigMad, statsText 

145 

146 def makePlot( 

147 self, 

148 data: KeyedData, 

149 plotInfo: Optional[Mapping[str, str]] = None, 

150 sumStats: Optional[Mapping] = None, 

151 **kwargs, 

152 ) -> Figure: 

153 """Make a skyPlot of the given data. 

154 

155 Parameters 

156 ---------- 

157 data : `KeyedData` 

158 The catalog to plot the points from. 

159 plotInfo : `dict` 

160 A dictionary of information about the data being plotted with keys: 

161 

162 ``"run"`` 

163 The output run for the plots (`str`). 

164 ``"skymap"`` 

165 The type of skymap used for the data (`str`). 

166 ``"filter"`` 

167 The filter used for this data (`str`). 

168 ``"tract"`` 

169 The tract that the data comes from (`str`). 

170 

171 sumStats : `dict` 

172 A dictionary where the patchIds are the keys which store the R.A. 

173 and dec of the corners of the patch. 

174 

175 Returns 

176 ------- 

177 `pipeBase.Struct` containing: 

178 skyPlot : `matplotlib.figure.Figure` 

179 The resulting figure. 

180 

181 Notes 

182 ----- 

183 Expects the data to contain slightly different things 

184 depending on the types specified in plotTypes. This 

185 is handled automatically if you go through the pipetask 

186 framework but if you call this method separately then you 

187 need to make sure that data contains what the code is expecting. 

188 

189 If stars is in the plot types given then it is expected that 

190 data contains: xStars, yStars, zStars and starStatMask. 

191 

192 If galaxies is present: xGalaxies, yGalaxies, zGalaxies and 

193 galaxyStatsMask. 

194 

195 If unknown is present: xUnknowns, yUnknowns, zUnknowns and 

196 unknownStatMask. 

197 

198 If any is specified: x, y, z, statMask. 

199 

200 These options are not exclusive and multiple can be specified 

201 and thus need to be present in data. 

202 

203 Examples 

204 -------- 

205 An example of the plot produced from this code is here: 

206 

207 .. image:: /_static/analysis_tools/skyPlotExample.png 

208 

209 For a detailed example of how to make a plot from the command line 

210 please see the 

211 :ref:`getting started guide<analysis-tools-getting-started>`. 

212 """ 

213 

214 fig = plt.figure(dpi=300) 

215 ax = fig.add_subplot(111) 

216 

217 if sumStats is None: 

218 if self.plotOutlines and "patch" in data.keys(): 

219 sumStats = generateSummaryStats(data, kwargs["skymap"], plotInfo) 

220 else: 

221 sumStats = {} 

222 

223 if plotInfo is None: 

224 plotInfo = {} 

225 

226 # Make divergent colormaps for stars, galaxes and all the points 

227 blueGreen = mkColormap(["midnightblue", "lightcyan", "darkgreen"]) 

228 redPurple = mkColormap(["indigo", "lemonchiffon", "firebrick"]) 

229 orangeBlue = mkColormap(["darkOrange", "thistle", "midnightblue"]) 

230 

231 xCol = self.xAxisLabel 

232 yCol = self.yAxisLabel 

233 zCol = self.zAxisLabel # noqa: F841 

234 

235 toPlotList = [] 

236 # For galaxies 

237 if "galaxies" in self.plotTypes: 

238 sortedArrs = sortAllArrays( 

239 [data["zGalaxies"], data["xGalaxies"], data["yGalaxies"], data["galaxyStatMask"]] 

240 ) 

241 [colorValsGalaxies, xsGalaxies, ysGalaxies, statGalaxies] = sortedArrs 

242 statGalMed, statGalMad, galStatsText = self.statsAndText(colorValsGalaxies, mask=statGalaxies) 

243 # Add statistics 

244 bbox = dict(facecolor="lemonchiffon", alpha=0.5, edgecolor="none") 

245 # Check if plotting stars and galaxies, if so move the 

246 # text box so that both can be seen. Needs to be 

247 # > 2 becuase not being plotted points are assigned 0 

248 if len(self.plotTypes) > 2: 

249 boxLoc = 0.63 

250 else: 

251 boxLoc = 0.8 

252 ax.text(boxLoc, 0.91, galStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox) 

253 toPlotList.append((xsGalaxies, ysGalaxies, colorValsGalaxies, redPurple, "Galaxies")) 

254 

255 # For stars 

256 if "stars" in self.plotTypes: 

257 sortedArrs = sortAllArrays([data["zStars"], data["xStars"], data["yStars"], data["starStatMask"]]) 

258 [colorValsStars, xsStars, ysStars, statStars] = sortedArrs 

259 statStarMed, statStarMad, starStatsText = self.statsAndText(colorValsStars, mask=statStars) 

260 # Add statistics 

261 bbox = dict(facecolor="paleturquoise", alpha=0.5, edgecolor="none") 

262 ax.text(0.8, 0.91, starStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox) 

263 toPlotList.append((xsStars, ysStars, colorValsStars, blueGreen, "Stars")) 

264 

265 # For unknowns 

266 if "unknown" in self.plotTypes: 

267 sortedArrs = sortAllArrays( 

268 [data["zUnknowns"], data["xUnknowns"], data["yUnknowns"], data["unknownStatMask"]] 

269 ) 

270 [colorValsUnknowns, xsUnknowns, ysUnknowns, statUnknowns] = sortedArrs 

271 statUnknownMed, statUnknownMad, unknownStatsText = self.statsAndText( 

272 colorValsUnknowns, mask=statUnknowns 

273 ) 

274 bbox = dict(facecolor="green", alpha=0.2, edgecolor="none") 

275 ax.text(0.8, 0.91, unknownStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox) 

276 toPlotList.append((xsUnknowns, ysUnknowns, colorValsUnknowns, "viridis", "Unknown")) 

277 

278 if "any" in self.plotTypes: 

279 sortedArrs = sortAllArrays([data["z"], data["x"], data["y"], data["statMask"]]) 

280 [colorValsAny, xs, ys, statAny] = sortedArrs 

281 statAnyMed, statAnyMad, anyStatsText = self.statsAndText(colorValsAny, mask=statAny) 

282 bbox = dict(facecolor="purple", alpha=0.2, edgecolor="none") 

283 ax.text(0.8, 0.91, anyStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox) 

284 toPlotList.append((xs, ys, colorValsAny, orangeBlue, "All")) 

285 

286 # Corner plot of patches showing summary stat in each 

287 if self.plotOutlines: 

288 patches = [] 

289 for dataId in sumStats.keys(): 

290 (corners, _) = sumStats[dataId] 

291 ra = corners[0][0].asDegrees() 

292 dec = corners[0][1].asDegrees() 

293 xy = (ra, dec) 

294 width = corners[2][0].asDegrees() - ra 

295 height = corners[2][1].asDegrees() - dec 

296 patches.append(Rectangle(xy, width, height, alpha=0.3)) 

297 ras = [ra.asDegrees() for (ra, dec) in corners] 

298 decs = [dec.asDegrees() for (ra, dec) in corners] 

299 ax.plot(ras + [ras[0]], decs + [decs[0]], "k", lw=0.5) 

300 cenX = ra + width / 2 

301 cenY = dec + height / 2 

302 if dataId == "tract": 

303 minRa = np.min(ras) 

304 minDec = np.min(decs) 

305 maxRa = np.max(ras) 

306 maxDec = np.max(decs) 

307 if dataId != "tract": 

308 ax.annotate( 

309 dataId, 

310 (cenX, cenY), 

311 color="k", 

312 fontsize=5, 

313 ha="center", 

314 va="center", 

315 path_effects=[pathEffects.withStroke(linewidth=2, foreground="w")], 

316 ) 

317 

318 for i, (xs, ys, colorVals, cmap, label) in enumerate(toPlotList): 

319 finite = np.isfinite(xs) & np.isfinite(ys) 

320 xs = xs[finite] 

321 ys = ys[finite] 

322 n_xs = len(xs) 

323 # colorVal column is unusable so zero it out 

324 # This should be obvious on the plot 

325 if not any(np.isfinite(colorVals)): 

326 colorVals[:] = 0 

327 minColorVal, maxColorVal = self.colorbarRange(colorVals) 

328 

329 if n_xs < 5: 

330 continue 

331 if not self.plotOutlines or "tract" not in sumStats.keys(): 

332 minRa = np.min(xs) 

333 maxRa = np.max(xs) 

334 minDec = np.min(ys) 

335 maxDec = np.max(ys) 

336 # Avoid identical end points which causes problems in binning 

337 if minRa == maxRa: 

338 maxRa += 1e-5 # There is no reason to pick this number in particular 

339 if minDec == maxDec: 

340 maxDec += 1e-5 # There is no reason to pick this number in particular 

341 

342 plotOut = plotProjectionWithBinning( 

343 ax, 

344 xs, 

345 ys, 

346 colorVals, 

347 cmap, 

348 minRa, 

349 maxRa, 

350 minDec, 

351 maxDec, 

352 vmin=minColorVal, 

353 vmax=maxColorVal, 

354 fixAroundZero=self.fixAroundZero, 

355 isSorted=True, 

356 ) 

357 cax = fig.add_axes([0.87 + i * 0.04, 0.11, 0.04, 0.77]) 

358 plt.colorbar(plotOut, cax=cax, extend="both") 

359 colorBarLabel = "{}: {}".format(self.zAxisLabel, label) 

360 text = cax.text( 

361 0.5, 

362 0.5, 

363 colorBarLabel, 

364 color="k", 

365 rotation="vertical", 

366 transform=cax.transAxes, 

367 ha="center", 

368 va="center", 

369 fontsize=10, 

370 ) 

371 text.set_path_effects([pathEffects.Stroke(linewidth=3, foreground="w"), pathEffects.Normal()]) 

372 cax.tick_params(labelsize=7) 

373 

374 if i == 0 and len(toPlotList) > 1: 

375 cax.yaxis.set_ticks_position("left") 

376 

377 ax.set_xlabel(xCol) 

378 ax.set_ylabel(yCol) 

379 ax.tick_params(axis="x", labelrotation=25) 

380 ax.tick_params(labelsize=7) 

381 

382 ax.set_aspect("equal") 

383 plt.draw() 

384 

385 # Find some useful axis limits 

386 lenXs = [len(xs) for (xs, _, _, _, _) in toPlotList] 

387 if lenXs != [] and np.max(lenXs) > 1000: 

388 padRa = (maxRa - minRa) / 10 

389 padDec = (maxDec - minDec) / 10 

390 ax.set_xlim(maxRa + padRa, minRa - padRa) 

391 ax.set_ylim(minDec - padDec, maxDec + padDec) 

392 else: 

393 ax.invert_xaxis() 

394 

395 # Add useful information to the plot 

396 plt.subplots_adjust(wspace=0.0, hspace=0.0, right=0.85) 

397 fig = plt.gcf() 

398 fig = addPlotInfo(fig, plotInfo) 

399 

400 return fig