Coverage for python/lsst/analysis/tools/actions/plot/plotUtils.py: 15%
259 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-08 13:17 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-08 13:17 +0000
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ("PanelConfig",)
25from typing import TYPE_CHECKING, Iterable, List, Mapping, Tuple
27import matplotlib
28import matplotlib.pyplot as plt
29import numpy as np
30from lsst.geom import Box2D, SpherePoint, degrees
31from lsst.pex.config import Config, Field
32from matplotlib import colors
33from matplotlib.collections import PatchCollection
34from matplotlib.patches import Rectangle
35from scipy.stats import binned_statistic_2d
37from ...statistics import nansigmaMad
39if TYPE_CHECKING: 39 ↛ 40line 39 didn't jump to line 40, because the condition on line 39 was never true
40 from matplotlib.figure import Figure
42null_formatter = matplotlib.ticker.NullFormatter()
45def parsePlotInfo(dataId, runName, tableName, bands, plotName, SN):
46 """Parse plot info from the dataId.
48 Parameters
49 ----------
50 dataId : `dict`
51 The dataId of the data to be plotted.
52 runName : `str`
53 The name of the run.
54 tableName : `str`
55 The name of the table.
56 bands : `list` [`str`]
57 The bands to be plotted.
58 plotName : `str`
59 The name of the plot.
60 SN : `str`
61 The signal to noise of the data.
63 Returns
64 -------
65 plotInfo : `dict`
66 A dictionary of the plot information.
67 """
68 plotInfo = {"run": runName, "tableName": tableName, "plotName": plotName, "SN": SN}
70 for dataInfo in dataId:
71 plotInfo[dataInfo.name] = dataId[dataInfo.name]
73 bandStr = ""
74 for band in bands:
75 bandStr += ", " + band
76 plotInfo["bands"] = bandStr[2:]
78 if "tract" not in plotInfo.keys():
79 plotInfo["tract"] = "N/A"
80 if "visit" not in plotInfo.keys():
81 plotInfo["visit"] = "N/A"
83 return plotInfo
86def generateSummaryStats(data, skymap, plotInfo):
87 """Generate a summary statistic in each patch or detector.
89 Parameters
90 ----------
91 data : `dict`
92 A dictionary of the data to be plotted.
93 skymap : `lsst.skymap.ringsSkyMap.RingsSkyMap`
94 The skymap associated with the data.
95 plotInfo : `dict`
96 A dictionary of the plot information.
98 Returns
99 -------
100 patchInfoDict : `dict`
101 A dictionary of the patch information.
102 """
103 # TODO: what is the more generic type of skymap?
104 tractInfo = skymap.generateTract(plotInfo["tract"])
105 tractWcs = tractInfo.getWcs()
107 # For now also convert the gen 2 patchIds to gen 3
108 if "y" in data.keys():
109 yCol = "y"
110 elif "yStars" in data.keys():
111 yCol = "yStars"
112 elif "yGalaxies" in data.keys():
113 yCol = "yGalaxies"
114 elif "yUnknowns" in data.keys():
115 yCol = "yUnknowns"
117 patchInfoDict = {}
118 maxPatchNum = tractInfo.num_patches.x * tractInfo.num_patches.y
119 patches = np.arange(0, maxPatchNum, 1)
120 for patch in patches:
121 if patch is None:
122 continue
123 # Once the objectTable_tract catalogues are using gen 3 patches
124 # this will go away
125 onPatch = data["patch"] == patch
126 if sum(onPatch) == 0:
127 stat = np.nan
128 else:
129 stat = np.nanmedian(data[yCol][onPatch])
130 try:
131 patchTuple = (int(patch.split(",")[0]), int(patch.split(",")[-1]))
132 patchInfo = tractInfo.getPatchInfo(patchTuple)
133 gen3PatchId = tractInfo.getSequentialPatchIndex(patchInfo)
134 except AttributeError:
135 # For native gen 3 tables the patches don't need converting
136 # When we are no longer looking at the gen 2 -> gen 3
137 # converted repos we can tidy this up
138 gen3PatchId = patch
139 patchInfo = tractInfo.getPatchInfo(patch)
141 corners = Box2D(patchInfo.getInnerBBox()).getCorners()
142 skyCoords = tractWcs.pixelToSky(corners)
144 patchInfoDict[gen3PatchId] = (skyCoords, stat)
146 tractCorners = Box2D(tractInfo.getBBox()).getCorners()
147 skyCoords = tractWcs.pixelToSky(tractCorners)
148 patchInfoDict["tract"] = (skyCoords, np.nan)
150 return patchInfoDict
153def generateSummaryStatsVisit(cat, colName, visitSummaryTable):
154 """Generate a summary statistic in each patch or detector.
156 Parameters
157 ----------
158 cat : `pandas.core.frame.DataFrame`
159 A dataframe of the data to be plotted.
160 colName : `str`
161 The name of the column to be plotted.
162 visitSummaryTable : `pandas.core.frame.DataFrame`
163 A dataframe of the visit summary table.
165 Returns
166 -------
167 visitInfoDict : `dict`
168 A dictionary of the visit information.
169 """
170 visitInfoDict = {}
171 for ccd in cat.detector.unique():
172 if ccd is None:
173 continue
174 onCcd = cat["detector"] == ccd
175 stat = np.nanmedian(cat[colName].values[onCcd])
177 sumRow = visitSummaryTable["id"] == ccd
178 corners = zip(visitSummaryTable["raCorners"][sumRow][0], visitSummaryTable["decCorners"][sumRow][0])
179 cornersOut = []
180 for ra, dec in corners:
181 corner = SpherePoint(ra, dec, units=degrees)
182 cornersOut.append(corner)
184 visitInfoDict[ccd] = (cornersOut, stat)
186 return visitInfoDict
189# Inspired by matplotlib.testing.remove_ticks_and_titles
190def get_and_remove_axis_text(ax) -> Tuple[List[str], List[np.ndarray]]:
191 """Remove text from an Axis and its children and return with line points.
193 Parameters
194 ----------
195 ax : `plt.Axis`
196 A matplotlib figure axis.
198 Returns
199 -------
200 texts : `List[str]`
201 A list of all text strings (title and axis/legend/tick labels).
202 line_xys : `List[numpy.ndarray]`
203 A list of all line ``_xy`` attributes (arrays of shape ``(N, 2)``).
204 """
205 line_xys = [line._xy for line in ax.lines]
206 texts = [text.get_text() for text in (ax.title, ax.xaxis.label, ax.yaxis.label)]
207 ax.set_title("")
208 ax.set_xlabel("")
209 ax.set_ylabel("")
211 try:
212 texts_legend = ax.get_legend().texts
213 texts.extend(text.get_text() for text in texts_legend)
214 for text in texts_legend:
215 text.set_alpha(0)
216 except AttributeError:
217 pass
219 for idx in range(len(ax.texts)):
220 texts.append(ax.texts[idx].get_text())
221 ax.texts[idx].set_text("")
223 ax.xaxis.set_major_formatter(null_formatter)
224 ax.xaxis.set_minor_formatter(null_formatter)
225 ax.yaxis.set_major_formatter(null_formatter)
226 ax.yaxis.set_minor_formatter(null_formatter)
227 try:
228 ax.zaxis.set_major_formatter(null_formatter)
229 ax.zaxis.set_minor_formatter(null_formatter)
230 except AttributeError:
231 pass
232 for child in ax.child_axes:
233 texts_child, lines_child = get_and_remove_axis_text(child)
234 texts.extend(texts_child)
236 return texts, line_xys
239def get_and_remove_figure_text(figure: Figure):
240 """Remove text from a Figure and its Axes and return with line points.
242 Parameters
243 ----------
244 figure : `matplotlib.pyplot.Figure`
245 A matplotlib figure.
247 Returns
248 -------
249 texts : `List[str]`
250 A list of all text strings (title and axis/legend/tick labels).
251 line_xys : `List[numpy.ndarray]`, (N, 2)
252 A list of all line ``_xy`` attributes (arrays of shape ``(N, 2)``).
253 """
254 texts = [str(figure._suptitle)]
255 lines = []
256 figure.suptitle("")
258 texts.extend(text.get_text() for text in figure.texts)
259 figure.texts = []
261 for ax in figure.get_axes():
262 texts_ax, lines_ax = get_and_remove_axis_text(ax)
263 texts.extend(texts_ax)
264 lines.extend(lines_ax)
266 return texts, lines
269def addPlotInfo(fig: Figure, plotInfo: Mapping[str, str]) -> Figure:
270 """Add useful information to the plot.
272 Parameters
273 ----------
274 fig : `matplotlib.figure.Figure`
275 The figure to add the information to.
276 plotInfo : `dict`
277 A dictionary of the plot information.
279 Returns
280 -------
281 fig : `matplotlib.figure.Figure`
282 The figure with the information added.
283 """
284 # TO DO: figure out how to get this information
285 photocalibDataset = "None"
286 astroDataset = "None"
288 fig.text(0.01, 0.99, plotInfo["plotName"], fontsize=8, transform=fig.transFigure, ha="left", va="top")
290 run = plotInfo["run"]
291 datasetsUsed = f"\nPhotoCalib: {photocalibDataset}, Astrometry: {astroDataset}"
292 tableType = f"\nTable: {plotInfo['tableName']}"
294 dataIdText = ""
295 if "tract" in plotInfo.keys():
296 dataIdText += f", Tract: {plotInfo['tract']}"
297 if "visit" in plotInfo.keys():
298 dataIdText += f", Visit: {plotInfo['visit']}"
300 bandText = ""
301 for band in plotInfo["bands"]:
302 bandText += band + ", "
303 bandsText = f", Bands: {bandText[:-2]}"
304 SNText = f", S/N: {plotInfo.get('SN', 'N/A')}"
305 infoText = f"\n{run}{datasetsUsed}{tableType}{dataIdText}{bandsText}{SNText}"
306 fig.text(0.01, 0.98, infoText, fontsize=7, transform=fig.transFigure, alpha=0.6, ha="left", va="top")
308 return fig
311def mkColormap(colorNames):
312 """Make a colormap from the list of color names.
314 Parameters
315 ----------
316 colorNames : `list`
317 A list of strings that correspond to matplotlib named colors.
319 Returns
320 -------
321 cmap : `matplotlib.colors.LinearSegmentedColormap`
322 A colormap stepping through the supplied list of names.
323 """
324 nums = np.linspace(0, 1, len(colorNames))
325 blues = []
326 greens = []
327 reds = []
328 for num, color in zip(nums, colorNames):
329 r, g, b = colors.colorConverter.to_rgb(color)
330 blues.append((num, b, b))
331 greens.append((num, g, g))
332 reds.append((num, r, r))
334 colorDict = {"blue": blues, "red": reds, "green": greens}
335 cmap = colors.LinearSegmentedColormap("newCmap", colorDict)
336 return cmap
339def extremaSort(xs):
340 """Return the IDs of the points reordered so that those furthest from the
341 median, in absolute terms, are last.
343 Parameters
344 ----------
345 xs : `np.array`
346 An array of the values to sort
348 Returns
349 -------
350 ids : `np.array`
351 """
352 med = np.nanmedian(xs)
353 dists = np.abs(xs - med)
354 ids = np.argsort(dists)
355 return ids
358def sortAllArrays(arrsToSort, sortArrayIndex=0):
359 """Sort one array and then return all the others in the associated order.
361 Parameters
362 ----------
363 arrsToSort : `list` [`np.array`]
364 A list of arrays to be simultaneously sorted based on the array in
365 the list position given by ``sortArrayIndex`` (defaults to be the
366 first array in the list).
367 sortArrayIndex : `int`, optional
368 Zero-based index indicating the array on which to base the sorting.
370 Returns
371 -------
372 arrsToSort : `list` [`np.array`]
373 The list of arrays sorted on array in list index ``sortArrayIndex``.
374 """
375 ids = extremaSort(arrsToSort[sortArrayIndex])
376 for i, arr in enumerate(arrsToSort):
377 arrsToSort[i] = arr[ids]
378 return arrsToSort
381def addSummaryPlot(fig, loc, sumStats, label):
382 """Add a summary subplot to the figure.
384 Parameters
385 ----------
386 fig : `matplotlib.figure.Figure`
387 The figure that the summary plot is to be added to.
388 loc : `matplotlib.gridspec.SubplotSpec` or `int` or `(int, int, index`
389 Describes the location in the figure to put the summary plot,
390 can be a gridspec SubplotSpec, a 3 digit integer where the first
391 digit is the number of rows, the second is the number of columns
392 and the third is the index. This is the same for the tuple
393 of int, int, index.
394 sumStats : `dict`
395 A dictionary where the patchIds are the keys which store the R.A.
396 and the dec of the corners of the patch, along with a summary
397 statistic for each patch.
398 label : `str`
399 The label to be used for the colorbar.
401 Returns
402 -------
403 fig : `matplotlib.figure.Figure`
404 """
405 # Add the subplot to the relevant place in the figure
406 # and sort the axis out
407 axCorner = fig.add_subplot(loc)
408 axCorner.yaxis.tick_right()
409 axCorner.yaxis.set_label_position("right")
410 axCorner.xaxis.tick_top()
411 axCorner.xaxis.set_label_position("top")
412 axCorner.set_aspect("equal")
414 # Plot the corners of the patches and make the color
415 # coded rectangles for each patch, the colors show
416 # the median of the given value in the patch
417 patches = []
418 colors = []
419 for dataId in sumStats.keys():
420 (corners, stat) = sumStats[dataId]
421 ra = corners[0][0].asDegrees()
422 dec = corners[0][1].asDegrees()
423 xy = (ra, dec)
424 width = corners[2][0].asDegrees() - ra
425 height = corners[2][1].asDegrees() - dec
426 patches.append(Rectangle(xy, width, height))
427 colors.append(stat)
428 ras = [ra.asDegrees() for (ra, dec) in corners]
429 decs = [dec.asDegrees() for (ra, dec) in corners]
430 axCorner.plot(ras + [ras[0]], decs + [decs[0]], "k", lw=0.5)
431 cenX = ra + width / 2
432 cenY = dec + height / 2
433 if dataId != "tract":
434 axCorner.annotate(dataId, (cenX, cenY), color="k", fontsize=4, ha="center", va="center")
436 # Set the bad color to transparent and make a masked array
437 cmapPatch = plt.cm.coolwarm.copy()
438 cmapPatch.set_bad(color="none")
439 colors = np.ma.array(colors, mask=np.isnan(colors))
440 collection = PatchCollection(patches, cmap=cmapPatch)
441 collection.set_array(colors)
442 axCorner.add_collection(collection)
444 # Add some labels
445 axCorner.set_xlabel("R.A. (deg)", fontsize=7)
446 axCorner.set_ylabel("Dec. (deg)", fontsize=7)
447 axCorner.tick_params(axis="both", labelsize=6, length=0, pad=1.5)
448 axCorner.invert_xaxis()
450 # Add a colorbar
451 pos = axCorner.get_position()
452 yOffset = (pos.y1 - pos.y0) / 3
453 cax = fig.add_axes([pos.x0, pos.y1 + yOffset, pos.x1 - pos.x0, 0.025])
454 plt.colorbar(collection, cax=cax, orientation="horizontal")
455 cax.text(
456 0.5,
457 0.48,
458 label,
459 color="k",
460 transform=cax.transAxes,
461 rotation="horizontal",
462 horizontalalignment="center",
463 verticalalignment="center",
464 fontsize=6,
465 )
466 cax.tick_params(
467 axis="x", labelsize=6, labeltop=True, labelbottom=False, bottom=False, top=True, pad=0.5, length=2
468 )
470 return fig
473def shorten_list(numbers: Iterable[int], *, range_indicator: str = "-", range_separator: str = ",") -> str:
474 """Shorten an iterable of integers.
476 Parameters
477 ----------
478 numbers : `~collections.abc.Iterable` [`int`]
479 Any iterable (list, set, tuple, numpy.array) of integers.
480 range_indicator : `str`, optional
481 The string to use to indicate a range of numbers.
482 range_separator : `str`, optional
483 The string to use to separate ranges of numbers.
485 Returns
486 -------
487 result : `str`
488 A shortened string representation of the list.
490 Examples
491 --------
492 >>> shorten_list([1,2,3,5,6,8])
493 "1-3,5-6,8"
495 >>> shorten_list((1,2,3,5,6,8,9,10,11), range_separator=", ")
496 "1-3, 5-6, 8-11"
498 >>> shorten_list(range(4), range_indicator="..")
499 "0..3"
500 """
501 # Sort the list in ascending order.
502 numbers = sorted(numbers)
504 if not numbers: # empty container
505 return ""
507 # Initialize an empty list to hold the results to be returned.
508 result = []
510 # Initialize variables to track the current start and end of a list.
511 start = 0
512 end = 0 # initialize to 0 to handle single element lists.
514 # Iterate through the sorted list of numbers
515 for end in range(1, len(numbers)):
516 # If the current number is the same or consecutive to the previous
517 # number, skip to the next iteration.
518 if numbers[end] > numbers[end - 1] + 1: # > is used to handle duplicates, if any.
519 # If the current number is not consecutive to the previous number,
520 # add the current range to the result and reset the start to end.
521 if start == end - 1:
522 result.append(str(numbers[start]))
523 else:
524 result.append(range_indicator.join((str(numbers[start]), str(numbers[end - 1]))))
526 # Update start.
527 start = end
529 # Add the final range to the result.
530 if start == end:
531 result.append(str(numbers[start]))
532 else:
533 result.append(range_indicator.join((str(numbers[start]), str(numbers[end]))))
535 # Return the shortened string representation.
536 return range_separator.join(result)
539class PanelConfig(Config):
540 """Configuration options for the plot panels used by DiaSkyPlot.
542 The defaults will produce a good-looking single panel plot.
543 The subplot2grid* fields correspond to matplotlib.pyplot.subplot2grid.
544 """
546 topSpinesVisible = Field[bool](
547 doc="Draw line and ticks on top of panel?",
548 default=False,
549 )
550 bottomSpinesVisible = Field[bool](
551 doc="Draw line and ticks on bottom of panel?",
552 default=True,
553 )
554 leftSpinesVisible = Field[bool](
555 doc="Draw line and ticks on left side of panel?",
556 default=True,
557 )
558 rightSpinesVisible = Field[bool](
559 doc="Draw line and ticks on right side of panel?",
560 default=True,
561 )
562 subplot2gridShapeRow = Field[int](
563 doc="Number of rows of the grid in which to place axis.",
564 default=10,
565 )
566 subplot2gridShapeColumn = Field[int](
567 doc="Number of columns of the grid in which to place axis.",
568 default=10,
569 )
570 subplot2gridLocRow = Field[int](
571 doc="Row of the axis location within the grid.",
572 default=1,
573 )
574 subplot2gridLocColumn = Field[int](
575 doc="Column of the axis location within the grid.",
576 default=1,
577 )
578 subplot2gridRowspan = Field[int](
579 doc="Number of rows for the axis to span downwards.",
580 default=5,
581 )
582 subplot2gridColspan = Field[int](
583 doc="Number of rows for the axis to span to the right.",
584 default=5,
585 )
588def plotProjectionWithBinning(
589 ax,
590 xs,
591 ys,
592 zs,
593 cmap,
594 xMin,
595 xMax,
596 yMin,
597 yMax,
598 xNumBins=45,
599 yNumBins=None,
600 fixAroundZero=False,
601 nPointBinThresh=5000,
602 isSorted=False,
603 vmin=None,
604 vmax=None,
605 scatPtSize=7,
606):
607 """Plot color-mapped data in projection and with binning when appropriate.
609 Parameters
610 ----------
611 ax : `matplotlib.axes.Axes`
612 Axis on which to plot the projection data.
613 xs, ys : `np.array`
614 Arrays containing the x and y positions of the data.
615 zs : `np.array`
616 Array containing the scaling value associated with the (``xs``, ``ys``)
617 positions.
618 cmap : `matplotlib.colors.Colormap`
619 Colormap for the ``zs`` values.
620 xMin, xMax, yMin, yMax : `float`
621 Data limits within which to compute bin sizes.
622 xNumBins : `int`, optional
623 The number of bins along the x-axis.
624 yNumBins : `int`, optional
625 The number of bins along the y-axis. If `None`, this is set to equal
626 ``xNumBins``.
627 nPointBinThresh : `int`, optional
628 Threshold number of points above which binning will be implemented
629 for the plotting. If the number of data points is lower than this
630 threshold, a basic scatter plot will be generated.
631 isSorted : `bool`, optional
632 Whether the data have been sorted in ``zs`` (the sorting is to
633 accommodate the overplotting of points in the upper and lower
634 extrema of the data).
635 vmin, vmax : `float`, optional
636 The min and max limits for the colorbar.
637 scatPtSize : `float`, optional
638 The point size to use if just plotting a regular scatter plot.
640 Returns
641 -------
642 plotOut : `matplotlib.collections.PathCollection`
643 The plot object with ``ax`` updated with data plotted here.
644 """
645 med = np.nanmedian(zs)
646 mad = nansigmaMad(zs)
647 if vmin is None:
648 vmin = med - 2 * mad
649 if vmax is None:
650 vmax = med + 2 * mad
651 if fixAroundZero:
652 scaleEnd = np.max([np.abs(vmin), np.abs(vmax)])
653 vmin = -1 * scaleEnd
654 vmax = scaleEnd
656 yNumBins = xNumBins if yNumBins is None else yNumBins
658 xBinEdges = np.linspace(xMin, xMax, xNumBins + 1)
659 yBinEdges = np.linspace(yMin, yMax, yNumBins + 1)
660 binnedStats, xEdges, yEdges, binNums = binned_statistic_2d(
661 xs, ys, zs, statistic="median", bins=(xBinEdges, yBinEdges)
662 )
663 if len(xs) >= nPointBinThresh:
664 s = min(10, max(0.5, nPointBinThresh / 10 / (len(xs) ** 0.5)))
665 lw = (s**0.5) / 10
666 plotOut = ax.imshow(
667 binnedStats.T,
668 cmap=cmap,
669 extent=[xEdges[0], xEdges[-1], yEdges[-1], yEdges[0]],
670 vmin=vmin,
671 vmax=vmax,
672 )
673 if not isSorted:
674 sortedArrays = sortAllArrays([zs, xs, ys])
675 zs, xs, ys = sortedArrays[0], sortedArrays[1], sortedArrays[2]
676 # Find the most extreme 15% of points. The list is ordered by the
677 # distance from the median, this is just the head/tail 15% of points.
678 if len(xs) > 1:
679 extremes = int(np.floor((len(xs) / 100)) * 85)
680 plotOut = ax.scatter(
681 xs[extremes:],
682 ys[extremes:],
683 c=zs[extremes:],
684 s=s,
685 cmap=cmap,
686 vmin=vmin,
687 vmax=vmax,
688 edgecolor="white",
689 linewidths=lw,
690 )
691 else:
692 plotOut = ax.scatter(
693 xs,
694 ys,
695 c=zs,
696 cmap=cmap,
697 s=scatPtSize,
698 vmin=vmin,
699 vmax=vmax,
700 edgecolor="white",
701 linewidths=0.2,
702 )
703 return plotOut