Coverage for tests / test_sourceconfig.py: 23%
61 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-30 08:48 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-30 08:48 +0000
1# This file is part of multiprofit.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22import lsst.gauss2d.fit as g2f
23from lsst.multiprofit.componentconfig import (
24 GaussianComponentConfig,
25 ParameterConfig,
26 SersicComponentConfig,
27 SersicIndexParameterConfig,
28)
29from lsst.multiprofit.sourceconfig import ComponentGroupConfig, SourceConfig
30from lsst.multiprofit.utils import get_params_uniq
31import numpy as np
32import pytest
35@pytest.fixture(scope="module")
36def centroid_limits():
37 limits = g2f.LimitsD(min=-np.inf, max=np.inf)
38 return limits
41@pytest.fixture(scope="module")
42def centroid(centroid_limits):
43 cenx = g2f.CentroidXParameterD(0, limits=centroid_limits, fixed=True)
44 ceny = g2f.CentroidYParameterD(0, limits=centroid_limits, fixed=True)
45 centroid = g2f.CentroidParameters(cenx, ceny)
46 return centroid
49@pytest.fixture(scope="module")
50def channels():
51 return {band: g2f.Channel.get(band) for band in ("R", "G", "B")}
54def test_ComponentGroupConfig(centroid):
55 with pytest.raises(ValueError):
56 config = ComponentGroupConfig(
57 components_gauss={"x": GaussianComponentConfig()},
58 components_sersic={"x": SersicComponentConfig()},
59 )
60 config.validate()
63def test_SourceConfig_base():
64 with pytest.raises(ValueError):
65 config = SourceConfig()
66 config.validate()
68 with pytest.raises(ValueError):
69 config = SourceConfig(component_groups={})
70 config.validate()
73def test_SourceConfig_fractional(centroid):
74 rho, size_x, size_y = -0.3, 1.4, 1.6
75 drho, dsize_x, dsize_y = 0.5, 1.6, 1.3
77 n_components = 2
78 config = SourceConfig(
79 component_groups={
80 "src": ComponentGroupConfig(
81 components_gauss={
82 str(idx): GaussianComponentConfig(
83 rho=ParameterConfig(value_initial=rho + idx * drho),
84 size_x=ParameterConfig(value_initial=size_x + idx * dsize_x),
85 size_y=ParameterConfig(value_initial=size_y + idx * dsize_y),
86 )
87 for idx in range(n_components)
88 },
89 is_fractional=True,
90 )
91 },
92 )
93 config.validate()
94 channel = g2f.Channel.NONE
95 psf_model, priors = config.make_psf_model(
96 [
97 [
98 {channel: 1.0},
99 {channel: 0.5},
100 ]
101 ],
102 )
103 assert len(priors) == 0
104 assert len(psf_model.components) == n_components
107def test_SourceConfig_linear(centroid, channels):
108 rho, size_x, size_y, sersicn, flux = 0.4, 1.5, 1.9, 0.5, 4.7
109 drho, dsize_x, dsize_y, dsersicn, dflux = -0.9, 2.5, 5.4, 2.8, 13.9
111 names = ("PS", "Sersic")
112 config = SourceConfig(
113 component_groups={
114 "src": ComponentGroupConfig(
115 components_sersic={
116 name: SersicComponentConfig(
117 rho=ParameterConfig(value_initial=rho + idx * drho),
118 size_x=ParameterConfig(value_initial=size_x + idx * dsize_x),
119 size_y=ParameterConfig(value_initial=size_y + idx * dsize_y),
120 sersic_index=SersicIndexParameterConfig(
121 value_initial=sersicn + idx * dsersicn,
122 fixed=idx == 0,
123 prior_mean=None,
124 ),
125 )
126 for idx, name in enumerate(names)
127 }
128 ),
129 }
130 )
131 fluxes = [
132 {
133 channel: flux + idx_channel * dflux * idx_comp
134 for idx_channel, channel in enumerate(channels.values())
135 }
136 for idx_comp in range(len(config.component_groups["src"].components_sersic))
137 ]
138 source, priors = config.make_source([fluxes])
139 assert len(priors) == 0
140 for idx, component in enumerate(source.components):
141 params = get_params_uniq(component)
142 values_init = {
143 g2f.RhoParameterD: rho + idx * drho,
144 g2f.ReffXParameterD: size_x + idx * dsize_x,
145 g2f.ReffYParameterD: size_y + idx * dsize_y,
146 g2f.SersicIndexParameterD: sersicn + idx * dsersicn,
147 }
148 for name_group, component_group in config.component_groups.items():
149 fluxes_comp = fluxes[idx]
150 name_comp = names[idx]
151 config_comp = component_group.components_sersic[name_comp]
152 fluxes_label = {
153 config.format_label(
154 component_group.format_label(
155 label=config_comp.format_label(
156 label=config.get_integral_label_default(), name_channel=channel.name
157 ),
158 name_component=name_comp,
159 ),
160 name_group=name_group,
161 ): fluxes_comp[channel]
162 for channel in channels.values()
163 }
164 for param in params:
165 if isinstance(param, g2f.IntegralParameterD):
166 assert fluxes_label[param.label] == param.value
167 elif value_init := values_init.get(param.__class__):
168 assert param.value == value_init