Coverage for python/lsst/analysis/tools/actions/plot/plotUtils.py: 15%

269 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-16 04:38 -0700

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("PanelConfig",) 

24 

25from typing import TYPE_CHECKING, Iterable, List, Mapping, Tuple 

26 

27import matplotlib 

28import matplotlib.pyplot as plt 

29import numpy as np 

30from lsst.geom import Box2D, SpherePoint, degrees 

31from lsst.pex.config import Config, Field 

32from matplotlib import colors 

33from matplotlib.collections import PatchCollection 

34from matplotlib.patches import Rectangle 

35from scipy.stats import binned_statistic_2d 

36 

37from ...math import nanMedian, nanSigmaMad 

38 

39if TYPE_CHECKING: 39 ↛ 40line 39 didn't jump to line 40, because the condition on line 39 was never true

40 from matplotlib.figure import Figure 

41 

42null_formatter = matplotlib.ticker.NullFormatter() 

43 

44 

45def generateSummaryStats(data, skymap, plotInfo): 

46 """Generate a summary statistic in each patch or detector. 

47 

48 Parameters 

49 ---------- 

50 data : `dict` 

51 A dictionary of the data to be plotted. 

52 skymap : `lsst.skymap.BaseSkyMap` 

53 The skymap associated with the data. 

54 plotInfo : `dict` 

55 A dictionary of the plot information. 

56 

57 Returns 

58 ------- 

59 patchInfoDict : `dict` 

60 A dictionary of the patch information. 

61 """ 

62 tractInfo = skymap.generateTract(plotInfo["tract"]) 

63 tractWcs = tractInfo.getWcs() 

64 

65 # For now also convert the gen 2 patchIds to gen 3 

66 if "y" in data.keys(): 

67 yCol = "y" 

68 elif "yStars" in data.keys(): 

69 yCol = "yStars" 

70 elif "yGalaxies" in data.keys(): 

71 yCol = "yGalaxies" 

72 elif "yUnknowns" in data.keys(): 

73 yCol = "yUnknowns" 

74 

75 patchInfoDict = {} 

76 maxPatchNum = tractInfo.num_patches.x * tractInfo.num_patches.y 

77 patches = np.arange(0, maxPatchNum, 1) 

78 for patch in patches: 

79 if patch is None: 

80 continue 

81 # Once the objectTable_tract catalogues are using gen 3 patches 

82 # this will go away 

83 onPatch = data["patch"] == patch 

84 if sum(onPatch) == 0: 

85 stat = np.nan 

86 else: 

87 stat = nanMedian(data[yCol][onPatch]) 

88 try: 

89 patchTuple = (int(patch.split(",")[0]), int(patch.split(",")[-1])) 

90 patchInfo = tractInfo.getPatchInfo(patchTuple) 

91 gen3PatchId = tractInfo.getSequentialPatchIndex(patchInfo) 

92 except AttributeError: 

93 # For native gen 3 tables the patches don't need converting 

94 # When we are no longer looking at the gen 2 -> gen 3 

95 # converted repos we can tidy this up 

96 gen3PatchId = patch 

97 patchInfo = tractInfo.getPatchInfo(patch) 

98 

99 corners = Box2D(patchInfo.getInnerBBox()).getCorners() 

100 skyCoords = tractWcs.pixelToSky(corners) 

101 

102 patchInfoDict[gen3PatchId] = (skyCoords, stat) 

103 

104 tractCorners = Box2D(tractInfo.getBBox()).getCorners() 

105 skyCoords = tractWcs.pixelToSky(tractCorners) 

106 patchInfoDict["tract"] = (skyCoords, np.nan) 

107 

108 return patchInfoDict 

109 

110 

111def generateSummaryStatsVisit(cat, colName, visitSummaryTable): 

112 """Generate a summary statistic in each patch or detector. 

113 

114 Parameters 

115 ---------- 

116 cat : `pandas.core.frame.DataFrame` 

117 A dataframe of the data to be plotted. 

118 colName : `str` 

119 The name of the column to be plotted. 

120 visitSummaryTable : `pandas.core.frame.DataFrame` 

121 A dataframe of the visit summary table. 

122 

123 Returns 

124 ------- 

125 visitInfoDict : `dict` 

126 A dictionary of the visit information. 

127 """ 

128 visitInfoDict = {} 

129 for ccd in cat.detector.unique(): 

130 if ccd is None: 

131 continue 

132 onCcd = cat["detector"] == ccd 

