Coverage for python/lsst/analysis/tools/actions/plot/plotUtils.py: 11%

247 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-12-06 02:44 -0800

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("PanelConfig",) 

24 

25from typing import List, Tuple 

26 

27import matplotlib 

28import matplotlib.pyplot as plt 

29import numpy as np 

30import scipy.odr as scipyODR 

31from lsst.geom import Box2D, SpherePoint, degrees 

32from lsst.pex.config import Config, Field 

33from matplotlib import colors 

34from matplotlib.collections import PatchCollection 

35from matplotlib.patches import Rectangle 

36 

37null_formatter = matplotlib.ticker.NullFormatter() 

38 

39 

40def parsePlotInfo(dataId, runName, tableName, bands, plotName, SN): 

41 """Parse plot info from the dataId 

42 Parameters 

43 ---------- 

44 dataId : `lsst.daf.butler.core.dimensions.` 

45 `_coordinate._ExpandedTupleDataCoordinate` 

46 runName : `str` 

47 Returns 

48 ------- 

49 plotInfo : `dict` 

50 """ 

51 plotInfo = {"run": runName, "tableName": tableName, "plotName": plotName, "SN": SN} 

52 

53 for dataInfo in dataId: 

54 plotInfo[dataInfo.name] = dataId[dataInfo.name] 

55 

56 bandStr = "" 

57 for band in bands: 

58 bandStr += ", " + band 

59 plotInfo["bands"] = bandStr[2:] 

60 

61 if "tract" not in plotInfo.keys(): 

62 plotInfo["tract"] = "N/A" 

63 if "visit" not in plotInfo.keys(): 

64 plotInfo["visit"] = "N/A" 

65 

66 return plotInfo 

67 

68 

69def generateSummaryStats(data, skymap, plotInfo): 

70 """Generate a summary statistic in each patch or detector 

71 Parameters 

72 ---------- 

73 data : `dict` 

74 skymap : `lsst.skymap.ringsSkyMap.RingsSkyMap` 

75 plotInfo : `dict` 

76 

77 Returns 

78 ------- 

79 patchInfoDict : `dict` 

80 """ 

81 

82 # TODO: what is the more generic type of skymap? 

83 tractInfo = skymap.generateTract(plotInfo["tract"]) 

84 tractWcs = tractInfo.getWcs() 

85 

86 # For now also convert the gen 2 patchIds to gen 3 

87 if "y" in data.keys(): 

88 yCol = "y" 

89 elif "yStars" in data.keys(): 

90 yCol = "yStars" 

91 elif "yGalaxies" in data.keys(): 

92 yCol = "yGalaxies" 

93 elif "yUnknowns" in data.keys(): 

94 yCol = "yUnknowns" 

95 

96 patchInfoDict = {} 

97 maxPatchNum = tractInfo.num_patches.x * tractInfo.num_patches.y 

98 patches = np.arange(0, maxPatchNum, 1) 

99 for patch in patches: 

100 if patch is None: 

101 continue 

102 # Once the objectTable_tract catalogues are using gen 3 patches 

103 # this will go away 

104 onPatch = data["patch"] == patch 

105 stat = np.nanmedian(data[yCol][onPatch]) 

106 try: 

107 patchTuple = (int(patch.split(",")[0]), int(patch.split(",")[-1])) 

108 patchInfo = tractInfo.getPatchInfo(patchTuple) 

109 gen3PatchId = tractInfo.getSequentialPatchIndex(patchInfo) 

110 except AttributeError: 

111 # For native gen 3 tables the patches don't need converting 

112 # When we are no longer looking at the gen 2 -> gen 3 

113 # converted repos we can tidy this up 

114 gen3PatchId = patch 

115 patchInfo = tractInfo.getPatchInfo(patch) 

116 

117 corners = Box2D(patchInfo.getInnerBBox()).getCorners() 

118 skyCoords = tractWcs.pixelToSky(corners) 

119 

120 patchInfoDict[gen3PatchId] = (skyCoords, stat) 

