Coverage for python/lsst/analysis/drp/scatterPlot.py: 10%

Shortcuts on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

336 statements  

1import matplotlib.pyplot as plt 

2import numpy as np 

3import pandas as pd 

4from scipy.stats import median_absolute_deviation as sigmaMad 

5from matplotlib import gridspec 

6from matplotlib.patches import Rectangle 

7from matplotlib.path import Path 

8from matplotlib.collections import PatchCollection 

9from mpl_toolkits.axes_grid1 import make_axes_locatable 

10from lsst.pipe.tasks.configurableActions import ConfigurableActionStructField 

11from lsst.pipe.tasks.dataFrameActions import MagColumnNanoJansky, SingleColumnAction 

12from lsst.skymap import BaseSkyMap 

13 

14import lsst.pipe.base as pipeBase 

15import lsst.pex.config as pexConfig 

16 

17from . import dataSelectors as dataSelectors 

18from .plotUtils import generateSummaryStats, parsePlotInfo, addPlotInfo, mkColormap 

19 

20 

21class ScatterPlotWithTwoHistsTaskConnections(pipeBase.PipelineTaskConnections, 

22 dimensions=("tract", "skymap"), 

23 defaultTemplates={"inputCoaddName": "deep", 

24 "plotName": "deltaCoords"}): 

25 

26 catPlot = pipeBase.connectionTypes.Input(doc="The tract wide catalog to make plots from.", 

27 storageClass="DataFrame", 

28 name="objectTable_tract", 

29 dimensions=("tract", "skymap"), 

30 deferLoad=True) 

31 

32 skymap = pipeBase.connectionTypes.Input(doc="The skymap for the tract", 

33 storageClass="SkyMap", 

34 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME, 

35 dimensions=("skymap",)) 

36 

37 scatterPlot = pipeBase.connectionTypes.Output(doc="A scatter plot with histograms for both axes.", 

38 storageClass="Plot", 

39 name="scatterTwoHistPlot_{plotName}", 

40 dimensions=("tract", "skymap")) 

41 

42 

43class ScatterPlotWithTwoHistsTaskConfig(pipeBase.PipelineTaskConfig, 

44 pipelineConnections=ScatterPlotWithTwoHistsTaskConnections): 

45 

46 axisActions = ConfigurableActionStructField( 

47 doc="The actions to use to calculate the values used on each axis. The defaults for the" 

48 "column names xAction and magAction are set to iCModelFlux.", 

49 default={"xAction": MagColumnNanoJansky, "yAction": SingleColumnAction, 

50 "magAction": MagColumnNanoJansky}, 

51 ) 

52 

53 axisLabels = pexConfig.DictField( 

54 doc="Name of the dataframe columns to plot, will be used as the axis label: {'x':, 'y':, 'mag':}" 

55 "The mag column is used to decide which points to include in the printed statistics.", 

56 keytype=str, 

57 itemtype=str 

58 ) 

59 

60 def get_requirements(self): 

61 """Return inputs required for a Task to run with this config. 

62 

63 Returns 

64 ------- 

65 bands : `set` 

66 The required bands. 

67 columns : `set` 

68 The required column names. 

69 """ 

70 columnNames = {"patch"} 

71 bands = set() 

72 for actionStruct in [self.axisActions, 

73 self.selectorActions, 

74 self.highSnStatisticSelectorActions, 

75 self.lowSnStatisticSelectorActions, 

76 self.sourceSelectorActions]: 

77 for action in actionStruct: 

78 for col in action.columns: 

79 if col is not None: 

80 columnNames.add(col) 

81 column_split = col.split("_") 

82 # If there's no underscore, it has no band prefix 

83 if len(column_split) > 1: 

84 band = column_split[0] 

85 if band not in self.nonBandColumnPrefixes: 

86 bands.add(band) 

87 return bands, columnNames 

88 

89 nonBandColumnPrefixes = pexConfig.ListField( 

90 doc="Column prefixes that are not bands and which should not be added to the set of bands", 

91 dtype=str, 

92 default=["coord", "extend", "detect", "xy", "merge"], 

93 ) 

94 

95 selectorActions = ConfigurableActionStructField( 

96 doc="Which selectors to use to narrow down the data for QA plotting.", 

97 default={"flagSelector": dataSelectors.CoaddPlotFlagSelector}, 

98 ) 

