Coverage for python/lsst/analysis/tools/actions/plot/scatterplotWithTwoHists.py: 16%

342 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-13 11:47 +0000

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ("ScatterPlotStatsAction", "ScatterPlotWithTwoHists") 

25 

26from typing import Mapping, NamedTuple, Optional, cast 

27 

28import matplotlib.colors 

29import matplotlib.pyplot as plt 

30import numpy as np 

31from lsst.pex.config import Field 

32from lsst.pex.config.configurableActions import ConfigurableActionField 

33from lsst.pex.config.listField import ListField 

34from lsst.skymap import BaseSkyMap 

35from matplotlib import gridspec 

36from matplotlib.axes import Axes 

37from matplotlib.collections import PolyCollection 

38from matplotlib.figure import Figure 

39from matplotlib.path import Path 

40from mpl_toolkits.axes_grid1 import make_axes_locatable 

41 

42from ...interfaces import KeyedData, KeyedDataAction, KeyedDataSchema, PlotAction, Scalar, ScalarType, Vector 

43from ...math import nanMedian, nanSigmaMad 

44from ..keyedData.summaryStatistics import SummaryStatisticAction 

45from ..scalar import MedianAction 

46from ..vector import ConvertFluxToMag, SnSelector 

47from .plotUtils import addPlotInfo, addSummaryPlot, generateSummaryStats, mkColormap 

48 

49# ignore because coolwarm is actually part of module 

50cmapPatch = plt.cm.coolwarm.copy() # type: ignore 

51cmapPatch.set_bad(color="none") 

52 

53 

54class ScatterPlotStatsAction(KeyedDataAction): 

55 """Calculates the statistics needed for the 

56 scatter plot with two hists. 

57 """ 

58 

59 vectorKey = Field[str](doc="Vector on which to compute statistics") 

60 highSNSelector = ConfigurableActionField[SnSelector]( 

61 doc="Selector used to determine high SN Objects", default=SnSelector(threshold=2700) 

62 ) 

63 lowSNSelector = ConfigurableActionField[SnSelector]( 

64 doc="Selector used to determine low SN Objects", default=SnSelector(threshold=500) 

65 ) 

66 prefix = Field[str]( 

67 doc="Prefix for all output fields; will use self.identity if None", 

68 optional=True, 

69 default=None, 

70 ) 

71 fluxType = Field[str](doc="Vector key to use to compute signal to noise ratio", default="{band}_psfFlux") 

72 suffix = Field[str](doc="Suffix for all output fields", default="") 

73 

74 def _get_key_prefix(self): 

75 prefix = self.prefix if self.prefix else (self.identity if self.identity else "") 

76 return prefix 

77 

78 def getInputSchema(self, **kwargs) -> KeyedDataSchema: 

79 yield (self.vectorKey, Vector) 

80 yield (self.fluxType, Vector) 

81 yield from self.highSNSelector.getInputSchema() 

82 yield from self.lowSNSelector.getInputSchema() 

83 

84 def getOutputSchema(self) -> KeyedDataSchema: 

85 prefix = self._get_key_prefix() 

86 prefix_lower = prefix.lower() if prefix else "" 

87 prefix_upper = prefix.capitalize() if prefix else "" 

88 suffix = self.suffix 

89 return ( 

90 (f"{prefix_lower}HighSNMask{suffix}", Vector), 

91 (f"{prefix_lower}LowSNMask{suffix}", Vector), 

92 (f"{{band}}_lowSN{prefix_upper}_median{suffix}", Scalar), 

93 (f"{{band}}_lowSN{prefix_upper}_sigmaMad{suffix}", Scalar), 

94 (f"{{band}}_lowSN{prefix_upper}_count{suffix}", Scalar), 

95 (f"{{band}}_lowSN{prefix_upper}_approxMag{suffix}", Scalar), 

96 (f"{{band}}_highSN{prefix_upper}_median{suffix}", Scalar), 

97 (f"{{band}}_highSN{prefix_upper}_sigmaMad{suffix}", Scalar), 

98 (f"{{band}}_highSN{prefix_upper}_count{suffix}", Scalar), 

99 (f"{{band}}_highSN{prefix_upper}_approxMag{suffix}", Scalar), 

100 (f"{prefix_lower}HighSNThreshold{suffix}", Scalar), 

101 (f"{prefix_lower}LowSNThreshold{suffix}", Scalar), 

102 ) 

103 

104 def __call__(self, data: KeyedData, **kwargs) -> KeyedData: 

105 results = {} 

106 prefix = self._get_key_prefix() 

