Coverage for python/lsst/analysis/drp/scatterPlot.py: 9%

386 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-03-30 03:49 -0700

1# This file is part of analysis_drp. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21import matplotlib 

22import matplotlib.pyplot as plt 

23import numpy as np 

24import pandas as pd 

25from matplotlib import gridspec 

26from matplotlib.patches import Rectangle 

27from matplotlib.path import Path 

28from matplotlib.collections import PatchCollection 

29from mpl_toolkits.axes_grid1 import make_axes_locatable 

30 

31from lsst.pex.config.configurableActions import ConfigurableActionStructField 

32from lsst.pipe.tasks.dataFrameActions import MagColumnNanoJansky, SingleColumnAction 

33from lsst.skymap import BaseSkyMap 

34 

35import lsst.pipe.base as pipeBase 

36import lsst.pex.config as pexConfig 

37 

38from .calcFunctors import MagDiff 

39from .dataSelectors import SnSelector, StarIdentifier, CoaddPlotFlagSelector 

40from .plotUtils import generateSummaryStats, parsePlotInfo, addPlotInfo, mkColormap 

41from .statistics import sigmaMad 

42 

43cmapPatch = plt.cm.coolwarm.copy() 

44cmapPatch.set_bad(color="none") 

45matplotlib.use("Agg") 

46 

47 

48class ScatterPlotWithTwoHistsTaskConnections(pipeBase.PipelineTaskConnections, 

49 dimensions=("tract", "skymap"), 

50 defaultTemplates={"inputCoaddName": "deep", 

51 "plotName": "deltaCoords"}): 

52 

53 catPlot = pipeBase.connectionTypes.Input(doc="The tract wide catalog to make plots from.", 

54 storageClass="DataFrame", 

55 name="objectTable_tract", 

56 dimensions=("tract", "skymap"), 

57 deferLoad=True) 

58 

59 skymap = pipeBase.connectionTypes.Input(doc="The skymap for the tract", 

60 storageClass="SkyMap", 

61 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME, 

62 dimensions=("skymap",)) 

63 

64 scatterPlot = pipeBase.connectionTypes.Output(doc="A scatter plot with histograms for both axes.", 

65 storageClass="Plot", 

66 name="scatterTwoHistPlot_{plotName}", 

67 dimensions=("tract", "skymap")) 

68 

69 

70class ScatterPlotWithTwoHistsTaskConfig(pipeBase.PipelineTaskConfig, 

71 pipelineConnections=ScatterPlotWithTwoHistsTaskConnections): 

72 

73 axisActions = ConfigurableActionStructField( 

74 doc="The actions to use to calculate the values used on each axis. The defaults for the" 

75 "column names xAction and magAction are set to iCModelFlux.", 

76 default={"xAction": MagColumnNanoJansky, "yAction": SingleColumnAction, 

77 "magAction": MagColumnNanoJansky}, 

78 ) 

79 

80 axisLabels = pexConfig.DictField( 

81 doc="Name of the dataframe columns to plot, will be used as the axis label: {'x':, 'y':, 'mag':}" 

82 "The mag column is used to decide which points to include in the printed statistics.", 

83 keytype=str, 

84 itemtype=str 

85 ) 

86 

87 def get_requirements(self): 

88 """Return inputs required for a Task to run with this config. 

89 

90 Returns 

91 ------- 

92 bands : `set` 

93 The required bands. 

94 columns : `set` 

95 The required column names. 

96 """ 

97 columnNames = {"patch"} 

98 bands = set() 

99 for actionStruct in [self.axisActions, 

100 self.selectorActions, 

101 self.highSnStatisticSelectorActions, 

102 self.lowSnStatisticSelectorActions, 

103 self.sourceSelectorActions]: 

104 for action in actionStruct: 

105 for col in action.columns: 

106 if col is not None: 

107 columnNames.add(col) 

108 column_split = col.split("_") 

109 # If there's no underscore, it has no band prefix 

110 if len(column_split) > 1: 

111 band = column_split[0] 

112 if band not in self.nonBandColumnPrefixes: 

113 bands.add(band) 

114 return bands, columnNames 

115 

116 nonBandColumnPrefixes = pexConfig.ListField( 

117 doc="Column prefixes that are not bands and which should not be added to the set of bands", 

118 dtype=str, 

119 default=["coord", "extend", "detect", "xy", "merge", "sky"], 

120 ) 

121 

122 selectorActions = ConfigurableActionStructField( 

123 doc="Which selectors to use to narrow down the data for QA plotting.", 

124 default={"flagSelector": CoaddPlotFlagSelector, 

125 "catSnSelector": SnSelector}, 

126 ) 

127 

128 highSnStatisticSelectorActions = ConfigurableActionStructField( 

129 doc="Selectors to use to decide which points to use for calculating the high SN statistics.", 

130 default={"statSelector": SnSelector}, 

131 ) 

132 

