Coverage for python/lsst/analysis/tools/actions/plot/plotUtils.py: 12%
249 statements
« prev ^ index » next coverage.py v6.5.0, created at 2022-12-23 09:29 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2022-12-23 09:29 +0000
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ("PanelConfig",)
25from typing import TYPE_CHECKING, List, Mapping, Tuple
27import matplotlib
28import matplotlib.pyplot as plt
29import numpy as np
30import scipy.odr as scipyODR
31from lsst.geom import Box2D, SpherePoint, degrees
32from lsst.pex.config import Config, Field
33from matplotlib import colors
34from matplotlib.collections import PatchCollection
35from matplotlib.patches import Rectangle
37if TYPE_CHECKING: 37 ↛ 38line 37 didn't jump to line 38, because the condition on line 37 was never true
38 from matplotlib.figure import Figure
40null_formatter = matplotlib.ticker.NullFormatter()
43def parsePlotInfo(dataId, runName, tableName, bands, plotName, SN):
44 """Parse plot info from the dataId
45 Parameters
46 ----------
47 dataId : `lsst.daf.butler.core.dimensions.`
48 `_coordinate._ExpandedTupleDataCoordinate`
49 runName : `str`
50 Returns
51 -------
52 plotInfo : `dict`
53 """
54 plotInfo = {"run": runName, "tableName": tableName, "plotName": plotName, "SN": SN}
56 for dataInfo in dataId:
57 plotInfo[dataInfo.name] = dataId[dataInfo.name]
59 bandStr = ""
60 for band in bands:
61 bandStr += ", " + band
62 plotInfo["bands"] = bandStr[2:]
64 if "tract" not in plotInfo.keys():
65 plotInfo["tract"] = "N/A"
66 if "visit" not in plotInfo.keys():
67 plotInfo["visit"] = "N/A"
69 return plotInfo
72def generateSummaryStats(data, skymap, plotInfo):
73 """Generate a summary statistic in each patch or detector
74 Parameters
75 ----------
76 data : `dict`
77 skymap : `lsst.skymap.ringsSkyMap.RingsSkyMap`
78 plotInfo : `dict`
80 Returns
81 -------
82 patchInfoDict : `dict`
83 """
85 # TODO: what is the more generic type of skymap?
86 tractInfo = skymap.generateTract(plotInfo["tract"])
87 tractWcs = tractInfo.getWcs()
89 # For now also convert the gen 2 patchIds to gen 3
90 if "y" in data.keys():
91 yCol = "y"
92 elif "yStars" in data.keys():
93 yCol = "yStars"
94 elif "yGalaxies" in data.keys():
95 yCol = "yGalaxies"
96 elif "yUnknowns" in data.keys():
97 yCol = "yUnknowns"
99 patchInfoDict = {}
100 maxPatchNum = tractInfo.num_patches.x * tractInfo.num_patches.y
101 patches = np.arange(0, maxPatchNum, 1)
102 for patch in patches:
103 if patch is None:
104 continue
105 # Once the objectTable_tract catalogues are using gen 3 patches
106 # this will go away
107 onPatch = data["patch"] == patch
108 stat = np.nanmedian(data[yCol][onPatch])
109 try:
110 patchTuple = (int(patch.split(",")[0]), int(patch.split(",")[-1]))
111 patchInfo = tractInfo.getPatchInfo(patchTuple)
112 gen3PatchId = tractInfo.getSequentialPatchIndex(patchInfo)
113 except AttributeError:
114 # For native gen 3 tables the patches don't need converting
115 # When we are no longer looking at the gen 2 -> gen 3
116 # converted repos we can tidy this up
117 gen3PatchId = patch
118 patchInfo = tractInfo.getPatchInfo(patch)
120 corners = Box2D(patchInfo.getInnerBBox()).getCorners()
121 skyCoords = tractWcs.pixelToSky(corners)
123 patchInfoDict[gen3PatchId] = (skyCoords, stat)
125 tractCorners = Box2D(tractInfo.getBBox()).getCorners()
126 skyCoords = tractWcs.pixelToSky(tractCorners)
127 patchInfoDict["tract"] = (skyCoords, np.nan)
129 return patchInfoDict
132def generateSummaryStatsVisit(cat, colName, visitSummaryTable, plotInfo):
133 """Generate a summary statistic in each patch or detector
134 Parameters
135 ----------
136 cat : `pandas.core.frame.DataFrame`
137 colName : `str`
138 visitSummaryTable : `pandas.core.frame.DataFrame`
139 plotInfo : `dict`
140 Returns
141 -------
142 visitInfoDict : `dict`
143 """
145 visitInfoDict = {}
146 for ccd in cat.detector.unique():
147 if ccd is None:
148 continue
149 onCcd = cat["detector"] == ccd
150 stat = np.nanmedian(cat[colName].values[onCcd])
152 sumRow = visitSummaryTable["id"] == ccd
153 corners = zip(visitSummaryTable["raCorners"][sumRow][0], visitSummaryTable["decCorners"][sumRow][0])
154 cornersOut = []
155 for (ra, dec) in corners:
156 corner = SpherePoint(ra, dec, units=degrees)
157 cornersOut.append(corner)
159 visitInfoDict[ccd] = (cornersOut, stat)
161 return visitInfoDict
164# Inspired by matplotlib.testing.remove_ticks_and_titles
165def get_and_remove_axis_text(ax) -> Tuple[List[str], List[np.ndarray]]:
166 """Remove text from an Axis and its children and return with line points.
167 Parameters
168 ----------
169 ax : `plt.Axis`
170 A matplotlib figure axis.
171 Returns
172 -------
173 texts : `List[str]`
174 A list of all text strings (title and axis/legend/tick labels).
175 line_xys : `List[numpy.ndarray]`
176 A list of all line ``_xy`` attributes (arrays of shape ``(N, 2)``).
177 """
178 line_xys = [line._xy for line in ax.lines]
179 texts = [text.get_text() for text in (ax.title, ax.xaxis.label, ax.yaxis.label)]
180 ax.set_title("")
181 ax.set_xlabel("")
182 ax.set_ylabel("")
184 try:
185 texts_legend = ax.get_legend().texts
186 texts.extend(text.get_text() for text in texts_legend)
187 for text in texts_legend:
188 text.set_alpha(0)
189 except AttributeError:
190 pass
192 for idx in range(len(ax.texts)):
193 texts.append(ax.texts[idx].get_text())
194 ax.texts[idx].set_text("")
196 ax.xaxis.set_major_formatter(null_formatter)
197 ax.xaxis.set_minor_formatter(null_formatter)
198 ax.yaxis.set_major_formatter(null_formatter)
199 ax.yaxis.set_minor_formatter(null_formatter)
200 try:
201 ax.zaxis.set_major_formatter(null_formatter)
202 ax.zaxis.set_minor_formatter(null_formatter)
203 except AttributeError:
204 pass
205 for child in ax.child_axes:
206 texts_child, lines_child = get_and_remove_axis_text(child)
207 texts.extend(texts_child)
209 return texts, line_xys
212def get_and_remove_figure_text(figure: Figure):
213 """Remove text from a Figure and its Axes and return with line points.
214 Parameters
215 ----------
216 figure : `matplotlib.pyplot.Figure`
217 A matplotlib figure.
218 Returns
219 -------
220 texts : `List[str]`
221 A list of all text strings (title and axis/legend/tick labels).
222 line_xys : `List[numpy.ndarray]`, (N, 2)
223 A list of all line ``_xy`` attributes (arrays of shape ``(N, 2)``).
224 """
225 texts = [str(figure._suptitle)]
226 lines = []
227 figure.suptitle("")
229 texts.extend(text.get_text() for text in figure.texts)
230 figure.texts = []
232 for ax in figure.get_axes():
233 texts_ax, lines_ax = get_and_remove_axis_text(ax)
234 texts.extend(texts_ax)
235 lines.extend(lines_ax)
237 return texts, lines
240def addPlotInfo(fig: Figure, plotInfo: Mapping[str, str]) -> Figure:
241 """Add useful information to the plot
242 Parameters
243 ----------
244 fig : `matplotlib.figure.Figure`
245 plotInfo : `dict`
246 Returns
247 -------
248 fig : `matplotlib.figure.Figure`
249 """
251 # TO DO: figure out how to get this information
252 photocalibDataset = "None"
253 astroDataset = "None"
255 fig.text(0.01, 0.99, plotInfo["plotName"], fontsize=8, transform=fig.transFigure, ha="left", va="top")
257 run = plotInfo["run"]
258 datasetsUsed = f"\nPhotoCalib: {photocalibDataset}, Astrometry: {astroDataset}"
259 tableType = f"\nTable: {plotInfo['tableName']}"
261 dataIdText = ""
262 if "tract" in plotInfo.keys():
263 dataIdText += f", Tract: {plotInfo['tract']}"
264 if "visit" in plotInfo.keys():
265 dataIdText += f", Visit: {plotInfo['visit']}"
267 bandText = ""
268 for band in plotInfo["bands"]:
269 bandText += band + ", "
270 bandsText = f", Bands: {bandText[:-2]}"
271 SNText = f", S/N: {plotInfo.get('SN', 'N/A')}"
272 infoText = f"\n{run}{datasetsUsed}{tableType}{dataIdText}{bandsText}{SNText}"
273 fig.text(0.01, 0.98, infoText, fontsize=7, transform=fig.transFigure, alpha=0.6, ha="left", va="top")
275 return fig
278def stellarLocusFit(xs, ys, paramDict):
279 """Make a fit to the stellar locus
280 Parameters
281 ----------
282 xs : `numpy.ndarray`
283 The color on the xaxis
284 ys : `numpy.ndarray`
285 The color on the yaxis
286 paramDict : lsst.pex.config.dictField.Dict
287 A dictionary of parameters for line fitting
288 xMin : `float`
289 The minimum x edge of the box to use for initial fitting
290 xMax : `float`
291 The maximum x edge of the box to use for initial fitting
292 yMin : `float`
293 The minimum y edge of the box to use for initial fitting
294 yMax : `float`
295 The maximum y edge of the box to use for initial fitting
296 mHW : `float`
297 The hardwired gradient for the fit
298 bHW : `float`
299 The hardwired intercept of the fit
300 Returns
301 -------
302 paramsOut : `dict`
303 A dictionary of the calculated fit parameters
304 xMin : `float`
305 The minimum x edge of the box to use for initial fitting
306 xMax : `float`
307 The maximum x edge of the box to use for initial fitting
308 yMin : `float`
309 The minimum y edge of the box to use for initial fitting
310 yMax : `float`
311 The maximum y edge of the box to use for initial fitting
312 mHW : `float`
313 The hardwired gradient for the fit
314 bHW : `float`
315 The hardwired intercept of the fit
316 mODR : `float`
317 The gradient calculated by the ODR fit
318 bODR : `float`
319 The intercept calculated by the ODR fit
320 yBoxMin : `float`
321 The y value of the fitted line at xMin
322 yBoxMax : `float`
323 The y value of the fitted line at xMax
324 bPerpMin : `float`
325 The intercept of the perpendicular line that goes through xMin
326 bPerpMax : `float`
327 The intercept of the perpendicular line that goes through xMax
328 mODR2 : `float`
329 The gradient from the second round of fitting
330 bODR2 : `float`
331 The intercept from the second round of fitting
332 mPerp : `float`
333 The gradient of the line perpendicular to the line from the
334 second fit
335 Notes
336 -----
337 The code does two rounds of fitting, the first is initiated using the
338 hardwired values given in the `paramDict` parameter and is done using
339 an Orthogonal Distance Regression fit to the points defined by the
340 box of xMin, xMax, yMin and yMax. Once this fitting has been done a
341 perpendicular bisector is calculated at either end of the line and
342 only points that fall within these lines are used to recalculate the fit.
343 """
345 # Points to use for the fit
346 fitPoints = np.where(
347 (xs > paramDict["xMin"])
348 & (xs < paramDict["xMax"])
349 & (ys > paramDict["yMin"])
350 & (ys < paramDict["yMax"])
351 )[0]
353 linear = scipyODR.polynomial(1)
355 data = scipyODR.Data(xs[fitPoints], ys[fitPoints])
356 odr = scipyODR.ODR(data, linear, beta0=[paramDict["bHW"], paramDict["mHW"]])
357 params = odr.run()
358 mODR = float(params.beta[1])
359 bODR = float(params.beta[0])
361 paramsOut = {
362 "xMin": paramDict["xMin"],
363 "xMax": paramDict["xMax"],
364 "yMin": paramDict["yMin"],
365 "yMax": paramDict["yMax"],
366 "mHW": paramDict["mHW"],
367 "bHW": paramDict["bHW"],
368 "mODR": mODR,
369 "bODR": bODR,
370 }
372 # Having found the initial fit calculate perpendicular ends
373 mPerp = -1.0 / mODR
374 # When the gradient is really steep we need to use
375 # the y limits of the box rather than the x ones
377 if np.abs(mODR) > 1:
378 yBoxMin = paramDict["yMin"]
379 xBoxMin = (yBoxMin - bODR) / mODR
380 yBoxMax = paramDict["yMax"]
381 xBoxMax = (yBoxMax - bODR) / mODR
382 else:
383 yBoxMin = mODR * paramDict["xMin"] + bODR
384 xBoxMin = paramDict["xMin"]
385 yBoxMax = mODR * paramDict["xMax"] + bODR
386 xBoxMax = paramDict["xMax"]
388 bPerpMin = yBoxMin - mPerp * xBoxMin
390 paramsOut["yBoxMin"] = yBoxMin
391 paramsOut["bPerpMin"] = bPerpMin
393 bPerpMax = yBoxMax - mPerp * xBoxMax
395 paramsOut["yBoxMax"] = yBoxMax
396 paramsOut["bPerpMax"] = bPerpMax
398 # Use these perpendicular lines to chose the data and refit
399 fitPoints = (ys > mPerp * xs + bPerpMin) & (ys < mPerp * xs + bPerpMax)
400 data = scipyODR.Data(xs[fitPoints], ys[fitPoints])
401 odr = scipyODR.ODR(data, linear, beta0=[bODR, mODR])
402 params = odr.run()
403 mODR = float(params.beta[1])
404 bODR = float(params.beta[0])
406 paramsOut["mODR2"] = float(params.beta[1])
407 paramsOut["bODR2"] = float(params.beta[0])
409 paramsOut["mPerp"] = -1.0 / paramsOut["mODR2"]
411 return paramsOut
414def perpDistance(p1, p2, points):
415 """Calculate the perpendicular distance to a line from a point
416 Parameters
417 ----------
418 p1 : `numpy.ndarray`
419 A point on the line
420 p2 : `numpy.ndarray`
421 Another point on the line
422 points : `zip`
423 The points to calculate the distance to
424 Returns
425 -------
426 dists : `list`
427 The distances from the line to the points. Uses the cross
428 product to work this out.
429 """
430 dists = []
431 for point in points:
432 point = np.array(point)
433 distToLine = np.cross(p2 - p1, point - p1) / np.linalg.norm(p2 - p1)
434 dists.append(distToLine)
436 return dists
439def mkColormap(colorNames):
440 """Make a colormap from the list of color names.
441 Parameters
442 ----------
443 colorNames : `list`
444 A list of strings that correspond to matplotlib
445 named colors.
446 Returns
447 -------
448 cmap : `matplotlib.colors.LinearSegmentedColormap`
449 """
451 nums = np.linspace(0, 1, len(colorNames))
452 blues = []
453 greens = []
454 reds = []
455 for (num, color) in zip(nums, colorNames):
456 r, g, b = colors.colorConverter.to_rgb(color)
457 blues.append((num, b, b))
458 greens.append((num, g, g))
459 reds.append((num, r, r))
461 colorDict = {"blue": blues, "red": reds, "green": greens}
462 cmap = colors.LinearSegmentedColormap("newCmap", colorDict)
463 return cmap
466def extremaSort(xs):
467 """Return the ids of the points reordered so that those
468 furthest from the median, in absolute terms, are last.
469 Parameters
470 ----------
471 xs : `np.array`
472 An array of the values to sort
473 Returns
474 -------
475 ids : `np.array`
476 """
478 med = np.nanmedian(xs)
479 dists = np.abs(xs - med)
480 ids = np.argsort(dists)
481 return ids
484def addSummaryPlot(fig, loc, sumStats, label):
485 """Add a summary subplot to the figure.
487 Parameters
488 ----------
489 fig : `matplotlib.figure.Figure`
490 The figure that the summary plot is to be added to.
491 loc : `matplotlib.gridspec.SubplotSpec` or `int` or `(int, int, index`
492 Describes the location in the figure to put the summary plot,
493 can be a gridspec SubplotSpec, a 3 digit integer where the first
494 digit is the number of rows, the second is the number of columns
495 and the third is the index. This is the same for the tuple
496 of int, int, index.
497 sumStats : `dict`
498 A dictionary where the patchIds are the keys which store the R.A.
499 and the dec of the corners of the patch, along with a summary
500 statistic for each patch.
502 Returns
503 -------
504 fig : `matplotlib.figure.Figure`
505 """
507 # Add the subplot to the relevant place in the figure
508 # and sort the axis out
509 axCorner = fig.add_subplot(loc)
510 axCorner.yaxis.tick_right()
511 axCorner.yaxis.set_label_position("right")
512 axCorner.xaxis.tick_top()
513 axCorner.xaxis.set_label_position("top")
514 axCorner.set_aspect("equal")
516 # Plot the corners of the patches and make the color
517 # coded rectangles for each patch, the colors show
518 # the median of the given value in the patch
519 patches = []
520 colors = []
521 for dataId in sumStats.keys():
522 (corners, stat) = sumStats[dataId]
523 ra = corners[0][0].asDegrees()
524 dec = corners[0][1].asDegrees()
525 xy = (ra, dec)
526 width = corners[2][0].asDegrees() - ra
527 height = corners[2][1].asDegrees() - dec
528 patches.append(Rectangle(xy, width, height))
529 colors.append(stat)
530 ras = [ra.asDegrees() for (ra, dec) in corners]
531 decs = [dec.asDegrees() for (ra, dec) in corners]
532 axCorner.plot(ras + [ras[0]], decs + [decs[0]], "k", lw=0.5)
533 cenX = ra + width / 2
534 cenY = dec + height / 2
535 if dataId != "tract":
536 axCorner.annotate(dataId, (cenX, cenY), color="k", fontsize=4, ha="center", va="center")
538 # Set the bad color to transparent and make a masked array
539 cmapPatch = plt.cm.coolwarm.copy()
540 cmapPatch.set_bad(color="none")
541 colors = np.ma.array(colors, mask=np.isnan(colors))
542 collection = PatchCollection(patches, cmap=cmapPatch)
543 collection.set_array(colors)
544 axCorner.add_collection(collection)
546 # Add some labels
547 axCorner.set_xlabel("R.A. (deg)", fontsize=7)
548 axCorner.set_ylabel("Dec. (deg)", fontsize=7)
549 axCorner.tick_params(axis="both", labelsize=6, length=0, pad=1.5)
550 axCorner.invert_xaxis()
552 # Add a colorbar
553 pos = axCorner.get_position()
554 yOffset = (pos.y1 - pos.y0) / 3
555 cax = fig.add_axes([pos.x0, pos.y1 + yOffset, pos.x1 - pos.x0, 0.025])
556 plt.colorbar(collection, cax=cax, orientation="horizontal")
557 cax.text(
558 0.5,
559 0.48,
560 label,
561 color="k",
562 transform=cax.transAxes,
563 rotation="horizontal",
564 horizontalalignment="center",
565 verticalalignment="center",
566 fontsize=6,
567 )
568 cax.tick_params(
569 axis="x", labelsize=6, labeltop=True, labelbottom=False, bottom=False, top=True, pad=0.5, length=2
570 )
572 return fig
575class PanelConfig(Config):
576 """Configuration options for the plot panels used by DiaSkyPlot.
578 The defaults will produce a good-looking single panel plot.
579 The subplot2grid* fields correspond to matplotlib.pyplot.subplot2grid.
580 """
582 topSpinesVisible = Field[bool](
583 doc="Draw line and ticks on top of panel?",
584 default=False,
585 )
586 bottomSpinesVisible = Field[bool](
587 doc="Draw line and ticks on bottom of panel?",
588 default=True,
589 )
590 leftSpinesVisible = Field[bool](
591 doc="Draw line and ticks on left side of panel?",
592 default=True,
593 )
594 rightSpinesVisible = Field[bool](
595 doc="Draw line and ticks on right side of panel?",
596 default=True,
597 )
598 subplot2gridShapeRow = Field[int](
599 doc="Number of rows of the grid in which to place axis.",
600 default=10,
601 )
602 subplot2gridShapeColumn = Field[int](
603 doc="Number of columns of the grid in which to place axis.",
604 default=10,
605 )
606 subplot2gridLocRow = Field[int](
607 doc="Row of the axis location within the grid.",
608 default=1,
609 )
610 subplot2gridLocColumn = Field[int](
611 doc="Column of the axis location within the grid.",
612 default=1,
613 )
614 subplot2gridRowspan = Field[int](
615 doc="Number of rows for the axis to span downwards.",
616 default=5,
617 )
618 subplot2gridColspan = Field[int](
619 doc="Number of rows for the axis to span to the right.",
620 default=5,
621 )