Coverage for python/lsst/analysis/tools/actions/plot/scatterplotWithTwoHists.py: 11%

375 statements  

« prev     ^ index     » next       coverage.py v7.2.5, created at 2023-05-12 03:08 -0700

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ("ScatterPlotStatsAction", "ScatterPlotWithTwoHists") 

25 

26from itertools import chain 

27from typing import Mapping, NamedTuple, Optional, cast 

28 

29import matplotlib.pyplot as plt 

30import numpy as np 

31from lsst.pex.config import Field 

32from lsst.pex.config.configurableActions import ConfigurableActionField 

33from lsst.pex.config.listField import ListField 

34from lsst.skymap import BaseSkyMap 

35from matplotlib import gridspec 

36from matplotlib.axes import Axes 

37from matplotlib.collections import PolyCollection 

38from matplotlib.figure import Figure 

39from matplotlib.path import Path 

40from mpl_toolkits.axes_grid1 import make_axes_locatable 

41 

42from ...interfaces import KeyedData, KeyedDataAction, KeyedDataSchema, PlotAction, Scalar, ScalarType, Vector 

43from ...statistics import nansigmaMad, sigmaMad 

44from ..keyedData.summaryStatistics import SummaryStatisticAction 

45from ..scalar import MedianAction 

46from ..vector import MagColumnNanoJansky, SnSelector 

47from .plotUtils import addPlotInfo, addSummaryPlot, generateSummaryStats, mkColormap 

48 

49# ignore because coolwarm is actually part of module 

50cmapPatch = plt.cm.coolwarm.copy() # type: ignore 

51cmapPatch.set_bad(color="none") 

52 

53 

54class ScatterPlotStatsAction(KeyedDataAction): 

55 """Calculates the statistics needed for the 

56 scatter plot with two hists. 

57 """ 

58 

59 vectorKey = Field[str](doc="Vector on which to compute statistics") 

60 highSNSelector = ConfigurableActionField[SnSelector]( 

61 doc="Selector used to determine high SN Objects", default=SnSelector(threshold=2700) 

62 ) 

63 lowSNSelector = ConfigurableActionField[SnSelector]( 

64 doc="Selector used to determine low SN Objects", default=SnSelector(threshold=500) 

65 ) 

66 fluxType = Field[str](doc="Vector key to use to compute signal to noise ratio", default="{band}_psfFlux") 

67 

68 def getInputSchema(self, **kwargs) -> KeyedDataSchema: 

69 yield (self.vectorKey, Vector) 

70 yield (self.fluxType, Vector) 

71 yield from self.highSNSelector.getInputSchema() 

72 yield from self.lowSNSelector.getInputSchema() 

73 

74 def getOutputSchema(self) -> KeyedDataSchema: 

75 return ( 

76 (f'{self.identity or ""}HighSNMask', Vector), 

77 (f'{self.identity or ""}LowSNMask', Vector), 

78 (f"{{band}}_lowSN{self.identity.capitalize() if self.identity else ''}_median", Scalar), 

79 (f"{{band}}_lowSN{self.identity.capitalize() if self.identity else ''}_sigmaMad", Scalar), 

80 (f"{{band}}_lowSN{self.identity.capitalize() if self.identity else ''}_count", Scalar), 

81 (f"{{band}}_lowSN{self.identity.capitalize() if self.identity else ''}_approxMag", Scalar), 

82 (f"{{band}}_highSN{self.identity.capitalize() if self.identity else ''}_median", Scalar), 

83 (f"{{band}}_highSN{self.identity.capitalize() if self.identity else ''}_sigmaMad", Scalar), 

84 (f"{{band}}_highSN{self.identity.capitalize() if self.identity else ''}_count", Scalar), 

85 (f"{{band}}_highSN{self.identity.capitalize() if self.identity else ''}_approxMag", Scalar), 

86 ("highThreshold", Scalar), 

87 ("lowThreshold", Scalar), 

88 ) 

89 

90 def __call__(self, data: KeyedData, **kwargs) -> KeyedData: 

91 results = {} 

92 highMaskKey = f'{(self.identity or "").lower()}HighSNMask' 

93 results[highMaskKey] = self.highSNSelector(data, **kwargs) 

94 

95 lowMaskKey = f'{(self.identity or "").lower()}LowSNMask' 

96 results[lowMaskKey] = self.lowSNSelector(data, **kwargs) 

97 

98 prefix = f"{band}_" if (band := kwargs.get("band")) else "" 

99 fluxes = data[self.fluxType.format(band=band)] if band is not None else None 

100 

101 statAction = SummaryStatisticAction(vectorKey=self.vectorKey) 

102 

103 # this is sad, but pex_config seems to have broken behavior that 

104 # is dangerous to fix 

105 statAction.setDefaults() 

106 

107 medianAction = MedianAction(vectorKey="mag") 

