from seq2seq import *
path = Config().data_path()/'giga-fren'
data = load_data(path)
model_path = Config().model_path()
emb_enc = torch.load(model_path/'fr_emb.pth')
emb_dec = torch.load(model_path/'en_emb.pth')
class Seq2SeqRNN_attn(nn.Module):
def __init__(self, emb_enc, emb_dec, nh, out_sl, nl=2, bos_idx=0, pad_idx=1):
super().__init__()
self.nl,self.nh,self.out_sl,self.pr_force = nl,nh,out_sl,1
self.bos_idx,self.pad_idx = bos_idx,pad_idx
self.emb_enc,self.emb_dec = emb_enc,emb_dec
self.emb_sz_enc,self.emb_sz_dec = emb_enc.embedding_dim,emb_dec.embedding_dim
self.voc_sz_dec = emb_dec.num_embeddings
self.emb_enc_drop = nn.Dropout(0.15)
self.gru_enc = nn.GRU(self.emb_sz_enc, nh, num_layers=nl, dropout=0.25,
batch_first=True, bidirectional=True)
self.out_enc = nn.Linear(2*nh, self.emb_sz_dec, bias=False)
self.gru_dec = nn.GRU(self.emb_sz_dec + 2*nh, self.emb_sz_dec, num_layers=nl,
dropout=0.1, batch_first=True)
self.out_drop = nn.Dropout(0.35)
self.out = nn.Linear(self.emb_sz_dec, self.voc_sz_dec)
self.out.weight.data = self.emb_dec.weight.data
self.enc_att = nn.Linear(2*nh, self.emb_sz_dec, bias=False)
self.hid_att = nn.Linear(self.emb_sz_dec, self.emb_sz_dec)
self.V = self.init_param(self.emb_sz_dec)
def encoder(self, bs, inp):
h = self.initHidden(bs)
emb = self.emb_enc_drop(self.emb_enc(inp))
enc_out, hid = self.gru_enc(emb, 2*h)
pre_hid = hid.view(2, self.nl, bs, self.nh).permute(1,2,0,3).contiguous()
pre_hid = pre_hid.view(self.nl, bs, 2*self.nh)
hid = self.out_enc(pre_hid)
return hid,enc_out
def decoder(self, dec_inp, hid, enc_att, enc_out):
hid_att = self.hid_att(hid[-1])
u = torch.tanh(enc_att + hid_att[:,None])
attn_wgts = F.softmax(u @ self.V, 1)
ctx = (attn_wgts[...,None] * enc_out).sum(1)
emb = self.emb_dec(dec_inp)
outp, hid = self.gru_dec(torch.cat([emb, ctx], 1)[:,None], hid)
outp = self.out(self.out_drop(outp[:,0]))
return hid, outp
def forward(self, inp, targ=None):
bs, sl = inp.size()
hid,enc_out = self.encoder(bs, inp)
dec_inp = inp.new_zeros(bs).long() + self.bos_idx
enc_att = self.enc_att(enc_out)
res = []
for i in range(self.out_sl):
hid, outp = self.decoder(dec_inp, hid, enc_att, enc_out)
res.append(outp)
dec_inp = outp.max(1)[1]
if (dec_inp==self.pad_idx).all(): break
if (targ is not None) and (random.random()<self.pr_force):
if i>=targ.shape[1]: continue
dec_inp = targ[:,i]
return torch.stack(res, dim=1)
def initHidden(self, bs): return one_param(self).new_zeros(2*self.nl, bs, self.nh)
def init_param(self, *sz): return nn.Parameter(torch.randn(sz)/math.sqrt(sz[0]))
model = Seq2SeqRNN_attn(emb_enc, emb_dec, 256, 30)
learn = Learner(data, model, loss_func=seq2seq_loss, metrics=seq2seq_acc,
callback_fns=partial(TeacherForcing, end_epoch=30))
learn.fit_one_cycle(5, 3e-3)
epoch | train_loss | valid_loss | seq2seq_acc | bleu | time |
---|---|---|---|---|---|
0 | 2.411852 | 3.832309 | 0.489762 | 0.266689 | 02:08 |
1 | 1.962503 | 4.193816 | 0.490516 | 0.398154 | 01:52 |
2 | 1.705578 | 4.441123 | 0.456639 | 0.390446 | 01:51 |
3 | 1.562641 | 4.048814 | 0.472333 | 0.403023 | 01:49 |
4 | 1.373088 | 4.048331 | 0.477947 | 0.413261 | 01:50 |
# learn.save('5')
learn.load('5');
/home/jhoward/anaconda3/lib/python3.7/site-packages/torch/serialization.py:256: UserWarning: Couldn't retrieve source code for container of type Seq2SeqRNN_attn. It won't be checked for correctness upon loading. "type " + obj.__name__ + ". It won't be checked "
def preds_acts(learn, ds_type=DatasetType.Valid):
"Same as `get_predictions` but also returns non-reconstructed activations"
learn.model.eval()
ds = learn.data.train_ds
rxs,rys,rzs,xs,ys,zs = [],[],[],[],[],[] # 'r' == 'reconstructed'
with torch.no_grad():
for xb,yb in progress_bar(learn.dl(ds_type)):
out = learn.model(xb)
for x,y,z in zip(xb,yb,out):
rxs.append(ds.x.reconstruct(x))
rys.append(ds.y.reconstruct(y))
preds = z.argmax(1)
rzs.append(ds.y.reconstruct(preds))
for a,b in zip([xs,ys,zs],[x,y,z]): a.append(b)
return rxs,rys,rzs,xs,ys,zs
rxs,rys,rzs,xs,ys,zs = preds_acts(learn)
idx=701
rx,ry,rz = rxs[idx],rys[idx],rzs[idx]
x,y,z = xs[idx],ys[idx],zs[idx]
rx,ry,rz
(Text xxbos quelles sont les lacunes qui existent encore dans notre connaissance du travail autonome et sur lesquelles les recherches devraient se concentrer à l’avenir ?, Text xxbos what gaps remain in our knowledge of xxunk on which future research should focus ?, Text xxbos what gaps are needed in our work and what is the research of the work and what research will be in place to future ?)
def select_topk(outp, k=5):
probs = F.softmax(outp,dim=-1)
vals,idxs = probs.topk(k, dim=-1)
return idxs[torch.randint(k, (1,))]
from random import choice
def select_nucleus(outp, p=0.5):
probs = F.softmax(outp,dim=-1)
idxs = torch.argsort(probs, descending=True)
res,cumsum = [],0.
for idx in idxs:
res.append(idx)
cumsum += probs[idx]
if cumsum>p: return idxs.new_tensor([choice(res)])
def decode(self, inp):
inp = inp[None]
bs, sl = inp.size()
hid,enc_out = self.encoder(bs, inp)
dec_inp = inp.new_zeros(bs).long() + self.bos_idx
enc_att = self.enc_att(enc_out)
res = []
for i in range(self.out_sl):
hid, outp = self.decoder(dec_inp, hid, enc_att, enc_out)
dec_inp = select_nucleus(outp[0], p=0.3)
# dec_inp = select_topk(outp[0], k=2)
res.append(dec_inp)
if (dec_inp==self.pad_idx).all(): break
return torch.cat(res)
def predict_with_decode(learn, x, y):
learn.model.eval()
ds = learn.data.train_ds
with torch.no_grad():
out = decode(learn.model, x)
rx = ds.x.reconstruct(x)
ry = ds.y.reconstruct(y)
rz = ds.y.reconstruct(out)
return rx,ry,rz
rx,ry,rz = predict_with_decode(learn, x, y)
rz
Text xxbos what gaps are needed in our understanding of work and security and how research will need to be put in place ?