Coverage for python/lsst/analysis/drp/scatterPlot.py: 11%

Shortcuts on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

335 statements  

1# This file is part of analysis_drp. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import matplotlib.pyplot as plt 

23import numpy as np 

24import pandas as pd 

25from matplotlib import gridspec 

26from matplotlib.patches import Rectangle 

27from matplotlib.path import Path 

28from matplotlib.collections import PatchCollection 

29from mpl_toolkits.axes_grid1 import make_axes_locatable 

30from lsst.pipe.tasks.configurableActions import ConfigurableActionStructField 

31from lsst.pipe.tasks.dataFrameActions import MagColumnNanoJansky, SingleColumnAction 

32from lsst.skymap import BaseSkyMap 

33 

34import lsst.pipe.base as pipeBase 

35import lsst.pex.config as pexConfig 

36 

37from . import dataSelectors as dataSelectors 

38from .plotUtils import generateSummaryStats, parsePlotInfo, addPlotInfo, mkColormap 

39from .statistics import sigmaMad 

40 

41cmapPatch = plt.cm.coolwarm.copy() 

42cmapPatch.set_bad(color="none") 

43 

44 

45class ScatterPlotWithTwoHistsTaskConnections(pipeBase.PipelineTaskConnections, 

46 dimensions=("tract", "skymap"), 

47 defaultTemplates={"inputCoaddName": "deep", 

48 "plotName": "deltaCoords"}): 

49 

50 catPlot = pipeBase.connectionTypes.Input(doc="The tract wide catalog to make plots from.", 

51 storageClass="DataFrame", 

52 name="objectTable_tract", 

53 dimensions=("tract", "skymap"), 

54 deferLoad=True) 

55 

56 skymap = pipeBase.connectionTypes.Input(doc="The skymap for the tract", 

57 storageClass="SkyMap", 

58 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME, 

59 dimensions=("skymap",)) 

60 

61 scatterPlot = pipeBase.connectionTypes.Output(doc="A scatter plot with histograms for both axes.", 

62 storageClass="Plot", 

63 name="scatterTwoHistPlot_{plotName}", 

64 dimensions=("tract", "skymap")) 

65 

66 

67class ScatterPlotWithTwoHistsTaskConfig(pipeBase.PipelineTaskConfig, 

68 pipelineConnections=ScatterPlotWithTwoHistsTaskConnections): 

69 

70 axisActions = ConfigurableActionStructField( 

71 doc="The actions to use to calculate the values used on each axis. The defaults for the" 

72 "column names xAction and magAction are set to iCModelFlux.", 

73 default={"xAction": MagColumnNanoJansky, "yAction": SingleColumnAction, 

74 "magAction": MagColumnNanoJansky}, 

75 ) 

76 

77 axisLabels = pexConfig.DictField( 

78 doc="Name of the dataframe columns to plot, will be used as the axis label: {'x':, 'y':, 'mag':}" 

79 "The mag column is used to decide which points to include in the printed statistics.", 

80 keytype=str, 

81 itemtype=str 

82 ) 

83 

84 def get_requirements(self): 

85 """Return inputs required for a Task to run with this config. 

86 

87 Returns 

88 ------- 

89 bands : `set` 

90 The required bands. 

91 columns : `set` 

92 The required column names. 

93 """ 

94 columnNames = {"patch"} 

95 bands = set() 

96 for actionStruct in [self.axisActions, 

97 self.selectorActions, 

98 self.highSnStatisticSelectorActions, 

99 self.lowSnStatisticSelectorActions, 

100 self.sourceSelectorActions]: 

101 for action in actionStruct: 

102 for col in action.columns: 

103 if col is not None: 

104 columnNames.add(col) 

105 column_split = col.split("_") 

106 # If there's no underscore, it has no band prefix 

107 if len(column_split) > 1: 

108 band = column_split[0] 

109 if band not in self.nonBandColumnPrefixes: 

110 bands.add(band) 

111 return bands, columnNames 

112 

113 nonBandColumnPrefixes = pexConfig.ListField( 

114 doc="Column prefixes that are not bands and which should not be added to the set of bands", 

115 dtype=str, 

116 default=["coord", "extend", "detect", "xy", "merge"], 

117 ) 

118 

