Coverage for python/lsst/analysis/drp/plotUtils.py: 8%

182 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-12-14 12:53 +0000

1# This file is part of analysis_drp. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import numpy as np 

23import matplotlib.pyplot as plt 

24import scipy.odr as scipyODR 

25import matplotlib 

26from matplotlib import colors 

27from typing import List, Tuple 

28 

29from lsst.geom import Box2D, SpherePoint, degrees 

30 

31null_formatter = matplotlib.ticker.NullFormatter() 

32 

33 

34def parsePlotInfo(dataId, runName, tableName, bands, plotName, SN, SNFlux): 

35 """Parse plot info from the dataId 

36 

37 Parameters 

38 ---------- 

39 dataId : `lsst.daf.butler.DataCoordinate` 

40 runName : `str` 

41 

42 Returns 

43 ------- 

44 plotInfo : `dict` 

45 """ 

46 plotInfo = {"run": runName, "tractTableType": tableName, "plotName": plotName, "SN": SN, "SNFlux": SNFlux} 

47 

48 plotInfo.update(dataId.mapping) 

49 

50 bandStr = "" 

51 for band in bands: 

52 bandStr += (", " + band) 

53 plotInfo["bands"] = bandStr[2:] 

54 

55 if "tract" not in plotInfo.keys(): 

56 plotInfo["tract"] = "N/A" 

57 if "visit" not in plotInfo.keys(): 

58 plotInfo["visit"] = "N/A" 

59 

60 return plotInfo 

61 

62 

63def generateSummaryStats(cat, colName, skymap, plotInfo): 

64 """Generate a summary statistic in each patch or detector 

65 

66 Parameters 

67 ---------- 

68 cat : `pandas.core.frame.DataFrame` 

69 colName : `str` 

70 skymap : `lsst.skymap.ringsSkyMap.RingsSkyMap` 

71 plotInfo : `dict` 

72 

73 Returns 

74 ------- 

75 patchInfoDict : `dict` 

76 """ 

77 

78 # TODO: what is the more generic type of skymap? 

79 tractInfo = skymap.generateTract(plotInfo["tract"]) 

80 tractWcs = tractInfo.getWcs() 

81 

82 if "sourceType" in cat.columns: 

83 cat = cat.loc[cat["sourceType"] != 0] 

84 

85 # For now also convert the gen 2 patchIds to gen 3 

86 

87 patchInfoDict = {} 

88 maxPatchNum = tractInfo.num_patches.x*tractInfo.num_patches.y 

89 patches = np.arange(0, maxPatchNum, 1) 

90 for patch in patches: 

91 if patch is None: 

92 continue 

93 # Once the objectTable_tract catalogues are using gen 3 patches 

94 # this will go away 

95 onPatch = (cat["patch"] == patch) 

96 stat = np.nanmedian(cat[colName].values[onPatch]) 

97 try: 

98 patchTuple = (int(patch.split(",")[0]), int(patch.split(",")[-1])) 

99 patchInfo = tractInfo.getPatchInfo(patchTuple) 

100 gen3PatchId = tractInfo.getSequentialPatchIndex(patchInfo) 

101 except AttributeError: 

102 # For native gen 3 tables the patches don't need converting 

103 # When we are no longer looking at the gen 2 -> gen 3 

104 # converted repos we can tidy this up 

105 gen3PatchId = patch 

106 patchInfo = tractInfo.getPatchInfo(patch) 

107 

108 corners = Box2D(patchInfo.getInnerBBox()).getCorners() 

109 skyCoords = tractWcs.pixelToSky(corners) 

110 

111 patchInfoDict[gen3PatchId] = (skyCoords, stat) 

112 

113 tractCorners = Box2D(tractInfo.getBBox()).getCorners() 

114 skyCoords = tractWcs.pixelToSky(tractCorners) 

115 patchInfoDict["tract"] = (skyCoords, np.nan) 

116 

117 return patchInfoDict 

118 

119 

120def generateSummaryStatsVisit(cat, colName, visitSummaryTable, plotInfo): 

121 """Generate a summary statistic in each patch or detector 

122 

123 Parameters 

124 ---------- 

125 cat : `pandas.core.frame.DataFrame` 

126 colName : `str` 

127 visitSummaryTable : `pandas.core.frame.DataFrame` 

128 plotInfo : `dict` 

129 

130 Returns 

131 ------- 

132 visitInfoDict : `dict` 

133 """ 

