#!/usr/bin/env python # coding: utf-8 # # Using BERT embeddings for text classification of movie reviews # # ![](support/single-classification-bert.png) # # ![](support/godfather.png) # In[1]: import gluonnlp as nlp import mxnet as mx from mxnet import gluon, nd import numpy as np # ## Data # We are going to use the [imdb dataset](https://ai.stanford.edu/~amaas/data/sentiment/), trying to predict if a review is positive or if a review is negative # In[6]: def transform_label(data): """ Transform label into position / negative """ text, label = data return text, 1 if label >= 5 else 0 # In[7]: train_dataset = nlp.data.IMDB('train') test_dataset = nlp.data.IMDB('test') # In[8]: k = {i+1:0 for i in range(10)} for elem in train_dataset: k[elem[1]] += 1 # In[9]: print("Distribution of the ratings") k # In[10]: print("Positive Review:\n{}".format(test_dataset[0][0])) print() print("Negative Review:\n{}".format(test_dataset[12501][0])) # In[11]: train_dataset = train_dataset.transform(transform_label) test_dataset = test_dataset.transform(transform_label) # In[12]: print("There are {} training examples and {} test examples".format(len(train_dataset), len(test_dataset))) # # Sklearn TFIDF baseline # # Let's use sklearn to build a TFIDF pipeline with word tri-grams as a baseline # In[9]: from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.feature_selection import chi2 from sklearn.linear_model import LogisticRegression from sklearn.pipeline import Pipeline for n_gram in [1,2,3]: text_clf = Pipeline([ ('tfidf', TfidfVectorizer(sublinear_tf=True, min_df=2+n_gram, norm='l2', encoding='latin-1', ngram_range=(1,n_gram), stop_words='english')), ('clf', LogisticRegression()), ]) train_x = [elem[0] for elem in train_dataset] test_x = [elem[0] for elem in test_dataset] train_y = np.array([elem[1] for elem in train_dataset]) test_y = np.array([elem[1] for elem in test_dataset]) text_clf.fit(train_x, train_y) test_y_hat = text_clf.predict(test_x) train_y_hat = text_clf.predict(train_x) print("{}-gram Accuracy train:{}%, test:{}%".format(n_gram, (train_y_hat == train_y).mean(), (test_y_hat == test_y).mean())) # # Fine-tuning BERT # # We download a pre-trained BERT a fine-tune it on the same dataset # In[13]: ctx = mx.gpu() if mx.context.num_gpus() > 0 else mx.cpu() # In[14]: bert_base, vocabulary = nlp.model.get_model('bert_24_1024_16', dataset_name='book_corpus_wiki_en_uncased', pretrained=True, use_pooler=True, use_decoder=False, use_classifier=False, ctx=ctx) # In[15]: batch_size = 8 # We need to process the words the same way as it was done during training for that we use the `BERTTokenizer` and the `BERTSentenceTransform` # In[16]: # use the vocabulary from pre-trained model for tokenization bert_tokenizer = nlp.data.BERTTokenizer(vocabulary) max_len = 128 transform = nlp.data.BERTSentenceTransform(bert_tokenizer, max_len, pad=False, pair=False) # We create a custom network for BERT classification we take advantage of the pooler output which is the output of `[CLS]` token plus a non-linearity # In[17]: class BERTTextClassifier(gluon.nn.Block): def __init__(self, bert, num_classes): super(BERTTextClassifier, self).__init__() self.bert = bert with self.name_scope(): self.classifier = gluon.nn.Dense(num_classes) def forward(self, inputs, seq_len, token_types): out, pooler = self.bert(inputs, seq_len, token_types) return self.classifier(pooler) # In[18]: net = BERTTextClassifier(bert_base, 2) net.classifier.initialize(ctx=ctx) # **Data Loading**: # In[19]: def transform_fn(text, label): data, length, token_type = transform([text]) return data.astype('float32'), length.astype('float32'), token_type.astype('float32'), label # In[20]: batchify_fn = nlp.data.batchify.Tuple( nlp.data.batchify.Pad(axis=0), nlp.data.batchify.Stack(), nlp.data.batchify.Pad(axis=0), nlp.data.batchify.Stack(np.float32)) # In[21]: train_data = gluon.data.DataLoader(train_dataset.transform(transform_fn), batchify_fn=batchify_fn, shuffle=True, batch_size=batch_size, num_workers=8, thread_pool=True) test_data = gluon.data.DataLoader(test_dataset.transform(transform_fn), batchify_fn=batchify_fn, shuffle=True, batch_size=batch_size*4, num_workers=8, thread_pool=True) # **Training** # In[22]: trainer = gluon.Trainer(net.collect_params(), 'bertadam', {'learning_rate':0.000005, 'wd':0.001, 'epsilon':1e-6}) loss_fn = gluon.loss.SoftmaxCELoss() net.hybridize(static_alloc=True, static_shape=True) num_epoch = 3 # Training loop # In[23]: for epoch in range(num_epoch): accuracy = mx.metric.Accuracy() running_loss = 0 for i, (inputs, seq_len, token_types, label) in enumerate(train_data): inputs = inputs.as_in_context(ctx) seq_len = seq_len.as_in_context(ctx) token_types = token_types.as_in_context(ctx) label = label.as_in_context(ctx) with mx.autograd.record(): out = net(inputs, token_types, seq_len) loss = loss_fn(out, label.astype('float32')) loss.backward() running_loss += loss.mean() trainer.step(batch_size) accuracy.update(label, out.softmax()) if i % 50 == 0: print("Batch", i, "Accuracy", accuracy.get()[1],"Loss", running_loss.asscalar()/(i+1)) print("Epoch {}, Accuracy {}, Loss {}".format(epoch, accuracy.get(), running_loss.asscalar()/(i+1))) # **Evaluation** # In[45]: accuracy = 0 for i, (inputs, seq_len, token_types, label) in enumerate(test_data): inputs = inputs.as_in_context(ctx) seq_len = seq_len.as_in_context(ctx) token_types = token_types.as_in_context(ctx) label = label.as_in_context(ctx) out = net(inputs, token_types, seq_len) accuracy += (out.argmax(axis=1).squeeze() == label).mean() if i % 50 == 0 and i > 0: print(accuracy.asscalar()/(i+1)) print("Test Accuracy {}".format(accuracy.asscalar()/(i+1))) # Final accuracies: # # | Model | Training Accuracy | Testing Accuracy | # |--------------|-------------------|------------------| # |TF-IDF 1-gram | 93.6% | 88.4% | # |TF-IDF 2-gram | 95.0% | 88.6% | # |TF-IDF 3-gram | 94.9% | 88.7% | # |BERT 1024 | **97.0%** | **90.3%** |