Question Answering with DeepMatcher

Note: you can run this notebook live in Google Colab.

DeepMatcher can be easily be used for text matching tasks such Question Answering, Text Entailment, etc. In this tutorial we will see how to use DeepMatcher for Answer Selection, a major sub-task of Question Answering. Specifically, we will look at WikiQA, a benchmark dataset for Answer Selection. There are three main steps in this tutorial:

  1. Get data and transform it into DeepMatcher input format
  2. Setup and train DeepMatcher model
  3. Evaluate model using QA eval metrics

Before we begin, if you are running this notebook in Colab, you will first need to install necessary packages by running the code below:

In [ ]:
try:
    import deepmatcher
except:
    !pip install -qqq deepmatcher

Step 1: Get data and transform it into DeepMatcher input format

First let's import relevant packages and download the dataset:

In [1]:
import deepmatcher as dm
import pandas as pd
import os

!wget -qnc https://download.microsoft.com/download/E/5/F/E5FCFCEE-7005-4814-853D-DAA7C66507E0/WikiQACorpus.zip
!unzip -qn WikiQACorpus.zip

Let's see how this dataset looks like:

In [2]:
raw_train = pd.read_csv(os.path.join('WikiQACorpus', 'WikiQA-train.txt'), sep='\t', header=None)
raw_train.head()
Out[2]:
0 1 2
0 how are glacier caves formed ? A partly submerged glacier cave on Perito More... 0
1 how are glacier caves formed ? The ice facade is approximately 60 m high 0
2 how are glacier caves formed ? Ice formations in the Titlis glacier cave 0
3 how are glacier caves formed ? A glacier cave is a cave formed within the ice... 1
4 how are glacier caves formed ? Glacier caves are often called ice caves , but... 0

Clearly, it is not in the format deepmatcher wants its input data to be in - this file has no column names, no ID column, and its not a CSV file. Let's fix that:

In [3]:
raw_train.columns = ['left_value', 'right_value', 'label']
raw_train.index.name = 'id'
raw_train.head()
Out[3]:
left_value right_value label
id
0 how are glacier caves formed ? A partly submerged glacier cave on Perito More... 0
1 how are glacier caves formed ? The ice facade is approximately 60 m high 0
2 how are glacier caves formed ? Ice formations in the Titlis glacier cave 0
3 how are glacier caves formed ? A glacier cave is a cave formed within the ice... 1
4 how are glacier caves formed ? Glacier caves are often called ice caves , but... 0

Looks good, now let's save this to disk and transform the validation and test data in the same way:

In [4]:
raw_train.to_csv(os.path.join('WikiQACorpus', 'dm_train.csv'))

raw_files = ['WikiQA-dev.txt', 'WikiQA-test.txt']
csv_files = ['dm_valid.csv', 'dm_test.csv']
for i in range(2):
    raw_data = pd.read_csv(os.path.join('WikiQACorpus', raw_files[i]), sep='\t', header=None)
    raw_data.columns = ['left_value', 'right_value', 'label']
    raw_data.index.name = 'id'
    raw_data.to_csv(os.path.join('WikiQACorpus', csv_files[i]))

Step 2: Setup and train DeepMatcher model

Now we are ready to load and process the data for deepmatcher:

In [5]:
train, validation, test = dm.data.process(
    path='WikiQACorpus',
    train='dm_train.csv',
    validation='dm_valid.csv',
    test='dm_test.csv')
Rebuilding data cache because: {'One or more data files have been modified.'}
Load time: 6.962715303525329
Vocab time: 14.411666898056865
Metadata time: 4.01532360445708
Cache time: 7.646213295869529

Next, we create a deepmatcher model and train it. Note that since this is a demo, we do not perform hyperparameter tuning - we simply use the default settings for everything except the pos_neg_ratio param. This must be set since there are very few "positive matches" (candidates that correctly answer the question) in this dataset. In a real application setting you must tune other model hyperparameters as well to get optimal performance.

In [6]:
model = dm.MatchingModel()
model.run_train(
    train,
    validation,
    epochs=10,
    best_save_path='hybrid_model.pth',
    pos_neg_ratio=7)