108 magAction = MagColumnNanoJansky(vectorKey="flux") 

109 

110 for maskKey, binName in ((lowMaskKey, "low"), (highMaskKey, "high")): 

111 name = f"{prefix}{binName}SN{self.identity.capitalize() if self.identity else ''}" 

112 # set the approxMag to the median mag in the SN selection 

113 results[f"{name}_approxMag".format(**kwargs)] = ( 

114 medianAction({"mag": magAction({"flux": fluxes[results[maskKey]]})}) # type: ignore 

115 if band is not None 

116 else np.nan 

117 ) 

118 stats = statAction(data, **(kwargs | {"mask": results[maskKey]})).items() 

119 for suffix, value in stats: 

120 tmpKey = f"{name}_{suffix}".format(**kwargs) 

121 results[tmpKey] = value 

122 results["highSnThreshold"] = self.highSNSelector.threshold # type: ignore 

123 results["lowSnThreshold"] = self.lowSNSelector.threshold # type: ignore 

124 

125 return results 

126 

127 

128def _validatePlotTypes(value): 

129 return value in ("stars", "galaxies", "unknown", "any", "mag") 

130 

131 

132# ignore type because of conflicting name on tuple baseclass 

133class _StatsContainer(NamedTuple): 

134 median: Scalar 

135 sigmaMad: Scalar 

136 count: Scalar # type: ignore 

137 approxMag: Scalar 

138 

139 

140class ScatterPlotWithTwoHists(PlotAction): 

141 """Makes a scatter plot of the data with a marginal 

142 histogram for each axis. 

143 """ 

144 

145 yLims = ListField[float]( 

146 doc="ylimits of the plot, if not specified determined from data", 

147 length=2, 

148 optional=True, 

149 ) 

150 

151 xLims = ListField[float]( 

152 doc="xlimits of the plot, if not specified determined from data", length=2, optional=True 

153 ) 

154 xAxisLabel = Field[str](doc="Label to use for the x axis", optional=False) 

155 yAxisLabel = Field[str](doc="Label to use for the y axis", optional=False) 

156 magLabel = Field[str](doc="Label to use for the magnitudes used for SNR", optional=False) 

157 nBins = Field[float](doc="Number of bins on x axis", default=40.0) 

158 plot2DHist = Field[bool]( 

159 doc="Plot a 2D histogram in dense areas of points on the scatter plot." 

160 "Doesn't look great if plotting multiple datasets on top of each other.", 

161 default=True, 

162 ) 

163 plotTypes = ListField[str]( 

164 doc="Selection of types of objects to plot. Can take any combination of" 

165 " stars, galaxies, unknown, mag, any", 

166 optional=False, 

167 itemCheck=_validatePlotTypes, 

168 ) 

169 

170 addSummaryPlot = Field[bool]( 

171 doc="Add a summary plot to the figure?", 

172 default=False, 

173 ) 

174 

175 _stats = ("median", "sigmaMad", "count", "approxMag") 

176 

177 def getInputSchema(self) -> KeyedDataSchema: 

178 base: list[tuple[str, type[Vector] | ScalarType]] = [] 

179 if "stars" in self.plotTypes: # type: ignore 

180 base.append(("xStars", Vector)) 

181 base.append(("yStars", Vector)) 

182 base.append(("starsHighSNMask", Vector)) 

183 base.append(("starsLowSNMask", Vector)) 

184 # statistics 

185 for name in self._stats: 

186 base.append((f"{{band}}_highSNStars_{name}", Scalar)) 

187 base.append((f"{{band}}_lowSNStars_{name}", Scalar)) 

188 if "galaxies" in self.plotTypes: # type: ignore 

189 base.append(("xGalaxies", Vector)) 

190 base.append(("yGalaxies", Vector)) 

191 base.append(("galaxiesHighSNMask", Vector)) 

192 base.append(("galaxiesLowSNMask", Vector)) 

193 # statistics 

194 for name in self._stats: 

195 base.append((f"{{band}}_highSNGalaxies_{name}", Scalar)) 

196 base.append((f"{{band}}_lowSNGalaxies_{name}", Scalar)) 

197 if "unknown" in self.plotTypes: # type: ignore 

198 base.append(("xUnknown", Vector)) 

199 base.append(("yUnknown", Vector)) 

200 base.append(("unknownHighSNMask", Vector)) 

201 base.append(("unknownLowSNMask", Vector)) 

202 # statistics 

203 for name in self._stats: 

204 base.append((f"{{band}}_highSNUnknown_{name}", Scalar)) 

205 base.append((f"{{band}}_lowSNUnknown_{name}", Scalar)) 

206 if "any" in self.plotTypes: # type: ignore 

207 base.append(("x", Vector)) 

208 base.append(("y", Vector)) 

209 base.append(("anyHighSNMask", Vector)) 

