Coverage for python/lsst/analysis/tools/actions/plot/scatterplotWithTwoHists.py: 16%

320 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-02-06 12:42 +0000

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ("ScatterPlotStatsAction", "ScatterPlotWithTwoHists") 

25 

26from typing import Mapping, NamedTuple, Optional, cast 

27 

28import matplotlib.colors 

29import matplotlib.pyplot as plt 

30import numpy as np 

31from lsst.pex.config import Field 

32from lsst.pex.config.configurableActions import ConfigurableActionField 

33from lsst.pex.config.listField import ListField 

34from lsst.skymap import BaseSkyMap 

35from matplotlib import gridspec 

36from matplotlib.axes import Axes 

37from matplotlib.collections import PolyCollection 

38from matplotlib.figure import Figure 

39from matplotlib.path import Path 

40from mpl_toolkits.axes_grid1 import make_axes_locatable 

41 

42from ...interfaces import KeyedData, KeyedDataAction, KeyedDataSchema, PlotAction, Scalar, ScalarType, Vector 

43from ...math import nanMedian, nanSigmaMad 

44from ..keyedData.summaryStatistics import SummaryStatisticAction 

45from ..scalar import MedianAction 

46from ..vector import ConvertFluxToMag, SnSelector 

47from .plotUtils import addPlotInfo, addSummaryPlot, generateSummaryStats, mkColormap 

48 

49# ignore because coolwarm is actually part of module 

50cmapPatch = plt.cm.coolwarm.copy() # type: ignore 

51cmapPatch.set_bad(color="none") 

52 

53 

54class ScatterPlotStatsAction(KeyedDataAction): 

55 """Calculates the statistics needed for the 

56 scatter plot with two hists. 

57 """ 

58 

59 vectorKey = Field[str](doc="Vector on which to compute statistics") 

60 highSNSelector = ConfigurableActionField[SnSelector]( 

61 doc="Selector used to determine high SN Objects", default=SnSelector(threshold=2700) 

62 ) 

63 lowSNSelector = ConfigurableActionField[SnSelector]( 

64 doc="Selector used to determine low SN Objects", default=SnSelector(threshold=500) 

65 ) 

66 fluxType = Field[str](doc="Vector key to use to compute signal to noise ratio", default="{band}_psfFlux") 

67 

68 def getInputSchema(self, **kwargs) -> KeyedDataSchema: 

69 yield (self.vectorKey, Vector) 

70 yield (self.fluxType, Vector) 

71 yield from self.highSNSelector.getInputSchema() 

72 yield from self.lowSNSelector.getInputSchema() 

73 

74 def getOutputSchema(self) -> KeyedDataSchema: 

75 return ( 

76 (f'{self.identity or ""}HighSNMask', Vector), 

77 (f'{self.identity or ""}LowSNMask', Vector), 

78 (f"{{band}}_lowSN{self.identity.capitalize() if self.identity else ''}_median", Scalar), 

79 (f"{{band}}_lowSN{self.identity.capitalize() if self.identity else ''}_sigmaMad", Scalar), 

80 (f"{{band}}_lowSN{self.identity.capitalize() if self.identity else ''}_count", Scalar), 

81 (f"{{band}}_lowSN{self.identity.capitalize() if self.identity else ''}_approxMag", Scalar), 

82 (f"{{band}}_highSN{self.identity.capitalize() if self.identity else ''}_median", Scalar), 

83 (f"{{band}}_highSN{self.identity.capitalize() if self.identity else ''}_sigmaMad", Scalar), 

84 (f"{{band}}_highSN{self.identity.capitalize() if self.identity else ''}_count", Scalar), 

85 (f"{{band}}_highSN{self.identity.capitalize() if self.identity else ''}_approxMag", Scalar), 

86 ("highThreshold", Scalar), 

87 ("lowThreshold", Scalar), 

88 ) 

89 

90 def __call__(self, data: KeyedData, **kwargs) -> KeyedData: 

91 results = {} 

92 highMaskKey = f'{(self.identity or "").lower()}HighSNMask' 

93 results[highMaskKey] = self.highSNSelector(data, **kwargs) 

94 

95 lowMaskKey = f'{(self.identity or "").lower()}LowSNMask' 

