Coverage for python/lsst/analysis/drp/plotUtils.py: 8%

183 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-11-02 11:50 +0000

1# This file is part of analysis_drp. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import numpy as np 

23import matplotlib.pyplot as plt 

24import scipy.odr as scipyODR 

25import matplotlib 

26from matplotlib import colors 

27from typing import List, Tuple 

28 

29from lsst.geom import Box2D, SpherePoint, degrees 

30 

31null_formatter = matplotlib.ticker.NullFormatter() 

32 

33 

34def parsePlotInfo(dataId, runName, tableName, bands, plotName, SN, SNFlux): 

35 """Parse plot info from the dataId 

36 

37 Parameters 

38 ---------- 

39 dataId : `lsst.daf.butler.DataCoordinate` 

40 runName : `str` 

41 

42 Returns 

43 ------- 

44 plotInfo : `dict` 

45 """ 

46 plotInfo = {"run": runName, "tractTableType": tableName, "plotName": plotName, "SN": SN, "SNFlux": SNFlux} 

47 

48 for dataInfo in dataId: 

49 plotInfo[dataInfo.name] = dataId[dataInfo.name] 

50 

51 bandStr = "" 

52 for band in bands: 

53 bandStr += (", " + band) 

54 plotInfo["bands"] = bandStr[2:] 

55 

56 if "tract" not in plotInfo.keys(): 

57 plotInfo["tract"] = "N/A" 

58 if "visit" not in plotInfo.keys(): 

59 plotInfo["visit"] = "N/A" 

60 

61 return plotInfo 

62 

63 

64def generateSummaryStats(cat, colName, skymap, plotInfo): 

65 """Generate a summary statistic in each patch or detector 

66 

67 Parameters 

68 ---------- 

69 cat : `pandas.core.frame.DataFrame` 

70 colName : `str` 

71 skymap : `lsst.skymap.ringsSkyMap.RingsSkyMap` 

72 plotInfo : `dict` 

73 

74 Returns 

75 ------- 

76 patchInfoDict : `dict` 

77 """ 

78 

79 # TODO: what is the more generic type of skymap? 

80 tractInfo = skymap.generateTract(plotInfo["tract"]) 

81 tractWcs = tractInfo.getWcs() 

82 

83 if "sourceType" in cat.columns: 

84 cat = cat.loc[cat["sourceType"] != 0] 

85 

86 # For now also convert the gen 2 patchIds to gen 3 

87 

88 patchInfoDict = {} 

89 maxPatchNum = tractInfo.num_patches.x*tractInfo.num_patches.y 

90 patches = np.arange(0, maxPatchNum, 1) 

91 for patch in patches: 

92 if patch is None: 

93 continue 

94 # Once the objectTable_tract catalogues are using gen 3 patches 

95 # this will go away 

96 onPatch = (cat["patch"] == patch) 

97 stat = np.nanmedian(cat[colName].values[onPatch]) 

98 try: 

99 patchTuple = (int(patch.split(",")[0]), int(patch.split(",")[-1])) 

100 patchInfo = tractInfo.getPatchInfo(patchTuple) 

101 gen3PatchId = tractInfo.getSequentialPatchIndex(patchInfo) 

102 except AttributeError: 

103 # For native gen 3 tables the patches don't need converting 

104 # When we are no longer looking at the gen 2 -> gen 3 

105 # converted repos we can tidy this up 

106 gen3PatchId = patch 

107 patchInfo = tractInfo.getPatchInfo(patch) 

108 

109 corners = Box2D(patchInfo.getInnerBBox()).getCorners() 

110 skyCoords = tractWcs.pixelToSky(corners) 

111 

112 patchInfoDict[gen3PatchId] = (skyCoords, stat) 

113 

114 tractCorners = Box2D(tractInfo.getBBox()).getCorners() 

115 skyCoords = tractWcs.pixelToSky(tractCorners) 

116 patchInfoDict["tract"] = (skyCoords, np.nan) 

117 

118 return patchInfoDict 

119 

120 

121def generateSummaryStatsVisit(cat, colName, visitSummaryTable, plotInfo): 

122 """Generate a summary statistic in each patch or detector 

123 

124 Parameters 

125 ---------- 

126 cat : `pandas.core.frame.DataFrame` 

127 colName : `str` 

128 visitSummaryTable : `pandas.core.frame.DataFrame` 

129 plotInfo : `dict` 

130 

131 Returns 

132 ------- 

133 visitInfoDict : `dict` 

134 """ 