134 

135 visitInfoDict = {} 

136 for ccd in cat.detector.unique(): 

137 if ccd is None: 

138 continue 

139 onCcd = (cat["detector"] == ccd) 

140 stat = np.nanmedian(cat[colName].values[onCcd]) 

141 

142 sumRow = (visitSummaryTable["id"] == ccd) 

143 corners = zip(visitSummaryTable["raCorners"][sumRow][0], visitSummaryTable["decCorners"][sumRow][0]) 

144 cornersOut = [] 

145 for (ra, dec) in corners: 

146 corner = SpherePoint(ra, dec, units=degrees) 

147 cornersOut.append(corner) 

148 

149 visitInfoDict[ccd] = (cornersOut, stat) 

150 

151 return visitInfoDict 

152 

153 

154# Inspired by matplotlib.testing.remove_ticks_and_titles 

155def get_and_remove_axis_text(ax) -> Tuple[List[str], List[np.ndarray]]: 

156 """Remove text from an Axis and its children and return with line points. 

157 

158 Parameters 

159 ---------- 

160 ax : `plt.Axis` 

161 A matplotlib figure axis. 

162 

163 Returns 

164 ------- 

165 texts : `List[str]` 

166 A list of all text strings (title and axis/legend/tick labels). 

167 line_xys : `List[numpy.ndarray]` 

168 A list of all line ``_xy`` attributes (arrays of shape ``(N, 2)``). 

169 """ 

170 line_xys = [line._xy for line in ax.lines] 

171 texts = [text.get_text() for text in (ax.title, ax.xaxis.label, ax.yaxis.label)] 

172 ax.set_title("") 

173 ax.set_xlabel("") 

174 ax.set_ylabel("") 

175 

176 try: 

177 texts_legend = ax.get_legend().texts 

178 texts.extend(text.get_text() for text in texts_legend) 

179 for text in texts_legend: 

180 text.set_alpha(0) 

181 except AttributeError: 

182 pass 

183 

184 for idx in range(len(ax.texts)): 

185 texts.append(ax.texts[idx].get_text()) 

186 ax.texts[idx].set_text('') 

187 

188 ax.xaxis.set_major_formatter(null_formatter) 

189 ax.xaxis.set_minor_formatter(null_formatter) 

190 ax.yaxis.set_major_formatter(null_formatter) 

191 ax.yaxis.set_minor_formatter(null_formatter) 

192 try: 

193 ax.zaxis.set_major_formatter(null_formatter) 

194 ax.zaxis.set_minor_formatter(null_formatter) 

195 except AttributeError: 

196 pass 

197 for child in ax.child_axes: 

198 texts_child, lines_child = get_and_remove_axis_text(child) 

199 texts.extend(texts_child) 

200 

201 return texts, line_xys 

202 

203 

204def get_and_remove_figure_text(figure: plt.Figure): 

205 """Remove text from a Figure and its Axes and return with line points. 

206 

207 Parameters 

208 ---------- 

209 figure : `matplotlib.pyplot.Figure` 

210 A matplotlib figure. 

211 

212 Returns 

213 ------- 

214 texts : `List[str]` 

215 A list of all text strings (title and axis/legend/tick labels). 

216 line_xys : `List[numpy.ndarray]`, (N, 2) 

217 A list of all line ``_xy`` attributes (arrays of shape ``(N, 2)``). 

218 """ 

219 texts = [str(figure._suptitle)] 

220 lines = [] 

221 figure.suptitle("") 

222 

223 texts.extend(text.get_text() for text in figure.texts) 

224 figure.texts = [] 

225 

226 for ax in figure.get_axes(): 

227 texts_ax, lines_ax = get_and_remove_axis_text(ax) 

228 texts.extend(texts_ax) 

229 lines.extend(lines_ax) 

230 

231 return texts, lines 

232 

233 

234def addPlotInfo(fig, plotInfo): 

235 """Add useful information to the plot 

236 

237 Parameters 

238 ---------- 

239 fig : `matplotlib.figure.Figure` 

240 plotInfo : `dict` 

241 

242 Returns 

243 ------- 

244 fig : `matplotlib.figure.Figure` 

245 """ 

246 

247 # TO DO: figure out how to get this information 

248 photocalibDataset = "None" 

249 astroDataset = "None" 

250 

