Coverage for python / lsst / analysis / tools / actions / plot / wholeTractImage.py: 15%

184 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-22 09:08 +0000

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ("WholeTractImage",) 

25 

26from typing import Mapping, Optional 

27 

28import matplotlib.cm as cm 

29import matplotlib.patches as patches 

30import matplotlib.patheffects as pathEffects 

31import matplotlib.pyplot as plt 

32import numpy as np 

33from astropy.visualization import ImageNormalize 

34from lsst.pex.config import ( 

35 ChoiceField, 

36 Field, 

37 FieldValidationError, 

38 ListField, 

39) 

40from lsst.pex.config.configurableActions import ConfigurableActionField 

41from lsst.skymap import BaseSkyMap 

42from lsst.utils.plotting import make_figure, set_rubin_plotstyle 

43from matplotlib.figure import Figure 

44 

45from ...interfaces import ( 

46 KeyedData, 

47 KeyedDataSchema, 

48 PlotAction, 

49 TensorAction, 

50 VectorAction, 

51) 

52from ...utils import getPatchCorners, getTractCorners 

53from .calculateRange import Asinh, Perc 

54 

55 

56class WholeTractImage(PlotAction): 

57 """ 

58 Produces a figure displaying whole-tract coadd pixel data as a 2D image. 

59 

60 The figure is constructed from all patches covering the tract. Regions of 

61 NO_DATA or where no coadd exists are shown as red shading or red hatches, 

62 respectively. 

63 

64 Either the image, pixel mask, or variance components of the coadd can be 

65 displayed. In the case of the pixel mask, one or more bitmaskPlanes must 

66 be specified; the specified bitmaskPlanes are OR-combined, with flagged 

67 pixels given a value of 1, and unflagged pixels given a value of 1. 

68 """ 

69 

70 component = ChoiceField[str]( 

71 doc="Coadd component to display. Can take one of image, mask, variance. Default: image.", 

72 default="image", 

73 allowed={plane: plane for plane in ("image", "mask", "variance")}, 

74 ) 

75 

76 bitmaskPlanes = ListField[str]( 

77 doc="List of names of bitmask plane(s) to display when displaying the " 

78 "mask plane. Bitmask planes are OR-combined. Flagged pixels are given " 

79 "a value of 1; unflagged pixels are given a value of 0. " 

80 "Optional when displaying either the image or variance planes. " 

81 "Required when displaying the mask plane.", 

82 optional=True, 

83 ) 

84 

85 showPatchIds = Field[bool]( 

86 doc="Show the patch IDs in the centre of each patch. Default: False", 

87 default=False, 

88 ) 

89 

90 showColorbar = Field[bool]( 

91 doc="Show a colorbar alongside the main plot. Default: False", 

92 default=False, 

93 ) 

94 

95 zAxisLabel = Field[str]( 

96 doc="Label to display on the colorbar. Optional", 

97 optional=True, 

98 ) 

99 

100 interval = ConfigurableActionField[VectorAction]( 

101 doc="Action to calculate the min and max values of the image scale. Default: Perc.", 

102 default=Perc, 

103 ) 

104 

105 colorbarCmap = ChoiceField[str]( 

106 doc="Matplotlib colormap to use for the displayed image. Default: gray", 

107 default="gray", 

108 allowed={name: name for name in plt.colormaps()}, 

109 ) 

110 

111 noDataColor = Field[str]( 

112 doc="Matplotlib color to use to indicate regions of no data. Default: red", 

113 default="red", 

114 ) 

115 

116 noDataValue = Field[int]( 

117 doc="If data doesn't contain a mask plane, the value in the image plane to " 

118 "assign the noDataColor to. Optional.", 

119 optional=True, 

120 ) 

121 

122 vmaxFloor = Field[float]( 

123 doc="The floor of the vmax value of the colorbar", 

124 default=None, 

125 optional=True, 

126 ) 

127 

