Coverage for python/lsst/analysis/tools/actions/plot/skyPlot.py: 11%

165 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-02-01 03:17 -0800

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ("SkyPlot",) 

25 

26from typing import Mapping, Optional 

27 

28import matplotlib.patheffects as pathEffects 

29import matplotlib.pyplot as plt 

30import numpy as np 

31from lsst.pex.config import Field, ListField 

32from matplotlib.figure import Figure 

33from matplotlib.patches import Rectangle 

34 

35from ...interfaces import KeyedData, KeyedDataSchema, PlotAction, Scalar, Vector 

36from ...statistics import nansigmaMad 

37from .plotUtils import addPlotInfo, mkColormap, plotProjectionWithBinning, sortAllArrays 

38 

39# from .plotUtils import generateSummaryStats, parsePlotInfo 

40 

41 

42class SkyPlot(PlotAction): 

43 

44 xAxisLabel = Field[str](doc="Label to use for the x axis.", optional=False) 

45 yAxisLabel = Field[str](doc="Label to use for the y axis.", optional=False) 

46 zAxisLabel = Field[str](doc="Label to use for the z axis.", optional=False) 

47 

48 plotOutlines = Field[bool]( 

49 doc="Plot the outlines of the ccds/patches?", 

50 default=True, 

51 ) 

52 

53 plotTypes = ListField[str]( 

54 doc="Selection of types of objects to plot. Can take any combination of" 

55 " stars, galaxies, unknown, mag, any.", 

56 optional=False, 

57 # itemCheck=_validatePlotTypes, 

58 ) 

59 

60 plotName = Field[str](doc="The name for the plot.", optional=False) 

61 

62 fixAroundZero = Field[bool]( 

63 doc="Fix the colorbar to be symmetric around zero.", 

64 default=False, 

65 ) 

66 

67 def getInputSchema(self, **kwargs) -> KeyedDataSchema: 

68 base = [] 

69 if "stars" in self.plotTypes: # type: ignore 

70 base.append(("xStars", Vector)) 

71 base.append(("yStars", Vector)) 

72 base.append(("zStars", Vector)) 

73 base.append(("starStatMask", Vector)) 

74 if "galaxies" in self.plotTypes: # type: ignore 

75 base.append(("xGalaxies", Vector)) 

76 base.append(("yGalaxies", Vector)) 

77 base.append(("zGalaxies", Vector)) 

78 base.append(("galaxyStatMask", Vector)) 

79 if "unknown" in self.plotTypes: # type: ignore 

80 base.append(("xUnknowns", Vector)) 

81 base.append(("yUnknowns", Vector)) 

82 base.append(("zUnknowns", Vector)) 

83 base.append(("unknownStatMask", Vector)) 

84 if "any" in self.plotTypes: # type: ignore 

85 base.append(("x", Vector)) 

86 base.append(("y", Vector)) 

87 base.append(("z", Vector)) 

88 base.append(("statMask", Vector)) 

89 

90 return base 

91 

92 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure: 

93 self._validateInput(data, **kwargs) 

94 return self.makePlot(data, **kwargs) 

95 # table is a dict that needs: x, y, run, skymap, filter, tract, 

96 

97 def _validateInput(self, data: KeyedData, **kwargs) -> None: 

98 """NOTE currently can only check that something is not a Scalar, not 

99 check that the data is consistent with Vector 

100 """ 

101 needed = self.getInputSchema(**kwargs) 

102 if remainder := {key.format(**kwargs) for key, _ in needed} - { 

103 key.format(**kwargs) for key in data.keys() 

104 }: 

105 raise ValueError(f"Task needs keys {remainder} but they were not found in input") 

106 for name, typ in needed: 

107 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar) 

108 if isScalar and typ != Scalar: 

109 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}") 

110 

111 def statsAndText(self, arr, mask=None): 

112 """Calculate some stats from an array and return them 

113 and some text. 

114 """ 

