Coverage for python / lsst / analysis / tools / actions / plot / plotUtils.py: 10%

289 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-22 09:09 +0000

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("PanelConfig", "sortAllArrays") 

24 

25from typing import TYPE_CHECKING, Iterable, List, Mapping, Tuple 

26 

27import esutil 

28import matplotlib 

29import numpy as np 

30from lsst.geom import Box2D, SpherePoint, degrees 

31from lsst.pex.config import Config, Field 

32from matplotlib import cm, colors 

33from matplotlib.collections import PatchCollection 

34from matplotlib.patches import Rectangle 

35from scipy.stats import binned_statistic_2d 

36 

37from ...math import nanMedian, nanSigmaMad 

38 

39if TYPE_CHECKING: 

40 from matplotlib.figure import Figure 

41 

42null_formatter = matplotlib.ticker.NullFormatter() 

43 

44 

45def generateSummaryStats(data, skymap, plotInfo): 

46 """Generate a summary statistic in each patch or detector. 

47 

48 Parameters 

49 ---------- 

50 data : `dict` 

51 A dictionary of the data to be plotted. 

52 skymap : `lsst.skymap.BaseSkyMap` 

53 The skymap associated with the data. 

54 plotInfo : `dict` 

55 A dictionary of the plot information. 

56 

57 Returns 

58 ------- 

59 patchInfoDict : `dict` 

60 A dictionary of the patch information. 

61 """ 

62 tractInfo = skymap.generateTract(plotInfo["tract"]) 

63 tractWcs = tractInfo.getWcs() 

64 

65 # For now also convert the gen 2 patchIds to gen 3 

66 if "y" in data.keys(): 

67 yCol = "y" 

68 elif "yStars" in data.keys(): 

69 yCol = "yStars" 

70 elif "yGalaxies" in data.keys(): 

71 yCol = "yGalaxies" 

72 elif "yUnknowns" in data.keys(): 

73 yCol = "yUnknowns" 

74 

75 patchInfoDict = {} 

76 maxPatchNum = tractInfo.num_patches.x * tractInfo.num_patches.y 

77 patches = np.arange(0, maxPatchNum, 1) 

78 

79 # Histogram (group) the patch values, and return an array of 

80 # "reverse indices" which is a specially encoded array of where 

81 # every patch is in the overall array. 

82 if len(data["patch"]) == 0: 

83 rev = np.full(maxPatchNum + 2, maxPatchNum + 2) 

84 else: 

85 _, rev = esutil.stat.histogram(data["patch"], min=0, max=maxPatchNum - 1, rev=True) 

86 

87 for patch in patches: 

88 # Pull out the onPatch indices 

89 onPatch = rev[rev[patch] : rev[patch + 1]] 

90 

91 if len(onPatch) == 0: 

92 stat = np.nan 

93 else: 

94 stat = nanMedian(data[yCol][onPatch]) 

95 try: 

96 patchTuple = (int(patch.split(",")[0]), int(patch.split(",")[-1])) 

97 patchInfo = tractInfo.getPatchInfo(patchTuple) 

98 gen3PatchId = tractInfo.getSequentialPatchIndex(patchInfo) 

99 except AttributeError: 

100 # For native gen 3 tables the patches don't need converting 

101 # When we are no longer looking at the gen 2 -> gen 3 

102 # converted repos we can tidy this up 

103 gen3PatchId = patch 

104 patchInfo = tractInfo.getPatchInfo(patch) 

105 

106 corners = Box2D(patchInfo.getInnerBBox()).getCorners() 

107 skyCoords = tractWcs.pixelToSky(corners) 

108 

109 patchInfoDict[gen3PatchId] = (skyCoords, stat) 

110 

111 tractCorners = Box2D(tractInfo.getBBox()).getCorners() 

112 skyCoords = tractWcs.pixelToSky(tractCorners) 

113 patchInfoDict["tract"] = (skyCoords, np.nan) 

114 

115 return patchInfoDict 

116 

117 

118def generateSummaryStatsVisit(cat, colName, visitSummaryTable): 

