Coverage for python / lsst / analysis / tools / actions / plot / skyPlot.py: 12%

210 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-07 08:53 +0000

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ("SkyPlot",) 

25 

26from collections.abc import Mapping 

27 

28import matplotlib.patheffects as pathEffects 

29import numpy as np 

30from matplotlib.figure import Figure 

31from matplotlib.patches import Rectangle 

32 

33from lsst.pex.config import Field, ListField 

34from lsst.pex.config.configurableActions import ConfigurableActionField 

35from lsst.utils.plotting import ( 

36 divergent_cmap, 

37 galaxies_cmap, 

38 galaxies_color, 

39 make_figure, 

40 set_rubin_plotstyle, 

41 stars_cmap, 

42 stars_color, 

43) 

44 

45from ...interfaces import KeyedData, KeyedDataSchema, PlotAction, Scalar, Vector, VectorAction 

46from ...math import nanMedian, nanSigmaMad 

47from .calculateRange import Med2Mad 

48from .plotUtils import addPlotInfo, generateSummaryStats, plotProjectionWithBinning, sortAllArrays 

49 

50 

51class SkyPlot(PlotAction): 

52 """Plots the on sky distribution of a parameter. 

53 

54 Plots the values of the parameter given for the z axis 

55 according to the positions given for x and y. Optimised 

56 for use with RA and Dec. Also calculates some basic 

57 statistics and includes those on the plot. 

58 

59 The plotting of patch outlines requires patch information 

60 to be included as an additional parameter. 

61 """ 

62 

63 xAxisLabel = Field[str](doc="Label to use for the x axis.", optional=False) 

64 yAxisLabel = Field[str](doc="Label to use for the y axis.", optional=False) 

65 zAxisLabel = Field[str](doc="Label to use for the z axis.", optional=False) 

66 

67 plotOutlines = Field[bool]( 

68 doc="Plot the outlines of the ccds/patches?", 

69 default=True, 

70 ) 

71 

72 plotTypes = ListField[str]( 

73 doc="Selection of types of objects to plot. Can take any combination of" 

74 " stars, galaxies, unknown, mag, any.", 

75 optional=False, 

76 # itemCheck=_validatePlotTypes, 

77 ) 

78 

79 plotName = Field[str](doc="The name for the plot.", optional=False) 

80 

81 alpha = Field[float]( 

82 doc="Transparency for scatter plot points.", 

83 default=1.0, 

84 ) 

85 

86 scatPtSize = Field[float]( 

87 doc="Marker size for scatter plot points.", 

88 default=7, 

89 ) 

90 

91 fixAroundZero = Field[bool]( 

92 doc="Fix the colorbar to be symmetric around zero.", 

93 default=False, 

94 ) 

95 

96 doBinning = Field[bool]( 

97 doc="Spatially bin the data? Extreme outliers can be shown using" " `showExtremeOutliers`.", 

98 optional=True, 

99 default=True, 

100 ) 

101 

102 colorbarRange = ConfigurableActionField[VectorAction]( 

103 doc="Action to calculate the min and max of the colorbar range.", 

104 default=Med2Mad, 

105 ) 

106 

107 showExtremeOutliers = Field[bool]( 

108 doc="Show the x-y positions of extreme outlier values as overlaid scatter points.", 

109 default=True, 

110 ) 

111 

112 publicationStyle = Field[bool]( 

113 doc="Make a simplified plot for publication use.", 

114 default=False, 

115 ) 

116 

117 divergent = Field[bool]( 

118 doc="Use a divergent colormap?", 

119 default=False, 

120 ) 

121 

122 def getInputSchema(self, **kwargs) -> KeyedDataSchema: 

123 base = [] 

124 if "stars" in self.plotTypes: # type: ignore 

125 base.append(("xStars", Vector)) 

126 base.append(("yStars", Vector)) 

127 base.append(("zStars", Vector)) 

128 base.append(("starStatMask", Vector)) 

129 if "galaxies" in self.plotTypes: # type: ignore 

130 base.append(("xGalaxies", Vector)) 

131 base.append(("yGalaxies", Vector)) 

132 base.append(("zGalaxies", Vector)) 

133 base.append(("galaxyStatMask", Vector)) 

134 if "unknown" in self.plotTypes: # type: ignore 

135 base.append(("xUnknowns", Vector)) 

136 base.append(("yUnknowns", Vector)) 

137 base.append(("zUnknowns", Vector)) 

138 base.append(("unknownStatMask", Vector)) 

139 if "any" in self.plotTypes: # type: ignore 

140 base.append(("x", Vector)) 

141 base.append(("y", Vector)) 

142 base.append(("z", Vector)) 

143 base.append(("statMask", Vector)) 

144 

145 return base 

146 

147 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure: 

