Coverage for python / lsst / meas / extensions / multiprofit / pipetasks_fit.py: 37%

256 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-22 09:34 +0000

1# This file is part of meas_extensions_multiprofit. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = ( 

23 "component_names_default", 

24 "model_names_default", 

25 "MultiProFitCoaddPsfFitConfig", 

26 "MultiProFitCoaddPsfFitTask", 

27 "MultiProFitCoaddObjectFitConfig", 

28 "MultiProFitCoaddPointFitConfig", 

29 "MultiProFitCoaddSersicFitConfig", 

30 "MultiProFitCoaddSersicFitTask", 

31 "MultiProFitCoaddGaussFitConfig", 

32 "MultiProFitCoaddGaussFitTask", 

33 "MultiProFitCoaddExpFitConfig", 

34 "MultiProFitCoaddExpFitTask", 

35 "MultiProFitCoaddDeVFitConfig", 

36 "MultiProFitCoaddDeVFitTask", 

37 "MultiProFitCoaddExpDeVFitConfig", 

38 "MultiProFitCoaddExpDeVFitTask", 

39) 

40 

41from abc import abstractmethod 

42import itertools 

43import math 

44from types import SimpleNamespace 

45from typing import Any, Mapping, Sequence 

46 

47import lsst.gauss2d.fit as g2f 

48from lsst.multiprofit.componentconfig import ( 

49 GaussianComponentConfig, 

50 ParameterConfig, 

51 SersicComponentConfig, 

52 SersicIndexParameterConfig, 

53) 

54from lsst.multiprofit.fitting.fit_source import CatalogExposureSourcesABC, CatalogSourceFitterConfigData 

55from lsst.multiprofit.modelconfig import ModelConfig 

56from lsst.multiprofit.sourceconfig import ComponentGroupConfig, SourceConfig 

57from lsst.pex.config import ConfigDictField, Field 

58from lsst.pipe.tasks.fit_coadd_multiband import ( 

59 CatalogExposureInputs, 

60 CoaddMultibandFitConfig, 

61 CoaddMultibandFitConnections, 

62 CoaddMultibandFitTask, 

63) 

64from lsst.pipe.tasks.fit_coadd_psf import CoaddPsfFitConfig, CoaddPsfFitConnections, CoaddPsfFitTask 

65 

66from .fit_coadd_multiband import ( 

67 CachedBasicModelInitializer, 

68 MagnitudeDependentSizePriorConfig, 

69 MakeBasicInitializerAction, 

70 ModelInitializer, 

71 MultiProFitSourceTask, 

72 PsfComponentsActionBase, 

73 SourceTablePsfComponentsAction, 

74) 

75from .fit_coadd_psf import MultiProFitPsfTask 

76from .input_config import InputConfig 

77 

78component_names_default = SimpleNamespace( 

79 point="point", 

80 gauss="gauss", 

81 exp="exp", 

82 deV="deV", 

83 sersic="sersic", 

84) 

85 

86model_names_default = SimpleNamespace( 

87 point="Point", 

88 gauss="Gauss", 

89 exp="Exp", 

90 deV="DeV", 

91 sersic="Sersic", 

92 fixed_cen="FixedCen", 

93 shapelet_psf="ShapeletPsf", 

94) 

95 

96 

97class MultiProFitCoaddPsfFitConfig( 

98 CoaddPsfFitConfig, 

99 pipelineConnections=CoaddPsfFitConnections, 

100): 

101 """MultiProFit PSF fit task config.""" 

102 

103 def setDefaults(self): 

104 super().setDefaults() 

105 self.fit_coadd_psf.retarget(MultiProFitPsfTask) 

106 self.fit_coadd_psf.config_fit.eval_residual = False 

107 

108 

109class MultiProFitCoaddPsfFitTask(CoaddPsfFitTask): 

110 """MultiProFit PSF fit task.""" 

111 

112 ConfigClass = MultiProFitCoaddPsfFitConfig 

113 _DefaultName = "multiProFitCoaddPsfFit" 

114 

115 

116class MultiProFitCoaddObjectFitConnections(CoaddMultibandFitConnections): 

