Coverage for python/lsst/analysis/tools/actions/plot/scatterplotWithTwoHists.py: 16%

342 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-05-04 03:35 -0700

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ("ScatterPlotStatsAction", "ScatterPlotWithTwoHists") 

25 

26from typing import Mapping, NamedTuple, Optional, cast 

27 

28import matplotlib.colors 

29import matplotlib.pyplot as plt 

30import numpy as np 

31from lsst.pex.config import Field 

32from lsst.pex.config.configurableActions import ConfigurableActionField 

33from lsst.pex.config.listField import ListField 

34from matplotlib import gridspec 

35from matplotlib.axes import Axes 

36from matplotlib.collections import PolyCollection 

37from matplotlib.figure import Figure 

38from matplotlib.path import Path 

39from mpl_toolkits.axes_grid1 import make_axes_locatable 

40 

41from ...interfaces import KeyedData, KeyedDataAction, KeyedDataSchema, PlotAction, Scalar, ScalarType, Vector 

42from ...math import nanMedian, nanSigmaMad 

43from ..keyedData.summaryStatistics import SummaryStatisticAction 

44from ..scalar import MedianAction 

45from ..vector import ConvertFluxToMag, SnSelector 

46from .plotUtils import addPlotInfo, addSummaryPlot, generateSummaryStats, mkColormap 

47 

48# ignore because coolwarm is actually part of module 

49cmapPatch = plt.cm.coolwarm.copy() # type: ignore 

50cmapPatch.set_bad(color="none") 

51 

52 

53class ScatterPlotStatsAction(KeyedDataAction): 

54 """Calculates the statistics needed for the 

55 scatter plot with two hists. 

56 """ 

57 

58 vectorKey = Field[str](doc="Vector on which to compute statistics") 

59 highSNSelector = ConfigurableActionField[SnSelector]( 

60 doc="Selector used to determine high SN Objects", default=SnSelector(threshold=2700) 

61 ) 

62 lowSNSelector = ConfigurableActionField[SnSelector]( 

63 doc="Selector used to determine low SN Objects", default=SnSelector(threshold=500) 

64 ) 

65 prefix = Field[str]( 

66 doc="Prefix for all output fields; will use self.identity if None", 

67 optional=True, 

68 default=None, 

69 ) 

70 fluxType = Field[str](doc="Vector key to use to compute signal to noise ratio", default="{band}_psfFlux") 

71 suffix = Field[str](doc="Suffix for all output fields", default="") 

72 

73 def _get_key_prefix(self): 

74 prefix = self.prefix if self.prefix else (self.identity if self.identity else "") 

75 return prefix 

76 

77 def getInputSchema(self, **kwargs) -> KeyedDataSchema: 

78 yield (self.vectorKey, Vector) 

79 yield (self.fluxType, Vector) 

80 yield from self.highSNSelector.getInputSchema() 

81 yield from self.lowSNSelector.getInputSchema() 

82 

83 def getOutputSchema(self) -> KeyedDataSchema: 

84 prefix = self._get_key_prefix() 

85 prefix_lower = prefix.lower() if prefix else "" 

86 prefix_upper = prefix.capitalize() if prefix else "" 

87 suffix = self.suffix 

88 return ( 

89 (f"{prefix_lower}HighSNMask{suffix}", Vector), 

90 (f"{prefix_lower}LowSNMask{suffix}", Vector), 

91 (f"{{band}}_lowSN{prefix_upper}_median{suffix}", Scalar), 

92 (f"{{band}}_lowSN{prefix_upper}_sigmaMad{suffix}", Scalar), 

93 (f"{{band}}_lowSN{prefix_upper}_count{suffix}", Scalar), 

94 (f"{{band}}_lowSN{prefix_upper}_approxMag{suffix}", Scalar), 

95 (f"{{band}}_highSN{prefix_upper}_median{suffix}", Scalar), 

96 (f"{{band}}_highSN{prefix_upper}_sigmaMad{suffix}", Scalar), 

97 (f"{{band}}_highSN{prefix_upper}_count{suffix}", Scalar), 

98 (f"{{band}}_highSN{prefix_upper}_approxMag{suffix}", Scalar), 

99 (f"{prefix_lower}HighSNThreshold{suffix}", Scalar), 

100 (f"{prefix_lower}LowSNThreshold{suffix}", Scalar), 

101 ) 

102 

103 def __call__(self, data: KeyedData, **kwargs) -> KeyedData: 

104 results = {} 

105 prefix = self._get_key_prefix() 

106 prefix_lower = prefix.lower() if prefix else "" 

