Coverage for python / lsst / analysis / tools / actions / plot / skyPlot.py: 12%

195 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-15 00:23 +0000

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ("SkyPlot",) 

25 

26from typing import Mapping, Optional 

27 

28import matplotlib.patheffects as pathEffects 

29import numpy as np 

30from lsst.pex.config import Field, ListField 

31from lsst.pex.config.configurableActions import ConfigurableActionField 

32from lsst.utils.plotting import ( 

33 divergent_cmap, 

34 galaxies_cmap, 

35 galaxies_color, 

36 make_figure, 

37 set_rubin_plotstyle, 

38 stars_cmap, 

39 stars_color, 

40) 

41from matplotlib.figure import Figure 

42from matplotlib.patches import Rectangle 

43 

44from ...interfaces import KeyedData, KeyedDataSchema, PlotAction, Scalar, Vector, VectorAction 

45from ...math import nanMedian, nanSigmaMad 

46from .calculateRange import Med2Mad 

47from .plotUtils import addPlotInfo, generateSummaryStats, plotProjectionWithBinning, sortAllArrays 

48 

49 

50class SkyPlot(PlotAction): 

51 """Plots the on sky distribution of a parameter. 

52 

53 Plots the values of the parameter given for the z axis 

54 according to the positions given for x and y. Optimised 

55 for use with RA and Dec. Also calculates some basic 

56 statistics and includes those on the plot. 

57 

58 The plotting of patch outlines requires patch information 

59 to be included as an additional parameter. 

60 """ 

61 

62 xAxisLabel = Field[str](doc="Label to use for the x axis.", optional=False) 

63 yAxisLabel = Field[str](doc="Label to use for the y axis.", optional=False) 

64 zAxisLabel = Field[str](doc="Label to use for the z axis.", optional=False) 

65 

66 plotOutlines = Field[bool]( 

67 doc="Plot the outlines of the ccds/patches?", 

68 default=True, 

69 ) 

70 

71 plotTypes = ListField[str]( 

72 doc="Selection of types of objects to plot. Can take any combination of" 

73 " stars, galaxies, unknown, mag, any.", 

74 optional=False, 

75 # itemCheck=_validatePlotTypes, 

76 ) 

77 

78 plotName = Field[str](doc="The name for the plot.", optional=False) 

79 

80 fixAroundZero = Field[bool]( 

81 doc="Fix the colorbar to be symmetric around zero.", 

82 default=False, 

83 ) 

84 

85 colorbarRange = ConfigurableActionField[VectorAction]( 

86 doc="Action to calculate the min and max of the colorbar range.", 

87 default=Med2Mad, 

88 ) 

89 

90 showExtremeOutliers = Field[bool]( 

91 doc="Show the x-y positions of extreme outlier values as overlaid scatter points.", 

92 default=True, 

93 ) 

94 

95 publicationStyle = Field[bool]( 

96 doc="Make a simplified plot for publication use.", 

97 default=False, 

98 ) 

99 

100 divergent = Field[bool]( 

101 doc="Use a divergent colormap?", 

102 default=False, 

103 ) 

104 

105 def getInputSchema(self, **kwargs) -> KeyedDataSchema: 

106 base = [] 

107 if "stars" in self.plotTypes: # type: ignore 

108 base.append(("xStars", Vector)) 

109 base.append(("yStars", Vector)) 

110 base.append(("zStars", Vector)) 

111 base.append(("starStatMask", Vector)) 

112 if "galaxies" in self.plotTypes: # type: ignore 

113 base.append(("xGalaxies", Vector)) 

114 base.append(("yGalaxies", Vector)) 

115 base.append(("zGalaxies", Vector)) 

116 base.append(("galaxyStatMask", Vector)) 

117 if "unknown" in self.plotTypes: # type: ignore 

118 base.append(("xUnknowns", Vector)) 

119 base.append(("yUnknowns", Vector)) 

120 base.append(("zUnknowns", Vector)) 

121 base.append(("unknownStatMask", Vector)) 

122 if "any" in self.plotTypes: # type: ignore 

123 base.append(("x", Vector)) 

124 base.append(("y", Vector)) 

125 base.append(("z", Vector)) 

126 base.append(("statMask", Vector)) 

127 

128 return base 

129 

130 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure: 

131 self._validateInput(data, **kwargs) 

132 return self.makePlot(data, **kwargs) 