* Number of trainable parameters: 2798703
===>  TRAIN Epoch 1 :
0% [██████████████████████████████] 100% | ETA: 00:00:00
Total time elapsed: 00:00:25
Finished Epoch 1 || Run Time:   18.7 | Load Time:    6.6 || F1:  12.55 | Prec:  18.40 | Rec:   9.52 || Ex/s: 803.13

===>  EVAL Epoch 1 :
0% [█████████████████] 100% | ETA: 00:00:00
Total time elapsed: 00:00:01
Finished Epoch 1 || Run Time:    1.1 | Load Time:    0.9 || F1:  29.49 | Prec:  21.77 | Rec:  45.71 || Ex/s: 1378.35

* Best F1: 29.493087557603683
Saving best model...
===>  TRAIN Epoch 2 :
0% [██████████████████████████████] 100% | ETA: 00:00:00
Total time elapsed: 00:00:25
Finished Epoch 2 || Run Time:   18.5 | Load Time:    6.7 || F1:  30.13 | Prec:  24.97 | Rec:  37.98 || Ex/s: 808.23

===>  EVAL Epoch 2 :
0% [█████████████████] 100% | ETA: 00:00:00
Total time elapsed: 00:00:01
Finished Epoch 2 || Run Time:    1.1 | Load Time:    0.9 || F1:  34.00 | Prec:  24.60 | Rec:  55.00 || Ex/s: 1389.86

* Best F1: 33.99558498896247
Saving best model...
===>  TRAIN Epoch 3 :
0% [██████████████████████████████] 100% | ETA: 00:00:00
Total time elapsed: 00:00:25
Finished Epoch 3 || Run Time:   18.6 | Load Time:    6.7 || F1:  40.77 | Prec:  31.73 | Rec:  57.02 || Ex/s: 804.25

===>  EVAL Epoch 3 :
0% [█████████████████] 100% | ETA: 00:00:00
Total time elapsed: 00:00:01
Finished Epoch 3 || Run Time:    1.0 | Load Time:    0.9 || F1:  31.07 | Prec:  22.40 | Rec:  50.71 || Ex/s: 1429.85

===>  TRAIN Epoch 4 :
0% [██████████████████████████████] 100% | ETA: 00:00:00
Total time elapsed: 00:00:25
Finished Epoch 4 || Run Time:   18.6 | Load Time:    6.7 || F1:  50.94 | Prec:  39.58 | Rec:  71.44 || Ex/s: 803.71

===>  EVAL Epoch 4 :
0% [█████████████████] 100% | ETA: 00:00:00
Total time elapsed: 00:00:01
Finished Epoch 4 || Run Time:    1.0 | Load Time:    0.9 || F1:  34.23 | Prec:  26.02 | Rec:  50.00 || Ex/s: 1427.62

* Best F1: 34.22982885085574
Saving best model...
===>  TRAIN Epoch 5 :
0% [██████████████████████████████] 100% | ETA: 00:00:00
Total time elapsed: 00:00:25
Finished Epoch 5 || Run Time:   18.6 | Load Time:    6.7 || F1:  63.19 | Prec:  50.67 | Rec:  83.94 || Ex/s: 805.32

===>  EVAL Epoch 5 :
0% [█████████████████] 100% | ETA: 00:00:00
Total time elapsed: 00:00:01
Finished Epoch 5 || Run Time:    1.0 | Load Time:    0.9 || F1:  35.14 | Prec:  33.33 | Rec:  37.14 || Ex/s: 1430.46

* Best F1: 35.13513513513514
Saving best model...
===>  TRAIN Epoch 6 :
0% [██████████████████████████████] 100% | ETA: 00:00:00
Total time elapsed: 00:00:25
Finished Epoch 6 || Run Time:   18.6 | Load Time:    6.7 || F1:  74.33 | Prec:  62.93 | Rec:  90.77 || Ex/s: 804.23

===>  EVAL Epoch 6 :
0% [█████████████████] 100% | ETA: 00:00:00
Total time elapsed: 00:00:01
Finished Epoch 6 || Run Time:    1.1 | Load Time:    0.9 || F1:  34.93 | Prec:  33.55 | Rec:  36.43 || Ex/s: 1388.46

