Coverage for python/lsst/analysis/tools/actions/plot/plotUtils.py: 7%
237 statements
« prev ^ index » next coverage.py v6.5.0, created at 2022-10-05 12:17 -0700
« prev ^ index » next coverage.py v6.5.0, created at 2022-10-05 12:17 -0700
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21from __future__ import annotations
23from typing import List, Tuple
25import matplotlib
26import matplotlib.pyplot as plt
27import numpy as np
28import scipy.odr as scipyODR
29from lsst.geom import Box2D, SpherePoint, degrees
30from matplotlib import colors
31from matplotlib.collections import PatchCollection
32from matplotlib.patches import Rectangle
34null_formatter = matplotlib.ticker.NullFormatter()
37def parsePlotInfo(dataId, runName, tableName, bands, plotName, SN):
38 """Parse plot info from the dataId
39 Parameters
40 ----------
41 dataId : `lsst.daf.butler.core.dimensions.`
42 `_coordinate._ExpandedTupleDataCoordinate`
43 runName : `str`
44 Returns
45 -------
46 plotInfo : `dict`
47 """
48 plotInfo = {"run": runName, "tableName": tableName, "plotName": plotName, "SN": SN}
50 for dataInfo in dataId:
51 plotInfo[dataInfo.name] = dataId[dataInfo.name]
53 bandStr = ""
54 for band in bands:
55 bandStr += ", " + band
56 plotInfo["bands"] = bandStr[2:]
58 if "tract" not in plotInfo.keys():
59 plotInfo["tract"] = "N/A"
60 if "visit" not in plotInfo.keys():
61 plotInfo["visit"] = "N/A"
63 return plotInfo
66def generateSummaryStats(data, skymap, plotInfo):
67 """Generate a summary statistic in each patch or detector
68 Parameters
69 ----------
70 data : `dict`
71 skymap : `lsst.skymap.ringsSkyMap.RingsSkyMap`
72 plotInfo : `dict`
74 Returns
75 -------
76 patchInfoDict : `dict`
77 """
79 # TODO: what is the more generic type of skymap?
80 tractInfo = skymap.generateTract(plotInfo["tract"])
81 tractWcs = tractInfo.getWcs()
83 # For now also convert the gen 2 patchIds to gen 3
84 if "y" in data.keys():
85 yCol = "y"
86 elif "yStars" in data.keys():
87 yCol = "yStars"
88 elif "yGalaxies" in data.keys():
89 yCol = "yGalaxies"
90 elif "yUnknowns" in data.keys():
91 yCol = "yUnknowns"
93 patchInfoDict = {}
94 maxPatchNum = tractInfo.num_patches.x * tractInfo.num_patches.y
95 patches = np.arange(0, maxPatchNum, 1)
96 for patch in patches:
97 if patch is None:
98 continue
99 # Once the objectTable_tract catalogues are using gen 3 patches
100 # this will go away
101 onPatch = data["patch"] == patch
102 stat = np.nanmedian(data[yCol][onPatch])
103 try:
104 patchTuple = (int(patch.split(",")[0]), int(patch.split(",")[-1]))
105 patchInfo = tractInfo.getPatchInfo(patchTuple)
106 gen3PatchId = tractInfo.getSequentialPatchIndex(patchInfo)
107 except AttributeError:
108 # For native gen 3 tables the patches don't need converting
109 # When we are no longer looking at the gen 2 -> gen 3
110 # converted repos we can tidy this up
111 gen3PatchId = patch
112 patchInfo = tractInfo.getPatchInfo(patch)
114 corners = Box2D(patchInfo.getInnerBBox()).getCorners()
115 skyCoords = tractWcs.pixelToSky(corners)
117 patchInfoDict[gen3PatchId] = (skyCoords, stat)
119 tractCorners = Box2D(tractInfo.getBBox()).getCorners()
120 skyCoords = tractWcs.pixelToSky(tractCorners)
121 patchInfoDict["tract"] = (skyCoords, np.nan)
123 return patchInfoDict
126def generateSummaryStatsVisit(cat, colName, visitSummaryTable, plotInfo):
127 """Generate a summary statistic in each patch or detector
128 Parameters
129 ----------
130 cat : `pandas.core.frame.DataFrame`
131 colName : `str`
132 visitSummaryTable : `pandas.core.frame.DataFrame`
133 plotInfo : `dict`
134 Returns
135 -------
136 visitInfoDict : `dict`
137 """
139 visitInfoDict = {}
140 for ccd in cat.detector.unique():
141 if ccd is None:
142 continue
143 onCcd = cat["detector"] == ccd
144 stat = np.nanmedian(cat[colName].values[onCcd])
146 sumRow = visitSummaryTable["id"] == ccd
147 corners = zip(visitSummaryTable["raCorners"][sumRow][0], visitSummaryTable["decCorners"][sumRow][0])
148 cornersOut = []
149 for (ra, dec) in corners:
150 corner = SpherePoint(ra, dec, units=degrees)
151 cornersOut.append(corner)
153 visitInfoDict[ccd] = (cornersOut, stat)
155 return visitInfoDict
158# Inspired by matplotlib.testing.remove_ticks_and_titles
159def get_and_remove_axis_text(ax) -> Tuple[List[str], List[np.ndarray]]:
160 """Remove text from an Axis and its children and return with line points.
161 Parameters
162 ----------
163 ax : `plt.Axis`
164 A matplotlib figure axis.
165 Returns
166 -------
167 texts : `List[str]`
168 A list of all text strings (title and axis/legend/tick labels).
169 line_xys : `List[numpy.ndarray]`
170 A list of all line ``_xy`` attributes (arrays of shape ``(N, 2)``).
171 """
172 line_xys = [line._xy for line in ax.lines]
173 texts = [text.get_text() for text in (ax.title, ax.xaxis.label, ax.yaxis.label)]
174 ax.set_title("")
175 ax.set_xlabel("")
176 ax.set_ylabel("")
178 try:
179 texts_legend = ax.get_legend().texts
180 texts.extend(text.get_text() for text in texts_legend)
181 for text in texts_legend:
182 text.set_alpha(0)
183 except AttributeError:
184 pass
186 for idx in range(len(ax.texts)):
187 texts.append(ax.texts[idx].get_text())
188 ax.texts[idx].set_text("")
190 ax.xaxis.set_major_formatter(null_formatter)
191 ax.xaxis.set_minor_formatter(null_formatter)
192 ax.yaxis.set_major_formatter(null_formatter)
193 ax.yaxis.set_minor_formatter(null_formatter)
194 try:
195 ax.zaxis.set_major_formatter(null_formatter)
196 ax.zaxis.set_minor_formatter(null_formatter)
197 except AttributeError:
198 pass
199 for child in ax.child_axes:
200 texts_child, lines_child = get_and_remove_axis_text(child)
201 texts.extend(texts_child)
203 return texts, line_xys
206def get_and_remove_figure_text(figure: plt.Figure):
207 """Remove text from a Figure and its Axes and return with line points.
208 Parameters
209 ----------
210 figure : `matplotlib.pyplot.Figure`
211 A matplotlib figure.
212 Returns
213 -------
214 texts : `List[str]`
215 A list of all text strings (title and axis/legend/tick labels).
216 line_xys : `List[numpy.ndarray]`, (N, 2)
217 A list of all line ``_xy`` attributes (arrays of shape ``(N, 2)``).
218 """
219 texts = [str(figure._suptitle)]
220 lines = []
221 figure.suptitle("")
223 texts.extend(text.get_text() for text in figure.texts)
224 figure.texts = []
226 for ax in figure.get_axes():
227 texts_ax, lines_ax = get_and_remove_axis_text(ax)
228 texts.extend(texts_ax)
229 lines.extend(lines_ax)
231 return texts, lines
234def addPlotInfo(fig, plotInfo):
235 """Add useful information to the plot
236 Parameters
237 ----------
238 fig : `matplotlib.figure.Figure`
239 plotInfo : `dict`
240 Returns
241 -------
242 fig : `matplotlib.figure.Figure`
243 """
245 # TO DO: figure out how to get this information
246 photocalibDataset = "None"
247 astroDataset = "None"
249 fig.text(0.01, 0.99, plotInfo["plotName"], fontsize=8, transform=fig.transFigure, ha="left", va="top")
251 run = plotInfo["run"]
252 datasetsUsed = f"\nPhotoCalib: {photocalibDataset}, Astrometry: {astroDataset}"
253 tableType = f"\nTable: {plotInfo['tableName']}"
255 dataIdText = ""
256 if "tract" in plotInfo.keys():
257 dataIdText += f", Tract: {plotInfo['tract']}"
258 if "visit" in plotInfo.keys():
259 dataIdText += f", Visit: {plotInfo['visit']}"
261 bandText = ""
262 for band in plotInfo["bands"]:
263 bandText += band + ", "
264 bandsText = f", Bands: {bandText[:-2]}"
265 try:
266 SNText = f", S/N: {plotInfo['SN']}"
267 except KeyError:
268 SNText = ", S/N: -"
269 infoText = f"\n{run}{datasetsUsed}{tableType}{dataIdText}{bandsText}{SNText}"
270 fig.text(0.01, 0.98, infoText, fontsize=7, transform=fig.transFigure, alpha=0.6, ha="left", va="top")
272 return fig
275def stellarLocusFit(xs, ys, paramDict):
276 """Make a fit to the stellar locus
277 Parameters
278 ----------
279 xs : `numpy.ndarray`
280 The color on the xaxis
281 ys : `numpy.ndarray`
282 The color on the yaxis
283 paramDict : lsst.pex.config.dictField.Dict
284 A dictionary of parameters for line fitting
285 xMin : `float`
286 The minimum x edge of the box to use for initial fitting
287 xMax : `float`
288 The maximum x edge of the box to use for initial fitting
289 yMin : `float`
290 The minimum y edge of the box to use for initial fitting
291 yMax : `float`
292 The maximum y edge of the box to use for initial fitting
293 mHW : `float`
294 The hardwired gradient for the fit
295 bHW : `float`
296 The hardwired intercept of the fit
297 Returns
298 -------
299 paramsOut : `dict`
300 A dictionary of the calculated fit parameters
301 xMin : `float`
302 The minimum x edge of the box to use for initial fitting
303 xMax : `float`
304 The maximum x edge of the box to use for initial fitting
305 yMin : `float`
306 The minimum y edge of the box to use for initial fitting
307 yMax : `float`
308 The maximum y edge of the box to use for initial fitting
309 mHW : `float`
310 The hardwired gradient for the fit
311 bHW : `float`
312 The hardwired intercept of the fit
313 mODR : `float`
314 The gradient calculated by the ODR fit
315 bODR : `float`
316 The intercept calculated by the ODR fit
317 yBoxMin : `float`
318 The y value of the fitted line at xMin
319 yBoxMax : `float`
320 The y value of the fitted line at xMax
321 bPerpMin : `float`
322 The intercept of the perpendicular line that goes through xMin
323 bPerpMax : `float`
324 The intercept of the perpendicular line that goes through xMax
325 mODR2 : `float`
326 The gradient from the second round of fitting
327 bODR2 : `float`
328 The intercept from the second round of fitting
329 mPerp : `float`
330 The gradient of the line perpendicular to the line from the
331 second fit
332 Notes
333 -----
334 The code does two rounds of fitting, the first is initiated using the
335 hardwired values given in the `paramDict` parameter and is done using
336 an Orthogonal Distance Regression fit to the points defined by the
337 box of xMin, xMax, yMin and yMax. Once this fitting has been done a
338 perpendicular bisector is calculated at either end of the line and
339 only points that fall within these lines are used to recalculate the fit.
340 """
342 # Points to use for the fit
343 fitPoints = np.where(
344 (xs > paramDict["xMin"])
345 & (xs < paramDict["xMax"])
346 & (ys > paramDict["yMin"])
347 & (ys < paramDict["yMax"])
348 )[0]
350 linear = scipyODR.polynomial(1)
352 data = scipyODR.Data(xs[fitPoints], ys[fitPoints])
353 odr = scipyODR.ODR(data, linear, beta0=[paramDict["bHW"], paramDict["mHW"]])
354 params = odr.run()
355 mODR = float(params.beta[1])
356 bODR = float(params.beta[0])
358 paramsOut = {
359 "xMin": paramDict["xMin"],
360 "xMax": paramDict["xMax"],
361 "yMin": paramDict["yMin"],
362 "yMax": paramDict["yMax"],
363 "mHW": paramDict["mHW"],
364 "bHW": paramDict["bHW"],
365 "mODR": mODR,
366 "bODR": bODR,
367 }
369 # Having found the initial fit calculate perpendicular ends
370 mPerp = -1.0 / mODR
371 # When the gradient is really steep we need to use
372 # the y limits of the box rather than the x ones
374 if np.abs(mODR) > 1:
375 yBoxMin = paramDict["yMin"]
376 xBoxMin = (yBoxMin - bODR) / mODR
377 yBoxMax = paramDict["yMax"]
378 xBoxMax = (yBoxMax - bODR) / mODR
379 else:
380 yBoxMin = mODR * paramDict["xMin"] + bODR
381 xBoxMin = paramDict["xMin"]
382 yBoxMax = mODR * paramDict["xMax"] + bODR
383 xBoxMax = paramDict["xMax"]
385 bPerpMin = yBoxMin - mPerp * xBoxMin
387 paramsOut["yBoxMin"] = yBoxMin
388 paramsOut["bPerpMin"] = bPerpMin
390 bPerpMax = yBoxMax - mPerp * xBoxMax
392 paramsOut["yBoxMax"] = yBoxMax
393 paramsOut["bPerpMax"] = bPerpMax
395 # Use these perpendicular lines to chose the data and refit
396 fitPoints = (ys > mPerp * xs + bPerpMin) & (ys < mPerp * xs + bPerpMax)
397 data = scipyODR.Data(xs[fitPoints], ys[fitPoints])
398 odr = scipyODR.ODR(data, linear, beta0=[bODR, mODR])
399 params = odr.run()
400 mODR = float(params.beta[1])
401 bODR = float(params.beta[0])
403 paramsOut["mODR2"] = float(params.beta[1])
404 paramsOut["bODR2"] = float(params.beta[0])
406 paramsOut["mPerp"] = -1.0 / paramsOut["mODR2"]
408 return paramsOut
411def perpDistance(p1, p2, points):
412 """Calculate the perpendicular distance to a line from a point
413 Parameters
414 ----------
415 p1 : `numpy.ndarray`
416 A point on the line
417 p2 : `numpy.ndarray`
418 Another point on the line
419 points : `zip`
420 The points to calculate the distance to
421 Returns
422 -------
423 dists : `list`
424 The distances from the line to the points. Uses the cross
425 product to work this out.
426 """
427 dists = []
428 for point in points:
429 point = np.array(point)
430 distToLine = np.cross(p2 - p1, point - p1) / np.linalg.norm(p2 - p1)
431 dists.append(distToLine)
433 return dists
436def mkColormap(colorNames):
437 """Make a colormap from the list of color names.
438 Parameters
439 ----------
440 colorNames : `list`
441 A list of strings that correspond to matplotlib
442 named colors.
443 Returns
444 -------
445 cmap : `matplotlib.colors.LinearSegmentedColormap`
446 """
448 nums = np.linspace(0, 1, len(colorNames))
449 blues = []
450 greens = []
451 reds = []
452 for (num, color) in zip(nums, colorNames):
453 r, g, b = colors.colorConverter.to_rgb(color)
454 blues.append((num, b, b))
455 greens.append((num, g, g))
456 reds.append((num, r, r))
458 colorDict = {"blue": blues, "red": reds, "green": greens}
459 cmap = colors.LinearSegmentedColormap("newCmap", colorDict)
460 return cmap
463def extremaSort(xs):
464 """Return the ids of the points reordered so that those
465 furthest from the median, in absolute terms, are last.
466 Parameters
467 ----------
468 xs : `np.array`
469 An array of the values to sort
470 Returns
471 -------
472 ids : `np.array`
473 """
475 med = np.median(xs)
476 dists = np.abs(xs - med)
477 ids = np.argsort(dists)
478 return ids
481def addSummaryPlot(fig, loc, sumStats, label):
482 """Add a summary subplot to the figure.
484 Parameters
485 ----------
486 fig : `matplotlib.figure.Figure`
487 The figure that the summary plot is to be added to.
488 loc : `matplotlib.gridspec.SubplotSpec` or `int` or `(int, int, index`
489 Describes the location in the figure to put the summary plot,
490 can be a gridspec SubplotSpec, a 3 digit integer where the first
491 digit is the number of rows, the second is the number of columns
492 and the third is the index. This is the same for the tuple
493 of int, int, index.
494 sumStats : `dict`
495 A dictionary where the patchIds are the keys which store the R.A.
496 and the dec of the corners of the patch, along with a summary
497 statistic for each patch.
499 Returns
500 -------
501 fig : `matplotlib.figure.Figure`
502 """
504 # Add the subplot to the relevant place in the figure
505 # and sort the axis out
506 axCorner = fig.add_subplot(loc)
507 axCorner.yaxis.tick_right()
508 axCorner.yaxis.set_label_position("right")
509 axCorner.xaxis.tick_top()
510 axCorner.xaxis.set_label_position("top")
511 axCorner.set_aspect("equal")
513 # Plot the corners of the patches and make the color
514 # coded rectangles for each patch, the colors show
515 # the median of the given value in the patch
516 patches = []
517 colors = []
518 for dataId in sumStats.keys():
519 (corners, stat) = sumStats[dataId]
520 ra = corners[0][0].asDegrees()
521 dec = corners[0][1].asDegrees()
522 xy = (ra, dec)
523 width = corners[2][0].asDegrees() - ra
524 height = corners[2][1].asDegrees() - dec
525 patches.append(Rectangle(xy, width, height))
526 colors.append(stat)
527 ras = [ra.asDegrees() for (ra, dec) in corners]
528 decs = [dec.asDegrees() for (ra, dec) in corners]
529 axCorner.plot(ras + [ras[0]], decs + [decs[0]], "k", lw=0.5)
530 cenX = ra + width / 2
531 cenY = dec + height / 2
532 if dataId != "tract":
533 axCorner.annotate(dataId, (cenX, cenY), color="k", fontsize=4, ha="center", va="center")
535 # Set the bad color to transparent and make a masked array
536 cmapPatch = plt.cm.coolwarm.copy()
537 cmapPatch.set_bad(color="none")
538 colors = np.ma.array(colors, mask=np.isnan(colors))
539 collection = PatchCollection(patches, cmap=cmapPatch)
540 collection.set_array(colors)
541 axCorner.add_collection(collection)
543 # Add some labels
544 axCorner.set_xlabel("R.A. (deg)", fontsize=7)
545 axCorner.set_ylabel("Dec. (deg)", fontsize=7)
546 axCorner.tick_params(axis="both", labelsize=6, length=0, pad=1.5)
547 axCorner.invert_xaxis()
549 # Add a colorbar
550 pos = axCorner.get_position()
551 yOffset = (pos.y1 - pos.y0) / 3
552 cax = fig.add_axes([pos.x0, pos.y1 + yOffset, pos.x1 - pos.x0, 0.025])
553 plt.colorbar(collection, cax=cax, orientation="horizontal")
554 cax.text(
555 0.5,
556 0.48,
557 label,
558 color="k",
559 transform=cax.transAxes,
560 rotation="horizontal",
561 horizontalalignment="center",
562 verticalalignment="center",
563 fontsize=6,
564 )
565 cax.tick_params(
566 axis="x", labelsize=6, labeltop=True, labelbottom=False, bottom=False, top=True, pad=0.5, length=2
567 )
569 return fig