Coverage for python/lsst/analysis/tools/actions/plot/scatterplotWithTwoHists.py: 16%

319 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-14 03:17 -0700

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ("ScatterPlotStatsAction", "ScatterPlotWithTwoHists") 

25 

26from typing import Mapping, NamedTuple, Optional, cast 

27 

28import matplotlib.colors 

29import matplotlib.pyplot as plt 

30import numpy as np 

31from lsst.pex.config import Field 

32from lsst.pex.config.configurableActions import ConfigurableActionField 

33from lsst.pex.config.listField import ListField 

34from lsst.skymap import BaseSkyMap 

35from matplotlib import gridspec 

36from matplotlib.axes import Axes 

37from matplotlib.collections import PolyCollection 

38from matplotlib.figure import Figure 

39from matplotlib.path import Path 

40from mpl_toolkits.axes_grid1 import make_axes_locatable 

41 

42from ...interfaces import KeyedData, KeyedDataAction, KeyedDataSchema, PlotAction, Scalar, ScalarType, Vector 

43from ...statistics import nansigmaMad, sigmaMad 

44from ..keyedData.summaryStatistics import SummaryStatisticAction 

45from ..scalar import MedianAction 

46from ..vector import ConvertFluxToMag, SnSelector 

47from .plotUtils import addPlotInfo, addSummaryPlot, generateSummaryStats, mkColormap 

48 

49# ignore because coolwarm is actually part of module 

50cmapPatch = plt.cm.coolwarm.copy() # type: ignore 

51cmapPatch.set_bad(color="none") 

52 

53 

54class ScatterPlotStatsAction(KeyedDataAction): 

55 """Calculates the statistics needed for the 

56 scatter plot with two hists. 

57 """ 

58 

59 vectorKey = Field[str](doc="Vector on which to compute statistics") 

60 highSNSelector = ConfigurableActionField[SnSelector]( 

61 doc="Selector used to determine high SN Objects", default=SnSelector(threshold=2700) 

62 ) 

63 lowSNSelector = ConfigurableActionField[SnSelector]( 

64 doc="Selector used to determine low SN Objects", default=SnSelector(threshold=500) 

65 ) 

66 fluxType = Field[str](doc="Vector key to use to compute signal to noise ratio", default="{band}_psfFlux") 

67 

68 def getInputSchema(self, **kwargs) -> KeyedDataSchema: 

69 yield (self.vectorKey, Vector) 

70 yield (self.fluxType, Vector) 

71 yield from self.highSNSelector.getInputSchema() 

72 yield from self.lowSNSelector.getInputSchema() 

73 

74 def getOutputSchema(self) -> KeyedDataSchema: 

75 return ( 

76 (f'{self.identity or ""}HighSNMask', Vector), 

77 (f'{self.identity or ""}LowSNMask', Vector), 

78 (f"{{band}}_lowSN{self.identity.capitalize() if self.identity else ''}_median", Scalar), 

79 (f"{{band}}_lowSN{self.identity.capitalize() if self.identity else ''}_sigmaMad", Scalar), 

80 (f"{{band}}_lowSN{self.identity.capitalize() if self.identity else ''}_count", Scalar), 

81 (f"{{band}}_lowSN{self.identity.capitalize() if self.identity else ''}_approxMag", Scalar), 

82 (f"{{band}}_highSN{self.identity.capitalize() if self.identity else ''}_median", Scalar), 

83 (f"{{band}}_highSN{self.identity.capitalize() if self.identity else ''}_sigmaMad", Scalar), 

84 (f"{{band}}_highSN{self.identity.capitalize() if self.identity else ''}_count", Scalar), 

85 (f"{{band}}_highSN{self.identity.capitalize() if self.identity else ''}_approxMag", Scalar), 

86 ("highThreshold", Scalar), 

87 ("lowThreshold", Scalar), 

88 ) 

89 

90 def __call__(self, data: KeyedData, **kwargs) -> KeyedData: 

91 results = {} 

92 highMaskKey = f'{(self.identity or "").lower()}HighSNMask' 

93 results[highMaskKey] = self.highSNSelector(data, **kwargs) 

94 

95 lowMaskKey = f'{(self.identity or "").lower()}LowSNMask' 

