Coverage for python / lsst / meas / extensions / multiprofit / fit_coadd_multiband.py: 21%
515 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-01 08:57 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-01 08:57 +0000
1# This file is part of meas_extensions_multiprofit.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22__all__ = (
23 "PsfFitSuccessActionBase",
24 "PsfComponentsActionBase",
25 "SourceTablePsfFitSuccessAction",
26 "SourceTablePsfComponentsAction",
27 "MagnitudeDependentSizePriorConfig",
28 "ModelInitializer",
29 "MakeInitializerActionBase",
30 "BasicModelInitializer",
31 "CachedBasicModelInitializer",
32 "InitialInputData",
33 "MakeBasicInitializerAction",
34 "MakeCachedBasicInitializerAction",
35 "MultiProFitSourceConfig",
36 "CatalogExposurePsfs",
37 "MultiProFitSourceFitter",
38 "MultiProFitSourceTask",
39)
41from abc import ABC, abstractmethod
42from functools import cached_property
43import logging
44import math
45from typing import Any, ClassVar, Iterable, Mapping, Sequence
47from astropy.table import Table
48import astropy.units as u
49import lsst.afw.geom
50import lsst.afw.table as afwTable
51from lsst.daf.butler.formatters.parquet import astropy_to_arrow
52import lsst.gauss2d as g2
53import lsst.gauss2d.fit as g2f
54from lsst.multiprofit.errors import NoDataError, PsfRebuildFitFlagError
55from lsst.multiprofit.fitting.fit_psf import CatalogPsfFitterConfig, CatalogPsfFitterConfigData
56from lsst.multiprofit.fitting.fit_source import (
57 CatalogExposureSourcesABC,
58 CatalogSourceFitterABC,
59 CatalogSourceFitterConfig,
60 CatalogSourceFitterConfigData,
61)
62from lsst.multiprofit.modeller import Model
63from lsst.multiprofit.utils import frozen_arbitrary_allowed_config, get_params_uniq, set_config_from_dict
64import lsst.pex.config as pexConfig
65from lsst.pex.config.configurableActions import ConfigurableAction, ConfigurableActionField
66import lsst.pipe.base as pipeBase
67import lsst.pipe.tasks.fit_coadd_multiband as fitMB
68import lsst.utils.timer as utilsTimer
69import numpy as np
70import pydantic
72from .errors import IsParentError, NotPrimaryError
73from .input_config import InputConfig
74from .utils import get_spanned_image
76_LOG = logging.getLogger(__name__)
77TWO_SQRT_PI = 2 * math.sqrt(np.pi)
80class PsfFitSuccessActionBase(ConfigurableAction):
81 """Base action to return whether a source had a succesful PSF fit."""
83 def get_schema(self) -> list[str]:
84 """Return the list of columns required to call this action."""
85 raise NotImplementedError("This method must be overloaded in subclasses")
87 def __call__(self, source: Mapping[str, Any], *args: Any, **kwargs: Any) -> bool:
88 raise NotImplementedError("This method must be overloaded in subclasses")
91class PsfComponentsActionBase(ConfigurableAction):
92 """Base action to return a list of Gaussians from a source mapping.
94 This base class should be used as a sentinel when using a MultiProFit PSF
95 fit table, and only needs to be specialized for external PSF fitters.
96 """
98 def get_schema(self) -> list[str]:
99 """Return the list of columns required to call this action."""
100 raise NotImplementedError("This method must be overloaded in subclasses")
102 def __call__(self, source: Mapping[str, Any], *args: Any, **kwargs: Any) -> list[g2.Gaussian]:
103 raise NotImplementedError("This method must be overloaded in subclasses")
106class SourceTablePsfFitSuccessAction(PsfFitSuccessActionBase):
107 """Action to return PSF fit status from a SourceTable row."""
109 flag_format = pexConfig.Field[str](
110 doc="Format for the flag field; flag_prefix, flag_suffix and flag_sub are substituted",
111 default="{flag_prefix}{flag_suffix}{flag_sub}",
112 )
113 flag_prefix = pexConfig.Field[str](
114 doc="Prefix for the key for the summed flag field",
115 default="modelfit_DoubleShapeletPsfApprox",
116 )
117 flag_suffix = pexConfig.Field[str](
118 doc="Suffix for all flag fields",
119 default="_flag",
120 )
121 flags_sub = pexConfig.ListField[str](
122 doc="Suffixes for specific flag fields that must not be true",
123 default=["_invalidPointForPsf", "_invalidMoments", "_maxIterations"],
124 )
126 def _format(self, flag_sub: str) -> str:
127 return self.flag_format.format(
128 flag_prefix=self.flag_prefix,
129 flag_sub=flag_sub,
130 flag_suffix=self.flag_suffix,
131 )
133 def get_schema(self) -> Iterable[str]:
134 for flag_sub in self.flags_sub:
135 yield self._format(flag_sub=flag_sub)
137 def __call__(self, source: Mapping[str, Any], *args: Any, **kwargs: Any) -> bool:
138 good = True
139 for flag_sub in self.flags_sub:
140 good &= not source[self._format(flag_sub=flag_sub)]
141 return good
144class SourceTablePsfComponentsAction(PsfComponentsActionBase):
145 """Action to return PSF components from a SourceTable.
147 This is anticipated to be a deepCoadd_meas with PSF fit parameters from a
148 measurement plugin returning covariance matrix terms.
149 """
151 action_source = ConfigurableActionField[PsfFitSuccessActionBase](
152 doc="Action to return whether the PSF fit was successful for a single source row",
153 default=SourceTablePsfFitSuccessAction,
154 )
155 format = pexConfig.Field[str](
156 doc="Format for the field names, where {idx_comp} is the index of the component and {moment}"
157 "is the name of the moment (xx, xy or yy, integral)",
158 default="modelfit_DoubleShapeletPsfApprox_{idx_comp}_{moment}",
159 )
160 name_moment_xx = pexConfig.Field[str](doc="Name of the xx (2nd x-axis) moment", default="xx")
161 name_moment_xy = pexConfig.Field[str](doc="Name of the xy (covariance term) moment", default="xy")
162 name_moment_yy = pexConfig.Field[str](doc="Name of the yy (2nd y-axis) moment", default="yy")
163 name_moment_integral = pexConfig.Field[str](doc="Name of the integral (zeroth) moment", default="0")
164 n_components = pexConfig.Field[int](
165 doc="Number of Gaussian components",
166 default=2,
167 check=lambda x: x >= 2,
168 )
170 @staticmethod
171 def get_integral(moment_zero) -> float:
172 """Get the total integrated flux from a zeroth moment value.
174 The zeroth moment is simply the integrated flux divided by a
175 constant value of 2*sqrt(pi).
177 Parameters
178 ----------
179 moment_zero
180 The zeroth moment value.
182 Returns
183 -------
184 integral
185 The total integrated weight (flux).
186 """
187 return moment_zero * TWO_SQRT_PI
189 def get_schema(self) -> list[str]:
190 names_moments = (
191 self.name_moment_xx,
192 self.name_moment_yy,
193 self.name_moment_xy,
194 self.name_moment_integral,
195 )
196 columns = [
197 column
198 for idx_comp in range(self.n_components)
199 for column in (
200 self.format.format(name_moment=name_moment, idx_comp=idx_comp)
201 for name_moment in names_moments
202 )
203 ] + self.action_source.get_schema()
204 return columns
206 def __call__(self, source: Mapping[str, Any], *args: Any, **kwargs: Any) -> list[g2.Gaussian]:
207 if not self.action_source(source):
208 raise PsfRebuildFitFlagError(
209 f"PSF fit failed due to action based on schema: {self.action_source.get_schema()}"
210 )
211 gaussians = [None] * self.n_components
212 for idx_comp in range(self.n_components):
213 gaussian = g2.Gaussian(
214 ellipse=g2.Ellipse(
215 g2.Covariance(
216 sigma_x_sq=source[self.format.format(moment=self.name_moment_xx, idx_comp=idx_comp)],
217 sigma_y_sq=source[self.format.format(moment=self.name_moment_yy, idx_comp=idx_comp)],
218 cov_xy=source[self.format.format(moment=self.name_moment_xy, idx_comp=idx_comp)],
219 )
220 ),
221 integral=g2.GaussianIntegralValue(
222 value=self.get_integral(
223 source[self.format.format(moment=self.name_moment_integral, idx_comp=idx_comp)]
224 )
225 ),
226 )
227 gaussians[idx_comp] = gaussian
228 return gaussians
231class MagnitudeDependentSizePriorConfig(pexConfig.Config):
232 """Configuration for a magnitude-dependent size prior.
234 Defaults are for ugrizy total mag and log10(r_eff/arcsec).
235 """
237 intercept_mag = pexConfig.Field[float](
238 doc="The magnitude at which no adjustment is applied",
239 default=18.0,
240 )
241 slope_median_per_mag = pexConfig.Field[float](
242 doc="The slope in the median size, in dex per mag",
243 default=-0.15,
244 )
245 slope_stddev_per_mag = pexConfig.Field[float](
246 doc="The slope in the standard deviation of the size, in dex per mag",
247 default=0.0,
248 )
251class ModelInitializer(ABC, pydantic.BaseModel):
252 """An interface for a configurable model initializer based on priors
253 and optional external data.
254 """
256 model_config: ClassVar[pydantic.ConfigDict] = frozen_arbitrary_allowed_config
258 inputs: dict[str, Any] = pydantic.Field(
259 title="Additional external inputs used in initialization",
260 default_factory=dict,
261 )
262 priors_shape_mag: dict = pydantic.Field(
263 title="Magnitude-dependent shape prior configurations",
264 default_factory=dict,
265 )
267 @abstractmethod
268 def initialize_model(
269 self,
270 model: Model,
271 source: Mapping[str, Any],
272 catexps: list[CatalogExposureSourcesABC],
273 config_data: CatalogSourceFitterConfigData,
274 values_init: Mapping[g2f.ParameterD, float] | None = None,
275 **kwargs,
276 ):
277 """Initialize a MultiProFit model for a single object corresponding
278 to a row in a catalog.
280 Parameters
281 ----------
282 model
283 The model to initialize parameter values for.
284 source
285 A mapping with fields expected to be populated in the
286 corresponding source catalog for initialization.
287 catexps
288 Per-band catalog-exposure pairs.
289 config_data
290 Fitter configuration and data.
291 values_init
292 Default initial values for parameters.
293 **kwargs
294 Additional keyword arguments for any purpose.
295 """
296 raise NotImplementedError(f"{self.__name__} must implement initialize_model")
299class MakeInitializerActionBase(ConfigurableAction):
300 """An interface for an action that creates an initializer."""
302 def __call__(
303 self,
304 catalog_multi: Sequence,
305 catexps: list[fitMB.CatalogExposureInputs],
306 config_data: CatalogSourceFitterConfigData,
307 **kwargs,
308 ) -> ModelInitializer:
309 """Make a ModelInitializer object that can initialize model
310 parameter values for a given object in a catalog.
312 Parameters
313 ----------
314 catalog_multi
315 The multiband catalog with one row per object to fit.
316 catexps
317 Per-band catalog-exposure pairs.
318 config_data
319 Fitter configuration and data.
320 **kwargs
321 Additional arguments to pass to add to ModelInitializer.inputs.
323 Returns
324 -------
325 initializer
326 The configured ModelInitializer.
327 """
328 raise NotImplementedError(f"{self.__name__} must implement __call__")
331class BasicModelInitializerConfig(pexConfig.Config):
332 """Configuration for a BasicModelInitializer."""
334 psf_factor_shrink = pexConfig.Field[float](
335 doc="Multiplicative factor to shrink PSF sizes by for deconvolution",
336 default=0.9,
337 check=lambda x: 0.0 <= x < 1.0,
338 )
339 psf_factor_minimum = pexConfig.Field[float](
340 doc="Factor to multiply the PSF size by for a minimum initialize size",
341 default=0.5,
342 check=lambda x: x >= 0,
343 )
344 size_minimum = pexConfig.Field[float](
345 doc="Absolute minimum initial size in pixels",
346 default=0.5,
347 check=lambda x: x >= 0,
348 )
349 rho_abs_max = pexConfig.Field[float](
350 doc="Maximum absolute initial value of rho",
351 default=0.8,
352 check=lambda x: x >= 0,
353 )
356class BasicModelInitializer(ModelInitializer):
357 """A generic model initializer that should work on most kinds of models
358 with a single source.
359 """
361 config: BasicModelInitializerConfig = pydantic.Field(title="A BasicModelInitializerConfig to be frozen")
363 def _get_params_init(self, model_sources: tuple[g2f.Source]) -> tuple[g2f.ParameterD]:
364 """Return an ordered set of free parameters from a model's sources.
366 Parameters
367 ----------
368 model_sources
369 The sources in the model.
371 Returns
372 -------
373 params_init
374 The parameter objects for sources in the model.
376 Notes
377 -----
378 Only free and/or centroid parameters are returned (centroids are
379 always needed even if they are fixed).
380 """
381 # TODO: There ought to be a better way to not get the PSF centroids
382 # (those are part of model.data's fixed parameters)
383 params_init = (
384 tuple(
385 (
386 param
387 for param in get_params_uniq(model_sources[0])
388 if param.free
389 or (
390 isinstance(param, g2f.CentroidXParameterD)
391 or isinstance(param, g2f.CentroidYParameterD)
392 )
393 )
394 )
395 if (len(model_sources) == 1)
396 else tuple(
397 {
398 param: None
399 for source in model_sources
400 for param in get_params_uniq(source)
401 if param.free
402 or (
403 isinstance(param, g2f.CentroidXParameterD)
404 or isinstance(param, g2f.CentroidYParameterD)
405 )
406 }.keys()
407 )
408 )
409 return params_init
411 def _get_priors_type(
412 self,
413 priors: tuple[g2f.Prior],
414 ) -> tuple[tuple[g2f.GaussianPrior], tuple[g2f.ShapePrior]]:
415 """Return the list of priors of known type, by type.
417 Parameters
418 ----------
419 priors
420 A list of priors of any type, typically from a model.
422 Returns
423 -------
424 priors_gauss
425 A list of all of the Gaussian priors, in the order they occurred.
426 priors_shape
427 A list of all of the shape priors, in the order they occurred.
428 """
429 priors_gauss: list[g2f.GaussianPrior] = []
430 priors_shape: list[g2f.ShapePrior] = []
431 for prior in priors:
432 if isinstance(prior, g2f.GaussianPrior):
433 priors_gauss.append(prior)
434 elif isinstance(prior, g2f.ShapePrior):
435 priors_shape.append(prior)
436 return tuple(priors_gauss), tuple(priors_shape)
438 def get_centroid_and_shape(
439 self,
440 source: Mapping[str, Any],
441 catexps: list[CatalogExposureSourcesABC],
442 config_data: CatalogSourceFitterConfigData,
443 values_init: Mapping[g2f.ParameterD, float] | None = None,
444 ) -> tuple[tuple[float, float], tuple[float, float, float]]:
445 """Get the centroid and shape for a source.
447 Parameters
448 ----------
449 source
450 A mapping with fields expected to be populated in the
451 corresponding source catalog for initialization.
452 catexps
453 A list of (source and psf) catalog-exposure pairs.
454 config_data
455 Configuration settings and data for fitting and output.
456 values_init
457 Initial parameter values from the model configuration.
459 Returns
460 -------
461 centroid
462 The x- and y-axis centroid values.
463 sig_x, sig_y, rho
464 The x- and y-axis Gaussian sigma and rho values defining the
465 estimated elliptical shape of the source.
466 """
467 centroid = source["slot_Centroid_x"], source["slot_Centroid_y"]
468 # Attempt partial deconvolution of observed moments
469 psf_factor_shrink = self.config.psf_factor_shrink**2
470 psf_factor_minimum = self.config.psf_factor_minimum**2
471 rho_min, rho_max = -self.config.rho_abs_max, self.config.rho_abs_max
472 psf_xx = source["base_SdssShape_psf_xx"]
473 psf_yy = source["base_SdssShape_psf_yy"]
474 sig_x, sig_y = (
475 math.sqrt(
476 np.nanmax(
477 (
478 source[f"slot_Shape_{suffix}"] - moment_sq * psf_factor_shrink,
479 moment_sq * psf_factor_minimum,
480 self.config.size_minimum,
481 )
482 )
483 )
484 for suffix, moment_sq in (("xx", psf_xx), ("yy", psf_yy))
485 )
486 psf_xy = source["base_SdssShape_psf_xy"]
487 sig_xy = sig_x * sig_y
488 if not (sig_xy > 0):
489 rho = 0
490 else:
491 rho = np.clip((source["slot_Shape_xy"] - psf_xy * psf_factor_shrink) / sig_xy, rho_min, rho_max)
492 shape = sig_x, sig_y, rho
493 return centroid, shape
495 def get_params_init(self, model: Model) -> tuple[g2f.ParameterD]:
496 """Return the free and/or centroid parameters for a model.
498 Parameters
499 ----------
500 model
501 The model to return parameters for.
503 Returns
504 -------
505 parameters
506 The ordered list of parameters for the model.
507 """
508 return self._get_params_init(model_sources=model.sources)
510 def get_priors_type(self, model: Model) -> tuple[tuple[g2f.GaussianPrior], tuple[g2f.ShapePrior]]:
511 """Return the list of priors of known type, by type.
513 Parameters
514 ----------
515 model
516 The model to return priors for.
518 Returns
519 -------
520 priors_gauss
521 A list of all of the Gaussian priors, in the order they occurred.
522 priors_shape
523 A list of all of the shape priors, in the order they occurred.
524 """
525 return self._get_priors_type(model.priors)
527 def initialize_model(
528 self,
529 model: Model,
530 source: Mapping[str, Any],
531 catexps: list[CatalogExposureSourcesABC],
532 config_data: CatalogSourceFitterConfigData,
533 values_init: Mapping[g2f.ParameterD, float] | None = None,
534 **kwargs,
535 ):
536 if values_init is None:
537 values_init = {}
538 set_flux_limits = kwargs.pop("set_flux_limits", True)
539 flux_init_min = kwargs.pop("value_init_min", 1e-10)
540 flux_limit_min = kwargs.pop("flux_limit_min", 1e-12)
541 if kwargs:
542 raise ValueError(f"Unexpected {kwargs=}")
543 centroid_pixel_offset = config_data.config.centroid_pixel_offset
544 (cen_x, cen_y), (sig_x, sig_y, rho) = self.get_centroid_and_shape(
545 source,
546 catexps,
547 config_data,
548 values_init=values_init,
549 )
550 # If we couldn't get a shape at all, make it small and roundish
551 if not np.isfinite(rho):
552 # Note rho=0 (circular) is generally disfavoured by shape priors
553 # However, setting it to a non-zero value seems to make scipy
554 # fail to move off initial conditions, as do sizes below 2 pixels
555 sig_x, sig_y, rho = 2.0, 2.0, 0.0
557 # Make restrictive centroid limits (intersection, not union)
558 x_min, y_min, x_max, y_max = -np.inf, -np.inf, np.inf, np.inf
560 fluxes_init = {}
561 fluxes_limits = {}
563 # This is the maximum number of potential observations
564 # They might not all have made it into the data
565 n_catexps = len(catexps)
566 n_components = len(model.sources[0].components)
568 # If not true, some bands must have no data to fit
569 if len(catexps) != len(model.data):
570 catexps_obs = []
571 for catexp in catexps:
572 fluxes_init[catexp.channel] = flux_init_min
573 fluxes_limits[catexp.channel] = (0, np.inf)
574 # No associated catalog means we can't fit (and should be
575 # because there's no exposure for this band in this patch)
576 if len(catexp.get_catalog()) > 0:
577 catexps_obs.append(catexp)
578 else:
579 catexps_obs = catexps
581 for idx_obs, observation in enumerate(model.data):
582 coordsys = observation.image.coordsys
583 catexp = catexps_obs[idx_obs]
584 band = catexp.band
586 x_min = max(x_min, coordsys.x_min)
587 y_min = max(y_min, coordsys.y_min)
588 x_max = min(x_max, coordsys.x_min + float(observation.image.n_cols))
589 y_max = min(y_max, coordsys.y_min + float(observation.image.n_rows))
591 flux_total = np.nansum(observation.image.data[observation.mask_inv.data])
593 column_ref = f"merge_measurement_{band}"
594 if column_ref in source.schema.getNames() and source[column_ref]:
595 row = source
596 else:
597 row = catexp.catalog.find(source["id"])
599 if not row["base_SdssShape_flag"]:
600 flux_init = row["base_SdssShape_instFlux"]
601 else:
602 flux_init = row["slot_GaussianFlux_instFlux"]
603 if not (flux_init > 0):
604 flux_init = row["slot_PsfFlux_instFlux"]
606 calib = catexp.exposure.photoCalib
607 flux_init = calib.instFluxToNanojansky(flux_init) if (flux_init > 0) else max(flux_total, 1.0)
608 if set_flux_limits:
609 flux_max = 10 * max((flux_init, flux_total))
610 flux_min = min(flux_limit_min, flux_max / 1000)
611 else:
612 flux_min, flux_max = 0, np.inf
613 if not (flux_init > flux_min):
614 flux_upper = flux_max if (flux_max < np.inf) else 10.0 * flux_min
615 flux_init = flux_min + 0.01 * (flux_upper - flux_min)
616 fluxes_init[observation.channel] = flux_init / n_components
617 fluxes_limits[observation.channel] = (flux_min, flux_max)
619 if not np.isfinite(cen_x):
620 cen_x = observation.image.n_cols / 2.0
621 else:
622 cen_x -= centroid_pixel_offset
623 if not np.isfinite(cen_y):
624 # TODO: Add bbox coords or remove
625 cen_y = observation.image.n_rows / 2.0
626 else:
627 cen_y -= centroid_pixel_offset
629 # An R_eff larger than the box size is problematic. This should also
630 # stop unreasonable size proposals; a log10 transform isn't enough.
631 # TODO: Try logit for r_eff?
632 size_major = g2.EllipseMajor(g2.Ellipse(sigma_x=sig_x, sigma_y=sig_y, rho=rho)).r_major
633 limits_size = max(5.0 * size_major, 2.0 * np.hypot(x_max - x_min, y_max - y_min))
634 limits_xy = (1e-5, limits_size)
635 params_limits_init = {
636 g2f.CentroidXParameterD: (cen_x, (x_min, x_max)),
637 g2f.CentroidYParameterD: (cen_y, (y_min, y_max)),
638 g2f.ReffXParameterD: (sig_x, limits_xy),
639 g2f.ReffYParameterD: (sig_y, limits_xy),
640 g2f.SigmaXParameterD: (sig_x, limits_xy),
641 g2f.SigmaYParameterD: (sig_y, limits_xy),
642 g2f.RhoParameterD: (rho, None),
643 # TODO: get guess from configs?
644 g2f.SersicMixComponentIndexParameterD: (1.0, None),
645 }
647 fluxes_init_tuple = tuple(fluxes_init.values())
648 fluxes_limits_tuple = tuple(fluxes_limits.values())
649 idx_obs = 0
650 for param in self.params_init:
651 if param.linear:
652 value_init = fluxes_init_tuple[idx_obs]
653 limits_new = fluxes_limits_tuple[idx_obs]
654 idx_obs += 1
655 if idx_obs == n_catexps:
656 idx_obs = 0
657 else:
658 type_param = type(param)
659 value_init, limits_new = params_limits_init.get(type_param, (values_init.get(param), None))
660 if limits_new:
661 param.limits = g2f.LimitsD(limits_new[0], limits_new[1])
662 if value_init is not None:
663 param.value = np.clip(value_init, param.limits.min, param.limits.max)
665 priors_shape_mag = self.priors_shape_mag
666 has_priors_mag = len(priors_shape_mag) > 0
667 if has_priors_mag:
668 mag_total = u.nJy.to(u.ABmag, np.nansum(fluxes_init_tuple))
670 # TODO: Add centroid prior
671 priors_gauss, priors_shape = self.get_priors_type(model)
672 for prior in priors_shape:
673 if has_priors_mag and ((prior_adjustments := priors_shape_mag.get(prior)) is not None):
674 mag_dep_prior, prior_shape_new = prior_adjustments
675 prior_size_new = prior_shape_new.prior_size
676 # the size-apparent mag relation probably flattens
677 # for very bright/faint objects - maybe not so
678 # sharply, but clipping a broad mag range ought to be fine
679 prior.prior_size.mean_parameter.value = prior_size_new.mean_parameter.value * 10 ** (
680 mag_dep_prior.slope_median_per_mag
681 * np.clip(
682 mag_total - mag_dep_prior.intercept_mag,
683 -12.5,
684 12.5,
685 )
686 )
687 # it's uncertain how the intrinsic scatter behaves
688 # educated guess is it doesn't change much, also
689 # one runs out of bright galaxies to measure it anyway
690 prior.prior_size.stddev_parameter.value = prior_size_new.stddev_parameter.value * 10 ** (
691 mag_dep_prior.slope_stddev_per_mag
692 * np.clip(
693 mag_total - mag_dep_prior.intercept_mag,
694 -12.5,
695 12.5,
696 )
697 )
698 else:
699 prior.prior_size.mean_parameter.value = size_major
702class CachedBasicModelInitializer(BasicModelInitializer):
703 """A basic initializer with a cached list of model sources and priors."""
705 priors: tuple[g2f.Prior, ...] = pydantic.Field(title="The gauss2d_fit model priors")
706 sources: tuple[g2f.Source, ...] = pydantic.Field(title="The gauss2d_fit model sources")
708 @cached_property
709 def params_init(self) -> tuple[g2f.ParameterD]:
710 """Return a cached reference to the result of _get_params_init."""
711 return self._get_params_init(model_sources=self.sources)
713 @cached_property
714 def priors_type(self) -> tuple[tuple[g2f.GaussianPrior], tuple[g2f.ShapePrior]]:
715 """Return a cached reference to the result of _get_priors_type."""
716 return self._get_priors_type(self.priors)
718 def get_params_init(self, model: Model) -> tuple[g2f.ParameterD]:
719 assert tuple(model.sources) == self.sources
720 return self.params_init
722 def get_priors_type(self, model: Model) -> tuple[tuple[g2f.GaussianPrior], tuple[g2f.ShapePrior]]:
723 assert tuple(model.priors) == self.priors
724 return self.priors_type
727class InitialInputData(pydantic.BaseModel):
728 """A configurable wrapper to retrieve formatted columns from a catalog.
730 This provides a common interface to typical MultiProFit table outputs.
731 """
733 model_config: ClassVar[pydantic.ConfigDict] = frozen_arbitrary_allowed_config
735 column_id: str | None = pydantic.Field(
736 title="Override for id column specified in config_input",
737 default=None,
738 )
739 config_input: InputConfig = pydantic.Field(title="Configuration for the data table")
740 data: Table = pydantic.Field(title="The data table")
741 name_model: str = pydantic.Field(title="The name of the model in columns")
742 prefix_column: str = pydantic.Field(title="The prefix for all fitted column names")
743 size_column: str = pydantic.Field(title="The name of the size column", default="reff")
745 def get_column_id(self):
746 """Return the name of the object ID column."""
747 return self.column_id or self.config_input.column_id
749 def get_column(self, name_column: str, data=None):
750 """Get the values from a column.
752 Parameters
753 ----------
754 name_column
755 The name of the column to retrieve.
756 data
757 The catalog to retrieve the column from. Default is self.data.
759 Returns
760 -------
761 values
762 The column values.
763 """
764 if data is None:
765 data = self.data
766 return data[f"{self.prefix_column}{name_column}"]
768 def model_post_init(self, __context: Any) -> None:
769 # Initialize a mapping of the row number for a given object ID value
770 # This is implemented in afw catalogs but not most other tabular types
771 id_index = {idnum: idx for idx, idnum in enumerate(self.data[self.get_column_id()])}
772 object.__setattr__(self, "id_index", id_index)
775class MakeBasicInitializerAction(MakeInitializerActionBase):
776 """An action to construct an initializer for a single-component,
777 single-source model.
778 """
780 config = pexConfig.ConfigField[BasicModelInitializerConfig](
781 doc="Configuration for the initializer to be constructed",
782 )
784 def _make_initializer(
785 self,
786 catalog_multi: Sequence,
787 catexps: list[fitMB.CatalogExposureInputs],
788 config_data: CatalogSourceFitterConfigData,
789 ) -> ModelInitializer:
790 return BasicModelInitializer(config=self.config)
792 def __call__(
793 self,
794 catalog_multi: Sequence,
795 catexps: list[fitMB.CatalogExposureInputs],
796 config_data: CatalogSourceFitterConfigData,
797 **kwargs,
798 ) -> ModelInitializer:
799 initializer = self._make_initializer(
800 catalog_multi=catalog_multi,
801 catexps=catexps,
802 config_data=config_data,
803 )
804 for name, (config_input, data) in kwargs.items():
805 if not isinstance(data, Table) and hasattr(data, "meta"):
806 _LOG.warning(
807 f"Ignoring extra input {name=} because it is of type {type(data)} and is either not an"
808 f" astropy.table.Table or missing a 'meta' attr"
809 )
810 config_data = data.meta["config"]
811 prefix_column = config_data["prefix_column"]
812 config_source = next(iter(config_data["config_model"]["sources"].values()))
813 config_group = next(iter(config_source["component_groups"].values()))
814 is_sersic = len(config_group["components_sersic"]) > 0
815 name_model = next(
816 iter(config_group["components_sersic"] if is_sersic else config_group["components_gaussian"])
817 )
818 initializer.inputs[name] = InitialInputData(
819 column_id=config_data.get("column_id"),
820 config_input=config_input,
821 data=data,
822 name_model=name_model,
823 prefix_column=prefix_column,
824 size_column="reff" if is_sersic else "sig",
825 )
826 return initializer
829class MakeCachedBasicInitializerAction(MakeBasicInitializerAction):
830 """A MakeBasicInitializerAction that caches references to the source
831 and prior objects of the model.
833 This is solely a performance optimization and should be favored over
834 MakeBasicInitializerAction unless the caching is shown to be slower.
835 """
837 def _make_initializer(
838 self,
839 catalog_multi: Sequence,
840 catexps: list[fitMB.CatalogExposureInputs],
841 config_data: CatalogSourceFitterConfigData,
842 ) -> ModelInitializer:
843 sources, priors = config_data.sources_priors
844 return CachedBasicModelInitializer(config=self.config, priors=priors, sources=sources)
847class MultiProFitSourceConfig(CatalogSourceFitterConfig, fitMB.CoaddMultibandFitSubConfig):
848 """Configuration for the MultiProFit profile fitter."""
850 action_initializer = ConfigurableActionField[MakeInitializerActionBase](
851 doc="The action to return an initializer",
852 default=MakeCachedBasicInitializerAction,
853 )
854 action_psf = ConfigurableActionField[PsfComponentsActionBase](
855 doc="The action to return PSF component values from catalogs, if implemented",
856 default=None,
857 )
858 columns_copy = pexConfig.DictField[str, str](
859 doc="Mapping of input/output column names to copy from the input"
860 "multiband catalog to the output fit catalog.",
861 default={},
862 dictCheck=lambda x: len(set(x.values())) == len(x.values()),
863 )
864 mask_names_zero = pexConfig.ListField[str](
865 doc="Mask bits to mask out",
866 default=["BAD", "EDGE", "SAT", "NO_DATA"],
867 )
868 psf_sigma_subtract = pexConfig.Field[float](
869 doc="PSF x/y sigma value to subtract in quadrature from best-fit values",
870 default=0.1,
871 check=lambda x: np.isfinite(x) and (x >= 0),
872 )
873 prefix_column = pexConfig.Field[str](default="mpf_", doc="Column name prefix")
874 size_priors = pexConfig.ConfigDictField[str, MagnitudeDependentSizePriorConfig](
875 doc="Per-component magnitude-dependent size prior configurations."
876 " Will be added to component with existing configs.",
877 default={},
878 )
880 def bands_read_only(self) -> set[str]:
881 # TODO: Re-implement determination of prior-only bands once
882 # data-driven priors are re-implemented (DM-4xxxx)
883 return set()
885 def requires_psf(self):
886 """Return whether the PSF action is not None."""
887 return type(self.action_psf) is PsfComponentsActionBase
889 def setDefaults(self):
890 super().setDefaults()
891 self.defer_radec_conversion = True
892 self.compute_radec_covariance = True
893 self.flag_errors = {
894 IsParentError.column_name(): "IsParentError",
895 NoDataError.column_name(): "NoDataError",
896 NotPrimaryError.column_name(): "NotPrimaryError",
897 PsfRebuildFitFlagError.column_name(): "PsfRebuildFitFlagError",
898 }
899 self.centroid_pixel_offset = -0.5
900 self.naming_scheme = "lsst"
901 self.prefix_column = ""
902 self.suffix_error = "Err"
905@pydantic.dataclasses.dataclass(frozen=True, kw_only=True, config=fitMB.CatalogExposureConfig)
906class CatalogExposurePsfs(fitMB.CatalogExposureInputs, CatalogExposureSourcesABC):
907 """Input data from lsst pipelines, parsed for MultiProFit."""
909 channel: g2f.Channel = pydantic.Field(title="Channel for the image's band")
910 config_fit: MultiProFitSourceConfig = pydantic.Field(title="Config for fitting options")
912 @cached_property
913 def _psf_flux_params(self) -> tuple[list[g2f.ParameterD], bool]:
914 psf_model = self.psf_model_data.psf_model
915 n_comps = len(psf_model.components)
916 params_flux = [None] * n_comps
917 is_frac = [False] * n_comps
918 for idx_comp, comp in enumerate(psf_model.components):
919 # TODO: Change to comp.integralmodel when DM-44344 is fixed
920 # integralmodels will still need to be handled differently
921 params_all = get_params_uniq(comp)
922 params_frac = [param for param in params_all if isinstance(param, g2f.ProperFractionParameterD)]
923 if params_frac:
924 is_last = idx_comp == (n_comps - 1)
925 if len(params_frac) != (idx_comp + 1 - is_last):
926 raise RuntimeError(
927 f"Got unexpected {params_frac=} for"
928 f" {self.psf_model_data.psf_model.components[idx_comp]=} ({idx_comp=});"
929 f" len should be idx_comp+1"
930 )
931 params_flux[idx_comp] = None if is_last else params_frac[idx_comp]
932 is_frac[idx_comp] = True
933 else:
934 params_integral = [param for param in params_all if isinstance(param, g2f.IntegralParameterD)]
935 if len(params_integral != 1):
936 raise RuntimeError(
937 f"Got unexpected {params_integral=} != 1 for"
938 f" {self.psf_model_data.psf_model.components[idx_comp]=} ({idx_comp=})"
939 )
940 params_flux[idx_comp] = params_integral[0]
941 is_frac_any = any(is_frac)
942 if is_frac_any and not all(is_frac):
943 # TODO: This should work by iterating through componentgroups
944 # But that's not trivial or supported now
945 raise RuntimeError("Got PSF model with a mix of fractional and linear models; cannot initialize")
947 return params_flux, is_frac_any
949 def get_psf_model(self, params: Mapping[str, Any]) -> g2f.PsfModel | None:
950 psf_model = self.psf_model_data.psf_model
951 # PsfComponentsActionBase is an abstract class, so check if the action
952 # is a subclass that needs to be called
953 if not self.config_fit.requires_psf():
954 try:
955 gaussians = self.config_fit.action_psf(params)
956 except PsfRebuildFitFlagError:
957 return None
958 n_comps = len(psf_model.components)
959 fluxes = [0.0] * n_comps
960 params_flux, is_frac = self._psf_flux_params
961 for idx_comp, (comp, gaussian) in enumerate(zip(psf_model.components, gaussians)):
962 ellipse_out = comp.ellipse
963 ellipse_in = gaussian.ellipse
964 ellipse_out.sigma_x = ellipse_in.sigma_x
965 ellipse_out.sigma_y = ellipse_in.sigma_y
966 ellipse_out.rho = ellipse_in.rho
967 fluxes[idx_comp] = gaussian.integral.value
968 # Apparently negative fluxes are possible. Not much can be done to
969 # fix that but set them to a tiny value (zero might work)
970 fluxes = np.clip(fluxes, 1e-3, np.inf)
971 flux_total = sum(fluxes)
972 if is_frac:
973 flux_remaining = 1.0
974 for flux, param_frac in zip(fluxes, params_flux[:-1]):
975 flux_component = flux / flux_total
976 param_frac.value = flux_component / flux_remaining
977 flux_remaining -= flux_component
978 else:
979 for flux, param_flux in zip(fluxes, params_flux):
980 param_flux.value = flux / flux_total
981 else:
982 # TODO: this should probably use .index or something
983 match = np.argwhere(
984 self.table_psf_fits[self.psf_model_data.config.column_id] == params[self.config_fit.column_id]
985 )[0][0]
986 psf_model = self.psf_model_data.psf_model
987 try:
988 self.psf_model_data.init_psf_model(self.table_psf_fits[match])
989 except PsfRebuildFitFlagError:
990 return None
992 sigma_subtract = self.config_fit.psf_sigma_subtract
993 if sigma_subtract > 0:
994 sigma_subtract_sq = sigma_subtract * sigma_subtract
995 # 1/10 of PSF sigma should suffice as a minimum size
996 sigma_min_sq = sigma_subtract_sq / 100.0
997 for param in self.psf_model_data.parameters.values():
998 if isinstance(
999 param,
1000 g2f.SigmaXParameterD | g2f.SigmaYParameterD | g2f.ReffXParameterD | g2f.ReffYParameterD,
1001 ):
1002 param.value = math.sqrt(max(param.value**2 - sigma_subtract_sq, sigma_min_sq))
1003 return psf_model
1005 def get_source_observation(self, source, **kwargs) -> g2f.ObservationD | None:
1006 if not kwargs.get("skip_flags"):
1007 if (not source["detect_isPrimary"]) or source["merge_peak_sky"]:
1008 raise NotPrimaryError(f"source {source[self.config_fit.column_id]} has invalid flags for fit")
1009 footprint = source.getFootprint()
1010 bbox = footprint.getBBox()
1011 if not (bbox.getArea() > 0):
1012 return None
1013 bitmask = 0
1014 mask = self.exposure.mask[bbox]
1015 spans = footprint.spans.asArray()
1016 for bitname in self.config_fit.mask_names_zero:
1017 bitval = mask.getPlaneBitMask(bitname)
1018 bitmask |= bitval
1019 mask = ((mask.array & bitmask) != 0) & (spans != 0)
1020 mask = ~mask
1022 is_deblended_child = source["parent"] != 0
1024 img, _, sigma_inv = get_spanned_image(
1025 exposure=self.exposure,
1026 footprint=footprint if is_deblended_child else None,
1027 bbox=bbox,
1028 spans=spans,
1029 get_sig_inv=True,
1030 )
1031 x_min_bbox, y_min_bbox = bbox.beginX, bbox.beginY
1032 # Crop to tighter box for deblended model if edges are unusable
1033 # ... this rarely ever seems to happen though
1034 if is_deblended_child:
1035 coords = np.argwhere(np.isfinite(img) & (sigma_inv > 0) & np.isfinite(sigma_inv))
1036 if len(coords) == 0:
1037 return None
1038 x_min, y_min = coords.min(axis=0)
1039 x_max, y_max = coords.max(axis=0)
1040 x_max += 1
1041 y_max += 1
1043 if (x_min > 0) or (y_min > 0) or (x_max < img.shape[0]) or (y_max < img.shape[1]):
1044 # Ensure the nominal centroid is still inside the box
1045 # ... although it's a bad sign if that row/column is all bad
1046 x_cen = source["slot_Centroid_x"] - x_min_bbox
1047 y_cen = source["slot_Centroid_y"] - y_min_bbox
1048 x_min = min(x_min, int(np.floor(x_cen)))
1049 x_max = max(x_max, int(np.ceil(x_cen)))
1050 y_min = min(y_min, int(np.floor(y_cen)))
1051 y_max = max(y_max, int(np.ceil(y_cen)))
1052 x_min_bbox += x_min
1053 y_min_bbox += y_min
1054 img = img[x_min:x_max, y_min:y_max]
1055 sigma_inv = sigma_inv[x_min:x_max, y_min:y_max]
1056 mask = mask[x_min:x_max, y_min:y_max]
1058 sigma_inv[~mask] = 0
1060 coordsys = g2.CoordinateSystem(1.0, 1.0, x_min_bbox, y_min_bbox)
1062 obs = g2f.ObservationD(
1063 image=g2.ImageD(img, coordsys),
1064 sigma_inv=g2.ImageD(sigma_inv, coordsys),
1065 mask_inv=g2.ImageB(mask, coordsys),
1066 channel=self.channel,
1067 )
1068 return obs
1070 def __post_init__(self):
1071 # TODO: Can/should this be the derived type (MultiProFitPsfConfig)?
1072 config = CatalogPsfFitterConfig()
1073 config_dict = self.table_psf_fits.meta.get("config")
1074 if config_dict:
1075 set_config_from_dict(config, config_dict)
1076 else:
1077 # TODO: How should this be set?
1078 # If using external PSF fits, it needs to be configured normally
1079 pass
1080 config_data = CatalogPsfFitterConfigData(config=config)
1081 object.__setattr__(self, "psf_model_data", config_data)
1084class MultiProFitSourceFitter(CatalogSourceFitterABC):
1085 """A MultiProFit source fitter.
1087 Parameters
1088 ----------
1089 wcs
1090 A WCS solution that applies to all exposures.
1091 errors_expected
1092 A dictionary of exceptions that are expected to sometimes be raised
1093 during processing (e.g. for missing data) keyed by the name of the
1094 flag column used to record the failure.
1095 add_missing_errors
1096 Whether to add all of the standard MultiProFit errors with default
1097 column names to errors_expected, if not already present.
1098 **kwargs
1099 Keyword arguments to pass to the superclass constructor.
1100 """
1102 initializer: ModelInitializer = pydantic.Field(
1103 title="The model parameter initializer",
1104 default_factory=lambda: BasicModelInitializer(),
1105 )
1106 wcs: lsst.afw.geom.SkyWcs = pydantic.Field(
1107 title="The WCS object to use to convert pixel coordinates to RA/dec",
1108 )
1110 def __init__(
1111 self,
1112 wcs: lsst.afw.geom.SkyWcs,
1113 errors_expected: dict[str, Exception] | None = None,
1114 add_missing_errors: bool = True,
1115 **kwargs: Any,
1116 ):
1117 if errors_expected is None:
1118 errors_expected = {}
1119 if add_missing_errors:
1120 for error_catalog in (IsParentError, NoDataError, NotPrimaryError, PsfRebuildFitFlagError):
1121 if error_catalog not in errors_expected:
1122 errors_expected[error_catalog] = error_catalog.column_name()
1123 super().__init__(wcs=wcs, errors_expected=errors_expected, **kwargs)
1125 def copy_centroid_errors(
1126 self,
1127 columns_cenx_err_copy: tuple[str],
1128 columns_ceny_err_copy: tuple[str],
1129 results: Table,
1130 catalog_multi: Sequence,
1131 catexps: list[CatalogExposureSourcesABC],
1132 config_data: CatalogSourceFitterConfigData,
1133 ):
1134 for column in columns_cenx_err_copy:
1135 results[column] = catalog_multi["slot_Centroid_xErr"]
1136 for column in columns_ceny_err_copy:
1137 results[column] = catalog_multi["slot_Centroid_yErr"]
1139 def compute_model_radec_err(
1140 self,
1141 source_multi: Mapping[str, Any],
1142 results,
1143 columns_params_radec_err,
1144 idx: int,
1145 set_radec: bool = False,
1146 ) -> None:
1147 for (
1148 key_ra_err,
1149 key_dec_err,
1150 key_cen_x,
1151 key_cen_y,
1152 key_cen_x_err,
1153 key_cen_y_err,
1154 key_cen_ra_dec_cov,
1155 key_ra,
1156 key_dec,
1157 ) in columns_params_radec_err:
1158 (ra, dec), (ra_err, dec_err, ra_dec_cov) = afwTable.convertCentroid(
1159 self.wcs,
1160 results[key_cen_x][idx],
1161 results[key_cen_y][idx],
1162 results[key_cen_x_err][idx],
1163 results[key_cen_y_err][idx],
1164 0.0,
1165 )
1166 if set_radec:
1167 results[key_ra][idx], results[key_dec][idx] = ra, dec
1168 else:
1169 ra_in, dec_in = results[key_ra][idx], results[key_dec][idx]
1170 if not np.isclose((ra, dec), (ra_in, dec_in), rtol=1e-7, atol=1e-8):
1171 self._get_logger().warning(
1172 "idx=%i ra, dec = %f,%f differ significantly from convertCentroid ra, dec = %f, %f",
1173 idx,
1174 ra_in,
1175 dec_in,
1176 ra,
1177 dec,
1178 )
1179 results[key_ra_err][idx], results[key_dec_err][idx] = ra_err, dec_err
1180 if key_cen_ra_dec_cov is not None:
1181 results[key_cen_ra_dec_cov][idx] = ra_dec_cov
1183 def get_model_radec(self, source: Mapping[str, Any], cen_x: float, cen_y: float):
1184 # no extra conversions are needed here - cen_x, cen_y are in catalog
1185 # coordinates already
1186 ra, dec = self.wcs.pixelToSky(cen_x, cen_y)
1187 return ra.asDegrees(), dec.asDegrees()
1189 def initialize_model(
1190 self,
1191 model: g2f.ModelD,
1192 source: Mapping[str, Any],
1193 catexps: list[CatalogExposureSourcesABC],
1194 config_data: CatalogSourceFitterConfigData,
1195 values_init: Mapping[g2f.ParameterD, float] | None = None,
1196 **kwargs,
1197 ):
1198 self.initializer.initialize_model(
1199 model=model,
1200 source=source,
1201 catexps=catexps,
1202 config_data=config_data,
1203 values_init=values_init,
1204 **kwargs,
1205 )
1207 def make_CatalogExposurePsfs(
1208 self,
1209 catexp: fitMB.CatalogExposureInputs,
1210 config: MultiProFitSourceConfig,
1211 ) -> CatalogExposurePsfs:
1212 """Make a CatalogExposurePsfs from a list of inputs and a fit config.
1214 Parameters
1215 ----------
1216 catexp
1217 The input catalog-exposure pairs.
1218 config
1219 The MultiProFit source fitting config.
1221 Returns
1222 -------
1223 catexp_psf
1224 The resulting CatalogExposurePsfs.
1225 """
1226 catexp_psf = CatalogExposurePsfs(
1227 # dataclasses.asdict(catexp)_makes a recursive deep copy.
1228 # That must be avoided.
1229 **{key: getattr(catexp, key) for key in catexp.__dataclass_fields__.keys()},
1230 channel=g2f.Channel.get(catexp.band),
1231 config_fit=config,
1232 )
1233 return catexp_psf
1235 def validate_fit_inputs(
1236 self,
1237 catalog_multi: Sequence,
1238 catexps: list[CatalogExposurePsfs],
1239 config_data: CatalogSourceFitterConfigData = None,
1240 logger: logging.Logger = None,
1241 **kwargs: Any,
1242 ) -> None:
1243 errors = []
1244 for idx, catexp in enumerate(catexps):
1245 if not isinstance(catexp, CatalogExposurePsfs):
1246 errors.append(f"catexps[{idx=} {type(catexp)=} !isinstance(CatalogExposurePsfs)")
1247 # Pre-validate the model
1248 config_sources = config_data.config.config_model.sources
1249 model_sources, priors = config_data.sources_priors
1250 priors_shape = [prior for prior in priors if isinstance(prior, g2f.ShapePrior)]
1252 if len(config_sources.keys()) > 1:
1253 errors.append(f"model config has multiple sources: {list(config_sources.keys())}")
1254 elif len(priors_shape) > 0:
1255 idx_prior_found = 0
1256 name_source, config_source = next(iter(config_sources.items()))
1257 source = model_sources[0]
1258 config_groups = config_source.component_groups
1259 if len(config_groups.keys()) > 1:
1260 errors.append(f"model {name_source=} has multiple groups: {list(config_source.keys())}")
1261 else:
1262 name_group, config_group = next(iter(config_groups.items()))
1263 for idx_comp, (name_comp, config_comp) in enumerate(
1264 config_group.get_component_configs().items()
1265 ):
1266 ellipse = source.components[idx_comp].ellipse
1267 # component.ellipse returns a const ref and must be copied
1268 # The ellipse classes might need copy constructors
1269 ellipse_copy = type(ellipse)(
1270 # No kwargs here, since they are unfortunately not
1271 # standardized (e.g. Gaussian is sigma_x not size_x)
1272 # but the arg order is
1273 ellipse.size_x,
1274 ellipse.size_y,
1275 ellipse.rho,
1276 )
1277 prior_shape_new = config_comp.make_shape_prior(ellipse_copy)
1278 if prior_shape_new is not None:
1279 if idx_prior_found == len(priors_shape):
1280 errors.append(
1281 f"Could not validate prior for {name_source=} {name_group=} {name_comp=}"
1282 )
1283 break
1284 prior_shape_old = priors_shape[idx_prior_found]
1285 ll_new, ll_old = (
1286 prior.evaluate().loglike for prior in (prior_shape_new, prior_shape_old)
1287 )
1288 # The necessary tolerance for this check is uncertain
1289 if not np.isclose(ll_new, ll_old):
1290 logger.warning(
1291 f"shape prior for {name_comp=} got inconsistent {ll_new=} vs {ll_old}"
1292 )
1293 if (prior_shape_mod := config_data.config.size_priors.get(name_comp)) is not None:
1294 self.initializer.priors_shape_mag[prior_shape_old] = (
1295 prior_shape_mod,
1296 prior_shape_new,
1297 )
1299 if errors:
1300 raise RuntimeError("\n".join(errors))
1303class MultiProFitSourceTask(fitMB.CoaddMultibandFitSubTask):
1304 """Run MultiProFit on Exposure/SourceCatalog pairs in multiple bands.
1306 This task uses MultiProFit to fit a single model to all sources in a coadd,
1307 using a previously-fit PSF model for each exposure. The task may also use
1308 prior measurements from single- or merged multiband catalogs for
1309 initialization.
1310 """
1312 ConfigClass: ClassVar = MultiProFitSourceConfig
1313 _DefaultName: ClassVar = "multiProFitSource"
1315 def make_default_fitter(
1316 self,
1317 catalog_multi: Sequence,
1318 catexps: list[fitMB.CatalogExposureInputs],
1319 config_data: CatalogSourceFitterConfigData,
1320 **kwargs,
1321 ) -> MultiProFitSourceFitter:
1322 """Make a default MultiProFitSourceFitter.
1324 Parameters
1325 ----------
1326 catalog_multi
1327 A multi-band, indexable source catalog.
1328 catexps
1329 Catalog-exposure-PSF model tuples to fit source models for.
1330 config_data
1331 Configuration and data for the initalizer.
1332 **kwargs
1333 Additional keyword arguments to pass to
1334 self.config.action_initializer.
1336 Returns
1337 -------
1338 fitter
1339 A MultiProFitSourceFitter using the first catexp's wcs.
1340 """
1341 initializer = self.config.action_initializer(
1342 catalog_multi=catalog_multi, catexps=catexps, config_data=config_data, **kwargs
1343 )
1344 # Look for the first WCS - they ought to be identical
1345 # If they are not, the patch coadd data model must have changed
1346 wcs = None
1347 for catexp in catexps:
1348 if catexp.exposure is not None:
1349 wcs = catexp.exposure.wcs
1350 break
1351 if wcs is None:
1352 raise RuntimeError(f"Could not find valid wcs in any of {catexps=}")
1353 fitter = MultiProFitSourceFitter(wcs=wcs, initializer=initializer)
1354 return fitter
1356 @utilsTimer.timeMethod
1357 def run(
1358 self,
1359 catalog_multi: Sequence,
1360 catexps: list[fitMB.CatalogExposureInputs],
1361 fitter: MultiProFitSourceFitter | None = None,
1362 **kwargs,
1363 ) -> pipeBase.Struct:
1364 """Run the MultiProFit source fit task on catalog-exposure pairs.
1366 Parameters
1367 ----------
1368 catalog_multi
1369 A multi-band, indexable source catalog.
1370 catexps
1371 Catalog-exposure-PSF model tuples to fit source models for.
1372 fitter
1373 The fitter instance to use. Default-initialized if not provided.
1374 **kwargs
1375 Additional keyword arguments to pass to self.fit.
1377 Returns
1378 -------
1379 catalog : `astropy.Table`
1380 A table with fit parameters for the PSF model at the location
1381 of each source.
1382 """
1383 n_catexps = len(catexps)
1384 if n_catexps == 0:
1385 raise ValueError("Must provide at least one catexp")
1386 catexps_conv: list[CatalogExposurePsfs] = [None] * n_catexps
1387 channels = [g2f.Channel.get(catexp.band) for catexp in catexps]
1388 config_data = CatalogSourceFitterConfigData(channels=channels, config=self.config)
1389 if fitter is None:
1390 inputs_init = kwargs.get("inputs_init")
1391 if inputs_init:
1392 del kwargs["inputs_init"]
1393 else:
1394 inputs_init = {}
1395 fitter = self.make_default_fitter(
1396 catalog_multi=catalog_multi, catexps=catexps, config_data=config_data, **inputs_init
1397 )
1398 for idx, catexp in enumerate(catexps):
1399 if not isinstance(catexp, CatalogExposurePsfs):
1400 catexp = fitter.make_CatalogExposurePsfs(catexp, config=self.config)
1401 catexps_conv[idx] = catexp
1402 catalog = fitter.fit(
1403 catalog_multi=catalog_multi, catexps=catexps_conv, config_data=config_data, **kwargs
1404 )
1405 for name_in, name_out in self.config.columns_copy.items():
1406 catalog[name_out] = catalog_multi[name_in]
1407 catalog[name_out].description = catalog_multi.schema.find(name_in).field.getDoc()
1408 return pipeBase.Struct(output=astropy_to_arrow(catalog))