Coverage for python/lsst/analysis/tools/actions/plot/plotUtils.py: 11%

250 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-10-28 11:25 +0000

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("PanelConfig",) 

24 

25from typing import List, Tuple 

26 

27import matplotlib 

28import matplotlib.pyplot as plt 

29import numpy as np 

30import scipy.odr as scipyODR 

31from lsst.geom import Box2D, SpherePoint, degrees 

32from lsst.pex.config import Config, Field 

33from matplotlib import colors 

34from matplotlib.collections import PatchCollection 

35from matplotlib.patches import Rectangle 

36 

37null_formatter = matplotlib.ticker.NullFormatter() 

38 

39 

40def parsePlotInfo(dataId, runName, tableName, bands, plotName, SN): 

41 """Parse plot info from the dataId 

42 Parameters 

43 ---------- 

44 dataId : `lsst.daf.butler.core.dimensions.` 

45 `_coordinate._ExpandedTupleDataCoordinate` 

46 runName : `str` 

47 Returns 

48 ------- 

49 plotInfo : `dict` 

50 """ 

51 plotInfo = {"run": runName, "tableName": tableName, "plotName": plotName, "SN": SN} 

52 

53 for dataInfo in dataId: 

54 plotInfo[dataInfo.name] = dataId[dataInfo.name] 

55 

56 bandStr = "" 

57 for band in bands: 

58 bandStr += ", " + band 

59 plotInfo["bands"] = bandStr[2:] 

60 

61 if "tract" not in plotInfo.keys(): 

62 plotInfo["tract"] = "N/A" 

63 if "visit" not in plotInfo.keys(): 

64 plotInfo["visit"] = "N/A" 

65 

66 return plotInfo 

67 

68 

69def generateSummaryStats(data, skymap, plotInfo): 

70 """Generate a summary statistic in each patch or detector 

71 Parameters 

72 ---------- 

73 data : `dict` 

74 skymap : `lsst.skymap.ringsSkyMap.RingsSkyMap` 

75 plotInfo : `dict` 

76 

77 Returns 

78 ------- 

79 patchInfoDict : `dict` 

80 """ 

81 

82 # TODO: what is the more generic type of skymap? 

83 tractInfo = skymap.generateTract(plotInfo["tract"]) 

84 tractWcs = tractInfo.getWcs() 

85 

86 # For now also convert the gen 2 patchIds to gen 3 

87 if "y" in data.keys(): 

88 yCol = "y" 

89 elif "yStars" in data.keys(): 

90 yCol = "yStars" 

91 elif "yGalaxies" in data.keys(): 

92 yCol = "yGalaxies" 

93 elif "yUnknowns" in data.keys(): 

94 yCol = "yUnknowns" 

95 

96 patchInfoDict = {} 

97 maxPatchNum = tractInfo.num_patches.x * tractInfo.num_patches.y 

98 patches = np.arange(0, maxPatchNum, 1) 

99 for patch in patches: 

100 if patch is None: 

101 continue 

102 # Once the objectTable_tract catalogues are using gen 3 patches 

103 # this will go away 

104 onPatch = data["patch"] == patch 

105 stat = np.nanmedian(data[yCol][onPatch]) 

106 try: 

107 patchTuple = (int(patch.split(",")[0]), int(patch.split(",")[-1])) 

108 patchInfo = tractInfo.getPatchInfo(patchTuple) 

109 gen3PatchId = tractInfo.getSequentialPatchIndex(patchInfo) 

110 except AttributeError: 

111 # For native gen 3 tables the patches don't need converting 

112 # When we are no longer looking at the gen 2 -> gen 3 

113 # converted repos we can tidy this up 

114 gen3PatchId = patch 

115 patchInfo = tractInfo.getPatchInfo(patch) 

116 

117 corners = Box2D(patchInfo.getInnerBBox()).getCorners() 

118 skyCoords = tractWcs.pixelToSky(corners) 

119 

120 patchInfoDict[gen3PatchId] = (skyCoords, stat) 

