Coverage for python / lsst / analysis / tools / actions / plot / scatterplotWithTwoHists.py: 13%

444 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-15 00:23 +0000

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ("ScatterPlotStatsAction", "ScatterPlotWithTwoHists") 

25 

26import math 

27from typing import Mapping, NamedTuple, Optional, cast 

28 

29import matplotlib.colors 

30import matplotlib.patheffects as pathEffects 

31import numpy as np 

32from lsst.pex.config import Field 

33from lsst.pex.config.configurableActions import ConfigurableActionField 

34from lsst.pex.config.listField import ListField 

35from lsst.utils.plotting import ( 

36 galaxies_cmap, 

37 galaxies_color, 

38 make_figure, 

39 set_rubin_plotstyle, 

40 stars_cmap, 

41 stars_color, 

42) 

43from matplotlib import gridspec 

44from matplotlib.axes import Axes 

45from matplotlib.collections import PolyCollection 

46from matplotlib.figure import Figure 

47from matplotlib.path import Path 

48from matplotlib.ticker import LogFormatterMathtext, NullFormatter 

49from mpl_toolkits.axes_grid1 import make_axes_locatable 

50 

51from ...interfaces import KeyedData, KeyedDataAction, KeyedDataSchema, PlotAction, Scalar, ScalarType, Vector 

52from ...math import nanMedian, nanSigmaMad 

53from ..keyedData.summaryStatistics import SummaryStatisticAction 

54from ..scalar import MedianAction 

55from ..vector import ConvertFluxToMag, SnSelector 

56from .plotUtils import addPlotInfo, addSummaryPlot, generateSummaryStats 

57 

58 

59class ScatterPlotStatsAction(KeyedDataAction): 

60 """Calculates the statistics needed for the 

61 scatter plot with two hists. 

62 """ 

63 

64 vectorKey = Field[str](doc="Vector on which to compute statistics") 

65 highSNSelector = ConfigurableActionField[SnSelector]( 

66 doc="Selector used to determine high SN Objects", default=SnSelector(threshold=2700) 

67 ) 

68 lowSNSelector = ConfigurableActionField[SnSelector]( 

69 doc="Selector used to determine low SN Objects", default=SnSelector(threshold=500) 

70 ) 

71 prefix = Field[str]( 

72 doc="Prefix for all output fields; will use self.identity if None", 

73 optional=True, 

74 default=None, 

75 ) 

76 fluxType = Field[str](doc="Vector key to use to compute signal to noise ratio", default="{band}_psfFlux") 

77 suffix = Field[str](doc="Suffix for all output fields", default="") 

78 

79 def _get_key_prefix(self): 

80 prefix = self.prefix if self.prefix else (self.identity if self.identity else "") 

81 return prefix 

82 

83 def getInputSchema(self, **kwargs) -> KeyedDataSchema: 

84 yield (self.vectorKey, Vector) 

85 yield (self.fluxType, Vector) 

86 yield from self.highSNSelector.getInputSchema() 

87 yield from self.lowSNSelector.getInputSchema() 

88 

89 def getOutputSchema(self) -> KeyedDataSchema: 

90 prefix = self._get_key_prefix() 

91 prefix_lower = prefix.lower() if prefix else "" 

92 prefix_upper = prefix.capitalize() if prefix else "" 

93 suffix = self.suffix 

94 return ( 

95 (f"{prefix_lower}HighSNMask{suffix}", Vector), 

96 (f"{prefix_lower}LowSNMask{suffix}", Vector), 

97 (f"{{band}}_lowSN{prefix_upper}_median{suffix}", Scalar), 

98 (f"{{band}}_lowSN{prefix_upper}_sigmaMad{suffix}", Scalar), 

99 (f"{{band}}_lowSN{prefix_upper}_count{suffix}", Scalar), 

100 (f"{{band}}_lowSN{prefix_upper}_approxMag{suffix}", Scalar), 

101 (f"{{band}}_highSN{prefix_upper}_median{suffix}", Scalar), 

102 (f"{{band}}_highSN{prefix_upper}_sigmaMad{suffix}", Scalar), 

103 (f"{{band}}_highSN{prefix_upper}_count{suffix}", Scalar), 

104 (f"{{band}}_highSN{prefix_upper}_approxMag{suffix}", Scalar), 

105 (f"{prefix_lower}HighSNThreshold{suffix}", Scalar), 

106 (f"{prefix_lower}LowSNThreshold{suffix}", Scalar), 

107 ) 

108 

109 def __call__(self, data: KeyedData, **kwargs) -> KeyedData: 

110 results = {} 

111 prefix = self._get_key_prefix() 

112 prefix_lower = prefix.lower() if prefix else "" 

113 prefix_upper = prefix.capitalize() if prefix else "" 

114 suffix = self.suffix 

115 highMaskKey = f"{prefix_lower}HighSNMask{suffix}" 

116 results[highMaskKey] = self.highSNSelector(data, **kwargs) 

117 

118 lowMaskKey = f"{prefix_lower}LowSNMask{suffix}" 

119 results[lowMaskKey] = self.lowSNSelector(data, **kwargs) 

120 

121 prefix_band = f"{band}_" if (band := kwargs.get("band")) else "" 

122 fluxes = data[self.fluxType.format(band=band)] if band is not None else None 

123 

124 statAction = SummaryStatisticAction(vectorKey=self.vectorKey) 

125 

126 # this is sad, but pex_config seems to have broken behavior that 

127 # is dangerous to fix 

128 statAction.setDefaults() 

129 

130 medianAction = MedianAction(vectorKey="mag") 

131 magAction = ConvertFluxToMag(vectorKey="flux") 

132 

133 for maskKey, binName in ((lowMaskKey, "low"), (highMaskKey, "high")): 