99 

100 highSnStatisticSelectorActions = ConfigurableActionStructField( 

101 doc="Selectors to use to decide which points to use for calculating the high SN statistics.", 

102 default={"statSelector": dataSelectors.SnSelector}, 

103 ) 

104 

105 lowSnStatisticSelectorActions = ConfigurableActionStructField( 

106 doc="Selectors to use to decide which points to use for calculating the low SN statistics.", 

107 default={"statSelector": dataSelectors.SnSelector}, 

108 ) 

109 

110 sourceSelectorActions = ConfigurableActionStructField( 

111 doc="What types of sources to use.", 

112 default={"sourceSelector": dataSelectors.StarIdentifier}, 

113 ) 

114 

115 nBins = pexConfig.Field( 

116 doc="Number of bins to put on the x axis.", 

117 default=40.0, 

118 dtype=float, 

119 ) 

120 

121 plot2DHist = pexConfig.Field( 

122 doc="Plot a 2D histogram in the densist area of points on the scatter plot." 

123 "Doesn't look great if plotting mulitple datasets on top of each other.", 

124 default=True, 

125 dtype=bool, 

126 ) 

127 

128 def setDefaults(self): 

129 super().setDefaults() 

130 self.axisActions.magAction.column = "i_cModelFlux" 

131 self.axisActions.xAction.column = "i_cModelFlux" 

132 self.highSnStatisticSelectorActions.statSelector.threshold = 2700 

133 self.lowSnStatisticSelectorActions.statSelector.threshold = 500 

134 

135 

136class ScatterPlotWithTwoHistsTask(pipeBase.PipelineTask): 

137 

138 ConfigClass = ScatterPlotWithTwoHistsTaskConfig 

139 _DefaultName = "scatterPlotWithTwoHistsTask" 

140 

141 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

142 # Docs inherited from base class 

143 bands, columnNames = self.config.get_requirements() 

144 inputs = butlerQC.get(inputRefs) 

145 dataFrame = inputs["catPlot"].get(parameters={"columns": columnNames}) 

146 inputs['catPlot'] = dataFrame 

147 dataId = butlerQC.quantum.dataId 

148 inputs["dataId"] = dataId 

149 inputs["runName"] = inputRefs.catPlot.datasetRef.run 

150 localConnections = self.config.ConnectionsClass(config=self.config) 

151 inputs["tableName"] = localConnections.catPlot.name 

152 inputs["plotName"] = localConnections.scatterPlot.name 

153 inputs["bands"] = bands 

154 outputs = self.run(**inputs) 

155 butlerQC.put(outputs, outputRefs) 

156 

157 def run(self, catPlot, dataId, runName, skymap, tableName, bands, plotName): 

158 """Prep the catalogue and then make a scatterPlot of the given column. 

159 

160 Parameters 

161 ---------- 

162 catPlot : `pandas.core.frame.DataFrame` 

163 The catalog to plot the points from. 

164 dataId : 

165 `lsst.daf.butler.core.dimensions._coordinate._ExpandedTupleDataCoordinate` 

166 The dimensions that the plot is being made from. 

167 runName : `str` 

168 The name of the collection that the plot is written out to. 

169 skymap : `lsst.skymap` 

170 The skymap used to define the patch boundaries. 

171 tableName : `str` 

172 The type of table used to make the plot. 

173 

174 Returns 

175 ------- 

176 `pipeBase.Struct` containing: 

177 scatterPlot : `matplotlib.figure.Figure` 

178 The resulting figure. 

179 

180 Notes 

181 ----- 

182 The catalogue is first narrowed down using the selectors specified in 

183 `self.config.selectorActions`. 

184 If the column names are 'Functor' then the functors specified in 

185 `self.config.axisFunctors` are used to calculate the required values. 

186 After this the following functions are run: 

187 

188 `parsePlotInfo` which uses the dataId, runName and tableName to add 

189 useful information to the plot. 

190 

191 `generateSummaryStats` which parses the skymap to give the corners of 

192 the patches for later plotting and calculates some basic statistics 

193 in each patch for the column in self.zColName. 

194 

195 `scatterPlotWithTwoHists` which makes a scatter plot of the points with 

196 a histogram of each axis. 

197 """ 

198 

199 # Apply the selectors to narrow down the sources to use 

200 mask = np.ones(len(catPlot), dtype=bool) 