133 stat = nanMedian(cat[colName].values[onCcd]) 

134 

135 sumRow = visitSummaryTable["id"] == ccd 

136 corners = zip(visitSummaryTable["raCorners"][sumRow][0], visitSummaryTable["decCorners"][sumRow][0]) 

137 cornersOut = [] 

138 for ra, dec in corners: 

139 corner = SpherePoint(ra, dec, units=degrees) 

140 cornersOut.append(corner) 

141 

142 visitInfoDict[ccd] = (cornersOut, stat) 

143 

144 return visitInfoDict 

145 

146 

147# Inspired by matplotlib.testing.remove_ticks_and_titles 

148def get_and_remove_axis_text(ax) -> Tuple[List[str], List[np.ndarray]]: 

149 """Remove text from an Axis and its children and return with line points. 

150 

151 Parameters 

152 ---------- 

153 ax : `plt.Axis` 

154 A matplotlib figure axis. 

155 

156 Returns 

157 ------- 

158 texts : `List[str]` 

159 A list of all text strings (title and axis/legend/tick labels). 

160 line_xys : `List[numpy.ndarray]` 

161 A list of all line ``_xy`` attributes (arrays of shape ``(N, 2)``). 

162 """ 

163 line_xys = [line._xy for line in ax.lines] 

164 texts = [text.get_text() for text in (ax.title, ax.xaxis.label, ax.yaxis.label)] 

165 ax.set_title("") 

166 ax.set_xlabel("") 

167 ax.set_ylabel("") 

168 

169 try: 

170 texts_legend = ax.get_legend().texts 

171 texts.extend(text.get_text() for text in texts_legend) 

172 for text in texts_legend: 

173 text.set_alpha(0) 

174 except AttributeError: 

175 pass 

176 

177 for idx in range(len(ax.texts)): 

178 texts.append(ax.texts[idx].get_text()) 

179 ax.texts[idx].set_text("") 

180 

181 ax.xaxis.set_major_formatter(null_formatter) 

182 ax.xaxis.set_minor_formatter(null_formatter) 

183 ax.yaxis.set_major_formatter(null_formatter) 

184 ax.yaxis.set_minor_formatter(null_formatter) 

185 try: 

186 ax.zaxis.set_major_formatter(null_formatter) 

187 ax.zaxis.set_minor_formatter(null_formatter) 

188 except AttributeError: 

189 pass 

190 for child in ax.child_axes: 

191 texts_child, lines_child = get_and_remove_axis_text(child) 

192 texts.extend(texts_child) 

193 

194 return texts, line_xys 

195 

196 

197def get_and_remove_figure_text(figure: Figure): 

198 """Remove text from a Figure and its Axes and return with line points. 

199 

200 Parameters 

201 ---------- 

202 figure : `matplotlib.pyplot.Figure` 

203 A matplotlib figure. 

204 

205 Returns 

206 ------- 

207 texts : `List[str]` 

208 A list of all text strings (title and axis/legend/tick labels). 

209 line_xys : `List[numpy.ndarray]`, (N, 2) 

210 A list of all line ``_xy`` attributes (arrays of shape ``(N, 2)``). 

211 """ 

212 texts = [str(figure._suptitle)] 

213 lines = [] 

214 figure.suptitle("") 

215 

216 texts.extend(text.get_text() for text in figure.texts) 

217 figure.texts = [] 

218 

219 for ax in figure.get_axes(): 

220 texts_ax, lines_ax = get_and_remove_axis_text(ax) 

221 texts.extend(texts_ax) 

222 lines.extend(lines_ax) 

223 

224 return texts, lines 

225 

226 

227def parsePlotInfo(plotInfo: Mapping[str, str]) -> str: 

228 """Extract information from the plotInfo dictionary and parses it into 

229 a meaningful string that can be added to a figure. 

230 

231 Parameters 

232 ---------- 

233 plotInfo : `dict`[`str`, `str`] 

234 A plotInfo dictionary containing useful information to 

235 be included on a figure. 

236 

237 Returns 

238 ------- 

239 infoText : `str` 

240 A string containing the plotInfo information, parsed in such a 

241 way that it can be included on a figure. 

242 """ 

243 photocalibDataset = "None" 

244 astroDataset = "None" 

245 

246 run = plotInfo["run"] 

247 datasetsUsed = f"\nPhotoCalib: {photocalibDataset}, Astrometry: {astroDataset}" 

248 tableType = f"\nTable: {plotInfo['tableName']}" 

249 

250 dataIdText = "" 

