#!/usr/bin/env python # coding: utf-8 # In[1]: get_ipython().run_line_magic('matplotlib', 'inline') import numpy as np import matplotlib.pyplot as plt from time import time from keras.models import Model, Sequential from keras.optimizers import Adam import keras.backend as K from keras.utils.generic_utils import Progbar from model import * # In[2]: # for resist GPU memory import tensorflow as tf config = tf.ConfigProto() config.gpu_options.allow_growth=True sess = tf.Session(config=config) K.set_session(sess) # In[3]: from keras.datasets import cifar100, cifar10 (x_train, y_train), (x_test, y_test) = cifar10.load_data() # In[4]: X = np.concatenate((x_test,x_train)) X.shape # In[5]: plt.imshow(X[9487]) # In[6]: #Hyperperemeter BATCHSIZE=64 LEARNING_RATE = 0.0002 TRAINING_RATIO = 1 BETA_1 = 0.0 BETA_2 = 0.9 EPOCHS = 500 BN_MIMENTUM = 0.9 BN_EPSILON = 0.00002 SAVE_DIR = 'img/generated_img_CIFAR10_ResNet/' GENERATE_ROW_NUM = 8 GENERATE_BATCHSIZE = GENERATE_ROW_NUM*GENERATE_ROW_NUM # In[7]: def wasserstein_loss(y_true, y_pred): return K.mean(y_true*y_pred) generator = BuildGenerator(bn_momentum=BN_MIMENTUM, bn_epsilon=BN_EPSILON) discriminator = BuildDiscriminator() Noise_input_for_training_generator = Input(shape=(128,)) Generated_image = generator(Noise_input_for_training_generator) Discriminator_output = discriminator(Generated_image) model_for_training_generator = Model(Noise_input_for_training_generator, Discriminator_output) print("model_for_training_generator") discriminator.trainable = False model_for_training_generator.summary() model_for_training_generator.compile(optimizer=Adam(LEARNING_RATE, beta_1=BETA_1, beta_2=BETA_2), loss=wasserstein_loss) # In[8]: Real_image = Input(shape=(32,32,3)) Noise_input_for_training_discriminator = Input(shape=(128,)) Fake_image = generator(Noise_input_for_training_discriminator) Discriminator_output_for_real = discriminator(Real_image) Discriminator_output_for_fake = discriminator(Fake_image) model_for_training_discriminator = Model([Real_image, Noise_input_for_training_discriminator], [Discriminator_output_for_real, Discriminator_output_for_fake]) print("model_for_training_discriminator") generator.trainable = False discriminator.trainable = True model_for_training_discriminator.compile(optimizer=Adam(LEARNING_RATE, beta_1=BETA_1, beta_2=BETA_2), loss=[wasserstein_loss, wasserstein_loss]) model_for_training_discriminator.summary() # In[11]: real_y = np.ones((BATCHSIZE, 1), dtype=np.float32) fake_y = -real_y # In[10]: X = X/255*2-1 # In[ ]: test_noise = np.random.randn(GENERATE_BATCHSIZE, 128) W_loss = [] discriminator_loss = [] generator_loss = [] for epoch in range(EPOCHS): np.random.shuffle(X) print("epoch {} of {}".format(epoch+1, EPOCHS)) num_batches = int(X.shape[0] // BATCHSIZE) print("number of batches: {}".format(int(X.shape[0] // (BATCHSIZE)))) progress_bar = Progbar(target=int(X.shape[0] // (BATCHSIZE * TRAINING_RATIO))) minibatches_size = BATCHSIZE * TRAINING_RATIO start_time = time() for index in range(int(X.shape[0] // (BATCHSIZE * TRAINING_RATIO))): progress_bar.update(index) discriminator_minibatches = X[index * minibatches_size:(index + 1) * minibatches_size] for j in range(TRAINING_RATIO): image_batch = discriminator_minibatches[j * BATCHSIZE : (j + 1) * BATCHSIZE] noise = np.random.randn(BATCHSIZE, 128).astype(np.float32) discriminator.trainable = True generator.trainable = False discriminator_loss.append(model_for_training_discriminator.train_on_batch([image_batch, noise], [real_y, fake_y])) discriminator.trainable = False generator.trainable = True generator_loss.append(model_for_training_generator.train_on_batch(np.random.randn(BATCHSIZE, 128), real_y)) print('\nepoch time: {}'.format(time()-start_time)) W_real = model_for_training_generator.evaluate(test_noise, real_y) print(W_real) W_fake = model_for_training_generator.evaluate(test_noise, fake_y) print(W_fake) W_l = W_real+W_fake print('wasserstein_loss: {}'.format(W_l)) W_loss.append(W_l) #Generate image generated_image = generator.predict(test_noise) generated_image = (generated_image+1)/2 for i in range(GENERATE_ROW_NUM): new = generated_image[i*GENERATE_ROW_NUM:i*GENERATE_ROW_NUM+GENERATE_ROW_NUM].reshape(32*GENERATE_ROW_NUM,32,3) if i!=0: old = np.concatenate((old,new),axis=1) else: old = new print('plot generated_image') plt.imsave('{}/SN_epoch_{}.png'.format(SAVE_DIR, epoch), old) # In[ ]: