Coverage for python/lsst/analysis/drp/scatterPlot.py: 10%

Shortcuts on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

332 statements  

1import matplotlib.pyplot as plt 

2import numpy as np 

3import pandas as pd 

4from scipy.stats import median_absolute_deviation as sigmaMad 

5from matplotlib import gridspec 

6from matplotlib.patches import Rectangle 

7from matplotlib.path import Path 

8from matplotlib.collections import PatchCollection 

9from mpl_toolkits.axes_grid1 import make_axes_locatable 

10from lsst.pipe.tasks.configurableActions import ConfigurableActionStructField 

11from lsst.pipe.tasks.dataFrameActions import MagColumnNanoJansky, SingleColumnAction 

12from lsst.skymap import BaseSkyMap 

13 

14import lsst.pipe.base as pipeBase 

15import lsst.pex.config as pexConfig 

16 

17from . import dataSelectors as dataSelectors 

18from .plotUtils import generateSummaryStats, parsePlotInfo, addPlotInfo, mkColormap 

19 

20 

21class ScatterPlotWithTwoHistsTaskConnections(pipeBase.PipelineTaskConnections, 

22 dimensions=("tract", "skymap"), 

23 defaultTemplates={"inputCoaddName": "deep", 

24 "plotName": "deltaCoords"}): 

25 

26 catPlot = pipeBase.connectionTypes.Input(doc="The tract wide catalog to make plots from.", 

27 storageClass="DataFrame", 

28 name="objectTable_tract", 

29 dimensions=("tract", "skymap"), 

30 deferLoad=True) 

31 

32 skymap = pipeBase.connectionTypes.Input(doc="The skymap for the tract", 

33 storageClass="SkyMap", 

34 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME, 

35 dimensions=("skymap",)) 

36 

37 scatterPlot = pipeBase.connectionTypes.Output(doc="A scatter plot with histograms for both axes.", 

38 storageClass="Plot", 

39 name="scatterTwoHistPlot_{plotName}", 

40 dimensions=("tract", "skymap")) 

41 

42 

43class ScatterPlotWithTwoHistsTaskConfig(pipeBase.PipelineTaskConfig, 

44 pipelineConnections=ScatterPlotWithTwoHistsTaskConnections): 

45 

46 axisActions = ConfigurableActionStructField( 

47 doc="The actions to use to calculate the values used on each axis. The defaults for the" 

48 "column names xAction and magAction are set to iCModelFlux.", 

49 default={"xAction": MagColumnNanoJansky, "yAction": SingleColumnAction, 

50 "magAction": MagColumnNanoJansky}, 

51 ) 

52 

53 axisLabels = pexConfig.DictField( 

54 doc="Name of the dataframe columns to plot, will be used as the axis label: {'x':, 'y':, 'mag':}" 

55 "The mag column is used to decide which points to include in the printed statistics.", 

56 keytype=str, 

57 itemtype=str 

58 ) 

59 

60 def get_requirements(self): 

61 """Return inputs required for a Task to run with this config. 

62 

63 Returns 

64 ------- 

65 bands : `set` 

66 The required bands. 

67 columns : `set` 

68 The required column names. 

69 """ 

70 columnNames = {"patch"} 

71 bands = set() 

72 for actionStruct in [self.axisActions, 

73 self.selectorActions, 

74 self.highSnStatisticSelectorActions, 

75 self.lowSnStatisticSelectorActions, 

76 self.sourceSelectorActions]: 

77 for action in actionStruct: 

78 for col in action.columns: 

79 if col is not None: 

80 columnNames.add(col) 

81 column_split = col.split("_") 

82 # If there's no underscore, it has no band prefix 

83 if len(column_split) > 1: 

84 band = column_split[0] 

85 if band not in self.nonBandColumnPrefixes: 

86 bands.add(band) 

87 return bands, columnNames 

88 

89 nonBandColumnPrefixes = pexConfig.ListField( 

90 doc="Column prefixes that are not bands and which should not be added to the set of bands", 

91 dtype=str, 

92 default=["coord", "extend", "detect", "xy", "merge"], 

93 ) 

94 

95 selectorActions = ConfigurableActionStructField( 

96 doc="Which selectors to use to narrow down the data for QA plotting.", 

97 default={"flagSelector": dataSelectors.CoaddPlotFlagSelector}, 

98 ) 

