Coverage for python/lsst/analysis/tools/actions/plot/plotUtils.py: 11%
300 statements
« prev ^ index » next coverage.py v7.2.5, created at 2023-05-09 03:19 -0700
« prev ^ index » next coverage.py v7.2.5, created at 2023-05-09 03:19 -0700
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ("PanelConfig",)
25from typing import TYPE_CHECKING, Iterable, List, Mapping, Tuple
27import matplotlib
28import matplotlib.pyplot as plt
29import numpy as np
30import scipy.odr as scipyODR
31from lsst.geom import Box2D, SpherePoint, degrees
32from lsst.pex.config import Config, Field
33from matplotlib import colors
34from matplotlib.collections import PatchCollection
35from matplotlib.patches import Rectangle
36from scipy.stats import binned_statistic_2d
38from ...statistics import nansigmaMad
40if TYPE_CHECKING: 40 ↛ 41line 40 didn't jump to line 41, because the condition on line 40 was never true
41 from matplotlib.figure import Figure
43null_formatter = matplotlib.ticker.NullFormatter()
46def parsePlotInfo(dataId, runName, tableName, bands, plotName, SN):
47 """Parse plot info from the dataId
48 Parameters
49 ----------
50 dataId : `lsst.daf.butler.core.dimensions.`
51 `_coordinate._ExpandedTupleDataCoordinate`
52 runName : `str`
53 Returns
54 -------
55 plotInfo : `dict`
56 """
57 plotInfo = {"run": runName, "tableName": tableName, "plotName": plotName, "SN": SN}
59 for dataInfo in dataId:
60 plotInfo[dataInfo.name] = dataId[dataInfo.name]
62 bandStr = ""
63 for band in bands:
64 bandStr += ", " + band
65 plotInfo["bands"] = bandStr[2:]
67 if "tract" not in plotInfo.keys():
68 plotInfo["tract"] = "N/A"
69 if "visit" not in plotInfo.keys():
70 plotInfo["visit"] = "N/A"
72 return plotInfo
75def generateSummaryStats(data, skymap, plotInfo):
76 """Generate a summary statistic in each patch or detector
77 Parameters
78 ----------
79 data : `dict`
80 skymap : `lsst.skymap.ringsSkyMap.RingsSkyMap`
81 plotInfo : `dict`
83 Returns
84 -------
85 patchInfoDict : `dict`
86 """
88 # TODO: what is the more generic type of skymap?
89 tractInfo = skymap.generateTract(plotInfo["tract"])
90 tractWcs = tractInfo.getWcs()
92 # For now also convert the gen 2 patchIds to gen 3
93 if "y" in data.keys():
94 yCol = "y"
95 elif "yStars" in data.keys():
96 yCol = "yStars"
97 elif "yGalaxies" in data.keys():
98 yCol = "yGalaxies"
99 elif "yUnknowns" in data.keys():
100 yCol = "yUnknowns"
102 patchInfoDict = {}
103 maxPatchNum = tractInfo.num_patches.x * tractInfo.num_patches.y
104 patches = np.arange(0, maxPatchNum, 1)
105 for patch in patches:
106 if patch is None:
107 continue
108 # Once the objectTable_tract catalogues are using gen 3 patches
109 # this will go away
110 onPatch = data["patch"] == patch
111 stat = np.nanmedian(data[yCol][onPatch])
112 try:
113 patchTuple = (int(patch.split(",")[0]), int(patch.split(",")[-1]))
114 patchInfo = tractInfo.getPatchInfo(patchTuple)
115 gen3PatchId = tractInfo.getSequentialPatchIndex(patchInfo)
116 except AttributeError:
117 # For native gen 3 tables the patches don't need converting
118 # When we are no longer looking at the gen 2 -> gen 3
119 # converted repos we can tidy this up
120 gen3PatchId = patch
121 patchInfo = tractInfo.getPatchInfo(patch)
123 corners = Box2D(patchInfo.getInnerBBox()).getCorners()
124 skyCoords = tractWcs.pixelToSky(corners)
126 patchInfoDict[gen3PatchId] = (skyCoords, stat)
128 tractCorners = Box2D(tractInfo.getBBox()).getCorners()
129 skyCoords = tractWcs.pixelToSky(tractCorners)
130 patchInfoDict["tract"] = (skyCoords, np.nan)
132 return patchInfoDict
135def generateSummaryStatsVisit(cat, colName, visitSummaryTable, plotInfo):
136 """Generate a summary statistic in each patch or detector
137 Parameters
138 ----------
139 cat : `pandas.core.frame.DataFrame`
140 colName : `str`
141 visitSummaryTable : `pandas.core.frame.DataFrame`
142 plotInfo : `dict`
143 Returns
144 -------
145 visitInfoDict : `dict`
146 """
148 visitInfoDict = {}
149 for ccd in cat.detector.unique():
150 if ccd is None:
151 continue
152 onCcd = cat["detector"] == ccd
153 stat = np.nanmedian(cat[colName].values[onCcd])
155 sumRow = visitSummaryTable["id"] == ccd
156 corners = zip(visitSummaryTable["raCorners"][sumRow][0], visitSummaryTable["decCorners"][sumRow][0])
157 cornersOut = []
158 for ra, dec in corners:
159 corner = SpherePoint(ra, dec, units=degrees)
160 cornersOut.append(corner)
162 visitInfoDict[ccd] = (cornersOut, stat)
164 return visitInfoDict
167# Inspired by matplotlib.testing.remove_ticks_and_titles
168def get_and_remove_axis_text(ax) -> Tuple[List[str], List[np.ndarray]]:
169 """Remove text from an Axis and its children and return with line points.
170 Parameters
171 ----------
172 ax : `plt.Axis`
173 A matplotlib figure axis.
174 Returns
175 -------
176 texts : `List[str]`
177 A list of all text strings (title and axis/legend/tick labels).
178 line_xys : `List[numpy.ndarray]`
179 A list of all line ``_xy`` attributes (arrays of shape ``(N, 2)``).
180 """
181 line_xys = [line._xy for line in ax.lines]
182 texts = [text.get_text() for text in (ax.title, ax.xaxis.label, ax.yaxis.label)]
183 ax.set_title("")
184 ax.set_xlabel("")
185 ax.set_ylabel("")
187 try:
188 texts_legend = ax.get_legend().texts
189 texts.extend(text.get_text() for text in texts_legend)
190 for text in texts_legend:
191 text.set_alpha(0)
192 except AttributeError:
193 pass
195 for idx in range(len(ax.texts)):
196 texts.append(ax.texts[idx].get_text())
197 ax.texts[idx].set_text("")
199 ax.xaxis.set_major_formatter(null_formatter)
200 ax.xaxis.set_minor_formatter(null_formatter)
201 ax.yaxis.set_major_formatter(null_formatter)
202 ax.yaxis.set_minor_formatter(null_formatter)
203 try:
204 ax.zaxis.set_major_formatter(null_formatter)
205 ax.zaxis.set_minor_formatter(null_formatter)
206 except AttributeError:
207 pass
208 for child in ax.child_axes:
209 texts_child, lines_child = get_and_remove_axis_text(child)
210 texts.extend(texts_child)
212 return texts, line_xys
215def get_and_remove_figure_text(figure: Figure):
216 """Remove text from a Figure and its Axes and return with line points.
217 Parameters
218 ----------
219 figure : `matplotlib.pyplot.Figure`
220 A matplotlib figure.
221 Returns
222 -------
223 texts : `List[str]`
224 A list of all text strings (title and axis/legend/tick labels).
225 line_xys : `List[numpy.ndarray]`, (N, 2)
226 A list of all line ``_xy`` attributes (arrays of shape ``(N, 2)``).
227 """
228 texts = [str(figure._suptitle)]
229 lines = []
230 figure.suptitle("")
232 texts.extend(text.get_text() for text in figure.texts)
233 figure.texts = []
235 for ax in figure.get_axes():
236 texts_ax, lines_ax = get_and_remove_axis_text(ax)
237 texts.extend(texts_ax)
238 lines.extend(lines_ax)
240 return texts, lines
243def addPlotInfo(fig: Figure, plotInfo: Mapping[str, str]) -> Figure:
244 """Add useful information to the plot
245 Parameters
246 ----------
247 fig : `matplotlib.figure.Figure`
248 plotInfo : `dict`
249 Returns
250 -------
251 fig : `matplotlib.figure.Figure`
252 """
254 # TO DO: figure out how to get this information
255 photocalibDataset = "None"
256 astroDataset = "None"
258 fig.text(0.01, 0.99, plotInfo["plotName"], fontsize=8, transform=fig.transFigure, ha="left", va="top")
260 run = plotInfo["run"]
261 datasetsUsed = f"\nPhotoCalib: {photocalibDataset}, Astrometry: {astroDataset}"
262 tableType = f"\nTable: {plotInfo['tableName']}"
264 dataIdText = ""
265 if "tract" in plotInfo.keys():
266 dataIdText += f", Tract: {plotInfo['tract']}"
267 if "visit" in plotInfo.keys():
268 dataIdText += f", Visit: {plotInfo['visit']}"
270 bandText = ""
271 for band in plotInfo["bands"]:
272 bandText += band + ", "
273 bandsText = f", Bands: {bandText[:-2]}"
274 SNText = f", S/N: {plotInfo.get('SN', 'N/A')}"
275 infoText = f"\n{run}{datasetsUsed}{tableType}{dataIdText}{bandsText}{SNText}"
276 fig.text(0.01, 0.98, infoText, fontsize=7, transform=fig.transFigure, alpha=0.6, ha="left", va="top")
278 return fig
281def stellarLocusFit(xs, ys, paramDict):
282 """Make a fit to the stellar locus
283 Parameters
284 ----------
285 xs : `numpy.ndarray`
286 The color on the xaxis
287 ys : `numpy.ndarray`
288 The color on the yaxis
289 paramDict : lsst.pex.config.dictField.Dict
290 A dictionary of parameters for line fitting
291 xMin : `float`
292 The minimum x edge of the box to use for initial fitting
293 xMax : `float`
294 The maximum x edge of the box to use for initial fitting
295 yMin : `float`
296 The minimum y edge of the box to use for initial fitting
297 yMax : `float`
298 The maximum y edge of the box to use for initial fitting
299 mHW : `float`
300 The hardwired gradient for the fit
301 bHW : `float`
302 The hardwired intercept of the fit
303 Returns
304 -------
305 paramsOut : `dict`
306 A dictionary of the calculated fit parameters
307 xMin : `float`
308 The minimum x edge of the box to use for initial fitting
309 xMax : `float`
310 The maximum x edge of the box to use for initial fitting
311 yMin : `float`
312 The minimum y edge of the box to use for initial fitting
313 yMax : `float`
314 The maximum y edge of the box to use for initial fitting
315 mHW : `float`
316 The hardwired gradient for the fit
317 bHW : `float`
318 The hardwired intercept of the fit
319 mODR : `float`
320 The gradient calculated by the ODR fit
321 bODR : `float`
322 The intercept calculated by the ODR fit
323 yBoxMin : `float`
324 The y value of the fitted line at xMin
325 yBoxMax : `float`
326 The y value of the fitted line at xMax
327 bPerpMin : `float`
328 The intercept of the perpendicular line that goes through xMin
329 bPerpMax : `float`
330 The intercept of the perpendicular line that goes through xMax
331 mODR2 : `float`
332 The gradient from the second round of fitting
333 bODR2 : `float`
334 The intercept from the second round of fitting
335 mPerp : `float`
336 The gradient of the line perpendicular to the line from the
337 second fit
338 Notes
339 -----
340 The code does two rounds of fitting, the first is initiated using the
341 hardwired values given in the `paramDict` parameter and is done using
342 an Orthogonal Distance Regression fit to the points defined by the
343 box of xMin, xMax, yMin and yMax. Once this fitting has been done a
344 perpendicular bisector is calculated at either end of the line and
345 only points that fall within these lines are used to recalculate the fit.
346 """
348 # Points to use for the fit
349 fitPoints = np.where(
350 (xs > paramDict["xMin"])
351 & (xs < paramDict["xMax"])
352 & (ys > paramDict["yMin"])
353 & (ys < paramDict["yMax"])
354 )[0]
356 linear = scipyODR.polynomial(1)
358 data = scipyODR.Data(xs[fitPoints], ys[fitPoints])
359 odr = scipyODR.ODR(data, linear, beta0=[paramDict["bHW"], paramDict["mHW"]])
360 params = odr.run()
361 mODR = float(params.beta[1])
362 bODR = float(params.beta[0])
364 paramsOut = {
365 "xMin": paramDict["xMin"],
366 "xMax": paramDict["xMax"],
367 "yMin": paramDict["yMin"],
368 "yMax": paramDict["yMax"],
369 "mHW": paramDict["mHW"],
370 "bHW": paramDict["bHW"],
371 "mODR": mODR,
372 "bODR": bODR,
373 }
375 # Having found the initial fit calculate perpendicular ends
376 mPerp = -1.0 / mODR
377 # When the gradient is really steep we need to use
378 # the y limits of the box rather than the x ones
380 if np.abs(mODR) > 1:
381 yBoxMin = paramDict["yMin"]
382 xBoxMin = (yBoxMin - bODR) / mODR
383 yBoxMax = paramDict["yMax"]
384 xBoxMax = (yBoxMax - bODR) / mODR
385 else:
386 yBoxMin = mODR * paramDict["xMin"] + bODR
387 xBoxMin = paramDict["xMin"]
388 yBoxMax = mODR * paramDict["xMax"] + bODR
389 xBoxMax = paramDict["xMax"]
391 bPerpMin = yBoxMin - mPerp * xBoxMin
393 paramsOut["yBoxMin"] = yBoxMin
394 paramsOut["bPerpMin"] = bPerpMin
396 bPerpMax = yBoxMax - mPerp * xBoxMax
398 paramsOut["yBoxMax"] = yBoxMax
399 paramsOut["bPerpMax"] = bPerpMax
401 # Use these perpendicular lines to chose the data and refit
402 fitPoints = (ys > mPerp * xs + bPerpMin) & (ys < mPerp * xs + bPerpMax)
403 data = scipyODR.Data(xs[fitPoints], ys[fitPoints])
404 odr = scipyODR.ODR(data, linear, beta0=[bODR, mODR])
405 params = odr.run()
406 mODR = float(params.beta[1])
407 bODR = float(params.beta[0])
409 paramsOut["mODR2"] = float(params.beta[1])
410 paramsOut["bODR2"] = float(params.beta[0])
412 paramsOut["mPerp"] = -1.0 / paramsOut["mODR2"]
414 return paramsOut
417def perpDistance(p1, p2, points):
418 """Calculate the perpendicular distance to a line from a point
419 Parameters
420 ----------
421 p1 : `numpy.ndarray`
422 A point on the line
423 p2 : `numpy.ndarray`
424 Another point on the line
425 points : `zip`
426 The points to calculate the distance to
427 Returns
428 -------
429 dists : `list`
430 The distances from the line to the points. Uses the cross
431 product to work this out.
432 """
433 dists = []
434 for point in points:
435 point = np.array(point)
436 distToLine = np.cross(p2 - p1, point - p1) / np.linalg.norm(p2 - p1)
437 dists.append(distToLine)
439 return dists
442def mkColormap(colorNames):
443 """Make a colormap from the list of color names.
444 Parameters
445 ----------
446 colorNames : `list`
447 A list of strings that correspond to matplotlib
448 named colors.
449 Returns
450 -------
451 cmap : `matplotlib.colors.LinearSegmentedColormap`
452 """
454 nums = np.linspace(0, 1, len(colorNames))
455 blues = []
456 greens = []
457 reds = []
458 for num, color in zip(nums, colorNames):
459 r, g, b = colors.colorConverter.to_rgb(color)
460 blues.append((num, b, b))
461 greens.append((num, g, g))
462 reds.append((num, r, r))
464 colorDict = {"blue": blues, "red": reds, "green": greens}
465 cmap = colors.LinearSegmentedColormap("newCmap", colorDict)
466 return cmap
469def extremaSort(xs):
470 """Return the ids of the points reordered so that those
471 furthest from the median, in absolute terms, are last.
472 Parameters
473 ----------
474 xs : `np.array`
475 An array of the values to sort
476 Returns
477 -------
478 ids : `np.array`
479 """
481 med = np.nanmedian(xs)
482 dists = np.abs(xs - med)
483 ids = np.argsort(dists)
484 return ids
487def sortAllArrays(arrsToSort, sortArrayIndex=0):
488 """Sort one array and then return all the others in the associated order.
490 Parameters
491 ----------
492 arrsToSort : `list` [`np.array`]
493 A list of arrays to be simultaneously sorted based on the array in
494 the list position given by ``sortArrayIndex`` (defaults to be the
495 first array in the list).
496 sortArrayIndex : `int`, optional
497 Zero-based index indicating the array on which to base the sorting.
499 Returns
500 -------
501 arrsToSort : `list` [`np.array`]
502 The list of arrays sorted on array in list index ``sortArrayIndex``.
503 """
504 ids = extremaSort(arrsToSort[sortArrayIndex])
505 for i, arr in enumerate(arrsToSort):
506 arrsToSort[i] = arr[ids]
507 return arrsToSort
510def addSummaryPlot(fig, loc, sumStats, label):
511 """Add a summary subplot to the figure.
513 Parameters
514 ----------
515 fig : `matplotlib.figure.Figure`
516 The figure that the summary plot is to be added to.
517 loc : `matplotlib.gridspec.SubplotSpec` or `int` or `(int, int, index`
518 Describes the location in the figure to put the summary plot,
519 can be a gridspec SubplotSpec, a 3 digit integer where the first
520 digit is the number of rows, the second is the number of columns
521 and the third is the index. This is the same for the tuple
522 of int, int, index.
523 sumStats : `dict`
524 A dictionary where the patchIds are the keys which store the R.A.
525 and the dec of the corners of the patch, along with a summary
526 statistic for each patch.
528 Returns
529 -------
530 fig : `matplotlib.figure.Figure`
531 """
533 # Add the subplot to the relevant place in the figure
534 # and sort the axis out
535 axCorner = fig.add_subplot(loc)
536 axCorner.yaxis.tick_right()
537 axCorner.yaxis.set_label_position("right")
538 axCorner.xaxis.tick_top()
539 axCorner.xaxis.set_label_position("top")
540 axCorner.set_aspect("equal")
542 # Plot the corners of the patches and make the color
543 # coded rectangles for each patch, the colors show
544 # the median of the given value in the patch
545 patches = []
546 colors = []
547 for dataId in sumStats.keys():
548 (corners, stat) = sumStats[dataId]
549 ra = corners[0][0].asDegrees()
550 dec = corners[0][1].asDegrees()
551 xy = (ra, dec)
552 width = corners[2][0].asDegrees() - ra
553 height = corners[2][1].asDegrees() - dec
554 patches.append(Rectangle(xy, width, height))
555 colors.append(stat)
556 ras = [ra.asDegrees() for (ra, dec) in corners]
557 decs = [dec.asDegrees() for (ra, dec) in corners]
558 axCorner.plot(ras + [ras[0]], decs + [decs[0]], "k", lw=0.5)
559 cenX = ra + width / 2
560 cenY = dec + height / 2
561 if dataId != "tract":
562 axCorner.annotate(dataId, (cenX, cenY), color="k", fontsize=4, ha="center", va="center")
564 # Set the bad color to transparent and make a masked array
565 cmapPatch = plt.cm.coolwarm.copy()
566 cmapPatch.set_bad(color="none")
567 colors = np.ma.array(colors, mask=np.isnan(colors))
568 collection = PatchCollection(patches, cmap=cmapPatch)
569 collection.set_array(colors)
570 axCorner.add_collection(collection)
572 # Add some labels
573 axCorner.set_xlabel("R.A. (deg)", fontsize=7)
574 axCorner.set_ylabel("Dec. (deg)", fontsize=7)
575 axCorner.tick_params(axis="both", labelsize=6, length=0, pad=1.5)
576 axCorner.invert_xaxis()
578 # Add a colorbar
579 pos = axCorner.get_position()
580 yOffset = (pos.y1 - pos.y0) / 3
581 cax = fig.add_axes([pos.x0, pos.y1 + yOffset, pos.x1 - pos.x0, 0.025])
582 plt.colorbar(collection, cax=cax, orientation="horizontal")
583 cax.text(
584 0.5,
585 0.48,
586 label,
587 color="k",
588 transform=cax.transAxes,
589 rotation="horizontal",
590 horizontalalignment="center",
591 verticalalignment="center",
592 fontsize=6,
593 )
594 cax.tick_params(
595 axis="x", labelsize=6, labeltop=True, labelbottom=False, bottom=False, top=True, pad=0.5, length=2
596 )
598 return fig
601def shorten_list(numbers: Iterable[int], *, range_indicator: str = "-", range_separator: str = ",") -> str:
602 """Shorten an iterable of integers.
604 Parameters
605 ----------
606 numbers : `~collections.abc.Iterable` [`int`]
607 Any iterable (list, set, tuple, numpy.array) of integers.
608 range_indicator : `str`, optional
609 The string to use to indicate a range of numbers.
610 range_separator : `str`, optional
611 The string to use to separate ranges of numbers.
613 Returns
614 -------
615 result : `str`
616 A shortened string representation of the list.
618 Examples
619 --------
620 >>> shorten_list([1,2,3,5,6,8])
621 "1-3,5-6,8"
623 >>> shorten_list((1,2,3,5,6,8,9,10,11), range_separator=", ")
624 "1-3, 5-6, 8-11"
626 >>> shorten_list(range(4), range_indicator="..")
627 "0..3"
628 """
630 # Sort the list in ascending order.
631 numbers = sorted(numbers)
633 if not numbers: # empty container
634 return ""
636 # Initialize an empty list to hold the results to be returned.
637 result = []
639 # Initialize variables to track the current start and end of a list.
640 start = 0
641 end = 0 # initialize to 0 to handle single element lists.
643 # Iterate through the sorted list of numbers
644 for end in range(1, len(numbers)):
645 # If the current number is the same or consecutive to the previous
646 # number, skip to the next iteration.
647 if numbers[end] > numbers[end - 1] + 1: # > is used to handle duplicates, if any.
648 # If the current number is not consecutive to the previous number,
649 # add the current range to the result and reset the start to end.
650 if start == end - 1:
651 result.append(str(numbers[start]))
652 else:
653 result.append(range_indicator.join((str(numbers[start]), str(numbers[end - 1]))))
655 # Update start.
656 start = end
658 # Add the final range to the result.
659 if start == end:
660 result.append(str(numbers[start]))
661 else:
662 result.append(range_indicator.join((str(numbers[start]), str(numbers[end]))))
664 # Return the shortened string representation.
665 return range_separator.join(result)
668class PanelConfig(Config):
669 """Configuration options for the plot panels used by DiaSkyPlot.
671 The defaults will produce a good-looking single panel plot.
672 The subplot2grid* fields correspond to matplotlib.pyplot.subplot2grid.
673 """
675 topSpinesVisible = Field[bool](
676 doc="Draw line and ticks on top of panel?",
677 default=False,
678 )
679 bottomSpinesVisible = Field[bool](
680 doc="Draw line and ticks on bottom of panel?",
681 default=True,
682 )
683 leftSpinesVisible = Field[bool](
684 doc="Draw line and ticks on left side of panel?",
685 default=True,
686 )
687 rightSpinesVisible = Field[bool](
688 doc="Draw line and ticks on right side of panel?",
689 default=True,
690 )
691 subplot2gridShapeRow = Field[int](
692 doc="Number of rows of the grid in which to place axis.",
693 default=10,
694 )
695 subplot2gridShapeColumn = Field[int](
696 doc="Number of columns of the grid in which to place axis.",
697 default=10,
698 )
699 subplot2gridLocRow = Field[int](
700 doc="Row of the axis location within the grid.",
701 default=1,
702 )
703 subplot2gridLocColumn = Field[int](
704 doc="Column of the axis location within the grid.",
705 default=1,
706 )
707 subplot2gridRowspan = Field[int](
708 doc="Number of rows for the axis to span downwards.",
709 default=5,
710 )
711 subplot2gridColspan = Field[int](
712 doc="Number of rows for the axis to span to the right.",
713 default=5,
714 )
717def plotProjectionWithBinning(
718 ax,
719 xs,
720 ys,
721 zs,
722 cmap,
723 xMin,
724 xMax,
725 yMin,
726 yMax,
727 xNumBins=45,
728 yNumBins=None,
729 fixAroundZero=False,
730 nPointBinThresh=5000,
731 isSorted=False,
732 vmin=None,
733 vmax=None,
734 scatPtSize=7,
735):
736 """Plot color-mapped data in projection and with binning when appropriate.
738 Parameters
739 ----------
740 ax : `matplotlib.axes.Axes`
741 Axis on which to plot the projection data.
742 xs, ys : `np.array`
743 Arrays containing the x and y positions of the data.
744 zs : `np.array`
745 Array containing the scaling value associated with the (``xs``, ``ys``)
746 positions.
747 cmap : `matplotlib.colors.Colormap`
748 Colormap for the ``zs`` values.
749 xMin, xMax, yMin, yMax : `float`
750 Data limits within which to compute bin sizes.
751 xNumBins : `int`, optional
752 The number of bins along the x-axis.
753 yNumBins : `int`, optional
754 The number of bins along the y-axis. If `None`, this is set to equal
755 ``xNumBins``.
756 nPointBinThresh : `int`, optional
757 Threshold number of points above which binning will be implemented
758 for the plotting. If the number of data points is lower than this
759 threshold, a basic scatter plot will be generated.
760 isSorted : `bool`, optional
761 Whether the data have been sorted in ``zs`` (the sorting is to
762 accommodate the overplotting of points in the upper and lower
763 extrema of the data).
764 vmin, vmax : `float`, optional
765 The min and max limits for the colorbar.
766 scatPtSize : `float`, optional
767 The point size to use if just plotting a regular scatter plot.
769 Returns
770 -------
771 plotOut : `matplotlib.collections.PathCollection`
772 The plot object with ``ax`` updated with data plotted here.
773 """
774 med = np.nanmedian(zs)
775 mad = nansigmaMad(zs)
776 if vmin is None:
777 vmin = med - 2 * mad
778 if vmax is None:
779 vmax = med + 2 * mad
780 if fixAroundZero:
781 scaleEnd = np.max([np.abs(vmin), np.abs(vmax)])
782 vmin = -1 * scaleEnd
783 vmax = scaleEnd
785 yNumBins = xNumBins if yNumBins is None else yNumBins
787 xBinEdges = np.linspace(xMin, xMax, xNumBins + 1)
788 yBinEdges = np.linspace(yMin, yMax, yNumBins + 1)
789 binnedStats, xEdges, yEdges, binNums = binned_statistic_2d(
790 xs, ys, zs, statistic="median", bins=(xBinEdges, yBinEdges)
791 )
792 if len(xs) >= nPointBinThresh:
793 s = min(10, max(0.5, nPointBinThresh / 10 / (len(xs) ** 0.5)))
794 lw = (s**0.5) / 10
795 plotOut = ax.imshow(
796 binnedStats.T,
797 cmap=cmap,
798 extent=[xEdges[0], xEdges[-1], yEdges[-1], yEdges[0]],
799 vmin=vmin,
800 vmax=vmax,
801 )
802 if not isSorted:
803 sortedArrays = sortAllArrays([zs, xs, ys])
804 zs, xs, ys = sortedArrays[0], sortedArrays[1], sortedArrays[2]
805 # Find the most extreme 15% of points. The list is ordered by the
806 # distance from the median, this is just the head/tail 15% of points.
807 if len(xs) > 1:
808 extremes = int(np.floor((len(xs) / 100)) * 85)
809 plotOut = ax.scatter(
810 xs[extremes:],
811 ys[extremes:],
812 c=zs[extremes:],
813 s=s,
814 cmap=cmap,
815 vmin=vmin,
816 vmax=vmax,
817 edgecolor="white",
818 linewidths=lw,
819 )
820 else:
821 plotOut = ax.scatter(
822 xs,
823 ys,
824 c=zs,
825 cmap=cmap,
826 s=scatPtSize,
827 vmin=vmin,
828 vmax=vmax,
829 edgecolor="white",
830 linewidths=0.2,
831 )
832 return plotOut