Coverage for python/lsst/analysis/tools/actions/plot/scatterplotWithTwoHists.py: 16%

383 statements  

« prev     ^ index     » next       coverage.py v6.4.2, created at 2022-08-06 09:59 +0000

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23from functools import partial 

24from itertools import chain 

25from typing import Mapping, NamedTuple, Optional, cast 

26 

27import astropy.units as u 

28import matplotlib.pyplot as plt 

29import numpy as np 

30import scipy.stats as sps 

31from lsst.analysis.tools.actions.scalar.scalarActions import CountAction, MedianAction, SigmaMadAction 

32from lsst.pex.config import Field 

33from lsst.pex.config.listField import ListField 

34from lsst.pipe.tasks.configurableActions import ConfigurableActionField 

35from lsst.skymap import BaseSkyMap 

36from matplotlib import gridspec 

37from matplotlib.axes import Axes 

38from matplotlib.collections import PolyCollection 

39from matplotlib.figure import Figure 

40from matplotlib.path import Path 

41from mpl_toolkits.axes_grid1 import make_axes_locatable 

42 

43from ...interfaces import ( 

44 KeyedData, 

45 KeyedDataAction, 

46 KeyedDataSchema, 

47 PlotAction, 

48 Scalar, 

49 ScalarAction, 

50 Vector, 

51 VectorAction, 

52) 

53from ..keyedData import KeyedScalars 

54from ..vector import SnSelector 

55from .plotUtils import addPlotInfo, addSummaryPlot, generateSummaryStats, mkColormap 

56 

57# ignore because coolwarm is actually part of module 

58cmapPatch = plt.cm.coolwarm.copy() # type: ignore 

59cmapPatch.set_bad(color="none") 

60 

61sigmaMad = partial(sps.median_abs_deviation, scale="normal") # type: ignore 

62 

63 

64class _ApproxMedian(ScalarAction): 

65 vectorKey = Field[str](doc="Key for the vector to perform action on", optional=False) 

66 inputUnit = Field[str](doc="Input unit of the vector", default="nJy") 

67 outputUnit = Field[str](doc="Output unit of the vector", default="mag(AB)") 

68 

69 def getInputSchema(self, **kwargs) -> KeyedDataSchema: 

70 return ((self.vectorKey.format(**kwargs), Vector),) 

71 

72 def __call__(self, data: KeyedData, **kwargs) -> Scalar: 

73 mask = self.getMask(**kwargs) 

74 value = np.sort(cast(Vector, data[self.vectorKey.format(**kwargs)])[mask]) 

75 x = int(len(value) / 10) 

76 median = np.nanmedian(value[-x:]) 

77 if self.inputUnit != self.outputUnit: 

78 median = (median * u.Unit(self.inputUnit)).to(u.Unit(self.outputUnit)).value 

79 return median 

80 

81 

82class _StatsImpl(KeyedScalars): 

83 vectorKey = Field[str](doc="Column key to compute scalars") 

84 

85 snFluxType = Field[str](doc="column key for the flux type used in SN selection") 

86 

87 def setDefaults(self): 

88 super().setDefaults() 

89 self.scalarActions.median = MedianAction(vectorKey=self.vectorKey) 

90 self.scalarActions.sigmaMad = SigmaMadAction(vectorKey=self.vectorKey) 

91 self.scalarActions.count = CountAction(vectorKey=self.vectorKey) 

92 self.scalarActions.approxMag = _ApproxMedian(vectorKey=self.snFluxType) 

93 

94 def __call__(self, data: KeyedData, **kwargs) -> KeyedData: 

95 mask = kwargs.get("mask") 

96 return super().__call__(data, **(kwargs | dict(mask=mask))) 

97 

98 

99class ScatterPlotStatsAction(KeyedDataAction): 

100 vectorKey = Field[str](doc="Vector on which to compute statistics") 

101 highSNSelector = ConfigurableActionField[VectorAction]( 

102 doc="Selector used to determine high SN Objects", default=SnSelector(threshold=2700) 

103 ) 

104 lowSNSelector = ConfigurableActionField[VectorAction]( 

105 doc="Selector used to determine low SN Objects", default=SnSelector(threshold=500) 

106 ) 

107 fluxType = Field[str](doc="Vector key to use to compute signal to noise ratio", default="{band}_psfFlux") 

108 

109 def getInputSchema(self, **kwargs) -> KeyedDataSchema: 

110 yield (self.vectorKey, Vector) 

111 yield (self.fluxType, Vector) 

112 yield from self.highSNSelector.getInputSchema() 

113 yield from self.lowSNSelector.getInputSchema() 

114 

115 def getOutputSchema(self) -> KeyedDataSchema: 

