Coverage for python / lsst / multiprofit / fitting / fit_catalog.py: 50%

109 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-14 23:46 +0000

1# This file is part of multiprofit. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = ["CatalogExposureABC", "ColumnInfo", "CatalogFitterConfig"] 

23 

24from abc import ABC, abstractmethod 

25from collections.abc import Iterable 

26from typing import ClassVar 

27 

28from astropy.table import Table 

29import astropy.units as u 

30import lsst.pex.config as pexConfig 

31import numpy as np 

32import pydantic 

33 

34from ..componentconfig import GaussianComponentConfig, SersicComponentConfig 

35from ..modeller import ModelFitConfig 

36from ..utils import frozen_arbitrary_allowed_config 

37 

38 

39class CatalogExposureABC(ABC): 

40 """Interface for catalog-exposure pairs.""" 

41 

42 # TODO: add get_exposure (with Any return type?) 

43 

44 @abstractmethod 

45 def get_catalog(self) -> Iterable: 

46 """Return a row-iterable catalog covering an exposure.""" 

47 

48 

49class ColumnInfo(pydantic.BaseModel): 

50 """Metadata for a column in a catalog.""" 

51 

52 model_config: ClassVar[pydantic.ConfigDict] = frozen_arbitrary_allowed_config 

53 

54 dtype: str = pydantic.Field(title="Column data type name (numpy or otherwise)") 

55 key: str = pydantic.Field(title="Column key (name)") 

56 description: str = pydantic.Field("", title="Column description") 

57 unit: u.UnitBase | None = pydantic.Field(None, title="Column unit (astropy)") 

58 

59 

60class CatalogFitterConfig(pexConfig.Config): 

61 """Configuration for generic MultiProFit fitting tasks.""" 

62 

63 column_id = pexConfig.Field[str](default="id", doc="Catalog index column key") 

64 compute_errors = pexConfig.ChoiceField[str]( 

65 default="INV_HESSIAN_BESTFIT", 

66 doc="Whether/how to compute sqrt(variances) of each free parameter", 

67 allowed={ 

68 "NONE": "no errors computed", 

69 "INV_HESSIAN": "inverse hessian using noisy image as data", 

70 "INV_HESSIAN_BESTFIT": "inverse hessian using best-fit model as data", 

71 }, 

72 ) 

73 compute_errors_from_jacobian = pexConfig.Field[bool]( 

74 default=True, 

75 doc="Whether to estimate the Hessian from the Jacobian first, with finite differencing as a backup", 

76 ) 

77 compute_errors_no_covar = pexConfig.Field[bool]( 

78 default=True, 

79 doc="Whether to compute parameter errors independently, ignoring covariances", 

80 ) 

81 config_fit = pexConfig.ConfigField[ModelFitConfig](default=ModelFitConfig, doc="Fitter configuration") 

82 fit_centroid = pexConfig.Field[bool](default=True, doc="Fit centroid parameters") 

83 fit_linear_init = pexConfig.Field[bool](default=True, doc="Fit linear parameters after initialization") 

84 fit_linear_final = pexConfig.Field[bool](default=True, doc="Fit linear parameters after optimization") 

85 float_fill_value = pexConfig.Field[float]( 

86 default=np.nan, doc="Fill value for float fields when creating the output table." 

87 ) 

88 flag_errors = pexConfig.DictField( 

89 default={}, 

90 keytype=str, 

91 itemtype=str, 

92 doc="Flag column names to set, keyed by name of exception to catch", 

93 ) 

94 integer_fill_value = pexConfig.Field[int]( 

95 default=-1, doc="Fill value for integer fields when creating the output table." 

96 ) 

97 naming_scheme = pexConfig.ChoiceField[str]( 

98 doc="Naming scheme for column names", 

99 allowed={ 

100 "default": "snake_case with {component_name}[_{band}]_{parameter}[_err]", 

101 "camel": "CamelCase with {component_name}[_{band}]_{parameter}[Err]", 

102 "lsst": "snake_case with [{band}_]{component_name}_{parameter}[Err]", 

103 }, 

104 default="default", 

105 ) 

106 prefix_column = pexConfig.Field[str](default="mpf_", doc="Column name prefix") 

107 suffix_error = pexConfig.Field[str]( 

108 default="_err", 

109 doc="Default suffix for error columns. Can be overridden by naming_scheme.", 

110 ) 

