Coverage for python / lsst / multiprofit / transforms.py: 22%
27 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-26 08:58 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-26 08:58 +0000
1# This file is part of multiprofit.
2#
3# Developed for the LSST Data Management System.
4# This product includes software developed by the LSST Project
5# (https://www.lsst.org).
6# See the COPYRIGHT file at the top-level directory of this distribution
7# for details of code ownership.
8#
9# This program is free software: you can redistribute it and/or modify
10# it under the terms of the GNU General Public License as published by
11# the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# This program is distributed in the hope that it will be useful,
15# but WITHOUT ANY WARRANTY; without even the implied warranty of
16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17# GNU General Public License for more details.
18#
19# You should have received a copy of the GNU General Public License
20# along with this program. If not, see <https://www.gnu.org/licenses/>.
22__all__ = ["get_logit_limited", "verify_transform_derivative", "transforms_ref"]
25from typing import Any, Iterable
27import lsst.gauss2d.fit as g2f
28import numpy as np
30from .limits import limits_ref
33def get_logit_limited(
34 lower: float, upper: float, factor: float = 1.0, name: str | None = None
35) -> g2f.LogitLimitedTransformD:
36 """Get a logit transform stretched to span a different range than [0,1].
38 Parameters
39 ----------
40 lower
41 The lower limit of the range to span.
42 upper
43 The upper limit of the range to span.
44 factor
45 A multiplicative factor to apply to the transformed result.
46 name
47 A descriptive name for the transform.
49 Returns
50 -------
51 transform
52 A modified logit transform as specified.
53 """
54 return g2f.LogitLimitedTransformD(
55 limits=g2f.LimitsD(
56 min=lower,
57 max=upper,
58 name=(
59 name
60 if name is not None
61 else f"LogitLimitedTransformD(min={lower}, max={upper}, factor={factor})"
62 ),
63 ),
64 factor=factor,
65 )
68def verify_transform_derivative(
69 transform: g2f.TransformD,
70 value_transformed: float,
71 derivative: float | None = None,
72 abs_max: float = 1e6,
73 dx_ratios: Iterable[float] | None = None,
74 **kwargs: Any,
75) -> None:
76 """Verify that the derivative of a transform class is correct.
78 Parameters
79 ----------
80 transform
81 The transform to verify.
82 value_transformed
83 The un-transformed value at which to verify the transform.
84 derivative
85 The nominal derivative at value_transformed.
86 Must equal transform.derivative(value_transformed).
87 abs_max
88 The x value to skip verification if np.abs(derivative) > x.
89 dx_ratios
90 Iterable of signed ratios to set dx for finite differencing, where
91 dx = value*ratio (untransformed).
92 **kwargs
93 Keyword arguments to pass to np.isclose when comparing derivatives to
94 finite differences.
96 Raises
97 ------
98 RuntimeError
99 Raised if the transform derivative doesn't match finite differences
100 within the specified tolerances.
102 Notes
103 -----
104 derivative should only be specified if it has previously been computed for
105 the exact value_transformed, to avoid re-computing it unnecessarily.
107 Default dx_ratios are [1e-4, 1e-6, 1e-8, 1e-10, 1e-12, 1e-14].
108 Verification will test all ratios until at least one passes.
109 """
110 value = transform.reverse(value_transformed)
111 if derivative is None:
112 derivative = transform.derivative(value)
113 if np.abs(derivative) > abs_max:
114 # Skip testing finite differencing if the derivative is very large
115 # This might happen e.g. near the limits of the transformation
116 # TODO: Check if finite differencing can be improved for large values
117 return
118 if dx_ratios is None:
119 dx_ratios = [1e-4, 1e-6, 1e-8, 1e-10, 1e-12, 1e-14]
120 elif not (len(dx_ratios) > 0):
121 raise ValueError(f"{dx_ratios=} must not be empty")
122 for ratio in dx_ratios:
123 dx = value * ratio
124 fin_diff = (transform.forward(value + dx) - value_transformed) / dx
125 if not np.isfinite(fin_diff):
126 fin_diff = -(transform.forward(value - dx) - value_transformed) / dx
127 is_close = np.isclose(derivative, fin_diff, **kwargs)
128 if is_close:
129 return
130 raise RuntimeError(
131 f"{transform} derivative={derivative:.8e} != last "
132 f"finite diff.={fin_diff:8e} with {dx=} and dx_abs_max={abs_max}"
133 )
136transforms_ref = {
137 "none": g2f.UnitTransformD(),
138 "log": g2f.LogTransformD(),
139 "log10": g2f.Log10TransformD(),
140 "inverse": g2f.InverseTransformD(),
141 "logit": g2f.LogitTransformD(),
142 "logit_fluxfrac": get_logit_limited(
143 limits_ref["fluxfrac"].min,
144 limits_ref["fluxfrac"].max,
145 name=f"ref_logit_fluxfrac[{limits_ref['fluxfrac'].min}, {limits_ref['fluxfrac'].max}]",
146 ),
147 "logit_rho": get_logit_limited(
148 limits_ref["rho"].min,
149 limits_ref["rho"].max,
150 name=f"ref_logit_rho[{limits_ref['rho'].min}, {limits_ref['rho'].max}]",
151 ),
152 "logit_axrat": get_logit_limited(1e-4, 1, name="ref_logit_axrat[1e-4, 1]"),
153 "logit_axrat_prior": get_logit_limited(-0.0001, 1.1, name="ref_logit_axrat_prior[-0.0001, 1.1]"),
154 "logit_sersic": get_logit_limited(0.49, 6.01, name="ref_logit_sersic[0.49, 6.01]"),
155}