Coverage for python/lsst/analysis/tools/actions/plot/skyPlot.py: 12%

180 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-23 13:09 +0000

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ("SkyPlot",) 

25 

26from typing import Mapping, Optional 

27 

28import matplotlib.patheffects as pathEffects 

29import matplotlib.pyplot as plt 

30import numpy as np 

31from lsst.pex.config import Field, ListField 

32from lsst.pex.config.configurableActions import ConfigurableActionField 

33from matplotlib.figure import Figure 

34from matplotlib.patches import Rectangle 

35 

36from ...interfaces import KeyedData, KeyedDataSchema, PlotAction, Scalar, Vector, VectorAction 

37from ...math import nanMedian, nanSigmaMad 

38from .calculateRange import Med2Mad 

39from .plotUtils import addPlotInfo, generateSummaryStats, mkColormap, plotProjectionWithBinning, sortAllArrays 

40 

41 

42class SkyPlot(PlotAction): 

43 """Plots the on sky distribution of a parameter. 

44 

45 Plots the values of the parameter given for the z axis 

46 according to the positions given for x and y. Optimised 

47 for use with RA and Dec. Also calculates some basic 

48 statistics and includes those on the plot. 

49 

50 The plotting of patch outlines requires patch information 

51 to be included as an additional parameter. 

52 """ 

53 

54 xAxisLabel = Field[str](doc="Label to use for the x axis.", optional=False) 

55 yAxisLabel = Field[str](doc="Label to use for the y axis.", optional=False) 

56 zAxisLabel = Field[str](doc="Label to use for the z axis.", optional=False) 

57 

58 plotOutlines = Field[bool]( 

59 doc="Plot the outlines of the ccds/patches?", 

60 default=True, 

61 ) 

62 

63 plotTypes = ListField[str]( 

64 doc="Selection of types of objects to plot. Can take any combination of" 

65 " stars, galaxies, unknown, mag, any.", 

66 optional=False, 

67 # itemCheck=_validatePlotTypes, 

68 ) 

69 

70 plotName = Field[str](doc="The name for the plot.", optional=False) 

71 

72 fixAroundZero = Field[bool]( 

73 doc="Fix the colorbar to be symmetric around zero.", 

74 default=False, 

75 ) 

76 

77 colorbarRange = ConfigurableActionField[VectorAction]( 

78 doc="Action to calculate the min and max of the colorbar range.", 

79 default=Med2Mad, 

80 ) 

81 

82 showExtremeOutliers = Field[bool]( 

83 doc="Show the x-y positions of extreme outlier values as overlaid scatter points.", 

84 default=True, 

85 ) 

86 

87 def getInputSchema(self, **kwargs) -> KeyedDataSchema: 

88 base = [] 

89 if "stars" in self.plotTypes: # type: ignore 

90 base.append(("xStars", Vector)) 

91 base.append(("yStars", Vector)) 

92 base.append(("zStars", Vector)) 

93 base.append(("starStatMask", Vector)) 

94 if "galaxies" in self.plotTypes: # type: ignore 

95 base.append(("xGalaxies", Vector)) 

96 base.append(("yGalaxies", Vector)) 

97 base.append(("zGalaxies", Vector)) 

98 base.append(("galaxyStatMask", Vector)) 

99 if "unknown" in self.plotTypes: # type: ignore 

100 base.append(("xUnknowns", Vector)) 

101 base.append(("yUnknowns", Vector)) 

102 base.append(("zUnknowns", Vector)) 

103 base.append(("unknownStatMask", Vector)) 

104 if "any" in self.plotTypes: # type: ignore 

105 base.append(("x", Vector)) 

106 base.append(("y", Vector)) 

107 base.append(("z", Vector)) 

108 base.append(("statMask", Vector)) 

109 

110 return base 

111 

112 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure: 

113 self._validateInput(data, **kwargs) 

114 return self.makePlot(data, **kwargs) 

115 # table is a dict that needs: x, y, run, skymap, filter, tract, 

116 

117 def _validateInput(self, data: KeyedData, **kwargs) -> None: 