107 prefix_lower = prefix.lower() if prefix else "" 

108 prefix_upper = prefix.capitalize() if prefix else "" 

109 suffix = self.suffix 

110 highMaskKey = f"{prefix_lower}HighSNMask{suffix}" 

111 results[highMaskKey] = self.highSNSelector(data, **kwargs) 

112 

113 lowMaskKey = f"{prefix_lower}LowSNMask{suffix}" 

114 results[lowMaskKey] = self.lowSNSelector(data, **kwargs) 

115 

116 prefix_band = f"{band}_" if (band := kwargs.get("band")) else "" 

117 fluxes = data[self.fluxType.format(band=band)] if band is not None else None 

118 

119 statAction = SummaryStatisticAction(vectorKey=self.vectorKey) 

120 

121 # this is sad, but pex_config seems to have broken behavior that 

122 # is dangerous to fix 

123 statAction.setDefaults() 

124 

125 medianAction = MedianAction(vectorKey="mag") 

126 magAction = ConvertFluxToMag(vectorKey="flux") 

127 

128 for maskKey, binName in ((lowMaskKey, "low"), (highMaskKey, "high")): 

129 name = f"{prefix_band}{binName}SN{prefix_upper}" 

130 # set the approxMag to the median mag in the SN selection 

131 results[f"{name}_approxMag{suffix}".format(**kwargs)] = ( 

132 medianAction({"mag": magAction({"flux": fluxes[results[maskKey]]})}) # type: ignore 

133 if band is not None 

134 else np.nan 

135 ) 

136 stats = statAction(data, **(kwargs | {"mask": results[maskKey]})).items() 

137 for name_stat, value in stats: 

138 tmpKey = f"{name}_{name_stat}{suffix}".format(**kwargs) 

139 results[tmpKey] = value 

140 results[f"{prefix_lower}HighSNThreshold{suffix}"] = self.highSNSelector.threshold # type: ignore 

141 results[f"{prefix_lower}LowSNThreshold{suffix}"] = self.lowSNSelector.threshold # type: ignore 

142 

143 return results 

144 

145 

146def _validObjectTypes(value): 

147 return value in ("stars", "galaxies", "unknown", "any") 

148 

149 

150# ignore type because of conflicting name on tuple baseclass 

151class _StatsContainer(NamedTuple): 

152 median: Scalar 

153 sigmaMad: Scalar 

154 count: Scalar # type: ignore 

155 approxMag: Scalar 

156 

157 

158class DataTypeDefaults(NamedTuple): 

159 suffix_stat: str 

160 suffix_xy: str 

161 color: str 

162 colormap: matplotlib.colors.Colormap | None 

163 

164 

165class ScatterPlotWithTwoHists(PlotAction): 

166 """Makes a scatter plot of the data with a marginal 

167 histogram for each axis. 

168 """ 

169 

170 yLims = ListField[float]( 

171 doc="ylimits of the plot, if not specified determined from data", 

172 length=2, 

173 optional=True, 

174 ) 

175 

176 xLims = ListField[float]( 

177 doc="xlimits of the plot, if not specified determined from data", length=2, optional=True 

178 ) 

179 xAxisLabel = Field[str](doc="Label to use for the x axis", optional=False) 

180 yAxisLabel = Field[str](doc="Label to use for the y axis", optional=False) 

181 magLabel = Field[str](doc="Label to use for the magnitudes used for SNR", optional=False) 

182 

183 legendLocation = Field[str](doc="Legend position within main plot", default="upper left") 

184 nBins = Field[float](doc="Number of bins on x axis", default=40.0) 

185 plot2DHist = Field[bool]( 

186 doc="Plot a 2D histogram in dense areas of points on the scatter plot." 

187 "Doesn't look great if plotting multiple datasets on top of each other.", 

188 default=True, 

189 ) 

190 plotTypes = ListField[str]( 

191 doc="Selection of types of objects to plot. Can take any combination of" 

192 " stars, galaxies, unknown, any", 

193 optional=False, 

194 itemCheck=_validObjectTypes, 

195 ) 

196 

197 addSummaryPlot = Field[bool]( 

198 doc="Add a summary plot to the figure?", 

199 default=True, 

200 ) 

201 suffix_x = Field[str](doc="Suffix for all x-axis action inputs", optional=True, default="") 

202 suffix_y = Field[str](doc="Suffix for all y-axis action inputs", optional=True, default="") 

203 suffix_stat = Field[str](doc="Suffix for all binned statistic action inputs", optional=True, default="") 

204 

205 _stats = ("median", "sigmaMad", "count", "approxMag") 

