Coverage for python/lsst/analysis/tools/actions/plot/skyPlot.py: 12%

185 statements  

« prev     ^ index     » next       coverage.py v6.4.4, created at 2022-08-18 12:39 -0700

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23from typing import Mapping 

24 

25import matplotlib.patheffects as pathEffects 

26import matplotlib.pyplot as plt 

27import numpy as np 

28from lsst.pex.config import Field, ListField 

29from matplotlib.figure import Figure 

30from matplotlib.patches import Rectangle 

31from scipy.stats import binned_statistic_2d 

32from scipy.stats import median_absolute_deviation as sigmaMad 

33 

34from ...interfaces import KeyedData, KeyedDataSchema, PlotAction, Scalar, Vector 

35from .plotUtils import addPlotInfo, extremaSort, mkColormap 

36 

37# from .plotUtils import generateSummaryStats, parsePlotInfo 

38 

39 

40class SkyPlot(PlotAction): 

41 

42 xAxisLabel = Field[str](doc="Label to use for the x axis.", optional=False) 

43 yAxisLabel = Field[str](doc="Label to use for the y axis.", optional=False) 

44 zAxisLabel = Field[str](doc="Label to use for the z axis.", optional=False) 

45 

46 fixAroundZero = Field[bool]( 

47 doc="Fix the center of the colorscale to be zero.", 

48 default=False, 

49 ) 

50 

51 plotOutlines = Field[bool]( 

52 doc="Plot the outlines of the ccds/patches?", 

53 default=True, 

54 ) 

55 

56 plotTypes = ListField[str]( 

57 doc="Selection of types of objects to plot. Can take any combination of" 

58 " stars, galaxies, unknown, mag, any.", 

59 optional=False, 

60 # itemCheck=_validatePlotTypes, 

61 ) 

62 

63 plotName = Field[str](doc="The name for the plot.", optional=False) 

64 

65 fixAroundZero = Field[bool]( 

66 doc="Fix the colorbar to be symmetric around zero.", 

67 default=False, 

68 ) 

69 

70 def getInputSchema(self, **kwargs) -> KeyedDataSchema: 

71 base = [] 

72 if "stars" in self.plotTypes: # type: ignore 

73 base.append(("xStars", Vector)) 

74 base.append(("yStars", Vector)) 

75 base.append(("zStars", Vector)) 

76 base.append(("starStatMask", Vector)) 

77 if "galaxies" in self.plotTypes: # type: ignore 

78 base.append(("xGalaxies", Vector)) 

79 base.append(("yGalaxies", Vector)) 

80 base.append(("zGalaxies", Vector)) 

81 base.append(("galaxyStatMask", Vector)) 

82 if "unknown" in self.plotTypes: # type: ignore 

83 base.append(("xUnknowns", Vector)) 

84 base.append(("yUnknowns", Vector)) 

85 base.append(("zUnknowns", Vector)) 

86 base.append(("unknownStatMask", Vector)) 

87 if "any" in self.plotTypes: # type: ignore 

88 base.append(("x", Vector)) 

89 base.append(("y", Vector)) 

90 base.append(("z", Vector)) 

91 base.append(("statMask", Vector)) 

92 

93 return base 

94 

95 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure: 

96 self._validateInput(data, **kwargs) 

97 return self.makePlot(data, **kwargs) 

98 # table is a dict that needs: x, y, run, skymap, filter, tract, 

99 

100 def _validateInput(self, data: KeyedData, **kwargs) -> None: 

101 """NOTE currently can only check that something is not a Scalar, not 

102 check that the data is consistent with Vector 

103 """ 

104 needed = self.getInputSchema(**kwargs) 

105 if remainder := {key.format(**kwargs) for key, _ in needed} - { 

106 key.format(**kwargs) for key in data.keys() 

107 }: 

108 raise ValueError(f"Task needs keys {remainder} but they were not found in input") 

109 for name, typ in needed: 

110 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar) 

111 if isScalar and typ != Scalar: 

112 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}") 

113 

114 def sortAllArrays(self, arrsToSort): 

115 """Sort one array and then return all the others in 

116 the associated order. 

117 """ 

118 ids = extremaSort(arrsToSort[0]) 

119 for (i, arr) in enumerate(arrsToSort): 

120 arrsToSort[i] = arr[ids] 

121 return arrsToSort 

122 

123 def statsAndText(self, arr, mask=None): 

