Coverage for python/lsst/analysis/tools/actions/plot/skyPlot.py: 11%

189 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-11-16 01:27 -0800

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24from typing import Mapping, Optional 

25 

26import matplotlib.patheffects as pathEffects 

27import matplotlib.pyplot as plt 

28import numpy as np 

29from lsst.pex.config import Field, ListField 

30from matplotlib.figure import Figure 

31from matplotlib.patches import Rectangle 

32from scipy.stats import binned_statistic_2d 

33 

34from ...interfaces import KeyedData, KeyedDataSchema, PlotAction, Scalar, Vector 

35from ...statistics import nansigmaMad, sigmaMad 

36from .plotUtils import addPlotInfo, extremaSort, mkColormap 

37 

38# from .plotUtils import generateSummaryStats, parsePlotInfo 

39 

40 

41class SkyPlot(PlotAction): 

42 

43 xAxisLabel = Field[str](doc="Label to use for the x axis.", optional=False) 

44 yAxisLabel = Field[str](doc="Label to use for the y axis.", optional=False) 

45 zAxisLabel = Field[str](doc="Label to use for the z axis.", optional=False) 

46 

47 fixAroundZero = Field[bool]( 

48 doc="Fix the center of the colorscale to be zero.", 

49 default=False, 

50 ) 

51 

52 plotOutlines = Field[bool]( 

53 doc="Plot the outlines of the ccds/patches?", 

54 default=True, 

55 ) 

56 

57 plotTypes = ListField[str]( 

58 doc="Selection of types of objects to plot. Can take any combination of" 

59 " stars, galaxies, unknown, mag, any.", 

60 optional=False, 

61 # itemCheck=_validatePlotTypes, 

62 ) 

63 

64 plotName = Field[str](doc="The name for the plot.", optional=False) 

65 

66 fixAroundZero = Field[bool]( 

67 doc="Fix the colorbar to be symmetric around zero.", 

68 default=False, 

69 ) 

70 

71 def getInputSchema(self, **kwargs) -> KeyedDataSchema: 

72 base = [] 

73 if "stars" in self.plotTypes: # type: ignore 

74 base.append(("xStars", Vector)) 

75 base.append(("yStars", Vector)) 

76 base.append(("zStars", Vector)) 

77 base.append(("starStatMask", Vector)) 

78 if "galaxies" in self.plotTypes: # type: ignore 

79 base.append(("xGalaxies", Vector)) 

80 base.append(("yGalaxies", Vector)) 

81 base.append(("zGalaxies", Vector)) 

82 base.append(("galaxyStatMask", Vector)) 

83 if "unknown" in self.plotTypes: # type: ignore 

84 base.append(("xUnknowns", Vector)) 

85 base.append(("yUnknowns", Vector)) 

86 base.append(("zUnknowns", Vector)) 

87 base.append(("unknownStatMask", Vector)) 

88 if "any" in self.plotTypes: # type: ignore 

89 base.append(("x", Vector)) 

90 base.append(("y", Vector)) 

91 base.append(("z", Vector)) 

92 base.append(("statMask", Vector)) 

93 

94 return base 

95 

96 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure: 

97 self._validateInput(data, **kwargs) 

98 return self.makePlot(data, **kwargs) 

99 # table is a dict that needs: x, y, run, skymap, filter, tract, 

100 

101 def _validateInput(self, data: KeyedData, **kwargs) -> None: 

102 """NOTE currently can only check that something is not a Scalar, not 

103 check that the data is consistent with Vector 

104 """ 

105 needed = self.getInputSchema(**kwargs) 

106 if remainder := {key.format(**kwargs) for key, _ in needed} - { 

107 key.format(**kwargs) for key in data.keys() 

108 }: 

109 raise ValueError(f"Task needs keys {remainder} but they were not found in input") 

110 for name, typ in needed: 

111 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar) 

112 if isScalar and typ != Scalar: 

113 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}") 

114 

115 def sortAllArrays(self, arrsToSort): 

116 """Sort one array and then return all the others in 

117 the associated order. 

118 """ 

119 ids = extremaSort(arrsToSort[0]) 

120 for (i, arr) in enumerate(arrsToSort): 

121 arrsToSort[i] = arr[ids] 

122 return arrsToSort 

123 

124 def statsAndText(self, arr, mask=None): 

125 """Calculate some stats from an array and return them 

126 and some text. 

127 """ 

128 numPoints = len(arr) 

