Coverage for python / lsst / analysis / tools / actions / plot / wholeSkyPlot.py: 13%

230 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-15 00:23 +0000

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ("WholeSkyPlot",) 

25 

26import importlib.resources as importResources 

27from typing import Mapping, Optional 

28 

29import lsst.analysis.tools 

30import matplotlib.patheffects as pathEffects 

31import numpy as np 

32import yaml 

33from lsst.pex.config import ChoiceField, Field, ListField 

34from lsst.utils.plotting import ( 

35 accent_color, 

36 divergent_cmap, 

37 make_figure, 

38 set_rubin_plotstyle, 

39 stars_cmap, 

40 stars_color, 

41) 

42from matplotlib import gridspec 

43from matplotlib.collections import PatchCollection 

44from matplotlib.colors import CenteredNorm 

45from matplotlib.figure import Figure 

46from matplotlib.patches import Patch, Polygon 

47 

48from ...interfaces import KeyedData, KeyedDataSchema, PlotAction, Scalar, Vector 

49from ...math import nanSigmaMad 

50from ...utils import getTractCorners 

51from .plotUtils import addPlotInfo 

52 

53 

54class WholeSkyPlot(PlotAction): 

55 """Plots the on sky distribution of a parameter. 

56 

57 Plots the values of the parameter given for the z axis 

58 according to the positions given for x and y. Optimised 

59 for use with RA and Dec. Also calculates some basic 

60 statistics and includes those on the plot. 

61 

62 The default axes limits and figure size were chosen to plot HSC PDR2. 

63 """ 

64 

65 xAxisLabel = Field[str](doc="Label to use for the x axis.", default="RA (degrees)") 

66 yAxisLabel = Field[str](doc="Label to use for the y axis.", default="Dec (degrees)") 

67 zAxisLabel = Field[str](doc="Label to use for the z axis.", default="") 

68 autoAxesLimits = Field[bool](doc="Find axes limits automatically.", default=True) 

69 xLimits = ListField[float](doc="Plotting limits for the x axis.", default=[-5.0, 365.0]) 

70 yLimits = ListField[float](doc="Plotting limits for the y axis.", default=[-10.0, 60.0]) 

71 autoAxesLimits = Field[bool](doc="Find axes limits automatically.", default=True) 

72 colorBarMin = Field[float](doc="The minimum value of the color bar.", optional=True) 

73 colorBarMax = Field[float](doc="The minimum value of the color bar.", optional=True) 

74 colorBarRange = Field[float]( 

75 doc="The multiplier for the color bar range. The max/min range values are: median +/- N * sigmaMad" 

76 ", where N is this config value.", 

77 default=3.0, 

78 ) 

79 colorMapType = ChoiceField[str]( 

80 doc="Type of color map to use for the color bar. Options: sequential, divergent, userDefined.", 

81 allowed={cmType: cmType for cmType in ("sequential", "divergent")}, 

82 default="divergent", 

83 ) 

84 colorMap = ListField[str]( 

85 doc="List of hexidecimal colors for a user-defined color map.", 

86 optional=True, 

87 ) 

88 showOutliers = Field[bool]( 

89 doc="Show the outliers on the plot. " 

90 "Outliers are values whose absolute value is > colorBarRange * sigmaMAD.", 

91 default=True, 

92 ) 

93 showNaNs = Field[bool](doc="Show the NaNs on the plot.", default=True) 

94 labelTracts = Field[bool](doc="Label the tracts.", default=False) 

95 

96 addThresholds = Field[bool]( 

97 doc="Read in the predefined thresholds and indicate them on the histogram.", 

98 default=True, 

99 ) 

100 

101 def getInputSchema(self, **kwargs) -> KeyedDataSchema: 

102 base = [] 

103 base.append(("z", Vector)) 

104 base.append(("tract", Vector)) 

105 return base 

106 

107 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure: 

108 self._validateInput(data, **kwargs) 

109 return self.makePlot(data, **kwargs) 

110 

111 def _validateInput(self, data: KeyedData, **kwargs) -> None: 