96 results[lowMaskKey] = self.lowSNSelector(data, **kwargs) 

97 

98 prefix = f"{band}_" if (band := kwargs.get("band")) else "" 

99 fluxes = data[self.fluxType.format(band=band)] if band is not None else None 

100 

101 statAction = SummaryStatisticAction(vectorKey=self.vectorKey) 

102 

103 # this is sad, but pex_config seems to have broken behavior that 

104 # is dangerous to fix 

105 statAction.setDefaults() 

106 

107 medianAction = MedianAction(vectorKey="mag") 

108 magAction = ConvertFluxToMag(vectorKey="flux") 

109 

110 for maskKey, binName in ((lowMaskKey, "low"), (highMaskKey, "high")): 

111 name = f"{prefix}{binName}SN{self.identity.capitalize() if self.identity else ''}" 

112 # set the approxMag to the median mag in the SN selection 

113 results[f"{name}_approxMag".format(**kwargs)] = ( 

114 medianAction({"mag": magAction({"flux": fluxes[results[maskKey]]})}) # type: ignore 

115 if band is not None 

116 else np.nan 

117 ) 

118 stats = statAction(data, **(kwargs | {"mask": results[maskKey]})).items() 

119 for suffix, value in stats: 

120 tmpKey = f"{name}_{suffix}".format(**kwargs) 

121 results[tmpKey] = value 

122 results["highSnThreshold"] = self.highSNSelector.threshold # type: ignore 

123 results["lowSnThreshold"] = self.lowSNSelector.threshold # type: ignore 

124 

125 return results 

126 

127 

128def _validObjectTypes(value): 

129 return value in ("stars", "galaxies", "unknown", "any") 

130 

131 

132# ignore type because of conflicting name on tuple baseclass 

133class _StatsContainer(NamedTuple): 

134 median: Scalar 

135 sigmaMad: Scalar 

136 count: Scalar # type: ignore 

137 approxMag: Scalar 

138 

139 

140class DataTypeDefaults(NamedTuple): 

141 suffix_stat: str 

142 suffix_xy: str 

143 color: str 

144 colormap: matplotlib.colors.Colormap | None 

145 

146 

147class ScatterPlotWithTwoHists(PlotAction): 

148 """Makes a scatter plot of the data with a marginal 

149 histogram for each axis. 

150 """ 

151 

152 yLims = ListField[float]( 

153 doc="ylimits of the plot, if not specified determined from data", 

154 length=2, 

155 optional=True, 

156 ) 

157 

158 xLims = ListField[float]( 

159 doc="xlimits of the plot, if not specified determined from data", length=2, optional=True 

160 ) 

161 xAxisLabel = Field[str](doc="Label to use for the x axis", optional=False) 

162 yAxisLabel = Field[str](doc="Label to use for the y axis", optional=False) 

163 magLabel = Field[str](doc="Label to use for the magnitudes used for SNR", optional=False) 

164 nBins = Field[float](doc="Number of bins on x axis", default=40.0) 

165 plot2DHist = Field[bool]( 

166 doc="Plot a 2D histogram in dense areas of points on the scatter plot." 

167 "Doesn't look great if plotting multiple datasets on top of each other.", 

168 default=True, 

169 ) 

170 plotTypes = ListField[str]( 

171 doc="Selection of types of objects to plot. Can take any combination of" 

172 " stars, galaxies, unknown, any", 

173 optional=False, 

174 itemCheck=_validObjectTypes, 

175 ) 

176 

177 addSummaryPlot = Field[bool]( 

178 doc="Add a summary plot to the figure?", 

179 default=False, 

180 ) 

181 

182 _stats = ("median", "sigmaMad", "count", "approxMag") 

183 _datatypes = { 

184 "galaxies": DataTypeDefaults( 

185 suffix_stat="Galaxies", 

186 suffix_xy="Galaxies", 

187 color="firebrick", 

188 colormap=mkColormap(["lemonchiffon", "firebrick"]), 

189 ), 

190 "stars": DataTypeDefaults( 

191 suffix_stat="Stars", 

192 suffix_xy="Stars", 

193 color="midnightblue", 

194 colormap=mkColormap(["paleturquoise", "midnightBlue"]), 

195 ), 

196 "unknown": DataTypeDefaults( 

197 suffix_stat="Unknown", 

198 suffix_xy="Unknown", 

199 color="green", 

200 colormap=None, 

201 ), 

202 "any": DataTypeDefaults( 

203 suffix_stat="Any", 

204 suffix_xy="", 

205 color="purple", 

206 colormap=None, 

207 ), 

208 } 