121 

122 tractCorners = Box2D(tractInfo.getBBox()).getCorners() 

123 skyCoords = tractWcs.pixelToSky(tractCorners) 

124 patchInfoDict["tract"] = (skyCoords, np.nan) 

125 

126 return patchInfoDict 

127 

128 

129def generateSummaryStatsVisit(cat, colName, visitSummaryTable, plotInfo): 

130 """Generate a summary statistic in each patch or detector 

131 Parameters 

132 ---------- 

133 cat : `pandas.core.frame.DataFrame` 

134 colName : `str` 

135 visitSummaryTable : `pandas.core.frame.DataFrame` 

136 plotInfo : `dict` 

137 Returns 

138 ------- 

139 visitInfoDict : `dict` 

140 """ 

141 

142 visitInfoDict = {} 

143 for ccd in cat.detector.unique(): 

144 if ccd is None: 

145 continue 

146 onCcd = cat["detector"] == ccd 

147 stat = np.nanmedian(cat[colName].values[onCcd]) 

148 

149 sumRow = visitSummaryTable["id"] == ccd 

150 corners = zip(visitSummaryTable["raCorners"][sumRow][0], visitSummaryTable["decCorners"][sumRow][0]) 

151 cornersOut = [] 

152 for (ra, dec) in corners: 

153 corner = SpherePoint(ra, dec, units=degrees) 

154 cornersOut.append(corner) 

155 

156 visitInfoDict[ccd] = (cornersOut, stat) 

157 

158 return visitInfoDict 

159 

160 

161# Inspired by matplotlib.testing.remove_ticks_and_titles 

162def get_and_remove_axis_text(ax) -> Tuple[List[str], List[np.ndarray]]: 

163 """Remove text from an Axis and its children and return with line points. 

164 Parameters 

165 ---------- 

166 ax : `plt.Axis` 

167 A matplotlib figure axis. 

168 Returns 

169 ------- 

170 texts : `List[str]` 

171 A list of all text strings (title and axis/legend/tick labels). 

172 line_xys : `List[numpy.ndarray]` 

173 A list of all line ``_xy`` attributes (arrays of shape ``(N, 2)``). 

174 """ 

175 line_xys = [line._xy for line in ax.lines] 

176 texts = [text.get_text() for text in (ax.title, ax.xaxis.label, ax.yaxis.label)] 

177 ax.set_title("") 

178 ax.set_xlabel("") 

179 ax.set_ylabel("") 

180 

181 try: 

182 texts_legend = ax.get_legend().texts 

183 texts.extend(text.get_text() for text in texts_legend) 

184 for text in texts_legend: 

185 text.set_alpha(0) 

186 except AttributeError: 

187 pass 

188 

189 for idx in range(len(ax.texts)): 

190 texts.append(ax.texts[idx].get_text()) 

191 ax.texts[idx].set_text("") 

192 

193 ax.xaxis.set_major_formatter(null_formatter) 

194 ax.xaxis.set_minor_formatter(null_formatter) 

195 ax.yaxis.set_major_formatter(null_formatter) 

196 ax.yaxis.set_minor_formatter(null_formatter) 

197 try: 

198 ax.zaxis.set_major_formatter(null_formatter) 

199 ax.zaxis.set_minor_formatter(null_formatter) 

200 except AttributeError: 

201 pass 

202 for child in ax.child_axes: 

203 texts_child, lines_child = get_and_remove_axis_text(child) 

204 texts.extend(texts_child) 

205 

206 return texts, line_xys 

207 

208 

209def get_and_remove_figure_text(figure: plt.Figure): 

210 """Remove text from a Figure and its Axes and return with line points. 

211 Parameters 

212 ---------- 

213 figure : `matplotlib.pyplot.Figure` 

214 A matplotlib figure. 

215 Returns 

216 ------- 

217 texts : `List[str]` 

218 A list of all text strings (title and axis/legend/tick labels). 

219 line_xys : `List[numpy.ndarray]`, (N, 2) 

220 A list of all line ``_xy`` attributes (arrays of shape ``(N, 2)``). 

221 """ 

