Coverage for python/lsst/analysis/tools/actions/plot/scatterplotWithTwoHists.py: 12%

367 statements  

« prev     ^ index     » next       coverage.py v7.2.5, created at 2023-05-04 11:09 +0000

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ("ScatterPlotStatsAction", "ScatterPlotWithTwoHists") 

25 

26from itertools import chain 

27from typing import Mapping, NamedTuple, Optional, cast 

28 

29import matplotlib.pyplot as plt 

30import numpy as np 

31from lsst.pex.config import Field 

32from lsst.pex.config.configurableActions import ConfigurableActionField 

33from lsst.pex.config.listField import ListField 

34from lsst.skymap import BaseSkyMap 

35from matplotlib import gridspec 

36from matplotlib.axes import Axes 

37from matplotlib.collections import PolyCollection 

38from matplotlib.figure import Figure 

39from matplotlib.path import Path 

40from mpl_toolkits.axes_grid1 import make_axes_locatable 

41 

42from ...interfaces import KeyedData, KeyedDataAction, KeyedDataSchema, PlotAction, Scalar, ScalarType, Vector 

43from ...statistics import nansigmaMad, sigmaMad 

44from ..keyedData.summaryStatistics import SummaryStatisticAction 

45from ..scalar import MedianAction 

46from ..vector import MagColumnNanoJansky, SnSelector 

47from .plotUtils import addPlotInfo, addSummaryPlot, generateSummaryStats, mkColormap 

48 

49# ignore because coolwarm is actually part of module 

50cmapPatch = plt.cm.coolwarm.copy() # type: ignore 

51cmapPatch.set_bad(color="none") 

52 

53 

54class ScatterPlotStatsAction(KeyedDataAction): 

55 """Calculates the statistics needed for the 

56 scatter plot with two hists. 

57 """ 

58 

59 vectorKey = Field[str](doc="Vector on which to compute statistics") 

60 highSNSelector = ConfigurableActionField[SnSelector]( 

61 doc="Selector used to determine high SN Objects", default=SnSelector(threshold=2700) 

62 ) 

63 lowSNSelector = ConfigurableActionField[SnSelector]( 

64 doc="Selector used to determine low SN Objects", default=SnSelector(threshold=500) 

65 ) 

66 fluxType = Field[str](doc="Vector key to use to compute signal to noise ratio", default="{band}_psfFlux") 

67 

68 def getInputSchema(self, **kwargs) -> KeyedDataSchema: 

69 yield (self.vectorKey, Vector) 

70 yield (self.fluxType, Vector) 

71 yield from self.highSNSelector.getInputSchema() 

72 yield from self.lowSNSelector.getInputSchema() 

73 

74 def getOutputSchema(self) -> KeyedDataSchema: 

75 return ( 

76 (f'{self.identity or ""}HighSNMask', Vector), 

77 (f'{self.identity or ""}LowSNMask', Vector), 

78 (f"{{band}}_lowSN{self.identity.capitalize() if self.identity else ''}_median", Scalar), 

79 (f"{{band}}_lowSN{self.identity.capitalize() if self.identity else ''}_sigmaMad", Scalar), 

80 (f"{{band}}_lowSN{self.identity.capitalize() if self.identity else ''}_count", Scalar), 

81 (f"{{band}}_lowSN{self.identity.capitalize() if self.identity else ''}_approxMag", Scalar), 

82 (f"{{band}}_highSN{self.identity.capitalize() if self.identity else ''}_median", Scalar), 

83 (f"{{band}}_highSN{self.identity.capitalize() if self.identity else ''}_sigmaMad", Scalar), 

84 (f"{{band}}_highSN{self.identity.capitalize() if self.identity else ''}_count", Scalar), 

85 (f"{{band}}_highSN{self.identity.capitalize() if self.identity else ''}_approxMag", Scalar), 

86 ("highThreshold", Scalar), 

87 ("lowThreshold", Scalar), 

88 ) 

89 

90 def __call__(self, data: KeyedData, **kwargs) -> KeyedData: 

91 results = {} 

92 highMaskKey = f'{(self.identity or "").lower()}HighSNMask' 

93 results[highMaskKey] = self.highSNSelector(data, **kwargs) 

94 

95 lowMaskKey = f'{(self.identity or "").lower()}LowSNMask' 

96 results[lowMaskKey] = self.lowSNSelector(data, **kwargs) 

97 

98 prefix = f"{band}_" if (band := kwargs.get("band")) else "" 

99 fluxes = data[self.fluxType.format(band=band)] if band is not None else None 

100 

101 statAction = SummaryStatisticAction(vectorKey=self.vectorKey) 

102 

103 # this is sad, but pex_config seems to have broken behavior that 

104 # is dangerous to fix 

105 statAction.setDefaults() 

