Coverage for python/lsst/analysis/tools/actions/plot/propertyMapPlot.py: 18%

162 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-26 04:07 -0700

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("PropertyMapPlot",) 

24 

25import logging 

26from typing import Iterable, Mapping, Union 

27 

28import lsst.pex.config as pexConfig 

29import matplotlib.patheffects as mpl_path_effects 

30import matplotlib.pyplot as plt 

31import numpy as np 

32import skyproj 

33from healsparse.healSparseMap import HealSparseMap 

34from lsst.analysis.tools.tasks.propertyMapTractAnalysis import PropertyMapTractAnalysisConfig 

35from lsst.skymap.tractInfo import ExplicitTractInfo 

36from matplotlib.figure import Figure 

37from matplotlib.legend_handler import HandlerTuple 

38 

39from ...interfaces import KeyedData, PlotAction 

40 

41_LOG = logging.getLogger(__name__) 

42 

43 

44class CustomHandler(HandlerTuple): 

45 """Custom legend handler to overlay multiple patches for a single 

46 legend entry. 

47 

48 This handler class inherits from `HandlerTuple` and is designed to 

49 handle cases where multiple artists (e.g., patches) need to be overlaid 

50 on top of each other in a single legend entry, as opposed to 

51 side-by-side which is the default behavior of `HandlerTuple`. 

52 

53 Methods 

54 ------- 

55 create_artists: 

56 Override the `create_artists` method of `HandlerTuple` to modify 

57 the positioning of the artists so that they overlay directly on top 

58 of one another in the legend. 

59 

60 Example 

61 ------- 

62 # Plot some data. 

63 line, = ax.plot(x, y, label="Sample Line") 

64 

65 # Use CustomHandler for overlaid patches and also include the regular 

66 # line legend if desired. 

67 handles = [(patch1, patch2), line] 

68 labels = ['Overlaid Patches', line.get_label()] 

69 leg = ax.legend( 

70 handles, labels, handler_map={tuple: CustomHandler()}, loc="best" 

71 ) 

72 """ 

73 

74 def create_artists(self, legend, orig_handle, xdescent, ydescent, width, height, fontsize, trans): 

75 artists = HandlerTuple.create_artists( 

76 self, legend, orig_handle, xdescent, ydescent, width, height, fontsize, trans 

77 ) 

78 # Overlay the two patches. 

79 for a in artists: 

80 a.set_transform(trans) 

81 return artists 

82 

83 

84class PropertyMapPlot(PlotAction): 

85 plotName = pexConfig.Field[str](doc="The name for the plotting task.", optional=True) 

86 

87 def __call__( 

88 self, 

89 data: KeyedData, 

90 tractInfo: ExplicitTractInfo, 

91 plotConfig: PropertyMapTractAnalysisConfig, 

92 plotInfo: Mapping[str, Union[Mapping[str, str], str, int]], 

93 **kwargs, 

94 ) -> Mapping[str, Figure]: 

95 self._validateInput(data, tractInfo, plotConfig, plotInfo) 

96 return self.makePlot(data, tractInfo, plotConfig, plotInfo) 

97 

98 def _validateInput( 

99 self, 

100 data: KeyedData, 

101 tractInfo: ExplicitTractInfo, 

102 plotConfig: PropertyMapTractAnalysisConfig, 

103 plotInfo: Mapping[str, Union[Mapping[str, str], str, int]], 

104 ) -> None: 

105 """Validate the input data.""" 

106 

107 if not isinstance(tractInfo, ExplicitTractInfo): 

108 raise TypeError(f"Input `tractInfo` type must be {ExplicitTractInfo} not {type(tractInfo)}.") 

109 

110 if not isinstance(plotConfig, PropertyMapTractAnalysisConfig): 

111 raise TypeError( 

112 "`plotConfig` must be a " 

113 "`lsst.analysis.tools.tasks.propertyMapTractAnalysis.PropertyMapTractAnalysisConfig`." 

114 ) 

115 

116 if not isinstance(plotInfo, dict): 

117 raise TypeError("`plotConfig` must be a dictionary.") 

118 

119 zoomFactors = plotConfig.zoomFactors 

120 isListOfFloats = isinstance(zoomFactors, pexConfig.listField.List) and all( 

121 isinstance(zf, float) for zf in zoomFactors 

122 ) 

123 if not (isListOfFloats and len(zoomFactors) == 2) or any(zf <= 1 for zf in zoomFactors): 

