Coverage for python/lsst/summit/utils/plotting.py: 7%

118 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-02-06 14:05 +0000

1# This file is part of summit_utils. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import numpy as np 

23import matplotlib.pyplot as plt 

24import matplotlib.colors as colors 

25from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable 

26import astropy.visualization as vis 

27import logging 

28 

29from lsst.afw.detection import FootprintSet, Footprint 

30import lsst.geom as geom 

31from lsst.summit.utils import getQuantiles 

32import lsst.afw.image as afwImage 

33 

34 

35def drawCompass(ax, wcs, compassLocation=300, arrowLength=300.): 

36 """ 

37 Draw the compass. 

38 The arrowLength is the length of compass arrows (arrows should have 

39 the same length). 

40 The steps here are: 

41 - transform the (compassLocation, compassLocation) to RA, DEC coordinates 

42 - move this point in DEC to get N; in RA to get E directions 

43 - transform N and E points back to pixel coordinates 

44 - find linear solutions for lines connecting the center of 

45 the compass with N and E points 

46 - find points along those lines located at the distance of 

47 arrowLength form the (compassLocation, compassLocation). 

48 - there will be two points for each linear solution. 

49 Choose the correct one. 

50 - centers of the N/E labels will also be located on those lines. 

51 

52 Parameters 

53 ---------- 

54 ax : `matplotlib.axes.Axes` 

55 The axes on which the compass will be drawn. 

56 wcs : `lsst.afw.geom.SkyWcs` 

57 WCS from exposure. 

58 compassLocation : `int`, optional 

59 How far in from the bottom left of the image to display the compass. 

60 arrowLength : `float`, optional 

61 The length of the compass arrow. 

62 Returns 

63 ------- 

64 ax : `matplotlib.axes.Axes` 

65 The axes with the compass. 

66 """ 

67 

68 anchorRa, anchorDec = wcs.pixelToSky(compassLocation, compassLocation) 

69 east = wcs.skyToPixel(geom.SpherePoint(anchorRa + 30.0 * geom.arcseconds, anchorDec)) 

70 north = wcs.skyToPixel(geom.SpherePoint(anchorRa, anchorDec + 30. * geom.arcseconds)) 

71 labelPosition = arrowLength + 50. 

72 

73 for xy, label in [(north, 'N'), (east, 'E')]: 

74 if compassLocation == xy[0]: 

75 xTip = compassLocation 

76 xTipLabel = compassLocation 

77 if xy[1] > compassLocation: 

78 yTip = compassLocation + arrowLength 

79 yTipLabel = compassLocation + labelPosition 

80 else: 

81 yTip = compassLocation - arrowLength 

82 yTipLabel = compassLocation - labelPosition 

83 else: 

84 slope = (xy[1] - compassLocation) / (xy[0] - compassLocation) 

85 xTipProjection = arrowLength / np.sqrt(1. + slope**2) 

86 xTipLabelProjection = labelPosition / np.sqrt(1. + slope**2) 

87 

88 if xy[0] > compassLocation: 

89 xTip = compassLocation + xTipProjection 

90 xTipLabel = compassLocation + xTipLabelProjection 

91 elif xy[0] < compassLocation: 

92 xTip = compassLocation - xTipProjection 

93 xTipLabel = compassLocation - xTipLabelProjection 

94 yTip = slope * (xTip - compassLocation) + compassLocation 

95 yTipLabel = slope * (xTipLabel - compassLocation) + compassLocation 

96 

97 color = 'r' 

98 ax.arrow(compassLocation, compassLocation, 

99 xTip-compassLocation, yTip-compassLocation, 

100 head_width=30., length_includes_head=True, color=color) 

101 ax.text(xTipLabel, yTipLabel, label, ha='center', va='center', color=color) 

102 return ax 

103 

104 

105def plot(inputData, 

106 figure=None, 

107 centroids=None, 

108 footprints=None, 

109 sourceCat=None, 

110 title=None, 

111 showCompass=True, 

112 stretch='linear', 

113 percentile=99., 

114 cmap='gray', 

115 compassLocation=300, 

116 addLegend=False, 

117 savePlotAs=None, 

118 logger=None): 

119 

120 """Plot an input image accommodating different data types and additional 

121 features, like: overplotting centroids, compass (if the input image 

122 has a WCS), stretching, plot title, and legend. 

123 

124 Parameters 

125 ---------- 

126 inputData : `numpy.array` or 

127 `lsst.afw.image.Exposure` or 

128 `lsst.afw.image.Image`, or 

129 `lsst.afw.image.MaskedImage` 

130 The input data. 

131 figure : `matplotlib.figure.Figure`, optional 

132 The matplotlib figure that will be used for plotting. 

133 centroids : `list` 

134 The centroids parameter as a list of tuples. 

135 Each tuple is a centroid with its (X,Y) coordinates. 

136 footprints: `lsst.afw.detection.FootprintSet` or 

137 `lsst.afw.detection.Footprint` or 

138 `list` of `lsst.afw.detection.Footprint` 

139 The footprints containing centroids to plot. 

140 sourceCat: `lsst.afw.table.SourceCatalog`: 

141 An `lsst.afw.table.SourceCatalog` object containing centroids 

142 to plot. 

143 title : `str`, optional 

144 Title for the plot. 

145 showCompass : `bool`, optional 

146 Add compass to the plot? Defaults to True. 

147 stretch : `str', optional 

148 Changes mapping of colors for the image. Avaliable options: 

149 ccs, log, power, asinh, linear, sqrt. Defaults to linear. 

150 percentile : `float', optional 

151 Parameter for astropy.visualization.PercentileInterval. 

152 Sets lower and upper limits for a stretch. This parameter 

153 will be ignored if stretch='ccs'. 

154 cmap : `str`, optional 

155 The colormap to use for mapping the image values to colors. This can be 

156 a string representing a predefined colormap. Default is 'gray'. 

157 compassLocation : `int`, optional 

158 How far in from the bottom left of the image to display the compass. 

159 By default, compass will be placed at pixel (x,y) = (300,300). 

160 addLegend : `bool', optional 

161 Option to add legend to the plot. Recommended if centroids come from 

162 different sources. Default value is False. 

163 savePlotAs : `str`, optional 

164 The name of the file to save the plot as, including the file extension. 

165 The extention must be supported by `matplotlib.pyplot`. 

166 If None (default) plot will not be saved. 

167 logger : `logging.Logger`, optional 

168 The logger to use for errors, created if not supplied. 

169 Returns 

170 ------- 

171 figure : `matplotlib.figure.Figure` 

172 The rendered image. 

173 """ 