117 def __init__(self, *, config=None): 

118 super().__init__(config=config) 

119 for name, config_input in config.inputs_init.items(): 

120 if hasattr(self, name): 

121 raise ValueError( 

122 f"{config_input=} {name=} is invalid, due to being an existing attribute" f" of {self=}" 

123 ) 

124 if config_input.is_multipatch or not config_input.is_multiband: 

125 raise ValueError( 

126 f"Single-band and/or multipatch initialization config_input entries ({name})" 

127 f" are not supported yet." 

128 ) 

129 connection = config_input.get_connection(name) 

130 setattr(self, name, connection) 

131 

132 

133class MultiProFitCoaddObjectFitConfig( 

134 CoaddMultibandFitConfig, 

135 pipelineConnections=MultiProFitCoaddObjectFitConnections, 

136): 

137 """Generic MultiProFit source fit task config.""" 

138 

139 inputs_init = ConfigDictField( 

140 doc="Mapping of optional input dataset configs by name, for initialization", 

141 keytype=str, 

142 itemtype=InputConfig, 

143 default={}, 

144 ) 

145 

146 # This needs to be set, ideally in setDefaults of subclasses 

147 name_model = Field[str](doc="The name of the model", default=None) 

148 

149 def _get_source(self): 

150 return next(iter(self.fit_coadd_multiband.config_model.sources.values())) 

151 

152 def _get_component_group(self, source: SourceConfig | None = None): 

153 if source is None: 

154 source = self._get_source() 

155 return next(iter(source.component_groups.values())) 

156 

157 def add_point_source(self, name: str | None = None): 

158 """Add a point source component. 

159 

160 Parameters 

161 ---------- 

162 name 

163 The name of the component. 

164 """ 

165 if name is None: 

166 name = component_names_default.point 

167 source = self._get_source() 

168 group = self._get_component_group(source=source) 

169 if name in group.components_gauss: 

170 raise RuntimeError(f"{name=} component already exists in {source=}") 

171 group.components_gauss[name] = self.make_point_source_component() 

172 self.connections.name_table += model_names_default.point 

173 

174 def finalize( 

175 self, 

176 add_point_source: bool = False, 

177 fix_centroid: bool = False, 

178 use_shapelet_psf: bool = False, 

179 prior_axrat_stddev: float | str | None = None, 

180 ): 

181 """Apply runtime configuration changes to this config. 

182 

183 Parameters 

184 ---------- 

185 add_point_source 

186 Whether to add a point source component. 

187 fix_centroid 

188 Whether to fix the centroid. 

189 use_shapelet_psf 

190 Whether to initialize PSF parameters from prior shapelet fits. 

191 prior_axrat_stddev 

192 The standard deviation for the axis ratio prior. Ignored if None, 

193 otherwise it must be convertible to a float. 

194 """ 

195 if add_point_source: 

196 self.add_point_source() 

197 if fix_centroid: 

198 self.fix_centroid() 

199 if use_shapelet_psf: 

200 self.use_shapelet_psf() 

201 if prior_axrat_stddev is not None: 

202 self.set_prior_axrat_stddev(float(prior_axrat_stddev)) 

203 

204 def fix_centroid(self): 

205 """Fix (freeze) the source centroid parameters.""" 

206 group = self._get_component_group() 

207 centroids = group.centroids["default"] 

208 centroids.x.fixed = True 

209 centroids.y.fixed = True 

210 self.connections.name_table += model_names_default.fixed_cen 

211 

212 @classmethod 

213 @abstractmethod 

214 def get_model_name_default(cls) -> str: 

215 """Return the default name for this model in table columns.""" 

216 raise NotImplementedError("Subclasses must implement get_model_name_default") 

217 

218 @classmethod 

219 def get_model_name_full(cls) -> str: 

220 """Return a longer, more descriptive name for the model.""" 

221 return cls.get_model_name_default() 

222 

223 @abstractmethod 

224 def make_default_model_config(self) -> ModelConfig: 

225 """Make a default configuration object for this model.""" 

226 raise NotImplementedError("Subclasses must implement make_default_model_config") 

227 