99 

100 highSnStatisticSelectorActions = ConfigurableActionStructField( 

101 doc="Selectors to use to decide which points to use for calculating the high SN statistics.", 

102 default={"statSelector": dataSelectors.SnSelector}, 

103 ) 

104 

105 lowSnStatisticSelectorActions = ConfigurableActionStructField( 

106 doc="Selectors to use to decide which points to use for calculating the low SN statistics.", 

107 default={"statSelector": dataSelectors.SnSelector}, 

108 ) 

109 

110 sourceSelectorActions = ConfigurableActionStructField( 

111 doc="What types of sources to use.", 

112 default={"sourceSelector": dataSelectors.StarIdentifier}, 

113 ) 

114 

115 nBins = pexConfig.Field( 

116 doc="Number of bins to put on the x axis.", 

117 default=40.0, 

118 dtype=float, 

119 ) 

120 

121 plot2DHist = pexConfig.Field( 

122 doc="Plot a 2D histogram in the densist area of points on the scatter plot." 

123 "Doesn't look great if plotting mulitple datasets on top of each other.", 

124 default=True, 

125 dtype=bool, 

126 ) 

127 

128 def setDefaults(self): 

129 super().setDefaults() 

130 self.axisActions.magAction.column = "i_cModelFlux" 

131 self.axisActions.xAction.column = "i_cModelFlux" 

132 self.highSnStatisticSelectorActions.statSelector.threshold = 2700 

133 self.lowSnStatisticSelectorActions.statSelector.threshold = 500 

134 

135 

136class ScatterPlotWithTwoHistsTask(pipeBase.PipelineTask): 

137 

138 ConfigClass = ScatterPlotWithTwoHistsTaskConfig 

139 _DefaultName = "scatterPlotWithTwoHistsTask" 

140 

141 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

142 # Docs inherited from base class 

143 bands, columnNames = self.config.get_requirements() 

144 inputs = butlerQC.get(inputRefs) 

145 dataFrame = inputs["catPlot"].get(parameters={"columns": columnNames}) 

146 inputs['catPlot'] = dataFrame 

147 dataId = butlerQC.quantum.dataId 

148 inputs["dataId"] = dataId 

149 inputs["runName"] = inputRefs.catPlot.datasetRef.run 

150 localConnections = self.config.ConnectionsClass(config=self.config) 

151 inputs["tableName"] = localConnections.catPlot.name 

152 inputs["plotName"] = localConnections.scatterPlot.name 

153 inputs["bands"] = bands 

154 outputs = self.run(**inputs) 

155 butlerQC.put(outputs, outputRefs) 

156 

157 def run(self, catPlot, dataId, runName, skymap, tableName, bands, plotName): 

158 """Prep the catalogue and then make a scatterPlot of the given column. 

159 

160 Parameters 

161 ---------- 

162 catPlot : `pandas.core.frame.DataFrame` 

163 The catalog to plot the points from. 

164 dataId : 

165 `lsst.daf.butler.core.dimensions._coordinate._ExpandedTupleDataCoordinate` 

166 The dimensions that the plot is being made from. 

167 runName : `str` 

168 The name of the collection that the plot is written out to. 

169 skymap : `lsst.skymap` 

170 The skymap used to define the patch boundaries. 

171 tableName : `str` 

172 The type of table used to make the plot. 

173 

174 Returns 

175 ------- 

176 `pipeBase.Struct` containing: 

177 scatterPlot : `matplotlib.figure.Figure` 

178 The resulting figure. 

179 

180 Notes 

181 ----- 

182 The catalogue is first narrowed down using the selectors specified in 

183 `self.config.selectorActions`. 

184 If the column names are 'Functor' then the functors specified in 

185 `self.config.axisFunctors` are used to calculate the required values. 

186 After this the following functions are run: 

187 

188 `parsePlotInfo` which uses the dataId, runName and tableName to add 

189 useful information to the plot. 

190 

191 `generateSummaryStats` which parses the skymap to give the corners of 

192 the patches for later plotting and calculates some basic statistics 

193 in each patch for the column in self.zColName. 

194 

195 `scatterPlotWithTwoHists` which makes a scatter plot of the points with 

196 a histogram of each axis. 

197 """ 

198 

199 # Apply the selectors to narrow down the sources to use 

