Coverage for python/lsst/analysis/tools/actions/plot/scatterplotWithTwoHists.py: 12%

367 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-02-10 11:17 +0000

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ("ScatterPlotStatsAction", "ScatterPlotWithTwoHists") 

25 

26from itertools import chain 

27from typing import Mapping, NamedTuple, Optional, cast 

28 

29import matplotlib.pyplot as plt 

30import numpy as np 

31from lsst.pex.config import Field 

32from lsst.pex.config.listField import ListField 

33from lsst.pipe.tasks.configurableActions import ConfigurableActionField 

34from lsst.skymap import BaseSkyMap 

35from matplotlib import gridspec 

36from matplotlib.axes import Axes 

37from matplotlib.collections import PolyCollection 

38from matplotlib.figure import Figure 

39from matplotlib.path import Path 

40from mpl_toolkits.axes_grid1 import make_axes_locatable 

41 

42from ...interfaces import KeyedData, KeyedDataAction, KeyedDataSchema, PlotAction, Scalar, Vector 

43from ...statistics import nansigmaMad, sigmaMad 

44from ..keyedData.summaryStatistics import SummaryStatisticAction 

45from ..scalar import MedianAction 

46from ..vector import MagColumnNanoJansky, SnSelector 

47from .plotUtils import addPlotInfo, addSummaryPlot, generateSummaryStats, mkColormap 

48 

49# ignore because coolwarm is actually part of module 

50cmapPatch = plt.cm.coolwarm.copy() # type: ignore 

51cmapPatch.set_bad(color="none") 

52 

53 

54class ScatterPlotStatsAction(KeyedDataAction): 

55 vectorKey = Field[str](doc="Vector on which to compute statistics") 

56 highSNSelector = ConfigurableActionField[SnSelector]( 

57 doc="Selector used to determine high SN Objects", default=SnSelector(threshold=2700) 

58 ) 

59 lowSNSelector = ConfigurableActionField[SnSelector]( 

60 doc="Selector used to determine low SN Objects", default=SnSelector(threshold=500) 

61 ) 

62 fluxType = Field[str](doc="Vector key to use to compute signal to noise ratio", default="{band}_psfFlux") 

63 

64 def getInputSchema(self, **kwargs) -> KeyedDataSchema: 

65 yield (self.vectorKey, Vector) 

66 yield (self.fluxType, Vector) 

67 yield from self.highSNSelector.getInputSchema() 

68 yield from self.lowSNSelector.getInputSchema() 

69 

70 def getOutputSchema(self) -> KeyedDataSchema: 

71 return ( 

72 (f'{self.identity or ""}HighSNMask', Vector), 

73 (f'{self.identity or ""}LowSNMask', Vector), 

74 (f"{{band}}_lowSN{self.identity.capitalize() if self.identity else ''}_median", Scalar), 

75 (f"{{band}}_lowSN{self.identity.capitalize() if self.identity else ''}_sigmaMad", Scalar), 

76 (f"{{band}}_lowSN{self.identity.capitalize() if self.identity else ''}_count", Scalar), 

77 (f"{{band}}_lowSN{self.identity.capitalize() if self.identity else ''}_approxMag", Scalar), 

78 (f"{{band}}_highSN{self.identity.capitalize() if self.identity else ''}_median", Scalar), 

79 (f"{{band}}_highSN{self.identity.capitalize() if self.identity else ''}_sigmaMad", Scalar), 

80 (f"{{band}}_highSN{self.identity.capitalize() if self.identity else ''}_count", Scalar), 

81 (f"{{band}}_highSN{self.identity.capitalize() if self.identity else ''}_approxMag", Scalar), 

82 ("highThreshold", Scalar), 

83 ("lowThreshold", Scalar), 

84 ) 

85 

86 def __call__(self, data: KeyedData, **kwargs) -> KeyedData: 

87 results = {} 

88 highMaskKey = f'{self.identity.lower() or ""}HighSNMask' 

89 results[highMaskKey] = self.highSNSelector(data, **kwargs) 

90 

91 lowMaskKey = f'{self.identity.lower() or ""}LowSNMask' 

92 results[lowMaskKey] = self.lowSNSelector(data, **kwargs) 

93 

94 prefix = f"{band}_" if (band := kwargs.get("band")) else "" 

95 fluxes = data[self.fluxType.format(band=band)] if band is not None else None 

96 

97 statAction = SummaryStatisticAction(vectorKey=self.vectorKey) 

98 

99 # this is sad, but pex_config seems to have broken behavior that 

100 # is dangerous to fix 

101 statAction.setDefaults() 