116 return ( 

117 (f'{self.identity or ""}HighSNMask', Vector), 

118 (f'{self.identity or ""}LowSNMask', Vector), 

119 (f"{{band}}_lowSN{self.identity.capitalize() if self.identity else ''}_median", Scalar), 

120 (f"{{band}}_lowSN{self.identity.capitalize() if self.identity else ''}_sigmaMad", Scalar), 

121 (f"{{band}}_lowSN{self.identity.capitalize() if self.identity else ''}_count", Scalar), 

122 (f"{{band}}_lowSN{self.identity.capitalize() if self.identity else ''}_approxMag", Scalar), 

123 (f"{{band}}_highSN{self.identity.capitalize() if self.identity else ''}_median", Scalar), 

124 (f"{{band}}_highSN{self.identity.capitalize() if self.identity else ''}_sigmaMad", Scalar), 

125 (f"{{band}}_highSN{self.identity.capitalize() if self.identity else ''}_count", Scalar), 

126 (f"{{band}}_highSN{self.identity.capitalize() if self.identity else ''}_approxMag", Scalar), 

127 ("highThreshold", Scalar), 

128 ("lowThreshold", Scalar), 

129 ) 

130 

131 def __call__(self, data: KeyedData, **kwargs) -> KeyedData: 

132 results = {} 

133 highMaskKey = f'{self.identity or ""}HighSNMask' 

134 results[highMaskKey] = self.highSNSelector(data, **kwargs) 

135 

136 lowMaskKey = f'{self.identity or ""}LowSNMask' 

137 results[lowMaskKey] = self.lowSNSelector(data, **kwargs) 

138 

139 prefix = f"{band}_" if (band := kwargs.get("band")) else "" 

140 

141 stats = _StatsImpl(vectorKey=self.vectorKey, snFluxType=self.fluxType) 

142 # this is sad, but pex_config seems to have broken behavior that 

143 # is dangerous to fix 

144 stats.setDefaults() 

145 for maskKey, typ in ((lowMaskKey, "low"), (highMaskKey, "high")): 

146 for name, value in stats(data, **(kwargs | {"mask": results[maskKey]})).items(): 

147 tmpKey = ( 

148 f"{prefix}{typ}SN{self.identity.capitalize() if self.identity else '' }_{name}".format( 

149 **kwargs 

150 ) 

151 ) 

152 results[tmpKey] = value 

153 results["highSnThreshold"] = self.highSNSelector.threshold # type: ignore 

154 results["lowSnThreshold"] = self.lowSNSelector.threshold # type: ignore 

155 

156 return results 

157 

158 

159def _validatePlotTypes(value): 

160 return value in ("stars", "galaxies", "unknown", "any", "mag") 

161 

162 

163# ignore type because of conflicting name on tuple baseclass 

164class _StatsContainer(NamedTuple): 

165 median: Scalar 

166 sigmaMad: Scalar 

167 count: Scalar # type: ignore 

168 approxMag: Scalar 

169 

170 

171class ScatterPlotWithTwoHists(PlotAction): 

172 yLims = ListField[float]( 

173 doc="ylimits of the plot, if not specified determined from data", 

174 length=2, 

175 optional=True, 

176 ) 

177 

178 xLims = ListField[float]( 

179 doc="xlimits of the plot, if not specified determined from data", length=2, optional=True 

180 ) 

181 xAxisLabel = Field[str](doc="Label to use for the x axis", optional=False) 

182 yAxisLabel = Field[str](doc="Label to use for the y axis", optional=False) 

183 magLabel = Field[str](doc="Label to use for the magnitudes used for SNR", optional=False) 

184 nBins = Field[float](doc="Number of bins on x axis", default=40.0) 

185 plot2DHist = Field[bool]( 

186 doc="Plot a 2D histogram in dense areas of points on the scatter plot." 

187 "Doesn't look great if plotting multiple datasets on top of each other.", 

188 default=True, 

189 ) 

190 plotTypes = ListField[str]( 

191 doc="Selection of types of objects to plot. Can take any combination of" 

192 " stars, galaxies, unknown, mag, any", 

193 optional=False, 

194 itemCheck=_validatePlotTypes, 

195 ) 

196 

197 addSummaryPlot = Field[bool]( 

198 doc="Add a summary plot to the figure?", 

199 default=False, 

200 ) 

201 

202 _stats = ("median", "sigmaMad", "count", "approxMag") 

203 

204 def getInputSchema(self) -> KeyedDataSchema: 

205 base: list[tuple[str, type[Vector] | type[Scalar]]] = [] 

206 if "stars" in self.plotTypes: # type: ignore 

207 base.append(("xStars", Vector)) 

