Coverage for python/lsst/analysis/tools/actions/plot/matrixPlot.py: 24%

153 statements  

« prev     ^ index     » next       coverage.py v7.4.3, created at 2024-02-24 11:17 +0000

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ("MatrixPlot",) 

25 

26from typing import TYPE_CHECKING, Any, Mapping 

27 

28import astropy.visualization as apViz 

29import matplotlib.patheffects as mpl_path_effects 

30import matplotlib.pyplot as plt 

31import numpy as np 

32from astropy.visualization.mpl_normalize import ImageNormalize 

33from lsst.pex.config import ChoiceField, Config, ConfigDictField, DictField, Field, ListField 

34 

35from ...interfaces import PlotAction, Vector 

36from .plotUtils import addPlotInfo 

37 

38if TYPE_CHECKING: 38 ↛ 39line 38 didn't jump to line 39, because the condition on line 38 was never true

39 from matplotlib.figure import Figure 

40 

41 from ...interfaces import KeyedData, KeyedDataSchema 

42 

43 

44class GuideLinesConfig(Config): 

45 lines = DictField[float, str]( 

46 doc=("Dictionary of x/y-values and the labels where vertical/horizontal lines are drawn."), 

47 optional=False, 

48 ) 

49 

50 color = Field[str]( 

51 doc="The color of the lines and labels.", 

52 default="red", 

53 ) 

54 

55 outlineColor = Field[str]( 

56 doc="The color of the outline around the lines and labels.", 

57 default="white", 

58 ) 

59 

60 linestyle = Field[str]( 

61 doc="The style of the lines.", 

62 default="--", 

63 ) 

64 

65 

66class MatrixPlot(PlotAction): 

67 """Make the plot of a matrix (2D array). 

68 

69 Notes 

70 ----- 

71 The `xAxisTickLabels` and `yAxisTickLabels` attributes of this class serve 

72 as dictionaries to map axis tick positions to their corresponding labels. 

73 If any positions do not align with major ticks (either provided by 

74 `x/yAxisTickValues` or automatically set by matplotlib), they will be 

75 designated as minor ticks. Thus, these tick labels operate independently, 

76 meaning their corresponding positions do not need to match those in 

77 `x/yAxisTickValues` or anything else. The code automatically adjusts to 

78 handle any overlaps caused by user input and across various plotting 

79 scenarios. 

80 Note that when `component1Key` and `component2Key` are specified, the x and 

81 y tick values and labels will be dynamically configured, thereby 

82 eliminating the need for providing `x/yAxisTickValues` and 

83 `x/yAxisTickLabels`. 

84 """ 

85 

86 inputDim = ChoiceField[int]( 

87 doc="The dimensionality of the input data.", 

88 default=1, 

89 allowed={ 

90 1: "1D inputs are automatically reshaped into square 2D matrices.", 

91 2: "2D inputs are directly utilized as is.", 

92 }, 

93 optional=True, 

94 ) 

95 

96 matrixKey = Field[str]( 

97 doc="The key for the input matrix.", 

98 default="matrix", 

99 ) 

100 

101 matrixOrigin = ChoiceField[str]( 

102 doc="Determines the starting corner ('upper', 'lower') for matrix plots.", 

103 default="upper", 

104 allowed={ 

105 "upper": "The origin is at the upper left corner.", 

106 "lower": "The origin is at the lower left corner.", 

107 }, 

108 optional=True, 

109 ) 

110 

111 component1Key = Field[str]( 

112 doc="The key to access a list of names for the first component set in a correlation analysis. This " 

113 "will be used to determine x-axis tick values and labels.", 

114 default=None, 

115 optional=True, 

116 ) 

117 

118 component2Key = Field[str]( 

119 doc="The key to access a list of names for the second component set in a correlation analysis. This " 

120 "will be used to determine y-axis tick values and labels.", 

121 ) 

122 

123 xAxisLabel = Field[str]( 

124 doc="The label to use for the x-axis.", 

125 default="", 

126 optional=True, 

127 ) 

128 

129 yAxisLabel = Field[str]( 

130 doc="The label to use for the y-axis.", 

131 default="", 

132 optional=True, 

133 ) 

134 

135 axisLabelFontSize = Field[float]( 

136 doc="The font size for the axis labels.", 

137 default=9, 

138 optional=True, 

139 ) 

140 