119 selectorActions = ConfigurableActionStructField( 

120 doc="Which selectors to use to narrow down the data for QA plotting.", 

121 default={"flagSelector": dataSelectors.CoaddPlotFlagSelector}, 

122 ) 

123 

124 highSnStatisticSelectorActions = ConfigurableActionStructField( 

125 doc="Selectors to use to decide which points to use for calculating the high SN statistics.", 

126 default={"statSelector": dataSelectors.SnSelector}, 

127 ) 

128 

129 lowSnStatisticSelectorActions = ConfigurableActionStructField( 

130 doc="Selectors to use to decide which points to use for calculating the low SN statistics.", 

131 default={"statSelector": dataSelectors.SnSelector}, 

132 ) 

133 

134 sourceSelectorActions = ConfigurableActionStructField( 

135 doc="What types of sources to use.", 

136 default={"sourceSelector": dataSelectors.StarIdentifier}, 

137 ) 

138 

139 nBins = pexConfig.Field( 

140 doc="Number of bins to put on the x axis.", 

141 default=40.0, 

142 dtype=float, 

143 ) 

144 

145 plot2DHist = pexConfig.Field( 

146 doc="Plot a 2D histogram in the densist area of points on the scatter plot." 

147 "Doesn't look great if plotting mulitple datasets on top of each other.", 

148 default=True, 

149 dtype=bool, 

150 ) 

151 

152 def setDefaults(self): 

153 super().setDefaults() 

154 self.axisActions.magAction.column = "i_cModelFlux" 

155 self.axisActions.xAction.column = "i_cModelFlux" 

156 self.highSnStatisticSelectorActions.statSelector.threshold = 2700 

157 self.lowSnStatisticSelectorActions.statSelector.threshold = 500 

158 

159 

160class ScatterPlotWithTwoHistsTask(pipeBase.PipelineTask): 

161 

162 ConfigClass = ScatterPlotWithTwoHistsTaskConfig 

163 _DefaultName = "scatterPlotWithTwoHistsTask" 

164 

165 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

166 # Docs inherited from base class 

167 bands, columnNames = self.config.get_requirements() 

168 inputs = butlerQC.get(inputRefs) 

169 dataFrame = inputs["catPlot"].get(parameters={"columns": columnNames}) 

170 inputs['catPlot'] = dataFrame 

171 dataId = butlerQC.quantum.dataId 

172 inputs["dataId"] = dataId 

173 inputs["runName"] = inputRefs.catPlot.datasetRef.run 

174 localConnections = self.config.ConnectionsClass(config=self.config) 

175 inputs["tableName"] = localConnections.catPlot.name 

176 inputs["plotName"] = localConnections.scatterPlot.name 

177 inputs["bands"] = bands 

178 outputs = self.run(**inputs) 

179 butlerQC.put(outputs, outputRefs) 

180 

181 def run(self, catPlot, dataId, runName, skymap, tableName, bands, plotName): 

182 """Prep the catalogue and then make a scatterPlot of the given column. 

183 

184 Parameters 

185 ---------- 

186 catPlot : `pandas.core.frame.DataFrame` 

187 The catalog to plot the points from. 

188 dataId : 

189 `lsst.daf.butler.core.dimensions._coordinate._ExpandedTupleDataCoordinate` 

190 The dimensions that the plot is being made from. 

191 runName : `str` 

192 The name of the collection that the plot is written out to. 

193 skymap : `lsst.skymap` 

194 The skymap used to define the patch boundaries. 

195 tableName : `str` 

196 The type of table used to make the plot. 

197 

198 Returns 

199 ------- 

200 `pipeBase.Struct` containing: 

201 scatterPlot : `matplotlib.figure.Figure` 

202 The resulting figure. 

203 

204 Notes 

205 ----- 

206 The catalogue is first narrowed down using the selectors specified in 

207 `self.config.selectorActions`. 

208 If the column names are 'Functor' then the functors specified in 

209 `self.config.axisFunctors` are used to calculate the required values. 

210 After this the following functions are run: 

211 

212 `parsePlotInfo` which uses the dataId, runName and tableName to add 

213 useful information to the plot. 

214 

215 `generateSummaryStats` which parses the skymap to give the corners of 

216 the patches for later plotting and calculates some basic statistics 

217 in each patch for the column in self.zColName. 

218 

219 `scatterPlotWithTwoHists` which makes a scatter plot of the points with 

220 a histogram of each axis. 

221 """ 