124 raise TypeError( 

125 "`zoomFactors` must be a two-element `lsst.pex.config.listField.List` of floats > 1." 

126 ) 

127 

128 for prop, propConfig in plotConfig.properties.items(): 

129 if not isinstance(propConfig.nBinsHist, int) or propConfig.nBinsHist <= 0: 

130 raise ValueError( 

131 f"`nBinsHist` for property `{prop}` must be a positive integer, not " 

132 f"{propConfig.nBinsHist}." 

133 ) 

134 

135 # Identify any invalid entries in `data`. 

136 invalidEntries = { 

137 key: pytype 

138 for key, pytype in {k: v.ref.datasetType.storageClass.pytype for k, v in data.items()}.items() 

139 if pytype != HealSparseMap 

140 } 

141 

142 # If any invalid entries are found, raise a TypeError with details. 

143 if invalidEntries: 

144 errorMessage = "; ".join( 

145 f"`{key}` should be {HealSparseMap}, got {type_}" for key, type_ in invalidEntries.items() 

146 ) 

147 raise TypeError(f"Invalid input types found in `data`: {errorMessage}") 

148 

149 def addPlotInfo( 

150 self, 

151 fig: Figure, 

152 plotInfo: Mapping[str, Union[Mapping[str, str], str, int]], 

153 mapName: Mapping[str, str], 

154 ) -> Figure: 

155 """Add useful information to the plot. 

156 

157 Parameters 

158 ---------- 

159 fig : `matplotlib.figure.Figure` 

160 The figure to add the information to. 

161 plotInfo : `dict` 

162 A dictionary of the plot information. 

163 mapName : `str` 

164 The name of the map being plotted. 

165 

166 Returns 

167 ------- 

168 fig : `matplotlib.figure.Figure` 

169 The figure with the information added. 

170 """ 

171 

172 run = plotInfo["run"] 

173 tableType = f"\nTable: {plotInfo['tableNames'][mapName]}" 

174 

175 dataIdText = f"Tract: {plotInfo['tract']}, Band: {plotInfo['band']}" 

176 mapText = ( 

177 f", Property: {plotInfo['property']}, " 

178 f"Operation: {plotInfo['operation']}, " 

179 f"Coadd: {plotInfo['coaddName']}" 

180 ) 

181 geomText = f", Valid area: {plotInfo['valid_area']:.2f} sq. deg., " f"NSIDE: {plotInfo['nside']}" 

182 infoText = f"\n{dataIdText}{mapText}" 

183 

184 fig.text( 

185 0.04, 

186 0.965, 

187 f'{plotInfo["plotName"]}: {plotInfo["property"].title().replace("Psf", "PSF")}', 

188 fontsize=19, 

189 transform=fig.transFigure, 

190 ha="left", 

191 va="top", 

192 ) 

193 t = fig.text( 

194 0.04, 

195 0.942, 

196 f"{run}{tableType}{geomText}{infoText}", 

197 fontsize=15, 

198 transform=fig.transFigure, 

199 alpha=0.6, 

200 ha="left", 

201 va="top", 

202 ) 

203 t.set_linespacing(1.4) 

204 

205 return fig 

206 

207 def makePlot( 

208 self, 

209 data: KeyedData, 

210 tractInfo: ExplicitTractInfo, 

211 plotConfig: PropertyMapTractAnalysisConfig, 

212 plotInfo: Mapping[str, Union[Mapping[str, str], str, int]], 

213 ) -> Mapping[str, Figure]: 

214 """Make the survey property map plot. 

215 

216 Parameters 

217 ---------- 

218 data : `KeyedData` 

219 The HealSparseMap to plot the points from. 

220 tractInfo: `~lsst.skymap.tractInfo.ExplicitTractInfo` 

221 The tract info object. 

222 plotConfig : 

223 `~lsst.analysis.tools.tasks.propertyMapTractAnalysis. 

224 PropertyMapTractAnalysisConfig` 

225 The configuration for the plot. 

226 plotInfo : `dict` 

227 A dictionary of information about the data being plotted. 

228 

229 Returns 

230 ------- 

231 figDict : `dict` [`~matplotlib.figure.Figure`] 

232 The resulting figures. 

233 """ 

234 

235 # 'plotName' defaults to the attribute specified in 

236 # 'atools.<attribute>' in the pipeline YAML. If it is explicitly 

237 # set in `~lsst.analysis.tools.atools.propertyMap.PropertyMapTool`, 

238 # it will override this default. 

239 if self.plotName: 

