Coverage for python / lsst / analysis / tools / actions / plot / wholeTractImage.py: 15%

184 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-17 09:36 +0000

1# This file is part of analysis_tools. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = ("WholeTractImage",) 

25 

26from collections.abc import Mapping 

27 

28import matplotlib.cm as cm 

29import matplotlib.patches as patches 

30import matplotlib.patheffects as pathEffects 

31import matplotlib.pyplot as plt 

32import numpy as np 

33from astropy.visualization import ImageNormalize 

34from matplotlib.figure import Figure 

35 

36from lsst.pex.config import ( 

37 ChoiceField, 

38 Field, 

39 FieldValidationError, 

40 ListField, 

41) 

42from lsst.pex.config.configurableActions import ConfigurableActionField 

43from lsst.skymap import BaseSkyMap 

44from lsst.utils.plotting import make_figure, set_rubin_plotstyle 

45 

46from ...interfaces import ( 

47 KeyedData, 

48 KeyedDataSchema, 

49 PlotAction, 

50 TensorAction, 

51 VectorAction, 

52) 

53from ...utils import getPatchCorners, getTractCorners 

54from .calculateRange import Asinh, Perc 

55 

56 

57class WholeTractImage(PlotAction): 

58 """ 

59 Produces a figure displaying whole-tract coadd pixel data as a 2D image. 

60 

61 The figure is constructed from all patches covering the tract. Regions of 

62 NO_DATA or where no coadd exists are shown as red shading or red hatches, 

63 respectively. 

64 

65 Either the image, pixel mask, or variance components of the coadd can be 

66 displayed. In the case of the pixel mask, one or more bitmaskPlanes must 

67 be specified; the specified bitmaskPlanes are OR-combined, with flagged 

68 pixels given a value of 1, and unflagged pixels given a value of 1. 

69 """ 

70 

71 component = ChoiceField[str]( 

72 doc="Coadd component to display. Can take one of image, mask, variance. Default: image.", 

73 default="image", 

74 allowed={plane: plane for plane in ("image", "mask", "variance")}, 

75 ) 

76 

77 bitmaskPlanes = ListField[str]( 

78 doc="List of names of bitmask plane(s) to display when displaying the " 

79 "mask plane. Bitmask planes are OR-combined. Flagged pixels are given " 

80 "a value of 1; unflagged pixels are given a value of 0. " 

81 "Optional when displaying either the image or variance planes. " 

82 "Required when displaying the mask plane.", 

83 optional=True, 

84 ) 

85 

86 showPatchIds = Field[bool]( 

87 doc="Show the patch IDs in the centre of each patch. Default: False", 

88 default=False, 

89 ) 

90 

91 showColorbar = Field[bool]( 

92 doc="Show a colorbar alongside the main plot. Default: False", 

93 default=False, 

94 ) 

95 

96 zAxisLabel = Field[str]( 

97 doc="Label to display on the colorbar. Optional", 

98 optional=True, 

99 ) 

100 

101 interval = ConfigurableActionField[VectorAction]( 

102 doc="Action to calculate the min and max values of the image scale. Default: Perc.", 

103 default=Perc, 

104 ) 

105 

106 colorbarCmap = ChoiceField[str]( 

107 doc="Matplotlib colormap to use for the displayed image. Default: gray", 

108 default="gray", 

109 allowed={name: name for name in plt.colormaps()}, 

110 ) 

111 

112 noDataColor = Field[str]( 

113 doc="Matplotlib color to use to indicate regions of no data. Default: red", 

114 default="red", 

115 ) 

116 

117 noDataValue = Field[int]( 

118 doc="If data doesn't contain a mask plane, the value in the image plane to " 

119 "assign the noDataColor to. Optional.", 

120 optional=True, 

121 ) 

122 

123 vmaxFloor = Field[float]( 

124 doc="The floor of the vmax value of the colorbar", 

125 default=None, 

126 optional=True, 

127 ) 

