Coverage for python / lsst / meas / extensions / multiprofit / plots.py: 0%
154 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-18 09:21 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-18 09:21 +0000
1# This file is part of meas_extensions_multiprofit.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
23from abc import ABC, abstractmethod
24from typing import Any, Iterable, Self, Type
26import astropy.table
27import astropy.units as u
28from lsst.multiprofit.plotting import bands_weights_lsst, plot_model_rgb
29import matplotlib.axes
30import matplotlib.figure
31import matplotlib.pyplot as plt
32import numpy as np
33import pydantic
35from .rebuild_coadd_multiband import DataLoader, PatchCoaddRebuilder
37__all__ = [
38 "ObjectTableBase",
39 "TruthSummaryTable",
40 "ObjectTable",
41 "ObjectTableCModel",
42 "ObjectTableMultiProFit",
43 "ObjectTablePsf",
44 "downselect_table",
45 "downselect_table_axis",
46 "plot_blend",
47 "plot_objects",
48]
50Figure = matplotlib.figure.Figure
51Axes = matplotlib.axes.Axes | Iterable[matplotlib.axes.Axes]
52FigureAxes = tuple[Figure, Axes]
55class ObjectTableBase(ABC, pydantic.BaseModel):
56 """Base class for retrieving columns from tract-based object tables."""
58 model_config = pydantic.ConfigDict(arbitrary_types_allowed=True, frozen=True)
60 table: astropy.table.Table = pydantic.Field(doc="The object table")
62 @abstractmethod
63 def get_flux(self, band: str) -> np.ndarray:
64 """Return the flux in a given band.
66 Parameters
67 ----------
68 band
69 The name of the band.
71 Returns
72 -------
73 flux
74 The configured flux in that band.
75 """
77 @abstractmethod
78 def get_id(self) -> np.ndarray:
79 """Return a unique source id."""
81 @abstractmethod
82 def get_is_extended(self) -> np.ndarray:
83 """Return if the source is extended."""
85 @abstractmethod
86 def get_is_variable(self) -> np.ndarray:
87 """Return if the source is variable."""
89 @abstractmethod
90 def get_x(self) -> np.ndarray:
91 """Return the x pixel coordinates."""
93 @abstractmethod
94 def get_y(self) -> np.ndarray:
95 """Return the y pixel coordinates."""
97 def make_subset(self, subset) -> Self:
98 """Make a new table of the same type as self with a subset of rows.
100 Parameters
101 ----------
102 subset
103 An array that can be used to select asubset of the rows in
104 self.table.
106 Returns
107 -------
108 table
109 An object of the same type as self with a subsetted table.
110 The table will be a copy, as it does not appear to be possible
111 to return views of slices of astropy Table instances.
112 """
113 kwargs_table = {name: getattr(self, name) for name in self.model_fields if name != "table"}
114 return type(self)(table=self.table[subset], **kwargs_table)
117class TruthSummaryTable(ObjectTableBase):
118 """Class for retrieving columns from DC2 truth tables."""
120 def get_flux(self, band: str) -> np.ndarray:
121 return self.table[f"flux_{band}"]
123 def get_id(self) -> np.ndarray:
124 return self.table["id"]
126 def get_is_extended(self) -> np.ndarray:
127 return self.table["is_pointsource"] == False # noqa: E712
129 def get_is_variable(self) -> np.ndarray:
130 return self.table["is_variable"] == True # noqa: E712
132 def get_x(self):
133 return self.table["x"]
135 def get_y(self):
136 return self.table["y"]
139class ObjectTable(ObjectTableBase, ABC):
140 """Base class for objectTable_tract."""
142 def get_id(self) -> np.ndarray:
143 return self.table["objectId"]
145 def get_is_extended(self) -> np.ndarray:
146 return self.table["refExtendedness"] >= 0.5
148 def get_is_variable(self) -> np.ndarray:
149 return np.zeros(len(self.table), dtype=bool)
151 def get_x(self):
152 return self.table["x"]
154 def get_y(self):
155 return self.table["y"]
158class ObjectTableCModel(ObjectTable):
159 """Class for retrieving CModel fluxes from objectTable_tract."""
161 def get_flux(self, band: str) -> np.ndarray:
162 return self.table[f"{band}_cModelFlux"]
165class ObjectTableMultiProFit(ObjectTableBase):
166 """Class for retrieving fluxes from objectTable_tract_multiprofit."""
168 name_model: str = pydantic.Field(doc="The name of the MultiProFit model")
169 prefix_col: str = pydantic.Field(doc="The prefix for object fit columns", default="mpf_")
171 def get_flux(self, band: str) -> np.ndarray:
172 return self.table[f"{self.prefix_col}{self.name_model}_{band}_flux"]
174 def get_id(self) -> np.ndarray:
175 return self.table["objectId"]
177 def get_is_extended(self) -> np.ndarray:
178 return self.table["refExtendedness"] >= 0.5
180 def get_is_variable(self) -> np.ndarray:
181 return np.zeros(len(self.table), dtype=bool)
183 def get_x(self):
184 return self.table[f"{self.prefix_col}{self.name_model}_cen_x"]
186 def get_y(self):
187 return self.table[f"{self.prefix_col}{self.name_model}_cen_y"]
190class ObjectTablePsf(ObjectTable):
191 """Class for retreiving PSF fluxes from objectTable_tract."""
193 def get_flux(self, band: str) -> np.ndarray:
194 return self.table[f"{band}_psfFlux"]
197def downselect_table(
198 table: ObjectTableBase,
199 x_min: float,
200 x_max: float,
201 y_min: float,
202 y_max: float,
203) -> ObjectTableBase:
204 """Select points from a table within an x,y extent.
206 Parameters
207 ----------
208 table
209 The table to downselect.
210 x_min
211 The minimum x value.
212 x_max
213 The maximum x value.
214 y_min
215 The minimum y value.
216 y_max
217 The maximum y value.
219 Returns
220 -------
221 table
222 A downselected table of the same class.
223 """
224 x_all = table.get_x()
225 y_all = table.get_y()
226 within = (x_all > x_min) & (x_all < x_max) & (y_all > y_min) & (y_all < y_max)
227 return table.make_subset(within)
230def downselect_table_axis(table: ObjectTableBase, axis) -> ObjectTableBase:
231 """Select points from a table within a figure axis.
233 Parameters
234 ----------
235 table
236 The table to downselect.
237 axis
238 The figure axis to determine the extent from.
240 Returns
241 -------
242 table
243 A downselected table of the same class.
244 """
245 extent = np.array(axis.axis())
246 return downselect_table(table, extent[0], extent[1], extent[2], extent[3])
249def plot_objects(
250 table: ObjectTableBase,
251 axes: Axes,
252 bands: Iterable[str],
253 table_downselected: bool = False,
254 kwargs_annotate: dict[str, Any] = None,
255 kwargs_scatter: dict[str, Any] = None,
256 labels_extended: tuple[str, str] = ("S", "G"),
257) -> Axes:
258 """Plot catalog objects on an existing image.
260 Parameters
261 ----------
262 table
263 The object table to plot source from.
264 axes
265 The figure axes to plot on.
266 bands
267 The bands to sum over fluxes to derive a total mag label.
268 table_downselected
269 Whether the table has already been downselected to contain only
270 points within the bounds of the axes.
271 kwargs_annotate
272 Keyword arguments to pass to axes.annotate.
273 kwargs_scatter
274 Keyword arguments to pass to axes.scatter.
275 labels_extended
276 Label prefixes for non-extended and extended objects, respectively.
278 Returns
279 -------
280 axes
281 The input axes with added points and labels.
282 """
283 if kwargs_annotate is None:
284 kwargs_annotate = dict(color="white", fontsize=14, ha="left", va="bottom")
285 if kwargs_scatter is None:
286 kwargs_scatter = dict(c="white", marker="+", s=100)
287 table_within = table if table_downselected else downselect_table_axis(table, axes)
288 x = table_within.get_x()
289 y = table_within.get_y()
290 axes.scatter(x, y, **kwargs_scatter)
291 fluxes = [table_within.get_flux(band) for band in bands]
292 is_extended = table_within.get_is_extended()
293 is_variable = table_within.get_is_variable()
295 for idx in range(len(table_within.table)):
296 mag = u.nanojansky.to(u.ABmag, np.sum([fluxcol[idx] for fluxcol in fluxes]))
297 type_src = f"{'V' if is_variable[idx] else ''}{labels_extended[1 if is_extended[idx] else 0]}"
298 axes.annotate(f"{type_src}{mag:.1f}", (x[idx], y[idx]), **kwargs_annotate)
300 return axes
303def plot_blend(
304 rebuilder: PatchCoaddRebuilder,
305 idx_row_parent: int,
306 weights: dict[str, float] = None,
307 table_ref_type: Type = TruthSummaryTable,
308 kwargs_plot_parent: dict[str, Any] = None,
309 kwargs_plot_children: dict[str, Any] = None,
310) -> tuple[Figure, Axes, Figure, Axes]:
311 """Plot an image of an entire blend and its deblended children.
313 Parameters
314 ----------
315 rebuilder
316 The patch rebuilder to plot from.
317 idx_row_parent
318 The row index of the parent object in the reference SourceCatalog.
319 weights
320 Multiplicative weights by band name for RGB plots.
321 table_ref_type
322 The type of reference table to construct when downselecting.
323 kwargs_plot_parent
324 Keyword arguments to pass to make RGB plots of the parent blend.
325 kwargs_plot_children
326 Keyword arguments to pass to make RGB plots of deblended children.
328 Returns
329 -------
330 fig_rgb
331 The Figure for the RGB plots of the parent.
332 ax_rgb
333 The Axes for the RGB plots of the parent.
334 fig_gs
335 The Figure for the grayscale plots of the parent.
336 ax_gs
337 The Axes for the grayscale plots of the parent.
338 """
339 if kwargs_plot_parent is None:
340 kwargs_plot_parent = {}
341 if kwargs_plot_children is None:
342 kwargs_plot_children = {}
343 if weights is None:
344 weights = bands_weights_lsst
346 plot_chi_hist = kwargs_plot_children.pop("plot_chi_hist", True)
347 rebuilder_ref = rebuilder.matches[rebuilder.name_model_ref].rebuilder
348 observations = {
349 catexp.band: catexp.get_source_observation(catexp.get_catalog()[idx_row_parent], skip_flags=True)
350 for catexp in rebuilder_ref.catexps
351 }
353 fig_rgb, ax_rgb, fig_gs, ax_gs, *_ = plot_model_rgb(
354 model=None,
355 weights=weights,
356 observations=observations,
357 plot_singleband=False,
358 plot_chi_hist=False,
359 **kwargs_plot_parent,
360 )
361 table_within_ref = downselect_table_axis(table_ref_type(table=rebuilder.reference), ax_rgb)
362 plot_objects(table_within_ref, ax_rgb, weights, table_downselected=True)
364 objects_primary = rebuilder.objects[rebuilder.objects["detect_isPrimary"] == True] # noqa: E712
365 kwargs_annotate_obs = dict(color="white", fontsize=14, ha="right", va="top")
366 kwargs_scatter_obs = dict(c="white", marker="x", s=70)
367 table_within_cmodel = downselect_table_axis(ObjectTableCModel(table=objects_primary), ax_rgb)
368 labels_extended_model = ("C", "E")
369 plot_objects(
370 table_within_cmodel,
371 ax_rgb,
372 weights,
373 table_downselected=True,
374 kwargs_annotate=kwargs_annotate_obs,
375 kwargs_scatter=kwargs_scatter_obs,
376 labels_extended=labels_extended_model,
377 )
378 plt.show()
380 objects_mpf = rebuilder.objects_multiprofit
381 objects_mpf_within = {}
382 for name, matched in rebuilder.matches.items():
383 if matched.rebuilder and objects_mpf:
384 objects_mpf_within[name] = downselect_table_axis(
385 ObjectTableMultiProFit(name_model=name, table=objects_mpf),
386 ax_rgb,
387 )
389 cat_ref = rebuilder_ref.catalog_multi
390 row_parent = cat_ref[idx_row_parent]
391 idx_children = (
392 (idx_row_parent,)
393 if (row_parent["parent"] == 0)
394 else (np.where(rebuilder_ref.catalog_multi["parent"] == row_parent["id"])[0])
395 )
397 for idx_child in idx_children:
398 for name, matched in rebuilder.matches.items():
399 print(f"Model: {name}")
400 rebuilder_child = matched.rebuilder
401 is_dataloader = isinstance(rebuilder_child, DataLoader)
402 is_scarlet = is_dataloader and (name == "scarlet")
403 if is_scarlet or rebuilder_child:
404 try:
405 if is_dataloader:
406 model = None
407 observations = rebuilder_child.load_deblended_object(idx_child)
408 else:
409 model = rebuilder_child.make_model(idx_child)
410 observations = None
412 _, ax_rgb_c, *_ = plot_model_rgb(
413 model=model,
414 weights=weights,
415 plot_singleband=False,
416 plot_chi_hist=(not is_dataloader) and plot_chi_hist,
417 observations=observations,
418 **kwargs_plot_children,
419 )
420 ax_rgb_c0 = ax_rgb_c[0][0]
421 plot_objects(table_within_ref, ax_rgb_c0, weights)
422 tab_mpf = objects_mpf_within.get(name)
423 if tab_mpf:
424 plot_objects(
425 tab_mpf,
426 ax_rgb_c0,
427 weights,
428 kwargs_annotate=kwargs_annotate_obs,
429 kwargs_scatter=kwargs_scatter_obs,
430 labels_extended=labels_extended_model,
431 )
432 plt.show()
433 except Exception as exc:
434 print(f"{idx_child=} failed to rebuild due to {exc}")
436 return fig_rgb, ax_rgb, fig_gs, ax_gs