240 # Set the plot name using 'produce.plot.plotName' from 

241 # PropertyMapTool's instance. 

242 plotInfo["plotName"] = self.plotName 

243 

244 figDict: dict[str, Figure] = {} 

245 

246 # Plotting customization. 

247 colorbarTickLabelSize = 14 

248 colorBarAspect = 16 

249 rcparams = { 

250 "axes.labelsize": 18, 

251 "axes.linewidth": 1.8, 

252 "xtick.labelsize": 13, 

253 "ytick.labelsize": 13, 

254 } 

255 zoomFactors = plotConfig.zoomFactors 

256 

257 # Muted green for the full map, and muted red and blue for the two 

258 # zoomed-in maps. 

259 histColors = ["#265D40", "#8B0000", "#00008B"] 

260 

261 with plt.rc_context(rcparams): 

262 for mapName, mapDataHandle in data.items(): 

263 mapData = mapDataHandle.get() 

264 fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16, 16)) 

265 

266 # Reduce whitespace but leave some room at the top for 

267 # plotInfo. 

268 plt.subplots_adjust(left=0.064, right=0.96, top=0.855, bottom=0.07, wspace=0.18, hspace=0.24) 

269 

270 # Get the values for the valid pixels of the full tract. 

271 values = mapData[mapData.valid_pixels] 

272 goodValues = np.isfinite(values) 

273 values = values[goodValues] # As a precaution. 

274 

275 # Make a concise human-readable label for the plot. 

276 plotInfo["coaddName"] = mapName.split("Coadd_")[0] 

277 plotInfo["operation"] = self.getLongestSuffixMatch( 

278 mapName, ["min", "max", "mean", "weighted_mean", "sum"] 

279 ).replace("_", " ") 

280 propertyName = mapName[ 

281 len(f"{plotInfo['coaddName']}Coadd_") : -len(plotInfo["operation"]) 

282 ].strip("_") 

283 plotInfo["property"] = propertyName.replace("_", " ") 

284 

285 nBinsHist = plotConfig.properties[propertyName].nBinsHist 

286 fullExtent = None 

287 zoomIdx = [] 

288 for ax, zoom, zoomFactor, histColor in zip( 

289 [ax1, ax3, ax4], [True, False, False], [None, *zoomFactors], histColors 

290 ): 

291 extent = self.getZoomedExtent(fullExtent, zoomFactor) 

292 sp = skyproj.GnomonicSkyproj( 

293 ax=ax, 

294 lon_0=tractInfo.ctr_coord.getRa().asDegrees(), 

295 lat_0=tractInfo.ctr_coord.getDec().asDegrees(), 

296 extent=extent, 

297 rcparams=rcparams, 

298 ) 

299 sp.draw_hspmap(mapData, zoom=zoom) 

300 sp.set_xlabel("RA") 

301 sp.set_ylabel("Dec") 

302 cbar = sp.draw_colorbar(location="right", fraction=0.15, aspect=colorBarAspect, pad=0) 

303 cbar.ax.tick_params(labelsize=colorbarTickLabelSize) 

304 cbarText = ( 

305 "Full Tract" if zoomFactor is None else f"{self.prettyPrintFloat(zoomFactor)}x Zoom" 

306 ) 

307 self.addTextToColorbar(cbar, cbarText, color=histColor) 

308 if zoomFactor is None: 

309 # Save the skyproj object of the full-tract plot. 

310 # Will be used in drawing zoom rectangles etc. 

311 spf = sp 

312 # Get the extent of the full tract. 

313 fullExtent = spf.get_extent() 

314 else: 

315 # Create a rectangle for the zoomed-in region. 

316 x0, x1, y0, y1 = extent 

317 for c, ls, lw in zip(["white", histColor], ["solid", "dashed"], [3.5, 1.5]): 

318 spf.draw_polygon( 

319 [x0, x0, x1, x1], 

320 [y0, y1, y1, y0], 

321 facecolor="none", 

322 edgecolor=c, 

323 linestyle=ls, 

324 linewidth=lw, 

325 alpha=0.8, 

326 ) 

327 zoomText = spf.ax.text( 

328 (x0 + x1) / 2, 

329 y0, 

330 f"{self.prettyPrintFloat(zoomFactor)}x", 

331 color=histColor, 

332 fontsize=14, 

333 fontweight="bold", 

334 alpha=0.8, 

335 ha="center", 

336 va="bottom", 

337 ) 

338 # Add a distinct outline around the text for better 

339 # visibility in various backgrounds. 

