Coverage for python/lsst/analysis/tools/actions/plot/skyPlot.py: 12%

185 statements  

« prev     ^ index     » next       coverage.py v6.4.4, created at 2022-09-27 02:41 -0700

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24from typing import Mapping 

25 

26import matplotlib.patheffects as pathEffects 

27import matplotlib.pyplot as plt 

28import numpy as np 

29from lsst.pex.config import Field, ListField 

30from matplotlib.figure import Figure 

31from matplotlib.patches import Rectangle 

32from scipy.stats import binned_statistic_2d 

33 

34from ...interfaces import KeyedData, KeyedDataSchema, PlotAction, Scalar, Vector 

35from ...statistics import nansigmaMad, sigmaMad 

36from .plotUtils import addPlotInfo, extremaSort, mkColormap 

37 

38# from .plotUtils import generateSummaryStats, parsePlotInfo 

39 

40 

41class SkyPlot(PlotAction): 

42 

43 xAxisLabel = Field[str](doc="Label to use for the x axis.", optional=False) 

44 yAxisLabel = Field[str](doc="Label to use for the y axis.", optional=False) 

45 zAxisLabel = Field[str](doc="Label to use for the z axis.", optional=False) 

46 

47 fixAroundZero = Field[bool]( 

48 doc="Fix the center of the colorscale to be zero.", 

49 default=False, 

50 ) 

51 

52 plotOutlines = Field[bool]( 

53 doc="Plot the outlines of the ccds/patches?", 

54 default=True, 

55 ) 

56 

57 plotTypes = ListField[str]( 

58 doc="Selection of types of objects to plot. Can take any combination of" 

59 " stars, galaxies, unknown, mag, any.", 

60 optional=False, 

61 # itemCheck=_validatePlotTypes, 

62 ) 

63 

64 plotName = Field[str](doc="The name for the plot.", optional=False) 

65 

66 fixAroundZero = Field[bool]( 

67 doc="Fix the colorbar to be symmetric around zero.", 

68 default=False, 

69 ) 

70 

71 def getInputSchema(self, **kwargs) -> KeyedDataSchema: 

72 base = [] 

73 if "stars" in self.plotTypes: # type: ignore 

74 base.append(("xStars", Vector)) 

75 base.append(("yStars", Vector)) 

76 base.append(("zStars", Vector)) 

77 base.append(("starStatMask", Vector)) 

78 if "galaxies" in self.plotTypes: # type: ignore 

79 base.append(("xGalaxies", Vector)) 

80 base.append(("yGalaxies", Vector)) 

81 base.append(("zGalaxies", Vector)) 

82 base.append(("galaxyStatMask", Vector)) 

83 if "unknown" in self.plotTypes: # type: ignore 

84 base.append(("xUnknowns", Vector)) 

85 base.append(("yUnknowns", Vector)) 

86 base.append(("zUnknowns", Vector)) 

87 base.append(("unknownStatMask", Vector)) 

88 if "any" in self.plotTypes: # type: ignore 

89 base.append(("x", Vector)) 

90 base.append(("y", Vector)) 

91 base.append(("z", Vector)) 

92 base.append(("statMask", Vector)) 

93 

94 return base 

95 

96 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure: 

97 self._validateInput(data, **kwargs) 

98 return self.makePlot(data, **kwargs) 

99 # table is a dict that needs: x, y, run, skymap, filter, tract, 

100 

101 def _validateInput(self, data: KeyedData, **kwargs) -> None: 

102 """NOTE currently can only check that something is not a Scalar, not 

103 check that the data is consistent with Vector 

104 """ 

105 needed = self.getInputSchema(**kwargs) 

106 if remainder := {key.format(**kwargs) for key, _ in needed} - { 

107 key.format(**kwargs) for key in data.keys() 

108 }: 

109 raise ValueError(f"Task needs keys {remainder} but they were not found in input") 

110 for name, typ in needed: 

111 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar) 

112 if isScalar and typ != Scalar: 

113 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}") 

114 

115 def sortAllArrays(self, arrsToSort): 

116 """Sort one array and then return all the others in 

117 the associated order. 

118 """ 

119 ids = extremaSort(arrsToSort[0]) 

120 for (i, arr) in enumerate(arrsToSort): 

121 arrsToSort[i] = arr[ids] 

122 return arrsToSort 

123 

124 def statsAndText(self, arr, mask=None): 

