Coverage for python / lsst / multiprofit / plotting / plot_model_rgb.py: 3%

229 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-30 08:48 +0000

1# This file is part of multiprofit. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = ["plot_model_rgb"] 

23 

24import math 

25from typing import Any 

26 

27import astropy.visualization as apVis 

28import lsst.gauss2d as g2 

29import lsst.gauss2d.fit as g2f 

30import matplotlib as mpl 

31import matplotlib.pyplot as plt 

32import numpy as np 

33 

34from .types import Axes, Figure 

35 

36 

37def plot_model_rgb( 

38 model: g2f.ModelD | None, 

39 weights: dict[str, float] | None = None, 

40 high_sn_threshold: float | None = None, 

41 plot_singleband: bool = True, 

42 plot_chi_hist: bool = True, 

43 chi_max: float = 5.0, 

44 rgb_min_auto: bool = False, 

45 rgb_stretch_auto: bool = False, 

46 **kwargs: Any, 

47) -> tuple[Figure, Axes, Figure, Axes, np.ndarray]: 

48 """Plot RGB images of a model, its data and residuals thereof. 

49 

50 Parameters 

51 ---------- 

52 model 

53 The model to plot. If None, a dict of observations by band may be 

54 passed as an additional kwarg; otherwise, only the data will be 

55 plotted. 

56 weights 

57 Linear weights to multiply each band's image by. The default is a 

58 weight of one for each band. 

59 high_sn_threshold 

60 If non-None and given a model, this will return an image with the 

61 pixels having a model S/N above this threshold in every band. 

62 plot_singleband 

63 Whether to make grayscale plots for each band. 

64 plot_chi_hist 

65 Whether to plot histograms of the chi (scaled residual) values. 

66 chi_max 

67 The maximum absolute value of chi in residual plots. Values of 3-5 are 

68 suitable for good models while inadequate ones may need larger values. 

69 rgb_min_auto 

70 Whether to set the minimum in RGB plots automatically. Cannot supply 

71 minimum in kwargs if enabled. 

72 rgb_stretch_auto 

73 Whether to set the stretch in RGB plots automatically. Cannot supply 

74 stretch in kwargs if enabled. 

75 **kwargs 

76 Additional keyword arguments to pass to make_lupton_rgb when creating 

77 RGB images. 

78 

79 Returns 

80 ------- 

81 fig_rgb 

82 The Figure for the RGB plots. 

83 ax_rgb 

84 The Axes for the RGB plots. 

85 fig_gs 

86 The Figure for the grayscale plots. 

87 ax_gs 

88 The Axes for the grayscale plots. 

89 mask_inv_highsn 

90 The inverse mask (1=selected) if high_sn_threshold was specified. 

91 """ 

92 if rgb_min_auto and "minimum" in kwargs: 

93 raise ValueError(f"Cannot set rgb_min_auto and pass {kwargs['minimum']=}") 

94 if rgb_stretch_auto and "stretch" in kwargs: 

95 raise ValueError(f"Cannot set rgb_stretch_auto and pass {kwargs['stretch']=}") 

96 if not (chi_max > 0): 

97 raise ValueError(f"{chi_max=} not >0") 

98 if weights is None: 

99 if model is None: 

100 weights = {band: 1.0 for band in kwargs["observations"].keys()} 

101 else: 

102 bands_set = set() 

103 bands = [] 

104 weights = {} 

105 for obs in model.data: 

106 band = obs.channel.name 

107 if band not in bands_set: 

108 bands_set.add(band) 

109 bands.append(band) 

110 weights[band] = 1.0 

111 

112 n_data = len(model.data) 

113 has_model = model is not None 

114 observations = {} 

115 models = {} 

116 

117 if has_model and (n_data < 3): 

118 if n_data == 1: 

119 # pretend this is three bands 

120 obs, output_data = model.data[0], model.outputs[0].data 

121 band = obs.channel.name 

122 weights = {} 

123 for idx in range(1, 4): 

124 key = f"{band}{idx}" 

125 weights[key] = 1.0 

126 observations[key] = obs 

127 models[key] = output_data 

128 elif n_data == 2: 

129 raise NotImplementedError("RGB images for two-band data are not supported (yet)") 

130 

131 bands = tuple(weights.keys()) 

132 band_str = ",".join(bands) 

133 n_bands = len(bands) 

134 

135 if has_model and (not model.outputs or any([output is None for output in model.outputs])): 

136 model.setup_evaluators(g2f.EvaluatorMode.image) 

