Coverage for python/lsst/analysis/drp/scatterPlot.py: 10%

368 statements  

« prev     ^ index     » next       coverage.py v6.4.1, created at 2022-06-21 08:11 +0000

1# This file is part of analysis_drp. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import matplotlib.pyplot as plt 

23import numpy as np 

24import pandas as pd 

25from matplotlib import gridspec 

26from matplotlib.patches import Rectangle 

27from matplotlib.path import Path 

28from matplotlib.collections import PatchCollection 

29from mpl_toolkits.axes_grid1 import make_axes_locatable 

30from lsst.pipe.tasks.configurableActions import ConfigurableActionStructField 

31from lsst.pipe.tasks.dataFrameActions import MagColumnNanoJansky, SingleColumnAction 

32from lsst.skymap import BaseSkyMap 

33 

34import lsst.pipe.base as pipeBase 

35import lsst.pex.config as pexConfig 

36 

37from . import dataSelectors as dataSelectors 

38from .plotUtils import generateSummaryStats, parsePlotInfo, addPlotInfo, mkColormap 

39from .statistics import sigmaMad 

40 

41cmapPatch = plt.cm.coolwarm.copy() 

42cmapPatch.set_bad(color="none") 

43 

44 

45class ScatterPlotWithTwoHistsTaskConnections(pipeBase.PipelineTaskConnections, 

46 dimensions=("tract", "skymap"), 

47 defaultTemplates={"inputCoaddName": "deep", 

48 "plotName": "deltaCoords"}): 

49 

50 catPlot = pipeBase.connectionTypes.Input(doc="The tract wide catalog to make plots from.", 

51 storageClass="DataFrame", 

52 name="objectTable_tract", 

53 dimensions=("tract", "skymap"), 

54 deferLoad=True) 

55 

56 skymap = pipeBase.connectionTypes.Input(doc="The skymap for the tract", 

57 storageClass="SkyMap", 

58 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME, 

59 dimensions=("skymap",)) 

60 

61 scatterPlot = pipeBase.connectionTypes.Output(doc="A scatter plot with histograms for both axes.", 

62 storageClass="Plot", 

63 name="scatterTwoHistPlot_{plotName}", 

64 dimensions=("tract", "skymap")) 

65 

66 

67class ScatterPlotWithTwoHistsTaskConfig(pipeBase.PipelineTaskConfig, 

68 pipelineConnections=ScatterPlotWithTwoHistsTaskConnections): 

69 

70 axisActions = ConfigurableActionStructField( 

71 doc="The actions to use to calculate the values used on each axis. The defaults for the" 

72 "column names xAction and magAction are set to iCModelFlux.", 

73 default={"xAction": MagColumnNanoJansky, "yAction": SingleColumnAction, 

74 "magAction": MagColumnNanoJansky}, 

75 ) 

76 

77 axisLabels = pexConfig.DictField( 

78 doc="Name of the dataframe columns to plot, will be used as the axis label: {'x':, 'y':, 'mag':}" 

79 "The mag column is used to decide which points to include in the printed statistics.", 

80 keytype=str, 

81 itemtype=str 

82 ) 

83 

84 def get_requirements(self): 

85 """Return inputs required for a Task to run with this config. 

86 

87 Returns 

88 ------- 

89 bands : `set` 

90 The required bands. 

91 columns : `set` 

92 The required column names. 

93 """ 

94 columnNames = {"patch"} 

95 bands = set() 

96 for actionStruct in [self.axisActions, 

97 self.selectorActions, 

98 self.highSnStatisticSelectorActions, 

99 self.lowSnStatisticSelectorActions, 

100 self.sourceSelectorActions]: 

101 for action in actionStruct: 

102 for col in action.columns: 

103 if col is not None: 

104 columnNames.add(col) 

105 column_split = col.split("_") 

106 # If there's no underscore, it has no band prefix 

107 if len(column_split) > 1: 

108 band = column_split[0] 

109 if band not in self.nonBandColumnPrefixes: 

110 bands.add(band) 

111 return bands, columnNames 

112 

113 nonBandColumnPrefixes = pexConfig.ListField( 

114 doc="Column prefixes that are not bands and which should not be added to the set of bands", 

115 dtype=str, 

116 default=["coord", "extend", "detect", "xy", "merge"], 

117 ) 

118 

119 selectorActions = ConfigurableActionStructField( 

120 doc="Which selectors to use to narrow down the data for QA plotting.", 

121 default={"flagSelector": dataSelectors.CoaddPlotFlagSelector}, 

122 ) 

123 