106 

107 medianAction = MedianAction(vectorKey="mag") 

108 magAction = MagColumnNanoJansky(vectorKey="flux") 

109 

110 for maskKey, binName in ((lowMaskKey, "low"), (highMaskKey, "high")): 

111 name = f"{prefix}{binName}SN{self.identity.capitalize() if self.identity else ''}" 

112 # set the approxMag to the median mag in the SN selection 

113 results[f"{name}_approxMag".format(**kwargs)] = ( 

114 medianAction({"mag": magAction({"flux": fluxes[results[maskKey]]})}) # type: ignore 

115 if band is not None 

116 else np.nan 

117 ) 

118 stats = statAction(data, **(kwargs | {"mask": results[maskKey]})).items() 

119 for suffix, value in stats: 

120 tmpKey = f"{name}_{suffix}".format(**kwargs) 

121 results[tmpKey] = value 

122 results["highSnThreshold"] = self.highSNSelector.threshold # type: ignore 

123 results["lowSnThreshold"] = self.lowSNSelector.threshold # type: ignore 

124 

125 return results 

126 

127 

128def _validatePlotTypes(value): 

129 return value in ("stars", "galaxies", "unknown", "any", "mag") 

130 

131 

132# ignore type because of conflicting name on tuple baseclass 

133class _StatsContainer(NamedTuple): 

134 median: Scalar 

135 sigmaMad: Scalar 

136 count: Scalar # type: ignore 

137 approxMag: Scalar 

138 

139 

140class ScatterPlotWithTwoHists(PlotAction): 

141 """Makes a scatter plot of the data with a marginal 

142 histogram for each axis. 

143 """ 

144 

145 yLims = ListField[float]( 

146 doc="ylimits of the plot, if not specified determined from data", 

147 length=2, 

148 optional=True, 

149 ) 

150 

151 xLims = ListField[float]( 

152 doc="xlimits of the plot, if not specified determined from data", length=2, optional=True 

153 ) 

154 xAxisLabel = Field[str](doc="Label to use for the x axis", optional=False) 

155 yAxisLabel = Field[str](doc="Label to use for the y axis", optional=False) 

156 magLabel = Field[str](doc="Label to use for the magnitudes used for SNR", optional=False) 

157 nBins = Field[float](doc="Number of bins on x axis", default=40.0) 

158 plot2DHist = Field[bool]( 

159 doc="Plot a 2D histogram in dense areas of points on the scatter plot." 

160 "Doesn't look great if plotting multiple datasets on top of each other.", 

161 default=True, 

162 ) 

163 plotTypes = ListField[str]( 

164 doc="Selection of types of objects to plot. Can take any combination of" 

165 " stars, galaxies, unknown, mag, any", 

166 optional=False, 

167 itemCheck=_validatePlotTypes, 

168 ) 

169 

170 addSummaryPlot = Field[bool]( 

171 doc="Add a summary plot to the figure?", 

172 default=False, 

173 ) 

174 

175 _stats = ("median", "sigmaMad", "count", "approxMag") 

176 

177 def getInputSchema(self) -> KeyedDataSchema: 

178 base: list[tuple[str, type[Vector] | ScalarType]] = [] 

179 if "stars" in self.plotTypes: # type: ignore 

180 base.append(("xStars", Vector)) 

181 base.append(("yStars", Vector)) 

182 base.append(("starsHighSNMask", Vector)) 

183 base.append(("starsLowSNMask", Vector)) 

184 # statistics 

185 for name in self._stats: 

186 base.append((f"{{band}}_highSNStars_{name}", Scalar)) 

187 base.append((f"{{band}}_lowSNStars_{name}", Scalar)) 

188 if "galaxies" in self.plotTypes: # type: ignore 

189 base.append(("xGalaxies", Vector)) 

190 base.append(("yGalaxies", Vector)) 

191 base.append(("galaxiesHighSNMask", Vector)) 

192 base.append(("galaxiesLowSNMask", Vector)) 

193 # statistics 

194 for name in self._stats: 

195 base.append((f"{{band}}_highSNGalaxies_{name}", Scalar)) 

196 base.append((f"{{band}}_lowSNGalaxies_{name}", Scalar)) 

197 if "unknown" in self.plotTypes: # type: ignore 

198 base.append(("xUnknown", Vector)) 

199 base.append(("yUnknown", Vector)) 

200 base.append(("unknownHighSNMask", Vector)) 

201 base.append(("unknownLowSNMask", Vector)) 

202 # statistics 

203 for name in self._stats: 

204 base.append((f"{{band}}_highSNUnknown_{name}", Scalar)) 

205 base.append((f"{{band}}_lowSNUnknown_{name}", Scalar)) 

206 if "any" in self.plotTypes: # type: ignore 