222 

223 # Apply the selectors to narrow down the sources to use 

224 mask = np.ones(len(catPlot), dtype=bool) 

225 for selector in self.config.selectorActions: 

226 mask &= selector(catPlot) 

227 catPlot = catPlot[mask] 

228 

229 columns = {self.config.axisLabels["x"]: self.config.axisActions.xAction(catPlot), 

230 self.config.axisLabels["y"]: self.config.axisActions.yAction(catPlot), 

231 self.config.axisLabels["mag"]: self.config.axisActions.magAction(catPlot), 

232 "patch": catPlot["patch"]} 

233 for actionStruct in [self.config.highSnStatisticSelectorActions, 

234 self.config.lowSnStatisticSelectorActions, 

235 self.config.sourceSelectorActions]: 

236 for action in actionStruct: 

237 for col in action.columns: 

238 columns.update({col: catPlot[col]}) 

239 

240 plotDf = pd.DataFrame(columns) 

241 

242 sourceTypes = np.zeros(len(plotDf)) 

243 for selector in self.config.sourceSelectorActions: 

244 # The source selectors return 1 for a star and 2 for a galaxy 

245 # rather than a mask this allows the information about which 

246 # type of sources are being plotted to be propagated 

247 sourceTypes += selector(catPlot) 

248 if list(self.config.sourceSelectorActions) == []: 

249 sourceTypes = [10]*len(plotDf) 

250 plotDf.loc[:, "sourceType"] = sourceTypes 

251 

252 # Decide which points to use for stats calculation 

253 useForStats = np.zeros(len(plotDf)) 

254 lowSnMask = np.ones(len(plotDf), dtype=bool) 

255 for selector in self.config.lowSnStatisticSelectorActions: 

256 lowSnMask &= selector(plotDf) 

257 useForStats[lowSnMask] = 2 

258 

259 highSnMask = np.ones(len(plotDf), dtype=bool) 

260 for selector in self.config.highSnStatisticSelectorActions: 

261 highSnMask &= selector(plotDf) 

262 useForStats[highSnMask] = 1 

263 plotDf.loc[:, "useForStats"] = useForStats 

264 

265 # Get the S/N cut used 

266 try: 

267 SN = self.config.selectorActions.SnSelector.threshold 

268 except AttributeError: 

269 SN = "N/A" 

270 

271 # Get useful information about the plot 

272 plotInfo = parsePlotInfo(dataId, runName, tableName, bands, plotName, SN) 

273 # Calculate the corners of the patches and some associated stats 

274 sumStats = {} if skymap is None else generateSummaryStats( 

275 plotDf, self.config.axisLabels["y"], skymap, plotInfo) 

276 # Make the plot 

277 try: 

278 fig = self.scatterPlotWithTwoHists(plotDf, plotInfo, sumStats) 

279 except Exception: 

280 # This is a workaround until scatterPlotWithTwoHists works properly 

281 fig = plt.figure(dpi=300) 

282 

283 return pipeBase.Struct(scatterPlot=fig) 

284 

285 def scatterPlotWithTwoHists(self, catPlot, plotInfo, sumStats, yLims=False, xLims=False): 

286 """Makes a generic plot with a 2D histogram and collapsed histograms of 

287 each axis. 

288 

289 Parameters 

290 ---------- 

291 catPlot : `pandas.core.frame.DataFrame` 

292 The catalog to plot the points from. 

293 plotInfo : `dict` 

294 A dictionary of information about the data being plotted with keys: 

295 ``"run"`` 

296 The output run for the plots (`str`). 

297 ``"skymap"`` 

298 The type of skymap used for the data (`str`). 

299 ``"filter"`` 

300 The filter used for this data (`str`). 

301 ``"tract"`` 

302 The tract that the data comes from (`str`). 

303 sumStats : `dict` 

304 A dictionary where the patchIds are the keys which store the R.A. 

305 and dec of the corners of the patch, along with a summary 

306 statistic for each patch. 

307 yLims : `Bool` or `tuple`, optional 

308 The y axis limits to use for the plot. If `False`, they are 

309 calculated from the data. If being given a tuple of 

310 (yMin, yMax). 

311 xLims : `Bool` or `tuple`, optional 

312 The x axis limits to use for the plot. If `False`, they are 

313 calculated from the data. 

314 If being given a tuple of (xMin, xMax). 

315 

316 Returns 

317 ------- 

318 fig : `matplotlib.figure.Figure` 

319 The resulting figure. 

320 

321 Notes 

322 ----- 

323 Uses the axisLabels config options `x` and `y` and the axisAction 

324 config options `xAction` and `yAction` to plot a scatter 

325 plot of the values against each other. A histogram of the points 

326 collapsed onto each axis is also plotted. A summary panel showing the 

327 median of the y value in each patch is shown in the upper right corner 

328 of the resultant plot. The code uses the selectorActions to decide 

329 which points to plot and the statisticSelector actions to determine 

330 which points to use for the printed statistics. 

331 """ 

