Coverage for python/lsst/analysis/tools/actions/plot/gridPlot.py: 28%
65 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-05-04 03:35 -0700
« prev ^ index » next coverage.py v7.5.0, created at 2024-05-04 03:35 -0700
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ("GridPlot", "GridPanelConfig")
26from typing import TYPE_CHECKING
28import matplotlib.pyplot as plt
29from lsst.pex.config import Config, ConfigDictField, DictField, Field, ListField
30from lsst.pex.config.configurableActions import ConfigurableActionField
31from matplotlib.gridspec import GridSpec
33from ...interfaces import PlotAction, PlotElement
35if TYPE_CHECKING: 35 ↛ 36line 35 didn't jump to line 36, because the condition on line 35 was never true
36 from lsst.analysis.tools.interfaces import KeyedData, PlotResultType
39class GridPanelConfig(Config):
40 plotElement = ConfigurableActionField[PlotElement](
41 doc="Plot element.",
42 )
43 title = DictField[str, str](
44 doc="String arguments passed into ax.set_title() defining the plot element title.",
45 )
46 titleY = Field[float](
47 doc="Y position of plot element title.",
48 default=None,
49 )
52class GridPlot(PlotAction):
53 """Plot a series of plot elements onto a regularly spaced grid."""
55 panels = ConfigDictField(
56 doc="Plot elements.",
57 keytype=int,
58 itemtype=GridPanelConfig,
59 )
60 numRows = Field[int](
61 doc="Number of rows.",
62 default=1,
63 )
64 numCols = Field[int](
65 doc="Number of columns.",
66 default=1,
67 )
68 xDataKeys = DictField[int, str](
69 doc="Dependent data definitions. The key of this dict is the panel ID. The values are keys of data "
70 "to plot (comma-separated for multiple) where each key may be a subset of a full key.",
71 default={},
72 )
73 valsGroupBy = DictField[int, str](
74 doc="Independent data definitions. The key of this dict is the panel ID. The values are keys of data "
75 "to plot (comma-separated for multiple) where each key may be a subset of a full key.",
76 )
77 figsize = ListField[float](
78 doc="Figure size.",
79 default=[8, 8],
80 )
81 dpi = Field[float](
82 doc="Dots per inch.",
83 default=150,
84 )
85 suptitle = DictField[str, str](
86 doc="String arguments passed into fig.suptitle() defining the figure title.",
87 optional=True,
88 )
89 xAxisLabel = Field[str](
90 doc="String argument passed into fig.supxlabel() defining the figure x label.",
91 optional=True,
92 )
93 yAxisLabel = Field[str](
94 doc="String argument passed into fig.supylabel() defining the figure y label.",
95 optional=True,
96 )
98 def __call__(self, data: KeyedData, **kwargs) -> PlotResultType:
99 """Plot data."""
100 fig = plt.figure(figsize=self.figsize, dpi=self.dpi)
101 if self.suptitle is not None:
102 fig.suptitle(**self.suptitle)
103 if self.xAxisLabel is not None:
104 fig.supxlabel(self.xAxisLabel)
105 if self.yAxisLabel is not None:
106 fig.supylabel(self.yAxisLabel)
108 gs = GridSpec(self.numRows, self.numCols, figure=fig)
110 # Iterate over all of the plots we'll make:
111 for row in range(self.numRows):
112 for col in range(self.numCols):
113 # This sequential index is used to identify what data
114 # to plot. The `valsGroupBy` dict should have this
115 # index as a key, with the values matching the subset
116 # of rows that have that value in the column specified
117 # by the `panelKey`.
118 index = row * self.numCols + col
119 if index not in self.valsGroupBy.keys():
120 continue
121 ax = fig.add_subplot(gs[row, col])
123 # These lists hold the columns that will be plotted,
124 # comma separated to allow multiple series to be
125 # plotted on the same panel. If `xDataKeys` does not
126 # contain this panel's index, then the vector index
127 # will be used for the x-coordinate.
128 xList = x.split(",") if (x := self.xDataKeys.get(index)) else None
129 valList = self.valsGroupBy[index].split(",")
131 # Iterate over the series to plot in this panel:
132 for i, val in enumerate(valList):
133 for key in data:
134 newData = {}
135 if val not in key:
136 # Skip columns in data that do not match
137 # our series identifier.
138 continue
139 if xList is not None:
140 # Store the x-coordinate data to be
141 # plotted in the temporary column name
142 # indicated by the `xDataKeys` dict above.
143 namedKey = self.panels[index].plotElement.xKey
144 newData[namedKey] = data[xList[i]]
145 if key in xList:
146 # if this key is in the xList, we need
147 # to not plot it.
148 continue
150 # Store the y-coordinate data to be plotted in
151 # the temporary column name indicated by the
152 # `valsGroupBy` dict above.
153 namedKey = self.panels[index].plotElement.valsKey
154 newData[namedKey] = data[key]
156 # Actually make the plot.
157 _ = self.panels[index].plotElement(data=newData, ax=ax, **kwargs)
159 if self.panels[index].title is not None:
160 ax.set_title(**self.panels[index].title, y=self.panels[index].titleY)
162 plt.tight_layout()
163 return fig
165 def validate(self):
166 """Validate configuration."""
167 super().validate()
168 if self.xDataKeys and len(self.xDataKeys) != self.numRows * self.numCols:
169 raise RuntimeError("Number of xDataKeys keys must match number of rows * columns.")
170 if len(self.valsGroupBy) != self.numRows * self.numCols:
171 raise RuntimeError("Number of valsGroupBy keys must match number of rows * columns.")