Coverage for python/lsst/scarlet/lite/display.py: 5%

338 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-16 02:46 -0700

1# This file is part of scarlet_lite. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from typing import Sequence 

23 

24import matplotlib 

25import matplotlib.pyplot as plt 

26import numpy as np 

27from astropy.visualization.lupton_rgb import AsinhMapping, LinearMapping, Mapping 

28from matplotlib.patches import Rectangle 

29from matplotlib.ticker import MaxNLocator 

30 

31from .bbox import Box 

32from .blend import Blend 

33from .image import Image 

34from .observation import Observation 

35from .source import Source 

36 

37# Size of a single panel, used for generating figures with multiple sub-plots 

38panel_size = 4.0 

39 

40 

41def channels_to_rgb(channels: int) -> np.ndarray: 

42 """Get the linear mapping of multiple channels to RGB channels 

43 The mapping created here assumes the channels are ordered in wavelength 

44 direction, starting with the shortest wavelength. 

45 The mapping seeks to produce a relatively even weights for across 

46 all channels. It does not consider e.g. 

47 signal-to-noise variations across channels or human perception. 

48 Parameters 

49 ---------- 

50 channels: 

51 Number of channels (in range(0,7)). 

52 Returns 

53 ------- 

54 channel_map: 

55 Array (3, `channels`) to map onto RGB. 

56 """ 

57 if channels not in range(0, 8): 

58 msg = f"No mapping has been implemented for more than 8 channels, got {channels=}" 

59 raise ValueError(msg) 

60 

61 channel_map = np.zeros((3, channels)) 

62 if channels == 1: 

63 channel_map[0, 0] = channel_map[1, 0] = channel_map[2, 0] = 1 

64 elif channels == 2: 

65 channel_map[0, 1] = 0.667 

66 channel_map[1, 1] = 0.333 

67 channel_map[1, 0] = 0.333 

68 channel_map[2, 0] = 0.667 

69 channel_map /= 0.667 

70 elif channels == 3: 

71 channel_map[0, 2] = 1 

72 channel_map[1, 1] = 1 

73 channel_map[2, 0] = 1 

74 elif channels == 4: 

75 channel_map[0, 3] = 1 

76 channel_map[0, 2] = 0.333 

77 channel_map[1, 2] = 0.667 

78 channel_map[1, 1] = 0.667 

79 channel_map[2, 1] = 0.333 

80 channel_map[2, 0] = 1 

81 channel_map /= 1.333 

82 elif channels == 5: 

83 channel_map[0, 4] = 1 

84 channel_map[0, 3] = 0.667 

85 channel_map[1, 3] = 0.333 

86 channel_map[1, 2] = 1 

87 channel_map[1, 1] = 0.333 

88 channel_map[2, 1] = 0.667 

89 channel_map[2, 0] = 1 

90 channel_map /= 1.667 

91 elif channels == 6: 

92 channel_map[0, 5] = 1 

93 channel_map[0, 4] = 0.667 

94 channel_map[0, 3] = 0.333 

95 channel_map[1, 4] = 0.333 

96 channel_map[1, 3] = 0.667 

97 channel_map[1, 2] = 0.667 

98 channel_map[1, 1] = 0.333 

99 channel_map[2, 2] = 0.333 

100 channel_map[2, 1] = 0.667 

101 channel_map[2, 0] = 1 

102 channel_map /= 2 

103 elif channels == 7: 

104 channel_map[:, 6] = 2 / 3.0 

105 channel_map[0, 5] = 1 

106 channel_map[0, 4] = 0.667 

107 channel_map[0, 3] = 0.333 

108 channel_map[1, 4] = 0.333 

109 channel_map[1, 3] = 0.667 

110 channel_map[1, 2] = 0.667 

111 channel_map[1, 1] = 0.333 

112 channel_map[2, 2] = 0.333 

113 channel_map[2, 1] = 0.667 

114 channel_map[2, 0] = 1 

115 channel_map /= 2 

116 return channel_map 

117 

118 

119class LinearPercentileNorm(LinearMapping): 

120 """Create norm that is linear between lower and upper percentile of img 

121 

122 Parameters 

123 ---------- 

124 img: 

125 Image to normalize 

126 percentiles: 

127 Lower and upper percentile to consider (default = ``(1,99)``). 

128 Pixel values below will be 

129 set to zero, above to saturated. 

130 """ 

