Coverage for python/lsst/analysis/tools/actions/plot/scatterplotWithTwoHists.py: 13%

360 statements  

« prev     ^ index     » next       coverage.py v6.4.4, created at 2022-08-18 12:39 -0700

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("ScatterPlotStatsAction", "ScatterPlotWithTwoHists") 

24 

25from itertools import chain 

26from typing import Mapping, NamedTuple, Optional, cast 

27 

28import matplotlib.pyplot as plt 

29import numpy as np 

30from lsst.pex.config import Field 

31from lsst.pex.config.listField import ListField 

32from lsst.pipe.tasks.configurableActions import ConfigurableActionField 

33from lsst.skymap import BaseSkyMap 

34from matplotlib import gridspec 

35from matplotlib.axes import Axes 

36from matplotlib.collections import PolyCollection 

37from matplotlib.figure import Figure 

38from matplotlib.path import Path 

39from mpl_toolkits.axes_grid1 import make_axes_locatable 

40 

41from ...interfaces import KeyedData, KeyedDataAction, KeyedDataSchema, PlotAction, Scalar, Vector 

42from ..keyedData.summaryStatistics import SummaryStatisticAction, sigmaMad 

43from ..scalar import MedianAction 

44from ..vector import MagColumnNanoJansky, SnSelector 

45from .plotUtils import addPlotInfo, addSummaryPlot, generateSummaryStats, mkColormap 

46 

47# ignore because coolwarm is actually part of module 

48cmapPatch = plt.cm.coolwarm.copy() # type: ignore 

49cmapPatch.set_bad(color="none") 

50 

51 

52class ScatterPlotStatsAction(KeyedDataAction): 

53 vectorKey = Field[str](doc="Vector on which to compute statistics") 

54 highSNSelector = ConfigurableActionField[SnSelector]( 

55 doc="Selector used to determine high SN Objects", default=SnSelector(threshold=2700) 

56 ) 

57 lowSNSelector = ConfigurableActionField[SnSelector]( 

58 doc="Selector used to determine low SN Objects", default=SnSelector(threshold=500) 

59 ) 

60 fluxType = Field[str](doc="Vector key to use to compute signal to noise ratio", default="{band}_psfFlux") 

61 

62 def getInputSchema(self, **kwargs) -> KeyedDataSchema: 

63 yield (self.vectorKey, Vector) 

64 yield (self.fluxType, Vector) 

65 yield from self.highSNSelector.getInputSchema() 

66 yield from self.lowSNSelector.getInputSchema() 

67 

68 def getOutputSchema(self) -> KeyedDataSchema: 

69 return ( 

70 (f'{self.identity or ""}HighSNMask', Vector), 

71 (f'{self.identity or ""}LowSNMask', Vector), 

72 (f"{{band}}_lowSN{self.identity.capitalize() if self.identity else ''}_median", Scalar), 

73 (f"{{band}}_lowSN{self.identity.capitalize() if self.identity else ''}_sigmaMad", Scalar), 

74 (f"{{band}}_lowSN{self.identity.capitalize() if self.identity else ''}_count", Scalar), 

75 (f"{{band}}_lowSN{self.identity.capitalize() if self.identity else ''}_approxMag", Scalar), 

76 (f"{{band}}_highSN{self.identity.capitalize() if self.identity else ''}_median", Scalar), 

77 (f"{{band}}_highSN{self.identity.capitalize() if self.identity else ''}_sigmaMad", Scalar), 

78 (f"{{band}}_highSN{self.identity.capitalize() if self.identity else ''}_count", Scalar), 

79 (f"{{band}}_highSN{self.identity.capitalize() if self.identity else ''}_approxMag", Scalar), 

80 ("highThreshold", Scalar), 

81 ("lowThreshold", Scalar), 

82 ) 

83 

84 def __call__(self, data: KeyedData, **kwargs) -> KeyedData: 

85 results = {} 

86 highMaskKey = f'{self.identity.lower() or ""}HighSNMask' 

87 results[highMaskKey] = self.highSNSelector(data, **kwargs) 

88 

89 lowMaskKey = f'{self.identity.lower() or ""}LowSNMask' 

90 results[lowMaskKey] = self.lowSNSelector(data, **kwargs) 

91 

92 prefix = f"{band}_" if (band := kwargs.get("band")) else "" 

93 fluxes = data[self.fluxType.format(band=band)] if band is not None else None 

94 

95 statAction = SummaryStatisticAction(vectorKey=self.vectorKey) 

96 