115 numPoints = len(arr) 

116 if mask is not None: 

117 arr = arr[mask] 

118 med = np.nanmedian(arr) 

119 sigMad = nansigmaMad(arr) 

120 

121 statsText = ( 

122 "Median: {:0.2f}\n".format(med) 

123 + r"$\sigma_{MAD}$: " 

124 + "{:0.2f}\n".format(sigMad) 

125 + r"n$_{points}$: " 

126 + "{}".format(numPoints) 

127 ) 

128 

129 return med, sigMad, statsText 

130 

131 def makePlot( 

132 self, 

133 data: KeyedData, 

134 plotInfo: Optional[Mapping[str, str]] = None, 

135 sumStats: Optional[Mapping] = None, 

136 **kwargs, 

137 ) -> Figure: 

138 """Prep the catalogue and then make a skyPlot of the given column. 

139 

140 Parameters 

141 ---------- 

142 catPlot : `pandas.core.frame.DataFrame` 

143 The catalog to plot the points from. 

144 dataId : 

145 `lsst.daf.butler.core.dimensions._coordinate._ExpandedTupleDataCoordinate` 

146 The dimensions that the plot is being made from. 

147 runName : `str` 

148 The name of the collection that the plot is written out to. 

149 skymap : `lsst.skymap` 

150 The skymap used to define the patch boundaries. 

151 tableName : `str` 

152 The type of table used to make the plot. 

153 

154 Returns 

155 ------- 

156 `pipeBase.Struct` containing: 

157 skyPlot : `matplotlib.figure.Figure` 

158 The resulting figure. 

159 

160 Notes 

161 ----- 

162 The catalogue is first narrowed down using the selectors specified in 

163 `self.config.selectorActions`. 

164 If the column names are 'Functor' then the functors specified in 

165 `self.config.axisFunctors` are used to calculate the required values. 

166 After this the following functions are run: 

167 

168 `parsePlotInfo` which uses the dataId, runName and tableName to add 

169 useful information to the plot. 

170 

171 `generateSummaryStats` which parses the skymap to give the corners of 

172 the patches for later plotting and calculates some basic statistics 

173 in each patch for the column in self.config.axisActions['zAction']. 

174 

175 `SkyPlot` which makes the plot of the sky distribution of 

176 `self.config.axisActions['zAction']`. 

177 

178 Makes a generic plot showing the value at given points on the sky. 

179 

180 Parameters 

181 ---------- 

182 catPlot : `pandas.core.frame.DataFrame` 

183 The catalog to plot the points from. 

184 plotInfo : `dict` 

185 A dictionary of information about the data being plotted with keys: 

186 ``"run"`` 

187 The output run for the plots (`str`). 

188 ``"skymap"`` 

189 The type of skymap used for the data (`str`). 

190 ``"filter"`` 

191 The filter used for this data (`str`). 

192 ``"tract"`` 

193 The tract that the data comes from (`str`). 

194 sumStats : `dict` 

195 A dictionary where the patchIds are the keys which store the R.A. 

196 and dec of the corners of the patch. 

197 

198 Returns 

199 ------- 

200 fig : `matplotlib.figure.Figure` 

201 The resulting figure. 

202 

203 Notes 

204 ----- 

205 Uses the config options `self.config.xColName` and 

206 `self.config.yColName` to plot points color coded by 

207 `self.config.axisActions['zAction']`. 

208 The points plotted are those selected by the selectors specified in 

209 `self.config.selectorActions`. 

210 """ 

211 fig = plt.figure(dpi=300) 

212 ax = fig.add_subplot(111) 

213 

214 if sumStats is None: 

215 sumStats = {} 

216 

217 if plotInfo is None: 

218 plotInfo = {} 

219 

220 # Make divergent colormaps for stars, galaxes and all the points 

221 blueGreen = mkColormap(["midnightblue", "lightcyan", "darkgreen"]) 

