GroupNorm
vs BatchNorm
on Pets
Dataset¶In this notebook, we implement GroupNorm
with Weight Standardization
and compare the results with BatchNorm
. Simply replacing BN
with GN
lead to sub-optimal results.
from fastai2.vision.all import *
from nbdev.showdoc import *
import glob
import albumentations
from torchvision import models
from albumentations.pytorch.transforms import ToTensorV2
set_s`eed(2)
Resnet
Implementation¶We copy the implementation of Weight Standardization
from the official repository here and also copy the implementation of ResNet
from TorchVision.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.hub import load_state_dict_from_url
model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
}
We replace the Convolution
layer inside ResNet
with the standardized version as in the Standardized Weights
research paper. Everything else remains the same.
class Conv2d_WS(nn.Conv2d):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True):
super(Conv2d_WS, self).__init__(in_channels, out_channels, kernel_size, stride,
padding, dilation, groups, bias)
def forward(self, x):
weight = self.weight
weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2,
keepdim=True).mean(dim=3, keepdim=True)
weight = weight - weight_mean
std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5
weight = weight / std.expand_as(weight)
return F.conv2d(x, weight, self.bias, self.stride,
self.padding, self.dilation, self.groups)
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
return Conv2d_WS(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=False, dilation=dilation)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return Conv2d_WS(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(Bottleneck, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=None):
super(ResNet, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer
self.inplanes = 64
self.dilation = 1
if replace_stride_with_dilation is None:
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.groups = groups
self.base_width = width_per_group
self.conv1 = Conv2d_WS(3, self.inplanes, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
dilate=replace_stride_with_dilation[0])
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
dilate=replace_stride_with_dilation[1])
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
dilate=replace_stride_with_dilation[2])
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0)
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
norm_layer(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, previous_dilation, norm_layer))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, groups=self.groups,
base_width=self.base_width, dilation=self.dilation,
norm_layer=norm_layer))
return nn.Sequential(*layers)
def _forward_impl(self, x):
# See note [TorchScript super()]
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
def forward(self, x):
return self._forward_impl(x)
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
model = ResNet(block, layers, **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch],
progress=progress)
model.load_state_dict(state_dict)
return model
def resnet18(pretrained=False, progress=True, **kwargs):
r"""ResNet-18 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
**kwargs)
def resnet34(pretrained=False, progress=True, **kwargs):
r"""ResNet-34 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
**kwargs)
def resnet50(pretrained=False, progress=True, **kwargs):
r"""ResNet-50 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
**kwargs)
def resnet101(pretrained=False, progress=True, **kwargs):
r"""ResNet-101 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
**kwargs)
def resnet152(pretrained=False, progress=True, **kwargs):
r"""ResNet-152 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
**kwargs)
def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
r"""ResNeXt-50 32x4d model from
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['groups'] = 32
kwargs['width_per_group'] = 4
return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
pretrained, progress, **kwargs)
def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
r"""ResNeXt-101 32x8d model from
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['groups'] = 32
kwargs['width_per_group'] = 8
return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
pretrained, progress, **kwargs)
def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
r"""Wide ResNet-50-2 model from
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
The model is the same as ResNet except for the bottleneck number of channels
which is twice larger in every block. The number of channels in outer 1x1
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['width_per_group'] = 64 * 2
return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
pretrained, progress, **kwargs)
def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
r"""Wide ResNet-101-2 model from
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
The model is the same as ResNet except for the bottleneck number of channels
which is twice larger in every block. The number of channels in outer 1x1
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['width_per_group'] = 64 * 2
return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
pretrained, progress, **kwargs)
Pets
Dataset¶Now we use the wonderful fastai library to get the Pets
dataset.
bs = 4
path = untar_data(URLs.PETS); path
Path('/home/ubuntu/.fastai/data/oxford-iiit-pet')
(path/'images').ls()
(#7381) [Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/keeshond_34.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Siamese_178.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/german_shorthaired_94.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Abyssinian_92.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/basset_hound_111.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Russian_Blue_194.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/staffordshire_bull_terrier_91.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Persian_69.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/english_setter_33.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Russian_Blue_155.jpg')...]
Dataset
¶The implementation of the PetsDataset
has been heavily inspired and partially copied (regex part) from fastai2
repo here.
class PetsDataset:
def __init__(self, paths, transforms=None):
self.image_paths = paths
self.transforms = transforms
def __len__(self):
return len(self.image_paths)
def setup(self, pat=r'(.+)_\d+.jpg$', label2int=None):
"adds a label dictionary to `self`"
self.pat = re.compile(pat)
if label2int is not None:
self.label2int = label2int
self.int2label = {v:i for i,v in self.label2int.items()}
else:
labels = [os.path.basename(self.pat.search(str(p)).group(1))
for p in self.image_paths]
self.labels = set(labels)
self.label2int = {label:i for i,label in enumerate(self.labels)}
self.int2label = {v:i for i,v in self.label2int.items()}
def __getitem__(self, idx):
img_path = self.image_paths[idx]
img = Image.open(img_path)
img = np.array(img)
target = os.path.basename(self.pat.search(str(img_path)).group(1))
target = self.label2int[target]
if self.transforms:
img = self.transforms(image=img)['image']
return img, torch.tensor(target, dtype=torch.long)
image_paths = get_image_files(path/'images')
image_paths
(#7378) [Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/keeshond_34.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Siamese_178.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/german_shorthaired_94.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Abyssinian_92.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/basset_hound_111.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Russian_Blue_194.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/staffordshire_bull_terrier_91.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Persian_69.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/english_setter_33.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Russian_Blue_155.jpg')...]
# remove those images that are not 3 channel
from tqdm.notebook import tqdm
run_remove = False
def remove(o):
img = Image.open(o)
img = np.array(img)
if img.shape[2] != 3:
os.remove(o)
if run_remove:
for o in tqdm(image_paths): remove(o)
image_paths = get_image_files(path/'images')
image_paths
(#7378) [Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/keeshond_34.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Siamese_178.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/german_shorthaired_94.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Abyssinian_92.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/basset_hound_111.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Russian_Blue_194.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/staffordshire_bull_terrier_91.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Persian_69.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/english_setter_33.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Russian_Blue_155.jpg')...]
# augmentations using `albumentations` library
sz = 224
tfms = albumentations.Compose([
albumentations.Resize(sz, sz) if sz else albumentations.NoOp(),
albumentations.OneOf(
[albumentations.Cutout(random.randint(1,8), 16, 16),
albumentations.CoarseDropout(random.randint(1,8), 16, 16)]
),
albumentations.Normalize(always_apply=True),
ToTensorV2()
])
dataset = PetsDataset(image_paths, tfms)
# to setup the `label2int` dictionary
dataset.setup()
dataset[0]
(tensor([[[ 0.8618, 0.1597, 0.4166, ..., -0.6452, -0.3198, -0.2171], [ 1.1872, 0.3481, 0.4166, ..., -0.3027, 0.0912, 0.3138], [ 0.8104, 0.6049, 0.0227, ..., -0.3712, -0.1657, -0.1828], ..., [ 1.2385, 0.4851, 0.0227, ..., 0.8789, 1.2214, 0.8961], [ 0.7077, 0.9474, -0.6965, ..., 0.1254, 1.5297, 1.6667], [ 0.1083, -0.0801, 0.3652, ..., 0.2111, 0.5193, 0.6734]], [[ 0.9230, 0.4328, 0.4503, ..., -0.2850, -0.0224, -0.0399], [ 1.3256, 0.7304, 0.4678, ..., -0.0399, 0.1527, 0.3277], [ 0.8354, 0.8354, 0.3102, ..., -0.2500, -0.1975, -0.3200], ..., [ 1.3606, 1.3431, 0.6078, ..., 0.9755, 1.3957, 1.1331], [ 0.7654, 1.0455, -0.0574, ..., 0.7654, 1.6232, 1.7458], [ 0.4153, 0.5903, 0.9230, ..., 0.7654, 0.8529, 1.0980]], [[ 0.3393, -0.3578, -0.4275, ..., -0.7936, -0.4624, -0.3578], [ 0.6531, -0.2358, -0.4973, ..., -0.3753, -0.0615, 0.1128], [ 0.0431, 0.1128, -1.0201, ..., -0.4101, -0.2707, -0.3578], ..., [ 0.7228, 0.3219, -0.5321, ..., 0.4439, 1.0017, 0.7576], [ 0.2173, 0.4265, -1.1247, ..., -0.0790, 1.1411, 1.2457], [-0.4450, -0.2881, 0.1302, ..., 0.0082, 0.2696, 0.4439]]]), tensor(24))
dataset[0][0].shape
torch.Size([3, 224, 224])
DataLoaders
¶We divide the image_paths
into train and validation with 20% split.
nval = int(len(image_paths)*0.2)
nval
1475
trn_img_paths = image_paths[:-nval]
val_img_paths = image_paths[-nval:]
assert len(trn_img_paths) + len(val_img_paths) == len(image_paths)
len(trn_img_paths), len(val_img_paths)
(5903, 1475)
trn_dataset = PetsDataset(trn_img_paths, transforms=tfms)
val_dataset = PetsDataset(val_img_paths, transforms=tfms)
# use same `label2int` dictionary as in `dataset` for consistency across train and val
trn_dataset.setup(label2int=dataset.label2int)
val_dataset.setup(label2int=dataset.label2int)
trn_loader = torch.utils.data.DataLoader(trn_dataset, batch_size=bs, num_workers=4, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=bs, num_workers=4, shuffle=False)
# make sure eveyrthing works so far
next(iter(trn_loader))[0].shape, next(iter(val_loader))[0].shape
(torch.Size([4, 3, 224, 224]), torch.Size([4, 3, 224, 224]))
Model
¶Now, we define the resnet34 from the torchvision
repo with pretrained=False
as we do not have pretrained weights for the GroupNorm
layer.
# Vanilla resnet with `BatchNorm`
resnet34_bn = models.resnet34(num_classes=len(trn_dataset.label2int), pretrained=False)
resnet34_bn
ResNet( (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) (layer1): Sequential( (0): BasicBlock( (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (1): BasicBlock( (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (2): BasicBlock( (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (layer2): Sequential( (0): BasicBlock( (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (downsample): Sequential( (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False) (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): BasicBlock( (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (2): BasicBlock( (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (3): BasicBlock( (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (layer3): Sequential( (0): BasicBlock( (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (downsample): Sequential( (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False) (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): BasicBlock( (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (2): BasicBlock( (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (3): BasicBlock( (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (4): BasicBlock( (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (5): BasicBlock( (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (layer4): Sequential( (0): BasicBlock( (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (downsample): Sequential( (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False) (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (1): BasicBlock( (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (2): BasicBlock( (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (avgpool): AdaptiveAvgPool2d(output_size=(1, 1)) (fc): Linear(in_features=512, out_features=37, bias=True) )
Next, we define GroupNorm_32
class with default 32 groups as in the Group Normalization
research paper here.
class GroupNorm_32(torch.nn.GroupNorm):
def __init__(self, num_channels, num_groups=32, **kwargs):
super().__init__(num_groups, num_channels, **kwargs)
# resnet34 with `GroupNorm` and `Standardized Weights`
# `conv2d` replaced with `Conv2d_WS` and `BatchNorm` replaced with `GroupNorm`
resnet34_gn = resnet34(norm_layer=GroupNorm_32, num_classes=len(trn_dataset.label2int))
resnet34_gn
ResNet( (conv1): Conv2d_WS(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) (bn1): GroupNorm_32(32, 64, eps=1e-05, affine=True) (relu): ReLU(inplace=True) (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) (layer1): Sequential( (0): BasicBlock( (conv1): Conv2d_WS(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): GroupNorm_32(32, 64, eps=1e-05, affine=True) (relu): ReLU(inplace=True) (conv2): Conv2d_WS(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): GroupNorm_32(32, 64, eps=1e-05, affine=True) ) (1): BasicBlock( (conv1): Conv2d_WS(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): GroupNorm_32(32, 64, eps=1e-05, affine=True) (relu): ReLU(inplace=True) (conv2): Conv2d_WS(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): GroupNorm_32(32, 64, eps=1e-05, affine=True) ) (2): BasicBlock( (conv1): Conv2d_WS(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): GroupNorm_32(32, 64, eps=1e-05, affine=True) (relu): ReLU(inplace=True) (conv2): Conv2d_WS(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): GroupNorm_32(32, 64, eps=1e-05, affine=True) ) ) (layer2): Sequential( (0): BasicBlock( (conv1): Conv2d_WS(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (bn1): GroupNorm_32(32, 128, eps=1e-05, affine=True) (relu): ReLU(inplace=True) (conv2): Conv2d_WS(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): GroupNorm_32(32, 128, eps=1e-05, affine=True) (downsample): Sequential( (0): Conv2d_WS(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False) (1): GroupNorm_32(32, 128, eps=1e-05, affine=True) ) ) (1): BasicBlock( (conv1): Conv2d_WS(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): GroupNorm_32(32, 128, eps=1e-05, affine=True) (relu): ReLU(inplace=True) (conv2): Conv2d_WS(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): GroupNorm_32(32, 128, eps=1e-05, affine=True) ) (2): BasicBlock( (conv1): Conv2d_WS(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): GroupNorm_32(32, 128, eps=1e-05, affine=True) (relu): ReLU(inplace=True) (conv2): Conv2d_WS(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): GroupNorm_32(32, 128, eps=1e-05, affine=True) ) (3): BasicBlock( (conv1): Conv2d_WS(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): GroupNorm_32(32, 128, eps=1e-05, affine=True) (relu): ReLU(inplace=True) (conv2): Conv2d_WS(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): GroupNorm_32(32, 128, eps=1e-05, affine=True) ) ) (layer3): Sequential( (0): BasicBlock( (conv1): Conv2d_WS(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (bn1): GroupNorm_32(32, 256, eps=1e-05, affine=True) (relu): ReLU(inplace=True) (conv2): Conv2d_WS(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): GroupNorm_32(32, 256, eps=1e-05, affine=True) (downsample): Sequential( (0): Conv2d_WS(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False) (1): GroupNorm_32(32, 256, eps=1e-05, affine=True) ) ) (1): BasicBlock( (conv1): Conv2d_WS(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): GroupNorm_32(32, 256, eps=1e-05, affine=True) (relu): ReLU(inplace=True) (conv2): Conv2d_WS(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): GroupNorm_32(32, 256, eps=1e-05, affine=True) ) (2): BasicBlock( (conv1): Conv2d_WS(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): GroupNorm_32(32, 256, eps=1e-05, affine=True) (relu): ReLU(inplace=True) (conv2): Conv2d_WS(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): GroupNorm_32(32, 256, eps=1e-05, affine=True) ) (3): BasicBlock( (conv1): Conv2d_WS(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): GroupNorm_32(32, 256, eps=1e-05, affine=True) (relu): ReLU(inplace=True) (conv2): Conv2d_WS(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): GroupNorm_32(32, 256, eps=1e-05, affine=True) ) (4): BasicBlock( (conv1): Conv2d_WS(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): GroupNorm_32(32, 256, eps=1e-05, affine=True) (relu): ReLU(inplace=True) (conv2): Conv2d_WS(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): GroupNorm_32(32, 256, eps=1e-05, affine=True) ) (5): BasicBlock( (conv1): Conv2d_WS(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): GroupNorm_32(32, 256, eps=1e-05, affine=True) (relu): ReLU(inplace=True) (conv2): Conv2d_WS(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): GroupNorm_32(32, 256, eps=1e-05, affine=True) ) ) (layer4): Sequential( (0): BasicBlock( (conv1): Conv2d_WS(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (bn1): GroupNorm_32(32, 512, eps=1e-05, affine=True) (relu): ReLU(inplace=True) (conv2): Conv2d_WS(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): GroupNorm_32(32, 512, eps=1e-05, affine=True) (downsample): Sequential( (0): Conv2d_WS(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False) (1): GroupNorm_32(32, 512, eps=1e-05, affine=True) ) ) (1): BasicBlock( (conv1): Conv2d_WS(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): GroupNorm_32(32, 512, eps=1e-05, affine=True) (relu): ReLU(inplace=True) (conv2): Conv2d_WS(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): GroupNorm_32(32, 512, eps=1e-05, affine=True) ) (2): BasicBlock( (conv1): Conv2d_WS(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): GroupNorm_32(32, 512, eps=1e-05, affine=True) (relu): ReLU(inplace=True) (conv2): Conv2d_WS(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): GroupNorm_32(32, 512, eps=1e-05, affine=True) ) ) (avgpool): AdaptiveAvgPool2d(output_size=(1, 1)) (fc): Linear(in_features=512, out_features=37, bias=True) )
# make sure we are able to make forward pass
resnet34_gn(next(iter(trn_loader))[0]).shape
torch.Size([4, 37])
PytorchLightning
¶Finally, we use PytorchLightning for training the model.
from pytorch_lightning import LightningModule, Trainer
class Model(LightningModule):
def __init__(self, base):
super().__init__()
self.base = base
def forward(self, x):
return self.base(x)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)
def step(self, batch):
x, y = batch
y_hat = self(x)
loss = nn.CrossEntropyLoss()(y_hat, y)
return loss, y, y_hat
def training_step(self, batch, batch_nb):
loss, _, _ = self.step(batch)
return {'loss': loss}
def validation_step(self, batch, batch_nb):
loss, y, y_hat = self.step(batch)
return {'loss': loss, 'y': y.detach(), 'y_hat': y_hat.detach()}
def validation_epoch_end(self, outputs):
avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
acc = self.get_accuracy(outputs)
print(f"Epoch:{self.current_epoch} | Loss:{avg_loss} | Accuracy:{acc}")
return {'loss': avg_loss}
def get_accuracy(self, outputs):
from sklearn.metrics import accuracy_score
y = torch.cat([x['y'] for x in outputs])
y_hat = torch.cat([x['y_hat'] for x in outputs])
preds = y_hat.argmax(1)
return accuracy_score(y.cpu().numpy(), preds.cpu().numpy())
# define PL versions
model_bn = Model(resnet34_bn)
model_gn = Model(resnet34_gn)
debug = False
gpus = torch.cuda.device_count()
trainer = Trainer(gpus=gpus, max_epochs=50,
num_sanity_val_steps=1 if debug else 0)
GPU available: True, used: True TPU available: False, using: 0 TPU cores CUDA_VISIBLE_DEVICES: [0]
batch_size=4
¶# train model with `GroupNorm` with `bs=4` on the `Pets` dataset
trainer = Trainer(gpus=gpus, max_epochs=25,
num_sanity_val_steps=1 if debug else 0)
trainer.fit(model_gn, train_dataloader=trn_loader, val_dataloaders=val_loader)
GPU available: True, used: True TPU available: False, using: 0 TPU cores CUDA_VISIBLE_DEVICES: [0] | Name | Type | Params -------------------------------- 0 | base | ResNet | 21 M
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:0 | Loss:3.638690710067749 | Accuracy:0.022372881355932205
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:1 | Loss:3.5767452716827393 | Accuracy:0.03728813559322034
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:2 | Loss:3.532081365585327 | Accuracy:0.05152542372881356
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:3 | Loss:3.497438907623291 | Accuracy:0.06033898305084746
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:4 | Loss:3.437784194946289 | Accuracy:0.07457627118644068
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:5 | Loss:3.3992772102355957 | Accuracy:0.07322033898305084
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:6 | Loss:3.3322556018829346 | Accuracy:0.08203389830508474
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:7 | Loss:3.278475761413574 | Accuracy:0.09220338983050848
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:8 | Loss:3.2041637897491455 | Accuracy:0.12
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:9 | Loss:3.1338086128234863 | Accuracy:0.13288135593220338
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:10 | Loss:2.9662578105926514 | Accuracy:0.15593220338983052
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:11 | Loss:2.9380886554718018 | Accuracy:0.16203389830508474
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:12 | Loss:2.7531585693359375 | Accuracy:0.21627118644067797
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:13 | Loss:2.7896103858947754 | Accuracy:0.2223728813559322
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:14 | Loss:2.5649585723876953 | Accuracy:0.26372881355932204
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:15 | Loss:2.5243453979492188 | Accuracy:0.3071186440677966
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:16 | Loss:2.453778028488159 | Accuracy:0.3220338983050847
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:17 | Loss:2.575655460357666 | Accuracy:0.33016949152542374
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:18 | Loss:2.723491668701172 | Accuracy:0.3193220338983051
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:19 | Loss:3.0088090896606445 | Accuracy:0.3369491525423729
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:20 | Loss:3.221853494644165 | Accuracy:0.3213559322033898
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:21 | Loss:3.3212766647338867 | Accuracy:0.34576271186440677
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:22 | Loss:3.6144063472747803 | Accuracy:0.3247457627118644
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:23 | Loss:3.542142868041992 | Accuracy:0.34440677966101696
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:24 | Loss:3.8027701377868652 | Accuracy:0.32610169491525426
1
# train model with `BatchNorm` with `bs=4` on the `Pets` dataset
trainer = Trainer(gpus=gpus, max_epochs=25,
num_sanity_val_steps=1 if debug else 0)
trainer.fit(model_bn, train_dataloader=trn_loader, val_dataloaders=val_loader)
GPU available: True, used: True TPU available: False, using: 0 TPU cores CUDA_VISIBLE_DEVICES: [0] | Name | Type | Params -------------------------------- 0 | base | ResNet | 21 M
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:0 | Loss:4.403476715087891 | Accuracy:0.01966101694915254
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:1 | Loss:3.615051746368408 | Accuracy:0.03932203389830508
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:2 | Loss:3.6922903060913086 | Accuracy:0.05084745762711865
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:3 | Loss:3.4302172660827637 | Accuracy:0.062372881355932205
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:4 | Loss:3.351684331893921 | Accuracy:0.08271186440677966
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:5 | Loss:3.2836146354675293 | Accuracy:0.0935593220338983
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:6 | Loss:3.2269628047943115 | Accuracy:0.10915254237288136
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:7 | Loss:3.2704873085021973 | Accuracy:0.1023728813559322
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:8 | Loss:3.071798801422119 | Accuracy:0.1423728813559322
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:9 | Loss:3.0656063556671143 | Accuracy:0.15457627118644068
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:10 | Loss:3.0375216007232666 | Accuracy:0.17288135593220338
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:11 | Loss:2.8739380836486816 | Accuracy:0.2094915254237288
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:12 | Loss:2.7329418659210205 | Accuracy:0.23186440677966103
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:13 | Loss:2.737560510635376 | Accuracy:0.24813559322033898
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:14 | Loss:2.541532516479492 | Accuracy:0.27728813559322035
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:15 | Loss:2.540792226791382 | Accuracy:0.3064406779661017
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:16 | Loss:2.485729217529297 | Accuracy:0.3328813559322034
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:17 | Loss:2.7257814407348633 | Accuracy:0.31389830508474575
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:18 | Loss:3.07981276512146 | Accuracy:0.3247457627118644
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:19 | Loss:3.1801645755767822 | Accuracy:0.31661016949152543
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:20 | Loss:3.270585298538208 | Accuracy:0.3328813559322034
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:21 | Loss:3.355048656463623 | Accuracy:0.3376271186440678
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:22 | Loss:3.362093687057495 | Accuracy:0.29898305084745763
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:23 | Loss:3.470551013946533 | Accuracy:0.3389830508474576
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:24 | Loss:3.5411648750305176 | Accuracy:0.31254237288135595
1
batch_size=64
¶trn_loader = torch.utils.data.DataLoader(trn_dataset, batch_size=64, num_workers=4, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=64, num_workers=4, shuffle=False)
# redefine PL versions to remove trained weights
model_bn = Model(resnet34_bn)
model_gn = Model(resnet34_gn)
trainer = Trainer(gpus=gpus, max_epochs=25,
num_sanity_val_steps=1 if debug else 0)
trainer.fit(model_bn, train_dataloader=trn_loader, val_dataloaders=val_loader)
GPU available: True, used: True TPU available: False, using: 0 TPU cores CUDA_VISIBLE_DEVICES: [0] | Name | Type | Params -------------------------------- 0 | base | ResNet | 21 M
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:0 | Loss:4.571255683898926 | Accuracy:0.33152542372881355
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:1 | Loss:4.823599815368652 | Accuracy:0.33084745762711865
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:2 | Loss:4.738388538360596 | Accuracy:0.33152542372881355
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:3 | Loss:4.6921844482421875 | Accuracy:0.3383050847457627
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:4 | Loss:5.571420669555664 | Accuracy:0.3227118644067797
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:5 | Loss:4.973819255828857 | Accuracy:0.31864406779661014
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:6 | Loss:4.960039138793945 | Accuracy:0.31186440677966104
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:7 | Loss:4.72049617767334 | Accuracy:0.33152542372881355
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:8 | Loss:4.7438859939575195 | Accuracy:0.3410169491525424
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:9 | Loss:4.7650861740112305 | Accuracy:0.33220338983050846
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:10 | Loss:4.842560768127441 | Accuracy:0.33491525423728813
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:11 | Loss:5.002099514007568 | Accuracy:0.3410169491525424
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:12 | Loss:4.969579696655273 | Accuracy:0.3328813559322034
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:13 | Loss:4.797631740570068 | Accuracy:0.3328813559322034
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:14 | Loss:4.790388107299805 | Accuracy:0.33220338983050846
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:15 | Loss:4.84404993057251 | Accuracy:0.3464406779661017
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:16 | Loss:4.882577896118164 | Accuracy:0.3416949152542373
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:17 | Loss:4.831890106201172 | Accuracy:0.3403389830508475
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:18 | Loss:4.815413475036621 | Accuracy:0.34576271186440677
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:19 | Loss:4.880715370178223 | Accuracy:0.34779661016949154
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:20 | Loss:4.870474815368652 | Accuracy:0.34508474576271186
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:21 | Loss:4.8547258377075195 | Accuracy:0.3430508474576271
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:22 | Loss:4.814042568206787 | Accuracy:0.3505084745762712
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:23 | Loss:5.573678970336914 | Accuracy:0.29152542372881357
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:24 | Loss:4.861083030700684 | Accuracy:0.33220338983050846
1
trainer = Trainer(gpus=gpus, max_epochs=25,
num_sanity_val_steps=1 if debug else 0)
trainer.fit(model_gn, train_dataloader=trn_loader, val_dataloaders=val_loader)
GPU available: True, used: True TPU available: False, using: 0 TPU cores CUDA_VISIBLE_DEVICES: [0] | Name | Type | Params -------------------------------- 0 | base | ResNet | 21 M
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:0 | Loss:4.338170051574707 | Accuracy:0.36135593220338985
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:1 | Loss:4.264873027801514 | Accuracy:0.3593220338983051
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:2 | Loss:4.475521564483643 | Accuracy:0.368135593220339
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:3 | Loss:4.5568928718566895 | Accuracy:0.37559322033898307
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:4 | Loss:4.563418865203857 | Accuracy:0.36610169491525424
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:5 | Loss:4.532094955444336 | Accuracy:0.36677966101694914
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:6 | Loss:4.709390163421631 | Accuracy:0.36474576271186443
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:7 | Loss:4.703502178192139 | Accuracy:0.34983050847457625
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:8 | Loss:4.687512397766113 | Accuracy:0.36135593220338985
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:9 | Loss:4.453052997589111 | Accuracy:0.37559322033898307
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:10 | Loss:4.729727745056152 | Accuracy:0.3423728813559322
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:11 | Loss:4.887462139129639 | Accuracy:0.34847457627118644
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:12 | Loss:4.761058807373047 | Accuracy:0.36
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:13 | Loss:4.628625869750977 | Accuracy:0.36610169491525424
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:14 | Loss:4.939492225646973 | Accuracy:0.3735593220338983
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:15 | Loss:4.9373321533203125 | Accuracy:0.36
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:16 | Loss:4.884154796600342 | Accuracy:0.3701694915254237
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:17 | Loss:5.015425682067871 | Accuracy:0.34576271186440677
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:18 | Loss:5.0034356117248535 | Accuracy:0.34372881355932206
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:19 | Loss:5.081662178039551 | Accuracy:0.34372881355932206
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:20 | Loss:5.115207195281982 | Accuracy:0.3403389830508475
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:21 | Loss:4.923257827758789 | Accuracy:0.368135593220339
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:22 | Loss:5.064967632293701 | Accuracy:0.3701694915254237
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:23 | Loss:4.966062545776367 | Accuracy:0.368135593220339
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:24 | Loss:5.010922431945801 | Accuracy:0.376271186440678
1
batch_size=1
¶trn_loader = torch.utils.data.DataLoader(trn_dataset, batch_size=1, num_workers=4, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, num_workers=4, shuffle=False)
model_bn = Model(resnet34_bn)
model_gn = Model(resnet34_gn)
trainer = Trainer(gpus=gpus, max_epochs=25,
num_sanity_val_steps=1 if debug else 0)
trainer.fit(model_bn, train_dataloader=trn_loader, val_dataloaders=val_loader)
GPU available: True, used: True TPU available: False, using: 0 TPU cores CUDA_VISIBLE_DEVICES: [0] | Name | Type | Params -------------------------------- 0 | base | ResNet | 21 M
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:0 | Loss:3.6087236404418945 | Accuracy:0.04067796610169491
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:1 | Loss:3.8362090587615967 | Accuracy:0.025084745762711864
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:2 | Loss:3.6673178672790527 | Accuracy:0.03593220338983051
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:3 | Loss:3.7399044036865234 | Accuracy:0.03389830508474576
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:4 | Loss:4.054337501525879 | Accuracy:0.0488135593220339
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:5 | Loss:4.010653972625732 | Accuracy:0.04542372881355932
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:6 | Loss:4.764206886291504 | Accuracy:0.05288135593220339
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:7 | Loss:10.56059455871582 | Accuracy:0.04474576271186441
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:8 | Loss:5.048521041870117 | Accuracy:0.05830508474576271
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:9 | Loss:4.828557014465332 | Accuracy:0.06508474576271187
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:10 | Loss:7.225879192352295 | Accuracy:0.05694915254237288
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:11 | Loss:6.472527027130127 | Accuracy:0.06779661016949153
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:12 | Loss:9.755941390991211 | Accuracy:0.07050847457627119
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:13 | Loss:13.05939769744873 | Accuracy:0.059661016949152545
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:14 | Loss:18.591503143310547 | Accuracy:0.06508474576271187
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:15 | Loss:11.946345329284668 | Accuracy:0.06915254237288136
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:16 | Loss:16.744611740112305 | Accuracy:0.06983050847457627
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:17 | Loss:12.913531303405762 | Accuracy:0.07661016949152542
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:18 | Loss:23.76015281677246 | Accuracy:0.06508474576271187
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:19 | Loss:26.5297794342041 | Accuracy:0.06576271186440678
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:20 | Loss:35.212242126464844 | Accuracy:0.05898305084745763
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:21 | Loss:16.634546279907227 | Accuracy:0.06169491525423729
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:22 | Loss:21.815725326538086 | Accuracy:0.062372881355932205
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:23 | Loss:12.68907356262207 | Accuracy:0.0711864406779661
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:24 | Loss:19.639753341674805 | Accuracy:0.06779661016949153
1
trainer = Trainer(gpus=gpus, max_epochs=25,
num_sanity_val_steps=1 if debug else 0)
trainer.fit(model_gn, train_dataloader=trn_loader, val_dataloaders=val_loader)
GPU available: True, used: True TPU available: False, using: 0 TPU cores CUDA_VISIBLE_DEVICES: [0] | Name | Type | Params -------------------------------- 0 | base | ResNet | 21 M
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:0 | Loss:3.6178038120269775 | Accuracy:0.03593220338983051
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:1 | Loss:3.5887539386749268 | Accuracy:0.03864406779661017
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:2 | Loss:3.4937922954559326 | Accuracy:0.0576271186440678
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:3 | Loss:3.426539421081543 | Accuracy:0.06508474576271187
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:4 | Loss:3.4010708332061768 | Accuracy:0.06915254237288136
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:5 | Loss:3.352757453918457 | Accuracy:0.08949152542372882
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:6 | Loss:3.3006396293640137 | Accuracy:0.10033898305084746
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:7 | Loss:3.2513763904571533 | Accuracy:0.09966101694915254
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:8 | Loss:3.2186825275421143 | Accuracy:0.11254237288135593
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:9 | Loss:3.1824042797088623 | Accuracy:0.12067796610169491
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:10 | Loss:3.1842432022094727 | Accuracy:0.1152542372881356
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:11 | Loss:3.080850839614868 | Accuracy:0.1342372881355932
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:12 | Loss:3.1100575923919678 | Accuracy:0.1430508474576271
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:13 | Loss:3.085071563720703 | Accuracy:0.14508474576271185
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:14 | Loss:3.007901906967163 | Accuracy:0.17559322033898306
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:15 | Loss:3.1437573432922363 | Accuracy:0.1694915254237288
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:16 | Loss:3.110459089279175 | Accuracy:0.1864406779661017
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:17 | Loss:3.5012593269348145 | Accuracy:0.18847457627118644
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:18 | Loss:3.4454123973846436 | Accuracy:0.21084745762711865
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:19 | Loss:3.8177714347839355 | Accuracy:0.21152542372881356
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:20 | Loss:4.031371116638184 | Accuracy:0.1952542372881356
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:21 | Loss:4.404645919799805 | Accuracy:0.1769491525423729
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:22 | Loss:4.856805324554443 | Accuracy:0.1959322033898305
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Epoch:23 | Loss:4.558755874633789 | Accuracy:0.21152542372881356
IOPub message rate exceeded. The notebook server will temporarily stop sending output to the client in order to avoid crashing it. To change this limit, set the config variable `--NotebookApp.iopub_msg_rate_limit`. Current values: NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec) NotebookApp.rate_limit_window=3.0 (secs) IOPub message rate exceeded. The notebook server will temporarily stop sending output to the client in order to avoid crashing it. To change this limit, set the config variable `--NotebookApp.iopub_msg_rate_limit`. Current values: NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec) NotebookApp.rate_limit_window=3.0 (secs)
Model with GroupNorm
+ Standardised Weights
was able to achieve similar performance as BatchNorm
. Thus, GroupNorm
can be considered as an alternative to BatchNorm
.
GroupNorm
does not necessarily achieve better performance than BatchNorm
with lower batch size as reported in the paper for Pets
dataset.
The research paper uses Imagenet
dataset whereas this experiment was run using the Pets
dataset due to lack of compute required to train on Imagenet
For more details, refer to my blogpost
For bs=1
GroupNorm
performs significantly better than BatchNorm