Coverage for python / lsst / analysis / tools / actions / plot / matrixPlot.py: 25%

160 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-18 09:19 +0000

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ("MatrixPlot",) 

25 

26from collections.abc import Mapping 

27from typing import TYPE_CHECKING, Any 

28 

29import astropy.visualization as apViz 

30import matplotlib.patheffects as mpl_path_effects 

31import matplotlib.pyplot as plt 

32import numpy as np 

33from astropy.visualization.mpl_normalize import ImageNormalize 

34 

35from lsst.pex.config import ChoiceField, Config, ConfigDictField, DictField, Field, ListField 

36 

37from ...interfaces import PlotAction, Vector 

38from .plotUtils import addPlotInfo 

39 

40if TYPE_CHECKING: 

41 from matplotlib.figure import Figure 

42 

43 from ...interfaces import KeyedData, KeyedDataSchema 

44 

45 

46class GuideLinesConfig(Config): 

47 lines = DictField[float, str]( 

48 doc=("Dictionary of x/y-values and the labels where vertical/horizontal lines are drawn."), 

49 optional=False, 

50 ) 

51 

52 color = Field[str]( 

53 doc="The color of the lines and labels.", 

54 default="red", 

55 ) 

56 

57 outlineColor = Field[str]( 

58 doc="The color of the outline around the lines and labels.", 

59 default="white", 

60 ) 

61 

62 linestyle = Field[str]( 

63 doc="The style of the lines.", 

64 default="--", 

65 ) 

66 

67 

68class MatrixPlot(PlotAction): 

69 """Make the plot of a matrix (2D array). 

70 

71 Notes 

72 ----- 

73 The `xAxisTickLabels` and `yAxisTickLabels` attributes of this class serve 

74 as dictionaries to map axis tick positions to their corresponding labels. 

75 If any positions do not align with major ticks (either provided by 

76 `x/yAxisTickValues` or automatically set by matplotlib), they will be 

77 designated as minor ticks. Thus, these tick labels operate independently, 

78 meaning their corresponding positions do not need to match those in 

79 `x/yAxisTickValues` or anything else. The code automatically adjusts to 

80 handle any overlaps caused by user input and across various plotting 

81 scenarios. 

82 Note that when `component1Key` and `component2Key` are specified, the x and 

83 y tick values and labels will be dynamically configured, thereby 

84 eliminating the need for providing `x/yAxisTickValues` and 

85 `x/yAxisTickLabels`. When `componentGroup1Key` and `componentGroup2Key` are 

86 specified, the x and y axis labels are dynamically updated to include the 

87 group names, prefixed by `xAxisLabel` and `yAxisLabel` for a more 

88 descriptive labeling. 

89 """ 

90 

91 inputDim = ChoiceField[int]( 

92 doc="The dimensionality of the input data.", 

93 default=1, 

94 allowed={ 

95 1: "1D inputs are automatically reshaped into square 2D matrices.", 

96 2: "2D inputs are directly utilized as is.", 

97 }, 

98 optional=True, 

99 ) 

100 

101 matrixKey = Field[str]( 

102 doc="The key for the input matrix.", 

103 default="matrix", 

104 ) 

105 

106 matrixOrigin = ChoiceField[str]( 

107 doc="Determines the starting corner ('upper', 'lower') for matrix plots. It only affects the visual " 

108 "appearance of the plot.", 

109 default="upper", 

110 allowed={ 

111 "upper": "The origin is at the upper left corner.", 

112 "lower": "The origin is at the lower left corner.", 

113 }, 

114 optional=True, 

115 ) 

116 

117 component1Key = Field[str]( 

118 doc="The key to access a list of names for the first set of components in a correlation analysis. " 

119 "This will be used to determine x-axis tick values and tick labels.", 

120 default=None, 

121 optional=True, 

122 ) 

123 

