Coverage for python / lsst / meas / extensions / multiprofit / plots.py: 0%

154 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-26 09:37 +0000

1# This file is part of meas_extensions_multiprofit. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22 

23from abc import ABC, abstractmethod 

24from typing import Any, Iterable, Self, Type 

25 

26import astropy.table 

27import astropy.units as u 

28from lsst.multiprofit.plotting import bands_weights_lsst, plot_model_rgb 

29import matplotlib.axes 

30import matplotlib.figure 

31import matplotlib.pyplot as plt 

32import numpy as np 

33import pydantic 

34 

35from .rebuild_coadd_multiband import DataLoader, PatchCoaddRebuilder 

36 

37__all__ = [ 

38 "ObjectTableBase", 

39 "TruthSummaryTable", 

40 "ObjectTable", 

41 "ObjectTableCModel", 

42 "ObjectTableMultiProFit", 

43 "ObjectTablePsf", 

44 "downselect_table", 

45 "downselect_table_axis", 

46 "plot_blend", 

47 "plot_objects", 

48] 

49 

50Figure = matplotlib.figure.Figure 

51Axes = matplotlib.axes.Axes | Iterable[matplotlib.axes.Axes] 

52FigureAxes = tuple[Figure, Axes] 

53 

54 

55class ObjectTableBase(ABC, pydantic.BaseModel): 

56 """Base class for retrieving columns from tract-based object tables.""" 

57 

58 model_config = pydantic.ConfigDict(arbitrary_types_allowed=True, frozen=True) 

59 

60 table: astropy.table.Table = pydantic.Field(doc="The object table") 

61 

62 @abstractmethod 

63 def get_flux(self, band: str) -> np.ndarray: 

64 """Return the flux in a given band. 

65 

66 Parameters 

67 ---------- 

68 band 

69 The name of the band. 

70 

71 Returns 

72 ------- 

73 flux 

74 The configured flux in that band. 

75 """ 

76 

77 @abstractmethod 

78 def get_id(self) -> np.ndarray: 

79 """Return a unique source id.""" 

80 

81 @abstractmethod 

82 def get_is_extended(self) -> np.ndarray: 

83 """Return if the source is extended.""" 

84 

85 @abstractmethod 

86 def get_is_variable(self) -> np.ndarray: 

87 """Return if the source is variable.""" 

88 

89 @abstractmethod 

90 def get_x(self) -> np.ndarray: 

91 """Return the x pixel coordinates.""" 

92 

93 @abstractmethod 

94 def get_y(self) -> np.ndarray: 

95 """Return the y pixel coordinates.""" 

96 

97 def make_subset(self, subset) -> Self: 

98 """Make a new table of the same type as self with a subset of rows. 

99 

100 Parameters 

101 ---------- 

102 subset 

103 An array that can be used to select asubset of the rows in 

104 self.table. 

105 

106 Returns 

107 ------- 

108 table 

109 An object of the same type as self with a subsetted table. 

110 The table will be a copy, as it does not appear to be possible 

111 to return views of slices of astropy Table instances. 

112 """ 

113 kwargs_table = {name: getattr(self, name) for name in self.model_fields if name != "table"} 

114 return type(self)(table=self.table[subset], **kwargs_table) 

115 

116 

117class TruthSummaryTable(ObjectTableBase): 

118 """Class for retrieving columns from DC2 truth tables.""" 

119 

120 def get_flux(self, band: str) -> np.ndarray: 

121 return self.table[f"flux_{band}"] 

122 

123 def get_id(self) -> np.ndarray: 

124 return self.table["id"] 

125 

126 def get_is_extended(self) -> np.ndarray: 

127 return self.table["is_pointsource"] == False # noqa: E712 

128 

129 def get_is_variable(self) -> np.ndarray: 

130 return self.table["is_variable"] == True # noqa: E712 

131 

132 def get_x(self): 

133 return self.table["x"] 

134 