148 self._validateInput(data, **kwargs) 

149 return self.makePlot(data, **kwargs) 

150 # table is a dict that needs: x, y, run, skymap, filter, tract, 

151 

152 def _validateInput(self, data: KeyedData, **kwargs) -> None: 

153 """NOTE currently can only check that something is not a Scalar, not 

154 check that the data is consistent with Vector 

155 """ 

156 needed = self.getInputSchema(**kwargs) 

157 if remainder := {key.format(**kwargs) for key, _ in needed} - { 

158 key.format(**kwargs) for key in data.keys() 

159 }: 

160 raise ValueError(f"Task needs keys {remainder} but they were not found in input") 

161 for name, typ in needed: 

162 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar) 

163 if isScalar and typ != Scalar: 

164 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}") 

165 

166 def statsAndText(self, arr, mask=None): 

167 """Calculate some stats from an array and return them 

168 and some text. 

169 """ 

170 numPoints = len(arr) 

171 if mask is not None: 

172 arr = arr[mask] 

173 med = nanMedian(arr) 

174 sigMad = nanSigmaMad(arr) 

175 

176 statsText = ( 

177 f"Median: {med:0.2f}\n" 

178 + r"$\sigma_{MAD}$: " 

179 + f"{sigMad:0.2f}\n" 

180 + r"n$_{points}$: " 

181 + f"{numPoints}" 

182 ) 

183 

184 return med, sigMad, statsText 

185 

186 def makePlot( 

187 self, 

188 data: KeyedData, 

189 plotInfo: Mapping[str, str] | None = None, 

190 sumStats: Mapping | None = None, 

191 **kwargs, 

192 ) -> Figure: 

193 """Make a skyPlot of the given data. 

194 

195 Parameters 

196 ---------- 

197 data : `KeyedData` 

198 The catalog to plot the points from. 

199 plotInfo : `dict` 

200 A dictionary of information about the data being plotted with keys: 

201 

202 ``"run"`` 

203 The output run for the plots (`str`). 

204 ``"skymap"`` 

205 The type of skymap used for the data (`str`). 

206 ``"filter"`` 

207 The filter used for this data (`str`). 

208 ``"tract"`` 

209 The tract that the data comes from (`str`). 

210 

211 sumStats : `dict` 

212 A dictionary where the patchIds are the keys which store the R.A. 

213 and dec of the corners of the patch. 

214 

215 Returns 

216 ------- 

217 `pipeBase.Struct` containing: 

218 skyPlot : `matplotlib.figure.Figure` 

219 The resulting figure. 

220 

221 Notes 

222 ----- 

223 Expects the data to contain slightly different things 

224 depending on the types specified in plotTypes. This 

225 is handled automatically if you go through the pipetask 

226 framework but if you call this method separately then you 

227 need to make sure that data contains what the code is expecting. 

228 

229 If stars is in the plot types given then it is expected that 

230 data contains: xStars, yStars, zStars and starStatMask. 

231 

232 If galaxies is present: xGalaxies, yGalaxies, zGalaxies and 

233 galaxyStatsMask. 

234 

235 If unknown is present: xUnknowns, yUnknowns, zUnknowns and 

236 unknownStatMask. 

237 

238 If any is specified: x, y, z, statMask. 

239 

240 These options are not exclusive and multiple can be specified 

241 and thus need to be present in data. 

242 

243 Examples 

244 -------- 

245 An example of the plot produced from this code is here: 

246 

247 .. image:: /_static/analysis_tools/skyPlotExample.png 

248 

249 For a detailed example of how to make a plot from the command line 

250 please see the 

251 :ref:`getting started guide<analysis-tools-getting-started>`. 

252 """ 

253 

254 set_rubin_plotstyle() 

255 

256 # 'plotName' by default is constructed from the attribute specified in 

257 # 'atools.<attribute>' in the pipeline YAML. If the atool sets 

258 # self.produce.plot.plotName, it will override this default. 

259 if self.plotName: 

260 plotInfo["plotName"] = self.plotName 

261 

262 fig = make_figure() 

263 ax = fig.add_subplot(111) 

264 

265 if sumStats is None: 

266 if self.plotOutlines and "patch" in data.keys(): 

267 sumStats = generateSummaryStats(data, kwargs["skymap"], plotInfo) 

268 else: 

269 sumStats = {} 

270 

271 if plotInfo is None: 

272 plotInfo = {} 

273 

274 # Make divergent colormaps for stars, galaxes and all the points 

275 starsCmap = stars_cmap() 

276 galsCmap = galaxies_cmap() 

277 

278 xCol = self.xAxisLabel 

279 yCol = self.yAxisLabel 

280 zCol = self.zAxisLabel # noqa: F841 

281 

282 toPlotList = [] 

