Pretrained GAN

In [ ]:
import fastai
from fastai.vision import *
from fastai.callbacks import *
from fastai.vision.gan import *
In [ ]:
path = untar_data(URLs.PETS)
path_hr = path/'images'
path_lr = path/'crappy'

Crappified data

Prepare the input data by crappifying images.

In [ ]:
from crappify import *

Uncomment the first time you run this notebook.

In [ ]:
#il = ImageList.from_folder(path_hr)
#parallel(crappifier(path_lr, path_hr), il.items)

For gradual resizing we can change the commented line here.

In [ ]:
bs,size=32, 128
# bs,size = 24,160
#bs,size = 8,256
arch = models.resnet34

Pre-train generator

Now let's pretrain the generator.

In [ ]:
arch = models.resnet34
src = ImageImageList.from_folder(path_lr).split_by_rand_pct(0.1, seed=42)
In [ ]:
def get_data(bs,size):
    data = (src.label_from_func(lambda x: path_hr/x.name)
           .transform(get_transforms(max_zoom=2.), size=size, tfm_y=True)
           .databunch(bs=bs).normalize(imagenet_stats, do_y=True))

    data.c = 3
    return data
In [ ]:
data_gen = get_data(bs,size)
In [ ]:
data_gen.show_batch(4)
In [ ]:
wd = 1e-3
In [ ]:
y_range = (-3.,3.)
In [ ]:
loss_gen = MSELossFlat()
In [ ]:
def create_gen_learner():
    return unet_learner(data_gen, arch, wd=wd, blur=True, norm_type=NormType.Weight,
                         self_attention=True, y_range=y_range, loss_func=loss_gen)
In [ ]:
learn_gen = create_gen_learner()
In [ ]:
learn_gen.fit_one_cycle(2, pct_start=0.8)
Total time: 01:35

epoch train_loss valid_loss
1 0.061653 0.053493
2 0.051248 0.047272
In [ ]:
learn_gen.unfreeze()
In [ ]:
learn_gen.fit_one_cycle(3, slice(1e-6,1e-3))
Total time: 02:24

epoch train_loss valid_loss
1 0.050429 0.046088
2 0.049056 0.043954
3 0.045437 0.043146
In [ ]:
learn_gen.show_results(rows=4)
In [ ]:
learn_gen.save('gen-pre2')

Save generated images

In [ ]:
learn_gen.load('gen-pre2');
In [ ]:
name_gen = 'image_gen'
path_gen = path/name_gen
In [ ]:
# shutil.rmtree(path_gen)
In [ ]:
path_gen.mkdir(exist_ok=True)
In [ ]:
def save_preds(dl):
    i=0
    names = dl.dataset.items
    
    for b in dl:
        preds = learn_gen.pred_batch(batch=b, reconstruct=True)
        for o in preds:
            o.save(path_gen/names[i].name)
            i += 1
In [ ]:
save_preds(data_gen.fix_dl)
In [ ]:
PIL.Image.open(path_gen.ls()[0])
Out[ ]:

Train critic

In [ ]:
learn_gen=None
gc.collect()
Out[ ]:
3755

Pretrain the critic on crappy vs not crappy.

In [ ]:
def get_crit_data(classes, bs, size):
    src = ImageList.from_folder(path, include=classes).split_by_rand_pct(0.1, seed=42)
    ll = src.label_from_folder(classes=classes)
    data = (ll.transform(get_transforms(max_zoom=2.), size=size)
           .databunch(bs=bs).normalize(imagenet_stats))
    data.c = 3
    return data
In [ ]:
data_crit = get_crit_data([name_gen, 'images'], bs=bs, size=size)
In [ ]:
data_crit.show_batch(rows=3, ds_type=DatasetType.Train, imgsize=3)