118 """NOTE currently can only check that something is not a Scalar, not 

119 check that the data is consistent with Vector 

120 """ 

121 needed = self.getInputSchema(**kwargs) 

122 if remainder := {key.format(**kwargs) for key, _ in needed} - { 

123 key.format(**kwargs) for key in data.keys() 

124 }: 

125 raise ValueError(f"Task needs keys {remainder} but they were not found in input") 

126 for name, typ in needed: 

127 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar) 

128 if isScalar and typ != Scalar: 

129 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}") 

130 

131 def statsAndText(self, arr, mask=None): 

132 """Calculate some stats from an array and return them 

133 and some text. 

134 """ 

135 numPoints = len(arr) 

136 if mask is not None: 

137 arr = arr[mask] 

138 med = nanMedian(arr) 

139 sigMad = nanSigmaMad(arr) 

140 

141 statsText = ( 

142 "Median: {:0.2f}\n".format(med) 

143 + r"$\sigma_{MAD}$: " 

144 + "{:0.2f}\n".format(sigMad) 

145 + r"n$_{points}$: " 

146 + "{}".format(numPoints) 

147 ) 

148 

149 return med, sigMad, statsText 

150 

151 def makePlot( 

152 self, 

153 data: KeyedData, 

154 plotInfo: Optional[Mapping[str, str]] = None, 

155 sumStats: Optional[Mapping] = None, 

156 **kwargs, 

157 ) -> Figure: 

158 """Make a skyPlot of the given data. 

159 

160 Parameters 

161 ---------- 

162 data : `KeyedData` 

163 The catalog to plot the points from. 

164 plotInfo : `dict` 

165 A dictionary of information about the data being plotted with keys: 

166 

167 ``"run"`` 

168 The output run for the plots (`str`). 

169 ``"skymap"`` 

170 The type of skymap used for the data (`str`). 

171 ``"filter"`` 

172 The filter used for this data (`str`). 

173 ``"tract"`` 

174 The tract that the data comes from (`str`). 

175 

176 sumStats : `dict` 

177 A dictionary where the patchIds are the keys which store the R.A. 

178 and dec of the corners of the patch. 

179 

180 Returns 

181 ------- 

182 `pipeBase.Struct` containing: 

183 skyPlot : `matplotlib.figure.Figure` 

184 The resulting figure. 

185 

186 Notes 

187 ----- 

188 Expects the data to contain slightly different things 

189 depending on the types specified in plotTypes. This 

190 is handled automatically if you go through the pipetask 

191 framework but if you call this method separately then you 

192 need to make sure that data contains what the code is expecting. 

193 

194 If stars is in the plot types given then it is expected that 

195 data contains: xStars, yStars, zStars and starStatMask. 

196 

197 If galaxies is present: xGalaxies, yGalaxies, zGalaxies and 

198 galaxyStatsMask. 

199 

200 If unknown is present: xUnknowns, yUnknowns, zUnknowns and 

201 unknownStatMask. 

202 

203 If any is specified: x, y, z, statMask. 

204 

205 These options are not exclusive and multiple can be specified 

206 and thus need to be present in data. 

207 

208 Examples 

209 -------- 

210 An example of the plot produced from this code is here: 

211 

212 .. image:: /_static/analysis_tools/skyPlotExample.png 

213 

214 For a detailed example of how to make a plot from the command line 

215 please see the 

216 :ref:`getting started guide<analysis-tools-getting-started>`. 

217 """ 

218 

219 fig = plt.figure(dpi=300) 

220 ax = fig.add_subplot(111) 

221 

222 if sumStats is None: 

223 if self.plotOutlines and "patch" in data.keys(): 

224 sumStats = generateSummaryStats(data, kwargs["skymap"], plotInfo) 

225 else: 

226 sumStats = {} 

227 

228 if plotInfo is None: 

229 plotInfo = {} 

230 

231 # Make divergent colormaps for stars, galaxes and all the points 

232 blueGreen = mkColormap(["midnightblue", "lightcyan", "darkgreen"]) 

233 redPurple = mkColormap(["indigo", "lemonchiffon", "firebrick"]) 

234 orangeBlue = mkColormap(["darkOrange", "thistle", "midnightblue"]) 

