Coverage for python/lsst/summit/utils/plotting.py: 9%

122 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-17 08:53 +0000

1# This file is part of summit_utils. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import logging 

23 

24import astropy.visualization as vis 

25import matplotlib 

26import matplotlib.colors as colors 

27import matplotlib.pyplot as plt 

28import numpy as np 

29from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable 

30 

31import lsst.afw.detection as afwDetection 

32import lsst.afw.geom as afwGeom 

33import lsst.afw.image as afwImage 

34import lsst.afw.table as afwTable 

35import lsst.geom as geom 

36from lsst.afw.detection import Footprint, FootprintSet 

37from lsst.summit.utils import getQuantiles 

38 

39 

40def drawCompass( 

41 ax: matplotlib.axes.Axes, 

42 wcs: afwGeom.SkyWcs, 

43 compassLocation: int = 300, 

44 arrowLength: float = 300.0, 

45) -> matplotlib.axes.Axes: 

46 """ 

47 Draw the compass. 

48 The arrowLength is the length of compass arrows (arrows should have 

49 the same length). 

50 The steps here are: 

51 - transform the (compassLocation, compassLocation) to RA, DEC coordinates 

52 - move this point in DEC to get N; in RA to get E directions 

53 - transform N and E points back to pixel coordinates 

54 - find linear solutions for lines connecting the center of 

55 the compass with N and E points 

56 - find points along those lines located at the distance of 

57 arrowLength form the (compassLocation, compassLocation). 

58 - there will be two points for each linear solution. 

59 Choose the correct one. 

60 - centers of the N/E labels will also be located on those lines. 

61 

62 Parameters 

63 ---------- 

64 ax : `matplotlib.axes.Axes` 

65 The axes on which the compass will be drawn. 

66 wcs : `lsst.afw.geom.SkyWcs` 

67 WCS from exposure. 

68 compassLocation : `int`, optional 

69 How far in from the bottom left of the image to display the compass. 

70 arrowLength : `float`, optional 

71 The length of the compass arrow. 

72 Returns 

73 ------- 

74 ax : `matplotlib.axes.Axes` 

75 The axes with the compass. 

76 """ 

77 

78 anchorRa, anchorDec = wcs.pixelToSky(compassLocation, compassLocation) 

79 east = wcs.skyToPixel(geom.SpherePoint(anchorRa + 30.0 * geom.arcseconds, anchorDec)) 

80 north = wcs.skyToPixel(geom.SpherePoint(anchorRa, anchorDec + 30.0 * geom.arcseconds)) 

81 labelPosition = arrowLength + 50.0 

82 

83 for xy, label in [(north, "N"), (east, "E")]: 

84 if compassLocation == xy[0]: 

85 xTip = compassLocation 

86 xTipLabel = compassLocation 

87 if xy[1] > compassLocation: 

88 yTip = compassLocation + arrowLength 

89 yTipLabel = compassLocation + labelPosition 

90 else: 

91 yTip = compassLocation - arrowLength 

92 yTipLabel = compassLocation - labelPosition 

93 else: 

94 slope = (xy[1] - compassLocation) / (xy[0] - compassLocation) 

95 xTipProjection = arrowLength / np.sqrt(1.0 + slope**2) 

96 xTipLabelProjection = labelPosition / np.sqrt(1.0 + slope**2) 

97 

98 if xy[0] > compassLocation: 

99 xTip = compassLocation + xTipProjection 

100 xTipLabel = compassLocation + xTipLabelProjection 

101 elif xy[0] < compassLocation: 

102 xTip = compassLocation - xTipProjection 

103 xTipLabel = compassLocation - xTipLabelProjection 

104 yTip = slope * (xTip - compassLocation) + compassLocation 

105 yTipLabel = slope * (xTipLabel - compassLocation) + compassLocation 

106 

107 color = "r" 

108 ax.arrow( 

109 compassLocation, 

110 compassLocation, 

111 xTip - compassLocation, 

112 yTip - compassLocation, 

113 head_width=30.0, 

114 length_includes_head=True, 

115 color=color, 

116 ) 

