Coverage for python/lsst/analysis/tools/actions/plot/skyPlot.py: 11%

165 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-02-28 03:55 -0800

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ("SkyPlot",) 

25 

26from typing import Mapping, Optional 

27 

28import matplotlib.patheffects as pathEffects 

29import matplotlib.pyplot as plt 

30import numpy as np 

31from lsst.pex.config import Field, ListField 

32from matplotlib.figure import Figure 

33from matplotlib.patches import Rectangle 

34 

35from ...interfaces import KeyedData, KeyedDataSchema, PlotAction, Scalar, Vector 

36from ...statistics import nansigmaMad 

37from .plotUtils import addPlotInfo, mkColormap, plotProjectionWithBinning, sortAllArrays 

38 

39# from .plotUtils import generateSummaryStats, parsePlotInfo 

40 

41 

42class SkyPlot(PlotAction): 

43 xAxisLabel = Field[str](doc="Label to use for the x axis.", optional=False) 

44 yAxisLabel = Field[str](doc="Label to use for the y axis.", optional=False) 

45 zAxisLabel = Field[str](doc="Label to use for the z axis.", optional=False) 

46 

47 plotOutlines = Field[bool]( 

48 doc="Plot the outlines of the ccds/patches?", 

49 default=True, 

50 ) 

51 

52 plotTypes = ListField[str]( 

53 doc="Selection of types of objects to plot. Can take any combination of" 

54 " stars, galaxies, unknown, mag, any.", 

55 optional=False, 

56 # itemCheck=_validatePlotTypes, 

57 ) 

58 

59 plotName = Field[str](doc="The name for the plot.", optional=False) 

60 

61 fixAroundZero = Field[bool]( 

62 doc="Fix the colorbar to be symmetric around zero.", 

63 default=False, 

64 ) 

65 

66 def getInputSchema(self, **kwargs) -> KeyedDataSchema: 

67 base = [] 

68 if "stars" in self.plotTypes: # type: ignore 

69 base.append(("xStars", Vector)) 

70 base.append(("yStars", Vector)) 

71 base.append(("zStars", Vector)) 

72 base.append(("starStatMask", Vector)) 

73 if "galaxies" in self.plotTypes: # type: ignore 

74 base.append(("xGalaxies", Vector)) 

75 base.append(("yGalaxies", Vector)) 

76 base.append(("zGalaxies", Vector)) 

77 base.append(("galaxyStatMask", Vector)) 

78 if "unknown" in self.plotTypes: # type: ignore 

79 base.append(("xUnknowns", Vector)) 

80 base.append(("yUnknowns", Vector)) 

81 base.append(("zUnknowns", Vector)) 

82 base.append(("unknownStatMask", Vector)) 

83 if "any" in self.plotTypes: # type: ignore 

84 base.append(("x", Vector)) 

85 base.append(("y", Vector)) 

86 base.append(("z", Vector)) 

87 base.append(("statMask", Vector)) 

88 

89 return base 

90 

91 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure: 

92 self._validateInput(data, **kwargs) 

93 return self.makePlot(data, **kwargs) 

94 # table is a dict that needs: x, y, run, skymap, filter, tract, 

95 

96 def _validateInput(self, data: KeyedData, **kwargs) -> None: 

97 """NOTE currently can only check that something is not a Scalar, not 

98 check that the data is consistent with Vector 

99 """ 

100 needed = self.getInputSchema(**kwargs) 

101 if remainder := {key.format(**kwargs) for key, _ in needed} - { 

102 key.format(**kwargs) for key in data.keys() 

103 }: 

104 raise ValueError(f"Task needs keys {remainder} but they were not found in input") 

105 for name, typ in needed: 

106 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar) 

107 if isScalar and typ != Scalar: 

108 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}") 

109 

110 def statsAndText(self, arr, mask=None): 

111 """Calculate some stats from an array and return them 

112 and some text. 

113 """ 

