Coverage for python/lsst/analysis/tools/actions/plot/plotUtils.py: 11%

283 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-01-12 03:15 -0800

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("PanelConfig",) 

24 

25from typing import TYPE_CHECKING, List, Mapping, Tuple 

26 

27import matplotlib 

28import matplotlib.pyplot as plt 

29import numpy as np 

30import scipy.odr as scipyODR 

31from lsst.geom import Box2D, SpherePoint, degrees 

32from lsst.pex.config import Config, Field 

33from matplotlib import colors 

34from matplotlib.collections import PatchCollection 

35from matplotlib.patches import Rectangle 

36from scipy.stats import binned_statistic_2d 

37 

38from ...statistics import nansigmaMad 

39 

40if TYPE_CHECKING: 40 ↛ 41line 40 didn't jump to line 41, because the condition on line 40 was never true

41 from matplotlib.figure import Figure 

42 

43null_formatter = matplotlib.ticker.NullFormatter() 

44 

45 

46def parsePlotInfo(dataId, runName, tableName, bands, plotName, SN): 

47 """Parse plot info from the dataId 

48 Parameters 

49 ---------- 

50 dataId : `lsst.daf.butler.core.dimensions.` 

51 `_coordinate._ExpandedTupleDataCoordinate` 

52 runName : `str` 

53 Returns 

54 ------- 

55 plotInfo : `dict` 

56 """ 

57 plotInfo = {"run": runName, "tableName": tableName, "plotName": plotName, "SN": SN} 

58 

59 for dataInfo in dataId: 

60 plotInfo[dataInfo.name] = dataId[dataInfo.name] 

61 

62 bandStr = "" 

63 for band in bands: 

64 bandStr += ", " + band 

65 plotInfo["bands"] = bandStr[2:] 

66 

67 if "tract" not in plotInfo.keys(): 

68 plotInfo["tract"] = "N/A" 

69 if "visit" not in plotInfo.keys(): 

70 plotInfo["visit"] = "N/A" 

71 

72 return plotInfo 

73 

74 

75def generateSummaryStats(data, skymap, plotInfo): 

76 """Generate a summary statistic in each patch or detector 

77 Parameters 

78 ---------- 

79 data : `dict` 

80 skymap : `lsst.skymap.ringsSkyMap.RingsSkyMap` 

81 plotInfo : `dict` 

82 

83 Returns 

84 ------- 

85 patchInfoDict : `dict` 

86 """ 

87 

88 # TODO: what is the more generic type of skymap? 

89 tractInfo = skymap.generateTract(plotInfo["tract"]) 

90 tractWcs = tractInfo.getWcs() 

91 

92 # For now also convert the gen 2 patchIds to gen 3 

93 if "y" in data.keys(): 

94 yCol = "y" 

95 elif "yStars" in data.keys(): 

96 yCol = "yStars" 

97 elif "yGalaxies" in data.keys(): 

98 yCol = "yGalaxies" 

99 elif "yUnknowns" in data.keys(): 

100 yCol = "yUnknowns" 

101 

102 patchInfoDict = {} 

103 maxPatchNum = tractInfo.num_patches.x * tractInfo.num_patches.y 

104 patches = np.arange(0, maxPatchNum, 1) 

105 for patch in patches: 

106 if patch is None: 

107 continue 

108 # Once the objectTable_tract catalogues are using gen 3 patches 

109 # this will go away 

110 onPatch = data["patch"] == patch 

111 stat = np.nanmedian(data[yCol][onPatch]) 

112 try: 

113 patchTuple = (int(patch.split(",")[0]), int(patch.split(",")[-1])) 

114 patchInfo = tractInfo.getPatchInfo(patchTuple) 

115 gen3PatchId = tractInfo.getSequentialPatchIndex(patchInfo) 

116 except AttributeError: 

117 # For native gen 3 tables the patches don't need converting 

118 # When we are no longer looking at the gen 2 -> gen 3 

119 # converted repos we can tidy this up 

120 gen3PatchId = patch 

121 patchInfo = tractInfo.getPatchInfo(patch) 

122 

123 corners = Box2D(patchInfo.getInnerBBox()).getCorners() 

124 skyCoords = tractWcs.pixelToSky(corners) 

125 