283 # For galaxies 

284 if "galaxies" in self.plotTypes: 

285 sortedArrs = sortAllArrays( 

286 [data["zGalaxies"], data["xGalaxies"], data["yGalaxies"], data["galaxyStatMask"]] 

287 ) 

288 [colorValsGalaxies, xsGalaxies, ysGalaxies, statGalaxies] = sortedArrs 

289 statGalMed, statGalMad, galStatsText = self.statsAndText(colorValsGalaxies, mask=statGalaxies) 

290 # Add statistics 

291 bbox = dict(facecolor=galaxies_color(), alpha=0.5, edgecolor="none") 

292 # Check if plotting stars and galaxies, if so move the 

293 # text box so that both can be seen. Needs to be 

294 # > 2 becuase not being plotted points are assigned 0 

295 if len(self.plotTypes) > 2: 

296 boxLoc = 0.63 

297 else: 

298 boxLoc = 0.8 

299 ax.text(boxLoc, 0.91, galStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox) 

300 if self.divergent: 

301 galsCmap = divergent_cmap() 

302 toPlotList.append((xsGalaxies, ysGalaxies, colorValsGalaxies, galsCmap, "Galaxies")) 

303 

304 # For stars 

305 if "stars" in self.plotTypes: 

306 sortedArrs = sortAllArrays([data["zStars"], data["xStars"], data["yStars"], data["starStatMask"]]) 

307 [colorValsStars, xsStars, ysStars, statStars] = sortedArrs 

308 statStarMed, statStarMad, starStatsText = self.statsAndText(colorValsStars, mask=statStars) 

309 if not self.publicationStyle: 

310 # Add statistics 

311 bbox = dict(facecolor=stars_color(), alpha=0.5, edgecolor="none") 

312 ax.text(0.8, 0.91, starStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox) 

313 if self.divergent: 

314 starsCmap = divergent_cmap() 

315 toPlotList.append((xsStars, ysStars, colorValsStars, starsCmap, "Stars")) 

316 

317 # For unknowns 

318 if "unknown" in self.plotTypes: 

319 sortedArrs = sortAllArrays( 

320 [data["zUnknowns"], data["xUnknowns"], data["yUnknowns"], data["unknownStatMask"]] 

321 ) 

322 [colorValsUnknowns, xsUnknowns, ysUnknowns, statUnknowns] = sortedArrs 

323 statUnknownMed, statUnknownMad, unknownStatsText = self.statsAndText( 

324 colorValsUnknowns, mask=statUnknowns 

325 ) 

326 if not self.publicationStyle: 

327 bbox = dict(facecolor="green", alpha=0.2, edgecolor="none") 

328 ax.text(0.8, 0.91, unknownStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox) 

329 toPlotList.append((xsUnknowns, ysUnknowns, colorValsUnknowns, "viridis", "Unknown")) 

330 

331 if "any" in self.plotTypes: 

332 sortedArrs = sortAllArrays([data["z"], data["x"], data["y"], data["statMask"]]) 

333 [colorValsAny, xs, ys, statAny] = sortedArrs 

334 statAnyMed, statAnyMad, anyStatsText = self.statsAndText(colorValsAny, mask=statAny) 

335 if not self.publicationStyle: 

336 bbox = dict(facecolor="#bab0ac", alpha=0.2, edgecolor="none") 

337 ax.text(0.8, 0.91, anyStatsText, transform=fig.transFigure, fontsize=8, bbox=bbox) 

338 toPlotList.append((xs, ys, colorValsAny, "viridis", "")) 

339 

340 # Corner plot of patches showing summary stat in each 

341 if self.plotOutlines: 

342 patches = [] 

343 for dataId in sumStats.keys(): 

344 (corners, _) = sumStats[dataId] 

345 ra = corners[0][0].asDegrees() 

346 dec = corners[0][1].asDegrees() 

347 xy = (ra, dec) 

348 width = corners[2][0].asDegrees() - ra 

349 height = corners[2][1].asDegrees() - dec 

350 patches.append(Rectangle(xy, width, height, alpha=0.3)) 

351 ras = [ra.asDegrees() for (ra, dec) in corners] 

352 decs = [dec.asDegrees() for (ra, dec) in corners] 

353 ax.plot(ras + [ras[0]], decs + [decs[0]], "k", lw=0.5) 

354 cenX = ra + width / 2 

355 cenY = dec + height / 2 

356 if dataId == "tract": 

357 minRa = np.min(ras) 

358 minDec = np.min(decs) 

359 maxRa = np.max(ras) 

360 maxDec = np.max(decs) 

361 if dataId != "tract": 