114 numPoints = len(arr) 

115 if mask is not None: 

116 arr = arr[mask] 

117 med = np.nanmedian(arr) 

118 sigMad = nansigmaMad(arr) 

119 

120 statsText = ( 

121 "Median: {:0.2f}\n".format(med) 

122 + r"$\sigma_{MAD}$: " 

123 + "{:0.2f}\n".format(sigMad) 

124 + r"n$_{points}$: " 

125 + "{}".format(numPoints) 

126 ) 

127 

128 return med, sigMad, statsText 

129 

130 def makePlot( 

131 self, 

132 data: KeyedData, 

133 plotInfo: Optional[Mapping[str, str]] = None, 

134 sumStats: Optional[Mapping] = None, 

135 **kwargs, 

136 ) -> Figure: 

137 """Prep the catalogue and then make a skyPlot of the given column. 

138 

139 Parameters 

140 ---------- 

141 catPlot : `pandas.core.frame.DataFrame` 

142 The catalog to plot the points from. 

143 dataId : 

144 `lsst.daf.butler.core.dimensions._coordinate._ExpandedTupleDataCoordinate` 

145 The dimensions that the plot is being made from. 

146 runName : `str` 

147 The name of the collection that the plot is written out to. 

148 skymap : `lsst.skymap` 

149 The skymap used to define the patch boundaries. 

150 tableName : `str` 

151 The type of table used to make the plot. 

152 

153 Returns 

154 ------- 

155 `pipeBase.Struct` containing: 

156 skyPlot : `matplotlib.figure.Figure` 

157 The resulting figure. 

158 

159 Notes 

160 ----- 

161 The catalogue is first narrowed down using the selectors specified in 

162 `self.config.selectorActions`. 

163 If the column names are 'Functor' then the functors specified in 

164 `self.config.axisFunctors` are used to calculate the required values. 

165 After this the following functions are run: 

166 

167 `parsePlotInfo` which uses the dataId, runName and tableName to add 

168 useful information to the plot. 

169 

170 `generateSummaryStats` which parses the skymap to give the corners of 

171 the patches for later plotting and calculates some basic statistics 

172 in each patch for the column in self.config.axisActions['zAction']. 

173 

174 `SkyPlot` which makes the plot of the sky distribution of 

175 `self.config.axisActions['zAction']`. 

176 

177 Makes a generic plot showing the value at given points on the sky. 

178 

179 Parameters 

180 ---------- 

181 catPlot : `pandas.core.frame.DataFrame` 

182 The catalog to plot the points from. 

183 plotInfo : `dict` 

184 A dictionary of information about the data being plotted with keys: 

185 ``"run"`` 

186 The output run for the plots (`str`). 

187 ``"skymap"`` 

188 The type of skymap used for the data (`str`). 

189 ``"filter"`` 

190 The filter used for this data (`str`). 

191 ``"tract"`` 

192 The tract that the data comes from (`str`). 

193 sumStats : `dict` 

194 A dictionary where the patchIds are the keys which store the R.A. 

195 and dec of the corners of the patch. 

196 

197 Returns 

198 ------- 

199 fig : `matplotlib.figure.Figure` 

200 The resulting figure. 

201 

202 Notes 

203 ----- 

204 Uses the config options `self.config.xColName` and 

205 `self.config.yColName` to plot points color coded by 

206 `self.config.axisActions['zAction']`. 

207 The points plotted are those selected by the selectors specified in 

208 `self.config.selectorActions`. 

209 """ 

210 fig = plt.figure(dpi=300) 

211 ax = fig.add_subplot(111) 

212 

213 if sumStats is None: 

214 sumStats = {} 

215 

216 if plotInfo is None: 

217 plotInfo = {} 

218 

219 # Make divergent colormaps for stars, galaxes and all the points 

220 blueGreen = mkColormap(["midnightblue", "lightcyan", "darkgreen"]) 