121 

122 tractCorners = Box2D(tractInfo.getBBox()).getCorners() 

123 skyCoords = tractWcs.pixelToSky(tractCorners) 

124 patchInfoDict["tract"] = (skyCoords, np.nan) 

125 

126 return patchInfoDict 

127 

128 

129def generateSummaryStatsVisit(cat, colName, visitSummaryTable, plotInfo): 

130 """Generate a summary statistic in each patch or detector 

131 Parameters 

132 ---------- 

133 cat : `pandas.core.frame.DataFrame` 

134 colName : `str` 

135 visitSummaryTable : `pandas.core.frame.DataFrame` 

136 plotInfo : `dict` 

137 Returns 

138 ------- 

139 visitInfoDict : `dict` 

140 """ 

141 

142 visitInfoDict = {} 

143 for ccd in cat.detector.unique(): 

144 if ccd is None: 

145 continue 

146 onCcd = cat["detector"] == ccd 

147 stat = np.nanmedian(cat[colName].values[onCcd]) 

148 

149 sumRow = visitSummaryTable["id"] == ccd 

150 corners = zip(visitSummaryTable["raCorners"][sumRow][0], visitSummaryTable["decCorners"][sumRow][0]) 

151 cornersOut = [] 

152 for (ra, dec) in corners: 

153 corner = SpherePoint(ra, dec, units=degrees) 

154 cornersOut.append(corner) 

155 

156 visitInfoDict[ccd] = (cornersOut, stat) 

157 

158 return visitInfoDict 

159 

160 

161# Inspired by matplotlib.testing.remove_ticks_and_titles 

162def get_and_remove_axis_text(ax) -> Tuple[List[str], List[np.ndarray]]: 

163 """Remove text from an Axis and its children and return with line points. 

164 Parameters 

165 ---------- 

166 ax : `plt.Axis` 

167 A matplotlib figure axis. 

168 Returns 

169 ------- 

170 texts : `List[str]` 

171 A list of all text strings (title and axis/legend/tick labels). 

172 line_xys : `List[numpy.ndarray]` 

173 A list of all line ``_xy`` attributes (arrays of shape ``(N, 2)``). 

174 """ 

175 line_xys = [line._xy for line in ax.lines] 

176 texts = [text.get_text() for text in (ax.title, ax.xaxis.label, ax.yaxis.label)] 

177 ax.set_title("") 

178 ax.set_xlabel("") 

179 ax.set_ylabel("") 

180 

181 try: 

182 texts_legend = ax.get_legend().texts 

183 texts.extend(text.get_text() for text in texts_legend) 

184 for text in texts_legend: 

185 text.set_alpha(0) 

186 except AttributeError: 

187 pass 

188 

189 for idx in range(len(ax.texts)): 

190 texts.append(ax.texts[idx].get_text()) 

191 ax.texts[idx].set_text("") 

192 

193 ax.xaxis.set_major_formatter(null_formatter) 

194 ax.xaxis.set_minor_formatter(null_formatter) 

195 ax.yaxis.set_major_formatter(null_formatter) 

196 ax.yaxis.set_minor_formatter(null_formatter) 

197 try: 

198 ax.zaxis.set_major_formatter(null_formatter) 

199 ax.zaxis.set_minor_formatter(null_formatter) 

200 except AttributeError: 

201 pass 

202 for child in ax.child_axes: 

203 texts_child, lines_child = get_and_remove_axis_text(child) 

204 texts.extend(texts_child) 

205 

206 return texts, line_xys 

207 

208 

209def get_and_remove_figure_text(figure: plt.Figure): 

210 """Remove text from a Figure and its Axes and return with line points. 

211 Parameters 

212 ---------- 

213 figure : `matplotlib.pyplot.Figure` 

214 A matplotlib figure. 

215 Returns 

216 ------- 

217 texts : `List[str]` 

218 A list of all text strings (title and axis/legend/tick labels). 

219 line_xys : `List[numpy.ndarray]`, (N, 2) 

220 A list of all line ``_xy`` attributes (arrays of shape ``(N, 2)``). 

221 """ 

