Coverage for python/lsst/analysis/tools/actions/plot/plotUtils.py: 7%

234 statements  

« prev     ^ index     » next       coverage.py v6.4.4, created at 2022-08-20 02:31 -0700

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23from typing import List, Tuple 

24 

25import matplotlib 

26import matplotlib.pyplot as plt 

27import numpy as np 

28import scipy.odr as scipyODR 

29from lsst.geom import Box2D, SpherePoint, degrees 

30from matplotlib import colors 

31from matplotlib.collections import PatchCollection 

32from matplotlib.patches import Rectangle 

33 

34null_formatter = matplotlib.ticker.NullFormatter() 

35 

36 

37def parsePlotInfo(dataId, runName, tableName, bands, plotName, SN): 

38 """Parse plot info from the dataId 

39 Parameters 

40 ---------- 

41 dataId : `lsst.daf.butler.core.dimensions.` 

42 `_coordinate._ExpandedTupleDataCoordinate` 

43 runName : `str` 

44 Returns 

45 ------- 

46 plotInfo : `dict` 

47 """ 

48 plotInfo = {"run": runName, "tableName": tableName, "plotName": plotName, "SN": SN} 

49 

50 for dataInfo in dataId: 

51 plotInfo[dataInfo.name] = dataId[dataInfo.name] 

52 

53 bandStr = "" 

54 for band in bands: 

55 bandStr += ", " + band 

56 plotInfo["bands"] = bandStr[2:] 

57 

58 if "tract" not in plotInfo.keys(): 

59 plotInfo["tract"] = "N/A" 

60 if "visit" not in plotInfo.keys(): 

61 plotInfo["visit"] = "N/A" 

62 

63 return plotInfo 

64 

65 

66def generateSummaryStats(data, skymap, plotInfo): 

67 """Generate a summary statistic in each patch or detector 

68 Parameters 

69 ---------- 

70 data : `dict` 

71 skymap : `lsst.skymap.ringsSkyMap.RingsSkyMap` 

72 plotInfo : `dict` 

73 

74 Returns 

75 ------- 

76 patchInfoDict : `dict` 

77 """ 

78 

79 # TODO: what is the more generic type of skymap? 

80 tractInfo = skymap.generateTract(plotInfo["tract"]) 

81 tractWcs = tractInfo.getWcs() 

82 

83 # For now also convert the gen 2 patchIds to gen 3 

84 if "y" in data.keys(): 

85 yCol = "y" 

86 elif "yStars" in data.keys(): 

87 yCol = "yStars" 

88 elif "yGalaxies" in data.keys(): 

89 yCol = "yGalaxies" 

90 elif "yUnknowns" in data.keys(): 

91 yCol = "yUnknowns" 

92 

93 patchInfoDict = {} 

94 maxPatchNum = tractInfo.num_patches.x * tractInfo.num_patches.y 

95 patches = np.arange(0, maxPatchNum, 1) 

96 for patch in patches: 

97 if patch is None: 

98 continue 

99 # Once the objectTable_tract catalogues are using gen 3 patches 

100 # this will go away 

101 onPatch = data["patch"] == patch 

102 stat = np.nanmedian(data[yCol][onPatch]) 

103 try: 

104 patchTuple = (int(patch.split(",")[0]), int(patch.split(",")[-1])) 

105 patchInfo = tractInfo.getPatchInfo(patchTuple) 

106 gen3PatchId = tractInfo.getSequentialPatchIndex(patchInfo) 

107 except AttributeError: 

108 # For native gen 3 tables the patches don't need converting 

109 # When we are no longer looking at the gen 2 -> gen 3 

110 # converted repos we can tidy this up 

111 gen3PatchId = patch 

112 patchInfo = tractInfo.getPatchInfo(patch) 

113 

114 corners = Box2D(patchInfo.getInnerBBox()).getCorners() 

115 skyCoords = tractWcs.pixelToSky(corners) 

116 

117 patchInfoDict[gen3PatchId] = (skyCoords, stat) 

118 

119 tractCorners = Box2D(tractInfo.getBBox()).getCorners() 

120 skyCoords = tractWcs.pixelToSky(tractCorners) 

121 patchInfoDict["tract"] = (skyCoords, np.nan) 

122 

123 return patchInfoDict 

124 

125 

126def generateSummaryStatsVisit(cat, colName, visitSummaryTable, plotInfo): 

