Coverage for python / lsst / summit / utils / plotting.py: 10%

167 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-24 09:02 +0000

1# This file is part of summit_utils. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21from __future__ import annotations 

22 

23import logging 

24from typing import TYPE_CHECKING 

25 

26import astropy.visualization as vis 

27import matplotlib 

28import matplotlib.colors as colors 

29import numpy as np 

30from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable 

31 

32import lsst.afw.detection as afwDetection 

33import lsst.afw.geom as afwGeom 

34import lsst.afw.image as afwImage 

35import lsst.afw.table as afwTable 

36import lsst.geom as geom 

37from lsst.summit.utils.utils import getImageArray, getQuantiles 

38from lsst.utils.plotting.figures import make_figure 

39 

40if TYPE_CHECKING: 

41 from matplotlib.figure import Figure 

42 

43 

44def drawCompass( 

45 ax: matplotlib.axes.Axes, 

46 wcs: afwGeom.SkyWcs, 

47 compassLocation: int = 300, 

48 arrowLength: float = 300.0, 

49) -> matplotlib.axes.Axes: 

50 """ 

51 Draw the compass. 

52 The arrowLength is the length of compass arrows (arrows should have 

53 the same length). 

54 The steps here are: 

55 - transform the (compassLocation, compassLocation) to RA, DEC coordinates 

56 - move this point in DEC to get N; in RA to get E directions 

57 - transform N and E points back to pixel coordinates 

58 - find linear solutions for lines connecting the center of 

59 the compass with N and E points 

60 - find points along those lines located at the distance of 

61 arrowLength form the (compassLocation, compassLocation). 

62 - there will be two points for each linear solution. 

63 Choose the correct one. 

64 - centers of the N/E labels will also be located on those lines. 

65 

66 Parameters 

67 ---------- 

68 ax : `matplotlib.axes.Axes` 

69 The axes on which the compass will be drawn. 

70 wcs : `lsst.afw.geom.SkyWcs` 

71 WCS from exposure. 

72 compassLocation : `int`, optional 

73 How far in from the bottom left of the image to display the compass. 

74 arrowLength : `float`, optional 

75 The length of the compass arrow. 

76 Returns 

77 ------- 

78 ax : `matplotlib.axes.Axes` 

79 The axes with the compass. 

80 """ 

81 

82 anchorRa, anchorDec = wcs.pixelToSky(compassLocation, compassLocation) 

83 east = wcs.skyToPixel(geom.SpherePoint(anchorRa + 30.0 * geom.arcseconds, anchorDec)) 

84 north = wcs.skyToPixel(geom.SpherePoint(anchorRa, anchorDec + 30.0 * geom.arcseconds)) 

85 labelPosition = arrowLength + 50.0 

86 

87 for xy, label in [(north, "N"), (east, "E")]: 

88 xTip = compassLocation 

89 xTipLabel = compassLocation 

90 if compassLocation == xy[0]: 

91 if xy[1] > compassLocation: 

92 yTip = compassLocation + arrowLength 

93 yTipLabel = compassLocation + labelPosition 

94 else: 

95 yTip = compassLocation - arrowLength 

96 yTipLabel = compassLocation - labelPosition 

97 else: 

98 slope = (xy[1] - compassLocation) / (xy[0] - compassLocation) 

99 xTipProjection = arrowLength / np.sqrt(1.0 + slope**2) 

100 xTipLabelProjection = labelPosition / np.sqrt(1.0 + slope**2) 

101 

102 if xy[0] > compassLocation: 

103 xTip = compassLocation + xTipProjection 

104 xTipLabel = compassLocation + xTipLabelProjection 

105 elif xy[0] < compassLocation: 

106 xTip = compassLocation - xTipProjection 

107 xTipLabel = compassLocation - xTipLabelProjection 

108 yTip = slope * (xTip - compassLocation) + compassLocation 

109 yTipLabel = slope * (xTipLabel - compassLocation) + compassLocation 

110 

111 color = "r" 

112 ax.arrow( 

113 compassLocation, 

114 compassLocation, 

115 xTip - compassLocation, 

116 yTip - compassLocation, 

117 head_width=30.0, 

118 length_includes_head=True, 

119 color=color, 

120 ) 

121 ax.text(xTipLabel, yTipLabel, label, ha="center", va="center", color=color) 

122 return ax 

123 

124 

125def plot( 

126 inputData: np.ndarray | afwImage.Exposure | afwImage.Image | afwImage.MaskedImage, 

127 figure: matplotlib.figure.Figure | None = None, 

128 centroids: list[tuple[int, int]] | None = None, 

129 footprints: ( 

130 afwDetection.FootprintSet | afwDetection.Footprint | list[afwDetection.Footprint] | None 

131 ) = None, 

132 sourceCat: afwTable.SourceCatalog = None, 

133 title: str | None = None, 

134 showCompass: bool = True, 

135 stretch: str = "linear", 

136 percentile: float = 99.0, 

137 cmap: str = "gray", 

138 compassLocation: int = 300, 

139 addLegend: bool = False, 

140 savePlotAs: str | None = None, 

141 logger: logging.Logger | None = None, 

142) -> Figure: 

