Coverage for python / lsst / scarlet / lite / utils.py: 31%

76 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-17 08:40 +0000

1# This file is part of scarlet_lite. 

2# 

3# Developed for the LSST Data Management System. 

4# This product includes software developed by the LSST Project 

5# (https://www.lsst.org). 

6# See the COPYRIGHT file at the top-level directory of this distribution 

7# for details of code ownership. 

8# 

9# This program is free software: you can redistribute it and/or modify 

10# it under the terms of the GNU General Public License as published by 

11# the Free Software Foundation, either version 3 of the License, or 

12# (at your option) any later version. 

13# 

14# This program is distributed in the hope that it will be useful, 

15# but WITHOUT ANY WARRANTY; without even the implied warranty of 

16# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

17# GNU General Public License for more details. 

18# 

19# You should have received a copy of the GNU General Public License 

20# along with this program. If not, see <https://www.gnu.org/licenses/>. 

21 

22from __future__ import annotations 

23 

24import sys 

25from typing import Any, Sequence 

26 

27import numpy as np 

28import numpy.typing as npt 

29from scipy.special import erfc 

30 

31ScalarLike = bool | int | float | complex 

32ScalarTypes = (bool, int, float, complex) 

33 

34 

35sqrt2 = np.sqrt(2) 

36sqrt_pi = np.sqrt(np.pi) 

37 

38 

39def integrated_gaussian_value(x: np.ndarray, sigma: float) -> np.ndarray: 

40 """A Gaussian function evaluated at `x` 

41 

42 Parameters 

43 ---------- 

44 x: 

45 The coordinates to evaluate the integrated Gaussian 

46 (ie. the centers of pixels). 

47 sigma: 

48 The standard deviation of the Gaussian. 

49 

50 Returns 

51 ------- 

52 gaussian: 

53 A Gaussian function integrated over `x` 

54 """ 

55 lhs = erfc((x - 0.5) / (sqrt2 * sigma)) 

56 rhs = erfc((x + 0.5) / (sqrt2 * sigma)) 

57 return sqrt_pi * 0.5 * sigma * (lhs - rhs) 

58 

59 

60def integrated_circular_gaussian( 

61 x: np.ndarray | None = None, y: np.ndarray | None = None, sigma: float = 0.8 

62) -> np.ndarray: 

63 """Create a circular Gaussian that is integrated over pixels 

64 

65 This is typically used for the model PSF, 

66 working well with the default parameters. 

67 

68 Parameters 

69 ---------- 

70 x, y: 

71 The x,y-coordinates to evaluate the integrated Gaussian. 

72 If `X` and `Y` are `None` then they will both be given the 

73 default value `numpy.arange(-7, 8)`, resulting in a 

74 `15x15` centered image. 

75 sigma: 

76 The standard deviation of the Gaussian. 

77 

78 Returns 

79 ------- 

80 image: 

81 A Gaussian function integrated over `X` and `Y`. 

82 """ 

83 if x is None: 

84 if y is None: 

85 x = np.arange(-7, 8) 

86 y = x 

87 else: 

88 raise ValueError( 

89 f"Either X and Y must be specified, or neither must be specified, got {x=} and {y=}" 

90 ) 

91 elif y is None: 

92 raise ValueError(f"Either X and Y must be specified, or neither must be specified, got {x=} and {y=}") 

93 

94 _x = integrated_gaussian_value(np.abs(x), sigma)[None, :] 

95 _y = integrated_gaussian_value(np.abs(y), sigma)[:, None] 

96 result = _x * _y 

97 return result / np.sum(result) 

98 

99 

100def get_circle_mask(diameter: int, dtype: npt.DTypeLike = np.float64): 

101 """Get a boolean image of a circle 

102 

103 Parameters 

104 ---------- 

105 diameter: 

106 The diameter of the circle and width 

107 of the image. 

108 dtype: 

109 The `dtype` of the image. 

110 

111 Returns 

112 ------- 

113 circle: 

114 A boolean array with ones for the pixels with centers 

115 inside of the circle and zeros 

116 outside of the circle. 

117 """ 

118 c = (diameter - 1) / 2 

119 # The center of the circle and its radius are 

120 # off by half a pixel for circles with 

121 # even numbered diameter 

122 if diameter % 2 == 0: 

123 radius = diameter / 2 

124 else: 

125 radius = c 

126 _x = np.arange(diameter) 

127 x, y = np.meshgrid(_x, _x) 

128 r = np.sqrt((x - c) ** 2 + (y - c) ** 2) 

129 

130 circle = np.ones((diameter, diameter), dtype=dtype) 

131 circle[r > radius] = 0 

132 return circle 

133 

134 