207 base.append(("x", Vector)) 

208 base.append(("y", Vector)) 

209 base.append(("anyHighSNMask", Vector)) 

210 base.append(("anySNMask", Vector)) 

211 # statistics 

212 for name in self._stats: 

213 base.append((f"{{band}}_highSNAny_{name}", Scalar)) 

214 base.append((f"{{band}}_lowSNAny_{name}", Scalar)) 

215 base.append(("lowSnThreshold", Scalar)) 

216 base.append(("highSnThreshold", Scalar)) 

217 

218 if self.addSummaryPlot: 

219 base.append(("patch", Vector)) 

220 

221 return base 

222 

223 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure: 

224 self._validateInput(data, **kwargs) 

225 return self.makePlot(data, **kwargs) 

226 

227 def _validateInput(self, data: KeyedData, **kwargs) -> None: 

228 """NOTE currently can only check that something is not a Scalar, not 

229 check that the data is consistent with Vector 

230 """ 

231 needed = self.getFormattedInputSchema(**kwargs) 

232 if remainder := {key.format(**kwargs) for key, _ in needed} - { 

233 key.format(**kwargs) for key in data.keys() 

234 }: 

235 raise ValueError(f"Task needs keys {remainder} but they were not found in input") 

236 for name, typ in needed: 

237 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar) 

238 if isScalar and typ != Scalar: 

239 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}") 

240 

241 def makePlot( 

242 self, 

243 data: KeyedData, 

244 skymap: BaseSkyMap, 

245 plotInfo: Mapping[str, str], 

246 sumStats: Optional[Mapping] = None, 

247 **kwargs, 

248 ) -> Figure: 

249 """Makes a generic plot with a 2D histogram and collapsed histograms of 

250 each axis. 

251 

252 Parameters 

253 ---------- 

254 data : `KeyedData` 

255 The catalog to plot the points from. 

256 skymap : `lsst.skymap.BaseSkyMap` 

257 The skymap that gives the patch locations 

258 plotInfo : `dict` 

259 A dictionary of information about the data being plotted with keys: 

260 

261 * ``"run"`` 

262 The output run for the plots (`str`). 

263 * ``"skymap"`` 

264 The type of skymap used for the data (`str`). 

265 * ``"filter"`` 

266 The filter used for this data (`str`). 

267 * ``"tract"`` 

268 The tract that the data comes from (`str`). 

269 sumStats : `dict` 

270 A dictionary where the patchIds are the keys which store the R.A. 

271 and dec of the corners of the patch, along with a summary 

272 statistic for each patch. 

273 

274 Returns 

275 ------- 

276 fig : `matplotlib.figure.Figure` 

277 The resulting figure. 

278 

279 Notes 

280 ----- 

281 Uses the axisLabels config options `x` and `y` and the axisAction 

282 config options `xAction` and `yAction` to plot a scatter 

283 plot of the values against each other. A histogram of the points 

284 collapsed onto each axis is also plotted. A summary panel showing the 

285 median of the y value in each patch is shown in the upper right corner 

286 of the resultant plot. The code uses the selectorActions to decide 

287 which points to plot and the statisticSelector actions to determine 

288 which points to use for the printed statistics. 

289 

290 If this function is being used within the pipetask framework 

291 that takes care of making sure that data has all the required 

292 elements but if you are runnign this as a standalone function 

293 then you will need to provide the following things in the 

294 input data. 

295 

296 * If stars is in self.plotTypes: 

297 xStars, yStars, starsHighSNMask, starsLowSNMask and 

298 {band}_highSNStars_{name}, {band}_lowSNStars_{name} 

299 where name is median, sigma_Mad, count and approxMag. 

300 

301 * If it is for galaxies/unknowns then replace stars in the above 

302 names with galaxies/unknowns. 

303 

304 * If it is for any (which covers all the points) then it 

305 becomes, x, y, and any instead of stars for the other 

306 parameters given above. 

307 

308 * In every case it is expected that data contains: 

309 lowSnThreshold, highSnThreshold and patch 

310 (if the summary plot is being plotted). 

311 

312 Examples 

313 -------- 

314 An example of the plot produced from this code is here: 

315 

316 .. image:: /_static/analysis_tools/scatterPlotExample.png 

317 

318 For a detailed example of how to make a plot from the command line 

319 please see the 

320 :ref:`getting started guide<analysis-tools-getting-started>`. 

321 """ 

322 if not self.plotTypes: 

323 noDataFig = Figure() 

324 noDataFig.text(0.3, 0.5, "No data to plot after selectors applied") 

325 noDataFig = addPlotInfo(noDataFig, plotInfo) 

326 return noDataFig 

327 

328 # Set default color and line style for the horizontal 

329 # reference line at 0 

330 if "hlineColor" not in kwargs: 