222 redPurple = mkColormap(["indigo", "lemonchiffon", "firebrick"]) 

223 orangeBlue = mkColormap(["darkOrange", "thistle", "midnightblue"]) 

224 

225 xCol = self.xAxisLabel 

226 yCol = self.yAxisLabel 

227 zCol = self.zAxisLabel # noqa: F841 

228 

229 toPlotList = [] 

230 # For galaxies 

231 if "galaxies" in self.plotTypes: 

232 sortedArrs = sortAllArrays( 

233 [data["zGalaxies"], data["xGalaxies"], data["yGalaxies"], data["galaxyStatMask"]] 

234 ) 

235 [colorValsGalaxies, xsGalaxies, ysGalaxies, statGalaxies] = sortedArrs 

236 statGalMed, statGalMad, galStatsText = self.statsAndText(colorValsGalaxies, mask=statGalaxies) 

237 # Add statistics 

238 bbox = dict(facecolor="lemonchiffon", alpha=0.5, edgecolor="none") 

239 # Check if plotting stars and galaxies, if so move the 

240 # text box so that both can be seen. Needs to be 

241 # > 2 becuase not being plotted points are assigned 0 

242 if len(self.plotTypes) > 2: 

243 boxLoc = 0.63 

244 else: 

245 boxLoc = 0.8 

246 ax.text(boxLoc, 0.91, galStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox) 

247 toPlotList.append((xsGalaxies, ysGalaxies, colorValsGalaxies, redPurple, "Galaxies")) 

248 

249 # For stars 

250 if "stars" in self.plotTypes: 

251 sortedArrs = sortAllArrays([data["zStars"], data["xStars"], data["yStars"], data["starStatMask"]]) 

252 [colorValsStars, xsStars, ysStars, statStars] = sortedArrs 

253 statStarMed, statStarMad, starStatsText = self.statsAndText(colorValsStars, mask=statStars) 

254 # Add statistics 

255 bbox = dict(facecolor="paleturquoise", alpha=0.5, edgecolor="none") 

256 ax.text(0.8, 0.91, starStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox) 

257 toPlotList.append((xsStars, ysStars, colorValsStars, blueGreen, "Stars")) 

258 

259 # For unknowns 

260 if "unknown" in self.plotTypes: 

261 sortedArrs = sortAllArrays( 

262 [data["zUnknowns"], data["xUnknowns"], data["yUnknowns"], data["unknownStatMask"]] 

263 ) 

264 [colorValsUnknowns, xsUnknowns, ysUnknowns, statUnknowns] = sortedArrs 

265 statUnknownMed, statUnknownMad, unknownStatsText = self.statsAndText( 

266 colorValsUnknowns, mask=statUnknowns 

267 ) 

268 bbox = dict(facecolor="green", alpha=0.2, edgecolor="none") 

269 ax.text(0.8, 0.91, unknownStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox) 

270 toPlotList.append((xsUnknowns, ysUnknowns, colorValsUnknowns, "viridis", "Unknown")) 

271 

272 if "any" in self.plotTypes: 

273 sortedArrs = sortAllArrays([data["z"], data["x"], data["y"], data["statMask"]]) 

274 [colorValsAny, xs, ys, statAny] = sortedArrs 

275 statAnyMed, statAnyMad, anyStatsText = self.statsAndText(colorValsAny, mask=statAny) 

276 bbox = dict(facecolor="purple", alpha=0.2, edgecolor="none") 

277 ax.text(0.8, 0.91, anyStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox) 

278 toPlotList.append((xs, ys, colorValsAny, orangeBlue, "All")) 

279 

280 # Corner plot of patches showing summary stat in each 

281 if self.plotOutlines: 

282 patches = [] 

283 for dataId in sumStats.keys(): 

284 (corners, _) = sumStats[dataId] 

285 ra = corners[0][0].asDegrees() 

