Coverage for python/lsst/analysis/tools/actions/plot/skyPlot.py: 10%

188 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-12-06 02:44 -0800

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24from typing import Mapping, Optional 

25 

26import matplotlib.patheffects as pathEffects 

27import matplotlib.pyplot as plt 

28import numpy as np 

29from lsst.pex.config import Field, ListField 

30from matplotlib.figure import Figure 

31from matplotlib.patches import Rectangle 

32from scipy.stats import binned_statistic_2d 

33 

34from ...interfaces import KeyedData, KeyedDataSchema, PlotAction, Scalar, Vector 

35from ...statistics import nansigmaMad, sigmaMad 

36from .plotUtils import addPlotInfo, extremaSort, mkColormap 

37 

38# from .plotUtils import generateSummaryStats, parsePlotInfo 

39 

40 

41class SkyPlot(PlotAction): 

42 

43 xAxisLabel = Field[str](doc="Label to use for the x axis.", optional=False) 

44 yAxisLabel = Field[str](doc="Label to use for the y axis.", optional=False) 

45 zAxisLabel = Field[str](doc="Label to use for the z axis.", optional=False) 

46 

47 plotOutlines = Field[bool]( 

48 doc="Plot the outlines of the ccds/patches?", 

49 default=True, 

50 ) 

51 

52 plotTypes = ListField[str]( 

53 doc="Selection of types of objects to plot. Can take any combination of" 

54 " stars, galaxies, unknown, mag, any.", 

55 optional=False, 

56 # itemCheck=_validatePlotTypes, 

57 ) 

58 

59 plotName = Field[str](doc="The name for the plot.", optional=False) 

60 

61 fixAroundZero = Field[bool]( 

62 doc="Fix the colorbar to be symmetric around zero.", 

63 default=False, 

64 ) 

65 

66 def getInputSchema(self, **kwargs) -> KeyedDataSchema: 

67 base = [] 

68 if "stars" in self.plotTypes: # type: ignore 

69 base.append(("xStars", Vector)) 

70 base.append(("yStars", Vector)) 

71 base.append(("zStars", Vector)) 

72 base.append(("starStatMask", Vector)) 

73 if "galaxies" in self.plotTypes: # type: ignore 

74 base.append(("xGalaxies", Vector)) 

75 base.append(("yGalaxies", Vector)) 

76 base.append(("zGalaxies", Vector)) 

77 base.append(("galaxyStatMask", Vector)) 

78 if "unknown" in self.plotTypes: # type: ignore 

79 base.append(("xUnknowns", Vector)) 

80 base.append(("yUnknowns", Vector)) 

81 base.append(("zUnknowns", Vector)) 

82 base.append(("unknownStatMask", Vector)) 

83 if "any" in self.plotTypes: # type: ignore 

84 base.append(("x", Vector)) 

85 base.append(("y", Vector)) 

86 base.append(("z", Vector)) 

87 base.append(("statMask", Vector)) 

88 

89 return base 

90 

91 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure: 

92 self._validateInput(data, **kwargs) 

93 return self.makePlot(data, **kwargs) 

94 # table is a dict that needs: x, y, run, skymap, filter, tract, 

95 

96 def _validateInput(self, data: KeyedData, **kwargs) -> None: 

97 """NOTE currently can only check that something is not a Scalar, not 

98 check that the data is consistent with Vector 

99 """ 

100 needed = self.getInputSchema(**kwargs) 

101 if remainder := {key.format(**kwargs) for key, _ in needed} - { 

102 key.format(**kwargs) for key in data.keys() 

103 }: 

104 raise ValueError(f"Task needs keys {remainder} but they were not found in input") 

105 for name, typ in needed: 

106 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar) 

107 if isScalar and typ != Scalar: 

108 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}") 

109 

110 def sortAllArrays(self, arrsToSort): 

111 """Sort one array and then return all the others in 

112 the associated order. 

113 """ 

114 ids = extremaSort(arrsToSort[0]) 

115 for (i, arr) in enumerate(arrsToSort): 

116 arrsToSort[i] = arr[ids] 

117 return arrsToSort 

118 

119 def statsAndText(self, arr, mask=None): 

120 """Calculate some stats from an array and return them 

121 and some text. 

122 """ 

123 numPoints = len(arr) 

124 if mask is not None: 

125 arr = arr[mask] 

126 med = np.nanmedian(arr) 

127 sigMad = nansigmaMad(arr) 

