Coverage for python/lsst/analysis/tools/actions/plot/plotUtils.py: 12%

249 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-12-23 09:29 +0000

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("PanelConfig",) 

24 

25from typing import TYPE_CHECKING, List, Mapping, Tuple 

26 

27import matplotlib 

28import matplotlib.pyplot as plt 

29import numpy as np 

30import scipy.odr as scipyODR 

31from lsst.geom import Box2D, SpherePoint, degrees 

32from lsst.pex.config import Config, Field 

33from matplotlib import colors 

34from matplotlib.collections import PatchCollection 

35from matplotlib.patches import Rectangle 

36 

37if TYPE_CHECKING: 37 ↛ 38line 37 didn't jump to line 38, because the condition on line 37 was never true

38 from matplotlib.figure import Figure 

39 

40null_formatter = matplotlib.ticker.NullFormatter() 

41 

42 

43def parsePlotInfo(dataId, runName, tableName, bands, plotName, SN): 

44 """Parse plot info from the dataId 

45 Parameters 

46 ---------- 

47 dataId : `lsst.daf.butler.core.dimensions.` 

48 `_coordinate._ExpandedTupleDataCoordinate` 

49 runName : `str` 

50 Returns 

51 ------- 

52 plotInfo : `dict` 

53 """ 

54 plotInfo = {"run": runName, "tableName": tableName, "plotName": plotName, "SN": SN} 

55 

56 for dataInfo in dataId: 

57 plotInfo[dataInfo.name] = dataId[dataInfo.name] 

58 

59 bandStr = "" 

60 for band in bands: 

61 bandStr += ", " + band 

62 plotInfo["bands"] = bandStr[2:] 

63 

64 if "tract" not in plotInfo.keys(): 

65 plotInfo["tract"] = "N/A" 

66 if "visit" not in plotInfo.keys(): 

67 plotInfo["visit"] = "N/A" 

68 

69 return plotInfo 

70 

71 

72def generateSummaryStats(data, skymap, plotInfo): 

73 """Generate a summary statistic in each patch or detector 

74 Parameters 

75 ---------- 

76 data : `dict` 

77 skymap : `lsst.skymap.ringsSkyMap.RingsSkyMap` 

78 plotInfo : `dict` 

79 

80 Returns 

81 ------- 

82 patchInfoDict : `dict` 

83 """ 

84 

85 # TODO: what is the more generic type of skymap? 

86 tractInfo = skymap.generateTract(plotInfo["tract"]) 

87 tractWcs = tractInfo.getWcs() 

88 

89 # For now also convert the gen 2 patchIds to gen 3 

90 if "y" in data.keys(): 

91 yCol = "y" 

92 elif "yStars" in data.keys(): 

93 yCol = "yStars" 

94 elif "yGalaxies" in data.keys(): 

95 yCol = "yGalaxies" 

96 elif "yUnknowns" in data.keys(): 

97 yCol = "yUnknowns" 

98 

99 patchInfoDict = {} 

100 maxPatchNum = tractInfo.num_patches.x * tractInfo.num_patches.y 

101 patches = np.arange(0, maxPatchNum, 1) 

102 for patch in patches: 

103 if patch is None: 

104 continue 

105 # Once the objectTable_tract catalogues are using gen 3 patches 

106 # this will go away 

107 onPatch = data["patch"] == patch 

108 stat = np.nanmedian(data[yCol][onPatch]) 

109 try: 

110 patchTuple = (int(patch.split(",")[0]), int(patch.split(",")[-1])) 

111 patchInfo = tractInfo.getPatchInfo(patchTuple) 

112 gen3PatchId = tractInfo.getSequentialPatchIndex(patchInfo) 

113 except AttributeError: 

114 # For native gen 3 tables the patches don't need converting 

115 # When we are no longer looking at the gen 2 -> gen 3 

116 # converted repos we can tidy this up 

117 gen3PatchId = patch 

118 patchInfo = tractInfo.getPatchInfo(patch) 

119 