251 if "tract" in plotInfo.keys(): 

252 dataIdText += f", Tract: {plotInfo['tract']}" 

253 if "visit" in plotInfo.keys(): 

254 dataIdText += f", Visit: {plotInfo['visit']}" 

255 

256 bandText = "" 

257 for band in plotInfo["bands"]: 

258 bandText += band + ", " 

259 bandsText = f", Bands: {bandText[:-2]}" 

260 infoText = f"\n{run}{datasetsUsed}{tableType}{dataIdText}{bandsText}" 

261 

262 # Find S/N and mag keys, if present. 

263 snKeys = [] 

264 magKeys = [] 

265 selectionKeys = [] 

266 selectionPrefix = "Selection: " 

267 for key, value in plotInfo.items(): 

268 if "SN" in key or "S/N" in key: 

269 snKeys.append(key) 

270 elif "Mag" in key: 

271 magKeys.append(key) 

272 elif key.startswith(selectionPrefix): 

273 selectionKeys.append(key) 

274 # Add S/N and mag values to label, if present. 

275 # TODO: Do something if there are multiple sn/mag keys. Log? Warn? 

276 newline = "\n" 

277 if snKeys: 

278 infoText = f"{infoText}{newline if magKeys else ', '}{snKeys[0]}{plotInfo.get(snKeys[0])}" 

279 if magKeys: 

280 infoText = f"{infoText}, {magKeys[0]}{plotInfo.get(magKeys[0])}" 

281 if selectionKeys: 

282 nPrefix = len(selectionPrefix) 

283 selections = ", ".join(f"{key[nPrefix:]}: {plotInfo[key]}" for key in selectionKeys) 

284 infoText = f"{infoText}, Selections: {selections}" 

285 

286 return infoText 

287 

288 

289def addPlotInfo(fig: Figure, plotInfo: Mapping[str, str]) -> Figure: 

290 """Add useful information to the plot. 

291 

292 Parameters 

293 ---------- 

294 fig : `matplotlib.figure.Figure` 

295 The figure to add the information to. 

296 plotInfo : `dict` 

297 A dictionary of the plot information. 

298 

299 Returns 

300 ------- 

301 fig : `matplotlib.figure.Figure` 

302 The figure with the information added. 

303 """ 

304 fig.text(0.01, 0.99, plotInfo["plotName"], fontsize=7, transform=fig.transFigure, ha="left", va="top") 

305 infoText = parsePlotInfo(plotInfo) 

306 fig.text(0.01, 0.984, infoText, fontsize=6, transform=fig.transFigure, alpha=0.6, ha="left", va="top") 

307 

308 return fig 

309 

310 

311def mkColormap(colorNames): 

312 """Make a colormap from the list of color names. 

313 

314 Parameters 

315 ---------- 

316 colorNames : `list` 

317 A list of strings that correspond to matplotlib named colors. 

318 

319 Returns 

320 ------- 

321 cmap : `matplotlib.colors.LinearSegmentedColormap` 

322 A colormap stepping through the supplied list of names. 

323 """ 

324 nums = np.linspace(0, 1, len(colorNames)) 

325 blues = [] 

326 greens = [] 

327 reds = [] 

328 for num, color in zip(nums, colorNames): 

329 r, g, b = colors.colorConverter.to_rgb(color) 

330 blues.append((num, b, b)) 

331 greens.append((num, g, g)) 

332 reds.append((num, r, r)) 

333 

334 colorDict = {"blue": blues, "red": reds, "green": greens} 

335 cmap = colors.LinearSegmentedColormap("newCmap", colorDict) 

336 return cmap 

337 

338 

339def extremaSort(xs): 

340 """Return the IDs of the points reordered so that those furthest from the 

341 median, in absolute terms, are last. 

342 

343 Parameters 

344 ---------- 

345 xs : `np.array` 

346 An array of the values to sort 

347 

348 Returns 

349 ------- 

350 ids : `np.array` 

351 """ 

352 med = nanMedian(xs) 

353 dists = np.abs(xs - med) 

354 ids = np.argsort(dists) 

355 return ids 

356 

357 

358def sortAllArrays(arrsToSort, sortArrayIndex=0): 

