Coverage for python/lsst/analysis/tools/actions/plot/matrixPlot.py: 24%

162 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-03-28 12:39 +0000

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ("MatrixPlot",) 

25 

26from typing import TYPE_CHECKING, Any, Mapping 

27 

28import astropy.visualization as apViz 

29import matplotlib.patheffects as mpl_path_effects 

30import matplotlib.pyplot as plt 

31import numpy as np 

32from astropy.visualization.mpl_normalize import ImageNormalize 

33from lsst.pex.config import ChoiceField, Config, ConfigDictField, DictField, Field, ListField 

34 

35from ...interfaces import PlotAction, Vector 

36from .plotUtils import addPlotInfo 

37 

38if TYPE_CHECKING: 38 ↛ 39line 38 didn't jump to line 39, because the condition on line 38 was never true

39 from matplotlib.figure import Figure 

40 

41 from ...interfaces import KeyedData, KeyedDataSchema 

42 

43 

44class GuideLinesConfig(Config): 

45 lines = DictField[float, str]( 

46 doc=("Dictionary of x/y-values and the labels where vertical/horizontal lines are drawn."), 

47 optional=False, 

48 ) 

49 

50 color = Field[str]( 

51 doc="The color of the lines and labels.", 

52 default="red", 

53 ) 

54 

55 outlineColor = Field[str]( 

56 doc="The color of the outline around the lines and labels.", 

57 default="white", 

58 ) 

59 

60 linestyle = Field[str]( 

61 doc="The style of the lines.", 

62 default="--", 

63 ) 

64 

65 

66class MatrixPlot(PlotAction): 

67 """Make the plot of a matrix (2D array). 

68 

69 Notes 

70 ----- 

71 The `xAxisTickLabels` and `yAxisTickLabels` attributes of this class serve 

72 as dictionaries to map axis tick positions to their corresponding labels. 

73 If any positions do not align with major ticks (either provided by 

74 `x/yAxisTickValues` or automatically set by matplotlib), they will be 

75 designated as minor ticks. Thus, these tick labels operate independently, 

76 meaning their corresponding positions do not need to match those in 

77 `x/yAxisTickValues` or anything else. The code automatically adjusts to 

78 handle any overlaps caused by user input and across various plotting 

79 scenarios. 

80 Note that when `component1Key` and `component2Key` are specified, the x and 

81 y tick values and labels will be dynamically configured, thereby 

82 eliminating the need for providing `x/yAxisTickValues` and 

83 `x/yAxisTickLabels`. When `componentGroup1Key` and `componentGroup2Key` are 

84 specified, the x and y axis labels are dynamically updated to include the 

85 group names, prefixed by `xAxisLabel` and `yAxisLabel` for a more 

86 descriptive labeling. 

87 """ 

88 

89 inputDim = ChoiceField[int]( 

90 doc="The dimensionality of the input data.", 

91 default=1, 

92 allowed={ 

93 1: "1D inputs are automatically reshaped into square 2D matrices.", 

94 2: "2D inputs are directly utilized as is.", 

95 }, 

96 optional=True, 

97 ) 

98 

99 matrixKey = Field[str]( 

100 doc="The key for the input matrix.", 

101 default="matrix", 

102 ) 

103 

104 matrixOrigin = ChoiceField[str]( 

105 doc="Determines the starting corner ('upper', 'lower') for matrix plots. It only affects the visual " 

106 "appearance of the plot.", 

107 default="upper", 

108 allowed={ 

109 "upper": "The origin is at the upper left corner.", 

110 "lower": "The origin is at the lower left corner.", 

111 }, 

112 optional=True, 

113 ) 

114 

115 component1Key = Field[str]( 

116 doc="The key to access a list of names for the first set of components in a correlation analysis. " 

117 "This will be used to determine x-axis tick values and tick labels.", 

118 default=None, 

119 optional=True, 

120 ) 

121 