119 """Generate a summary statistic in each patch or detector. 

120 

121 Parameters 

122 ---------- 

123 cat : `pandas.core.frame.DataFrame` 

124 A dataframe of the data to be plotted. 

125 colName : `str` 

126 The name of the column to be plotted. 

127 visitSummaryTable : `pandas.core.frame.DataFrame` 

128 A dataframe of the visit summary table. 

129 

130 Returns 

131 ------- 

132 visitInfoDict : `dict` 

133 A dictionary of the visit information. 

134 """ 

135 visitInfoDict = {} 

136 for ccd in cat.detector.unique(): 

137 if ccd is None: 

138 continue 

139 onCcd = cat["detector"] == ccd 

140 stat = nanMedian(cat[colName].values[onCcd]) 

141 

142 sumRow = visitSummaryTable["id"] == ccd 

143 corners = zip(visitSummaryTable["raCorners"][sumRow][0], visitSummaryTable["decCorners"][sumRow][0]) 

144 cornersOut = [] 

145 for ra, dec in corners: 

146 corner = SpherePoint(ra, dec, units=degrees) 

147 cornersOut.append(corner) 

148 

149 visitInfoDict[ccd] = (cornersOut, stat) 

150 

151 return visitInfoDict 

152 

153 

154# Inspired by matplotlib.testing.remove_ticks_and_titles 

155def get_and_remove_axis_text(ax) -> Tuple[List[str], List[np.ndarray]]: 

156 """Remove text from an Axis and its children and return with line points. 

157 

158 Parameters 

159 ---------- 

160 ax : `plt.Axis` 

161 A matplotlib figure axis. 

162 

163 Returns 

164 ------- 

165 texts : `List[str]` 

166 A list of all text strings (title and axis/legend/tick labels). 

167 line_xys : `List[numpy.ndarray]` 

168 A list of all line ``_xy`` attributes (arrays of shape ``(N, 2)``). 

169 """ 

170 line_xys = [line._xy for line in ax.lines] 

171 texts = [text.get_text() for text in (ax.title, ax.xaxis.label, ax.yaxis.label)] 

172 ax.set_title("") 

173 ax.set_xlabel("") 

174 ax.set_ylabel("") 

175 

176 try: 

177 texts_legend = ax.get_legend().texts 

178 texts.extend(text.get_text() for text in texts_legend) 

179 for text in texts_legend: 

180 text.set_alpha(0) 

181 except AttributeError: 

182 pass 

183 

184 for idx in range(len(ax.texts)): 

185 texts.append(ax.texts[idx].get_text()) 

186 ax.texts[idx].set_text("") 

187 

188 ax.xaxis.set_major_formatter(null_formatter) 

189 ax.xaxis.set_minor_formatter(null_formatter) 

190 ax.yaxis.set_major_formatter(null_formatter) 

191 ax.yaxis.set_minor_formatter(null_formatter) 

192 try: 

193 ax.zaxis.set_major_formatter(null_formatter) 

194 ax.zaxis.set_minor_formatter(null_formatter) 

195 except AttributeError: 

196 pass 

197 for child in ax.child_axes: 

198 texts_child, lines_child = get_and_remove_axis_text(child) 

199 texts.extend(texts_child) 

200 

201 return texts, line_xys 

202 

203 

204def get_and_remove_figure_text(figure: Figure): 

205 """Remove text from a Figure and its Axes and return with line points. 

206 

207 Parameters 

208 ---------- 

209 figure : `matplotlib.pyplot.Figure` 

210 A matplotlib figure. 

211 

212 Returns 

213 ------- 

214 texts : `List[str]` 

215 A list of all text strings (title and axis/legend/tick labels). 

216 line_xys : `List[numpy.ndarray]`, (N, 2) 

217 A list of all line ``_xy`` attributes (arrays of shape ``(N, 2)``). 

218 """ 

219 texts = [str(figure._suptitle)] 

220 lines = [] 

221 figure.suptitle("") 

222 

223 texts.extend(text.get_text() for text in figure.texts) 

224 figure.texts = [] 

225 