206 _datatypes = { 

207 "galaxies": DataTypeDefaults( 

208 suffix_stat="Galaxies", 

209 suffix_xy="Galaxies", 

210 color="firebrick", 

211 colormap=mkColormap(["lemonchiffon", "firebrick"]), 

212 ), 

213 "stars": DataTypeDefaults( 

214 suffix_stat="Stars", 

215 suffix_xy="Stars", 

216 color="midnightblue", 

217 colormap=mkColormap(["paleturquoise", "midnightBlue"]), 

218 ), 

219 "unknown": DataTypeDefaults( 

220 suffix_stat="Unknown", 

221 suffix_xy="Unknown", 

222 color="green", 

223 colormap=None, 

224 ), 

225 "any": DataTypeDefaults( 

226 suffix_stat="Any", 

227 suffix_xy="", 

228 color="purple", 

229 colormap=None, 

230 ), 

231 } 

232 

233 def getInputSchema(self) -> KeyedDataSchema: 

234 base: list[tuple[str, type[Vector] | ScalarType]] = [] 

235 for name_datatype in self.plotTypes: 

236 config_datatype = self._datatypes[name_datatype] 

237 base.append((f"x{config_datatype.suffix_xy}{self.suffix_x}", Vector)) 

238 base.append((f"y{config_datatype.suffix_xy}{self.suffix_y}", Vector)) 

239 base.append((f"{name_datatype}HighSNMask{self.suffix_stat}", Vector)) 

240 base.append((f"{name_datatype}LowSNMask{self.suffix_stat}", Vector)) 

241 # statistics 

242 for name in self._stats: 

243 base.append( 

244 (f"{{band}}_highSN{config_datatype.suffix_stat}_{name}{self.suffix_stat}", Scalar) 

245 ) 

246 base.append((f"{{band}}_lowSN{config_datatype.suffix_stat}_{name}{self.suffix_stat}", Scalar)) 

247 base.append((f"{name_datatype}LowSNThreshold{self.suffix_stat}", Scalar)) 

248 base.append((f"{name_datatype}HighSNThreshold{self.suffix_stat}", Scalar)) 

249 

250 if self.addSummaryPlot: 

251 base.append(("patch", Vector)) 

252 

253 return base 

254 

255 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure: 

256 self._validateInput(data, **kwargs) 

257 return self.makePlot(data, **kwargs) 

258 

259 def _validateInput(self, data: KeyedData, **kwargs) -> None: 

260 """NOTE currently can only check that something is not a Scalar, not 

261 check that the data is consistent with Vector 

262 """ 

263 needed = self.getFormattedInputSchema(**kwargs) 

264 if remainder := {key.format(**kwargs) for key, _ in needed} - { 

265 key.format(**kwargs) for key in data.keys() 

266 }: 

267 raise ValueError( 

268 f"Task needs keys {remainder} but they were not found in input keys" f" {list(data.keys())}" 

269 ) 

270 for name, typ in needed: 

271 isScalar = issubclass((colType := type(data[name.format(**kwargs)])), Scalar) 

272 if isScalar and typ != Scalar: 

273 raise ValueError(f"Data keyed by {name} has type {colType} but action requires type {typ}") 

274 

275 def makePlot( 

276 self, 

277 data: KeyedData, 

278 skymap: BaseSkyMap, 

279 plotInfo: Mapping[str, str], 

280 **kwargs, 

281 ) -> Figure: 