124 component2Key = Field[str]( 

125 doc="The key to access a list of names for the second set of components in a correlation analysis. " 

126 "This will be used to determine y-axis tick values and tick labels.", 

127 ) 

128 

129 componentGroup1Key = Field[str]( 

130 doc="The key to access a list of group names for the first set of components in a correlation " 

131 "analysis. This will be used to determine the x-axis label.", 

132 default=None, 

133 optional=True, 

134 ) 

135 

136 componentGroup2Key = Field[str]( 

137 doc="The key to access a list of group names for the second set of components in a correlation " 

138 "analysis. This will be used to determine the y-axis label.", 

139 default=None, 

140 optional=True, 

141 ) 

142 

143 xAxisLabel = Field[str]( 

144 doc="The label to use for the x-axis.", 

145 default="", 

146 optional=True, 

147 ) 

148 

149 yAxisLabel = Field[str]( 

150 doc="The label to use for the y-axis.", 

151 default="", 

152 optional=True, 

153 ) 

154 

155 axisLabelFontSize = Field[float]( 

156 doc="The font size for the axis labels.", 

157 default=9, 

158 optional=True, 

159 ) 

160 

161 colorbarLabel = Field[str]( 

162 doc="The label to use for the colorbar.", 

163 default="", 

164 optional=True, 

165 ) 

166 

167 colorbarLabelFontSize = Field[float]( 

168 doc="The font size for the colorbar label.", 

169 default=10, 

170 optional=True, 

171 ) 

172 

173 colorbarTickLabelFontSize = Field[float]( 

174 doc="The font size for the colorbar tick labels.", 

175 default=8, 

176 optional=True, 

177 ) 

178 

179 colorbarCmap = ChoiceField[str]( 

180 doc="The colormap to use for the colorbar.", 

181 default="viridis", 

182 allowed={name: name for name in plt.colormaps()}, 

183 optional=True, 

184 ) 

185 

186 vmin = Field[float]( 

187 doc="The vmin value for the colorbar.", 

188 default=None, 

189 optional=True, 

190 ) 

191 

192 vmax = Field[float]( 

193 doc="The vmax value for the colorbar.", 

194 default=None, 

195 optional=True, 

196 ) 

197 

198 figsize = ListField[float]( 

199 doc="The size of the figure.", 

200 default=[5, 5], 

201 maxLength=2, 

202 optional=True, 

203 ) 

204 

205 title = Field[str]( 

206 doc="The title of the figure.", 

207 default="", 

208 optional=True, 

209 ) 

210 

211 titleFontSize = Field[float]( 

212 doc="The font size for the title.", 

213 default=10, 

214 optional=True, 

215 ) 

216 

217 xAxisTickValues = ListField[float]( 

218 doc="List of x-axis tick values. If not set, the ticks will be set automatically by matplotlib.", 

219 default=None, 

220 optional=True, 

221 ) 

222 

223 xAxisTickLabels = DictField[float, str]( 

224 doc="Dictionary mapping x-axis tick positions to their corresponding labels. For behavior details, " 

225 "refer to the 'Notes' section of the class docstring.", 

226 default=None, 

227 optional=True, 

228 ) 

229 

230 yAxisTickValues = ListField[float]( 

231 doc="List of y-axis tick values. If not set, the ticks will be set automatically by matplotlib.", 

232 default=None, 

233 optional=True, 

234 ) 

235 

236 yAxisTickLabels = DictField[float, str]( 

237 doc="Dictionary mapping y-axis tick positions to their corresponding labels. For behavior details, " 

238 "refer to the 'Notes' section of the class docstring.", 

239 default=None, 

240 optional=True, 

241 ) 

242 

243 tickLabelsFontSize = Field[float]( 

244 doc="The font size for the tick labels.", 

245 default=8, 

246 optional=True, 

247 ) 

248 

249 tickLabelsRotation = Field[float]( 

250 doc="The rotation of the tick labels.", 

251 default=0, 

252 optional=True, 

253 ) 