210 base.append(("anySNMask", Vector)) 

211 # statistics 

212 for name in self._stats: 

213 base.append((f"{{band}}_highSNAny_{name}", Scalar)) 

214 base.append((f"{{band}}_lowSNAny_{name}", Scalar)) 

215 base.append(("lowSnThreshold", Scalar)) 

216 base.append(("highSnThreshold", Scalar)) 

217 

218 if self.addSummaryPlot: 

219 base.append(("patch", Vector)) 

220 

221 return base 

222 

223 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure: 

224 self._validateInput(data, **kwargs) 

225 return self.makePlot(data, **kwargs) 

226 

227 def _validateInput(self, data: KeyedData, **kwargs) -> None: 

228 """NOTE currently can only check that something is not a Scalar, not 

229 check that the data is consistent with Vector 

230 """ 

231 needed = self.getFormattedInputSchema(**kwargs) 

232 if remainder := {key.format(**kwargs) for key, _ in needed} - { 

233 key.format(**kwargs) for key in data.keys() 

234 }: 

235 raise ValueError(f"Task needs keys {remainder} but they were not found in input") 

236 for name, typ in needed: 

237 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar) 

238 if isScalar and typ != Scalar: 

239 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}") 

240 

241 def makePlot( 

242 self, 

243 data: KeyedData, 

244 skymap: BaseSkyMap, 

245 plotInfo: Mapping[str, str], 

246 sumStats: Optional[Mapping] = None, 

247 **kwargs, 

248 ) -> Figure: 

249 """Makes a generic plot with a 2D histogram and collapsed histograms of 

250 each axis. 

251 

252 Parameters 

253 ---------- 

254 data : `KeyedData` 

255 The catalog to plot the points from. 

256 skymap : `lsst.skymap.BaseSkyMap` 

257 The skymap that gives the patch locations 

258 plotInfo : `dict` 

259 A dictionary of information about the data being plotted with keys: 

260 

261 * ``"run"`` 

262 The output run for the plots (`str`). 

263 * ``"skymap"`` 

264 The type of skymap used for the data (`str`). 

265 * ``"filter"`` 

266 The filter used for this data (`str`). 

267 * ``"tract"`` 

268 The tract that the data comes from (`str`). 

269 sumStats : `dict` 

270 A dictionary where the patchIds are the keys which store the R.A. 

271 and dec of the corners of the patch, along with a summary 

272 statistic for each patch. 

273 

274 Returns 

275 ------- 

276 fig : `matplotlib.figure.Figure` 

277 The resulting figure. 

278 

279 Notes 

280 ----- 

281 Uses the axisLabels config options `x` and `y` and the axisAction 

282 config options `xAction` and `yAction` to plot a scatter 

283 plot of the values against each other. A histogram of the points 

284 collapsed onto each axis is also plotted. A summary panel showing the 

285 median of the y value in each patch is shown in the upper right corner 

286 of the resultant plot. The code uses the selectorActions to decide 

287 which points to plot and the statisticSelector actions to determine 

288 which points to use for the printed statistics. 

289 

290 If this function is being used within the pipetask framework 

291 that takes care of making sure that data has all the required 

292 elements but if you are runnign this as a standalone function 

293 then you will need to provide the following things in the 

294 input data. 

295 

296 * If stars is in self.plotTypes: 

297 xStars, yStars, starsHighSNMask, starsLowSNMask and 

298 {band}_highSNStars_{name}, {band}_lowSNStars_{name} 

299 where name is median, sigma_Mad, count and approxMag. 

300 

301 * If it is for galaxies/unknowns then replace stars in the above 

302 names with galaxies/unknowns. 

303 

304 * If it is for any (which covers all the points) then it 

305 becomes, x, y, and any instead of stars for the other 

306 parameters given above. 

307 

308 * In every case it is expected that data contains: 

309 lowSnThreshold, highSnThreshold and patch 

310 (if the summary plot is being plotted). 

311 

312 Examples 

313 -------- 

314 An example of the plot produced from this code is here: 

315 

316 .. image:: /_static/analysis_tools/scatterPlotExample.png 

317 

318 For a detailed example of how to make a plot from the command line 

319 please see the 

320 :ref:`getting started guide<analysis-tools-getting-started>`. 

321 """ 

322 if not self.plotTypes: 

323 noDataFig = Figure() 

324 noDataFig.text(0.3, 0.5, "No data to plot after selectors applied") 

325 noDataFig = addPlotInfo(noDataFig, plotInfo) 

326 return noDataFig 

327 

328 # Set default color and line style for the horizontal 

329 # reference line at 0 

330 if "hlineColor" not in kwargs: 

331 kwargs["hlineColor"] = "black" 

332 

333 if "hlineStyle" not in kwargs: 

334 kwargs["hlineStyle"] = (0, (1, 4)) 