135 def get_y(self): 

136 return self.table["y"] 

137 

138 

139class ObjectTable(ObjectTableBase, ABC): 

140 """Base class for objectTable_tract.""" 

141 

142 def get_id(self) -> np.ndarray: 

143 return self.table["objectId"] 

144 

145 def get_is_extended(self) -> np.ndarray: 

146 return self.table["refExtendedness"] >= 0.5 

147 

148 def get_is_variable(self) -> np.ndarray: 

149 return np.zeros(len(self.table), dtype=bool) 

150 

151 def get_x(self): 

152 return self.table["x"] 

153 

154 def get_y(self): 

155 return self.table["y"] 

156 

157 

158class ObjectTableCModel(ObjectTable): 

159 """Class for retrieving CModel fluxes from objectTable_tract.""" 

160 

161 def get_flux(self, band: str) -> np.ndarray: 

162 return self.table[f"{band}_cModelFlux"] 

163 

164 

165class ObjectTableMultiProFit(ObjectTableBase): 

166 """Class for retrieving fluxes from objectTable_tract_multiprofit.""" 

167 

168 name_model: str = pydantic.Field(doc="The name of the MultiProFit model") 

169 prefix_col: str = pydantic.Field(doc="The prefix for object fit columns", default="mpf_") 

170 

171 def get_flux(self, band: str) -> np.ndarray: 

172 return self.table[f"{self.prefix_col}{self.name_model}_{band}_flux"] 

173 

174 def get_id(self) -> np.ndarray: 

175 return self.table["objectId"] 

176 

177 def get_is_extended(self) -> np.ndarray: 

178 return self.table["refExtendedness"] >= 0.5 

179 

180 def get_is_variable(self) -> np.ndarray: 

181 return np.zeros(len(self.table), dtype=bool) 

182 

183 def get_x(self): 

184 return self.table[f"{self.prefix_col}{self.name_model}_cen_x"] 

185 

186 def get_y(self): 

187 return self.table[f"{self.prefix_col}{self.name_model}_cen_y"] 

188 

189 

190class ObjectTablePsf(ObjectTable): 

191 """Class for retreiving PSF fluxes from objectTable_tract.""" 

192 

193 def get_flux(self, band: str) -> np.ndarray: 

194 return self.table[f"{band}_psfFlux"] 

195 

196 

197def downselect_table( 

198 table: ObjectTableBase, 

199 x_min: float, 

200 x_max: float, 

201 y_min: float, 

202 y_max: float, 

203) -> ObjectTableBase: 

204 """Select points from a table within an x,y extent. 

205 

206 Parameters 

207 ---------- 

208 table 

209 The table to downselect. 

210 x_min 

211 The minimum x value. 

212 x_max 

213 The maximum x value. 

214 y_min 

215 The minimum y value. 

216 y_max 

217 The maximum y value. 

218 

219 Returns 

220 ------- 

221 table 

222 A downselected table of the same class. 

223 """ 

224 x_all = table.get_x() 

225 y_all = table.get_y() 

226 within = (x_all > x_min) & (x_all < x_max) & (y_all > y_min) & (y_all < y_max) 

227 return table.make_subset(within) 

228 

229 

230def downselect_table_axis(table: ObjectTableBase, axis) -> ObjectTableBase: 

231 """Select points from a table within a figure axis. 

232 

233 Parameters 

234 ---------- 

235 table 

236 The table to downselect. 

237 axis 

238 The figure axis to determine the extent from. 

239 

240 Returns 

241 ------- 

242 table 

243 A downselected table of the same class. 

244 """ 

245 extent = np.array(axis.axis()) 

246 return downselect_table(table, extent[0], extent[1], extent[2], extent[3]) 

247 

248 

249def plot_objects( 

250 table: ObjectTableBase, 

251 axes: Axes, 

252 bands: Iterable[str], 

253 table_downselected: bool = False, 

254 kwargs_annotate: dict[str, Any] = None, 

255 kwargs_scatter: dict[str, Any] = None, 

256 labels_extended: tuple[str, str] = ("S", "G"), 

257) -> Axes: 