96 results[lowMaskKey] = self.lowSNSelector(data, **kwargs) 

97 

98 prefix = f"{band}_" if (band := kwargs.get("band")) else "" 

99 fluxes = data[self.fluxType.format(band=band)] if band is not None else None 

100 

101 statAction = SummaryStatisticAction(vectorKey=self.vectorKey) 

102 

103 # this is sad, but pex_config seems to have broken behavior that 

104 # is dangerous to fix 

105 statAction.setDefaults() 

106 

107 medianAction = MedianAction(vectorKey="mag") 

108 magAction = ConvertFluxToMag(vectorKey="flux") 

109 

110 for maskKey, binName in ((lowMaskKey, "low"), (highMaskKey, "high")): 

111 name = f"{prefix}{binName}SN{self.identity.capitalize() if self.identity else ''}" 

112 # set the approxMag to the median mag in the SN selection 

113 results[f"{name}_approxMag".format(**kwargs)] = ( 

114 medianAction({"mag": magAction({"flux": fluxes[results[maskKey]]})}) # type: ignore 

115 if band is not None 

116 else np.nan 

117 ) 

118 stats = statAction(data, **(kwargs | {"mask": results[maskKey]})).items() 

119 for suffix, value in stats: 

120 tmpKey = f"{name}_{suffix}".format(**kwargs) 

121 results[tmpKey] = value 

122 results["highSnThreshold"] = self.highSNSelector.threshold # type: ignore 

123 results["lowSnThreshold"] = self.lowSNSelector.threshold # type: ignore 

124 

125 return results 

126 

127 

128def _validObjectTypes(value): 

129 return value in ("stars", "galaxies", "unknown", "any") 

130 

131 

132# ignore type because of conflicting name on tuple baseclass 

133class _StatsContainer(NamedTuple): 

134 median: Scalar 

135 sigmaMad: Scalar 

136 count: Scalar # type: ignore 

137 approxMag: Scalar 

138 

139 

140class DataTypeDefaults(NamedTuple): 

141 suffix_stat: str 

142 suffix_xy: str 

143 color: str 

144 colormap: matplotlib.colors.Colormap | None 

145 

146 

147class ScatterPlotWithTwoHists(PlotAction): 

148 """Makes a scatter plot of the data with a marginal 

149 histogram for each axis. 

150 """ 

151 

152 yLims = ListField[float]( 

153 doc="ylimits of the plot, if not specified determined from data", 

154 length=2, 

155 optional=True, 

156 ) 

157 

158 xLims = ListField[float]( 

159 doc="xlimits of the plot, if not specified determined from data", length=2, optional=True 

160 ) 

161 xAxisLabel = Field[str](doc="Label to use for the x axis", optional=False) 

162 yAxisLabel = Field[str](doc="Label to use for the y axis", optional=False) 

163 magLabel = Field[str](doc="Label to use for the magnitudes used for SNR", optional=False) 

164 

165 legendLocation = Field[str](doc="Legend position within main plot", default="upper left") 

166 nBins = Field[float](doc="Number of bins on x axis", default=40.0) 

167 plot2DHist = Field[bool]( 

168 doc="Plot a 2D histogram in dense areas of points on the scatter plot." 

169 "Doesn't look great if plotting multiple datasets on top of each other.", 

170 default=True, 

171 ) 

172 plotTypes = ListField[str]( 

173 doc="Selection of types of objects to plot. Can take any combination of" 

174 " stars, galaxies, unknown, any", 

175 optional=False, 

176 itemCheck=_validObjectTypes, 

177 ) 

178 

179 addSummaryPlot = Field[bool]( 

180 doc="Add a summary plot to the figure?", 

181 default=True, 

182 ) 

183 

184 _stats = ("median", "sigmaMad", "count", "approxMag") 

185 _datatypes = { 

186 "galaxies": DataTypeDefaults( 

187 suffix_stat="Galaxies", 

188 suffix_xy="Galaxies", 

189 color="firebrick", 

190 colormap=mkColormap(["lemonchiffon", "firebrick"]), 

191 ), 

192 "stars": DataTypeDefaults( 

193 suffix_stat="Stars", 

194 suffix_xy="Stars", 

195 color="midnightblue", 

196 colormap=mkColormap(["paleturquoise", "midnightBlue"]), 

197 ), 

198 "unknown": DataTypeDefaults( 

199 suffix_stat="Unknown", 

200 suffix_xy="Unknown", 

201 color="green", 

202 colormap=None, 

203 ), 

204 "any": DataTypeDefaults( 

205 suffix_stat="Any", 

206 suffix_xy="", 

207 color="purple", 

208 colormap=None, 

209 ), 

210 } 

