Coverage for python/lsst/analysis/tools/actions/plot/plotUtils.py: 11%
250 statements
« prev ^ index » next coverage.py v6.5.0, created at 2022-10-28 11:25 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2022-10-28 11:25 +0000
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21from __future__ import annotations
23__all__ = ("PanelConfig",)
25from typing import List, Tuple
27import matplotlib
28import matplotlib.pyplot as plt
29import numpy as np
30import scipy.odr as scipyODR
31from lsst.geom import Box2D, SpherePoint, degrees
32from lsst.pex.config import Config, Field
33from matplotlib import colors
34from matplotlib.collections import PatchCollection
35from matplotlib.patches import Rectangle
37null_formatter = matplotlib.ticker.NullFormatter()
40def parsePlotInfo(dataId, runName, tableName, bands, plotName, SN):
41 """Parse plot info from the dataId
42 Parameters
43 ----------
44 dataId : `lsst.daf.butler.core.dimensions.`
45 `_coordinate._ExpandedTupleDataCoordinate`
46 runName : `str`
47 Returns
48 -------
49 plotInfo : `dict`
50 """
51 plotInfo = {"run": runName, "tableName": tableName, "plotName": plotName, "SN": SN}
53 for dataInfo in dataId:
54 plotInfo[dataInfo.name] = dataId[dataInfo.name]
56 bandStr = ""
57 for band in bands:
58 bandStr += ", " + band
59 plotInfo["bands"] = bandStr[2:]
61 if "tract" not in plotInfo.keys():
62 plotInfo["tract"] = "N/A"
63 if "visit" not in plotInfo.keys():
64 plotInfo["visit"] = "N/A"
66 return plotInfo
69def generateSummaryStats(data, skymap, plotInfo):
70 """Generate a summary statistic in each patch or detector
71 Parameters
72 ----------
73 data : `dict`
74 skymap : `lsst.skymap.ringsSkyMap.RingsSkyMap`
75 plotInfo : `dict`
77 Returns
78 -------
79 patchInfoDict : `dict`
80 """
82 # TODO: what is the more generic type of skymap?
83 tractInfo = skymap.generateTract(plotInfo["tract"])
84 tractWcs = tractInfo.getWcs()
86 # For now also convert the gen 2 patchIds to gen 3
87 if "y" in data.keys():
88 yCol = "y"
89 elif "yStars" in data.keys():
90 yCol = "yStars"
91 elif "yGalaxies" in data.keys():
92 yCol = "yGalaxies"
93 elif "yUnknowns" in data.keys():
94 yCol = "yUnknowns"
96 patchInfoDict = {}
97 maxPatchNum = tractInfo.num_patches.x * tractInfo.num_patches.y
98 patches = np.arange(0, maxPatchNum, 1)
99 for patch in patches:
100 if patch is None:
101 continue
102 # Once the objectTable_tract catalogues are using gen 3 patches
103 # this will go away
104 onPatch = data["patch"] == patch
105 stat = np.nanmedian(data[yCol][onPatch])
106 try:
107 patchTuple = (int(patch.split(",")[0]), int(patch.split(",")[-1]))
108 patchInfo = tractInfo.getPatchInfo(patchTuple)
109 gen3PatchId = tractInfo.getSequentialPatchIndex(patchInfo)
110 except AttributeError:
111 # For native gen 3 tables the patches don't need converting
112 # When we are no longer looking at the gen 2 -> gen 3
113 # converted repos we can tidy this up
114 gen3PatchId = patch
115 patchInfo = tractInfo.getPatchInfo(patch)
117 corners = Box2D(patchInfo.getInnerBBox()).getCorners()
118 skyCoords = tractWcs.pixelToSky(corners)
120 patchInfoDict[gen3PatchId] = (skyCoords, stat)
122 tractCorners = Box2D(tractInfo.getBBox()).getCorners()
123 skyCoords = tractWcs.pixelToSky(tractCorners)
124 patchInfoDict["tract"] = (skyCoords, np.nan)
126 return patchInfoDict
129def generateSummaryStatsVisit(cat, colName, visitSummaryTable, plotInfo):
130 """Generate a summary statistic in each patch or detector
131 Parameters
132 ----------
133 cat : `pandas.core.frame.DataFrame`
134 colName : `str`
135 visitSummaryTable : `pandas.core.frame.DataFrame`
136 plotInfo : `dict`
137 Returns
138 -------
139 visitInfoDict : `dict`
140 """
142 visitInfoDict = {}
143 for ccd in cat.detector.unique():
144 if ccd is None:
145 continue
146 onCcd = cat["detector"] == ccd
147 stat = np.nanmedian(cat[colName].values[onCcd])
149 sumRow = visitSummaryTable["id"] == ccd
150 corners = zip(visitSummaryTable["raCorners"][sumRow][0], visitSummaryTable["decCorners"][sumRow][0])
151 cornersOut = []
152 for (ra, dec) in corners:
153 corner = SpherePoint(ra, dec, units=degrees)
154 cornersOut.append(corner)
156 visitInfoDict[ccd] = (cornersOut, stat)
158 return visitInfoDict
161# Inspired by matplotlib.testing.remove_ticks_and_titles
162def get_and_remove_axis_text(ax) -> Tuple[List[str], List[np.ndarray]]:
163 """Remove text from an Axis and its children and return with line points.
164 Parameters
165 ----------
166 ax : `plt.Axis`
167 A matplotlib figure axis.
168 Returns
169 -------
170 texts : `List[str]`
171 A list of all text strings (title and axis/legend/tick labels).
172 line_xys : `List[numpy.ndarray]`
173 A list of all line ``_xy`` attributes (arrays of shape ``(N, 2)``).
174 """
175 line_xys = [line._xy for line in ax.lines]
176 texts = [text.get_text() for text in (ax.title, ax.xaxis.label, ax.yaxis.label)]
177 ax.set_title("")
178 ax.set_xlabel("")
179 ax.set_ylabel("")
181 try:
182 texts_legend = ax.get_legend().texts
183 texts.extend(text.get_text() for text in texts_legend)
184 for text in texts_legend:
185 text.set_alpha(0)
186 except AttributeError:
187 pass
189 for idx in range(len(ax.texts)):
190 texts.append(ax.texts[idx].get_text())
191 ax.texts[idx].set_text("")
193 ax.xaxis.set_major_formatter(null_formatter)
194 ax.xaxis.set_minor_formatter(null_formatter)
195 ax.yaxis.set_major_formatter(null_formatter)
196 ax.yaxis.set_minor_formatter(null_formatter)
197 try:
198 ax.zaxis.set_major_formatter(null_formatter)
199 ax.zaxis.set_minor_formatter(null_formatter)
200 except AttributeError:
201 pass
202 for child in ax.child_axes:
203 texts_child, lines_child = get_and_remove_axis_text(child)
204 texts.extend(texts_child)
206 return texts, line_xys
209def get_and_remove_figure_text(figure: plt.Figure):
210 """Remove text from a Figure and its Axes and return with line points.
211 Parameters
212 ----------
213 figure : `matplotlib.pyplot.Figure`
214 A matplotlib figure.
215 Returns
216 -------
217 texts : `List[str]`
218 A list of all text strings (title and axis/legend/tick labels).
219 line_xys : `List[numpy.ndarray]`, (N, 2)
220 A list of all line ``_xy`` attributes (arrays of shape ``(N, 2)``).
221 """
222 texts = [str(figure._suptitle)]
223 lines = []
224 figure.suptitle("")
226 texts.extend(text.get_text() for text in figure.texts)
227 figure.texts = []
229 for ax in figure.get_axes():
230 texts_ax, lines_ax = get_and_remove_axis_text(ax)
231 texts.extend(texts_ax)
232 lines.extend(lines_ax)
234 return texts, lines
237def addPlotInfo(fig, plotInfo):
238 """Add useful information to the plot
239 Parameters
240 ----------
241 fig : `matplotlib.figure.Figure`
242 plotInfo : `dict`
243 Returns
244 -------
245 fig : `matplotlib.figure.Figure`
246 """
248 # TO DO: figure out how to get this information
249 photocalibDataset = "None"
250 astroDataset = "None"
252 fig.text(0.01, 0.99, plotInfo["plotName"], fontsize=8, transform=fig.transFigure, ha="left", va="top")
254 run = plotInfo["run"]
255 datasetsUsed = f"\nPhotoCalib: {photocalibDataset}, Astrometry: {astroDataset}"
256 tableType = f"\nTable: {plotInfo['tableName']}"
258 dataIdText = ""
259 if "tract" in plotInfo.keys():
260 dataIdText += f", Tract: {plotInfo['tract']}"
261 if "visit" in plotInfo.keys():
262 dataIdText += f", Visit: {plotInfo['visit']}"
264 bandText = ""
265 for band in plotInfo["bands"]:
266 bandText += band + ", "
267 bandsText = f", Bands: {bandText[:-2]}"
268 try:
269 SNText = f", S/N: {plotInfo['SN']}"
270 except KeyError:
271 SNText = ", S/N: -"
272 infoText = f"\n{run}{datasetsUsed}{tableType}{dataIdText}{bandsText}{SNText}"
273 fig.text(0.01, 0.98, infoText, fontsize=7, transform=fig.transFigure, alpha=0.6, ha="left", va="top")
275 return fig
278def stellarLocusFit(xs, ys, paramDict):
279 """Make a fit to the stellar locus
280 Parameters
281 ----------
282 xs : `numpy.ndarray`
283 The color on the xaxis
284 ys : `numpy.ndarray`
285 The color on the yaxis
286 paramDict : lsst.pex.config.dictField.Dict
287 A dictionary of parameters for line fitting
288 xMin : `float`
289 The minimum x edge of the box to use for initial fitting
290 xMax : `float`
291 The maximum x edge of the box to use for initial fitting
292 yMin : `float`
293 The minimum y edge of the box to use for initial fitting
294 yMax : `float`
295 The maximum y edge of the box to use for initial fitting
296 mHW : `float`
297 The hardwired gradient for the fit
298 bHW : `float`
299 The hardwired intercept of the fit
300 Returns
301 -------
302 paramsOut : `dict`
303 A dictionary of the calculated fit parameters
304 xMin : `float`
305 The minimum x edge of the box to use for initial fitting
306 xMax : `float`
307 The maximum x edge of the box to use for initial fitting
308 yMin : `float`
309 The minimum y edge of the box to use for initial fitting
310 yMax : `float`
311 The maximum y edge of the box to use for initial fitting
312 mHW : `float`
313 The hardwired gradient for the fit
314 bHW : `float`
315 The hardwired intercept of the fit
316 mODR : `float`
317 The gradient calculated by the ODR fit
318 bODR : `float`
319 The intercept calculated by the ODR fit
320 yBoxMin : `float`
321 The y value of the fitted line at xMin
322 yBoxMax : `float`
323 The y value of the fitted line at xMax
324 bPerpMin : `float`
325 The intercept of the perpendicular line that goes through xMin
326 bPerpMax : `float`
327 The intercept of the perpendicular line that goes through xMax
328 mODR2 : `float`
329 The gradient from the second round of fitting
330 bODR2 : `float`
331 The intercept from the second round of fitting
332 mPerp : `float`
333 The gradient of the line perpendicular to the line from the
334 second fit
335 Notes
336 -----
337 The code does two rounds of fitting, the first is initiated using the
338 hardwired values given in the `paramDict` parameter and is done using
339 an Orthogonal Distance Regression fit to the points defined by the
340 box of xMin, xMax, yMin and yMax. Once this fitting has been done a
341 perpendicular bisector is calculated at either end of the line and
342 only points that fall within these lines are used to recalculate the fit.
343 """
345 # Points to use for the fit
346 fitPoints = np.where(
347 (xs > paramDict["xMin"])
348 & (xs < paramDict["xMax"])
349 & (ys > paramDict["yMin"])
350 & (ys < paramDict["yMax"])
351 )[0]
353 linear = scipyODR.polynomial(1)
355 data = scipyODR.Data(xs[fitPoints], ys[fitPoints])
356 odr = scipyODR.ODR(data, linear, beta0=[paramDict["bHW"], paramDict["mHW"]])
357 params = odr.run()
358 mODR = float(params.beta[1])
359 bODR = float(params.beta[0])
361 paramsOut = {
362 "xMin": paramDict["xMin"],
363 "xMax": paramDict["xMax"],
364 "yMin": paramDict["yMin"],
365 "yMax": paramDict["yMax"],
366 "mHW": paramDict["mHW"],
367 "bHW": paramDict["bHW"],
368 "mODR": mODR,
369 "bODR": bODR,
370 }
372 # Having found the initial fit calculate perpendicular ends
373 mPerp = -1.0 / mODR
374 # When the gradient is really steep we need to use
375 # the y limits of the box rather than the x ones
377 if np.abs(mODR) > 1:
378 yBoxMin = paramDict["yMin"]
379 xBoxMin = (yBoxMin - bODR) / mODR
380 yBoxMax = paramDict["yMax"]
381 xBoxMax = (yBoxMax - bODR) / mODR
382 else:
383 yBoxMin = mODR * paramDict["xMin"] + bODR
384 xBoxMin = paramDict["xMin"]
385 yBoxMax = mODR * paramDict["xMax"] + bODR
386 xBoxMax = paramDict["xMax"]
388 bPerpMin = yBoxMin - mPerp * xBoxMin
390 paramsOut["yBoxMin"] = yBoxMin
391 paramsOut["bPerpMin"] = bPerpMin
393 bPerpMax = yBoxMax - mPerp * xBoxMax
395 paramsOut["yBoxMax"] = yBoxMax
396 paramsOut["bPerpMax"] = bPerpMax
398 # Use these perpendicular lines to chose the data and refit
399 fitPoints = (ys > mPerp * xs + bPerpMin) & (ys < mPerp * xs + bPerpMax)
400 data = scipyODR.Data(xs[fitPoints], ys[fitPoints])
401 odr = scipyODR.ODR(data, linear, beta0=[bODR, mODR])
402 params = odr.run()
403 mODR = float(params.beta[1])
404 bODR = float(params.beta[0])
406 paramsOut["mODR2"] = float(params.beta[1])
407 paramsOut["bODR2"] = float(params.beta[0])
409 paramsOut["mPerp"] = -1.0 / paramsOut["mODR2"]
411 return paramsOut
414def perpDistance(p1, p2, points):
415 """Calculate the perpendicular distance to a line from a point
416 Parameters
417 ----------
418 p1 : `numpy.ndarray`
419 A point on the line
420 p2 : `numpy.ndarray`
421 Another point on the line
422 points : `zip`
423 The points to calculate the distance to
424 Returns
425 -------
426 dists : `list`
427 The distances from the line to the points. Uses the cross
428 product to work this out.
429 """
430 dists = []
431 for point in points:
432 point = np.array(point)
433 distToLine = np.cross(p2 - p1, point - p1) / np.linalg.norm(p2 - p1)
434 dists.append(distToLine)
436 return dists
439def mkColormap(colorNames):
440 """Make a colormap from the list of color names.
441 Parameters
442 ----------
443 colorNames : `list`
444 A list of strings that correspond to matplotlib
445 named colors.
446 Returns
447 -------
448 cmap : `matplotlib.colors.LinearSegmentedColormap`
449 """
451 nums = np.linspace(0, 1, len(colorNames))
452 blues = []
453 greens = []
454 reds = []
455 for (num, color) in zip(nums, colorNames):
456 r, g, b = colors.colorConverter.to_rgb(color)
457 blues.append((num, b, b))
458 greens.append((num, g, g))
459 reds.append((num, r, r))
461 colorDict = {"blue": blues, "red": reds, "green": greens}
462 cmap = colors.LinearSegmentedColormap("newCmap", colorDict)
463 return cmap
466def extremaSort(xs):
467 """Return the ids of the points reordered so that those
468 furthest from the median, in absolute terms, are last.
469 Parameters
470 ----------
471 xs : `np.array`
472 An array of the values to sort
473 Returns
474 -------
475 ids : `np.array`
476 """
478 med = np.median(xs)
479 dists = np.abs(xs - med)
480 ids = np.argsort(dists)
481 return ids
484def addSummaryPlot(fig, loc, sumStats, label):
485 """Add a summary subplot to the figure.
487 Parameters
488 ----------
489 fig : `matplotlib.figure.Figure`
490 The figure that the summary plot is to be added to.
491 loc : `matplotlib.gridspec.SubplotSpec` or `int` or `(int, int, index`
492 Describes the location in the figure to put the summary plot,
493 can be a gridspec SubplotSpec, a 3 digit integer where the first
494 digit is the number of rows, the second is the number of columns
495 and the third is the index. This is the same for the tuple
496 of int, int, index.
497 sumStats : `dict`
498 A dictionary where the patchIds are the keys which store the R.A.
499 and the dec of the corners of the patch, along with a summary
500 statistic for each patch.
502 Returns
503 -------
504 fig : `matplotlib.figure.Figure`
505 """
507 # Add the subplot to the relevant place in the figure
508 # and sort the axis out
509 axCorner = fig.add_subplot(loc)
510 axCorner.yaxis.tick_right()
511 axCorner.yaxis.set_label_position("right")
512 axCorner.xaxis.tick_top()
513 axCorner.xaxis.set_label_position("top")
514 axCorner.set_aspect("equal")
516 # Plot the corners of the patches and make the color
517 # coded rectangles for each patch, the colors show
518 # the median of the given value in the patch
519 patches = []
520 colors = []
521 for dataId in sumStats.keys():
522 (corners, stat) = sumStats[dataId]
523 ra = corners[0][0].asDegrees()
524 dec = corners[0][1].asDegrees()
525 xy = (ra, dec)
526 width = corners[2][0].asDegrees() - ra
527 height = corners[2][1].asDegrees() - dec
528 patches.append(Rectangle(xy, width, height))
529 colors.append(stat)
530 ras = [ra.asDegrees() for (ra, dec) in corners]
531 decs = [dec.asDegrees() for (ra, dec) in corners]
532 axCorner.plot(ras + [ras[0]], decs + [decs[0]], "k", lw=0.5)
533 cenX = ra + width / 2
534 cenY = dec + height / 2
535 if dataId != "tract":
536 axCorner.annotate(dataId, (cenX, cenY), color="k", fontsize=4, ha="center", va="center")
538 # Set the bad color to transparent and make a masked array
539 cmapPatch = plt.cm.coolwarm.copy()
540 cmapPatch.set_bad(color="none")
541 colors = np.ma.array(colors, mask=np.isnan(colors))
542 collection = PatchCollection(patches, cmap=cmapPatch)
543 collection.set_array(colors)
544 axCorner.add_collection(collection)
546 # Add some labels
547 axCorner.set_xlabel("R.A. (deg)", fontsize=7)
548 axCorner.set_ylabel("Dec. (deg)", fontsize=7)
549 axCorner.tick_params(axis="both", labelsize=6, length=0, pad=1.5)
550 axCorner.invert_xaxis()
552 # Add a colorbar
553 pos = axCorner.get_position()
554 yOffset = (pos.y1 - pos.y0) / 3
555 cax = fig.add_axes([pos.x0, pos.y1 + yOffset, pos.x1 - pos.x0, 0.025])
556 plt.colorbar(collection, cax=cax, orientation="horizontal")
557 cax.text(
558 0.5,
559 0.48,
560 label,
561 color="k",
562 transform=cax.transAxes,
563 rotation="horizontal",
564 horizontalalignment="center",
565 verticalalignment="center",
566 fontsize=6,
567 )
568 cax.tick_params(
569 axis="x", labelsize=6, labeltop=True, labelbottom=False, bottom=False, top=True, pad=0.5, length=2
570 )
572 return fig
575class PanelConfig(Config):
576 """Configuration options for the plot panels used by DiaSkyPlot.
578 The defaults will produce a good-looking single panel plot.
579 The subplot2grid* fields correspond to matplotlib.pyplot.subplot2grid.
580 """
582 topSpinesVisible = Field[bool](
583 doc="Draw line and ticks on top of panel?",
584 default=False,
585 )
586 bottomSpinesVisible = Field[bool](
587 doc="Draw line and ticks on bottom of panel?",
588 default=True,
589 )
590 leftSpinesVisible = Field[bool](
591 doc="Draw line and ticks on left side of panel?",
592 default=True,
593 )
594 rightSpinesVisible = Field[bool](
595 doc="Draw line and ticks on right side of panel?",
596 default=True,
597 )
598 subplot2gridShapeRow = Field[int](
599 doc="Number of rows of the grid in which to place axis.",
600 default=10,
601 )
602 subplot2gridShapeColumn = Field[int](
603 doc="Number of columns of the grid in which to place axis.",
604 default=10,
605 )
606 subplot2gridLocRow = Field[int](
607 doc="Row of the axis location within the grid.",
608 default=1,
609 )
610 subplot2gridLocColumn = Field[int](
611 doc="Column of the axis location within the grid.",
612 default=1,
613 )
614 subplot2gridRowspan = Field[int](
615 doc="Number of rows for the axis to span downwards.",
616 default=5,
617 )
618 subplot2gridColspan = Field[int](
619 doc="Number of rows for the axis to span to the right.",
620 default=5,
621 )