Vamos a crear una red neuronal para tratar ahora con un conjunto mucho más grande, el dataset MNIST. Verás que, a medida que tratamos con conjuntos mayores, el tiempo de procesamiento se incrementa y la necesidad de contar con una GPU crece al mismo ritmo. Échale un vistazo a este vídeo para que te vayas familiarizando con Google Colab, por si necesitas usarlo.
Aquí tienes un buen tutorial sobre Google Colab
El conjunto MNIST está formado por 70.000 imágenes de dígitos manuscritos del 0 al 9 con un tamaño de 28x28 en escala de grises. A su vez, el conjunto se divide en 60.000 imágenes para entrenamiento y 10.000 para test.
Vamos a utilizar este dataset para entrenar una red y ver qué precisión obtenemos al clasificar el conjunto de test. Piensa bien lo que pretendemos lograr: hacer una red que será capaz de “ver”, aunque, por ahora, solo sean imágenes de dos dimensiones.
El conjunto MNIST es muy popular y se utiliza mucho para aprender y probar redes, así que Keras ya lo incluye como parte de la librería.
import keras
from keras.datasets import mnist
from matplotlib import pyplot as plt
import numpy as np
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# Veamos la forma tiene x_train
print("Shape:", x_train.shape) # 60.000 imágenes de 28x28
# Veamos una imagen cualquiera, por ejemplo, con el índice 125
image = np.array(x_train[125], dtype='float')
plt.imshow(image, cmap='gray')
plt.show()
print("Label:", y_train[125])
Shape: (60000, 28, 28)
Label: 8
También es necesario saber en qué rango de valores se mueven nuestras muestras.
print("Max value:", max(x_train[125].reshape(784)))
print("Min value:", min(x_train[125].reshape(784)))
Max value: 255 Min value: 0
Vemos que cada pixel es un byte con un rango de valores que va desde el 0 hasta el 255 en formato entero. Esta escala no es muy adecuada para la red. Podemos facilitar mucho el trabajo de entrenamiento si transformamos esta escala en otra centrada en el 0 y con un rango de valores entre -0.5 y 0.5. Y, por supuesto, en formato real.
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255 # Escalamos a un rango entre 0 y 1
x_test /= 255
x_train -= 0.5 # desplazamos el rango a -0.5 y 0.5
x_test -= 0.5
print("Max value:", max(x_train[125].reshape(784)))
print("Min value:", min(x_train[125].reshape(784)))
x_train = x_train.reshape(60000, 784)
x_test = x_test.reshape(10000, 784)
Max value: 0.5 Min value: -0.5
Ahora preparamos las etiquetas transformándolas a formato one_hot. Keras tiene funciones para ello.
y_train = keras.utils.to_categorical(y_train, 10) # 10 clases
y_test = keras.utils.to_categorical(y_test, 10)
print("Label:", y_train[125]) # Recordemos que esta muestra tenía valor 8
Label: [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
Tenemos que clasificar imágenes en diez categorías distintas, luego la capa final tendrá diez salidas. En cuanto a la capa de entrada, tenemos una matriz de 28×28. Lo que vamos a hacer es transformarla en un vector de 784 componentes ( 28×28=784). Simplemente tomaremos cada fila de la matriz y las iremos colocando secuencialmente.
Para la capa o capas ocultas vamos a probar primero con una capa oculta con 20 neuronas. Recuerda que estos son los hiperparámetros de los que ya hemos hablado. No hay forma de saber a priori cuál es el número óptimo de capas ni de neuronas por capa oculta.
Como funciones de activación utilizaremos sigmoides y, en la capa final, softmax.
En lugar de definir la red de nuevo con el modelo secuencial de Keras, vamos a hacerlo con la API funcional. Esta API nos sirve para poder definir modelos más complejos que los simplemente creados a partir de acumular capas apiladas. Hay redes que tienen arquitecturas donde hay capas compartidas, las salidas de una capa pueden ir a capas separadas varias capas, etc. Hay una gran variedad.
Por supuesto, la red que vamos a hacer ahora la podríamos hacer perfectamente con el modelo secuencial (y hasta nos sería más fácil), pero vamos a aprender a utilizar esta API.
from keras.layers import Input, Dense
from keras.models import Model
inputs = Input(shape=(784,)) # Capa de entrada
output_h = Dense(units=20, activation='sigmoid')(inputs) # Capa oculta
predictions = Dense(10, activation='softmax')(output_h) # Capa de salida
model = Model(inputs=inputs, outputs=predictions)
Si te has fijado, la API funcional usa la función Dense(units=20, activation='sigmoid')(inputs)
para crear una capa en lugar del método model.add(Dense(units=20, activation='sigmoid', input_dim=784))
del modelo. Esto nos da libertad para asignar la salida de una capa a la entrada de la capa que queramos, no obligatoriamente a la siguiente.
model.compile(loss='mse',
optimizer=keras.optimizers.SGD(learning_rate=1),
metrics=['accuracy'])
Durante el entrenamiento vamos a ir viendo también cómo va evolucionando el accuracy del conjunto de test. Cuando usamos el conjunto de test de esta manera se suele llamar conjunto de validación.
history = model.fit(x_train, y_train, epochs=50, batch_size=30, validation_data=(x_test, y_test))
Train on 60000 samples, validate on 10000 samples Epoch 1/50 60000/60000 [==============================] - 4s 69us/step - loss: 0.0461 - accuracy: 0.6982 - val_loss: 0.0213 - val_accuracy: 0.8863 Epoch 2/50 60000/60000 [==============================] - 4s 60us/step - loss: 0.0182 - accuracy: 0.8930 - val_loss: 0.0150 - val_accuracy: 0.9086 Epoch 3/50 60000/60000 [==============================] - 4s 61us/step - loss: 0.0147 - accuracy: 0.9084 - val_loss: 0.0132 - val_accuracy: 0.9168 Epoch 4/50 60000/60000 [==============================] - 4s 59us/step - loss: 0.0132 - accuracy: 0.9170 - val_loss: 0.0122 - val_accuracy: 0.9247 Epoch 5/50 60000/60000 [==============================] - 4s 60us/step - loss: 0.0122 - accuracy: 0.9226 - val_loss: 0.0115 - val_accuracy: 0.9278 Epoch 6/50 60000/60000 [==============================] - 4s 60us/step - loss: 0.0115 - accuracy: 0.9272 - val_loss: 0.0111 - val_accuracy: 0.9301 Epoch 7/50 60000/60000 [==============================] - 4s 60us/step - loss: 0.0109 - accuracy: 0.9306 - val_loss: 0.0110 - val_accuracy: 0.9289 Epoch 8/50 60000/60000 [==============================] - 4s 60us/step - loss: 0.0105 - accuracy: 0.9338 - val_loss: 0.0102 - val_accuracy: 0.9341 Epoch 9/50 60000/60000 [==============================] - 4s 60us/step - loss: 0.0101 - accuracy: 0.9372 - val_loss: 0.0102 - val_accuracy: 0.9344 Epoch 10/50 60000/60000 [==============================] - 4s 60us/step - loss: 0.0097 - accuracy: 0.9395 - val_loss: 0.0096 - val_accuracy: 0.9381 Epoch 11/50 60000/60000 [==============================] - 4s 61us/step - loss: 0.0094 - accuracy: 0.9411 - val_loss: 0.0095 - val_accuracy: 0.9390 Epoch 12/50 60000/60000 [==============================] - 4s 60us/step - loss: 0.0091 - accuracy: 0.9427 - val_loss: 0.0093 - val_accuracy: 0.9399 Epoch 13/50 60000/60000 [==============================] - 4s 60us/step - loss: 0.0089 - accuracy: 0.9446 - val_loss: 0.0092 - val_accuracy: 0.9422 Epoch 14/50 60000/60000 [==============================] - 4s 60us/step - loss: 0.0087 - accuracy: 0.9461 - val_loss: 0.0090 - val_accuracy: 0.9400 Epoch 15/50 60000/60000 [==============================] - 4s 60us/step - loss: 0.0085 - accuracy: 0.9471 - val_loss: 0.0089 - val_accuracy: 0.9425 Epoch 16/50 60000/60000 [==============================] - 4s 60us/step - loss: 0.0083 - accuracy: 0.9490 - val_loss: 0.0087 - val_accuracy: 0.9437 Epoch 17/50 60000/60000 [==============================] - 4s 61us/step - loss: 0.0081 - accuracy: 0.9499 - val_loss: 0.0085 - val_accuracy: 0.9451 Epoch 18/50 60000/60000 [==============================] - 4s 60us/step - loss: 0.0079 - accuracy: 0.9513 - val_loss: 0.0083 - val_accuracy: 0.9461 Epoch 19/50 60000/60000 [==============================] - 4s 61us/step - loss: 0.0078 - accuracy: 0.9516 - val_loss: 0.0084 - val_accuracy: 0.9463 Epoch 20/50 60000/60000 [==============================] - 4s 61us/step - loss: 0.0077 - accuracy: 0.9528 - val_loss: 0.0083 - val_accuracy: 0.9461 Epoch 21/50 60000/60000 [==============================] - 4s 60us/step - loss: 0.0075 - accuracy: 0.9535 - val_loss: 0.0081 - val_accuracy: 0.9471 Epoch 22/50 60000/60000 [==============================] - 4s 60us/step - loss: 0.0074 - accuracy: 0.9540 - val_loss: 0.0081 - val_accuracy: 0.9471 Epoch 23/50 60000/60000 [==============================] - 4s 61us/step - loss: 0.0073 - accuracy: 0.9549 - val_loss: 0.0079 - val_accuracy: 0.9485 Epoch 24/50 60000/60000 [==============================] - 4s 61us/step - loss: 0.0072 - accuracy: 0.9557 - val_loss: 0.0080 - val_accuracy: 0.9482 Epoch 25/50 60000/60000 [==============================] - 4s 61us/step - loss: 0.0071 - accuracy: 0.9561 - val_loss: 0.0079 - val_accuracy: 0.9483 Epoch 26/50 60000/60000 [==============================] - 4s 61us/step - loss: 0.0070 - accuracy: 0.9570 - val_loss: 0.0078 - val_accuracy: 0.9493 Epoch 27/50 60000/60000 [==============================] - 4s 61us/step - loss: 0.0069 - accuracy: 0.9578 - val_loss: 0.0077 - val_accuracy: 0.9492 Epoch 28/50 60000/60000 [==============================] - 4s 61us/step - loss: 0.0068 - accuracy: 0.9583 - val_loss: 0.0078 - val_accuracy: 0.9492 Epoch 29/50 60000/60000 [==============================] - 4s 60us/step - loss: 0.0067 - accuracy: 0.9585 - val_loss: 0.0077 - val_accuracy: 0.9508 Epoch 30/50 60000/60000 [==============================] - 4s 60us/step - loss: 0.0066 - accuracy: 0.9593 - val_loss: 0.0078 - val_accuracy: 0.9499 Epoch 31/50 60000/60000 [==============================] - 4s 60us/step - loss: 0.0065 - accuracy: 0.9602 - val_loss: 0.0075 - val_accuracy: 0.9500 Epoch 32/50 60000/60000 [==============================] - 4s 60us/step - loss: 0.0065 - accuracy: 0.9607 - val_loss: 0.0075 - val_accuracy: 0.9505 Epoch 33/50 60000/60000 [==============================] - 4s 60us/step - loss: 0.0064 - accuracy: 0.9610 - val_loss: 0.0074 - val_accuracy: 0.9517 Epoch 34/50 60000/60000 [==============================] - 4s 60us/step - loss: 0.0063 - accuracy: 0.9617 - val_loss: 0.0074 - val_accuracy: 0.9511 Epoch 35/50 60000/60000 [==============================] - 4s 60us/step - loss: 0.0063 - accuracy: 0.9618 - val_loss: 0.0074 - val_accuracy: 0.9520 Epoch 36/50 60000/60000 [==============================] - 4s 61us/step - loss: 0.0062 - accuracy: 0.9629 - val_loss: 0.0073 - val_accuracy: 0.9527 Epoch 37/50 60000/60000 [==============================] - 4s 65us/step - loss: 0.0061 - accuracy: 0.9633 - val_loss: 0.0074 - val_accuracy: 0.9533 Epoch 38/50 60000/60000 [==============================] - 4s 62us/step - loss: 0.0061 - accuracy: 0.9633 - val_loss: 0.0072 - val_accuracy: 0.9532 Epoch 39/50 60000/60000 [==============================] - 4s 61us/step - loss: 0.0060 - accuracy: 0.9640 - val_loss: 0.0072 - val_accuracy: 0.9536 Epoch 40/50 60000/60000 [==============================] - 4s 61us/step - loss: 0.0059 - accuracy: 0.9642 - val_loss: 0.0073 - val_accuracy: 0.9519 Epoch 41/50 60000/60000 [==============================] - 4s 61us/step - loss: 0.0059 - accuracy: 0.9645 - val_loss: 0.0072 - val_accuracy: 0.9540 Epoch 42/50 60000/60000 [==============================] - 4s 62us/step - loss: 0.0058 - accuracy: 0.9652 - val_loss: 0.0072 - val_accuracy: 0.9526 Epoch 43/50 60000/60000 [==============================] - 4s 62us/step - loss: 0.0058 - accuracy: 0.9656 - val_loss: 0.0073 - val_accuracy: 0.9527 Epoch 44/50 60000/60000 [==============================] - 4s 62us/step - loss: 0.0057 - accuracy: 0.9657 - val_loss: 0.0072 - val_accuracy: 0.9534 Epoch 45/50 60000/60000 [==============================] - 4s 61us/step - loss: 0.0057 - accuracy: 0.9657 - val_loss: 0.0072 - val_accuracy: 0.9534 Epoch 46/50 60000/60000 [==============================] - 4s 65us/step - loss: 0.0056 - accuracy: 0.9665 - val_loss: 0.0072 - val_accuracy: 0.9534 Epoch 47/50 60000/60000 [==============================] - 4s 63us/step - loss: 0.0056 - accuracy: 0.9668 - val_loss: 0.0071 - val_accuracy: 0.9539 Epoch 48/50 60000/60000 [==============================] - 4s 64us/step - loss: 0.0055 - accuracy: 0.9672 - val_loss: 0.0071 - val_accuracy: 0.9546 Epoch 49/50 60000/60000 [==============================] - 4s 61us/step - loss: 0.0055 - accuracy: 0.9672 - val_loss: 0.0072 - val_accuracy: 0.9530 Epoch 50/50 60000/60000 [==============================] - 4s 60us/step - loss: 0.0054 - accuracy: 0.9678 - val_loss: 0.0070 - val_accuracy: 0.9550
from matplotlib import pyplot as plt
plt.plot(history.history['accuracy'], label='accuracy')
plt.plot(history.history['val_accuracy'], label='validation accuracy')
plt.title('Entrenamiento MNIST')
plt.xlabel('Épocas')
plt.legend(loc="lower right")
plt.show()
Si nos enfrentáramos a un problema de clasificación con responsabilildad deberíamos ser capaces de asegurar que el rendimiento que decimos que tiene nuestra red es el que realmente tiene (y no necesariamente una red, sino a cualquier modelo de clasificación que utilicemos. No solo de redes vive el experto en machine learning).
Para ello, hasta ahora hemos hecho uso del conjunto de test. Pero, cuando entrenamos una red hacemos muchas pruebas, muchos cambios en su configuración (los hiperparámetros) buscando una de ellas que nos dé los mejores resultados. Llegará un momento en el que hemos hecho tantas modificaciones en la red que nuestro conjunto de test logrará un buen accuracy. Sin embargo, ¿cómo podemos estar seguros de que la red funcionaría bien para un nuevo conjunto de test? Es decir, quizá hayamos involuntariamente optimizado la red para que funcione bien sobre el conjunto de test.
La forma de asegurar que hemos entrenado una red que generaliza correctamente es disponer de tres conjuntos: entrenamiento, validación y test. Con el de entrenamiento, entrenamos, y utilizaremos el conjunto de validación para comprobar el nivel de accuracy logrado en ese modelo. Al final de todas las pruebas que hayamos hecho, dispondremos de nuestro modelo final. En ese momento tomaremos nuestro conjunto de test (que previamente habíamos guardado bajo llave para evitar la tentación de utilizarlo antes) y lo pasaremos por la red. El accuracy que nos devuelva este conjunto de test será nuestro resultado final.
Nosotros en las clases no nos vamos a preocupar mucho de esto, y utilizaremos el conjunto de test también como conjunto de validación.