143 """Plot an input image accommodating different data types and additional 

144 features, like: overplotting centroids, compass (if the input image 

145 has a WCS), stretching, plot title, and legend. 

146 

147 Parameters 

148 ---------- 

149 inputData : `numpy.array` or 

150 `lsst.afw.image.Exposure` or 

151 `lsst.afw.image.Image`, or 

152 `lsst.afw.image.MaskedImage` 

153 The input data. 

154 figure : `matplotlib.figure.Figure`, optional 

155 The matplotlib figure that will be used for plotting. 

156 centroids : `list` 

157 The centroids parameter as a list of tuples. 

158 Each tuple is a centroid with its (X,Y) coordinates. 

159 footprints: `lsst.afw.detection.FootprintSet` or 

160 `lsst.afw.detection.Footprint` or 

161 `list` of `lsst.afw.detection.Footprint` 

162 The footprints containing centroids to plot. 

163 sourceCat: `lsst.afw.table.SourceCatalog`: 

164 An `lsst.afw.table.SourceCatalog` object containing centroids 

165 to plot. 

166 title : `str`, optional 

167 Title for the plot. 

168 showCompass : `bool`, optional 

169 Add compass to the plot? Defaults to True. 

170 stretch : `str', optional 

171 Changes mapping of colors for the image. Avaliable options: 

172 ccs, log, power, asinh, linear, sqrt, midtone. Defaults to linear. 

173 percentile : `float', optional 

174 Parameter for astropy.visualization.PercentileInterval. 

175 Sets lower and upper limits for a stretch. This parameter 

176 will be ignored if stretch='ccs'. 

177 cmap : `str`, optional 

178 The colormap to use for mapping the image values to colors. This can be 

179 a string representing a predefined colormap. Default is 'gray'. 

180 compassLocation : `int`, optional 

181 How far in from the bottom left of the image to display the compass. 

182 By default, compass will be placed at pixel (x,y) = (300,300). 

183 addLegend : `bool', optional 

184 Option to add legend to the plot. Recommended if centroids come from 

185 different sources. Default value is False. 

186 savePlotAs : `str`, optional 

187 The name of the file to save the plot as, including the file extension. 

188 The extention must be supported by `matplotlib.pyplot`. 

189 If None (default) plot will not be saved. 

190 logger : `logging.Logger`, optional 

191 The logger to use for errors, created if not supplied. 

192 Returns 

193 ------- 

194 figure : `matplotlib.figure.Figure` 

195 The rendered image. 

196 """ 

197 

198 if not figure: 

199 figure = make_figure(figsize=(10, 10)) 

200 

201 ax = figure.add_subplot(111) 

202 

203 if not logger: 

204 logger = logging.getLogger(__name__) 

205 

206 imageData = getImageArray(inputData) 

207 

208 if np.isnan(imageData).all(): 

209 im = ax.imshow(imageData, origin="lower", aspect="equal") 

210 logger.warning("The imageData contains only NaN values.") 

211 else: 

212 interval = vis.PercentileInterval(percentile) 

213 match stretch: 

214 case "ccs": 

215 quantiles = getQuantiles(imageData, 256) 

216 norm = colors.BoundaryNorm(quantiles, 256) 

217 case "asinh": 

218 norm = vis.ImageNormalize(imageData, interval=interval, stretch=vis.AsinhStretch(a=0.1)) 

219 case "power": 

220 norm = vis.ImageNormalize(imageData, interval=interval, stretch=vis.PowerStretch(a=2)) 

221 case "log": 

222 norm = vis.ImageNormalize(imageData, interval=interval, stretch=vis.LogStretch(a=1)) 

223 case "linear": 

224 norm = vis.ImageNormalize(imageData, interval=interval, stretch=vis.LinearStretch()) 

225 case "sqrt": 

226 norm = vis.ImageNormalize(imageData, interval=interval, stretch=vis.SqrtStretch()) 

227 case "midtone": 

228 imageData = stretchDataMidTone(imageData) 

229 # no interval in this norm as imageData is now [0, 1] aready 

230 norm = vis.ImageNormalize(imageData, stretch=vis.LinearStretch()) 

231 case _: 

232 raise ValueError( 

233 f"Invalid value for stretch : {stretch}. " 

234 "Accepted options are: ccs, asinh, power, log, linear, sqrt." 

235 ) 

236 

237 im = ax.imshow(imageData, cmap=cmap, origin="lower", norm=norm, aspect="equal", interpolation="auto") 

