Coverage for python/lsst/analysis/tools/actions/plot/gridPlot.py: 28%

65 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-15 09:59 +0000

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ("GridPlot", "GridPanelConfig") 

25 

26from typing import TYPE_CHECKING 

27 

28import matplotlib.pyplot as plt 

29from lsst.pex.config import Config, ConfigDictField, DictField, Field, ListField 

30from lsst.pex.config.configurableActions import ConfigurableActionField 

31from matplotlib.gridspec import GridSpec 

32 

33from ...interfaces import PlotAction, PlotElement 

34 

35if TYPE_CHECKING: 35 ↛ 36line 35 didn't jump to line 36, because the condition on line 35 was never true

36 from lsst.analysis.tools.interfaces import KeyedData, PlotResultType 

37 

38 

39class GridPanelConfig(Config): 

40 plotElement = ConfigurableActionField[PlotElement]( 

41 doc="Plot element.", 

42 ) 

43 title = DictField[str, str]( 

44 doc="String arguments passed into ax.set_title() defining the plot element title.", 

45 ) 

46 titleY = Field[float]( 

47 doc="Y position of plot element title.", 

48 default=None, 

49 ) 

50 

51 

52class GridPlot(PlotAction): 

53 """Plot a series of plot elements onto a regularly spaced grid.""" 

54 

55 panels = ConfigDictField( 

56 doc="Plot elements.", 

57 keytype=int, 

58 itemtype=GridPanelConfig, 

59 ) 

60 numRows = Field[int]( 

61 doc="Number of rows.", 

62 default=1, 

63 ) 

64 numCols = Field[int]( 

65 doc="Number of columns.", 

66 default=1, 

67 ) 

68 xDataKeys = DictField[int, str]( 

69 doc="Dependent data definitions. The key of this dict is the panel ID. The values are keys of data " 

70 "to plot (comma-separated for multiple) where each key may be a subset of a full key.", 

71 default={}, 

72 ) 

73 valsGroupBy = DictField[int, str]( 

74 doc="Independent data definitions. The key of this dict is the panel ID. The values are keys of data " 

75 "to plot (comma-separated for multiple) where each key may be a subset of a full key.", 

76 ) 

77 figsize = ListField[float]( 

78 doc="Figure size.", 

79 default=[8, 8], 

80 ) 

81 dpi = Field[float]( 

82 doc="Dots per inch.", 

83 default=150, 

84 ) 

85 suptitle = DictField[str, str]( 

86 doc="String arguments passed into fig.suptitle() defining the figure title.", 

87 optional=True, 

88 ) 

89 xAxisLabel = Field[str]( 

90 doc="String argument passed into fig.supxlabel() defining the figure x label.", 

91 optional=True, 

92 ) 

93 yAxisLabel = Field[str]( 

94 doc="String argument passed into fig.supylabel() defining the figure y label.", 

95 optional=True, 

96 ) 

97 

98 def __call__(self, data: KeyedData, **kwargs) -> PlotResultType: 

99 """Plot data.""" 

100 fig = plt.figure(figsize=self.figsize, dpi=self.dpi) 

101 if self.suptitle is not None: 

102 fig.suptitle(**self.suptitle) 

103 if self.xAxisLabel is not None: 

104 fig.supxlabel(self.xAxisLabel) 

105 if self.yAxisLabel is not None: 

106 fig.supylabel(self.yAxisLabel) 

107 

108 gs = GridSpec(self.numRows, self.numCols, figure=fig) 

109 

110 # Iterate over all of the plots we'll make: 

111 for row in range(self.numRows): 

112 for col in range(self.numCols): 

113 # This sequential index is used to identify what data 

114 # to plot. The `valsGroupBy` dict should have this 

115 # index as a key, with the values matching the subset 

116 # of rows that have that value in the column specified 

117 # by the `panelKey`. 

118 index = row * self.numCols + col 

119 if index not in self.valsGroupBy.keys(): 

120 continue 

121 ax = fig.add_subplot(gs[row, col]) 

122 

123 # These lists hold the columns that will be plotted, 

124 # comma separated to allow multiple series to be 

125 # plotted on the same panel. If `xDataKeys` does not 

126 # contain this panel's index, then the vector index 

127 # will be used for the x-coordinate. 

128 xList = x.split(",") if (x := self.xDataKeys.get(index)) else None 

129 valList = self.valsGroupBy[index].split(",") 

130 

131 # Iterate over the series to plot in this panel: 

132 for i, val in enumerate(valList): 

133 for key in data: 

134 newData = {} 

135 if val not in key: 

136 # Skip columns in data that do not match 

137 # our series identifier. 

138 continue 

139 if xList is not None: 

140 # Store the x-coordinate data to be 

141 # plotted in the temporary column name 

142 # indicated by the `xDataKeys` dict above. 

143 namedKey = self.panels[index].plotElement.xKey 

144 newData[namedKey] = data[xList[i]] 

145 if key in xList: 

146 # if this key is in the xList, we need 

147 # to not plot it. 

148 continue 

149 

150 # Store the y-coordinate data to be plotted in 

151 # the temporary column name indicated by the 

152 # `valsGroupBy` dict above. 

153 namedKey = self.panels[index].plotElement.valsKey 

154 newData[namedKey] = data[key] 

155 

156 # Actually make the plot. 

157 _ = self.panels[index].plotElement(data=newData, ax=ax, **kwargs) 

158 

159 if self.panels[index].title is not None: 

160 ax.set_title(**self.panels[index].title, y=self.panels[index].titleY) 

161 

162 plt.tight_layout() 

163 return fig 

164 

165 def validate(self): 

166 """Validate configuration.""" 

167 super().validate() 

168 if self.xDataKeys and len(self.xDataKeys) != self.numRows * self.numCols: 

169 raise RuntimeError("Number of xDataKeys keys must match number of rows * columns.") 

170 if len(self.valsGroupBy) != self.numRows * self.numCols: 

171 raise RuntimeError("Number of valsGroupBy keys must match number of rows * columns.")