Coverage for python/lsst/analysis/tools/actions/plot/plotUtils.py: 7%

237 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-10-11 09:59 +0000

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23from typing import List, Tuple 

24 

25import matplotlib 

26import matplotlib.pyplot as plt 

27import numpy as np 

28import scipy.odr as scipyODR 

29from lsst.geom import Box2D, SpherePoint, degrees 

30from matplotlib import colors 

31from matplotlib.collections import PatchCollection 

32from matplotlib.patches import Rectangle 

33 

34null_formatter = matplotlib.ticker.NullFormatter() 

35 

36 

37def parsePlotInfo(dataId, runName, tableName, bands, plotName, SN): 

38 """Parse plot info from the dataId 

39 Parameters 

40 ---------- 

41 dataId : `lsst.daf.butler.core.dimensions.` 

42 `_coordinate._ExpandedTupleDataCoordinate` 

43 runName : `str` 

44 Returns 

45 ------- 

46 plotInfo : `dict` 

47 """ 

48 plotInfo = {"run": runName, "tableName": tableName, "plotName": plotName, "SN": SN} 

49 

50 for dataInfo in dataId: 

51 plotInfo[dataInfo.name] = dataId[dataInfo.name] 

52 

53 bandStr = "" 

54 for band in bands: 

55 bandStr += ", " + band 

56 plotInfo["bands"] = bandStr[2:] 

57 

58 if "tract" not in plotInfo.keys(): 

59 plotInfo["tract"] = "N/A" 

60 if "visit" not in plotInfo.keys(): 

61 plotInfo["visit"] = "N/A" 

62 

63 return plotInfo 

64 

65 

66def generateSummaryStats(data, skymap, plotInfo): 

67 """Generate a summary statistic in each patch or detector 

68 Parameters 

69 ---------- 

70 data : `dict` 

71 skymap : `lsst.skymap.ringsSkyMap.RingsSkyMap` 

72 plotInfo : `dict` 

73 

74 Returns 

75 ------- 

76 patchInfoDict : `dict` 

77 """ 

78 

79 # TODO: what is the more generic type of skymap? 

80 tractInfo = skymap.generateTract(plotInfo["tract"]) 

81 tractWcs = tractInfo.getWcs() 

82 

83 # For now also convert the gen 2 patchIds to gen 3 

84 if "y" in data.keys(): 

85 yCol = "y" 

86 elif "yStars" in data.keys(): 

87 yCol = "yStars" 

88 elif "yGalaxies" in data.keys(): 

89 yCol = "yGalaxies" 

90 elif "yUnknowns" in data.keys(): 

91 yCol = "yUnknowns" 

92 

93 patchInfoDict = {} 

94 maxPatchNum = tractInfo.num_patches.x * tractInfo.num_patches.y 

95 patches = np.arange(0, maxPatchNum, 1) 

96 for patch in patches: 

97 if patch is None: 

98 continue 

99 # Once the objectTable_tract catalogues are using gen 3 patches 

100 # this will go away 

101 onPatch = data["patch"] == patch 

102 stat = np.nanmedian(data[yCol][onPatch]) 

103 try: 

104 patchTuple = (int(patch.split(",")[0]), int(patch.split(",")[-1])) 

105 patchInfo = tractInfo.getPatchInfo(patchTuple) 

106 gen3PatchId = tractInfo.getSequentialPatchIndex(patchInfo) 

107 except AttributeError: 

108 # For native gen 3 tables the patches don't need converting 

109 # When we are no longer looking at the gen 2 -> gen 3 

110 # converted repos we can tidy this up 

111 gen3PatchId = patch 

112 patchInfo = tractInfo.getPatchInfo(patch) 

113 

114 corners = Box2D(patchInfo.getInnerBBox()).getCorners() 

115 skyCoords = tractWcs.pixelToSky(corners) 

116 

117 patchInfoDict[gen3PatchId] = (skyCoords, stat) 

118 

119 tractCorners = Box2D(tractInfo.getBBox()).getCorners() 

120 skyCoords = tractWcs.pixelToSky(tractCorners) 

121 patchInfoDict["tract"] = (skyCoords, np.nan) 

122 

123 return patchInfoDict 

124 

125 

126def generateSummaryStatsVisit(cat, colName, visitSummaryTable, plotInfo): 