133 lowSnStatisticSelectorActions = ConfigurableActionStructField( 

134 doc="Selectors to use to decide which points to use for calculating the low SN statistics.", 

135 default={"statSelector": SnSelector}, 

136 ) 

137 

138 sourceSelectorActions = ConfigurableActionStructField( 

139 doc="What types of sources to use.", 

140 default={"sourceSelector": StarIdentifier}, 

141 ) 

142 

143 nBins = pexConfig.Field( 

144 doc="Number of bins to put on the x axis.", 

145 default=40.0, 

146 dtype=float, 

147 ) 

148 

149 plot2DHist = pexConfig.Field( 

150 doc="Plot a 2D histogram in the densist area of points on the scatter plot." 

151 "Doesn't look great if plotting mulitple datasets on top of each other.", 

152 default=True, 

153 dtype=bool, 

154 ) 

155 

156 xLims = pexConfig.ListField( 

157 doc="Minimum and maximum x-axis limit to force (provided as a list of [xMin, xMax]). " 

158 "If `None`, limits will be computed and set based on the data.", 

159 dtype=float, 

160 default=None, 

161 optional=True, 

162 ) 

163 

164 yLims = pexConfig.ListField( 

165 doc="Minimum and maximum y-axis limit to force (provided as a list of [yMin, yMax]). " 

166 "If `None`, limits will be computed and set based on the data.", 

167 dtype=float, 

168 default=None, 

169 optional=True, 

170 ) 

171 

172 minPointSize = pexConfig.Field( 

173 doc="When plotting points (as opposed to 2D hist bins), the minimum size they can be. Some " 

174 "relative scaling will be perfomed depending on the \"flavor\" of the set of points.", 

175 default=2.0, 

176 dtype=float, 

177 ) 

178 

179 def setDefaults(self): 

180 super().setDefaults() 

181 self.axisActions.magAction.column = "i_cModelFlux" 

182 self.axisActions.xAction.column = "i_cModelFlux" 

183 self.axisActions.yAction = MagDiff 

184 self.axisActions.yAction.col1 = "i_ap12Flux" 

185 self.axisActions.yAction.col2 = "i_psfFlux" 

186 self.selectorActions.flagSelector.bands = ["i"] 

187 self.selectorActions.catSnSelector.fluxType = "psfFlux" 

188 self.highSnStatisticSelectorActions.statSelector.fluxType = "cModelFlux" 

189 self.highSnStatisticSelectorActions.statSelector.threshold = 2700 

190 self.lowSnStatisticSelectorActions.statSelector.fluxType = "cModelFlux" 

191 self.lowSnStatisticSelectorActions.statSelector.threshold = 500 

192 self.axisLabels = { 

193 "x": self.axisActions.xAction.column.removesuffix("Flux") + " (mag)", 

194 "mag": self.axisActions.magAction.column.removesuffix("Flux") + " (mag)", 

195 "y": ("{} - {} (mmag)".format(self.axisActions.yAction.col1.removesuffix("Flux"), 

196 self.axisActions.yAction.col2.removesuffix("Flux"))) 

197 } 

198 

199 

200class ScatterPlotWithTwoHistsTask(pipeBase.PipelineTask): 

201 

202 ConfigClass = ScatterPlotWithTwoHistsTaskConfig 

203 _DefaultName = "scatterPlotWithTwoHistsTask" 

204 

205 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

206 # Docs inherited from base class 

207 bands, columnNames = self.config.get_requirements() 

208 inputs = butlerQC.get(inputRefs) 

209 dataFrame = inputs["catPlot"].get(parameters={"columns": columnNames}) 

210 inputs['catPlot'] = dataFrame 

211 dataId = butlerQC.quantum.dataId 

212 inputs["dataId"] = dataId 

213 inputs["runName"] = inputRefs.catPlot.datasetRef.run 

214 localConnections = self.config.ConnectionsClass(config=self.config) 

215 inputs["tableName"] = localConnections.catPlot.name 

216 inputs["plotName"] = localConnections.scatterPlot.name 

217 inputs["bands"] = bands 

218 outputs = self.run(**inputs) 

219 butlerQC.put(outputs, outputRefs) 

220 

221 def run(self, catPlot, dataId, runName, skymap, tableName, bands, plotName): 