97 # this is sad, but pex_config seems to have broken behavior that 

98 # is dangerous to fix 

99 statAction.setDefaults() 

100 

101 medianAction = MedianAction(vectorKey="mag") 

102 magAction = MagColumnNanoJansky(vectorKey="flux") 

103 

104 for maskKey, binName in ((lowMaskKey, "low"), (highMaskKey, "high")): 

105 name = f"{prefix}{binName}SN{self.identity.capitalize() if self.identity else ''}" 

106 # set the approxMag to the median mag in the SN selection 

107 results[f"{name}_approxMag".format(**kwargs)] = ( 

108 medianAction({"mag": magAction({"flux": fluxes[results[maskKey]]})}) # type: ignore 

109 if band is not None 

110 else np.nan 

111 ) 

112 stats = statAction(data, **(kwargs | {"mask": results[maskKey]})).items() 

113 for suffix, value in stats: 

114 tmpKey = f"{name}_{suffix}".format(**kwargs) 

115 results[tmpKey] = value 

116 results["highSnThreshold"] = self.highSNSelector.threshold # type: ignore 

117 results["lowSnThreshold"] = self.lowSNSelector.threshold # type: ignore 

118 

119 return results 

120 

121 

122def _validatePlotTypes(value): 

123 return value in ("stars", "galaxies", "unknown", "any", "mag") 

124 

125 

126# ignore type because of conflicting name on tuple baseclass 

127class _StatsContainer(NamedTuple): 

128 median: Scalar 

129 sigmaMad: Scalar 

130 count: Scalar # type: ignore 

131 approxMag: Scalar 

132 

133 

134class ScatterPlotWithTwoHists(PlotAction): 

135 yLims = ListField[float]( 

136 doc="ylimits of the plot, if not specified determined from data", 

137 length=2, 

138 optional=True, 

139 ) 

140 

141 xLims = ListField[float]( 

142 doc="xlimits of the plot, if not specified determined from data", length=2, optional=True 

143 ) 

144 xAxisLabel = Field[str](doc="Label to use for the x axis", optional=False) 

145 yAxisLabel = Field[str](doc="Label to use for the y axis", optional=False) 

146 magLabel = Field[str](doc="Label to use for the magnitudes used for SNR", optional=False) 

147 nBins = Field[float](doc="Number of bins on x axis", default=40.0) 

148 plot2DHist = Field[bool]( 

149 doc="Plot a 2D histogram in dense areas of points on the scatter plot." 

150 "Doesn't look great if plotting multiple datasets on top of each other.", 

151 default=True, 

152 ) 

153 plotTypes = ListField[str]( 

154 doc="Selection of types of objects to plot. Can take any combination of" 

155 " stars, galaxies, unknown, mag, any", 

156 optional=False, 

157 itemCheck=_validatePlotTypes, 

158 ) 

159 

160 addSummaryPlot = Field[bool]( 

161 doc="Add a summary plot to the figure?", 

162 default=False, 

163 ) 

164 

165 _stats = ("median", "sigmaMad", "count", "approxMag") 

166 

167 def getInputSchema(self) -> KeyedDataSchema: 

168 base: list[tuple[str, type[Vector] | type[Scalar]]] = [] 

169 if "stars" in self.plotTypes: # type: ignore 

170 base.append(("xStars", Vector)) 

171 base.append(("yStars", Vector)) 

172 base.append(("starsHighSNMask", Vector)) 

173 base.append(("starsLowSNMask", Vector)) 

174 # statistics 

175 for name in self._stats: 

176 base.append((f"{{band}}_highSNStars_{name}", Scalar)) 

177 base.append((f"{{band}}_lowSNStars_{name}", Scalar)) 

178 if "galaxies" in self.plotTypes: # type: ignore 

179 base.append(("xGalaxies", Vector)) 

180 base.append(("yGalaxies", Vector)) 

181 base.append(("galaxiesHighSNMask", Vector)) 

182 base.append(("galaxiesLowSNMask", Vector)) 

183 # statistics 

184 for name in self._stats: 

185 base.append((f"{{band}}_highSNGalaxies_{name}", Scalar)) 

186 base.append((f"{{band}}_lowSNGalaxies_{name}", Scalar)) 

187 if "unknown" in self.plotTypes: # type: ignore 

188 base.append(("xUnknown", Vector)) 

189 base.append(("yUnknown", Vector)) 

190 base.append(("unknownHighSNMask", Vector)) 

191 base.append(("unknownLowSNMask", Vector)) 

