Coverage for python/lsst/analysis/tools/actions/plot/plotUtils.py: 11%
247 statements
« prev ^ index » next coverage.py v6.5.0, created at 2022-12-09 02:45 -0800
« prev ^ index » next coverage.py v6.5.0, created at 2022-12-09 02:45 -0800
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ("PanelConfig",)
25from typing import List, Tuple
27import matplotlib
28import matplotlib.pyplot as plt
29import numpy as np
30import scipy.odr as scipyODR
31from lsst.geom import Box2D, SpherePoint, degrees
32from lsst.pex.config import Config, Field
33from matplotlib import colors
34from matplotlib.collections import PatchCollection
35from matplotlib.patches import Rectangle
37null_formatter = matplotlib.ticker.NullFormatter()
40def parsePlotInfo(dataId, runName, tableName, bands, plotName, SN):
41 """Parse plot info from the dataId
42 Parameters
43 ----------
44 dataId : `lsst.daf.butler.core.dimensions.`
45 `_coordinate._ExpandedTupleDataCoordinate`
46 runName : `str`
47 Returns
48 -------
49 plotInfo : `dict`
50 """
51 plotInfo = {"run": runName, "tableName": tableName, "plotName": plotName, "SN": SN}
53 for dataInfo in dataId:
54 plotInfo[dataInfo.name] = dataId[dataInfo.name]
56 bandStr = ""
57 for band in bands:
58 bandStr += ", " + band
59 plotInfo["bands"] = bandStr[2:]
61 if "tract" not in plotInfo.keys():
62 plotInfo["tract"] = "N/A"
63 if "visit" not in plotInfo.keys():
64 plotInfo["visit"] = "N/A"
66 return plotInfo
69def generateSummaryStats(data, skymap, plotInfo):
70 """Generate a summary statistic in each patch or detector
71 Parameters
72 ----------
73 data : `dict`
74 skymap : `lsst.skymap.ringsSkyMap.RingsSkyMap`
75 plotInfo : `dict`
77 Returns
78 -------
79 patchInfoDict : `dict`
80 """
82 # TODO: what is the more generic type of skymap?
83 tractInfo = skymap.generateTract(plotInfo["tract"])
84 tractWcs = tractInfo.getWcs()
86 # For now also convert the gen 2 patchIds to gen 3
87 if "y" in data.keys():
88 yCol = "y"
89 elif "yStars" in data.keys():
90 yCol = "yStars"
91 elif "yGalaxies" in data.keys():
92 yCol = "yGalaxies"
93 elif "yUnknowns" in data.keys():
94 yCol = "yUnknowns"
96 patchInfoDict = {}
97 maxPatchNum = tractInfo.num_patches.x * tractInfo.num_patches.y
98 patches = np.arange(0, maxPatchNum, 1)
99 for patch in patches:
100 if patch is None:
101 continue
102 # Once the objectTable_tract catalogues are using gen 3 patches
103 # this will go away
104 onPatch = data["patch"] == patch
105 stat = np.nanmedian(data[yCol][onPatch])
106 try:
107 patchTuple = (int(patch.split(",")[0]), int(patch.split(",")[-1]))
108 patchInfo = tractInfo.getPatchInfo(patchTuple)
109 gen3PatchId = tractInfo.getSequentialPatchIndex(patchInfo)
110 except AttributeError:
111 # For native gen 3 tables the patches don't need converting
112 # When we are no longer looking at the gen 2 -> gen 3
113 # converted repos we can tidy this up
114 gen3PatchId = patch
115 patchInfo = tractInfo.getPatchInfo(patch)
117 corners = Box2D(patchInfo.getInnerBBox()).getCorners()
118 skyCoords = tractWcs.pixelToSky(corners)
120 patchInfoDict[gen3PatchId] = (skyCoords, stat)
122 tractCorners = Box2D(tractInfo.getBBox()).getCorners()
123 skyCoords = tractWcs.pixelToSky(tractCorners)
124 patchInfoDict["tract"] = (skyCoords, np.nan)
126 return patchInfoDict
129def generateSummaryStatsVisit(cat, colName, visitSummaryTable, plotInfo):
130 """Generate a summary statistic in each patch or detector
131 Parameters
132 ----------
133 cat : `pandas.core.frame.DataFrame`
134 colName : `str`
135 visitSummaryTable : `pandas.core.frame.DataFrame`
136 plotInfo : `dict`
137 Returns
138 -------
139 visitInfoDict : `dict`
140 """
142 visitInfoDict = {}
143 for ccd in cat.detector.unique():
144 if ccd is None:
145 continue
146 onCcd = cat["detector"] == ccd
147 stat = np.nanmedian(cat[colName].values[onCcd])
149 sumRow = visitSummaryTable["id"] == ccd
150 corners = zip(visitSummaryTable["raCorners"][sumRow][0], visitSummaryTable["decCorners"][sumRow][0])
151 cornersOut = []
152 for (ra, dec) in corners:
153 corner = SpherePoint(ra, dec, units=degrees)
154 cornersOut.append(corner)
156 visitInfoDict[ccd] = (cornersOut, stat)
158 return visitInfoDict
161# Inspired by matplotlib.testing.remove_ticks_and_titles
162def get_and_remove_axis_text(ax) -> Tuple[List[str], List[np.ndarray]]:
163 """Remove text from an Axis and its children and return with line points.
164 Parameters
165 ----------
166 ax : `plt.Axis`
167 A matplotlib figure axis.
168 Returns
169 -------
170 texts : `List[str]`
171 A list of all text strings (title and axis/legend/tick labels).
172 line_xys : `List[numpy.ndarray]`
173 A list of all line ``_xy`` attributes (arrays of shape ``(N, 2)``).
174 """
175 line_xys = [line._xy for line in ax.lines]
176 texts = [text.get_text() for text in (ax.title, ax.xaxis.label, ax.yaxis.label)]
177 ax.set_title("")
178 ax.set_xlabel("")
179 ax.set_ylabel("")
181 try:
182 texts_legend = ax.get_legend().texts
183 texts.extend(text.get_text() for text in texts_legend)
184 for text in texts_legend:
185 text.set_alpha(0)
186 except AttributeError:
187 pass
189 for idx in range(len(ax.texts)):
190 texts.append(ax.texts[idx].get_text())
191 ax.texts[idx].set_text("")
193 ax.xaxis.set_major_formatter(null_formatter)
194 ax.xaxis.set_minor_formatter(null_formatter)
195 ax.yaxis.set_major_formatter(null_formatter)
196 ax.yaxis.set_minor_formatter(null_formatter)
197 try:
198 ax.zaxis.set_major_formatter(null_formatter)
199 ax.zaxis.set_minor_formatter(null_formatter)
200 except AttributeError:
201 pass
202 for child in ax.child_axes:
203 texts_child, lines_child = get_and_remove_axis_text(child)
204 texts.extend(texts_child)
206 return texts, line_xys
209def get_and_remove_figure_text(figure: plt.Figure):
210 """Remove text from a Figure and its Axes and return with line points.
211 Parameters
212 ----------
213 figure : `matplotlib.pyplot.Figure`
214 A matplotlib figure.
215 Returns
216 -------
217 texts : `List[str]`
218 A list of all text strings (title and axis/legend/tick labels).
219 line_xys : `List[numpy.ndarray]`, (N, 2)
220 A list of all line ``_xy`` attributes (arrays of shape ``(N, 2)``).
221 """
222 texts = [str(figure._suptitle)]
223 lines = []
224 figure.suptitle("")
226 texts.extend(text.get_text() for text in figure.texts)
227 figure.texts = []
229 for ax in figure.get_axes():
230 texts_ax, lines_ax = get_and_remove_axis_text(ax)
231 texts.extend(texts_ax)
232 lines.extend(lines_ax)
234 return texts, lines
237def addPlotInfo(fig, plotInfo):
238 """Add useful information to the plot
239 Parameters
240 ----------
241 fig : `matplotlib.figure.Figure`
242 plotInfo : `dict`
243 Returns
244 -------
245 fig : `matplotlib.figure.Figure`
246 """
248 # TO DO: figure out how to get this information
249 photocalibDataset = "None"
250 astroDataset = "None"
252 fig.text(0.01, 0.99, plotInfo["plotName"], fontsize=8, transform=fig.transFigure, ha="left", va="top")
254 run = plotInfo["run"]
255 datasetsUsed = f"\nPhotoCalib: {photocalibDataset}, Astrometry: {astroDataset}"
256 tableType = f"\nTable: {plotInfo['tableName']}"
258 dataIdText = ""
259 if "tract" in plotInfo.keys():
260 dataIdText += f", Tract: {plotInfo['tract']}"
261 if "visit" in plotInfo.keys():
262 dataIdText += f", Visit: {plotInfo['visit']}"
264 bandText = ""
265 for band in plotInfo["bands"]:
266 bandText += band + ", "
267 bandsText = f", Bands: {bandText[:-2]}"
268 SNText = f", S/N: {plotInfo.get('SN', 'N/A')}"
269 infoText = f"\n{run}{datasetsUsed}{tableType}{dataIdText}{bandsText}{SNText}"
270 fig.text(0.01, 0.98, infoText, fontsize=7, transform=fig.transFigure, alpha=0.6, ha="left", va="top")
272 return fig
275def stellarLocusFit(xs, ys, paramDict):
276 """Make a fit to the stellar locus
277 Parameters
278 ----------
279 xs : `numpy.ndarray`
280 The color on the xaxis
281 ys : `numpy.ndarray`
282 The color on the yaxis
283 paramDict : lsst.pex.config.dictField.Dict
284 A dictionary of parameters for line fitting
285 xMin : `float`
286 The minimum x edge of the box to use for initial fitting
287 xMax : `float`
288 The maximum x edge of the box to use for initial fitting
289 yMin : `float`
290 The minimum y edge of the box to use for initial fitting
291 yMax : `float`
292 The maximum y edge of the box to use for initial fitting
293 mHW : `float`
294 The hardwired gradient for the fit
295 bHW : `float`
296 The hardwired intercept of the fit
297 Returns
298 -------
299 paramsOut : `dict`
300 A dictionary of the calculated fit parameters
301 xMin : `float`
302 The minimum x edge of the box to use for initial fitting
303 xMax : `float`
304 The maximum x edge of the box to use for initial fitting
305 yMin : `float`
306 The minimum y edge of the box to use for initial fitting
307 yMax : `float`
308 The maximum y edge of the box to use for initial fitting
309 mHW : `float`
310 The hardwired gradient for the fit
311 bHW : `float`
312 The hardwired intercept of the fit
313 mODR : `float`
314 The gradient calculated by the ODR fit
315 bODR : `float`
316 The intercept calculated by the ODR fit
317 yBoxMin : `float`
318 The y value of the fitted line at xMin
319 yBoxMax : `float`
320 The y value of the fitted line at xMax
321 bPerpMin : `float`
322 The intercept of the perpendicular line that goes through xMin
323 bPerpMax : `float`
324 The intercept of the perpendicular line that goes through xMax
325 mODR2 : `float`
326 The gradient from the second round of fitting
327 bODR2 : `float`
328 The intercept from the second round of fitting
329 mPerp : `float`
330 The gradient of the line perpendicular to the line from the
331 second fit
332 Notes
333 -----
334 The code does two rounds of fitting, the first is initiated using the
335 hardwired values given in the `paramDict` parameter and is done using
336 an Orthogonal Distance Regression fit to the points defined by the
337 box of xMin, xMax, yMin and yMax. Once this fitting has been done a
338 perpendicular bisector is calculated at either end of the line and
339 only points that fall within these lines are used to recalculate the fit.
340 """
342 # Points to use for the fit
343 fitPoints = np.where(
344 (xs > paramDict["xMin"])
345 & (xs < paramDict["xMax"])
346 & (ys > paramDict["yMin"])
347 & (ys < paramDict["yMax"])
348 )[0]
350 linear = scipyODR.polynomial(1)
352 data = scipyODR.Data(xs[fitPoints], ys[fitPoints])
353 odr = scipyODR.ODR(data, linear, beta0=[paramDict["bHW"], paramDict["mHW"]])
354 params = odr.run()
355 mODR = float(params.beta[1])
356 bODR = float(params.beta[0])
358 paramsOut = {
359 "xMin": paramDict["xMin"],
360 "xMax": paramDict["xMax"],
361 "yMin": paramDict["yMin"],
362 "yMax": paramDict["yMax"],
363 "mHW": paramDict["mHW"],
364 "bHW": paramDict["bHW"],
365 "mODR": mODR,
366 "bODR": bODR,
367 }
369 # Having found the initial fit calculate perpendicular ends
370 mPerp = -1.0 / mODR
371 # When the gradient is really steep we need to use
372 # the y limits of the box rather than the x ones
374 if np.abs(mODR) > 1:
375 yBoxMin = paramDict["yMin"]
376 xBoxMin = (yBoxMin - bODR) / mODR
377 yBoxMax = paramDict["yMax"]
378 xBoxMax = (yBoxMax - bODR) / mODR
379 else:
380 yBoxMin = mODR * paramDict["xMin"] + bODR
381 xBoxMin = paramDict["xMin"]
382 yBoxMax = mODR * paramDict["xMax"] + bODR
383 xBoxMax = paramDict["xMax"]
385 bPerpMin = yBoxMin - mPerp * xBoxMin
387 paramsOut["yBoxMin"] = yBoxMin
388 paramsOut["bPerpMin"] = bPerpMin
390 bPerpMax = yBoxMax - mPerp * xBoxMax
392 paramsOut["yBoxMax"] = yBoxMax
393 paramsOut["bPerpMax"] = bPerpMax
395 # Use these perpendicular lines to chose the data and refit
396 fitPoints = (ys > mPerp * xs + bPerpMin) & (ys < mPerp * xs + bPerpMax)
397 data = scipyODR.Data(xs[fitPoints], ys[fitPoints])
398 odr = scipyODR.ODR(data, linear, beta0=[bODR, mODR])
399 params = odr.run()
400 mODR = float(params.beta[1])
401 bODR = float(params.beta[0])
403 paramsOut["mODR2"] = float(params.beta[1])
404 paramsOut["bODR2"] = float(params.beta[0])
406 paramsOut["mPerp"] = -1.0 / paramsOut["mODR2"]
408 return paramsOut
411def perpDistance(p1, p2, points):
412 """Calculate the perpendicular distance to a line from a point
413 Parameters
414 ----------
415 p1 : `numpy.ndarray`
416 A point on the line
417 p2 : `numpy.ndarray`
418 Another point on the line
419 points : `zip`
420 The points to calculate the distance to
421 Returns
422 -------
423 dists : `list`
424 The distances from the line to the points. Uses the cross
425 product to work this out.
426 """
427 dists = []
428 for point in points:
429 point = np.array(point)
430 distToLine = np.cross(p2 - p1, point - p1) / np.linalg.norm(p2 - p1)
431 dists.append(distToLine)
433 return dists
436def mkColormap(colorNames):
437 """Make a colormap from the list of color names.
438 Parameters
439 ----------
440 colorNames : `list`
441 A list of strings that correspond to matplotlib
442 named colors.
443 Returns
444 -------
445 cmap : `matplotlib.colors.LinearSegmentedColormap`
446 """
448 nums = np.linspace(0, 1, len(colorNames))
449 blues = []
450 greens = []
451 reds = []
452 for (num, color) in zip(nums, colorNames):
453 r, g, b = colors.colorConverter.to_rgb(color)
454 blues.append((num, b, b))
455 greens.append((num, g, g))
456 reds.append((num, r, r))
458 colorDict = {"blue": blues, "red": reds, "green": greens}
459 cmap = colors.LinearSegmentedColormap("newCmap", colorDict)
460 return cmap
463def extremaSort(xs):
464 """Return the ids of the points reordered so that those
465 furthest from the median, in absolute terms, are last.
466 Parameters
467 ----------
468 xs : `np.array`
469 An array of the values to sort
470 Returns
471 -------
472 ids : `np.array`
473 """
475 med = np.nanmedian(xs)
476 dists = np.abs(xs - med)
477 ids = np.argsort(dists)
478 return ids
481def addSummaryPlot(fig, loc, sumStats, label):
482 """Add a summary subplot to the figure.
484 Parameters
485 ----------
486 fig : `matplotlib.figure.Figure`
487 The figure that the summary plot is to be added to.
488 loc : `matplotlib.gridspec.SubplotSpec` or `int` or `(int, int, index`
489 Describes the location in the figure to put the summary plot,
490 can be a gridspec SubplotSpec, a 3 digit integer where the first
491 digit is the number of rows, the second is the number of columns
492 and the third is the index. This is the same for the tuple
493 of int, int, index.
494 sumStats : `dict`
495 A dictionary where the patchIds are the keys which store the R.A.
496 and the dec of the corners of the patch, along with a summary
497 statistic for each patch.
499 Returns
500 -------
501 fig : `matplotlib.figure.Figure`
502 """
504 # Add the subplot to the relevant place in the figure
505 # and sort the axis out
506 axCorner = fig.add_subplot(loc)
507 axCorner.yaxis.tick_right()
508 axCorner.yaxis.set_label_position("right")
509 axCorner.xaxis.tick_top()
510 axCorner.xaxis.set_label_position("top")
511 axCorner.set_aspect("equal")
513 # Plot the corners of the patches and make the color
514 # coded rectangles for each patch, the colors show
515 # the median of the given value in the patch
516 patches = []
517 colors = []
518 for dataId in sumStats.keys():
519 (corners, stat) = sumStats[dataId]
520 ra = corners[0][0].asDegrees()
521 dec = corners[0][1].asDegrees()
522 xy = (ra, dec)
523 width = corners[2][0].asDegrees() - ra
524 height = corners[2][1].asDegrees() - dec
525 patches.append(Rectangle(xy, width, height))
526 colors.append(stat)
527 ras = [ra.asDegrees() for (ra, dec) in corners]
528 decs = [dec.asDegrees() for (ra, dec) in corners]
529 axCorner.plot(ras + [ras[0]], decs + [decs[0]], "k", lw=0.5)
530 cenX = ra + width / 2
531 cenY = dec + height / 2
532 if dataId != "tract":
533 axCorner.annotate(dataId, (cenX, cenY), color="k", fontsize=4, ha="center", va="center")
535 # Set the bad color to transparent and make a masked array
536 cmapPatch = plt.cm.coolwarm.copy()
537 cmapPatch.set_bad(color="none")
538 colors = np.ma.array(colors, mask=np.isnan(colors))
539 collection = PatchCollection(patches, cmap=cmapPatch)
540 collection.set_array(colors)
541 axCorner.add_collection(collection)
543 # Add some labels
544 axCorner.set_xlabel("R.A. (deg)", fontsize=7)
545 axCorner.set_ylabel("Dec. (deg)", fontsize=7)
546 axCorner.tick_params(axis="both", labelsize=6, length=0, pad=1.5)
547 axCorner.invert_xaxis()
549 # Add a colorbar
550 pos = axCorner.get_position()
551 yOffset = (pos.y1 - pos.y0) / 3
552 cax = fig.add_axes([pos.x0, pos.y1 + yOffset, pos.x1 - pos.x0, 0.025])
553 plt.colorbar(collection, cax=cax, orientation="horizontal")
554 cax.text(
555 0.5,
556 0.48,
557 label,
558 color="k",
559 transform=cax.transAxes,
560 rotation="horizontal",
561 horizontalalignment="center",
562 verticalalignment="center",
563 fontsize=6,
564 )
565 cax.tick_params(
566 axis="x", labelsize=6, labeltop=True, labelbottom=False, bottom=False, top=True, pad=0.5, length=2
567 )
569 return fig
572class PanelConfig(Config):
573 """Configuration options for the plot panels used by DiaSkyPlot.
575 The defaults will produce a good-looking single panel plot.
576 The subplot2grid* fields correspond to matplotlib.pyplot.subplot2grid.
577 """
579 topSpinesVisible = Field[bool](
580 doc="Draw line and ticks on top of panel?",
581 default=False,
582 )
583 bottomSpinesVisible = Field[bool](
584 doc="Draw line and ticks on bottom of panel?",
585 default=True,
586 )
587 leftSpinesVisible = Field[bool](
588 doc="Draw line and ticks on left side of panel?",
589 default=True,
590 )
591 rightSpinesVisible = Field[bool](
592 doc="Draw line and ticks on right side of panel?",
593 default=True,
594 )
595 subplot2gridShapeRow = Field[int](
596 doc="Number of rows of the grid in which to place axis.",
597 default=10,
598 )
599 subplot2gridShapeColumn = Field[int](
600 doc="Number of columns of the grid in which to place axis.",
601 default=10,
602 )
603 subplot2gridLocRow = Field[int](
604 doc="Row of the axis location within the grid.",
605 default=1,
606 )
607 subplot2gridLocColumn = Field[int](
608 doc="Column of the axis location within the grid.",
609 default=1,
610 )
611 subplot2gridRowspan = Field[int](
612 doc="Number of rows for the axis to span downwards.",
613 default=5,
614 )
615 subplot2gridColspan = Field[int](
616 doc="Number of rows for the axis to span to the right.",
617 default=5,
618 )