Coverage for python/lsst/analysis/tools/actions/plot/scatterplotWithTwoHists.py: 12%

361 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-11-09 03:50 -0800

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ("ScatterPlotStatsAction", "ScatterPlotWithTwoHists") 

25 

26from itertools import chain 

27from typing import Mapping, NamedTuple, Optional, cast 

28 

29import matplotlib.pyplot as plt 

30import numpy as np 

31from lsst.pex.config import Field 

32from lsst.pex.config.listField import ListField 

33from lsst.pipe.tasks.configurableActions import ConfigurableActionField 

34from lsst.skymap import BaseSkyMap 

35from matplotlib import gridspec 

36from matplotlib.axes import Axes 

37from matplotlib.collections import PolyCollection 

38from matplotlib.figure import Figure 

39from matplotlib.path import Path 

40from mpl_toolkits.axes_grid1 import make_axes_locatable 

41 

42from ...interfaces import KeyedData, KeyedDataAction, KeyedDataSchema, PlotAction, Scalar, Vector 

43from ...statistics import nansigmaMad, sigmaMad 

44from ..keyedData.summaryStatistics import SummaryStatisticAction 

45from ..scalar import MedianAction 

46from ..vector import MagColumnNanoJansky, SnSelector 

47from .plotUtils import addPlotInfo, addSummaryPlot, generateSummaryStats, mkColormap 

48 

49# ignore because coolwarm is actually part of module 

50cmapPatch = plt.cm.coolwarm.copy() # type: ignore 

51cmapPatch.set_bad(color="none") 

52 

53 

54class ScatterPlotStatsAction(KeyedDataAction): 

55 vectorKey = Field[str](doc="Vector on which to compute statistics") 

56 highSNSelector = ConfigurableActionField[SnSelector]( 

57 doc="Selector used to determine high SN Objects", default=SnSelector(threshold=2700) 

58 ) 

59 lowSNSelector = ConfigurableActionField[SnSelector]( 

60 doc="Selector used to determine low SN Objects", default=SnSelector(threshold=500) 

61 ) 

62 fluxType = Field[str](doc="Vector key to use to compute signal to noise ratio", default="{band}_psfFlux") 

63 

64 def getInputSchema(self, **kwargs) -> KeyedDataSchema: 

65 yield (self.vectorKey, Vector) 

66 yield (self.fluxType, Vector) 

67 yield from self.highSNSelector.getInputSchema() 

68 yield from self.lowSNSelector.getInputSchema() 

69 

70 def getOutputSchema(self) -> KeyedDataSchema: 

71 return ( 

72 (f'{self.identity or ""}HighSNMask', Vector), 

73 (f'{self.identity or ""}LowSNMask', Vector), 

74 (f"{{band}}_lowSN{self.identity.capitalize() if self.identity else ''}_median", Scalar), 

75 (f"{{band}}_lowSN{self.identity.capitalize() if self.identity else ''}_sigmaMad", Scalar), 

76 (f"{{band}}_lowSN{self.identity.capitalize() if self.identity else ''}_count", Scalar), 

77 (f"{{band}}_lowSN{self.identity.capitalize() if self.identity else ''}_approxMag", Scalar), 

78 (f"{{band}}_highSN{self.identity.capitalize() if self.identity else ''}_median", Scalar), 

79 (f"{{band}}_highSN{self.identity.capitalize() if self.identity else ''}_sigmaMad", Scalar), 

80 (f"{{band}}_highSN{self.identity.capitalize() if self.identity else ''}_count", Scalar), 

81 (f"{{band}}_highSN{self.identity.capitalize() if self.identity else ''}_approxMag", Scalar), 

82 ("highThreshold", Scalar), 

83 ("lowThreshold", Scalar), 

84 ) 

85 

86 def __call__(self, data: KeyedData, **kwargs) -> KeyedData: 

87 results = {} 

88 highMaskKey = f'{self.identity.lower() or ""}HighSNMask' 

89 results[highMaskKey] = self.highSNSelector(data, **kwargs) 

90 

91 lowMaskKey = f'{self.identity.lower() or ""}LowSNMask' 

92 results[lowMaskKey] = self.lowSNSelector(data, **kwargs) 

93 

94 prefix = f"{band}_" if (band := kwargs.get("band")) else "" 

95 fluxes = data[self.fluxType.format(band=band)] if band is not None else None 

96 

97 statAction = SummaryStatisticAction(vectorKey=self.vectorKey) 

98 