238 

239 if stretch != "midtone": 

240 div = make_axes_locatable(ax) 

241 cax = div.append_axes("right", size="5%", pad=0.05) 

242 figure.colorbar(im, cax=cax) 

243 

244 if showCompass: 

245 try: 

246 assert hasattr(inputData, "getWcs"), "inputData does not have a getWcs method" 

247 wcs = inputData.getWcs() 

248 except AssertionError: 

249 logger.warning("Failed to get WCS from input data. Compass will not be plotted.") 

250 wcs = None 

251 

252 if wcs: 

253 arrowLength = min(imageData.shape) * 0.05 

254 ax = drawCompass(ax, wcs, compassLocation=compassLocation, arrowLength=arrowLength) 

255 

256 if centroids: 

257 ax.plot( 

258 *zip(*centroids), 

259 marker="x", 

260 markeredgecolor="r", 

261 markerfacecolor="None", 

262 linestyle="None", 

263 label="List of centroids", 

264 ) 

265 

266 if sourceCat: 

267 ax.scatter( 

268 sourceCat.getX(), 

269 sourceCat.getY(), 

270 marker="o", 

271 edgecolors="c", # cyan rings 

272 c="None", # empty cicrles (no fill) 

273 label="Source catalog", 

274 ) 

275 

276 if footprints: 

277 match footprints: 

278 case afwDetection.FootprintSet(): 

279 fs = afwDetection.FootprintSet.getFootprints(footprints) 

280 xy = [_.getCentroid() for _ in fs] 

281 case afwDetection.Footprint(): 

282 xy = [footprints.getCentroid()] 

283 case list(): 

284 xy = [] 

285 for i, ft in enumerate(footprints): 

286 try: 

287 ft.getCentroid() 

288 except AttributeError: 

289 raise TypeError( 

290 "Cannot get centroids for one of the " 

291 "elements from the footprints list. " 

292 "Expected lsst.afw.detection.Footprint, " 

293 f"got {type(ft)} for footprints[{i}]" 

294 ) 

295 xy.append(ft.getCentroid()) 

296 case _: 

297 raise TypeError( 

298 "This function works with FootprintSets, " 

299 "single Footprints, and iterables of Footprints. " 

300 f"Got {type(footprints)}" 

301 ) 

302 

303 ax.plot( 

304 *zip(*xy), 

305 marker="x", 

306 markeredgecolor="b", 

307 markerfacecolor="None", 

308 linestyle="None", 

309 label="Footprints centroids", 

310 ) 

311 

312 if addLegend: 

313 ax.legend(loc="upper center", bbox_to_anchor=(0.5, -0.05), ncol=5) 

314 

315 if title: 

316 ax.set_title(title) 

317 

318 if savePlotAs: 

319 figure.savefig(savePlotAs) 

320 

321 return figure 

322 

323 

324def _computeMtf(image: np.ndarray, midtonesBalance: float) -> np.ndarray: 

325 """ 

326 Compute the midtones transfer function (MTF) for an image. 

327 

328 Parameters 

329 ---------- 

330 image : `np.ndarray` 

331 Input image normalized to [0, 1]. 

332 midtonesBalance : `float` 

333 Balance parameter controlling midtones emphasis. 

334 

335 Returns 

336 ------- 

337 image : `np.ndarray` 

338 Image after applying the MTF, with NaNs replaced by fallback values. 

339 """ 

340 M = np.full(image.shape, midtonesBalance) 

341 maskHalf = image == 0.5 

342 maskZero = image == 0 

343 maskOne = image == 1 

344 fallback = (maskHalf * 0.5) * (1 - maskZero) + maskOne 

345 result = (M - 1) * image / ((2 * M - 1) * image - M) 

346 nanMask = ~np.isfinite(result) 

347 result[nanMask] = fallback[nanMask] 

348 return result 

349 

350 

351def _applyClip(image: np.ndarray, clipLow: float, clipHigh: float) -> np.ndarray: 

352 """ 

353 Linearly clip and scale image values between clipLow and clipHigh. 

354 

355 Parameters 

356 ---------- 

357 image : `np.ndarray` 

358 Input image normalized to [0, 1]. 

359 clipLow : `float` 

360 Lower clipping threshold. 

361 clipHigh : `float` 

362 Upper clipping threshold. 

363 

364 Returns 

365 ------- 

366 clipped : `np.ndarray` 

367 Clipped and scaled image, with values in [0, 1]. 

368 """ 

369 belowLow = image < clipLow 

370 aboveHigh = image > clipHigh 

371 scaled = (image - clipLow) / (clipHigh - clipLow) 

372 return np.clip(scaled * (~belowLow) + aboveHigh, 0.0, 1.0) 

373 

374 

375def _applyExpansion(image: np.ndarray, outMin: float, outMax: float) -> np.ndarray: 