221 redPurple = mkColormap(["indigo", "lemonchiffon", "firebrick"]) 

222 orangeBlue = mkColormap(["darkOrange", "thistle", "midnightblue"]) 

223 

224 xCol = self.xAxisLabel 

225 yCol = self.yAxisLabel 

226 zCol = self.zAxisLabel # noqa: F841 

227 

228 toPlotList = [] 

229 # For galaxies 

230 if "galaxies" in self.plotTypes: 

231 sortedArrs = sortAllArrays( 

232 [data["zGalaxies"], data["xGalaxies"], data["yGalaxies"], data["galaxyStatMask"]] 

233 ) 

234 [colorValsGalaxies, xsGalaxies, ysGalaxies, statGalaxies] = sortedArrs 

235 statGalMed, statGalMad, galStatsText = self.statsAndText(colorValsGalaxies, mask=statGalaxies) 

236 # Add statistics 

237 bbox = dict(facecolor="lemonchiffon", alpha=0.5, edgecolor="none") 

238 # Check if plotting stars and galaxies, if so move the 

239 # text box so that both can be seen. Needs to be 

240 # > 2 becuase not being plotted points are assigned 0 

241 if len(self.plotTypes) > 2: 

242 boxLoc = 0.63 

243 else: 

244 boxLoc = 0.8 

245 ax.text(boxLoc, 0.91, galStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox) 

246 toPlotList.append((xsGalaxies, ysGalaxies, colorValsGalaxies, redPurple, "Galaxies")) 

247 

248 # For stars 

249 if "stars" in self.plotTypes: 

250 sortedArrs = sortAllArrays([data["zStars"], data["xStars"], data["yStars"], data["starStatMask"]]) 

251 [colorValsStars, xsStars, ysStars, statStars] = sortedArrs 

252 statStarMed, statStarMad, starStatsText = self.statsAndText(colorValsStars, mask=statStars) 

253 # Add statistics 

254 bbox = dict(facecolor="paleturquoise", alpha=0.5, edgecolor="none") 

255 ax.text(0.8, 0.91, starStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox) 

256 toPlotList.append((xsStars, ysStars, colorValsStars, blueGreen, "Stars")) 

257 

258 # For unknowns 

259 if "unknown" in self.plotTypes: 

260 sortedArrs = sortAllArrays( 

261 [data["zUnknowns"], data["xUnknowns"], data["yUnknowns"], data["unknownStatMask"]] 

262 ) 

263 [colorValsUnknowns, xsUnknowns, ysUnknowns, statUnknowns] = sortedArrs 

264 statUnknownMed, statUnknownMad, unknownStatsText = self.statsAndText( 

265 colorValsUnknowns, mask=statUnknowns 

266 ) 

267 bbox = dict(facecolor="green", alpha=0.2, edgecolor="none") 

268 ax.text(0.8, 0.91, unknownStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox) 

269 toPlotList.append((xsUnknowns, ysUnknowns, colorValsUnknowns, "viridis", "Unknown")) 

270 

271 if "any" in self.plotTypes: 

272 sortedArrs = sortAllArrays([data["z"], data["x"], data["y"], data["statMask"]]) 

273 [colorValsAny, xs, ys, statAny] = sortedArrs 

274 statAnyMed, statAnyMad, anyStatsText = self.statsAndText(colorValsAny, mask=statAny) 

275 bbox = dict(facecolor="purple", alpha=0.2, edgecolor="none") 

276 ax.text(0.8, 0.91, anyStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox) 

277 toPlotList.append((xs, ys, colorValsAny, orangeBlue, "All")) 

278 

279 # Corner plot of patches showing summary stat in each 

280 if self.plotOutlines: 

281 patches = [] 

282 for dataId in sumStats.keys(): 

283 (corners, _) = sumStats[dataId] 

284 ra = corners[0][0].asDegrees() 