226 for ax in figure.get_axes(): 

227 texts_ax, lines_ax = get_and_remove_axis_text(ax) 

228 texts.extend(texts_ax) 

229 lines.extend(lines_ax) 

230 

231 return texts, lines 

232 

233 

234def parsePlotInfo(plotInfo: Mapping[str, str]) -> str: 

235 """Extract information from the plotInfo dictionary and parses it into 

236 a meaningful string that can be added to a figure. 

237 

238 Parameters 

239 ---------- 

240 plotInfo : `dict`[`str`, `str`] 

241 A plotInfo dictionary containing useful information to 

242 be included on a figure. 

243 

244 Returns 

245 ------- 

246 infoText : `str` 

247 A string containing the plotInfo information, parsed in such a 

248 way that it can be included on a figure. 

249 """ 

250 photocalibDataset = "None" 

251 astroDataset = "None" 

252 

253 run = plotInfo["run"] 

254 datasetsUsed = f"\nPhotoCalib: {photocalibDataset}, Astrometry: {astroDataset}" 

255 tableType = f"\nTable: {plotInfo['tableName']}" 

256 

257 dataIdText = "" 

258 if "tract" in plotInfo.keys(): 

259 dataIdText += f", Tract: {plotInfo['tract']}" 

260 if "visit" in plotInfo.keys(): 

261 dataIdText += f", Visit: {plotInfo['visit']}" 

262 

263 bandText = "" 

264 for band in plotInfo["bands"]: 

265 bandText += band + ", " 

266 bandsText = f", Bands: {bandText[:-2]}" 

267 infoText = f"\n{run}{datasetsUsed}{tableType}{dataIdText}{bandsText}" 

268 

269 # Find S/N and mag keys, if present. 

270 snKeys = [] 

271 magKeys = [] 

272 selectionKeys = [] 

273 selectionPrefix = "Selection: " 

274 for key, value in plotInfo.items(): 

275 if "SN" in key or "S/N" in key: 

276 snKeys.append(key) 

277 elif "Mag" in key: 

278 magKeys.append(key) 

279 elif key.startswith(selectionPrefix): 

280 selectionKeys.append(key) 

281 # Add S/N and mag values to label, if present. 

282 # TODO: Do something if there are multiple sn/mag keys. Log? Warn? 

283 newline = "\n" 

284 if snKeys: 

285 infoText = f"{infoText}{newline if magKeys else ', '}{snKeys[0]}{plotInfo.get(snKeys[0])}" 

286 if magKeys: 

287 infoText = f"{infoText}, {magKeys[0]}{plotInfo.get(magKeys[0])}" 

288 if selectionKeys: 

289 nPrefix = len(selectionPrefix) 

290 selections = ", ".join(f"{key[nPrefix:]}: {plotInfo[key]}" for key in selectionKeys) 

291 infoText = f"{infoText}, Selections: {selections}" 

292 

293 return infoText 

294 

295 

296def addPlotInfo(fig: Figure, plotInfo: Mapping[str, str]) -> Figure: 

297 """Add useful information to the plot. 

298 

299 Parameters 

300 ---------- 

301 fig : `matplotlib.figure.Figure` 

302 The figure to add the information to. 

303 plotInfo : `dict` 

304 A dictionary of the plot information. 

305 

306 Returns 

307 ------- 

308 fig : `matplotlib.figure.Figure` 

309 The figure with the information added. 

310 """ 

311 fig.text(0.01, 0.99, plotInfo["plotName"], fontsize=7, transform=fig.transFigure, ha="left", va="top") 

312 infoText = parsePlotInfo(plotInfo) 

313 fig.text(0.01, 0.984, infoText, fontsize=6, transform=fig.transFigure, alpha=0.6, ha="left", va="top") 

314 

315 return fig 

316 

317 

318def mkColormap(colorNames): 

319 """Make a colormap from the list of color names. 

320 

321 Parameters 

322 ---------- 

323 colorNames : `list` 

324 A list of strings that correspond to matplotlib named colors. 

325 

326 Returns 

327 ------- 

328 cmap : `matplotlib.colors.LinearSegmentedColormap` 

329 A colormap stepping through the supplied list of names. 

330 """ 