128 stretch = ConfigurableActionField[TensorAction]( 

129 doc="Action to calculate the stretch of the image scale. Default: Asinh", 

130 default=Asinh, 

131 ) 

132 

133 displayAsPostageStamp = Field[bool]( 

134 doc="Display as a figure to be used as postage stamp. No plotInfo or legend is shown, " 

135 "and large fonts are used for axis labels.", 

136 default=False, 

137 ) 

138 

139 def validate(self): 

140 super().validate() 

141 

142 if self.component == "mask" and self.bitmaskPlanes is None: 

143 raise FieldValidationError( 

144 self.__class__.bitmaskPlanes, 

145 self, 

146 "'bitmaskPlanes' must be specified if displaying the mask plane.", 

147 ) 

148 if self.bitmaskPlanes is not None and self.component != "mask": 

149 raise FieldValidationError( 

150 self.__class__.component, 

151 self, 

152 "'component' must be set to the mask plane if 'bitmaskPlanes' is specified.", 

153 ) 

154 

155 def getInputSchema(self) -> KeyedDataSchema: 

156 base = [] 

157 base.append((self.component, KeyedData)) 

158 return base 

159 

160 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure: 

161 self._validateInput(data, **kwargs) 

162 return self.makeFigure(data, **kwargs) 

163 

164 def _validateInput(self, data: KeyedData, **kwargs) -> None: 

165 needed = self.getInputSchema() 

166 if remainder := {key.format(**kwargs) for key, _ in needed} - { 

167 key.format(**kwargs) for key in data.keys() 

168 }: 

169 raise ValueError(f"Task needs keys {remainder} but they were not found in the input data") 

170 

171 def makeFigure( 

172 self, 

173 data: KeyedData, 

174 tractId: int, 

175 skymap: BaseSkyMap, 

176 plotInfo: Optional[Mapping[str, str]] = None, 

177 **kwargs, 

178 ) -> Figure: 

179 """Make a figure displaying the input pixel data. 

180 

181 Parameters 

182 ---------- 

183 data : `lsst.analysis.tools.interfaces.KeyedData` 

184 A python dict-of-dicts containing the pixel data to display in the 

185 figure. The top level keys are named after the coadd component(s), 

186 and must contain at least 'mask'. The next level keys are named 

187 after the patch ID of the coadd component contained as their 

188 corresponding value. 

189 tractId : `int` 

190 Identification number of the tract to be displayed. 

191 skymap : `lsst.skymap.BaseSkyMap` 

192 The sky map used for this dataset. This is referred-to to determine 

193 the location of the tract on-sky (for RA and Dec axis ranges) and 

194 the location of the patches within the tract. 

195 plotInfo : `dict`, optional 

196 A dictionary of information about the data being plotted with keys: 

197 

198 ``"run"`` 

199 The output run for the plots (`str`). 

200 ``"skymap"`` 

201 The type of skymap used for the data (`str`). 

202 ``"band"`` 

203 The filter used for this data (`str`). Optional 

204 ``"tract"`` 

205 The tract that the data comes from (`str`). 

206 

207 Returns 

208 ------- 

209 fig : `matplotlib.figure.Figure` 

210 The resulting figure. 

211 

212 Examples 

213 -------- 

214 An example wholeTractImage plot may be seen below: 

215 

216 .. image:: /_static/analysis_tools/wholeTractImageExample.png 

217 

218 For further details on how to generate a plot, please refer to the 

219 :ref:`getting started guide<analysis-tools-getting-started>`. 

220 """ 

221 

222 tractInfo = skymap.generateTract(tractId) 

223 tractCorners = getTractCorners(skymap, tractId) 

224 tractRas = [ra for (ra, dec) in tractCorners] 

225 RaSpansZero = max(tractRas) > 360.0 

226 

227 cmap = cm.get_cmap(self.colorbarCmap).reversed().copy() 

