Coverage for python/lsst/analysis/drp/scatterPlot.py: 10%
368 statements
« prev ^ index » next coverage.py v6.4.1, created at 2022-06-15 02:49 -0700
« prev ^ index » next coverage.py v6.4.1, created at 2022-06-15 02:49 -0700
1# This file is part of analysis_drp.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import matplotlib.pyplot as plt
23import numpy as np
24import pandas as pd
25from matplotlib import gridspec
26from matplotlib.patches import Rectangle
27from matplotlib.path import Path
28from matplotlib.collections import PatchCollection
29from mpl_toolkits.axes_grid1 import make_axes_locatable
30from lsst.pipe.tasks.configurableActions import ConfigurableActionStructField
31from lsst.pipe.tasks.dataFrameActions import MagColumnNanoJansky, SingleColumnAction
32from lsst.skymap import BaseSkyMap
34import lsst.pipe.base as pipeBase
35import lsst.pex.config as pexConfig
37from . import dataSelectors as dataSelectors
38from .plotUtils import generateSummaryStats, parsePlotInfo, addPlotInfo, mkColormap
39from .statistics import sigmaMad
41cmapPatch = plt.cm.coolwarm.copy()
42cmapPatch.set_bad(color="none")
45class ScatterPlotWithTwoHistsTaskConnections(pipeBase.PipelineTaskConnections,
46 dimensions=("tract", "skymap"),
47 defaultTemplates={"inputCoaddName": "deep",
48 "plotName": "deltaCoords"}):
50 catPlot = pipeBase.connectionTypes.Input(doc="The tract wide catalog to make plots from.",
51 storageClass="DataFrame",
52 name="objectTable_tract",
53 dimensions=("tract", "skymap"),
54 deferLoad=True)
56 skymap = pipeBase.connectionTypes.Input(doc="The skymap for the tract",
57 storageClass="SkyMap",
58 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME,
59 dimensions=("skymap",))
61 scatterPlot = pipeBase.connectionTypes.Output(doc="A scatter plot with histograms for both axes.",
62 storageClass="Plot",
63 name="scatterTwoHistPlot_{plotName}",
64 dimensions=("tract", "skymap"))
67class ScatterPlotWithTwoHistsTaskConfig(pipeBase.PipelineTaskConfig,
68 pipelineConnections=ScatterPlotWithTwoHistsTaskConnections):
70 axisActions = ConfigurableActionStructField(
71 doc="The actions to use to calculate the values used on each axis. The defaults for the"
72 "column names xAction and magAction are set to iCModelFlux.",
73 default={"xAction": MagColumnNanoJansky, "yAction": SingleColumnAction,
74 "magAction": MagColumnNanoJansky},
75 )
77 axisLabels = pexConfig.DictField(
78 doc="Name of the dataframe columns to plot, will be used as the axis label: {'x':, 'y':, 'mag':}"
79 "The mag column is used to decide which points to include in the printed statistics.",
80 keytype=str,
81 itemtype=str
82 )
84 def get_requirements(self):
85 """Return inputs required for a Task to run with this config.
87 Returns
88 -------
89 bands : `set`
90 The required bands.
91 columns : `set`
92 The required column names.
93 """
94 columnNames = {"patch"}
95 bands = set()
96 for actionStruct in [self.axisActions,
97 self.selectorActions,
98 self.highSnStatisticSelectorActions,
99 self.lowSnStatisticSelectorActions,
100 self.sourceSelectorActions]:
101 for action in actionStruct:
102 for col in action.columns:
103 if col is not None:
104 columnNames.add(col)
105 column_split = col.split("_")
106 # If there's no underscore, it has no band prefix
107 if len(column_split) > 1:
108 band = column_split[0]
109 if band not in self.nonBandColumnPrefixes:
110 bands.add(band)
111 return bands, columnNames
113 nonBandColumnPrefixes = pexConfig.ListField(
114 doc="Column prefixes that are not bands and which should not be added to the set of bands",
115 dtype=str,
116 default=["coord", "extend", "detect", "xy", "merge"],
117 )
119 selectorActions = ConfigurableActionStructField(
120 doc="Which selectors to use to narrow down the data for QA plotting.",
121 default={"flagSelector": dataSelectors.CoaddPlotFlagSelector},
122 )
124 highSnStatisticSelectorActions = ConfigurableActionStructField(
125 doc="Selectors to use to decide which points to use for calculating the high SN statistics.",
126 default={"statSelector": dataSelectors.SnSelector},
127 )
129 lowSnStatisticSelectorActions = ConfigurableActionStructField(
130 doc="Selectors to use to decide which points to use for calculating the low SN statistics.",
131 default={"statSelector": dataSelectors.SnSelector},
132 )
134 sourceSelectorActions = ConfigurableActionStructField(
135 doc="What types of sources to use.",
136 default={"sourceSelector": dataSelectors.StarIdentifier},
137 )
139 nBins = pexConfig.Field(
140 doc="Number of bins to put on the x axis.",
141 default=40.0,
142 dtype=float,
143 )
145 plot2DHist = pexConfig.Field(
146 doc="Plot a 2D histogram in the densist area of points on the scatter plot."
147 "Doesn't look great if plotting mulitple datasets on top of each other.",
148 default=True,
149 dtype=bool,
150 )
152 def setDefaults(self):
153 super().setDefaults()
154 self.axisActions.magAction.column = "i_cModelFlux"
155 self.axisActions.xAction.column = "i_cModelFlux"
156 self.highSnStatisticSelectorActions.statSelector.threshold = 2700
157 self.lowSnStatisticSelectorActions.statSelector.threshold = 500
160class ScatterPlotWithTwoHistsTask(pipeBase.PipelineTask):
162 ConfigClass = ScatterPlotWithTwoHistsTaskConfig
163 _DefaultName = "scatterPlotWithTwoHistsTask"
165 def runQuantum(self, butlerQC, inputRefs, outputRefs):
166 # Docs inherited from base class
167 bands, columnNames = self.config.get_requirements()
168 inputs = butlerQC.get(inputRefs)
169 dataFrame = inputs["catPlot"].get(parameters={"columns": columnNames})
170 inputs['catPlot'] = dataFrame
171 dataId = butlerQC.quantum.dataId
172 inputs["dataId"] = dataId
173 inputs["runName"] = inputRefs.catPlot.datasetRef.run
174 localConnections = self.config.ConnectionsClass(config=self.config)
175 inputs["tableName"] = localConnections.catPlot.name
176 inputs["plotName"] = localConnections.scatterPlot.name
177 inputs["bands"] = bands
178 outputs = self.run(**inputs)
179 butlerQC.put(outputs, outputRefs)
181 def run(self, catPlot, dataId, runName, skymap, tableName, bands, plotName):
182 """Prep the catalogue and then make a scatterPlot of the given column.
184 Parameters
185 ----------
186 catPlot : `pandas.core.frame.DataFrame`
187 The catalog to plot the points from.
188 dataId :
189 `lsst.daf.butler.core.dimensions._coordinate._ExpandedTupleDataCoordinate`
190 The dimensions that the plot is being made from.
191 runName : `str`
192 The name of the collection that the plot is written out to.
193 skymap : `lsst.skymap`
194 The skymap used to define the patch boundaries.
195 tableName : `str`
196 The type of table used to make the plot.
198 Returns
199 -------
200 `pipeBase.Struct` containing:
201 scatterPlot : `matplotlib.figure.Figure`
202 The resulting figure.
204 Notes
205 -----
206 The catalogue is first narrowed down using the selectors specified in
207 `self.config.selectorActions`.
208 If the column names are 'Functor' then the functors specified in
209 `self.config.axisFunctors` are used to calculate the required values.
210 After this the following functions are run:
212 `parsePlotInfo` which uses the dataId, runName and tableName to add
213 useful information to the plot.
215 `generateSummaryStats` which parses the skymap to give the corners of
216 the patches for later plotting and calculates some basic statistics
217 in each patch for the column in self.zColName.
219 `scatterPlotWithTwoHists` which makes a scatter plot of the points with
220 a histogram of each axis.
221 """
223 # Apply the selectors to narrow down the sources to use
224 mask = np.ones(len(catPlot), dtype=bool)
225 for selector in self.config.selectorActions:
226 mask &= selector(catPlot)
227 catPlot = catPlot[mask]
229 columns = {self.config.axisLabels["x"]: self.config.axisActions.xAction(catPlot),
230 self.config.axisLabels["y"]: self.config.axisActions.yAction(catPlot),
231 self.config.axisLabels["mag"]: self.config.axisActions.magAction(catPlot),
232 "patch": catPlot["patch"]}
233 for actionStruct in [self.config.highSnStatisticSelectorActions,
234 self.config.lowSnStatisticSelectorActions,
235 self.config.sourceSelectorActions]:
236 for action in actionStruct:
237 for col in action.columns:
238 columns.update({col: catPlot[col]})
240 plotDf = pd.DataFrame(columns)
242 sourceTypes = np.zeros(len(plotDf))
243 for selector in self.config.sourceSelectorActions:
244 # The source selectors return 1 for a star and 2 for a galaxy
245 # rather than a mask this allows the information about which
246 # type of sources are being plotted to be propagated
247 sourceTypes += selector(catPlot)
248 if list(self.config.sourceSelectorActions) == []:
249 sourceTypes = [10]*len(plotDf)
250 plotDf.loc[:, "sourceType"] = sourceTypes
252 # Decide which points to use for stats calculation
253 useForStats = np.zeros(len(plotDf))
254 lowSnMask = np.ones(len(plotDf), dtype=bool)
255 for selector in self.config.lowSnStatisticSelectorActions:
256 lowSnMask &= selector(plotDf)
257 useForStats[lowSnMask] = 2
259 highSnMask = np.ones(len(plotDf), dtype=bool)
260 for selector in self.config.highSnStatisticSelectorActions:
261 highSnMask &= selector(plotDf)
262 useForStats[highSnMask] = 1
263 plotDf.loc[:, "useForStats"] = useForStats
265 # Get the S/N cut used
266 try:
267 SN = self.config.selectorActions.SnSelector.threshold
268 except AttributeError:
269 SN = "N/A"
271 # Get useful information about the plot
272 plotInfo = parsePlotInfo(dataId, runName, tableName, bands, plotName, SN)
273 # Calculate the corners of the patches and some associated stats
274 sumStats = {} if skymap is None else generateSummaryStats(
275 plotDf, self.config.axisLabels["y"], skymap, plotInfo)
276 # Make the plot
277 fig = self.scatterPlotWithTwoHists(plotDf, plotInfo, sumStats)
279 return pipeBase.Struct(scatterPlot=fig)
281 def scatterPlotWithTwoHists(self, catPlot, plotInfo, sumStats, yLims=False, xLims=False):
282 """Makes a generic plot with a 2D histogram and collapsed histograms of
283 each axis.
285 Parameters
286 ----------
287 catPlot : `pandas.core.frame.DataFrame`
288 The catalog to plot the points from.
289 plotInfo : `dict`
290 A dictionary of information about the data being plotted with keys:
291 ``"run"``
292 The output run for the plots (`str`).
293 ``"skymap"``
294 The type of skymap used for the data (`str`).
295 ``"filter"``
296 The filter used for this data (`str`).
297 ``"tract"``
298 The tract that the data comes from (`str`).
299 sumStats : `dict`
300 A dictionary where the patchIds are the keys which store the R.A.
301 and dec of the corners of the patch, along with a summary
302 statistic for each patch.
303 yLims : `Bool` or `tuple`, optional
304 The y axis limits to use for the plot. If `False`, they are
305 calculated from the data. If being given a tuple of
306 (yMin, yMax).
307 xLims : `Bool` or `tuple`, optional
308 The x axis limits to use for the plot. If `False`, they are
309 calculated from the data.
310 If being given a tuple of (xMin, xMax).
312 Returns
313 -------
314 fig : `matplotlib.figure.Figure`
315 The resulting figure.
317 Notes
318 -----
319 Uses the axisLabels config options `x` and `y` and the axisAction
320 config options `xAction` and `yAction` to plot a scatter
321 plot of the values against each other. A histogram of the points
322 collapsed onto each axis is also plotted. A summary panel showing the
323 median of the y value in each patch is shown in the upper right corner
324 of the resultant plot. The code uses the selectorActions to decide
325 which points to plot and the statisticSelector actions to determine
326 which points to use for the printed statistics.
327 """
328 self.log.info("Plotting %s: the values of %s on a scatter plot.",
329 self.config.connections.plotName, self.config.axisLabels['y'])
331 fig = plt.figure(dpi=300)
332 gs = gridspec.GridSpec(4, 4)
334 newBlues = mkColormap(["paleturquoise", "midnightBlue"])
335 newReds = mkColormap(["lemonchiffon", "firebrick"])
337 # Need to separate stars and galaxies
338 stars = (catPlot["sourceType"] == 1)
339 galaxies = (catPlot["sourceType"] == 2)
341 xCol = self.config.axisLabels["x"]
342 yCol = self.config.axisLabels["y"]
343 magCol = self.config.axisLabels["mag"]
345 # For galaxies
346 xsGalaxies = catPlot.loc[galaxies, xCol]
347 ysGalaxies = catPlot.loc[galaxies, yCol]
349 # For stars
350 xsStars = catPlot.loc[stars, xCol]
351 ysStars = catPlot.loc[stars, yCol]
353 highStats = {}
354 highMags = {}
355 lowStats = {}
356 lowMags = {}
358 # sourceTypes: 1 - stars, 2 - galaxies, 9 - unknowns
359 # 10 - all
360 sourceTypeList = [1, 2, 9, 10]
361 sourceTypeMapper = {"stars": 1, "galaxies": 2, "unknowns": 9, "all": 10}
362 # Calculate some statistics
363 for sourceType in sourceTypeList:
364 if np.any(catPlot["sourceType"] == sourceType):
365 sources = (catPlot["sourceType"] == sourceType)
366 highSn = ((catPlot["useForStats"] == 1) & sources)
367 highSnMed = np.nanmedian(catPlot.loc[highSn, yCol])
368 highSnMad = sigmaMad(catPlot.loc[highSn, yCol], nan_policy="omit")
370 lowSn = (((catPlot["useForStats"] == 1) | (catPlot["useForStats"] == 2)) & sources)
371 lowSnMed = np.nanmedian(catPlot.loc[lowSn, yCol])
372 lowSnMad = sigmaMad(catPlot.loc[lowSn, yCol], nan_policy="omit")
374 highStatsStr = (f"Median: {highSnMed:0.3g} "
375 + r"$\sigma_{MAD}$: " + f"{highSnMad:0.3g} "
376 + r"N$_{points}$: " + f"{np.sum(highSn)}")
377 highStats[sourceType] = highStatsStr
379 lowStatsStr = (f"Median: {lowSnMed:0.3g} "
380 + r"$\sigma_{MAD}$: " + f"{lowSnMad:0.3g} "
381 + r"N$_{points}$: " + f"{np.sum(lowSn)}")
382 lowStats[sourceType] = lowStatsStr
384 if np.sum(highSn) > 0:
385 sortedMags = np.sort(catPlot.loc[highSn, magCol])
386 x = int(len(sortedMags)/10)
387 approxHighMag = np.nanmedian(sortedMags[-x:])
388 elif len(catPlot.loc[highSn, magCol]) < 10:
389 approxHighMag = np.nanmedian(catPlot.loc[highSn, magCol])
390 else:
391 approxHighMag = "-"
392 highMags[sourceType] = f"{approxHighMag:.3g}"
394 if np.sum(lowSn) > 0.0:
395 sortedMags = np.sort(catPlot.loc[lowSn, magCol])
396 x = int(len(sortedMags)/10)
397 approxLowMag = np.nanmedian(sortedMags[-x:])
398 elif len(catPlot.loc[lowSn, magCol]) < 10:
399 approxLowMag = np.nanmedian(catPlot.loc[lowSn, magCol])
400 else:
401 approxLowMag = "-"
402 lowMags[sourceType] = f"{approxLowMag:.3g}"
404 # Main scatter plot
405 ax = fig.add_subplot(gs[1:, :-1])
406 binThresh = 5
408 yBinsOut = []
409 linesForLegend = []
411 if (np.any(catPlot["sourceType"] == sourceTypeMapper["stars"])
412 and not np.any(catPlot["sourceType"] == sourceTypeMapper["galaxies"])):
413 toPlotList = [(xsStars.values, ysStars.values, "midnightblue", newBlues,
414 sourceTypeMapper["stars"])]
415 elif (np.any(catPlot["sourceType"] == sourceTypeMapper["galaxies"])
416 and not np.any(catPlot["sourceType"] == sourceTypeMapper["stars"])):
417 toPlotList = [(xsGalaxies.values, ysGalaxies.values, "firebrick", newReds,
418 sourceTypeMapper["galaxies"])]
419 elif (np.any(catPlot["sourceType"] == sourceTypeMapper["stars"])
420 and np.any(catPlot["sourceType"] == sourceTypeMapper["galaxies"])):
421 toPlotList = [(xsGalaxies.values, ysGalaxies.values, "firebrick", newReds,
422 sourceTypeMapper["galaxies"]),
423 (xsStars.values, ysStars.values, "midnightblue", newBlues,
424 sourceTypeMapper["stars"])]
425 elif np.any(catPlot["sourceType"] == sourceTypeMapper["unknowns"]):
426 unknowns = (catPlot["sourceType"] == sourceTypeMapper["unknowns"])
427 toPlotList = [(catPlot.loc[unknowns, xCol].values, catPlot.loc[unknowns, yCol].values,
428 "green", None, sourceTypeMapper["unknowns"])]
429 elif np.any(catPlot["sourceType"] == sourceTypeMapper["all"]):
430 toPlotList = [(catPlot[xCol].values, catPlot[yCol].values, "purple", None,
431 sourceTypeMapper["all"])]
432 else:
433 toPlotList = []
434 noDataFig = plt.Figure()
435 noDataFig.text(0.3, 0.5, "No data to plot after selectors applied")
436 noDataFig = addPlotInfo(noDataFig, plotInfo)
437 return noDataFig
439 xMin = None
440 for (j, (xs, ys, color, cmap, sourceType)) in enumerate(toPlotList):
441 sigMadYs = sigmaMad(ys, nan_policy="omit")
442 if len(xs) < 2:
443 medLine, = ax.plot(xs, np.nanmedian(ys), color,
444 label=f"Median: {np.nanmedian(ys):0.3g}", lw=0.8)
445 linesForLegend.append(medLine)
446 sigMads = np.array([sigmaMad(ys, nan_policy="omit")]*len(xs))
447 sigMadLine, = ax.plot(xs, np.nanmedian(ys) + 1.0*sigMads, color, alpha=0.8, lw=0.8,
448 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:0.3g}")
449 ax.plot(xs, np.nanmedian(ys) - 1.0*sigMads, color, alpha=0.8)
450 linesForLegend.append(sigMadLine)
451 histIm = None
452 continue
454 [xs1, xs25, xs50, xs75, xs95, xs97] = np.nanpercentile(xs, [1, 25, 50, 75, 95, 97])
455 xScale = (xs97 - xs1)/20.0 # This is ~5% of the data range
457 # 40 was used as the default number of bins because it looked good
458 xEdges = np.arange(np.nanmin(xs) - xScale, np.nanmax(xs) + xScale,
459 (np.nanmax(xs) + xScale - (np.nanmin(xs) - xScale))/self.config.nBins)
460 medYs = np.nanmedian(ys)
461 fiveSigmaHigh = medYs + 5.0*sigMadYs
462 fiveSigmaLow = medYs - 5.0*sigMadYs
463 binSize = (fiveSigmaHigh - fiveSigmaLow)/101.0
464 yEdges = np.arange(fiveSigmaLow, fiveSigmaHigh, binSize)
466 counts, xBins, yBins = np.histogram2d(xs, ys, bins=(xEdges, yEdges))
467 yBinsOut.append(yBins)
468 countsYs = np.sum(counts, axis=1)
470 ids = np.where((countsYs > binThresh))[0]
471 xEdgesPlot = xEdges[ids][1:]
472 xEdges = xEdges[ids]
474 if len(ids) > 1:
475 # Create the codes needed to turn the sigmaMad lines
476 # into a path to speed up checking which points are
477 # inside the area.
478 codes = np.ones(len(xEdgesPlot)*2)*Path.LINETO
479 codes[0] = Path.MOVETO
480 codes[-1] = Path.CLOSEPOLY
482 meds = np.zeros(len(xEdgesPlot))
483 threeSigMadVerts = np.zeros((len(xEdgesPlot)*2, 2))
484 sigMads = np.zeros(len(xEdgesPlot))
486 for (i, xEdge) in enumerate(xEdgesPlot):
487 ids = np.where((xs < xEdge) & (xs > xEdges[i]) & (np.isfinite(ys)))[0]
488 med = np.median(ys[ids])
489 sigMad = sigmaMad(ys[ids])
490 meds[i] = med
491 sigMads[i] = sigMad
492 threeSigMadVerts[i, :] = [xEdge, med + 3*sigMad]
493 threeSigMadVerts[-(i + 1), :] = [xEdge, med - 3*sigMad]
495 medLine, = ax.plot(xEdgesPlot, meds, color, label="Running Median")
496 linesForLegend.append(medLine)
498 # Make path to check which points lie within one sigma mad
499 threeSigMadPath = Path(threeSigMadVerts, codes)
501 # Add lines for the median +/- 3 * sigma MAD
502 threeSigMadLine, = ax.plot(xEdgesPlot, threeSigMadVerts[:len(xEdgesPlot), 1], color,
503 alpha=0.4, label=r"3$\sigma_{MAD}$")
504 ax.plot(xEdgesPlot[::-1], threeSigMadVerts[len(xEdgesPlot):, 1], color, alpha=0.4)
506 # Add lines for the median +/- 1 * sigma MAD
507 sigMadLine, = ax.plot(xEdgesPlot, meds + 1.0*sigMads, color, alpha=0.8,
508 label=r"$\sigma_{MAD}$")
509 linesForLegend.append(sigMadLine)
510 ax.plot(xEdgesPlot, meds - 1.0*sigMads, color, alpha=0.8)
512 # Add lines for the median +/- 2 * sigma MAD
513 twoSigMadLine, = ax.plot(xEdgesPlot, meds + 2.0*sigMads, color, alpha=0.6,
514 label=r"2$\sigma_{MAD}$")
515 linesForLegend.append(twoSigMadLine)
516 linesForLegend.append(threeSigMadLine)
517 ax.plot(xEdgesPlot, meds - 2.0*sigMads, color, alpha=0.6)
519 # Check which points are outside 3 sigma MAD of the median
520 # and plot these as points.
521 inside = threeSigMadPath.contains_points(np.array([xs, ys]).T)
522 ax.plot(xs[~inside], ys[~inside], ".", ms=3, alpha=0.3, mfc=color, mec=color, zorder=-1)
524 # Add some stats text
525 xPos = 0.65 - 0.4*j
526 bbox = dict(edgecolor=color, linestyle="--", facecolor="none")
527 highThresh = self.config.highSnStatisticSelectorActions.statSelector.threshold
528 statText = f"S/N > {highThresh} Stats ({magCol} < {highMags[sourceType]})\n"
529 statText += highStats[sourceType]
530 fig.text(xPos, 0.090, statText, bbox=bbox, transform=fig.transFigure, fontsize=6)
532 bbox = dict(edgecolor=color, linestyle=":", facecolor="none")
533 lowThresh = self.config.lowSnStatisticSelectorActions.statSelector.threshold
534 statText = f"S/N > {lowThresh} Stats ({magCol} < {lowMags[sourceType]})\n"
535 statText += lowStats[sourceType]
536 fig.text(xPos, 0.020, statText, bbox=bbox, transform=fig.transFigure, fontsize=6)
538 if self.config.plot2DHist:
539 histIm = ax.hexbin(xs[inside], ys[inside], gridsize=75, cmap=cmap, mincnt=1, zorder=-2)
541 # If there are not many sources being used for the
542 # statistics then plot them individually as just
543 # plotting a line makes the statistics look wrong
544 # as the magnitude estimation is iffy for low
545 # numbers of sources.
546 sources = (catPlot["sourceType"] == sourceType)
547 statInfo = catPlot["useForStats"].loc[sources].values
548 highSn = (statInfo == 1)
549 lowSn = ((statInfo == 2) | (statInfo == 2))
550 if np.sum(highSn) < 100 and np.sum(highSn) > 0:
551 ax.plot(xs[highSn], ys[highSn], marker="x", ms=4, mec="w", mew=2, ls="none")
552 highSnLine, = ax.plot(xs[highSn], ys[highSn], color=color, marker="x", ms=4, ls="none",
553 label="High SN")
554 linesForLegend.append(highSnLine)
555 xMin = np.min(xs[highSn])
556 else:
557 ax.axvline(float(highMags[sourceType]), color=color, ls="--")
559 if np.sum(lowSn) < 100 and np.sum(lowSn) > 0:
560 ax.plot(xs[lowSn], ys[lowSn], marker="+", ms=4, mec="w", mew=2, ls="none")
561 lowSnLine, = ax.plot(xs[lowSn], ys[lowSn], color=color, marker="+", ms=4, ls="none",
562 label="Low SN")
563 linesForLegend.append(lowSnLine)
564 if xMin is None or xMin > np.min(xs[lowSn]):
565 xMin = np.min(xs[lowSn])
566 else:
567 ax.axvline(float(lowMags[sourceType]), color=color, ls=":")
569 else:
570 ax.plot(xs, ys, ".", ms=5, alpha=0.3, mfc=color, mec=color, zorder=-1)
571 meds = np.array([np.nanmedian(ys)]*len(xs))
572 medLine, = ax.plot(xs, meds, color, label=f"Median: {np.nanmedian(ys):0.3g}", lw=0.8)
573 linesForLegend.append(medLine)
574 sigMads = np.array([sigmaMad(ys, nan_policy="omit")]*len(xs))
575 sigMadLine, = ax.plot(xs, meds + 1.0*sigMads, color, alpha=0.8, lw=0.8,
576 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:0.3g}")
577 ax.plot(xs, meds - 1.0*sigMads, color, alpha=0.8)
578 linesForLegend.append(sigMadLine)
579 histIm = None
581 # Set the scatter plot limits
582 if len(ysStars) > 0:
583 plotMed = np.nanmedian(ysStars)
584 else:
585 plotMed = np.nanmedian(ysGalaxies)
586 if len(xs) < 2:
587 meds = [np.median(ys)]
588 if yLims:
589 ax.set_ylim(yLims[0], yLims[1])
590 else:
591 numSig = 4
592 yLimMin = plotMed - numSig*sigMadYs
593 yLimMax = plotMed + numSig*sigMadYs
594 while (yLimMax < np.max(meds) or yLimMin > np.min(meds)) and numSig < 10:
595 numSig += 1
597 numSig += 1
598 yLimMin = plotMed - numSig*sigMadYs
599 yLimMax = plotMed + numSig*sigMadYs
600 ax.set_ylim(yLimMin, yLimMax)
602 if xLims:
603 ax.set_xlim(xLims[0], xLims[1])
604 elif len(xs) > 2:
605 if xMin is None:
606 xMin = xs1 - 2*xScale
607 ax.set_xlim(xMin, xs97 + 2*xScale)
609 # Add a line legend
610 ax.legend(handles=linesForLegend, ncol=4, fontsize=6, loc="upper left", framealpha=0.9,
611 edgecolor="k", borderpad=0.4, handlelength=1)
613 # Add axes labels
614 ax.set_ylabel(yCol, fontsize=10, labelpad=10)
615 ax.set_xlabel(xCol, fontsize=10, labelpad=2)
617 # Top histogram
618 topHist = plt.gcf().add_subplot(gs[0, :-1], sharex=ax)
619 topHist.hist(catPlot[xCol].values, bins=100, color="grey", alpha=0.3, log=True,
620 label=f"All ({len(catPlot)})")
621 if np.any(catPlot["sourceType"] == 2):
622 topHist.hist(xsGalaxies, bins=100, color="firebrick", histtype="step", log=True,
623 label=f"Galaxies ({len(np.where(galaxies)[0])})")
624 if np.any(catPlot["sourceType"] == 1):
625 topHist.hist(xsStars, bins=100, color="midnightblue", histtype="step", log=True,
626 label=f"Stars ({len(np.where(stars)[0])})")
627 topHist.axes.get_xaxis().set_visible(False)
628 topHist.set_ylabel("Number", fontsize=8)
629 topHist.legend(fontsize=6, framealpha=0.9, borderpad=0.4, loc="lower left", ncol=3, edgecolor="k")
631 # Side histogram
632 sideHist = plt.gcf().add_subplot(gs[1:, -1], sharey=ax)
633 finiteObjs = np.isfinite(catPlot[yCol].values)
634 bins = np.linspace(yLimMin, yLimMax)
635 sideHist.hist(catPlot[yCol].values[finiteObjs], bins=bins, color="grey", alpha=0.3,
636 orientation="horizontal", log=True)
637 if np.any(catPlot["sourceType"] == sourceTypeMapper["galaxies"]):
638 sideHist.hist(ysGalaxies[np.isfinite(ysGalaxies)], bins=bins, color="firebrick", histtype="step",
639 orientation="horizontal", log=True)
640 sources = (catPlot["sourceType"].values == sourceTypeMapper["galaxies"])
641 highSn = (catPlot["useForStats"].values == 1)
642 lowSn = (catPlot["useForStats"].values == 2)
643 sideHist.hist(ysGalaxies[highSn[sources]], bins=bins, color="firebrick", histtype="step",
644 orientation="horizontal", log=True, ls="--")
645 sideHist.hist(ysGalaxies[lowSn[sources]], bins=bins, color="firebrick", histtype="step",
646 orientation="horizontal", log=True, ls=":")
648 if np.any(catPlot["sourceType"] == sourceTypeMapper["stars"]):
649 sideHist.hist(ysStars[np.isfinite(ysStars)], bins=bins, color="midnightblue", histtype="step",
650 orientation="horizontal", log=True)
651 sources = (catPlot["sourceType"] == sourceTypeMapper["stars"])
652 highSn = (catPlot["useForStats"] == 1)
653 lowSn = (catPlot["useForStats"] == 2)
654 sideHist.hist(ysStars[highSn[sources]], bins=bins, color="midnightblue", histtype="step",
655 orientation="horizontal", log=True, ls="--")
656 sideHist.hist(ysStars[lowSn[sources]], bins=bins, color="midnightblue", histtype="step",
657 orientation="horizontal", log=True, ls=":")
659 sideHist.axes.get_yaxis().set_visible(False)
660 sideHist.set_xlabel("Number", fontsize=8)
661 if self.config.plot2DHist and histIm is not None:
662 divider = make_axes_locatable(sideHist)
663 cax = divider.append_axes("right", size="8%", pad=0)
664 plt.colorbar(histIm, cax=cax, orientation="vertical", label="Number of Points Per Bin")
666 # Corner plot of patches showing summary stat in each
667 axCorner = plt.gcf().add_subplot(gs[0, -1])
668 axCorner.yaxis.tick_right()
669 axCorner.yaxis.set_label_position("right")
670 axCorner.xaxis.tick_top()
671 axCorner.xaxis.set_label_position("top")
672 axCorner.set_aspect("equal")
674 patches = []
675 colors = []
676 for dataId in sumStats.keys():
677 (corners, stat) = sumStats[dataId]
678 ra = corners[0][0].asDegrees()
679 dec = corners[0][1].asDegrees()
680 xy = (ra, dec)
681 width = corners[2][0].asDegrees() - ra
682 height = corners[2][1].asDegrees() - dec
683 patches.append(Rectangle(xy, width, height))
684 colors.append(stat)
685 ras = [ra.asDegrees() for (ra, dec) in corners]
686 decs = [dec.asDegrees() for (ra, dec) in corners]
687 axCorner.plot(ras + [ras[0]], decs + [decs[0]], "k", lw=0.5)
688 cenX = ra + width / 2
689 cenY = dec + height / 2
690 if dataId != "tract":
691 axCorner.annotate(dataId, (cenX, cenY), color="k", fontsize=4, ha="center", va="center")
693 # Set the bad color to transparent and make a masked array
694 colors = np.ma.array(colors, mask=np.isnan(colors))
695 collection = PatchCollection(patches, cmap=cmapPatch)
696 collection.set_array(colors)
697 axCorner.add_collection(collection)
699 axCorner.set_xlabel("R.A. (deg)", fontsize=7)
700 axCorner.set_ylabel("Dec. (deg)", fontsize=7)
701 axCorner.tick_params(axis="both", labelsize=6, length=0, pad=1.5)
702 axCorner.invert_xaxis()
704 # Add a colorbar
705 pos = axCorner.get_position()
706 cax = fig.add_axes([pos.x0, pos.y0 + 0.23, pos.x1 - pos.x0, 0.025])
707 plt.colorbar(collection, cax=cax, orientation="horizontal")
708 cax.text(0.5, 0.5, "Median Value", color="k", transform=cax.transAxes, rotation="horizontal",
709 horizontalalignment="center", verticalalignment="center", fontsize=6)
710 cax.tick_params(axis="x", labelsize=6, labeltop=True, labelbottom=False, bottom=False, top=True,
711 pad=0.5, length=2)
713 plt.draw()
714 plt.subplots_adjust(wspace=0.0, hspace=0.0, bottom=0.22, left=0.21)
715 fig = plt.gcf()
716 fig = addPlotInfo(fig, plotInfo)
718 return fig