Coverage for python/lsst/shapelet/tractor.py: 23%

134 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-04-07 01:26 -0700

1# 

2# LSST Data Management System 

3# Copyright 2008-2013 LSST Corporation. 

4# 

5# This product includes software developed by the 

6# LSST Project (http://www.lsst.org/). 

7# 

8# This program is free software: you can redistribute it and/or modify 

9# it under the terms of the GNU General Public License as published by 

10# the Free Software Foundation, either version 3 of the License, or 

11# (at your option) any later version. 

12# 

13# This program is distributed in the hope that it will be useful, 

14# but WITHOUT ANY WARRANTY; without even the implied warranty of 

15# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

16# GNU General Public License for more details. 

17# 

18# You should have received a copy of the LSST License Statement and 

19# the GNU General Public License along with this program. If not, 

20# see <http://www.lsstcorp.org/LegalNotices/>. 

21# 

22"""Code to load multi-Gaussian approximations to profiles from "The Tractor" 

23into a lsst.shapelet.MultiShapeletBasis. 

24 

25Please see the README file in the data directory of the lsst.shapelet 

26package for more information. 

27""" 

28import numpy 

29import os 

30import re 

31import sys 

32import warnings 

33import pickle 

34 

35import lsst.pex.exceptions 

36from ._shapeletLib import RadialProfile, MultiShapeletBasis, ShapeletFunction 

37 

38 

39def registerRadialProfiles(): 

40 """Register the pickled profiles in the data directory with the RadialProfile 

41 singleton registry. 

42 

43 This should only be called at import time by this module; it's only a function to 

44 avoid polluting the module namespace with all the local variables used here. 

45 """ 

46 dataDir = os.path.join(os.environ["SHAPELET_DIR"], "data") 

47 regex = re.compile(r"([a-z]+\d?)_K(\d+)_MR(\d+)\.pickle") 

48 for filename in os.listdir(dataDir): 

49 match = regex.match(filename) 

50 if not match: 

51 continue 

52 name = match.group(1) 

53 nComponents = int(match.group(2)) 

54 maxRadius = int(match.group(3)) 

55 try: 

56 profile = RadialProfile.get(name) 

57 except lsst.pex.exceptions.Exception: 

58 warnings.warn("No C++ profile for multi-Gaussian pickle file '%s'" % filename) 

59 continue 

60 with open(os.path.join(dataDir, filename), 'rb') as stream: 

61 if sys.version_info[0] >= 3: 61 ↛ 64line 61 didn't jump to line 64, because the condition on line 61 was never false

62 array = pickle.load(stream, encoding='latin1') 

63 else: 

64 array = pickle.load(stream) 

65 amplitudes = array[:nComponents] 

66 amplitudes /= amplitudes.sum() 

67 variances = array[nComponents:] 

68 if amplitudes.shape != (nComponents,) or variances.shape != (nComponents,): 68 ↛ 69line 68 didn't jump to line 69, because the condition on line 68 was never true

69 warnings.warn("Unknown format for multi-Gaussian pickle file '%s'" % filename) 

70 continue 

71 basis = MultiShapeletBasis(1) 

72 for amplitude, variance in zip(amplitudes, variances): 

73 radius = variance**0.5 

74 matrix = numpy.array([[amplitude / ShapeletFunction.FLUX_FACTOR]], dtype=float) 

75 basis.addComponent(radius, 0, matrix) 

76 profile.registerBasis(basis, nComponents, maxRadius) 

77 

78 

79# We register all the profiles at module import time, to allow C++ code to access all available profiles 

80# without having to later call Python code to unpickle them. 

81registerRadialProfiles() 

82 

83 

84def evaluateRadial(basis, r, sbNormalize=False, doComponents=False): 

85 """Plot a single-element MultiShapeletBasis as a radial profile. 

86 

87 Parameters 

88 ---------- 

89 sbNormalize : `bool` 

90 `True` to normalize. 

91 doComponents : `bool` 

92 `True` to evaluate components. 

93 """ 