200 mask = np.ones(len(catPlot), dtype=bool) 

201 for selector in self.config.selectorActions: 

202 mask &= selector(catPlot) 

203 catPlot = catPlot[mask] 

204 

205 columns = {self.config.axisLabels["x"]: self.config.axisActions.xAction(catPlot), 

206 self.config.axisLabels["y"]: self.config.axisActions.yAction(catPlot), 

207 self.config.axisLabels["mag"]: self.config.axisActions.magAction(catPlot), 

208 "patch": catPlot["patch"]} 

209 for actionStruct in [self.config.highSnStatisticSelectorActions, 

210 self.config.lowSnStatisticSelectorActions, 

211 self.config.sourceSelectorActions]: 

212 for action in actionStruct: 

213 for col in action.columns: 

214 columns.update({col: catPlot[col]}) 

215 

216 plotDf = pd.DataFrame(columns) 

217 

218 sourceTypes = np.zeros(len(plotDf)) 

219 for selector in self.config.sourceSelectorActions: 

220 # The source selectors return 1 for a star and 2 for a galaxy 

221 # rather than a mask this allows the information about which 

222 # type of sources are being plotted to be propagated 

223 sourceTypes += selector(catPlot) 

224 if list(self.config.sourceSelectorActions) == []: 

225 sourceTypes = [10]*len(plotDf) 

226 plotDf.loc[:, "sourceType"] = sourceTypes 

227 

228 # Decide which points to use for stats calculation 

229 useForStats = np.zeros(len(plotDf)) 

230 lowSnMask = np.ones(len(plotDf), dtype=bool) 

231 for selector in self.config.lowSnStatisticSelectorActions: 

232 lowSnMask &= selector(plotDf) 

233 useForStats[lowSnMask] = 2 

234 

235 highSnMask = np.ones(len(plotDf), dtype=bool) 

236 for selector in self.config.highSnStatisticSelectorActions: 

237 highSnMask &= selector(plotDf) 

238 useForStats[highSnMask] = 1 

239 plotDf.loc[:, "useForStats"] = useForStats 

240 

241 # Get the S/N cut used 

242 try: 

243 SN = self.config.selectorActions.SnSelector.threshold 

244 except AttributeError: 

245 SN = "N/A" 

246 

247 # Get useful information about the plot 

248 plotInfo = parsePlotInfo(dataId, runName, tableName, bands, plotName, SN) 

249 # Calculate the corners of the patches and some associated stats 

250 sumStats = {} if skymap is None else generateSummaryStats( 

251 plotDf, self.config.axisLabels["y"], skymap, plotInfo) 

252 # Make the plot 

253 fig = self.scatterPlotWithTwoHists(plotDf, plotInfo, sumStats) 

254 

255 return pipeBase.Struct(scatterPlot=fig) 

256 

257 def scatterPlotWithTwoHists(self, catPlot, plotInfo, sumStats, yLims=False, xLims=False): 

258 """Makes a generic plot with a 2D histogram and collapsed histograms of 

259 each axis. 

260 

261 Parameters 

262 ---------- 

263 catPlot : `pandas.core.frame.DataFrame` 

264 The catalog to plot the points from. 

265 plotInfo : `dict` 

266 A dictionary of information about the data being plotted with keys: 

267 ``"run"`` 

268 The output run for the plots (`str`). 

269 ``"skymap"`` 

270 The type of skymap used for the data (`str`). 

271 ``"filter"`` 

272 The filter used for this data (`str`). 

273 ``"tract"`` 

274 The tract that the data comes from (`str`). 

275 sumStats : `dict` 

276 A dictionary where the patchIds are the keys which store the R.A. 

277 and dec of the corners of the patch, along with a summary 

278 statistic for each patch. 

279 yLims : `Bool` or `tuple`, optional 

280 The y axis limits to use for the plot. If `False`, they are 

281 calculated from the data. If being given a tuple of 

282 (yMin, yMax). 

283 xLims : `Bool` or `tuple`, optional 

284 The x axis limits to use for the plot. If `False`, they are 

285 calculated from the data. 

286 If being given a tuple of (xMin, xMax). 

287 

288 Returns 

289 ------- 

290 fig : `matplotlib.figure.Figure` 

291 The resulting figure. 

292 

293 Notes 

294 ----- 

295 Uses the axisLabels config options `x` and `y` and the axisAction 

296 config options `xAction` and `yAction` to plot a scatter 

297 plot of the values against each other. A histogram of the points 

298 collapsed onto each axis is also plotted. A summary panel showing the 

299 median of the y value in each patch is shown in the upper right corner 

300 of the resultant plot. The code uses the selectorActions to decide 

301 which points to plot and the statisticSelector actions to determine 

302 which points to use for the printed statistics. 

303 """ 

