Code and text from Python for Data Analysis By Wes McKinney - Chapter 9 => Plotting and Visualization Github - pydata-book
# %matplotlib notebook
%matplotlib inline
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
PREVIOUS_MAX_ROWS = pd.options.display.max_rows
pd.options.display.max_rows = 20
np.random.seed(12345)
plt.rc('figure', figsize=(10, 6))
np.set_printoptions(precision=4, suppress=True)
data = np.arange(10)
data
array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
plt.plot(data)
[<matplotlib.lines.Line2D at 0x1f96d8ebd08>]
Plots in matplotlib reside within a Figure object. You can create a new figure with plt.figure
You can’t make a plot with a blank figure. You have to create one or more subplots using add_subplot
fig = plt.figure()
ax1 = fig.add_subplot(2, 2, 1)
ax2 = fig.add_subplot(2, 2, 2)
ax3 = fig.add_subplot(2, 2, 3)
plt.plot(np.random.randn(50).cumsum(), 'k--')
# 'k--' is a style option instructing matplotlib to plot a black dashed line
[<matplotlib.lines.Line2D at 0x1f96d983d08>]
The objects returned by fig.add_subplot here are AxesSubplot objects, on which you can directly plot on the other empty subplots by calling each one’s instance method
_ = ax1.hist(np.random.randn(100), bins=20, color='k', alpha=0.3)
ax2.scatter(np.arange(30), np.arange(30) + 3 * np.random.randn(30))
<matplotlib.collections.PathCollection at 0x1f96d9b8fc8>
fig
Creating a figure with a grid of subplots is a very common task, so matplotlib includes a convenience method, plt.subplots, that creates a new figure and returns a NumPy array containing the created subplot objects
fig, axes = plt.subplots(2, 3)
axes
# axes array can be easily indexed like a two-dimensional
# array; for example, axes[0, 1].
array([[<matplotlib.axes._subplots.AxesSubplot object at 0x000001F96DA02A48>, <matplotlib.axes._subplots.AxesSubplot object at 0x000001F96DC6D748>, <matplotlib.axes._subplots.AxesSubplot object at 0x000001F96DB280C8>], [<matplotlib.axes._subplots.AxesSubplot object at 0x000001F96DAB5188>, <matplotlib.axes._subplots.AxesSubplot object at 0x000001F96DB89608>, <matplotlib.axes._subplots.AxesSubplot object at 0x000001F96DBC1808>]], dtype=object)
# ou can also indicate that subplots should have the
# same x- or y-axis using sharex and sharey, respectively
fig, axes = plt.subplots(2, 3, sharey=True)
pyplot.subplots options
Argument | Description |
---|---|
nrows | Number of rows of subplots |
ncols | Number of columns of subplots |
sharex | All subplots should use the same x-axis ticks (adjusting the xlim will affect all subplots) |
sharey | All subplots should use the same y-axis ticks (adjusting the ylim will affect all subplots) |
subplot_kw | Dict of keywords passed to add_subplot call used to create each subplot |
**fig_kw | Additional keywords to subplots are used when creating the figure, such as plt.subplots(2, 2, figsize=(8, 6)) |
By default matplotlib leaves a certain amount of padding around the outside of the subplots and spacing between subplots. This spacing is all specified relative to the height and width of the plot, so that if you resize the plot either programmatically or manually using the GUI window, the plot will dynamically adjust itself. You can change the spacing using the subplots_adjust method on Figure objects, also available as a top-level function
*subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=None, hspace=None)*
wspace and hspace controls the percent of the figure width and figure height, respec‐ tively, to use as spacing between subplots.
# shrinking spaces to zero
fig, axes = plt.subplots(2, 2, sharex=True, sharey=True)
for i in range(2):
for j in range(2):
axes[i, j].hist(np.random.randn(500), bins=50, color='k', alpha=0.5)
plt.subplots_adjust(wspace=0, hspace=0)
# plot function accepts an array of x and y and optionally a string abbreviation
# indicating color and line style
# ax.plot(x, y, 'g--')
# In practice
# ax.plot(x, y, linestyle='--', color='g')
# ?matplotlib.pyplot.plot
# using MARKERS to highlight the actual data points
from numpy.random import randn
# plt.plot(randn(30).cumsum(), 'ko--')
plt.plot(randn(30).cumsum(), color='k', linestyle='dashed', marker='o')
[<matplotlib.lines.Line2D at 0x1f96e4472c8>]
For line plots, you will notice that subsequent points are linearly interpolated by default. This can be altered with the drawstyle option
data = np.random.randn(30).cumsum()
plt.plot(data, 'k--', label='Default')
plt.plot(data, 'k-', drawstyle='steps-post', label='steps-post')
plt.legend(loc='best')
<matplotlib.legend.Legend at 0x1f96e4a47c8>
The pyplot interface, designed for interactive use, consists of methods like xlim, xticks, xticklabels => plot range, tick locations, tick labels, respectively.
They can be used in two ways:
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
ax.plot(np.random.randn(1000).cumsum())
[<matplotlib.lines.Line2D at 0x1f96e222c88>]
To change the x-axis ticks, it’s easiest to use set_xticks and set_xticklabels. The former instructs matplotlib where to place the ticks along the data range; by default these locations will also be the labels. But we can set any other values as the labels using set_xticklabels
ticks = ax.set_xticks([0, 250, 500, 750, 1000])
labels = ax.set_xticklabels(['one', 'two', 'three', 'four', 'five'],
rotation=30, fontsize='small')
ax.set_title('My first matplotlib Plotting and Visualization.ipynb')
Text(0.5, 1, 'My first matplotlib Plotting and Visualization.ipynb')
ax.set_xlabel('Stages')
# similar steps for y-axis
Text(0.5, 3.200000000000003, 'Stages')
fig
The axes class has a set method that allows batch setting of plot properties. From the prior example, we could also have written
props = {
'title': 'My first matplotlib plot',
'xlabel': 'Stages'
}
ax.set(**props)
[Text(0.5, 17.200000000000003, 'Stages'), Text(0.5, 1, 'My first matplotlib plot')]
fig
Legends are another critical element for identifying plot elements
fig = plt.figure(); ax=fig.add_subplot(1, 1, 1)
ax.plot(randn(1000).cumsum(), 'k', label='one')
[<matplotlib.lines.Line2D at 0x1f96de3e088>]
ax.plot(randn(1000).cumsum(), 'k--', label='two')
[<matplotlib.lines.Line2D at 0x1f96dffffc8>]
ax.plot(randn(1000).cumsum(), 'k.', label='three')
[<matplotlib.lines.Line2D at 0x1f96dfb1cc8>]
ax.legend(loc='best')
<matplotlib.legend.Legend at 0x1f96de5c888>
fig
In addition to the standard plot types, you may wish to draw your own plot annotations, which could consist of text, arrows, or other shapes. You can add annotations and text using the text, arrow, and annotate functions.
# text draws text at given coordinates (x, y) on the plot
# with optional custom styling
# ax.text(x, y, 'hello world',
# family='monospace', fontsize=10)
# Annotations can draw both text and arrows arranged appropriately
from datetime import datetime
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
data = pd.read_csv(r'examples/spx.csv', index_col=0, parse_dates=True)
spx = data['SPX']
spx.plot(ax=ax, style='k-')
crisis_data = [
(datetime(2007, 10, 11), 'Peak of bull market'),
(datetime(2008, 3, 12), 'Bear Stearns Fails'),
(datetime(2008, 9, 15), 'Lehman Bankruptcy')
]
for date, label in crisis_data:
ax.annotate(label, xy=(date, spx.asof(date) + 75),
xytext=(date, spx.asof(date) + 255),
arrowprops=dict(facecolor='black', headwidth=4, width=2,
headlength=4),
horizontalalignment='left', verticalalignment='top')
# Zoom in on 2007-2010
ax.set_xlim(['1/1/2007', '1/1/2011'])
ax.set_ylim([600, 1800])
ax.set_title('Important dates in the 2008-2009 financial crisis')
Text(0.5, 1.0, 'Important dates in the 2008-2009 financial crisis')
# drawing shapes
# matplotlib has objects that represent many common shapes, referred to as patches
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
rect = plt.Rectangle((0.2, 0.75), 0.4, 0.15, color='k', alpha=0.3)
circ = plt.Circle((0.7, 0.2), 0.15, color='b', alpha=0.3)
pgon = plt.Polygon([[0.15, 0.15], [0.35, 0.4], [0.2, 0.6]],
color='g', alpha=0.5)
ax.add_patch(rect)
ax.add_patch(circ)
ax.add_patch(pgon)
<matplotlib.patches.Polygon at 0x1f96e280588>
In pandas we may have multiple columns of data, along with row and column labels. Pandas itself has built-in methods that simplify creating visualizations from DataFrame and Series objects. Another library is seaborn, a statistical graphics library created by Michael Waskom. Seaborn simplifies creating many common visualization types.
Importing seaborn modifies the default matplotlib color schemes and plot styles to improve readability and aesthetics. Even if you do not use the seaborn API, you may prefer to import seaborn as a simple way to improve the visual aesthetics of general matplotlib plots.
Series and DataFrame each have a plot attribute for making some basic plot types. By default, *plot()* makes line plots
s = pd.Series(np.random.randn(10).cumsum(), index=np.arange(0, 100, 10))
s.plot()
<matplotlib.axes._subplots.AxesSubplot at 0x1f96e28a588>
The Series object’s *index* is passed to matplotlib for plotting on the *x-axis, though you can disable this by passing use_index=False*. The x-axis ticks and limits can be adjusted with the xticks and xlim options, and y-axis respectively with yticks and ylim.
Most of pandas’s plotting methods accept an optional *ax* parameter, which can be a matplotlib subplot object.
df = pd.DataFrame(np.random.randn(10, 4).cumsum(0), columns=['A', 'B', 'C', 'D'],
index=np.arange(0, 100, 10))
df.plot()
<matplotlib.axes._subplots.AxesSubplot at 0x1f96e328608>
Series.plot method arguments
Argument | Description |
---|---|
label | Label for plot legend |
ax | matplotlib subplot object to plot on; if nothing passed, uses active matplotlib subplot |
style | Style string, like 'ko--', to be passed to matplotlib |
alpha | The plot fill opacity (from 0 to 1) |
kind | Can be 'area', 'bar', 'barh', 'density', 'hist', 'kde', 'line', 'pie' |
logy | Use logarithmic scaling on the y-axis |
use_index | Use the object index for tick labels |
rot | Rotation of tick labels (0 through 360) |
xticks | Values to use for x-axis ticks |
yticks | Values to use for y-axis ticks |
xlim | x-axis limits (e.g., [0, 10]) |
ylim | y-axis limits |
grid | Display axis grid (on by default) |
*DataFrame-specific plot arguments*
Argument | Description |
---|---|
subplots | Plot each DataFrame column in a separate subplot |
sharex | If subplots=True, share the same x-axis, linking ticks and limits |
sharey| If subplots=True, share the same y-axis| |figsize| Size of figure to create as tuple| |title| Plot title as string| |legend| Add a subplot legend (True by default)| |sort_columns| Plot columns in alphabetical order; by default uses existing column order|
# plot.bar() and plot.barh() make vertical and horizontal bar plots
fig, axes = plt.subplots(2, 1)
data = pd.Series(np.random.rand(16), index=list('abcdefghijklmnop'))
data.plot.bar(ax=axes[0], color='k', alpha=0.7)
data.plot.barh(ax=axes[1], color='k', alpha=0.7)
<matplotlib.axes._subplots.AxesSubplot at 0x1f96e55f348>
With a DataFrame, bar plots group the values in each row together in a group in bars, side by side, for each value.
df = pd.DataFrame(np.random.rand(6, 4), index=[
'one', 'two', 'three', 'four', 'five', 'six'],
columns=pd.Index(['A', 'B', 'C', 'D'], name='Genus'))
df
Genus | A | B | C | D |
---|---|---|---|---|
one | 0.301686 | 0.156333 | 0.371943 | 0.270731 |
two | 0.750589 | 0.525587 | 0.689429 | 0.358974 |
three | 0.381504 | 0.667707 | 0.473772 | 0.632528 |
four | 0.942408 | 0.180186 | 0.708284 | 0.641783 |
five | 0.840278 | 0.909589 | 0.010041 | 0.653207 |
six | 0.062854 | 0.589813 | 0.811318 | 0.060217 |
df.plot.bar()
# Note that the name “Genus” on the DataFrame’s columns is used to title the legend.
<matplotlib.axes._subplots.AxesSubplot at 0x1f96f7a8448>
We create stacked bar plots from a DataFrame by passing stacked=True, resulting in the value in each row being stacked together
df.plot.barh(stacked=True, alpha=0.5)
<matplotlib.axes._subplots.AxesSubplot at 0x1f96f6d0288>
tips = pd.read_csv(r'examples/tips.csv')
tips.head()
total_bill | tip | smoker | day | time | size | |
---|---|---|---|---|---|---|
0 | 16.99 | 1.01 | No | Sun | Dinner | 2 |
1 | 10.34 | 1.66 | No | Sun | Dinner | 3 |
2 | 21.01 | 3.50 | No | Sun | Dinner | 3 |
3 | 23.68 | 3.31 | No | Sun | Dinner | 2 |
4 | 24.59 | 3.61 | No | Sun | Dinner | 4 |
# https://chrisalbon.com/python/data_wrangling/pandas_crosstabs/
party_counts = pd.crosstab(tips['day'], tips['size'])
party_counts
size | 1 | 2 | 3 | 4 | 5 | 6 |
---|---|---|---|---|---|---|
day | ||||||
Fri | 1 | 16 | 1 | 1 | 0 | 0 |
Sat | 2 | 53 | 18 | 13 | 1 | 0 |
Sun | 0 | 39 | 15 | 18 | 3 | 1 |
Thur | 1 | 48 | 4 | 5 | 1 | 3 |
party_counts = party_counts.loc[:, 2:5]
party_counts
size | 2 | 3 | 4 | 5 |
---|---|---|---|---|
day | ||||
Fri | 16 | 1 | 1 | 0 |
Sat | 53 | 18 | 13 | 1 |
Sun | 39 | 15 | 18 | 3 |
Thur | 48 | 4 | 5 | 1 |
print(party_counts.sum(1))
party_pcts = party_counts.div(party_counts.sum(1), axis=0)
day Fri 18 Sat 85 Sun 75 Thur 58 dtype: int64
party_pcts
size | 2 | 3 | 4 | 5 |
---|---|---|---|---|
day | ||||
Fri | 0.888889 | 0.055556 | 0.055556 | 0.000000 |
Sat | 0.623529 | 0.211765 | 0.152941 | 0.011765 |
Sun | 0.520000 | 0.200000 | 0.240000 | 0.040000 |
Thur | 0.827586 | 0.068966 | 0.086207 | 0.017241 |
party_pcts.plot.bar()
<matplotlib.axes._subplots.AxesSubplot at 0x1f96dfe9b88>
# same plot using seaborn
import seaborn as sns
tips['tip_pct'] = tips['tip'] / (tips['total_bill'] - tips['tip'])
tips.head()
total_bill | tip | smoker | day | time | size | tip_pct | |
---|---|---|---|---|---|---|---|
0 | 16.99 | 1.01 | No | Sun | Dinner | 2 | 0.063204 |
1 | 10.34 | 1.66 | No | Sun | Dinner | 3 | 0.191244 |
2 | 21.01 | 3.50 | No | Sun | Dinner | 3 | 0.199886 |
3 | 23.68 | 3.31 | No | Sun | Dinner | 2 | 0.162494 |
4 | 24.59 | 3.61 | No | Sun | Dinner | 4 | 0.172069 |
sns.barplot(x='tip_pct', y='day', data=tips, orient='h')
<matplotlib.axes._subplots.AxesSubplot at 0x1f9713602c8>
seaborn.barplot has a *hue* option that enables us to split by an additional categorical value
sns.barplot(x='tip_pct', y='day', hue='time', data=tips, orient='h')
<matplotlib.axes._subplots.AxesSubplot at 0x1f97151c3c8>
A histogram is a kind of bar plot that gives a discretized display of value frequency. The data points are split into discrete, evenly spaced bins, and the number of data points in each bin is plotted.
tips['tip_pct'].plot.hist(bins=50)
<matplotlib.axes._subplots.AxesSubplot at 0x1f97143c808>
A related plot type is a density plot, which is formed by computing an estimate of a continuous probability distribution that might have generated the observed data. The usual procedure is to approximate this distribution as a mixture of “kernels”—that is, simpler distributions like the normal distribution. Thus, density plots are also known as kernel density estimate (KDE) plots. Using plot.kde makes a density plot using the conventional mixture-of-normals estimate
tips['tip_pct'].plot.density()
<matplotlib.axes._subplots.AxesSubplot at 0x1f9716bb648>
Seaborn makes histograms and density plots even easier through its distplot method, which can plot both a histogram and a continuous density estimate simulta‐ neously
# example - bimodal distribution consisting of draws from
# two different standard normal distributions
comp1 = np.random.normal(0, 1, size=200)
comp2 = np.random.normal(10, 2, size=200)
values = pd.Series(np.concatenate([comp1, comp2]))
sns.distplot(values, bins=100, color='k')
<matplotlib.axes._subplots.AxesSubplot at 0x1f97159f4c8>
Point plots or scatter plots can be a useful way of examining the relationship between two one-dimensional data series.
macro = pd.read_csv(r'examples/macrodata.csv')
data = macro[['cpi', 'm1', 'tbilrate', 'unemp']]
trans_data = np.log(data).diff().dropna()
trans_data[-5:]
cpi | m1 | tbilrate | unemp | |
---|---|---|---|---|
198 | -0.007904 | 0.045361 | -0.396881 | 0.105361 |
199 | -0.021979 | 0.066753 | -2.277267 | 0.139762 |
200 | 0.002340 | 0.010286 | 0.606136 | 0.160343 |
201 | 0.008419 | 0.037461 | -0.200671 | 0.127339 |
202 | 0.008894 | 0.012202 | -0.405465 | 0.042560 |
*We can then use seaborn’s regplot method, which makes a scatter plot and fits a linear regression line*
sns.regplot('m1', 'unemp', data=trans_data)
plt.title(f"Changes in log {'m1'} versus log {'unemp'}")
Text(0.5, 1.0, 'Changes in log m1 versus log unemp')
In exploratory data analysis it’s helpful to be able to look at all the scatter plots among a group of variables; this is known as a *pairs plot* or *scatter plot matrix. Making such a plot from scratch is a bit of work, so seaborn has a convenient pairplot* function, which supports placing histograms or density estimates of each variable along the diagonal
sns.pairplot(trans_data, diag_kind='kde', plot_kws={'alpha': 0.7})
<seaborn.axisgrid.PairGrid at 0x1f97178ef08>
plot_kws - enables us to pass down configuration options to the individual plotting calls on the off-diagonal elements.
What about datasets where we have additional grouping dimensions? One way to visualize data with many categorical variables is to use a facet grid. Seaborn has a useful built-in function factorplot that simplifies making many kinds of faceted plots
sns.catplot(x='day', y='tip_pct', hue='time', col='smoker',
kind='bar', data=tips[tips.tip_pct < 1])
<seaborn.axisgrid.FacetGrid at 0x1f97203d888>
Instead of grouping by 'time' by different bar colors within a facet, we can also expand the facet grid by adding one row per time value
sns.catplot(x='day', y='tip_pct', row='time', col='smoker', kind='bar',
data=tips[tips.tip_pct < 1])
<seaborn.axisgrid.FacetGrid at 0x1f971f88988>
*catplot* supports other plot types that may be useful depending on what you are trying to display. For example, box plots (which show the median, quartiles, and outliers) can be an effective visualization type
sns.catplot(x='tip_pct', y='day', kind='box', data=tips[tips.tip_pct < 0.5])
<seaborn.axisgrid.FacetGrid at 0x1f9728666c8>