122 component2Key = Field[str]( 

123 doc="The key to access a list of names for the second set of components in a correlation analysis. " 

124 "This will be used to determine y-axis tick values and tick labels.", 

125 ) 

126 

127 componentGroup1Key = Field[str]( 

128 doc="The key to access a list of group names for the first set of components in a correlation " 

129 "analysis. This will be used to determine the x-axis label.", 

130 default=None, 

131 optional=True, 

132 ) 

133 

134 componentGroup2Key = Field[str]( 

135 doc="The key to access a list of group names for the second set of components in a correlation " 

136 "analysis. This will be used to determine the y-axis label.", 

137 default=None, 

138 optional=True, 

139 ) 

140 

141 xAxisLabel = Field[str]( 

142 doc="The label to use for the x-axis.", 

143 default="", 

144 optional=True, 

145 ) 

146 

147 yAxisLabel = Field[str]( 

148 doc="The label to use for the y-axis.", 

149 default="", 

150 optional=True, 

151 ) 

152 

153 axisLabelFontSize = Field[float]( 

154 doc="The font size for the axis labels.", 

155 default=9, 

156 optional=True, 

157 ) 

158 

159 colorbarLabel = Field[str]( 

160 doc="The label to use for the colorbar.", 

161 default="", 

162 optional=True, 

163 ) 

164 

165 colorbarLabelFontSize = Field[float]( 

166 doc="The font size for the colorbar label.", 

167 default=10, 

168 optional=True, 

169 ) 

170 

171 colorbarTickLabelFontSize = Field[float]( 

172 doc="The font size for the colorbar tick labels.", 

173 default=8, 

174 optional=True, 

175 ) 

176 

177 colorbarCmap = ChoiceField[str]( 

178 doc="The colormap to use for the colorbar.", 

179 default="viridis", 

180 allowed={name: name for name in plt.colormaps()}, 

181 optional=True, 

182 ) 

183 

184 vmin = Field[float]( 

185 doc="The vmin value for the colorbar.", 

186 default=None, 

187 optional=True, 

188 ) 

189 

190 vmax = Field[float]( 

191 doc="The vmax value for the colorbar.", 

192 default=None, 

193 optional=True, 

194 ) 

195 

196 figsize = ListField[float]( 

197 doc="The size of the figure.", 

198 default=[5, 5], 

199 maxLength=2, 

200 optional=True, 

201 ) 

202 

203 title = Field[str]( 

204 doc="The title of the figure.", 

205 default="", 

206 optional=True, 

207 ) 

208 

209 titleFontSize = Field[float]( 

210 doc="The font size for the title.", 

211 default=10, 

212 optional=True, 

213 ) 

214 

215 xAxisTickValues = ListField[float]( 

216 doc="List of x-axis tick values. If not set, the ticks will be set automatically by matplotlib.", 

217 default=None, 

218 optional=True, 

219 ) 

220 

221 xAxisTickLabels = DictField[float, str]( 

222 doc="Dictionary mapping x-axis tick positions to their corresponding labels. For behavior details, " 

223 "refer to the 'Notes' section of the class docstring.", 

224 default=None, 

225 optional=True, 

226 ) 

227 

228 yAxisTickValues = ListField[float]( 

229 doc="List of y-axis tick values. If not set, the ticks will be set automatically by matplotlib.", 

230 default=None, 

231 optional=True, 

232 ) 

233 

234 yAxisTickLabels = DictField[float, str]( 

235 doc="Dictionary mapping y-axis tick positions to their corresponding labels. For behavior details, " 

236 "refer to the 'Notes' section of the class docstring.", 

237 default=None, 

238 optional=True, 

239 ) 

240 

241 tickLabelsFontSize = Field[float]( 

242 doc="The font size for the tick labels.", 

243 default=8, 

244 optional=True, 

245 ) 

246 

247 tickLabelsRotation = Field[float]( 

248 doc="The rotation of the tick labels.", 

249 default=0, 

250 optional=True, 

251 ) 

