23 """A set of matplotlib-based classes that displays a grid of 1-d and 2-d slices through an 26 The main class, DensityPlot, manages the grid of matplotlib.axes.Axes objects, and holds 27 a sequence of Layer objects that each know how to draw individual 1-d or 2-d plots and a 28 data object that abstracts away how the N-d density data is actually represented. 30 For simple cases, users can just create a custom data class with an interface like that of 31 the ExampleData class provided here, and use the provided HistogramLayer and SurfaceLayer 32 classes directly. In more complicated cases, users may want to create their own Layer classes, 33 which may define their own relationship with the data object. 35 from builtins
import range
36 from builtins
import object
41 import matplotlib.pyplot
42 import matplotlib.ticker
44 __all__ = (
"HistogramLayer",
"SurfaceLayer",
"ScatterLayer",
"CrossPointsLayer",
45 "DensityPlot",
"ExampleData",
"demo")
49 for label
in axes.get_xticklabels():
50 label.set_visible(
False)
54 for label
in axes.get_yticklabels():
55 label.set_visible(
False)
59 copy = defaults.copy()
66 """A Layer class for DensityPlot for gridded histograms, drawing bar plots in 1-d and 67 colormapped large-pixel images in 2-d. 69 Relies on two data object attributes: 71 values ----- a (M,N) array of data points, where N is the dimension of the dataset and M is the 74 weights ---- (optional) an array of weights with shape (M,); if not present, all weights will 77 The need for these data object attributes can be removed by subclassing HistogramLayer and overriding 78 the hist1d and hist2d methods. 81 defaults1d = dict(facecolor=
'b', alpha=0.5)
82 defaults2d = dict(cmap=matplotlib.cm.Blues, vmin=0.0, interpolation=
'nearest')
84 def __init__(self, tag, bins1d=20, bins2d=(20, 20), kwds1d=
None, kwds2d=
None):
92 """Extract points from the data object and compute a 1-d histogram. 94 Return value should match that of numpy.histogram: a tuple of (hist, edges), 95 where hist is a 1-d array with size=bins1d, and edges is a 1-d array with 96 size=self.bins1d+1 giving the upper and lower edges of the bins. 98 i = data.dimensions.index(dim)
99 if hasattr(data,
"weights")
and data.weights
is not None:
100 weights = data.weights
103 return numpy.histogram(data.values[:, i], bins=self.
bins1d, weights=weights,
104 range=limits, normed=
True)
106 def hist2d(self, data, xDim, yDim, xLimits, yLimits):
107 """Extract points from the data object and compute a 1-d histogram. 109 Return value should match that of numpy.histogram2d: a tuple of (hist, xEdges, yEdges), 110 where hist is a 2-d array with shape=bins2d, xEdges is a 1-d array with size=bins2d[0]+1, 111 and yEdges is a 1-d array with size=bins2d[1]+1. 113 i = data.dimensions.index(yDim)
114 j = data.dimensions.index(xDim)
115 if hasattr(data,
"weights")
and data.weights
is not None:
116 weights = data.weights
119 return numpy.histogram2d(data.values[:, j], data.values[:, i], bins=self.
bins2d, weights=weights,
120 range=(xLimits, yLimits), normed=
True)
123 y, xEdge = self.
hist1d(data, dim, axes.get_xlim())
124 xCenter = 0.5*(xEdge[:-1] + xEdge[1:])
125 width = xEdge[1:] - xEdge[:-1]
126 return axes.bar(xCenter, y, width=width, align=
'center', **self.
kwds1d)
129 x, yEdge = self.
hist1d(data, dim, axes.get_ylim())
130 yCenter = 0.5*(yEdge[:-1] + yEdge[1:])
131 height = yEdge[1:] - yEdge[:-1]
132 return axes.barh(yCenter, x, height=height, align=
'center', **self.
kwds1d)
134 def plotXY(self, axes, data, xDim, yDim):
135 z, xEdge, yEdge = self.
hist2d(data, xDim, yDim, axes.get_xlim(), axes.get_ylim())
136 return axes.imshow(z.transpose(), aspect=
'auto', extent=(xEdge[0], xEdge[-1], yEdge[0], yEdge[-1]),
137 origin=
'lower', **self.
kwds2d)
141 """A Layer class that plots individual points in 2-d, and does nothing in 1-d. 143 Relies on two data object attributes: 145 values ----- a (M,N) array of data points, where N is the dimension of the dataset and M is the 146 number of data points 148 weights ---- (optional) an array of weights with shape (M,); will be used to set the color of points 152 defaults = dict(linewidth=0, alpha=0.2)
164 def plotXY(self, axes, data, xDim, yDim):
165 i = data.dimensions.index(yDim)
166 j = data.dimensions.index(xDim)
167 if hasattr(data,
"weights")
and data.weights
is not None:
168 args = data.values[:, j], data.values[:, i], data.weights
170 args = data.values[:, j], data.values[:, i]
171 return axes.scatter(*args, **self.kwds)
175 """A Layer class for analytic N-d distributions that can be evaluated in 1-d or 2-d slices. 177 The 2-d slices are drawn as contours, and the 1-d slices are drawn as simple curves. 179 Relies on eval1d and eval2d methods in the data object; this can be avoided by subclassing 180 SurfaceLayer and reimplementing its own eval1d and eval2d methods. 183 defaults1d = dict(linewidth=2, color=
'r') 184 defaults2d = dict(linewidths=2, cmap=matplotlib.cm.Reds) 186 def __init__(self, tag, steps1d=200, steps2d=200, filled=False, kwds1d=None, kwds2d=None):
195 """Return analytic function values for the given values.""" 196 return data.eval1d(dim, x)
198 def eval2d(self, data, xDim, yDim, x, y):
199 """Return analytic function values for the given values.""" 200 return data.eval2d(xDim, yDim, x, y)
203 xMin, xMax = axes.get_xlim()
204 x = numpy.linspace(xMin, xMax, self.
steps1d)
205 z = self.
eval1d(data, dim, x)
208 return axes.plot(x, z, **self.
kwds1d)
211 yMin, yMax = axes.get_ylim()
212 y = numpy.linspace(yMin, yMax, self.
steps1d)
213 z = self.
eval1d(data, dim, y)
216 return axes.plot(z, y, **self.
kwds1d)
218 def plotXY(self, axes, data, xDim, yDim):
219 xMin, xMax = axes.get_xlim()
220 yMin, yMax = axes.get_ylim()
221 xc = numpy.linspace(xMin, xMax, self.
steps2d)
222 yc = numpy.linspace(yMin, yMax, self.
steps2d)
223 xg, yg = numpy.meshgrid(xc, yc)
224 z = self.
eval2d(data, xDim, yDim, xg, yg)
228 return axes.contourf(xg, yg, z, 6, **self.
kwds2d)
230 return axes.contour(xg, yg, z, 6, **self.
kwds2d)
234 """A layer that marks a few points with axis-length vertical and horizontal lines. 236 This relies on a "points" data object attribute. 239 defaults = dict(alpha=0.8)
241 def __init__(self, tag, colors=(
"y",
"m",
"c",
"r", "g", "b"), **kwds):
247 i = data.dimensions.index(dim)
249 for n, point
in enumerate(data.points):
250 artists.append(axes.axvline(point[i], color=self.
colors[n % len(self.
colors)], **self.
kwds))
254 i = data.dimensions.index(dim)
256 for n, point
in enumerate(data.points):
257 artists.append(axes.axhline(point[i], color=self.
colors[n % len(self.
colors)], **self.
kwds))
260 def plotXY(self, axes, data, xDim, yDim):
261 i = data.dimensions.index(yDim)
262 j = data.dimensions.index(xDim)
264 for n, point
in enumerate(data.points):
265 artists.append(axes.axvline(point[j], color=self.
colors[n % len(self.
colors)], **self.
kwds))
266 artists.append(axes.axhline(point[i], color=self.
colors[n % len(self.
colors)], **self.
kwds))
271 """An object that manages a matrix of matplotlib.axes.Axes objects that represent a set of 1-d and 2-d 272 slices through an N-d density. 282 layer = self.
_dict.pop(name)
283 self.
_parent._dropLayer(name, layer)
287 self.
_dict[name] = layer
288 self.
_parent._plotLayer(name, layer)
291 return self.
_dict[name]
294 return iter(self.
_dict)
297 return len(self.
_dict)
300 return str(self.
_dict)
303 return repr(self.
_dict)
306 layer = self.
_dict[name]
307 self.
_parent._dropLayer(name, layer)
308 self.
_parent._plotLayer(name, layer)
318 for v
in self.
data.values():
319 for dim
in v.dimensions:
320 if dim
not in active:
322 self.
_lower[dim] = v.lower[dim]
323 self.
_upper[dim] = v.upper[dim]
329 self.
figure.subplots_adjust(left=0.05, right=0.95, bottom=0.05, top=0.95, hspace=0.01, wspace=0.01)
333 def _dropLayer(self, name, layer):
334 def removeArtist(*key):
336 self.
_objs.pop(key).remove()
337 except AttributeError:
344 for i, yDim
in enumerate(self.
_active):
345 removeArtist(
None, i, name)
346 removeArtist(i,
None, name)
347 for j, xDim
in enumerate(self.
_active):
350 removeArtist(i, j, name)
352 def _plotLayer(self, name, layer):
353 for i, yDim
in enumerate(self.
_active):
354 if yDim
not in self.
data[layer.tag].dimensions:
356 self.
_objs[
None, i, name] = layer.plotX(self.
_axes[
None, i], self.
data[layer.tag], yDim)
357 self.
_objs[i,
None, name] = layer.plotY(self.
_axes[i,
None], self.
data[layer.tag], yDim)
358 for j, xDim
in enumerate(self.
_active):
359 if xDim
not in self.
data[layer.tag].dimensions:
363 self.
_objs[i, j, name] = layer.plotXY(self.
_axes[i, j], self.
data[layer.tag], xDim, yDim)
364 self.
_axes[
None, i].xaxis.set_major_locator(matplotlib.ticker.MaxNLocator(nbins=5, prune=
'both'))
365 self.
_axes[i,
None].yaxis.set_major_locator(matplotlib.ticker.MaxNLocator(nbins=5, prune=
'both'))
366 self.
_axes[
None, i].xaxis.set_minor_locator(matplotlib.ticker.AutoMinorLocator())
367 self.
_axes[i,
None].yaxis.set_minor_locator(matplotlib.ticker.AutoMinorLocator())
369 def _get_active(self):
372 def _set_active(self, active):
374 if len(s) != len(active):
375 raise ValueError(
"Active set contains duplicates")
377 raise ValueError(
"Invalid values in active set")
380 active = property(_get_active, _set_active, doc=
"sequence of active dimensions to plot (sequence of str)")
386 for name, layer
in self.
layers.items():
389 def _build_axes(self):
400 axesX = self.
_axes[
None, j] = self.
figure.add_subplot(n+1, n+1, jStart+j*jStride)
401 axesX.autoscale(
False, axis=
'x')
402 axesX.xaxis.tick_top()
405 bbox = axesX.get_position()
407 axesX.set_position(bbox)
408 axesY = self.
_axes[i,
None] = self.
figure.add_subplot(n+1, n+1, iStart + iStart+i*iStride)
409 axesY.autoscale(
False, axis=
'y')
410 axesY.yaxis.tick_right()
413 bbox = axesY.get_position()
415 axesY.set_position(bbox)
418 axesXY = self.
_axes[i, j] = self.
figure.add_subplot(
419 n+1, n+1, iStart+i*iStride + jStart+j*jStride,
420 sharex=self.
_axes[
None, j],
421 sharey=self.
_axes[i,
None]
423 axesXY.autoscale(
False)
430 xbox = self.
_axes[
None, j].get_position()
431 ybox = self.
_axes[i,
None].get_position()
432 self.
figure.text(0.5*(xbox.x0 + xbox.x1), 0.5*(ybox.y0 + ybox.y1), self.
active[i],
433 ha=
'center', va=
'center', weight=
'bold')
434 self.
_axes[i, j].get_frame().set_facecolor(
'none')
441 """An example data object for DensityPlot, demonstrating the necessarity interface. 443 There are two levels of requirements for a data object. First are the attributes 444 required by the DensityPlot object itself; these must be present on every data object: 446 dimensions ------ a sequence of strings that provide names for the dimensions 448 lower ----------- a dictionary of {dimension-name: lower-bound} 450 upper ----------- a dictionary of {dimension-name: upper-bound} 452 The second level of requirements are those of the Layer objects provided here. These 453 may be absent if the associated Layer is not used or is subclassed to reimplement the 454 Layer method that calls the data object method. Currently, these include: 456 eval1d, eval2d -- methods used by the SurfaceLayer class; see their docs for more info 458 values ---------- attribute used by the HistogramLayer and ScatterLayer classes, an array 459 with shape (M,N), where N is the number of dimension and M is the number 462 weights --------- optional attribute used by the HistogramLayer and ScatterLayer classes, 463 a 1-d array with size=M that provides weights for each data point 468 self.
mu = numpy.array([-10.0, 0.0, 10.0])
469 self.
sigma = numpy.array([3.0, 2.0, 1.0])
472 self.
values = numpy.random.randn(2000, 3) * self.
sigma[numpy.newaxis, :] + self.
mu[numpy.newaxis, :]
475 """Evaluate the 1-d analytic function for the given dim at points x (a 1-d numpy array; 476 this method must be numpy-vectorized). 479 return numpy.exp(-0.5*((x-self.
mu[i])/self.
sigma[i])**2) / ((2.0*numpy.pi)**0.5 * self.
sigma[i])
482 """Evaluate the 2-d analytic function for the given xDim and yDim at points x,y 483 (2-d numpy arrays with the same shape; this method must be numpy-vectorized). 487 return (numpy.exp(-0.5*(((x-self.
mu[j])/self.
sigma[j])**2 + ((y-self.
mu[i])/self.
sigma[i])**2)) /
492 """Create and return a DensityPlot with example data.""" 493 fig = matplotlib.pyplot.figure()
def __delitem__(self, name)
def hist2d(self, data, xDim, yDim, xLimits, yLimits)
def plotY(self, axes, data, dim)
def __init__(self, tag, colors=("y", "m", "c", "r", "g", "b"), kwds)
def plotXY(self, axes, data, xDim, yDim)
def __init__(self, figure, kwds)
def hide_xticklabels(axes)
def plotXY(self, axes, data, xDim, yDim)
def hist1d(self, data, dim, limits)
def __init__(self, tag, bins1d=20, bins2d=(20, 20), kwds1d=None, kwds2d=None)
def __getitem__(self, name)
def plotY(self, axes, data, dim)
def plotX(self, axes, data, dim)
def plotX(self, axes, data, dim)
def __setitem__(self, name, layer)
def __init__(self, tag, kwds)
def hide_yticklabels(axes)
def _plotLayer(self, name, layer)
def plotY(self, axes, data, dim)
def eval2d(self, data, xDim, yDim, x, y)
def mergeDefaults(kwds, defaults)
def __init__(self, parent)
def plotX(self, axes, data, dim)
def plotX(self, axes, data, dim)
def __init__(self, tag, steps1d=200, steps2d=200, filled=False, kwds1d=None, kwds2d=None)
def eval1d(self, data, dim, x)
def plotY(self, axes, data, dim)
def eval2d(self, xDim, yDim, x, y)
def plotXY(self, axes, data, xDim, yDim)
def plotXY(self, axes, data, xDim, yDim)