135 

136 visitInfoDict = {} 

137 for ccd in cat.detector.unique(): 

138 if ccd is None: 

139 continue 

140 onCcd = (cat["detector"] == ccd) 

141 stat = np.nanmedian(cat[colName].values[onCcd]) 

142 

143 sumRow = (visitSummaryTable["id"] == ccd) 

144 corners = zip(visitSummaryTable["raCorners"][sumRow][0], visitSummaryTable["decCorners"][sumRow][0]) 

145 cornersOut = [] 

146 for (ra, dec) in corners: 

147 corner = SpherePoint(ra, dec, units=degrees) 

148 cornersOut.append(corner) 

149 

150 visitInfoDict[ccd] = (cornersOut, stat) 

151 

152 return visitInfoDict 

153 

154 

155# Inspired by matplotlib.testing.remove_ticks_and_titles 

156def get_and_remove_axis_text(ax) -> Tuple[List[str], List[np.ndarray]]: 

157 """Remove text from an Axis and its children and return with line points. 

158 

159 Parameters 

160 ---------- 

161 ax : `plt.Axis` 

162 A matplotlib figure axis. 

163 

164 Returns 

165 ------- 

166 texts : `List[str]` 

167 A list of all text strings (title and axis/legend/tick labels). 

168 line_xys : `List[numpy.ndarray]` 

169 A list of all line ``_xy`` attributes (arrays of shape ``(N, 2)``). 

170 """ 

171 line_xys = [line._xy for line in ax.lines] 

172 texts = [text.get_text() for text in (ax.title, ax.xaxis.label, ax.yaxis.label)] 

173 ax.set_title("") 

174 ax.set_xlabel("") 

175 ax.set_ylabel("") 

176 

177 try: 

178 texts_legend = ax.get_legend().texts 

179 texts.extend(text.get_text() for text in texts_legend) 

180 for text in texts_legend: 

181 text.set_alpha(0) 

182 except AttributeError: 

183 pass 

184 

185 for idx in range(len(ax.texts)): 

186 texts.append(ax.texts[idx].get_text()) 

187 ax.texts[idx].set_text('') 

188 

189 ax.xaxis.set_major_formatter(null_formatter) 

190 ax.xaxis.set_minor_formatter(null_formatter) 

191 ax.yaxis.set_major_formatter(null_formatter) 

192 ax.yaxis.set_minor_formatter(null_formatter) 

193 try: 

194 ax.zaxis.set_major_formatter(null_formatter) 

195 ax.zaxis.set_minor_formatter(null_formatter) 

196 except AttributeError: 

197 pass 

198 for child in ax.child_axes: 

199 texts_child, lines_child = get_and_remove_axis_text(child) 

200 texts.extend(texts_child) 

201 

202 return texts, line_xys 

203 

204 

205def get_and_remove_figure_text(figure: plt.Figure): 

206 """Remove text from a Figure and its Axes and return with line points. 

207 

208 Parameters 

209 ---------- 

210 figure : `matplotlib.pyplot.Figure` 

211 A matplotlib figure. 

212 

213 Returns 

214 ------- 

215 texts : `List[str]` 

216 A list of all text strings (title and axis/legend/tick labels). 

217 line_xys : `List[numpy.ndarray]`, (N, 2) 

218 A list of all line ``_xy`` attributes (arrays of shape ``(N, 2)``). 

219 """ 

220 texts = [str(figure._suptitle)] 

221 lines = [] 

222 figure.suptitle("") 

223 

224 texts.extend(text.get_text() for text in figure.texts) 

225 figure.texts = [] 

226 

227 for ax in figure.get_axes(): 

228 texts_ax, lines_ax = get_and_remove_axis_text(ax) 

229 texts.extend(texts_ax) 

230 lines.extend(lines_ax) 

231 

232 return texts, lines 

233 

234 

235def addPlotInfo(fig, plotInfo): 

236 """Add useful information to the plot 

237 

238 Parameters 

239 ---------- 

240 fig : `matplotlib.figure.Figure` 

241 plotInfo : `dict` 

242 

243 Returns 

244 ------- 

245 fig : `matplotlib.figure.Figure` 

246 """ 

247 

248 # TO DO: figure out how to get this information 

249 photocalibDataset = "None" 

250 astroDataset = "None" 

251 