124 """Calculate some stats from an array and return them 

125 and some text. 

126 """ 

127 numPoints = len(arr) 

128 if mask is not None: 

129 arr = arr[mask] 

130 med = np.nanmedian(arr) 

131 sigMad = sigmaMad(arr, nan_policy="omit") 

132 

133 statsText = ( 

134 "Median: {:0.2f}\n".format(med) 

135 + r"$\sigma_{MAD}$: " 

136 + "{:0.2f}\n".format(sigMad) 

137 + r"n$_{points}$: " 

138 + "{}".format(numPoints) 

139 ) 

140 

141 return med, sigMad, statsText 

142 

143 def makePlot( 

144 self, data: KeyedData, plotInfo: Mapping[str, str] = None, sumStats: Mapping = {}, **kwargs 

145 ) -> Figure: 

146 """Prep the catalogue and then make a skyPlot of the given column. 

147 

148 Parameters 

149 ---------- 

150 catPlot : `pandas.core.frame.DataFrame` 

151 The catalog to plot the points from. 

152 dataId : 

153 `lsst.daf.butler.core.dimensions._coordinate._ExpandedTupleDataCoordinate` 

154 The dimensions that the plot is being made from. 

155 runName : `str` 

156 The name of the collection that the plot is written out to. 

157 skymap : `lsst.skymap` 

158 The skymap used to define the patch boundaries. 

159 tableName : `str` 

160 The type of table used to make the plot. 

161 

162 Returns 

163 ------- 

164 `pipeBase.Struct` containing: 

165 skyPlot : `matplotlib.figure.Figure` 

166 The resulting figure. 

167 

168 Notes 

169 ----- 

170 The catalogue is first narrowed down using the selectors specified in 

171 `self.config.selectorActions`. 

172 If the column names are 'Functor' then the functors specified in 

173 `self.config.axisFunctors` are used to calculate the required values. 

174 After this the following functions are run: 

175 

176 `parsePlotInfo` which uses the dataId, runName and tableName to add 

177 useful information to the plot. 

178 

179 `generateSummaryStats` which parses the skymap to give the corners of 

180 the patches for later plotting and calculates some basic statistics 

181 in each patch for the column in self.config.axisActions['zAction']. 

182 

183 `SkyPlot` which makes the plot of the sky distribution of 

184 `self.config.axisActions['zAction']`. 

185 

186 Makes a generic plot showing the value at given points on the sky. 

187 

188 Parameters 

189 ---------- 

190 catPlot : `pandas.core.frame.DataFrame` 

191 The catalog to plot the points from. 

192 plotInfo : `dict` 

193 A dictionary of information about the data being plotted with keys: 

194 ``"run"`` 

195 The output run for the plots (`str`). 

196 ``"skymap"`` 

197 The type of skymap used for the data (`str`). 

198 ``"filter"`` 

199 The filter used for this data (`str`). 

200 ``"tract"`` 

201 The tract that the data comes from (`str`). 

202 sumStats : `dict` 

203 A dictionary where the patchIds are the keys which store the R.A. 

204 and dec of the corners of the patch. 

205 

206 Returns 

207 ------- 

208 fig : `matplotlib.figure.Figure` 

209 The resulting figure. 

210 

211 Notes 

212 ----- 

213 Uses the config options `self.config.xColName` and 

214 `self.config.yColName` to plot points color coded by 

215 `self.config.axisActions['zAction']`. 

216 The points plotted are those selected by the selectors specified in 

217 `self.config.selectorActions`. 

218 """ 

219 fig = plt.figure(dpi=300) 

220 ax = fig.add_subplot(111) 

221 

222 # Make divergent colormaps for stars, galaxes and all the points 

223 blueGreen = mkColormap(["midnightblue", "lightcyan", "darkgreen"]) 

224 redPurple = mkColormap(["indigo", "lemonchiffon", "firebrick"]) 

225 orangeBlue = mkColormap(["darkOrange", "thistle", "midnightblue"]) 

226 

227 xCol = self.xAxisLabel 

228 yCol = self.yAxisLabel 

229 zCol = self.zAxisLabel # noqa: F841 

230 

231 toPlotList = [] 

232 # For galaxies 

233 if "galaxies" in self.plotTypes: 

234 sortedArrs = self.sortAllArrays( 

235 [data["zGalaxies"], data["xGalaxies"], data["yGalaxies"], data["galaxyStatMask"]] 

236 ) 

