#!/usr/bin/env python # coding: utf-8 # In[1]: import warnings warnings.simplefilter(action='ignore') # In[2]: import tensorflow as tf import numpy as np import matplotlib.pyplot as plt # ### 1. 准备 mnist 数据 # In[3]: (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data() # ### 2. 查看 mnist 数据 # In[4]: print('x_train data:', len(x_train)) print('x_test data:', len(x_test)) # In[5]: print('x_train shape:', x_train.shape) print('y_train shape:', y_train.shape) print('x_test shape:', x_test.shape) print('y_test shape:', y_test.shape) # ### 3. 定义 plot_image 函数显示数字图像 # In[6]: def plot_image(image): fig = plt.gcf() fig.set_size_inches(2, 2) plt.imshow(image, cmap='binary') plt.show() plot_image(x_train[0]) print('y_train[0]:', y_train[0]) # ### 4. 定义 plot_images_labels_prediction 函数以查看数字图形、真实的数字与预测结果 # In[7]: def plot_images_labels_prediction(images, labels, predictions, idx, num=10): """ images: 数字图像数组 labels: 真实值数组 predictions: 预测结果数据 idx: 开始显示的数据index num: 要显示的数据项数, 默认为10, 不超过25 """ fig = plt.gcf() fig.set_size_inches(12, 14) if num > 25: num = 25 for i in range(0, num): ax = plt.subplot(5, 5, i+1) ax.imshow(images[idx], cmap='binary') title = 'lable=' + str(labels[idx]) if len(predictions) > 0: title += ',predict=' + str(predictions[idx]) ax.set_title(title, fontsize=10) ax.set_xticks([]) ax.set_yticks([]) idx += 1 plt.show() # In[8]: plot_images_labels_prediction(x_train, y_train, [], 0, 10) # In[9]: plot_images_labels_prediction(x_test, y_test, [], 0, 10) # ### 5. features 数据预处理 # #### 5.1 查看处理前的 shape # In[10]: print('x_train shape:', x_train.shape) print('y_train shape:', y_train.shape) # #### 5.2 将 image 以 reshape 转换为一维的向量 # In[11]: x_train = x_train.reshape(x_train.shape[0], 784).astype('float32') x_test = x_test.reshape(x_test.shape[0], 784).astype('float32') # #### 5.3 查看转换为一维向量的 shape # In[12]: print('x_train shape:', x_train.shape) print('x_test shape:', x_test.shape) # #### 5.4 查看 images 第0项的内容 # In[13]: x_train[0] # #### 5.5 将 images 中的数字标准化 # In[14]: x_train_normalize = x_train / 255.0 x_test_normalize = x_test / 255.0 # In[15]: x_train_normalize[0] # ### 6. label 数据预处理 # #### 6.1 查看原来的 label 标签字段 # In[16]: y_train[:5] # #### 6.2 将 label 标签字段进行 one-hot 编码转换 # In[17]: y_train_one_hot = tf.keras.utils.to_categorical(y_train) y_test_one_hot = tf.keras.utils.to_categorical(y_test) # #### 6.3 查看进行 one-hot 编码转换之后的 label 标签字段 # In[18]: y_train_one_hot[:5]