332 self.log.info("Plotting {}: the values of {} on a scatter plot.".format( 

333 self.config.connections.plotName, self.config.axisLabels["y"])) 

334 

335 fig = plt.figure(dpi=300) 

336 gs = gridspec.GridSpec(4, 4) 

337 

338 newBlues = mkColormap(["paleturquoise", "midnightBlue"]) 

339 newReds = mkColormap(["lemonchiffon", "firebrick"]) 

340 

341 # Need to separate stars and galaxies 

342 stars = (catPlot["sourceType"] == 1) 

343 galaxies = (catPlot["sourceType"] == 2) 

344 

345 xCol = self.config.axisLabels["x"] 

346 yCol = self.config.axisLabels["y"] 

347 magCol = self.config.axisLabels["mag"] 

348 

349 # For galaxies 

350 xsGalaxies = catPlot.loc[galaxies, xCol] 

351 ysGalaxies = catPlot.loc[galaxies, yCol] 

352 

353 # For stars 

354 xsStars = catPlot.loc[stars, xCol] 

355 ysStars = catPlot.loc[stars, yCol] 

356 

357 highStats = {} 

358 highMags = {} 

359 lowStats = {} 

360 lowMags = {} 

361 

362 # sourceTypes: 1 - stars, 2 - galaxies, 9 - unknowns 

363 # 10 - all 

364 sourceTypeList = [1, 2, 9, 10] 

365 sourceTypeMapper = {"stars": 1, "galaxies": 2, "unknowns": 9, "all": 10} 

366 # Calculate some statistics 

367 for sourceType in sourceTypeList: 

368 if np.any(catPlot["sourceType"] == sourceType): 

369 sources = (catPlot["sourceType"] == sourceType) 

370 highSn = ((catPlot["useForStats"] == 1) & sources) 

371 highSnMed = np.nanmedian(catPlot.loc[highSn, yCol]) 

372 highSnMad = sigmaMad(catPlot.loc[highSn, yCol], nan_policy="omit") 

373 

374 lowSn = (((catPlot["useForStats"] == 1) | (catPlot["useForStats"] == 2)) & sources) 

375 lowSnMed = np.nanmedian(catPlot.loc[lowSn, yCol]) 

376 lowSnMad = sigmaMad(catPlot.loc[lowSn, yCol], nan_policy="omit") 

377 

378 highStatsStr = ("Median: {:0.2f} ".format(highSnMed) 

379 + r"$\sigma_{MAD}$: " + "{:0.2f}".format(highSnMad)) 

380 highStats[sourceType] = highStatsStr 

381 

382 lowStatsStr = ("Median: {:0.2f} ".format(lowSnMed) 

383 + r"$\sigma_{MAD}$: " + "{:0.2f}".format(lowSnMad)) 

384 lowStats[sourceType] = lowStatsStr 

385 

386 if np.sum(highSn) > 0: 

387 highMags[sourceType] = f"{np.nanmax(catPlot.loc[highSn, magCol]):.2f}" 

388 else: 

389 highMags[sourceType] = "-" 

390 if np.sum(lowSn) > 0.0: 

391 lowMags[sourceType] = f"{np.nanmax(catPlot.loc[lowSn, magCol]):.2f}" 

392 else: 

393 lowMags[sourceType] = "-" 

394 

395 # Main scatter plot 

396 ax = fig.add_subplot(gs[1:, :-1]) 

397 binThresh = 5 

398 

399 yBinsOut = [] 

400 linesForLegend = [] 

401 

402 if (np.any(catPlot["sourceType"] == sourceTypeMapper["stars"]) 

403 and not np.any(catPlot["sourceType"] == sourceTypeMapper["galaxies"])): 

