lsst.meas.modelfit  14.0-1-g2fa83af+26
densityPlot.py
Go to the documentation of this file.
1 #
2 # LSST Data Management System
3 # Copyright 2008-2013 LSST Corporation.
4 #
5 # This product includes software developed by the
6 # LSST Project (http://www.lsst.org/).
7 #
8 # This program is free software: you can redistribute it and/or modify
9 # it under the terms of the GNU General Public License as published by
10 # the Free Software Foundation, either version 3 of the License, or
11 # (at your option) any later version.
12 #
13 # This program is distributed in the hope that it will be useful,
14 # but WITHOUT ANY WARRANTY; without even the implied warranty of
15 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 # GNU General Public License for more details.
17 #
18 # You should have received a copy of the LSST License Statement and
19 # the GNU General Public License along with this program. If not,
20 # see <http://www.lsstcorp.org/LegalNotices/>.
21 #
22 
23 """A set of matplotlib-based classes that displays a grid of 1-d and 2-d slices through an
24 N-d density.
25 
26 The main class, DensityPlot, manages the grid of matplotlib.axes.Axes objects, and holds
27 a sequence of Layer objects that each know how to draw individual 1-d or 2-d plots and a
28 data object that abstracts away how the N-d density data is actually represented.
29 
30 For simple cases, users can just create a custom data class with an interface like that of
31 the ExampleData class provided here, and use the provided HistogramLayer and SurfaceLayer
32 classes directly. In more complicated cases, users may want to create their own Layer classes,
33 which may define their own relationship with the data object.
34 """
35 from builtins import range
36 from builtins import object
37 
38 import collections
39 import numpy
40 import matplotlib.cm
41 import matplotlib.pyplot
42 import matplotlib.ticker
43 
44 __all__ = ("HistogramLayer", "SurfaceLayer", "ScatterLayer", "CrossPointsLayer",
45  "DensityPlot", "ExampleData", "demo")
46 
47 
48 def hide_xticklabels(axes):
49  for label in axes.get_xticklabels():
50  label.set_visible(False)
51 
52 
53 def hide_yticklabels(axes):
54  for label in axes.get_yticklabels():
55  label.set_visible(False)
56 
57 
58 def mergeDefaults(kwds, defaults):
59  copy = defaults.copy()
60  if kwds is not None:
61  copy.update(**kwds)
62  return copy
63 
64 
65 class HistogramLayer(object):
66  """A Layer class for DensityPlot for gridded histograms, drawing bar plots in 1-d and
67  colormapped large-pixel images in 2-d.
68 
69  Relies on two data object attributes:
70 
71  values ----- a (M,N) array of data points, where N is the dimension of the dataset and M is the
72  number of data points
73 
74  weights ---- (optional) an array of weights with shape (M,); if not present, all weights will
75  be set to unity
76 
77  The need for these data object attributes can be removed by subclassing HistogramLayer and overriding
78  the hist1d and hist2d methods.
79  """
80 
81  defaults1d = dict(facecolor='b', alpha=0.5)
82  defaults2d = dict(cmap=matplotlib.cm.Blues, vmin=0.0, interpolation='nearest')
83 
84  def __init__(self, tag, bins1d=20, bins2d=(20, 20), kwds1d=None, kwds2d=None):
85  self.tag = tag
86  self.bins1d = bins1d
87  self.bins2d = bins2d
88  self.kwds1d = mergeDefaults(kwds1d, self.defaults1d)
89  self.kwds2d = mergeDefaults(kwds2d, self.defaults2d)
90 
91  def hist1d(self, data, dim, limits):
92  """Extract points from the data object and compute a 1-d histogram.
93 
94  Return value should match that of numpy.histogram: a tuple of (hist, edges),
95  where hist is a 1-d array with size=bins1d, and edges is a 1-d array with
96  size=self.bins1d+1 giving the upper and lower edges of the bins.
97  """
98  i = data.dimensions.index(dim)
99  if hasattr(data, "weights") and data.weights is not None:
100  weights = data.weights
101  else:
102  weights = None
103  return numpy.histogram(data.values[:, i], bins=self.bins1d, weights=weights,
104  range=limits, normed=True)
105 
106  def hist2d(self, data, xDim, yDim, xLimits, yLimits):
107  """Extract points from the data object and compute a 1-d histogram.
108 
109  Return value should match that of numpy.histogram2d: a tuple of (hist, xEdges, yEdges),
110  where hist is a 2-d array with shape=bins2d, xEdges is a 1-d array with size=bins2d[0]+1,
111  and yEdges is a 1-d array with size=bins2d[1]+1.
112  """
113  i = data.dimensions.index(yDim)
114  j = data.dimensions.index(xDim)
115  if hasattr(data, "weights") and data.weights is not None:
116  weights = data.weights
117  else:
118  weights = None
119  return numpy.histogram2d(data.values[:, j], data.values[:, i], bins=self.bins2d, weights=weights,
120  range=(xLimits, yLimits), normed=True)
121 
122  def plotX(self, axes, data, dim):
123  y, xEdge = self.hist1d(data, dim, axes.get_xlim())
124  xCenter = 0.5*(xEdge[:-1] + xEdge[1:])
125  width = xEdge[1:] - xEdge[:-1]
126  return axes.bar(xCenter, y, width=width, align='center', **self.kwds1d)
127 
128  def plotY(self, axes, data, dim):
129  x, yEdge = self.hist1d(data, dim, axes.get_ylim())
130  yCenter = 0.5*(yEdge[:-1] + yEdge[1:])
131  height = yEdge[1:] - yEdge[:-1]
132  return axes.barh(yCenter, x, height=height, align='center', **self.kwds1d)
133 
134  def plotXY(self, axes, data, xDim, yDim):
135  z, xEdge, yEdge = self.hist2d(data, xDim, yDim, axes.get_xlim(), axes.get_ylim())
136  return axes.imshow(z.transpose(), aspect='auto', extent=(xEdge[0], xEdge[-1], yEdge[0], yEdge[-1]),
137  origin='lower', **self.kwds2d)
138 
139 
140 class ScatterLayer(object):
141  """A Layer class that plots individual points in 2-d, and does nothing in 1-d.
142 
143  Relies on two data object attributes:
144 
145  values ----- a (M,N) array of data points, where N is the dimension of the dataset and M is the
146  number of data points
147 
148  weights ---- (optional) an array of weights with shape (M,); will be used to set the color of points
149 
150  """
151 
152  defaults = dict(linewidth=0, alpha=0.2)
153 
154  def __init__(self, tag, **kwds):
155  self.tag = tag
156  self.kwds = mergeDefaults(kwds, self.defaults)
157 
158  def plotX(self, axes, data, dim):
159  pass
160 
161  def plotY(self, axes, data, dim):
162  pass
163 
164  def plotXY(self, axes, data, xDim, yDim):
165  i = data.dimensions.index(yDim)
166  j = data.dimensions.index(xDim)
167  if hasattr(data, "weights") and data.weights is not None:
168  args = data.values[:, j], data.values[:, i], data.weights
169  else:
170  args = data.values[:, j], data.values[:, i]
171  return axes.scatter(*args, **self.kwds)
172 
173 
174 class SurfaceLayer(object):
175  """A Layer class for analytic N-d distributions that can be evaluated in 1-d or 2-d slices.
176 
177  The 2-d slices are drawn as contours, and the 1-d slices are drawn as simple curves.
178 
179  Relies on eval1d and eval2d methods in the data object; this can be avoided by subclassing
180  SurfaceLayer and reimplementing its own eval1d and eval2d methods.
181  """
182 
183  defaults1d = dict(linewidth=2, color='r')
184  defaults2d = dict(linewidths=2, cmap=matplotlib.cm.Reds)
185 
186  def __init__(self, tag, steps1d=200, steps2d=200, filled=False, kwds1d=None, kwds2d=None):
187  self.tag = tag
188  self.steps1d = int(steps1d)
189  self.steps2d = int(steps2d)
190  self.filled = bool(filled)
191  self.kwds1d = mergeDefaults(kwds1d, self.defaults1d)
192  self.kwds2d = mergeDefaults(kwds2d, self.defaults2d)
193 
194  def eval1d(self, data, dim, x):
195  """Return analytic function values for the given values."""
196  return data.eval1d(dim, x)
197 
198  def eval2d(self, data, xDim, yDim, x, y):
199  """Return analytic function values for the given values."""
200  return data.eval2d(xDim, yDim, x, y)
201 
202  def plotX(self, axes, data, dim):
203  xMin, xMax = axes.get_xlim()
204  x = numpy.linspace(xMin, xMax, self.steps1d)
205  z = self.eval1d(data, dim, x)
206  if z is None:
207  return
208  return axes.plot(x, z, **self.kwds1d)
209 
210  def plotY(self, axes, data, dim):
211  yMin, yMax = axes.get_ylim()
212  y = numpy.linspace(yMin, yMax, self.steps1d)
213  z = self.eval1d(data, dim, y)
214  if z is None:
215  return
216  return axes.plot(z, y, **self.kwds1d)
217 
218  def plotXY(self, axes, data, xDim, yDim):
219  xMin, xMax = axes.get_xlim()
220  yMin, yMax = axes.get_ylim()
221  xc = numpy.linspace(xMin, xMax, self.steps2d)
222  yc = numpy.linspace(yMin, yMax, self.steps2d)
223  xg, yg = numpy.meshgrid(xc, yc)
224  z = self.eval2d(data, xDim, yDim, xg, yg)
225  if z is None:
226  return
227  if self.filled:
228  return axes.contourf(xg, yg, z, 6, **self.kwds2d)
229  else:
230  return axes.contour(xg, yg, z, 6, **self.kwds2d)
231 
232 
233 class CrossPointsLayer(object):
234  """A layer that marks a few points with axis-length vertical and horizontal lines.
235 
236  This relies on a "points" data object attribute.
237  """
238 
239  defaults = dict(alpha=0.8)
240 
241  def __init__(self, tag, colors=("y", "m", "c", "r", "g", "b"), **kwds):
242  self.tag = tag
243  self.colors = colors
244  self.kwds = mergeDefaults(kwds, self.defaults)
245 
246  def plotX(self, axes, data, dim):
247  i = data.dimensions.index(dim)
248  artists = []
249  for n, point in enumerate(data.points):
250  artists.append(axes.axvline(point[i], color=self.colors[n % len(self.colors)], **self.kwds))
251  return artists
252 
253  def plotY(self, axes, data, dim):
254  i = data.dimensions.index(dim)
255  artists = []
256  for n, point in enumerate(data.points):
257  artists.append(axes.axhline(point[i], color=self.colors[n % len(self.colors)], **self.kwds))
258  return artists
259 
260  def plotXY(self, axes, data, xDim, yDim):
261  i = data.dimensions.index(yDim)
262  j = data.dimensions.index(xDim)
263  artists = []
264  for n, point in enumerate(data.points):
265  artists.append(axes.axvline(point[j], color=self.colors[n % len(self.colors)], **self.kwds))
266  artists.append(axes.axhline(point[i], color=self.colors[n % len(self.colors)], **self.kwds))
267  return artists
268 
269 
270 class DensityPlot(object):
271  """An object that manages a matrix of matplotlib.axes.Axes objects that represent a set of 1-d and 2-d
272  slices through an N-d density.
273  """
274 
275  class LayerDict(collections.MutableMapping):
276 
277  def __init__(self, parent):
278  self._dict = dict()
279  self._parent = parent
280 
281  def __delitem__(self, name):
282  layer = self._dict.pop(name)
283  self._parent._dropLayer(name, layer)
284 
285  def __setitem__(self, name, layer):
286  self.pop(name, None)
287  self._dict[name] = layer
288  self._parent._plotLayer(name, layer)
289 
290  def __getitem__(self, name):
291  return self._dict[name]
292 
293  def __iter__(self):
294  return iter(self._dict)
295 
296  def __len__(self):
297  return len(self._dict)
298 
299  def __str__(self):
300  return str(self._dict)
301 
302  def __repr__(self):
303  return repr(self._dict)
304 
305  def replot(self, name):
306  layer = self._dict[name]
307  self._parent._dropLayer(name, layer)
308  self._parent._plotLayer(name, layer)
309 
310  def __init__(self, figure, **kwds):
311  self.figure = figure
312  self.data = dict(kwds)
313  active = []
314  self._lower = dict()
315  self._upper = dict()
316  # We merge the dimension name lists manually rather than using sets to preserve the order.
317  # Most of the time we expect all data objects to have the same dimensions anyway.
318  for v in self.data.values():
319  for dim in v.dimensions:
320  if dim not in active:
321  active.append(dim)
322  self._lower[dim] = v.lower[dim]
323  self._upper[dim] = v.upper[dim]
324  else:
325  self._lower[dim] = min(v.lower[dim], self._lower[dim])
326  self._upper[dim] = max(v.upper[dim], self._upper[dim])
327  self._active = tuple(active)
328  self._all_dims = frozenset(self._active)
329  self.figure.subplots_adjust(left=0.05, right=0.95, bottom=0.05, top=0.95, hspace=0.01, wspace=0.01)
330  self._build_axes()
331  self.layers = self.LayerDict(self)
332 
333  def _dropLayer(self, name, layer):
334  def removeArtist(*key):
335  try:
336  self._objs.pop(key).remove()
337  except AttributeError:
338  # sometimes the value might be None, which doesn't have a remove
339  pass
340  except TypeError:
341  # probably a matplotlib bug: remove sometimes raises an exception,
342  # but it still works
343  pass
344  for i, yDim in enumerate(self._active):
345  removeArtist(None, i, name)
346  removeArtist(i, None, name)
347  for j, xDim in enumerate(self._active):
348  if i == j:
349  continue
350  removeArtist(i, j, name)
351 
352  def _plotLayer(self, name, layer):
353  for i, yDim in enumerate(self._active):
354  if yDim not in self.data[layer.tag].dimensions:
355  continue
356  self._objs[None, i, name] = layer.plotX(self._axes[None, i], self.data[layer.tag], yDim)
357  self._objs[i, None, name] = layer.plotY(self._axes[i, None], self.data[layer.tag], yDim)
358  for j, xDim in enumerate(self._active):
359  if xDim not in self.data[layer.tag].dimensions:
360  continue
361  if i == j:
362  continue
363  self._objs[i, j, name] = layer.plotXY(self._axes[i, j], self.data[layer.tag], xDim, yDim)
364  self._axes[None, i].xaxis.set_major_locator(matplotlib.ticker.MaxNLocator(nbins=5, prune='both'))
365  self._axes[i, None].yaxis.set_major_locator(matplotlib.ticker.MaxNLocator(nbins=5, prune='both'))
366  self._axes[None, i].xaxis.set_minor_locator(matplotlib.ticker.AutoMinorLocator())
367  self._axes[i, None].yaxis.set_minor_locator(matplotlib.ticker.AutoMinorLocator())
368 
369  def _get_active(self): return self._active
370 
371  def _set_active(self, active):
372  s = set(active)
373  if len(s) != len(active):
374  raise ValueError("Active set contains duplicates")
375  if not self._all_dims.issuperset(s):
376  raise ValueError("Invalid values in active set")
377  self._active = tuple(active)
378  self.replot()
379  active = property(_get_active, _set_active, doc="sequence of active dimensions to plot (sequence of str)")
380 
381  def replot(self):
382  self._lower = {dim: min(self.data[k].lower[dim] for k in self.data) for dim in self._active}
383  self._upper = {dim: max(self.data[k].upper[dim] for k in self.data) for dim in self._active}
384  self._build_axes()
385  for name, layer in self.layers.items():
386  self._plotLayer(name, layer)
387 
388  def _build_axes(self):
389  self.figure.clear()
390  self._axes = dict()
391  self._objs = dict()
392  n = len(self._active)
393  iStride = n + 1
394  jStride = -1
395  iStart = n + 1
396  jStart = n
397  for i in range(n):
398  j = i
399  axesX = self._axes[None, j] = self.figure.add_subplot(n+1, n+1, jStart+j*jStride)
400  axesX.autoscale(False, axis='x')
401  axesX.xaxis.tick_top()
402  axesX.set_xlim(self._lower[self._active[j]], self._upper[self._active[j]])
403  hide_yticklabels(axesX)
404  bbox = axesX.get_position()
405  bbox.y1 -= 0.035
406  axesX.set_position(bbox)
407  axesY = self._axes[i, None] = self.figure.add_subplot(n+1, n+1, iStart + iStart+i*iStride)
408  axesY.autoscale(False, axis='y')
409  axesY.yaxis.tick_right()
410  axesY.set_ylim(self._lower[self._active[i]], self._upper[self._active[i]])
411  hide_xticklabels(axesY)
412  bbox = axesY.get_position()
413  bbox.x1 -= 0.035
414  axesY.set_position(bbox)
415  for i in range(n):
416  for j in range(n):
417  axesXY = self._axes[i, j] = self.figure.add_subplot(
418  n+1, n+1, iStart+i*iStride + jStart+j*jStride,
419  sharex=self._axes[None, j],
420  sharey=self._axes[i, None]
421  )
422  axesXY.autoscale(False)
423  if j < n - 1:
424  hide_yticklabels(axesXY)
425  if i < n - 1:
426  hide_xticklabels(axesXY)
427  for i in range(n):
428  j = i
429  xbox = self._axes[None, j].get_position()
430  ybox = self._axes[i, None].get_position()
431  self.figure.text(0.5*(xbox.x0 + xbox.x1), 0.5*(ybox.y0 + ybox.y1), self.active[i],
432  ha='center', va='center', weight='bold')
433  self._axes[i, j].get_frame().set_facecolor('none')
434 
435  def draw(self):
436  self.figure.canvas.draw()
437 
438 
439 class ExampleData(object):
440  """An example data object for DensityPlot, demonstrating the necessarity interface.
441 
442  There are two levels of requirements for a data object. First are the attributes
443  required by the DensityPlot object itself; these must be present on every data object:
444 
445  dimensions ------ a sequence of strings that provide names for the dimensions
446 
447  lower ----------- a dictionary of {dimension-name: lower-bound}
448 
449  upper ----------- a dictionary of {dimension-name: upper-bound}
450 
451  The second level of requirements are those of the Layer objects provided here. These
452  may be absent if the associated Layer is not used or is subclassed to reimplement the
453  Layer method that calls the data object method. Currently, these include:
454 
455  eval1d, eval2d -- methods used by the SurfaceLayer class; see their docs for more info
456 
457  values ---------- attribute used by the HistogramLayer and ScatterLayer classes, an array
458  with shape (M,N), where N is the number of dimension and M is the number
459  of data points
460 
461  weights --------- optional attribute used by the HistogramLayer and ScatterLayer classes,
462  a 1-d array with size=M that provides weights for each data point
463  """
464 
465  def __init__(self):
466  self.dimensions = ["a", "b", "c"]
467  self.mu = numpy.array([-10.0, 0.0, 10.0])
468  self.sigma = numpy.array([3.0, 2.0, 1.0])
469  self.lower = {dim: -3*self.sigma[i] + self.mu[i] for i, dim in enumerate(self.dimensions)}
470  self.upper = {dim: 3*self.sigma[i] + self.mu[i] for i, dim in enumerate(self.dimensions)}
471  self.values = numpy.random.randn(2000, 3) * self.sigma[numpy.newaxis, :] + self.mu[numpy.newaxis, :]
472 
473  def eval1d(self, dim, x):
474  """Evaluate the 1-d analytic function for the given dim at points x (a 1-d numpy array;
475  this method must be numpy-vectorized).
476  """
477  i = self.dimensions.index(dim)
478  return numpy.exp(-0.5*((x-self.mu[i])/self.sigma[i])**2) / ((2.0*numpy.pi)**0.5 * self.sigma[i])
479 
480  def eval2d(self, xDim, yDim, x, y):
481  """Evaluate the 2-d analytic function for the given xDim and yDim at points x,y
482  (2-d numpy arrays with the same shape; this method must be numpy-vectorized).
483  """
484  i = self.dimensions.index(yDim)
485  j = self.dimensions.index(xDim)
486  return (numpy.exp(-0.5*(((x-self.mu[j])/self.sigma[j])**2 + ((y-self.mu[i])/self.sigma[i])**2))
487  / (2.0*numpy.pi * self.sigma[j]*self.sigma[i]))
488 
489 
490 def demo():
491  """Create and return a DensityPlot with example data."""
492  fig = matplotlib.pyplot.figure()
493  p = DensityPlot(fig, primary=ExampleData())
494  p.layers['histogram'] = HistogramLayer('primary')
495  p.layers['surface'] = SurfaceLayer('primary')
496  p.draw()
497  return p
def hist2d(self, data, xDim, yDim, xLimits, yLimits)
Definition: densityPlot.py:106
def __init__(self, tag, colors=("y", "m", "c", "r", "g", "b"), kwds)
Definition: densityPlot.py:241
def __init__(self, tag, bins1d=20, bins2d=(20, 20), kwds1d=None, kwds2d=None)
Definition: densityPlot.py:84
def eval2d(self, data, xDim, yDim, x, y)
Definition: densityPlot.py:198
def __init__(self, tag, steps1d=200, steps2d=200, filled=False, kwds1d=None, kwds2d=None)
Definition: densityPlot.py:186
def plotXY(self, axes, data, xDim, yDim)
Definition: densityPlot.py:218
def plotXY(self, axes, data, xDim, yDim)
Definition: densityPlot.py:164