237 [colorValsGalaxies, xsGalaxies, ysGalaxies, statGalaxies] = sortedArrs 

238 statGalMed, statGalMad, galStatsText = self.statsAndText(colorValsGalaxies, mask=statGalaxies) 

239 # Add statistics 

240 bbox = dict(facecolor="lemonchiffon", alpha=0.5, edgecolor="none") 

241 # Check if plotting stars and galaxies, if so move the 

242 # text box so that both can be seen. Needs to be 

243 # > 2 becuase not being plotted points are assigned 0 

244 if len(self.plotTypes) > 2: 

245 boxLoc = 0.63 

246 else: 

247 boxLoc = 0.8 

248 ax.text(boxLoc, 0.91, galStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox) 

249 toPlotList.append((xsGalaxies, ysGalaxies, colorValsGalaxies, redPurple, "Galaxies")) 

250 

251 # For stars 

252 if "stars" in self.plotTypes: 

253 sortedArrs = self.sortAllArrays( 

254 [data["zStars"], data["xStars"], data["yStars"], data["starStatMask"]] 

255 ) 

256 [colorValsStars, xsStars, ysStars, statStars] = sortedArrs 

257 statStarMed, statStarMad, starStatsText = self.statsAndText(colorValsStars, mask=statStars) 

258 # Add statistics 

259 bbox = dict(facecolor="paleturquoise", alpha=0.5, edgecolor="none") 

260 ax.text(0.8, 0.91, starStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox) 

261 toPlotList.append((xsStars, ysStars, colorValsStars, blueGreen, "Stars")) 

262 

263 # For unknowns 

264 if "unknown" in self.plotTypes: 

265 sortedArrs = self.sortAllArrays( 

266 [data["zUnknowns"], data["xUnknowns"], data["yUnknowns"], data["unknownStatMask"]] 

267 ) 

268 [colorValsUnknowns, xsUnknowns, ysUnknowns, statUnknowns] = sortedArrs 

269 statUnknownMed, statUnknownMad, unknownStatsText = self.statsAndText( 

270 colorValsUnknowns, mask=statUnknowns 

271 ) 

272 bbox = dict(facecolor="green", alpha=0.2, edgecolor="none") 

273 ax.text(0.8, 0.91, unknownStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox) 

274 toPlotList.append((xsUnknowns, ysUnknowns, colorValsUnknowns, "viridis", "Unknown")) 

275 

276 if "any" in self.plotTypes: 

277 sortedArrs = self.sortAllArrays([data["z"], data["x"], data["y"], data["statMask"]]) 

278 [colorValsAny, xs, ys, statAny] = sortedArrs 

279 statAnyMed, statAnyMad, anyStatsText = self.statsAndText(colorValsAny, mask=statAny) 

280 bbox = dict(facecolor="purple", alpha=0.2, edgecolor="none") 

281 ax.text(0.8, 0.91, anyStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox) 

282 toPlotList.append((xs, ys, colorValsAny, orangeBlue, "All")) 

283 

284 # Corner plot of patches showing summary stat in each 

285 if self.plotOutlines: 

286 patches = [] 

287 for dataId in sumStats.keys(): 

288 (corners, _) = sumStats[dataId] 

289 ra = corners[0][0].asDegrees() 

290 dec = corners[0][1].asDegrees() 

291 xy = (ra, dec) 

292 width = corners[2][0].asDegrees() - ra 

293 height = corners[2][1].asDegrees() - dec 

294 patches.append(Rectangle(xy, width, height, alpha=0.3)) 

295 ras = [ra.asDegrees() for (ra, dec) in corners] 

296 decs = [dec.asDegrees() for (ra, dec) in corners] 

297 ax.plot(ras + [ras[0]], decs + [decs[0]], "k", lw=0.5) 

298 cenX = ra + width / 2 

299 cenY = dec + height / 2 

300 if dataId == "tract": 

301 minRa = np.min(ras) 

302 minDec = np.min(decs) 

303 maxRa = np.max(ras) 

304 maxDec = np.max(decs) 

305 if dataId != "tract": 

306 ax.annotate( 

307 dataId, 

308 (cenX, cenY), 

309 color="k", 

310 fontsize=5, 

311 ha="center", 

312 va="center", 

313 path_effects=[pathEffects.withStroke(linewidth=2, foreground="w")], 

314 ) 

315 

316 for (i, (xs, ys, colorVals, cmap, label)) in enumerate(toPlotList): 