124 highSnStatisticSelectorActions = ConfigurableActionStructField( 

125 doc="Selectors to use to decide which points to use for calculating the high SN statistics.", 

126 default={"statSelector": dataSelectors.SnSelector}, 

127 ) 

128 

129 lowSnStatisticSelectorActions = ConfigurableActionStructField( 

130 doc="Selectors to use to decide which points to use for calculating the low SN statistics.", 

131 default={"statSelector": dataSelectors.SnSelector}, 

132 ) 

133 

134 sourceSelectorActions = ConfigurableActionStructField( 

135 doc="What types of sources to use.", 

136 default={"sourceSelector": dataSelectors.StarIdentifier}, 

137 ) 

138 

139 nBins = pexConfig.Field( 

140 doc="Number of bins to put on the x axis.", 

141 default=40.0, 

142 dtype=float, 

143 ) 

144 

145 plot2DHist = pexConfig.Field( 

146 doc="Plot a 2D histogram in the densist area of points on the scatter plot." 

147 "Doesn't look great if plotting mulitple datasets on top of each other.", 

148 default=True, 

149 dtype=bool, 

150 ) 

151 

152 def setDefaults(self): 

153 super().setDefaults() 

154 self.axisActions.magAction.column = "i_cModelFlux" 

155 self.axisActions.xAction.column = "i_cModelFlux" 

156 self.highSnStatisticSelectorActions.statSelector.threshold = 2700 

157 self.lowSnStatisticSelectorActions.statSelector.threshold = 500 

158 

159 

160class ScatterPlotWithTwoHistsTask(pipeBase.PipelineTask): 

161 

162 ConfigClass = ScatterPlotWithTwoHistsTaskConfig 

163 _DefaultName = "scatterPlotWithTwoHistsTask" 

164 

165 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

166 # Docs inherited from base class 

167 bands, columnNames = self.config.get_requirements() 

168 inputs = butlerQC.get(inputRefs) 

169 dataFrame = inputs["catPlot"].get(parameters={"columns": columnNames}) 

170 inputs['catPlot'] = dataFrame 

171 dataId = butlerQC.quantum.dataId 

172 inputs["dataId"] = dataId 

173 inputs["runName"] = inputRefs.catPlot.datasetRef.run 

174 localConnections = self.config.ConnectionsClass(config=self.config) 

175 inputs["tableName"] = localConnections.catPlot.name 

176 inputs["plotName"] = localConnections.scatterPlot.name 

177 inputs["bands"] = bands 

178 outputs = self.run(**inputs) 

179 butlerQC.put(outputs, outputRefs) 

180 

181 def run(self, catPlot, dataId, runName, skymap, tableName, bands, plotName): 

182 """Prep the catalogue and then make a scatterPlot of the given column. 

183 

184 Parameters 

185 ---------- 

186 catPlot : `pandas.core.frame.DataFrame` 

187 The catalog to plot the points from. 

188 dataId : 

189 `lsst.daf.butler.core.dimensions._coordinate._ExpandedTupleDataCoordinate` 

190 The dimensions that the plot is being made from. 

191 runName : `str` 

192 The name of the collection that the plot is written out to. 

193 skymap : `lsst.skymap` 

194 The skymap used to define the patch boundaries. 

195 tableName : `str` 

196 The type of table used to make the plot. 

197 

198 Returns 

199 ------- 

200 `pipeBase.Struct` containing: 

201 scatterPlot : `matplotlib.figure.Figure` 

202 The resulting figure. 

203 

204 Notes 

205 ----- 

206 The catalogue is first narrowed down using the selectors specified in 

207 `self.config.selectorActions`. 

208 If the column names are 'Functor' then the functors specified in 

209 `self.config.axisFunctors` are used to calculate the required values. 

210 After this the following functions are run: 

211 

212 `parsePlotInfo` which uses the dataId, runName and tableName to add 

213 useful information to the plot. 

214 

215 `generateSummaryStats` which parses the skymap to give the corners of 

216 the patches for later plotting and calculates some basic statistics 

217 in each patch for the column in self.zColName. 

218 

219 `scatterPlotWithTwoHists` which makes a scatter plot of the points with 

220 a histogram of each axis. 

221 """ 

222 

223 # Apply the selectors to narrow down the sources to use 

224 mask = np.ones(len(catPlot), dtype=bool) 

225 for selector in self.config.selectorActions: 

226 mask &= selector(catPlot) 

227 catPlot = catPlot[mask] 

228 

229 columns = {self.config.axisLabels["x"]: self.config.axisActions.xAction(catPlot), 

230 self.config.axisLabels["y"]: self.config.axisActions.yAction(catPlot), 

231 self.config.axisLabels["mag"]: self.config.axisActions.magAction(catPlot), 

232 "patch": catPlot["patch"]} 

