Coverage for python/lsst/summit/utils/plotting.py: 7%
118 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-06 15:47 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-06 15:47 +0000
1# This file is part of summit_utils.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import numpy as np
23import matplotlib.pyplot as plt
24import matplotlib.colors as colors
25from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable
26import astropy.visualization as vis
27import logging
29from lsst.afw.detection import FootprintSet, Footprint
30import lsst.geom as geom
31from lsst.summit.utils import getQuantiles
32import lsst.afw.image as afwImage
35def drawCompass(ax, wcs, compassLocation=300, arrowLength=300.):
36 """
37 Draw the compass.
38 The arrowLength is the length of compass arrows (arrows should have
39 the same length).
40 The steps here are:
41 - transform the (compassLocation, compassLocation) to RA, DEC coordinates
42 - move this point in DEC to get N; in RA to get E directions
43 - transform N and E points back to pixel coordinates
44 - find linear solutions for lines connecting the center of
45 the compass with N and E points
46 - find points along those lines located at the distance of
47 arrowLength form the (compassLocation, compassLocation).
48 - there will be two points for each linear solution.
49 Choose the correct one.
50 - centers of the N/E labels will also be located on those lines.
52 Parameters
53 ----------
54 ax : `matplotlib.axes.Axes`
55 The axes on which the compass will be drawn.
56 wcs : `lsst.afw.geom.SkyWcs`
57 WCS from exposure.
58 compassLocation : `int`, optional
59 How far in from the bottom left of the image to display the compass.
60 arrowLength : `float`, optional
61 The length of the compass arrow.
62 Returns
63 -------
64 ax : `matplotlib.axes.Axes`
65 The axes with the compass.
66 """
68 anchorRa, anchorDec = wcs.pixelToSky(compassLocation, compassLocation)
69 east = wcs.skyToPixel(geom.SpherePoint(anchorRa + 30.0 * geom.arcseconds, anchorDec))
70 north = wcs.skyToPixel(geom.SpherePoint(anchorRa, anchorDec + 30. * geom.arcseconds))
71 labelPosition = arrowLength + 50.
73 for xy, label in [(north, 'N'), (east, 'E')]:
74 if compassLocation == xy[0]:
75 xTip = compassLocation
76 xTipLabel = compassLocation
77 if xy[1] > compassLocation:
78 yTip = compassLocation + arrowLength
79 yTipLabel = compassLocation + labelPosition
80 else:
81 yTip = compassLocation - arrowLength
82 yTipLabel = compassLocation - labelPosition
83 else:
84 slope = (xy[1] - compassLocation) / (xy[0] - compassLocation)
85 xTipProjection = arrowLength / np.sqrt(1. + slope**2)
86 xTipLabelProjection = labelPosition / np.sqrt(1. + slope**2)
88 if xy[0] > compassLocation:
89 xTip = compassLocation + xTipProjection
90 xTipLabel = compassLocation + xTipLabelProjection
91 elif xy[0] < compassLocation:
92 xTip = compassLocation - xTipProjection
93 xTipLabel = compassLocation - xTipLabelProjection
94 yTip = slope * (xTip - compassLocation) + compassLocation
95 yTipLabel = slope * (xTipLabel - compassLocation) + compassLocation
97 color = 'r'
98 ax.arrow(compassLocation, compassLocation,
99 xTip-compassLocation, yTip-compassLocation,
100 head_width=30., length_includes_head=True, color=color)
101 ax.text(xTipLabel, yTipLabel, label, ha='center', va='center', color=color)
102 return ax
105def plot(inputData,
106 figure=None,
107 centroids=None,
108 footprints=None,
109 sourceCat=None,
110 title=None,
111 showCompass=True,
112 stretch='linear',
113 percentile=99.,
114 cmap='gray',
115 compassLocation=300,
116 addLegend=False,
117 savePlotAs=None,
118 logger=None):
120 """Plot an input image accommodating different data types and additional
121 features, like: overplotting centroids, compass (if the input image
122 has a WCS), stretching, plot title, and legend.
124 Parameters
125 ----------
126 inputData : `numpy.array` or
127 `lsst.afw.image.Exposure` or
128 `lsst.afw.image.Image`, or
129 `lsst.afw.image.MaskedImage`
130 The input data.
131 figure : `matplotlib.figure.Figure`, optional
132 The matplotlib figure that will be used for plotting.
133 centroids : `list`
134 The centroids parameter as a list of tuples.
135 Each tuple is a centroid with its (X,Y) coordinates.
136 footprints: `lsst.afw.detection.FootprintSet` or
137 `lsst.afw.detection.Footprint` or
138 `list` of `lsst.afw.detection.Footprint`
139 The footprints containing centroids to plot.
140 sourceCat: `lsst.afw.table.SourceCatalog`:
141 An `lsst.afw.table.SourceCatalog` object containing centroids
142 to plot.
143 title : `str`, optional
144 Title for the plot.
145 showCompass : `bool`, optional
146 Add compass to the plot? Defaults to True.
147 stretch : `str', optional
148 Changes mapping of colors for the image. Avaliable options:
149 ccs, log, power, asinh, linear, sqrt. Defaults to linear.
150 percentile : `float', optional
151 Parameter for astropy.visualization.PercentileInterval.
152 Sets lower and upper limits for a stretch. This parameter
153 will be ignored if stretch='ccs'.
154 cmap : `str`, optional
155 The colormap to use for mapping the image values to colors. This can be
156 a string representing a predefined colormap. Default is 'gray'.
157 compassLocation : `int`, optional
158 How far in from the bottom left of the image to display the compass.
159 By default, compass will be placed at pixel (x,y) = (300,300).
160 addLegend : `bool', optional
161 Option to add legend to the plot. Recommended if centroids come from
162 different sources. Default value is False.
163 savePlotAs : `str`, optional
164 The name of the file to save the plot as, including the file extension.
165 The extention must be supported by `matplotlib.pyplot`.
166 If None (default) plot will not be saved.
167 logger : `logging.Logger`, optional
168 The logger to use for errors, created if not supplied.
169 Returns
170 -------
171 figure : `matplotlib.figure.Figure`
172 The rendered image.
173 """
175 if not figure:
176 figure = plt.figure(figsize=(10, 10))
178 ax = figure.add_subplot(111)
180 if not logger:
181 logger = logging.getLogger(__name__)
183 match inputData:
184 case np.ndarray():
185 imageData = inputData
186 case afwImage.MaskedImage():
187 imageData = inputData.image.array
188 case afwImage.Image():
189 imageData = inputData.array
190 case afwImage.Exposure():
191 imageData = inputData.image.array
192 case _:
193 raise TypeError("This function accepts numpy array, lsst.afw.image.Exposure components."
194 f" Got {type(inputData)}")
196 if np.isnan(imageData).all():
197 im = ax.imshow(imageData, origin='lower', aspect='equal')
198 logger.warning("The imageData contains only NaN values.")
199 else:
200 interval = vis.PercentileInterval(percentile)
201 match stretch:
202 case 'ccs':
203 quantiles = getQuantiles(imageData, 256)
204 norm = colors.BoundaryNorm(quantiles, 256)
205 case 'asinh':
206 norm = vis.ImageNormalize(imageData,
207 interval=interval,
208 stretch=vis.AsinhStretch(a=0.1))
209 case 'power':
210 norm = vis.ImageNormalize(imageData,
211 interval=interval,
212 stretch=vis.PowerStretch(a=2))
213 case 'log':
214 norm = vis.ImageNormalize(imageData,
215 interval=interval,
216 stretch=vis.LogStretch(a=1))
217 case 'linear':
218 norm = vis.ImageNormalize(imageData,
219 interval=interval,
220 stretch=vis.LinearStretch())
221 case 'sqrt':
222 norm = vis.ImageNormalize(imageData,
223 interval=interval,
224 stretch=vis.SqrtStretch())
225 case _:
226 raise ValueError(f"Invalid value for stretch : {stretch}. "
227 "Accepted options are: ccs, asinh, power, log, linear, sqrt.")
229 im = ax.imshow(imageData, cmap=cmap, origin='lower', norm=norm, aspect='equal')
230 div = make_axes_locatable(ax)
231 cax = div.append_axes("right", size="5%", pad=0.05)
232 figure.colorbar(im, cax=cax)
234 if showCompass:
235 try:
236 wcs = inputData.getWcs()
237 except AttributeError:
238 logger.warning("Failed to get WCS from input data. Compass will not be plotted.")
239 wcs = None
241 if wcs:
242 arrowLength = min(imageData.shape) * 0.05
243 ax = drawCompass(ax, wcs, compassLocation=compassLocation, arrowLength=arrowLength)
245 if centroids:
246 ax.plot(*zip(*centroids),
247 marker='x',
248 markeredgecolor='r',
249 markerfacecolor='None',
250 linestyle='None',
251 label='List of centroids')
253 if sourceCat:
254 ax.plot(list(zip(sourceCat.getX(), sourceCat.getY())),
255 marker='o',
256 markeredgecolor='c',
257 markerfacecolor='None',
258 linestyle='None',
259 label='Source catalog')
261 if footprints:
262 match footprints:
263 case FootprintSet():
264 fs = FootprintSet.getFootprints(footprints)
265 xy = [_.getCentroid() for _ in fs]
266 case Footprint():
267 xy = [footprints.getCentroid()]
268 case list():
269 xy = []
270 for i, ft in enumerate(footprints):
271 try:
272 ft.getCentroid()
273 except AttributeError:
274 raise TypeError("Cannot get centroids for one of the "
275 "elements from the footprints list. "
276 "Expected lsst.afw.detection.Footprint, "
277 f"got {type(ft)} for footprints[{i}]")
278 xy.append(ft.getCentroid())
279 case _:
280 raise TypeError("This function works with FootprintSets, "
281 "single Footprints, and iterables of Footprints. "
282 f"Got {type(footprints)}")
284 ax.plot(*zip(*xy),
285 marker='x',
286 markeredgecolor='b',
287 markerfacecolor='None',
288 linestyle='None',
289 label='Footprints centroids')
291 if addLegend:
292 ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.05), ncol=5)
294 if title:
295 ax.set_title(title)
297 if savePlotAs:
298 plt.savefig(savePlotAs)
300 return figure