#!/usr/bin/env python # coding: utf-8 # Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks. # - Author: Sebastian Raschka # - GitHub Repository: https://github.com/rasbt/deeplearning-models # In[1]: get_ipython().run_line_magic('load_ext', 'watermark') get_ipython().run_line_magic('watermark', "-a 'Sebastian Raschka' -v -p torch") # # Using PyTorch Dataset Loading Utilities for Custom Dataset -- Asian Face Dataset (AFAD) # This notebook provides an example for how to prepare a custom dataset for PyTorch's data loading utilities. More in-depth information can be found in the official documentation at: # # - [Data Loading and Processing Tutorial](http://pytorch.org/tutorials/beginner/data_loading_tutorial.html) # - [torch.utils.data](http://pytorch.org/docs/master/data.html) API documentation # # In this example, we are using the Asian Face Dataset (AFAD), which is a face image dataset with age labels [1]. There are two versions of this dataset, a smaller Lite version and the full version, which are available at # # - https://github.com/afad-dataset/tarball-lite # - https://github.com/afad-dataset/tarball # # Here, we will be working with the Lite dataset, but the same code can be used for the full dataset as well -- the Lite # dataset is just slightly smaller than the full dataset and thus faster to process. # # [1] Niu, Z., Zhou, M., Wang, L., Gao, X., & Hua, G. (2016). Ordinal regression with multiple output cnn for age estimation. In Proceedings of the IEEE conference on computer vision and pattern recognition (pp. 4920-4928). # ## Imports # In[2]: import time import os import pandas as pd import numpy as np from PIL import Image from torchvision import datasets from torchvision import transforms from torch.utils.data import DataLoader from torch.utils.data import SubsetRandomSampler from torch.utils.data import Dataset import torch.nn.functional as F import torch if torch.cuda.is_available(): torch.backends.cudnn.deterministic = True # ## Downloading the Dataset # The following lines of code (bash commands) will download, unzip, and untar the dataset from GitHub. # In[3]: # Download get_ipython().system('git clone https://github.com/afad-dataset/tarball-lite.git') # In[4]: # Join individual tars get_ipython().system('cat tarball-lite/AFAD-Lite.tar.xz* > tarball-lite/AFAD-Lite.tar.xz') # In[5]: # "Unzip" get_ipython().system('tar xf tarball-lite/AFAD-Lite.tar.xz') # In[6]: # Get image paths rootDir = 'AFAD-Lite' files = [os.path.relpath(os.path.join(dirpath, file), rootDir) for (dirpath, dirnames, filenames) in os.walk(rootDir) for file in filenames if file.endswith('.jpg')] # In[7]: print(f'Number of images in total: {len(files)}') # ## Creating Label Files (CSVs) # In[8]: d = {} d['age'] = [] d['gender'] = [] d['file'] = [] d['path'] = [] for f in files: age, gender, fname = f.split('/') if gender == '111': gender = 'male' else: gender = 'female' d['age'].append(age) d['gender'].append(gender) d['file'].append(fname) d['path'].append(f) # In[9]: df = pd.DataFrame.from_dict(d) df.head() # Normalize labels such that they start with `0`: # In[10]: df['age'].min() # In[11]: df['age'] = df['age'].values.astype(int) - int(df['age'].min()) # Seperate dataset into training and test subsets: # In[12]: np.random.seed(123) msk = np.random.rand(len(df)) < 0.8 df_train = df[msk] df_test = df[~msk] # Save data partitioning as CSV: # In[13]: df_train.to_csv('training_set_lite.csv', index=False) # In[14]: df_test.to_csv('test_set_lite.csv', index=False) # In[15]: num_ages = np.unique(df['age'].values).shape[0] print(f'Number of age labels: {num_ages}') # In[16]: print(f'Number of training examples: {df_train.shape[0]}') print(f'Number of test examples: {df_test.shape[0]}') # ## Implementing a Custom Dataset Class # In[17]: class AFADDatasetAge(Dataset): """Custom Dataset for loading AFAD face images""" def __init__(self, csv_path, img_dir, transform=None): df = pd.read_csv(csv_path) self.img_dir = img_dir self.csv_path = csv_path self.df = df self.y = df['age'].values self.transform = transform def __getitem__(self, index): img = Image.open(os.path.join(self.img_dir, self.df.iloc[index]['path'])) if self.transform is not None: img = self.transform(img) label = self.y[index] return img, label def __len__(self): return self.y.shape[0] # ## Setting Up DataLoaders # In[18]: TRAIN_CSV_PATH = 'training_set_lite.csv' TEST_CSV_PATH = 'test_set_lite.csv' IMAGE_PATH = 'AFAD-Lite' BATCH_SIZE = 128 # In[19]: test_transform = transforms.Compose([transforms.Resize((128, 128)), transforms.CenterCrop((120, 120)), transforms.ToTensor()]) test_dataset = AFADDatasetAge(csv_path=TEST_CSV_PATH, img_dir=IMAGE_PATH, transform=test_transform) test_loader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, num_workers=4, shuffle=False) # Checking the dataset for images, labels in test_loader: print('Image batch dimensions:', images.shape) print('Image label dimensions:', labels.shape) break # In[20]: train_indices = torch.arange(0, 46000).numpy() valid_indices = torch.arange(46000, 47524).numpy() train_sampler = SubsetRandomSampler(train_indices) valid_sampler = SubsetRandomSampler(valid_indices) train_transform = transforms.Compose([transforms.Resize((128, 128)), transforms.RandomCrop((120, 120)), transforms.ToTensor()]) test_transform = transforms.Compose([transforms.Resize((128, 128)), transforms.CenterCrop((120, 120)), transforms.ToTensor()]) train_dataset = AFADDatasetAge(csv_path=TRAIN_CSV_PATH, img_dir=IMAGE_PATH, transform=train_transform) valid_dataset = AFADDatasetAge(csv_path=TRAIN_CSV_PATH, img_dir=IMAGE_PATH, transform=test_transform) test_dataset = AFADDatasetAge(csv_path=TEST_CSV_PATH, img_dir=IMAGE_PATH, transform=test_transform) train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, num_workers=4, sampler=train_sampler) valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, num_workers=4, sampler=valid_sampler) test_loader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, num_workers=4, shuffle=False) # In[21]: # Checking the dataset for images, labels in test_loader: print('Image batch dimensions:', images.shape) print('Image label dimensions:', labels.shape) break for images, labels in valid_loader: print('Image batch dimensions:', images.shape) print('Image label dimensions:', labels.shape) break for images, labels in train_loader: print('Image batch dimensions:', images.shape) print('Image label dimensions:', labels.shape) break # ## Iterating through the Custom Dataset # In[22]: device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") torch.manual_seed(0) num_epochs = 2 for epoch in range(num_epochs): for batch_idx, (x, y) in enumerate(train_loader): print('Epoch:', epoch+1, end='') print(' | Batch index:', batch_idx, end='') print(' | Batch size:', y.size()[0]) x = x.to(device) y = y.to(device) break