254 

255 setPositionsAtPixelBoundaries = Field[bool]( 

256 doc="Whether to consider the positions at the pixel boundaries rather than the center of the pixel.", 

257 default=False, 

258 optional=True, 

259 ) 

260 

261 hideMajorTicks = ListField[str]( 

262 doc="List of axis names for which to hide the major ticks. The options to include in the list are " 

263 "'x' and 'y'. This does not affect the visibility of major tick 'labels'. For example, setting this " 

264 "field to ['x', 'y'] will hide both major ticks.", 

265 default=[], 

266 maxLength=2, 

267 itemCheck=lambda s: s in ["x", "y"], 

268 optional=True, 

269 ) 

270 

271 hideMinorTicks = ListField[str]( 

272 doc="List of axis names for which to hide the minor ticks. The options to include in the list are " 

273 "'x' and 'y'. This does not affect the visibility of minor tick labels. For example, setting this " 

274 "field to ['x', 'y'] will hide both minor ticks.", 

275 default=[], 

276 maxLength=2, 

277 itemCheck=lambda s: s in ["x", "y"], 

278 optional=True, 

279 ) 

280 

281 dpi = Field[int]( 

282 doc="The resolution of the figure.", 

283 default=300, 

284 optional=True, 

285 ) 

286 

287 guideLines = ConfigDictField[str, GuideLinesConfig]( 

288 doc="Dictionary of guide lines for the x and y axes. The keys are 'x' and 'y', and the values are " 

289 "instances of `GuideLinesConfig`.", 

290 default={}, 

291 dictCheck=lambda d: all([k in ["x", "y"] for k in d]), 

292 optional=True, 

293 ) 

294 

295 def getInputSchema(self) -> KeyedDataSchema: 

296 base: list[tuple[str, type[Vector]]] = [] 

297 base.append((self.matrixKey, Vector)) 

298 return base 

299 

300 def __call__(self, data: KeyedData, **kwargs) -> Figure: 

301 self._validateInput(data, **kwargs) 

302 return self.makePlot(data, **kwargs) 

303 

304 def _validateInput(self, data: KeyedData, **kwargs: Any) -> None: 

305 # Check that the input data contains all the required keys. 

306 needed = set(k[0] for k in self.getInputSchema()) 

307 if not needed.issubset(data.keys()): 

308 raise ValueError(f"Input data does not contain all required keys: {self.getInputSchema()}") 

309 # Check the input data is a matrix, i.e. a 2d array. 

310 if not isinstance(data[self.matrixKey], np.ndarray) and data[self.matrixKey].ndim != 2: 

311 raise ValueError(f"Input data is not a 2d array: {data[self.matrixKey]}") 

312 # Check that the keyword arguments are valid. 

313 acceptableKwargs = {"plotInfo", "skymap", "band", "metric_tags", "fig"} 

314 if not set(kwargs).issubset(acceptableKwargs): 

315 raise ValueError( 

316 f"Only the following keyword arguments are allowed: {acceptableKwargs}. Got: {kwargs}" 

317 ) 

318 # Check that if one component key is provided, the other must be too. 

319 if (self.component1Key is not None and self.component2Key is None) or ( 

320 self.component1Key is None and self.component2Key is not None 

321 ): 

322 raise ValueError( 

323 "Both 'component1Key' and 'component2Key' must be provided together if either is provided." 

324 ) 

325 # Check that if component keys are provided, any of the tick values or 

326 # labels are not and vice versa. 

327 if (self.component1Key is not None and self.component2Key is not None) and ( 

328 self.xAxisTickValues is not None 

329 or self.yAxisTickValues is not None 

330 or self.xAxisTickLabels is not None 

331 or self.yAxisTickLabels is not None 

332 ): 