192 # statistics 

193 for name in self._stats: 

194 base.append((f"{{band}}_highSNUnknown_{name}", Scalar)) 

195 base.append((f"{{band}}_lowSNUnknown_{name}", Scalar)) 

196 if "any" in self.plotTypes: # type: ignore 

197 base.append(("x", Vector)) 

198 base.append(("y", Vector)) 

199 base.append(("anyHighSNMask", Vector)) 

200 base.append(("anySNMask", Vector)) 

201 # statistics 

202 for name in self._stats: 

203 base.append((f"{{band}}_highSNAny_{name}", Scalar)) 

204 base.append((f"{{band}}_lowSNAny_{name}", Scalar)) 

205 base.append(("lowSnThreshold", Scalar)) 

206 base.append(("highSnThreshold", Scalar)) 

207 

208 if self.addSummaryPlot: 

209 base.append(("patch", Vector)) 

210 

211 return base 

212 

213 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure: 

214 

215 self._validateInput(data, **kwargs) 

216 return self.makePlot(data, **kwargs) 

217 

218 def _validateInput(self, data: KeyedData, **kwargs) -> None: 

219 """NOTE currently can only check that something is not a Scalar, not 

220 check that the data is consistent with Vector 

221 """ 

222 needed = self.getFormattedInputSchema(**kwargs) 

223 if remainder := {key.format(**kwargs) for key, _ in needed} - { 

224 key.format(**kwargs) for key in data.keys() 

225 }: 

226 raise ValueError(f"Task needs keys {remainder} but they were not found in input") 

227 for name, typ in needed: 

228 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar) 

229 if isScalar and typ != Scalar: 

230 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}") 

231 

232 def makePlot( 

233 self, 

234 data: KeyedData, 

235 skymap: BaseSkyMap, 

236 plotInfo: Optional[Mapping[str, str]] = None, 

237 sumStats: Optional[Mapping] = None, 

238 **kwargs, 

239 ) -> Figure: 

240 """Makes a generic plot with a 2D histogram and collapsed histograms of 

241 each axis. 

242 Parameters 

243 ---------- 

244 data : `pandas.core.frame.DataFrame` 

245 The catalog to plot the points from. 

246 plotInfo : `dict` 

247 A dictionary of information about the data being plotted with keys: 

248 ``"run"`` 

249 The output run for the plots (`str`). 

250 ``"skymap"`` 

251 The type of skymap used for the data (`str`). 

252 ``"filter"`` 

253 The filter used for this data (`str`). 

254 ``"tract"`` 

255 The tract that the data comes from (`str`). 

256 sumStats : `dict` 

257 A dictionary where the patchIds are the keys which store the R.A. 

258 and dec of the corners of the patch, along with a summary 

259 statistic for each patch. 

260 Returns 

261 ------- 

262 fig : `matplotlib.figure.Figure` 

263 The resulting figure. 

264 Notes 

265 ----- 

266 Uses the axisLabels config options `x` and `y` and the axisAction 

267 config options `xAction` and `yAction` to plot a scatter 

268 plot of the values against each other. A histogram of the points 

269 collapsed onto each axis is also plotted. A summary panel showing the 

270 median of the y value in each patch is shown in the upper right corner 

271 of the resultant plot. The code uses the selectorActions to decide 

272 which points to plot and the statisticSelector actions to determine 

273 which points to use for the printed statistics. 

274 """ 

275 if not self.plotTypes: 

276 noDataFig = Figure() 

277 noDataFig.text(0.3, 0.5, "No data to plot after selectors applied") 

278 noDataFig = addPlotInfo(noDataFig, plotInfo) 

279 return noDataFig 

280 

281 fig = plt.figure(dpi=300) 

282 gs = gridspec.GridSpec(4, 4) 

283 

284 # add the various plot elements 

285 ax, imhist = self._scatterPlot(data, fig, gs, **kwargs) 

286 self._makeTopHistogram(data, fig, gs, ax, **kwargs) 

287 self._makeSideHistogram(data, fig, gs, ax, imhist, **kwargs) 

288 # Needs info from run quantum 

289 if self.addSummaryPlot: 

290 sumStats = generateSummaryStats(data, skymap, plotInfo) 

291 label = self.yAxisLabel 

292 fig = addSummaryPlot(fig, gs[0, -1], sumStats, label) 

293 

294 plt.draw() 

295 plt.subplots_adjust(wspace=0.0, hspace=0.0, bottom=0.22, left=0.21) 

