Coverage for python / lsst / analysis / tools / actions / plot / wholeSkyPlot.py: 14%

241 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-05 08:45 +0000

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ("WholeSkyPlot",) 

25 

26import importlib.resources as importResources 

27import json 

28from collections.abc import Mapping 

29 

30import matplotlib.patheffects as pathEffects 

31import numpy as np 

32import yaml 

33from matplotlib import gridspec 

34from matplotlib.collections import PatchCollection 

35from matplotlib.colors import CenteredNorm 

36from matplotlib.figure import Figure 

37from matplotlib.patches import Patch, Polygon 

38 

39import lsst.analysis.tools 

40from lsst.pex.config import ChoiceField, Field, ListField 

41from lsst.utils.plotting import ( 

42 accent_color, 

43 divergent_cmap, 

44 make_figure, 

45 set_rubin_plotstyle, 

46 stars_cmap, 

47 stars_color, 

48) 

49 

50from ...interfaces import KeyedData, KeyedDataSchema, PlotAction, Scalar, Vector 

51from ...math import nanSigmaMad 

52from ...utils import getTractCorners 

53from .plotUtils import addPlotInfo 

54 

55 

56class AnnotatedFigure(Figure): 

57 metadata: dict 

58 

59 

60class WholeSkyPlot(PlotAction): 

61 """Plots the on sky distribution of a parameter. 

62 

63 Plots the values of the parameter given for the z axis 

64 according to the positions given for x and y. Optimised 

65 for use with RA and Dec. Also calculates some basic 

66 statistics and includes those on the plot. 

67 

68 The default axes limits and figure size were chosen to plot HSC PDR2. 

69 """ 

70 

71 xAxisLabel = Field[str](doc="Label to use for the x axis.", default="RA (degrees)") 

72 yAxisLabel = Field[str](doc="Label to use for the y axis.", default="Dec (degrees)") 

73 zAxisLabel = Field[str](doc="Label to use for the z axis.", default="") 

74 autoAxesLimits = Field[bool](doc="Find axes limits automatically.", default=True) 

75 xLimits = ListField[float](doc="Plotting limits for the x axis.", default=[-5.0, 365.0]) 

76 yLimits = ListField[float](doc="Plotting limits for the y axis.", default=[-10.0, 60.0]) 

77 autoAxesLimits = Field[bool](doc="Find axes limits automatically.", default=True) 

78 colorBarMin = Field[float](doc="The minimum value of the color bar.", optional=True) 

79 colorBarMax = Field[float](doc="The minimum value of the color bar.", optional=True) 

80 colorBarRange = Field[float]( 

81 doc="The multiplier for the color bar range. The max/min range values are: median +/- N * sigmaMad" 

82 ", where N is this config value.", 

83 default=3.0, 

84 ) 

85 colorMapType = ChoiceField[str]( 

86 doc="Type of color map to use for the color bar. Options: sequential, divergent, userDefined.", 

87 allowed={cmType: cmType for cmType in ("sequential", "divergent")}, 

88 default="divergent", 

89 ) 

90 colorMap = ListField[str]( 

91 doc="List of hexidecimal colors for a user-defined color map.", 

92 optional=True, 

93 ) 

94 showOutliers = Field[bool]( 

95 doc="Show the outliers on the plot. " 

96 "Outliers are values whose absolute value is > colorBarRange * sigmaMAD.", 

97 default=True, 

98 ) 

99 showNaNs = Field[bool](doc="Show the NaNs on the plot.", default=True) 

100 labelTracts = Field[bool](doc="Label the tracts.", default=False) 

101 

102 addThresholds = Field[bool]( 

103 doc="Read in the predefined thresholds and indicate them on the histogram.", 

104 default=True, 

105 ) 

106 

107 def getInputSchema(self, **kwargs) -> KeyedDataSchema: 

108 base = [] 

109 base.append(("z", Vector)) 

110 base.append(("tract", Vector)) 

111 return base 

112 

113 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure: 

114 self._validateInput(data, **kwargs) 

115 return self.makePlot(data, **kwargs) 

116 

117 def _validateInput(self, data: KeyedData, **kwargs) -> None: 

118 """NOTE currently can only check that something is not a Scalar, not 

119 check that the data is consistent with Vector 

120 """ 

121 needed = self.getInputSchema(**kwargs) 

