Coverage for python/lsst/analysis/drp/plotUtils.py: 8%

184 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-03-25 02:50 -0700

1# This file is part of analysis_drp. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import numpy as np 

23import matplotlib.pyplot as plt 

24import scipy.odr as scipyODR 

25import matplotlib 

26from matplotlib import colors 

27from typing import List, Tuple 

28 

29from lsst.geom import Box2D, SpherePoint, degrees 

30 

31null_formatter = matplotlib.ticker.NullFormatter() 

32 

33 

34def parsePlotInfo(dataId, runName, tableName, bands, plotName, SN, SNFlux): 

35 """Parse plot info from the dataId 

36 

37 Parameters 

38 ---------- 

39 dataId : `lsst.daf.butler.core.dimensions.` 

40 `_coordinate._ExpandedTupleDataCoordinate` 

41 runName : `str` 

42 

43 Returns 

44 ------- 

45 plotInfo : `dict` 

46 """ 

47 plotInfo = {"run": runName, "tractTableType": tableName, "plotName": plotName, "SN": SN, "SNFlux": SNFlux} 

48 

49 for dataInfo in dataId: 

50 plotInfo[dataInfo.name] = dataId[dataInfo.name] 

51 

52 bandStr = "" 

53 for band in bands: 

54 bandStr += (", " + band) 

55 plotInfo["bands"] = bandStr[2:] 

56 

57 if "tract" not in plotInfo.keys(): 

58 plotInfo["tract"] = "N/A" 

59 if "visit" not in plotInfo.keys(): 

60 plotInfo["visit"] = "N/A" 

61 

62 return plotInfo 

63 

64 

65def generateSummaryStats(cat, colName, skymap, plotInfo): 

66 """Generate a summary statistic in each patch or detector 

67 

68 Parameters 

69 ---------- 

70 cat : `pandas.core.frame.DataFrame` 

71 colName : `str` 

72 skymap : `lsst.skymap.ringsSkyMap.RingsSkyMap` 

73 plotInfo : `dict` 

74 

75 Returns 

76 ------- 

77 patchInfoDict : `dict` 

78 """ 

79 

80 # TODO: what is the more generic type of skymap? 

81 tractInfo = skymap.generateTract(plotInfo["tract"]) 

82 tractWcs = tractInfo.getWcs() 

83 

84 if "sourceType" in cat.columns: 

85 cat = cat.loc[cat["sourceType"] != 0] 

86 

87 # For now also convert the gen 2 patchIds to gen 3 

88 

89 patchInfoDict = {} 

90 maxPatchNum = tractInfo.num_patches.x*tractInfo.num_patches.y 

91 patches = np.arange(0, maxPatchNum, 1) 

92 for patch in patches: 

93 if patch is None: 

94 continue 

95 # Once the objectTable_tract catalogues are using gen 3 patches 

96 # this will go away 

97 onPatch = (cat["patch"] == patch) 

98 stat = np.nanmedian(cat[colName].values[onPatch]) 

99 try: 

100 patchTuple = (int(patch.split(",")[0]), int(patch.split(",")[-1])) 

101 patchInfo = tractInfo.getPatchInfo(patchTuple) 

102 gen3PatchId = tractInfo.getSequentialPatchIndex(patchInfo) 

103 except AttributeError: 

104 # For native gen 3 tables the patches don't need converting 

105 # When we are no longer looking at the gen 2 -> gen 3 

106 # converted repos we can tidy this up 

107 gen3PatchId = patch 

108 patchInfo = tractInfo.getPatchInfo(patch) 

109 

110 corners = Box2D(patchInfo.getInnerBBox()).getCorners() 

111 skyCoords = tractWcs.pixelToSky(corners) 

112 

113 patchInfoDict[gen3PatchId] = (skyCoords, stat) 

114 

115 tractCorners = Box2D(tractInfo.getBBox()).getCorners() 

116 skyCoords = tractWcs.pixelToSky(tractCorners) 

117 patchInfoDict["tract"] = (skyCoords, np.nan) 

118 

119 return patchInfoDict 

120 

121 

122def generateSummaryStatsVisit(cat, colName, visitSummaryTable, plotInfo): 

123 """Generate a summary statistic in each patch or detector 

124 

125 Parameters 

126 ---------- 

127 cat : `pandas.core.frame.DataFrame` 

128 colName : `str` 

129 visitSummaryTable : `pandas.core.frame.DataFrame` 

130 plotInfo : `dict` 

131 

132 Returns 

133 ------- 

134 visitInfoDict : `dict` 

135 """ 