107 prefix_upper = prefix.capitalize() if prefix else "" 

108 suffix = self.suffix 

109 highMaskKey = f"{prefix_lower}HighSNMask{suffix}" 

110 results[highMaskKey] = self.highSNSelector(data, **kwargs) 

111 

112 lowMaskKey = f"{prefix_lower}LowSNMask{suffix}" 

113 results[lowMaskKey] = self.lowSNSelector(data, **kwargs) 

114 

115 prefix_band = f"{band}_" if (band := kwargs.get("band")) else "" 

116 fluxes = data[self.fluxType.format(band=band)] if band is not None else None 

117 

118 statAction = SummaryStatisticAction(vectorKey=self.vectorKey) 

119 

120 # this is sad, but pex_config seems to have broken behavior that 

121 # is dangerous to fix 

122 statAction.setDefaults() 

123 

124 medianAction = MedianAction(vectorKey="mag") 

125 magAction = ConvertFluxToMag(vectorKey="flux") 

126 

127 for maskKey, binName in ((lowMaskKey, "low"), (highMaskKey, "high")): 

128 name = f"{prefix_band}{binName}SN{prefix_upper}" 

129 # set the approxMag to the median mag in the SN selection 

130 results[f"{name}_approxMag{suffix}".format(**kwargs)] = ( 

131 medianAction({"mag": magAction({"flux": fluxes[results[maskKey]]})}) # type: ignore 

132 if band is not None 

133 else np.nan 

134 ) 

135 stats = statAction(data, **(kwargs | {"mask": results[maskKey]})).items() 

136 for name_stat, value in stats: 

137 tmpKey = f"{name}_{name_stat}{suffix}".format(**kwargs) 

138 results[tmpKey] = value 

139 results[f"{prefix_lower}HighSNThreshold{suffix}"] = self.highSNSelector.threshold # type: ignore 

140 results[f"{prefix_lower}LowSNThreshold{suffix}"] = self.lowSNSelector.threshold # type: ignore 

141 

142 return results 

143 

144 

145def _validObjectTypes(value): 

146 return value in ("stars", "galaxies", "unknown", "any") 

147 

148 

149# ignore type because of conflicting name on tuple baseclass 

150class _StatsContainer(NamedTuple): 

151 median: Scalar 

152 sigmaMad: Scalar 

153 count: Scalar # type: ignore 

154 approxMag: Scalar 

155 

156 

157class DataTypeDefaults(NamedTuple): 

158 suffix_stat: str 

159 suffix_xy: str 

160 color: str 

161 colormap: matplotlib.colors.Colormap | None 

162 

163 

164class ScatterPlotWithTwoHists(PlotAction): 

165 """Makes a scatter plot of the data with a marginal 

166 histogram for each axis. 

167 """ 

168 

169 yLims = ListField[float]( 

170 doc="ylimits of the plot, if not specified determined from data", 

171 length=2, 

172 optional=True, 

173 ) 

174 

175 xLims = ListField[float]( 

176 doc="xlimits of the plot, if not specified determined from data", length=2, optional=True 

177 ) 

178 xAxisLabel = Field[str](doc="Label to use for the x axis", optional=False) 

179 yAxisLabel = Field[str](doc="Label to use for the y axis", optional=False) 

180 magLabel = Field[str](doc="Label to use for the magnitudes used for SNR", optional=False) 

181 

182 legendLocation = Field[str](doc="Legend position within main plot", default="upper left") 

183 nBins = Field[float](doc="Number of bins on x axis", default=40.0) 

184 plot2DHist = Field[bool]( 

185 doc="Plot a 2D histogram in dense areas of points on the scatter plot." 

186 "Doesn't look great if plotting multiple datasets on top of each other.", 

187 default=True, 

188 ) 

189 plotTypes = ListField[str]( 

190 doc="Selection of types of objects to plot. Can take any combination of" 

191 " stars, galaxies, unknown, any", 

192 optional=False, 

193 itemCheck=_validObjectTypes, 

194 ) 

195 

196 addSummaryPlot = Field[bool]( 

197 doc="Add a summary plot to the figure?", 

198 default=True, 

199 ) 

200 suffix_x = Field[str](doc="Suffix for all x-axis action inputs", optional=True, default="") 

201 suffix_y = Field[str](doc="Suffix for all y-axis action inputs", optional=True, default="") 

202 suffix_stat = Field[str](doc="Suffix for all binned statistic action inputs", optional=True, default="") 

203 

204 _stats = ("median", "sigmaMad", "count", "approxMag") 

