Coverage for python/lsst/analysis/drp/scatterPlot.py: 9%

386 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-12-30 09:43 +0000

1# This file is part of analysis_drp. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21import matplotlib 

22import matplotlib.pyplot as plt 

23import numpy as np 

24import pandas as pd 

25from matplotlib import gridspec 

26from matplotlib.patches import Rectangle 

27from matplotlib.path import Path 

28from matplotlib.collections import PatchCollection 

29from mpl_toolkits.axes_grid1 import make_axes_locatable 

30 

31from lsst.pipe.tasks.configurableActions import ConfigurableActionStructField 

32from lsst.pipe.tasks.dataFrameActions import MagColumnNanoJansky, SingleColumnAction 

33from lsst.skymap import BaseSkyMap 

34 

35import lsst.pipe.base as pipeBase 

36import lsst.pex.config as pexConfig 

37 

38from .calcFunctors import MagDiff 

39from .dataSelectors import SnSelector, StarIdentifier, CoaddPlotFlagSelector 

40from .plotUtils import generateSummaryStats, parsePlotInfo, addPlotInfo, mkColormap 

41from .statistics import sigmaMad 

42 

43cmapPatch = plt.cm.coolwarm.copy() 

44cmapPatch.set_bad(color="none") 

45matplotlib.use("Agg") 

46 

47 

48class ScatterPlotWithTwoHistsTaskConnections(pipeBase.PipelineTaskConnections, 

49 dimensions=("tract", "skymap"), 

50 defaultTemplates={"inputCoaddName": "deep", 

51 "plotName": "deltaCoords"}): 

52 

53 catPlot = pipeBase.connectionTypes.Input(doc="The tract wide catalog to make plots from.", 

54 storageClass="DataFrame", 

55 name="objectTable_tract", 

56 dimensions=("tract", "skymap"), 

57 deferLoad=True) 

58 

59 skymap = pipeBase.connectionTypes.Input(doc="The skymap for the tract", 

60 storageClass="SkyMap", 

61 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME, 

62 dimensions=("skymap",)) 

63 

64 scatterPlot = pipeBase.connectionTypes.Output(doc="A scatter plot with histograms for both axes.", 

65 storageClass="Plot", 

66 name="scatterTwoHistPlot_{plotName}", 

67 dimensions=("tract", "skymap")) 

68 

69 

70class ScatterPlotWithTwoHistsTaskConfig(pipeBase.PipelineTaskConfig, 

71 pipelineConnections=ScatterPlotWithTwoHistsTaskConnections): 

72 

73 axisActions = ConfigurableActionStructField( 

74 doc="The actions to use to calculate the values used on each axis. The defaults for the" 

75 "column names xAction and magAction are set to iCModelFlux.", 

76 default={"xAction": MagColumnNanoJansky, "yAction": SingleColumnAction, 

77 "magAction": MagColumnNanoJansky}, 

78 ) 

79 

80 axisLabels = pexConfig.DictField( 

81 doc="Name of the dataframe columns to plot, will be used as the axis label: {'x':, 'y':, 'mag':}" 

82 "The mag column is used to decide which points to include in the printed statistics.", 

83 keytype=str, 

84 itemtype=str 

85 ) 

86 

87 def get_requirements(self): 

88 """Return inputs required for a Task to run with this config. 

89 

90 Returns 

91 ------- 

92 bands : `set` 

93 The required bands. 

94 columns : `set` 

95 The required column names. 

96 """ 

97 columnNames = {"patch"} 

98 bands = set() 

99 for actionStruct in [self.axisActions, 

100 self.selectorActions, 

101 self.highSnStatisticSelectorActions, 

102 self.lowSnStatisticSelectorActions, 

103 self.sourceSelectorActions]: 

104 for action in actionStruct: 

105 for col in action.columns: 

106 if col is not None: 

107 columnNames.add(col) 

108 column_split = col.split("_") 

109 # If there's no underscore, it has no band prefix 

110 if len(column_split) > 1: 

111 band = column_split[0] 

112 if band not in self.nonBandColumnPrefixes: 

113 bands.add(band) 

114 return bands, columnNames 

115 

116 nonBandColumnPrefixes = pexConfig.ListField( 

117 doc="Column prefixes that are not bands and which should not be added to the set of bands", 

118 dtype=str, 

119 default=["coord", "extend", "detect", "xy", "merge", "sky"], 

120 ) 

121 

122 selectorActions = ConfigurableActionStructField( 

123 doc="Which selectors to use to narrow down the data for QA plotting.", 

124 default={"flagSelector": CoaddPlotFlagSelector, 

125 "catSnSelector": SnSelector}, 

126 ) 

127 

128 highSnStatisticSelectorActions = ConfigurableActionStructField( 

129 doc="Selectors to use to decide which points to use for calculating the high SN statistics.", 

130 default={"statSelector": SnSelector}, 

131 ) 

132 