135INTRINSIC_SPECIAL_ATTRIBUTES = frozenset( 

136 ( 

137 "__qualname__", 

138 "__module__", 

139 "__metaclass__", 

140 "__dict__", 

141 "__weakref__", 

142 "__class__", 

143 "__subclasshook__", 

144 "__name__", 

145 "__doc__", 

146 ) 

147) 

148 

149 

150def is_attribute_safe_to_transfer(name, value): 

151 """Return True if an attribute is safe to monkeypatch-transfer to another 

152 class. 

153 This rejects special methods that are defined automatically for all 

154 classes, leaving only those explicitly defined in a class decorated by 

155 `continueClass` or registered with an instance of `TemplateMeta`. 

156 """ 

157 if name.startswith("__") and ( 

158 value is getattr(object, name, None) or name in INTRINSIC_SPECIAL_ATTRIBUTES 

159 ): 

160 return False 

161 return True 

162 

163 

164def convert_indices(sequence: Sequence, indices: Any, inclusive: bool = True) -> tuple[int, ...] | slice: 

165 """Get either a tuple of indices or a slice object from the given sequence. 

166 

167 Parameters 

168 ---------- 

169 sequence : Sequence 

170 The sequence to get the indices from. This sequence should have 

171 unique hashable elements. 

172 

173 indices : Any 

174 The indices or slice to use. Can be: 

175 - A single element from sequence 

176 - A slice with start/stop elements from sequence 

177 - A sequence of elements from sequence 

178 

179 inclusive : bool, optional 

180 If True, the stop element of a slice is inclusive. 

181 

182 Returns 

183 ------- 

184 tuple[int, ...] | slice 

185 A tuple of indices or a slice object. 

186 

187 Raises 

188 ------ 

189 TypeError : 

190 If `sequence` does not support `index` and `in` operations. 

191 IndexError : 

192 If a single element is not found in `sequence`. 

193 """ 

194 # Validate that sequence has the required methods 

195 if not hasattr(sequence, "index") or not hasattr(sequence, "__contains__"): 

196 raise TypeError(f"'sequence' must support 'index' and 'in' operations, got {type(sequence)}") 

197 

198 # Handle slice objects 

199 if isinstance(indices, slice): 

200 # Convert a slice of objects into a slice of array indices 

201 try: 

202 start = None if indices.start is None else sequence.index(indices.start) 

203 except ValueError as e: 

204 raise IndexError(f"Element {indices.start} not found in sequence {sequence}.") from e 

205 try: 

206 stop = None if indices.stop is None else sequence.index(indices.stop) + (1 if inclusive else 0) 

207 except ValueError as e: 

208 raise IndexError(f"Element {indices.stop} not found in sequence {sequence}.") from e 

209 return slice(start, stop, indices.step) 

210 

211 # Try to handle as a single element first 

212 if indices in sequence: 

213 return (sequence.index(indices),) 

214 

215 # Validate that indices is iterable 

216 if not hasattr(indices, "__iter__"): 

217 raise IndexError(f"Element {indices} not found in sequence {sequence}.") 

218 

219 # Handle sequence of indices 

220 index_map = {value: idx for idx, value in enumerate(sequence)} 

221 new_indices = [] 

222 for i in indices: 

223 try: 

224 if i not in index_map: 

225 raise IndexError(f"Element {i} not found in sequence {sequence}.") 

226 except TypeError as e: 

227 # If the 

228 raise IndexError(f"Element {i} not found in sequence {sequence}.") from e 

229 new_indices.append(index_map[i]) 

230 

231 return tuple(new_indices) 

232 

233 

234def continue_class(cls): 

235 """Re-open the decorated class, adding any new definitions into the 

236 original. 

237 For example: 

238 .. code-block:: python 

239 class Foo: 

240 pass 

241 @continueClass 

242 class Foo: 

243 def run(self): 

244 return None 

245 is equivalent to: 

246 .. code-block:: python 

247 class Foo: 

248 def run(self): 

249 return None 

250 .. warning:: 

251 Python's built-in `super` function does not behave properly in classes 

252 decorated with `continue_class`. Base class methods must be invoked 

253 directly using their explicit types instead. 

254 

255 This is copied directly from lsst.utils. If any additional functions are 

256 used from that repo we should remove this function and make lsst.utils 

257 a dependency. But for now, it is easier to copy this single wrapper 

258 than to include lsst.utils and all of its dependencies. 

259 """ 

260 orig = getattr(sys.modules[cls.__module__], cls.__name__) 

261 for name in dir(cls): 

262 # Common descriptors like classmethod and staticmethod can only be 

263 # accessed without invoking their magic if we use __dict__; if we use 

264 # getattr on those we'll get e.g. a bound method instance on the dummy 

265 # class rather than a classmethod instance we can put on the target 

266 # class. 

267 attr = cls.__dict__.get(name, None) or getattr(cls, name) 

268 if is_attribute_safe_to_transfer(name, attr): 

269 setattr(orig, name, attr) 

270 return orig