251 fig.text(0.01, 0.99, plotInfo["plotName"], fontsize=8, transform=fig.transFigure, ha="left", va="top") 

252 

253 run = plotInfo["run"] 

254 datasetsUsed = f"\nPhotoCalib: {photocalibDataset}, Astrometry: {astroDataset}" 

255 tableType = f"\nTable: {plotInfo['tractTableType']}" 

256 

257 dataIdText = "" 

258 if str(plotInfo["tract"]) != "N/A": 

259 dataIdText += f", Tract: {plotInfo['tract']}" 

260 if str(plotInfo["visit"]) != "N/A": 

261 dataIdText += f", Visit: {plotInfo['visit']}" 

262 

263 bandsText = f", Bands: {''.join(plotInfo['bands'].split(' '))}" 

264 if isinstance(plotInfo["SN"], str): 

265 SNText = f", S/N: {plotInfo['SN']}" 

266 else: 

267 if np.abs(plotInfo["SN"]) > 1e4: 

268 SNText = f", S/N > {plotInfo['SN']:0.1g} ({plotInfo['SNFlux']})" 

269 else: 

270 SNText = f", S/N > {plotInfo['SN']:0.1f} ({plotInfo['SNFlux']})" 

271 infoText = f"\n{run}{datasetsUsed}{tableType}{dataIdText}{bandsText}{SNText}" 

272 fig.text(0.01, 0.98, infoText, fontsize=7, transform=fig.transFigure, alpha=0.6, ha="left", va="top") 

273 

274 return fig 

275 

276 

277def stellarLocusFit(xs, ys, paramDict): 

278 """Make a fit to the stellar locus 

279 

280 Parameters 

281 ---------- 

282 xs : `numpy.ndarray` [`float`] 

283 The color on the xaxis. 

284 ys : `numpy.ndarray` [`float`] 

285 The color on the yaxis. 

286 paramDict : `lsst.pex.config.dictField.Dict` 

287 A dictionary of parameters for line fitting 

288 xMin : `float` 

289 The minimum x edge of the box to use for initial fitting. 

290 xMax : `float` 

291 The maximum x edge of the box to use for initial fitting. 

292 yMin : `float` 

293 The minimum y edge of the box to use for initial fitting. 

294 yMax : `float` 

295 The maximum y edge of the box to use for initial fitting. 

296 mHW : `float` 

297 The hardwired gradient for the fit. 

298 bHW : `float` 

299 The hardwired intercept of the fit. 

300 

301 Returns 

302 ------- 

303 paramsOut : `dict` 

304 A dictionary of the calculated fit parameters. 

305 mODR0 : `float` 

306 The gradient calculated by the initial ODR fit. 

307 bODR0 : `float` 

308 The intercept calculated by the initial ODR fit. 

309 yBoxMin : `float` 

310 The y value of the fitted line at xMin. 

311 yBoxMax : `float` 

312 The y value of the fitted line at xMax. 

313 bPerpMin : `float` 

314 The intercept of the perpendicular line that goes through xMin. 

315 bPerpMax : `float` 

316 The intercept of the perpendicular line that goes through xMax. 

317 mODR : `float` 

318 The gradient from the second (and final) round of fitting. 

319 bODR : `float` 

320 The intercept from the second (and final) round of fitting. 

321 mPerp : `float` 

322 The gradient of the line perpendicular to the line from the 

323 second fit. 

324 fitPoints : `numpy.ndarray` [`bool`] 

325 A boolean array indicating which points were usee in the final fit. 

326 

327 Notes 

328 ----- 

329 The code does two rounds of fitting, the first is initiated using the 

330 hardwired values given in the `paramDict` parameter and is done using 

331 an Orthogonal Distance Regression fit to the points defined by the 

332 box of xMin, xMax, yMin and yMax. Once this fitting has been done a 

333 perpendicular bisector is calculated at either end of the line and 

334 only points that fall within these lines are used to recalculate the fit. 

335 """ 

336 

337 # Initial subselection of points to use for the fit 

338 fitPoints = ((xs > paramDict["xMin"]) & (xs < paramDict["xMax"]) 

339 & (ys > paramDict["yMin"]) & (ys < paramDict["yMax"])) 

340 

341 linear = scipyODR.polynomial(1) 

342 

343 data = scipyODR.Data(xs[fitPoints], ys[fitPoints]) 

344 odr = scipyODR.ODR(data, linear, beta0=[paramDict["bHW"], paramDict["mHW"]]) 