222 texts = [str(figure._suptitle)] 

223 lines = [] 

224 figure.suptitle("") 

225 

226 texts.extend(text.get_text() for text in figure.texts) 

227 figure.texts = [] 

228 

229 for ax in figure.get_axes(): 

230 texts_ax, lines_ax = get_and_remove_axis_text(ax) 

231 texts.extend(texts_ax) 

232 lines.extend(lines_ax) 

233 

234 return texts, lines 

235 

236 

237def addPlotInfo(fig, plotInfo): 

238 """Add useful information to the plot 

239 Parameters 

240 ---------- 

241 fig : `matplotlib.figure.Figure` 

242 plotInfo : `dict` 

243 Returns 

244 ------- 

245 fig : `matplotlib.figure.Figure` 

246 """ 

247 

248 # TO DO: figure out how to get this information 

249 photocalibDataset = "None" 

250 astroDataset = "None" 

251 

252 fig.text(0.01, 0.99, plotInfo["plotName"], fontsize=8, transform=fig.transFigure, ha="left", va="top") 

253 

254 run = plotInfo["run"] 

255 datasetsUsed = f"\nPhotoCalib: {photocalibDataset}, Astrometry: {astroDataset}" 

256 tableType = f"\nTable: {plotInfo['tableName']}" 

257 

258 dataIdText = "" 

259 if "tract" in plotInfo.keys(): 

260 dataIdText += f", Tract: {plotInfo['tract']}" 

261 if "visit" in plotInfo.keys(): 

262 dataIdText += f", Visit: {plotInfo['visit']}" 

263 

264 bandText = "" 

265 for band in plotInfo["bands"]: 

266 bandText += band + ", " 

267 bandsText = f", Bands: {bandText[:-2]}" 

268 try: 

269 SNText = f", S/N: {plotInfo['SN']}" 

270 except KeyError: 

271 SNText = ", S/N: -" 

272 infoText = f"\n{run}{datasetsUsed}{tableType}{dataIdText}{bandsText}{SNText}" 

273 fig.text(0.01, 0.98, infoText, fontsize=7, transform=fig.transFigure, alpha=0.6, ha="left", va="top") 

274 

275 return fig 

276 

277 

278def stellarLocusFit(xs, ys, paramDict): 

279 """Make a fit to the stellar locus 

280 Parameters 

281 ---------- 

282 xs : `numpy.ndarray` 

283 The color on the xaxis 

284 ys : `numpy.ndarray` 

285 The color on the yaxis 

286 paramDict : lsst.pex.config.dictField.Dict 

287 A dictionary of parameters for line fitting 

288 xMin : `float` 

289 The minimum x edge of the box to use for initial fitting 

290 xMax : `float` 

291 The maximum x edge of the box to use for initial fitting 

292 yMin : `float` 

293 The minimum y edge of the box to use for initial fitting 

294 yMax : `float` 

295 The maximum y edge of the box to use for initial fitting 

296 mHW : `float` 

297 The hardwired gradient for the fit 

298 bHW : `float` 

299 The hardwired intercept of the fit 

300 Returns 

301 ------- 

302 paramsOut : `dict` 

303 A dictionary of the calculated fit parameters 

304 xMin : `float` 

305 The minimum x edge of the box to use for initial fitting 

306 xMax : `float` 

307 The maximum x edge of the box to use for initial fitting 

308 yMin : `float` 

309 The minimum y edge of the box to use for initial fitting 

310 yMax : `float` 

311 The maximum y edge of the box to use for initial fitting 

312 mHW : `float` 

313 The hardwired gradient for the fit 

314 bHW : `float` 

315 The hardwired intercept of the fit 

316 mODR : `float` 

317 The gradient calculated by the ODR fit 

318 bODR : `float` 

319 The intercept calculated by the ODR fit 

320 yBoxMin : `float` 

321 The y value of the fitted line at xMin 

322 yBoxMax : `float` 

323 The y value of the fitted line at xMax 

324 bPerpMin : `float` 

325 The intercept of the perpendicular line that goes through xMin 

326 bPerpMax : `float` 

327 The intercept of the perpendicular line that goes through xMax 

328 mODR2 : `float` 

329 The gradient from the second round of fitting 

330 bODR2 : `float` 

331 The intercept from the second round of fitting 

332 mPerp : `float` 

333 The gradient of the line perpendicular to the line from the 

334 second fit 

335 Notes 

336 ----- 

337 The code does two rounds of fitting, the first is initiated using the 

338 hardwired values given in the `paramDict` parameter and is done using 

339 an Orthogonal Distance Regression fit to the points defined by the 

340 box of xMin, xMax, yMin and yMax. Once this fitting has been done a 

341 perpendicular bisector is calculated at either end of the line and 

342 only points that fall within these lines are used to recalculate the fit. 

343 """ 