133 lowSnStatisticSelectorActions = ConfigurableActionStructField( 

134 doc="Selectors to use to decide which points to use for calculating the low SN statistics.", 

135 default={"statSelector": SnSelector}, 

136 ) 

137 

138 sourceSelectorActions = ConfigurableActionStructField( 

139 doc="What types of sources to use.", 

140 default={"sourceSelector": StarIdentifier}, 

141 ) 

142 

143 nBins = pexConfig.Field( 

144 doc="Number of bins to put on the x axis.", 

145 default=40.0, 

146 dtype=float, 

147 ) 

148 

149 plot2DHist = pexConfig.Field( 

150 doc="Plot a 2D histogram in the densist area of points on the scatter plot." 

151 "Doesn't look great if plotting mulitple datasets on top of each other.", 

152 default=True, 

153 dtype=bool, 

154 ) 

155 

156 xLims = pexConfig.ListField( 

157 doc="Minimum and maximum x-axis limit to force (provided as a list of [xMin, xMax]). " 

158 "If `None`, limits will be computed and set based on the data.", 

159 dtype=float, 

160 default=None, 

161 optional=True, 

162 ) 

163 

164 yLims = pexConfig.ListField( 

165 doc="Minimum and maximum y-axis limit to force (provided as a list of [yMin, yMax]). " 

166 "If `None`, limits will be computed and set based on the data.", 

167 dtype=float, 

168 default=None, 

169 optional=True, 

170 ) 

171 

172 minPointSize = pexConfig.Field( 

173 doc="When plotting points (as opposed to 2D hist bins), the minimum size they can be. Some " 

174 "relative scaling will be perfomed depending on the \"flavor\" of the set of points.", 

175 default=2.0, 

176 dtype=float, 

177 ) 

178 

179 def setDefaults(self): 

180 super().setDefaults() 

181 self.axisActions.magAction.column = "i_cModelFlux" 

182 self.axisActions.xAction.column = "i_cModelFlux" 

183 self.axisActions.yAction = MagDiff 

184 self.axisActions.yAction.col1 = "i_ap12Flux" 

185 self.axisActions.yAction.col2 = "i_psfFlux" 

186 self.selectorActions.flagSelector.bands = ["i"] 

187 self.selectorActions.catSnSelector.fluxType = "psfFlux" 

188 self.highSnStatisticSelectorActions.statSelector.fluxType = "cModelFlux" 

189 self.highSnStatisticSelectorActions.statSelector.threshold = 2700 

190 self.lowSnStatisticSelectorActions.statSelector.fluxType = "cModelFlux" 

191 self.lowSnStatisticSelectorActions.statSelector.threshold = 500 

192 self.axisLabels = { 

193 "x": self.axisActions.xAction.column.removesuffix("Flux") + " (mag)", 

194 "mag": self.axisActions.magAction.column.removesuffix("Flux") + " (mag)", 

195 "y": ("{} - {} (mmag)".format(self.axisActions.yAction.col1.removesuffix("Flux"), 

196 self.axisActions.yAction.col2.removesuffix("Flux"))) 

197 } 

198 

199 

200class ScatterPlotWithTwoHistsTask(pipeBase.PipelineTask): 

201 

202 ConfigClass = ScatterPlotWithTwoHistsTaskConfig 

203 _DefaultName = "scatterPlotWithTwoHistsTask" 

204 

205 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

206 # Docs inherited from base class 

207 bands, columnNames = self.config.get_requirements() 

208 inputs = butlerQC.get(inputRefs) 

209 dataFrame = inputs["catPlot"].get(parameters={"columns": columnNames}) 

210 inputs['catPlot'] = dataFrame 

211 dataId = butlerQC.quantum.dataId 

212 inputs["dataId"] = dataId 

213 inputs["runName"] = inputRefs.catPlot.datasetRef.run 

214 localConnections = self.config.ConnectionsClass(config=self.config) 

215 inputs["tableName"] = localConnections.catPlot.name 

216 inputs["plotName"] = localConnections.scatterPlot.name 

217 inputs["bands"] = bands 

218 outputs = self.run(**inputs) 

219 butlerQC.put(outputs, outputRefs) 

220 

221 def run(self, catPlot, dataId, runName, skymap, tableName, bands, plotName): 

