Coverage for python/lsst/summit/utils/plotting.py: 7%
118 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-03-28 05:07 -0700
« prev ^ index » next coverage.py v7.4.4, created at 2024-03-28 05:07 -0700
1# This file is part of summit_utils.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import logging
24import astropy.visualization as vis
25import matplotlib.colors as colors
26import matplotlib.pyplot as plt
27import numpy as np
28from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable
30import lsst.afw.image as afwImage
31import lsst.geom as geom
32from lsst.afw.detection import Footprint, FootprintSet
33from lsst.summit.utils import getQuantiles
36def drawCompass(ax, wcs, compassLocation=300, arrowLength=300.0):
37 """
38 Draw the compass.
39 The arrowLength is the length of compass arrows (arrows should have
40 the same length).
41 The steps here are:
42 - transform the (compassLocation, compassLocation) to RA, DEC coordinates
43 - move this point in DEC to get N; in RA to get E directions
44 - transform N and E points back to pixel coordinates
45 - find linear solutions for lines connecting the center of
46 the compass with N and E points
47 - find points along those lines located at the distance of
48 arrowLength form the (compassLocation, compassLocation).
49 - there will be two points for each linear solution.
50 Choose the correct one.
51 - centers of the N/E labels will also be located on those lines.
53 Parameters
54 ----------
55 ax : `matplotlib.axes.Axes`
56 The axes on which the compass will be drawn.
57 wcs : `lsst.afw.geom.SkyWcs`
58 WCS from exposure.
59 compassLocation : `int`, optional
60 How far in from the bottom left of the image to display the compass.
61 arrowLength : `float`, optional
62 The length of the compass arrow.
63 Returns
64 -------
65 ax : `matplotlib.axes.Axes`
66 The axes with the compass.
67 """
69 anchorRa, anchorDec = wcs.pixelToSky(compassLocation, compassLocation)
70 east = wcs.skyToPixel(geom.SpherePoint(anchorRa + 30.0 * geom.arcseconds, anchorDec))
71 north = wcs.skyToPixel(geom.SpherePoint(anchorRa, anchorDec + 30.0 * geom.arcseconds))
72 labelPosition = arrowLength + 50.0
74 for xy, label in [(north, "N"), (east, "E")]:
75 if compassLocation == xy[0]:
76 xTip = compassLocation
77 xTipLabel = compassLocation
78 if xy[1] > compassLocation:
79 yTip = compassLocation + arrowLength
80 yTipLabel = compassLocation + labelPosition
81 else:
82 yTip = compassLocation - arrowLength
83 yTipLabel = compassLocation - labelPosition
84 else:
85 slope = (xy[1] - compassLocation) / (xy[0] - compassLocation)
86 xTipProjection = arrowLength / np.sqrt(1.0 + slope**2)
87 xTipLabelProjection = labelPosition / np.sqrt(1.0 + slope**2)
89 if xy[0] > compassLocation:
90 xTip = compassLocation + xTipProjection
91 xTipLabel = compassLocation + xTipLabelProjection
92 elif xy[0] < compassLocation:
93 xTip = compassLocation - xTipProjection
94 xTipLabel = compassLocation - xTipLabelProjection
95 yTip = slope * (xTip - compassLocation) + compassLocation
96 yTipLabel = slope * (xTipLabel - compassLocation) + compassLocation
98 color = "r"
99 ax.arrow(
100 compassLocation,
101 compassLocation,
102 xTip - compassLocation,
103 yTip - compassLocation,
104 head_width=30.0,
105 length_includes_head=True,
106 color=color,
107 )
108 ax.text(xTipLabel, yTipLabel, label, ha="center", va="center", color=color)
109 return ax
112def plot(
113 inputData,
114 figure=None,
115 centroids=None,
116 footprints=None,
117 sourceCat=None,
118 title=None,
119 showCompass=True,
120 stretch="linear",
121 percentile=99.0,
122 cmap="gray",
123 compassLocation=300,
124 addLegend=False,
125 savePlotAs=None,
126 logger=None,
127):
128 """Plot an input image accommodating different data types and additional
129 features, like: overplotting centroids, compass (if the input image
130 has a WCS), stretching, plot title, and legend.
132 Parameters
133 ----------
134 inputData : `numpy.array` or
135 `lsst.afw.image.Exposure` or
136 `lsst.afw.image.Image`, or
137 `lsst.afw.image.MaskedImage`
138 The input data.
139 figure : `matplotlib.figure.Figure`, optional
140 The matplotlib figure that will be used for plotting.
141 centroids : `list`
142 The centroids parameter as a list of tuples.
143 Each tuple is a centroid with its (X,Y) coordinates.
144 footprints: `lsst.afw.detection.FootprintSet` or
145 `lsst.afw.detection.Footprint` or
146 `list` of `lsst.afw.detection.Footprint`
147 The footprints containing centroids to plot.
148 sourceCat: `lsst.afw.table.SourceCatalog`:
149 An `lsst.afw.table.SourceCatalog` object containing centroids
150 to plot.
151 title : `str`, optional
152 Title for the plot.
153 showCompass : `bool`, optional
154 Add compass to the plot? Defaults to True.
155 stretch : `str', optional
156 Changes mapping of colors for the image. Avaliable options:
157 ccs, log, power, asinh, linear, sqrt. Defaults to linear.
158 percentile : `float', optional
159 Parameter for astropy.visualization.PercentileInterval.
160 Sets lower and upper limits for a stretch. This parameter
161 will be ignored if stretch='ccs'.
162 cmap : `str`, optional
163 The colormap to use for mapping the image values to colors. This can be
164 a string representing a predefined colormap. Default is 'gray'.
165 compassLocation : `int`, optional
166 How far in from the bottom left of the image to display the compass.
167 By default, compass will be placed at pixel (x,y) = (300,300).
168 addLegend : `bool', optional
169 Option to add legend to the plot. Recommended if centroids come from
170 different sources. Default value is False.
171 savePlotAs : `str`, optional
172 The name of the file to save the plot as, including the file extension.
173 The extention must be supported by `matplotlib.pyplot`.
174 If None (default) plot will not be saved.
175 logger : `logging.Logger`, optional
176 The logger to use for errors, created if not supplied.
177 Returns
178 -------
179 figure : `matplotlib.figure.Figure`
180 The rendered image.
181 """
183 if not figure:
184 figure = plt.figure(figsize=(10, 10))
186 ax = figure.add_subplot(111)
188 if not logger:
189 logger = logging.getLogger(__name__)
191 match inputData:
192 case np.ndarray():
193 imageData = inputData
194 case afwImage.MaskedImage():
195 imageData = inputData.image.array
196 case afwImage.Image():
197 imageData = inputData.array
198 case afwImage.Exposure():
199 imageData = inputData.image.array
200 case _:
201 raise TypeError(
202 "This function accepts numpy array, lsst.afw.image.Exposure components."
203 f" Got {type(inputData)}"
204 )
206 if np.isnan(imageData).all():
207 im = ax.imshow(imageData, origin="lower", aspect="equal")
208 logger.warning("The imageData contains only NaN values.")
209 else:
210 interval = vis.PercentileInterval(percentile)
211 match stretch:
212 case "ccs":
213 quantiles = getQuantiles(imageData, 256)
214 norm = colors.BoundaryNorm(quantiles, 256)
215 case "asinh":
216 norm = vis.ImageNormalize(imageData, interval=interval, stretch=vis.AsinhStretch(a=0.1))
217 case "power":
218 norm = vis.ImageNormalize(imageData, interval=interval, stretch=vis.PowerStretch(a=2))
219 case "log":
220 norm = vis.ImageNormalize(imageData, interval=interval, stretch=vis.LogStretch(a=1))
221 case "linear":
222 norm = vis.ImageNormalize(imageData, interval=interval, stretch=vis.LinearStretch())
223 case "sqrt":
224 norm = vis.ImageNormalize(imageData, interval=interval, stretch=vis.SqrtStretch())
225 case _:
226 raise ValueError(
227 f"Invalid value for stretch : {stretch}. "
228 "Accepted options are: ccs, asinh, power, log, linear, sqrt."
229 )
231 im = ax.imshow(imageData, cmap=cmap, origin="lower", norm=norm, aspect="equal")
232 div = make_axes_locatable(ax)
233 cax = div.append_axes("right", size="5%", pad=0.05)
234 figure.colorbar(im, cax=cax)
236 if showCompass:
237 try:
238 wcs = inputData.getWcs()
239 except AttributeError:
240 logger.warning("Failed to get WCS from input data. Compass will not be plotted.")
241 wcs = None
243 if wcs:
244 arrowLength = min(imageData.shape) * 0.05
245 ax = drawCompass(ax, wcs, compassLocation=compassLocation, arrowLength=arrowLength)
247 if centroids:
248 ax.plot(
249 *zip(*centroids),
250 marker="x",
251 markeredgecolor="r",
252 markerfacecolor="None",
253 linestyle="None",
254 label="List of centroids",
255 )
257 if sourceCat:
258 ax.plot(
259 list(zip(sourceCat.getX(), sourceCat.getY())),
260 marker="o",
261 markeredgecolor="c",
262 markerfacecolor="None",
263 linestyle="None",
264 label="Source catalog",
265 )
267 if footprints:
268 match footprints:
269 case FootprintSet():
270 fs = FootprintSet.getFootprints(footprints)
271 xy = [_.getCentroid() for _ in fs]
272 case Footprint():
273 xy = [footprints.getCentroid()]
274 case list():
275 xy = []
276 for i, ft in enumerate(footprints):
277 try:
278 ft.getCentroid()
279 except AttributeError:
280 raise TypeError(
281 "Cannot get centroids for one of the "
282 "elements from the footprints list. "
283 "Expected lsst.afw.detection.Footprint, "
284 f"got {type(ft)} for footprints[{i}]"
285 )
286 xy.append(ft.getCentroid())
287 case _:
288 raise TypeError(
289 "This function works with FootprintSets, "
290 "single Footprints, and iterables of Footprints. "
291 f"Got {type(footprints)}"
292 )
294 ax.plot(
295 *zip(*xy),
296 marker="x",
297 markeredgecolor="b",
298 markerfacecolor="None",
299 linestyle="None",
300 label="Footprints centroids",
301 )
303 if addLegend:
304 ax.legend(loc="upper center", bbox_to_anchor=(0.5, -0.05), ncol=5)
306 if title:
307 ax.set_title(title)
309 if savePlotAs:
310 plt.savefig(savePlotAs)
312 return figure