This notebook aims to give a minimal PyTorch example by training a very simple model to predict the input's square numbers.
We need to import torch
(the module containing the whole PyTorch lib) to create our model and tensors. Theoretically, it would be sufficient to import only torch
and then call torch.utils.data.Dataset
to access the Dataset-Baseclass, but for simplicity we import it explicitly. The same applies to the DataLoader
. tqdm
is imported to simply create a progressbar and os
is imported to access the filesystem.
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm_notebook as tqdm
import os
To load your data and your target labels it is recommended to use an instance of torch.utils.data.Dataset
. Since this class is an abstract baseclass you need to subclass it to suit your usecase. This is done in our case as:
class CustomIntegerDataset(Dataset):
def __init__(self, transforms=None, length=100000, max_int=100, min_int=0):
assert length >= max_int - min_int, "The dataset's length should be greater or equal the difference of max_int and min_int"
self.transforms = transforms
self.max_int = max_int
self.min_int = min_int
self.length = length
self.normalization = max_int
self.data = [torch.Tensor([i]) for i in range(min_int, max_int+1)]
def __getitem__(self, index):
true_index = index % len(self.data)
data = self.data[true_index]
label = data**2
if self.transforms:
data = self.transforms
return data / self.normalization, label / self.normalization
def __len__(self):
return self.length
This class must define a __init__
where the data has to be defined. The __getitem__
method is also necessary and should accept an integer index to index the data. You can either load your data in your __init__
and only index your cached data in __getitem__
or you just load it in your __getitem__
.
The last method which must be defined in your custom dataset is __len__
which indicates the length (i.e. number of elements) in your dataset. In our implementation we can specify the datasets length to adjust the number of iterations per epoch.
Each model in PyTorch should be defined in a subclass of torch.nn.Module
since this class handles things like parameter registration, correctly calling forward and backward hooks etc. A model also needs an __init__
to define the actual model (in our case only some linear/dense layers with ReLU activation; and should call the baseclasses __init__
by super().__init__()
) and a forward
defining the actual model behaviour. The forward
method is wrapped by the __call__
method which handles the hooks and some other stuff. Thus you should only call your model instance like your_model(input_tensor)
and not call it's forward independently. The forward
method can be implemented to take an arbitrary number of (keyword) arguments and all (keyword) arguments passed to the module's call will be forwarded to it. Our model definition looks like this:
class SimpleCustomModel(torch.nn.Module):
def __init__(self, num_layers=5, max_hidden_dim=64):
super().__init__()
layers = [torch.nn.Linear(1, 16),
torch.nn.ReLU()]
curr_dim = 16
for i in range(num_layers - 2):
new_dim = min(curr_dim * 2, max_hidden_dim)
layers += [torch.nn.Linear(curr_dim, new_dim),
torch.nn.ReLU()]
curr_dim = new_dim
layers.append(torch.nn.Linear(curr_dim, 1))
self.model = torch.nn.Sequential(*layers)
def forward(self, x):
return self.model(x)
In order to train our model's parameters we first need to do a few things:
# Create Model instance
model = SimpleCustomModel()
# Create optimizer holding the model's parameters and a common learningrate
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
# Create an instance of our custom dataset with default arguments
dataset = CustomIntegerDataset()
# Create a dataloader containing our dataset
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
# specify a device for training (for device-agnostic code in train-loop)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# push model to device
model = model.to(device)
# create instance of loss function
loss_fn = torch.nn.L1Loss()
# specify and create the directory, to which the checkpoints will be saved
out_dir = "./checkpoints"
os.makedirs(out_dir, exist_ok=True)
# set the number of epochs to train
NUM_EPOCHS = 150
We need to create an instance of our model class, an optimizer containing this model's parameters (here we use plain SGD, but other algorithms like ADAM etc. are also available) and create a dataloader wrapping our dataset to concatenate tensors returned by our dataset-instance (which we also created) to batches. Furthermore, we need to define our loss function (here we use the L1-Loss), specify the directory to save the model's parameters (and create it) and define the number of epochs to train.
To train the network we will iterate over the number of epochs to train. With tqdm
we create a progressbar holding our dataloader. Therefore, we can simply iterate over the progressbar instead of iterating over our dataloader to obtain the batches. With data, label = data.to(torch.float).to(device), label.to(torch.float).to(device)
we convert the data
and label
batches to float
-Tensors and push them to the correct device (CPU or GPU), since tensor-operations are usually only defined for tensors of the same type and on the same device (thus we defined the device before; Weights and bias parameters are per default of type float
and have been pushed to the right device before).
pred = model(data)
feeds the input (data
) through the model and obtains it's prediction (pred
). With loss_val = loss_fn(pred, label)
we calculate the loss value between the predictions and the desired targets (label
). Afterwards we need to set the parameter's gradients to zero (by optimizer.zero_grad()
) to prevent old gradients from influencing the parameter update, calculate the new gradients (by calling loss_val.backward
) and adjusting the parameters (by calling optimizer.step()
). By calling loss_val.item()
we obtain the actual value of the scalar-tensor (tensor with only one element) as plain python float to add it to the loss_vals
list (because otherwise we would have an overhead of tensor management and maybe even the computation graph if it wouldnt have been detached
).
With torch.save
we can save the model's parameters to a file (parameter values obtained by model.state_dict
), since it is recommended to only save the parameters instead of pickling the whole class.
In the following training loop you can see the mean loss decreasing:
# Iterate over epochs
for epoch in range(1, NUM_EPOCHS+1):
print("EPOCH: %03d of %03d" % (epoch, NUM_EPOCHS))
# create progress bar
p_bar = tqdm(dataloader)
p_bar.desc = "Loss: %d" % 0
loss_vals = []
for data, label in p_bar:
# convert to suitable dtype and push to correct device
data, label = data.to(torch.float).to(device), label.to(torch.float).to(device)
# feed data through network to obtain prediction
pred = model(data)
# calculate loss value
loss_val = loss_fn(pred, label)
# zero out the gradients
optimizer.zero_grad()
#calculate gradients (perform backprop)
loss_val.backward()
# adjust parameters
optimizer.step()
p_bar.desc = "Loss: %.6f" % loss_val.item()
loss_vals.append(loss_val.item())
print("\tMean Loss Value: %.6f" % (sum(loss_vals) / len(loss_vals)))
# save model's parameters to file
torch.save(model.state_dict(), os.path.join(out_dir, "model_epoch_%03d.pth" % epoch))
EPOCH: 001 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 15.485562 EPOCH: 002 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 1.154639 EPOCH: 003 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 1.027465 EPOCH: 004 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.959208 EPOCH: 005 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.908520 EPOCH: 006 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.877664 EPOCH: 007 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.853101 EPOCH: 008 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.831079 EPOCH: 009 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.822820 EPOCH: 010 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.806924 EPOCH: 011 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.801332 EPOCH: 012 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.785647 EPOCH: 013 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.775541 EPOCH: 014 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.774513 EPOCH: 015 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.762892 EPOCH: 016 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.759495 EPOCH: 017 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.752278 EPOCH: 018 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.744553 EPOCH: 019 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.741116 EPOCH: 020 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.733105 EPOCH: 021 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.731019 EPOCH: 022 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.729219 EPOCH: 023 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.715415 EPOCH: 024 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.713411 EPOCH: 025 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.713745 EPOCH: 026 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.704114 EPOCH: 027 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.700377 EPOCH: 028 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.695625 EPOCH: 029 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.691016 EPOCH: 030 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.686196 EPOCH: 031 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.680450 EPOCH: 032 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.673597 EPOCH: 033 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.675863 EPOCH: 034 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.668268 EPOCH: 035 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.658040 EPOCH: 036 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.657154 EPOCH: 037 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.649054 EPOCH: 038 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.645546 EPOCH: 039 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.639808 EPOCH: 040 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.637192 EPOCH: 041 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.630485 EPOCH: 042 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.623441 EPOCH: 043 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.618652 EPOCH: 044 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.611334 EPOCH: 045 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.611657 EPOCH: 046 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.608194 EPOCH: 047 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.602193 EPOCH: 048 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.594597 EPOCH: 049 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.591004 EPOCH: 050 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.582601 EPOCH: 051 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.582935 EPOCH: 052 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.579178 EPOCH: 053 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.570915 EPOCH: 054 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.562953 EPOCH: 055 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.558223 EPOCH: 056 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.557576 EPOCH: 057 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.550365 EPOCH: 058 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.550500 EPOCH: 059 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.546267 EPOCH: 060 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.538954 EPOCH: 061 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.536006 EPOCH: 062 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.535015 EPOCH: 063 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.534520 EPOCH: 064 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.530421 EPOCH: 065 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.529348 EPOCH: 066 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.522947 EPOCH: 067 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.522742 EPOCH: 068 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.516345 EPOCH: 069 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.517821 EPOCH: 070 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.510974 EPOCH: 071 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.512460 EPOCH: 072 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.501947 EPOCH: 073 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.506796 EPOCH: 074 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.502625 EPOCH: 075 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.506206 EPOCH: 076 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.505734 EPOCH: 077 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.500483 EPOCH: 078 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.497454 EPOCH: 079 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.494587 EPOCH: 080 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.493920 EPOCH: 081 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.494557 EPOCH: 082 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.491515 EPOCH: 083 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.493855 EPOCH: 084 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.492578 EPOCH: 085 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.488610 EPOCH: 086 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.490006 EPOCH: 087 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.489399 EPOCH: 088 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.485218 EPOCH: 089 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.484422 EPOCH: 090 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.483073 EPOCH: 091 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.486979 EPOCH: 092 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.482540 EPOCH: 093 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.482155 EPOCH: 094 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.481487 EPOCH: 095 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.480234 EPOCH: 096 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.479256 EPOCH: 097 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.475718 EPOCH: 098 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.476470 EPOCH: 099 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.469878 EPOCH: 100 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.474231 EPOCH: 101 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.470604 EPOCH: 102 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.469182 EPOCH: 103 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.468184 EPOCH: 104 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.467365 EPOCH: 105 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.465787 EPOCH: 106 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.467117 EPOCH: 107 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.469130 EPOCH: 108 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.465784 EPOCH: 109 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.460763 EPOCH: 110 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.463137 EPOCH: 111 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.460680 EPOCH: 112 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.462084 EPOCH: 113 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.457912 EPOCH: 114 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.459488 EPOCH: 115 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.459502 EPOCH: 116 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.460320 EPOCH: 117 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.457916 EPOCH: 118 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.458907 EPOCH: 119 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.453407 EPOCH: 120 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.453868 EPOCH: 121 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.453556 EPOCH: 122 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.453642 EPOCH: 123 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.449591 EPOCH: 124 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.450522 EPOCH: 125 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.449790 EPOCH: 126 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.449408 EPOCH: 127 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.448126 EPOCH: 128 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.444236 EPOCH: 129 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.448390 EPOCH: 130 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.446814 EPOCH: 131 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.447073 EPOCH: 132 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.444548 EPOCH: 133 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.446971 EPOCH: 134 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.445022 EPOCH: 135 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.439239 EPOCH: 136 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.442373 EPOCH: 137 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.438261 EPOCH: 138 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.439331 EPOCH: 139 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.437912 EPOCH: 140 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.438827 EPOCH: 141 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.433874 EPOCH: 142 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.433625 EPOCH: 143 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.434837 EPOCH: 144 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.432278 EPOCH: 145 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.435086 EPOCH: 146 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.430054 EPOCH: 147 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.433111 EPOCH: 148 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.431074 EPOCH: 149 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.431991 EPOCH: 150 of 150
HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))
Mean Loss Value: 0.429768
# load weights
model.load_state_dict(torch.load(os.path.join(out_dir, "model_epoch_%03d.pth" % epoch)))
The loop below predicts the values from 0-99 (which are values the model has seen during training).
Since large numbers are hard to learn for networks, the input and output data has been normalized inside the dataset by dividing them by max_int
. The same normalization must be applied during inference time and must be reverted to obtain the real predictions:
for i in range (0, 100):
# create tensor from integer, convert it to float and push it to correct device
tensor = torch.Tensor([i]).to(torch.float).to(device)
# add batch dimension since we only want to forward a single element
tensor = tensor.unsqueeze(0) / dataset.normalization
# feed tensor through model
pred = model(tensor)
# print input and prediction
print("Input: %05d \t Prediction: %05d \t Target: %05d" % (i, int(round(pred.item() * dataset.normalization)), i**2))
Input: 00000 Prediction: -0001 Target: 00000 Input: 00001 Prediction: 00002 Target: 00001 Input: 00002 Prediction: 00005 Target: 00004 Input: 00003 Prediction: 00009 Target: 00009 Input: 00004 Prediction: 00016 Target: 00016 Input: 00005 Prediction: 00026 Target: 00025 Input: 00006 Prediction: 00036 Target: 00036 Input: 00007 Prediction: 00049 Target: 00049 Input: 00008 Prediction: 00066 Target: 00064 Input: 00009 Prediction: 00083 Target: 00081 Input: 00010 Prediction: 00101 Target: 00100 Input: 00011 Prediction: 00121 Target: 00121 Input: 00012 Prediction: 00144 Target: 00144 Input: 00013 Prediction: 00169 Target: 00169 Input: 00014 Prediction: 00197 Target: 00196 Input: 00015 Prediction: 00225 Target: 00225 Input: 00016 Prediction: 00256 Target: 00256 Input: 00017 Prediction: 00287 Target: 00289 Input: 00018 Prediction: 00319 Target: 00324 Input: 00019 Prediction: 00350 Target: 00361 Input: 00020 Prediction: 00388 Target: 00400 Input: 00021 Prediction: 00436 Target: 00441 Input: 00022 Prediction: 00484 Target: 00484 Input: 00023 Prediction: 00531 Target: 00529 Input: 00024 Prediction: 00579 Target: 00576 Input: 00025 Prediction: 00627 Target: 00625 Input: 00026 Prediction: 00675 Target: 00676 Input: 00027 Prediction: 00728 Target: 00729 Input: 00028 Prediction: 00783 Target: 00784 Input: 00029 Prediction: 00839 Target: 00841 Input: 00030 Prediction: 00894 Target: 00900 Input: 00031 Prediction: 00950 Target: 00961 Input: 00032 Prediction: 01013 Target: 01024 Input: 00033 Prediction: 01085 Target: 01089 Input: 00034 Prediction: 01156 Target: 01156 Input: 00035 Prediction: 01226 Target: 01225 Input: 00036 Prediction: 01295 Target: 01296 Input: 00037 Prediction: 01365 Target: 01369 Input: 00038 Prediction: 01434 Target: 01444 Input: 00039 Prediction: 01503 Target: 01521 Input: 00040 Prediction: 01580 Target: 01600 Input: 00041 Prediction: 01667 Target: 01681 Input: 00042 Prediction: 01753 Target: 01764 Input: 00043 Prediction: 01832 Target: 01849 Input: 00044 Prediction: 01910 Target: 01936 Input: 00045 Prediction: 01993 Target: 02025 Input: 00046 Prediction: 02082 Target: 02116 Input: 00047 Prediction: 02184 Target: 02209 Input: 00048 Prediction: 02259 Target: 02304 Input: 00049 Prediction: 02360 Target: 02401 Input: 00050 Prediction: 02461 Target: 02500 Input: 00051 Prediction: 02561 Target: 02601 Input: 00052 Prediction: 02661 Target: 02704 Input: 00053 Prediction: 02761 Target: 02809 Input: 00054 Prediction: 02861 Target: 02916 Input: 00055 Prediction: 02961 Target: 03025 Input: 00056 Prediction: 03069 Target: 03136 Input: 00057 Prediction: 03188 Target: 03249 Input: 00058 Prediction: 03306 Target: 03364 Input: 00059 Prediction: 03425 Target: 03481 Input: 00060 Prediction: 03543 Target: 03600 Input: 00061 Prediction: 03662 Target: 03721 Input: 00062 Prediction: 03781 Target: 03844 Input: 00063 Prediction: 03900 Target: 03969 Input: 00064 Prediction: 04019 Target: 04096 Input: 00065 Prediction: 04137 Target: 04225 Input: 00066 Prediction: 04256 Target: 04356 Input: 00067 Prediction: 04389 Target: 04489 Input: 00068 Prediction: 04530 Target: 04624 Input: 00069 Prediction: 04670 Target: 04761 Input: 00070 Prediction: 04811 Target: 04900 Input: 00071 Prediction: 04952 Target: 05041 Input: 00072 Prediction: 05092 Target: 05184 Input: 00073 Prediction: 05233 Target: 05329 Input: 00074 Prediction: 05376 Target: 05476 Input: 00075 Prediction: 05521 Target: 05625 Input: 00076 Prediction: 05666 Target: 05776 Input: 00077 Prediction: 05811 Target: 05929 Input: 00078 Prediction: 05956 Target: 06084 Input: 00079 Prediction: 06115 Target: 06241 Input: 00080 Prediction: 06278 Target: 06400 Input: 00081 Prediction: 06441 Target: 06561 Input: 00082 Prediction: 06604 Target: 06724 Input: 00083 Prediction: 06767 Target: 06889 Input: 00084 Prediction: 06930 Target: 07056 Input: 00085 Prediction: 07094 Target: 07225 Input: 00086 Prediction: 07259 Target: 07396 Input: 00087 Prediction: 07430 Target: 07569 Input: 00088 Prediction: 07600 Target: 07744 Input: 00089 Prediction: 07770 Target: 07921 Input: 00090 Prediction: 07934 Target: 08100 Input: 00091 Prediction: 08112 Target: 08281 Input: 00092 Prediction: 08296 Target: 08464 Input: 00093 Prediction: 08483 Target: 08649 Input: 00094 Prediction: 08669 Target: 08836 Input: 00095 Prediction: 08856 Target: 09025 Input: 00096 Prediction: 09043 Target: 09216 Input: 00097 Prediction: 09229 Target: 09409 Input: 00098 Prediction: 09416 Target: 09604 Input: 00099 Prediction: 09603 Target: 09801
The loop below predicts the outputs for values above 100. Since the network has never seen such inputs during training it may go quite wrong since it may not have generalized the outputs but learned the outputs belonging to these specific inputs
for i in range (101, 200):
# create tensor from integer, convert it to float and push it to correct device
tensor = torch.Tensor([i]).to(torch.float).to(device)
# add batch dimension since we only want to forward a single element
tensor = tensor.unsqueeze(0) / dataset.normalization
# feed tensor through model
pred = model(tensor)
# print input and prediction
print("Input: %05d \t Prediction: %05d \t Target: %05d" % (i, int(round(pred.item() * dataset.normalization)), i**2))
Input: 00101 Prediction: 09919 Target: 10201 Input: 00102 Prediction: 10077 Target: 10404 Input: 00103 Prediction: 10235 Target: 10609 Input: 00104 Prediction: 10392 Target: 10816 Input: 00105 Prediction: 10550 Target: 11025 Input: 00106 Prediction: 10708 Target: 11236 Input: 00107 Prediction: 10866 Target: 11449 Input: 00108 Prediction: 11023 Target: 11664 Input: 00109 Prediction: 11181 Target: 11881 Input: 00110 Prediction: 11338 Target: 12100 Input: 00111 Prediction: 11495 Target: 12321 Input: 00112 Prediction: 11653 Target: 12544 Input: 00113 Prediction: 11810 Target: 12769 Input: 00114 Prediction: 11967 Target: 12996 Input: 00115 Prediction: 12124 Target: 13225 Input: 00116 Prediction: 12281 Target: 13456 Input: 00117 Prediction: 12438 Target: 13689 Input: 00118 Prediction: 12596 Target: 13924 Input: 00119 Prediction: 12753 Target: 14161 Input: 00120 Prediction: 12910 Target: 14400 Input: 00121 Prediction: 13067 Target: 14641 Input: 00122 Prediction: 13224 Target: 14884 Input: 00123 Prediction: 13381 Target: 15129 Input: 00124 Prediction: 13538 Target: 15376 Input: 00125 Prediction: 13695 Target: 15625 Input: 00126 Prediction: 13852 Target: 15876 Input: 00127 Prediction: 14009 Target: 16129 Input: 00128 Prediction: 14166 Target: 16384 Input: 00129 Prediction: 14323 Target: 16641 Input: 00130 Prediction: 14480 Target: 16900 Input: 00131 Prediction: 14637 Target: 17161 Input: 00132 Prediction: 14794 Target: 17424 Input: 00133 Prediction: 14951 Target: 17689 Input: 00134 Prediction: 15108 Target: 17956 Input: 00135 Prediction: 15265 Target: 18225 Input: 00136 Prediction: 15422 Target: 18496 Input: 00137 Prediction: 15579 Target: 18769 Input: 00138 Prediction: 15736 Target: 19044 Input: 00139 Prediction: 15894 Target: 19321 Input: 00140 Prediction: 16051 Target: 19600 Input: 00141 Prediction: 16209 Target: 19881 Input: 00142 Prediction: 16366 Target: 20164 Input: 00143 Prediction: 16524 Target: 20449 Input: 00144 Prediction: 16681 Target: 20736 Input: 00145 Prediction: 16839 Target: 21025 Input: 00146 Prediction: 16996 Target: 21316 Input: 00147 Prediction: 17154 Target: 21609 Input: 00148 Prediction: 17311 Target: 21904 Input: 00149 Prediction: 17468 Target: 22201 Input: 00150 Prediction: 17626 Target: 22500 Input: 00151 Prediction: 17783 Target: 22801 Input: 00152 Prediction: 17940 Target: 23104 Input: 00153 Prediction: 18097 Target: 23409 Input: 00154 Prediction: 18254 Target: 23716 Input: 00155 Prediction: 18412 Target: 24025 Input: 00156 Prediction: 18569 Target: 24336 Input: 00157 Prediction: 18726 Target: 24649 Input: 00158 Prediction: 18883 Target: 24964 Input: 00159 Prediction: 19040 Target: 25281 Input: 00160 Prediction: 19198 Target: 25600 Input: 00161 Prediction: 19355 Target: 25921 Input: 00162 Prediction: 19512 Target: 26244 Input: 00163 Prediction: 19669 Target: 26569 Input: 00164 Prediction: 19826 Target: 26896 Input: 00165 Prediction: 19984 Target: 27225 Input: 00166 Prediction: 20141 Target: 27556 Input: 00167 Prediction: 20298 Target: 27889 Input: 00168 Prediction: 20455 Target: 28224 Input: 00169 Prediction: 20612 Target: 28561 Input: 00170 Prediction: 20770 Target: 28900 Input: 00171 Prediction: 20927 Target: 29241 Input: 00172 Prediction: 21084 Target: 29584 Input: 00173 Prediction: 21241 Target: 29929 Input: 00174 Prediction: 21398 Target: 30276 Input: 00175 Prediction: 21556 Target: 30625 Input: 00176 Prediction: 21713 Target: 30976 Input: 00177 Prediction: 21870 Target: 31329 Input: 00178 Prediction: 22027 Target: 31684 Input: 00179 Prediction: 22184 Target: 32041 Input: 00180 Prediction: 22342 Target: 32400 Input: 00181 Prediction: 22499 Target: 32761 Input: 00182 Prediction: 22656 Target: 33124 Input: 00183 Prediction: 22813 Target: 33489 Input: 00184 Prediction: 22970 Target: 33856 Input: 00185 Prediction: 23128 Target: 34225 Input: 00186 Prediction: 23285 Target: 34596 Input: 00187 Prediction: 23442 Target: 34969 Input: 00188 Prediction: 23599 Target: 35344 Input: 00189 Prediction: 23756 Target: 35721 Input: 00190 Prediction: 23914 Target: 36100 Input: 00191 Prediction: 24071 Target: 36481 Input: 00192 Prediction: 24228 Target: 36864 Input: 00193 Prediction: 24385 Target: 37249 Input: 00194 Prediction: 24542 Target: 37636 Input: 00195 Prediction: 24700 Target: 38025 Input: 00196 Prediction: 24857 Target: 38416 Input: 00197 Prediction: 25014 Target: 38809 Input: 00198 Prediction: 25171 Target: 39204 Input: 00199 Prediction: 25328 Target: 39601
If this is the case, you should consider to increase the diversity in your dataset (with samples representing every type of data you await).
If the training is not sufficient for you, you can try different hyperparameters (length
argument in Dataset, number of epochs, learning rate) or another model structure.
This model uses Linear layers only. When dealing with other kinds of data, you'd probably want to swith to convolutional architectures.
This example is minimalistic and does not contain some techniques, usually used for successfull training (i.e. Dropout, Intermediate Normalization, Learning Rate scheduling etc.) since this would make the example unnecessarily complicated.
A network's performance should virtually always be tested on a dataset consisting of samples, that are not in your trainset.
A third dataset might be used for hyperparameter optimization.