335 

336 fig = plt.figure(dpi=300) 

337 gs = gridspec.GridSpec(4, 4) 

338 

339 # add the various plot elements 

340 ax, imhist = self._scatterPlot(data, fig, gs, **kwargs) 

341 self._makeTopHistogram(data, fig, gs, ax, **kwargs) 

342 self._makeSideHistogram(data, fig, gs, ax, imhist, **kwargs) 

343 # Needs info from run quantum 

344 if self.addSummaryPlot: 

345 sumStats = generateSummaryStats(data, skymap, plotInfo) 

346 label = self.yAxisLabel 

347 fig = addSummaryPlot(fig, gs[0, -1], sumStats, label) 

348 

349 plt.draw() 

350 plt.subplots_adjust(wspace=0.0, hspace=0.0, bottom=0.22, left=0.21) 

351 fig = addPlotInfo(fig, plotInfo) 

352 return fig 

353 

354 def _scatterPlot( 

355 self, data: KeyedData, fig: Figure, gs: gridspec.GridSpec, **kwargs 

356 ) -> tuple[Axes, Optional[PolyCollection]]: 

357 # Main scatter plot 

358 ax = fig.add_subplot(gs[1:, :-1]) 

359 

360 newBlues = mkColormap(["paleturquoise", "midnightBlue"]) 

361 newReds = mkColormap(["lemonchiffon", "firebrick"]) 

362 

363 binThresh = 5 

364 

365 yBinsOut = [] 

366 linesForLegend = [] 

367 

368 toPlotList = [] 

369 histIm = None 

370 highStats: _StatsContainer 

371 lowStats: _StatsContainer 

372 if "stars" in self.plotTypes: # type: ignore 

373 highArgs = {} 

374 lowArgs = {} 

375 for name in self._stats: 

376 highArgs[name] = cast(Scalar, data[f"{{band}}_highSNStars_{name}".format(**kwargs)]) 

377 lowArgs[name] = cast(Scalar, data[f"{{band}}_lowSNStars_{name}".format(**kwargs)]) 

378 highStats = _StatsContainer(**highArgs) 

379 lowStats = _StatsContainer(**lowArgs) 

380 

381 toPlotList.append( 

382 ( 

383 data["xStars"], 

384 data["yStars"], 

385 data["starsHighSNMask"], 

386 data["starsLowSNMask"], 

387 "midnightblue", 

388 newBlues, 

389 highStats, 

390 lowStats, 

391 ) 

392 ) 

393 if "galaxies" in self.plotTypes: # type: ignore 

394 highArgs = {} 

395 lowArgs = {} 

396 for name in self._stats: 

397 highArgs[name] = cast(Scalar, data[f"{{band}}_highSNGalaxies_{name}".format(**kwargs)]) 

398 lowArgs[name] = cast(Scalar, data[f"{{band}}_lowSNGalaxies_{name}".format(**kwargs)]) 

399 highStats = _StatsContainer(**highArgs) 

400 lowStats = _StatsContainer(**lowArgs) 

401 

402 toPlotList.append( 

403 ( 

404 data["xGalaxies"], 

405 data["yGalaxies"], 

406 data["galaxiesHighSNMask"], 

407 data["galaxiesLowSNMask"], 

408 "firebrick", 

409 newReds, 

410 highStats, 

411 lowStats, 

412 ) 

413 ) 

414 if "unknown" in self.plotTypes: # type: ignore 

415 highArgs = {} 

416 lowArgs = {} 

417 for name in self._stats: 

418 highArgs[name] = cast(Scalar, data[f"{{band}}_highSNUnknown_{name}".format(**kwargs)]) 

419 lowArgs[name] = cast(Scalar, data[f"{{band}}_lowSNUnknown_{name}".format(**kwargs)]) 

420 highStats = _StatsContainer(**highArgs) 

421 lowStats = _StatsContainer(**lowArgs) 

422 

423 toPlotList.append( 

424 ( 

425 data["xUnknown"], 

426 data["yUnknown"], 

427 data["unknownHighSNMask"], 

428 data["unknownLowSNMask"], 

429 "green", 

430 None, 

431 highStats, 

432 lowStats, 

433 ) 

434 ) 

435 if "any" in self.plotTypes: # type: ignore 

436 highArgs = {} 

437 lowArgs = {} 

438 for name in self._stats: 

439 highArgs[name] = cast(Scalar, data[f"{{band}}_highSNUnknown_{name}".format(**kwargs)]) 

440 lowArgs[name] = cast(Scalar, data[f"{{band}}_lowSNUnknown_{name}".format(**kwargs)]) 

441 highStats = _StatsContainer(**highArgs) 

442 lowStats = _StatsContainer(**lowArgs) 

443 