222 """Prep the catalogue and then make a scatterPlot of the given column. 

223 

224 Parameters 

225 ---------- 

226 catPlot : `pandas.core.frame.DataFrame` 

227 The catalog to plot the points from. 

228 dataId : 

229 `lsst.daf.butler.core.dimensions._coordinate._ExpandedTupleDataCoordinate` 

230 The dimensions that the plot is being made from. 

231 runName : `str` 

232 The name of the collection that the plot is written out to. 

233 skymap : `lsst.skymap` 

234 The skymap used to define the patch boundaries. 

235 tableName : `str` 

236 The type of table used to make the plot. 

237 

238 Returns 

239 ------- 

240 `pipeBase.Struct` containing: 

241 scatterPlot : `matplotlib.figure.Figure` 

242 The resulting figure. 

243 

244 Notes 

245 ----- 

246 The catalogue is first narrowed down using the selectors specified in 

247 `self.config.selectorActions`. 

248 If the column names are 'Functor' then the functors specified in 

249 `self.config.axisFunctors` are used to calculate the required values. 

250 After this the following functions are run: 

251 

252 `parsePlotInfo` which uses the dataId, runName and tableName to add 

253 useful information to the plot. 

254 

255 `generateSummaryStats` which parses the skymap to give the corners of 

256 the patches for later plotting and calculates some basic statistics 

257 in each patch for the column in self.zColName. 

258 

259 `scatterPlotWithTwoHists` which makes a scatter plot of the points with 

260 a histogram of each axis. 

261 """ 

262 

263 # Apply the selectors to narrow down the sources to use 

264 mask = np.ones(len(catPlot), dtype=bool) 

265 for selector in self.config.selectorActions: 

266 mask &= selector(catPlot) 

267 catPlot = catPlot[mask] 

268 

269 columns = {self.config.axisLabels["x"]: self.config.axisActions.xAction(catPlot), 

270 self.config.axisLabels["y"]: self.config.axisActions.yAction(catPlot), 

271 self.config.axisLabels["mag"]: self.config.axisActions.magAction(catPlot), 

272 "patch": catPlot["patch"]} 

273 for actionStruct in [self.config.highSnStatisticSelectorActions, 

274 self.config.lowSnStatisticSelectorActions, 

275 self.config.sourceSelectorActions]: 

276 for action in actionStruct: 

277 for col in action.columns: 

278 columns.update({col: catPlot[col]}) 

279 

280 plotDf = pd.DataFrame(columns) 

281 

282 sourceTypes = np.zeros(len(plotDf)) 

283 for selector in self.config.sourceSelectorActions: 

284 # The source selectors return 1 for a star and 2 for a galaxy 

285 # rather than a mask this allows the information about which 

286 # type of sources are being plotted to be propagated 

287 sourceTypes += selector(catPlot) 

288 if list(self.config.sourceSelectorActions) == []: 

289 sourceTypes = [10]*len(plotDf) 

290 plotDf.loc[:, "sourceType"] = sourceTypes 

291 

292 # Decide which points to use for stats calculation 

293 useForStats = np.zeros(len(plotDf)) 

294 lowSnMask = np.ones(len(plotDf), dtype=bool) 

295 for selector in self.config.lowSnStatisticSelectorActions: 

296 lowSnMask &= selector(plotDf) 

297 useForStats[lowSnMask] = 2 

298 

299 highSnMask = np.ones(len(plotDf), dtype=bool) 

300 for selector in self.config.highSnStatisticSelectorActions: 

301 highSnMask &= selector(plotDf) 

302 useForStats[highSnMask] = 1 

303 plotDf.loc[:, "useForStats"] = useForStats 

304 

305 # Get the S/N cut used 

306 if hasattr(self.config.selectorActions, "catSnSelector"): 

307 SN = self.config.selectorActions.catSnSelector.threshold 

308 SNFlux = self.config.selectorActions.catSnSelector.fluxType 

309 else: 

310 SN = "N/A" 

311 SNFlux = "N/A" 

312 

313 # Get useful information about the plot 

314 plotInfo = parsePlotInfo(dataId, runName, tableName, bands, plotName, SN, SNFlux) 

315 # Calculate the corners of the patches and some associated stats 

316 sumStats = {} if skymap is None else generateSummaryStats( 

317 plotDf, self.config.axisLabels["y"], skymap, plotInfo) 

318 # Make the plot 

319 fig = self.scatterPlotWithTwoHists(plotDf, plotInfo, sumStats) 

320 return pipeBase.Struct(scatterPlot=fig) 

321 

322 def scatterPlotWithTwoHists(self, catPlot, plotInfo, sumStats): 

323 """Makes a generic plot with a 2D histogram and collapsed histograms of 

324 each axis. 

325 

326 Parameters 

327 ---------- 

328 catPlot : `pandas.core.frame.DataFrame` 

329 The catalog to plot the points from. 

330 plotInfo : `dict` 

331 A dictionary of information about the data being plotted with keys: 

332 ``"run"`` 

333 The output run for the plots (`str`). 

334 ``"skymap"`` 

335 The type of skymap used for the data (`str`). 

336 ``"filter"`` 

337 The filter used for this data (`str`). 

338 ``"tract"`` 

339 The tract that the data comes from (`str`). 

340 sumStats : `dict` 

341 A dictionary where the patchIds are the keys which store the R.A. 

342 and dec of the corners of the patch, along with a summary 

343 statistic for each patch. 

344 

345 Returns 

346 ------- 

347 fig : `matplotlib.figure.Figure` 

348 The resulting figure. 

349 

350 Notes 

351 ----- 

352 Uses the axisLabels config options `x` and `y` and the axisAction 

353 config options `xAction` and `yAction` to plot a scatter 

354 plot of the values against each other. A histogram of the points 

355 collapsed onto each axis is also plotted. A summary panel showing the 

356 median of the y value in each patch is shown in the upper right corner 

357 of the resultant plot. The code uses the selectorActions to decide 

358 which points to plot and the statisticSelector actions to determine 

359 which points to use for the printed statistics. 

360 

361 The axis limits are set based on the values of `config.xLim` and 

362 `config.yLims`. If provided (as a `list` of [min, max]), those will 

363 be used. If `None` (the default), the axis limits will be computed 

364 and set based on the data. 

365 """ 

