PyTorch 実装の基本フロー

MNIST手書き数字の分類

  • 畳み込み層+全結合層で構築したネットワークモデルの実装を行い、MNIST分類を学習し推論します。

事前準備

モジュールインポート

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

Datasetの準備

MNISTデータセット

In [2]:
#データ前処理 transform を設定
transform = transforms.Compose(
    [transforms.ToTensor(),                      # Tensor変換とshape変換 [H, W, C] -> [C, H, W]
     transforms.Normalize((0.5, ), (0.5, ))])    # 標準化 平均:0.5  標準偏差:0.5

#訓練用Datasetを作成
train_dataset = datasets.MNIST(root='./data', 
                                        train=True,
                                        download=True,
                                        transform=transform)

#検証用Datasetを作成
val_dataset = datasets.MNIST(root='./data', 
                                        train=False, 
                                        download=True, 
                                        transform=transform)

Detaloaderの作成

In [3]:
#訓練用 Dataloder
train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                            batch_size=64,
                                            shuffle=True)

#検証用 Dataloder
val_dataloader = torch.utils.data.DataLoader(val_dataset, 
                                            batch_size=64,
                                            shuffle=False)

# 辞書型変数にまとめる
dataloaders_dict = {"train": train_dataloader, "val": val_dataloader}

動作の確認

In [4]:
batch_iterator = iter(dataloaders_dict["train"])  # イテレータに変換
imges, labels = next(batch_iterator)  # 1番目の要素を取り出す
print("imges size = ", imges.size())
print("labels size = ", labels.size())

#試しに1枚 plot してみる
plt.imshow(imges[0].numpy().reshape(28,28), cmap='gray')
plt.title("label = {}".format(labels[0].numpy()))
plt.show()
imges size =  torch.Size([64, 1, 28, 28])
labels size =  torch.Size([64])

ネットワークモデルの作成

  • 畳み込み層CNNのモデル
In [5]:
# 畳み込み層+全結合層のネットワークモデル
class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)    #畳み込み層
        self.conv2 = nn.Conv2d(32, 64, 3, 1)   #畳み込み層
        self.fc1 = nn.Linear(9216, 128)        #全結合層
        self.fc2 = nn.Linear(128, 10)          #全結合層

    def forward(self, x):
        x = self.conv1(x)              # (Batch,  1, 28, 28) -> (Batch, 32, 26, 26)
        x = F.relu(x)
        x = self.conv2(x)              # (Batch, 32, 26, 26) -> (Batch, 64, 24, 24)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)         # (Batch, 64, 24, 24) -> (Batch, 64, 12, 12)
        x = torch.flatten(x, 1)        # (Batch, 64, 12, 12) -> (Batch, 9216)
        x = self.fc1(x)                # (Batch, 9216) -> (Batch, 128)
        x = self.fc2(x)                # (Batch, 128) -> (Batch, 10)

        return x
In [6]:
#モデル作成
net = Net()
In [7]:
#ネットワークのレイヤー確認
print(net)
Net(
  (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
  (fc1): Linear(in_features=9216, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
)

損失関数の定義

In [8]:
# nn.CrossEntropyLoss() はソフトマックス関数+クロスエントロピー誤差
criterion = nn.CrossEntropyLoss()

最適化手法の設定

In [9]:
optimizer = optim.Adam(net.parameters(), lr=0.001)

学習・検証の実施

In [10]:
# モデルを学習させる関数を作成
def train_model(net, dataloaders_dict, criterion, optimizer, num_epochs):
    
    # epochのループ
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch+1, num_epochs))
        print('-------------')

        # epochごとの学習と検証のループ
        for phase in ['train', 'val']:
            if phase == 'train':
                net.train()  # モデルを訓練モードに
            else:
                net.eval()   # モデルを検証モードに

            epoch_loss = 0.0  # epochの損失和
            epoch_corrects = 0  # epochの正解数

            # 未学習時の検証性能を確かめるため、epoch=0の訓練は省略
            if (epoch == 0) and (phase == 'train'):
                continue

            # データローダーからミニバッチを取り出すループ
            for i , (inputs, labels) in tqdm(enumerate(dataloaders_dict[phase])):

                # optimizerを初期化
                optimizer.zero_grad()

                # 順伝搬(forward)計算
                with torch.set_grad_enabled(phase == 'train'):  # 訓練モードのみ勾配を算出
                    outputs = net(inputs)              # 順伝播
                    loss = criterion(outputs, labels)  # 損失を計算
                    _, preds = torch.max(outputs, 1)   # ラベルを予測
                    
  
                    # 訓練時はバックプロパゲーション
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                    # イタレーション結果の計算
                    # lossの合計を更新
                    epoch_loss += loss.item() * inputs.size(0)  
                    # 正解数の合計を更新
                    epoch_corrects += torch.sum(preds == labels.data)

            # epochごとのlossと正解率を表示
            epoch_loss = epoch_loss / len(dataloaders_dict[phase].dataset)
            epoch_acc = epoch_corrects.double() / len(dataloaders_dict[phase].dataset)
            
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
In [11]:
# 学習・検証を実行する
num_epochs = 3
train_model(net, dataloaders_dict, criterion, optimizer, num_epochs=num_epochs)
3it [00:00, 23.63it/s]
Epoch 1/3
-------------
157it [00:05, 30.87it/s]
1it [00:00,  9.87it/s]
val Loss: 2.3047 Acc: 0.1128
Epoch 2/3
-------------
938it [01:11, 13.09it/s]
3it [00:00, 28.23it/s]
train Loss: 0.1247 Acc: 0.9617
157it [00:04, 33.20it/s]
2it [00:00, 13.95it/s]
val Loss: 0.0486 Acc: 0.9844
Epoch 3/3
-------------
938it [01:26, 10.81it/s]
3it [00:00, 26.54it/s]
train Loss: 0.0447 Acc: 0.9862
157it [00:05, 26.97it/s]
val Loss: 0.0338 Acc: 0.9894

テストデータに対する予測

In [12]:
batch_iterator = iter(dataloaders_dict["val"])  # イテレータに変換
imges, labels = next(batch_iterator)  # 1番目の要素を取り出す

net.eval() #推論モード
with torch.set_grad_enabled(False):   # 推論モードでは勾配を算出しない
    outputs = net(imges)               # 順伝播
    _, preds = torch.max(outputs, 1)  # ラベルを予測
    
#テストデータの予測結果を描画
plt.imshow(imges[0].numpy().reshape(28,28), cmap='gray')
plt.title("Label: Target={}, Predict={}".format(labels[0], preds[0].numpy()))
plt.show()

以上