Coverage for python / lsst / analysis / tools / actions / plot / completenessPlot.py: 15%

129 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-28 09:21 +0000

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22 

23from collections.abc import Mapping 

24 

25import numpy as np 

26from matplotlib.figure import Figure 

27 

28from lsst.pex.config import ChoiceField, Field 

29from lsst.pex.config.configurableActions import ConfigurableActionField 

30from lsst.utils.plotting import make_figure, set_rubin_plotstyle 

31 

32from ...actions.keyedData import CalcCompletenessHistogramAction 

33from ...interfaces import KeyedData, KeyedDataSchema, PlotAction, Scalar 

34from .plotUtils import addPlotInfo 

35 

36__all__ = ("CompletenessHist",) 

37 

38 

39class CompletenessHist(PlotAction): 

40 """Makes plots of completeness and purity.""" 

41 

42 label_shift = Field[float]( 

43 doc="Fraction of plot width to shift completeness/purity labels by." 

44 "Ignored if percentiles_style is not 'below_line'", 

45 default=-0.1, 

46 ) 

47 action = ConfigurableActionField[CalcCompletenessHistogramAction]( 

48 doc="Action to compute completeness/purity", 

49 ) 

50 color_counts = Field[str](doc="Color for the line showing object counts", default="#029E73") 

51 color_right = Field[str]( 

52 doc="Color for the line showing the correctly classified fraction", default="#949494" 

53 ) 

54 color_wrong = Field[str]( 

55 doc="Color for the line showing the wrongly classified fraction", default="#DE8F05" 

56 ) 

57 legendLocation = Field[str](doc="Legend position within main plot", default="lower left") 

58 mag_ref_label = Field[str]( 

59 doc="Label for the completeness x axis.", default="{band}-band Reference Magnitude" 

60 ) 

61 mag_target_label = Field[str]( 

62 doc="Label for the purity x axis.", default="{band}-band Measured Magnitude" 

63 ) 

64 object_label = Field[str](doc="Label for measured objects", default="Object") 

65 reference_label = Field[str](doc="Label for reference objects", default="Reference") 

66 percentiles_style = ChoiceField[str]( 

67 doc="Style and locations for completeness threshold percentile labels", 

68 allowed={ 

69 "above_plot": "Labels in a semicolon-separated list above plot", 

70 "below_line": "Labels under the horizontal part of each line", 

71 }, 

72 default="below_line", 

73 ) 

74 publicationStyle = Field[bool](doc="Make a publication-style of plot", default=False) 

75 show_purity = Field[bool](doc="Whether to include a purity plot below completness", default=True) 

76 

77 def getInputSchema(self) -> KeyedDataSchema: 

78 yield from self.action.getOutputSchema() 

79 

80 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure: 

81 self._validateInput(data, **kwargs) 

82 return self.makePlot(data, **kwargs) 

83 

84 def _validateInput(self, data: KeyedData, **kwargs) -> None: 

85 """NOTE currently can only check that something is not a Scalar, not 

86 check that the data is consistent with Vector 

87 """ 

88 needed = self.getFormattedInputSchema(**kwargs) 

89 if remainder := {key.format(**kwargs) for key, _ in needed} - { 

90 key.format(**kwargs) for key in data.keys() 

91 }: 

92 raise ValueError(f"Task needs keys {remainder} but they were not found in input") 

93 for name, typ in needed: 

94 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar) 

95 if isScalar and typ != Scalar: 

96 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}") 

97 

98 def makePlot(self, data, plotInfo, **kwargs): 