120 corners = Box2D(patchInfo.getInnerBBox()).getCorners() 

121 skyCoords = tractWcs.pixelToSky(corners) 

122 

123 patchInfoDict[gen3PatchId] = (skyCoords, stat) 

124 

125 tractCorners = Box2D(tractInfo.getBBox()).getCorners() 

126 skyCoords = tractWcs.pixelToSky(tractCorners) 

127 patchInfoDict["tract"] = (skyCoords, np.nan) 

128 

129 return patchInfoDict 

130 

131 

132def generateSummaryStatsVisit(cat, colName, visitSummaryTable, plotInfo): 

133 """Generate a summary statistic in each patch or detector 

134 Parameters 

135 ---------- 

136 cat : `pandas.core.frame.DataFrame` 

137 colName : `str` 

138 visitSummaryTable : `pandas.core.frame.DataFrame` 

139 plotInfo : `dict` 

140 Returns 

141 ------- 

142 visitInfoDict : `dict` 

143 """ 

144 

145 visitInfoDict = {} 

146 for ccd in cat.detector.unique(): 

147 if ccd is None: 

148 continue 

149 onCcd = cat["detector"] == ccd 

150 stat = np.nanmedian(cat[colName].values[onCcd]) 

151 

152 sumRow = visitSummaryTable["id"] == ccd 

153 corners = zip(visitSummaryTable["raCorners"][sumRow][0], visitSummaryTable["decCorners"][sumRow][0]) 

154 cornersOut = [] 

155 for (ra, dec) in corners: 

156 corner = SpherePoint(ra, dec, units=degrees) 

157 cornersOut.append(corner) 

158 

159 visitInfoDict[ccd] = (cornersOut, stat) 

160 

161 return visitInfoDict 

162 

163 

164# Inspired by matplotlib.testing.remove_ticks_and_titles 

165def get_and_remove_axis_text(ax) -> Tuple[List[str], List[np.ndarray]]: 

166 """Remove text from an Axis and its children and return with line points. 

167 Parameters 

168 ---------- 

169 ax : `plt.Axis` 

170 A matplotlib figure axis. 

171 Returns 

172 ------- 

173 texts : `List[str]` 

174 A list of all text strings (title and axis/legend/tick labels). 

175 line_xys : `List[numpy.ndarray]` 

176 A list of all line ``_xy`` attributes (arrays of shape ``(N, 2)``). 

177 """ 

178 line_xys = [line._xy for line in ax.lines] 

179 texts = [text.get_text() for text in (ax.title, ax.xaxis.label, ax.yaxis.label)] 

180 ax.set_title("") 

181 ax.set_xlabel("") 

182 ax.set_ylabel("") 

183 

184 try: 

185 texts_legend = ax.get_legend().texts 

186 texts.extend(text.get_text() for text in texts_legend) 

187 for text in texts_legend: 

188 text.set_alpha(0) 

189 except AttributeError: 

190 pass 

191 

192 for idx in range(len(ax.texts)): 

193 texts.append(ax.texts[idx].get_text()) 

194 ax.texts[idx].set_text("") 

195 

196 ax.xaxis.set_major_formatter(null_formatter) 

197 ax.xaxis.set_minor_formatter(null_formatter) 

198 ax.yaxis.set_major_formatter(null_formatter) 

199 ax.yaxis.set_minor_formatter(null_formatter) 

200 try: 

201 ax.zaxis.set_major_formatter(null_formatter) 

202 ax.zaxis.set_minor_formatter(null_formatter) 

203 except AttributeError: 

204 pass 

205 for child in ax.child_axes: 

206 texts_child, lines_child = get_and_remove_axis_text(child) 

207 texts.extend(texts_child) 

208 

209 return texts, line_xys 

210 

211 

212def get_and_remove_figure_text(figure: Figure): 