===>  TRAIN Epoch 7 :
0% [██████████████████████████████] 100% | ETA: 00:00:00
Total time elapsed: 00:00:25
Finished Epoch 7 || Run Time:   18.6 | Load Time:    6.7 || F1:  82.96 | Prec:  73.62 | Rec:  95.00 || Ex/s: 805.43

===>  EVAL Epoch 7 :
0% [█████████████████] 100% | ETA: 00:00:00
Total time elapsed: 00:00:01
Finished Epoch 7 || Run Time:    1.1 | Load Time:    0.9 || F1:  32.07 | Prec:  27.09 | Rec:  39.29 || Ex/s: 1386.02

===>  TRAIN Epoch 8 :
0% [██████████████████████████████] 100% | ETA: 00:00:00
Total time elapsed: 00:00:25
Finished Epoch 8 || Run Time:   18.6 | Load Time:    6.7 || F1:  86.39 | Prec:  79.07 | Rec:  95.19 || Ex/s: 804.31

===>  EVAL Epoch 8 :
0% [█████████████████] 100% | ETA: 00:00:00
Total time elapsed: 00:00:01
Finished Epoch 8 || Run Time:    1.0 | Load Time:    0.9 || F1:  33.33 | Prec:  30.72 | Rec:  36.43 || Ex/s: 1426.20

===>  TRAIN Epoch 9 :
0% [██████████████████████████████] 100% | ETA: 00:00:00
Total time elapsed: 00:00:25
Finished Epoch 9 || Run Time:   18.7 | Load Time:    6.7 || F1:  91.81 | Prec:  86.68 | Rec:  97.60 || Ex/s: 802.35

===>  EVAL Epoch 9 :
0% [█████████████████] 100% | ETA: 00:00:00
Total time elapsed: 00:00:01
Finished Epoch 9 || Run Time:    1.1 | Load Time:    0.9 || F1:  34.88 | Prec:  38.14 | Rec:  32.14 || Ex/s: 1385.16

===>  TRAIN Epoch 10 :
0% [██████████████████████████████] 100% | ETA: 00:00:00
Total time elapsed: 00:00:25
Finished Epoch 10 || Run Time:   18.7 | Load Time:    6.7 || F1:  94.18 | Prec:  90.59 | Rec:  98.08 || Ex/s: 803.29

===>  EVAL Epoch 10 :
0% [█████████████████] 100% | ETA: 00:00:00
Total time elapsed: 00:00:01
Finished Epoch 10 || Run Time:    1.0 | Load Time:    0.9 || F1:  32.86 | Prec:  32.86 | Rec:  32.86 || Ex/s: 1426.35

Loading best model...

Now that we have a trained model, we obtain the predictions for the test data. Note that deepmatcher computes F1, precision and recall by default but these may not be optimal evaluation metrics for your end task. For instance, in Question Answering, the more relevant metrics are MAP and MRR which we will compute in the next step.

In [7]:
predictions = model.run_prediction(test, output_attributes=True)
===>  PREDICT Epoch 5 :
0% [██████████████████████████████] 100% | ETA: 00:00:00
Total time elapsed: 00:00:04
Finished Epoch 5 || Run Time:    2.5 | Load Time:    1.9 || F1:  28.88 | Prec:  26.50 | Rec:  31.74 || Ex/s: 1381.48

Step 3: Evaluate model using QA eval metrics

Finally, we compute the Mean Average Precision (MAP) and Mean Reciprocal Rank (MRR) using the model's predictions on the test set. Following the approach of the paper that introduced this dataset, questions in the test set without answers are ignored when computing these metrics.

In [8]:
MAP, MRR = 0, 0

grouped = predictions.groupby('left_value')
num_questions = 0
for question, answers in grouped:
    sorted_answers = answers.sort_values('match_score', ascending=False)
    
    p, ap = 0, 0
    top_answer_found = False
    for idx, answer in enumerate(sorted_answers.itertuples()):
        if answer.label == 1:
            if not top_answer_found:
                MRR += 1 / (idx + 1)
                top_answer_found = True
            p += 1
            ap += p / (idx + 1)
            
    if p > 0:
        ap /= p
        num_questions += 1
    MAP += ap
    
MAP /= num_questions
MRR /= num_questions

print('MAP:', MAP)
print('MRR:', MRR)
MAP: 0.6951284386872146
MRR: 0.7099084137865672