102 

103 medianAction = MedianAction(vectorKey="mag") 

104 magAction = MagColumnNanoJansky(vectorKey="flux") 

105 

106 for maskKey, binName in ((lowMaskKey, "low"), (highMaskKey, "high")): 

107 name = f"{prefix}{binName}SN{self.identity.capitalize() if self.identity else ''}" 

108 # set the approxMag to the median mag in the SN selection 

109 results[f"{name}_approxMag".format(**kwargs)] = ( 

110 medianAction({"mag": magAction({"flux": fluxes[results[maskKey]]})}) # type: ignore 

111 if band is not None 

112 else np.nan 

113 ) 

114 stats = statAction(data, **(kwargs | {"mask": results[maskKey]})).items() 

115 for suffix, value in stats: 

116 tmpKey = f"{name}_{suffix}".format(**kwargs) 

117 results[tmpKey] = value 

118 results["highSnThreshold"] = self.highSNSelector.threshold # type: ignore 

119 results["lowSnThreshold"] = self.lowSNSelector.threshold # type: ignore 

120 

121 return results 

122 

123 

124def _validatePlotTypes(value): 

125 return value in ("stars", "galaxies", "unknown", "any", "mag") 

126 

127 

128# ignore type because of conflicting name on tuple baseclass 

129class _StatsContainer(NamedTuple): 

130 median: Scalar 

131 sigmaMad: Scalar 

132 count: Scalar # type: ignore 

133 approxMag: Scalar 

134 

135 

136class ScatterPlotWithTwoHists(PlotAction): 

137 yLims = ListField[float]( 

138 doc="ylimits of the plot, if not specified determined from data", 

139 length=2, 

140 optional=True, 

141 ) 

142 

143 xLims = ListField[float]( 

144 doc="xlimits of the plot, if not specified determined from data", length=2, optional=True 

145 ) 

146 xAxisLabel = Field[str](doc="Label to use for the x axis", optional=False) 

147 yAxisLabel = Field[str](doc="Label to use for the y axis", optional=False) 

148 magLabel = Field[str](doc="Label to use for the magnitudes used for SNR", optional=False) 

149 nBins = Field[float](doc="Number of bins on x axis", default=40.0) 

150 plot2DHist = Field[bool]( 

151 doc="Plot a 2D histogram in dense areas of points on the scatter plot." 

152 "Doesn't look great if plotting multiple datasets on top of each other.", 

153 default=True, 

154 ) 

155 plotTypes = ListField[str]( 

156 doc="Selection of types of objects to plot. Can take any combination of" 

157 " stars, galaxies, unknown, mag, any", 

158 optional=False, 

159 itemCheck=_validatePlotTypes, 

160 ) 

161 

162 addSummaryPlot = Field[bool]( 

163 doc="Add a summary plot to the figure?", 

164 default=False, 

165 ) 

166 

167 _stats = ("median", "sigmaMad", "count", "approxMag") 

168 

169 def getInputSchema(self) -> KeyedDataSchema: 

170 base: list[tuple[str, type[Vector] | type[Scalar]]] = [] 

171 if "stars" in self.plotTypes: # type: ignore 

172 base.append(("xStars", Vector)) 

173 base.append(("yStars", Vector)) 

174 base.append(("starsHighSNMask", Vector)) 

175 base.append(("starsLowSNMask", Vector)) 

176 # statistics 

177 for name in self._stats: 

178 base.append((f"{{band}}_highSNStars_{name}", Scalar)) 

179 base.append((f"{{band}}_lowSNStars_{name}", Scalar)) 

180 if "galaxies" in self.plotTypes: # type: ignore 

181 base.append(("xGalaxies", Vector)) 

182 base.append(("yGalaxies", Vector)) 

183 base.append(("galaxiesHighSNMask", Vector)) 

184 base.append(("galaxiesLowSNMask", Vector)) 

185 # statistics 

186 for name in self._stats: 

187 base.append((f"{{band}}_highSNGalaxies_{name}", Scalar)) 

188 base.append((f"{{band}}_lowSNGalaxies_{name}", Scalar)) 

189 if "unknown" in self.plotTypes: # type: ignore 

190 base.append(("xUnknown", Vector)) 

191 base.append(("yUnknown", Vector)) 

192 base.append(("unknownHighSNMask", Vector)) 

193 base.append(("unknownLowSNMask", Vector)) 

194 # statistics 

195 for name in self._stats: 

196 base.append((f"{{band}}_highSNUnknown_{name}", Scalar)) 

197 base.append((f"{{band}}_lowSNUnknown_{name}", Scalar)) 