125 """Calculate some stats from an array and return them 

126 and some text. 

127 """ 

128 numPoints = len(arr) 

129 if mask is not None: 

130 arr = arr[mask] 

131 med = np.nanmedian(arr) 

132 sigMad = nansigmaMad(arr) 

133 

134 statsText = ( 

135 "Median: {:0.2f}\n".format(med) 

136 + r"$\sigma_{MAD}$: " 

137 + "{:0.2f}\n".format(sigMad) 

138 + r"n$_{points}$: " 

139 + "{}".format(numPoints) 

140 ) 

141 

142 return med, sigMad, statsText 

143 

144 def makePlot( 

145 self, data: KeyedData, plotInfo: Mapping[str, str] = None, sumStats: Mapping = {}, **kwargs 

146 ) -> Figure: 

147 """Prep the catalogue and then make a skyPlot of the given column. 

148 

149 Parameters 

150 ---------- 

151 catPlot : `pandas.core.frame.DataFrame` 

152 The catalog to plot the points from. 

153 dataId : 

154 `lsst.daf.butler.core.dimensions._coordinate._ExpandedTupleDataCoordinate` 

155 The dimensions that the plot is being made from. 

156 runName : `str` 

157 The name of the collection that the plot is written out to. 

158 skymap : `lsst.skymap` 

159 The skymap used to define the patch boundaries. 

160 tableName : `str` 

161 The type of table used to make the plot. 

162 

163 Returns 

164 ------- 

165 `pipeBase.Struct` containing: 

166 skyPlot : `matplotlib.figure.Figure` 

167 The resulting figure. 

168 

169 Notes 

170 ----- 

171 The catalogue is first narrowed down using the selectors specified in 

172 `self.config.selectorActions`. 

173 If the column names are 'Functor' then the functors specified in 

174 `self.config.axisFunctors` are used to calculate the required values. 

175 After this the following functions are run: 

176 

177 `parsePlotInfo` which uses the dataId, runName and tableName to add 

178 useful information to the plot. 

179 

180 `generateSummaryStats` which parses the skymap to give the corners of 

181 the patches for later plotting and calculates some basic statistics 

182 in each patch for the column in self.config.axisActions['zAction']. 

183 

184 `SkyPlot` which makes the plot of the sky distribution of 

185 `self.config.axisActions['zAction']`. 

186 

187 Makes a generic plot showing the value at given points on the sky. 

188 

189 Parameters 

190 ---------- 

191 catPlot : `pandas.core.frame.DataFrame` 

192 The catalog to plot the points from. 

193 plotInfo : `dict` 

194 A dictionary of information about the data being plotted with keys: 

195 ``"run"`` 

196 The output run for the plots (`str`). 

197 ``"skymap"`` 

198 The type of skymap used for the data (`str`). 

199 ``"filter"`` 

200 The filter used for this data (`str`). 

201 ``"tract"`` 

202 The tract that the data comes from (`str`). 

203 sumStats : `dict` 

204 A dictionary where the patchIds are the keys which store the R.A. 

205 and dec of the corners of the patch. 

206 

207 Returns 

208 ------- 

209 fig : `matplotlib.figure.Figure` 

210 The resulting figure. 

211 

212 Notes 

213 ----- 

214 Uses the config options `self.config.xColName` and 

215 `self.config.yColName` to plot points color coded by 

216 `self.config.axisActions['zAction']`. 

217 The points plotted are those selected by the selectors specified in 

218 `self.config.selectorActions`. 

219 """ 

220 fig = plt.figure(dpi=300) 

221 ax = fig.add_subplot(111) 

222 

223 # Make divergent colormaps for stars, galaxes and all the points 

224 blueGreen = mkColormap(["midnightblue", "lightcyan", "darkgreen"]) 

225 redPurple = mkColormap(["indigo", "lemonchiffon", "firebrick"]) 

226 orangeBlue = mkColormap(["darkOrange", "thistle", "midnightblue"]) 

227 

228 xCol = self.xAxisLabel 

229 yCol = self.yAxisLabel 

230 zCol = self.zAxisLabel # noqa: F841 

231 

232 toPlotList = [] 

233 # For galaxies 

234 if "galaxies" in self.plotTypes: 

235 sortedArrs = self.sortAllArrays( 

236 [data["zGalaxies"], data["xGalaxies"], data["yGalaxies"], data["galaxyStatMask"]] 

237 ) 