252 

253 setPositionsAtPixelBoundaries = Field[bool]( 

254 doc="Whether to consider the positions at the pixel boundaries rather than the center of the pixel.", 

255 default=False, 

256 optional=True, 

257 ) 

258 

259 hideMajorTicks = ListField[str]( 259 ↛ exitline 259 didn't jump to the function exit

260 doc="List of axis names for which to hide the major ticks. The options to include in the list are " 

261 "'x' and 'y'. This does not affect the visibility of major tick 'labels'. For example, setting this " 

262 "field to ['x', 'y'] will hide both major ticks.", 

263 default=[], 

264 maxLength=2, 

265 itemCheck=lambda s: s in ["x", "y"], 

266 optional=True, 

267 ) 

268 

269 hideMinorTicks = ListField[str]( 269 ↛ exitline 269 didn't jump to the function exit

270 doc="List of axis names for which to hide the minor ticks. The options to include in the list are " 

271 "'x' and 'y'. This does not affect the visibility of minor tick labels. For example, setting this " 

272 "field to ['x', 'y'] will hide both minor ticks.", 

273 default=[], 

274 maxLength=2, 

275 itemCheck=lambda s: s in ["x", "y"], 

276 optional=True, 

277 ) 

278 

279 dpi = Field[int]( 

280 doc="The resolution of the figure.", 

281 default=300, 

282 optional=True, 

283 ) 

284 

285 guideLines = ConfigDictField[str, GuideLinesConfig]( 285 ↛ exitline 285 didn't jump to the function exit

286 doc="Dictionary of guide lines for the x and y axes. The keys are 'x' and 'y', and the values are " 

287 "instances of `GuideLinesConfig`.", 

288 default={}, 

289 dictCheck=lambda d: all([k in ["x", "y"] for k in d]), 

290 optional=True, 

291 ) 

292 

293 def getInputSchema(self) -> KeyedDataSchema: 

294 base: list[tuple[str, type[Vector]]] = [] 

295 base.append((self.matrixKey, Vector)) 

296 return base 

297 

298 def __call__(self, data: KeyedData, **kwargs) -> Figure: 

299 self._validateInput(data, **kwargs) 

300 return self.makePlot(data, **kwargs) 

301 

302 def _validateInput(self, data: KeyedData, **kwargs: Any) -> None: 

303 # Check that the input data contains all the required keys. 

304 needed = set(k[0] for k in self.getInputSchema()) 

305 if not needed.issubset(data.keys()): 

306 raise ValueError(f"Input data does not contain all required keys: {self.getInputSchema()}") 

307 # Check the input data is a matrix, i.e. a 2d array. 

308 if not isinstance(data[self.matrixKey], np.ndarray) and data[self.matrixKey].ndim != 2: 

309 raise ValueError(f"Input data is not a 2d array: {data[self.matrixKey]}") 

310 # Check that the keyword arguments are valid. 

311 acceptableKwargs = {"plotInfo", "skymap", "band", "metric_tags", "fig"} 

312 if not set(kwargs).issubset(acceptableKwargs): 

313 raise ValueError( 

314 f"Only the following keyword arguments are allowed: {acceptableKwargs}. Got: {kwargs}" 

315 ) 

316 # Check that if one component key is provided, the other must be too. 

317 if (self.component1Key is not None and self.component2Key is None) or ( 

318 self.component1Key is None and self.component2Key is not None 

319 ): 

320 raise ValueError( 

321 "Both 'component1Key' and 'component2Key' must be provided together if either is provided." 

322 ) 

323 # Check that if component keys are provided, any of the tick values or 

324 # labels are not and vice versa. 

325 if (self.component1Key is not None and self.component2Key is not None) and ( 

326 self.xAxisTickValues is not None 

327 or self.yAxisTickValues is not None 

328 or self.xAxisTickLabels is not None 

329 or self.yAxisTickLabels is not None 

330 ): 