366 self.log.info("Plotting %s: the values of %s on a scatter plot.", 

367 self.config.connections.plotName, self.config.axisLabels['y']) 

368 

369 fig = plt.figure(dpi=300) 

370 gs = gridspec.GridSpec(4, 4) 

371 

372 newBlues = mkColormap(["paleturquoise", "midnightBlue"]) 

373 newReds = mkColormap(["lemonchiffon", "firebrick"]) 

374 

375 # Need to separate stars and galaxies 

376 stars = (catPlot["sourceType"] == 1) 

377 galaxies = (catPlot["sourceType"] == 2) 

378 

379 xCol = self.config.axisLabels["x"] 

380 yCol = self.config.axisLabels["y"] 

381 magCol = self.config.axisLabels["mag"] 

382 

383 # For galaxies 

384 xsGalaxies = catPlot.loc[galaxies, xCol] 

385 ysGalaxies = catPlot.loc[galaxies, yCol] 

386 

387 # For stars 

388 xsStars = catPlot.loc[stars, xCol] 

389 ysStars = catPlot.loc[stars, yCol] 

390 

391 highStats = {} 

392 highMags = {} 

393 lowStats = {} 

394 lowMags = {} 

395 

396 # sourceTypes: 1 - stars, 2 - galaxies, 9 - unknowns 

397 # 10 - all 

398 sourceTypeList = [1, 2, 9, 10] 

399 sourceTypeMapper = {"stars": 1, "galaxies": 2, "unknowns": 9, "all": 10} 

400 # Calculate some statistics 

401 for sourceType in sourceTypeList: 

402 if np.any(catPlot["sourceType"] == sourceType): 

403 sources = (catPlot["sourceType"] == sourceType) 

404 highSn = ((catPlot["useForStats"] == 1) & sources) 

405 highSnMed = np.nanmedian(catPlot.loc[highSn, yCol]) 

406 highSnMad = sigmaMad(catPlot.loc[highSn, yCol], nan_policy="omit") 

407 

408 lowSn = (((catPlot["useForStats"] == 1) | (catPlot["useForStats"] == 2)) & sources) 

409 lowSnMed = np.nanmedian(catPlot.loc[lowSn, yCol]) 

410 lowSnMad = sigmaMad(catPlot.loc[lowSn, yCol], nan_policy="omit") 

411 

412 highStatsStr = (f"Median: {highSnMed:0.3g} " 

413 + r"$\sigma_{MAD}$: " + f"{highSnMad:0.3g} " 

414 + r"N$_{points}$: " + f"{np.sum(highSn)}") 

415 highStats[sourceType] = highStatsStr 

416 

417 lowStatsStr = (f"Median: {lowSnMed:0.3g} " 

418 + r"$\sigma_{MAD}$: " + f"{lowSnMad:0.3g} " 

419 + r"N$_{points}$: " + f"{np.sum(lowSn)}") 

420 lowStats[sourceType] = lowStatsStr 

421 

422 if np.sum(highSn) > 0: 

423 sortedMags = np.sort(catPlot.loc[highSn, magCol]) 

424 x = int(len(sortedMags)/10) 

425 approxHighMag = np.nanmedian(sortedMags[-x:]) 

426 elif len(catPlot.loc[highSn, magCol]) < 10: 

427 approxHighMag = np.nanmedian(catPlot.loc[highSn, magCol]) 

428 else: 

429 approxHighMag = "-" 

430 highMags[sourceType] = f"{approxHighMag:.3g}" 

431 

432 if np.sum(lowSn) > 0.0: 

433 sortedMags = np.sort(catPlot.loc[lowSn, magCol]) 

434 x = int(len(sortedMags)/10) 

435 approxLowMag = np.nanmedian(sortedMags[-x:]) 

436 elif len(catPlot.loc[lowSn, magCol]) < 10: 

437 approxLowMag = np.nanmedian(catPlot.loc[lowSn, magCol]) 

438 else: 

439 approxLowMag = "-" 

440 lowMags[sourceType] = f"{approxLowMag:.3g}" 

441 

442 # Main scatter plot 

443 ax = fig.add_subplot(gs[1:, :-1]) 

444 binThresh = 5 

445 

446 linesForLegend = [] 

