Comparison between "Fast.AI's Rectangular Image Crop", "Imagenet Center Crop", "Test Time Augmentation"
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import urllib.request
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from validation_utils import sort_ar, map_idx2ar, ValDataset, RectangularCropTfm
data = Path.home()/'data/imagenet'
valdir = data/'validation'
batch_size = 128
idx_ar_sorted = sort_ar(data, valdir)
OR just download the precomputed indexes
idx2ar_path = data/'sorted_idxar.p'
url = 'https://s3-us-west-2.amazonaws.com/ashaw-fastai-imagenet/sorted_idxar.p'
if not idx2ar_path.exists(): urllib.request.urlretrieve(url, idx2ar_path)
idx_ar_sorted = sort_ar(data, valdir)
print('First five:', idx_ar_sorted[:5])
print('Last five:', idx_ar_sorted[-5:])
val_bs = 128
idx2ar, _ = map_idx2ar(idx_ar_sorted, val_bs)
target_size = 128
# Original Images
orig_dataset = datasets.ImageFolder(valdir, transforms.Resize(target_size))
# Center Crop
val_tfms = [transforms.Resize(int(target_size*1.14)), transforms.CenterCrop(target_size)]
val_dataset = datasets.ImageFolder(valdir, transforms.Compose(val_tfms))
# Rect
val_rect_tfms = [transforms.Resize(int(target_size*1.14)), RectangularCropTfm(idx2ar, target_size)]
val_dataset_rect = ValDataset(valdir, val_rect_tfms)
# TTA Augmentations
min_scale = 0.5
trn_tfms = [transforms.RandomResizedCrop(target_size, scale=(min_scale, 1.0)),
transforms.RandomHorizontalFlip()]
tta_aug_dataset = datasets.ImageFolder(valdir, transforms.Compose(trn_tfms))
def show_validation_image(sorted_index, figsize=(10,10)):
index, aspect_ratio = idx_ar_sorted[sorted_index]
print('Aspect ratio:', aspect_ratio)
f, (ax1, ax2, ax3) = plt.subplots(3, 2, figsize=figsize)
f.tight_layout(h_pad=6.0)
fontsize = 15
ax1[0].imshow(orig_dataset[index][0])
ax1[0].set_title('Original Image', fontsize=fontsize)
ax1[1].imshow(val_dataset_rect[index][0])
ax1[1].set_title('FastAi Rectangular Crop', fontsize=fontsize)
ax2[0].imshow(val_dataset[index][0])
ax2[0].set_title('Imagenet Center Crop', fontsize=fontsize)
ax2[1].imshow(tta_aug_dataset[index][0])
ax2[1].set_title('Test Time Augmentation 1', fontsize=fontsize)
ax3[0].imshow(tta_aug_dataset[index][0])
ax3[0].set_title('Test Time Augmentation 2', fontsize=fontsize)
ax3[1].imshow(tta_aug_dataset[index][0])
ax3[1].set_title('Test Time Augmentation 3', fontsize=fontsize)
Lots of these images are objects or figures. Helps to see the whole thing
show_validation_image(22)
show_validation_image(32)
show_validation_image(-5)
show_validation_image(-23)
Not much data loss here:
show_validation_image(len(idx_ar_sorted)//3)