99 # this is sad, but pex_config seems to have broken behavior that 

100 # is dangerous to fix 

101 statAction.setDefaults() 

102 

103 medianAction = MedianAction(vectorKey="mag") 

104 magAction = MagColumnNanoJansky(vectorKey="flux") 

105 

106 for maskKey, binName in ((lowMaskKey, "low"), (highMaskKey, "high")): 

107 name = f"{prefix}{binName}SN{self.identity.capitalize() if self.identity else ''}" 

108 # set the approxMag to the median mag in the SN selection 

109 results[f"{name}_approxMag".format(**kwargs)] = ( 

110 medianAction({"mag": magAction({"flux": fluxes[results[maskKey]]})}) # type: ignore 

111 if band is not None 

112 else np.nan 

113 ) 

114 stats = statAction(data, **(kwargs | {"mask": results[maskKey]})).items() 

115 for suffix, value in stats: 

116 tmpKey = f"{name}_{suffix}".format(**kwargs) 

117 results[tmpKey] = value 

118 results["highSnThreshold"] = self.highSNSelector.threshold # type: ignore 

119 results["lowSnThreshold"] = self.lowSNSelector.threshold # type: ignore 

120 

121 return results 

122 

123 

124def _validatePlotTypes(value): 

125 return value in ("stars", "galaxies", "unknown", "any", "mag") 

126 

127 

128# ignore type because of conflicting name on tuple baseclass 

129class _StatsContainer(NamedTuple): 

130 median: Scalar 

131 sigmaMad: Scalar 

132 count: Scalar # type: ignore 

133 approxMag: Scalar 

134 

135 

136class ScatterPlotWithTwoHists(PlotAction): 

137 yLims = ListField[float]( 

138 doc="ylimits of the plot, if not specified determined from data", 

139 length=2, 

140 optional=True, 

141 ) 

142 

143 xLims = ListField[float]( 

144 doc="xlimits of the plot, if not specified determined from data", length=2, optional=True 

145 ) 

146 xAxisLabel = Field[str](doc="Label to use for the x axis", optional=False) 

147 yAxisLabel = Field[str](doc="Label to use for the y axis", optional=False) 

148 magLabel = Field[str](doc="Label to use for the magnitudes used for SNR", optional=False) 

149 nBins = Field[float](doc="Number of bins on x axis", default=40.0) 

150 plot2DHist = Field[bool]( 

151 doc="Plot a 2D histogram in dense areas of points on the scatter plot." 

152 "Doesn't look great if plotting multiple datasets on top of each other.", 

153 default=True, 

154 ) 

155 plotTypes = ListField[str]( 

156 doc="Selection of types of objects to plot. Can take any combination of" 

157 " stars, galaxies, unknown, mag, any", 

158 optional=False, 

159 itemCheck=_validatePlotTypes, 

160 ) 

161 

162 addSummaryPlot = Field[bool]( 

163 doc="Add a summary plot to the figure?", 

164 default=False, 

165 ) 

166 

167 _stats = ("median", "sigmaMad", "count", "approxMag") 

168 

169 def getInputSchema(self) -> KeyedDataSchema: 

170 base: list[tuple[str, type[Vector] | type[Scalar]]] = [] 

171 if "stars" in self.plotTypes: # type: ignore 

172 base.append(("xStars", Vector)) 

173 base.append(("yStars", Vector)) 

174 base.append(("starsHighSNMask", Vector)) 

175 base.append(("starsLowSNMask", Vector)) 

176 # statistics 

177 for name in self._stats: 

178 base.append((f"{{band}}_highSNStars_{name}", Scalar)) 

179 base.append((f"{{band}}_lowSNStars_{name}", Scalar)) 

180 if "galaxies" in self.plotTypes: # type: ignore 

181 base.append(("xGalaxies", Vector)) 

182 base.append(("yGalaxies", Vector)) 

183 base.append(("galaxiesHighSNMask", Vector)) 

184 base.append(("galaxiesLowSNMask", Vector)) 

185 # statistics 

186 for name in self._stats: 

187 base.append((f"{{band}}_highSNGalaxies_{name}", Scalar)) 

188 base.append((f"{{band}}_lowSNGalaxies_{name}", Scalar)) 

189 if "unknown" in self.plotTypes: # type: ignore 

190 base.append(("xUnknown", Vector)) 

191 base.append(("yUnknown", Vector)) 

192 base.append(("unknownHighSNMask", Vector)) 

193 base.append(("unknownLowSNMask", Vector)) 