235 

236 xCol = self.xAxisLabel 

237 yCol = self.yAxisLabel 

238 zCol = self.zAxisLabel # noqa: F841 

239 

240 toPlotList = [] 

241 # For galaxies 

242 if "galaxies" in self.plotTypes: 

243 sortedArrs = sortAllArrays( 

244 [data["zGalaxies"], data["xGalaxies"], data["yGalaxies"], data["galaxyStatMask"]] 

245 ) 

246 [colorValsGalaxies, xsGalaxies, ysGalaxies, statGalaxies] = sortedArrs 

247 statGalMed, statGalMad, galStatsText = self.statsAndText(colorValsGalaxies, mask=statGalaxies) 

248 # Add statistics 

249 bbox = dict(facecolor="lemonchiffon", alpha=0.5, edgecolor="none") 

250 # Check if plotting stars and galaxies, if so move the 

251 # text box so that both can be seen. Needs to be 

252 # > 2 becuase not being plotted points are assigned 0 

253 if len(self.plotTypes) > 2: 

254 boxLoc = 0.63 

255 else: 

256 boxLoc = 0.8 

257 ax.text(boxLoc, 0.91, galStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox) 

258 toPlotList.append((xsGalaxies, ysGalaxies, colorValsGalaxies, redPurple, "Galaxies")) 

259 

260 # For stars 

261 if "stars" in self.plotTypes: 

262 sortedArrs = sortAllArrays([data["zStars"], data["xStars"], data["yStars"], data["starStatMask"]]) 

263 [colorValsStars, xsStars, ysStars, statStars] = sortedArrs 

264 statStarMed, statStarMad, starStatsText = self.statsAndText(colorValsStars, mask=statStars) 

265 # Add statistics 

266 bbox = dict(facecolor="paleturquoise", alpha=0.5, edgecolor="none") 

267 ax.text(0.8, 0.91, starStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox) 

268 toPlotList.append((xsStars, ysStars, colorValsStars, blueGreen, "Stars")) 

269 

270 # For unknowns 

271 if "unknown" in self.plotTypes: 

272 sortedArrs = sortAllArrays( 

273 [data["zUnknowns"], data["xUnknowns"], data["yUnknowns"], data["unknownStatMask"]] 

274 ) 

275 [colorValsUnknowns, xsUnknowns, ysUnknowns, statUnknowns] = sortedArrs 

276 statUnknownMed, statUnknownMad, unknownStatsText = self.statsAndText( 

277 colorValsUnknowns, mask=statUnknowns 

278 ) 

279 bbox = dict(facecolor="green", alpha=0.2, edgecolor="none") 

280 ax.text(0.8, 0.91, unknownStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox) 

281 toPlotList.append((xsUnknowns, ysUnknowns, colorValsUnknowns, "viridis", "Unknown")) 

282 

283 if "any" in self.plotTypes: 

284 sortedArrs = sortAllArrays([data["z"], data["x"], data["y"], data["statMask"]]) 

285 [colorValsAny, xs, ys, statAny] = sortedArrs 

286 statAnyMed, statAnyMad, anyStatsText = self.statsAndText(colorValsAny, mask=statAny) 

287 bbox = dict(facecolor="purple", alpha=0.2, edgecolor="none") 

288 ax.text(0.8, 0.91, anyStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox) 

289 toPlotList.append((xs, ys, colorValsAny, orangeBlue, "All")) 

290 

291 # Corner plot of patches showing summary stat in each 

292 if self.plotOutlines: 

293 patches = [] 

294 for dataId in sumStats.keys(): 

295 (corners, _) = sumStats[dataId] 

296 ra = corners[0][0].asDegrees() 

297 dec = corners[0][1].asDegrees() 

298 xy = (ra, dec) 

299 width = corners[2][0].asDegrees() - ra 

300 height = corners[2][1].asDegrees() - dec 

301 patches.append(Rectangle(xy, width, height, alpha=0.3)) 

302 ras = [ra.asDegrees() for (ra, dec) in corners] 

303 decs = [dec.asDegrees() for (ra, dec) in corners] 

