Coverage for python/lsst/analysis/drp/scatterPlot.py: 11%
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# This file is part of analysis_drp.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import matplotlib.pyplot as plt
23import numpy as np
24import pandas as pd
25from matplotlib import gridspec
26from matplotlib.patches import Rectangle
27from matplotlib.path import Path
28from matplotlib.collections import PatchCollection
29from mpl_toolkits.axes_grid1 import make_axes_locatable
30from lsst.pipe.tasks.configurableActions import ConfigurableActionStructField
31from lsst.pipe.tasks.dataFrameActions import MagColumnNanoJansky, SingleColumnAction
32from lsst.skymap import BaseSkyMap
34import lsst.pipe.base as pipeBase
35import lsst.pex.config as pexConfig
37from . import dataSelectors as dataSelectors
38from .plotUtils import generateSummaryStats, parsePlotInfo, addPlotInfo, mkColormap
39from .statistics import sigmaMad
41cmapPatch = plt.cm.coolwarm.copy()
42cmapPatch.set_bad(color="none")
45class ScatterPlotWithTwoHistsTaskConnections(pipeBase.PipelineTaskConnections,
46 dimensions=("tract", "skymap"),
47 defaultTemplates={"inputCoaddName": "deep",
48 "plotName": "deltaCoords"}):
50 catPlot = pipeBase.connectionTypes.Input(doc="The tract wide catalog to make plots from.",
51 storageClass="DataFrame",
52 name="objectTable_tract",
53 dimensions=("tract", "skymap"),
54 deferLoad=True)
56 skymap = pipeBase.connectionTypes.Input(doc="The skymap for the tract",
57 storageClass="SkyMap",
58 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME,
59 dimensions=("skymap",))
61 scatterPlot = pipeBase.connectionTypes.Output(doc="A scatter plot with histograms for both axes.",
62 storageClass="Plot",
63 name="scatterTwoHistPlot_{plotName}",
64 dimensions=("tract", "skymap"))
67class ScatterPlotWithTwoHistsTaskConfig(pipeBase.PipelineTaskConfig,
68 pipelineConnections=ScatterPlotWithTwoHistsTaskConnections):
70 axisActions = ConfigurableActionStructField(
71 doc="The actions to use to calculate the values used on each axis. The defaults for the"
72 "column names xAction and magAction are set to iCModelFlux.",
73 default={"xAction": MagColumnNanoJansky, "yAction": SingleColumnAction,
74 "magAction": MagColumnNanoJansky},
75 )
77 axisLabels = pexConfig.DictField(
78 doc="Name of the dataframe columns to plot, will be used as the axis label: {'x':, 'y':, 'mag':}"
79 "The mag column is used to decide which points to include in the printed statistics.",
80 keytype=str,
81 itemtype=str
82 )
84 def get_requirements(self):
85 """Return inputs required for a Task to run with this config.
87 Returns
88 -------
89 bands : `set`
90 The required bands.
91 columns : `set`
92 The required column names.
93 """
94 columnNames = {"patch"}
95 bands = set()
96 for actionStruct in [self.axisActions,
97 self.selectorActions,
98 self.highSnStatisticSelectorActions,
99 self.lowSnStatisticSelectorActions,
100 self.sourceSelectorActions]:
101 for action in actionStruct:
102 for col in action.columns:
103 if col is not None:
104 columnNames.add(col)
105 column_split = col.split("_")
106 # If there's no underscore, it has no band prefix
107 if len(column_split) > 1:
108 band = column_split[0]
109 if band not in self.nonBandColumnPrefixes:
110 bands.add(band)
111 return bands, columnNames
113 nonBandColumnPrefixes = pexConfig.ListField(
114 doc="Column prefixes that are not bands and which should not be added to the set of bands",
115 dtype=str,
116 default=["coord", "extend", "detect", "xy", "merge"],
117 )
119 selectorActions = ConfigurableActionStructField(
120 doc="Which selectors to use to narrow down the data for QA plotting.",
121 default={"flagSelector": dataSelectors.CoaddPlotFlagSelector},
122 )
124 highSnStatisticSelectorActions = ConfigurableActionStructField(
125 doc="Selectors to use to decide which points to use for calculating the high SN statistics.",
126 default={"statSelector": dataSelectors.SnSelector},
127 )
129 lowSnStatisticSelectorActions = ConfigurableActionStructField(
130 doc="Selectors to use to decide which points to use for calculating the low SN statistics.",
131 default={"statSelector": dataSelectors.SnSelector},
132 )
134 sourceSelectorActions = ConfigurableActionStructField(
135 doc="What types of sources to use.",
136 default={"sourceSelector": dataSelectors.StarIdentifier},
137 )
139 nBins = pexConfig.Field(
140 doc="Number of bins to put on the x axis.",
141 default=40.0,
142 dtype=float,
143 )
145 plot2DHist = pexConfig.Field(
146 doc="Plot a 2D histogram in the densist area of points on the scatter plot."
147 "Doesn't look great if plotting mulitple datasets on top of each other.",
148 default=True,
149 dtype=bool,
150 )
152 def setDefaults(self):
153 super().setDefaults()
154 self.axisActions.magAction.column = "i_cModelFlux"
155 self.axisActions.xAction.column = "i_cModelFlux"
156 self.highSnStatisticSelectorActions.statSelector.threshold = 2700
157 self.lowSnStatisticSelectorActions.statSelector.threshold = 500
160class ScatterPlotWithTwoHistsTask(pipeBase.PipelineTask):
162 ConfigClass = ScatterPlotWithTwoHistsTaskConfig
163 _DefaultName = "scatterPlotWithTwoHistsTask"
165 def runQuantum(self, butlerQC, inputRefs, outputRefs):
166 # Docs inherited from base class
167 bands, columnNames = self.config.get_requirements()
168 inputs = butlerQC.get(inputRefs)
169 dataFrame = inputs["catPlot"].get(parameters={"columns": columnNames})
170 inputs['catPlot'] = dataFrame
171 dataId = butlerQC.quantum.dataId
172 inputs["dataId"] = dataId
173 inputs["runName"] = inputRefs.catPlot.datasetRef.run
174 localConnections = self.config.ConnectionsClass(config=self.config)
175 inputs["tableName"] = localConnections.catPlot.name
176 inputs["plotName"] = localConnections.scatterPlot.name
177 inputs["bands"] = bands
178 outputs = self.run(**inputs)
179 butlerQC.put(outputs, outputRefs)
181 def run(self, catPlot, dataId, runName, skymap, tableName, bands, plotName):
182 """Prep the catalogue and then make a scatterPlot of the given column.
184 Parameters
185 ----------
186 catPlot : `pandas.core.frame.DataFrame`
187 The catalog to plot the points from.
188 dataId :
189 `lsst.daf.butler.core.dimensions._coordinate._ExpandedTupleDataCoordinate`
190 The dimensions that the plot is being made from.
191 runName : `str`
192 The name of the collection that the plot is written out to.
193 skymap : `lsst.skymap`
194 The skymap used to define the patch boundaries.
195 tableName : `str`
196 The type of table used to make the plot.
198 Returns
199 -------
200 `pipeBase.Struct` containing:
201 scatterPlot : `matplotlib.figure.Figure`
202 The resulting figure.
204 Notes
205 -----
206 The catalogue is first narrowed down using the selectors specified in
207 `self.config.selectorActions`.
208 If the column names are 'Functor' then the functors specified in
209 `self.config.axisFunctors` are used to calculate the required values.
210 After this the following functions are run:
212 `parsePlotInfo` which uses the dataId, runName and tableName to add
213 useful information to the plot.
215 `generateSummaryStats` which parses the skymap to give the corners of
216 the patches for later plotting and calculates some basic statistics
217 in each patch for the column in self.zColName.
219 `scatterPlotWithTwoHists` which makes a scatter plot of the points with
220 a histogram of each axis.
221 """
223 # Apply the selectors to narrow down the sources to use
224 mask = np.ones(len(catPlot), dtype=bool)
225 for selector in self.config.selectorActions:
226 mask &= selector(catPlot)
227 catPlot = catPlot[mask]
229 columns = {self.config.axisLabels["x"]: self.config.axisActions.xAction(catPlot),
230 self.config.axisLabels["y"]: self.config.axisActions.yAction(catPlot),
231 self.config.axisLabels["mag"]: self.config.axisActions.magAction(catPlot),
232 "patch": catPlot["patch"]}
233 for actionStruct in [self.config.highSnStatisticSelectorActions,
234 self.config.lowSnStatisticSelectorActions,
235 self.config.sourceSelectorActions]:
236 for action in actionStruct:
237 for col in action.columns:
238 columns.update({col: catPlot[col]})
240 plotDf = pd.DataFrame(columns)
242 sourceTypes = np.zeros(len(plotDf))
243 for selector in self.config.sourceSelectorActions:
244 # The source selectors return 1 for a star and 2 for a galaxy
245 # rather than a mask this allows the information about which
246 # type of sources are being plotted to be propagated
247 sourceTypes += selector(catPlot)
248 if list(self.config.sourceSelectorActions) == []:
249 sourceTypes = [10]*len(plotDf)
250 plotDf.loc[:, "sourceType"] = sourceTypes
252 # Decide which points to use for stats calculation
253 useForStats = np.zeros(len(plotDf))
254 lowSnMask = np.ones(len(plotDf), dtype=bool)
255 for selector in self.config.lowSnStatisticSelectorActions:
256 lowSnMask &= selector(plotDf)
257 useForStats[lowSnMask] = 2
259 highSnMask = np.ones(len(plotDf), dtype=bool)
260 for selector in self.config.highSnStatisticSelectorActions:
261 highSnMask &= selector(plotDf)
262 useForStats[highSnMask] = 1
263 plotDf.loc[:, "useForStats"] = useForStats
265 # Get the S/N cut used
266 try:
267 SN = self.config.selectorActions.SnSelector.threshold
268 except AttributeError:
269 SN = "N/A"
271 # Get useful information about the plot
272 plotInfo = parsePlotInfo(dataId, runName, tableName, bands, plotName, SN)
273 # Calculate the corners of the patches and some associated stats
274 sumStats = {} if skymap is None else generateSummaryStats(
275 plotDf, self.config.axisLabels["y"], skymap, plotInfo)
276 # Make the plot
277 try:
278 fig = self.scatterPlotWithTwoHists(plotDf, plotInfo, sumStats)
279 except Exception:
280 # This is a workaround until scatterPlotWithTwoHists works properly
281 fig = plt.figure(dpi=300)
283 return pipeBase.Struct(scatterPlot=fig)
285 def scatterPlotWithTwoHists(self, catPlot, plotInfo, sumStats, yLims=False, xLims=False):
286 """Makes a generic plot with a 2D histogram and collapsed histograms of
287 each axis.
289 Parameters
290 ----------
291 catPlot : `pandas.core.frame.DataFrame`
292 The catalog to plot the points from.
293 plotInfo : `dict`
294 A dictionary of information about the data being plotted with keys:
295 ``"run"``
296 The output run for the plots (`str`).
297 ``"skymap"``
298 The type of skymap used for the data (`str`).
299 ``"filter"``
300 The filter used for this data (`str`).
301 ``"tract"``
302 The tract that the data comes from (`str`).
303 sumStats : `dict`
304 A dictionary where the patchIds are the keys which store the R.A.
305 and dec of the corners of the patch, along with a summary
306 statistic for each patch.
307 yLims : `Bool` or `tuple`, optional
308 The y axis limits to use for the plot. If `False`, they are
309 calculated from the data. If being given a tuple of
310 (yMin, yMax).
311 xLims : `Bool` or `tuple`, optional
312 The x axis limits to use for the plot. If `False`, they are
313 calculated from the data.
314 If being given a tuple of (xMin, xMax).
316 Returns
317 -------
318 fig : `matplotlib.figure.Figure`
319 The resulting figure.
321 Notes
322 -----
323 Uses the axisLabels config options `x` and `y` and the axisAction
324 config options `xAction` and `yAction` to plot a scatter
325 plot of the values against each other. A histogram of the points
326 collapsed onto each axis is also plotted. A summary panel showing the
327 median of the y value in each patch is shown in the upper right corner
328 of the resultant plot. The code uses the selectorActions to decide
329 which points to plot and the statisticSelector actions to determine
330 which points to use for the printed statistics.
331 """
332 self.log.info("Plotting {}: the values of {} on a scatter plot.".format(
333 self.config.connections.plotName, self.config.axisLabels["y"]))
335 fig = plt.figure(dpi=300)
336 gs = gridspec.GridSpec(4, 4)
338 newBlues = mkColormap(["paleturquoise", "midnightBlue"])
339 newReds = mkColormap(["lemonchiffon", "firebrick"])
341 # Need to separate stars and galaxies
342 stars = (catPlot["sourceType"] == 1)
343 galaxies = (catPlot["sourceType"] == 2)
345 xCol = self.config.axisLabels["x"]
346 yCol = self.config.axisLabels["y"]
347 magCol = self.config.axisLabels["mag"]
349 # For galaxies
350 xsGalaxies = catPlot.loc[galaxies, xCol]
351 ysGalaxies = catPlot.loc[galaxies, yCol]
353 # For stars
354 xsStars = catPlot.loc[stars, xCol]
355 ysStars = catPlot.loc[stars, yCol]
357 highStats = {}
358 highMags = {}
359 lowStats = {}
360 lowMags = {}
362 # sourceTypes: 1 - stars, 2 - galaxies, 9 - unknowns
363 # 10 - all
364 sourceTypeList = [1, 2, 9, 10]
365 sourceTypeMapper = {"stars": 1, "galaxies": 2, "unknowns": 9, "all": 10}
366 # Calculate some statistics
367 for sourceType in sourceTypeList:
368 if np.any(catPlot["sourceType"] == sourceType):
369 sources = (catPlot["sourceType"] == sourceType)
370 highSn = ((catPlot["useForStats"] == 1) & sources)
371 highSnMed = np.nanmedian(catPlot.loc[highSn, yCol])
372 highSnMad = sigmaMad(catPlot.loc[highSn, yCol], nan_policy="omit")
374 lowSn = (((catPlot["useForStats"] == 1) | (catPlot["useForStats"] == 2)) & sources)
375 lowSnMed = np.nanmedian(catPlot.loc[lowSn, yCol])
376 lowSnMad = sigmaMad(catPlot.loc[lowSn, yCol], nan_policy="omit")
378 highStatsStr = ("Median: {:0.2f} ".format(highSnMed)
379 + r"$\sigma_{MAD}$: " + "{:0.2f}".format(highSnMad))
380 highStats[sourceType] = highStatsStr
382 lowStatsStr = ("Median: {:0.2f} ".format(lowSnMed)
383 + r"$\sigma_{MAD}$: " + "{:0.2f}".format(lowSnMad))
384 lowStats[sourceType] = lowStatsStr
386 if np.sum(highSn) > 0:
387 highMags[sourceType] = f"{np.nanmax(catPlot.loc[highSn, magCol]):.2f}"
388 else:
389 highMags[sourceType] = "-"
390 if np.sum(lowSn) > 0.0:
391 lowMags[sourceType] = f"{np.nanmax(catPlot.loc[lowSn, magCol]):.2f}"
392 else:
393 lowMags[sourceType] = "-"
395 # Main scatter plot
396 ax = fig.add_subplot(gs[1:, :-1])
397 binThresh = 5
399 yBinsOut = []
400 linesForLegend = []
402 if (np.any(catPlot["sourceType"] == sourceTypeMapper["stars"])
403 and not np.any(catPlot["sourceType"] == sourceTypeMapper["galaxies"])):
404 toPlotList = [(xsStars.values, ysStars.values, "midnightblue", newBlues,
405 sourceTypeMapper["stars"])]
406 elif (np.any(catPlot["sourceType"] == sourceTypeMapper["galaxies"])
407 and not np.any(catPlot["sourceType"] == sourceTypeMapper["stars"])):
408 toPlotList = [(xsGalaxies.values, ysGalaxies.values, "firebrick", newReds,
409 sourceTypeMapper["galaxies"])]
410 elif (np.any(catPlot["sourceType"] == sourceTypeMapper["stars"])
411 and np.any(catPlot["sourceType"] == sourceTypeMapper["galaxies"])):
412 toPlotList = [(xsGalaxies.values, ysGalaxies.values, "firebrick", newReds,
413 sourceTypeMapper["galaxies"]),
414 (xsStars.values, ysStars.values, "midnightblue", newBlues,
415 sourceTypeMapper["stars"])]
416 elif np.any(catPlot["sourceType"] == sourceTypeMapper["unknowns"]):
417 unknowns = (catPlot["sourceType"] == sourceTypeMapper["unknowns"])
418 toPlotList = [(catPlot.loc[unknowns, xCol].values, catPlot.loc[unknowns, yCol].values,
419 "green", None, sourceTypeMapper["unknowns"])]
420 elif np.any(catPlot["sourceType"] == sourceTypeMapper["all"]):
421 toPlotList = [(catPlot[xCol].values, catPlot[yCol].values, "purple", None,
422 sourceTypeMapper["all"])]
423 else:
424 toPlotList = []
426 for (j, (xs, ys, color, cmap, sourceType)) in enumerate(toPlotList):
427 if len(xs) < 2:
428 medLine, = ax.plot(xs, np.nanmedian(ys), color,
429 label="Median: {:0.2f}".format(np.nanmedian(ys)), lw=0.8)
430 linesForLegend.append(medLine)
431 sigMads = np.array([sigmaMad(ys, nan_policy="omit")]*len(xs))
432 sigMadLine, = ax.plot(xs, np.nanmedian(ys) + 1.0*sigMads, color, alpha=0.8, lw=0.8,
433 label=r"$\sigma_{MAD}$: " + "{:0.2f}".format(sigMads[0]))
434 ax.plot(xs, np.nanmedian(ys) - 1.0*sigMads, color, alpha=0.8)
435 linesForLegend.append(sigMadLine)
436 continue
438 [xs1, xs25, xs50, xs75, xs95, xs97] = np.nanpercentile(xs, [1, 25, 50, 75, 95, 97])
439 xScale = (xs97 - xs1)/20.0 # This is ~5% of the data range
441 # 40 was used as the default number of bins because it looked good
442 xEdges = np.arange(np.nanmin(xs) - xScale, np.nanmax(xs) + xScale,
443 (np.nanmax(xs) + xScale - (np.nanmin(xs) - xScale))/self.config.nBins)
444 medYs = np.nanmedian(ys)
445 sigMadYs = sigmaMad(ys, nan_policy="omit")
446 fiveSigmaHigh = medYs + 5.0*sigMadYs
447 fiveSigmaLow = medYs - 5.0*sigMadYs
448 binSize = (fiveSigmaHigh - fiveSigmaLow)/101.0
449 yEdges = np.arange(fiveSigmaLow, fiveSigmaHigh, binSize)
451 counts, xBins, yBins = np.histogram2d(xs, ys, bins=(xEdges, yEdges))
452 yBinsOut.append(yBins)
453 countsYs = np.sum(counts, axis=1)
455 ids = np.where((countsYs > binThresh))[0]
456 xEdgesPlot = xEdges[ids][1:]
457 xEdges = xEdges[ids]
459 if len(ids) > 1:
460 # Create the codes needed to turn the sigmaMad lines
461 # into a path to speed up checking which points are
462 # inside the area.
463 codes = np.ones(len(xEdgesPlot)*2)*Path.LINETO
464 codes[0] = Path.MOVETO
465 codes[-1] = Path.CLOSEPOLY
467 meds = np.zeros(len(xEdgesPlot))
468 threeSigMadVerts = np.zeros((len(xEdgesPlot)*2, 2))
469 sigMads = np.zeros(len(xEdgesPlot))
471 for (i, xEdge) in enumerate(xEdgesPlot):
472 ids = np.where((xs < xEdge) & (xs > xEdges[i]) & (np.isfinite(ys)))[0]
473 med = np.median(ys[ids])
474 sigMad = sigmaMad(ys[ids])
475 meds[i] = med
476 sigMads[i] = sigMad
477 threeSigMadVerts[i, :] = [xEdge, med + 3*sigMad]
478 threeSigMadVerts[-(i + 1), :] = [xEdge, med - 3*sigMad]
480 medLine, = ax.plot(xEdgesPlot, meds, color, label="Running Median")
481 linesForLegend.append(medLine)
483 # Make path to check which points lie within one sigma mad
484 threeSigMadPath = Path(threeSigMadVerts, codes)
486 # Add lines for the median +/- 3 * sigma MAD
487 threeSigMadLine, = ax.plot(xEdgesPlot, threeSigMadVerts[:len(xEdgesPlot), 1], color,
488 alpha=0.4, label=r"3$\sigma_{MAD}$")
489 ax.plot(xEdgesPlot[::-1], threeSigMadVerts[len(xEdgesPlot):, 1], color, alpha=0.4)
491 # Add lines for the median +/- 1 * sigma MAD
492 sigMadLine, = ax.plot(xEdgesPlot, meds + 1.0*sigMads, color, alpha=0.8,
493 label=r"$\sigma_{MAD}$")
494 linesForLegend.append(sigMadLine)
495 ax.plot(xEdgesPlot, meds - 1.0*sigMads, color, alpha=0.8)
497 # Add lines for the median +/- 2 * sigma MAD
498 twoSigMadLine, = ax.plot(xEdgesPlot, meds + 2.0*sigMads, color, alpha=0.6,
499 label=r"2$\sigma_{MAD}$")
500 linesForLegend.append(twoSigMadLine)
501 linesForLegend.append(threeSigMadLine)
502 ax.plot(xEdgesPlot, meds - 2.0*sigMads, color, alpha=0.6)
504 # Check which points are outside 3 sigma MAD of the median
505 # and plot these as points.
506 inside = threeSigMadPath.contains_points(np.array([xs, ys]).T)
507 ax.plot(xs[~inside], ys[~inside], ".", ms=3, alpha=0.3, mfc=color, mec=color, zorder=-1)
509 # Add some stats text
510 xPos = 0.65 - 0.4*j
511 bbox = dict(edgecolor=color, linestyle="--", facecolor="none")
512 highThresh = self.config.highSnStatisticSelectorActions.statSelector.threshold
513 statText = f"S/N > {highThresh} Stats ({magCol} < {highMags[sourceType]})\n"
514 statText += highStats[sourceType]
515 fig.text(xPos, 0.087, statText, bbox=bbox, transform=fig.transFigure, fontsize=6)
516 if highMags[sourceType] != "-":
517 ax.axvline(float(highMags[sourceType]), color=color, ls="--")
519 bbox = dict(edgecolor=color, linestyle=":", facecolor="none")
520 lowThresh = self.config.lowSnStatisticSelectorActions.statSelector.threshold
521 statText = f"S/N > {lowThresh} Stats ({magCol} < {lowMags[sourceType]})\n"
522 statText += lowStats[sourceType]
523 fig.text(xPos, 0.017, statText, bbox=bbox, transform=fig.transFigure, fontsize=6)
524 if lowMags[sourceType] != "-":
525 ax.axvline(float(lowMags[sourceType]), color=color, ls=":")
527 if self.config.plot2DHist:
528 histIm = ax.hexbin(xs[inside], ys[inside], gridsize=75, cmap=cmap, mincnt=1, zorder=-2)
530 else:
531 ax.plot(xs, ys, ".", ms=5, alpha=0.3, mfc=color, mec=color, zorder=-1)
532 meds = np.array([np.nanmedian(ys)]*len(xs))
533 medLine, = ax.plot(xs, meds, color, label=f"Median: {np.nanmedian(ys):0.2f}", lw=0.8)
534 linesForLegend.append(medLine)
535 sigMads = np.array([sigmaMad(ys, nan_policy="omit")]*len(xs))
536 sigMadLine, = ax.plot(xs, meds + 1.0*sigMads, color, alpha=0.8, lw=0.8,
537 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:0.2f}")
538 ax.plot(xs, meds - 1.0*sigMads, color, alpha=0.8)
539 linesForLegend.append(sigMadLine)
540 histIm = None
542 # Set the scatter plot limits
543 if len(ysStars) > 0:
544 plotMed = np.nanmedian(ysStars)
545 else:
546 plotMed = np.nanmedian(ysGalaxies)
547 if yLims:
548 ax.set_ylim(yLims[0], yLims[1])
549 else:
550 numSig = 4
551 yLimMin = plotMed - numSig*sigMadYs
552 yLimMax = plotMed + numSig*sigMadYs
553 while (yLimMax < np.max(meds) or yLimMin > np.min(meds)) and numSig < 10:
554 numSig += 1
556 numSig += 1
557 yLimMin = plotMed - numSig*sigMadYs
558 yLimMax = plotMed + numSig*sigMadYs
559 ax.set_ylim(yLimMin, yLimMax)
561 if xLims:
562 ax.set_xlim(xLims[0], xLims[1])
563 else:
564 ax.set_xlim(xs1 - xScale, xs97 + xScale)
566 # Add a line legend
567 ax.legend(handles=linesForLegend, ncol=4, fontsize=6, loc="upper left", framealpha=0.9,
568 edgecolor="k", borderpad=0.4, handlelength=1)
570 # Add axes labels
571 ax.set_ylabel(yCol, fontsize=10, labelpad=10)
572 ax.set_xlabel(xCol, fontsize=10, labelpad=2)
574 # Top histogram
575 topHist = plt.gcf().add_subplot(gs[0, :-1], sharex=ax)
576 topHist.hist(catPlot[xCol].values, bins=100, color="grey", alpha=0.3, log=True,
577 label=f"All ({len(catPlot)})")
578 if np.any(catPlot["sourceType"] == 2):
579 topHist.hist(xsGalaxies, bins=100, color="firebrick", histtype="step", log=True,
580 label=f"Galaxies ({len(np.where(galaxies)[0])})")
581 if np.any(catPlot["sourceType"] == 1):
582 topHist.hist(xsStars, bins=100, color="midnightblue", histtype="step", log=True,
583 label=f"Stars ({len(np.where(stars)[0])})")
584 topHist.axes.get_xaxis().set_visible(False)
585 topHist.set_ylabel("Number", fontsize=8)
586 topHist.legend(fontsize=6, framealpha=0.9, borderpad=0.4, loc="lower left", ncol=3, edgecolor="k")
588 # Side histogram
589 sideHist = plt.gcf().add_subplot(gs[1:, -1], sharey=ax)
590 finiteObjs = np.isfinite(catPlot[yCol].values)
591 bins = np.linspace(yLimMin, yLimMax)
592 sideHist.hist(catPlot[yCol].values[finiteObjs], bins=bins, color="grey", alpha=0.3,
593 orientation="horizontal", log=True)
594 if np.any(catPlot["sourceType"] == sourceTypeMapper["galaxies"]):
595 sideHist.hist(ysGalaxies[np.isfinite(ysGalaxies)], bins=bins, color="firebrick", histtype="step",
596 orientation="horizontal", log=True)
597 if highMags[sourceTypeMapper["galaxies"]] != "-":
598 sideHist.hist(ysGalaxies[np.isfinite(ysGalaxies) & (xsGalaxies < float(highMags[2]))],
599 bins=bins, color="firebrick", histtype="step", orientation="horizontal",
600 log=True, ls="--")
601 if lowMags[sourceTypeMapper["galaxies"]] != "-":
602 sideHist.hist(ysGalaxies[np.isfinite(ysGalaxies) & (xsGalaxies < float(lowMags[2]))],
603 bins=bins, color="firebrick", histtype="step", orientation="horizontal",
604 log=True, ls=":")
606 if np.any(catPlot["sourceType"] == sourceTypeMapper["stars"]):
607 sideHist.hist(ysStars[np.isfinite(ysStars)], bins=bins, color="midnightblue", histtype="step",
608 orientation="horizontal", log=True)
609 if highMags[sourceTypeMapper["stars"]] != "-":
610 sideHist.hist(ysStars[np.isfinite(ysStars) & (xsStars < float(highMags[1]))], bins=bins,
611 color="midnightblue", histtype="step", orientation="horizontal", log=True,
612 ls="--")
613 if lowMags[sourceTypeMapper["stars"]] != "-":
614 sideHist.hist(ysStars[np.isfinite(ysStars) & (xsStars < float(lowMags[1]))], bins=bins,
615 color="midnightblue", histtype="step", orientation="horizontal", log=True,
616 ls=":")
618 sideHist.axes.get_yaxis().set_visible(False)
619 sideHist.set_xlabel("Number", fontsize=8)
620 if self.config.plot2DHist and histIm is not None:
621 divider = make_axes_locatable(sideHist)
622 cax = divider.append_axes("right", size="8%", pad=0)
623 plt.colorbar(histIm, cax=cax, orientation="vertical", label="Number of Points Per Bin")
625 # Corner plot of patches showing summary stat in each
626 axCorner = plt.gcf().add_subplot(gs[0, -1])
627 axCorner.yaxis.tick_right()
628 axCorner.yaxis.set_label_position("right")
629 axCorner.xaxis.tick_top()
630 axCorner.xaxis.set_label_position("top")
631 axCorner.set_aspect("equal")
633 patches = []
634 colors = []
635 for dataId in sumStats.keys():
636 (corners, stat) = sumStats[dataId]
637 ra = corners[0][0].asDegrees()
638 dec = corners[0][1].asDegrees()
639 xy = (ra, dec)
640 width = corners[2][0].asDegrees() - ra
641 height = corners[2][1].asDegrees() - dec
642 patches.append(Rectangle(xy, width, height))
643 colors.append(stat)
644 ras = [ra.asDegrees() for (ra, dec) in corners]
645 decs = [dec.asDegrees() for (ra, dec) in corners]
646 axCorner.plot(ras + [ras[0]], decs + [decs[0]], "k", lw=0.5)
647 cenX = ra + width / 2
648 cenY = dec + height / 2
649 if dataId != "tract":
650 axCorner.annotate(dataId, (cenX, cenY), color="k", fontsize=4, ha="center", va="center")
652 # Set the bad color to transparent and make a masked array
653 colors = np.ma.array(colors, mask=np.isnan(colors))
654 collection = PatchCollection(patches, cmap=cmapPatch)
655 collection.set_array(colors)
656 axCorner.add_collection(collection)
658 axCorner.set_xlabel("R.A. (deg)", fontsize=7)
659 axCorner.set_ylabel("Dec. (deg)", fontsize=7)
660 axCorner.tick_params(axis="both", labelsize=6, length=0, pad=1.5)
661 axCorner.invert_xaxis()
663 # Add a colorbar
664 pos = axCorner.get_position()
665 cax = fig.add_axes([pos.x0, pos.y0 + 0.23, pos.x1 - pos.x0, 0.025])
666 plt.colorbar(collection, cax=cax, orientation="horizontal")
667 cax.text(0.5, 0.5, "Median Value", color="k", transform=cax.transAxes, rotation="horizontal",
668 horizontalalignment="center", verticalalignment="center", fontsize=6)
669 cax.tick_params(axis="x", labelsize=6, labeltop=True, labelbottom=False, bottom=False, top=True,
670 pad=0.5, length=2)
672 plt.draw()
673 plt.subplots_adjust(wspace=0.0, hspace=0.0, bottom=0.22, left=0.21)
674 fig = plt.gcf()
675 fig = addPlotInfo(fig, plotInfo)
677 return fig