Coverage for python / lsst / analysis / tools / actions / plot / scatterplotWithTwoHists.py: 13%

445 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-25 08:55 +0000

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ("ScatterPlotStatsAction", "ScatterPlotWithTwoHists") 

25 

26import math 

27from collections.abc import Mapping 

28from typing import NamedTuple, cast 

29 

30import matplotlib.colors 

31import matplotlib.patheffects as pathEffects 

32import numpy as np 

33from matplotlib import gridspec 

34from matplotlib.axes import Axes 

35from matplotlib.collections import PolyCollection 

36from matplotlib.figure import Figure 

37from matplotlib.path import Path 

38from matplotlib.ticker import LogFormatterMathtext, NullFormatter 

39from mpl_toolkits.axes_grid1 import make_axes_locatable 

40 

41from lsst.pex.config import Field 

42from lsst.pex.config.configurableActions import ConfigurableActionField 

43from lsst.pex.config.listField import ListField 

44from lsst.utils.plotting import ( 

45 galaxies_cmap, 

46 galaxies_color, 

47 make_figure, 

48 set_rubin_plotstyle, 

49 stars_cmap, 

50 stars_color, 

51) 

52 

53from ...interfaces import KeyedData, KeyedDataAction, KeyedDataSchema, PlotAction, Scalar, ScalarType, Vector 

54from ...math import nanMedian, nanSigmaMad 

55from ..keyedData.summaryStatistics import SummaryStatisticAction 

56from ..scalar import MedianAction 

57from ..vector import ConvertFluxToMag, SnSelector 

58from .plotUtils import addPlotInfo, addSummaryPlot, generateSummaryStats 

59 

60 

61class ScatterPlotStatsAction(KeyedDataAction): 

62 """Calculates the statistics needed for the 

63 scatter plot with two hists. 

64 """ 

65 

66 vectorKey = Field[str](doc="Vector on which to compute statistics") 

67 highSNSelector = ConfigurableActionField[SnSelector]( 

68 doc="Selector used to determine high SN Objects", default=SnSelector(threshold=2700) 

69 ) 

70 lowSNSelector = ConfigurableActionField[SnSelector]( 

71 doc="Selector used to determine low SN Objects", default=SnSelector(threshold=500) 

72 ) 

73 prefix = Field[str]( 

74 doc="Prefix for all output fields; will use self.identity if None", 

75 optional=True, 

76 default=None, 

77 ) 

78 fluxType = Field[str](doc="Vector key to use to compute signal to noise ratio", default="{band}_psfFlux") 

79 suffix = Field[str](doc="Suffix for all output fields", default="") 

80 

81 def _get_key_prefix(self): 

82 prefix = self.prefix if self.prefix else (self.identity if self.identity else "") 

83 return prefix 

84 

85 def getInputSchema(self, **kwargs) -> KeyedDataSchema: 

86 yield (self.vectorKey, Vector) 

87 yield (self.fluxType, Vector) 

88 yield from self.highSNSelector.getInputSchema() 

89 yield from self.lowSNSelector.getInputSchema() 

90 

91 def getOutputSchema(self) -> KeyedDataSchema: 

92 prefix = self._get_key_prefix() 

93 prefix_lower = prefix.lower() if prefix else "" 

94 prefix_upper = prefix.capitalize() if prefix else "" 

95 suffix = self.suffix 

96 return ( 

97 (f"{prefix_lower}HighSNMask{suffix}", Vector), 

98 (f"{prefix_lower}LowSNMask{suffix}", Vector), 

99 (f"{{band}}_lowSN{prefix_upper}_median{suffix}", Scalar), 

100 (f"{{band}}_lowSN{prefix_upper}_sigmaMad{suffix}", Scalar), 

101 (f"{{band}}_lowSN{prefix_upper}_count{suffix}", Scalar), 

102 (f"{{band}}_lowSN{prefix_upper}_approxMag{suffix}", Scalar), 

103 (f"{{band}}_highSN{prefix_upper}_median{suffix}", Scalar), 

104 (f"{{band}}_highSN{prefix_upper}_sigmaMad{suffix}", Scalar), 

105 (f"{{band}}_highSN{prefix_upper}_count{suffix}", Scalar), 

106 (f"{{band}}_highSN{prefix_upper}_approxMag{suffix}", Scalar), 

107 (f"{prefix_lower}HighSNThreshold{suffix}", Scalar), 

108 (f"{prefix_lower}LowSNThreshold{suffix}", Scalar), 

109 ) 

110 

111 def __call__(self, data: KeyedData, **kwargs) -> KeyedData: 

112 results = {} 

113 prefix = self._get_key_prefix() 

114 prefix_lower = prefix.lower() if prefix else "" 

115 prefix_upper = prefix.capitalize() if prefix else "" 

116 suffix = self.suffix 

117 highMaskKey = f"{prefix_lower}HighSNMask{suffix}" 

118 results[highMaskKey] = self.highSNSelector(data, **kwargs) 

119 

120 lowMaskKey = f"{prefix_lower}LowSNMask{suffix}" 

121 results[lowMaskKey] = self.lowSNSelector(data, **kwargs) 

122 

123 prefix_band = f"{band}_" if (band := kwargs.get("band")) else "" 

124 fluxes = data[self.fluxType.format(band=band)] if band is not None else None 

125 

126 statAction = SummaryStatisticAction(vectorKey=self.vectorKey) 

127 

128 # this is sad, but pex_config seems to have broken behavior that 

129 # is dangerous to fix 

130 statAction.setDefaults() 

131 

132 medianAction = MedianAction(vectorKey="mag") 

133 magAction = ConvertFluxToMag(vectorKey="flux") 

134 