131 

132 def __init__(self, img: np.ndarray, percentiles: tuple[int, int] | None = None): 

133 if percentiles is None: 

134 percentiles = (1, 99) 

135 if len(percentiles) != 2: 

136 raise ValueError(f"Percentiles must have two values, got {percentiles=}") 

137 vmin, vmax = np.percentile(img, percentiles) 

138 super().__init__(minimum=vmin, maximum=vmax) 

139 

140 

141class AsinhPercentileNorm(AsinhMapping): 

142 """Create norm that is linear between lower and upper percentile of img 

143 

144 Parameters 

145 ---------- 

146 img: 

147 Image to normalize. 

148 percentiles: 

149 Lower and upper percentile to consider (default = ``(1,99)``). 

150 Pixel values below will be 

151 set to zero, above to saturated. 

152 """ 

153 

154 def __init__(self, img: np.ndarray, percentiles: tuple[int, int] | None = None): 

155 if percentiles is None: 

156 percentiles = (1, 99) 

157 if len(percentiles) != 2: 

158 raise ValueError(f"Percentiles must have two values, got {percentiles=}") 

159 vmin, vmax = np.percentile(img, percentiles) 

160 # solution for beta assumes flat spectrum at vmax 

161 stretch = vmax - vmin 

162 beta = stretch / np.sinh(1) 

163 super().__init__(minimum=vmin, stretch=stretch, Q=beta) 

164 

165 

166def img_to_3channel( 

167 img: np.ndarray, channel_map: np.ndarray | None = None, fill_value: float = 0 

168) -> np.ndarray: 

169 """Convert multi-band image cube into 3 RGB channels 

170 

171 Parameters 

172 ---------- 

173 img: 

174 This should be an array with dimensions (channels, height, width). 

175 channel_map: 

176 Linear mapping with dimensions (3, channels) 

177 fill_value: 

178 Value to use for any masked pixels. 

179 

180 Returns 

181 ------- 

182 RGB: 

183 The input image converted into an RGB array that can be displayed 

184 with `matplotlib.imshow`. 

185 """ 

186 # expand single img into cube 

187 if img.ndim not in [2, 3]: 

188 msg = f"The image must have 2 or 3 dimensions, got {img.ndim}" 

189 raise ValueError(msg) 

190 

191 if len(img.shape) == 2: 

192 ny, nx = img.shape 

193 img_ = img.reshape((1, ny, nx)) 

194 elif len(img.shape) == 3: 

195 img_ = img 

196 else: 

197 raise ValueError(f"Image must have either 2 or 3 dimensions, got {len(img.shape)}") 

198 dimensions = len(img_) 

199 

200 # filterWeights: channel x band 

201 if channel_map is None: 

202 channel_map = channels_to_rgb(dimensions) 

203 elif channel_map.shape != (3, len(img)): 

204 raise ValueError("Invalid channel_map returned, something unexpected happened") 

205 

206 # map channels onto RGB channels 

207 _, ny, nx = img_.shape 

208 rgb = np.dot(channel_map, img_.reshape(dimensions, -1)).reshape((3, ny, nx)) 

209 

210 if hasattr(rgb, "mask"): 

211 rgb = rgb.filled(fill_value) 

212 

213 return rgb 

214 

215 

216def img_to_rgb( 

217 img: np.ndarray | Image, 

218 channel_map: np.ndarray | None = None, 

219 fill_value: float = 0, 

220 norm: Mapping | None = None, 

221 mask: np.ndarray | None = None, 

222) -> np.ndarray: 

223 """Convert images to normalized RGB. 

224 

225 If normalized values are outside of the range [0..255], they will be 

226 truncated such as to preserve the corresponding color. 

227 

228 Parameters 

229 ---------- 

230 img: 

231 This should be an array with dimensions (channels, height, width). 

232 channel_map: 

233 Linear mapping with dimensions (3, channels) 

234 fill_value: 

235 Value to use for any masked pixels. 

236 norm: 

237 Norm to use for mapping in the allowed range [0..255]. 

238 If ``norm=None``, `LinearPercentileNorm` will be used. 

239 mask: 

240 A [0,1] binary mask to apply over the top of the image, 

241 where pixels with mask==1 are masked out. 

242 

243 Returns 

244 ------- 

245 rgb: 

246 RGB values with dimensions (3, height, width) and dtype uint8 

247 """ 

