This tutorial focuses on how to quickly and easily create a range helpful data visualizations in Python with seaborn, a plotting package that "provides a high-level interface for drawing attractive statistical graphics". While the examples show some of the parameters you might use in different functions, this shouldn't be considered a substitute for the excellent documentation created by Michael Waskom.
There are a wide range of plotting libraries available in Python, but matplotlib is probably the most commonly used. Although matplotlib allows users to customize every object on their plot, it does not provide easy functions for quickly exploring patterns or trends within data. Seaborn fills this gap by making a few difficult tasks easy to accomplish.
For a better understanding of how seaborn differs from some other Python plotting libraries, check out this entertaining and informative dramatic tour.
Humans (including data scientists) are great at recognizing visual patterns. Using this built-in ability allows a data scientist to better understand the data they are working with, and can help guide them in their work. Exploratory visualizations don't have to be publication-ready, but it does help to understand some basic guidelines about what types of plots and color choices are appropriate for the data. Understanding the different types of data, and what plots are appropriate for each type, will make it easier to visually explore your data. And being able to easily create useful visualizations means that you will be more likely to use them in the course of your work.
Different types of plots might be used depending on the types of data available. In the seaborn library, graphs are generally split into three types:
It is important to begin any data science task with a visual exploration of the data. And to figure out how best to look at your data, it helps to understand the different types of data you are working with. One helpful set of data categories (borrowed from Altair) includes:
This tutorial focuses on plots for the first three types of data.
This tutorial uses individual household microdata from the Residential Energy Consumtion Survey (RECS). The data are minimally pre-processed and have not been properly weighted. Import and processing code are included at the end of the notebook.
All figures below use these data categories:
Seaborn can be installed with either pip or conda:
$ pip install seaborn
$ conda install seaborn
Dependencies include:
%matplotlib inline
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
sns.set_context('notebook') # Other options include "talk", "poster", and "paper"
These types of plots are used to visualize the distribution or density of quantitative data over their range. They are useful for visually determining the general underlying distribution (e.g. normal, log-normal, uniform) that data may be drawn from. The most basic type of distribution plot is the histogram. By default, the seaborn distplot()
function returns both a histogram and the kernel density estimate (kde) of the data.
One especially useful aspect of distplot()
and other seaborn functions that show histograms is that they use the Freedman-Diaconis rule as a smart default to set the number of histogram bins.
# irq and _freedman_diaconis_bins functions from seaborn
# https://github.com/mwaskom/seaborn
def iqr(a):
"""Calculate the IQR for an array of numbers."""
a = np.asarray(a)
q1 = sp.stats.scoreatpercentile(a, 25)
q3 = sp.stats.scoreatpercentile(a, 75)
return q3 - q1
def _freedman_diaconis_bins(a):
"""Calculate number of hist bins using Freedman-Diaconis rule."""
# From http://stats.stackexchange.com/questions/798/
a = np.asarray(a)
h = 2 * iqr(a) / (len(a) ** (1 / 3))
# fall back to sqrt(a) bins if iqr is 0
if h == 0:
return int(np.sqrt(a.size))
else:
return int(np.ceil((a.max() - a.min()) / h))
It's helpful to view the distributions of a few continuous variables from the RECS data. We see that many houses use little or no space heating, and electricity use follows something close to a beta distribution.
# These four plots show a few of the distplot() parameters.
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2,2, figsize=(10,8))
sns.distplot(house['TOTALBTUSPH'], ax=ax1)
# Fit the data to a distribution (black line)
sns.distplot(house['KWH'], fit=sp.stats.beta, ax=ax2)
sns.distplot(house['HDD65'], ax=ax3)
# Add named labels, adjust the bandwidth (smoothing function)
sns.distplot(house['CDD65'], ax=ax4, label='Default kde bandwidth')
sns.distplot(house['CDD65'], hist=False, kde_kws={'bw':500},
label='kde bandwidth = 500', ax=ax4)
sns.distplot(house['CDD65'], hist=False, kde_kws={'bw':1000},
label='kde bandwidth = 1000', ax=ax4)
plt.tight_layout()
Joint distribution plots can be viewed using jointplot()
function with kde
or hex
as the "kind".
sns.jointplot('HDD65', 'CDD65', data=house, kind='kde')
<seaborn.axisgrid.JointGrid at 0x1180436d0>
sns.jointplot('HDD65', 'CDD65', data=house, kind='hex')
<seaborn.axisgrid.JointGrid at 0x117179ed0>
Regression plots are used to show how one quantitative variable in a dataset changes with respect to another. The simplest regression plot in seaborn is regplot()
. It provides a scatter plot with a regression line and shaded 95% confidence interval.
sns.regplot('TOTSQFT', 'KWH', data=house)
<matplotlib.axes._subplots.AxesSubplot at 0x11b04af50>
lmplot()
also plots regressions. It has more features than regplot()
, like the ability to add hues or rows and columns for different categorical variables. These plotting features are explored further in the sections on factorplots and data-aware FacetGrids.
The plot below uses hue
to show Urban and Rural houses as different colors, and col
to split by climate region across columns. The col_wrap
parameter forces two of the climates onto a second row as a way to keep the figure a reasonable size.
sns.lmplot('TOTSQFT', 'KWH', data=house, hue='UR', col='Climate',
col_wrap=3, scatter_kws={'alpha':0.5}) # alpha is the transparancy
<seaborn.axisgrid.FacetGrid at 0x11c652e50>
If one of your variables is ordinal (discrete ordered values), the resulting figure might not be very helpful.
sns.lmplot('STOVEN', 'KWH', house)
<seaborn.axisgrid.FacetGrid at 0x11c6b6d90>
With lmplot()
, it's possible to show the mean (or any other estimator) of each x-value. Seaborn automatically calculates a bootstrapped 95% confidence interval.
sns.lmplot('STOVEN', 'KWH', house, x_estimator=np.mean)
<seaborn.axisgrid.FacetGrid at 0x11dc0c290>
Categorical plots are used to show quantitative values across ordinal or nominal data types. The regression and distribution plots above made use of categorical splits in the data to show multiple plots, but they didn't use those categories on either of the axes.
The four figures below show some of the most common categorical plots, and their features in seaborn. From top-left to bottom-right, they are:
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2,2, figsize=(10,8))
sns.boxplot('Climate', 'HDD65', data=house, ax=ax1)
sns.barplot('Climate', 'HDD65', data=house, ax=ax2)
# Shrink the point sizes (scale), change the estimator from mean to median
sns.pointplot('Climate', 'HDD65', data=house, scale=0.7, estimator=np.median, ax=ax3)
# Adjust the bandwidth (smoothing)
sns.violinplot('Climate', 'HDD65', data=house, ax=ax4, bw=0.4)
plt.tight_layout()
Moving to a slighty more complicated example, the two figures below add the house type as a hue category. This barplot is beginning to look cluttered. It is difficult to compare the same house type across climates, or see how different house types respond differently to the change in climate. By comparison, the pointplot is easier to read (especially with the dodge=True
parameter).
sns.barplot('Climate', 'TOTALBTUSPH', data=house, hue='House')
<matplotlib.axes._subplots.AxesSubplot at 0x11e81c310>
# Difficult to see full range of overlapping confidence intervals
sns.pointplot('Climate', 'TOTALBTUSPH', data=house, hue='House', scale=0.8)
<matplotlib.axes._subplots.AxesSubplot at 0x118849290>
# Set dodge=True to offset points on the x-axis and avoid overlap
sns.pointplot('Climate', 'TOTALBTUSPH', data=house, hue='House', dodge=True, scale=0.8)
<matplotlib.axes._subplots.AxesSubplot at 0x11c2230d0>
Violinplots may not be as common as bar or boxplots, but they can be just as useful. The figure below shows income distribution across a range of education categories. It splits each violin by another category - if the individual is able to telework. With this plot we can quickly tell that:
There are many other observations that can be made about the data. Knowing that violinplot assigns equal area by default, is it possible to infer anything about the proportion of employees who have the option to telework within each education category?
# "Split" places the hue categories on the same object
# "Cut" limits how far past the last datapoint the violin will go
sns.violinplot(y='School', x='Income', hue='TELLWORK', data=house,
split=True, cut=.1, orient='h')
<matplotlib.axes._subplots.AxesSubplot at 0x11c14e910>
The figure below recreates the violinplot as a horizontal boxplot. Picking between these plots is largely a matter of choice, and will depend on what type of information you are trying to extract from the data. The boxplot figures clearly show the IQR and median values for the TELLWORK options within each education category. But they omit any information about the overall income distribution within each education category.
sns.boxplot(y='School', x='Income', hue='TELLWORK', data=house)
# The legend is created manually so that it doesn't overlap with data
plt.legend(bbox_to_anchor=(1, 1), loc=2, title='TELLWORK')
<matplotlib.legend.Legend at 0x11b688790>
factorplot
¶Any of the categorical plots described above - and a couple more that weren't mentioned - can be used with the factorplot()
function to show data across rows, columns, and hues.
sns.factorplot('Climate', 'KWH', data=house, col='UR', hue='STOVEN', kind='bar', aspect=1.25)
<seaborn.axisgrid.FacetGrid at 0x11bb81490>
# This uses a countplot to show the number of datapoints underlying each of the bars
# in the figure above
sns.factorplot('Climate', data=house, col='UR', hue='STOVEN', kind='count', aspect=1.25)
<seaborn.axisgrid.FacetGrid at 0x11c231090>
# By default, factorplots return a pointplot
sns.factorplot('Climate', 'Income', data=house, col='House')
<seaborn.axisgrid.FacetGrid at 0x11ba90d50>
FacetGrid
objects in Seaborn¶We have seen that some functions in seaborn (e.g. lmplot()
and factorplot()
) can use categorical data to split plots into hues, rows, or columns. This behavior can be generalized to any type of seaborn or matplotlib plot with FacetGrid()
objects.
g = sns.FacetGrid(house, row='UR', col='Climate')
g.map(sns.barplot, 'House', 'KWH')
g.set_xticklabels(rotation=30)
<seaborn.axisgrid.FacetGrid at 0x12075e150>
g = sns.FacetGrid(house, row='UR', col='House')
g.map(sns.distplot, 'KWH')
g.set_xticklabels(rotation=35)
plt.tight_layout()
data = house.sample(1000) # taking a subset of the whole dataframe
markersize = data['Income']/data['Income'].mean() * 20 # normalize and scale income as markersize
g = sns.FacetGrid(data, row='UR', col='STOVEN')
# Make the markers empty black circles with size scaled by Income
g.map(plt.scatter, 'TOTSQFT', 'KWH', sizes=markersize, facecolors='None', edgecolors='black')
g.set_xticklabels(rotation=30)
plt.tight_layout()
There are three main categories of color palettes for data visualization:
Qualitative palettes are used to differentiate groups or classes within a dataset. Seaborn provides access to a wide range of prepared qualitative color palettes that can be selected from a default list (as shown below) or from the famous Color Brewer tool. See the seaborn docs for more information and tools.
By default, seaborn includes 6 qualitative palettes, called: deep
, muted
, pastel
, bright
, dark
, and colorblind
. The default palette is deep
, and all of the palettes have 6 colors (notice that the last two colors in the palplot
objects below are repeats of the first two).
sns.palplot(sns.color_palette('deep', n_colors=8))
sns.palplot(sns.color_palette('colorblind', n_colors=8))
If you want to invest a bit more time in your qualitative color choices, seaborn allows you to pass through any colors (rgb, hex, etc) that you like. Even better, you can use named colors from the xkcd color survey - I suggest using this awesome interactive tool to pick out your favorite.
colors = ['cobalt', 'orangey yellow', 'steel grey', 'dark lilac', 'bluish green', 'tomato']
sns.palplot(sns.xkcd_palette(colors))
Sequential palettes encode some quantitative information from a dataset with upper and lower bounds. Importantly, the data should not have a natural mid-point (that's what you use a diverging palette for). Sequential palettes are ideal for kernel density and correlation plots. See the seaborn docs for more information and tools.
sns.palplot(sns.color_palette('Blues'))
# Cubehelix returns sequential palettes with linear change in color brightness
# and a range of hues
cmap = sns.cubehelix_palette(light=1, as_cmap=True)
sns.kdeplot(house['HDD65'], house['CDD65'], cmap=cmap, shade=True)
<matplotlib.axes._subplots.AxesSubplot at 0x12289d110>
Under no circumstance will seaborn let you use the Jet colormap. If you aren't familiar with the name, Jet is that rainbow colormap that you've seen in every MATLAB figure and National Weather Service maps. As explained below, you should never use it.
# Showing a jet colorbar - notice how some colors are more intense than others?
import matplotlib as mpl
fig = plt.figure(figsize=(8, 1))
ax1 = fig.add_axes([0.05, 0.80, 0.9, 0.15])
mpl.colorbar.ColorbarBase(ax1, cmap='jet', orientation='horizontal')
<matplotlib.colorbar.ColorbarBase at 0x1217aa5d0>
# Trying to show a colorbar with jet in seaborn
sns.palplot(sns.color_palette('jet'))
--------------------------------------------------------------------------- ValueError Traceback (most recent call last) <ipython-input-40-6fabc0328ae8> in <module>() 1 # Trying to show a colorbar with jet in seaborn ----> 2 sns.palplot(sns.color_palette('jet')) /Users/Home/anaconda/lib/python2.7/site-packages/seaborn/palettes.pyc in color_palette(palette, n_colors, desat) 163 palette = husl_palette(n_colors) 164 elif palette.lower() == "jet": --> 165 raise ValueError("No.") 166 elif palette in SEABORN_PALETTES: 167 palette = SEABORN_PALETTES[palette] ValueError: No.
Diverging palettes encode some quantitative information from a dataset with upper and lower bounds, where there is a natural mid-point between the two extremes. Diverging palettes are especially useful when showing the change with respect to a baseline. Diverging palettes usually show one of two different colors on each side of a white mid-point, so don't use two colors (e.g. red/green) that might look the same to some viewers. See the seaborn docs for more information and tools.
sns.palplot(sns.color_palette('RdBu', 9))
# Plot electricity use against size of house with color determined by income
# I find the color range to be easier with matplotlib than with seaborn
data = house.sample(500)
# Defining mean as the average of maximum and minimum income so the colors are centered
data['income diff from mean'] = data['Income'] - (data['Income'].max() + data['Income'].min())/2
plt.scatter('TOTSQFT', 'KWH', data=data, c=data['income diff from mean'], cmap='RdBu')
plt.colorbar()
plt.xlabel('TOTSQFT')
plt.ylabel('KWH')
<matplotlib.text.Text at 0x12161ea50>
Check out the video below from SciPy2015. It's an amazing tour through color theory that I don't have space to reproduce here. Plus it gives a thorough explanation of why you should never user Jet. If you don't have 20 minutes, at least check out this blog post.
from IPython.display import HTML
HTML('<iframe width="640" height="360" src="https://www.youtube.com/embed/xAoljeRJ3lU" frameborder="0" allowfullscreen></iframe>')
# Read in a portion of the raw data
input_columns = ['TYPEHUQ', 'EDUCATION', 'MONEYPY', 'Climate_Region_Pub', 'UR',
'TOTSQFT', 'HDD65', 'CDD65', 'KWH', 'TOTALBTUSPH', 'STOVEN', 'TELLWORK']
fn = 'https://raw.githubusercontent.com/gschivley/Visualization-tutorial/master/recs2009_public.csv'
df = pd.read_csv(fn, usecols=input_columns, na_values=[-2])
# Several of the columns have integer values that need to be translated
df.head()
TYPEHUQ | HDD65 | CDD65 | Climate_Region_Pub | UR | STOVEN | EDUCATION | TELLWORK | MONEYPY | TOTSQFT | KWH | TOTALBTUSPH | |
---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 2 | 4742 | 1080 | 4 | U | 1 | 5 | 0 | 23 | 5075 | 18466 | 10873 |
1 | 2 | 2662 | 199 | 5 | U | 1 | 2 | 0 | 9 | 3136 | 5148 | 38606 |
2 | 5 | 6233 | 505 | 1 | U | 1 | 6 | 0 | 18 | 528 | 2218 | 40248 |
3 | 2 | 6034 | 672 | 1 | U | 0 | 2 | 0 | 10 | 2023 | 10015 | 40196 |
4 | 3 | 5388 | 702 | 1 | U | 1 | 5 | 0 | 20 | 1912 | 2869 | 36136 |
# Read in pre-defined translations of integer to category values
fn = 'https://raw.githubusercontent.com/gschivley/Visualization-tutorial/master/Translations.xlsx'
income_df = pd.read_excel(fn, 'Income', index_col=0)
education_df = pd.read_excel(fn, 'Education', index_col=0)
house_df = pd.read_excel(fn, 'House_stock', index_col=0)
climate_df = pd.read_excel(fn, 'Climate_region', index_col=0)
climate_df
Climate | |
---|---|
1 | Very Cold/\nCold |
2 | Hot-Dry/\nMixed-Dry |
3 | Hot-Humid |
4 | Mixed-Humid |
5 | Marine |
def translate(codebook, code):
"""Takes a dataframe of translations and item code, returns the translated value
Inputs:
codebook (dataframe): Dataframe that matches RECS (integer) codes with category values
code (integer): Integer code from the RECS database
Outputs:
(string or integer): Translated value returned from the codebook. In the case of income,
the returned value is a random integer over a range.
"""
if codebook.columns == 'income':
middle = codebook.loc[code].values[0]
# Range is from halfway between middle and low to halfway between middle and high
low = (middle + codebook.loc[code-1].values[0])/2
high = (middle + codebook.loc[code+1].values[0])/2
return np.random.randint(low, high)
else:
return codebook.loc[code].values[0]
# Translate columns
df['Income'] = df['MONEYPY'].apply(lambda x: translate(income_df, x))
df['School'] = df['EDUCATION'].apply(lambda x:
translate(education_df, x)).astype('category',
categories=education_df['Education'].tolist())
df['House'] = df['TYPEHUQ'].apply(lambda x: translate(house_df, x)).astype('category')
df['Climate'] = df['Climate_Region_Pub'].apply(lambda x: translate(climate_df, x)).astype('category')
# Remove apartments from the dataset and drop the categories
# Remove translated columns from the dataset
# New DataFrame is named 'house'
house = df[(df['TYPEHUQ'] != 4) & (df['TYPEHUQ'] != 5)]
house.drop(['MONEYPY', 'EDUCATION', 'TYPEHUQ', 'Climate_Region_Pub'], axis=1, inplace=True)
house.loc[:,'House'] = house.loc[:,'House'].cat.remove_categories(['Apartment in Building with 2-4 units',
'Apartment in Building with 5+ units'])
/Users/Home/anaconda/lib/python2.7/site-packages/ipykernel/__main__.py:5: SettingWithCopyWarning: A value is trying to be set on a copy of a slice from a DataFrame See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy /Users/Home/anaconda/lib/python2.7/site-packages/pandas/core/indexing.py:465: SettingWithCopyWarning: A value is trying to be set on a copy of a slice from a DataFrame. Try using .loc[row_indexer,col_indexer] = value instead See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy self.obj[item] = s
house.head()
HDD65 | CDD65 | UR | STOVEN | TELLWORK | TOTSQFT | KWH | TOTALBTUSPH | Income | School | House | Climate | |
---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 4742 | 1080 | U | 1 | 0 | 5075 | 18466 | 10873 | 111890 | Bachelor's | Single-Family Detached | Mixed-Humid |
1 | 2662 | 199 | U | 1 | 0 | 3136 | 5148 | 38606 | 32504 | High School | Single-Family Detached | Marine |
3 | 6034 | 672 | U | 0 | 0 | 2023 | 10015 | 40196 | 35187 | High School | Single-Family Detached | Very Cold/\nCold |
4 | 5388 | 702 | U | 1 | 0 | 1912 | 2869 | 36136 | 85197 | Bachelor's | Single-Family Attached | Very Cold/\nCold |
5 | 8866 | 270 | U | 1 | 0 | 3485 | 6387 | 74100 | 18411 | Associate's | Single-Family Detached | Very Cold/\nCold |
# These four plots are used to detect extreme outliers.
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2,2, figsize=(10,8))
sns.distplot(house['TOTALBTUSPH'], ax=ax1)
sns.distplot(house['KWH'], fit=sp.stats.beta, ax=ax2)
sns.distplot(house['HDD65'], ax=ax3)
sns.distplot(house['CDD65'], ax=ax4)
plt.tight_layout()
/Users/Home/anaconda/lib/python2.7/site-packages/scipy/optimize/minpack.py:161: RuntimeWarning: The iteration is not making good progress, as measured by the improvement from the last ten iterations. warnings.warn(msg, RuntimeWarning)
In both space heating and electricity energy use there are a few very large values that I'll filter out to make the plots easier to view.
house = house[(house['TOTALBTUSPH']<200000) & (house['KWH']<60000)]
house.head()
HDD65 | CDD65 | UR | STOVEN | TELLWORK | TOTSQFT | KWH | TOTALBTUSPH | Income | School | House | Climate | |
---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 4742 | 1080 | U | 1 | 0 | 5075 | 18466 | 10873 | 111890 | Bachelor's | Single-Family Detached | Mixed-Humid |
1 | 2662 | 199 | U | 1 | 0 | 3136 | 5148 | 38606 | 32504 | High School | Single-Family Detached | Marine |
3 | 6034 | 672 | U | 0 | 0 | 2023 | 10015 | 40196 | 35187 | High School | Single-Family Detached | Very Cold/\nCold |
4 | 5388 | 702 | U | 1 | 0 | 1912 | 2869 | 36136 | 85197 | Bachelor's | Single-Family Attached | Very Cold/\nCold |
5 | 8866 | 270 | U | 1 | 0 | 3485 | 6387 | 74100 | 18411 | Associate's | Single-Family Detached | Very Cold/\nCold |