图片分类数据 (Fashion-MNIST)

In [1]:
%matplotlib inline
import d2l
from mxnet import gluon 
d2l.use_svg_display()

使用提供的 FashionMNIST 类来下载和读取数据。

In [2]:
mnist_train = gluon.data.vision.FashionMNIST(train=True)
mnist_test = gluon.data.vision.FashionMNIST(train=False)

len(mnist_train), len(mnist_test)
Out[2]:
(60000, 10000)

可视化图片。

In [3]:
def get_fashion_mnist_labels(labels):
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]
X, y = mnist_train[:18]
d2l.show_images(X.squeeze(axis=-1), 2, 9, titles=get_fashion_mnist_labels(y));

读取一个小批量。

In [4]:
batch_size = 256
transformer = gluon.data.vision.transforms.ToTensor()
train_iter = gluon.data.DataLoader(mnist_train.transform_first(transformer),
                                   batch_size, shuffle=True, num_workers=4)

for X, y in train_iter:
    print(X.shape, y.shape)
    break
(256, 1, 28, 28) (256,)