Coverage for python / lsst / multiprofit / plotting / plot_model_rgb.py: 3%
229 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-28 08:43 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-28 08:43 +0000
1# This file is part of multiprofit.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22__all__ = ["plot_model_rgb"]
24import math
25from typing import Any
27import astropy.visualization as apVis
28import lsst.gauss2d as g2
29import lsst.gauss2d.fit as g2f
30import matplotlib as mpl
31import matplotlib.pyplot as plt
32import numpy as np
34from .types import Axes, Figure
37def plot_model_rgb(
38 model: g2f.ModelD | None,
39 weights: dict[str, float] | None = None,
40 high_sn_threshold: float | None = None,
41 plot_singleband: bool = True,
42 plot_chi_hist: bool = True,
43 chi_max: float = 5.0,
44 rgb_min_auto: bool = False,
45 rgb_stretch_auto: bool = False,
46 **kwargs: Any,
47) -> tuple[Figure, Axes, Figure, Axes, np.ndarray]:
48 """Plot RGB images of a model, its data and residuals thereof.
50 Parameters
51 ----------
52 model
53 The model to plot. If None, a dict of observations by band may be
54 passed as an additional kwarg; otherwise, only the data will be
55 plotted.
56 weights
57 Linear weights to multiply each band's image by. The default is a
58 weight of one for each band.
59 high_sn_threshold
60 If non-None and given a model, this will return an image with the
61 pixels having a model S/N above this threshold in every band.
62 plot_singleband
63 Whether to make grayscale plots for each band.
64 plot_chi_hist
65 Whether to plot histograms of the chi (scaled residual) values.
66 chi_max
67 The maximum absolute value of chi in residual plots. Values of 3-5 are
68 suitable for good models while inadequate ones may need larger values.
69 rgb_min_auto
70 Whether to set the minimum in RGB plots automatically. Cannot supply
71 minimum in kwargs if enabled.
72 rgb_stretch_auto
73 Whether to set the stretch in RGB plots automatically. Cannot supply
74 stretch in kwargs if enabled.
75 **kwargs
76 Additional keyword arguments to pass to make_lupton_rgb when creating
77 RGB images.
79 Returns
80 -------
81 fig_rgb
82 The Figure for the RGB plots.
83 ax_rgb
84 The Axes for the RGB plots.
85 fig_gs
86 The Figure for the grayscale plots.
87 ax_gs
88 The Axes for the grayscale plots.
89 mask_inv_highsn
90 The inverse mask (1=selected) if high_sn_threshold was specified.
91 """
92 if rgb_min_auto and "minimum" in kwargs:
93 raise ValueError(f"Cannot set rgb_min_auto and pass {kwargs['minimum']=}")
94 if rgb_stretch_auto and "stretch" in kwargs:
95 raise ValueError(f"Cannot set rgb_stretch_auto and pass {kwargs['stretch']=}")
96 if not (chi_max > 0):
97 raise ValueError(f"{chi_max=} not >0")
98 if weights is None:
99 if model is None:
100 weights = {band: 1.0 for band in kwargs["observations"].keys()}
101 else:
102 bands_set = set()
103 bands = []
104 weights = {}
105 for obs in model.data:
106 band = obs.channel.name
107 if band not in bands_set:
108 bands_set.add(band)
109 bands.append(band)
110 weights[band] = 1.0
112 n_data = len(model.data)
113 has_model = model is not None
114 observations = {}
115 models = {}
117 if has_model and (n_data < 3):
118 if n_data == 1:
119 # pretend this is three bands
120 obs, output_data = model.data[0], model.outputs[0].data
121 band = obs.channel.name
122 weights = {}
123 for idx in range(1, 4):
124 key = f"{band}{idx}"
125 weights[key] = 1.0
126 observations[key] = obs
127 models[key] = output_data
128 elif n_data == 2:
129 raise NotImplementedError("RGB images for two-band data are not supported (yet)")
131 bands = tuple(weights.keys())
132 band_str = ",".join(bands)
133 n_bands = len(bands)
135 if has_model and (not model.outputs or any([output is None for output in model.outputs])):
136 model.setup_evaluators(g2f.EvaluatorMode.image)
137 model.evaluate()
139 if not has_model:
140 if plot_chi_hist:
141 raise ValueError("Cannot plot chi histograms without a model")
142 obs_kwarg = kwargs.pop("observations")
143 for band in bands:
144 observations[band] = obs_kwarg[band]
146 x_min, x_max, y_min, y_max = np.inf, -np.inf, np.inf, -np.inf
147 coordsys_last = None
148 if has_model and n_data >= 3:
149 for obs, output in zip(model.data, model.outputs):
150 band = obs.channel.name
151 if band in bands:
152 if band in observations:
153 raise ValueError(f"Cannot plot {model=} because {band=} has multiple observations")
154 observations[band] = obs
155 models[band] = output.data
157 for band, obs in observations.items():
158 coordsys = obs.image.coordsys
159 if coordsys:
160 coordsys_last = coordsys
161 x_min = int(round(min(x_min, coordsys.x_min), 0))
162 x_max = int(round(max(x_max, coordsys.x_min + obs.image.n_cols), 0))
163 y_min = int(round(min(y_min, coordsys.y_min), 0))
164 y_max = int(round(max(y_max, coordsys.y_min + obs.image.n_rows), 0))
165 elif coordsys_last is not None:
166 raise ValueError(
167 f"coordinate system for {band=} is None but last was not; they must either "
168 f"all be None or all non-None"
169 )
171 if coordsys_last:
172 shape_new = (y_max - y_min, x_max - x_min)
173 keys = ("image", "mask_inv", "sigma_inv")
174 if has_model:
175 keys += ("model",)
176 for band, obs in observations.items():
177 coordsys = obs.image.coordsys
178 x_min_c = int(round(coordsys.x_min, 0)) - x_min
179 y_min_c = int(round(coordsys.y_min, 0)) - y_min
180 x_min_o, x_max_o = x_min_c, x_min_c + obs.image.n_cols
181 y_min_o, y_max_o = y_min_c, y_min_c + obs.image.n_rows
182 if x_min_o or x_max_o or y_min_o or y_max_o:
183 # zero-pad the relevant images into a new observation
184 data_new = {}
185 for key in keys:
186 img = np.zeros(shape_new)
187 img[y_min_o:y_max_o, x_min_o:x_max_o] = (
188 models[band] if (key == "model") else getattr(obs, key).data
189 )
190 if key == "model":
191 models[band] = img
192 else:
193 data_new[key] = (g2.ImageB if (key == "mask_inv") else g2.ImageD)(img)
194 observations[band] = g2f.ObservationD(channel=obs.channel, **data_new)
196 extent = (x_min, x_max, y_min, y_max)
198 images_data = [None] * 3
199 images_data_unweighted = [None] * 3 if has_model else None
200 images_model = [None] * 3 if has_model else None
201 images_model_unweighted = [None] * 3 if has_model else None
202 images_sigma_inv = [None] * 3 if has_model else None
203 masks_inv_rgb = [None] * 3
205 weights_channel = np.linspace(0, 3, len(weights) + 1)[1:]
206 idx_channel = 0
207 weight_channel = 0
209 def add_if_not_none(array: np.ndarray, index: int, arg: float | None) -> None:
210 if array[index] is not None:
211 array[index] += arg
212 else:
213 array[index] = arg
215 chis_unweighted = {}
217 for idx_band, (band, weight) in enumerate(weights.items()):
218 observation = observations[band]
219 if has_model:
220 model_band = models[band]
221 sigma_inv = observation.sigma_inv.data
222 sigma_inv_good = sigma_inv > 0
223 variance_band = np.empty_like(sigma_inv)
224 variance_band[sigma_inv_good] = sigma_inv[sigma_inv_good] ** -2
225 variance_band[~sigma_inv_good] = np.nan
226 if plot_chi_hist:
227 chi_good = (sigma_inv > 0) & np.isfinite(sigma_inv)
228 chi_unweighted = (observation.image.data[chi_good] - model_band[chi_good]) * sigma_inv[
229 chi_good
230 ]
231 chis_unweighted[band] = chi_unweighted
232 weight_channel_new = weights_channel[idx_band]
233 idx_channel_new = int(weight_channel_new // 1)
234 if idx_channel_new == idx_channel:
235 weight_low = weight_channel_new - weight_channel
236 weight_high = 0.0
237 else:
238 weight_low = idx_channel_new - weight_channel
239 weight_high = weight_channel_new - idx_channel_new
240 assert weight_high >= 0
241 assert weight_low >= 0
242 if weight_low > 0:
243 data_band = observation.image.data * weight_low
244 add_if_not_none(images_data, idx_channel, data_band * weight)
245 add_if_not_none(masks_inv_rgb, idx_channel, observation.mask_inv.data * weight_low)
246 if has_model:
247 add_if_not_none(images_data_unweighted, idx_channel, data_band)
248 model_sub = model_band * weight_low
249 add_if_not_none(images_model, idx_channel, model_sub * weight)
250 add_if_not_none(images_model_unweighted, idx_channel, model_sub)
251 add_if_not_none(images_sigma_inv, idx_channel, variance_band * weight_low)
252 if (idx_channel_new != idx_channel) and (weight_high > 0):
253 data_band = observation.image.data * weight_high
254 images_data[idx_channel_new] = data_band * weight
255 masks_inv_rgb[idx_channel_new] = observation.mask_inv.data * weight_low
256 if has_model:
257 images_model_unweighted[idx_channel_new] = data_band
258 model_sub = model_band * weight_high
259 images_model[idx_channel_new] = model_sub * weight
260 images_model_unweighted[idx_channel_new] = model_sub
261 images_sigma_inv[idx_channel_new] = variance_band * weight_high
262 weight_channel = weight_channel_new
263 idx_channel = idx_channel_new
265 # convert variance to 1/sigma
266 if has_model:
267 for idx in range(3):
268 images_sigma_inv[idx] = 1 / np.sqrt(images_sigma_inv[idx])
270 if rgb_min_auto or rgb_stretch_auto:
271 # The model won't have negative pixels, so it ought to stretch fine
272 # the max/stretch is not as important anyway
273 rgb_min, rgb_max = np.nanpercentile(
274 np.concatenate([image[mask_inv != 0] for mask_inv, image in zip(masks_inv_rgb, images_data)]),
275 (5, 95),
276 )
277 if rgb_min_auto:
278 kwargs["minimum"] = rgb_min
279 if rgb_stretch_auto:
280 kwargs["stretch"] = 2 * (rgb_max - rgb_min)
282 img_rgb = apVis.make_lupton_rgb(*images_data, **kwargs)
283 if has_model:
284 img_model_rgb = apVis.make_lupton_rgb(*images_model, **kwargs)
285 aspect = np.clip((y_max - y_min) / (x_max - x_min), 0.25, 4)
287 n_rows = 1 + has_model
288 n_cols_gs = 1 + has_model
289 n_cols_rgb = 1 + has_model * (1 + plot_chi_hist)
290 figsize_y = 8 * n_rows * aspect
292 fig_rgb, ax_rgb = plt.subplots(nrows=n_rows, ncols=n_cols_rgb, figsize=(8 * n_cols_rgb, figsize_y))
293 fig_gs, ax_gs = (
294 (None, None)
295 if not plot_singleband
296 else plt.subplots(
297 nrows=n_bands,
298 ncols=n_cols_gs,
299 figsize=(8 * n_cols_gs, 8 * aspect * n_bands),
300 )
301 )
302 (ax_rgb[0][0] if has_model else ax_rgb).imshow(img_rgb, extent=extent, origin="lower")
303 (ax_rgb[0][0] if has_model else ax_rgb).set_title("Data")
304 if has_model:
305 ax_rgb[1][0].imshow(img_model_rgb, extent=extent, origin="lower")
306 ax_rgb[1][0].set_title(f"Model ({band_str})")
308 masks_inv = {}
309 # Create a mask of high-sn pixels (based on the model)
310 mask_inv_highsn = np.ones(img_rgb.shape[:1], dtype="bool") if high_sn_threshold else None
312 for idx, band in enumerate(bands):
313 obs = observations[band]
314 mask_inv = obs.mask_inv.data
315 masks_inv[band] = mask_inv
316 img_data = obs.image.data
317 img_sigma_inv = obs.sigma_inv.data
318 if plot_singleband:
319 if has_model:
320 img_model = models[band]
321 if mask_inv_highsn:
322 mask_inv_highsn *= (img_model * np.nanmedian(img_sigma_inv)) > high_sn_threshold
323 residual = (img_data - img_model) * mask_inv
324 value_max = np.nanpercentile(np.abs(residual), 98)
325 ax_gs[idx][0].imshow(residual, cmap="gray", vmin=-value_max, vmax=value_max, origin="lower")
326 ax_gs[idx][0].tick_params(labelleft=False)
327 ax_gs[idx][0].set_title(f"{band}-band Residual (abs.)")
328 ax_gs[idx][1].imshow(
329 np.clip(residual * img_sigma_inv, -chi_max, chi_max),
330 cmap="gray",
331 origin="lower",
332 )
333 ax_gs[idx][1].tick_params(labelleft=False)
334 ax_gs[idx][1].set_title(f"{band}-band Residual (chi, +/- {chi_max:.2f})")
335 else:
336 ax_gs[idx].imshow(img_data * mask_inv * (img_sigma_inv > 0), cmap="gray", origin="lower")
337 ax_gs[idx].set_title(band)
339 if has_model:
340 # TODO: Draw masks in each channel? or draw the combined mask, like:
341 # mask_inv_all = np.prod(list(masks_inv.values()), axis=0)
342 residuals = [(images_model_unweighted[idx] - images_data_unweighted[idx]) for idx in range(3)]
343 resid_max = np.nanpercentile(
344 np.abs(np.concatenate([residual[np.isfinite(residual)] for residual in residuals])), 98
345 )
347 # This may or may not be equivalent to make_lupton_rgb
348 # I just can't figure out how to get that scaled so zero = 50% gray
349 stretch = 3
350 residual_rgb = np.stack(
351 [np.arcsinh(np.clip(residuals[idx], -resid_max, resid_max) * stretch) for idx in range(3)],
352 axis=-1,
353 )
354 residual_rgb /= 2 * np.arcsinh(resid_max * stretch)
355 residual_rgb += 0.5
357 ax_rgb[0][1].imshow(residual_rgb, origin="lower")
358 ax_rgb[0][1].set_title(f"Residual (abs., += {resid_max:.3e})")
359 ax_rgb[0][1].tick_params(labelleft=False)
361 if plot_chi_hist:
362 cmap = mpl.colormaps["coolwarm"]
363 residuals_rgb = np.concatenate(tuple(chis_unweighted.values()))
364 residuals_abs = np.abs(residuals_rgb)
365 n_resid = len(residuals_abs)
366 chi_max = 5 + 2.5 * (
367 (np.sum(residuals_abs > 5) / n_resid > 0.1) + (np.sum(residuals_abs > 7.5) / n_resid > 0.1)
368 )
369 n_bins = int(math.ceil(np.clip(n_resid / 50, 2, 20)) * chi_max)
370 # ax_rgb[0][2].set_adjustable('box')
371 ax_rgb[0][2].hist(
372 np.clip(residuals_rgb, -chi_max, chi_max),
373 bins=n_bins,
374 histtype="step",
375 label="all",
376 )
377 band_colors = cmap(np.linspace(0, 1, n_bands))
378 for band, band_color in zip(bands, band_colors):
379 ax_rgb[0][2].hist(
380 np.clip(residuals_rgb, -chi_max, chi_max),
381 bins=n_bins,
382 histtype="step",
383 label=band,
384 )
385 ax_rgb[0][2].legend()
387 # TODO: Plot unscaled residuals in ax_rgb[1][2]? It's unused now.
388 residual_rgb = np.stack(
389 [
390 (np.clip(residuals[idx] * images_sigma_inv[idx], -chi_max, chi_max) + chi_max) / (2 * chi_max)
391 for idx in range(3)
392 ],
393 axis=-1,
394 )
396 ax_rgb[1][1].imshow(residual_rgb, origin="lower")
397 ax_rgb[1][1].set_title(f"Residual (chi, +/- {chi_max:.2f})")
398 ax_rgb[1][1].tick_params(labelleft=False)
400 return fig_rgb, ax_rgb, fig_gs, ax_gs, mask_inv_highsn