Coverage for python/lsst/analysis/drp/plotUtils.py: 8%
184 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-03-21 01:40 -0700
« prev ^ index » next coverage.py v6.5.0, created at 2023-03-21 01:40 -0700
1# This file is part of analysis_drp.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import numpy as np
23import matplotlib.pyplot as plt
24import scipy.odr as scipyODR
25import matplotlib
26from matplotlib import colors
27from typing import List, Tuple
29from lsst.geom import Box2D, SpherePoint, degrees
31null_formatter = matplotlib.ticker.NullFormatter()
34def parsePlotInfo(dataId, runName, tableName, bands, plotName, SN, SNFlux):
35 """Parse plot info from the dataId
37 Parameters
38 ----------
39 dataId : `lsst.daf.butler.core.dimensions.`
40 `_coordinate._ExpandedTupleDataCoordinate`
41 runName : `str`
43 Returns
44 -------
45 plotInfo : `dict`
46 """
47 plotInfo = {"run": runName, "tractTableType": tableName, "plotName": plotName, "SN": SN, "SNFlux": SNFlux}
49 for dataInfo in dataId:
50 plotInfo[dataInfo.name] = dataId[dataInfo.name]
52 bandStr = ""
53 for band in bands:
54 bandStr += (", " + band)
55 plotInfo["bands"] = bandStr[2:]
57 if "tract" not in plotInfo.keys():
58 plotInfo["tract"] = "N/A"
59 if "visit" not in plotInfo.keys():
60 plotInfo["visit"] = "N/A"
62 return plotInfo
65def generateSummaryStats(cat, colName, skymap, plotInfo):
66 """Generate a summary statistic in each patch or detector
68 Parameters
69 ----------
70 cat : `pandas.core.frame.DataFrame`
71 colName : `str`
72 skymap : `lsst.skymap.ringsSkyMap.RingsSkyMap`
73 plotInfo : `dict`
75 Returns
76 -------
77 patchInfoDict : `dict`
78 """
80 # TODO: what is the more generic type of skymap?
81 tractInfo = skymap.generateTract(plotInfo["tract"])
82 tractWcs = tractInfo.getWcs()
84 if "sourceType" in cat.columns:
85 cat = cat.loc[cat["sourceType"] != 0]
87 # For now also convert the gen 2 patchIds to gen 3
89 patchInfoDict = {}
90 maxPatchNum = tractInfo.num_patches.x*tractInfo.num_patches.y
91 patches = np.arange(0, maxPatchNum, 1)
92 for patch in patches:
93 if patch is None:
94 continue
95 # Once the objectTable_tract catalogues are using gen 3 patches
96 # this will go away
97 onPatch = (cat["patch"] == patch)
98 stat = np.nanmedian(cat[colName].values[onPatch])
99 try:
100 patchTuple = (int(patch.split(",")[0]), int(patch.split(",")[-1]))
101 patchInfo = tractInfo.getPatchInfo(patchTuple)
102 gen3PatchId = tractInfo.getSequentialPatchIndex(patchInfo)
103 except AttributeError:
104 # For native gen 3 tables the patches don't need converting
105 # When we are no longer looking at the gen 2 -> gen 3
106 # converted repos we can tidy this up
107 gen3PatchId = patch
108 patchInfo = tractInfo.getPatchInfo(patch)
110 corners = Box2D(patchInfo.getInnerBBox()).getCorners()
111 skyCoords = tractWcs.pixelToSky(corners)
113 patchInfoDict[gen3PatchId] = (skyCoords, stat)
115 tractCorners = Box2D(tractInfo.getBBox()).getCorners()
116 skyCoords = tractWcs.pixelToSky(tractCorners)
117 patchInfoDict["tract"] = (skyCoords, np.nan)
119 return patchInfoDict
122def generateSummaryStatsVisit(cat, colName, visitSummaryTable, plotInfo):
123 """Generate a summary statistic in each patch or detector
125 Parameters
126 ----------
127 cat : `pandas.core.frame.DataFrame`
128 colName : `str`
129 visitSummaryTable : `pandas.core.frame.DataFrame`
130 plotInfo : `dict`
132 Returns
133 -------
134 visitInfoDict : `dict`
135 """
137 visitInfoDict = {}
138 for ccd in cat.detector.unique():
139 if ccd is None:
140 continue
141 onCcd = (cat["detector"] == ccd)
142 stat = np.nanmedian(cat[colName].values[onCcd])
144 sumRow = (visitSummaryTable["id"] == ccd)
145 corners = zip(visitSummaryTable["raCorners"][sumRow][0], visitSummaryTable["decCorners"][sumRow][0])
146 cornersOut = []
147 for (ra, dec) in corners:
148 corner = SpherePoint(ra, dec, units=degrees)
149 cornersOut.append(corner)
151 visitInfoDict[ccd] = (cornersOut, stat)
153 return visitInfoDict
156# Inspired by matplotlib.testing.remove_ticks_and_titles
157def get_and_remove_axis_text(ax) -> Tuple[List[str], List[np.ndarray]]:
158 """Remove text from an Axis and its children and return with line points.
160 Parameters
161 ----------
162 ax : `plt.Axis`
163 A matplotlib figure axis.
165 Returns
166 -------
167 texts : `List[str]`
168 A list of all text strings (title and axis/legend/tick labels).
169 line_xys : `List[numpy.ndarray]`
170 A list of all line ``_xy`` attributes (arrays of shape ``(N, 2)``).
171 """
172 line_xys = [line._xy for line in ax.lines]
173 texts = [text.get_text() for text in (ax.title, ax.xaxis.label, ax.yaxis.label)]
174 ax.set_title("")
175 ax.set_xlabel("")
176 ax.set_ylabel("")
178 try:
179 texts_legend = ax.get_legend().texts
180 texts.extend(text.get_text() for text in texts_legend)
181 for text in texts_legend:
182 text.set_alpha(0)
183 except AttributeError:
184 pass
186 for idx in range(len(ax.texts)):
187 texts.append(ax.texts[idx].get_text())
188 ax.texts[idx].set_text('')
190 ax.xaxis.set_major_formatter(null_formatter)
191 ax.xaxis.set_minor_formatter(null_formatter)
192 ax.yaxis.set_major_formatter(null_formatter)
193 ax.yaxis.set_minor_formatter(null_formatter)
194 try:
195 ax.zaxis.set_major_formatter(null_formatter)
196 ax.zaxis.set_minor_formatter(null_formatter)
197 except AttributeError:
198 pass
199 for child in ax.child_axes:
200 texts_child, lines_child = get_and_remove_axis_text(child)
201 texts.extend(texts_child)
203 return texts, line_xys
206def get_and_remove_figure_text(figure: plt.Figure):
207 """Remove text from a Figure and its Axes and return with line points.
209 Parameters
210 ----------
211 figure : `matplotlib.pyplot.Figure`
212 A matplotlib figure.
214 Returns
215 -------
216 texts : `List[str]`
217 A list of all text strings (title and axis/legend/tick labels).
218 line_xys : `List[numpy.ndarray]`, (N, 2)
219 A list of all line ``_xy`` attributes (arrays of shape ``(N, 2)``).
220 """
221 texts = [str(figure._suptitle)]
222 lines = []
223 figure.suptitle("")
225 texts.extend(text.get_text() for text in figure.texts)
226 figure.texts = []
228 for ax in figure.get_axes():
229 texts_ax, lines_ax = get_and_remove_axis_text(ax)
230 texts.extend(texts_ax)
231 lines.extend(lines_ax)
233 return texts, lines
236def addPlotInfo(fig, plotInfo):
237 """Add useful information to the plot
239 Parameters
240 ----------
241 fig : `matplotlib.figure.Figure`
242 plotInfo : `dict`
244 Returns
245 -------
246 fig : `matplotlib.figure.Figure`
247 """
249 # TO DO: figure out how to get this information
250 photocalibDataset = "None"
251 astroDataset = "None"
253 fig.text(0.01, 0.99, plotInfo["plotName"], fontsize=8, transform=fig.transFigure, ha="left", va="top")
255 run = plotInfo["run"]
256 datasetsUsed = f"\nPhotoCalib: {photocalibDataset}, Astrometry: {astroDataset}"
257 tableType = f"\nTable: {plotInfo['tractTableType']}"
259 dataIdText = ""
260 if str(plotInfo["tract"]) != "N/A":
261 dataIdText += f", Tract: {plotInfo['tract']}"
262 if str(plotInfo["visit"]) != "N/A":
263 dataIdText += f", Visit: {plotInfo['visit']}"
265 bandsText = f", Bands: {''.join(plotInfo['bands'].split(' '))}"
266 if isinstance(plotInfo["SN"], str):
267 SNText = f", S/N: {plotInfo['SN']}"
268 else:
269 if np.abs(plotInfo["SN"]) > 1e4:
270 SNText = f", S/N > {plotInfo['SN']:0.1g} ({plotInfo['SNFlux']})"
271 else:
272 SNText = f", S/N > {plotInfo['SN']:0.1f} ({plotInfo['SNFlux']})"
273 infoText = f"\n{run}{datasetsUsed}{tableType}{dataIdText}{bandsText}{SNText}"
274 fig.text(0.01, 0.98, infoText, fontsize=7, transform=fig.transFigure, alpha=0.6, ha="left", va="top")
276 return fig
279def stellarLocusFit(xs, ys, paramDict):
280 """Make a fit to the stellar locus
282 Parameters
283 ----------
284 xs : `numpy.ndarray`
285 The color on the xaxis
286 ys : `numpy.ndarray`
287 The color on the yaxis
288 paramDict : lsst.pex.config.dictField.Dict
289 A dictionary of parameters for line fitting
290 xMin : `float`
291 The minimum x edge of the box to use for initial fitting
292 xMax : `float`
293 The maximum x edge of the box to use for initial fitting
294 yMin : `float`
295 The minimum y edge of the box to use for initial fitting
296 yMax : `float`
297 The maximum y edge of the box to use for initial fitting
298 mHW : `float`
299 The hardwired gradient for the fit
300 bHW : `float`
301 The hardwired intercept of the fit
303 Returns
304 -------
305 paramsOut : `dict`
306 A dictionary of the calculated fit parameters
307 xMin : `float`
308 The minimum x edge of the box to use for initial fitting
309 xMax : `float`
310 The maximum x edge of the box to use for initial fitting
311 yMin : `float`
312 The minimum y edge of the box to use for initial fitting
313 yMax : `float`
314 The maximum y edge of the box to use for initial fitting
315 mHW : `float`
316 The hardwired gradient for the fit
317 bHW : `float`
318 The hardwired intercept of the fit
319 mODR : `float`
320 The gradient calculated by the ODR fit
321 bODR : `float`
322 The intercept calculated by the ODR fit
323 yBoxMin : `float`
324 The y value of the fitted line at xMin
325 yBoxMax : `float`
326 The y value of the fitted line at xMax
327 bPerpMin : `float`
328 The intercept of the perpendicular line that goes through xMin
329 bPerpMax : `float`
330 The intercept of the perpendicular line that goes through xMax
331 mODR2 : `float`
332 The gradient from the second round of fitting
333 bODR2 : `float`
334 The intercept from the second round of fitting
335 mPerp : `float`
336 The gradient of the line perpendicular to the line from the
337 second fit
339 Notes
340 -----
341 The code does two rounds of fitting, the first is initiated using the
342 hardwired values given in the `paramDict` parameter and is done using
343 an Orthogonal Distance Regression fit to the points defined by the
344 box of xMin, xMax, yMin and yMax. Once this fitting has been done a
345 perpendicular bisector is calculated at either end of the line and
346 only points that fall within these lines are used to recalculate the fit.
347 """
349 # Points to use for the fit
350 fitPoints = np.where((xs > paramDict["xMin"]) & (xs < paramDict["xMax"])
351 & (ys > paramDict["yMin"]) & (ys < paramDict["yMax"]))[0]
353 linear = scipyODR.polynomial(1)
355 data = scipyODR.Data(xs[fitPoints], ys[fitPoints])
356 odr = scipyODR.ODR(data, linear, beta0=[paramDict["bHW"], paramDict["mHW"]])
357 params = odr.run()
358 mODR = float(params.beta[1])
359 bODR = float(params.beta[0])
361 paramsOut = {"xMin": paramDict["xMin"], "xMax": paramDict["xMax"], "yMin": paramDict["yMin"],
362 "yMax": paramDict["yMax"], "mHW": paramDict["mHW"], "bHW": paramDict["bHW"],
363 "mODR": mODR, "bODR": bODR}
365 # Having found the initial fit calculate perpendicular ends
366 mPerp = -1.0/mODR
367 # When the gradient is really steep we need to use
368 # the y limits of the box rather than the x ones
370 if np.abs(mODR) > 1:
371 yBoxMin = paramDict["yMin"]
372 xBoxMin = (yBoxMin - bODR)/mODR
373 yBoxMax = paramDict["yMax"]
374 xBoxMax = (yBoxMax - bODR)/mODR
375 else:
376 yBoxMin = mODR*paramDict["xMin"] + bODR
377 xBoxMin = paramDict["xMin"]
378 yBoxMax = mODR*paramDict["xMax"] + bODR
379 xBoxMax = paramDict["xMax"]
381 bPerpMin = yBoxMin - mPerp*xBoxMin
383 paramsOut["yBoxMin"] = yBoxMin
384 paramsOut["bPerpMin"] = bPerpMin
386 bPerpMax = yBoxMax - mPerp*xBoxMax
388 paramsOut["yBoxMax"] = yBoxMax
389 paramsOut["bPerpMax"] = bPerpMax
391 # Use these perpendicular lines to chose the data and refit
392 fitPoints = ((ys > mPerp*xs + bPerpMin) & (ys < mPerp*xs + bPerpMax))
393 data = scipyODR.Data(xs[fitPoints], ys[fitPoints])
394 odr = scipyODR.ODR(data, linear, beta0=[bODR, mODR])
395 params = odr.run()
396 mODR = float(params.beta[1])
397 bODR = float(params.beta[0])
399 paramsOut["mODR2"] = float(params.beta[1])
400 paramsOut["bODR2"] = float(params.beta[0])
402 paramsOut["mPerp"] = -1.0/paramsOut["mODR2"]
404 return paramsOut
407def perpDistance(p1, p2, points):
408 """Calculate the perpendicular distance to a line from a point
410 Parameters
411 ----------
412 p1 : `numpy.ndarray`
413 A point on the line
414 p2 : `numpy.ndarray`
415 Another point on the line
416 points : `zip`
417 The points to calculate the distance to
419 Returns
420 -------
421 dists : `list`
422 The distances from the line to the points. Uses the cross
423 product to work this out.
424 """
425 dists = []
426 for point in points:
427 point = np.array(point)
428 distToLine = np.cross(p2 - p1, point - p1)/np.linalg.norm(p2 - p1)
429 dists.append(distToLine)
431 return dists
434def mkColormap(colorNames):
435 """Make a colormap from the list of color names.
437 Parameters
438 ----------
439 colorNames : `list`
440 A list of strings that correspond to matplotlib
441 named colors.
443 Returns
444 -------
445 cmap : `matplotlib.colors.LinearSegmentedColormap`
446 """
448 nums = np.linspace(0, 1, len(colorNames))
449 blues = []
450 greens = []
451 reds = []
452 for (num, color) in zip(nums, colorNames):
453 r, g, b = colors.colorConverter.to_rgb(color)
454 blues.append((num, b, b))
455 greens.append((num, g, g))
456 reds.append((num, r, r))
458 colorDict = {"blue": blues, "red": reds, "green": greens}
459 cmap = colors.LinearSegmentedColormap("newCmap", colorDict)
460 return cmap
463def extremaSort(xs):
464 """Return the ids of the points reordered so that those
465 furthest from the median, in absolute terms, are last.
467 Parameters
468 ----------
469 xs : `np.array`
470 An array of the values to sort
472 Returns
473 -------
474 ids : `np.array`
475 """
477 med = np.median(xs)
478 dists = np.abs(xs - med)
479 ids = np.argsort(dists)
480 return ids