Catalog of patterns for Matplotlib

Numpy and Matplotlib are great pieces of software for interacting with data and creating plots. Over the several years I have been working with Matplotlib, I have found two things:

  1. I keep returning to a handful of patterns in order to carry out my work.
  2. I have a hard time remembering these patterns

At some point in almost every project I end up stumbling around some seemingly-simple task that I know I have solved before. Hopefully by creating my own Matplotlib Cookbook I can spend less time shopping for details about subplots, tick spacing, colors, etc, and more time analyzing whatever problem I am working on.

The web already has a lot of resources for matplotlib, so there isn't likely anything new here - this is simply my personal collection of matplotlib examples, written in my own style, so even if they don't make sense in the future at least I understood them at some point!

Ideas for more topics/sections to add:

  • tick spacing
  • handling colorbars
  • images
  • 3d plots
  • pandas

Setup

Make sure you have at least numpy and pyplot included. Other packages/libraries will be included as needed.

In [14]:
import matplotlib.pyplot as plt
import numpy as np

Use some magic funciton to let plots show up in line in a Jupyter (nee IPython) Notebook. I don't really understand what a magic function is, but sometimes they seem to fix various IPython idiosyncracies.

In [15]:
# inline plots in notebooks
%matplotlib inline 

Generate some sample data

Then lets generate a bit of junk data to play with:

In [16]:
# A 1D variable, a basic timeseries for example
x = np.linspace(0, 3*np.pi, 100)
sinwave = np.sin(x)
noisy_sin = sinwave + np.random.uniform(-.75, 0.75, sinwave.shape)

# A 2D variable, for example, could be spatial data like a map
random_data = np.random.uniform(0,10,(5,5))

More than one plot

Frequently when investigating some issue, I find myself wanting to crunch some numbers and quickly throw up a few plots side by side. There are a bunch of ways to do this, including at least:

subplots command

First get a Figure instance and then explicitly un-pack a 2-D array of Axes instances into named variables. Maybe not ideal in some contexts, but it gives you nice terse variables for quick exploration in an IPython session.

In [17]:
fig, ((ax0,ax1,ax2),(ax3,ax4,ax5),(ax6,ax7,ax8)) = plt.subplots(3,3)

im3 = ax3.imshow(random_data, interpolation='none')
plt.colorbar(im3, ax=ax3) # associates the colorbar with a specific dataset and axes

p4 = ax4.plot(x, sinwave, label='sin(x)')
p4a = ax4.plot(x, noisy_sin, label="sin(x)+noise")

# fix some spacing issues with overlapping ticks
plt.tight_layout()

subplot command

This might be the easiest for simple multi-plot layouts.

In [18]:
ax_left = plt.subplot(211)  # <-- 2 rows, one column, first plot
ax_right = plt.subplot(212) # <-- 2 rows, one column, second plot

ax_left.plot(x, sinwave)
ax_right.imshow(random_data)
Out[18]:

gridspec object

Allows some handy tricks for getting different sized or proportioned axes. You actually end up passing a slice of a gridspec instance to plt.subplot().

In [37]:
import matplotlib.gridspec as gridspec

ROWS=5; COLS=5 

gs = gridspec.GridSpec(ROWS, COLS)

# use the first few rows, spanning the whole width of the grid
ax1 = plt.subplot(gs[0:2,:])
ax1.plot(sinwave, label="sin(x)", linewidth=1.0)
ax1.plot(noisy_sin, label="sin(x)+noise", linewidth=1.0)

# Only the first column, beneath the upper plot
ax2 = plt.subplot(gs[2:,0:2])
ax2.plot(sinwave, label='sin(x)')

# The remaining space in the lower right of the grid
ax3 = plt.subplot(gs[2:,2:])
ax3.plot(noisy_sin, label='sin(x)+noise')

# Turn on the legend for select axes
for ax in [ax1,  ax3]:
    ax.legend()
    
# fix some spacing issues with overlapping ticks
plt.tight_layout()

Histograms

Histograms are a standard plotting tool to look at the distribution of a dataset. Below are convenient functions to plot a histogram using plt.hist (search docs for optional arguments that can be added to the functions below).

First, generate a random dataset with an interesting distribution:

In [7]:
nobs = 1000
x = np.random.uniform(-4, 4, nobs)
y = x + 0.25 * x**2 + 0.1 * np.exp(1 + np.abs(x)) * np.random.randn(nobs)

Here's what our random data looks like:

In [8]:
plt.figure(figsize=(6 * 1.618, 6))
plt.plot(x, y, 'o', alpha=0.5)
Out[8]:
[]

The following produces a standard histogram, where the data, number of bins, and the title for the plot are passed to the function. I also plot the median and mean line, to get a sense for the central tendency of the data. If the data is normally distributed, the median and mean will plot on top of each other.

In [9]:
def histoplot(data, num_bins, titlestring):
    fig1 = plt.figure(figsize=(6*1.618, 6))
    ax1 = fig1.add_subplot(111) #can add add'l subplots if wanted
    n, bins, patches = plt.hist(data, num_bins, normed=1, facecolor='blue', alpha=0.5)
    plt.axvline(x=np.median(data), ymin=0, ymax=180, linewidth=1, color='r') # plot the median
    plt.axvline(x=np.mean(data), ymin=0, ymax=180, linewidth=1, color='k') # plot the mean
    ax1.set_title(titlestring, fontsize=12)
    ax1.set_xlabel('xlabel', fontsize=12)
    ax1.set_ylabel('samples', fontsize=12)
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
In [10]:
histoplot(y, 100, 'Title')

As an additional tool to check for normality in the data, the following plot uses mlab.normpdf (docs) to add a dashed line indicating what the true normal distribution should be, given the mean and standard deviation of the dataset. This allows you to inspect the skewness and kurtosis of your data compared to a true normal distribution.

In [11]:
import matplotlib.mlab as mlab
In [12]:
def histoplot_normal(data, num_bins, titlestring):
    fig2 = plt.figure(figsize=(6*1.618, 6))
    ax1 = fig2.add_subplot(111)
    n, bins, patches = plt.hist(data, num_bins, normed=1, facecolor='blue', alpha = 0.5)
    y = mlab.normpdf(bins, np.mean(data), np.std(data))
    plt.plot(bins, y, 'k--')
    plt.axvline(x=np.median(data), ymin=0, ymax=180, linewidth=1, color='r')
    plt.axvline(x=np.mean(data), ymin=0, ymax=180, linewidth=1, color='k')
    ax1.set_title(titlestring, fontsize=12)
    ax1.set_xlabel('xlabel', fontsize=12)
    ax1.set_ylabel('samples', fontsize=12)
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
In [13]:
histoplot_normal(y, 100, 'Title')