111 

112 _format_flux = { 

113 "default": "{label}{band}_flux", 

114 "lsst": "{band}_{label}Flux", 

115 "camel": "{label}{band}Flux", 

116 } 

117 _key_cen = {"default": "_cen", "lsst": "_cen", "camel": "Cen"} 

118 _key_reff = {"default": f"_{SersicComponentConfig._size_label}", "lsst": "_reff", "camel": "Reff"} 

119 _key_rho = {"default": "_rho", "lsst": "_rho", "camel": "Rho"} 

120 _key_sigma = {"default": f"_{GaussianComponentConfig._size_label}", "lsst": "_sigma", "camel": "Sigma"} 

121 _key_sersicindex = {"default": "_sersic_index", "lsst": "sersic_index", "camel": "SersicIndex"} 

122 _suffix_dec = {"default": "_dec", "lsst": "_dec", "camel": "Dec"} 

123 _suffix_ra = {"default": "_ra", "lsst": "_ra", "camel": "Ra"} 

124 _suffix_ra_dec_cov = {"default": "_ra_dec_cov", "lsst": "_ra_dec_Cov", "camel": "RaDecCov"} 

125 _suffix_x = {"default": "_x", "lsst": "_x", "camel": "X"} 

126 _suffix_y = {"default": "_y", "lsst": "_y", "camel": "Y"} 

127 

128 def _get_label(self, format_name: str, values: dict[str, str]) -> str: 

129 """Get the label for part of a column name for a given format. 

130 

131 Parameters 

132 ---------- 

133 format_name 

134 The name of the format to get the label for. 

135 values 

136 The values of the name by format. 

137 

138 Returns 

139 ------- 

140 label 

141 The formatted label, if specified for that format, else the 

142 value for the default format. 

143 """ 

144 return values.get(format_name, values["default"]) 

145 

146 def get_key_cen(self) -> str: 

147 """Get the key for centroid columns.""" 

148 return self._get_label(self.naming_scheme, self._key_cen) 

149 

150 def get_key_flux(self, band: str, label: str = "") -> str: 

151 """Get the key for a flux column. 

152 

153 Parameters 

154 ---------- 

155 band 

156 The band of the flux column. 

157 label 

158 A label for this flux, e.g. a component name. 

159 

160 Returns 

161 ------- 

162 key_flux 

163 The flux column key. 

164 """ 

165 return self._get_label(self.naming_scheme, self._format_flux).format(band=band, label=label) 

166 

167 def get_key_reff(self) -> str: 

168 """Get the key for Sersic effective radius columns.""" 

169 return self._get_label(self.naming_scheme, self._key_reff) 

170 

171 def get_key_rho(self) -> str: 

172 """Get the key for ellipse rho columns.""" 

173 return self._get_label(self.naming_scheme, self._key_rho) 

174 

175 def get_key_sersicindex(self) -> str: 

176 """Get the key for Sersic index columns.""" 

177 return self._get_label(self.naming_scheme, self._key_sersicindex) 

178 

179 def get_key_sigma(self) -> str: 

180 """Get the key for Gaussian sigma columns.""" 

181 return self._get_label(self.naming_scheme, self._key_sigma) 

182 

183 def get_key_size(self, label_size: str) -> str: 

184 """Get the key for a size column by its label. 

185 

186 Parameters 

187 ---------- 

188 label_size 

189 The label of the size, usually specified in a ComponentConfig. 

190 

191 Returns 

192 ------- 

193 key_size 

194 The size column key. 

195 """ 

196 if label_size == GaussianComponentConfig._size_label: 

197 return self._get_label(self.naming_scheme, self._key_sigma) 

198 elif label_size == SersicComponentConfig._size_label: 

199 return self._get_label(self.naming_scheme, self._key_reff) 

200 return label_size 

201 

202 def get_prefixed_label(self, label: str, prefix: str) -> str: 

203 """Get a prefixed label with redundant underscores removed. 

204 

205 Parameters 

206 ---------- 

207 label 

208 The label to format. 

209 prefix 

210 The prefix to prepend. 

211 

212 Returns 

213 ------- 

214 label_prefixed 

215 The prefixed label, with redundant underscores removed. 

216 """ 

