Coverage for python/lsst/summit/utils/plotting.py: 9%
122 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-12 03:03 -0700
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-12 03:03 -0700
1# This file is part of summit_utils.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import logging
24import astropy.visualization as vis
25import matplotlib
26import matplotlib.colors as colors
27import matplotlib.pyplot as plt
28import numpy as np
29from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable
31import lsst.afw.detection as afwDetection
32import lsst.afw.geom as afwGeom
33import lsst.afw.image as afwImage
34import lsst.afw.table as afwTable
35import lsst.geom as geom
36from lsst.afw.detection import Footprint, FootprintSet
37from lsst.summit.utils import getQuantiles
40def drawCompass(
41 ax: matplotlib.axes.Axes,
42 wcs: afwGeom.SkyWcs,
43 compassLocation: int = 300,
44 arrowLength: float = 300.0,
45) -> matplotlib.axes.Axes:
46 """
47 Draw the compass.
48 The arrowLength is the length of compass arrows (arrows should have
49 the same length).
50 The steps here are:
51 - transform the (compassLocation, compassLocation) to RA, DEC coordinates
52 - move this point in DEC to get N; in RA to get E directions
53 - transform N and E points back to pixel coordinates
54 - find linear solutions for lines connecting the center of
55 the compass with N and E points
56 - find points along those lines located at the distance of
57 arrowLength form the (compassLocation, compassLocation).
58 - there will be two points for each linear solution.
59 Choose the correct one.
60 - centers of the N/E labels will also be located on those lines.
62 Parameters
63 ----------
64 ax : `matplotlib.axes.Axes`
65 The axes on which the compass will be drawn.
66 wcs : `lsst.afw.geom.SkyWcs`
67 WCS from exposure.
68 compassLocation : `int`, optional
69 How far in from the bottom left of the image to display the compass.
70 arrowLength : `float`, optional
71 The length of the compass arrow.
72 Returns
73 -------
74 ax : `matplotlib.axes.Axes`
75 The axes with the compass.
76 """
78 anchorRa, anchorDec = wcs.pixelToSky(compassLocation, compassLocation)
79 east = wcs.skyToPixel(geom.SpherePoint(anchorRa + 30.0 * geom.arcseconds, anchorDec))
80 north = wcs.skyToPixel(geom.SpherePoint(anchorRa, anchorDec + 30.0 * geom.arcseconds))
81 labelPosition = arrowLength + 50.0
83 for xy, label in [(north, "N"), (east, "E")]:
84 if compassLocation == xy[0]:
85 xTip = compassLocation
86 xTipLabel = compassLocation
87 if xy[1] > compassLocation:
88 yTip = compassLocation + arrowLength
89 yTipLabel = compassLocation + labelPosition
90 else:
91 yTip = compassLocation - arrowLength
92 yTipLabel = compassLocation - labelPosition
93 else:
94 slope = (xy[1] - compassLocation) / (xy[0] - compassLocation)
95 xTipProjection = arrowLength / np.sqrt(1.0 + slope**2)
96 xTipLabelProjection = labelPosition / np.sqrt(1.0 + slope**2)
98 if xy[0] > compassLocation:
99 xTip = compassLocation + xTipProjection
100 xTipLabel = compassLocation + xTipLabelProjection
101 elif xy[0] < compassLocation:
102 xTip = compassLocation - xTipProjection
103 xTipLabel = compassLocation - xTipLabelProjection
104 yTip = slope * (xTip - compassLocation) + compassLocation
105 yTipLabel = slope * (xTipLabel - compassLocation) + compassLocation
107 color = "r"
108 ax.arrow(
109 compassLocation,
110 compassLocation,
111 xTip - compassLocation,
112 yTip - compassLocation,
113 head_width=30.0,
114 length_includes_head=True,
115 color=color,
116 )
117 ax.text(xTipLabel, yTipLabel, label, ha="center", va="center", color=color)
118 return ax
121def plot(
122 inputData: np.ndarray | afwImage.Exposure | afwImage.Image | afwImage.MaskedImage,
123 figure: matplotlib.figure.Figure | None = None,
124 centroids: list[tuple[int, int]] | None = None,
125 footprints: (
126 afwDetection.FootprintSet | afwDetection.Footprint | list[afwDetection.Footprint] | None
127 ) = None,
128 sourceCat: afwTable.SourceCatalog = None,
129 title: str | None = None,
130 showCompass: bool = True,
131 stretch: str = "linear",
132 percentile: float = 99.0,
133 cmap: str = "gray",
134 compassLocation: int = 300,
135 addLegend: bool = False,
136 savePlotAs: str | None = None,
137 logger: logging.Logger | None = None,
138) -> matplotlib.figure.Figure:
139 """Plot an input image accommodating different data types and additional
140 features, like: overplotting centroids, compass (if the input image
141 has a WCS), stretching, plot title, and legend.
143 Parameters
144 ----------
145 inputData : `numpy.array` or
146 `lsst.afw.image.Exposure` or
147 `lsst.afw.image.Image`, or
148 `lsst.afw.image.MaskedImage`
149 The input data.
150 figure : `matplotlib.figure.Figure`, optional
151 The matplotlib figure that will be used for plotting.
152 centroids : `list`
153 The centroids parameter as a list of tuples.
154 Each tuple is a centroid with its (X,Y) coordinates.
155 footprints: `lsst.afw.detection.FootprintSet` or
156 `lsst.afw.detection.Footprint` or
157 `list` of `lsst.afw.detection.Footprint`
158 The footprints containing centroids to plot.
159 sourceCat: `lsst.afw.table.SourceCatalog`:
160 An `lsst.afw.table.SourceCatalog` object containing centroids
161 to plot.
162 title : `str`, optional
163 Title for the plot.
164 showCompass : `bool`, optional
165 Add compass to the plot? Defaults to True.
166 stretch : `str', optional
167 Changes mapping of colors for the image. Avaliable options:
168 ccs, log, power, asinh, linear, sqrt. Defaults to linear.
169 percentile : `float', optional
170 Parameter for astropy.visualization.PercentileInterval.
171 Sets lower and upper limits for a stretch. This parameter
172 will be ignored if stretch='ccs'.
173 cmap : `str`, optional
174 The colormap to use for mapping the image values to colors. This can be
175 a string representing a predefined colormap. Default is 'gray'.
176 compassLocation : `int`, optional
177 How far in from the bottom left of the image to display the compass.
178 By default, compass will be placed at pixel (x,y) = (300,300).
179 addLegend : `bool', optional
180 Option to add legend to the plot. Recommended if centroids come from
181 different sources. Default value is False.
182 savePlotAs : `str`, optional
183 The name of the file to save the plot as, including the file extension.
184 The extention must be supported by `matplotlib.pyplot`.
185 If None (default) plot will not be saved.
186 logger : `logging.Logger`, optional
187 The logger to use for errors, created if not supplied.
188 Returns
189 -------
190 figure : `matplotlib.figure.Figure`
191 The rendered image.
192 """
194 if not figure:
195 figure = plt.figure(figsize=(10, 10))
197 ax = figure.add_subplot(111)
199 if not logger:
200 logger = logging.getLogger(__name__)
202 match inputData:
203 case np.ndarray():
204 imageData = inputData
205 case afwImage.MaskedImage():
206 imageData = inputData.image.array
207 case afwImage.Image():
208 imageData = inputData.array
209 case afwImage.Exposure():
210 imageData = inputData.image.array
211 case _:
212 raise TypeError(
213 "This function accepts numpy array, lsst.afw.image.Exposure components."
214 f" Got {type(inputData)}"
215 )
217 if np.isnan(imageData).all():
218 im = ax.imshow(imageData, origin="lower", aspect="equal")
219 logger.warning("The imageData contains only NaN values.")
220 else:
221 interval = vis.PercentileInterval(percentile)
222 match stretch:
223 case "ccs":
224 quantiles = getQuantiles(imageData, 256)
225 norm = colors.BoundaryNorm(quantiles, 256)
226 case "asinh":
227 norm = vis.ImageNormalize(imageData, interval=interval, stretch=vis.AsinhStretch(a=0.1))
228 case "power":
229 norm = vis.ImageNormalize(imageData, interval=interval, stretch=vis.PowerStretch(a=2))
230 case "log":
231 norm = vis.ImageNormalize(imageData, interval=interval, stretch=vis.LogStretch(a=1))
232 case "linear":
233 norm = vis.ImageNormalize(imageData, interval=interval, stretch=vis.LinearStretch())
234 case "sqrt":
235 norm = vis.ImageNormalize(imageData, interval=interval, stretch=vis.SqrtStretch())
236 case _:
237 raise ValueError(
238 f"Invalid value for stretch : {stretch}. "
239 "Accepted options are: ccs, asinh, power, log, linear, sqrt."
240 )
242 im = ax.imshow(imageData, cmap=cmap, origin="lower", norm=norm, aspect="equal")
243 div = make_axes_locatable(ax)
244 cax = div.append_axes("right", size="5%", pad=0.05)
245 figure.colorbar(im, cax=cax)
247 if showCompass:
248 try:
249 wcs = inputData.getWcs()
250 except AttributeError:
251 logger.warning("Failed to get WCS from input data. Compass will not be plotted.")
252 wcs = None
254 if wcs:
255 arrowLength = min(imageData.shape) * 0.05
256 ax = drawCompass(ax, wcs, compassLocation=compassLocation, arrowLength=arrowLength)
258 if centroids:
259 ax.plot(
260 *zip(*centroids),
261 marker="x",
262 markeredgecolor="r",
263 markerfacecolor="None",
264 linestyle="None",
265 label="List of centroids",
266 )
268 if sourceCat:
269 ax.plot(
270 list(zip(sourceCat.getX(), sourceCat.getY())),
271 marker="o",
272 markeredgecolor="c",
273 markerfacecolor="None",
274 linestyle="None",
275 label="Source catalog",
276 )
278 if footprints:
279 match footprints:
280 case FootprintSet():
281 fs = FootprintSet.getFootprints(footprints)
282 xy = [_.getCentroid() for _ in fs]
283 case Footprint():
284 xy = [footprints.getCentroid()]
285 case list():
286 xy = []
287 for i, ft in enumerate(footprints):
288 try:
289 ft.getCentroid()
290 except AttributeError:
291 raise TypeError(
292 "Cannot get centroids for one of the "
293 "elements from the footprints list. "
294 "Expected lsst.afw.detection.Footprint, "
295 f"got {type(ft)} for footprints[{i}]"
296 )
297 xy.append(ft.getCentroid())
298 case _:
299 raise TypeError(
300 "This function works with FootprintSets, "
301 "single Footprints, and iterables of Footprints. "
302 f"Got {type(footprints)}"
303 )
305 ax.plot(
306 *zip(*xy),
307 marker="x",
308 markeredgecolor="b",
309 markerfacecolor="None",
310 linestyle="None",
311 label="Footprints centroids",
312 )
314 if addLegend:
315 ax.legend(loc="upper center", bbox_to_anchor=(0.5, -0.05), ncol=5)
317 if title:
318 ax.set_title(title)
320 if savePlotAs:
321 plt.savefig(savePlotAs)
323 return figure