331 raise ValueError( 

332 "If 'component1Key' and 'component2Key' are provided, 'xAxisTickValues', " 

333 "'yAxisTickValues', 'xAxisTickLabels', and 'yAxisTickLabels' should not be " 

334 "provided as they will be dynamically configured." 

335 ) 

336 

337 def makePlot(self, data: KeyedData, plotInfo: Mapping[str, str] | None = None, **kwargs: Any) -> Figure: 

338 """ 

339 Plot a matrix of values. 

340 

341 Parameters 

342 ---------- 

343 data : `~lsst.analysis.tools.interfaces.KeyedData` 

344 The data to plot. 

345 plotInfo : `dict`, optional 

346 A dictionary of information about the data being plotted. 

347 **kwargs 

348 Additional keyword arguments to pass to the plot. 

349 

350 Returns 

351 ------- 

352 fig : `~matplotlib.figure.Figure` 

353 The resulting figure. 

354 """ 

355 # Retrieve the matrix info from the input data. 

356 matrix = data[self.matrixKey] 

357 

358 # Fetch the components between which the correlation is calculated. 

359 if self.component1Key is not None and self.component2Key is not None: 

360 comp1 = data[self.component1Key] 

361 comp2 = data[self.component2Key] 

362 

363 if self.inputDim == 1: 

364 # Calculate the size of the square. 

365 square_size = int(np.sqrt(matrix.size)) 

366 # Reshape into a square array. 

367 matrix = matrix.reshape(square_size, square_size) 

368 if self.component1Key is not None and self.component2Key is not None: 

369 comp1 = comp1.reshape(square_size, square_size) 

370 comp2 = comp2.reshape(square_size, square_size) 

371 

372 # Calculate default limits only if needed. 

373 if self.vmin is None or self.vmax is None: 

374 default_limits = apViz.PercentileInterval(98.0).get_limits(np.abs(matrix.flatten())) 

375 else: 

376 default_limits = (None, None) 

377 

378 # Set the value range using overrides or defaults. 

379 vrange = ( 

380 default_limits[0] if self.vmin is None else self.vmin, 

381 default_limits[1] if self.vmax is None else self.vmax, 

382 ) 

383 

384 # Allow for the figure object to be passed in. 

385 fig = kwargs.get("fig") 

386 if fig is None: 

387 fig = plt.figure(figsize=self.figsize, dpi=self.dpi) 

388 ax = fig.add_subplot(111) 

389 else: 

390 ax = fig.gca() 

391 

392 if self.title: 

393 ax.set_title(self.title, fontsize=self.titleFontSize) 

394 

395 if self.componentGroup1Key is not None and self.componentGroup2Key is not None: 

396 componentGroup1 = set(data[self.componentGroup1Key]) 

397 componentGroup2 = set(data[self.componentGroup2Key]) 

398 if len(componentGroup1) != 1 or len(componentGroup2) != 1: 

399 raise ValueError( 

400 f"Each column specified by {self.componentGroup1Key} and {self.componentGroup2Key} must " 

401 "contain identical values within itself, but they do not." 

402 ) 

403 else: 

404 xAxisLabel = self.xAxisLabel + str(componentGroup1.pop()) 

405 yAxisLabel = self.yAxisLabel + str(componentGroup2.pop()) 

406 else: 

407 xAxisLabel = self.xAxisLabel 

408 yAxisLabel = self.yAxisLabel 

409 

410 ax.set_xlabel(xAxisLabel, fontsize=self.axisLabelFontSize) 

411 ax.set_ylabel(yAxisLabel, fontsize=self.axisLabelFontSize) 

412 

413 # Set the colorbar and draw the image. 

414 norm = ImageNormalize(vmin=vrange[0], vmax=vrange[1]) 

415 img = ax.imshow( 

416 matrix, interpolation="none", norm=norm, origin=self.matrixOrigin, cmap=self.colorbarCmap 

417 ) 

