# !pip install torch pytorch_lightning datasets wandb
import os
import torch
import tqdm
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertConfig, BertTokenizer
from datasets import load_dataset
import pytorch_lightning as pl
import wandb
wandb: WARNING W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable.
# custom dataset class
class SentimentDataset(Dataset):
def __init__(self, tokenizer, text, target, max_len=180):
self.tokenizer = tokenizer
self.text = text
self.target = target
self.max_len = max_len
def __len__(self):
return len(self.text)
def __getitem__(self, idx):
text = self.text[idx]
target = self.target[idx]
# encode the text and target into tensors return the attention masks as well
encoding = self.tokenizer.encode_plus(
text=text,
add_special_tokens=True,
max_length=self.max_len,
return_token_type_ids=False,
pad_to_max_length=True,
return_attention_mask=True,
return_tensors='pt',
truncation=True
)
return {
'text': text,
'input_ids': encoding['input_ids'].flatten(),
'attention_mask': encoding['attention_mask'].flatten(),
'targets': torch.tensor(target, dtype=torch.long)
}
class BertClassifier(torch.nn.Module):
def __init__(self, config, model, dim=256, num_classes=2):
super(BertClassifier, self).__init__()
# create the model config and BERT initialize the pretrained BERT, also layers wise outputs
self.config = config
self.base = model
# classifier head [not useful]
self.head = torch.nn.Sequential(*[
torch.nn.Dropout(p=self.config.hidden_dropout_prob),
torch.nn.Linear(in_features=self.config.hidden_size, out_features=dim),
torch.nn.ReLU(),
torch.nn.Dropout(p=self.config.hidden_dropout_prob),
torch.nn.Linear(in_features=dim, out_features=num_classes)
])
def forward(self, input_ids, attention_mask=None):
# first output is top layer output, second output is context of input seq and third output will be layerwise token embeddings
top_layer, pooled, layers = self.base(input_ids, attention_mask)
outputs = self.head(pooled)
return top_layer, outputs, layers
class BertFinetuner(pl.LightningModule):
def __init__(self, model=None, tokenizer=None, data_file="./data/twitter/train.csv", use_cols=['review_text', 'sentiment'], batch_size=32):
super(BertFinetuner, self).__init__()
# initialize the BERT model c
self.model = model
self.data_file = data_file
self.use_cols = use_cols
self.batch_size = batch_size
self.tokenizer = tokenizer
self.f_score= Fbeta()
def accuracy(self, outputs, targets):
correct = 0
for i in range(outputs.shape[0]):
if outputs[i]==targets[i]:
correct+=1
return correct/outputs.shape[0]
def forward(self, input_ids, attention_mask=None):
top_layer, outputs, layers = self.model(input_ids, attention_mask)
return top_layer, outputs, layers
def configure_optimizers(self):
return torch.optim.Adam(params=self.parameters(), lr=1e-5)
def train_dataloader(self):
# first 30% data reserved for validation
train = load_dataset("csv", data_files=self.data_file, split='train[20%:]')
text, target = train['review_text'], train['sentiment']
dataset = SentimentDataset(tokenizer=self.tokenizer, text=text, target=target)
loader = DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=True)
return loader
def training_step(self, batch, batch_idx):
input_ids, attention_mask, targets = batch['input_ids'], batch['attention_mask'], batch['targets']
_, logits, _ = self(input_ids, attention_mask)
loss = F.cross_entropy(logits, targets)
acc = self.accuracy(logits.argmax(dim=1), targets)
wandb.log({"Loss": loss, "Accuracy": torch.tensor(acc)})
return {"loss": loss, "accuracy": torch.tensor(acc)}
def val_dataloader(self):
# first 30% data reserved for validation
val = load_dataset("csv", data_files=self.data_file, split='train[:20%]')
text, target = val['review_text'], val['sentiment']
dataset = SentimentDataset(tokenizer=self.tokenizer, text=text, target=target)
loader = DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=True)
return loader
def validation_step(self, batch, batch_idx):
input_ids, attention_mask, targets = batch['input_ids'], batch['attention_mask'], batch['targets']
_, logits, _ = self(input_ids, attention_mask)
loss = F.cross_entropy(logits, targets)
acc = self.accuracy(logits.argmax(dim=1), targets)
# wandb.log({"val_loss":loss, "val_accuracy":acc})
self.f_score(logits.argmax(dim=1), targets)
return {"val_loss": loss, "val_accuracy": torch.tensor(acc)}
def validation_epoch_end(self, outputs):
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
avg_acc = torch.stack([x['val_accuracy'] for x in outputs]).mean()
avg_f_score = self.f_score.compute()
wandb.log({"val_loss":avg_loss, "val_accuracy":avg_acc, "val_fb":avg_f_score})
return {'val_accuracy': avg_loss, 'val_accuracy': avg_acc, "val_fb":avg_f_score}
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.metrics import Fbeta
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, ProgressBar
ROOT_DIR = "../input/amazonproductsreview/amazon-review/"
DATASET = "dvd"
NUM_CLASSES = 2
BATCH_SIZE = 32
EPOCH = 20
# logger
logger = WandbLogger(
name=DATASET,
save_dir="../working/",
project="domain-adaptation",
log_model = True
)
# callbacks
early_stopping = EarlyStopping(
monitor="val_accuracy",
)
model_checkpoint = ModelCheckpoint(
filepath="{epoch}-{val_accuracy:.2f}-{val_loss:.2f}",
monitor="val_accuracy",
save_top_k=1,
)
progress_bar = ProgressBar()
# create the BERTConfig, BERTTokenizer, and BERTModel
model_name = "bert-base-uncased"
config = BertConfig.from_pretrained(model_name, output_hidden_states=True)
tokenizer = BertTokenizer.from_pretrained(model_name, do_lower_case=True)
bert = BertModel.from_pretrained(model_name, config=config)
classifier = BertClassifier(config=config, model=bert, num_classes=NUM_CLASSES)
model = BertFinetuner(
model=classifier,
data_file=os.path.join(ROOT_DIR, DATASET+".csv"),
tokenizer=tokenizer,
batch_size=BATCH_SIZE
)
tuner = pl.Trainer(
logger=logger,
gpus=[0],
checkpoint_callback=model_checkpoint,
max_epochs=EPOCH,
)
GPU available: True, used: True TPU available: False, using: 0 TPU cores LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
tuner.fit(model)
wandb: WARNING Calling wandb.login() after wandb.init() has no effect.
VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…
../working/wandb/run-20201026_093116-2jdo8tzd/logs/debug.log
../working/wandb/run-20201026_093116-2jdo8tzd/logs/debug-internal.log
val_loss | 0.65587 |
val_accuracy | 0.83062 |
val_fb | 0.83688 |
_step | 1020 |
_runtime | 746 |
_timestamp | 1603705422 |
Loss | 0.00673 |
Accuracy | 1.0 |
val_loss | █▄▁▁▂▁▃▃▃▄▄▅▆▄▆▅▅▆▅▆▇ |
val_accuracy | ▁▇█████████▇▇███████▇ |
val_fb | ▁▆██▇██▇███▇▇███████▇ |
_step | ▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███ |
_runtime | ▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███ |
_timestamp | ▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███ |
Loss | ██▅▅▄▃▄▄▂▅▁▄▂▂▁▁▁▁▁▂▁▄▁▁▁▁▃▁▁▃▁▁▁▁▁▁▁▃▃▁ |
Accuracy | ▁▃▇▆▇█▆▇█▆█▇█████████▇██████████████████ |
wandb: wandb version 0.10.8 is available! To upgrade, please run: wandb: $ pip install wandb --upgrade
../working/wandb/run-20201026_094759-1rj1ncqv
| Name | Type | Params ------------------------------------------- 0 | model | BertClassifier | 109 M 1 | f_score | Fbeta | 0 Using custom data configuration default
Downloading and preparing dataset csv/default-343faf1a87cc9b22 (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /root/.cache/huggingface/datasets/csv/default-343faf1a87cc9b22/0.0.0/49187751790fa4d820300fd4d0707896e5b941f1a9c644652645b866716a4ac4...
HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))
Dataset csv downloaded and prepared to /root/.cache/huggingface/datasets/csv/default-343faf1a87cc9b22/0.0.0/49187751790fa4d820300fd4d0707896e5b941f1a9c644652645b866716a4ac4. Subsequent calls will reuse this data.
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/utilities/distributed.py:45: UserWarning: Your val_dataloader has `shuffle=True`, it is best practice to turn this off for validation and test dataloaders. warnings.warn(*args, **kwargs)
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…
/opt/conda/lib/python3.7/site-packages/pytorch_lightning/utilities/distributed.py:45: UserWarning: The validation_epoch_end should not return anything as of 9.1.to log, use self.log(...) or self.write(...) directly in the LightningModule warnings.warn(*args, **kwargs) Using custom data configuration default Reusing dataset csv (/root/.cache/huggingface/datasets/csv/default-343faf1a87cc9b22/0.0.0/49187751790fa4d820300fd4d0707896e5b941f1a9c644652645b866716a4ac4)
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
1
PATH = DATASET+".pt"
# save the model
torch.save(classifier.state_dict(), PATH)
### Load from state dictionary
classifier_trained = BertClassifier(config=config, model=bert, num_classes=NUM_CLASSES)
classifier_trained.load_state_dict(torch.load(PATH))
# you can evaluate the model on top 20% data
<All keys matched successfully>
PATH = DATASET+".pt"
# save the model
torch.save(classifier.state_dict(), PATH)
### Load from state dictionary
classifier_dvd = BertClassifier(config=config, model=bert, num_classes=NUM_CLASSES)
classifier_dvd.load_state_dict(torch.load(PATH))
# you can evaluate the model on top 20% data
## There you go