Coverage for python/lsst/analysis/tools/actions/plot/skyPlot.py: 11%

164 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-12-30 01:58 -0800

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24from typing import Mapping, Optional 

25 

26import matplotlib.patheffects as pathEffects 

27import matplotlib.pyplot as plt 

28import numpy as np 

29from lsst.pex.config import Field, ListField 

30from matplotlib.figure import Figure 

31from matplotlib.patches import Rectangle 

32 

33from ...interfaces import KeyedData, KeyedDataSchema, PlotAction, Scalar, Vector 

34from ...statistics import nansigmaMad 

35from .plotUtils import addPlotInfo, mkColormap, plotProjectionWithBinning, sortAllArrays 

36 

37# from .plotUtils import generateSummaryStats, parsePlotInfo 

38 

39 

40class SkyPlot(PlotAction): 

41 

42 xAxisLabel = Field[str](doc="Label to use for the x axis.", optional=False) 

43 yAxisLabel = Field[str](doc="Label to use for the y axis.", optional=False) 

44 zAxisLabel = Field[str](doc="Label to use for the z axis.", optional=False) 

45 

46 plotOutlines = Field[bool]( 

47 doc="Plot the outlines of the ccds/patches?", 

48 default=True, 

49 ) 

50 

51 plotTypes = ListField[str]( 

52 doc="Selection of types of objects to plot. Can take any combination of" 

53 " stars, galaxies, unknown, mag, any.", 

54 optional=False, 

55 # itemCheck=_validatePlotTypes, 

56 ) 

57 

58 plotName = Field[str](doc="The name for the plot.", optional=False) 

59 

60 fixAroundZero = Field[bool]( 

61 doc="Fix the colorbar to be symmetric around zero.", 

62 default=False, 

63 ) 

64 

65 def getInputSchema(self, **kwargs) -> KeyedDataSchema: 

66 base = [] 

67 if "stars" in self.plotTypes: # type: ignore 

68 base.append(("xStars", Vector)) 

69 base.append(("yStars", Vector)) 

70 base.append(("zStars", Vector)) 

71 base.append(("starStatMask", Vector)) 

72 if "galaxies" in self.plotTypes: # type: ignore 

73 base.append(("xGalaxies", Vector)) 

74 base.append(("yGalaxies", Vector)) 

75 base.append(("zGalaxies", Vector)) 

76 base.append(("galaxyStatMask", Vector)) 

77 if "unknown" in self.plotTypes: # type: ignore 

78 base.append(("xUnknowns", Vector)) 

79 base.append(("yUnknowns", Vector)) 

80 base.append(("zUnknowns", Vector)) 

81 base.append(("unknownStatMask", Vector)) 

82 if "any" in self.plotTypes: # type: ignore 

83 base.append(("x", Vector)) 

84 base.append(("y", Vector)) 

85 base.append(("z", Vector)) 

86 base.append(("statMask", Vector)) 

87 

88 return base 

89 

90 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure: 

91 self._validateInput(data, **kwargs) 

92 return self.makePlot(data, **kwargs) 

93 # table is a dict that needs: x, y, run, skymap, filter, tract, 

94 

95 def _validateInput(self, data: KeyedData, **kwargs) -> None: 

96 """NOTE currently can only check that something is not a Scalar, not 

97 check that the data is consistent with Vector 

98 """ 

99 needed = self.getInputSchema(**kwargs) 

100 if remainder := {key.format(**kwargs) for key, _ in needed} - { 

101 key.format(**kwargs) for key in data.keys() 

102 }: 

103 raise ValueError(f"Task needs keys {remainder} but they were not found in input") 

104 for name, typ in needed: 

105 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar) 

106 if isScalar and typ != Scalar: 

107 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}") 

108 

109 def statsAndText(self, arr, mask=None): 

110 """Calculate some stats from an array and return them 

111 and some text. 

112 """ 