129 if mask is not None: 

130 arr = arr[mask] 

131 med = np.nanmedian(arr) 

132 sigMad = nansigmaMad(arr) 

133 

134 statsText = ( 

135 "Median: {:0.2f}\n".format(med) 

136 + r"$\sigma_{MAD}$: " 

137 + "{:0.2f}\n".format(sigMad) 

138 + r"n$_{points}$: " 

139 + "{}".format(numPoints) 

140 ) 

141 

142 return med, sigMad, statsText 

143 

144 def makePlot( 

145 self, 

146 data: KeyedData, 

147 plotInfo: Optional[Mapping[str, str]] = None, 

148 sumStats: Optional[Mapping] = None, 

149 **kwargs, 

150 ) -> Figure: 

151 """Prep the catalogue and then make a skyPlot of the given column. 

152 

153 Parameters 

154 ---------- 

155 catPlot : `pandas.core.frame.DataFrame` 

156 The catalog to plot the points from. 

157 dataId : 

158 `lsst.daf.butler.core.dimensions._coordinate._ExpandedTupleDataCoordinate` 

159 The dimensions that the plot is being made from. 

160 runName : `str` 

161 The name of the collection that the plot is written out to. 

162 skymap : `lsst.skymap` 

163 The skymap used to define the patch boundaries. 

164 tableName : `str` 

165 The type of table used to make the plot. 

166 

167 Returns 

168 ------- 

169 `pipeBase.Struct` containing: 

170 skyPlot : `matplotlib.figure.Figure` 

171 The resulting figure. 

172 

173 Notes 

174 ----- 

175 The catalogue is first narrowed down using the selectors specified in 

176 `self.config.selectorActions`. 

177 If the column names are 'Functor' then the functors specified in 

178 `self.config.axisFunctors` are used to calculate the required values. 

179 After this the following functions are run: 

180 

181 `parsePlotInfo` which uses the dataId, runName and tableName to add 

182 useful information to the plot. 

183 

184 `generateSummaryStats` which parses the skymap to give the corners of 

185 the patches for later plotting and calculates some basic statistics 

186 in each patch for the column in self.config.axisActions['zAction']. 

187 

188 `SkyPlot` which makes the plot of the sky distribution of 

189 `self.config.axisActions['zAction']`. 

190 

191 Makes a generic plot showing the value at given points on the sky. 

192 

193 Parameters 

194 ---------- 

195 catPlot : `pandas.core.frame.DataFrame` 

196 The catalog to plot the points from. 

197 plotInfo : `dict` 

198 A dictionary of information about the data being plotted with keys: 

199 ``"run"`` 

200 The output run for the plots (`str`). 

201 ``"skymap"`` 

202 The type of skymap used for the data (`str`). 

203 ``"filter"`` 

204 The filter used for this data (`str`). 

205 ``"tract"`` 

206 The tract that the data comes from (`str`). 

207 sumStats : `dict` 

208 A dictionary where the patchIds are the keys which store the R.A. 

209 and dec of the corners of the patch. 

210 

211 Returns 

212 ------- 

213 fig : `matplotlib.figure.Figure` 

214 The resulting figure. 

215 

216 Notes 

217 ----- 

218 Uses the config options `self.config.xColName` and 

219 `self.config.yColName` to plot points color coded by 

220 `self.config.axisActions['zAction']`. 

221 The points plotted are those selected by the selectors specified in 

222 `self.config.selectorActions`. 

223 """ 

224 fig = plt.figure(dpi=300) 

225 ax = fig.add_subplot(111) 

226 

227 if sumStats is None: 

228 sumStats = {} 

229 

230 if plotInfo is None: 

231 plotInfo = {} 

232 

233 # Make divergent colormaps for stars, galaxes and all the points 

234 blueGreen = mkColormap(["midnightblue", "lightcyan", "darkgreen"]) 

235 redPurple = mkColormap(["indigo", "lemonchiffon", "firebrick"]) 

236 orangeBlue = mkColormap(["darkOrange", "thistle", "midnightblue"]) 

237 

238 xCol = self.xAxisLabel 

239 yCol = self.yAxisLabel 

240 zCol = self.zAxisLabel # noqa: F841 

241 

242 toPlotList = [] 

243 # For galaxies 

244 if "galaxies" in self.plotTypes: 

245 sortedArrs = self.sortAllArrays( 

246 [data["zGalaxies"], data["xGalaxies"], data["yGalaxies"], data["galaxyStatMask"]] 

247 ) 