209 

210 def getInputSchema(self) -> KeyedDataSchema: 

211 base: list[tuple[str, type[Vector] | ScalarType]] = [] 

212 for name_datatype in self.plotTypes: 

213 config_datatype = self._datatypes[name_datatype] 

214 base.append((f"x{config_datatype.suffix_xy}", Vector)) 

215 base.append((f"y{config_datatype.suffix_xy}", Vector)) 

216 base.append((f"{name_datatype}HighSNMask", Vector)) 

217 base.append((f"{name_datatype}LowSNMask", Vector)) 

218 # statistics 

219 for name in self._stats: 

220 base.append((f"{{band}}_highSN{config_datatype.suffix_stat}_{name}", Scalar)) 

221 base.append((f"{{band}}_lowSN{config_datatype.suffix_stat}_{name}", Scalar)) 

222 base.append(("lowSnThreshold", Scalar)) 

223 base.append(("highSnThreshold", Scalar)) 

224 

225 if self.addSummaryPlot: 

226 base.append(("patch", Vector)) 

227 

228 return base 

229 

230 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure: 

231 self._validateInput(data, **kwargs) 

232 return self.makePlot(data, **kwargs) 

233 

234 def _validateInput(self, data: KeyedData, **kwargs) -> None: 

235 """NOTE currently can only check that something is not a Scalar, not 

236 check that the data is consistent with Vector 

237 """ 

238 needed = self.getFormattedInputSchema(**kwargs) 

239 if remainder := {key.format(**kwargs) for key, _ in needed} - { 

240 key.format(**kwargs) for key in data.keys() 

241 }: 

242 raise ValueError(f"Task needs keys {remainder} but they were not found in input") 

243 for name, typ in needed: 

244 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar) 

245 if isScalar and typ != Scalar: 

246 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}") 

247 

248 def makePlot( 

249 self, 

250 data: KeyedData, 

251 skymap: BaseSkyMap, 

252 plotInfo: Mapping[str, str], 

253 **kwargs, 

254 ) -> Figure: 

255 """Makes a generic plot with a 2D histogram and collapsed histograms of 

256 each axis. 

257 

258 Parameters 

259 ---------- 

260 data : `KeyedData` 

261 The catalog to plot the points from. 

262 skymap : `lsst.skymap.BaseSkyMap` 

263 The skymap that gives the patch locations 

264 plotInfo : `dict` 

265 A dictionary of information about the data being plotted with keys: 

266 

267 * ``"run"`` 

268 The output run for the plots (`str`). 

269 * ``"skymap"`` 

270 The type of skymap used for the data (`str`). 

271 * ``"filter"`` 

272 The filter used for this data (`str`). 

273 * ``"tract"`` 

274 The tract that the data comes from (`str`). 

275 

276 Returns 

277 ------- 

278 fig : `matplotlib.figure.Figure` 

279 The resulting figure. 

280 

281 Notes 

282 ----- 

283 Uses the axisLabels config options `x` and `y` and the axisAction 

284 config options `xAction` and `yAction` to plot a scatter 

285 plot of the values against each other. A histogram of the points 

286 collapsed onto each axis is also plotted. A summary panel showing the 

287 median of the y value in each patch is shown in the upper right corner 

288 of the resultant plot. The code uses the selectorActions to decide 

289 which points to plot and the statisticSelector actions to determine 

290 which points to use for the printed statistics. 

291 

292 If this function is being used within the pipetask framework 

293 that takes care of making sure that data has all the required 

294 elements but if you are running this as a standalone function 

295 then you will need to provide the following things in the 

296 input data. 

297 

298 * If stars is in self.plotTypes: 

299 xStars, yStars, starsHighSNMask, starsLowSNMask and 

300 {band}_highSNStars_{name}, {band}_lowSNStars_{name} 

301 where name is median, sigma_Mad, count and approxMag. 

302 

303 * If it is for galaxies/unknowns then replace stars in the above 

304 names with galaxies/unknowns. 

305 

306 * If it is for any (which covers all the points) then it 

307 becomes, x, y, and any instead of stars for the other 

308 parameters given above. 

309 

310 * In every case it is expected that data contains: 

311 lowSnThreshold, highSnThreshold and patch 

312 (if the summary plot is being plotted). 

313 

314 Examples 

315 -------- 

316 An example of the plot produced from this code is here: 

317 

318 .. image:: /_static/analysis_tools/scatterPlotExample.png 

319 

320 For a detailed example of how to make a plot from the command line 

321 please see the 

322 :ref:`getting started guide<analysis-tools-getting-started>`. 

323 """ 

