Coverage for python / lsst / meas / photoz / base / estimate_photoz_task.py: 48%

177 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-17 09:01 +0000

1# This file is part of meas_photoz_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24__all__ = [ 

25 "EstimatePhotozAlgoConfigBase", 

26 "EstimatePhotozAlgoTask", 

27 "EstimatePhotozTask", 

28 "EstimatePhotozTaskConfig", 

29 "photozAlgoRegistry", 

30] 

31 

32from abc import ABC, abstractmethod 

33from typing import Any 

34 

35import numpy as np 

36from astropy.table import Table 

37from ceci.config import StageConfig as CeciStageConfig 

38from ceci.config import StageParameter as CeciParam 

39from rail.core.model import Model 

40from rail.estimation.estimator import CatEstimator 

41from rail.interfaces import PZFactory 

42 

43import lsst.pex.config as pexConfig 

44import lsst.pipe.base.connectionTypes as cT 

45from lsst.pipe.base import ( 

46 PipelineTask, 

47 PipelineTaskConfig, 

48 PipelineTaskConnections, 

49 Struct, 

50 Task, 

51) 

52 

53 

54class EstimatePhotozConnections( 

55 PipelineTaskConnections, 

56 dimensions=("skymap", "tract"), 

57 defaultTemplates={"algo": "trainz"}, 

58): 

59 """Connections for tasks that make p(z) estimates. 

60 

61 These will take pickled model file as a "calibration-like" input, 

62 an object table as input, and create a p(z) file in 'qp' format. 

63 """ 

64 

65 photoz_model = cT.PrerequisiteInput( 

66 doc="Model for PZ Estimation", 

67 name="photoz_model_{algo}", 

68 storageClass="PhotozModel", 

69 dimensions=["instrument"], 

70 isCalibration=True, 

71 ) 

72 objects = cT.Input( 

73 doc="Object table", 

74 name="object", 

75 storageClass="ArrowAstropy", 

76 dimensions=("skymap", "tract"), 

77 deferLoad=True, 

78 ) 

79 

80 photoz_ensemble = cT.Output( 

81 doc="Per-object p(z) estimates", 

82 name="photoz_ensemble_{algo}", 

83 storageClass="QPEnsemble", 

84 dimensions=("skymap", "tract"), 

85 ) 

86 

87 

88class EstimatePhotozAlgoConfigBase( 

89 pexConfig.Config, 

90): 

91 """Base class for configurations of algorithm-specific p(z) 

92 estimation tasks. 

93 

94 This class mostly just translates the RAIL configuration 

95 parameters to pex.config parameters. 

96 

97 Subclasses will just have to set `estimator_class` and `stage_name` 

98 and invoke `_make_fields` once in the module. 

99 """ 

100 

101 @classmethod 

102 @abstractmethod 

103 def estimator_class(cls) -> type[CatEstimator]: 

104 """Return the type of the estimator's RAIL class.""" 

105 raise NotImplementedError("Subclasses must specify an estimator class") 

106 

107 # This should be a property but py3.13+ don't allow it 

108 @classmethod 

109 @abstractmethod 

110 def stage_name(cls) -> str: 

111 """Return the RAIL stage name for the estimator.""" 

112 raise NotImplementedError("Subclasses must define a RAIL stage name") 

113 

114 # Extinction coefficients; see https://ui.adsabs.harvard.edu/abs/1989ApJ...345..245C/abstract 

115 # Also in rail.utils.catalog_utils.RubinCatalogConfig.a_env 

116 default_a_env_values = dict( 

117 u=4.81, 

118 g=3.64, 

119 r=2.70, 

120 i=2.06, 

121 z=1.58, 

122 y=1.31, 

123 ) 

124 

125 # These appear in many DESC repos, for example: 

126 # https://github.com/LSSTDESC/TXPipe/blob/00ebe7476fd5d9529f5bbc4d73fcef0629d134c7/examples/dp0.2/config.yml#L47 

127 # They seem to be 10y WFD limits. Origin unclear. 

128 default_mag_limit_10y_values = dict( 

129 u=27.79, 

130 g=29.04, 

131 r=29.06, 

132 i=28.62, 

133 z=27.98, 

134 y=27.05, 

135 ) 

136 

137 # These appear to be from Roman-Rubin simulations: 