137 model.evaluate() 

138 

139 if not has_model: 

140 if plot_chi_hist: 

141 raise ValueError("Cannot plot chi histograms without a model") 

142 obs_kwarg = kwargs.pop("observations") 

143 for band in bands: 

144 observations[band] = obs_kwarg[band] 

145 

146 x_min, x_max, y_min, y_max = np.inf, -np.inf, np.inf, -np.inf 

147 coordsys_last = None 

148 if has_model and n_data >= 3: 

149 for obs, output in zip(model.data, model.outputs): 

150 band = obs.channel.name 

151 if band in bands: 

152 if band in observations: 

153 raise ValueError(f"Cannot plot {model=} because {band=} has multiple observations") 

154 observations[band] = obs 

155 models[band] = output.data 

156 

157 for band, obs in observations.items(): 

158 coordsys = obs.image.coordsys 

159 if coordsys: 

160 coordsys_last = coordsys 

161 x_min = int(round(min(x_min, coordsys.x_min), 0)) 

162 x_max = int(round(max(x_max, coordsys.x_min + obs.image.n_cols), 0)) 

163 y_min = int(round(min(y_min, coordsys.y_min), 0)) 

164 y_max = int(round(max(y_max, coordsys.y_min + obs.image.n_rows), 0)) 

165 elif coordsys_last is not None: 

166 raise ValueError( 

167 f"coordinate system for {band=} is None but last was not; they must either " 

168 f"all be None or all non-None" 

169 ) 

170 

171 if coordsys_last: 

172 shape_new = (y_max - y_min, x_max - x_min) 

173 keys = ("image", "mask_inv", "sigma_inv") 

174 if has_model: 

175 keys += ("model",) 

176 for band, obs in observations.items(): 

177 coordsys = obs.image.coordsys 

178 x_min_c = int(round(coordsys.x_min, 0)) - x_min 

179 y_min_c = int(round(coordsys.y_min, 0)) - y_min 

180 x_min_o, x_max_o = x_min_c, x_min_c + obs.image.n_cols 

181 y_min_o, y_max_o = y_min_c, y_min_c + obs.image.n_rows 

182 if x_min_o or x_max_o or y_min_o or y_max_o: 

183 # zero-pad the relevant images into a new observation 

184 data_new = {} 

185 for key in keys: 

186 img = np.zeros(shape_new) 

187 img[y_min_o:y_max_o, x_min_o:x_max_o] = ( 

188 models[band] if (key == "model") else getattr(obs, key).data 

189 ) 

190 if key == "model": 

191 models[band] = img 

192 else: 

193 data_new[key] = (g2.ImageB if (key == "mask_inv") else g2.ImageD)(img) 

194 observations[band] = g2f.ObservationD(channel=obs.channel, **data_new) 

195 

196 extent = (x_min, x_max, y_min, y_max) 

197 

198 images_data = [None] * 3 

199 images_data_unweighted = [None] * 3 if has_model else None 

200 images_model = [None] * 3 if has_model else None 

201 images_model_unweighted = [None] * 3 if has_model else None 

202 images_sigma_inv = [None] * 3 if has_model else None 

203 masks_inv_rgb = [None] * 3 

204 

205 weights_channel = np.linspace(0, 3, len(weights) + 1)[1:] 

206 idx_channel = 0 

207 weight_channel = 0 

208 

209 def add_if_not_none(array: np.ndarray, index: int, arg: float | None) -> None: 

210 if array[index] is not None: 

211 array[index] += arg 

212 else: 

213 array[index] = arg 

214 

215 chis_unweighted = {} 

216 

217 for idx_band, (band, weight) in enumerate(weights.items()): 

218 observation = observations[band] 

219 if has_model: 

220 model_band = models[band] 

221 sigma_inv = observation.sigma_inv.data 

222 sigma_inv_good = sigma_inv > 0 

223 variance_band = np.empty_like(sigma_inv) 

224 variance_band[sigma_inv_good] = sigma_inv[sigma_inv_good] ** -2 

225 variance_band[~sigma_inv_good] = np.nan 

226 if plot_chi_hist: 

227 chi_good = (sigma_inv > 0) & np.isfinite(sigma_inv) 

228 chi_unweighted = (observation.image.data[chi_good] - model_band[chi_good]) * sigma_inv[ 

229 chi_good 

230 ] 

231 chis_unweighted[band] = chi_unweighted 

232 weight_channel_new = weights_channel[idx_band] 

