Coverage for python / lsst / analysis / tools / actions / plot / plotUtils.py: 10%

292 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-04 17:42 +0000

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("PanelConfig", "sortAllArrays") 

24 

25from collections.abc import Iterable, Mapping 

26from typing import TYPE_CHECKING 

27 

28import esutil 

29import matplotlib 

30import numpy as np 

31from matplotlib import cm, colors 

32from matplotlib.collections import PatchCollection 

33from matplotlib.patches import Rectangle 

34from scipy.stats import binned_statistic_2d 

35 

36from lsst.geom import Box2D, SpherePoint, degrees 

37from lsst.pex.config import Config, Field 

38 

39from ...math import nanMedian, nanSigmaMad 

40 

41if TYPE_CHECKING: 

42 from matplotlib.figure import Figure 

43 

44null_formatter = matplotlib.ticker.NullFormatter() 

45 

46 

47def generateSummaryStats(data, skymap, plotInfo): 

48 """Generate a summary statistic in each patch or detector. 

49 

50 Parameters 

51 ---------- 

52 data : `dict` 

53 A dictionary of the data to be plotted. 

54 skymap : `lsst.skymap.BaseSkyMap` 

55 The skymap associated with the data. 

56 plotInfo : `dict` 

57 A dictionary of the plot information. 

58 

59 Returns 

60 ------- 

61 patchInfoDict : `dict` 

62 A dictionary of the patch information. 

63 """ 

64 tractInfo = skymap.generateTract(plotInfo["tract"]) 

65 tractWcs = tractInfo.getWcs() 

66 

67 # For now also convert the gen 2 patchIds to gen 3 

68 if "y" in data.keys(): 

69 yCol = "y" 

70 elif "yStars" in data.keys(): 

71 yCol = "yStars" 

72 elif "yGalaxies" in data.keys(): 

73 yCol = "yGalaxies" 

74 elif "yUnknowns" in data.keys(): 

75 yCol = "yUnknowns" 

76 

77 patchInfoDict = {} 

78 maxPatchNum = tractInfo.num_patches.x * tractInfo.num_patches.y 

79 patches = np.arange(0, maxPatchNum, 1) 

80 

81 # Histogram (group) the patch values, and return an array of 

82 # "reverse indices" which is a specially encoded array of where 

83 # every patch is in the overall array. 

84 if len(data["patch"]) == 0: 

85 rev = np.full(maxPatchNum + 2, maxPatchNum + 2) 

86 else: 

87 _, rev = esutil.stat.histogram(data["patch"], min=0, max=maxPatchNum - 1, rev=True) 

88 

89 for patch in patches: 

90 # Pull out the onPatch indices 

91 onPatch = rev[rev[patch] : rev[patch + 1]] 

92 

93 if len(onPatch) == 0: 

94 stat = np.nan 

95 else: 

96 stat = nanMedian(data[yCol][onPatch]) 

97 try: 

98 patchTuple = (int(patch.split(",")[0]), int(patch.split(",")[-1])) 

99 patchInfo = tractInfo.getPatchInfo(patchTuple) 

100 gen3PatchId = tractInfo.getSequentialPatchIndex(patchInfo) 

101 except AttributeError: 

102 # For native gen 3 tables the patches don't need converting 

103 # When we are no longer looking at the gen 2 -> gen 3 

104 # converted repos we can tidy this up 

105 gen3PatchId = patch 

106 patchInfo = tractInfo.getPatchInfo(patch) 

107 

108 corners = Box2D(patchInfo.getInnerBBox()).getCorners() 

109 skyCoords = tractWcs.pixelToSky(corners) 

110 

111 patchInfoDict[gen3PatchId] = (skyCoords, stat) 

112 

113 tractCorners = Box2D(tractInfo.getBBox()).getCorners() 

114 skyCoords = tractWcs.pixelToSky(tractCorners) 

115 patchInfoDict["tract"] = (skyCoords, np.nan) 

116 

117 return patchInfoDict 

118 

119 

120def generateSummaryStatsVisit(cat, colName, visitSummaryTable): 

