Note: you can run this notebook live in Google Colab.
DeepMatcher can be easily be used for text matching tasks such Question Answering, Text Entailment, etc. In this tutorial we will see how to use DeepMatcher for Answer Selection, a major sub-task of Question Answering. Specifically, we will look at WikiQA, a benchmark dataset for Answer Selection. There are three main steps in this tutorial:
Before we begin, if you are running this notebook in Colab, you will first need to install necessary packages by running the code below:
try:
import deepmatcher
except:
!pip install -qqq deepmatcher
First let's import relevant packages and download the dataset:
import deepmatcher as dm
import pandas as pd
import os
!wget -qnc https://download.microsoft.com/download/E/5/F/E5FCFCEE-7005-4814-853D-DAA7C66507E0/WikiQACorpus.zip
!unzip -qn WikiQACorpus.zip
Let's see how this dataset looks like:
raw_train = pd.read_csv(os.path.join('WikiQACorpus', 'WikiQA-train.txt'), sep='\t', header=None)
raw_train.head()
0 | 1 | 2 | |
---|---|---|---|
0 | how are glacier caves formed ? | A partly submerged glacier cave on Perito More... | 0 |
1 | how are glacier caves formed ? | The ice facade is approximately 60 m high | 0 |
2 | how are glacier caves formed ? | Ice formations in the Titlis glacier cave | 0 |
3 | how are glacier caves formed ? | A glacier cave is a cave formed within the ice... | 1 |
4 | how are glacier caves formed ? | Glacier caves are often called ice caves , but... | 0 |
Clearly, it is not in the format deepmatcher
wants its input data to be in - this file has no column names, no ID column, and its not a CSV file. Let's fix that:
raw_train.columns = ['left_value', 'right_value', 'label']
raw_train.index.name = 'id'
raw_train.head()
left_value | right_value | label | |
---|---|---|---|
id | |||
0 | how are glacier caves formed ? | A partly submerged glacier cave on Perito More... | 0 |
1 | how are glacier caves formed ? | The ice facade is approximately 60 m high | 0 |
2 | how are glacier caves formed ? | Ice formations in the Titlis glacier cave | 0 |
3 | how are glacier caves formed ? | A glacier cave is a cave formed within the ice... | 1 |
4 | how are glacier caves formed ? | Glacier caves are often called ice caves , but... | 0 |
Looks good, now let's save this to disk and transform the validation and test data in the same way:
raw_train.to_csv(os.path.join('WikiQACorpus', 'dm_train.csv'))
raw_files = ['WikiQA-dev.txt', 'WikiQA-test.txt']
csv_files = ['dm_valid.csv', 'dm_test.csv']
for i in range(2):
raw_data = pd.read_csv(os.path.join('WikiQACorpus', raw_files[i]), sep='\t', header=None)
raw_data.columns = ['left_value', 'right_value', 'label']
raw_data.index.name = 'id'
raw_data.to_csv(os.path.join('WikiQACorpus', csv_files[i]))
Now we are ready to load and process the data for deepmatcher
:
train, validation, test = dm.data.process(
path='WikiQACorpus',
train='dm_train.csv',
validation='dm_valid.csv',
test='dm_test.csv')
Rebuilding data cache because: {'One or more data files have been modified.'} Load time: 6.962715303525329 Vocab time: 14.411666898056865 Metadata time: 4.01532360445708 Cache time: 7.646213295869529
Next, we create a deepmatcher
model and train it. Note that since this is a demo, we do not perform hyperparameter tuning - we simply use the default settings for everything except the pos_neg_ratio
param. This must be set since there are very few "positive matches" (candidates that correctly answer the question) in this dataset. In a real application setting you must tune other model hyperparameters as well to get optimal performance.
model = dm.MatchingModel()
model.run_train(
train,
validation,
epochs=10,
best_save_path='hybrid_model.pth',
pos_neg_ratio=7)
* Number of trainable parameters: 2798703 ===> TRAIN Epoch 1 :
0% [██████████████████████████████] 100% | ETA: 00:00:00 Total time elapsed: 00:00:25
Finished Epoch 1 || Run Time: 18.7 | Load Time: 6.6 || F1: 12.55 | Prec: 18.40 | Rec: 9.52 || Ex/s: 803.13 ===> EVAL Epoch 1 :
0% [█████████████████] 100% | ETA: 00:00:00 Total time elapsed: 00:00:01
Finished Epoch 1 || Run Time: 1.1 | Load Time: 0.9 || F1: 29.49 | Prec: 21.77 | Rec: 45.71 || Ex/s: 1378.35 * Best F1: 29.493087557603683 Saving best model... ===> TRAIN Epoch 2 :
0% [██████████████████████████████] 100% | ETA: 00:00:00 Total time elapsed: 00:00:25
Finished Epoch 2 || Run Time: 18.5 | Load Time: 6.7 || F1: 30.13 | Prec: 24.97 | Rec: 37.98 || Ex/s: 808.23 ===> EVAL Epoch 2 :
0% [█████████████████] 100% | ETA: 00:00:00 Total time elapsed: 00:00:01
Finished Epoch 2 || Run Time: 1.1 | Load Time: 0.9 || F1: 34.00 | Prec: 24.60 | Rec: 55.00 || Ex/s: 1389.86 * Best F1: 33.99558498896247 Saving best model... ===> TRAIN Epoch 3 :
0% [██████████████████████████████] 100% | ETA: 00:00:00 Total time elapsed: 00:00:25
Finished Epoch 3 || Run Time: 18.6 | Load Time: 6.7 || F1: 40.77 | Prec: 31.73 | Rec: 57.02 || Ex/s: 804.25 ===> EVAL Epoch 3 :
0% [█████████████████] 100% | ETA: 00:00:00 Total time elapsed: 00:00:01
Finished Epoch 3 || Run Time: 1.0 | Load Time: 0.9 || F1: 31.07 | Prec: 22.40 | Rec: 50.71 || Ex/s: 1429.85 ===> TRAIN Epoch 4 :
0% [██████████████████████████████] 100% | ETA: 00:00:00 Total time elapsed: 00:00:25
Finished Epoch 4 || Run Time: 18.6 | Load Time: 6.7 || F1: 50.94 | Prec: 39.58 | Rec: 71.44 || Ex/s: 803.71 ===> EVAL Epoch 4 :
0% [█████████████████] 100% | ETA: 00:00:00 Total time elapsed: 00:00:01
Finished Epoch 4 || Run Time: 1.0 | Load Time: 0.9 || F1: 34.23 | Prec: 26.02 | Rec: 50.00 || Ex/s: 1427.62 * Best F1: 34.22982885085574 Saving best model... ===> TRAIN Epoch 5 :
0% [██████████████████████████████] 100% | ETA: 00:00:00 Total time elapsed: 00:00:25
Finished Epoch 5 || Run Time: 18.6 | Load Time: 6.7 || F1: 63.19 | Prec: 50.67 | Rec: 83.94 || Ex/s: 805.32 ===> EVAL Epoch 5 :
0% [█████████████████] 100% | ETA: 00:00:00 Total time elapsed: 00:00:01
Finished Epoch 5 || Run Time: 1.0 | Load Time: 0.9 || F1: 35.14 | Prec: 33.33 | Rec: 37.14 || Ex/s: 1430.46 * Best F1: 35.13513513513514 Saving best model... ===> TRAIN Epoch 6 :
0% [██████████████████████████████] 100% | ETA: 00:00:00 Total time elapsed: 00:00:25
Finished Epoch 6 || Run Time: 18.6 | Load Time: 6.7 || F1: 74.33 | Prec: 62.93 | Rec: 90.77 || Ex/s: 804.23 ===> EVAL Epoch 6 :
0% [█████████████████] 100% | ETA: 00:00:00 Total time elapsed: 00:00:01
Finished Epoch 6 || Run Time: 1.1 | Load Time: 0.9 || F1: 34.93 | Prec: 33.55 | Rec: 36.43 || Ex/s: 1388.46 ===> TRAIN Epoch 7 :
0% [██████████████████████████████] 100% | ETA: 00:00:00 Total time elapsed: 00:00:25
Finished Epoch 7 || Run Time: 18.6 | Load Time: 6.7 || F1: 82.96 | Prec: 73.62 | Rec: 95.00 || Ex/s: 805.43 ===> EVAL Epoch 7 :
0% [█████████████████] 100% | ETA: 00:00:00 Total time elapsed: 00:00:01
Finished Epoch 7 || Run Time: 1.1 | Load Time: 0.9 || F1: 32.07 | Prec: 27.09 | Rec: 39.29 || Ex/s: 1386.02 ===> TRAIN Epoch 8 :
0% [██████████████████████████████] 100% | ETA: 00:00:00 Total time elapsed: 00:00:25
Finished Epoch 8 || Run Time: 18.6 | Load Time: 6.7 || F1: 86.39 | Prec: 79.07 | Rec: 95.19 || Ex/s: 804.31 ===> EVAL Epoch 8 :
0% [█████████████████] 100% | ETA: 00:00:00 Total time elapsed: 00:00:01
Finished Epoch 8 || Run Time: 1.0 | Load Time: 0.9 || F1: 33.33 | Prec: 30.72 | Rec: 36.43 || Ex/s: 1426.20 ===> TRAIN Epoch 9 :
0% [██████████████████████████████] 100% | ETA: 00:00:00 Total time elapsed: 00:00:25
Finished Epoch 9 || Run Time: 18.7 | Load Time: 6.7 || F1: 91.81 | Prec: 86.68 | Rec: 97.60 || Ex/s: 802.35 ===> EVAL Epoch 9 :
0% [█████████████████] 100% | ETA: 00:00:00 Total time elapsed: 00:00:01
Finished Epoch 9 || Run Time: 1.1 | Load Time: 0.9 || F1: 34.88 | Prec: 38.14 | Rec: 32.14 || Ex/s: 1385.16 ===> TRAIN Epoch 10 :
0% [██████████████████████████████] 100% | ETA: 00:00:00 Total time elapsed: 00:00:25
Finished Epoch 10 || Run Time: 18.7 | Load Time: 6.7 || F1: 94.18 | Prec: 90.59 | Rec: 98.08 || Ex/s: 803.29 ===> EVAL Epoch 10 :
0% [█████████████████] 100% | ETA: 00:00:00 Total time elapsed: 00:00:01
Finished Epoch 10 || Run Time: 1.0 | Load Time: 0.9 || F1: 32.86 | Prec: 32.86 | Rec: 32.86 || Ex/s: 1426.35 Loading best model...
Now that we have a trained model, we obtain the predictions for the test data. Note that deepmatcher
computes F1, precision and recall by default but these may not be optimal evaluation metrics for your end task. For instance, in Question Answering, the more relevant metrics are MAP and MRR which we will compute in the next step.
predictions = model.run_prediction(test, output_attributes=True)
===> PREDICT Epoch 5 :
0% [██████████████████████████████] 100% | ETA: 00:00:00 Total time elapsed: 00:00:04
Finished Epoch 5 || Run Time: 2.5 | Load Time: 1.9 || F1: 28.88 | Prec: 26.50 | Rec: 31.74 || Ex/s: 1381.48
Finally, we compute the Mean Average Precision (MAP) and Mean Reciprocal Rank (MRR) using the model's predictions on the test set. Following the approach of the paper that introduced this dataset, questions in the test set without answers are ignored when computing these metrics.
MAP, MRR = 0, 0
grouped = predictions.groupby('left_value')
num_questions = 0
for question, answers in grouped:
sorted_answers = answers.sort_values('match_score', ascending=False)
p, ap = 0, 0
top_answer_found = False
for idx, answer in enumerate(sorted_answers.itertuples()):
if answer.label == 1:
if not top_answer_found:
MRR += 1 / (idx + 1)
top_answer_found = True
p += 1
ap += p / (idx + 1)
if p > 0:
ap /= p
num_questions += 1
MAP += ap
MAP /= num_questions
MRR /= num_questions
print('MAP:', MAP)
print('MRR:', MRR)
MAP: 0.6951284386872146 MRR: 0.7099084137865672