238 [colorValsGalaxies, xsGalaxies, ysGalaxies, statGalaxies] = sortedArrs 

239 statGalMed, statGalMad, galStatsText = self.statsAndText(colorValsGalaxies, mask=statGalaxies) 

240 # Add statistics 

241 bbox = dict(facecolor="lemonchiffon", alpha=0.5, edgecolor="none") 

242 # Check if plotting stars and galaxies, if so move the 

243 # text box so that both can be seen. Needs to be 

244 # > 2 becuase not being plotted points are assigned 0 

245 if len(self.plotTypes) > 2: 

246 boxLoc = 0.63 

247 else: 

248 boxLoc = 0.8 

249 ax.text(boxLoc, 0.91, galStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox) 

250 toPlotList.append((xsGalaxies, ysGalaxies, colorValsGalaxies, redPurple, "Galaxies")) 

251 

252 # For stars 

253 if "stars" in self.plotTypes: 

254 sortedArrs = self.sortAllArrays( 

255 [data["zStars"], data["xStars"], data["yStars"], data["starStatMask"]] 

256 ) 

257 [colorValsStars, xsStars, ysStars, statStars] = sortedArrs 

258 statStarMed, statStarMad, starStatsText = self.statsAndText(colorValsStars, mask=statStars) 

259 # Add statistics 

260 bbox = dict(facecolor="paleturquoise", alpha=0.5, edgecolor="none") 

261 ax.text(0.8, 0.91, starStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox) 

262 toPlotList.append((xsStars, ysStars, colorValsStars, blueGreen, "Stars")) 

263 

264 # For unknowns 

265 if "unknown" in self.plotTypes: 

266 sortedArrs = self.sortAllArrays( 

267 [data["zUnknowns"], data["xUnknowns"], data["yUnknowns"], data["unknownStatMask"]] 

268 ) 

269 [colorValsUnknowns, xsUnknowns, ysUnknowns, statUnknowns] = sortedArrs 

270 statUnknownMed, statUnknownMad, unknownStatsText = self.statsAndText( 

271 colorValsUnknowns, mask=statUnknowns 

272 ) 

273 bbox = dict(facecolor="green", alpha=0.2, edgecolor="none") 

274 ax.text(0.8, 0.91, unknownStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox) 

275 toPlotList.append((xsUnknowns, ysUnknowns, colorValsUnknowns, "viridis", "Unknown")) 

276 

277 if "any" in self.plotTypes: 

278 sortedArrs = self.sortAllArrays([data["z"], data["x"], data["y"], data["statMask"]]) 

279 [colorValsAny, xs, ys, statAny] = sortedArrs 

280 statAnyMed, statAnyMad, anyStatsText = self.statsAndText(colorValsAny, mask=statAny) 

281 bbox = dict(facecolor="purple", alpha=0.2, edgecolor="none") 

282 ax.text(0.8, 0.91, anyStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox) 

283 toPlotList.append((xs, ys, colorValsAny, orangeBlue, "All")) 

284 

285 # Corner plot of patches showing summary stat in each 

286 if self.plotOutlines: 

287 patches = [] 

288 for dataId in sumStats.keys(): 

289 (corners, _) = sumStats[dataId] 

290 ra = corners[0][0].asDegrees() 

291 dec = corners[0][1].asDegrees() 

292 xy = (ra, dec) 

293 width = corners[2][0].asDegrees() - ra 

294 height = corners[2][1].asDegrees() - dec 

295 patches.append(Rectangle(xy, width, height, alpha=0.3)) 

296 ras = [ra.asDegrees() for (ra, dec) in corners] 

297 decs = [dec.asDegrees() for (ra, dec) in corners] 

298 ax.plot(ras + [ras[0]], decs + [decs[0]], "k", lw=0.5) 

299 cenX = ra + width / 2 

300 cenY = dec + height / 2 

301 if dataId == "tract": 

302 minRa = np.min(ras) 

303 minDec = np.min(decs) 

304 maxRa = np.max(ras) 

305 maxDec = np.max(decs) 

306 if dataId != "tract": 

307 ax.annotate( 

308 dataId, 

309 (cenX, cenY), 

310 color="k", 

311 fontsize=5, 

312 ha="center", 

313 va="center", 

314 path_effects=[pathEffects.withStroke(linewidth=2, foreground="w")], 

315 ) 

316 

317 for (i, (xs, ys, colorVals, cmap, label)) in enumerate(toPlotList): 

