Coverage for python/lsst/analysis/tools/actions/plot/matrixPlot.py: 24%
162 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-05-02 04:48 -0700
« prev ^ index » next coverage.py v7.5.0, created at 2024-05-02 04:48 -0700
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ("MatrixPlot",)
26from typing import TYPE_CHECKING, Any, Mapping
28import astropy.visualization as apViz
29import matplotlib.patheffects as mpl_path_effects
30import matplotlib.pyplot as plt
31import numpy as np
32from astropy.visualization.mpl_normalize import ImageNormalize
33from lsst.pex.config import ChoiceField, Config, ConfigDictField, DictField, Field, ListField
35from ...interfaces import PlotAction, Vector
36from .plotUtils import addPlotInfo
38if TYPE_CHECKING: 38 ↛ 39line 38 didn't jump to line 39, because the condition on line 38 was never true
39 from matplotlib.figure import Figure
41 from ...interfaces import KeyedData, KeyedDataSchema
44class GuideLinesConfig(Config):
45 lines = DictField[float, str](
46 doc=("Dictionary of x/y-values and the labels where vertical/horizontal lines are drawn."),
47 optional=False,
48 )
50 color = Field[str](
51 doc="The color of the lines and labels.",
52 default="red",
53 )
55 outlineColor = Field[str](
56 doc="The color of the outline around the lines and labels.",
57 default="white",
58 )
60 linestyle = Field[str](
61 doc="The style of the lines.",
62 default="--",
63 )
66class MatrixPlot(PlotAction):
67 """Make the plot of a matrix (2D array).
69 Notes
70 -----
71 The `xAxisTickLabels` and `yAxisTickLabels` attributes of this class serve
72 as dictionaries to map axis tick positions to their corresponding labels.
73 If any positions do not align with major ticks (either provided by
74 `x/yAxisTickValues` or automatically set by matplotlib), they will be
75 designated as minor ticks. Thus, these tick labels operate independently,
76 meaning their corresponding positions do not need to match those in
77 `x/yAxisTickValues` or anything else. The code automatically adjusts to
78 handle any overlaps caused by user input and across various plotting
79 scenarios.
80 Note that when `component1Key` and `component2Key` are specified, the x and
81 y tick values and labels will be dynamically configured, thereby
82 eliminating the need for providing `x/yAxisTickValues` and
83 `x/yAxisTickLabels`. When `componentGroup1Key` and `componentGroup2Key` are
84 specified, the x and y axis labels are dynamically updated to include the
85 group names, prefixed by `xAxisLabel` and `yAxisLabel` for a more
86 descriptive labeling.
87 """
89 inputDim = ChoiceField[int](
90 doc="The dimensionality of the input data.",
91 default=1,
92 allowed={
93 1: "1D inputs are automatically reshaped into square 2D matrices.",
94 2: "2D inputs are directly utilized as is.",
95 },
96 optional=True,
97 )
99 matrixKey = Field[str](
100 doc="The key for the input matrix.",
101 default="matrix",
102 )
104 matrixOrigin = ChoiceField[str](
105 doc="Determines the starting corner ('upper', 'lower') for matrix plots. It only affects the visual "
106 "appearance of the plot.",
107 default="upper",
108 allowed={
109 "upper": "The origin is at the upper left corner.",
110 "lower": "The origin is at the lower left corner.",
111 },
112 optional=True,
113 )
115 component1Key = Field[str](
116 doc="The key to access a list of names for the first set of components in a correlation analysis. "
117 "This will be used to determine x-axis tick values and tick labels.",
118 default=None,
119 optional=True,
120 )
122 component2Key = Field[str](
123 doc="The key to access a list of names for the second set of components in a correlation analysis. "
124 "This will be used to determine y-axis tick values and tick labels.",
125 )
127 componentGroup1Key = Field[str](
128 doc="The key to access a list of group names for the first set of components in a correlation "
129 "analysis. This will be used to determine the x-axis label.",
130 default=None,
131 optional=True,
132 )
134 componentGroup2Key = Field[str](
135 doc="The key to access a list of group names for the second set of components in a correlation "
136 "analysis. This will be used to determine the y-axis label.",
137 default=None,
138 optional=True,
139 )
141 xAxisLabel = Field[str](
142 doc="The label to use for the x-axis.",
143 default="",
144 optional=True,
145 )
147 yAxisLabel = Field[str](
148 doc="The label to use for the y-axis.",
149 default="",
150 optional=True,
151 )
153 axisLabelFontSize = Field[float](
154 doc="The font size for the axis labels.",
155 default=9,
156 optional=True,
157 )
159 colorbarLabel = Field[str](
160 doc="The label to use for the colorbar.",
161 default="",
162 optional=True,
163 )
165 colorbarLabelFontSize = Field[float](
166 doc="The font size for the colorbar label.",
167 default=10,
168 optional=True,
169 )
171 colorbarTickLabelFontSize = Field[float](
172 doc="The font size for the colorbar tick labels.",
173 default=8,
174 optional=True,
175 )
177 colorbarCmap = ChoiceField[str](
178 doc="The colormap to use for the colorbar.",
179 default="viridis",
180 allowed={name: name for name in plt.colormaps()},
181 optional=True,
182 )
184 vmin = Field[float](
185 doc="The vmin value for the colorbar.",
186 default=None,
187 optional=True,
188 )
190 vmax = Field[float](
191 doc="The vmax value for the colorbar.",
192 default=None,
193 optional=True,
194 )
196 figsize = ListField[float](
197 doc="The size of the figure.",
198 default=[5, 5],
199 maxLength=2,
200 optional=True,
201 )
203 title = Field[str](
204 doc="The title of the figure.",
205 default="",
206 optional=True,
207 )
209 titleFontSize = Field[float](
210 doc="The font size for the title.",
211 default=10,
212 optional=True,
213 )
215 xAxisTickValues = ListField[float](
216 doc="List of x-axis tick values. If not set, the ticks will be set automatically by matplotlib.",
217 default=None,
218 optional=True,
219 )
221 xAxisTickLabels = DictField[float, str](
222 doc="Dictionary mapping x-axis tick positions to their corresponding labels. For behavior details, "
223 "refer to the 'Notes' section of the class docstring.",
224 default=None,
225 optional=True,
226 )
228 yAxisTickValues = ListField[float](
229 doc="List of y-axis tick values. If not set, the ticks will be set automatically by matplotlib.",
230 default=None,
231 optional=True,
232 )
234 yAxisTickLabels = DictField[float, str](
235 doc="Dictionary mapping y-axis tick positions to their corresponding labels. For behavior details, "
236 "refer to the 'Notes' section of the class docstring.",
237 default=None,
238 optional=True,
239 )
241 tickLabelsFontSize = Field[float](
242 doc="The font size for the tick labels.",
243 default=8,
244 optional=True,
245 )
247 tickLabelsRotation = Field[float](
248 doc="The rotation of the tick labels.",
249 default=0,
250 optional=True,
251 )
253 setPositionsAtPixelBoundaries = Field[bool](
254 doc="Whether to consider the positions at the pixel boundaries rather than the center of the pixel.",
255 default=False,
256 optional=True,
257 )
259 hideMajorTicks = ListField[str]( 259 ↛ exitline 259 didn't jump to the function exit
260 doc="List of axis names for which to hide the major ticks. The options to include in the list are "
261 "'x' and 'y'. This does not affect the visibility of major tick 'labels'. For example, setting this "
262 "field to ['x', 'y'] will hide both major ticks.",
263 default=[],
264 maxLength=2,
265 itemCheck=lambda s: s in ["x", "y"],
266 optional=True,
267 )
269 hideMinorTicks = ListField[str]( 269 ↛ exitline 269 didn't jump to the function exit
270 doc="List of axis names for which to hide the minor ticks. The options to include in the list are "
271 "'x' and 'y'. This does not affect the visibility of minor tick labels. For example, setting this "
272 "field to ['x', 'y'] will hide both minor ticks.",
273 default=[],
274 maxLength=2,
275 itemCheck=lambda s: s in ["x", "y"],
276 optional=True,
277 )
279 dpi = Field[int](
280 doc="The resolution of the figure.",
281 default=300,
282 optional=True,
283 )
285 guideLines = ConfigDictField[str, GuideLinesConfig]( 285 ↛ exitline 285 didn't jump to the function exit
286 doc="Dictionary of guide lines for the x and y axes. The keys are 'x' and 'y', and the values are "
287 "instances of `GuideLinesConfig`.",
288 default={},
289 dictCheck=lambda d: all([k in ["x", "y"] for k in d]),
290 optional=True,
291 )
293 def getInputSchema(self) -> KeyedDataSchema:
294 base: list[tuple[str, type[Vector]]] = []
295 base.append((self.matrixKey, Vector))
296 return base
298 def __call__(self, data: KeyedData, **kwargs) -> Figure:
299 self._validateInput(data, **kwargs)
300 return self.makePlot(data, **kwargs)
302 def _validateInput(self, data: KeyedData, **kwargs: Any) -> None:
303 # Check that the input data contains all the required keys.
304 needed = set(k[0] for k in self.getInputSchema())
305 if not needed.issubset(data.keys()):
306 raise ValueError(f"Input data does not contain all required keys: {self.getInputSchema()}")
307 # Check the input data is a matrix, i.e. a 2d array.
308 if not isinstance(data[self.matrixKey], np.ndarray) and data[self.matrixKey].ndim != 2:
309 raise ValueError(f"Input data is not a 2d array: {data[self.matrixKey]}")
310 # Check that the keyword arguments are valid.
311 acceptableKwargs = {"plotInfo", "skymap", "band", "metric_tags", "fig"}
312 if not set(kwargs).issubset(acceptableKwargs):
313 raise ValueError(
314 f"Only the following keyword arguments are allowed: {acceptableKwargs}. Got: {kwargs}"
315 )
316 # Check that if one component key is provided, the other must be too.
317 if (self.component1Key is not None and self.component2Key is None) or (
318 self.component1Key is None and self.component2Key is not None
319 ):
320 raise ValueError(
321 "Both 'component1Key' and 'component2Key' must be provided together if either is provided."
322 )
323 # Check that if component keys are provided, any of the tick values or
324 # labels are not and vice versa.
325 if (self.component1Key is not None and self.component2Key is not None) and (
326 self.xAxisTickValues is not None
327 or self.yAxisTickValues is not None
328 or self.xAxisTickLabels is not None
329 or self.yAxisTickLabels is not None
330 ):
331 raise ValueError(
332 "If 'component1Key' and 'component2Key' are provided, 'xAxisTickValues', "
333 "'yAxisTickValues', 'xAxisTickLabels', and 'yAxisTickLabels' should not be "
334 "provided as they will be dynamically configured."
335 )
337 def makePlot(self, data: KeyedData, plotInfo: Mapping[str, str] | None = None, **kwargs: Any) -> Figure:
338 """
339 Plot a matrix of values.
341 Parameters
342 ----------
343 data : `~lsst.analysis.tools.interfaces.KeyedData`
344 The data to plot.
345 plotInfo : `dict`, optional
346 A dictionary of information about the data being plotted.
347 **kwargs
348 Additional keyword arguments to pass to the plot.
350 Returns
351 -------
352 fig : `~matplotlib.figure.Figure`
353 The resulting figure.
354 """
355 # Retrieve the matrix info from the input data.
356 matrix = data[self.matrixKey]
358 # Fetch the components between which the correlation is calculated.
359 if self.component1Key is not None and self.component2Key is not None:
360 comp1 = data[self.component1Key]
361 comp2 = data[self.component2Key]
363 if self.inputDim == 1:
364 # Calculate the size of the square.
365 square_size = int(np.sqrt(matrix.size))
366 # Reshape into a square array.
367 matrix = matrix.reshape(square_size, square_size)
368 if self.component1Key is not None and self.component2Key is not None:
369 comp1 = comp1.reshape(square_size, square_size)
370 comp2 = comp2.reshape(square_size, square_size)
372 # Calculate default limits only if needed.
373 if self.vmin is None or self.vmax is None:
374 default_limits = apViz.PercentileInterval(98.0).get_limits(np.abs(matrix.flatten()))
375 else:
376 default_limits = (None, None)
378 # Set the value range using overrides or defaults.
379 vrange = (
380 default_limits[0] if self.vmin is None else self.vmin,
381 default_limits[1] if self.vmax is None else self.vmax,
382 )
384 # Allow for the figure object to be passed in.
385 fig = kwargs.get("fig")
386 if fig is None:
387 fig = plt.figure(figsize=self.figsize, dpi=self.dpi)
388 ax = fig.add_subplot(111)
389 else:
390 ax = fig.gca()
392 if self.title:
393 ax.set_title(self.title, fontsize=self.titleFontSize)
395 if self.componentGroup1Key is not None and self.componentGroup2Key is not None:
396 componentGroup1 = set(data[self.componentGroup1Key])
397 componentGroup2 = set(data[self.componentGroup2Key])
398 if len(componentGroup1) != 1 or len(componentGroup2) != 1:
399 raise ValueError(
400 f"Each column specified by {self.componentGroup1Key} and {self.componentGroup2Key} must "
401 "contain identical values within itself, but they do not."
402 )
403 else:
404 xAxisLabel = self.xAxisLabel + str(componentGroup1.pop())
405 yAxisLabel = self.yAxisLabel + str(componentGroup2.pop())
406 else:
407 xAxisLabel = self.xAxisLabel
408 yAxisLabel = self.yAxisLabel
410 ax.set_xlabel(xAxisLabel, fontsize=self.axisLabelFontSize)
411 ax.set_ylabel(yAxisLabel, fontsize=self.axisLabelFontSize)
413 # Set the colorbar and draw the image.
414 norm = ImageNormalize(vmin=vrange[0], vmax=vrange[1])
415 img = ax.imshow(
416 matrix, interpolation="none", norm=norm, origin=self.matrixOrigin, cmap=self.colorbarCmap
417 )
419 # Calculate the aspect ratio of the image.
420 ratio = matrix.shape[0] / matrix.shape[1]
422 # Add the colorbar flush with the image axis.
423 cbar = fig.colorbar(img, fraction=0.0457 * ratio, pad=0.04)
425 # Set the colorbar label and its font size.
426 cbar.set_label(self.colorbarLabel, fontsize=self.colorbarLabelFontSize)
428 # Set the colorbar tick label font size.
429 cbar.ax.tick_params(labelsize=self.colorbarTickLabelFontSize)
431 # If requested, we shift all the positions by 0.5 considering the
432 # zero-point at a pixel boundary rather than the center of the pixel.
433 shift = 0.5 if self.setPositionsAtPixelBoundaries else 0
435 if self.component1Key is not None and self.component2Key is not None:
436 xAxisTickValues = np.arange(matrix.shape[1] + shift)
437 yAxisTickValues = np.arange(matrix.shape[0] + shift)
438 xAxisTickLabels = {key + shift: str(val) for key, val in zip(range(matrix.shape[1]), comp1[0, :])}
439 yAxisTickLabels = {key + shift: str(val) for key, val in zip(range(matrix.shape[0]), comp2[:, 0])}
440 else:
441 xAxisTickValues = self.xAxisTickValues
442 yAxisTickValues = self.yAxisTickValues
443 xAxisTickLabels = self.xAxisTickLabels
444 yAxisTickLabels = self.yAxisTickLabels
446 # If the tick values are not provided, retrieve them from the axes.
447 xticks = xAxisTickValues if xAxisTickValues is not None else ax.xaxis.get_ticklocs()
448 yticks = yAxisTickValues if yAxisTickValues is not None else ax.yaxis.get_ticklocs()
450 # Retrieve the current limits of the x and y axes.
451 xlim, ylim = ax.get_xlim(), ax.get_ylim()
453 # Filter out tick locations that fall outside the current x/y-axis
454 # limits to ensures that only tick locations within the visible range
455 # are kept.
456 xticks = np.array([tick for tick in xticks if min(xlim) <= tick - shift <= max(xlim)])
457 yticks = np.array([tick for tick in yticks if min(ylim) <= tick - shift <= max(ylim)])
458 tick_data = {
459 "x": (
460 xticks - shift,
461 np.array(list(xAxisTickLabels.keys())) - shift if xAxisTickLabels else None,
462 list(xAxisTickLabels.values()) if xAxisTickLabels else None,
463 ),
464 "y": (
465 yticks - shift,
466 np.array(list(yAxisTickLabels.keys())) - shift if yAxisTickLabels else None,
467 list(yAxisTickLabels.values()) if yAxisTickLabels else None,
468 ),
469 }
471 for dim, axis in [("x", ax.xaxis), ("y", ax.yaxis)]:
472 # Get the major tick positions and labels.
473 major_tick_values, positions, labels = tick_data[dim]
475 # Set major ticks.
476 axis.set_ticks(major_tick_values, minor=False)
478 # Set tick labels while compensating for the potential shift in the
479 # tick positions and removing trailing zeros and the decimal point
480 # for integer values.
481 axis.set_ticklabels(
482 [
483 f"{tick + shift:.0f}" if (tick + shift).is_integer() else f"{tick + shift}"
484 for tick in axis.get_ticklocs()
485 ],
486 fontsize=self.tickLabelsFontSize,
487 )
489 # Check if positions are provided.
490 if positions is not None:
491 # Assign specified positions as minor ticks.
492 axis.set_ticks(positions, minor=True)
494 # Conditionally assign labels to major and/or minor ticks.
495 if labels is not None:
496 # Create a lookup for positions to labels.
497 positions_labels_lookup = {
498 p: l if p in major_tick_values else "" for p, l in zip(positions, labels)
499 }
500 # Generate labels for major ticks, leaving blanks for
501 # non-major positions.
502 major_labels = [
503 "" if m not in positions_labels_lookup else positions_labels_lookup[m]
504 for m in major_tick_values
505 ]
506 # Generate labels for minor ticks, excluding those
507 # designated as major.
508 minor_labels = ["" if p in major_tick_values else l for p, l in zip(positions, labels)]
510 # Apply labels to major ticks if any exist.
511 if any(e for e in major_labels if e):
512 axis.set_ticklabels(major_labels, minor=False, fontsize=self.tickLabelsFontSize)
513 else:
514 # If no major labels, clear major tick labels.
515 axis.set_ticklabels("")
517 # Apply labels to minor ticks if any exist.
518 if any(e for e in minor_labels if e):
519 axis.set_ticklabels(minor_labels, minor=True, fontsize=self.tickLabelsFontSize)
521 if dim in self.hideMajorTicks:
522 # Remove major tick marks for asthetic reasons.
523 axis.set_tick_params(which="major", length=0)
525 if dim in self.hideMinorTicks:
526 # Remove minor tick marks for asthetic reasons.
527 axis.set_tick_params(which="minor", length=0)
529 # Rotate the tick labels by the specified angle.
530 ax.tick_params(axis=dim, rotation=self.tickLabelsRotation)
532 # Add vertical and horizontal lines if provided.
533 if "x" in self.guideLines:
534 xLines = self.guideLines["x"]
535 for x, label in xLines.lines.items():
536 ax.axvline(x=x - shift, color=xLines.outlineColor, linewidth=2, alpha=0.6)
537 ax.axvline(
538 x=x - shift, color=xLines.color, linestyle=xLines.linestyle, linewidth=1, alpha=0.85
539 )
540 label = ax.text(
541 x - shift,
542 0.03,
543 label,
544 rotation=90,
545 color=xLines.color,
546 transform=ax.get_xaxis_transform(),
547 horizontalalignment="right",
548 alpha=0.9,
549 )
550 # Add a distinct outline around the label for better visibility
551 # in various backgrounds.
552 label.set_path_effects(
553 [
554 mpl_path_effects.Stroke(linewidth=2, foreground=xLines.outlineColor, alpha=0.8),
555 mpl_path_effects.Normal(),
556 ]
557 )
559 if "y" in self.guideLines:
560 yLines = self.guideLines["y"]
561 for y, label in yLines.lines.items():
562 ax.axhline(y=y - shift, color=yLines.outlineColor, linewidth=2, alpha=0.6)
563 ax.axhline(
564 y=y - shift, color=yLines.color, linestyle=yLines.linestyle, linewidth=1, alpha=0.85
565 )
566 label = ax.text(
567 0.03,
568 y - shift,
569 label,
570 color=yLines.color,
571 transform=ax.get_yaxis_transform(),
572 verticalalignment="bottom",
573 alpha=0.9,
574 )
575 # Add a distinct outline around the label for better visibility
576 # in various backgrounds.
577 label.set_path_effects(
578 [
579 mpl_path_effects.Stroke(linewidth=2, foreground=yLines.outlineColor, alpha=0.8),
580 mpl_path_effects.Normal(),
581 ]
582 )
584 # Add plot info if provided.
585 if plotInfo is not None:
586 fig = addPlotInfo(fig, plotInfo)
588 return fig