Coverage for python/lsst/shapelet/tractor.py: 23%
134 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-04-13 02:59 -0700
« prev ^ index » next coverage.py v6.5.0, created at 2023-04-13 02:59 -0700
1#
2# LSST Data Management System
3# Copyright 2008-2013 LSST Corporation.
4#
5# This product includes software developed by the
6# LSST Project (http://www.lsst.org/).
7#
8# This program is free software: you can redistribute it and/or modify
9# it under the terms of the GNU General Public License as published by
10# the Free Software Foundation, either version 3 of the License, or
11# (at your option) any later version.
12#
13# This program is distributed in the hope that it will be useful,
14# but WITHOUT ANY WARRANTY; without even the implied warranty of
15# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16# GNU General Public License for more details.
17#
18# You should have received a copy of the LSST License Statement and
19# the GNU General Public License along with this program. If not,
20# see <http://www.lsstcorp.org/LegalNotices/>.
21#
22"""Code to load multi-Gaussian approximations to profiles from "The Tractor"
23into a lsst.shapelet.MultiShapeletBasis.
25Please see the README file in the data directory of the lsst.shapelet
26package for more information.
27"""
28import numpy
29import os
30import re
31import sys
32import warnings
33import pickle
35import lsst.pex.exceptions
36from ._shapeletLib import RadialProfile, MultiShapeletBasis, ShapeletFunction
39def registerRadialProfiles():
40 """Register the pickled profiles in the data directory with the RadialProfile
41 singleton registry.
43 This should only be called at import time by this module; it's only a function to
44 avoid polluting the module namespace with all the local variables used here.
45 """
46 dataDir = os.path.join(os.environ["SHAPELET_DIR"], "data")
47 regex = re.compile(r"([a-z]+\d?)_K(\d+)_MR(\d+)\.pickle")
48 for filename in os.listdir(dataDir):
49 match = regex.match(filename)
50 if not match:
51 continue
52 name = match.group(1)
53 nComponents = int(match.group(2))
54 maxRadius = int(match.group(3))
55 try:
56 profile = RadialProfile.get(name)
57 except lsst.pex.exceptions.Exception:
58 warnings.warn("No C++ profile for multi-Gaussian pickle file '%s'" % filename)
59 continue
60 with open(os.path.join(dataDir, filename), 'rb') as stream:
61 if sys.version_info[0] >= 3: 61 ↛ 64line 61 didn't jump to line 64, because the condition on line 61 was never false
62 array = pickle.load(stream, encoding='latin1')
63 else:
64 array = pickle.load(stream)
65 amplitudes = array[:nComponents]
66 amplitudes /= amplitudes.sum()
67 variances = array[nComponents:]
68 if amplitudes.shape != (nComponents,) or variances.shape != (nComponents,): 68 ↛ 69line 68 didn't jump to line 69, because the condition on line 68 was never true
69 warnings.warn("Unknown format for multi-Gaussian pickle file '%s'" % filename)
70 continue
71 basis = MultiShapeletBasis(1)
72 for amplitude, variance in zip(amplitudes, variances):
73 radius = variance**0.5
74 matrix = numpy.array([[amplitude / ShapeletFunction.FLUX_FACTOR]], dtype=float)
75 basis.addComponent(radius, 0, matrix)
76 profile.registerBasis(basis, nComponents, maxRadius)
79# We register all the profiles at module import time, to allow C++ code to access all available profiles
80# without having to later call Python code to unpickle them.
81registerRadialProfiles()
84def evaluateRadial(basis, r, sbNormalize=False, doComponents=False):
85 """Plot a single-element MultiShapeletBasis as a radial profile.
87 Parameters
88 ----------
89 sbNormalize : `bool`
90 `True` to normalize.
91 doComponents : `bool`
92 `True` to evaluate components.
93 """
94 ellipse = lsst.afw.geom.ellipses.Ellipse(lsst.afw.geom.ellipses.Axes())
95 coefficients = numpy.ones(1, dtype=float)
96 msf = basis.makeFunction(ellipse, coefficients)
97 ev = msf.evaluate()
98 n = 1
99 if doComponents:
100 n += len(msf.getComponents())
101 z = numpy.zeros((n,) + r.shape, dtype=float)
102 for j, x in enumerate(r):
103 z[0, j] = ev(x, 0.0)
104 if doComponents:
105 for i, sf in enumerate(msf.getComponents()):
106 evc = sf.evaluate()
107 for j, x in enumerate(r):
108 z[i+1, j] = evc(x, 0.0)
109 if sbNormalize:
110 z /= ev(1.0, 0.0)
111 return z
114def integrateNormalizedFluxes(maxRadius=20.0, nSteps=5000):
115 """Integrate the profiles to compare relative fluxes between the true profiles
116 and their approximations.
118 After normalizing by surface brightness at r=1 r_e, integrate the profiles to compare
119 relative fluxes between the true profiles and their approximations.
121 Parameters
122 ----------
123 maxRadius : `float`, optional
124 Maximum radius to integrate the profile, in units of r_e.
125 nSteps : `int`, optional
126 Number of concrete points at which to evaluate the profile to
127 do the integration (we just use the trapezoidal rule).
129 Returns
130 -------
131 fluxes : `dict` of `float` values
132 Dictionary of fluxes (``exp``, ``lux``, ``dev``, ``luv``, ``ser2``, ``ser3``,
133 ``ser5``, ``gexp``, ``glux``, ``gdev``, ``gluv``, ``gser2``, ``gser3``, ``gser5``)
134 """
135 radii = numpy.linspace(0.0, maxRadius, nSteps)
136 profiles = {name: RadialProfile.get(name) for name in ("exp", "lux", "dev", "luv",
137 "ser2", "ser3", "ser5")}
138 evaluated = {}
139 for name, profile in profiles.items():
140 evaluated[name] = profile.evaluate(radii)
141 basis = profile.getBasis(8)
142 evaluated["g" + name] = evaluateRadial(basis, radii, sbNormalize=True, doComponents=False)[0, :]
143 fluxes = {name: numpy.trapz(z*radii, radii) for name, z in evaluated.items()}
144 return fluxes
147def plotSuite(doComponents=False):
148 """Plot all the profiles defined in this module together.
150 Plot all the profiles defined in this module together: true exp and dev,
151 the SDSS softened/truncated lux and luv, and the multi-Gaussian approximations
152 to all of these.
154 Parameters
155 ----------
156 doComponents : `bool`, optional
157 True, to plot the individual Gaussians that form the multi-Gaussian approximations.
159 Returns
160 -------
161 figure : `matplotlib.figure.Figure`
162 Figure that contains the plot.
163 axes : `numpy.ndarray` of `matplotlib.axes.Axes`
164 A 2x4 NumPy array of matplotlib axes objects.
165 """
166 from matplotlib import pyplot
167 fig = pyplot.figure(figsize=(9, 4.7))
168 axes = numpy.zeros((2, 4), dtype=object)
169 r1 = numpy.logspace(-3, 0, 1000, base=10)
170 r2 = numpy.linspace(1, 10, 1000)
171 r = [r1, r2]
172 for i in range(2):
173 for j in range(4):
174 axes[i, j] = fig.add_subplot(2, 4, i*4+j+1)
175 profiles = {name: RadialProfile.get(name) for name in ("exp", "lux", "dev", "luv")}
176 basis = {name: profiles[name].getBasis(8) for name in profiles}
177 z = numpy.zeros((2, 4), dtype=object)
178 colors = ("k", "g", "b", "r")
179 fig.subplots_adjust(wspace=0.025, hspace=0.025, bottom=0.15, left=0.1, right=0.98, top=0.92)
180 centers = [None, None]
181 for i in range(2): # 0=profile, 1=relative error
182 for j in range(0, 4, 2): # grid columns: 0=exp-like, 2=dev-like
183 bbox0 = axes[i, j].get_position()
184 bbox1 = axes[i, j+1].get_position()
185 bbox1.x0 = bbox0.x1 - 0.06
186 bbox0.x1 = bbox1.x0
187 centers[j//2] = 0.5*(bbox0.x0 + bbox1.x1)
188 axes[i, j].set_position(bbox0)
189 axes[i, j+1].set_position(bbox1)
190 for j in range(0, 2):
191 z[0, j] = [evaluateRadial(basis[k], r[j], sbNormalize=True, doComponents=doComponents)
192 for k in ("exp", "lux")]
193 z[0, j][0:0] = [profiles[k].evaluate(r[j])[numpy.newaxis, :] for k in ("exp", "lux")]
194 z[0, j+2] = [evaluateRadial(basis[k], r[j], sbNormalize=True, doComponents=doComponents)
195 for k in ("dev", "luv")]
196 z[0, j+2][0:0] = [profiles[k].evaluate(r[j])[numpy.newaxis, :] for k in ("dev", "luv")]
197 methodNames = [["loglog", "semilogy"], ["semilogx", "plot"]]
198 for j in range(0, 4): # grid columns
199 z[1, j] = [(z[0, j][0][0, :] - z[0, j][i][0, :])/z[0, j][0][0, :] for i in range(0, 4)]
200 handles = []
201 method0 = getattr(axes[0, j], methodNames[0][j%2])
202 method1 = getattr(axes[1, j], methodNames[1][j%2])
203 for k in range(4):
204 y0 = z[0, j][k]
205 handles.append(method0(r[j%2], y0[0, :], color=colors[k])[0])
206 if doComponents:
207 for ll in range(1, y0.shape[0]):
208 method0(r[j%2], y0[ll, :], color=colors[k], alpha=0.25)
209 method1(r[j%2], z[1, j][k], color=colors[k])
210 axes[0, j].set_xticklabels([])
211 axes[0, j].set_ylim(1E-6, 1E3)
212 axes[1, j].set_ylim(-0.2, 1.0)
213 for i, label in enumerate(("profile", "relative error")):
214 axes[i, 0].set_ylabel(label)
215 for t in axes[i, 0].get_yticklabels():
216 t.set_fontsize(11)
217 for j in range(1, 4):
218 axes[0, j].set_yticklabels([])
219 axes[1, j].set_yticklabels([])
220 xticks = [['$\\mathdefault{10^{%d}}$' % i for i in range(-3, 1)],
221 [str(i) for i in range(1, 11)]]
222 xticks[0][-1] = ""
223 xticks[1][-1] = ""
224 for j in range(0, 4):
225 axes[1, j].set_xticklabels(xticks[j%2])
226 for t in axes[1, j].get_xticklabels():
227 t.set_fontsize(11)
228 fig.legend(handles, ["exp/dev", "lux/luv", "approx exp/dev", "approx lux/luv"],
229 loc='lower center', ncol=4)
230 fig.text(centers[0], 0.95, "exponential", ha='center', weight='bold')
231 fig.text(centers[1], 0.95, "de Vaucouleur", ha='center', weight='bold')
232 return fig, axes