Coverage for python/lsst/analysis/tools/actions/plot/xyPlot.py: 30%

78 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-07-25 11:38 +0000

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ("XYPlot",) 

25 

26from typing import TYPE_CHECKING, Any, Mapping 

27 

28import matplotlib.pyplot as plt 

29from lsst.pex.config import ChoiceField, DictField, Field, FieldValidationError 

30from matplotlib.ticker import SymmetricalLogLocator 

31 

32from ...interfaces import PlotAction, Vector 

33from .plotUtils import addPlotInfo 

34 

35if TYPE_CHECKING: 35 ↛ 36line 35 didn't jump to line 36, because the condition on line 35 was never true

36 from matplotlib.figure import Figure 

37 

38 from ...interfaces import KeyedData, KeyedDataSchema 

39 

40 

41class XYPlot(PlotAction): 

42 """Make a plot (with errorbars) of one quantity (X) vs another (Y).""" 

43 

44 boolKwargs = DictField[str, bool]( 

45 doc="Keyword arguments to ax.errorbar that take boolean values", 

46 default={}, 

47 optional=True, 

48 ) 

49 

50 numKwargs = DictField[str, float]( 

51 doc="Keyword arguments to ax.errorbar that take numerical (float or int) values", 

52 default={}, 

53 optional=True, 

54 ) 

55 

56 strKwargs = DictField[str, str]( 

57 doc="Keyword arguments to ax.errorbar that take string values", 

58 default={}, 

59 optional=True, 

60 ) 

61 

62 xAxisLabel = Field[str]( 

63 doc="The label to use for the x-axis.", 

64 default="x", 

65 ) 

66 

67 yAxisLabel = Field[str]( 

68 doc="The label to use for the y-axis.", 

69 default="y", 

70 ) 

71 

72 xScale = ChoiceField[str]( 

73 doc="The scale to use for the x-axis.", 

74 default="linear", 

75 allowed={scale: scale for scale in ("linear", "log", "symlog")}, 

76 ) 

77 

78 yScale = ChoiceField[str]( 

79 doc="The scale to use for the y-axis.", 

80 default="linear", 

81 allowed={scale: scale for scale in ("linear", "log", "symlog")}, 

82 ) 

83 

84 xLinThresh = Field[float]( 

85 doc=( 

86 "The value around zero where the scale becomes linear in x-axis " 

87 "when symlog is set as the scale. Sets the `linthresh` parameter " 

88 "of `~matplotlib.axes.set_xscale`." 

89 ), 

90 default=1e-6, 

91 optional=True, 

92 ) 

93 

94 yLinThresh = Field[float]( 

95 doc=( 

96 "The value around zero where the scale becomes linear in y-axis " 

97 "when symlog is set as the scale. Sets the `linthresh` parameter " 

98 "of `~matplotlib.axes.set_yscale`." 

99 ), 

100 default=1e-6, 

101 optional=True, 

102 ) 

103 

104 xLine = Field[float]( 

105 doc=("The value of x where a vertical line is drawn."), 

106 default=None, 

107 optional=True, 

108 ) 

109 

110 yLine = Field[float]( 

111 doc=("The value of y where a horizontal line is drawn."), 

112 default=None, 

113 optional=True, 

114 ) 

115 

116 def setDefaults(self): 

117 super().setDefaults() 

118 self.strKwargs = {"fmt": "o"} 

119 

120 def validate(self): 

121 if (len(set(self.boolKwargs.keys()).intersection(self.numKwargs.keys())) > 0) or ( 

122 len(set(self.boolKwargs.keys()).intersection(self.strKwargs.keys())) > 0 

123 ): 

124 raise FieldValidationError(self.boolKwargs, self, "Keywords have been repeated") 

125 

126 super().validate() 

127 

128 def getInputSchema(self) -> KeyedDataSchema: 

129 base: list[tuple[str, type[Vector]]] = [] 

130 base.append(("x", Vector)) 

131 base.append(("y", Vector)) 

132 base.append(("xerr", Vector)) 

133 base.append(("yerr", Vector)) 

134 return base 

135 

136 def __call__(self, data: KeyedData, **kwargs) -> Figure: 

137 self._validateInput(data) 

138 return self.makePlot(data, **kwargs) 

139 

140 def _validateInput(self, data: KeyedData) -> None: 

141 needed = set(k[0] for k in self.getInputSchema()) 

142 if not needed.issubset(data.keys()): 

143 raise ValueError(f"Input data does not contain all required keys: {self.getInputSchema()}") 

144 

145 def makePlot(self, data: KeyedData, plotInfo: Mapping[str, str] | None = None, **kwargs: Any) -> Figure: 

146 """Make the plot. 

147 

148 Parameters 

149 ---------- 

150 data : `~pandas.core.frame.DataFrame` 

151 The catalog containing various rho statistics. 

152 **kwargs 

153 Additional keyword arguments to pass to the plot 

154 

155 Returns 

156 ------- 

157 fig : `~matplotlib.figure.Figure` 

158 The resulting figure. 

159 """ 

160 # Allow for multiple curves to lie on the same plot. 

161 fig = kwargs.get("fig", None) 

162 if fig is None: 

163 fig = plt.figure(dpi=300) 

164 ax = fig.add_subplot(111) 

165 else: 

166 ax = fig.gca() 

167 

168 ax.errorbar( 

169 data["x"], 

170 data["y"], 

171 xerr=data["xerr"], 

172 yerr=data["yerr"], 

173 **self.boolKwargs, # type: ignore 

174 **self.numKwargs, # type: ignore 

175 **self.strKwargs, # type: ignore 

176 ) 

177 ax.set_xlabel(self.xAxisLabel) 

178 ax.set_ylabel(self.yAxisLabel) 

179 

180 if self.xLine is not None: 

181 ax.axvline(self.xLine, color="k", linestyle="--") 

182 if self.yLine is not None: 

183 ax.axhline(self.yLine, color="k", linestyle="--") 

184 

185 if self.xScale == "symlog": 

186 ax.set_xscale("symlog", linthresh=self.xLinThresh) 

187 locator = SymmetricalLogLocator( 

188 linthresh=self.xLinThresh, base=10, subs=[0.1 * ii for ii in range(1, 10)] 

189 ) 

190 ax.xaxis.set_minor_locator(locator) 

191 ax.axvspan(-self.xLinThresh, self.xLinThresh, color="gray", alpha=0.1) 

192 else: 

193 ax.set_xscale(self.xScale) # type: ignore 

194 ax.tick_params(axis="x", which="minor") 

195 

196 if self.yScale == "symlog": 

197 ax.set_yscale("symlog", linthresh=self.yLinThresh) 

198 locator = SymmetricalLogLocator( 

199 linthresh=self.yLinThresh, base=10, subs=[0.1 * ii for ii in range(1, 10)] 

200 ) 

201 ax.yaxis.set_minor_locator(locator) 

202 ax.axhspan(-self.yLinThresh, self.yLinThresh, color="gray", alpha=0.1) 

203 else: 

204 ax.set_yscale(self.yScale) # type: ignore 

205 ax.tick_params(axis="y", which="minor") 

206 

207 if self.xScale == "symlog": 

208 locator = SymmetricalLogLocator(linthresh=self.xLinThresh, base=10) 

209 ax.xaxis.set_minor_locator(locator) 

210 else: 

211 ax.tick_params(axis="x", which="minor") 

212 

213 if plotInfo is not None: 

214 fig = addPlotInfo(fig, plotInfo) 

215 

216 return fig