211 

212 def getInputSchema(self) -> KeyedDataSchema: 

213 base: list[tuple[str, type[Vector] | ScalarType]] = [] 

214 for name_datatype in self.plotTypes: 

215 config_datatype = self._datatypes[name_datatype] 

216 base.append((f"x{config_datatype.suffix_xy}", Vector)) 

217 base.append((f"y{config_datatype.suffix_xy}", Vector)) 

218 base.append((f"{name_datatype}HighSNMask", Vector)) 

219 base.append((f"{name_datatype}LowSNMask", Vector)) 

220 # statistics 

221 for name in self._stats: 

222 base.append((f"{{band}}_highSN{config_datatype.suffix_stat}_{name}", Scalar)) 

223 base.append((f"{{band}}_lowSN{config_datatype.suffix_stat}_{name}", Scalar)) 

224 base.append(("lowSnThreshold", Scalar)) 

225 base.append(("highSnThreshold", Scalar)) 

226 

227 if self.addSummaryPlot: 

228 base.append(("patch", Vector)) 

229 

230 return base 

231 

232 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure: 

233 self._validateInput(data, **kwargs) 

234 return self.makePlot(data, **kwargs) 

235 

236 def _validateInput(self, data: KeyedData, **kwargs) -> None: 

237 """NOTE currently can only check that something is not a Scalar, not 

238 check that the data is consistent with Vector 

239 """ 

240 needed = self.getFormattedInputSchema(**kwargs) 

241 if remainder := {key.format(**kwargs) for key, _ in needed} - { 

242 key.format(**kwargs) for key in data.keys() 

243 }: 

244 raise ValueError(f"Task needs keys {remainder} but they were not found in input") 

245 for name, typ in needed: 

246 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar) 

247 if isScalar and typ != Scalar: 

248 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}") 

249 

250 def makePlot( 

251 self, 

252 data: KeyedData, 

253 skymap: BaseSkyMap, 

254 plotInfo: Mapping[str, str], 

255 **kwargs, 

256 ) -> Figure: 

257 """Makes a generic plot with a 2D histogram and collapsed histograms of 

258 each axis. 

259 

260 Parameters 

261 ---------- 

262 data : `KeyedData` 

263 The catalog to plot the points from. 

264 skymap : `lsst.skymap.BaseSkyMap` 

265 The skymap that gives the patch locations 

266 plotInfo : `dict` 

267 A dictionary of information about the data being plotted with keys: 

268 

269 * ``"run"`` 

270 The output run for the plots (`str`). 

271 * ``"skymap"`` 

272 The type of skymap used for the data (`str`). 

273 * ``"filter"`` 

274 The filter used for this data (`str`). 

275 * ``"tract"`` 

276 The tract that the data comes from (`str`). 

277 

278 Returns 

279 ------- 

280 fig : `matplotlib.figure.Figure` 

281 The resulting figure. 

282 

283 Notes 

284 ----- 

285 Uses the axisLabels config options `x` and `y` and the axisAction 

286 config options `xAction` and `yAction` to plot a scatter 

287 plot of the values against each other. A histogram of the points 

288 collapsed onto each axis is also plotted. A summary panel showing the 

289 median of the y value in each patch is shown in the upper right corner 

290 of the resultant plot. The code uses the selectorActions to decide 

291 which points to plot and the statisticSelector actions to determine 

292 which points to use for the printed statistics. 

293 

294 If this function is being used within the pipetask framework 

295 that takes care of making sure that data has all the required 

296 elements but if you are running this as a standalone function 

297 then you will need to provide the following things in the 

298 input data. 

299 

300 * If stars is in self.plotTypes: 

301 xStars, yStars, starsHighSNMask, starsLowSNMask and 

302 {band}_highSNStars_{name}, {band}_lowSNStars_{name} 

303 where name is median, sigma_Mad, count and approxMag. 

304 

305 * If it is for galaxies/unknowns then replace stars in the above 

306 names with galaxies/unknowns. 

307 

308 * If it is for any (which covers all the points) then it 

309 becomes, x, y, and any instead of stars for the other 

310 parameters given above. 

311 

312 * In every case it is expected that data contains: 

313 lowSnThreshold, highSnThreshold and patch 

314 (if the summary plot is being plotted). 

315 

316 Examples 

317 -------- 

318 An example of the plot produced from this code is here: 

319 

320 .. image:: /_static/analysis_tools/scatterPlotExample.png 

321 

322 For a detailed example of how to make a plot from the command line 

323 please see the 

324 :ref:`getting started guide<analysis-tools-getting-started>`. 

325 """ 

