I recommend you take a look at these material first.
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
import torch.nn.functional as F
import nltk
import random
import numpy as np
from collections import Counter, OrderedDict
import nltk
from copy import deepcopy
import os
import re
import unicodedata
flatten = lambda l: [item for sublist in l for item in sublist]
from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence
random.seed(1024)
USE_CUDA = torch.cuda.is_available()
gpus = [0]
torch.cuda.set_device(gpus[0])
FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor
ByteTensor = torch.cuda.ByteTensor if USE_CUDA else torch.ByteTensor
def getBatch(batch_size, train_data):
random.shuffle(train_data)
sindex = 0
eindex = batch_size
while eindex < len(train_data):
batch = train_data[sindex: eindex]
temp = eindex
eindex = eindex + batch_size
sindex = temp
yield batch
if eindex >= len(train_data):
batch = train_data[sindex:]
yield batch
def pad_to_batch(batch, w_to_ix): # for bAbI dataset
fact,q,a = list(zip(*batch))
max_fact = max([len(f) for f in fact])
max_len = max([f.size(1) for f in flatten(fact)])
max_q = max([qq.size(1) for qq in q])
max_a = max([aa.size(1) for aa in a])
facts, fact_masks, q_p, a_p = [], [], [], []
for i in range(len(batch)):
fact_p_t = []
for j in range(len(fact[i])):
if fact[i][j].size(1) < max_len:
fact_p_t.append(torch.cat([fact[i][j], Variable(LongTensor([w_to_ix['<PAD>']] * (max_len - fact[i][j].size(1)))).view(1, -1)], 1))
else:
fact_p_t.append(fact[i][j])
while len(fact_p_t) < max_fact:
fact_p_t.append(Variable(LongTensor([w_to_ix['<PAD>']] * max_len)).view(1, -1))
fact_p_t = torch.cat(fact_p_t)
facts.append(fact_p_t)
fact_masks.append(torch.cat([Variable(ByteTensor(tuple(map(lambda s: s ==0, t.data))), volatile=False) for t in fact_p_t]).view(fact_p_t.size(0), -1))
if q[i].size(1) < max_q:
q_p.append(torch.cat([q[i], Variable(LongTensor([w_to_ix['<PAD>']] * (max_q - q[i].size(1)))).view(1, -1)], 1))
else:
q_p.append(q[i])
if a[i].size(1) < max_a:
a_p.append(torch.cat([a[i], Variable(LongTensor([w_to_ix['<PAD>']] * (max_a - a[i].size(1)))).view(1, -1)], 1))
else:
a_p.append(a[i])
questions = torch.cat(q_p)
answers = torch.cat(a_p)
question_masks = torch.cat([Variable(ByteTensor(tuple(map(lambda s: s ==0, t.data))), volatile=False) for t in questions]).view(questions.size(0), -1)
return facts, fact_masks, questions, question_masks, answers
def prepare_sequence(seq, to_index):
idxs = list(map(lambda w: to_index[w] if to_index.get(w) is not None else to_index["<UNK>"], seq))
return Variable(LongTensor(idxs))
def bAbI_data_load(path):
try:
data = open(path).readlines()
except:
print("Such a file does not exist at %s".format(path))
return None
data = [d[:-1] for d in data]
data_p = []
fact = []
qa = []
try:
for d in data:
index = d.split(' ')[0]
if index == '1':
fact = []
qa = []
if '?' in d:
temp = d.split('\t')
q = temp[0].strip().replace('?', '').split(' ')[1:] + ['?']
a = temp[1].split() + ['</s>']
stemp = deepcopy(fact)
data_p.append([stemp, q, a])
else:
tokens = d.replace('.', '').split(' ')[1:] + ['</s>']
fact.append(tokens)
except:
print("Please check the data is right")
return None
return data_p
train_data = bAbI_data_load('../dataset/bAbI/en-10k/qa5_three-arg-relations_train.txt')
train_data[0]
[[['Bill', 'travelled', 'to', 'the', 'office', '</s>'], ['Bill', 'picked', 'up', 'the', 'football', 'there', '</s>'], ['Bill', 'went', 'to', 'the', 'bedroom', '</s>'], ['Bill', 'gave', 'the', 'football', 'to', 'Fred', '</s>']], ['What', 'did', 'Bill', 'give', 'to', 'Fred', '?'], ['football', '</s>']]
fact,q,a = list(zip(*train_data))
vocab = list(set(flatten(flatten(fact)) + flatten(q) + flatten(a)))
word2index={'<PAD>': 0, '<UNK>': 1, '<s>': 2, '</s>': 3}
for vo in vocab:
if word2index.get(vo) is None:
word2index[vo] = len(word2index)
index2word = {v:k for k, v in word2index.items()}
len(word2index)
44
for t in train_data:
for i,fact in enumerate(t[0]):
t[0][i] = prepare_sequence(fact, word2index).view(1, -1)
t[1] = prepare_sequence(t[1], word2index).view(1, -1)
t[2] = prepare_sequence(t[2], word2index).view(1, -1)
class DMN(nn.Module):
def __init__(self, input_size, hidden_size, output_size, dropout_p=0.1):
super(DMN, self).__init__()
self.hidden_size = hidden_size
self.embed = nn.Embedding(input_size, hidden_size, padding_idx=0) #sparse=True)
self.input_gru = nn.GRU(hidden_size, hidden_size, batch_first=True)
self.question_gru = nn.GRU(hidden_size, hidden_size, batch_first=True)
self.gate = nn.Sequential(
nn.Linear(hidden_size * 4, hidden_size),
nn.Tanh(),
nn.Linear(hidden_size, 1),
nn.Sigmoid()
)
self.attention_grucell = nn.GRUCell(hidden_size, hidden_size)
self.memory_grucell = nn.GRUCell(hidden_size, hidden_size)
self.answer_grucell = nn.GRUCell(hidden_size * 2, hidden_size)
self.answer_fc = nn.Linear(hidden_size, output_size)
self.dropout = nn.Dropout(dropout_p)
def init_hidden(self, inputs):
hidden = Variable(torch.zeros(1, inputs.size(0), self.hidden_size))
return hidden.cuda() if USE_CUDA else hidden
def init_weight(self):
nn.init.xavier_uniform(self.embed.state_dict()['weight'])
for name, param in self.input_gru.state_dict().items():
if 'weight' in name: nn.init.xavier_normal(param)
for name, param in self.question_gru.state_dict().items():
if 'weight' in name: nn.init.xavier_normal(param)
for name, param in self.gate.state_dict().items():
if 'weight' in name: nn.init.xavier_normal(param)
for name, param in self.attention_grucell.state_dict().items():
if 'weight' in name: nn.init.xavier_normal(param)
for name, param in self.memory_grucell.state_dict().items():
if 'weight' in name: nn.init.xavier_normal(param)
for name, param in self.answer_grucell.state_dict().items():
if 'weight' in name: nn.init.xavier_normal(param)
nn.init.xavier_normal(self.answer_fc.state_dict()['weight'])
self.answer_fc.bias.data.fill_(0)
def forward(self, facts, fact_masks, questions, question_masks, num_decode, episodes=3, is_training=False):
"""
facts : (B,T_C,T_I) / LongTensor in List # batch_size, num_of_facts, length_of_each_fact(padded)
fact_masks : (B,T_C,T_I) / ByteTensor in List # batch_size, num_of_facts, length_of_each_fact(padded)
questions : (B,T_Q) / LongTensor # batch_size, question_length
question_masks : (B,T_Q) / ByteTensor # batch_size, question_length
"""
# Input Module
C = [] # encoded facts
for fact, fact_mask in zip(facts, fact_masks):
embeds = self.embed(fact)
if is_training:
embeds = self.dropout(embeds)
hidden = self.init_hidden(fact)
outputs, hidden = self.input_gru(embeds, hidden)
real_hidden = []
for i, o in enumerate(outputs): # B,T,D
real_length = fact_mask[i].data.tolist().count(0)
real_hidden.append(o[real_length - 1])
C.append(torch.cat(real_hidden).view(fact.size(0), -1).unsqueeze(0))
encoded_facts = torch.cat(C) # B,T_C,D
# Question Module
embeds = self.embed(questions)
if is_training:
embeds = self.dropout(embeds)
hidden = self.init_hidden(questions)
outputs, hidden = self.question_gru(embeds, hidden)
if isinstance(question_masks, torch.autograd.variable.Variable):
real_question = []
for i, o in enumerate(outputs): # B,T,D
real_length = question_masks[i].data.tolist().count(0)
real_question.append(o[real_length - 1])
encoded_question = torch.cat(real_question).view(questions.size(0), -1) # B,D
else: # for inference mode
encoded_question = hidden.squeeze(0) # B,D
# Episodic Memory Module
memory = encoded_question
T_C = encoded_facts.size(1)
B = encoded_facts.size(0)
for i in range(episodes):
hidden = self.init_hidden(encoded_facts.transpose(0, 1)[0]).squeeze(0) # B,D
for t in range(T_C):
#TODO: fact masking
#TODO: gate function => softmax
z = torch.cat([
encoded_facts.transpose(0, 1)[t] * encoded_question, # B,D , element-wise product
encoded_facts.transpose(0, 1)[t] * memory, # B,D , element-wise product
torch.abs(encoded_facts.transpose(0,1)[t] - encoded_question), # B,D
torch.abs(encoded_facts.transpose(0,1)[t] - memory) # B,D
], 1)
g_t = self.gate(z) # B,1 scalar
hidden = g_t * self.attention_grucell(encoded_facts.transpose(0, 1)[t], hidden) + (1 - g_t) * hidden
e = hidden
memory = self.memory_grucell(e, memory)
# Answer Module
answer_hidden = memory
start_decode = Variable(LongTensor([[word2index['<s>']] * memory.size(0)])).transpose(0, 1)
y_t_1 = self.embed(start_decode).squeeze(1) # B,D
decodes = []
for t in range(num_decode):
answer_hidden = self.answer_grucell(torch.cat([y_t_1, encoded_question], 1), answer_hidden)
decodes.append(F.log_softmax(self.answer_fc(answer_hidden),1))
return torch.cat(decodes, 1).view(B * num_decode, -1)
It takes for a while if you use just cpu.
HIDDEN_SIZE = 80
BATCH_SIZE = 64
LR = 0.001
EPOCH = 50
NUM_EPISODE = 3
EARLY_STOPPING = False
model = DMN(len(word2index), HIDDEN_SIZE, len(word2index))
model.init_weight()
if USE_CUDA:
model = model.cuda()
loss_function = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(model.parameters(), lr=LR)
for epoch in range(EPOCH):
losses = []
if EARLY_STOPPING:
break
for i,batch in enumerate(getBatch(BATCH_SIZE, train_data)):
facts, fact_masks, questions, question_masks, answers = pad_to_batch(batch, word2index)
model.zero_grad()
pred = model(facts, fact_masks, questions, question_masks, answers.size(1), NUM_EPISODE, True)
loss = loss_function(pred, answers.view(-1))
losses.append(loss.data.tolist()[0])
loss.backward()
optimizer.step()
if i % 100 == 0:
print("[%d/%d] mean_loss : %0.2f" %(epoch, EPOCH, np.mean(losses)))
if np.mean(losses) < 0.01:
EARLY_STOPPING = True
print("Early Stopping!")
break
losses = []
[0/50] mean_loss : 3.86 [0/50] mean_loss : 1.32 [1/50] mean_loss : 0.68 [1/50] mean_loss : 0.65 [2/50] mean_loss : 0.62 [2/50] mean_loss : 0.65 [3/50] mean_loss : 0.65 [3/50] mean_loss : 0.64 [4/50] mean_loss : 0.60 [4/50] mean_loss : 0.62 [5/50] mean_loss : 0.63 [5/50] mean_loss : 0.61 [6/50] mean_loss : 0.60 [6/50] mean_loss : 0.61 [7/50] mean_loss : 0.63 [7/50] mean_loss : 0.60 [8/50] mean_loss : 0.62 [8/50] mean_loss : 0.60 [9/50] mean_loss : 0.58 [9/50] mean_loss : 0.60 [10/50] mean_loss : 0.60 [10/50] mean_loss : 0.60 [11/50] mean_loss : 0.62 [11/50] mean_loss : 0.60 [12/50] mean_loss : 0.61 [12/50] mean_loss : 0.60 [13/50] mean_loss : 0.57 [13/50] mean_loss : 0.60 [14/50] mean_loss : 0.59 [14/50] mean_loss : 0.60 [15/50] mean_loss : 0.61 [15/50] mean_loss : 0.60 [16/50] mean_loss : 0.59 [16/50] mean_loss : 0.60 [17/50] mean_loss : 0.59 [17/50] mean_loss : 0.60 [18/50] mean_loss : 0.51 [18/50] mean_loss : 0.50 [19/50] mean_loss : 0.44 [19/50] mean_loss : 0.37 [20/50] mean_loss : 0.30 [20/50] mean_loss : 0.33 [21/50] mean_loss : 0.31 [21/50] mean_loss : 0.31 [22/50] mean_loss : 0.29 [22/50] mean_loss : 0.31 [23/50] mean_loss : 0.29 [23/50] mean_loss : 0.31 [24/50] mean_loss : 0.24 [24/50] mean_loss : 0.31 [25/50] mean_loss : 0.30 [25/50] mean_loss : 0.30 [26/50] mean_loss : 0.14 [26/50] mean_loss : 0.16 [27/50] mean_loss : 0.12 [27/50] mean_loss : 0.15 [28/50] mean_loss : 0.18 [28/50] mean_loss : 0.14 [29/50] mean_loss : 0.12 [29/50] mean_loss : 0.14 [30/50] mean_loss : 0.14 [30/50] mean_loss : 0.14 [31/50] mean_loss : 0.13 [31/50] mean_loss : 0.14 [32/50] mean_loss : 0.11 [32/50] mean_loss : 0.13 [33/50] mean_loss : 0.08 [33/50] mean_loss : 0.06 [34/50] mean_loss : 0.01 [34/50] mean_loss : 0.03 [35/50] mean_loss : 0.01 Early Stopping!
def pad_to_fact(fact, x_to_ix): # this is for inference
max_x = max([s.size(1) for s in fact])
x_p = []
for i in range(len(fact)):
if fact[i].size(1) < max_x:
x_p.append(torch.cat([fact[i], Variable(LongTensor([x_to_ix['<PAD>']] * (max_x - fact[i].size(1)))).view(1, -1)], 1))
else:
x_p.append(fact[i])
fact = torch.cat(x_p)
fact_mask = torch.cat([Variable(ByteTensor(tuple(map(lambda s: s ==0, t.data))), volatile=False) for t in fact]).view(fact.size(0), -1)
return fact, fact_mask
test_data = bAbI_data_load('../dataset/bAbI/en-10k/qa5_three-arg-relations_test.txt')
for t in test_data:
for i, fact in enumerate(t[0]):
t[0][i] = prepare_sequence(fact, word2index).view(1, -1)
t[1] = prepare_sequence(t[1], word2index).view(1, -1)
t[2] = prepare_sequence(t[2], word2index).view(1, -1)
accuracy = 0
for t in test_data:
fact, fact_mask = pad_to_fact(t[0], word2index)
question = t[1]
question_mask = Variable(ByteTensor([0] * t[1].size(1)), volatile=False).unsqueeze(0)
answer = t[2].squeeze(0)
model.zero_grad()
pred = model([fact], [fact_mask], question, question_mask, answer.size(0), NUM_EPISODE)
if pred.max(1)[1].data.tolist() == answer.data.tolist():
accuracy += 1
print(accuracy/len(test_data) * 100)
97.39999999999999
t = random.choice(test_data)
fact, fact_mask = pad_to_fact(t[0], word2index)
question = t[1]
question_mask = Variable(ByteTensor([0] * t[1].size(1)), volatile=False).unsqueeze(0)
answer = t[2].squeeze(0)
model.zero_grad()
pred = model([fact], [fact_mask], question, question_mask, answer.size(0), NUM_EPISODE)
print("Facts : ")
print('\n'.join([' '.join(list(map(lambda x: index2word[x],f))) for f in fact.data.tolist()]))
print("")
print("Question : ",' '.join(list(map(lambda x: index2word[x], question.data.tolist()[0]))))
print("")
print("Answer : ",' '.join(list(map(lambda x: index2word[x], answer.data.tolist()))))
print("Prediction : ",' '.join(list(map(lambda x: index2word[x], pred.max(1)[1].data.tolist()))))
Facts : Bill went back to the bedroom </s> Mary went to the office </s> <PAD> Jeff journeyed to the kitchen </s> <PAD> Fred journeyed to the kitchen </s> <PAD> Fred got the milk there </s> <PAD> Fred handed the milk to Jeff </s> Jeff passed the milk to Fred </s> Fred gave the milk to Jeff </s> Question : Who received the milk ? Answer : Jeff </s> Prediction : Jeff </s>