126 patchInfoDict[gen3PatchId] = (skyCoords, stat) 

127 

128 tractCorners = Box2D(tractInfo.getBBox()).getCorners() 

129 skyCoords = tractWcs.pixelToSky(tractCorners) 

130 patchInfoDict["tract"] = (skyCoords, np.nan) 

131 

132 return patchInfoDict 

133 

134 

135def generateSummaryStatsVisit(cat, colName, visitSummaryTable, plotInfo): 

136 """Generate a summary statistic in each patch or detector 

137 Parameters 

138 ---------- 

139 cat : `pandas.core.frame.DataFrame` 

140 colName : `str` 

141 visitSummaryTable : `pandas.core.frame.DataFrame` 

142 plotInfo : `dict` 

143 Returns 

144 ------- 

145 visitInfoDict : `dict` 

146 """ 

147 

148 visitInfoDict = {} 

149 for ccd in cat.detector.unique(): 

150 if ccd is None: 

151 continue 

152 onCcd = cat["detector"] == ccd 

153 stat = np.nanmedian(cat[colName].values[onCcd]) 

154 

155 sumRow = visitSummaryTable["id"] == ccd 

156 corners = zip(visitSummaryTable["raCorners"][sumRow][0], visitSummaryTable["decCorners"][sumRow][0]) 

157 cornersOut = [] 

158 for (ra, dec) in corners: 

159 corner = SpherePoint(ra, dec, units=degrees) 

160 cornersOut.append(corner) 

161 

162 visitInfoDict[ccd] = (cornersOut, stat) 

163 

164 return visitInfoDict 

165 

166 

167# Inspired by matplotlib.testing.remove_ticks_and_titles 

168def get_and_remove_axis_text(ax) -> Tuple[List[str], List[np.ndarray]]: 

169 """Remove text from an Axis and its children and return with line points. 

170 Parameters 

171 ---------- 

172 ax : `plt.Axis` 

173 A matplotlib figure axis. 

174 Returns 

175 ------- 

176 texts : `List[str]` 

177 A list of all text strings (title and axis/legend/tick labels). 

178 line_xys : `List[numpy.ndarray]` 

179 A list of all line ``_xy`` attributes (arrays of shape ``(N, 2)``). 

180 """ 

181 line_xys = [line._xy for line in ax.lines] 

182 texts = [text.get_text() for text in (ax.title, ax.xaxis.label, ax.yaxis.label)] 

183 ax.set_title("") 

184 ax.set_xlabel("") 

185 ax.set_ylabel("") 

186 

187 try: 

188 texts_legend = ax.get_legend().texts 

189 texts.extend(text.get_text() for text in texts_legend) 

190 for text in texts_legend: 

191 text.set_alpha(0) 

192 except AttributeError: 

193 pass 

194 

195 for idx in range(len(ax.texts)): 

196 texts.append(ax.texts[idx].get_text()) 

197 ax.texts[idx].set_text("") 

198 

199 ax.xaxis.set_major_formatter(null_formatter) 

200 ax.xaxis.set_minor_formatter(null_formatter) 

201 ax.yaxis.set_major_formatter(null_formatter) 

202 ax.yaxis.set_minor_formatter(null_formatter) 

203 try: 

204 ax.zaxis.set_major_formatter(null_formatter) 

205 ax.zaxis.set_minor_formatter(null_formatter) 

206 except AttributeError: 

207 pass 

208 for child in ax.child_axes: 

209 texts_child, lines_child = get_and_remove_axis_text(child) 

210 texts.extend(texts_child) 

211 

212 return texts, line_xys 

213 

214 

215def get_and_remove_figure_text(figure: Figure): 

216 """Remove text from a Figure and its Axes and return with line points. 

217 Parameters 

218 ---------- 

219 figure : `matplotlib.pyplot.Figure` 

220 A matplotlib figure. 

221 Returns 

222 ------- 

223 texts : `List[str]` 

224 A list of all text strings (title and axis/legend/tick labels). 

225 line_xys : `List[numpy.ndarray]`, (N, 2) 

226 A list of all line ``_xy`` attributes (arrays of shape ``(N, 2)``). 

227 """ 

228 texts = [str(figure._suptitle)] 

229 lines = [] 

