import sys
sys.path.insert(0, '..')
import time
from mxnet import nd, init, gluon, autograd
from mxnet.gluon import nn, rnn, loss as gloss
import d2l
class Seq2SeqEncoder(d2l.Encoder):
def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,
dropout=0, **kwargs):
super(Seq2SeqEncoder, self).__init__(**kwargs)
self.embedding = nn.Embedding(vocab_size, embed_size)
self.rnn = rnn.LSTM(num_hiddens, num_layers, dropout=dropout)
def forward(self, X, *args):
X = self.embedding(X) # X shape: (batch_size, seq_len, embed_size)
X = X.swapaxes(0, 1) # RNN needs first axes to be time
state = self.rnn.begin_state(batch_size=X.shape[1], ctx=X.context)
out, state = self.rnn(X, state)
# The shape of out is (seq_len, batch_size, num_hiddens).
# state contains the hidden state and the memory cell
# of the last time step, the shape is (num_layers, batch_size, num_hiddens)
return out, state
Sanity test
encoder = Seq2SeqEncoder(vocab_size=10, embed_size=8,
num_hiddens=16, num_layers=2)
encoder.initialize()
X = nd.zeros((4, 7))
output, state = encoder(X)
output.shape, len(state), state[0].shape, state[1].shape
((7, 4, 16), 2, (2, 4, 16), (2, 4, 16))
class Seq2SeqDecoder(d2l.Decoder):
def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,
dropout=0, **kwargs):
super(Seq2SeqDecoder, self).__init__(**kwargs)
self.embedding = nn.Embedding(vocab_size, embed_size)
self.rnn = rnn.LSTM(num_hiddens, num_layers, dropout=dropout)
self.dense = nn.Dense(vocab_size, flatten=False)
def init_state(self, enc_outputs, *args):
return enc_outputs[1]
def forward(self, X, state):
X = self.embedding(X).swapaxes(0, 1)
out, state = self.rnn(X, state)
# Make the batch to be the first dimension to simplify loss computation.
out = self.dense(out).swapaxes(0, 1)
return out, state
Sanity test
decoder = Seq2SeqDecoder(vocab_size=10, embed_size=8,
num_hiddens=16, num_layers=2)
decoder.initialize()
state = decoder.init_state(encoder(X))
out, state = decoder(X, state)
out.shape, len(state), state[0].shape, state[1].shape
((4, 7, 10), 2, (2, 4, 16), (2, 4, 16))
X = nd.array([[1,2,3], [4,5,6]])
nd.SequenceMask(X, nd.array([1,2]), True, axis=1)
[[1. 0. 0.] [4. 5. 0.]] <NDArray 2x3 @cpu(0)>
Apply to $n$-dim tensor $X$, it sets X[i, len[i]:, :, ..., :] = 0
.
X = nd.ones((2, 3, 4))
nd.SequenceMask(X, nd.array([1,2]), True, value=-1, axis=1)
[[[ 1. 1. 1. 1.] [-1. -1. -1. -1.] [-1. -1. -1. -1.]] [[ 1. 1. 1. 1.] [ 1. 1. 1. 1.] [-1. -1. -1. -1.]]] <NDArray 2x3x4 @cpu(0)>
The masked version of the softmax cross-entropy loss
class MaskedSoftmaxCELoss(gloss.SoftmaxCELoss):
# pred shape: (batch_size, seq_len, vocab_size)
# label shape: (batch_size, seq_len)
# valid_length shape: (batch_size, )
def forward(self, pred, label, valid_length):
# the sample weights shape should be (batch_size, seq_len, 1)
weights = nd.ones_like(label).expand_dims(axis=-1)
weights = nd.SequenceMask(weights, valid_length, True, axis=1)
return super(MaskedSoftmaxCELoss, self).forward(pred, label, weights)
Sanity check
loss = MaskedSoftmaxCELoss()
loss(nd.ones((3, 4, 10)), nd.ones((3, 4)), nd.array([4, 2, 0]))
[2.3025851 1.1512926 0. ] <NDArray 3 @cpu(0)>
def train_ch7(model, data_iter, lr, num_epochs, ctx): # Saved in d2l
model.initialize(init.Xavier(), force_reinit=True, ctx=ctx)
trainer = gluon.Trainer(model.collect_params(),
'adam', {'learning_rate': lr})
loss = MaskedSoftmaxCELoss()
tic = time.time()
for epoch in range(1, num_epochs+1):
l_sum, num_tokens_sum = 0.0, 0.0
for batch in data_iter:
X, X_vlen, Y, Y_vlen = [x.as_in_context(ctx) for x in batch]
Y_input, Y_label, Y_vlen = Y[:,:-1], Y[:,1:], Y_vlen-1
with autograd.record():
Y_hat, _ = model(X, Y_input, X_vlen, Y_vlen)
l = loss(Y_hat, Y_label, Y_vlen)
l.backward()
d2l.grad_clipping_gluon(model, 5, ctx)
num_tokens = Y_vlen.sum().asscalar()
trainer.step(num_tokens)
l_sum += l.sum().asscalar()
num_tokens_sum += num_tokens
if epoch % 50 == 0:
print("epoch %d, loss %.3f, time %.1f sec" % (
epoch, l_sum/num_tokens_sum, time.time()-tic))
tic = time.time()
Train the model
embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.0
batch_size, num_examples, max_len = 64, 1e3, 10
lr, num_epochs, ctx = 0.005, 300, d2l.try_gpu()
src_vocab, tgt_vocab, train_iter = d2l.load_data_nmt(
batch_size, max_len, num_examples)
encoder = Seq2SeqEncoder(
len(src_vocab), embed_size, num_hiddens, num_layers, dropout)
decoder = Seq2SeqDecoder(
len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout)
model = d2l.EncoderDecoder(encoder, decoder)
train_ch7(model, train_iter, lr, num_epochs, ctx)
epoch 50, loss 0.120, time 10.2 sec epoch 100, loss 0.066, time 10.4 sec epoch 150, loss 0.041, time 10.3 sec epoch 200, loss 0.031, time 10.3 sec epoch 250, loss 0.028, time 10.0 sec epoch 300, loss 0.025, time 9.5 sec
def translate_ch7(model, src_sentence, src_vocab, tgt_vocab, max_len, ctx):
src_tokens = src_vocab[src_sentence.lower().split(' ')]
src_len = len(src_tokens)
if src_len < max_len:
src_tokens += [src_vocab.pad] * (max_len - src_len)
enc_X = nd.array(src_tokens, ctx=ctx)
enc_valid_length = nd.array([src_len], ctx=ctx)
# use expand_dim to add the batch_size dimension.
enc_outputs = model.encoder(enc_X.expand_dims(axis=0), enc_valid_length)
dec_state = model.decoder.init_state(enc_outputs, enc_valid_length)
dec_X = nd.array([tgt_vocab.bos], ctx=ctx).expand_dims(axis=0)
predict_tokens = []
for _ in range(max_len):
Y, dec_state = model.decoder(dec_X, dec_state)
# The token with highest score is used as the next time step input.
dec_X = Y.argmax(axis=2)
py = dec_X.squeeze(axis=0).astype('int32').asscalar()
if py == tgt_vocab.eos:
break
predict_tokens.append(py)
return ' '.join(tgt_vocab.to_tokens(predict_tokens))
Try several examples:
for sentence in ['Go .', 'Wow !', "I'm OK .", 'I won !']:
print(sentence + ' => ' + translate_ch7(
model, sentence, src_vocab, tgt_vocab, max_len, ctx))
Go . => va ! Wow ! => <unk> ! I'm OK . => je vais bien . I won ! => je l'ai emporté !