344 

345 # Points to use for the fit 

346 fitPoints = np.where( 

347 (xs > paramDict["xMin"]) 

348 & (xs < paramDict["xMax"]) 

349 & (ys > paramDict["yMin"]) 

350 & (ys < paramDict["yMax"]) 

351 )[0] 

352 

353 linear = scipyODR.polynomial(1) 

354 

355 data = scipyODR.Data(xs[fitPoints], ys[fitPoints]) 

356 odr = scipyODR.ODR(data, linear, beta0=[paramDict["bHW"], paramDict["mHW"]]) 

357 params = odr.run() 

358 mODR = float(params.beta[1]) 

359 bODR = float(params.beta[0]) 

360 

361 paramsOut = { 

362 "xMin": paramDict["xMin"], 

363 "xMax": paramDict["xMax"], 

364 "yMin": paramDict["yMin"], 

365 "yMax": paramDict["yMax"], 

366 "mHW": paramDict["mHW"], 

367 "bHW": paramDict["bHW"], 

368 "mODR": mODR, 

369 "bODR": bODR, 

370 } 

371 

372 # Having found the initial fit calculate perpendicular ends 

373 mPerp = -1.0 / mODR 

374 # When the gradient is really steep we need to use 

375 # the y limits of the box rather than the x ones 

376 

377 if np.abs(mODR) > 1: 

378 yBoxMin = paramDict["yMin"] 

379 xBoxMin = (yBoxMin - bODR) / mODR 

380 yBoxMax = paramDict["yMax"] 

381 xBoxMax = (yBoxMax - bODR) / mODR 

382 else: 

383 yBoxMin = mODR * paramDict["xMin"] + bODR 

384 xBoxMin = paramDict["xMin"] 

385 yBoxMax = mODR * paramDict["xMax"] + bODR 

386 xBoxMax = paramDict["xMax"] 

387 

388 bPerpMin = yBoxMin - mPerp * xBoxMin 

389 

390 paramsOut["yBoxMin"] = yBoxMin 

391 paramsOut["bPerpMin"] = bPerpMin 

392 

393 bPerpMax = yBoxMax - mPerp * xBoxMax 

394 

395 paramsOut["yBoxMax"] = yBoxMax 

396 paramsOut["bPerpMax"] = bPerpMax 

397 

398 # Use these perpendicular lines to chose the data and refit 

399 fitPoints = (ys > mPerp * xs + bPerpMin) & (ys < mPerp * xs + bPerpMax) 

400 data = scipyODR.Data(xs[fitPoints], ys[fitPoints]) 

401 odr = scipyODR.ODR(data, linear, beta0=[bODR, mODR]) 

402 params = odr.run() 

403 mODR = float(params.beta[1]) 

404 bODR = float(params.beta[0]) 

405 

406 paramsOut["mODR2"] = float(params.beta[1]) 

407 paramsOut["bODR2"] = float(params.beta[0]) 

408 

409 paramsOut["mPerp"] = -1.0 / paramsOut["mODR2"] 