324 if not self.plotTypes: 

325 noDataFig = Figure() 

326 noDataFig.text(0.3, 0.5, "No data to plot after selectors applied") 

327 noDataFig = addPlotInfo(noDataFig, plotInfo) 

328 return noDataFig 

329 

330 # Set default color and line style for the horizontal 

331 # reference line at 0 

332 if "hlineColor" not in kwargs: 

333 kwargs["hlineColor"] = "black" 

334 

335 if "hlineStyle" not in kwargs: 

336 kwargs["hlineStyle"] = (0, (1, 4)) 

337 

338 fig = plt.figure(dpi=300) 

339 gs = gridspec.GridSpec(4, 4) 

340 

341 # add the various plot elements 

342 ax, imhist = self._scatterPlot(data, fig, gs, **kwargs) 

343 self._makeTopHistogram(data, fig, gs, ax, **kwargs) 

344 self._makeSideHistogram(data, fig, gs, ax, imhist, **kwargs) 

345 # Needs info from run quantum 

346 if self.addSummaryPlot: 

347 sumStats = generateSummaryStats(data, skymap, plotInfo) 

348 label = self.yAxisLabel 

349 fig = addSummaryPlot(fig, gs[0, -1], sumStats, label) 

350 

351 plt.draw() 

352 plt.subplots_adjust(wspace=0.0, hspace=0.0, bottom=0.22, left=0.21) 

353 fig = addPlotInfo(fig, plotInfo) 

354 return fig 

355 

356 def _scatterPlot( 

357 self, data: KeyedData, fig: Figure, gs: gridspec.GridSpec, **kwargs 

358 ) -> tuple[Axes, Optional[PolyCollection]]: 

359 # Main scatter plot 

360 ax = fig.add_subplot(gs[1:, :-1]) 

361 

362 binThresh = 5 

363 

364 yBinsOut = [] 

365 linesForLegend = [] 

366 

367 toPlotList = [] 

368 histIm = None 

369 highStats: _StatsContainer 

370 lowStats: _StatsContainer 

371 

372 for name_datatype in self.plotTypes: 

373 config_datatype = self._datatypes[name_datatype] 

374 highArgs = {} 

375 lowArgs = {} 

376 for name in self._stats: 

377 highArgs[name] = cast( 

378 Scalar, data[f"{{band}}_highSN{config_datatype.suffix_stat}_{name}".format(**kwargs)] 

379 ) 

380 lowArgs[name] = cast( 

381 Scalar, data[f"{{band}}_lowSN{config_datatype.suffix_stat}_{name}".format(**kwargs)] 

382 ) 

383 highStats = _StatsContainer(**highArgs) 

384 lowStats = _StatsContainer(**lowArgs) 

385 

386 toPlotList.append( 

387 ( 

388 data[f"x{config_datatype.suffix_xy}"], 

389 data[f"y{config_datatype.suffix_xy}"], 

390 data[f"{name_datatype}HighSNMask"], 

391 data[f"{name_datatype}LowSNMask"], 

392 config_datatype.color, 

393 config_datatype.colormap, 

394 highStats, 

395 lowStats, 

396 ) 

397 ) 

398 

399 xLims = self.xLims if self.xLims is not None else [np.Inf, -np.Inf] 

400 for j, (xs, ys, highSn, lowSn, color, cmap, highStats, lowStats) in enumerate(toPlotList): 

401 highSn = cast(Vector, highSn) 

402 lowSn = cast(Vector, lowSn) 