117 ax.text(xTipLabel, yTipLabel, label, ha="center", va="center", color=color) 

118 return ax 

119 

120 

121def plot( 

122 inputData: np.ndarray | afwImage.Exposure | afwImage.Image | afwImage.MaskedImage, 

123 figure: matplotlib.figure.Figure | None = None, 

124 centroids: list[tuple[int, int]] | None = None, 

125 footprints: ( 

126 afwDetection.FootprintSet | afwDetection.Footprint | list[afwDetection.Footprint] | None 

127 ) = None, 

128 sourceCat: afwTable.SourceCatalog = None, 

129 title: str | None = None, 

130 showCompass: bool = True, 

131 stretch: str = "linear", 

132 percentile: float = 99.0, 

133 cmap: str = "gray", 

134 compassLocation: int = 300, 

135 addLegend: bool = False, 

136 savePlotAs: str | None = None, 

137 logger: logging.Logger | None = None, 

138) -> matplotlib.figure.Figure: 

139 """Plot an input image accommodating different data types and additional 

140 features, like: overplotting centroids, compass (if the input image 

141 has a WCS), stretching, plot title, and legend. 

142 

143 Parameters 

144 ---------- 

145 inputData : `numpy.array` or 

146 `lsst.afw.image.Exposure` or 

147 `lsst.afw.image.Image`, or 

148 `lsst.afw.image.MaskedImage` 

149 The input data. 

150 figure : `matplotlib.figure.Figure`, optional 

151 The matplotlib figure that will be used for plotting. 

152 centroids : `list` 

153 The centroids parameter as a list of tuples. 

154 Each tuple is a centroid with its (X,Y) coordinates. 

155 footprints: `lsst.afw.detection.FootprintSet` or 

156 `lsst.afw.detection.Footprint` or 

157 `list` of `lsst.afw.detection.Footprint` 

158 The footprints containing centroids to plot. 

159 sourceCat: `lsst.afw.table.SourceCatalog`: 

160 An `lsst.afw.table.SourceCatalog` object containing centroids 

161 to plot. 

162 title : `str`, optional 

163 Title for the plot. 

164 showCompass : `bool`, optional 

165 Add compass to the plot? Defaults to True. 

166 stretch : `str', optional 

167 Changes mapping of colors for the image. Avaliable options: 

168 ccs, log, power, asinh, linear, sqrt. Defaults to linear. 

169 percentile : `float', optional 

170 Parameter for astropy.visualization.PercentileInterval. 

171 Sets lower and upper limits for a stretch. This parameter 

172 will be ignored if stretch='ccs'. 

173 cmap : `str`, optional 

174 The colormap to use for mapping the image values to colors. This can be 

175 a string representing a predefined colormap. Default is 'gray'. 

176 compassLocation : `int`, optional 

177 How far in from the bottom left of the image to display the compass. 

178 By default, compass will be placed at pixel (x,y) = (300,300). 

179 addLegend : `bool', optional 

180 Option to add legend to the plot. Recommended if centroids come from 

181 different sources. Default value is False. 

182 savePlotAs : `str`, optional 

183 The name of the file to save the plot as, including the file extension. 

184 The extention must be supported by `matplotlib.pyplot`. 

185 If None (default) plot will not be saved. 

186 logger : `logging.Logger`, optional 

187 The logger to use for errors, created if not supplied. 

188 Returns 

189 ------- 

190 figure : `matplotlib.figure.Figure` 

191 The rendered image. 

192 """ 

193 

194 if not figure: 

195 figure = plt.figure(figsize=(10, 10)) 

196 

197 ax = figure.add_subplot(111) 

198 

199 if not logger: 

200 logger = logging.getLogger(__name__) 

201 

202 match inputData: 

203 case np.ndarray(): 

204 imageData = inputData 

205 case afwImage.MaskedImage(): 

206 imageData = inputData.image.array 

207 case afwImage.Image(): 

208 imageData = inputData.array 

209 case afwImage.Exposure(): 

210 imageData = inputData.image.array 

211 case _: 

212 raise TypeError( 

213 "This function accepts numpy array, lsst.afw.image.Exposure components." 

214 f" Got {type(inputData)}" 

215 ) 