135 for maskKey, binName in ((lowMaskKey, "low"), (highMaskKey, "high")): 

136 name = f"{prefix_band}{binName}SN{prefix_upper}" 

137 # set the approxMag to the median mag in the SN selection 

138 results[f"{name}_approxMag{suffix}".format(**kwargs)] = ( 

139 medianAction({"mag": magAction({"flux": fluxes[results[maskKey]]})}) # type: ignore 

140 if band is not None 

141 else np.nan 

142 ) 

143 stats = statAction(data, **(kwargs | {"mask": results[maskKey]})).items() 

144 for name_stat, value in stats: 

145 tmpKey = f"{name}_{name_stat}{suffix}".format(**kwargs) 

146 results[tmpKey] = value 

147 results[f"{prefix_lower}HighSNThreshold{suffix}"] = self.highSNSelector.threshold # type: ignore 

148 results[f"{prefix_lower}LowSNThreshold{suffix}"] = self.lowSNSelector.threshold # type: ignore 

149 

150 return results 

151 

152 

153def _validObjectTypes(value): 

154 return value in ("stars", "galaxies", "unknown", "any") 

155 

156 

157# ignore type because of conflicting name on tuple baseclass 

158class _StatsContainer(NamedTuple): 

159 median: Scalar 

160 sigmaMad: Scalar 

161 count: Scalar # type: ignore 

162 approxMag: Scalar 

163 

164 

165class DataTypeDefaults(NamedTuple): 

166 suffix_stat: str 

167 suffix_xy: str 

168 color: str 

169 colormap: matplotlib.colors.Colormap | None 

170 

171 

172class LogFormatterExponentSci(LogFormatterMathtext): 

173 """ 

174 Format values following scientific notation. 

175 

176 Unlike the matplotlib LogFormatterExponent, this will print near-integer 

177 coefficients with a base between 0 and 2 such as 500 as 500 (if base10) 

178 or 5e2 otherwise. 

179 """ 

180 

181 def _non_decade_format(self, sign_string, base, fx, usetex): 

182 """Return string for non-decade locations.""" 

183 b = float(base) 

184 exponent = math.floor(fx) 

185 coeff = b ** (fx - exponent) 

186 rounded = round(coeff) 

187 if math.isclose(coeff, rounded): 

188 if (base == "10") and (0 <= exponent <= 3): 

189 return f"{sign_string}{rounded}{'0'*int(exponent)}" 

190 coeff = rounded 

191 return f"{sign_string}{coeff:1.1f}e{exponent}" 

192 

193 

194class ScatterPlotWithTwoHists(PlotAction): 

195 """Makes a scatter plot of the data with a marginal 

196 histogram for each axis. 

197 """ 

198 

199 yLims = ListField[float]( 

200 doc="ylimits of the plot, if not specified determined from data", 

201 length=2, 

202 optional=True, 

203 ) 

204 

205 xLims = ListField[float]( 

206 doc="xlimits of the plot, if not specified determined from data", length=2, optional=True 

207 ) 

208 xAxisLabel = Field[str](doc="Label to use for the x axis", optional=False) 

209 yAxisLabel = Field[str](doc="Label to use for the y axis", optional=False) 

210 magLabel = Field[str](doc="Label to use for the magnitudes used for SNR", optional=False) 

211 

212 legendLocation = Field[str](doc="Legend position within main plot", default="upper left") 

213 nBins = Field[float](doc="Number of bins on x axis", default=40.0) 

214 plot2DHist = Field[bool]( 

215 doc="Plot a 2D histogram in dense areas of points on the scatter plot." 

216 "Doesn't look great if plotting multiple datasets on top of each other.", 

217 default=True, 

218 ) 

219 plotTypes = ListField[str]( 

220 doc="Selection of types of objects to plot. Can take any combination of" 

221 " stars, galaxies, unknown, any", 

222 optional=False, 

223 itemCheck=_validObjectTypes, 

224 ) 

225 

226 addSummaryPlot = Field[bool]( 

227 doc="Add a summary plot to the figure?", 

228 default=True, 

229 ) 

230 histMinimum = Field[float]( 

231 doc="Minimum value for the histogram count axis", 

232 default=0.3, 

233 ) 

234 xHistMaxLabels = Field[int]( 

235 doc="Maximum number of labels for ticks on the x-axis marginal histogram", 

236 default=3, 

237 check=lambda x: x >= 2, 

238 ) 

239 yHistMaxLabels = Field[int]( 

240 doc="Maximum number of labels for ticks on the y-axis marginal histogram", 

241 default=3, 

242 check=lambda x: x >= 2, 

243 ) 

244 

245 suffix_x = Field[str](doc="Suffix for all x-axis action inputs", optional=True, default="") 

246 suffix_y = Field[str](doc="Suffix for all y-axis action inputs", optional=True, default="") 

247 suffix_stat = Field[str](doc="Suffix for all binned statistic action inputs", optional=True, default="") 

248 

249 publicationStyle = Field[bool](doc="Slimmed down publication style plot?", default=False) 

250 

251 _stats = ("median", "sigmaMad", "count", "approxMag") 

252 _datatypes = { 

253 "galaxies": DataTypeDefaults( 

254 suffix_stat="Galaxies", 

255 suffix_xy="Galaxies", 

256 color=galaxies_color(), 

257 colormap=galaxies_cmap(single_color=True), 

258 ), 

259 "stars": DataTypeDefaults( 

260 suffix_stat="Stars", 

261 suffix_xy="Stars", 

262 color=stars_color(), 

263 colormap=stars_cmap(single_color=True), 

264 ), 

265 "unknown": DataTypeDefaults( 

266 suffix_stat="Unknown", 

267 suffix_xy="Unknown", 

268 color="green", 

269 colormap=None, 

270 ), 

271 "any": DataTypeDefaults( 

272 suffix_stat="Any", 

273 suffix_xy="", 

274 color="purple", 

275 colormap=None, 

276 ), 

277 } 