194 # statistics 

195 for name in self._stats: 

196 base.append((f"{{band}}_highSNUnknown_{name}", Scalar)) 

197 base.append((f"{{band}}_lowSNUnknown_{name}", Scalar)) 

198 if "any" in self.plotTypes: # type: ignore 

199 base.append(("x", Vector)) 

200 base.append(("y", Vector)) 

201 base.append(("anyHighSNMask", Vector)) 

202 base.append(("anySNMask", Vector)) 

203 # statistics 

204 for name in self._stats: 

205 base.append((f"{{band}}_highSNAny_{name}", Scalar)) 

206 base.append((f"{{band}}_lowSNAny_{name}", Scalar)) 

207 base.append(("lowSnThreshold", Scalar)) 

208 base.append(("highSnThreshold", Scalar)) 

209 

210 if self.addSummaryPlot: 

211 base.append(("patch", Vector)) 

212 

213 return base 

214 

215 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure: 

216 

217 self._validateInput(data, **kwargs) 

218 return self.makePlot(data, **kwargs) 

219 

220 def _validateInput(self, data: KeyedData, **kwargs) -> None: 

221 """NOTE currently can only check that something is not a Scalar, not 

222 check that the data is consistent with Vector 

223 """ 

224 needed = self.getFormattedInputSchema(**kwargs) 

225 if remainder := {key.format(**kwargs) for key, _ in needed} - { 

226 key.format(**kwargs) for key in data.keys() 

227 }: 

228 raise ValueError(f"Task needs keys {remainder} but they were not found in input") 

229 for name, typ in needed: 

230 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar) 

231 if isScalar and typ != Scalar: 

232 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}") 

233 

234 def makePlot( 

235 self, 

236 data: KeyedData, 

237 skymap: BaseSkyMap, 

238 plotInfo: Optional[Mapping[str, str]] = None, 

239 sumStats: Optional[Mapping] = None, 

240 **kwargs, 

241 ) -> Figure: 

242 """Makes a generic plot with a 2D histogram and collapsed histograms of 

243 each axis. 

244 Parameters 

245 ---------- 

246 data : `pandas.core.frame.DataFrame` 

247 The catalog to plot the points from. 

248 plotInfo : `dict` 

249 A dictionary of information about the data being plotted with keys: 

250 ``"run"`` 

251 The output run for the plots (`str`). 

252 ``"skymap"`` 

253 The type of skymap used for the data (`str`). 

254 ``"filter"`` 

255 The filter used for this data (`str`). 

256 ``"tract"`` 

257 The tract that the data comes from (`str`). 

258 sumStats : `dict` 

259 A dictionary where the patchIds are the keys which store the R.A. 

260 and dec of the corners of the patch, along with a summary 

261 statistic for each patch. 

262 Returns 

263 ------- 

264 fig : `matplotlib.figure.Figure` 

265 The resulting figure. 

266 Notes 

267 ----- 

268 Uses the axisLabels config options `x` and `y` and the axisAction 

269 config options `xAction` and `yAction` to plot a scatter 

270 plot of the values against each other. A histogram of the points 

271 collapsed onto each axis is also plotted. A summary panel showing the 

272 median of the y value in each patch is shown in the upper right corner 

273 of the resultant plot. The code uses the selectorActions to decide 

274 which points to plot and the statisticSelector actions to determine 

275 which points to use for the printed statistics. 

276 """ 

277 if not self.plotTypes: 

278 noDataFig = Figure() 

279 noDataFig.text(0.3, 0.5, "No data to plot after selectors applied") 

280 noDataFig = addPlotInfo(noDataFig, plotInfo) 

281 return noDataFig 

282 

283 fig = plt.figure(dpi=300) 

284 gs = gridspec.GridSpec(4, 4) 

285 

286 # add the various plot elements 

287 ax, imhist = self._scatterPlot(data, fig, gs, **kwargs) 

288 self._makeTopHistogram(data, fig, gs, ax, **kwargs) 

289 self._makeSideHistogram(data, fig, gs, ax, imhist, **kwargs) 

290 # Needs info from run quantum 

291 if self.addSummaryPlot: 

292 sumStats = generateSummaryStats(data, skymap, plotInfo) 

293 label = self.yAxisLabel 

294 fig = addSummaryPlot(fig, gs[0, -1], sumStats, label) 

295 

296 plt.draw() 

297 plt.subplots_adjust(wspace=0.0, hspace=0.0, bottom=0.22, left=0.21) 

