At the end of the day, any scientific experiment -- either performed in the lab, or in silico -- tends to produce data which is processed and analyzed. Sometimes, this analysis involves visualizing the data. In this lesson, you'll learn how to use Python's matplotlib
package to make some visualizations.
In this lesson, we'll be working with some data derived from a data set published by the World Bank on carbon dixoide emissions by country and year, measured in metric tons per capita.
Let's take a look at one of the main tools used to do data visualization in Python: matplotlib. matplotlib
is designed to be Python's answer to plotting in MATLAB
. For this reason, if you've had to make plots in MATLAB
, you might find working with matplotlib
to be somewhat familiar.
Generally speaking, people interact with matplotlib
in one of two ways:
They import a module, matplotlib.pyplot
, which exposes a whole bunch of plotting functions, and allows for more fine-tuned control over aspects of the plot
They import a module pylab
, which itself imports matplotlib
and some other Python packages, to make the generation of plots a bit more like how MATLAB
works.
For the purposes of this lesson, you will be using the former approach, as matplotlib.pyplot
provides a much more rich set of commands with which you can create plots. (For more of a discussion on this point, see this blog post from DataCamp.)
Ready to start making some plots?
The
matplotlib.pyplot
figure
object
The first thing you need to know about making a plot is the figure
object. Think of it like a canvas that an artist paints on - it's a blank slate upon which you can put text, images, etc. The figure
object is the fundamental thing we need to do plotting.
Let's import matplotlib.pyplot
and instantiate a figure. To do so, we first need to import the module. In addition, because we are working in a Jupyter notebook, it's often convenient to tell matplotlib that we want the plots to be rendered "inline" inside the notebook.
#Import matplotlib's pyplot module, and set rendering to be "inline"
import matplotlib.pyplot as plt
%matplotlib inline
#Import numpy for later in the notebook
import numpy as np
Now, we can instantiate a figure
.
#Instantiate a figure, and show it
fig = plt.figure()
plt.show()
<matplotlib.figure.Figure at 0x7fad9de692b0>
Hmm...that's interesting! We made a figure, but when we show it, we simply get a memory address! Why is that? Because, since the figure is a canvas, and we haven't put anything on it, there's nothing to show!
Adding axes to the figure
Having instantiated a figure
object, we can use the .add_axes()
method to put axes on the figure.
#Instantiate a figure, and add a standard set of axes
fig = plt.figure()
ax = fig.add_axes([0, 0, 1, 1])
plt.show()
The add_axes()
method provides us with a non-trivial degree of control over the figure. For example, suppose we wanted to put two sets of axes. To do so, we call add_axes()
twice.
#Instantiate a figure
fig = plt.figure()
#Add an axis for the left half of the figure
ax1 = fig.add_axes([0, 0, .5, 1])
#Add an axis for the right half of the figure
ax2 = fig.add_axes([.5, 0, .5, 1])
plt.show()
Now that we can add axes to our figures, we can start making some plots!
Line Plot
One of the most basic -- but useful -- plots is a line plot. If we have an Axes
object, we can use its plot
method to plot data. Below, you will take some data about historical US CO2 emissions, and make a line plot.
#Data derived from the World Bank data set
years = [1990, 2000, 2007, 2008, 2009, 2010, 2011, 2012, 2013]
#Uses a Python *list comprehension* to cast the original data
#(which is formatted as a string!) into a float with 2 decimal
#place
us_emissions = [np.round(float(x), 2) for x in ['19.3233681671961', '20.2076147591466', '19.2374604467856',
'18.4892337521751', '17.1923671443447', '17.4848031479796',
'17.0194385160343', '16.28705288144', '16.389757994879']]
#Instantiate a figure
fig = plt.figure()
#Add an axis
ax = fig.add_axes([0, 0, 1, 1])
#Plot emissions versus years
ax.plot(years, us_emissions)
[<matplotlib.lines.Line2D at 0x7fad9a72fcf8>]
This is a pretty minimal plot. There are several things we might like to do to it, to make it a bit more legible:
Using the plot
method of an Axes
object, doing so is straightforward! The first two points are accomplished using the marker
, markersize
, linestyle
keywords, and the last two using other methods built into the Axes
object.
#Draw the minimal figure
fig = plt.figure()
ax = fig.add_axes([0, 0, 1, 1])
#Change marker, markersize, and linestyle
#Use clip_on=False to prevent matplotlib
#from clipping markers which lie on the edge of the plot
ax.plot(years, us_emissions, marker='o', linestyle='--', markersize=8, clip_on=False)
#Add a label to the x-axis
ax.set_xlabel('Year', fontsize=15)
#Add a label to the y-axis
ax.set_ylabel('Emissions', fontsize=15)
#Add a title
ax.set_title('US CO2 Emissions (Metric Tons per Capita)', fontsize=20)
#Adjust x limits for easier viewing (1990, 2015)
ax.set_xlim([1990, 2015])
#Turn off the frame (the box around the plot)
ax.set_frame_on(False)
Bar Plot
To make a bar plot, we use the bar
method of an Axes
object. Below, you'll make a bar plot of carbon dioxide emissions in 2012 for a few countries.
The bar
method is a bit strange in that it request as input a list of scalars and a list of bar heights, and
then outputs the bar plot. Which means that if we want our x-labels to be categorical, then we need to explicitly
manipulate the plot to do so.
#2012 carbon dioxide emissions (metric tons per capita) for a few countries
countries = ['Qatar', 'Kuwait', 'Luxembourg', 'United Arab Emirates', 'Saudi Arabia', 'Australia', 'United States',\
'Canada', 'Russian Federation']
emissions = [46.697477, 29.578729, 20.084219, 19.252223, 19.188937, 16.519386, 16.287053, 13.858827, 12.818345]
#Make a figure
fig = plt.figure()
#Add an axis
ax = fig.add_axes([0, 0, 1, 1])
#Use the bar() function to make the bar plot
ax.bar(range(len(countries)), emissions)
#Adjust the x-ticks
#Set the ticks at the number of countries
#Make the tick labels the country names, rotated at 45 degrees
ax.set_xticks(range(len(countries)))
ax.set_xticklabels(countries, rotation=45)
#Put labels on the axes, and make a title
ax.set_xlabel('Country', fontsize=15)
ax.set_ylabel('Emissions', fontsize=15)
ax.set_title('2012 CO2 Emissions for Select Countries', fontsize=20)
#Turn off the frame
ax.set_frame_on(False)
#Suppress some output, and show the plot
plt.show()
Using
seaborn
to change default aesthetics
Seaborn
is a useful package for improving figure aesthetics. Even if you don't want to use some of its more advanced functionality, know that simply calling import seaborn
overrides a lot of matplotlib's default behaviors, making your plots much more visually pleasing.
#Import seaborn
import seaborn
#Make a figure and standard axes
#See how they are different from earlier?
fig = plt.figure()
ax = fig.add_axes([0, 0, 1, 1])
plt.show()
Choosing a Good Colormap with
seaborn.color_palette
Another compelling reason to use seaborn is that is gives you access to many color palettes, which is useful when you want to make sure the colors you use on your plots are harmonious with the kind of data you are plotting. We will look at three major data types: categorical, sequential, and diverging, and show how seaborn
can be used to pick a good color palette for each type.
Data is categorical when it is related to a set of fixed classifications, or categories. For example, the US Census collects a substantial amount of categorical data about the US population.
Sometimes, we want to take numerical data and make sense of it as a function of some category. In the World Bank data set, we will look at CO2 emissions by country.
In this part of the lesson, we make use of the pandas
library, which is intended to help researchers work with data, through the use of "dataframes" (think "tables"). In particular, pandas
has a variety of read_X
functions, which allow you to read CSV files, or even Excel files!
#Import pandas
import pandas as pd
#Read in worldwide_emissions.csv
df = pd.read_csv('worldwide_emissions.csv')
#Display the first few rows
df.head()
Country Name | Year | emissions | |
---|---|---|---|
0 | China | 1990 | 2.168 |
1 | India | 1990 | 0.711 |
2 | Iran, Islamic Rep. | 1990 | 3.757 |
3 | United States | 1990 | 19.323 |
4 | China | 2000 | 2.697 |
Colormaps for categorical data should themselves not be distracting. In particular, if we use a color map which has substantial variation in the perceived brightness of the colors, then viewers of our plot will tend to focus more on that, than the data conveyed!
To that end, we make use of the "HUSL" color map. The perceived brightness of the colors doesn't change as we scan the colors, making this colormap great for categorical data.
Seaborn provides access to a function, color_palette
, which allows us to create a HUSL color palette with enough colors, one for each country in our dataframe.
One you have a color palette, you can use the palplot
function to display it. (Great for checking that the colors are what you want!)
#Determine how many unique country names are in the dataframe
num_countries =len(df['Country Name'].unique())
#Make a HUSL color palette
colors = seaborn.color_palette('husl', n_colors=num_countries)
#Display the palette
seaborn.palplot(colors)
Now that we have the colors and the data, we are going to make a figure showing CO2 emissions over the years for each country in the dataset.
#Make a standard figure and axes
fig = plt.figure()
ax = fig.add_axes([0, 0, 1, 1])
#Initialize a counter for the color,
#noting Python uses 0-indexing
counter = 0
#Loop over the countries in the dataframe
#and plot their emissions vs year
for country in df['Country Name'].unique():
X = df[df['Country Name'] == country]['Year']
Y = df[df['Country Name'] == country]['emissions']
#Plot X vs Y, setting the color, linestyle, marker, clip_on, and **label**
ax.plot(X, Y, color=colors[counter], linestyle='--', marker='o', clip_on=False, label=country)
#Don't forget to increment the counter!
counter += 1
#Put in a legend
ax.legend()
<matplotlib.legend.Legend at 0x7fad9a64ad68>
A sequential color palette is useful when you have an ordering on numeric data, and your data do not take both positive and negative values.
The plot above does a good job of using lines to indicate increases/decreases in CO2 emissions over time. However, because the amount of emisssions can never go below 0, it might make sense to plot this data in a different way, using a heatmap.
#Reshape the data, so that countries are now the vertical axis, and years the horizontal.
PT = df.pivot_table(index='Year', columns='Country Name', values='emissions').transpose()
PT
Year | 1990 | 2000 | 2007 | 2008 | 2009 | 2010 | 2011 | 2012 | 2013 |
---|---|---|---|---|---|---|---|---|---|
Country Name | |||||||||
China | 2.168 | 2.697 | 5.154 | 5.417 | 5.723 | 6.554 | 7.235 | 7.419 | 7.551 |
India | 0.711 | 0.979 | 1.193 | 1.310 | 1.432 | 1.397 | 1.480 | 1.597 | 1.590 |
Iran, Islamic Rep. | 3.757 | 5.656 | 7.821 | 8.023 | 8.061 | 8.151 | 8.235 | 8.454 | 7.997 |
Russian Federation | NaN | 10.627 | 11.672 | 12.015 | 11.024 | 11.726 | 12.368 | 12.818 | 12.467 |
United States | 19.323 | 20.208 | 19.237 | 18.489 | 17.192 | 17.485 | 17.019 | 16.287 | 16.390 |
For this plot, we use seaborn's cubehelix_palette
function to provide access to the cubehelix color map. Because later we'll be making a heatmap, which takes as input a matplotlib colormap
object, we duplicate some code here to show the color palette, and then make the color map.
#Show the colors, by calling cubehelix_palette with as_cmap=False
colors = seaborn.cubehelix_palette()
seaborn.palplot(colors)
#Duplicate the code above, but use as_cmap=True to get a colormap
cmap = seaborn.cubehelix_palette(as_cmap=True)
#Make a standard figure and axes
fig = plt.figure()
ax = fig.add_axes([0, 0, 1, 1])
#Use the seaborn.heatmap() function to display PT.as_matrix(),
#setting cmap as the cmap above
seaborn.heatmap(PT.as_matrix(), ax=ax, cmap=cmap)
#Adjust the x and y-tick labels
years = PT.columns.values
countries= list(reversed(PT.index.values))
#Set x,y labels, and a title
ax.set_xticklabels(years, fontsize=15)
ax.set_yticklabels(countries, rotation=0, fontsize=15)
ax.set_title('CO2 Emissions over Time', fontsize=20)
#Show the plot
plt.show()
A diverging colormap is good when your data takes positive and negative values relative to some reference value. (So, for example, the temperature relative to room temperature, or elevation relevative to sea level, are data sets where using a diverging colormap would be useful.)
The plot above shows each country's absolute CO2 emissions. What if we wanted to look at, say, how much higher/lower they were relative to some reference country? In that case, a diverging colormap would be good.
#Pick a reference country, and assign it
#to sub_country
sub_country = 'China'
#Some pandas code which subtracts the values of
#sub_country from each element in the pivot table,
#while keeping the indexing correct.
sub_df = PT.sub(PT.ix[sub_country, :], axis=1)
sub_df
Year | 1990 | 2000 | 2007 | 2008 | 2009 | 2010 | 2011 | 2012 | 2013 |
---|---|---|---|---|---|---|---|---|---|
Country Name | |||||||||
China | 0.000 | 0.000 | 0.000 | 0.000 | 0.000 | 0.000 | 0.000 | 0.000 | 0.000 |
India | -1.457 | -1.718 | -3.961 | -4.107 | -4.291 | -5.157 | -5.755 | -5.822 | -5.961 |
Iran, Islamic Rep. | 1.589 | 2.959 | 2.667 | 2.606 | 2.338 | 1.597 | 1.000 | 1.035 | 0.446 |
Russian Federation | NaN | 7.930 | 6.518 | 6.598 | 5.301 | 5.172 | 5.133 | 5.399 | 4.916 |
United States | 17.155 | 17.511 | 14.083 | 13.072 | 11.469 | 10.931 | 9.784 | 8.868 | 8.839 |
For this plot, we use seaborn's diverging_palette
function to choose a diverging colormap. Again, we need to re-call the code using as_cmap=True
to get a colormap
#Create a color map using the diverging_palette function,
#with inputs (150, 275, s=80, l=55, n=9)
colors = seaborn.diverging_palette(150, 275, s=80, l=55, n=9)
#Display the color map
seaborn.palplot(colors)
#Repeat the code above, but use as_cmap=True to get a cmap
cmap = seaborn.diverging_palette(150, 275, s=80, l=55, n=9, as_cmap=True)
#Instantiate a standard figure and axes
fig = plt.figure()
ax = fig.add_axes([0, 0, 1, 1])
#Make a heatmap of sub_df.as_matrix()
seaborn.heatmap(sub_df.as_matrix(), ax=ax, cmap=cmap)
#Set x,y tick labels
years = sub_df.columns.values
countries = list(reversed(sub_df.index))
#Set x,y labels and title.
#Use .format(sub_country) to indicate what country
#the data is being compared to
ax.set_xticklabels(years, fontsize=15)
ax.set_yticklabels(countries, rotation=0, fontsize=15)
ax.set_title('CO2 Emissions Compared to {0}'.format(sub_country), fontsize=20)
#Show the plot
plt.show()
What trends do you spot in CO2 emissions relative to the country you chose?
While Jupyter notebooks are a great way to display and interact with code and data, sometimes you may find yourself needing to send a figure to your advisor, or put it in your paper. How do you get figures out of matplotlib, and into other file formats? We use the figure.savefig()
method to do so. But first, we should take a look at some common file formats, and when it's appropriate to use them.
Choosing a Good File Format
We've all seen it before - a figure which is pixellated, or doesn't quite look right. Such issues commonly arise from using the wrong file format. Some useful formats are:
Everyone usually has some kind of pdf reader on their computer, making the pdf format useful when you want something "which just works".
Scalable Vector Graphics (SVG)
SVG is really nice when you want to make sure people can zoom in/out of your figure. This file format is supported by modern browsers. SVG is also stored as text, making it easy to parse SVG files using computer code.
Portable Network Graphics (PNG)
A lossless compression file type. Personally, I have never used it.
Saving Figures with
figure.savefig()
Once we've made a matplotlib Figure
, we can save it using the .savefig()
function. One super-useful keyword argument is bbox_inches
, which controls how much of the figure matplotlib tries to save. When set to tight
, matplotlib attempts to find the tightest bounding box which contains all the elements of the figure.
#Save the figure from the previous section in pdf, svg, and png
#You can use a loop!
for file_format in ['pdf', 'svg', 'png']:
fig.savefig('emissions.{0}'.format(file_format), format=file_format, bbox_inches='tight')