lsst.shapelet g1c9b347f51+c02d18c6ed
Loading...
Searching...
No Matches
tractor.py
Go to the documentation of this file.
2# LSST Data Management System
3# Copyright 2008-2013 LSST Corporation.
4#
5# This product includes software developed by the
6# LSST Project (http://www.lsst.org/).
7#
8# This program is free software: you can redistribute it and/or modify
9# it under the terms of the GNU General Public License as published by
10# the Free Software Foundation, either version 3 of the License, or
11# (at your option) any later version.
12#
13# This program is distributed in the hope that it will be useful,
14# but WITHOUT ANY WARRANTY; without even the implied warranty of
15# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16# GNU General Public License for more details.
17#
18# You should have received a copy of the LSST License Statement and
19# the GNU General Public License along with this program. If not,
20# see <http://www.lsstcorp.org/LegalNotices/>.
21#
22"""Code to load multi-Gaussian approximations to profiles from "The Tractor"
23into a lsst.shapelet.MultiShapeletBasis.
24
25Please see the README file in the data directory of the lsst.shapelet
26package for more information.
27"""
28import numpy
29import os
30import re
31import sys
32import warnings
33import pickle
34
36from ._shapeletLib import RadialProfile, MultiShapeletBasis, ShapeletFunction
37
38
40 """Register the pickled profiles in the data directory with the RadialProfile
41 singleton registry.
42
43 This should only be called at import time by this module; it's only a function to
44 avoid polluting the module namespace with all the local variables used here.
45 """
46 dataDir = os.path.join(os.environ["SHAPELET_DIR"], "data")
47 regex = re.compile(r"([a-z]+\d?)_K(\d+)_MR(\d+)\.pickle")
48 for filename in os.listdir(dataDir):
49 match = regex.match(filename)
50 if not match:
51 continue
52 name = match.group(1)
53 nComponents = int(match.group(2))
54 maxRadius = int(match.group(3))
55 try:
56 profile = RadialProfile.get(name)
58 warnings.warn("No C++ profile for multi-Gaussian pickle file '%s'" % filename)
59 continue
60 with open(os.path.join(dataDir, filename), 'rb') as stream:
61 if sys.version_info[0] >= 3:
62 array = pickle.load(stream, encoding='latin1')
63 else:
64 array = pickle.load(stream)
65 amplitudes = array[:nComponents]
66 amplitudes /= amplitudes.sum()
67 variances = array[nComponents:]
68 if amplitudes.shape != (nComponents,) or variances.shape != (nComponents,):
69 warnings.warn("Unknown format for multi-Gaussian pickle file '%s'" % filename)
70 continue
71 basis = MultiShapeletBasis(1)
72 for amplitude, variance in zip(amplitudes, variances):
73 radius = variance**0.5
74 matrix = numpy.array([[amplitude / ShapeletFunction.FLUX_FACTOR]], dtype=float)
75 basis.addComponent(radius, 0, matrix)
76 profile.registerBasis(basis, nComponents, maxRadius)
77
78
79# We register all the profiles at module import time, to allow C++ code to access all available profiles
80# without having to later call Python code to unpickle them.
82
83
84def evaluateRadial(basis, r, sbNormalize=False, doComponents=False):
85 """Plot a single-element MultiShapeletBasis as a radial profile.
86
87 Parameters
88 ----------
89 sbNormalize : `bool`
90 `True` to normalize.
91 doComponents : `bool`
92 `True` to evaluate components.
93 """
95 coefficients = numpy.ones(1, dtype=float)
96 msf = basis.makeFunction(ellipse, coefficients)
97 ev = msf.evaluate()
98 n = 1
99 if doComponents:
100 n += len(msf.getComponents())
101 z = numpy.zeros((n,) + r.shape, dtype=float)
102 for j, x in enumerate(r):
103 z[0, j] = ev(x, 0.0)
104 if doComponents:
105 for i, sf in enumerate(msf.getComponents()):
106 evc = sf.evaluate()
107 for j, x in enumerate(r):
108 z[i+1, j] = evc(x, 0.0)
109 if sbNormalize:
110 z /= ev(1.0, 0.0)
111 return z
112
113
114def integrateNormalizedFluxes(maxRadius=20.0, nSteps=5000):
115 """Integrate the profiles to compare relative fluxes between the true profiles
116 and their approximations.
117
118 After normalizing by surface brightness at r=1 r_e, integrate the profiles to compare
119 relative fluxes between the true profiles and their approximations.
120
121 Parameters
122 ----------
123 maxRadius : `float`, optional
124 Maximum radius to integrate the profile, in units of r_e.
125 nSteps : `int`, optional
126 Number of concrete points at which to evaluate the profile to
127 do the integration (we just use the trapezoidal rule).
128
129 Returns
130 -------
131 fluxes : `dict` of `float` values
132 Dictionary of fluxes (``exp``, ``lux``, ``dev``, ``luv``, ``ser2``, ``ser3``,
133 ``ser5``, ``gexp``, ``glux``, ``gdev``, ``gluv``, ``gser2``, ``gser3``, ``gser5``)
134 """
135 radii = numpy.linspace(0.0, maxRadius, nSteps)
136 profiles = {name: RadialProfile.get(name) for name in ("exp", "lux", "dev", "luv",
137 "ser2", "ser3", "ser5")}
138 evaluated = {}
139 for name, profile in profiles.items():
140 evaluated[name] = profile.evaluate(radii)
141 basis = profile.getBasis(8)
142 evaluated["g" + name] = evaluateRadial(basis, radii, sbNormalize=True, doComponents=False)[0, :]
143 fluxes = {name: numpy.trapz(z*radii, radii) for name, z in evaluated.items()}
144 return fluxes
145
146
147def plotSuite(doComponents=False):
148 """Plot all the profiles defined in this module together.
149
150 Plot all the profiles defined in this module together: true exp and dev,
151 the SDSS softened/truncated lux and luv, and the multi-Gaussian approximations
152 to all of these.
153
154 Parameters
155 ----------
156 doComponents : `bool`, optional
157 True, to plot the individual Gaussians that form the multi-Gaussian approximations.
158
159 Returns
160 -------
161 figure : `matplotlib.figure.Figure`
162 Figure that contains the plot.
163 axes : `numpy.ndarray` of `matplotlib.axes.Axes`
164 A 2x4 NumPy array of matplotlib axes objects.
165 """
166 from matplotlib import pyplot
167 fig = pyplot.figure(figsize=(9, 4.7))
168 axes = numpy.zeros((2, 4), dtype=object)
169 r1 = numpy.logspace(-3, 0, 1000, base=10)
170 r2 = numpy.linspace(1, 10, 1000)
171 r = [r1, r2]
172 for i in range(2):
173 for j in range(4):
174 axes[i, j] = fig.add_subplot(2, 4, i*4+j+1)
175 profiles = {name: RadialProfile.get(name) for name in ("exp", "lux", "dev", "luv")}
176 basis = {name: profiles[name].getBasis(8) for name in profiles}
177 z = numpy.zeros((2, 4), dtype=object)
178 colors = ("k", "g", "b", "r")
179 fig.subplots_adjust(wspace=0.025, hspace=0.025, bottom=0.15, left=0.1, right=0.98, top=0.92)
180 centers = [None, None]
181 for i in range(2): # 0=profile, 1=relative error
182 for j in range(0, 4, 2): # grid columns: 0=exp-like, 2=dev-like
183 bbox0 = axes[i, j].get_position()
184 bbox1 = axes[i, j+1].get_position()
185 bbox1.x0 = bbox0.x1 - 0.06
186 bbox0.x1 = bbox1.x0
187 centers[j//2] = 0.5*(bbox0.x0 + bbox1.x1)
188 axes[i, j].set_position(bbox0)
189 axes[i, j+1].set_position(bbox1)
190 for j in range(0, 2):
191 z[0, j] = [evaluateRadial(basis[k], r[j], sbNormalize=True, doComponents=doComponents)
192 for k in ("exp", "lux")]
193 z[0, j][0:0] = [profiles[k].evaluate(r[j])[numpy.newaxis, :] for k in ("exp", "lux")]
194 z[0, j+2] = [evaluateRadial(basis[k], r[j], sbNormalize=True, doComponents=doComponents)
195 for k in ("dev", "luv")]
196 z[0, j+2][0:0] = [profiles[k].evaluate(r[j])[numpy.newaxis, :] for k in ("dev", "luv")]
197 methodNames = [["loglog", "semilogy"], ["semilogx", "plot"]]
198 for j in range(0, 4): # grid columns
199 z[1, j] = [(z[0, j][0][0, :] - z[0, j][i][0, :])/z[0, j][0][0, :] for i in range(0, 4)]
200 handles = []
201 method0 = getattr(axes[0, j], methodNames[0][j%2])
202 method1 = getattr(axes[1, j], methodNames[1][j%2])
203 for k in range(4):
204 y0 = z[0, j][k]
205 handles.append(method0(r[j%2], y0[0, :], color=colors[k])[0])
206 if doComponents:
207 for ll in range(1, y0.shape[0]):
208 method0(r[j%2], y0[ll, :], color=colors[k], alpha=0.25)
209 method1(r[j%2], z[1, j][k], color=colors[k])
210 axes[0, j].set_xticklabels([])
211 axes[0, j].set_ylim(1E-6, 1E3)
212 axes[1, j].set_ylim(-0.2, 1.0)
213 for i, label in enumerate(("profile", "relative error")):
214 axes[i, 0].set_ylabel(label)
215 for t in axes[i, 0].get_yticklabels():
216 t.set_fontsize(11)
217 for j in range(1, 4):
218 axes[0, j].set_yticklabels([])
219 axes[1, j].set_yticklabels([])
220 xticks = [['$\\mathdefault{10^{%d}}$' % i for i in range(-3, 1)],
221 [str(i) for i in range(1, 11)]]
222 xticks[0][-1] = ""
223 xticks[1][-1] = ""
224 for j in range(0, 4):
225 axes[1, j].set_xticklabels(xticks[j%2])
226 for t in axes[1, j].get_xticklabels():
227 t.set_fontsize(11)
228 fig.legend(handles, ["exp/dev", "lux/luv", "approx exp/dev", "approx lux/luv"],
229 loc='lower center', ncol=4)
230 fig.text(centers[0], 0.95, "exponential", ha='center', weight='bold')
231 fig.text(centers[1], 0.95, "de Vaucouleur", ha='center', weight='bold')
232 return fig, axes
def integrateNormalizedFluxes(maxRadius=20.0, nSteps=5000)
Definition: tractor.py:114
def evaluateRadial(basis, r, sbNormalize=False, doComponents=False)
Definition: tractor.py:84
def registerRadialProfiles()
Definition: tractor.py:39
def plotSuite(doComponents=False)
Definition: tractor.py:147