228 @staticmethod 

229 def make_point_source_component() -> GaussianComponentConfig: 

230 """Make a point source component config (zero-size Gaussian).""" 

231 return GaussianComponentConfig( 

232 size_x=ParameterConfig(value_initial=0.0, fixed=True), 

233 size_y=ParameterConfig(value_initial=0.0, fixed=True), 

234 rho=ParameterConfig(value_initial=0.0, fixed=True), 

235 ) 

236 

237 @staticmethod 

238 def make_sersic_component(**kwargs) -> SersicComponentConfig: 

239 """Make a default Sersic component config. 

240 

241 Parameters 

242 ---------- 

243 **kwargs 

244 Keyword arguments to pass to the SersicIndexParameterConfig. 

245 

246 Returns 

247 ------- 

248 config 

249 The default-initialized config. 

250 """ 

251 return SersicComponentConfig( 

252 prior_axrat_stddev=1.0, 

253 prior_size_stddev=0.2, 

254 sersic_index=SersicIndexParameterConfig(**kwargs), 

255 ) 

256 

257 @staticmethod 

258 def make_single_model_config(group: ComponentGroupConfig) -> ModelConfig: 

259 """Make a default single-source, single component group config. 

260 

261 Parameters 

262 ---------- 

263 group 

264 The component group config for the single source. 

265 

266 Returns 

267 ------- 

268 config 

269 A model config with a single nameless source and component group. 

270 """ 

271 return ModelConfig( 

272 sources={ 

273 "": SourceConfig( 

274 component_groups={ 

275 "": group, 

276 } 

277 ) 

278 } 

279 ) 

280 

281 def set_prior_axrat_stddev(self, stddev: float) -> None: 

282 """Set the standard deviation for all axis ratio priors. 

283 

284 Parameters 

285 ---------- 

286 stddev 

287 The standard deviation. 

288 """ 

289 for source in self.fit_coadd_multiband.config_model.sources.values(): 

290 for group in source.component_groups.values(): 

291 for comp in itertools.chain( 

292 group.components_gauss.values(), 

293 group.components_sersic.values(), 

294 ): 

295 comp.prior_axrat_stddev = stddev 

296 

297 group = self._get_component_group() 

298 centroids = group.centroids["default"] 

299 centroids.x.fixed = True 

300 centroids.y.fixed = True 

301 self.connections.name_table += model_names_default.fixed_cen 

302 

303 def setDefaults(self): 

304 super().setDefaults() 

305 self.fit_coadd_multiband.retarget(MultiProFitSourceTask) 

306 self.fit_coadd_multiband.action_psf = PsfComponentsActionBase() 

307 self.fit_coadd_multiband.bands_fit = ("u", "g", "r", "i", "z", "y") 

308 

309 self.fit_coadd_multiband.config_model = self.make_default_model_config() 

310 self.name_model = self.get_model_name_default() 

311 self.connections.name_table = self.name_model 

312 

313 def use_shapelet_psf(self): 

314 """Reconfigure self to use prior shapelet PSF fit parameters.""" 

315 self.fit_coadd_multiband.action_psf = SourceTablePsfComponentsAction() 

316 self.drop_psf_connection = True 

317 self.connections.name_table += model_names_default.shapelet_psf 

318 

319 

320class MultiProFitCoaddObjectFitTask(CoaddMultibandFitTask): 

321 """MultiProFit coadd object model fitting task.""" 

322 

323 ConfigClass = MultiProFitCoaddObjectFitConfig 

324 _DefaultName = "multiProFitCoaddObjectFit" 

325 

326 def make_kwargs(self, butlerQC, inputRefs, inputs): 

327 inputs_init = {name: (config, inputs[name][0]) for name, config in self.config.inputs_init.items()} 

328 kwargs = {} 

329 if inputs_init: 

330 kwargs["inputs_init"] = inputs_init 

331 

332 return kwargs 

333 

334 

335class MultiProFitCoaddPointFitConfig( 

336 MultiProFitCoaddObjectFitConfig, 

337 pipelineConnections=MultiProFitCoaddObjectFitConnections, 

338): 