331 blues = [] 

332 greens = [] 

333 reds = [] 

334 alphas = [] 

335 

336 if len(colorNames) == 1: 

337 # Alpha is between 0 and 1 really but 

338 # using 1.5 saturates out the top of the 

339 # colorscale, this looks good for ComCam data 

340 # but might want to be changed in the future. 

341 alphaRange = [0.3, 1.0] 

342 nums = np.linspace(0, 1, len(alphaRange)) 

343 r, g, b = colors.colorConverter.to_rgb(colorNames[0]) 

344 for num, alpha in zip(nums, alphaRange): 

345 blues.append((num, b, b)) 

346 greens.append((num, g, g)) 

347 reds.append((num, r, r)) 

348 alphas.append((num, alpha, alpha)) 

349 

350 else: 

351 nums = np.linspace(0, 1, len(colorNames)) 

352 if len(colorNames) == 3: 

353 alphaRange = [1.0, 0.3, 1.0] 

354 elif len(colorNames) == 5: 

355 alphaRange = [1.0, 0.7, 0.3, 0.7, 1.0] 

356 else: 

357 alphaRange = np.ones(len(colorNames)) 

358 

359 for num, color, alpha in zip(nums, colorNames, alphaRange): 

360 r, g, b = colors.colorConverter.to_rgb(color) 

361 blues.append((num, b, b)) 

362 greens.append((num, g, g)) 

363 reds.append((num, r, r)) 

364 alphas.append((num, alpha, alpha)) 

365 

366 colorDict = {"blue": blues, "red": reds, "green": greens, "alpha": alphas} 

367 cmap = colors.LinearSegmentedColormap("newCmap", colorDict) 

368 return cmap 

369 

370 

371def extremaSort(xs): 

372 """Return the IDs of the points reordered so that those furthest from the 

373 median, in absolute terms, are last. 

374 

375 Parameters 

376 ---------- 

377 xs : `np.array` 

378 An array of the values to sort 

379 

380 Returns 

381 ------- 

382 ids : `np.array` 

383 """ 

384 med = nanMedian(xs) 

385 dists = np.abs(xs - med) 

386 ids = np.argsort(dists) 

387 return ids 

388 

389 

390def sortAllArrays(arrsToSort, sortArrayIndex=0): 

391 """Sort one array and then return all the others in the associated order. 

392 

393 Parameters 

394 ---------- 

395 arrsToSort : `list` [`np.array`] 

396 A list of arrays to be simultaneously sorted based on the array in 

397 the list position given by ``sortArrayIndex`` (defaults to be the 

398 first array in the list). 

399 sortArrayIndex : `int`, optional 

400 Zero-based index indicating the array on which to base the sorting. 

401 

402 Returns 

403 ------- 

404 arrsToSort : `list` [`np.array`] 

405 The list of arrays sorted on array in list index ``sortArrayIndex``. 

406 """ 

407 ids = extremaSort(arrsToSort[sortArrayIndex]) 

408 for i, arr in enumerate(arrsToSort): 

409 arrsToSort[i] = arr[ids] 

410 return arrsToSort 

411 

412 

413def addSummaryPlot(fig, loc, sumStats, label): 

414 """Add a summary subplot to the figure. 

415 

416 Parameters 

417 ---------- 

418 fig : `matplotlib.figure.Figure` 

419 The figure that the summary plot is to be added to. 

420 loc : `matplotlib.gridspec.SubplotSpec` or `int` or `(int, int, index` 

421 Describes the location in the figure to put the summary plot, 

422 can be a gridspec SubplotSpec, a 3 digit integer where the first 

423 digit is the number of rows, the second is the number of columns 

424 and the third is the index. This is the same for the tuple 

425 of int, int, index. 

426 sumStats : `dict` 

427 A dictionary where the patchIds are the keys which store the R.A. 

428 and the dec of the corners of the patch, along with a summary 

429 statistic for each patch. 

430 label : `str` 

431 The label to be used for the colorbar. 

432 

433 Returns 

434 ------- 

435 fig : `matplotlib.figure.Figure` 

436 """ 