113 numPoints = len(arr) 

114 if mask is not None: 

115 arr = arr[mask] 

116 med = np.nanmedian(arr) 

117 sigMad = nansigmaMad(arr) 

118 

119 statsText = ( 

120 "Median: {:0.2f}\n".format(med) 

121 + r"$\sigma_{MAD}$: " 

122 + "{:0.2f}\n".format(sigMad) 

123 + r"n$_{points}$: " 

124 + "{}".format(numPoints) 

125 ) 

126 

127 return med, sigMad, statsText 

128 

129 def makePlot( 

130 self, 

131 data: KeyedData, 

132 plotInfo: Optional[Mapping[str, str]] = None, 

133 sumStats: Optional[Mapping] = None, 

134 **kwargs, 

135 ) -> Figure: 

136 """Prep the catalogue and then make a skyPlot of the given column. 

137 

138 Parameters 

139 ---------- 

140 catPlot : `pandas.core.frame.DataFrame` 

141 The catalog to plot the points from. 

142 dataId : 

143 `lsst.daf.butler.core.dimensions._coordinate._ExpandedTupleDataCoordinate` 

144 The dimensions that the plot is being made from. 

145 runName : `str` 

146 The name of the collection that the plot is written out to. 

147 skymap : `lsst.skymap` 

148 The skymap used to define the patch boundaries. 

149 tableName : `str` 

150 The type of table used to make the plot. 

151 

152 Returns 

153 ------- 

154 `pipeBase.Struct` containing: 

155 skyPlot : `matplotlib.figure.Figure` 

156 The resulting figure. 

157 

158 Notes 

159 ----- 

160 The catalogue is first narrowed down using the selectors specified in 

161 `self.config.selectorActions`. 

162 If the column names are 'Functor' then the functors specified in 

163 `self.config.axisFunctors` are used to calculate the required values. 

164 After this the following functions are run: 

165 

166 `parsePlotInfo` which uses the dataId, runName and tableName to add 

167 useful information to the plot. 

168 

169 `generateSummaryStats` which parses the skymap to give the corners of 

170 the patches for later plotting and calculates some basic statistics 

171 in each patch for the column in self.config.axisActions['zAction']. 

172 

173 `SkyPlot` which makes the plot of the sky distribution of 

174 `self.config.axisActions['zAction']`. 

175 

176 Makes a generic plot showing the value at given points on the sky. 

177 

178 Parameters 

179 ---------- 

180 catPlot : `pandas.core.frame.DataFrame` 

181 The catalog to plot the points from. 

182 plotInfo : `dict` 

183 A dictionary of information about the data being plotted with keys: 

184 ``"run"`` 

185 The output run for the plots (`str`). 

186 ``"skymap"`` 

187 The type of skymap used for the data (`str`). 

188 ``"filter"`` 

189 The filter used for this data (`str`). 

190 ``"tract"`` 

191 The tract that the data comes from (`str`). 

192 sumStats : `dict` 

193 A dictionary where the patchIds are the keys which store the R.A. 

194 and dec of the corners of the patch. 

195 

196 Returns 

197 ------- 

198 fig : `matplotlib.figure.Figure` 

199 The resulting figure. 

200 

201 Notes 

202 ----- 

203 Uses the config options `self.config.xColName` and 

204 `self.config.yColName` to plot points color coded by 

205 `self.config.axisActions['zAction']`. 

206 The points plotted are those selected by the selectors specified in 

207 `self.config.selectorActions`. 

208 """ 

209 fig = plt.figure(dpi=300) 

210 ax = fig.add_subplot(111) 

211 

212 if sumStats is None: 

213 sumStats = {} 

214 

215 if plotInfo is None: 

216 plotInfo = {} 

217 

218 # Make divergent colormaps for stars, galaxes and all the points 

219 blueGreen = mkColormap(["midnightblue", "lightcyan", "darkgreen"]) 