339 """MultiProFit single Sersic model fit task config.""" 

340 

341 @classmethod 

342 def get_model_name_default(cls) -> str: 

343 return model_names_default.point 

344 

345 @classmethod 

346 def get_model_name_full(cls) -> str: 

347 return "Point Source" 

348 

349 def make_default_model_config(self) -> ModelConfig: 

350 config_group = ComponentGroupConfig() 

351 # This is a bit silly but add_point_source will look for the first 

352 # source so it must be added now. Perhaps add_point_source should 

353 # add to a config instance or only self by default 

354 self.fit_coadd_multiband.config_model = self.make_single_model_config(group=config_group) 

355 self.add_point_source() 

356 return self.fit_coadd_multiband.config_model 

357 

358 

359class MultiProFitCoaddSersicFitConfig( 

360 MultiProFitCoaddObjectFitConfig, 

361 pipelineConnections=MultiProFitCoaddObjectFitConnections, 

362): 

363 """MultiProFit single Sersic model fit task config.""" 

364 

365 def _rename_defaults( 

366 self, 

367 name_new: str, 

368 name_model: str | None = None, 

369 name_old: str | None = None, 

370 index_new: float | None = None, 

371 fix_index: bool = False, 

372 ): 

373 """Rename the default Sersic component to something more specific. 

374 

375 This is intended for fixed index models such as exponential and 

376 deVaucouleurs. 

377 

378 Parameters 

379 ---------- 

380 name_new 

381 The new name for the component. 

382 name_model 

383 The new name of the model. Default is to capitalize name_new. 

384 name_old 

385 The old name of the component. Default is to set to 

386 component_names_default.sersic. 

387 index_new 

388 The initial value for the Sersic index. 

389 fix_index 

390 Whether the fix the index to the new value. 

391 """ 

392 if name_old is None: 

393 name_old = component_names_default.sersic 

394 if name_model is None: 

395 name_model = name_new.capitalize() 

396 group = self._get_component_group() 

397 comps_sersic = group.components_sersic 

398 

399 if name_new in comps_sersic: 

400 raise RuntimeError(f"{name_new=} is already in {comps_sersic=}") 

401 

402 comp_sersic = comps_sersic[name_old] 

403 del comps_sersic[name_old] 

404 if index_new is not None: 

405 comp_sersic.sersic_index.value_initial = index_new 

406 if fix_index: 

407 comp_sersic.sersic_index.fixed = True 

408 comps_sersic[name_new] = comp_sersic 

409 

410 if prior_old := self.fit_coadd_multiband.size_priors.get(name_old): 

411 self.fit_coadd_multiband.size_priors[name_new] = prior_old 

412 del self.fit_coadd_multiband.size_priors[name_old] 

413 

414 self.name_model = name_model 

415 self.connections.name_table = name_model 

416 

417 @classmethod 

418 def get_model_name_default(cls) -> str: 

419 return model_names_default.sersic 

420 

421 @classmethod 

422 def get_model_name_full(cls) -> str: 

423 return "Sersic" 

424 

425 def make_default_model_config(self) -> ModelConfig: 

426 config_group = ComponentGroupConfig( 

427 components_sersic={ 

428 component_names_default.sersic: self.make_sersic_component(), 

429 }, 

430 ) 

431 return self.make_single_model_config(group=config_group) 

432 

433 def setDefaults(self): 

434 super().setDefaults() 

435 # This is in pixels and based on DC2. See DM-46498 for details. 

436 self.fit_coadd_multiband.size_priors[component_names_default.sersic] = ( 

437 MagnitudeDependentSizePriorConfig( 

438 intercept_mag=22.6, 

439 slope_median_per_mag=-0.15, 

440 slope_stddev_per_mag=0, 

441 ) 

442 ) 

443 

444 

445class MultiProFitCoaddSersicFitTask(MultiProFitCoaddObjectFitTask): 

446 """MultiProFit single Sersic model fit task.""" 

447 

448 ConfigClass = MultiProFitCoaddSersicFitConfig 

449 _DefaultName = "multiProFitCoaddSersicFit" 

450 

451 

