Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.

Custom DataLoader Example for PNG files

Illustration of how we can efficiently iterate through custom (image) datasets. For this, suppose

  • mnist_train, mnist_valid, and mnist_test are image folders you created with your own custom images
  • mnist_train.csv, mnist_valid.csv, and mnist_test.csv are tables that store the image names with their associated class labels
In [1]:
%load_ext watermark
%watermark -a 'Sebastian Raschka' -v -p torch,pandas,numpy,matplotlib
Sebastian Raschka 

CPython 3.7.1
IPython 7.2.0

torch 1.0.1
pandas 0.24.0
numpy 1.15.4
matplotlib 3.0.2

1) Inspecting the Dataset

In [2]:
%matplotlib inline
import matplotlib.pyplot as plt
from PIL import Image
In [3]:
im = Image.open('mnist_train/1.png')
plt.imshow(im)
Out[3]:
<matplotlib.image.AxesImage at 0x11ea383c8>
In [4]:
import numpy as np

im_array = np.array(im)
print('Array Dimensions', im_array.shape)
print()
print(im_array)
Array Dimensions (28, 28)

[[  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   1  18  38 136 227 255
  254 132   0  90 136  98   3   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0  82 156 253 253 253 253 253
  253 249 154 219 253 253  35   0   0   0]
 [  0   0   0   0   0   0   0   0   0  40 150 244 253 253 253 253 253 253
  253 253 253 253 253 253  35   0   0   0]
 [  0   0   0   0   0   0   0   0  74 237 253 253 253 253 253 203 182 242
  253 253 253 253 253 230  25   0   0   0]
 [  0   0   0   0   0   0   0  13 200 253 253 253 168 164  91  14  64 246
  253 253 253 195  79  32   0   0   0   0]
 [  0   0   0   0   0   0   0  21 219 253 253 159   2   0   0 103 233 253
  253 253 177  10   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0 171 253 253 147   0   1 155 250 253 253
  251 126   5   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0 101 236 253 206  32 152 253 253 253 253
  130   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0  91 253 253 253 253 253 253 241 113
    9   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0  91 243 253 253 253 253 239  81   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0 207 253 253 253 253 158   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0 207 253 253 253 253 121   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0  24 145 249 253 253 253 253 194   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0  59 253 253 253 253 253 253 224  30   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   5 181 253 253 241 114 240 253 253 136   5
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0  36 253 253 253 125   0  65 253 253 253  41
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0  67 253 253 253  29   2 138 253 253 253  41
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0  60 253 253 253 207 202 253 253 253 192   9
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   5 183 253 253 253 253 253 253 230  52   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0  62 253 253 253 253 242 116  13   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]]










In [5]:
import pandas as pd
In [6]:
df_train = pd.read_csv('mnist_train.csv')
print(df_train.shape)
df_train.head()
(256, 2)
Out[6]:
Class Label File Name
0 5 0.png
1 8 1.png
2 8 2.png
3 0 3.png
4 9 4.png
In [7]:
df_valid = pd.read_csv('mnist_valid.csv')
print(df_valid.shape)
df_valid.head()
(256, 2)
Out[7]:
Class Label File Name
0 0 256.png
1 8 257.png
2 7 258.png
3 4 259.png
4 7 260.png
In [8]:
df_test = pd.read_csv('mnist_test.csv')
print(df_test.shape)
df_test.head()
(256, 2)
Out[8]:
Class Label File Name
0 4 512.png
1 0 513.png
2 6 514.png
3 8 515.png
4 4 516.png







2) Custom Dataset Class

In [9]:
import torch
from PIL import Image
from torch.utils.data import Dataset
import os



class MyDataset(Dataset):

    def __init__(self, csv_path, img_dir, transform=None):
    
        df = pd.read_csv(csv_path)
        self.img_dir = img_dir
        self.img_names = df['File Name']
        self.y = df['Class Label']
        self.transform = transform

    def __getitem__(self, index):
        img = Image.open(os.path.join(self.img_dir,
                                      self.img_names[index]))
        
        if self.transform is not None:
            img = self.transform(img)
        
        label = self.y[index]
        return img, label

    def __len__(self):
        return self.y.shape[0]







3) Custom Dataloader

In [10]:
from torchvision import transforms
from torch.utils.data import DataLoader


# Note that transforms.ToTensor()
# already divides pixels by 255. internally

custom_transform = transforms.Compose([#transforms.Lambda(lambda x: x/255.), # not necessary
                                       transforms.ToTensor()
                                      ])

train_dataset = MyDataset(csv_path='mnist_train.csv',
                          img_dir='mnist_train',
                          transform=custom_transform)

train_loader = DataLoader(dataset=train_dataset,
                          batch_size=32,
                          shuffle=True, # want to shuffle the dataset
                          num_workers=4) # number processes/CPUs to use




4) Iterating Through the Dataset

In [11]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.manual_seed(0)

num_epochs = 2
for epoch in range(num_epochs):

    for batch_idx, (x, y) in enumerate(train_loader):
        
        print('Epoch:', epoch+1, end='')
        print(' | Batch index:', batch_idx, end='')
        print(' | Batch size:', y.size()[0])
        
        x = x.to(device)
        y = y.to(device)
Epoch: 1 | Batch index: 0 | Batch size: 32
Epoch: 1 | Batch index: 1 | Batch size: 32
Epoch: 1 | Batch index: 2 | Batch size: 32
Epoch: 1 | Batch index: 3 | Batch size: 32
Epoch: 1 | Batch index: 4 | Batch size: 32
Epoch: 1 | Batch index: 5 | Batch size: 32
Epoch: 1 | Batch index: 6 | Batch size: 32
Epoch: 1 | Batch index: 7 | Batch size: 32
Epoch: 2 | Batch index: 0 | Batch size: 32
Epoch: 2 | Batch index: 1 | Batch size: 32
Epoch: 2 | Batch index: 2 | Batch size: 32
Epoch: 2 | Batch index: 3 | Batch size: 32
Epoch: 2 | Batch index: 4 | Batch size: 32
Epoch: 2 | Batch index: 5 | Batch size: 32
Epoch: 2 | Batch index: 6 | Batch size: 32
Epoch: 2 | Batch index: 7 | Batch size: 32
In [12]:
print(x.shape)
torch.Size([32, 1, 28, 28])
In [13]:
x_image_as_vector = x.view(-1, 28*28)
print(x_image_as_vector.shape)
torch.Size([32, 784])
In [14]:
x
Out[14]:
tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        ...,


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]]])