Coverage for python / lsst / summit / utils / plotting.py: 10%
167 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-07 09:02 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-07 09:02 +0000
1# This file is part of summit_utils.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
21from __future__ import annotations
23import logging
24from typing import TYPE_CHECKING
26import astropy.visualization as vis
27import matplotlib
28import matplotlib.colors as colors
29import numpy as np
30from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable
32import lsst.afw.detection as afwDetection
33import lsst.afw.geom as afwGeom
34import lsst.afw.image as afwImage
35import lsst.afw.table as afwTable
36import lsst.geom as geom
37from lsst.summit.utils.utils import getImageArray, getQuantiles
38from lsst.utils.plotting.figures import make_figure
40if TYPE_CHECKING:
41 from matplotlib.figure import Figure
44def drawCompass(
45 ax: matplotlib.axes.Axes,
46 wcs: afwGeom.SkyWcs,
47 compassLocation: int = 300,
48 arrowLength: float = 300.0,
49) -> matplotlib.axes.Axes:
50 """
51 Draw the compass.
52 The arrowLength is the length of compass arrows (arrows should have
53 the same length).
54 The steps here are:
55 - transform the (compassLocation, compassLocation) to RA, DEC coordinates
56 - move this point in DEC to get N; in RA to get E directions
57 - transform N and E points back to pixel coordinates
58 - find linear solutions for lines connecting the center of
59 the compass with N and E points
60 - find points along those lines located at the distance of
61 arrowLength form the (compassLocation, compassLocation).
62 - there will be two points for each linear solution.
63 Choose the correct one.
64 - centers of the N/E labels will also be located on those lines.
66 Parameters
67 ----------
68 ax : `matplotlib.axes.Axes`
69 The axes on which the compass will be drawn.
70 wcs : `lsst.afw.geom.SkyWcs`
71 WCS from exposure.
72 compassLocation : `int`, optional
73 How far in from the bottom left of the image to display the compass.
74 arrowLength : `float`, optional
75 The length of the compass arrow.
76 Returns
77 -------
78 ax : `matplotlib.axes.Axes`
79 The axes with the compass.
80 """
82 anchorRa, anchorDec = wcs.pixelToSky(compassLocation, compassLocation)
83 east = wcs.skyToPixel(geom.SpherePoint(anchorRa + 30.0 * geom.arcseconds, anchorDec))
84 north = wcs.skyToPixel(geom.SpherePoint(anchorRa, anchorDec + 30.0 * geom.arcseconds))
85 labelPosition = arrowLength + 50.0
87 for xy, label in [(north, "N"), (east, "E")]:
88 xTip = compassLocation
89 xTipLabel = compassLocation
90 if compassLocation == xy[0]:
91 if xy[1] > compassLocation:
92 yTip = compassLocation + arrowLength
93 yTipLabel = compassLocation + labelPosition
94 else:
95 yTip = compassLocation - arrowLength
96 yTipLabel = compassLocation - labelPosition
97 else:
98 slope = (xy[1] - compassLocation) / (xy[0] - compassLocation)
99 xTipProjection = arrowLength / np.sqrt(1.0 + slope**2)
100 xTipLabelProjection = labelPosition / np.sqrt(1.0 + slope**2)
102 if xy[0] > compassLocation:
103 xTip = compassLocation + xTipProjection
104 xTipLabel = compassLocation + xTipLabelProjection
105 elif xy[0] < compassLocation:
106 xTip = compassLocation - xTipProjection
107 xTipLabel = compassLocation - xTipLabelProjection
108 yTip = slope * (xTip - compassLocation) + compassLocation
109 yTipLabel = slope * (xTipLabel - compassLocation) + compassLocation
111 color = "r"
112 ax.arrow(
113 compassLocation,
114 compassLocation,
115 xTip - compassLocation,
116 yTip - compassLocation,
117 head_width=30.0,
118 length_includes_head=True,
119 color=color,
120 )
121 ax.text(xTipLabel, yTipLabel, label, ha="center", va="center", color=color)
122 return ax
125def plot(
126 inputData: np.ndarray | afwImage.Exposure | afwImage.Image | afwImage.MaskedImage,
127 figure: matplotlib.figure.Figure | None = None,
128 centroids: list[tuple[int, int]] | None = None,
129 footprints: (
130 afwDetection.FootprintSet | afwDetection.Footprint | list[afwDetection.Footprint] | None
131 ) = None,
132 sourceCat: afwTable.SourceCatalog = None,
133 title: str | None = None,
134 showCompass: bool = True,
135 stretch: str = "linear",
136 percentile: float = 99.0,
137 cmap: str = "gray",
138 compassLocation: int = 300,
139 addLegend: bool = False,
140 savePlotAs: str | None = None,
141 logger: logging.Logger | None = None,
142) -> Figure:
143 """Plot an input image accommodating different data types and additional
144 features, like: overplotting centroids, compass (if the input image
145 has a WCS), stretching, plot title, and legend.
147 Parameters
148 ----------
149 inputData : `numpy.array` or
150 `lsst.afw.image.Exposure` or
151 `lsst.afw.image.Image`, or
152 `lsst.afw.image.MaskedImage`
153 The input data.
154 figure : `matplotlib.figure.Figure`, optional
155 The matplotlib figure that will be used for plotting.
156 centroids : `list`
157 The centroids parameter as a list of tuples.
158 Each tuple is a centroid with its (X,Y) coordinates.
159 footprints: `lsst.afw.detection.FootprintSet` or
160 `lsst.afw.detection.Footprint` or
161 `list` of `lsst.afw.detection.Footprint`
162 The footprints containing centroids to plot.
163 sourceCat: `lsst.afw.table.SourceCatalog`:
164 An `lsst.afw.table.SourceCatalog` object containing centroids
165 to plot.
166 title : `str`, optional
167 Title for the plot.
168 showCompass : `bool`, optional
169 Add compass to the plot? Defaults to True.
170 stretch : `str', optional
171 Changes mapping of colors for the image. Avaliable options:
172 ccs, log, power, asinh, linear, sqrt, midtone. Defaults to linear.
173 percentile : `float', optional
174 Parameter for astropy.visualization.PercentileInterval.
175 Sets lower and upper limits for a stretch. This parameter
176 will be ignored if stretch='ccs'.
177 cmap : `str`, optional
178 The colormap to use for mapping the image values to colors. This can be
179 a string representing a predefined colormap. Default is 'gray'.
180 compassLocation : `int`, optional
181 How far in from the bottom left of the image to display the compass.
182 By default, compass will be placed at pixel (x,y) = (300,300).
183 addLegend : `bool', optional
184 Option to add legend to the plot. Recommended if centroids come from
185 different sources. Default value is False.
186 savePlotAs : `str`, optional
187 The name of the file to save the plot as, including the file extension.
188 The extention must be supported by `matplotlib.pyplot`.
189 If None (default) plot will not be saved.
190 logger : `logging.Logger`, optional
191 The logger to use for errors, created if not supplied.
192 Returns
193 -------
194 figure : `matplotlib.figure.Figure`
195 The rendered image.
196 """
198 if not figure:
199 figure = make_figure(figsize=(10, 10))
201 ax = figure.add_subplot(111)
203 if not logger:
204 logger = logging.getLogger(__name__)
206 imageData = getImageArray(inputData)
208 if np.isnan(imageData).all():
209 im = ax.imshow(imageData, origin="lower", aspect="equal")
210 logger.warning("The imageData contains only NaN values.")
211 else:
212 interval = vis.PercentileInterval(percentile)
213 match stretch:
214 case "ccs":
215 quantiles = getQuantiles(imageData, 256)
216 norm = colors.BoundaryNorm(quantiles, 256)
217 case "asinh":
218 norm = vis.ImageNormalize(imageData, interval=interval, stretch=vis.AsinhStretch(a=0.1))
219 case "power":
220 norm = vis.ImageNormalize(imageData, interval=interval, stretch=vis.PowerStretch(a=2))
221 case "log":
222 norm = vis.ImageNormalize(imageData, interval=interval, stretch=vis.LogStretch(a=1))
223 case "linear":
224 norm = vis.ImageNormalize(imageData, interval=interval, stretch=vis.LinearStretch())
225 case "sqrt":
226 norm = vis.ImageNormalize(imageData, interval=interval, stretch=vis.SqrtStretch())
227 case "midtone":
228 imageData = stretchDataMidTone(imageData)
229 # no interval in this norm as imageData is now [0, 1] aready
230 norm = vis.ImageNormalize(imageData, stretch=vis.LinearStretch())
231 case _:
232 raise ValueError(
233 f"Invalid value for stretch : {stretch}. "
234 "Accepted options are: ccs, asinh, power, log, linear, sqrt."
235 )
237 im = ax.imshow(imageData, cmap=cmap, origin="lower", norm=norm, aspect="equal", interpolation="auto")
239 if stretch != "midtone":
240 div = make_axes_locatable(ax)
241 cax = div.append_axes("right", size="5%", pad=0.05)
242 figure.colorbar(im, cax=cax)
244 if showCompass:
245 try:
246 assert hasattr(inputData, "getWcs"), "inputData does not have a getWcs method"
247 wcs = inputData.getWcs()
248 except AssertionError:
249 logger.warning("Failed to get WCS from input data. Compass will not be plotted.")
250 wcs = None
252 if wcs:
253 arrowLength = min(imageData.shape) * 0.05
254 ax = drawCompass(ax, wcs, compassLocation=compassLocation, arrowLength=arrowLength)
256 if centroids:
257 ax.plot(
258 *zip(*centroids),
259 marker="x",
260 markeredgecolor="r",
261 markerfacecolor="None",
262 linestyle="None",
263 label="List of centroids",
264 )
266 if sourceCat:
267 ax.scatter(
268 sourceCat.getX(),
269 sourceCat.getY(),
270 marker="o",
271 edgecolors="c", # cyan rings
272 c="None", # empty cicrles (no fill)
273 label="Source catalog",
274 )
276 if footprints:
277 match footprints:
278 case afwDetection.FootprintSet():
279 fs = afwDetection.FootprintSet.getFootprints(footprints)
280 xy = [_.getCentroid() for _ in fs]
281 case afwDetection.Footprint():
282 xy = [footprints.getCentroid()]
283 case list():
284 xy = []
285 for i, ft in enumerate(footprints):
286 try:
287 ft.getCentroid()
288 except AttributeError:
289 raise TypeError(
290 "Cannot get centroids for one of the "
291 "elements from the footprints list. "
292 "Expected lsst.afw.detection.Footprint, "
293 f"got {type(ft)} for footprints[{i}]"
294 )
295 xy.append(ft.getCentroid())
296 case _:
297 raise TypeError(
298 "This function works with FootprintSets, "
299 "single Footprints, and iterables of Footprints. "
300 f"Got {type(footprints)}"
301 )
303 ax.plot(
304 *zip(*xy),
305 marker="x",
306 markeredgecolor="b",
307 markerfacecolor="None",
308 linestyle="None",
309 label="Footprints centroids",
310 )
312 if addLegend:
313 ax.legend(loc="upper center", bbox_to_anchor=(0.5, -0.05), ncol=5)
315 if title:
316 ax.set_title(title)
318 if savePlotAs:
319 figure.savefig(savePlotAs)
321 return figure
324def _computeMtf(image: np.ndarray, midtonesBalance: float) -> np.ndarray:
325 """
326 Compute the midtones transfer function (MTF) for an image.
328 Parameters
329 ----------
330 image : `np.ndarray`
331 Input image normalized to [0, 1].
332 midtonesBalance : `float`
333 Balance parameter controlling midtones emphasis.
335 Returns
336 -------
337 image : `np.ndarray`
338 Image after applying the MTF, with NaNs replaced by fallback values.
339 """
340 M = np.full(image.shape, midtonesBalance)
341 maskHalf = image == 0.5
342 maskZero = image == 0
343 maskOne = image == 1
344 fallback = (maskHalf * 0.5) * (1 - maskZero) + maskOne
345 result = (M - 1) * image / ((2 * M - 1) * image - M)
346 nanMask = ~np.isfinite(result)
347 result[nanMask] = fallback[nanMask]
348 return result
351def _applyClip(image: np.ndarray, clipLow: float, clipHigh: float) -> np.ndarray:
352 """
353 Linearly clip and scale image values between clipLow and clipHigh.
355 Parameters
356 ----------
357 image : `np.ndarray`
358 Input image normalized to [0, 1].
359 clipLow : `float`
360 Lower clipping threshold.
361 clipHigh : `float`
362 Upper clipping threshold.
364 Returns
365 -------
366 clipped : `np.ndarray`
367 Clipped and scaled image, with values in [0, 1].
368 """
369 belowLow = image < clipLow
370 aboveHigh = image > clipHigh
371 scaled = (image - clipLow) / (clipHigh - clipLow)
372 return np.clip(scaled * (~belowLow) + aboveHigh, 0.0, 1.0)
375def _applyExpansion(image: np.ndarray, outMin: float, outMax: float) -> np.ndarray:
376 """
377 Expand image dynamic range from [outMin, outMax] to [0, 1].
379 Parameters
380 ----------
381 image : `np.ndarray`
382 Input image after MTF.
383 outMin : `float`
384 Minimum output value (usually 0.0).
385 outMax : `float`
386 Maximum output value (usually 1.0).
388 Returns
389 -------
390 expanded : `np.ndarray`
391 Expanded image in [0, 1].
392 """
393 return (image - outMin) / (outMax - outMin)
396def _applyDisplayFunction(
397 image: np.ndarray, midtonesBalance: float, clipLow: float, clipHigh: float, outMin: float, outMax: float
398) -> np.ndarray:
399 """
400 Apply the full display function: clip, MTF, then expansion.
402 Parameters
403 ----------
404 image : `np.ndarray`
405 Input image normalized to [0, 1].
406 midtonesBalance : `float`
407 Midtones balance parameter.
408 clipLow : `float`
409 Lower clipping threshold.
410 clipHigh : `float`
411 Upper clipping threshold.
412 outMin : `float`
413 Minimum output of expansion.
414 outMax : `float`
415 Maximum output of expansion.
417 Returns
418 -------
419 np.ndarray
420 Stretched image ready for display.
421 """
422 clipped = _applyClip(image, clipLow, clipHigh)
423 mtf = _computeMtf(clipped, midtonesBalance)
424 return _applyExpansion(mtf, outMin, outMax)
427def _computeDisplayParameters(data: np.ndarray) -> tuple[float, float, float, float, float]:
428 """
429 Compute parameters for display function based on data statistics.
431 Parameters
432 ----------
433 data : `np.ndarray`
434 Normalized image data array.
436 Returns
437 -------
438 tuple[float, float, float, float, float]
439 midtonesBalance, clipLow, clipHigh, outMin, outMax
440 """
441 median = np.median(data)
442 deviations = np.abs(data.ravel() - median)
443 madn = 1.4826 * np.median(np.sort(deviations))
444 targetBackground = 0.25
445 clippingFactor = -2.8
447 aboveHalf = median > 0.5
449 if not aboveHalf and madn != 0:
450 clipLow = min(1.0, max(0.0, median + clippingFactor * madn))
451 else:
452 clipLow = 0.0
454 if aboveHalf and madn != 0:
455 clipHigh = min(1.0, max(0.0, median - clippingFactor * madn))
456 else:
457 clipHigh = 1.0
459 if median <= 0.5:
460 midtonesBalance = (
461 (targetBackground - 1)
462 * (median - clipLow)
463 / ((2 * targetBackground - 1) * (median - clipLow) - targetBackground)
464 )
465 else:
466 midtonesBalance = (
467 (clipHigh - median - 1)
468 * targetBackground
469 / (2 * (clipHigh - median - 1) * targetBackground - (clipHigh - median))
470 )
472 return midtonesBalance, clipLow, clipHigh, 0.0, 1.0
475def stretchDataMidTone(
476 imageLike: np.ndarray | afwImage.Exposure | afwImage.Image | afwImage.MaskedImage,
477) -> np.ndarray:
478 """
479 Normalize and stretch image data from an Exposure object using the Midtone
480 Transfer Function (MTF).
482 This is following:
483 https://pixinsight.com/doc/docs/XISF-1.0-spec/XISF-1.0-spec.html
484 #__XISF_Data_Objects_:_XISF_Image_:_Display_Function__
486 Parameters
487 ----------
488 imageLike : `numpy.ndarray`, `lsst.afw.image.Exposure`,
489 `lsst.afw.image.Image`, or `lsst.afw.image.MaskedImage`
490 The image-like object containg the data to be stretched.
492 Returns
493 -------
494 stretched : `np.ndarray`
495 The stretched image array.
496 """
497 data = getImageArray(imageLike)
499 pedestal = np.min(data)
500 if pedestal >= 0.0:
501 norm = np.max(data)
502 normalized = data / norm
503 else:
504 norm = np.max(data - pedestal)
505 normalized = (data - pedestal) / norm
507 midtonesBalance, clipLow, clipHigh, outMin, outMax = _computeDisplayParameters(normalized)
508 stretched = _applyDisplayFunction(normalized, midtonesBalance, clipLow, clipHigh, outMin, outMax)
509 return stretched