222 """Prep the catalogue and then make a scatterPlot of the given column. 

223 

224 Parameters 

225 ---------- 

226 catPlot : `pandas.core.frame.DataFrame` 

227 The catalog to plot the points from. 

228 dataId : `lsst.daf.butler.DataCoordinate` 

229 The dimensions that the plot is being made from. 

230 runName : `str` 

231 The name of the collection that the plot is written out to. 

232 skymap : `lsst.skymap` 

233 The skymap used to define the patch boundaries. 

234 tableName : `str` 

235 The type of table used to make the plot. 

236 

237 Returns 

238 ------- 

239 `pipeBase.Struct` containing: 

240 scatterPlot : `matplotlib.figure.Figure` 

241 The resulting figure. 

242 

243 Notes 

244 ----- 

245 The catalogue is first narrowed down using the selectors specified in 

246 `self.config.selectorActions`. 

247 If the column names are 'Functor' then the functors specified in 

248 `self.config.axisFunctors` are used to calculate the required values. 

249 After this the following functions are run: 

250 

251 `parsePlotInfo` which uses the dataId, runName and tableName to add 

252 useful information to the plot. 

253 

254 `generateSummaryStats` which parses the skymap to give the corners of 

255 the patches for later plotting and calculates some basic statistics 

256 in each patch for the column in self.zColName. 

257 

258 `scatterPlotWithTwoHists` which makes a scatter plot of the points with 

259 a histogram of each axis. 

260 """ 

261 

262 # Apply the selectors to narrow down the sources to use 

263 mask = np.ones(len(catPlot), dtype=bool) 

264 for selector in self.config.selectorActions: 

265 mask &= selector(catPlot) 

266 catPlot = catPlot[mask] 

267 

268 columns = {self.config.axisLabels["x"]: self.config.axisActions.xAction(catPlot), 

269 self.config.axisLabels["y"]: self.config.axisActions.yAction(catPlot), 

270 self.config.axisLabels["mag"]: self.config.axisActions.magAction(catPlot), 

271 "patch": catPlot["patch"]} 

272 for actionStruct in [self.config.highSnStatisticSelectorActions, 

273 self.config.lowSnStatisticSelectorActions, 

274 self.config.sourceSelectorActions]: 

275 for action in actionStruct: 

276 for col in action.columns: 

277 columns.update({col: catPlot[col]}) 

278 

279 plotDf = pd.DataFrame(columns) 

280 

281 sourceTypes = np.zeros(len(plotDf)) 

282 for selector in self.config.sourceSelectorActions: 

283 # The source selectors return 1 for a star and 2 for a galaxy 

284 # rather than a mask this allows the information about which 

285 # type of sources are being plotted to be propagated 

286 sourceTypes += selector(catPlot) 

287 if list(self.config.sourceSelectorActions) == []: 

288 sourceTypes = [10]*len(plotDf) 

289 plotDf.loc[:, "sourceType"] = sourceTypes 

290 

291 # Decide which points to use for stats calculation 

292 useForStats = np.zeros(len(plotDf)) 

293 lowSnMask = np.ones(len(plotDf), dtype=bool) 

294 for selector in self.config.lowSnStatisticSelectorActions: 

295 lowSnMask &= selector(plotDf) 

296 useForStats[lowSnMask] = 2 

297 

298 highSnMask = np.ones(len(plotDf), dtype=bool) 

299 for selector in self.config.highSnStatisticSelectorActions: 

300 highSnMask &= selector(plotDf) 

301 useForStats[highSnMask] = 1 

302 plotDf.loc[:, "useForStats"] = useForStats 

303 

304 # Get the S/N cut used 

305 if hasattr(self.config.selectorActions, "catSnSelector"): 

306 SN = self.config.selectorActions.catSnSelector.threshold 

307 SNFlux = self.config.selectorActions.catSnSelector.fluxType 

308 else: 

309 SN = "N/A" 

310 SNFlux = "N/A" 

311 

312 # Get useful information about the plot 

313 plotInfo = parsePlotInfo(dataId, runName, tableName, bands, plotName, SN, SNFlux) 

314 # Calculate the corners of the patches and some associated stats 

315 sumStats = {} if skymap is None else generateSummaryStats( 

316 plotDf, self.config.axisLabels["y"], skymap, plotInfo) 

317 # Make the plot 

318 fig = self.scatterPlotWithTwoHists(plotDf, plotInfo, sumStats) 

319 return pipeBase.Struct(scatterPlot=fig) 

320 

321 def scatterPlotWithTwoHists(self, catPlot, plotInfo, sumStats): 

322 """Makes a generic plot with a 2D histogram and collapsed histograms of 

323 each axis. 

324 

325 Parameters 

326 ---------- 

327 catPlot : `pandas.core.frame.DataFrame` 

328 The catalog to plot the points from. 

329 plotInfo : `dict` 

330 A dictionary of information about the data being plotted with keys: 

331 ``"run"`` 

332 The output run for the plots (`str`). 

333 ``"skymap"`` 

334 The type of skymap used for the data (`str`). 

335 ``"filter"`` 

336 The filter used for this data (`str`). 

337 ``"tract"`` 

338 The tract that the data comes from (`str`). 

339 sumStats : `dict` 

340 A dictionary where the patchIds are the keys which store the R.A. 

341 and dec of the corners of the patch, along with a summary 

342 statistic for each patch. 

343 

344 Returns 

345 ------- 

346 fig : `matplotlib.figure.Figure` 

347 The resulting figure. 

348 

349 Notes 

350 ----- 

351 Uses the axisLabels config options `x` and `y` and the axisAction 

352 config options `xAction` and `yAction` to plot a scatter 

353 plot of the values against each other. A histogram of the points 

354 collapsed onto each axis is also plotted. A summary panel showing the 

355 median of the y value in each patch is shown in the upper right corner 

356 of the resultant plot. The code uses the selectorActions to decide 

357 which points to plot and the statisticSelector actions to determine 

358 which points to use for the printed statistics. 

359 

360 The axis limits are set based on the values of `config.xLim` and 

361 `config.yLims`. If provided (as a `list` of [min, max]), those will 

362 be used. If `None` (the default), the axis limits will be computed 

363 and set based on the data. 

364 """ 