128 

129 statsText = ( 

130 "Median: {:0.2f}\n".format(med) 

131 + r"$\sigma_{MAD}$: " 

132 + "{:0.2f}\n".format(sigMad) 

133 + r"n$_{points}$: " 

134 + "{}".format(numPoints) 

135 ) 

136 

137 return med, sigMad, statsText 

138 

139 def makePlot( 

140 self, 

141 data: KeyedData, 

142 plotInfo: Optional[Mapping[str, str]] = None, 

143 sumStats: Optional[Mapping] = None, 

144 **kwargs, 

145 ) -> Figure: 

146 """Prep the catalogue and then make a skyPlot of the given column. 

147 

148 Parameters 

149 ---------- 

150 catPlot : `pandas.core.frame.DataFrame` 

151 The catalog to plot the points from. 

152 dataId : 

153 `lsst.daf.butler.core.dimensions._coordinate._ExpandedTupleDataCoordinate` 

154 The dimensions that the plot is being made from. 

155 runName : `str` 

156 The name of the collection that the plot is written out to. 

157 skymap : `lsst.skymap` 

158 The skymap used to define the patch boundaries. 

159 tableName : `str` 

160 The type of table used to make the plot. 

161 

162 Returns 

163 ------- 

164 `pipeBase.Struct` containing: 

165 skyPlot : `matplotlib.figure.Figure` 

166 The resulting figure. 

167 

168 Notes 

169 ----- 

170 The catalogue is first narrowed down using the selectors specified in 

171 `self.config.selectorActions`. 

172 If the column names are 'Functor' then the functors specified in 

173 `self.config.axisFunctors` are used to calculate the required values. 

174 After this the following functions are run: 

175 

176 `parsePlotInfo` which uses the dataId, runName and tableName to add 

177 useful information to the plot. 

178 

179 `generateSummaryStats` which parses the skymap to give the corners of 

180 the patches for later plotting and calculates some basic statistics 

181 in each patch for the column in self.config.axisActions['zAction']. 

182 

183 `SkyPlot` which makes the plot of the sky distribution of 

184 `self.config.axisActions['zAction']`. 

185 

186 Makes a generic plot showing the value at given points on the sky. 

187 

188 Parameters 

189 ---------- 

190 catPlot : `pandas.core.frame.DataFrame` 

191 The catalog to plot the points from. 

192 plotInfo : `dict` 

193 A dictionary of information about the data being plotted with keys: 

194 ``"run"`` 

195 The output run for the plots (`str`). 

196 ``"skymap"`` 

197 The type of skymap used for the data (`str`). 

198 ``"filter"`` 

199 The filter used for this data (`str`). 

200 ``"tract"`` 

201 The tract that the data comes from (`str`). 

202 sumStats : `dict` 

203 A dictionary where the patchIds are the keys which store the R.A. 

204 and dec of the corners of the patch. 

205 

206 Returns 

207 ------- 

208 fig : `matplotlib.figure.Figure` 

209 The resulting figure. 

210 

211 Notes 

212 ----- 

213 Uses the config options `self.config.xColName` and 

214 `self.config.yColName` to plot points color coded by 

215 `self.config.axisActions['zAction']`. 

216 The points plotted are those selected by the selectors specified in 

217 `self.config.selectorActions`. 

218 """ 

219 fig = plt.figure(dpi=300) 

220 ax = fig.add_subplot(111) 

221 

222 if sumStats is None: 

223 sumStats = {} 

224 

225 if plotInfo is None: 

226 plotInfo = {} 

227 

228 # Make divergent colormaps for stars, galaxes and all the points 

229 blueGreen = mkColormap(["midnightblue", "lightcyan", "darkgreen"]) 

230 redPurple = mkColormap(["indigo", "lemonchiffon", "firebrick"]) 

231 orangeBlue = mkColormap(["darkOrange", "thistle", "midnightblue"]) 

232 

233 xCol = self.xAxisLabel 

234 yCol = self.yAxisLabel 

235 zCol = self.zAxisLabel # noqa: F841 

236 

237 toPlotList = [] 

238 # For galaxies 

239 if "galaxies" in self.plotTypes: 

240 sortedArrs = self.sortAllArrays( 

241 [data["zGalaxies"], data["xGalaxies"], data["yGalaxies"], data["galaxyStatMask"]] 

242 ) 

243 [colorValsGalaxies, xsGalaxies, ysGalaxies, statGalaxies] = sortedArrs 