208 base.append(("yStars", Vector)) 

209 base.append(("starsHighSNMask", Vector)) 

210 base.append(("starsLowSNMask", Vector)) 

211 # statistics 

212 for name in self._stats: 

213 base.append((f"{{band}}_highSNStars_{name}", Scalar)) 

214 base.append((f"{{band}}_lowSNStars_{name}", Scalar)) 

215 if "galaxies" in self.plotTypes: # type: ignore 

216 base.append(("xGalaxies", Vector)) 

217 base.append(("yGalaxies", Vector)) 

218 base.append(("galaxiesHighSNMask", Vector)) 

219 base.append(("galaxiesLowSNMask", Vector)) 

220 # statistics 

221 for name in self._stats: 

222 base.append((f"{{band}}_highSNGalaxies_{name}", Scalar)) 

223 base.append((f"{{band}}_lowSNGalaxies_{name}", Scalar)) 

224 if "unknown" in self.plotTypes: # type: ignore 

225 base.append(("xUnknown", Vector)) 

226 base.append(("yUnknown", Vector)) 

227 base.append(("unknownHighSNMask", Vector)) 

228 base.append(("unknownLowSNMask", Vector)) 

229 # statistics 

230 for name in self._stats: 

231 base.append((f"{{band}}_highSNUnknown_{name}", Scalar)) 

232 base.append((f"{{band}}_lowSNUnknown_{name}", Scalar)) 

233 if "any" in self.plotTypes: # type: ignore 

234 base.append(("x", Vector)) 

235 base.append(("y", Vector)) 

236 base.append(("anyHighSNMask", Vector)) 

237 base.append(("anySNMask", Vector)) 

238 # statistics 

239 for name in self._stats: 

240 base.append((f"{{band}}_highSNAny_{name}", Scalar)) 

241 base.append((f"{{band}}_lowSNAny_{name}", Scalar)) 

242 base.append(("lowSnThreshold", Scalar)) 

243 base.append(("highSnThreshold", Scalar)) 

244 

245 if self.addSummaryPlot: 

246 base.append(("patch", Vector)) 

247 

248 return base 

249 

250 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure: 

251 

252 self._validateInput(data, **kwargs) 

253 return self.makePlot(data, **kwargs) 

254 

255 def _validateInput(self, data: KeyedData, **kwargs) -> None: 

256 """NOTE currently can only check that something is not a Scalar, not 

257 check that the data is consistent with Vector 

258 """ 

259 needed = self.getFormattedInputSchema(**kwargs) 

260 if remainder := {key.format(**kwargs) for key, _ in needed} - { 

261 key.format(**kwargs) for key in data.keys() 

262 }: 

263 raise ValueError(f"Task needs keys {remainder} but they were not found in input") 

264 for name, typ in needed: 

265 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar) 

266 if isScalar and typ != Scalar: 

267 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}") 

268 

269 def makePlot( 

270 self, 

271 data: KeyedData, 

272 skymap: BaseSkyMap, 

273 plotInfo: Optional[Mapping[str, str]] = None, 

274 sumStats: Optional[Mapping] = None, 

275 **kwargs, 

276 ) -> Figure: 

277 """Makes a generic plot with a 2D histogram and collapsed histograms of 

278 each axis. 

279 Parameters 

280 ---------- 

281 data : `pandas.core.frame.DataFrame` 

282 The catalog to plot the points from. 

283 plotInfo : `dict` 

284 A dictionary of information about the data being plotted with keys: 

285 ``"run"`` 

286 The output run for the plots (`str`). 

287 ``"skymap"`` 

288 The type of skymap used for the data (`str`). 

289 ``"filter"`` 

290 The filter used for this data (`str`). 

291 ``"tract"`` 

292 The tract that the data comes from (`str`). 

293 sumStats : `dict` 

294 A dictionary where the patchIds are the keys which store the R.A. 

295 and dec of the corners of the patch, along with a summary 

296 statistic for each patch. 

297 Returns 

298 ------- 

299 fig : `matplotlib.figure.Figure` 

300 The resulting figure. 

301 Notes 

302 ----- 

303 Uses the axisLabels config options `x` and `y` and the axisAction 

304 config options `xAction` and `yAction` to plot a scatter 

305 plot of the values against each other. A histogram of the points 

306 collapsed onto each axis is also plotted. A summary panel showing the 

307 median of the y value in each patch is shown in the upper right corner 

308 of the resultant plot. The code uses the selectorActions to decide 

309 which points to plot and the statisticSelector actions to determine 

310 which points to use for the printed statistics. 

311 """ 

312 if not self.plotTypes: 

313 noDataFig = Figure() 