127 """Generate a summary statistic in each patch or detector 

128 Parameters 

129 ---------- 

130 cat : `pandas.core.frame.DataFrame` 

131 colName : `str` 

132 visitSummaryTable : `pandas.core.frame.DataFrame` 

133 plotInfo : `dict` 

134 Returns 

135 ------- 

136 visitInfoDict : `dict` 

137 """ 

138 

139 visitInfoDict = {} 

140 for ccd in cat.detector.unique(): 

141 if ccd is None: 

142 continue 

143 onCcd = cat["detector"] == ccd 

144 stat = np.nanmedian(cat[colName].values[onCcd]) 

145 

146 sumRow = visitSummaryTable["id"] == ccd 

147 corners = zip(visitSummaryTable["raCorners"][sumRow][0], visitSummaryTable["decCorners"][sumRow][0]) 

148 cornersOut = [] 

149 for (ra, dec) in corners: 

150 corner = SpherePoint(ra, dec, units=degrees) 

151 cornersOut.append(corner) 

152 

153 visitInfoDict[ccd] = (cornersOut, stat) 

154 

155 return visitInfoDict 

156 

157 

158# Inspired by matplotlib.testing.remove_ticks_and_titles 

159def get_and_remove_axis_text(ax) -> Tuple[List[str], List[np.ndarray]]: 

160 """Remove text from an Axis and its children and return with line points. 

161 Parameters 

162 ---------- 

163 ax : `plt.Axis` 

164 A matplotlib figure axis. 

165 Returns 

166 ------- 

167 texts : `List[str]` 

168 A list of all text strings (title and axis/legend/tick labels). 

169 line_xys : `List[numpy.ndarray]` 

170 A list of all line ``_xy`` attributes (arrays of shape ``(N, 2)``). 

171 """ 

172 line_xys = [line._xy for line in ax.lines] 

173 texts = [text.get_text() for text in (ax.title, ax.xaxis.label, ax.yaxis.label)] 

174 ax.set_title("") 

175 ax.set_xlabel("") 

176 ax.set_ylabel("") 

177 

178 try: 

179 texts_legend = ax.get_legend().texts 

180 texts.extend(text.get_text() for text in texts_legend) 

181 for text in texts_legend: 

182 text.set_alpha(0) 

183 except AttributeError: 

184 pass 

185 

186 for idx in range(len(ax.texts)): 

187 texts.append(ax.texts[idx].get_text()) 

188 ax.texts[idx].set_text("") 

189 

190 ax.xaxis.set_major_formatter(null_formatter) 

191 ax.xaxis.set_minor_formatter(null_formatter) 

192 ax.yaxis.set_major_formatter(null_formatter) 

193 ax.yaxis.set_minor_formatter(null_formatter) 

194 try: 

195 ax.zaxis.set_major_formatter(null_formatter) 

196 ax.zaxis.set_minor_formatter(null_formatter) 

197 except AttributeError: 

198 pass 

199 for child in ax.child_axes: 

200 texts_child, lines_child = get_and_remove_axis_text(child) 

201 texts.extend(texts_child) 

202 

203 return texts, line_xys 

204 

205 

206def get_and_remove_figure_text(figure: plt.Figure): 

207 """Remove text from a Figure and its Axes and return with line points. 

208 Parameters 

209 ---------- 

210 figure : `matplotlib.pyplot.Figure` 

211 A matplotlib figure. 

212 Returns 

213 ------- 

214 texts : `List[str]` 

215 A list of all text strings (title and axis/legend/tick labels). 

216 line_xys : `List[numpy.ndarray]`, (N, 2) 

217 A list of all line ``_xy`` attributes (arrays of shape ``(N, 2)``). 

218 """ 

219 texts = [str(figure._suptitle)] 

220 lines = [] 

221 figure.suptitle("") 

222 

223 texts.extend(text.get_text() for text in figure.texts) 

224 figure.texts = [] 

225 

226 for ax in figure.get_axes(): 

227 texts_ax, lines_ax = get_and_remove_axis_text(ax) 

228 texts.extend(texts_ax) 

229 lines.extend(lines_ax) 

230 

231 return texts, lines 

232 

233 

234def addPlotInfo(fig, plotInfo): 

235 """Add useful information to the plot 

236 Parameters 

237 ---------- 

238 fig : `matplotlib.figure.Figure` 

239 plotInfo : `dict` 

240 Returns 

241 ------- 

242 fig : `matplotlib.figure.Figure` 

243 """ 

