import itertools
import pandas as pd
import torch
from mnist_ten.data import train_labeled_loader, train_rotoflip_loader
from mnist_ten.models import weights_path, classifier_head, classifier, rotoflip as rotoflip_classifier
from tqdm.auto import tqdm, trange
import plotly.graph_objects as go
criterion = torch.nn.CrossEntropyLoss()
def train_step(model, optimizer, images, labels):
predictions = model(images)
loss = criterion(predictions, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
return loss.item()
def train(model, optimizer, data_loader, epochs):
losses = []
for epoch in epochs:
for images, labels in data_loader:
losses.append(train_step(
model=model,
optimizer=optimizer,
images=images,
labels=labels
))
return losses
def plot_losses(losses, title):
losses_smooth = pd.Series(losses).rolling(window=16).mean()
return go.Figure(
layout=dict(
title=title,
xaxis_title='Batch number',
yaxis_title='Loss',
showlegend=False
),
data=[
go.Scatter(
y=losses,
line=dict(color='rgba(128,206,255,0.25)')
),
go.Scatter(
y=losses_smooth,
line=dict(color='rgba(128,206,255,1)')
)
]
)
rotoflip_losses = train(
model=rotoflip_classifier,
optimizer=torch.optim.Adam(rotoflip_classifier.parameters(), lr=2e-3),
data_loader=tqdm(train_rotoflip_loader),
epochs=range(1)
)
plot_losses(rotoflip_losses, 'Rotation/flip classification losses')
HBox(children=(FloatProgress(value=0.0, max=3750.0), HTML(value='')))
classifier_losses = train(
model=classifier,
optimizer=torch.optim.Adam(classifier_head.parameters(), lr=1e-4),
data_loader=train_labeled_loader,
epochs=trange(4096)
)
plot_losses(classifier_losses, 'Digit classification losses')
HBox(children=(FloatProgress(value=0.0, max=4096.0), HTML(value='')))
torch.save(classifier.state_dict(), weights_path)