217 if label.startswith("_") and ((prefix == "") or (prefix[-1] == "_")): 

218 return f"{prefix}{label[1:]}" 

219 return f"{prefix}{label}" 

220 

221 def get_suffix_dec(self) -> str: 

222 """Get the suffix for declination columns.""" 

223 return self._get_label(self.naming_scheme, self._suffix_dec) 

224 

225 def get_suffix_ra(self) -> str: 

226 """Get the suffix for right ascension columns.""" 

227 return self._get_label(self.naming_scheme, self._suffix_ra) 

228 

229 def get_suffix_ra_dec_cov(self) -> str: 

230 """Get the suffix for right ascension columns.""" 

231 return self._get_label(self.naming_scheme, self._suffix_ra_dec_cov) 

232 

233 def get_suffix_x(self) -> str: 

234 """Get the suffix for x-axis columns.""" 

235 return self._get_label(self.naming_scheme, self._suffix_x) 

236 

237 def get_suffix_y(self) -> str: 

238 """Get the suffix for y-axis columns.""" 

239 return self._get_label(self.naming_scheme, self._suffix_y) 

240 

241 def make_catalog(self, n_rows: int, **kwargs): 

242 """Make a catalog with default-initialized column values. 

243 

244 Parameters 

245 ---------- 

246 n_rows 

247 The number of rows to create. 

248 **kwargs 

249 Keyword arguments to pass to self.schema. 

250 

251 Returns 

252 ------- 

253 catalog 

254 The initialized catalog. 

255 columns 

256 The columns as returned by self.schema. 

257 """ 

258 columns = self.schema(**kwargs) 

259 keys = [column.key for column in columns] 

260 prefix = self.prefix_column 

261 

262 idx_flag_first = keys.index("unknown_flag") 

263 idx_flag_last = idx_flag_first + len(self.flag_errors) 

264 dtypes = [(f'{prefix if col.key != self.column_id else ""}{col.key}', col.dtype) for col in columns] 

265 

266 results = Table(np.empty(n_rows, dtype=dtypes)) 

267 for colname in results.colnames: 

268 column = results[colname] 

269 dtype = column.dtype 

270 if ( 

271 value := ( 

272 self.float_fill_value 

273 if np.issubdtype(dtype, np.floating) 

274 else (self.integer_fill_value if np.issubdtype(dtype, np.integer) else None) 

275 ) 

276 ) is not None: 

277 column[:] = value 

278 

279 # Set nan-default flags to False instead 

280 errors = [] 

281 for flag in columns[idx_flag_first : (idx_flag_last + 1)]: 

282 column = results[f"{prefix}{flag.key}"] 

283 column[:] = False 

284 if not ((column.dtype == bool) and column.name.endswith("_flag")): 

285 errors.append(f"{column.name=} should end with _flag and {column.dtype=} must be bool") 

286 if errors: 

287 errors.append(f"These may be logic errors in {self=}") 

288 raise RuntimeError("\n".join(errors)) 

289 results.meta["config"] = self.toDict() 

290 

291 return results, columns 

292 

293 def schema( 

294 self, 

295 bands: list[str] | None = None, 

296 ) -> list[ColumnInfo]: 

297 """Return the schema as an ordered list of columns. 

298 

299 Parameters 

300 ---------- 

301 bands 

302 A list of band names to prefix band-dependent columns with. 

303 Band prefixes should not be used if None. 

304 

305 Returns 

306 ------- 

307 schema 

308 An ordered list of ColumnInfo instances. 

309 """ 

310 schema = [ 

311 ColumnInfo(key=self.column_id, dtype="i8"), 

312 ColumnInfo(key="n_iter", dtype="i4"), 

313 ColumnInfo(key="time_eval", dtype="f8", unit=u.s), 

314 ColumnInfo(key="time_fit", dtype="f8", unit=u.s), 

315 ColumnInfo(key="time_full", dtype="f8", unit=u.s), 

316 ColumnInfo(key="chisq_reduced", dtype="f8"), 

317 ColumnInfo(key="unknown_flag", dtype="bool"), 

318 ] 

319 schema.extend([ColumnInfo(key=key, dtype="bool") for key in self.flag_errors.keys()]) 

320 # Subclasses should always write out centroids even if not fitting 

321 # They are helpful for reconstructing models 

322 return schema