317 if not self.plotOutlines or "tract" not in sumStats.keys(): 

318 minRa = np.min(xs) 

319 maxRa = np.max(xs) 

320 minDec = np.min(ys) 

321 maxDec = np.max(ys) 

322 # Avoid identical end points which causes problems in binning 

323 if minRa == maxRa: 

324 maxRa += 1e-5 # There is no reason to pick this number in particular 

325 if minDec == maxDec: 

326 maxDec += 1e-5 # There is no reason to pick this number in particular 

327 med = np.median(colorVals) 

328 mad = sigmaMad(colorVals) 

329 vmin = med - 2 * mad 

330 vmax = med + 2 * mad 

331 if self.fixAroundZero: 

332 scaleEnd = np.max([np.abs(vmin), np.abs(vmax)]) 

333 vmin = -1 * scaleEnd 

334 vmax = scaleEnd 

335 nBins = 45 

336 xBinEdges = np.linspace(minRa, maxRa, nBins + 1) 

337 yBinEdges = np.linspace(minDec, maxDec, nBins + 1) 

338 binnedStats, xEdges, yEdges, binNums = binned_statistic_2d( 

339 xs, ys, colorVals, statistic="median", bins=(xBinEdges, yBinEdges) 

340 ) 

341 

342 if len(xs) > 5000: 

343 s = 500 / (len(xs) ** 0.5) 

344 lw = (s**0.5) / 10 

345 plotOut = ax.imshow( 

346 binnedStats.T, 

347 cmap=cmap, 

348 extent=[xEdges[0], xEdges[-1], yEdges[-1], yEdges[0]], 

349 vmin=vmin, 

350 vmax=vmax, 

351 ) 

352 # find the most extreme 15% of points, because the list 

353 # is ordered by the distance from the median this is just 

354 # the final 15% of points 

355 extremes = int(np.floor((len(xs) / 100)) * 85) 

356 ax.scatter( 

357 xs[extremes:], 

358 ys[extremes:], 

359 c=colorVals[extremes:], 

360 s=s, 

361 cmap=cmap, 

362 vmin=vmin, 

363 vmax=vmax, 

364 edgecolor="white", 

365 linewidths=lw, 

366 ) 

367 

368 else: 

369 plotOut = ax.scatter( 

370 xs, 

371 ys, 

372 c=colorVals, 

373 cmap=cmap, 

374 s=7, 

375 vmin=vmin, 

376 vmax=vmax, 

377 edgecolor="white", 

378 linewidths=0.2, 

379 ) 

380 

381 cax = fig.add_axes([0.87 + i * 0.04, 0.11, 0.04, 0.77]) 

382 plt.colorbar(plotOut, cax=cax, extend="both") 

383 colorBarLabel = "{}: {}".format(self.zAxisLabel, label) 

384 text = cax.text( 

385 0.5, 

386 0.5, 

387 colorBarLabel, 

388 color="k", 

389 rotation="vertical", 

390 transform=cax.transAxes, 

391 ha="center", 

392 va="center", 

393 fontsize=10, 

394 ) 

395 text.set_path_effects([pathEffects.Stroke(linewidth=3, foreground="w"), pathEffects.Normal()]) 

396 cax.tick_params(labelsize=7) 

397 

398 if i == 0 and len(toPlotList) > 1: 

399 cax.yaxis.set_ticks_position("left") 

400 

401 ax.set_xlabel(xCol) 

402 ax.set_ylabel(yCol) 

403 ax.tick_params(axis="x", labelrotation=25) 

404 ax.tick_params(labelsize=7) 

405 

406 ax.set_aspect("equal") 

407 plt.draw() 

408 

409 # Find some useful axis limits 

410 lenXs = [len(xs) for (xs, _, _, _, _) in toPlotList] 

411 if lenXs != [] and np.max(lenXs) > 1000: 

412 padRa = (maxRa - minRa) / 10 

413 padDec = (maxDec - minDec) / 10 

414 ax.set_xlim(maxRa + padRa, minRa - padRa) 

415 ax.set_ylim(minDec - padDec, maxDec + padDec) 

416 else: 

417 ax.invert_xaxis() 

418 

419 # Add useful information to the plot 

420 plt.subplots_adjust(wspace=0.0, hspace=0.0, right=0.85) 

421 fig = plt.gcf() 

422 fig = addPlotInfo(fig, plotInfo) 

423 

424 return fig