213 """Remove text from a Figure and its Axes and return with line points. 

214 Parameters 

215 ---------- 

216 figure : `matplotlib.pyplot.Figure` 

217 A matplotlib figure. 

218 Returns 

219 ------- 

220 texts : `List[str]` 

221 A list of all text strings (title and axis/legend/tick labels). 

222 line_xys : `List[numpy.ndarray]`, (N, 2) 

223 A list of all line ``_xy`` attributes (arrays of shape ``(N, 2)``). 

224 """ 

225 texts = [str(figure._suptitle)] 

226 lines = [] 

227 figure.suptitle("") 

228 

229 texts.extend(text.get_text() for text in figure.texts) 

230 figure.texts = [] 

231 

232 for ax in figure.get_axes(): 

233 texts_ax, lines_ax = get_and_remove_axis_text(ax) 

234 texts.extend(texts_ax) 

235 lines.extend(lines_ax) 

236 

237 return texts, lines 

238 

239 

240def addPlotInfo(fig: Figure, plotInfo: Mapping[str, str]) -> Figure: 

241 """Add useful information to the plot 

242 Parameters 

243 ---------- 

244 fig : `matplotlib.figure.Figure` 

245 plotInfo : `dict` 

246 Returns 

247 ------- 

248 fig : `matplotlib.figure.Figure` 

249 """ 

250 

251 # TO DO: figure out how to get this information 

252 photocalibDataset = "None" 

253 astroDataset = "None" 

254 

255 fig.text(0.01, 0.99, plotInfo["plotName"], fontsize=8, transform=fig.transFigure, ha="left", va="top") 

256 

257 run = plotInfo["run"] 

258 datasetsUsed = f"\nPhotoCalib: {photocalibDataset}, Astrometry: {astroDataset}" 

259 tableType = f"\nTable: {plotInfo['tableName']}" 

260 

261 dataIdText = "" 

262 if "tract" in plotInfo.keys(): 

263 dataIdText += f", Tract: {plotInfo['tract']}" 

264 if "visit" in plotInfo.keys(): 

265 dataIdText += f", Visit: {plotInfo['visit']}" 

266 

267 bandText = "" 

268 for band in plotInfo["bands"]: 

269 bandText += band + ", " 

270 bandsText = f", Bands: {bandText[:-2]}" 

271 SNText = f", S/N: {plotInfo.get('SN', 'N/A')}" 

272 infoText = f"\n{run}{datasetsUsed}{tableType}{dataIdText}{bandsText}{SNText}" 

273 fig.text(0.01, 0.98, infoText, fontsize=7, transform=fig.transFigure, alpha=0.6, ha="left", va="top") 

274 

275 return fig 

276 

277 

278def stellarLocusFit(xs, ys, paramDict): 

279 """Make a fit to the stellar locus 

280 Parameters 

281 ---------- 

282 xs : `numpy.ndarray` 

283 The color on the xaxis 

284 ys : `numpy.ndarray` 

285 The color on the yaxis 

286 paramDict : lsst.pex.config.dictField.Dict 

287 A dictionary of parameters for line fitting 

288 xMin : `float` 

289 The minimum x edge of the box to use for initial fitting 

290 xMax : `float` 

291 The maximum x edge of the box to use for initial fitting 

292 yMin : `float` 

293 The minimum y edge of the box to use for initial fitting 

294 yMax : `float` 

295 The maximum y edge of the box to use for initial fitting 

296 mHW : `float` 

297 The hardwired gradient for the fit 

298 bHW : `float` 

299 The hardwired intercept of the fit 

300 Returns 

301 ------- 

302 paramsOut : `dict` 

303 A dictionary of the calculated fit parameters 

304 xMin : `float` 

305 The minimum x edge of the box to use for initial fitting 

306 xMax : `float` 

307 The maximum x edge of the box to use for initial fitting 

308 yMin : `float` 

309 The minimum y edge of the box to use for initial fitting 

310 yMax : `float` 

311 The maximum y edge of the box to use for initial fitting 

312 mHW : `float` 

313 The hardwired gradient for the fit 

314 bHW : `float` 

315 The hardwired intercept of the fit 

316 mODR : `float` 

317 The gradient calculated by the ODR fit 

318 bODR : `float` 

319 The intercept calculated by the ODR fit 

320 yBoxMin : `float` 

321 The y value of the fitted line at xMin 

322 yBoxMax : `float` 

323 The y value of the fitted line at xMax 

324 bPerpMin : `float` 

325 The intercept of the perpendicular line that goes through xMin 

326 bPerpMax : `float` 

327 The intercept of the perpendicular line that goes through xMax 

328 mODR2 : `float` 

329 The gradient from the second round of fitting 

330 bODR2 : `float` 

331 The intercept from the second round of fitting 

332 mPerp : `float` 

333 The gradient of the line perpendicular to the line from the 

334 second fit 

335 Notes 

336 ----- 

337 The code does two rounds of fitting, the first is initiated using the 

338 hardwired values given in the `paramDict` parameter and is done using 

339 an Orthogonal Distance Regression fit to the points defined by the 

340 box of xMin, xMax, yMin and yMax. Once this fitting has been done a 

341 perpendicular bisector is calculated at either end of the line and 

342 only points that fall within these lines are used to recalculate the fit. 

343 """ 