128 

129 stretch = ConfigurableActionField[TensorAction]( 

130 doc="Action to calculate the stretch of the image scale. Default: Asinh", 

131 default=Asinh, 

132 ) 

133 

134 displayAsPostageStamp = Field[bool]( 

135 doc="Display as a figure to be used as postage stamp. No plotInfo or legend is shown, " 

136 "and large fonts are used for axis labels.", 

137 default=False, 

138 ) 

139 

140 def validate(self): 

141 super().validate() 

142 

143 if self.component == "mask" and self.bitmaskPlanes is None: 

144 raise FieldValidationError( 

145 self.__class__.bitmaskPlanes, 

146 self, 

147 "'bitmaskPlanes' must be specified if displaying the mask plane.", 

148 ) 

149 if self.bitmaskPlanes is not None and self.component != "mask": 

150 raise FieldValidationError( 

151 self.__class__.component, 

152 self, 

153 "'component' must be set to the mask plane if 'bitmaskPlanes' is specified.", 

154 ) 

155 

156 def getInputSchema(self) -> KeyedDataSchema: 

157 base = [] 

158 base.append((self.component, KeyedData)) 

159 return base 

160 

161 def __call__(self, data: KeyedData, **kwargs) -> Mapping[str, Figure] | Figure: 

162 self._validateInput(data, **kwargs) 

163 return self.makeFigure(data, **kwargs) 

164 

165 def _validateInput(self, data: KeyedData, **kwargs) -> None: 

166 needed = self.getInputSchema() 

167 if remainder := {key.format(**kwargs) for key, _ in needed} - { 

168 key.format(**kwargs) for key in data.keys() 

169 }: 

170 raise ValueError(f"Task needs keys {remainder} but they were not found in the input data") 

171 

172 def makeFigure( 

173 self, 

174 data: KeyedData, 

175 tractId: int, 

176 skymap: BaseSkyMap, 

177 plotInfo: Mapping[str, str] | None = None, 

178 **kwargs, 

179 ) -> Figure: 

180 """Make a figure displaying the input pixel data. 

181 

182 Parameters 

183 ---------- 

184 data : `lsst.analysis.tools.interfaces.KeyedData` 

185 A python dict-of-dicts containing the pixel data to display in the 

186 figure. The top level keys are named after the coadd component(s), 

187 and must contain at least 'mask'. The next level keys are named 

188 after the patch ID of the coadd component contained as their 

189 corresponding value. 

190 tractId : `int` 

191 Identification number of the tract to be displayed. 

192 skymap : `lsst.skymap.BaseSkyMap` 

193 The sky map used for this dataset. This is referred-to to determine 

194 the location of the tract on-sky (for RA and Dec axis ranges) and 

195 the location of the patches within the tract. 

196 plotInfo : `dict`, optional 

197 A dictionary of information about the data being plotted with keys: 

198 

199 ``"run"`` 

200 The output run for the plots (`str`). 

201 ``"skymap"`` 

202 The type of skymap used for the data (`str`). 

203 ``"band"`` 

204 The filter used for this data (`str`). Optional 

205 ``"tract"`` 

206 The tract that the data comes from (`str`). 

207 

208 Returns 

209 ------- 

210 fig : `matplotlib.figure.Figure` 

211 The resulting figure. 

212 

213 Examples 

214 -------- 

215 An example wholeTractImage plot may be seen below: 

216 

217 .. image:: /_static/analysis_tools/wholeTractImageExample.png 

218 

219 For further details on how to generate a plot, please refer to the 

220 :ref:`getting started guide<analysis-tools-getting-started>`. 

221 """ 

222 

223 tractInfo = skymap.generateTract(tractId) 

224 tractCorners = getTractCorners(skymap, tractId) 

225 tractRas = [ra for (ra, dec) in tractCorners] 

226 RaSpansZero = max(tractRas) > 360.0 

227 