314 noDataFig.text(0.3, 0.5, "No data to plot after selectors applied") 

315 noDataFig = addPlotInfo(noDataFig, plotInfo) 

316 return noDataFig 

317 

318 fig = plt.figure(dpi=300) 

319 gs = gridspec.GridSpec(4, 4) 

320 

321 # add the various plot elements 

322 ax, imhist = self._scatterPlot(data, fig, gs, **kwargs) 

323 self._makeTopHistogram(data, fig, gs, ax, **kwargs) 

324 self._makeSideHistogram(data, fig, gs, ax, imhist, **kwargs) 

325 # Needs info from run quantum 

326 if self.addSummaryPlot: 

327 sumStats = generateSummaryStats(data, skymap, plotInfo) 

328 label = self.yAxisLabel 

329 fig = addSummaryPlot(fig, gs[0, -1], sumStats, label) 

330 

331 plt.draw() 

332 plt.subplots_adjust(wspace=0.0, hspace=0.0, bottom=0.22, left=0.21) 

333 fig = addPlotInfo(fig, plotInfo) 

334 return fig 

335 

336 def _scatterPlot( 

337 self, data: KeyedData, fig: Figure, gs: gridspec.GridSpec, **kwargs 

338 ) -> tuple[Axes, Optional[PolyCollection]]: 

339 # Main scatter plot 

340 ax = fig.add_subplot(gs[1:, :-1]) 

341 

342 newBlues = mkColormap(["paleturquoise", "midnightBlue"]) 

343 newReds = mkColormap(["lemonchiffon", "firebrick"]) 

344 

345 binThresh = 5 

346 

347 yBinsOut = [] 

348 linesForLegend = [] 

349 

350 toPlotList = [] 

351 histIm = None 

352 highStats: _StatsContainer 

353 lowStats: _StatsContainer 

354 if "stars" in self.plotTypes: # type: ignore 

355 highArgs = {} 

356 lowArgs = {} 

357 for name in self._stats: 

358 highArgs[name] = cast(Scalar, data[f"{{band}}_highSNStars_{name}".format(**kwargs)]) 

359 lowArgs[name] = cast(Scalar, data[f"{{band}}_lowSNStars_{name}".format(**kwargs)]) 

360 highStats = _StatsContainer(**highArgs) 

361 lowStats = _StatsContainer(**lowArgs) 

362 

363 toPlotList.append( 

364 ( 

365 data["xStars"], 

366 data["yStars"], 

367 data["starsHighSNMask"], 

368 data["starsLowSNMask"], 

369 "midnightblue", 

370 newBlues, 

371 highStats, 

372 lowStats, 

373 ) 

374 ) 

375 if "galaxies" in self.plotTypes: # type: ignore 

376 highArgs = {} 

377 lowArgs = {} 

378 for name in self._stats: 

379 highArgs[name] = cast(Scalar, data[f"{{band}}_highSNGalaxies_{name}".format(**kwargs)]) 

380 lowArgs[name] = cast(Scalar, data[f"{{band}}_lowSNGalaxies_{name}".format(**kwargs)]) 

381 highStats = _StatsContainer(**highArgs) 

382 lowStats = _StatsContainer(**lowArgs) 

383 

384 toPlotList.append( 

385 ( 

386 data["xGalaxies"], 

387 data["yGalaxies"], 

388 data["galaxiesHighSNMask"], 

389 data["galaxiesLowSNMask"], 

390 "firebrick", 

391 newReds, 

392 highStats, 

393 lowStats, 

394 ) 

395 ) 

396 if "unknown" in self.plotTypes: # type: ignore 

397 highArgs = {} 

398 lowArgs = {} 

399 for name in self._stats: 

400 highArgs[name] = cast(Scalar, data[f"{{band}}_highSNUnknown_{name}".format(**kwargs)]) 

401 lowArgs[name] = cast(Scalar, data[f"{{band}}_lowSNUnknown_{name}".format(**kwargs)]) 

402 highStats = _StatsContainer(**highArgs) 

403 lowStats = _StatsContainer(**lowArgs) 

404 

405 toPlotList.append( 

406 ( 

407 data["xUnknown"], 

408 data["yUnknown"], 

409 data["unknownHighSNMask"], 

410 data["unknownLowSNMask"], 

411 "green", 

412 None, 

413 highStats, 

414 lowStats, 

415 ) 

416 ) 

417 if "any" in self.plotTypes: # type: ignore 

418 highArgs = {} 

419 lowArgs = {} 

420 for name in self._stats: 

421 highArgs[name] = cast(Scalar, data[f"{{band}}_highSNUnknown_{name}".format(**kwargs)]) 