318 if not self.plotOutlines or "tract" not in sumStats.keys(): 

319 minRa = np.min(xs) 

320 maxRa = np.max(xs) 

321 minDec = np.min(ys) 

322 maxDec = np.max(ys) 

323 # Avoid identical end points which causes problems in binning 

324 if minRa == maxRa: 

325 maxRa += 1e-5 # There is no reason to pick this number in particular 

326 if minDec == maxDec: 

327 maxDec += 1e-5 # There is no reason to pick this number in particular 

328 med = np.median(colorVals) 

329 mad = sigmaMad(colorVals) 

330 vmin = med - 2 * mad 

331 vmax = med + 2 * mad 

332 if self.fixAroundZero: 

333 scaleEnd = np.max([np.abs(vmin), np.abs(vmax)]) 

334 vmin = -1 * scaleEnd 

335 vmax = scaleEnd 

336 nBins = 45 

337 xBinEdges = np.linspace(minRa, maxRa, nBins + 1) 

338 yBinEdges = np.linspace(minDec, maxDec, nBins + 1) 

339 binnedStats, xEdges, yEdges, binNums = binned_statistic_2d( 

340 xs, ys, colorVals, statistic="median", bins=(xBinEdges, yBinEdges) 

341 ) 

342 

343 if len(xs) > 5000: 

344 s = 500 / (len(xs) ** 0.5) 

345 lw = (s**0.5) / 10 

346 plotOut = ax.imshow( 

347 binnedStats.T, 

348 cmap=cmap, 

349 extent=[xEdges[0], xEdges[-1], yEdges[-1], yEdges[0]], 

350 vmin=vmin, 

351 vmax=vmax, 

352 ) 

353 # find the most extreme 15% of points, because the list 

354 # is ordered by the distance from the median this is just 

355 # the final 15% of points 

356 extremes = int(np.floor((len(xs) / 100)) * 85) 

357 ax.scatter( 

358 xs[extremes:], 

359 ys[extremes:], 

360 c=colorVals[extremes:], 

361 s=s, 

362 cmap=cmap, 

363 vmin=vmin, 

364 vmax=vmax, 

365 edgecolor="white", 

366 linewidths=lw, 

367 ) 

368 

369 else: 

370 plotOut = ax.scatter( 

371 xs, 

372 ys, 

373 c=colorVals, 

374 cmap=cmap, 

375 s=7, 

376 vmin=vmin, 

377 vmax=vmax, 

378 edgecolor="white", 

379 linewidths=0.2, 

380 ) 

381 

382 cax = fig.add_axes([0.87 + i * 0.04, 0.11, 0.04, 0.77]) 

383 plt.colorbar(plotOut, cax=cax, extend="both") 

384 colorBarLabel = "{}: {}".format(self.zAxisLabel, label) 

385 text = cax.text( 

386 0.5, 

387 0.5, 

388 colorBarLabel, 

389 color="k", 

390 rotation="vertical", 

391 transform=cax.transAxes, 

392 ha="center", 

393 va="center", 

394 fontsize=10, 

395 ) 

396 text.set_path_effects([pathEffects.Stroke(linewidth=3, foreground="w"), pathEffects.Normal()]) 

397 cax.tick_params(labelsize=7) 

398 

399 if i == 0 and len(toPlotList) > 1: 

400 cax.yaxis.set_ticks_position("left") 

401 

402 ax.set_xlabel(xCol) 

403 ax.set_ylabel(yCol) 

404 ax.tick_params(axis="x", labelrotation=25) 

405 ax.tick_params(labelsize=7) 

406 

407 ax.set_aspect("equal") 

408 plt.draw() 

409 

410 # Find some useful axis limits 

411 lenXs = [len(xs) for (xs, _, _, _, _) in toPlotList] 

412 if lenXs != [] and np.max(lenXs) > 1000: 

413 padRa = (maxRa - minRa) / 10 

414 padDec = (maxDec - minDec) / 10 

415 ax.set_xlim(maxRa + padRa, minRa - padRa) 

416 ax.set_ylim(minDec - padDec, maxDec + padDec) 

417 else: 

418 ax.invert_xaxis() 

419 

420 # Add useful information to the plot 

421 plt.subplots_adjust(wspace=0.0, hspace=0.0, right=0.85) 

422 fig = plt.gcf() 

423 fig = addPlotInfo(fig, plotInfo) 

424 

425 return fig