Coverage for python/lsst/analysis/tools/actions/plot/plotUtils.py: 14%

279 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-24 04:10 -0700

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23__all__ = ("PanelConfig",) 

24 

25from typing import TYPE_CHECKING, Iterable, List, Mapping, Tuple 

26 

27import matplotlib 

28import matplotlib.pyplot as plt 

29import numpy as np 

30from lsst.geom import Box2D, SpherePoint, degrees 

31from lsst.pex.config import Config, Field 

32from matplotlib import colors 

33from matplotlib.collections import PatchCollection 

34from matplotlib.patches import Rectangle 

35from scipy.stats import binned_statistic_2d 

36 

37from ...math import nanMedian, nanSigmaMad 

38 

39if TYPE_CHECKING: 39 ↛ 40line 39 didn't jump to line 40, because the condition on line 39 was never true

40 from matplotlib.figure import Figure 

41 

42null_formatter = matplotlib.ticker.NullFormatter() 

43 

44 

45def parsePlotInfo(dataId, runName, tableName, bands, plotName, SN): 

46 """Parse plot info from the dataId. 

47 

48 Parameters 

49 ---------- 

50 dataId : `dict` 

51 The dataId of the data to be plotted. 

52 runName : `str` 

53 The name of the run. 

54 tableName : `str` 

55 The name of the table. 

56 bands : `list` [`str`] 

57 The bands to be plotted. 

58 plotName : `str` 

59 The name of the plot. 

60 SN : `str` 

61 The signal to noise of the data. 

62 

63 Returns 

64 ------- 

65 plotInfo : `dict` 

66 A dictionary of the plot information. 

67 """ 

68 plotInfo = {"run": runName, "tableName": tableName, "plotName": plotName, "SN": SN} 

69 

70 for dataInfo in dataId: 

71 plotInfo[dataInfo.name] = dataId[dataInfo.name] 

72 

73 bandStr = "" 

74 for band in bands: 

75 bandStr += ", " + band 

76 plotInfo["bands"] = bandStr[2:] 

77 

78 if "tract" not in plotInfo.keys(): 

79 plotInfo["tract"] = "N/A" 

80 if "visit" not in plotInfo.keys(): 

81 plotInfo["visit"] = "N/A" 

82 

83 return plotInfo 

84 

85 

86def generateSummaryStats(data, skymap, plotInfo): 

87 """Generate a summary statistic in each patch or detector. 

88 

89 Parameters 

90 ---------- 

91 data : `dict` 

92 A dictionary of the data to be plotted. 

93 skymap : `lsst.skymap.ringsSkyMap.RingsSkyMap` 

94 The skymap associated with the data. 

95 plotInfo : `dict` 

96 A dictionary of the plot information. 

97 

98 Returns 

99 ------- 

100 patchInfoDict : `dict` 

101 A dictionary of the patch information. 

102 """ 

103 # TODO: what is the more generic type of skymap? 

104 tractInfo = skymap.generateTract(plotInfo["tract"]) 

105 tractWcs = tractInfo.getWcs() 

106 

107 # For now also convert the gen 2 patchIds to gen 3 

108 if "y" in data.keys(): 

109 yCol = "y" 

110 elif "yStars" in data.keys(): 

111 yCol = "yStars" 

112 elif "yGalaxies" in data.keys(): 

113 yCol = "yGalaxies" 

114 elif "yUnknowns" in data.keys(): 

115 yCol = "yUnknowns" 

116 

117 patchInfoDict = {} 

118 maxPatchNum = tractInfo.num_patches.x * tractInfo.num_patches.y 

119 patches = np.arange(0, maxPatchNum, 1) 

120 for patch in patches: 

121 if patch is None: 

122 continue 

123 # Once the objectTable_tract catalogues are using gen 3 patches 

124 # this will go away 

125 onPatch = data["patch"] == patch 

126 if sum(onPatch) == 0: 

127 stat = np.nan 

128 else: 