298 fig = addPlotInfo(fig, plotInfo) 

299 return fig 

300 

301 def _scatterPlot( 

302 self, data: KeyedData, fig: Figure, gs: gridspec.GridSpec, **kwargs 

303 ) -> tuple[Axes, Optional[PolyCollection]]: 

304 # Main scatter plot 

305 ax = fig.add_subplot(gs[1:, :-1]) 

306 

307 newBlues = mkColormap(["paleturquoise", "midnightBlue"]) 

308 newReds = mkColormap(["lemonchiffon", "firebrick"]) 

309 

310 binThresh = 5 

311 

312 yBinsOut = [] 

313 linesForLegend = [] 

314 

315 toPlotList = [] 

316 histIm = None 

317 highStats: _StatsContainer 

318 lowStats: _StatsContainer 

319 if "stars" in self.plotTypes: # type: ignore 

320 highArgs = {} 

321 lowArgs = {} 

322 for name in self._stats: 

323 highArgs[name] = cast(Scalar, data[f"{{band}}_highSNStars_{name}".format(**kwargs)]) 

324 lowArgs[name] = cast(Scalar, data[f"{{band}}_lowSNStars_{name}".format(**kwargs)]) 

325 highStats = _StatsContainer(**highArgs) 

326 lowStats = _StatsContainer(**lowArgs) 

327 

328 toPlotList.append( 

329 ( 

330 data["xStars"], 

331 data["yStars"], 

332 data["starsHighSNMask"], 

333 data["starsLowSNMask"], 

334 "midnightblue", 

335 newBlues, 

336 highStats, 

337 lowStats, 

338 ) 

339 ) 

340 if "galaxies" in self.plotTypes: # type: ignore 

341 highArgs = {} 

342 lowArgs = {} 

343 for name in self._stats: 

344 highArgs[name] = cast(Scalar, data[f"{{band}}_highSNGalaxies_{name}".format(**kwargs)]) 

345 lowArgs[name] = cast(Scalar, data[f"{{band}}_lowSNGalaxies_{name}".format(**kwargs)]) 

346 highStats = _StatsContainer(**highArgs) 

347 lowStats = _StatsContainer(**lowArgs) 

348 

349 toPlotList.append( 

350 ( 

351 data["xGalaxies"], 

352 data["yGalaxies"], 

353 data["galaxiesHighSNMask"], 

354 data["galaxiesLowSNMask"], 

355 "firebrick", 

356 newReds, 

357 highStats, 

358 lowStats, 

359 ) 

360 ) 

361 if "unknown" in self.plotTypes: # type: ignore 

362 highArgs = {} 

363 lowArgs = {} 

364 for name in self._stats: 

365 highArgs[name] = cast(Scalar, data[f"{{band}}_highSNUnknown_{name}".format(**kwargs)]) 

366 lowArgs[name] = cast(Scalar, data[f"{{band}}_lowSNUnknown_{name}".format(**kwargs)]) 

367 highStats = _StatsContainer(**highArgs) 

368 lowStats = _StatsContainer(**lowArgs) 

369 

370 toPlotList.append( 

371 ( 

372 data["xUnknown"], 

373 data["yUnknown"], 

374 data["unknownHighSNMask"], 

375 data["unknownLowSNMask"], 

376 "green", 

377 None, 

378 highStats, 

379 lowStats, 

380 ) 

381 ) 

382 if "any" in self.plotTypes: # type: ignore 

383 highArgs = {} 

384 lowArgs = {} 

385 for name in self._stats: 

386 highArgs[name] = cast(Scalar, data[f"{{band}}_highSNUnknown_{name}".format(**kwargs)]) 

387 lowArgs[name] = cast(Scalar, data[f"{{band}}_lowSNUnknown_{name}".format(**kwargs)]) 

388 highStats = _StatsContainer(**highArgs) 

389 lowStats = _StatsContainer(**lowArgs) 

390 

391 toPlotList.append( 

392 ( 

393 data["x"], 

394 data["y"], 

395 data["anyHighSNMask"], 

396 data["anyLowSNMask"], 

397 "purple", 

398 None, 

399 highStats, 

400 lowStats, 

401 ) 

402 ) 

403 

404 xMin = None 

405 for (j, (xs, ys, highSn, lowSn, color, cmap, highStats, lowStats)) in enumerate(toPlotList): 

406 highSn = cast(Vector, highSn) 

407 lowSn = cast(Vector, lowSn) 

408 # ensure the columns are actually array 