326 if not self.plotTypes: 

327 noDataFig = Figure() 

328 noDataFig.text(0.3, 0.5, "No data to plot after selectors applied") 

329 noDataFig = addPlotInfo(noDataFig, plotInfo) 

330 return noDataFig 

331 

332 # Set default color and line style for the horizontal 

333 # reference line at 0 

334 if "hlineColor" not in kwargs: 

335 kwargs["hlineColor"] = "black" 

336 

337 if "hlineStyle" not in kwargs: 

338 kwargs["hlineStyle"] = (0, (1, 4)) 

339 

340 fig = plt.figure(dpi=300) 

341 gs = gridspec.GridSpec(4, 4) 

342 

343 # add the various plot elements 

344 ax, imhist = self._scatterPlot(data, fig, gs, **kwargs) 

345 self._makeTopHistogram(data, fig, gs, ax, **kwargs) 

346 self._makeSideHistogram(data, fig, gs, ax, imhist, **kwargs) 

347 # Needs info from run quantum 

348 if self.addSummaryPlot: 

349 sumStats = generateSummaryStats(data, skymap, plotInfo) 

350 label = self.yAxisLabel 

351 fig = addSummaryPlot(fig, gs[0, -1], sumStats, label) 

352 

353 plt.draw() 

354 plt.subplots_adjust(wspace=0.0, hspace=0.0, bottom=0.22, left=0.21) 

355 fig = addPlotInfo(fig, plotInfo) 

356 return fig 

357 

358 def _scatterPlot( 

359 self, data: KeyedData, fig: Figure, gs: gridspec.GridSpec, **kwargs 

360 ) -> tuple[Axes, Optional[PolyCollection]]: 

361 # Main scatter plot 

362 ax = fig.add_subplot(gs[1:, :-1]) 

363 

364 binThresh = 5 

365 

366 yBinsOut = [] 

367 linesForLegend = [] 

368 

369 toPlotList = [] 

370 histIm = None 

371 highStats: _StatsContainer 

372 lowStats: _StatsContainer 

373 

374 for name_datatype in self.plotTypes: 

375 config_datatype = self._datatypes[name_datatype] 

376 highArgs = {} 

377 lowArgs = {} 

378 for name in self._stats: 

379 highArgs[name] = cast( 

380 Scalar, data[f"{{band}}_highSN{config_datatype.suffix_stat}_{name}".format(**kwargs)] 

381 ) 

382 lowArgs[name] = cast( 

383 Scalar, data[f"{{band}}_lowSN{config_datatype.suffix_stat}_{name}".format(**kwargs)] 

384 ) 

385 highStats = _StatsContainer(**highArgs) 

386 lowStats = _StatsContainer(**lowArgs) 

387 

388 toPlotList.append( 

389 ( 

390 data[f"x{config_datatype.suffix_xy}"], 

391 data[f"y{config_datatype.suffix_xy}"], 

392 data[f"{name_datatype}HighSNMask"], 

393 data[f"{name_datatype}LowSNMask"], 

394 config_datatype.color, 

395 config_datatype.colormap, 

396 highStats, 

397 lowStats, 

398 ) 

399 ) 

400 

401 xLims = self.xLims if self.xLims is not None else [np.Inf, -np.Inf] 

402 for j, (xs, ys, highSn, lowSn, color, cmap, highStats, lowStats) in enumerate(toPlotList): 

403 highSn = cast(Vector, highSn) 