127 """Generate a summary statistic in each patch or detector 

128 Parameters 

129 ---------- 

130 cat : `pandas.core.frame.DataFrame` 

131 colName : `str` 

132 visitSummaryTable : `pandas.core.frame.DataFrame` 

133 plotInfo : `dict` 

134 Returns 

135 ------- 

136 visitInfoDict : `dict` 

137 """ 

138 

139 visitInfoDict = {} 

140 for ccd in cat.detector.unique(): 

141 if ccd is None: 

142 continue 

143 onCcd = cat["detector"] == ccd 

144 stat = np.nanmedian(cat[colName].values[onCcd]) 

145 

146 sumRow = visitSummaryTable["id"] == ccd 

147 corners = zip(visitSummaryTable["raCorners"][sumRow][0], visitSummaryTable["decCorners"][sumRow][0]) 

148 cornersOut = [] 

149 for (ra, dec) in corners: 

150 corner = SpherePoint(ra, dec, units=degrees) 

151 cornersOut.append(corner) 

152 

153 visitInfoDict[ccd] = (cornersOut, stat) 

154 

155 return visitInfoDict 

156 

157 

158# Inspired by matplotlib.testing.remove_ticks_and_titles 

159def get_and_remove_axis_text(ax) -> Tuple[List[str], List[np.ndarray]]: 

160 """Remove text from an Axis and its children and return with line points. 

161 Parameters 

162 ---------- 

163 ax : `plt.Axis` 

164 A matplotlib figure axis. 

165 Returns 

166 ------- 

167 texts : `List[str]` 

168 A list of all text strings (title and axis/legend/tick labels). 

169 line_xys : `List[numpy.ndarray]` 

170 A list of all line ``_xy`` attributes (arrays of shape ``(N, 2)``). 

171 """ 

172 line_xys = [line._xy for line in ax.lines] 

173 texts = [text.get_text() for text in (ax.title, ax.xaxis.label, ax.yaxis.label)] 

174 ax.set_title("") 

175 ax.set_xlabel("") 

176 ax.set_ylabel("") 

177 

178 try: 

179 texts_legend = ax.get_legend().texts 

180 texts.extend(text.get_text() for text in texts_legend) 

181 for text in texts_legend: 

182 text.set_alpha(0) 

183 except AttributeError: 

184 pass 

185 

186 for idx in range(len(ax.texts)): 

187 texts.append(ax.texts[idx].get_text()) 

188 ax.texts[idx].set_text("") 

189 

190 ax.xaxis.set_major_formatter(null_formatter) 

191 ax.xaxis.set_minor_formatter(null_formatter) 

192 ax.yaxis.set_major_formatter(null_formatter) 

193 ax.yaxis.set_minor_formatter(null_formatter) 

194 try: 

195 ax.zaxis.set_major_formatter(null_formatter) 

196 ax.zaxis.set_minor_formatter(null_formatter) 

197 except AttributeError: 

198 pass 

199 for child in ax.child_axes: 

200 texts_child, lines_child = get_and_remove_axis_text(child) 

201 texts.extend(texts_child) 

202 

203 return texts, line_xys 

204 

205 

206def get_and_remove_figure_text(figure: plt.Figure): 

207 """Remove text from a Figure and its Axes and return with line points. 

208 Parameters 

209 ---------- 

210 figure : `matplotlib.pyplot.Figure` 

211 A matplotlib figure. 

212 Returns 

213 ------- 

214 texts : `List[str]` 

215 A list of all text strings (title and axis/legend/tick labels). 

216 line_xys : `List[numpy.ndarray]`, (N, 2) 

217 A list of all line ``_xy`` attributes (arrays of shape ``(N, 2)``). 

218 """ 

219 texts = [str(figure._suptitle)] 

220 lines = [] 

221 figure.suptitle("") 

222 

223 texts.extend(text.get_text() for text in figure.texts) 

224 figure.texts = [] 

225 

226 for ax in figure.get_axes(): 

227 texts_ax, lines_ax = get_and_remove_axis_text(ax) 

228 texts.extend(texts_ax) 

229 lines.extend(lines_ax) 

230 

231 return texts, lines 

232 

233 

234def addPlotInfo(fig, plotInfo): 

235 """Add useful information to the plot 

236 Parameters 

237 ---------- 

238 fig : `matplotlib.figure.Figure` 

239 plotInfo : `dict` 

240 Returns 

241 ------- 

242 fig : `matplotlib.figure.Figure` 

243 """ 

244 