133 # table is a dict that needs: x, y, run, skymap, filter, tract, 

134 

135 def _validateInput(self, data: KeyedData, **kwargs) -> None: 

136 """NOTE currently can only check that something is not a Scalar, not 

137 check that the data is consistent with Vector 

138 """ 

139 needed = self.getInputSchema(**kwargs) 

140 if remainder := {key.format(**kwargs) for key, _ in needed} - { 

141 key.format(**kwargs) for key in data.keys() 

142 }: 

143 raise ValueError(f"Task needs keys {remainder} but they were not found in input") 

144 for name, typ in needed: 

145 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar) 

146 if isScalar and typ != Scalar: 

147 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}") 

148 

149 def statsAndText(self, arr, mask=None): 

150 """Calculate some stats from an array and return them 

151 and some text. 

152 """ 

153 numPoints = len(arr) 

154 if mask is not None: 

155 arr = arr[mask] 

156 med = nanMedian(arr) 

157 sigMad = nanSigmaMad(arr) 

158 

159 statsText = ( 

160 "Median: {:0.2f}\n".format(med) 

161 + r"$\sigma_{MAD}$: " 

162 + "{:0.2f}\n".format(sigMad) 

163 + r"n$_{points}$: " 

164 + "{}".format(numPoints) 

165 ) 

166 

167 return med, sigMad, statsText 

168 

169 def makePlot( 

170 self, 

171 data: KeyedData, 

172 plotInfo: Optional[Mapping[str, str]] = None, 

173 sumStats: Optional[Mapping] = None, 

174 **kwargs, 

175 ) -> Figure: 

176 """Make a skyPlot of the given data. 

177 

178 Parameters 

179 ---------- 

180 data : `KeyedData` 

181 The catalog to plot the points from. 

182 plotInfo : `dict` 

183 A dictionary of information about the data being plotted with keys: 

184 

185 ``"run"`` 

186 The output run for the plots (`str`). 

187 ``"skymap"`` 

188 The type of skymap used for the data (`str`). 

189 ``"filter"`` 

190 The filter used for this data (`str`). 

191 ``"tract"`` 

192 The tract that the data comes from (`str`). 

193 

194 sumStats : `dict` 

195 A dictionary where the patchIds are the keys which store the R.A. 

196 and dec of the corners of the patch. 

197 

198 Returns 

199 ------- 

200 `pipeBase.Struct` containing: 

201 skyPlot : `matplotlib.figure.Figure` 

202 The resulting figure. 

203 

204 Notes 

205 ----- 

206 Expects the data to contain slightly different things 

207 depending on the types specified in plotTypes. This 

208 is handled automatically if you go through the pipetask 

209 framework but if you call this method separately then you 

210 need to make sure that data contains what the code is expecting. 

211 

212 If stars is in the plot types given then it is expected that 

213 data contains: xStars, yStars, zStars and starStatMask. 

214 

215 If galaxies is present: xGalaxies, yGalaxies, zGalaxies and 

216 galaxyStatsMask. 

217 

218 If unknown is present: xUnknowns, yUnknowns, zUnknowns and 

219 unknownStatMask. 

220 

221 If any is specified: x, y, z, statMask. 

222 

223 These options are not exclusive and multiple can be specified 

224 and thus need to be present in data. 

225 

226 Examples 

227 -------- 

228 An example of the plot produced from this code is here: 

229 

230 .. image:: /_static/analysis_tools/skyPlotExample.png 

231 

232 For a detailed example of how to make a plot from the command line 

233 please see the 

234 :ref:`getting started guide<analysis-tools-getting-started>`. 

235 """ 

236 

237 set_rubin_plotstyle() 

238 fig = make_figure() 

239 ax = fig.add_subplot(111) 

240 

241 if sumStats is None: 

242 if self.plotOutlines and "patch" in data.keys(): 

243 sumStats = generateSummaryStats(data, kwargs["skymap"], plotInfo) 

244 else: 

245 sumStats = {} 

246 

247 if plotInfo is None: 

248 plotInfo = {} 

249 

250 # Make divergent colormaps for stars, galaxes and all the points 

251 starsCmap = stars_cmap() 

252 galsCmap = galaxies_cmap() 

253 

254 xCol = self.xAxisLabel 

255 yCol = self.yAxisLabel 

256 zCol = self.zAxisLabel # noqa: F841 

257 