418 

419 # Calculate the aspect ratio of the image. 

420 ratio = matrix.shape[0] / matrix.shape[1] 

421 

422 # Add the colorbar flush with the image axis. 

423 cbar = fig.colorbar(img, fraction=0.0457 * ratio, pad=0.04) 

424 

425 # Set the colorbar label and its font size. 

426 cbar.set_label(self.colorbarLabel, fontsize=self.colorbarLabelFontSize) 

427 

428 # Set the colorbar tick label font size. 

429 cbar.ax.tick_params(labelsize=self.colorbarTickLabelFontSize) 

430 

431 # If requested, we shift all the positions by 0.5 considering the 

432 # zero-point at a pixel boundary rather than the center of the pixel. 

433 shift = 0.5 if self.setPositionsAtPixelBoundaries else 0 

434 

435 if self.component1Key is not None and self.component2Key is not None: 

436 xAxisTickValues = np.arange(matrix.shape[1] + shift) 

437 yAxisTickValues = np.arange(matrix.shape[0] + shift) 

438 xAxisTickLabels = {key + shift: str(val) for key, val in zip(range(matrix.shape[1]), comp1[0, :])} 

439 yAxisTickLabels = {key + shift: str(val) for key, val in zip(range(matrix.shape[0]), comp2[:, 0])} 

440 else: 

441 xAxisTickValues = self.xAxisTickValues 

442 yAxisTickValues = self.yAxisTickValues 

443 xAxisTickLabels = self.xAxisTickLabels 

444 yAxisTickLabels = self.yAxisTickLabels 

445 

446 # If the tick values are not provided, retrieve them from the axes. 

447 xticks = xAxisTickValues if xAxisTickValues is not None else ax.xaxis.get_ticklocs() 

448 yticks = yAxisTickValues if yAxisTickValues is not None else ax.yaxis.get_ticklocs() 

449 

450 # Retrieve the current limits of the x and y axes. 

451 xlim, ylim = ax.get_xlim(), ax.get_ylim() 

452 

453 # Filter out tick locations that fall outside the current x/y-axis 

454 # limits to ensures that only tick locations within the visible range 

455 # are kept. 

456 xticks = np.array([tick for tick in xticks if min(xlim) <= tick - shift <= max(xlim)]) 

457 yticks = np.array([tick for tick in yticks if min(ylim) <= tick - shift <= max(ylim)]) 

458 tick_data = { 

459 "x": ( 

460 xticks - shift, 

461 np.array(list(xAxisTickLabels.keys())) - shift if xAxisTickLabels else None, 

462 list(xAxisTickLabels.values()) if xAxisTickLabels else None, 

463 ), 

464 "y": ( 

465 yticks - shift, 

466 np.array(list(yAxisTickLabels.keys())) - shift if yAxisTickLabels else None, 

467 list(yAxisTickLabels.values()) if yAxisTickLabels else None, 

468 ), 

469 } 

470 

471 for dim, axis in [("x", ax.xaxis), ("y", ax.yaxis)]: 

472 # Get the major tick positions and labels. 

473 major_tick_values, positions, labels = tick_data[dim] 

474 

475 # Set major ticks. 

476 axis.set_ticks(major_tick_values, minor=False) 

477 

478 # Set tick labels while compensating for the potential shift in the 

479 # tick positions and removing trailing zeros and the decimal point 

480 # for integer values. 

481 axis.set_ticklabels( 

482 [ 

483 f"{tick + shift:.0f}" if (tick + shift).is_integer() else f"{tick + shift}" 

484 for tick in axis.get_ticklocs() 

485 ], 

486 fontsize=self.tickLabelsFontSize, 

487 ) 

488 

489 # Check if positions are provided. 

490 if positions is not None: 

491 # Assign specified positions as minor ticks. 

492 axis.set_ticks(positions, minor=True) 

493 