94 ellipse = lsst.afw.geom.ellipses.Ellipse(lsst.afw.geom.ellipses.Axes()) 

95 coefficients = numpy.ones(1, dtype=float) 

96 msf = basis.makeFunction(ellipse, coefficients) 

97 ev = msf.evaluate() 

98 n = 1 

99 if doComponents: 

100 n += len(msf.getComponents()) 

101 z = numpy.zeros((n,) + r.shape, dtype=float) 

102 for j, x in enumerate(r): 

103 z[0, j] = ev(x, 0.0) 

104 if doComponents: 

105 for i, sf in enumerate(msf.getComponents()): 

106 evc = sf.evaluate() 

107 for j, x in enumerate(r): 

108 z[i+1, j] = evc(x, 0.0) 

109 if sbNormalize: 

110 z /= ev(1.0, 0.0) 

111 return z 

112 

113 

114def integrateNormalizedFluxes(maxRadius=20.0, nSteps=5000): 

115 """Integrate the profiles to compare relative fluxes between the true profiles 

116 and their approximations. 

117 

118 After normalizing by surface brightness at r=1 r_e, integrate the profiles to compare 

119 relative fluxes between the true profiles and their approximations. 

120 

121 Parameters 

122 ---------- 

123 maxRadius : `float`, optional 

124 Maximum radius to integrate the profile, in units of r_e. 

125 nSteps : `int`, optional 

126 Number of concrete points at which to evaluate the profile to 

127 do the integration (we just use the trapezoidal rule). 

128 

129 Returns 

130 ------- 

131 fluxes : `dict` of `float` values 

132 Dictionary of fluxes (``exp``, ``lux``, ``dev``, ``luv``, ``ser2``, ``ser3``, 

133 ``ser5``, ``gexp``, ``glux``, ``gdev``, ``gluv``, ``gser2``, ``gser3``, ``gser5``) 

134 """ 

135 radii = numpy.linspace(0.0, maxRadius, nSteps) 

136 profiles = {name: RadialProfile.get(name) for name in ("exp", "lux", "dev", "luv", 

137 "ser2", "ser3", "ser5")} 

138 evaluated = {} 

139 for name, profile in profiles.items(): 

140 evaluated[name] = profile.evaluate(radii) 

141 basis = profile.getBasis(8) 

142 evaluated["g" + name] = evaluateRadial(basis, radii, sbNormalize=True, doComponents=False)[0, :] 

143 fluxes = {name: numpy.trapz(z*radii, radii) for name, z in evaluated.items()} 

144 return fluxes 

145 

146 

147def plotSuite(doComponents=False): 

148 """Plot all the profiles defined in this module together. 

149 

150 Plot all the profiles defined in this module together: true exp and dev, 

151 the SDSS softened/truncated lux and luv, and the multi-Gaussian approximations 

152 to all of these. 

153 

154 Parameters 

155 ---------- 

156 doComponents : `bool`, optional 

157 True, to plot the individual Gaussians that form the multi-Gaussian approximations. 

158 

159 Returns 

160 ------- 

161 figure : `matplotlib.figure.Figure` 

162 Figure that contains the plot. 

163 axes : `numpy.ndarray` of `matplotlib.axes.Axes` 

164 A 2x4 NumPy array of matplotlib axes objects. 

165 """ 

166 from matplotlib import pyplot 

167 fig = pyplot.figure(figsize=(9, 4.7)) 

168 axes = numpy.zeros((2, 4), dtype=object) 

169 r1 = numpy.logspace(-3, 0, 1000, base=10) 

170 r2 = numpy.linspace(1, 10, 1000) 

171 r = [r1, r2] 

172 for i in range(2): 

173 for j in range(4): 

174 axes[i, j] = fig.add_subplot(2, 4, i*4+j+1) 