296 fig = addPlotInfo(fig, plotInfo) 

297 return fig 

298 

299 def _scatterPlot( 

300 self, data: KeyedData, fig: Figure, gs: gridspec.GridSpec, **kwargs 

301 ) -> tuple[Axes, Optional[PolyCollection]]: 

302 # Main scatter plot 

303 ax = fig.add_subplot(gs[1:, :-1]) 

304 

305 newBlues = mkColormap(["paleturquoise", "midnightBlue"]) 

306 newReds = mkColormap(["lemonchiffon", "firebrick"]) 

307 

308 binThresh = 5 

309 

310 yBinsOut = [] 

311 linesForLegend = [] 

312 

313 toPlotList = [] 

314 histIm = None 

315 highStats: _StatsContainer 

316 lowStats: _StatsContainer 

317 if "stars" in self.plotTypes: # type: ignore 

318 highArgs = {} 

319 lowArgs = {} 

320 for name in self._stats: 

321 highArgs[name] = cast(Scalar, data[f"{{band}}_highSNStars_{name}".format(**kwargs)]) 

322 lowArgs[name] = cast(Scalar, data[f"{{band}}_lowSNStars_{name}".format(**kwargs)]) 

323 highStats = _StatsContainer(**highArgs) 

324 lowStats = _StatsContainer(**lowArgs) 

325 

326 toPlotList.append( 

327 ( 

328 data["xStars"], 

329 data["yStars"], 

330 data["starsHighSNMask"], 

331 data["starsLowSNMask"], 

332 "midnightblue", 

333 newBlues, 

334 highStats, 

335 lowStats, 

336 ) 

337 ) 

338 if "galaxies" in self.plotTypes: # type: ignore 

339 highArgs = {} 

340 lowArgs = {} 

341 for name in self._stats: 

342 highArgs[name] = cast(Scalar, data[f"{{band}}_highSNGalaxies_{name}".format(**kwargs)]) 

343 lowArgs[name] = cast(Scalar, data[f"{{band}}_lowSNGalaxies_{name}".format(**kwargs)]) 

344 highStats = _StatsContainer(**highArgs) 

345 lowStats = _StatsContainer(**lowArgs) 

346 

347 toPlotList.append( 

348 ( 

349 data["xGalaxies"], 

350 data["yGalaxies"], 

351 data["galaxiesHighSNMask"], 

352 data["galaxiesLowSNMask"], 

353 "firebrick", 

354 newReds, 

355 highStats, 

356 lowStats, 

357 ) 

358 ) 

359 if "unknown" in self.plotTypes: # type: ignore 

360 highArgs = {} 

361 lowArgs = {} 

362 for name in self._stats: 

363 highArgs[name] = cast(Scalar, data[f"{{band}}_highSNUnknown_{name}".format(**kwargs)]) 

364 lowArgs[name] = cast(Scalar, data[f"{{band}}_lowSNUnknown_{name}".format(**kwargs)]) 

365 highStats = _StatsContainer(**highArgs) 

366 lowStats = _StatsContainer(**lowArgs) 

367 

368 toPlotList.append( 

369 ( 

370 data["xUnknown"], 

371 data["yUnknown"], 

372 data["unknownHighSNMask"], 

373 data["unknownLowSNMask"], 

374 "green", 

375 None, 

376 highStats, 

377 lowStats, 

378 ) 

379 ) 

380 if "any" in self.plotTypes: # type: ignore 

381 highArgs = {} 

382 lowArgs = {} 

383 for name in self._stats: 

384 highArgs[name] = cast(Scalar, data[f"{{band}}_highSNUnknown_{name}".format(**kwargs)]) 

385 lowArgs[name] = cast(Scalar, data[f"{{band}}_lowSNUnknown_{name}".format(**kwargs)]) 

386 highStats = _StatsContainer(**highArgs) 

387 lowStats = _StatsContainer(**lowArgs) 

388 

389 toPlotList.append( 

390 ( 

391 data["x"], 

392 data["y"], 

393 data["anyHighSNMask"], 

394 data["anyLowSNMask"], 

395 "purple", 

396 None, 

397 highStats, 

398 lowStats, 

399 ) 

400 ) 

401 

402 xMin = None 

403 for (j, (xs, ys, highSn, lowSn, color, cmap, highStats, lowStats)) in enumerate(toPlotList): 

404 highSn = cast(Vector, highSn) 

405 lowSn = cast(Vector, lowSn) 

406 # ensure the columns are actually array 