437 # Add the subplot to the relevant place in the figure 

438 # and sort the axis out 

439 axCorner = fig.add_subplot(loc) 

440 axCorner.yaxis.tick_right() 

441 axCorner.yaxis.set_label_position("right") 

442 axCorner.xaxis.tick_top() 

443 axCorner.xaxis.set_label_position("top") 

444 axCorner.set_aspect("equal") 

445 

446 # Plot the corners of the patches and make the color 

447 # coded rectangles for each patch, the colors show 

448 # the median of the given value in the patch 

449 patches = [] 

450 colors = [] 

451 for dataId in sumStats.keys(): 

452 (corners, stat) = sumStats[dataId] 

453 ra = corners[0][0].asDegrees() 

454 dec = corners[0][1].asDegrees() 

455 xy = (ra, dec) 

456 width = corners[2][0].asDegrees() - ra 

457 height = corners[2][1].asDegrees() - dec 

458 patches.append(Rectangle(xy, width, height)) 

459 colors.append(stat) 

460 ras = [ra.asDegrees() for (ra, dec) in corners] 

461 decs = [dec.asDegrees() for (ra, dec) in corners] 

462 axCorner.plot(ras + [ras[0]], decs + [decs[0]], "k", lw=0.5) 

463 cenX = ra + width / 2 

464 cenY = dec + height / 2 

465 if dataId != "tract": 

466 axCorner.annotate(dataId, (cenX, cenY), color="k", fontsize=4, ha="center", va="center") 

467 

468 # Set the bad color to transparent and make a masked array 

469 cmapPatch = cm.coolwarm.copy() 

470 cmapPatch.set_bad(color="none") 

471 colors = np.ma.array(colors, mask=np.isnan(colors)) 

472 collection = PatchCollection(patches, cmap=cmapPatch) 

473 collection.set_array(colors) 

474 axCorner.add_collection(collection) 

475 

476 # Add some labels 

477 axCorner.set_xlabel("R.A. (deg)", fontsize=7) 

478 axCorner.set_ylabel("Dec. (deg)", fontsize=7) 

479 axCorner.tick_params(axis="both", labelsize=6, length=0, pad=1.5) 

480 axCorner.invert_xaxis() 

481 

482 # Add a colorbar 

483 pos = axCorner.get_position() 

484 yOffset = (pos.y1 - pos.y0) / 3 

485 cax = fig.add_axes([pos.x0, pos.y1 + yOffset, pos.x1 - pos.x0, 0.025]) 

486 fig.colorbar(collection, cax=cax, orientation="horizontal") 

487 cornerLabel = "Median of patch\nvalues for\ny axis" 

488 axCorner.text(1.1, 1.3, cornerLabel, transform=axCorner.transAxes, fontsize=6) 

489 cax.tick_params( 

490 axis="x", labelsize=6, labeltop=True, labelbottom=False, bottom=False, top=True, pad=0.5, length=2 

491 ) 

492 

493 return fig 

494 

495 

496def shorten_list(numbers: Iterable[int], *, range_indicator: str = "-", range_separator: str = ",") -> str: 

497 """Shorten an iterable of integers. 

498 

499 Parameters 

500 ---------- 

501 numbers : `~collections.abc.Iterable` [`int`] 

502 Any iterable (list, set, tuple, numpy.array) of integers. 

503 range_indicator : `str`, optional 

504 The string to use to indicate a range of numbers. 

505 range_separator : `str`, optional 

506 The string to use to separate ranges of numbers. 

507 

508 Returns 

509 ------- 

510 result : `str` 

511 A shortened string representation of the list. 

512 

513 Examples 

514 -------- 

515 >>> shorten_list([1,2,3,5,6,8]) 

516 "1-3,5-6,8" 

517 

518 >>> shorten_list((1,2,3,5,6,8,9,10,11), range_separator=", ") 

519 "1-3, 5-6, 8-11" 

520 

521 >>> shorten_list(range(4), range_indicator="..") 

522 "0..3" 

523 """ 