141 colorbarLabel = Field[str]( 

142 doc="The label to use for the colorbar.", 

143 default="", 

144 optional=True, 

145 ) 

146 

147 colorbarLabelFontSize = Field[float]( 

148 doc="The font size for the colorbar label.", 

149 default=10, 

150 optional=True, 

151 ) 

152 

153 colorbarTickLabelFontSize = Field[float]( 

154 doc="The font size for the colorbar tick labels.", 

155 default=8, 

156 optional=True, 

157 ) 

158 

159 colorbarCmap = ChoiceField[str]( 

160 doc="The colormap to use for the colorbar.", 

161 default="viridis", 

162 allowed={name: name for name in plt.colormaps()}, 

163 optional=True, 

164 ) 

165 

166 vmin = Field[float]( 

167 doc="The vmin value for the colorbar.", 

168 default=None, 

169 optional=True, 

170 ) 

171 

172 vmax = Field[float]( 

173 doc="The vmax value for the colorbar.", 

174 default=None, 

175 optional=True, 

176 ) 

177 

178 figsize = ListField[float]( 

179 doc="The size of the figure.", 

180 default=[5, 5], 

181 maxLength=2, 

182 optional=True, 

183 ) 

184 

185 title = Field[str]( 

186 doc="The title of the figure.", 

187 default="", 

188 optional=True, 

189 ) 

190 

191 titleFontSize = Field[float]( 

192 doc="The font size for the title.", 

193 default=10, 

194 optional=True, 

195 ) 

196 

197 xAxisTickValues = ListField[float]( 

198 doc="List of x-axis tick values. If not set, the ticks will be set automatically by matplotlib.", 

199 default=None, 

200 optional=True, 

201 ) 

202 

203 xAxisTickLabels = DictField[float, str]( 

204 doc="Dictionary mapping x-axis tick positions to their corresponding labels. For behavior details, " 

205 "refer to the 'Notes' section of the class docstring.", 

206 default=None, 

207 optional=True, 

208 ) 

209 

210 yAxisTickValues = ListField[float]( 

211 doc="List of y-axis tick values. If not set, the ticks will be set automatically by matplotlib.", 

212 default=None, 

213 optional=True, 

214 ) 

215 

216 yAxisTickLabels = DictField[float, str]( 

217 doc="Dictionary mapping y-axis tick positions to their corresponding labels. For behavior details, " 

218 "refer to the 'Notes' section of the class docstring.", 

219 default=None, 

220 optional=True, 

221 ) 

222 

223 tickLabelsFontSize = Field[float]( 

224 doc="The font size for the tick labels.", 

225 default=8, 

226 optional=True, 

227 ) 

228 

229 tickLabelsRotation = Field[float]( 

230 doc="The rotation of the tick labels.", 

231 default=0, 

232 optional=True, 

233 ) 

234 

235 setPositionsAtPixelBoundaries = Field[bool]( 

236 doc="Whether to consider the positions at the pixel boundaries rather than the center of the pixel.", 

237 default=False, 

238 optional=True, 

239 ) 

240 

241 hideMajorTicks = ListField[str]( 241 ↛ exitline 241 didn't jump to the function exit

242 doc="List of axis names for which to hide the major ticks. The options to include in the list are " 

243 "'x' and 'y'. This does not affect the visibility of major tick 'labels'. For example, setting this " 

244 "field to ['x', 'y'] will hide both major ticks.", 

245 default=[], 

246 maxLength=2, 

247 itemCheck=lambda s: s in ["x", "y"], 

248 optional=True, 

249 ) 

250 

251 hideMinorTicks = ListField[str]( 251 ↛ exitline 251 didn't jump to the function exit

252 doc="List of axis names for which to hide the minor ticks. The options to include in the list are " 

253 "'x' and 'y'. This does not affect the visibility of minor tick labels. For example, setting this " 

254 "field to ['x', 'y'] will hide both minor ticks.", 

255 default=[], 

256 maxLength=2, 

257 itemCheck=lambda s: s in ["x", "y"], 

258 optional=True, 

259 ) 

260 

261 dpi = Field[int]( 

262 doc="The resolution of the figure.", 

263 default=300, 

264 optional=True, 

265 ) 

266 

267 guideLines = ConfigDictField[str, GuideLinesConfig]( 267 ↛ exitline 267 didn't jump to the function exit

268 doc="Dictionary of guide lines for the x and y axes. The keys are 'x' and 'y', and the values are " 

269 "instances of `GuideLinesConfig`.", 

270 default={}, 

271 dictCheck=lambda d: all([k in ["x", "y"] for k in d]), 

272 optional=True, 

273 ) 