422 lowArgs[name] = cast(Scalar, data[f"{{band}}_lowSNUnknown_{name}".format(**kwargs)]) 

423 highStats = _StatsContainer(**highArgs) 

424 lowStats = _StatsContainer(**lowArgs) 

425 

426 toPlotList.append( 

427 ( 

428 data["x"], 

429 data["y"], 

430 data["anyHighSNMask"], 

431 data["anyLowSNMask"], 

432 "purple", 

433 None, 

434 highStats, 

435 lowStats, 

436 ) 

437 ) 

438 

439 xMin = None 

440 for (j, (xs, ys, highSn, lowSn, color, cmap, highStats, lowStats)) in enumerate(toPlotList): 

441 highSn = cast(Vector, highSn) 

442 lowSn = cast(Vector, lowSn) 

443 # ensure the columns are actually array 

444 xs = np.array(xs) 

445 ys = np.array(ys) 

446 sigMadYs = sigmaMad(ys, nan_policy="omit") 

447 if len(xs) < 2: 

448 (medLine,) = ax.plot( 

449 xs, np.nanmedian(ys), color, label=f"Median: {np.nanmedian(ys):0.3g}", lw=0.8 

450 ) 

451 linesForLegend.append(medLine) 

452 sigMads = np.array([sigmaMad(ys, nan_policy="omit")] * len(xs)) 

453 (sigMadLine,) = ax.plot( 

454 xs, 

455 np.nanmedian(ys) + 1.0 * sigMads, 

456 color, 

457 alpha=0.8, 

458 lw=0.8, 

459 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:0.3g}", 

460 ) 

461 ax.plot(xs, np.nanmedian(ys) - 1.0 * sigMads, color, alpha=0.8) 

462 linesForLegend.append(sigMadLine) 

463 histIm = None 

464 continue 

465 

466 [xs1, xs25, xs50, xs75, xs95, xs97] = np.nanpercentile(xs, [1, 25, 50, 75, 95, 97]) 

467 xScale = (xs97 - xs1) / 20.0 # This is ~5% of the data range 

468 

469 # 40 was used as the default number of bins because it looked good 

470 xEdges = np.arange( 

471 np.nanmin(xs) - xScale, 

472 np.nanmax(xs) + xScale, 

473 (np.nanmax(xs) + xScale - (np.nanmin(xs) - xScale)) / self.nBins, 

474 ) 

475 medYs = np.nanmedian(ys) 

476 fiveSigmaHigh = medYs + 5.0 * sigMadYs 

477 fiveSigmaLow = medYs - 5.0 * sigMadYs 

478 binSize = (fiveSigmaHigh - fiveSigmaLow) / 101.0 

479 yEdges = np.arange(fiveSigmaLow, fiveSigmaHigh, binSize) 

480 

481 counts, xBins, yBins = np.histogram2d(xs, ys, bins=(xEdges, yEdges)) 

482 yBinsOut.append(yBins) 

483 countsYs = np.sum(counts, axis=1) 

484 

485 ids = np.where((countsYs > binThresh))[0] 

486 xEdgesPlot = xEdges[ids][1:] 

487 xEdges = xEdges[ids] 

488 

489 if len(ids) > 1: 

490 # Create the codes needed to turn the sigmaMad lines 

491 # into a path to speed up checking which points are 

492 # inside the area. 

493 codes = np.ones(len(xEdgesPlot) * 2) * Path.LINETO 

494 codes[0] = Path.MOVETO 

495 codes[-1] = Path.CLOSEPOLY 

496 

497 meds = np.zeros(len(xEdgesPlot)) 

498 threeSigMadVerts = np.zeros((len(xEdgesPlot) * 2, 2)) 

499 sigMads = np.zeros(len(xEdgesPlot)) 

500 

501 for (i, xEdge) in enumerate(xEdgesPlot): 

502 ids = np.where((xs < xEdge) & (xs > xEdges[i]) & (np.isfinite(ys)))[0] 

503 med = np.median(ys[ids]) 

504 sigMad = sigmaMad(ys[ids]) 

505 meds[i] = med 

506 sigMads[i] = sigMad 

507 threeSigMadVerts[i, :] = [xEdge, med + 3 * sigMad] 

508 threeSigMadVerts[-(i + 1), :] = [xEdge, med - 3 * sigMad] 

509 

510 (medLine,) = ax.plot(xEdgesPlot, meds, color, label="Running Median") 

511 linesForLegend.append(medLine) 

512 

513 # Make path to check which points lie within one sigma mad 

514 threeSigMadPath = Path(threeSigMadVerts, codes) 

515 

516 # Add lines for the median +/- 3 * sigma MAD 

