Coverage for python/lsst/analysis/drp/scatterPlot.py: 9%
386 statements
« prev ^ index » next coverage.py v6.5.0, created at 2024-01-27 01:24 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2024-01-27 01:24 +0000
1# This file is part of analysis_drp.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21import matplotlib
22import matplotlib.pyplot as plt
23import numpy as np
24import pandas as pd
25from matplotlib import gridspec
26from matplotlib.patches import Rectangle
27from matplotlib.path import Path
28from matplotlib.collections import PatchCollection
29from mpl_toolkits.axes_grid1 import make_axes_locatable
31from lsst.pipe.tasks.configurableActions import ConfigurableActionStructField
32from lsst.pipe.tasks.dataFrameActions import MagColumnNanoJansky, SingleColumnAction
33from lsst.skymap import BaseSkyMap
35import lsst.pipe.base as pipeBase
36import lsst.pex.config as pexConfig
38from .calcFunctors import MagDiff
39from .dataSelectors import SnSelector, StarIdentifier, CoaddPlotFlagSelector
40from .plotUtils import generateSummaryStats, parsePlotInfo, addPlotInfo, mkColormap
41from .statistics import sigmaMad
43cmapPatch = plt.cm.coolwarm.copy()
44cmapPatch.set_bad(color="none")
45matplotlib.use("Agg")
48class ScatterPlotWithTwoHistsTaskConnections(pipeBase.PipelineTaskConnections,
49 dimensions=("tract", "skymap"),
50 defaultTemplates={"inputCoaddName": "deep",
51 "plotName": "deltaCoords"}):
53 catPlot = pipeBase.connectionTypes.Input(doc="The tract wide catalog to make plots from.",
54 storageClass="DataFrame",
55 name="objectTable_tract",
56 dimensions=("tract", "skymap"),
57 deferLoad=True)
59 skymap = pipeBase.connectionTypes.Input(doc="The skymap for the tract",
60 storageClass="SkyMap",
61 name=BaseSkyMap.SKYMAP_DATASET_TYPE_NAME,
62 dimensions=("skymap",))
64 scatterPlot = pipeBase.connectionTypes.Output(doc="A scatter plot with histograms for both axes.",
65 storageClass="Plot",
66 name="scatterTwoHistPlot_{plotName}",
67 dimensions=("tract", "skymap"))
70class ScatterPlotWithTwoHistsTaskConfig(pipeBase.PipelineTaskConfig,
71 pipelineConnections=ScatterPlotWithTwoHistsTaskConnections):
73 axisActions = ConfigurableActionStructField(
74 doc="The actions to use to calculate the values used on each axis. The defaults for the"
75 "column names xAction and magAction are set to iCModelFlux.",
76 default={"xAction": MagColumnNanoJansky, "yAction": SingleColumnAction,
77 "magAction": MagColumnNanoJansky},
78 )
80 axisLabels = pexConfig.DictField(
81 doc="Name of the dataframe columns to plot, will be used as the axis label: {'x':, 'y':, 'mag':}"
82 "The mag column is used to decide which points to include in the printed statistics.",
83 keytype=str,
84 itemtype=str
85 )
87 def get_requirements(self):
88 """Return inputs required for a Task to run with this config.
90 Returns
91 -------
92 bands : `set`
93 The required bands.
94 columns : `set`
95 The required column names.
96 """
97 columnNames = {"patch"}
98 bands = set()
99 for actionStruct in [self.axisActions,
100 self.selectorActions,
101 self.highSnStatisticSelectorActions,
102 self.lowSnStatisticSelectorActions,
103 self.sourceSelectorActions]:
104 for action in actionStruct:
105 for col in action.columns:
106 if col is not None:
107 columnNames.add(col)
108 column_split = col.split("_")
109 # If there's no underscore, it has no band prefix
110 if len(column_split) > 1:
111 band = column_split[0]
112 if band not in self.nonBandColumnPrefixes:
113 bands.add(band)
114 return bands, columnNames
116 nonBandColumnPrefixes = pexConfig.ListField(
117 doc="Column prefixes that are not bands and which should not be added to the set of bands",
118 dtype=str,
119 default=["coord", "extend", "detect", "xy", "merge", "sky"],
120 )
122 selectorActions = ConfigurableActionStructField(
123 doc="Which selectors to use to narrow down the data for QA plotting.",
124 default={"flagSelector": CoaddPlotFlagSelector,
125 "catSnSelector": SnSelector},
126 )
128 highSnStatisticSelectorActions = ConfigurableActionStructField(
129 doc="Selectors to use to decide which points to use for calculating the high SN statistics.",
130 default={"statSelector": SnSelector},
131 )
133 lowSnStatisticSelectorActions = ConfigurableActionStructField(
134 doc="Selectors to use to decide which points to use for calculating the low SN statistics.",
135 default={"statSelector": SnSelector},
136 )
138 sourceSelectorActions = ConfigurableActionStructField(
139 doc="What types of sources to use.",
140 default={"sourceSelector": StarIdentifier},
141 )
143 nBins = pexConfig.Field(
144 doc="Number of bins to put on the x axis.",
145 default=40.0,
146 dtype=float,
147 )
149 plot2DHist = pexConfig.Field(
150 doc="Plot a 2D histogram in the densist area of points on the scatter plot."
151 "Doesn't look great if plotting mulitple datasets on top of each other.",
152 default=True,
153 dtype=bool,
154 )
156 xLims = pexConfig.ListField(
157 doc="Minimum and maximum x-axis limit to force (provided as a list of [xMin, xMax]). "
158 "If `None`, limits will be computed and set based on the data.",
159 dtype=float,
160 default=None,
161 optional=True,
162 )
164 yLims = pexConfig.ListField(
165 doc="Minimum and maximum y-axis limit to force (provided as a list of [yMin, yMax]). "
166 "If `None`, limits will be computed and set based on the data.",
167 dtype=float,
168 default=None,
169 optional=True,
170 )
172 minPointSize = pexConfig.Field(
173 doc="When plotting points (as opposed to 2D hist bins), the minimum size they can be. Some "
174 "relative scaling will be perfomed depending on the \"flavor\" of the set of points.",
175 default=2.0,
176 dtype=float,
177 )
179 def setDefaults(self):
180 super().setDefaults()
181 self.axisActions.magAction.column = "i_cModelFlux"
182 self.axisActions.xAction.column = "i_cModelFlux"
183 self.axisActions.yAction = MagDiff
184 self.axisActions.yAction.col1 = "i_ap12Flux"
185 self.axisActions.yAction.col2 = "i_psfFlux"
186 self.selectorActions.flagSelector.bands = ["i"]
187 self.selectorActions.catSnSelector.fluxType = "psfFlux"
188 self.highSnStatisticSelectorActions.statSelector.fluxType = "cModelFlux"
189 self.highSnStatisticSelectorActions.statSelector.threshold = 2700
190 self.lowSnStatisticSelectorActions.statSelector.fluxType = "cModelFlux"
191 self.lowSnStatisticSelectorActions.statSelector.threshold = 500
192 self.axisLabels = {
193 "x": self.axisActions.xAction.column.removesuffix("Flux") + " (mag)",
194 "mag": self.axisActions.magAction.column.removesuffix("Flux") + " (mag)",
195 "y": ("{} - {} (mmag)".format(self.axisActions.yAction.col1.removesuffix("Flux"),
196 self.axisActions.yAction.col2.removesuffix("Flux")))
197 }
200class ScatterPlotWithTwoHistsTask(pipeBase.PipelineTask):
202 ConfigClass = ScatterPlotWithTwoHistsTaskConfig
203 _DefaultName = "scatterPlotWithTwoHistsTask"
205 def runQuantum(self, butlerQC, inputRefs, outputRefs):
206 # Docs inherited from base class
207 bands, columnNames = self.config.get_requirements()
208 inputs = butlerQC.get(inputRefs)
209 dataFrame = inputs["catPlot"].get(parameters={"columns": columnNames})
210 inputs['catPlot'] = dataFrame
211 dataId = butlerQC.quantum.dataId
212 inputs["dataId"] = dataId
213 inputs["runName"] = inputRefs.catPlot.datasetRef.run
214 localConnections = self.config.ConnectionsClass(config=self.config)
215 inputs["tableName"] = localConnections.catPlot.name
216 inputs["plotName"] = localConnections.scatterPlot.name
217 inputs["bands"] = bands
218 outputs = self.run(**inputs)
219 butlerQC.put(outputs, outputRefs)
221 def run(self, catPlot, dataId, runName, skymap, tableName, bands, plotName):
222 """Prep the catalogue and then make a scatterPlot of the given column.
224 Parameters
225 ----------
226 catPlot : `pandas.core.frame.DataFrame`
227 The catalog to plot the points from.
228 dataId :
229 `lsst.daf.butler.core.dimensions._coordinate._ExpandedTupleDataCoordinate`
230 The dimensions that the plot is being made from.
231 runName : `str`
232 The name of the collection that the plot is written out to.
233 skymap : `lsst.skymap`
234 The skymap used to define the patch boundaries.
235 tableName : `str`
236 The type of table used to make the plot.
238 Returns
239 -------
240 `pipeBase.Struct` containing:
241 scatterPlot : `matplotlib.figure.Figure`
242 The resulting figure.
244 Notes
245 -----
246 The catalogue is first narrowed down using the selectors specified in
247 `self.config.selectorActions`.
248 If the column names are 'Functor' then the functors specified in
249 `self.config.axisFunctors` are used to calculate the required values.
250 After this the following functions are run:
252 `parsePlotInfo` which uses the dataId, runName and tableName to add
253 useful information to the plot.
255 `generateSummaryStats` which parses the skymap to give the corners of
256 the patches for later plotting and calculates some basic statistics
257 in each patch for the column in self.zColName.
259 `scatterPlotWithTwoHists` which makes a scatter plot of the points with
260 a histogram of each axis.
261 """
263 # Apply the selectors to narrow down the sources to use
264 mask = np.ones(len(catPlot), dtype=bool)
265 for selector in self.config.selectorActions:
266 mask &= selector(catPlot)
267 catPlot = catPlot[mask]
269 columns = {self.config.axisLabels["x"]: self.config.axisActions.xAction(catPlot),
270 self.config.axisLabels["y"]: self.config.axisActions.yAction(catPlot),
271 self.config.axisLabels["mag"]: self.config.axisActions.magAction(catPlot),
272 "patch": catPlot["patch"]}
273 for actionStruct in [self.config.highSnStatisticSelectorActions,
274 self.config.lowSnStatisticSelectorActions,
275 self.config.sourceSelectorActions]:
276 for action in actionStruct:
277 for col in action.columns:
278 columns.update({col: catPlot[col]})
280 plotDf = pd.DataFrame(columns)
282 sourceTypes = np.zeros(len(plotDf))
283 for selector in self.config.sourceSelectorActions:
284 # The source selectors return 1 for a star and 2 for a galaxy
285 # rather than a mask this allows the information about which
286 # type of sources are being plotted to be propagated
287 sourceTypes += selector(catPlot)
288 if list(self.config.sourceSelectorActions) == []:
289 sourceTypes = [10]*len(plotDf)
290 plotDf.loc[:, "sourceType"] = sourceTypes
292 # Decide which points to use for stats calculation
293 useForStats = np.zeros(len(plotDf))
294 lowSnMask = np.ones(len(plotDf), dtype=bool)
295 for selector in self.config.lowSnStatisticSelectorActions:
296 lowSnMask &= selector(plotDf)
297 useForStats[lowSnMask] = 2
299 highSnMask = np.ones(len(plotDf), dtype=bool)
300 for selector in self.config.highSnStatisticSelectorActions:
301 highSnMask &= selector(plotDf)
302 useForStats[highSnMask] = 1
303 plotDf.loc[:, "useForStats"] = useForStats
305 # Get the S/N cut used
306 if hasattr(self.config.selectorActions, "catSnSelector"):
307 SN = self.config.selectorActions.catSnSelector.threshold
308 SNFlux = self.config.selectorActions.catSnSelector.fluxType
309 else:
310 SN = "N/A"
311 SNFlux = "N/A"
313 # Get useful information about the plot
314 plotInfo = parsePlotInfo(dataId, runName, tableName, bands, plotName, SN, SNFlux)
315 # Calculate the corners of the patches and some associated stats
316 sumStats = {} if skymap is None else generateSummaryStats(
317 plotDf, self.config.axisLabels["y"], skymap, plotInfo)
318 # Make the plot
319 fig = self.scatterPlotWithTwoHists(plotDf, plotInfo, sumStats)
320 return pipeBase.Struct(scatterPlot=fig)
322 def scatterPlotWithTwoHists(self, catPlot, plotInfo, sumStats):
323 """Makes a generic plot with a 2D histogram and collapsed histograms of
324 each axis.
326 Parameters
327 ----------
328 catPlot : `pandas.core.frame.DataFrame`
329 The catalog to plot the points from.
330 plotInfo : `dict`
331 A dictionary of information about the data being plotted with keys:
332 ``"run"``
333 The output run for the plots (`str`).
334 ``"skymap"``
335 The type of skymap used for the data (`str`).
336 ``"filter"``
337 The filter used for this data (`str`).
338 ``"tract"``
339 The tract that the data comes from (`str`).
340 sumStats : `dict`
341 A dictionary where the patchIds are the keys which store the R.A.
342 and dec of the corners of the patch, along with a summary
343 statistic for each patch.
345 Returns
346 -------
347 fig : `matplotlib.figure.Figure`
348 The resulting figure.
350 Notes
351 -----
352 Uses the axisLabels config options `x` and `y` and the axisAction
353 config options `xAction` and `yAction` to plot a scatter
354 plot of the values against each other. A histogram of the points
355 collapsed onto each axis is also plotted. A summary panel showing the
356 median of the y value in each patch is shown in the upper right corner
357 of the resultant plot. The code uses the selectorActions to decide
358 which points to plot and the statisticSelector actions to determine
359 which points to use for the printed statistics.
361 The axis limits are set based on the values of `config.xLim` and
362 `config.yLims`. If provided (as a `list` of [min, max]), those will
363 be used. If `None` (the default), the axis limits will be computed
364 and set based on the data.
365 """
366 self.log.info("Plotting %s: the values of %s on a scatter plot.",
367 self.config.connections.plotName, self.config.axisLabels['y'])
369 fig = plt.figure(dpi=300)
370 gs = gridspec.GridSpec(4, 4)
372 newBlues = mkColormap(["paleturquoise", "midnightBlue"])
373 newReds = mkColormap(["lemonchiffon", "firebrick"])
375 # Need to separate stars and galaxies
376 stars = (catPlot["sourceType"] == 1)
377 galaxies = (catPlot["sourceType"] == 2)
379 xCol = self.config.axisLabels["x"]
380 yCol = self.config.axisLabels["y"]
381 magCol = self.config.axisLabels["mag"]
383 # For galaxies
384 xsGalaxies = catPlot.loc[galaxies, xCol]
385 ysGalaxies = catPlot.loc[galaxies, yCol]
387 # For stars
388 xsStars = catPlot.loc[stars, xCol]
389 ysStars = catPlot.loc[stars, yCol]
391 highStats = {}
392 highMags = {}
393 lowStats = {}
394 lowMags = {}
396 # sourceTypes: 1 - stars, 2 - galaxies, 9 - unknowns
397 # 10 - all
398 sourceTypeList = [1, 2, 9, 10]
399 sourceTypeMapper = {"stars": 1, "galaxies": 2, "unknowns": 9, "all": 10}
400 # Calculate some statistics
401 for sourceType in sourceTypeList:
402 if np.any(catPlot["sourceType"] == sourceType):
403 sources = (catPlot["sourceType"] == sourceType)
404 highSn = ((catPlot["useForStats"] == 1) & sources)
405 highSnMed = np.nanmedian(catPlot.loc[highSn, yCol])
406 highSnMad = sigmaMad(catPlot.loc[highSn, yCol], nan_policy="omit")
408 lowSn = (((catPlot["useForStats"] == 1) | (catPlot["useForStats"] == 2)) & sources)
409 lowSnMed = np.nanmedian(catPlot.loc[lowSn, yCol])
410 lowSnMad = sigmaMad(catPlot.loc[lowSn, yCol], nan_policy="omit")
412 highStatsStr = (f"Median: {highSnMed:0.3g} "
413 + r"$\sigma_{MAD}$: " + f"{highSnMad:0.3g} "
414 + r"N$_{points}$: " + f"{np.sum(highSn)}")
415 highStats[sourceType] = highStatsStr
417 lowStatsStr = (f"Median: {lowSnMed:0.3g} "
418 + r"$\sigma_{MAD}$: " + f"{lowSnMad:0.3g} "
419 + r"N$_{points}$: " + f"{np.sum(lowSn)}")
420 lowStats[sourceType] = lowStatsStr
422 if np.sum(highSn) > 0:
423 sortedMags = np.sort(catPlot.loc[highSn, magCol])
424 x = int(len(sortedMags)/10)
425 approxHighMag = np.nanmedian(sortedMags[-x:])
426 elif len(catPlot.loc[highSn, magCol]) < 10:
427 approxHighMag = np.nanmedian(catPlot.loc[highSn, magCol])
428 else:
429 approxHighMag = "-"
430 highMags[sourceType] = f"{approxHighMag:.3g}"
432 if np.sum(lowSn) > 0.0:
433 sortedMags = np.sort(catPlot.loc[lowSn, magCol])
434 x = int(len(sortedMags)/10)
435 approxLowMag = np.nanmedian(sortedMags[-x:])
436 elif len(catPlot.loc[lowSn, magCol]) < 10:
437 approxLowMag = np.nanmedian(catPlot.loc[lowSn, magCol])
438 else:
439 approxLowMag = "-"
440 lowMags[sourceType] = f"{approxLowMag:.3g}"
442 # Main scatter plot
443 ax = fig.add_subplot(gs[1:, :-1])
444 binThresh = 5
446 linesForLegend = []
448 if (np.any(catPlot["sourceType"] == sourceTypeMapper["stars"])
449 and not np.any(catPlot["sourceType"] == sourceTypeMapper["galaxies"])):
450 toPlotList = [(xsStars.values, ysStars.values, "midnightblue", newBlues,
451 sourceTypeMapper["stars"])]
452 elif (np.any(catPlot["sourceType"] == sourceTypeMapper["galaxies"])
453 and not np.any(catPlot["sourceType"] == sourceTypeMapper["stars"])):
454 toPlotList = [(xsGalaxies.values, ysGalaxies.values, "firebrick", newReds,
455 sourceTypeMapper["galaxies"])]
456 elif (np.any(catPlot["sourceType"] == sourceTypeMapper["stars"])
457 and np.any(catPlot["sourceType"] == sourceTypeMapper["galaxies"])):
458 toPlotList = [(xsGalaxies.values, ysGalaxies.values, "firebrick", newReds,
459 sourceTypeMapper["galaxies"]),
460 (xsStars.values, ysStars.values, "midnightblue", newBlues,
461 sourceTypeMapper["stars"])]
462 elif np.any(catPlot["sourceType"] == sourceTypeMapper["unknowns"]):
463 unknowns = (catPlot["sourceType"] == sourceTypeMapper["unknowns"])
464 toPlotList = [(catPlot.loc[unknowns, xCol].values, catPlot.loc[unknowns, yCol].values,
465 "green", None, sourceTypeMapper["unknowns"])]
466 elif np.any(catPlot["sourceType"] == sourceTypeMapper["all"]):
467 toPlotList = [(catPlot[xCol].values, catPlot[yCol].values, "purple", None,
468 sourceTypeMapper["all"])]
469 else:
470 toPlotList = []
471 noDataFig = plt.Figure()
472 noDataFig.text(0.3, 0.5, "No data to plot after selectors applied")
473 noDataFig = addPlotInfo(noDataFig, plotInfo)
474 return noDataFig
476 xMin = None
477 for (j, (xs, ys, color, cmap, sourceType)) in enumerate(toPlotList):
478 sigMadYs = sigmaMad(ys, nan_policy="omit")
479 if len(xs) < 2:
480 medLine, = ax.plot(xs, np.nanmedian(ys), color,
481 label=f"Median: {np.nanmedian(ys):0.3g}", lw=0.8)
482 linesForLegend.append(medLine)
483 sigMads = np.array([sigmaMad(ys, nan_policy="omit")]*len(xs))
484 sigMadLine, = ax.plot(xs, np.nanmedian(ys) + 1.0*sigMads, color, alpha=0.8, lw=0.8,
485 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:0.3g}")
486 ax.plot(xs, np.nanmedian(ys) - 1.0*sigMads, color, alpha=0.8)
487 linesForLegend.append(sigMadLine)
488 histIm = None
489 continue
491 [xs1, xs25, xs50, xs75, xs95, xs97] = np.nanpercentile(xs, [1, 25, 50, 75, 95, 97])
492 xScale = (xs97 - xs1)/20.0 # This is ~5% of the data range
494 # 40 was used as the default number of bins because it looked good
495 xEdges = np.arange(np.nanmin(xs) - xScale, np.nanmax(xs) + xScale,
496 (np.nanmax(xs) + xScale - (np.nanmin(xs) - xScale))/self.config.nBins)
497 medYs = np.nanmedian(ys)
498 fiveSigmaHigh = medYs + 5.0*sigMadYs
499 fiveSigmaLow = medYs - 5.0*sigMadYs
500 binSize = (fiveSigmaHigh - fiveSigmaLow)/self.config.nBins
501 yEdges = np.arange(fiveSigmaLow, fiveSigmaHigh, binSize)
503 counts, xBins, yBins = np.histogram2d(xs, ys, bins=(xEdges, yEdges))
504 countsYs = np.sum(counts, axis=1)
506 ids = np.where((countsYs > binThresh))[0]
507 xEdgesPlot = xEdges[ids][1:]
508 xEdges = xEdges[ids]
510 if len(ids) > 1:
511 # Create the codes needed to turn the sigmaMad lines
512 # into a path to speed up checking which points are
513 # inside the area.
514 codes = np.ones(len(xEdgesPlot)*2)*Path.LINETO
515 codes[0] = Path.MOVETO
516 codes[-1] = Path.CLOSEPOLY
518 meds = np.zeros(len(xEdgesPlot))
519 threeSigMadVerts = np.zeros((len(xEdgesPlot)*2, 2))
520 sigMads = np.zeros(len(xEdgesPlot))
522 for (i, xEdge) in enumerate(xEdgesPlot):
523 ids = np.where((xs < xEdge) & (xs >= xEdges[i]) & (np.isfinite(ys)))[0]
524 med = np.median(ys[ids])
525 sigMad = sigmaMad(ys[ids])
526 meds[i] = med
527 sigMads[i] = sigMad
528 threeSigMadVerts[i, :] = [xEdge, med + 3*sigMad]
529 threeSigMadVerts[-(i + 1), :] = [xEdge, med - 3*sigMad]
531 medLine, = ax.plot(xEdgesPlot, meds, color, label="Running Median")
532 linesForLegend.append(medLine)
534 # Make path to check which points lie within one sigma mad
535 threeSigMadPath = Path(threeSigMadVerts, codes)
537 # Add lines for the median +/- 3 * sigma MAD
538 threeSigMadLine, = ax.plot(xEdgesPlot, threeSigMadVerts[:len(xEdgesPlot), 1], color,
539 alpha=0.4, label=r"3$\sigma_{MAD}$")
540 ax.plot(xEdgesPlot[::-1], threeSigMadVerts[len(xEdgesPlot):, 1], color, alpha=0.4)
542 # Add lines for the median +/- 1 * sigma MAD
543 sigMadLine, = ax.plot(xEdgesPlot, meds + 1.0*sigMads, color, alpha=0.8,
544 label=r"$\sigma_{MAD}$")
545 linesForLegend.append(sigMadLine)
546 ax.plot(xEdgesPlot, meds - 1.0*sigMads, color, alpha=0.8)
548 # Add lines for the median +/- 2 * sigma MAD
549 twoSigMadLine, = ax.plot(xEdgesPlot, meds + 2.0*sigMads, color, alpha=0.6,
550 label=r"2$\sigma_{MAD}$")
551 linesForLegend.append(twoSigMadLine)
552 linesForLegend.append(threeSigMadLine)
553 ax.plot(xEdgesPlot, meds - 2.0*sigMads, color, alpha=0.6)
555 # Check which points are outside 3 sigma MAD of the median
556 # and plot these as points.
557 inside = threeSigMadPath.contains_points(np.array([xs, ys]).T)
558 ax.plot(xs[~inside], ys[~inside], ".", ms=self.config.minPointSize, alpha=0.3,
559 mfc=color, mec=color, zorder=-1)
561 # Add some stats text
562 xPos = 0.65 - 0.4*j
563 if hasattr(self.config.highSnStatisticSelectorActions, "statSelector"):
564 bbox = dict(edgecolor=color, linestyle="--", facecolor="none")
565 highThresh = self.config.highSnStatisticSelectorActions.statSelector.threshold
566 statText = f"S/N > {highThresh} Stats [{magCol} $\\lesssim$ {highMags[sourceType]}]\n"
567 statText += highStats[sourceType]
568 fig.text(xPos, 0.090, statText, bbox=bbox, transform=fig.transFigure, fontsize=6)
570 if hasattr(self.config.highSnStatisticSelectorActions, "statSelector"):
571 bbox = dict(edgecolor=color, linestyle=":", facecolor="none")
572 lowThresh = self.config.lowSnStatisticSelectorActions.statSelector.threshold
573 statText = f"S/N > {lowThresh} Stats [{magCol} $\\lesssim$ {lowMags[sourceType]}]\n"
574 statText += lowStats[sourceType]
575 fig.text(xPos, 0.020, statText, bbox=bbox, transform=fig.transFigure, fontsize=6)
577 if self.config.plot2DHist:
578 histIm = ax.hexbin(xs[inside], ys[inside], gridsize=75, cmap=cmap, mincnt=1, zorder=-2)
580 # If there are not many sources being used for the
581 # statistics then plot them individually as just
582 # plotting a line makes the statistics look wrong
583 # as the magnitude estimation is iffy for low
584 # numbers of sources.
585 sources = (catPlot["sourceType"] == sourceType)
586 statInfo = catPlot["useForStats"].loc[sources].values
587 if hasattr(self.config.highSnStatisticSelectorActions, "statSelector"):
588 highSn = (statInfo == 1)
589 if np.sum(highSn) < 100 and np.sum(highSn) > 0:
590 ax.plot(xs[highSn], ys[highSn], marker="x", ms=self.config.minPointSize + 1,
591 mec="w", mew=2, ls="none")
592 highSnLine, = ax.plot(xs[highSn], ys[highSn], color=color, marker="x",
593 ms=self.config.minPointSize + 1, ls="none", label="High SN")
594 linesForLegend.append(highSnLine)
595 xMin = np.min(xs[highSn])
596 else:
597 ax.axvline(float(highMags[sourceType]), color=color, ls="--")
598 if hasattr(self.config.lowSnStatisticSelectorActions, "statSelector"):
599 lowSn = ((statInfo == 2) | (statInfo == 2))
600 if np.sum(lowSn) < 100 and np.sum(lowSn) > 0:
601 ax.plot(xs[lowSn], ys[lowSn], marker="+", ms=self.config.minPointSize + 1, mec="w",
602 mew=2, ls="none")
603 lowSnLine, = ax.plot(xs[lowSn], ys[lowSn], color=color, marker="+",
604 ms=self.config.minPointSize + 1, ls="none", label="Low SN")
605 linesForLegend.append(lowSnLine)
606 if xMin is None or xMin > np.min(xs[lowSn]):
607 xMin = np.min(xs[lowSn])
608 else:
609 ax.axvline(float(lowMags[sourceType]), color=color, ls=":")
611 else:
612 ax.plot(xs, ys, ".", ms=self.config.minPointSize + 3, alpha=0.3, mfc=color, mec=color,
613 zorder=-1)
614 meds = np.array([np.nanmedian(ys)]*len(xs))
615 medLine, = ax.plot(xs, meds, color, label=f"Median: {np.nanmedian(ys):0.3g}", lw=0.8)
616 linesForLegend.append(medLine)
617 sigMads = np.array([sigmaMad(ys, nan_policy="omit")]*len(xs))
618 sigMadLine, = ax.plot(xs, meds + 1.0*sigMads, color, alpha=0.8, lw=0.8,
619 label=r"$\sigma_{MAD}$: " + f"{sigMads[0]:0.3g}")
620 ax.plot(xs, meds - 1.0*sigMads, color, alpha=0.8)
621 linesForLegend.append(sigMadLine)
622 histIm = None
624 # Set the scatter plot limits
625 if len(ysStars) > 0:
626 plotMed = np.nanmedian(ysStars)
627 else:
628 plotMed = np.nanmedian(ysGalaxies)
629 if len(xs) < 2:
630 meds = [np.median(ys)]
632 if self.config.yLims is not None:
633 yLimMin = self.config.yLims[0]
634 yLimMax = self.config.yLims[1]
635 else:
636 numSig = 4
637 yLimMin = plotMed - numSig*sigMadYs
638 yLimMax = plotMed + numSig*sigMadYs
639 while (yLimMax < np.max(meds) or yLimMin > np.min(meds)) and numSig < 10:
640 numSig += 1
642 numSig += 1
643 yLimMin = plotMed - numSig*sigMadYs
644 yLimMax = plotMed + numSig*sigMadYs
645 ax.set_ylim(yLimMin, yLimMax)
647 if self.config.xLims is not None:
648 ax.set_xlim(self.config.xLims[0], self.config.xLims[1])
649 elif len(xs) > 2:
650 if xMin is None:
651 xMin = xs1 - 2*xScale
652 ax.set_xlim(xMin, xs97 + 2*xScale)
654 # Add a line legend
655 ax.legend(handles=linesForLegend, ncol=4, fontsize=6, loc="upper left", framealpha=0.9,
656 edgecolor="k", borderpad=0.4, handlelength=1)
658 # Add axes labels
659 ax.set_ylabel(yCol, fontsize=10, labelpad=10)
660 ax.set_xlabel(xCol, fontsize=10, labelpad=2)
662 # Top histogram
663 topHist = plt.gcf().add_subplot(gs[0, :-1], sharex=ax)
664 topHist.hist(catPlot[xCol].values, bins=100, color="grey", alpha=0.3, log=True,
665 label=f"All ({len(catPlot)})")
666 if np.any(catPlot["sourceType"] == 2):
667 topHist.hist(xsGalaxies, bins=100, color="firebrick", histtype="step", log=True,
668 label=f"Galaxies ({len(np.where(galaxies)[0])})")
669 if np.any(catPlot["sourceType"] == 1):
670 topHist.hist(xsStars, bins=100, color="midnightblue", histtype="step", log=True,
671 label=f"Stars ({len(np.where(stars)[0])})")
672 topHist.axes.get_xaxis().set_visible(False)
673 topHist.set_ylabel("Number", fontsize=8)
674 topHist.legend(fontsize=6, framealpha=0.9, borderpad=0.4, loc="lower left", ncol=3, edgecolor="k")
676 # Side histogram
677 sideHist = plt.gcf().add_subplot(gs[1:, -1], sharey=ax)
678 finiteObjs = np.isfinite(catPlot[yCol].values)
679 bins = np.linspace(yLimMin, yLimMax)
680 sideHist.hist(catPlot[yCol].values[finiteObjs], bins=bins, color="grey", alpha=0.3,
681 orientation="horizontal", log=True)
682 if np.any(catPlot["sourceType"] == sourceTypeMapper["galaxies"]):
683 sideHist.hist(ysGalaxies[np.isfinite(ysGalaxies)], bins=bins, color="firebrick", histtype="step",
684 orientation="horizontal", log=True)
685 sources = (catPlot["sourceType"].values == sourceTypeMapper["galaxies"])
686 highSn = (catPlot["useForStats"].values == 1)
687 lowSn = (catPlot["useForStats"].values == 2)
688 sideHist.hist(ysGalaxies[highSn[sources]], bins=bins, color="firebrick", histtype="step",
689 orientation="horizontal", log=True, ls="--")
690 sideHist.hist(ysGalaxies[lowSn[sources]], bins=bins, color="firebrick", histtype="step",
691 orientation="horizontal", log=True, ls=":")
693 if np.any(catPlot["sourceType"] == sourceTypeMapper["stars"]):
694 sideHist.hist(ysStars[np.isfinite(ysStars)], bins=bins, color="midnightblue", histtype="step",
695 orientation="horizontal", log=True)
696 sources = (catPlot["sourceType"] == sourceTypeMapper["stars"])
697 highSn = (catPlot["useForStats"] == 1)
698 lowSn = (catPlot["useForStats"] == 2)
699 sideHist.hist(ysStars[highSn[sources]], bins=bins, color="midnightblue", histtype="step",
700 orientation="horizontal", log=True, ls="--")
701 sideHist.hist(ysStars[lowSn[sources]], bins=bins, color="midnightblue", histtype="step",
702 orientation="horizontal", log=True, ls=":")
704 sideHist.axes.get_yaxis().set_visible(False)
705 sideHist.set_xlabel("Number", fontsize=8)
706 if self.config.plot2DHist and histIm is not None:
707 divider = make_axes_locatable(sideHist)
708 cax = divider.append_axes("right", size="8%", pad=0)
709 plt.colorbar(histIm, cax=cax, orientation="vertical", label="Number of Points Per Bin")
711 # Corner plot of patches showing summary stat in each
712 axCorner = plt.gcf().add_subplot(gs[0, -1])
713 axCorner.yaxis.tick_right()
714 axCorner.yaxis.set_label_position("right")
715 axCorner.xaxis.tick_top()
716 axCorner.xaxis.set_label_position("top")
717 axCorner.set_aspect("equal")
719 patches = []
720 colors = []
721 for dataId in sumStats.keys():
722 (corners, stat) = sumStats[dataId]
723 ra = corners[0][0].asDegrees()
724 dec = corners[0][1].asDegrees()
725 xy = (ra, dec)
726 width = corners[2][0].asDegrees() - ra
727 height = corners[2][1].asDegrees() - dec
728 patches.append(Rectangle(xy, width, height))
729 colors.append(stat)
730 ras = [ra.asDegrees() for (ra, dec) in corners]
731 decs = [dec.asDegrees() for (ra, dec) in corners]
732 axCorner.plot(ras + [ras[0]], decs + [decs[0]], "k", lw=0.5)
733 cenX = ra + width / 2
734 cenY = dec + height / 2
735 if dataId != "tract":
736 axCorner.annotate(dataId, (cenX, cenY), color="k", fontsize=4, ha="center", va="center")
738 # Set the bad color to transparent and make a masked array
739 colors = np.ma.array(colors, mask=np.isnan(colors))
740 collection = PatchCollection(patches, cmap=cmapPatch)
741 collection.set_array(colors)
742 axCorner.add_collection(collection)
744 axCorner.set_xlabel("R.A. (deg)", fontsize=7)
745 axCorner.set_ylabel("Dec. (deg)", fontsize=7)
746 axCorner.tick_params(axis="both", labelsize=6, length=0, pad=1.5)
747 axCorner.invert_xaxis()
749 # Add a colorbar
750 pos = axCorner.get_position()
751 cax = fig.add_axes([pos.x0, pos.y0 + 0.23, pos.x1 - pos.x0, 0.025])
752 plt.colorbar(collection, cax=cax, orientation="horizontal")
753 cax.text(0.5, 0.5, "Median Value", color="k", transform=cax.transAxes, rotation="horizontal",
754 horizontalalignment="center", verticalalignment="center", fontsize=6)
755 cax.tick_params(axis="x", labelsize=6, labeltop=True, labelbottom=False, bottom=False, top=True,
756 pad=0.5, length=2)
758 plt.draw()
759 plt.subplots_adjust(wspace=0.0, hspace=0.0, bottom=0.22, left=0.21)
760 fig = plt.gcf()
761 fig = addPlotInfo(fig, plotInfo)
763 return fig