Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

67

68

69

70

71

72

73

74

75

76

77

78

79

80

81

82

83

84

85

86

87

88

89

90

91

92

93

94

95

96

97

98

99

100

101

102

103

104

105

106

107

108

109

110

111

112

113

114

115

116

117

118

119

120

121

122

123

124

125

126

127

128

129

130

131

132

133

134

135

136

137

138

139

140

141

142

143

144

145

146

147

148

149

150

151

152

153

154

155

156

157

158

159

160

161

162

163

164

165

166

167

168

169

170

171

172

173

174

175

176

177

178

179

180

181

182

183

184

185

186

187

188

189

190

191

192

193

194

195

196

197

198

199

200

201

202

203

204

205

206

207

208

209

210

211

212

213

214

215

216

217

218

219

220

221

222

223

224

225

226

227

228

229

230

231

232

233

234

235

236

237

238

239

240

241

242

243

244

245

246

247

248

249

250

251

252

253

254

255

256

import numpy as np 

import numbers 

 

__all__ = ["_FactorialGenerator", "ZernikePolynomialGenerator"] 

 

 

class _FactorialGenerator(object): 

""" 

A class that generates factorials 

and stores them in a dict to be referenced 

as needed. 

""" 

 

def __init__(self): 

self._values = {0:1, 1:1} 

self._max_i = 1 

 

def evaluate(self, num): 

""" 

Return the factorial of num 

""" 

22 ↛ 23line 22 didn't jump to line 23, because the condition on line 22 was never true if num<0: 

raise RuntimeError("Cannot handle negative factorial") 

 

i_num = int(np.round(num)); 

if i_num in self._values: 

return self._values[i_num] 

 

val = self._values[self._max_i] 

for ii in range(self._max_i, num): 

val *= (ii+1) 

self._values[ii+1] = val 

 

self._max_i = num 

return self._values[num] 

 

 

class ZernikePolynomialGenerator(object): 

""" 

A class to generate and evaluate the Zernike 

polynomials. Definitions of Zernike polynomials 

are taken from 

https://en.wikipedia.org/wiki/Zernike_polynomials 

""" 

 

def __init__(self): 

self._factorial = _FactorialGenerator() 

self._coeffs = {} 

self._powers = {} 

 

def _validate_nm(self, n, m): 

""" 

Make sure that n, m are a valid pair of indices for 

a Zernike polynomial. 

 

n is the radial order 

 

m is the angular order 

""" 

60 ↛ 61line 60 didn't jump to line 61, because the condition on line 60 was never true if not isinstance(n, int) and not isinstance(n, np.int64): 

raise RuntimeError('Zernike polynomial n must be int') 

62 ↛ 63line 62 didn't jump to line 63, because the condition on line 62 was never true if not isinstance(m,int) and not isinstance(m, np.int64): 

raise RuntimeError('Zernike polynomial m must be int') 

 

65 ↛ 66line 65 didn't jump to line 66, because the condition on line 65 was never true if n<0: 

raise RuntimeError('Radial Zernike n cannot be negative') 

67 ↛ 68line 67 didn't jump to line 68, because the condition on line 67 was never true if m<0: 

raise RuntimeError('Radial Zernike m cannot be negative') 

69 ↛ 70line 69 didn't jump to line 70, because the condition on line 69 was never true if n<m: 

raise RuntimeError('Radial Zerniki n must be >= m') 

 

n = int(n) 

m = int(m) 

 

return (n, m) 

 

def _make_polynomial(self, n, m): 

""" 

Make the radial part of the n, m Zernike 

polynomial. 

 

n is the radial order 

 

m is the angular order 

 

Returns 2 numpy arrays: coeffs and powers. 

 

The radial part of the Zernike polynomial is 

 

sum([coeffs[ii]*power(r, powers[ii]) 

for ii in range(len(coeffs))]) 

""" 

 

n, m = self._validate_nm(n, m) 

 

# coefficients taken from 

# https://en.wikipedia.org/wiki/Zernike_polynomials 

 

n_coeffs = 1+(n-m)//2 

local_coeffs = np.zeros(n_coeffs, dtype=float) 

local_powers = np.zeros(n_coeffs, dtype=float) 

for k in range(0, n_coeffs): 

if k%2 == 0: 

sgn = 1.0 

else: 

sgn = -1.0 

 

num_fac = self._factorial.evaluate(n-k) 

k_fac = self._factorial.evaluate(k) 