365 self.log.info("Plotting %s: the values of %s on a scatter plot.", 

366 self.config.connections.plotName, self.config.axisLabels['y']) 

367 

368 fig = plt.figure(dpi=300) 

369 gs = gridspec.GridSpec(4, 4) 

370 

371 newBlues = mkColormap(["paleturquoise", "midnightBlue"]) 

372 newReds = mkColormap(["lemonchiffon", "firebrick"]) 

373 

374 # Need to separate stars and galaxies 

375 stars = (catPlot["sourceType"] == 1) 

376 galaxies = (catPlot["sourceType"] == 2) 

377 

378 xCol = self.config.axisLabels["x"] 

379 yCol = self.config.axisLabels["y"] 

380 magCol = self.config.axisLabels["mag"] 

381 

382 # For galaxies 

383 xsGalaxies = catPlot.loc[galaxies, xCol] 

384 ysGalaxies = catPlot.loc[galaxies, yCol] 

385 

386 # For stars 

387 xsStars = catPlot.loc[stars, xCol] 

388 ysStars = catPlot.loc[stars, yCol] 

389 

390 highStats = {} 

391 highMags = {} 

392 lowStats = {} 

393 lowMags = {} 

394 

395 # sourceTypes: 1 - stars, 2 - galaxies, 9 - unknowns 

396 # 10 - all 

397 sourceTypeList = [1, 2, 9, 10] 

398 sourceTypeMapper = {"stars": 1, "galaxies": 2, "unknowns": 9, "all": 10} 

399 # Calculate some statistics 

400 for sourceType in sourceTypeList: 

401 if np.any(catPlot["sourceType"] == sourceType): 

402 sources = (catPlot["sourceType"] == sourceType) 

403 highSn = ((catPlot["useForStats"] == 1) & sources) 

404 highSnMed = np.nanmedian(catPlot.loc[highSn, yCol]) 

405 highSnMad = sigmaMad(catPlot.loc[highSn, yCol], nan_policy="omit") 

406 

407 lowSn = (((catPlot["useForStats"] == 1) | (catPlot["useForStats"] == 2)) & sources) 

408 lowSnMed = np.nanmedian(catPlot.loc[lowSn, yCol]) 

409 lowSnMad = sigmaMad(catPlot.loc[lowSn, yCol], nan_policy="omit") 

410 

411 highStatsStr = (f"Median: {highSnMed:0.3g} " 

412 + r"$\sigma_{MAD}$: " + f"{highSnMad:0.3g} " 

413 + r"N$_{points}$: " + f"{np.sum(highSn)}") 

414 highStats[sourceType] = highStatsStr 

415 

416 lowStatsStr = (f"Median: {lowSnMed:0.3g} " 

417 + r"$\sigma_{MAD}$: " + f"{lowSnMad:0.3g} " 

418 + r"N$_{points}$: " + f"{np.sum(lowSn)}") 

419 lowStats[sourceType] = lowStatsStr 

420 

421 if np.sum(highSn) > 0: 

422 sortedMags = np.sort(catPlot.loc[highSn, magCol]) 

423 x = int(len(sortedMags)/10) 

424 approxHighMag = np.nanmedian(sortedMags[-x:]) 

425 elif len(catPlot.loc[highSn, magCol]) < 10: 

426 approxHighMag = np.nanmedian(catPlot.loc[highSn, magCol]) 

427 else: 

428 approxHighMag = "-" 

429 highMags[sourceType] = f"{approxHighMag:.3g}" 

430 

431 if np.sum(lowSn) > 0.0: 

432 sortedMags = np.sort(catPlot.loc[lowSn, magCol]) 

433 x = int(len(sortedMags)/10) 

434 approxLowMag = np.nanmedian(sortedMags[-x:]) 

435 elif len(catPlot.loc[lowSn, magCol]) < 10: 

436 approxLowMag = np.nanmedian(catPlot.loc[lowSn, magCol]) 

437 else: 

438 approxLowMag = "-" 

439 lowMags[sourceType] = f"{approxLowMag:.3g}" 

440 

441 # Main scatter plot 

442 ax = fig.add_subplot(gs[1:, :-1]) 

443 binThresh = 5 

444 

445 linesForLegend = [] 

446 