452class MultiProFitCoaddGaussFitConfig( 

453 MultiProFitCoaddSersicFitConfig, 

454 pipelineConnections=MultiProFitCoaddObjectFitConnections, 

455): 

456 """MultiProFit single Gaussian model fit task config.""" 

457 

458 @classmethod 

459 def get_model_name_default(cls) -> str: 

460 return model_names_default.gauss 

461 

462 @classmethod 

463 def get_model_name_full(cls) -> str: 

464 return "Gaussian" 

465 

466 def setDefaults(self): 

467 super().setDefaults() 

468 self._rename_defaults( 

469 name_new=component_names_default.gauss, 

470 name_model=model_names_default.gauss, 

471 index_new=0.5, 

472 fix_index=True, 

473 ) 

474 

475 

476class MultiProFitCoaddGaussFitTask(MultiProFitCoaddObjectFitTask): 

477 """MultiProFit single Gaussian model fit task.""" 

478 

479 ConfigClass = MultiProFitCoaddGaussFitConfig 

480 _DefaultName = "multiProFitCoaddGaussFit" 

481 

482 

483class MultiProFitCoaddExpFitConfig( 

484 MultiProFitCoaddSersicFitConfig, 

485 pipelineConnections=MultiProFitCoaddObjectFitConnections, 

486): 

487 """MultiProFit single exponential model fit task config.""" 

488 

489 @classmethod 

490 def get_model_name_default(cls) -> str: 

491 return model_names_default.exp 

492 

493 @classmethod 

494 def get_model_name_full(cls) -> str: 

495 return "Exponential" 

496 

497 def setDefaults(self): 

498 super().setDefaults() 

499 self._rename_defaults( 

500 name_new=component_names_default.exp, 

501 name_model=model_names_default.exp, 

502 index_new=1.0, 

503 fix_index=True, 

504 ) 

505 # These are typical values from DC2 and could/should be switched to a 

506 # more data-driven prior (from HSC?) 

507 prior_size = self.fit_coadd_multiband.size_priors[component_names_default.exp] 

508 prior_size.intercept_mag = 23.4 

509 prior_size.slope_median_per_mag = -0.14 

510 

511 

512class MultiProFitCoaddExpFitTask(MultiProFitCoaddObjectFitTask): 

513 """MultiProFit single exponential model fit task.""" 

514 

515 ConfigClass = MultiProFitCoaddExpFitConfig 

516 _DefaultName = "multiProFitCoaddExpFit" 

517 

518 

519class MultiProFitCoaddDeVFitConfig( 

520 MultiProFitCoaddSersicFitConfig, 

521 pipelineConnections=MultiProFitCoaddObjectFitConnections, 

522): 

523 """MultiProFit single DeVaucouleurs model fit task config.""" 

524 

525 @classmethod 

526 def get_model_name_default(cls) -> str: 

527 return model_names_default.deV 

528 

529 @classmethod 

530 def get_model_name_full(cls) -> str: 

531 return "de Vaucouleurs" 

532 

533 def setDefaults(self): 

534 super().setDefaults() 

535 self._rename_defaults( 

536 name_new=component_names_default.deV, 

537 name_model=model_names_default.deV, 

538 index_new=4.0, 

539 fix_index=True, 

540 ) 

541 # These are typical values from DC2 and could/should be switched to a 

542 # more data-driven prior (from HSC?). See DM-46498 for details. 

543 prior_size = self.fit_coadd_multiband.size_priors[component_names_default.deV] 

544 prior_size.intercept_mag = 21.2 

545 prior_size.slope_median_per_mag = -0.14 

546 

547 

548class MultiProFitCoaddDeVFitTask(MultiProFitCoaddObjectFitTask): 

549 """MultiProFit single DeVaucouleurs model fit task.""" 

550 

551 ConfigClass = MultiProFitCoaddDeVFitConfig 

552 _DefaultName = "multiProFitCoaddDeVFit" 

553 

554 

555class CachedChainedModelInitializer(CachedBasicModelInitializer): 