122 if remainder := {key.format(**kwargs) for key, _ in needed} - { 

123 key.format(**kwargs) for key in data.keys() 

124 }: 

125 raise ValueError(f"Task needs keys {remainder} but they were not found in input") 

126 for name, typ in needed: 

127 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar) 

128 if isScalar and typ != Scalar: 

129 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}") 

130 

131 def _getAxesLimits(self, xs: list, ys: list) -> tuple(list, list): 

132 """Get the x and y axes limits in degrees. 

133 

134 Parameters 

135 ---------- 

136 xs : `list` 

137 X coordinates for the tracts to plot. 

138 ys : `list` 

139 Y coordinates for the tracts to plot. 

140 

141 Returns 

142 ------- 

143 xlim : `list` 

144 Minimun and maximum x axis values. 

145 ylim : `list` 

146 Minimun and maximum y axis values. 

147 """ 

148 

149 # Add some blank space on the edges of the plot. 

150 xlim = [np.nanmin(xs) - 5, np.nanmax(xs) + 5] 

151 ylim = [np.nanmin(ys) - 5, np.nanmax(ys) + 5] 

152 

153 # Limit to only show real RA/Dec values. 

154 if xlim[0] < 0.0: 

155 xlim[0] = 0.0 

156 if xlim[1] > 360.0: 

157 xlim[1] = 360.0 

158 if ylim[0] < -90.0: 

159 ylim[0] = -90.0 

160 if ylim[1] > 90.0: 

161 ylim[1] = 90.0 

162 

163 return (xlim, ylim) 

164 

165 def _getMaxOutlierVals(self, multiplier: float, tracts: list, values: list, outlierInds: list) -> str: 

166 """Get the 5 largest outlier values in a string. 

167 

168 Parameters 

169 ---------- 

170 multiplier : `float` 

171 Select values whose absolute value is > multiplier * sigmaMAD. 

172 tracts : `list` 

173 All the tracts. 

174 values : `list` 

175 All the metric values. 

176 outlierInds : `list` 

177 Indicies of outlier values. 

178 

179 Returns 

180 ------- 

181 text : `str` 

182 A string containing the 10 tracts with the largest outlier values. 

183 """ 

184 if self.addThresholds: 

185 text = "Tracts with value outside thresholds: " 

186 else: 

187 text = f"Tracts with |value| > {multiplier}" + r"$\sigma_{MAD}$" + ": " 

188 if len(outlierInds) > 0: 

189 outlierValues = np.array(values)[outlierInds] 

190 outlierTracts = np.array(tracts)[outlierInds] 

191 # Sort values in descending (-) absolute value order discounting 

192 # NaNs. 

193 maxInds = np.argsort(-np.abs(outlierValues)) 

194 # Show up to ten values on the plot. 

195 for ind in maxInds[:10]: 

196 val = outlierValues[ind] 

197 tract = outlierTracts[ind] 

198 text += f"{tract}, {val:.3}; " 

199 # Remove the final trailing comma and whitespace. 

200 text = text[:-2] 

201 else: 

202 text += "None" 

203 

204 return text 

205 

206 def makePlot( 

207 self, 

208 data: KeyedData, 

209 plotInfo: Mapping[str, str] | None = None, 

210 **kwargs, 

211 ) -> AnnotatedFigure: 

212 """Make a WholeSkyPlot of the given data. 

213 

214 Parameters 

215 ---------- 

216 data : `KeyedData` 

217 The catalog to plot the points from. 

218 plotInfo : `dict` 

219 A dictionary of information about the data being plotted with keys: 

220 

221 ``"run"`` 

222 The output run for the plots (`str`). 

223 ``"skymap"`` 

224 The type of skymap used for the data (`str`). 

225 ``"filter"`` 

226 The filter used for this data (`str`). 

227 ``"tract"`` 

228 The tract that the data comes from (`str`). 

229 

230 Returns 

231 ------- 

232 `pipeBase.Struct` containing: 

233 skyPlot : `matplotlib.figure.Figure` 

234 The resulting figure. 

235 

236 

237 Examples 

238 -------- 

239 An example of the plot produced from this code is here: 

240 

241 .. image:: /_static/analysis_tools/wholeSkyPlotExample.png 

242 

243 For a detailed example of how to make a plot from the command line 

244 please see the 

245 :ref:`getting started guide<analysis-tools-getting-started>`. 

246 """ 

247 skymap = kwargs["skymap"] 

