Coverage for python/lsst/analysis/drp/plotUtils.py: 8%
183 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-11-15 12:24 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-11-15 12:24 +0000
1# This file is part of analysis_drp.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import numpy as np
23import matplotlib.pyplot as plt
24import scipy.odr as scipyODR
25import matplotlib
26from matplotlib import colors
27from typing import List, Tuple
29from lsst.geom import Box2D, SpherePoint, degrees
31null_formatter = matplotlib.ticker.NullFormatter()
34def parsePlotInfo(dataId, runName, tableName, bands, plotName, SN, SNFlux):
35 """Parse plot info from the dataId
37 Parameters
38 ----------
39 dataId : `lsst.daf.butler.DataCoordinate`
40 runName : `str`
42 Returns
43 -------
44 plotInfo : `dict`
45 """
46 plotInfo = {"run": runName, "tractTableType": tableName, "plotName": plotName, "SN": SN, "SNFlux": SNFlux}
48 for dataInfo in dataId:
49 plotInfo[dataInfo.name] = dataId[dataInfo.name]
51 bandStr = ""
52 for band in bands:
53 bandStr += (", " + band)
54 plotInfo["bands"] = bandStr[2:]
56 if "tract" not in plotInfo.keys():
57 plotInfo["tract"] = "N/A"
58 if "visit" not in plotInfo.keys():
59 plotInfo["visit"] = "N/A"
61 return plotInfo
64def generateSummaryStats(cat, colName, skymap, plotInfo):
65 """Generate a summary statistic in each patch or detector
67 Parameters
68 ----------
69 cat : `pandas.core.frame.DataFrame`
70 colName : `str`
71 skymap : `lsst.skymap.ringsSkyMap.RingsSkyMap`
72 plotInfo : `dict`
74 Returns
75 -------
76 patchInfoDict : `dict`
77 """
79 # TODO: what is the more generic type of skymap?
80 tractInfo = skymap.generateTract(plotInfo["tract"])
81 tractWcs = tractInfo.getWcs()
83 if "sourceType" in cat.columns:
84 cat = cat.loc[cat["sourceType"] != 0]
86 # For now also convert the gen 2 patchIds to gen 3
88 patchInfoDict = {}
89 maxPatchNum = tractInfo.num_patches.x*tractInfo.num_patches.y
90 patches = np.arange(0, maxPatchNum, 1)
91 for patch in patches:
92 if patch is None:
93 continue
94 # Once the objectTable_tract catalogues are using gen 3 patches
95 # this will go away
96 onPatch = (cat["patch"] == patch)
97 stat = np.nanmedian(cat[colName].values[onPatch])
98 try:
99 patchTuple = (int(patch.split(",")[0]), int(patch.split(",")[-1]))
100 patchInfo = tractInfo.getPatchInfo(patchTuple)
101 gen3PatchId = tractInfo.getSequentialPatchIndex(patchInfo)
102 except AttributeError:
103 # For native gen 3 tables the patches don't need converting
104 # When we are no longer looking at the gen 2 -> gen 3
105 # converted repos we can tidy this up
106 gen3PatchId = patch
107 patchInfo = tractInfo.getPatchInfo(patch)
109 corners = Box2D(patchInfo.getInnerBBox()).getCorners()
110 skyCoords = tractWcs.pixelToSky(corners)
112 patchInfoDict[gen3PatchId] = (skyCoords, stat)
114 tractCorners = Box2D(tractInfo.getBBox()).getCorners()
115 skyCoords = tractWcs.pixelToSky(tractCorners)
116 patchInfoDict["tract"] = (skyCoords, np.nan)
118 return patchInfoDict
121def generateSummaryStatsVisit(cat, colName, visitSummaryTable, plotInfo):
122 """Generate a summary statistic in each patch or detector
124 Parameters
125 ----------
126 cat : `pandas.core.frame.DataFrame`
127 colName : `str`
128 visitSummaryTable : `pandas.core.frame.DataFrame`
129 plotInfo : `dict`
131 Returns
132 -------
133 visitInfoDict : `dict`
134 """
136 visitInfoDict = {}
137 for ccd in cat.detector.unique():
138 if ccd is None:
139 continue
140 onCcd = (cat["detector"] == ccd)
141 stat = np.nanmedian(cat[colName].values[onCcd])
143 sumRow = (visitSummaryTable["id"] == ccd)
144 corners = zip(visitSummaryTable["raCorners"][sumRow][0], visitSummaryTable["decCorners"][sumRow][0])
145 cornersOut = []
146 for (ra, dec) in corners:
147 corner = SpherePoint(ra, dec, units=degrees)
148 cornersOut.append(corner)
150 visitInfoDict[ccd] = (cornersOut, stat)
152 return visitInfoDict
155# Inspired by matplotlib.testing.remove_ticks_and_titles
156def get_and_remove_axis_text(ax) -> Tuple[List[str], List[np.ndarray]]:
157 """Remove text from an Axis and its children and return with line points.
159 Parameters
160 ----------
161 ax : `plt.Axis`
162 A matplotlib figure axis.
164 Returns
165 -------
166 texts : `List[str]`
167 A list of all text strings (title and axis/legend/tick labels).
168 line_xys : `List[numpy.ndarray]`
169 A list of all line ``_xy`` attributes (arrays of shape ``(N, 2)``).
170 """
171 line_xys = [line._xy for line in ax.lines]
172 texts = [text.get_text() for text in (ax.title, ax.xaxis.label, ax.yaxis.label)]
173 ax.set_title("")
174 ax.set_xlabel("")
175 ax.set_ylabel("")
177 try:
178 texts_legend = ax.get_legend().texts
179 texts.extend(text.get_text() for text in texts_legend)
180 for text in texts_legend:
181 text.set_alpha(0)
182 except AttributeError:
183 pass
185 for idx in range(len(ax.texts)):
186 texts.append(ax.texts[idx].get_text())
187 ax.texts[idx].set_text('')
189 ax.xaxis.set_major_formatter(null_formatter)
190 ax.xaxis.set_minor_formatter(null_formatter)
191 ax.yaxis.set_major_formatter(null_formatter)
192 ax.yaxis.set_minor_formatter(null_formatter)
193 try:
194 ax.zaxis.set_major_formatter(null_formatter)
195 ax.zaxis.set_minor_formatter(null_formatter)
196 except AttributeError:
197 pass
198 for child in ax.child_axes:
199 texts_child, lines_child = get_and_remove_axis_text(child)
200 texts.extend(texts_child)
202 return texts, line_xys
205def get_and_remove_figure_text(figure: plt.Figure):
206 """Remove text from a Figure and its Axes and return with line points.
208 Parameters
209 ----------
210 figure : `matplotlib.pyplot.Figure`
211 A matplotlib figure.
213 Returns
214 -------
215 texts : `List[str]`
216 A list of all text strings (title and axis/legend/tick labels).
217 line_xys : `List[numpy.ndarray]`, (N, 2)
218 A list of all line ``_xy`` attributes (arrays of shape ``(N, 2)``).
219 """
220 texts = [str(figure._suptitle)]
221 lines = []
222 figure.suptitle("")
224 texts.extend(text.get_text() for text in figure.texts)
225 figure.texts = []
227 for ax in figure.get_axes():
228 texts_ax, lines_ax = get_and_remove_axis_text(ax)
229 texts.extend(texts_ax)
230 lines.extend(lines_ax)
232 return texts, lines
235def addPlotInfo(fig, plotInfo):
236 """Add useful information to the plot
238 Parameters
239 ----------
240 fig : `matplotlib.figure.Figure`
241 plotInfo : `dict`
243 Returns
244 -------
245 fig : `matplotlib.figure.Figure`
246 """
248 # TO DO: figure out how to get this information
249 photocalibDataset = "None"
250 astroDataset = "None"
252 fig.text(0.01, 0.99, plotInfo["plotName"], fontsize=8, transform=fig.transFigure, ha="left", va="top")
254 run = plotInfo["run"]
255 datasetsUsed = f"\nPhotoCalib: {photocalibDataset}, Astrometry: {astroDataset}"
256 tableType = f"\nTable: {plotInfo['tractTableType']}"
258 dataIdText = ""
259 if str(plotInfo["tract"]) != "N/A":
260 dataIdText += f", Tract: {plotInfo['tract']}"
261 if str(plotInfo["visit"]) != "N/A":
262 dataIdText += f", Visit: {plotInfo['visit']}"
264 bandsText = f", Bands: {''.join(plotInfo['bands'].split(' '))}"
265 if isinstance(plotInfo["SN"], str):
266 SNText = f", S/N: {plotInfo['SN']}"
267 else:
268 if np.abs(plotInfo["SN"]) > 1e4:
269 SNText = f", S/N > {plotInfo['SN']:0.1g} ({plotInfo['SNFlux']})"
270 else:
271 SNText = f", S/N > {plotInfo['SN']:0.1f} ({plotInfo['SNFlux']})"
272 infoText = f"\n{run}{datasetsUsed}{tableType}{dataIdText}{bandsText}{SNText}"
273 fig.text(0.01, 0.98, infoText, fontsize=7, transform=fig.transFigure, alpha=0.6, ha="left", va="top")
275 return fig
278def stellarLocusFit(xs, ys, paramDict):
279 """Make a fit to the stellar locus
281 Parameters
282 ----------
283 xs : `numpy.ndarray` [`float`]
284 The color on the xaxis.
285 ys : `numpy.ndarray` [`float`]
286 The color on the yaxis.
287 paramDict : `lsst.pex.config.dictField.Dict`
288 A dictionary of parameters for line fitting
289 xMin : `float`
290 The minimum x edge of the box to use for initial fitting.
291 xMax : `float`
292 The maximum x edge of the box to use for initial fitting.
293 yMin : `float`
294 The minimum y edge of the box to use for initial fitting.
295 yMax : `float`
296 The maximum y edge of the box to use for initial fitting.
297 mHW : `float`
298 The hardwired gradient for the fit.
299 bHW : `float`
300 The hardwired intercept of the fit.
302 Returns
303 -------
304 paramsOut : `dict`
305 A dictionary of the calculated fit parameters.
306 mODR0 : `float`
307 The gradient calculated by the initial ODR fit.
308 bODR0 : `float`
309 The intercept calculated by the initial ODR fit.
310 yBoxMin : `float`
311 The y value of the fitted line at xMin.
312 yBoxMax : `float`
313 The y value of the fitted line at xMax.
314 bPerpMin : `float`
315 The intercept of the perpendicular line that goes through xMin.
316 bPerpMax : `float`
317 The intercept of the perpendicular line that goes through xMax.
318 mODR : `float`
319 The gradient from the second (and final) round of fitting.
320 bODR : `float`
321 The intercept from the second (and final) round of fitting.
322 mPerp : `float`
323 The gradient of the line perpendicular to the line from the
324 second fit.
325 fitPoints : `numpy.ndarray` [`bool`]
326 A boolean array indicating which points were usee in the final fit.
328 Notes
329 -----
330 The code does two rounds of fitting, the first is initiated using the
331 hardwired values given in the `paramDict` parameter and is done using
332 an Orthogonal Distance Regression fit to the points defined by the
333 box of xMin, xMax, yMin and yMax. Once this fitting has been done a
334 perpendicular bisector is calculated at either end of the line and
335 only points that fall within these lines are used to recalculate the fit.
336 """
338 # Initial subselection of points to use for the fit
339 fitPoints = ((xs > paramDict["xMin"]) & (xs < paramDict["xMax"])
340 & (ys > paramDict["yMin"]) & (ys < paramDict["yMax"]))
342 linear = scipyODR.polynomial(1)
344 data = scipyODR.Data(xs[fitPoints], ys[fitPoints])
345 odr = scipyODR.ODR(data, linear, beta0=[paramDict["bHW"], paramDict["mHW"]])
346 params = odr.run()
347 mODR0 = float(params.beta[1])
348 bODR0 = float(params.beta[0])
350 paramsOut = {"mODR0": mODR0, "bODR0": bODR0}
352 # Having found the initial fit calculate perpendicular ends
353 mPerp0 = -1.0/mODR0
354 # When the gradient is really steep we need to use
355 # the y limits of the box rather than the x ones.
357 if np.abs(mODR0) > 1:
358 yBoxMin = paramDict["yMin"]
359 xBoxMin = (yBoxMin - bODR0)/mODR0
360 yBoxMax = paramDict["yMax"]
361 xBoxMax = (yBoxMax - bODR0)/mODR0
362 else:
363 yBoxMin = mODR0*paramDict["xMin"] + bODR0
364 xBoxMin = paramDict["xMin"]
365 yBoxMax = mODR0*paramDict["xMax"] + bODR0
366 xBoxMax = paramDict["xMax"]
368 bPerpMin = yBoxMin - mPerp0*xBoxMin
370 paramsOut["yBoxMin"] = yBoxMin
371 paramsOut["bPerpMin"] = bPerpMin
373 bPerpMax = yBoxMax - mPerp0*xBoxMax
375 paramsOut["yBoxMax"] = yBoxMax
376 paramsOut["bPerpMax"] = bPerpMax
378 # Use these perpendicular lines to chose the data and refit
379 fitPoints = ((ys > mPerp0*xs + bPerpMin) & (ys < mPerp0*xs + bPerpMax))
380 data = scipyODR.Data(xs[fitPoints], ys[fitPoints])
381 odr = scipyODR.ODR(data, linear, beta0=[bODR0, mODR0])
382 params = odr.run()
384 paramsOut["mODR"] = float(params.beta[1])
385 paramsOut["bODR"] = float(params.beta[0])
387 paramsOut["mPerp"] = -1.0/paramsOut["mODR"]
388 paramsOut["fitPoints"] = fitPoints
390 return paramsOut
393def perpDistance(p1, p2, points):
394 """Calculate the perpendicular distance to a line from a point
396 Parameters
397 ----------
398 p1 : `numpy.ndarray`
399 A point on the line
400 p2 : `numpy.ndarray`
401 Another point on the line
402 points : `zip`
403 The points to calculate the distance to
405 Returns
406 -------
407 dists : `list`
408 The distances from the line to the points. Uses the cross
409 product to work this out.
410 """
411 dists = []
412 for point in points:
413 point = np.array(point)
414 distToLine = np.cross(p2 - p1, point - p1)/np.linalg.norm(p2 - p1)
415 dists.append(distToLine)
417 return dists
420def mkColormap(colorNames):
421 """Make a colormap from the list of color names.
423 Parameters
424 ----------
425 colorNames : `list`
426 A list of strings that correspond to matplotlib
427 named colors.
429 Returns
430 -------
431 cmap : `matplotlib.colors.LinearSegmentedColormap`
432 """
434 nums = np.linspace(0, 1, len(colorNames))
435 blues = []
436 greens = []
437 reds = []
438 for (num, color) in zip(nums, colorNames):
439 r, g, b = colors.colorConverter.to_rgb(color)
440 blues.append((num, b, b))
441 greens.append((num, g, g))
442 reds.append((num, r, r))
444 colorDict = {"blue": blues, "red": reds, "green": greens}
445 cmap = colors.LinearSegmentedColormap("newCmap", colorDict)
446 return cmap
449def extremaSort(xs):
450 """Return the ids of the points reordered so that those
451 furthest from the median, in absolute terms, are last.
453 Parameters
454 ----------
455 xs : `np.array`
456 An array of the values to sort
458 Returns
459 -------
460 ids : `np.array`
461 """
463 med = np.median(xs)
464 dists = np.abs(xs - med)
465 ids = np.argsort(dists)
466 return ids