407 xs = np.array(xs) 

408 ys = np.array(ys) 

409 sigMadYs = sigmaMad(ys, nan_policy="omit") 

410 if len(xs) < 2: 

411 (medLine,) = ax.plot( 

412 xs, np.nanmedian(ys), color, label=f"Median: {np.nanmedian(ys):.2g}", lw=0.8 

413 ) 

414 linesForLegend.append(medLine) 

415 sigMads = np.array([sigmaMad(ys, nan_policy="omit")] * len(xs)) 

416 (sigMadLine,) = ax.plot( 

417 xs, 

418 np.nanmedian(ys) + 1.0 * sigMads, 

419 color, 

420 alpha=0.8, 

421 lw=0.8, 

422 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:.2g}", 

423 ) 

424 ax.plot(xs, np.nanmedian(ys) - 1.0 * sigMads, color, alpha=0.8) 

425 linesForLegend.append(sigMadLine) 

426 histIm = None 

427 continue 

428 

429 [xs1, xs25, xs50, xs75, xs95, xs97] = np.nanpercentile(xs, [1, 25, 50, 75, 95, 97]) 

430 xScale = (xs97 - xs1) / 20.0 # This is ~5% of the data range 

431 

432 # 40 was used as the default number of bins because it looked good 

433 xEdges = np.arange( 

434 np.nanmin(xs) - xScale, 

435 np.nanmax(xs) + xScale, 

436 (np.nanmax(xs) + xScale - (np.nanmin(xs) - xScale)) / self.nBins, 

437 ) 

438 medYs = np.nanmedian(ys) 

439 fiveSigmaHigh = medYs + 5.0 * sigMadYs 

440 fiveSigmaLow = medYs - 5.0 * sigMadYs 

441 binSize = (fiveSigmaHigh - fiveSigmaLow) / 101.0 

442 yEdges = np.arange(fiveSigmaLow, fiveSigmaHigh, binSize) 

443 

444 counts, xBins, yBins = np.histogram2d(xs, ys, bins=(xEdges, yEdges)) 

445 yBinsOut.append(yBins) 

446 countsYs = np.sum(counts, axis=1) 

447 

448 ids = np.where((countsYs > binThresh))[0] 

449 xEdgesPlot = xEdges[ids][1:] 

450 xEdges = xEdges[ids] 

451 

452 if len(ids) > 1: 

453 # Create the codes needed to turn the sigmaMad lines 

454 # into a path to speed up checking which points are 

455 # inside the area. 

456 codes = np.ones(len(xEdgesPlot) * 2) * Path.LINETO 

457 codes[0] = Path.MOVETO 

458 codes[-1] = Path.CLOSEPOLY 

459 

460 meds = np.zeros(len(xEdgesPlot)) 

461 threeSigMadVerts = np.zeros((len(xEdgesPlot) * 2, 2)) 

462 sigMads = np.zeros(len(xEdgesPlot)) 

463 

464 for (i, xEdge) in enumerate(xEdgesPlot): 

465 ids = np.where((xs < xEdge) & (xs > xEdges[i]) & (np.isfinite(ys)))[0] 

466 med = np.median(ys[ids]) 

467 sigMad = sigmaMad(ys[ids]) 

468 meds[i] = med 

469 sigMads[i] = sigMad 

470 threeSigMadVerts[i, :] = [xEdge, med + 3 * sigMad] 

471 threeSigMadVerts[-(i + 1), :] = [xEdge, med - 3 * sigMad] 

472 

473 (medLine,) = ax.plot(xEdgesPlot, meds, color, label="Running Median") 

474 linesForLegend.append(medLine) 

475 

476 # Make path to check which points lie within one sigma mad 

477 threeSigMadPath = Path(threeSigMadVerts, codes) 

478 

479 # Add lines for the median +/- 3 * sigma MAD 

480 (threeSigMadLine,) = ax.plot( 

481 xEdgesPlot, 

482 threeSigMadVerts[: len(xEdgesPlot), 1], 

483 color, 

484 alpha=0.4, 

485 label=r"3$\sigma_{MAD}$", 

486 ) 

487 ax.plot(xEdgesPlot[::-1], threeSigMadVerts[len(xEdgesPlot) :, 1], color, alpha=0.4) 

488 

489 # Add lines for the median +/- 1 * sigma MAD 

490 (sigMadLine,) = ax.plot( 

491 xEdgesPlot, meds + 1.0 * sigMads, color, alpha=0.8, label=r"$\sigma_{MAD}$" 

492 ) 