444 toPlotList.append( 

445 ( 

446 data["x"], 

447 data["y"], 

448 data["anyHighSNMask"], 

449 data["anyLowSNMask"], 

450 "purple", 

451 None, 

452 highStats, 

453 lowStats, 

454 ) 

455 ) 

456 

457 xMin = None 

458 for j, (xs, ys, highSn, lowSn, color, cmap, highStats, lowStats) in enumerate(toPlotList): 

459 highSn = cast(Vector, highSn) 

460 lowSn = cast(Vector, lowSn) 

461 # ensure the columns are actually array 

462 xs = np.array(xs) 

463 ys = np.array(ys) 

464 sigMadYs = nansigmaMad(ys) 

465 # plot lone median point if there's not enough data to measure more 

466 n_xs = len(xs) 

467 if n_xs == 0: 

468 continue 

469 elif n_xs < 10: 

470 xs = [np.nanmedian(xs)] 

471 sigMads = np.array([nansigmaMad(ys)]) 

472 ys = [np.nanmedian(ys)] 

473 (medLine,) = ax.plot(xs, ys, color, label=f"Median: {ys[0]:.2g}", lw=0.8) 

474 linesForLegend.append(medLine) 

475 (sigMadLine,) = ax.plot( 

476 xs, 

477 ys + 1.0 * sigMads, 

478 color, 

479 alpha=0.8, 

480 lw=0.8, 

481 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:.2g}", 

482 ) 

483 ax.plot(xs, ys - 1.0 * sigMads, color, alpha=0.8) 

484 linesForLegend.append(sigMadLine) 

485 histIm = None 

486 continue 

487 

488 [xs1, xs25, xs50, xs75, xs95, xs97] = np.nanpercentile(xs, [1, 25, 50, 75, 95, 97]) 

489 xScale = (xs97 - xs1) / 20.0 # This is ~5% of the data range 

490 

491 # 40 was used as the default number of bins because it looked good 

492 xEdges = np.arange( 

493 np.nanmin(xs) - xScale, 

494 np.nanmax(xs) + xScale, 

495 (np.nanmax(xs) + xScale - (np.nanmin(xs) - xScale)) / self.nBins, 

496 ) 

497 medYs = np.nanmedian(ys) 

498 fiveSigmaHigh = medYs + 5.0 * sigMadYs 

499 fiveSigmaLow = medYs - 5.0 * sigMadYs 

500 binSize = (fiveSigmaHigh - fiveSigmaLow) / 101.0 

501 yEdges = np.arange(fiveSigmaLow, fiveSigmaHigh, binSize) 

502 

503 counts, xBins, yBins = np.histogram2d(xs, ys, bins=(xEdges, yEdges)) 

504 yBinsOut.append(yBins) 

505 countsYs = np.sum(counts, axis=1) 

506 

507 ids = np.where((countsYs > binThresh))[0] 

508 xEdgesPlot = xEdges[ids][1:] 

509 xEdges = xEdges[ids] 

510 

511 if len(ids) > 1: 

512 # Create the codes needed to turn the sigmaMad lines 

513 # into a path to speed up checking which points are 

514 # inside the area. 

515 codes = np.ones(len(xEdgesPlot) * 2) * Path.LINETO 

516 codes[0] = Path.MOVETO 

517 codes[-1] = Path.CLOSEPOLY 

518 

519 meds = np.zeros(len(xEdgesPlot)) 

520 threeSigMadVerts = np.zeros((len(xEdgesPlot) * 2, 2)) 

521 sigMads = np.zeros(len(xEdgesPlot)) 

522 

523 for i, xEdge in enumerate(xEdgesPlot): 

524 ids = np.where((xs < xEdge) & (xs > xEdges[i]) & (np.isfinite(ys)))[0] 

525 med = np.nanmedian(ys[ids]) 

526 sigMad = sigmaMad(ys[ids], nan_policy="omit") 

527 meds[i] = med 

528 sigMads[i] = sigMad 

529 threeSigMadVerts[i, :] = [xEdge, med + 3 * sigMad] 

530 threeSigMadVerts[-(i + 1), :] = [xEdge, med - 3 * sigMad] 

531 

532 (medLine,) = ax.plot(xEdgesPlot, meds, color, label="Running Median") 

533 linesForLegend.append(medLine) 

534 

535 # Make path to check which points lie within one sigma mad 

536 threeSigMadPath = Path(threeSigMadVerts, codes) 

537 

538 # Add lines for the median +/- 3 * sigma MAD 

539 (threeSigMadLine,) = ax.plot( 

540 xEdgesPlot, 

541 threeSigMadVerts[: len(xEdgesPlot), 1], 

542 color, 

543 alpha=0.4, 

544 label=r"3$\sigma_{MAD}$", 

545 ) 

546 ax.plot(xEdgesPlot[::-1], threeSigMadVerts[len(xEdgesPlot) :, 1], color, alpha=0.4) 

