Coverage for python / lsst / multiprofit / transforms.py: 22%

27 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-14 23:46 +0000

1# This file is part of multiprofit. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22__all__ = ["get_logit_limited", "verify_transform_derivative", "transforms_ref"] 

23 

24 

25from typing import Any, Iterable 

26 

27import lsst.gauss2d.fit as g2f 

28import numpy as np 

29 

30from .limits import limits_ref 

31 

32 

33def get_logit_limited( 

34 lower: float, upper: float, factor: float = 1.0, name: str | None = None 

35) -> g2f.LogitLimitedTransformD: 

36 """Get a logit transform stretched to span a different range than [0,1]. 

37 

38 Parameters 

39 ---------- 

40 lower 

41 The lower limit of the range to span. 

42 upper 

43 The upper limit of the range to span. 

44 factor 

45 A multiplicative factor to apply to the transformed result. 

46 name 

47 A descriptive name for the transform. 

48 

49 Returns 

50 ------- 

51 transform 

52 A modified logit transform as specified. 

53 """ 

54 return g2f.LogitLimitedTransformD( 

55 limits=g2f.LimitsD( 

56 min=lower, 

57 max=upper, 

58 name=( 

59 name 

60 if name is not None 

61 else f"LogitLimitedTransformD(min={lower}, max={upper}, factor={factor})" 

62 ), 

63 ), 

64 factor=factor, 

65 ) 

66 

67 

68def verify_transform_derivative( 

69 transform: g2f.TransformD, 

70 value_transformed: float, 

71 derivative: float | None = None, 

72 abs_max: float = 1e6, 

73 dx_ratios: Iterable[float] | None = None, 

74 **kwargs: Any, 

75) -> None: 

76 """Verify that the derivative of a transform class is correct. 

77 

78 Parameters 

79 ---------- 

80 transform 

81 The transform to verify. 

82 value_transformed 

83 The un-transformed value at which to verify the transform. 

84 derivative 

85 The nominal derivative at value_transformed. 

86 Must equal transform.derivative(value_transformed). 

87 abs_max 

88 The x value to skip verification if np.abs(derivative) > x. 

89 dx_ratios 

90 Iterable of signed ratios to set dx for finite differencing, where 

91 dx = value*ratio (untransformed). 

92 **kwargs 

93 Keyword arguments to pass to np.isclose when comparing derivatives to 

94 finite differences. 

95 

96 Raises 

97 ------ 

98 RuntimeError 

99 Raised if the transform derivative doesn't match finite differences 

100 within the specified tolerances. 

101 

102 Notes 

103 ----- 

104 derivative should only be specified if it has previously been computed for 

105 the exact value_transformed, to avoid re-computing it unnecessarily. 

106 

107 Default dx_ratios are [1e-4, 1e-6, 1e-8, 1e-10, 1e-12, 1e-14]. 

108 Verification will test all ratios until at least one passes. 

109 """ 

110 value = transform.reverse(value_transformed) 

111 if derivative is None: 

112 derivative = transform.derivative(value) 

113 if np.abs(derivative) > abs_max: 

114 # Skip testing finite differencing if the derivative is very large 

115 # This might happen e.g. near the limits of the transformation 

116 # TODO: Check if finite differencing can be improved for large values 

117 return 

118 if dx_ratios is None: 

119 dx_ratios = [1e-4, 1e-6, 1e-8, 1e-10, 1e-12, 1e-14] 

120 elif not (len(dx_ratios) > 0): 

121 raise ValueError(f"{dx_ratios=} must not be empty") 

122 for ratio in dx_ratios: 

123 dx = value * ratio 

124 fin_diff = (transform.forward(value + dx) - value_transformed) / dx 

125 if not np.isfinite(fin_diff): 

126 fin_diff = -(transform.forward(value - dx) - value_transformed) / dx 

127 is_close = np.isclose(derivative, fin_diff, **kwargs) 

128 if is_close: 

129 return 

130 raise RuntimeError( 

131 f"{transform} derivative={derivative:.8e} != last " 

132 f"finite diff.={fin_diff:8e} with {dx=} and dx_abs_max={abs_max}" 

133 ) 

134 

135 

136transforms_ref = { 

137 "none": g2f.UnitTransformD(), 

138 "log": g2f.LogTransformD(), 

139 "log10": g2f.Log10TransformD(), 

140 "inverse": g2f.InverseTransformD(), 

141 "logit": g2f.LogitTransformD(), 

142 "logit_fluxfrac": get_logit_limited( 

143 limits_ref["fluxfrac"].min, 

144 limits_ref["fluxfrac"].max, 

145 name=f"ref_logit_fluxfrac[{limits_ref['fluxfrac'].min}, {limits_ref['fluxfrac'].max}]", 

146 ), 

147 "logit_rho": get_logit_limited( 

148 limits_ref["rho"].min, 

149 limits_ref["rho"].max, 

150 name=f"ref_logit_rho[{limits_ref['rho'].min}, {limits_ref['rho'].max}]", 

151 ), 

152 "logit_axrat": get_logit_limited(1e-4, 1, name="ref_logit_axrat[1e-4, 1]"), 

153 "logit_axrat_prior": get_logit_limited(-0.0001, 1.1, name="ref_logit_axrat_prior[-0.0001, 1.1]"), 

154 "logit_sersic": get_logit_limited(0.49, 6.01, name="ref_logit_sersic[0.49, 6.01]"), 

155}