493 linesForLegend.append(sigMadLine) 

494 ax.plot(xEdgesPlot, meds - 1.0 * sigMads, color, alpha=0.8) 

495 

496 # Add lines for the median +/- 2 * sigma MAD 

497 (twoSigMadLine,) = ax.plot( 

498 xEdgesPlot, meds + 2.0 * sigMads, color, alpha=0.6, label=r"2$\sigma_{MAD}$" 

499 ) 

500 linesForLegend.append(twoSigMadLine) 

501 linesForLegend.append(threeSigMadLine) 

502 ax.plot(xEdgesPlot, meds - 2.0 * sigMads, color, alpha=0.6) 

503 

504 # Check which points are outside 3 sigma MAD of the median 

505 # and plot these as points. 

506 inside = threeSigMadPath.contains_points(np.array([xs, ys]).T) 

507 ax.plot(xs[~inside], ys[~inside], ".", ms=3, alpha=0.3, mfc=color, mec=color, zorder=-1) 

508 

509 # Add some stats text 

510 xPos = 0.65 - 0.4 * j 

511 bbox = dict(edgecolor=color, linestyle="--", facecolor="none") 

512 highThresh = data["highSnThreshold"] 

513 statText = f"S/N > {highThresh:0.4g} Stats ({self.magLabel} < {highStats.approxMag:0.4g})\n" 

514 highStatsStr = ( 

515 f"Median: {highStats.median:0.4g} " 

516 + r"$\sigma_{MAD}$: " 

517 + f"{highStats.sigmaMad:0.4g} " 

518 + r"N$_{points}$: " 

519 + f"{highStats.count}" 

520 ) 

521 statText += highStatsStr 

522 fig.text(xPos, 0.090, statText, bbox=bbox, transform=fig.transFigure, fontsize=6) 

523 

524 bbox = dict(edgecolor=color, linestyle=":", facecolor="none") 

525 lowThresh = data["lowSnThreshold"] 

526 statText = f"S/N > {lowThresh:0.4g} Stats ({self.magLabel} < {lowStats.approxMag:0.4g})\n" 

527 lowStatsStr = ( 

528 f"Median: {lowStats.median:0.4g} " 

529 + r"$\sigma_{MAD}$: " 

530 + f"{lowStats.sigmaMad:0.4g} " 

531 + r"N$_{points}$: " 

532 + f"{lowStats.count}" 

533 ) 

534 statText += lowStatsStr 

535 fig.text(xPos, 0.020, statText, bbox=bbox, transform=fig.transFigure, fontsize=6) 

536 

537 if self.plot2DHist: 

538 histIm = ax.hexbin(xs[inside], ys[inside], gridsize=75, cmap=cmap, mincnt=1, zorder=-2) 

539 

540 # If there are not many sources being used for the 

541 # statistics then plot them individually as just 

542 # plotting a line makes the statistics look wrong 

543 # as the magnitude estimation is iffy for low 

544 # numbers of sources. 

545 if np.sum(highSn) < 100 and np.sum(highSn) > 0: 

546 ax.plot( 

547 cast(Vector, xs[highSn]), 

548 cast(Vector, ys[highSn]), 

549 marker="x", 

550 ms=4, 

551 mec="w", 

552 mew=2, 

553 ls="none", 

554 ) 

555 (highSnLine,) = ax.plot( 

556 cast(Vector, xs[highSn]), 

557 cast(Vector, ys[highSn]), 

558 color=color, 

559 marker="x", 

560 ms=4, 

561 ls="none", 

562 label="High SN", 

563 ) 

564 linesForLegend.append(highSnLine) 

565 xMin = np.min(cast(Vector, xs[highSn])) 

566 else: 

567 ax.axvline(highStats.approxMag, color=color, ls="--") 

568 

569 if np.sum(lowSn) < 100 and np.sum(lowSn) > 0: 

570 ax.plot( 

571 cast(Vector, xs[lowSn]), 

572 cast(Vector, ys[lowSn]), 

573 marker="+", 

574 ms=4, 

575 mec="w", 

576 mew=2, 

577 ls="none", 

578 ) 

579 (lowSnLine,) = ax.plot( 

580 cast(Vector, xs[lowSn]), 

581 cast(Vector, ys[lowSn]), 

582 color=color, 

583 marker="+", 

584 ms=4, 

585 ls="none", 

586 label="Low SN", 

587 ) 

588 linesForLegend.append(lowSnLine) 