410 

411 return paramsOut 

412 

413 

414def perpDistance(p1, p2, points): 

415 """Calculate the perpendicular distance to a line from a point 

416 Parameters 

417 ---------- 

418 p1 : `numpy.ndarray` 

419 A point on the line 

420 p2 : `numpy.ndarray` 

421 Another point on the line 

422 points : `zip` 

423 The points to calculate the distance to 

424 Returns 

425 ------- 

426 dists : `list` 

427 The distances from the line to the points. Uses the cross 

428 product to work this out. 

429 """ 

430 dists = [] 

431 for point in points: 

432 point = np.array(point) 

433 distToLine = np.cross(p2 - p1, point - p1) / np.linalg.norm(p2 - p1) 

434 dists.append(distToLine) 

435 

436 return dists 

437 

438 

439def mkColormap(colorNames): 

440 """Make a colormap from the list of color names. 

441 Parameters 

442 ---------- 

443 colorNames : `list` 

444 A list of strings that correspond to matplotlib 

445 named colors. 

446 Returns 

447 ------- 

448 cmap : `matplotlib.colors.LinearSegmentedColormap` 

449 """ 

450 

451 nums = np.linspace(0, 1, len(colorNames)) 

452 blues = [] 

453 greens = [] 

454 reds = [] 

455 for (num, color) in zip(nums, colorNames): 

456 r, g, b = colors.colorConverter.to_rgb(color) 

457 blues.append((num, b, b)) 

458 greens.append((num, g, g)) 

459 reds.append((num, r, r)) 

460 

461 colorDict = {"blue": blues, "red": reds, "green": greens} 

462 cmap = colors.LinearSegmentedColormap("newCmap", colorDict) 

463 return cmap 

464 

465 

466def extremaSort(xs): 

467 """Return the ids of the points reordered so that those 

468 furthest from the median, in absolute terms, are last. 

469 Parameters 

470 ---------- 

471 xs : `np.array` 

472 An array of the values to sort 

473 Returns 

474 ------- 

475 ids : `np.array` 

476 """ 

477 

478 med = np.median(xs) 

479 dists = np.abs(xs - med) 

480 ids = np.argsort(dists) 

481 return ids 

482 

483 

484def addSummaryPlot(fig, loc, sumStats, label): 

485 """Add a summary subplot to the figure. 

486 

487 Parameters 

488 ---------- 

489 fig : `matplotlib.figure.Figure` 

490 The figure that the summary plot is to be added to. 

491 loc : `matplotlib.gridspec.SubplotSpec` or `int` or `(int, int, index` 

492 Describes the location in the figure to put the summary plot, 

493 can be a gridspec SubplotSpec, a 3 digit integer where the first 

494 digit is the number of rows, the second is the number of columns 

495 and the third is the index. This is the same for the tuple 

496 of int, int, index. 

497 sumStats : `dict` 

498 A dictionary where the patchIds are the keys which store the R.A. 

499 and the dec of the corners of the patch, along with a summary 

500 statistic for each patch. 

501 

502 Returns 

503 ------- 

504 fig : `matplotlib.figure.Figure` 

505 """ 

506 

507 # Add the subplot to the relevant place in the figure 

508 # and sort the axis out 

509 axCorner = fig.add_subplot(loc) 

510 axCorner.yaxis.tick_right() 

511 axCorner.yaxis.set_label_position("right") 

512 axCorner.xaxis.tick_top() 

513 axCorner.xaxis.set_label_position("top") 

514 axCorner.set_aspect("equal") 

515 

516 # Plot the corners of the patches and make the color 

517 # coded rectangles for each patch, the colors show 

518 # the median of the given value in the patch 

519 patches = [] 

520 colors = [] 

521 for dataId in sumStats.keys(): 

522 (corners, stat) = sumStats[dataId] 

523 ra = corners[0][0].asDegrees() 

524 dec = corners[0][1].asDegrees() 

525 xy = (ra, dec) 