201 for selector in self.config.selectorActions: 

202 mask &= selector(catPlot) 

203 catPlot = catPlot[mask] 

204 

205 columns = {self.config.axisLabels["x"]: self.config.axisActions.xAction(catPlot), 

206 self.config.axisLabels["y"]: self.config.axisActions.yAction(catPlot), 

207 self.config.axisLabels["mag"]: self.config.axisActions.magAction(catPlot), 

208 "patch": catPlot["patch"]} 

209 for actionStruct in [self.config.highSnStatisticSelectorActions, 

210 self.config.lowSnStatisticSelectorActions, 

211 self.config.sourceSelectorActions]: 

212 for action in actionStruct: 

213 for col in action.columns: 

214 columns.update({col: catPlot[col]}) 

215 

216 plotDf = pd.DataFrame(columns) 

217 

218 sourceTypes = np.zeros(len(plotDf)) 

219 for selector in self.config.sourceSelectorActions: 

220 # The source selectors return 1 for a star and 2 for a galaxy 

221 # rather than a mask this allows the information about which 

222 # type of sources are being plotted to be propagated 

223 sourceTypes += selector(catPlot) 

224 if list(self.config.sourceSelectorActions) == []: 

225 sourceTypes = [10]*len(plotDf) 

226 plotDf.loc[:, "sourceType"] = sourceTypes 

227 

228 # Decide which points to use for stats calculation 

229 useForStats = np.zeros(len(plotDf)) 

230 lowSnMask = np.ones(len(plotDf), dtype=bool) 

231 for selector in self.config.lowSnStatisticSelectorActions: 

232 lowSnMask &= selector(plotDf) 

233 useForStats[lowSnMask] = 2 

234 

235 highSnMask = np.ones(len(plotDf), dtype=bool) 

236 for selector in self.config.highSnStatisticSelectorActions: 

237 highSnMask &= selector(plotDf) 

238 useForStats[highSnMask] = 1 

239 plotDf.loc[:, "useForStats"] = useForStats 

240 

241 # Get the S/N cut used 

242 try: 

243 SN = self.config.selectorActions.SnSelector.threshold 

244 except AttributeError: 

245 SN = "N/A" 

246 

247 # Get useful information about the plot 

248 plotInfo = parsePlotInfo(dataId, runName, tableName, bands, plotName, SN) 

249 # Calculate the corners of the patches and some associated stats 

250 sumStats = {} if skymap is None else generateSummaryStats( 

251 plotDf, self.config.axisLabels["y"], skymap, plotInfo) 

252 # Make the plot 

253 try: 

254 fig = self.scatterPlotWithTwoHists(plotDf, plotInfo, sumStats) 

255 except Exception: 

256 # This is a workaround until scatterPlotWithTwoHists works properly 

257 fig = plt.figure(dpi=300) 

258 

259 return pipeBase.Struct(scatterPlot=fig) 

260 

261 def scatterPlotWithTwoHists(self, catPlot, plotInfo, sumStats, yLims=False, xLims=False): 

262 """Makes a generic plot with a 2D histogram and collapsed histograms of 

263 each axis. 

264 

265 Parameters 

266 ---------- 

267 catPlot : `pandas.core.frame.DataFrame` 

268 The catalog to plot the points from. 

269 plotInfo : `dict` 

270 A dictionary of information about the data being plotted with keys: 

271 ``"run"`` 

272 The output run for the plots (`str`). 

273 ``"skymap"`` 

274 The type of skymap used for the data (`str`). 

275 ``"filter"`` 

276 The filter used for this data (`str`). 

277 ``"tract"`` 

278 The tract that the data comes from (`str`). 

279 sumStats : `dict` 

280 A dictionary where the patchIds are the keys which store the R.A. 

281 and dec of the corners of the patch, along with a summary 

282 statistic for each patch. 

283 yLims : `Bool` or `tuple`, optional 

284 The y axis limits to use for the plot. If `False`, they are 

285 calculated from the data. If being given a tuple of 

286 (yMin, yMax). 

287 xLims : `Bool` or `tuple`, optional 

288 The x axis limits to use for the plot. If `False`, they are 

289 calculated from the data. 

290 If being given a tuple of (xMin, xMax). 

291 

292 Returns 

293 ------- 

294 fig : `matplotlib.figure.Figure` 

295 The resulting figure. 

296 

297 Notes 

298 ----- 

299 Uses the axisLabels config options `x` and `y` and the axisAction 

300 config options `xAction` and `yAction` to plot a scatter 

301 plot of the values against each other. A histogram of the points 

302 collapsed onto each axis is also plotted. A summary panel showing the 

303 median of the y value in each patch is shown in the upper right corner 

304 of the resultant plot. The code uses the selectorActions to decide 

305 which points to plot and the statisticSelector actions to determine 

306 which points to use for the printed statistics. 

307 """ 