278 

279 def getInputSchema(self) -> KeyedDataSchema: 

280 base: list[tuple[str, type[Vector] | ScalarType]] = [] 

281 for name_datatype in self.plotTypes: 

282 config_datatype = self._datatypes[name_datatype] 

283 if not self.publicationStyle: 

284 base.append((f"x{config_datatype.suffix_xy}{self.suffix_x}", Vector)) 

285 base.append((f"y{config_datatype.suffix_xy}{self.suffix_y}", Vector)) 

286 base.append((f"{name_datatype}HighSNMask{self.suffix_stat}", Vector)) 

287 base.append((f"{name_datatype}LowSNMask{self.suffix_stat}", Vector)) 

288 # statistics 

289 for name in self._stats: 

290 base.append( 

291 (f"{{band}}_highSN{config_datatype.suffix_stat}_{name}{self.suffix_stat}", Scalar) 

292 ) 

293 base.append( 

294 (f"{{band}}_lowSN{config_datatype.suffix_stat}_{name}{self.suffix_stat}", Scalar) 

295 ) 

296 base.append((f"{name_datatype}LowSNThreshold{self.suffix_stat}", Scalar)) 

297 base.append((f"{name_datatype}HighSNThreshold{self.suffix_stat}", Scalar)) 

298 

299 if self.addSummaryPlot and not self.publicationStyle: 

300 base.append(("patch", Vector)) 

301 

302 return base 

303 

304 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure: 

305 self._validateInput(data, **kwargs) 

306 return self.makePlot(data, **kwargs) 

307 

308 def _validateInput(self, data: KeyedData, **kwargs) -> None: 

309 """NOTE currently can only check that something is not a Scalar, not 

310 check that the data is consistent with Vector 

311 """ 

312 needed = self.getFormattedInputSchema(**kwargs) 

313 if remainder := {key.format(**kwargs) for key, _ in needed} - { 

314 key.format(**kwargs) for key in data.keys() 

315 }: 

316 raise ValueError( 

317 f"Task needs keys {remainder} but they were not found in input keys" f" {list(data.keys())}" 

318 ) 

319 for name, typ in needed: 

320 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar) 

321 if isScalar and typ != Scalar: 

322 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}") 

323 

324 def makePlot( 

325 self, 

326 data: KeyedData, 

327 plotInfo: Mapping[str, str], 

328 **kwargs, 

329 ) -> Figure: 

330 """Makes a generic plot with a 2D histogram and collapsed histograms of 

331 each axis. 

332 

333 Parameters 

334 ---------- 

335 data : `KeyedData` 

336 The catalog to plot the points from. 

337 plotInfo : `dict` 

338 A dictionary of information about the data being plotted with keys: 

339 

340 * ``"run"`` 

341 The output run for the plots (`str`). 

342 * ``"skymap"`` 

343 The type of skymap used for the data (`str`). 

344 * ``"filter"`` 

345 The filter used for this data (`str`). 

346 * ``"tract"`` 

347 The tract that the data comes from (`str`). 

348 

349 Returns 

350 ------- 

351 fig : `matplotlib.figure.Figure` 

352 The resulting figure. 

353 

354 Notes 

355 ----- 

356 Uses the axisLabels config options `x` and `y` and the axisAction 

357 config options `xAction` and `yAction` to plot a scatter 

358 plot of the values against each other. A histogram of the points 

359 collapsed onto each axis is also plotted. A summary panel showing the 

360 median of the y value in each patch is shown in the upper right corner 

361 of the resultant plot. The code uses the selectorActions to decide 

362 which points to plot and the statisticSelector actions to determine 

363 which points to use for the printed statistics. 

364 

365 If this function is being used within the pipetask framework 

366 that takes care of making sure that data has all the required 

367 elements but if you are running this as a standalone function 

368 then you will need to provide the following things in the 

369 input data. 

370 

371 * If stars is in self.plotTypes: 

372 xStars, yStars, starsHighSNMask, starsLowSNMask and 

373 {band}_highSNStars_{name}, {band}_lowSNStars_{name} 

374 where name is median, sigma_Mad, count and approxMag. 

375 

376 * If it is for galaxies/unknowns then replace stars in the above 

377 names with galaxies/unknowns. 

378 

379 * If it is for any (which covers all the points) then it 

380 becomes, x, y, and any instead of stars for the other 

381 parameters given above. 

382 

383 * In every case it is expected that data contains: 

384 lowSnThreshold, highSnThreshold and patch 

385 (if the summary plot is being plotted). 

386 

387 Examples 

388 -------- 

389 An example of the plot produced from this code is here: 

390 

391 .. image:: /_static/analysis_tools/scatterPlotExample.png 

392 

393 For a detailed example of how to make a plot from the command line 

394 please see the 

395 :ref:`getting started guide<analysis-tools-getting-started>`. 

396 """ 

397 if not self.plotTypes: 

398 noDataFig = Figure() 

399 noDataFig.text(0.3, 0.5, "No data to plot after selectors applied") 

400 noDataFig = addPlotInfo(noDataFig, plotInfo) 

401 return noDataFig 

402 

403 # Set default color and line style for the horizontal 

404 # reference line at 0 

405 if "hlineColor" not in kwargs: 

406 kwargs["hlineColor"] = "black" 

407 

408 if "hlineStyle" not in kwargs: 

409 kwargs["hlineStyle"] = (0, (1, 4)) 

410 

411 set_rubin_plotstyle() 

412 fig = make_figure() 