340 zoomText.set_path_effects( 

341 [ 

342 mpl_path_effects.Stroke(linewidth=4, foreground="white", alpha=0.8), 

343 mpl_path_effects.Normal(), 

344 ] 

345 ) 

346 # Get the indices of pixels in the zoomed-in region. 

347 pos = mapData.valid_pixels_pos() 

348 # Reversed axes consideration. 

349 xmin, xmax = sorted([x0, x1]) 

350 idx = (pos[0] > xmin) & (pos[0] < xmax) & (pos[1] > y0) & (pos[1] < y1) 

351 zoomIdx.append(idx[goodValues]) 

352 

353 # Calculate weights for each bin to ensure that the peak of the 

354 # histogram reaches 1. 

355 weights = np.ones_like(values) / np.histogram(values, bins=nBinsHist)[0].max() 

356 

357 # Compute full-tract histogram and get its bins. 

358 # NOTE: `exposure_time` histograms are quantized and look more 

359 # bar-like, so they look better with fewer bins. 

360 bins = ax2.hist( 

361 values, 

362 bins=nBinsHist, 

363 label="Full Tract", 

364 color=histColors[0], 

365 weights=weights, 

366 alpha=0.7, 

367 )[1] 

368 

369 # Align the histogram (top right panel) with the skyproj plots. 

370 pos1 = spf.ax.get_position() # Top left. 

371 pos4 = sp.ax.get_position() # Bottom right. 

372 cbarWidth = cbar.ax.get_position().height / colorBarAspect 

373 # NOTE: cbarWidth != cbar.ax.get_position().width 

374 ax2.set_position([pos4.x0, pos1.y0, pos4.width + cbarWidth, pos1.height]) 

375 

376 # Overplot the histograms for the zoomed-in plots. 

377 for zoomFactor, zidx, color, linestyle, hatch in zip( 

378 zoomFactors, zoomIdx, histColors[1:], ["solid", "dotted"], ["//", "xxxx"] 

379 ): 

380 weights = np.ones_like(values[zidx]) / np.histogram(values[zidx], bins=bins)[0].max() 

381 histLabel = f"{self.prettyPrintFloat(zoomFactor)}x Zoom" 

382 histValues = ax2.hist( 

383 values[zidx], 

384 bins=bins, 

385 label=histLabel, 

386 color=color, 

387 weights=weights, 

388 histtype="step", 

389 linewidth=2, 

390 linestyle=linestyle, 

391 alpha=0.6, 

392 )[0] 

393 # Fill the area under the step. 

394 ax2.fill_between( 

395 bins[:-1], 

396 histValues, 

397 step="post", 

398 color=color, 

399 alpha=0.2, 

400 hatch=hatch, 

401 label="hidden", 

402 ) 

403 

404 # Set labels and legend. 

405 ax2.set_xlabel(plotInfo["property"].title().replace("Psf", "PSF")) 

406 ax2.set_ylabel("Normalized Count") 

407 

408 # Get handles and labels from the axis. 

409 handles, labels = ax2.get_legend_handles_labels() 

410 

411 # Add a legend with custom handler that combines the handle 

412 # pairs for the zoomed-in cases. 

413 handles = [handles[0], (handles[1], handles[2]), (handles[3], handles[4])] 

414 while "hidden" in labels: 

415 labels.remove("hidden") 

416 legend = ax2.legend( 

417 handles, 

418 labels, 

419 handler_map={tuple: CustomHandler()}, 

420 loc="best", 

421 frameon=False, 

422 fontsize=15, 

423 ) 

424 

425 for line, text in zip(handles, legend.get_texts()): 

426 if isinstance(line, tuple): 

427 # Use the first handle to get the color. 

428 line = line[0] 

429 color = line.get_edgecolor() if line.get_facecolor()[-1] == 0 else line.get_facecolor() 

430 text.set_color(color) 

431 

432 # Add extra info to plotInfo. 

433 plotInfo["nside"] = mapData.nside_sparse 

434 plotInfo["valid_area"] = mapData.get_valid_area() 

435 

436 # Add useful information to the plot. 

437 figDict[mapName] = self.addPlotInfo(fig, plotInfo, mapName) 

438 

439 _LOG.info( 

440 f"Made property map plot for dataset type {mapName}, tract: {plotInfo['tract']}, " 

441 f"band: '{plotInfo['band']}'." 

442 ) 

443 

444 return figDict 

445 

446 def getOutputNames(self, config=None) -> Iterable[str]: 