308 self.log.info("Plotting {}: the values of {} on a scatter plot.".format( 

309 self.config.connections.plotName, self.config.axisLabels["y"])) 

310 

311 fig = plt.figure(dpi=300) 

312 gs = gridspec.GridSpec(4, 4) 

313 

314 newBlues = mkColormap(["paleturquoise", "midnightBlue"]) 

315 newReds = mkColormap(["lemonchiffon", "firebrick"]) 

316 

317 # Need to separate stars and galaxies 

318 stars = (catPlot["sourceType"] == 1) 

319 galaxies = (catPlot["sourceType"] == 2) 

320 

321 xCol = self.config.axisLabels["x"] 

322 yCol = self.config.axisLabels["y"] 

323 magCol = self.config.axisLabels["mag"] 

324 

325 # For galaxies 

326 xsGalaxies = catPlot.loc[galaxies, xCol] 

327 ysGalaxies = catPlot.loc[galaxies, yCol] 

328 

329 # For stars 

330 xsStars = catPlot.loc[stars, xCol] 

331 ysStars = catPlot.loc[stars, yCol] 

332 

333 highStats = {} 

334 highMags = {} 

335 lowStats = {} 

336 lowMags = {} 

337 

338 # sourceTypes: 1 - stars, 2 - galaxies, 9 - unknowns 

339 # 10 - all 

340 sourceTypeList = [1, 2, 9, 10] 

341 sourceTypeMapper = {"stars": 1, "galaxies": 2, "unknowns": 9, "all": 10} 

342 # Calculate some statistics 

343 for sourceType in sourceTypeList: 

344 if np.any(catPlot["sourceType"] == sourceType): 

345 sources = (catPlot["sourceType"] == sourceType) 

346 highSn = ((catPlot["useForStats"] == 1) & sources) 

347 highSnMed = np.nanmedian(catPlot.loc[highSn, yCol]) 

348 highSnMad = sigmaMad(catPlot.loc[highSn, yCol], nan_policy="omit") 

349 

350 lowSn = (((catPlot["useForStats"] == 1) | (catPlot["useForStats"] == 2)) & sources) 

351 lowSnMed = np.nanmedian(catPlot.loc[lowSn, yCol]) 

352 lowSnMad = sigmaMad(catPlot.loc[lowSn, yCol], nan_policy="omit") 

353 

354 highStatsStr = ("Median: {:0.2f} ".format(highSnMed) 

355 + r"$\sigma_{MAD}$: " + "{:0.2f}".format(highSnMad)) 

356 highStats[sourceType] = highStatsStr 

357 

358 lowStatsStr = ("Median: {:0.2f} ".format(lowSnMed) 

359 + r"$\sigma_{MAD}$: " + "{:0.2f}".format(lowSnMad)) 

360 lowStats[sourceType] = lowStatsStr 

361 

362 if np.sum(highSn) > 0: 

363 highMags[sourceType] = f"{np.nanmax(catPlot.loc[highSn, magCol]):.2f}" 

364 else: 

365 highMags[sourceType] = "-" 

366 if np.sum(lowSn) > 0.0: 

367 lowMags[sourceType] = f"{np.nanmax(catPlot.loc[lowSn, magCol]):.2f}" 

368 else: 

369 lowMags[sourceType] = "-" 

370 

371 # Main scatter plot 

372 ax = fig.add_subplot(gs[1:, :-1]) 

373 binThresh = 5 

374 

375 yBinsOut = [] 

376 linesForLegend = [] 

377 

378 if (np.any(catPlot["sourceType"] == sourceTypeMapper["stars"]) 

379 and not np.any(catPlot["sourceType"] == sourceTypeMapper["galaxies"])): 

380 toPlotList = [(xsStars.values, ysStars.values, "midnightblue", newBlues, 

381 sourceTypeMapper["stars"])] 