589 if xMin is None or xMin > np.min(cast(Vector, xs[lowSn])): 

590 xMin = np.min(cast(Vector, xs[lowSn])) 

591 else: 

592 ax.axvline(lowStats.approxMag, color=color, ls=":") 

593 

594 else: 

595 ax.plot(xs, ys, ".", ms=5, alpha=0.3, mfc=color, mec=color, zorder=-1) 

596 meds = np.array([np.nanmedian(ys)] * len(xs)) 

597 (medLine,) = ax.plot(xs, meds, color, label=f"Median: {np.nanmedian(ys):0.3g}", lw=0.8) 

598 linesForLegend.append(medLine) 

599 sigMads = np.array([sigmaMad(ys, nan_policy="omit")] * len(xs)) 

600 (sigMadLine,) = ax.plot( 

601 xs, 

602 meds + 1.0 * sigMads, 

603 color, 

604 alpha=0.8, 

605 lw=0.8, 

606 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:0.3g}", 

607 ) 

608 ax.plot(xs, meds - 1.0 * sigMads, color, alpha=0.8) 

609 linesForLegend.append(sigMadLine) 

610 histIm = None 

611 

612 # Set the scatter plot limits 

613 # TODO: Make this not work by accident 

614 if len(cast(Vector, data["yStars"])) > 0: 

615 plotMed = np.nanmedian(cast(Vector, data["yStars"])) 

616 else: 

617 plotMed = np.nanmedian(cast(Vector, data["yGalaxies"])) 

618 # Ignore types below pending making this not working my accident 

619 if len(xs) < 2: # type: ignore 

620 meds = [np.median(ys)] # type: ignore 

621 if self.yLims: 

622 ax.set_ylim(self.yLims[0], self.yLims[1]) # type: ignore 

623 else: 

624 numSig = 4 

625 yLimMin = plotMed - numSig * sigMadYs # type: ignore 

626 yLimMax = plotMed + numSig * sigMadYs # type: ignore 

627 while (yLimMax < np.max(meds) or yLimMin > np.min(meds)) and numSig < 10: # type: ignore 

628 numSig += 1 

629 

630 numSig += 1 

631 yLimMin = plotMed - numSig * sigMadYs # type: ignore 

632 yLimMax = plotMed + numSig * sigMadYs # type: ignore 

633 ax.set_ylim(yLimMin, yLimMax) 

634 

635 if self.xLims: 

636 ax.set_xlim(self.xLims[0], self.xLims[1]) # type: ignore 

637 elif len(xs) > 2: # type: ignore 

638 if xMin is None: 

639 xMin = xs1 - 2 * xScale # type: ignore 

640 ax.set_xlim(xMin, xs97 + 2 * xScale) # type: ignore 

641 

642 # Add a line legend 

643 ax.legend( 

644 handles=linesForLegend, 

645 ncol=4, 

646 fontsize=6, 

647 loc="upper left", 

648 framealpha=0.9, 

649 edgecolor="k", 

650 borderpad=0.4, 

651 handlelength=1, 

652 ) 

653 

654 # Add axes labels 

655 ax.set_ylabel(self.yAxisLabel, fontsize=10, labelpad=10) 

656 ax.set_xlabel(self.xAxisLabel, fontsize=10, labelpad=2) 

657 

658 return ax, histIm 

659 

660 def _makeTopHistogram( 

661 self, data: KeyedData, figure: Figure, gs: gridspec.GridSpec, ax: Axes, **kwargs 

662 ) -> None: 

663 # Top histogram 

664 totalX: list[Vector] = [] 

665 if "stars" in self.plotTypes: # type: ignore 

666 totalX.append(cast(Vector, data["xStars"])) 

667 if "galaxies" in self.plotTypes: # type: ignore 

668 totalX.append(cast(Vector, data["xGalaxies"])) 

669 if "unknown" in self.plotTypes: # type: ignore 

670 totalX.append(cast(Vector, data["xUknown"])) 

671 if "any" in self.plotTypes: # type: ignore 

672 totalX.append(cast(Vector, data["x"])) 

673 

674 totalXChained = [x for x in chain.from_iterable(totalX) if x == x] 

675 

676 topHist = figure.add_subplot(gs[0, :-1], sharex=ax) 

677 topHist.hist( 

678 totalXChained, bins=100, color="grey", alpha=0.3, log=True, label=f"All ({len(totalXChained)})" 

679 ) 

680 if "galaxies" in self.plotTypes: # type: ignore 