403 # ensure the columns are actually array 

404 xs = np.array(xs) 

405 ys = np.array(ys) 

406 sigMadYs = nansigmaMad(ys) 

407 # plot lone median point if there's not enough data to measure more 

408 n_xs = len(xs) 

409 if n_xs == 0: 

410 continue 

411 elif n_xs < 10: 

412 xs = [np.nanmedian(xs)] 

413 sigMads = np.array([nansigmaMad(ys)]) 

414 ys = np.array([np.nanmedian(ys)]) 

415 (medLine,) = ax.plot(xs, ys, color, label=f"Median: {ys[0]:.2g}", lw=0.8) 

416 linesForLegend.append(medLine) 

417 (sigMadLine,) = ax.plot( 

418 xs, 

419 ys + 1.0 * sigMads, 

420 color, 

421 alpha=0.8, 

422 lw=0.8, 

423 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:.2g}", 

424 ) 

425 ax.plot(xs, ys - 1.0 * sigMads, color, alpha=0.8) 

426 linesForLegend.append(sigMadLine) 

427 histIm = None 

428 continue 

429 

430 if self.xLims: 

431 xMin, xMax = self.xLims 

432 else: 

433 # Chop off 1/3% from only the finite xs values 

434 # (there may be +/-np.Inf values) 

435 # TODO: This should be configurable 

436 # but is there a good way to avoid redundant config params 

437 # without using slightly annoying subconfigs? 

438 xs1, xs97 = np.nanpercentile(xs[np.isfinite(xs)], (1, 97)) 

439 xScale = (xs97 - xs1) / 20.0 # This is ~5% of the data range 

440 xMin, xMax = (xs1 - xScale, xs97 + xScale) 

441 xLims[0] = min(xLims[0], xMin) 

442 xLims[1] = max(xLims[1], xMax) 

443 

444 xEdges = np.arange(xMin, xMax, (xMax - xMin) / self.nBins) 

445 medYs = np.nanmedian(ys) 

446 fiveSigmaHigh = medYs + 5.0 * sigMadYs 

447 fiveSigmaLow = medYs - 5.0 * sigMadYs 

448 binSize = (fiveSigmaHigh - fiveSigmaLow) / 101.0 

449 yEdges = np.arange(fiveSigmaLow, fiveSigmaHigh, binSize) 

450 

451 counts, xBins, yBins = np.histogram2d(xs, ys, bins=(xEdges, yEdges)) 

452 yBinsOut.append(yBins) 

453 countsYs = np.sum(counts, axis=1) 

454 

455 ids = np.where((countsYs > binThresh))[0] 

456 xEdgesPlot = xEdges[ids][1:] 

457 xEdges = xEdges[ids] 

458 

459 if len(ids) > 1: 

460 # Create the codes needed to turn the sigmaMad lines 

461 # into a path to speed up checking which points are 

462 # inside the area. 

463 codes = np.ones(len(xEdgesPlot) * 2) * Path.LINETO 

464 codes[0] = Path.MOVETO 

465 codes[-1] = Path.CLOSEPOLY 

466 

467 meds = np.zeros(len(xEdgesPlot)) 

468 threeSigMadVerts = np.zeros((len(xEdgesPlot) * 2, 2)) 

469 sigMads = np.zeros(len(xEdgesPlot)) 

470 

471 for i, xEdge in enumerate(xEdgesPlot): 

472 ids = np.where((xs < xEdge) & (xs > xEdges[i]) & (np.isfinite(ys)))[0] 

473 med = np.nanmedian(ys[ids]) 

474 sigMad = sigmaMad(ys[ids], nan_policy="omit") 

475 meds[i] = med 

476 sigMads[i] = sigMad 

477 threeSigMadVerts[i, :] = [xEdge, med + 3 * sigMad] 

478 threeSigMadVerts[-(i + 1), :] = [xEdge, med - 3 * sigMad] 

479 

480 (medLine,) = ax.plot(xEdgesPlot, meds, color, label="Running Median") 

481 linesForLegend.append(medLine) 

482 

483 # Make path to check which points lie within one sigma mad 

484 threeSigMadPath = Path(threeSigMadVerts, codes) 

485 

486 # Add lines for the median +/- 3 * sigma MAD 

