Coverage for python/lsst/analysis/drp/scatterPlot.py: 10%
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import matplotlib.pyplot as plt
2import numpy as np
3import pandas as pd
4from scipy.stats import median_absolute_deviation as sigmaMad
5from matplotlib import gridspec
6from matplotlib.patches import Rectangle
7from matplotlib.path import Path
8from matplotlib.collections import PatchCollection
9from mpl_toolkits.axes_grid1 import make_axes_locatable
10from lsst.pipe.tasks.configurableActions import ConfigurableActionStructField
11from lsst.pipe.tasks.dataFrameActions import MagColumnNanoJansky, SingleColumnAction
12from lsst.skymap import BaseSkyMap
14import lsst.pipe.base as pipeBase
15import lsst.pex.config as pexConfig
17from . import dataSelectors as dataSelectors
18from .plotUtils import generateSummaryStats, parsePlotInfo, addPlotInfo, mkColormap
21class ScatterPlotWithTwoHistsTaskConnections(pipeBase.PipelineTaskConnections,
22 dimensions=("tract", "skymap"),
23 defaultTemplates={"inputCoaddName": "deep",
24 "plotName": "deltaCoords"}):
26 catPlot = pipeBase.connectionTypes.Input(doc="The tract wide catalog to make plots from.",
27 storageClass="DataFrame",
28 name="objectTable_tract",
29 dimensions=("tract", "skymap"),
30 deferLoad=True)
32 skymap = pipeBase.connectionTypes.Input(doc="The skymap for the tract",
33 storageClass="SkyMap",
34 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME,
35 dimensions=("skymap",))
37 scatterPlot = pipeBase.connectionTypes.Output(doc="A scatter plot with histograms for both axes.",
38 storageClass="Plot",
39 name="scatterTwoHistPlot_{plotName}",
40 dimensions=("tract", "skymap"))
43class ScatterPlotWithTwoHistsTaskConfig(pipeBase.PipelineTaskConfig,
44 pipelineConnections=ScatterPlotWithTwoHistsTaskConnections):
46 axisActions = ConfigurableActionStructField(
47 doc="The actions to use to calculate the values used on each axis. The defaults for the"
48 "column names xAction and magAction are set to iCModelFlux.",
49 default={"xAction": MagColumnNanoJansky, "yAction": SingleColumnAction,
50 "magAction": MagColumnNanoJansky},
51 )
53 axisLabels = pexConfig.DictField(
54 doc="Name of the dataframe columns to plot, will be used as the axis label: {'x':, 'y':, 'mag':}"
55 "The mag column is used to decide which points to include in the printed statistics.",
56 keytype=str,
57 itemtype=str
58 )
60 def get_requirements(self):
61 """Return inputs required for a Task to run with this config.
63 Returns
64 -------
65 bands : `set`
66 The required bands.
67 columns : `set`
68 The required column names.
69 """
70 columnNames = {"patch"}
71 bands = set()
72 for actionStruct in [self.axisActions,
73 self.selectorActions,
74 self.highSnStatisticSelectorActions,
75 self.lowSnStatisticSelectorActions,
76 self.sourceSelectorActions]:
77 for action in actionStruct:
78 for col in action.columns:
79 if col is not None:
80 columnNames.add(col)
81 column_split = col.split("_")
82 # If there's no underscore, it has no band prefix
83 if len(column_split) > 1:
84 band = column_split[0]
85 if band not in self.nonBandColumnPrefixes:
86 bands.add(band)
87 return bands, columnNames
89 nonBandColumnPrefixes = pexConfig.ListField(
90 doc="Column prefixes that are not bands and which should not be added to the set of bands",
91 dtype=str,
92 default=["coord", "extend", "detect", "xy", "merge"],
93 )
95 selectorActions = ConfigurableActionStructField(
96 doc="Which selectors to use to narrow down the data for QA plotting.",
97 default={"flagSelector": dataSelectors.CoaddPlotFlagSelector},
98 )
100 highSnStatisticSelectorActions = ConfigurableActionStructField(
101 doc="Selectors to use to decide which points to use for calculating the high SN statistics.",
102 default={"statSelector": dataSelectors.SnSelector},
103 )
105 lowSnStatisticSelectorActions = ConfigurableActionStructField(
106 doc="Selectors to use to decide which points to use for calculating the low SN statistics.",
107 default={"statSelector": dataSelectors.SnSelector},
108 )
110 sourceSelectorActions = ConfigurableActionStructField(
111 doc="What types of sources to use.",
112 default={"sourceSelector": dataSelectors.StarIdentifier},
113 )
115 nBins = pexConfig.Field(
116 doc="Number of bins to put on the x axis.",
117 default=40.0,
118 dtype=float,
119 )
121 plot2DHist = pexConfig.Field(
122 doc="Plot a 2D histogram in the densist area of points on the scatter plot."
123 "Doesn't look great if plotting mulitple datasets on top of each other.",
124 default=True,
125 dtype=bool,
126 )
128 def setDefaults(self):
129 super().setDefaults()
130 self.axisActions.magAction.column = "i_cModelFlux"
131 self.axisActions.xAction.column = "i_cModelFlux"
132 self.highSnStatisticSelectorActions.statSelector.threshold = 2700
133 self.lowSnStatisticSelectorActions.statSelector.threshold = 500
136class ScatterPlotWithTwoHistsTask(pipeBase.PipelineTask):
138 ConfigClass = ScatterPlotWithTwoHistsTaskConfig
139 _DefaultName = "scatterPlotWithTwoHistsTask"
141 def runQuantum(self, butlerQC, inputRefs, outputRefs):
142 # Docs inherited from base class
143 bands, columnNames = self.config.get_requirements()
144 inputs = butlerQC.get(inputRefs)
145 dataFrame = inputs["catPlot"].get(parameters={"columns": columnNames})
146 inputs['catPlot'] = dataFrame
147 dataId = butlerQC.quantum.dataId
148 inputs["dataId"] = dataId
149 inputs["runName"] = inputRefs.catPlot.datasetRef.run
150 localConnections = self.config.ConnectionsClass(config=self.config)
151 inputs["tableName"] = localConnections.catPlot.name
152 inputs["plotName"] = localConnections.scatterPlot.name
153 inputs["bands"] = bands
154 outputs = self.run(**inputs)
155 butlerQC.put(outputs, outputRefs)
157 def run(self, catPlot, dataId, runName, skymap, tableName, bands, plotName):
158 """Prep the catalogue and then make a scatterPlot of the given column.
160 Parameters
161 ----------
162 catPlot : `pandas.core.frame.DataFrame`
163 The catalog to plot the points from.
164 dataId :
165 `lsst.daf.butler.core.dimensions._coordinate._ExpandedTupleDataCoordinate`
166 The dimensions that the plot is being made from.
167 runName : `str`
168 The name of the collection that the plot is written out to.
169 skymap : `lsst.skymap`
170 The skymap used to define the patch boundaries.
171 tableName : `str`
172 The type of table used to make the plot.
174 Returns
175 -------
176 `pipeBase.Struct` containing:
177 scatterPlot : `matplotlib.figure.Figure`
178 The resulting figure.
180 Notes
181 -----
182 The catalogue is first narrowed down using the selectors specified in
183 `self.config.selectorActions`.
184 If the column names are 'Functor' then the functors specified in
185 `self.config.axisFunctors` are used to calculate the required values.
186 After this the following functions are run:
188 `parsePlotInfo` which uses the dataId, runName and tableName to add
189 useful information to the plot.
191 `generateSummaryStats` which parses the skymap to give the corners of
192 the patches for later plotting and calculates some basic statistics
193 in each patch for the column in self.zColName.
195 `scatterPlotWithTwoHists` which makes a scatter plot of the points with
196 a histogram of each axis.
197 """
199 # Apply the selectors to narrow down the sources to use
200 mask = np.ones(len(catPlot), dtype=bool)
201 for selector in self.config.selectorActions:
202 mask &= selector(catPlot)
203 catPlot = catPlot[mask]
205 columns = {self.config.axisLabels["x"]: self.config.axisActions.xAction(catPlot),
206 self.config.axisLabels["y"]: self.config.axisActions.yAction(catPlot),
207 self.config.axisLabels["mag"]: self.config.axisActions.magAction(catPlot),
208 "patch": catPlot["patch"]}
209 for actionStruct in [self.config.highSnStatisticSelectorActions,
210 self.config.lowSnStatisticSelectorActions,
211 self.config.sourceSelectorActions]:
212 for action in actionStruct:
213 for col in action.columns:
214 columns.update({col: catPlot[col]})
216 plotDf = pd.DataFrame(columns)
218 sourceTypes = np.zeros(len(plotDf))
219 for selector in self.config.sourceSelectorActions:
220 # The source selectors return 1 for a star and 2 for a galaxy
221 # rather than a mask this allows the information about which
222 # type of sources are being plotted to be propagated
223 sourceTypes += selector(catPlot)
224 if list(self.config.sourceSelectorActions) == []:
225 sourceTypes = [10]*len(plotDf)
226 plotDf.loc[:, "sourceType"] = sourceTypes
228 # Decide which points to use for stats calculation
229 useForStats = np.zeros(len(plotDf))
230 lowSnMask = np.ones(len(plotDf), dtype=bool)
231 for selector in self.config.lowSnStatisticSelectorActions:
232 lowSnMask &= selector(plotDf)
233 useForStats[lowSnMask] = 2
235 highSnMask = np.ones(len(plotDf), dtype=bool)
236 for selector in self.config.highSnStatisticSelectorActions:
237 highSnMask &= selector(plotDf)
238 useForStats[highSnMask] = 1
239 plotDf.loc[:, "useForStats"] = useForStats
241 # Get the S/N cut used
242 try:
243 SN = self.config.selectorActions.SnSelector.threshold
244 except AttributeError:
245 SN = "N/A"
247 # Get useful information about the plot
248 plotInfo = parsePlotInfo(dataId, runName, tableName, bands, plotName, SN)
249 # Calculate the corners of the patches and some associated stats
250 sumStats = {} if skymap is None else generateSummaryStats(
251 plotDf, self.config.axisLabels["y"], skymap, plotInfo)
252 # Make the plot
253 try:
254 fig = self.scatterPlotWithTwoHists(plotDf, plotInfo, sumStats)
255 except Exception:
256 # This is a workaround until scatterPlotWithTwoHists works properly
257 fig = plt.figure(dpi=300)
259 return pipeBase.Struct(scatterPlot=fig)
261 def scatterPlotWithTwoHists(self, catPlot, plotInfo, sumStats, yLims=False, xLims=False):
262 """Makes a generic plot with a 2D histogram and collapsed histograms of
263 each axis.
265 Parameters
266 ----------
267 catPlot : `pandas.core.frame.DataFrame`
268 The catalog to plot the points from.
269 plotInfo : `dict`
270 A dictionary of information about the data being plotted with keys:
271 ``"run"``
272 The output run for the plots (`str`).
273 ``"skymap"``
274 The type of skymap used for the data (`str`).
275 ``"filter"``
276 The filter used for this data (`str`).
277 ``"tract"``
278 The tract that the data comes from (`str`).
279 sumStats : `dict`
280 A dictionary where the patchIds are the keys which store the R.A.
281 and dec of the corners of the patch, along with a summary
282 statistic for each patch.
283 yLims : `Bool` or `tuple`, optional
284 The y axis limits to use for the plot. If `False`, they are
285 calculated from the data. If being given a tuple of
286 (yMin, yMax).
287 xLims : `Bool` or `tuple`, optional
288 The x axis limits to use for the plot. If `False`, they are
289 calculated from the data.
290 If being given a tuple of (xMin, xMax).
292 Returns
293 -------
294 fig : `matplotlib.figure.Figure`
295 The resulting figure.
297 Notes
298 -----
299 Uses the axisLabels config options `x` and `y` and the axisAction
300 config options `xAction` and `yAction` to plot a scatter
301 plot of the values against each other. A histogram of the points
302 collapsed onto each axis is also plotted. A summary panel showing the
303 median of the y value in each patch is shown in the upper right corner
304 of the resultant plot. The code uses the selectorActions to decide
305 which points to plot and the statisticSelector actions to determine
306 which points to use for the printed statistics.
307 """
308 self.log.info("Plotting {}: the values of {} on a scatter plot.".format(
309 self.config.connections.plotName, self.config.axisLabels["y"]))
311 fig = plt.figure(dpi=300)
312 gs = gridspec.GridSpec(4, 4)
314 newBlues = mkColormap(["paleturquoise", "midnightBlue"])
315 newReds = mkColormap(["lemonchiffon", "firebrick"])
317 # Need to separate stars and galaxies
318 stars = (catPlot["sourceType"] == 1)
319 galaxies = (catPlot["sourceType"] == 2)
321 xCol = self.config.axisLabels["x"]
322 yCol = self.config.axisLabels["y"]
323 magCol = self.config.axisLabels["mag"]
325 # For galaxies
326 xsGalaxies = catPlot.loc[galaxies, xCol]
327 ysGalaxies = catPlot.loc[galaxies, yCol]
329 # For stars
330 xsStars = catPlot.loc[stars, xCol]
331 ysStars = catPlot.loc[stars, yCol]
333 highStats = {}
334 highMags = {}
335 lowStats = {}
336 lowMags = {}
338 # sourceTypes: 1 - stars, 2 - galaxies, 9 - unknowns
339 # 10 - all
340 sourceTypeList = [1, 2, 9, 10]
341 sourceTypeMapper = {"stars": 1, "galaxies": 2, "unknowns": 9, "all": 10}
342 # Calculate some statistics
343 for sourceType in sourceTypeList:
344 if np.any(catPlot["sourceType"] == sourceType):
345 sources = (catPlot["sourceType"] == sourceType)
346 highSn = ((catPlot["useForStats"] == 1) & sources)
347 highSnMed = np.nanmedian(catPlot.loc[highSn, yCol])
348 highSnMad = sigmaMad(catPlot.loc[highSn, yCol], nan_policy="omit")
350 lowSn = (((catPlot["useForStats"] == 1) | (catPlot["useForStats"] == 2)) & sources)
351 lowSnMed = np.nanmedian(catPlot.loc[lowSn, yCol])
352 lowSnMad = sigmaMad(catPlot.loc[lowSn, yCol], nan_policy="omit")
354 highStatsStr = ("Median: {:0.2f} ".format(highSnMed)
355 + r"$\sigma_{MAD}$: " + "{:0.2f}".format(highSnMad))
356 highStats[sourceType] = highStatsStr
358 lowStatsStr = ("Median: {:0.2f} ".format(lowSnMed)
359 + r"$\sigma_{MAD}$: " + "{:0.2f}".format(lowSnMad))
360 lowStats[sourceType] = lowStatsStr
362 if np.sum(highSn) > 0:
363 highMags[sourceType] = f"{np.nanmax(catPlot.loc[highSn, magCol]):.2f}"
364 else:
365 highMags[sourceType] = "-"
366 if np.sum(lowSn) > 0.0:
367 lowMags[sourceType] = f"{np.nanmax(catPlot.loc[lowSn, magCol]):.2f}"
368 else:
369 lowMags[sourceType] = "-"
371 # Main scatter plot
372 ax = fig.add_subplot(gs[1:, :-1])
373 binThresh = 5
375 yBinsOut = []
376 linesForLegend = []
378 if (np.any(catPlot["sourceType"] == sourceTypeMapper["stars"])
379 and not np.any(catPlot["sourceType"] == sourceTypeMapper["galaxies"])):
380 toPlotList = [(xsStars.values, ysStars.values, "midnightblue", newBlues,
381 sourceTypeMapper["stars"])]
382 elif (np.any(catPlot["sourceType"] == sourceTypeMapper["galaxies"])
383 and not np.any(catPlot["sourceType"] == sourceTypeMapper["stars"])):
384 toPlotList = [(xsGalaxies.values, ysGalaxies.values, "firebrick", newReds,
385 sourceTypeMapper["galaxies"])]
386 elif (np.any(catPlot["sourceType"] == sourceTypeMapper["stars"])
387 and np.any(catPlot["sourceType"] == sourceTypeMapper["galaxies"])):
388 toPlotList = [(xsGalaxies.values, ysGalaxies.values, "firebrick", newReds,
389 sourceTypeMapper["galaxies"]),
390 (xsStars.values, ysStars.values, "midnightblue", newBlues,
391 sourceTypeMapper["stars"])]
392 elif np.any(catPlot["sourceType"] == sourceTypeMapper["unknowns"]):
393 unknowns = (catPlot["sourceType"] == sourceTypeMapper["unknowns"])
394 toPlotList = [(catPlot.loc[unknowns, xCol].values, catPlot.loc[unknowns, yCol].values,
395 "green", None, sourceTypeMapper["unknowns"])]
396 elif np.any(catPlot["sourceType"] == sourceTypeMapper["all"]):
397 toPlotList = [(catPlot[xCol].values, catPlot[yCol].values, "purple", None,
398 sourceTypeMapper["all"])]
399 else:
400 toPlotList = []
402 for (j, (xs, ys, color, cmap, sourceType)) in enumerate(toPlotList):
403 if len(xs) < 2:
404 medLine, = ax.plot(xs, np.nanmedian(ys), color,
405 label="Median: {:0.2f}".format(np.nanmedian(ys)), lw=0.8)
406 linesForLegend.append(medLine)
407 sigMads = np.array([sigmaMad(ys, nan_policy="omit")]*len(xs))
408 sigMadLine, = ax.plot(xs, np.nanmedian(ys) + 1.0*sigMads, color, alpha=0.8, lw=0.8,
409 label=r"$\sigma_{MAD}$: " + "{:0.2f}".format(sigMads[0]))
410 ax.plot(xs, np.nanmedian(ys) - 1.0*sigMads, color, alpha=0.8)
411 linesForLegend.append(sigMadLine)
412 continue
414 [xs1, xs25, xs50, xs75, xs95, xs97] = np.nanpercentile(xs, [1, 25, 50, 75, 95, 97])
415 xScale = (xs97 - xs1)/20.0 # This is ~5% of the data range
417 # 40 was used as the default number of bins because it looked good
418 xEdges = np.arange(np.nanmin(xs) - xScale, np.nanmax(xs) + xScale,
419 (np.nanmax(xs) + xScale - (np.nanmin(xs) - xScale))/self.config.nBins)
420 medYs = np.nanmedian(ys)
421 sigMadYs = sigmaMad(ys, nan_policy="omit")
422 fiveSigmaHigh = medYs + 5.0*sigMadYs
423 fiveSigmaLow = medYs - 5.0*sigMadYs
424 binSize = (fiveSigmaHigh - fiveSigmaLow)/101.0
425 yEdges = np.arange(fiveSigmaLow, fiveSigmaHigh, binSize)
427 counts, xBins, yBins = np.histogram2d(xs, ys, bins=(xEdges, yEdges))
428 yBinsOut.append(yBins)
429 countsYs = np.sum(counts, axis=1)
431 ids = np.where((countsYs > binThresh))[0]
432 xEdgesPlot = xEdges[ids][1:]
433 xEdges = xEdges[ids]
435 if len(ids) > 1:
436 # Create the codes needed to turn the sigmaMad lines
437 # into a path to speed up checking which points are
438 # inside the area.
439 codes = np.ones(len(xEdgesPlot)*2)*Path.LINETO
440 codes[0] = Path.MOVETO
441 codes[-1] = Path.CLOSEPOLY
443 meds = np.zeros(len(xEdgesPlot))
444 threeSigMadVerts = np.zeros((len(xEdgesPlot)*2, 2))
445 sigMads = np.zeros(len(xEdgesPlot))
447 for (i, xEdge) in enumerate(xEdgesPlot):
448 ids = np.where((xs < xEdge) & (xs > xEdges[i]) & (np.isfinite(ys)))[0]
449 med = np.median(ys[ids])
450 sigMad = sigmaMad(ys[ids])
451 meds[i] = med
452 sigMads[i] = sigMad
453 threeSigMadVerts[i, :] = [xEdge, med + 3*sigMad]
454 threeSigMadVerts[-(i + 1), :] = [xEdge, med - 3*sigMad]
456 medLine, = ax.plot(xEdgesPlot, meds, color, label="Running Median")
457 linesForLegend.append(medLine)
459 # Make path to check which points lie within one sigma mad
460 threeSigMadPath = Path(threeSigMadVerts, codes)
462 # Add lines for the median +/- 3 * sigma MAD
463 threeSigMadLine, = ax.plot(xEdgesPlot, threeSigMadVerts[:len(xEdgesPlot), 1], color,
464 alpha=0.4, label=r"3$\sigma_{MAD}$")
465 ax.plot(xEdgesPlot[::-1], threeSigMadVerts[len(xEdgesPlot):, 1], color, alpha=0.4)
467 # Add lines for the median +/- 1 * sigma MAD
468 sigMadLine, = ax.plot(xEdgesPlot, meds + 1.0*sigMads, color, alpha=0.8,
469 label=r"$\sigma_{MAD}$")
470 linesForLegend.append(sigMadLine)
471 ax.plot(xEdgesPlot, meds - 1.0*sigMads, color, alpha=0.8)
473 # Add lines for the median +/- 2 * sigma MAD
474 twoSigMadLine, = ax.plot(xEdgesPlot, meds + 2.0*sigMads, color, alpha=0.6,
475 label=r"2$\sigma_{MAD}$")
476 linesForLegend.append(twoSigMadLine)
477 linesForLegend.append(threeSigMadLine)
478 ax.plot(xEdgesPlot, meds - 2.0*sigMads, color, alpha=0.6)
480 # Check which points are outside 3 sigma MAD of the median
481 # and plot these as points.
482 inside = threeSigMadPath.contains_points(np.array([xs, ys]).T)
483 points, = ax.plot(xs[~inside], ys[~inside], ".", ms=3, alpha=0.3, mfc=color, mec=color,
484 zorder=-1)
486 # Add some stats text
487 xPos = 0.65 - 0.4*j
488 bbox = dict(edgecolor=color, linestyle="--", facecolor="none")
489 highThresh = self.config.highSnStatisticSelectorActions.statSelector.threshold
490 statText = f"S/N > {highThresh} Stats ({magCol} < {highMags[sourceType]})\n"
491 statText += highStats[sourceType]
492 fig.text(xPos, 0.087, statText, bbox=bbox, transform=fig.transFigure, fontsize=6)
493 if highMags[sourceType] != "-":
494 ax.axvline(float(highMags[sourceType]), color=color, ls="--")
496 bbox = dict(edgecolor=color, linestyle=":", facecolor="none")
497 lowThresh = self.config.lowSnStatisticSelectorActions.statSelector.threshold
498 statText = f"S/N > {lowThresh} Stats ({magCol} < {lowMags[sourceType]})\n"
499 statText += lowStats[sourceType]
500 fig.text(xPos, 0.017, statText, bbox=bbox, transform=fig.transFigure, fontsize=6)
501 if lowMags[sourceType] != "-":
502 ax.axvline(float(lowMags[sourceType]), color=color, ls=":")
504 if self.config.plot2DHist:
505 histIm = ax.hexbin(xs[inside], ys[inside], gridsize=75, cmap=cmap, mincnt=1, zorder=-2)
507 else:
508 points, = ax.plot(xs, ys, ".", ms=5, alpha=0.3, mfc=color, mec=color, zorder=-1)
509 meds = np.array([np.nanmedian(ys)]*len(xs))
510 medLine, = ax.plot(xs, meds, color, label=f"Median: {np.nanmedian(ys):0.2f}", lw=0.8)
511 linesForLegend.append(medLine)
512 sigMads = np.array([sigmaMad(ys, nan_policy="omit")]*len(xs))
513 sigMadLine, = ax.plot(xs, meds + 1.0*sigMads, color, alpha=0.8, lw=0.8,
514 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:0.2f}")
515 ax.plot(xs, meds - 1.0*sigMads, color, alpha=0.8)
516 linesForLegend.append(sigMadLine)
517 histIm = None
519 # Set the scatter plot limits
520 if len(ysStars) > 0:
521 plotMed = np.nanmedian(ysStars)
522 else:
523 plotMed = np.nanmedian(ysGalaxies)
524 if yLims:
525 ax.set_ylim(yLims[0], yLims[1])
526 else:
527 numSig = 4
528 yLimMin = plotMed - numSig*sigMadYs
529 yLimMax = plotMed + numSig*sigMadYs
530 while (yLimMax < np.max(meds) or yLimMin > np.min(meds)) and numSig < 10:
531 numSig += 1
533 numSig += 1
534 yLimMin = plotMed - numSig*sigMadYs
535 yLimMax = plotMed + numSig*sigMadYs
536 ax.set_ylim(yLimMin, yLimMax)
538 if xLims:
539 ax.set_xlim(xLims[0], xLims[1])
540 else:
541 ax.set_xlim(xs1 - xScale, xs97)
542 ax.set_xlim(np.nanmin(xs), np.nanmax(xs))
544 # Add a line legend
545 ax.legend(handles=linesForLegend, ncol=4, fontsize=6, loc="upper left", framealpha=0.9,
546 edgecolor="k", borderpad=0.4, handlelength=1)
548 # Add axes labels
549 ax.set_ylabel(yCol, fontsize=10, labelpad=10)
550 ax.set_xlabel(xCol, fontsize=10, labelpad=2)
552 # Top histogram
553 topHist = plt.gcf().add_subplot(gs[0, :-1], sharex=ax)
554 topHist.hist(catPlot[xCol].values, bins=100, color="grey", alpha=0.3, log=True,
555 label=f"All ({len(catPlot)})")
556 if np.any(catPlot["sourceType"] == 2):
557 topHist.hist(xsGalaxies, bins=100, color="firebrick", histtype="step", log=True,
558 label=f"Galaxies ({len(np.where(galaxies)[0])})")
559 if np.any(catPlot["sourceType"] == 1):
560 topHist.hist(xsStars, bins=100, color="midnightblue", histtype="step", log=True,
561 label=f"Stars ({len(np.where(stars)[0])})")
562 topHist.axes.get_xaxis().set_visible(False)
563 topHist.set_ylabel("Number", fontsize=8)
564 topHist.legend(fontsize=6, framealpha=0.9, borderpad=0.4, loc="lower left", ncol=3, edgecolor="k")
566 # Side histogram
567 sideHist = plt.gcf().add_subplot(gs[1:, -1], sharey=ax)
568 finiteObjs = np.isfinite(catPlot[yCol].values)
569 bins = np.linspace(yLimMin, yLimMax)
570 sideHist.hist(catPlot[yCol].values[finiteObjs], bins=bins, color="grey", alpha=0.3,
571 orientation="horizontal", log=True)
572 if np.any(catPlot["sourceType"] == sourceTypeMapper["galaxies"]):
573 sideHist.hist(ysGalaxies[np.isfinite(ysGalaxies)], bins=bins, color="firebrick", histtype="step",
574 orientation="horizontal", log=True)
575 if highMags[sourceTypeMapper["galaxies"]] != "-":
576 sideHist.hist(ysGalaxies[np.isfinite(ysGalaxies) & (xsGalaxies < float(highMags[2]))],
577 bins=bins, color="firebrick", histtype="step", orientation="horizontal",
578 log=True, ls="--")
579 if lowMags[sourceTypeMapper["galaxies"]] != "-":
580 sideHist.hist(ysGalaxies[np.isfinite(ysGalaxies) & (xsGalaxies < float(lowMags[2]))],
581 bins=bins, color="firebrick", histtype="step", orientation="horizontal",
582 log=True, ls=":")
584 if np.any(catPlot["sourceType"] == sourceTypeMapper["stars"]):
585 sideHist.hist(ysStars[np.isfinite(ysStars)], bins=bins, color="midnightblue", histtype="step",
586 orientation="horizontal", log=True)
587 if highMags[sourceTypeMapper["stars"]] != "-":
588 sideHist.hist(ysStars[np.isfinite(ysStars) & (xsStars < float(highMags[1]))], bins=bins,
589 color="midnightblue", histtype="step", orientation="horizontal", log=True,
590 ls="--")
591 if lowMags[sourceTypeMapper["stars"]] != "-":
592 sideHist.hist(ysStars[np.isfinite(ysStars) & (xsStars < float(lowMags[1]))], bins=bins,
593 color="midnightblue", histtype="step", orientation="horizontal", log=True,
594 ls=":")
596 sideHist.axes.get_yaxis().set_visible(False)
597 sideHist.set_xlabel("Number", fontsize=8)
598 if self.config.plot2DHist and histIm is not None:
599 divider = make_axes_locatable(sideHist)
600 cax = divider.append_axes("right", size="8%", pad=0)
601 plt.colorbar(histIm, cax=cax, orientation="vertical", label="Number of Points Per Bin")
603 # Corner plot of patches showing summary stat in each
604 axCorner = plt.gcf().add_subplot(gs[0, -1])
605 axCorner.yaxis.tick_right()
606 axCorner.yaxis.set_label_position("right")
607 axCorner.xaxis.tick_top()
608 axCorner.xaxis.set_label_position("top")
609 axCorner.set_aspect("equal")
611 patches = []
612 colors = []
613 for dataId in sumStats.keys():
614 (corners, stat) = sumStats[dataId]
615 ra = corners[0][0].asDegrees()
616 dec = corners[0][1].asDegrees()
617 xy = (ra, dec)
618 width = corners[2][0].asDegrees() - ra
619 height = corners[2][1].asDegrees() - dec
620 patches.append(Rectangle(xy, width, height))
621 colors.append(stat)
622 ras = [ra.asDegrees() for (ra, dec) in corners]
623 decs = [dec.asDegrees() for (ra, dec) in corners]
624 axCorner.plot(ras + [ras[0]], decs + [decs[0]], "k", lw=0.5)
625 cenX = ra + width / 2
626 cenY = dec + height / 2
627 if dataId != "tract":
628 axCorner.annotate(dataId, (cenX, cenY), color="k", fontsize=4, ha="center", va="center")
630 cmapUse = plt.cm.coolwarm
631 # Set the bad color to transparent and make a masked array
632 cmapUse.set_bad(color="none")
633 colors = np.ma.array(colors, mask=np.isnan(colors))
634 collection = PatchCollection(patches, cmap=cmapUse)
635 collection.set_array(colors)
636 axCorner.add_collection(collection)
638 axCorner.set_xlabel("R.A. (deg)", fontsize=7)
639 axCorner.set_ylabel("Dec. (deg)", fontsize=7)
640 axCorner.tick_params(axis="both", labelsize=6, length=0, pad=1.5)
641 axCorner.invert_xaxis()
643 # Add a colorbar
644 pos = axCorner.get_position()
645 cax = fig.add_axes([pos.x0, pos.y0 + 0.23, pos.x1 - pos.x0, 0.025])
646 plt.colorbar(collection, cax=cax, orientation="horizontal")
647 cax.text(0.5, 0.5, "Median Value", color="k", transform=cax.transAxes, rotation="horizontal",
648 horizontalalignment="center", verticalalignment="center", fontsize=6)
649 cax.tick_params(axis="x", labelsize=6, labeltop=True, labelbottom=False, bottom=False, top=True,
650 pad=0.5, length=2)
652 plt.draw()
653 plt.subplots_adjust(wspace=0.0, hspace=0.0, bottom=0.22, left=0.21)
654 fig = plt.gcf()
655 fig = addPlotInfo(fig, plotInfo)
657 return fig