Disclaimer: Este notebook contiene mis notas sobre Matplotlib, resumiendo básicamente el capítulo 4 de Python Data Science Handbook escrito por Jake VanderPlas. Recomiendo leer la fuente original e ir ejecutando todos los ejemplos (___learn by doing!___).
Matplotlib (Math Plot Lib for Python) es una librería para la creación de visualizaciones en Python, basada en los arrays de NumPy, y diseñada para trabajar sobre todo el stack de SciPy.
Una de las características más importantes de Matplotlib es que funciona en múltiples sistemas operativos y frameworks. Su desventaja reside en su interfaz gráfica, algo anticuada; pero se han desarrollado por encima nuevos paquetes como Seaborn que potencian ese aspecto.
Para poder usar Matplotlib, simplemente tendremos que importar la librería. Usualmente también trabajaremos con el módulo pyplot
(proporciona una interfaz con sintaxis parecido al de Matlab), que también importamos por defecto.
import matplotlib as mpl # mpl es el alias usado habitualmente para matplotlib
import matplotlib.pyplot as plt # plt es el alias para pyplot
Podemos especificar fácilmente el estilo de las visualizaciones:
plt.style.use('seaborn-whitegrid')
Existen 3 formas de pintar gráficas con Matplotlib:
plt.show()
al final, de forma que se muestren todas las gráficas creadas.%matplotlib
en el terminal, todas las gráficas que se creen después se visualizarán automáticamente. Podemos usar plt.draw()
para refrescarlas.%matplotlib notebook
activaremos la inserción de gráficas interactivas; mientras que con %matplotlib inline
se insertarán imágenes estáticas.Estas imágenes se pueden guardar usando el método savefig('name')
.
También las podemos cargar usando la clase Image
del módulo IPython.display:
from IPython.display import Image
#Image('file')
Podemos trabajar con la interfaz de pyplot, al estilo de Matlab:
import numpy as np
%matplotlib inline
x = np.linspace(0, 10, 100)
# Creamos una figura, que es el canvas para todo lo demás
plt.figure()
# Creamos el subpanel 1 (de 2) y pintamos la gráfica
plt.subplot(2, 1, 1) # (filas, columnas, nº de subpanel)
plt.plot(x, np.sin(x))
# Creamos el subpanel 2 y pintamos la gráfica
plt.subplot(2, 1, 2) # (filas, columnas, nº de subpanel)
plt.plot(x, np.cos(x)); # el ';' evita que se imprima la salida, que es una línea de texto.
O bien podemos trabajar con la interfaz orientada a objetos:
# Creación de la figura y el panel, todo en uno
fig, ax = plt.subplots(2) # ax sería un array de 2 objetos Axes
# Pintamos las gráficas en los subpaneles (objetos Axes)
ax[0].plot(x, np.sin(x))
ax[1].plot(x, np.cos(x));
# Lo mismo sin usar subpaneles...
fig = plt.figure()
ax1 = fig.add_axes([0, 0.5, 1, 0.4], xticklabels=[], ylim=(-1.1, 1.1))
ax2 = fig.add_axes([0, 0.1, 1, 0.4], ylim=(-1.1, 1.1))
x = np.linspace(0, 10)
ax1.plot(x, np.sin(x))
ax2.plot(x, np.cos(x));
Matplotlib es bastante flexible con las gráficas que podemos dibujar dentro de una figura. A continuación 2 ejemplos sencillos:
# Con subplots, ejemplo de 2 filas y 3 columnas
fig, ax = plt.subplots(2, 3, sharex='col', sharey='row')
ax[1, 2].plot(x, np.sin(x));
# Con GridSpec podemos hacer cosas más complicadas...
grid = plt.GridSpec(2, 3, wspace=0.4, hspace=0.3)
plt.subplot(grid[0, 0])
plt.subplot(grid[0, 1:])
plt.subplot(grid[1, :2])
plt.subplot(grid[1, 2]);
Las visualizaciones más sencillas quizá sean las de una función lineal y = f(x)
fig = plt.figure() # creación de una figura (esto no pinta nada, sólo crea un contenedor para todo lo demás)
ax = plt.axes() # creación de los ejes (esto pinta el grid)
# Ejemplo usando figure y axes
fig = plt.figure()
ax = plt.axes()
x = np.linspace(0, 10, 1000)
ax.plot(x, np.sin(x))
ax.plot(x, np.cos(x));
Alternativamente podemos usar la interfaz pylab, y que tanto la figura como los ejes se creen de forma transparente:
plt.plot(x, np.sin(x))
plt.plot(x, np.cos(x));
Con plot()
podemos usar varios parámetros:
color
=...linestyle
=label
='str' : etiqueta para una función que se mostrará en la leyenda de la gráfica, si se muestraAmbos se pueden combinar; escribiendo simplemente '--g' tendremos dashed green.
Tenemos más métodos disponibles en pyplot
:
xlim(x1, x2)
: ajusta eje x, pudiendo ser creciente (lim2 > lim1) o decrecienteylim(y1, y2)
: ajusta eje y, creciente o decreciente.axis()
: ajusta ambos ejesaxis([x1, x2, y1, y2])
: ajusta ambos ejes a lo indicadoaxis('tight')
: ajusta los ejes al contenidoaxis('equal')
: ajusta los ejes para que las unidades en x e y tengan el mismo tamañotitle(str)
: establece el título de la gráficaxlabel(str)
: establece la etiqueta del eje xylabel(str)
: establece la etiqueta del eje ylegend()
: muestra la leyenda de la gráficaCuando usamos la interfaz orientada a objetos, algunos métodos de Axes se llaman igual que los de pyplot, pero otros no. En todo caso se suelen aplicar todos como parámetros del método set()
:
# Ejemplo
ax = plt.axes()
ax.plot(x, np.sin(x))
ax.set(xlim=(0, 10), ylim=(-2, 2),
xlabel='x', ylabel='sin(x)',
title='A Simple Plot');
Otro tipo de gráfica muy usada son los diagramas de dispersión, que viene a ser una versión discreta de la gráfica lineal.
Podemos crearlos usando plot()
:
# Ejemplo
x = np.linspace(0, 10, 30)
y = np.sin(x)
plt.plot(x, y, 'o'); # el tercer argumento es el símbolo usado para representar los puntos
# Ejemplo 2
rng = np.random.RandomState(0)
for marker in ['o', '.', ',', 'x', '+', 'v', '^', '<', '>', 's', 'd']:
plt.plot(rng.rand(5), rng.rand(5), marker,
label="marker='{0}'".format(marker))
plt.legend(numpoints=1, loc='upper right', frameon=True, fancybox=True, shadow=True)
plt.xlim(0, 1.8);
Existen más parámetros para personalizar las gráficas. Lo mejor es echar un vistazo a la documentación de pyplot.plot.
Una segunda forma de crear este tipo de gráficas es usando la función scatter()
de pyplot, que funciona de forma parecida a plot():
# Ejemplo
plt.scatter(x, y, marker='o');
La ventaja con respecto a plot, es que con scatter podemos tratar las propiedades de cada punto de forma individual. Esto nos permite en realidad ver representadas 4 dimensiones de un punto.
# Ejemplo con puntos aleatorios de distintos colores y tamaños
rng = np.random.RandomState(0)
x = rng.randn(100)
y = rng.randn(100)
colors = rng.rand(100)
sizes = 1000 * rng.rand(100)
plt.scatter(x, y, c=colors, s=sizes, alpha=0.3, cmap='viridis') # cmap especifica el mala de color; alpha la transparencia
plt.colorbar(); # muestra la escala de colores
La desventaja de usar scatter, es que para datasets grandes es sensiblemente menos eficiente que plot.
Para más información sobre mapas de color:
Podemos crear una gráfica de barras de error con una simple llamada a la función errorbar()
:
# Ejemplo
x = np.linspace(0, 10, 50)
dy = 0.8
y = np.sin(x) + dy * np.random.randn(50)
plt.errorbar(x, y, yerr=dy, fmt='.k',
ecolor='lightgray', elinewidth=3, capsize=0); # fmt indica el formato (mismo sinaxis que plot)
Para más opciones, lo mejor es consultar la documentación de pyplot.errorbar.
Aunque Matplotlib no tiene nada específico para mostrar el error de forma continua, se puede conseguir su visualización combinando plot() y fill()
:
# Ejemplo (https://scikit-learn.org/stable/auto_examples/gaussian_process/plot_gpr_noisy_targets.html)
import numpy as np
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, ConstantKernel as C
np.random.seed(1)
def f(x):
"""The function to predict."""
return x * np.sin(x)
# First the noiseless case
X = np.atleast_2d([1., 3., 5., 6., 7., 8.]).T
# Observations
y = f(X).ravel()
# Mesh the input space for evaluations of the real function, the prediction and its MSE
x = np.atleast_2d(np.linspace(0, 10, 1000)).T
# Instantiate a Gaussian Process model
kernel = C(1.0, (1e-3, 1e3)) * RBF(10, (1e-2, 1e2))
gp = GaussianProcessRegressor(kernel=kernel, n_restarts_optimizer=9)
# Fit to data using Maximum Likelihood Estimation of the parameters
gp.fit(X, y)
# Make the prediction on the meshed x-axis (ask for MSE as well)
y_pred, sigma = gp.predict(x, return_std=True)
# Plot the function, the prediction and the 95% confidence interval based on the MSE
plt.figure()
plt.plot(x, f(x), 'r:', label=u'$f(x) = x\,\sin(x)$')
plt.plot(X, y, 'r.', markersize=10, label=u'Observations')
plt.plot(x, y_pred, 'b-', label=u'Prediction')
plt.fill(np.concatenate([x, x[::-1]]),
np.concatenate([y_pred - 1.9600 * sigma,
(y_pred + 1.9600 * sigma)[::-1]]),
alpha=.2, fc='b', ec='None', label='95% confidence interval')
plt.xlabel('$x$')
plt.ylabel('$f(x)$')
plt.ylim(-10, 20)
plt.legend(loc='upper left');
A veces queremos representar datos con 3 dimensiones en 2, usando curvas de nivel o regiones de color. Matplotlib tiene 3 funciones para realizar este tipo de gráficas:
plt.contour()
: para crear curvas de densidadplt.contourf()
: para crear curvas de densidad rellenasplt.imshow()
: para mostrar imágenes# Ejemplo
def f(x, y):
return np.sin(x) ** 10 + np.cos(10 + y * x) * np.cos(x)
x = np.linspace(0, 5, 50)
y = np.linspace(0, 5, 40)
X, Y = np.meshgrid(x, y)
Z = f(X, Y)
plt.contour(X, Y, Z, 20, cmap='RdGy');
# Versión en b/n
plt.contour(X, Y, Z, colors='black'); #
# Mejora rellenando el espacio entre líneas
plt.contourf(X, Y, Z, 20, cmap='RdGy')
plt.colorbar();
# Mejora suavizando saltos
plt.imshow(Z, extent=[0, 5, 0, 5], origin='lower', cmap='RdGy')
plt.colorbar()
plt.axis(aspect='image');
¡Lo mejor es combinar ambas funciones y usar etiquetas!
contours = plt.contour(X, Y, Z, 3, colors='black')
plt.clabel(contours, inline=True, fontsize=8)
plt.imshow(Z, extent=[0, 5, 0, 5], origin='lower', cmap='RdGy', alpha=0.4)
plt.colorbar();
Si tenemos outliers es posible que lleguemos a ver mal una gráfica de estas características. Pero tiene solución...
# Ejemplo usando ruido
x = np.linspace(0, 10, 1000)
I = np.sin(x) * np.cos(x[:, np.newaxis])
speckles = (np.random.random(I.shape) < 0.01)
I[speckles] = np.random.normal(0, 3, np.count_nonzero(speckles))
plt.figure(figsize=(10, 3.5))
# En la gráfica de la izquierda vemos que la magnitud del ruido ha arruinado el rango de la gráfica
plt.subplot(1, 2, 1)
plt.imshow(I, cmap='RdBu')
plt.colorbar()
# En la gráfica de la derecha conseguimos arreglarlo
plt.subplot(1, 2, 2)
plt.imshow(I, cmap='RdBu')
# esto es lo que usuamos para arreglar la visualización...
plt.colorbar(extend='both')
plt.clim(-1, 1);
Para pintar histogramas tenemos la función hist()
de pyplot. Estas gráficas son fruto de repartir los datos disponibles en diferentes cajones o intervalos.
# Ejemplo
data = np.random.randn(1000)
plt.hist(data, bins=30, density=True, alpha=0.5,
histtype='stepfilled', color='steelblue',
edgecolor='none');
# Ejemplo múltiple
x1 = np.random.normal(0, 0.8, 1000)
x2 = np.random.normal(-2, 1, 1000)
x3 = np.random.normal(3, 2, 1000)
kwargs = dict(histtype='stepfilled', alpha=0.3, density=True, bins=40)
plt.hist(x1, **kwargs)
plt.hist(x2, **kwargs)
plt.hist(x3, **kwargs);
Para hacer el cómputo del histograma sin dibujarlo (contar el nº de puntos por cajón), usamos la función histogram()
de NumPy:
# Ejemplo
counts, bin_edges = np.histogram(data, bins=5)
print(counts)
[ 6 141 453 349 51]
Igual que creamos los histogramas en 1D dividiendo valores en distintos cajones, podemos crear histogramas en 2D dividiendo los puntos en cajones de 2 dimensiones. En pyplot contamos con la función hist2d()
para representar estos histogramas, y con hexbin()
para usar hexágonos en vez de cuadros, así como otras opciones adicionales.
# Ejemplo:
media = [0, 0]
covarianza = [[1, 1], [1, 2]]
x, y = np.random.multivariate_normal(media, covarianza, 10000).T
plt.hist2d(x, y, bins=30, cmap='Blues')
cb = plt.colorbar()
cb.set_label('puntos por cajón')
# Mismo ejemplo con hexbin()
plt.hexbin(x, y, gridsize=30, cmap='Blues')
cb = plt.colorbar(label='puntos por cajón')
También podemos hacer el cómputo de estos histogramas usando la función histogram2d()
de NumPy. Si queremos pasar a más dimensiones usaremos histogramdd()
Podemos cambiar la configuración por defecto de Matplotlib usando la función rc()
de pyplot, que modifica la configuración en tiempo de ejecución. Lo vemos con un ejemplo:
# Histograma con la configuración por defecto
x = np.random.randn(1000)
plt.hist(x);
# Pasos previos
IPython_default = plt.rcParams.copy() # copia del diccionario rcParams, para poder resetear la configuración
from matplotlib import cycler
colors = cycler('color',
['#EE6666', '#3388BB', '#9988DD',
'#EECC55', '#88BB44', '#FFBBBB'])
# Config
plt.rc('axes', facecolor='#E6E6E6', edgecolor='none',
axisbelow=True, grid=True, prop_cycle=colors)
plt.rc('grid', color='w', linestyle='solid')
plt.rc('xtick', direction='out', color='gray')
plt.rc('ytick', direction='out', color='gray')
plt.rc('patch', edgecolor='#E6E6E6')
plt.rc('lines', linewidth=2)
# Histograma con la nueva configuración
plt.hist(x);
Para cambiar el aspecto de las gráficas es preferible usar las hojas de estilo. Por defecto tenemos varias incluidas dentro del módulo style
de pyplot, pero también podemos crear las nuestras propias. Podemos ver todas las disponibles ejecutando el atributo available
:
plt.style.available[:5]
['bmh', 'classic', 'dark_background', 'fast', 'fivethirtyeight']
Existen varias formas de cambiar la hoja de estilo:
# Indefinidamente
plt.style.use('seaborn-whitegrid')
# Para una porción del código
with plt.style.context('ggplot'):
plt.hist(x);
Podemos escribir anotaciones dentro de las gráficas para añadir información o marcar puntos interesantes, y ayudar así a construir una historia.
# Ejemplo
fig, ax = plt.subplots(facecolor='lightgray')
ax.axis([0, 10, 0, 10])
ax.text(1, 5, ". Data: (1, 5)", transform=ax.transData) # por defecto: posición absoluta según la escala
ax.text(0.5, 0.1, ". Axes: (0.5, 0.1)", transform=ax.transAxes) # posición relativa a los ejes
ax.text(0.2, 0.2, ". Figure: (0.2, 0.2)", transform=fig.transFigure); # posición relativa al tamaño de la figura
Podemos cambiar las escalas de los ejes, las unidades, y las marcas que se muestran (podemos incluso eliminarlas).
# Ejemplo
with plt.style.context('classic'):
ax = plt.axes((0, 0, 0.48, 0.35), xscale='log', yscale='log')
ax.grid();
Matplotlib incorpora el submódulo mplot3d
que nos permite crear representaciones en 3D:
from mpl_toolkits import mplot3d
fig = plt.figure()
ax = plt.axes(projection='3d')
# Ejemplo: puntos y líneas
ax = plt.axes(projection='3d')
# Data for a three-dimensional line
zline = np.linspace(0, 15, 1000)
xline = np.sin(zline)
yline = np.cos(zline)
ax.plot3D(xline, yline, zline, 'gray')
# Data for three-dimensional scattered points
zdata = 15 * np.random.random(100)
xdata = np.sin(zdata) + 0.1 * np.random.randn(100)
ydata = np.cos(zdata) + 0.1 * np.random.randn(100)
ax.scatter3D(xdata, ydata, zdata, c=zdata, cmap='Greens');
# Ejemplo: curvas de nivel
def f(x, y):
return np.sin(np.sqrt(x ** 2 + y ** 2))
x = np.linspace(-6, 6, 30)
y = np.linspace(-6, 6, 30)
X, Y = np.meshgrid(x, y)
Z = f(X, Y)
fig = plt.figure()
ax = plt.axes(projection='3d')
ax.view_init(60, 35)
ax.contour3D(X, Y, Z, 25, cmap='binary')
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('z');
# Ejemplo: superficie
ax = plt.axes(projection='3d')
ax.plot_surface(X, Y, Z, rstride=1, cstride=1,
cmap='viridis', edgecolor='none')
ax.set_title('surface');
Proporciona una API para Matplotlib que ofrece mejores estilos por defecto, define funciones de alto nivel para gráficas estadísticas comunes, y se integra con la funcionalidad de los DataFrames de Pandas.
Podemos aplicar el estilo de Seaborn simplemente llamando a su método set()
para obtener visualizaciones más agradables por defecto:
import seaborn as sns
sns.set()
Algunas de las funciones de alto nivel incluidas en Seaborn:
kdeplot()
: dibuja la curva KDE (kernel density estimation) en lugar de un histograma (1D)distplot()
: dibuja la curva KDE y el histograma superpuestos (1 o 2D)jointplot()
: dibuja la distribución conjunta junto a las marginales (2D)pairplot()
: dibuja el equivalente a joinplot para más dimensiones o atributos ("pairs plot"), comparando 1 a 1 (+D)FacetGrid()
: dibuja histogramas por facetas (subconjuntos dentro de una dimensión)catplot()
: pinta la distribución de un parámetro dentro de los cajones definidos por otro parámetro, o pinta diagramas de barras con series temporales, etc.violinplot()
: compara distribuciones.regplot()
: ajusta una regresión lineal a los datos# Ejemplo de visualización de histogramas y curvas KDE
import pandas as pd
data = np.random.multivariate_normal([0, 0], [[5, 2], [2, 2]], size=2000)
data = pd.DataFrame(data, columns=['x', 'y'])
for col in 'xy':
sns.distplot(data[col]);
# Ejemplo de distribuciones conjuntas
with sns.axes_style('white'):
sns.jointplot("x", "y", data, kind='kde'); # kind=hex para cambiar a estilo hexagonal, =reg para regresión
# Ejemplo de "pairs plot"
iris = sns.load_dataset("iris")
sns.pairplot(iris, hue='species', height=2.5);
# Ejemplo de histogramas facetados
tips = sns.load_dataset('tips')
tips['tip_pct'] = 100 * tips['tip'] / tips['total_bill']
grid = sns.FacetGrid(tips, row="sex", col="time", margin_titles=True)
grid.map(plt.hist, "tip_pct", bins=np.linspace(0, 40, 15));
# Ejemplo de factor plots
with sns.axes_style(style='ticks'):
g = sns.catplot("day", "total_bill", "sex", data=tips, kind="box")
g.set_axis_labels("Day", "Total Bill");
# Ejemplo de diagramas de barras para series temporales
planets = sns.load_dataset('planets')
with sns.axes_style('white'):
g = sns.catplot("year", data=planets, aspect=2,
kind="count", color='steelblue')
g.set_xticklabels(step=5)
with sns.axes_style('white'):
g = sns.catplot("year", data=planets, aspect=4.0, kind='count',
hue='method', order=range(2001, 2015))
g.set_ylabels('Number of Planets Discovered')