274 

275 def getInputSchema(self) -> KeyedDataSchema: 

276 base: list[tuple[str, type[Vector]]] = [] 

277 base.append((self.matrixKey, Vector)) 

278 return base 

279 

280 def __call__(self, data: KeyedData, **kwargs) -> Figure: 

281 self._validateInput(data, **kwargs) 

282 return self.makePlot(data, **kwargs) 

283 

284 def _validateInput(self, data: KeyedData, **kwargs: Any) -> None: 

285 # Check that the input data contains all the required keys. 

286 needed = set(k[0] for k in self.getInputSchema()) 

287 if not needed.issubset(data.keys()): 

288 raise ValueError(f"Input data does not contain all required keys: {self.getInputSchema()}") 

289 # Check the input data is a matrix, i.e. a 2d array. 

290 if not isinstance(data[self.matrixKey], np.ndarray) and data[self.matrixKey].ndim != 2: 

291 raise ValueError(f"Input data is not a 2d array: {data[self.matrixKey]}") 

292 # Check that the keyword arguments are valid. 

293 acceptableKwargs = {"plotInfo", "skymap", "band", "metric_tags", "fig"} 

294 if not set(kwargs).issubset(acceptableKwargs): 

295 raise ValueError( 

296 f"Only the following keyword arguments are allowed: {acceptableKwargs}. Got: {kwargs}" 

297 ) 

298 # Check that if one component key is provided, the other must be too. 

299 if (self.component1Key is not None and self.component2Key is None) or ( 

300 self.component1Key is None and self.component2Key is not None 

301 ): 

302 raise ValueError( 

303 "Both 'component1Key' and 'component2Key' must be provided together if either is provided." 

304 ) 

305 # Check that if component keys are provided, any of the tick values or 

306 # labels are not and vice versa. 

307 if (self.component1Key is not None and self.component2Key is not None) and ( 

308 self.xAxisTickValues is not None 

309 or self.yAxisTickValues is not None 

310 or self.xAxisTickLabels is not None 

311 or self.yAxisTickLabels is not None 

312 ): 

313 raise ValueError( 

314 "If 'component1Key' and 'component2Key' are provided, 'xAxisTickValues', " 

315 "'yAxisTickValues', 'xAxisTickLabels', and 'yAxisTickLabels' should not be " 

316 "provided as they will be dynamically configured." 

317 ) 

318 

319 def makePlot(self, data: KeyedData, plotInfo: Mapping[str, str] | None = None, **kwargs: Any) -> Figure: 

320 """ 

321 Plot a matrix of values. 

322 

323 Parameters 

324 ---------- 

325 data : `~lsst.analysis.tools.interfaces.KeyedData` 

326 The data to plot. 

327 plotInfo : `dict`, optional 

328 A dictionary of information about the data being plotted. 

329 **kwargs 

330 Additional keyword arguments to pass to the plot. 

331 

332 Returns 

333 ------- 

334 fig : `~matplotlib.figure.Figure` 

335 The resulting figure. 

336 """ 

337 # Retrieve the matrix info from the input data. 

338 matrix = data[self.matrixKey] 

339 

340 # Fetch the components between which the correlation is calculated. 

341 if self.component1Key is not None and self.component2Key is not None: 

342 comp1 = data[self.component1Key] 

343 comp2 = data[self.component2Key] 

344 

345 if self.inputDim == 1: 

346 # Calculate the size of the square. 

347 square_size = int(np.sqrt(matrix.size)) 

348 # Reshape into a square array. 

349 matrix = matrix.reshape(square_size, square_size) 

350 if self.component1Key is not None and self.component2Key is not None: 

351 comp1 = comp1.reshape(square_size, square_size) 

352 comp2 = comp2.reshape(square_size, square_size) 

353 

354 # Calculate default limits only if needed. 

355 if self.vmin is None or self.vmax is None: 

356 default_limits = apViz.PercentileInterval(98.0).get_limits(np.abs(matrix.flatten())) 

357 else: 

358 default_limits = (None, None) 

359 

360 # Set the value range using overrides or defaults. 

361 vrange = ( 

362 default_limits[0] if self.vmin is None else self.vmin, 

363 default_limits[1] if self.vmax is None else self.vmax, 

364 ) 