121 """Generate a summary statistic in each patch or detector. 

122 

123 Parameters 

124 ---------- 

125 cat : `pandas.core.frame.DataFrame` 

126 A dataframe of the data to be plotted. 

127 colName : `str` 

128 The name of the column to be plotted. 

129 visitSummaryTable : `pandas.core.frame.DataFrame` 

130 A dataframe of the visit summary table. 

131 

132 Returns 

133 ------- 

134 visitInfoDict : `dict` 

135 A dictionary of the visit information. 

136 """ 

137 visitInfoDict = {} 

138 for ccd in cat.detector.unique(): 

139 if ccd is None: 

140 continue 

141 onCcd = cat["detector"] == ccd 

142 stat = nanMedian(cat[colName].values[onCcd]) 

143 

144 sumRow = visitSummaryTable["id"] == ccd 

145 corners = zip(visitSummaryTable["raCorners"][sumRow][0], visitSummaryTable["decCorners"][sumRow][0]) 

146 cornersOut = [] 

147 for ra, dec in corners: 

148 corner = SpherePoint(ra, dec, units=degrees) 

149 cornersOut.append(corner) 

150 

151 visitInfoDict[ccd] = (cornersOut, stat) 

152 

153 return visitInfoDict 

154 

155 

156# Inspired by matplotlib.testing.remove_ticks_and_titles 

157def get_and_remove_axis_text(ax) -> tuple[list[str], list[np.ndarray]]: 

158 """Remove text from an Axis and its children and return with line points. 

159 

160 Parameters 

161 ---------- 

162 ax : `plt.Axis` 

163 A matplotlib figure axis. 

164 

165 Returns 

166 ------- 

167 texts : `List[str]` 

168 A list of all text strings (title and axis/legend/tick labels). 

169 line_xys : `List[numpy.ndarray]` 

170 A list of all line ``_xy`` attributes (arrays of shape ``(N, 2)``). 

171 """ 

172 line_xys = [line._xy for line in ax.lines] 

173 texts = [text.get_text() for text in (ax.title, ax.xaxis.label, ax.yaxis.label)] 

174 ax.set_title("") 

175 ax.set_xlabel("") 

176 ax.set_ylabel("") 

177 

178 try: 

179 texts_legend = ax.get_legend().texts 

180 texts.extend(text.get_text() for text in texts_legend) 

181 for text in texts_legend: 

182 text.set_alpha(0) 

183 except AttributeError: 

184 pass 

185 

186 for idx in range(len(ax.texts)): 

187 texts.append(ax.texts[idx].get_text()) 

188 ax.texts[idx].set_text("") 

189 

190 ax.xaxis.set_major_formatter(null_formatter) 

191 ax.xaxis.set_minor_formatter(null_formatter) 

192 ax.yaxis.set_major_formatter(null_formatter) 

193 ax.yaxis.set_minor_formatter(null_formatter) 

194 try: 

195 ax.zaxis.set_major_formatter(null_formatter) 

196 ax.zaxis.set_minor_formatter(null_formatter) 

197 except AttributeError: 

198 pass 

199 for child in ax.child_axes: 

200 texts_child, lines_child = get_and_remove_axis_text(child) 

201 texts.extend(texts_child) 

202 

203 return texts, line_xys 

204 

205 

206def get_and_remove_figure_text(figure: Figure): 

207 """Remove text from a Figure and its Axes and return with line points. 

208 

209 Parameters 

210 ---------- 

211 figure : `matplotlib.pyplot.Figure` 

212 A matplotlib figure. 

213 

214 Returns 

215 ------- 

216 texts : `List[str]` 

217 A list of all text strings (title and axis/legend/tick labels). 

218 line_xys : `List[numpy.ndarray]`, (N, 2) 

219 A list of all line ``_xy`` attributes (arrays of shape ``(N, 2)``). 

220 """ 

221 texts = [str(figure._suptitle)] 

222 lines = [] 

223 figure.suptitle("") 

224 

225 texts.extend(text.get_text() for text in figure.texts) 

226 figure.texts = [] 

227 

228 for ax in figure.get_axes(): 

229 texts_ax, lines_ax = get_and_remove_axis_text(ax) 

230 texts.extend(texts_ax) 