198 if "any" in self.plotTypes: # type: ignore 

199 base.append(("x", Vector)) 

200 base.append(("y", Vector)) 

201 base.append(("anyHighSNMask", Vector)) 

202 base.append(("anySNMask", Vector)) 

203 # statistics 

204 for name in self._stats: 

205 base.append((f"{{band}}_highSNAny_{name}", Scalar)) 

206 base.append((f"{{band}}_lowSNAny_{name}", Scalar)) 

207 base.append(("lowSnThreshold", Scalar)) 

208 base.append(("highSnThreshold", Scalar)) 

209 

210 if self.addSummaryPlot: 

211 base.append(("patch", Vector)) 

212 

213 return base 

214 

215 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure: 

216 self._validateInput(data, **kwargs) 

217 return self.makePlot(data, **kwargs) 

218 

219 def _validateInput(self, data: KeyedData, **kwargs) -> None: 

220 """NOTE currently can only check that something is not a Scalar, not 

221 check that the data is consistent with Vector 

222 """ 

223 needed = self.getFormattedInputSchema(**kwargs) 

224 if remainder := {key.format(**kwargs) for key, _ in needed} - { 

225 key.format(**kwargs) for key in data.keys() 

226 }: 

227 raise ValueError(f"Task needs keys {remainder} but they were not found in input") 

228 for name, typ in needed: 

229 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar) 

230 if isScalar and typ != Scalar: 

231 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}") 

232 

233 def makePlot( 

234 self, 

235 data: KeyedData, 

236 skymap: BaseSkyMap, 

237 plotInfo: Mapping[str, str], 

238 sumStats: Optional[Mapping] = None, 

239 **kwargs, 

240 ) -> Figure: 

241 """Makes a generic plot with a 2D histogram and collapsed histograms of 

242 each axis. 

243 Parameters 

244 ---------- 

245 data : `pandas.core.frame.DataFrame` 

246 The catalog to plot the points from. 

247 plotInfo : `dict` 

248 A dictionary of information about the data being plotted with keys: 

249 ``"run"`` 

250 The output run for the plots (`str`). 

251 ``"skymap"`` 

252 The type of skymap used for the data (`str`). 

253 ``"filter"`` 

254 The filter used for this data (`str`). 

255 ``"tract"`` 

256 The tract that the data comes from (`str`). 

257 sumStats : `dict` 

258 A dictionary where the patchIds are the keys which store the R.A. 

259 and dec of the corners of the patch, along with a summary 

260 statistic for each patch. 

261 Returns 

262 ------- 

263 fig : `matplotlib.figure.Figure` 

264 The resulting figure. 

265 Notes 

266 ----- 

267 Uses the axisLabels config options `x` and `y` and the axisAction 

268 config options `xAction` and `yAction` to plot a scatter 

269 plot of the values against each other. A histogram of the points 

270 collapsed onto each axis is also plotted. A summary panel showing the 

271 median of the y value in each patch is shown in the upper right corner 

272 of the resultant plot. The code uses the selectorActions to decide 

273 which points to plot and the statisticSelector actions to determine 

274 which points to use for the printed statistics. 

275 """ 

276 if not self.plotTypes: 

277 noDataFig = Figure() 

278 noDataFig.text(0.3, 0.5, "No data to plot after selectors applied") 

279 noDataFig = addPlotInfo(noDataFig, plotInfo) 

280 return noDataFig 

281 

282 # Set default color and line style for the horizontal 

283 # reference line at 0 

284 if "hlineColor" not in kwargs: 

285 kwargs["hlineColor"] = "black" 

286 

287 if "hlineStyle" not in kwargs: 

288 kwargs["hlineStyle"] = (0, (1, 4)) 

289 

290 fig = plt.figure(dpi=300) 

291 gs = gridspec.GridSpec(4, 4) 

292 

293 # add the various plot elements 

294 ax, imhist = self._scatterPlot(data, fig, gs, **kwargs) 

295 self._makeTopHistogram(data, fig, gs, ax, **kwargs) 

296 self._makeSideHistogram(data, fig, gs, ax, imhist, **kwargs) 

297 # Needs info from run quantum 

298 if self.addSummaryPlot: 

299 sumStats = generateSummaryStats(data, skymap, plotInfo) 

300 label = self.yAxisLabel 

301 fig = addSummaryPlot(fig, gs[0, -1], sumStats, label) 

302 

303 plt.draw() 

304 plt.subplots_adjust(wspace=0.0, hspace=0.0, bottom=0.22, left=0.21) 