517 (threeSigMadLine,) = ax.plot( 

518 xEdgesPlot, 

519 threeSigMadVerts[: len(xEdgesPlot), 1], 

520 color, 

521 alpha=0.4, 

522 label=r"3$\sigma_{MAD}$", 

523 ) 

524 ax.plot(xEdgesPlot[::-1], threeSigMadVerts[len(xEdgesPlot) :, 1], color, alpha=0.4) 

525 

526 # Add lines for the median +/- 1 * sigma MAD 

527 (sigMadLine,) = ax.plot( 

528 xEdgesPlot, meds + 1.0 * sigMads, color, alpha=0.8, label=r"$\sigma_{MAD}$" 

529 ) 

530 linesForLegend.append(sigMadLine) 

531 ax.plot(xEdgesPlot, meds - 1.0 * sigMads, color, alpha=0.8) 

532 

533 # Add lines for the median +/- 2 * sigma MAD 

534 (twoSigMadLine,) = ax.plot( 

535 xEdgesPlot, meds + 2.0 * sigMads, color, alpha=0.6, label=r"2$\sigma_{MAD}$" 

536 ) 

537 linesForLegend.append(twoSigMadLine) 

538 linesForLegend.append(threeSigMadLine) 

539 ax.plot(xEdgesPlot, meds - 2.0 * sigMads, color, alpha=0.6) 

540 

541 # Check which points are outside 3 sigma MAD of the median 

542 # and plot these as points. 

543 inside = threeSigMadPath.contains_points(np.array([xs, ys]).T) 

544 ax.plot(xs[~inside], ys[~inside], ".", ms=3, alpha=0.3, mfc=color, mec=color, zorder=-1) 

545 

546 # Add some stats text 

547 xPos = 0.65 - 0.4 * j 

548 bbox = dict(edgecolor=color, linestyle="--", facecolor="none") 

549 highThresh = data["highSnThreshold"] 

550 statText = f"S/N > {highThresh} Stats ({self.magLabel} < {highStats.approxMag})\n" 

551 highStatsStr = ( 

552 f"Median: {highStats.median} " 

553 + r"$\sigma_{MAD}$: " 

554 + f"{highStats.sigmaMad} " 

555 + r"N$_{points}$: " 

556 + f"{highStats.count}" 

557 ) 

558 statText += highStatsStr 

559 fig.text(xPos, 0.090, statText, bbox=bbox, transform=fig.transFigure, fontsize=6) 

560 

561 bbox = dict(edgecolor=color, linestyle=":", facecolor="none") 

562 lowThresh = data["lowSnThreshold"] 

563 statText = f"S/N > {lowThresh} Stats ({self.magLabel} < {lowStats.approxMag})\n" 

564 lowStatsStr = ( 

565 f"Median: {lowStats.median} " 

566 + r"$\sigma_{MAD}$: " 

567 + f"{lowStats.sigmaMad} " 

568 + r"N$_{points}$: " 

569 + f"{lowStats.count}" 

570 ) 

571 statText += lowStatsStr 

572 fig.text(xPos, 0.020, statText, bbox=bbox, transform=fig.transFigure, fontsize=6) 

573 

574 if self.plot2DHist: 

575 histIm = ax.hexbin(xs[inside], ys[inside], gridsize=75, cmap=cmap, mincnt=1, zorder=-2) 

576 

577 # If there are not many sources being used for the 

578 # statistics then plot them individually as just 

579 # plotting a line makes the statistics look wrong 

580 # as the magnitude estimation is iffy for low 

581 # numbers of sources. 

582 if np.sum(highSn) < 100 and np.sum(highSn) > 0: 

583 ax.plot( 

584 cast(Vector, xs[highSn]), 

585 cast(Vector, ys[highSn]), 

586 marker="x", 

587 ms=4, 

588 mec="w", 

589 mew=2, 

590 ls="none", 

591 ) 

592 (highSnLine,) = ax.plot( 

593 cast(Vector, xs[highSn]), 

594 cast(Vector, ys[highSn]), 

595 color=color, 

596 marker="x", 

597 ms=4, 

598 ls="none", 

599 label="High SN", 

600 ) 

601 linesForLegend.append(highSnLine) 

602 xMin = np.min(cast(Vector, xs[highSn])) 

603 else: 

604 ax.axvline(highStats.approxMag, color=color, ls="--") 

605 

606 if np.sum(lowSn) < 100 and np.sum(lowSn) > 0: 

607 ax.plot( 

608 cast(Vector, xs[lowSn]), 

609 cast(Vector, ys[lowSn]), 

610 marker="+", 

611 ms=4, 

612 mec="w", 

613 mew=2, 

614 ls="none", 

615 ) 