228 cmap.set_bad(self.noDataColor, alpha=0.6 if self.noDataColor == "red" else 1.0) 

229 

230 set_rubin_plotstyle() 

231 fig = make_figure() 

232 ax = fig.add_subplot(111) 

233 

234 if plotInfo is None: 

235 plotInfo = {} 

236 plotInfo["component"] = self.component 

237 

238 if self.bitmaskPlanes is not None: 

239 plotInfo["maskPlanes"] = self.bitmaskPlanes 

240 

241 if self.displayAsPostageStamp: 

242 axisLabelFontSize = 20 

243 tickMarkFontSize = 10 

244 boundaryColor = "k" 

245 boundaryAlpha = 1.0 

246 boundaryWidth = 0.5 

247 else: 

248 axisLabelFontSize = 8 

249 tickMarkFontSize = 8 

250 boundaryColor = "r" if "viridis" in self.colorbarCmap.lower() else "c" 

251 boundaryAlpha = 0.3 

252 boundaryWidth = 1.0 

253 

254 # Keep a record of the "empty" patches that do not have coadds. 

255 emptyPatches = np.arange(tractInfo.getNumPatches()[0] * tractInfo.getNumPatches()[1]).tolist() 

256 

257 # Extract the pixel arrays for all patches prior to plotting. 

258 # This allows for a global image normalisation to be calculated. 

259 imStack = dict() 

260 allPix = np.array([]) 

261 patchIds = data[self.component].keys() 

262 first = True 

263 for patchId in patchIds: 

264 

265 if first: 

266 if "mask" in data: 

267 noDataBitmask = data["mask"][patchId].getPlaneBitMask("NO_DATA") 

268 if self.bitmaskPlanes: 

269 bitmasks = data["mask"][patchId].getPlaneBitMask(self.bitmaskPlanes) 

270 first = False 

271 

272 emptyPatches.remove(patchId) 

273 im = data[self.component][patchId].array 

274 if self.bitmaskPlanes: 

275 im = (im & bitmasks > 0) * 1.0 

276 

277 if "mask" in data: 

278 noDataMask = data["mask"][patchId].array & noDataBitmask > 0 

279 elif self.noDataValue is not None: 

280 noDataMask = data[self.component][patchId].array == self.noDataValue 

281 else: 

282 noDataMask = np.zeros_like(data[self.component][patchId].array) > 0 

283 

284 allPix = np.append(allPix, im[~noDataMask].flatten()) 

285 imStack[patchId] = np.ma.masked_array(im, mask=noDataMask) 

286 

287 # It is possible that all pixels are flagged NO_DATA. 

288 # In which case, set vmin & vmax to arbitrary values. 

289 if len(allPix) == 0: 

290 vmin, vmax = (0, 1) 

291 else: 

292 vmin, vmax = self.interval(allPix) 

293 

294 # Set a floor to vmax. Useful for low dymanic range data. 

295 if self.vmaxFloor is not None: 

296 vmax = max(vmax, self.vmaxFloor) 

297 

298 for patchId, im in imStack.items(): 

299 

300 # Create the patch axes at the appropriate location in tract: 

301 patchCorners = getPatchCorners(tractInfo, patchId) 

302 ras = [ra for (ra, dec) in patchCorners] 

303 decs = [dec for (ra, dec) in patchCorners] 

304 

305 # Account for the RA wrapping using negative RA values. 

306 # This is rectified when the final axes are built. 

307 if RaSpansZero: 

308 ras = [ra - 360 if ra > 180.0 else ra for ra in ras] 

309 Extent = (max(ras), min(ras), max(decs), min(decs)) 

310 ax.plot( 

311 [min(ras), max(ras), max(ras), min(ras), min(ras)], 

312 [min(decs), min(decs), max(decs), max(decs), min(decs)], 

313 boundaryColor, 

314 lw=boundaryWidth, 

315 alpha=boundaryAlpha, 

316 ) 