286 dec = corners[0][1].asDegrees() 

287 xy = (ra, dec) 

288 width = corners[2][0].asDegrees() - ra 

289 height = corners[2][1].asDegrees() - dec 

290 patches.append(Rectangle(xy, width, height, alpha=0.3)) 

291 ras = [ra.asDegrees() for (ra, dec) in corners] 

292 decs = [dec.asDegrees() for (ra, dec) in corners] 

293 ax.plot(ras + [ras[0]], decs + [decs[0]], "k", lw=0.5) 

294 cenX = ra + width / 2 

295 cenY = dec + height / 2 

296 if dataId == "tract": 

297 minRa = np.min(ras) 

298 minDec = np.min(decs) 

299 maxRa = np.max(ras) 

300 maxDec = np.max(decs) 

301 if dataId != "tract": 

302 ax.annotate( 

303 dataId, 

304 (cenX, cenY), 

305 color="k", 

306 fontsize=5, 

307 ha="center", 

308 va="center", 

309 path_effects=[pathEffects.withStroke(linewidth=2, foreground="w")], 

310 ) 

311 

312 for (i, (xs, ys, colorVals, cmap, label)) in enumerate(toPlotList): 

313 if not self.plotOutlines or "tract" not in sumStats.keys(): 

314 minRa = np.min(xs) 

315 maxRa = np.max(xs) 

316 minDec = np.min(ys) 

317 maxDec = np.max(ys) 

318 # Avoid identical end points which causes problems in binning 

319 if minRa == maxRa: 

320 maxRa += 1e-5 # There is no reason to pick this number in particular 

321 if minDec == maxDec: 

322 maxDec += 1e-5 # There is no reason to pick this number in particular 

323 

324 plotOut = plotProjectionWithBinning( 

325 ax, 

326 xs, 

327 ys, 

328 colorVals, 

329 cmap, 

330 minRa, 

331 maxRa, 

332 minDec, 

333 maxDec, 

334 fixAroundZero=self.fixAroundZero, 

335 isSorted=True, 

336 ) 

337 cax = fig.add_axes([0.87 + i * 0.04, 0.11, 0.04, 0.77]) 

338 plt.colorbar(plotOut, cax=cax, extend="both") 

339 colorBarLabel = "{}: {}".format(self.zAxisLabel, label) 

340 text = cax.text( 

341 0.5, 

342 0.5, 

343 colorBarLabel, 

344 color="k", 

345 rotation="vertical", 

346 transform=cax.transAxes, 

347 ha="center", 

348 va="center", 

349 fontsize=10, 

350 ) 

351 text.set_path_effects([pathEffects.Stroke(linewidth=3, foreground="w"), pathEffects.Normal()]) 

352 cax.tick_params(labelsize=7) 

353 

354 if i == 0 and len(toPlotList) > 1: 

355 cax.yaxis.set_ticks_position("left") 

356 

357 ax.set_xlabel(xCol) 

358 ax.set_ylabel(yCol) 

359 ax.tick_params(axis="x", labelrotation=25) 

360 ax.tick_params(labelsize=7) 

361 

362 ax.set_aspect("equal") 

363 plt.draw() 

364 

365 # Find some useful axis limits 

366 lenXs = [len(xs) for (xs, _, _, _, _) in toPlotList] 

367 if lenXs != [] and np.max(lenXs) > 1000: 

368 padRa = (maxRa - minRa) / 10 

369 padDec = (maxDec - minDec) / 10 

370 ax.set_xlim(maxRa + padRa, minRa - padRa) 

371 ax.set_ylim(minDec - padDec, maxDec + padDec) 

372 else: 

373 ax.invert_xaxis() 

374 

375 # Add useful information to the plot 

376 plt.subplots_adjust(wspace=0.0, hspace=0.0, right=0.85) 

377 fig = plt.gcf() 

378 fig = addPlotInfo(fig, plotInfo) 

379 

380 return fig