409 xs = np.array(xs) 

410 ys = np.array(ys) 

411 sigMadYs = nansigmaMad(ys) 

412 if len(xs) < 2: 

413 (medLine,) = ax.plot( 

414 xs, np.nanmedian(ys), color, label=f"Median: {np.nanmedian(ys):.2g}", lw=0.8 

415 ) 

416 linesForLegend.append(medLine) 

417 sigMads = np.array([nansigmaMad(ys)] * len(xs)) 

418 (sigMadLine,) = ax.plot( 

419 xs, 

420 np.nanmedian(ys) + 1.0 * sigMads, 

421 color, 

422 alpha=0.8, 

423 lw=0.8, 

424 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:.2g}", 

425 ) 

426 ax.plot(xs, np.nanmedian(ys) - 1.0 * sigMads, color, alpha=0.8) 

427 linesForLegend.append(sigMadLine) 

428 histIm = None 

429 continue 

430 

431 [xs1, xs25, xs50, xs75, xs95, xs97] = np.nanpercentile(xs, [1, 25, 50, 75, 95, 97]) 

432 xScale = (xs97 - xs1) / 20.0 # This is ~5% of the data range 

433 

434 # 40 was used as the default number of bins because it looked good 

435 xEdges = np.arange( 

436 np.nanmin(xs) - xScale, 

437 np.nanmax(xs) + xScale, 

438 (np.nanmax(xs) + xScale - (np.nanmin(xs) - xScale)) / self.nBins, 

439 ) 

440 medYs = np.nanmedian(ys) 

441 fiveSigmaHigh = medYs + 5.0 * sigMadYs 

442 fiveSigmaLow = medYs - 5.0 * sigMadYs 

443 binSize = (fiveSigmaHigh - fiveSigmaLow) / 101.0 

444 yEdges = np.arange(fiveSigmaLow, fiveSigmaHigh, binSize) 

445 

446 counts, xBins, yBins = np.histogram2d(xs, ys, bins=(xEdges, yEdges)) 

447 yBinsOut.append(yBins) 

448 countsYs = np.sum(counts, axis=1) 

449 

450 ids = np.where((countsYs > binThresh))[0] 

451 xEdgesPlot = xEdges[ids][1:] 

452 xEdges = xEdges[ids] 

453 

454 if len(ids) > 1: 

455 # Create the codes needed to turn the sigmaMad lines 

456 # into a path to speed up checking which points are 

457 # inside the area. 

458 codes = np.ones(len(xEdgesPlot) * 2) * Path.LINETO 

459 codes[0] = Path.MOVETO 

460 codes[-1] = Path.CLOSEPOLY 

461 

462 meds = np.zeros(len(xEdgesPlot)) 

463 threeSigMadVerts = np.zeros((len(xEdgesPlot) * 2, 2)) 

464 sigMads = np.zeros(len(xEdgesPlot)) 

465 

466 for (i, xEdge) in enumerate(xEdgesPlot): 

467 ids = np.where((xs < xEdge) & (xs > xEdges[i]) & (np.isfinite(ys)))[0] 

468 med = np.median(ys[ids]) 

469 sigMad = sigmaMad(ys[ids]) 

470 meds[i] = med 

471 sigMads[i] = sigMad 

472 threeSigMadVerts[i, :] = [xEdge, med + 3 * sigMad] 

473 threeSigMadVerts[-(i + 1), :] = [xEdge, med - 3 * sigMad] 

474 

475 (medLine,) = ax.plot(xEdgesPlot, meds, color, label="Running Median") 

476 linesForLegend.append(medLine) 

477 

478 # Make path to check which points lie within one sigma mad 

479 threeSigMadPath = Path(threeSigMadVerts, codes) 

480 

481 # Add lines for the median +/- 3 * sigma MAD 

482 (threeSigMadLine,) = ax.plot( 

483 xEdgesPlot, 

484 threeSigMadVerts[: len(xEdgesPlot), 1], 

485 color, 

486 alpha=0.4, 

487 label=r"3$\sigma_{MAD}$", 

488 ) 

489 ax.plot(xEdgesPlot[::-1], threeSigMadVerts[len(xEdgesPlot) :, 1], color, alpha=0.4) 

490 

491 # Add lines for the median +/- 1 * sigma MAD 

492 (sigMadLine,) = ax.plot( 

493 xEdgesPlot, meds + 1.0 * sigMads, color, alpha=0.8, label=r"$\sigma_{MAD}$" 

494 ) 