248 [colorValsGalaxies, xsGalaxies, ysGalaxies, statGalaxies] = sortedArrs 

249 statGalMed, statGalMad, galStatsText = self.statsAndText(colorValsGalaxies, mask=statGalaxies) 

250 # Add statistics 

251 bbox = dict(facecolor="lemonchiffon", alpha=0.5, edgecolor="none") 

252 # Check if plotting stars and galaxies, if so move the 

253 # text box so that both can be seen. Needs to be 

254 # > 2 becuase not being plotted points are assigned 0 

255 if len(self.plotTypes) > 2: 

256 boxLoc = 0.63 

257 else: 

258 boxLoc = 0.8 

259 ax.text(boxLoc, 0.91, galStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox) 

260 toPlotList.append((xsGalaxies, ysGalaxies, colorValsGalaxies, redPurple, "Galaxies")) 

261 

262 # For stars 

263 if "stars" in self.plotTypes: 

264 sortedArrs = self.sortAllArrays( 

265 [data["zStars"], data["xStars"], data["yStars"], data["starStatMask"]] 

266 ) 

267 [colorValsStars, xsStars, ysStars, statStars] = sortedArrs 

268 statStarMed, statStarMad, starStatsText = self.statsAndText(colorValsStars, mask=statStars) 

269 # Add statistics 

270 bbox = dict(facecolor="paleturquoise", alpha=0.5, edgecolor="none") 

271 ax.text(0.8, 0.91, starStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox) 

272 toPlotList.append((xsStars, ysStars, colorValsStars, blueGreen, "Stars")) 

273 

274 # For unknowns 

275 if "unknown" in self.plotTypes: 

276 sortedArrs = self.sortAllArrays( 

277 [data["zUnknowns"], data["xUnknowns"], data["yUnknowns"], data["unknownStatMask"]] 

278 ) 

279 [colorValsUnknowns, xsUnknowns, ysUnknowns, statUnknowns] = sortedArrs 

280 statUnknownMed, statUnknownMad, unknownStatsText = self.statsAndText( 

281 colorValsUnknowns, mask=statUnknowns 

282 ) 

283 bbox = dict(facecolor="green", alpha=0.2, edgecolor="none") 

284 ax.text(0.8, 0.91, unknownStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox) 

285 toPlotList.append((xsUnknowns, ysUnknowns, colorValsUnknowns, "viridis", "Unknown")) 

286 

287 if "any" in self.plotTypes: 

288 sortedArrs = self.sortAllArrays([data["z"], data["x"], data["y"], data["statMask"]]) 

289 [colorValsAny, xs, ys, statAny] = sortedArrs 

290 statAnyMed, statAnyMad, anyStatsText = self.statsAndText(colorValsAny, mask=statAny) 

291 bbox = dict(facecolor="purple", alpha=0.2, edgecolor="none") 

292 ax.text(0.8, 0.91, anyStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox) 

293 toPlotList.append((xs, ys, colorValsAny, orangeBlue, "All")) 

294 

295 # Corner plot of patches showing summary stat in each 

296 if self.plotOutlines: 

297 patches = [] 

298 for dataId in sumStats.keys(): 

299 (corners, _) = sumStats[dataId] 

300 ra = corners[0][0].asDegrees() 

301 dec = corners[0][1].asDegrees() 

302 xy = (ra, dec) 

303 width = corners[2][0].asDegrees() - ra 

304 height = corners[2][1].asDegrees() - dec 

305 patches.append(Rectangle(xy, width, height, alpha=0.3)) 

306 ras = [ra.asDegrees() for (ra, dec) in corners] 

307 decs = [dec.asDegrees() for (ra, dec) in corners] 

308 ax.plot(ras + [ras[0]], decs + [decs[0]], "k", lw=0.5) 

309 cenX = ra + width / 2 

310 cenY = dec + height / 2 

311 if dataId == "tract": 

312 minRa = np.min(ras) 

313 minDec = np.min(decs) 

314 maxRa = np.max(ras) 

315 maxDec = np.max(decs) 

316 if dataId != "tract": 

317 ax.annotate( 

318 dataId, 

319 (cenX, cenY), 

320 color="k", 

321 fontsize=5, 

322 ha="center", 

323 va="center", 

324 path_effects=[pathEffects.withStroke(linewidth=2, foreground="w")], 

325 ) 

326 