447 

448 if (np.any(catPlot["sourceType"] == sourceTypeMapper["stars"]) 

449 and not np.any(catPlot["sourceType"] == sourceTypeMapper["galaxies"])): 

450 toPlotList = [(xsStars.values, ysStars.values, "midnightblue", newBlues, 

451 sourceTypeMapper["stars"])] 

452 elif (np.any(catPlot["sourceType"] == sourceTypeMapper["galaxies"]) 

453 and not np.any(catPlot["sourceType"] == sourceTypeMapper["stars"])): 

454 toPlotList = [(xsGalaxies.values, ysGalaxies.values, "firebrick", newReds, 

455 sourceTypeMapper["galaxies"])] 

456 elif (np.any(catPlot["sourceType"] == sourceTypeMapper["stars"]) 

457 and np.any(catPlot["sourceType"] == sourceTypeMapper["galaxies"])): 

458 toPlotList = [(xsGalaxies.values, ysGalaxies.values, "firebrick", newReds, 

459 sourceTypeMapper["galaxies"]), 

460 (xsStars.values, ysStars.values, "midnightblue", newBlues, 

461 sourceTypeMapper["stars"])] 

462 elif np.any(catPlot["sourceType"] == sourceTypeMapper["unknowns"]): 

463 unknowns = (catPlot["sourceType"] == sourceTypeMapper["unknowns"]) 

464 toPlotList = [(catPlot.loc[unknowns, xCol].values, catPlot.loc[unknowns, yCol].values, 

465 "green", None, sourceTypeMapper["unknowns"])] 

466 elif np.any(catPlot["sourceType"] == sourceTypeMapper["all"]): 

467 toPlotList = [(catPlot[xCol].values, catPlot[yCol].values, "purple", None, 

468 sourceTypeMapper["all"])] 

469 else: 

470 toPlotList = [] 

471 noDataFig = plt.Figure() 

472 noDataFig.text(0.3, 0.5, "No data to plot after selectors applied") 

473 noDataFig = addPlotInfo(noDataFig, plotInfo) 

474 return noDataFig 

475 

476 xMin = None 

477 for (j, (xs, ys, color, cmap, sourceType)) in enumerate(toPlotList): 

478 sigMadYs = sigmaMad(ys, nan_policy="omit") 

479 if len(xs) < 2: 

480 medLine, = ax.plot(xs, np.nanmedian(ys), color, 

481 label=f"Median: {np.nanmedian(ys):0.3g}", lw=0.8) 

482 linesForLegend.append(medLine) 

483 sigMads = np.array([sigmaMad(ys, nan_policy="omit")]*len(xs)) 

484 sigMadLine, = ax.plot(xs, np.nanmedian(ys) + 1.0*sigMads, color, alpha=0.8, lw=0.8, 

485 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:0.3g}") 

486 ax.plot(xs, np.nanmedian(ys) - 1.0*sigMads, color, alpha=0.8) 

487 linesForLegend.append(sigMadLine) 

488 histIm = None 

489 continue 

490 

491 [xs1, xs25, xs50, xs75, xs95, xs97] = np.nanpercentile(xs, [1, 25, 50, 75, 95, 97]) 

492 xScale = (xs97 - xs1)/20.0 # This is ~5% of the data range 

493 

494 # 40 was used as the default number of bins because it looked good 

495 xEdges = np.arange(np.nanmin(xs) - xScale, np.nanmax(xs) + xScale, 

496 (np.nanmax(xs) + xScale - (np.nanmin(xs) - xScale))/self.config.nBins) 

497 medYs = np.nanmedian(ys) 

498 fiveSigmaHigh = medYs + 5.0*sigMadYs 

499 fiveSigmaLow = medYs - 5.0*sigMadYs 

500 binSize = (fiveSigmaHigh - fiveSigmaLow)/self.config.nBins 

501 yEdges = np.arange(fiveSigmaLow, fiveSigmaHigh, binSize) 

502 

503 counts, xBins, yBins = np.histogram2d(xs, ys, bins=(xEdges, yEdges)) 

504 countsYs = np.sum(counts, axis=1) 

505 

506 ids = np.where((countsYs > binThresh))[0] 

507 xEdgesPlot = xEdges[ids][1:] 

508 xEdges = xEdges[ids] 

509 

510 if len(ids) > 1: 

511 # Create the codes needed to turn the sigmaMad lines 

512 # into a path to speed up checking which points are 

513 # inside the area. 

514 codes = np.ones(len(xEdgesPlot)*2)*Path.LINETO 

515 codes[0] = Path.MOVETO 

516 codes[-1] = Path.CLOSEPOLY 

517 

518 meds = np.zeros(len(xEdgesPlot)) 

519 threeSigMadVerts = np.zeros((len(xEdgesPlot)*2, 2)) 

520 sigMads = np.zeros(len(xEdgesPlot)) 

521 

522 for (i, xEdge) in enumerate(xEdgesPlot): 