231 lines.extend(lines_ax) 

232 

233 return texts, lines 

234 

235 

236def parsePlotInfo(plotInfo: Mapping[str, str]) -> str: 

237 """Extract information from the plotInfo dictionary and parses it into 

238 a meaningful string that can be added to a figure. 

239 

240 Parameters 

241 ---------- 

242 plotInfo : `dict`[`str`, `str`] 

243 A plotInfo dictionary containing useful information to 

244 be included on a figure. 

245 

246 Returns 

247 ------- 

248 infoText : `str` 

249 A string containing the plotInfo information, parsed in such a 

250 way that it can be included on a figure. 

251 """ 

252 photocalibDataset = "None" 

253 astroDataset = "None" 

254 

255 run = plotInfo["run"] 

256 datasetsUsed = f"\nPhotoCalib: {photocalibDataset}, Astrometry: {astroDataset}" 

257 tableType = f"\nTable: {plotInfo['tableName']}" 

258 

259 dataIdText = "" 

260 if "tract" in plotInfo.keys(): 

261 dataIdText += f", Tract: {plotInfo['tract']}" 

262 if "visit" in plotInfo.keys(): 

263 dataIdText += f", Visit: {plotInfo['visit']}" 

264 

265 bandText = "" 

266 for band in plotInfo["bands"]: 

267 bandText += band + ", " 

268 bandsText = f", Bands: {bandText[:-2]}" 

269 if (photocalibDataset != "None") or (astroDataset != "None"): 

270 infoText = f"\n{run}{datasetsUsed}{tableType}{dataIdText}{bandsText}" 

271 else: 

272 infoText = f"\n{run}{tableType}{dataIdText}{bandsText}" 

273 

274 # Find S/N and mag keys, if present. 

275 snKeys = [] 

276 magKeys = [] 

277 selectionKeys = [] 

278 selectionPrefix = "Selection: " 

279 for key, value in plotInfo.items(): 

280 if "SN" in key or "S/N" in key: 

281 snKeys.append(key) 

282 elif "Mag" in key: 

283 magKeys.append(key) 

284 elif key.startswith(selectionPrefix): 

285 selectionKeys.append(key) 

286 # Add S/N and mag values to label, if present. 

287 # TODO: Do something if there are multiple sn/mag keys. Log? Warn? 

288 newline = "\n" 

289 if snKeys: 

290 infoText = f"{infoText}{newline if magKeys else ', '}{snKeys[0]}{plotInfo.get(snKeys[0])}" 

291 if magKeys: 

292 infoText = f"{infoText}, {magKeys[0]}{plotInfo.get(magKeys[0])}" 

293 if selectionKeys: 

294 nPrefix = len(selectionPrefix) 

295 selections = ", ".join(f"{key[nPrefix:]}: {plotInfo[key]}" for key in selectionKeys) 

296 infoText = f"{infoText}, Selections: {selections}" 

297 

298 return infoText 

299 

300 

301def addPlotInfo(fig: Figure, plotInfo: Mapping[str, str]) -> Figure: 

302 """Add useful information to the plot. 

303 

304 Parameters 

305 ---------- 

306 fig : `matplotlib.figure.Figure` 

307 The figure to add the information to. 

308 plotInfo : `dict` 

309 A dictionary of the plot information. 

310 

311 Returns 

312 ------- 

313 fig : `matplotlib.figure.Figure` 

314 The figure with the information added. 

315 """ 

316 fig.text(0.01, 0.99, plotInfo["plotName"], fontsize=7, transform=fig.transFigure, ha="left", va="top") 

317 infoText = parsePlotInfo(plotInfo) 

318 fig.text(0.01, 0.984, infoText, fontsize=6, transform=fig.transFigure, alpha=0.6, ha="left", va="top") 

319 

320 return fig 

321 

322 

323def mkColormap(colorNames): 

324 """Make a colormap from the list of color names. 

325 

326 Parameters 

327 ---------- 

328 colorNames : `list` 

329 A list of strings that correspond to matplotlib named colors. 

330 

331 Returns 

332 ------- 

333 cmap : `matplotlib.colors.LinearSegmentedColormap` 

334 A colormap stepping through the supplied list of names. 

335 """ 

