# for tf 2.0 !pip install -U tensorflow-gpu import tensorflow as tf import matplotlib.pyplot as plt import numpy as np import os import PIL import time from skimage.io import imshow from IPython.display import display tf.__version__ (train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.cifar10.load_data() train_images.dtype, train_images.shape imshow(train_images[2]) def img_to_float(img): return (np.float32(img)-127.5)/127.5 def img_to_uint8(img): return np.uint8(img*127.5+128).clip(0, 255) train_img_f32 = img_to_float(train_images) imshow(img_to_uint8(train_img_f32[30])) BUFFER_SIZE = train_img_f32.shape[0] BATCH_SIZE = 32 train_dataset = tf.data.Dataset.from_tensor_slices(train_img_f32).shuffle(BUFFER_SIZE).batch(BATCH_SIZE) from tensorflow.keras.layers import Dense, BatchNormalization from tensorflow.keras.layers import LeakyReLU, Reshape, Conv2DTranspose, Conv2D latent_dim = 64 generator = tf.keras.Sequential([ Dense(4*4*256, use_bias=False, input_shape=(latent_dim,)), BatchNormalization(), LeakyReLU(), Reshape((4, 4, 256)), Conv2DTranspose(128, (4, 4), strides=(2, 2), padding='same', use_bias=False), BatchNormalization(), LeakyReLU(), Conv2DTranspose(64, (4, 4), strides=(2, 2), padding='same', use_bias=False), BatchNormalization(), LeakyReLU(), Conv2DTranspose(32, (4, 4), strides=(2, 2), padding='same', use_bias=False), BatchNormalization(), LeakyReLU(), Conv2D(3, (3, 3), strides=(1, 1), padding='same', activation='tanh') ] ) from tensorflow.keras.layers import Conv2D, Dropout, Flatten discriminator = tf.keras.Sequential([ Conv2D(64, (3, 3), strides=(2, 2), padding='same', input_shape=(32,32, 3)), LeakyReLU(), Conv2D(128, (3, 3), strides=(2, 2), padding='same',use_bias=False), BatchNormalization(), LeakyReLU(), Conv2D(128, (3, 3), strides=(2, 2), padding='same',use_bias=False), BatchNormalization(), LeakyReLU(), Conv2D(128, (3, 3), strides=(2, 2), padding='same',use_bias=False), BatchNormalization(), LeakyReLU(), Flatten(), Dense(1)] ) loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=True) def generator_loss(generated_output): return loss_fn(tf.ones_like(generated_output), generated_output) def discriminator_loss(real_output, generated_output): # [1,1,...,1] with real output since it is true and we want our generated examples to look like it real_loss = loss_fn(tf.ones_like(real_output), real_output) # [0,0,...,0] with generated images since they are fake generated_loss = loss_fn(tf.zeros_like(generated_output), generated_output) total_loss = real_loss + generated_loss return total_loss generator_optimizer = tf.keras.optimizers.Adam(1e-4) discriminator_optimizer = tf.keras.optimizers.Adam(1e-4) EPOCHS = 50 num_examples_to_generate = 16 # We'll re-use this random vector used to seed the generator so # it will be easier to see the improvement over time. random_vector_for_generation = tf.random.normal([num_examples_to_generate, latent_dim]) @tf.function def train_step(images): # generating noise from a normal distribution noise = tf.random.normal([BATCH_SIZE, latent_dim]) with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: generated_images = generator(noise, training=True) real_output = discriminator(images, training=True) generated_output = discriminator(generated_images, training=True) gen_loss = generator_loss(generated_output) disc_loss = discriminator_loss(real_output, generated_output) gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables) gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables) generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables)) discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables)) return gen_loss, disc_loss for epoch in range(30): start_time = time.time() loss = [] for images in train_dataset: loss.append(np.array(train_step(images))) fake = generator(random_vector_for_generation, training=False) fake_concat = np.transpose(img_to_uint8(fake), [1,0,2,3]).reshape((32,-1,3)) print(epoch, np.mean(loss, axis=0), time.time()-start_time) display(PIL.Image.fromarray(fake_concat))