305 fig = addPlotInfo(fig, plotInfo) 

306 return fig 

307 

308 def _scatterPlot( 

309 self, data: KeyedData, fig: Figure, gs: gridspec.GridSpec, **kwargs 

310 ) -> tuple[Axes, Optional[PolyCollection]]: 

311 # Main scatter plot 

312 ax = fig.add_subplot(gs[1:, :-1]) 

313 

314 newBlues = mkColormap(["paleturquoise", "midnightBlue"]) 

315 newReds = mkColormap(["lemonchiffon", "firebrick"]) 

316 

317 binThresh = 5 

318 

319 yBinsOut = [] 

320 linesForLegend = [] 

321 

322 toPlotList = [] 

323 histIm = None 

324 highStats: _StatsContainer 

325 lowStats: _StatsContainer 

326 if "stars" in self.plotTypes: # type: ignore 

327 highArgs = {} 

328 lowArgs = {} 

329 for name in self._stats: 

330 highArgs[name] = cast(Scalar, data[f"{{band}}_highSNStars_{name}".format(**kwargs)]) 

331 lowArgs[name] = cast(Scalar, data[f"{{band}}_lowSNStars_{name}".format(**kwargs)]) 

332 highStats = _StatsContainer(**highArgs) 

333 lowStats = _StatsContainer(**lowArgs) 

334 

335 toPlotList.append( 

336 ( 

337 data["xStars"], 

338 data["yStars"], 

339 data["starsHighSNMask"], 

340 data["starsLowSNMask"], 

341 "midnightblue", 

342 newBlues, 

343 highStats, 

344 lowStats, 

345 ) 

346 ) 

347 if "galaxies" in self.plotTypes: # type: ignore 

348 highArgs = {} 

349 lowArgs = {} 

350 for name in self._stats: 

351 highArgs[name] = cast(Scalar, data[f"{{band}}_highSNGalaxies_{name}".format(**kwargs)]) 

352 lowArgs[name] = cast(Scalar, data[f"{{band}}_lowSNGalaxies_{name}".format(**kwargs)]) 

353 highStats = _StatsContainer(**highArgs) 

354 lowStats = _StatsContainer(**lowArgs) 

355 

356 toPlotList.append( 

357 ( 

358 data["xGalaxies"], 

359 data["yGalaxies"], 

360 data["galaxiesHighSNMask"], 

361 data["galaxiesLowSNMask"], 

362 "firebrick", 

363 newReds, 

364 highStats, 

365 lowStats, 

366 ) 

367 ) 

368 if "unknown" in self.plotTypes: # type: ignore 

369 highArgs = {} 

370 lowArgs = {} 

371 for name in self._stats: 

372 highArgs[name] = cast(Scalar, data[f"{{band}}_highSNUnknown_{name}".format(**kwargs)]) 

373 lowArgs[name] = cast(Scalar, data[f"{{band}}_lowSNUnknown_{name}".format(**kwargs)]) 

374 highStats = _StatsContainer(**highArgs) 

375 lowStats = _StatsContainer(**lowArgs) 

376 

377 toPlotList.append( 

378 ( 

379 data["xUnknown"], 

380 data["yUnknown"], 

381 data["unknownHighSNMask"], 

382 data["unknownLowSNMask"], 

383 "green", 

384 None, 

385 highStats, 

386 lowStats, 

387 ) 

388 ) 

389 if "any" in self.plotTypes: # type: ignore 

390 highArgs = {} 

391 lowArgs = {} 

392 for name in self._stats: 

393 highArgs[name] = cast(Scalar, data[f"{{band}}_highSNUnknown_{name}".format(**kwargs)]) 

394 lowArgs[name] = cast(Scalar, data[f"{{band}}_lowSNUnknown_{name}".format(**kwargs)]) 

395 highStats = _StatsContainer(**highArgs) 

396 lowStats = _StatsContainer(**lowArgs) 

397 

398 toPlotList.append( 

399 ( 

400 data["x"], 

401 data["y"], 

402 data["anyHighSNMask"], 

403 data["anyLowSNMask"], 

404 "purple", 

405 None, 

406 highStats, 

407 lowStats, 

408 ) 

409 ) 

410 

411 xMin = None 

412 for j, (xs, ys, highSn, lowSn, color, cmap, highStats, lowStats) in enumerate(toPlotList): 

413 highSn = cast(Vector, highSn) 

414 lowSn = cast(Vector, lowSn) 

415 # ensure the columns are actually array 

416 xs = np.array(xs) 

417 ys = np.array(ys) 

418 sigMadYs = nansigmaMad(ys) 