222 texts = [str(figure._suptitle)] 

223 lines = [] 

224 figure.suptitle("") 

225 

226 texts.extend(text.get_text() for text in figure.texts) 

227 figure.texts = [] 

228 

229 for ax in figure.get_axes(): 

230 texts_ax, lines_ax = get_and_remove_axis_text(ax) 

231 texts.extend(texts_ax) 

232 lines.extend(lines_ax) 

233 

234 return texts, lines 

235 

236 

237def addPlotInfo(fig, plotInfo): 

238 """Add useful information to the plot 

239 Parameters 

240 ---------- 

241 fig : `matplotlib.figure.Figure` 

242 plotInfo : `dict` 

243 Returns 

244 ------- 

245 fig : `matplotlib.figure.Figure` 

246 """ 

247 

248 # TO DO: figure out how to get this information 

249 photocalibDataset = "None" 

250 astroDataset = "None" 

251 

252 fig.text(0.01, 0.99, plotInfo["plotName"], fontsize=8, transform=fig.transFigure, ha="left", va="top") 

253 

254 run = plotInfo["run"] 

255 datasetsUsed = f"\nPhotoCalib: {photocalibDataset}, Astrometry: {astroDataset}" 

256 tableType = f"\nTable: {plotInfo['tableName']}" 

257 

258 dataIdText = "" 

259 if "tract" in plotInfo.keys(): 

260 dataIdText += f", Tract: {plotInfo['tract']}" 

261 if "visit" in plotInfo.keys(): 

262 dataIdText += f", Visit: {plotInfo['visit']}" 

263 

264 bandText = "" 

265 for band in plotInfo["bands"]: 

266 bandText += band + ", " 

267 bandsText = f", Bands: {bandText[:-2]}" 

268 SNText = f", S/N: {plotInfo.get('SN', 'N/A')}" 

269 infoText = f"\n{run}{datasetsUsed}{tableType}{dataIdText}{bandsText}{SNText}" 

270 fig.text(0.01, 0.98, infoText, fontsize=7, transform=fig.transFigure, alpha=0.6, ha="left", va="top") 

271 

272 return fig 

273 

274 

275def stellarLocusFit(xs, ys, paramDict): 

276 """Make a fit to the stellar locus 

277 Parameters 

278 ---------- 

279 xs : `numpy.ndarray` 

280 The color on the xaxis 

281 ys : `numpy.ndarray` 

282 The color on the yaxis 

283 paramDict : lsst.pex.config.dictField.Dict 

284 A dictionary of parameters for line fitting 

285 xMin : `float` 

286 The minimum x edge of the box to use for initial fitting 

287 xMax : `float` 

288 The maximum x edge of the box to use for initial fitting 

289 yMin : `float` 

290 The minimum y edge of the box to use for initial fitting 

291 yMax : `float` 

292 The maximum y edge of the box to use for initial fitting 

293 mHW : `float` 

294 The hardwired gradient for the fit 

295 bHW : `float` 

296 The hardwired intercept of the fit 

297 Returns 

298 ------- 

299 paramsOut : `dict` 

300 A dictionary of the calculated fit parameters 

301 xMin : `float` 

302 The minimum x edge of the box to use for initial fitting 

303 xMax : `float` 

304 The maximum x edge of the box to use for initial fitting 

305 yMin : `float` 

306 The minimum y edge of the box to use for initial fitting 

307 yMax : `float` 

308 The maximum y edge of the box to use for initial fitting 

309 mHW : `float` 

310 The hardwired gradient for the fit 

311 bHW : `float` 

312 The hardwired intercept of the fit 

313 mODR : `float` 

314 The gradient calculated by the ODR fit 

315 bODR : `float` 

316 The intercept calculated by the ODR fit 

317 yBoxMin : `float` 

318 The y value of the fitted line at xMin 

319 yBoxMax : `float` 

320 The y value of the fitted line at xMax 

321 bPerpMin : `float` 

322 The intercept of the perpendicular line that goes through xMin 

323 bPerpMax : `float` 

324 The intercept of the perpendicular line that goes through xMax 

325 mODR2 : `float` 

326 The gradient from the second round of fitting 

327 bODR2 : `float` 

328 The intercept from the second round of fitting 

329 mPerp : `float` 

330 The gradient of the line perpendicular to the line from the 

331 second fit 

332 Notes 

333 ----- 

334 The code does two rounds of fitting, the first is initiated using the 

335 hardwired values given in the `paramDict` parameter and is done using 

336 an Orthogonal Distance Regression fit to the points defined by the 

337 box of xMin, xMax, yMin and yMax. Once this fitting has been done a 

338 perpendicular bisector is calculated at either end of the line and 

339 only points that fall within these lines are used to recalculate the fit. 

340 """ 