244 

245 # TO DO: figure out how to get this information 

246 photocalibDataset = "None" 

247 astroDataset = "None" 

248 

249 fig.text(0.01, 0.99, plotInfo["plotName"], fontsize=8, transform=fig.transFigure, ha="left", va="top") 

250 

251 run = plotInfo["run"] 

252 datasetsUsed = f"\nPhotoCalib: {photocalibDataset}, Astrometry: {astroDataset}" 

253 tableType = f"\nTable: {plotInfo['tableName']}" 

254 

255 dataIdText = "" 

256 if "tract" in plotInfo.keys(): 

257 dataIdText += f", Tract: {plotInfo['tract']}" 

258 if "visit" in plotInfo.keys(): 

259 dataIdText += f", Visit: {plotInfo['visit']}" 

260 

261 bandText = "" 

262 for band in plotInfo["bands"]: 

263 bandText += band + ", " 

264 bandsText = f", Bands: {bandText[:-2]}" 

265 SNText = f", S/N: {plotInfo['SN']}" 

266 infoText = f"\n{run}{datasetsUsed}{tableType}{dataIdText}{bandsText}{SNText}" 

267 fig.text(0.01, 0.98, infoText, fontsize=7, transform=fig.transFigure, alpha=0.6, ha="left", va="top") 

268 

269 return fig 

270 

271 

272def stellarLocusFit(xs, ys, paramDict): 

273 """Make a fit to the stellar locus 

274 Parameters 

275 ---------- 

276 xs : `numpy.ndarray` 

277 The color on the xaxis 

278 ys : `numpy.ndarray` 

279 The color on the yaxis 

280 paramDict : lsst.pex.config.dictField.Dict 

281 A dictionary of parameters for line fitting 

282 xMin : `float` 

283 The minimum x edge of the box to use for initial fitting 

284 xMax : `float` 

285 The maximum x edge of the box to use for initial fitting 

286 yMin : `float` 

287 The minimum y edge of the box to use for initial fitting 

288 yMax : `float` 

289 The maximum y edge of the box to use for initial fitting 

290 mHW : `float` 

291 The hardwired gradient for the fit 

292 bHW : `float` 

293 The hardwired intercept of the fit 

294 Returns 

295 ------- 

296 paramsOut : `dict` 

297 A dictionary of the calculated fit parameters 

298 xMin : `float` 

299 The minimum x edge of the box to use for initial fitting 

300 xMax : `float` 

301 The maximum x edge of the box to use for initial fitting 

302 yMin : `float` 

303 The minimum y edge of the box to use for initial fitting 

304 yMax : `float` 

305 The maximum y edge of the box to use for initial fitting 

306 mHW : `float` 

307 The hardwired gradient for the fit 

308 bHW : `float` 

309 The hardwired intercept of the fit 

310 mODR : `float` 

311 The gradient calculated by the ODR fit 

312 bODR : `float` 

313 The intercept calculated by the ODR fit 

314 yBoxMin : `float` 

315 The y value of the fitted line at xMin 

316 yBoxMax : `float` 

317 The y value of the fitted line at xMax 

318 bPerpMin : `float` 

319 The intercept of the perpendicular line that goes through xMin 

320 bPerpMax : `float` 

321 The intercept of the perpendicular line that goes through xMax 

322 mODR2 : `float` 

323 The gradient from the second round of fitting 

324 bODR2 : `float` 

325 The intercept from the second round of fitting 

326 mPerp : `float` 

327 The gradient of the line perpendicular to the line from the 

328 second fit 

329 Notes 

330 ----- 

331 The code does two rounds of fitting, the first is initiated using the 

332 hardwired values given in the `paramDict` parameter and is done using 

333 an Orthogonal Distance Regression fit to the points defined by the 

334 box of xMin, xMax, yMin and yMax. Once this fitting has been done a 

335 perpendicular bisector is calculated at either end of the line and 

336 only points that fall within these lines are used to recalculate the fit. 

337 """ 

338 

339 # Points to use for the fit 

340 fitPoints = np.where( 

341 (xs > paramDict["xMin"]) 

342 & (xs < paramDict["xMax"]) 

343 & (ys > paramDict["yMin"]) 

344 & (ys < paramDict["yMax"]) 

345 )[0] 

346 

347 linear = scipyODR.polynomial(1) 

348 

