This is a little extension of the work done in Distilling Task-Specific Knowledge from BERT into Simple Neural Networks by Tang et al. 2019. Hopefully this notebook will serve as an easy-to-follow guide to distillation, which is actually really simple. This is based on work I did for Polecat.
Tang demonstrates that training a lower-complexity student model to predict a teacher model's output logits is more effective than directly training the student model on the dataset. This is a really neat way of improving performance of smaller models (which are much easier to productionize).
In the paper Tang uses BERT to train a BiLSTM. One of the suggestions for future work is to explore to what extent even simpler models can benefit from the technique. This notebook does just that - we'll try and use BERT to train a CNN and simple linear model implemented in PyTorch.
The linear model is the FastText model (Joulin et al. 2016) which normally is an excellent compromise between speed and accuracy. The task is document classification. We wouldn't expect to get near BERT-like accuracy because FastText is a bag-of-words model (it ignores word order, although you can give it n-grams) but it will be interesting to see if we can increase its accuracy at all.
The CNN is the basic model described by Kim in Convolutional Neural Networks for Sentence Classification (2014). For simplicitly pretrained word embeddings haven't been used, although they would certainly improve performance.
Let's begin with our dependencies: PyTorch, the great Huggingface transformers library (for a BERT implementation) and other usual suspects.
!pip install torch transformers pandas tqdm altair joblib sklearn
Requirement already satisfied: torch in /usr/local/lib/python3.6/dist-packages (1.5.0+cu101) Requirement already satisfied: transformers in /usr/local/lib/python3.6/dist-packages (2.11.0) Requirement already satisfied: pandas in /usr/local/lib/python3.6/dist-packages (1.0.4) Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (4.41.1) Requirement already satisfied: altair in /usr/local/lib/python3.6/dist-packages (4.1.0) Requirement already satisfied: joblib in /usr/local/lib/python3.6/dist-packages (0.15.1) Requirement already satisfied: sklearn in /usr/local/lib/python3.6/dist-packages (0.0) Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from torch) (0.16.0) Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from torch) (1.18.5) Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.6/dist-packages (from transformers) (2019.12.20) Requirement already satisfied: sentencepiece in /usr/local/lib/python3.6/dist-packages (from transformers) (0.1.91) Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from transformers) (2.23.0) Requirement already satisfied: dataclasses; python_version < "3.7" in /usr/local/lib/python3.6/dist-packages (from transformers) (0.7) Requirement already satisfied: filelock in /usr/local/lib/python3.6/dist-packages (from transformers) (3.0.12) Requirement already satisfied: sacremoses in /usr/local/lib/python3.6/dist-packages (from transformers) (0.0.43) Requirement already satisfied: packaging in /usr/local/lib/python3.6/dist-packages (from transformers) (20.4) Requirement already satisfied: tokenizers==0.7.0 in /usr/local/lib/python3.6/dist-packages (from transformers) (0.7.0) Requirement already satisfied: python-dateutil>=2.6.1 in /usr/local/lib/python3.6/dist-packages (from pandas) (2.8.1) Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.6/dist-packages (from pandas) (2018.9) Requirement already satisfied: jinja2 in /usr/local/lib/python3.6/dist-packages (from altair) (2.11.2) Requirement already satisfied: toolz in /usr/local/lib/python3.6/dist-packages (from altair) (0.10.0) Requirement already satisfied: jsonschema in /usr/local/lib/python3.6/dist-packages (from altair) (2.6.0) Requirement already satisfied: entrypoints in /usr/local/lib/python3.6/dist-packages (from altair) (0.3) Requirement already satisfied: scikit-learn in /usr/local/lib/python3.6/dist-packages (from sklearn) (0.22.2.post1) Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (3.0.4) Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (2020.4.5.1) Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (1.24.3) Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (2.9) Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (1.12.0) Requirement already satisfied: click in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (7.1.2) Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.6/dist-packages (from packaging->transformers) (2.4.7) Requirement already satisfied: MarkupSafe>=0.23 in /usr/local/lib/python3.6/dist-packages (from jinja2->altair) (1.1.1) Requirement already satisfied: scipy>=0.17.0 in /usr/local/lib/python3.6/dist-packages (from scikit-learn->sklearn) (1.4.1)
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import altair as alt
from pathlib import Path
from joblib import Memory
from sklearn.metrics import f1_score
from tqdm import tqdm
from torch.utils.data import DataLoader, TensorDataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig, DistilBertForSequenceClassification
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
device(type='cuda')
We'll use the Amazon review dataset. It is freely available and consists of product reviews with a star rating, and the task is simply to predict the star rating. It's a challenging task.
First, some data wrangling. I'm afraid this notebook won't run out-of-the-box, because the data and teacher model are too large to distribute.
ROOT = Path("/mnt/gdrive/My Drive/")
if not ROOT.exists():
from google.colab import drive
drive.mount("/mnt/gdrive")
assert ROOT.exists()
DATA = ROOT / "data"
MODELS = ROOT / "models"
CACHE = ROOT / "cache/distillation"
if not CACHE.exists():
CACHE.mkdir(parents=True)
memory = Memory(CACHE, verbose=False)
market = "uk"
reviews = (pd.read_csv(DATA / "amazon" / f"amazon_reviews_multilingual_{market.upper()}_v1_00.tsv.gz",
sep="\t",
usecols=["review_id", "star_rating", "review_headline", "review_body"],
dtype={"review_id": "string",
"star_rating": "Int32",
"review_headline": "string",
"review_body": "string"})
.dropna())
We balance the classes and shuffle the dataset. Ideally we should also remove some low-value reviews, e.g. single-word reviews and reviews in other languages. But there are few enough of these to not make much difference as far as this exploration goes.
MAX_LEN = 50_000
classes = {1, 2, 3, 4, 5}
class_examples = [reviews[reviews.star_rating == rating] for rating in classes]
min_len = min(MAX_LEN // len(classes), *[len(c) for c in class_examples])
balanced_df = pd.concat([c.sample(min_len, random_state=42) for c in class_examples])
shuffled_df = balanced_df.sample(len(balanced_df))
shuffled_df["label"] = shuffled_df.star_rating.astype(int) - 1
len(shuffled_df)
50000
shuffled_df.head(2)
review_id | star_rating | review_headline | review_body | label | |
---|---|---|---|---|---|
626969 | R16AH12YPHGU7C | 1 | No instructions | No instructions, only pictures that you can fi... | 0 |
1101179 | R1W6Y6B361L24G | 3 | A Little Slight But Still Entertaining | Although the Bee Gees had included some R+B/so... | 2 |
Split the data into a training set and a test set.
train_frac = 0.8
split_idx = int(train_frac * len(shuffled_df))
train_df = shuffled_df[:split_idx]
test_df =shuffled_df[split_idx:]
len(train_df), len(test_df)
(40000, 10000)
Tokenize the text and convert it to PyTorch tensors. We also need two masking vectors for each example as input to BERT.
try:
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-multilingual-cased")
except NameError:
tokenizer = tokenizer
To reduce training time, we will precompute the teacher's predictions. Additionally joblib
is used to cache the results of this function so less time is wasted during development. This is the reason for the Torch/pickling warnings that occur later on in this notebook.
@memory.cache(ignore=["teacher"]) # cache but don't bother serializing the teacher - doesn't change and is slow to pickle
def dataframe_to_dataset(df, teacher):
max_len = 128
features = tokenizer.batch_encode_plus(df.review_body,
max_length=max_len,
pad_to_max_length=True,
return_attention_masks=True,
return_token_type_ids=True,
return_tensors="pt")
pre_dataset = TensorDataset(features["input_ids"],
features["attention_mask"],
features["token_type_ids"])
teacher.to(device)
teacher.eval()
teacher_predictions = []
for batch in tqdm(DataLoader(pre_dataset, batch_size=32, shuffle=False)):
batch = tuple([b.to(device) for b in batch])
inputs = {"input_ids": batch[0], "attention_mask": batch[1]}
if teacher.base_model_prefix == "bert":
inputs["token_type_ids"] = batch[2]
with torch.no_grad():
outputs = teacher(**inputs)
teacher_predictions.append(outputs[0].to(torch.device("cpu"))) # put back on CPU
dataset = TensorDataset(features["input_ids"],
features["attention_mask"],
features["token_type_ids"],
torch.tensor(df.label.astype("int").to_numpy(), dtype=torch.long),
torch.cat(teacher_predictions, axis=0))
return dataset
These are more-or-less the default hyperparameters for FastText. The embedding dimension is reduced to 50 to speed up processing slightly.
Beware the batch size - we're using a batch size of 1 for training the linear model. This has a significant impact on its accuracy, and it's lightweight enough that we can get away with it.
N_EPOCHS = 5
EMBEDDING_DIM = 50
LR = 0.5
BATCH_SIZE = 32
N_LABELS = 5 # num review ratings
padding_idx = tokenizer.vocab["[PAD]"]
n_vocab = len(tokenizer.vocab)
The teacher is actually DistilBERT, rather than BERT. So we are distilling from a distilled model! Ideally the teacher should be BERT-proper so that results are more comparable. But this is running on Google Colab with limited GPU time, so a compromise is necessary.
I trained this DistilBERT model on the same dataset previously. Later on we'll check its accuracy.
try:
config = config
teacher = teacher
except NameError:
config = AutoConfig.from_pretrained("distilbert-base-multilingual-cased")
config.num_labels = N_LABELS
teacher = DistilBertForSequenceClassification(config)
teacher.load_state_dict(torch.load(MODELS / "distilbert_uk_50000.bin", map_location=device))
This is a simple convolution neural network (CNN) with dropout as per Kim (2014).
class CNN(nn.Module):
def __init__(self,
n_vocab,
n_labels,
embedding_dim=50,
n_filters=100,
filter_sizes=[3, 4, 5],
dropout=0.5,
special_chars=[],
pretrained_embeddings=None): # TODO make number of conv layers configurable
super(CNN, self).__init__()
self.n_vocab = n_vocab
self.n_labels = n_labels
self.embedding_dim = embedding_dim
self.n_filters = n_filters
self.filter_sizes = filter_sizes
self.dropout_p = dropout
self.width = len(filter_sizes) * n_filters
if pretrained_embeddings is not None:
assert n_vocab == pretrained_embeddings.shape[0]
assert embedding_dim == pretrained_embeddings.shape[1]
self.embedding = nn.Embedding.from_pretrained(pretrained_embeddings)
else:
self.embedding = nn.Embedding(n_vocab, embedding_dim)
self.conv0 = nn.Conv2d(in_channels=1,
out_channels=n_filters,
kernel_size=(filter_sizes[0], embedding_dim))
self.conv1 = nn.Conv2d(in_channels=1,
out_channels=n_filters,
kernel_size=(filter_sizes[1], embedding_dim))
self.conv2 = nn.Conv2d(in_channels=1,
out_channels=n_filters,
kernel_size=(filter_sizes[2], embedding_dim))
self.dropout = nn.Dropout(dropout)
self.fc = nn.Linear(in_features=self.width, out_features=n_labels)
for special in special_chars:
self.embedding.weight.data[special] = torch.zeros(embedding_dim)
def forward(self, input_ids, **kwargs):
"""Only input ids are required - kwargs are for API compat with BERT."""
X = self.embedding(input_ids)
X = X.unsqueeze(1) # add single channel as dim 1
X0 = F.relu(self.conv0(X).squeeze(3))
X1 = F.relu(self.conv1(X).squeeze(3))
X2 = F.relu(self.conv2(X).squeeze(3))
X0 = F.max_pool1d(X0, X0.shape[2]).squeeze(2)
X1 = F.max_pool1d(X1, X1.shape[2]).squeeze(2)
X2 = F.max_pool1d(X2, X2.shape[2]).squeeze(2)
X = torch.cat([X0, X1, X2], dim=1)
X = self.dropout(X)
X = self.fc(X)
return X
This is a faithful implementation of the FastText linear model (Joulin et al. 2016).
class LinearModel(nn.Module):
def __init__(self, n_vocab, n_labels, embedding_dim, padding_idx):
super(LinearModel, self).__init__()
self.embeddings = nn.Embedding(n_vocab, embedding_dim, padding_idx=padding_idx)
self.output = nn.Linear(embedding_dim, n_labels)
with torch.no_grad():
# FastText initializes embeddings with uniform distribution vs normal in PyTorch
self.embeddings.weight.uniform_(to=1.0 / embedding_dim)
self.embeddings.weight[padding_idx] = 0 # but FT doesn't have a padding token
# FastText initializes output with zeros vs some random dist in PyTorch
self.output.weight.zero_()
def forward(self, input_ids, **kwargs):
"""Only input ids are required - kwargs are for API compat with BERT."""
X = self.embeddings(input_ids)
X = X.mean(dim=1)
X = self.output(X)
return X
This function trains the model for one epoch. If no teacher is provided it uses cross entropy loss (i.e. softmax then NLL) and compares the model predictions to the target label.
If a teacher is provided then model predictions are compared to the teacher's predictions and MSE loss is used.
In the paper Tang defines a cost function that is a balance between the two (i.e. L=αLCE+(1−αLMSE) but in practice observed that the best value for α was zero.
The accuracy on the training set is also output for visibility.
def train_epoch(train_iter, model, optim, epoch_num, distil=False):
train_loss = 0
train_acc = 0
y_true = []
y_pred = []
model.to(device)
model.train()
if distil:
cost = nn.MSELoss()
else:
cost = nn.CrossEntropyLoss()
for batch in tqdm(train_iter, total=len(train_iter), desc=f"Batch progress for epoch {epoch_num}"):
batch = tuple([t.to(device) for t in batch])
inputs = {"input_ids": batch[0],
"attention_mask": batch[1]}
labels = batch[3]
optim.zero_grad()
output = model(**inputs)
if distil:
target = batch[4]
else:
target = labels
batch_loss = cost(output, target)
# Had some trouble with linear distilled model dying in training
# but since starting to debug it the issue hasn't reoccurred.
# Gradient clipping might help.
if torch.isnan(batch_loss):
print("NAN batch loss!", epoch_num, batch_loss, output, target)
train_loss += batch_loss.item()
batch_acc = (output.argmax(1) == labels).sum().item()
train_acc += batch_acc
y_true.extend(labels.tolist())
y_pred.extend(output.argmax(1).tolist())
batch_loss.backward()
optim.step()
return train_loss / len(train_iter), train_acc / len(train_iter.dataset), f1_score(y_true, y_pred, average="macro") # classes are already balanced
def train_loop(model, optim, train_loader, test_loader, n_epochs=5, sched=None, distil=False):
training_results = {"epoch": list(range(n_epochs)),
"train_loss": [],
"train_acc": [],
"train_f1_macro": [],
"test_loss": [],
"test_acc": [],
"test_f1_macro": []}
model.to(device)
try:
for i in range(n_epochs):
train_loss, train_acc, train_f1 = train_epoch(train_loader, model, optim, epoch_num=i, distil=distil)
if sched is not None:
sched.step()
test_loss, test_acc, test_f1 = validate(test_loader, model)
training_results["train_loss"].append(train_loss)
training_results["train_acc"].append(train_acc)
training_results["train_f1_macro"].append(train_f1)
training_results["test_loss"].append(test_loss)
training_results["test_acc"].append(test_acc)
training_results["test_f1_macro"].append(test_f1)
except KeyboardInterrupt:
pass
return pd.DataFrame(training_results)
The validation function is similar but in this case there is no option to compare to the teacher's predictions, because that's not the ultimate point of the exercise - at the end of it all we just want a better small model.
The metrics are accuracy and macro F1. Macro F1 will help us understand whether the model is performing similarly across all classes (e.g. a model that always predicts class 2 will have accuracy of 20% but a terrible F1 score).
def validate(test_iter, model):
test_acc = 0
test_loss = 0
y_true = []
y_pred = []
cost = nn.CrossEntropyLoss()
model.to(device)
model.eval()
for batch in tqdm(test_iter, desc="Validating"):
batch = tuple([t.to(device) for t in batch])
inputs = {"input_ids": batch[0],
"attention_mask": batch[1],
"token_type_ids": batch[2]}
labels = batch[3]
with torch.no_grad():
output = model(**inputs)
batch_loss = cost(output, labels)
test_loss += batch_loss.item()
batch_acc = (output.argmax(1) == labels).sum().item()
test_acc += batch_acc
y_true.extend(labels.tolist())
y_pred.extend(output.argmax(1).tolist())
return test_loss / len(test_iter), test_acc / len(test_iter.dataset), f1_score(y_true, y_pred, average="macro") # classes are balanced
We will construct two models of each architecture; one to train directly and one to train with distillation.
linear_model = LinearModel(n_vocab, N_LABELS, embedding_dim=EMBEDDING_DIM, padding_idx=padding_idx)
linear_model_dist = LinearModel(n_vocab, N_LABELS, embedding_dim=EMBEDDING_DIM, padding_idx=padding_idx)
cnn = CNN(n_vocab, N_LABELS, embedding_dim=EMBEDDING_DIM, special_chars=[padding_idx])
cnn_dist = CNN(n_vocab, N_LABELS, embedding_dim=EMBEDDING_DIM, special_chars=[padding_idx])
test_loader = DataLoader(dataframe_to_dataset(test_df, teacher), batch_size=BATCH_SIZE, shuffle=False)
100%|██████████| 313/313 [00:37<00:00, 8.25it/s] /usr/local/lib/python3.6/dist-packages/torch/storage.py:34: FutureWarning: pickle support for Storage will be removed in 1.5. Use `torch.save` instead warnings.warn("pickle support for Storage will be removed in 1.5. Use `torch.save` instead", FutureWarning)
Sanity check - we expect 20% accuracy in each case. The loss (which is cross-entropy for validation, regardless of the training method) should be about 1.6, i.e. the loss we'd expect at random accuracy. The CNN's might have higher loss because of the regularization (i.e. the dropout).
validate(test_loader, linear_model)
Validating: 100%|██████████| 313/313 [00:00<00:00, 1010.79it/s]
(1.6121794857537022, 0.1963, 0.06563571010616068)
validate(test_loader, linear_model_dist)
Validating: 100%|██████████| 313/313 [00:00<00:00, 1017.86it/s]
(1.6159831746317708, 0.2011, 0.06697194238614604)
validate(test_loader, cnn)
Validating: 100%|██████████| 313/313 [00:05<00:00, 54.14it/s]
(1.9059412772663105, 0.2051, 0.06808298755186722)
validate(test_loader, cnn_dist)
Validating: 100%|██████████| 313/313 [00:05<00:00, 54.44it/s]
(1.8158414051555598, 0.2162, 0.1188892656541893)
train_loader = DataLoader(dataframe_to_dataset(train_df, teacher), batch_size=1, shuffle=False) # optimal training for the linear model
100%|██████████| 1250/1250 [02:30<00:00, 8.29it/s] /usr/local/lib/python3.6/dist-packages/torch/storage.py:34: FutureWarning: pickle support for Storage will be removed in 1.5. Use `torch.save` instead warnings.warn("pickle support for Storage will be removed in 1.5. Use `torch.save` instead", FutureWarning) /usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:1: UserWarning: Persisting input arguments took 0.91s to run. If this happens often in your code, it can cause performance problems (results will be correct in all cases). The reason for this is probably some large input arguments for a wrapped function (e.g. large strings). THIS IS A JOBLIB ISSUE. If you can, kindly provide the joblib's team with an example so that they can fix the problem. """Entry point for launching an IPython kernel.
optim = torch.optim.SGD(linear_model.parameters(), lr=LR)
sched = torch.optim.lr_scheduler.StepLR(optim, step_size=1, gamma=0.5)
linear_model_train_results = train_loop(linear_model, optim, train_loader, test_loader, N_EPOCHS, sched, distil=False)
linear_model_train_results["model"] = "Linear"
linear_model_train_results
Batch progress for epoch 0: 100%|██████████| 40000/40000 [01:08<00:00, 585.80it/s] Validating: 100%|██████████| 313/313 [00:00<00:00, 995.07it/s] Batch progress for epoch 1: 100%|██████████| 40000/40000 [01:07<00:00, 591.82it/s] Validating: 100%|██████████| 313/313 [00:00<00:00, 1012.24it/s] Batch progress for epoch 2: 100%|██████████| 40000/40000 [01:07<00:00, 594.67it/s] Validating: 100%|██████████| 313/313 [00:00<00:00, 972.25it/s] Batch progress for epoch 3: 100%|██████████| 40000/40000 [01:07<00:00, 594.98it/s] Validating: 100%|██████████| 313/313 [00:00<00:00, 972.59it/s] Batch progress for epoch 4: 100%|██████████| 40000/40000 [01:07<00:00, 591.64it/s] Validating: 100%|██████████| 313/313 [00:00<00:00, 978.57it/s]
epoch | train_loss | train_acc | train_f1_macro | test_loss | test_acc | test_f1_macro | model | |
---|---|---|---|---|---|---|---|---|
0 | 0 | 1.517863 | 0.327375 | 0.323505 | 1.503577 | 0.3584 | 0.280098 | Linear |
1 | 1 | 1.286499 | 0.443450 | 0.436256 | 1.385932 | 0.4166 | 0.350168 | Linear |
2 | 2 | 1.202731 | 0.488375 | 0.480658 | 1.331111 | 0.4408 | 0.391589 | Linear |
3 | 3 | 1.157016 | 0.516250 | 0.508555 | 1.294666 | 0.4603 | 0.427882 | Linear |
4 | 4 | 1.131102 | 0.529975 | 0.522605 | 1.267536 | 0.4727 | 0.451547 | Linear |
A second training loop for the linear model that is trained via distillation.
optim = torch.optim.SGD(linear_model_dist.parameters(), lr=LR)
sched = torch.optim.lr_scheduler.StepLR(optim, step_size=1, gamma=0.5)
linear_model_dist_train_results = train_loop(linear_model_dist, optim, train_loader, test_loader, N_EPOCHS, sched, distil=True)
linear_model_dist_train_results["model"] = "Linear (distilled)"
linear_model_dist_train_results
Batch progress for epoch 0: 100%|██████████| 40000/40000 [01:06<00:00, 601.33it/s] Validating: 100%|██████████| 313/313 [00:00<00:00, 999.62it/s] Batch progress for epoch 1: 100%|██████████| 40000/40000 [01:06<00:00, 600.73it/s] Validating: 100%|██████████| 313/313 [00:00<00:00, 998.13it/s] Batch progress for epoch 2: 100%|██████████| 40000/40000 [01:06<00:00, 600.61it/s] Validating: 100%|██████████| 313/313 [00:00<00:00, 977.32it/s] Batch progress for epoch 3: 100%|██████████| 40000/40000 [01:06<00:00, 599.38it/s] Validating: 100%|██████████| 313/313 [00:00<00:00, 1006.41it/s] Batch progress for epoch 4: 100%|██████████| 40000/40000 [01:06<00:00, 598.48it/s] Validating: 100%|██████████| 313/313 [00:00<00:00, 981.18it/s]
epoch | train_loss | train_acc | train_f1_macro | test_loss | test_acc | test_f1_macro | model | |
---|---|---|---|---|---|---|---|---|
0 | 0 | 2.156607 | 0.364000 | 0.367437 | 1.315366 | 0.4376 | 0.431926 | Linear (distilled) |
1 | 1 | 1.915474 | 0.439525 | 0.443687 | 1.281477 | 0.4598 | 0.464885 | Linear (distilled) |
2 | 2 | 3.414282 | 0.442000 | 0.446172 | 1.294624 | 0.4524 | 0.456491 | Linear (distilled) |
3 | 3 | 1.905386 | 0.453650 | 0.458064 | 1.275777 | 0.4533 | 0.457361 | Linear (distilled) |
4 | 4 | 1.464136 | 0.478500 | 0.482614 | 1.259447 | 0.4655 | 0.470224 | Linear (distilled) |
It's interesting to see that the student learned noticeably faster than the directly-trained model (look at the train_acc
and test_acc
columns). It also performs slightly better on the test set, but this may not be a significant result.
Note that you cannot directly compare the training loss, remember these are from different loss functions.
For the CNN we use Adam and a more conventional batch size to speed up training - it's a more complicated architecture.
train_loader = DataLoader(dataframe_to_dataset(train_df, teacher), batch_size=BATCH_SIZE, shuffle=False) # optimal training for the cnn
optim = torch.optim.Adam(cnn.parameters())
cnn_train_results = train_loop(cnn, optim, train_loader, test_loader, n_epochs=N_EPOCHS, sched=None, distil=False)
cnn_train_results["model"] = "CNN"
cnn_train_results
Batch progress for epoch 0: 100%|██████████| 1250/1250 [00:27<00:00, 45.88it/s] Validating: 100%|██████████| 313/313 [00:05<00:00, 54.66it/s] Batch progress for epoch 1: 100%|██████████| 1250/1250 [00:27<00:00, 45.87it/s] Validating: 100%|██████████| 313/313 [00:05<00:00, 54.49it/s] Batch progress for epoch 2: 100%|██████████| 1250/1250 [00:27<00:00, 45.86it/s] Validating: 100%|██████████| 313/313 [00:05<00:00, 54.56it/s] Batch progress for epoch 3: 100%|██████████| 1250/1250 [00:27<00:00, 45.92it/s] Validating: 100%|██████████| 313/313 [00:05<00:00, 54.50it/s] Batch progress for epoch 4: 100%|██████████| 1250/1250 [00:27<00:00, 45.83it/s] Validating: 100%|██████████| 313/313 [00:05<00:00, 54.60it/s]
epoch | train_loss | train_acc | train_f1_macro | test_loss | test_acc | test_f1_macro | model | |
---|---|---|---|---|---|---|---|---|
0 | 0 | 1.556353 | 0.308800 | 0.305734 | 1.364729 | 0.4036 | 0.385092 | CNN |
1 | 1 | 1.374294 | 0.396175 | 0.390798 | 1.276479 | 0.4454 | 0.444202 | CNN |
2 | 2 | 1.286066 | 0.438800 | 0.433470 | 1.237772 | 0.4617 | 0.454582 | CNN |
3 | 3 | 1.213118 | 0.471425 | 0.466652 | 1.226823 | 0.4673 | 0.454492 | CNN |
4 | 4 | 1.154105 | 0.501875 | 0.497450 | 1.224071 | 0.4739 | 0.462970 | CNN |
optim = torch.optim.Adam(cnn_dist.parameters())
cnn_dist_train_results = train_loop(cnn_dist, optim, train_loader, test_loader, n_epochs=N_EPOCHS, sched=None, distil=True)
cnn_dist_train_results["model"] = "CNN (distilled)"
cnn_dist_train_results
Batch progress for epoch 0: 100%|██████████| 1250/1250 [00:27<00:00, 46.02it/s] Validating: 100%|██████████| 313/313 [00:05<00:00, 54.67it/s] Batch progress for epoch 1: 100%|██████████| 1250/1250 [00:27<00:00, 46.02it/s] Validating: 100%|██████████| 313/313 [00:05<00:00, 54.51it/s] Batch progress for epoch 2: 100%|██████████| 1250/1250 [00:27<00:00, 45.76it/s] Validating: 100%|██████████| 313/313 [00:05<00:00, 54.38it/s] Batch progress for epoch 3: 100%|██████████| 1250/1250 [00:27<00:00, 45.89it/s] Validating: 100%|██████████| 313/313 [00:05<00:00, 54.50it/s] Batch progress for epoch 4: 100%|██████████| 1250/1250 [00:27<00:00, 45.98it/s] Validating: 100%|██████████| 313/313 [00:05<00:00, 54.49it/s]
epoch | train_loss | train_acc | train_f1_macro | test_loss | test_acc | test_f1_macro | model | |
---|---|---|---|---|---|---|---|---|
0 | 0 | 2.292530 | 0.331050 | 0.334280 | 1.298172 | 0.4220 | 0.427988 | CNN (distilled) |
1 | 1 | 1.490734 | 0.416750 | 0.421440 | 1.236959 | 0.4490 | 0.455647 | CNN (distilled) |
2 | 2 | 1.209579 | 0.451575 | 0.455322 | 1.210864 | 0.4641 | 0.470251 | CNN (distilled) |
3 | 3 | 1.033341 | 0.472825 | 0.475761 | 1.181134 | 0.4880 | 0.491456 | CNN (distilled) |
4 | 4 | 0.919759 | 0.484475 | 0.486859 | 1.170595 | 0.4903 | 0.493316 | CNN (distilled) |
Again the distilled CNN has has achieved greater progress in the early epochs, and again the test accuracy is greater for the distilled model.
training_results = pd.concat([linear_model_train_results,
linear_model_dist_train_results,
cnn_train_results,
cnn_dist_train_results])
training_results.query("epoch == 4").set_index("model")[["test_acc", "test_f1_macro"]]
test_acc | test_f1_macro | |
---|---|---|
model | ||
Linear | 0.4727 | 0.451547 |
Linear (distilled) | 0.4655 | 0.470224 |
CNN | 0.4739 | 0.462970 |
CNN (distilled) | 0.4903 | 0.493316 |
alt.Chart(training_results).mark_line().encode(
x="epoch:Q",
y="test_f1_macro:Q",
color="model"
)
The distilled models learned faster (reached higher accuracy at lower epochs) and achieved slightly better test accuracy.
We would need to check whether that result is significant however. Because the data is multinomial and the metric is F1, we could either use the bootstrap method or follow the approach in A Bayesian Interpretation of the Confusion Matrix (Caelen 2017).
For another interesting comparison, what can the teacher achieve?
This isn't quite fair because the teacher was not trained on the same splits of this dataset. But it's a large dataset and the likely proportion of the teacher's training data in this test set is low.
teacher.to(device)
teacher.eval()
teacher_test_acc5 = []
for batch_num, batch in enumerate(tqdm(test_loader)):
batch = tuple([t.to(device) for t in batch])
inputs = {"input_ids": batch[0],
"attention_mask": batch[1]}
if teacher.base_model_prefix == "bert":
inputs["token_type_ids"]: batch[2]
labels = batch[3]
with torch.no_grad():
logits = teacher(**inputs)[0]
probs = torch.softmax(logits, dim=1)
preds_5class = probs.argmax(dim=1)
acc_5class = (preds_5class == labels).sum().item() / len(batch[0])
teacher_test_acc5.append(acc_5class)
np.mean(teacher_test_acc5)
100%|██████████| 313/313 [00:38<00:00, 8.11it/s]
0.6168130990415336
So neither model really got close to the teacher's accuracy, but we do see a (potentially significant) improvement in accuracy between the distilled and directly-trained models.
It's also very interesting to see that the students converged faster. This supports Tang's suggestion that the information about prediction uncertainty is valuable, and that this even outweighs the error from the teacher's inaccurate predictions.
We might get better results if we implement the data augmentation that Tang suggests. We could also probably do better with a more complex student - you can see that in this NLP Town blog post, which inspired me to try this. NLP Town trained spaCy's "ensemble" classifier, which is a more sophisticated CNN than this and would be expected to perform better.