341 

342 # Points to use for the fit 

343 fitPoints = np.where( 

344 (xs > paramDict["xMin"]) 

345 & (xs < paramDict["xMax"]) 

346 & (ys > paramDict["yMin"]) 

347 & (ys < paramDict["yMax"]) 

348 )[0] 

349 

350 linear = scipyODR.polynomial(1) 

351 

352 data = scipyODR.Data(xs[fitPoints], ys[fitPoints]) 

353 odr = scipyODR.ODR(data, linear, beta0=[paramDict["bHW"], paramDict["mHW"]]) 

354 params = odr.run() 

355 mODR = float(params.beta[1]) 

356 bODR = float(params.beta[0]) 

357 

358 paramsOut = { 

359 "xMin": paramDict["xMin"], 

360 "xMax": paramDict["xMax"], 

361 "yMin": paramDict["yMin"], 

362 "yMax": paramDict["yMax"], 

363 "mHW": paramDict["mHW"], 

364 "bHW": paramDict["bHW"], 

365 "mODR": mODR, 

366 "bODR": bODR, 

367 } 

368 

369 # Having found the initial fit calculate perpendicular ends 

370 mPerp = -1.0 / mODR 

371 # When the gradient is really steep we need to use 

372 # the y limits of the box rather than the x ones 

373 

374 if np.abs(mODR) > 1: 

375 yBoxMin = paramDict["yMin"] 

376 xBoxMin = (yBoxMin - bODR) / mODR 

377 yBoxMax = paramDict["yMax"] 

378 xBoxMax = (yBoxMax - bODR) / mODR 

379 else: 

380 yBoxMin = mODR * paramDict["xMin"] + bODR 

381 xBoxMin = paramDict["xMin"] 

382 yBoxMax = mODR * paramDict["xMax"] + bODR 

383 xBoxMax = paramDict["xMax"] 

384 

385 bPerpMin = yBoxMin - mPerp * xBoxMin 

386 

387 paramsOut["yBoxMin"] = yBoxMin 

388 paramsOut["bPerpMin"] = bPerpMin 

389 

390 bPerpMax = yBoxMax - mPerp * xBoxMax 

391 

392 paramsOut["yBoxMax"] = yBoxMax 

393 paramsOut["bPerpMax"] = bPerpMax 

394 

395 # Use these perpendicular lines to chose the data and refit 

396 fitPoints = (ys > mPerp * xs + bPerpMin) & (ys < mPerp * xs + bPerpMax) 

397 data = scipyODR.Data(xs[fitPoints], ys[fitPoints]) 

398 odr = scipyODR.ODR(data, linear, beta0=[bODR, mODR]) 

399 params = odr.run() 

400 mODR = float(params.beta[1]) 

401 bODR = float(params.beta[0]) 

402 

403 paramsOut["mODR2"] = float(params.beta[1]) 

404 paramsOut["bODR2"] = float(params.beta[0]) 

405 

406 paramsOut["mPerp"] = -1.0 / paramsOut["mODR2"] 

407 