304 self.log.info("Plotting {}: the values of {} on a scatter plot.".format( 

305 self.config.connections.plotName, self.config.axisLabels["y"])) 

306 

307 fig = plt.figure(dpi=300) 

308 gs = gridspec.GridSpec(4, 4) 

309 

310 newBlues = mkColormap(["paleturquoise", "midnightBlue"]) 

311 newReds = mkColormap(["lemonchiffon", "firebrick"]) 

312 

313 # Need to separate stars and galaxies 

314 stars = (catPlot["sourceType"] == 1) 

315 galaxies = (catPlot["sourceType"] == 2) 

316 

317 xCol = self.config.axisLabels["x"] 

318 yCol = self.config.axisLabels["y"] 

319 magCol = self.config.axisLabels["mag"] 

320 

321 # For galaxies 

322 xsGalaxies = catPlot.loc[galaxies, xCol] 

323 ysGalaxies = catPlot.loc[galaxies, yCol] 

324 

325 # For stars 

326 xsStars = catPlot.loc[stars, xCol] 

327 ysStars = catPlot.loc[stars, yCol] 

328 

329 highStats = {} 

330 highMags = {} 

331 lowStats = {} 

332 lowMags = {} 

333 

334 # sourceTypes: 1 - stars, 2 - galaxies, 9 - unknowns 

335 # 10 - all 

336 sourceTypeList = [1, 2, 9, 10] 

337 sourceTypeMapper = {"stars": 1, "galaxies": 2, "unknowns": 9, "all": 10} 

338 # Calculate some statistics 

339 for sourceType in sourceTypeList: 

340 if np.any(catPlot["sourceType"] == sourceType): 

341 sources = (catPlot["sourceType"] == sourceType) 

342 highSn = ((catPlot["useForStats"] == 1) & sources) 

343 highSnMed = np.nanmedian(catPlot.loc[highSn, yCol]) 

344 highSnMad = sigmaMad(catPlot.loc[highSn, yCol], nan_policy="omit") 

345 

346 lowSn = (((catPlot["useForStats"] == 1) | (catPlot["useForStats"] == 2)) & sources) 

347 lowSnMed = np.nanmedian(catPlot.loc[lowSn, yCol]) 

348 lowSnMad = sigmaMad(catPlot.loc[lowSn, yCol], nan_policy="omit") 

349 

350 highStatsStr = ("Median: {:0.2f} ".format(highSnMed) 

351 + r"$\sigma_{MAD}$: " + "{:0.2f}".format(highSnMad)) 

352 highStats[sourceType] = highStatsStr 

353 

354 lowStatsStr = ("Median: {:0.2f} ".format(lowSnMed) 

355 + r"$\sigma_{MAD}$: " + "{:0.2f}".format(lowSnMad)) 

356 lowStats[sourceType] = lowStatsStr 

357 

358 if np.sum(highSn) > 0: 

359 highMags[sourceType] = f"{np.nanmax(catPlot.loc[highSn, magCol]):.2f}" 

360 else: 

361 highMags[sourceType] = "-" 

362 if np.sum(lowSn) > 0.0: 

363 lowMags[sourceType] = f"{np.nanmax(catPlot.loc[lowSn, magCol]):.2f}" 

364 else: 

365 lowMags[sourceType] = "-" 

366 

367 # Main scatter plot 

368 ax = fig.add_subplot(gs[1:, :-1]) 

369 binThresh = 5 

370 

371 yBinsOut = [] 

372 linesForLegend = [] 

373 

374 if (np.any(catPlot["sourceType"] == sourceTypeMapper["stars"]) 

375 and not np.any(catPlot["sourceType"] == sourceTypeMapper["galaxies"])): 

376 toPlotList = [(xsStars.values, ysStars.values, "midnightblue", newBlues, 

377 sourceTypeMapper["stars"])] 