413 gs = gridspec.GridSpec(4, 4) 

414 

415 # add the various plot elements 

416 ax, imhist = self._scatterPlot(data, fig, gs, **kwargs) 

417 if ax is None: 

418 noDataFig = Figure() 

419 noDataFig.text(0.3, 0.5, "No data to plot after selectors applied") 

420 if not self.publicationStyle: 

421 noDataFig = addPlotInfo(noDataFig, plotInfo) 

422 return noDataFig 

423 

424 self._makeTopHistogram(data, fig, gs, ax, **kwargs) 

425 self._makeSideHistogram(data, fig, gs, ax, imhist, **kwargs) 

426 # Needs info from run quantum 

427 skymap = kwargs.get("skymap", None) 

428 if self.addSummaryPlot and skymap is not None and not self.publicationStyle: 

429 sumStats = generateSummaryStats(data, skymap, plotInfo) 

430 label = self.yAxisLabel 

431 fig = addSummaryPlot(fig, gs[0, -1], sumStats, label) 

432 

433 fig.canvas.draw() 

434 # TODO: Check if these spacings can be defined less arbitrarily 

435 fig.subplots_adjust( 

436 wspace=0.0, 

437 hspace=0.0, 

438 bottom=0.13 if self.publicationStyle else 0.22, 

439 left=0.18 if self.publicationStyle else 0.21, 

440 right=0.92 if self.publicationStyle else None, 

441 top=0.98 if self.publicationStyle else None, 

442 ) 

443 if not self.publicationStyle: 

444 fig = addPlotInfo(fig, plotInfo) 

445 return fig 

446 

447 def _scatterPlot( 

448 self, data: KeyedData, fig: Figure, gs: gridspec.GridSpec, **kwargs 

449 ) -> tuple[Axes, PolyCollection | None]: 

450 suf_x = self.suffix_x 

451 suf_y = self.suffix_y 

452 suf_stat = self.suffix_stat 

453 # Main scatter plot 

454 ax = fig.add_subplot(gs[1:, :-1]) 

455 

456 binThresh = 5 

457 min_n_xs_for_stats = 10 

458 

459 yBinsOut = [] 

460 linesForLegend = [] 

461 

462 toPlotList = [] 

463 histIm = None 

464 highStats: _StatsContainer 

465 lowStats: _StatsContainer 

466 

467 magLabel = self.magLabel 

468 kwargs_in_label = {k: v for k, v in kwargs.items() if f"{{{k}}}" in magLabel} 

469 if kwargs_in_label: 

470 magLabel = magLabel.format(**kwargs_in_label) 

471 

472 for name_datatype in self.plotTypes: 

473 config_datatype = self._datatypes[name_datatype] 

474 highArgs = {} 

475 lowArgs = {} 

476 if not self.publicationStyle: 

477 for name in self._stats: 

478 highArgs[name] = cast( 

479 Scalar, 

480 data[ 

481 f"{{band}}_highSN{config_datatype.suffix_stat}_{name}{suf_stat}".format(**kwargs) 

482 ], 

483 ) 

484 lowArgs[name] = cast( 

485 Scalar, 

486 data[ 

487 f"{{band}}_lowSN{config_datatype.suffix_stat}_{name}{suf_stat}".format(**kwargs) 

488 ], 

489 ) 

490 highStats = _StatsContainer(**highArgs) 

491 lowStats = _StatsContainer(**lowArgs) 

492 

493 toPlotList.append( 

494 ( 

495 data[f"x{config_datatype.suffix_xy}{suf_x}"], 

496 data[f"y{config_datatype.suffix_xy}{suf_y}"], 

497 data[f"{name_datatype}HighSNMask{suf_stat}"], 

498 data[f"{name_datatype}LowSNMask{suf_stat}"], 

499 data[f"{name_datatype}HighSNThreshold{suf_stat}"], 

500 data[f"{name_datatype}LowSNThreshold{suf_stat}"], 

501 config_datatype.color, 

502 config_datatype.colormap, 

503 highStats, 

504 lowStats, 

505 ) 

506 ) 

507 else: 

508 toPlotList.append( 

509 ( 

510 data[f"x{config_datatype.suffix_xy}{suf_x}"], 

511 data[f"y{config_datatype.suffix_xy}{suf_y}"], 

512 [], 

513 [], 

514 [], 

515 [], 

516 config_datatype.color, 

517 config_datatype.colormap, 

518 [], 

519 [], 

520 ) 

521 ) 

522 

523 xLims = self.xLims if self.xLims is not None else [np.inf, -np.inf] 

524 

525 # If there is no data to plot make a 

526 # no data figure 

527 numData = 0 

528 for xs, _, _, _, _, _, _, _, _, _ in toPlotList: 

529 numData += np.count_nonzero(np.isfinite(xs)) 

530 if numData == 0: 

531 return None, None 

532 

533 for j, ( 

534 xs, 

535 ys, 

536 highSn, 

537 lowSn, 

538 highThresh, 

539 lowThresh, 

540 color, 

541 cmap, 

542 highStats, 

543 lowStats, 

544 ) in enumerate(toPlotList): 

545 highSn = cast(Vector, highSn) 

546 lowSn = cast(Vector, lowSn) 

547 # ensure the columns are actually array 

548 xs = np.array(xs) 

549 ys = np.array(ys) 

550 sigMadYs = nanSigmaMad(ys) 

551 # plot lone median point if there's not enough data to measure more 

552 n_xs = np.count_nonzero(np.isfinite(xs)) 

553 if n_xs <= 1 or not (np.isfinite(sigMadYs) and sigMadYs >= 0.0): 

554 continue 

555 elif n_xs < min_n_xs_for_stats: 

556 xs = [nanMedian(xs)] 