447 if (np.any(catPlot["sourceType"] == sourceTypeMapper["stars"]) 

448 and not np.any(catPlot["sourceType"] == sourceTypeMapper["galaxies"])): 

449 toPlotList = [(xsStars.values, ysStars.values, "midnightblue", newBlues, 

450 sourceTypeMapper["stars"])] 

451 elif (np.any(catPlot["sourceType"] == sourceTypeMapper["galaxies"]) 

452 and not np.any(catPlot["sourceType"] == sourceTypeMapper["stars"])): 

453 toPlotList = [(xsGalaxies.values, ysGalaxies.values, "firebrick", newReds, 

454 sourceTypeMapper["galaxies"])] 

455 elif (np.any(catPlot["sourceType"] == sourceTypeMapper["stars"]) 

456 and np.any(catPlot["sourceType"] == sourceTypeMapper["galaxies"])): 

457 toPlotList = [(xsGalaxies.values, ysGalaxies.values, "firebrick", newReds, 

458 sourceTypeMapper["galaxies"]), 

459 (xsStars.values, ysStars.values, "midnightblue", newBlues, 

460 sourceTypeMapper["stars"])] 

461 elif np.any(catPlot["sourceType"] == sourceTypeMapper["unknowns"]): 

462 unknowns = (catPlot["sourceType"] == sourceTypeMapper["unknowns"]) 

463 toPlotList = [(catPlot.loc[unknowns, xCol].values, catPlot.loc[unknowns, yCol].values, 

464 "green", None, sourceTypeMapper["unknowns"])] 

465 elif np.any(catPlot["sourceType"] == sourceTypeMapper["all"]): 

466 toPlotList = [(catPlot[xCol].values, catPlot[yCol].values, "purple", None, 

467 sourceTypeMapper["all"])] 

468 else: 

469 toPlotList = [] 

470 noDataFig = plt.Figure() 

471 noDataFig.text(0.3, 0.5, "No data to plot after selectors applied") 

472 noDataFig = addPlotInfo(noDataFig, plotInfo) 

473 return noDataFig 

474 

475 xMin = None 

476 for (j, (xs, ys, color, cmap, sourceType)) in enumerate(toPlotList): 

477 sigMadYs = sigmaMad(ys, nan_policy="omit") 

478 if len(xs) < 2: 

479 medLine, = ax.plot(xs, np.nanmedian(ys), color, 

480 label=f"Median: {np.nanmedian(ys):0.3g}", lw=0.8) 

481 linesForLegend.append(medLine) 

482 sigMads = np.array([sigmaMad(ys, nan_policy="omit")]*len(xs)) 

483 sigMadLine, = ax.plot(xs, np.nanmedian(ys) + 1.0*sigMads, color, alpha=0.8, lw=0.8, 

484 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:0.3g}") 

485 ax.plot(xs, np.nanmedian(ys) - 1.0*sigMads, color, alpha=0.8) 

486 linesForLegend.append(sigMadLine) 

487 histIm = None 

488 continue 

489 

490 [xs1, xs25, xs50, xs75, xs95, xs97] = np.nanpercentile(xs, [1, 25, 50, 75, 95, 97]) 

491 xScale = (xs97 - xs1)/20.0 # This is ~5% of the data range 

492 

493 # 40 was used as the default number of bins because it looked good 

494 xEdges = np.arange(np.nanmin(xs) - xScale, np.nanmax(xs) + xScale, 

495 (np.nanmax(xs) + xScale - (np.nanmin(xs) - xScale))/self.config.nBins) 

496 medYs = np.nanmedian(ys) 

497 fiveSigmaHigh = medYs + 5.0*sigMadYs 

498 fiveSigmaLow = medYs - 5.0*sigMadYs 

499 binSize = (fiveSigmaHigh - fiveSigmaLow)/self.config.nBins 

500 yEdges = np.arange(fiveSigmaLow, fiveSigmaHigh, binSize) 

501 

502 counts, xBins, yBins = np.histogram2d(xs, ys, bins=(xEdges, yEdges)) 

503 countsYs = np.sum(counts, axis=1) 

504 

505 ids = np.where((countsYs > binThresh))[0] 

506 xEdgesPlot = xEdges[ids][1:] 

507 xEdges = xEdges[ids] 

508 

509 if len(ids) > 1: 

510 # Create the codes needed to turn the sigmaMad lines 

511 # into a path to speed up checking which points are 

512 # inside the area. 

513 codes = np.ones(len(xEdgesPlot)*2)*Path.LINETO 

514 codes[0] = Path.MOVETO 

515 codes[-1] = Path.CLOSEPOLY 

516 

517 meds = np.zeros(len(xEdgesPlot)) 

518 threeSigMadVerts = np.zeros((len(xEdgesPlot)*2, 2)) 

519 sigMads = np.zeros(len(xEdgesPlot)) 

520 

521 for (i, xEdge) in enumerate(xEdgesPlot): 