134 name = f"{prefix_band}{binName}SN{prefix_upper}" 

135 # set the approxMag to the median mag in the SN selection 

136 results[f"{name}_approxMag{suffix}".format(**kwargs)] = ( 

137 medianAction({"mag": magAction({"flux": fluxes[results[maskKey]]})}) # type: ignore 

138 if band is not None 

139 else np.nan 

140 ) 

141 stats = statAction(data, **(kwargs | {"mask": results[maskKey]})).items() 

142 for name_stat, value in stats: 

143 tmpKey = f"{name}_{name_stat}{suffix}".format(**kwargs) 

144 results[tmpKey] = value 

145 results[f"{prefix_lower}HighSNThreshold{suffix}"] = self.highSNSelector.threshold # type: ignore 

146 results[f"{prefix_lower}LowSNThreshold{suffix}"] = self.lowSNSelector.threshold # type: ignore 

147 

148 return results 

149 

150 

151def _validObjectTypes(value): 

152 return value in ("stars", "galaxies", "unknown", "any") 

153 

154 

155# ignore type because of conflicting name on tuple baseclass 

156class _StatsContainer(NamedTuple): 

157 median: Scalar 

158 sigmaMad: Scalar 

159 count: Scalar # type: ignore 

160 approxMag: Scalar 

161 

162 

163class DataTypeDefaults(NamedTuple): 

164 suffix_stat: str 

165 suffix_xy: str 

166 color: str 

167 colormap: matplotlib.colors.Colormap | None 

168 

169 

170class LogFormatterExponentSci(LogFormatterMathtext): 

171 """ 

172 Format values following scientific notation. 

173 

174 Unlike the matplotlib LogFormatterExponent, this will print near-integer 

175 coefficients with a base between 0 and 2 such as 500 as 500 (if base10) 

176 or 5e2 otherwise. 

177 """ 

178 

179 def _non_decade_format(self, sign_string, base, fx, usetex): 

180 """Return string for non-decade locations.""" 

181 b = float(base) 

182 exponent = math.floor(fx) 

183 coeff = b ** (fx - exponent) 

184 rounded = round(coeff) 

185 if math.isclose(coeff, rounded): 

186 if (base == "10") and (0 <= exponent <= 3): 

187 return f"{sign_string}{rounded}{'0'*int(exponent)}" 

188 coeff = rounded 

189 return f"{sign_string}{coeff:1.1f}e{exponent}" 

190 

191 

192class ScatterPlotWithTwoHists(PlotAction): 

193 """Makes a scatter plot of the data with a marginal 

194 histogram for each axis. 

195 """ 

196 

197 yLims = ListField[float]( 

198 doc="ylimits of the plot, if not specified determined from data", 

199 length=2, 

200 optional=True, 

201 ) 

202 

203 xLims = ListField[float]( 

204 doc="xlimits of the plot, if not specified determined from data", length=2, optional=True 

205 ) 

206 xAxisLabel = Field[str](doc="Label to use for the x axis", optional=False) 

207 yAxisLabel = Field[str](doc="Label to use for the y axis", optional=False) 

208 magLabel = Field[str](doc="Label to use for the magnitudes used for SNR", optional=False) 

209 

210 legendLocation = Field[str](doc="Legend position within main plot", default="upper left") 

211 nBins = Field[float](doc="Number of bins on x axis", default=40.0) 

212 plot2DHist = Field[bool]( 

213 doc="Plot a 2D histogram in dense areas of points on the scatter plot." 

214 "Doesn't look great if plotting multiple datasets on top of each other.", 

215 default=True, 

216 ) 

217 plotTypes = ListField[str]( 

218 doc="Selection of types of objects to plot. Can take any combination of" 

219 " stars, galaxies, unknown, any", 

220 optional=False, 

221 itemCheck=_validObjectTypes, 

222 ) 

223 

224 addSummaryPlot = Field[bool]( 

225 doc="Add a summary plot to the figure?", 

226 default=True, 

227 ) 

228 histMinimum = Field[float]( 

229 doc="Minimum value for the histogram count axis", 

230 default=0.3, 

231 ) 

232 xHistMaxLabels = Field[int]( 

233 doc="Maximum number of labels for ticks on the x-axis marginal histogram", 

234 default=3, 

235 check=lambda x: x >= 2, 

236 ) 

237 yHistMaxLabels = Field[int]( 

238 doc="Maximum number of labels for ticks on the y-axis marginal histogram", 

239 default=3, 

240 check=lambda x: x >= 2, 

241 ) 

242 

243 suffix_x = Field[str](doc="Suffix for all x-axis action inputs", optional=True, default="") 

244 suffix_y = Field[str](doc="Suffix for all y-axis action inputs", optional=True, default="") 

245 suffix_stat = Field[str](doc="Suffix for all binned statistic action inputs", optional=True, default="") 

246 

247 publicationStyle = Field[bool](doc="Slimmed down publication style plot?", default=False) 

248 

249 _stats = ("median", "sigmaMad", "count", "approxMag") 

250 _datatypes = { 

251 "galaxies": DataTypeDefaults( 

252 suffix_stat="Galaxies", 

253 suffix_xy="Galaxies", 

254 color=galaxies_color(), 

255 colormap=galaxies_cmap(single_color=True), 

256 ), 

257 "stars": DataTypeDefaults( 

258 suffix_stat="Stars", 

259 suffix_xy="Stars", 

260 color=stars_color(), 

261 colormap=stars_cmap(single_color=True), 

262 ), 

263 "unknown": DataTypeDefaults( 

264 suffix_stat="Unknown", 

265 suffix_xy="Unknown", 

266 color="green", 

267 colormap=None, 

268 ), 

269 "any": DataTypeDefaults( 

270 suffix_stat="Any", 

271 suffix_xy="", 

272 color="purple", 

273 colormap=None, 

274 ), 

275 } 