419 if len(xs) < 2: 

420 (medLine,) = ax.plot( 

421 xs, np.nanmedian(ys), color, label=f"Median: {np.nanmedian(ys):.2g}", lw=0.8 

422 ) 

423 linesForLegend.append(medLine) 

424 sigMads = np.array([nansigmaMad(ys)] * len(xs)) 

425 (sigMadLine,) = ax.plot( 

426 xs, 

427 np.nanmedian(ys) + 1.0 * sigMads, 

428 color, 

429 alpha=0.8, 

430 lw=0.8, 

431 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:.2g}", 

432 ) 

433 ax.plot(xs, np.nanmedian(ys) - 1.0 * sigMads, color, alpha=0.8) 

434 linesForLegend.append(sigMadLine) 

435 histIm = None 

436 continue 

437 

438 [xs1, xs25, xs50, xs75, xs95, xs97] = np.nanpercentile(xs, [1, 25, 50, 75, 95, 97]) 

439 xScale = (xs97 - xs1) / 20.0 # This is ~5% of the data range 

440 

441 # 40 was used as the default number of bins because it looked good 

442 xEdges = np.arange( 

443 np.nanmin(xs) - xScale, 

444 np.nanmax(xs) + xScale, 

445 (np.nanmax(xs) + xScale - (np.nanmin(xs) - xScale)) / self.nBins, 

446 ) 

447 medYs = np.nanmedian(ys) 

448 fiveSigmaHigh = medYs + 5.0 * sigMadYs 

449 fiveSigmaLow = medYs - 5.0 * sigMadYs 

450 binSize = (fiveSigmaHigh - fiveSigmaLow) / 101.0 

451 yEdges = np.arange(fiveSigmaLow, fiveSigmaHigh, binSize) 

452 

453 counts, xBins, yBins = np.histogram2d(xs, ys, bins=(xEdges, yEdges)) 

454 yBinsOut.append(yBins) 

455 countsYs = np.sum(counts, axis=1) 

456 

457 ids = np.where((countsYs > binThresh))[0] 

458 xEdgesPlot = xEdges[ids][1:] 

459 xEdges = xEdges[ids] 

460 

461 if len(ids) > 1: 

462 # Create the codes needed to turn the sigmaMad lines 

463 # into a path to speed up checking which points are 

464 # inside the area. 

465 codes = np.ones(len(xEdgesPlot) * 2) * Path.LINETO 

466 codes[0] = Path.MOVETO 

467 codes[-1] = Path.CLOSEPOLY 

468 

469 meds = np.zeros(len(xEdgesPlot)) 

470 threeSigMadVerts = np.zeros((len(xEdgesPlot) * 2, 2)) 

471 sigMads = np.zeros(len(xEdgesPlot)) 

472 

473 for i, xEdge in enumerate(xEdgesPlot): 

474 ids = np.where((xs < xEdge) & (xs > xEdges[i]) & (np.isfinite(ys)))[0] 

475 med = np.nanmedian(ys[ids]) 

476 sigMad = sigmaMad(ys[ids], nan_policy="omit") 

477 meds[i] = med 

478 sigMads[i] = sigMad 

479 threeSigMadVerts[i, :] = [xEdge, med + 3 * sigMad] 

480 threeSigMadVerts[-(i + 1), :] = [xEdge, med - 3 * sigMad] 

481 

482 (medLine,) = ax.plot(xEdgesPlot, meds, color, label="Running Median") 

483 linesForLegend.append(medLine) 

484 

485 # Make path to check which points lie within one sigma mad 

486 threeSigMadPath = Path(threeSigMadVerts, codes) 

487 

488 # Add lines for the median +/- 3 * sigma MAD 

489 (threeSigMadLine,) = ax.plot( 

490 xEdgesPlot, 

491 threeSigMadVerts[: len(xEdgesPlot), 1], 

492 color, 

493 alpha=0.4, 

494 label=r"3$\sigma_{MAD}$", 

495 ) 

496 ax.plot(xEdgesPlot[::-1], threeSigMadVerts[len(xEdgesPlot) :, 1], color, alpha=0.4) 

497 

498 # Add lines for the median +/- 1 * sigma MAD 

499 (sigMadLine,) = ax.plot( 

500 xEdgesPlot, meds + 1.0 * sigMads, color, alpha=0.8, label=r"$\sigma_{MAD}$" 

501 ) 

502 linesForLegend.append(sigMadLine) 

503 ax.plot(xEdgesPlot, meds - 1.0 * sigMads, color, alpha=0.8) 

504 

