Coverage for python / lsst / analysis / tools / actions / plot / wholeTractImage.py: 15%
184 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-06 09:07 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-06 09:07 +0000
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ("WholeTractImage",)
26from collections.abc import Mapping
28import matplotlib.cm as cm
29import matplotlib.patches as patches
30import matplotlib.patheffects as pathEffects
31import matplotlib.pyplot as plt
32import numpy as np
33from astropy.visualization import ImageNormalize
34from matplotlib.figure import Figure
36from lsst.pex.config import (
37 ChoiceField,
38 Field,
39 FieldValidationError,
40 ListField,
41)
42from lsst.pex.config.configurableActions import ConfigurableActionField
43from lsst.skymap import BaseSkyMap
44from lsst.utils.plotting import make_figure, set_rubin_plotstyle
46from ...interfaces import (
47 KeyedData,
48 KeyedDataSchema,
49 PlotAction,
50 TensorAction,
51 VectorAction,
52)
53from ...utils import getPatchCorners, getTractCorners
54from .calculateRange import Asinh, Perc
57class WholeTractImage(PlotAction):
58 """
59 Produces a figure displaying whole-tract coadd pixel data as a 2D image.
61 The figure is constructed from all patches covering the tract. Regions of
62 NO_DATA or where no coadd exists are shown as red shading or red hatches,
63 respectively.
65 Either the image, pixel mask, or variance components of the coadd can be
66 displayed. In the case of the pixel mask, one or more bitmaskPlanes must
67 be specified; the specified bitmaskPlanes are OR-combined, with flagged
68 pixels given a value of 1, and unflagged pixels given a value of 1.
69 """
71 component = ChoiceField[str](
72 doc="Coadd component to display. Can take one of image, mask, variance. Default: image.",
73 default="image",
74 allowed={plane: plane for plane in ("image", "mask", "variance")},
75 )
77 bitmaskPlanes = ListField[str](
78 doc="List of names of bitmask plane(s) to display when displaying the "
79 "mask plane. Bitmask planes are OR-combined. Flagged pixels are given "
80 "a value of 1; unflagged pixels are given a value of 0. "
81 "Optional when displaying either the image or variance planes. "
82 "Required when displaying the mask plane.",
83 optional=True,
84 )
86 showPatchIds = Field[bool](
87 doc="Show the patch IDs in the centre of each patch. Default: False",
88 default=False,
89 )
91 showColorbar = Field[bool](
92 doc="Show a colorbar alongside the main plot. Default: False",
93 default=False,
94 )
96 zAxisLabel = Field[str](
97 doc="Label to display on the colorbar. Optional",
98 optional=True,
99 )
101 interval = ConfigurableActionField[VectorAction](
102 doc="Action to calculate the min and max values of the image scale. Default: Perc.",
103 default=Perc,
104 )
106 colorbarCmap = ChoiceField[str](
107 doc="Matplotlib colormap to use for the displayed image. Default: gray",
108 default="gray",
109 allowed={name: name for name in plt.colormaps()},
110 )
112 noDataColor = Field[str](
113 doc="Matplotlib color to use to indicate regions of no data. Default: red",
114 default="red",
115 )
117 noDataValue = Field[int](
118 doc="If data doesn't contain a mask plane, the value in the image plane to "
119 "assign the noDataColor to. Optional.",
120 optional=True,
121 )
123 vmaxFloor = Field[float](
124 doc="The floor of the vmax value of the colorbar",
125 default=None,
126 optional=True,
127 )
129 stretch = ConfigurableActionField[TensorAction](
130 doc="Action to calculate the stretch of the image scale. Default: Asinh",
131 default=Asinh,
132 )
134 displayAsPostageStamp = Field[bool](
135 doc="Display as a figure to be used as postage stamp. No plotInfo or legend is shown, "
136 "and large fonts are used for axis labels.",
137 default=False,
138 )
140 def validate(self):
141 super().validate()
143 if self.component == "mask" and self.bitmaskPlanes is None:
144 raise FieldValidationError(
145 self.__class__.bitmaskPlanes,
146 self,
147 "'bitmaskPlanes' must be specified if displaying the mask plane.",
148 )
149 if self.bitmaskPlanes is not None and self.component != "mask":
150 raise FieldValidationError(
151 self.__class__.component,
152 self,
153 "'component' must be set to the mask plane if 'bitmaskPlanes' is specified.",
154 )
156 def getInputSchema(self) -> KeyedDataSchema:
157 base = []
158 base.append((self.component, KeyedData))
159 return base
161 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure:
162 self._validateInput(data, **kwargs)
163 return self.makeFigure(data, **kwargs)
165 def _validateInput(self, data: KeyedData, **kwargs) -> None:
166 needed = self.getInputSchema()
167 if remainder := {key.format(**kwargs) for key, _ in needed} - {
168 key.format(**kwargs) for key in data.keys()
169 }:
170 raise ValueError(f"Task needs keys {remainder} but they were not found in the input data")
172 def makeFigure(
173 self,
174 data: KeyedData,
175 tractId: int,
176 skymap: BaseSkyMap,
177 plotInfo: Mapping[str, str] | None = None,
178 **kwargs,
179 ) -> Figure:
180 """Make a figure displaying the input pixel data.
182 Parameters
183 ----------
184 data : `lsst.analysis.tools.interfaces.KeyedData`
185 A python dict-of-dicts containing the pixel data to display in the
186 figure. The top level keys are named after the coadd component(s),
187 and must contain at least 'mask'. The next level keys are named
188 after the patch ID of the coadd component contained as their
189 corresponding value.
190 tractId : `int`
191 Identification number of the tract to be displayed.
192 skymap : `lsst.skymap.BaseSkyMap`
193 The sky map used for this dataset. This is referred-to to determine
194 the location of the tract on-sky (for RA and Dec axis ranges) and
195 the location of the patches within the tract.
196 plotInfo : `dict`, optional
197 A dictionary of information about the data being plotted with keys:
199 ``"run"``
200 The output run for the plots (`str`).
201 ``"skymap"``
202 The type of skymap used for the data (`str`).
203 ``"band"``
204 The filter used for this data (`str`). Optional
205 ``"tract"``
206 The tract that the data comes from (`str`).
208 Returns
209 -------
210 fig : `matplotlib.figure.Figure`
211 The resulting figure.
213 Examples
214 --------
215 An example wholeTractImage plot may be seen below:
217 .. image:: /_static/analysis_tools/wholeTractImageExample.png
219 For further details on how to generate a plot, please refer to the
220 :ref:`getting started guide<analysis-tools-getting-started>`.
221 """
223 tractInfo = skymap.generateTract(tractId)
224 tractCorners = getTractCorners(skymap, tractId)
225 tractRas = [ra for (ra, dec) in tractCorners]
226 RaSpansZero = max(tractRas) > 360.0
228 cmap = cm.get_cmap(self.colorbarCmap).reversed().copy()
229 cmap.set_bad(self.noDataColor, alpha=0.6 if self.noDataColor == "red" else 1.0)
231 set_rubin_plotstyle()
232 fig = make_figure()
233 ax = fig.add_subplot(111)
235 if plotInfo is None:
236 plotInfo = {}
237 plotInfo["component"] = self.component
239 if self.bitmaskPlanes is not None:
240 plotInfo["maskPlanes"] = self.bitmaskPlanes
242 if self.displayAsPostageStamp:
243 axisLabelFontSize = 20
244 tickMarkFontSize = 10
245 boundaryColor = "k"
246 boundaryAlpha = 1.0
247 boundaryWidth = 0.5
248 else:
249 axisLabelFontSize = 8
250 tickMarkFontSize = 8
251 boundaryColor = "r" if "viridis" in self.colorbarCmap.lower() else "c"
252 boundaryAlpha = 0.3
253 boundaryWidth = 1.0
255 # Keep a record of the "empty" patches that do not have coadds.
256 emptyPatches = np.arange(tractInfo.getNumPatches()[0] * tractInfo.getNumPatches()[1]).tolist()
258 # Extract the pixel arrays for all patches prior to plotting.
259 # This allows for a global image normalisation to be calculated.
260 imStack = dict()
261 allPix = np.array([])
262 patchIds = data[self.component].keys()
263 first = True
264 for patchId in patchIds:
266 if first:
267 if "mask" in data:
268 noDataBitmask = data["mask"][patchId].getPlaneBitMask("NO_DATA")
269 if self.bitmaskPlanes:
270 bitmasks = data["mask"][patchId].getPlaneBitMask(self.bitmaskPlanes)
271 first = False
273 emptyPatches.remove(patchId)
274 im = data[self.component][patchId].array
275 if self.bitmaskPlanes:
276 im = (im & bitmasks > 0) * 1.0
278 if "mask" in data:
279 noDataMask = data["mask"][patchId].array & noDataBitmask > 0
280 elif self.noDataValue is not None:
281 noDataMask = data[self.component][patchId].array == self.noDataValue
282 else:
283 noDataMask = np.zeros_like(data[self.component][patchId].array) > 0
285 allPix = np.append(allPix, im[~noDataMask].flatten())
286 imStack[patchId] = np.ma.masked_array(im, mask=noDataMask)
288 # It is possible that all pixels are flagged NO_DATA.
289 # In which case, set vmin & vmax to arbitrary values.
290 if len(allPix) == 0:
291 vmin, vmax = (0, 1)
292 else:
293 vmin, vmax = self.interval(allPix)
295 # Set a floor to vmax. Useful for low dymanic range data.
296 if self.vmaxFloor is not None:
297 vmax = max(vmax, self.vmaxFloor)
299 for patchId, im in imStack.items():
301 # Create the patch axes at the appropriate location in tract:
302 patchCorners = getPatchCorners(tractInfo, patchId)
303 ras = [ra for (ra, dec) in patchCorners]
304 decs = [dec for (ra, dec) in patchCorners]
306 # Account for the RA wrapping using negative RA values.
307 # This is rectified when the final axes are built.
308 if RaSpansZero:
309 ras = [ra - 360 if ra > 180.0 else ra for ra in ras]
310 Extent = (max(ras), min(ras), max(decs), min(decs))
311 ax.plot(
312 [min(ras), max(ras), max(ras), min(ras), min(ras)],
313 [min(decs), min(decs), max(decs), max(decs), min(decs)],
314 boundaryColor,
315 lw=boundaryWidth,
316 alpha=boundaryAlpha,
317 )
319 norm = ImageNormalize(vmin=vmin, vmax=vmax)
320 stretchedIm = self.stretch(norm(im))
321 masked_stretched = np.ma.masked_array(
322 norm.inverse(stretchedIm.data),
323 mask=stretchedIm.mask,
324 )
325 plotIm = ax.imshow(masked_stretched, vmin=vmin, vmax=vmax, extent=Extent, cmap=cmap)
327 if self.showPatchIds:
328 ax.annotate(
329 patchId,
330 (np.mean(ras), np.mean(decs)),
331 color="k",
332 ha="center",
333 va="center",
334 path_effects=[pathEffects.withStroke(linewidth=2, foreground="w")],
335 )
337 # Indicate the empty patches with red hatching
338 for patchId in emptyPatches:
340 patchCorners = getPatchCorners(tractInfo, patchId)
341 ras = [ra for (ra, dec) in patchCorners]
342 decs = [dec for (ra, dec) in patchCorners]
344 # Account for the RA wrapping using negative RA values.
345 if RaSpansZero:
346 ras = [ra - 360 if ra > 180.0 else ra for ra in ras]
348 Extent = (max(ras), min(ras), max(decs), min(decs))
349 ax.plot(
350 [min(ras), max(ras), max(ras), min(ras), min(ras)],
351 [min(decs), min(decs), max(decs), max(decs), min(decs)],
352 boundaryColor,
353 lw=boundaryWidth,
354 alpha=boundaryAlpha,
355 )
357 cs = ax.contourf(np.ones((10, 10)), 1, hatches=["xx"], extent=Extent, colors="none")
358 cs.set_edgecolors("red")
359 if self.showPatchIds:
360 ax.annotate(
361 patchId,
362 (np.mean(ras), np.mean(decs)),
363 color="k",
364 ha="center",
365 va="center",
366 path_effects=[pathEffects.withStroke(linewidth=2, foreground="w")],
367 )
369 # Draw axes around the entire tract:
370 ax.set_xlabel("R.A. (deg)", fontsize=axisLabelFontSize)
371 ax.set_ylabel("Dec. (deg)", fontsize=axisLabelFontSize)
373 tractRas = [ra for (ra, dec) in tractCorners]
374 # Account for the RA wrapping using negative RA values.
375 if RaSpansZero:
376 tractRas = [ra - 360.0 for ra in tractRas]
378 ax.set_xlim(max(tractRas), min(tractRas))
379 ticks = [t for t in ax.get_xticks() if t >= min(tractRas) and t <= max(tractRas)]
381 # Rectify potential negative RA values via tick labels
382 tickLabels = [f"{t % 360:.1f}" for t in ticks]
383 ax.set_xticks(ticks, tickLabels)
385 tractDecs = [dec for (ra, dec) in tractCorners]
386 ax.set_ylim(min(tractDecs), max(tractDecs))
388 ax.tick_params(axis="both", labelsize=tickMarkFontSize, length=0, pad=1.5)
390 if self.showColorbar:
391 cax = fig.add_axes([0.90, 0.11, 0.04, 0.77])
392 cbar = fig.colorbar(plotIm, cax=cax, extend="both")
393 cbar.ax.tick_params(labelsize=tickMarkFontSize)
394 if self.zAxisLabel:
395 colorbarLabel = self.zAxisLabel
396 else:
397 colorbarLabel = ""
398 text = cax.text(
399 0.5,
400 0.5,
401 colorbarLabel,
402 color="k",
403 rotation="vertical",
404 transform=cax.transAxes,
405 ha="center",
406 va="center",
407 fontsize=10,
408 )
409 text.set_path_effects([pathEffects.Stroke(linewidth=3, foreground="w"), pathEffects.Normal()])
411 if self.displayAsPostageStamp:
412 if "band" in plotInfo:
413 title = f"{str(tractId)}; {plotInfo['band']}"
414 else:
415 title = f"{str(tractId)}"
416 ax.set_title(title, fontsize=20)
418 if not self.displayAsPostageStamp:
419 if "mask" in data:
420 noDataPatch = patches.Rectangle(
421 (0.8, 1.1), 0.05, 0.04, transform=ax.transAxes, facecolor="red", alpha=0.6, clip_on=False
422 )
423 ax.add_patch(noDataPatch)
424 ax.text(0.86, 1.115, "NO_DATA", transform=ax.transAxes, va="center", ha="left", fontsize=8)
426 noCoaddPatch = patches.Rectangle(
427 (0.8, 1.02),
428 0.05,
429 0.04,
430 transform=ax.transAxes,
431 facecolor="none",
432 edgecolor="red",
433 hatch="xx",
434 clip_on=False,
435 )
436 ax.add_patch(noCoaddPatch)
437 ax.text(0.86, 1.04, "No coadd", transform=ax.transAxes, va="center", ha="left", fontsize=8)
439 fig = addPlotInfo(fig, plotInfo)
440 fig.canvas.draw()
442 return fig
445def addPlotInfo(fig: Figure, plotInfo: Mapping[str, str]) -> Figure:
446 """Add useful information to the plot.
448 Parameters
449 ----------
450 fig : `matplotlib.figure.Figure`
451 The figure to add the information to.
452 plotInfo : `dict`
453 A dictionary of the plot information.
455 Returns
456 -------
457 fig : `matplotlib.figure.Figure`
458 The figure with the information added.
459 """
460 fig.text(0.01, 0.99, plotInfo["plotName"], fontsize=7, transform=fig.transFigure, ha="left", va="top")
461 infoText = parsePlotInfo(plotInfo)
462 fig.text(0.01, 0.984, infoText, fontsize=6, transform=fig.transFigure, alpha=0.6, ha="left", va="top")
464 return fig
467def parsePlotInfo(plotInfo: Mapping[str, str]) -> str:
468 """Extract information from the plotInfo dictionary and parses it into
469 a meaningful string that can be added to a figure. The default function
470 in .plotUtils is not suitable for image plotting.
472 Parameters
473 ----------
474 plotInfo : `dict`[`str`, `str`]
475 A plotInfo dictionary containing useful information to
476 be included on a figure.
478 Returns
479 -------
480 infoText : `str`
481 A string containing the plotInfo information, parsed in such a
482 way that it can be included on a figure.
483 """
484 run = plotInfo["run"]
485 componentType = f"\nComponent: {plotInfo['component']}"
487 maskPlaneText = ""
488 if "maskPlanes" in plotInfo:
489 for maskPlane in plotInfo["maskPlanes"]:
490 maskPlaneText += maskPlane + ", "
491 maskPlaneText = f", Mask Plane(s): {maskPlaneText[:-2]}"
493 dataIdText = f"\nSkyMap:{plotInfo['skymap']}, Tract: {plotInfo['tract']}"
495 bandText = ""
496 for band in plotInfo["bands"]:
497 bandText += band + ", "
498 bandsText = f", Bands: {bandText[:-2]}"
499 infoText = f"\n{run}{componentType}{maskPlaneText}{dataIdText}{bandsText}"
501 return infoText