233 for actionStruct in [self.config.highSnStatisticSelectorActions, 

234 self.config.lowSnStatisticSelectorActions, 

235 self.config.sourceSelectorActions]: 

236 for action in actionStruct: 

237 for col in action.columns: 

238 columns.update({col: catPlot[col]}) 

239 

240 plotDf = pd.DataFrame(columns) 

241 

242 sourceTypes = np.zeros(len(plotDf)) 

243 for selector in self.config.sourceSelectorActions: 

244 # The source selectors return 1 for a star and 2 for a galaxy 

245 # rather than a mask this allows the information about which 

246 # type of sources are being plotted to be propagated 

247 sourceTypes += selector(catPlot) 

248 if list(self.config.sourceSelectorActions) == []: 

249 sourceTypes = [10]*len(plotDf) 

250 plotDf.loc[:, "sourceType"] = sourceTypes 

251 

252 # Decide which points to use for stats calculation 

253 useForStats = np.zeros(len(plotDf)) 

254 lowSnMask = np.ones(len(plotDf), dtype=bool) 

255 for selector in self.config.lowSnStatisticSelectorActions: 

256 lowSnMask &= selector(plotDf) 

257 useForStats[lowSnMask] = 2 

258 

259 highSnMask = np.ones(len(plotDf), dtype=bool) 

260 for selector in self.config.highSnStatisticSelectorActions: 

261 highSnMask &= selector(plotDf) 

262 useForStats[highSnMask] = 1 

263 plotDf.loc[:, "useForStats"] = useForStats 

264 

265 # Get the S/N cut used 

266 try: 

267 SN = self.config.selectorActions.SnSelector.threshold 

268 except AttributeError: 

269 SN = "N/A" 

270 

271 # Get useful information about the plot 

272 plotInfo = parsePlotInfo(dataId, runName, tableName, bands, plotName, SN) 

273 # Calculate the corners of the patches and some associated stats 

274 sumStats = {} if skymap is None else generateSummaryStats( 

275 plotDf, self.config.axisLabels["y"], skymap, plotInfo) 

276 # Make the plot 

277 fig = self.scatterPlotWithTwoHists(plotDf, plotInfo, sumStats) 

278 

279 return pipeBase.Struct(scatterPlot=fig) 

280 

281 def scatterPlotWithTwoHists(self, catPlot, plotInfo, sumStats, yLims=False, xLims=False): 

282 """Makes a generic plot with a 2D histogram and collapsed histograms of 

283 each axis. 

284 

285 Parameters 

286 ---------- 

287 catPlot : `pandas.core.frame.DataFrame` 

288 The catalog to plot the points from. 

289 plotInfo : `dict` 

290 A dictionary of information about the data being plotted with keys: 

291 ``"run"`` 

292 The output run for the plots (`str`). 

293 ``"skymap"`` 

294 The type of skymap used for the data (`str`). 

295 ``"filter"`` 

296 The filter used for this data (`str`). 

297 ``"tract"`` 

298 The tract that the data comes from (`str`). 

299 sumStats : `dict` 

300 A dictionary where the patchIds are the keys which store the R.A. 

301 and dec of the corners of the patch, along with a summary 

302 statistic for each patch. 

303 yLims : `Bool` or `tuple`, optional 

304 The y axis limits to use for the plot. If `False`, they are 

305 calculated from the data. If being given a tuple of 

306 (yMin, yMax). 

307 xLims : `Bool` or `tuple`, optional 

308 The x axis limits to use for the plot. If `False`, they are 

309 calculated from the data. 

310 If being given a tuple of (xMin, xMax). 

311 

312 Returns 

313 ------- 

314 fig : `matplotlib.figure.Figure` 

315 The resulting figure. 

316 

317 Notes 

318 ----- 

319 Uses the axisLabels config options `x` and `y` and the axisAction 

320 config options `xAction` and `yAction` to plot a scatter 

321 plot of the values against each other. A histogram of the points 

322 collapsed onto each axis is also plotted. A summary panel showing the 

323 median of the y value in each patch is shown in the upper right corner 

324 of the resultant plot. The code uses the selectorActions to decide 

325 which points to plot and the statisticSelector actions to determine 

326 which points to use for the printed statistics. 

327 """ 

328 self.log.info("Plotting %s: the values of %s on a scatter plot.", 

329 self.config.connections.plotName, self.config.axisLabels['y']) 

330 

331 fig = plt.figure(dpi=300) 

332 gs = gridspec.GridSpec(4, 4) 