228 cmap = cm.get_cmap(self.colorbarCmap).reversed().copy() 

229 cmap.set_bad(self.noDataColor, alpha=0.6 if self.noDataColor == "red" else 1.0) 

230 

231 set_rubin_plotstyle() 

232 fig = make_figure() 

233 ax = fig.add_subplot(111) 

234 

235 if plotInfo is None: 

236 plotInfo = {} 

237 plotInfo["component"] = self.component 

238 

239 if self.bitmaskPlanes is not None: 

240 plotInfo["maskPlanes"] = self.bitmaskPlanes 

241 

242 if self.displayAsPostageStamp: 

243 axisLabelFontSize = 20 

244 tickMarkFontSize = 10 

245 boundaryColor = "k" 

246 boundaryAlpha = 1.0 

247 boundaryWidth = 0.5 

248 else: 

249 axisLabelFontSize = 8 

250 tickMarkFontSize = 8 

251 boundaryColor = "r" if "viridis" in self.colorbarCmap.lower() else "c" 

252 boundaryAlpha = 0.3 

253 boundaryWidth = 1.0 

254 

255 # Keep a record of the "empty" patches that do not have coadds. 

256 emptyPatches = np.arange(tractInfo.getNumPatches()[0] * tractInfo.getNumPatches()[1]).tolist() 

257 

258 # Extract the pixel arrays for all patches prior to plotting. 

259 # This allows for a global image normalisation to be calculated. 

260 imStack = dict() 

261 allPix = np.array([]) 

262 patchIds = data[self.component].keys() 

263 first = True 

264 for patchId in patchIds: 

265 

266 if first: 

267 if "mask" in data: 

268 noDataBitmask = data["mask"][patchId].getPlaneBitMask("NO_DATA") 

269 if self.bitmaskPlanes: 

270 bitmasks = data["mask"][patchId].getPlaneBitMask(self.bitmaskPlanes) 

271 first = False 

272 

273 emptyPatches.remove(patchId) 

274 im = data[self.component][patchId].array 

275 if self.bitmaskPlanes: 

276 im = (im & bitmasks > 0) * 1.0 

277 

278 if "mask" in data: 

279 noDataMask = data["mask"][patchId].array & noDataBitmask > 0 

280 elif self.noDataValue is not None: 

281 noDataMask = data[self.component][patchId].array == self.noDataValue 

282 else: 

283 noDataMask = np.zeros_like(data[self.component][patchId].array) > 0 

284 

285 allPix = np.append(allPix, im[~noDataMask].flatten()) 

286 imStack[patchId] = np.ma.masked_array(im, mask=noDataMask) 

287 

288 # It is possible that all pixels are flagged NO_DATA. 

289 # In which case, set vmin & vmax to arbitrary values. 

290 if len(allPix) == 0: 

291 vmin, vmax = (0, 1) 

292 else: 

293 vmin, vmax = self.interval(allPix) 

294 

295 # Set a floor to vmax. Useful for low dymanic range data. 

296 if self.vmaxFloor is not None: 

297 vmax = max(vmax, self.vmaxFloor) 

298 

299 for patchId, im in imStack.items(): 

300 

301 # Create the patch axes at the appropriate location in tract: 

302 patchCorners = getPatchCorners(tractInfo, patchId) 

303 ras = [ra for (ra, dec) in patchCorners] 

304 decs = [dec for (ra, dec) in patchCorners] 

305 

306 # Account for the RA wrapping using negative RA values. 

307 # This is rectified when the final axes are built. 

308 if RaSpansZero: 

309 ras = [ra - 360 if ra > 180.0 else ra for ra in ras] 

310 Extent = (max(ras), min(ras), max(decs), min(decs)) 

311 ax.plot( 

312 [min(ras), max(ras), max(ras), min(ras), min(ras)], 

313 [min(decs), min(decs), max(decs), max(decs), min(decs)], 

314 boundaryColor, 

315 lw=boundaryWidth, 

316 alpha=boundaryAlpha, 

317 ) 

