Imagenet Validation Techniques Compared - Square VS Rectangle VS TTA

Here we compare the different validation techniques on a pretrained resnet50 model

In [1]:
%reload_ext autoreload
%autoreload 2
In [2]:
from validation_utils import sort_ar, chunks, map_idx2ar, ValDataset, SequentialIndexSampler, RectangularCropTfm, validate

import sys, os, shutil, time, warnings
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import urllib.request
import pandas as pd

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data.sampler import Sampler
In [5]:
cudnn.benchmark = True
data = Path.home()/'data/imagenet'
workers = 7
valdir = data/'validation'
batch_size = 64
fp16 = True

Step 1: Create Image to Aspect ratio mapping

In [6]:
idx_ar_sorted = sort_ar(data, valdir)

OR just download precomputed imagenet sizes

In [8]:
idx2ar_path = data/'sorted_idxar.p'
url = 'https://s3-us-west-2.amazonaws.com/ashaw-fastai-imagenet/sorted_idxar.p'
if not idx2ar_path.exists(): urllib.request.urlretrieve(url, idx2ar_path)
idx_ar_sorted = sort_ar(data, valdir)

Step 2: Get pretrained resnet model

In [9]:
import resnet
model = resnet.resnet50(pretrained=True)
model = model.cuda()
criterion = nn.CrossEntropyLoss().cuda()
if fp16: model = model.half()

Global dataset settings

In [11]:
val_bs = 64
target_size = 288

idx_sorted, _ = zip(*idx_ar_sorted)
idx2ar, ar_means = map_idx2ar(idx_ar_sorted, val_bs)
val_sampler_ar = SequentialIndexSampler(idx_sorted)

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
tensor_tfm = [transforms.ToTensor(), normalize]

Compare different Validations

Test Square Validation Technique

This was the validation technique used in fast.ai's original DAWNBenchmark
Resize Image 1.14x -> Crop to target size (288)

In [13]:
val_tfms = [transforms.Resize(int(target_size*1.14)), transforms.CenterCrop(target_size)] + tensor_tfm
val_dataset = datasets.ImageFolder(valdir,  transforms.Compose(val_tfms))

val_loader = torch.utils.data.DataLoader(
    val_dataset, batch_size=val_bs, shuffle=False,
    num_workers=workers, pin_memory=True, sampler=val_sampler_ar)

orig_prec5 = validate(val_loader, model, criterion)
Test: [100/782]	Time 0.074 (0.130)	Loss 1.4658 (1.0240)	Prec@1 59.375 (74.469)	Prec@5 85.938 (92.047)
Test: [200/782]	Time 0.055 (0.114)	Loss 0.4976 (0.9866)	Prec@1 85.938 (75.430)	Prec@5 100.000 (92.594)
Test: [300/782]	Time 0.055 (0.110)	Loss 0.9302 (0.9541)	Prec@1 75.000 (76.198)	Prec@5 93.750 (93.010)
Test: [400/782]	Time 0.064 (0.107)	Loss 0.7817 (0.9018)	Prec@1 81.250 (77.461)	Prec@5 95.312 (93.516)
Test: [500/782]	Time 0.056 (0.108)	Loss 0.3689 (0.9054)	Prec@1 93.750 (77.328)	Prec@5 98.438 (93.653)
Test: [600/782]	Time 0.056 (0.108)	Loss 1.2910 (0.9347)	Prec@1 70.312 (76.797)	Prec@5 90.625 (93.206)
Test: [700/782]	Time 0.055 (0.107)	Loss 0.4583 (0.9248)	Prec@1 85.938 (76.944)	Prec@5 96.875 (93.366)
Test: [782/782]	Time 1.049 (0.108)	Loss 1.2539 (0.9217)	Prec@1 62.500 (76.914)	Prec@5 93.750 (93.430)
Total Time:0.02348072333333333	 Top 5 Accuracy: 93.430

 * Prec@1 76.914 Prec@5 93.430

Test Fast.Ai Rectangular Validation

Perform validation with rectangular images!

In [14]:
val_ar_tfms = [transforms.Resize(int(target_size*1.14)), RectangularCropTfm(idx2ar, target_size)]
val_dataset_rect = ValDataset(valdir, val_ar_tfms+tensor_tfm)
val_loader = torch.utils.data.DataLoader(
    val_dataset_rect, batch_size=val_bs, shuffle=False,
    num_workers=workers, pin_memory=True, sampler=val_sampler_ar)

