Coverage for python / lsst / analysis / tools / actions / plot / gridPlot.py: 24%

78 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-01 08:55 +0000

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ("GridPlot", "GridPanelConfig") 

25 

26from typing import TYPE_CHECKING 

27 

28import matplotlib.pyplot as plt 

29import numpy as np 

30from matplotlib.gridspec import GridSpec 

31 

32from lsst.pex.config import Config, ConfigDictField, DictField, Field, ListField 

33from lsst.pex.config.configurableActions import ConfigurableActionField 

34 

35from ...interfaces import PlotAction, PlotElement 

36 

37if TYPE_CHECKING: 

38 from lsst.analysis.tools.interfaces import KeyedData, PlotResultType 

39 

40 

41class GridPanelConfig(Config): 

42 plotElement = ConfigurableActionField[PlotElement]( 

43 doc="Plot element.", 

44 ) 

45 title = DictField[str, str]( 

46 doc="String arguments passed into ax.set_title() defining the plot element title.", 

47 optional=True, 

48 ) 

49 titleY = Field[float]( 

50 doc="Y position of plot element title.", 

51 optional=True, 

52 ) 

53 

54 

55class GridPlot(PlotAction): 

56 """Plot a series of plot elements onto a regularly spaced grid.""" 

57 

58 panels = ConfigDictField( 

59 doc="Plot elements.", 

60 keytype=int, 

61 itemtype=GridPanelConfig, 

62 ) 

63 numRows = Field[int]( 

64 doc="Number of rows.", 

65 default=1, 

66 ) 

67 numCols = Field[int]( 

68 doc="Number of columns.", 

69 default=1, 

70 ) 

71 width_ratios = ListField[float]( 

72 doc="Width ratios", 

73 optional=True, 

74 ) 

75 height_ratios = ListField[float]( 

76 doc="Height ratios", 

77 optional=True, 

78 ) 

79 xDataKeys = DictField[int, str]( 

80 doc="Dependent data definitions. The key of this dict is the panel ID. The values are keys of data " 

81 "to plot (comma-separated for multiple) where each key may be a subset of a full key.", 

82 default={}, 

83 ) 

84 valsGroupBy = DictField[int, str]( 

85 doc="Independent data definitions. The key of this dict is the panel ID. The values are keys of data " 

86 "to plot (comma-separated for multiple) where each key may be a subset of a full key.", 

87 ) 

88 figsize = ListField[float]( 

89 doc="Figure size.", 

90 default=[8, 8], 

91 ) 

92 dpi = Field[float]( 

93 doc="Dots per inch.", 

94 default=150, 

95 ) 

96 suptitle = DictField[str, str]( 

97 doc="String arguments passed into fig.suptitle() defining the figure title.", 

98 optional=True, 

99 ) 

100 xAxisLabel = Field[str]( 

101 doc="String argument passed into fig.supxlabel() defining the figure x label.", 

102 optional=True, 

103 ) 

104 yAxisLabel = Field[str]( 

105 doc="String argument passed into fig.supylabel() defining the figure y label.", 

106 optional=True, 

107 ) 

108 

109 def __call__(self, data: KeyedData, **kwargs) -> PlotResultType: 

110 """Plot data.""" 

111 fig = plt.figure(figsize=self.figsize, dpi=self.dpi) 

112 figureInfo = {"figsize": fig.get_size_inches(), "dpi": fig.get_dpi()} 

113 

114 if self.height_ratios is None: 

115 height_ratios = np.ones(self.numRows) / self.numRows 

116 else: 

117 height_ratios = self.height_ratios / np.sum(self.height_ratios) 

118 

119 if self.width_ratios is None: 

120 width_ratios = np.ones(self.numCols) / self.numCols 

121 else: 

122 width_ratios = self.width_ratios / np.sum(self.width_ratios) 

123 

124 if self.suptitle is not None: 

125 fig.suptitle(**self.suptitle) 

126 if self.xAxisLabel is not None: 

127 fig.supxlabel(self.xAxisLabel) 

128 if self.yAxisLabel is not None: 

129 fig.supylabel(self.yAxisLabel) 

130 

131 # TODO: See DM-44283:Add subplot_mosaic functionality to plotElements 

132 gs = GridSpec( 

133 self.numRows, 

134 self.numCols, 

135 figure=fig, 

136 height_ratios=height_ratios, 

137 width_ratios=width_ratios, 

138 ) 

139 

140 # Iterate over all of the plots we'll make: 

141 for row in range(self.numRows): 

142 for col in range(self.numCols): 

143 # This sequential index is used to identify what data 

144 # to plot. The `valsGroupBy` dict should have this 

145 # index as a key, with the values matching the subset 

146 # of rows that have that value in the column specified 

147 # by the `panelKey`. 

148 index = row * self.numCols + col 

149 if index not in self.valsGroupBy.keys(): 

150 continue 

151 ax = fig.add_subplot(gs[row, col]) 

152 

153 # These lists hold the columns that will be plotted, 

154 # comma separated to allow multiple series to be 

155 # plotted on the same panel. If `xDataKeys` does not 

156 # contain this panel's index, then the vector index 

157 # will be used for the x-coordinate. 

158 xList = x.split(",") if (x := self.xDataKeys.get(index)) else None 

159 valList = self.valsGroupBy[index].split(",") 

160 

161 # Iterate over the series to plot in this panel: 

162 for i, val in enumerate(valList): 

163 for key in data: 

164 newData = {} 

165 if val not in key: 

166 # Skip columns in data that do not match 

167 # our series identifier. 

168 continue 

169 if xList is not None: 

170 # Store the x-coordinate data to be 

171 # plotted in the temporary column name 

172 # indicated by the `xDataKeys` dict above. 

173 namedKey = self.panels[index].plotElement.xKey 

174 newData[namedKey] = data[xList[i]] 

175 if key in xList: 

176 # if this key is in the xList, we need 

177 # to not plot it. 

178 continue 

179 

180 # If provided, store the y-coordinate data to be 

181 # plotted in the temporary column name indicated 

182 # by the `valsGroupBy` dict above. Not all elements 

183 # need y-coordinate data, such as plotInfoElement. 

184 if hasattr(self.panels[index].plotElement, "valsKey"): 

185 namedKey = self.panels[index].plotElement.valsKey 

186 newData[namedKey] = data[key] 

187 

188 # Actually make the plot. 

189 _ = self.panels[index].plotElement( 

190 data=newData, ax=ax, figureInfo=figureInfo, **kwargs 

191 ) 

192 

193 if self.panels[index].title is not None: 

194 ax.set_title(**self.panels[index].title, y=self.panels[index].titleY) 

195 

196 plt.tight_layout() 

197 return fig 

198 

199 def validate(self): 

200 """Validate configuration.""" 

201 super().validate() 

202 if self.xDataKeys and len(self.xDataKeys) != self.numRows * self.numCols: 

203 raise RuntimeError("Number of xDataKeys keys must match number of rows * columns.") 

204 if len(self.valsGroupBy) != self.numRows * self.numCols: 

205 raise RuntimeError("Number of valsGroupBy keys must match number of rows * columns.") 

206 if self.width_ratios and len(self.width_ratios) != self.numCols: 

207 raise RuntimeError("Number of supplied width ratios must match number of columns.") 

208 if self.height_ratios and len(self.height_ratios) != self.numRows: 

209 raise RuntimeError("Number of supplied height ratios must match number of rows.")