Licensed under the MIT License.

Text Classification of SST-2 Sentences using a 3-Player Introspective Model

In [ ]:
import sys
import os
import json
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import scrapbook as sb

from interpret_text.experimental.introspective_rationale import IntrospectiveRationaleExplainer
from interpret_text.experimental.common.preprocessor.glove_preprocessor import GlovePreprocessor
from interpret_text.experimental.common.preprocessor.bert_preprocessor import BertPreprocessor
from interpret_text.experimental.common.model_config.introspective_rationale_model_config import IntrospectiveRationaleModelConfig
from interpret_text.experimental.widget import ExplanationDashboard

from notebooks.test_utils.utils_sst2 import load_sst2_pandas_df
from notebooks.test_utils.utils_data_shared import load_glove_embeddings


In this notebook, we train and evaluate a three-player explainer model on a subset of the SST-2 dataset. To run this notebook, we used the SST-2 data files provided here.

Set parameters

Here we set some parameters that we use for our modeling task.

In [ ]:
# if quick run true, skips over embedding, most of model training, and model evaluation; used to quickly test pipeline
MODEL_TYPE = "RNN" # currently support either RNN, BERT, or a combination of RNN and BERT
CUDA = False

# data processing parameters
DATA_FOLDER = "../../../data/sst2"
LABEL_COL = "labels" 
TEXT_COL = "sentences"
token_count_thresh = 1
max_sentence_token_count = 70

# training procedure parameters
load_pretrained_model = False
pretrained_model_path = "../models/rnn.pth"
MODEL_SAVE_DIR = os.path.join("..", "models")
model_prefix = "sst2rnpmodel"
In [ ]:
model_config = {
    "cuda": CUDA,
    "model_save_dir": MODEL_SAVE_DIR, 
    "model_prefix": model_prefix,
    "lr": 2e-4

    model_config["save_best_model"] = False
    model_config["pretrain_cls"] = True
    model_config["num_epochs"] = 1

    # (i.e. not using BERT), load pretrained glove embeddings
    if not QUICK_RUN:
        model_config["embedding_path"] = load_glove_embeddings(DATA_FOLDER)
        model_config["embedding_path"] = os.path.join(DATA_FOLDER, "")

Read Dataset

We start by loading a subset of the data for training and testing.

In [ ]:
train_data = load_sst2_pandas_df('train')
test_data = load_sst2_pandas_df('test')
all_data = pd.concat([train_data, test_data])
    batch_size = 50
    train_data = train_data.head(batch_size)
    test_data = test_data.head(batch_size)
X_train = train_data[TEXT_COL]
X_test = test_data[TEXT_COL]
In [ ]:
# get all unique labels
y_labels = all_data[LABEL_COL].unique()
In [ ]:
model_config["labels"] = np.array(sorted(y_labels))
model_config["num_labels"] = len(y_labels)

Tokenization and embedding

The data is then tokenized and embedded using glove embeddings.

In [ ]:
    preprocessor = GlovePreprocessor(token_count_thresh, max_sentence_token_count)
    preprocessor = BertPreprocessor()
In [ ]:
# append labels to tokenizer output
df_train = pd.concat([train_data[LABEL_COL], preprocessor.preprocess(X_train)], axis=1)
df_test = pd.concat([test_data[LABEL_COL], preprocessor.preprocess(X_test)], axis=1)


Then, we create the explainer and train it (or load a pretrained model). The steps involved to set up the explainer:

  • Initialize explainer
  • Setup preprocessor for the explainer
  • Supply necessary model configurations to the explainer
  • Load the explainer once all necessary modules are setup
  • Fit/Train the explainer
In [ ]:
explainer = IntrospectiveRationaleExplainer(classifier_type=MODEL_TYPE, cuda=CUDA)
In [ ]:
In [ ]:
In [ ]:
In [ ]:, df_test)

We can test the explainer and measure its performance:

In [ ]:
if not QUICK_RUN:
    sparsity = explainer.model.avg_sparsity
    accuracy = explainer.model.avg_accuracy
    anti_accuracy = explainer.model.avg_anti_accuracy
    print("Test sparsity: ", sparsity)
    print("Test accuracy: ", accuracy, "% Anti-accuracy: ", anti_accuracy)
    # for testing
    sb.glue("sparsity", sparsity)
    sb.glue("accuracy", accuracy)
    sb.glue("anti_accuracy", anti_accuracy)

Local importances

We can display the found local importances (the most and least important words for a given sentence):

In [ ]:
# Enter a sentence that needs to be interpreted
sentence = "Beautiful movie ; really good , the popcorn was bad"
s2 = "a beautiful and haunting examination of the stories we tell ourselves to make sense of the mundane horrors of the world."
s3 = "the premise is in extremely bad taste , and the film's supposed insights are so poorly executed and done that even a high school dropout taking his or her first psychology class could dismiss them ."
s4= "This is a super amazing movie with bad acting"

local_explanation = explainer.explain_local(s4)

Visualize explanations

We can visualize local feature importances as a heatmap over words in the document and view importance values of individual words.

In [ ]:
In [ ]: