Coverage for python / lsst / analysis / tools / actions / plot / completenessPlot.py: 15%

129 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-22 09:09 +0000

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22 

23from typing import Mapping 

24 

25import numpy as np 

26from lsst.pex.config import ChoiceField, Field 

27from lsst.pex.config.configurableActions import ConfigurableActionField 

28from lsst.utils.plotting import make_figure, set_rubin_plotstyle 

29from matplotlib.figure import Figure 

30 

31from ...actions.keyedData import CalcCompletenessHistogramAction 

32from ...interfaces import KeyedData, KeyedDataSchema, PlotAction, Scalar 

33from .plotUtils import addPlotInfo 

34 

35__all__ = ("CompletenessHist",) 

36 

37 

38class CompletenessHist(PlotAction): 

39 """Makes plots of completeness and purity.""" 

40 

41 label_shift = Field[float]( 

42 doc="Fraction of plot width to shift completeness/purity labels by." 

43 "Ignored if percentiles_style is not 'below_line'", 

44 default=-0.1, 

45 ) 

46 action = ConfigurableActionField[CalcCompletenessHistogramAction]( 

47 doc="Action to compute completeness/purity", 

48 ) 

49 color_counts = Field[str](doc="Color for the line showing object counts", default="#029E73") 

50 color_right = Field[str]( 

51 doc="Color for the line showing the correctly classified fraction", default="#949494" 

52 ) 

53 color_wrong = Field[str]( 

54 doc="Color for the line showing the wrongly classified fraction", default="#DE8F05" 

55 ) 

56 legendLocation = Field[str](doc="Legend position within main plot", default="lower left") 

57 mag_ref_label = Field[str]( 

58 doc="Label for the completeness x axis.", default="{band}-band Reference Magnitude" 

59 ) 

60 mag_target_label = Field[str]( 

61 doc="Label for the purity x axis.", default="{band}-band Measured Magnitude" 

62 ) 

63 object_label = Field[str](doc="Label for measured objects", default="Object") 

64 reference_label = Field[str](doc="Label for reference objects", default="Reference") 

65 percentiles_style = ChoiceField[str]( 

66 doc="Style and locations for completeness threshold percentile labels", 

67 allowed={ 

68 "above_plot": "Labels in a semicolon-separated list above plot", 

69 "below_line": "Labels under the horizontal part of each line", 

70 }, 

71 default="below_line", 

72 ) 

73 publicationStyle = Field[bool](doc="Make a publication-style of plot", default=False) 

74 show_purity = Field[bool](doc="Whether to include a purity plot below completness", default=True) 

75 

76 def getInputSchema(self) -> KeyedDataSchema: 

77 yield from self.action.getOutputSchema() 

78 

79 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure: 

80 self._validateInput(data, **kwargs) 

81 return self.makePlot(data, **kwargs) 

82 

83 def _validateInput(self, data: KeyedData, **kwargs) -> None: 

84 """NOTE currently can only check that something is not a Scalar, not 

85 check that the data is consistent with Vector 

86 """ 

87 needed = self.getFormattedInputSchema(**kwargs) 

88 if remainder := {key.format(**kwargs) for key, _ in needed} - { 

89 key.format(**kwargs) for key in data.keys() 

90 }: 

91 raise ValueError(f"Task needs keys {remainder} but they were not found in input") 

92 for name, typ in needed: 

93 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar) 

94 if isScalar and typ != Scalar: 

95 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}") 

96 

97 def makePlot(self, data, plotInfo, **kwargs): 

98 """Makes a plot showing the fraction of injected sources recovered by 

99 input magnitude. 

100 

101 The behavior of this plot is controlled by `self.action`. This action 

102 must be added to a struct (usually self.process.calculateActions) by 

103 the tool that calls this plot. 

104 

105 Parameters 

106 ---------- 

107 data : `KeyedData` 

108 All the data 

109 plotInfo : `dict` 

110 A dictionary of information about the data being plotted with keys: 

111 ``camera`` 

112 The camera used to take the data (`lsst.afw.cameraGeom.Camera`) 

113 ``"cameraName"`` 

114 The name of camera used to take the data (`str`). 

115 ``"filter"`` 

116 The filter used for this data (`str`). 

117 ``"ccdKey"`` 

118 The ccd/dectector key associated with this camera (`str`). 

119 ``"visit"`` 

120 The visit of the data; only included if the data is from a 

121 single epoch dataset (`str`). 

122 ``"patch"`` 

123 The patch that the data is from; only included if the data is 

124 from a coadd dataset (`str`). 

125 ``"tract"`` 

126 The tract that the data comes from (`str`). 

127 ``"photoCalibDataset"`` 

128 The dataset used for the calibration, e.g. "jointcal" or "fgcm" 

129 (`str`). 

130 ``"skyWcsDataset"`` 

131 The sky Wcs dataset used (`str`). 

132 ``"rerun"`` 

133 The rerun the data is stored in (`str`). 

134 

135 Returns 

136 ------ 

137 ``fig`` 

138 The figure to be saved (`matplotlib.figure.Figure`). 

139 

140 Notes 

141 ----- 

142 The behaviour of this plot is largel 

143 

144 Examples 

145 -------- 

146 An example of the plot produced from this code from tract 3828 of the 

147 DC2 simulations is here: 

148 

149 .. image:: /_static/analysis_tools/completenessPlotExample.png 

150 

151 """ 

152 

153 # Make plot showing the fraction recovered in magnitude bins 

154 set_rubin_plotstyle() 

155 n_sub = 1 + self.show_purity 

156 fig = make_figure(dpi=300, figsize=(8, 4 * n_sub)) 

157 if self.show_purity: 

158 axes = (fig.add_subplot(2, 1, 1), fig.add_subplot(2, 1, 2)) 

159 else: 

160 axes = [fig.add_axes([0.1, 0.15, 0.8, 0.75])] 

161 max_left = 1.05 

162 

163 band = kwargs.get("band") 

164 action_hist = self.action.action 

165 names = {} 

166 for name in ( 

167 "range_minimum", 

168 "range_maximum", 

169 "count", 

170 "count_ref", 

171 "count_target", 

172 "completeness", 

173 "completeness_bad_match", 

174 "completeness_good_match", 

175 "purity", 

176 "purity_bad_match", 

177 "purity_good_match", 

178 ): 

179 key = getattr(action_hist, f"name_{name}") 

180 if band is not None: 

181 key = key.format(band=band) 

182 names[name] = key 

183 

184 ranges_min = data[names["range_minimum"]] 

185 ranges_max = data[names["range_maximum"]] 

186 x = (ranges_max + ranges_min) / 2.0 

187 interval = self.action.bins.mag_width / 1000.0 

188 x_err = interval / 2.0 

189 

190 counts_all = data[names["count"]] 

191 

192 if self.publicationStyle: 

193 lineTuples = ( 

194 (data[names["completeness"]], False, "k", "Completeness"), 

195 (data[names["completeness_bad_match"]], False, self.color_wrong, "Incorrect Class"), 

196 ) 

197 else: 

198 lineTuples = ( 

199 (data[names["completeness"]], True, "k", "Completeness"), 

200 (data[names["completeness_bad_match"]], False, self.color_wrong, "Incorrect class"), 

201 (data[names["completeness_good_match"]], False, self.color_right, "Correct Class"), 

202 ) 

203 

204 mag_ref_label = self.mag_ref_label 

205 if "{band}" in mag_ref_label: 

206 mag_ref_label = mag_ref_label.format(band=band) 

207 mag_target_label = self.mag_target_label 

208 if "{band}" in mag_target_label: 

209 mag_target_label = mag_target_label.format(band=band) 

210 

211 plots = { 

212 "Completeness": { 

213 "count_type": self.reference_label, 

214 "counts": data[names["count_ref"]], 

215 "lines": lineTuples, 

216 "xlabel": mag_ref_label, 

217 }, 

218 } 

219 if self.show_purity: 

220 plots["Purity"] = { 

221 "count_type": self.object_label, 

222 "counts": data[names["count_target"]], 

223 "lines": ( 

224 (data[names["purity"]], True, "k", "Purity"), 

225 (data[names["purity_bad_match"]], False, self.color_wrong, "Incorrect class"), 

226 (data[names["purity_good_match"]], False, self.color_right, "Correct class"), 

227 ), 

228 "xlabel": mag_target_label, 

229 } 

230 

231 # idx == 0 should be completeness; update this if that assumption 

232 # is changed 

233 for idx, (ylabel, plot_data) in enumerate(plots.items()): 

234 axes_idx = axes[idx] 

235 xlim = (ranges_min[0], ranges_max[-1]) 

236 axes_idx.set( 

237 xlabel=plot_data["xlabel"], 

238 ylabel=ylabel, 

239 xlim=xlim, 

240 ylim=(0, max_left), 

241 xticks=np.arange(round(xlim[0]), round(xlim[1])), 

242 yticks=np.linspace(0, 1, 11), 

243 ) 

244 if not self.publicationStyle: 

245 axes_idx.grid(color="lightgrey", ls="-") 

246 ax_right = axes_idx.twinx() 

247 ax_right.set_ylabel(f"{plot_data['count_type']} Counts/Magnitude", color="k") 

248 ax_right.set_yscale("log") 

249 

250 for y, do_err, color, label in plot_data["lines"]: 

251 axes_idx.errorbar( 

252 x=x, 

253 y=y, 

254 xerr=x_err if do_err else None, 

255 yerr=1.0 / np.sqrt(counts_all + 1) if do_err else None, 

256 capsize=0, 

257 color=color, 

258 label=label, 

259 ) 

260 y = plot_data["counts"] / interval 

261 # It should be unusual for np.max(y) to be zero; nonetheless... 

262 lines_left, labels_left = axes_idx.get_legend_handles_labels() 

263 ax_right.step( 

264 [x[0] - interval] + list(x) + [x[-1] + interval], 

265 [0] + list(y) + [0], 

266 where="mid", 

267 color=self.color_counts, 

268 label="Counts", 

269 ) 

270 

271 # Force the inputs counts histogram to the back 

272 ax_right.zorder = 1 

273 axes_idx.zorder = 2 

274 axes_idx.patch.set_visible(False) 

275 

276 ax_right.set_ylim(0.999, 10 ** (max_left * np.log10(max(np.nanmax(y), 2)))) 

277 ax_right.tick_params(axis="y", labelcolor=self.color_counts) 

278 lines_right, labels_right = ax_right.get_legend_handles_labels() 

279 

280 # Using fig for legend 

281 (axes_idx if self.show_purity else fig).legend( 

282 lines_left + lines_right, 

283 labels_left + labels_right, 

284 loc=self.legendLocation, 

285 ncol=3 if self.publicationStyle else 2, 

286 ) 

287 

288 if idx == 0: 

289 if not self.publicationStyle: 

290 percentiles = self.action.config_metrics.completeness_percentiles 

291 else: 

292 percentiles = [90.0, 50.0] 

293 if percentiles: 

294 above_plot = self.percentiles_style == "above_plot" 

295 below_line = self.percentiles_style == "below_line" 

296 kwargs_lines = dict(color="dimgrey", ls=":") 

297 xlims = axes_idx.get_xlim() 

298 if above_plot: 

299 texts = [] 

300 elif below_line: 

301 offset = self.label_shift * (xlims[1] - xlims[0]) 

302 else: 

303 raise RuntimeError(f"Unimplemented {self.percentiles_style=}") 

304 for pct in percentiles: 

305 name_pct = self.action.action.name_mag_completeness( 

306 self.action.getPercentileName(pct), 

307 ) 

308 if band is not None: 

309 name_pct = name_pct.format(band=band) 

310 mag_completeness = data.get(name_pct, None) 

311 pct /= 100.0 

312 if mag_completeness is not None and np.isfinite(mag_completeness): 

313 axes_idx.plot([xlims[0], mag_completeness], [pct, pct], **kwargs_lines) 

314 axes_idx.plot([mag_completeness, mag_completeness], [0, pct], **kwargs_lines) 

315 text = f"{pct*100:.2g}%: {mag_completeness:.2f}" 

316 if above_plot: 

317 texts.append(text) 

318 elif below_line: 

319 axes_idx.text( 

320 mag_completeness + offset, 

321 pct - 0.02, 

322 text, 

323 ha="right", 

324 va="top", 

325 fontsize=12, 

326 ) 

327 if above_plot: 

328 texts = f"Thresholds: {'; '.join(texts)}" 

329 axes_idx.text(xlims[0], max_left, texts, ha="left", va="bottom") 

330 

331 # Add useful information to the plot 

332 if not self.publicationStyle: 

333 addPlotInfo(fig, plotInfo) 

334 if self.show_purity: 

335 fig.tight_layout() 

336 fig.subplots_adjust(top=0.90) 

337 return fig