331 kwargs["hlineColor"] = "black" 

332 

333 if "hlineStyle" not in kwargs: 

334 kwargs["hlineStyle"] = (0, (1, 4)) 

335 

336 fig = plt.figure(dpi=300) 

337 gs = gridspec.GridSpec(4, 4) 

338 

339 # add the various plot elements 

340 ax, imhist = self._scatterPlot(data, fig, gs, **kwargs) 

341 self._makeTopHistogram(data, fig, gs, ax, **kwargs) 

342 self._makeSideHistogram(data, fig, gs, ax, imhist, **kwargs) 

343 # Needs info from run quantum 

344 if self.addSummaryPlot: 

345 sumStats = generateSummaryStats(data, skymap, plotInfo) 

346 label = self.yAxisLabel 

347 fig = addSummaryPlot(fig, gs[0, -1], sumStats, label) 

348 

349 plt.draw() 

350 plt.subplots_adjust(wspace=0.0, hspace=0.0, bottom=0.22, left=0.21) 

351 fig = addPlotInfo(fig, plotInfo) 

352 return fig 

353 

354 def _scatterPlot( 

355 self, data: KeyedData, fig: Figure, gs: gridspec.GridSpec, **kwargs 

356 ) -> tuple[Axes, Optional[PolyCollection]]: 

357 # Main scatter plot 

358 ax = fig.add_subplot(gs[1:, :-1]) 

359 

360 newBlues = mkColormap(["paleturquoise", "midnightBlue"]) 

361 newReds = mkColormap(["lemonchiffon", "firebrick"]) 

362 

363 binThresh = 5 

364 

365 yBinsOut = [] 

366 linesForLegend = [] 

367 

368 toPlotList = [] 

369 histIm = None 

370 highStats: _StatsContainer 

371 lowStats: _StatsContainer 

372 if "stars" in self.plotTypes: # type: ignore 

373 highArgs = {} 

374 lowArgs = {} 

375 for name in self._stats: 

376 highArgs[name] = cast(Scalar, data[f"{{band}}_highSNStars_{name}".format(**kwargs)]) 

377 lowArgs[name] = cast(Scalar, data[f"{{band}}_lowSNStars_{name}".format(**kwargs)]) 

378 highStats = _StatsContainer(**highArgs) 

379 lowStats = _StatsContainer(**lowArgs) 

380 

381 toPlotList.append( 

382 ( 

383 data["xStars"], 

384 data["yStars"], 

385 data["starsHighSNMask"], 

386 data["starsLowSNMask"], 

387 "midnightblue", 

388 newBlues, 

389 highStats, 

390 lowStats, 

391 ) 

392 ) 

393 if "galaxies" in self.plotTypes: # type: ignore 

394 highArgs = {} 

395 lowArgs = {} 

396 for name in self._stats: 

397 highArgs[name] = cast(Scalar, data[f"{{band}}_highSNGalaxies_{name}".format(**kwargs)]) 

398 lowArgs[name] = cast(Scalar, data[f"{{band}}_lowSNGalaxies_{name}".format(**kwargs)]) 

399 highStats = _StatsContainer(**highArgs) 

400 lowStats = _StatsContainer(**lowArgs) 

401 

402 toPlotList.append( 

403 ( 

404 data["xGalaxies"], 

405 data["yGalaxies"], 

406 data["galaxiesHighSNMask"], 

407 data["galaxiesLowSNMask"], 

408 "firebrick", 

409 newReds, 

410 highStats, 

411 lowStats, 

412 ) 

413 ) 

414 if "unknown" in self.plotTypes: # type: ignore 

415 highArgs = {} 

416 lowArgs = {} 

417 for name in self._stats: 

418 highArgs[name] = cast(Scalar, data[f"{{band}}_highSNUnknown_{name}".format(**kwargs)]) 

419 lowArgs[name] = cast(Scalar, data[f"{{band}}_lowSNUnknown_{name}".format(**kwargs)]) 

420 highStats = _StatsContainer(**highArgs) 

421 lowStats = _StatsContainer(**lowArgs) 

422 

423 toPlotList.append( 

424 ( 

425 data["xUnknown"], 

426 data["yUnknown"], 

427 data["unknownHighSNMask"], 

428 data["unknownLowSNMask"], 

429 "green", 

430 None, 

431 highStats, 

432 lowStats, 

433 ) 

434 ) 

435 if "any" in self.plotTypes: # type: ignore 

436 highArgs = {} 

437 lowArgs = {} 

438 for name in self._stats: 

439 highArgs[name] = cast(Scalar, data[f"{{band}}_highSNUnknown_{name}".format(**kwargs)]) 

440 lowArgs[name] = cast(Scalar, data[f"{{band}}_lowSNUnknown_{name}".format(**kwargs)]) 