248 if isinstance(img, Image): 

249 img = img.data 

250 _rgb = img_to_3channel(img, channel_map=channel_map, fill_value=fill_value) 

251 if norm is None: 

252 norm = LinearMapping(image=_rgb) 

253 rgb = norm.make_rgb_image(*_rgb) 

254 if mask is not None: 

255 rgb = np.dstack([rgb, ~mask * 255]) 

256 return rgb 

257 

258 

259def show_likelihood( 

260 blend: Blend, figsize: tuple[float, float] | None = None, **kwargs 

261) -> matplotlib.pyplot.Figure: 

262 """Display a plot of the likelihood in each iteration for a blend 

263 

264 Parameters 

265 ---------- 

266 blend: 

267 The blend to generate the likelihood plot for. 

268 figsize: 

269 The size of the figure. 

270 kwargs: 

271 Keyword arguements passed to `blend.log_likelihood`. 

272 

273 Returns 

274 ------- 

275 fig: 

276 The figure containing the log-likelihood plot. 

277 """ 

278 fig, ax = plt.subplots(1, 1, figsize=figsize) 

279 ax.plot(blend.log_likelihood, **kwargs) 

280 ax.set_xlabel("Iteration") 

281 ax.xaxis.set_major_locator(MaxNLocator(integer=True)) 

282 ax.set_ylabel("log-Likelihood") 

283 return fig 

284 

285 

286def _add_markers( 

287 src: Source, 

288 extent: tuple[float, float, float, float], 

289 ax: matplotlib.pyplot.Axes, 

290 add_markers: bool, 

291 add_boxes: bool, 

292 marker_kwargs: dict, 

293 box_kwargs: dict, 

294): 

295 """Add markers to a plot. 

296 

297 Parameters 

298 ---------- 

299 src: 

300 The source to mark on the plot. 

301 extent: 

302 The extent of the source. 

303 ax: 

304 The axis of the plot. 

305 add_markers: 

306 Whether or not to add an "x" at the center of the source. 

307 add_boxes: 

308 Whether or not to draw a box around the entire source. 

309 marker_kwargs: 

310 Any kwargs to pass to the ``ax.plot`` when drawing the marker. 

311 box_kwargs: 

312 Any kwargs to pass to `~matplotlib.patches.Rectangle` when creating 

313 the source box. 

314 """ 

315 if add_markers and hasattr(src, "center") and src.center is not None: 

316 center = np.array(src.center)[::-1] 

317 ax.plot(*center, "wx", **marker_kwargs) 

318 

319 if add_boxes: 

320 rect = Rectangle( 

321 (extent[0], extent[2]), 

322 extent[1] - extent[0], 

323 extent[3] - extent[2], 

324 **box_kwargs, 

325 ) 

326 ax.add_artist(rect) 

327 

328 

329def show_observation( 

330 observation: Observation, 

331 norm: Mapping | None = None, 

332 channel_map: np.ndarray | None = None, 

333 centers: Sequence | None = None, 

334 psf_scaling: str | None = None, 

335 figsize: tuple[float, float] | None = None, 

336): 

337 """Plot observation in standardized form. 

338 

339 Parameters 

340 ---------- 

341 observation: 

342 The observation to show. 

343 norm: 

344 An ``astropy.visualization.lupton_rgb.Mapping`` to map the colors. 

345 channel_map: 

346 A mapping to convert the multiband image into an RGB image. 

347 centers: 

348 A list of source centers to mark on the plot. 

349 If `centers` is ``None`` then no markers are added. 

350 psf_scaling: 

351 Scaling to use to display the PSF. 

352 If `psf_scaling` is ``None`` then the PSF is not displayed. 

353 If `psf_scaling` is "native", 

354 then the PSF is displayed with no scaling. 

355 If `psf_scaling` is "same", then the PSF is normalzied using the 

356 brightest pixel in each band. 

357 figsize: 

358 The size of the output figure. 

359 If not size is specified then the figsize is calculated automatically 

360 based on the number of objects shown. 

361 """ 

362 if psf_scaling is None: 

363 panels = 1 

364 else: 

365 panels = 2 