495 linesForLegend.append(sigMadLine) 

496 ax.plot(xEdgesPlot, meds - 1.0 * sigMads, color, alpha=0.8) 

497 

498 # Add lines for the median +/- 2 * sigma MAD 

499 (twoSigMadLine,) = ax.plot( 

500 xEdgesPlot, meds + 2.0 * sigMads, color, alpha=0.6, label=r"2$\sigma_{MAD}$" 

501 ) 

502 linesForLegend.append(twoSigMadLine) 

503 linesForLegend.append(threeSigMadLine) 

504 ax.plot(xEdgesPlot, meds - 2.0 * sigMads, color, alpha=0.6) 

505 

506 # Check which points are outside 3 sigma MAD of the median 

507 # and plot these as points. 

508 inside = threeSigMadPath.contains_points(np.array([xs, ys]).T) 

509 ax.plot(xs[~inside], ys[~inside], ".", ms=3, alpha=0.3, mfc=color, mec=color, zorder=-1) 

510 

511 # Add some stats text 

512 xPos = 0.65 - 0.4 * j 

513 bbox = dict(edgecolor=color, linestyle="--", facecolor="none") 

514 highThresh = data["highSnThreshold"] 

515 statText = f"S/N > {highThresh:0.4g} Stats ({self.magLabel} < {highStats.approxMag:0.4g})\n" 

516 highStatsStr = ( 

517 f"Median: {highStats.median:0.4g} " 

518 + r"$\sigma_{MAD}$: " 

519 + f"{highStats.sigmaMad:0.4g} " 

520 + r"N$_{points}$: " 

521 + f"{highStats.count}" 

522 ) 

523 statText += highStatsStr 

524 fig.text(xPos, 0.090, statText, bbox=bbox, transform=fig.transFigure, fontsize=6) 

525 

526 bbox = dict(edgecolor=color, linestyle=":", facecolor="none") 

527 lowThresh = data["lowSnThreshold"] 

528 statText = f"S/N > {lowThresh:0.4g} Stats ({self.magLabel} < {lowStats.approxMag:0.4g})\n" 

529 lowStatsStr = ( 

530 f"Median: {lowStats.median:0.4g} " 

531 + r"$\sigma_{MAD}$: " 

532 + f"{lowStats.sigmaMad:0.4g} " 

533 + r"N$_{points}$: " 

534 + f"{lowStats.count}" 

535 ) 

536 statText += lowStatsStr 

537 fig.text(xPos, 0.020, statText, bbox=bbox, transform=fig.transFigure, fontsize=6) 

538 

539 if self.plot2DHist: 

540 histIm = ax.hexbin(xs[inside], ys[inside], gridsize=75, cmap=cmap, mincnt=1, zorder=-2) 

541 

542 # If there are not many sources being used for the 

543 # statistics then plot them individually as just 

544 # plotting a line makes the statistics look wrong 

545 # as the magnitude estimation is iffy for low 

546 # numbers of sources. 

547 if np.sum(highSn) < 100 and np.sum(highSn) > 0: 

548 ax.plot( 

549 cast(Vector, xs[highSn]), 

550 cast(Vector, ys[highSn]), 

551 marker="x", 

552 ms=4, 

553 mec="w", 

554 mew=2, 

555 ls="none", 

556 ) 

557 (highSnLine,) = ax.plot( 

558 cast(Vector, xs[highSn]), 

559 cast(Vector, ys[highSn]), 

560 color=color, 

561 marker="x", 

562 ms=4, 

563 ls="none", 

564 label="High SN", 

565 ) 

566 linesForLegend.append(highSnLine) 

567 xMin = np.min(cast(Vector, xs[highSn])) 

568 else: 

569 ax.axvline(highStats.approxMag, color=color, ls="--") 

570 

571 if np.sum(lowSn) < 100 and np.sum(lowSn) > 0: 

572 ax.plot( 

573 cast(Vector, xs[lowSn]), 

574 cast(Vector, ys[lowSn]), 

575 marker="+", 

576 ms=4, 

577 mec="w", 

578 mew=2, 

579 ls="none", 

580 ) 

581 (lowSnLine,) = ax.plot( 

582 cast(Vector, xs[lowSn]), 

583 cast(Vector, ys[lowSn]), 

584 color=color, 

585 marker="+", 

586 ms=4, 

587 ls="none", 

588 label="Low SN", 

589 ) 

590 linesForLegend.append(lowSnLine) 

591 if xMin is None or xMin > np.min(cast(Vector, xs[lowSn])): 