441 highStats = _StatsContainer(**highArgs) 

442 lowStats = _StatsContainer(**lowArgs) 

443 

444 toPlotList.append( 

445 ( 

446 data["x"], 

447 data["y"], 

448 data["anyHighSNMask"], 

449 data["anyLowSNMask"], 

450 "purple", 

451 None, 

452 highStats, 

453 lowStats, 

454 ) 

455 ) 

456 

457 xMin = None 

458 for j, (xs, ys, highSn, lowSn, color, cmap, highStats, lowStats) in enumerate(toPlotList): 

459 highSn = cast(Vector, highSn) 

460 lowSn = cast(Vector, lowSn) 

461 # ensure the columns are actually array 

462 xs = np.array(xs) 

463 ys = np.array(ys) 

464 sigMadYs = nansigmaMad(ys) 

465 if len(xs) < 2: 

466 (medLine,) = ax.plot( 

467 xs, np.nanmedian(ys), color, label=f"Median: {np.nanmedian(ys):.2g}", lw=0.8 

468 ) 

469 linesForLegend.append(medLine) 

470 sigMads = np.array([nansigmaMad(ys)] * len(xs)) 

471 (sigMadLine,) = ax.plot( 

472 xs, 

473 np.nanmedian(ys) + 1.0 * sigMads, 

474 color, 

475 alpha=0.8, 

476 lw=0.8, 

477 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:.2g}", 

478 ) 

479 ax.plot(xs, np.nanmedian(ys) - 1.0 * sigMads, color, alpha=0.8) 

480 linesForLegend.append(sigMadLine) 

481 histIm = None 

482 continue 

483 

484 [xs1, xs25, xs50, xs75, xs95, xs97] = np.nanpercentile(xs, [1, 25, 50, 75, 95, 97]) 

485 xScale = (xs97 - xs1) / 20.0 # This is ~5% of the data range 

486 

487 # 40 was used as the default number of bins because it looked good 

488 xEdges = np.arange( 

489 np.nanmin(xs) - xScale, 

490 np.nanmax(xs) + xScale, 

491 (np.nanmax(xs) + xScale - (np.nanmin(xs) - xScale)) / self.nBins, 

492 ) 

493 medYs = np.nanmedian(ys) 

494 fiveSigmaHigh = medYs + 5.0 * sigMadYs 

495 fiveSigmaLow = medYs - 5.0 * sigMadYs 

496 binSize = (fiveSigmaHigh - fiveSigmaLow) / 101.0 

497 yEdges = np.arange(fiveSigmaLow, fiveSigmaHigh, binSize) 

498 

499 counts, xBins, yBins = np.histogram2d(xs, ys, bins=(xEdges, yEdges)) 

500 yBinsOut.append(yBins) 

501 countsYs = np.sum(counts, axis=1) 

502 

503 ids = np.where((countsYs > binThresh))[0] 

504 xEdgesPlot = xEdges[ids][1:] 

505 xEdges = xEdges[ids] 

506 

507 if len(ids) > 1: 

508 # Create the codes needed to turn the sigmaMad lines 

509 # into a path to speed up checking which points are 

510 # inside the area. 

511 codes = np.ones(len(xEdgesPlot) * 2) * Path.LINETO 

512 codes[0] = Path.MOVETO 

513 codes[-1] = Path.CLOSEPOLY 

514 

515 meds = np.zeros(len(xEdgesPlot)) 

516 threeSigMadVerts = np.zeros((len(xEdgesPlot) * 2, 2)) 

517 sigMads = np.zeros(len(xEdgesPlot)) 

518 

519 for i, xEdge in enumerate(xEdgesPlot): 

520 ids = np.where((xs < xEdge) & (xs > xEdges[i]) & (np.isfinite(ys)))[0] 

521 med = np.nanmedian(ys[ids]) 

522 sigMad = sigmaMad(ys[ids], nan_policy="omit") 

523 meds[i] = med 

524 sigMads[i] = sigMad 

525 threeSigMadVerts[i, :] = [xEdge, med + 3 * sigMad] 

526 threeSigMadVerts[-(i + 1), :] = [xEdge, med - 3 * sigMad] 

527 

528 (medLine,) = ax.plot(xEdgesPlot, meds, color, label="Running Median") 

529 linesForLegend.append(medLine) 

530 

531 # Make path to check which points lie within one sigma mad 

532 threeSigMadPath = Path(threeSigMadVerts, codes) 

533 

534 # Add lines for the median +/- 3 * sigma MAD 

535 (threeSigMadLine,) = ax.plot( 

536 xEdgesPlot, 

537 threeSigMadVerts[: len(xEdgesPlot), 1], 

538 color, 

539 alpha=0.4, 

540 label=r"3$\sigma_{MAD}$", 

541 ) 