336 blues = [] 

337 greens = [] 

338 reds = [] 

339 alphas = [] 

340 

341 if len(colorNames) == 1: 

342 # Alpha is between 0 and 1 really but 

343 # using 1.5 saturates out the top of the 

344 # colorscale, this looks good for ComCam data 

345 # but might want to be changed in the future. 

346 alphaRange = [0.3, 1.0] 

347 nums = np.linspace(0, 1, len(alphaRange)) 

348 r, g, b = colors.colorConverter.to_rgb(colorNames[0]) 

349 for num, alpha in zip(nums, alphaRange): 

350 blues.append((num, b, b)) 

351 greens.append((num, g, g)) 

352 reds.append((num, r, r)) 

353 alphas.append((num, alpha, alpha)) 

354 

355 else: 

356 nums = np.linspace(0, 1, len(colorNames)) 

357 if len(colorNames) == 3: 

358 alphaRange = [1.0, 0.3, 1.0] 

359 elif len(colorNames) == 5: 

360 alphaRange = [1.0, 0.7, 0.3, 0.7, 1.0] 

361 else: 

362 alphaRange = np.ones(len(colorNames)) 

363 

364 for num, color, alpha in zip(nums, colorNames, alphaRange): 

365 r, g, b = colors.colorConverter.to_rgb(color) 

366 blues.append((num, b, b)) 

367 greens.append((num, g, g)) 

368 reds.append((num, r, r)) 

369 alphas.append((num, alpha, alpha)) 

370 

371 colorDict = {"blue": blues, "red": reds, "green": greens, "alpha": alphas} 

372 cmap = colors.LinearSegmentedColormap("newCmap", colorDict) 

373 return cmap 

374 

375 

376def extremaSort(xs): 

377 """Return the IDs of the points reordered so that those furthest from the 

378 median, in absolute terms, are last. 

379 

380 Parameters 

381 ---------- 

382 xs : `np.array` 

383 An array of the values to sort 

384 

385 Returns 

386 ------- 

387 ids : `np.array` 

388 """ 

389 med = nanMedian(xs) 

390 dists = np.abs(xs - med) 

391 ids = np.argsort(dists) 

392 return ids 

393 

394 

395def sortAllArrays(arrsToSort, sortArrayIndex=0): 

396 """Sort one array and then return all the others in the associated order. 

397 

398 Parameters 

399 ---------- 

400 arrsToSort : `list` [`np.array`] 

401 A list of arrays to be simultaneously sorted based on the array in 

402 the list position given by ``sortArrayIndex`` (defaults to be the 

403 first array in the list). 

404 sortArrayIndex : `int`, optional 

405 Zero-based index indicating the array on which to base the sorting. 

406 

407 Returns 

408 ------- 

409 arrsToSort : `list` [`np.array`] 

410 The list of arrays sorted on array in list index ``sortArrayIndex``. 

411 """ 

412 ids = extremaSort(arrsToSort[sortArrayIndex]) 

413 for i, arr in enumerate(arrsToSort): 

414 arrsToSort[i] = arr[ids] 

415 return arrsToSort 

416 

417 

418def addSummaryPlot(fig, loc, sumStats, label): 

419 """Add a summary subplot to the figure. 

420 

421 Parameters 

422 ---------- 

423 fig : `matplotlib.figure.Figure` 

424 The figure that the summary plot is to be added to. 

425 loc : `matplotlib.gridspec.SubplotSpec` or `int` or `(int, int, index` 

426 Describes the location in the figure to put the summary plot, 

427 can be a gridspec SubplotSpec, a 3 digit integer where the first 

428 digit is the number of rows, the second is the number of columns 

429 and the third is the index. This is the same for the tuple 

430 of int, int, index. 

431 sumStats : `dict` 

432 A dictionary where the patchIds are the keys which store the R.A. 

433 and the dec of the corners of the patch, along with a summary 

434 statistic for each patch. 

435 label : `str` 

436 The label to be used for the colorbar. 

437 

438 Returns 

439 ------- 

440 fig : `matplotlib.figure.Figure` 

441 """ 