230 figure.suptitle("") 

231 

232 texts.extend(text.get_text() for text in figure.texts) 

233 figure.texts = [] 

234 

235 for ax in figure.get_axes(): 

236 texts_ax, lines_ax = get_and_remove_axis_text(ax) 

237 texts.extend(texts_ax) 

238 lines.extend(lines_ax) 

239 

240 return texts, lines 

241 

242 

243def addPlotInfo(fig: Figure, plotInfo: Mapping[str, str]) -> Figure: 

244 """Add useful information to the plot 

245 Parameters 

246 ---------- 

247 fig : `matplotlib.figure.Figure` 

248 plotInfo : `dict` 

249 Returns 

250 ------- 

251 fig : `matplotlib.figure.Figure` 

252 """ 

253 

254 # TO DO: figure out how to get this information 

255 photocalibDataset = "None" 

256 astroDataset = "None" 

257 

258 fig.text(0.01, 0.99, plotInfo["plotName"], fontsize=8, transform=fig.transFigure, ha="left", va="top") 

259 

260 run = plotInfo["run"] 

261 datasetsUsed = f"\nPhotoCalib: {photocalibDataset}, Astrometry: {astroDataset}" 

262 tableType = f"\nTable: {plotInfo['tableName']}" 

263 

264 dataIdText = "" 

265 if "tract" in plotInfo.keys(): 

266 dataIdText += f", Tract: {plotInfo['tract']}" 

267 if "visit" in plotInfo.keys(): 

268 dataIdText += f", Visit: {plotInfo['visit']}" 

269 

270 bandText = "" 

271 for band in plotInfo["bands"]: 

272 bandText += band + ", " 

273 bandsText = f", Bands: {bandText[:-2]}" 

274 SNText = f", S/N: {plotInfo.get('SN', 'N/A')}" 

275 infoText = f"\n{run}{datasetsUsed}{tableType}{dataIdText}{bandsText}{SNText}" 

276 fig.text(0.01, 0.98, infoText, fontsize=7, transform=fig.transFigure, alpha=0.6, ha="left", va="top") 

277 

278 return fig 

279 

280 

281def stellarLocusFit(xs, ys, paramDict): 

282 """Make a fit to the stellar locus 

283 Parameters 

284 ---------- 

285 xs : `numpy.ndarray` 

286 The color on the xaxis 

287 ys : `numpy.ndarray` 

288 The color on the yaxis 

289 paramDict : lsst.pex.config.dictField.Dict 

290 A dictionary of parameters for line fitting 

291 xMin : `float` 

292 The minimum x edge of the box to use for initial fitting 

293 xMax : `float` 

294 The maximum x edge of the box to use for initial fitting 

295 yMin : `float` 

296 The minimum y edge of the box to use for initial fitting 

297 yMax : `float` 

298 The maximum y edge of the box to use for initial fitting 

299 mHW : `float` 

300 The hardwired gradient for the fit 

301 bHW : `float` 

302 The hardwired intercept of the fit 

303 Returns 

304 ------- 

305 paramsOut : `dict` 

306 A dictionary of the calculated fit parameters 

307 xMin : `float` 

308 The minimum x edge of the box to use for initial fitting 

309 xMax : `float` 

310 The maximum x edge of the box to use for initial fitting 

311 yMin : `float` 

312 The minimum y edge of the box to use for initial fitting 

313 yMax : `float` 

314 The maximum y edge of the box to use for initial fitting 

315 mHW : `float` 

316 The hardwired gradient for the fit 

317 bHW : `float` 

318 The hardwired intercept of the fit 

319 mODR : `float` 

320 The gradient calculated by the ODR fit 

321 bODR : `float` 

322 The intercept calculated by the ODR fit 

323 yBoxMin : `float` 

324 The y value of the fitted line at xMin 

325 yBoxMax : `float` 

326 The y value of the fitted line at xMax 

327 bPerpMin : `float` 

328 The intercept of the perpendicular line that goes through xMin 

329 bPerpMax : `float` 

330 The intercept of the perpendicular line that goes through xMax 

331 mODR2 : `float` 

332 The gradient from the second round of fitting 

333 bODR2 : `float` 

334 The intercept from the second round of fitting 

335 mPerp : `float` 

336 The gradient of the line perpendicular to the line from the 

337 second fit 

338 Notes 

339 ----- 

340 The code does two rounds of fitting, the first is initiated using the 

341 hardwired values given in the `paramDict` parameter and is done using 

342 an Orthogonal Distance Regression fit to the points defined by the 

343 box of xMin, xMax, yMin and yMax. Once this fitting has been done a 

344 perpendicular bisector is calculated at either end of the line and 

345 only points that fall within these lines are used to recalculate the fit. 

346 """ 