408 return paramsOut 

409 

410 

411def perpDistance(p1, p2, points): 

412 """Calculate the perpendicular distance to a line from a point 

413 Parameters 

414 ---------- 

415 p1 : `numpy.ndarray` 

416 A point on the line 

417 p2 : `numpy.ndarray` 

418 Another point on the line 

419 points : `zip` 

420 The points to calculate the distance to 

421 Returns 

422 ------- 

423 dists : `list` 

424 The distances from the line to the points. Uses the cross 

425 product to work this out. 

426 """ 

427 dists = [] 

428 for point in points: 

429 point = np.array(point) 

430 distToLine = np.cross(p2 - p1, point - p1) / np.linalg.norm(p2 - p1) 

431 dists.append(distToLine) 

432 

433 return dists 

434 

435 

436def mkColormap(colorNames): 

437 """Make a colormap from the list of color names. 

438 Parameters 

439 ---------- 

440 colorNames : `list` 

441 A list of strings that correspond to matplotlib 

442 named colors. 

443 Returns 

444 ------- 

445 cmap : `matplotlib.colors.LinearSegmentedColormap` 

446 """ 

447 

448 nums = np.linspace(0, 1, len(colorNames)) 

449 blues = [] 

450 greens = [] 

451 reds = [] 

452 for (num, color) in zip(nums, colorNames): 

453 r, g, b = colors.colorConverter.to_rgb(color) 

454 blues.append((num, b, b)) 

455 greens.append((num, g, g)) 

456 reds.append((num, r, r)) 

457 

458 colorDict = {"blue": blues, "red": reds, "green": greens} 

459 cmap = colors.LinearSegmentedColormap("newCmap", colorDict) 

460 return cmap 

461 

462 

463def extremaSort(xs): 

464 """Return the ids of the points reordered so that those 

465 furthest from the median, in absolute terms, are last. 

466 Parameters 

467 ---------- 

468 xs : `np.array` 

469 An array of the values to sort 

470 Returns 

471 ------- 

472 ids : `np.array` 

473 """ 

474 

475 med = np.nanmedian(xs) 

476 dists = np.abs(xs - med) 

477 ids = np.argsort(dists) 

478 return ids 

479 

480 

481def addSummaryPlot(fig, loc, sumStats, label): 

482 """Add a summary subplot to the figure. 

483 

484 Parameters 

485 ---------- 

486 fig : `matplotlib.figure.Figure` 

487 The figure that the summary plot is to be added to. 

488 loc : `matplotlib.gridspec.SubplotSpec` or `int` or `(int, int, index` 

489 Describes the location in the figure to put the summary plot, 

490 can be a gridspec SubplotSpec, a 3 digit integer where the first 

491 digit is the number of rows, the second is the number of columns 

492 and the third is the index. This is the same for the tuple 

493 of int, int, index. 

494 sumStats : `dict` 

495 A dictionary where the patchIds are the keys which store the R.A. 

496 and the dec of the corners of the patch, along with a summary 

497 statistic for each patch. 

498 

499 Returns 

500 ------- 

501 fig : `matplotlib.figure.Figure` 

502 """ 

503 

504 # Add the subplot to the relevant place in the figure 

505 # and sort the axis out 

506 axCorner = fig.add_subplot(loc) 

507 axCorner.yaxis.tick_right() 

508 axCorner.yaxis.set_label_position("right") 

509 axCorner.xaxis.tick_top() 

510 axCorner.xaxis.set_label_position("top") 

511 axCorner.set_aspect("equal") 

512 

513 # Plot the corners of the patches and make the color 

514 # coded rectangles for each patch, the colors show 

515 # the median of the given value in the patch 

516 patches = [] 

517 colors = [] 

518 for dataId in sumStats.keys(): 

519 (corners, stat) = sumStats[dataId] 

520 ra = corners[0][0].asDegrees() 

521 dec = corners[0][1].asDegrees() 

522 xy = (ra, dec) 