404 toPlotList = [(xsStars.values, ysStars.values, "midnightblue", newBlues, 

405 sourceTypeMapper["stars"])] 

406 elif (np.any(catPlot["sourceType"] == sourceTypeMapper["galaxies"]) 

407 and not np.any(catPlot["sourceType"] == sourceTypeMapper["stars"])): 

408 toPlotList = [(xsGalaxies.values, ysGalaxies.values, "firebrick", newReds, 

409 sourceTypeMapper["galaxies"])] 

410 elif (np.any(catPlot["sourceType"] == sourceTypeMapper["stars"]) 

411 and np.any(catPlot["sourceType"] == sourceTypeMapper["galaxies"])): 

412 toPlotList = [(xsGalaxies.values, ysGalaxies.values, "firebrick", newReds, 

413 sourceTypeMapper["galaxies"]), 

414 (xsStars.values, ysStars.values, "midnightblue", newBlues, 

415 sourceTypeMapper["stars"])] 

416 elif np.any(catPlot["sourceType"] == sourceTypeMapper["unknowns"]): 

417 unknowns = (catPlot["sourceType"] == sourceTypeMapper["unknowns"]) 

418 toPlotList = [(catPlot.loc[unknowns, xCol].values, catPlot.loc[unknowns, yCol].values, 

419 "green", None, sourceTypeMapper["unknowns"])] 

420 elif np.any(catPlot["sourceType"] == sourceTypeMapper["all"]): 

421 toPlotList = [(catPlot[xCol].values, catPlot[yCol].values, "purple", None, 

422 sourceTypeMapper["all"])] 

423 else: 

424 toPlotList = [] 

425 

426 for (j, (xs, ys, color, cmap, sourceType)) in enumerate(toPlotList): 

427 if len(xs) < 2: 

428 medLine, = ax.plot(xs, np.nanmedian(ys), color, 

429 label="Median: {:0.2f}".format(np.nanmedian(ys)), lw=0.8) 

430 linesForLegend.append(medLine) 

431 sigMads = np.array([sigmaMad(ys, nan_policy="omit")]*len(xs)) 

432 sigMadLine, = ax.plot(xs, np.nanmedian(ys) + 1.0*sigMads, color, alpha=0.8, lw=0.8, 

433 label=r"$\sigma_{MAD}$: " + "{:0.2f}".format(sigMads[0])) 

434 ax.plot(xs, np.nanmedian(ys) - 1.0*sigMads, color, alpha=0.8) 

435 linesForLegend.append(sigMadLine) 

436 continue 

437 

438 [xs1, xs25, xs50, xs75, xs95, xs97] = np.nanpercentile(xs, [1, 25, 50, 75, 95, 97]) 

439 xScale = (xs97 - xs1)/20.0 # This is ~5% of the data range 

440 

441 # 40 was used as the default number of bins because it looked good 

442 xEdges = np.arange(np.nanmin(xs) - xScale, np.nanmax(xs) + xScale, 

443 (np.nanmax(xs) + xScale - (np.nanmin(xs) - xScale))/self.config.nBins) 

444 medYs = np.nanmedian(ys) 

445 sigMadYs = sigmaMad(ys, nan_policy="omit") 

446 fiveSigmaHigh = medYs + 5.0*sigMadYs 

447 fiveSigmaLow = medYs - 5.0*sigMadYs 

448 binSize = (fiveSigmaHigh - fiveSigmaLow)/101.0 

449 yEdges = np.arange(fiveSigmaLow, fiveSigmaHigh, binSize) 

450 

451 counts, xBins, yBins = np.histogram2d(xs, ys, bins=(xEdges, yEdges)) 

452 yBinsOut.append(yBins) 

453 countsYs = np.sum(counts, axis=1) 

454 

455 ids = np.where((countsYs > binThresh))[0] 

456 xEdgesPlot = xEdges[ids][1:] 

457 xEdges = xEdges[ids] 

458 

459 if len(ids) > 1: 

460 # Create the codes needed to turn the sigmaMad lines 

461 # into a path to speed up checking which points are 

462 # inside the area. 

463 codes = np.ones(len(xEdgesPlot)*2)*Path.LINETO 

464 codes[0] = Path.MOVETO 

465 codes[-1] = Path.CLOSEPOLY 

