Coverage for python/lsst/analysis/drp/plotUtils.py: 8%
182 statements
« prev ^ index » next coverage.py v7.4.1, created at 2024-02-17 11:29 +0000
« prev ^ index » next coverage.py v7.4.1, created at 2024-02-17 11:29 +0000
1# This file is part of analysis_drp.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import numpy as np
23import matplotlib.pyplot as plt
24import scipy.odr as scipyODR
25import matplotlib
26from matplotlib import colors
27from typing import List, Tuple
29from lsst.geom import Box2D, SpherePoint, degrees
31null_formatter = matplotlib.ticker.NullFormatter()
34def parsePlotInfo(dataId, runName, tableName, bands, plotName, SN, SNFlux):
35 """Parse plot info from the dataId
37 Parameters
38 ----------
39 dataId : `lsst.daf.butler.DataCoordinate`
40 runName : `str`
42 Returns
43 -------
44 plotInfo : `dict`
45 """
46 plotInfo = {"run": runName, "tractTableType": tableName, "plotName": plotName, "SN": SN, "SNFlux": SNFlux}
48 plotInfo.update(dataId.mapping)
50 bandStr = ""
51 for band in bands:
52 bandStr += (", " + band)
53 plotInfo["bands"] = bandStr[2:]
55 if "tract" not in plotInfo.keys():
56 plotInfo["tract"] = "N/A"
57 if "visit" not in plotInfo.keys():
58 plotInfo["visit"] = "N/A"
60 return plotInfo
63def generateSummaryStats(cat, colName, skymap, plotInfo):
64 """Generate a summary statistic in each patch or detector
66 Parameters
67 ----------
68 cat : `pandas.core.frame.DataFrame`
69 colName : `str`
70 skymap : `lsst.skymap.ringsSkyMap.RingsSkyMap`
71 plotInfo : `dict`
73 Returns
74 -------
75 patchInfoDict : `dict`
76 """
78 # TODO: what is the more generic type of skymap?
79 tractInfo = skymap.generateTract(plotInfo["tract"])
80 tractWcs = tractInfo.getWcs()
82 if "sourceType" in cat.columns:
83 cat = cat.loc[cat["sourceType"] != 0]
85 # For now also convert the gen 2 patchIds to gen 3
87 patchInfoDict = {}
88 maxPatchNum = tractInfo.num_patches.x*tractInfo.num_patches.y
89 patches = np.arange(0, maxPatchNum, 1)
90 for patch in patches:
91 if patch is None:
92 continue
93 # Once the objectTable_tract catalogues are using gen 3 patches
94 # this will go away
95 onPatch = (cat["patch"] == patch)
96 stat = np.nanmedian(cat[colName].values[onPatch])
97 try:
98 patchTuple = (int(patch.split(",")[0]), int(patch.split(",")[-1]))
99 patchInfo = tractInfo.getPatchInfo(patchTuple)
100 gen3PatchId = tractInfo.getSequentialPatchIndex(patchInfo)
101 except AttributeError:
102 # For native gen 3 tables the patches don't need converting
103 # When we are no longer looking at the gen 2 -> gen 3
104 # converted repos we can tidy this up
105 gen3PatchId = patch
106 patchInfo = tractInfo.getPatchInfo(patch)
108 corners = Box2D(patchInfo.getInnerBBox()).getCorners()
109 skyCoords = tractWcs.pixelToSky(corners)
111 patchInfoDict[gen3PatchId] = (skyCoords, stat)
113 tractCorners = Box2D(tractInfo.getBBox()).getCorners()
114 skyCoords = tractWcs.pixelToSky(tractCorners)
115 patchInfoDict["tract"] = (skyCoords, np.nan)
117 return patchInfoDict
120def generateSummaryStatsVisit(cat, colName, visitSummaryTable, plotInfo):
121 """Generate a summary statistic in each patch or detector
123 Parameters
124 ----------
125 cat : `pandas.core.frame.DataFrame`
126 colName : `str`
127 visitSummaryTable : `pandas.core.frame.DataFrame`
128 plotInfo : `dict`
130 Returns
131 -------
132 visitInfoDict : `dict`
133 """
135 visitInfoDict = {}
136 for ccd in cat.detector.unique():
137 if ccd is None:
138 continue
139 onCcd = (cat["detector"] == ccd)
140 stat = np.nanmedian(cat[colName].values[onCcd])
142 sumRow = (visitSummaryTable["id"] == ccd)
143 corners = zip(visitSummaryTable["raCorners"][sumRow][0], visitSummaryTable["decCorners"][sumRow][0])
144 cornersOut = []
145 for (ra, dec) in corners:
146 corner = SpherePoint(ra, dec, units=degrees)
147 cornersOut.append(corner)
149 visitInfoDict[ccd] = (cornersOut, stat)
151 return visitInfoDict
154# Inspired by matplotlib.testing.remove_ticks_and_titles
155def get_and_remove_axis_text(ax) -> Tuple[List[str], List[np.ndarray]]:
156 """Remove text from an Axis and its children and return with line points.
158 Parameters
159 ----------
160 ax : `plt.Axis`
161 A matplotlib figure axis.
163 Returns
164 -------
165 texts : `List[str]`
166 A list of all text strings (title and axis/legend/tick labels).
167 line_xys : `List[numpy.ndarray]`
168 A list of all line ``_xy`` attributes (arrays of shape ``(N, 2)``).
169 """
170 line_xys = [line._xy for line in ax.lines]
171 texts = [text.get_text() for text in (ax.title, ax.xaxis.label, ax.yaxis.label)]
172 ax.set_title("")
173 ax.set_xlabel("")
174 ax.set_ylabel("")
176 try:
177 texts_legend = ax.get_legend().texts
178 texts.extend(text.get_text() for text in texts_legend)
179 for text in texts_legend:
180 text.set_alpha(0)
181 except AttributeError:
182 pass
184 for idx in range(len(ax.texts)):
185 texts.append(ax.texts[idx].get_text())
186 ax.texts[idx].set_text('')
188 ax.xaxis.set_major_formatter(null_formatter)
189 ax.xaxis.set_minor_formatter(null_formatter)
190 ax.yaxis.set_major_formatter(null_formatter)
191 ax.yaxis.set_minor_formatter(null_formatter)
192 try:
193 ax.zaxis.set_major_formatter(null_formatter)
194 ax.zaxis.set_minor_formatter(null_formatter)
195 except AttributeError:
196 pass
197 for child in ax.child_axes:
198 texts_child, lines_child = get_and_remove_axis_text(child)
199 texts.extend(texts_child)
201 return texts, line_xys
204def get_and_remove_figure_text(figure: plt.Figure):
205 """Remove text from a Figure and its Axes and return with line points.
207 Parameters
208 ----------
209 figure : `matplotlib.pyplot.Figure`
210 A matplotlib figure.
212 Returns
213 -------
214 texts : `List[str]`
215 A list of all text strings (title and axis/legend/tick labels).
216 line_xys : `List[numpy.ndarray]`, (N, 2)
217 A list of all line ``_xy`` attributes (arrays of shape ``(N, 2)``).
218 """
219 texts = [str(figure._suptitle)]
220 lines = []
221 figure.suptitle("")
223 texts.extend(text.get_text() for text in figure.texts)
224 figure.texts = []
226 for ax in figure.get_axes():
227 texts_ax, lines_ax = get_and_remove_axis_text(ax)
228 texts.extend(texts_ax)
229 lines.extend(lines_ax)
231 return texts, lines
234def addPlotInfo(fig, plotInfo):
235 """Add useful information to the plot
237 Parameters
238 ----------
239 fig : `matplotlib.figure.Figure`
240 plotInfo : `dict`
242 Returns
243 -------
244 fig : `matplotlib.figure.Figure`
245 """
247 # TO DO: figure out how to get this information
248 photocalibDataset = "None"
249 astroDataset = "None"
251 fig.text(0.01, 0.99, plotInfo["plotName"], fontsize=8, transform=fig.transFigure, ha="left", va="top")
253 run = plotInfo["run"]
254 datasetsUsed = f"\nPhotoCalib: {photocalibDataset}, Astrometry: {astroDataset}"
255 tableType = f"\nTable: {plotInfo['tractTableType']}"
257 dataIdText = ""
258 if str(plotInfo["tract"]) != "N/A":
259 dataIdText += f", Tract: {plotInfo['tract']}"
260 if str(plotInfo["visit"]) != "N/A":
261 dataIdText += f", Visit: {plotInfo['visit']}"
263 bandsText = f", Bands: {''.join(plotInfo['bands'].split(' '))}"
264 if isinstance(plotInfo["SN"], str):
265 SNText = f", S/N: {plotInfo['SN']}"
266 else:
267 if np.abs(plotInfo["SN"]) > 1e4:
268 SNText = f", S/N > {plotInfo['SN']:0.1g} ({plotInfo['SNFlux']})"
269 else:
270 SNText = f", S/N > {plotInfo['SN']:0.1f} ({plotInfo['SNFlux']})"
271 infoText = f"\n{run}{datasetsUsed}{tableType}{dataIdText}{bandsText}{SNText}"
272 fig.text(0.01, 0.98, infoText, fontsize=7, transform=fig.transFigure, alpha=0.6, ha="left", va="top")
274 return fig
277def stellarLocusFit(xs, ys, paramDict):
278 """Make a fit to the stellar locus
280 Parameters
281 ----------
282 xs : `numpy.ndarray` [`float`]
283 The color on the xaxis.
284 ys : `numpy.ndarray` [`float`]
285 The color on the yaxis.
286 paramDict : `lsst.pex.config.dictField.Dict`
287 A dictionary of parameters for line fitting
288 xMin : `float`
289 The minimum x edge of the box to use for initial fitting.
290 xMax : `float`
291 The maximum x edge of the box to use for initial fitting.
292 yMin : `float`
293 The minimum y edge of the box to use for initial fitting.
294 yMax : `float`
295 The maximum y edge of the box to use for initial fitting.
296 mHW : `float`
297 The hardwired gradient for the fit.
298 bHW : `float`
299 The hardwired intercept of the fit.
301 Returns
302 -------
303 paramsOut : `dict`
304 A dictionary of the calculated fit parameters.
305 mODR0 : `float`
306 The gradient calculated by the initial ODR fit.
307 bODR0 : `float`
308 The intercept calculated by the initial ODR fit.
309 yBoxMin : `float`
310 The y value of the fitted line at xMin.
311 yBoxMax : `float`
312 The y value of the fitted line at xMax.
313 bPerpMin : `float`
314 The intercept of the perpendicular line that goes through xMin.
315 bPerpMax : `float`
316 The intercept of the perpendicular line that goes through xMax.
317 mODR : `float`
318 The gradient from the second (and final) round of fitting.
319 bODR : `float`
320 The intercept from the second (and final) round of fitting.
321 mPerp : `float`
322 The gradient of the line perpendicular to the line from the
323 second fit.
324 fitPoints : `numpy.ndarray` [`bool`]
325 A boolean array indicating which points were usee in the final fit.
327 Notes
328 -----
329 The code does two rounds of fitting, the first is initiated using the
330 hardwired values given in the `paramDict` parameter and is done using
331 an Orthogonal Distance Regression fit to the points defined by the
332 box of xMin, xMax, yMin and yMax. Once this fitting has been done a
333 perpendicular bisector is calculated at either end of the line and
334 only points that fall within these lines are used to recalculate the fit.
335 """
337 # Initial subselection of points to use for the fit
338 fitPoints = ((xs > paramDict["xMin"]) & (xs < paramDict["xMax"])
339 & (ys > paramDict["yMin"]) & (ys < paramDict["yMax"]))
341 linear = scipyODR.polynomial(1)
343 data = scipyODR.Data(xs[fitPoints], ys[fitPoints])
344 odr = scipyODR.ODR(data, linear, beta0=[paramDict["bHW"], paramDict["mHW"]])
345 params = odr.run()
346 mODR0 = float(params.beta[1])
347 bODR0 = float(params.beta[0])
349 paramsOut = {"mODR0": mODR0, "bODR0": bODR0}
351 # Having found the initial fit calculate perpendicular ends
352 mPerp0 = -1.0/mODR0
353 # When the gradient is really steep we need to use
354 # the y limits of the box rather than the x ones.
356 if np.abs(mODR0) > 1:
357 yBoxMin = paramDict["yMin"]
358 xBoxMin = (yBoxMin - bODR0)/mODR0
359 yBoxMax = paramDict["yMax"]
360 xBoxMax = (yBoxMax - bODR0)/mODR0
361 else:
362 yBoxMin = mODR0*paramDict["xMin"] + bODR0
363 xBoxMin = paramDict["xMin"]
364 yBoxMax = mODR0*paramDict["xMax"] + bODR0
365 xBoxMax = paramDict["xMax"]
367 bPerpMin = yBoxMin - mPerp0*xBoxMin
369 paramsOut["yBoxMin"] = yBoxMin
370 paramsOut["bPerpMin"] = bPerpMin
372 bPerpMax = yBoxMax - mPerp0*xBoxMax
374 paramsOut["yBoxMax"] = yBoxMax
375 paramsOut["bPerpMax"] = bPerpMax
377 # Use these perpendicular lines to chose the data and refit
378 fitPoints = ((ys > mPerp0*xs + bPerpMin) & (ys < mPerp0*xs + bPerpMax))
379 data = scipyODR.Data(xs[fitPoints], ys[fitPoints])
380 odr = scipyODR.ODR(data, linear, beta0=[bODR0, mODR0])
381 params = odr.run()
383 paramsOut["mODR"] = float(params.beta[1])
384 paramsOut["bODR"] = float(params.beta[0])
386 paramsOut["mPerp"] = -1.0/paramsOut["mODR"]
387 paramsOut["fitPoints"] = fitPoints
389 return paramsOut
392def perpDistance(p1, p2, points):
393 """Calculate the perpendicular distance to a line from a point
395 Parameters
396 ----------
397 p1 : `numpy.ndarray`
398 A point on the line
399 p2 : `numpy.ndarray`
400 Another point on the line
401 points : `zip`
402 The points to calculate the distance to
404 Returns
405 -------
406 dists : `list`
407 The distances from the line to the points. Uses the cross
408 product to work this out.
409 """
410 dists = []
411 for point in points:
412 point = np.array(point)
413 distToLine = np.cross(p2 - p1, point - p1)/np.linalg.norm(p2 - p1)
414 dists.append(distToLine)
416 return dists
419def mkColormap(colorNames):
420 """Make a colormap from the list of color names.
422 Parameters
423 ----------
424 colorNames : `list`
425 A list of strings that correspond to matplotlib
426 named colors.
428 Returns
429 -------
430 cmap : `matplotlib.colors.LinearSegmentedColormap`
431 """
433 nums = np.linspace(0, 1, len(colorNames))
434 blues = []
435 greens = []
436 reds = []
437 for (num, color) in zip(nums, colorNames):
438 r, g, b = colors.colorConverter.to_rgb(color)
439 blues.append((num, b, b))
440 greens.append((num, g, g))
441 reds.append((num, r, r))
443 colorDict = {"blue": blues, "red": reds, "green": greens}
444 cmap = colors.LinearSegmentedColormap("newCmap", colorDict)
445 return cmap
448def extremaSort(xs):
449 """Return the ids of the points reordered so that those
450 furthest from the median, in absolute terms, are last.
452 Parameters
453 ----------
454 xs : `np.array`
455 An array of the values to sort
457 Returns
458 -------
459 ids : `np.array`
460 """
462 med = np.median(xs)
463 dists = np.abs(xs - med)
464 ids = np.argsort(dists)
465 return ids