547 

548 # Add lines for the median +/- 1 * sigma MAD 

549 (sigMadLine,) = ax.plot( 

550 xEdgesPlot, meds + 1.0 * sigMads, color, alpha=0.8, label=r"$\sigma_{MAD}$" 

551 ) 

552 linesForLegend.append(sigMadLine) 

553 ax.plot(xEdgesPlot, meds - 1.0 * sigMads, color, alpha=0.8) 

554 

555 # Add lines for the median +/- 2 * sigma MAD 

556 (twoSigMadLine,) = ax.plot( 

557 xEdgesPlot, meds + 2.0 * sigMads, color, alpha=0.6, label=r"2$\sigma_{MAD}$" 

558 ) 

559 linesForLegend.append(twoSigMadLine) 

560 linesForLegend.append(threeSigMadLine) 

561 ax.plot(xEdgesPlot, meds - 2.0 * sigMads, color, alpha=0.6) 

562 

563 # Check which points are outside 3 sigma MAD of the median 

564 # and plot these as points. 

565 inside = threeSigMadPath.contains_points(np.array([xs, ys]).T) 

566 ax.plot(xs[~inside], ys[~inside], ".", ms=3, alpha=0.3, mfc=color, mec=color, zorder=-1) 

567 

568 # Add some stats text 

569 xPos = 0.65 - 0.4 * j 

570 bbox = dict(edgecolor=color, linestyle="--", facecolor="none") 

571 highThresh = data["highSnThreshold"] 

572 statText = f"S/N > {highThresh:0.4g} Stats ({self.magLabel} < {highStats.approxMag:0.4g})\n" 

573 highStatsStr = ( 

574 f"Median: {highStats.median:0.4g} " 

575 + r"$\sigma_{MAD}$: " 

576 + f"{highStats.sigmaMad:0.4g} " 

577 + r"N$_{points}$: " 

578 + f"{highStats.count}" 

579 ) 

580 statText += highStatsStr 

581 fig.text(xPos, 0.090, statText, bbox=bbox, transform=fig.transFigure, fontsize=6) 

582 

583 bbox = dict(edgecolor=color, linestyle=":", facecolor="none") 

584 lowThresh = data["lowSnThreshold"] 

585 statText = f"S/N > {lowThresh:0.4g} Stats ({self.magLabel} < {lowStats.approxMag:0.4g})\n" 

586 lowStatsStr = ( 

587 f"Median: {lowStats.median:0.4g} " 

588 + r"$\sigma_{MAD}$: " 

589 + f"{lowStats.sigmaMad:0.4g} " 

590 + r"N$_{points}$: " 

591 + f"{lowStats.count}" 

592 ) 

593 statText += lowStatsStr 

594 fig.text(xPos, 0.020, statText, bbox=bbox, transform=fig.transFigure, fontsize=6) 

595 

596 if self.plot2DHist: 

597 histIm = ax.hexbin(xs[inside], ys[inside], gridsize=75, cmap=cmap, mincnt=1, zorder=-3) 

598 

599 # If there are not many sources being used for the 

600 # statistics then plot them individually as just 

601 # plotting a line makes the statistics look wrong 

602 # as the magnitude estimation is iffy for low 

603 # numbers of sources. 

604 if np.sum(highSn) < 100 and np.sum(highSn) > 0: 

605 ax.plot( 

606 cast(Vector, xs[highSn]), 

607 cast(Vector, ys[highSn]), 

608 marker="x", 

609 ms=4, 

610 mec="w", 

611 mew=2, 

612 ls="none", 

613 ) 

614 (highSnLine,) = ax.plot( 

615 cast(Vector, xs[highSn]), 

616 cast(Vector, ys[highSn]), 

617 color=color, 

618 marker="x", 

619 ms=4, 

620 ls="none", 

621 label="High SN", 

622 ) 

623 linesForLegend.append(highSnLine) 

624 xMin = np.min(cast(Vector, xs[highSn])) 

625 else: 

626 ax.axvline(highStats.approxMag, color=color, ls="--") 

627 

628 if np.sum(lowSn) < 100 and np.sum(lowSn) > 0: 

629 ax.plot( 

630 cast(Vector, xs[lowSn]), 

631 cast(Vector, ys[lowSn]), 

632 marker="+", 

633 ms=4, 

634 mec="w", 

635 mew=2, 

636 ls="none", 

637 ) 

638 (lowSnLine,) = ax.plot( 

639 cast(Vector, xs[lowSn]), 

640 cast(Vector, ys[lowSn]), 

641 color=color, 

642 marker="+", 

643 ms=4, 

644 ls="none", 

645 label="Low SN", 

646 ) 

647 linesForLegend.append(lowSnLine) 

648 if xMin is None or xMin > np.min(cast(Vector, xs[lowSn])): 