681 topHist.hist( 

682 data["xGalaxies"], 

683 bins=100, 

684 color="firebrick", 

685 histtype="step", 

686 log=True, 

687 label=f"Galaxies ({len(cast(Vector, data['xGalaxies']))})", 

688 ) 

689 if "stars" in self.plotTypes: # type: ignore 

690 topHist.hist( 

691 data["xStars"], 

692 bins=100, 

693 color="midnightblue", 

694 histtype="step", 

695 log=True, 

696 label=f"Stars ({len(cast(Vector, data['xStars']))})", 

697 ) 

698 topHist.axes.get_xaxis().set_visible(False) 

699 topHist.set_ylabel("Number", fontsize=8) 

700 topHist.legend(fontsize=6, framealpha=0.9, borderpad=0.4, loc="lower left", ncol=3, edgecolor="k") 

701 

702 # Side histogram 

703 

704 def _makeSideHistogram( 

705 self, 

706 data: KeyedData, 

707 figure: Figure, 

708 gs: gridspec.Gridspec, 

709 ax: Axes, 

710 histIm: Optional[PolyCollection], 

711 **kwargs, 

712 ) -> None: 

713 sideHist = figure.add_subplot(gs[1:, -1], sharey=ax) 

714 

715 totalY: list[Vector] = [] 

716 if "stars" in self.plotTypes: # type: ignore 

717 totalY.append(cast(Vector, data["yStars"])) 

718 if "galaxies" in self.plotTypes: # type: ignore 

719 totalY.append(cast(Vector, data["yGalaxies"])) 

720 if "unknown" in self.plotTypes: # type: ignore 

721 totalY.append(cast(Vector, data["yUknown"])) 

722 if "any" in self.plotTypes: # type: ignore 

723 totalY.append(cast(Vector, data["y"])) 

724 totalYChained = [y for y in chain.from_iterable(totalY) if y == y] 

725 

726 # cheat to get the total count while iterating once 

727 yLimMin, yLimMax = ax.get_ylim() 

728 bins = np.linspace(yLimMin, yLimMax) 

729 sideHist.hist( 

730 totalYChained, 

731 bins=bins, 

732 color="grey", 

733 alpha=0.3, 

734 orientation="horizontal", 

735 log=True, 

736 ) 

737 if "galaxies" in self.plotTypes: # type: ignore 

738 sideHist.hist( 

739 [g for g in cast(Vector, data["yGalaxies"]) if g == g], 

740 bins=bins, 

741 color="firebrick", 

742 histtype="step", 

743 orientation="horizontal", 

744 log=True, 

745 ) 

746 sideHist.hist( 

747 cast(Vector, data["yGalaxies"])[cast(Vector, data["galaxiesHighSNMask"])], 

748 bins=bins, 

749 color="firebrick", 

750 histtype="step", 

751 orientation="horizontal", 

752 log=True, 

753 ls="--", 

754 ) 

755 sideHist.hist( 

756 cast(Vector, data["yGalaxies"])[cast(Vector, data["galaxiesLowSNMask"])], 

757 bins=bins, 

758 color="firebrick", 

759 histtype="step", 

760 orientation="horizontal", 

761 log=True, 

762 ls=":", 

763 ) 

764 

765 if "stars" in self.plotTypes: # type: ignore 

766 sideHist.hist( 

767 [s for s in cast(Vector, data["yStars"]) if s == s], 

768 bins=bins, 

769 color="midnightblue", 

770 histtype="step", 

771 orientation="horizontal", 

772 log=True, 

773 ) 

774 sideHist.hist( 

775 cast(Vector, data["yStars"])[cast(Vector, data["starsHighSNMask"])], 

776 bins=bins, 

777 color="midnightblue", 

778 histtype="step", 

779 orientation="horizontal", 

780 log=True, 

781 ls="--", 

782 ) 

783 sideHist.hist( 

784 cast(Vector, data["yStars"])[cast(Vector, data["starsLowSNMask"])], 

785 bins=bins, 

786 color="midnightblue", 

787 histtype="step", 

788 orientation="horizontal", 

789 log=True, 

790 ls=":", 

791 ) 

792 

793 sideHist.axes.get_yaxis().set_visible(False) 

794 sideHist.set_xlabel("Number", fontsize=8) 

795 if self.plot2DHist and histIm is not None: 

796 divider = make_axes_locatable(sideHist) 

797 cax = divider.append_axes("right", size="8%", pad=0) 

798 plt.colorbar(histIm, cax=cax, orientation="vertical", label="Number of Points Per Bin")