542 ax.plot(xEdgesPlot[::-1], threeSigMadVerts[len(xEdgesPlot) :, 1], color, alpha=0.4) 

543 

544 # Add lines for the median +/- 1 * sigma MAD 

545 (sigMadLine,) = ax.plot( 

546 xEdgesPlot, meds + 1.0 * sigMads, color, alpha=0.8, label=r"$\sigma_{MAD}$" 

547 ) 

548 linesForLegend.append(sigMadLine) 

549 ax.plot(xEdgesPlot, meds - 1.0 * sigMads, color, alpha=0.8) 

550 

551 # Add lines for the median +/- 2 * sigma MAD 

552 (twoSigMadLine,) = ax.plot( 

553 xEdgesPlot, meds + 2.0 * sigMads, color, alpha=0.6, label=r"2$\sigma_{MAD}$" 

554 ) 

555 linesForLegend.append(twoSigMadLine) 

556 linesForLegend.append(threeSigMadLine) 

557 ax.plot(xEdgesPlot, meds - 2.0 * sigMads, color, alpha=0.6) 

558 

559 # Check which points are outside 3 sigma MAD of the median 

560 # and plot these as points. 

561 inside = threeSigMadPath.contains_points(np.array([xs, ys]).T) 

562 ax.plot(xs[~inside], ys[~inside], ".", ms=3, alpha=0.3, mfc=color, mec=color, zorder=-1) 

563 

564 # Add some stats text 

565 xPos = 0.65 - 0.4 * j 

566 bbox = dict(edgecolor=color, linestyle="--", facecolor="none") 

567 highThresh = data["highSnThreshold"] 

568 statText = f"S/N > {highThresh:0.4g} Stats ({self.magLabel} < {highStats.approxMag:0.4g})\n" 

569 highStatsStr = ( 

570 f"Median: {highStats.median:0.4g} " 

571 + r"$\sigma_{MAD}$: " 

572 + f"{highStats.sigmaMad:0.4g} " 

573 + r"N$_{points}$: " 

574 + f"{highStats.count}" 

575 ) 

576 statText += highStatsStr 

577 fig.text(xPos, 0.090, statText, bbox=bbox, transform=fig.transFigure, fontsize=6) 

578 

579 bbox = dict(edgecolor=color, linestyle=":", facecolor="none") 

580 lowThresh = data["lowSnThreshold"] 

581 statText = f"S/N > {lowThresh:0.4g} Stats ({self.magLabel} < {lowStats.approxMag:0.4g})\n" 

582 lowStatsStr = ( 

583 f"Median: {lowStats.median:0.4g} " 

584 + r"$\sigma_{MAD}$: " 

585 + f"{lowStats.sigmaMad:0.4g} " 

586 + r"N$_{points}$: " 

587 + f"{lowStats.count}" 

588 ) 

589 statText += lowStatsStr 

590 fig.text(xPos, 0.020, statText, bbox=bbox, transform=fig.transFigure, fontsize=6) 

591 

592 if self.plot2DHist: 

593 histIm = ax.hexbin(xs[inside], ys[inside], gridsize=75, cmap=cmap, mincnt=1, zorder=-3) 

594 

595 # If there are not many sources being used for the 

596 # statistics then plot them individually as just 

597 # plotting a line makes the statistics look wrong 

598 # as the magnitude estimation is iffy for low 

599 # numbers of sources. 

600 if np.sum(highSn) < 100 and np.sum(highSn) > 0: 

601 ax.plot( 

602 cast(Vector, xs[highSn]), 

603 cast(Vector, ys[highSn]), 

604 marker="x", 

605 ms=4, 

606 mec="w", 

607 mew=2, 

608 ls="none", 

609 ) 

610 (highSnLine,) = ax.plot( 

611 cast(Vector, xs[highSn]), 

612 cast(Vector, ys[highSn]), 

613 color=color, 

614 marker="x", 

615 ms=4, 

616 ls="none", 

617 label="High SN", 

618 ) 

619 linesForLegend.append(highSnLine) 

620 xMin = np.min(cast(Vector, xs[highSn])) 

621 else: 

622 ax.axvline(highStats.approxMag, color=color, ls="--") 

623 

624 if np.sum(lowSn) < 100 and np.sum(lowSn) > 0: 

625 ax.plot( 

626 cast(Vector, xs[lowSn]), 

627 cast(Vector, ys[lowSn]), 

628 marker="+", 

629 ms=4, 

630 mec="w", 

631 mew=2, 

632 ls="none", 

633 ) 

634 (lowSnLine,) = ax.plot( 

635 cast(Vector, xs[lowSn]), 

636 cast(Vector, ys[lowSn]), 

637 color=color, 

638 marker="+", 

639 ms=4, 

640 ls="none", 

641 label="Low SN", 

642 ) 