382 elif (np.any(catPlot["sourceType"] == sourceTypeMapper["galaxies"]) 

383 and not np.any(catPlot["sourceType"] == sourceTypeMapper["stars"])): 

384 toPlotList = [(xsGalaxies.values, ysGalaxies.values, "firebrick", newReds, 

385 sourceTypeMapper["galaxies"])] 

386 elif (np.any(catPlot["sourceType"] == sourceTypeMapper["stars"]) 

387 and np.any(catPlot["sourceType"] == sourceTypeMapper["galaxies"])): 

388 toPlotList = [(xsGalaxies.values, ysGalaxies.values, "firebrick", newReds, 

389 sourceTypeMapper["galaxies"]), 

390 (xsStars.values, ysStars.values, "midnightblue", newBlues, 

391 sourceTypeMapper["stars"])] 

392 elif np.any(catPlot["sourceType"] == sourceTypeMapper["unknowns"]): 

393 unknowns = (catPlot["sourceType"] == sourceTypeMapper["unknowns"]) 

394 toPlotList = [(catPlot.loc[unknowns, xCol].values, catPlot.loc[unknowns, yCol].values, 

395 "green", None, sourceTypeMapper["unknowns"])] 

396 elif np.any(catPlot["sourceType"] == sourceTypeMapper["all"]): 

397 toPlotList = [(catPlot[xCol].values, catPlot[yCol].values, "purple", None, 

398 sourceTypeMapper["all"])] 

399 else: 

400 toPlotList = [] 

401 

402 for (j, (xs, ys, color, cmap, sourceType)) in enumerate(toPlotList): 

403 if len(xs) < 2: 

404 medLine, = ax.plot(xs, np.nanmedian(ys), color, 

405 label="Median: {:0.2f}".format(np.nanmedian(ys)), lw=0.8) 

406 linesForLegend.append(medLine) 

407 sigMads = np.array([sigmaMad(ys, nan_policy="omit")]*len(xs)) 

408 sigMadLine, = ax.plot(xs, np.nanmedian(ys) + 1.0*sigMads, color, alpha=0.8, lw=0.8, 

409 label=r"$\sigma_{MAD}$: " + "{:0.2f}".format(sigMads[0])) 

410 ax.plot(xs, np.nanmedian(ys) - 1.0*sigMads, color, alpha=0.8) 

411 linesForLegend.append(sigMadLine) 

412 continue 

413 

414 [xs1, xs25, xs50, xs75, xs95, xs97] = np.nanpercentile(xs, [1, 25, 50, 75, 95, 97]) 

415 xScale = (xs97 - xs1)/20.0 # This is ~5% of the data range 

416 

417 # 40 was used as the default number of bins because it looked good 

418 xEdges = np.arange(np.nanmin(xs) - xScale, np.nanmax(xs) + xScale, 

419 (np.nanmax(xs) + xScale - (np.nanmin(xs) - xScale))/self.config.nBins) 

420 medYs = np.nanmedian(ys) 

421 sigMadYs = sigmaMad(ys, nan_policy="omit") 

422 fiveSigmaHigh = medYs + 5.0*sigMadYs 

423 fiveSigmaLow = medYs - 5.0*sigMadYs 

424 binSize = (fiveSigmaHigh - fiveSigmaLow)/101.0 

425 yEdges = np.arange(fiveSigmaLow, fiveSigmaHigh, binSize) 

426 

427 counts, xBins, yBins = np.histogram2d(xs, ys, bins=(xEdges, yEdges)) 

428 yBinsOut.append(yBins) 

429 countsYs = np.sum(counts, axis=1) 

430 

431 ids = np.where((countsYs > binThresh))[0] 

432 xEdgesPlot = xEdges[ids][1:] 

433 xEdges = xEdges[ids] 

434 

435 if len(ids) > 1: 

436 # Create the codes needed to turn the sigmaMad lines 

437 # into a path to speed up checking which points are 

438 # inside the area. 

439 codes = np.ones(len(xEdgesPlot)*2)*Path.LINETO 

440 codes[0] = Path.MOVETO 

441 codes[-1] = Path.CLOSEPOLY 

442 

443 meds = np.zeros(len(xEdgesPlot)) 

444 threeSigMadVerts = np.zeros((len(xEdgesPlot)*2, 2)) 