276 

277 def getInputSchema(self) -> KeyedDataSchema: 

278 base: list[tuple[str, type[Vector] | ScalarType]] = [] 

279 for name_datatype in self.plotTypes: 

280 config_datatype = self._datatypes[name_datatype] 

281 if not self.publicationStyle: 

282 base.append((f"x{config_datatype.suffix_xy}{self.suffix_x}", Vector)) 

283 base.append((f"y{config_datatype.suffix_xy}{self.suffix_y}", Vector)) 

284 base.append((f"{name_datatype}HighSNMask{self.suffix_stat}", Vector)) 

285 base.append((f"{name_datatype}LowSNMask{self.suffix_stat}", Vector)) 

286 # statistics 

287 for name in self._stats: 

288 base.append( 

289 (f"{{band}}_highSN{config_datatype.suffix_stat}_{name}{self.suffix_stat}", Scalar) 

290 ) 

291 base.append( 

292 (f"{{band}}_lowSN{config_datatype.suffix_stat}_{name}{self.suffix_stat}", Scalar) 

293 ) 

294 base.append((f"{name_datatype}LowSNThreshold{self.suffix_stat}", Scalar)) 

295 base.append((f"{name_datatype}HighSNThreshold{self.suffix_stat}", Scalar)) 

296 

297 if self.addSummaryPlot and not self.publicationStyle: 

298 base.append(("patch", Vector)) 

299 

300 return base 

301 

302 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure: 

303 self._validateInput(data, **kwargs) 

304 return self.makePlot(data, **kwargs) 

305 

306 def _validateInput(self, data: KeyedData, **kwargs) -> None: 

307 """NOTE currently can only check that something is not a Scalar, not 

308 check that the data is consistent with Vector 

309 """ 

310 needed = self.getFormattedInputSchema(**kwargs) 

311 if remainder := {key.format(**kwargs) for key, _ in needed} - { 

312 key.format(**kwargs) for key in data.keys() 

313 }: 

314 raise ValueError( 

315 f"Task needs keys {remainder} but they were not found in input keys" f" {list(data.keys())}" 

316 ) 

317 for name, typ in needed: 

318 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar) 

319 if isScalar and typ != Scalar: 

320 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}") 

321 

322 def makePlot( 

323 self, 

324 data: KeyedData, 

325 plotInfo: Mapping[str, str], 

326 **kwargs, 

327 ) -> Figure: 

328 """Makes a generic plot with a 2D histogram and collapsed histograms of 

329 each axis. 

330 

331 Parameters 

332 ---------- 

333 data : `KeyedData` 

334 The catalog to plot the points from. 

335 plotInfo : `dict` 

336 A dictionary of information about the data being plotted with keys: 

337 

338 * ``"run"`` 

339 The output run for the plots (`str`). 

340 * ``"skymap"`` 

341 The type of skymap used for the data (`str`). 

342 * ``"filter"`` 

343 The filter used for this data (`str`). 

344 * ``"tract"`` 

345 The tract that the data comes from (`str`). 

346 

347 Returns 

348 ------- 

349 fig : `matplotlib.figure.Figure` 

350 The resulting figure. 

351 

352 Notes 

353 ----- 

354 Uses the axisLabels config options `x` and `y` and the axisAction 

355 config options `xAction` and `yAction` to plot a scatter 

356 plot of the values against each other. A histogram of the points 

357 collapsed onto each axis is also plotted. A summary panel showing the 

358 median of the y value in each patch is shown in the upper right corner 

359 of the resultant plot. The code uses the selectorActions to decide 

360 which points to plot and the statisticSelector actions to determine 

361 which points to use for the printed statistics. 

362 

363 If this function is being used within the pipetask framework 

364 that takes care of making sure that data has all the required 

365 elements but if you are running this as a standalone function 

366 then you will need to provide the following things in the 

367 input data. 

368 

369 * If stars is in self.plotTypes: 

370 xStars, yStars, starsHighSNMask, starsLowSNMask and 

371 {band}_highSNStars_{name}, {band}_lowSNStars_{name} 

372 where name is median, sigma_Mad, count and approxMag. 

373 

374 * If it is for galaxies/unknowns then replace stars in the above 

375 names with galaxies/unknowns. 

376 

377 * If it is for any (which covers all the points) then it 

378 becomes, x, y, and any instead of stars for the other 

379 parameters given above. 

380 

381 * In every case it is expected that data contains: 

382 lowSnThreshold, highSnThreshold and patch 

383 (if the summary plot is being plotted). 

384 

385 Examples 

386 -------- 

387 An example of the plot produced from this code is here: 

388 

389 .. image:: /_static/analysis_tools/scatterPlotExample.png 

390 

391 For a detailed example of how to make a plot from the command line 

392 please see the 

393 :ref:`getting started guide<analysis-tools-getting-started>`. 

394 """ 

395 if not self.plotTypes: 

396 noDataFig = Figure() 

397 noDataFig.text(0.3, 0.5, "No data to plot after selectors applied") 

398 noDataFig = addPlotInfo(noDataFig, plotInfo) 

399 return noDataFig 

400 

401 # Set default color and line style for the horizontal 

402 # reference line at 0 

403 if "hlineColor" not in kwargs: 

404 kwargs["hlineColor"] = "black" 

405 

406 if "hlineStyle" not in kwargs: 

407 kwargs["hlineStyle"] = (0, (1, 4)) 

408 

409 set_rubin_plotstyle() 

410 fig = make_figure() 

411 gs = gridspec.GridSpec(4, 4) 