616 (lowSnLine,) = ax.plot( 

617 cast(Vector, xs[lowSn]), 

618 cast(Vector, ys[lowSn]), 

619 color=color, 

620 marker="+", 

621 ms=4, 

622 ls="none", 

623 label="Low SN", 

624 ) 

625 linesForLegend.append(lowSnLine) 

626 if xMin is None or xMin > np.min(cast(Vector, xs[lowSn])): 

627 xMin = np.min(cast(Vector, xs[lowSn])) 

628 else: 

629 ax.axvline(lowStats.approxMag, color=color, ls=":") 

630 

631 else: 

632 ax.plot(xs, ys, ".", ms=5, alpha=0.3, mfc=color, mec=color, zorder=-1) 

633 meds = np.array([np.nanmedian(ys)] * len(xs)) 

634 (medLine,) = ax.plot(xs, meds, color, label=f"Median: {np.nanmedian(ys):0.3g}", lw=0.8) 

635 linesForLegend.append(medLine) 

636 sigMads = np.array([sigmaMad(ys, nan_policy="omit")] * len(xs)) 

637 (sigMadLine,) = ax.plot( 

638 xs, 

639 meds + 1.0 * sigMads, 

640 color, 

641 alpha=0.8, 

642 lw=0.8, 

643 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:0.3g}", 

644 ) 

645 ax.plot(xs, meds - 1.0 * sigMads, color, alpha=0.8) 

646 linesForLegend.append(sigMadLine) 

647 histIm = None 

648 

649 # Set the scatter plot limits 

650 # TODO: Make this not work by accident 

651 if len(cast(Vector, data["yStars"])) > 0: 

652 plotMed = np.nanmedian(cast(Vector, data["yStars"])) 

653 else: 

654 plotMed = np.nanmedian(cast(Vector, data["yGalaxies"])) 

655 # Ignore types below pending making this not working my accident 

656 if len(xs) < 2: # type: ignore 

657 meds = [np.median(ys)] # type: ignore 

658 if self.yLims: 

659 ax.set_ylim(self.yLims[0], self.yLims[1]) # type: ignore 

660 else: 

661 numSig = 4 

662 yLimMin = plotMed - numSig * sigMadYs # type: ignore 

663 yLimMax = plotMed + numSig * sigMadYs # type: ignore 

664 while (yLimMax < np.max(meds) or yLimMin > np.min(meds)) and numSig < 10: # type: ignore 

665 numSig += 1 

666 

667 numSig += 1 

668 yLimMin = plotMed - numSig * sigMadYs # type: ignore 

669 yLimMax = plotMed + numSig * sigMadYs # type: ignore 

670 ax.set_ylim(yLimMin, yLimMax) 

671 

672 if self.xLims: 

673 ax.set_xlim(self.xLims[0], self.xLims[1]) # type: ignore 

674 elif len(xs) > 2: # type: ignore 

675 if xMin is None: 

676 xMin = xs1 - 2 * xScale # type: ignore 

677 ax.set_xlim(xMin, xs97 + 2 * xScale) # type: ignore 

678 

679 # Add a line legend 

680 ax.legend( 

681 handles=linesForLegend, 

682 ncol=4, 

683 fontsize=6, 

684 loc="upper left", 

685 framealpha=0.9, 

686 edgecolor="k", 

687 borderpad=0.4, 

688 handlelength=1, 

689 ) 

690 

691 # Add axes labels 

692 ax.set_ylabel(self.yAxisLabel, fontsize=10, labelpad=10) 

693 ax.set_xlabel(self.xAxisLabel, fontsize=10, labelpad=2) 

694 

695 return ax, histIm 

696 

697 def _makeTopHistogram( 

698 self, data: KeyedData, figure: Figure, gs: gridspec.GridSpec, ax: Axes, **kwargs 

699 ) -> None: 

700 # Top histogram 

701 totalX: list[Vector] = [] 

702 if "stars" in self.plotTypes: # type: ignore 

703 totalX.append(cast(Vector, data["xStars"])) 

704 if "galaxies" in self.plotTypes: # type: ignore 

705 totalX.append(cast(Vector, data["xGalaxies"])) 

706 if "unknown" in self.plotTypes: # type: ignore 

707 totalX.append(cast(Vector, data["xUknown"])) 

708 if "any" in self.plotTypes: # type: ignore 

709 totalX.append(cast(Vector, data["x"])) 

710 

711 totalXChained = [x for x in chain.from_iterable(totalX) if x == x] 

712 

713 topHist = figure.add_subplot(gs[0, :-1], sharex=ax) 

714 topHist.hist( 

715 totalXChained, bins=100, color="grey", alpha=0.3, log=True, label=f"All ({len(totalXChained)})" 

716 ) 