245 # TO DO: figure out how to get this information 

246 photocalibDataset = "None" 

247 astroDataset = "None" 

248 

249 fig.text(0.01, 0.99, plotInfo["plotName"], fontsize=8, transform=fig.transFigure, ha="left", va="top") 

250 

251 run = plotInfo["run"] 

252 datasetsUsed = f"\nPhotoCalib: {photocalibDataset}, Astrometry: {astroDataset}" 

253 tableType = f"\nTable: {plotInfo['tableName']}" 

254 

255 dataIdText = "" 

256 if "tract" in plotInfo.keys(): 

257 dataIdText += f", Tract: {plotInfo['tract']}" 

258 if "visit" in plotInfo.keys(): 

259 dataIdText += f", Visit: {plotInfo['visit']}" 

260 

261 bandText = "" 

262 for band in plotInfo["bands"]: 

263 bandText += band + ", " 

264 bandsText = f", Bands: {bandText[:-2]}" 

265 try: 

266 SNText = f", S/N: {plotInfo['SN']}" 

267 except KeyError: 

268 SNText = ", S/N: -" 

269 infoText = f"\n{run}{datasetsUsed}{tableType}{dataIdText}{bandsText}{SNText}" 

270 fig.text(0.01, 0.98, infoText, fontsize=7, transform=fig.transFigure, alpha=0.6, ha="left", va="top") 

271 

272 return fig 

273 

274 

275def stellarLocusFit(xs, ys, paramDict): 

276 """Make a fit to the stellar locus 

277 Parameters 

278 ---------- 

279 xs : `numpy.ndarray` 

280 The color on the xaxis 

281 ys : `numpy.ndarray` 

282 The color on the yaxis 

283 paramDict : lsst.pex.config.dictField.Dict 

284 A dictionary of parameters for line fitting 

285 xMin : `float` 

286 The minimum x edge of the box to use for initial fitting 

287 xMax : `float` 

288 The maximum x edge of the box to use for initial fitting 

289 yMin : `float` 

290 The minimum y edge of the box to use for initial fitting 

291 yMax : `float` 

292 The maximum y edge of the box to use for initial fitting 

293 mHW : `float` 

294 The hardwired gradient for the fit 

295 bHW : `float` 

296 The hardwired intercept of the fit 

297 Returns 

298 ------- 

299 paramsOut : `dict` 

300 A dictionary of the calculated fit parameters 

301 xMin : `float` 

302 The minimum x edge of the box to use for initial fitting 

303 xMax : `float` 

304 The maximum x edge of the box to use for initial fitting 

305 yMin : `float` 

306 The minimum y edge of the box to use for initial fitting 

307 yMax : `float` 

308 The maximum y edge of the box to use for initial fitting 

309 mHW : `float` 

310 The hardwired gradient for the fit 

311 bHW : `float` 

312 The hardwired intercept of the fit 

313 mODR : `float` 

314 The gradient calculated by the ODR fit 

315 bODR : `float` 

316 The intercept calculated by the ODR fit 

317 yBoxMin : `float` 

318 The y value of the fitted line at xMin 

319 yBoxMax : `float` 

320 The y value of the fitted line at xMax 

321 bPerpMin : `float` 

322 The intercept of the perpendicular line that goes through xMin 

323 bPerpMax : `float` 

324 The intercept of the perpendicular line that goes through xMax 

325 mODR2 : `float` 

326 The gradient from the second round of fitting 

327 bODR2 : `float` 

328 The intercept from the second round of fitting 

329 mPerp : `float` 

330 The gradient of the line perpendicular to the line from the 

331 second fit 

332 Notes 

333 ----- 

334 The code does two rounds of fitting, the first is initiated using the 

335 hardwired values given in the `paramDict` parameter and is done using 

336 an Orthogonal Distance Regression fit to the points defined by the 

337 box of xMin, xMax, yMin and yMax. Once this fitting has been done a 

338 perpendicular bisector is calculated at either end of the line and 

339 only points that fall within these lines are used to recalculate the fit. 

340 """ 

341 

342 # Points to use for the fit 

343 fitPoints = np.where( 

344 (xs > paramDict["xMin"]) 

345 & (xs < paramDict["xMax"]) 

346 & (ys > paramDict["yMin"]) 

347 & (ys < paramDict["yMax"]) 

348 )[0] 

349 

350 linear = scipyODR.polynomial(1) 

351 