252 fig.text(0.01, 0.99, plotInfo["plotName"], fontsize=8, transform=fig.transFigure, ha="left", va="top") 

253 

254 run = plotInfo["run"] 

255 datasetsUsed = f"\nPhotoCalib: {photocalibDataset}, Astrometry: {astroDataset}" 

256 tableType = f"\nTable: {plotInfo['tractTableType']}" 

257 

258 dataIdText = "" 

259 if str(plotInfo["tract"]) != "N/A": 

260 dataIdText += f", Tract: {plotInfo['tract']}" 

261 if str(plotInfo["visit"]) != "N/A": 

262 dataIdText += f", Visit: {plotInfo['visit']}" 

263 

264 bandsText = f", Bands: {''.join(plotInfo['bands'].split(' '))}" 

265 if isinstance(plotInfo["SN"], str): 

266 SNText = f", S/N: {plotInfo['SN']}" 

267 else: 

268 if np.abs(plotInfo["SN"]) > 1e4: 

269 SNText = f", S/N > {plotInfo['SN']:0.1g} ({plotInfo['SNFlux']})" 

270 else: 

271 SNText = f", S/N > {plotInfo['SN']:0.1f} ({plotInfo['SNFlux']})" 

272 infoText = f"\n{run}{datasetsUsed}{tableType}{dataIdText}{bandsText}{SNText}" 

273 fig.text(0.01, 0.98, infoText, fontsize=7, transform=fig.transFigure, alpha=0.6, ha="left", va="top") 

274 

275 return fig 

276 

277 

278def stellarLocusFit(xs, ys, paramDict): 

279 """Make a fit to the stellar locus 

280 

281 Parameters 

282 ---------- 

283 xs : `numpy.ndarray` [`float`] 

284 The color on the xaxis. 

285 ys : `numpy.ndarray` [`float`] 

286 The color on the yaxis. 

287 paramDict : `lsst.pex.config.dictField.Dict` 

288 A dictionary of parameters for line fitting 

289 xMin : `float` 

290 The minimum x edge of the box to use for initial fitting. 

291 xMax : `float` 

292 The maximum x edge of the box to use for initial fitting. 

293 yMin : `float` 

294 The minimum y edge of the box to use for initial fitting. 

295 yMax : `float` 

296 The maximum y edge of the box to use for initial fitting. 

297 mHW : `float` 

298 The hardwired gradient for the fit. 

299 bHW : `float` 

300 The hardwired intercept of the fit. 

301 

302 Returns 

303 ------- 

304 paramsOut : `dict` 

305 A dictionary of the calculated fit parameters. 

306 mODR0 : `float` 

307 The gradient calculated by the initial ODR fit. 

308 bODR0 : `float` 

309 The intercept calculated by the initial ODR fit. 

310 yBoxMin : `float` 

311 The y value of the fitted line at xMin. 

312 yBoxMax : `float` 

313 The y value of the fitted line at xMax. 

314 bPerpMin : `float` 

315 The intercept of the perpendicular line that goes through xMin. 

316 bPerpMax : `float` 

317 The intercept of the perpendicular line that goes through xMax. 

318 mODR : `float` 

319 The gradient from the second (and final) round of fitting. 

320 bODR : `float` 

321 The intercept from the second (and final) round of fitting. 

322 mPerp : `float` 

323 The gradient of the line perpendicular to the line from the 

324 second fit. 

325 fitPoints : `numpy.ndarray` [`bool`] 

326 A boolean array indicating which points were usee in the final fit. 

327 

328 Notes 

329 ----- 

330 The code does two rounds of fitting, the first is initiated using the 

331 hardwired values given in the `paramDict` parameter and is done using 

332 an Orthogonal Distance Regression fit to the points defined by the 

333 box of xMin, xMax, yMin and yMax. Once this fitting has been done a 

334 perpendicular bisector is calculated at either end of the line and 

335 only points that fall within these lines are used to recalculate the fit. 

336 """ 

337 

338 # Initial subselection of points to use for the fit 

339 fitPoints = ((xs > paramDict["xMin"]) & (xs < paramDict["xMax"]) 

340 & (ys > paramDict["yMin"]) & (ys < paramDict["yMax"])) 

341 

342 linear = scipyODR.polynomial(1) 

343 

344 data = scipyODR.Data(xs[fitPoints], ys[fitPoints]) 

345 odr = scipyODR.ODR(data, linear, beta0=[paramDict["bHW"], paramDict["mHW"]]) 