366 if psf_scaling not in ["native", "same"]: 

367 raise ValueError(f"psf_scaling must be either 'same' or 'native', got {psf_scaling}") 

368 if figsize is None: 

369 figsize = (panel_size * panels, panel_size) 

370 fig, ax = plt.subplots(1, panels, figsize=figsize) 

371 if not hasattr(ax, "__iter__"): 

372 ax = (ax,) 

373 

374 # Mask any pixels with zero weight in all bands 

375 mask = np.sum(observation.weights.data, axis=0) == 0 

376 # if there are no masked pixels, do not use a mask 

377 if np.all(mask == 0): 

378 mask = None 

379 

380 panel = 0 

381 extent = get_extent(observation.bbox) 

382 ax[panel].imshow( 

383 img_to_rgb(observation.images, norm=norm, channel_map=channel_map, mask=mask), 

384 extent=extent, 

385 origin="lower", 

386 ) 

387 ax[panel].set_title("Observation") 

388 

389 if centers is not None: 

390 for k, center in enumerate(centers): 

391 # If the image is multi-band, use a white label, 

392 # otherwise the image with be black and white so use red. 

393 color = "w" if observation.images.shape[0] > 1 else "r" 

394 ax[panel].text(*center[::-1], k, color=color, ha="center", va="center") 

395 

396 panel += 1 

397 if psf_scaling is not None: 

398 psf_image = np.zeros(observation.images.shape) 

399 

400 if observation.model_psf is not None: 

401 psf_model = observation.psfs 

402 # make PSF as bright as the brightest pixel of the observation 

403 psf_model *= np.max(np.mean(observation.images.data, axis=0)) / np.max(np.mean(psf_model, axis=0)) 

404 if psf_scaling == "native": 

405 psf_image = psf_model 

406 else: 

407 psf_image = np.zeros(observation.images.shape) 

408 height = psf_model.shape[1] 

409 width = psf_model.shape[2] 

410 height_diff = observation.images.shape[1] - height 

411 width_diff = observation.images.shape[2] - width 

412 y0 = height_diff // 2 

413 x0 = width_diff // 2 

414 yf = y0 + height 

415 xf = x0 + width 

416 psf_image[:, y0:yf, x0:xf] = psf_model 

417 ax[panel].imshow(img_to_rgb(psf_image, norm=norm), origin="lower") 

418 ax[panel].set_title("PSF") 

419 

420 fig.tight_layout() 

421 return fig 

422 

423 

424def show_scene( 

425 blend: Blend, 

426 norm: Mapping | None = None, 

427 channel_map: np.ndarray | None = None, 

428 show_model: bool = True, 

429 show_observed: bool = False, 

430 show_rendered: bool = False, 

431 show_residual: bool = False, 

432 add_labels: bool = True, 

433 add_boxes: bool = False, 

434 figsize: tuple[float, float] | None = None, 

435 linear: bool = True, 

436 use_flux: bool = False, 

437 box_kwargs: dict | None = None, 

438) -> matplotlib.pyplot.Figure: 

439 """Plot all sources to recreate the scene. 

440 

441 The functions provides a fast way of evaluating the quality 

442 of the entire model, 

443 i.e. the combination of all sources that seek to fit the observation. 

444 

445 Parameters 

446 ---------- 

447 blend: 

448 The blend containing the observatons and sources to plot. 

449 norm: 

450 Norm to compress image intensity to the range [0,255]. 

451 channel_map: 

452 Linear mapping with dimensions (3, channels). 

453 show_model: 

454 Whether the model is shown in the model frame. 

455 show_observed: 

456 Whether the observation is shown. 

457 show_rendered: 

458 Whether the model, rendered to match the observation, is shown. 

459 show_residual: 

460 Whether the residuals between rendered model and observation is shown. 

461 add_labels: 

462 Whether each source is labeled with its numerical 

463 index in the source list. 

464 add_boxes: 

465 Whether each source box is shown. 

466 figsize: 

467 Size of the final figure. 

468 linear: 

469 Whether or not to display the scene in a single line (`True`) or 

470 on multiple lines (`False`). 

471 use_flux: 

472 Whether to show the flux redistributed model (`source.flux`) or 

473 the model itself (`source.get_model()`) for each source. 

474 box_kwargs: 

475 Keyword arguments to create boxes (`matplotlib.patches.Rectangle`) 

476 around sources, if `add_boxes == True`. 

477 

478 Returns 

479 ------- 

480 fig: 

481 The figure that is generated based on the parameters. 

482 """ 