345 params = odr.run() 

346 mODR0 = float(params.beta[1]) 

347 bODR0 = float(params.beta[0]) 

348 

349 paramsOut = {"mODR0": mODR0, "bODR0": bODR0} 

350 

351 # Having found the initial fit calculate perpendicular ends 

352 mPerp0 = -1.0/mODR0 

353 # When the gradient is really steep we need to use 

354 # the y limits of the box rather than the x ones. 

355 

356 if np.abs(mODR0) > 1: 

357 yBoxMin = paramDict["yMin"] 

358 xBoxMin = (yBoxMin - bODR0)/mODR0 

359 yBoxMax = paramDict["yMax"] 

360 xBoxMax = (yBoxMax - bODR0)/mODR0 

361 else: 

362 yBoxMin = mODR0*paramDict["xMin"] + bODR0 

363 xBoxMin = paramDict["xMin"] 

364 yBoxMax = mODR0*paramDict["xMax"] + bODR0 

365 xBoxMax = paramDict["xMax"] 

366 

367 bPerpMin = yBoxMin - mPerp0*xBoxMin 

368 

369 paramsOut["yBoxMin"] = yBoxMin 

370 paramsOut["bPerpMin"] = bPerpMin 

371 

372 bPerpMax = yBoxMax - mPerp0*xBoxMax 

373 

374 paramsOut["yBoxMax"] = yBoxMax 

375 paramsOut["bPerpMax"] = bPerpMax 

376 

377 # Use these perpendicular lines to chose the data and refit 

378 fitPoints = ((ys > mPerp0*xs + bPerpMin) & (ys < mPerp0*xs + bPerpMax)) 

379 data = scipyODR.Data(xs[fitPoints], ys[fitPoints]) 

380 odr = scipyODR.ODR(data, linear, beta0=[bODR0, mODR0]) 

381 params = odr.run() 

382 

383 paramsOut["mODR"] = float(params.beta[1]) 

384 paramsOut["bODR"] = float(params.beta[0]) 

385 

386 paramsOut["mPerp"] = -1.0/paramsOut["mODR"] 

387 paramsOut["fitPoints"] = fitPoints 

388 

389 return paramsOut 

390 

391 

392def perpDistance(p1, p2, points): 

393 """Calculate the perpendicular distance to a line from a point 

394 

395 Parameters 

396 ---------- 

397 p1 : `numpy.ndarray` 

398 A point on the line 

399 p2 : `numpy.ndarray` 

400 Another point on the line 

401 points : `zip` 

402 The points to calculate the distance to 

403 

404 Returns 

405 ------- 

406 dists : `list` 

407 The distances from the line to the points. Uses the cross 

408 product to work this out. 

409 """ 

410 dists = [] 

411 for point in points: 

412 point = np.array(point) 

413 distToLine = np.cross(p2 - p1, point - p1)/np.linalg.norm(p2 - p1) 

414 dists.append(distToLine) 

415 

416 return dists 

417 

418 

419def mkColormap(colorNames): 

420 """Make a colormap from the list of color names. 

421 

422 Parameters 

423 ---------- 

424 colorNames : `list` 

425 A list of strings that correspond to matplotlib 

426 named colors. 

427 

428 Returns 

429 ------- 

430 cmap : `matplotlib.colors.LinearSegmentedColormap` 

431 """ 

432 

433 nums = np.linspace(0, 1, len(colorNames)) 

434 blues = [] 

435 greens = [] 

436 reds = [] 

437 for (num, color) in zip(nums, colorNames): 

438 r, g, b = colors.colorConverter.to_rgb(color) 

439 blues.append((num, b, b)) 

440 greens.append((num, g, g)) 

441 reds.append((num, r, r)) 

442 

443 colorDict = {"blue": blues, "red": reds, "green": greens} 

444 cmap = colors.LinearSegmentedColormap("newCmap", colorDict) 

445 return cmap 

446 

447 

448def extremaSort(xs): 

449 """Return the ids of the points reordered so that those 

450 furthest from the median, in absolute terms, are last. 

451 

452 Parameters 

453 ---------- 

454 xs : `np.array` 

455 An array of the values to sort 

456 

457 Returns 

458 ------- 

459 ids : `np.array` 

460 """ 

461 

462 med = np.median(xs) 

463 dists = np.abs(xs - med) 

464 ids = np.argsort(dists) 

465 return ids