258 toPlotList = [] 

259 # For galaxies 

260 if "galaxies" in self.plotTypes: 

261 sortedArrs = sortAllArrays( 

262 [data["zGalaxies"], data["xGalaxies"], data["yGalaxies"], data["galaxyStatMask"]] 

263 ) 

264 [colorValsGalaxies, xsGalaxies, ysGalaxies, statGalaxies] = sortedArrs 

265 statGalMed, statGalMad, galStatsText = self.statsAndText(colorValsGalaxies, mask=statGalaxies) 

266 # Add statistics 

267 bbox = dict(facecolor=galaxies_color(), alpha=0.5, edgecolor="none") 

268 # Check if plotting stars and galaxies, if so move the 

269 # text box so that both can be seen. Needs to be 

270 # > 2 becuase not being plotted points are assigned 0 

271 if len(self.plotTypes) > 2: 

272 boxLoc = 0.63 

273 else: 

274 boxLoc = 0.8 

275 ax.text(boxLoc, 0.91, galStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox) 

276 if self.divergent: 

277 galsCmap = divergent_cmap() 

278 toPlotList.append((xsGalaxies, ysGalaxies, colorValsGalaxies, galsCmap, "Galaxies")) 

279 

280 # For stars 

281 if "stars" in self.plotTypes: 

282 sortedArrs = sortAllArrays([data["zStars"], data["xStars"], data["yStars"], data["starStatMask"]]) 

283 [colorValsStars, xsStars, ysStars, statStars] = sortedArrs 

284 statStarMed, statStarMad, starStatsText = self.statsAndText(colorValsStars, mask=statStars) 

285 if not self.publicationStyle: 

286 # Add statistics 

287 bbox = dict(facecolor=stars_color(), alpha=0.5, edgecolor="none") 

288 ax.text(0.8, 0.91, starStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox) 

289 if self.divergent: 

290 starsCmap = divergent_cmap() 

291 toPlotList.append((xsStars, ysStars, colorValsStars, starsCmap, "Stars")) 

292 

293 # For unknowns 

294 if "unknown" in self.plotTypes: 

295 sortedArrs = sortAllArrays( 

296 [data["zUnknowns"], data["xUnknowns"], data["yUnknowns"], data["unknownStatMask"]] 

297 ) 

298 [colorValsUnknowns, xsUnknowns, ysUnknowns, statUnknowns] = sortedArrs 

299 statUnknownMed, statUnknownMad, unknownStatsText = self.statsAndText( 

300 colorValsUnknowns, mask=statUnknowns 

301 ) 

302 if not self.publicationStyle: 

303 bbox = dict(facecolor="green", alpha=0.2, edgecolor="none") 

304 ax.text(0.8, 0.91, unknownStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox) 

305 toPlotList.append((xsUnknowns, ysUnknowns, colorValsUnknowns, "viridis", "Unknown")) 

306 

307 if "any" in self.plotTypes: 

308 sortedArrs = sortAllArrays([data["z"], data["x"], data["y"], data["statMask"]]) 

309 [colorValsAny, xs, ys, statAny] = sortedArrs 

310 statAnyMed, statAnyMad, anyStatsText = self.statsAndText(colorValsAny, mask=statAny) 

311 if not self.publicationStyle: 

312 bbox = dict(facecolor="#bab0ac", alpha=0.2, edgecolor="none") 

313 ax.text(0.8, 0.91, anyStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox) 

314 toPlotList.append((xs, ys, colorValsAny, "viridis", "All")) 

315 

316 # Corner plot of patches showing summary stat in each 

317 if self.plotOutlines: 

318 patches = [] 

319 for dataId in sumStats.keys(): 

320 (corners, _) = sumStats[dataId] 

321 ra = corners[0][0].asDegrees() 

322 dec = corners[0][1].asDegrees() 

323 xy = (ra, dec) 

324 width = corners[2][0].asDegrees() - ra 

325 height = corners[2][1].asDegrees() - dec 

326 patches.append(Rectangle(xy, width, height, alpha=0.3)) 

327 ras = [ra.asDegrees() for (ra, dec) in corners] 

328 decs = [dec.asDegrees() for (ra, dec) in corners] 

329 ax.plot(ras + [ras[0]], decs + [decs[0]], "k", lw=0.5) 

330 cenX = ra + width / 2 

331 cenY = dec + height / 2 