282 """Makes a generic plot with a 2D histogram and collapsed histograms of 

283 each axis. 

284 

285 Parameters 

286 ---------- 

287 data : `KeyedData` 

288 The catalog to plot the points from. 

289 skymap : `lsst.skymap.BaseSkyMap` 

290 The skymap that gives the patch locations 

291 plotInfo : `dict` 

292 A dictionary of information about the data being plotted with keys: 

293 

294 * ``"run"`` 

295 The output run for the plots (`str`). 

296 * ``"skymap"`` 

297 The type of skymap used for the data (`str`). 

298 * ``"filter"`` 

299 The filter used for this data (`str`). 

300 * ``"tract"`` 

301 The tract that the data comes from (`str`). 

302 

303 Returns 

304 ------- 

305 fig : `matplotlib.figure.Figure` 

306 The resulting figure. 

307 

308 Notes 

309 ----- 

310 Uses the axisLabels config options `x` and `y` and the axisAction 

311 config options `xAction` and `yAction` to plot a scatter 

312 plot of the values against each other. A histogram of the points 

313 collapsed onto each axis is also plotted. A summary panel showing the 

314 median of the y value in each patch is shown in the upper right corner 

315 of the resultant plot. The code uses the selectorActions to decide 

316 which points to plot and the statisticSelector actions to determine 

317 which points to use for the printed statistics. 

318 

319 If this function is being used within the pipetask framework 

320 that takes care of making sure that data has all the required 

321 elements but if you are running this as a standalone function 

322 then you will need to provide the following things in the 

323 input data. 

324 

325 * If stars is in self.plotTypes: 

326 xStars, yStars, starsHighSNMask, starsLowSNMask and 

327 {band}_highSNStars_{name}, {band}_lowSNStars_{name} 

328 where name is median, sigma_Mad, count and approxMag. 

329 

330 * If it is for galaxies/unknowns then replace stars in the above 

331 names with galaxies/unknowns. 

332 

333 * If it is for any (which covers all the points) then it 

334 becomes, x, y, and any instead of stars for the other 

335 parameters given above. 

336 

337 * In every case it is expected that data contains: 

338 lowSnThreshold, highSnThreshold and patch 

339 (if the summary plot is being plotted). 

340 

341 Examples 

342 -------- 

343 An example of the plot produced from this code is here: 

344 

345 .. image:: /_static/analysis_tools/scatterPlotExample.png 

346 

347 For a detailed example of how to make a plot from the command line 

348 please see the 

349 :ref:`getting started guide<analysis-tools-getting-started>`. 

350 """ 

351 if not self.plotTypes: 

352 noDataFig = Figure() 

353 noDataFig.text(0.3, 0.5, "No data to plot after selectors applied") 

354 noDataFig = addPlotInfo(noDataFig, plotInfo) 

355 return noDataFig 

356 

357 # Set default color and line style for the horizontal 

358 # reference line at 0 

359 if "hlineColor" not in kwargs: 

360 kwargs["hlineColor"] = "black" 

361 

362 if "hlineStyle" not in kwargs: 

363 kwargs["hlineStyle"] = (0, (1, 4)) 

364 

365 fig = plt.figure(dpi=300) 

366 gs = gridspec.GridSpec(4, 4) 

367 

368 # add the various plot elements 

369 ax, imhist = self._scatterPlot(data, fig, gs, **kwargs) 

370 self._makeTopHistogram(data, fig, gs, ax, **kwargs) 

371 self._makeSideHistogram(data, fig, gs, ax, imhist, **kwargs) 

372 # Needs info from run quantum 

373 if self.addSummaryPlot: 

374 sumStats = generateSummaryStats(data, skymap, plotInfo) 

375 label = self.yAxisLabel 

376 fig = addSummaryPlot(fig, gs[0, -1], sumStats, label) 

377 

378 plt.draw() 

379 plt.subplots_adjust(wspace=0.0, hspace=0.0, bottom=0.22, left=0.21) 

380 fig = addPlotInfo(fig, plotInfo) 

381 return fig 

382 

383 def _scatterPlot( 

384 self, data: KeyedData, fig: Figure, gs: gridspec.GridSpec, **kwargs 

385 ) -> tuple[Axes, Optional[PolyCollection]]: 

386 suf_x = self.suffix_x 

387 suf_y = self.suffix_y 

388 suf_stat = self.suffix_stat 

389 # Main scatter plot 

390 ax = fig.add_subplot(gs[1:, :-1]) 

391 

392 binThresh = 5 

393 

394 yBinsOut = [] 

395 linesForLegend = [] 

396 

397 toPlotList = [] 

398 histIm = None 

399 highStats: _StatsContainer 

400 lowStats: _StatsContainer 

401 

402 for name_datatype in self.plotTypes: 

403 config_datatype = self._datatypes[name_datatype] 

404 highArgs = {} 

405 lowArgs = {} 

406 for name in self._stats: 

407 highArgs[name] = cast( 

408 Scalar, 

409 data[f"{{band}}_highSN{config_datatype.suffix_stat}_{name}{suf_stat}".format(**kwargs)], 

410 ) 

411 lowArgs[name] = cast( 

412 Scalar, 

413 data[f"{{band}}_lowSN{config_datatype.suffix_stat}_{name}{suf_stat}".format(**kwargs)], 

414 ) 

415 highStats = _StatsContainer(**highArgs) 

416 lowStats = _StatsContainer(**lowArgs) 

417 