285 dec = corners[0][1].asDegrees() 

286 xy = (ra, dec) 

287 width = corners[2][0].asDegrees() - ra 

288 height = corners[2][1].asDegrees() - dec 

289 patches.append(Rectangle(xy, width, height, alpha=0.3)) 

290 ras = [ra.asDegrees() for (ra, dec) in corners] 

291 decs = [dec.asDegrees() for (ra, dec) in corners] 

292 ax.plot(ras + [ras[0]], decs + [decs[0]], "k", lw=0.5) 

293 cenX = ra + width / 2 

294 cenY = dec + height / 2 

295 if dataId == "tract": 

296 minRa = np.min(ras) 

297 minDec = np.min(decs) 

298 maxRa = np.max(ras) 

299 maxDec = np.max(decs) 

300 if dataId != "tract": 

301 ax.annotate( 

302 dataId, 

303 (cenX, cenY), 

304 color="k", 

305 fontsize=5, 

306 ha="center", 

307 va="center", 

308 path_effects=[pathEffects.withStroke(linewidth=2, foreground="w")], 

309 ) 

310 

311 for i, (xs, ys, colorVals, cmap, label) in enumerate(toPlotList): 

312 if not self.plotOutlines or "tract" not in sumStats.keys(): 

313 minRa = np.min(xs) 

314 maxRa = np.max(xs) 

315 minDec = np.min(ys) 

316 maxDec = np.max(ys) 

317 # Avoid identical end points which causes problems in binning 

318 if minRa == maxRa: 

319 maxRa += 1e-5 # There is no reason to pick this number in particular 

320 if minDec == maxDec: 

321 maxDec += 1e-5 # There is no reason to pick this number in particular 

322 

323 plotOut = plotProjectionWithBinning( 

324 ax, 

325 xs, 

326 ys, 

327 colorVals, 

328 cmap, 

329 minRa, 

330 maxRa, 

331 minDec, 

332 maxDec, 

333 fixAroundZero=self.fixAroundZero, 

334 isSorted=True, 

335 ) 

336 cax = fig.add_axes([0.87 + i * 0.04, 0.11, 0.04, 0.77]) 

337 plt.colorbar(plotOut, cax=cax, extend="both") 

338 colorBarLabel = "{}: {}".format(self.zAxisLabel, label) 

339 text = cax.text( 

340 0.5, 

341 0.5, 

342 colorBarLabel, 

343 color="k", 

344 rotation="vertical", 

345 transform=cax.transAxes, 

346 ha="center", 

347 va="center", 

348 fontsize=10, 

349 ) 

350 text.set_path_effects([pathEffects.Stroke(linewidth=3, foreground="w"), pathEffects.Normal()]) 

351 cax.tick_params(labelsize=7) 

352 

353 if i == 0 and len(toPlotList) > 1: 

354 cax.yaxis.set_ticks_position("left") 

355 

356 ax.set_xlabel(xCol) 

357 ax.set_ylabel(yCol) 

358 ax.tick_params(axis="x", labelrotation=25) 

359 ax.tick_params(labelsize=7) 

360 

361 ax.set_aspect("equal") 

362 plt.draw() 

363 

364 # Find some useful axis limits 

365 lenXs = [len(xs) for (xs, _, _, _, _) in toPlotList] 

366 if lenXs != [] and np.max(lenXs) > 1000: 

367 padRa = (maxRa - minRa) / 10 

368 padDec = (maxDec - minDec) / 10 

369 ax.set_xlim(maxRa + padRa, minRa - padRa) 

370 ax.set_ylim(minDec - padDec, maxDec + padDec) 

371 else: 

372 ax.invert_xaxis() 

373 

374 # Add useful information to the plot 

375 plt.subplots_adjust(wspace=0.0, hspace=0.0, right=0.85) 

376 fig = plt.gcf() 

377 fig = addPlotInfo(fig, plotInfo) 

378 

379 return fig