138 # https://github.com/LSSTDESC/rail_base/blob/v1.2.1/src/rail/utils/catalog_utils.py#L207 

139 # Presumably max 5y depth, and more useful for now 

140 default_mag_limit_values = dict( 

141 u=24.0, 

142 g=27.66, 

143 r=27.25, 

144 i=26.6, 

145 z=26.24, 

146 y=25.35, 

147 ) 

148 

149 def get_band_a_env_dict(self): 

150 """Return the set of a_envs to use.""" 

151 return {band_: self.default_a_env_values[band_] for band_ in self.bands_to_convert} 

152 

153 def get_mag_lim_dict(self): 

154 """Return the set of maglims to use.""" 

155 return { 

156 self.mag_template.format(band=band): self.default_mag_limit_values[band] 

157 for band in self.bands_to_convert 

158 } 

159 

160 def get_flux_names(self) -> dict[str, str]: 

161 """Return a dict mapping band to flux column name.""" 

162 return {band: self.flux_column_template.format(band=band) for band in self.bands_to_convert} 

163 

164 def get_flux_err_names(self) -> dict[str, str]: 

165 """Return a dict mapping band to flux error column name.""" 

166 return {band: self.flux_err_column_template.format(band=band) for band in self.bands_to_convert} 

167 

168 def get_mag_names(self) -> dict[str, str]: 

169 """Return a dict mapping band to mag column name.""" 

170 return {band: self.mag_template.format(band=band) for band in self.bands_to_convert} 

171 

172 def get_mag_err_names(self) -> dict[str, str]: 

173 """Return a dict mapping band to mag error column name.""" 

174 return {band: self.mag_err_template.format(band=band) for band in self.bands_to_convert} 

175 

176 mag_offset = pexConfig.Field(doc="Magnitude offset", dtype=float, default=31.4) 

177 deredden = pexConfig.Field[bool]( 

178 doc="Apply dereddening", 

179 default=True, 

180 ) 

181 band_ref = pexConfig.Field[str]( 

182 doc="Name of the most reliable reference band, if needed", 

183 default="i", 

184 ) 

185 bands_to_convert = pexConfig.ListField[str]( 

186 doc="Names of bands to convert fluxs to mags for RAIL", 

187 default=["u", "g", "r", "i", "z", "y"], 

188 ) 

189 flux_column_template = pexConfig.Field[str]( 

190 doc="Template for flux column names", 

191 default="{band}_gaap1p0Flux", 

192 # default="{band}_cModelFlux", 

193 ) 

194 flux_err_column_template = pexConfig.Field[str]( 

195 doc="Template for flux error column names", 

196 default="{band}_gaap1p0FluxErr", 

197 # default="{band}_cModelFluxErr", 

198 ) 

199 mag_template = pexConfig.Field[str]( 

200 doc="Template for magnitude names", 

201 default="{band}_gaap1p0Mag", 

202 # default="{band}_cModelMag", 

203 ) 

204 mag_err_template = pexConfig.Field[str]( 

205 doc="Template for magntitude error names", 

206 default="{band}_gaap1p0MagErr", 

207 # default="{band}_cModelMagErr", 

208 ) 

209 nondetect_val = pexConfig.Field[float]( 

210 doc="Magnitude to set for non-detections", 

211 default=np.nan, 

212 ) 

213 band_a_env = pexConfig.DictField[str, float]( 

214 doc="Reddening parameters", 

215 default=default_a_env_values, 

216 ) 

217 

218 def freeze(self): 

219 if not self._frozen: 

220 self._finalize() 

221 super().freeze() 

222 

223 def _finalize(self): 

224 # These calls will fail if it's already frozen. 

225 if hasattr(self, "ref_band"): 

226 self.ref_band = self.mag_template.format(band=self.band_ref) 

227 if hasattr(self, "bands"): 

228 # This is a list of mag columns in RAIL, not bands 

229 self.bands = list(self.get_mag_names().values()) 

230 if hasattr(self, "err_bands"): 

231 self.err_bands = list(self.get_mag_err_names().values()) 

232 if hasattr(self, "mag_limits"): 

233 self.mag_limits = self.get_mag_lim_dict() 

234 if hasattr(self, "band_a_env"): 

235 self.band_a_env = self.get_band_a_env_dict() 

236 