557 sigMads = np.array([nanSigmaMad(ys)]) 

558 ys = np.array([nanMedian(ys)]) 

559 (medLine,) = ax.plot(xs, ys, color, label=f"Median: {ys[0]:.2g}", lw=0.8) 

560 linesForLegend.append(medLine) 

561 (sigMadLine,) = ax.plot( 

562 xs, 

563 ys + 1.0 * sigMads, 

564 color, 

565 alpha=0.8, 

566 lw=0.8, 

567 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:.2g}", 

568 ) 

569 ax.plot(xs, ys - 1.0 * sigMads, color, alpha=0.8) 

570 linesForLegend.append(sigMadLine) 

571 histIm = None 

572 continue 

573 

574 if self.xLims: 

575 xMin, xMax = self.xLims 

576 else: 

577 # Chop off 1/3% from only the finite xs values 

578 # (there may be +/-np.Inf values) 

579 # TODO: This should be configurable 

580 # but is there a good way to avoid redundant config params 

581 # without using slightly annoying subconfigs? 

582 xs1, xs97 = np.nanpercentile(xs[np.isfinite(xs)], (1, 97)) 

583 xScale = (xs97 - xs1) / 20.0 # This is ~5% of the data range 

584 xMin, xMax = (xs1 - xScale, xs97 + xScale) 

585 xLims[0] = min(xLims[0], xMin) 

586 xLims[1] = max(xLims[1], xMax) 

587 

588 xEdges = np.arange(xMin, xMax, (xMax - xMin) / self.nBins) 

589 medYs = nanMedian(ys) 

590 fiveSigmaHigh = medYs + 5.0 * sigMadYs 

591 fiveSigmaLow = medYs - 5.0 * sigMadYs 

592 binSize = (fiveSigmaHigh - fiveSigmaLow) / 101.0 

593 # When the binsize is 0 try using the 1st and 99th 

594 # percentile instead of the sigmas. 

595 if binSize == 0.0: 

596 p1, p99 = np.nanpercentile(ys, [1, 99]) 

597 binSize = (p99 - p1) / 101.0 

598 

599 # If fiveSigmaHigh and fiveSigmaLow are the same 

600 # then use the 1st and 99th percentiles to define the 

601 # yEdges. 

602 yEdges = np.arange(fiveSigmaLow, fiveSigmaHigh, binSize) 

603 if fiveSigmaLow == fiveSigmaHigh: 

604 yEdges = np.arange(p1, p99, binSize) 

605 

606 counts, xBins, yBins = np.histogram2d(xs, ys, bins=(xEdges, yEdges)) 

607 yBinsOut.append(yBins) 

608 countsYs = np.sum(counts, axis=1) 

609 

610 ids = np.where(countsYs > binThresh)[0] 

611 xEdgesPlot = xEdges[ids][1:] 

612 xEdges = xEdges[ids] 

613 

614 if len(ids) > 1: 

615 # Create the codes needed to turn the sigmaMad lines 

616 # into a path to speed up checking which points are 

617 # inside the area. 

618 codes = np.ones(len(xEdgesPlot) * 2) * Path.LINETO 

619 codes[0] = Path.MOVETO 

620 codes[-1] = Path.CLOSEPOLY 

621 

622 meds = np.zeros(len(xEdgesPlot)) 

623 threeSigMadVerts = np.zeros((len(xEdgesPlot) * 2, 2)) 

624 sigMads = np.zeros(len(xEdgesPlot)) 

625 

626 for i, xEdge in enumerate(xEdgesPlot): 

627 ids = np.where((xs < xEdge) & (xs > xEdges[i]) & (np.isfinite(ys)))[0] 

628 med = nanMedian(ys[ids]) 

629 sigMad = nanSigmaMad(ys[ids]) 

630 meds[i] = med 

631 sigMads[i] = sigMad 

632 threeSigMadVerts[i, :] = [xEdge, med + 3 * sigMad] 

633 threeSigMadVerts[-(i + 1), :] = [xEdge, med - 3 * sigMad] 

634 

635 if self.publicationStyle: 

636 linecolor = "k" 

637 else: 

638 linecolor = color 

639 

640 (medLine,) = ax.plot(xEdgesPlot, meds, linecolor, label="Running Median") 

641 linesForLegend.append(medLine) 

642 

643 # Make path to check which points lie within one sigma mad 

644 threeSigMadPath = Path(threeSigMadVerts, codes) 

645 

646 if not self.publicationStyle: 

647 # Add lines for the median +/- 3 * sigma MAD 

648 (threeSigMadLine,) = ax.plot( 

649 xEdgesPlot, 

650 threeSigMadVerts[: len(xEdgesPlot), 1], 

651 linecolor, 

652 alpha=0.4, 

653 label=r"3$\sigma_{MAD}$", 

654 ) 

655 ax.plot(xEdgesPlot[::-1], threeSigMadVerts[len(xEdgesPlot) :, 1], linecolor, alpha=0.4) 

656 

657 # Add lines for the median +/- 1 * sigma MAD 

658 (sigMadLine,) = ax.plot( 

659 xEdgesPlot, 

660 meds + 1.0 * sigMads, 

661 linecolor, 

662 alpha=0.8, 

663 label=r"$\sigma_{MAD}$", 

664 ls="dashed", 

665 ) 

666 linesForLegend.append(sigMadLine) 

667 ax.plot(xEdgesPlot, meds - 1.0 * sigMads, linecolor, alpha=0.8, ls="dashed") 

668 

669 if not self.publicationStyle: 

670 # Add lines for the median +/- 2 * sigma MAD 

671 (twoSigMadLine,) = ax.plot( 

672 xEdgesPlot, meds + 2.0 * sigMads, linecolor, alpha=0.6, label=r"2$\sigma_{MAD}$" 

673 ) 