220 redPurple = mkColormap(["indigo", "lemonchiffon", "firebrick"]) 

221 orangeBlue = mkColormap(["darkOrange", "thistle", "midnightblue"]) 

222 

223 xCol = self.xAxisLabel 

224 yCol = self.yAxisLabel 

225 zCol = self.zAxisLabel # noqa: F841 

226 

227 toPlotList = [] 

228 # For galaxies 

229 if "galaxies" in self.plotTypes: 

230 sortedArrs = sortAllArrays( 

231 [data["zGalaxies"], data["xGalaxies"], data["yGalaxies"], data["galaxyStatMask"]] 

232 ) 

233 [colorValsGalaxies, xsGalaxies, ysGalaxies, statGalaxies] = sortedArrs 

234 statGalMed, statGalMad, galStatsText = self.statsAndText(colorValsGalaxies, mask=statGalaxies) 

235 # Add statistics 

236 bbox = dict(facecolor="lemonchiffon", alpha=0.5, edgecolor="none") 

237 # Check if plotting stars and galaxies, if so move the 

238 # text box so that both can be seen. Needs to be 

239 # > 2 becuase not being plotted points are assigned 0 

240 if len(self.plotTypes) > 2: 

241 boxLoc = 0.63 

242 else: 

243 boxLoc = 0.8 

244 ax.text(boxLoc, 0.91, galStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox) 

245 toPlotList.append((xsGalaxies, ysGalaxies, colorValsGalaxies, redPurple, "Galaxies")) 

246 

247 # For stars 

248 if "stars" in self.plotTypes: 

249 sortedArrs = sortAllArrays([data["zStars"], data["xStars"], data["yStars"], data["starStatMask"]]) 

250 [colorValsStars, xsStars, ysStars, statStars] = sortedArrs 

251 statStarMed, statStarMad, starStatsText = self.statsAndText(colorValsStars, mask=statStars) 

252 # Add statistics 

253 bbox = dict(facecolor="paleturquoise", alpha=0.5, edgecolor="none") 

254 ax.text(0.8, 0.91, starStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox) 

255 toPlotList.append((xsStars, ysStars, colorValsStars, blueGreen, "Stars")) 

256 

257 # For unknowns 

258 if "unknown" in self.plotTypes: 

259 sortedArrs = sortAllArrays( 

260 [data["zUnknowns"], data["xUnknowns"], data["yUnknowns"], data["unknownStatMask"]] 

261 ) 

262 [colorValsUnknowns, xsUnknowns, ysUnknowns, statUnknowns] = sortedArrs 

263 statUnknownMed, statUnknownMad, unknownStatsText = self.statsAndText( 

264 colorValsUnknowns, mask=statUnknowns 

265 ) 

266 bbox = dict(facecolor="green", alpha=0.2, edgecolor="none") 

267 ax.text(0.8, 0.91, unknownStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox) 

268 toPlotList.append((xsUnknowns, ysUnknowns, colorValsUnknowns, "viridis", "Unknown")) 

269 

270 if "any" in self.plotTypes: 

271 sortedArrs = sortAllArrays([data["z"], data["x"], data["y"], data["statMask"]]) 

272 [colorValsAny, xs, ys, statAny] = sortedArrs 

273 statAnyMed, statAnyMad, anyStatsText = self.statsAndText(colorValsAny, mask=statAny) 

274 bbox = dict(facecolor="purple", alpha=0.2, edgecolor="none") 

275 ax.text(0.8, 0.91, anyStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox) 

276 toPlotList.append((xs, ys, colorValsAny, orangeBlue, "All")) 

277 

278 # Corner plot of patches showing summary stat in each 

279 if self.plotOutlines: 

280 patches = [] 

281 for dataId in sumStats.keys(): 

282 (corners, _) = sumStats[dataId] 

283 ra = corners[0][0].asDegrees() 

284 dec = corners[0][1].asDegrees() 

