Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.
%load_ext watermark
%watermark -a 'Sebastian Raschka' -v -p torch
Sebastian Raschka CPython 3.7.3 IPython 7.9.0 torch 1.7.0
Why do we care about gradient checkpointing? It can lower the memory requirement of deep neural networks quite substantially, allowing us to work with larger architectures and memory limitations of conventional GPUs. However, there is no free lunch here: as a trade-off for the lower-memory requirements, additional computations are carried out which can prolong the training time. However, when GPU-memory is a limiting factor that we cannot even circumvent by lowering the batch sizes, then gradient checkpointing is a great and easy option for making things work!
Below is a brief summary of how gradient checkpointing works. For more details, please see the excellent explanations in [1] and [2].
In vanilla backpropagation (the standard version of backpropagation), the required memory grows linearly with the number of layers n in the neural network. This is because all nodes from the forward pass are being kept in memory (until all their dependent child nodes are processed).
In the low-memory version of backpropagation, the forward pass is recomputed at each step, making it more memory-efficient than vanilla backpropagation, trading the memory for additional computations. In comparison, vanilla backpropagation processes n layers (nodes), the low-memory version processes $n^2$ nodes.
The gradient checkpointing method is a compromise between vanilla backpropagation and low-memory backpropagation, where nodes are recomputed more often than in vanilla backpropagation but not as often as in the low-memory version. In gradient checkpointing, we designate certain nodes as checkpoints so that they are not recomputed and serve as a basis for recomputing other nodes. The optimal choice is to designate every \sqrt{n}
-th node as a checkpoint node. Consequently, the memory requirement increases by \sqrt{n}
compared to the low-memory version of backpropagation.
As stated in [3], gradient checkpointing, we can implement models that are 4x to 10x larger than architectures that would usually fit into GPU memory.
PyTorch allows us to use gradient checkpointing very conveniently. In this notebook, we are only using the checkpointing for sequential models. However, it is also possible (and not much more complicated) to use checkpointing for non-sequential models. I recommend checking out the tutorial in [3] for more details.
A great performance benchmark and write-up is available at [4], showing the difference in memory consumption between a baseline ResNet-18 and one enhanced with gradient checkpointing.
[1] Saving memory using gradient-checkpointing: https://github.com/cybertronai/gradient-checkpointing
[2] Fitting larger networks into memory: https://medium.com/tensorflow/fitting-larger-networks-into-memory-583e3c758ff9
[3] Trading compute for memory in PyTorch models using Checkpointing: https://github.com/prigoyal/pytorch_memonger/blob/master/tutorial/Checkpointing_for_PyTorch_models.ipynb
[4] Deep Learning Memory Usage and Pytorch Optimization Tricks: https://www.sicara.ai/blog/2019-28-10-deep-learning-memory-usage-and-pytorch-optimization-tricks
For this demo, I am using a simple Network-in-Network (NiN) architecture for the purpose of code readability. The gain from gradient checkpointing can be larger the deeper the architecture.
The CNN architecture is based on
import os
import time
import random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Subset
from torchvision import datasets
from torchvision import transforms
import matplotlib.pyplot as plt
from PIL import Image
if torch.cuda.is_available():
torch.backends.cudnn.deterministic = True
I recommend using a function like the following one prior to using dataset loaders and initializing a model if you want to ensure the data is shuffled in the same manner if you rerun this notebook and the model gets the same initial random weights:
def set_all_seeds(seed):
os.environ["PL_GLOBAL_SEED"] = str(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
Similar to the set_all_seeds
function above, I recommend setting the behavior of PyTorch and cuDNN to deterministic (this is particulary relevant when using GPUs). We can also define a function for that:
def set_deterministic():
if torch.cuda.is_available():
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.set_deterministic(True)
##########################
### SETTINGS
##########################
# Device
CUDA_DEVICE_NUM = 2 # change as appropriate
DEVICE = torch.device('cuda:%d' % CUDA_DEVICE_NUM if torch.cuda.is_available() else 'cpu')
print('Device:', DEVICE)
# Hyperparameters
RANDOM_SEED = 1
LEARNING_RATE = 0.0001
BATCH_SIZE = 256
NUM_EPOCHS = 40
# Architecture
NUM_CLASSES = 10
set_all_seeds(RANDOM_SEED)
# Deterministic behavior not yet supported by AdaptiveAvgPool2d
#set_deterministic()
Device: cuda:2
import sys
sys.path.insert(0, "..") # to include ../helper_evaluate.py etc.
from helper_evaluate import compute_accuracy
from helper_data import get_dataloaders_cifar10
from helper_train import train_classifier_simple_v1
### Set random seed ###
set_all_seeds(RANDOM_SEED)
##########################
### Dataset
##########################
train_loader, valid_loader, test_loader = get_dataloaders_cifar10(
batch_size=BATCH_SIZE,
num_workers=2,
validation_fraction=0.1)
Files already downloaded and verified
# Checking the dataset
print('Training Set:\n')
for images, labels in train_loader:
print('Image batch dimensions:', images.size())
print('Image label dimensions:', labels.size())
print(labels[:10])
break
# Checking the dataset
print('\nValidation Set:')
for images, labels in valid_loader:
print('Image batch dimensions:', images.size())
print('Image label dimensions:', labels.size())
print(labels[:10])
break
# Checking the dataset
print('\nTesting Set:')
for images, labels in train_loader:
print('Image batch dimensions:', images.size())
print('Image label dimensions:', labels.size())
print(labels[:10])
break
Training Set: Image batch dimensions: torch.Size([256, 3, 32, 32]) Image label dimensions: torch.Size([256]) tensor([6, 9, 9, 4, 1, 1, 2, 7, 8, 3]) Validation Set: Image batch dimensions: torch.Size([256, 3, 32, 32]) Image label dimensions: torch.Size([256]) tensor([7, 1, 4, 1, 0, 2, 2, 5, 9, 6]) Testing Set: Image batch dimensions: torch.Size([256, 3, 32, 32]) Image label dimensions: torch.Size([256]) tensor([6, 9, 9, 4, 1, 1, 2, 7, 8, 3])
This is the basic NiN model without gradient checkpointing for reference.
##########################
### MODEL
##########################
class NiN(nn.Module):
def __init__(self, num_classes):
super(NiN, self).__init__()
self.num_classes = num_classes
self.classifier = nn.Sequential(
nn.Conv2d(3, 192, kernel_size=5, stride=1, padding=2),
nn.ReLU(inplace=True),
nn.Conv2d(192, 160, kernel_size=1, stride=1, padding=0),
nn.ReLU(inplace=True),
nn.Conv2d(160, 96, kernel_size=1, stride=1, padding=0),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
nn.Dropout(0.5),
nn.Conv2d(96, 192, kernel_size=5, stride=1, padding=2),
nn.ReLU(inplace=True),
nn.Conv2d(192, 192, kernel_size=1, stride=1, padding=0),
nn.ReLU(inplace=True),
nn.Conv2d(192, 192, kernel_size=1, stride=1, padding=0),
nn.ReLU(inplace=True),
nn.AvgPool2d(kernel_size=3, stride=2, padding=1),
nn.Dropout(0.5),
nn.Conv2d(192, 192, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(192, 192, kernel_size=1, stride=1, padding=0),
nn.ReLU(inplace=True),
nn.Conv2d(192, 10, kernel_size=1, stride=1, padding=0),
nn.ReLU(inplace=True),
nn.AvgPool2d(kernel_size=8, stride=1, padding=0),
)
def forward(self, x):
x = self.classifier(x)
logits = x.view(x.size(0), self.num_classes)
#probas = torch.softmax(logits, dim=1)
return logits
set_all_seeds(RANDOM_SEED)
model = NiN(NUM_CLASSES)
model.to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
import tracemalloc
tracemalloc.start()
log_dict = train_classifier_simple_v1(num_epochs=2, model=model,
optimizer=optimizer, device=DEVICE,
train_loader=train_loader, valid_loader=valid_loader,
logging_interval=50)
current, peak = tracemalloc.get_traced_memory()
print(f"{current}, {peak}")
tracemalloc.stop()
Epoch: 001/002 | Batch 0000/0176 | Loss: 2.3045 Epoch: 001/002 | Batch 0050/0176 | Loss: 2.2849 Epoch: 001/002 | Batch 0100/0176 | Loss: 2.1435 Epoch: 001/002 | Batch 0150/0176 | Loss: 2.0891 ***Epoch: 001/002 | Train. Acc.: 20.751% | Loss: 2.119 ***Epoch: 001/002 | Valid. Acc.: 20.600% | Loss: 2.121 Time elapsed: 0.40 min Epoch: 002/002 | Batch 0000/0176 | Loss: 2.1154 Epoch: 002/002 | Batch 0050/0176 | Loss: 2.0218 Epoch: 002/002 | Batch 0100/0176 | Loss: 2.0404 Epoch: 002/002 | Batch 0150/0176 | Loss: 1.9474 ***Epoch: 002/002 | Train. Acc.: 26.649% | Loss: 1.978 ***Epoch: 002/002 | Valid. Acc.: 25.720% | Loss: 1.989 Time elapsed: 0.80 min Total Training Time: 0.80 min 87518, 143683
### Delete model and free memory
model.cpu()
del model
The changes we have to make to the NiN code are highlighted below. Note that this example uses only 1 segment in checkpoint_sequential
. Generally, a lower number of segments improves memory efficiency but making the computational performance worse since more values need to be recomputed. For this architecture, I found that segments=1
represents a good trade-off, though.
##########################
### MODEL
##########################
###### NEW ####################################################
from torch.utils.checkpoint import checkpoint_sequential
###############################################################
class NiN(nn.Module):
def __init__(self, num_classes):
super(NiN, self).__init__()
self.num_classes = num_classes
self.classifier = nn.Sequential(
nn.Conv2d(3, 192, kernel_size=5, stride=1, padding=2),
nn.ReLU(inplace=True),
nn.Conv2d(192, 160, kernel_size=1, stride=1, padding=0),
nn.ReLU(inplace=True),
nn.Conv2d(160, 96, kernel_size=1, stride=1, padding=0),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
nn.Dropout(0.5),
nn.Conv2d(96, 192, kernel_size=5, stride=1, padding=2),
nn.ReLU(inplace=True),
nn.Conv2d(192, 192, kernel_size=1, stride=1, padding=0),
nn.ReLU(inplace=True),
nn.Conv2d(192, 192, kernel_size=1, stride=1, padding=0),
nn.ReLU(inplace=True),
nn.AvgPool2d(kernel_size=3, stride=2, padding=1),
nn.Dropout(0.5),
nn.Conv2d(192, 192, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(192, 192, kernel_size=1, stride=1, padding=0),
nn.ReLU(inplace=True),
nn.Conv2d(192, 10, kernel_size=1, stride=1, padding=0),
nn.ReLU(inplace=True),
nn.AvgPool2d(kernel_size=8, stride=1, padding=0),
)
###### NEW ####################################################
self.classifier_modules = [module for k, module in self.classifier._modules.items()]
###############################################################
def forward(self, x):
###### NEW ####################################################
x.requires_grad = True
x = checkpoint_sequential(functions=self.classifier_modules,
segments=1,
input=x)
###############################################################
x = x.view(x.size(0), self.num_classes)
#probas = torch.softmax(x, dim=1)
return x
set_all_seeds(RANDOM_SEED)
model = NiN(NUM_CLASSES)
model.to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
tracemalloc.start()
log_dict = train_classifier_simple_v1(num_epochs=2, model=model,
optimizer=optimizer, device=DEVICE,
train_loader=train_loader, valid_loader=valid_loader,
logging_interval=50)
current, peak = tracemalloc.get_traced_memory()
print(f"{current}, {peak}")
tracemalloc.stop()
Epoch: 001/002 | Batch 0000/0176 | Loss: 2.3045 Epoch: 001/002 | Batch 0050/0176 | Loss: 2.2849 Epoch: 001/002 | Batch 0100/0176 | Loss: 2.1435 Epoch: 001/002 | Batch 0150/0176 | Loss: 2.0891 ***Epoch: 001/002 | Train. Acc.: 20.751% | Loss: 2.119 ***Epoch: 001/002 | Valid. Acc.: 20.600% | Loss: 2.121 Time elapsed: 0.47 min Epoch: 002/002 | Batch 0000/0176 | Loss: 2.1154 Epoch: 002/002 | Batch 0050/0176 | Loss: 2.0218 Epoch: 002/002 | Batch 0100/0176 | Loss: 2.0404 Epoch: 002/002 | Batch 0150/0176 | Loss: 1.9474 ***Epoch: 002/002 | Train. Acc.: 26.649% | Loss: 1.978 ***Epoch: 002/002 | Valid. Acc.: 25.720% | Loss: 1.989 Time elapsed: 0.93 min Total Training Time: 0.93 min 57806, 115055
We can see that the gradient checkpointing improves peak memory efficiency by approximately 22% while the computational performance (runtime) becomes only 14% worse.