333 raise ValueError( 

334 "If 'component1Key' and 'component2Key' are provided, 'xAxisTickValues', " 

335 "'yAxisTickValues', 'xAxisTickLabels', and 'yAxisTickLabels' should not be " 

336 "provided as they will be dynamically configured." 

337 ) 

338 

339 def makePlot(self, data: KeyedData, plotInfo: Mapping[str, str] | None = None, **kwargs: Any) -> Figure: 

340 """ 

341 Plot a matrix of values. 

342 

343 Parameters 

344 ---------- 

345 data : `~lsst.analysis.tools.interfaces.KeyedData` 

346 The data to plot. 

347 plotInfo : `dict`, optional 

348 A dictionary of information about the data being plotted. 

349 **kwargs 

350 Additional keyword arguments to pass to the plot. 

351 

352 Returns 

353 ------- 

354 fig : `~matplotlib.figure.Figure` 

355 The resulting figure. 

356 """ 

357 # Retrieve the matrix info from the input data. 

358 matrix = data[self.matrixKey] 

359 

360 # Fetch the components between which the correlation is calculated. 

361 if self.component1Key is not None and self.component2Key is not None: 

362 comp1 = data[self.component1Key] 

363 comp2 = data[self.component2Key] 

364 

365 if self.inputDim == 1: 

366 # Calculate the size of the square. 

367 square_size = int(np.sqrt(matrix.size)) 

368 # Reshape into a square array. 

369 matrix = matrix.reshape(square_size, square_size) 

370 if self.component1Key is not None and self.component2Key is not None: 

371 comp1 = comp1.reshape(square_size, square_size) 

372 comp2 = comp2.reshape(square_size, square_size) 

373 

374 # Calculate default limits only if needed. 

375 if self.vmin is None or self.vmax is None: 

376 default_limits = apViz.PercentileInterval(98.0).get_limits(np.abs(matrix.flatten())) 

377 else: 

378 default_limits = (None, None) 

379 

380 # Set the value range using overrides or defaults. 

381 vrange = ( 

382 default_limits[0] if self.vmin is None else self.vmin, 

383 default_limits[1] if self.vmax is None else self.vmax, 

384 ) 

385 

386 # Allow for the figure object to be passed in. 

387 fig = kwargs.get("fig") 

388 if fig is None: 

389 fig = plt.figure(figsize=self.figsize, dpi=self.dpi) 

390 ax = fig.add_subplot(111) 

391 else: 

392 ax = fig.gca() 

393 

394 if self.title: 

395 ax.set_title(self.title, fontsize=self.titleFontSize) 

396 

397 if self.componentGroup1Key is not None and self.componentGroup2Key is not None: 

398 componentGroup1 = set(data[self.componentGroup1Key]) 

399 componentGroup2 = set(data[self.componentGroup2Key]) 

400 if len(componentGroup1) != 1 or len(componentGroup2) != 1: 

401 raise ValueError( 

402 f"Each column specified by {self.componentGroup1Key} and {self.componentGroup2Key} must " 

403 "contain identical values within itself, but they do not." 

404 ) 

405 else: 

406 xAxisLabel = self.xAxisLabel + str(componentGroup1.pop()) 

407 yAxisLabel = self.yAxisLabel + str(componentGroup2.pop()) 

408 else: 

409 xAxisLabel = self.xAxisLabel 

410 yAxisLabel = self.yAxisLabel 

411 

412 ax.set_xlabel(xAxisLabel, fontsize=self.axisLabelFontSize) 

413 ax.set_ylabel(yAxisLabel, fontsize=self.axisLabelFontSize) 

414 

415 # Set the colorbar and draw the image. 

416 norm = ImageNormalize(vmin=vrange[0], vmax=vrange[1]) 

417 img = ax.imshow( 

418 matrix, interpolation="none", norm=norm, origin=self.matrixOrigin, cmap=self.colorbarCmap 

419 ) 

420 

421 # Calculate the aspect ratio of the image. 