418 toPlotList.append( 

419 ( 

420 data[f"x{config_datatype.suffix_xy}{suf_x}"], 

421 data[f"y{config_datatype.suffix_xy}{suf_y}"], 

422 data[f"{name_datatype}HighSNMask{suf_stat}"], 

423 data[f"{name_datatype}LowSNMask{suf_stat}"], 

424 data[f"{name_datatype}HighSNThreshold{suf_stat}"], 

425 data[f"{name_datatype}LowSNThreshold{suf_stat}"], 

426 config_datatype.color, 

427 config_datatype.colormap, 

428 highStats, 

429 lowStats, 

430 ) 

431 ) 

432 

433 xLims = self.xLims if self.xLims is not None else [np.Inf, -np.Inf] 

434 for j, ( 

435 xs, 

436 ys, 

437 highSn, 

438 lowSn, 

439 highThresh, 

440 lowThresh, 

441 color, 

442 cmap, 

443 highStats, 

444 lowStats, 

445 ) in enumerate(toPlotList): 

446 highSn = cast(Vector, highSn) 

447 lowSn = cast(Vector, lowSn) 

448 # ensure the columns are actually array 

449 xs = np.array(xs) 

450 ys = np.array(ys) 

451 sigMadYs = nanSigmaMad(ys) 

452 # plot lone median point if there's not enough data to measure more 

453 n_xs = len(xs) 

454 if n_xs == 0 or not np.isfinite(sigMadYs): 

455 continue 

456 elif n_xs < 10: 

457 xs = [nanMedian(xs)] 

458 sigMads = np.array([nanSigmaMad(ys)]) 

459 ys = np.array([nanMedian(ys)]) 

460 (medLine,) = ax.plot(xs, ys, color, label=f"Median: {ys[0]:.2g}", lw=0.8) 

461 linesForLegend.append(medLine) 

462 (sigMadLine,) = ax.plot( 

463 xs, 

464 ys + 1.0 * sigMads, 

465 color, 

466 alpha=0.8, 

467 lw=0.8, 

468 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:.2g}", 

469 ) 

470 ax.plot(xs, ys - 1.0 * sigMads, color, alpha=0.8) 

471 linesForLegend.append(sigMadLine) 

472 histIm = None 

473 continue 

474 

475 if self.xLims: 

476 xMin, xMax = self.xLims 

477 else: 

478 # Chop off 1/3% from only the finite xs values 

479 # (there may be +/-np.Inf values) 

480 # TODO: This should be configurable 

481 # but is there a good way to avoid redundant config params 

482 # without using slightly annoying subconfigs? 

483 xs1, xs97 = np.nanpercentile(xs[np.isfinite(xs)], (1, 97)) 

484 xScale = (xs97 - xs1) / 20.0 # This is ~5% of the data range 

485 xMin, xMax = (xs1 - xScale, xs97 + xScale) 

486 xLims[0] = min(xLims[0], xMin) 

487 xLims[1] = max(xLims[1], xMax) 

488 

489 xEdges = np.arange(xMin, xMax, (xMax - xMin) / self.nBins) 

490 medYs = nanMedian(ys) 

491 fiveSigmaHigh = medYs + 5.0 * sigMadYs 

492 fiveSigmaLow = medYs - 5.0 * sigMadYs 

493 binSize = (fiveSigmaHigh - fiveSigmaLow) / 101.0 

494 yEdges = np.arange(fiveSigmaLow, fiveSigmaHigh, binSize) 

495 

496 counts, xBins, yBins = np.histogram2d(xs, ys, bins=(xEdges, yEdges)) 

497 yBinsOut.append(yBins) 

498 countsYs = np.sum(counts, axis=1) 

499 

500 ids = np.where((countsYs > binThresh))[0] 

501 xEdgesPlot = xEdges[ids][1:] 

502 xEdges = xEdges[ids] 

503 

504 if len(ids) > 1: 

505 # Create the codes needed to turn the sigmaMad lines 

506 # into a path to speed up checking which points are 

507 # inside the area. 

508 codes = np.ones(len(xEdgesPlot) * 2) * Path.LINETO 

509 codes[0] = Path.MOVETO 

510 codes[-1] = Path.CLOSEPOLY 

511 

512 meds = np.zeros(len(xEdgesPlot)) 

513 threeSigMadVerts = np.zeros((len(xEdgesPlot) * 2, 2)) 

514 sigMads = np.zeros(len(xEdgesPlot)) 

515 

516 for i, xEdge in enumerate(xEdgesPlot): 

517 ids = np.where((xs < xEdge) & (xs > xEdges[i]) & (np.isfinite(ys)))[0] 

518 med = nanMedian(ys[ids]) 

519 sigMad = nanSigmaMad(ys[ids]) 

520 meds[i] = med 

521 sigMads[i] = sigMad 

522 threeSigMadVerts[i, :] = [xEdge, med + 3 * sigMad] 