129 stat = nanMedian(data[yCol][onPatch]) 

130 try: 

131 patchTuple = (int(patch.split(",")[0]), int(patch.split(",")[-1])) 

132 patchInfo = tractInfo.getPatchInfo(patchTuple) 

133 gen3PatchId = tractInfo.getSequentialPatchIndex(patchInfo) 

134 except AttributeError: 

135 # For native gen 3 tables the patches don't need converting 

136 # When we are no longer looking at the gen 2 -> gen 3 

137 # converted repos we can tidy this up 

138 gen3PatchId = patch 

139 patchInfo = tractInfo.getPatchInfo(patch) 

140 

141 corners = Box2D(patchInfo.getInnerBBox()).getCorners() 

142 skyCoords = tractWcs.pixelToSky(corners) 

143 

144 patchInfoDict[gen3PatchId] = (skyCoords, stat) 

145 

146 tractCorners = Box2D(tractInfo.getBBox()).getCorners() 

147 skyCoords = tractWcs.pixelToSky(tractCorners) 

148 patchInfoDict["tract"] = (skyCoords, np.nan) 

149 

150 return patchInfoDict 

151 

152 

153def generateSummaryStatsVisit(cat, colName, visitSummaryTable): 

154 """Generate a summary statistic in each patch or detector. 

155 

156 Parameters 

157 ---------- 

158 cat : `pandas.core.frame.DataFrame` 

159 A dataframe of the data to be plotted. 

160 colName : `str` 

161 The name of the column to be plotted. 

162 visitSummaryTable : `pandas.core.frame.DataFrame` 

163 A dataframe of the visit summary table. 

164 

165 Returns 

166 ------- 

167 visitInfoDict : `dict` 

168 A dictionary of the visit information. 

169 """ 

170 visitInfoDict = {} 

171 for ccd in cat.detector.unique(): 

172 if ccd is None: 

173 continue 

174 onCcd = cat["detector"] == ccd 

175 stat = nanMedian(cat[colName].values[onCcd]) 

176 

177 sumRow = visitSummaryTable["id"] == ccd 

178 corners = zip(visitSummaryTable["raCorners"][sumRow][0], visitSummaryTable["decCorners"][sumRow][0]) 

179 cornersOut = [] 

180 for ra, dec in corners: 

181 corner = SpherePoint(ra, dec, units=degrees) 

182 cornersOut.append(corner) 

183 

184 visitInfoDict[ccd] = (cornersOut, stat) 

185 

186 return visitInfoDict 

187 

188 

189# Inspired by matplotlib.testing.remove_ticks_and_titles 

190def get_and_remove_axis_text(ax) -> Tuple[List[str], List[np.ndarray]]: 

191 """Remove text from an Axis and its children and return with line points. 

192 

193 Parameters 

194 ---------- 

195 ax : `plt.Axis` 

196 A matplotlib figure axis. 

197 

198 Returns 

199 ------- 

200 texts : `List[str]` 

201 A list of all text strings (title and axis/legend/tick labels). 

202 line_xys : `List[numpy.ndarray]` 

203 A list of all line ``_xy`` attributes (arrays of shape ``(N, 2)``). 

204 """ 

205 line_xys = [line._xy for line in ax.lines] 

206 texts = [text.get_text() for text in (ax.title, ax.xaxis.label, ax.yaxis.label)] 

207 ax.set_title("") 

208 ax.set_xlabel("") 

209 ax.set_ylabel("") 

210 

211 try: 

212 texts_legend = ax.get_legend().texts 

213 texts.extend(text.get_text() for text in texts_legend) 

214 for text in texts_legend: 

215 text.set_alpha(0) 

216 except AttributeError: 

217 pass 

218 

219 for idx in range(len(ax.texts)): 

220 texts.append(ax.texts[idx].get_text()) 

221 ax.texts[idx].set_text("") 

222 

223 ax.xaxis.set_major_formatter(null_formatter) 

224 ax.xaxis.set_minor_formatter(null_formatter) 