205 _datatypes = { 

206 "galaxies": DataTypeDefaults( 

207 suffix_stat="Galaxies", 

208 suffix_xy="Galaxies", 

209 color="firebrick", 

210 colormap=mkColormap(["lemonchiffon", "firebrick"]), 

211 ), 

212 "stars": DataTypeDefaults( 

213 suffix_stat="Stars", 

214 suffix_xy="Stars", 

215 color="midnightblue", 

216 colormap=mkColormap(["paleturquoise", "midnightBlue"]), 

217 ), 

218 "unknown": DataTypeDefaults( 

219 suffix_stat="Unknown", 

220 suffix_xy="Unknown", 

221 color="green", 

222 colormap=None, 

223 ), 

224 "any": DataTypeDefaults( 

225 suffix_stat="Any", 

226 suffix_xy="", 

227 color="purple", 

228 colormap=None, 

229 ), 

230 } 

231 

232 def getInputSchema(self) -> KeyedDataSchema: 

233 base: list[tuple[str, type[Vector] | ScalarType]] = [] 

234 for name_datatype in self.plotTypes: 

235 config_datatype = self._datatypes[name_datatype] 

236 base.append((f"x{config_datatype.suffix_xy}{self.suffix_x}", Vector)) 

237 base.append((f"y{config_datatype.suffix_xy}{self.suffix_y}", Vector)) 

238 base.append((f"{name_datatype}HighSNMask{self.suffix_stat}", Vector)) 

239 base.append((f"{name_datatype}LowSNMask{self.suffix_stat}", Vector)) 

240 # statistics 

241 for name in self._stats: 

242 base.append( 

243 (f"{{band}}_highSN{config_datatype.suffix_stat}_{name}{self.suffix_stat}", Scalar) 

244 ) 

245 base.append((f"{{band}}_lowSN{config_datatype.suffix_stat}_{name}{self.suffix_stat}", Scalar)) 

246 base.append((f"{name_datatype}LowSNThreshold{self.suffix_stat}", Scalar)) 

247 base.append((f"{name_datatype}HighSNThreshold{self.suffix_stat}", Scalar)) 

248 

249 if self.addSummaryPlot: 

250 base.append(("patch", Vector)) 

251 

252 return base 

253 

254 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure: 

255 self._validateInput(data, **kwargs) 

256 return self.makePlot(data, **kwargs) 

257 

258 def _validateInput(self, data: KeyedData, **kwargs) -> None: 

259 """NOTE currently can only check that something is not a Scalar, not 

260 check that the data is consistent with Vector 

261 """ 

262 needed = self.getFormattedInputSchema(**kwargs) 

263 if remainder := {key.format(**kwargs) for key, _ in needed} - { 

264 key.format(**kwargs) for key in data.keys() 

265 }: 

266 raise ValueError( 

267 f"Task needs keys {remainder} but they were not found in input keys" f" {list(data.keys())}" 

268 ) 

269 for name, typ in needed: 

270 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar) 

271 if isScalar and typ != Scalar: 

272 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}") 

273 

274 def makePlot( 

275 self, 

276 data: KeyedData, 

277 plotInfo: Mapping[str, str], 

278 **kwargs, 

279 ) -> Figure: 

280 """Makes a generic plot with a 2D histogram and collapsed histograms of 

281 each axis. 

282 

283 Parameters 

284 ---------- 

285 data : `KeyedData` 

286 The catalog to plot the points from. 

287 plotInfo : `dict` 

288 A dictionary of information about the data being plotted with keys: 

289 

290 * ``"run"`` 

291 The output run for the plots (`str`). 

292 * ``"skymap"`` 

293 The type of skymap used for the data (`str`). 

294 * ``"filter"`` 

295 The filter used for this data (`str`). 

296 * ``"tract"`` 

297 The tract that the data comes from (`str`). 

298 

299 Returns 

300 ------- 

301 fig : `matplotlib.figure.Figure` 

302 The resulting figure. 

303 

304 Notes 

305 ----- 

306 Uses the axisLabels config options `x` and `y` and the axisAction 

307 config options `xAction` and `yAction` to plot a scatter 

308 plot of the values against each other. A histogram of the points 

309 collapsed onto each axis is also plotted. A summary panel showing the 

310 median of the y value in each patch is shown in the upper right corner 

311 of the resultant plot. The code uses the selectorActions to decide 

312 which points to plot and the statisticSelector actions to determine 

313 which points to use for the printed statistics. 

314 

315 If this function is being used within the pipetask framework 

316 that takes care of making sure that data has all the required 

317 elements but if you are running this as a standalone function 

318 then you will need to provide the following things in the 

319 input data. 

320 

321 * If stars is in self.plotTypes: 

322 xStars, yStars, starsHighSNMask, starsLowSNMask and 

323 {band}_highSNStars_{name}, {band}_lowSNStars_{name} 

324 where name is median, sigma_Mad, count and approxMag. 

325 

326 * If it is for galaxies/unknowns then replace stars in the above 

327 names with galaxies/unknowns. 

328 

329 * If it is for any (which covers all the points) then it 

330 becomes, x, y, and any instead of stars for the other 

331 parameters given above. 

332 

333 * In every case it is expected that data contains: 

334 lowSnThreshold, highSnThreshold and patch 

335 (if the summary plot is being plotted). 

336 

337 Examples 

338 -------- 

339 An example of the plot produced from this code is here: 

340 

341 .. image:: /_static/analysis_tools/scatterPlotExample.png 

342 

343 For a detailed example of how to make a plot from the command line 

344 please see the 

345 :ref:`getting started guide<analysis-tools-getting-started>`. 

346 """ 