136 

137 visitInfoDict = {} 

138 for ccd in cat.detector.unique(): 

139 if ccd is None: 

140 continue 

141 onCcd = (cat["detector"] == ccd) 

142 stat = np.nanmedian(cat[colName].values[onCcd]) 

143 

144 sumRow = (visitSummaryTable["id"] == ccd) 

145 corners = zip(visitSummaryTable["raCorners"][sumRow][0], visitSummaryTable["decCorners"][sumRow][0]) 

146 cornersOut = [] 

147 for (ra, dec) in corners: 

148 corner = SpherePoint(ra, dec, units=degrees) 

149 cornersOut.append(corner) 

150 

151 visitInfoDict[ccd] = (cornersOut, stat) 

152 

153 return visitInfoDict 

154 

155 

156# Inspired by matplotlib.testing.remove_ticks_and_titles 

157def get_and_remove_axis_text(ax) -> Tuple[List[str], List[np.ndarray]]: 

158 """Remove text from an Axis and its children and return with line points. 

159 

160 Parameters 

161 ---------- 

162 ax : `plt.Axis` 

163 A matplotlib figure axis. 

164 

165 Returns 

166 ------- 

167 texts : `List[str]` 

168 A list of all text strings (title and axis/legend/tick labels). 

169 line_xys : `List[numpy.ndarray]` 

170 A list of all line ``_xy`` attributes (arrays of shape ``(N, 2)``). 

171 """ 

172 line_xys = [line._xy for line in ax.lines] 

173 texts = [text.get_text() for text in (ax.title, ax.xaxis.label, ax.yaxis.label)] 

174 ax.set_title("") 

175 ax.set_xlabel("") 

176 ax.set_ylabel("") 

177 

178 try: 

179 texts_legend = ax.get_legend().texts 

180 texts.extend(text.get_text() for text in texts_legend) 

181 for text in texts_legend: 

182 text.set_alpha(0) 

183 except AttributeError: 

184 pass 

185 

186 for idx in range(len(ax.texts)): 

187 texts.append(ax.texts[idx].get_text()) 

188 ax.texts[idx].set_text('') 

189 

190 ax.xaxis.set_major_formatter(null_formatter) 

191 ax.xaxis.set_minor_formatter(null_formatter) 

192 ax.yaxis.set_major_formatter(null_formatter) 

193 ax.yaxis.set_minor_formatter(null_formatter) 

194 try: 

195 ax.zaxis.set_major_formatter(null_formatter) 

196 ax.zaxis.set_minor_formatter(null_formatter) 

197 except AttributeError: 

198 pass 

199 for child in ax.child_axes: 

200 texts_child, lines_child = get_and_remove_axis_text(child) 

201 texts.extend(texts_child) 

202 

203 return texts, line_xys 

204 

205 

206def get_and_remove_figure_text(figure: plt.Figure): 

207 """Remove text from a Figure and its Axes and return with line points. 

208 

209 Parameters 

210 ---------- 

211 figure : `matplotlib.pyplot.Figure` 

212 A matplotlib figure. 

213 

214 Returns 

215 ------- 

216 texts : `List[str]` 

217 A list of all text strings (title and axis/legend/tick labels). 

218 line_xys : `List[numpy.ndarray]`, (N, 2) 

219 A list of all line ``_xy`` attributes (arrays of shape ``(N, 2)``). 

220 """ 

221 texts = [str(figure._suptitle)] 

222 lines = [] 

223 figure.suptitle("") 

224 

225 texts.extend(text.get_text() for text in figure.texts) 

226 figure.texts = [] 

227 

228 for ax in figure.get_axes(): 

229 texts_ax, lines_ax = get_and_remove_axis_text(ax) 

230 texts.extend(texts_ax) 

231 lines.extend(lines_ax) 

232 

233 return texts, lines 

234 

235 

236def addPlotInfo(fig, plotInfo): 

237 """Add useful information to the plot 

238 

239 Parameters 

240 ---------- 

241 fig : `matplotlib.figure.Figure` 

242 plotInfo : `dict` 

243 

244 Returns 

245 ------- 

246 fig : `matplotlib.figure.Figure` 

247 """ 

248 

249 # TO DO: figure out how to get this information 

250 photocalibDataset = "None" 

251 astroDataset = "None" 

252 

253 fig.text(0.01, 0.99, plotInfo["plotName"], fontsize=8, transform=fig.transFigure, ha="left", va="top") 

254 

255 run = plotInfo["run"] 