523 threeSigMadVerts[-(i + 1), :] = [xEdge, med - 3 * sigMad] 

524 

525 (medLine,) = ax.plot(xEdgesPlot, meds, color, label="Running Median") 

526 linesForLegend.append(medLine) 

527 

528 # Make path to check which points lie within one sigma mad 

529 threeSigMadPath = Path(threeSigMadVerts, codes) 

530 

531 # Add lines for the median +/- 3 * sigma MAD 

532 (threeSigMadLine,) = ax.plot( 

533 xEdgesPlot, 

534 threeSigMadVerts[: len(xEdgesPlot), 1], 

535 color, 

536 alpha=0.4, 

537 label=r"3$\sigma_{MAD}$", 

538 ) 

539 ax.plot(xEdgesPlot[::-1], threeSigMadVerts[len(xEdgesPlot) :, 1], color, alpha=0.4) 

540 

541 # Add lines for the median +/- 1 * sigma MAD 

542 (sigMadLine,) = ax.plot( 

543 xEdgesPlot, meds + 1.0 * sigMads, color, alpha=0.8, label=r"$\sigma_{MAD}$" 

544 ) 

545 linesForLegend.append(sigMadLine) 

546 ax.plot(xEdgesPlot, meds - 1.0 * sigMads, color, alpha=0.8) 

547 

548 # Add lines for the median +/- 2 * sigma MAD 

549 (twoSigMadLine,) = ax.plot( 

550 xEdgesPlot, meds + 2.0 * sigMads, color, alpha=0.6, label=r"2$\sigma_{MAD}$" 

551 ) 

552 linesForLegend.append(twoSigMadLine) 

553 linesForLegend.append(threeSigMadLine) 

554 ax.plot(xEdgesPlot, meds - 2.0 * sigMads, color, alpha=0.6) 

555 

556 # Check which points are outside 3 sigma MAD of the median 

557 # and plot these as points. 

558 inside = threeSigMadPath.contains_points(np.array([xs, ys]).T) 

559 ax.plot(xs[~inside], ys[~inside], ".", ms=3, alpha=0.3, mfc=color, mec=color, zorder=-1) 

560 

561 # Add some stats text 

562 xPos = 0.65 - 0.4 * j 

563 bbox = dict(edgecolor=color, linestyle="--", facecolor="none") 

564 statText = f"S/N > {highThresh:0.4g} Stats ({self.magLabel} < {highStats.approxMag:0.4g})\n" 

565 highStatsStr = ( 

566 f"Median: {highStats.median:0.4g} " 

567 + r"$\sigma_{MAD}$: " 

568 + f"{highStats.sigmaMad:0.4g} " 

569 + r"N$_{points}$: " 

570 + f"{highStats.count}" 

571 ) 

572 statText += highStatsStr 

573 fig.text(xPos, 0.090, statText, bbox=bbox, transform=fig.transFigure, fontsize=6) 

574 

575 bbox = dict(edgecolor=color, linestyle=":", facecolor="none") 

576 statText = f"S/N > {lowThresh:0.4g} Stats ({self.magLabel} < {lowStats.approxMag:0.4g})\n" 

577 lowStatsStr = ( 

578 f"Median: {lowStats.median:0.4g} " 

579 + r"$\sigma_{MAD}$: " 

580 + f"{lowStats.sigmaMad:0.4g} " 

581 + r"N$_{points}$: " 

582 + f"{lowStats.count}" 

583 ) 

584 statText += lowStatsStr 

585 fig.text(xPos, 0.020, statText, bbox=bbox, transform=fig.transFigure, fontsize=6) 

586 

587 if self.plot2DHist: 

588 histIm = ax.hexbin(xs[inside], ys[inside], gridsize=75, cmap=cmap, mincnt=1, zorder=-3) 

589 

590 # If there are not many sources being used for the 

591 # statistics then plot them individually as just 

592 # plotting a line makes the statistics look wrong 

593 # as the magnitude estimation is iffy for low 

594 # numbers of sources. 

595 if np.sum(highSn) < 100 and np.sum(highSn) > 0: 

596 ax.plot( 

597 cast(Vector, xs[highSn]), 

598 cast(Vector, ys[highSn]), 

599 marker="x", 

600 ms=4, 

601 mec="w", 

602 mew=2, 

603 ls="none", 

604 ) 

605 (highSnLine,) = ax.plot( 

606 cast(Vector, xs[highSn]), 

607 cast(Vector, ys[highSn]), 

608 color=color, 

609 marker="x", 

610 ms=4, 

611 ls="none", 

612 label="High SN", 

613 ) 