347 if not self.plotTypes: 

348 noDataFig = Figure() 

349 noDataFig.text(0.3, 0.5, "No data to plot after selectors applied") 

350 noDataFig = addPlotInfo(noDataFig, plotInfo) 

351 return noDataFig 

352 

353 # Set default color and line style for the horizontal 

354 # reference line at 0 

355 if "hlineColor" not in kwargs: 

356 kwargs["hlineColor"] = "black" 

357 

358 if "hlineStyle" not in kwargs: 

359 kwargs["hlineStyle"] = (0, (1, 4)) 

360 

361 fig = plt.figure(dpi=300) 

362 gs = gridspec.GridSpec(4, 4) 

363 

364 # add the various plot elements 

365 ax, imhist = self._scatterPlot(data, fig, gs, **kwargs) 

366 self._makeTopHistogram(data, fig, gs, ax, **kwargs) 

367 self._makeSideHistogram(data, fig, gs, ax, imhist, **kwargs) 

368 # Needs info from run quantum 

369 skymap = kwargs.get("skymap", None) 

370 if self.addSummaryPlot and skymap is not None: 

371 sumStats = generateSummaryStats(data, skymap, plotInfo) 

372 label = self.yAxisLabel 

373 fig = addSummaryPlot(fig, gs[0, -1], sumStats, label) 

374 

375 plt.draw() 

376 plt.subplots_adjust(wspace=0.0, hspace=0.0, bottom=0.22, left=0.21) 

377 fig = addPlotInfo(fig, plotInfo) 

378 return fig 

379 

380 def _scatterPlot( 

381 self, data: KeyedData, fig: Figure, gs: gridspec.GridSpec, **kwargs 

382 ) -> tuple[Axes, Optional[PolyCollection]]: 

383 suf_x = self.suffix_x 

384 suf_y = self.suffix_y 

385 suf_stat = self.suffix_stat 

386 # Main scatter plot 

387 ax = fig.add_subplot(gs[1:, :-1]) 

388 

389 binThresh = 5 

390 

391 yBinsOut = [] 

392 linesForLegend = [] 

393 

394 toPlotList = [] 

395 histIm = None 

396 highStats: _StatsContainer 

397 lowStats: _StatsContainer 

398 

399 for name_datatype in self.plotTypes: 

400 config_datatype = self._datatypes[name_datatype] 

401 highArgs = {} 

402 lowArgs = {} 

403 for name in self._stats: 

404 highArgs[name] = cast( 

405 Scalar, 

406 data[f"{{band}}_highSN{config_datatype.suffix_stat}_{name}{suf_stat}".format(**kwargs)], 

407 ) 

408 lowArgs[name] = cast( 

409 Scalar, 

410 data[f"{{band}}_lowSN{config_datatype.suffix_stat}_{name}{suf_stat}".format(**kwargs)], 

411 ) 

412 highStats = _StatsContainer(**highArgs) 

413 lowStats = _StatsContainer(**lowArgs) 

414 

415 toPlotList.append( 

416 ( 

417 data[f"x{config_datatype.suffix_xy}{suf_x}"], 

418 data[f"y{config_datatype.suffix_xy}{suf_y}"], 

419 data[f"{name_datatype}HighSNMask{suf_stat}"], 

420 data[f"{name_datatype}LowSNMask{suf_stat}"], 

421 data[f"{name_datatype}HighSNThreshold{suf_stat}"], 

422 data[f"{name_datatype}LowSNThreshold{suf_stat}"], 

423 config_datatype.color, 

424 config_datatype.colormap, 

425 highStats, 

426 lowStats, 

427 ) 

428 ) 

