Coverage for python/lsst/analysis/drp/scatterPlot.py: 9%
386 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-06 13:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-06 13:57 +0000
1# This file is part of analysis_drp.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21import matplotlib
22import matplotlib.pyplot as plt
23import numpy as np
24import pandas as pd
25from matplotlib import gridspec
26from matplotlib.patches import Rectangle
27from matplotlib.path import Path
28from matplotlib.collections import PatchCollection
29from mpl_toolkits.axes_grid1 import make_axes_locatable
31from lsst.pex.config.configurableActions import ConfigurableActionStructField
32from lsst.pipe.tasks.dataFrameActions import MagColumnNanoJansky, SingleColumnAction
33from lsst.skymap import BaseSkyMap
35import lsst.pipe.base as pipeBase
36import lsst.pex.config as pexConfig
38from .calcFunctors import MagDiff
39from .dataSelectors import SnSelector, StarIdentifier, CoaddPlotFlagSelector
40from .plotUtils import generateSummaryStats, parsePlotInfo, addPlotInfo, mkColormap
41from .statistics import sigmaMad
43cmapPatch = plt.cm.coolwarm.copy()
44cmapPatch.set_bad(color="none")
45matplotlib.use("Agg")
48class ScatterPlotWithTwoHistsTaskConnections(pipeBase.PipelineTaskConnections,
49 dimensions=("tract", "skymap"),
50 defaultTemplates={"inputCoaddName": "deep",
51 "plotName": "deltaCoords"}):
53 catPlot = pipeBase.connectionTypes.Input(doc="The tract wide catalog to make plots from.",
54 storageClass="DataFrame",
55 name="objectTable_tract",
56 dimensions=("tract", "skymap"),
57 deferLoad=True)
59 skymap = pipeBase.connectionTypes.Input(doc="The skymap for the tract",
60 storageClass="SkyMap",
61 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME,
62 dimensions=("skymap",))
64 scatterPlot = pipeBase.connectionTypes.Output(doc="A scatter plot with histograms for both axes.",
65 storageClass="Plot",
66 name="scatterTwoHistPlot_{plotName}",
67 dimensions=("tract", "skymap"))
70class ScatterPlotWithTwoHistsTaskConfig(pipeBase.PipelineTaskConfig,
71 pipelineConnections=ScatterPlotWithTwoHistsTaskConnections):
73 axisActions = ConfigurableActionStructField(
74 doc="The actions to use to calculate the values used on each axis. The defaults for the"
75 "column names xAction and magAction are set to iCModelFlux.",
76 default={"xAction": MagColumnNanoJansky, "yAction": SingleColumnAction,
77 "magAction": MagColumnNanoJansky},
78 )
80 axisLabels = pexConfig.DictField(
81 doc="Name of the dataframe columns to plot, will be used as the axis label: {'x':, 'y':, 'mag':}"
82 "The mag column is used to decide which points to include in the printed statistics.",
83 keytype=str,
84 itemtype=str
85 )
87 def get_requirements(self):
88 """Return inputs required for a Task to run with this config.
90 Returns
91 -------
92 bands : `set`
93 The required bands.
94 columns : `set`
95 The required column names.
96 """
97 columnNames = {"patch"}
98 bands = set()
99 for actionStruct in [self.axisActions,
100 self.selectorActions,
101 self.highSnStatisticSelectorActions,
102 self.lowSnStatisticSelectorActions,
103 self.sourceSelectorActions]:
104 for action in actionStruct:
105 for col in action.columns:
106 if col is not None:
107 columnNames.add(col)
108 column_split = col.split("_")
109 # If there's no underscore, it has no band prefix
110 if len(column_split) > 1:
111 band = column_split[0]
112 if band not in self.nonBandColumnPrefixes:
113 bands.add(band)
114 return bands, columnNames
116 nonBandColumnPrefixes = pexConfig.ListField(
117 doc="Column prefixes that are not bands and which should not be added to the set of bands",
118 dtype=str,
119 default=["coord", "extend", "detect", "xy", "merge", "sky"],
120 )
122 selectorActions = ConfigurableActionStructField(
123 doc="Which selectors to use to narrow down the data for QA plotting.",
124 default={"flagSelector": CoaddPlotFlagSelector,
125 "catSnSelector": SnSelector},
126 )
128 highSnStatisticSelectorActions = ConfigurableActionStructField(
129 doc="Selectors to use to decide which points to use for calculating the high SN statistics.",
130 default={"statSelector": SnSelector},
131 )
133 lowSnStatisticSelectorActions = ConfigurableActionStructField(
134 doc="Selectors to use to decide which points to use for calculating the low SN statistics.",
135 default={"statSelector": SnSelector},
136 )
138 sourceSelectorActions = ConfigurableActionStructField(
139 doc="What types of sources to use.",
140 default={"sourceSelector": StarIdentifier},
141 )
143 nBins = pexConfig.Field(
144 doc="Number of bins to put on the x axis.",
145 default=40.0,
146 dtype=float,
147 )
149 plot2DHist = pexConfig.Field(
150 doc="Plot a 2D histogram in the densist area of points on the scatter plot."
151 "Doesn't look great if plotting mulitple datasets on top of each other.",
152 default=True,
153 dtype=bool,
154 )
156 xLims = pexConfig.ListField(
157 doc="Minimum and maximum x-axis limit to force (provided as a list of [xMin, xMax]). "
158 "If `None`, limits will be computed and set based on the data.",
159 dtype=float,
160 default=None,
161 optional=True,
162 )
164 yLims = pexConfig.ListField(
165 doc="Minimum and maximum y-axis limit to force (provided as a list of [yMin, yMax]). "
166 "If `None`, limits will be computed and set based on the data.",
167 dtype=float,
168 default=None,
169 optional=True,
170 )
172 minPointSize = pexConfig.Field(
173 doc="When plotting points (as opposed to 2D hist bins), the minimum size they can be. Some "
174 "relative scaling will be perfomed depending on the \"flavor\" of the set of points.",
175 default=2.0,
176 dtype=float,
177 )
179 def setDefaults(self):
180 super().setDefaults()
181 self.axisActions.magAction.column = "i_cModelFlux"
182 self.axisActions.xAction.column = "i_cModelFlux"
183 self.axisActions.yAction = MagDiff
184 self.axisActions.yAction.col1 = "i_ap12Flux"
185 self.axisActions.yAction.col2 = "i_psfFlux"
186 self.selectorActions.flagSelector.bands = ["i"]
187 self.selectorActions.catSnSelector.fluxType = "psfFlux"
188 self.highSnStatisticSelectorActions.statSelector.fluxType = "cModelFlux"
189 self.highSnStatisticSelectorActions.statSelector.threshold = 2700
190 self.lowSnStatisticSelectorActions.statSelector.fluxType = "cModelFlux"
191 self.lowSnStatisticSelectorActions.statSelector.threshold = 500
192 self.axisLabels = {
193 "x": self.axisActions.xAction.column.removesuffix("Flux") + " (mag)",
194 "mag": self.axisActions.magAction.column.removesuffix("Flux") + " (mag)",
195 "y": ("{} - {} (mmag)".format(self.axisActions.yAction.col1.removesuffix("Flux"),
196 self.axisActions.yAction.col2.removesuffix("Flux")))
197 }
200class ScatterPlotWithTwoHistsTask(pipeBase.PipelineTask):
202 ConfigClass = ScatterPlotWithTwoHistsTaskConfig
203 _DefaultName = "scatterPlotWithTwoHistsTask"
205 def runQuantum(self, butlerQC, inputRefs, outputRefs):
206 # Docs inherited from base class
207 bands, columnNames = self.config.get_requirements()
208 inputs = butlerQC.get(inputRefs)
209 dataFrame = inputs["catPlot"].get(parameters={"columns": columnNames})
210 inputs['catPlot'] = dataFrame
211 dataId = butlerQC.quantum.dataId
212 inputs["dataId"] = dataId
213 inputs["runName"] = inputRefs.catPlot.datasetRef.run
214 localConnections = self.config.ConnectionsClass(config=self.config)
215 inputs["tableName"] = localConnections.catPlot.name
216 inputs["plotName"] = localConnections.scatterPlot.name
217 inputs["bands"] = bands
218 outputs = self.run(**inputs)
219 butlerQC.put(outputs, outputRefs)
221 def run(self, catPlot, dataId, runName, skymap, tableName, bands, plotName):
222 """Prep the catalogue and then make a scatterPlot of the given column.
224 Parameters
225 ----------
226 catPlot : `pandas.core.frame.DataFrame`
227 The catalog to plot the points from.
228 dataId : `lsst.daf.butler.DataCoordinate`
229 The dimensions that the plot is being made from.
230 runName : `str`
231 The name of the collection that the plot is written out to.
232 skymap : `lsst.skymap`
233 The skymap used to define the patch boundaries.
234 tableName : `str`
235 The type of table used to make the plot.
237 Returns
238 -------
239 `pipeBase.Struct` containing:
240 scatterPlot : `matplotlib.figure.Figure`
241 The resulting figure.
243 Notes
244 -----
245 The catalogue is first narrowed down using the selectors specified in
246 `self.config.selectorActions`.
247 If the column names are 'Functor' then the functors specified in
248 `self.config.axisFunctors` are used to calculate the required values.
249 After this the following functions are run:
251 `parsePlotInfo` which uses the dataId, runName and tableName to add
252 useful information to the plot.
254 `generateSummaryStats` which parses the skymap to give the corners of
255 the patches for later plotting and calculates some basic statistics
256 in each patch for the column in self.zColName.
258 `scatterPlotWithTwoHists` which makes a scatter plot of the points with
259 a histogram of each axis.
260 """
262 # Apply the selectors to narrow down the sources to use
263 mask = np.ones(len(catPlot), dtype=bool)
264 for selector in self.config.selectorActions:
265 mask &= selector(catPlot)
266 catPlot = catPlot[mask]
268 columns = {self.config.axisLabels["x"]: self.config.axisActions.xAction(catPlot),
269 self.config.axisLabels["y"]: self.config.axisActions.yAction(catPlot),
270 self.config.axisLabels["mag"]: self.config.axisActions.magAction(catPlot),
271 "patch": catPlot["patch"]}
272 for actionStruct in [self.config.highSnStatisticSelectorActions,
273 self.config.lowSnStatisticSelectorActions,
274 self.config.sourceSelectorActions]:
275 for action in actionStruct:
276 for col in action.columns:
277 columns.update({col: catPlot[col]})
279 plotDf = pd.DataFrame(columns)
281 sourceTypes = np.zeros(len(plotDf))
282 for selector in self.config.sourceSelectorActions:
283 # The source selectors return 1 for a star and 2 for a galaxy
284 # rather than a mask this allows the information about which
285 # type of sources are being plotted to be propagated
286 sourceTypes += selector(catPlot)
287 if list(self.config.sourceSelectorActions) == []:
288 sourceTypes = [10]*len(plotDf)
289 plotDf.loc[:, "sourceType"] = sourceTypes
291 # Decide which points to use for stats calculation
292 useForStats = np.zeros(len(plotDf))
293 lowSnMask = np.ones(len(plotDf), dtype=bool)
294 for selector in self.config.lowSnStatisticSelectorActions:
295 lowSnMask &= selector(plotDf)
296 useForStats[lowSnMask] = 2
298 highSnMask = np.ones(len(plotDf), dtype=bool)
299 for selector in self.config.highSnStatisticSelectorActions:
300 highSnMask &= selector(plotDf)
301 useForStats[highSnMask] = 1
302 plotDf.loc[:, "useForStats"] = useForStats
304 # Get the S/N cut used
305 if hasattr(self.config.selectorActions, "catSnSelector"):
306 SN = self.config.selectorActions.catSnSelector.threshold
307 SNFlux = self.config.selectorActions.catSnSelector.fluxType
308 else:
309 SN = "N/A"
310 SNFlux = "N/A"
312 # Get useful information about the plot
313 plotInfo = parsePlotInfo(dataId, runName, tableName, bands, plotName, SN, SNFlux)
314 # Calculate the corners of the patches and some associated stats
315 sumStats = {} if skymap is None else generateSummaryStats(
316 plotDf, self.config.axisLabels["y"], skymap, plotInfo)
317 # Make the plot
318 fig = self.scatterPlotWithTwoHists(plotDf, plotInfo, sumStats)
319 return pipeBase.Struct(scatterPlot=fig)
321 def scatterPlotWithTwoHists(self, catPlot, plotInfo, sumStats):
322 """Makes a generic plot with a 2D histogram and collapsed histograms of
323 each axis.
325 Parameters
326 ----------
327 catPlot : `pandas.core.frame.DataFrame`
328 The catalog to plot the points from.
329 plotInfo : `dict`
330 A dictionary of information about the data being plotted with keys:
331 ``"run"``
332 The output run for the plots (`str`).
333 ``"skymap"``
334 The type of skymap used for the data (`str`).
335 ``"filter"``
336 The filter used for this data (`str`).
337 ``"tract"``
338 The tract that the data comes from (`str`).
339 sumStats : `dict`
340 A dictionary where the patchIds are the keys which store the R.A.
341 and dec of the corners of the patch, along with a summary
342 statistic for each patch.
344 Returns
345 -------
346 fig : `matplotlib.figure.Figure`
347 The resulting figure.
349 Notes
350 -----
351 Uses the axisLabels config options `x` and `y` and the axisAction
352 config options `xAction` and `yAction` to plot a scatter
353 plot of the values against each other. A histogram of the points
354 collapsed onto each axis is also plotted. A summary panel showing the
355 median of the y value in each patch is shown in the upper right corner
356 of the resultant plot. The code uses the selectorActions to decide
357 which points to plot and the statisticSelector actions to determine
358 which points to use for the printed statistics.
360 The axis limits are set based on the values of `config.xLim` and
361 `config.yLims`. If provided (as a `list` of [min, max]), those will
362 be used. If `None` (the default), the axis limits will be computed
363 and set based on the data.
364 """
365 self.log.info("Plotting %s: the values of %s on a scatter plot.",
366 self.config.connections.plotName, self.config.axisLabels['y'])
368 fig = plt.figure(dpi=300)
369 gs = gridspec.GridSpec(4, 4)
371 newBlues = mkColormap(["paleturquoise", "midnightBlue"])
372 newReds = mkColormap(["lemonchiffon", "firebrick"])
374 # Need to separate stars and galaxies
375 stars = (catPlot["sourceType"] == 1)
376 galaxies = (catPlot["sourceType"] == 2)
378 xCol = self.config.axisLabels["x"]
379 yCol = self.config.axisLabels["y"]
380 magCol = self.config.axisLabels["mag"]
382 # For galaxies
383 xsGalaxies = catPlot.loc[galaxies, xCol]
384 ysGalaxies = catPlot.loc[galaxies, yCol]
386 # For stars
387 xsStars = catPlot.loc[stars, xCol]
388 ysStars = catPlot.loc[stars, yCol]
390 highStats = {}
391 highMags = {}
392 lowStats = {}
393 lowMags = {}
395 # sourceTypes: 1 - stars, 2 - galaxies, 9 - unknowns
396 # 10 - all
397 sourceTypeList = [1, 2, 9, 10]
398 sourceTypeMapper = {"stars": 1, "galaxies": 2, "unknowns": 9, "all": 10}
399 # Calculate some statistics
400 for sourceType in sourceTypeList:
401 if np.any(catPlot["sourceType"] == sourceType):
402 sources = (catPlot["sourceType"] == sourceType)
403 highSn = ((catPlot["useForStats"] == 1) & sources)
404 highSnMed = np.nanmedian(catPlot.loc[highSn, yCol])
405 highSnMad = sigmaMad(catPlot.loc[highSn, yCol], nan_policy="omit")
407 lowSn = (((catPlot["useForStats"] == 1) | (catPlot["useForStats"] == 2)) & sources)
408 lowSnMed = np.nanmedian(catPlot.loc[lowSn, yCol])
409 lowSnMad = sigmaMad(catPlot.loc[lowSn, yCol], nan_policy="omit")
411 highStatsStr = (f"Median: {highSnMed:0.3g} "
412 + r"$\sigma_{MAD}$: " + f"{highSnMad:0.3g} "
413 + r"N$_{points}$: " + f"{np.sum(highSn)}")
414 highStats[sourceType] = highStatsStr
416 lowStatsStr = (f"Median: {lowSnMed:0.3g} "
417 + r"$\sigma_{MAD}$: " + f"{lowSnMad:0.3g} "
418 + r"N$_{points}$: " + f"{np.sum(lowSn)}")
419 lowStats[sourceType] = lowStatsStr
421 if np.sum(highSn) > 0:
422 sortedMags = np.sort(catPlot.loc[highSn, magCol])
423 x = int(len(sortedMags)/10)
424 approxHighMag = np.nanmedian(sortedMags[-x:])
425 elif len(catPlot.loc[highSn, magCol]) < 10:
426 approxHighMag = np.nanmedian(catPlot.loc[highSn, magCol])
427 else:
428 approxHighMag = "-"
429 highMags[sourceType] = f"{approxHighMag:.3g}"
431 if np.sum(lowSn) > 0.0:
432 sortedMags = np.sort(catPlot.loc[lowSn, magCol])
433 x = int(len(sortedMags)/10)
434 approxLowMag = np.nanmedian(sortedMags[-x:])
435 elif len(catPlot.loc[lowSn, magCol]) < 10:
436 approxLowMag = np.nanmedian(catPlot.loc[lowSn, magCol])
437 else:
438 approxLowMag = "-"
439 lowMags[sourceType] = f"{approxLowMag:.3g}"
441 # Main scatter plot
442 ax = fig.add_subplot(gs[1:, :-1])
443 binThresh = 5
445 linesForLegend = []
447 if (np.any(catPlot["sourceType"] == sourceTypeMapper["stars"])
448 and not np.any(catPlot["sourceType"] == sourceTypeMapper["galaxies"])):
449 toPlotList = [(xsStars.values, ysStars.values, "midnightblue", newBlues,
450 sourceTypeMapper["stars"])]
451 elif (np.any(catPlot["sourceType"] == sourceTypeMapper["galaxies"])
452 and not np.any(catPlot["sourceType"] == sourceTypeMapper["stars"])):
453 toPlotList = [(xsGalaxies.values, ysGalaxies.values, "firebrick", newReds,
454 sourceTypeMapper["galaxies"])]
455 elif (np.any(catPlot["sourceType"] == sourceTypeMapper["stars"])
456 and np.any(catPlot["sourceType"] == sourceTypeMapper["galaxies"])):
457 toPlotList = [(xsGalaxies.values, ysGalaxies.values, "firebrick", newReds,
458 sourceTypeMapper["galaxies"]),
459 (xsStars.values, ysStars.values, "midnightblue", newBlues,
460 sourceTypeMapper["stars"])]
461 elif np.any(catPlot["sourceType"] == sourceTypeMapper["unknowns"]):
462 unknowns = (catPlot["sourceType"] == sourceTypeMapper["unknowns"])
463 toPlotList = [(catPlot.loc[unknowns, xCol].values, catPlot.loc[unknowns, yCol].values,
464 "green", None, sourceTypeMapper["unknowns"])]
465 elif np.any(catPlot["sourceType"] == sourceTypeMapper["all"]):
466 toPlotList = [(catPlot[xCol].values, catPlot[yCol].values, "purple", None,
467 sourceTypeMapper["all"])]
468 else:
469 toPlotList = []
470 noDataFig = plt.Figure()
471 noDataFig.text(0.3, 0.5, "No data to plot after selectors applied")
472 noDataFig = addPlotInfo(noDataFig, plotInfo)
473 return noDataFig
475 xMin = None
476 for (j, (xs, ys, color, cmap, sourceType)) in enumerate(toPlotList):
477 sigMadYs = sigmaMad(ys, nan_policy="omit")
478 if len(xs) < 2:
479 medLine, = ax.plot(xs, np.nanmedian(ys), color,
480 label=f"Median: {np.nanmedian(ys):0.3g}", lw=0.8)
481 linesForLegend.append(medLine)
482 sigMads = np.array([sigmaMad(ys, nan_policy="omit")]*len(xs))
483 sigMadLine, = ax.plot(xs, np.nanmedian(ys) + 1.0*sigMads, color, alpha=0.8, lw=0.8,
484 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:0.3g}")
485 ax.plot(xs, np.nanmedian(ys) - 1.0*sigMads, color, alpha=0.8)
486 linesForLegend.append(sigMadLine)
487 histIm = None
488 continue
490 [xs1, xs25, xs50, xs75, xs95, xs97] = np.nanpercentile(xs, [1, 25, 50, 75, 95, 97])
491 xScale = (xs97 - xs1)/20.0 # This is ~5% of the data range
493 # 40 was used as the default number of bins because it looked good
494 xEdges = np.arange(np.nanmin(xs) - xScale, np.nanmax(xs) + xScale,
495 (np.nanmax(xs) + xScale - (np.nanmin(xs) - xScale))/self.config.nBins)
496 medYs = np.nanmedian(ys)
497 fiveSigmaHigh = medYs + 5.0*sigMadYs
498 fiveSigmaLow = medYs - 5.0*sigMadYs
499 binSize = (fiveSigmaHigh - fiveSigmaLow)/self.config.nBins
500 yEdges = np.arange(fiveSigmaLow, fiveSigmaHigh, binSize)
502 counts, xBins, yBins = np.histogram2d(xs, ys, bins=(xEdges, yEdges))
503 countsYs = np.sum(counts, axis=1)
505 ids = np.where((countsYs > binThresh))[0]
506 xEdgesPlot = xEdges[ids][1:]
507 xEdges = xEdges[ids]
509 if len(ids) > 1:
510 # Create the codes needed to turn the sigmaMad lines
511 # into a path to speed up checking which points are
512 # inside the area.
513 codes = np.ones(len(xEdgesPlot)*2)*Path.LINETO
514 codes[0] = Path.MOVETO
515 codes[-1] = Path.CLOSEPOLY
517 meds = np.zeros(len(xEdgesPlot))
518 threeSigMadVerts = np.zeros((len(xEdgesPlot)*2, 2))
519 sigMads = np.zeros(len(xEdgesPlot))
521 for (i, xEdge) in enumerate(xEdgesPlot):
522 ids = np.where((xs < xEdge) & (xs >= xEdges[i]) & (np.isfinite(ys)))[0]
523 med = np.median(ys[ids])
524 sigMad = sigmaMad(ys[ids])
525 meds[i] = med
526 sigMads[i] = sigMad
527 threeSigMadVerts[i, :] = [xEdge, med + 3*sigMad]
528 threeSigMadVerts[-(i + 1), :] = [xEdge, med - 3*sigMad]
530 medLine, = ax.plot(xEdgesPlot, meds, color, label="Running Median")
531 linesForLegend.append(medLine)
533 # Make path to check which points lie within one sigma mad
534 threeSigMadPath = Path(threeSigMadVerts, codes)
536 # Add lines for the median +/- 3 * sigma MAD
537 threeSigMadLine, = ax.plot(xEdgesPlot, threeSigMadVerts[:len(xEdgesPlot), 1], color,
538 alpha=0.4, label=r"3$\sigma_{MAD}$")
539 ax.plot(xEdgesPlot[::-1], threeSigMadVerts[len(xEdgesPlot):, 1], color, alpha=0.4)
541 # Add lines for the median +/- 1 * sigma MAD
542 sigMadLine, = ax.plot(xEdgesPlot, meds + 1.0*sigMads, color, alpha=0.8,
543 label=r"$\sigma_{MAD}$")
544 linesForLegend.append(sigMadLine)
545 ax.plot(xEdgesPlot, meds - 1.0*sigMads, color, alpha=0.8)
547 # Add lines for the median +/- 2 * sigma MAD
548 twoSigMadLine, = ax.plot(xEdgesPlot, meds + 2.0*sigMads, color, alpha=0.6,
549 label=r"2$\sigma_{MAD}$")
550 linesForLegend.append(twoSigMadLine)
551 linesForLegend.append(threeSigMadLine)
552 ax.plot(xEdgesPlot, meds - 2.0*sigMads, color, alpha=0.6)
554 # Check which points are outside 3 sigma MAD of the median
555 # and plot these as points.
556 inside = threeSigMadPath.contains_points(np.array([xs, ys]).T)
557 ax.plot(xs[~inside], ys[~inside], ".", ms=self.config.minPointSize, alpha=0.3,
558 mfc=color, mec=color, zorder=-1)
560 # Add some stats text
561 xPos = 0.65 - 0.4*j
562 if hasattr(self.config.highSnStatisticSelectorActions, "statSelector"):
563 bbox = dict(edgecolor=color, linestyle="--", facecolor="none")
564 highThresh = self.config.highSnStatisticSelectorActions.statSelector.threshold
565 statText = f"S/N > {highThresh} Stats [{magCol} $\\lesssim$ {highMags[sourceType]}]\n"
566 statText += highStats[sourceType]
567 fig.text(xPos, 0.090, statText, bbox=bbox, transform=fig.transFigure, fontsize=6)
569 if hasattr(self.config.highSnStatisticSelectorActions, "statSelector"):
570 bbox = dict(edgecolor=color, linestyle=":", facecolor="none")
571 lowThresh = self.config.lowSnStatisticSelectorActions.statSelector.threshold
572 statText = f"S/N > {lowThresh} Stats [{magCol} $\\lesssim$ {lowMags[sourceType]}]\n"
573 statText += lowStats[sourceType]
574 fig.text(xPos, 0.020, statText, bbox=bbox, transform=fig.transFigure, fontsize=6)
576 if self.config.plot2DHist:
577 histIm = ax.hexbin(xs[inside], ys[inside], gridsize=75, cmap=cmap, mincnt=1, zorder=-2)
579 # If there are not many sources being used for the
580 # statistics then plot them individually as just
581 # plotting a line makes the statistics look wrong
582 # as the magnitude estimation is iffy for low
583 # numbers of sources.
584 sources = (catPlot["sourceType"] == sourceType)
585 statInfo = catPlot["useForStats"].loc[sources].values
586 if hasattr(self.config.highSnStatisticSelectorActions, "statSelector"):
587 highSn = (statInfo == 1)
588 if np.sum(highSn) < 100 and np.sum(highSn) > 0:
589 ax.plot(xs[highSn], ys[highSn], marker="x", ms=self.config.minPointSize + 1,
590 mec="w", mew=2, ls="none")
591 highSnLine, = ax.plot(xs[highSn], ys[highSn], color=color, marker="x",
592 ms=self.config.minPointSize + 1, ls="none", label="High SN")
593 linesForLegend.append(highSnLine)
594 xMin = np.min(xs[highSn])
595 else:
596 ax.axvline(float(highMags[sourceType]), color=color, ls="--")
597 if hasattr(self.config.lowSnStatisticSelectorActions, "statSelector"):
598 lowSn = ((statInfo == 2) | (statInfo == 2))
599 if np.sum(lowSn) < 100 and np.sum(lowSn) > 0:
600 ax.plot(xs[lowSn], ys[lowSn], marker="+", ms=self.config.minPointSize + 1, mec="w",
601 mew=2, ls="none")
602 lowSnLine, = ax.plot(xs[lowSn], ys[lowSn], color=color, marker="+",
603 ms=self.config.minPointSize + 1, ls="none", label="Low SN")
604 linesForLegend.append(lowSnLine)
605 if xMin is None or xMin > np.min(xs[lowSn]):
606 xMin = np.min(xs[lowSn])
607 else:
608 ax.axvline(float(lowMags[sourceType]), color=color, ls=":")
610 else:
611 ax.plot(xs, ys, ".", ms=self.config.minPointSize + 3, alpha=0.3, mfc=color, mec=color,
612 zorder=-1)
613 meds = np.array([np.nanmedian(ys)]*len(xs))
614 medLine, = ax.plot(xs, meds, color, label=f"Median: {np.nanmedian(ys):0.3g}", lw=0.8)
615 linesForLegend.append(medLine)
616 sigMads = np.array([sigmaMad(ys, nan_policy="omit")]*len(xs))
617 sigMadLine, = ax.plot(xs, meds + 1.0*sigMads, color, alpha=0.8, lw=0.8,
618 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:0.3g}")
619 ax.plot(xs, meds - 1.0*sigMads, color, alpha=0.8)
620 linesForLegend.append(sigMadLine)
621 histIm = None
623 # Set the scatter plot limits
624 if len(ysStars) > 0:
625 plotMed = np.nanmedian(ysStars)
626 else:
627 plotMed = np.nanmedian(ysGalaxies)
628 if len(xs) < 2:
629 meds = [np.median(ys)]
631 if self.config.yLims is not None:
632 yLimMin = self.config.yLims[0]
633 yLimMax = self.config.yLims[1]
634 else:
635 numSig = 4
636 yLimMin = plotMed - numSig*sigMadYs
637 yLimMax = plotMed + numSig*sigMadYs
638 while (yLimMax < np.max(meds) or yLimMin > np.min(meds)) and numSig < 10:
639 numSig += 1
641 numSig += 1
642 yLimMin = plotMed - numSig*sigMadYs
643 yLimMax = plotMed + numSig*sigMadYs
644 ax.set_ylim(yLimMin, yLimMax)
646 if self.config.xLims is not None:
647 ax.set_xlim(self.config.xLims[0], self.config.xLims[1])
648 elif len(xs) > 2:
649 if xMin is None:
650 xMin = xs1 - 2*xScale
651 ax.set_xlim(xMin, xs97 + 2*xScale)
653 # Add a line legend
654 ax.legend(handles=linesForLegend, ncol=4, fontsize=6, loc="upper left", framealpha=0.9,
655 edgecolor="k", borderpad=0.4, handlelength=1)
657 # Add axes labels
658 ax.set_ylabel(yCol, fontsize=10, labelpad=10)
659 ax.set_xlabel(xCol, fontsize=10, labelpad=2)
661 # Top histogram
662 topHist = plt.gcf().add_subplot(gs[0, :-1], sharex=ax)
663 topHist.hist(catPlot[xCol].values, bins=100, color="grey", alpha=0.3, log=True,
664 label=f"All ({len(catPlot)})")
665 if np.any(catPlot["sourceType"] == 2):
666 topHist.hist(xsGalaxies, bins=100, color="firebrick", histtype="step", log=True,
667 label=f"Galaxies ({len(np.where(galaxies)[0])})")
668 if np.any(catPlot["sourceType"] == 1):
669 topHist.hist(xsStars, bins=100, color="midnightblue", histtype="step", log=True,
670 label=f"Stars ({len(np.where(stars)[0])})")
671 topHist.axes.get_xaxis().set_visible(False)
672 topHist.set_ylabel("Number", fontsize=8)
673 topHist.legend(fontsize=6, framealpha=0.9, borderpad=0.4, loc="lower left", ncol=3, edgecolor="k")
675 # Side histogram
676 sideHist = plt.gcf().add_subplot(gs[1:, -1], sharey=ax)
677 finiteObjs = np.isfinite(catPlot[yCol].values)
678 bins = np.linspace(yLimMin, yLimMax)
679 sideHist.hist(catPlot[yCol].values[finiteObjs], bins=bins, color="grey", alpha=0.3,
680 orientation="horizontal", log=True)
681 if np.any(catPlot["sourceType"] == sourceTypeMapper["galaxies"]):
682 sideHist.hist(ysGalaxies[np.isfinite(ysGalaxies)], bins=bins, color="firebrick", histtype="step",
683 orientation="horizontal", log=True)
684 sources = (catPlot["sourceType"].values == sourceTypeMapper["galaxies"])
685 highSn = (catPlot["useForStats"].values == 1)
686 lowSn = (catPlot["useForStats"].values == 2)
687 sideHist.hist(ysGalaxies[highSn[sources]], bins=bins, color="firebrick", histtype="step",
688 orientation="horizontal", log=True, ls="--")
689 sideHist.hist(ysGalaxies[lowSn[sources]], bins=bins, color="firebrick", histtype="step",
690 orientation="horizontal", log=True, ls=":")
692 if np.any(catPlot["sourceType"] == sourceTypeMapper["stars"]):
693 sideHist.hist(ysStars[np.isfinite(ysStars)], bins=bins, color="midnightblue", histtype="step",
694 orientation="horizontal", log=True)
695 sources = (catPlot["sourceType"] == sourceTypeMapper["stars"])
696 highSn = (catPlot["useForStats"] == 1)
697 lowSn = (catPlot["useForStats"] == 2)
698 sideHist.hist(ysStars[highSn[sources]], bins=bins, color="midnightblue", histtype="step",
699 orientation="horizontal", log=True, ls="--")
700 sideHist.hist(ysStars[lowSn[sources]], bins=bins, color="midnightblue", histtype="step",
701 orientation="horizontal", log=True, ls=":")
703 sideHist.axes.get_yaxis().set_visible(False)
704 sideHist.set_xlabel("Number", fontsize=8)
705 if self.config.plot2DHist and histIm is not None:
706 divider = make_axes_locatable(sideHist)
707 cax = divider.append_axes("right", size="8%", pad=0)
708 plt.colorbar(histIm, cax=cax, orientation="vertical", label="Number of Points Per Bin")
710 # Corner plot of patches showing summary stat in each
711 axCorner = plt.gcf().add_subplot(gs[0, -1])
712 axCorner.yaxis.tick_right()
713 axCorner.yaxis.set_label_position("right")
714 axCorner.xaxis.tick_top()
715 axCorner.xaxis.set_label_position("top")
716 axCorner.set_aspect("equal")
718 patches = []
719 colors = []
720 for dataId in sumStats.keys():
721 (corners, stat) = sumStats[dataId]
722 ra = corners[0][0].asDegrees()
723 dec = corners[0][1].asDegrees()
724 xy = (ra, dec)
725 width = corners[2][0].asDegrees() - ra
726 height = corners[2][1].asDegrees() - dec
727 patches.append(Rectangle(xy, width, height))
728 colors.append(stat)
729 ras = [ra.asDegrees() for (ra, dec) in corners]
730 decs = [dec.asDegrees() for (ra, dec) in corners]
731 axCorner.plot(ras + [ras[0]], decs + [decs[0]], "k", lw=0.5)
732 cenX = ra + width / 2
733 cenY = dec + height / 2
734 if dataId != "tract":
735 axCorner.annotate(dataId, (cenX, cenY), color="k", fontsize=4, ha="center", va="center")
737 # Set the bad color to transparent and make a masked array
738 colors = np.ma.array(colors, mask=np.isnan(colors))
739 collection = PatchCollection(patches, cmap=cmapPatch)
740 collection.set_array(colors)
741 axCorner.add_collection(collection)
743 axCorner.set_xlabel("R.A. (deg)", fontsize=7)
744 axCorner.set_ylabel("Dec. (deg)", fontsize=7)
745 axCorner.tick_params(axis="both", labelsize=6, length=0, pad=1.5)
746 axCorner.invert_xaxis()
748 # Add a colorbar
749 pos = axCorner.get_position()
750 cax = fig.add_axes([pos.x0, pos.y0 + 0.23, pos.x1 - pos.x0, 0.025])
751 plt.colorbar(collection, cax=cax, orientation="horizontal")
752 cax.text(0.5, 0.5, "Median Value", color="k", transform=cax.transAxes, rotation="horizontal",
753 horizontalalignment="center", verticalalignment="center", fontsize=6)
754 cax.tick_params(axis="x", labelsize=6, labeltop=True, labelbottom=False, bottom=False, top=True,
755 pad=0.5, length=2)
757 plt.draw()
758 plt.subplots_adjust(wspace=0.0, hspace=0.0, bottom=0.22, left=0.21)
759 fig = plt.gcf()
760 fig = addPlotInfo(fig, plotInfo)
762 return fig