248 if plotInfo is None: 

249 plotInfo = {} 

250 

251 if self.addThresholds: 

252 metricThresholdFile = importResources.read_text(lsst.analysis.tools, "metricInformation.yaml") 

253 metricDefs = yaml.safe_load(metricThresholdFile) 

254 

255 # Prevent Bands in the plot info showing a list of bands. 

256 # If bands is a list, it implies that parameterizedBand=False, 

257 # and that the metric is not band-specific. 

258 if "bands" in plotInfo: 

259 if isinstance(plotInfo["bands"], list): 

260 plotInfo["bands"] = "N/A" 

261 

262 colorMap = self.colorMap 

263 match self.colorMapType: 

264 case "sequential": 

265 if colorMap is None: 

266 colorMap = stars_cmap() 

267 outlierColor = "red" 

268 norm = None 

269 case "divergent": 

270 if colorMap is None: 

271 colorMap = divergent_cmap() 

272 outlierColor = "fuchsia" 

273 norm = CenteredNorm() 

274 

275 # Create patches using the corners of each tract. 

276 patches = [] 

277 colBarVals = [] 

278 tracts = [] 

279 ras = [] 

280 decs = [] 

281 mid_ras = [] 

282 mid_decs = [] 

283 for i, tract in enumerate(data["tract"]): 

284 corners = getTractCorners(skymap, tract) 

285 patches.append(Polygon(corners, closed=True)) 

286 colBarVals.append(data["z"][i]) 

287 tracts.append(tract) 

288 ras.append(corners[0][0]) 

289 decs.append(corners[0][1]) 

290 mid_ras.append((corners[0][0] + corners[1][0]) / 2) 

291 mid_decs.append((corners[0][1] + corners[2][1]) / 2) 

292 

293 # Setup figure. 

294 fig: AnnotatedFigure = make_figure(dpi=300, figsize=(12, 3.5)) 

295 set_rubin_plotstyle() 

296 gs = gridspec.GridSpec(1, 4) 

297 ax = fig.add_subplot(gs[:3]) 

298 # Add colored patches showing tract metric values. 

299 patchCollection = PatchCollection(patches, cmap=colorMap, norm=norm) 

300 ax.add_collection(patchCollection) 

301 

302 # Define color bar range. 

303 if np.sum(np.isfinite(colBarVals)) > 0: 

304 med = np.nanmedian(colBarVals) 

305 else: 

306 med = np.nan 

307 sigmaMad = nanSigmaMad(colBarVals) 

308 if self.colorBarMin is not None: 

309 vmin = np.float64(self.colorBarMin) 

310 else: 

311 vmin = med - self.colorBarRange * sigmaMad 

312 if self.colorBarMax is not None: 

313 vmax = np.float64(self.colorBarMax) 

314 else: 

315 vmax = med + self.colorBarRange * sigmaMad 

316 

317 dataName = self.zAxisLabel.format_map(kwargs) 

318 colBarVals = np.array(colBarVals) 

319 if self.addThresholds and dataName in metricDefs: 

320 if "lowThreshold" in metricDefs[dataName].keys(): 

321 lowThreshold = metricDefs[dataName]["lowThreshold"] 

322 else: 

323 lowThreshold = np.nan 

324 if "highThreshold" in metricDefs[dataName].keys(): 

325 highThreshold = metricDefs[dataName]["highThreshold"] 

326 else: 

327 highThreshold = np.nan 

328 outlierInds = np.where((colBarVals < lowThreshold) | (colBarVals > highThreshold))[0] 

329 else: 

330 # Note tracts with metrics outside (vmin, vmax) as outliers. 

331 outlierInds = np.where((colBarVals < vmin) | (colBarVals > vmax))[0] 

332 

333 # Initialize legend handles. 

334 handles = [] 

335 

336 if self.showOutliers: 

337 # Plot the outlier patches. 

338 outlierPatches = [] 

339 if len(outlierInds) > 0: 

340 for ind in outlierInds: 

341 outlierPatches.append(patches[ind]) 

342 outlierPatchCollection = PatchCollection( 

343 outlierPatches, 

344 cmap=colorMap, 

345 norm=norm, 

346 facecolors="none", 

347 edgecolors=outlierColor, 

348 linewidths=0.5, 

349 zorder=100, 

350 ) 

351 ax.add_collection(outlierPatchCollection) 