522 ids = np.where((xs < xEdge) & (xs >= xEdges[i]) & (np.isfinite(ys)))[0] 

523 med = np.median(ys[ids]) 

524 sigMad = sigmaMad(ys[ids]) 

525 meds[i] = med 

526 sigMads[i] = sigMad 

527 threeSigMadVerts[i, :] = [xEdge, med + 3*sigMad] 

528 threeSigMadVerts[-(i + 1), :] = [xEdge, med - 3*sigMad] 

529 

530 medLine, = ax.plot(xEdgesPlot, meds, color, label="Running Median") 

531 linesForLegend.append(medLine) 

532 

533 # Make path to check which points lie within one sigma mad 

534 threeSigMadPath = Path(threeSigMadVerts, codes) 

535 

536 # Add lines for the median +/- 3 * sigma MAD 

537 threeSigMadLine, = ax.plot(xEdgesPlot, threeSigMadVerts[:len(xEdgesPlot), 1], color, 

538 alpha=0.4, label=r"3$\sigma_{MAD}$") 

539 ax.plot(xEdgesPlot[::-1], threeSigMadVerts[len(xEdgesPlot):, 1], color, alpha=0.4) 

540 

541 # Add lines for the median +/- 1 * sigma MAD 

542 sigMadLine, = ax.plot(xEdgesPlot, meds + 1.0*sigMads, color, alpha=0.8, 

543 label=r"$\sigma_{MAD}$") 

544 linesForLegend.append(sigMadLine) 

545 ax.plot(xEdgesPlot, meds - 1.0*sigMads, color, alpha=0.8) 

546 

547 # Add lines for the median +/- 2 * sigma MAD 

548 twoSigMadLine, = ax.plot(xEdgesPlot, meds + 2.0*sigMads, color, alpha=0.6, 

549 label=r"2$\sigma_{MAD}$") 

550 linesForLegend.append(twoSigMadLine) 

551 linesForLegend.append(threeSigMadLine) 

552 ax.plot(xEdgesPlot, meds - 2.0*sigMads, color, alpha=0.6) 

553 

554 # Check which points are outside 3 sigma MAD of the median 

555 # and plot these as points. 

556 inside = threeSigMadPath.contains_points(np.array([xs, ys]).T) 

557 ax.plot(xs[~inside], ys[~inside], ".", ms=self.config.minPointSize, alpha=0.3, 

558 mfc=color, mec=color, zorder=-1) 

559 

560 # Add some stats text 

561 xPos = 0.65 - 0.4*j 

562 if hasattr(self.config.highSnStatisticSelectorActions, "statSelector"): 

563 bbox = dict(edgecolor=color, linestyle="--", facecolor="none") 

564 highThresh = self.config.highSnStatisticSelectorActions.statSelector.threshold 

565 statText = f"S/N > {highThresh} Stats [{magCol} $\\lesssim$ {highMags[sourceType]}]\n" 

566 statText += highStats[sourceType] 

567 fig.text(xPos, 0.090, statText, bbox=bbox, transform=fig.transFigure, fontsize=6) 

568 

569 if hasattr(self.config.highSnStatisticSelectorActions, "statSelector"): 

570 bbox = dict(edgecolor=color, linestyle=":", facecolor="none") 

571 lowThresh = self.config.lowSnStatisticSelectorActions.statSelector.threshold 

572 statText = f"S/N > {lowThresh} Stats [{magCol} $\\lesssim$ {lowMags[sourceType]}]\n" 

573 statText += lowStats[sourceType] 

574 fig.text(xPos, 0.020, statText, bbox=bbox, transform=fig.transFigure, fontsize=6) 

575 

576 if self.config.plot2DHist: 

577 histIm = ax.hexbin(xs[inside], ys[inside], gridsize=75, cmap=cmap, mincnt=1, zorder=-2) 

578 

579 # If there are not many sources being used for the 

580 # statistics then plot them individually as just 

581 # plotting a line makes the statistics look wrong 

582 # as the magnitude estimation is iffy for low 

583 # numbers of sources. 

584 sources = (catPlot["sourceType"] == sourceType) 

585 statInfo = catPlot["useForStats"].loc[sources].values 

586 if hasattr(self.config.highSnStatisticSelectorActions, "statSelector"): 

587 highSn = (statInfo == 1) 

588 if np.sum(highSn) < 100 and np.sum(highSn) > 0: 

589 ax.plot(xs[highSn], ys[highSn], marker="x", ms=self.config.minPointSize + 1, 

590 mec="w", mew=2, ls="none") 

591 highSnLine, = ax.plot(xs[highSn], ys[highSn], color=color, marker="x", 

592 ms=self.config.minPointSize + 1, ls="none", label="High SN") 

593 linesForLegend.append(highSnLine) 

594 xMin = np.min(xs[highSn]) 

595 else: 

596 ax.axvline(float(highMags[sourceType]), color=color, ls="--") 