412 

413 # add the various plot elements 

414 ax, imhist = self._scatterPlot(data, fig, gs, **kwargs) 

415 if ax is None: 

416 noDataFig = Figure() 

417 noDataFig.text(0.3, 0.5, "No data to plot after selectors applied") 

418 if not self.publicationStyle: 

419 noDataFig = addPlotInfo(noDataFig, plotInfo) 

420 return noDataFig 

421 

422 self._makeTopHistogram(data, fig, gs, ax, **kwargs) 

423 self._makeSideHistogram(data, fig, gs, ax, imhist, **kwargs) 

424 # Needs info from run quantum 

425 skymap = kwargs.get("skymap", None) 

426 if self.addSummaryPlot and skymap is not None and not self.publicationStyle: 

427 sumStats = generateSummaryStats(data, skymap, plotInfo) 

428 label = self.yAxisLabel 

429 fig = addSummaryPlot(fig, gs[0, -1], sumStats, label) 

430 

431 fig.canvas.draw() 

432 # TODO: Check if these spacings can be defined less arbitrarily 

433 fig.subplots_adjust( 

434 wspace=0.0, 

435 hspace=0.0, 

436 bottom=0.13 if self.publicationStyle else 0.22, 

437 left=0.18 if self.publicationStyle else 0.21, 

438 right=0.92 if self.publicationStyle else None, 

439 top=0.98 if self.publicationStyle else None, 

440 ) 

441 if not self.publicationStyle: 

442 fig = addPlotInfo(fig, plotInfo) 

443 return fig 

444 

445 def _scatterPlot( 

446 self, data: KeyedData, fig: Figure, gs: gridspec.GridSpec, **kwargs 

447 ) -> tuple[Axes, Optional[PolyCollection]]: 

448 suf_x = self.suffix_x 

449 suf_y = self.suffix_y 

450 suf_stat = self.suffix_stat 

451 # Main scatter plot 

452 ax = fig.add_subplot(gs[1:, :-1]) 

453 

454 binThresh = 5 

455 min_n_xs_for_stats = 10 

456 

457 yBinsOut = [] 

458 linesForLegend = [] 

459 

460 toPlotList = [] 

461 histIm = None 

462 highStats: _StatsContainer 

463 lowStats: _StatsContainer 

464 

465 magLabel = self.magLabel 

466 kwargs_in_label = {k: v for k, v in kwargs.items() if f"{{{k}}}" in magLabel} 

467 if kwargs_in_label: 

468 magLabel = magLabel.format(**kwargs_in_label) 

469 

470 for name_datatype in self.plotTypes: 

471 config_datatype = self._datatypes[name_datatype] 

472 highArgs = {} 

473 lowArgs = {} 

474 if not self.publicationStyle: 

475 for name in self._stats: 

476 highArgs[name] = cast( 

477 Scalar, 

478 data[ 

479 f"{{band}}_highSN{config_datatype.suffix_stat}_{name}{suf_stat}".format(**kwargs) 

480 ], 

481 ) 

482 lowArgs[name] = cast( 

483 Scalar, 

484 data[ 

485 f"{{band}}_lowSN{config_datatype.suffix_stat}_{name}{suf_stat}".format(**kwargs) 

486 ], 

487 ) 

488 highStats = _StatsContainer(**highArgs) 

489 lowStats = _StatsContainer(**lowArgs) 

490 

491 toPlotList.append( 

492 ( 

493 data[f"x{config_datatype.suffix_xy}{suf_x}"], 

494 data[f"y{config_datatype.suffix_xy}{suf_y}"], 

495 data[f"{name_datatype}HighSNMask{suf_stat}"], 

496 data[f"{name_datatype}LowSNMask{suf_stat}"], 

497 data[f"{name_datatype}HighSNThreshold{suf_stat}"], 

498 data[f"{name_datatype}LowSNThreshold{suf_stat}"], 

499 config_datatype.color, 

500 config_datatype.colormap, 

501 highStats, 

502 lowStats, 

503 ) 

504 ) 

505 else: 

506 toPlotList.append( 

507 ( 

508 data[f"x{config_datatype.suffix_xy}{suf_x}"], 

509 data[f"y{config_datatype.suffix_xy}{suf_y}"], 

510 [], 

511 [], 

512 [], 

513 [], 

514 config_datatype.color, 

515 config_datatype.colormap, 

516 [], 

517 [], 

518 ) 

519 ) 

520 

521 xLims = self.xLims if self.xLims is not None else [np.inf, -np.inf] 

522 

523 # If there is no data to plot make a 

524 # no data figure 

525 numData = 0 

526 for xs, _, _, _, _, _, _, _, _, _ in toPlotList: 

527 numData += np.count_nonzero(np.isfinite(xs)) 

528 if numData == 0: 

529 return None, None 

530 

531 for j, ( 

532 xs, 

533 ys, 

534 highSn, 

535 lowSn, 

536 highThresh, 

537 lowThresh, 

538 color, 

539 cmap, 

540 highStats, 

541 lowStats, 

542 ) in enumerate(toPlotList): 

543 highSn = cast(Vector, highSn) 

544 lowSn = cast(Vector, lowSn) 

545 # ensure the columns are actually array 

546 xs = np.array(xs) 

547 ys = np.array(ys) 

548 sigMadYs = nanSigmaMad(ys) 

549 # plot lone median point if there's not enough data to measure more 

550 n_xs = np.count_nonzero(np.isfinite(xs)) 

551 if n_xs <= 1 or not (np.isfinite(sigMadYs) and sigMadYs >= 0.0): 

552 continue 

553 elif n_xs < min_n_xs_for_stats: 

554 xs = [nanMedian(xs)] 