404 lowSn = cast(Vector, lowSn) 

405 # ensure the columns are actually array 

406 xs = np.array(xs) 

407 ys = np.array(ys) 

408 sigMadYs = nanSigmaMad(ys) 

409 # plot lone median point if there's not enough data to measure more 

410 n_xs = len(xs) 

411 if n_xs == 0: 

412 continue 

413 elif n_xs < 10: 

414 xs = [nanMedian(xs)] 

415 sigMads = np.array([nanSigmaMad(ys)]) 

416 ys = np.array([nanMedian(ys)]) 

417 (medLine,) = ax.plot(xs, ys, color, label=f"Median: {ys[0]:.2g}", lw=0.8) 

418 linesForLegend.append(medLine) 

419 (sigMadLine,) = ax.plot( 

420 xs, 

421 ys + 1.0 * sigMads, 

422 color, 

423 alpha=0.8, 

424 lw=0.8, 

425 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:.2g}", 

426 ) 

427 ax.plot(xs, ys - 1.0 * sigMads, color, alpha=0.8) 

428 linesForLegend.append(sigMadLine) 

429 histIm = None 

430 continue 

431 

432 if self.xLims: 

433 xMin, xMax = self.xLims 

434 else: 

435 # Chop off 1/3% from only the finite xs values 

436 # (there may be +/-np.Inf values) 

437 # TODO: This should be configurable 

438 # but is there a good way to avoid redundant config params 

439 # without using slightly annoying subconfigs? 

440 xs1, xs97 = np.nanpercentile(xs[np.isfinite(xs)], (1, 97)) 

441 xScale = (xs97 - xs1) / 20.0 # This is ~5% of the data range 

442 xMin, xMax = (xs1 - xScale, xs97 + xScale) 

443 xLims[0] = min(xLims[0], xMin) 

444 xLims[1] = max(xLims[1], xMax) 

445 

446 xEdges = np.arange(xMin, xMax, (xMax - xMin) / self.nBins) 

447 medYs = nanMedian(ys) 

448 fiveSigmaHigh = medYs + 5.0 * sigMadYs 

449 fiveSigmaLow = medYs - 5.0 * sigMadYs 

450 binSize = (fiveSigmaHigh - fiveSigmaLow) / 101.0 

451 yEdges = np.arange(fiveSigmaLow, fiveSigmaHigh, binSize) 

452 

453 counts, xBins, yBins = np.histogram2d(xs, ys, bins=(xEdges, yEdges)) 

454 yBinsOut.append(yBins) 

455 countsYs = np.sum(counts, axis=1) 

456 

457 ids = np.where((countsYs > binThresh))[0] 

458 xEdgesPlot = xEdges[ids][1:] 

459 xEdges = xEdges[ids] 

460 

461 if len(ids) > 1: 

462 # Create the codes needed to turn the sigmaMad lines 

463 # into a path to speed up checking which points are 

464 # inside the area. 

465 codes = np.ones(len(xEdgesPlot) * 2) * Path.LINETO 

466 codes[0] = Path.MOVETO 

467 codes[-1] = Path.CLOSEPOLY 

468 

469 meds = np.zeros(len(xEdgesPlot)) 

470 threeSigMadVerts = np.zeros((len(xEdgesPlot) * 2, 2)) 

471 sigMads = np.zeros(len(xEdgesPlot)) 

472 

473 for i, xEdge in enumerate(xEdgesPlot): 

474 ids = np.where((xs < xEdge) & (xs > xEdges[i]) & (np.isfinite(ys)))[0] 

475 med = nanMedian(ys[ids]) 

476 sigMad = nanSigmaMad(ys[ids]) 

477 meds[i] = med 

478 sigMads[i] = sigMad 

479 threeSigMadVerts[i, :] = [xEdge, med + 3 * sigMad] 

480 threeSigMadVerts[-(i + 1), :] = [xEdge, med - 3 * sigMad] 

481 

482 (medLine,) = ax.plot(xEdgesPlot, meds, color, label="Running Median") 

483 linesForLegend.append(medLine) 

484 

485 # Make path to check which points lie within one sigma mad 

486 threeSigMadPath = Path(threeSigMadVerts, codes) 

487 

488 # Add lines for the median +/- 3 * sigma MAD 