442 # Add the subplot to the relevant place in the figure 

443 # and sort the axis out 

444 axCorner = fig.add_subplot(loc) 

445 axCorner.yaxis.tick_right() 

446 axCorner.yaxis.set_label_position("right") 

447 axCorner.xaxis.tick_top() 

448 axCorner.xaxis.set_label_position("top") 

449 axCorner.set_aspect("equal") 

450 

451 # Plot the corners of the patches and make the color 

452 # coded rectangles for each patch, the colors show 

453 # the median of the given value in the patch 

454 patches = [] 

455 colors = [] 

456 for dataId in sumStats.keys(): 

457 (corners, stat) = sumStats[dataId] 

458 ra = corners[0][0].asDegrees() 

459 dec = corners[0][1].asDegrees() 

460 xy = (ra, dec) 

461 width = corners[2][0].asDegrees() - ra 

462 height = corners[2][1].asDegrees() - dec 

463 patches.append(Rectangle(xy, width, height)) 

464 colors.append(stat) 

465 ras = [ra.asDegrees() for (ra, dec) in corners] 

466 decs = [dec.asDegrees() for (ra, dec) in corners] 

467 axCorner.plot(ras + [ras[0]], decs + [decs[0]], "k", lw=0.5) 

468 cenX = ra + width / 2 

469 cenY = dec + height / 2 

470 if dataId != "tract": 

471 axCorner.annotate(dataId, (cenX, cenY), color="k", fontsize=4, ha="center", va="center") 

472 

473 # Set the bad color to transparent and make a masked array 

474 cmapPatch = cm.coolwarm.copy() 

475 cmapPatch.set_bad(color="none") 

476 colors = np.ma.array(colors, mask=np.isnan(colors)) 

477 collection = PatchCollection(patches, cmap=cmapPatch) 

478 collection.set_array(colors) 

479 axCorner.add_collection(collection) 

480 

481 # Add some labels 

482 axCorner.set_xlabel("R.A. (deg)", fontsize=7) 

483 axCorner.set_ylabel("Dec. (deg)", fontsize=7) 

484 axCorner.tick_params(axis="both", labelsize=6, length=0, pad=1.5) 

485 axCorner.invert_xaxis() 

486 

487 # Add a colorbar 

488 pos = axCorner.get_position() 

489 yOffset = (pos.y1 - pos.y0) / 3 

490 cax = fig.add_axes([pos.x0, pos.y1 + yOffset, pos.x1 - pos.x0, 0.025]) 

491 fig.colorbar(collection, cax=cax, orientation="horizontal") 

492 cornerLabel = "Median of patch\nvalues for\ny axis" 

493 axCorner.text(1.1, 1.3, cornerLabel, transform=axCorner.transAxes, fontsize=6) 

494 cax.tick_params( 

495 axis="x", labelsize=6, labeltop=True, labelbottom=False, bottom=False, top=True, pad=0.5, length=2 

496 ) 

497 

498 return fig 

499 

500 

501def shorten_list(numbers: Iterable[int], *, range_indicator: str = "-", range_separator: str = ",") -> str: 

502 """Shorten an iterable of integers. 

503 

504 Parameters 

505 ---------- 

506 numbers : `~collections.abc.Iterable` [`int`] 

507 Any iterable (list, set, tuple, numpy.array) of integers. 

508 range_indicator : `str`, optional 

509 The string to use to indicate a range of numbers. 

510 range_separator : `str`, optional 

511 The string to use to separate ranges of numbers. 

512 

513 Returns 

514 ------- 

515 result : `str` 

516 A shortened string representation of the list. 

517 

518 Examples 

519 -------- 

520 >>> shorten_list([1,2,3,5,6,8]) 

521 "1-3,5-6,8" 

522 

523 >>> shorten_list((1,2,3,5,6,8,9,10,11), range_separator=", ") 

524 "1-3, 5-6, 8-11" 

525 

526 >>> shorten_list(range(4), range_indicator="..") 

527 "0..3" 

528 """ 

529 # Sort the list in ascending order. 

530 numbers = sorted(numbers) 

531 

