Coverage for python / lsst / analysis / tools / actions / plot / xyPlot.py: 28%

75 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-05 18:53 +0000

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ("XYPlot",) 

25 

26from typing import TYPE_CHECKING, Any, Mapping 

27 

28import matplotlib.pyplot as plt 

29from lsst.pex.config import ChoiceField, DictField, Field, FieldValidationError 

30from matplotlib.ticker import SymmetricalLogLocator 

31 

32from ...interfaces import PlotAction, Vector 

33from .plotUtils import addPlotInfo 

34 

35if TYPE_CHECKING: 

36 from matplotlib.figure import Figure 

37 

38 from ...interfaces import KeyedData, KeyedDataSchema 

39 

40 

41class XYPlot(PlotAction): 

42 """Make a plot (with errorbars) of one quantity (X) vs another (Y).""" 

43 

44 boolKwargs = DictField[str, bool]( 

45 doc="Keyword arguments to ax.errorbar that take boolean values", 

46 default={}, 

47 optional=True, 

48 ) 

49 

50 numKwargs = DictField[str, float]( 

51 doc="Keyword arguments to ax.errorbar that take numerical (float or int) values", 

52 default={}, 

53 optional=True, 

54 ) 

55 

56 strKwargs = DictField[str, str]( 

57 doc="Keyword arguments to ax.errorbar that take string values", 

58 default={}, 

59 optional=True, 

60 ) 

61 

62 xAxisLabel = Field[str]( 

63 doc="The label to use for the x-axis.", 

64 default="x", 

65 ) 

66 

67 yAxisLabel = Field[str]( 

68 doc="The label to use for the y-axis.", 

69 default="y", 

70 ) 

71 

72 xScale = ChoiceField[str]( 

73 doc="The scale to use for the x-axis.", 

74 default="linear", 

75 allowed={scale: scale for scale in ("linear", "log", "symlog")}, 

76 ) 

77 

78 yScale = ChoiceField[str]( 

79 doc="The scale to use for the y-axis.", 

80 default="linear", 

81 allowed={scale: scale for scale in ("linear", "log", "symlog")}, 

82 ) 

83 

84 xLinThresh = Field[float]( 

85 doc=( 

86 "The value around zero where the scale becomes linear in x-axis " 

87 "when symlog is set as the scale. Sets the `linthresh` parameter " 

88 "of `~matplotlib.axes.set_xscale`." 

89 ), 

90 default=1e-6, 

91 optional=True, 

92 ) 

93 

94 yLinThresh = Field[float]( 

95 doc=( 

96 "The value around zero where the scale becomes linear in y-axis " 

97 "when symlog is set as the scale. Sets the `linthresh` parameter " 

98 "of `~matplotlib.axes.set_yscale`." 

99 ), 

100 default=1e-6, 

101 optional=True, 

102 ) 

103 

104 xLine = Field[float]( 

105 doc=("The value of x where a vertical line is drawn."), 

106 default=None, 

107 optional=True, 

108 ) 

109 

110 yLine = Field[float]( 

111 doc=("The value of y where a horizontal line is drawn."), 

112 default=None, 

113 optional=True, 

114 ) 

115 

116 def setDefaults(self): 

117 super().setDefaults() 

118 self.strKwargs = {"fmt": "o"} 

119 

120 def validate(self): 

121 if (len(set(self.boolKwargs.keys()).intersection(self.numKwargs.keys())) > 0) or ( 

122 len(set(self.boolKwargs.keys()).intersection(self.strKwargs.keys())) > 0 

123 ): 

124 raise FieldValidationError(self.boolKwargs, self, "Keywords have been repeated") 

125 

126 super().validate() 

127 

128 def getInputSchema(self) -> KeyedDataSchema: 

129 base: list[tuple[str, type[Vector]]] = [] 

130 base.append(("x", Vector)) 

131 base.append(("y", Vector)) 

132 base.append(("xerr", Vector)) 

133 base.append(("yerr", Vector)) 

134 return base 

135 

136 def __call__(self, data: KeyedData, **kwargs) -> Figure: 

137 self._validateInput(data) 

138 return self.makePlot(data, **kwargs) 

139 

140 def _validateInput(self, data: KeyedData) -> None: 

141 needed = set(k[0] for k in self.getInputSchema()) 

142 if not needed.issubset(data.keys()): 

143 raise ValueError(f"Input data does not contain all required keys: {self.getInputSchema()}") 

144 

145 def makePlot(self, data: KeyedData, plotInfo: Mapping[str, str] | None = None, **kwargs: Any) -> Figure: 

146 """Make the plot. 

147 

148 Parameters 

149 ---------- 

150 data : `~pandas.core.frame.DataFrame` 

151 The catalog containing various rho statistics. 

152 **kwargs 

153 Additional keyword arguments to pass to the plot 

154 

155 Returns 

156 ------- 

157 fig : `~matplotlib.figure.Figure` 

158 The resulting figure. 

159 """ 

160 # Allow for multiple curves to lie on the same plot. 

161 fig = kwargs.get("fig", None) 

162 if fig is None: 

163 fig = plt.figure(dpi=300) 

164 ax = fig.add_subplot(111) 

165 else: 

166 ax = fig.gca() 

167 

168 ax.errorbar( 

169 data["x"], 

170 data["y"], 

171 xerr=data["xerr"], 

172 yerr=data["yerr"], 

173 **self.boolKwargs, # type: ignore 

174 **self.numKwargs, # type: ignore 

175 **self.strKwargs, # type: ignore 

176 ) 

177 ax.set_xlabel(self.xAxisLabel) 

178 ax.set_ylabel(self.yAxisLabel) 

179 

180 if self.xLine is not None: 

181 ax.axvline(self.xLine, color="k", linestyle="--") 

182 if self.yLine is not None: 

183 ax.axhline(self.yLine, color="k", linestyle="--") 

184 

185 if self.xScale == "symlog": 

186 ax.set_xscale("symlog", linthresh=self.xLinThresh) 

187 locator = SymmetricalLogLocator( 

188 linthresh=self.xLinThresh, base=10, subs=[0.1 * ii for ii in range(1, 10)] 

189 ) 

190 ax.xaxis.set_minor_locator(locator) 

191 ax.axvspan(-self.xLinThresh, self.xLinThresh, color="gray", alpha=0.1) 

192 else: 

193 ax.set_xscale(self.xScale) # type: ignore 

194 ax.tick_params(axis="x", which="minor") 

195 

196 if self.yScale == "symlog": 

197 ax.set_yscale("symlog", linthresh=self.yLinThresh) 

198 locator = SymmetricalLogLocator( 

199 linthresh=self.yLinThresh, base=10, subs=[0.1 * ii for ii in range(1, 10)] 

200 ) 

201 ax.yaxis.set_minor_locator(locator) 

202 ax.axhspan(-self.yLinThresh, self.yLinThresh, color="gray", alpha=0.1) 

203 else: 

204 ax.set_yscale(self.yScale) # type: ignore 

205 ax.tick_params(axis="y", which="minor") 

206 

207 if self.xScale == "symlog": 

208 locator = SymmetricalLogLocator(linthresh=self.xLinThresh, base=10) 

209 ax.xaxis.set_minor_locator(locator) 

210 else: 

211 ax.tick_params(axis="x", which="minor") 

212 

213 if plotInfo is not None: 

214 fig = addPlotInfo(fig, plotInfo) 

215 

216 return fig