487 (threeSigMadLine,) = ax.plot( 

488 xEdgesPlot, 

489 threeSigMadVerts[: len(xEdgesPlot), 1], 

490 color, 

491 alpha=0.4, 

492 label=r"3$\sigma_{MAD}$", 

493 ) 

494 ax.plot(xEdgesPlot[::-1], threeSigMadVerts[len(xEdgesPlot) :, 1], color, alpha=0.4) 

495 

496 # Add lines for the median +/- 1 * sigma MAD 

497 (sigMadLine,) = ax.plot( 

498 xEdgesPlot, meds + 1.0 * sigMads, color, alpha=0.8, label=r"$\sigma_{MAD}$" 

499 ) 

500 linesForLegend.append(sigMadLine) 

501 ax.plot(xEdgesPlot, meds - 1.0 * sigMads, color, alpha=0.8) 

502 

503 # Add lines for the median +/- 2 * sigma MAD 

504 (twoSigMadLine,) = ax.plot( 

505 xEdgesPlot, meds + 2.0 * sigMads, color, alpha=0.6, label=r"2$\sigma_{MAD}$" 

506 ) 

507 linesForLegend.append(twoSigMadLine) 

508 linesForLegend.append(threeSigMadLine) 

509 ax.plot(xEdgesPlot, meds - 2.0 * sigMads, color, alpha=0.6) 

510 

511 # Check which points are outside 3 sigma MAD of the median 

512 # and plot these as points. 

513 inside = threeSigMadPath.contains_points(np.array([xs, ys]).T) 

514 ax.plot(xs[~inside], ys[~inside], ".", ms=3, alpha=0.3, mfc=color, mec=color, zorder=-1) 

515 

516 # Add some stats text 

517 xPos = 0.65 - 0.4 * j 

518 bbox = dict(edgecolor=color, linestyle="--", facecolor="none") 

519 highThresh = data["highSnThreshold"] 

520 statText = f"S/N > {highThresh:0.4g} Stats ({self.magLabel} < {highStats.approxMag:0.4g})\n" 

521 highStatsStr = ( 

522 f"Median: {highStats.median:0.4g} " 

523 + r"$\sigma_{MAD}$: " 

524 + f"{highStats.sigmaMad:0.4g} " 

525 + r"N$_{points}$: " 

526 + f"{highStats.count}" 

527 ) 

528 statText += highStatsStr 

529 fig.text(xPos, 0.090, statText, bbox=bbox, transform=fig.transFigure, fontsize=6) 

530 

531 bbox = dict(edgecolor=color, linestyle=":", facecolor="none") 

532 lowThresh = data["lowSnThreshold"] 

533 statText = f"S/N > {lowThresh:0.4g} Stats ({self.magLabel} < {lowStats.approxMag:0.4g})\n" 

534 lowStatsStr = ( 

535 f"Median: {lowStats.median:0.4g} " 

536 + r"$\sigma_{MAD}$: " 

537 + f"{lowStats.sigmaMad:0.4g} " 

538 + r"N$_{points}$: " 

539 + f"{lowStats.count}" 

540 ) 

541 statText += lowStatsStr 

542 fig.text(xPos, 0.020, statText, bbox=bbox, transform=fig.transFigure, fontsize=6) 

543 

544 if self.plot2DHist: 

545 histIm = ax.hexbin(xs[inside], ys[inside], gridsize=75, cmap=cmap, mincnt=1, zorder=-3) 

546 

547 # If there are not many sources being used for the 

548 # statistics then plot them individually as just 

549 # plotting a line makes the statistics look wrong 

550 # as the magnitude estimation is iffy for low 

551 # numbers of sources. 

552 if np.sum(highSn) < 100 and np.sum(highSn) > 0: 

553 ax.plot( 

554 cast(Vector, xs[highSn]), 

555 cast(Vector, ys[highSn]), 

556 marker="x", 

557 ms=4, 

558 mec="w", 

559 mew=2, 

560 ls="none", 

561 ) 

562 (highSnLine,) = ax.plot( 

563 cast(Vector, xs[highSn]), 

564 cast(Vector, ys[highSn]), 

565 color=color, 

566 marker="x", 

567 ms=4, 

568 ls="none", 

569 label="High SN", 

570 ) 