555 sigMads = np.array([nanSigmaMad(ys)]) 

556 ys = np.array([nanMedian(ys)]) 

557 (medLine,) = ax.plot(xs, ys, color, label=f"Median: {ys[0]:.2g}", lw=0.8) 

558 linesForLegend.append(medLine) 

559 (sigMadLine,) = ax.plot( 

560 xs, 

561 ys + 1.0 * sigMads, 

562 color, 

563 alpha=0.8, 

564 lw=0.8, 

565 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:.2g}", 

566 ) 

567 ax.plot(xs, ys - 1.0 * sigMads, color, alpha=0.8) 

568 linesForLegend.append(sigMadLine) 

569 histIm = None 

570 continue 

571 

572 if self.xLims: 

573 xMin, xMax = self.xLims 

574 else: 

575 # Chop off 1/3% from only the finite xs values 

576 # (there may be +/-np.Inf values) 

577 # TODO: This should be configurable 

578 # but is there a good way to avoid redundant config params 

579 # without using slightly annoying subconfigs? 

580 xs1, xs97 = np.nanpercentile(xs[np.isfinite(xs)], (1, 97)) 

581 xScale = (xs97 - xs1) / 20.0 # This is ~5% of the data range 

582 xMin, xMax = (xs1 - xScale, xs97 + xScale) 

583 xLims[0] = min(xLims[0], xMin) 

584 xLims[1] = max(xLims[1], xMax) 

585 

586 xEdges = np.arange(xMin, xMax, (xMax - xMin) / self.nBins) 

587 medYs = nanMedian(ys) 

588 fiveSigmaHigh = medYs + 5.0 * sigMadYs 

589 fiveSigmaLow = medYs - 5.0 * sigMadYs 

590 binSize = (fiveSigmaHigh - fiveSigmaLow) / 101.0 

591 # When the binsize is 0 try using the 1st and 99th 

592 # percentile instead of the sigmas. 

593 if binSize == 0.0: 

594 p1, p99 = np.nanpercentile(ys, [1, 99]) 

595 binSize = (p99 - p1) / 101.0 

596 

597 # If fiveSigmaHigh and fiveSigmaLow are the same 

598 # then use the 1st and 99th percentiles to define the 

599 # yEdges. 

600 yEdges = np.arange(fiveSigmaLow, fiveSigmaHigh, binSize) 

601 if fiveSigmaLow == fiveSigmaHigh: 

602 yEdges = np.arange(p1, p99, binSize) 

603 

604 counts, xBins, yBins = np.histogram2d(xs, ys, bins=(xEdges, yEdges)) 

605 yBinsOut.append(yBins) 

606 countsYs = np.sum(counts, axis=1) 

607 

608 ids = np.where((countsYs > binThresh))[0] 

609 xEdgesPlot = xEdges[ids][1:] 

610 xEdges = xEdges[ids] 

611 

612 if len(ids) > 1: 

613 # Create the codes needed to turn the sigmaMad lines 

614 # into a path to speed up checking which points are 

615 # inside the area. 

616 codes = np.ones(len(xEdgesPlot) * 2) * Path.LINETO 

617 codes[0] = Path.MOVETO 

618 codes[-1] = Path.CLOSEPOLY 

619 

620 meds = np.zeros(len(xEdgesPlot)) 

621 threeSigMadVerts = np.zeros((len(xEdgesPlot) * 2, 2)) 

622 sigMads = np.zeros(len(xEdgesPlot)) 

623 

624 for i, xEdge in enumerate(xEdgesPlot): 

625 ids = np.where((xs < xEdge) & (xs > xEdges[i]) & (np.isfinite(ys)))[0] 

626 med = nanMedian(ys[ids]) 

627 sigMad = nanSigmaMad(ys[ids]) 

628 meds[i] = med 

629 sigMads[i] = sigMad 

630 threeSigMadVerts[i, :] = [xEdge, med + 3 * sigMad] 

631 threeSigMadVerts[-(i + 1), :] = [xEdge, med - 3 * sigMad] 

632 

633 if self.publicationStyle: 

634 linecolor = "k" 

635 else: 

636 linecolor = color 

637 

638 (medLine,) = ax.plot(xEdgesPlot, meds, linecolor, label="Running Median") 

639 linesForLegend.append(medLine) 

640 

641 # Make path to check which points lie within one sigma mad 

642 threeSigMadPath = Path(threeSigMadVerts, codes) 

643 

644 if not self.publicationStyle: 

645 # Add lines for the median +/- 3 * sigma MAD 

646 (threeSigMadLine,) = ax.plot( 

647 xEdgesPlot, 

648 threeSigMadVerts[: len(xEdgesPlot), 1], 

649 linecolor, 

650 alpha=0.4, 

651 label=r"3$\sigma_{MAD}$", 

652 ) 

653 ax.plot(xEdgesPlot[::-1], threeSigMadVerts[len(xEdgesPlot) :, 1], linecolor, alpha=0.4) 

654 

655 # Add lines for the median +/- 1 * sigma MAD 

656 (sigMadLine,) = ax.plot( 

657 xEdgesPlot, 

658 meds + 1.0 * sigMads, 

659 linecolor, 

660 alpha=0.8, 

661 label=r"$\sigma_{MAD}$", 

662 ls="dashed", 

663 ) 

664 linesForLegend.append(sigMadLine) 

665 ax.plot(xEdgesPlot, meds - 1.0 * sigMads, linecolor, alpha=0.8, ls="dashed") 

666 

667 if not self.publicationStyle: 

668 # Add lines for the median +/- 2 * sigma MAD 

669 (twoSigMadLine,) = ax.plot( 

670 xEdgesPlot, meds + 2.0 * sigMads, linecolor, alpha=0.6, label=r"2$\sigma_{MAD}$" 

671 ) 

