Licensed under the MIT License.
import sys
sys.path.append("../..")
import os
import json
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import scrapbook as sb
from interpret_text.experimental.introspective_rationale import IntrospectiveRationaleExplainer
from interpret_text.experimental.common.preprocessor.glove_preprocessor import GlovePreprocessor
from interpret_text.experimental.common.preprocessor.bert_preprocessor import BertPreprocessor
from interpret_text.experimental.common.model_config.introspective_rationale_model_config import IntrospectiveRationaleModelConfig
from interpret_text.experimental.widget import ExplanationDashboard
from notebooks.test_utils.utils_sst2 import load_sst2_pandas_df
from notebooks.test_utils.utils_data_shared import load_glove_embeddings
In this notebook, we train and evaluate a three-player explainer model on a subset of the SST-2 dataset. To run this notebook, we used the SST-2 data files provided here.
Here we set some parameters that we use for our modeling task.
# if quick run true, skips over embedding, most of model training, and model evaluation; used to quickly test pipeline
QUICK_RUN = True
MODEL_TYPE = "RNN" # currently support either RNN, BERT, or a combination of RNN and BERT
CUDA = False
# data processing parameters
DATA_FOLDER = "../../../data/sst2"
LABEL_COL = "labels"
TEXT_COL = "sentences"
token_count_thresh = 1
max_sentence_token_count = 70
# training procedure parameters
load_pretrained_model = False
pretrained_model_path = "../models/rnn.pth"
MODEL_SAVE_DIR = os.path.join("..", "models")
model_prefix = "sst2rnpmodel"
model_config = {
"cuda": CUDA,
"model_save_dir": MODEL_SAVE_DIR,
"model_prefix": model_prefix,
"lr": 2e-4
}
if QUICK_RUN:
model_config["save_best_model"] = False
model_config["pretrain_cls"] = True
model_config["num_epochs"] = 1
if MODEL_TYPE == "RNN":
# (i.e. not using BERT), load pretrained glove embeddings
if not QUICK_RUN:
model_config["embedding_path"] = load_glove_embeddings(DATA_FOLDER)
else:
model_config["embedding_path"] = os.path.join(DATA_FOLDER, "")
We start by loading a subset of the data for training and testing.
train_data = load_sst2_pandas_df('train')
test_data = load_sst2_pandas_df('test')
all_data = pd.concat([train_data, test_data])
if QUICK_RUN:
batch_size = 50
train_data = train_data.head(batch_size)
test_data = test_data.head(batch_size)
X_train = train_data[TEXT_COL]
X_test = test_data[TEXT_COL]
# get all unique labels
y_labels = all_data[LABEL_COL].unique()
model_config["labels"] = np.array(sorted(y_labels))
model_config["num_labels"] = len(y_labels)
The data is then tokenized and embedded using glove embeddings.
if MODEL_TYPE == "RNN":
preprocessor = GlovePreprocessor(token_count_thresh, max_sentence_token_count)
preprocessor.build_vocab(all_data[TEXT_COL])
if MODEL_TYPE == "BERT":
preprocessor = BertPreprocessor()
# append labels to tokenizer output
df_train = pd.concat([train_data[LABEL_COL], preprocessor.preprocess(X_train)], axis=1)
df_test = pd.concat([test_data[LABEL_COL], preprocessor.preprocess(X_test)], axis=1)
Then, we create the explainer and train it (or load a pretrained model). The steps involved to set up the explainer:
explainer = IntrospectiveRationaleExplainer(classifier_type=MODEL_TYPE, cuda=CUDA)
explainer.set_preprocessor(preprocessor)
explainer.build_model_config(model_config)
explainer.load()
explainer.fit(df_train, df_test)
We can test the explainer and measure its performance:
if not QUICK_RUN:
explainer.score(df_test)
sparsity = explainer.model.avg_sparsity
accuracy = explainer.model.avg_accuracy
anti_accuracy = explainer.model.avg_anti_accuracy
print("Test sparsity: ", sparsity)
print("Test accuracy: ", accuracy, "% Anti-accuracy: ", anti_accuracy)
# for testing
sb.glue("sparsity", sparsity)
sb.glue("accuracy", accuracy)
sb.glue("anti_accuracy", anti_accuracy)
We can display the found local importances (the most and least important words for a given sentence):
# Enter a sentence that needs to be interpreted
sentence = "Beautiful movie ; really good , the popcorn was bad"
s2 = "a beautiful and haunting examination of the stories we tell ourselves to make sense of the mundane horrors of the world."
s3 = "the premise is in extremely bad taste , and the film's supposed insights are so poorly executed and done that even a high school dropout taking his or her first psychology class could dismiss them ."
s4= "This is a super amazing movie with bad acting"
local_explanation = explainer.explain_local(s4)
We can visualize local feature importances as a heatmap over words in the document and view importance values of individual words.
explainer.visualize(local_explanation)
ExplanationDashboard(local_explanation)