422 ratio = matrix.shape[0] / matrix.shape[1] 

423 

424 # Add the colorbar flush with the image axis. 

425 cbar = fig.colorbar(img, fraction=0.0457 * ratio, pad=0.04) 

426 

427 # Set the colorbar label and its font size. 

428 cbar.set_label(self.colorbarLabel, fontsize=self.colorbarLabelFontSize) 

429 

430 # Set the colorbar tick label font size. 

431 cbar.ax.tick_params(labelsize=self.colorbarTickLabelFontSize) 

432 

433 # If requested, we shift all the positions by 0.5 considering the 

434 # zero-point at a pixel boundary rather than the center of the pixel. 

435 shift = 0.5 if self.setPositionsAtPixelBoundaries else 0 

436 

437 if self.component1Key is not None and self.component2Key is not None: 

438 xAxisTickValues = np.arange(matrix.shape[1] + shift) 

439 yAxisTickValues = np.arange(matrix.shape[0] + shift) 

440 xAxisTickLabels = {key + shift: str(val) for key, val in zip(range(matrix.shape[1]), comp1[0, :])} 

441 yAxisTickLabels = {key + shift: str(val) for key, val in zip(range(matrix.shape[0]), comp2[:, 0])} 

442 else: 

443 xAxisTickValues = self.xAxisTickValues 

444 yAxisTickValues = self.yAxisTickValues 

445 xAxisTickLabels = self.xAxisTickLabels 

446 yAxisTickLabels = self.yAxisTickLabels 

447 

448 # If the tick values are not provided, retrieve them from the axes. 

449 xticks = xAxisTickValues if xAxisTickValues is not None else ax.xaxis.get_ticklocs() 

450 yticks = yAxisTickValues if yAxisTickValues is not None else ax.yaxis.get_ticklocs() 

451 

452 # Retrieve the current limits of the x and y axes. 

453 xlim, ylim = ax.get_xlim(), ax.get_ylim() 

454 

455 # Filter out tick locations that fall outside the current x/y-axis 

456 # limits to ensures that only tick locations within the visible range 

457 # are kept. 

458 xticks = np.array([tick for tick in xticks if min(xlim) <= tick - shift <= max(xlim)]) 

459 yticks = np.array([tick for tick in yticks if min(ylim) <= tick - shift <= max(ylim)]) 

460 tick_data = { 

461 "x": ( 

462 xticks - shift, 

463 np.array(list(xAxisTickLabels.keys())) - shift if xAxisTickLabels else None, 

464 list(xAxisTickLabels.values()) if xAxisTickLabels else None, 

465 ), 

466 "y": ( 

467 yticks - shift, 

468 np.array(list(yAxisTickLabels.keys())) - shift if yAxisTickLabels else None, 

469 list(yAxisTickLabels.values()) if yAxisTickLabels else None, 

470 ), 

471 } 

472 

473 for dim, axis in [("x", ax.xaxis), ("y", ax.yaxis)]: 

474 # Get the major tick positions and labels. 

475 major_tick_values, positions, labels = tick_data[dim] 

476 

477 # Set major ticks. 

478 axis.set_ticks(major_tick_values, minor=False) 

479 

480 # Set tick labels while compensating for the potential shift in the 

481 # tick positions and removing trailing zeros and the decimal point 

482 # for integer values. 

483 axis.set_ticklabels( 

484 [ 

485 f"{tick + shift:.0f}" if (tick + shift).is_integer() else f"{tick + shift}" 

486 for tick in axis.get_ticklocs() 

487 ], 

488 fontsize=self.tickLabelsFontSize, 

489 ) 

490 

491 # Check if positions are provided. 

492 if positions is not None: 

493 # Assign specified positions as minor ticks. 

494 axis.set_ticks(positions, minor=True) 

495 

496 # Conditionally assign labels to major and/or minor ticks. 

497 if labels is not None: 