524 # Sort the list in ascending order. 

525 numbers = sorted(numbers) 

526 

527 if not numbers: # empty container 

528 return "" 

529 

530 # Initialize an empty list to hold the results to be returned. 

531 result = [] 

532 

533 # Initialize variables to track the current start and end of a list. 

534 start = 0 

535 end = 0 # initialize to 0 to handle single element lists. 

536 

537 # Iterate through the sorted list of numbers 

538 for end in range(1, len(numbers)): 

539 # If the current number is the same or consecutive to the previous 

540 # number, skip to the next iteration. 

541 if numbers[end] > numbers[end - 1] + 1: # > is used to handle duplicates, if any. 

542 # If the current number is not consecutive to the previous number, 

543 # add the current range to the result and reset the start to end. 

544 if start == end - 1: 

545 result.append(str(numbers[start])) 

546 else: 

547 result.append(range_indicator.join((str(numbers[start]), str(numbers[end - 1])))) 

548 

549 # Update start. 

550 start = end 

551 

552 # Add the final range to the result. 

553 if start == end: 

554 result.append(str(numbers[start])) 

555 else: 

556 result.append(range_indicator.join((str(numbers[start]), str(numbers[end])))) 

557 

558 # Return the shortened string representation. 

559 return range_separator.join(result) 

560 

561 

562class PanelConfig(Config): 

563 """Configuration options for the plot panels used by DiaSkyPlot. 

564 

565 The defaults will produce a good-looking single panel plot. 

566 The subplot2grid* fields correspond to matplotlib.pyplot.subplot2grid. 

567 """ 

568 

569 topSpinesVisible = Field[bool]( 

570 doc="Draw line and ticks on top of panel?", 

571 default=False, 

572 ) 

573 bottomSpinesVisible = Field[bool]( 

574 doc="Draw line and ticks on bottom of panel?", 

575 default=True, 

576 ) 

577 leftSpinesVisible = Field[bool]( 

578 doc="Draw line and ticks on left side of panel?", 

579 default=True, 

580 ) 

581 rightSpinesVisible = Field[bool]( 

582 doc="Draw line and ticks on right side of panel?", 

583 default=True, 

584 ) 

585 subplot2gridShapeRow = Field[int]( 

586 doc="Number of rows of the grid in which to place axis.", 

587 default=10, 

588 ) 

589 subplot2gridShapeColumn = Field[int]( 

590 doc="Number of columns of the grid in which to place axis.", 

591 default=10, 

592 ) 

593 subplot2gridLocRow = Field[int]( 

594 doc="Row of the axis location within the grid.", 

595 default=1, 

596 ) 

597 subplot2gridLocColumn = Field[int]( 

598 doc="Column of the axis location within the grid.", 

599 default=1, 

600 ) 

601 subplot2gridRowspan = Field[int]( 

602 doc="Number of rows for the axis to span downwards.", 

603 default=5, 

604 ) 

605 subplot2gridColspan = Field[int]( 

606 doc="Number of rows for the axis to span to the right.", 

607 default=5, 

608 ) 

609 

610 

611def plotProjectionWithBinning( 

612 ax, 

613 xs, 

614 ys, 

615 zs, 

616 cmap, 

617 xMin, 

618 xMax, 

619 yMin, 

620 yMax, 

621 xNumBins=45, 

622 yNumBins=None, 

623 fixAroundZero=False, 

624 nPointBinThresh=5000, 

625 isSorted=False, 

626 vmin=None, 

627 vmax=None, 

628 showExtremeOutliers=True, 

629 scatPtSize=7, 

630): 

