Coverage for python/lsst/summit/utils/plotting.py: 7%

118 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-17 12:31 +0000

1# This file is part of summit_utils. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22import logging 

23 

24import astropy.visualization as vis 

25import matplotlib.colors as colors 

26import matplotlib.pyplot as plt 

27import numpy as np 

28from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable 

29 

30import lsst.afw.image as afwImage 

31import lsst.geom as geom 

32from lsst.afw.detection import Footprint, FootprintSet 

33from lsst.summit.utils import getQuantiles 

34 

35 

36def drawCompass(ax, wcs, compassLocation=300, arrowLength=300.0): 

37 """ 

38 Draw the compass. 

39 The arrowLength is the length of compass arrows (arrows should have 

40 the same length). 

41 The steps here are: 

42 - transform the (compassLocation, compassLocation) to RA, DEC coordinates 

43 - move this point in DEC to get N; in RA to get E directions 

44 - transform N and E points back to pixel coordinates 

45 - find linear solutions for lines connecting the center of 

46 the compass with N and E points 

47 - find points along those lines located at the distance of 

48 arrowLength form the (compassLocation, compassLocation). 

49 - there will be two points for each linear solution. 

50 Choose the correct one. 

51 - centers of the N/E labels will also be located on those lines. 

52 

53 Parameters 

54 ---------- 

55 ax : `matplotlib.axes.Axes` 

56 The axes on which the compass will be drawn. 

57 wcs : `lsst.afw.geom.SkyWcs` 

58 WCS from exposure. 

59 compassLocation : `int`, optional 

60 How far in from the bottom left of the image to display the compass. 

61 arrowLength : `float`, optional 

62 The length of the compass arrow. 

63 Returns 

64 ------- 

65 ax : `matplotlib.axes.Axes` 

66 The axes with the compass. 

67 """ 

68 

69 anchorRa, anchorDec = wcs.pixelToSky(compassLocation, compassLocation) 

70 east = wcs.skyToPixel(geom.SpherePoint(anchorRa + 30.0 * geom.arcseconds, anchorDec)) 

71 north = wcs.skyToPixel(geom.SpherePoint(anchorRa, anchorDec + 30.0 * geom.arcseconds)) 

72 labelPosition = arrowLength + 50.0 

73 

74 for xy, label in [(north, "N"), (east, "E")]: 

75 if compassLocation == xy[0]: 

76 xTip = compassLocation 

77 xTipLabel = compassLocation 

78 if xy[1] > compassLocation: 

79 yTip = compassLocation + arrowLength 

80 yTipLabel = compassLocation + labelPosition 

81 else: 

82 yTip = compassLocation - arrowLength 

83 yTipLabel = compassLocation - labelPosition 

84 else: 

85 slope = (xy[1] - compassLocation) / (xy[0] - compassLocation) 

86 xTipProjection = arrowLength / np.sqrt(1.0 + slope**2) 

87 xTipLabelProjection = labelPosition / np.sqrt(1.0 + slope**2) 

88 

89 if xy[0] > compassLocation: 

90 xTip = compassLocation + xTipProjection 

91 xTipLabel = compassLocation + xTipLabelProjection 

92 elif xy[0] < compassLocation: 

93 xTip = compassLocation - xTipProjection 

94 xTipLabel = compassLocation - xTipLabelProjection 

95 yTip = slope * (xTip - compassLocation) + compassLocation 

96 yTipLabel = slope * (xTipLabel - compassLocation) + compassLocation 

97 

98 color = "r" 

99 ax.arrow( 

100 compassLocation, 

101 compassLocation, 

102 xTip - compassLocation, 

103 yTip - compassLocation, 

104 head_width=30.0, 

105 length_includes_head=True, 

106 color=color, 

107 ) 

108 ax.text(xTipLabel, yTipLabel, label, ha="center", va="center", color=color) 

109 return ax 

110 

111 

112def plot( 

113 inputData, 

114 figure=None, 

115 centroids=None, 

116 footprints=None, 

117 sourceCat=None, 

118 title=None, 

119 showCompass=True, 

120 stretch="linear", 

121 percentile=99.0, 

122 cmap="gray", 

123 compassLocation=300, 

124 addLegend=False, 

125 savePlotAs=None, 

126 logger=None, 

127): 

128 """Plot an input image accommodating different data types and additional 

129 features, like: overplotting centroids, compass (if the input image 

130 has a WCS), stretching, plot title, and legend. 

131 

132 Parameters 

133 ---------- 

134 inputData : `numpy.array` or 

135 `lsst.afw.image.Exposure` or 

136 `lsst.afw.image.Image`, or 

137 `lsst.afw.image.MaskedImage` 

138 The input data. 

139 figure : `matplotlib.figure.Figure`, optional 

140 The matplotlib figure that will be used for plotting. 

141 centroids : `list` 

142 The centroids parameter as a list of tuples. 

143 Each tuple is a centroid with its (X,Y) coordinates. 

144 footprints: `lsst.afw.detection.FootprintSet` or 

145 `lsst.afw.detection.Footprint` or 

146 `list` of `lsst.afw.detection.Footprint` 

147 The footprints containing centroids to plot. 

148 sourceCat: `lsst.afw.table.SourceCatalog`: 

149 An `lsst.afw.table.SourceCatalog` object containing centroids 

150 to plot. 

151 title : `str`, optional 

152 Title for the plot. 

153 showCompass : `bool`, optional 

154 Add compass to the plot? Defaults to True. 

155 stretch : `str', optional 

156 Changes mapping of colors for the image. Avaliable options: 

157 ccs, log, power, asinh, linear, sqrt. Defaults to linear. 

158 percentile : `float', optional 

159 Parameter for astropy.visualization.PercentileInterval. 

160 Sets lower and upper limits for a stretch. This parameter 

161 will be ignored if stretch='ccs'. 

162 cmap : `str`, optional 

163 The colormap to use for mapping the image values to colors. This can be 

164 a string representing a predefined colormap. Default is 'gray'. 

165 compassLocation : `int`, optional 

166 How far in from the bottom left of the image to display the compass. 

167 By default, compass will be placed at pixel (x,y) = (300,300). 

168 addLegend : `bool', optional 

169 Option to add legend to the plot. Recommended if centroids come from 

170 different sources. Default value is False. 

171 savePlotAs : `str`, optional 

172 The name of the file to save the plot as, including the file extension. 

173 The extention must be supported by `matplotlib.pyplot`. 

174 If None (default) plot will not be saved. 

175 logger : `logging.Logger`, optional 

176 The logger to use for errors, created if not supplied. 

177 Returns 

178 ------- 

179 figure : `matplotlib.figure.Figure` 

180 The rendered image. 

181 """ 