362 ax.annotate( 

363 dataId, 

364 (cenX, cenY), 

365 color="k", 

366 fontsize=5, 

367 ha="center", 

368 va="center", 

369 path_effects=[pathEffects.withStroke(linewidth=2, foreground="w")], 

370 ) 

371 

372 for i, (xs, ys, colorVals, cmap, label) in enumerate(toPlotList): 

373 finite = np.isfinite(xs) & np.isfinite(ys) 

374 xs = xs[finite] 

375 ys = ys[finite] 

376 colorVals = colorVals[finite] 

377 n_xs = len(xs) 

378 if n_xs == 0: 

379 continue 

380 

381 # colorVal column is unusable so zero it out 

382 # This should be obvious on the plot 

383 if not any(np.isfinite(colorVals)): 

384 colorVals[:] = 0 

385 minColorVal, maxColorVal = self.colorbarRange(colorVals) 

386 

387 if not self.plotOutlines or "tract" not in sumStats.keys(): 

388 minRa = np.min(xs) 

389 maxRa = np.max(xs) 

390 minDec = np.min(ys) 

391 maxDec = np.max(ys) 

392 # Avoid identical end points which causes problems in binning 

393 if minRa == maxRa: 

394 maxRa += 1e-5 # There is no reason to pick this number in particular 

395 if minDec == maxDec: 

396 maxDec += 1e-5 # There is no reason to pick this number in particular 

397 if n_xs < 5: 

398 continue 

399 

400 if self.publicationStyle: 

401 showExtremeOutliers = False 

402 else: 

403 showExtremeOutliers = self.showExtremeOutliers 

404 

405 if self.doBinning: 

406 nPointBinThresh = 5000 

407 else: # Make a true scatter plot (plot all the points) 

408 nPointBinThresh = len(xs) + 1 

409 

410 # If transparency is being used, point edgecolor matches facecolor 

411 if self.alpha == 1.0: 

412 edgecolor = "white" 

413 else: 

414 edgecolor = "face" 

415 

416 plotOut = plotProjectionWithBinning( 

417 ax, 

418 xs, 

419 ys, 

420 colorVals, 

421 cmap, 

422 minRa, 

423 maxRa, 

424 minDec, 

425 maxDec, 

426 vmin=minColorVal, 

427 vmax=maxColorVal, 

428 alpha=self.alpha, 

429 edgecolor=edgecolor, 

430 fixAroundZero=self.fixAroundZero, 

431 nPointBinThresh=nPointBinThresh, 

432 isSorted=True, 

433 showExtremeOutliers=showExtremeOutliers, 

434 scatPtSize=self.scatPtSize, 

435 ) 

436 ax.set_aspect("equal") 

437 if not self.publicationStyle: 

438 cax = fig.add_axes([0.87 + i * 0.04, 0.11, 0.04, 0.77]) 

439 fig.colorbar(plotOut, cax=cax, extend="both") 

440 else: 

441 fig.subplots_adjust(wspace=0.0, hspace=0.0, right=0.95, bottom=0.15) 

442 axBbox = ax.get_position() 

443 cax = fig.add_axes([axBbox.x1, axBbox.y0, 0.04, axBbox.y1 - axBbox.y0]) 

444 fig.colorbar(plotOut, cax=cax) 

445 

446 if len(label) > 0: 

447 colorBarLabel = f"{self.zAxisLabel}: {label}" 

448 else: 

449 colorBarLabel = f"{self.zAxisLabel}" 

450 text = cax.text( 

451 0.5, 

452 0.5, 

453 colorBarLabel, 

454 color="k", 

455 rotation="vertical", 

456 transform=cax.transAxes, 

457 ha="center", 

458 va="center", 

459 fontsize=10, 

460 ) 

461 text.set_path_effects([pathEffects.Stroke(linewidth=3, foreground="w"), pathEffects.Normal()]) 

462 

463 if i == 0 and len(toPlotList) > 1: 

464 cax.yaxis.set_ticks_position("left") 

465 

466 ax.set_xlabel(xCol) 

467 ax.set_ylabel(yCol) 

468 

469 fig.canvas.draw() 

470 

471 # Find some useful axis limits 

472 lenXs = [len(xs) for (xs, _, _, _, _) in toPlotList] 

473 if lenXs != [] and np.max(lenXs) > 1000: 

474 padRa = (maxRa - minRa) / 10 

475 padDec = (maxDec - minDec) / 10 

476 ax.set_xlim(maxRa + padRa, minRa - padRa) 

477 ax.set_ylim(minDec - padDec, maxDec + padDec) 

478 else: 

479 ax.invert_xaxis() 

480 

481 # Add useful information to the plot 

482 if not self.publicationStyle: 

483 fig.subplots_adjust(wspace=0.0, hspace=0.0) 

484 fig = addPlotInfo(fig, plotInfo) 

485 

486 return fig