Coverage for python / lsst / analysis / tools / actions / plot / matrixPlot.py: 25%
160 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-26 09:36 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-26 09:36 +0000
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ("MatrixPlot",)
26from collections.abc import Mapping
27from typing import TYPE_CHECKING, Any
29import astropy.visualization as apViz
30import matplotlib.patheffects as mpl_path_effects
31import matplotlib.pyplot as plt
32import numpy as np
33from astropy.visualization.mpl_normalize import ImageNormalize
35from lsst.pex.config import ChoiceField, Config, ConfigDictField, DictField, Field, ListField
37from ...interfaces import PlotAction, Vector
38from .plotUtils import addPlotInfo
40if TYPE_CHECKING:
41 from matplotlib.figure import Figure
43 from ...interfaces import KeyedData, KeyedDataSchema
46class GuideLinesConfig(Config):
47 lines = DictField[float, str](
48 doc=("Dictionary of x/y-values and the labels where vertical/horizontal lines are drawn."),
49 optional=False,
50 )
52 color = Field[str](
53 doc="The color of the lines and labels.",
54 default="red",
55 )
57 outlineColor = Field[str](
58 doc="The color of the outline around the lines and labels.",
59 default="white",
60 )
62 linestyle = Field[str](
63 doc="The style of the lines.",
64 default="--",
65 )
68class MatrixPlot(PlotAction):
69 """Make the plot of a matrix (2D array).
71 Notes
72 -----
73 The `xAxisTickLabels` and `yAxisTickLabels` attributes of this class serve
74 as dictionaries to map axis tick positions to their corresponding labels.
75 If any positions do not align with major ticks (either provided by
76 `x/yAxisTickValues` or automatically set by matplotlib), they will be
77 designated as minor ticks. Thus, these tick labels operate independently,
78 meaning their corresponding positions do not need to match those in
79 `x/yAxisTickValues` or anything else. The code automatically adjusts to
80 handle any overlaps caused by user input and across various plotting
81 scenarios.
82 Note that when `component1Key` and `component2Key` are specified, the x and
83 y tick values and labels will be dynamically configured, thereby
84 eliminating the need for providing `x/yAxisTickValues` and
85 `x/yAxisTickLabels`. When `componentGroup1Key` and `componentGroup2Key` are
86 specified, the x and y axis labels are dynamically updated to include the
87 group names, prefixed by `xAxisLabel` and `yAxisLabel` for a more
88 descriptive labeling.
89 """
91 inputDim = ChoiceField[int](
92 doc="The dimensionality of the input data.",
93 default=1,
94 allowed={
95 1: "1D inputs are automatically reshaped into square 2D matrices.",
96 2: "2D inputs are directly utilized as is.",
97 },
98 optional=True,
99 )
101 matrixKey = Field[str](
102 doc="The key for the input matrix.",
103 default="matrix",
104 )
106 matrixOrigin = ChoiceField[str](
107 doc="Determines the starting corner ('upper', 'lower') for matrix plots. It only affects the visual "
108 "appearance of the plot.",
109 default="upper",
110 allowed={
111 "upper": "The origin is at the upper left corner.",
112 "lower": "The origin is at the lower left corner.",
113 },
114 optional=True,
115 )
117 component1Key = Field[str](
118 doc="The key to access a list of names for the first set of components in a correlation analysis. "
119 "This will be used to determine x-axis tick values and tick labels.",
120 default=None,
121 optional=True,
122 )
124 component2Key = Field[str](
125 doc="The key to access a list of names for the second set of components in a correlation analysis. "
126 "This will be used to determine y-axis tick values and tick labels.",
127 )
129 componentGroup1Key = Field[str](
130 doc="The key to access a list of group names for the first set of components in a correlation "
131 "analysis. This will be used to determine the x-axis label.",
132 default=None,
133 optional=True,
134 )
136 componentGroup2Key = Field[str](
137 doc="The key to access a list of group names for the second set of components in a correlation "
138 "analysis. This will be used to determine the y-axis label.",
139 default=None,
140 optional=True,
141 )
143 xAxisLabel = Field[str](
144 doc="The label to use for the x-axis.",
145 default="",
146 optional=True,
147 )
149 yAxisLabel = Field[str](
150 doc="The label to use for the y-axis.",
151 default="",
152 optional=True,
153 )
155 axisLabelFontSize = Field[float](
156 doc="The font size for the axis labels.",
157 default=9,
158 optional=True,
159 )
161 colorbarLabel = Field[str](
162 doc="The label to use for the colorbar.",
163 default="",
164 optional=True,
165 )
167 colorbarLabelFontSize = Field[float](
168 doc="The font size for the colorbar label.",
169 default=10,
170 optional=True,
171 )
173 colorbarTickLabelFontSize = Field[float](
174 doc="The font size for the colorbar tick labels.",
175 default=8,
176 optional=True,
177 )
179 colorbarCmap = ChoiceField[str](
180 doc="The colormap to use for the colorbar.",
181 default="viridis",
182 allowed={name: name for name in plt.colormaps()},
183 optional=True,
184 )
186 vmin = Field[float](
187 doc="The vmin value for the colorbar.",
188 default=None,
189 optional=True,
190 )
192 vmax = Field[float](
193 doc="The vmax value for the colorbar.",
194 default=None,
195 optional=True,
196 )
198 figsize = ListField[float](
199 doc="The size of the figure.",
200 default=[5, 5],
201 maxLength=2,
202 optional=True,
203 )
205 title = Field[str](
206 doc="The title of the figure.",
207 default="",
208 optional=True,
209 )
211 titleFontSize = Field[float](
212 doc="The font size for the title.",
213 default=10,
214 optional=True,
215 )
217 xAxisTickValues = ListField[float](
218 doc="List of x-axis tick values. If not set, the ticks will be set automatically by matplotlib.",
219 default=None,
220 optional=True,
221 )
223 xAxisTickLabels = DictField[float, str](
224 doc="Dictionary mapping x-axis tick positions to their corresponding labels. For behavior details, "
225 "refer to the 'Notes' section of the class docstring.",
226 default=None,
227 optional=True,
228 )
230 yAxisTickValues = ListField[float](
231 doc="List of y-axis tick values. If not set, the ticks will be set automatically by matplotlib.",
232 default=None,
233 optional=True,
234 )
236 yAxisTickLabels = DictField[float, str](
237 doc="Dictionary mapping y-axis tick positions to their corresponding labels. For behavior details, "
238 "refer to the 'Notes' section of the class docstring.",
239 default=None,
240 optional=True,
241 )
243 tickLabelsFontSize = Field[float](
244 doc="The font size for the tick labels.",
245 default=8,
246 optional=True,
247 )
249 tickLabelsRotation = Field[float](
250 doc="The rotation of the tick labels.",
251 default=0,
252 optional=True,
253 )
255 setPositionsAtPixelBoundaries = Field[bool](
256 doc="Whether to consider the positions at the pixel boundaries rather than the center of the pixel.",
257 default=False,
258 optional=True,
259 )
261 hideMajorTicks = ListField[str](
262 doc="List of axis names for which to hide the major ticks. The options to include in the list are "
263 "'x' and 'y'. This does not affect the visibility of major tick 'labels'. For example, setting this "
264 "field to ['x', 'y'] will hide both major ticks.",
265 default=[],
266 maxLength=2,
267 itemCheck=lambda s: s in ["x", "y"],
268 optional=True,
269 )
271 hideMinorTicks = ListField[str](
272 doc="List of axis names for which to hide the minor ticks. The options to include in the list are "
273 "'x' and 'y'. This does not affect the visibility of minor tick labels. For example, setting this "
274 "field to ['x', 'y'] will hide both minor ticks.",
275 default=[],
276 maxLength=2,
277 itemCheck=lambda s: s in ["x", "y"],
278 optional=True,
279 )
281 dpi = Field[int](
282 doc="The resolution of the figure.",
283 default=300,
284 optional=True,
285 )
287 guideLines = ConfigDictField[str, GuideLinesConfig](
288 doc="Dictionary of guide lines for the x and y axes. The keys are 'x' and 'y', and the values are "
289 "instances of `GuideLinesConfig`.",
290 default={},
291 dictCheck=lambda d: all([k in ["x", "y"] for k in d]),
292 optional=True,
293 )
295 def getInputSchema(self) -> KeyedDataSchema:
296 base: list[tuple[str, type[Vector]]] = []
297 base.append((self.matrixKey, Vector))
298 return base
300 def __call__(self, data: KeyedData, **kwargs) -> Figure:
301 self._validateInput(data, **kwargs)
302 return self.makePlot(data, **kwargs)
304 def _validateInput(self, data: KeyedData, **kwargs: Any) -> None:
305 # Check that the input data contains all the required keys.
306 needed = set(k[0] for k in self.getInputSchema())
307 if not needed.issubset(data.keys()):
308 raise ValueError(f"Input data does not contain all required keys: {self.getInputSchema()}")
309 # Check the input data is a matrix, i.e. a 2d array.
310 if not isinstance(data[self.matrixKey], np.ndarray) and data[self.matrixKey].ndim != 2:
311 raise ValueError(f"Input data is not a 2d array: {data[self.matrixKey]}")
312 # Check that the keyword arguments are valid.
313 acceptableKwargs = {"plotInfo", "skymap", "band", "metric_tags", "fig"}
314 if not set(kwargs).issubset(acceptableKwargs):
315 raise ValueError(
316 f"Only the following keyword arguments are allowed: {acceptableKwargs}. Got: {kwargs}"
317 )
318 # Check that if one component key is provided, the other must be too.
319 if (self.component1Key is not None and self.component2Key is None) or (
320 self.component1Key is None and self.component2Key is not None
321 ):
322 raise ValueError(
323 "Both 'component1Key' and 'component2Key' must be provided together if either is provided."
324 )
325 # Check that if component keys are provided, any of the tick values or
326 # labels are not and vice versa.
327 if (self.component1Key is not None and self.component2Key is not None) and (
328 self.xAxisTickValues is not None
329 or self.yAxisTickValues is not None
330 or self.xAxisTickLabels is not None
331 or self.yAxisTickLabels is not None
332 ):
333 raise ValueError(
334 "If 'component1Key' and 'component2Key' are provided, 'xAxisTickValues', "
335 "'yAxisTickValues', 'xAxisTickLabels', and 'yAxisTickLabels' should not be "
336 "provided as they will be dynamically configured."
337 )
339 def makePlot(self, data: KeyedData, plotInfo: Mapping[str, str] | None = None, **kwargs: Any) -> Figure:
340 """
341 Plot a matrix of values.
343 Parameters
344 ----------
345 data : `~lsst.analysis.tools.interfaces.KeyedData`
346 The data to plot.
347 plotInfo : `dict`, optional
348 A dictionary of information about the data being plotted.
349 **kwargs
350 Additional keyword arguments to pass to the plot.
352 Returns
353 -------
354 fig : `~matplotlib.figure.Figure`
355 The resulting figure.
356 """
357 # Retrieve the matrix info from the input data.
358 matrix = data[self.matrixKey]
360 # Fetch the components between which the correlation is calculated.
361 if self.component1Key is not None and self.component2Key is not None:
362 comp1 = data[self.component1Key]
363 comp2 = data[self.component2Key]
365 if self.inputDim == 1:
366 # Calculate the size of the square.
367 square_size = int(np.sqrt(matrix.size))
368 # Reshape into a square array.
369 matrix = matrix.reshape(square_size, square_size)
370 if self.component1Key is not None and self.component2Key is not None:
371 comp1 = comp1.reshape(square_size, square_size)
372 comp2 = comp2.reshape(square_size, square_size)
374 # Calculate default limits only if needed.
375 if self.vmin is None or self.vmax is None:
376 default_limits = apViz.PercentileInterval(98.0).get_limits(np.abs(matrix.flatten()))
377 else:
378 default_limits = (None, None)
380 # Set the value range using overrides or defaults.
381 vrange = (
382 default_limits[0] if self.vmin is None else self.vmin,
383 default_limits[1] if self.vmax is None else self.vmax,
384 )
386 # Allow for the figure object to be passed in.
387 fig = kwargs.get("fig")
388 if fig is None:
389 fig = plt.figure(figsize=self.figsize, dpi=self.dpi)
390 ax = fig.add_subplot(111)
391 else:
392 ax = fig.gca()
394 if self.title:
395 ax.set_title(self.title, fontsize=self.titleFontSize)
397 if self.componentGroup1Key is not None and self.componentGroup2Key is not None:
398 componentGroup1 = set(data[self.componentGroup1Key])
399 componentGroup2 = set(data[self.componentGroup2Key])
400 if len(componentGroup1) != 1 or len(componentGroup2) != 1:
401 raise ValueError(
402 f"Each column specified by {self.componentGroup1Key} and {self.componentGroup2Key} must "
403 "contain identical values within itself, but they do not."
404 )
405 else:
406 xAxisLabel = self.xAxisLabel + str(componentGroup1.pop())
407 yAxisLabel = self.yAxisLabel + str(componentGroup2.pop())
408 else:
409 xAxisLabel = self.xAxisLabel
410 yAxisLabel = self.yAxisLabel
412 ax.set_xlabel(xAxisLabel, fontsize=self.axisLabelFontSize)
413 ax.set_ylabel(yAxisLabel, fontsize=self.axisLabelFontSize)
415 # Set the colorbar and draw the image.
416 norm = ImageNormalize(vmin=vrange[0], vmax=vrange[1])
417 img = ax.imshow(
418 matrix, interpolation="none", norm=norm, origin=self.matrixOrigin, cmap=self.colorbarCmap
419 )
421 # Calculate the aspect ratio of the image.
422 ratio = matrix.shape[0] / matrix.shape[1]
424 # Add the colorbar flush with the image axis.
425 cbar = fig.colorbar(img, fraction=0.0457 * ratio, pad=0.04)
427 # Set the colorbar label and its font size.
428 cbar.set_label(self.colorbarLabel, fontsize=self.colorbarLabelFontSize)
430 # Set the colorbar tick label font size.
431 cbar.ax.tick_params(labelsize=self.colorbarTickLabelFontSize)
433 # If requested, we shift all the positions by 0.5 considering the
434 # zero-point at a pixel boundary rather than the center of the pixel.
435 shift = 0.5 if self.setPositionsAtPixelBoundaries else 0
437 if self.component1Key is not None and self.component2Key is not None:
438 xAxisTickValues = np.arange(matrix.shape[1] + shift)
439 yAxisTickValues = np.arange(matrix.shape[0] + shift)
440 xAxisTickLabels = {key + shift: str(val) for key, val in zip(range(matrix.shape[1]), comp1[0, :])}
441 yAxisTickLabels = {key + shift: str(val) for key, val in zip(range(matrix.shape[0]), comp2[:, 0])}
442 else:
443 xAxisTickValues = self.xAxisTickValues
444 yAxisTickValues = self.yAxisTickValues
445 xAxisTickLabels = self.xAxisTickLabels
446 yAxisTickLabels = self.yAxisTickLabels
448 # If the tick values are not provided, retrieve them from the axes.
449 xticks = xAxisTickValues if xAxisTickValues is not None else ax.xaxis.get_ticklocs()
450 yticks = yAxisTickValues if yAxisTickValues is not None else ax.yaxis.get_ticklocs()
452 # Retrieve the current limits of the x and y axes.
453 xlim, ylim = ax.get_xlim(), ax.get_ylim()
455 # Filter out tick locations that fall outside the current x/y-axis
456 # limits to ensures that only tick locations within the visible range
457 # are kept.
458 xticks = np.array([tick for tick in xticks if min(xlim) <= tick - shift <= max(xlim)])
459 yticks = np.array([tick for tick in yticks if min(ylim) <= tick - shift <= max(ylim)])
460 tick_data = {
461 "x": (
462 xticks - shift,
463 np.array(list(xAxisTickLabels.keys())) - shift if xAxisTickLabels else None,
464 list(xAxisTickLabels.values()) if xAxisTickLabels else None,
465 ),
466 "y": (
467 yticks - shift,
468 np.array(list(yAxisTickLabels.keys())) - shift if yAxisTickLabels else None,
469 list(yAxisTickLabels.values()) if yAxisTickLabels else None,
470 ),
471 }
473 for dim, axis in [("x", ax.xaxis), ("y", ax.yaxis)]:
474 # Get the major tick positions and labels.
475 major_tick_values, positions, labels = tick_data[dim]
477 # Set major ticks.
478 axis.set_ticks(major_tick_values, minor=False)
480 # Set tick labels while compensating for the potential shift in the
481 # tick positions and removing trailing zeros and the decimal point
482 # for integer values.
483 axis.set_ticklabels(
484 [
485 f"{tick + shift:.0f}" if (tick + shift).is_integer() else f"{tick + shift}"
486 for tick in axis.get_ticklocs()
487 ],
488 fontsize=self.tickLabelsFontSize,
489 )
491 # Check if positions are provided.
492 if positions is not None:
493 # Assign specified positions as minor ticks.
494 axis.set_ticks(positions, minor=True)
496 # Conditionally assign labels to major and/or minor ticks.
497 if labels is not None:
498 # Create a lookup for positions to labels.
499 positions_labels_lookup = {
500 p: lbl if p in major_tick_values else "" for p, lbl in zip(positions, labels)
501 }
502 # Generate labels for major ticks, leaving blanks for
503 # non-major positions.
504 major_labels = [
505 "" if m not in positions_labels_lookup else positions_labels_lookup[m]
506 for m in major_tick_values
507 ]
508 # Generate labels for minor ticks, excluding those
509 # designated as major.
510 minor_labels = [
511 "" if p in major_tick_values else lbl for p, lbl in zip(positions, labels)
512 ]
514 # Apply labels to major ticks if any exist.
515 if any(e for e in major_labels if e):
516 axis.set_ticklabels(major_labels, minor=False, fontsize=self.tickLabelsFontSize)
517 else:
518 # If no major labels, clear major tick labels.
519 axis.set_ticklabels("")
521 # Apply labels to minor ticks if any exist.
522 if any(e for e in minor_labels if e):
523 axis.set_ticklabels(minor_labels, minor=True, fontsize=self.tickLabelsFontSize)
525 if dim in self.hideMajorTicks:
526 # Remove major tick marks for asthetic reasons.
527 axis.set_tick_params(which="major", length=0)
529 if dim in self.hideMinorTicks:
530 # Remove minor tick marks for asthetic reasons.
531 axis.set_tick_params(which="minor", length=0)
533 # Rotate the tick labels by the specified angle.
534 ax.tick_params(axis=dim, rotation=self.tickLabelsRotation)
536 # Add vertical and horizontal lines if provided.
537 if "x" in self.guideLines:
538 xLines = self.guideLines["x"]
539 for x, label in xLines.lines.items():
540 ax.axvline(x=x - shift, color=xLines.outlineColor, linewidth=2, alpha=0.6)
541 ax.axvline(
542 x=x - shift, color=xLines.color, linestyle=xLines.linestyle, linewidth=1, alpha=0.85
543 )
544 label = ax.text(
545 x - shift,
546 0.03,
547 label,
548 rotation=90,
549 color=xLines.color,
550 transform=ax.get_xaxis_transform(),
551 horizontalalignment="right",
552 alpha=0.9,
553 )
554 # Add a distinct outline around the label for better visibility
555 # in various backgrounds.
556 label.set_path_effects(
557 [
558 mpl_path_effects.Stroke(linewidth=2, foreground=xLines.outlineColor, alpha=0.8),
559 mpl_path_effects.Normal(),
560 ]
561 )
563 if "y" in self.guideLines:
564 yLines = self.guideLines["y"]
565 for y, label in yLines.lines.items():
566 ax.axhline(y=y - shift, color=yLines.outlineColor, linewidth=2, alpha=0.6)
567 ax.axhline(
568 y=y - shift, color=yLines.color, linestyle=yLines.linestyle, linewidth=1, alpha=0.85
569 )
570 label = ax.text(
571 0.03,
572 y - shift,
573 label,
574 color=yLines.color,
575 transform=ax.get_yaxis_transform(),
576 verticalalignment="bottom",
577 alpha=0.9,
578 )
579 # Add a distinct outline around the label for better visibility
580 # in various backgrounds.
581 label.set_path_effects(
582 [
583 mpl_path_effects.Stroke(linewidth=2, foreground=yLines.outlineColor, alpha=0.8),
584 mpl_path_effects.Normal(),
585 ]
586 )
588 # Add plot info if provided.
589 if plotInfo is not None:
590 fig = addPlotInfo(fig, plotInfo)
592 return fig