318 

319 norm = ImageNormalize(vmin=vmin, vmax=vmax) 

320 stretchedIm = self.stretch(norm(im)) 

321 masked_stretched = np.ma.masked_array( 

322 norm.inverse(stretchedIm.data), 

323 mask=stretchedIm.mask, 

324 ) 

325 plotIm = ax.imshow(masked_stretched, vmin=vmin, vmax=vmax, extent=Extent, cmap=cmap) 

326 

327 if self.showPatchIds: 

328 ax.annotate( 

329 patchId, 

330 (np.mean(ras), np.mean(decs)), 

331 color="k", 

332 ha="center", 

333 va="center", 

334 path_effects=[pathEffects.withStroke(linewidth=2, foreground="w")], 

335 ) 

336 

337 # Indicate the empty patches with red hatching 

338 for patchId in emptyPatches: 

339 

340 patchCorners = getPatchCorners(tractInfo, patchId) 

341 ras = [ra for (ra, dec) in patchCorners] 

342 decs = [dec for (ra, dec) in patchCorners] 

343 

344 # Account for the RA wrapping using negative RA values. 

345 if RaSpansZero: 

346 ras = [ra - 360 if ra > 180.0 else ra for ra in ras] 

347 

348 Extent = (max(ras), min(ras), max(decs), min(decs)) 

349 ax.plot( 

350 [min(ras), max(ras), max(ras), min(ras), min(ras)], 

351 [min(decs), min(decs), max(decs), max(decs), min(decs)], 

352 boundaryColor, 

353 lw=boundaryWidth, 

354 alpha=boundaryAlpha, 

355 ) 

356 

357 cs = ax.contourf(np.ones((10, 10)), 1, hatches=["xx"], extent=Extent, colors="none") 

358 cs.set_edgecolors("red") 

359 if self.showPatchIds: 

360 ax.annotate( 

361 patchId, 

362 (np.mean(ras), np.mean(decs)), 

363 color="k", 

364 ha="center", 

365 va="center", 

366 path_effects=[pathEffects.withStroke(linewidth=2, foreground="w")], 

367 ) 

368 

369 # Draw axes around the entire tract: 

370 ax.set_xlabel("R.A. (deg)", fontsize=axisLabelFontSize) 

371 ax.set_ylabel("Dec. (deg)", fontsize=axisLabelFontSize) 

372 

373 tractRas = [ra for (ra, dec) in tractCorners] 

374 # Account for the RA wrapping using negative RA values. 

375 if RaSpansZero: 

376 tractRas = [ra - 360.0 for ra in tractRas] 

377 

378 ax.set_xlim(max(tractRas), min(tractRas)) 

379 ticks = [t for t in ax.get_xticks() if t >= min(tractRas) and t <= max(tractRas)] 

380 

381 # Rectify potential negative RA values via tick labels 

382 tickLabels = [f"{t % 360:.1f}" for t in ticks] 

383 ax.set_xticks(ticks, tickLabels) 

384 

385 tractDecs = [dec for (ra, dec) in tractCorners] 

386 ax.set_ylim(min(tractDecs), max(tractDecs)) 

387 

388 ax.tick_params(axis="both", labelsize=tickMarkFontSize, length=0, pad=1.5) 

389 

390 if self.showColorbar: 

391 cax = fig.add_axes([0.90, 0.11, 0.04, 0.77]) 

392 cbar = fig.colorbar(plotIm, cax=cax, extend="both") 

393 cbar.ax.tick_params(labelsize=tickMarkFontSize) 

394 if self.zAxisLabel: 

395 colorbarLabel = self.zAxisLabel 

396 else: 

397 colorbarLabel = "" 

398 text = cax.text( 

399 0.5, 

400 0.5, 

401 colorbarLabel, 

402 color="k", 

403 rotation="vertical", 

404 transform=cax.transAxes, 

405 ha="center", 

406 va="center", 

407 fontsize=10, 

408 ) 