349 data = scipyODR.Data(xs[fitPoints], ys[fitPoints]) 

350 odr = scipyODR.ODR(data, linear, beta0=[paramDict["bHW"], paramDict["mHW"]]) 

351 params = odr.run() 

352 mODR = float(params.beta[1]) 

353 bODR = float(params.beta[0]) 

354 

355 paramsOut = { 

356 "xMin": paramDict["xMin"], 

357 "xMax": paramDict["xMax"], 

358 "yMin": paramDict["yMin"], 

359 "yMax": paramDict["yMax"], 

360 "mHW": paramDict["mHW"], 

361 "bHW": paramDict["bHW"], 

362 "mODR": mODR, 

363 "bODR": bODR, 

364 } 

365 

366 # Having found the initial fit calculate perpendicular ends 

367 mPerp = -1.0 / mODR 

368 # When the gradient is really steep we need to use 

369 # the y limits of the box rather than the x ones 

370 

371 if np.abs(mODR) > 1: 

372 yBoxMin = paramDict["yMin"] 

373 xBoxMin = (yBoxMin - bODR) / mODR 

374 yBoxMax = paramDict["yMax"] 

375 xBoxMax = (yBoxMax - bODR) / mODR 

376 else: 

377 yBoxMin = mODR * paramDict["xMin"] + bODR 

378 xBoxMin = paramDict["xMin"] 

379 yBoxMax = mODR * paramDict["xMax"] + bODR 

380 xBoxMax = paramDict["xMax"] 

381 

382 bPerpMin = yBoxMin - mPerp * xBoxMin 

383 

384 paramsOut["yBoxMin"] = yBoxMin 

385 paramsOut["bPerpMin"] = bPerpMin 

386 

387 bPerpMax = yBoxMax - mPerp * xBoxMax 

388 

389 paramsOut["yBoxMax"] = yBoxMax 

390 paramsOut["bPerpMax"] = bPerpMax 

391 

392 # Use these perpendicular lines to chose the data and refit 

393 fitPoints = (ys > mPerp * xs + bPerpMin) & (ys < mPerp * xs + bPerpMax) 

394 data = scipyODR.Data(xs[fitPoints], ys[fitPoints]) 

395 odr = scipyODR.ODR(data, linear, beta0=[bODR, mODR]) 

396 params = odr.run() 

397 mODR = float(params.beta[1]) 

398 bODR = float(params.beta[0]) 

399 

400 paramsOut["mODR2"] = float(params.beta[1]) 

401 paramsOut["bODR2"] = float(params.beta[0]) 

402 

403 paramsOut["mPerp"] = -1.0 / paramsOut["mODR2"] 

404 

405 return paramsOut 

406 

407 

408def perpDistance(p1, p2, points): 

409 """Calculate the perpendicular distance to a line from a point 

410 Parameters 

411 ---------- 

412 p1 : `numpy.ndarray` 

413 A point on the line 

414 p2 : `numpy.ndarray` 

415 Another point on the line 

416 points : `zip` 

417 The points to calculate the distance to 

418 Returns 

419 ------- 

420 dists : `list` 

421 The distances from the line to the points. Uses the cross 

422 product to work this out. 

423 """ 

424 dists = [] 

425 for point in points: 

426 point = np.array(point) 

427 distToLine = np.cross(p2 - p1, point - p1) / np.linalg.norm(p2 - p1) 

428 dists.append(distToLine) 

429 

430 return dists 

431 

432 

433def mkColormap(colorNames): 

434 """Make a colormap from the list of color names. 

435 Parameters 

436 ---------- 

437 colorNames : `list` 

438 A list of strings that correspond to matplotlib 

439 named colors. 

440 Returns 

441 ------- 

442 cmap : `matplotlib.colors.LinearSegmentedColormap` 

443 """ 

444 

445 nums = np.linspace(0, 1, len(colorNames)) 

446 blues = [] 

447 greens = [] 

448 reds = [] 

449 for (num, color) in zip(nums, colorNames): 

450 r, g, b = colors.colorConverter.to_rgb(color) 

451 blues.append((num, b, b)) 

452 greens.append((num, g, g)) 

453 reds.append((num, r, r)) 

454 

455 colorDict = {"blue": blues, "red": reds, "green": greens} 

456 cmap = colors.LinearSegmentedColormap("newCmap", colorDict) 

457 return cmap 

458 

459 

460def extremaSort(xs): 