571 linesForLegend.append(highSnLine) 

572 else: 

573 ax.axvline(highStats.approxMag, color=color, ls="--") 

574 

575 if np.sum(lowSn) < 100 and np.sum(lowSn) > 0: 

576 ax.plot( 

577 cast(Vector, xs[lowSn]), 

578 cast(Vector, ys[lowSn]), 

579 marker="+", 

580 ms=4, 

581 mec="w", 

582 mew=2, 

583 ls="none", 

584 ) 

585 (lowSnLine,) = ax.plot( 

586 cast(Vector, xs[lowSn]), 

587 cast(Vector, ys[lowSn]), 

588 color=color, 

589 marker="+", 

590 ms=4, 

591 ls="none", 

592 label="Low SN", 

593 ) 

594 linesForLegend.append(lowSnLine) 

595 else: 

596 ax.axvline(lowStats.approxMag, color=color, ls=":") 

597 

598 else: 

599 ax.plot(xs, ys, ".", ms=5, alpha=0.3, mfc=color, mec=color, zorder=-1) 

600 meds = np.array([np.nanmedian(ys)] * len(xs)) 

601 (medLine,) = ax.plot(xs, meds, color, label=f"Median: {np.nanmedian(ys):0.3g}", lw=0.8) 

602 linesForLegend.append(medLine) 

603 sigMads = np.array([nansigmaMad(ys)] * len(xs)) 

604 (sigMadLine,) = ax.plot( 

605 xs, 

606 meds + 1.0 * sigMads, 

607 color, 

608 alpha=0.8, 

609 lw=0.8, 

610 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:0.3g}", 

611 ) 

612 ax.plot(xs, meds - 1.0 * sigMads, color, alpha=0.8) 

613 linesForLegend.append(sigMadLine) 

614 histIm = None 

615 

616 # Add a horizontal reference line at 0 to the scatter plot 

617 ax.axhline(0, color=kwargs["hlineColor"], ls=kwargs["hlineStyle"], alpha=0.7, zorder=-2) 

618 

619 # Set the scatter plot limits 

620 # TODO: Make this not work by accident 

621 if "yStars" in data and (len(cast(Vector, data["yStars"])) > 0): 

622 plotMed = np.nanmedian(cast(Vector, data["yStars"])) 

623 elif "yGalaxies" in data and (len(cast(Vector, data["yGalaxies"])) > 0): 

624 plotMed = np.nanmedian(cast(Vector, data["yGalaxies"])) 

625 else: 

626 plotMed = np.nan 

627 

628 # Ignore types below pending making this not working my accident 

629 if len(xs) < 2: # type: ignore 

630 meds = [np.nanmedian(ys)] # type: ignore 

631 if self.yLims: 

632 ax.set_ylim(self.yLims[0], self.yLims[1]) # type: ignore 

633 elif np.isfinite(plotMed): 

634 numSig = 4 

635 yLimMin = plotMed - numSig * sigMadYs # type: ignore 

636 yLimMax = plotMed + numSig * sigMadYs # type: ignore 

637 while (yLimMax < np.max(meds) or yLimMin > np.min(meds)) and numSig < 10: # type: ignore 

638 numSig += 1 

639 

640 numSig += 1 

641 yLimMin = plotMed - numSig * sigMadYs # type: ignore 

642 yLimMax = plotMed + numSig * sigMadYs # type: ignore 

643 ax.set_ylim(yLimMin, yLimMax) 

644 

645 # This could be false if len(x) == 0 for xs in toPlotList 

646 # ... in which case nothing was plotted and limits are irrelevant 

647 if all(np.isfinite(xLims)): 

648 ax.set_xlim(xLims) 

649 

650 # Add a line legend 

651 ax.legend( 

652 handles=linesForLegend, 

653 ncol=4, 

654 fontsize=6, 

655 loc="upper left", 

656 framealpha=0.9, 

657 edgecolor="k", 

658 borderpad=0.4, 

659 handlelength=1, 

660 ) 

661 

662 # Add axes labels 

663 ax.set_ylabel(self.yAxisLabel, fontsize=10, labelpad=10) 

664 ax.set_xlabel(self.xAxisLabel, fontsize=10, labelpad=2) 