445 sigMads = np.zeros(len(xEdgesPlot)) 

446 

447 for (i, xEdge) in enumerate(xEdgesPlot): 

448 ids = np.where((xs < xEdge) & (xs > xEdges[i]) & (np.isfinite(ys)))[0] 

449 med = np.median(ys[ids]) 

450 sigMad = sigmaMad(ys[ids]) 

451 meds[i] = med 

452 sigMads[i] = sigMad 

453 threeSigMadVerts[i, :] = [xEdge, med + 3*sigMad] 

454 threeSigMadVerts[-(i + 1), :] = [xEdge, med - 3*sigMad] 

455 

456 medLine, = ax.plot(xEdgesPlot, meds, color, label="Running Median") 

457 linesForLegend.append(medLine) 

458 

459 # Make path to check which points lie within one sigma mad 

460 threeSigMadPath = Path(threeSigMadVerts, codes) 

461 

462 # Add lines for the median +/- 3 * sigma MAD 

463 threeSigMadLine, = ax.plot(xEdgesPlot, threeSigMadVerts[:len(xEdgesPlot), 1], color, 

464 alpha=0.4, label=r"3$\sigma_{MAD}$") 

465 ax.plot(xEdgesPlot[::-1], threeSigMadVerts[len(xEdgesPlot):, 1], color, alpha=0.4) 

466 

467 # Add lines for the median +/- 1 * sigma MAD 

468 sigMadLine, = ax.plot(xEdgesPlot, meds + 1.0*sigMads, color, alpha=0.8, 

469 label=r"$\sigma_{MAD}$") 

470 linesForLegend.append(sigMadLine) 

471 ax.plot(xEdgesPlot, meds - 1.0*sigMads, color, alpha=0.8) 

472 

473 # Add lines for the median +/- 2 * sigma MAD 

474 twoSigMadLine, = ax.plot(xEdgesPlot, meds + 2.0*sigMads, color, alpha=0.6, 

475 label=r"2$\sigma_{MAD}$") 

476 linesForLegend.append(twoSigMadLine) 

477 linesForLegend.append(threeSigMadLine) 

478 ax.plot(xEdgesPlot, meds - 2.0*sigMads, color, alpha=0.6) 

479 

480 # Check which points are outside 3 sigma MAD of the median 

481 # and plot these as points. 

482 inside = threeSigMadPath.contains_points(np.array([xs, ys]).T) 

483 points, = ax.plot(xs[~inside], ys[~inside], ".", ms=3, alpha=0.3, mfc=color, mec=color, 

484 zorder=-1) 

485 

486 # Add some stats text 

487 xPos = 0.65 - 0.4*j 

488 bbox = dict(edgecolor=color, linestyle="--", facecolor="none") 

489 highThresh = self.config.highSnStatisticSelectorActions.statSelector.threshold 

490 statText = f"S/N > {highThresh} Stats ({magCol} < {highMags[sourceType]})\n" 

491 statText += highStats[sourceType] 

492 fig.text(xPos, 0.087, statText, bbox=bbox, transform=fig.transFigure, fontsize=6) 

493 if highMags[sourceType] != "-": 

494 ax.axvline(float(highMags[sourceType]), color=color, ls="--") 

495 

496 bbox = dict(edgecolor=color, linestyle=":", facecolor="none") 

497 lowThresh = self.config.lowSnStatisticSelectorActions.statSelector.threshold 

498 statText = f"S/N > {lowThresh} Stats ({magCol} < {lowMags[sourceType]})\n" 

499 statText += lowStats[sourceType] 

500 fig.text(xPos, 0.017, statText, bbox=bbox, transform=fig.transFigure, fontsize=6) 

501 if lowMags[sourceType] != "-": 

502 ax.axvline(float(lowMags[sourceType]), color=color, ls=":") 

503 

504 if self.config.plot2DHist: 

505 histIm = ax.hexbin(xs[inside], ys[inside], gridsize=75, cmap=cmap, mincnt=1, zorder=-2) 

506 

507 else: 

508 points, = ax.plot(xs, ys, ".", ms=5, alpha=0.3, mfc=color, mec=color, zorder=-1) 

509 meds = np.array([np.nanmedian(ys)]*len(xs)) 

510 medLine, = ax.plot(xs, meds, color, label=f"Median: {np.nanmedian(ys):0.2f}", lw=0.8) 