333 

334 newBlues = mkColormap(["paleturquoise", "midnightBlue"]) 

335 newReds = mkColormap(["lemonchiffon", "firebrick"]) 

336 

337 # Need to separate stars and galaxies 

338 stars = (catPlot["sourceType"] == 1) 

339 galaxies = (catPlot["sourceType"] == 2) 

340 

341 xCol = self.config.axisLabels["x"] 

342 yCol = self.config.axisLabels["y"] 

343 magCol = self.config.axisLabels["mag"] 

344 

345 # For galaxies 

346 xsGalaxies = catPlot.loc[galaxies, xCol] 

347 ysGalaxies = catPlot.loc[galaxies, yCol] 

348 

349 # For stars 

350 xsStars = catPlot.loc[stars, xCol] 

351 ysStars = catPlot.loc[stars, yCol] 

352 

353 highStats = {} 

354 highMags = {} 

355 lowStats = {} 

356 lowMags = {} 

357 

358 # sourceTypes: 1 - stars, 2 - galaxies, 9 - unknowns 

359 # 10 - all 

360 sourceTypeList = [1, 2, 9, 10] 

361 sourceTypeMapper = {"stars": 1, "galaxies": 2, "unknowns": 9, "all": 10} 

362 # Calculate some statistics 

363 for sourceType in sourceTypeList: 

364 if np.any(catPlot["sourceType"] == sourceType): 

365 sources = (catPlot["sourceType"] == sourceType) 

366 highSn = ((catPlot["useForStats"] == 1) & sources) 

367 highSnMed = np.nanmedian(catPlot.loc[highSn, yCol]) 

368 highSnMad = sigmaMad(catPlot.loc[highSn, yCol], nan_policy="omit") 

369 

370 lowSn = (((catPlot["useForStats"] == 1) | (catPlot["useForStats"] == 2)) & sources) 

371 lowSnMed = np.nanmedian(catPlot.loc[lowSn, yCol]) 

372 lowSnMad = sigmaMad(catPlot.loc[lowSn, yCol], nan_policy="omit") 

373 

374 highStatsStr = (f"Median: {highSnMed:0.3g} " 

375 + r"$\sigma_{MAD}$: " + f"{highSnMad:0.3g} " 

376 + r"N$_{points}$: " + f"{np.sum(highSn)}") 

377 highStats[sourceType] = highStatsStr 

378 

379 lowStatsStr = (f"Median: {lowSnMed:0.3g} " 

380 + r"$\sigma_{MAD}$: " + f"{lowSnMad:0.3g} " 

381 + r"N$_{points}$: " + f"{np.sum(lowSn)}") 

382 lowStats[sourceType] = lowStatsStr 

383 

384 if np.sum(highSn) > 0: 

385 sortedMags = np.sort(catPlot.loc[highSn, magCol]) 

386 x = int(len(sortedMags)/10) 

387 approxHighMag = np.nanmedian(sortedMags[-x:]) 

388 elif len(catPlot.loc[highSn, magCol]) < 10: 

389 approxHighMag = np.nanmedian(catPlot.loc[highSn, magCol]) 

390 else: 

391 approxHighMag = "-" 

392 highMags[sourceType] = f"{approxHighMag:.3g}" 

393 

394 if np.sum(lowSn) > 0.0: 

395 sortedMags = np.sort(catPlot.loc[lowSn, magCol]) 

396 x = int(len(sortedMags)/10) 

397 approxLowMag = np.nanmedian(sortedMags[-x:]) 

398 elif len(catPlot.loc[lowSn, magCol]) < 10: 

399 approxLowMag = np.nanmedian(catPlot.loc[lowSn, magCol]) 

400 else: 

401 approxLowMag = "-" 

402 lowMags[sourceType] = f"{approxLowMag:.3g}" 

403 

404 # Main scatter plot 

405 ax = fig.add_subplot(gs[1:, :-1]) 

406 binThresh = 5 

407 

408 yBinsOut = [] 

409 linesForLegend = [] 

410 

411 if (np.any(catPlot["sourceType"] == sourceTypeMapper["stars"]) 

412 and not np.any(catPlot["sourceType"] == sourceTypeMapper["galaxies"])): 

413 toPlotList = [(xsStars.values, ysStars.values, "midnightblue", newBlues, 

414 sourceTypeMapper["stars"])] 

415 elif (np.any(catPlot["sourceType"] == sourceTypeMapper["galaxies"]) 

416 and not np.any(catPlot["sourceType"] == sourceTypeMapper["stars"])): 