498 # Create a lookup for positions to labels. 

499 positions_labels_lookup = { 

500 p: lbl if p in major_tick_values else "" for p, lbl in zip(positions, labels) 

501 } 

502 # Generate labels for major ticks, leaving blanks for 

503 # non-major positions. 

504 major_labels = [ 

505 "" if m not in positions_labels_lookup else positions_labels_lookup[m] 

506 for m in major_tick_values 

507 ] 

508 # Generate labels for minor ticks, excluding those 

509 # designated as major. 

510 minor_labels = [ 

511 "" if p in major_tick_values else lbl for p, lbl in zip(positions, labels) 

512 ] 

513 

514 # Apply labels to major ticks if any exist. 

515 if any(e for e in major_labels if e): 

516 axis.set_ticklabels(major_labels, minor=False, fontsize=self.tickLabelsFontSize) 

517 else: 

518 # If no major labels, clear major tick labels. 

519 axis.set_ticklabels("") 

520 

521 # Apply labels to minor ticks if any exist. 

522 if any(e for e in minor_labels if e): 

523 axis.set_ticklabels(minor_labels, minor=True, fontsize=self.tickLabelsFontSize) 

524 

525 if dim in self.hideMajorTicks: 

526 # Remove major tick marks for asthetic reasons. 

527 axis.set_tick_params(which="major", length=0) 

528 

529 if dim in self.hideMinorTicks: 

530 # Remove minor tick marks for asthetic reasons. 

531 axis.set_tick_params(which="minor", length=0) 

532 

533 # Rotate the tick labels by the specified angle. 

534 ax.tick_params(axis=dim, rotation=self.tickLabelsRotation) 

535 

536 # Add vertical and horizontal lines if provided. 

537 if "x" in self.guideLines: 

538 xLines = self.guideLines["x"] 

539 for x, label in xLines.lines.items(): 

540 ax.axvline(x=x - shift, color=xLines.outlineColor, linewidth=2, alpha=0.6) 

541 ax.axvline( 

542 x=x - shift, color=xLines.color, linestyle=xLines.linestyle, linewidth=1, alpha=0.85 

543 ) 

544 label = ax.text( 

545 x - shift, 

546 0.03, 

547 label, 

548 rotation=90, 

549 color=xLines.color, 

550 transform=ax.get_xaxis_transform(), 

551 horizontalalignment="right", 

552 alpha=0.9, 

553 ) 

554 # Add a distinct outline around the label for better visibility 

555 # in various backgrounds. 

556 label.set_path_effects( 

557 [ 

558 mpl_path_effects.Stroke(linewidth=2, foreground=xLines.outlineColor, alpha=0.8), 

559 mpl_path_effects.Normal(), 

560 ] 

561 ) 

562 

563 if "y" in self.guideLines: 

564 yLines = self.guideLines["y"] 

565 for y, label in yLines.lines.items(): 

566 ax.axhline(y=y - shift, color=yLines.outlineColor, linewidth=2, alpha=0.6) 

567 ax.axhline( 

568 y=y - shift, color=yLines.color, linestyle=yLines.linestyle, linewidth=1, alpha=0.85 

569 ) 

570 label = ax.text( 

571 0.03, 

572 y - shift, 

573 label, 

574 color=yLines.color, 

575 transform=ax.get_yaxis_transform(), 

576 verticalalignment="bottom", 

577 alpha=0.9, 

578 ) 

579 # Add a distinct outline around the label for better visibility 

580 # in various backgrounds. 

581 label.set_path_effects( 

582 [ 

583 mpl_path_effects.Stroke(linewidth=2, foreground=yLines.outlineColor, alpha=0.8), 

584 mpl_path_effects.Normal(), 

585 ] 

586 ) 

587 

588 # Add plot info if provided. 

589 if plotInfo is not None: 

590 fig = addPlotInfo(fig, plotInfo) 

591 

592 return fig