466 

467 meds = np.zeros(len(xEdgesPlot)) 

468 threeSigMadVerts = np.zeros((len(xEdgesPlot)*2, 2)) 

469 sigMads = np.zeros(len(xEdgesPlot)) 

470 

471 for (i, xEdge) in enumerate(xEdgesPlot): 

472 ids = np.where((xs < xEdge) & (xs > xEdges[i]) & (np.isfinite(ys)))[0] 

473 med = np.median(ys[ids]) 

474 sigMad = sigmaMad(ys[ids]) 

475 meds[i] = med 

476 sigMads[i] = sigMad 

477 threeSigMadVerts[i, :] = [xEdge, med + 3*sigMad] 

478 threeSigMadVerts[-(i + 1), :] = [xEdge, med - 3*sigMad] 

479 

480 medLine, = ax.plot(xEdgesPlot, meds, color, label="Running Median") 

481 linesForLegend.append(medLine) 

482 

483 # Make path to check which points lie within one sigma mad 

484 threeSigMadPath = Path(threeSigMadVerts, codes) 

485 

486 # Add lines for the median +/- 3 * sigma MAD 

487 threeSigMadLine, = ax.plot(xEdgesPlot, threeSigMadVerts[:len(xEdgesPlot), 1], color, 

488 alpha=0.4, label=r"3$\sigma_{MAD}$") 

489 ax.plot(xEdgesPlot[::-1], threeSigMadVerts[len(xEdgesPlot):, 1], color, alpha=0.4) 

490 

491 # Add lines for the median +/- 1 * sigma MAD 

492 sigMadLine, = ax.plot(xEdgesPlot, meds + 1.0*sigMads, color, alpha=0.8, 

493 label=r"$\sigma_{MAD}$") 

494 linesForLegend.append(sigMadLine) 

495 ax.plot(xEdgesPlot, meds - 1.0*sigMads, color, alpha=0.8) 

496 

497 # Add lines for the median +/- 2 * sigma MAD 

498 twoSigMadLine, = ax.plot(xEdgesPlot, meds + 2.0*sigMads, color, alpha=0.6, 

499 label=r"2$\sigma_{MAD}$") 

500 linesForLegend.append(twoSigMadLine) 

501 linesForLegend.append(threeSigMadLine) 

502 ax.plot(xEdgesPlot, meds - 2.0*sigMads, color, alpha=0.6) 

503 

504 # Check which points are outside 3 sigma MAD of the median 

505 # and plot these as points. 

506 inside = threeSigMadPath.contains_points(np.array([xs, ys]).T) 

507 ax.plot(xs[~inside], ys[~inside], ".", ms=3, alpha=0.3, mfc=color, mec=color, zorder=-1) 

508 

509 # Add some stats text 

510 xPos = 0.65 - 0.4*j 

511 bbox = dict(edgecolor=color, linestyle="--", facecolor="none") 

512 highThresh = self.config.highSnStatisticSelectorActions.statSelector.threshold 

513 statText = f"S/N > {highThresh} Stats ({magCol} < {highMags[sourceType]})\n" 

514 statText += highStats[sourceType] 

515 fig.text(xPos, 0.087, statText, bbox=bbox, transform=fig.transFigure, fontsize=6) 

516 if highMags[sourceType] != "-": 

517 ax.axvline(float(highMags[sourceType]), color=color, ls="--") 

518 

519 bbox = dict(edgecolor=color, linestyle=":", facecolor="none") 

520 lowThresh = self.config.lowSnStatisticSelectorActions.statSelector.threshold 

521 statText = f"S/N > {lowThresh} Stats ({magCol} < {lowMags[sourceType]})\n" 

522 statText += lowStats[sourceType] 

523 fig.text(xPos, 0.017, statText, bbox=bbox, transform=fig.transFigure, fontsize=6) 

524 if lowMags[sourceType] != "-": 

525 ax.axvline(float(lowMags[sourceType]), color=color, ls=":") 

526 

527 if self.config.plot2DHist: 

528 histIm = ax.hexbin(xs[inside], ys[inside], gridsize=75, cmap=cmap, mincnt=1, zorder=-2) 

529 

530 else: 

531 ax.plot(xs, ys, ".", ms=5, alpha=0.3, mfc=color, mec=color, zorder=-1) 