417 toPlotList = [(xsGalaxies.values, ysGalaxies.values, "firebrick", newReds, 

418 sourceTypeMapper["galaxies"])] 

419 elif (np.any(catPlot["sourceType"] == sourceTypeMapper["stars"]) 

420 and np.any(catPlot["sourceType"] == sourceTypeMapper["galaxies"])): 

421 toPlotList = [(xsGalaxies.values, ysGalaxies.values, "firebrick", newReds, 

422 sourceTypeMapper["galaxies"]), 

423 (xsStars.values, ysStars.values, "midnightblue", newBlues, 

424 sourceTypeMapper["stars"])] 

425 elif np.any(catPlot["sourceType"] == sourceTypeMapper["unknowns"]): 

426 unknowns = (catPlot["sourceType"] == sourceTypeMapper["unknowns"]) 

427 toPlotList = [(catPlot.loc[unknowns, xCol].values, catPlot.loc[unknowns, yCol].values, 

428 "green", None, sourceTypeMapper["unknowns"])] 

429 elif np.any(catPlot["sourceType"] == sourceTypeMapper["all"]): 

430 toPlotList = [(catPlot[xCol].values, catPlot[yCol].values, "purple", None, 

431 sourceTypeMapper["all"])] 

432 else: 

433 toPlotList = [] 

434 noDataFig = plt.Figure() 

435 noDataFig.text(0.3, 0.5, "No data to plot after selectors applied") 

436 noDataFig = addPlotInfo(noDataFig, plotInfo) 

437 return noDataFig 

438 

439 xMin = None 

440 for (j, (xs, ys, color, cmap, sourceType)) in enumerate(toPlotList): 

441 sigMadYs = sigmaMad(ys, nan_policy="omit") 

442 if len(xs) < 2: 

443 medLine, = ax.plot(xs, np.nanmedian(ys), color, 

444 label=f"Median: {np.nanmedian(ys):0.3g}", lw=0.8) 

445 linesForLegend.append(medLine) 

446 sigMads = np.array([sigmaMad(ys, nan_policy="omit")]*len(xs)) 

447 sigMadLine, = ax.plot(xs, np.nanmedian(ys) + 1.0*sigMads, color, alpha=0.8, lw=0.8, 

448 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:0.3g}") 

449 ax.plot(xs, np.nanmedian(ys) - 1.0*sigMads, color, alpha=0.8) 

450 linesForLegend.append(sigMadLine) 

451 histIm = None 

452 continue 

453 

454 [xs1, xs25, xs50, xs75, xs95, xs97] = np.nanpercentile(xs, [1, 25, 50, 75, 95, 97]) 

455 xScale = (xs97 - xs1)/20.0 # This is ~5% of the data range 

456 

457 # 40 was used as the default number of bins because it looked good 

458 xEdges = np.arange(np.nanmin(xs) - xScale, np.nanmax(xs) + xScale, 

459 (np.nanmax(xs) + xScale - (np.nanmin(xs) - xScale))/self.config.nBins) 

460 medYs = np.nanmedian(ys) 

461 fiveSigmaHigh = medYs + 5.0*sigMadYs 

462 fiveSigmaLow = medYs - 5.0*sigMadYs 

463 binSize = (fiveSigmaHigh - fiveSigmaLow)/101.0 

464 yEdges = np.arange(fiveSigmaLow, fiveSigmaHigh, binSize) 

465 

466 counts, xBins, yBins = np.histogram2d(xs, ys, bins=(xEdges, yEdges)) 

467 yBinsOut.append(yBins) 

468 countsYs = np.sum(counts, axis=1) 

469 

470 ids = np.where((countsYs > binThresh))[0] 

471 xEdgesPlot = xEdges[ids][1:] 

472 xEdges = xEdges[ids] 

473 

474 if len(ids) > 1: 

475 # Create the codes needed to turn the sigmaMad lines 

476 # into a path to speed up checking which points are 

477 # inside the area. 

478 codes = np.ones(len(xEdgesPlot)*2)*Path.LINETO 

479 codes[0] = Path.MOVETO 

480 codes[-1] = Path.CLOSEPOLY 

481 

482 meds = np.zeros(len(xEdgesPlot)) 

483 threeSigMadVerts = np.zeros((len(xEdgesPlot)*2, 2)) 

484 sigMads = np.zeros(len(xEdgesPlot)) 

485 

486 for (i, xEdge) in enumerate(xEdgesPlot): 

487 ids = np.where((xs < xEdge) & (xs > xEdges[i]) & (np.isfinite(ys)))[0] 

488 med = np.median(ys[ids]) 

489 sigMad = sigmaMad(ys[ids]) 

490 meds[i] = med 