216 

217 if np.isnan(imageData).all(): 

218 im = ax.imshow(imageData, origin="lower", aspect="equal") 

219 logger.warning("The imageData contains only NaN values.") 

220 else: 

221 interval = vis.PercentileInterval(percentile) 

222 match stretch: 

223 case "ccs": 

224 quantiles = getQuantiles(imageData, 256) 

225 norm = colors.BoundaryNorm(quantiles, 256) 

226 case "asinh": 

227 norm = vis.ImageNormalize(imageData, interval=interval, stretch=vis.AsinhStretch(a=0.1)) 

228 case "power": 

229 norm = vis.ImageNormalize(imageData, interval=interval, stretch=vis.PowerStretch(a=2)) 

230 case "log": 

231 norm = vis.ImageNormalize(imageData, interval=interval, stretch=vis.LogStretch(a=1)) 

232 case "linear": 

233 norm = vis.ImageNormalize(imageData, interval=interval, stretch=vis.LinearStretch()) 

234 case "sqrt": 

235 norm = vis.ImageNormalize(imageData, interval=interval, stretch=vis.SqrtStretch()) 

236 case _: 

237 raise ValueError( 

238 f"Invalid value for stretch : {stretch}. " 

239 "Accepted options are: ccs, asinh, power, log, linear, sqrt." 

240 ) 

241 

242 im = ax.imshow(imageData, cmap=cmap, origin="lower", norm=norm, aspect="equal") 

243 div = make_axes_locatable(ax) 

244 cax = div.append_axes("right", size="5%", pad=0.05) 

245 figure.colorbar(im, cax=cax) 

246 

247 if showCompass: 

248 try: 

249 wcs = inputData.getWcs() 

250 except AttributeError: 

251 logger.warning("Failed to get WCS from input data. Compass will not be plotted.") 

252 wcs = None 

253 

254 if wcs: 

255 arrowLength = min(imageData.shape) * 0.05 

256 ax = drawCompass(ax, wcs, compassLocation=compassLocation, arrowLength=arrowLength) 

257 

258 if centroids: 

259 ax.plot( 

260 *zip(*centroids), 

261 marker="x", 

262 markeredgecolor="r", 

263 markerfacecolor="None", 

264 linestyle="None", 

265 label="List of centroids", 

266 ) 

267 

268 if sourceCat: 

269 ax.plot( 

270 list(zip(sourceCat.getX(), sourceCat.getY())), 

271 marker="o", 

272 markeredgecolor="c", 

273 markerfacecolor="None", 

274 linestyle="None", 

275 label="Source catalog", 

276 ) 

277 

278 if footprints: 

279 match footprints: 

280 case FootprintSet(): 

281 fs = FootprintSet.getFootprints(footprints) 

282 xy = [_.getCentroid() for _ in fs] 

283 case Footprint(): 

284 xy = [footprints.getCentroid()] 

285 case list(): 

286 xy = [] 

287 for i, ft in enumerate(footprints): 

288 try: 

289 ft.getCentroid() 

290 except AttributeError: 

291 raise TypeError( 

292 "Cannot get centroids for one of the " 

293 "elements from the footprints list. " 

294 "Expected lsst.afw.detection.Footprint, " 

295 f"got {type(ft)} for footprints[{i}]" 

296 ) 

297 xy.append(ft.getCentroid()) 

298 case _: 

299 raise TypeError( 

300 "This function works with FootprintSets, " 

301 "single Footprints, and iterables of Footprints. " 

302 f"Got {type(footprints)}" 

303 ) 

304 

305 ax.plot( 

306 *zip(*xy), 

307 marker="x", 

308 markeredgecolor="b", 

309 markerfacecolor="None", 

310 linestyle="None", 

311 label="Footprints centroids", 

312 ) 

313 

314 if addLegend: 

315 ax.legend(loc="upper center", bbox_to_anchor=(0.5, -0.05), ncol=5) 

316 

317 if title: 

318 ax.set_title(title) 

319 

320 if savePlotAs: 

321 plt.savefig(savePlotAs) 

322 

323 return figure