614 linesForLegend.append(highSnLine) 

615 else: 

616 ax.axvline(highStats.approxMag, color=color, ls="--") 

617 

618 if np.sum(lowSn) < 100 and np.sum(lowSn) > 0: 

619 ax.plot( 

620 cast(Vector, xs[lowSn]), 

621 cast(Vector, ys[lowSn]), 

622 marker="+", 

623 ms=4, 

624 mec="w", 

625 mew=2, 

626 ls="none", 

627 ) 

628 (lowSnLine,) = ax.plot( 

629 cast(Vector, xs[lowSn]), 

630 cast(Vector, ys[lowSn]), 

631 color=color, 

632 marker="+", 

633 ms=4, 

634 ls="none", 

635 label="Low SN", 

636 ) 

637 linesForLegend.append(lowSnLine) 

638 else: 

639 ax.axvline(lowStats.approxMag, color=color, ls=":") 

640 

641 else: 

642 ax.plot(xs, ys, ".", ms=5, alpha=0.3, mfc=color, mec=color, zorder=-1) 

643 meds = np.array([nanMedian(ys)] * len(xs)) 

644 (medLine,) = ax.plot(xs, meds, color, label=f"Median: {nanMedian(ys):0.3g}", lw=0.8) 

645 linesForLegend.append(medLine) 

646 sigMads = np.array([nanSigmaMad(ys)] * len(xs)) 

647 (sigMadLine,) = ax.plot( 

648 xs, 

649 meds + 1.0 * sigMads, 

650 color, 

651 alpha=0.8, 

652 lw=0.8, 

653 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:0.3g}", 

654 ) 

655 ax.plot(xs, meds - 1.0 * sigMads, color, alpha=0.8) 

656 linesForLegend.append(sigMadLine) 

657 histIm = None 

658 

659 # Add a horizontal reference line at 0 to the scatter plot 

660 ax.axhline(0, color=kwargs["hlineColor"], ls=kwargs["hlineStyle"], alpha=0.7, zorder=-2) 

661 

662 # Set the scatter plot limits 

663 suf_x = self.suffix_y 

664 # TODO: Make this not work by accident 

665 if f"yStars{suf_x}" in data and (len(cast(Vector, data[f"yStars{suf_x}"])) > 0): 

666 plotMed = nanMedian(cast(Vector, data[f"yStars{suf_x}"])) 

667 elif f"yGalaxies{suf_x}" in data and (len(cast(Vector, data[f"yGalaxies{suf_x}"])) > 0): 

668 plotMed = nanMedian(cast(Vector, data[f"yGalaxies{suf_x}"])) 

669 else: 

670 plotMed = np.nan 

671 

672 # Ignore types below pending making this not working my accident 

673 if len(xs) < 2: # type: ignore 

674 meds = [nanMedian(ys)] # type: ignore 

675 if self.yLims: 

676 ax.set_ylim(self.yLims[0], self.yLims[1]) # type: ignore 

677 elif np.isfinite(plotMed): 

678 numSig = 4 

679 yLimMin = plotMed - numSig * sigMadYs # type: ignore 

680 yLimMax = plotMed + numSig * sigMadYs # type: ignore 

681 while (yLimMax < np.max(meds) or yLimMin > np.min(meds)) and numSig < 10: # type: ignore 

682 numSig += 1 

683 

684 numSig += 1 

685 yLimMin = plotMed - numSig * sigMadYs # type: ignore 

686 yLimMax = plotMed + numSig * sigMadYs # type: ignore 

687 ax.set_ylim(yLimMin, yLimMax) 

688 

689 # This could be false if len(x) == 0 for xs in toPlotList 

690 # ... in which case nothing was plotted and limits are irrelevant 

691 if all(np.isfinite(xLims)): 

692 ax.set_xlim(xLims) 

693 

694 # Add a line legend 

695 ax.legend( 

696 handles=linesForLegend, 

697 ncol=4, 

698 fontsize=6, 

699 loc=self.legendLocation, 

700 framealpha=0.9, 

701 edgecolor="k", 

702 borderpad=0.4, 

703 handlelength=1, 

704 ) 

705 

706 # Add axes labels 

707 ax.set_ylabel(self.yAxisLabel, fontsize=10, labelpad=10) 

708 ax.set_xlabel(self.xAxisLabel, fontsize=10, labelpad=2) 

709 

710 return ax, histIm 

711 

712 def _makeTopHistogram( 

713 self, data: KeyedData, figure: Figure, gs: gridspec.GridSpec, ax: Axes, **kwargs 

714 ) -> None: 

