Coverage for python/lsst/analysis/drp/scatterPlot.py: 10%
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import matplotlib.pyplot as plt
2import numpy as np
3import pandas as pd
4from scipy.stats import median_absolute_deviation as sigmaMad
5from matplotlib import gridspec
6from matplotlib.patches import Rectangle
7from matplotlib.path import Path
8from matplotlib.collections import PatchCollection
9from mpl_toolkits.axes_grid1 import make_axes_locatable
10from lsst.pipe.tasks.configurableActions import ConfigurableActionStructField
11from lsst.pipe.tasks.dataFrameActions import MagColumnNanoJansky, SingleColumnAction
12from lsst.skymap import BaseSkyMap
14import lsst.pipe.base as pipeBase
15import lsst.pex.config as pexConfig
17from . import dataSelectors as dataSelectors
18from .plotUtils import generateSummaryStats, parsePlotInfo, addPlotInfo, mkColormap
21class ScatterPlotWithTwoHistsTaskConnections(pipeBase.PipelineTaskConnections,
22 dimensions=("tract", "skymap"),
23 defaultTemplates={"inputCoaddName": "deep",
24 "plotName": "deltaCoords"}):
26 catPlot = pipeBase.connectionTypes.Input(doc="The tract wide catalog to make plots from.",
27 storageClass="DataFrame",
28 name="objectTable_tract",
29 dimensions=("tract", "skymap"),
30 deferLoad=True)
32 skymap = pipeBase.connectionTypes.Input(doc="The skymap for the tract",
33 storageClass="SkyMap",
34 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME,
35 dimensions=("skymap",))
37 scatterPlot = pipeBase.connectionTypes.Output(doc="A scatter plot with histograms for both axes.",
38 storageClass="Plot",
39 name="scatterTwoHistPlot_{plotName}",
40 dimensions=("tract", "skymap"))
43class ScatterPlotWithTwoHistsTaskConfig(pipeBase.PipelineTaskConfig,
44 pipelineConnections=ScatterPlotWithTwoHistsTaskConnections):
46 axisActions = ConfigurableActionStructField(
47 doc="The actions to use to calculate the values used on each axis. The defaults for the"
48 "column names xAction and magAction are set to iCModelFlux.",
49 default={"xAction": MagColumnNanoJansky, "yAction": SingleColumnAction,
50 "magAction": MagColumnNanoJansky},
51 )
53 axisLabels = pexConfig.DictField(
54 doc="Name of the dataframe columns to plot, will be used as the axis label: {'x':, 'y':, 'mag':}"
55 "The mag column is used to decide which points to include in the printed statistics.",
56 keytype=str,
57 itemtype=str
58 )
60 def get_requirements(self):
61 """Return inputs required for a Task to run with this config.
63 Returns
64 -------
65 bands : `set`
66 The required bands.
67 columns : `set`
68 The required column names.
69 """
70 columnNames = {"patch"}
71 bands = set()
72 for actionStruct in [self.axisActions,
73 self.selectorActions,
74 self.highSnStatisticSelectorActions,
75 self.lowSnStatisticSelectorActions,
76 self.sourceSelectorActions]:
77 for action in actionStruct:
78 for col in action.columns:
79 if col is not None:
80 columnNames.add(col)
81 column_split = col.split("_")
82 # If there's no underscore, it has no band prefix
83 if len(column_split) > 1:
84 band = column_split[0]
85 if band not in self.nonBandColumnPrefixes:
86 bands.add(band)
87 return bands, columnNames
89 nonBandColumnPrefixes = pexConfig.ListField(
90 doc="Column prefixes that are not bands and which should not be added to the set of bands",
91 dtype=str,
92 default=["coord", "extend", "detect", "xy", "merge"],
93 )
95 selectorActions = ConfigurableActionStructField(
96 doc="Which selectors to use to narrow down the data for QA plotting.",
97 default={"flagSelector": dataSelectors.CoaddPlotFlagSelector},
98 )
100 highSnStatisticSelectorActions = ConfigurableActionStructField(
101 doc="Selectors to use to decide which points to use for calculating the high SN statistics.",
102 default={"statSelector": dataSelectors.SnSelector},
103 )
105 lowSnStatisticSelectorActions = ConfigurableActionStructField(
106 doc="Selectors to use to decide which points to use for calculating the low SN statistics.",
107 default={"statSelector": dataSelectors.SnSelector},
108 )
110 sourceSelectorActions = ConfigurableActionStructField(
111 doc="What types of sources to use.",
112 default={"sourceSelector": dataSelectors.StarIdentifier},
113 )
115 nBins = pexConfig.Field(
116 doc="Number of bins to put on the x axis.",
117 default=40.0,
118 dtype=float,
119 )
121 plot2DHist = pexConfig.Field(
122 doc="Plot a 2D histogram in the densist area of points on the scatter plot."
123 "Doesn't look great if plotting mulitple datasets on top of each other.",
124 default=True,
125 dtype=bool,
126 )
128 def setDefaults(self):
129 super().setDefaults()
130 self.axisActions.magAction.column = "i_cModelFlux"
131 self.axisActions.xAction.column = "i_cModelFlux"
132 self.highSnStatisticSelectorActions.statSelector.threshold = 2700
133 self.lowSnStatisticSelectorActions.statSelector.threshold = 500
136class ScatterPlotWithTwoHistsTask(pipeBase.PipelineTask):
138 ConfigClass = ScatterPlotWithTwoHistsTaskConfig
139 _DefaultName = "scatterPlotWithTwoHistsTask"
141 def runQuantum(self, butlerQC, inputRefs, outputRefs):
142 # Docs inherited from base class
143 bands, columnNames = self.config.get_requirements()
144 inputs = butlerQC.get(inputRefs)
145 dataFrame = inputs["catPlot"].get(parameters={"columns": columnNames})
146 inputs['catPlot'] = dataFrame
147 dataId = butlerQC.quantum.dataId
148 inputs["dataId"] = dataId
149 inputs["runName"] = inputRefs.catPlot.datasetRef.run
150 localConnections = self.config.ConnectionsClass(config=self.config)
151 inputs["tableName"] = localConnections.catPlot.name
152 inputs["plotName"] = localConnections.scatterPlot.name
153 inputs["bands"] = bands
154 outputs = self.run(**inputs)
155 butlerQC.put(outputs, outputRefs)
157 def run(self, catPlot, dataId, runName, skymap, tableName, bands, plotName):
158 """Prep the catalogue and then make a scatterPlot of the given column.
160 Parameters
161 ----------
162 catPlot : `pandas.core.frame.DataFrame`
163 The catalog to plot the points from.
164 dataId :
165 `lsst.daf.butler.core.dimensions._coordinate._ExpandedTupleDataCoordinate`
166 The dimensions that the plot is being made from.
167 runName : `str`
168 The name of the collection that the plot is written out to.
169 skymap : `lsst.skymap`
170 The skymap used to define the patch boundaries.
171 tableName : `str`
172 The type of table used to make the plot.
174 Returns
175 -------
176 `pipeBase.Struct` containing:
177 scatterPlot : `matplotlib.figure.Figure`
178 The resulting figure.
180 Notes
181 -----
182 The catalogue is first narrowed down using the selectors specified in
183 `self.config.selectorActions`.
184 If the column names are 'Functor' then the functors specified in
185 `self.config.axisFunctors` are used to calculate the required values.
186 After this the following functions are run:
188 `parsePlotInfo` which uses the dataId, runName and tableName to add
189 useful information to the plot.
191 `generateSummaryStats` which parses the skymap to give the corners of
192 the patches for later plotting and calculates some basic statistics
193 in each patch for the column in self.zColName.
195 `scatterPlotWithTwoHists` which makes a scatter plot of the points with
196 a histogram of each axis.
197 """
199 # Apply the selectors to narrow down the sources to use
200 mask = np.ones(len(catPlot), dtype=bool)
201 for selector in self.config.selectorActions:
202 mask &= selector(catPlot)
203 catPlot = catPlot[mask]
205 columns = {self.config.axisLabels["x"]: self.config.axisActions.xAction(catPlot),
206 self.config.axisLabels["y"]: self.config.axisActions.yAction(catPlot),
207 self.config.axisLabels["mag"]: self.config.axisActions.magAction(catPlot),
208 "patch": catPlot["patch"]}
209 for actionStruct in [self.config.highSnStatisticSelectorActions,
210 self.config.lowSnStatisticSelectorActions,
211 self.config.sourceSelectorActions]:
212 for action in actionStruct:
213 for col in action.columns:
214 columns.update({col: catPlot[col]})
216 plotDf = pd.DataFrame(columns)
218 sourceTypes = np.zeros(len(plotDf))
219 for selector in self.config.sourceSelectorActions:
220 # The source selectors return 1 for a star and 2 for a galaxy
221 # rather than a mask this allows the information about which
222 # type of sources are being plotted to be propagated
223 sourceTypes += selector(catPlot)
224 if list(self.config.sourceSelectorActions) == []:
225 sourceTypes = [10]*len(plotDf)
226 plotDf.loc[:, "sourceType"] = sourceTypes
228 # Decide which points to use for stats calculation
229 useForStats = np.zeros(len(plotDf))
230 lowSnMask = np.ones(len(plotDf), dtype=bool)
231 for selector in self.config.lowSnStatisticSelectorActions:
232 lowSnMask &= selector(plotDf)
233 useForStats[lowSnMask] = 2
235 highSnMask = np.ones(len(plotDf), dtype=bool)
236 for selector in self.config.highSnStatisticSelectorActions:
237 highSnMask &= selector(plotDf)
238 useForStats[highSnMask] = 1
239 plotDf.loc[:, "useForStats"] = useForStats
241 # Get the S/N cut used
242 try:
243 SN = self.config.selectorActions.SnSelector.threshold
244 except AttributeError:
245 SN = "N/A"
247 # Get useful information about the plot
248 plotInfo = parsePlotInfo(dataId, runName, tableName, bands, plotName, SN)
249 # Calculate the corners of the patches and some associated stats
250 sumStats = {} if skymap is None else generateSummaryStats(
251 plotDf, self.config.axisLabels["y"], skymap, plotInfo)
252 # Make the plot
253 fig = self.scatterPlotWithTwoHists(plotDf, plotInfo, sumStats)
255 return pipeBase.Struct(scatterPlot=fig)
257 def scatterPlotWithTwoHists(self, catPlot, plotInfo, sumStats, yLims=False, xLims=False):
258 """Makes a generic plot with a 2D histogram and collapsed histograms of
259 each axis.
261 Parameters
262 ----------
263 catPlot : `pandas.core.frame.DataFrame`
264 The catalog to plot the points from.
265 plotInfo : `dict`
266 A dictionary of information about the data being plotted with keys:
267 ``"run"``
268 The output run for the plots (`str`).
269 ``"skymap"``
270 The type of skymap used for the data (`str`).
271 ``"filter"``
272 The filter used for this data (`str`).
273 ``"tract"``
274 The tract that the data comes from (`str`).
275 sumStats : `dict`
276 A dictionary where the patchIds are the keys which store the R.A.
277 and dec of the corners of the patch, along with a summary
278 statistic for each patch.
279 yLims : `Bool` or `tuple`, optional
280 The y axis limits to use for the plot. If `False`, they are
281 calculated from the data. If being given a tuple of
282 (yMin, yMax).
283 xLims : `Bool` or `tuple`, optional
284 The x axis limits to use for the plot. If `False`, they are
285 calculated from the data.
286 If being given a tuple of (xMin, xMax).
288 Returns
289 -------
290 fig : `matplotlib.figure.Figure`
291 The resulting figure.
293 Notes
294 -----
295 Uses the axisLabels config options `x` and `y` and the axisAction
296 config options `xAction` and `yAction` to plot a scatter
297 plot of the values against each other. A histogram of the points
298 collapsed onto each axis is also plotted. A summary panel showing the
299 median of the y value in each patch is shown in the upper right corner
300 of the resultant plot. The code uses the selectorActions to decide
301 which points to plot and the statisticSelector actions to determine
302 which points to use for the printed statistics.
303 """
304 self.log.info("Plotting {}: the values of {} on a scatter plot.".format(
305 self.config.connections.plotName, self.config.axisLabels["y"]))
307 fig = plt.figure(dpi=300)
308 gs = gridspec.GridSpec(4, 4)
310 newBlues = mkColormap(["paleturquoise", "midnightBlue"])
311 newReds = mkColormap(["lemonchiffon", "firebrick"])
313 # Need to separate stars and galaxies
314 stars = (catPlot["sourceType"] == 1)
315 galaxies = (catPlot["sourceType"] == 2)
317 xCol = self.config.axisLabels["x"]
318 yCol = self.config.axisLabels["y"]
319 magCol = self.config.axisLabels["mag"]
321 # For galaxies
322 xsGalaxies = catPlot.loc[galaxies, xCol]
323 ysGalaxies = catPlot.loc[galaxies, yCol]
325 # For stars
326 xsStars = catPlot.loc[stars, xCol]
327 ysStars = catPlot.loc[stars, yCol]
329 highStats = {}
330 highMags = {}
331 lowStats = {}
332 lowMags = {}
334 # sourceTypes: 1 - stars, 2 - galaxies, 9 - unknowns
335 # 10 - all
336 sourceTypeList = [1, 2, 9, 10]
337 sourceTypeMapper = {"stars": 1, "galaxies": 2, "unknowns": 9, "all": 10}
338 # Calculate some statistics
339 for sourceType in sourceTypeList:
340 if np.any(catPlot["sourceType"] == sourceType):
341 sources = (catPlot["sourceType"] == sourceType)
342 highSn = ((catPlot["useForStats"] == 1) & sources)
343 highSnMed = np.nanmedian(catPlot.loc[highSn, yCol])
344 highSnMad = sigmaMad(catPlot.loc[highSn, yCol], nan_policy="omit")
346 lowSn = (((catPlot["useForStats"] == 1) | (catPlot["useForStats"] == 2)) & sources)
347 lowSnMed = np.nanmedian(catPlot.loc[lowSn, yCol])
348 lowSnMad = sigmaMad(catPlot.loc[lowSn, yCol], nan_policy="omit")
350 highStatsStr = ("Median: {:0.2f} ".format(highSnMed)
351 + r"$\sigma_{MAD}$: " + "{:0.2f}".format(highSnMad))
352 highStats[sourceType] = highStatsStr
354 lowStatsStr = ("Median: {:0.2f} ".format(lowSnMed)
355 + r"$\sigma_{MAD}$: " + "{:0.2f}".format(lowSnMad))
356 lowStats[sourceType] = lowStatsStr
358 if np.sum(highSn) > 0:
359 highMags[sourceType] = f"{np.nanmax(catPlot.loc[highSn, magCol]):.2f}"
360 else:
361 highMags[sourceType] = "-"
362 if np.sum(lowSn) > 0.0:
363 lowMags[sourceType] = f"{np.nanmax(catPlot.loc[lowSn, magCol]):.2f}"
364 else:
365 lowMags[sourceType] = "-"
367 # Main scatter plot
368 ax = fig.add_subplot(gs[1:, :-1])
369 binThresh = 5
371 yBinsOut = []
372 linesForLegend = []
374 if (np.any(catPlot["sourceType"] == sourceTypeMapper["stars"])
375 and not np.any(catPlot["sourceType"] == sourceTypeMapper["galaxies"])):
376 toPlotList = [(xsStars.values, ysStars.values, "midnightblue", newBlues,
377 sourceTypeMapper["stars"])]
378 elif (np.any(catPlot["sourceType"] == sourceTypeMapper["galaxies"])
379 and not np.any(catPlot["sourceType"] == sourceTypeMapper["stars"])):
380 toPlotList = [(xsGalaxies.values, ysGalaxies.values, "firebrick", newReds,
381 sourceTypeMapper["galaxies"])]
382 elif (np.any(catPlot["sourceType"] == sourceTypeMapper["stars"])
383 and np.any(catPlot["sourceType"] == sourceTypeMapper["galaxies"])):
384 toPlotList = [(xsGalaxies.values, ysGalaxies.values, "firebrick", newReds,
385 sourceTypeMapper["galaxies"]),
386 (xsStars.values, ysStars.values, "midnightblue", newBlues,
387 sourceTypeMapper["stars"])]
388 if np.any(catPlot["sourceType"] == sourceTypeMapper["unknowns"]):
389 unknowns = (catPlot["sourceType"] == sourceTypeMapper["unknowns"])
390 toPlotList = [(catPlot.loc[unknowns, xCol].values, catPlot.loc[unknowns, yCol].values,
391 "green", sourceTypeMapper["unknowns"])]
392 if np.any(catPlot["sourceType"] == sourceTypeMapper["all"]):
393 toPlotList = [(catPlot[xCol].values, catPlot[yCol].values, "purple", sourceTypeMapper["all"])]
395 for (j, (xs, ys, color, cmap, sourceType)) in enumerate(toPlotList):
396 if len(xs) < 2:
397 medLine, = ax.plot(xs, np.nanmedian(ys), color,
398 label="Median: {:0.2f}".format(np.nanmedian(ys)), lw=0.8)
399 linesForLegend.append(medLine)
400 sigMads = np.array([sigmaMad(ys, nan_policy="omit")]*len(xs))
401 sigMadLine, = ax.plot(xs, np.nanmedian(ys) + 1.0*sigMads, color, alpha=0.8, lw=0.8,
402 label=r"$\sigma_{MAD}$: " + "{:0.2f}".format(sigMads[0]))
403 ax.plot(xs, np.nanmedian(ys) - 1.0*sigMads, color, alpha=0.8)
404 linesForLegend.append(sigMadLine)
405 continue
407 [xs1, xs25, xs50, xs75, xs95, xs97] = np.nanpercentile(xs, [1, 25, 50, 75, 95, 97])
408 xScale = (xs97 - xs1)/20.0 # This is ~5% of the data range
410 # 40 was used as the default number of bins because it looked good
411 xEdges = np.arange(np.nanmin(xs) - xScale, np.nanmax(xs) + xScale,
412 (np.nanmax(xs) + xScale - (np.nanmin(xs) - xScale))/self.config.nBins)
413 medYs = np.nanmedian(ys)
414 sigMadYs = sigmaMad(ys, nan_policy="omit")
415 fiveSigmaHigh = medYs + 5.0*sigMadYs
416 fiveSigmaLow = medYs - 5.0*sigMadYs
417 binSize = (fiveSigmaHigh - fiveSigmaLow)/101.0
418 yEdges = np.arange(fiveSigmaLow, fiveSigmaHigh, binSize)
420 counts, xBins, yBins = np.histogram2d(xs, ys, bins=(xEdges, yEdges))
421 yBinsOut.append(yBins)
422 countsYs = np.sum(counts, axis=1)
424 ids = np.where((countsYs > binThresh))[0]
425 xEdgesPlot = xEdges[ids][1:]
426 xEdges = xEdges[ids]
428 if len(ids) > 1:
429 # Create the codes needed to turn the sigmaMad lines
430 # into a path to speed up checking which points are
431 # inside the area.
432 codes = np.ones(len(xEdgesPlot)*2)*Path.LINETO
433 codes[0] = Path.MOVETO
434 codes[-1] = Path.CLOSEPOLY
436 meds = np.zeros(len(xEdgesPlot))
437 threeSigMadVerts = np.zeros((len(xEdgesPlot)*2, 2))
438 sigMads = np.zeros(len(xEdgesPlot))
440 for (i, xEdge) in enumerate(xEdgesPlot):
441 ids = np.where((xs < xEdge) & (xs > xEdges[i]) & (np.isfinite(ys)))[0]
442 med = np.median(ys[ids])
443 sigMad = sigmaMad(ys[ids])
444 meds[i] = med
445 sigMads[i] = sigMad
446 threeSigMadVerts[i, :] = [xEdge, med + 3*sigMad]
447 threeSigMadVerts[-(i + 1), :] = [xEdge, med - 3*sigMad]
449 medLine, = ax.plot(xEdgesPlot, meds, color, label="Running Median")
450 linesForLegend.append(medLine)
452 # Make path to check which points lie within one sigma mad
453 threeSigMadPath = Path(threeSigMadVerts, codes)
455 # Add lines for the median +/- 3 * sigma MAD
456 threeSigMadLine, = ax.plot(xEdgesPlot, threeSigMadVerts[:len(xEdgesPlot), 1], color,
457 alpha=0.4, label=r"3$\sigma_{MAD}$")
458 ax.plot(xEdgesPlot[::-1], threeSigMadVerts[len(xEdgesPlot):, 1], color, alpha=0.4)
460 # Add lines for the median +/- 1 * sigma MAD
461 sigMadLine, = ax.plot(xEdgesPlot, meds + 1.0*sigMads, color, alpha=0.8,
462 label=r"$\sigma_{MAD}$")
463 linesForLegend.append(sigMadLine)
464 ax.plot(xEdgesPlot, meds - 1.0*sigMads, color, alpha=0.8)
466 # Add lines for the median +/- 2 * sigma MAD
467 twoSigMadLine, = ax.plot(xEdgesPlot, meds + 2.0*sigMads, color, alpha=0.6,
468 label=r"2$\sigma_{MAD}$")
469 linesForLegend.append(twoSigMadLine)
470 linesForLegend.append(threeSigMadLine)
471 ax.plot(xEdgesPlot, meds - 2.0*sigMads, color, alpha=0.6)
473 # Check which points are outside 3 sigma MAD of the median
474 # and plot these as points.
475 inside = threeSigMadPath.contains_points(np.array([xs, ys]).T)
476 points, = ax.plot(xs[~inside], ys[~inside], ".", ms=3, alpha=0.3, mfc=color, mec=color,
477 zorder=-1)
479 # Add some stats text
480 xPos = 0.65 - 0.4*j
481 bbox = dict(edgecolor=color, linestyle="--", facecolor="none")
482 highThresh = self.config.highSnStatisticSelectorActions.statSelector.threshold
483 statText = f"S/N > {highThresh} Stats ({magCol} < {highMags[sourceType]})\n"
484 statText += highStats[sourceType]
485 fig.text(xPos, 0.087, statText, bbox=bbox, transform=fig.transFigure, fontsize=6)
486 if highMags[sourceType] != "-":
487 ax.axvline(float(highMags[sourceType]), color=color, ls="--")
489 bbox = dict(edgecolor=color, linestyle=":", facecolor="none")
490 lowThresh = self.config.lowSnStatisticSelectorActions.statSelector.threshold
491 statText = f"S/N > {lowThresh} Stats ({magCol} < {lowMags[sourceType]})\n"
492 statText += lowStats[sourceType]
493 fig.text(xPos, 0.017, statText, bbox=bbox, transform=fig.transFigure, fontsize=6)
494 if lowMags[sourceType] != "-":
495 ax.axvline(float(lowMags[sourceType]), color=color, ls=":")
497 if self.config.plot2DHist:
498 histIm = ax.hexbin(xs[inside], ys[inside], gridsize=75, cmap=cmap, mincnt=1, zorder=-2)
500 else:
501 points, = ax.plot(xs, ys, ".", ms=5, alpha=0.3, mfc=color, mec=color, zorder=-1)
502 meds = np.array([np.nanmedian(ys)]*len(xs))
503 medLine, = ax.plot(xs, meds, color, label=f"Median: {np.nanmedian(ys):0.2f}", lw=0.8)
504 linesForLegend.append(medLine)
505 sigMads = np.array([sigmaMad(ys, nan_policy="omit")]*len(xs))
506 sigMadLine, = ax.plot(xs, meds + 1.0*sigMads, color, alpha=0.8, lw=0.8,
507 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:0.2f}")
508 ax.plot(xs, meds - 1.0*sigMads, color, alpha=0.8)
509 linesForLegend.append(sigMadLine)
510 histIm = None
512 # Set the scatter plot limits
513 if len(ysStars) > 0:
514 plotMed = np.nanmedian(ysStars)
515 else:
516 plotMed = np.nanmedian(ysGalaxies)
517 if yLims:
518 ax.set_ylim(yLims[0], yLims[1])
519 else:
520 numSig = 4
521 yLimMin = plotMed - numSig*sigMadYs
522 yLimMax = plotMed + numSig*sigMadYs
523 while (yLimMax < np.max(meds) or yLimMin > np.min(meds)) and numSig < 10:
524 numSig += 1
526 numSig += 1
527 yLimMin = plotMed - numSig*sigMadYs
528 yLimMax = plotMed + numSig*sigMadYs
529 ax.set_ylim(yLimMin, yLimMax)
531 if xLims:
532 ax.set_xlim(xLims[0], xLims[1])
533 else:
534 ax.set_xlim(xs1 - xScale, xs97)
535 ax.set_xlim(np.nanmin(xs), np.nanmax(xs))
537 # Add a line legend
538 ax.legend(handles=linesForLegend, ncol=4, fontsize=6, loc="upper left", framealpha=0.9,
539 edgecolor="k", borderpad=0.4, handlelength=1)
541 # Add axes labels
542 ax.set_ylabel(yCol, fontsize=10, labelpad=10)
543 ax.set_xlabel(xCol, fontsize=10, labelpad=2)
545 # Top histogram
546 topHist = plt.gcf().add_subplot(gs[0, :-1], sharex=ax)
547 topHist.hist(catPlot[xCol].values, bins=100, color="grey", alpha=0.3, log=True,
548 label=f"All ({len(catPlot)})")
549 if np.any(catPlot["sourceType"] == 2):
550 topHist.hist(xsGalaxies, bins=100, color="firebrick", histtype="step", log=True,
551 label=f"Galaxies ({len(np.where(galaxies)[0])})")
552 if np.any(catPlot["sourceType"] == 1):
553 topHist.hist(xsStars, bins=100, color="midnightblue", histtype="step", log=True,
554 label=f"Stars ({len(np.where(stars)[0])})")
555 topHist.axes.get_xaxis().set_visible(False)
556 topHist.set_ylabel("Number", fontsize=8)
557 topHist.legend(fontsize=6, framealpha=0.9, borderpad=0.4, loc="lower left", ncol=3, edgecolor="k")
559 # Side histogram
560 sideHist = plt.gcf().add_subplot(gs[1:, -1], sharey=ax)
561 finiteObjs = np.isfinite(catPlot[yCol].values)
562 bins = np.linspace(yLimMin, yLimMax)
563 sideHist.hist(catPlot[yCol].values[finiteObjs], bins=bins, color="grey", alpha=0.3,
564 orientation="horizontal", log=True)
565 if np.any(catPlot["sourceType"] == sourceTypeMapper["galaxies"]):
566 sideHist.hist(ysGalaxies[np.isfinite(ysGalaxies)], bins=bins, color="firebrick", histtype="step",
567 orientation="horizontal", log=True)
568 if highMags[sourceTypeMapper["galaxies"]] != "-":
569 sideHist.hist(ysGalaxies[np.isfinite(ysGalaxies) & (xsGalaxies < float(highMags[2]))],
570 bins=bins, color="firebrick", histtype="step", orientation="horizontal",
571 log=True, ls="--")
572 if lowMags[sourceTypeMapper["galaxies"]] != "-":
573 sideHist.hist(ysGalaxies[np.isfinite(ysGalaxies) & (xsGalaxies < float(lowMags[2]))],
574 bins=bins, color="firebrick", histtype="step", orientation="horizontal",
575 log=True, ls=":")
577 if np.any(catPlot["sourceType"] == sourceTypeMapper["stars"]):
578 sideHist.hist(ysStars[np.isfinite(ysStars)], bins=bins, color="midnightblue", histtype="step",
579 orientation="horizontal", log=True)
580 if highMags[sourceTypeMapper["stars"]] != "-":
581 sideHist.hist(ysStars[np.isfinite(ysStars) & (xsStars < float(highMags[1]))], bins=bins,
582 color="midnightblue", histtype="step", orientation="horizontal", log=True,
583 ls="--")
584 if lowMags[sourceTypeMapper["stars"]] != "-":
585 sideHist.hist(ysStars[np.isfinite(ysStars) & (xsStars < float(lowMags[1]))], bins=bins,
586 color="midnightblue", histtype="step", orientation="horizontal", log=True,
587 ls=":")
589 sideHist.axes.get_yaxis().set_visible(False)
590 sideHist.set_xlabel("Number", fontsize=8)
591 if self.config.plot2DHist and histIm is not None:
592 divider = make_axes_locatable(sideHist)
593 cax = divider.append_axes("right", size="8%", pad=0)
594 plt.colorbar(histIm, cax=cax, orientation="vertical", label="Number of Points Per Bin")
596 # Corner plot of patches showing summary stat in each
597 axCorner = plt.gcf().add_subplot(gs[0, -1])
598 axCorner.yaxis.tick_right()
599 axCorner.yaxis.set_label_position("right")
600 axCorner.xaxis.tick_top()
601 axCorner.xaxis.set_label_position("top")
602 axCorner.set_aspect("equal")
604 patches = []
605 colors = []
606 for dataId in sumStats.keys():
607 (corners, stat) = sumStats[dataId]
608 ra = corners[0][0].asDegrees()
609 dec = corners[0][1].asDegrees()
610 xy = (ra, dec)
611 width = corners[2][0].asDegrees() - ra
612 height = corners[2][1].asDegrees() - dec
613 patches.append(Rectangle(xy, width, height))
614 colors.append(stat)
615 ras = [ra.asDegrees() for (ra, dec) in corners]
616 decs = [dec.asDegrees() for (ra, dec) in corners]
617 axCorner.plot(ras + [ras[0]], decs + [decs[0]], "k", lw=0.5)
618 cenX = ra + width / 2
619 cenY = dec + height / 2
620 if dataId != "tract":
621 axCorner.annotate(dataId, (cenX, cenY), color="k", fontsize=4, ha="center", va="center")
623 cmapUse = plt.cm.coolwarm
624 # Set the bad color to transparent and make a masked array
625 cmapUse.set_bad(color="none")
626 colors = np.ma.array(colors, mask=np.isnan(colors))
627 collection = PatchCollection(patches, cmap=cmapUse)
628 collection.set_array(colors)
629 axCorner.add_collection(collection)
631 axCorner.set_xlabel("R.A. (deg)", fontsize=7)
632 axCorner.set_ylabel("Dec. (deg)", fontsize=7)
633 axCorner.tick_params(axis="both", labelsize=6, length=0, pad=1.5)
634 axCorner.invert_xaxis()
636 # Add a colorbar
637 pos = axCorner.get_position()
638 cax = fig.add_axes([pos.x0, pos.y0 + 0.23, pos.x1 - pos.x0, 0.025])
639 plt.colorbar(collection, cax=cax, orientation="horizontal")
640 cax.text(0.5, 0.5, "Median Value", color="k", transform=cax.transAxes, rotation="horizontal",
641 horizontalalignment="center", verticalalignment="center", fontsize=6)
642 cax.tick_params(axis="x", labelsize=6, labeltop=True, labelbottom=False, bottom=False, top=True,
643 pad=0.5, length=2)
645 plt.draw()
646 plt.subplots_adjust(wspace=0.0, hspace=0.0, bottom=0.22, left=0.21)
647 fig = plt.gcf()
648 fig = addPlotInfo(fig, plotInfo)
650 return fig