258 """Plot catalog objects on an existing image. 

259 

260 Parameters 

261 ---------- 

262 table 

263 The object table to plot source from. 

264 axes 

265 The figure axes to plot on. 

266 bands 

267 The bands to sum over fluxes to derive a total mag label. 

268 table_downselected 

269 Whether the table has already been downselected to contain only 

270 points within the bounds of the axes. 

271 kwargs_annotate 

272 Keyword arguments to pass to axes.annotate. 

273 kwargs_scatter 

274 Keyword arguments to pass to axes.scatter. 

275 labels_extended 

276 Label prefixes for non-extended and extended objects, respectively. 

277 

278 Returns 

279 ------- 

280 axes 

281 The input axes with added points and labels. 

282 """ 

283 if kwargs_annotate is None: 

284 kwargs_annotate = dict(color="white", fontsize=14, ha="left", va="bottom") 

285 if kwargs_scatter is None: 

286 kwargs_scatter = dict(c="white", marker="+", s=100) 

287 table_within = table if table_downselected else downselect_table_axis(table, axes) 

288 x = table_within.get_x() 

289 y = table_within.get_y() 

290 axes.scatter(x, y, **kwargs_scatter) 

291 fluxes = [table_within.get_flux(band) for band in bands] 

292 is_extended = table_within.get_is_extended() 

293 is_variable = table_within.get_is_variable() 

294 

295 for idx in range(len(table_within.table)): 

296 mag = u.nanojansky.to(u.ABmag, np.sum([fluxcol[idx] for fluxcol in fluxes])) 

297 type_src = f"{'V' if is_variable[idx] else ''}{labels_extended[1 if is_extended[idx] else 0]}" 

298 axes.annotate(f"{type_src}{mag:.1f}", (x[idx], y[idx]), **kwargs_annotate) 

299 

300 return axes 

301 

302 

303def plot_blend( 

304 rebuilder: PatchCoaddRebuilder, 

305 idx_row_parent: int, 

306 weights: dict[str, float] = None, 

307 table_ref_type: Type = TruthSummaryTable, 

308 kwargs_plot_parent: dict[str, Any] = None, 

309 kwargs_plot_children: dict[str, Any] = None, 

310) -> tuple[Figure, Axes, Figure, Axes]: 

311 """Plot an image of an entire blend and its deblended children. 

312 

313 Parameters 

314 ---------- 

315 rebuilder 

316 The patch rebuilder to plot from. 

317 idx_row_parent 

318 The row index of the parent object in the reference SourceCatalog. 

319 weights 

320 Multiplicative weights by band name for RGB plots. 

321 table_ref_type 

322 The type of reference table to construct when downselecting. 

323 kwargs_plot_parent 

324 Keyword arguments to pass to make RGB plots of the parent blend. 

325 kwargs_plot_children 

326 Keyword arguments to pass to make RGB plots of deblended children. 

327 

328 Returns 

329 ------- 

330 fig_rgb 

331 The Figure for the RGB plots of the parent. 

332 ax_rgb 

333 The Axes for the RGB plots of the parent. 

334 fig_gs 

335 The Figure for the grayscale plots of the parent. 

336 ax_gs 

337 The Axes for the grayscale plots of the parent. 

338 """ 

339 if kwargs_plot_parent is None: 

340 kwargs_plot_parent = {} 

341 if kwargs_plot_children is None: 

342 kwargs_plot_children = {} 

343 if weights is None: 

344 weights = bands_weights_lsst 

345 

346 plot_chi_hist = kwargs_plot_children.pop("plot_chi_hist", True) 

347 rebuilder_ref = rebuilder.matches[rebuilder.name_model_ref].rebuilder 