346 params = odr.run() 

347 mODR0 = float(params.beta[1]) 

348 bODR0 = float(params.beta[0]) 

349 

350 paramsOut = {"mODR0": mODR0, "bODR0": bODR0} 

351 

352 # Having found the initial fit calculate perpendicular ends 

353 mPerp0 = -1.0/mODR0 

354 # When the gradient is really steep we need to use 

355 # the y limits of the box rather than the x ones. 

356 

357 if np.abs(mODR0) > 1: 

358 yBoxMin = paramDict["yMin"] 

359 xBoxMin = (yBoxMin - bODR0)/mODR0 

360 yBoxMax = paramDict["yMax"] 

361 xBoxMax = (yBoxMax - bODR0)/mODR0 

362 else: 

363 yBoxMin = mODR0*paramDict["xMin"] + bODR0 

364 xBoxMin = paramDict["xMin"] 

365 yBoxMax = mODR0*paramDict["xMax"] + bODR0 

366 xBoxMax = paramDict["xMax"] 

367 

368 bPerpMin = yBoxMin - mPerp0*xBoxMin 

369 

370 paramsOut["yBoxMin"] = yBoxMin 

371 paramsOut["bPerpMin"] = bPerpMin 

372 

373 bPerpMax = yBoxMax - mPerp0*xBoxMax 

374 

375 paramsOut["yBoxMax"] = yBoxMax 

376 paramsOut["bPerpMax"] = bPerpMax 

377 

378 # Use these perpendicular lines to chose the data and refit 

379 fitPoints = ((ys > mPerp0*xs + bPerpMin) & (ys < mPerp0*xs + bPerpMax)) 

380 data = scipyODR.Data(xs[fitPoints], ys[fitPoints]) 

381 odr = scipyODR.ODR(data, linear, beta0=[bODR0, mODR0]) 

382 params = odr.run() 

383 

384 paramsOut["mODR"] = float(params.beta[1]) 

385 paramsOut["bODR"] = float(params.beta[0]) 

386 

387 paramsOut["mPerp"] = -1.0/paramsOut["mODR"] 

388 paramsOut["fitPoints"] = fitPoints 

389 

390 return paramsOut 

391 

392 

393def perpDistance(p1, p2, points): 

394 """Calculate the perpendicular distance to a line from a point 

395 

396 Parameters 

397 ---------- 

398 p1 : `numpy.ndarray` 

399 A point on the line 

400 p2 : `numpy.ndarray` 

401 Another point on the line 

402 points : `zip` 

403 The points to calculate the distance to 

404 

405 Returns 

406 ------- 

407 dists : `list` 

408 The distances from the line to the points. Uses the cross 

409 product to work this out. 

410 """ 

411 dists = [] 

412 for point in points: 

413 point = np.array(point) 

414 distToLine = np.cross(p2 - p1, point - p1)/np.linalg.norm(p2 - p1) 

415 dists.append(distToLine) 

416 

417 return dists 

418 

419 

420def mkColormap(colorNames): 

421 """Make a colormap from the list of color names. 

422 

423 Parameters 

424 ---------- 

425 colorNames : `list` 

426 A list of strings that correspond to matplotlib 

427 named colors. 

428 

429 Returns 

430 ------- 

431 cmap : `matplotlib.colors.LinearSegmentedColormap` 

432 """ 

433 

434 nums = np.linspace(0, 1, len(colorNames)) 

435 blues = [] 

436 greens = [] 

437 reds = [] 

438 for (num, color) in zip(nums, colorNames): 

439 r, g, b = colors.colorConverter.to_rgb(color) 

440 blues.append((num, b, b)) 

441 greens.append((num, g, g)) 

442 reds.append((num, r, r)) 

443 

444 colorDict = {"blue": blues, "red": reds, "green": greens} 

445 cmap = colors.LinearSegmentedColormap("newCmap", colorDict) 

446 return cmap 

447 

448 

449def extremaSort(xs): 

450 """Return the ids of the points reordered so that those 

451 furthest from the median, in absolute terms, are last. 

452 

453 Parameters 

454 ---------- 

455 xs : `np.array` 

456 An array of the values to sort 

457 

458 Returns 

459 ------- 

460 ids : `np.array` 

461 """ 

462 

463 med = np.median(xs) 

464 dists = np.abs(xs - med) 

465 ids = np.argsort(dists) 

466 return ids