99 """Makes a plot showing the fraction of injected sources recovered by 

100 input magnitude. 

101 

102 The behavior of this plot is controlled by `self.action`. This action 

103 must be added to a struct (usually self.process.calculateActions) by 

104 the tool that calls this plot. 

105 

106 Parameters 

107 ---------- 

108 data : `KeyedData` 

109 All the data 

110 plotInfo : `dict` 

111 A dictionary of information about the data being plotted with keys: 

112 ``camera`` 

113 The camera used to take the data (`lsst.afw.cameraGeom.Camera`) 

114 ``"cameraName"`` 

115 The name of camera used to take the data (`str`). 

116 ``"filter"`` 

117 The filter used for this data (`str`). 

118 ``"ccdKey"`` 

119 The ccd/dectector key associated with this camera (`str`). 

120 ``"visit"`` 

121 The visit of the data; only included if the data is from a 

122 single epoch dataset (`str`). 

123 ``"patch"`` 

124 The patch that the data is from; only included if the data is 

125 from a coadd dataset (`str`). 

126 ``"tract"`` 

127 The tract that the data comes from (`str`). 

128 ``"photoCalibDataset"`` 

129 The dataset used for the calibration, e.g. "jointcal" or "fgcm" 

130 (`str`). 

131 ``"skyWcsDataset"`` 

132 The sky Wcs dataset used (`str`). 

133 ``"rerun"`` 

134 The rerun the data is stored in (`str`). 

135 

136 Returns 

137 ------ 

138 ``fig`` 

139 The figure to be saved (`matplotlib.figure.Figure`). 

140 

141 Notes 

142 ----- 

143 The behaviour of this plot is largel 

144 

145 Examples 

146 -------- 

147 An example of the plot produced from this code from tract 3828 of the 

148 DC2 simulations is here: 

149 

150 .. image:: /_static/analysis_tools/completenessPlotExample.png 

151 

152 """ 

153 

154 # Make plot showing the fraction recovered in magnitude bins 

155 set_rubin_plotstyle() 

156 n_sub = 1 + self.show_purity 

157 fig = make_figure(dpi=300, figsize=(8, 4 * n_sub)) 

158 if self.show_purity: 

159 axes = (fig.add_subplot(2, 1, 1), fig.add_subplot(2, 1, 2)) 

160 else: 

161 axes = [fig.add_axes([0.1, 0.15, 0.8, 0.75])] 

162 max_left = 1.05 

163 

164 band = kwargs.get("band") 

165 action_hist = self.action.action 

166 names = {} 

167 for name in ( 

168 "range_minimum", 

169 "range_maximum", 

170 "count", 

171 "count_ref", 

172 "count_target", 

173 "completeness", 

174 "completeness_bad_match", 

175 "completeness_good_match", 

176 "purity", 

177 "purity_bad_match", 

178 "purity_good_match", 

179 ): 

180 key = getattr(action_hist, f"name_{name}") 

181 if band is not None: 

182 key = key.format(band=band) 

183 names[name] = key 

184 

185 ranges_min = data[names["range_minimum"]] 

186 ranges_max = data[names["range_maximum"]] 

187 x = (ranges_max + ranges_min) / 2.0 

188 interval = self.action.bins.mag_width / 1000.0 

189 x_err = interval / 2.0 

190 

191 counts_all = data[names["count"]] 

192 

193 if self.publicationStyle: 

194 lineTuples = ( 

195 (data[names["completeness"]], False, "k", "Completeness"), 

196 (data[names["completeness_bad_match"]], False, self.color_wrong, "Incorrect Class"), 

197 ) 

198 else: 

199 lineTuples = ( 

200 (data[names["completeness"]], True, "k", "Completeness"), 

201 (data[names["completeness_bad_match"]], False, self.color_wrong, "Incorrect class"), 

202 (data[names["completeness_good_match"]], False, self.color_right, "Correct Class"), 

203 ) 

204 

205 mag_ref_label = self.mag_ref_label 

206 if "{band}" in mag_ref_label: 

207 mag_ref_label = mag_ref_label.format(band=band) 

208 mag_target_label = self.mag_target_label 

209 if "{band}" in mag_target_label: 

210 mag_target_label = mag_target_label.format(band=band) 

211 

212 plots = { 

213 "Completeness": { 

214 "count_type": self.reference_label, 

215 "counts": data[names["count_ref"]], 

216 "lines": lineTuples, 

217 "xlabel": mag_ref_label, 

218 }, 

219 } 

220 if self.show_purity: 

221 plots["Purity"] = { 

222 "count_type": self.object_label, 

223 "counts": data[names["count_target"]], 

224 "lines": ( 

225 (data[names["purity"]], True, "k", "Purity"), 

226 (data[names["purity_bad_match"]], False, self.color_wrong, "Incorrect class"), 

227 (data[names["purity_good_match"]], False, self.color_right, "Correct class"), 

228 ), 

229 "xlabel": mag_target_label, 

230 } 

231 

232 # idx == 0 should be completeness; update this if that assumption 

233 # is changed 

234 for idx, (ylabel, plot_data) in enumerate(plots.items()): 

235 axes_idx = axes[idx] 