237 @classmethod 

238 def _make_fields(cls) -> None: 

239 """Import the RAIL estimation stage. 

240 

241 This method loops through the stage config parameters and converts 

242 RAIL/Ceci parameters to corresponding pex.config parameters. 

243 

244 It should be called exactly once, immediately after the definition 

245 of every subclass of this base class. 

246 """ 

247 if hasattr(cls, "__fields_made__"): 247 ↛ 248line 247 didn't jump to line 248 because the condition on line 247 was never true

248 if cls.__fields_made__ is not True: 

249 raise RuntimeError(f"{cls.__fields_made__=} exists but is not True") 

250 raise RuntimeError(f"{cls=} called _make_fields twice") 

251 stage_class = cls.estimator_class() 

252 for key, val in stage_class.config_options.items(): 

253 if isinstance(val, CeciStageConfig): 

254 val = val.get(key) 

255 if isinstance(val, CeciParam): 255 ↛ 252line 255 didn't jump to line 252 because the condition on line 255 was always true

256 if val.dtype in [bool, int, float, str]: 

257 if (attr := getattr(cls, key, None)) is not None: 

258 if not isinstance(attr, pexConfig.Field): 258 ↛ 259line 258 didn't jump to line 259 because the condition on line 258 was never true

259 raise RuntimeError(f"{cls=} {key=} exists but is of {type(key)=}, not Field") 

260 elif attr.dtype != val.dtype: 260 ↛ 261line 260 didn't jump to line 261 because the condition on line 260 was never true

261 raise RuntimeError(f"{cls=} {key=} exists but {attr.dtype=} != {val.dtype=}") 

262 attr.default = val.default 

263 attr.doc = f"{val.msg} (overriding base doc='{attr.doc}')" 

264 else: 

265 setattr( 

266 cls, 

267 key, 

268 pexConfig.Field(doc=val.msg, dtype=val.dtype, default=val.default), 

269 ) 

270 elif val.dtype in [list]: 

271 # this is a hack, but it works. 

272 if val.default: 

273 item_type = type(val.default[0]) 

274 else: 

275 item_type = str 

276 setattr( 

277 cls, 

278 key, 

279 pexConfig.ListField(doc=val.msg, dtype=item_type, default=val.default), 

280 ) 

281 elif val.dtype in [dict]: 281 ↛ 252line 281 didn't jump to line 252 because the condition on line 281 was always true

282 setattr( 

283 cls, 

284 key, 

285 pexConfig.DictField(doc=val.msg, keytype=str, default=val.default), 

286 ) 

287 cls.__fields_made__ = True 

288 

289 

290photozAlgoRegistry = pexConfig.makeRegistry( 

291 doc="A registry of photometric redshift estimation algorithm subtasks", 

292) 

293 

294 

295class EstimatePhotozAlgoTask(Task, ABC): 

296 """Task for algorithm-specific p(z) estimation. 

297 

298 This provides almost all of the functionality 

299 needed to run RAIL p(z) algorithms. 

300 

301 Parameters 

302 ---------- 

303 **kwargs 

304 Additional keyword arguments to pass to super().__init__. 

305 """ 

306 

307 ConfigClass = EstimatePhotozAlgoConfigBase 

308 

309 mag_conv = np.log(10) * 0.4 

310 

311 def __init__(self, **kwargs: Any): 

312 super().__init__(**kwargs) 

313 

314 @staticmethod 

315 def _flux_to_mag( 

316 flux_vals: np.ndarray, 

317 mag_offset: float, 

318 nondetect_val: float, 

319 ) -> np.ndarray: 

320 """Convert flux to magnitude. 

321 

322 Parameters 

323 ---------- 

324 flux_vals : np.array 

325 Input flux values (units?) 

326 

327 mag_offset : float 

328 Magnitude offset (corresponding to a flux of 1.) 

329 

330 nondetect_val : float 

331 Value to set for non-detections 

332 

333 Returns 

334 ------- 

335 mags : np.array 

336 Magnitude values 

337 """ 

338 vals = np.empty_like(flux_vals) 

339 positive = flux_vals >= 0 

340 vals[positive] = -2.5 * np.log10(flux_vals[positive]) + mag_offset 

341 vals[~positive] = nondetect_val 

342 return vals 

343 

344 @staticmethod 