491 sigMads[i] = sigMad 

492 threeSigMadVerts[i, :] = [xEdge, med + 3*sigMad] 

493 threeSigMadVerts[-(i + 1), :] = [xEdge, med - 3*sigMad] 

494 

495 medLine, = ax.plot(xEdgesPlot, meds, color, label="Running Median") 

496 linesForLegend.append(medLine) 

497 

498 # Make path to check which points lie within one sigma mad 

499 threeSigMadPath = Path(threeSigMadVerts, codes) 

500 

501 # Add lines for the median +/- 3 * sigma MAD 

502 threeSigMadLine, = ax.plot(xEdgesPlot, threeSigMadVerts[:len(xEdgesPlot), 1], color, 

503 alpha=0.4, label=r"3$\sigma_{MAD}$") 

504 ax.plot(xEdgesPlot[::-1], threeSigMadVerts[len(xEdgesPlot):, 1], color, alpha=0.4) 

505 

506 # Add lines for the median +/- 1 * sigma MAD 

507 sigMadLine, = ax.plot(xEdgesPlot, meds + 1.0*sigMads, color, alpha=0.8, 

508 label=r"$\sigma_{MAD}$") 

509 linesForLegend.append(sigMadLine) 

510 ax.plot(xEdgesPlot, meds - 1.0*sigMads, color, alpha=0.8) 

511 

512 # Add lines for the median +/- 2 * sigma MAD 

513 twoSigMadLine, = ax.plot(xEdgesPlot, meds + 2.0*sigMads, color, alpha=0.6, 

514 label=r"2$\sigma_{MAD}$") 

515 linesForLegend.append(twoSigMadLine) 

516 linesForLegend.append(threeSigMadLine) 

517 ax.plot(xEdgesPlot, meds - 2.0*sigMads, color, alpha=0.6) 

518 

519 # Check which points are outside 3 sigma MAD of the median 

520 # and plot these as points. 

521 inside = threeSigMadPath.contains_points(np.array([xs, ys]).T) 

522 ax.plot(xs[~inside], ys[~inside], ".", ms=3, alpha=0.3, mfc=color, mec=color, zorder=-1) 

523 

524 # Add some stats text 

525 xPos = 0.65 - 0.4*j 

526 bbox = dict(edgecolor=color, linestyle="--", facecolor="none") 

527 highThresh = self.config.highSnStatisticSelectorActions.statSelector.threshold 

528 statText = f"S/N > {highThresh} Stats ({magCol} < {highMags[sourceType]})\n" 

529 statText += highStats[sourceType] 

530 fig.text(xPos, 0.090, statText, bbox=bbox, transform=fig.transFigure, fontsize=6) 

531 

532 bbox = dict(edgecolor=color, linestyle=":", facecolor="none") 

533 lowThresh = self.config.lowSnStatisticSelectorActions.statSelector.threshold 

534 statText = f"S/N > {lowThresh} Stats ({magCol} < {lowMags[sourceType]})\n" 

535 statText += lowStats[sourceType] 

536 fig.text(xPos, 0.020, statText, bbox=bbox, transform=fig.transFigure, fontsize=6) 

537 

538 if self.config.plot2DHist: 

539 histIm = ax.hexbin(xs[inside], ys[inside], gridsize=75, cmap=cmap, mincnt=1, zorder=-2) 

540 

541 # If there are not many sources being used for the 

542 # statistics then plot them individually as just 

543 # plotting a line makes the statistics look wrong 

544 # as the magnitude estimation is iffy for low 

545 # numbers of sources. 

546 sources = (catPlot["sourceType"] == sourceType) 

547 statInfo = catPlot["useForStats"].loc[sources].values 

548 highSn = (statInfo == 1) 

549 lowSn = ((statInfo == 2) | (statInfo == 2)) 

550 if np.sum(highSn) < 100 and np.sum(highSn) > 0: 

551 ax.plot(xs[highSn], ys[highSn], marker="x", ms=4, mec="w", mew=2, ls="none") 

552 highSnLine, = ax.plot(xs[highSn], ys[highSn], color=color, marker="x", ms=4, ls="none", 

553 label="High SN") 

554 linesForLegend.append(highSnLine) 

555 xMin = np.min(xs[highSn]) 

556 else: 

557 ax.axvline(float(highMags[sourceType]), color=color, ls="--") 

558 

559 if np.sum(lowSn) < 100 and np.sum(lowSn) > 0: 

560 ax.plot(xs[lowSn], ys[lowSn], marker="+", ms=4, mec="w", mew=2, ls="none") 