672 linesForLegend.append(twoSigMadLine) 

673 linesForLegend.append(threeSigMadLine) 

674 ax.plot(xEdgesPlot, meds - 2.0 * sigMads, linecolor, alpha=0.6) 

675 

676 # Check which points are outside 3 sigma MAD of the median 

677 # and plot these as points. 

678 inside = threeSigMadPath.contains_points(np.array([xs, ys]).T) 

679 ax.plot(xs[~inside], ys[~inside], ".", ms=5, alpha=0.2, mfc=color, mec="none", zorder=-1) 

680 

681 if not self.publicationStyle: 

682 # Add some stats text 

683 xPos = 0.65 - 0.4 * j 

684 bbox = dict(edgecolor=color, linestyle="--", facecolor="none") 

685 statText = f"S/N > {highThresh:0.4g} Stats ({magLabel} < {highStats.approxMag:0.4g})\n" 

686 highStatsStr = ( 

687 f"Median: {highStats.median:0.4g} " 

688 + r"$\sigma_{MAD}$: " 

689 + f"{highStats.sigmaMad:0.4g} " 

690 + r"N$_{points}$: " 

691 + f"{highStats.count}" 

692 ) 

693 statText += highStatsStr 

694 fig.text(xPos, 0.090, statText, bbox=bbox, transform=fig.transFigure, fontsize=6) 

695 

696 bbox = dict(edgecolor=color, linestyle=":", facecolor="none") 

697 statText = f"S/N > {lowThresh:0.4g} Stats ({magLabel} < {lowStats.approxMag:0.4g})\n" 

698 lowStatsStr = ( 

699 f"Median: {lowStats.median:0.4g} " 

700 + r"$\sigma_{MAD}$: " 

701 + f"{lowStats.sigmaMad:0.4g} " 

702 + r"N$_{points}$: " 

703 + f"{lowStats.count}" 

704 ) 

705 statText += lowStatsStr 

706 fig.text(xPos, 0.020, statText, bbox=bbox, transform=fig.transFigure, fontsize=6) 

707 

708 if self.plot2DHist: 

709 extent = [xLims[0], xLims[1], self.yLims[0], self.yLims[1]] if self.yLims else None 

710 histIm = ax.hexbin( 

711 xs[inside], 

712 ys[inside], 

713 gridsize=75, 

714 extent=extent, 

715 cmap=cmap, 

716 mincnt=1, 

717 zorder=-3, 

718 edgecolors=None, 

719 ) 

720 else: 

721 ax.plot(xs[inside], ys[inside], ".", ms=3, alpha=0.2, mfc=color, mec=color, zorder=-1) 

722 

723 if not self.publicationStyle: 

724 # If there are not many sources being used for the 

725 # statistics then plot them individually as just 

726 # plotting a line makes the statistics look wrong 

727 # as the magnitude estimation is iffy for low 

728 # numbers of sources. 

729 if np.sum(highSn) < 100 and np.sum(highSn) > 0: 

730 ax.plot( 

731 cast(Vector, xs[highSn]), 

732 cast(Vector, ys[highSn]), 

733 marker="x", 

734 ms=4, 

735 mec="w", 

736 mew=2, 

737 ls="none", 

738 ) 

739 (highSnLine,) = ax.plot( 

740 cast(Vector, xs[highSn]), 

741 cast(Vector, ys[highSn]), 

742 color=color, 

743 marker="x", 

744 ms=4, 

745 ls="none", 

746 label="High SN", 

747 ) 

748 linesForLegend.append(highSnLine) 

749 else: 

750 ax.axvline(highStats.approxMag, color=color, ls="--") 

751 

752 if np.sum(lowSn) < 100 and np.sum(lowSn) > 0: 

753 ax.plot( 

754 cast(Vector, xs[lowSn]), 

755 cast(Vector, ys[lowSn]), 

756 marker="+", 

757 ms=4, 

758 mec="w", 

759 mew=2, 

760 ls="none", 

761 ) 

762 (lowSnLine,) = ax.plot( 

763 cast(Vector, xs[lowSn]), 

764 cast(Vector, ys[lowSn]), 

765 color=color, 

766 marker="+", 

767 ms=4, 

768 ls="none", 

769 label="Low SN", 

770 ) 

771 linesForLegend.append(lowSnLine) 

772 else: 

773 ax.axvline(lowStats.approxMag, color=color, ls=":") 

774 

775 else: 

776 ax.plot(xs, ys, ".", ms=5, alpha=0.3, mfc=color, mec=color, zorder=-1) 

777 meds = np.array([nanMedian(ys)] * len(xs)) 

778 (medLine,) = ax.plot(xs, meds, color, label=f"Median: {nanMedian(ys):0.3g}", lw=0.8) 

779 linesForLegend.append(medLine) 

780 sigMads = np.array([nanSigmaMad(ys)] * len(xs)) 

781 (sigMadLine,) = ax.plot( 

782 xs, 

783 meds + 1.0 * sigMads, 

784 color, 

785 alpha=0.8, 

786 lw=0.8, 

787 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:0.3g}", 

788 ) 

789 ax.plot(xs, meds - 1.0 * sigMads, color, alpha=0.8) 

790 linesForLegend.append(sigMadLine) 

791 histIm = None 

792 

793 # Add a horizontal reference line at 0 to the scatter plot 

794 ax.axhline(0, color=kwargs["hlineColor"], ls=kwargs["hlineStyle"], alpha=0.7, zorder=-2) 

795 

796 # Set the scatter plot limits 

797 suf_x = self.suffix_y 

798 # TODO: Make this not work by accident 