d1_fac = self._factorial.evaluate(((n+m)//2)-k) 

d2_fac = self._factorial.evaluate(((n-m)//2)-k) 

 

local_coeffs[k] = sgn*num_fac/(k_fac*d1_fac*d2_fac) 

local_powers[k] = n-2*k 

 

self._coeffs[(n,m)] = local_coeffs 

self._powers[(n,m)] = local_powers 

 

def _evaluate_radial_number(self, r, nm_tuple): 

""" 

Evaluate the radial part of a Zernike polynomial. 

 

r is a scalar value 

 

nm_tuple is a tuple of the form (radial order, angular order) 

denoting the polynomial to evaluate 

 

Return the value of the radial part of the polynomial at r 

""" 

if r > 1.0: 

return np.NaN 

 

r_term = np.power(r, self._powers[nm_tuple]) 

return (self._coeffs[nm_tuple]*r_term).sum() 

 

def _evaluate_radial_array(self, r, nm_tuple): 

""" 

Evaluate the radial part of a Zernike polynomial. 

 

r is a numpy array of radial values 

 

nm_tuple is a tuple of the form (radial order, angular order) 

denoting the polynomial to evaluate 

 

Return the values of the radial part of the polynomial at r 

(returns np.NaN if r>1.0) 

""" 

148 ↛ 149line 148 didn't jump to line 149, because the condition on line 148 was never true if len(r) == 0: 

return np.array([],dtype=float) 

 

# since we use np.where to handle cases of 

# r==0, use np.errstate to temporarily 

# turn off the divide by zero and 

# invalid double scalar RuntimeWarnings 

with np.errstate(divide='ignore', invalid='ignore'): 

log_r = np.log(r) 

log_r = np.where(np.isfinite(log_r), log_r, -1.0e10) 

r_power = np.exp(np.outer(log_r, self._powers[nm_tuple])) 

 

results = np.dot(r_power, self._coeffs[nm_tuple]) 

return np.where(r<1.0, results, np.NaN) 

 

def _evaluate_radial(self, r, n, m): 

""" 

Evaluate the radial part of a Zernike polynomial 

 

r is a radial value or an array of radial values 

 

n is the radial order of the polynomial 

 

m is the angular order of the polynomial 

 

Return the value(s) of the radial part of the polynomial at r 

(returns np.NaN if r>1.0) 

""" 

 

is_array = False 

if not isinstance(r, numbers.Number): 

is_array = True 

 

nm_tuple = self._validate_nm(n,m) 

 

if (nm_tuple[0]-nm_tuple[1]) % 2 == 1: 

184 ↛ 185line 184 didn't jump to line 185, because the condition on line 184 was never true if is_array: 

return np.zeros(len(r), dtype=float) 

return 0.0 

 

if nm_tuple not in self._coeffs: 

self._make_polynomial(nm_tuple[0], nm_tuple[1]) 

 

if is_array: 

return self._evaluate_radial_array(r, nm_tuple) 

 

return self._evaluate_radial_number(r, nm_tuple) 

 

def evaluate(self, r, phi, n, m): 

""" 

Evaluate a Zernike polynomial in polar coordinates 

 

r is the radial coordinate (a scalar or an array) 

 

phi is the angular coordinate in radians (a scalar or an array) 

 

n is the radial order of the polynomial 

 

m is the angular order of the polynomial 

 

Return the value(s) of the polynomial at r, phi 

(returns np.NaN if r>1.0) 

""" 

radial_part = self._evaluate_radial(r, n, np.abs(m)) 

if m>=0: 

return radial_part*np.cos(m*phi) 

return radial_part*np.sin(m*phi) 

 

def norm(self, n, m): 

""" 

Return the normalization of the n, m Zernike 

polynomial 

 

n is the radial order 

 

m is the angular order 

""" 

nm_tuple = self._validate_nm(n, np.abs(m)) 

if nm_tuple[1] == 0: 

eps = 2.0 

else: 

eps = 1.0 

return eps*np.pi/(nm_tuple[0]*2+2) 

 

def evaluate_xy(self, x, y, n, m): 

""" 

Evaluate a Zernike polynomial at a point in 

Cartesian space. 

 

x and y are the Cartesian coordinaes (either scalars 

or arrays) 

 

n is the radial order of the polynomial 

 

m is the angular order of the polynomial 

 

Return the value(s) of the polynomial at x, y 

(returns np.NaN if sqrt(x**2+y**2)>1.0) 

""" 

# since we use np.where to handle r==0 cases, 

# use np.errstate to temporarily turn off the 

# divide by zero and invalid double scalar 

# RuntimeWarnings 

with np.errstate(divide='ignore', invalid='ignore'): 

r = np.sqrt(x**2+y**2) 

cos_phi = np.where(r>0.0, x/r, 0.0) 

arccos_phi = np.arccos(cos_phi) 

phi = np.where(y>=0.0, arccos_phi, 0.0-arccos_phi) 

return self.evaluate(r, phi, n, m)