483 if box_kwargs is None: 

484 box_kwargs = {"facecolor": "none", "edgecolor": "w", "lw": 0.5} 

485 

486 panels = sum((show_model, show_observed, show_rendered, show_residual)) 

487 if linear: 

488 if figsize is None: 

489 figsize = (panel_size * panels, panel_size) 

490 fig, ax = plt.subplots(1, panels, figsize=figsize) 

491 else: 

492 columns = int(np.ceil(panels / 2)) 

493 if figsize is None: 

494 figsize = (panel_size * columns, panel_size * 2) 

495 fig = plt.figure(figsize=figsize) 

496 ax = [fig.add_subplot(2, columns, n + 1) for n in range(panels)] 

497 if not hasattr(ax, "__iter__"): 

498 ax = (ax,) 

499 

500 observation = blend.observation 

501 sources = blend.sources 

502 model = blend.get_model(use_flux=use_flux) 

503 bbox = blend.bbox 

504 

505 # Mask any pixels with zero weight in all bands 

506 if observation is not None: 

507 mask = np.sum(observation.weights.data, axis=0) == 0 

508 # if there are no masked pixels, do not use a mask 

509 if np.all(mask == 0): 

510 mask = None 

511 else: 

512 mask = None 

513 

514 panel = 0 

515 if show_model: 

516 extent = get_extent(bbox) 

517 ax[panel].imshow( 

518 img_to_rgb(model.data, norm=norm, channel_map=channel_map, mask=mask), 

519 extent=extent, 

520 origin="lower", 

521 ) 

522 ax[panel].set_title("Model") 

523 panel += 1 

524 

525 if (show_rendered or show_residual) and not use_flux: 

526 model = observation.convolve(model) 

527 extent = get_extent(observation.bbox) 

528 

529 if show_rendered: 

530 ax[panel].imshow( 

531 img_to_rgb(model.data, norm=norm, channel_map=channel_map, mask=mask), 

532 extent=extent, 

533 origin="lower", 

534 ) 

535 ax[panel].set_title("Model Rendered") 

536 panel += 1 

537 

538 if show_observed: 

539 ax[panel].imshow( 

540 img_to_rgb(observation.images.data, norm=norm, channel_map=channel_map, mask=mask), 

541 extent=extent, 

542 origin="lower", 

543 ) 

544 ax[panel].set_title("Observation") 

545 panel += 1 

546 

547 if show_residual: 

548 residual = observation.images - model 

549 norm_ = LinearPercentileNorm(residual.data) 

550 ax[panel].imshow( 

551 img_to_rgb(residual.data, norm=norm_, channel_map=channel_map, mask=mask), 

552 extent=extent, 

553 origin="lower", 

554 ) 

555 ax[panel].set_title("Residual") 

556 panel += 1 

557 

558 for k, src in enumerate(sources): 

559 if add_boxes: 

560 panel = 0 

561 extent = get_extent(src.bbox) 

562 if show_model: 

563 rect = Rectangle( 

564 (extent[0], extent[2]), 

565 extent[1] - extent[0], 

566 extent[3] - extent[2], 

567 **box_kwargs, 

568 ) 

569 ax[panel].add_artist(rect) 

570 panel = 1 

571 if observation is not None: 

572 for panel in range(panel, panels): 

573 rect = Rectangle( 

574 (extent[0], extent[2]), 

575 extent[1] - extent[0], 

576 extent[3] - extent[2], 

577 **box_kwargs, 

578 ) 

579 ax[panel].add_artist(rect) 

580 

581 if add_labels and hasattr(src, "center") and src.center is not None: 

582 center = src.center 

583 panel = 0 

584 if show_model: 

585 ax[panel].text(*center[::-1], k, color="w", ha="center", va="center") 

586 panel = 1 

587 if observation is not None: 

588 for panel in range(panel, panels): 

589 ax[panel].text(*center[::-1], k, color="w", ha="center", va="center") 

590 

591 fig.tight_layout() 

592 return fig 

593 

594 

595def get_extent(bbox: Box) -> tuple[int, int, int, int]: 