304 ax.plot(ras + [ras[0]], decs + [decs[0]], "k", lw=0.5) 

305 cenX = ra + width / 2 

306 cenY = dec + height / 2 

307 if dataId == "tract": 

308 minRa = np.min(ras) 

309 minDec = np.min(decs) 

310 maxRa = np.max(ras) 

311 maxDec = np.max(decs) 

312 if dataId != "tract": 

313 ax.annotate( 

314 dataId, 

315 (cenX, cenY), 

316 color="k", 

317 fontsize=5, 

318 ha="center", 

319 va="center", 

320 path_effects=[pathEffects.withStroke(linewidth=2, foreground="w")], 

321 ) 

322 

323 for i, (xs, ys, colorVals, cmap, label) in enumerate(toPlotList): 

324 finite = np.isfinite(xs) & np.isfinite(ys) 

325 xs = xs[finite] 

326 ys = ys[finite] 

327 n_xs = len(xs) 

328 # colorVal column is unusable so zero it out 

329 # This should be obvious on the plot 

330 if not any(np.isfinite(colorVals)): 

331 colorVals[:] = 0 

332 minColorVal, maxColorVal = self.colorbarRange(colorVals) 

333 

334 if n_xs < 5: 

335 continue 

336 if not self.plotOutlines or "tract" not in sumStats.keys(): 

337 minRa = np.min(xs) 

338 maxRa = np.max(xs) 

339 minDec = np.min(ys) 

340 maxDec = np.max(ys) 

341 # Avoid identical end points which causes problems in binning 

342 if minRa == maxRa: 

343 maxRa += 1e-5 # There is no reason to pick this number in particular 

344 if minDec == maxDec: 

345 maxDec += 1e-5 # There is no reason to pick this number in particular 

346 

347 plotOut = plotProjectionWithBinning( 

348 ax, 

349 xs, 

350 ys, 

351 colorVals, 

352 cmap, 

353 minRa, 

354 maxRa, 

355 minDec, 

356 maxDec, 

357 vmin=minColorVal, 

358 vmax=maxColorVal, 

359 fixAroundZero=self.fixAroundZero, 

360 isSorted=True, 

361 showExtremeOutliers=self.showExtremeOutliers, 

362 ) 

363 cax = fig.add_axes([0.87 + i * 0.04, 0.11, 0.04, 0.77]) 

364 plt.colorbar(plotOut, cax=cax, extend="both") 

365 colorBarLabel = "{}: {}".format(self.zAxisLabel, label) 

366 text = cax.text( 

367 0.5, 

368 0.5, 

369 colorBarLabel, 

370 color="k", 

371 rotation="vertical", 

372 transform=cax.transAxes, 

373 ha="center", 

374 va="center", 

375 fontsize=10, 

376 ) 

377 text.set_path_effects([pathEffects.Stroke(linewidth=3, foreground="w"), pathEffects.Normal()]) 

378 cax.tick_params(labelsize=7) 

379 

380 if i == 0 and len(toPlotList) > 1: 

381 cax.yaxis.set_ticks_position("left") 

382 

383 ax.set_xlabel(xCol) 

384 ax.set_ylabel(yCol) 

385 ax.tick_params(axis="x", labelrotation=25) 

386 ax.tick_params(labelsize=7) 

387 

388 ax.set_aspect("equal") 

389 plt.draw() 

390 

391 # Find some useful axis limits 

392 lenXs = [len(xs) for (xs, _, _, _, _) in toPlotList] 

393 if lenXs != [] and np.max(lenXs) > 1000: 

394 padRa = (maxRa - minRa) / 10 

395 padDec = (maxDec - minDec) / 10 

396 ax.set_xlim(maxRa + padRa, minRa - padRa) 

397 ax.set_ylim(minDec - padDec, maxDec + padDec) 

398 else: 

399 ax.invert_xaxis() 

400 

401 # Add useful information to the plot 

402 plt.subplots_adjust(wspace=0.0, hspace=0.0, right=0.85) 

403 fig = plt.gcf() 

404 fig = addPlotInfo(fig, plotInfo) 

405 

406 return fig