561 lowSnLine, = ax.plot(xs[lowSn], ys[lowSn], color=color, marker="+", ms=4, ls="none", 

562 label="Low SN") 

563 linesForLegend.append(lowSnLine) 

564 if xMin is None or xMin > np.min(xs[lowSn]): 

565 xMin = np.min(xs[lowSn]) 

566 else: 

567 ax.axvline(float(lowMags[sourceType]), color=color, ls=":") 

568 

569 else: 

570 ax.plot(xs, ys, ".", ms=5, alpha=0.3, mfc=color, mec=color, zorder=-1) 

571 meds = np.array([np.nanmedian(ys)]*len(xs)) 

572 medLine, = ax.plot(xs, meds, color, label=f"Median: {np.nanmedian(ys):0.3g}", lw=0.8) 

573 linesForLegend.append(medLine) 

574 sigMads = np.array([sigmaMad(ys, nan_policy="omit")]*len(xs)) 

575 sigMadLine, = ax.plot(xs, meds + 1.0*sigMads, color, alpha=0.8, lw=0.8, 

576 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:0.3g}") 

577 ax.plot(xs, meds - 1.0*sigMads, color, alpha=0.8) 

578 linesForLegend.append(sigMadLine) 

579 histIm = None 

580 

581 # Set the scatter plot limits 

582 if len(ysStars) > 0: 

583 plotMed = np.nanmedian(ysStars) 

584 else: 

585 plotMed = np.nanmedian(ysGalaxies) 

586 if len(xs) < 2: 

587 meds = [np.median(ys)] 

588 if yLims: 

589 ax.set_ylim(yLims[0], yLims[1]) 

590 else: 

591 numSig = 4 

592 yLimMin = plotMed - numSig*sigMadYs 

593 yLimMax = plotMed + numSig*sigMadYs 

594 while (yLimMax < np.max(meds) or yLimMin > np.min(meds)) and numSig < 10: 

595 numSig += 1 

596 

597 numSig += 1 

598 yLimMin = plotMed - numSig*sigMadYs 

599 yLimMax = plotMed + numSig*sigMadYs 

600 ax.set_ylim(yLimMin, yLimMax) 

601 

602 if xLims: 

603 ax.set_xlim(xLims[0], xLims[1]) 

604 elif len(xs) > 2: 

605 if xMin is None: 

606 xMin = xs1 - 2*xScale 

607 ax.set_xlim(xMin, xs97 + 2*xScale) 

608 

609 # Add a line legend 

610 ax.legend(handles=linesForLegend, ncol=4, fontsize=6, loc="upper left", framealpha=0.9, 

611 edgecolor="k", borderpad=0.4, handlelength=1) 

612 

613 # Add axes labels 

614 ax.set_ylabel(yCol, fontsize=10, labelpad=10) 

615 ax.set_xlabel(xCol, fontsize=10, labelpad=2) 

616 

617 # Top histogram 

618 topHist = plt.gcf().add_subplot(gs[0, :-1], sharex=ax) 

619 topHist.hist(catPlot[xCol].values, bins=100, color="grey", alpha=0.3, log=True, 

620 label=f"All ({len(catPlot)})") 

621 if np.any(catPlot["sourceType"] == 2): 

622 topHist.hist(xsGalaxies, bins=100, color="firebrick", histtype="step", log=True, 

623 label=f"Galaxies ({len(np.where(galaxies)[0])})") 

624 if np.any(catPlot["sourceType"] == 1): 

625 topHist.hist(xsStars, bins=100, color="midnightblue", histtype="step", log=True, 

626 label=f"Stars ({len(np.where(stars)[0])})") 

627 topHist.axes.get_xaxis().set_visible(False) 

628 topHist.set_ylabel("Number", fontsize=8) 

629 topHist.legend(fontsize=6, framealpha=0.9, borderpad=0.4, loc="lower left", ncol=3, edgecolor="k") 

630 

631 # Side histogram 

632 sideHist = plt.gcf().add_subplot(gs[1:, -1], sharey=ax) 

633 finiteObjs = np.isfinite(catPlot[yCol].values) 

634 bins = np.linspace(yLimMin, yLimMax) 

635 sideHist.hist(catPlot[yCol].values[finiteObjs], bins=bins, color="grey", alpha=0.3, 

636 orientation="horizontal", log=True) 

637 if np.any(catPlot["sourceType"] == sourceTypeMapper["galaxies"]): 

638 sideHist.hist(ysGalaxies[np.isfinite(ysGalaxies)], bins=bins, color="firebrick", histtype="step", 

639 orientation="horizontal", log=True) 

640 sources = (catPlot["sourceType"].values == sourceTypeMapper["galaxies"]) 