112 """NOTE currently can only check that something is not a Scalar, not 

113 check that the data is consistent with Vector 

114 """ 

115 needed = self.getInputSchema(**kwargs) 

116 if remainder := {key.format(**kwargs) for key, _ in needed} - { 

117 key.format(**kwargs) for key in data.keys() 

118 }: 

119 raise ValueError(f"Task needs keys {remainder} but they were not found in input") 

120 for name, typ in needed: 

121 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar) 

122 if isScalar and typ != Scalar: 

123 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}") 

124 

125 def _getAxesLimits(self, xs: list, ys: list) -> tuple(list, list): 

126 """Get the x and y axes limits in degrees. 

127 

128 Parameters 

129 ---------- 

130 xs : `list` 

131 X coordinates for the tracts to plot. 

132 ys : `list` 

133 Y coordinates for the tracts to plot. 

134 

135 Returns 

136 ------- 

137 xlim : `list` 

138 Minimun and maximum x axis values. 

139 ylim : `list` 

140 Minimun and maximum y axis values. 

141 """ 

142 

143 # Add some blank space on the edges of the plot. 

144 xlim = [np.nanmin(xs) - 5, np.nanmax(xs) + 5] 

145 ylim = [np.nanmin(ys) - 5, np.nanmax(ys) + 5] 

146 

147 # Limit to only show real RA/Dec values. 

148 if xlim[0] < 0.0: 

149 xlim[0] = 0.0 

150 if xlim[1] > 360.0: 

151 xlim[1] = 360.0 

152 if ylim[0] < -90.0: 

153 ylim[0] = -90.0 

154 if ylim[1] > 90.0: 

155 ylim[1] = 90.0 

156 

157 return (xlim, ylim) 

158 

159 def _getMaxOutlierVals(self, multiplier: float, tracts: list, values: list, outlierInds: list) -> str: 

160 """Get the 5 largest outlier values in a string. 

161 

162 Parameters 

163 ---------- 

164 multiplier : `float` 

165 Select values whose absolute value is > multiplier * sigmaMAD. 

166 tracts : `list` 

167 All the tracts. 

168 values : `list` 

169 All the metric values. 

170 outlierInds : `list` 

171 Indicies of outlier values. 

172 

173 Returns 

174 ------- 

175 text : `str` 

176 A string containing the 10 tracts with the largest outlier values. 

177 """ 

178 if self.addThresholds: 

179 text = "Tracts with value outside thresholds: " 

180 else: 

181 text = f"Tracts with |value| > {multiplier}" + r"$\sigma_{MAD}$" + ": " 

182 if len(outlierInds) > 0: 

183 outlierValues = np.array(values)[outlierInds] 

184 outlierTracts = np.array(tracts)[outlierInds] 

185 # Sort values in descending (-) absolute value order discounting 

186 # NaNs. 

187 maxInds = np.argsort(-np.abs(outlierValues)) 

188 # Show up to ten values on the plot. 

189 for ind in maxInds[:10]: 

190 val = outlierValues[ind] 

191 tract = outlierTracts[ind] 

192 text += f"{tract}, {val:.3}; " 

193 # Remove the final trailing comma and whitespace. 

194 text = text[:-2] 

195 else: 

196 text += "None" 

197 

198 return text 

199 

200 def makePlot( 

201 self, 

202 data: KeyedData, 

203 plotInfo: Optional[Mapping[str, str]] = None, 

204 **kwargs, 

205 ) -> Figure: 