505 # Add lines for the median +/- 2 * sigma MAD 

506 (twoSigMadLine,) = ax.plot( 

507 xEdgesPlot, meds + 2.0 * sigMads, color, alpha=0.6, label=r"2$\sigma_{MAD}$" 

508 ) 

509 linesForLegend.append(twoSigMadLine) 

510 linesForLegend.append(threeSigMadLine) 

511 ax.plot(xEdgesPlot, meds - 2.0 * sigMads, color, alpha=0.6) 

512 

513 # Check which points are outside 3 sigma MAD of the median 

514 # and plot these as points. 

515 inside = threeSigMadPath.contains_points(np.array([xs, ys]).T) 

516 ax.plot(xs[~inside], ys[~inside], ".", ms=3, alpha=0.3, mfc=color, mec=color, zorder=-1) 

517 

518 # Add some stats text 

519 xPos = 0.65 - 0.4 * j 

520 bbox = dict(edgecolor=color, linestyle="--", facecolor="none") 

521 highThresh = data["highSnThreshold"] 

522 statText = f"S/N > {highThresh:0.4g} Stats ({self.magLabel} < {highStats.approxMag:0.4g})\n" 

523 highStatsStr = ( 

524 f"Median: {highStats.median:0.4g} " 

525 + r"$\sigma_{MAD}$: " 

526 + f"{highStats.sigmaMad:0.4g} " 

527 + r"N$_{points}$: " 

528 + f"{highStats.count}" 

529 ) 

530 statText += highStatsStr 

531 fig.text(xPos, 0.090, statText, bbox=bbox, transform=fig.transFigure, fontsize=6) 

532 

533 bbox = dict(edgecolor=color, linestyle=":", facecolor="none") 

534 lowThresh = data["lowSnThreshold"] 

535 statText = f"S/N > {lowThresh:0.4g} Stats ({self.magLabel} < {lowStats.approxMag:0.4g})\n" 

536 lowStatsStr = ( 

537 f"Median: {lowStats.median:0.4g} " 

538 + r"$\sigma_{MAD}$: " 

539 + f"{lowStats.sigmaMad:0.4g} " 

540 + r"N$_{points}$: " 

541 + f"{lowStats.count}" 

542 ) 

543 statText += lowStatsStr 

544 fig.text(xPos, 0.020, statText, bbox=bbox, transform=fig.transFigure, fontsize=6) 

545 

546 if self.plot2DHist: 

547 histIm = ax.hexbin(xs[inside], ys[inside], gridsize=75, cmap=cmap, mincnt=1, zorder=-3) 

548 

549 # If there are not many sources being used for the 

550 # statistics then plot them individually as just 

551 # plotting a line makes the statistics look wrong 

552 # as the magnitude estimation is iffy for low 

553 # numbers of sources. 

554 if np.sum(highSn) < 100 and np.sum(highSn) > 0: 

555 ax.plot( 

556 cast(Vector, xs[highSn]), 

557 cast(Vector, ys[highSn]), 

558 marker="x", 

559 ms=4, 

560 mec="w", 

561 mew=2, 

562 ls="none", 

563 ) 

564 (highSnLine,) = ax.plot( 

565 cast(Vector, xs[highSn]), 

566 cast(Vector, ys[highSn]), 

567 color=color, 

568 marker="x", 

569 ms=4, 

570 ls="none", 

571 label="High SN", 

572 ) 

573 linesForLegend.append(highSnLine) 

574 xMin = np.min(cast(Vector, xs[highSn])) 

575 else: 

576 ax.axvline(highStats.approxMag, color=color, ls="--") 

577 

578 if np.sum(lowSn) < 100 and np.sum(lowSn) > 0: 

579 ax.plot( 

580 cast(Vector, xs[lowSn]), 

581 cast(Vector, ys[lowSn]), 

582 marker="+", 

583 ms=4, 

584 mec="w", 

585 mew=2, 

586 ls="none", 

587 ) 

588 (lowSnLine,) = ax.plot( 

589 cast(Vector, xs[lowSn]), 

590 cast(Vector, ys[lowSn]), 

591 color=color, 

592 marker="+", 

593 ms=4, 

594 ls="none", 

595 label="Low SN", 

596 ) 

597 linesForLegend.append(lowSnLine) 

598 if xMin is None or xMin > np.min(cast(Vector, xs[lowSn])): 

599 xMin = np.min(cast(Vector, xs[lowSn])) 

600 else: 

601 ax.axvline(lowStats.approxMag, color=color, ls=":") 

602 

603 else: 

