%matplotlib inline
%reload_ext autoreload
%autoreload 2
from fastai.conv_learner import *
PATH = "data/cifar10/"
os.makedirs(PATH,exist_ok=True)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
stats = (np.array([ 0.4914 , 0.48216, 0.44653]), np.array([ 0.24703, 0.24349, 0.26159]))
def get_data(sz,bs):
tfms = tfms_from_stats(stats, sz, aug_tfms=[RandomFlip()], pad=sz//8)
return ImageClassifierData.from_paths(PATH, val_name='test', tfms=tfms, bs=bs)
bs=128
import torch
import torch.nn as nn
class tofp16(nn.Module):
def __init__(self):
super(tofp16, self).__init__()
def forward(self, input):
return input.half()
def copy_in_params(net, params):
net_params = list(net.parameters())
for i in range(len(params)):
net_params[i].data.copy_(params[i].data)
def set_grad(params, params_with_grad):
for param, param_w_grad in zip(params, params_with_grad):
if param.grad is None:
param.grad = torch.nn.Parameter(param.data.new().resize_(*param.data.size()))
param.grad.data.copy_(param_w_grad.grad.data)
#BatchNorm layers to have parameters in single precision.
#Find all layers and convert them back to float. This can't
#be done with built in .apply as that function will apply
#fn to all modules, parameters, and buffers. Thus we wouldn't
#be able to guard the float conversion based on the module type.
def BN_convert_float(module):
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
module.float()
for child in module.children():
BN_convert_float(child)
return module
def network_to_half(network):
return nn.Sequential(tofp16(), BN_convert_float(network.cuda().half()))
from fastai.models.cifar10.resnext import resnext29_8_64
m = resnext29_8_64()
# m = resnet50(False)
bm = BasicModel(network_to_half(m).cuda(), name='cifar10_resnet50')
data = get_data(8,bs*4*4)
learn = ConvLearner(data, bm)
learn.unfreeze()
lr=4e-2; wd=5e-4
%time learn.fit(lr, 1, cycle_len=3, use_clr=(20,8))
Failed to display Jupyter Widget of type HBox
.
If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean that the widgets JavaScript is still loading. If this message persists, it likely means that the widgets JavaScript library is either not installed or not enabled. See the Jupyter Widgets Documentation for setup instructions.
If you're reading this message in another frontend (for example, a static rendering on GitHub or NBViewer), it may mean that your frontend doesn't currently support widgets.
epoch trn_loss val_loss accuracy 0 3.748584 4.073438 0.199611 1 2.67269 1.886133 0.30006 2 2.250531 1.782031 0.36761 CPU times: user 1min 40s, sys: 47.9 s, total: 2min 28s Wall time: 1min 39s
[1.7820313, 0.367610102891922]
data = get_data(32,bs*4).half()
%time learn.fit(lr, 1, cycle_len=3, use_clr=(20,8))
Failed to display Jupyter Widget of type HBox
.
If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean that the widgets JavaScript is still loading. If this message persists, it likely means that the widgets JavaScript library is either not installed or not enabled. See the Jupyter Widgets Documentation for setup instructions.
If you're reading this message in another frontend (for example, a static rendering on GitHub or NBViewer), it may mean that your frontend doesn't currently support widgets.
epoch trn_loss val_loss accuracy 0 1.583998 1.559375 0.462343 1 1.508549 1.411523 0.492938 2 1.443428 1.383203 0.507284 CPU times: user 1min 48s, sys: 45.1 s, total: 2min 33s Wall time: 1min 46s
[1.3832031, 0.50728360414505]
from fastai.models.cifar10.resnext import resnext29_8_64
mf = resnext29_8_64()
# m = resnet50(False)
bmf = BasicModel(mf.cuda(), name='cifar10_resnet50')
dataf = get_data(8,bs*4*4)
learnf = ConvLearner(dataf, bmf)
learnf.unfreeze()
lr=4e-2; wd=5e-4
%time learnf.fit(lr, 1, cycle_len=3, use_clr=(20,8))
Failed to display Jupyter Widget of type HBox
.
If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean that the widgets JavaScript is still loading. If this message persists, it likely means that the widgets JavaScript library is either not installed or not enabled. See the Jupyter Widgets Documentation for setup instructions.
If you're reading this message in another frontend (for example, a static rendering on GitHub or NBViewer), it may mean that your frontend doesn't currently support widgets.
epoch trn_loss val_loss accuracy 0 3.712339 2.205959 0.234404 1 2.59699 1.735602 0.357074 2 2.128743 1.61046 0.409041 CPU times: user 1min 22s, sys: 46.6 s, total: 2min 8s Wall time: 1min 33s
[1.6104597, 0.4090405464172363]
data = get_data(32,bs*4)
%time learnf.fit(lr, 1, cycle_len=3, use_clr=(20,8))
Failed to display Jupyter Widget of type HBox
.
If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean that the widgets JavaScript is still loading. If this message persists, it likely means that the widgets JavaScript library is either not installed or not enabled. See the Jupyter Widgets Documentation for setup instructions.
If you're reading this message in another frontend (for example, a static rendering on GitHub or NBViewer), it may mean that your frontend doesn't currently support widgets.
epoch trn_loss val_loss accuracy 0 1.606625 1.529435 0.445549 1 1.535117 1.434556 0.485208 2 1.466418 1.392065 0.501437 CPU times: user 1min 20s, sys: 47.4 s, total: 2min 8s Wall time: 1min 33s
[1.3920648, 0.5014371871948242]
FP16 is actually slower in these tests. Will have to look at why this is. Possible reasons: