Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.
%load_ext watermark
%watermark -a 'Sebastian Raschka' -v -p torch
Author: Sebastian Raschka Python implementation: CPython Python version : 3.8.8 IPython version : 7.21.0 torch: 1.8.0
This notebook demos pipeline parallelism added to PyTorch 1.8 using VGG-16 as an example. For more details, see https://pytorch.org/docs/1.8.0/pipeline.html?highlight=pipeline#.
import sys
import torch
sys.path.insert(0, "..") # to include ../helper_evaluate.py etc.
from helper_utils import set_all_seeds, set_deterministic
from helper_evaluate import compute_accuracy
from helper_data import get_dataloaders_cifar10
from helper_train import train_classifier_simple_v1
##########################
### SETTINGS
##########################
# Data settings
num_classes = 10
# Hyperparameters
random_seed = 1
learning_rate = 0.0001
batch_size = 128
num_epochs = 50
set_all_seeds(random_seed)
#set_deterministic()
##########################
### Dataset
##########################
train_loader, valid_loader, test_loader = get_dataloaders_cifar10(
batch_size,
num_workers=2,
validation_fraction=0.1)
Files already downloaded and verified
This section implements the VGG-16 network in the conventional manne as a reference. The next section replicates this using pipeline parallelism.
##########################
### Model
##########################
class VGG16(torch.nn.Module):
def __init__(self, num_classes):
super().__init__()
# calculate same padding:
# (w - k + 2*p)/s + 1 = o
# => p = (s(o-1) - w + k)/2
self.block_1 = torch.nn.Sequential(
torch.nn.Conv2d(in_channels=3,
out_channels=64,
kernel_size=(3, 3),
stride=(1, 1),
# (1(32-1)- 32 + 3)/2 = 1
padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(in_channels=64,
out_channels=64,
kernel_size=(3, 3),
stride=(1, 1),
padding=1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(kernel_size=(2, 2),
stride=(2, 2))
)
self.block_2 = torch.nn.Sequential(
torch.nn.Conv2d(in_channels=64,
out_channels=128,
kernel_size=(3, 3),
stride=(1, 1),
padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(in_channels=128,
out_channels=128,
kernel_size=(3, 3),
stride=(1, 1),
padding=1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(kernel_size=(2, 2),
stride=(2, 2))
)
self.block_3 = torch.nn.Sequential(
torch.nn.Conv2d(in_channels=128,
out_channels=256,
kernel_size=(3, 3),
stride=(1, 1),
padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(in_channels=256,
out_channels=256,
kernel_size=(3, 3),
stride=(1, 1),
padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(in_channels=256,
out_channels=256,
kernel_size=(3, 3),
stride=(1, 1),
padding=1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(kernel_size=(2, 2),
stride=(2, 2))
)
self.block_4 = torch.nn.Sequential(
torch.nn.Conv2d(in_channels=256,
out_channels=512,
kernel_size=(3, 3),
stride=(1, 1),
padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(in_channels=512,
out_channels=512,
kernel_size=(3, 3),
stride=(1, 1),
padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(in_channels=512,
out_channels=512,
kernel_size=(3, 3),
stride=(1, 1),
padding=1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(kernel_size=(2, 2),
stride=(2, 2))
)
self.block_5 = torch.nn.Sequential(
torch.nn.Conv2d(in_channels=512,
out_channels=512,
kernel_size=(3, 3),
stride=(1, 1),
padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(in_channels=512,
out_channels=512,
kernel_size=(3, 3),
stride=(1, 1),
padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(in_channels=512,
out_channels=512,
kernel_size=(3, 3),
stride=(1, 1),
padding=1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(kernel_size=(2, 2),
stride=(2, 2))
)
self.classifier = torch.nn.Sequential(
torch.nn.Flatten(),
torch.nn.Linear(512, 4096),
torch.nn.ReLU(True),
#torch.nn.Dropout(p=0.5),
torch.nn.Linear(4096, 4096),
torch.nn.ReLU(True),
#torch.nn.Dropout(p=0.5),
torch.nn.Linear(4096, num_classes),
)
def forward(self, x):
x = self.block_1(x)
x = self.block_2(x)
x = self.block_3(x)
x = self.block_4(x)
x = self.block_5(x)
x = self.classifier(x) # logits
return x
model = VGG16(num_classes=num_classes)
device = torch.device('cuda:0')
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
_ = train_classifier_simple_v1(num_epochs=num_epochs, model=model,
optimizer=optimizer, device=device,
train_loader=train_loader, valid_loader=valid_loader,
logging_interval=200)
Epoch: 001/050 | Batch 0000/0352 | Loss: 2.3034 Epoch: 001/050 | Batch 0200/0352 | Loss: 1.9757 ***Epoch: 001/050 | Train. Acc.: 24.324% | Loss: 1.883 ***Epoch: 001/050 | Valid. Acc.: 24.480% | Loss: 1.873 Time elapsed: 0.54 min Epoch: 002/050 | Batch 0000/0352 | Loss: 1.8859 Epoch: 002/050 | Batch 0200/0352 | Loss: 1.8733 ***Epoch: 002/050 | Train. Acc.: 33.400% | Loss: 1.765 ***Epoch: 002/050 | Valid. Acc.: 33.280% | Loss: 1.781 Time elapsed: 1.09 min Epoch: 003/050 | Batch 0000/0352 | Loss: 1.8217 Epoch: 003/050 | Batch 0200/0352 | Loss: 1.5797 ***Epoch: 003/050 | Train. Acc.: 41.444% | Loss: 1.540 ***Epoch: 003/050 | Valid. Acc.: 40.460% | Loss: 1.549 Time elapsed: 1.63 min Epoch: 004/050 | Batch 0000/0352 | Loss: 1.6600 Epoch: 004/050 | Batch 0200/0352 | Loss: 1.4421 ***Epoch: 004/050 | Train. Acc.: 48.211% | Loss: 1.400 ***Epoch: 004/050 | Valid. Acc.: 47.960% | Loss: 1.418 Time elapsed: 2.17 min Epoch: 005/050 | Batch 0000/0352 | Loss: 1.4752 Epoch: 005/050 | Batch 0200/0352 | Loss: 1.3762 ***Epoch: 005/050 | Train. Acc.: 55.909% | Loss: 1.202 ***Epoch: 005/050 | Valid. Acc.: 54.940% | Loss: 1.231 Time elapsed: 2.71 min Epoch: 006/050 | Batch 0000/0352 | Loss: 1.2081 Epoch: 006/050 | Batch 0200/0352 | Loss: 1.1761 ***Epoch: 006/050 | Train. Acc.: 62.300% | Loss: 1.034 ***Epoch: 006/050 | Valid. Acc.: 60.620% | Loss: 1.069 Time elapsed: 3.25 min Epoch: 007/050 | Batch 0000/0352 | Loss: 1.0605 Epoch: 007/050 | Batch 0200/0352 | Loss: 0.9945 ***Epoch: 007/050 | Train. Acc.: 67.273% | Loss: 0.909 ***Epoch: 007/050 | Valid. Acc.: 65.260% | Loss: 0.970 Time elapsed: 3.80 min Epoch: 008/050 | Batch 0000/0352 | Loss: 0.9049 Epoch: 008/050 | Batch 0200/0352 | Loss: 0.8547 ***Epoch: 008/050 | Train. Acc.: 69.647% | Loss: 0.844 ***Epoch: 008/050 | Valid. Acc.: 66.620% | Loss: 0.939 Time elapsed: 4.34 min Epoch: 009/050 | Batch 0000/0352 | Loss: 0.8435 Epoch: 009/050 | Batch 0200/0352 | Loss: 0.8415 ***Epoch: 009/050 | Train. Acc.: 71.738% | Loss: 0.787 ***Epoch: 009/050 | Valid. Acc.: 68.700% | Loss: 0.912 Time elapsed: 4.88 min Epoch: 010/050 | Batch 0000/0352 | Loss: 0.7540 Epoch: 010/050 | Batch 0200/0352 | Loss: 0.7508 ***Epoch: 010/050 | Train. Acc.: 73.676% | Loss: 0.734 ***Epoch: 010/050 | Valid. Acc.: 69.380% | Loss: 0.896 Time elapsed: 5.42 min Epoch: 011/050 | Batch 0000/0352 | Loss: 0.6758 Epoch: 011/050 | Batch 0200/0352 | Loss: 0.7798 ***Epoch: 011/050 | Train. Acc.: 74.769% | Loss: 0.708 ***Epoch: 011/050 | Valid. Acc.: 69.540% | Loss: 0.912 Time elapsed: 5.97 min Epoch: 012/050 | Batch 0000/0352 | Loss: 0.5798 Epoch: 012/050 | Batch 0200/0352 | Loss: 0.6676 ***Epoch: 012/050 | Train. Acc.: 77.618% | Loss: 0.649 ***Epoch: 012/050 | Valid. Acc.: 71.120% | Loss: 0.864 Time elapsed: 6.51 min Epoch: 013/050 | Batch 0000/0352 | Loss: 0.5435 Epoch: 013/050 | Batch 0200/0352 | Loss: 0.5304 ***Epoch: 013/050 | Train. Acc.: 72.060% | Loss: 0.811 ***Epoch: 013/050 | Valid. Acc.: 66.560% | Loss: 1.064 Time elapsed: 7.05 min Epoch: 014/050 | Batch 0000/0352 | Loss: 0.8212 Epoch: 014/050 | Batch 0200/0352 | Loss: 0.5295 ***Epoch: 014/050 | Train. Acc.: 80.340% | Loss: 0.565 ***Epoch: 014/050 | Valid. Acc.: 72.520% | Loss: 0.868 Time elapsed: 7.60 min Epoch: 015/050 | Batch 0000/0352 | Loss: 0.4467 Epoch: 015/050 | Batch 0200/0352 | Loss: 0.5036 ***Epoch: 015/050 | Train. Acc.: 80.913% | Loss: 0.556 ***Epoch: 015/050 | Valid. Acc.: 72.140% | Loss: 0.930 Time elapsed: 8.14 min Epoch: 016/050 | Batch 0000/0352 | Loss: 0.4458 Epoch: 016/050 | Batch 0200/0352 | Loss: 0.5054 ***Epoch: 016/050 | Train. Acc.: 83.393% | Loss: 0.497 ***Epoch: 016/050 | Valid. Acc.: 73.360% | Loss: 0.933 Time elapsed: 8.69 min Epoch: 017/050 | Batch 0000/0352 | Loss: 0.3353 Epoch: 017/050 | Batch 0200/0352 | Loss: 0.5200 ***Epoch: 017/050 | Train. Acc.: 83.424% | Loss: 0.505 ***Epoch: 017/050 | Valid. Acc.: 72.440% | Loss: 0.975 Time elapsed: 9.23 min Epoch: 018/050 | Batch 0000/0352 | Loss: 0.4049 Epoch: 018/050 | Batch 0200/0352 | Loss: 0.4376 ***Epoch: 018/050 | Train. Acc.: 85.198% | Loss: 0.447 ***Epoch: 018/050 | Valid. Acc.: 73.160% | Loss: 0.965 Time elapsed: 9.77 min Epoch: 019/050 | Batch 0000/0352 | Loss: 0.3085 Epoch: 019/050 | Batch 0200/0352 | Loss: 0.3960 ***Epoch: 019/050 | Train. Acc.: 86.209% | Loss: 0.419 ***Epoch: 019/050 | Valid. Acc.: 73.140% | Loss: 1.017 Time elapsed: 10.32 min Epoch: 020/050 | Batch 0000/0352 | Loss: 0.2945 Epoch: 020/050 | Batch 0200/0352 | Loss: 0.3203 ***Epoch: 020/050 | Train. Acc.: 86.078% | Loss: 0.411 ***Epoch: 020/050 | Valid. Acc.: 72.480% | Loss: 1.011 Time elapsed: 10.86 min Epoch: 021/050 | Batch 0000/0352 | Loss: 0.2950 Epoch: 021/050 | Batch 0200/0352 | Loss: 0.3598 ***Epoch: 021/050 | Train. Acc.: 88.662% | Loss: 0.340 ***Epoch: 021/050 | Valid. Acc.: 73.460% | Loss: 1.048 Time elapsed: 11.41 min Epoch: 022/050 | Batch 0000/0352 | Loss: 0.2688 Epoch: 022/050 | Batch 0200/0352 | Loss: 0.3375 ***Epoch: 022/050 | Train. Acc.: 90.007% | Loss: 0.310 ***Epoch: 022/050 | Valid. Acc.: 74.300% | Loss: 1.081 Time elapsed: 11.95 min Epoch: 023/050 | Batch 0000/0352 | Loss: 0.1888 Epoch: 023/050 | Batch 0200/0352 | Loss: 0.2363 ***Epoch: 023/050 | Train. Acc.: 91.322% | Loss: 0.261 ***Epoch: 023/050 | Valid. Acc.: 75.540% | Loss: 1.055 Time elapsed: 12.49 min Epoch: 024/050 | Batch 0000/0352 | Loss: 0.1915 Epoch: 024/050 | Batch 0200/0352 | Loss: 0.1584 ***Epoch: 024/050 | Train. Acc.: 91.958% | Loss: 0.241 ***Epoch: 024/050 | Valid. Acc.: 76.000% | Loss: 1.057 Time elapsed: 13.04 min Epoch: 025/050 | Batch 0000/0352 | Loss: 0.1950 Epoch: 025/050 | Batch 0200/0352 | Loss: 0.1886 ***Epoch: 025/050 | Train. Acc.: 89.602% | Loss: 0.304 ***Epoch: 025/050 | Valid. Acc.: 74.600% | Loss: 1.166 Time elapsed: 13.58 min Epoch: 026/050 | Batch 0000/0352 | Loss: 0.1521 Epoch: 026/050 | Batch 0200/0352 | Loss: 0.3536 ***Epoch: 026/050 | Train. Acc.: 92.018% | Loss: 0.241 ***Epoch: 026/050 | Valid. Acc.: 74.700% | Loss: 1.186 Time elapsed: 14.12 min Epoch: 027/050 | Batch 0000/0352 | Loss: 0.0479 Epoch: 027/050 | Batch 0200/0352 | Loss: 0.1856 ***Epoch: 027/050 | Train. Acc.: 93.753% | Loss: 0.196 ***Epoch: 027/050 | Valid. Acc.: 75.320% | Loss: 1.220 Time elapsed: 14.66 min Epoch: 028/050 | Batch 0000/0352 | Loss: 0.0642 Epoch: 028/050 | Batch 0200/0352 | Loss: 0.0706 ***Epoch: 028/050 | Train. Acc.: 94.836% | Loss: 0.152 ***Epoch: 028/050 | Valid. Acc.: 75.320% | Loss: 1.163 Time elapsed: 15.21 min Epoch: 029/050 | Batch 0000/0352 | Loss: 0.0688 Epoch: 029/050 | Batch 0200/0352 | Loss: 0.2222 ***Epoch: 029/050 | Train. Acc.: 94.131% | Loss: 0.181 ***Epoch: 029/050 | Valid. Acc.: 75.580% | Loss: 1.224 Time elapsed: 15.75 min Epoch: 030/050 | Batch 0000/0352 | Loss: 0.1569 Epoch: 030/050 | Batch 0200/0352 | Loss: 0.0786 ***Epoch: 030/050 | Train. Acc.: 96.044% | Loss: 0.126 ***Epoch: 030/050 | Valid. Acc.: 76.240% | Loss: 1.272 Time elapsed: 16.30 min Epoch: 031/050 | Batch 0000/0352 | Loss: 0.0563 Epoch: 031/050 | Batch 0200/0352 | Loss: 0.0863 ***Epoch: 031/050 | Train. Acc.: 96.498% | Loss: 0.112 ***Epoch: 031/050 | Valid. Acc.: 76.520% | Loss: 1.239 Time elapsed: 16.84 min Epoch: 032/050 | Batch 0000/0352 | Loss: 0.0244 Epoch: 032/050 | Batch 0200/0352 | Loss: 0.1219 ***Epoch: 032/050 | Train. Acc.: 96.169% | Loss: 0.121 ***Epoch: 032/050 | Valid. Acc.: 76.720% | Loss: 1.216 Time elapsed: 17.38 min Epoch: 033/050 | Batch 0000/0352 | Loss: 0.1291 Epoch: 033/050 | Batch 0200/0352 | Loss: 0.0828 ***Epoch: 033/050 | Train. Acc.: 95.909% | Loss: 0.136 ***Epoch: 033/050 | Valid. Acc.: 76.000% | Loss: 1.242 Time elapsed: 17.93 min Epoch: 034/050 | Batch 0000/0352 | Loss: 0.0858 Epoch: 034/050 | Batch 0200/0352 | Loss: 0.1213 ***Epoch: 034/050 | Train. Acc.: 96.098% | Loss: 0.124 ***Epoch: 034/050 | Valid. Acc.: 75.720% | Loss: 1.318 Time elapsed: 18.47 min Epoch: 035/050 | Batch 0000/0352 | Loss: 0.0288 Epoch: 035/050 | Batch 0200/0352 | Loss: 0.0558 ***Epoch: 035/050 | Train. Acc.: 95.222% | Loss: 0.162 ***Epoch: 035/050 | Valid. Acc.: 75.160% | Loss: 1.388 Time elapsed: 19.01 min Epoch: 036/050 | Batch 0000/0352 | Loss: 0.0136 Epoch: 036/050 | Batch 0200/0352 | Loss: 0.0345 ***Epoch: 036/050 | Train. Acc.: 95.422% | Loss: 0.151 ***Epoch: 036/050 | Valid. Acc.: 74.820% | Loss: 1.354 Time elapsed: 19.55 min Epoch: 037/050 | Batch 0000/0352 | Loss: 0.0644 Epoch: 037/050 | Batch 0200/0352 | Loss: 0.0265 ***Epoch: 037/050 | Train. Acc.: 94.347% | Loss: 0.187 ***Epoch: 037/050 | Valid. Acc.: 74.560% | Loss: 1.423 Time elapsed: 20.10 min Epoch: 038/050 | Batch 0000/0352 | Loss: 0.0368 Epoch: 038/050 | Batch 0200/0352 | Loss: 0.0925 ***Epoch: 038/050 | Train. Acc.: 95.807% | Loss: 0.135 ***Epoch: 038/050 | Valid. Acc.: 75.380% | Loss: 1.365 Time elapsed: 20.64 min Epoch: 039/050 | Batch 0000/0352 | Loss: 0.0510 Epoch: 039/050 | Batch 0200/0352 | Loss: 0.2865 ***Epoch: 039/050 | Train. Acc.: 96.160% | Loss: 0.125 ***Epoch: 039/050 | Valid. Acc.: 75.940% | Loss: 1.406 Time elapsed: 21.18 min Epoch: 040/050 | Batch 0000/0352 | Loss: 0.1398 Epoch: 040/050 | Batch 0200/0352 | Loss: 0.0894 ***Epoch: 040/050 | Train. Acc.: 98.256% | Loss: 0.056 ***Epoch: 040/050 | Valid. Acc.: 77.300% | Loss: 1.251 Time elapsed: 21.73 min Epoch: 041/050 | Batch 0000/0352 | Loss: 0.0201 Epoch: 041/050 | Batch 0200/0352 | Loss: 0.0795 ***Epoch: 041/050 | Train. Acc.: 97.827% | Loss: 0.071 ***Epoch: 041/050 | Valid. Acc.: 77.000% | Loss: 1.379 Time elapsed: 22.27 min Epoch: 042/050 | Batch 0000/0352 | Loss: 0.0309 Epoch: 042/050 | Batch 0200/0352 | Loss: 0.0177 ***Epoch: 042/050 | Train. Acc.: 98.500% | Loss: 0.049 ***Epoch: 042/050 | Valid. Acc.: 77.280% | Loss: 1.315 Time elapsed: 22.81 min Epoch: 043/050 | Batch 0000/0352 | Loss: 0.0442 Epoch: 043/050 | Batch 0200/0352 | Loss: 0.0547 ***Epoch: 043/050 | Train. Acc.: 98.369% | Loss: 0.051 ***Epoch: 043/050 | Valid. Acc.: 77.680% | Loss: 1.303 Time elapsed: 23.36 min Epoch: 044/050 | Batch 0000/0352 | Loss: 0.0294 Epoch: 044/050 | Batch 0200/0352 | Loss: 0.0188 ***Epoch: 044/050 | Train. Acc.: 97.716% | Loss: 0.073 ***Epoch: 044/050 | Valid. Acc.: 77.100% | Loss: 1.381 Time elapsed: 23.90 min Epoch: 045/050 | Batch 0000/0352 | Loss: 0.0226 Epoch: 045/050 | Batch 0200/0352 | Loss: 0.0128 ***Epoch: 045/050 | Train. Acc.: 98.140% | Loss: 0.059 ***Epoch: 045/050 | Valid. Acc.: 76.420% | Loss: 1.395 Time elapsed: 24.45 min Epoch: 046/050 | Batch 0000/0352 | Loss: 0.0170 Epoch: 046/050 | Batch 0200/0352 | Loss: 0.0198 ***Epoch: 046/050 | Train. Acc.: 98.527% | Loss: 0.047 ***Epoch: 046/050 | Valid. Acc.: 77.000% | Loss: 1.548 Time elapsed: 24.99 min Epoch: 047/050 | Batch 0000/0352 | Loss: 0.0011 Epoch: 047/050 | Batch 0200/0352 | Loss: 0.0362 ***Epoch: 047/050 | Train. Acc.: 98.682% | Loss: 0.041 ***Epoch: 047/050 | Valid. Acc.: 77.260% | Loss: 1.345 Time elapsed: 25.54 min Epoch: 048/050 | Batch 0000/0352 | Loss: 0.0180 Epoch: 048/050 | Batch 0200/0352 | Loss: 0.0191 ***Epoch: 048/050 | Train. Acc.: 98.838% | Loss: 0.036 ***Epoch: 048/050 | Valid. Acc.: 77.300% | Loss: 1.412 Time elapsed: 26.08 min Epoch: 049/050 | Batch 0000/0352 | Loss: 0.0114 Epoch: 049/050 | Batch 0200/0352 | Loss: 0.0217 ***Epoch: 049/050 | Train. Acc.: 98.378% | Loss: 0.054 ***Epoch: 049/050 | Valid. Acc.: 77.460% | Loss: 1.459 Time elapsed: 26.63 min Epoch: 050/050 | Batch 0000/0352 | Loss: 0.0044 Epoch: 050/050 | Batch 0200/0352 | Loss: 0.0069 ***Epoch: 050/050 | Train. Acc.: 98.804% | Loss: 0.036 ***Epoch: 050/050 | Valid. Acc.: 77.700% | Loss: 1.365 Time elapsed: 27.17 min Total Training Time: 27.17 min
Below we first define the blocks we are going to wrap into the model:
block_1 = torch.nn.Sequential(
torch.nn.Conv2d(in_channels=3,
out_channels=64,
kernel_size=(3, 3),
stride=(1, 1),
# (1(32-1)- 32 + 3)/2 = 1
padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(in_channels=64,
out_channels=64,
kernel_size=(3, 3),
stride=(1, 1),
padding=1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(kernel_size=(2, 2),
stride=(2, 2))
)
block_2 = torch.nn.Sequential(
torch.nn.Conv2d(in_channels=64,
out_channels=128,
kernel_size=(3, 3),
stride=(1, 1),
padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(in_channels=128,
out_channels=128,
kernel_size=(3, 3),
stride=(1, 1),
padding=1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(kernel_size=(2, 2),
stride=(2, 2))
)
block_3 = torch.nn.Sequential(
torch.nn.Conv2d(in_channels=128,
out_channels=256,
kernel_size=(3, 3),
stride=(1, 1),
padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(in_channels=256,
out_channels=256,
kernel_size=(3, 3),
stride=(1, 1),
padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(in_channels=256,
out_channels=256,
kernel_size=(3, 3),
stride=(1, 1),
padding=1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(kernel_size=(2, 2),
stride=(2, 2))
)
block_4 = torch.nn.Sequential(
torch.nn.Conv2d(in_channels=256,
out_channels=512,
kernel_size=(3, 3),
stride=(1, 1),
padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(in_channels=512,
out_channels=512,
kernel_size=(3, 3),
stride=(1, 1),
padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(in_channels=512,
out_channels=512,
kernel_size=(3, 3),
stride=(1, 1),
padding=1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(kernel_size=(2, 2),
stride=(2, 2))
)
block_5 = torch.nn.Sequential(
torch.nn.Conv2d(in_channels=512,
out_channels=512,
kernel_size=(3, 3),
stride=(1, 1),
padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(in_channels=512,
out_channels=512,
kernel_size=(3, 3),
stride=(1, 1),
padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(in_channels=512,
out_channels=512,
kernel_size=(3, 3),
stride=(1, 1),
padding=1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(kernel_size=(2, 2),
stride=(2, 2))
)
classifier = torch.nn.Sequential(
torch.nn.Flatten(),
torch.nn.Linear(512, 4096),
torch.nn.ReLU(True),
#torch.nn.Dropout(p=0.5),
torch.nn.Linear(4096, 4096),
torch.nn.ReLU(True),
#torch.nn.Dropout(p=0.5),
torch.nn.Linear(4096, num_classes),
)
Before setting up the environment for the distributed run, we check if the distributed setting is supported on our machine. The following should return True
:
torch.distributed.is_available()
True
Next, we set the following environment variables for your machine:
For MASTER_ADDR
just use the IP address of your machine. E.g., 123.45.67.89
%env MASTER_ADDR=xxx.xx.xx.xx
Choose a free port:
%env MASTER_PORT=8891
env: MASTER_PORT=8891
Set up the RPC if it is not already running (more details at https://pytorch.org/docs/stable/rpc.html):
try:
torch.distributed.rpc.init_rpc(name='node1', rank=0, world_size=1)
except RuntimeError as e:
if str(e) == 'Address already in use':
pass
else:
raise RuntimeError(e)
This is the main part for running the model on multiple GPUs.
Sequential
modelmicrobatches
, for more details, see https://pytorch.org/docs/1.8.0/pipeline.html?highlight=pipeline#from torch.distributed.pipeline.sync import Pipe
block1 = block_1.cuda(0)
block2 = block_2.cuda(0)
block3 = block_3.cuda(2)
block4 = block_4.cuda(2)
block4 = block_5.cuda(3)
block4 = classifier.cuda(0)
model_parallel = torch.nn.Sequential(
block_1, block_2, block_3, block_4, block_5, classifier)
model_parallel = Pipe(model_parallel, chunks=8)
optimizer = torch.optim.Adam(model_parallel.parameters(), lr=learning_rate)
_ = train_classifier_simple_v1(num_epochs=num_epochs, model=model_parallel,
optimizer=optimizer, device=torch.device('cuda:0'),
train_loader=train_loader, valid_loader=valid_loader,
logging_interval=200)
Epoch: 001/050 | Batch 0000/0352 | Loss: 2.3043 Epoch: 001/050 | Batch 0200/0352 | Loss: 2.0182 ***Epoch: 001/050 | Train. Acc.: 21.711% | Loss: 1.914 ***Epoch: 001/050 | Valid. Acc.: 21.420% | Loss: 1.906 Time elapsed: 1.35 min Epoch: 002/050 | Batch 0000/0352 | Loss: 1.9691 Epoch: 002/050 | Batch 0200/0352 | Loss: 1.7432 ***Epoch: 002/050 | Train. Acc.: 28.180% | Loss: 1.881 ***Epoch: 002/050 | Valid. Acc.: 28.580% | Loss: 1.895 Time elapsed: 2.68 min Epoch: 003/050 | Batch 0000/0352 | Loss: 1.8837 Epoch: 003/050 | Batch 0200/0352 | Loss: 1.5357 ***Epoch: 003/050 | Train. Acc.: 38.922% | Loss: 1.584 ***Epoch: 003/050 | Valid. Acc.: 38.960% | Loss: 1.587 Time elapsed: 4.02 min Epoch: 004/050 | Batch 0000/0352 | Loss: 1.6131 Epoch: 004/050 | Batch 0200/0352 | Loss: 1.4481 ***Epoch: 004/050 | Train. Acc.: 48.900% | Loss: 1.360 ***Epoch: 004/050 | Valid. Acc.: 49.240% | Loss: 1.367 Time elapsed: 5.36 min Epoch: 005/050 | Batch 0000/0352 | Loss: 1.3710 Epoch: 005/050 | Batch 0200/0352 | Loss: 1.2681 ***Epoch: 005/050 | Train. Acc.: 55.542% | Loss: 1.225 ***Epoch: 005/050 | Valid. Acc.: 55.100% | Loss: 1.246 Time elapsed: 6.70 min Epoch: 006/050 | Batch 0000/0352 | Loss: 1.1851 Epoch: 006/050 | Batch 0200/0352 | Loss: 1.2468 ***Epoch: 006/050 | Train. Acc.: 60.782% | Loss: 1.082 ***Epoch: 006/050 | Valid. Acc.: 58.540% | Loss: 1.123 Time elapsed: 8.04 min Epoch: 007/050 | Batch 0000/0352 | Loss: 1.0370 Epoch: 007/050 | Batch 0200/0352 | Loss: 1.1411 ***Epoch: 007/050 | Train. Acc.: 65.060% | Loss: 0.969 ***Epoch: 007/050 | Valid. Acc.: 63.240% | Loss: 1.023 Time elapsed: 9.38 min Epoch: 008/050 | Batch 0000/0352 | Loss: 0.9135 Epoch: 008/050 | Batch 0200/0352 | Loss: 1.0559 ***Epoch: 008/050 | Train. Acc.: 68.411% | Loss: 0.876 ***Epoch: 008/050 | Valid. Acc.: 66.260% | Loss: 0.946 Time elapsed: 10.73 min Epoch: 009/050 | Batch 0000/0352 | Loss: 0.8361 Epoch: 009/050 | Batch 0200/0352 | Loss: 1.0417 ***Epoch: 009/050 | Train. Acc.: 71.067% | Loss: 0.815 ***Epoch: 009/050 | Valid. Acc.: 67.320% | Loss: 0.908 Time elapsed: 12.06 min Epoch: 010/050 | Batch 0000/0352 | Loss: 0.7411 Epoch: 010/050 | Batch 0200/0352 | Loss: 0.9193 ***Epoch: 010/050 | Train. Acc.: 74.264% | Loss: 0.729 ***Epoch: 010/050 | Valid. Acc.: 70.560% | Loss: 0.847 Time elapsed: 13.39 min Epoch: 011/050 | Batch 0000/0352 | Loss: 0.6631 Epoch: 011/050 | Batch 0200/0352 | Loss: 0.7807 ***Epoch: 011/050 | Train. Acc.: 75.558% | Loss: 0.685 ***Epoch: 011/050 | Valid. Acc.: 71.380% | Loss: 0.840 Time elapsed: 14.74 min Epoch: 012/050 | Batch 0000/0352 | Loss: 0.6035 Epoch: 012/050 | Batch 0200/0352 | Loss: 0.7203 ***Epoch: 012/050 | Train. Acc.: 76.527% | Loss: 0.658 ***Epoch: 012/050 | Valid. Acc.: 71.760% | Loss: 0.856 Time elapsed: 16.08 min Epoch: 013/050 | Batch 0000/0352 | Loss: 0.5782 Epoch: 013/050 | Batch 0200/0352 | Loss: 0.6597 ***Epoch: 013/050 | Train. Acc.: 77.604% | Loss: 0.629 ***Epoch: 013/050 | Valid. Acc.: 71.920% | Loss: 0.859 Time elapsed: 17.42 min Epoch: 014/050 | Batch 0000/0352 | Loss: 0.5054 Epoch: 014/050 | Batch 0200/0352 | Loss: 0.6732 ***Epoch: 014/050 | Train. Acc.: 77.451% | Loss: 0.652 ***Epoch: 014/050 | Valid. Acc.: 71.200% | Loss: 0.883 Time elapsed: 18.78 min Epoch: 015/050 | Batch 0000/0352 | Loss: 0.4921 Epoch: 015/050 | Batch 0200/0352 | Loss: 0.5878 ***Epoch: 015/050 | Train. Acc.: 78.100% | Loss: 0.630 ***Epoch: 015/050 | Valid. Acc.: 71.240% | Loss: 0.907 Time elapsed: 20.12 min Epoch: 016/050 | Batch 0000/0352 | Loss: 0.5568 Epoch: 016/050 | Batch 0200/0352 | Loss: 0.5079 ***Epoch: 016/050 | Train. Acc.: 80.038% | Loss: 0.579 ***Epoch: 016/050 | Valid. Acc.: 71.440% | Loss: 0.927 Time elapsed: 21.48 min Epoch: 017/050 | Batch 0000/0352 | Loss: 0.4743 Epoch: 017/050 | Batch 0200/0352 | Loss: 0.5260 ***Epoch: 017/050 | Train. Acc.: 80.993% | Loss: 0.587 ***Epoch: 017/050 | Valid. Acc.: 71.620% | Loss: 0.972 Time elapsed: 22.84 min Epoch: 018/050 | Batch 0000/0352 | Loss: 0.4304 Epoch: 018/050 | Batch 0200/0352 | Loss: 0.4014 ***Epoch: 018/050 | Train. Acc.: 79.338% | Loss: 0.653 ***Epoch: 018/050 | Valid. Acc.: 69.800% | Loss: 1.060 Time elapsed: 24.18 min Epoch: 019/050 | Batch 0000/0352 | Loss: 0.3735 Epoch: 019/050 | Batch 0200/0352 | Loss: 0.3300 ***Epoch: 019/050 | Train. Acc.: 83.069% | Loss: 0.518 ***Epoch: 019/050 | Valid. Acc.: 73.120% | Loss: 0.991 Time elapsed: 25.54 min Epoch: 020/050 | Batch 0000/0352 | Loss: 0.3183 Epoch: 020/050 | Batch 0200/0352 | Loss: 0.4193 ***Epoch: 020/050 | Train. Acc.: 86.129% | Loss: 0.417 ***Epoch: 020/050 | Valid. Acc.: 74.960% | Loss: 0.961 Time elapsed: 26.90 min Epoch: 021/050 | Batch 0000/0352 | Loss: 0.1804 Epoch: 021/050 | Batch 0200/0352 | Loss: 0.4772 ***Epoch: 021/050 | Train. Acc.: 84.840% | Loss: 0.465 ***Epoch: 021/050 | Valid. Acc.: 73.140% | Loss: 1.030 Time elapsed: 28.23 min Epoch: 022/050 | Batch 0000/0352 | Loss: 0.1948 Epoch: 022/050 | Batch 0200/0352 | Loss: 0.4274 ***Epoch: 022/050 | Train. Acc.: 84.476% | Loss: 0.483 ***Epoch: 022/050 | Valid. Acc.: 73.260% | Loss: 1.068 Time elapsed: 29.55 min Epoch: 023/050 | Batch 0000/0352 | Loss: 0.2390 Epoch: 023/050 | Batch 0200/0352 | Loss: 0.2604 ***Epoch: 023/050 | Train. Acc.: 87.033% | Loss: 0.407 ***Epoch: 023/050 | Valid. Acc.: 74.260% | Loss: 1.054 Time elapsed: 30.90 min Epoch: 024/050 | Batch 0000/0352 | Loss: 0.2224 Epoch: 024/050 | Batch 0200/0352 | Loss: 0.3437 ***Epoch: 024/050 | Train. Acc.: 88.540% | Loss: 0.345 ***Epoch: 024/050 | Valid. Acc.: 74.440% | Loss: 0.996 Time elapsed: 32.25 min Epoch: 025/050 | Batch 0000/0352 | Loss: 0.1355 Epoch: 025/050 | Batch 0200/0352 | Loss: 0.2179 ***Epoch: 025/050 | Train. Acc.: 90.258% | Loss: 0.292 ***Epoch: 025/050 | Valid. Acc.: 75.200% | Loss: 1.034 Time elapsed: 33.59 min Epoch: 026/050 | Batch 0000/0352 | Loss: 0.1521 Epoch: 026/050 | Batch 0200/0352 | Loss: 0.1664 ***Epoch: 026/050 | Train. Acc.: 89.593% | Loss: 0.324 ***Epoch: 026/050 | Valid. Acc.: 74.920% | Loss: 1.139 Time elapsed: 34.94 min Epoch: 027/050 | Batch 0000/0352 | Loss: 0.1650 Epoch: 027/050 | Batch 0200/0352 | Loss: 0.2342 ***Epoch: 027/050 | Train. Acc.: 90.831% | Loss: 0.287 ***Epoch: 027/050 | Valid. Acc.: 75.440% | Loss: 1.242 Time elapsed: 36.29 min Epoch: 028/050 | Batch 0000/0352 | Loss: 0.1207 Epoch: 028/050 | Batch 0200/0352 | Loss: 0.1094 ***Epoch: 028/050 | Train. Acc.: 90.716% | Loss: 0.304 ***Epoch: 028/050 | Valid. Acc.: 74.680% | Loss: 1.334 Time elapsed: 37.64 min Epoch: 029/050 | Batch 0000/0352 | Loss: 0.0708 Epoch: 029/050 | Batch 0200/0352 | Loss: 0.1305 ***Epoch: 029/050 | Train. Acc.: 90.467% | Loss: 0.326 ***Epoch: 029/050 | Valid. Acc.: 73.800% | Loss: 1.357 Time elapsed: 38.97 min Epoch: 030/050 | Batch 0000/0352 | Loss: 0.0699 Epoch: 030/050 | Batch 0200/0352 | Loss: 0.0874 ***Epoch: 030/050 | Train. Acc.: 94.009% | Loss: 0.188 ***Epoch: 030/050 | Valid. Acc.: 75.640% | Loss: 1.134 Time elapsed: 40.31 min Epoch: 031/050 | Batch 0000/0352 | Loss: 0.0434 Epoch: 031/050 | Batch 0200/0352 | Loss: 0.0705 ***Epoch: 031/050 | Train. Acc.: 94.442% | Loss: 0.172 ***Epoch: 031/050 | Valid. Acc.: 76.360% | Loss: 1.172 Time elapsed: 41.63 min Epoch: 032/050 | Batch 0000/0352 | Loss: 0.1119 Epoch: 032/050 | Batch 0200/0352 | Loss: 0.0542 ***Epoch: 032/050 | Train. Acc.: 94.224% | Loss: 0.182 ***Epoch: 032/050 | Valid. Acc.: 76.940% | Loss: 1.169 Time elapsed: 42.96 min Epoch: 033/050 | Batch 0000/0352 | Loss: 0.0771 Epoch: 033/050 | Batch 0200/0352 | Loss: 0.0898 ***Epoch: 033/050 | Train. Acc.: 94.258% | Loss: 0.177 ***Epoch: 033/050 | Valid. Acc.: 76.440% | Loss: 1.189 Time elapsed: 44.31 min Epoch: 034/050 | Batch 0000/0352 | Loss: 0.0264 Epoch: 034/050 | Batch 0200/0352 | Loss: 0.1268 ***Epoch: 034/050 | Train. Acc.: 94.207% | Loss: 0.186 ***Epoch: 034/050 | Valid. Acc.: 76.520% | Loss: 1.160 Time elapsed: 45.66 min Epoch: 035/050 | Batch 0000/0352 | Loss: 0.0771 Epoch: 035/050 | Batch 0200/0352 | Loss: 0.0835 ***Epoch: 035/050 | Train. Acc.: 94.469% | Loss: 0.187 ***Epoch: 035/050 | Valid. Acc.: 76.440% | Loss: 1.249 Time elapsed: 47.01 min Epoch: 036/050 | Batch 0000/0352 | Loss: 0.0535 Epoch: 036/050 | Batch 0200/0352 | Loss: 0.0453 ***Epoch: 036/050 | Train. Acc.: 96.549% | Loss: 0.105 ***Epoch: 036/050 | Valid. Acc.: 77.480% | Loss: 1.154 Time elapsed: 48.35 min Epoch: 037/050 | Batch 0000/0352 | Loss: 0.0167 Epoch: 037/050 | Batch 0200/0352 | Loss: 0.0867
As expected, the Training is slower as before. But this is expected because the main selling point of pipeline parallelism is to utilize more GPUs due to memory contraints not to speed up training.