175 profiles = {name: RadialProfile.get(name) for name in ("exp", "lux", "dev", "luv")} 

176 basis = {name: profiles[name].getBasis(8) for name in profiles} 

177 z = numpy.zeros((2, 4), dtype=object) 

178 colors = ("k", "g", "b", "r") 

179 fig.subplots_adjust(wspace=0.025, hspace=0.025, bottom=0.15, left=0.1, right=0.98, top=0.92) 

180 centers = [None, None] 

181 for i in range(2): # 0=profile, 1=relative error 

182 for j in range(0, 4, 2): # grid columns: 0=exp-like, 2=dev-like 

183 bbox0 = axes[i, j].get_position() 

184 bbox1 = axes[i, j+1].get_position() 

185 bbox1.x0 = bbox0.x1 - 0.06 

186 bbox0.x1 = bbox1.x0 

187 centers[j//2] = 0.5*(bbox0.x0 + bbox1.x1) 

188 axes[i, j].set_position(bbox0) 

189 axes[i, j+1].set_position(bbox1) 

190 for j in range(0, 2): 

191 z[0, j] = [evaluateRadial(basis[k], r[j], sbNormalize=True, doComponents=doComponents) 

192 for k in ("exp", "lux")] 

193 z[0, j][0:0] = [profiles[k].evaluate(r[j])[numpy.newaxis, :] for k in ("exp", "lux")] 

194 z[0, j+2] = [evaluateRadial(basis[k], r[j], sbNormalize=True, doComponents=doComponents) 

195 for k in ("dev", "luv")] 

196 z[0, j+2][0:0] = [profiles[k].evaluate(r[j])[numpy.newaxis, :] for k in ("dev", "luv")] 

197 methodNames = [["loglog", "semilogy"], ["semilogx", "plot"]] 

198 for j in range(0, 4): # grid columns 

199 z[1, j] = [(z[0, j][0][0, :] - z[0, j][i][0, :])/z[0, j][0][0, :] for i in range(0, 4)] 

200 handles = [] 

201 method0 = getattr(axes[0, j], methodNames[0][j%2]) 

202 method1 = getattr(axes[1, j], methodNames[1][j%2]) 

203 for k in range(4): 

204 y0 = z[0, j][k] 

205 handles.append(method0(r[j%2], y0[0, :], color=colors[k])[0]) 

206 if doComponents: 

207 for ll in range(1, y0.shape[0]): 

208 method0(r[j%2], y0[ll, :], color=colors[k], alpha=0.25) 

209 method1(r[j%2], z[1, j][k], color=colors[k]) 

210 axes[0, j].set_xticklabels([]) 

211 axes[0, j].set_ylim(1E-6, 1E3) 

212 axes[1, j].set_ylim(-0.2, 1.0) 

213 for i, label in enumerate(("profile", "relative error")): 

214 axes[i, 0].set_ylabel(label) 

215 for t in axes[i, 0].get_yticklabels(): 

216 t.set_fontsize(11) 

217 for j in range(1, 4): 

218 axes[0, j].set_yticklabels([]) 

219 axes[1, j].set_yticklabels([]) 

220 xticks = [['$\\mathdefault{10^{%d}}$' % i for i in range(-3, 1)], 

221 [str(i) for i in range(1, 11)]] 

222 xticks[0][-1] = "" 

223 xticks[1][-1] = "" 

224 for j in range(0, 4): 

225 axes[1, j].set_xticklabels(xticks[j%2]) 

226 for t in axes[1, j].get_xticklabels(): 

227 t.set_fontsize(11) 

228 fig.legend(handles, ["exp/dev", "lux/luv", "approx exp/dev", "approx lux/luv"], 

229 loc='lower center', ncol=4) 

230 fig.text(centers[0], 0.95, "exponential", ha='center', weight='bold') 

231 fig.text(centers[1], 0.95, "de Vaucouleur", ha='center', weight='bold') 

232 return fig, axes