799 if f"yStars{suf_x}" in data and (len(cast(Vector, data[f"yStars{suf_x}"])) > 0): 

800 plotMed = nanMedian(cast(Vector, data[f"yStars{suf_x}"])) 

801 elif f"yGalaxies{suf_x}" in data and (len(cast(Vector, data[f"yGalaxies{suf_x}"])) > 0): 

802 plotMed = nanMedian(cast(Vector, data[f"yGalaxies{suf_x}"])) 

803 else: 

804 plotMed = np.nan 

805 

806 # Ignore types below pending making this not working my accident 

807 # If len(xs) < min_n_xs_for_stats then `meds` doesn't exist. 

808 if len(xs) < min_n_xs_for_stats: # type: ignore 

809 meds = [nanMedian(ys)] # type: ignore 

810 if self.yLims: 

811 ax.set_ylim(self.yLims[0], self.yLims[1]) # type: ignore 

812 elif np.isfinite(plotMed): 

813 numSig = 4 

814 yLimMin = plotMed - numSig * sigMadYs # type: ignore 

815 yLimMax = plotMed + numSig * sigMadYs # type: ignore 

816 while (yLimMax < np.max(meds) or yLimMin > np.min(meds)) and numSig < 10: # type: ignore 

817 numSig += 1 

818 

819 numSig += 1 

820 yLimMin = plotMed - numSig * sigMadYs # type: ignore 

821 yLimMax = plotMed + numSig * sigMadYs # type: ignore 

822 

823 # If len(y) == 0 for ys in toPlotList then sigMadY = NaN 

824 # ... in which case nothing was plotted and limits are irrelevant. 

825 if all(np.isfinite([yLimMin, yLimMax])): 

826 ax.set_ylim(yLimMin, yLimMax) 

827 

828 # This could be false if len(x) == 0 for xs in toPlotList 

829 # ... in which case nothing was plotted and limits are irrelevant 

830 if all(np.isfinite(xLims)): 

831 ax.set_xlim(xLims) 

832 

833 # Add a line legend 

834 ax.legend( 

835 handles=linesForLegend, 

836 ncol=4, 

837 fontsize=6, 

838 loc=self.legendLocation, 

839 framealpha=0.9, 

840 edgecolor="k", 

841 borderpad=0.4, 

842 handlelength=3, 

843 ) 

844 

845 # Add axes labels 

846 band = kwargs.get("band", "unspecified") 

847 xlabel = self.xAxisLabel 

848 ylabel = self.yAxisLabel 

849 if "{band}" in xlabel: 

850 xlabel = xlabel.format(band=band) 

851 if "{band}" in ylabel: 

852 ylabel = ylabel.format(band=band) 

853 if self.publicationStyle: 

854 ax.set_ylabel(ylabel, labelpad=10) 

855 ax.set_xlabel(xlabel, labelpad=2) 

856 else: 

857 ax.set_ylabel(ylabel, labelpad=10, fontsize=10) 

858 ax.set_xlabel(xlabel, labelpad=2, fontsize=10) 

859 ax.tick_params(labelsize=8) 

860 

861 return ax, histIm 

862 

863 def _makeTopHistogram( 

864 self, data: KeyedData, figure: Figure, gs: gridspec.GridSpec, ax: Axes, **kwargs 

865 ) -> None: 

866 suf_x = self.suffix_x 

867 # Top histogram 

868 topHist = figure.add_subplot(gs[0, :-1], sharex=ax) 

869 x_min, x_max = ax.get_xlim() 

870 bins = np.linspace(x_min, x_max, 100) 

871 

872 if "any" in self.plotTypes: 

873 x_any = f"x{self._datatypes['any'].suffix_xy}{suf_x}" 

874 keys_notany = [x for x in self.plotTypes if x != "any"] 

875 else: 

876 x_any = ( 

877 np.concatenate([data[f"x{self._datatypes[key].suffix_xy}{suf_x}"] for key in self.plotTypes]) 

878 if (len(self.plotTypes) > 1) 

879 else None 

880 ) 

881 keys_notany = self.plotTypes 

882 if x_any is not None: 

883 if np.sum(x_any > 0) > 0: 

884 log = True 

885 else: 

886 log = False 

887 topHist.hist(x_any, bins=bins, color="grey", alpha=0.3, log=log, label=f"Any ({len(x_any)})") 

888 

889 for key in keys_notany: 

890 config_datatype = self._datatypes[key] 

891 vector = np.array(data[f"x{config_datatype.suffix_xy}{suf_x}"]) 

892 if np.sum(vector > 0) > 0: 

893 log = True 

894 else: 

895 log = False 

896 topHist.hist( 

897 vector, 

898 bins=bins, 

899 color=config_datatype.color, 

900 histtype="step", 

901 log=log, 

902 label=f"{config_datatype.suffix_stat} ({len(vector)})", 

903 ) 

904 topHist.axes.get_xaxis().set_visible(False) 

905 topHist.set_ylabel("Count", fontsize=10 + 4 * self.publicationStyle) 

906 if not self.publicationStyle: 

907 topHist.legend(fontsize=6, framealpha=0.9, borderpad=0.4, loc="lower left", ncol=3, edgecolor="k") 

908 topHist.tick_params(labelsize=8) 

909 

910 self._modifyHistogramTicks(topHist, do_x=False, max_labels=self.xHistMaxLabels) 

911 

912 def _makeSideHistogram( 

913 self, 

914 data: KeyedData, 

915 figure: Figure, 

916 gs: gridspec.Gridspec, 

917 ax: Axes, 

918 histIm: Optional[PolyCollection], 

919 **kwargs, 

920 ) -> None: 

921 suf_y = self.suffix_y 

922 # Side histogram 