649 xMin = np.min(cast(Vector, xs[lowSn])) 

650 else: 

651 ax.axvline(lowStats.approxMag, color=color, ls=":") 

652 

653 else: 

654 ax.plot(xs, ys, ".", ms=5, alpha=0.3, mfc=color, mec=color, zorder=-1) 

655 meds = np.array([np.nanmedian(ys)] * len(xs)) 

656 (medLine,) = ax.plot(xs, meds, color, label=f"Median: {np.nanmedian(ys):0.3g}", lw=0.8) 

657 linesForLegend.append(medLine) 

658 sigMads = np.array([nansigmaMad(ys)] * len(xs)) 

659 (sigMadLine,) = ax.plot( 

660 xs, 

661 meds + 1.0 * sigMads, 

662 color, 

663 alpha=0.8, 

664 lw=0.8, 

665 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:0.3g}", 

666 ) 

667 ax.plot(xs, meds - 1.0 * sigMads, color, alpha=0.8) 

668 linesForLegend.append(sigMadLine) 

669 histIm = None 

670 

671 # Add a horizontal reference line at 0 to the scatter plot 

672 ax.axhline(0, color=kwargs["hlineColor"], ls=kwargs["hlineStyle"], alpha=0.7, zorder=-2) 

673 

674 # Set the scatter plot limits 

675 # TODO: Make this not work by accident 

676 if "yStars" in data and (len(cast(Vector, data["yStars"])) > 0): 

677 plotMed = np.nanmedian(cast(Vector, data["yStars"])) 

678 elif "yGalaxies" in data and (len(cast(Vector, data["yGalaxies"])) > 0): 

679 plotMed = np.nanmedian(cast(Vector, data["yGalaxies"])) 

680 else: 

681 plotMed = np.nan 

682 

683 # Ignore types below pending making this not working my accident 

684 if len(xs) < 2: # type: ignore 

685 meds = [np.nanmedian(ys)] # type: ignore 

686 if self.yLims: 

687 ax.set_ylim(self.yLims[0], self.yLims[1]) # type: ignore 

688 elif np.isfinite(plotMed): 

689 numSig = 4 

690 yLimMin = plotMed - numSig * sigMadYs # type: ignore 

691 yLimMax = plotMed + numSig * sigMadYs # type: ignore 

692 while (yLimMax < np.max(meds) or yLimMin > np.min(meds)) and numSig < 10: # type: ignore 

693 numSig += 1 

694 

695 numSig += 1 

696 yLimMin = plotMed - numSig * sigMadYs # type: ignore 

697 yLimMax = plotMed + numSig * sigMadYs # type: ignore 

698 ax.set_ylim(yLimMin, yLimMax) 

699 

700 if self.xLims: 

701 ax.set_xlim(self.xLims[0], self.xLims[1]) # type: ignore 

702 elif len(xs) > 2: # type: ignore 

703 if xMin is None: 

704 xMin = xs1 - 2 * xScale # type: ignore 

705 ax.set_xlim(xMin, xs97 + 2 * xScale) # type: ignore 

706 

707 # Add a line legend 

708 ax.legend( 

709 handles=linesForLegend, 

710 ncol=4, 

711 fontsize=6, 

712 loc="upper left", 

713 framealpha=0.9, 

714 edgecolor="k", 

715 borderpad=0.4, 

716 handlelength=1, 

717 ) 

718 

719 # Add axes labels 

720 ax.set_ylabel(self.yAxisLabel, fontsize=10, labelpad=10) 

721 ax.set_xlabel(self.xAxisLabel, fontsize=10, labelpad=2) 

722 

723 return ax, histIm 

724 

725 def _makeTopHistogram( 

726 self, data: KeyedData, figure: Figure, gs: gridspec.GridSpec, ax: Axes, **kwargs 

727 ) -> None: 

728 # Top histogram 

729 totalX: list[Vector] = [] 

730 if "stars" in self.plotTypes: # type: ignore 

731 totalX.append(cast(Vector, data["xStars"])) 

732 if "galaxies" in self.plotTypes: # type: ignore 

733 totalX.append(cast(Vector, data["xGalaxies"])) 

734 if "unknown" in self.plotTypes: # type: ignore 

735 totalX.append(cast(Vector, data["xUknown"])) 

736 if "any" in self.plotTypes: # type: ignore 

737 totalX.append(cast(Vector, data["x"])) 

738 

739 totalXChained = [x for x in chain.from_iterable(totalX) if x == x] 

740 

741 topHist = figure.add_subplot(gs[0, :-1], sharex=ax) 

742 topHist.hist( 

743 totalXChained, bins=100, color="grey", alpha=0.3, log=True, label=f"All ({len(totalXChained)})" 

744 ) 

745 if "galaxies" in self.plotTypes: # type: ignore 