592 xMin = np.min(cast(Vector, xs[lowSn])) 

593 else: 

594 ax.axvline(lowStats.approxMag, color=color, ls=":") 

595 

596 else: 

597 ax.plot(xs, ys, ".", ms=5, alpha=0.3, mfc=color, mec=color, zorder=-1) 

598 meds = np.array([np.nanmedian(ys)] * len(xs)) 

599 (medLine,) = ax.plot(xs, meds, color, label=f"Median: {np.nanmedian(ys):0.3g}", lw=0.8) 

600 linesForLegend.append(medLine) 

601 sigMads = np.array([nansigmaMad(ys)] * len(xs)) 

602 (sigMadLine,) = ax.plot( 

603 xs, 

604 meds + 1.0 * sigMads, 

605 color, 

606 alpha=0.8, 

607 lw=0.8, 

608 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:0.3g}", 

609 ) 

610 ax.plot(xs, meds - 1.0 * sigMads, color, alpha=0.8) 

611 linesForLegend.append(sigMadLine) 

612 histIm = None 

613 

614 # Set the scatter plot limits 

615 # TODO: Make this not work by accident 

616 if len(cast(Vector, data["yStars"])) > 0: 

617 plotMed = np.nanmedian(cast(Vector, data["yStars"])) 

618 else: 

619 plotMed = np.nanmedian(cast(Vector, data["yGalaxies"])) 

620 # Ignore types below pending making this not working my accident 

621 if len(xs) < 2: # type: ignore 

622 meds = [np.median(ys)] # type: ignore 

623 if self.yLims: 

624 ax.set_ylim(self.yLims[0], self.yLims[1]) # type: ignore 

625 else: 

626 numSig = 4 

627 yLimMin = plotMed - numSig * sigMadYs # type: ignore 

628 yLimMax = plotMed + numSig * sigMadYs # type: ignore 

629 while (yLimMax < np.max(meds) or yLimMin > np.min(meds)) and numSig < 10: # type: ignore 

630 numSig += 1 

631 

632 numSig += 1 

633 yLimMin = plotMed - numSig * sigMadYs # type: ignore 

634 yLimMax = plotMed + numSig * sigMadYs # type: ignore 

635 ax.set_ylim(yLimMin, yLimMax) 

636 

637 if self.xLims: 

638 ax.set_xlim(self.xLims[0], self.xLims[1]) # type: ignore 

639 elif len(xs) > 2: # type: ignore 

640 if xMin is None: 

641 xMin = xs1 - 2 * xScale # type: ignore 

642 ax.set_xlim(xMin, xs97 + 2 * xScale) # type: ignore 

643 

644 # Add a line legend 

645 ax.legend( 

646 handles=linesForLegend, 

647 ncol=4, 

648 fontsize=6, 

649 loc="upper left", 

650 framealpha=0.9, 

651 edgecolor="k", 

652 borderpad=0.4, 

653 handlelength=1, 

654 ) 

655 

656 # Add axes labels 

657 ax.set_ylabel(self.yAxisLabel, fontsize=10, labelpad=10) 

658 ax.set_xlabel(self.xAxisLabel, fontsize=10, labelpad=2) 

659 

660 return ax, histIm 

661 

662 def _makeTopHistogram( 

663 self, data: KeyedData, figure: Figure, gs: gridspec.GridSpec, ax: Axes, **kwargs 

664 ) -> None: 

665 # Top histogram 

666 totalX: list[Vector] = [] 

667 if "stars" in self.plotTypes: # type: ignore 

668 totalX.append(cast(Vector, data["xStars"])) 

669 if "galaxies" in self.plotTypes: # type: ignore 

670 totalX.append(cast(Vector, data["xGalaxies"])) 

671 if "unknown" in self.plotTypes: # type: ignore 

672 totalX.append(cast(Vector, data["xUknown"])) 

673 if "any" in self.plotTypes: # type: ignore 

674 totalX.append(cast(Vector, data["x"])) 

675 

676 totalXChained = [x for x in chain.from_iterable(totalX) if x == x] 

677 

678 topHist = figure.add_subplot(gs[0, :-1], sharex=ax) 

679 topHist.hist( 

680 totalXChained, bins=100, color="grey", alpha=0.3, log=True, label=f"All ({len(totalXChained)})" 

681 ) 

682 if "galaxies" in self.plotTypes: # type: ignore 

