Start with the base pyplot histogram.
import numpy as np
import matplotlib.pyplot as plt
plt.style.use('ggplot')
plt.rcParams['axes.facecolor']='w'
import matplotlib.ticker as ticker
x = 1.5*np.random.randn(500)
y = 1.5*np.random.randn(500)
x = np.append(x, [-5 + 1*np.random.randn(500)]);
y = np.append(y, [-5 + 1*np.random.randn(500)]);
So here's what we get by default in pyplot
:
plt.hist2d(x,y,bins=20);
or
plt.scatter(x,y)
<matplotlib.collections.PathCollection at 0x7ff121a42050>
So we need to create new 1D histograms on the x and y axes. In the first examples we move the axes to the top and right of the plot.
def scatter_hist(x, y, ax, ax_histx, ax_histy):
# no labels
ax_histx.tick_params(axis="x", labelbottom=True)
ax_histy.tick_params(axis="y", labelleft=True)
# move the axes to the right and top respectively
ax.yaxis.tick_right()
ax.xaxis.tick_top()
# the scatter plot:
ax.scatter(x, y, color='black', alpha=0.4)
# now determine nice limits by hand:
binwidth = 0.2
xymax = max(np.max(np.abs(x)), np.max(np.abs(y)))
lim = (int(xymax/binwidth) + 1) * binwidth
bins = np.arange(-lim, lim + binwidth, binwidth)
#ax_histx.hist(x, bins=bins, width=0.07)
ax_histx.hist(x, bins=bins, histtype='step')
#ax_histy.hist(y, bins=bins, height=0.07, orientation='horizontal')
ax_histy.hist(y, bins=bins, histtype='step', orientation='horizontal')
# definitions for the axes
left, width = 0.1, 0.65
bottom, height = 0.1, 0.65
spacing = 0.01
rect_scatter = [left, bottom, width, height]
rect_histx = [left, bottom + height + spacing, width, 0.2]
rect_histy = [left + width + spacing, bottom, 0.2, height]
# start with a square Figure
fig = plt.figure(figsize=(8, 8))
ax = fig.add_axes(rect_scatter)
ax_histx = fig.add_axes(rect_histx, sharex=ax)
ax_histy = fig.add_axes(rect_histy, sharey=ax)
#ax_histx.set_xticks([])
ax_histy.set_xticks([])
ax_histx.set_yticks([])
#ax_histy.set_yticks([])
# use the previously defined function
scatter_hist(x, y, ax, ax_histx, ax_histy)
fig.text(0.2, 0.71, 'x-label', va='center', rotation='horizontal', fontsize=14)
fig.text(0.7, 0.2, 'y-label', ha='center', rotation='vertical', fontsize=14)
plt.show()
We can also have more traditional bars...
def scatter_hist(x, y, ax, ax_histx, ax_histy):
# no labels
ax_histx.tick_params(axis="x", labelbottom=True)
ax_histy.tick_params(axis="y", labelleft=True)
# the scatter plot:
ax.scatter(x, y, color='black', alpha=0.4)
# move the axes to the right and top respectively
ax.yaxis.tick_right()
ax.xaxis.tick_top()
# now determine nice limits by hand:
binwidth = 0.2
xymax = max(np.max(np.abs(x)), np.max(np.abs(y)))
lim = (int(xymax/binwidth) + 1) * binwidth
bins = np.arange(-lim, lim + binwidth, binwidth)
ax_histx.hist(np.flip(x), bins=bins, width=0.07)
ax_histy.hist(np.flip(y), bins=bins, height=0.07, orientation='horizontal')
# definitions for the axes
left, width = 0.1, 1
bottom, height = 0.1, 1
spacing = 0.005
rect_scatter = [left, bottom, width, height]
rect_histx = [left, bottom + height + spacing, width, 0.2]
rect_histy = [left + width + spacing, bottom, 0.2, height]
# start with a square Figure
fig = plt.figure(figsize=(8, 8))
ax = fig.add_axes(rect_scatter)
ax_histx = fig.add_axes(rect_histx, sharex=ax)
ax_histx.tick_params(axis="x", labelbottom=False)
ax_histx.tick_params(axis="y", labelleft=False)
ax_histy = fig.add_axes(rect_histy, sharey=ax)
ax_histy.tick_params(axis="x", labelbottom=False)
ax_histy.tick_params(axis="y", labelleft=False)
#ax_histx.set_xticks([])
ax_histy.set_xticks([])
ax_histx.set_yticks([])
#ax_histy.set_yticks([])
# use the previously defined function
scatter_hist(x, y, ax, ax_histx, ax_histy)
#plt.subplots_adjust(left=-0.2, bottom=-0.2)
#plt.savefig("test.pdf")
plt.show()
Move the axes...
def scatter_hist(x, y, ax, ax_histx, ax_histy):
# no labels
ax_histx.tick_params(axis="x", labelbottom=True)
ax_histy.tick_params(axis="y", labelleft=True)
# the scatter plot:
ax.scatter(x, y, color='black', alpha=0.4)
# move the axes to the right and top respectively
#ax.yaxis.tick_right()
#ax.xaxis.tick_top()
# now determine nice limits by hand:
binwidth = 0.1
xymax = max(np.max(np.abs(x)), np.max(np.abs(y)))
lim = (int(xymax/binwidth) + 1) * binwidth
#bins = np.arange(-lim, lim + binwidth, binwidth)
ax_histx.hist(x, bins=50, histtype='step', orientation='horizontal')
ax_histy.hist(y, bins=50, histtype='step')
tick_spacing = 3
# start with a square Figure
fig = plt.figure(figsize=(8, 8))
# Manually hack the labels
fig.text(0.5, 0.15, 'x-label', va='center', rotation='horizontal')
fig.text(0.15, 0.5, 'y-label', ha='center', rotation='vertical')
# Add a gridspec with two rows and two columns and a ratio of 2 to 8 between
# the size of the marginal axes and the main axes in both directions.
# Also adjust the subplot parameters for a square plot.
gs = fig.add_gridspec(2, 2, width_ratios=(2, 8), height_ratios=(8, 2),
left=0.2, right=0.8, bottom=0.2, top=0.8,
wspace=0.2, hspace=0.2)
ax = fig.add_subplot(gs[0, 1])
ax_histx = fig.add_subplot(gs[0, 0])#, sharey=ax)
ax_histy = fig.add_subplot(gs[1, 1])#, sharex=ax)
ax.xaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))
ax.yaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))
# use the previously defined function
scatter_hist(x, y, ax, ax_histx, ax_histy)
# commenting out various combinations of these will dictate the lables
# on the plot and were they appear
ax_histx.set_xticks([])
ax_histy.set_xticks([])
ax_histx.set_yticks([])
ax_histy.set_yticks([])
#ax.set_xticks([])
#ax.set_yticks([])
plt.show()
#plt.savefig("marginal_hist.png", dpi=200, bbox_inches='tight')