523 ids = np.where((xs < xEdge) & (xs >= xEdges[i]) & (np.isfinite(ys)))[0] 

524 med = np.median(ys[ids]) 

525 sigMad = sigmaMad(ys[ids]) 

526 meds[i] = med 

527 sigMads[i] = sigMad 

528 threeSigMadVerts[i, :] = [xEdge, med + 3*sigMad] 

529 threeSigMadVerts[-(i + 1), :] = [xEdge, med - 3*sigMad] 

530 

531 medLine, = ax.plot(xEdgesPlot, meds, color, label="Running Median") 

532 linesForLegend.append(medLine) 

533 

534 # Make path to check which points lie within one sigma mad 

535 threeSigMadPath = Path(threeSigMadVerts, codes) 

536 

537 # Add lines for the median +/- 3 * sigma MAD 

538 threeSigMadLine, = ax.plot(xEdgesPlot, threeSigMadVerts[:len(xEdgesPlot), 1], color, 

539 alpha=0.4, label=r"3$\sigma_{MAD}$") 

540 ax.plot(xEdgesPlot[::-1], threeSigMadVerts[len(xEdgesPlot):, 1], color, alpha=0.4) 

541 

542 # Add lines for the median +/- 1 * sigma MAD 

543 sigMadLine, = ax.plot(xEdgesPlot, meds + 1.0*sigMads, color, alpha=0.8, 

544 label=r"$\sigma_{MAD}$") 

545 linesForLegend.append(sigMadLine) 

546 ax.plot(xEdgesPlot, meds - 1.0*sigMads, color, alpha=0.8) 

547 

548 # Add lines for the median +/- 2 * sigma MAD 

549 twoSigMadLine, = ax.plot(xEdgesPlot, meds + 2.0*sigMads, color, alpha=0.6, 

550 label=r"2$\sigma_{MAD}$") 

551 linesForLegend.append(twoSigMadLine) 

552 linesForLegend.append(threeSigMadLine) 

553 ax.plot(xEdgesPlot, meds - 2.0*sigMads, color, alpha=0.6) 

554 

555 # Check which points are outside 3 sigma MAD of the median 

556 # and plot these as points. 

557 inside = threeSigMadPath.contains_points(np.array([xs, ys]).T) 

558 ax.plot(xs[~inside], ys[~inside], ".", ms=self.config.minPointSize, alpha=0.3, 

559 mfc=color, mec=color, zorder=-1) 

560 

561 # Add some stats text 

562 xPos = 0.65 - 0.4*j 

563 if hasattr(self.config.highSnStatisticSelectorActions, "statSelector"): 

564 bbox = dict(edgecolor=color, linestyle="--", facecolor="none") 

565 highThresh = self.config.highSnStatisticSelectorActions.statSelector.threshold 

566 statText = f"S/N > {highThresh} Stats [{magCol} $\\lesssim$ {highMags[sourceType]}]\n" 

567 statText += highStats[sourceType] 

568 fig.text(xPos, 0.090, statText, bbox=bbox, transform=fig.transFigure, fontsize=6) 

569 

570 if hasattr(self.config.highSnStatisticSelectorActions, "statSelector"): 

571 bbox = dict(edgecolor=color, linestyle=":", facecolor="none") 

572 lowThresh = self.config.lowSnStatisticSelectorActions.statSelector.threshold 

573 statText = f"S/N > {lowThresh} Stats [{magCol} $\\lesssim$ {lowMags[sourceType]}]\n" 

574 statText += lowStats[sourceType] 

575 fig.text(xPos, 0.020, statText, bbox=bbox, transform=fig.transFigure, fontsize=6) 

576 

577 if self.config.plot2DHist: 

578 histIm = ax.hexbin(xs[inside], ys[inside], gridsize=75, cmap=cmap, mincnt=1, zorder=-2) 

579 

580 # If there are not many sources being used for the 

581 # statistics then plot them individually as just 

582 # plotting a line makes the statistics look wrong 

583 # as the magnitude estimation is iffy for low 

584 # numbers of sources. 

585 sources = (catPlot["sourceType"] == sourceType) 

586 statInfo = catPlot["useForStats"].loc[sources].values 

587 if hasattr(self.config.highSnStatisticSelectorActions, "statSelector"): 

588 highSn = (statInfo == 1) 

589 if np.sum(highSn) < 100 and np.sum(highSn) > 0: 

590 ax.plot(xs[highSn], ys[highSn], marker="x", ms=self.config.minPointSize + 1, 

591 mec="w", mew=2, ls="none") 

592 highSnLine, = ax.plot(xs[highSn], ys[highSn], color=color, marker="x", 

593 ms=self.config.minPointSize + 1, ls="none", label="High SN") 

594 linesForLegend.append(highSnLine) 

595 xMin = np.min(xs[highSn]) 

596 else: 

597 ax.axvline(float(highMags[sourceType]), color=color, ls="--") 

