Coverage for python/lsst/analysis/tools/actions/plot/plotUtils.py: 16%
257 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-01 03:25 -0700
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-01 03:25 -0700
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ("PanelConfig",)
25from typing import TYPE_CHECKING, Iterable, List, Mapping, Tuple
27import matplotlib
28import matplotlib.pyplot as plt
29import numpy as np
30from lsst.geom import Box2D, SpherePoint, degrees
31from lsst.pex.config import Config, Field
32from matplotlib import colors
33from matplotlib.collections import PatchCollection
34from matplotlib.patches import Rectangle
35from scipy.stats import binned_statistic_2d
37from ...statistics import nansigmaMad
39if TYPE_CHECKING: 39 ↛ 40line 39 didn't jump to line 40, because the condition on line 39 was never true
40 from matplotlib.figure import Figure
42null_formatter = matplotlib.ticker.NullFormatter()
45def parsePlotInfo(dataId, runName, tableName, bands, plotName, SN):
46 """Parse plot info from the dataId.
48 Parameters
49 ----------
50 dataId : `dict`
51 The dataId of the data to be plotted.
52 runName : `str`
53 The name of the run.
54 tableName : `str`
55 The name of the table.
56 bands : `list` [`str`]
57 The bands to be plotted.
58 plotName : `str`
59 The name of the plot.
60 SN : `str`
61 The signal to noise of the data.
63 Returns
64 -------
65 plotInfo : `dict`
66 A dictionary of the plot information.
67 """
68 plotInfo = {"run": runName, "tableName": tableName, "plotName": plotName, "SN": SN}
70 for dataInfo in dataId:
71 plotInfo[dataInfo.name] = dataId[dataInfo.name]
73 bandStr = ""
74 for band in bands:
75 bandStr += ", " + band
76 plotInfo["bands"] = bandStr[2:]
78 if "tract" not in plotInfo.keys():
79 plotInfo["tract"] = "N/A"
80 if "visit" not in plotInfo.keys():
81 plotInfo["visit"] = "N/A"
83 return plotInfo
86def generateSummaryStats(data, skymap, plotInfo):
87 """Generate a summary statistic in each patch or detector.
89 Parameters
90 ----------
91 data : `dict`
92 A dictionary of the data to be plotted.
93 skymap : `lsst.skymap.ringsSkyMap.RingsSkyMap`
94 The skymap associated with the data.
95 plotInfo : `dict`
96 A dictionary of the plot information.
98 Returns
99 -------
100 patchInfoDict : `dict`
101 A dictionary of the patch information.
102 """
103 # TODO: what is the more generic type of skymap?
104 tractInfo = skymap.generateTract(plotInfo["tract"])
105 tractWcs = tractInfo.getWcs()
107 # For now also convert the gen 2 patchIds to gen 3
108 if "y" in data.keys():
109 yCol = "y"
110 elif "yStars" in data.keys():
111 yCol = "yStars"
112 elif "yGalaxies" in data.keys():
113 yCol = "yGalaxies"
114 elif "yUnknowns" in data.keys():
115 yCol = "yUnknowns"
117 patchInfoDict = {}
118 maxPatchNum = tractInfo.num_patches.x * tractInfo.num_patches.y
119 patches = np.arange(0, maxPatchNum, 1)
120 for patch in patches:
121 if patch is None:
122 continue
123 # Once the objectTable_tract catalogues are using gen 3 patches
124 # this will go away
125 onPatch = data["patch"] == patch
126 stat = np.nanmedian(data[yCol][onPatch])
127 try:
128 patchTuple = (int(patch.split(",")[0]), int(patch.split(",")[-1]))
129 patchInfo = tractInfo.getPatchInfo(patchTuple)
130 gen3PatchId = tractInfo.getSequentialPatchIndex(patchInfo)
131 except AttributeError:
132 # For native gen 3 tables the patches don't need converting
133 # When we are no longer looking at the gen 2 -> gen 3
134 # converted repos we can tidy this up
135 gen3PatchId = patch
136 patchInfo = tractInfo.getPatchInfo(patch)
138 corners = Box2D(patchInfo.getInnerBBox()).getCorners()
139 skyCoords = tractWcs.pixelToSky(corners)
141 patchInfoDict[gen3PatchId] = (skyCoords, stat)
143 tractCorners = Box2D(tractInfo.getBBox()).getCorners()
144 skyCoords = tractWcs.pixelToSky(tractCorners)
145 patchInfoDict["tract"] = (skyCoords, np.nan)
147 return patchInfoDict
150def generateSummaryStatsVisit(cat, colName, visitSummaryTable):
151 """Generate a summary statistic in each patch or detector.
153 Parameters
154 ----------
155 cat : `pandas.core.frame.DataFrame`
156 A dataframe of the data to be plotted.
157 colName : `str`
158 The name of the column to be plotted.
159 visitSummaryTable : `pandas.core.frame.DataFrame`
160 A dataframe of the visit summary table.
162 Returns
163 -------
164 visitInfoDict : `dict`
165 A dictionary of the visit information.
166 """
167 visitInfoDict = {}
168 for ccd in cat.detector.unique():
169 if ccd is None:
170 continue
171 onCcd = cat["detector"] == ccd
172 stat = np.nanmedian(cat[colName].values[onCcd])
174 sumRow = visitSummaryTable["id"] == ccd
175 corners = zip(visitSummaryTable["raCorners"][sumRow][0], visitSummaryTable["decCorners"][sumRow][0])
176 cornersOut = []
177 for ra, dec in corners:
178 corner = SpherePoint(ra, dec, units=degrees)
179 cornersOut.append(corner)
181 visitInfoDict[ccd] = (cornersOut, stat)
183 return visitInfoDict
186# Inspired by matplotlib.testing.remove_ticks_and_titles
187def get_and_remove_axis_text(ax) -> Tuple[List[str], List[np.ndarray]]:
188 """Remove text from an Axis and its children and return with line points.
190 Parameters
191 ----------
192 ax : `plt.Axis`
193 A matplotlib figure axis.
195 Returns
196 -------
197 texts : `List[str]`
198 A list of all text strings (title and axis/legend/tick labels).
199 line_xys : `List[numpy.ndarray]`
200 A list of all line ``_xy`` attributes (arrays of shape ``(N, 2)``).
201 """
202 line_xys = [line._xy for line in ax.lines]
203 texts = [text.get_text() for text in (ax.title, ax.xaxis.label, ax.yaxis.label)]
204 ax.set_title("")
205 ax.set_xlabel("")
206 ax.set_ylabel("")
208 try:
209 texts_legend = ax.get_legend().texts
210 texts.extend(text.get_text() for text in texts_legend)
211 for text in texts_legend:
212 text.set_alpha(0)
213 except AttributeError:
214 pass
216 for idx in range(len(ax.texts)):
217 texts.append(ax.texts[idx].get_text())
218 ax.texts[idx].set_text("")
220 ax.xaxis.set_major_formatter(null_formatter)
221 ax.xaxis.set_minor_formatter(null_formatter)
222 ax.yaxis.set_major_formatter(null_formatter)
223 ax.yaxis.set_minor_formatter(null_formatter)
224 try:
225 ax.zaxis.set_major_formatter(null_formatter)
226 ax.zaxis.set_minor_formatter(null_formatter)
227 except AttributeError:
228 pass
229 for child in ax.child_axes:
230 texts_child, lines_child = get_and_remove_axis_text(child)
231 texts.extend(texts_child)
233 return texts, line_xys
236def get_and_remove_figure_text(figure: Figure):
237 """Remove text from a Figure and its Axes and return with line points.
239 Parameters
240 ----------
241 figure : `matplotlib.pyplot.Figure`
242 A matplotlib figure.
244 Returns
245 -------
246 texts : `List[str]`
247 A list of all text strings (title and axis/legend/tick labels).
248 line_xys : `List[numpy.ndarray]`, (N, 2)
249 A list of all line ``_xy`` attributes (arrays of shape ``(N, 2)``).
250 """
251 texts = [str(figure._suptitle)]
252 lines = []
253 figure.suptitle("")
255 texts.extend(text.get_text() for text in figure.texts)
256 figure.texts = []
258 for ax in figure.get_axes():
259 texts_ax, lines_ax = get_and_remove_axis_text(ax)
260 texts.extend(texts_ax)
261 lines.extend(lines_ax)
263 return texts, lines
266def addPlotInfo(fig: Figure, plotInfo: Mapping[str, str]) -> Figure:
267 """Add useful information to the plot.
269 Parameters
270 ----------
271 fig : `matplotlib.figure.Figure`
272 The figure to add the information to.
273 plotInfo : `dict`
274 A dictionary of the plot information.
276 Returns
277 -------
278 fig : `matplotlib.figure.Figure`
279 The figure with the information added.
280 """
281 # TO DO: figure out how to get this information
282 photocalibDataset = "None"
283 astroDataset = "None"
285 fig.text(0.01, 0.99, plotInfo["plotName"], fontsize=8, transform=fig.transFigure, ha="left", va="top")
287 run = plotInfo["run"]
288 datasetsUsed = f"\nPhotoCalib: {photocalibDataset}, Astrometry: {astroDataset}"
289 tableType = f"\nTable: {plotInfo['tableName']}"
291 dataIdText = ""
292 if "tract" in plotInfo.keys():
293 dataIdText += f", Tract: {plotInfo['tract']}"
294 if "visit" in plotInfo.keys():
295 dataIdText += f", Visit: {plotInfo['visit']}"
297 bandText = ""
298 for band in plotInfo["bands"]:
299 bandText += band + ", "
300 bandsText = f", Bands: {bandText[:-2]}"
301 SNText = f", S/N: {plotInfo.get('SN', 'N/A')}"
302 infoText = f"\n{run}{datasetsUsed}{tableType}{dataIdText}{bandsText}{SNText}"
303 fig.text(0.01, 0.98, infoText, fontsize=7, transform=fig.transFigure, alpha=0.6, ha="left", va="top")
305 return fig
308def mkColormap(colorNames):
309 """Make a colormap from the list of color names.
311 Parameters
312 ----------
313 colorNames : `list`
314 A list of strings that correspond to matplotlib named colors.
316 Returns
317 -------
318 cmap : `matplotlib.colors.LinearSegmentedColormap`
319 A colormap stepping through the supplied list of names.
320 """
321 nums = np.linspace(0, 1, len(colorNames))
322 blues = []
323 greens = []
324 reds = []
325 for num, color in zip(nums, colorNames):
326 r, g, b = colors.colorConverter.to_rgb(color)
327 blues.append((num, b, b))
328 greens.append((num, g, g))
329 reds.append((num, r, r))
331 colorDict = {"blue": blues, "red": reds, "green": greens}
332 cmap = colors.LinearSegmentedColormap("newCmap", colorDict)
333 return cmap
336def extremaSort(xs):
337 """Return the IDs of the points reordered so that those furthest from the
338 median, in absolute terms, are last.
340 Parameters
341 ----------
342 xs : `np.array`
343 An array of the values to sort
345 Returns
346 -------
347 ids : `np.array`
348 """
349 med = np.nanmedian(xs)
350 dists = np.abs(xs - med)
351 ids = np.argsort(dists)
352 return ids
355def sortAllArrays(arrsToSort, sortArrayIndex=0):
356 """Sort one array and then return all the others in the associated order.
358 Parameters
359 ----------
360 arrsToSort : `list` [`np.array`]
361 A list of arrays to be simultaneously sorted based on the array in
362 the list position given by ``sortArrayIndex`` (defaults to be the
363 first array in the list).
364 sortArrayIndex : `int`, optional
365 Zero-based index indicating the array on which to base the sorting.
367 Returns
368 -------
369 arrsToSort : `list` [`np.array`]
370 The list of arrays sorted on array in list index ``sortArrayIndex``.
371 """
372 ids = extremaSort(arrsToSort[sortArrayIndex])
373 for i, arr in enumerate(arrsToSort):
374 arrsToSort[i] = arr[ids]
375 return arrsToSort
378def addSummaryPlot(fig, loc, sumStats, label):
379 """Add a summary subplot to the figure.
381 Parameters
382 ----------
383 fig : `matplotlib.figure.Figure`
384 The figure that the summary plot is to be added to.
385 loc : `matplotlib.gridspec.SubplotSpec` or `int` or `(int, int, index`
386 Describes the location in the figure to put the summary plot,
387 can be a gridspec SubplotSpec, a 3 digit integer where the first
388 digit is the number of rows, the second is the number of columns
389 and the third is the index. This is the same for the tuple
390 of int, int, index.
391 sumStats : `dict`
392 A dictionary where the patchIds are the keys which store the R.A.
393 and the dec of the corners of the patch, along with a summary
394 statistic for each patch.
395 label : `str`
396 The label to be used for the colorbar.
398 Returns
399 -------
400 fig : `matplotlib.figure.Figure`
401 """
402 # Add the subplot to the relevant place in the figure
403 # and sort the axis out
404 axCorner = fig.add_subplot(loc)
405 axCorner.yaxis.tick_right()
406 axCorner.yaxis.set_label_position("right")
407 axCorner.xaxis.tick_top()
408 axCorner.xaxis.set_label_position("top")
409 axCorner.set_aspect("equal")
411 # Plot the corners of the patches and make the color
412 # coded rectangles for each patch, the colors show
413 # the median of the given value in the patch
414 patches = []
415 colors = []
416 for dataId in sumStats.keys():
417 (corners, stat) = sumStats[dataId]
418 ra = corners[0][0].asDegrees()
419 dec = corners[0][1].asDegrees()
420 xy = (ra, dec)
421 width = corners[2][0].asDegrees() - ra
422 height = corners[2][1].asDegrees() - dec
423 patches.append(Rectangle(xy, width, height))
424 colors.append(stat)
425 ras = [ra.asDegrees() for (ra, dec) in corners]
426 decs = [dec.asDegrees() for (ra, dec) in corners]
427 axCorner.plot(ras + [ras[0]], decs + [decs[0]], "k", lw=0.5)
428 cenX = ra + width / 2
429 cenY = dec + height / 2
430 if dataId != "tract":
431 axCorner.annotate(dataId, (cenX, cenY), color="k", fontsize=4, ha="center", va="center")
433 # Set the bad color to transparent and make a masked array
434 cmapPatch = plt.cm.coolwarm.copy()
435 cmapPatch.set_bad(color="none")
436 colors = np.ma.array(colors, mask=np.isnan(colors))
437 collection = PatchCollection(patches, cmap=cmapPatch)
438 collection.set_array(colors)
439 axCorner.add_collection(collection)
441 # Add some labels
442 axCorner.set_xlabel("R.A. (deg)", fontsize=7)
443 axCorner.set_ylabel("Dec. (deg)", fontsize=7)
444 axCorner.tick_params(axis="both", labelsize=6, length=0, pad=1.5)
445 axCorner.invert_xaxis()
447 # Add a colorbar
448 pos = axCorner.get_position()
449 yOffset = (pos.y1 - pos.y0) / 3
450 cax = fig.add_axes([pos.x0, pos.y1 + yOffset, pos.x1 - pos.x0, 0.025])
451 plt.colorbar(collection, cax=cax, orientation="horizontal")
452 cax.text(
453 0.5,
454 0.48,
455 label,
456 color="k",
457 transform=cax.transAxes,
458 rotation="horizontal",
459 horizontalalignment="center",
460 verticalalignment="center",
461 fontsize=6,
462 )
463 cax.tick_params(
464 axis="x", labelsize=6, labeltop=True, labelbottom=False, bottom=False, top=True, pad=0.5, length=2
465 )
467 return fig
470def shorten_list(numbers: Iterable[int], *, range_indicator: str = "-", range_separator: str = ",") -> str:
471 """Shorten an iterable of integers.
473 Parameters
474 ----------
475 numbers : `~collections.abc.Iterable` [`int`]
476 Any iterable (list, set, tuple, numpy.array) of integers.
477 range_indicator : `str`, optional
478 The string to use to indicate a range of numbers.
479 range_separator : `str`, optional
480 The string to use to separate ranges of numbers.
482 Returns
483 -------
484 result : `str`
485 A shortened string representation of the list.
487 Examples
488 --------
489 >>> shorten_list([1,2,3,5,6,8])
490 "1-3,5-6,8"
492 >>> shorten_list((1,2,3,5,6,8,9,10,11), range_separator=", ")
493 "1-3, 5-6, 8-11"
495 >>> shorten_list(range(4), range_indicator="..")
496 "0..3"
497 """
498 # Sort the list in ascending order.
499 numbers = sorted(numbers)
501 if not numbers: # empty container
502 return ""
504 # Initialize an empty list to hold the results to be returned.
505 result = []
507 # Initialize variables to track the current start and end of a list.
508 start = 0
509 end = 0 # initialize to 0 to handle single element lists.
511 # Iterate through the sorted list of numbers
512 for end in range(1, len(numbers)):
513 # If the current number is the same or consecutive to the previous
514 # number, skip to the next iteration.
515 if numbers[end] > numbers[end - 1] + 1: # > is used to handle duplicates, if any.
516 # If the current number is not consecutive to the previous number,
517 # add the current range to the result and reset the start to end.
518 if start == end - 1:
519 result.append(str(numbers[start]))
520 else:
521 result.append(range_indicator.join((str(numbers[start]), str(numbers[end - 1]))))
523 # Update start.
524 start = end
526 # Add the final range to the result.
527 if start == end:
528 result.append(str(numbers[start]))
529 else:
530 result.append(range_indicator.join((str(numbers[start]), str(numbers[end]))))
532 # Return the shortened string representation.
533 return range_separator.join(result)
536class PanelConfig(Config):
537 """Configuration options for the plot panels used by DiaSkyPlot.
539 The defaults will produce a good-looking single panel plot.
540 The subplot2grid* fields correspond to matplotlib.pyplot.subplot2grid.
541 """
543 topSpinesVisible = Field[bool](
544 doc="Draw line and ticks on top of panel?",
545 default=False,
546 )
547 bottomSpinesVisible = Field[bool](
548 doc="Draw line and ticks on bottom of panel?",
549 default=True,
550 )
551 leftSpinesVisible = Field[bool](
552 doc="Draw line and ticks on left side of panel?",
553 default=True,
554 )
555 rightSpinesVisible = Field[bool](
556 doc="Draw line and ticks on right side of panel?",
557 default=True,
558 )
559 subplot2gridShapeRow = Field[int](
560 doc="Number of rows of the grid in which to place axis.",
561 default=10,
562 )
563 subplot2gridShapeColumn = Field[int](
564 doc="Number of columns of the grid in which to place axis.",
565 default=10,
566 )
567 subplot2gridLocRow = Field[int](
568 doc="Row of the axis location within the grid.",
569 default=1,
570 )
571 subplot2gridLocColumn = Field[int](
572 doc="Column of the axis location within the grid.",
573 default=1,
574 )
575 subplot2gridRowspan = Field[int](
576 doc="Number of rows for the axis to span downwards.",
577 default=5,
578 )
579 subplot2gridColspan = Field[int](
580 doc="Number of rows for the axis to span to the right.",
581 default=5,
582 )
585def plotProjectionWithBinning(
586 ax,
587 xs,
588 ys,
589 zs,
590 cmap,
591 xMin,
592 xMax,
593 yMin,
594 yMax,
595 xNumBins=45,
596 yNumBins=None,
597 fixAroundZero=False,
598 nPointBinThresh=5000,
599 isSorted=False,
600 vmin=None,
601 vmax=None,
602 scatPtSize=7,
603):
604 """Plot color-mapped data in projection and with binning when appropriate.
606 Parameters
607 ----------
608 ax : `matplotlib.axes.Axes`
609 Axis on which to plot the projection data.
610 xs, ys : `np.array`
611 Arrays containing the x and y positions of the data.
612 zs : `np.array`
613 Array containing the scaling value associated with the (``xs``, ``ys``)
614 positions.
615 cmap : `matplotlib.colors.Colormap`
616 Colormap for the ``zs`` values.
617 xMin, xMax, yMin, yMax : `float`
618 Data limits within which to compute bin sizes.
619 xNumBins : `int`, optional
620 The number of bins along the x-axis.
621 yNumBins : `int`, optional
622 The number of bins along the y-axis. If `None`, this is set to equal
623 ``xNumBins``.
624 nPointBinThresh : `int`, optional
625 Threshold number of points above which binning will be implemented
626 for the plotting. If the number of data points is lower than this
627 threshold, a basic scatter plot will be generated.
628 isSorted : `bool`, optional
629 Whether the data have been sorted in ``zs`` (the sorting is to
630 accommodate the overplotting of points in the upper and lower
631 extrema of the data).
632 vmin, vmax : `float`, optional
633 The min and max limits for the colorbar.
634 scatPtSize : `float`, optional
635 The point size to use if just plotting a regular scatter plot.
637 Returns
638 -------
639 plotOut : `matplotlib.collections.PathCollection`
640 The plot object with ``ax`` updated with data plotted here.
641 """
642 med = np.nanmedian(zs)
643 mad = nansigmaMad(zs)
644 if vmin is None:
645 vmin = med - 2 * mad
646 if vmax is None:
647 vmax = med + 2 * mad
648 if fixAroundZero:
649 scaleEnd = np.max([np.abs(vmin), np.abs(vmax)])
650 vmin = -1 * scaleEnd
651 vmax = scaleEnd
653 yNumBins = xNumBins if yNumBins is None else yNumBins
655 xBinEdges = np.linspace(xMin, xMax, xNumBins + 1)
656 yBinEdges = np.linspace(yMin, yMax, yNumBins + 1)
657 binnedStats, xEdges, yEdges, binNums = binned_statistic_2d(
658 xs, ys, zs, statistic="median", bins=(xBinEdges, yBinEdges)
659 )
660 if len(xs) >= nPointBinThresh:
661 s = min(10, max(0.5, nPointBinThresh / 10 / (len(xs) ** 0.5)))
662 lw = (s**0.5) / 10
663 plotOut = ax.imshow(
664 binnedStats.T,
665 cmap=cmap,
666 extent=[xEdges[0], xEdges[-1], yEdges[-1], yEdges[0]],
667 vmin=vmin,
668 vmax=vmax,
669 )
670 if not isSorted:
671 sortedArrays = sortAllArrays([zs, xs, ys])
672 zs, xs, ys = sortedArrays[0], sortedArrays[1], sortedArrays[2]
673 # Find the most extreme 15% of points. The list is ordered by the
674 # distance from the median, this is just the head/tail 15% of points.
675 if len(xs) > 1:
676 extremes = int(np.floor((len(xs) / 100)) * 85)
677 plotOut = ax.scatter(
678 xs[extremes:],
679 ys[extremes:],
680 c=zs[extremes:],
681 s=s,
682 cmap=cmap,
683 vmin=vmin,
684 vmax=vmax,
685 edgecolor="white",
686 linewidths=lw,
687 )
688 else:
689 plotOut = ax.scatter(
690 xs,
691 ys,
692 c=zs,
693 cmap=cmap,
694 s=scatPtSize,
695 vmin=vmin,
696 vmax=vmax,
697 edgecolor="white",
698 linewidths=0.2,
699 )
700 return plotOut