365 

366 # Allow for the figure object to be passed in. 

367 fig = kwargs.get("fig") 

368 if fig is None: 

369 fig = plt.figure(figsize=self.figsize, dpi=self.dpi) 

370 ax = fig.add_subplot(111) 

371 else: 

372 ax = fig.gca() 

373 

374 if self.title: 

375 ax.set_title(self.title, fontsize=self.titleFontSize) 

376 

377 if self.xAxisLabel: 

378 ax.set_xlabel(self.xAxisLabel, fontsize=self.axisLabelFontSize) 

379 

380 if self.yAxisLabel: 

381 ax.set_ylabel(self.yAxisLabel, fontsize=self.axisLabelFontSize) 

382 

383 # Set the colorbar and draw the image. 

384 norm = ImageNormalize(vmin=vrange[0], vmax=vrange[1]) 

385 img = ax.imshow( 

386 matrix, interpolation="none", norm=norm, origin=self.matrixOrigin, cmap=self.colorbarCmap 

387 ) 

388 

389 # Calculate the aspect ratio of the image. 

390 ratio = matrix.shape[0] / matrix.shape[1] 

391 

392 # Add the colorbar flush with the image axis. 

393 cbar = fig.colorbar(img, fraction=0.0457 * ratio, pad=0.04) 

394 

395 # Set the colorbar label and its font size. 

396 cbar.set_label(self.colorbarLabel, fontsize=self.colorbarLabelFontSize) 

397 

398 # Set the colorbar tick label font size. 

399 cbar.ax.tick_params(labelsize=self.colorbarTickLabelFontSize) 

400 

401 # If requested, we shift all the positions by 0.5 considering the 

402 # zero-point at a pixel boundary rather than the center of the pixel. 

403 shift = 0.5 if self.setPositionsAtPixelBoundaries else 0 

404 

405 if self.component1Key is not None and self.component2Key is not None: 

406 xAxisTickValues = np.arange(matrix.shape[0] + shift) 

407 yAxisTickValues = np.arange(matrix.shape[1] + shift) 

408 xAxisTickLabels = {key + shift: str(val) for key, val in zip(range(matrix.shape[0]), comp1[0, :])} 

409 yAxisTickLabels = {key + shift: str(val) for key, val in zip(range(matrix.shape[1]), comp2[:, 0])} 

410 else: 

411 xAxisTickValues = self.xAxisTickValues 

412 yAxisTickValues = self.yAxisTickValues 

413 xAxisTickLabels = self.xAxisTickLabels 

414 yAxisTickLabels = self.yAxisTickLabels 

415 

416 # If the tick values are not provided, retrieve them from the axes. 

417 xticks = xAxisTickValues if xAxisTickValues is not None else ax.xaxis.get_ticklocs() 

418 yticks = yAxisTickValues if yAxisTickValues is not None else ax.yaxis.get_ticklocs() 

419 

420 # Retrieve the current limits of the x and y axes. 

421 xlim, ylim = ax.get_xlim(), ax.get_ylim() 

422 

423 # Filter out tick locations that fall outside the current x/y-axis 

424 # limits to ensures that only tick locations within the visible range 

425 # are kept. 

426 xticks = np.array([tick for tick in xticks if min(xlim) <= tick - shift <= max(xlim)]) 

427 yticks = np.array([tick for tick in yticks if min(ylim) <= tick - shift <= max(ylim)]) 

428 tick_data = { 

429 "x": ( 

430 xticks - shift, 

431 np.array(list(xAxisTickLabels.keys())) - shift if xAxisTickLabels else None, 

432 list(xAxisTickLabels.values()) if xAxisTickLabels else None, 

433 ), 

434 "y": ( 

435 yticks - shift, 

436 np.array(list(yAxisTickLabels.keys())) - shift if yAxisTickLabels else None, 

437 list(yAxisTickLabels.values()) if yAxisTickLabels else None, 

438 ), 

439 } 

440 

441 for dim, axis in [("x", ax.xaxis), ("y", ax.yaxis)]: 

442 # Get the major tick positions and labels. 

443 major_tick_values, positions, labels = tick_data[dim] 

444 

445 # Set major ticks. 

446 axis.set_ticks(major_tick_values, minor=False) 

447 

448 # Set tick labels while compensating for the potential shift in the 

449 # tick positions and removing trailing zeros and the decimal point 

450 # for integer values. 