596 """Convert a `Box` into a list of bounds used in matplotlib 

597 

598 Paramters 

599 --------- 

600 bbox: 

601 The box to convert into an extent list. 

602 

603 Returns 

604 ------- 

605 extent: 

606 Tuple of coordinates that matplotlib requires for the 

607 extent of an image in ``imshow``. 

608 """ 

609 return bbox.start[-1], bbox.stop[-1], bbox.start[-2], bbox.stop[-2] 

610 

611 

612def show_sources( 

613 blend: Blend, 

614 sources: list[Source] | None = None, 

615 norm: Mapping | None = None, 

616 channel_map: np.ndarray | None = None, 

617 show_model: bool = True, 

618 show_observed: bool = False, 

619 show_rendered: bool = False, 

620 show_spectrum: bool = True, 

621 figsize: tuple[float, float] | None = None, 

622 model_mask: bool = True, 

623 add_markers: bool = True, 

624 add_boxes: bool = False, 

625 use_flux: bool = False, 

626) -> matplotlib.pyplot.Figure: 

627 """Plot individual source models 

628 

629 The functions provides a fast way of evaluating the quality of 

630 individual sources. 

631 

632 Parameters 

633 ---------- 

634 blend: 

635 The blend that contains the sources. 

636 sources: 

637 The list of sources to plot. 

638 If `sources` is `None` then all of the sources in `blend` are 

639 displayed. 

640 norm: 

641 Norm to compress image intensity to the range [0,255]. 

642 channel_map: 

643 Linear mapping with dimensions (3, channels). 

644 show_model: 

645 Whether the model is shown in the model frame. 

646 show_observed: 

647 Whether the observation is shown. 

648 show_rendered: 

649 Whether the model, rendered to match the observation, is shown. 

650 show_spectrum: 

651 Whether or not to show a plot for the spectrum of each component 

652 in each source. 

653 figsize: 

654 Size of the final figure. 

655 model_mask: 

656 Whether pixels with no flux in a model are masked. 

657 add_markers: 

658 Whether all of the sources are marked in each plot. 

659 add_boxes: 

660 Whether each source box is shown. 

661 use_flux: 

662 Whether to show the flux redistributed model (`source.flux`) or 

663 the model itself (`source.get_model()`) for each source. 

664 

665 Returns 

666 ------- 

667 fig: 

668 The figure that is generated based on the parameters. 

669 """ 

670 observation = blend.observation 

671 if sources is None: 

672 sources = blend.sources 

673 panels = sum((show_model, show_observed, show_rendered, show_spectrum)) 

674 n_sources = len([src for src in sources if not src.is_null]) 

675 if figsize is None: 

676 figsize = (panel_size * panels, panel_size * n_sources) 

677 

678 fig, ax = plt.subplots(n_sources, panels, figsize=figsize, squeeze=False) 

679 

680 marker_kwargs = {"mew": 1, "ms": 10} 

681 box_kwargs = {"facecolor": "none", "edgecolor": "w", "lw": 0.5} 

682 

683 skipped = 0 

684 for k, src in enumerate(sources): 

685 if src.is_null: 

686 skipped += 1 

687 continue 

688 if use_flux: 

689 if src.flux_weighted_image is None: 

690 raise ValueError(f"Flux has not been calculated for src {k}, rerun measure.conserve_flux") 

691 src_box = src.flux_weighted_image.bbox 

692 else: 

693 src_box = src.bbox 

694 

695 extent = get_extent(src_box) 

696 

697 # model in its bbox 

698 panel = 0 

699 model = src.get_model(use_flux=use_flux) 

700 

701 if show_model: 

702 if model_mask: 

703 _model_mask = np.max(model.data, axis=0) <= 0 

704 else: 

705 _model_mask = None 

706 # Show the unrendered model in it's bbox 

707 ax[k - skipped][panel].imshow( 

708 img_to_rgb(model.data, norm=norm, channel_map=channel_map, mask=_model_mask), 

709 extent=extent, 

710 origin="lower", 

711 ) 

712 ax[k - skipped][panel].set_title("Model Source {}".format(k)) 

713 _add_markers( 

714 src, 

715 extent, 

716 ax[k - skipped][panel], 

717 add_markers, 

718 False, 

719 marker_kwargs, 

720 box_kwargs, 

721 ) 

