Coverage for python / lsst / analysis / tools / actions / plot / xyPlot.py: 29%

76 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-28 09:21 +0000

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ("XYPlot",) 

25 

26from collections.abc import Mapping 

27from typing import TYPE_CHECKING, Any 

28 

29import matplotlib.pyplot as plt 

30from matplotlib.ticker import SymmetricalLogLocator 

31 

32from lsst.pex.config import ChoiceField, DictField, Field, FieldValidationError 

33 

34from ...interfaces import PlotAction, Vector 

35from .plotUtils import addPlotInfo 

36 

37if TYPE_CHECKING: 

38 from matplotlib.figure import Figure 

39 

40 from ...interfaces import KeyedData, KeyedDataSchema 

41 

42 

43class XYPlot(PlotAction): 

44 """Make a plot (with errorbars) of one quantity (X) vs another (Y).""" 

45 

46 boolKwargs = DictField[str, bool]( 

47 doc="Keyword arguments to ax.errorbar that take boolean values", 

48 default={}, 

49 optional=True, 

50 ) 

51 

52 numKwargs = DictField[str, float]( 

53 doc="Keyword arguments to ax.errorbar that take numerical (float or int) values", 

54 default={}, 

55 optional=True, 

56 ) 

57 

58 strKwargs = DictField[str, str]( 

59 doc="Keyword arguments to ax.errorbar that take string values", 

60 default={}, 

61 optional=True, 

62 ) 

63 

64 xAxisLabel = Field[str]( 

65 doc="The label to use for the x-axis.", 

66 default="x", 

67 ) 

68 

69 yAxisLabel = Field[str]( 

70 doc="The label to use for the y-axis.", 

71 default="y", 

72 ) 

73 

74 xScale = ChoiceField[str]( 

75 doc="The scale to use for the x-axis.", 

76 default="linear", 

77 allowed={scale: scale for scale in ("linear", "log", "symlog")}, 

78 ) 

79 

80 yScale = ChoiceField[str]( 

81 doc="The scale to use for the y-axis.", 

82 default="linear", 

83 allowed={scale: scale for scale in ("linear", "log", "symlog")}, 

84 ) 

85 

86 xLinThresh = Field[float]( 

87 doc=( 

88 "The value around zero where the scale becomes linear in x-axis " 

89 "when symlog is set as the scale. Sets the `linthresh` parameter " 

90 "of `~matplotlib.axes.set_xscale`." 

91 ), 

92 default=1e-6, 

93 optional=True, 

94 ) 

95 

96 yLinThresh = Field[float]( 

97 doc=( 

98 "The value around zero where the scale becomes linear in y-axis " 

99 "when symlog is set as the scale. Sets the `linthresh` parameter " 

100 "of `~matplotlib.axes.set_yscale`." 

101 ), 

102 default=1e-6, 

103 optional=True, 

104 ) 

105 

106 xLine = Field[float]( 

107 doc=("The value of x where a vertical line is drawn."), 

108 default=None, 

109 optional=True, 

110 ) 

111 

112 yLine = Field[float]( 

113 doc=("The value of y where a horizontal line is drawn."), 

114 default=None, 

115 optional=True, 

116 ) 

117 

118 def setDefaults(self): 

119 super().setDefaults() 

120 self.strKwargs = {"fmt": "o"} 

121 

122 def validate(self): 

123 if (len(set(self.boolKwargs.keys()).intersection(self.numKwargs.keys())) > 0) or ( 

124 len(set(self.boolKwargs.keys()).intersection(self.strKwargs.keys())) > 0 

125 ): 

126 raise FieldValidationError(self.boolKwargs, self, "Keywords have been repeated") 

127 

128 super().validate() 

129 

130 def getInputSchema(self) -> KeyedDataSchema: 

131 base: list[tuple[str, type[Vector]]] = [] 

132 base.append(("x", Vector)) 

133 base.append(("y", Vector)) 

134 base.append(("xerr", Vector)) 

135 base.append(("yerr", Vector)) 

136 return base 

137 

138 def __call__(self, data: KeyedData, **kwargs) -> Figure: 

139 self._validateInput(data) 

140 return self.makePlot(data, **kwargs) 

141 

142 def _validateInput(self, data: KeyedData) -> None: 

143 needed = set(k[0] for k in self.getInputSchema()) 

144 if not needed.issubset(data.keys()): 

145 raise ValueError(f"Input data does not contain all required keys: {self.getInputSchema()}") 

146 

147 def makePlot(self, data: KeyedData, plotInfo: Mapping[str, str] | None = None, **kwargs: Any) -> Figure: 

148 """Make the plot. 

149 

150 Parameters 

151 ---------- 

152 data : `~pandas.core.frame.DataFrame` 

153 The catalog containing various rho statistics. 

154 **kwargs 

155 Additional keyword arguments to pass to the plot 

156 

157 Returns 

158 ------- 

159 fig : `~matplotlib.figure.Figure` 

160 The resulting figure. 

161 """ 

162 # Allow for multiple curves to lie on the same plot. 

163 fig = kwargs.get("fig", None) 

164 if fig is None: 

165 fig = plt.figure(dpi=300) 

166 ax = fig.add_subplot(111) 

167 else: 

168 ax = fig.gca() 

169 

170 ax.errorbar( 

171 data["x"], 

172 data["y"], 

173 xerr=data["xerr"], 

174 yerr=data["yerr"], 

175 **self.boolKwargs, # type: ignore 

176 **self.numKwargs, # type: ignore 

177 **self.strKwargs, # type: ignore 

178 ) 

179 ax.set_xlabel(self.xAxisLabel) 

180 ax.set_ylabel(self.yAxisLabel) 

181 

182 if self.xLine is not None: 

183 ax.axvline(self.xLine, color="k", linestyle="--") 

184 if self.yLine is not None: 

185 ax.axhline(self.yLine, color="k", linestyle="--") 

186 

187 if self.xScale == "symlog": 

188 ax.set_xscale("symlog", linthresh=self.xLinThresh) 

189 locator = SymmetricalLogLocator( 

190 linthresh=self.xLinThresh, base=10, subs=[0.1 * ii for ii in range(1, 10)] 

191 ) 

192 ax.xaxis.set_minor_locator(locator) 

193 ax.axvspan(-self.xLinThresh, self.xLinThresh, color="gray", alpha=0.1) 

194 else: 

195 ax.set_xscale(self.xScale) # type: ignore 

196 ax.tick_params(axis="x", which="minor") 

197 

198 if self.yScale == "symlog": 

199 ax.set_yscale("symlog", linthresh=self.yLinThresh) 

200 locator = SymmetricalLogLocator( 

201 linthresh=self.yLinThresh, base=10, subs=[0.1 * ii for ii in range(1, 10)] 

202 ) 

203 ax.yaxis.set_minor_locator(locator) 

204 ax.axhspan(-self.yLinThresh, self.yLinThresh, color="gray", alpha=0.1) 

205 else: 

206 ax.set_yscale(self.yScale) # type: ignore 

207 ax.tick_params(axis="y", which="minor") 

208 

209 if self.xScale == "symlog": 

210 locator = SymmetricalLogLocator(linthresh=self.xLinThresh, base=10) 

211 ax.xaxis.set_minor_locator(locator) 

212 else: 

213 ax.tick_params(axis="x", which="minor") 

214 

215 if plotInfo is not None: 

216 fig = addPlotInfo(fig, plotInfo) 

217 

218 return fig