451 axis.set_ticklabels( 

452 [ 

453 f"{tick + shift:.0f}" if (tick + shift).is_integer() else f"{tick + shift}" 

454 for tick in axis.get_ticklocs() 

455 ], 

456 fontsize=self.tickLabelsFontSize, 

457 ) 

458 

459 # Check if positions are provided. 

460 if positions is not None: 

461 # Assign specified positions as minor ticks. 

462 axis.set_ticks(positions, minor=True) 

463 

464 # Conditionally assign labels to major and/or minor ticks. 

465 if labels is not None: 

466 # Create a lookup for positions to labels. 

467 positions_labels_lookup = { 

468 p: l if p in major_tick_values else "" for p, l in zip(positions, labels) 

469 } 

470 # Generate labels for major ticks, leaving blanks for 

471 # non-major positions. 

472 major_labels = [ 

473 "" if m not in positions_labels_lookup else positions_labels_lookup[m] 

474 for m in major_tick_values 

475 ] 

476 # Generate labels for minor ticks, excluding those 

477 # designated as major. 

478 minor_labels = ["" if p in major_tick_values else l for p, l in zip(positions, labels)] 

479 

480 # Apply labels to major ticks if any exist. 

481 if any(e for e in major_labels if e): 

482 axis.set_ticklabels(major_labels, minor=False, fontsize=self.tickLabelsFontSize) 

483 else: 

484 # If no major labels, clear major tick labels. 

485 axis.set_ticklabels("") 

486 

487 # Apply labels to minor ticks if any exist. 

488 if any(e for e in minor_labels if e): 

489 axis.set_ticklabels(minor_labels, minor=True, fontsize=self.tickLabelsFontSize) 

490 

491 if dim in self.hideMajorTicks: 

492 # Remove major tick marks for asthetic reasons. 

493 axis.set_tick_params(which="major", length=0) 

494 

495 if dim in self.hideMinorTicks: 

496 # Remove minor tick marks for asthetic reasons. 

497 axis.set_tick_params(which="minor", length=0) 

498 

499 # Rotate the tick labels by the specified angle. 

500 ax.tick_params(axis=dim, rotation=self.tickLabelsRotation) 

501 

502 # Add vertical and horizontal lines if provided. 

503 if "x" in self.guideLines: 

504 xLines = self.guideLines["x"] 

505 for x, label in xLines.lines.items(): 

506 ax.axvline(x=x - shift, color=xLines.outlineColor, linewidth=2, alpha=0.6) 

507 ax.axvline( 

508 x=x - shift, color=xLines.color, linestyle=xLines.linestyle, linewidth=1, alpha=0.85 

509 ) 

510 label = ax.text( 

511 x - shift, 

512 0.03, 

513 label, 

514 rotation=90, 

515 color=xLines.color, 

516 transform=ax.get_xaxis_transform(), 

517 horizontalalignment="right", 

518 alpha=0.9, 

519 ) 

520 # Add a distinct outline around the label for better visibility 

521 # in various backgrounds. 

522 label.set_path_effects( 

523 [ 

524 mpl_path_effects.Stroke(linewidth=2, foreground=xLines.outlineColor, alpha=0.8), 

525 mpl_path_effects.Normal(), 

526 ] 

527 ) 

528 

529 if "y" in self.guideLines: 

530 yLines = self.guideLines["y"] 

531 for y, label in yLines.lines.items(): 

532 ax.axhline(y=y - shift, color=yLines.outlineColor, linewidth=2, alpha=0.6) 

533 ax.axhline( 

534 y=y - shift, color=yLines.color, linestyle=yLines.linestyle, linewidth=1, alpha=0.85 

535 ) 

536 label = ax.text( 

537 0.03, 

538 y - shift, 

539 label, 

540 color=yLines.color, 

541 transform=ax.get_yaxis_transform(), 

542 verticalalignment="bottom", 

543 alpha=0.9, 

544 ) 

545 # Add a distinct outline around the label for better visibility 

546 # in various backgrounds. 

547 label.set_path_effects( 

548 [ 

549 mpl_path_effects.Stroke(linewidth=2, foreground=yLines.outlineColor, alpha=0.8), 

550 mpl_path_effects.Normal(), 

551 ] 

552 ) 

553 # Add plot info if provided. 

554 if plotInfo is not None: 

555 fig = addPlotInfo(fig, plotInfo) 

556 

557 return fig