317 

318 norm = ImageNormalize(vmin=vmin, vmax=vmax) 

319 stretchedIm = self.stretch(norm(im)) 

320 masked_stretched = np.ma.masked_array( 

321 norm.inverse(stretchedIm.data), 

322 mask=stretchedIm.mask, 

323 ) 

324 plotIm = ax.imshow(masked_stretched, vmin=vmin, vmax=vmax, extent=Extent, cmap=cmap) 

325 

326 if self.showPatchIds: 

327 ax.annotate( 

328 patchId, 

329 (np.mean(ras), np.mean(decs)), 

330 color="k", 

331 ha="center", 

332 va="center", 

333 path_effects=[pathEffects.withStroke(linewidth=2, foreground="w")], 

334 ) 

335 

336 # Indicate the empty patches with red hatching 

337 for patchId in emptyPatches: 

338 

339 patchCorners = getPatchCorners(tractInfo, patchId) 

340 ras = [ra for (ra, dec) in patchCorners] 

341 decs = [dec for (ra, dec) in patchCorners] 

342 

343 # Account for the RA wrapping using negative RA values. 

344 if RaSpansZero: 

345 ras = [ra - 360 if ra > 180.0 else ra for ra in ras] 

346 

347 Extent = (max(ras), min(ras), max(decs), min(decs)) 

348 ax.plot( 

349 [min(ras), max(ras), max(ras), min(ras), min(ras)], 

350 [min(decs), min(decs), max(decs), max(decs), min(decs)], 

351 boundaryColor, 

352 lw=boundaryWidth, 

353 alpha=boundaryAlpha, 

354 ) 

355 

356 cs = ax.contourf(np.ones((10, 10)), 1, hatches=["xx"], extent=Extent, colors="none") 

357 cs.set_edgecolors("red") 

358 if self.showPatchIds: 

359 ax.annotate( 

360 patchId, 

361 (np.mean(ras), np.mean(decs)), 

362 color="k", 

363 ha="center", 

364 va="center", 

365 path_effects=[pathEffects.withStroke(linewidth=2, foreground="w")], 

366 ) 

367 

368 # Draw axes around the entire tract: 

369 ax.set_xlabel("R.A. (deg)", fontsize=axisLabelFontSize) 

370 ax.set_ylabel("Dec. (deg)", fontsize=axisLabelFontSize) 

371 

372 tractRas = [ra for (ra, dec) in tractCorners] 

373 # Account for the RA wrapping using negative RA values. 

374 if RaSpansZero: 

375 tractRas = [ra - 360.0 for ra in tractRas] 

376 

377 ax.set_xlim(max(tractRas), min(tractRas)) 

378 ticks = [t for t in ax.get_xticks() if t >= min(tractRas) and t <= max(tractRas)] 

379 

380 # Rectify potential negative RA values via tick labels 

381 tickLabels = [f"{t % 360:.1f}" for t in ticks] 

382 ax.set_xticks(ticks, tickLabels) 

383 

384 tractDecs = [dec for (ra, dec) in tractCorners] 

385 ax.set_ylim(min(tractDecs), max(tractDecs)) 

386 

387 ax.tick_params(axis="both", labelsize=tickMarkFontSize, length=0, pad=1.5) 

388 

389 if self.showColorbar: 

390 cax = fig.add_axes([0.90, 0.11, 0.04, 0.77]) 

391 cbar = fig.colorbar(plotIm, cax=cax, extend="both") 

392 cbar.ax.tick_params(labelsize=tickMarkFontSize) 

393 if self.zAxisLabel: 

394 colorbarLabel = self.zAxisLabel 

395 else: 

396 colorbarLabel = "" 

397 text = cax.text( 

398 0.5, 

399 0.5, 

400 colorbarLabel, 

401 color="k", 

402 rotation="vertical", 

403 transform=cax.transAxes, 

404 ha="center", 

405 va="center", 

406 fontsize=10, 

407 ) 

