Coverage for python/lsst/analysis/tools/actions/plot/gridPlot.py: 25%

80 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-16 04:37 -0700

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ("GridPlot", "GridPanelConfig") 

25 

26from typing import TYPE_CHECKING 

27 

28import matplotlib.pyplot as plt 

29import numpy as np 

30from lsst.pex.config import Config, ConfigDictField, DictField, Field, ListField 

31from lsst.pex.config.configurableActions import ConfigurableActionField 

32from matplotlib.gridspec import GridSpec 

33 

34from ...interfaces import PlotAction, PlotElement 

35 

36if TYPE_CHECKING: 36 ↛ 37line 36 didn't jump to line 37, because the condition on line 36 was never true

37 from lsst.analysis.tools.interfaces import KeyedData, PlotResultType 

38 

39 

40class GridPanelConfig(Config): 

41 plotElement = ConfigurableActionField[PlotElement]( 

42 doc="Plot element.", 

43 ) 

44 title = DictField[str, str]( 

45 doc="String arguments passed into ax.set_title() defining the plot element title.", 

46 optional=True, 

47 ) 

48 titleY = Field[float]( 

49 doc="Y position of plot element title.", 

50 optional=True, 

51 ) 

52 

53 

54class GridPlot(PlotAction): 

55 """Plot a series of plot elements onto a regularly spaced grid.""" 

56 

57 panels = ConfigDictField( 

58 doc="Plot elements.", 

59 keytype=int, 

60 itemtype=GridPanelConfig, 

61 ) 

62 numRows = Field[int]( 

63 doc="Number of rows.", 

64 default=1, 

65 ) 

66 numCols = Field[int]( 

67 doc="Number of columns.", 

68 default=1, 

69 ) 

70 width_ratios = ListField[float]( 

71 doc="Width ratios", 

72 optional=True, 

73 ) 

74 height_ratios = ListField[float]( 

75 doc="Height ratios", 

76 optional=True, 

77 ) 

78 xDataKeys = DictField[int, str]( 

79 doc="Dependent data definitions. The key of this dict is the panel ID. The values are keys of data " 

80 "to plot (comma-separated for multiple) where each key may be a subset of a full key.", 

81 default={}, 

82 ) 

83 valsGroupBy = DictField[int, str]( 

84 doc="Independent data definitions. The key of this dict is the panel ID. The values are keys of data " 

85 "to plot (comma-separated for multiple) where each key may be a subset of a full key.", 

86 ) 

87 figsize = ListField[float]( 

88 doc="Figure size.", 

89 default=[8, 8], 

90 ) 

91 dpi = Field[float]( 

92 doc="Dots per inch.", 

93 default=150, 

94 ) 

95 suptitle = DictField[str, str]( 

96 doc="String arguments passed into fig.suptitle() defining the figure title.", 

97 optional=True, 

98 ) 

99 xAxisLabel = Field[str]( 

100 doc="String argument passed into fig.supxlabel() defining the figure x label.", 

101 optional=True, 

102 ) 

103 yAxisLabel = Field[str]( 

104 doc="String argument passed into fig.supylabel() defining the figure y label.", 

105 optional=True, 

106 ) 

107 

108 def __call__(self, data: KeyedData, **kwargs) -> PlotResultType: 

109 """Plot data.""" 

110 fig = plt.figure(figsize=self.figsize, dpi=self.dpi) 

111 figureInfo = {"figsize": fig.get_size_inches(), "dpi": fig.get_dpi()} 

112 

113 if self.height_ratios is None: 

114 height_ratios = np.ones(self.numRows) / self.numRows 

115 else: 

116 height_ratios = self.height_ratios / np.sum(self.height_ratios) 

117 

118 if self.width_ratios is None: 

119 width_ratios = np.ones(self.numCols) / self.numCols 

120 else: 

121 width_ratios = self.width_ratios / np.sum(self.width_ratios) 

122 

123 if self.suptitle is not None: 

124 fig.suptitle(**self.suptitle) 

125 if self.xAxisLabel is not None: 

126 fig.supxlabel(self.xAxisLabel) 

127 if self.yAxisLabel is not None: 

128 fig.supylabel(self.yAxisLabel) 

129 

130 # TODO: See DM-44283:Add subplot_mosaic functionality to plotElements 

131 gs = GridSpec( 

132 self.numRows, 

133 self.numCols, 

134 figure=fig, 

135 height_ratios=height_ratios, 

136 width_ratios=width_ratios, 

137 ) 

138 

139 # Iterate over all of the plots we'll make: 

140 for row in range(self.numRows): 

141 for col in range(self.numCols): 

142 # This sequential index is used to identify what data 

143 # to plot. The `valsGroupBy` dict should have this 

144 # index as a key, with the values matching the subset 

145 # of rows that have that value in the column specified 

146 # by the `panelKey`. 

147 index = row * self.numCols + col 

148 if index not in self.valsGroupBy.keys(): 

149 continue 

150 ax = fig.add_subplot(gs[row, col]) 

151 

152 # These lists hold the columns that will be plotted, 

153 # comma separated to allow multiple series to be 

154 # plotted on the same panel. If `xDataKeys` does not 

155 # contain this panel's index, then the vector index 

156 # will be used for the x-coordinate. 

157 xList = x.split(",") if (x := self.xDataKeys.get(index)) else None 

158 valList = self.valsGroupBy[index].split(",") 

159 

160 # Iterate over the series to plot in this panel: 

161 for i, val in enumerate(valList): 

162 for key in data: 

163 newData = {} 

164 if val not in key: 

165 # Skip columns in data that do not match 

166 # our series identifier. 

167 continue 

168 if xList is not None: 

169 # Store the x-coordinate data to be 

170 # plotted in the temporary column name 

171 # indicated by the `xDataKeys` dict above. 

172 namedKey = self.panels[index].plotElement.xKey 

173 newData[namedKey] = data[xList[i]] 

174 if key in xList: 

175 # if this key is in the xList, we need 

176 # to not plot it. 

177 continue 

178 

179 # If provided, store the y-coordinate data to be 

180 # plotted in the temporary column name indicated 

181 # by the `valsGroupBy` dict above. Not all elements 

182 # need y-coordinate data, such as plotInfoElement. 

183 if hasattr(self.panels[index].plotElement, "valsKey"): 

184 namedKey = self.panels[index].plotElement.valsKey 

185 newData[namedKey] = data[key] 

186 

187 # Actually make the plot. 

188 _ = self.panels[index].plotElement( 

189 data=newData, ax=ax, figureInfo=figureInfo, **kwargs 

190 ) 

191 

192 if self.panels[index].title is not None: 

193 ax.set_title(**self.panels[index].title, y=self.panels[index].titleY) 

194 

195 plt.tight_layout() 

196 return fig 

197 

198 def validate(self): 

199 """Validate configuration.""" 

200 super().validate() 

201 if self.xDataKeys and len(self.xDataKeys) != self.numRows * self.numCols: 

202 raise RuntimeError("Number of xDataKeys keys must match number of rows * columns.") 

203 if len(self.valsGroupBy) != self.numRows * self.numCols: 

204 raise RuntimeError("Number of valsGroupBy keys must match number of rows * columns.") 

205 if self.width_ratios and len(self.width_ratios) != self.numCols: 

206 raise RuntimeError("Number of supplied width ratios must match number of columns.") 

207 if self.height_ratios and len(self.height_ratios) != self.numRows: 

208 raise RuntimeError("Number of supplied height ratios must match number of rows.")