378 elif (np.any(catPlot["sourceType"] == sourceTypeMapper["galaxies"]) 

379 and not np.any(catPlot["sourceType"] == sourceTypeMapper["stars"])): 

380 toPlotList = [(xsGalaxies.values, ysGalaxies.values, "firebrick", newReds, 

381 sourceTypeMapper["galaxies"])] 

382 elif (np.any(catPlot["sourceType"] == sourceTypeMapper["stars"]) 

383 and np.any(catPlot["sourceType"] == sourceTypeMapper["galaxies"])): 

384 toPlotList = [(xsGalaxies.values, ysGalaxies.values, "firebrick", newReds, 

385 sourceTypeMapper["galaxies"]), 

386 (xsStars.values, ysStars.values, "midnightblue", newBlues, 

387 sourceTypeMapper["stars"])] 

388 if np.any(catPlot["sourceType"] == sourceTypeMapper["unknowns"]): 

389 unknowns = (catPlot["sourceType"] == sourceTypeMapper["unknowns"]) 

390 toPlotList = [(catPlot.loc[unknowns, xCol].values, catPlot.loc[unknowns, yCol].values, 

391 "green", sourceTypeMapper["unknowns"])] 

392 if np.any(catPlot["sourceType"] == sourceTypeMapper["all"]): 

393 toPlotList = [(catPlot[xCol].values, catPlot[yCol].values, "purple", sourceTypeMapper["all"])] 

394 

395 for (j, (xs, ys, color, cmap, sourceType)) in enumerate(toPlotList): 

396 if len(xs) < 2: 

397 medLine, = ax.plot(xs, np.nanmedian(ys), color, 

398 label="Median: {:0.2f}".format(np.nanmedian(ys)), lw=0.8) 

399 linesForLegend.append(medLine) 

400 sigMads = np.array([sigmaMad(ys, nan_policy="omit")]*len(xs)) 

401 sigMadLine, = ax.plot(xs, np.nanmedian(ys) + 1.0*sigMads, color, alpha=0.8, lw=0.8, 

402 label=r"$\sigma_{MAD}$: " + "{:0.2f}".format(sigMads[0])) 

403 ax.plot(xs, np.nanmedian(ys) - 1.0*sigMads, color, alpha=0.8) 

404 linesForLegend.append(sigMadLine) 

405 continue 

406 

407 [xs1, xs25, xs50, xs75, xs95, xs97] = np.nanpercentile(xs, [1, 25, 50, 75, 95, 97]) 

408 xScale = (xs97 - xs1)/20.0 # This is ~5% of the data range 

409 

410 # 40 was used as the default number of bins because it looked good 

411 xEdges = np.arange(np.nanmin(xs) - xScale, np.nanmax(xs) + xScale, 

412 (np.nanmax(xs) + xScale - (np.nanmin(xs) - xScale))/self.config.nBins) 

413 medYs = np.nanmedian(ys) 

414 sigMadYs = sigmaMad(ys, nan_policy="omit") 

415 fiveSigmaHigh = medYs + 5.0*sigMadYs 

416 fiveSigmaLow = medYs - 5.0*sigMadYs 

417 binSize = (fiveSigmaHigh - fiveSigmaLow)/101.0 

418 yEdges = np.arange(fiveSigmaLow, fiveSigmaHigh, binSize) 

419 

420 counts, xBins, yBins = np.histogram2d(xs, ys, bins=(xEdges, yEdges)) 

421 yBinsOut.append(yBins) 

422 countsYs = np.sum(counts, axis=1) 

423 

424 ids = np.where((countsYs > binThresh))[0] 

425 xEdgesPlot = xEdges[ids][1:] 

426 xEdges = xEdges[ids] 

427 

428 if len(ids) > 1: 

429 # Create the codes needed to turn the sigmaMad lines 

430 # into a path to speed up checking which points are 

431 # inside the area. 

432 codes = np.ones(len(xEdgesPlot)*2)*Path.LINETO 

433 codes[0] = Path.MOVETO 

434 codes[-1] = Path.CLOSEPOLY 

435 

436 meds = np.zeros(len(xEdgesPlot)) 

437 threeSigMadVerts = np.zeros((len(xEdgesPlot)*2, 2)) 

438 sigMads = np.zeros(len(xEdgesPlot)) 

