%reload_ext autoreload
%autoreload 2
from fastai import *
from fastai.vision import *
For this lesson, we'll be using the bedrooms from the LSUN dataset. The full dataset is a bit too large so we'll use a sample from kaggle.
path = Config.data_path()/'lsun'
path.mkdir(parents=True, exist_ok=True)
path
PosixPath('/home/ubuntu/.fastai/data/lsun')
Uncomment the next commands to download and extract the data in your machine.
#! kaggle datasets download -d jhoward/lsun_bedroom -p {path}
#! unzip -q -n {path}/lsun_bedroom.zip -d {path}
#! unzip -q -n {path}/sample.zip -d {path}
We then grab all the images in the folder with the data block API. We don't create a validation set here for reasons we'll explain later.
def get_data(bs, size):
train_ds = (ImageItemList.from_folder(path).label_empty()
.transform(tfms=[crop_pad(size=size, row_pct=(0,1), col_pct=(0,1))],size=size))
return (ImageDataBunch.create(train_ds, valid_ds=None, path=path, bs=bs)
.normalize(stats = [torch.tensor([0.5,0.5,0.5]), torch.tensor([0.5,0.5,0.5])]))
We'll begin with a small side and use gradual resizing.
data = get_data(64, 64)
data.show_batch(rows=5)
GAN stands for Generative Adversarial Nets and were invented by Ian Goodfellow. The concept is that we will train two models at the same time: a generator and a discriminator. The generator will try to make new images similar to the ones in our dataset, and the discriminator job will try to classify real images from the ones the generator does. The generator returns images, the discriminator a single number (usually 0. for fake images and 1. for real ones).
We train them against each other in the sense that at each step (more or less), we:
real
)fake
)Here, we'll use the Wassertein GAN.
def gan_learner(data, generator, discriminator, loss_funcD=None, loss_funcG=None, noise_size:int=None, wgan:bool=False,
**kwargs):
gan = models.GAN(generator, discriminator)
learn = Learner(data, gan, loss_func=NoopLoss(), **kwargs)
if wgan: loss_funcD,loss_funcG = WasserteinLoss(),noop
if noise_size is None: cb = callbacks.GANTrainer(learn, loss_funcD, loss_funcG)
else: cb = callbacks.NoisyGANTrainer(learn, loss_funcD, loss_funcG, bs=data.batch_size, noise_sz=noise_size)
learn.callbacks.append(cb)
return learn
generator = models.basic_generator(in_size=64, n_channels=3, n_extra_layers=1)
discriminator = models.basic_discriminator(in_size=64, n_channels=3, n_extra_layers=1)
learn = gan_learner(data, generator, discriminator, wgan=True, noise_size=100,
opt_func=optim.RMSprop, wd=0.)
learn.fit(1, 1e-4)
Total time: 15:05 epoch train_loss valid_loss gen_loss disc_loss 1 -0.716458 0.670430 -1.410371 (15:05)
x,y = next(iter(learn.data.train_dl))
tst = learn.model(learn.callbacks[0].input_fake(x, grad=False), gen=True)
img = data.denorm(tst[0].cpu()).numpy().clip(0,1).transpose(1,2,0)
plt.imshow(img)
<matplotlib.image.AxesImage at 0x7fb6dc85b048>
learn.model(tst)
tensor([-0.1581], device='cuda:0', grad_fn=<ViewBackward>)
learn.model(x)
tensor([-0.2798], device='cuda:0', grad_fn=<ViewBackward>)