344 

345 # Points to use for the fit 

346 fitPoints = np.where( 

347 (xs > paramDict["xMin"]) 

348 & (xs < paramDict["xMax"]) 

349 & (ys > paramDict["yMin"]) 

350 & (ys < paramDict["yMax"]) 

351 )[0] 

352 

353 linear = scipyODR.polynomial(1) 

354 

355 data = scipyODR.Data(xs[fitPoints], ys[fitPoints]) 

356 odr = scipyODR.ODR(data, linear, beta0=[paramDict["bHW"], paramDict["mHW"]]) 

357 params = odr.run() 

358 mODR = float(params.beta[1]) 

359 bODR = float(params.beta[0]) 

360 

361 paramsOut = { 

362 "xMin": paramDict["xMin"], 

363 "xMax": paramDict["xMax"], 

364 "yMin": paramDict["yMin"], 

365 "yMax": paramDict["yMax"], 

366 "mHW": paramDict["mHW"], 

367 "bHW": paramDict["bHW"], 

368 "mODR": mODR, 

369 "bODR": bODR, 

370 } 

371 

372 # Having found the initial fit calculate perpendicular ends 

373 mPerp = -1.0 / mODR 

374 # When the gradient is really steep we need to use 

375 # the y limits of the box rather than the x ones 

376 

377 if np.abs(mODR) > 1: 

378 yBoxMin = paramDict["yMin"] 

379 xBoxMin = (yBoxMin - bODR) / mODR 

380 yBoxMax = paramDict["yMax"] 

381 xBoxMax = (yBoxMax - bODR) / mODR 

382 else: 

383 yBoxMin = mODR * paramDict["xMin"] + bODR 

384 xBoxMin = paramDict["xMin"] 

385 yBoxMax = mODR * paramDict["xMax"] + bODR 

386 xBoxMax = paramDict["xMax"] 

387 

388 bPerpMin = yBoxMin - mPerp * xBoxMin 

389 

390 paramsOut["yBoxMin"] = yBoxMin 

391 paramsOut["bPerpMin"] = bPerpMin 

392 

393 bPerpMax = yBoxMax - mPerp * xBoxMax 

394 

395 paramsOut["yBoxMax"] = yBoxMax 

396 paramsOut["bPerpMax"] = bPerpMax 

397 

398 # Use these perpendicular lines to chose the data and refit 

399 fitPoints = (ys > mPerp * xs + bPerpMin) & (ys < mPerp * xs + bPerpMax) 

400 data = scipyODR.Data(xs[fitPoints], ys[fitPoints]) 

401 odr = scipyODR.ODR(data, linear, beta0=[bODR, mODR]) 

402 params = odr.run() 

403 mODR = float(params.beta[1]) 

404 bODR = float(params.beta[0]) 

405 

406 paramsOut["mODR2"] = float(params.beta[1]) 

407 paramsOut["bODR2"] = float(params.beta[0]) 

408 