244 statGalMed, statGalMad, galStatsText = self.statsAndText(colorValsGalaxies, mask=statGalaxies) 

245 # Add statistics 

246 bbox = dict(facecolor="lemonchiffon", alpha=0.5, edgecolor="none") 

247 # Check if plotting stars and galaxies, if so move the 

248 # text box so that both can be seen. Needs to be 

249 # > 2 becuase not being plotted points are assigned 0 

250 if len(self.plotTypes) > 2: 

251 boxLoc = 0.63 

252 else: 

253 boxLoc = 0.8 

254 ax.text(boxLoc, 0.91, galStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox) 

255 toPlotList.append((xsGalaxies, ysGalaxies, colorValsGalaxies, redPurple, "Galaxies")) 

256 

257 # For stars 

258 if "stars" in self.plotTypes: 

259 sortedArrs = self.sortAllArrays( 

260 [data["zStars"], data["xStars"], data["yStars"], data["starStatMask"]] 

261 ) 

262 [colorValsStars, xsStars, ysStars, statStars] = sortedArrs 

263 statStarMed, statStarMad, starStatsText = self.statsAndText(colorValsStars, mask=statStars) 

264 # Add statistics 

265 bbox = dict(facecolor="paleturquoise", alpha=0.5, edgecolor="none") 

266 ax.text(0.8, 0.91, starStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox) 

267 toPlotList.append((xsStars, ysStars, colorValsStars, blueGreen, "Stars")) 

268 

269 # For unknowns 

270 if "unknown" in self.plotTypes: 

271 sortedArrs = self.sortAllArrays( 

272 [data["zUnknowns"], data["xUnknowns"], data["yUnknowns"], data["unknownStatMask"]] 

273 ) 

274 [colorValsUnknowns, xsUnknowns, ysUnknowns, statUnknowns] = sortedArrs 

275 statUnknownMed, statUnknownMad, unknownStatsText = self.statsAndText( 

276 colorValsUnknowns, mask=statUnknowns 

277 ) 

278 bbox = dict(facecolor="green", alpha=0.2, edgecolor="none") 

279 ax.text(0.8, 0.91, unknownStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox) 

280 toPlotList.append((xsUnknowns, ysUnknowns, colorValsUnknowns, "viridis", "Unknown")) 

281 

282 if "any" in self.plotTypes: 

283 sortedArrs = self.sortAllArrays([data["z"], data["x"], data["y"], data["statMask"]]) 

284 [colorValsAny, xs, ys, statAny] = sortedArrs 

285 statAnyMed, statAnyMad, anyStatsText = self.statsAndText(colorValsAny, mask=statAny) 

286 bbox = dict(facecolor="purple", alpha=0.2, edgecolor="none") 

287 ax.text(0.8, 0.91, anyStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox) 

288 toPlotList.append((xs, ys, colorValsAny, orangeBlue, "All")) 

289 

290 # Corner plot of patches showing summary stat in each 

291 if self.plotOutlines: 

292 patches = [] 

293 for dataId in sumStats.keys(): 

294 (corners, _) = sumStats[dataId] 

295 ra = corners[0][0].asDegrees() 

296 dec = corners[0][1].asDegrees() 

297 xy = (ra, dec) 

298 width = corners[2][0].asDegrees() - ra 

299 height = corners[2][1].asDegrees() - dec 

300 patches.append(Rectangle(xy, width, height, alpha=0.3)) 

301 ras = [ra.asDegrees() for (ra, dec) in corners] 

302 decs = [dec.asDegrees() for (ra, dec) in corners] 

303 ax.plot(ras + [ras[0]], decs + [decs[0]], "k", lw=0.5) 

304 cenX = ra + width / 2 

305 cenY = dec + height / 2 

306 if dataId == "tract": 

307 minRa = np.min(ras) 

308 minDec = np.min(decs) 

309 maxRa = np.max(ras) 

310 maxDec = np.max(decs) 

311 if dataId != "tract": 

312 ax.annotate( 

313 dataId, 

314 (cenX, cenY), 

315 color="k", 

316 fontsize=5, 

317 ha="center", 

318 va="center", 

319 path_effects=[pathEffects.withStroke(linewidth=2, foreground="w")], 

320 ) 

321 

322 for (i, (xs, ys, colorVals, cmap, label)) in enumerate(toPlotList): 