439 

440 for (i, xEdge) in enumerate(xEdgesPlot): 

441 ids = np.where((xs < xEdge) & (xs > xEdges[i]) & (np.isfinite(ys)))[0] 

442 med = np.median(ys[ids]) 

443 sigMad = sigmaMad(ys[ids]) 

444 meds[i] = med 

445 sigMads[i] = sigMad 

446 threeSigMadVerts[i, :] = [xEdge, med + 3*sigMad] 

447 threeSigMadVerts[-(i + 1), :] = [xEdge, med - 3*sigMad] 

448 

449 medLine, = ax.plot(xEdgesPlot, meds, color, label="Running Median") 

450 linesForLegend.append(medLine) 

451 

452 # Make path to check which points lie within one sigma mad 

453 threeSigMadPath = Path(threeSigMadVerts, codes) 

454 

455 # Add lines for the median +/- 3 * sigma MAD 

456 threeSigMadLine, = ax.plot(xEdgesPlot, threeSigMadVerts[:len(xEdgesPlot), 1], color, 

457 alpha=0.4, label=r"3$\sigma_{MAD}$") 

458 ax.plot(xEdgesPlot[::-1], threeSigMadVerts[len(xEdgesPlot):, 1], color, alpha=0.4) 

459 

460 # Add lines for the median +/- 1 * sigma MAD 

461 sigMadLine, = ax.plot(xEdgesPlot, meds + 1.0*sigMads, color, alpha=0.8, 

462 label=r"$\sigma_{MAD}$") 

463 linesForLegend.append(sigMadLine) 

464 ax.plot(xEdgesPlot, meds - 1.0*sigMads, color, alpha=0.8) 

465 

466 # Add lines for the median +/- 2 * sigma MAD 

467 twoSigMadLine, = ax.plot(xEdgesPlot, meds + 2.0*sigMads, color, alpha=0.6, 

468 label=r"2$\sigma_{MAD}$") 

469 linesForLegend.append(twoSigMadLine) 

470 linesForLegend.append(threeSigMadLine) 

471 ax.plot(xEdgesPlot, meds - 2.0*sigMads, color, alpha=0.6) 

472 

473 # Check which points are outside 3 sigma MAD of the median 

474 # and plot these as points. 

475 inside = threeSigMadPath.contains_points(np.array([xs, ys]).T) 

476 points, = ax.plot(xs[~inside], ys[~inside], ".", ms=3, alpha=0.3, mfc=color, mec=color, 

477 zorder=-1) 

478 

479 # Add some stats text 

480 xPos = 0.65 - 0.4*j 

481 bbox = dict(edgecolor=color, linestyle="--", facecolor="none") 

482 highThresh = self.config.highSnStatisticSelectorActions.statSelector.threshold 

483 statText = f"S/N > {highThresh} Stats ({magCol} < {highMags[sourceType]})\n" 

484 statText += highStats[sourceType] 

485 fig.text(xPos, 0.087, statText, bbox=bbox, transform=fig.transFigure, fontsize=6) 

486 if highMags[sourceType] != "-": 

487 ax.axvline(float(highMags[sourceType]), color=color, ls="--") 

488 

489 bbox = dict(edgecolor=color, linestyle=":", facecolor="none") 

490 lowThresh = self.config.lowSnStatisticSelectorActions.statSelector.threshold 

491 statText = f"S/N > {lowThresh} Stats ({magCol} < {lowMags[sourceType]})\n" 

492 statText += lowStats[sourceType] 

493 fig.text(xPos, 0.017, statText, bbox=bbox, transform=fig.transFigure, fontsize=6) 

494 if lowMags[sourceType] != "-": 

495 ax.axvline(float(lowMags[sourceType]), color=color, ls=":") 

496 

497 if self.config.plot2DHist: 

498 histIm = ax.hexbin(xs[inside], ys[inside], gridsize=75, cmap=cmap, mincnt=1, zorder=-2) 

499 

500 else: 

501 points, = ax.plot(xs, ys, ".", ms=5, alpha=0.3, mfc=color, mec=color, zorder=-1) 

502 meds = np.array([np.nanmedian(ys)]*len(xs)) 

503 medLine, = ax.plot(xs, meds, color, label=f"Median: {np.nanmedian(ys):0.2f}", lw=0.8) 