327 for (i, (xs, ys, colorVals, cmap, label)) in enumerate(toPlotList): 

328 if not self.plotOutlines or "tract" not in sumStats.keys(): 

329 minRa = np.min(xs) 

330 maxRa = np.max(xs) 

331 minDec = np.min(ys) 

332 maxDec = np.max(ys) 

333 # Avoid identical end points which causes problems in binning 

334 if minRa == maxRa: 

335 maxRa += 1e-5 # There is no reason to pick this number in particular 

336 if minDec == maxDec: 

337 maxDec += 1e-5 # There is no reason to pick this number in particular 

338 med = np.median(colorVals) 

339 mad = sigmaMad(colorVals) 

340 vmin = med - 2 * mad 

341 vmax = med + 2 * mad 

342 if self.fixAroundZero: 

343 scaleEnd = np.max([np.abs(vmin), np.abs(vmax)]) 

344 vmin = -1 * scaleEnd 

345 vmax = scaleEnd 

346 nBins = 45 

347 xBinEdges = np.linspace(minRa, maxRa, nBins + 1) 

348 yBinEdges = np.linspace(minDec, maxDec, nBins + 1) 

349 binnedStats, xEdges, yEdges, binNums = binned_statistic_2d( 

350 xs, ys, colorVals, statistic="median", bins=(xBinEdges, yBinEdges) 

351 ) 

352 

353 if len(xs) > 5000: 

354 s = 500 / (len(xs) ** 0.5) 

355 lw = (s**0.5) / 10 

356 plotOut = ax.imshow( 

357 binnedStats.T, 

358 cmap=cmap, 

359 extent=[xEdges[0], xEdges[-1], yEdges[-1], yEdges[0]], 

360 vmin=vmin, 

361 vmax=vmax, 

362 ) 

363 # find the most extreme 15% of points, because the list 

364 # is ordered by the distance from the median this is just 

365 # the final 15% of points 

366 extremes = int(np.floor((len(xs) / 100)) * 85) 

367 ax.scatter( 

368 xs[extremes:], 

369 ys[extremes:], 

370 c=colorVals[extremes:], 

371 s=s, 

372 cmap=cmap, 

373 vmin=vmin, 

374 vmax=vmax, 

375 edgecolor="white", 

376 linewidths=lw, 

377 ) 

378 

379 else: 

380 plotOut = ax.scatter( 

381 xs, 

382 ys, 

383 c=colorVals, 

384 cmap=cmap, 

385 s=7, 

386 vmin=vmin, 

387 vmax=vmax, 

388 edgecolor="white", 

389 linewidths=0.2, 

390 ) 

391 

392 cax = fig.add_axes([0.87 + i * 0.04, 0.11, 0.04, 0.77]) 

393 plt.colorbar(plotOut, cax=cax, extend="both") 

394 colorBarLabel = "{}: {}".format(self.zAxisLabel, label) 

395 text = cax.text( 

396 0.5, 

397 0.5, 

398 colorBarLabel, 

399 color="k", 

400 rotation="vertical", 

401 transform=cax.transAxes, 

402 ha="center", 

403 va="center", 

404 fontsize=10, 

405 ) 

406 text.set_path_effects([pathEffects.Stroke(linewidth=3, foreground="w"), pathEffects.Normal()]) 

407 cax.tick_params(labelsize=7) 

408 

409 if i == 0 and len(toPlotList) > 1: 

410 cax.yaxis.set_ticks_position("left") 

411 

412 ax.set_xlabel(xCol) 

413 ax.set_ylabel(yCol) 

414 ax.tick_params(axis="x", labelrotation=25) 

415 ax.tick_params(labelsize=7) 

416 

417 ax.set_aspect("equal") 

418 plt.draw() 

419 

420 # Find some useful axis limits 

421 lenXs = [len(xs) for (xs, _, _, _, _) in toPlotList] 

422 if lenXs != [] and np.max(lenXs) > 1000: 

423 padRa = (maxRa - minRa) / 10 

424 padDec = (maxDec - minDec) / 10 

425 ax.set_xlim(maxRa + padRa, minRa - padRa) 

426 ax.set_ylim(minDec - padDec, maxDec + padDec) 

427 else: 

428 ax.invert_xaxis() 

429 

430 # Add useful information to the plot 

431 plt.subplots_adjust(wspace=0.0, hspace=0.0, right=0.85) 

432 fig = plt.gcf() 

433 fig = addPlotInfo(fig, plotInfo) 

434 

435 return fig