408 text.set_path_effects([pathEffects.Stroke(linewidth=3, foreground="w"), pathEffects.Normal()]) 

409 

410 if self.displayAsPostageStamp: 

411 if "band" in plotInfo: 

412 title = f"{str(tractId)}; {plotInfo['band']}" 

413 else: 

414 title = f"{str(tractId)}" 

415 ax.set_title(title, fontsize=20) 

416 

417 if not self.displayAsPostageStamp: 

418 if "mask" in data: 

419 noDataPatch = patches.Rectangle( 

420 (0.8, 1.1), 0.05, 0.04, transform=ax.transAxes, facecolor="red", alpha=0.6, clip_on=False 

421 ) 

422 ax.add_patch(noDataPatch) 

423 ax.text(0.86, 1.115, "NO_DATA", transform=ax.transAxes, va="center", ha="left", fontsize=8) 

424 

425 noCoaddPatch = patches.Rectangle( 

426 (0.8, 1.02), 

427 0.05, 

428 0.04, 

429 transform=ax.transAxes, 

430 facecolor="none", 

431 edgecolor="red", 

432 hatch="xx", 

433 clip_on=False, 

434 ) 

435 ax.add_patch(noCoaddPatch) 

436 ax.text(0.86, 1.04, "No coadd", transform=ax.transAxes, va="center", ha="left", fontsize=8) 

437 

438 fig = addPlotInfo(fig, plotInfo) 

439 fig.canvas.draw() 

440 

441 return fig 

442 

443 

444def addPlotInfo(fig: Figure, plotInfo: Mapping[str, str]) -> Figure: 

445 """Add useful information to the plot. 

446 

447 Parameters 

448 ---------- 

449 fig : `matplotlib.figure.Figure` 

450 The figure to add the information to. 

451 plotInfo : `dict` 

452 A dictionary of the plot information. 

453 

454 Returns 

455 ------- 

456 fig : `matplotlib.figure.Figure` 

457 The figure with the information added. 

458 """ 

459 fig.text(0.01, 0.99, plotInfo["plotName"], fontsize=7, transform=fig.transFigure, ha="left", va="top") 

460 infoText = parsePlotInfo(plotInfo) 

461 fig.text(0.01, 0.984, infoText, fontsize=6, transform=fig.transFigure, alpha=0.6, ha="left", va="top") 

462 

463 return fig 

464 

465 

466def parsePlotInfo(plotInfo: Mapping[str, str]) -> str: 

467 """Extract information from the plotInfo dictionary and parses it into 

468 a meaningful string that can be added to a figure. The default function 

469 in .plotUtils is not suitable for image plotting. 

470 

471 Parameters 

472 ---------- 

473 plotInfo : `dict`[`str`, `str`] 

474 A plotInfo dictionary containing useful information to 

475 be included on a figure. 

476 

477 Returns 

478 ------- 

479 infoText : `str` 

480 A string containing the plotInfo information, parsed in such a 

481 way that it can be included on a figure. 

482 """ 

483 run = plotInfo["run"] 

484 componentType = f"\nComponent: {plotInfo['component']}" 

485 

486 maskPlaneText = "" 

487 if "maskPlanes" in plotInfo: 

488 for maskPlane in plotInfo["maskPlanes"]: 

489 maskPlaneText += maskPlane + ", " 

490 maskPlaneText = f", Mask Plane(s): {maskPlaneText[:-2]}" 

491 

492 dataIdText = f"\nSkyMap:{plotInfo['skymap']}, Tract: {plotInfo['tract']}" 

493 

494 bandText = "" 

495 for band in plotInfo["bands"]: 

496 bandText += band + ", " 

497 bandsText = f", Bands: {bandText[:-2]}" 

498 infoText = f"\n{run}{componentType}{maskPlaneText}{dataIdText}{bandsText}" 

499 

500 return infoText