matplotlib
¶In this lecture, we'll introduce the matplotlib
package for data visualization. matplotlib
is second core library for computational and data science in Python. It works very well with numpy
.
In this course, we'll only make use of tools in the matplotlib.pyplot
module. The standard way to import these tools is:
import numpy as np
from matplotlib import pyplot as plt
We've seen a few examples of using matplotlib
functions before, e.g. when visualizing random walks. In these cases, we called plt.plot()
and it "just worked." In order to obtain finer-grained control over our visualizations, we are going to start with the "object-oriented interface" to matplotlib
. Here are the main ingredients:
The basic workflow for matplotlib
figures is:
Let's work through these stages.
Unless you need something very customized, it's easiest to do these two steps simultaneously, using plt.subplots()
:
# create a single figure with a single axis (1 row x 1 column)
fig, ax = plt.subplots(1, 1)
# two plots side by side, using figsize to stretch
# them out to a better aspect ratio
fig, ax = plt.subplots(1, 2, figsize = (5, 2))
# 2x2 grid
fig, ax = plt.subplots(2, 2)
Let's take a moment to notice what kind of object ax
is:
type(ax)
numpy.ndarray
It's a numpy
array! It's shape is:
ax.shape
(2, 2)
This means, for example, that if we want the axis in the first row and second column, we can access it like this:
ax[0, 1]
<matplotlib.axes._subplots.AxesSubplot at 0x7f9e2f97f1f0>
To add data markings, we need:
fig, ax = plt.subplots(1, 2, figsize = (7, 3))
# something to plot
x = np.linspace(0, 2*np.pi, 21)
y = np.sin(x)
z = np.log(x)
<ipython-input-9-efc66d13071f>:5: RuntimeWarning: divide by zero encountered in log z = np.log(x)
ax[0].scatter(x, y)
fig
ax[1].plot(x,z, label = "#bestLogarithm")
fig
You can add as many layers as you want:
ax[1].scatter(x, z)
fig
Don't forget to label your axes!! We can also add axis titles and an overall figure title.
ax[0].set(xlabel = "x",
ylabel = "y",
title = "plot 1")
ax[1].set(xlabel = "x",
title = "plot 2")
fig.suptitle("My awesome plots")
fig
When layering similar kinds of data markings, matplotlib
will usually change the color each time. If you label
your markings, then you'll be able to tell them apart with a legend()
.
ax[0].scatter(x, np.cos(x))
ax[1].scatter(x, np.exp(x/3), label = "#bestExponential")
fig
ax[1].legend()
fig
Once you're happy with how the figure looks, you can save it in one of several file formats using the savefig()
method. PNG is usually a good choice.
fig.savefig("my_awesome_plot.png")
We now have the basic components in place to create figures with multiple axes, populate them with data markings, annotate them, and save the results. In the next lecture, we'll look at plotting 2d data, including mathematical functions and images.