532 meds = np.array([np.nanmedian(ys)]*len(xs)) 

533 medLine, = ax.plot(xs, meds, color, label=f"Median: {np.nanmedian(ys):0.2f}", lw=0.8) 

534 linesForLegend.append(medLine) 

535 sigMads = np.array([sigmaMad(ys, nan_policy="omit")]*len(xs)) 

536 sigMadLine, = ax.plot(xs, meds + 1.0*sigMads, color, alpha=0.8, lw=0.8, 

537 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:0.2f}") 

538 ax.plot(xs, meds - 1.0*sigMads, color, alpha=0.8) 

539 linesForLegend.append(sigMadLine) 

540 histIm = None 

541 

542 # Set the scatter plot limits 

543 if len(ysStars) > 0: 

544 plotMed = np.nanmedian(ysStars) 

545 else: 

546 plotMed = np.nanmedian(ysGalaxies) 

547 if yLims: 

548 ax.set_ylim(yLims[0], yLims[1]) 

549 else: 

550 numSig = 4 

551 yLimMin = plotMed - numSig*sigMadYs 

552 yLimMax = plotMed + numSig*sigMadYs 

553 while (yLimMax < np.max(meds) or yLimMin > np.min(meds)) and numSig < 10: 

554 numSig += 1 

555 

556 numSig += 1 

557 yLimMin = plotMed - numSig*sigMadYs 

558 yLimMax = plotMed + numSig*sigMadYs 

559 ax.set_ylim(yLimMin, yLimMax) 

560 

561 if xLims: 

562 ax.set_xlim(xLims[0], xLims[1]) 

563 else: 

564 ax.set_xlim(xs1 - xScale, xs97 + xScale) 

565 

566 # Add a line legend 

567 ax.legend(handles=linesForLegend, ncol=4, fontsize=6, loc="upper left", framealpha=0.9, 

568 edgecolor="k", borderpad=0.4, handlelength=1) 

569 

570 # Add axes labels 

571 ax.set_ylabel(yCol, fontsize=10, labelpad=10) 

572 ax.set_xlabel(xCol, fontsize=10, labelpad=2) 

573 

574 # Top histogram 

575 topHist = plt.gcf().add_subplot(gs[0, :-1], sharex=ax) 

576 topHist.hist(catPlot[xCol].values, bins=100, color="grey", alpha=0.3, log=True, 

577 label=f"All ({len(catPlot)})") 

578 if np.any(catPlot["sourceType"] == 2): 

579 topHist.hist(xsGalaxies, bins=100, color="firebrick", histtype="step", log=True, 

580 label=f"Galaxies ({len(np.where(galaxies)[0])})") 

581 if np.any(catPlot["sourceType"] == 1): 

582 topHist.hist(xsStars, bins=100, color="midnightblue", histtype="step", log=True, 

583 label=f"Stars ({len(np.where(stars)[0])})") 

584 topHist.axes.get_xaxis().set_visible(False) 

585 topHist.set_ylabel("Number", fontsize=8) 

586 topHist.legend(fontsize=6, framealpha=0.9, borderpad=0.4, loc="lower left", ncol=3, edgecolor="k") 

587 

588 # Side histogram 

589 sideHist = plt.gcf().add_subplot(gs[1:, -1], sharey=ax) 

590 finiteObjs = np.isfinite(catPlot[yCol].values) 

591 bins = np.linspace(yLimMin, yLimMax) 

592 sideHist.hist(catPlot[yCol].values[finiteObjs], bins=bins, color="grey", alpha=0.3, 

593 orientation="horizontal", log=True) 

594 if np.any(catPlot["sourceType"] == sourceTypeMapper["galaxies"]): 

595 sideHist.hist(ysGalaxies[np.isfinite(ysGalaxies)], bins=bins, color="firebrick", histtype="step", 

596 orientation="horizontal", log=True) 

597 if highMags[sourceTypeMapper["galaxies"]] != "-": 

598 sideHist.hist(ysGalaxies[np.isfinite(ysGalaxies) & (xsGalaxies < float(highMags[2]))], 

599 bins=bins, color="firebrick", histtype="step", orientation="horizontal", 

600 log=True, ls="--") 

601 if lowMags[sourceTypeMapper["galaxies"]] != "-": 