532 if not numbers: # empty container 

533 return "" 

534 

535 # Initialize an empty list to hold the results to be returned. 

536 result = [] 

537 

538 # Initialize variables to track the current start and end of a list. 

539 start = 0 

540 end = 0 # initialize to 0 to handle single element lists. 

541 

542 # Iterate through the sorted list of numbers 

543 for end in range(1, len(numbers)): 

544 # If the current number is the same or consecutive to the previous 

545 # number, skip to the next iteration. 

546 if numbers[end] > numbers[end - 1] + 1: # > is used to handle duplicates, if any. 

547 # If the current number is not consecutive to the previous number, 

548 # add the current range to the result and reset the start to end. 

549 if start == end - 1: 

550 result.append(str(numbers[start])) 

551 else: 

552 result.append(range_indicator.join((str(numbers[start]), str(numbers[end - 1])))) 

553 

554 # Update start. 

555 start = end 

556 

557 # Add the final range to the result. 

558 if start == end: 

559 result.append(str(numbers[start])) 

560 else: 

561 result.append(range_indicator.join((str(numbers[start]), str(numbers[end])))) 

562 

563 # Return the shortened string representation. 

564 return range_separator.join(result) 

565 

566 

567class PanelConfig(Config): 

568 """Configuration options for the plot panels used by DiaSkyPlot. 

569 

570 The defaults will produce a good-looking single panel plot. 

571 The subplot2grid* fields correspond to matplotlib.pyplot.subplot2grid. 

572 """ 

573 

574 topSpinesVisible = Field[bool]( 

575 doc="Draw line and ticks on top of panel?", 

576 default=False, 

577 ) 

578 bottomSpinesVisible = Field[bool]( 

579 doc="Draw line and ticks on bottom of panel?", 

580 default=True, 

581 ) 

582 leftSpinesVisible = Field[bool]( 

583 doc="Draw line and ticks on left side of panel?", 

584 default=True, 

585 ) 

586 rightSpinesVisible = Field[bool]( 

587 doc="Draw line and ticks on right side of panel?", 

588 default=True, 

589 ) 

590 subplot2gridShapeRow = Field[int]( 

591 doc="Number of rows of the grid in which to place axis.", 

592 default=10, 

593 ) 

594 subplot2gridShapeColumn = Field[int]( 

595 doc="Number of columns of the grid in which to place axis.", 

596 default=10, 

597 ) 

598 subplot2gridLocRow = Field[int]( 

599 doc="Row of the axis location within the grid.", 

600 default=1, 

601 ) 

602 subplot2gridLocColumn = Field[int]( 

603 doc="Column of the axis location within the grid.", 

604 default=1, 

605 ) 

606 subplot2gridRowspan = Field[int]( 

607 doc="Number of rows for the axis to span downwards.", 

608 default=5, 

609 ) 

610 subplot2gridColspan = Field[int]( 

611 doc="Number of rows for the axis to span to the right.", 

612 default=5, 

613 ) 

614 

615 

616def plotProjectionWithBinning( 

617 ax, 

618 xs, 

619 ys, 

620 zs, 

621 cmap, 

622 xMin, 

623 xMax, 

624 yMin, 

625 yMax, 

626 xNumBins=45, 

627 yNumBins=None, 

628 fixAroundZero=False, 

629 nPointBinThresh=5000, 

630 isSorted=False, 

631 vmin=None, 

632 vmax=None, 

633 showExtremeOutliers=True, 

634 scatPtSize=7, 

635 edgecolor="white", 

636 alpha=1.0, 

637): 