409 paramsOut["mPerp"] = -1.0 / paramsOut["mODR2"] 

410 

411 return paramsOut 

412 

413 

414def perpDistance(p1, p2, points): 

415 """Calculate the perpendicular distance to a line from a point 

416 Parameters 

417 ---------- 

418 p1 : `numpy.ndarray` 

419 A point on the line 

420 p2 : `numpy.ndarray` 

421 Another point on the line 

422 points : `zip` 

423 The points to calculate the distance to 

424 Returns 

425 ------- 

426 dists : `list` 

427 The distances from the line to the points. Uses the cross 

428 product to work this out. 

429 """ 

430 dists = [] 

431 for point in points: 

432 point = np.array(point) 

433 distToLine = np.cross(p2 - p1, point - p1) / np.linalg.norm(p2 - p1) 

434 dists.append(distToLine) 

435 

436 return dists 

437 

438 

439def mkColormap(colorNames): 

440 """Make a colormap from the list of color names. 

441 Parameters 

442 ---------- 

443 colorNames : `list` 

444 A list of strings that correspond to matplotlib 

445 named colors. 

446 Returns 

447 ------- 

448 cmap : `matplotlib.colors.LinearSegmentedColormap` 

449 """ 

450 

451 nums = np.linspace(0, 1, len(colorNames)) 

452 blues = [] 

453 greens = [] 

454 reds = [] 

455 for (num, color) in zip(nums, colorNames): 

456 r, g, b = colors.colorConverter.to_rgb(color) 

457 blues.append((num, b, b)) 

458 greens.append((num, g, g)) 

459 reds.append((num, r, r)) 

460 

461 colorDict = {"blue": blues, "red": reds, "green": greens} 

462 cmap = colors.LinearSegmentedColormap("newCmap", colorDict) 

463 return cmap 

464 

465 

466def extremaSort(xs): 

467 """Return the ids of the points reordered so that those 

468 furthest from the median, in absolute terms, are last. 

469 Parameters 

470 ---------- 

471 xs : `np.array` 

472 An array of the values to sort 

473 Returns 

474 ------- 

475 ids : `np.array` 

476 """ 

477 

478 med = np.nanmedian(xs) 

479 dists = np.abs(xs - med) 

480 ids = np.argsort(dists) 

481 return ids 

482 

483 

484def addSummaryPlot(fig, loc, sumStats, label): 

485 """Add a summary subplot to the figure. 

486 

487 Parameters 

488 ---------- 

489 fig : `matplotlib.figure.Figure` 

490 The figure that the summary plot is to be added to. 

491 loc : `matplotlib.gridspec.SubplotSpec` or `int` or `(int, int, index` 

492 Describes the location in the figure to put the summary plot, 

493 can be a gridspec SubplotSpec, a 3 digit integer where the first 

494 digit is the number of rows, the second is the number of columns 

495 and the third is the index. This is the same for the tuple 

496 of int, int, index. 

497 sumStats : `dict` 

498 A dictionary where the patchIds are the keys which store the R.A. 

499 and the dec of the corners of the patch, along with a summary 

500 statistic for each patch. 

501 

502 Returns 

503 ------- 

504 fig : `matplotlib.figure.Figure` 

505 """ 

506 

507 # Add the subplot to the relevant place in the figure 

508 # and sort the axis out 

509 axCorner = fig.add_subplot(loc) 

510 axCorner.yaxis.tick_right() 

511 axCorner.yaxis.set_label_position("right") 

512 axCorner.xaxis.tick_top() 

513 axCorner.xaxis.set_label_position("top") 

514 axCorner.set_aspect("equal") 

515 

516 # Plot the corners of the patches and make the color 

517 # coded rectangles for each patch, the colors show 

518 # the median of the given value in the patch 

519 patches = [] 

520 colors = [] 

521 for dataId in sumStats.keys(): 

522 (corners, stat) = sumStats[dataId] 

523 ra = corners[0][0].asDegrees() 

524 dec = corners[0][1].asDegrees() 