206 """Make a WholeSkyPlot of the given data. 

207 

208 Parameters 

209 ---------- 

210 data : `KeyedData` 

211 The catalog to plot the points from. 

212 plotInfo : `dict` 

213 A dictionary of information about the data being plotted with keys: 

214 

215 ``"run"`` 

216 The output run for the plots (`str`). 

217 ``"skymap"`` 

218 The type of skymap used for the data (`str`). 

219 ``"filter"`` 

220 The filter used for this data (`str`). 

221 ``"tract"`` 

222 The tract that the data comes from (`str`). 

223 

224 Returns 

225 ------- 

226 `pipeBase.Struct` containing: 

227 skyPlot : `matplotlib.figure.Figure` 

228 The resulting figure. 

229 

230 

231 Examples 

232 -------- 

233 An example of the plot produced from this code is here: 

234 

235 .. image:: /_static/analysis_tools/wholeSkyPlotExample.png 

236 

237 For a detailed example of how to make a plot from the command line 

238 please see the 

239 :ref:`getting started guide<analysis-tools-getting-started>`. 

240 """ 

241 skymap = kwargs["skymap"] 

242 if plotInfo is None: 

243 plotInfo = {} 

244 

245 if self.addThresholds: 

246 metricThresholdFile = importResources.read_text(lsst.analysis.tools, "metricInformation.yaml") 

247 metricDefs = yaml.safe_load(metricThresholdFile) 

248 

249 # Prevent Bands in the plot info showing a list of bands. 

250 # If bands is a list, it implies that parameterizedBand=False, 

251 # and that the metric is not band-specific. 

252 if "bands" in plotInfo: 

253 if isinstance(plotInfo["bands"], list): 

254 plotInfo["bands"] = "N/A" 

255 

256 colorMap = self.colorMap 

257 match self.colorMapType: 

258 case "sequential": 

259 if colorMap is None: 

260 colorMap = stars_cmap() 

261 outlierColor = "red" 

262 norm = None 

263 case "divergent": 

264 if colorMap is None: 

265 colorMap = divergent_cmap() 

266 outlierColor = "fuchsia" 

267 norm = CenteredNorm() 

268 

269 # Create patches using the corners of each tract. 

270 patches = [] 

271 colBarVals = [] 

272 tracts = [] 

273 ras = [] 

274 decs = [] 

275 mid_ras = [] 

276 mid_decs = [] 

277 for i, tract in enumerate(data["tract"]): 

278 corners = getTractCorners(skymap, tract) 

279 patches.append(Polygon(corners, closed=True)) 

280 colBarVals.append(data["z"][i]) 

281 tracts.append(tract) 

282 ras.append(corners[0][0]) 

283 decs.append(corners[0][1]) 

284 mid_ras.append((corners[0][0] + corners[1][0]) / 2) 

285 mid_decs.append((corners[0][1] + corners[2][1]) / 2) 

286 

287 # Setup figure. 

288 fig = make_figure(dpi=300, figsize=(12, 3.5)) 

289 set_rubin_plotstyle() 

290 gs = gridspec.GridSpec(1, 4) 

291 ax = fig.add_subplot(gs[:3]) 

292 # Add colored patches showing tract metric values. 

293 patchCollection = PatchCollection(patches, cmap=colorMap, norm=norm) 

294 ax.add_collection(patchCollection) 

295 

296 # Define color bar range. 

297 if np.sum(np.isfinite(colBarVals)) > 0: 

298 med = np.nanmedian(colBarVals) 

299 else: 

300 med = np.nan 

301 sigmaMad = nanSigmaMad(colBarVals) 

302 if self.colorBarMin is not None: 

303 vmin = np.float64(self.colorBarMin) 

304 else: 

305 vmin = med - self.colorBarRange * sigmaMad 

306 if self.colorBarMax is not None: 

307 vmax = np.float64(self.colorBarMax) 

308 else: 

309 vmax = med + self.colorBarRange * sigmaMad 

310 

311 dataName = self.zAxisLabel.format_map(kwargs) 

312 colBarVals = np.array(colBarVals) 

313 if self.addThresholds and dataName in metricDefs: 

314 if "lowThreshold" in metricDefs[dataName].keys(): 

315 lowThreshold = metricDefs[dataName]["lowThreshold"] 

316 else: 

317 lowThreshold = np.nan 

318 if "highThreshold" in metricDefs[dataName].keys(): 

319 highThreshold = metricDefs[dataName]["highThreshold"] 

320 else: 

321 highThreshold = np.nan 