345 def _flux_err_to_mag_err( 

346 flux_vals: np.ndarray, 

347 flux_err_vals: np.ndarray, 

348 mag_conv: float, 

349 nondetect_val: float = np.nan, 

350 ) -> np.ndarray: 

351 """Config flux error to magnitude error. 

352 

353 Parameters 

354 ---------- 

355 flux_vals : np.array 

356 Input flux values (units?) 

357 

358 flux_err_vals : np.array 

359 Input flux errors (units?) 

360 

361 mag_conv : float 

362 Magnitude to flux conversion (typically np.log(10)*0.4) 

363 

364 nondetect_val : float 

365 Value to set for non-detections 

366 

367 Returns 

368 ------- 

369 mags_errs : np.array 

370 Magnitude errors 

371 """ 

372 vals = np.empty_like(flux_vals) 

373 positive = flux_vals >= 0 

374 vals[positive] = flux_err_vals[positive] / (flux_vals[positive] * mag_conv) 

375 vals[~positive] = nondetect_val 

376 return vals 

377 

378 @staticmethod 

379 def _deredden_mags( 

380 data: dict[str, np.ndarray], 

381 a_env_dict: dict[str, float], 

382 mag_names: dict[str, str], 

383 nondetect_val: float, 

384 ) -> dict[str, np.ndarray]: 

385 """Deredden the magnitdues 

386 

387 Parameters 

388 ---------- 

389 data: dict[str, np.array] 

390 Input data 

391 

392 a_env_dict: dict[str, float], 

393 Redenning parameters for bands 

394 

395 mag_names: dict[str, str] 

396 Mapping from bands to magnitudes 

397 

398 nondetect_val : float 

399 Value to set for non-detections 

400 

401 Returns 

402 ------- 

403 mags: dict[str, np.array] 

404 Udpated dict with dereddened mags 

405 """ 

406 ebv = data["ebv"] 

407 for band_, a_env_ in a_env_dict.items(): 

408 mag_name = mag_names[band_] 

409 raw_mag = data[mag_name] 

410 dered_mag = np.where( 

411 np.isfinite(raw_mag), 

412 raw_mag - ebv * a_env_, 

413 nondetect_val, 

414 ) 

415 data[mag_name] = dered_mag 

416 return data 

417 

418 def _get_mags_and_errs( 

419 self, 

420 fluxes: Table, 

421 mag_offset: float, 

422 ) -> dict[str, np.ndarray]: 

423 """Fill and return a numpy dict with mags and mag errors. 

424 

425 Parameters 

426 ---------- 

427 fluxes : Table 

428 Input fluxes and flux errors 

429 

430 mag_offset : float 

431 Magnitude offset (corresponding to a flux of 1.) 

432 

433 Returns 

434 ------- 

435 mags: dict[str, np.array] 

436 Numpy dict with mags and mag errors 

437 """ 

438 # get all the column names we will use 

439 flux_names = self.config.get_flux_names() 

440 mag_names = self.config.get_mag_names() 

441 flux_err_names = self.config.get_flux_err_names() 

442 mag_err_names = self.config.get_mag_err_names() 

443 nondetect_val = self.config.nondetect_val 

444 # output dict 

445 mag_dict = {} 

446 # loop over bands, make mags and mag errors and fill dict 

447 for band in flux_names.keys(): 

448 fluxVals = np.asarray(fluxes[flux_names[band]]) 

449 fluxErrVals = np.asarray(fluxes[flux_err_names[band]]) 

450 mag_dict[mag_names[band]] = self._flux_to_mag( 

451 fluxVals, 

452 mag_offset, 

453 nondetect_val, 

454 ) 

455 if flux_err_names: 

456 mag_dict[mag_err_names[band]] = self._flux_err_to_mag_err( 

457 fluxVals, 

458 fluxErrVals, 

459 self.mag_conv, 

460 nondetect_val, 

461 ) 

462 

463 # return the dict with the mags 

464 return mag_dict 

465 

466 def init( 

467 self, 

468 photoz_model: Model, 

469 ) -> None: 

470 """Set up the RAIL stage to compute photo-zs. 

471 

472 Parameters 

473 ---------- 

474 photoz_model : Model 

475 Model used by the p(z) estimation algorithm. 

476 """ 

