Coverage for python/lsst/analysis/tools/actions/plot/gridPlot.py: 25%
80 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-16 04:37 -0700
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-16 04:37 -0700
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ("GridPlot", "GridPanelConfig")
26from typing import TYPE_CHECKING
28import matplotlib.pyplot as plt
29import numpy as np
30from lsst.pex.config import Config, ConfigDictField, DictField, Field, ListField
31from lsst.pex.config.configurableActions import ConfigurableActionField
32from matplotlib.gridspec import GridSpec
34from ...interfaces import PlotAction, PlotElement
36if TYPE_CHECKING: 36 ↛ 37line 36 didn't jump to line 37, because the condition on line 36 was never true
37 from lsst.analysis.tools.interfaces import KeyedData, PlotResultType
40class GridPanelConfig(Config):
41 plotElement = ConfigurableActionField[PlotElement](
42 doc="Plot element.",
43 )
44 title = DictField[str, str](
45 doc="String arguments passed into ax.set_title() defining the plot element title.",
46 optional=True,
47 )
48 titleY = Field[float](
49 doc="Y position of plot element title.",
50 optional=True,
51 )
54class GridPlot(PlotAction):
55 """Plot a series of plot elements onto a regularly spaced grid."""
57 panels = ConfigDictField(
58 doc="Plot elements.",
59 keytype=int,
60 itemtype=GridPanelConfig,
61 )
62 numRows = Field[int](
63 doc="Number of rows.",
64 default=1,
65 )
66 numCols = Field[int](
67 doc="Number of columns.",
68 default=1,
69 )
70 width_ratios = ListField[float](
71 doc="Width ratios",
72 optional=True,
73 )
74 height_ratios = ListField[float](
75 doc="Height ratios",
76 optional=True,
77 )
78 xDataKeys = DictField[int, str](
79 doc="Dependent data definitions. The key of this dict is the panel ID. The values are keys of data "
80 "to plot (comma-separated for multiple) where each key may be a subset of a full key.",
81 default={},
82 )
83 valsGroupBy = DictField[int, str](
84 doc="Independent data definitions. The key of this dict is the panel ID. The values are keys of data "
85 "to plot (comma-separated for multiple) where each key may be a subset of a full key.",
86 )
87 figsize = ListField[float](
88 doc="Figure size.",
89 default=[8, 8],
90 )
91 dpi = Field[float](
92 doc="Dots per inch.",
93 default=150,
94 )
95 suptitle = DictField[str, str](
96 doc="String arguments passed into fig.suptitle() defining the figure title.",
97 optional=True,
98 )
99 xAxisLabel = Field[str](
100 doc="String argument passed into fig.supxlabel() defining the figure x label.",
101 optional=True,
102 )
103 yAxisLabel = Field[str](
104 doc="String argument passed into fig.supylabel() defining the figure y label.",
105 optional=True,
106 )
108 def __call__(self, data: KeyedData, **kwargs) -> PlotResultType:
109 """Plot data."""
110 fig = plt.figure(figsize=self.figsize, dpi=self.dpi)
111 figureInfo = {"figsize": fig.get_size_inches(), "dpi": fig.get_dpi()}
113 if self.height_ratios is None:
114 height_ratios = np.ones(self.numRows) / self.numRows
115 else:
116 height_ratios = self.height_ratios / np.sum(self.height_ratios)
118 if self.width_ratios is None:
119 width_ratios = np.ones(self.numCols) / self.numCols
120 else:
121 width_ratios = self.width_ratios / np.sum(self.width_ratios)
123 if self.suptitle is not None:
124 fig.suptitle(**self.suptitle)
125 if self.xAxisLabel is not None:
126 fig.supxlabel(self.xAxisLabel)
127 if self.yAxisLabel is not None:
128 fig.supylabel(self.yAxisLabel)
130 # TODO: See DM-44283:Add subplot_mosaic functionality to plotElements
131 gs = GridSpec(
132 self.numRows,
133 self.numCols,
134 figure=fig,
135 height_ratios=height_ratios,
136 width_ratios=width_ratios,
137 )
139 # Iterate over all of the plots we'll make:
140 for row in range(self.numRows):
141 for col in range(self.numCols):
142 # This sequential index is used to identify what data
143 # to plot. The `valsGroupBy` dict should have this
144 # index as a key, with the values matching the subset
145 # of rows that have that value in the column specified
146 # by the `panelKey`.
147 index = row * self.numCols + col
148 if index not in self.valsGroupBy.keys():
149 continue
150 ax = fig.add_subplot(gs[row, col])
152 # These lists hold the columns that will be plotted,
153 # comma separated to allow multiple series to be
154 # plotted on the same panel. If `xDataKeys` does not
155 # contain this panel's index, then the vector index
156 # will be used for the x-coordinate.
157 xList = x.split(",") if (x := self.xDataKeys.get(index)) else None
158 valList = self.valsGroupBy[index].split(",")
160 # Iterate over the series to plot in this panel:
161 for i, val in enumerate(valList):
162 for key in data:
163 newData = {}
164 if val not in key:
165 # Skip columns in data that do not match
166 # our series identifier.
167 continue
168 if xList is not None:
169 # Store the x-coordinate data to be
170 # plotted in the temporary column name
171 # indicated by the `xDataKeys` dict above.
172 namedKey = self.panels[index].plotElement.xKey
173 newData[namedKey] = data[xList[i]]
174 if key in xList:
175 # if this key is in the xList, we need
176 # to not plot it.
177 continue
179 # If provided, store the y-coordinate data to be
180 # plotted in the temporary column name indicated
181 # by the `valsGroupBy` dict above. Not all elements
182 # need y-coordinate data, such as plotInfoElement.
183 if hasattr(self.panels[index].plotElement, "valsKey"):
184 namedKey = self.panels[index].plotElement.valsKey
185 newData[namedKey] = data[key]
187 # Actually make the plot.
188 _ = self.panels[index].plotElement(
189 data=newData, ax=ax, figureInfo=figureInfo, **kwargs
190 )
192 if self.panels[index].title is not None:
193 ax.set_title(**self.panels[index].title, y=self.panels[index].titleY)
195 plt.tight_layout()
196 return fig
198 def validate(self):
199 """Validate configuration."""
200 super().validate()
201 if self.xDataKeys and len(self.xDataKeys) != self.numRows * self.numCols:
202 raise RuntimeError("Number of xDataKeys keys must match number of rows * columns.")
203 if len(self.valsGroupBy) != self.numRows * self.numCols:
204 raise RuntimeError("Number of valsGroupBy keys must match number of rows * columns.")
205 if self.width_ratios and len(self.width_ratios) != self.numCols:
206 raise RuntimeError("Number of supplied width ratios must match number of columns.")
207 if self.height_ratios and len(self.height_ratios) != self.numRows:
208 raise RuntimeError("Number of supplied height ratios must match number of rows.")