604 ax.plot(xs, ys, ".", ms=5, alpha=0.3, mfc=color, mec=color, zorder=-1) 

605 meds = np.array([np.nanmedian(ys)] * len(xs)) 

606 (medLine,) = ax.plot(xs, meds, color, label=f"Median: {np.nanmedian(ys):0.3g}", lw=0.8) 

607 linesForLegend.append(medLine) 

608 sigMads = np.array([nansigmaMad(ys)] * len(xs)) 

609 (sigMadLine,) = ax.plot( 

610 xs, 

611 meds + 1.0 * sigMads, 

612 color, 

613 alpha=0.8, 

614 lw=0.8, 

615 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:0.3g}", 

616 ) 

617 ax.plot(xs, meds - 1.0 * sigMads, color, alpha=0.8) 

618 linesForLegend.append(sigMadLine) 

619 histIm = None 

620 

621 # Add a horizontal reference line at 0 to the scatter plot 

622 ax.axhline(0, color=kwargs["hlineColor"], ls=kwargs["hlineStyle"], alpha=0.7, zorder=-2) 

623 

624 # Set the scatter plot limits 

625 # TODO: Make this not work by accident 

626 if len(cast(Vector, data["yStars"])) > 0: 

627 plotMed = np.nanmedian(cast(Vector, data["yStars"])) 

628 else: 

629 plotMed = np.nanmedian(cast(Vector, data["yGalaxies"])) 

630 # Ignore types below pending making this not working my accident 

631 if len(xs) < 2: # type: ignore 

632 meds = [np.nanmedian(ys)] # type: ignore 

633 if self.yLims: 

634 ax.set_ylim(self.yLims[0], self.yLims[1]) # type: ignore 

635 else: 

636 numSig = 4 

637 yLimMin = plotMed - numSig * sigMadYs # type: ignore 

638 yLimMax = plotMed + numSig * sigMadYs # type: ignore 

639 while (yLimMax < np.max(meds) or yLimMin > np.min(meds)) and numSig < 10: # type: ignore 

640 numSig += 1 

641 

642 numSig += 1 

643 yLimMin = plotMed - numSig * sigMadYs # type: ignore 

644 yLimMax = plotMed + numSig * sigMadYs # type: ignore 

645 ax.set_ylim(yLimMin, yLimMax) 

646 

647 if self.xLims: 

648 ax.set_xlim(self.xLims[0], self.xLims[1]) # type: ignore 

649 elif len(xs) > 2: # type: ignore 

650 if xMin is None: 

651 xMin = xs1 - 2 * xScale # type: ignore 

652 ax.set_xlim(xMin, xs97 + 2 * xScale) # type: ignore 

653 

654 # Add a line legend 

655 ax.legend( 

656 handles=linesForLegend, 

657 ncol=4, 

658 fontsize=6, 

659 loc="upper left", 

660 framealpha=0.9, 

661 edgecolor="k", 

662 borderpad=0.4, 

663 handlelength=1, 

664 ) 

665 

666 # Add axes labels 

667 ax.set_ylabel(self.yAxisLabel, fontsize=10, labelpad=10) 

668 ax.set_xlabel(self.xAxisLabel, fontsize=10, labelpad=2) 

669 

670 return ax, histIm 

671 

672 def _makeTopHistogram( 

673 self, data: KeyedData, figure: Figure, gs: gridspec.GridSpec, ax: Axes, **kwargs 

674 ) -> None: 

675 # Top histogram 

676 totalX: list[Vector] = [] 

677 if "stars" in self.plotTypes: # type: ignore 

678 totalX.append(cast(Vector, data["xStars"])) 

679 if "galaxies" in self.plotTypes: # type: ignore 

680 totalX.append(cast(Vector, data["xGalaxies"])) 

681 if "unknown" in self.plotTypes: # type: ignore 

682 totalX.append(cast(Vector, data["xUknown"])) 

683 if "any" in self.plotTypes: # type: ignore 

684 totalX.append(cast(Vector, data["x"])) 

685 

686 totalXChained = [x for x in chain.from_iterable(totalX) if x == x] 

687 

688 topHist = figure.add_subplot(gs[0, :-1], sharex=ax) 

689 topHist.hist( 

690 totalXChained, bins=100, color="grey", alpha=0.3, log=True, label=f"All ({len(totalXChained)})" 

691 ) 

692 if "galaxies" in self.plotTypes: # type: ignore 

693 topHist.hist( 

694 data["xGalaxies"], 

695 bins=100, 

696 color="firebrick", 

697 histtype="step", 

698 log=True, 

699 label=f"Galaxies ({len(cast(Vector, data['xGalaxies']))})", 

700 ) 