233 idx_channel_new = int(weight_channel_new // 1) 

234 if idx_channel_new == idx_channel: 

235 weight_low = weight_channel_new - weight_channel 

236 weight_high = 0.0 

237 else: 

238 weight_low = idx_channel_new - weight_channel 

239 weight_high = weight_channel_new - idx_channel_new 

240 assert weight_high >= 0 

241 assert weight_low >= 0 

242 if weight_low > 0: 

243 data_band = observation.image.data * weight_low 

244 add_if_not_none(images_data, idx_channel, data_band * weight) 

245 add_if_not_none(masks_inv_rgb, idx_channel, observation.mask_inv.data * weight_low) 

246 if has_model: 

247 add_if_not_none(images_data_unweighted, idx_channel, data_band) 

248 model_sub = model_band * weight_low 

249 add_if_not_none(images_model, idx_channel, model_sub * weight) 

250 add_if_not_none(images_model_unweighted, idx_channel, model_sub) 

251 add_if_not_none(images_sigma_inv, idx_channel, variance_band * weight_low) 

252 if (idx_channel_new != idx_channel) and (weight_high > 0): 

253 data_band = observation.image.data * weight_high 

254 images_data[idx_channel_new] = data_band * weight 

255 masks_inv_rgb[idx_channel_new] = observation.mask_inv.data * weight_low 

256 if has_model: 

257 images_model_unweighted[idx_channel_new] = data_band 

258 model_sub = model_band * weight_high 

259 images_model[idx_channel_new] = model_sub * weight 

260 images_model_unweighted[idx_channel_new] = model_sub 

261 images_sigma_inv[idx_channel_new] = variance_band * weight_high 

262 weight_channel = weight_channel_new 

263 idx_channel = idx_channel_new 

264 

265 # convert variance to 1/sigma 

266 if has_model: 

267 for idx in range(3): 

268 images_sigma_inv[idx] = 1 / np.sqrt(images_sigma_inv[idx]) 

269 

270 if rgb_min_auto or rgb_stretch_auto: 

271 # The model won't have negative pixels, so it ought to stretch fine 

272 # the max/stretch is not as important anyway 

273 rgb_min, rgb_max = np.nanpercentile( 

274 np.concatenate([image[mask_inv != 0] for mask_inv, image in zip(masks_inv_rgb, images_data)]), 

275 (5, 95), 

276 ) 

277 if rgb_min_auto: 

278 kwargs["minimum"] = rgb_min 

279 if rgb_stretch_auto: 

280 kwargs["stretch"] = 2 * (rgb_max - rgb_min) 

281 

282 img_rgb = apVis.make_lupton_rgb(*images_data, **kwargs) 

283 if has_model: 

284 img_model_rgb = apVis.make_lupton_rgb(*images_model, **kwargs) 

285 aspect = np.clip((y_max - y_min) / (x_max - x_min), 0.25, 4) 

286 

287 n_rows = 1 + has_model 

288 n_cols_gs = 1 + has_model 

289 n_cols_rgb = 1 + has_model * (1 + plot_chi_hist) 

290 figsize_y = 8 * n_rows * aspect 

291 

292 fig_rgb, ax_rgb = plt.subplots(nrows=n_rows, ncols=n_cols_rgb, figsize=(8 * n_cols_rgb, figsize_y)) 

293 fig_gs, ax_gs = ( 

294 (None, None) 

295 if not plot_singleband 

296 else plt.subplots( 

297 nrows=n_bands, 

298 ncols=n_cols_gs, 

299 figsize=(8 * n_cols_gs, 8 * aspect * n_bands), 

300 ) 

301 ) 

302 (ax_rgb[0][0] if has_model else ax_rgb).imshow(img_rgb, extent=extent, origin="lower") 

303 (ax_rgb[0][0] if has_model else ax_rgb).set_title("Data") 

304 if has_model: 

305 ax_rgb[1][0].imshow(img_model_rgb, extent=extent, origin="lower") 

306 ax_rgb[1][0].set_title(f"Model ({band_str})") 

307 

308 masks_inv = {} 

309 # Create a mask of high-sn pixels (based on the model) 

310 mask_inv_highsn = np.ones(img_rgb.shape[:1], dtype="bool") if high_sn_threshold else None 

311 

312 for idx, band in enumerate(bands): 

313 obs = observations[band] 

314 mask_inv = obs.mask_inv.data 

315 masks_inv[band] = mask_inv 

316 img_data = obs.image.data 

317 img_sigma_inv = obs.sigma_inv.data 

318 if plot_singleband: 

319 if has_model: 

320 img_model = models[band] 

321 if mask_inv_highsn: 

322 mask_inv_highsn *= (img_model * np.nanmedian(img_sigma_inv)) > high_sn_threshold 

323 residual = (img_data - img_model) * mask_inv 

324 value_max = np.nanpercentile(np.abs(residual), 98) 

325 ax_gs[idx][0].imshow(residual, cmap="gray", vmin=-value_max, vmax=value_max, origin="lower") 

326 ax_gs[idx][0].tick_params(labelleft=False) 

327 ax_gs[idx][0].set_title(f"{band}-band Residual (abs.)") 

328 ax_gs[idx][1].imshow( 

329 np.clip(residual * img_sigma_inv, -chi_max, chi_max), 

330 cmap="gray", 

331 origin="lower", 

332 ) 

333 ax_gs[idx][1].tick_params(labelleft=False) 

334 ax_gs[idx][1].set_title(f"{band}-band Residual (chi, +/- {chi_max:.2f})") 

335 else: 

336 ax_gs[idx].imshow(img_data * mask_inv * (img_sigma_inv > 0), cmap="gray", origin="lower") 

337 ax_gs[idx].set_title(band) 

338 

339 if has_model: 

340 # TODO: Draw masks in each channel? or draw the combined mask, like: 

341 # mask_inv_all = np.prod(list(masks_inv.values()), axis=0) 

342 residuals = [(images_model_unweighted[idx] - images_data_unweighted[idx]) for idx in range(3)] 

343 resid_max = np.nanpercentile( 

344 np.abs(np.concatenate([residual[np.isfinite(residual)] for residual in residuals])), 98 

345 ) 

346 

347 # This may or may not be equivalent to make_lupton_rgb 

348 # I just can't figure out how to get that scaled so zero = 50% gray 

349 stretch = 3 

350 residual_rgb = np.stack( 

351 [np.arcsinh(np.clip(residuals[idx], -resid_max, resid_max) * stretch) for idx in range(3)], 

352 axis=-1, 

353 ) 

354 residual_rgb /= 2 * np.arcsinh(resid_max * stretch) 

355 residual_rgb += 0.5 

356 

357 ax_rgb[0][1].imshow(residual_rgb, origin="lower") 

358 ax_rgb[0][1].set_title(f"Residual (abs., += {resid_max:.3e})") 

359 ax_rgb[0][1].tick_params(labelleft=False) 

360 

361 if plot_chi_hist: 

362 cmap = mpl.colormaps["coolwarm"] 

363 residuals_rgb = np.concatenate(tuple(chis_unweighted.values())) 

364 residuals_abs = np.abs(residuals_rgb) 

365 n_resid = len(residuals_abs) 

366 chi_max = 5 + 2.5 * ( 

367 (np.sum(residuals_abs > 5) / n_resid > 0.1) + (np.sum(residuals_abs > 7.5) / n_resid > 0.1) 

368 ) 

369 n_bins = int(math.ceil(np.clip(n_resid / 50, 2, 20)) * chi_max) 

370 # ax_rgb[0][2].set_adjustable('box') 

371 ax_rgb[0][2].hist( 

372 np.clip(residuals_rgb, -chi_max, chi_max), 

373 bins=n_bins, 

374 histtype="step", 

375 label="all", 

376 ) 

377 band_colors = cmap(np.linspace(0, 1, n_bands)) 

378 for band, band_color in zip(bands, band_colors): 

379 ax_rgb[0][2].hist( 

380 np.clip(residuals_rgb, -chi_max, chi_max), 

381 bins=n_bins, 

382 histtype="step", 

383 label=band, 

384 ) 

385 ax_rgb[0][2].legend() 

386 

387 # TODO: Plot unscaled residuals in ax_rgb[1][2]? It's unused now. 

388 residual_rgb = np.stack( 

389 [ 

390 (np.clip(residuals[idx] * images_sigma_inv[idx], -chi_max, chi_max) + chi_max) / (2 * chi_max) 

391 for idx in range(3) 

392 ], 

393 axis=-1, 

394 ) 

395 

396 ax_rgb[1][1].imshow(residual_rgb, origin="lower") 

397 ax_rgb[1][1].set_title(f"Residual (chi, +/- {chi_max:.2f})") 

398 ax_rgb[1][1].tick_params(labelleft=False) 

399 

400 return fig_rgb, ax_rgb, fig_gs, ax_gs, mask_inv_highsn