352 data = scipyODR.Data(xs[fitPoints], ys[fitPoints]) 

353 odr = scipyODR.ODR(data, linear, beta0=[paramDict["bHW"], paramDict["mHW"]]) 

354 params = odr.run() 

355 mODR = float(params.beta[1]) 

356 bODR = float(params.beta[0]) 

357 

358 paramsOut = { 

359 "xMin": paramDict["xMin"], 

360 "xMax": paramDict["xMax"], 

361 "yMin": paramDict["yMin"], 

362 "yMax": paramDict["yMax"], 

363 "mHW": paramDict["mHW"], 

364 "bHW": paramDict["bHW"], 

365 "mODR": mODR, 

366 "bODR": bODR, 

367 } 

368 

369 # Having found the initial fit calculate perpendicular ends 

370 mPerp = -1.0 / mODR 

371 # When the gradient is really steep we need to use 

372 # the y limits of the box rather than the x ones 

373 

374 if np.abs(mODR) > 1: 

375 yBoxMin = paramDict["yMin"] 

376 xBoxMin = (yBoxMin - bODR) / mODR 

377 yBoxMax = paramDict["yMax"] 

378 xBoxMax = (yBoxMax - bODR) / mODR 

379 else: 

380 yBoxMin = mODR * paramDict["xMin"] + bODR 

381 xBoxMin = paramDict["xMin"] 

382 yBoxMax = mODR * paramDict["xMax"] + bODR 

383 xBoxMax = paramDict["xMax"] 

384 

385 bPerpMin = yBoxMin - mPerp * xBoxMin 

386 

387 paramsOut["yBoxMin"] = yBoxMin 

388 paramsOut["bPerpMin"] = bPerpMin 

389 

390 bPerpMax = yBoxMax - mPerp * xBoxMax 

391 

392 paramsOut["yBoxMax"] = yBoxMax 

393 paramsOut["bPerpMax"] = bPerpMax 

394 

395 # Use these perpendicular lines to chose the data and refit 

396 fitPoints = (ys > mPerp * xs + bPerpMin) & (ys < mPerp * xs + bPerpMax) 

397 data = scipyODR.Data(xs[fitPoints], ys[fitPoints]) 

398 odr = scipyODR.ODR(data, linear, beta0=[bODR, mODR]) 

399 params = odr.run() 

400 mODR = float(params.beta[1]) 

401 bODR = float(params.beta[0]) 

402 

403 paramsOut["mODR2"] = float(params.beta[1]) 

404 paramsOut["bODR2"] = float(params.beta[0]) 

405 

406 paramsOut["mPerp"] = -1.0 / paramsOut["mODR2"] 

407 

408 return paramsOut 

409 

410 

411def perpDistance(p1, p2, points): 

412 """Calculate the perpendicular distance to a line from a point 

413 Parameters 

414 ---------- 

415 p1 : `numpy.ndarray` 

416 A point on the line 

417 p2 : `numpy.ndarray` 

418 Another point on the line 

419 points : `zip` 

420 The points to calculate the distance to 

421 Returns 

422 ------- 

423 dists : `list` 

424 The distances from the line to the points. Uses the cross 

425 product to work this out. 

426 """ 

427 dists = [] 

428 for point in points: 

429 point = np.array(point) 

430 distToLine = np.cross(p2 - p1, point - p1) / np.linalg.norm(p2 - p1) 

431 dists.append(distToLine) 

432 

433 return dists 

434 

435 

436def mkColormap(colorNames): 

437 """Make a colormap from the list of color names. 

438 Parameters 

439 ---------- 

440 colorNames : `list` 

441 A list of strings that correspond to matplotlib 

442 named colors. 

443 Returns 

444 ------- 

445 cmap : `matplotlib.colors.LinearSegmentedColormap` 

446 """ 

447 

448 nums = np.linspace(0, 1, len(colorNames)) 

449 blues = [] 

450 greens = [] 

451 reds = [] 

452 for (num, color) in zip(nums, colorNames): 

453 r, g, b = colors.colorConverter.to_rgb(color) 

454 blues.append((num, b, b)) 

455 greens.append((num, g, g)) 

456 reds.append((num, r, r)) 

457 

458 colorDict = {"blue": blues, "red": reds, "green": greens} 

459 cmap = colors.LinearSegmentedColormap("newCmap", colorDict) 

460 return cmap 

461 

462 

463def extremaSort(xs): 