717 if "galaxies" in self.plotTypes: # type: ignore 

718 topHist.hist( 

719 data["xGalaxies"], 

720 bins=100, 

721 color="firebrick", 

722 histtype="step", 

723 log=True, 

724 label=f"Galaxies ({len(cast(Vector, data['xGalaxies']))})", 

725 ) 

726 if "stars" in self.plotTypes: # type: ignore 

727 topHist.hist( 

728 data["xStars"], 

729 bins=100, 

730 color="midnightblue", 

731 histtype="step", 

732 log=True, 

733 label=f"Stars ({len(cast(Vector, data['xStars']))})", 

734 ) 

735 topHist.axes.get_xaxis().set_visible(False) 

736 topHist.set_ylabel("Number", fontsize=8) 

737 topHist.legend(fontsize=6, framealpha=0.9, borderpad=0.4, loc="lower left", ncol=3, edgecolor="k") 

738 

739 # Side histogram 

740 

741 def _makeSideHistogram( 

742 self, 

743 data: KeyedData, 

744 figure: Figure, 

745 gs: gridspec.Gridspec, 

746 ax: Axes, 

747 histIm: Optional[PolyCollection], 

748 **kwargs, 

749 ) -> None: 

750 sideHist = figure.add_subplot(gs[1:, -1], sharey=ax) 

751 

752 totalY: list[Vector] = [] 

753 if "stars" in self.plotTypes: # type: ignore 

754 totalY.append(cast(Vector, data["yStars"])) 

755 if "galaxies" in self.plotTypes: # type: ignore 

756 totalY.append(cast(Vector, data["yGalaxies"])) 

757 if "unknown" in self.plotTypes: # type: ignore 

758 totalY.append(cast(Vector, data["yUknown"])) 

759 if "any" in self.plotTypes: # type: ignore 

760 totalY.append(cast(Vector, data["y"])) 

761 totalYChained = [y for y in chain.from_iterable(totalY) if y == y] 

762 

763 # cheat to get the total count while iterating once 

764 yLimMin, yLimMax = ax.get_ylim() 

765 bins = np.linspace(yLimMin, yLimMax) 

766 sideHist.hist( 

767 totalYChained, 

768 bins=bins, 

769 color="grey", 

770 alpha=0.3, 

771 orientation="horizontal", 

772 log=True, 

773 ) 

774 if "galaxies" in self.plotTypes: # type: ignore 

775 sideHist.hist( 

776 [g for g in cast(Vector, data["yGalaxies"]) if g == g], 

777 bins=bins, 

778 color="firebrick", 

779 histtype="step", 

780 orientation="horizontal", 

781 log=True, 

782 ) 

783 sideHist.hist( 

784 cast(Vector, data["yGalaxies"])[cast(Vector, data["galaxiesHighSNMask"])], 

785 bins=bins, 

786 color="firebrick", 

787 histtype="step", 

788 orientation="horizontal", 

789 log=True, 

790 ls="--", 

791 ) 

792 sideHist.hist( 

793 cast(Vector, data["yGalaxies"])[cast(Vector, data["galaxiesLowSNMask"])], 

794 bins=bins, 

795 color="firebrick", 

796 histtype="step", 

797 orientation="horizontal", 

798 log=True, 

799 ls=":", 

800 ) 

801 

802 if "stars" in self.plotTypes: # type: ignore 

803 sideHist.hist( 

804 [s for s in cast(Vector, data["yStars"]) if s == s], 

805 bins=bins, 

806 color="midnightblue", 

807 histtype="step", 

808 orientation="horizontal", 

809 log=True, 

810 ) 

811 sideHist.hist( 

812 cast(Vector, data["yStars"])[cast(Vector, data["starsHighSNMask"])], 

813 bins=bins, 

814 color="midnightblue", 

815 histtype="step", 

816 orientation="horizontal", 

817 log=True, 

818 ls="--", 

819 ) 

820 sideHist.hist( 

821 cast(Vector, data["yStars"])[cast(Vector, data["starsLowSNMask"])], 

822 bins=bins, 

823 color="midnightblue", 

824 histtype="step", 

825 orientation="horizontal", 

826 log=True, 

827 ls=":", 

828 ) 

829 

830 sideHist.axes.get_yaxis().set_visible(False) 

831 sideHist.set_xlabel("Number", fontsize=8) 

832 if self.plot2DHist and histIm is not None: 

833 divider = make_axes_locatable(sideHist) 

834 cax = divider.append_axes("right", size="8%", pad=0) 

835 plt.colorbar(histIm, cax=cax, orientation="vertical", label="Number of Points Per Bin")