526 width = corners[2][0].asDegrees() - ra 

527 height = corners[2][1].asDegrees() - dec 

528 patches.append(Rectangle(xy, width, height)) 

529 colors.append(stat) 

530 ras = [ra.asDegrees() for (ra, dec) in corners] 

531 decs = [dec.asDegrees() for (ra, dec) in corners] 

532 axCorner.plot(ras + [ras[0]], decs + [decs[0]], "k", lw=0.5) 

533 cenX = ra + width / 2 

534 cenY = dec + height / 2 

535 if dataId != "tract": 

536 axCorner.annotate(dataId, (cenX, cenY), color="k", fontsize=4, ha="center", va="center") 

537 

538 # Set the bad color to transparent and make a masked array 

539 cmapPatch = plt.cm.coolwarm.copy() 

540 cmapPatch.set_bad(color="none") 

541 colors = np.ma.array(colors, mask=np.isnan(colors)) 

542 collection = PatchCollection(patches, cmap=cmapPatch) 

543 collection.set_array(colors) 

544 axCorner.add_collection(collection) 

545 

546 # Add some labels 

547 axCorner.set_xlabel("R.A. (deg)", fontsize=7) 

548 axCorner.set_ylabel("Dec. (deg)", fontsize=7) 

549 axCorner.tick_params(axis="both", labelsize=6, length=0, pad=1.5) 

550 axCorner.invert_xaxis() 

551 

552 # Add a colorbar 

553 pos = axCorner.get_position() 

554 yOffset = (pos.y1 - pos.y0) / 3 

555 cax = fig.add_axes([pos.x0, pos.y1 + yOffset, pos.x1 - pos.x0, 0.025]) 

556 plt.colorbar(collection, cax=cax, orientation="horizontal") 

557 cax.text( 

558 0.5, 

559 0.48, 

560 label, 

561 color="k", 

562 transform=cax.transAxes, 

563 rotation="horizontal", 

564 horizontalalignment="center", 

565 verticalalignment="center", 

566 fontsize=6, 

567 ) 

568 cax.tick_params( 

569 axis="x", labelsize=6, labeltop=True, labelbottom=False, bottom=False, top=True, pad=0.5, length=2 

570 ) 

571 

572 return fig 

573 

574 

575class PanelConfig(Config): 

576 """Configuration options for the plot panels used by DiaSkyPlot. 

577 

578 The defaults will produce a good-looking single panel plot. 

579 The subplot2grid* fields correspond to matplotlib.pyplot.subplot2grid. 

580 """ 

581 

582 topSpinesVisible = Field[bool]( 

583 doc="Draw line and ticks on top of panel?", 

584 default=False, 

585 ) 

586 bottomSpinesVisible = Field[bool]( 

587 doc="Draw line and ticks on bottom of panel?", 

588 default=True, 

589 ) 

590 leftSpinesVisible = Field[bool]( 

591 doc="Draw line and ticks on left side of panel?", 

592 default=True, 

593 ) 

594 rightSpinesVisible = Field[bool]( 

595 doc="Draw line and ticks on right side of panel?", 

596 default=True, 

597 ) 

598 subplot2gridShapeRow = Field[int]( 

599 doc="Number of rows of the grid in which to place axis.", 

600 default=10, 

601 ) 

602 subplot2gridShapeColumn = Field[int]( 

603 doc="Number of columns of the grid in which to place axis.", 

604 default=10, 

605 ) 

606 subplot2gridLocRow = Field[int]( 

607 doc="Row of the axis location within the grid.", 

608 default=1, 

609 ) 

610 subplot2gridLocColumn = Field[int]( 

611 doc="Column of the axis location within the grid.", 

612 default=1, 

613 ) 

614 subplot2gridRowspan = Field[int]( 

615 doc="Number of rows for the axis to span downwards.", 

616 default=5, 

617 ) 

618 subplot2gridColspan = Field[int]( 

619 doc="Number of rows for the axis to span to the right.", 

620 default=5, 

621 )