The resulting image $\pmb{x}^*$ minimizes a weighted sum of three losses: the content loss $\mathcal{L}_{content}(\pmb{c}, \pmb{x})$, the style loss $\mathcal{L}_{style}(\pmb{s}, \pmb{x})$ and the total variation loss $\mathcal{L}_{total\_variation}(\pmb{x})$.
$$\pmb{x}^* = \underset{\pmb{x}}{\mathrm{argmin}}(\alpha\mathcal{L}_{content}(\pmb{c}, \pmb{x}) + \beta\mathcal{L}_{style}(\pmb{s}, \pmb{x}) + \gamma\mathcal{L}_{total\_variation}(\pmb{x}))$$(Heavily inspired by this Keras example)
import platform
print(f"Python version: {platform.python_version()}")
assert platform.python_version_tuple() >= ("3", "6")
from IPython.display import Image, display
import numpy as np
Python version: 3.8.2
import tensorflow as tf
print(f"TensorFlow version: {tf.__version__}")
print(f"Keras version: {tf.keras.__version__}")
print('GPU found :)' if tf.config.list_physical_devices("GPU") else 'No GPU :(')
from tensorflow.keras import Model
from tensorflow.keras.applications import vgg19
from tensorflow.keras.utils import get_file
from tensorflow.keras.preprocessing.image import load_img, save_img, img_to_array
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.optimizers.schedules import ExponentialDecay
TensorFlow version: 2.2.0 Keras version: 2.3.0-tf GPU found :)
# Download content and style images
base_image_path = get_file("paris.jpg", "https://i.imgur.com/F28w3Ac.jpg")
style_reference_image_path = get_file(
"starry_night.jpg", "https://i.imgur.com/9ooB60I.jpg"
)
width, height = load_img(base_image_path).size
print(f"Input image dimensions: {width, height}")
Input image dimensions: (910, 607)
# Show content images
display(Image(base_image_path))
# Show style reference image
display(Image(style_reference_image_path))
def preprocess_image(image_path, height, width):
"""Open, resize and format a picture into appropriate tensors"""
img = load_img(
image_path, target_size=(height, width)
)
img = img_to_array(img)
img = np.expand_dims(img, axis=0)
# Convert image from RGB to BGR and zero-center each color channel w.r.t. the ImageNet dataset
img = vgg19.preprocess_input(img)
return tf.convert_to_tensor(img)
def deprocess_image(x, height, width):
"""Convert a tensor into a valid image"""
x = x.reshape((height, width, 3))
# Remove zero-center by mean pixel applied by vgg19.preprocess_input()
# The following values are the mean pixel values of each color channel for the ImageNet dataset
x[:, :, 0] += 103.939
x[:, :, 1] += 116.779
x[:, :, 2] += 123.68
# 'BGR'->'RGB'
x = x[:, :, ::-1]
x = np.clip(x, 0, 255).astype("uint8")
return x
# The gram matrix of a 3D tensor (correlations between the feature maps of a convolutional layer)
def gram_matrix(x):
# Transpose feature maps tensor tensor so that 3rd dimension becomes 1st
x = tf.transpose(x, (2, 0, 1))
# Reshape feature maps tensor into a matrix. First dimension is the number of filters/maps
features = tf.reshape(x, (tf.shape(x)[0], -1))
# Compute the outer product of feature vectors with themselves
gram = tf.matmul(features, tf.transpose(features))
return gram
# The style loss is designed to maintain the style of the reference image in the generated image
# It is based on the gram matrices (which capture style) of feature maps from the style reference image
# and from the generated image
def style_loss(style, combination, height, width):
S = gram_matrix(style)
C = gram_matrix(combination)
channels = 3
size = height * width
# Compute distance between Gram matrices of style and generated images
return tf.reduce_sum(tf.square(S - C)) / (4.0 * (channels ** 2) * (size ** 2))
# The content loss is designed to maintain the "content" of the base image in the generated image
def content_loss(base, combination):
return tf.reduce_sum(tf.square(combination - base))
# The total variation loss is designed to keep the generated image locally coherent
def total_variation_loss(x, height, width):
a = tf.square(
x[:, : height - 1, : width - 1, :] - x[:, 1:, : width - 1, :]
)
b = tf.square(
x[:, : height - 1, : width - 1, :] - x[:, : height - 1, 1:, :]
)
return tf.reduce_sum(tf.pow(a + b, 1.25))
# Using the convolutional base of VGG19, loaded with pre-trained ImageNet weights
vgg = vgg19.VGG19(weights="imagenet", include_top=False)
# Get the symbolic outputs of each "key" layer (we gave them unique names)
outputs_dict = dict([(layer.name, layer.output) for layer in vgg.layers])
# Set up a model that returns the activation values for every layer in VGG19 (as a dict)
feature_extractor = Model(inputs=vgg.inputs, outputs=outputs_dict)
vgg.summary()
Model: "vgg19" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_1 (InputLayer) [(None, None, None, 3)] 0 _________________________________________________________________ block1_conv1 (Conv2D) (None, None, None, 64) 1792 _________________________________________________________________ block1_conv2 (Conv2D) (None, None, None, 64) 36928 _________________________________________________________________ block1_pool (MaxPooling2D) (None, None, None, 64) 0 _________________________________________________________________ block2_conv1 (Conv2D) (None, None, None, 128) 73856 _________________________________________________________________ block2_conv2 (Conv2D) (None, None, None, 128) 147584 _________________________________________________________________ block2_pool (MaxPooling2D) (None, None, None, 128) 0 _________________________________________________________________ block3_conv1 (Conv2D) (None, None, None, 256) 295168 _________________________________________________________________ block3_conv2 (Conv2D) (None, None, None, 256) 590080 _________________________________________________________________ block3_conv3 (Conv2D) (None, None, None, 256) 590080 _________________________________________________________________ block3_conv4 (Conv2D) (None, None, None, 256) 590080 _________________________________________________________________ block3_pool (MaxPooling2D) (None, None, None, 256) 0 _________________________________________________________________ block4_conv1 (Conv2D) (None, None, None, 512) 1180160 _________________________________________________________________ block4_conv2 (Conv2D) (None, None, None, 512) 2359808 _________________________________________________________________ block4_conv3 (Conv2D) (None, None, None, 512) 2359808 _________________________________________________________________ block4_conv4 (Conv2D) (None, None, None, 512) 2359808 _________________________________________________________________ block4_pool (MaxPooling2D) (None, None, None, 512) 0 _________________________________________________________________ block5_conv1 (Conv2D) (None, None, None, 512) 2359808 _________________________________________________________________ block5_conv2 (Conv2D) (None, None, None, 512) 2359808 _________________________________________________________________ block5_conv3 (Conv2D) (None, None, None, 512) 2359808 _________________________________________________________________ block5_conv4 (Conv2D) (None, None, None, 512) 2359808 _________________________________________________________________ block5_pool (MaxPooling2D) (None, None, None, 512) 0 ================================================================= Total params: 20,024,384 Trainable params: 20,024,384 Non-trainable params: 0 _________________________________________________________________
# List of layers to use for the style loss
style_layer_names = [
"block1_conv1",
"block2_conv1",
"block3_conv1",
"block4_conv1",
"block5_conv1",
]
# The layer to use for the content loss
content_layer_name = "block5_conv2"
# Weights of the different loss components
total_variation_weight = 1e-6
style_weight = 1e-6
content_weight = 2.5e-8
def compute_loss(combination_image, base_image, style_reference_image, height, width):
input_tensor = tf.concat(
[base_image, style_reference_image, combination_image], axis=0
)
features = feature_extractor(input_tensor)
# Initialize the loss
loss = tf.zeros(shape=())
# Add content loss
layer_features = features[content_layer_name]
base_image_features = layer_features[0, :, :, :]
combination_features = layer_features[2, :, :, :]
loss = loss + content_weight * content_loss(
base_image_features, combination_features
)
# Add style loss
for layer_name in style_layer_names:
layer_features = features[layer_name]
style_reference_features = layer_features[1, :, :, :]
combination_features = layer_features[2, :, :, :]
sl = style_loss(style_reference_features, combination_features, height, width)
loss += (style_weight / len(style_layer_names)) * sl
# Add total variation loss
loss += total_variation_weight * total_variation_loss(combination_image, height, width)
return loss
@tf.function
def compute_loss_and_grads(combination_image, base_image, style_reference_image, height, width):
with tf.GradientTape() as tape:
loss = compute_loss(combination_image, base_image, style_reference_image, height, width)
grads = tape.gradient(loss, combination_image)
return loss, grads
# Generated image height
gen_height = 400
# Compute generated width so that input and generated images have same scale
gen_width = int(width * gen_height / height)
print(f"Generated image dimensions: {gen_width, gen_height}")
Generated image dimensions: (599, 400)
optimizer = SGD(
ExponentialDecay(
initial_learning_rate=100.0, decay_steps=100, decay_rate=0.96
)
)
base_image = preprocess_image(base_image_path, gen_height, gen_width)
style_reference_image = preprocess_image(style_reference_image_path, gen_height, gen_width)
combination_image = tf.Variable(preprocess_image(base_image_path, gen_height, gen_width))
# Training loop
n_epochs = 4000
for epoch in range(1, n_epochs + 1):
loss, grads = compute_loss_and_grads(
combination_image, base_image, style_reference_image, gen_height, gen_width
)
optimizer.apply_gradients([(grads, combination_image)])
if epoch % 100 == 0:
print(f"Epoch [{epoch}/{n_epochs}], loss: {loss:.2f}")
# Save final image
final_img = deprocess_image(combination_image.numpy(), gen_height, gen_width)
result_image_path = f"generated_image_epoch_{n_epochs}.png"
save_img(result_image_path, final_img)
Epoch [100/4000], loss: 11018.34 Epoch [200/4000], loss: 8514.22 Epoch [300/4000], loss: 7571.62 Epoch [400/4000], loss: 7064.01 Epoch [500/4000], loss: 6736.22 Epoch [600/4000], loss: 6501.81 Epoch [700/4000], loss: 6323.29 Epoch [800/4000], loss: 6181.49 Epoch [900/4000], loss: 6065.40 Epoch [1000/4000], loss: 5967.84 Epoch [1100/4000], loss: 5884.72 Epoch [1200/4000], loss: 5812.83 Epoch [1300/4000], loss: 5750.32 Epoch [1400/4000], loss: 5695.63 Epoch [1500/4000], loss: 5647.32 Epoch [1600/4000], loss: 5604.25 Epoch [1700/4000], loss: 5565.57 Epoch [1800/4000], loss: 5530.69 Epoch [1900/4000], loss: 5499.09 Epoch [2000/4000], loss: 5470.43 Epoch [2100/4000], loss: 5444.34 Epoch [2200/4000], loss: 5420.45 Epoch [2300/4000], loss: 5398.52 Epoch [2400/4000], loss: 5378.34 Epoch [2500/4000], loss: 5359.73 Epoch [2600/4000], loss: 5342.55 Epoch [2700/4000], loss: 5326.65 Epoch [2800/4000], loss: 5311.85 Epoch [2900/4000], loss: 5298.08 Epoch [3000/4000], loss: 5285.24 Epoch [3100/4000], loss: 5273.25 Epoch [3200/4000], loss: 5262.03 Epoch [3300/4000], loss: 5251.54 Epoch [3400/4000], loss: 5241.72 Epoch [3500/4000], loss: 5232.52 Epoch [3600/4000], loss: 5223.90 Epoch [3700/4000], loss: 5215.81 Epoch [3800/4000], loss: 5208.19 Epoch [3900/4000], loss: 5201.01 Epoch [4000/4000], loss: 5194.25
# Show final generated image
display(Image(result_image_path))