376 """ 

377 Expand image dynamic range from [outMin, outMax] to [0, 1]. 

378 

379 Parameters 

380 ---------- 

381 image : `np.ndarray` 

382 Input image after MTF. 

383 outMin : `float` 

384 Minimum output value (usually 0.0). 

385 outMax : `float` 

386 Maximum output value (usually 1.0). 

387 

388 Returns 

389 ------- 

390 expanded : `np.ndarray` 

391 Expanded image in [0, 1]. 

392 """ 

393 return (image - outMin) / (outMax - outMin) 

394 

395 

396def _applyDisplayFunction( 

397 image: np.ndarray, midtonesBalance: float, clipLow: float, clipHigh: float, outMin: float, outMax: float 

398) -> np.ndarray: 

399 """ 

400 Apply the full display function: clip, MTF, then expansion. 

401 

402 Parameters 

403 ---------- 

404 image : `np.ndarray` 

405 Input image normalized to [0, 1]. 

406 midtonesBalance : `float` 

407 Midtones balance parameter. 

408 clipLow : `float` 

409 Lower clipping threshold. 

410 clipHigh : `float` 

411 Upper clipping threshold. 

412 outMin : `float` 

413 Minimum output of expansion. 

414 outMax : `float` 

415 Maximum output of expansion. 

416 

417 Returns 

418 ------- 

419 np.ndarray 

420 Stretched image ready for display. 

421 """ 

422 clipped = _applyClip(image, clipLow, clipHigh) 

423 mtf = _computeMtf(clipped, midtonesBalance) 

424 return _applyExpansion(mtf, outMin, outMax) 

425 

426 

427def _computeDisplayParameters(data: np.ndarray) -> tuple[float, float, float, float, float]: 

428 """ 

429 Compute parameters for display function based on data statistics. 

430 

431 Parameters 

432 ---------- 

433 data : `np.ndarray` 

434 Normalized image data array. 

435 

436 Returns 

437 ------- 

438 tuple[float, float, float, float, float] 

439 midtonesBalance, clipLow, clipHigh, outMin, outMax 

440 """ 

441 median = np.median(data) 

442 deviations = np.abs(data.ravel() - median) 

443 madn = 1.4826 * np.median(np.sort(deviations)) 

444 targetBackground = 0.25 

445 clippingFactor = -2.8 

446 

447 aboveHalf = median > 0.5 

448 

449 if not aboveHalf and madn != 0: 

450 clipLow = min(1.0, max(0.0, median + clippingFactor * madn)) 

451 else: 

452 clipLow = 0.0 

453 

454 if aboveHalf and madn != 0: 

455 clipHigh = min(1.0, max(0.0, median - clippingFactor * madn)) 

456 else: 

457 clipHigh = 1.0 

458 

459 if median <= 0.5: 

460 midtonesBalance = ( 

461 (targetBackground - 1) 

462 * (median - clipLow) 

463 / ((2 * targetBackground - 1) * (median - clipLow) - targetBackground) 

464 ) 

465 else: 

466 midtonesBalance = ( 

467 (clipHigh - median - 1) 

468 * targetBackground 

469 / (2 * (clipHigh - median - 1) * targetBackground - (clipHigh - median)) 

470 ) 

471 

472 return midtonesBalance, clipLow, clipHigh, 0.0, 1.0 

473 

474 

475def stretchDataMidTone( 

476 imageLike: np.ndarray | afwImage.Exposure | afwImage.Image | afwImage.MaskedImage, 

477) -> np.ndarray: 

478 """ 

479 Normalize and stretch image data from an Exposure object using the Midtone 

480 Transfer Function (MTF). 

481 

482 This is following: 

483 https://pixinsight.com/doc/docs/XISF-1.0-spec/XISF-1.0-spec.html 

484 #__XISF_Data_Objects_:_XISF_Image_:_Display_Function__ 

485 

486 Parameters 

487 ---------- 

488 imageLike : `numpy.ndarray`, `lsst.afw.image.Exposure`, 

489 `lsst.afw.image.Image`, or `lsst.afw.image.MaskedImage` 

490 The image-like object containg the data to be stretched. 

491 

492 Returns 

493 ------- 

494 stretched : `np.ndarray` 

495 The stretched image array. 

496 """ 

497 data = getImageArray(imageLike) 

498 

499 pedestal = np.min(data) 

500 if pedestal >= 0.0: 

501 norm = np.max(data) 

502 normalized = data / norm 

503 else: 

504 norm = np.max(data - pedestal) 

505 normalized = (data - pedestal) / norm 

506 

507 midtonesBalance, clipLow, clipHigh, outMin, outMax = _computeDisplayParameters(normalized) 

508 stretched = _applyDisplayFunction(normalized, midtonesBalance, clipLow, clipHigh, outMin, outMax) 

509 return stretched