631 """Plot color-mapped data in projection and with binning when appropriate. 

632 

633 Parameters 

634 ---------- 

635 ax : `matplotlib.axes.Axes` 

636 Axis on which to plot the projection data. 

637 xs, ys : `np.array` 

638 Arrays containing the x and y positions of the data. 

639 zs : `np.array` 

640 Array containing the scaling value associated with the (``xs``, ``ys``) 

641 positions. 

642 cmap : `matplotlib.colors.Colormap` 

643 Colormap for the ``zs`` values. 

644 xMin, xMax, yMin, yMax : `float` 

645 Data limits within which to compute bin sizes. 

646 xNumBins : `int`, optional 

647 The number of bins along the x-axis. 

648 yNumBins : `int`, optional 

649 The number of bins along the y-axis. If `None`, this is set to equal 

650 ``xNumBins``. 

651 nPointBinThresh : `int`, optional 

652 Threshold number of points above which binning will be implemented 

653 for the plotting. If the number of data points is lower than this 

654 threshold, a basic scatter plot will be generated. 

655 isSorted : `bool`, optional 

656 Whether the data have been sorted in ``zs`` (the sorting is to 

657 accommodate the overplotting of points in the upper and lower 

658 extrema of the data). 

659 vmin, vmax : `float`, optional 

660 The min and max limits for the colorbar. 

661 showExtremeOutliers: `bool`, default True 

662 Use overlaid scatter points to show the x-y positions of the 15% 

663 most extreme values. 

664 scatPtSize : `float`, optional 

665 The point size to use if just plotting a regular scatter plot. 

666 

667 Returns 

668 ------- 

669 plotOut : `matplotlib.collections.PathCollection` 

670 The plot object with ``ax`` updated with data plotted here. 

671 """ 

672 med = nanMedian(zs) 

673 mad = nanSigmaMad(zs) 

674 if vmin is None: 

675 vmin = med - 2 * mad 

676 if vmax is None: 

677 vmax = med + 2 * mad 

678 if fixAroundZero: 

679 scaleEnd = np.max([np.abs(vmin), np.abs(vmax)]) 

680 vmin = -1 * scaleEnd 

681 vmax = scaleEnd 

682 

683 yNumBins = xNumBins if yNumBins is None else yNumBins 

684 

685 xBinEdges = np.linspace(xMin, xMax, xNumBins + 1) 

686 yBinEdges = np.linspace(yMin, yMax, yNumBins + 1) 

687 finiteMask = np.isfinite(zs) 

688 xs = xs[finiteMask] 

689 ys = ys[finiteMask] 

690 zs = zs[finiteMask] 

691 binnedStats, xEdges, yEdges, binNums = binned_statistic_2d( 

692 xs, ys, zs, statistic="median", bins=(xBinEdges, yBinEdges) 

693 ) 

694 

695 if len(xs) >= nPointBinThresh: 

696 s = min(10, max(0.5, nPointBinThresh / 10 / (len(xs) ** 0.5))) 

697 lw = (s**0.5) / 10 

698 plotOut = ax.imshow( 

699 binnedStats.T, 

700 cmap=cmap, 

701 extent=[xEdges[0], xEdges[-1], yEdges[-1], yEdges[0]], 

702 vmin=vmin, 

703 vmax=vmax, 

704 ) 

705 if not isSorted: 

706 sortedArrays = sortAllArrays([zs, xs, ys]) 

707 zs, xs, ys = sortedArrays[0], sortedArrays[1], sortedArrays[2] 

708 if len(xs) > 1: 

709 if showExtremeOutliers: 

710 # Find the most extreme 15% of points. The list is ordered 

711 # by the distance from the median, this is just the 

712 # head/tail 15% of points. 

713 extremes = int(np.floor((len(xs) / 100)) * 85) 

714 plotOut = ax.scatter( 

715 xs[extremes:], 

716 ys[extremes:], 

717 c=zs[extremes:], 

718 s=s, 

719 cmap=cmap, 

720 vmin=vmin, 

721 vmax=vmax, 

722 edgecolor="white", 

723 linewidths=lw, 

724 ) 

725 else: 

726 plotOut = ax.scatter( 

727 xs, 

728 ys, 

729 c=zs, 

730 cmap=cmap, 

731 s=scatPtSize, 

732 vmin=vmin, 

733 vmax=vmax, 

734 edgecolor="white", 

735 linewidths=0.2, 

736 ) 

737 return plotOut