641 highSn = (catPlot["useForStats"].values == 1) 

642 lowSn = (catPlot["useForStats"].values == 2) 

643 sideHist.hist(ysGalaxies[highSn[sources]], bins=bins, color="firebrick", histtype="step", 

644 orientation="horizontal", log=True, ls="--") 

645 sideHist.hist(ysGalaxies[lowSn[sources]], bins=bins, color="firebrick", histtype="step", 

646 orientation="horizontal", log=True, ls=":") 

647 

648 if np.any(catPlot["sourceType"] == sourceTypeMapper["stars"]): 

649 sideHist.hist(ysStars[np.isfinite(ysStars)], bins=bins, color="midnightblue", histtype="step", 

650 orientation="horizontal", log=True) 

651 sources = (catPlot["sourceType"] == sourceTypeMapper["stars"]) 

652 highSn = (catPlot["useForStats"] == 1) 

653 lowSn = (catPlot["useForStats"] == 2) 

654 sideHist.hist(ysStars[highSn[sources]], bins=bins, color="midnightblue", histtype="step", 

655 orientation="horizontal", log=True, ls="--") 

656 sideHist.hist(ysStars[lowSn[sources]], bins=bins, color="midnightblue", histtype="step", 

657 orientation="horizontal", log=True, ls=":") 

658 

659 sideHist.axes.get_yaxis().set_visible(False) 

660 sideHist.set_xlabel("Number", fontsize=8) 

661 if self.config.plot2DHist and histIm is not None: 

662 divider = make_axes_locatable(sideHist) 

663 cax = divider.append_axes("right", size="8%", pad=0) 

664 plt.colorbar(histIm, cax=cax, orientation="vertical", label="Number of Points Per Bin") 

665 

666 # Corner plot of patches showing summary stat in each 

667 axCorner = plt.gcf().add_subplot(gs[0, -1]) 

668 axCorner.yaxis.tick_right() 

669 axCorner.yaxis.set_label_position("right") 

670 axCorner.xaxis.tick_top() 

671 axCorner.xaxis.set_label_position("top") 

672 axCorner.set_aspect("equal") 

673 

674 patches = [] 

675 colors = [] 

676 for dataId in sumStats.keys(): 

677 (corners, stat) = sumStats[dataId] 

678 ra = corners[0][0].asDegrees() 

679 dec = corners[0][1].asDegrees() 

680 xy = (ra, dec) 

681 width = corners[2][0].asDegrees() - ra 

682 height = corners[2][1].asDegrees() - dec 

683 patches.append(Rectangle(xy, width, height)) 

684 colors.append(stat) 

685 ras = [ra.asDegrees() for (ra, dec) in corners] 

686 decs = [dec.asDegrees() for (ra, dec) in corners] 

687 axCorner.plot(ras + [ras[0]], decs + [decs[0]], "k", lw=0.5) 

688 cenX = ra + width / 2 

689 cenY = dec + height / 2 

690 if dataId != "tract": 

691 axCorner.annotate(dataId, (cenX, cenY), color="k", fontsize=4, ha="center", va="center") 

692 

693 # Set the bad color to transparent and make a masked array 

694 colors = np.ma.array(colors, mask=np.isnan(colors)) 

695 collection = PatchCollection(patches, cmap=cmapPatch) 

696 collection.set_array(colors) 

697 axCorner.add_collection(collection) 

698 

699 axCorner.set_xlabel("R.A. (deg)", fontsize=7) 

700 axCorner.set_ylabel("Dec. (deg)", fontsize=7) 

701 axCorner.tick_params(axis="both", labelsize=6, length=0, pad=1.5) 

702 axCorner.invert_xaxis() 

703 

704 # Add a colorbar 

705 pos = axCorner.get_position() 

706 cax = fig.add_axes([pos.x0, pos.y0 + 0.23, pos.x1 - pos.x0, 0.025]) 

707 plt.colorbar(collection, cax=cax, orientation="horizontal") 

708 cax.text(0.5, 0.5, "Median Value", color="k", transform=cax.transAxes, rotation="horizontal", 

709 horizontalalignment="center", verticalalignment="center", fontsize=6) 

710 cax.tick_params(axis="x", labelsize=6, labeltop=True, labelbottom=False, bottom=False, top=True, 

711 pad=0.5, length=2) 

712 

713 plt.draw() 

714 plt.subplots_adjust(wspace=0.0, hspace=0.0, bottom=0.22, left=0.21) 

715 fig = plt.gcf() 

716 fig = addPlotInfo(fig, plotInfo) 

717 

718 return fig