Coverage for python/lsst/analysis/tools/actions/plot/plotUtils.py: 14%
275 statements
« prev ^ index » next coverage.py v7.4.1, created at 2024-02-09 13:16 +0000
« prev ^ index » next coverage.py v7.4.1, created at 2024-02-09 13:16 +0000
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ("PanelConfig",)
25from typing import TYPE_CHECKING, Iterable, List, Mapping, Tuple
27import matplotlib
28import matplotlib.pyplot as plt
29import numpy as np
30from lsst.geom import Box2D, SpherePoint, degrees
31from lsst.pex.config import Config, Field
32from matplotlib import colors
33from matplotlib.collections import PatchCollection
34from matplotlib.patches import Rectangle
35from scipy.stats import binned_statistic_2d
37from ...math import nanMedian, nanSigmaMad
39if TYPE_CHECKING: 39 ↛ 40line 39 didn't jump to line 40, because the condition on line 39 was never true
40 from matplotlib.figure import Figure
42null_formatter = matplotlib.ticker.NullFormatter()
45def parsePlotInfo(dataId, runName, tableName, bands, plotName, SN):
46 """Parse plot info from the dataId.
48 Parameters
49 ----------
50 dataId : `dict`
51 The dataId of the data to be plotted.
52 runName : `str`
53 The name of the run.
54 tableName : `str`
55 The name of the table.
56 bands : `list` [`str`]
57 The bands to be plotted.
58 plotName : `str`
59 The name of the plot.
60 SN : `str`
61 The signal to noise of the data.
63 Returns
64 -------
65 plotInfo : `dict`
66 A dictionary of the plot information.
67 """
68 plotInfo = {"run": runName, "tableName": tableName, "plotName": plotName, "SN": SN}
70 for dataInfo in dataId:
71 plotInfo[dataInfo.name] = dataId[dataInfo.name]
73 bandStr = ""
74 for band in bands:
75 bandStr += ", " + band
76 plotInfo["bands"] = bandStr[2:]
78 if "tract" not in plotInfo.keys():
79 plotInfo["tract"] = "N/A"
80 if "visit" not in plotInfo.keys():
81 plotInfo["visit"] = "N/A"
83 return plotInfo
86def generateSummaryStats(data, skymap, plotInfo):
87 """Generate a summary statistic in each patch or detector.
89 Parameters
90 ----------
91 data : `dict`
92 A dictionary of the data to be plotted.
93 skymap : `lsst.skymap.ringsSkyMap.RingsSkyMap`
94 The skymap associated with the data.
95 plotInfo : `dict`
96 A dictionary of the plot information.
98 Returns
99 -------
100 patchInfoDict : `dict`
101 A dictionary of the patch information.
102 """
103 # TODO: what is the more generic type of skymap?
104 tractInfo = skymap.generateTract(plotInfo["tract"])
105 tractWcs = tractInfo.getWcs()
107 # For now also convert the gen 2 patchIds to gen 3
108 if "y" in data.keys():
109 yCol = "y"
110 elif "yStars" in data.keys():
111 yCol = "yStars"
112 elif "yGalaxies" in data.keys():
113 yCol = "yGalaxies"
114 elif "yUnknowns" in data.keys():
115 yCol = "yUnknowns"
117 patchInfoDict = {}
118 maxPatchNum = tractInfo.num_patches.x * tractInfo.num_patches.y
119 patches = np.arange(0, maxPatchNum, 1)
120 for patch in patches:
121 if patch is None:
122 continue
123 # Once the objectTable_tract catalogues are using gen 3 patches
124 # this will go away
125 onPatch = data["patch"] == patch
126 if sum(onPatch) == 0:
127 stat = np.nan
128 else:
129 stat = nanMedian(data[yCol][onPatch])
130 try:
131 patchTuple = (int(patch.split(",")[0]), int(patch.split(",")[-1]))
132 patchInfo = tractInfo.getPatchInfo(patchTuple)
133 gen3PatchId = tractInfo.getSequentialPatchIndex(patchInfo)
134 except AttributeError:
135 # For native gen 3 tables the patches don't need converting
136 # When we are no longer looking at the gen 2 -> gen 3
137 # converted repos we can tidy this up
138 gen3PatchId = patch
139 patchInfo = tractInfo.getPatchInfo(patch)
141 corners = Box2D(patchInfo.getInnerBBox()).getCorners()
142 skyCoords = tractWcs.pixelToSky(corners)
144 patchInfoDict[gen3PatchId] = (skyCoords, stat)
146 tractCorners = Box2D(tractInfo.getBBox()).getCorners()
147 skyCoords = tractWcs.pixelToSky(tractCorners)
148 patchInfoDict["tract"] = (skyCoords, np.nan)
150 return patchInfoDict
153def generateSummaryStatsVisit(cat, colName, visitSummaryTable):
154 """Generate a summary statistic in each patch or detector.
156 Parameters
157 ----------
158 cat : `pandas.core.frame.DataFrame`
159 A dataframe of the data to be plotted.
160 colName : `str`
161 The name of the column to be plotted.
162 visitSummaryTable : `pandas.core.frame.DataFrame`
163 A dataframe of the visit summary table.
165 Returns
166 -------
167 visitInfoDict : `dict`
168 A dictionary of the visit information.
169 """
170 visitInfoDict = {}
171 for ccd in cat.detector.unique():
172 if ccd is None:
173 continue
174 onCcd = cat["detector"] == ccd
175 stat = nanMedian(cat[colName].values[onCcd])
177 sumRow = visitSummaryTable["id"] == ccd
178 corners = zip(visitSummaryTable["raCorners"][sumRow][0], visitSummaryTable["decCorners"][sumRow][0])
179 cornersOut = []
180 for ra, dec in corners:
181 corner = SpherePoint(ra, dec, units=degrees)
182 cornersOut.append(corner)
184 visitInfoDict[ccd] = (cornersOut, stat)
186 return visitInfoDict
189# Inspired by matplotlib.testing.remove_ticks_and_titles
190def get_and_remove_axis_text(ax) -> Tuple[List[str], List[np.ndarray]]:
191 """Remove text from an Axis and its children and return with line points.
193 Parameters
194 ----------
195 ax : `plt.Axis`
196 A matplotlib figure axis.
198 Returns
199 -------
200 texts : `List[str]`
201 A list of all text strings (title and axis/legend/tick labels).
202 line_xys : `List[numpy.ndarray]`
203 A list of all line ``_xy`` attributes (arrays of shape ``(N, 2)``).
204 """
205 line_xys = [line._xy for line in ax.lines]
206 texts = [text.get_text() for text in (ax.title, ax.xaxis.label, ax.yaxis.label)]
207 ax.set_title("")
208 ax.set_xlabel("")
209 ax.set_ylabel("")
211 try:
212 texts_legend = ax.get_legend().texts
213 texts.extend(text.get_text() for text in texts_legend)
214 for text in texts_legend:
215 text.set_alpha(0)
216 except AttributeError:
217 pass
219 for idx in range(len(ax.texts)):
220 texts.append(ax.texts[idx].get_text())
221 ax.texts[idx].set_text("")
223 ax.xaxis.set_major_formatter(null_formatter)
224 ax.xaxis.set_minor_formatter(null_formatter)
225 ax.yaxis.set_major_formatter(null_formatter)
226 ax.yaxis.set_minor_formatter(null_formatter)
227 try:
228 ax.zaxis.set_major_formatter(null_formatter)
229 ax.zaxis.set_minor_formatter(null_formatter)
230 except AttributeError:
231 pass
232 for child in ax.child_axes:
233 texts_child, lines_child = get_and_remove_axis_text(child)
234 texts.extend(texts_child)
236 return texts, line_xys
239def get_and_remove_figure_text(figure: Figure):
240 """Remove text from a Figure and its Axes and return with line points.
242 Parameters
243 ----------
244 figure : `matplotlib.pyplot.Figure`
245 A matplotlib figure.
247 Returns
248 -------
249 texts : `List[str]`
250 A list of all text strings (title and axis/legend/tick labels).
251 line_xys : `List[numpy.ndarray]`, (N, 2)
252 A list of all line ``_xy`` attributes (arrays of shape ``(N, 2)``).
253 """
254 texts = [str(figure._suptitle)]
255 lines = []
256 figure.suptitle("")
258 texts.extend(text.get_text() for text in figure.texts)
259 figure.texts = []
261 for ax in figure.get_axes():
262 texts_ax, lines_ax = get_and_remove_axis_text(ax)
263 texts.extend(texts_ax)
264 lines.extend(lines_ax)
266 return texts, lines
269def addPlotInfo(fig: Figure, plotInfo: Mapping[str, str]) -> Figure:
270 """Add useful information to the plot.
272 Parameters
273 ----------
274 fig : `matplotlib.figure.Figure`
275 The figure to add the information to.
276 plotInfo : `dict`
277 A dictionary of the plot information.
279 Returns
280 -------
281 fig : `matplotlib.figure.Figure`
282 The figure with the information added.
283 """
284 # TO DO: figure out how to get this information
285 photocalibDataset = "None"
286 astroDataset = "None"
288 fig.text(0.01, 0.99, plotInfo["plotName"], fontsize=7, transform=fig.transFigure, ha="left", va="top")
290 run = plotInfo["run"]
291 datasetsUsed = f"\nPhotoCalib: {photocalibDataset}, Astrometry: {astroDataset}"
292 tableType = f"\nTable: {plotInfo['tableName']}"
294 dataIdText = ""
295 if "tract" in plotInfo.keys():
296 dataIdText += f", Tract: {plotInfo['tract']}"
297 if "visit" in plotInfo.keys():
298 dataIdText += f", Visit: {plotInfo['visit']}"
300 bandText = ""
301 for band in plotInfo["bands"]:
302 bandText += band + ", "
303 bandsText = f", Bands: {bandText[:-2]}"
304 infoText = f"\n{run}{datasetsUsed}{tableType}{dataIdText}{bandsText}"
306 # Find S/N and mag keys, if present.
307 snKey = None
308 for key, value in plotInfo.items():
309 if "SN" in key or "S/N" in key:
310 snKey = key
311 break
312 magKey = None
313 for key, value in plotInfo.items():
314 if "Mag" in key:
315 magKey = key
316 break
317 # Add S/N and mag values to label, if present.
318 if snKey is not None:
319 if magKey is None:
320 infoText += f", {snKey}{plotInfo.get(snKey)}"
321 else:
322 infoText += f"\n{snKey}{plotInfo.get(snKey)}"
323 if magKey is not None:
324 infoText += f", {magKey}{plotInfo.get(magKey)}"
325 fig.text(0.01, 0.984, infoText, fontsize=6, transform=fig.transFigure, alpha=0.6, ha="left", va="top")
327 return fig
330def mkColormap(colorNames):
331 """Make a colormap from the list of color names.
333 Parameters
334 ----------
335 colorNames : `list`
336 A list of strings that correspond to matplotlib named colors.
338 Returns
339 -------
340 cmap : `matplotlib.colors.LinearSegmentedColormap`
341 A colormap stepping through the supplied list of names.
342 """
343 nums = np.linspace(0, 1, len(colorNames))
344 blues = []
345 greens = []
346 reds = []
347 for num, color in zip(nums, colorNames):
348 r, g, b = colors.colorConverter.to_rgb(color)
349 blues.append((num, b, b))
350 greens.append((num, g, g))
351 reds.append((num, r, r))
353 colorDict = {"blue": blues, "red": reds, "green": greens}
354 cmap = colors.LinearSegmentedColormap("newCmap", colorDict)
355 return cmap
358def extremaSort(xs):
359 """Return the IDs of the points reordered so that those furthest from the
360 median, in absolute terms, are last.
362 Parameters
363 ----------
364 xs : `np.array`
365 An array of the values to sort
367 Returns
368 -------
369 ids : `np.array`
370 """
371 med = nanMedian(xs)
372 dists = np.abs(xs - med)
373 ids = np.argsort(dists)
374 return ids
377def sortAllArrays(arrsToSort, sortArrayIndex=0):
378 """Sort one array and then return all the others in the associated order.
380 Parameters
381 ----------
382 arrsToSort : `list` [`np.array`]
383 A list of arrays to be simultaneously sorted based on the array in
384 the list position given by ``sortArrayIndex`` (defaults to be the
385 first array in the list).
386 sortArrayIndex : `int`, optional
387 Zero-based index indicating the array on which to base the sorting.
389 Returns
390 -------
391 arrsToSort : `list` [`np.array`]
392 The list of arrays sorted on array in list index ``sortArrayIndex``.
393 """
394 ids = extremaSort(arrsToSort[sortArrayIndex])
395 for i, arr in enumerate(arrsToSort):
396 arrsToSort[i] = arr[ids]
397 return arrsToSort
400def addSummaryPlot(fig, loc, sumStats, label):
401 """Add a summary subplot to the figure.
403 Parameters
404 ----------
405 fig : `matplotlib.figure.Figure`
406 The figure that the summary plot is to be added to.
407 loc : `matplotlib.gridspec.SubplotSpec` or `int` or `(int, int, index`
408 Describes the location in the figure to put the summary plot,
409 can be a gridspec SubplotSpec, a 3 digit integer where the first
410 digit is the number of rows, the second is the number of columns
411 and the third is the index. This is the same for the tuple
412 of int, int, index.
413 sumStats : `dict`
414 A dictionary where the patchIds are the keys which store the R.A.
415 and the dec of the corners of the patch, along with a summary
416 statistic for each patch.
417 label : `str`
418 The label to be used for the colorbar.
420 Returns
421 -------
422 fig : `matplotlib.figure.Figure`
423 """
424 # Add the subplot to the relevant place in the figure
425 # and sort the axis out
426 axCorner = fig.add_subplot(loc)
427 axCorner.yaxis.tick_right()
428 axCorner.yaxis.set_label_position("right")
429 axCorner.xaxis.tick_top()
430 axCorner.xaxis.set_label_position("top")
431 axCorner.set_aspect("equal")
433 # Plot the corners of the patches and make the color
434 # coded rectangles for each patch, the colors show
435 # the median of the given value in the patch
436 patches = []
437 colors = []
438 for dataId in sumStats.keys():
439 (corners, stat) = sumStats[dataId]
440 ra = corners[0][0].asDegrees()
441 dec = corners[0][1].asDegrees()
442 xy = (ra, dec)
443 width = corners[2][0].asDegrees() - ra
444 height = corners[2][1].asDegrees() - dec
445 patches.append(Rectangle(xy, width, height))
446 colors.append(stat)
447 ras = [ra.asDegrees() for (ra, dec) in corners]
448 decs = [dec.asDegrees() for (ra, dec) in corners]
449 axCorner.plot(ras + [ras[0]], decs + [decs[0]], "k", lw=0.5)
450 cenX = ra + width / 2
451 cenY = dec + height / 2
452 if dataId != "tract":
453 axCorner.annotate(dataId, (cenX, cenY), color="k", fontsize=4, ha="center", va="center")
455 # Set the bad color to transparent and make a masked array
456 cmapPatch = plt.cm.coolwarm.copy()
457 cmapPatch.set_bad(color="none")
458 colors = np.ma.array(colors, mask=np.isnan(colors))
459 collection = PatchCollection(patches, cmap=cmapPatch)
460 collection.set_array(colors)
461 axCorner.add_collection(collection)
463 # Add some labels
464 axCorner.set_xlabel("R.A. (deg)", fontsize=7)
465 axCorner.set_ylabel("Dec. (deg)", fontsize=7)
466 axCorner.tick_params(axis="both", labelsize=6, length=0, pad=1.5)
467 axCorner.invert_xaxis()
469 # Add a colorbar
470 pos = axCorner.get_position()
471 yOffset = (pos.y1 - pos.y0) / 3
472 cax = fig.add_axes([pos.x0, pos.y1 + yOffset, pos.x1 - pos.x0, 0.025])
473 plt.colorbar(collection, cax=cax, orientation="horizontal")
474 cax.text(
475 0.5,
476 0.48,
477 label,
478 color="k",
479 transform=cax.transAxes,
480 rotation="horizontal",
481 horizontalalignment="center",
482 verticalalignment="center",
483 fontsize=6,
484 )
485 cax.tick_params(
486 axis="x", labelsize=6, labeltop=True, labelbottom=False, bottom=False, top=True, pad=0.5, length=2
487 )
489 return fig
492def shorten_list(numbers: Iterable[int], *, range_indicator: str = "-", range_separator: str = ",") -> str:
493 """Shorten an iterable of integers.
495 Parameters
496 ----------
497 numbers : `~collections.abc.Iterable` [`int`]
498 Any iterable (list, set, tuple, numpy.array) of integers.
499 range_indicator : `str`, optional
500 The string to use to indicate a range of numbers.
501 range_separator : `str`, optional
502 The string to use to separate ranges of numbers.
504 Returns
505 -------
506 result : `str`
507 A shortened string representation of the list.
509 Examples
510 --------
511 >>> shorten_list([1,2,3,5,6,8])
512 "1-3,5-6,8"
514 >>> shorten_list((1,2,3,5,6,8,9,10,11), range_separator=", ")
515 "1-3, 5-6, 8-11"
517 >>> shorten_list(range(4), range_indicator="..")
518 "0..3"
519 """
520 # Sort the list in ascending order.
521 numbers = sorted(numbers)
523 if not numbers: # empty container
524 return ""
526 # Initialize an empty list to hold the results to be returned.
527 result = []
529 # Initialize variables to track the current start and end of a list.
530 start = 0
531 end = 0 # initialize to 0 to handle single element lists.
533 # Iterate through the sorted list of numbers
534 for end in range(1, len(numbers)):
535 # If the current number is the same or consecutive to the previous
536 # number, skip to the next iteration.
537 if numbers[end] > numbers[end - 1] + 1: # > is used to handle duplicates, if any.
538 # If the current number is not consecutive to the previous number,
539 # add the current range to the result and reset the start to end.
540 if start == end - 1:
541 result.append(str(numbers[start]))
542 else:
543 result.append(range_indicator.join((str(numbers[start]), str(numbers[end - 1]))))
545 # Update start.
546 start = end
548 # Add the final range to the result.
549 if start == end:
550 result.append(str(numbers[start]))
551 else:
552 result.append(range_indicator.join((str(numbers[start]), str(numbers[end]))))
554 # Return the shortened string representation.
555 return range_separator.join(result)
558class PanelConfig(Config):
559 """Configuration options for the plot panels used by DiaSkyPlot.
561 The defaults will produce a good-looking single panel plot.
562 The subplot2grid* fields correspond to matplotlib.pyplot.subplot2grid.
563 """
565 topSpinesVisible = Field[bool](
566 doc="Draw line and ticks on top of panel?",
567 default=False,
568 )
569 bottomSpinesVisible = Field[bool](
570 doc="Draw line and ticks on bottom of panel?",
571 default=True,
572 )
573 leftSpinesVisible = Field[bool](
574 doc="Draw line and ticks on left side of panel?",
575 default=True,
576 )
577 rightSpinesVisible = Field[bool](
578 doc="Draw line and ticks on right side of panel?",
579 default=True,
580 )
581 subplot2gridShapeRow = Field[int](
582 doc="Number of rows of the grid in which to place axis.",
583 default=10,
584 )
585 subplot2gridShapeColumn = Field[int](
586 doc="Number of columns of the grid in which to place axis.",
587 default=10,
588 )
589 subplot2gridLocRow = Field[int](
590 doc="Row of the axis location within the grid.",
591 default=1,
592 )
593 subplot2gridLocColumn = Field[int](
594 doc="Column of the axis location within the grid.",
595 default=1,
596 )
597 subplot2gridRowspan = Field[int](
598 doc="Number of rows for the axis to span downwards.",
599 default=5,
600 )
601 subplot2gridColspan = Field[int](
602 doc="Number of rows for the axis to span to the right.",
603 default=5,
604 )
607def plotProjectionWithBinning(
608 ax,
609 xs,
610 ys,
611 zs,
612 cmap,
613 xMin,
614 xMax,
615 yMin,
616 yMax,
617 xNumBins=45,
618 yNumBins=None,
619 fixAroundZero=False,
620 nPointBinThresh=5000,
621 isSorted=False,
622 vmin=None,
623 vmax=None,
624 showExtremeOutliers=True,
625 scatPtSize=7,
626):
627 """Plot color-mapped data in projection and with binning when appropriate.
629 Parameters
630 ----------
631 ax : `matplotlib.axes.Axes`
632 Axis on which to plot the projection data.
633 xs, ys : `np.array`
634 Arrays containing the x and y positions of the data.
635 zs : `np.array`
636 Array containing the scaling value associated with the (``xs``, ``ys``)
637 positions.
638 cmap : `matplotlib.colors.Colormap`
639 Colormap for the ``zs`` values.
640 xMin, xMax, yMin, yMax : `float`
641 Data limits within which to compute bin sizes.
642 xNumBins : `int`, optional
643 The number of bins along the x-axis.
644 yNumBins : `int`, optional
645 The number of bins along the y-axis. If `None`, this is set to equal
646 ``xNumBins``.
647 nPointBinThresh : `int`, optional
648 Threshold number of points above which binning will be implemented
649 for the plotting. If the number of data points is lower than this
650 threshold, a basic scatter plot will be generated.
651 isSorted : `bool`, optional
652 Whether the data have been sorted in ``zs`` (the sorting is to
653 accommodate the overplotting of points in the upper and lower
654 extrema of the data).
655 vmin, vmax : `float`, optional
656 The min and max limits for the colorbar.
657 showExtremeOutliers: `bool`, default True
658 Use overlaid scatter points to show the x-y positions of the 15%
659 most extreme values.
660 scatPtSize : `float`, optional
661 The point size to use if just plotting a regular scatter plot.
663 Returns
664 -------
665 plotOut : `matplotlib.collections.PathCollection`
666 The plot object with ``ax`` updated with data plotted here.
667 """
668 med = nanMedian(zs)
669 mad = nanSigmaMad(zs)
670 if vmin is None:
671 vmin = med - 2 * mad
672 if vmax is None:
673 vmax = med + 2 * mad
674 if fixAroundZero:
675 scaleEnd = np.max([np.abs(vmin), np.abs(vmax)])
676 vmin = -1 * scaleEnd
677 vmax = scaleEnd
679 yNumBins = xNumBins if yNumBins is None else yNumBins
681 xBinEdges = np.linspace(xMin, xMax, xNumBins + 1)
682 yBinEdges = np.linspace(yMin, yMax, yNumBins + 1)
683 binnedStats, xEdges, yEdges, binNums = binned_statistic_2d(
684 xs, ys, zs, statistic="median", bins=(xBinEdges, yBinEdges)
685 )
686 if len(xs) >= nPointBinThresh:
687 s = min(10, max(0.5, nPointBinThresh / 10 / (len(xs) ** 0.5)))
688 lw = (s**0.5) / 10
689 plotOut = ax.imshow(
690 binnedStats.T,
691 cmap=cmap,
692 extent=[xEdges[0], xEdges[-1], yEdges[-1], yEdges[0]],
693 vmin=vmin,
694 vmax=vmax,
695 )
696 if not isSorted:
697 sortedArrays = sortAllArrays([zs, xs, ys])
698 zs, xs, ys = sortedArrays[0], sortedArrays[1], sortedArrays[2]
699 if len(xs) > 1:
700 if showExtremeOutliers:
701 # Find the most extreme 15% of points. The list is ordered
702 # by the distance from the median, this is just the
703 # head/tail 15% of points.
704 extremes = int(np.floor((len(xs) / 100)) * 85)
705 plotOut = ax.scatter(
706 xs[extremes:],
707 ys[extremes:],
708 c=zs[extremes:],
709 s=s,
710 cmap=cmap,
711 vmin=vmin,
712 vmax=vmax,
713 edgecolor="white",
714 linewidths=lw,
715 )
716 else:
717 plotOut = ax.scatter(
718 xs,
719 ys,
720 c=zs,
721 cmap=cmap,
722 s=scatPtSize,
723 vmin=vmin,
724 vmax=vmax,
725 edgecolor="white",
726 linewidths=0.2,
727 )
728 return plotOut