347 

348 # Points to use for the fit 

349 fitPoints = np.where( 

350 (xs > paramDict["xMin"]) 

351 & (xs < paramDict["xMax"]) 

352 & (ys > paramDict["yMin"]) 

353 & (ys < paramDict["yMax"]) 

354 )[0] 

355 

356 linear = scipyODR.polynomial(1) 

357 

358 data = scipyODR.Data(xs[fitPoints], ys[fitPoints]) 

359 odr = scipyODR.ODR(data, linear, beta0=[paramDict["bHW"], paramDict["mHW"]]) 

360 params = odr.run() 

361 mODR = float(params.beta[1]) 

362 bODR = float(params.beta[0]) 

363 

364 paramsOut = { 

365 "xMin": paramDict["xMin"], 

366 "xMax": paramDict["xMax"], 

367 "yMin": paramDict["yMin"], 

368 "yMax": paramDict["yMax"], 

369 "mHW": paramDict["mHW"], 

370 "bHW": paramDict["bHW"], 

371 "mODR": mODR, 

372 "bODR": bODR, 

373 } 

374 

375 # Having found the initial fit calculate perpendicular ends 

376 mPerp = -1.0 / mODR 

377 # When the gradient is really steep we need to use 

378 # the y limits of the box rather than the x ones 

379 

380 if np.abs(mODR) > 1: 

381 yBoxMin = paramDict["yMin"] 

382 xBoxMin = (yBoxMin - bODR) / mODR 

383 yBoxMax = paramDict["yMax"] 

384 xBoxMax = (yBoxMax - bODR) / mODR 

385 else: 

386 yBoxMin = mODR * paramDict["xMin"] + bODR 

387 xBoxMin = paramDict["xMin"] 

388 yBoxMax = mODR * paramDict["xMax"] + bODR 

389 xBoxMax = paramDict["xMax"] 

390 

391 bPerpMin = yBoxMin - mPerp * xBoxMin 

392 

393 paramsOut["yBoxMin"] = yBoxMin 

394 paramsOut["bPerpMin"] = bPerpMin 

395 

396 bPerpMax = yBoxMax - mPerp * xBoxMax 

397 

398 paramsOut["yBoxMax"] = yBoxMax 

399 paramsOut["bPerpMax"] = bPerpMax 

400 

401 # Use these perpendicular lines to chose the data and refit 

402 fitPoints = (ys > mPerp * xs + bPerpMin) & (ys < mPerp * xs + bPerpMax) 

403 data = scipyODR.Data(xs[fitPoints], ys[fitPoints]) 

404 odr = scipyODR.ODR(data, linear, beta0=[bODR, mODR]) 

405 params = odr.run() 

406 mODR = float(params.beta[1]) 

407 bODR = float(params.beta[0]) 

408 

409 paramsOut["mODR2"] = float(params.beta[1]) 

410 paramsOut["bODR2"] = float(params.beta[0]) 

411 

412 paramsOut["mPerp"] = -1.0 / paramsOut["mODR2"] 

413 

414 return paramsOut 

415 

416 

417def perpDistance(p1, p2, points): 

418 """Calculate the perpendicular distance to a line from a point 

419 Parameters 

420 ---------- 

421 p1 : `numpy.ndarray` 

422 A point on the line 

423 p2 : `numpy.ndarray` 

424 Another point on the line 

425 points : `zip` 

426 The points to calculate the distance to 

427 Returns 

428 ------- 

429 dists : `list` 

430 The distances from the line to the points. Uses the cross 

431 product to work this out. 

432 """ 

433 dists = [] 

434 for point in points: 

435 point = np.array(point) 

436 distToLine = np.cross(p2 - p1, point - p1) / np.linalg.norm(p2 - p1) 