674 linesForLegend.append(twoSigMadLine) 

675 linesForLegend.append(threeSigMadLine) 

676 ax.plot(xEdgesPlot, meds - 2.0 * sigMads, linecolor, alpha=0.6) 

677 

678 # Check which points are outside 3 sigma MAD of the median 

679 # and plot these as points. 

680 inside = threeSigMadPath.contains_points(np.array([xs, ys]).T) 

681 ax.plot(xs[~inside], ys[~inside], ".", ms=5, alpha=0.2, mfc=color, mec="none", zorder=-1) 

682 

683 if not self.publicationStyle: 

684 # Add some stats text 

685 xPos = 0.65 - 0.4 * j 

686 bbox = dict(edgecolor=color, linestyle="--", facecolor="none") 

687 statText = f"S/N > {highThresh:0.4g} Stats ({magLabel} < {highStats.approxMag:0.4g})\n" 

688 highStatsStr = ( 

689 f"Median: {highStats.median:0.4g} " 

690 + r"$\sigma_{MAD}$: " 

691 + f"{highStats.sigmaMad:0.4g} " 

692 + r"N$_{points}$: " 

693 + f"{highStats.count}" 

694 ) 

695 statText += highStatsStr 

696 fig.text(xPos, 0.090, statText, bbox=bbox, transform=fig.transFigure, fontsize=6) 

697 

698 bbox = dict(edgecolor=color, linestyle=":", facecolor="none") 

699 statText = f"S/N > {lowThresh:0.4g} Stats ({magLabel} < {lowStats.approxMag:0.4g})\n" 

700 lowStatsStr = ( 

701 f"Median: {lowStats.median:0.4g} " 

702 + r"$\sigma_{MAD}$: " 

703 + f"{lowStats.sigmaMad:0.4g} " 

704 + r"N$_{points}$: " 

705 + f"{lowStats.count}" 

706 ) 

707 statText += lowStatsStr 

708 fig.text(xPos, 0.020, statText, bbox=bbox, transform=fig.transFigure, fontsize=6) 

709 

710 if self.plot2DHist: 

711 extent = [xLims[0], xLims[1], self.yLims[0], self.yLims[1]] if self.yLims else None 

712 histIm = ax.hexbin( 

713 xs[inside], 

714 ys[inside], 

715 gridsize=75, 

716 extent=extent, 

717 cmap=cmap, 

718 mincnt=1, 

719 zorder=-3, 

720 edgecolors=None, 

721 ) 

722 else: 

723 ax.plot(xs[inside], ys[inside], ".", ms=3, alpha=0.2, mfc=color, mec=color, zorder=-1) 

724 

725 if not self.publicationStyle: 

726 # If there are not many sources being used for the 

727 # statistics then plot them individually as just 

728 # plotting a line makes the statistics look wrong 

729 # as the magnitude estimation is iffy for low 

730 # numbers of sources. 

731 if np.sum(highSn) < 100 and np.sum(highSn) > 0: 

732 ax.plot( 

733 cast(Vector, xs[highSn]), 

734 cast(Vector, ys[highSn]), 

735 marker="x", 

736 ms=4, 

737 mec="w", 

738 mew=2, 

739 ls="none", 

740 ) 

741 (highSnLine,) = ax.plot( 

742 cast(Vector, xs[highSn]), 

743 cast(Vector, ys[highSn]), 

744 color=color, 

745 marker="x", 

746 ms=4, 

747 ls="none", 

748 label="High SN", 

749 ) 

750 linesForLegend.append(highSnLine) 

751 else: 

752 ax.axvline(highStats.approxMag, color=color, ls="--") 

753 

754 if np.sum(lowSn) < 100 and np.sum(lowSn) > 0: 

755 ax.plot( 

756 cast(Vector, xs[lowSn]), 

757 cast(Vector, ys[lowSn]), 

758 marker="+", 

759 ms=4, 

760 mec="w", 

761 mew=2, 

762 ls="none", 

763 ) 

764 (lowSnLine,) = ax.plot( 

765 cast(Vector, xs[lowSn]), 

766 cast(Vector, ys[lowSn]), 

767 color=color, 

768 marker="+", 

769 ms=4, 

770 ls="none", 

771 label="Low SN", 

772 ) 

773 linesForLegend.append(lowSnLine) 

774 else: 

775 ax.axvline(lowStats.approxMag, color=color, ls=":") 

776 

777 else: 

778 ax.plot(xs, ys, ".", ms=5, alpha=0.3, mfc=color, mec=color, zorder=-1) 

779 meds = np.array([nanMedian(ys)] * len(xs)) 

780 (medLine,) = ax.plot(xs, meds, color, label=f"Median: {nanMedian(ys):0.3g}", lw=0.8) 

781 linesForLegend.append(medLine) 

782 sigMads = np.array([nanSigmaMad(ys)] * len(xs)) 

783 (sigMadLine,) = ax.plot( 

784 xs, 

785 meds + 1.0 * sigMads, 

786 color, 

787 alpha=0.8, 

788 lw=0.8, 

789 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:0.3g}", 

790 ) 

791 ax.plot(xs, meds - 1.0 * sigMads, color, alpha=0.8) 

792 linesForLegend.append(sigMadLine) 

793 histIm = None 

794 

795 # Add a horizontal reference line at 0 to the scatter plot 

796 ax.axhline(0, color=kwargs["hlineColor"], ls=kwargs["hlineStyle"], alpha=0.7, zorder=-2) 

797 

798 # Set the scatter plot limits 

799 suf_x = self.suffix_y 

800 # TODO: Make this not work by accident 