489 (threeSigMadLine,) = ax.plot( 

490 xEdgesPlot, 

491 threeSigMadVerts[: len(xEdgesPlot), 1], 

492 color, 

493 alpha=0.4, 

494 label=r"3$\sigma_{MAD}$", 

495 ) 

496 ax.plot(xEdgesPlot[::-1], threeSigMadVerts[len(xEdgesPlot) :, 1], color, alpha=0.4) 

497 

498 # Add lines for the median +/- 1 * sigma MAD 

499 (sigMadLine,) = ax.plot( 

500 xEdgesPlot, meds + 1.0 * sigMads, color, alpha=0.8, label=r"$\sigma_{MAD}$" 

501 ) 

502 linesForLegend.append(sigMadLine) 

503 ax.plot(xEdgesPlot, meds - 1.0 * sigMads, color, alpha=0.8) 

504 

505 # Add lines for the median +/- 2 * sigma MAD 

506 (twoSigMadLine,) = ax.plot( 

507 xEdgesPlot, meds + 2.0 * sigMads, color, alpha=0.6, label=r"2$\sigma_{MAD}$" 

508 ) 

509 linesForLegend.append(twoSigMadLine) 

510 linesForLegend.append(threeSigMadLine) 

511 ax.plot(xEdgesPlot, meds - 2.0 * sigMads, color, alpha=0.6) 

512 

513 # Check which points are outside 3 sigma MAD of the median 

514 # and plot these as points. 

515 inside = threeSigMadPath.contains_points(np.array([xs, ys]).T) 

516 ax.plot(xs[~inside], ys[~inside], ".", ms=3, alpha=0.3, mfc=color, mec=color, zorder=-1) 

517 

518 # Add some stats text 

519 xPos = 0.65 - 0.4 * j 

520 bbox = dict(edgecolor=color, linestyle="--", facecolor="none") 

521 highThresh = data["highSnThreshold"] 

522 statText = f"S/N > {highThresh:0.4g} Stats ({self.magLabel} < {highStats.approxMag:0.4g})\n" 

523 highStatsStr = ( 

524 f"Median: {highStats.median:0.4g} " 

525 + r"$\sigma_{MAD}$: " 

526 + f"{highStats.sigmaMad:0.4g} " 

527 + r"N$_{points}$: " 

528 + f"{highStats.count}" 

529 ) 

530 statText += highStatsStr 

531 fig.text(xPos, 0.090, statText, bbox=bbox, transform=fig.transFigure, fontsize=6) 

532 

533 bbox = dict(edgecolor=color, linestyle=":", facecolor="none") 

534 lowThresh = data["lowSnThreshold"] 

535 statText = f"S/N > {lowThresh:0.4g} Stats ({self.magLabel} < {lowStats.approxMag:0.4g})\n" 

536 lowStatsStr = ( 

537 f"Median: {lowStats.median:0.4g} " 

538 + r"$\sigma_{MAD}$: " 

539 + f"{lowStats.sigmaMad:0.4g} " 

540 + r"N$_{points}$: " 

541 + f"{lowStats.count}" 

542 ) 

543 statText += lowStatsStr 

544 fig.text(xPos, 0.020, statText, bbox=bbox, transform=fig.transFigure, fontsize=6) 

545 

546 if self.plot2DHist: 

547 histIm = ax.hexbin(xs[inside], ys[inside], gridsize=75, cmap=cmap, mincnt=1, zorder=-3) 

548 

549 # If there are not many sources being used for the 

550 # statistics then plot them individually as just 

551 # plotting a line makes the statistics look wrong 

552 # as the magnitude estimation is iffy for low 

553 # numbers of sources. 

554 if np.sum(highSn) < 100 and np.sum(highSn) > 0: 

555 ax.plot( 

556 cast(Vector, xs[highSn]), 

557 cast(Vector, ys[highSn]), 

558 marker="x", 

559 ms=4, 

560 mec="w", 

561 mew=2, 

562 ls="none", 

563 ) 

564 (highSnLine,) = ax.plot( 

565 cast(Vector, xs[highSn]), 

566 cast(Vector, ys[highSn]), 

567 color=color, 

568 marker="x", 

569 ms=4, 

570 ls="none", 

571 label="High SN", 

572 ) 