746 topHist.hist( 

747 data["xGalaxies"], 

748 bins=100, 

749 color="firebrick", 

750 histtype="step", 

751 log=True, 

752 label=f"Galaxies ({len(cast(Vector, data['xGalaxies']))})", 

753 ) 

754 if "stars" in self.plotTypes: # type: ignore 

755 topHist.hist( 

756 data["xStars"], 

757 bins=100, 

758 color="midnightblue", 

759 histtype="step", 

760 log=True, 

761 label=f"Stars ({len(cast(Vector, data['xStars']))})", 

762 ) 

763 topHist.axes.get_xaxis().set_visible(False) 

764 topHist.set_ylabel("Number", fontsize=8) 

765 topHist.legend(fontsize=6, framealpha=0.9, borderpad=0.4, loc="lower left", ncol=3, edgecolor="k") 

766 

767 # Side histogram 

768 

769 def _makeSideHistogram( 

770 self, 

771 data: KeyedData, 

772 figure: Figure, 

773 gs: gridspec.Gridspec, 

774 ax: Axes, 

775 histIm: Optional[PolyCollection], 

776 **kwargs, 

777 ) -> None: 

778 sideHist = figure.add_subplot(gs[1:, -1], sharey=ax) 

779 

780 totalY: dict[str, Vector] = {} 

781 if "stars" in self.plotTypes and "yStars" in data: # type: ignore 

782 totalY["stars"] = cast(Vector, data["yStars"]) 

783 if "galaxies" in self.plotTypes and "yGalaxies" in data: # type: ignore 

784 totalY["galaxies"] = cast(Vector, data["yGalaxies"]) 

785 if "unknown" in self.plotTypes and "yUnknown" in data: # type: ignore 

786 totalY["unknown"] = cast(Vector, data["yUnknown"]) 

787 if "any" in self.plotTypes and "y" in data: # type: ignore 

788 totalY["y"] = cast(Vector, data["y"]) 

789 totalYChained = [y for y in chain.from_iterable(totalY.values()) if y == y] 

790 

791 # cheat to get the total count while iterating once 

792 yLimMin, yLimMax = ax.get_ylim() 

793 bins = np.linspace(yLimMin, yLimMax) 

794 sideHist.hist( 

795 totalYChained, 

796 bins=bins, 

797 color="grey", 

798 alpha=0.3, 

799 orientation="horizontal", 

800 log=True, 

801 ) 

802 if "galaxies" in totalY: # type: ignore 

803 sideHist.hist( 

804 [g for g in cast(Vector, data["yGalaxies"]) if g == g], 

805 bins=bins, 

806 color="firebrick", 

807 histtype="step", 

808 orientation="horizontal", 

809 log=True, 

810 ) 

811 sideHist.hist( 

812 cast(Vector, data["yGalaxies"])[cast(Vector, data["galaxiesHighSNMask"])], 

813 bins=bins, 

814 color="firebrick", 

815 histtype="step", 

816 orientation="horizontal", 

817 log=True, 

818 ls="--", 

819 ) 

820 sideHist.hist( 

821 cast(Vector, data["yGalaxies"])[cast(Vector, data["galaxiesLowSNMask"])], 

822 bins=bins, 

823 color="firebrick", 

824 histtype="step", 

825 orientation="horizontal", 

826 log=True, 

827 ls=":", 

828 ) 

829 

830 if "stars" in totalY: # type: ignore 

831 sideHist.hist( 

832 [s for s in cast(Vector, data["yStars"]) if s == s], 

833 bins=bins, 

834 color="midnightblue", 

835 histtype="step", 

836 orientation="horizontal", 

837 log=True, 

838 ) 

839 sideHist.hist( 

840 cast(Vector, data["yStars"])[cast(Vector, data["starsHighSNMask"])], 

841 bins=bins, 

842 color="midnightblue", 

843 histtype="step", 

844 orientation="horizontal", 

845 log=True, 

846 ls="--", 

847 ) 

848 sideHist.hist( 

849 cast(Vector, data["yStars"])[cast(Vector, data["starsLowSNMask"])], 

850 bins=bins, 

851 color="midnightblue", 

852 histtype="step", 

853 orientation="horizontal", 

854 log=True, 

855 ls=":", 

856 ) 

857 

858 # Add a horizontal reference line at 0 to the side histogram 

859 sideHist.axhline(0, color=kwargs["hlineColor"], ls=kwargs["hlineStyle"], alpha=0.7, zorder=-2) 

860 

861 sideHist.axes.get_yaxis().set_visible(False) 

862 sideHist.set_xlabel("Number", fontsize=8) 

863 if self.plot2DHist and histIm is not None: 

864 divider = make_axes_locatable(sideHist) 

865 cax = divider.append_axes("right", size="8%", pad=0) 

866 plt.colorbar(histIm, cax=cax, orientation="vertical", label="Number of Points Per Bin")