447 # Docstring inherited. 

448 

449 # Names needed for making corresponding output connections for the maps 

450 # that are configured for this task. 

451 outputNames: tuple[str] = () 

452 for propertyName in config.properties: 

453 coaddName = config.properties[propertyName].coaddName 

454 for operationName in config.properties[propertyName].operations: 

455 outputNames += (f"{coaddName}Coadd_{propertyName}_{operationName}",) 

456 

457 return outputNames 

458 

459 @staticmethod 

460 def getZoomedExtent(fullExtent, n): 

461 """Get zoomed extent centered on the original full plot. 

462 

463 Parameters 

464 ---------- 

465 fullExtent : `tuple` [`float`] 

466 The full extent defined by (lon_min, lon_max, lat_min, lat_max): 

467 

468 * ``lon_min`` 

469 Minimum longitude of the original extent (`float`). 

470 * ``"lon_max"`` 

471 Maximum longitude of the original extent (`float`). 

472 * ``lat_min`` 

473 Minimum latitude of the original extent (`float`). 

474 * ``"lat_max"`` 

475 Maximum latitude of the original extent (`float`). 

476 

477 n : `float`, optional 

478 Zoom factor; for instance, n=2 means zooming in 2 times at the 

479 center. If None, the function returns None. 

480 

481 Returns 

482 ------- 

483 `tuple` [`float`] 

484 New extent as (new_lon_min, new_lon_max, new_lat_min, new_lat_max). 

485 """ 

486 if n is None: 

487 return None 

488 lon_min, lon_max, lat_min, lat_max = fullExtent 

489 lon_center, lat_center = (lon_min + lon_max) / 2, (lat_min + lat_max) / 2 

490 half_lon = (lon_max - lon_min) * np.cos(np.radians(lat_center)) / (2 * n) 

491 half_lat = (lat_max - lat_min) / (2 * n) 

492 return lon_center - half_lon, lon_center + half_lon, lat_center - half_lat, lat_center + half_lat 

493 

494 @staticmethod 

495 def prettyPrintFloat(n): 

496 if n.is_integer(): 

497 return str(int(n)) 

498 return str(n) 

499 

500 @staticmethod 

501 def addTextToColorbar( 

502 cb, text, orientation="vertical", color="black", fontsize=14, fontweight="bold", alpha=0.8 

503 ): 

504 """Helper method to add text inside the horizontal colorbar. 

505 

506 Parameters 

507 ---------- 

508 cb : `~matplotlib.colorbar.Colorbar` 

509 The colorbar object. 

510 text : `str` 

511 The text to add. 

512 orientation : `str`, optional 

513 The orientation of the colorbar. Can be either "vertical" or 

514 "horizontal". 

515 fontsize : `int`, optional 

516 The fontsize of the text. 

517 fontweight : `str`, optional 

518 The fontweight of the text. 

519 alpha : `float`, optional 

520 The alpha value of the text. 

521 

522 Returns 

523 ------- 

524 `None` 

525 The text is added to the colorbar in place. 

526 """ 

527 if color is None: 

528 color = "black" 

529 vmid = (cb.vmin + cb.vmax) / 2 

530 positions = {"vertical": (0.5, vmid), "horizontal": (vmid, 0.5)} 

531 cbtext = cb.ax.text( 

532 *positions[orientation], 

533 text, 

534 color=color, 

535 va="center", 

536 ha="center", 

537 fontsize=fontsize, 

538 fontweight=fontweight, 

539 rotation=orientation, 

540 alpha=alpha, 

541 ) 

542 # Add a distinct outline around the text for better visibility in 

543 # various backgrounds. 

544 cbtext.set_path_effects( 

545 [mpl_path_effects.Stroke(linewidth=4, foreground="white", alpha=0.8), mpl_path_effects.Normal()] 

546 ) 

547 

548 @staticmethod 

549 def getLongestSuffixMatch(s, options): 

550 """Find the longest suffix in the provided list that matches the end of 

551 the given string. 

552 

553 Parameters 

554 ---------- 

555 s : `str` 

556 The target string for which we want to find a matching suffix. 

557 options : `list` [`str`] 

558 A list of potential suffix strings to match against the target 

559 string `s`. 

560 

561 Returns 

562 ------- 

563 `str` 

564 The longest matching suffix from the `options` list. If no match is 

565 found, returns `None`. 

566 """ 

567 return next((opt for opt in sorted(options, key=len, reverse=True) if s.endswith(opt)), None)