523 width = corners[2][0].asDegrees() - ra 

524 height = corners[2][1].asDegrees() - dec 

525 patches.append(Rectangle(xy, width, height)) 

526 colors.append(stat) 

527 ras = [ra.asDegrees() for (ra, dec) in corners] 

528 decs = [dec.asDegrees() for (ra, dec) in corners] 

529 axCorner.plot(ras + [ras[0]], decs + [decs[0]], "k", lw=0.5) 

530 cenX = ra + width / 2 

531 cenY = dec + height / 2 

532 if dataId != "tract": 

533 axCorner.annotate(dataId, (cenX, cenY), color="k", fontsize=4, ha="center", va="center") 

534 

535 # Set the bad color to transparent and make a masked array 

536 cmapPatch = plt.cm.coolwarm.copy() 

537 cmapPatch.set_bad(color="none") 

538 colors = np.ma.array(colors, mask=np.isnan(colors)) 

539 collection = PatchCollection(patches, cmap=cmapPatch) 

540 collection.set_array(colors) 

541 axCorner.add_collection(collection) 

542 

543 # Add some labels 

544 axCorner.set_xlabel("R.A. (deg)", fontsize=7) 

545 axCorner.set_ylabel("Dec. (deg)", fontsize=7) 

546 axCorner.tick_params(axis="both", labelsize=6, length=0, pad=1.5) 

547 axCorner.invert_xaxis() 

548 

549 # Add a colorbar 

550 pos = axCorner.get_position() 

551 yOffset = (pos.y1 - pos.y0) / 3 

552 cax = fig.add_axes([pos.x0, pos.y1 + yOffset, pos.x1 - pos.x0, 0.025]) 

553 plt.colorbar(collection, cax=cax, orientation="horizontal") 

554 cax.text( 

555 0.5, 

556 0.48, 

557 label, 

558 color="k", 

559 transform=cax.transAxes, 

560 rotation="horizontal", 

561 horizontalalignment="center", 

562 verticalalignment="center", 

563 fontsize=6, 

564 ) 

565 cax.tick_params( 

566 axis="x", labelsize=6, labeltop=True, labelbottom=False, bottom=False, top=True, pad=0.5, length=2 

567 ) 

568 

569 return fig 

570 

571 

572class PanelConfig(Config): 

573 """Configuration options for the plot panels used by DiaSkyPlot. 

574 

575 The defaults will produce a good-looking single panel plot. 

576 The subplot2grid* fields correspond to matplotlib.pyplot.subplot2grid. 

577 """ 

578 

579 topSpinesVisible = Field[bool]( 

580 doc="Draw line and ticks on top of panel?", 

581 default=False, 

582 ) 

583 bottomSpinesVisible = Field[bool]( 

584 doc="Draw line and ticks on bottom of panel?", 

585 default=True, 

586 ) 

587 leftSpinesVisible = Field[bool]( 

588 doc="Draw line and ticks on left side of panel?", 

589 default=True, 

590 ) 

591 rightSpinesVisible = Field[bool]( 

592 doc="Draw line and ticks on right side of panel?", 

593 default=True, 

594 ) 

595 subplot2gridShapeRow = Field[int]( 

596 doc="Number of rows of the grid in which to place axis.", 

597 default=10, 

598 ) 

599 subplot2gridShapeColumn = Field[int]( 

600 doc="Number of columns of the grid in which to place axis.", 

601 default=10, 

602 ) 

603 subplot2gridLocRow = Field[int]( 

604 doc="Row of the axis location within the grid.", 

605 default=1, 

606 ) 

607 subplot2gridLocColumn = Field[int]( 

608 doc="Column of the axis location within the grid.", 

609 default=1, 

610 ) 

611 subplot2gridRowspan = Field[int]( 

612 doc="Number of rows for the axis to span downwards.", 

613 default=5, 

614 ) 

615 subplot2gridColspan = Field[int]( 

616 doc="Number of rows for the axis to span to the right.", 

617 default=5, 

618 )