683 topHist.hist( 

684 data["xGalaxies"], 

685 bins=100, 

686 color="firebrick", 

687 histtype="step", 

688 log=True, 

689 label=f"Galaxies ({len(cast(Vector, data['xGalaxies']))})", 

690 ) 

691 if "stars" in self.plotTypes: # type: ignore 

692 topHist.hist( 

693 data["xStars"], 

694 bins=100, 

695 color="midnightblue", 

696 histtype="step", 

697 log=True, 

698 label=f"Stars ({len(cast(Vector, data['xStars']))})", 

699 ) 

700 topHist.axes.get_xaxis().set_visible(False) 

701 topHist.set_ylabel("Number", fontsize=8) 

702 topHist.legend(fontsize=6, framealpha=0.9, borderpad=0.4, loc="lower left", ncol=3, edgecolor="k") 

703 

704 # Side histogram 

705 

706 def _makeSideHistogram( 

707 self, 

708 data: KeyedData, 

709 figure: Figure, 

710 gs: gridspec.Gridspec, 

711 ax: Axes, 

712 histIm: Optional[PolyCollection], 

713 **kwargs, 

714 ) -> None: 

715 sideHist = figure.add_subplot(gs[1:, -1], sharey=ax) 

716 

717 totalY: list[Vector] = [] 

718 if "stars" in self.plotTypes: # type: ignore 

719 totalY.append(cast(Vector, data["yStars"])) 

720 if "galaxies" in self.plotTypes: # type: ignore 

721 totalY.append(cast(Vector, data["yGalaxies"])) 

722 if "unknown" in self.plotTypes: # type: ignore 

723 totalY.append(cast(Vector, data["yUknown"])) 

724 if "any" in self.plotTypes: # type: ignore 

725 totalY.append(cast(Vector, data["y"])) 

726 totalYChained = [y for y in chain.from_iterable(totalY) if y == y] 

727 

728 # cheat to get the total count while iterating once 

729 yLimMin, yLimMax = ax.get_ylim() 

730 bins = np.linspace(yLimMin, yLimMax) 

731 sideHist.hist( 

732 totalYChained, 

733 bins=bins, 

734 color="grey", 

735 alpha=0.3, 

736 orientation="horizontal", 

737 log=True, 

738 ) 

739 if "galaxies" in self.plotTypes: # type: ignore 

740 sideHist.hist( 

741 [g for g in cast(Vector, data["yGalaxies"]) if g == g], 

742 bins=bins, 

743 color="firebrick", 

744 histtype="step", 

745 orientation="horizontal", 

746 log=True, 

747 ) 

748 sideHist.hist( 

749 cast(Vector, data["yGalaxies"])[cast(Vector, data["galaxiesHighSNMask"])], 

750 bins=bins, 

751 color="firebrick", 

752 histtype="step", 

753 orientation="horizontal", 

754 log=True, 

755 ls="--", 

756 ) 

757 sideHist.hist( 

758 cast(Vector, data["yGalaxies"])[cast(Vector, data["galaxiesLowSNMask"])], 

759 bins=bins, 

760 color="firebrick", 

761 histtype="step", 

762 orientation="horizontal", 

763 log=True, 

764 ls=":", 

765 ) 

766 

767 if "stars" in self.plotTypes: # type: ignore 

768 sideHist.hist( 

769 [s for s in cast(Vector, data["yStars"]) if s == s], 

770 bins=bins, 

771 color="midnightblue", 

772 histtype="step", 

773 orientation="horizontal", 

774 log=True, 

775 ) 

776 sideHist.hist( 

777 cast(Vector, data["yStars"])[cast(Vector, data["starsHighSNMask"])], 

778 bins=bins, 

779 color="midnightblue", 

780 histtype="step", 

781 orientation="horizontal", 

782 log=True, 

783 ls="--", 

784 ) 

785 sideHist.hist( 

786 cast(Vector, data["yStars"])[cast(Vector, data["starsLowSNMask"])], 

787 bins=bins, 

788 color="midnightblue", 

789 histtype="step", 

790 orientation="horizontal", 

791 log=True, 

792 ls=":", 

793 ) 

794 

795 sideHist.axes.get_yaxis().set_visible(False) 

796 sideHist.set_xlabel("Number", fontsize=8) 

797 if self.plot2DHist and histIm is not None: 

798 divider = make_axes_locatable(sideHist) 

799 cax = divider.append_axes("right", size="8%", pad=0) 

800 plt.colorbar(histIm, cax=cax, orientation="vertical", label="Number of Points Per Bin")