665 

666 return ax, histIm 

667 

668 def _makeTopHistogram( 

669 self, data: KeyedData, figure: Figure, gs: gridspec.GridSpec, ax: Axes, **kwargs 

670 ) -> None: 

671 # Top histogram 

672 topHist = figure.add_subplot(gs[0, :-1], sharex=ax) 

673 

674 if "all" in self.plotTypes: 

675 x_all = f"x{self._datatypes['all'].suffix_xy}" 

676 keys_notall = [x for x in self.plotTypes if x != "all"] 

677 else: 

678 x_all = np.concatenate([data[f"x{self._datatypes[key].suffix_xy}"] for key in self.plotTypes]) 

679 keys_notall = self.plotTypes 

680 

681 x_min, x_max = ax.get_xlim() 

682 bins = np.linspace(x_min, x_max, 100) 

683 topHist.hist(x_all, bins=bins, color="grey", alpha=0.3, log=True, label=f"All ({len(x_all)})") 

684 for key in keys_notall: 

685 config_datatype = self._datatypes[key] 

686 vector = data[f"x{config_datatype.suffix_xy}"] 

687 topHist.hist( 

688 vector, 

689 bins=bins, 

690 color=config_datatype.color, 

691 histtype="step", 

692 log=True, 

693 label=f"{config_datatype.suffix_stat} ({len(vector)})", 

694 ) 

695 topHist.axes.get_xaxis().set_visible(False) 

696 topHist.set_ylabel("Number", fontsize=8) 

697 topHist.legend(fontsize=6, framealpha=0.9, borderpad=0.4, loc="lower left", ncol=3, edgecolor="k") 

698 

699 # Side histogram 

700 

701 def _makeSideHistogram( 

702 self, 

703 data: KeyedData, 

704 figure: Figure, 

705 gs: gridspec.Gridspec, 

706 ax: Axes, 

707 histIm: Optional[PolyCollection], 

708 **kwargs, 

709 ) -> None: 

710 # Side histogram 

711 sideHist = figure.add_subplot(gs[1:, -1], sharey=ax) 

712 

713 if "all" in self.plotTypes: 

714 y_all = f"y{self._datatypes['all'].suffix_xy}" 

715 keys_notall = [x for x in self.plotTypes if x != "all"] 

716 else: 

717 y_all = np.concatenate([data[f"y{self._datatypes[key].suffix_xy}"] for key in self.plotTypes]) 

718 keys_notall = self.plotTypes 

719 

720 y_min, y_max = ax.get_ylim() 

721 bins = np.linspace(y_min, y_max, 100) 

722 sideHist.hist(y_all, bins=bins, color="grey", alpha=0.3, orientation="horizontal", log=True) 

723 kwargs_hist = dict( 

724 bins=bins, 

725 histtype="step", 

726 log=True, 

727 orientation="horizontal", 

728 ) 

729 for key in keys_notall: 

730 config_datatype = self._datatypes[key] 

731 vector = data[f"y{config_datatype.suffix_xy}"] 

732 sideHist.hist( 

733 vector, 

734 color=config_datatype.color, 

735 **kwargs_hist, 

736 ) 

737 sideHist.hist( 

738 vector[cast(Vector, data[f"{key}HighSNMask"])], 

739 color=config_datatype.color, 

740 ls="--", 

741 **kwargs_hist, 

742 ) 

743 sideHist.hist( 

744 vector[cast(Vector, data[f"{key}LowSNMask"])], 

745 color=config_datatype.color, 

746 **kwargs_hist, 

747 ls=":", 

748 ) 

749 

750 # Add a horizontal reference line at 0 to the side histogram 

751 sideHist.axhline(0, color=kwargs["hlineColor"], ls=kwargs["hlineStyle"], alpha=0.7, zorder=-2) 

752 

753 sideHist.axes.get_yaxis().set_visible(False) 

754 sideHist.set_xlabel("Number", fontsize=8) 

755 if self.plot2DHist and histIm is not None: 

756 divider = make_axes_locatable(sideHist) 

757 cax = divider.append_axes("right", size="8%", pad=0) 

758 plt.colorbar(histIm, cax=cax, orientation="vertical", label="Number of Points Per Bin")