256 datasetsUsed = f"\nPhotoCalib: {photocalibDataset}, Astrometry: {astroDataset}" 

257 tableType = f"\nTable: {plotInfo['tractTableType']}" 

258 

259 dataIdText = "" 

260 if str(plotInfo["tract"]) != "N/A": 

261 dataIdText += f", Tract: {plotInfo['tract']}" 

262 if str(plotInfo["visit"]) != "N/A": 

263 dataIdText += f", Visit: {plotInfo['visit']}" 

264 

265 bandsText = f", Bands: {''.join(plotInfo['bands'].split(' '))}" 

266 if isinstance(plotInfo["SN"], str): 

267 SNText = f", S/N: {plotInfo['SN']}" 

268 else: 

269 if np.abs(plotInfo["SN"]) > 1e4: 

270 SNText = f", S/N > {plotInfo['SN']:0.1g} ({plotInfo['SNFlux']})" 

271 else: 

272 SNText = f", S/N > {plotInfo['SN']:0.1f} ({plotInfo['SNFlux']})" 

273 infoText = f"\n{run}{datasetsUsed}{tableType}{dataIdText}{bandsText}{SNText}" 

274 fig.text(0.01, 0.98, infoText, fontsize=7, transform=fig.transFigure, alpha=0.6, ha="left", va="top") 

275 

276 return fig 

277 

278 

279def stellarLocusFit(xs, ys, paramDict): 

280 """Make a fit to the stellar locus 

281 

282 Parameters 

283 ---------- 

284 xs : `numpy.ndarray` 

285 The color on the xaxis 

286 ys : `numpy.ndarray` 

287 The color on the yaxis 

288 paramDict : lsst.pex.config.dictField.Dict 

289 A dictionary of parameters for line fitting 

290 xMin : `float` 

291 The minimum x edge of the box to use for initial fitting 

292 xMax : `float` 

293 The maximum x edge of the box to use for initial fitting 

294 yMin : `float` 

295 The minimum y edge of the box to use for initial fitting 

296 yMax : `float` 

297 The maximum y edge of the box to use for initial fitting 

298 mHW : `float` 

299 The hardwired gradient for the fit 

300 bHW : `float` 

301 The hardwired intercept of the fit 

302 

303 Returns 

304 ------- 

305 paramsOut : `dict` 

306 A dictionary of the calculated fit parameters 

307 xMin : `float` 

308 The minimum x edge of the box to use for initial fitting 

309 xMax : `float` 

310 The maximum x edge of the box to use for initial fitting 

311 yMin : `float` 

312 The minimum y edge of the box to use for initial fitting 

313 yMax : `float` 

314 The maximum y edge of the box to use for initial fitting 

315 mHW : `float` 

316 The hardwired gradient for the fit 

317 bHW : `float` 

318 The hardwired intercept of the fit 

319 mODR : `float` 

320 The gradient calculated by the ODR fit 

321 bODR : `float` 

322 The intercept calculated by the ODR fit 

323 yBoxMin : `float` 

324 The y value of the fitted line at xMin 

325 yBoxMax : `float` 

326 The y value of the fitted line at xMax 

327 bPerpMin : `float` 

328 The intercept of the perpendicular line that goes through xMin 

329 bPerpMax : `float` 

330 The intercept of the perpendicular line that goes through xMax 

331 mODR2 : `float` 

332 The gradient from the second round of fitting 

333 bODR2 : `float` 

334 The intercept from the second round of fitting 

335 mPerp : `float` 

336 The gradient of the line perpendicular to the line from the 

337 second fit 

338 

339 Notes 

340 ----- 

341 The code does two rounds of fitting, the first is initiated using the 

342 hardwired values given in the `paramDict` parameter and is done using 

343 an Orthogonal Distance Regression fit to the points defined by the 

344 box of xMin, xMax, yMin and yMax. Once this fitting has been done a 

345 perpendicular bisector is calculated at either end of the line and 

346 only points that fall within these lines are used to recalculate the fit. 

347 """ 

348 

349 # Points to use for the fit 

350 fitPoints = np.where((xs > paramDict["xMin"]) & (xs < paramDict["xMax"]) 

351 & (ys > paramDict["yMin"]) & (ys < paramDict["yMax"]))[0] 

352 

353 linear = scipyODR.polynomial(1) 

354 

355 data = scipyODR.Data(xs[fitPoints], ys[fitPoints]) 

356 odr = scipyODR.ODR(data, linear, beta0=[paramDict["bHW"], paramDict["mHW"]]) 

357 params = odr.run() 

358 mODR = float(params.beta[1]) 