322 outlierInds = np.where((colBarVals < lowThreshold) | (colBarVals > highThreshold))[0] 

323 else: 

324 # Note tracts with metrics outside (vmin, vmax) as outliers. 

325 outlierInds = np.where((colBarVals < vmin) | (colBarVals > vmax))[0] 

326 

327 # Initialize legend handles. 

328 handles = [] 

329 

330 if self.showOutliers: 

331 # Plot the outlier patches. 

332 outlierPatches = [] 

333 if len(outlierInds) > 0: 

334 for ind in outlierInds: 

335 outlierPatches.append(patches[ind]) 

336 outlierPatchCollection = PatchCollection( 

337 outlierPatches, 

338 cmap=colorMap, 

339 norm=norm, 

340 facecolors="none", 

341 edgecolors=outlierColor, 

342 linewidths=0.5, 

343 zorder=100, 

344 ) 

345 ax.add_collection(outlierPatchCollection) 

346 # Add legend information. 

347 outlierPatch = Patch( 

348 facecolor="none", 

349 edgecolor=outlierColor, 

350 linewidth=0.5, 

351 label="Outlier", 

352 ) 

353 handles.append(outlierPatch) 

354 

355 if self.showNaNs: 

356 # Plot tracts with NaN metric values. 

357 nanInds = np.where(~np.isfinite(colBarVals))[0] 

358 nanPatches = [] 

359 if len(nanInds) > 0: 

360 for ind in nanInds: 

361 nanPatches.append(patches[ind]) 

362 nanPatchCollection = PatchCollection( 

363 nanPatches, 

364 cmap=None, 

365 norm=norm, 

366 facecolors="white", 

367 edgecolors="grey", 

368 linestyles="dotted", 

369 linewidths=0.5, 

370 zorder=100, 

371 ) 

372 ax.add_collection(nanPatchCollection) 

373 # Add legend information. 

374 nanPatch = Patch( 

375 facecolor="white", 

376 edgecolor="grey", 

377 linestyle="dotted", 

378 linewidth=0.5, 

379 label="NaN", 

380 ) 

381 handles.append(nanPatch) 

382 

383 if len(handles) > 0: 

384 fig.legend(handles=handles) 

385 

386 if self.labelTracts: 

387 # Label the tracts 

388 for i, tract in enumerate(tracts): 

389 ax.text( 

390 mid_ras[i], 

391 mid_decs[i], 

392 f"{tract}", 

393 ha="center", 

394 va="center", 

395 fontsize=2, 

396 alpha=0.7, 

397 zorder=100, 

398 ) 

399 

400 ax.set_aspect("equal") 

401 axPos = ax.get_position() 

402 ax1 = fig.add_axes([0.73, 0.25, 0.20, 0.47]) 

403 

404 if np.sum(np.isfinite(data["z"])) > 0: 

405 ax1.hist(data["z"], bins=len(data["z"] / 10), color=stars_color(), histtype="step") 

406 else: 

407 ax1.text(0.5, 0.5, "Data all NaN/Inf") 

408 ax1.set_xlabel("Metric Values") 

409 ax1.set_ylabel("Number") 

410 ax1.yaxis.set_label_position("right") 

411 ax1.yaxis.tick_right() 

412 

413 if self.addThresholds and dataName in metricDefs: 

414 # Check the thresholds are finite and set them to 

415 # the min/max of the data if they aren't to calculate 

416 # the x range of the plot 

417 if np.isfinite(lowThreshold): 

418 ax1.axvline(lowThreshold, color=accent_color()) 

419 else: 

420 lowThreshold = np.nanmin(colBarVals) 

421 if np.isfinite(highThreshold): 

422 ax1.axvline(highThreshold, color=accent_color()) 

423 else: 

424 highThreshold = np.nanmax(colBarVals) 

425 

426 widthThreshold = highThreshold - lowThreshold 

427 upperLim = highThreshold + 0.5 * widthThreshold 

428 lowerLim = lowThreshold - 0.5 * widthThreshold 