598 if hasattr(self.config.lowSnStatisticSelectorActions, "statSelector"): 

599 lowSn = ((statInfo == 2) | (statInfo == 2)) 

600 if np.sum(lowSn) < 100 and np.sum(lowSn) > 0: 

601 ax.plot(xs[lowSn], ys[lowSn], marker="+", ms=self.config.minPointSize + 1, mec="w", 

602 mew=2, ls="none") 

603 lowSnLine, = ax.plot(xs[lowSn], ys[lowSn], color=color, marker="+", 

604 ms=self.config.minPointSize + 1, ls="none", label="Low SN") 

605 linesForLegend.append(lowSnLine) 

606 if xMin is None or xMin > np.min(xs[lowSn]): 

607 xMin = np.min(xs[lowSn]) 

608 else: 

609 ax.axvline(float(lowMags[sourceType]), color=color, ls=":") 

610 

611 else: 

612 ax.plot(xs, ys, ".", ms=self.config.minPointSize + 3, alpha=0.3, mfc=color, mec=color, 

613 zorder=-1) 

614 meds = np.array([np.nanmedian(ys)]*len(xs)) 

615 medLine, = ax.plot(xs, meds, color, label=f"Median: {np.nanmedian(ys):0.3g}", lw=0.8) 

616 linesForLegend.append(medLine) 

617 sigMads = np.array([sigmaMad(ys, nan_policy="omit")]*len(xs)) 

618 sigMadLine, = ax.plot(xs, meds + 1.0*sigMads, color, alpha=0.8, lw=0.8, 

619 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:0.3g}") 

620 ax.plot(xs, meds - 1.0*sigMads, color, alpha=0.8) 

621 linesForLegend.append(sigMadLine) 

622 histIm = None 

623 

624 # Set the scatter plot limits 

625 if len(ysStars) > 0: 

626 plotMed = np.nanmedian(ysStars) 

627 else: 

628 plotMed = np.nanmedian(ysGalaxies) 

629 if len(xs) < 2: 

630 meds = [np.median(ys)] 

631 

632 if self.config.yLims is not None: 

633 yLimMin = self.config.yLims[0] 

634 yLimMax = self.config.yLims[1] 

635 else: 

636 numSig = 4 

637 yLimMin = plotMed - numSig*sigMadYs 

638 yLimMax = plotMed + numSig*sigMadYs 

639 while (yLimMax < np.max(meds) or yLimMin > np.min(meds)) and numSig < 10: 

640 numSig += 1 

641 

642 numSig += 1 

643 yLimMin = plotMed - numSig*sigMadYs 

644 yLimMax = plotMed + numSig*sigMadYs 

645 ax.set_ylim(yLimMin, yLimMax) 

646 

647 if self.config.xLims is not None: 

648 ax.set_xlim(self.config.xLims[0], self.config.xLims[1]) 

649 elif len(xs) > 2: 

650 if xMin is None: 

651 xMin = xs1 - 2*xScale 

652 ax.set_xlim(xMin, xs97 + 2*xScale) 

653 

654 # Add a line legend 

655 ax.legend(handles=linesForLegend, ncol=4, fontsize=6, loc="upper left", framealpha=0.9, 

656 edgecolor="k", borderpad=0.4, handlelength=1) 

657 

658 # Add axes labels 

659 ax.set_ylabel(yCol, fontsize=10, labelpad=10) 

660 ax.set_xlabel(xCol, fontsize=10, labelpad=2) 

661 

662 # Top histogram 

663 topHist = plt.gcf().add_subplot(gs[0, :-1], sharex=ax) 

664 topHist.hist(catPlot[xCol].values, bins=100, color="grey", alpha=0.3, log=True, 

665 label=f"All ({len(catPlot)})") 

666 if np.any(catPlot["sourceType"] == 2): 

667 topHist.hist(xsGalaxies, bins=100, color="firebrick", histtype="step", log=True, 

668 label=f"Galaxies ({len(np.where(galaxies)[0])})") 

669 if np.any(catPlot["sourceType"] == 1): 

670 topHist.hist(xsStars, bins=100, color="midnightblue", histtype="step", log=True, 

671 label=f"Stars ({len(np.where(stars)[0])})") 

672 topHist.axes.get_xaxis().set_visible(False) 

673 topHist.set_ylabel("Number", fontsize=8) 

674 topHist.legend(fontsize=6, framealpha=0.9, borderpad=0.4, loc="lower left", ncol=3, edgecolor="k") 

675 

676 # Side histogram 

677 sideHist = plt.gcf().add_subplot(gs[1:, -1], sharey=ax) 

678 finiteObjs = np.isfinite(catPlot[yCol].values) 

679 bins = np.linspace(yLimMin, yLimMax) 

680 sideHist.hist(catPlot[yCol].values[finiteObjs], bins=bins, color="grey", alpha=0.3, 

681 orientation="horizontal", log=True) 