359 bODR = float(params.beta[0]) 

360 

361 paramsOut = {"xMin": paramDict["xMin"], "xMax": paramDict["xMax"], "yMin": paramDict["yMin"], 

362 "yMax": paramDict["yMax"], "mHW": paramDict["mHW"], "bHW": paramDict["bHW"], 

363 "mODR": mODR, "bODR": bODR} 

364 

365 # Having found the initial fit calculate perpendicular ends 

366 mPerp = -1.0/mODR 

367 # When the gradient is really steep we need to use 

368 # the y limits of the box rather than the x ones 

369 

370 if np.abs(mODR) > 1: 

371 yBoxMin = paramDict["yMin"] 

372 xBoxMin = (yBoxMin - bODR)/mODR 

373 yBoxMax = paramDict["yMax"] 

374 xBoxMax = (yBoxMax - bODR)/mODR 

375 else: 

376 yBoxMin = mODR*paramDict["xMin"] + bODR 

377 xBoxMin = paramDict["xMin"] 

378 yBoxMax = mODR*paramDict["xMax"] + bODR 

379 xBoxMax = paramDict["xMax"] 

380 

381 bPerpMin = yBoxMin - mPerp*xBoxMin 

382 

383 paramsOut["yBoxMin"] = yBoxMin 

384 paramsOut["bPerpMin"] = bPerpMin 

385 

386 bPerpMax = yBoxMax - mPerp*xBoxMax 

387 

388 paramsOut["yBoxMax"] = yBoxMax 

389 paramsOut["bPerpMax"] = bPerpMax 

390 

391 # Use these perpendicular lines to chose the data and refit 

392 fitPoints = ((ys > mPerp*xs + bPerpMin) & (ys < mPerp*xs + bPerpMax)) 

393 data = scipyODR.Data(xs[fitPoints], ys[fitPoints]) 

394 odr = scipyODR.ODR(data, linear, beta0=[bODR, mODR]) 

395 params = odr.run() 

396 mODR = float(params.beta[1]) 

397 bODR = float(params.beta[0]) 

398 

399 paramsOut["mODR2"] = float(params.beta[1]) 

400 paramsOut["bODR2"] = float(params.beta[0]) 

401 

402 paramsOut["mPerp"] = -1.0/paramsOut["mODR2"] 

403 

404 return paramsOut 

405 

406 

407def perpDistance(p1, p2, points): 

408 """Calculate the perpendicular distance to a line from a point 

409 

410 Parameters 

411 ---------- 

412 p1 : `numpy.ndarray` 

413 A point on the line 

414 p2 : `numpy.ndarray` 

415 Another point on the line 

416 points : `zip` 

417 The points to calculate the distance to 

418 

419 Returns 

420 ------- 

421 dists : `list` 

422 The distances from the line to the points. Uses the cross 

423 product to work this out. 

424 """ 

425 dists = [] 

426 for point in points: 

427 point = np.array(point) 

428 distToLine = np.cross(p2 - p1, point - p1)/np.linalg.norm(p2 - p1) 

429 dists.append(distToLine) 

430 

431 return dists 

432 

433 

434def mkColormap(colorNames): 

435 """Make a colormap from the list of color names. 

436 

437 Parameters 

438 ---------- 

439 colorNames : `list` 

440 A list of strings that correspond to matplotlib 

441 named colors. 

442 

443 Returns 

444 ------- 

445 cmap : `matplotlib.colors.LinearSegmentedColormap` 

446 """ 

447 

448 nums = np.linspace(0, 1, len(colorNames)) 

449 blues = [] 

450 greens = [] 

451 reds = [] 

452 for (num, color) in zip(nums, colorNames): 

453 r, g, b = colors.colorConverter.to_rgb(color) 

454 blues.append((num, b, b)) 

455 greens.append((num, g, g)) 

456 reds.append((num, r, r)) 

457 

458 colorDict = {"blue": blues, "red": reds, "green": greens} 

459 cmap = colors.LinearSegmentedColormap("newCmap", colorDict) 

460 return cmap 

461 

462 

463def extremaSort(xs): 

464 """Return the ids of the points reordered so that those 

465 furthest from the median, in absolute terms, are last. 

466 

467 Parameters 

468 ---------- 

469 xs : `np.array` 

470 An array of the values to sort 

471 

472 Returns 

473 ------- 

474 ids : `np.array` 

475 """ 

476 

477 med = np.median(xs) 

478 dists = np.abs(xs - med) 

479 ids = np.argsort(dists) 

480 return ids