429 

430 xLims = self.xLims if self.xLims is not None else [np.Inf, -np.Inf] 

431 for j, ( 

432 xs, 

433 ys, 

434 highSn, 

435 lowSn, 

436 highThresh, 

437 lowThresh, 

438 color, 

439 cmap, 

440 highStats, 

441 lowStats, 

442 ) in enumerate(toPlotList): 

443 highSn = cast(Vector, highSn) 

444 lowSn = cast(Vector, lowSn) 

445 # ensure the columns are actually array 

446 xs = np.array(xs) 

447 ys = np.array(ys) 

448 sigMadYs = nanSigmaMad(ys) 

449 # plot lone median point if there's not enough data to measure more 

450 n_xs = len(xs) 

451 if n_xs == 0 or not np.isfinite(sigMadYs): 

452 continue 

453 elif n_xs < 10: 

454 xs = [nanMedian(xs)] 

455 sigMads = np.array([nanSigmaMad(ys)]) 

456 ys = np.array([nanMedian(ys)]) 

457 (medLine,) = ax.plot(xs, ys, color, label=f"Median: {ys[0]:.2g}", lw=0.8) 

458 linesForLegend.append(medLine) 

459 (sigMadLine,) = ax.plot( 

460 xs, 

461 ys + 1.0 * sigMads, 

462 color, 

463 alpha=0.8, 

464 lw=0.8, 

465 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:.2g}", 

466 ) 

467 ax.plot(xs, ys - 1.0 * sigMads, color, alpha=0.8) 

468 linesForLegend.append(sigMadLine) 

469 histIm = None 

470 continue 

471 

472 if self.xLims: 

473 xMin, xMax = self.xLims 

474 else: 

475 # Chop off 1/3% from only the finite xs values 

476 # (there may be +/-np.Inf values) 

477 # TODO: This should be configurable 

478 # but is there a good way to avoid redundant config params 

479 # without using slightly annoying subconfigs? 

480 xs1, xs97 = np.nanpercentile(xs[np.isfinite(xs)], (1, 97)) 

481 xScale = (xs97 - xs1) / 20.0 # This is ~5% of the data range 

482 xMin, xMax = (xs1 - xScale, xs97 + xScale) 

483 xLims[0] = min(xLims[0], xMin) 

484 xLims[1] = max(xLims[1], xMax) 

485 

486 xEdges = np.arange(xMin, xMax, (xMax - xMin) / self.nBins) 

487 medYs = nanMedian(ys) 

488 fiveSigmaHigh = medYs + 5.0 * sigMadYs 

489 fiveSigmaLow = medYs - 5.0 * sigMadYs 

490 binSize = (fiveSigmaHigh - fiveSigmaLow) / 101.0 

491 yEdges = np.arange(fiveSigmaLow, fiveSigmaHigh, binSize) 

492 

493 counts, xBins, yBins = np.histogram2d(xs, ys, bins=(xEdges, yEdges)) 

494 yBinsOut.append(yBins) 

495 countsYs = np.sum(counts, axis=1) 

496 

497 ids = np.where((countsYs > binThresh))[0] 

498 xEdgesPlot = xEdges[ids][1:] 

499 xEdges = xEdges[ids] 

500 

501 if len(ids) > 1: 

502 # Create the codes needed to turn the sigmaMad lines 

503 # into a path to speed up checking which points are 

504 # inside the area. 

505 codes = np.ones(len(xEdgesPlot) * 2) * Path.LINETO 

506 codes[0] = Path.MOVETO 

507 codes[-1] = Path.CLOSEPOLY 

508 

509 meds = np.zeros(len(xEdgesPlot)) 

510 threeSigMadVerts = np.zeros((len(xEdgesPlot) * 2, 2)) 

511 sigMads = np.zeros(len(xEdgesPlot)) 

512 

513 for i, xEdge in enumerate(xEdgesPlot): 

514 ids = np.where((xs < xEdge) & (xs > xEdges[i]) & (np.isfinite(ys)))[0] 

515 med = nanMedian(ys[ids]) 

516 sigMad = nanSigmaMad(ys[ids]) 

517 meds[i] = med 

518 sigMads[i] = sigMad 

519 threeSigMadVerts[i, :] = [xEdge, med + 3 * sigMad] 

520 threeSigMadVerts[-(i + 1), :] = [xEdge, med - 3 * sigMad] 

