import sys
sys.path.insert(0, '..')
from mxnet import nd
from mxnet.gluon import rnn, nn
import d2l
class Seq2SeqAttentionDecoder(d2l.Decoder):
def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,
dropout=0, **kwargs):
super(Seq2SeqAttentionDecoder, self).__init__(**kwargs)
self.attention_cell = d2l.MLPAttention(num_hiddens, dropout)
self.embedding = nn.Embedding(vocab_size, embed_size)
self.rnn = rnn.LSTM(num_hiddens, num_layers, dropout=dropout)
self.dense = nn.Dense(vocab_size, flatten=False)
def init_state(self, enc_outputs, enc_valid_len, *args):
outputs, hidden_state = enc_outputs
# Transpose outputs to (batch_size, seq_len, hidden_size)
return (outputs.swapaxes(0,1), hidden_state, enc_valid_len)
def forward(self, X, state):
enc_outputs, hidden_state, enc_valid_len = state
X = self.embedding(X).swapaxes(0, 1)
outputs = []
for x in X:
# query shape: (batch_size, 1, hidden_size)
query = hidden_state[0][-1].expand_dims(axis=1)
# context has same shape as query
context = self.attention_cell(
query, enc_outputs, enc_outputs, enc_valid_len)
# concatenate on the feature dimension
x = nd.concat(context, x.expand_dims(axis=1), dim=-1)
# reshape x to (1, batch_size, embed_size+hidden_size)
out, hidden_state = self.rnn(x.swapaxes(0, 1), hidden_state)
outputs.append(out)
outputs = self.dense(nd.concat(*outputs, dim=0))
return outputs.swapaxes(0, 1), [enc_outputs, hidden_state,
enc_valid_len]
Example
encoder = d2l.Seq2SeqEncoder(vocab_size=10, embed_size=8,
num_hiddens=16, num_layers=2)
encoder.initialize()
decoder = Seq2SeqAttentionDecoder(vocab_size=10, embed_size=8,
num_hiddens=16, num_layers=2)
decoder.initialize()
X = nd.zeros((4, 7))
state = decoder.init_state(encoder(X), None)
out, state = decoder(X, state)
out.shape, len(state), state[0].shape, len(state[1]), state[1][0].shape
((4, 7, 10), 3, (4, 7, 16), 2, (2, 4, 16))
embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.0
batch_size, num_examples, max_len = 64, 1e3, 10
lr, num_epochs, ctx = 0.005, 200, d2l.try_gpu()
src_vocab, tgt_vocab, train_iter = d2l.load_data_nmt(
batch_size, max_len, num_examples)
encoder = d2l.Seq2SeqEncoder(
len(src_vocab), embed_size, num_hiddens, num_layers, dropout)
decoder = Seq2SeqAttentionDecoder(
len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout)
model = d2l.EncoderDecoder(encoder, decoder)
d2l.train_ch7(model, train_iter, lr, num_epochs, ctx)
epoch 50, loss 0.115, time 34.2 sec epoch 100, loss 0.067, time 34.6 sec epoch 150, loss 0.043, time 35.4 sec epoch 200, loss 0.032, time 34.6 sec
Predict
for sentence in ['Go .', 'Wow !', "I'm OK .", 'I won !']:
print(sentence + ' => ' + d2l.translate_ch7(
model, sentence, src_vocab, tgt_vocab, max_len, ctx))
Go . => va ! Wow ! => <unk> ! I'm OK . => je vais bien . I won ! => je l'ai emporté !