348 observations = { 

349 catexp.band: catexp.get_source_observation(catexp.get_catalog()[idx_row_parent], skip_flags=True) 

350 for catexp in rebuilder_ref.catexps 

351 } 

352 

353 fig_rgb, ax_rgb, fig_gs, ax_gs, *_ = plot_model_rgb( 

354 model=None, 

355 weights=weights, 

356 observations=observations, 

357 plot_singleband=False, 

358 plot_chi_hist=False, 

359 **kwargs_plot_parent, 

360 ) 

361 table_within_ref = downselect_table_axis(table_ref_type(table=rebuilder.reference), ax_rgb) 

362 plot_objects(table_within_ref, ax_rgb, weights, table_downselected=True) 

363 

364 objects_primary = rebuilder.objects[rebuilder.objects["detect_isPrimary"] == True] # noqa: E712 

365 kwargs_annotate_obs = dict(color="white", fontsize=14, ha="right", va="top") 

366 kwargs_scatter_obs = dict(c="white", marker="x", s=70) 

367 table_within_cmodel = downselect_table_axis(ObjectTableCModel(table=objects_primary), ax_rgb) 

368 labels_extended_model = ("C", "E") 

369 plot_objects( 

370 table_within_cmodel, 

371 ax_rgb, 

372 weights, 

373 table_downselected=True, 

374 kwargs_annotate=kwargs_annotate_obs, 

375 kwargs_scatter=kwargs_scatter_obs, 

376 labels_extended=labels_extended_model, 

377 ) 

378 plt.show() 

379 

380 objects_mpf = rebuilder.objects_multiprofit 

381 objects_mpf_within = {} 

382 for name, matched in rebuilder.matches.items(): 

383 if matched.rebuilder and objects_mpf: 

384 objects_mpf_within[name] = downselect_table_axis( 

385 ObjectTableMultiProFit(name_model=name, table=objects_mpf), 

386 ax_rgb, 

387 ) 

388 

389 cat_ref = rebuilder_ref.catalog_multi 

390 row_parent = cat_ref[idx_row_parent] 

391 idx_children = ( 

392 (idx_row_parent,) 

393 if (row_parent["parent"] == 0) 

394 else (np.where(rebuilder_ref.catalog_multi["parent"] == row_parent["id"])[0]) 

395 ) 

396 

397 for idx_child in idx_children: 

398 for name, matched in rebuilder.matches.items(): 

399 print(f"Model: {name}") 

400 rebuilder_child = matched.rebuilder 

401 is_dataloader = isinstance(rebuilder_child, DataLoader) 

402 is_scarlet = is_dataloader and (name == "scarlet") 

403 if is_scarlet or rebuilder_child: 

404 try: 

405 if is_dataloader: 

406 model = None 

407 observations = rebuilder_child.load_deblended_object(idx_child) 

408 else: 

409 model = rebuilder_child.make_model(idx_child) 

410 observations = None 

411 

412 _, ax_rgb_c, *_ = plot_model_rgb( 

413 model=model, 

414 weights=weights, 

415 plot_singleband=False, 

416 plot_chi_hist=(not is_dataloader) and plot_chi_hist, 

417 observations=observations, 

418 **kwargs_plot_children, 

419 ) 

420 ax_rgb_c0 = ax_rgb_c[0][0] 

421 plot_objects(table_within_ref, ax_rgb_c0, weights) 

422 tab_mpf = objects_mpf_within.get(name) 

423 if tab_mpf: 

424 plot_objects( 

425 tab_mpf, 

426 ax_rgb_c0, 

427 weights, 

428 kwargs_annotate=kwargs_annotate_obs, 

429 kwargs_scatter=kwargs_scatter_obs, 

430 labels_extended=labels_extended_model, 

431 ) 

432 plt.show() 

433 except Exception as exc: 

434 print(f"{idx_child=} failed to rebuild due to {exc}") 

435 

436 return fig_rgb, ax_rgb, fig_gs, ax_gs