643 linesForLegend.append(lowSnLine) 

644 if xMin is None or xMin > np.min(cast(Vector, xs[lowSn])): 

645 xMin = np.min(cast(Vector, xs[lowSn])) 

646 else: 

647 ax.axvline(lowStats.approxMag, color=color, ls=":") 

648 

649 else: 

650 ax.plot(xs, ys, ".", ms=5, alpha=0.3, mfc=color, mec=color, zorder=-1) 

651 meds = np.array([np.nanmedian(ys)] * len(xs)) 

652 (medLine,) = ax.plot(xs, meds, color, label=f"Median: {np.nanmedian(ys):0.3g}", lw=0.8) 

653 linesForLegend.append(medLine) 

654 sigMads = np.array([nansigmaMad(ys)] * len(xs)) 

655 (sigMadLine,) = ax.plot( 

656 xs, 

657 meds + 1.0 * sigMads, 

658 color, 

659 alpha=0.8, 

660 lw=0.8, 

661 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:0.3g}", 

662 ) 

663 ax.plot(xs, meds - 1.0 * sigMads, color, alpha=0.8) 

664 linesForLegend.append(sigMadLine) 

665 histIm = None 

666 

667 # Add a horizontal reference line at 0 to the scatter plot 

668 ax.axhline(0, color=kwargs["hlineColor"], ls=kwargs["hlineStyle"], alpha=0.7, zorder=-2) 

669 

670 # Set the scatter plot limits 

671 # TODO: Make this not work by accident 

672 if len(cast(Vector, data["yStars"])) > 0: 

673 plotMed = np.nanmedian(cast(Vector, data["yStars"])) 

674 else: 

675 plotMed = np.nanmedian(cast(Vector, data["yGalaxies"])) 

676 # Ignore types below pending making this not working my accident 

677 if len(xs) < 2: # type: ignore 

678 meds = [np.nanmedian(ys)] # type: ignore 

679 if self.yLims: 

680 ax.set_ylim(self.yLims[0], self.yLims[1]) # type: ignore 

681 else: 

682 numSig = 4 

683 yLimMin = plotMed - numSig * sigMadYs # type: ignore 

684 yLimMax = plotMed + numSig * sigMadYs # type: ignore 

685 while (yLimMax < np.max(meds) or yLimMin > np.min(meds)) and numSig < 10: # type: ignore 

686 numSig += 1 

687 

688 numSig += 1 

689 yLimMin = plotMed - numSig * sigMadYs # type: ignore 

690 yLimMax = plotMed + numSig * sigMadYs # type: ignore 

691 ax.set_ylim(yLimMin, yLimMax) 

692 

693 if self.xLims: 

694 ax.set_xlim(self.xLims[0], self.xLims[1]) # type: ignore 

695 elif len(xs) > 2: # type: ignore 

696 if xMin is None: 

697 xMin = xs1 - 2 * xScale # type: ignore 

698 ax.set_xlim(xMin, xs97 + 2 * xScale) # type: ignore 

699 

700 # Add a line legend 

701 ax.legend( 

702 handles=linesForLegend, 

703 ncol=4, 

704 fontsize=6, 

705 loc="upper left", 

706 framealpha=0.9, 

707 edgecolor="k", 

708 borderpad=0.4, 

709 handlelength=1, 

710 ) 

711 

712 # Add axes labels 

713 ax.set_ylabel(self.yAxisLabel, fontsize=10, labelpad=10) 

714 ax.set_xlabel(self.xAxisLabel, fontsize=10, labelpad=2) 

715 

716 return ax, histIm 

717 

718 def _makeTopHistogram( 

719 self, data: KeyedData, figure: Figure, gs: gridspec.GridSpec, ax: Axes, **kwargs 

720 ) -> None: 

721 # Top histogram 

722 totalX: list[Vector] = [] 

723 if "stars" in self.plotTypes: # type: ignore 

724 totalX.append(cast(Vector, data["xStars"])) 

725 if "galaxies" in self.plotTypes: # type: ignore 

726 totalX.append(cast(Vector, data["xGalaxies"])) 

727 if "unknown" in self.plotTypes: # type: ignore 

728 totalX.append(cast(Vector, data["xUknown"])) 

729 if "any" in self.plotTypes: # type: ignore 

730 totalX.append(cast(Vector, data["x"])) 

731 

732 totalXChained = [x for x in chain.from_iterable(totalX) if x == x] 

733 

734 topHist = figure.add_subplot(gs[0, :-1], sharex=ax) 

735 topHist.hist( 

736 totalXChained, bins=100, color="grey", alpha=0.3, log=True, label=f"All ({len(totalXChained)})" 

737 ) 

738 if "galaxies" in self.plotTypes: # type: ignore 