573 linesForLegend.append(highSnLine) 

574 else: 

575 ax.axvline(highStats.approxMag, color=color, ls="--") 

576 

577 if np.sum(lowSn) < 100 and np.sum(lowSn) > 0: 

578 ax.plot( 

579 cast(Vector, xs[lowSn]), 

580 cast(Vector, ys[lowSn]), 

581 marker="+", 

582 ms=4, 

583 mec="w", 

584 mew=2, 

585 ls="none", 

586 ) 

587 (lowSnLine,) = ax.plot( 

588 cast(Vector, xs[lowSn]), 

589 cast(Vector, ys[lowSn]), 

590 color=color, 

591 marker="+", 

592 ms=4, 

593 ls="none", 

594 label="Low SN", 

595 ) 

596 linesForLegend.append(lowSnLine) 

597 else: 

598 ax.axvline(lowStats.approxMag, color=color, ls=":") 

599 

600 else: 

601 ax.plot(xs, ys, ".", ms=5, alpha=0.3, mfc=color, mec=color, zorder=-1) 

602 meds = np.array([nanMedian(ys)] * len(xs)) 

603 (medLine,) = ax.plot(xs, meds, color, label=f"Median: {nanMedian(ys):0.3g}", lw=0.8) 

604 linesForLegend.append(medLine) 

605 sigMads = np.array([nanSigmaMad(ys)] * len(xs)) 

606 (sigMadLine,) = ax.plot( 

607 xs, 

608 meds + 1.0 * sigMads, 

609 color, 

610 alpha=0.8, 

611 lw=0.8, 

612 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:0.3g}", 

613 ) 

614 ax.plot(xs, meds - 1.0 * sigMads, color, alpha=0.8) 

615 linesForLegend.append(sigMadLine) 

616 histIm = None 

617 

618 # Add a horizontal reference line at 0 to the scatter plot 

619 ax.axhline(0, color=kwargs["hlineColor"], ls=kwargs["hlineStyle"], alpha=0.7, zorder=-2) 

620 

621 # Set the scatter plot limits 

622 # TODO: Make this not work by accident 

623 if "yStars" in data and (len(cast(Vector, data["yStars"])) > 0): 

624 plotMed = nanMedian(cast(Vector, data["yStars"])) 

625 elif "yGalaxies" in data and (len(cast(Vector, data["yGalaxies"])) > 0): 

626 plotMed = nanMedian(cast(Vector, data["yGalaxies"])) 

627 else: 

628 plotMed = np.nan 

629 

630 # Ignore types below pending making this not working my accident 

631 if len(xs) < 2: # type: ignore 

632 meds = [nanMedian(ys)] # type: ignore 

633 if self.yLims: 

634 ax.set_ylim(self.yLims[0], self.yLims[1]) # type: ignore 

635 elif np.isfinite(plotMed): 

636 numSig = 4 

637 yLimMin = plotMed - numSig * sigMadYs # type: ignore 

638 yLimMax = plotMed + numSig * sigMadYs # type: ignore 

639 while (yLimMax < np.max(meds) or yLimMin > np.min(meds)) and numSig < 10: # type: ignore 

640 numSig += 1 

641 

642 numSig += 1 

643 yLimMin = plotMed - numSig * sigMadYs # type: ignore 

644 yLimMax = plotMed + numSig * sigMadYs # type: ignore 

645 ax.set_ylim(yLimMin, yLimMax) 

646 

647 # This could be false if len(x) == 0 for xs in toPlotList 

648 # ... in which case nothing was plotted and limits are irrelevant 

649 if all(np.isfinite(xLims)): 

650 ax.set_xlim(xLims) 

651 

652 # Add a line legend 

653 ax.legend( 

654 handles=linesForLegend, 

655 ncol=4, 

656 fontsize=6, 

657 loc=self.legendLocation, 

658 framealpha=0.9, 

659 edgecolor="k", 

660 borderpad=0.4, 

661 handlelength=1, 

662 ) 

663 

664 # Add axes labels 

665 ax.set_ylabel(self.yAxisLabel, fontsize=10, labelpad=10) 

666 ax.set_xlabel(self.xAxisLabel, fontsize=10, labelpad=2) 