225 ax.yaxis.set_major_formatter(null_formatter) 

226 ax.yaxis.set_minor_formatter(null_formatter) 

227 try: 

228 ax.zaxis.set_major_formatter(null_formatter) 

229 ax.zaxis.set_minor_formatter(null_formatter) 

230 except AttributeError: 

231 pass 

232 for child in ax.child_axes: 

233 texts_child, lines_child = get_and_remove_axis_text(child) 

234 texts.extend(texts_child) 

235 

236 return texts, line_xys 

237 

238 

239def get_and_remove_figure_text(figure: Figure): 

240 """Remove text from a Figure and its Axes and return with line points. 

241 

242 Parameters 

243 ---------- 

244 figure : `matplotlib.pyplot.Figure` 

245 A matplotlib figure. 

246 

247 Returns 

248 ------- 

249 texts : `List[str]` 

250 A list of all text strings (title and axis/legend/tick labels). 

251 line_xys : `List[numpy.ndarray]`, (N, 2) 

252 A list of all line ``_xy`` attributes (arrays of shape ``(N, 2)``). 

253 """ 

254 texts = [str(figure._suptitle)] 

255 lines = [] 

256 figure.suptitle("") 

257 

258 texts.extend(text.get_text() for text in figure.texts) 

259 figure.texts = [] 

260 

261 for ax in figure.get_axes(): 

262 texts_ax, lines_ax = get_and_remove_axis_text(ax) 

263 texts.extend(texts_ax) 

264 lines.extend(lines_ax) 

265 

266 return texts, lines 

267 

268 

269def addPlotInfo(fig: Figure, plotInfo: Mapping[str, str]) -> Figure: 

270 """Add useful information to the plot. 

271 

272 Parameters 

273 ---------- 

274 fig : `matplotlib.figure.Figure` 

275 The figure to add the information to. 

276 plotInfo : `dict` 

277 A dictionary of the plot information. 

278 

279 Returns 

280 ------- 

281 fig : `matplotlib.figure.Figure` 

282 The figure with the information added. 

283 """ 

284 # TO DO: figure out how to get this information 

285 photocalibDataset = "None" 

286 astroDataset = "None" 

287 

288 fig.text(0.01, 0.99, plotInfo["plotName"], fontsize=7, transform=fig.transFigure, ha="left", va="top") 

289 

290 run = plotInfo["run"] 

291 datasetsUsed = f"\nPhotoCalib: {photocalibDataset}, Astrometry: {astroDataset}" 

292 tableType = f"\nTable: {plotInfo['tableName']}" 

293 

294 dataIdText = "" 

295 if "tract" in plotInfo.keys(): 

296 dataIdText += f", Tract: {plotInfo['tract']}" 

297 if "visit" in plotInfo.keys(): 

298 dataIdText += f", Visit: {plotInfo['visit']}" 

299 

300 bandText = "" 

301 for band in plotInfo["bands"]: 

302 bandText += band + ", " 

303 bandsText = f", Bands: {bandText[:-2]}" 

304 infoText = f"\n{run}{datasetsUsed}{tableType}{dataIdText}{bandsText}" 

305 

306 # Find S/N and mag keys, if present. 

307 snKeys = [] 

308 magKeys = [] 

309 selectionKeys = [] 

310 selectionPrefix = "Selection: " 

311 for key, value in plotInfo.items(): 

312 if "SN" in key or "S/N" in key: 

313 snKeys.append(key) 

314 elif "Mag" in key: 

315 magKeys.append(key) 

316 elif key.startswith(selectionPrefix): 

317 selectionKeys.append(key) 

318 # Add S/N and mag values to label, if present. 

319 # TODO: Do something if there are multiple sn/mag keys. Log? Warn? 

320 newline = "\n" 

321 if snKeys: 

322 infoText = f"{infoText}{newline if magKeys else ', '}{snKeys[0]}{plotInfo.get(snKeys[0])}" 

323 if magKeys: 