359 """Sort one array and then return all the others in the associated order. 

360 

361 Parameters 

362 ---------- 

363 arrsToSort : `list` [`np.array`] 

364 A list of arrays to be simultaneously sorted based on the array in 

365 the list position given by ``sortArrayIndex`` (defaults to be the 

366 first array in the list). 

367 sortArrayIndex : `int`, optional 

368 Zero-based index indicating the array on which to base the sorting. 

369 

370 Returns 

371 ------- 

372 arrsToSort : `list` [`np.array`] 

373 The list of arrays sorted on array in list index ``sortArrayIndex``. 

374 """ 

375 ids = extremaSort(arrsToSort[sortArrayIndex]) 

376 for i, arr in enumerate(arrsToSort): 

377 arrsToSort[i] = arr[ids] 

378 return arrsToSort 

379 

380 

381def addSummaryPlot(fig, loc, sumStats, label): 

382 """Add a summary subplot to the figure. 

383 

384 Parameters 

385 ---------- 

386 fig : `matplotlib.figure.Figure` 

387 The figure that the summary plot is to be added to. 

388 loc : `matplotlib.gridspec.SubplotSpec` or `int` or `(int, int, index` 

389 Describes the location in the figure to put the summary plot, 

390 can be a gridspec SubplotSpec, a 3 digit integer where the first 

391 digit is the number of rows, the second is the number of columns 

392 and the third is the index. This is the same for the tuple 

393 of int, int, index. 

394 sumStats : `dict` 

395 A dictionary where the patchIds are the keys which store the R.A. 

396 and the dec of the corners of the patch, along with a summary 

397 statistic for each patch. 

398 label : `str` 

399 The label to be used for the colorbar. 

400 

401 Returns 

402 ------- 

403 fig : `matplotlib.figure.Figure` 

404 """ 

405 # Add the subplot to the relevant place in the figure 

406 # and sort the axis out 

407 axCorner = fig.add_subplot(loc) 

408 axCorner.yaxis.tick_right() 

409 axCorner.yaxis.set_label_position("right") 

410 axCorner.xaxis.tick_top() 

411 axCorner.xaxis.set_label_position("top") 

412 axCorner.set_aspect("equal") 

413 

414 # Plot the corners of the patches and make the color 

415 # coded rectangles for each patch, the colors show 

416 # the median of the given value in the patch 

417 patches = [] 

418 colors = [] 

419 for dataId in sumStats.keys(): 

420 (corners, stat) = sumStats[dataId] 

421 ra = corners[0][0].asDegrees() 

422 dec = corners[0][1].asDegrees() 

423 xy = (ra, dec) 

424 width = corners[2][0].asDegrees() - ra 

425 height = corners[2][1].asDegrees() - dec 

426 patches.append(Rectangle(xy, width, height)) 

427 colors.append(stat) 

428 ras = [ra.asDegrees() for (ra, dec) in corners] 

429 decs = [dec.asDegrees() for (ra, dec) in corners] 

430 axCorner.plot(ras + [ras[0]], decs + [decs[0]], "k", lw=0.5) 

431 cenX = ra + width / 2 

432 cenY = dec + height / 2 

433 if dataId != "tract": 

434 axCorner.annotate(dataId, (cenX, cenY), color="k", fontsize=4, ha="center", va="center") 

435 

436 # Set the bad color to transparent and make a masked array 

437 cmapPatch = plt.cm.coolwarm.copy() 

438 cmapPatch.set_bad(color="none") 

439 colors = np.ma.array(colors, mask=np.isnan(colors)) 

440 collection = PatchCollection(patches, cmap=cmapPatch) 

441 collection.set_array(colors) 

442 axCorner.add_collection(collection) 

443 

444 # Add some labels 

445 axCorner.set_xlabel("R.A. (deg)", fontsize=7) 

446 axCorner.set_ylabel("Dec. (deg)", fontsize=7) 

447 axCorner.tick_params(axis="both", labelsize=6, length=0, pad=1.5) 

448 axCorner.invert_xaxis() 

449 

450 # Add a colorbar 

451 pos = axCorner.get_position() 

452 yOffset = (pos.y1 - pos.y0) / 3 

453 cax = fig.add_axes([pos.x0, pos.y1 + yOffset, pos.x1 - pos.x0, 0.025]) 

454 plt.colorbar(collection, cax=cax, orientation="horizontal") 

455 cax.text( 

456 0.5, 

457 0.48, 

458 label, 

459 color="k", 

460 transform=cax.transAxes, 

461 rotation="horizontal", 

462 horizontalalignment="center", 

463 verticalalignment="center", 

464 fontsize=6, 

465 ) 

466 cax.tick_params( 

467 axis="x", labelsize=6, labeltop=True, labelbottom=False, bottom=False, top=True, pad=0.5, length=2 

468 ) 