667 

668 return ax, histIm 

669 

670 def _makeTopHistogram( 

671 self, data: KeyedData, figure: Figure, gs: gridspec.GridSpec, ax: Axes, **kwargs 

672 ) -> None: 

673 # Top histogram 

674 topHist = figure.add_subplot(gs[0, :-1], sharex=ax) 

675 

676 if "any" in self.plotTypes: 

677 x_all = data[f"x{self._datatypes['any'].suffix_xy}"] 

678 keys_notall = [x for x in self.plotTypes if x != "any"] 

679 else: 

680 x_all = np.concatenate([data[f"x{self._datatypes[key].suffix_xy}"] for key in self.plotTypes]) 

681 keys_notall = self.plotTypes 

682 

683 x_min, x_max = ax.get_xlim() 

684 bins = np.linspace(x_min, x_max, 100) 

685 topHist.hist(x_all, bins=bins, color="grey", alpha=0.3, log=True, label=f"All ({len(x_all)})") 

686 for key in keys_notall: 

687 config_datatype = self._datatypes[key] 

688 vector = data[f"x{config_datatype.suffix_xy}"] 

689 topHist.hist( 

690 vector, 

691 bins=bins, 

692 color=config_datatype.color, 

693 histtype="step", 

694 log=True, 

695 label=f"{config_datatype.suffix_stat} ({len(vector)})", 

696 ) 

697 topHist.axes.get_xaxis().set_visible(False) 

698 topHist.set_ylabel("Number", fontsize=8) 

699 topHist.legend(fontsize=6, framealpha=0.9, borderpad=0.4, loc="lower left", ncol=3, edgecolor="k") 

700 

701 # Side histogram 

702 

703 def _makeSideHistogram( 

704 self, 

705 data: KeyedData, 

706 figure: Figure, 

707 gs: gridspec.Gridspec, 

708 ax: Axes, 

709 histIm: Optional[PolyCollection], 

710 **kwargs, 

711 ) -> None: 

712 # Side histogram 

713 sideHist = figure.add_subplot(gs[1:, -1], sharey=ax) 

714 

715 if "any" in self.plotTypes: 

716 y_all = data[f"y{self._datatypes['any'].suffix_xy}"] 

717 keys_notall = [x for x in self.plotTypes if x != "any"] 

718 else: 

719 y_all = np.concatenate([data[f"y{self._datatypes[key].suffix_xy}"] for key in self.plotTypes]) 

720 keys_notall = self.plotTypes 

721 

722 y_min, y_max = ax.get_ylim() 

723 bins = np.linspace(y_min, y_max, 100) 

724 sideHist.hist(np.array(y_all), bins=bins, color="grey", alpha=0.3, orientation="horizontal", log=True) 

725 kwargs_hist = dict( 

726 bins=bins, 

727 histtype="step", 

728 log=True, 

729 orientation="horizontal", 

730 ) 

731 for key in keys_notall: 

732 config_datatype = self._datatypes[key] 

733 vector = data[f"y{config_datatype.suffix_xy}"] 

734 sideHist.hist( 

735 vector, 

736 color=config_datatype.color, 

737 **kwargs_hist, 

738 ) 

739 sideHist.hist( 

740 vector[cast(Vector, data[f"{key}HighSNMask"])], 

741 color=config_datatype.color, 

742 ls="--", 

743 **kwargs_hist, 

744 ) 

745 sideHist.hist( 

746 vector[cast(Vector, data[f"{key}LowSNMask"])], 

747 color=config_datatype.color, 

748 **kwargs_hist, 

749 ls=":", 

750 ) 

751 

752 # Add a horizontal reference line at 0 to the side histogram 

753 sideHist.axhline(0, color=kwargs["hlineColor"], ls=kwargs["hlineStyle"], alpha=0.7, zorder=-2) 

754 

755 sideHist.axes.get_yaxis().set_visible(False) 

756 sideHist.set_xlabel("Number", fontsize=8) 

757 if self.plot2DHist and histIm is not None: 

758 divider = make_axes_locatable(sideHist) 

759 cax = divider.append_axes("right", size="8%", pad=0) 

760 plt.colorbar(histIm, cax=cax, orientation="vertical", label="Number of Points Per Bin")