352 # Add legend information. 

353 outlierPatch = Patch( 

354 facecolor="none", 

355 edgecolor=outlierColor, 

356 linewidth=0.5, 

357 label="Outlier", 

358 ) 

359 handles.append(outlierPatch) 

360 

361 if self.showNaNs: 

362 # Plot tracts with NaN metric values. 

363 nanInds = np.where(~np.isfinite(colBarVals))[0] 

364 nanPatches = [] 

365 if len(nanInds) > 0: 

366 for ind in nanInds: 

367 nanPatches.append(patches[ind]) 

368 nanPatchCollection = PatchCollection( 

369 nanPatches, 

370 cmap=None, 

371 norm=norm, 

372 facecolors="white", 

373 edgecolors="grey", 

374 linestyles="dotted", 

375 linewidths=0.5, 

376 zorder=100, 

377 ) 

378 ax.add_collection(nanPatchCollection) 

379 # Add legend information. 

380 nanPatch = Patch( 

381 facecolor="white", 

382 edgecolor="grey", 

383 linestyle="dotted", 

384 linewidth=0.5, 

385 label="NaN", 

386 ) 

387 handles.append(nanPatch) 

388 

389 if len(handles) > 0: 

390 fig.legend(handles=handles) 

391 

392 if self.labelTracts: 

393 # Label the tracts 

394 for i, tract in enumerate(tracts): 

395 ax.text( 

396 mid_ras[i], 

397 mid_decs[i], 

398 f"{tract}", 

399 ha="center", 

400 va="center", 

401 fontsize=2, 

402 alpha=0.7, 

403 zorder=100, 

404 ) 

405 

406 ax.set_aspect("equal") 

407 axPos = ax.get_position() 

408 ax1 = fig.add_axes([0.73, 0.25, 0.20, 0.47]) 

409 

410 if np.sum(np.isfinite(data["z"])) > 0: 

411 ax1.hist(data["z"], bins=len(data["z"] / 10), color=stars_color(), histtype="step") 

412 else: 

413 ax1.text(0.5, 0.5, "Data all NaN/Inf") 

414 ax1.set_xlabel("Metric Values") 

415 ax1.set_ylabel("Number") 

416 ax1.yaxis.set_label_position("right") 

417 ax1.yaxis.tick_right() 

418 

419 if self.addThresholds and dataName in metricDefs: 

420 # Check the thresholds are finite and set them to 

421 # the min/max of the data if they aren't to calculate 

422 # the x range of the plot 

423 if np.isfinite(lowThreshold): 

424 ax1.axvline(lowThreshold, color=accent_color()) 

425 else: 

426 lowThreshold = np.nanmin(colBarVals) 

427 if np.isfinite(highThreshold): 

428 ax1.axvline(highThreshold, color=accent_color()) 

429 else: 

430 highThreshold = np.nanmax(colBarVals) 

431 

432 widthThreshold = highThreshold - lowThreshold 

433 upperLim = highThreshold + 0.5 * widthThreshold 

434 lowerLim = lowThreshold - 0.5 * widthThreshold 

435 ax1.set_xlim(lowerLim, upperLim) 

436 numOutside = np.sum((data["z"] > upperLim) | (data["z"] < lowerLim)) 

437 ax1.set_title("Outside plot limits: " + str(numOutside)) 

438 

439 else: 

440 if vmin != vmax and np.isfinite(vmin) and np.isfinite(vmax): 

441 ax1.set_xlim(vmin, vmax) 

442 

443 if self.autoAxesLimits: 

444 xlim, ylim = self._getAxesLimits(ras, decs) 

445 else: 

446 xlim, ylim = self.xLimits, self.yLimits 

447 ax.set_xlim(xlim) 

448 ax.set_ylim(ylim) 

449 ax.set_xlabel(self.xAxisLabel) 

450 ax.set_ylabel(self.yAxisLabel) 

451 ax.invert_xaxis() 

452 

453 if self.showOutliers: 

454 # Add text boxes to show the number of tracts, number of NaNs, 

455 # median, sigma MAD, and the five largest outlier values. 

456 outlierText = self._getMaxOutlierVals(self.colorBarRange, tracts, colBarVals, outlierInds) 

457 # Make vertical text spacing readable for different figure sizes. 

458 multiplier = 3.5 / fig.get_size_inches()[1] 

459 verticalSpacing = 0.028 * multiplier 