429 ax1.set_xlim(lowerLim, upperLim) 

430 numOutside = np.sum(((data["z"] > upperLim) | (data["z"] < lowerLim))) 

431 ax1.set_title("Outside plot limits: " + str(numOutside)) 

432 

433 else: 

434 if vmin != vmax and np.isfinite(vmin) and np.isfinite(vmax): 

435 ax1.set_xlim(vmin, vmax) 

436 

437 if self.autoAxesLimits: 

438 xlim, ylim = self._getAxesLimits(ras, decs) 

439 else: 

440 xlim, ylim = self.xLimits, self.yLimits 

441 ax.set_xlim(xlim) 

442 ax.set_ylim(ylim) 

443 ax.set_xlabel(self.xAxisLabel) 

444 ax.set_ylabel(self.yAxisLabel) 

445 ax.invert_xaxis() 

446 

447 if self.showOutliers: 

448 # Add text boxes to show the number of tracts, number of NaNs, 

449 # median, sigma MAD, and the five largest outlier values. 

450 outlierText = self._getMaxOutlierVals(self.colorBarRange, tracts, colBarVals, outlierInds) 

451 # Make vertical text spacing readable for different figure sizes. 

452 multiplier = 3.5 / fig.get_size_inches()[1] 

453 verticalSpacing = 0.028 * multiplier 

454 fig.text( 

455 0.01, 

456 0.01 + 3 * verticalSpacing, 

457 f"Num tracts: {len(tracts)}", 

458 transform=fig.transFigure, 

459 fontsize=8, 

460 alpha=0.7, 

461 ) 

462 if self.showNaNs: 

463 fig.text( 

464 0.01, 

465 0.01 + 2 * verticalSpacing, 

466 f"Num nans: {len(nanInds)}", 

467 transform=fig.transFigure, 

468 fontsize=8, 

469 alpha=0.7, 

470 ) 

471 fig.text( 

472 0.01, 

473 0.01 + verticalSpacing, 

474 f"Median: {med:.3f}; " + r"$\sigma_{MAD}$" + f": {sigmaMad:.3f}", 

475 transform=fig.transFigure, 

476 fontsize=8, 

477 alpha=0.7, 

478 ) 

479 if self.showOutliers: 

480 fig.text(0.01, 0.01, outlierText, transform=fig.transFigure, fontsize=8, alpha=0.7) 

481 

482 # Truncate the color range to (vmin, vmax). 

483 if vmin != vmax and np.isfinite(vmin) and np.isfinite(vmax): 

484 colBarVals = np.clip(np.array(colBarVals), vmin, vmax) 

485 patchCollection.set_array(colBarVals) 

486 # Make the color bar with a metric label. 

487 axPos = ax.get_position() 

488 cax = fig.add_axes([0.084, axPos.y1 + 0.02, 0.62, 0.07]) 

489 fig.colorbar( 

490 patchCollection, 

491 cax=cax, 

492 shrink=0.7, 

493 extend="both", 

494 location="top", 

495 orientation="horizontal", 

496 ) 

497 cbarText = "Metric Values" 

498 

499 text = cax.text( 

500 0.5, 

501 0.5, 

502 cbarText, 

503 transform=cax.transAxes, 

504 ha="center", 

505 va="center", 

506 fontsize=10, 

507 zorder=100, 

508 ) 

509 text.set_path_effects([pathEffects.Stroke(linewidth=3, foreground="w"), pathEffects.Normal()]) 

510 

511 # Finalize plot appearance. 

512 ax.grid() 

513 ax.set_axisbelow(True) 

514 fig = addPlotInfo(fig, plotInfo) 

515 fig.subplots_adjust(left=0.08, right=0.92, top=0.8, bottom=0.17, wspace=0.05) 

516 titleText = self.zAxisLabel.format_map(kwargs) 

517 if "zUnit" in data and data["zUnit"] != "": 

518 titleText += f" ({data['zUnit']})" 

519 fig.suptitle("Metric: " + titleText, fontsize=20) 

520 

521 return fig