Coverage for python / lsst / analysis / tools / actions / plot / wholeTractImage.py: 15%
184 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-05 18:53 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-05 18:53 +0000
1# This file is part of analysis_tools.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from __future__ import annotations
24__all__ = ("WholeTractImage",)
26from typing import Mapping, Optional
28import matplotlib.cm as cm
29import matplotlib.patches as patches
30import matplotlib.patheffects as pathEffects
31import matplotlib.pyplot as plt
32import numpy as np
33from astropy.visualization import ImageNormalize
34from lsst.pex.config import (
35 ChoiceField,
36 Field,
37 FieldValidationError,
38 ListField,
39)
40from lsst.pex.config.configurableActions import ConfigurableActionField
41from lsst.skymap import BaseSkyMap
42from lsst.utils.plotting import make_figure, set_rubin_plotstyle
43from matplotlib.figure import Figure
45from ...interfaces import (
46 KeyedData,
47 KeyedDataSchema,
48 PlotAction,
49 TensorAction,
50 VectorAction,
51)
52from ...utils import getPatchCorners, getTractCorners
53from .calculateRange import Asinh, Perc
56class WholeTractImage(PlotAction):
57 """
58 Produces a figure displaying whole-tract coadd pixel data as a 2D image.
60 The figure is constructed from all patches covering the tract. Regions of
61 NO_DATA or where no coadd exists are shown as red shading or red hatches,
62 respectively.
64 Either the image, pixel mask, or variance components of the coadd can be
65 displayed. In the case of the pixel mask, one or more bitmaskPlanes must
66 be specified; the specified bitmaskPlanes are OR-combined, with flagged
67 pixels given a value of 1, and unflagged pixels given a value of 1.
68 """
70 component = ChoiceField[str](
71 doc="Coadd component to display. Can take one of image, mask, variance. Default: image.",
72 default="image",
73 allowed={plane: plane for plane in ("image", "mask", "variance")},
74 )
76 bitmaskPlanes = ListField[str](
77 doc="List of names of bitmask plane(s) to display when displaying the "
78 "mask plane. Bitmask planes are OR-combined. Flagged pixels are given "
79 "a value of 1; unflagged pixels are given a value of 0. "
80 "Optional when displaying either the image or variance planes. "
81 "Required when displaying the mask plane.",
82 optional=True,
83 )
85 showPatchIds = Field[bool](
86 doc="Show the patch IDs in the centre of each patch. Default: False",
87 default=False,
88 )
90 showColorbar = Field[bool](
91 doc="Show a colorbar alongside the main plot. Default: False",
92 default=False,
93 )
95 zAxisLabel = Field[str](
96 doc="Label to display on the colorbar. Optional",
97 optional=True,
98 )
100 interval = ConfigurableActionField[VectorAction](
101 doc="Action to calculate the min and max values of the image scale. Default: Perc.",
102 default=Perc,
103 )
105 colorbarCmap = ChoiceField[str](
106 doc="Matplotlib colormap to use for the displayed image. Default: gray",
107 default="gray",
108 allowed={name: name for name in plt.colormaps()},
109 )
111 noDataColor = Field[str](
112 doc="Matplotlib color to use to indicate regions of no data. Default: red",
113 default="red",
114 )
116 noDataValue = Field[int](
117 doc="If data doesn't contain a mask plane, the value in the image plane to "
118 "assign the noDataColor to. Optional.",
119 optional=True,
120 )
122 vmaxFloor = Field[float](
123 doc="The floor of the vmax value of the colorbar",
124 default=None,
125 optional=True,
126 )
128 stretch = ConfigurableActionField[TensorAction](
129 doc="Action to calculate the stretch of the image scale. Default: Asinh",
130 default=Asinh,
131 )
133 displayAsPostageStamp = Field[bool](
134 doc="Display as a figure to be used as postage stamp. No plotInfo or legend is shown, "
135 "and large fonts are used for axis labels.",
136 default=False,
137 )
139 def validate(self):
140 super().validate()
142 if self.component == "mask" and self.bitmaskPlanes is None:
143 raise FieldValidationError(
144 self.__class__.bitmaskPlanes,
145 self,
146 "'bitmaskPlanes' must be specified if displaying the mask plane.",
147 )
148 if self.bitmaskPlanes is not None and self.component != "mask":
149 raise FieldValidationError(
150 self.__class__.component,
151 self,
152 "'component' must be set to the mask plane if 'bitmaskPlanes' is specified.",
153 )
155 def getInputSchema(self) -> KeyedDataSchema:
156 base = []
157 base.append((self.component, KeyedData))
158 return base
160 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure:
161 self._validateInput(data, **kwargs)
162 return self.makeFigure(data, **kwargs)
164 def _validateInput(self, data: KeyedData, **kwargs) -> None:
165 needed = self.getInputSchema()
166 if remainder := {key.format(**kwargs) for key, _ in needed} - {
167 key.format(**kwargs) for key in data.keys()
168 }:
169 raise ValueError(f"Task needs keys {remainder} but they were not found in the input data")
171 def makeFigure(
172 self,
173 data: KeyedData,
174 tractId: int,
175 skymap: BaseSkyMap,
176 plotInfo: Optional[Mapping[str, str]] = None,
177 **kwargs,
178 ) -> Figure:
179 """Make a figure displaying the input pixel data.
181 Parameters
182 ----------
183 data : `lsst.analysis.tools.interfaces.KeyedData`
184 A python dict-of-dicts containing the pixel data to display in the
185 figure. The top level keys are named after the coadd component(s),
186 and must contain at least 'mask'. The next level keys are named
187 after the patch ID of the coadd component contained as their
188 corresponding value.
189 tractId : `int`
190 Identification number of the tract to be displayed.
191 skymap : `lsst.skymap.BaseSkyMap`
192 The sky map used for this dataset. This is referred-to to determine
193 the location of the tract on-sky (for RA and Dec axis ranges) and
194 the location of the patches within the tract.
195 plotInfo : `dict`, optional
196 A dictionary of information about the data being plotted with keys:
198 ``"run"``
199 The output run for the plots (`str`).
200 ``"skymap"``
201 The type of skymap used for the data (`str`).
202 ``"band"``
203 The filter used for this data (`str`). Optional
204 ``"tract"``
205 The tract that the data comes from (`str`).
207 Returns
208 -------
209 fig : `matplotlib.figure.Figure`
210 The resulting figure.
212 Examples
213 --------
214 An example wholeTractImage plot may be seen below:
216 .. image:: /_static/analysis_tools/wholeTractImageExample.png
218 For further details on how to generate a plot, please refer to the
219 :ref:`getting started guide<analysis-tools-getting-started>`.
220 """
222 tractInfo = skymap.generateTract(tractId)
223 tractCorners = getTractCorners(skymap, tractId)
224 tractRas = [ra for (ra, dec) in tractCorners]
225 RaSpansZero = max(tractRas) > 360.0
227 cmap = cm.get_cmap(self.colorbarCmap).reversed().copy()
228 cmap.set_bad(self.noDataColor, alpha=0.6 if self.noDataColor == "red" else 1.0)
230 set_rubin_plotstyle()
231 fig = make_figure()
232 ax = fig.add_subplot(111)
234 if plotInfo is None:
235 plotInfo = {}
236 plotInfo["component"] = self.component
238 if self.bitmaskPlanes is not None:
239 plotInfo["maskPlanes"] = self.bitmaskPlanes
241 if self.displayAsPostageStamp:
242 axisLabelFontSize = 20
243 tickMarkFontSize = 10
244 boundaryColor = "k"
245 boundaryAlpha = 1.0
246 boundaryWidth = 0.5
247 else:
248 axisLabelFontSize = 8
249 tickMarkFontSize = 8
250 boundaryColor = "r" if "viridis" in self.colorbarCmap.lower() else "c"
251 boundaryAlpha = 0.3
252 boundaryWidth = 1.0
254 # Keep a record of the "empty" patches that do not have coadds.
255 emptyPatches = np.arange(tractInfo.getNumPatches()[0] * tractInfo.getNumPatches()[1]).tolist()
257 # Extract the pixel arrays for all patches prior to plotting.
258 # This allows for a global image normalisation to be calculated.
259 imStack = dict()
260 allPix = np.array([])
261 patchIds = data[self.component].keys()
262 first = True
263 for patchId in patchIds:
265 if first:
266 if "mask" in data:
267 noDataBitmask = data["mask"][patchId].getPlaneBitMask("NO_DATA")
268 if self.bitmaskPlanes:
269 bitmasks = data["mask"][patchId].getPlaneBitMask(self.bitmaskPlanes)
270 first = False
272 emptyPatches.remove(patchId)
273 im = data[self.component][patchId].array
274 if self.bitmaskPlanes:
275 im = (im & bitmasks > 0) * 1.0
277 if "mask" in data:
278 noDataMask = data["mask"][patchId].array & noDataBitmask > 0
279 elif self.noDataValue is not None:
280 noDataMask = data[self.component][patchId].array == self.noDataValue
281 else:
282 noDataMask = np.zeros_like(data[self.component][patchId].array) > 0
284 allPix = np.append(allPix, im[~noDataMask].flatten())
285 imStack[patchId] = np.ma.masked_array(im, mask=noDataMask)
287 # It is possible that all pixels are flagged NO_DATA.
288 # In which case, set vmin & vmax to arbitrary values.
289 if len(allPix) == 0:
290 vmin, vmax = (0, 1)
291 else:
292 vmin, vmax = self.interval(allPix)
294 # Set a floor to vmax. Useful for low dymanic range data.
295 if self.vmaxFloor is not None:
296 vmax = max(vmax, self.vmaxFloor)
298 for patchId, im in imStack.items():
300 # Create the patch axes at the appropriate location in tract:
301 patchCorners = getPatchCorners(tractInfo, patchId)
302 ras = [ra for (ra, dec) in patchCorners]
303 decs = [dec for (ra, dec) in patchCorners]
305 # Account for the RA wrapping using negative RA values.
306 # This is rectified when the final axes are built.
307 if RaSpansZero:
308 ras = [ra - 360 if ra > 180.0 else ra for ra in ras]
309 Extent = (max(ras), min(ras), max(decs), min(decs))
310 ax.plot(
311 [min(ras), max(ras), max(ras), min(ras), min(ras)],
312 [min(decs), min(decs), max(decs), max(decs), min(decs)],
313 boundaryColor,
314 lw=boundaryWidth,
315 alpha=boundaryAlpha,
316 )
318 norm = ImageNormalize(vmin=vmin, vmax=vmax)
319 stretchedIm = self.stretch(norm(im))
320 masked_stretched = np.ma.masked_array(
321 norm.inverse(stretchedIm.data),
322 mask=stretchedIm.mask,
323 )
324 plotIm = ax.imshow(masked_stretched, vmin=vmin, vmax=vmax, extent=Extent, cmap=cmap)
326 if self.showPatchIds:
327 ax.annotate(
328 patchId,
329 (np.mean(ras), np.mean(decs)),
330 color="k",
331 ha="center",
332 va="center",
333 path_effects=[pathEffects.withStroke(linewidth=2, foreground="w")],
334 )
336 # Indicate the empty patches with red hatching
337 for patchId in emptyPatches:
339 patchCorners = getPatchCorners(tractInfo, patchId)
340 ras = [ra for (ra, dec) in patchCorners]
341 decs = [dec for (ra, dec) in patchCorners]
343 # Account for the RA wrapping using negative RA values.
344 if RaSpansZero:
345 ras = [ra - 360 if ra > 180.0 else ra for ra in ras]
347 Extent = (max(ras), min(ras), max(decs), min(decs))
348 ax.plot(
349 [min(ras), max(ras), max(ras), min(ras), min(ras)],
350 [min(decs), min(decs), max(decs), max(decs), min(decs)],
351 boundaryColor,
352 lw=boundaryWidth,
353 alpha=boundaryAlpha,
354 )
356 cs = ax.contourf(np.ones((10, 10)), 1, hatches=["xx"], extent=Extent, colors="none")
357 cs.set_edgecolors("red")
358 if self.showPatchIds:
359 ax.annotate(
360 patchId,
361 (np.mean(ras), np.mean(decs)),
362 color="k",
363 ha="center",
364 va="center",
365 path_effects=[pathEffects.withStroke(linewidth=2, foreground="w")],
366 )
368 # Draw axes around the entire tract:
369 ax.set_xlabel("R.A. (deg)", fontsize=axisLabelFontSize)
370 ax.set_ylabel("Dec. (deg)", fontsize=axisLabelFontSize)
372 tractRas = [ra for (ra, dec) in tractCorners]
373 # Account for the RA wrapping using negative RA values.
374 if RaSpansZero:
375 tractRas = [ra - 360.0 for ra in tractRas]
377 ax.set_xlim(max(tractRas), min(tractRas))
378 ticks = [t for t in ax.get_xticks() if t >= min(tractRas) and t <= max(tractRas)]
380 # Rectify potential negative RA values via tick labels
381 tickLabels = [f"{t % 360:.1f}" for t in ticks]
382 ax.set_xticks(ticks, tickLabels)
384 tractDecs = [dec for (ra, dec) in tractCorners]
385 ax.set_ylim(min(tractDecs), max(tractDecs))
387 ax.tick_params(axis="both", labelsize=tickMarkFontSize, length=0, pad=1.5)
389 if self.showColorbar:
390 cax = fig.add_axes([0.90, 0.11, 0.04, 0.77])
391 cbar = fig.colorbar(plotIm, cax=cax, extend="both")
392 cbar.ax.tick_params(labelsize=tickMarkFontSize)
393 if self.zAxisLabel:
394 colorbarLabel = self.zAxisLabel
395 else:
396 colorbarLabel = ""
397 text = cax.text(
398 0.5,
399 0.5,
400 colorbarLabel,
401 color="k",
402 rotation="vertical",
403 transform=cax.transAxes,
404 ha="center",
405 va="center",
406 fontsize=10,
407 )
408 text.set_path_effects([pathEffects.Stroke(linewidth=3, foreground="w"), pathEffects.Normal()])
410 if self.displayAsPostageStamp:
411 if "band" in plotInfo:
412 title = f"{str(tractId)}; {plotInfo['band']}"
413 else:
414 title = f"{str(tractId)}"
415 ax.set_title(title, fontsize=20)
417 if not self.displayAsPostageStamp:
418 if "mask" in data:
419 noDataPatch = patches.Rectangle(
420 (0.8, 1.1), 0.05, 0.04, transform=ax.transAxes, facecolor="red", alpha=0.6, clip_on=False
421 )
422 ax.add_patch(noDataPatch)
423 ax.text(0.86, 1.115, "NO_DATA", transform=ax.transAxes, va="center", ha="left", fontsize=8)
425 noCoaddPatch = patches.Rectangle(
426 (0.8, 1.02),
427 0.05,
428 0.04,
429 transform=ax.transAxes,
430 facecolor="none",
431 edgecolor="red",
432 hatch="xx",
433 clip_on=False,
434 )
435 ax.add_patch(noCoaddPatch)
436 ax.text(0.86, 1.04, "No coadd", transform=ax.transAxes, va="center", ha="left", fontsize=8)
438 fig = addPlotInfo(fig, plotInfo)
439 fig.canvas.draw()
441 return fig
444def addPlotInfo(fig: Figure, plotInfo: Mapping[str, str]) -> Figure:
445 """Add useful information to the plot.
447 Parameters
448 ----------
449 fig : `matplotlib.figure.Figure`
450 The figure to add the information to.
451 plotInfo : `dict`
452 A dictionary of the plot information.
454 Returns
455 -------
456 fig : `matplotlib.figure.Figure`
457 The figure with the information added.
458 """
459 fig.text(0.01, 0.99, plotInfo["plotName"], fontsize=7, transform=fig.transFigure, ha="left", va="top")
460 infoText = parsePlotInfo(plotInfo)
461 fig.text(0.01, 0.984, infoText, fontsize=6, transform=fig.transFigure, alpha=0.6, ha="left", va="top")
463 return fig
466def parsePlotInfo(plotInfo: Mapping[str, str]) -> str:
467 """Extract information from the plotInfo dictionary and parses it into
468 a meaningful string that can be added to a figure. The default function
469 in .plotUtils is not suitable for image plotting.
471 Parameters
472 ----------
473 plotInfo : `dict`[`str`, `str`]
474 A plotInfo dictionary containing useful information to
475 be included on a figure.
477 Returns
478 -------
479 infoText : `str`
480 A string containing the plotInfo information, parsed in such a
481 way that it can be included on a figure.
482 """
483 run = plotInfo["run"]
484 componentType = f"\nComponent: {plotInfo['component']}"
486 maskPlaneText = ""
487 if "maskPlanes" in plotInfo:
488 for maskPlane in plotInfo["maskPlanes"]:
489 maskPlaneText += maskPlane + ", "
490 maskPlaneText = f", Mask Plane(s): {maskPlaneText[:-2]}"
492 dataIdText = f"\nSkyMap:{plotInfo['skymap']}, Tract: {plotInfo['tract']}"
494 bandText = ""
495 for band in plotInfo["bands"]:
496 bandText += band + ", "
497 bandsText = f", Bands: {bandText[:-2]}"
498 infoText = f"\n{run}{componentType}{maskPlaneText}{dataIdText}{bandsText}"
500 return infoText