923 sideHist = figure.add_subplot(gs[1:, -1], sharey=ax) 

924 y_min, y_max = ax.get_ylim() 

925 bins = np.linspace(y_min, y_max, 100) 

926 

927 if "any" in self.plotTypes: 

928 y_any = np.array(data[f"y{self._datatypes['any'].suffix_xy}{suf_y}"]) 

929 keys_notany = [x for x in self.plotTypes if x != "any"] 

930 else: 

931 y_any = ( 

932 np.concatenate( 

933 [np.array(data[f"y{self._datatypes[key].suffix_xy}{suf_y}"]) for key in self.plotTypes] 

934 ) 

935 if (len(self.plotTypes) > 1) 

936 else None 

937 ) 

938 keys_notany = self.plotTypes 

939 if y_any is not None: 

940 sideHist.hist( 

941 np.array(y_any), 

942 bins=bins, 

943 color="grey", 

944 alpha=0.3, 

945 orientation="horizontal", 

946 log=np.any(y_any > 0), 

947 ) 

948 

949 kwargs_hist = dict( 

950 bins=bins, 

951 histtype="step", 

952 log=True, 

953 orientation="horizontal", 

954 ) 

955 for key in keys_notany: 

956 config_datatype = self._datatypes[key] 

957 # If the data has no positive values then it 

958 # cannot be log scaled and it prints a bunch 

959 # of irritating warnings, in this case don't 

960 # try. 

961 numPos = np.sum(data[f"y{config_datatype.suffix_xy}{suf_y}"] > 0) 

962 

963 if numPos <= 0: 

964 kwargs_hist["log"] = False 

965 

966 vector = data[f"y{config_datatype.suffix_xy}{suf_y}"] 

967 sideHist.hist( 

968 vector, 

969 color=config_datatype.color, 

970 **kwargs_hist, 

971 ) 

972 if not self.publicationStyle: 

973 sideHist.hist( 

974 vector[cast(Vector, data[f"{key}HighSNMask{self.suffix_stat}"])], 

975 color=config_datatype.color, 

976 linestyle="--", 

977 **kwargs_hist, 

978 ) 

979 sideHist.hist( 

980 vector[cast(Vector, data[f"{key}LowSNMask{self.suffix_stat}"])], 

981 color=config_datatype.color, 

982 **kwargs_hist, 

983 linestyle=":", 

984 ) 

985 

986 # Add a horizontal reference line at 0 to the side histogram 

987 sideHist.axhline(0, color=kwargs["hlineColor"], ls=kwargs["hlineStyle"], alpha=0.7, zorder=-2) 

988 

989 sideHist.axes.get_yaxis().set_visible(False) 

990 sideHist.set_xlabel("Count", fontsize=10 + 4 * self.publicationStyle) 

991 self._modifyHistogramTicks(sideHist, do_x=True, max_labels=self.yHistMaxLabels) 

992 

993 if not self.publicationStyle: 

994 sideHist.tick_params(labelsize=8) 

995 if self.plot2DHist and histIm is not None: 

996 divider = make_axes_locatable(sideHist) 

997 cax = divider.append_axes("right", size="25%", pad=0) 

998 sideHist.get_figure().colorbar(histIm, cax=cax, orientation="vertical") 

999 text = cax.text( 

1000 0.5, 

1001 0.5, 

1002 "Points Per Bin", 

1003 color="k", 

1004 rotation="vertical", 

1005 transform=cax.transAxes, 

1006 ha="center", 

1007 va="center", 

1008 fontsize=10, 

1009 ) 

1010 text.set_path_effects([pathEffects.Stroke(linewidth=3, foreground="w"), pathEffects.Normal()]) 

1011 

1012 def _modifyHistogramTicks(self, histogram, do_x: bool, max_labels: int): 

1013 axis = histogram.get_xaxis() if do_x else histogram.get_yaxis() 

1014 limits = list(histogram.get_xlim() if do_x else histogram.get_ylim()) 

1015 get_ticks = histogram.get_xticks if do_x else histogram.get_yticks 

1016 ticks = get_ticks() 

1017 # Let the minimum be larger then specified if the histogram has large 

1018 # values everywhere, but cut it down a little so the lowest-valued bin 

1019 # is still easily visible 

1020 limits[0] = max(self.histMinimum, 0.9 * limits[0]) 

1021 # Round the upper limit to the nearest power of 10 

1022 limits[1] = 10 ** (np.ceil(np.log10(limits[1]))) if (limits[1] > 0) else limits[1] 

1023 for minor in (False, True): 

1024 # Ignore ticks that are below the minimum value 

1025 valid = (ticks >= limits[0]) & (ticks <= limits[1]) 

1026 labels = [label for label, _valid in zip(axis.get_ticklabels(minor=minor), valid) if _valid] 

1027 if (n_labels := len(labels)) > max_labels: 

1028 labels_new = [""] * n_labels 

1029 # Skip the first label if we're not using minor axis labels 

1030 # This helps avoid overlap with the scatter plot labels 

1031 for idx_fill in np.round(np.linspace(1 - minor, n_labels - 1, max_labels)).astype(int): 

1032 labels_new[idx_fill] = labels[idx_fill] 

1033 axis.set_ticks(ticks[valid], labels_new) 

1034 # If there are enough major tick labels, disable minor tick labels 

1035 if len(labels) >= 2: 

1036 axis.set_minor_formatter(NullFormatter()) 

1037 break 

1038 else: 

1039 axis.set_minor_formatter( 

1040 LogFormatterExponentSci(minor_thresholds=(1, self.histMinimum / 10.0)) 

1041 ) 

1042 ticks = get_ticks(minor=True) 

1043 

1044 (histogram.set_xlim if do_x else histogram.set_ylim)(limits)