437 dists.append(distToLine) 

438 

439 return dists 

440 

441 

442def mkColormap(colorNames): 

443 """Make a colormap from the list of color names. 

444 Parameters 

445 ---------- 

446 colorNames : `list` 

447 A list of strings that correspond to matplotlib 

448 named colors. 

449 Returns 

450 ------- 

451 cmap : `matplotlib.colors.LinearSegmentedColormap` 

452 """ 

453 

454 nums = np.linspace(0, 1, len(colorNames)) 

455 blues = [] 

456 greens = [] 

457 reds = [] 

458 for (num, color) in zip(nums, colorNames): 

459 r, g, b = colors.colorConverter.to_rgb(color) 

460 blues.append((num, b, b)) 

461 greens.append((num, g, g)) 

462 reds.append((num, r, r)) 

463 

464 colorDict = {"blue": blues, "red": reds, "green": greens} 

465 cmap = colors.LinearSegmentedColormap("newCmap", colorDict) 

466 return cmap 

467 

468 

469def extremaSort(xs): 

470 """Return the ids of the points reordered so that those 

471 furthest from the median, in absolute terms, are last. 

472 Parameters 

473 ---------- 

474 xs : `np.array` 

475 An array of the values to sort 

476 Returns 

477 ------- 

478 ids : `np.array` 

479 """ 

480 

481 med = np.nanmedian(xs) 

482 dists = np.abs(xs - med) 

483 ids = np.argsort(dists) 

484 return ids 

485 

486 

487def sortAllArrays(arrsToSort, sortArrayIndex=0): 

488 """Sort one array and then return all the others in the associated order. 

489 

490 Parameters 

491 ---------- 

492 arrsToSort : `list` [`np.array`] 

493 A list of arrays to be simultaneously sorted based on the array in 

494 the list position given by ``sortArrayIndex`` (defaults to be the 

495 first array in the list). 

496 sortArrayIndex : `int`, optional 

497 Zero-based index indicating the array on which to base the sorting. 

498 

499 Returns 

500 ------- 

501 arrsToSort : `list` [`np.array`] 

502 The list of arrays sorted on array in list index ``sortArrayIndex``. 

503 """ 

504 ids = extremaSort(arrsToSort[sortArrayIndex]) 

505 for (i, arr) in enumerate(arrsToSort): 

506 arrsToSort[i] = arr[ids] 

507 return arrsToSort 

508 

509 

510def addSummaryPlot(fig, loc, sumStats, label): 

511 """Add a summary subplot to the figure. 

512 

513 Parameters 

514 ---------- 

515 fig : `matplotlib.figure.Figure` 

516 The figure that the summary plot is to be added to. 

517 loc : `matplotlib.gridspec.SubplotSpec` or `int` or `(int, int, index` 

518 Describes the location in the figure to put the summary plot, 

519 can be a gridspec SubplotSpec, a 3 digit integer where the first 

520 digit is the number of rows, the second is the number of columns 

521 and the third is the index. This is the same for the tuple 

522 of int, int, index. 

523 sumStats : `dict` 

524 A dictionary where the patchIds are the keys which store the R.A. 

525 and the dec of the corners of the patch, along with a summary 

526 statistic for each patch. 

527 

528 Returns 

529 ------- 

530 fig : `matplotlib.figure.Figure` 

531 """ 

532 

533 # Add the subplot to the relevant place in the figure 

534 # and sort the axis out 

535 axCorner = fig.add_subplot(loc) 

536 axCorner.yaxis.tick_right() 

537 axCorner.yaxis.set_label_position("right") 

538 axCorner.xaxis.tick_top() 

539 axCorner.xaxis.set_label_position("top") 

540 axCorner.set_aspect("equal") 

541 

542 # Plot the corners of the patches and make the color 

543 # coded rectangles for each patch, the colors show 

544 # the median of the given value in the patch 

545 patches = [] 

546 colors = [] 

547 for dataId in sumStats.keys(): 

548 (corners, stat) = sumStats[dataId] 

549 ra = corners[0][0].asDegrees() 

550 dec = corners[0][1].asDegrees() 

551 xy = (ra, dec) 

552 width = corners[2][0].asDegrees() - ra 

553 height = corners[2][1].asDegrees() - dec 