682 if np.any(catPlot["sourceType"] == sourceTypeMapper["galaxies"]): 

683 sideHist.hist(ysGalaxies[np.isfinite(ysGalaxies)], bins=bins, color="firebrick", histtype="step", 

684 orientation="horizontal", log=True) 

685 sources = (catPlot["sourceType"].values == sourceTypeMapper["galaxies"]) 

686 highSn = (catPlot["useForStats"].values == 1) 

687 lowSn = (catPlot["useForStats"].values == 2) 

688 sideHist.hist(ysGalaxies[highSn[sources]], bins=bins, color="firebrick", histtype="step", 

689 orientation="horizontal", log=True, ls="--") 

690 sideHist.hist(ysGalaxies[lowSn[sources]], bins=bins, color="firebrick", histtype="step", 

691 orientation="horizontal", log=True, ls=":") 

692 

693 if np.any(catPlot["sourceType"] == sourceTypeMapper["stars"]): 

694 sideHist.hist(ysStars[np.isfinite(ysStars)], bins=bins, color="midnightblue", histtype="step", 

695 orientation="horizontal", log=True) 

696 sources = (catPlot["sourceType"] == sourceTypeMapper["stars"]) 

697 highSn = (catPlot["useForStats"] == 1) 

698 lowSn = (catPlot["useForStats"] == 2) 

699 sideHist.hist(ysStars[highSn[sources]], bins=bins, color="midnightblue", histtype="step", 

700 orientation="horizontal", log=True, ls="--") 

701 sideHist.hist(ysStars[lowSn[sources]], bins=bins, color="midnightblue", histtype="step", 

702 orientation="horizontal", log=True, ls=":") 

703 

704 sideHist.axes.get_yaxis().set_visible(False) 

705 sideHist.set_xlabel("Number", fontsize=8) 

706 if self.config.plot2DHist and histIm is not None: 

707 divider = make_axes_locatable(sideHist) 

708 cax = divider.append_axes("right", size="8%", pad=0) 

709 plt.colorbar(histIm, cax=cax, orientation="vertical", label="Number of Points Per Bin") 

710 

711 # Corner plot of patches showing summary stat in each 

712 axCorner = plt.gcf().add_subplot(gs[0, -1]) 

713 axCorner.yaxis.tick_right() 

714 axCorner.yaxis.set_label_position("right") 

715 axCorner.xaxis.tick_top() 

716 axCorner.xaxis.set_label_position("top") 

717 axCorner.set_aspect("equal") 

718 

719 patches = [] 

720 colors = [] 

721 for dataId in sumStats.keys(): 

722 (corners, stat) = sumStats[dataId] 

723 ra = corners[0][0].asDegrees() 

724 dec = corners[0][1].asDegrees() 

725 xy = (ra, dec) 

726 width = corners[2][0].asDegrees() - ra 

727 height = corners[2][1].asDegrees() - dec 

728 patches.append(Rectangle(xy, width, height)) 

729 colors.append(stat) 

730 ras = [ra.asDegrees() for (ra, dec) in corners] 

731 decs = [dec.asDegrees() for (ra, dec) in corners] 

732 axCorner.plot(ras + [ras[0]], decs + [decs[0]], "k", lw=0.5) 

733 cenX = ra + width / 2 

734 cenY = dec + height / 2 

735 if dataId != "tract": 

736 axCorner.annotate(dataId, (cenX, cenY), color="k", fontsize=4, ha="center", va="center") 

737 

738 # Set the bad color to transparent and make a masked array 

739 colors = np.ma.array(colors, mask=np.isnan(colors)) 

740 collection = PatchCollection(patches, cmap=cmapPatch) 

741 collection.set_array(colors) 

742 axCorner.add_collection(collection) 

743 

744 axCorner.set_xlabel("R.A. (deg)", fontsize=7) 

745 axCorner.set_ylabel("Dec. (deg)", fontsize=7) 

746 axCorner.tick_params(axis="both", labelsize=6, length=0, pad=1.5) 

747 axCorner.invert_xaxis() 

748 

749 # Add a colorbar 

750 pos = axCorner.get_position() 

751 cax = fig.add_axes([pos.x0, pos.y0 + 0.23, pos.x1 - pos.x0, 0.025]) 

752 plt.colorbar(collection, cax=cax, orientation="horizontal") 

753 cax.text(0.5, 0.5, "Median Value", color="k", transform=cax.transAxes, rotation="horizontal", 

754 horizontalalignment="center", verticalalignment="center", fontsize=6) 

755 cax.tick_params(axis="x", labelsize=6, labeltop=True, labelbottom=False, bottom=False, top=True, 

756 pad=0.5, length=2) 

757 

758 plt.draw() 

759 plt.subplots_adjust(wspace=0.0, hspace=0.0, bottom=0.22, left=0.21) 

760 fig = plt.gcf() 

761 fig = addPlotInfo(fig, plotInfo) 

762 

763 return fig