602 sideHist.hist(ysGalaxies[np.isfinite(ysGalaxies) & (xsGalaxies < float(lowMags[2]))], 

603 bins=bins, color="firebrick", histtype="step", orientation="horizontal", 

604 log=True, ls=":") 

605 

606 if np.any(catPlot["sourceType"] == sourceTypeMapper["stars"]): 

607 sideHist.hist(ysStars[np.isfinite(ysStars)], bins=bins, color="midnightblue", histtype="step", 

608 orientation="horizontal", log=True) 

609 if highMags[sourceTypeMapper["stars"]] != "-": 

610 sideHist.hist(ysStars[np.isfinite(ysStars) & (xsStars < float(highMags[1]))], bins=bins, 

611 color="midnightblue", histtype="step", orientation="horizontal", log=True, 

612 ls="--") 

613 if lowMags[sourceTypeMapper["stars"]] != "-": 

614 sideHist.hist(ysStars[np.isfinite(ysStars) & (xsStars < float(lowMags[1]))], bins=bins, 

615 color="midnightblue", histtype="step", orientation="horizontal", log=True, 

616 ls=":") 

617 

618 sideHist.axes.get_yaxis().set_visible(False) 

619 sideHist.set_xlabel("Number", fontsize=8) 

620 if self.config.plot2DHist and histIm is not None: 

621 divider = make_axes_locatable(sideHist) 

622 cax = divider.append_axes("right", size="8%", pad=0) 

623 plt.colorbar(histIm, cax=cax, orientation="vertical", label="Number of Points Per Bin") 

624 

625 # Corner plot of patches showing summary stat in each 

626 axCorner = plt.gcf().add_subplot(gs[0, -1]) 

627 axCorner.yaxis.tick_right() 

628 axCorner.yaxis.set_label_position("right") 

629 axCorner.xaxis.tick_top() 

630 axCorner.xaxis.set_label_position("top") 

631 axCorner.set_aspect("equal") 

632 

633 patches = [] 

634 colors = [] 

635 for dataId in sumStats.keys(): 

636 (corners, stat) = sumStats[dataId] 

637 ra = corners[0][0].asDegrees() 

638 dec = corners[0][1].asDegrees() 

639 xy = (ra, dec) 

640 width = corners[2][0].asDegrees() - ra 

641 height = corners[2][1].asDegrees() - dec 

642 patches.append(Rectangle(xy, width, height)) 

643 colors.append(stat) 

644 ras = [ra.asDegrees() for (ra, dec) in corners] 

645 decs = [dec.asDegrees() for (ra, dec) in corners] 

646 axCorner.plot(ras + [ras[0]], decs + [decs[0]], "k", lw=0.5) 

647 cenX = ra + width / 2 

648 cenY = dec + height / 2 

649 if dataId != "tract": 

650 axCorner.annotate(dataId, (cenX, cenY), color="k", fontsize=4, ha="center", va="center") 

651 

652 # Set the bad color to transparent and make a masked array 

653 colors = np.ma.array(colors, mask=np.isnan(colors)) 

654 collection = PatchCollection(patches, cmap=cmapPatch) 

655 collection.set_array(colors) 

656 axCorner.add_collection(collection) 

657 

658 axCorner.set_xlabel("R.A. (deg)", fontsize=7) 

659 axCorner.set_ylabel("Dec. (deg)", fontsize=7) 

660 axCorner.tick_params(axis="both", labelsize=6, length=0, pad=1.5) 

661 axCorner.invert_xaxis() 

662 

663 # Add a colorbar 

664 pos = axCorner.get_position() 

665 cax = fig.add_axes([pos.x0, pos.y0 + 0.23, pos.x1 - pos.x0, 0.025]) 

666 plt.colorbar(collection, cax=cax, orientation="horizontal") 

667 cax.text(0.5, 0.5, "Median Value", color="k", transform=cax.transAxes, rotation="horizontal", 

668 horizontalalignment="center", verticalalignment="center", fontsize=6) 

669 cax.tick_params(axis="x", labelsize=6, labeltop=True, labelbottom=False, bottom=False, top=True, 

670 pad=0.5, length=2) 

671 

672 plt.draw() 

673 plt.subplots_adjust(wspace=0.0, hspace=0.0, bottom=0.22, left=0.21) 

674 fig = plt.gcf() 

675 fig = addPlotInfo(fig, plotInfo) 

676 

677 return fig