554 patches.append(Rectangle(xy, width, height)) 

555 colors.append(stat) 

556 ras = [ra.asDegrees() for (ra, dec) in corners] 

557 decs = [dec.asDegrees() for (ra, dec) in corners] 

558 axCorner.plot(ras + [ras[0]], decs + [decs[0]], "k", lw=0.5) 

559 cenX = ra + width / 2 

560 cenY = dec + height / 2 

561 if dataId != "tract": 

562 axCorner.annotate(dataId, (cenX, cenY), color="k", fontsize=4, ha="center", va="center") 

563 

564 # Set the bad color to transparent and make a masked array 

565 cmapPatch = plt.cm.coolwarm.copy() 

566 cmapPatch.set_bad(color="none") 

567 colors = np.ma.array(colors, mask=np.isnan(colors)) 

568 collection = PatchCollection(patches, cmap=cmapPatch) 

569 collection.set_array(colors) 

570 axCorner.add_collection(collection) 

571 

572 # Add some labels 

573 axCorner.set_xlabel("R.A. (deg)", fontsize=7) 

574 axCorner.set_ylabel("Dec. (deg)", fontsize=7) 

575 axCorner.tick_params(axis="both", labelsize=6, length=0, pad=1.5) 

576 axCorner.invert_xaxis() 

577 

578 # Add a colorbar 

579 pos = axCorner.get_position() 

580 yOffset = (pos.y1 - pos.y0) / 3 

581 cax = fig.add_axes([pos.x0, pos.y1 + yOffset, pos.x1 - pos.x0, 0.025]) 

582 plt.colorbar(collection, cax=cax, orientation="horizontal") 

583 cax.text( 

584 0.5, 

585 0.48, 

586 label, 

587 color="k", 

588 transform=cax.transAxes, 

589 rotation="horizontal", 

590 horizontalalignment="center", 

591 verticalalignment="center", 

592 fontsize=6, 

593 ) 

594 cax.tick_params( 

595 axis="x", labelsize=6, labeltop=True, labelbottom=False, bottom=False, top=True, pad=0.5, length=2 

596 ) 

597 

598 return fig 

599 

600 

601class PanelConfig(Config): 

602 """Configuration options for the plot panels used by DiaSkyPlot. 

603 

604 The defaults will produce a good-looking single panel plot. 

605 The subplot2grid* fields correspond to matplotlib.pyplot.subplot2grid. 

606 """ 

607 

608 topSpinesVisible = Field[bool]( 

609 doc="Draw line and ticks on top of panel?", 

610 default=False, 

611 ) 

612 bottomSpinesVisible = Field[bool]( 

613 doc="Draw line and ticks on bottom of panel?", 

614 default=True, 

615 ) 

616 leftSpinesVisible = Field[bool]( 

617 doc="Draw line and ticks on left side of panel?", 

618 default=True, 

619 ) 

620 rightSpinesVisible = Field[bool]( 

621 doc="Draw line and ticks on right side of panel?", 

622 default=True, 

623 ) 

624 subplot2gridShapeRow = Field[int]( 

625 doc="Number of rows of the grid in which to place axis.", 

626 default=10, 

627 ) 

628 subplot2gridShapeColumn = Field[int]( 

629 doc="Number of columns of the grid in which to place axis.", 

630 default=10, 

631 ) 

632 subplot2gridLocRow = Field[int]( 

633 doc="Row of the axis location within the grid.", 

634 default=1, 

635 ) 

636 subplot2gridLocColumn = Field[int]( 

637 doc="Column of the axis location within the grid.", 

638 default=1, 

639 ) 

640 subplot2gridRowspan = Field[int]( 

641 doc="Number of rows for the axis to span downwards.", 

642 default=5, 

643 ) 

644 subplot2gridColspan = Field[int]( 

645 doc="Number of rows for the axis to span to the right.", 

646 default=5, 

647 ) 

648 

649 

650def plotProjectionWithBinning( 

651 ax, 

652 xs, 

653 ys, 

654 zs, 

655 cmap, 

656 xMin, 

657 xMax, 

658 yMin, 

659 yMax, 

660 xNumBins=45, 

661 yNumBins=None, 

662 fixAroundZero=False, 

663 nPointBinThresh=5000, 

664 isSorted=False, 

665 vmin=None, 

666 vmax=None, 

667 scatPtSize=7, 

668): 