174 

175 if not figure: 

176 figure = plt.figure(figsize=(10, 10)) 

177 

178 ax = figure.add_subplot(111) 

179 

180 if not logger: 

181 logger = logging.getLogger(__name__) 

182 

183 match inputData: 

184 case np.ndarray(): 

185 imageData = inputData 

186 case afwImage.MaskedImage(): 

187 imageData = inputData.image.array 

188 case afwImage.Image(): 

189 imageData = inputData.array 

190 case afwImage.Exposure(): 

191 imageData = inputData.image.array 

192 case _: 

193 raise TypeError("This function accepts numpy array, lsst.afw.image.Exposure components." 

194 f" Got {type(inputData)}") 

195 

196 if np.isnan(imageData).all(): 

197 im = ax.imshow(imageData, origin='lower', aspect='equal') 

198 logger.warning("The imageData contains only NaN values.") 

199 else: 

200 interval = vis.PercentileInterval(percentile) 

201 match stretch: 

202 case 'ccs': 

203 quantiles = getQuantiles(imageData, 256) 

204 norm = colors.BoundaryNorm(quantiles, 256) 

205 case 'asinh': 

206 norm = vis.ImageNormalize(imageData, 

207 interval=interval, 

208 stretch=vis.AsinhStretch(a=0.1)) 

209 case 'power': 

210 norm = vis.ImageNormalize(imageData, 

211 interval=interval, 

212 stretch=vis.PowerStretch(a=2)) 

213 case 'log': 

214 norm = vis.ImageNormalize(imageData, 

215 interval=interval, 

216 stretch=vis.LogStretch(a=1)) 

217 case 'linear': 

218 norm = vis.ImageNormalize(imageData, 

219 interval=interval, 

220 stretch=vis.LinearStretch()) 

221 case 'sqrt': 

222 norm = vis.ImageNormalize(imageData, 

223 interval=interval, 

224 stretch=vis.SqrtStretch()) 

225 case _: 

226 raise ValueError(f"Invalid value for stretch : {stretch}. " 

227 "Accepted options are: ccs, asinh, power, log, linear, sqrt.") 

228 

229 im = ax.imshow(imageData, cmap=cmap, origin='lower', norm=norm, aspect='equal') 

230 div = make_axes_locatable(ax) 

231 cax = div.append_axes("right", size="5%", pad=0.05) 

232 figure.colorbar(im, cax=cax) 

233 

234 if showCompass: 

235 try: 

236 wcs = inputData.getWcs() 

237 except AttributeError: 

238 logger.warning("Failed to get WCS from input data. Compass will not be plotted.") 

239 wcs = None 

240 

241 if wcs: 

242 arrowLength = min(imageData.shape) * 0.05 

243 ax = drawCompass(ax, wcs, compassLocation=compassLocation, arrowLength=arrowLength) 

244 

245 if centroids: 

246 ax.plot(*zip(*centroids), 

247 marker='x', 

248 markeredgecolor='r', 

249 markerfacecolor='None', 

250 linestyle='None', 

251 label='List of centroids') 

252 

253 if sourceCat: 

254 ax.plot(list(zip(sourceCat.getX(), sourceCat.getY())), 

255 marker='o', 

256 markeredgecolor='c', 

257 markerfacecolor='None', 

258 linestyle='None', 

259 label='Source catalog') 

260 

261 if footprints: 

262 match footprints: 

263 case FootprintSet(): 

264 fs = FootprintSet.getFootprints(footprints) 

265 xy = [_.getCentroid() for _ in fs] 

266 case Footprint(): 

267 xy = [footprints.getCentroid()] 

268 case list(): 

269 xy = [] 

270 for i, ft in enumerate(footprints): 

271 try: 

272 ft.getCentroid() 

273 except AttributeError: 

274 raise TypeError("Cannot get centroids for one of the " 

275 "elements from the footprints list. " 

276 "Expected lsst.afw.detection.Footprint, " 

277 f"got {type(ft)} for footprints[{i}]") 

278 xy.append(ft.getCentroid()) 

279 case _: 

280 raise TypeError("This function works with FootprintSets, " 

281 "single Footprints, and iterables of Footprints. " 

282 f"Got {type(footprints)}") 

283 

284 ax.plot(*zip(*xy), 

285 marker='x', 

286 markeredgecolor='b', 

287 markerfacecolor='None', 

288 linestyle='None', 

289 label='Footprints centroids') 

290 

291 if addLegend: 

292 ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.05), ncol=5) 

293 

294 if title: 

295 ax.set_title(title) 

296 

297 if savePlotAs: 

298 plt.savefig(savePlotAs) 

299 

300 return figure