Coverage for python / lsst / meas / photoz / base / estimate_photoz_task_trainz.py: 63%

39 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-28 08:46 +0000

1# This file is part of meas_photoz_base. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = [ 

23 "EstimatePhotozTrainZAlgoConfig", 

24 "EstimatePhotozTrainZAlgoTask", 

25 "EstimatePhotozTrainZConfig", 

26 "EstimatePhotozTrainZTask", 

27] 

28 

29import numpy as np 

30from astropy.table import Table 

31from rail.estimation.algos.train_z import TrainZEstimator 

32from rail.estimation.estimator import CatEstimator 

33 

34import lsst.pex.config as pexConfig 

35 

36from .estimate_photoz_task import ( 

37 EstimatePhotozAlgoConfigBase, 

38 EstimatePhotozAlgoTask, 

39 EstimatePhotozTask, 

40 EstimatePhotozTaskConfig, 

41 photozAlgoRegistry, 

42) 

43 

44 

45class EstimatePhotozTrainZAlgoConfig(EstimatePhotozAlgoConfigBase): 

46 """Config for EstimatePhotozTrainZAlgoTask.""" 

47 

48 @classmethod 

49 def estimator_class(cls) -> type[CatEstimator]: 

50 return TrainZEstimator 

51 

52 @classmethod 

53 def stage_name(cls): 

54 return "trainz" 

55 

56 def setDefaults(self): 

57 super().setDefaults() 

58 self.band_a_env = {"i": 2.06} 

59 

60 

61EstimatePhotozTrainZAlgoConfig._make_fields() 

62 

63 

64@pexConfig.registerConfigurable(EstimatePhotozTrainZAlgoConfig.stage_name(), photozAlgoRegistry) 

65class EstimatePhotozTrainZAlgoTask(EstimatePhotozAlgoTask): 

66 """Subtask to run RAIL TrainZ algorithm for p(z) estimation. 

67 

68 See https://github.com/LSSTDESC/rail_base/blob/main/src/rail/estimation/algos/train_z.py 

69 for algorithm implementation. 

70 

71 TrainZ is just a placeholder algorithm that assigns that same 

72 p(z) distribution (taken from the input model file) to every object. 

73 """ 

74 

75 ConfigClass = EstimatePhotozTrainZAlgoConfig 

76 _DefaultName = "estimatePhotozTrainZAlgo" 

77 

78 def _get_mags_and_errs( 

79 self, 

80 fluxes: Table, 

81 mag_offset: float, 

82 ) -> dict[str, np.ndarray]: 

83 flux_names = self.config.get_flux_names() 

84 mag_names = self.config.get_mag_names() 

85 

86 mag_dict = {} 

87 # loop over bands, make mags and mag errors and fill dict 

88 for band, band_name in flux_names.items(): 

89 fluxVals = fluxes[band_name] 

90 mag_dict[mag_names[band]] = self._flux_to_mag( 

91 fluxVals, 

92 mag_offset, 

93 99.0, 

94 ) 

95 return mag_dict 

96 

97 

98class EstimatePhotozTrainZConfig(EstimatePhotozTaskConfig): 

99 """Config for EstimatePhotozTrainZTask.""" 

100 

101 def setDefaults(self) -> None: 

102 super().setDefaults() 

103 name = EstimatePhotozTrainZAlgoConfig.stage_name() 

104 self.connections.algo = name 

105 self.photoz_algo = name 

106 

107 

108class EstimatePhotozTrainZTask(EstimatePhotozTask): 

109 """Task to run RAIL TrainZ algorithm for p(z) estimation.""" 

110 

111 ConfigClass = EstimatePhotozTrainZConfig 

112 _DefaultName = "estimatePhotozTrainZ"