511 linesForLegend.append(medLine) 

512 sigMads = np.array([sigmaMad(ys, nan_policy="omit")]*len(xs)) 

513 sigMadLine, = ax.plot(xs, meds + 1.0*sigMads, color, alpha=0.8, lw=0.8, 

514 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:0.2f}") 

515 ax.plot(xs, meds - 1.0*sigMads, color, alpha=0.8) 

516 linesForLegend.append(sigMadLine) 

517 histIm = None 

518 

519 # Set the scatter plot limits 

520 if len(ysStars) > 0: 

521 plotMed = np.nanmedian(ysStars) 

522 else: 

523 plotMed = np.nanmedian(ysGalaxies) 

524 if yLims: 

525 ax.set_ylim(yLims[0], yLims[1]) 

526 else: 

527 numSig = 4 

528 yLimMin = plotMed - numSig*sigMadYs 

529 yLimMax = plotMed + numSig*sigMadYs 

530 while (yLimMax < np.max(meds) or yLimMin > np.min(meds)) and numSig < 10: 

531 numSig += 1 

532 

533 numSig += 1 

534 yLimMin = plotMed - numSig*sigMadYs 

535 yLimMax = plotMed + numSig*sigMadYs 

536 ax.set_ylim(yLimMin, yLimMax) 

537 

538 if xLims: 

539 ax.set_xlim(xLims[0], xLims[1]) 

540 else: 

541 ax.set_xlim(xs1 - xScale, xs97) 

542 ax.set_xlim(np.nanmin(xs), np.nanmax(xs)) 

543 

544 # Add a line legend 

545 ax.legend(handles=linesForLegend, ncol=4, fontsize=6, loc="upper left", framealpha=0.9, 

546 edgecolor="k", borderpad=0.4, handlelength=1) 

547 

548 # Add axes labels 

549 ax.set_ylabel(yCol, fontsize=10, labelpad=10) 

550 ax.set_xlabel(xCol, fontsize=10, labelpad=2) 

551 

552 # Top histogram 

553 topHist = plt.gcf().add_subplot(gs[0, :-1], sharex=ax) 

554 topHist.hist(catPlot[xCol].values, bins=100, color="grey", alpha=0.3, log=True, 

555 label=f"All ({len(catPlot)})") 

556 if np.any(catPlot["sourceType"] == 2): 

557 topHist.hist(xsGalaxies, bins=100, color="firebrick", histtype="step", log=True, 

558 label=f"Galaxies ({len(np.where(galaxies)[0])})") 

559 if np.any(catPlot["sourceType"] == 1): 

560 topHist.hist(xsStars, bins=100, color="midnightblue", histtype="step", log=True, 

561 label=f"Stars ({len(np.where(stars)[0])})") 

562 topHist.axes.get_xaxis().set_visible(False) 

563 topHist.set_ylabel("Number", fontsize=8) 

564 topHist.legend(fontsize=6, framealpha=0.9, borderpad=0.4, loc="lower left", ncol=3, edgecolor="k") 

565 

566 # Side histogram 

567 sideHist = plt.gcf().add_subplot(gs[1:, -1], sharey=ax) 

568 finiteObjs = np.isfinite(catPlot[yCol].values) 

569 bins = np.linspace(yLimMin, yLimMax) 

570 sideHist.hist(catPlot[yCol].values[finiteObjs], bins=bins, color="grey", alpha=0.3, 

571 orientation="horizontal", log=True) 

572 if np.any(catPlot["sourceType"] == sourceTypeMapper["galaxies"]): 

573 sideHist.hist(ysGalaxies[np.isfinite(ysGalaxies)], bins=bins, color="firebrick", histtype="step", 

574 orientation="horizontal", log=True) 

575 if highMags[sourceTypeMapper["galaxies"]] != "-": 

576 sideHist.hist(ysGalaxies[np.isfinite(ysGalaxies) & (xsGalaxies < float(highMags[2]))], 

577 bins=bins, color="firebrick", histtype="step", orientation="horizontal", 

578 log=True, ls="--") 

579 if lowMags[sourceTypeMapper["galaxies"]] != "-": 

580 sideHist.hist(ysGalaxies[np.isfinite(ysGalaxies) & (xsGalaxies < float(lowMags[2]))], 

581 bins=bins, color="firebrick", histtype="step", orientation="horizontal", 

582 log=True, ls=":") 