739 topHist.hist( 

740 data["xGalaxies"], 

741 bins=100, 

742 color="firebrick", 

743 histtype="step", 

744 log=True, 

745 label=f"Galaxies ({len(cast(Vector, data['xGalaxies']))})", 

746 ) 

747 if "stars" in self.plotTypes: # type: ignore 

748 topHist.hist( 

749 data["xStars"], 

750 bins=100, 

751 color="midnightblue", 

752 histtype="step", 

753 log=True, 

754 label=f"Stars ({len(cast(Vector, data['xStars']))})", 

755 ) 

756 topHist.axes.get_xaxis().set_visible(False) 

757 topHist.set_ylabel("Number", fontsize=8) 

758 topHist.legend(fontsize=6, framealpha=0.9, borderpad=0.4, loc="lower left", ncol=3, edgecolor="k") 

759 

760 # Side histogram 

761 

762 def _makeSideHistogram( 

763 self, 

764 data: KeyedData, 

765 figure: Figure, 

766 gs: gridspec.Gridspec, 

767 ax: Axes, 

768 histIm: Optional[PolyCollection], 

769 **kwargs, 

770 ) -> None: 

771 sideHist = figure.add_subplot(gs[1:, -1], sharey=ax) 

772 

773 totalY: list[Vector] = [] 

774 if "stars" in self.plotTypes: # type: ignore 

775 totalY.append(cast(Vector, data["yStars"])) 

776 if "galaxies" in self.plotTypes: # type: ignore 

777 totalY.append(cast(Vector, data["yGalaxies"])) 

778 if "unknown" in self.plotTypes: # type: ignore 

779 totalY.append(cast(Vector, data["yUknown"])) 

780 if "any" in self.plotTypes: # type: ignore 

781 totalY.append(cast(Vector, data["y"])) 

782 totalYChained = [y for y in chain.from_iterable(totalY) if y == y] 

783 

784 # cheat to get the total count while iterating once 

785 yLimMin, yLimMax = ax.get_ylim() 

786 bins = np.linspace(yLimMin, yLimMax) 

787 sideHist.hist( 

788 totalYChained, 

789 bins=bins, 

790 color="grey", 

791 alpha=0.3, 

792 orientation="horizontal", 

793 log=True, 

794 ) 

795 if "galaxies" in self.plotTypes: # type: ignore 

796 sideHist.hist( 

797 [g for g in cast(Vector, data["yGalaxies"]) if g == g], 

798 bins=bins, 

799 color="firebrick", 

800 histtype="step", 

801 orientation="horizontal", 

802 log=True, 

803 ) 

804 sideHist.hist( 

805 cast(Vector, data["yGalaxies"])[cast(Vector, data["galaxiesHighSNMask"])], 

806 bins=bins, 

807 color="firebrick", 

808 histtype="step", 

809 orientation="horizontal", 

810 log=True, 

811 ls="--", 

812 ) 

813 sideHist.hist( 

814 cast(Vector, data["yGalaxies"])[cast(Vector, data["galaxiesLowSNMask"])], 

815 bins=bins, 

816 color="firebrick", 

817 histtype="step", 

818 orientation="horizontal", 

819 log=True, 

820 ls=":", 

821 ) 

822 

823 if "stars" in self.plotTypes: # type: ignore 

824 sideHist.hist( 

825 [s for s in cast(Vector, data["yStars"]) if s == s], 

826 bins=bins, 

827 color="midnightblue", 

828 histtype="step", 

829 orientation="horizontal", 

830 log=True, 

831 ) 

832 sideHist.hist( 

833 cast(Vector, data["yStars"])[cast(Vector, data["starsHighSNMask"])], 

834 bins=bins, 

835 color="midnightblue", 

836 histtype="step", 

837 orientation="horizontal", 

838 log=True, 

839 ls="--", 

840 ) 

841 sideHist.hist( 

842 cast(Vector, data["yStars"])[cast(Vector, data["starsLowSNMask"])], 

843 bins=bins, 

844 color="midnightblue", 

845 histtype="step", 

846 orientation="horizontal", 

847 log=True, 

848 ls=":", 

849 ) 

850 

851 # Add a horizontal reference line at 0 to the side histogram 

852 sideHist.axhline(0, color=kwargs["hlineColor"], ls=kwargs["hlineStyle"], alpha=0.7, zorder=-2) 

853 

854 sideHist.axes.get_yaxis().set_visible(False) 

855 sideHist.set_xlabel("Number", fontsize=8) 

856 if self.plot2DHist and histIm is not None: 

857 divider = make_axes_locatable(sideHist) 

858 cax = divider.append_axes("right", size="8%", pad=0) 

859 plt.colorbar(histIm, cax=cax, orientation="vertical", label="Number of Points Per Bin")