461 """Return the ids of the points reordered so that those 

462 furthest from the median, in absolute terms, are last. 

463 Parameters 

464 ---------- 

465 xs : `np.array` 

466 An array of the values to sort 

467 Returns 

468 ------- 

469 ids : `np.array` 

470 """ 

471 

472 med = np.median(xs) 

473 dists = np.abs(xs - med) 

474 ids = np.argsort(dists) 

475 return ids 

476 

477 

478def addSummaryPlot(fig, loc, sumStats, label): 

479 """Add a summary subplot to the figure. 

480 

481 Parameters 

482 ---------- 

483 fig : `matplotlib.figure.Figure` 

484 The figure that the summary plot is to be added to. 

485 loc : `matplotlib.gridspec.SubplotSpec` or `int` or `(int, int, index` 

486 Describes the location in the figure to put the summary plot, 

487 can be a gridspec SubplotSpec, a 3 digit integer where the first 

488 digit is the number of rows, the second is the number of columns 

489 and the third is the index. This is the same for the tuple 

490 of int, int, index. 

491 sumStats : `dict` 

492 A dictionary where the patchIds are the keys which store the R.A. 

493 and the dec of the corners of the patch, along with a summary 

494 statistic for each patch. 

495 

496 Returns 

497 ------- 

498 fig : `matplotlib.figure.Figure` 

499 """ 

500 

501 # Add the subplot to the relevant place in the figure 

502 # and sort the axis out 

503 axCorner = fig.add_subplot(loc) 

504 axCorner.yaxis.tick_right() 

505 axCorner.yaxis.set_label_position("right") 

506 axCorner.xaxis.tick_top() 

507 axCorner.xaxis.set_label_position("top") 

508 axCorner.set_aspect("equal") 

509 

510 # Plot the corners of the patches and make the color 

511 # coded rectangles for each patch, the colors show 

512 # the median of the given value in the patch 

513 patches = [] 

514 colors = [] 

515 for dataId in sumStats.keys(): 

516 (corners, stat) = sumStats[dataId] 

517 ra = corners[0][0].asDegrees() 

518 dec = corners[0][1].asDegrees() 

519 xy = (ra, dec) 

520 width = corners[2][0].asDegrees() - ra 

521 height = corners[2][1].asDegrees() - dec 

522 patches.append(Rectangle(xy, width, height)) 

523 colors.append(stat) 

524 ras = [ra.asDegrees() for (ra, dec) in corners] 

525 decs = [dec.asDegrees() for (ra, dec) in corners] 

526 axCorner.plot(ras + [ras[0]], decs + [decs[0]], "k", lw=0.5) 

527 cenX = ra + width / 2 

528 cenY = dec + height / 2 

529 if dataId != "tract": 

530 axCorner.annotate(dataId, (cenX, cenY), color="k", fontsize=4, ha="center", va="center") 

531 

532 # Set the bad color to transparent and make a masked array 

533 cmapPatch = plt.cm.coolwarm.copy() 

534 cmapPatch.set_bad(color="none") 

535 colors = np.ma.array(colors, mask=np.isnan(colors)) 

536 collection = PatchCollection(patches, cmap=cmapPatch) 

537 collection.set_array(colors) 

538 axCorner.add_collection(collection) 

539 

540 # Add some labels 

541 axCorner.set_xlabel("R.A. (deg)", fontsize=7) 

542 axCorner.set_ylabel("Dec. (deg)", fontsize=7) 

543 axCorner.tick_params(axis="both", labelsize=6, length=0, pad=1.5) 

544 axCorner.invert_xaxis() 

545 

546 # Add a colorbar 

547 pos = axCorner.get_position() 

548 yOffset = (pos.y1 - pos.y0) / 3 

549 cax = fig.add_axes([pos.x0, pos.y1 + yOffset, pos.x1 - pos.x0, 0.025]) 

550 plt.colorbar(collection, cax=cax, orientation="horizontal") 

551 cax.text( 

552 0.5, 

553 0.48, 

554 label, 

555 color="k", 

556 transform=cax.transAxes, 

557 rotation="horizontal", 

558 horizontalalignment="center", 

559 verticalalignment="center", 

560 fontsize=6, 

561 ) 

562 cax.tick_params( 

563 axis="x", labelsize=6, labeltop=True, labelbottom=False, bottom=False, top=True, pad=0.5, length=2 

564 ) 

565 

566 return fig