504 linesForLegend.append(medLine) 

505 sigMads = np.array([sigmaMad(ys, nan_policy="omit")]*len(xs)) 

506 sigMadLine, = ax.plot(xs, meds + 1.0*sigMads, color, alpha=0.8, lw=0.8, 

507 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:0.2f}") 

508 ax.plot(xs, meds - 1.0*sigMads, color, alpha=0.8) 

509 linesForLegend.append(sigMadLine) 

510 histIm = None 

511 

512 # Set the scatter plot limits 

513 if len(ysStars) > 0: 

514 plotMed = np.nanmedian(ysStars) 

515 else: 

516 plotMed = np.nanmedian(ysGalaxies) 

517 if yLims: 

518 ax.set_ylim(yLims[0], yLims[1]) 

519 else: 

520 numSig = 4 

521 yLimMin = plotMed - numSig*sigMadYs 

522 yLimMax = plotMed + numSig*sigMadYs 

523 while (yLimMax < np.max(meds) or yLimMin > np.min(meds)) and numSig < 10: 

524 numSig += 1 

525 

526 numSig += 1 

527 yLimMin = plotMed - numSig*sigMadYs 

528 yLimMax = plotMed + numSig*sigMadYs 

529 ax.set_ylim(yLimMin, yLimMax) 

530 

531 if xLims: 

532 ax.set_xlim(xLims[0], xLims[1]) 

533 else: 

534 ax.set_xlim(xs1 - xScale, xs97) 

535 ax.set_xlim(np.nanmin(xs), np.nanmax(xs)) 

536 

537 # Add a line legend 

538 ax.legend(handles=linesForLegend, ncol=4, fontsize=6, loc="upper left", framealpha=0.9, 

539 edgecolor="k", borderpad=0.4, handlelength=1) 

540 

541 # Add axes labels 

542 ax.set_ylabel(yCol, fontsize=10, labelpad=10) 

543 ax.set_xlabel(xCol, fontsize=10, labelpad=2) 

544 

545 # Top histogram 

546 topHist = plt.gcf().add_subplot(gs[0, :-1], sharex=ax) 

547 topHist.hist(catPlot[xCol].values, bins=100, color="grey", alpha=0.3, log=True, 

548 label=f"All ({len(catPlot)})") 

549 if np.any(catPlot["sourceType"] == 2): 

550 topHist.hist(xsGalaxies, bins=100, color="firebrick", histtype="step", log=True, 

551 label=f"Galaxies ({len(np.where(galaxies)[0])})") 

552 if np.any(catPlot["sourceType"] == 1): 

553 topHist.hist(xsStars, bins=100, color="midnightblue", histtype="step", log=True, 

554 label=f"Stars ({len(np.where(stars)[0])})") 

555 topHist.axes.get_xaxis().set_visible(False) 

556 topHist.set_ylabel("Number", fontsize=8) 

557 topHist.legend(fontsize=6, framealpha=0.9, borderpad=0.4, loc="lower left", ncol=3, edgecolor="k") 

558 

559 # Side histogram 

560 sideHist = plt.gcf().add_subplot(gs[1:, -1], sharey=ax) 

561 finiteObjs = np.isfinite(catPlot[yCol].values) 

562 bins = np.linspace(yLimMin, yLimMax) 

563 sideHist.hist(catPlot[yCol].values[finiteObjs], bins=bins, color="grey", alpha=0.3, 

564 orientation="horizontal", log=True) 

565 if np.any(catPlot["sourceType"] == sourceTypeMapper["galaxies"]): 

566 sideHist.hist(ysGalaxies[np.isfinite(ysGalaxies)], bins=bins, color="firebrick", histtype="step", 

567 orientation="horizontal", log=True) 

568 if highMags[sourceTypeMapper["galaxies"]] != "-": 

569 sideHist.hist(ysGalaxies[np.isfinite(ysGalaxies) & (xsGalaxies < float(highMags[2]))], 

570 bins=bins, color="firebrick", histtype="step", orientation="horizontal", 

571 log=True, ls="--") 

572 if lowMags[sourceTypeMapper["galaxies"]] != "-": 

573 sideHist.hist(ysGalaxies[np.isfinite(ysGalaxies) & (xsGalaxies < float(lowMags[2]))], 

574 bins=bins, color="firebrick", histtype="step", orientation="horizontal", 

575 log=True, ls=":") 