rect_prec5 = validate(val_loader, model, criterion)
Test: [100/782]	Time 0.100 (0.329)	Loss 1.4004 (1.0478)	Prec@1 59.375 (76.172)	Prec@5 87.500 (93.422)
Test: [200/782]	Time 0.058 (0.268)	Loss 0.5542 (1.0024)	Prec@1 85.938 (76.594)	Prec@5 100.000 (93.539)
Test: [300/782]	Time 0.075 (0.234)	Loss 0.9624 (0.9680)	Prec@1 76.562 (76.948)	Prec@5 92.188 (93.698)
Test: [400/782]	Time 0.071 (0.208)	Loss 0.9170 (0.9237)	Prec@1 79.688 (77.914)	Prec@5 93.750 (94.078)
Test: [500/782]	Time 0.071 (0.189)	Loss 0.3828 (0.9363)	Prec@1 90.625 (77.713)	Prec@5 98.438 (94.056)
Test: [600/782]	Time 0.071 (0.177)	Loss 1.2441 (0.9607)	Prec@1 70.312 (77.188)	Prec@5 93.750 (93.693)
Test: [700/782]	Time 0.080 (0.174)	Loss 0.5312 (0.9544)	Prec@1 85.938 (77.304)	Prec@5 95.312 (93.846)
Test: [782/782]	Time 1.264 (0.204)	Loss 0.8638 (0.9548)	Prec@1 75.000 (77.354)	Prec@5 93.750 (93.914)
Total Time:0.04431332388888889	 Top 5 Accuracy: 93.914

 * Prec@1 77.354 Prec@5 93.914

Comparison Square VS Rectangles

  • Square Validation
    • Top 5 - 93.430
    • Total Time - 0.0235
  • Rectangular Validation
    • Top 5 - 93.914
    • Total Time - 0.0443
In [27]:
def batch_mean(array, size=10): return [np.array(c).mean() for c in chunks(array, 100)]
batch_means = batch_mean(ar_means)
rect_prec5_mean = batch_mean(rect_prec5)
orig_prec5_mean = batch_mean(orig_prec5)
In [30]:
d = {'OriginalValidation': orig_prec5_mean, 
     'RectangularValidation': rect_prec5_mean, 
     'AR Mean': batch_means,
     'Difference': np.array(rect_prec5_mean)-np.array(orig_prec5_mean)}
df = pd.DataFrame(data=d); df
Out[30]:
OriginalValidation RectangularValidation AR Mean Difference
0 92.046875 93.421875 0.704379 1.375000
1 93.140625 93.656250 0.806230 0.515625
2 93.843750 94.015625 1.072789 0.171875
3 95.031250 95.218750 1.301455 0.187500
4 94.203125 93.968750 1.333330 -0.234375
5 90.968750 91.875000 1.333330 0.906250
6 94.328125 94.765625 1.406869 0.437500
7 93.978659 94.493140 1.585774 0.514482

You can see that rectangular validation outperforms the original when the aspect ratio is farther away from 1 (square crop)

Validate with TTA (Test Time Augmentation)

Take 4 random crops + original validation image and averages the predictions together

In [31]:
min_scale = 0.5
trn_tfms = [
        transforms.RandomResizedCrop(target_size, scale=(min_scale, 1.0)),
        transforms.RandomHorizontalFlip(),
    ] + tensor_tfm
aug_dataset = datasets.ImageFolder(valdir, transforms.Compose(trn_tfms))

val_tfms = [transforms.Resize(int(target_size*1.14)), transforms.CenterCrop(target_size)] + tensor_tfm
val_dataset = datasets.ImageFolder(valdir,  transforms.Compose(val_tfms))

aug_loader = torch.utils.data.DataLoader(
    aug_dataset, batch_size=val_bs, shuffle=False,
    num_workers=workers, pin_memory=True, sampler=val_sampler_ar)

val_loader = torch.utils.data.DataLoader(
    val_dataset, batch_size=val_bs, shuffle=False,
    num_workers=workers, pin_memory=True, sampler=val_sampler_ar)
In [32]:
tta_prec5 = validate(val_loader, model, criterion, aug_loader=aug_loader, num_augmentations=4)
Test: [100/782]	Time 0.393 (0.494)	Loss 1.4531 (1.0414)	Prec@1 59.375 (76.203)	Prec@5 87.500 (93.359)
Test: [200/782]	Time 0.368 (0.461)	Loss 0.5205 (1.0151)	Prec@1 85.938 (76.742)	Prec@5 100.000 (93.555)
Test: [300/782]	Time 0.348 (0.461)	Loss 0.9717 (0.9848)	Prec@1 76.562 (77.208)	Prec@5 92.188 (93.693)
Test: [400/782]	Time 0.349 (0.454)	Loss 0.8115 (0.9328)	Prec@1 82.812 (78.309)	Prec@5 96.875 (94.105)
Test: [500/782]	Time 0.463 (0.458)	Loss 0.4075 (0.9356)	Prec@1 95.312 (78.256)	Prec@5 98.438 (94.194)
Test: [600/782]	Time 0.464 (0.456)	Loss 1.3145 (0.9632)	Prec@1 73.438 (77.776)	Prec@5 93.750 (93.831)
Test: [700/782]	Time 0.368 (0.453)	Loss 0.4526 (0.9523)	Prec@1 87.500 (77.931)	Prec@5 98.438 (94.011)
Test: [782/782]	Time 0.078 (0.447)	Loss 1.1768 (0.9465)	Prec@1 68.750 (78.010)	Prec@5 87.500 (94.094)
Total Time:0.09710323	 Top 5 Accuracy: 94.094

 * Prec@1 78.010 Prec@5 94.094