597 if hasattr(self.config.lowSnStatisticSelectorActions, "statSelector"): 

598 lowSn = ((statInfo == 2) | (statInfo == 2)) 

599 if np.sum(lowSn) < 100 and np.sum(lowSn) > 0: 

600 ax.plot(xs[lowSn], ys[lowSn], marker="+", ms=self.config.minPointSize + 1, mec="w", 

601 mew=2, ls="none") 

602 lowSnLine, = ax.plot(xs[lowSn], ys[lowSn], color=color, marker="+", 

603 ms=self.config.minPointSize + 1, ls="none", label="Low SN") 

604 linesForLegend.append(lowSnLine) 

605 if xMin is None or xMin > np.min(xs[lowSn]): 

606 xMin = np.min(xs[lowSn]) 

607 else: 

608 ax.axvline(float(lowMags[sourceType]), color=color, ls=":") 

609 

610 else: 

611 ax.plot(xs, ys, ".", ms=self.config.minPointSize + 3, alpha=0.3, mfc=color, mec=color, 

612 zorder=-1) 

613 meds = np.array([np.nanmedian(ys)]*len(xs)) 

614 medLine, = ax.plot(xs, meds, color, label=f"Median: {np.nanmedian(ys):0.3g}", lw=0.8) 

615 linesForLegend.append(medLine) 

616 sigMads = np.array([sigmaMad(ys, nan_policy="omit")]*len(xs)) 

617 sigMadLine, = ax.plot(xs, meds + 1.0*sigMads, color, alpha=0.8, lw=0.8, 

618 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:0.3g}") 

619 ax.plot(xs, meds - 1.0*sigMads, color, alpha=0.8) 

620 linesForLegend.append(sigMadLine) 

621 histIm = None 

622 

623 # Set the scatter plot limits 

624 if len(ysStars) > 0: 

625 plotMed = np.nanmedian(ysStars) 

626 else: 

627 plotMed = np.nanmedian(ysGalaxies) 

628 if len(xs) < 2: 

629 meds = [np.median(ys)] 

630 

631 if self.config.yLims is not None: 

632 yLimMin = self.config.yLims[0] 

633 yLimMax = self.config.yLims[1] 

634 else: 

635 numSig = 4 

636 yLimMin = plotMed - numSig*sigMadYs 

637 yLimMax = plotMed + numSig*sigMadYs 

638 while (yLimMax < np.max(meds) or yLimMin > np.min(meds)) and numSig < 10: 

639 numSig += 1 

640 

641 numSig += 1 

642 yLimMin = plotMed - numSig*sigMadYs 

643 yLimMax = plotMed + numSig*sigMadYs 

644 ax.set_ylim(yLimMin, yLimMax) 

645 

646 if self.config.xLims is not None: 

647 ax.set_xlim(self.config.xLims[0], self.config.xLims[1]) 

648 elif len(xs) > 2: 

649 if xMin is None: 

650 xMin = xs1 - 2*xScale 

651 ax.set_xlim(xMin, xs97 + 2*xScale) 

652 

653 # Add a line legend 

654 ax.legend(handles=linesForLegend, ncol=4, fontsize=6, loc="upper left", framealpha=0.9, 

655 edgecolor="k", borderpad=0.4, handlelength=1) 

656 

657 # Add axes labels 

658 ax.set_ylabel(yCol, fontsize=10, labelpad=10) 

659 ax.set_xlabel(xCol, fontsize=10, labelpad=2) 

660 

661 # Top histogram 

662 topHist = plt.gcf().add_subplot(gs[0, :-1], sharex=ax) 

663 topHist.hist(catPlot[xCol].values, bins=100, color="grey", alpha=0.3, log=True, 

664 label=f"All ({len(catPlot)})") 

665 if np.any(catPlot["sourceType"] == 2): 

666 topHist.hist(xsGalaxies, bins=100, color="firebrick", histtype="step", log=True, 

667 label=f"Galaxies ({len(np.where(galaxies)[0])})") 

668 if np.any(catPlot["sourceType"] == 1): 

669 topHist.hist(xsStars, bins=100, color="midnightblue", histtype="step", log=True, 

670 label=f"Stars ({len(np.where(stars)[0])})") 

671 topHist.axes.get_xaxis().set_visible(False) 

672 topHist.set_ylabel("Number", fontsize=8) 

673 topHist.legend(fontsize=6, framealpha=0.9, borderpad=0.4, loc="lower left", ncol=3, edgecolor="k") 

674 

675 # Side histogram 

676 sideHist = plt.gcf().add_subplot(gs[1:, -1], sharey=ax) 

677 finiteObjs = np.isfinite(catPlot[yCol].values) 

678 bins = np.linspace(yLimMin, yLimMax) 

679 sideHist.hist(catPlot[yCol].values[finiteObjs], bins=bins, color="grey", alpha=0.3, 

680 orientation="horizontal", log=True) 