324 infoText = f"{infoText}, {magKeys[0]}{plotInfo.get(magKeys[0])}" 

325 if selectionKeys: 

326 nPrefix = len(selectionPrefix) 

327 selections = ", ".join(f"{key[nPrefix:]}: {plotInfo[key]}" for key in selectionKeys) 

328 infoText = f"{infoText}, Selections: {selections}" 

329 

330 fig.text(0.01, 0.984, infoText, fontsize=6, transform=fig.transFigure, alpha=0.6, ha="left", va="top") 

331 

332 return fig 

333 

334 

335def mkColormap(colorNames): 

336 """Make a colormap from the list of color names. 

337 

338 Parameters 

339 ---------- 

340 colorNames : `list` 

341 A list of strings that correspond to matplotlib named colors. 

342 

343 Returns 

344 ------- 

345 cmap : `matplotlib.colors.LinearSegmentedColormap` 

346 A colormap stepping through the supplied list of names. 

347 """ 

348 nums = np.linspace(0, 1, len(colorNames)) 

349 blues = [] 

350 greens = [] 

351 reds = [] 

352 for num, color in zip(nums, colorNames): 

353 r, g, b = colors.colorConverter.to_rgb(color) 

354 blues.append((num, b, b)) 

355 greens.append((num, g, g)) 

356 reds.append((num, r, r)) 

357 

358 colorDict = {"blue": blues, "red": reds, "green": greens} 

359 cmap = colors.LinearSegmentedColormap("newCmap", colorDict) 

360 return cmap 

361 

362 

363def extremaSort(xs): 

364 """Return the IDs of the points reordered so that those furthest from the 

365 median, in absolute terms, are last. 

366 

367 Parameters 

368 ---------- 

369 xs : `np.array` 

370 An array of the values to sort 

371 

372 Returns 

373 ------- 

374 ids : `np.array` 

375 """ 

376 med = nanMedian(xs) 

377 dists = np.abs(xs - med) 

378 ids = np.argsort(dists) 

379 return ids 

380 

381 

382def sortAllArrays(arrsToSort, sortArrayIndex=0): 

383 """Sort one array and then return all the others in the associated order. 

384 

385 Parameters 

386 ---------- 

387 arrsToSort : `list` [`np.array`] 

388 A list of arrays to be simultaneously sorted based on the array in 

389 the list position given by ``sortArrayIndex`` (defaults to be the 

390 first array in the list). 

391 sortArrayIndex : `int`, optional 

392 Zero-based index indicating the array on which to base the sorting. 

393 

394 Returns 

395 ------- 

396 arrsToSort : `list` [`np.array`] 

397 The list of arrays sorted on array in list index ``sortArrayIndex``. 

398 """ 

399 ids = extremaSort(arrsToSort[sortArrayIndex]) 

400 for i, arr in enumerate(arrsToSort): 

401 arrsToSort[i] = arr[ids] 

402 return arrsToSort 

403 

404 

405def addSummaryPlot(fig, loc, sumStats, label): 

406 """Add a summary subplot to the figure. 

407 

408 Parameters 

409 ---------- 

410 fig : `matplotlib.figure.Figure` 

411 The figure that the summary plot is to be added to. 

412 loc : `matplotlib.gridspec.SubplotSpec` or `int` or `(int, int, index` 

413 Describes the location in the figure to put the summary plot, 

414 can be a gridspec SubplotSpec, a 3 digit integer where the first 

415 digit is the number of rows, the second is the number of columns 

416 and the third is the index. This is the same for the tuple 

417 of int, int, index. 

418 sumStats : `dict` 

419 A dictionary where the patchIds are the keys which store the R.A. 

420 and the dec of the corners of the patch, along with a summary 

421 statistic for each patch. 

422 label : `str` 

423 The label to be used for the colorbar. 

424 

425 Returns 

426 ------- 

427 fig : `matplotlib.figure.Figure` 

428 """ 

429 # Add the subplot to the relevant place in the figure 