285 xy = (ra, dec) 

286 width = corners[2][0].asDegrees() - ra 

287 height = corners[2][1].asDegrees() - dec 

288 patches.append(Rectangle(xy, width, height, alpha=0.3)) 

289 ras = [ra.asDegrees() for (ra, dec) in corners] 

290 decs = [dec.asDegrees() for (ra, dec) in corners] 

291 ax.plot(ras + [ras[0]], decs + [decs[0]], "k", lw=0.5) 

292 cenX = ra + width / 2 

293 cenY = dec + height / 2 

294 if dataId == "tract": 

295 minRa = np.min(ras) 

296 minDec = np.min(decs) 

297 maxRa = np.max(ras) 

298 maxDec = np.max(decs) 

299 if dataId != "tract": 

300 ax.annotate( 

301 dataId, 

302 (cenX, cenY), 

303 color="k", 

304 fontsize=5, 

305 ha="center", 

306 va="center", 

307 path_effects=[pathEffects.withStroke(linewidth=2, foreground="w")], 

308 ) 

309 

310 for (i, (xs, ys, colorVals, cmap, label)) in enumerate(toPlotList): 

311 if not self.plotOutlines or "tract" not in sumStats.keys(): 

312 minRa = np.min(xs) 

313 maxRa = np.max(xs) 

314 minDec = np.min(ys) 

315 maxDec = np.max(ys) 

316 # Avoid identical end points which causes problems in binning 

317 if minRa == maxRa: 

318 maxRa += 1e-5 # There is no reason to pick this number in particular 

319 if minDec == maxDec: 

320 maxDec += 1e-5 # There is no reason to pick this number in particular 

321 

322 plotOut = plotProjectionWithBinning( 

323 ax, 

324 xs, 

325 ys, 

326 colorVals, 

327 cmap, 

328 minRa, 

329 maxRa, 

330 minDec, 

331 maxDec, 

332 fixAroundZero=self.fixAroundZero, 

333 isSorted=True, 

334 ) 

335 cax = fig.add_axes([0.87 + i * 0.04, 0.11, 0.04, 0.77]) 

336 plt.colorbar(plotOut, cax=cax, extend="both") 

337 colorBarLabel = "{}: {}".format(self.zAxisLabel, label) 

338 text = cax.text( 

339 0.5, 

340 0.5, 

341 colorBarLabel, 

342 color="k", 

343 rotation="vertical", 

344 transform=cax.transAxes, 

345 ha="center", 

346 va="center", 

347 fontsize=10, 

348 ) 

349 text.set_path_effects([pathEffects.Stroke(linewidth=3, foreground="w"), pathEffects.Normal()]) 

350 cax.tick_params(labelsize=7) 

351 

352 if i == 0 and len(toPlotList) > 1: 

353 cax.yaxis.set_ticks_position("left") 

354 

355 ax.set_xlabel(xCol) 

356 ax.set_ylabel(yCol) 

357 ax.tick_params(axis="x", labelrotation=25) 

358 ax.tick_params(labelsize=7) 

359 

360 ax.set_aspect("equal") 

361 plt.draw() 

362 

363 # Find some useful axis limits 

364 lenXs = [len(xs) for (xs, _, _, _, _) in toPlotList] 

365 if lenXs != [] and np.max(lenXs) > 1000: 

366 padRa = (maxRa - minRa) / 10 

367 padDec = (maxDec - minDec) / 10 

368 ax.set_xlim(maxRa + padRa, minRa - padRa) 

369 ax.set_ylim(minDec - padDec, maxDec + padDec) 

370 else: 

371 ax.invert_xaxis() 

372 

373 # Add useful information to the plot 

374 plt.subplots_adjust(wspace=0.0, hspace=0.0, right=0.85) 

375 fig = plt.gcf() 

376 fig = addPlotInfo(fig, plotInfo) 

377 

378 return fig