715 suf_x = self.suffix_x 

716 # Top histogram 

717 topHist = figure.add_subplot(gs[0, :-1], sharex=ax) 

718 x_min, x_max = ax.get_xlim() 

719 bins = np.linspace(x_min, x_max, 100) 

720 

721 if "any" in self.plotTypes: 

722 x_any = f"x{self._datatypes['any'].suffix_xy}{suf_x}" 

723 keys_notany = [x for x in self.plotTypes if x != "any"] 

724 else: 

725 x_any = ( 

726 np.concatenate([data[f"x{self._datatypes[key].suffix_xy}{suf_x}"] for key in self.plotTypes]) 

727 if (len(self.plotTypes) > 1) 

728 else None 

729 ) 

730 keys_notany = self.plotTypes 

731 if x_any is not None: 

732 topHist.hist(x_any, bins=bins, color="grey", alpha=0.3, log=True, label=f"Any ({len(x_any)})") 

733 

734 for key in keys_notany: 

735 config_datatype = self._datatypes[key] 

736 vector = data[f"x{config_datatype.suffix_xy}{suf_x}"] 

737 topHist.hist( 

738 vector, 

739 bins=bins, 

740 color=config_datatype.color, 

741 histtype="step", 

742 log=True, 

743 label=f"{config_datatype.suffix_stat} ({len(vector)})", 

744 ) 

745 topHist.axes.get_xaxis().set_visible(False) 

746 topHist.set_ylabel("Number", fontsize=8) 

747 topHist.legend(fontsize=6, framealpha=0.9, borderpad=0.4, loc="lower left", ncol=3, edgecolor="k") 

748 

749 # Side histogram 

750 

751 def _makeSideHistogram( 

752 self, 

753 data: KeyedData, 

754 figure: Figure, 

755 gs: gridspec.Gridspec, 

756 ax: Axes, 

757 histIm: Optional[PolyCollection], 

758 **kwargs, 

759 ) -> None: 

760 suf_y = self.suffix_y 

761 # Side histogram 

762 sideHist = figure.add_subplot(gs[1:, -1], sharey=ax) 

763 y_min, y_max = ax.get_ylim() 

764 bins = np.linspace(y_min, y_max, 100) 

765 

766 if "any" in self.plotTypes: 

767 y_any = f"y{self._datatypes['any'].suffix_xy}{suf_y}" 

768 keys_notany = [x for x in self.plotTypes if x != "any"] 

769 else: 

770 y_any = ( 

771 np.concatenate([data[f"y{self._datatypes[key].suffix_xy}{suf_y}"] for key in self.plotTypes]) 

772 if (len(self.plotTypes) > 1) 

773 else None 

774 ) 

775 keys_notany = self.plotTypes 

776 if y_any is not None: 

777 sideHist.hist( 

778 np.array(y_any), 

779 bins=bins, 

780 color="grey", 

781 alpha=0.3, 

782 orientation="horizontal", 

783 log=True, 

784 ) 

785 kwargs_hist = dict( 

786 bins=bins, 

787 histtype="step", 

788 log=True, 

789 orientation="horizontal", 

790 ) 

791 for key in keys_notany: 

792 config_datatype = self._datatypes[key] 

793 vector = data[f"y{config_datatype.suffix_xy}{suf_y}"] 

794 sideHist.hist( 

795 vector, 

796 color=config_datatype.color, 

797 **kwargs_hist, 

798 ) 

799 sideHist.hist( 

800 vector[cast(Vector, data[f"{key}HighSNMask{self.suffix_stat}"])], 

801 color=config_datatype.color, 

802 ls="--", 

803 **kwargs_hist, 

804 ) 

805 sideHist.hist( 

806 vector[cast(Vector, data[f"{key}LowSNMask{self.suffix_stat}"])], 

807 color=config_datatype.color, 

808 **kwargs_hist, 

809 ls=":", 

810 ) 

811 

812 # Add a horizontal reference line at 0 to the side histogram 

813 sideHist.axhline(0, color=kwargs["hlineColor"], ls=kwargs["hlineStyle"], alpha=0.7, zorder=-2) 

814 

815 sideHist.axes.get_yaxis().set_visible(False) 

816 sideHist.set_xlabel("Number", fontsize=8) 

817 if self.plot2DHist and histIm is not None: 

818 divider = make_axes_locatable(sideHist) 

819 cax = divider.append_axes("right", size="8%", pad=0) 

820 plt.colorbar(histIm, cax=cax, orientation="vertical", label="Number of Points Per Bin")