469 

470 return fig 

471 

472 

473def shorten_list(numbers: Iterable[int], *, range_indicator: str = "-", range_separator: str = ",") -> str: 

474 """Shorten an iterable of integers. 

475 

476 Parameters 

477 ---------- 

478 numbers : `~collections.abc.Iterable` [`int`] 

479 Any iterable (list, set, tuple, numpy.array) of integers. 

480 range_indicator : `str`, optional 

481 The string to use to indicate a range of numbers. 

482 range_separator : `str`, optional 

483 The string to use to separate ranges of numbers. 

484 

485 Returns 

486 ------- 

487 result : `str` 

488 A shortened string representation of the list. 

489 

490 Examples 

491 -------- 

492 >>> shorten_list([1,2,3,5,6,8]) 

493 "1-3,5-6,8" 

494 

495 >>> shorten_list((1,2,3,5,6,8,9,10,11), range_separator=", ") 

496 "1-3, 5-6, 8-11" 

497 

498 >>> shorten_list(range(4), range_indicator="..") 

499 "0..3" 

500 """ 

501 # Sort the list in ascending order. 

502 numbers = sorted(numbers) 

503 

504 if not numbers: # empty container 

505 return "" 

506 

507 # Initialize an empty list to hold the results to be returned. 

508 result = [] 

509 

510 # Initialize variables to track the current start and end of a list. 

511 start = 0 

512 end = 0 # initialize to 0 to handle single element lists. 

513 

514 # Iterate through the sorted list of numbers 

515 for end in range(1, len(numbers)): 

516 # If the current number is the same or consecutive to the previous 

517 # number, skip to the next iteration. 

518 if numbers[end] > numbers[end - 1] + 1: # > is used to handle duplicates, if any. 

519 # If the current number is not consecutive to the previous number, 

520 # add the current range to the result and reset the start to end. 

521 if start == end - 1: 

522 result.append(str(numbers[start])) 

523 else: 

524 result.append(range_indicator.join((str(numbers[start]), str(numbers[end - 1])))) 

525 

526 # Update start. 

527 start = end 

528 

529 # Add the final range to the result. 

530 if start == end: 

531 result.append(str(numbers[start])) 

532 else: 

533 result.append(range_indicator.join((str(numbers[start]), str(numbers[end])))) 

534 

535 # Return the shortened string representation. 

536 return range_separator.join(result) 

537 

538 

539class PanelConfig(Config): 

540 """Configuration options for the plot panels used by DiaSkyPlot. 

541 

542 The defaults will produce a good-looking single panel plot. 

543 The subplot2grid* fields correspond to matplotlib.pyplot.subplot2grid. 

544 """ 

545 

546 topSpinesVisible = Field[bool]( 

547 doc="Draw line and ticks on top of panel?", 

548 default=False, 

549 ) 

550 bottomSpinesVisible = Field[bool]( 

551 doc="Draw line and ticks on bottom of panel?", 

552 default=True, 

553 ) 

554 leftSpinesVisible = Field[bool]( 

555 doc="Draw line and ticks on left side of panel?", 

556 default=True, 

557 ) 

558 rightSpinesVisible = Field[bool]( 

559 doc="Draw line and ticks on right side of panel?", 

560 default=True, 

561 ) 

562 subplot2gridShapeRow = Field[int]( 

563 doc="Number of rows of the grid in which to place axis.", 

564 default=10, 

565 ) 

566 subplot2gridShapeColumn = Field[int]( 

567 doc="Number of columns of the grid in which to place axis.", 

568 default=10, 

569 ) 

570 subplot2gridLocRow = Field[int]( 

571 doc="Row of the axis location within the grid.", 

572 default=1, 

573 ) 

574 subplot2gridLocColumn = Field[int]( 

575 doc="Column of the axis location within the grid.", 

576 default=1, 

577 ) 

578 subplot2gridRowspan = Field[int]( 

579 doc="Number of rows for the axis to span downwards.", 

580 default=5, 

581 ) 

582 subplot2gridColspan = Field[int]( 

583 doc="Number of rows for the axis to span to the right.", 

584 default=5, 

585 ) 

586 

587 

588def plotProjectionWithBinning( 

589 ax, 

590 xs, 

591 ys, 

592 zs, 

593 cmap, 

594 xMin, 

595 xMax, 

596 yMin, 

597 yMax, 

598 xNumBins=45, 

599 yNumBins=None, 

600 fixAroundZero=False, 

601 nPointBinThresh=5000, 

602 isSorted=False, 

603 vmin=None, 

604 vmax=None, 

605 showExtremeOutliers=True, 

606 scatPtSize=7, 

607): 