701 if "stars" in self.plotTypes: # type: ignore 

702 topHist.hist( 

703 data["xStars"], 

704 bins=100, 

705 color="midnightblue", 

706 histtype="step", 

707 log=True, 

708 label=f"Stars ({len(cast(Vector, data['xStars']))})", 

709 ) 

710 topHist.axes.get_xaxis().set_visible(False) 

711 topHist.set_ylabel("Number", fontsize=8) 

712 topHist.legend(fontsize=6, framealpha=0.9, borderpad=0.4, loc="lower left", ncol=3, edgecolor="k") 

713 

714 # Side histogram 

715 

716 def _makeSideHistogram( 

717 self, 

718 data: KeyedData, 

719 figure: Figure, 

720 gs: gridspec.Gridspec, 

721 ax: Axes, 

722 histIm: Optional[PolyCollection], 

723 **kwargs, 

724 ) -> None: 

725 sideHist = figure.add_subplot(gs[1:, -1], sharey=ax) 

726 

727 totalY: list[Vector] = [] 

728 if "stars" in self.plotTypes: # type: ignore 

729 totalY.append(cast(Vector, data["yStars"])) 

730 if "galaxies" in self.plotTypes: # type: ignore 

731 totalY.append(cast(Vector, data["yGalaxies"])) 

732 if "unknown" in self.plotTypes: # type: ignore 

733 totalY.append(cast(Vector, data["yUknown"])) 

734 if "any" in self.plotTypes: # type: ignore 

735 totalY.append(cast(Vector, data["y"])) 

736 totalYChained = [y for y in chain.from_iterable(totalY) if y == y] 

737 

738 # cheat to get the total count while iterating once 

739 yLimMin, yLimMax = ax.get_ylim() 

740 bins = np.linspace(yLimMin, yLimMax) 

741 sideHist.hist( 

742 totalYChained, 

743 bins=bins, 

744 color="grey", 

745 alpha=0.3, 

746 orientation="horizontal", 

747 log=True, 

748 ) 

749 if "galaxies" in self.plotTypes: # type: ignore 

750 sideHist.hist( 

751 [g for g in cast(Vector, data["yGalaxies"]) if g == g], 

752 bins=bins, 

753 color="firebrick", 

754 histtype="step", 

755 orientation="horizontal", 

756 log=True, 

757 ) 

758 sideHist.hist( 

759 cast(Vector, data["yGalaxies"])[cast(Vector, data["galaxiesHighSNMask"])], 

760 bins=bins, 

761 color="firebrick", 

762 histtype="step", 

763 orientation="horizontal", 

764 log=True, 

765 ls="--", 

766 ) 

767 sideHist.hist( 

768 cast(Vector, data["yGalaxies"])[cast(Vector, data["galaxiesLowSNMask"])], 

769 bins=bins, 

770 color="firebrick", 

771 histtype="step", 

772 orientation="horizontal", 

773 log=True, 

774 ls=":", 

775 ) 

776 

777 if "stars" in self.plotTypes: # type: ignore 

778 sideHist.hist( 

779 [s for s in cast(Vector, data["yStars"]) if s == s], 

780 bins=bins, 

781 color="midnightblue", 

782 histtype="step", 

783 orientation="horizontal", 

784 log=True, 

785 ) 

786 sideHist.hist( 

787 cast(Vector, data["yStars"])[cast(Vector, data["starsHighSNMask"])], 

788 bins=bins, 

789 color="midnightblue", 

790 histtype="step", 

791 orientation="horizontal", 

792 log=True, 

793 ls="--", 

794 ) 

795 sideHist.hist( 

796 cast(Vector, data["yStars"])[cast(Vector, data["starsLowSNMask"])], 

797 bins=bins, 

798 color="midnightblue", 

799 histtype="step", 

800 orientation="horizontal", 

801 log=True, 

802 ls=":", 

803 ) 

804 

805 # Add a horizontal reference line at 0 to the side histogram 

806 sideHist.axhline(0, color=kwargs["hlineColor"], ls=kwargs["hlineStyle"], alpha=0.7, zorder=-2) 

807 

808 sideHist.axes.get_yaxis().set_visible(False) 

809 sideHist.set_xlabel("Number", fontsize=8) 

810 if self.plot2DHist and histIm is not None: 

811 divider = make_axes_locatable(sideHist) 

812 cax = divider.append_axes("right", size="8%", pad=0) 

813 plt.colorbar(histIm, cax=cax, orientation="vertical", label="Number of Points Per Bin")