import os import sys import torch from torch.nn import functional as F import numpy as np from torchtext import data from torchtext import datasets from torchtext.vocab import Vectors, GloVe tokenize = lambda x: x.split() TEXT = data.Field(sequential=True, tokenize=tokenize, lower=True, include_lengths=True, batch_first=True, fix_length=200) LABEL = data.LabelField() train_data, test_data = datasets.IMDB.splits(TEXT, LABEL) TEXT.build_vocab(train_data, vectors=GloVe(name='6B', dim=300)) LABEL.build_vocab(train_data) word_embeddings = TEXT.vocab.vectors print ("Length of Text Vocabulary: " + str(len(TEXT.vocab))) print ("Vector size of Text Vocabulary: ", TEXT.vocab.vectors.size()) print ("Label Length: " + str(len(LABEL.vocab))) train_data, valid_data = train_data.split() # Further splitting of training_data to create new training_data & validation_data train_iter, valid_iter, test_iter = data.BucketIterator.splits((train_data, valid_data, test_data), batch_size=32, sort_key=lambda x: len(x.text), repeat=False, shuffle=True) '''Alternatively we can also use the default configurations''' #train_iter_, test_iter_ = datasets.IMDB.iters(batch_size=32) vocab_size = len(TEXT.vocab) #return TEXT, vocab_size, word_embeddings, train_iter, valid_iter, test_iter from google.colab import drive drive.mount('/content/drive') import os os.chdir("/content/drive/My Drive/Colab Notebooks/Optimization project") os.getcwd() file_path = "/content/drive/My Drive/Colab Notebooks/Optimization project/IMDB" #directory = os.path.dirname(file_path) try: os.stat(file_path) except: os.mkdir(file_path) import sug from sug import SUG import torch from torch.optim import Optimizer import math import copy class SUG(Optimizer): def __init__(self, params, l_0, d_0=0, prob=1., eps=1e-4, momentum=0, dampening=0, weight_decay=0, nesterov=False): if l_0 < 0.0: raise ValueError("Invalid Lipsitz constant of gradient: {}".format(l_0)) if d_0 < 0.0: raise ValueError("Invalid disperion of gradient: {}".format(d_0)) if momentum < 0.0: raise ValueError("Invalid momentum value: {}".format(momentum)) if weight_decay < 0.0: raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) defaults = dict(L=l_0, momentum=momentum, dampening=dampening, weight_decay=weight_decay, nesterov=nesterov) if nesterov and (momentum <= 0 or dampening != 0): raise ValueError("Nesterov momentum requires a momentum and zero dampening") self.Lips = l_0 self.prev_Lips = l_0 self.D_0 = d_0 self.eps = eps self.prob = prob self.start_param = params self.upd_sq_grad_norm = None self.sq_grad_norm = None self.loss = torch.tensor(0.) self.cur_loss = 0 self.closure = None super(SUG, self).__init__(params, defaults) def __setstate__(self, state): super(SUG, self).__setstate__(state) for group in self.param_groups: group.setdefault('nesterov', False) def comp_batch_size(self): """Returns optimal batch size for given d_0, eps and l_0; """ return math.ceil(2 * self.D_0 * self.eps / self.prev_Lips) def step(self, loss, closure): """Performs a single optimization step. Arguments: loss : current loss closure (callable, optional): A closure that reevaluates the model and returns the loss. """ self.start_params = [] self.loss = loss self.sq_grad_norm = 0 self.cur_loss = loss self.closure = closure for gr_idx, group in enumerate(self.param_groups): weight_decay = group['weight_decay'] momentum = group['momentum'] dampening = group['dampening'] nesterov = group['nesterov'] self.start_params.append([]) for p_idx, p in enumerate(group['params']): self.start_params[gr_idx].append([p.data.clone()]) if p.grad is None: continue self.start_params[gr_idx][p_idx].append(p.grad.data.clone()) d_p = self.start_params[gr_idx][p_idx][1] p_ = self.start_params[gr_idx][p_idx][0] if weight_decay != 0: d_p.add_(weight_decay, p.data) self.cur_loss += weight_decay * torch.sum(p * p).item() self.sq_grad_norm += torch.sum(d_p * d_p).item() if momentum != 0: param_state = self.state[p] if 'momentum_buffer' not in param_state: buf = param_state['momentum_buffer'] = torch.zeros_like(p.data) buf.mul_(momentum).add_(d_p) else: buf = param_state['momentum_buffer'] buf.mul_(momentum).add_(1 - dampening, d_p) if nesterov: d_p = d_p.add(momentum, buf) else: d_p = buf self.start_params[gr_idx][p_idx][1] = d_p i = 0 self.Lips = max(self.prev_Lips / 2, 0.1) difference = -1 while difference < 0 or i == 0: if (i > 0): self.Lips = max(self.Lips * 2, 0.1) for gr_idx, group in enumerate(self.param_groups): for p_idx, p in enumerate(group['params']): if p.grad is None: continue start_param_val = self.start_params[gr_idx][p_idx][0] start_param_grad = self.start_params[gr_idx][p_idx][1] p.data = start_param_val - 1/(2*self.Lips) * start_param_grad difference, upd_loss = self.stop_criteria() i += 1 self.prev_Lips = self.Lips return self.Lips, i def stop_criteria(self): """Checks if the Lipsitz constant of gradient is appropriate + 2L_k / 2 ||x_k - w_k||^2 = - 1 / (2L_k)||g(x_k)||^2 + 1 / (4L_k)||g(x_k)||^2 = -1 / (4L_k)||g(x_k)||^2 """ upd_loss = self.closure() major = self.cur_loss - 1 / (4 * self.Lips) * self.sq_grad_norm return major - upd_loss - self.l2_reg() + self.eps / 10, upd_loss def get_lipsitz_const(self): """Returns current Lipsitz constant of the gradient of the loss function """ return self.Lips def get_sq_grad(self): """Returns the current second norm of the gradient of the loss function calculated by the formula ||f'(p_1,...,p_n)||_2^2 ~ \sum\limits_{i=1}^n ((df/dp_i) * (df/dp_i))(p1,...,p_n)) """ self.upd_sq_grad_norm = 0 for gr_idx, group in enumerate(self.param_groups): for p_idx, p in enumerate(group['params']): if p.grad is None: continue self.upd_sq_grad_norm += torch.sum(p.grad.data * p.grad.data).item() return self.upd_sq_grad_norm def l2_reg(self): """Returns the current l2 regularization addiction """ self.upd_l2_reg = 0 for gr_idx, group in enumerate(self.param_groups): weight_decay = group['weight_decay'] if weight_decay != 0: for p_idx, p in enumerate(group['params']): self.upd_l2_reg += weight_decay * torch.sum(p * p).item() return self.upd_l2_reg device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") device import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.autograd import Variable from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence class SimpleLSTMBaseline(nn.Module): def __init__(self, hidden_dim, emb_dim=300, num_linear=1): super().__init__() self.embedding = nn.Embedding(len(TEXT.vocab), emb_dim) self.encoder = nn.LSTM(emb_dim, hidden_dim, num_layers=1, bidirectional=True, batch_first=True) self.linear1 = nn.Linear(2 * hidden_dim, 32) self.linear1.weight.data.fill_(2) self.linear2 = nn.Linear(32, 2) self.linear2.weight.data.fill_(2) def forward(self, seq, lens): embeds = self.embedding(seq) packed = pack_padded_sequence(embeds, lens, batch_first=True) hdn, _ = self.encoder(packed) hdn, _ = pad_packed_sequence(hdn, batch_first=True) output = nn.functional.max_pool1d(hdn, kernel_size=10) output = nn.functional.relu(self.linear1(hdn[:,1,:])) prob = nn.functional.log_softmax(self.linear2(output), -1) return prob class LSTMClassifier(nn.Module): def __init__(self, batch_size, output_size, hidden_size, vocab_size, embedding_length, weights): super(LSTMClassifier, self).__init__() """ Arguments --------- batch_size : Size of the batch which is same as the batch_size of the data returned by the TorchText BucketIterator output_size : 2 = (pos, neg) hidden_sie : Size of the hidden_state of the LSTM vocab_size : Size of the vocabulary containing unique words embedding_length : Embeddding dimension of GloVe word embeddings weights : Pre-trained GloVe word_embeddings which we will use to create our word_embedding look-up table """ self.batch_size = batch_size self.output_size = output_size self.hidden_size = hidden_size self.vocab_size = vocab_size self.embedding_length = embedding_length self.num_layers = 1 self.word_embeddings = nn.Embedding(vocab_size, embedding_length)# Initializing the look-up table. self.word_embeddings.weight = nn.Parameter(weights, requires_grad=False) # Assigning the look-up table to the pre-trained GloVe word embedding. self.lstm = nn.LSTM(embedding_length, hidden_size, batch_first=True, bidirectional=False, num_layers=self.num_layers) self.label = nn.Linear(1 * hidden_size * self.num_layers, output_size) def forward(self, input_sentence, batch_size=None): """ Parameters ---------- input_sentence: input_sentence of shape = (batch_size, num_sequences) batch_size : default = None. Used only for prediction on a single sentence after training (batch_size = 1) Returns ------- Output of the linear layer containing logits for positive & negative class which receives its input as the final_hidden_state of the LSTM final_output.shape = (batch_size, output_size) """ ''' Here we will map all the indexes present in the input sequence to the corresponding word vector using our pre-trained word_embedddins.''' input = self.word_embeddings(input_sentence) # embedded input of shape = (batch_size, num_sequences, embedding_length) #input = input.permute(1, 0, 2) # input.size() = (num_sequences, batch_size, embedding_length) batch_size = input_sentence.size(0) h_0 = Variable(torch.zeros(1 * self.num_layers, batch_size, self.hidden_size).cuda()) c_0 = Variable(torch.zeros(1 * self.num_layers, batch_size, self.hidden_size).cuda()) #packed = pack_padded_sequence(input, lens, batch_first=True) #output, (final_hidden_state, final_cell_state) = self.lstm(packed, (h_0, c_0)) #output, _ = pad_packed_sequence(output, batch_first=True) #print(input.size()) output, (final_hidden_state, final_cell_state) = self.lstm(input, (h_0, c_0)) final_output = self.label(final_hidden_state.view(batch_size, self.num_layers*1*self.hidden_size)) # final_hidden_state.size() = (2, batch_size, hidden_size) & final_output.size() = (batch_size, output_size) return final_output import torch import torch.nn as nn from torch.autograd import Variable from torch.nn import functional as F import numpy as np class AttentionModel(torch.nn.Module): def __init__(self, batch_size, output_size, hidden_size, vocab_size, embedding_length, weights): super(AttentionModel, self).__init__() """ Arguments --------- batch_size : Size of the batch which is same as the batch_size of the data returned by the TorchText BucketIterator output_size : 2 = (pos, neg) hidden_sie : Size of the hidden_state of the LSTM vocab_size : Size of the vocabulary containing unique words embedding_length : Embeddding dimension of GloVe word embeddings weights : Pre-trained GloVe word_embeddings which we will use to create our word_embedding look-up table -------- """ self.batch_size = batch_size self.output_size = output_size self.hidden_size = hidden_size self.vocab_size = vocab_size self.embedding_length = embedding_length self.word_embeddings = nn.Embedding(vocab_size, embedding_length) self.word_embeddings.weights = nn.Parameter(weights, requires_grad=False) self.lstm = nn.LSTM(embedding_length, hidden_size) self.label = nn.Linear(hidden_size, output_size) #self.attn_fc_layer = nn.Linear() def attention_net(self, lstm_output, final_state): """ Now we will incorporate Attention mechanism in our LSTM model. In this new model, we will use attention to compute soft alignment score corresponding between each of the hidden_state and the last hidden_state of the LSTM. We will be using torch.bmm for the batch matrix multiplication. Arguments --------- lstm_output : Final output of the LSTM which contains hidden layer outputs for each sequence. final_state : Final time-step hidden state (h_n) of the LSTM --------- Returns : It performs attention mechanism by first computing weights for each of the sequence present in lstm_output and and then finally computing the new hidden state. Tensor Size : hidden.size() = (batch_size, hidden_size) attn_weights.size() = (batch_size, num_seq) soft_attn_weights.size() = (batch_size, num_seq) new_hidden_state.size() = (batch_size, hidden_size) """ hidden = final_state.squeeze(0) attn_weights = torch.bmm(lstm_output, hidden.unsqueeze(2)).squeeze(2) soft_attn_weights = F.softmax(attn_weights, 1) new_hidden_state = torch.bmm(lstm_output.transpose(1, 2), soft_attn_weights.unsqueeze(2)).squeeze(2) return new_hidden_state def forward(self, input_sentences, batch_size=None): """ Parameters ---------- input_sentence: input_sentence of shape = (batch_size, num_sequences) batch_size : default = None. Used only for prediction on a single sentence after training (batch_size = 1) Returns ------- Output of the linear layer containing logits for pos & neg class which receives its input as the new_hidden_state which is basically the output of the Attention network. final_output.shape = (batch_size, output_size) """ batch_size = input_sentences.size(0) input = self.word_embeddings(input_sentences) input = input.permute(1, 0, 2) h_0 = Variable(torch.zeros(1, batch_size, self.hidden_size).cuda()) c_0 = Variable(torch.zeros(1, batch_size, self.hidden_size).cuda()) output, (final_hidden_state, final_cell_state) = self.lstm(input, (h_0, c_0)) # final_hidden_state.size() = (1, batch_size, hidden_size) output = output.permute(1, 0, 2) # output.size() = (batch_size, num_seq, hidden_size) attn_output = self.attention_net(output, final_hidden_state) logits = self.label(attn_output) return nn.functional.log_softmax(logits, -1) import time import math def time_since(since): s = time.time() - since m = math.floor(s / 60) s -= m * 60 return '%dm %ds' % (m, s) def model_step(model, optimizer, criterion, inputs, labels): outputs = model(inputs) loss = criterion(outputs, labels) acc = (torch.argmax(outputs, 1) == labels).float().sum().item() if model.training: optimizer.zero_grad() loss.backward(retain_graph=True) if optimizer.__class__.__name__ != 'SUG': optimizer.step() else: def closure(): optimizer.zero_grad() upd_outputs = model(inputs) upd_loss = criterion(upd_outputs, labels).item() return upd_loss optimizer.step(loss.item(), closure) return loss.item(), acc def train(model, trainloader, criterion, optimizer, path=None, n_epochs=2, validloader=None, eps=1e-5, print_every=1): tr_loss, val_loss, lips, times, grad, tr_acc, val_acc = ([] for i in range(7)) start_time = time.time() model.to(device=device) print(len(list(trainloader))) for ep in range(n_epochs): model.train() i = 0 tot_acc = 0 n_ex = 0 for i, batch in enumerate(trainloader): #t, l = batch #(text, lens), target = t text = batch.text[0] lens = batch.text[1] target = batch.label target = torch.autograd.Variable(target).long() if torch.cuda.is_available(): text = text.cuda() target = target.cuda() loss, acc = model_step(model, optimizer, criterion, text, target) tot_acc += acc n_ex += text.size(0) tr_loss.append(loss) if optimizer.__class__.__name__ == 'SUG': lips.append(optimizer.get_lipsitz_const()) grad.append(optimizer.get_sq_grad) if i % 100 == 0: print(tr_loss[-1], i) times.append(time_since(start_time)) model.zero_grad() optimizer.zero_grad() states = { 'epoch': n_epochs, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'tr_loss' : tr_loss, 'val_loss' : val_loss, 'lips' : lips, 'grad' : grad, 'times' : times } if path is not None: torch.save(states, path) tr_acc.append(tot_acc / n_ex) times.append(time_since(start_time)) if ep % print_every == 0: print("Epoch {}, training loss {}, time passed {}, training accuracy {}".format(ep, sum(tr_loss[-i:]) / i, time_since(start_time), tr_acc[-1])) if validloader is None: continue model.zero_grad() model.eval() j = 0 count = 0 n_ex = 0 for j, batch in enumerate(validloader): text = batch.text[0] target = batch.label target = torch.autograd.Variable(target).long() if torch.cuda.is_available(): text = text.cuda() target = target.cuda() outputs = model(text) #outputs_lab = torch.argmax(outputs, 1) count += (torch.argmax(outputs, 1) == target).float().sum().item() n_ex += outputs.size(0) val_loss.append(criterion(outputs, target).item()) val_acc.append(count / n_ex) if ep % print_every == 0: print("Validation loss {}, validation accuracy {}".format(sum(val_loss[-j:]) / j, val_acc[-1])) return tr_loss, times, val_loss, lips, grad, tr_acc, val_acc def concat_states(state1, state2): states = { 'epoch': state1['epoch'] + state2['epoch'], 'state_dict': state2['state_dict'], 'optimizer': state2['optimizer'], 'tr_loss' : state1['tr_loss'] + state2['tr_loss'], 'val_loss' : state1['val_loss'] + state2['val_loss'], 'lips' : state1['lips'] + state2['lips'], 'grad' : state1['grad'] + state2['grad'], #'times' : state1['times'] + list(map(lambda x: x + state1['times'][-1],state2['times'])) 'times' : state1['times'] + state2['times'] } return states print_every = 1 n_epochs = 10 tr_loss = {} tr_loss['sgd'] = {} val_loss = {} val_loss['sgd'] = {} #lrs = [0.05, 0.01, 0.005] em_sz = 128 hidden_size = 256 embedding_length = 300 nl = 2 torch.manual_seed(999) batch_size = 32 criterion = nn.CrossEntropyLoss() n_epochs = 20 vocab_size = int(vocab_size) vocab_size lrs = [0.0001, 0.001] for lr in lrs: model = LSTMClassifier(batch_size, 2, hidden_size, vocab_size, embedding_length, word_embeddings) print("SGD lr={}, momentum=0. :".format(lr)) optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.) tr_loss['sgd'][lr], times, val_loss['sgd'][lr], lips, grad, tr_acc, val_acc = train(model, train_iter, criterion, optimizer, n_epochs=n_epochs, print_every=print_every, validloader=valid_iter) states = { 'epoch': n_epochs, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'tr_loss' : tr_loss['sgd'][lr], 'val_loss' : val_loss['sgd'][lr], 'lips' : lips, 'grad' : grad, 'times' : times, 'tr_acc' : tr_acc, 'val_acc' : val_acc } torch.save(states, './IMDB/LSTM_' + str(lr)) l_0 = 20 model = LSTMClassifier(batch_size, 2, hidden_size, vocab_size, embedding_length, word_embeddings) print("SUG l_0={}, momentum=0. :".format(l_0)) optimizer = SUG(model.parameters(), l_0=l_0, momentum=0.) tr_loss['sug'], times, val_loss['sug'], lips, grad, tr_acc, val_acc = train(model, train_iter, criterion, optimizer, n_epochs=n_epochs, print_every=print_every, validloader=valid_iter) states = { 'epoch': n_epochs, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'tr_loss' : tr_loss['sug'], 'val_loss' : val_loss['sug'], 'lips' : lips, 'grad' : grad, 'times' : times, 'tr_acc' : tr_acc, 'val_acc' : val_acc } torch.save(states, './IMDB/LSTM_sug') lrs = [0.01] for lr in lrs: model = LSTMClassifier(batch_size, 2, hidden_size, vocab_size, embedding_length, word_embeddings) print("SGD lr={}, momentum=0. :".format(lr)) optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9) tr_loss['sgd'][lr], times, val_loss['sgd'][lr], lips, grad, tr_acc, val_acc = train(model, train_iter, criterion, optimizer, n_epochs=n_epochs, print_every=print_every, validloader=valid_iter) states = { 'epoch': n_epochs, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'tr_loss' : tr_loss['sgd'][lr], 'val_loss' : val_loss['sgd'][lr], 'lips' : lips, 'grad' : grad, 'times' : times, 'tr_acc' : tr_acc, 'val_acc' : val_acc } torch.save(states, './IMDB/LSTM_' + str(lr)) lrs = [0.001] for lr in lrs: model = LSTMClassifier(batch_size, 2, hidden_size, vocab_size, embedding_length, word_embeddings) print("SGD lr={}, momentum=0. :".format(lr)) optimizer = optim.Adam(model.parameters(), lr=lr) tr_loss['sgd'][lr], times, val_loss['sgd'][lr], lips, grad, tr_acc, val_acc = train(model, train_iter, criterion, optimizer, n_epochs=n_epochs, print_every=print_every, validloader=valid_iter) states = { 'epoch': n_epochs, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'tr_loss' : tr_loss['sgd'][lr], 'val_loss' : val_loss['sgd'][lr], 'lips' : lips, 'grad' : grad, 'times' : times, 'tr_acc' : tr_acc, 'val_acc' : val_acc } torch.save(states, './IMDB/LSTM_adam_' + str(lr)) lrs = [0.0001, 0.001] for lr in lrs: model = AttentionModel(batch_size, 2, hidden_size, vocab_size, embedding_length, word_embeddings) print("SGD lr={}, momentum=0. :".format(lr)) optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.) tr_loss['sgd'][lr], times, val_loss['sgd'][lr], lips, grad, tr_acc, val_acc = train(model, train_iter, criterion, optimizer, n_epochs=n_epochs, print_every=print_every, validloader=valid_iter) states = { 'epoch': n_epochs, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'tr_loss' : tr_loss['sgd'][lr], 'val_loss' : val_loss['sgd'][lr], 'lips' : lips, 'grad' : grad, 'times' : times, 'tr_acc' : tr_acc, 'val_acc' : val_acc } torch.save(states, './IMDB/attn_' + str(lr)) l_0 = 20 model = AttentionModel(batch_size, 2, hidden_size, vocab_size, embedding_length, word_embeddings) print("SUG l_0={}, momentum=0. :".format(l_0)) optimizer = SUG(model.parameters(), l_0=l_0, momentum=0.) tr_loss['sug'], times, val_loss['sug'], lips, grad, tr_acc, val_acc = train(model, train_iter, criterion, optimizer, n_epochs=n_epochs, print_every=print_every, validloader=valid_iter) states = { 'epoch': n_epochs, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'tr_loss' : tr_loss['sug'], 'val_loss' : val_loss['sug'], 'lips' : lips, 'grad' : grad, 'times' : times, 'tr_acc' : tr_acc, 'val_acc' : val_acc } torch.save(states, './IMDB/attn_sug') lrs = [0.0001] for lr in lrs: model = AttentionModel(batch_size, 2, hidden_size, vocab_size, embedding_length, word_embeddings) print("SGD lr={}, momentum=0. :".format(lr)) optimizer = optim.Adam(model.parameters(), lr=lr) tr_loss['sgd'][lr], times, val_loss['sgd'][lr], lips, grad, tr_acc, val_acc = train(model, train_iter, criterion, optimizer, n_epochs=n_epochs, print_every=print_every, validloader=valid_iter) states = { 'epoch': n_epochs, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'tr_loss' : tr_loss['sgd'][lr], 'val_loss' : val_loss['sgd'][lr], 'lips' : lips, 'grad' : grad, 'times' : times, 'tr_acc' : tr_acc, 'val_acc' : val_acc } torch.save(states, './IMDB/attn_adam_' + str(lr)) l_0 = 20 model = AttentionModel(batch_size, 2, hidden_size, vocab_size, embedding_length, word_embeddings) print("SUG l_0={}, momentum=0. :".format(l_0)) optimizer = SUG(model.parameters(), l_0=l_0, momentum=0.9, weight_decay=1e-3) tr_loss['sug'], times, val_loss['sug'], lips, grad, tr_acc, val_acc = train(model, train_iter, criterion, optimizer, n_epochs=n_epochs, print_every=print_every, validloader=valid_iter) states = { 'epoch': n_epochs, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'tr_loss' : tr_loss['sug'], 'val_loss' : val_loss['sug'], 'lips' : lips, 'grad' : grad, 'times' : times, 'tr_acc' : tr_acc, 'val_acc' : val_acc } torch.save(states, './IMDB/attn_sug_0.9') l_0 = 20 model = AttentionModel(batch_size, 2, hidden_size, vocab_size, embedding_length, word_embeddings) print("SUG l_0={}, momentum=0. :".format(l_0)) optimizer = SUG(model.parameters(), l_0=l_0, momentum=0.5, weight_decay=0.) tr_loss['sug'], times, val_loss['sug'], lips, grad, tr_acc, val_acc = train(model, train_iter, criterion, optimizer, n_epochs=n_epochs, print_every=print_every, validloader=valid_iter) states = { 'epoch': n_epochs, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'tr_loss' : tr_loss['sug'], 'val_loss' : val_loss['sug'], 'lips' : lips, 'grad' : grad, 'times' : times, 'tr_acc' : tr_acc, 'val_acc' : val_acc } torch.save(states, './IMDB/attn_sug_0.5_wd_1e-4') lrs = [0.0001] for lr in lrs: model = AttentionModel(batch_size, 2, hidden_size, vocab_size, embedding_length, word_embeddings) print("SGD lr={}, momentum=0. :".format(lr)) optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4) tr_loss['sgd'][lr], times, val_loss['sgd'][lr], lips, grad, tr_acc, val_acc = train(model, train_iter, criterion, optimizer, n_epochs=n_epochs, print_every=print_every, validloader=valid_iter) states = { 'epoch': n_epochs, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'tr_loss' : tr_loss['sgd'][lr], 'val_loss' : val_loss['sgd'][lr], 'lips' : lips, 'grad' : grad, 'times' : times, 'tr_acc' : tr_acc, 'val_acc' : val_acc } torch.save(states, './IMDB/atnn_adam_' + str(lr)+'_wd_1e-4')