from mnist_ten.data import train_labeled_loader, train_rotoflip_loader
import matplotlib.pyplot as plt
def show(image_tensor, title=None):
plt.title(title)
plt.imshow(image_tensor.cpu().numpy()[0], cmap='gray')
plt.show()
def preview_loader(loader, label_type):
batch = next(iter(loader))
image, label = batch[0][0], batch[1][0]
show(image, f'{label_type}: {label}')
for i in range(4):
preview_loader(train_rotoflip_loader, label_type='Rotation/flip ID')
for i in range(4):
preview_loader(train_labeled_loader, label_type='Digit')