Coverage for python/lsst/analysis/tools/actions/plot/plotUtils.py: 7%
234 statements
« prev ^ index » next coverage.py v6.4.4, created at 2022-09-11 02:06 -0700
« prev ^ index » next coverage.py v6.4.4, created at 2022-09-11 02:06 -0700
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21from __future__ import annotations
23from typing import List, Tuple
25import matplotlib
26import matplotlib.pyplot as plt
27import numpy as np
28import scipy.odr as scipyODR
29from lsst.geom import Box2D, SpherePoint, degrees
30from matplotlib import colors
31from matplotlib.collections import PatchCollection
32from matplotlib.patches import Rectangle
34null_formatter = matplotlib.ticker.NullFormatter()
37def parsePlotInfo(dataId, runName, tableName, bands, plotName, SN):
38 """Parse plot info from the dataId
39 Parameters
40 ----------
41 dataId : `lsst.daf.butler.core.dimensions.`
42 `_coordinate._ExpandedTupleDataCoordinate`
43 runName : `str`
44 Returns
45 -------
46 plotInfo : `dict`
47 """
48 plotInfo = {"run": runName, "tableName": tableName, "plotName": plotName, "SN": SN}
50 for dataInfo in dataId:
51 plotInfo[dataInfo.name] = dataId[dataInfo.name]
53 bandStr = ""
54 for band in bands:
55 bandStr += ", " + band
56 plotInfo["bands"] = bandStr[2:]
58 if "tract" not in plotInfo.keys():
59 plotInfo["tract"] = "N/A"
60 if "visit" not in plotInfo.keys():
61 plotInfo["visit"] = "N/A"
63 return plotInfo
66def generateSummaryStats(data, skymap, plotInfo):
67 """Generate a summary statistic in each patch or detector
68 Parameters
69 ----------
70 data : `dict`
71 skymap : `lsst.skymap.ringsSkyMap.RingsSkyMap`
72 plotInfo : `dict`
74 Returns
75 -------
76 patchInfoDict : `dict`
77 """
79 # TODO: what is the more generic type of skymap?
80 tractInfo = skymap.generateTract(plotInfo["tract"])
81 tractWcs = tractInfo.getWcs()
83 # For now also convert the gen 2 patchIds to gen 3
84 if "y" in data.keys():
85 yCol = "y"
86 elif "yStars" in data.keys():
87 yCol = "yStars"
88 elif "yGalaxies" in data.keys():
89 yCol = "yGalaxies"
90 elif "yUnknowns" in data.keys():
91 yCol = "yUnknowns"
93 patchInfoDict = {}
94 maxPatchNum = tractInfo.num_patches.x * tractInfo.num_patches.y
95 patches = np.arange(0, maxPatchNum, 1)
96 for patch in patches:
97 if patch is None:
98 continue
99 # Once the objectTable_tract catalogues are using gen 3 patches
100 # this will go away
101 onPatch = data["patch"] == patch
102 stat = np.nanmedian(data[yCol][onPatch])
103 try:
104 patchTuple = (int(patch.split(",")[0]), int(patch.split(",")[-1]))
105 patchInfo = tractInfo.getPatchInfo(patchTuple)
106 gen3PatchId = tractInfo.getSequentialPatchIndex(patchInfo)
107 except AttributeError:
108 # For native gen 3 tables the patches don't need converting
109 # When we are no longer looking at the gen 2 -> gen 3
110 # converted repos we can tidy this up
111 gen3PatchId = patch
112 patchInfo = tractInfo.getPatchInfo(patch)
114 corners = Box2D(patchInfo.getInnerBBox()).getCorners()
115 skyCoords = tractWcs.pixelToSky(corners)
117 patchInfoDict[gen3PatchId] = (skyCoords, stat)
119 tractCorners = Box2D(tractInfo.getBBox()).getCorners()
120 skyCoords = tractWcs.pixelToSky(tractCorners)
121 patchInfoDict["tract"] = (skyCoords, np.nan)
123 return patchInfoDict
126def generateSummaryStatsVisit(cat, colName, visitSummaryTable, plotInfo):
127 """Generate a summary statistic in each patch or detector
128 Parameters
129 ----------
130 cat : `pandas.core.frame.DataFrame`
131 colName : `str`
132 visitSummaryTable : `pandas.core.frame.DataFrame`
133 plotInfo : `dict`
134 Returns
135 -------
136 visitInfoDict : `dict`
137 """
139 visitInfoDict = {}
140 for ccd in cat.detector.unique():
141 if ccd is None:
142 continue
143 onCcd = cat["detector"] == ccd
144 stat = np.nanmedian(cat[colName].values[onCcd])
146 sumRow = visitSummaryTable["id"] == ccd
147 corners = zip(visitSummaryTable["raCorners"][sumRow][0], visitSummaryTable["decCorners"][sumRow][0])
148 cornersOut = []
149 for (ra, dec) in corners:
150 corner = SpherePoint(ra, dec, units=degrees)
151 cornersOut.append(corner)
153 visitInfoDict[ccd] = (cornersOut, stat)
155 return visitInfoDict
158# Inspired by matplotlib.testing.remove_ticks_and_titles
159def get_and_remove_axis_text(ax) -> Tuple[List[str], List[np.ndarray]]:
160 """Remove text from an Axis and its children and return with line points.
161 Parameters
162 ----------
163 ax : `plt.Axis`
164 A matplotlib figure axis.
165 Returns
166 -------
167 texts : `List[str]`
168 A list of all text strings (title and axis/legend/tick labels).
169 line_xys : `List[numpy.ndarray]`
170 A list of all line ``_xy`` attributes (arrays of shape ``(N, 2)``).
171 """
172 line_xys = [line._xy for line in ax.lines]
173 texts = [text.get_text() for text in (ax.title, ax.xaxis.label, ax.yaxis.label)]
174 ax.set_title("")
175 ax.set_xlabel("")
176 ax.set_ylabel("")
178 try:
179 texts_legend = ax.get_legend().texts
180 texts.extend(text.get_text() for text in texts_legend)
181 for text in texts_legend:
182 text.set_alpha(0)
183 except AttributeError:
184 pass
186 for idx in range(len(ax.texts)):
187 texts.append(ax.texts[idx].get_text())
188 ax.texts[idx].set_text("")
190 ax.xaxis.set_major_formatter(null_formatter)
191 ax.xaxis.set_minor_formatter(null_formatter)
192 ax.yaxis.set_major_formatter(null_formatter)
193 ax.yaxis.set_minor_formatter(null_formatter)
194 try:
195 ax.zaxis.set_major_formatter(null_formatter)
196 ax.zaxis.set_minor_formatter(null_formatter)
197 except AttributeError:
198 pass
199 for child in ax.child_axes:
200 texts_child, lines_child = get_and_remove_axis_text(child)
201 texts.extend(texts_child)
203 return texts, line_xys
206def get_and_remove_figure_text(figure: plt.Figure):
207 """Remove text from a Figure and its Axes and return with line points.
208 Parameters
209 ----------
210 figure : `matplotlib.pyplot.Figure`
211 A matplotlib figure.
212 Returns
213 -------
214 texts : `List[str]`
215 A list of all text strings (title and axis/legend/tick labels).
216 line_xys : `List[numpy.ndarray]`, (N, 2)
217 A list of all line ``_xy`` attributes (arrays of shape ``(N, 2)``).
218 """
219 texts = [str(figure._suptitle)]
220 lines = []
221 figure.suptitle("")
223 texts.extend(text.get_text() for text in figure.texts)
224 figure.texts = []
226 for ax in figure.get_axes():
227 texts_ax, lines_ax = get_and_remove_axis_text(ax)
228 texts.extend(texts_ax)
229 lines.extend(lines_ax)
231 return texts, lines
234def addPlotInfo(fig, plotInfo):
235 """Add useful information to the plot
236 Parameters
237 ----------
238 fig : `matplotlib.figure.Figure`
239 plotInfo : `dict`
240 Returns
241 -------
242 fig : `matplotlib.figure.Figure`
243 """
245 # TO DO: figure out how to get this information
246 photocalibDataset = "None"
247 astroDataset = "None"
249 fig.text(0.01, 0.99, plotInfo["plotName"], fontsize=8, transform=fig.transFigure, ha="left", va="top")
251 run = plotInfo["run"]
252 datasetsUsed = f"\nPhotoCalib: {photocalibDataset}, Astrometry: {astroDataset}"
253 tableType = f"\nTable: {plotInfo['tableName']}"
255 dataIdText = ""
256 if "tract" in plotInfo.keys():
257 dataIdText += f", Tract: {plotInfo['tract']}"
258 if "visit" in plotInfo.keys():
259 dataIdText += f", Visit: {plotInfo['visit']}"
261 bandText = ""
262 for band in plotInfo["bands"]:
263 bandText += band + ", "
264 bandsText = f", Bands: {bandText[:-2]}"
265 SNText = f", S/N: {plotInfo['SN']}"
266 infoText = f"\n{run}{datasetsUsed}{tableType}{dataIdText}{bandsText}{SNText}"
267 fig.text(0.01, 0.98, infoText, fontsize=7, transform=fig.transFigure, alpha=0.6, ha="left", va="top")
269 return fig
272def stellarLocusFit(xs, ys, paramDict):
273 """Make a fit to the stellar locus
274 Parameters
275 ----------
276 xs : `numpy.ndarray`
277 The color on the xaxis
278 ys : `numpy.ndarray`
279 The color on the yaxis
280 paramDict : lsst.pex.config.dictField.Dict
281 A dictionary of parameters for line fitting
282 xMin : `float`
283 The minimum x edge of the box to use for initial fitting
284 xMax : `float`
285 The maximum x edge of the box to use for initial fitting
286 yMin : `float`
287 The minimum y edge of the box to use for initial fitting
288 yMax : `float`
289 The maximum y edge of the box to use for initial fitting
290 mHW : `float`
291 The hardwired gradient for the fit
292 bHW : `float`
293 The hardwired intercept of the fit
294 Returns
295 -------
296 paramsOut : `dict`
297 A dictionary of the calculated fit parameters
298 xMin : `float`
299 The minimum x edge of the box to use for initial fitting
300 xMax : `float`
301 The maximum x edge of the box to use for initial fitting
302 yMin : `float`
303 The minimum y edge of the box to use for initial fitting
304 yMax : `float`
305 The maximum y edge of the box to use for initial fitting
306 mHW : `float`
307 The hardwired gradient for the fit
308 bHW : `float`
309 The hardwired intercept of the fit
310 mODR : `float`
311 The gradient calculated by the ODR fit
312 bODR : `float`
313 The intercept calculated by the ODR fit
314 yBoxMin : `float`
315 The y value of the fitted line at xMin
316 yBoxMax : `float`
317 The y value of the fitted line at xMax
318 bPerpMin : `float`
319 The intercept of the perpendicular line that goes through xMin
320 bPerpMax : `float`
321 The intercept of the perpendicular line that goes through xMax
322 mODR2 : `float`
323 The gradient from the second round of fitting
324 bODR2 : `float`
325 The intercept from the second round of fitting
326 mPerp : `float`
327 The gradient of the line perpendicular to the line from the
328 second fit
329 Notes
330 -----
331 The code does two rounds of fitting, the first is initiated using the
332 hardwired values given in the `paramDict` parameter and is done using
333 an Orthogonal Distance Regression fit to the points defined by the
334 box of xMin, xMax, yMin and yMax. Once this fitting has been done a
335 perpendicular bisector is calculated at either end of the line and
336 only points that fall within these lines are used to recalculate the fit.
337 """
339 # Points to use for the fit
340 fitPoints = np.where(
341 (xs > paramDict["xMin"])
342 & (xs < paramDict["xMax"])
343 & (ys > paramDict["yMin"])
344 & (ys < paramDict["yMax"])
345 )[0]
347 linear = scipyODR.polynomial(1)
349 data = scipyODR.Data(xs[fitPoints], ys[fitPoints])
350 odr = scipyODR.ODR(data, linear, beta0=[paramDict["bHW"], paramDict["mHW"]])
351 params = odr.run()
352 mODR = float(params.beta[1])
353 bODR = float(params.beta[0])
355 paramsOut = {
356 "xMin": paramDict["xMin"],
357 "xMax": paramDict["xMax"],
358 "yMin": paramDict["yMin"],
359 "yMax": paramDict["yMax"],
360 "mHW": paramDict["mHW"],
361 "bHW": paramDict["bHW"],
362 "mODR": mODR,
363 "bODR": bODR,
364 }
366 # Having found the initial fit calculate perpendicular ends
367 mPerp = -1.0 / mODR
368 # When the gradient is really steep we need to use
369 # the y limits of the box rather than the x ones
371 if np.abs(mODR) > 1:
372 yBoxMin = paramDict["yMin"]
373 xBoxMin = (yBoxMin - bODR) / mODR
374 yBoxMax = paramDict["yMax"]
375 xBoxMax = (yBoxMax - bODR) / mODR
376 else:
377 yBoxMin = mODR * paramDict["xMin"] + bODR
378 xBoxMin = paramDict["xMin"]
379 yBoxMax = mODR * paramDict["xMax"] + bODR
380 xBoxMax = paramDict["xMax"]
382 bPerpMin = yBoxMin - mPerp * xBoxMin
384 paramsOut["yBoxMin"] = yBoxMin
385 paramsOut["bPerpMin"] = bPerpMin
387 bPerpMax = yBoxMax - mPerp * xBoxMax
389 paramsOut["yBoxMax"] = yBoxMax
390 paramsOut["bPerpMax"] = bPerpMax
392 # Use these perpendicular lines to chose the data and refit
393 fitPoints = (ys > mPerp * xs + bPerpMin) & (ys < mPerp * xs + bPerpMax)
394 data = scipyODR.Data(xs[fitPoints], ys[fitPoints])
395 odr = scipyODR.ODR(data, linear, beta0=[bODR, mODR])
396 params = odr.run()
397 mODR = float(params.beta[1])
398 bODR = float(params.beta[0])
400 paramsOut["mODR2"] = float(params.beta[1])
401 paramsOut["bODR2"] = float(params.beta[0])
403 paramsOut["mPerp"] = -1.0 / paramsOut["mODR2"]
405 return paramsOut
408def perpDistance(p1, p2, points):
409 """Calculate the perpendicular distance to a line from a point
410 Parameters
411 ----------
412 p1 : `numpy.ndarray`
413 A point on the line
414 p2 : `numpy.ndarray`
415 Another point on the line
416 points : `zip`
417 The points to calculate the distance to
418 Returns
419 -------
420 dists : `list`
421 The distances from the line to the points. Uses the cross
422 product to work this out.
423 """
424 dists = []
425 for point in points:
426 point = np.array(point)
427 distToLine = np.cross(p2 - p1, point - p1) / np.linalg.norm(p2 - p1)
428 dists.append(distToLine)
430 return dists
433def mkColormap(colorNames):
434 """Make a colormap from the list of color names.
435 Parameters
436 ----------
437 colorNames : `list`
438 A list of strings that correspond to matplotlib
439 named colors.
440 Returns
441 -------
442 cmap : `matplotlib.colors.LinearSegmentedColormap`
443 """
445 nums = np.linspace(0, 1, len(colorNames))
446 blues = []
447 greens = []
448 reds = []
449 for (num, color) in zip(nums, colorNames):
450 r, g, b = colors.colorConverter.to_rgb(color)
451 blues.append((num, b, b))
452 greens.append((num, g, g))
453 reds.append((num, r, r))
455 colorDict = {"blue": blues, "red": reds, "green": greens}
456 cmap = colors.LinearSegmentedColormap("newCmap", colorDict)
457 return cmap
460def extremaSort(xs):
461 """Return the ids of the points reordered so that those
462 furthest from the median, in absolute terms, are last.
463 Parameters
464 ----------
465 xs : `np.array`
466 An array of the values to sort
467 Returns
468 -------
469 ids : `np.array`
470 """
472 med = np.median(xs)
473 dists = np.abs(xs - med)
474 ids = np.argsort(dists)
475 return ids
478def addSummaryPlot(fig, loc, sumStats, label):
479 """Add a summary subplot to the figure.
481 Parameters
482 ----------
483 fig : `matplotlib.figure.Figure`
484 The figure that the summary plot is to be added to.
485 loc : `matplotlib.gridspec.SubplotSpec` or `int` or `(int, int, index`
486 Describes the location in the figure to put the summary plot,
487 can be a gridspec SubplotSpec, a 3 digit integer where the first
488 digit is the number of rows, the second is the number of columns
489 and the third is the index. This is the same for the tuple
490 of int, int, index.
491 sumStats : `dict`
492 A dictionary where the patchIds are the keys which store the R.A.
493 and the dec of the corners of the patch, along with a summary
494 statistic for each patch.
496 Returns
497 -------
498 fig : `matplotlib.figure.Figure`
499 """
501 # Add the subplot to the relevant place in the figure
502 # and sort the axis out
503 axCorner = fig.add_subplot(loc)
504 axCorner.yaxis.tick_right()
505 axCorner.yaxis.set_label_position("right")
506 axCorner.xaxis.tick_top()
507 axCorner.xaxis.set_label_position("top")
508 axCorner.set_aspect("equal")
510 # Plot the corners of the patches and make the color
511 # coded rectangles for each patch, the colors show
512 # the median of the given value in the patch
513 patches = []
514 colors = []
515 for dataId in sumStats.keys():
516 (corners, stat) = sumStats[dataId]
517 ra = corners[0][0].asDegrees()
518 dec = corners[0][1].asDegrees()
519 xy = (ra, dec)
520 width = corners[2][0].asDegrees() - ra
521 height = corners[2][1].asDegrees() - dec
522 patches.append(Rectangle(xy, width, height))
523 colors.append(stat)
524 ras = [ra.asDegrees() for (ra, dec) in corners]
525 decs = [dec.asDegrees() for (ra, dec) in corners]
526 axCorner.plot(ras + [ras[0]], decs + [decs[0]], "k", lw=0.5)
527 cenX = ra + width / 2
528 cenY = dec + height / 2
529 if dataId != "tract":
530 axCorner.annotate(dataId, (cenX, cenY), color="k", fontsize=4, ha="center", va="center")
532 # Set the bad color to transparent and make a masked array
533 cmapPatch = plt.cm.coolwarm.copy()
534 cmapPatch.set_bad(color="none")
535 colors = np.ma.array(colors, mask=np.isnan(colors))
536 collection = PatchCollection(patches, cmap=cmapPatch)
537 collection.set_array(colors)
538 axCorner.add_collection(collection)
540 # Add some labels
541 axCorner.set_xlabel("R.A. (deg)", fontsize=7)
542 axCorner.set_ylabel("Dec. (deg)", fontsize=7)
543 axCorner.tick_params(axis="both", labelsize=6, length=0, pad=1.5)
544 axCorner.invert_xaxis()
546 # Add a colorbar
547 pos = axCorner.get_position()
548 yOffset = (pos.y1 - pos.y0) / 3
549 cax = fig.add_axes([pos.x0, pos.y1 + yOffset, pos.x1 - pos.x0, 0.025])
550 plt.colorbar(collection, cax=cax, orientation="horizontal")
551 cax.text(
552 0.5,
553 0.48,
554 label,
555 color="k",
556 transform=cax.transAxes,
557 rotation="horizontal",
558 horizontalalignment="center",
559 verticalalignment="center",
560 fontsize=6,
561 )
562 cax.tick_params(
563 axis="x", labelsize=6, labeltop=True, labelbottom=False, bottom=False, top=True, pad=0.5, length=2
564 )
566 return fig