464 """Return the ids of the points reordered so that those 

465 furthest from the median, in absolute terms, are last. 

466 Parameters 

467 ---------- 

468 xs : `np.array` 

469 An array of the values to sort 

470 Returns 

471 ------- 

472 ids : `np.array` 

473 """ 

474 

475 med = np.median(xs) 

476 dists = np.abs(xs - med) 

477 ids = np.argsort(dists) 

478 return ids 

479 

480 

481def addSummaryPlot(fig, loc, sumStats, label): 

482 """Add a summary subplot to the figure. 

483 

484 Parameters 

485 ---------- 

486 fig : `matplotlib.figure.Figure` 

487 The figure that the summary plot is to be added to. 

488 loc : `matplotlib.gridspec.SubplotSpec` or `int` or `(int, int, index` 

489 Describes the location in the figure to put the summary plot, 

490 can be a gridspec SubplotSpec, a 3 digit integer where the first 

491 digit is the number of rows, the second is the number of columns 

492 and the third is the index. This is the same for the tuple 

493 of int, int, index. 

494 sumStats : `dict` 

495 A dictionary where the patchIds are the keys which store the R.A. 

496 and the dec of the corners of the patch, along with a summary 

497 statistic for each patch. 

498 

499 Returns 

500 ------- 

501 fig : `matplotlib.figure.Figure` 

502 """ 

503 

504 # Add the subplot to the relevant place in the figure 

505 # and sort the axis out 

506 axCorner = fig.add_subplot(loc) 

507 axCorner.yaxis.tick_right() 

508 axCorner.yaxis.set_label_position("right") 

509 axCorner.xaxis.tick_top() 

510 axCorner.xaxis.set_label_position("top") 

511 axCorner.set_aspect("equal") 

512 

513 # Plot the corners of the patches and make the color 

514 # coded rectangles for each patch, the colors show 

515 # the median of the given value in the patch 

516 patches = [] 

517 colors = [] 

518 for dataId in sumStats.keys(): 

519 (corners, stat) = sumStats[dataId] 

520 ra = corners[0][0].asDegrees() 

521 dec = corners[0][1].asDegrees() 

522 xy = (ra, dec) 

523 width = corners[2][0].asDegrees() - ra 

524 height = corners[2][1].asDegrees() - dec 

525 patches.append(Rectangle(xy, width, height)) 

526 colors.append(stat) 

527 ras = [ra.asDegrees() for (ra, dec) in corners] 

528 decs = [dec.asDegrees() for (ra, dec) in corners] 

529 axCorner.plot(ras + [ras[0]], decs + [decs[0]], "k", lw=0.5) 

530 cenX = ra + width / 2 

531 cenY = dec + height / 2 

532 if dataId != "tract": 

533 axCorner.annotate(dataId, (cenX, cenY), color="k", fontsize=4, ha="center", va="center") 

534 

535 # Set the bad color to transparent and make a masked array 

536 cmapPatch = plt.cm.coolwarm.copy() 

537 cmapPatch.set_bad(color="none") 

538 colors = np.ma.array(colors, mask=np.isnan(colors)) 

539 collection = PatchCollection(patches, cmap=cmapPatch) 

540 collection.set_array(colors) 

541 axCorner.add_collection(collection) 

542 

543 # Add some labels 

544 axCorner.set_xlabel("R.A. (deg)", fontsize=7) 

545 axCorner.set_ylabel("Dec. (deg)", fontsize=7) 

546 axCorner.tick_params(axis="both", labelsize=6, length=0, pad=1.5) 

547 axCorner.invert_xaxis() 

548 

549 # Add a colorbar 

550 pos = axCorner.get_position() 

551 yOffset = (pos.y1 - pos.y0) / 3 

552 cax = fig.add_axes([pos.x0, pos.y1 + yOffset, pos.x1 - pos.x0, 0.025]) 

553 plt.colorbar(collection, cax=cax, orientation="horizontal") 

554 cax.text( 

555 0.5, 

556 0.48, 

557 label, 

558 color="k", 

559 transform=cax.transAxes, 

560 rotation="horizontal", 

561 horizontalalignment="center", 

562 verticalalignment="center", 

563 fontsize=6, 

564 ) 

565 cax.tick_params( 

566 axis="x", labelsize=6, labeltop=True, labelbottom=False, bottom=False, top=True, pad=0.5, length=2 

567 ) 

568 

569 return fig