556 def get_centroid_and_shape( 

557 self, 

558 source: Mapping[str, Any], 

559 catexps: list[CatalogExposureSourcesABC], 

560 config_data: CatalogSourceFitterConfigData, 

561 values_init: Mapping[g2f.ParameterD, float] | None = None, 

562 ) -> tuple[tuple[float, float], tuple[float, float, float]]: 

563 row_best = None 

564 chisq_red_min = math.inf 

565 for name, input_data in self.inputs.items(): 

566 data = input_data.data 

567 index_row = input_data.id_index.get(source["id"]) 

568 if index_row is not None: 

569 row = data[index_row] 

570 chisq_red = input_data.get_column("chisq_reduced", data=row) 

571 if chisq_red < chisq_red_min: 

572 row_best = (row, input_data) 

573 chisq_red_min = chisq_red 

574 if row_best is None: 

575 return super().get_centroid_and_shape( 

576 source=source, 

577 catexps=catexps, 

578 config_data=config_data, 

579 values_init=values_init, 

580 ) 

581 row_best, input_data = row_best 

582 cen_x, cen_y, reff_x, reff_y, rho = ( 

583 input_data.get_column(column, data=row_best) 

584 for column in ( 

585 input_data.get_column("cen_x").name, 

586 input_data.get_column("cen_y").name, 

587 input_data.get_column(f"{input_data.size_column}_x").name, 

588 input_data.get_column(f"{input_data.size_column}_y").name, 

589 input_data.get_column("rho").name, 

590 ) 

591 ) 

592 return (cen_x, cen_y), (reff_x, reff_y, rho) 

593 

594 

595class MakeCachedChainedInitializerAction(MakeBasicInitializerAction): 

596 def _make_initializer( 

597 self, 

598 catalog_multi: Sequence, 

599 catexps: list[CatalogExposureInputs], 

600 config_data: CatalogSourceFitterConfigData, 

601 ) -> ModelInitializer: 

602 sources, priors = config_data.sources_priors 

603 return CachedChainedModelInitializer(config=self.config, priors=priors, sources=sources) 

604 

605 

606class MultiProFitCoaddExpDeVFitConfig( 

607 MultiProFitCoaddObjectFitConfig, 

608 pipelineConnections=MultiProFitCoaddObjectFitConnections, 

609): 

610 """MultiProFit single Exponential+DeVaucouleurs model fit task config.""" 

611 

612 @classmethod 

613 def get_model_name_default(cls) -> str: 

614 return f"{model_names_default.exp}{model_names_default.deV}" 

615 

616 @classmethod 

617 def get_model_name_full(cls) -> str: 

618 return "Exponential + de Vaucouleurs" 

619 

620 def make_default_model_config(self) -> ModelConfig: 

621 config_group = ComponentGroupConfig( 

622 components_sersic={ 

623 component_names_default.exp: self.make_sersic_component(value_initial=1.0, fixed=True), 

624 component_names_default.deV: self.make_sersic_component(value_initial=4.0, fixed=True), 

625 }, 

626 ) 

627 return self.make_single_model_config(group=config_group) 

628 

629 def setDefaults(self): 

630 super().setDefaults() 

631 self.fit_coadd_multiband.action_initializer = MakeCachedChainedInitializerAction() 

632 self.fit_coadd_multiband.config_model = self.make_default_model_config() 

633 self.name_model = self.get_model_name_default() 

634 self.connections.name_table = self.name_model 

635 

636 size_priors = self.fit_coadd_multiband.size_priors 

637 size_priors[component_names_default.exp] = MagnitudeDependentSizePriorConfig( 

638 intercept_mag=23.3, 

639 slope_median_per_mag=-0.14, 

640 ) 

641 size_priors[component_names_default.deV] = MagnitudeDependentSizePriorConfig( 

642 intercept_mag=21.2, 

643 slope_median_per_mag=-0.14, 

644 ) 

645 

646 

647class MultiProFitCoaddExpDeVFitTask(MultiProFitCoaddObjectFitTask): 

648 """MultiProFit single ExpDeV model fit task.""" 

649 

650 ConfigClass = MultiProFitCoaddExpDeVFitConfig 

651 _DefaultName = "multiProFitCoaddExpDeVFit"