521 

522 (medLine,) = ax.plot(xEdgesPlot, meds, color, label="Running Median") 

523 linesForLegend.append(medLine) 

524 

525 # Make path to check which points lie within one sigma mad 

526 threeSigMadPath = Path(threeSigMadVerts, codes) 

527 

528 # Add lines for the median +/- 3 * sigma MAD 

529 (threeSigMadLine,) = ax.plot( 

530 xEdgesPlot, 

531 threeSigMadVerts[: len(xEdgesPlot), 1], 

532 color, 

533 alpha=0.4, 

534 label=r"3$\sigma_{MAD}$", 

535 ) 

536 ax.plot(xEdgesPlot[::-1], threeSigMadVerts[len(xEdgesPlot) :, 1], color, alpha=0.4) 

537 

538 # Add lines for the median +/- 1 * sigma MAD 

539 (sigMadLine,) = ax.plot( 

540 xEdgesPlot, meds + 1.0 * sigMads, color, alpha=0.8, label=r"$\sigma_{MAD}$" 

541 ) 

542 linesForLegend.append(sigMadLine) 

543 ax.plot(xEdgesPlot, meds - 1.0 * sigMads, color, alpha=0.8) 

544 

545 # Add lines for the median +/- 2 * sigma MAD 

546 (twoSigMadLine,) = ax.plot( 

547 xEdgesPlot, meds + 2.0 * sigMads, color, alpha=0.6, label=r"2$\sigma_{MAD}$" 

548 ) 

549 linesForLegend.append(twoSigMadLine) 

550 linesForLegend.append(threeSigMadLine) 

551 ax.plot(xEdgesPlot, meds - 2.0 * sigMads, color, alpha=0.6) 

552 

553 # Check which points are outside 3 sigma MAD of the median 

554 # and plot these as points. 

555 inside = threeSigMadPath.contains_points(np.array([xs, ys]).T) 

556 ax.plot(xs[~inside], ys[~inside], ".", ms=3, alpha=0.3, mfc=color, mec=color, zorder=-1) 

557 

558 # Add some stats text 

559 xPos = 0.65 - 0.4 * j 

560 bbox = dict(edgecolor=color, linestyle="--", facecolor="none") 

561 statText = f"S/N > {highThresh:0.4g} Stats ({self.magLabel} < {highStats.approxMag:0.4g})\n" 

562 highStatsStr = ( 

563 f"Median: {highStats.median:0.4g} " 

564 + r"$\sigma_{MAD}$: " 

565 + f"{highStats.sigmaMad:0.4g} " 

566 + r"N$_{points}$: " 

567 + f"{highStats.count}" 

568 ) 

569 statText += highStatsStr 

570 fig.text(xPos, 0.090, statText, bbox=bbox, transform=fig.transFigure, fontsize=6) 

571 

572 bbox = dict(edgecolor=color, linestyle=":", facecolor="none") 

573 statText = f"S/N > {lowThresh:0.4g} Stats ({self.magLabel} < {lowStats.approxMag:0.4g})\n" 

574 lowStatsStr = ( 

575 f"Median: {lowStats.median:0.4g} " 

576 + r"$\sigma_{MAD}$: " 

577 + f"{lowStats.sigmaMad:0.4g} " 

578 + r"N$_{points}$: " 

579 + f"{lowStats.count}" 

580 ) 

581 statText += lowStatsStr 

582 fig.text(xPos, 0.020, statText, bbox=bbox, transform=fig.transFigure, fontsize=6) 

583 

584 if self.plot2DHist: 

585 histIm = ax.hexbin(xs[inside], ys[inside], gridsize=75, cmap=cmap, mincnt=1, zorder=-3) 

586 

587 # If there are not many sources being used for the 

588 # statistics then plot them individually as just 

589 # plotting a line makes the statistics look wrong 

590 # as the magnitude estimation is iffy for low 

591 # numbers of sources. 

592 if np.sum(highSn) < 100 and np.sum(highSn) > 0: 

593 ax.plot( 

594 cast(Vector, xs[highSn]), 

595 cast(Vector, ys[highSn]), 

596 marker="x", 

597 ms=4, 

598 mec="w", 

599 mew=2, 

600 ls="none", 

601 ) 

602 (highSnLine,) = ax.plot( 

603 cast(Vector, xs[highSn]), 

604 cast(Vector, ys[highSn]), 

605 color=color, 

606 marker="x", 

607 ms=4, 

608 ls="none", 

609 label="High SN", 

610 ) 