332 if dataId == "tract": 

333 minRa = np.min(ras) 

334 minDec = np.min(decs) 

335 maxRa = np.max(ras) 

336 maxDec = np.max(decs) 

337 if dataId != "tract": 

338 ax.annotate( 

339 dataId, 

340 (cenX, cenY), 

341 color="k", 

342 fontsize=5, 

343 ha="center", 

344 va="center", 

345 path_effects=[pathEffects.withStroke(linewidth=2, foreground="w")], 

346 ) 

347 

348 for i, (xs, ys, colorVals, cmap, label) in enumerate(toPlotList): 

349 finite = np.isfinite(xs) & np.isfinite(ys) 

350 xs = xs[finite] 

351 ys = ys[finite] 

352 colorVals = colorVals[finite] 

353 n_xs = len(xs) 

354 # colorVal column is unusable so zero it out 

355 # This should be obvious on the plot 

356 if not any(np.isfinite(colorVals)): 

357 colorVals[:] = 0 

358 minColorVal, maxColorVal = self.colorbarRange(colorVals) 

359 

360 if n_xs < 5: 

361 continue 

362 if not self.plotOutlines or "tract" not in sumStats.keys(): 

363 minRa = np.min(xs) 

364 maxRa = np.max(xs) 

365 minDec = np.min(ys) 

366 maxDec = np.max(ys) 

367 # Avoid identical end points which causes problems in binning 

368 if minRa == maxRa: 

369 maxRa += 1e-5 # There is no reason to pick this number in particular 

370 if minDec == maxDec: 

371 maxDec += 1e-5 # There is no reason to pick this number in particular 

372 

373 if self.publicationStyle: 

374 showExtremeOutliers = False 

375 else: 

376 showExtremeOutliers = self.showExtremeOutliers 

377 plotOut = plotProjectionWithBinning( 

378 ax, 

379 xs, 

380 ys, 

381 colorVals, 

382 cmap, 

383 minRa, 

384 maxRa, 

385 minDec, 

386 maxDec, 

387 vmin=minColorVal, 

388 vmax=maxColorVal, 

389 fixAroundZero=self.fixAroundZero, 

390 isSorted=True, 

391 showExtremeOutliers=showExtremeOutliers, 

392 ) 

393 ax.set_aspect("equal") 

394 if not self.publicationStyle: 

395 cax = fig.add_axes([0.87 + i * 0.04, 0.11, 0.04, 0.77]) 

396 fig.colorbar(plotOut, cax=cax, extend="both") 

397 else: 

398 fig.subplots_adjust(wspace=0.0, hspace=0.0, right=0.95, bottom=0.15) 

399 axBbox = ax.get_position() 

400 cax = fig.add_axes([axBbox.x1, axBbox.y0, 0.04, axBbox.y1 - axBbox.y0]) 

401 fig.colorbar(plotOut, cax=cax) 

402 

403 colorBarLabel = "{}: {}".format(self.zAxisLabel, label) 

404 text = cax.text( 

405 0.5, 

406 0.5, 

407 colorBarLabel, 

408 color="k", 

409 rotation="vertical", 

410 transform=cax.transAxes, 

411 ha="center", 

412 va="center", 

413 fontsize=10, 

414 ) 

415 text.set_path_effects([pathEffects.Stroke(linewidth=3, foreground="w"), pathEffects.Normal()]) 

416 

417 if i == 0 and len(toPlotList) > 1: 

418 cax.yaxis.set_ticks_position("left") 

419 

420 ax.set_xlabel(xCol) 

421 ax.set_ylabel(yCol) 

422 

423 fig.canvas.draw() 

424 

425 # Find some useful axis limits 

426 lenXs = [len(xs) for (xs, _, _, _, _) in toPlotList] 

427 if lenXs != [] and np.max(lenXs) > 1000: 

428 padRa = (maxRa - minRa) / 10 

429 padDec = (maxDec - minDec) / 10 

430 ax.set_xlim(maxRa + padRa, minRa - padRa) 

431 ax.set_ylim(minDec - padDec, maxDec + padDec) 

432 else: 

433 ax.invert_xaxis() 

434 

435 # Add useful information to the plot 

436 if not self.publicationStyle: 

437 fig.subplots_adjust(wspace=0.0, hspace=0.0) 

438 fig = addPlotInfo(fig, plotInfo) 

439 

440 return fig