460 fig.text( 

461 0.01, 

462 0.01 + 3 * verticalSpacing, 

463 f"Num tracts: {len(tracts)}", 

464 transform=fig.transFigure, 

465 fontsize=8, 

466 alpha=0.7, 

467 ) 

468 if self.showNaNs: 

469 fig.text( 

470 0.01, 

471 0.01 + 2 * verticalSpacing, 

472 f"Num nans: {len(nanInds)}", 

473 transform=fig.transFigure, 

474 fontsize=8, 

475 alpha=0.7, 

476 ) 

477 fig.text( 

478 0.01, 

479 0.01 + verticalSpacing, 

480 f"Median: {med:.3f}; " + r"$\sigma_{MAD}$" + f": {sigmaMad:.3f}", 

481 transform=fig.transFigure, 

482 fontsize=8, 

483 alpha=0.7, 

484 ) 

485 if self.showOutliers: 

486 fig.text(0.01, 0.01, outlierText, transform=fig.transFigure, fontsize=8, alpha=0.7) 

487 

488 # Truncate the color range to (vmin, vmax). 

489 if vmin != vmax and np.isfinite(vmin) and np.isfinite(vmax): 

490 colBarVals = np.clip(np.array(colBarVals), vmin, vmax) 

491 patchCollection.set_array(colBarVals) 

492 # Make the color bar with a metric label. 

493 axPos = ax.get_position() 

494 cax = fig.add_axes([0.084, axPos.y1 + 0.02, 0.62, 0.07]) 

495 fig.colorbar( 

496 patchCollection, 

497 cax=cax, 

498 shrink=0.7, 

499 extend="both", 

500 location="top", 

501 orientation="horizontal", 

502 ) 

503 cbarText = "Metric Values" 

504 

505 text = cax.text( 

506 0.5, 

507 0.5, 

508 cbarText, 

509 transform=cax.transAxes, 

510 ha="center", 

511 va="center", 

512 fontsize=10, 

513 zorder=100, 

514 ) 

515 text.set_path_effects([pathEffects.Stroke(linewidth=3, foreground="w"), pathEffects.Normal()]) 

516 

517 # Finalize plot appearance. 

518 ax.grid() 

519 ax.set_axisbelow(True) 

520 addPlotInfo(fig, plotInfo) 

521 fig.subplots_adjust(left=0.08, right=0.92, top=0.8, bottom=0.17, wspace=0.05) 

522 titleText = self.zAxisLabel.format_map(kwargs) 

523 if "zUnit" in data and data["zUnit"] != "": 

524 titleText += f" ({data['zUnit']})" 

525 fig.suptitle("Metric: " + titleText, fontsize=20) 

526 

527 # This saves metadata in the PNG that allows the plot-navigator 

528 # to provide tract numbers and metric values on mouseover. 

529 # 

530 # PNG metadata is a set of string keys and string values. 

531 # The WholeSkyPlot stores two keys: 

532 # - label: the string describing the regions ('tract') 

533 # - boxes, JSON string of a list of per-region dictionaries, 

534 # where each dictionary has fields: 

535 # - min_x, max_x, min_y, max_y for the pixel coordinates of 

536 # the four corners of the region 

537 # - id: the identifier of the region (e.g. tract number) 

538 # - value: the region's metric, as a string. 

539 # 

540 def make_patch_md(patch, id_field, value, ax): 

541 path = ax.transData.transform_path(patch.get_path()) 

542 x_path = [int(x) for x in path.vertices[:, 0].tolist()] 

543 y_path = [int(y) for y in path.vertices[:, 1].tolist()] 

544 return { 

545 "min_x": min(x_path), 

546 "max_x": max(x_path), 

547 "min_y": min(y_path), 

548 "max_y": max(y_path), 

549 "id": f"{id_field}", 

550 "value": f"{value:.3}", 

551 } 

552 

553 # After ax.set_aspect(), the figure needs to be drawn for the axes 

554 # transformations to be updated to the right values. 

555 fig.canvas.draw_idle() 

556 

557 patch_coordinate_entries = [ 

558 make_patch_md(patch, tract, value, ax) 

559 for (patch, tract, value) in zip(patches, tracts, colBarVals) 

560 ] 

561 

562 fig.metadata = {"label": "Tract", "boxes": json.dumps(patch_coordinate_entries)} 

563 

564 return fig