525 xy = (ra, dec) 

526 width = corners[2][0].asDegrees() - ra 

527 height = corners[2][1].asDegrees() - dec 

528 patches.append(Rectangle(xy, width, height)) 

529 colors.append(stat) 

530 ras = [ra.asDegrees() for (ra, dec) in corners] 

531 decs = [dec.asDegrees() for (ra, dec) in corners] 

532 axCorner.plot(ras + [ras[0]], decs + [decs[0]], "k", lw=0.5) 

533 cenX = ra + width / 2 

534 cenY = dec + height / 2 

535 if dataId != "tract": 

536 axCorner.annotate(dataId, (cenX, cenY), color="k", fontsize=4, ha="center", va="center") 

537 

538 # Set the bad color to transparent and make a masked array 

539 cmapPatch = plt.cm.coolwarm.copy() 

540 cmapPatch.set_bad(color="none") 

541 colors = np.ma.array(colors, mask=np.isnan(colors)) 

542 collection = PatchCollection(patches, cmap=cmapPatch) 

543 collection.set_array(colors) 

544 axCorner.add_collection(collection) 

545 

546 # Add some labels 

547 axCorner.set_xlabel("R.A. (deg)", fontsize=7) 

548 axCorner.set_ylabel("Dec. (deg)", fontsize=7) 

549 axCorner.tick_params(axis="both", labelsize=6, length=0, pad=1.5) 

550 axCorner.invert_xaxis() 

551 

552 # Add a colorbar 

553 pos = axCorner.get_position() 

554 yOffset = (pos.y1 - pos.y0) / 3 

555 cax = fig.add_axes([pos.x0, pos.y1 + yOffset, pos.x1 - pos.x0, 0.025]) 

556 plt.colorbar(collection, cax=cax, orientation="horizontal") 

557 cax.text( 

558 0.5, 

559 0.48, 

560 label, 

561 color="k", 

562 transform=cax.transAxes, 

563 rotation="horizontal", 

564 horizontalalignment="center", 

565 verticalalignment="center", 

566 fontsize=6, 

567 ) 

568 cax.tick_params( 

569 axis="x", labelsize=6, labeltop=True, labelbottom=False, bottom=False, top=True, pad=0.5, length=2 

570 ) 

571 

572 return fig 

573 

574 

575class PanelConfig(Config): 

576 """Configuration options for the plot panels used by DiaSkyPlot. 

577 

578 The defaults will produce a good-looking single panel plot. 

579 The subplot2grid* fields correspond to matplotlib.pyplot.subplot2grid. 

580 """ 

581 

582 topSpinesVisible = Field[bool]( 

583 doc="Draw line and ticks on top of panel?", 

584 default=False, 

585 ) 

586 bottomSpinesVisible = Field[bool]( 

587 doc="Draw line and ticks on bottom of panel?", 

588 default=True, 

589 ) 

590 leftSpinesVisible = Field[bool]( 

591 doc="Draw line and ticks on left side of panel?", 

592 default=True, 

593 ) 

594 rightSpinesVisible = Field[bool]( 

595 doc="Draw line and ticks on right side of panel?", 

596 default=True, 

597 ) 

598 subplot2gridShapeRow = Field[int]( 

599 doc="Number of rows of the grid in which to place axis.", 

600 default=10, 

601 ) 

602 subplot2gridShapeColumn = Field[int]( 

603 doc="Number of columns of the grid in which to place axis.", 

604 default=10, 

605 ) 

606 subplot2gridLocRow = Field[int]( 

607 doc="Row of the axis location within the grid.", 

608 default=1, 

609 ) 

610 subplot2gridLocColumn = Field[int]( 

611 doc="Column of the axis location within the grid.", 

612 default=1, 

613 ) 

614 subplot2gridRowspan = Field[int]( 

615 doc="Number of rows for the axis to span downwards.", 

616 default=5, 

617 ) 

618 subplot2gridColspan = Field[int]( 

619 doc="Number of rows for the axis to span to the right.", 

620 default=5, 

621 )