608 """Plot color-mapped data in projection and with binning when appropriate. 

609 

610 Parameters 

611 ---------- 

612 ax : `matplotlib.axes.Axes` 

613 Axis on which to plot the projection data. 

614 xs, ys : `np.array` 

615 Arrays containing the x and y positions of the data. 

616 zs : `np.array` 

617 Array containing the scaling value associated with the (``xs``, ``ys``) 

618 positions. 

619 cmap : `matplotlib.colors.Colormap` 

620 Colormap for the ``zs`` values. 

621 xMin, xMax, yMin, yMax : `float` 

622 Data limits within which to compute bin sizes. 

623 xNumBins : `int`, optional 

624 The number of bins along the x-axis. 

625 yNumBins : `int`, optional 

626 The number of bins along the y-axis. If `None`, this is set to equal 

627 ``xNumBins``. 

628 nPointBinThresh : `int`, optional 

629 Threshold number of points above which binning will be implemented 

630 for the plotting. If the number of data points is lower than this 

631 threshold, a basic scatter plot will be generated. 

632 isSorted : `bool`, optional 

633 Whether the data have been sorted in ``zs`` (the sorting is to 

634 accommodate the overplotting of points in the upper and lower 

635 extrema of the data). 

636 vmin, vmax : `float`, optional 

637 The min and max limits for the colorbar. 

638 showExtremeOutliers: `bool`, default True 

639 Use overlaid scatter points to show the x-y positions of the 15% 

640 most extreme values. 

641 scatPtSize : `float`, optional 

642 The point size to use if just plotting a regular scatter plot. 

643 

644 Returns 

645 ------- 

646 plotOut : `matplotlib.collections.PathCollection` 

647 The plot object with ``ax`` updated with data plotted here. 

648 """ 

649 med = nanMedian(zs) 

650 mad = nanSigmaMad(zs) 

651 if vmin is None: 

652 vmin = med - 2 * mad 

653 if vmax is None: 

654 vmax = med + 2 * mad 

655 if fixAroundZero: 

656 scaleEnd = np.max([np.abs(vmin), np.abs(vmax)]) 

657 vmin = -1 * scaleEnd 

658 vmax = scaleEnd 

659 

660 yNumBins = xNumBins if yNumBins is None else yNumBins 

661 

662 xBinEdges = np.linspace(xMin, xMax, xNumBins + 1) 

663 yBinEdges = np.linspace(yMin, yMax, yNumBins + 1) 

664 binnedStats, xEdges, yEdges, binNums = binned_statistic_2d( 

665 xs, ys, zs, statistic="median", bins=(xBinEdges, yBinEdges) 

666 ) 

667 if len(xs) >= nPointBinThresh: 

668 s = min(10, max(0.5, nPointBinThresh / 10 / (len(xs) ** 0.5))) 

669 lw = (s**0.5) / 10 

670 plotOut = ax.imshow( 

671 binnedStats.T, 

672 cmap=cmap, 

673 extent=[xEdges[0], xEdges[-1], yEdges[-1], yEdges[0]], 

674 vmin=vmin, 

675 vmax=vmax, 

676 ) 

677 if not isSorted: 

678 sortedArrays = sortAllArrays([zs, xs, ys]) 

679 zs, xs, ys = sortedArrays[0], sortedArrays[1], sortedArrays[2] 

680 if len(xs) > 1: 

681 if showExtremeOutliers: 

682 # Find the most extreme 15% of points. The list is ordered 

683 # by the distance from the median, this is just the 

684 # head/tail 15% of points. 

685 extremes = int(np.floor((len(xs) / 100)) * 85) 

686 plotOut = ax.scatter( 

687 xs[extremes:], 

688 ys[extremes:], 

689 c=zs[extremes:], 

690 s=s, 

691 cmap=cmap, 

692 vmin=vmin, 

693 vmax=vmax, 

694 edgecolor="white", 

695 linewidths=lw, 

696 ) 

697 else: 

698 plotOut = ax.scatter( 

699 xs, 

700 ys, 

701 c=zs, 

702 cmap=cmap, 

703 s=scatPtSize, 

704 vmin=vmin, 

705 vmax=vmax, 

706 edgecolor="white", 

707 linewidths=0.2, 

708 ) 

709 return plotOut