638 """Plot color-mapped data in projection and with binning when appropriate. 

639 

640 Parameters 

641 ---------- 

642 ax : `matplotlib.axes.Axes` 

643 Axis on which to plot the projection data. 

644 xs, ys : `np.array` 

645 Arrays containing the x and y positions of the data. 

646 zs : `np.array` 

647 Array containing the scaling value associated with the (``xs``, ``ys``) 

648 positions. 

649 cmap : `matplotlib.colors.Colormap` 

650 Colormap for the ``zs`` values. 

651 xMin, xMax, yMin, yMax : `float` 

652 Data limits within which to compute bin sizes. 

653 xNumBins : `int`, optional 

654 The number of bins along the x-axis. 

655 yNumBins : `int`, optional 

656 The number of bins along the y-axis. If `None`, this is set to equal 

657 ``xNumBins``. 

658 nPointBinThresh : `int`, optional 

659 Threshold number of points above which binning will be implemented 

660 for the plotting. If the number of data points is lower than this 

661 threshold, a basic scatter plot will be generated. 

662 isSorted : `bool`, optional 

663 Whether the data have been sorted in ``zs`` (the sorting is to 

664 accommodate the overplotting of points in the upper and lower 

665 extrema of the data). 

666 vmin, vmax : `float`, optional 

667 The min and max limits for the colorbar. 

668 showExtremeOutliers: `bool`, default True 

669 Use overlaid scatter points to show the x-y positions of the 15% 

670 most extreme values. 

671 scatPtSize : `float`, optional 

672 The point size to use if just plotting a regular scatter plot. 

673 edgecolor : `str`, optional 

674 The edge color to use for the scatter plot points. Default white. 

675 alpha : `float`, optional 

676 The transparency or alpha to use for scatter plot points. Default 1.0. 

677 

678 Returns 

679 ------- 

680 plotOut : `matplotlib.collections.PathCollection` 

681 The plot object with ``ax`` updated with data plotted here. 

682 """ 

683 med = nanMedian(zs) 

684 mad = nanSigmaMad(zs) 

685 if vmin is None: 

686 vmin = med - 2 * mad 

687 if vmax is None: 

688 vmax = med + 2 * mad 

689 if fixAroundZero: 

690 scaleEnd = np.max([np.abs(vmin), np.abs(vmax)]) 

691 vmin = -1 * scaleEnd 

692 vmax = scaleEnd 

693 

694 yNumBins = xNumBins if yNumBins is None else yNumBins 

695 

696 xBinEdges = np.linspace(xMin, xMax, xNumBins + 1) 

697 yBinEdges = np.linspace(yMin, yMax, yNumBins + 1) 

698 finiteMask = np.isfinite(zs) 

699 xs = xs[finiteMask] 

700 ys = ys[finiteMask] 

701 zs = zs[finiteMask] 

702 binnedStats, xEdges, yEdges, binNums = binned_statistic_2d( 

703 xs, ys, zs, statistic="median", bins=(xBinEdges, yBinEdges) 

704 ) 

705 

706 if len(xs) >= nPointBinThresh: 

707 s = min(10, max(0.5, nPointBinThresh / 10 / (len(xs) ** 0.5))) 

708 lw = (s**0.5) / 10 

709 plotOut = ax.imshow( 

710 binnedStats.T, 

711 cmap=cmap, 

712 extent=[xEdges[0], xEdges[-1], yEdges[-1], yEdges[0]], 

713 vmin=vmin, 

714 vmax=vmax, 

715 ) 

716 if not isSorted: 

717 sortedArrays = sortAllArrays([zs, xs, ys]) 

718 zs, xs, ys = sortedArrays[0], sortedArrays[1], sortedArrays[2] 

719 if len(xs) > 1: 

720 if showExtremeOutliers: 

721 # Find the most extreme 15% of points. The list is ordered 

722 # by the distance from the median, this is just the 

723 # head/tail 15% of points. 

724 extremes = int(np.floor(len(xs) / 100) * 85) 

725 plotOut = ax.scatter( 

726 xs[extremes:], 

727 ys[extremes:], 

728 c=zs[extremes:], 

729 s=s, 

730 cmap=cmap, 

731 vmin=vmin, 

732 vmax=vmax, 

733 edgecolor=edgecolor, 

734 linewidths=lw, 

735 alpha=alpha, 

736 ) 

737 else: 

738 plotOut = ax.scatter( 

739 xs, 

740 ys, 

741 c=zs, 

742 cmap=cmap, 

743 s=scatPtSize, 

744 vmin=vmin, 

745 vmax=vmax, 

746 edgecolor=edgecolor, 

747 linewidths=0.2, 

748 alpha=alpha, 

749 ) 

750 return plotOut