477 # pop the pipeline task config options 

478 # so that we can pass the rest to RAIL 

479 rail_kwargs = self.config.toDict().copy() 

480 for key in ["saveLogOutput", "stage_name", "mag_offset", "connections"]: 

481 rail_kwargs.pop(key, None) 

482 rail_kwargs["output_mode"] = "return" 

483 

484 # Build the RAIL stage 

485 self._stage = PZFactory.build_stage_instance( 

486 self.config.stage_name(), 

487 self.config.estimator_class(), 

488 model_path=photoz_model.data, 

489 input_path="dummy.in", 

490 **rail_kwargs, 

491 ) 

492 self._stage._initialize_run() 

493 

494 def col_names( 

495 self, 

496 ) -> list[str]: 

497 """Get the list of column names to read from the input data.""" 

498 columns = list(self.config.get_flux_names().values()) + list( 

499 self.config.get_flux_err_names().values() 

500 ) 

501 if self.config.deredden: 

502 columns += ["ebv"] 

503 

504 return columns 

505 

506 def run( 

507 self, 

508 fluxes: Table, 

509 ) -> Struct: 

510 """Run a p(z) estimation algorithm. 

511 

512 Parameters 

513 ---------- 

514 fluxes : Table 

515 Fluxes used to compute the redshifts. 

516 

517 Returns 

518 ------- 

519 photoz_pdfs : qp.Ensemble 

520 Object with the p(z) PDFs. 

521 """ 

522 n_obj = len(fluxes) 

523 # Convert fluxes to mags 

524 mags = self._get_mags_and_errs(fluxes, self.config.mag_offset) 

525 nondetect_val = self.config.nondetect_val if hasattr(self.config, "nondetect_val") else np.nan 

526 

527 # De-redden 

528 if self.config.deredden: 

529 # asarray will convert an astropy column to an array w/o units 

530 mags["ebv"] = np.asarray(fluxes["ebv"]) 

531 mags = self._deredden_mags( 

532 mags, 

533 self.config.band_a_env, 

534 self.config.get_mag_names(), 

535 nondetect_val, 

536 ) 

537 

538 # Pass the mags to RAIL and get back the p(z) pdfs 

539 # as a qp.Ensemble object 

540 photoz_pdfs = PZFactory.estimate_single_pz(self._stage, mags, n_obj) 

541 return Struct(photoz_ensemble=photoz_pdfs) 

542 

543 

544class EstimatePhotozTaskConfig(PipelineTaskConfig, pipelineConnections=EstimatePhotozConnections): 

545 """Configuration for EstimatePhotozTask PipelineTask.""" 

546 

547 photoz_algo = photozAlgoRegistry.makeField( 

548 doc="Algorithm specific configuration p(z) estimation task", 

549 ) 

550 

551 

552class EstimatePhotozTask(PipelineTask): 

553 """PipelineTask for p(z) estimation. 

554 

555 Parameters 

556 ---------- 

557 initInputs 

558 Initialization inputs to pass to super().__init__. 

559 **kwargs 

560 Additional keyword arguments to pass to super().__init__. 

561 """ 

562 

563 ConfigClass = EstimatePhotozTaskConfig 

564 _DefaultName = "estimatePhotoz" 

565 

566 def __init__(self, initInputs: dict, **kwargs): 

567 super().__init__(initInputs=initInputs, **kwargs) 

568 self._initialized = False 

569 self.makeSubtask("photoz_algo") 

570 

571 def runQuantum(self, butlerQC, inputRefs, outputRefs): 

572 inputs = butlerQC.get(inputRefs) 

573 inputs["fluxes"] = inputs.pop("objects").get( 

574 parameters=dict(columns=self.photoz_algo.col_names()), 

575 ) 

576 outputs = self.run(**inputs, skip_init=self._initialized) 

577 butlerQC.put(outputs, outputRefs) 

578 

579 def run( 

580 self, 

581 *, 

582 photoz_model: Model, 

583 fluxes: Table, 

584 skip_init: bool = False, 

585 ) -> Struct: 

586 if not skip_init: 

587 self._initialized = True 

588 self.photoz_algo.init(photoz_model) 

589 

590 ret_struct = self.photoz_algo.run(fluxes) 

591 return Struct(photoz_ensemble=ret_struct.photoz_ensemble)