722 panel += 1 

723 

724 # model in observation frame 

725 if show_rendered: 

726 # Center and show the rendered model 

727 model_ = Image(np.zeros(observation.shape), bands=observation.bands) 

728 model_.insert(src.get_model(use_flux=use_flux)) 

729 if not use_flux: 

730 model_ = observation.convolve(model_) 

731 ax[k - skipped][panel].imshow( 

732 img_to_rgb(model_.data, norm=norm, channel_map=channel_map), 

733 extent=get_extent(observation.bbox), 

734 origin="lower", 

735 ) 

736 ax[k - skipped][panel].set_title("Model Source {} Rendered".format(k)) 

737 _add_markers( 

738 src, 

739 extent, 

740 ax[k - skipped][panel], 

741 add_markers, 

742 add_boxes, 

743 marker_kwargs, 

744 box_kwargs, 

745 ) 

746 panel += 1 

747 

748 if show_observed: 

749 # Center the observation on the source and display it 

750 _images = observation.images 

751 ax[k - skipped][panel].imshow( 

752 img_to_rgb(_images.data, norm=norm, channel_map=channel_map), 

753 extent=get_extent(observation.bbox), 

754 origin="lower", 

755 ) 

756 ax[k - skipped][panel].set_title(f"Observation {k}") 

757 _add_markers( 

758 src, 

759 extent, 

760 ax[k - skipped][panel], 

761 add_markers, 

762 add_boxes, 

763 marker_kwargs, 

764 box_kwargs, 

765 ) 

766 panel += 1 

767 

768 if show_spectrum: 

769 spectra = [np.sum(model.data, axis=(1, 2))] 

770 

771 for spectrum in spectra: 

772 ax[k - skipped][panel].plot(spectrum) 

773 ax[k - skipped][panel].set_xticks(range(len(spectra))) 

774 ax[k - skipped][panel].set_title("Spectrum") 

775 ax[k - skipped][panel].set_xlabel("Band") 

776 ax[k - skipped][panel].set_ylabel("Intensity") 

777 

778 fig.tight_layout() 

779 return fig 

780 

781 

782def compare_spectra( 

783 use_flux: bool = True, use_template: bool = True, **all_sources: list[Source] 

784) -> matplotlib.pyplot.Figure: 

785 """Compare spectra from multiple different deblending results of the 

786 same sources. 

787 

788 Parameters 

789 ---------- 

790 use_flux: 

791 Whether or not to show the re-distributed flux version of the model. 

792 use_template: 

793 Whether or not to show the scarlet model templates. 

794 all_sources: 

795 The list of sources for each different deblending model. 

796 """ 

797 first_key = next(iter(all_sources.keys())) 

798 nbr_sources = len(all_sources[first_key]) 

799 for key, sources in all_sources.items(): 

800 if len(sources) != nbr_sources: 

801 msg = ( 

802 "All source lists must have the same number of components." 

803 f"Received {nbr_sources} sources for the list {first_key} and {len(sources)}" 

804 f"for list {key}." 

805 ) 

806 raise ValueError(msg) 

807 

808 columns = 4 

809 rows = int(np.ceil(nbr_sources / columns)) 

810 fig, ax = plt.subplots(rows, columns, figsize=(15, 15 * rows / columns)) 

811 if rows == 1: 

812 ax = [ax[0], ax[1]] 

813 

814 panel = 0 

815 for k in range(nbr_sources): 

816 row = panel // 4 

817 column = panel - row * 4 

818 ax[row][column].set_title(f"source {k}") 

819 for key, sources in all_sources.items(): 

820 if sources[k].is_null: 

821 continue 

822 if use_template or not hasattr(sources[k], "flux"): 

823 spectrum = np.sum(sources[k].get_model().data, axis=(1, 2)) 

824 ax[row][column].plot(spectrum, ".-", label=key + " model") 

825 if use_flux and hasattr(sources[k], "flux"): 

826 spectrum = np.sum(sources[k].get_model(use_flux=True).data, axis=(1, 2)) 

827 ax[row][column].plot(spectrum, ".--", label=key + " flux") 

828 panel += 1 

829 handles, labels = ax[0][0].get_legend_handles_labels() 

830 fig.legend(handles, labels, loc="lower center", ncol=4) 

831 return fig