669 """Plot color-mapped data in projection and with binning when appropriate. 

670 

671 Parameters 

672 ---------- 

673 ax : `matplotlib.axes.Axes` 

674 Axis on which to plot the projection data. 

675 xs, ys : `np.array` 

676 Arrays containing the x and y positions of the data. 

677 zs : `np.array` 

678 Array containing the scaling value associated with the (``xs``, ``ys``) 

679 positions. 

680 cmap : `matplotlib.colors.Colormap` 

681 Colormap for the ``zs`` values. 

682 xMin, xMax, yMin, yMax : `float` 

683 Data limits within which to compute bin sizes. 

684 xNumBins : `int`, optional 

685 The number of bins along the x-axis. 

686 yNumBins : `int`, optional 

687 The number of bins along the y-axis. If `None`, this is set to equal 

688 ``xNumBins``. 

689 nPointBinThresh : `int`, optional 

690 Threshold number of points above which binning will be implemented 

691 for the plotting. If the number of data points is lower than this 

692 threshold, a basic scatter plot will be generated. 

693 isSorted : `bool`, optional 

694 Whether the data have been sorted in ``zs`` (the sorting is to 

695 accommodate the overplotting of points in the upper and lower 

696 extrema of the data). 

697 vmin, vmax : `float`, optional 

698 The min and max limits for the colorbar. 

699 scatPtSize : `float`, optional 

700 The point size to use if just plotting a regular scatter plot. 

701 

702 Returns 

703 ------- 

704 plotOut : `matplotlib.collections.PathCollection` 

705 The plot object with ``ax`` updated with data plotted here. 

706 """ 

707 med = np.nanmedian(zs) 

708 mad = nansigmaMad(zs) 

709 if vmin is None: 

710 vmin = med - 2 * mad 

711 if vmax is None: 

712 vmax = med + 2 * mad 

713 if fixAroundZero: 

714 scaleEnd = np.max([np.abs(vmin), np.abs(vmax)]) 

715 vmin = -1 * scaleEnd 

716 vmax = scaleEnd 

717 

718 yNumBins = xNumBins if yNumBins is None else yNumBins 

719 

720 xBinEdges = np.linspace(xMin, xMax, xNumBins + 1) 

721 yBinEdges = np.linspace(yMin, yMax, yNumBins + 1) 

722 binnedStats, xEdges, yEdges, binNums = binned_statistic_2d( 

723 xs, ys, zs, statistic="median", bins=(xBinEdges, yBinEdges) 

724 ) 

725 if len(xs) >= nPointBinThresh: 

726 s = min(10, max(0.5, nPointBinThresh / 10 / (len(xs) ** 0.5))) 

727 lw = (s**0.5) / 10 

728 plotOut = ax.imshow( 

729 binnedStats.T, 

730 cmap=cmap, 

731 extent=[xEdges[0], xEdges[-1], yEdges[-1], yEdges[0]], 

732 vmin=vmin, 

733 vmax=vmax, 

734 ) 

735 if not isSorted: 

736 sortedArrays = sortAllArrays([zs, xs, ys]) 

737 zs, xs, ys = sortedArrays[0], sortedArrays[1], sortedArrays[2] 

738 # Find the most extreme 15% of points. The list is ordered by the 

739 # distance from the median, this is just the head/tail 15% of points. 

740 if len(xs) > 1: 

741 extremes = int(np.floor((len(xs) / 100)) * 85) 

742 plotOut = ax.scatter( 

743 xs[extremes:], 

744 ys[extremes:], 

745 c=zs[extremes:], 

746 s=s, 

747 cmap=cmap, 

748 vmin=vmin, 

749 vmax=vmax, 

750 edgecolor="white", 

751 linewidths=lw, 

752 ) 

753 else: 

754 plotOut = ax.scatter( 

755 xs, 

756 ys, 

757 c=zs, 

758 cmap=cmap, 

759 s=scatPtSize, 

760 vmin=vmin, 

761 vmax=vmax, 

762 edgecolor="white", 

763 linewidths=0.2, 

764 ) 

765 return plotOut