494 # Conditionally assign labels to major and/or minor ticks. 

495 if labels is not None: 

496 # Create a lookup for positions to labels. 

497 positions_labels_lookup = { 

498 p: l if p in major_tick_values else "" for p, l in zip(positions, labels) 

499 } 

500 # Generate labels for major ticks, leaving blanks for 

501 # non-major positions. 

502 major_labels = [ 

503 "" if m not in positions_labels_lookup else positions_labels_lookup[m] 

504 for m in major_tick_values 

505 ] 

506 # Generate labels for minor ticks, excluding those 

507 # designated as major. 

508 minor_labels = ["" if p in major_tick_values else l for p, l in zip(positions, labels)] 

509 

510 # Apply labels to major ticks if any exist. 

511 if any(e for e in major_labels if e): 

512 axis.set_ticklabels(major_labels, minor=False, fontsize=self.tickLabelsFontSize) 

513 else: 

514 # If no major labels, clear major tick labels. 

515 axis.set_ticklabels("") 

516 

517 # Apply labels to minor ticks if any exist. 

518 if any(e for e in minor_labels if e): 

519 axis.set_ticklabels(minor_labels, minor=True, fontsize=self.tickLabelsFontSize) 

520 

521 if dim in self.hideMajorTicks: 

522 # Remove major tick marks for asthetic reasons. 

523 axis.set_tick_params(which="major", length=0) 

524 

525 if dim in self.hideMinorTicks: 

526 # Remove minor tick marks for asthetic reasons. 

527 axis.set_tick_params(which="minor", length=0) 

528 

529 # Rotate the tick labels by the specified angle. 

530 ax.tick_params(axis=dim, rotation=self.tickLabelsRotation) 

531 

532 # Add vertical and horizontal lines if provided. 

533 if "x" in self.guideLines: 

534 xLines = self.guideLines["x"] 

535 for x, label in xLines.lines.items(): 

536 ax.axvline(x=x - shift, color=xLines.outlineColor, linewidth=2, alpha=0.6) 

537 ax.axvline( 

538 x=x - shift, color=xLines.color, linestyle=xLines.linestyle, linewidth=1, alpha=0.85 

539 ) 

540 label = ax.text( 

541 x - shift, 

542 0.03, 

543 label, 

544 rotation=90, 

545 color=xLines.color, 

546 transform=ax.get_xaxis_transform(), 

547 horizontalalignment="right", 

548 alpha=0.9, 

549 ) 

550 # Add a distinct outline around the label for better visibility 

551 # in various backgrounds. 

552 label.set_path_effects( 

553 [ 

554 mpl_path_effects.Stroke(linewidth=2, foreground=xLines.outlineColor, alpha=0.8), 

555 mpl_path_effects.Normal(), 

556 ] 

557 ) 

558 

559 if "y" in self.guideLines: 

560 yLines = self.guideLines["y"] 

561 for y, label in yLines.lines.items(): 

562 ax.axhline(y=y - shift, color=yLines.outlineColor, linewidth=2, alpha=0.6) 

563 ax.axhline( 

564 y=y - shift, color=yLines.color, linestyle=yLines.linestyle, linewidth=1, alpha=0.85 

565 ) 

566 label = ax.text( 

567 0.03, 

568 y - shift, 

569 label, 

570 color=yLines.color, 

571 transform=ax.get_yaxis_transform(), 

572 verticalalignment="bottom", 

573 alpha=0.9, 

574 ) 

575 # Add a distinct outline around the label for better visibility 

576 # in various backgrounds. 

577 label.set_path_effects( 

578 [ 

579 mpl_path_effects.Stroke(linewidth=2, foreground=yLines.outlineColor, alpha=0.8), 

580 mpl_path_effects.Normal(), 

581 ] 

582 ) 

583 

584 # Add plot info if provided. 

585 if plotInfo is not None: 

586 fig = addPlotInfo(fig, plotInfo) 

587 

588 return fig