Coverage for python / lsst / multiprofit / fitting / fit_source.py: 10%
515 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-26 08:58 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-26 08:58 +0000
1# This file is part of multiprofit.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22from abc import ABC, abstractmethod
23from functools import cached_property
24import logging
25import time
26from typing import Any, ClassVar, Iterable, Mapping, Self, Sequence, Type
28import astropy
29from astropy.table import Table
30import astropy.units as u
31import lsst.gauss2d.fit as g2f
32import lsst.pex.config as pexConfig
33from lsst.utils.logging import PeriodicLogger
34import numpy as np
35import pydantic
37from ..componentconfig import Fluxes, GaussianComponentConfig
38from ..errors import NoDataError, RaDecConversionNotImplementedError
39from ..modelconfig import ModelConfig
40from ..modeller import FitInputsDummy, Modeller
41from ..sourceconfig import ComponentGroupConfig, SourceConfig
42from ..utils import frozen_arbitrary_allowed_config, get_params_uniq
43from .fit_catalog import CatalogExposureABC, CatalogFitterConfig, ColumnInfo
45__all__ = [
46 "CatalogExposureSourcesABC",
47 "CatalogSourceFitterConfig",
48 "CatalogSourceFitterConfigData",
49 "CatalogSourceFitterABC",
50]
53class CatalogExposureSourcesABC(CatalogExposureABC):
54 """Interface for a CatalogExposure for source modelling."""
56 @property
57 def band(self) -> str:
58 """Return the name of the exposure's passband (e.g. 'r')."""
59 return self.channel.name
61 # Note: not named band because that's usually a string
62 @property
63 @abstractmethod
64 def channel(self) -> g2f.Channel:
65 """Return the exposure's associated channel object."""
67 @abstractmethod
68 def get_psf_model(self, params: Mapping[str, Any]) -> g2f.PsfModel | None:
69 """Get the PSF model for a given source row.
71 Parameters
72 ----------
73 params : Mapping[str, Any]
74 A mapping with parameter values for the best-fit PSF model at the
75 centroid of a single source.
77 Returns
78 -------
79 psf_model : `lsst.gauss2d.fit.PsfModel`
80 A PsfModel object initialized with the best-fit parameters, or None
81 if PSF rebuilding failed for an expected reason (i.e. the input PSF
82 fit table has a flag set).
83 """
85 @abstractmethod
86 def get_source_observation(self, source: Mapping[str, Any], **kwargs: Any) -> g2f.ObservationD | None:
87 """Get the Observation for a given source row.
89 Parameters
90 ----------
91 source : Mapping[str, Any]
92 A mapping with any values needed to retrieve an observation for a
93 single source.
94 **kwargs
95 Additional keyword arguments not used during fitting.
97 Returns
98 -------
99 observation : `lsst.gauss2d.fit.Observation`
100 An Observation object with suitable data for fitting parametric
101 models of the source, or None if the observation cannot be fit.
102 """
105class CatalogSourceFitterConfig(CatalogFitterConfig):
106 """Configuration for the MultiProFit profile fitter."""
108 centroid_pixel_offset = pexConfig.Field[float](
109 doc="Number to add to MultiProFit centroids (bottom-left corner is 0,0) to convert to catalog"
110 " coordinates (e.g. set to -0.5 if the bottom-left corner is -0.5, -0.5)",
111 default=0,
112 )
113 compute_radec_covariance = pexConfig.Field[bool](
114 doc="Whether to compute the RA/dec covariance. Ignore if convert_cen_xy_to_radec is False.",
115 default=False,
116 )
117 config_model = pexConfig.ConfigField[ModelConfig](doc="Source model configuration")
118 convert_cen_xy_to_radec = pexConfig.Field[bool](
119 doc="Convert pixel x/y centroid params to RA/dec",
120 default=True,
121 )
122 defer_radec_conversion = pexConfig.Field[bool](
123 doc="Whether to defer conversion of pixel x/y centroid params to RA/dec to compute_model_radec_err."
124 " Only effective if convert_cen_xy_to_radec and compute_errors is not NONE, and requires that the"
125 " overloaded compute_model_radec_err method sets RA/dec values itself.",
126 default=False,
127 )
128 fit_psmodel_final = pexConfig.Field[bool](
129 default=False,
130 doc="Fit a point source model after optimization",
131 )
132 prior_cen_x_stddev = pexConfig.Field[float](
133 default=0, doc="Prior std. dev. on x centroid (ignored if not >0)"
134 )
135 prior_cen_y_stddev = pexConfig.Field[float](
136 default=0, doc="Prior std. dev. on y centroid (ignored if not >0)"
137 )
138 unit_flux = pexConfig.Field[str](default=None, doc="Flux unit", optional=True)
140 def make_model_data(
141 self,
142 idx_row: int,
143 catexps: list[CatalogExposureSourcesABC],
144 ) -> tuple[g2f.DataD, list[g2f.PsfModel]]:
145 """Make data and psf_models for a catalog row.
147 Parameters
148 ----------
149 idx_row
150 The index of the row in each catalog.
151 catexps
152 Catalog-exposure pairs to initialize observations from.
154 Returns
155 -------
156 data
157 The resulting data object.
158 psf_models
159 A list of psf_models, one per catexp.
161 Notes
162 -----
163 Only observations with good data and valid PSF models will be
164 returned; bad data will be excluded from the return values.
165 """
166 observations = []
167 psf_models = []
169 for catexp in catexps:
170 catalog = catexp.get_catalog()
171 # This indicates that there's no corresponding exposure
172 # (the catexp interface expects a tabular type for catalog but
173 # no interface for an exposure has been defined, yet)
174 if len(catalog) == 0:
175 continue
176 source = catalog[idx_row]
177 observation = catexp.get_source_observation(source)
178 # If the observation or PSF model is bad enough that it cannot be
179 # fit, do not add it to the data.
180 if observation is not None:
181 psf_model = catexp.get_psf_model(source)
182 if psf_model is not None:
183 observations.append(observation)
184 # PSF model parameters cannot be fit along with sources
185 for param in get_params_uniq(psf_model):
186 param.fixed = True
187 psf_models.append(psf_model)
189 data = g2f.DataD(observations)
190 return data, psf_models
192 def make_point_sources(
193 self,
194 channels: Iterable[g2f.Channel],
195 sources: list[g2f.Source],
196 ) -> tuple[list[g2f.Source], list[g2f.Prior]]:
197 """Make initialized point sources given channels.
199 Parameters
200 ----------
201 channels
202 The channels to initialize fluxes for.
203 sources
204 List of sources.
206 Returns
207 -------
208 sources
209 The list of initialized sources.
210 priors
211 The list of priors.
213 Notes
214 -----
215 The prior list is always empty, but is returned to keep this function
216 consistent with make_sources.
217 """
218 point_sources = []
219 fluxes = [[{channel: 1.0 for channel in channels}]]
221 for (name_src, config_src), source in zip(self.config_model.sources.items(), sources):
222 centroids = next(iter(config_src.component_groups.values())).centroids
223 config_src_psf = SourceConfig(
224 component_groups={
225 "": ComponentGroupConfig(
226 centroids=centroids,
227 components_gauss={"": GaussianComponentConfig()},
228 )
229 }
230 )
231 source, _ = config_src_psf.make_source(fluxes)
232 point_sources.append(source)
234 return point_sources, []
236 def make_sources(
237 self,
238 channels: Iterable[g2f.Channel],
239 source_fluxes: list[list[list[Fluxes]]] | None = None,
240 ) -> tuple[list[g2f.Source], list[g2f.Prior]]:
241 """Make initialized sources given channels using `self.config_model`.
243 Parameters
244 ----------
245 channels
246 The channels to initialize fluxes for.
247 source_fluxes
248 A list of fluxes by channel for each component group in each
249 source. The default is to initialize using
250 `ComponentGroupConfig.get_fluxes_default`.
252 Returns
253 -------
254 sources
255 The list of initialized sources.
256 priors
257 The list of priors.
258 """
259 n_sources = len(self.config_model.sources)
260 if source_fluxes is None:
261 source_fluxes = [None] * n_sources
262 for idx, (config_source, component_group_fluxes) in enumerate(
263 zip(
264 self.config_model.sources.values(),
265 source_fluxes,
266 )
267 ):
268 component_group_fluxes = [
269 component_group.get_fluxes_default(
270 channels=channels,
271 component_configs=component_group.get_component_configs(),
272 is_fractional=component_group.is_fractional,
273 )
274 for component_group in config_source.component_groups.values()
275 ]
276 source_fluxes[idx] = component_group_fluxes
277 else:
278 if len(source_fluxes) != n_sources:
279 raise ValueError(f"{len(source_fluxes)=} != {len(self.config_model.sources)=}")
281 sources, priors = self.config_model.make_sources(
282 component_group_fluxes_srcs=source_fluxes,
283 )
285 has_prior_x = self.prior_cen_x_stddev > 0 and np.isfinite(self.prior_cen_x_stddev)
286 has_prior_y = self.prior_cen_y_stddev > 0 and np.isfinite(self.prior_cen_y_stddev)
287 if has_prior_x or has_prior_y:
288 for source in sources:
289 for param in get_params_uniq(source, fixed=False):
290 if has_prior_x and isinstance(param, g2f.CentroidXParameterD):
291 priors.append(g2f.GaussianPrior(param.x_param_ptr, 0, self.prior_cen_x_stddev))
292 elif has_prior_y and isinstance(param, g2f.CentroidYParameterD):
293 priors.append(g2f.GaussianPrior(param.y_param_ptr, 0, self.prior_cen_y_stddev))
295 return sources, priors
297 def schema_configurable(self) -> list[ColumnInfo]:
298 columns = []
299 if self.config_fit.eval_residual:
300 columns.append(ColumnInfo(key="n_eval_jac", dtype="i4"))
301 if self.fit_linear_final:
302 columns.append(ColumnInfo(key="delta_lnL_fit_linear", dtype="f8"))
303 if self.fit_psmodel_final:
304 columns.append(ColumnInfo(key="delta_lnL_fit_ps", dtype="f8"))
305 return columns
307 def schema(
308 self,
309 bands: list[str] | None = None,
310 ) -> list[ColumnInfo]:
311 if bands is None or not (len(bands) > 0):
312 raise ValueError("CatalogSourceFitter must provide at least one band")
313 schema = super().schema(bands)
315 parameters = CatalogSourceFitterConfigData(
316 config=self,
317 channels=tuple((g2f.Channel.get(band) for band in bands)),
318 ).parameters
319 unit_size = u.Unit("pix")
320 units = {
321 g2f.IntegralParameterD: self.unit_flux,
322 g2f.ReffXParameterD: unit_size,
323 g2f.ReffYParameterD: unit_size,
324 g2f.SizeXParameterD: unit_size,
325 g2f.SizeYParameterD: unit_size,
326 }
327 idx_start = len(schema)
328 schema.extend(
329 [
330 ColumnInfo(key=key, dtype="f8", unit=units.get(type(param)))
331 for key, param in parameters.items()
332 ]
333 )
334 # Keep track of covariance key by declination parameter indexs
335 # If we want to add RA/dec covariance, it'll need to come after decErr
336 keys_cov = {}
337 compute_errors = self.compute_errors != "NONE"
338 if self.convert_cen_xy_to_radec:
339 label_cen = self.get_key_cen()
340 cen_underscored = label_cen.startswith("_")
341 suffix_x, suffix_y, suffix_ra, suffix_dec = (
342 f"{label_cen}{suffix}"
343 for suffix in (
344 self.get_suffix_x(),
345 self.get_suffix_y(),
346 self.get_suffix_ra(),
347 self.get_suffix_dec(),
348 )
349 )
350 suffix_ra = f"{label_cen}{self.get_suffix_ra()}"
351 suffix_dec = f"{label_cen}{self.get_suffix_dec()}"
352 for key, param in parameters.items():
353 # TODO: Update if allowing x, y <-> dec, RA mappings
354 # ... or arbitrary rotations
355 is_y = isinstance(param, g2f.CentroidYParameterD)
356 suffix_radec, suffix_xy = (
357 (suffix_ra, suffix_x)
358 if isinstance(param, g2f.CentroidXParameterD)
359 else ((suffix_dec, suffix_y) if is_y else (None, None))
360 )
361 if suffix_radec is not None:
362 # Add whatever the corresponding prefix is, and also
363 # remove any leading underscore if there's no prefix
364 prefix, suffix = (
365 ("", suffix_radec[1:])
366 if (cen_underscored and (key == suffix_xy[1:]))
367 else (key.split(suffix_xy)[0], suffix_radec)
368 )
369 schema.append(ColumnInfo(key=f"{prefix}{suffix}", dtype="f8", unit=u.deg))
370 if compute_errors and is_y:
371 suffix_radec = f"{label_cen}{self.get_suffix_ra_dec_cov()}"
372 prefix, suffix = (
373 ("", suffix_radec[1:])
374 if (cen_underscored and (key == suffix_xy[1:]))
375 else (key.split(suffix_xy)[0], suffix_radec)
376 )
377 keys_cov[len(schema) - 1] = f"{prefix}{suffix}"
378 if compute_errors:
379 suffix = self.suffix_error
380 idx_end = len(schema)
381 for idx in range(idx_start, idx_end):
382 column = schema[idx]
383 schema.append(ColumnInfo(key=f"{column.key}{suffix}", dtype=column.dtype, unit=column.unit))
384 if (key_cov := keys_cov.get(idx)) is not None:
385 schema.append(ColumnInfo(key=key_cov, dtype="f8", unit=u.deg**2))
387 schema.extend(self.schema_configurable())
388 return schema
391class CatalogSourceFitterConfigData(pydantic.BaseModel):
392 """Configuration data for a fitter that can initialize lsst.gauss2d.fit
393 models and images thereof.
395 This class relies on cached properties being computed once, mostly shortly
396 after initialization. Therefore, it and the config field must be frozen to
397 ensure that the model remains unchanged.
398 """
400 model_config: ClassVar[pydantic.ConfigDict] = frozen_arbitrary_allowed_config
402 channels: list[g2f.Channel] = pydantic.Field(title="The list of channels")
403 config: CatalogSourceFitterConfig = pydantic.Field(title="A CatalogSourceFitterConfig to be frozen")
405 @pydantic.model_validator(mode="after")
406 def validate_config(self) -> Self:
407 self.config.validate()
408 return self
410 @cached_property
411 def components(self) -> tuple[g2f.Component]:
412 sources = self.sources_priors[0]
413 components = []
414 for source in sources:
415 components.extend(source.components)
416 return components
418 @cached_property
419 def parameters(self) -> dict[str, g2f.ParameterD]:
420 config = self.config
421 config_model = config.config_model
422 idx_comp_first = 0
423 has_prefix_source = config_model.has_prefix_source()
424 n_channels = len(self.channels)
425 parameters = {}
427 label_cen = config.get_key_cen()
428 label_rho = config.get_key_rho()
429 label_sersic = config.get_key_sersicindex()
430 label_x, label_y = config.get_suffix_x(), config.get_suffix_y()
432 for name_source, config_source in config_model.sources.items():
433 prefix_source = f"{name_source}_" if has_prefix_source else ""
434 has_prefix_group = config_source.has_prefix_group()
436 for name_group, config_group in config_source.component_groups.items():
437 prefix_group = f"{prefix_source}{name_group}_" if has_prefix_group else prefix_source
438 multicen = len(config_group.centroids) > 1
439 configs_comp = config_group.get_component_configs().items()
441 is_multicomp = len(configs_comp) > 1
443 for idx_comp_group, (name_comp, config_comp) in enumerate(configs_comp):
444 component = self.components[idx_comp_first + idx_comp_group]
446 key_comp = name_comp if is_multicomp else ""
447 prefix_comp = f"{prefix_group}{key_comp}"
448 key_size = config.get_prefixed_label(
449 config.get_key_size(config_comp.get_size_label()),
450 prefix_comp,
451 )
452 key_rho = config.get_prefixed_label(label_rho, prefix_comp)
454 if multicen or (idx_comp_group == 0):
455 prefix_cen = prefix_comp if multicen else prefix_group
456 # Avoid double-underscoring if there's nothing to
457 # prefix or an existing prefix
458 key_cen = config.get_prefixed_label(label_cen, prefix_cen)
459 parameters[f"{key_cen}{label_x}"] = component.centroid.x_param
460 parameters[f"{key_cen}{label_y}"] = component.centroid.y_param
461 if not config_comp.size_x.fixed:
462 parameters[f"{key_size}{label_x}"] = component.ellipse.size_x_param
463 if not config_comp.size_y.fixed:
464 parameters[f"{key_size}{label_y}"] = component.ellipse.size_y_param
465 if not config_comp.rho.fixed:
466 parameters[key_rho] = component.ellipse.rho_param
467 if not config_comp.flux.fixed:
468 # TODO: return this to component.integralmodel
469 # when binding for g2f.FractionalIntegralModel is fixed
470 params_flux = get_params_uniq(component, fixed=False, nonlinear=False)
471 if len(params_flux) != n_channels:
472 raise ValueError(f"{params_flux=} len={len(params_flux)} != {n_channels=}")
473 for channel, param_flux in zip(self.channels, params_flux):
474 key_flux = config.get_key_flux(label=prefix_comp, band=channel.name)
475 parameters[key_flux] = param_flux
476 if hasattr(config_comp, "sersic_index") and not config_comp.sersic_index.fixed:
477 parameters[config.get_prefixed_label(label_sersic, prefix_comp)] = (
478 component.sersicindex_param
479 )
481 return parameters
483 @cached_property
484 def sources_priors(self) -> tuple[tuple[g2f.Source], tuple[g2f.Prior]]:
485 sources, priors = self.config.make_sources(channels=self.channels)
486 return tuple(sources), tuple(priors)
489class CatalogSourceFitterABC(ABC, pydantic.BaseModel):
490 """Fit a Gaussian mixture source model to an image with a PSF model.
492 Notes
493 -----
494 Any exceptions raised and not in errors_expected will be logged in a
495 generic unknown_flag failure column.
496 """
498 model_config: ClassVar[pydantic.ConfigDict] = frozen_arbitrary_allowed_config
500 errors_expected: dict[Type[Exception], str] = pydantic.Field(
501 default_factory=dict,
502 title="A dictionary of Exceptions with the name of the flag column key to fill if raised.",
503 )
504 modeller: Modeller = pydantic.Field(
505 default_factory=Modeller,
506 title="A Modeller instance to use for fitting.",
507 )
509 def _get_columns_params_radec(
510 self,
511 params_radec: dict[str, tuple[g2f.CentroidXParameterD, g2f.CentroidYParameterD]],
512 compute_errors: bool,
513 config: CatalogSourceFitterConfig,
514 ) -> tuple[list[tuple[str, str, str, str]], list[tuple[str, str, str, str, str, str]]]:
515 """Get a list of the columns needed for conversion of x/y centroid
516 parameters into ra/dec.
518 Parameters
519 ----------
520 params_radec
521 Dict of tuple of x, y parameter objects by name.
522 compute_errors
523 Whether errors will be computed.
524 config
525 The configuration with column formatting parameters.
527 Returns
528 -------
529 columns_params_radec
530 Column names for RA, dec, x, and y.
531 columns_params_radec_err
532 Column names for RA_err, dec_err, x, y, x_err, y_err.
533 """
534 columns_params_radec = []
535 columns_params_radec_err = []
536 suffix_err = config.suffix_error
537 key_cen = config.get_key_cen()
538 suffix_x, suffix_y = config.get_suffix_x(), config.get_suffix_y()
539 suffix_ra, suffix_dec = config.get_suffix_ra(), config.get_suffix_dec()
541 for key_base, (param_cen_x, param_cen_y) in params_radec.items():
542 # This removes redundant underscores
543 key_base_cen = config.get_prefixed_label(key_cen, key_base)
545 if param_cen_y is None:
546 raise RuntimeError(
547 f"Fitter failed to find corresponding cen_y param for {key_base=}; is it fixed?"
548 )
549 column_ra = f"{key_base_cen}{suffix_ra}"
550 column_dec = f"{key_base_cen}{suffix_dec}"
552 columns_params_radec.append(
553 (
554 column_ra,
555 column_dec,
556 f"{key_base_cen}{suffix_x}",
557 f"{key_base_cen}{suffix_y}",
558 )
559 )
560 if compute_errors:
561 key_cov = (
562 None
563 if not config.compute_radec_covariance
564 else (f"{key_base_cen}{config.get_suffix_ra_dec_cov()}")
565 )
566 columns_params_radec_err.append(
567 (
568 f"{key_base_cen}{suffix_ra}{suffix_err}",
569 f"{key_base_cen}{suffix_dec}{suffix_err}",
570 f"{key_base_cen}{suffix_x}",
571 f"{key_base_cen}{suffix_y}",
572 f"{key_base_cen}{suffix_x}{suffix_err}",
573 f"{key_base_cen}{suffix_y}{suffix_err}",
574 key_cov,
575 column_ra,
576 column_dec,
577 )
578 )
579 return columns_params_radec, columns_params_radec_err
581 @staticmethod
582 def _get_logger() -> logging.Logger:
583 logger = logging.getLogger(__name__)
585 return logger
587 def _validate_errors_expected(self, config: CatalogSourceFitterConfig) -> None:
588 """Check that self.errors_expected is set correctly.
590 Parameters
591 ----------
592 config
593 The fitting configuration.
595 Raises
596 ------
597 ValueError
598 Raised if the configuration is invalid.
599 """
600 if len(self.errors_expected) != len(config.flag_errors):
601 raise ValueError(f"{self.errors_expected=} keys not same len as {config.flag_errors=}")
602 errors_bad = {}
603 errors_recast = {}
604 for error_name, error_type in self.errors_expected.items():
605 if error_type in errors_recast:
606 errors_bad[error_name] = error_type
607 else:
608 errors_recast[error_type] = error_name
609 if errors_bad:
610 raise ValueError(f"{self.errors_expected=} keys contain duplicates from {config.flag_errors=}")
612 def compute_model_radec_err(
613 self,
614 source_multi: Mapping[str, Any],
615 results,
616 columns_params_radec_err,
617 idx: int,
618 set_radec: bool = False,
619 ) -> None:
620 """Compute right ascension and declination errors for a source.
622 This default implementation is naive, assuming only that
623 get_model_radec is implemented, and should be overridden.
625 Parameters
626 ----------
627 source_multi
628 A mapping with fields expected to be populated in the
629 corresponding multiband source catalog.
630 results
631 The output catalog to read/write from/to.
632 columns_params_radec_err
633 A list of tuples containing six keys for:
634 ra, dec: RA/Dec inputs.
635 ra_err, dec_err: RA/Dec error outputs.
636 cen_x, cen_y: Pixel x/y centroid inputs.
637 cen_x_err, cen_y_err: Pixel x/y centroid error inputs.
638 idx
639 The integer index of this source in the results catalog.
640 set_radec
641 Whether this method should set RA, dec values instead of reading
642 them (should be True if defer_radec_conversion is True).
643 """
644 for (
645 key_ra_err,
646 key_dec_err,
647 key_cen_x,
648 key_cen_y,
649 key_cen_x_err,
650 key_cen_y_err,
651 key_cen_ra_dec_cov,
652 key_ra,
653 key_dec,
654 ) in columns_params_radec_err:
655 cen_x, cen_y = results[key_cen_x][idx], results[key_cen_y][idx]
656 # TODO: improve this in DM-45682
657 # For one, it won't work right at limits:
658 # RA=359.99... or dec=+89.99...
659 # Could also consider dividing by sqrt(2)
660 # ...but that factor would multiply out later
661 ra_err, dec_err = self.get_model_radec(
662 source_multi,
663 cen_x + results[key_cen_x_err][idx],
664 cen_y + results[key_cen_y_err][idx],
665 )
666 ra, dec = results[key_ra][idx], results[key_dec][idx]
667 results[key_ra_err][idx], results[key_dec_err][idx] = abs(ra_err - ra), abs(dec_err - dec)
669 def copy_centroid_errors(
670 self,
671 columns_cenx_err_copy: tuple[str],
672 columns_ceny_err_copy: tuple[str],
673 results: Table,
674 catalog_multi: Sequence,
675 catexps: list[CatalogExposureSourcesABC],
676 config_data: CatalogSourceFitterConfigData,
677 ) -> None:
678 """Copy centroid errors from an input catalog.
680 This method exists to support fitting models with fixed centroids
681 derived from an input catalog. Implementers can simply copy an
682 existing column into the results catalog or use the data as needed;
683 however, there is no reasonable default implementation.
685 Parameters
686 ----------
687 columns_cenx_err_copy
688 X-axis result centroid columns to copy errors for.
689 columns_ceny_err_copy
690 Y-axis result centroid columns to copy errors for.
691 results
692 The table of fit results to copy errors into.
693 catalog_multi
694 The input multiband catalog.
695 catexps
696 The input data.
697 config_data
698 The fitter config and data.
700 Raises
701 ------
702 NotImplementedError
703 Raised if columns need to be copied but no implementation is
704 available.
705 """
706 if columns_cenx_err_copy or columns_ceny_err_copy:
707 raise NotImplementedError(
708 f"Fitter of {type(self)=} got {columns_cenx_err_copy=} and/or {columns_ceny_err_copy=}"
709 f" but has not overriden copy_centroid_errors"
710 )
712 def fit(
713 self,
714 catalog_multi: Sequence,
715 catexps: list[CatalogExposureSourcesABC],
716 config_data: CatalogSourceFitterConfigData | None = None,
717 logger: logging.Logger | None = None,
718 **kwargs: Any,
719 ) -> astropy.table.Table:
720 """Fit PSF-convolved source models with MultiProFit.
722 Each source has a single PSF-convolved model fit, given PSF model
723 parameters from a catalog, and a combination of initial source
724 model parameters and a deconvolved source image from the
725 CatalogExposureSources.
727 Parameters
728 ----------
729 catalog_multi
730 A multi-band source catalog to fit a model to.
731 catexps
732 A list of (source and psf) catalog-exposure pairs.
733 config_data
734 Configuration settings and data for fitting and output.
735 logger
736 The logger. Defaults to calling `_getlogger`.
737 **kwargs
738 Additional keyword arguments to pass to self.modeller.
740 Returns
741 -------
742 catalog : `astropy.Table`
743 A table with fit parameters for the PSF model at the location
744 of each source.
745 """
746 if config_data is None:
747 config_data = CatalogSourceFitterConfigData(
748 config=CatalogSourceFitterConfig(),
749 channels=[catexp.channel for catexp in catexps],
750 )
751 if logger is None:
752 logger = self._get_logger()
754 config = config_data.config
755 self._validate_errors_expected(config)
756 self.validate_fit_inputs(
757 catalog_multi=catalog_multi, catexps=catexps, config_data=config_data, logger=logger, **kwargs
758 )
760 model_sources, priors = config_data.sources_priors
762 # TODO: If free Observation params are ever supported, make null Data
763 # Because config_data knows nothing about the Observation(s)
764 params = config_data.parameters
765 values_init = {param: param.value for param in params.values() if param.free}
766 prefix = config.prefix_column
767 columns_param_fixed: dict[str, tuple[g2f.ParameterD, float]] = {}
768 columns_param_free: dict[str, tuple[g2f.ParameterD, float]] = {}
769 columns_param_flux: dict[str, g2f.IntegralParameterD] = {}
770 params_cen_x: dict[str, g2f.CentroidXParameterD] = {}
771 params_cen_y: dict[str, g2f.CentroidYParameterD] = {}
772 columns_err = []
774 errors_hessian: bool = config.compute_errors == "INV_HESSIAN"
775 errors_hessian_bestfit: bool = config.compute_errors == "INV_HESSIAN_BESTFIT"
776 compute_errors: bool = errors_hessian or errors_hessian_bestfit
778 columns_cenx_err_copy = []
779 columns_ceny_err_copy = []
781 suffix_err = config.suffix_error
782 key_cen = config.get_key_cen()
783 cen_underscored = key_cen.startswith("_")
784 suffix_cenx = f"{key_cen}{config.get_suffix_x()}"
785 suffix_ceny = f"{key_cen}{config.get_suffix_y()}"
787 # Add each param to appropriate and more specific pre-computed lists
788 for key, param in params.items():
789 key_full = f"{prefix}{key}"
790 is_cenx = isinstance(param, g2f.CentroidXParameterD)
791 is_ceny = isinstance(param, g2f.CentroidYParameterD)
793 # Add the corresponding error key to the appropriate list
794 if compute_errors:
795 if param.free:
796 columns_err.append(f"{key_full}{suffix_err}")
797 elif is_cenx:
798 columns_cenx_err_copy.append(f"{key_full}{suffix_err}")
799 elif is_ceny:
800 columns_ceny_err_copy.append(f"{key_full}{suffix_err}")
802 # Add this param to the appropriate dict
803 (columns_param_fixed if param.fixed else columns_param_free)[key_full] = (
804 param,
805 config_data.config.centroid_pixel_offset if (is_cenx or is_ceny) else 0,
806 )
807 if isinstance(param, g2f.IntegralParameterD):
808 columns_param_flux[key_full] = param
809 elif config.convert_cen_xy_to_radec:
810 # Infer the prefix if possible, after checking for a dropped
811 # leading underscore in case there's no prefix
812 if is_cenx:
813 prefix_cen, suffix_cen = (
814 ("", key_full)
815 if (cen_underscored and (key_full == suffix_cenx[1:]))
816 else key_full.split(suffix_cenx)
817 )
818 params_cen_x[prefix_cen] = param
819 elif is_ceny:
820 prefix_cen, suffix_cen = (
821 ("", key_full)
822 if (cen_underscored and (key_full == suffix_ceny[1:]))
823 else key_full.split(suffix_ceny)
824 )
825 params_cen_y[prefix_cen] = param
827 if config.convert_cen_xy_to_radec or config.fit_psmodel_final:
828 assert params_cen_x.keys() == params_cen_y.keys()
829 columns_params_radec, columns_params_radec_err = self._get_columns_params_radec(
830 {k: (x, params_cen_y[k]) for k, x in params_cen_x.items()},
831 compute_errors,
832 config=config,
833 )
835 fit_psmodel_final = False
836 if config.fit_psmodel_final:
837 # This should never be True until DM-46497 is merged, but models
838 # in other/future derived classes might have multiple centroids
839 if (len(set(params_cen_x.values())) > 1) or (len(set(params_cen_y.values())) > 1):
840 raise ValueError(
841 f"Got {params_cen_x=} and {params_cen_y} with > 1 unique elements, so "
842 f"config.fit_psmodel_final may not be set to True"
843 )
844 fit_psmodel_final = True
846 key_cen_x_psmodel, key_cen_y_psmodel = columns_params_radec[0][2:4]
848 channels = config_data.channels
849 sources_psmodel, priors_psmodel = config.make_point_sources(channels, model_sources)
850 params_psmodel = sources_psmodel[0].parameters()
851 cenx_psmodel, ceny_psmodel = None, None
852 fluxes_psmodel = {}
853 idx_band = 0
854 for param in params_psmodel:
855 if isinstance(param, g2f.CentroidXParameterD):
856 if cenx_psmodel is not None:
857 raise RuntimeError("Point source model found multiple x centroids")
858 cenx_psmodel = param
859 elif isinstance(param, g2f.CentroidYParameterD):
860 if ceny_psmodel is not None:
861 raise RuntimeError("Point source model found multiple y centroids")
862 ceny_psmodel = param
863 elif isinstance(param, g2f.IntegralParameterD):
864 fluxes_psmodel[channels[idx_band]] = param
865 idx_band += 1
867 convert_cen_xy_to_radec_first = config.convert_cen_xy_to_radec and not (
868 config.compute_errors and config.defer_radec_conversion
869 )
871 # Setup the results table with correct column names
872 n_rows = len(catalog_multi)
873 channels = self.get_channels(catexps)
874 results, columns = config.make_catalog(n_rows, bands=list(channels.keys()))
876 # Copy centroid error columns into results ( if needed)
877 self.copy_centroid_errors(
878 columns_cenx_err_copy=columns_cenx_err_copy,
879 columns_ceny_err_copy=columns_ceny_err_copy,
880 results=results,
881 catalog_multi=catalog_multi,
882 catexps=catexps,
883 config_data=config_data,
884 )
886 # dummy size for first iteration
887 size, size_new = 0, 0
888 fitInputs = FitInputsDummy()
889 plot = False
891 # Configure default options for calls to compute_variances
892 # keys are for values of return_negative
893 kwargs_err_default = {
894 True: {
895 "options": g2f.HessianOptions(findiff_add=1e-3, findiff_frac=1e-3),
896 "use_diag_only": config.compute_errors_no_covar,
897 },
898 False: {"options": g2f.HessianOptions(findiff_add=1e-6, findiff_frac=1e-6)},
899 }
901 range_idx = range(n_rows)
903 # TODO: Do this check with dummy data
904 # It might not work with real data if the first row is bad
905 # data, psf_models = config.make_model_data(
906 # idx_row=range_idx[0], catexps=catexps)
907 # model = g2f.ModelD(data=data, psfmodels=psf_models,
908 # sources=model_sources, priors=priors)
909 # Remember to filter out fixed centroids from params
910 # assert list(params.values()) == get_params_uniq(model, fixed=False)
912 time_init_all = time.process_time()
913 logger_periodic = PeriodicLogger(logger)
914 n_skipfail = 0
916 for idx in range_idx:
917 time_init = time.process_time()
918 row = results[idx]
919 source_multi = catalog_multi[idx]
920 id_source = source_multi[config.column_id]
921 row[config.column_id] = id_source
922 time_final = time_init
924 try:
925 data, psf_models = config.make_model_data(idx_row=idx, catexps=catexps)
926 if data.size == 0:
927 raise NoDataError("make_model_data returned empty data")
928 model = g2f.ModelD(data=data, psfmodels=psf_models, sources=model_sources, priors=priors)
929 self.initialize_model(
930 model,
931 source_multi,
932 catexps,
933 config_data=config_data,
934 values_init=values_init,
935 )
937 # Caches the jacobian residual if the data size is unchanged
938 # Note: this will need to change with priors
939 # (data should report its own size)
940 size_new = np.sum([datum.image.size for datum in data])
941 if size_new != size:
942 fitInputs = None
943 size = size_new
944 # Some algorithms might not even use fitInputs
945 elif fitInputs is not None:
946 fitInputs = fitInputs if not fitInputs.validate_for_model(model) else None
948 # TODO: Check if flux param limits and transforms are set
949 # appropriately if config.fit_linear_init is False
950 if config.fit_linear_init:
951 self.modeller.fit_model_linear(model=model, ratio_min=0.01)
953 for observation in data:
954 observation.image.data[~np.isfinite(observation.image.data)] = 0
956 result_full = self.modeller.fit_model(
957 model, fitinputs=fitInputs, config=config.config_fit, **kwargs
958 )
959 fitInputs = result_full.inputs
960 results[f"{prefix}n_iter"][idx] = result_full.n_eval_func
961 results[f"{prefix}time_eval"][idx] = result_full.time_eval
962 results[f"{prefix}time_fit"][idx] = result_full.time_run
963 if config.config_fit.eval_residual:
964 results[f"{prefix}n_eval_jac"][idx] = result_full.n_eval_jac
966 params_free_missing = result_full.params_free_missing or tuple()
968 # Set all params to best fit values
969 # In case the optimizer doesn't
970 for (key, (param, offset)), value in zip(
971 columns_param_free.items(),
972 result_full.params_best,
973 ):
974 param.value_transformed = value
975 if param not in params_free_missing:
976 results[key][idx] = param.value + offset
978 # Also add any offset to the fixed parameters
979 # (usually centroids, if any)
980 for key, (param, offset) in columns_param_fixed.items():
981 results[key][idx] = param.value + offset
983 # Do a final linear fit
984 # If the nonlinear fit is good, the values won't change much
985 if config.fit_linear_final:
986 loglike_init, loglike_new = self.modeller.fit_model_linear(
987 model=model, ratio_min=0.01, validate=True
988 )
989 loglike_final = max(loglike_init, loglike_new)
990 results[f"{prefix}delta_lnL_fit_linear"][idx] = np.sum(loglike_new) - np.sum(loglike_init)
992 if params_free_missing:
993 columns_param_flux_fit = {
994 column: param
995 for column, param in columns_param_flux.items()
996 if param not in params_free_missing
997 }
998 else:
999 columns_param_flux_fit = columns_param_flux
1001 for column, param in columns_param_flux_fit.items():
1002 results[column][idx] = param.value
1003 else:
1004 loglike_final = model.evaluate()
1006 if convert_cen_xy_to_radec_first:
1007 for key_ra, key_dec, key_cen_x, key_cen_y in columns_params_radec:
1008 # These will have been converted back if necessary
1009 cen_x, cen_y = results[key_cen_x][idx], results[key_cen_y][idx]
1010 radec = self.get_model_radec(source_multi, cen_x, cen_y)
1011 results[key_ra][idx], results[key_dec][idx] = radec
1013 if fit_psmodel_final:
1014 cen_x, cen_y = results[key_cen_x_psmodel][idx], results[key_cen_y_psmodel][idx]
1015 cenx_psmodel.value = cen_x
1016 ceny_psmodel.value = cen_y
1017 model_psf = g2f.ModelD(data=data, psfmodels=psf_models, sources=sources_psmodel)
1018 _ = self.modeller.fit_model_linear(model_psf)
1019 model_psf.setup_evaluators(evaluatormode=g2f.EvaluatorMode.loglike)
1020 loglike_psfmodel = model_psf.evaluate()
1021 # Reset fluxes for the next fit
1022 for param in fluxes_psmodel.values():
1023 param.value = 1.0
1024 results[f"{prefix}delta_lnL_fit_ps"][idx] = loglike_final[0] - loglike_psfmodel[0]
1026 if compute_errors:
1027 errors = []
1028 model_eval = model
1029 errors_iter = None
1030 for param in params_free_missing:
1031 param.fixed = True
1033 if config.compute_errors_from_jacobian:
1034 try:
1035 errors_iter = np.sqrt(
1036 self.modeller.compute_variances(
1037 model_eval,
1038 transformed=False,
1039 use_diag_only=config.compute_errors_no_covar,
1040 )
1041 )
1042 errors.append((errors_iter, np.sum(~(errors_iter > 0))))
1043 except Exception:
1044 pass
1045 # If computing errors from the Jacobian didn't work, or if
1046 # it was disabled in the config, try the Hessian
1047 if errors_iter is None:
1048 img_data_old = []
1049 if errors_hessian_bestfit:
1050 # Model sans prior
1051 model_eval = g2f.ModelD(
1052 data=model.data, psfmodels=model.psfmodels, sources=model.sources
1053 )
1054 model_eval.setup_evaluators(evaluatormode=g2f.EvaluatorMode.image)
1055 model_eval.evaluate()
1056 # Compute the errors by setting the data to the
1057 # best-fit model (a quasi-parametric bootstrap
1058 # with one iteration)
1059 for obs, output in zip(model_eval.data, model_eval.outputs):
1060 img_data_old.append(obs.image.data.copy())
1061 img = obs.image.data
1062 img.flat = output.data.flat
1063 # To make this a real bootstrap, could do this
1064 # (but would need to iterate):
1065 # + rng.standard_normal(img.size)*(
1066 # obs.sigma_inv.data.flat)
1068 # Try without forcing all of the Hessian terms to be
1069 # negative first. At the optimum they should be, but
1070 # in practice the best-fit values are always at least
1071 # a little off and so the sign is equally likely to be
1072 # positive as negative.
1073 for return_negative in (False, True):
1074 kwargs_err = kwargs_err_default[return_negative]
1075 if errors and errors[-1][1] == 0:
1076 break
1077 try:
1078 errors_iter = np.sqrt(
1079 self.modeller.compute_variances(
1080 model_eval, transformed=False, **kwargs_err
1081 )
1082 )
1083 errors.append((errors_iter, np.sum(~(errors_iter > 0))))
1084 except Exception:
1085 try:
1086 errors_iter = np.sqrt(
1087 self.modeller.compute_variances(
1088 model_eval,
1089 transformed=False,
1090 use_svd=True,
1091 **kwargs_err,
1092 )
1093 )
1094 errors.append((errors_iter, np.sum(~(errors_iter > 0))))
1095 except Exception:
1096 pass
1097 # Return the data to its original noisy values
1098 # (it was replaced by the model earlier)
1099 if errors_hessian_bestfit:
1100 for obs, img_datum_old in zip(model.data, img_data_old):
1101 obs.image.data.flat = img_datum_old.flat
1102 # Save and optionally plot the errors
1103 if errors:
1104 idx_min = np.argmax([err[1] for err in errors])
1105 errors = errors[idx_min][0]
1106 if plot:
1107 errors_plot = np.clip(errors, 0, 1000)
1108 errors_plot[~np.isfinite(errors_plot)] = 0
1109 from ..plotting import ErrorValues, plot_loglike
1111 try:
1112 plot_loglike(model, errors={"err": ErrorValues(values=errors_plot)})
1113 except Exception:
1114 for param in params:
1115 param.fixed = False
1117 if params_free_missing:
1118 columns_err_fitted = [
1119 column
1120 for column, param in zip(columns_err, params.values())
1121 if param not in params_free_missing
1122 ]
1123 else:
1124 columns_err_fitted = columns_err
1126 for value, column_err in zip(errors, columns_err_fitted):
1127 results[column_err][idx] = value
1129 for param in params_free_missing:
1130 param.fixed = False
1132 # Convert the x/y errors to ra/dec errors
1133 if config.convert_cen_xy_to_radec:
1134 self.compute_model_radec_err(
1135 source_multi,
1136 results,
1137 columns_params_radec_err,
1138 idx,
1139 set_radec=not convert_cen_xy_to_radec_first,
1140 )
1142 results[f"{prefix}chisq_reduced"][idx] = result_full.chisq_best / size
1143 time_final = time.process_time()
1144 results[f"{prefix}time_full"][idx] = time_final - time_init
1145 except Exception as e:
1146 n_skipfail += 1
1147 size = 0 if fitInputs is None else size_new
1148 column = self.errors_expected.get(e.__class__, "")
1149 if column:
1150 row[f"{prefix}{column}"] = True
1151 logger.debug(
1152 "id_source=%i (idx=%i/%i) fit failed with known exception: %s",
1153 id_source,
1154 idx,
1155 n_rows,
1156 e,
1157 )
1158 else:
1159 row[f"{prefix}unknown_flag"] = True
1160 logger.info(
1161 "id_source=%i (idx=%i/%i) fit failed with unexpected exception: %s",
1162 id_source,
1163 idx,
1164 n_rows,
1165 e,
1166 exc_info=1,
1167 )
1168 logger_periodic.log(
1169 "Fit idx=%i/%i sources (%i skipped/failed) in %.2f",
1170 idx,
1171 n_rows,
1172 n_skipfail,
1173 time_final - time_init_all,
1174 )
1176 n_unknown = np.sum(row[f"{prefix}unknown_flag"])
1177 if n_unknown > 0:
1178 logger.warning("%i/%i source fits failed with unexpected exceptions", n_unknown, n_rows)
1180 return results
1182 def get_channels(
1183 self,
1184 catexps: list[CatalogExposureSourcesABC],
1185 ) -> dict[str, g2f.Channel]:
1186 channels = {}
1187 for catexp in catexps:
1188 try:
1189 channel = catexp.channel
1190 except AttributeError:
1191 band = catexp.band
1192 if callable(band):
1193 band = band()
1194 channel = g2f.Channel.get(band)
1195 if channel not in channels:
1196 channels[channel.name] = channel
1197 return channels
1199 def get_model(
1200 self,
1201 idx_row: int,
1202 catalog_multi: Sequence,
1203 catexps: list[CatalogExposureSourcesABC],
1204 config_data: CatalogSourceFitterConfigData | None = None,
1205 results: astropy.table.Table | None = None,
1206 **kwargs: Any,
1207 ) -> g2f.ModelD:
1208 """Reconstruct the model for a single row of a fit catalog.
1210 Parameters
1211 ----------
1212 idx_row
1213 The index of the row in the catalog.
1214 catalog_multi
1215 The multi-band catalog originally used for initialization.
1216 catexps
1217 The catalog-exposure pairs to reconstruct the model for.
1218 config_data
1219 The configuration used to generate sources.
1220 Default-initialized if None.
1221 results
1222 The corresponding best-fit parameter catalog to initialize
1223 parameter values from. If None, the model params will be set by
1224 `self.initialize_model`, as they would be when calling `self.fit`.
1225 **kwargs
1226 Additional keyword arguments to pass to initialize_model. Not
1227 used during fitting.
1229 Returns
1230 -------
1231 model
1232 The reconstructed model.
1233 """
1234 channels = self.get_channels(catexps)
1235 if config_data is None:
1236 config_data = CatalogSourceFitterConfigData(
1237 config=CatalogSourceFitterConfig(),
1238 channels=list(channels.values()),
1239 )
1240 config = config_data.config
1242 if not idx_row >= 0:
1243 raise ValueError(f"{idx_row=} !>=0")
1244 if not len(catalog_multi) > idx_row:
1245 raise ValueError(f"{len(catalog_multi)=} !> {idx_row=}")
1246 if (results is not None) and not (len(results) > idx_row):
1247 raise ValueError(f"{len(results)=} !> {idx_row=}")
1249 model_sources, priors = config_data.sources_priors
1250 source_multi = catalog_multi[idx_row]
1252 data, psf_models = config.make_model_data(
1253 idx_row=idx_row,
1254 catexps=catexps,
1255 )
1256 model = g2f.ModelD(data=data, psfmodels=psf_models, sources=model_sources, priors=priors)
1257 self.initialize_model(model, source_multi, catexps, **kwargs)
1259 if results is not None:
1260 row = results[idx_row]
1261 for column, param in config_data.parameters.items():
1262 param.value = row[f"{config.prefix_column}{column}"]
1264 return model
1266 def get_model_radec(self, source: Mapping[str, Any], cen_x: float, cen_y: float) -> tuple[float, float]:
1267 """Return right ascension and declination values for a source.
1269 Implementing this method is necessary only when fitting data with
1270 accompanying WCS.
1272 Parameters
1273 ----------
1274 source
1275 A mapping with fields expected to be populated in the
1276 corresponding source catalog.
1277 cen_x
1278 The x-axis centroid in pixel coordinates.
1279 cen_y
1280 The y-axis centroid in pixel coordinates.
1282 Returns
1283 -------
1284 ra, dec
1285 The right ascension and declination.
1286 """
1287 raise RaDecConversionNotImplementedError("get_model_radec has no default implementation")
1289 @abstractmethod
1290 def initialize_model(
1291 self,
1292 model: g2f.ModelD,
1293 source: Mapping[str, Any],
1294 catexps: list[CatalogExposureSourcesABC],
1295 config_data: CatalogSourceFitterConfigData,
1296 values_init: Mapping[g2f.ParameterD, float] | None = None,
1297 **kwargs: Any,
1298 ) -> None:
1299 """Initialize a Model for a single source row.
1301 Parameters
1302 ----------
1303 model
1304 The model object to initialize.
1305 source
1306 A mapping with fields expected to be populated in the
1307 corresponding source catalog for initialization.
1308 catexps
1309 A list of (source and psf) catalog-exposure pairs.
1310 config_data
1311 Configuration settings and data for fitting and output.
1312 values_init
1313 Initial parameter values from the model configuration.
1314 **kwargs
1315 Additional keyword arguments that cannot be required for fitting.
1316 """
1318 @abstractmethod
1319 def validate_fit_inputs(
1320 self,
1321 catalog_multi: Sequence,
1322 catexps: list[CatalogExposureSourcesABC],
1323 config_data: CatalogSourceFitterConfigData = None,
1324 logger: logging.Logger = None,
1325 **kwargs: Any,
1326 ) -> None:
1327 """Validate inputs to self.fit.
1329 This method is called before any fitting is done. It may be used for
1330 any purpose, including checking that the inputs are a particular
1331 subclass of the base classes.
1333 Parameters
1334 ----------
1335 catalog_multi
1336 A multi-band source catalog to fit a model to.
1337 catexps
1338 A list of (source and psf) catalog-exposure pairs.
1339 config_data
1340 Configuration settings and data for fitting and output.
1341 logger
1342 The logger. Defaults to calling `_getlogger`.
1343 **kwargs
1344 Additional keyword arguments to pass to self.modeller.
1345 """
1346 pass