430 # and sort the axis out 

431 axCorner = fig.add_subplot(loc) 

432 axCorner.yaxis.tick_right() 

433 axCorner.yaxis.set_label_position("right") 

434 axCorner.xaxis.tick_top() 

435 axCorner.xaxis.set_label_position("top") 

436 axCorner.set_aspect("equal") 

437 

438 # Plot the corners of the patches and make the color 

439 # coded rectangles for each patch, the colors show 

440 # the median of the given value in the patch 

441 patches = [] 

442 colors = [] 

443 for dataId in sumStats.keys(): 

444 (corners, stat) = sumStats[dataId] 

445 ra = corners[0][0].asDegrees() 

446 dec = corners[0][1].asDegrees() 

447 xy = (ra, dec) 

448 width = corners[2][0].asDegrees() - ra 

449 height = corners[2][1].asDegrees() - dec 

450 patches.append(Rectangle(xy, width, height)) 

451 colors.append(stat) 

452 ras = [ra.asDegrees() for (ra, dec) in corners] 

453 decs = [dec.asDegrees() for (ra, dec) in corners] 

454 axCorner.plot(ras + [ras[0]], decs + [decs[0]], "k", lw=0.5) 

455 cenX = ra + width / 2 

456 cenY = dec + height / 2 

457 if dataId != "tract": 

458 axCorner.annotate(dataId, (cenX, cenY), color="k", fontsize=4, ha="center", va="center") 

459 

460 # Set the bad color to transparent and make a masked array 

461 cmapPatch = plt.cm.coolwarm.copy() 

462 cmapPatch.set_bad(color="none") 

463 colors = np.ma.array(colors, mask=np.isnan(colors)) 

464 collection = PatchCollection(patches, cmap=cmapPatch) 

465 collection.set_array(colors) 

466 axCorner.add_collection(collection) 

467 

468 # Add some labels 

469 axCorner.set_xlabel("R.A. (deg)", fontsize=7) 

470 axCorner.set_ylabel("Dec. (deg)", fontsize=7) 

471 axCorner.tick_params(axis="both", labelsize=6, length=0, pad=1.5) 

472 axCorner.invert_xaxis() 

473 

474 # Add a colorbar 

475 pos = axCorner.get_position() 

476 yOffset = (pos.y1 - pos.y0) / 3 

477 cax = fig.add_axes([pos.x0, pos.y1 + yOffset, pos.x1 - pos.x0, 0.025]) 

478 plt.colorbar(collection, cax=cax, orientation="horizontal") 

479 cax.text( 

480 0.5, 

481 0.48, 

482 label, 

483 color="k", 

484 transform=cax.transAxes, 

485 rotation="horizontal", 

486 horizontalalignment="center", 

487 verticalalignment="center", 

488 fontsize=6, 

489 ) 

490 cax.tick_params( 

491 axis="x", labelsize=6, labeltop=True, labelbottom=False, bottom=False, top=True, pad=0.5, length=2 

492 ) 

493 

494 return fig 

495 

496 

497def shorten_list(numbers: Iterable[int], *, range_indicator: str = "-", range_separator: str = ",") -> str: 

498 """Shorten an iterable of integers. 

499 

500 Parameters 

501 ---------- 

502 numbers : `~collections.abc.Iterable` [`int`] 

503 Any iterable (list, set, tuple, numpy.array) of integers. 

504 range_indicator : `str`, optional 

505 The string to use to indicate a range of numbers. 

506 range_separator : `str`, optional 

507 The string to use to separate ranges of numbers. 

508 

509 Returns 

510 ------- 

511 result : `str` 

512 A shortened string representation of the list. 

513 

514 Examples 

515 -------- 

516 >>> shorten_list([1,2,3,5,6,8]) 

517 "1-3,5-6,8" 

518 

519 >>> shorten_list((1,2,3,5,6,8,9,10,11), range_separator=", ") 

520 "1-3, 5-6, 8-11" 

521 

522 >>> shorten_list(range(4), range_indicator="..") 

523 "0..3" 

524 """ 