801 if f"yStars{suf_x}" in data and (len(cast(Vector, data[f"yStars{suf_x}"])) > 0): 

802 plotMed = nanMedian(cast(Vector, data[f"yStars{suf_x}"])) 

803 elif f"yGalaxies{suf_x}" in data and (len(cast(Vector, data[f"yGalaxies{suf_x}"])) > 0): 

804 plotMed = nanMedian(cast(Vector, data[f"yGalaxies{suf_x}"])) 

805 else: 

806 plotMed = np.nan 

807 

808 # Ignore types below pending making this not working my accident 

809 # If len(xs) < min_n_xs_for_stats then `meds` doesn't exist. 

810 if len(xs) < min_n_xs_for_stats: # type: ignore 

811 meds = [nanMedian(ys)] # type: ignore 

812 if self.yLims: 

813 ax.set_ylim(self.yLims[0], self.yLims[1]) # type: ignore 

814 elif np.isfinite(plotMed): 

815 numSig = 4 

816 yLimMin = plotMed - numSig * sigMadYs # type: ignore 

817 yLimMax = plotMed + numSig * sigMadYs # type: ignore 

818 while (yLimMax < np.max(meds) or yLimMin > np.min(meds)) and numSig < 10: # type: ignore 

819 numSig += 1 

820 

821 numSig += 1 

822 yLimMin = plotMed - numSig * sigMadYs # type: ignore 

823 yLimMax = plotMed + numSig * sigMadYs # type: ignore 

824 

825 # If len(y) == 0 for ys in toPlotList then sigMadY = NaN 

826 # ... in which case nothing was plotted and limits are irrelevant. 

827 if all(np.isfinite([yLimMin, yLimMax])): 

828 ax.set_ylim(yLimMin, yLimMax) 

829 

830 # This could be false if len(x) == 0 for xs in toPlotList 

831 # ... in which case nothing was plotted and limits are irrelevant 

832 if all(np.isfinite(xLims)): 

833 ax.set_xlim(xLims) 

834 

835 # Add a line legend 

836 ax.legend( 

837 handles=linesForLegend, 

838 ncol=4, 

839 fontsize=6, 

840 loc=self.legendLocation, 

841 framealpha=0.9, 

842 edgecolor="k", 

843 borderpad=0.4, 

844 handlelength=3, 

845 ) 

846 

847 # Add axes labels 

848 band = kwargs.get("band", "unspecified") 

849 xlabel = self.xAxisLabel 

850 ylabel = self.yAxisLabel 

851 if "{band}" in xlabel: 

852 xlabel = xlabel.format(band=band) 

853 if "{band}" in ylabel: 

854 ylabel = ylabel.format(band=band) 

855 if self.publicationStyle: 

856 ax.set_ylabel(ylabel, labelpad=10) 

857 ax.set_xlabel(xlabel, labelpad=2) 

858 else: 

859 ax.set_ylabel(ylabel, labelpad=10, fontsize=10) 

860 ax.set_xlabel(xlabel, labelpad=2, fontsize=10) 

861 ax.tick_params(labelsize=8) 

862 

863 return ax, histIm 

864 

865 def _makeTopHistogram( 

866 self, data: KeyedData, figure: Figure, gs: gridspec.GridSpec, ax: Axes, **kwargs 

867 ) -> None: 

868 suf_x = self.suffix_x 

869 # Top histogram 

870 topHist = figure.add_subplot(gs[0, :-1], sharex=ax) 

871 x_min, x_max = ax.get_xlim() 

872 bins = np.linspace(x_min, x_max, 100) 

873 

874 if "any" in self.plotTypes: 

875 x_any = f"x{self._datatypes['any'].suffix_xy}{suf_x}" 

876 keys_notany = [x for x in self.plotTypes if x != "any"] 

877 else: 

878 x_any = ( 

879 np.concatenate([data[f"x{self._datatypes[key].suffix_xy}{suf_x}"] for key in self.plotTypes]) 

880 if (len(self.plotTypes) > 1) 

881 else None 

882 ) 

883 keys_notany = self.plotTypes 

884 if x_any is not None: 

885 if np.sum(x_any > 0) > 0: 

886 log = True 

887 else: 

888 log = False 

889 topHist.hist(x_any, bins=bins, color="grey", alpha=0.3, log=log, label=f"Any ({len(x_any)})") 

890 

891 for key in keys_notany: 

892 config_datatype = self._datatypes[key] 

893 vector = np.array(data[f"x{config_datatype.suffix_xy}{suf_x}"]) 

894 if np.sum(vector > 0) > 0: 

895 log = True 

896 else: 

897 log = False 

898 topHist.hist( 

899 vector, 

900 bins=bins, 

901 color=config_datatype.color, 

902 histtype="step", 

903 log=log, 

904 label=f"{config_datatype.suffix_stat} ({len(vector)})", 

905 ) 

906 topHist.axes.get_xaxis().set_visible(False) 

907 topHist.set_ylabel("Count", fontsize=10 + 4 * self.publicationStyle) 

908 if not self.publicationStyle: 

909 topHist.legend(fontsize=6, framealpha=0.9, borderpad=0.4, loc="lower left", ncol=3, edgecolor="k") 

910 topHist.tick_params(labelsize=8) 

911 

912 self._modifyHistogramTicks(topHist, do_x=False, max_labels=self.xHistMaxLabels) 

913 

914 def _makeSideHistogram( 

915 self, 

916 data: KeyedData, 

917 figure: Figure, 

918 gs: gridspec.Gridspec, 

919 ax: Axes, 

920 histIm: PolyCollection | None, 

921 **kwargs, 

922 ) -> None: 

923 suf_y = self.suffix_y 

924 # Side histogram 