323 if not self.plotOutlines or "tract" not in sumStats.keys(): 

324 minRa = np.min(xs) 

325 maxRa = np.max(xs) 

326 minDec = np.min(ys) 

327 maxDec = np.max(ys) 

328 # Avoid identical end points which causes problems in binning 

329 if minRa == maxRa: 

330 maxRa += 1e-5 # There is no reason to pick this number in particular 

331 if minDec == maxDec: 

332 maxDec += 1e-5 # There is no reason to pick this number in particular 

333 med = np.nanmedian(colorVals) 

334 mad = sigmaMad(colorVals, nan_policy="omit") 

335 vmin = med - 2 * mad 

336 vmax = med + 2 * mad 

337 if self.fixAroundZero: 

338 scaleEnd = np.max([np.abs(vmin), np.abs(vmax)]) 

339 vmin = -1 * scaleEnd 

340 vmax = scaleEnd 

341 nBins = 45 

342 xBinEdges = np.linspace(minRa, maxRa, nBins + 1) 

343 yBinEdges = np.linspace(minDec, maxDec, nBins + 1) 

344 binnedStats, xEdges, yEdges, binNums = binned_statistic_2d( 

345 xs, ys, colorVals, statistic="median", bins=(xBinEdges, yBinEdges) 

346 ) 

347 

348 if len(xs) > 5000: 

349 s = 500 / (len(xs) ** 0.5) 

350 lw = (s**0.5) / 10 

351 plotOut = ax.imshow( 

352 binnedStats.T, 

353 cmap=cmap, 

354 extent=[xEdges[0], xEdges[-1], yEdges[-1], yEdges[0]], 

355 vmin=vmin, 

356 vmax=vmax, 

357 ) 

358 # find the most extreme 15% of points, because the list 

359 # is ordered by the distance from the median this is just 

360 # the final 15% of points 

361 extremes = int(np.floor((len(xs) / 100)) * 85) 

362 ax.scatter( 

363 xs[extremes:], 

364 ys[extremes:], 

365 c=colorVals[extremes:], 

366 s=s, 

367 cmap=cmap, 

368 vmin=vmin, 

369 vmax=vmax, 

370 edgecolor="white", 

371 linewidths=lw, 

372 ) 

373 

374 else: 

375 plotOut = ax.scatter( 

376 xs, 

377 ys, 

378 c=colorVals, 

379 cmap=cmap, 

380 s=7, 

381 vmin=vmin, 

382 vmax=vmax, 

383 edgecolor="white", 

384 linewidths=0.2, 

385 ) 

386 

387 cax = fig.add_axes([0.87 + i * 0.04, 0.11, 0.04, 0.77]) 

388 plt.colorbar(plotOut, cax=cax, extend="both") 

389 colorBarLabel = "{}: {}".format(self.zAxisLabel, label) 

390 text = cax.text( 

391 0.5, 

392 0.5, 

393 colorBarLabel, 

394 color="k", 

395 rotation="vertical", 

396 transform=cax.transAxes, 

397 ha="center", 

398 va="center", 

399 fontsize=10, 

400 ) 

401 text.set_path_effects([pathEffects.Stroke(linewidth=3, foreground="w"), pathEffects.Normal()]) 

402 cax.tick_params(labelsize=7) 

403 

404 if i == 0 and len(toPlotList) > 1: 

405 cax.yaxis.set_ticks_position("left") 

406 

407 ax.set_xlabel(xCol) 

408 ax.set_ylabel(yCol) 

409 ax.tick_params(axis="x", labelrotation=25) 

410 ax.tick_params(labelsize=7) 

411 

412 ax.set_aspect("equal") 

413 plt.draw() 

414 

415 # Find some useful axis limits 

416 lenXs = [len(xs) for (xs, _, _, _, _) in toPlotList] 

417 if lenXs != [] and np.max(lenXs) > 1000: 

418 padRa = (maxRa - minRa) / 10 

419 padDec = (maxDec - minDec) / 10 

420 ax.set_xlim(maxRa + padRa, minRa - padRa) 

421 ax.set_ylim(minDec - padDec, maxDec + padDec) 

422 else: 

423 ax.invert_xaxis() 

424 

425 # Add useful information to the plot 

426 plt.subplots_adjust(wspace=0.0, hspace=0.0, right=0.85) 

427 fig = plt.gcf() 

428 fig = addPlotInfo(fig, plotInfo) 

429 

430 return fig