525 # Sort the list in ascending order. 

526 numbers = sorted(numbers) 

527 

528 if not numbers: # empty container 

529 return "" 

530 

531 # Initialize an empty list to hold the results to be returned. 

532 result = [] 

533 

534 # Initialize variables to track the current start and end of a list. 

535 start = 0 

536 end = 0 # initialize to 0 to handle single element lists. 

537 

538 # Iterate through the sorted list of numbers 

539 for end in range(1, len(numbers)): 

540 # If the current number is the same or consecutive to the previous 

541 # number, skip to the next iteration. 

542 if numbers[end] > numbers[end - 1] + 1: # > is used to handle duplicates, if any. 

543 # If the current number is not consecutive to the previous number, 

544 # add the current range to the result and reset the start to end. 

545 if start == end - 1: 

546 result.append(str(numbers[start])) 

547 else: 

548 result.append(range_indicator.join((str(numbers[start]), str(numbers[end - 1])))) 

549 

550 # Update start. 

551 start = end 

552 

553 # Add the final range to the result. 

554 if start == end: 

555 result.append(str(numbers[start])) 

556 else: 

557 result.append(range_indicator.join((str(numbers[start]), str(numbers[end])))) 

558 

559 # Return the shortened string representation. 

560 return range_separator.join(result) 

561 

562 

563class PanelConfig(Config): 

564 """Configuration options for the plot panels used by DiaSkyPlot. 

565 

566 The defaults will produce a good-looking single panel plot. 

567 The subplot2grid* fields correspond to matplotlib.pyplot.subplot2grid. 

568 """ 

569 

570 topSpinesVisible = Field[bool]( 

571 doc="Draw line and ticks on top of panel?", 

572 default=False, 

573 ) 

574 bottomSpinesVisible = Field[bool]( 

575 doc="Draw line and ticks on bottom of panel?", 

576 default=True, 

577 ) 

578 leftSpinesVisible = Field[bool]( 

579 doc="Draw line and ticks on left side of panel?", 

580 default=True, 

581 ) 

582 rightSpinesVisible = Field[bool]( 

583 doc="Draw line and ticks on right side of panel?", 

584 default=True, 

585 ) 

586 subplot2gridShapeRow = Field[int]( 

587 doc="Number of rows of the grid in which to place axis.", 

588 default=10, 

589 ) 

590 subplot2gridShapeColumn = Field[int]( 

591 doc="Number of columns of the grid in which to place axis.", 

592 default=10, 

593 ) 

594 subplot2gridLocRow = Field[int]( 

595 doc="Row of the axis location within the grid.", 

596 default=1, 

597 ) 

598 subplot2gridLocColumn = Field[int]( 

599 doc="Column of the axis location within the grid.", 

600 default=1, 

601 ) 

602 subplot2gridRowspan = Field[int]( 

603 doc="Number of rows for the axis to span downwards.", 

604 default=5, 

605 ) 

606 subplot2gridColspan = Field[int]( 

607 doc="Number of rows for the axis to span to the right.", 

608 default=5, 

609 ) 

610 

611 

612def plotProjectionWithBinning( 

613 ax, 

614 xs, 

615 ys, 

616 zs, 

617 cmap, 

618 xMin, 

619 xMax, 

620 yMin, 

621 yMax, 

622 xNumBins=45, 

623 yNumBins=None, 

624 fixAroundZero=False, 

625 nPointBinThresh=5000, 

626 isSorted=False, 

627 vmin=None, 

628 vmax=None, 

629 showExtremeOutliers=True, 

630 scatPtSize=7, 

631): 