925 sideHist = figure.add_subplot(gs[1:, -1], sharey=ax) 

926 y_min, y_max = ax.get_ylim() 

927 bins = np.linspace(y_min, y_max, 100) 

928 

929 if "any" in self.plotTypes: 

930 y_any = np.array(data[f"y{self._datatypes['any'].suffix_xy}{suf_y}"]) 

931 keys_notany = [x for x in self.plotTypes if x != "any"] 

932 else: 

933 y_any = ( 

934 np.concatenate( 

935 [np.array(data[f"y{self._datatypes[key].suffix_xy}{suf_y}"]) for key in self.plotTypes] 

936 ) 

937 if (len(self.plotTypes) > 1) 

938 else None 

939 ) 

940 keys_notany = self.plotTypes 

941 if y_any is not None: 

942 sideHist.hist( 

943 np.array(y_any), 

944 bins=bins, 

945 color="grey", 

946 alpha=0.3, 

947 orientation="horizontal", 

948 log=np.any(y_any > 0), 

949 ) 

950 

951 kwargs_hist = dict( 

952 bins=bins, 

953 histtype="step", 

954 log=True, 

955 orientation="horizontal", 

956 ) 

957 for key in keys_notany: 

958 config_datatype = self._datatypes[key] 

959 # If the data has no positive values then it 

960 # cannot be log scaled and it prints a bunch 

961 # of irritating warnings, in this case don't 

962 # try. 

963 numPos = np.sum(data[f"y{config_datatype.suffix_xy}{suf_y}"] > 0) 

964 

965 if numPos <= 0: 

966 kwargs_hist["log"] = False 

967 

968 vector = data[f"y{config_datatype.suffix_xy}{suf_y}"] 

969 sideHist.hist( 

970 vector, 

971 color=config_datatype.color, 

972 **kwargs_hist, 

973 ) 

974 if not self.publicationStyle: 

975 sideHist.hist( 

976 vector[cast(Vector, data[f"{key}HighSNMask{self.suffix_stat}"])], 

977 color=config_datatype.color, 

978 linestyle="--", 

979 **kwargs_hist, 

980 ) 

981 sideHist.hist( 

982 vector[cast(Vector, data[f"{key}LowSNMask{self.suffix_stat}"])], 

983 color=config_datatype.color, 

984 **kwargs_hist, 

985 linestyle=":", 

986 ) 

987 

988 # Add a horizontal reference line at 0 to the side histogram 

989 sideHist.axhline(0, color=kwargs["hlineColor"], ls=kwargs["hlineStyle"], alpha=0.7, zorder=-2) 

990 

991 sideHist.axes.get_yaxis().set_visible(False) 

992 sideHist.set_xlabel("Count", fontsize=10 + 4 * self.publicationStyle) 

993 self._modifyHistogramTicks(sideHist, do_x=True, max_labels=self.yHistMaxLabels) 

994 

995 if not self.publicationStyle: 

996 sideHist.tick_params(labelsize=8) 

997 if self.plot2DHist and histIm is not None: 

998 divider = make_axes_locatable(sideHist) 

999 cax = divider.append_axes("right", size="25%", pad=0) 

1000 sideHist.get_figure().colorbar(histIm, cax=cax, orientation="vertical") 

1001 text = cax.text( 

1002 0.5, 

1003 0.5, 

1004 "Points Per Bin", 

1005 color="k", 

1006 rotation="vertical", 

1007 transform=cax.transAxes, 

1008 ha="center", 

1009 va="center", 

1010 fontsize=10, 

1011 ) 

1012 text.set_path_effects([pathEffects.Stroke(linewidth=3, foreground="w"), pathEffects.Normal()]) 

1013 

1014 def _modifyHistogramTicks(self, histogram, do_x: bool, max_labels: int): 

1015 axis = histogram.get_xaxis() if do_x else histogram.get_yaxis() 

1016 limits = list(histogram.get_xlim() if do_x else histogram.get_ylim()) 

1017 get_ticks = histogram.get_xticks if do_x else histogram.get_yticks 

1018 ticks = get_ticks() 

1019 # Let the minimum be larger then specified if the histogram has large 

1020 # values everywhere, but cut it down a little so the lowest-valued bin 

1021 # is still easily visible 

1022 limits[0] = max(self.histMinimum, 0.9 * limits[0]) 

1023 # Round the upper limit to the nearest power of 10 

1024 limits[1] = 10 ** (np.ceil(np.log10(limits[1]))) if (limits[1] > 0) else limits[1] 

1025 for minor in (False, True): 

1026 # Ignore ticks that are below the minimum value 

1027 valid = (ticks >= limits[0]) & (ticks <= limits[1]) 

1028 labels = [label for label, _valid in zip(axis.get_ticklabels(minor=minor), valid) if _valid] 

1029 if (n_labels := len(labels)) > max_labels: 

1030 labels_new = [""] * n_labels 

1031 # Skip the first label if we're not using minor axis labels 

1032 # This helps avoid overlap with the scatter plot labels 

1033 for idx_fill in np.round(np.linspace(1 - minor, n_labels - 1, max_labels)).astype(int): 

1034 labels_new[idx_fill] = labels[idx_fill] 

1035 axis.set_ticks(ticks[valid], labels_new) 

1036 # If there are enough major tick labels, disable minor tick labels 

1037 if len(labels) >= 2: 

1038 axis.set_minor_formatter(NullFormatter()) 

1039 break 

1040 else: 

1041 axis.set_minor_formatter( 

1042 LogFormatterExponentSci(minor_thresholds=(1, self.histMinimum / 10.0)) 

1043 ) 

1044 ticks = get_ticks(minor=True) 

1045 

1046 (histogram.set_xlim if do_x else histogram.set_ylim)(limits)