182 

183 if not figure: 

184 figure = plt.figure(figsize=(10, 10)) 

185 

186 ax = figure.add_subplot(111) 

187 

188 if not logger: 

189 logger = logging.getLogger(__name__) 

190 

191 match inputData: 

192 case np.ndarray(): 

193 imageData = inputData 

194 case afwImage.MaskedImage(): 

195 imageData = inputData.image.array 

196 case afwImage.Image(): 

197 imageData = inputData.array 

198 case afwImage.Exposure(): 

199 imageData = inputData.image.array 

200 case _: 

201 raise TypeError( 

202 "This function accepts numpy array, lsst.afw.image.Exposure components." 

203 f" Got {type(inputData)}" 

204 ) 

205 

206 if np.isnan(imageData).all(): 

207 im = ax.imshow(imageData, origin="lower", aspect="equal") 

208 logger.warning("The imageData contains only NaN values.") 

209 else: 

210 interval = vis.PercentileInterval(percentile) 

211 match stretch: 

212 case "ccs": 

213 quantiles = getQuantiles(imageData, 256) 

214 norm = colors.BoundaryNorm(quantiles, 256) 

215 case "asinh": 

216 norm = vis.ImageNormalize(imageData, interval=interval, stretch=vis.AsinhStretch(a=0.1)) 

217 case "power": 

218 norm = vis.ImageNormalize(imageData, interval=interval, stretch=vis.PowerStretch(a=2)) 

219 case "log": 

220 norm = vis.ImageNormalize(imageData, interval=interval, stretch=vis.LogStretch(a=1)) 

221 case "linear": 

222 norm = vis.ImageNormalize(imageData, interval=interval, stretch=vis.LinearStretch()) 

223 case "sqrt": 

224 norm = vis.ImageNormalize(imageData, interval=interval, stretch=vis.SqrtStretch()) 

225 case _: 

226 raise ValueError( 

227 f"Invalid value for stretch : {stretch}. " 

228 "Accepted options are: ccs, asinh, power, log, linear, sqrt." 

229 ) 

230 

231 im = ax.imshow(imageData, cmap=cmap, origin="lower", norm=norm, aspect="equal") 

232 div = make_axes_locatable(ax) 

233 cax = div.append_axes("right", size="5%", pad=0.05) 

234 figure.colorbar(im, cax=cax) 

235 

236 if showCompass: 

237 try: 

238 wcs = inputData.getWcs() 

239 except AttributeError: 

240 logger.warning("Failed to get WCS from input data. Compass will not be plotted.") 

241 wcs = None 

242 

243 if wcs: 

244 arrowLength = min(imageData.shape) * 0.05 

245 ax = drawCompass(ax, wcs, compassLocation=compassLocation, arrowLength=arrowLength) 

246 

247 if centroids: 

248 ax.plot( 

249 *zip(*centroids), 

250 marker="x", 

251 markeredgecolor="r", 

252 markerfacecolor="None", 

253 linestyle="None", 

254 label="List of centroids", 

255 ) 

256 

257 if sourceCat: 

258 ax.plot( 

259 list(zip(sourceCat.getX(), sourceCat.getY())), 

260 marker="o", 

261 markeredgecolor="c", 

262 markerfacecolor="None", 

263 linestyle="None", 

264 label="Source catalog", 

265 ) 

266 

267 if footprints: 

268 match footprints: 

269 case FootprintSet(): 

270 fs = FootprintSet.getFootprints(footprints) 

271 xy = [_.getCentroid() for _ in fs] 

272 case Footprint(): 

273 xy = [footprints.getCentroid()] 

274 case list(): 

275 xy = [] 

276 for i, ft in enumerate(footprints): 

277 try: 

278 ft.getCentroid() 

279 except AttributeError: 

280 raise TypeError( 

281 "Cannot get centroids for one of the " 

282 "elements from the footprints list. " 

283 "Expected lsst.afw.detection.Footprint, " 

284 f"got {type(ft)} for footprints[{i}]" 

285 ) 

286 xy.append(ft.getCentroid()) 

287 case _: 

288 raise TypeError( 

289 "This function works with FootprintSets, " 

290 "single Footprints, and iterables of Footprints. " 

291 f"Got {type(footprints)}" 

292 ) 

293 

294 ax.plot( 

295 *zip(*xy), 

296 marker="x", 

297 markeredgecolor="b", 

298 markerfacecolor="None", 

299 linestyle="None", 

300 label="Footprints centroids", 

301 ) 

302 

303 if addLegend: 

304 ax.legend(loc="upper center", bbox_to_anchor=(0.5, -0.05), ncol=5) 

305 

306 if title: 

307 ax.set_title(title) 

308 

309 if savePlotAs: 

310 plt.savefig(savePlotAs) 

311 

312 return figure