17th June 2020
See accompanying blog post on Engineering for Data Science
# install requirements if necessary
# ! pip install matplotlib pandas ffn nb_black
%load_ext nb_black
You can't directly loop through the subplots if both nrows
and ncols
are greater than 1. This is because you are returned a list of lists, rather than a list of subplot objects.
import matplotlib.pyplot as plt
%matplotlib inline
# create subplots
fig, axs = plt.subplots(nrows=3, ncols=2)
print(axs.shape)
axs
(3, 2)
array([[<AxesSubplot:>, <AxesSubplot:>], [<AxesSubplot:>, <AxesSubplot:>], [<AxesSubplot:>, <AxesSubplot:>]], dtype=object)
# library to get stock data
import ffn
# load daily stock prices for selected stocks from ffn
tickers = ["aapl", "msft", "tsla", "nvda", "intc"]
prices = ffn.get(tickers, start="2017-01-01")
# convert data into a 'long' table for this plotting exercise
df = prices.melt(ignore_index=False, var_name="ticker", value_name="closing_price")
df.head()
ticker | closing_price | |
---|---|---|
Date | ||
2017-01-03 | aapl | 27.413372 |
2017-01-04 | aapl | 27.382690 |
2017-01-05 | aapl | 27.521944 |
2017-01-06 | aapl | 27.828764 |
2017-01-09 | aapl | 28.083660 |
axs.ravel()
¶# define subplot grid
fig, axs = plt.subplots(nrows=3, ncols=2, figsize=(15, 12))
plt.subplots_adjust(hspace=0.5)
fig.suptitle("Daily closing prices", fontsize=18, y=0.95)
# loop through tickers and axes
for ticker, ax in zip(tickers, axs.ravel()):
# filter df for ticker and plot on specified axes
df[df["ticker"] == ticker].plot(ax=ax)
# chart formatting
ax.set_title(ticker.upper())
ax.get_legend().remove()
ax.set_xlabel("")
plt.show()
plt.subplot
¶plt.figure(figsize=(15, 12))
plt.subplots_adjust(hspace=0.5)
plt.suptitle("Daily closing prices", fontsize=18, y=0.95)
# loop through the length of tickers and keep track of index
for n, ticker in enumerate(tickers):
# add a new subplot iteratively
ax = plt.subplot(3, 2, n + 1)
# filter df and plot ticker on the new subplot axis
df[df["ticker"] == ticker].plot(ax=ax)
# chart formatting
ax.set_title(ticker.upper())
ax.get_legend().remove()
ax.set_xlabel("")
Pros of Method 2
Cons of Method 2
The snippet below can be used to dynamically calculate the number of rows in a grid
# find minimium required rows given we want 2 columns
ncols = 2
nrows = len(tickers) // ncols + (len(tickers) % ncols > 0)
nrows
3
plt.figure(figsize=(15, 12))
plt.subplots_adjust(hspace=0.2)
plt.suptitle("Daily closing prices", fontsize=18, y=0.95)
# set number of columns (use 3 to demonstrate the change)
ncols = 3
# calculate number of rows
nrows = len(tickers) // ncols + (len(tickers) % ncols > 0)
# loop through the length of tickers and keep track of index
for n, ticker in enumerate(tickers):
# add a new subplot iteratively using nrows and cols
ax = plt.subplot(nrows, ncols, n + 1)
# filter df and plot ticker on the new subplot axis
df[df["ticker"] == ticker].plot(ax=ax)
# chart formatting
ax.set_title(ticker.upper())
ax.get_legend().remove()
ax.set_xlabel("")
This example is slightly contrived because there are inbuilt methods in Pandas which will do this for you. E.g. using df.groupby('ticker').plot()
, however, you may not have as much easy control over chart formatting. Equally you could also use Seaborn, however, the API for subplots (Facet grids) can be equally as cumbersome.