Coverage for python / lsst / analysis / tools / actions / plot / gridPlot.py: 24%
78 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-30 09:26 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-30 09:26 +0000
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ("GridPlot", "GridPanelConfig")
26from typing import TYPE_CHECKING
28import matplotlib.pyplot as plt
29import numpy as np
30from matplotlib.gridspec import GridSpec
32from lsst.pex.config import Config, ConfigDictField, DictField, Field, ListField
33from lsst.pex.config.configurableActions import ConfigurableActionField
35from ...interfaces import PlotAction, PlotElement
37if TYPE_CHECKING:
38 from lsst.analysis.tools.interfaces import KeyedData, PlotResultType
41class GridPanelConfig(Config):
42 plotElement = ConfigurableActionField[PlotElement](
43 doc="Plot element.",
44 )
45 title = DictField[str, str](
46 doc="String arguments passed into ax.set_title() defining the plot element title.",
47 optional=True,
48 )
49 titleY = Field[float](
50 doc="Y position of plot element title.",
51 optional=True,
52 )
55class GridPlot(PlotAction):
56 """Plot a series of plot elements onto a regularly spaced grid."""
58 panels = ConfigDictField(
59 doc="Plot elements.",
60 keytype=int,
61 itemtype=GridPanelConfig,
62 )
63 numRows = Field[int](
64 doc="Number of rows.",
65 default=1,
66 )
67 numCols = Field[int](
68 doc="Number of columns.",
69 default=1,
70 )
71 width_ratios = ListField[float](
72 doc="Width ratios",
73 optional=True,
74 )
75 height_ratios = ListField[float](
76 doc="Height ratios",
77 optional=True,
78 )
79 xDataKeys = DictField[int, str](
80 doc="Dependent data definitions. The key of this dict is the panel ID. The values are keys of data "
81 "to plot (comma-separated for multiple) where each key may be a subset of a full key.",
82 default={},
83 )
84 valsGroupBy = DictField[int, str](
85 doc="Independent data definitions. The key of this dict is the panel ID. The values are keys of data "
86 "to plot (comma-separated for multiple) where each key may be a subset of a full key.",
87 )
88 figsize = ListField[float](
89 doc="Figure size.",
90 default=[8, 8],
91 )
92 dpi = Field[float](
93 doc="Dots per inch.",
94 default=150,
95 )
96 suptitle = DictField[str, str](
97 doc="String arguments passed into fig.suptitle() defining the figure title.",
98 optional=True,
99 )
100 xAxisLabel = Field[str](
101 doc="String argument passed into fig.supxlabel() defining the figure x label.",
102 optional=True,
103 )
104 yAxisLabel = Field[str](
105 doc="String argument passed into fig.supylabel() defining the figure y label.",
106 optional=True,
107 )
109 def __call__(self, data: KeyedData, **kwargs) -> PlotResultType:
110 """Plot data."""
111 fig = plt.figure(figsize=self.figsize, dpi=self.dpi)
112 figureInfo = {"figsize": fig.get_size_inches(), "dpi": fig.get_dpi()}
114 if self.height_ratios is None:
115 height_ratios = np.ones(self.numRows) / self.numRows
116 else:
117 height_ratios = self.height_ratios / np.sum(self.height_ratios)
119 if self.width_ratios is None:
120 width_ratios = np.ones(self.numCols) / self.numCols
121 else:
122 width_ratios = self.width_ratios / np.sum(self.width_ratios)
124 if self.suptitle is not None:
125 fig.suptitle(**self.suptitle)
126 if self.xAxisLabel is not None:
127 fig.supxlabel(self.xAxisLabel)
128 if self.yAxisLabel is not None:
129 fig.supylabel(self.yAxisLabel)
131 # TODO: See DM-44283:Add subplot_mosaic functionality to plotElements
132 gs = GridSpec(
133 self.numRows,
134 self.numCols,
135 figure=fig,
136 height_ratios=height_ratios,
137 width_ratios=width_ratios,
138 )
140 # Iterate over all of the plots we'll make:
141 for row in range(self.numRows):
142 for col in range(self.numCols):
143 # This sequential index is used to identify what data
144 # to plot. The `valsGroupBy` dict should have this
145 # index as a key, with the values matching the subset
146 # of rows that have that value in the column specified
147 # by the `panelKey`.
148 index = row * self.numCols + col
149 if index not in self.valsGroupBy.keys():
150 continue
151 ax = fig.add_subplot(gs[row, col])
153 # These lists hold the columns that will be plotted,
154 # comma separated to allow multiple series to be
155 # plotted on the same panel. If `xDataKeys` does not
156 # contain this panel's index, then the vector index
157 # will be used for the x-coordinate.
158 xList = x.split(",") if (x := self.xDataKeys.get(index)) else None
159 valList = self.valsGroupBy[index].split(",")
161 # Iterate over the series to plot in this panel:
162 for i, val in enumerate(valList):
163 for key in data:
164 newData = {}
165 if val not in key:
166 # Skip columns in data that do not match
167 # our series identifier.
168 continue
169 if xList is not None:
170 # Store the x-coordinate data to be
171 # plotted in the temporary column name
172 # indicated by the `xDataKeys` dict above.
173 namedKey = self.panels[index].plotElement.xKey
174 newData[namedKey] = data[xList[i]]
175 if key in xList:
176 # if this key is in the xList, we need
177 # to not plot it.
178 continue
180 # If provided, store the y-coordinate data to be
181 # plotted in the temporary column name indicated
182 # by the `valsGroupBy` dict above. Not all elements
183 # need y-coordinate data, such as plotInfoElement.
184 if hasattr(self.panels[index].plotElement, "valsKey"):
185 namedKey = self.panels[index].plotElement.valsKey
186 newData[namedKey] = data[key]
188 # Actually make the plot.
189 _ = self.panels[index].plotElement(
190 data=newData, ax=ax, figureInfo=figureInfo, **kwargs
191 )
193 if self.panels[index].title is not None:
194 ax.set_title(**self.panels[index].title, y=self.panels[index].titleY)
196 plt.tight_layout()
197 return fig
199 def validate(self):
200 """Validate configuration."""
201 super().validate()
202 if self.xDataKeys and len(self.xDataKeys) != self.numRows * self.numCols:
203 raise RuntimeError("Number of xDataKeys keys must match number of rows * columns.")
204 if len(self.valsGroupBy) != self.numRows * self.numCols:
205 raise RuntimeError("Number of valsGroupBy keys must match number of rows * columns.")
206 if self.width_ratios and len(self.width_ratios) != self.numCols:
207 raise RuntimeError("Number of supplied width ratios must match number of columns.")
208 if self.height_ratios and len(self.height_ratios) != self.numRows:
209 raise RuntimeError("Number of supplied height ratios must match number of rows.")