Validate with TTA and Rectangles

Take 4 random crops + recangular validation image and averages the predictions together

In [33]:
min_scale = 0.5
trn_tfms = [
        transforms.RandomResizedCrop(target_size, scale=(min_scale, 1.0)),
        transforms.RandomHorizontalFlip(),
    ] + tensor_tfm
aug_dataset = datasets.ImageFolder(valdir, transforms.Compose(trn_tfms))

aug_loader = torch.utils.data.DataLoader(
    aug_dataset, batch_size=val_bs, shuffle=False,
    num_workers=workers, pin_memory=True, sampler=val_sampler_ar)

val_ar_tfms = [transforms.Resize(int(target_size*1.14)), RectangularCropTfm(idx2ar, target_size)]
val_dataset_rect = ValDataset(valdir, val_ar_tfms+tensor_tfm)
val_loader = torch.utils.data.DataLoader(
    val_dataset_rect, batch_size=val_bs, shuffle=False,
    num_workers=workers, pin_memory=True, sampler=val_sampler_ar)

tta_rect_prec5 = validate(val_loader, model, criterion, aug_loader=aug_loader, num_augmentations=4)
Test: [100/782]	Time 0.450 (0.533)	Loss 1.4590 (1.0473)	Prec@1 60.938 (76.484)	Prec@5 87.500 (93.344)
Test: [200/782]	Time 0.404 (0.511)	Loss 0.5361 (1.0162)	Prec@1 92.188 (77.000)	Prec@5 100.000 (93.594)
Test: [300/782]	Time 0.384 (0.507)	Loss 1.0049 (0.9867)	Prec@1 75.000 (77.469)	Prec@5 92.188 (93.771)
Test: [400/782]	Time 0.500 (0.504)	Loss 0.8462 (0.9368)	Prec@1 84.375 (78.496)	Prec@5 96.875 (94.160)
Test: [500/782]	Time 0.408 (0.505)	Loss 0.3413 (0.9410)	Prec@1 93.750 (78.428)	Prec@5 98.438 (94.219)
Test: [600/782]	Time 0.352 (0.501)	Loss 1.2451 (0.9683)	Prec@1 71.875 (77.893)	Prec@5 93.750 (93.872)
Test: [700/782]	Time 0.431 (0.498)	Loss 0.4773 (0.9577)	Prec@1 87.500 (78.013)	Prec@5 98.438 (94.069)
Test: [782/782]	Time 0.114 (0.494)	Loss 1.1299 (0.9529)	Prec@1 68.750 (78.090)	Prec@5 93.750 (94.168)
Total Time:0.10739349833333334	 Top 5 Accuracy: 94.168

 * Prec@1 78.090 Prec@5 94.168

Comparing all the Techniques

  • Square Validation
    • Top 5 - 93.430
    • Total Time - 0.0235
  • Rectangular Validation
    • Top 5 - 93.914
    • Total Time - 0.0443
  • TTA
    • Top 5 - 94.094
    • Total Time - 0.0971
  • TTA + Rectangles
    • Top 5 - 94.168
    • Total Time - 0.1074
In [34]:
def batch_mean(array, size=10): return [np.array(c).mean() for c in chunks(array, 100)]
batch_means = batch_mean(ar_means)
rect_prec5_mean = batch_mean(rect_prec5)
orig_prec5_mean = batch_mean(orig_prec5)
tta_prec5_mean = batch_mean(tta_prec5)
tta_rect_prec5_mean = batch_mean(tta_rect_prec5)
In [36]:
d = {'Original Validation': orig_prec5_mean, 
     'Rectangular Validation': rect_prec5_mean, 
     'TTA Validation': tta_prec5_mean, 
     'TTA + Rectangular Validation': tta_rect_prec5_mean, 
     'AR Mean': batch_means}
df = pd.DataFrame(data=d); df
Out[36]:
Original Validation Rectangular Validation TTA Validation TTA + Rectangular Validation AR Mean
0 92.046875 93.421875 93.359375 93.343750 0.704379
1 93.140625 93.656250 93.750000 93.843750 0.806230
2 93.843750 94.015625 93.968750 94.125000 1.072789
3 95.031250 95.218750 95.343750 95.328125 1.301455
4 94.203125 93.968750 94.546875 94.453125 1.333330
5 90.968750 91.875000 92.015625 92.140625 1.333330
6 94.328125 94.765625 95.093750 95.250000 1.406869
7 93.978659 94.493140 94.740854 95.007622 1.585774