611 linesForLegend.append(highSnLine) 

612 else: 

613 ax.axvline(highStats.approxMag, color=color, ls="--") 

614 

615 if np.sum(lowSn) < 100 and np.sum(lowSn) > 0: 

616 ax.plot( 

617 cast(Vector, xs[lowSn]), 

618 cast(Vector, ys[lowSn]), 

619 marker="+", 

620 ms=4, 

621 mec="w", 

622 mew=2, 

623 ls="none", 

624 ) 

625 (lowSnLine,) = ax.plot( 

626 cast(Vector, xs[lowSn]), 

627 cast(Vector, ys[lowSn]), 

628 color=color, 

629 marker="+", 

630 ms=4, 

631 ls="none", 

632 label="Low SN", 

633 ) 

634 linesForLegend.append(lowSnLine) 

635 else: 

636 ax.axvline(lowStats.approxMag, color=color, ls=":") 

637 

638 else: 

639 ax.plot(xs, ys, ".", ms=5, alpha=0.3, mfc=color, mec=color, zorder=-1) 

640 meds = np.array([nanMedian(ys)] * len(xs)) 

641 (medLine,) = ax.plot(xs, meds, color, label=f"Median: {nanMedian(ys):0.3g}", lw=0.8) 

642 linesForLegend.append(medLine) 

643 sigMads = np.array([nanSigmaMad(ys)] * len(xs)) 

644 (sigMadLine,) = ax.plot( 

645 xs, 

646 meds + 1.0 * sigMads, 

647 color, 

648 alpha=0.8, 

649 lw=0.8, 

650 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:0.3g}", 

651 ) 

652 ax.plot(xs, meds - 1.0 * sigMads, color, alpha=0.8) 

653 linesForLegend.append(sigMadLine) 

654 histIm = None 

655 

656 # Add a horizontal reference line at 0 to the scatter plot 

657 ax.axhline(0, color=kwargs["hlineColor"], ls=kwargs["hlineStyle"], alpha=0.7, zorder=-2) 

658 

659 # Set the scatter plot limits 

660 suf_x = self.suffix_y 

661 # TODO: Make this not work by accident 

662 if f"yStars{suf_x}" in data and (len(cast(Vector, data[f"yStars{suf_x}"])) > 0): 

663 plotMed = nanMedian(cast(Vector, data[f"yStars{suf_x}"])) 

664 elif f"yGalaxies{suf_x}" in data and (len(cast(Vector, data[f"yGalaxies{suf_x}"])) > 0): 

665 plotMed = nanMedian(cast(Vector, data[f"yGalaxies{suf_x}"])) 

666 else: 

667 plotMed = np.nan 

668 

669 # Ignore types below pending making this not working my accident 

670 if len(xs) < 2: # type: ignore 

671 meds = [nanMedian(ys)] # type: ignore 

672 if self.yLims: 

673 ax.set_ylim(self.yLims[0], self.yLims[1]) # type: ignore 

674 elif np.isfinite(plotMed): 

675 numSig = 4 

676 yLimMin = plotMed - numSig * sigMadYs # type: ignore 

677 yLimMax = plotMed + numSig * sigMadYs # type: ignore 

678 while (yLimMax < np.max(meds) or yLimMin > np.min(meds)) and numSig < 10: # type: ignore 

679 numSig += 1 

680 

681 numSig += 1 

682 yLimMin = plotMed - numSig * sigMadYs # type: ignore 

683 yLimMax = plotMed + numSig * sigMadYs # type: ignore 

684 ax.set_ylim(yLimMin, yLimMax) 

685 

686 # This could be false if len(x) == 0 for xs in toPlotList 

687 # ... in which case nothing was plotted and limits are irrelevant 

688 if all(np.isfinite(xLims)): 

689 ax.set_xlim(xLims) 

690 

691 # Add a line legend 

692 ax.legend( 

693 handles=linesForLegend, 

694 ncol=4, 

695 fontsize=6, 

696 loc=self.legendLocation, 

697 framealpha=0.9, 

698 edgecolor="k", 

699 borderpad=0.4, 

700 handlelength=1, 

701 ) 

702 

703 # Add axes labels 

704 ax.set_ylabel(self.yAxisLabel, fontsize=10, labelpad=10) 

705 ax.set_xlabel(self.xAxisLabel, fontsize=10, labelpad=2) 

706 

707 return ax, histIm 

708 

709 def _makeTopHistogram( 

710 self, data: KeyedData, figure: Figure, gs: gridspec.GridSpec, ax: Axes, **kwargs 

711 ) -> None: 

712 suf_x = self.suffix_x 

713 # Top histogram 

714 topHist = figure.add_subplot(gs[0, :-1], sharex=ax) 

715 x_min, x_max = ax.get_xlim() 

716 bins = np.linspace(x_min, x_max, 100) 

717 

718 if "any" in self.plotTypes: 

719 x_any = f"x{self._datatypes['any'].suffix_xy}{suf_x}" 

720 keys_notany = [x for x in self.plotTypes if x != "any"] 

721 else: 

722 x_any = ( 

723 np.concatenate([data[f"x{self._datatypes[key].suffix_xy}{suf_x}"] for key in self.plotTypes]) 

724 if (len(self.plotTypes) > 1) 

725 else None 

726 ) 

727 keys_notany = self.plotTypes 

728 if x_any is not None: 

729 topHist.hist(x_any, bins=bins, color="grey", alpha=0.3, log=True, label=f"Any ({len(x_any)})") 

730 

731 for key in keys_notany: 

732 config_datatype = self._datatypes[key] 

733 vector = data[f"x{config_datatype.suffix_xy}{suf_x}"] 

734 topHist.hist( 

735 vector, 

736 bins=bins, 

737 color=config_datatype.color, 

738 histtype="step", 

739 log=True, 

740 label=f"{config_datatype.suffix_stat} ({len(vector)})", 

741 ) 

742 topHist.axes.get_xaxis().set_visible(False) 

743 topHist.set_ylabel("Number", fontsize=8) 

744 topHist.legend(fontsize=6, framealpha=0.9, borderpad=0.4, loc="lower left", ncol=3, edgecolor="k") 

745 

746 # Side histogram 

747 

748 def _makeSideHistogram( 

749 self, 

750 data: KeyedData, 

751 figure: Figure, 

752 gs: gridspec.Gridspec, 

753 ax: Axes, 

754 histIm: Optional[PolyCollection], 

755 **kwargs, 

756 ) -> None: 

757 suf_y = self.suffix_y 

758 # Side histogram 

759 sideHist = figure.add_subplot(gs[1:, -1], sharey=ax) 

760 y_min, y_max = ax.get_ylim() 

761 bins = np.linspace(y_min, y_max, 100) 

762 

763 if "any" in self.plotTypes: 

764 y_any = f"y{self._datatypes['any'].suffix_xy}{suf_y}" 

765 keys_notany = [x for x in self.plotTypes if x != "any"] 

766 else: 

767 y_any = ( 

768 np.concatenate([data[f"y{self._datatypes[key].suffix_xy}{suf_y}"] for key in self.plotTypes]) 

769 if (len(self.plotTypes) > 1) 

770 else None 

771 ) 

772 keys_notany = self.plotTypes 

773 if y_any is not None: 

774 sideHist.hist( 

775 np.array(y_any), 

776 bins=bins, 

777 color="grey", 

778 alpha=0.3, 

779 orientation="horizontal", 

780 log=True, 

781 ) 

782 kwargs_hist = dict( 

783 bins=bins, 

784 histtype="step", 

785 log=True, 

786 orientation="horizontal", 

787 ) 

788 for key in keys_notany: 

789 config_datatype = self._datatypes[key] 

790 vector = data[f"y{config_datatype.suffix_xy}{suf_y}"] 

791 sideHist.hist( 

792 vector, 

793 color=config_datatype.color, 

794 **kwargs_hist, 

795 ) 

796 sideHist.hist( 

797 vector[cast(Vector, data[f"{key}HighSNMask{self.suffix_stat}"])], 

798 color=config_datatype.color, 

799 ls="--", 

800 **kwargs_hist, 

801 ) 

802 sideHist.hist( 

803 vector[cast(Vector, data[f"{key}LowSNMask{self.suffix_stat}"])], 

804 color=config_datatype.color, 

805 **kwargs_hist, 

806 ls=":", 

807 ) 

808 

809 # Add a horizontal reference line at 0 to the side histogram 

810 sideHist.axhline(0, color=kwargs["hlineColor"], ls=kwargs["hlineStyle"], alpha=0.7, zorder=-2) 

811 

812 sideHist.axes.get_yaxis().set_visible(False) 

813 sideHist.set_xlabel("Number", fontsize=8) 

814 if self.plot2DHist and histIm is not None: 

815 divider = make_axes_locatable(sideHist) 

816 cax = divider.append_axes("right", size="8%", pad=0) 

817 plt.colorbar(histIm, cax=cax, orientation="vertical", label="Number of Points Per Bin")