576 

577 if np.any(catPlot["sourceType"] == sourceTypeMapper["stars"]): 

578 sideHist.hist(ysStars[np.isfinite(ysStars)], bins=bins, color="midnightblue", histtype="step", 

579 orientation="horizontal", log=True) 

580 if highMags[sourceTypeMapper["stars"]] != "-": 

581 sideHist.hist(ysStars[np.isfinite(ysStars) & (xsStars < float(highMags[1]))], bins=bins, 

582 color="midnightblue", histtype="step", orientation="horizontal", log=True, 

583 ls="--") 

584 if lowMags[sourceTypeMapper["stars"]] != "-": 

585 sideHist.hist(ysStars[np.isfinite(ysStars) & (xsStars < float(lowMags[1]))], bins=bins, 

586 color="midnightblue", histtype="step", orientation="horizontal", log=True, 

587 ls=":") 

588 

589 sideHist.axes.get_yaxis().set_visible(False) 

590 sideHist.set_xlabel("Number", fontsize=8) 

591 if self.config.plot2DHist and histIm is not None: 

592 divider = make_axes_locatable(sideHist) 

593 cax = divider.append_axes("right", size="8%", pad=0) 

594 plt.colorbar(histIm, cax=cax, orientation="vertical", label="Number of Points Per Bin") 

595 

596 # Corner plot of patches showing summary stat in each 

597 axCorner = plt.gcf().add_subplot(gs[0, -1]) 

598 axCorner.yaxis.tick_right() 

599 axCorner.yaxis.set_label_position("right") 

600 axCorner.xaxis.tick_top() 

601 axCorner.xaxis.set_label_position("top") 

602 axCorner.set_aspect("equal") 

603 

604 patches = [] 

605 colors = [] 

606 for dataId in sumStats.keys(): 

607 (corners, stat) = sumStats[dataId] 

608 ra = corners[0][0].asDegrees() 

609 dec = corners[0][1].asDegrees() 

610 xy = (ra, dec) 

611 width = corners[2][0].asDegrees() - ra 

612 height = corners[2][1].asDegrees() - dec 

613 patches.append(Rectangle(xy, width, height)) 

614 colors.append(stat) 

615 ras = [ra.asDegrees() for (ra, dec) in corners] 

616 decs = [dec.asDegrees() for (ra, dec) in corners] 

617 axCorner.plot(ras + [ras[0]], decs + [decs[0]], "k", lw=0.5) 

618 cenX = ra + width / 2 

619 cenY = dec + height / 2 

620 if dataId != "tract": 

621 axCorner.annotate(dataId, (cenX, cenY), color="k", fontsize=4, ha="center", va="center") 

622 

623 cmapUse = plt.cm.coolwarm 

624 # Set the bad color to transparent and make a masked array 

625 cmapUse.set_bad(color="none") 

626 colors = np.ma.array(colors, mask=np.isnan(colors)) 

627 collection = PatchCollection(patches, cmap=cmapUse) 

628 collection.set_array(colors) 

629 axCorner.add_collection(collection) 

630 

631 axCorner.set_xlabel("R.A. (deg)", fontsize=7) 

632 axCorner.set_ylabel("Dec. (deg)", fontsize=7) 

633 axCorner.tick_params(axis="both", labelsize=6, length=0, pad=1.5) 

634 axCorner.invert_xaxis() 

635 

636 # Add a colorbar 

637 pos = axCorner.get_position() 

638 cax = fig.add_axes([pos.x0, pos.y0 + 0.23, pos.x1 - pos.x0, 0.025]) 

639 plt.colorbar(collection, cax=cax, orientation="horizontal") 

640 cax.text(0.5, 0.5, "Median Value", color="k", transform=cax.transAxes, rotation="horizontal", 

641 horizontalalignment="center", verticalalignment="center", fontsize=6) 

642 cax.tick_params(axis="x", labelsize=6, labeltop=True, labelbottom=False, bottom=False, top=True, 

643 pad=0.5, length=2) 

644 

645 plt.draw() 

646 plt.subplots_adjust(wspace=0.0, hspace=0.0, bottom=0.22, left=0.21) 

647 fig = plt.gcf() 

648 fig = addPlotInfo(fig, plotInfo) 

649 

650 return fig