583 

584 if np.any(catPlot["sourceType"] == sourceTypeMapper["stars"]): 

585 sideHist.hist(ysStars[np.isfinite(ysStars)], bins=bins, color="midnightblue", histtype="step", 

586 orientation="horizontal", log=True) 

587 if highMags[sourceTypeMapper["stars"]] != "-": 

588 sideHist.hist(ysStars[np.isfinite(ysStars) & (xsStars < float(highMags[1]))], bins=bins, 

589 color="midnightblue", histtype="step", orientation="horizontal", log=True, 

590 ls="--") 

591 if lowMags[sourceTypeMapper["stars"]] != "-": 

592 sideHist.hist(ysStars[np.isfinite(ysStars) & (xsStars < float(lowMags[1]))], bins=bins, 

593 color="midnightblue", histtype="step", orientation="horizontal", log=True, 

594 ls=":") 

595 

596 sideHist.axes.get_yaxis().set_visible(False) 

597 sideHist.set_xlabel("Number", fontsize=8) 

598 if self.config.plot2DHist and histIm is not None: 

599 divider = make_axes_locatable(sideHist) 

600 cax = divider.append_axes("right", size="8%", pad=0) 

601 plt.colorbar(histIm, cax=cax, orientation="vertical", label="Number of Points Per Bin") 

602 

603 # Corner plot of patches showing summary stat in each 

604 axCorner = plt.gcf().add_subplot(gs[0, -1]) 

605 axCorner.yaxis.tick_right() 

606 axCorner.yaxis.set_label_position("right") 

607 axCorner.xaxis.tick_top() 

608 axCorner.xaxis.set_label_position("top") 

609 axCorner.set_aspect("equal") 

610 

611 patches = [] 

612 colors = [] 

613 for dataId in sumStats.keys(): 

614 (corners, stat) = sumStats[dataId] 

615 ra = corners[0][0].asDegrees() 

616 dec = corners[0][1].asDegrees() 

617 xy = (ra, dec) 

618 width = corners[2][0].asDegrees() - ra 

619 height = corners[2][1].asDegrees() - dec 

620 patches.append(Rectangle(xy, width, height)) 

621 colors.append(stat) 

622 ras = [ra.asDegrees() for (ra, dec) in corners] 

623 decs = [dec.asDegrees() for (ra, dec) in corners] 

624 axCorner.plot(ras + [ras[0]], decs + [decs[0]], "k", lw=0.5) 

625 cenX = ra + width / 2 

626 cenY = dec + height / 2 

627 if dataId != "tract": 

628 axCorner.annotate(dataId, (cenX, cenY), color="k", fontsize=4, ha="center", va="center") 

629 

630 cmapUse = plt.cm.coolwarm 

631 # Set the bad color to transparent and make a masked array 

632 cmapUse.set_bad(color="none") 

633 colors = np.ma.array(colors, mask=np.isnan(colors)) 

634 collection = PatchCollection(patches, cmap=cmapUse) 

635 collection.set_array(colors) 

636 axCorner.add_collection(collection) 

637 

638 axCorner.set_xlabel("R.A. (deg)", fontsize=7) 

639 axCorner.set_ylabel("Dec. (deg)", fontsize=7) 

640 axCorner.tick_params(axis="both", labelsize=6, length=0, pad=1.5) 

641 axCorner.invert_xaxis() 

642 

643 # Add a colorbar 

644 pos = axCorner.get_position() 

645 cax = fig.add_axes([pos.x0, pos.y0 + 0.23, pos.x1 - pos.x0, 0.025]) 

646 plt.colorbar(collection, cax=cax, orientation="horizontal") 

647 cax.text(0.5, 0.5, "Median Value", color="k", transform=cax.transAxes, rotation="horizontal", 

648 horizontalalignment="center", verticalalignment="center", fontsize=6) 

649 cax.tick_params(axis="x", labelsize=6, labeltop=True, labelbottom=False, bottom=False, top=True, 

650 pad=0.5, length=2) 

651 

652 plt.draw() 

653 plt.subplots_adjust(wspace=0.0, hspace=0.0, bottom=0.22, left=0.21) 

654 fig = plt.gcf() 

655 fig = addPlotInfo(fig, plotInfo) 

656 

657 return fig