#!/usr/bin/env python # coding: utf-8 # # Sequence to Sequence with Attention Mechanism # In[1]: import sys sys.path.insert(0, '..') from mxnet import nd from mxnet.gluon import rnn, nn import d2l # In[2]: class Seq2SeqAttentionDecoder(d2l.Decoder): def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs): super(Seq2SeqAttentionDecoder, self).__init__(**kwargs) self.attention_cell = d2l.MLPAttention(num_hiddens, dropout) self.embedding = nn.Embedding(vocab_size, embed_size) self.rnn = rnn.LSTM(num_hiddens, num_layers, dropout=dropout) self.dense = nn.Dense(vocab_size, flatten=False) def init_state(self, enc_outputs, enc_valid_len, *args): outputs, hidden_state = enc_outputs # Transpose outputs to (batch_size, seq_len, hidden_size) return (outputs.swapaxes(0,1), hidden_state, enc_valid_len) def forward(self, X, state): enc_outputs, hidden_state, enc_valid_len = state X = self.embedding(X).swapaxes(0, 1) outputs = [] for x in X: # query shape: (batch_size, 1, hidden_size) query = hidden_state[0][-1].expand_dims(axis=1) # context has same shape as query context = self.attention_cell( query, enc_outputs, enc_outputs, enc_valid_len) # concatenate on the feature dimension x = nd.concat(context, x.expand_dims(axis=1), dim=-1) # reshape x to (1, batch_size, embed_size+hidden_size) out, hidden_state = self.rnn(x.swapaxes(0, 1), hidden_state) outputs.append(out) outputs = self.dense(nd.concat(*outputs, dim=0)) return outputs.swapaxes(0, 1), [enc_outputs, hidden_state, enc_valid_len] # Example # In[3]: encoder = d2l.Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2) encoder.initialize() decoder = Seq2SeqAttentionDecoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2) decoder.initialize() X = nd.zeros((4, 7)) state = decoder.init_state(encoder(X), None) out, state = decoder(X, state) out.shape, len(state), state[0].shape, len(state[1]), state[1][0].shape # ## Training # In[4]: embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.0 batch_size, num_examples, max_len = 64, 1e3, 10 lr, num_epochs, ctx = 0.005, 200, d2l.try_gpu() src_vocab, tgt_vocab, train_iter = d2l.load_data_nmt( batch_size, max_len, num_examples) encoder = d2l.Seq2SeqEncoder( len(src_vocab), embed_size, num_hiddens, num_layers, dropout) decoder = Seq2SeqAttentionDecoder( len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout) model = d2l.EncoderDecoder(encoder, decoder) d2l.train_ch7(model, train_iter, lr, num_epochs, ctx) # Predict # In[5]: for sentence in ['Go .', 'Wow !', "I'm OK .", 'I won !']: print(sentence + ' => ' + d2l.translate_ch7( model, sentence, src_vocab, tgt_vocab, max_len, ctx))