lsst.meas.modelfit  15.0-3-g150fc43+8
densityPlot.py
Go to the documentation of this file.
1 #
2 # LSST Data Management System
3 # Copyright 2008-2013 LSST Corporation.
4 #
5 # This product includes software developed by the
6 # LSST Project (http://www.lsst.org/).
7 #
8 # This program is free software: you can redistribute it and/or modify
9 # it under the terms of the GNU General Public License as published by
10 # the Free Software Foundation, either version 3 of the License, or
11 # (at your option) any later version.
12 #
13 # This program is distributed in the hope that it will be useful,
14 # but WITHOUT ANY WARRANTY; without even the implied warranty of
15 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 # GNU General Public License for more details.
17 #
18 # You should have received a copy of the LSST License Statement and
19 # the GNU General Public License along with this program. If not,
20 # see <http://www.lsstcorp.org/LegalNotices/>.
21 #
22 
23 """A set of matplotlib-based classes that displays a grid of 1-d and 2-d slices through an
24 N-d density.
25 
26 The main class, DensityPlot, manages the grid of matplotlib.axes.Axes objects, and holds
27 a sequence of Layer objects that each know how to draw individual 1-d or 2-d plots and a
28 data object that abstracts away how the N-d density data is actually represented.
29 
30 For simple cases, users can just create a custom data class with an interface like that of
31 the ExampleData class provided here, and use the provided HistogramLayer and SurfaceLayer
32 classes directly. In more complicated cases, users may want to create their own Layer classes,
33 which may define their own relationship with the data object.
34 """
35 from builtins import range
36 from builtins import object
37 
38 import collections
39 import numpy
40 import matplotlib.cm
41 import matplotlib.pyplot
42 import matplotlib.ticker
43 
44 __all__ = ("HistogramLayer", "SurfaceLayer", "ScatterLayer", "CrossPointsLayer",
45  "DensityPlot", "ExampleData", "demo")
46 
47 
48 def hide_xticklabels(axes):
49  for label in axes.get_xticklabels():
50  label.set_visible(False)
51 
52 
53 def hide_yticklabels(axes):
54  for label in axes.get_yticklabels():
55  label.set_visible(False)
56 
57 
58 def mergeDefaults(kwds, defaults):
59  copy = defaults.copy()
60  if kwds is not None:
61  copy.update(**kwds)
62  return copy
63 
64 
65 class HistogramLayer(object):
66  """A Layer class for DensityPlot for gridded histograms, drawing bar plots in 1-d and
67  colormapped large-pixel images in 2-d.
68 
69  Relies on two data object attributes:
70 
71  values ----- a (M,N) array of data points, where N is the dimension of the dataset and M is the
72  number of data points
73 
74  weights ---- (optional) an array of weights with shape (M,); if not present, all weights will
75  be set to unity
76 
77  The need for these data object attributes can be removed by subclassing HistogramLayer and overriding
78  the hist1d and hist2d methods.
79  """
80 
81  defaults1d = dict(facecolor='b', alpha=0.5)
82  defaults2d = dict(cmap=matplotlib.cm.Blues, vmin=0.0, interpolation='nearest')
83 
84  def __init__(self, tag, bins1d=20, bins2d=(20, 20), kwds1d=None, kwds2d=None):
85  self.tag = tag
86  self.bins1d = bins1d
87  self.bins2d = bins2d
88  self.kwds1d = mergeDefaults(kwds1d, self.defaults1d)
89  self.kwds2d = mergeDefaults(kwds2d, self.defaults2d)
90 
91  def hist1d(self, data, dim, limits):
92  """Extract points from the data object and compute a 1-d histogram.
93 
94  Return value should match that of numpy.histogram: a tuple of (hist, edges),
95  where hist is a 1-d array with size=bins1d, and edges is a 1-d array with
96  size=self.bins1d+1 giving the upper and lower edges of the bins.
97  """
98  i = data.dimensions.index(dim)
99  if hasattr(data, "weights") and data.weights is not None:
100  weights = data.weights
101  else:
102  weights = None
103  return numpy.histogram(data.values[:, i], bins=self.bins1d, weights=weights,
104  range=limits, normed=True)
105 
106  def hist2d(self, data, xDim, yDim, xLimits, yLimits):
107  """Extract points from the data object and compute a 1-d histogram.
108 
109  Return value should match that of numpy.histogram2d: a tuple of (hist, xEdges, yEdges),
110  where hist is a 2-d array with shape=bins2d, xEdges is a 1-d array with size=bins2d[0]+1,
111  and yEdges is a 1-d array with size=bins2d[1]+1.
112  """
113  i = data.dimensions.index(yDim)
114  j = data.dimensions.index(xDim)
115  if hasattr(data, "weights") and data.weights is not None:
116  weights = data.weights
117  else:
118  weights = None
119  return numpy.histogram2d(data.values[:, j], data.values[:, i], bins=self.bins2d, weights=weights,
120  range=(xLimits, yLimits), normed=True)
121 
122  def plotX(self, axes, data, dim):
123  y, xEdge = self.hist1d(data, dim, axes.get_xlim())
124  xCenter = 0.5*(xEdge[:-1] + xEdge[1:])
125  width = xEdge[1:] - xEdge[:-1]
126  return axes.bar(xCenter, y, width=width, align='center', **self.kwds1d)
127 
128  def plotY(self, axes, data, dim):
129  x, yEdge = self.hist1d(data, dim, axes.get_ylim())
130  yCenter = 0.5*(yEdge[:-1] + yEdge[1:])
131  height = yEdge[1:] - yEdge[:-1]
132  return axes.barh(yCenter, x, height=height, align='center', **self.kwds1d)
133 
134  def plotXY(self, axes, data, xDim, yDim):
135  z, xEdge, yEdge = self.hist2d(data, xDim, yDim, axes.get_xlim(), axes.get_ylim())
136  return axes.imshow(z.transpose(), aspect='auto', extent=(xEdge[0], xEdge[-1], yEdge[0], yEdge[-1]),
137  origin='lower', **self.kwds2d)
138 
139 
140 class ScatterLayer(object):
141  """A Layer class that plots individual points in 2-d, and does nothing in 1-d.
142 
143  Relies on two data object attributes:
144 
145  values ----- a (M,N) array of data points, where N is the dimension of the dataset and M is the
146  number of data points
147 
148  weights ---- (optional) an array of weights with shape (M,); will be used to set the color of points
149 
150  """
151 
152  defaults = dict(linewidth=0, alpha=0.2)
153 
154  def __init__(self, tag, **kwds):
155  self.tag = tag
156  self.kwds = mergeDefaults(kwds, self.defaults)
157 
158  def plotX(self, axes, data, dim):
159  pass
160 
161  def plotY(self, axes, data, dim):
162  pass
163 
164  def plotXY(self, axes, data, xDim, yDim):
165  i = data.dimensions.index(yDim)
166  j = data.dimensions.index(xDim)
167  if hasattr(data, "weights") and data.weights is not None:
168  args = data.values[:, j], data.values[:, i], data.weights
169  else:
170  args = data.values[:, j], data.values[:, i]
171  return axes.scatter(*args, **self.kwds)
172 
173 
174 class SurfaceLayer(object):
175  """A Layer class for analytic N-d distributions that can be evaluated in 1-d or 2-d slices.
176 
177  The 2-d slices are drawn as contours, and the 1-d slices are drawn as simple curves.
178 
179  Relies on eval1d and eval2d methods in the data object; this can be avoided by subclassing
180  SurfaceLayer and reimplementing its own eval1d and eval2d methods.
181  """
182 
183  defaults1d = dict(linewidth=2, color='r')
184  defaults2d = dict(linewidths=2, cmap=matplotlib.cm.Reds)
185 
186  def __init__(self, tag, steps1d=200, steps2d=200, filled=False, kwds1d=None, kwds2d=None):
187  self.tag = tag
188  self.steps1d = int(steps1d)
189  self.steps2d = int(steps2d)
190  self.filled = bool(filled)
191  self.kwds1d = mergeDefaults(kwds1d, self.defaults1d)
192  self.kwds2d = mergeDefaults(kwds2d, self.defaults2d)
193 
194  def eval1d(self, data, dim, x):
195  """Return analytic function values for the given values."""
196  return data.eval1d(dim, x)
197 
198  def eval2d(self, data, xDim, yDim, x, y):
199  """Return analytic function values for the given values."""
200  return data.eval2d(xDim, yDim, x, y)
201 
202  def plotX(self, axes, data, dim):
203  xMin, xMax = axes.get_xlim()
204  x = numpy.linspace(xMin, xMax, self.steps1d)
205  z = self.eval1d(data, dim, x)
206  if z is None:
207  return
208  return axes.plot(x, z, **self.kwds1d)
209 
210  def plotY(self, axes, data, dim):
211  yMin, yMax = axes.get_ylim()
212  y = numpy.linspace(yMin, yMax, self.steps1d)
213  z = self.eval1d(data, dim, y)
214  if z is None:
215  return
216  return axes.plot(z, y, **self.kwds1d)
217 
218  def plotXY(self, axes, data, xDim, yDim):
219  xMin, xMax = axes.get_xlim()
220  yMin, yMax = axes.get_ylim()
221  xc = numpy.linspace(xMin, xMax, self.steps2d)
222  yc = numpy.linspace(yMin, yMax, self.steps2d)
223  xg, yg = numpy.meshgrid(xc, yc)
224  z = self.eval2d(data, xDim, yDim, xg, yg)
225  if z is None:
226  return
227  if self.filled:
228  return axes.contourf(xg, yg, z, 6, **self.kwds2d)
229  else:
230  return axes.contour(xg, yg, z, 6, **self.kwds2d)
231 
232 
233 class CrossPointsLayer(object):
234  """A layer that marks a few points with axis-length vertical and horizontal lines.
235 
236  This relies on a "points" data object attribute.
237  """
238 
239  defaults = dict(alpha=0.8)
240 
241  def __init__(self, tag, colors=("y", "m", "c", "r", "g", "b"), **kwds):
242  self.tag = tag
243  self.colors = colors
244  self.kwds = mergeDefaults(kwds, self.defaults)
245 
246  def plotX(self, axes, data, dim):
247  i = data.dimensions.index(dim)
248  artists = []
249  for n, point in enumerate(data.points):
250  artists.append(axes.axvline(point[i], color=self.colors[n % len(self.colors)], **self.kwds))
251  return artists
252 
253  def plotY(self, axes, data, dim):
254  i = data.dimensions.index(dim)
255  artists = []
256  for n, point in enumerate(data.points):
257  artists.append(axes.axhline(point[i], color=self.colors[n % len(self.colors)], **self.kwds))
258  return artists
259 
260  def plotXY(self, axes, data, xDim, yDim):
261  i = data.dimensions.index(yDim)
262  j = data.dimensions.index(xDim)
263  artists = []
264  for n, point in enumerate(data.points):
265  artists.append(axes.axvline(point[j], color=self.colors[n % len(self.colors)], **self.kwds))
266  artists.append(axes.axhline(point[i], color=self.colors[n % len(self.colors)], **self.kwds))
267  return artists
268 
269 
270 class DensityPlot(object):
271  """An object that manages a matrix of matplotlib.axes.Axes objects that represent a set of 1-d and 2-d
272  slices through an N-d density.
273  """
274 
275  class LayerDict(collections.MutableMapping):
276 
277  def __init__(self, parent):
278  self._dict = dict()
279  self._parent = parent
280 
281  def __delitem__(self, name):
282  layer = self._dict.pop(name)
283  self._parent._dropLayer(name, layer)
284 
285  def __setitem__(self, name, layer):
286  self.pop(name, None)
287  self._dict[name] = layer
288  self._parent._plotLayer(name, layer)
289 
290  def __getitem__(self, name):
291  return self._dict[name]
292 
293  def __iter__(self):
294  return iter(self._dict)
295 
296  def __len__(self):
297  return len(self._dict)
298 
299  def __str__(self):
300  return str(self._dict)
301 
302  def __repr__(self):
303  return repr(self._dict)
304 
305  def replot(self, name):
306  layer = self._dict[name]
307  self._parent._dropLayer(name, layer)
308  self._parent._plotLayer(name, layer)
309 
310  def __init__(self, figure, **kwds):
311  self.figure = figure
312  self.data = dict(kwds)
313  active = []
314  self._lower = dict()
315  self._upper = dict()
316  # We merge the dimension name lists manually rather than using sets to preserve the order.
317  # Most of the time we expect all data objects to have the same dimensions anyway.
318  for v in self.data.values():
319  for dim in v.dimensions:
320  if dim not in active:
321  active.append(dim)
322  self._lower[dim] = v.lower[dim]
323  self._upper[dim] = v.upper[dim]
324  else:
325  self._lower[dim] = min(v.lower[dim], self._lower[dim])
326  self._upper[dim] = max(v.upper[dim], self._upper[dim])
327  self._active = tuple(active)
328  self._all_dims = frozenset(self._active)
329  self.figure.subplots_adjust(left=0.05, right=0.95, bottom=0.05, top=0.95, hspace=0.01, wspace=0.01)
330  self._build_axes()
331  self.layers = self.LayerDict(self)
332 
333  def _dropLayer(self, name, layer):
334  def removeArtist(*key):
335  try:
336  self._objs.pop(key).remove()
337  except AttributeError:
338  # sometimes the value might be None, which doesn't have a remove
339  pass
340  except TypeError:
341  # probably a matplotlib bug: remove sometimes raises an exception,
342  # but it still works
343  pass
344  for i, yDim in enumerate(self._active):
345  removeArtist(None, i, name)
346  removeArtist(i, None, name)
347  for j, xDim in enumerate(self._active):
348  if i == j:
349  continue
350  removeArtist(i, j, name)
351 
352  def _plotLayer(self, name, layer):
353  for i, yDim in enumerate(self._active):
354  if yDim not in self.data[layer.tag].dimensions:
355  continue
356  self._objs[None, i, name] = layer.plotX(self._axes[None, i], self.data[layer.tag], yDim)
357  self._objs[i, None, name] = layer.plotY(self._axes[i, None], self.data[layer.tag], yDim)
358  for j, xDim in enumerate(self._active):
359  if xDim not in self.data[layer.tag].dimensions:
360  continue
361  if i == j:
362  continue
363  self._objs[i, j, name] = layer.plotXY(self._axes[i, j], self.data[layer.tag], xDim, yDim)
364  self._axes[None, i].xaxis.set_major_locator(matplotlib.ticker.MaxNLocator(nbins=5, prune='both'))
365  self._axes[i, None].yaxis.set_major_locator(matplotlib.ticker.MaxNLocator(nbins=5, prune='both'))
366  self._axes[None, i].xaxis.set_minor_locator(matplotlib.ticker.AutoMinorLocator())
367  self._axes[i, None].yaxis.set_minor_locator(matplotlib.ticker.AutoMinorLocator())
368 
369  def _get_active(self):
370  return self._active
371 
372  def _set_active(self, active):
373  s = set(active)
374  if len(s) != len(active):
375  raise ValueError("Active set contains duplicates")
376  if not self._all_dims.issuperset(s):
377  raise ValueError("Invalid values in active set")
378  self._active = tuple(active)
379  self.replot()
380  active = property(_get_active, _set_active, doc="sequence of active dimensions to plot (sequence of str)")
381 
382  def replot(self):
383  self._lower = {dim: min(self.data[k].lower[dim] for k in self.data) for dim in self._active}
384  self._upper = {dim: max(self.data[k].upper[dim] for k in self.data) for dim in self._active}
385  self._build_axes()
386  for name, layer in self.layers.items():
387  self._plotLayer(name, layer)
388 
389  def _build_axes(self):
390  self.figure.clear()
391  self._axes = dict()
392  self._objs = dict()
393  n = len(self._active)
394  iStride = n + 1
395  jStride = -1
396  iStart = n + 1
397  jStart = n
398  for i in range(n):
399  j = i
400  axesX = self._axes[None, j] = self.figure.add_subplot(n+1, n+1, jStart+j*jStride)
401  axesX.autoscale(False, axis='x')
402  axesX.xaxis.tick_top()
403  axesX.set_xlim(self._lower[self._active[j]], self._upper[self._active[j]])
404  hide_yticklabels(axesX)
405  bbox = axesX.get_position()
406  bbox.y1 -= 0.035
407  axesX.set_position(bbox)
408  axesY = self._axes[i, None] = self.figure.add_subplot(n+1, n+1, iStart + iStart+i*iStride)
409  axesY.autoscale(False, axis='y')
410  axesY.yaxis.tick_right()
411  axesY.set_ylim(self._lower[self._active[i]], self._upper[self._active[i]])
412  hide_xticklabels(axesY)
413  bbox = axesY.get_position()
414  bbox.x1 -= 0.035
415  axesY.set_position(bbox)
416  for i in range(n):
417  for j in range(n):
418  axesXY = self._axes[i, j] = self.figure.add_subplot(
419  n+1, n+1, iStart+i*iStride + jStart+j*jStride,
420  sharex=self._axes[None, j],
421  sharey=self._axes[i, None]
422  )
423  axesXY.autoscale(False)
424  if j < n - 1:
425  hide_yticklabels(axesXY)
426  if i < n - 1:
427  hide_xticklabels(axesXY)
428  for i in range(n):
429  j = i
430  xbox = self._axes[None, j].get_position()
431  ybox = self._axes[i, None].get_position()
432  self.figure.text(0.5*(xbox.x0 + xbox.x1), 0.5*(ybox.y0 + ybox.y1), self.active[i],
433  ha='center', va='center', weight='bold')
434  self._axes[i, j].get_frame().set_facecolor('none')
435 
436  def draw(self):
437  self.figure.canvas.draw()
438 
439 
440 class ExampleData(object):
441  """An example data object for DensityPlot, demonstrating the necessarity interface.
442 
443  There are two levels of requirements for a data object. First are the attributes
444  required by the DensityPlot object itself; these must be present on every data object:
445 
446  dimensions ------ a sequence of strings that provide names for the dimensions
447 
448  lower ----------- a dictionary of {dimension-name: lower-bound}
449 
450  upper ----------- a dictionary of {dimension-name: upper-bound}
451 
452  The second level of requirements are those of the Layer objects provided here. These
453  may be absent if the associated Layer is not used or is subclassed to reimplement the
454  Layer method that calls the data object method. Currently, these include:
455 
456  eval1d, eval2d -- methods used by the SurfaceLayer class; see their docs for more info
457 
458  values ---------- attribute used by the HistogramLayer and ScatterLayer classes, an array
459  with shape (M,N), where N is the number of dimension and M is the number
460  of data points
461 
462  weights --------- optional attribute used by the HistogramLayer and ScatterLayer classes,
463  a 1-d array with size=M that provides weights for each data point
464  """
465 
466  def __init__(self):
467  self.dimensions = ["a", "b", "c"]
468  self.mu = numpy.array([-10.0, 0.0, 10.0])
469  self.sigma = numpy.array([3.0, 2.0, 1.0])
470  self.lower = {dim: -3*self.sigma[i] + self.mu[i] for i, dim in enumerate(self.dimensions)}
471  self.upper = {dim: 3*self.sigma[i] + self.mu[i] for i, dim in enumerate(self.dimensions)}
472  self.values = numpy.random.randn(2000, 3) * self.sigma[numpy.newaxis, :] + self.mu[numpy.newaxis, :]
473 
474  def eval1d(self, dim, x):
475  """Evaluate the 1-d analytic function for the given dim at points x (a 1-d numpy array;
476  this method must be numpy-vectorized).
477  """
478  i = self.dimensions.index(dim)
479  return numpy.exp(-0.5*((x-self.mu[i])/self.sigma[i])**2) / ((2.0*numpy.pi)**0.5 * self.sigma[i])
480 
481  def eval2d(self, xDim, yDim, x, y):
482  """Evaluate the 2-d analytic function for the given xDim and yDim at points x,y
483  (2-d numpy arrays with the same shape; this method must be numpy-vectorized).
484  """
485  i = self.dimensions.index(yDim)
486  j = self.dimensions.index(xDim)
487  return (numpy.exp(-0.5*(((x-self.mu[j])/self.sigma[j])**2 + ((y-self.mu[i])/self.sigma[i])**2)) /
488  (2.0*numpy.pi * self.sigma[j]*self.sigma[i]))
489 
490 
491 def demo():
492  """Create and return a DensityPlot with example data."""
493  fig = matplotlib.pyplot.figure()
494  p = DensityPlot(fig, primary=ExampleData())
495  p.layers['histogram'] = HistogramLayer('primary')
496  p.layers['surface'] = SurfaceLayer('primary')
497  p.draw()
498  return p
def hist2d(self, data, xDim, yDim, xLimits, yLimits)
Definition: densityPlot.py:106
def __init__(self, tag, colors=("y", "m", "c", "r", "g", "b"), kwds)
Definition: densityPlot.py:241
def __init__(self, tag, bins1d=20, bins2d=(20, 20), kwds1d=None, kwds2d=None)
Definition: densityPlot.py:84
def eval2d(self, data, xDim, yDim, x, y)
Definition: densityPlot.py:198
def __init__(self, tag, steps1d=200, steps2d=200, filled=False, kwds1d=None, kwds2d=None)
Definition: densityPlot.py:186
def plotXY(self, axes, data, xDim, yDim)
Definition: densityPlot.py:218
def plotXY(self, axes, data, xDim, yDim)
Definition: densityPlot.py:164