681 if np.any(catPlot["sourceType"] == sourceTypeMapper["galaxies"]): 

682 sideHist.hist(ysGalaxies[np.isfinite(ysGalaxies)], bins=bins, color="firebrick", histtype="step", 

683 orientation="horizontal", log=True) 

684 sources = (catPlot["sourceType"].values == sourceTypeMapper["galaxies"]) 

685 highSn = (catPlot["useForStats"].values == 1) 

686 lowSn = (catPlot["useForStats"].values == 2) 

687 sideHist.hist(ysGalaxies[highSn[sources]], bins=bins, color="firebrick", histtype="step", 

688 orientation="horizontal", log=True, ls="--") 

689 sideHist.hist(ysGalaxies[lowSn[sources]], bins=bins, color="firebrick", histtype="step", 

690 orientation="horizontal", log=True, ls=":") 

691 

692 if np.any(catPlot["sourceType"] == sourceTypeMapper["stars"]): 

693 sideHist.hist(ysStars[np.isfinite(ysStars)], bins=bins, color="midnightblue", histtype="step", 

694 orientation="horizontal", log=True) 

695 sources = (catPlot["sourceType"] == sourceTypeMapper["stars"]) 

696 highSn = (catPlot["useForStats"] == 1) 

697 lowSn = (catPlot["useForStats"] == 2) 

698 sideHist.hist(ysStars[highSn[sources]], bins=bins, color="midnightblue", histtype="step", 

699 orientation="horizontal", log=True, ls="--") 

700 sideHist.hist(ysStars[lowSn[sources]], bins=bins, color="midnightblue", histtype="step", 

701 orientation="horizontal", log=True, ls=":") 

702 

703 sideHist.axes.get_yaxis().set_visible(False) 

704 sideHist.set_xlabel("Number", fontsize=8) 

705 if self.config.plot2DHist and histIm is not None: 

706 divider = make_axes_locatable(sideHist) 

707 cax = divider.append_axes("right", size="8%", pad=0) 

708 plt.colorbar(histIm, cax=cax, orientation="vertical", label="Number of Points Per Bin") 

709 

710 # Corner plot of patches showing summary stat in each 

711 axCorner = plt.gcf().add_subplot(gs[0, -1]) 

712 axCorner.yaxis.tick_right() 

713 axCorner.yaxis.set_label_position("right") 

714 axCorner.xaxis.tick_top() 

715 axCorner.xaxis.set_label_position("top") 

716 axCorner.set_aspect("equal") 

717 

718 patches = [] 

719 colors = [] 

720 for dataId in sumStats.keys(): 

721 (corners, stat) = sumStats[dataId] 

722 ra = corners[0][0].asDegrees() 

723 dec = corners[0][1].asDegrees() 

724 xy = (ra, dec) 

725 width = corners[2][0].asDegrees() - ra 

726 height = corners[2][1].asDegrees() - dec 

727 patches.append(Rectangle(xy, width, height)) 

728 colors.append(stat) 

729 ras = [ra.asDegrees() for (ra, dec) in corners] 

730 decs = [dec.asDegrees() for (ra, dec) in corners] 

731 axCorner.plot(ras + [ras[0]], decs + [decs[0]], "k", lw=0.5) 

732 cenX = ra + width / 2 

733 cenY = dec + height / 2 

734 if dataId != "tract": 

735 axCorner.annotate(dataId, (cenX, cenY), color="k", fontsize=4, ha="center", va="center") 

736 

737 # Set the bad color to transparent and make a masked array 

738 colors = np.ma.array(colors, mask=np.isnan(colors)) 

739 collection = PatchCollection(patches, cmap=cmapPatch) 

740 collection.set_array(colors) 

741 axCorner.add_collection(collection) 

742 

743 axCorner.set_xlabel("R.A. (deg)", fontsize=7) 

744 axCorner.set_ylabel("Dec. (deg)", fontsize=7) 

745 axCorner.tick_params(axis="both", labelsize=6, length=0, pad=1.5) 

746 axCorner.invert_xaxis() 

747 

748 # Add a colorbar 

749 pos = axCorner.get_position() 

750 cax = fig.add_axes([pos.x0, pos.y0 + 0.23, pos.x1 - pos.x0, 0.025]) 

751 plt.colorbar(collection, cax=cax, orientation="horizontal") 

752 cax.text(0.5, 0.5, "Median Value", color="k", transform=cax.transAxes, rotation="horizontal", 

753 horizontalalignment="center", verticalalignment="center", fontsize=6) 

754 cax.tick_params(axis="x", labelsize=6, labeltop=True, labelbottom=False, bottom=False, top=True, 

755 pad=0.5, length=2) 

756 

757 plt.draw() 

758 plt.subplots_adjust(wspace=0.0, hspace=0.0, bottom=0.22, left=0.21) 

759 fig = plt.gcf() 

760 fig = addPlotInfo(fig, plotInfo) 

761 

762 return fig