236 xlim = (ranges_min[0], ranges_max[-1]) 

237 axes_idx.set( 

238 xlabel=plot_data["xlabel"], 

239 ylabel=ylabel, 

240 xlim=xlim, 

241 ylim=(0, max_left), 

242 xticks=np.arange(round(xlim[0]), round(xlim[1])), 

243 yticks=np.linspace(0, 1, 11), 

244 ) 

245 if not self.publicationStyle: 

246 axes_idx.grid(color="lightgrey", ls="-") 

247 ax_right = axes_idx.twinx() 

248 ax_right.set_ylabel(f"{plot_data['count_type']} Counts/Magnitude", color="k") 

249 ax_right.set_yscale("log") 

250 

251 for y, do_err, color, label in plot_data["lines"]: 

252 axes_idx.errorbar( 

253 x=x, 

254 y=y, 

255 xerr=x_err if do_err else None, 

256 yerr=1.0 / np.sqrt(counts_all + 1) if do_err else None, 

257 capsize=0, 

258 color=color, 

259 label=label, 

260 ) 

261 y = plot_data["counts"] / interval 

262 # It should be unusual for np.max(y) to be zero; nonetheless... 

263 lines_left, labels_left = axes_idx.get_legend_handles_labels() 

264 ax_right.step( 

265 [x[0] - interval] + list(x) + [x[-1] + interval], 

266 [0] + list(y) + [0], 

267 where="mid", 

268 color=self.color_counts, 

269 label="Counts", 

270 ) 

271 

272 # Force the inputs counts histogram to the back 

273 ax_right.zorder = 1 

274 axes_idx.zorder = 2 

275 axes_idx.patch.set_visible(False) 

276 

277 ax_right.set_ylim(0.999, 10 ** (max_left * np.log10(max(np.nanmax(y), 2)))) 

278 ax_right.tick_params(axis="y", labelcolor=self.color_counts) 

279 lines_right, labels_right = ax_right.get_legend_handles_labels() 

280 

281 # Using fig for legend 

282 (axes_idx if self.show_purity else fig).legend( 

283 lines_left + lines_right, 

284 labels_left + labels_right, 

285 loc=self.legendLocation, 

286 ncol=3 if self.publicationStyle else 2, 

287 ) 

288 

289 if idx == 0: 

290 if not self.publicationStyle: 

291 percentiles = self.action.config_metrics.completeness_percentiles 

292 else: 

293 percentiles = [90.0, 50.0] 

294 if percentiles: 

295 above_plot = self.percentiles_style == "above_plot" 

296 below_line = self.percentiles_style == "below_line" 

297 kwargs_lines = dict(color="dimgrey", ls=":") 

298 xlims = axes_idx.get_xlim() 

299 if above_plot: 

300 texts = [] 

301 elif below_line: 

302 offset = self.label_shift * (xlims[1] - xlims[0]) 

303 else: 

304 raise RuntimeError(f"Unimplemented {self.percentiles_style=}") 

305 for pct in percentiles: 

306 name_pct = self.action.action.name_mag_completeness( 

307 self.action.getPercentileName(pct), 

308 ) 

309 if band is not None: 

310 name_pct = name_pct.format(band=band) 

311 mag_completeness = data.get(name_pct, None) 

312 pct /= 100.0 

313 if mag_completeness is not None and np.isfinite(mag_completeness): 

314 axes_idx.plot([xlims[0], mag_completeness], [pct, pct], **kwargs_lines) 

315 axes_idx.plot([mag_completeness, mag_completeness], [0, pct], **kwargs_lines) 

316 text = f"{pct*100:.2g}%: {mag_completeness:.2f}" 

317 if above_plot: 

318 texts.append(text) 

319 elif below_line: 

320 axes_idx.text( 

321 mag_completeness + offset, 

322 pct - 0.02, 

323 text, 

324 ha="right", 

325 va="top", 

326 fontsize=12, 

327 ) 

328 if above_plot: 

329 texts = f"Thresholds: {'; '.join(texts)}" 

330 axes_idx.text(xlims[0], max_left, texts, ha="left", va="bottom") 

331 

332 # Add useful information to the plot 

333 if not self.publicationStyle: 

334 addPlotInfo(fig, plotInfo) 

335 if self.show_purity: 

336 fig.tight_layout() 

337 fig.subplots_adjust(top=0.90) 

338 return fig