409 text.set_path_effects([pathEffects.Stroke(linewidth=3, foreground="w"), pathEffects.Normal()]) 

410 

411 if self.displayAsPostageStamp: 

412 if "band" in plotInfo: 

413 title = f"{str(tractId)}; {plotInfo['band']}" 

414 else: 

415 title = f"{str(tractId)}" 

416 ax.set_title(title, fontsize=20) 

417 

418 if not self.displayAsPostageStamp: 

419 if "mask" in data: 

420 noDataPatch = patches.Rectangle( 

421 (0.8, 1.1), 0.05, 0.04, transform=ax.transAxes, facecolor="red", alpha=0.6, clip_on=False 

422 ) 

423 ax.add_patch(noDataPatch) 

424 ax.text(0.86, 1.115, "NO_DATA", transform=ax.transAxes, va="center", ha="left", fontsize=8) 

425 

426 noCoaddPatch = patches.Rectangle( 

427 (0.8, 1.02), 

428 0.05, 

429 0.04, 

430 transform=ax.transAxes, 

431 facecolor="none", 

432 edgecolor="red", 

433 hatch="xx", 

434 clip_on=False, 

435 ) 

436 ax.add_patch(noCoaddPatch) 

437 ax.text(0.86, 1.04, "No coadd", transform=ax.transAxes, va="center", ha="left", fontsize=8) 

438 

439 fig = addPlotInfo(fig, plotInfo) 

440 fig.canvas.draw() 

441 

442 return fig 

443 

444 

445def addPlotInfo(fig: Figure, plotInfo: Mapping[str, str]) -> Figure: 

446 """Add useful information to the plot. 

447 

448 Parameters 

449 ---------- 

450 fig : `matplotlib.figure.Figure` 

451 The figure to add the information to. 

452 plotInfo : `dict` 

453 A dictionary of the plot information. 

454 

455 Returns 

456 ------- 

457 fig : `matplotlib.figure.Figure` 

458 The figure with the information added. 

459 """ 

460 fig.text(0.01, 0.99, plotInfo["plotName"], fontsize=7, transform=fig.transFigure, ha="left", va="top") 

461 infoText = parsePlotInfo(plotInfo) 

462 fig.text(0.01, 0.984, infoText, fontsize=6, transform=fig.transFigure, alpha=0.6, ha="left", va="top") 

463 

464 return fig 

465 

466 

467def parsePlotInfo(plotInfo: Mapping[str, str]) -> str: 

468 """Extract information from the plotInfo dictionary and parses it into 

469 a meaningful string that can be added to a figure. The default function 

470 in .plotUtils is not suitable for image plotting. 

471 

472 Parameters 

473 ---------- 

474 plotInfo : `dict`[`str`, `str`] 

475 A plotInfo dictionary containing useful information to 

476 be included on a figure. 

477 

478 Returns 

479 ------- 

480 infoText : `str` 

481 A string containing the plotInfo information, parsed in such a 

482 way that it can be included on a figure. 

483 """ 

484 run = plotInfo["run"] 

485 componentType = f"\nComponent: {plotInfo['component']}" 

486 

487 maskPlaneText = "" 

488 if "maskPlanes" in plotInfo: 

489 for maskPlane in plotInfo["maskPlanes"]: 

490 maskPlaneText += maskPlane + ", " 

491 maskPlaneText = f", Mask Plane(s): {maskPlaneText[:-2]}" 

492 

493 dataIdText = f"\nSkyMap:{plotInfo['skymap']}, Tract: {plotInfo['tract']}" 

494 

495 bandText = "" 

496 for band in plotInfo["bands"]: 

497 bandText += band + ", " 

498 bandsText = f", Bands: {bandText[:-2]}" 

499 infoText = f"\n{run}{componentType}{maskPlaneText}{dataIdText}{bandsText}" 

500 

501 return infoText