632 """Plot color-mapped data in projection and with binning when appropriate. 

633 

634 Parameters 

635 ---------- 

636 ax : `matplotlib.axes.Axes` 

637 Axis on which to plot the projection data. 

638 xs, ys : `np.array` 

639 Arrays containing the x and y positions of the data. 

640 zs : `np.array` 

641 Array containing the scaling value associated with the (``xs``, ``ys``) 

642 positions. 

643 cmap : `matplotlib.colors.Colormap` 

644 Colormap for the ``zs`` values. 

645 xMin, xMax, yMin, yMax : `float` 

646 Data limits within which to compute bin sizes. 

647 xNumBins : `int`, optional 

648 The number of bins along the x-axis. 

649 yNumBins : `int`, optional 

650 The number of bins along the y-axis. If `None`, this is set to equal 

651 ``xNumBins``. 

652 nPointBinThresh : `int`, optional 

653 Threshold number of points above which binning will be implemented 

654 for the plotting. If the number of data points is lower than this 

655 threshold, a basic scatter plot will be generated. 

656 isSorted : `bool`, optional 

657 Whether the data have been sorted in ``zs`` (the sorting is to 

658 accommodate the overplotting of points in the upper and lower 

659 extrema of the data). 

660 vmin, vmax : `float`, optional 

661 The min and max limits for the colorbar. 

662 showExtremeOutliers: `bool`, default True 

663 Use overlaid scatter points to show the x-y positions of the 15% 

664 most extreme values. 

665 scatPtSize : `float`, optional 

666 The point size to use if just plotting a regular scatter plot. 

667 

668 Returns 

669 ------- 

670 plotOut : `matplotlib.collections.PathCollection` 

671 The plot object with ``ax`` updated with data plotted here. 

672 """ 

673 med = nanMedian(zs) 

674 mad = nanSigmaMad(zs) 

675 if vmin is None: 

676 vmin = med - 2 * mad 

677 if vmax is None: 

678 vmax = med + 2 * mad 

679 if fixAroundZero: 

680 scaleEnd = np.max([np.abs(vmin), np.abs(vmax)]) 

681 vmin = -1 * scaleEnd 

682 vmax = scaleEnd 

683 

684 yNumBins = xNumBins if yNumBins is None else yNumBins 

685 

686 xBinEdges = np.linspace(xMin, xMax, xNumBins + 1) 

687 yBinEdges = np.linspace(yMin, yMax, yNumBins + 1) 

688 binnedStats, xEdges, yEdges, binNums = binned_statistic_2d( 

689 xs, ys, zs, statistic="median", bins=(xBinEdges, yBinEdges) 

690 ) 

691 if len(xs) >= nPointBinThresh: 

692 s = min(10, max(0.5, nPointBinThresh / 10 / (len(xs) ** 0.5))) 

693 lw = (s**0.5) / 10 

694 plotOut = ax.imshow( 

695 binnedStats.T, 

696 cmap=cmap, 

697 extent=[xEdges[0], xEdges[-1], yEdges[-1], yEdges[0]], 

698 vmin=vmin, 

699 vmax=vmax, 

700 ) 

701 if not isSorted: 

702 sortedArrays = sortAllArrays([zs, xs, ys]) 

703 zs, xs, ys = sortedArrays[0], sortedArrays[1], sortedArrays[2] 

704 if len(xs) > 1: 

705 if showExtremeOutliers: 

706 # Find the most extreme 15% of points. The list is ordered 

707 # by the distance from the median, this is just the 

708 # head/tail 15% of points. 

709 extremes = int(np.floor((len(xs) / 100)) * 85) 

710 plotOut = ax.scatter( 

711 xs[extremes:], 

712 ys[extremes:], 

713 c=zs[extremes:], 

714 s=s, 

715 cmap=cmap, 

716 vmin=vmin, 

717 vmax=vmax, 

718 edgecolor="white", 

719 linewidths=lw, 

720 ) 

721 else: 

722 plotOut = ax.scatter( 

723 xs, 

724 ys, 

725 c=zs, 

726 cmap=cmap, 

727 s=scatPtSize, 

728 vmin=vmin, 

729 vmax=vmax, 

730 edgecolor="white", 

731 linewidths=0.2, 

732 ) 

733 return plotOut