#!/usr/bin/env python
# coding: utf-8
# In[1]:
import numpy as np
import matplotlib.pyplot as plt
get_ipython().run_line_magic('matplotlib', 'inline')
# # Mental rotation
#
# In this notebook we will show how we can enable a pretrained Generative Query Network (GQN) for the Shepard-Metzler mental rotation task. The problem is well studied in psychology to asses spatial intelligence. Mental rotation is a congitive hard problem as it typically requires the employment of both the ventral and dorsal visual streams for recognition and spatial reasoning respectively. Additionally, a certain degree of metacognition is required to reason about uncertainty.
#
# It turns out that the GQN is capable of this, as we will see in this notebook.
#
#
# Note:
# This model has only been trained on around 10% of the data for $2 \times 10^5$ iterations instead of the $2 \times 10^6$ described in the original paper. This means that the reconstructions are quite bad and the samples are even worse. Consequently, this notebook is just a proof of concept that the model approximately works. If you have the computational means to fully train the model, then please feel free to make a pull request with the trained model, this will help me a lot.
#
#
# You can download the pretrained model weights from here: [https://github.com/wohlert/generative-query-network-pytorch/releases/tag/0.1](https://github.com/wohlert/generative-query-network-pytorch/releases/download/0.1/model-checkpoint.pth).
# In[35]:
import torch
import torch.nn as nn
# Load dataset
from shepardmetzler import ShepardMetzler
from torch.utils.data import DataLoader
dataset = ShepardMetzler("/data/shepard_metzler_5_parts/") ## <= Choose your data location
loader = DataLoader(dataset, batch_size=1, shuffle=True)
# In[36]:
from gqn import GenerativeQueryNetwork, partition
# Load model parameters onto CPU
state_dict = torch.load("./model-checkpoint.pth", map_location="cpu") ## <= Choose your model location
# Initialise new model with the settings of the trained one
model_settings = dict(x_dim=3, v_dim=7, r_dim=256, h_dim=128, z_dim=64, L=8)
model = GenerativeQueryNetwork(**model_settings)
# Load trained parameters, un-dataparallel if needed
if True in ["module" in m for m in list(state_dict.keys())]:
model = nn.DataParallel(model)
model.load_state_dict(state_dict)
model = model.module
else:
model.load_state_dict(state_dict)
model
# We load a batch of a single image containing a single object seen from 15 different viewpoints. We describe the whole set of image, viewpoint pairs by $\{x_i, v_i \}_{i=1}^{n}$. Whereafter we seperate this set into a context set $\{x_i, v_i \}_{i=1}^{m}$ of $m$ random elements and a query set $\{x^q, v^q \}$, which contains just a single element.
# In[184]:
def deterministic_partition(images, viewpoints, indices):
"""
Partition batch into context and query sets.
:param images
:param viewpoints
:return: context images, context viewpoint, query image, query viewpoint
"""
# Maximum number of context points to use
_, b, m, *x_dims = images.shape
_, b, m, *v_dims = viewpoints.shape
# "Squeeze" the batch dimension
images = images.view((-1, m, *x_dims))
viewpoints = viewpoints.view((-1, m, *v_dims))
# Partition into context and query sets
context_idx, query_idx = indices[:-1], indices[-1]
x, v = images[:, context_idx], viewpoints[:, context_idx]
x_q, v_q = images[:, query_idx], viewpoints[:, query_idx]
return x, v, x_q, v_q
# In[186]:
import random
# Pick a scene to visualise
scene_id = 0
# Load data
x, v = next(iter(loader))
x_, v_ = x.squeeze(0), v.squeeze(0)
# Sample a set of views
n_context = 7 + 1
indices = random.sample([i for i in range(v_.size(1))], n_context)
# Seperate into context and query sets
x_c, v_c, x_q, v_q = deterministic_partition(x, v, indices)
# Visualise context and query images
f, axarr = plt.subplots(1, 15, figsize=(20, 7))
for i, ax in enumerate(axarr.flat):
# Move channel dimension to end
ax.imshow(x_[scene_id][i].permute(1, 2, 0))
if i == indices[-1]:
ax.set_title("Query", color="magenta")
elif i in indices[:-1]:
ax.set_title("Context", color="green")
else:
ax.set_title("Unused", color="grey")
ax.axis("off")
# ## Reconstruction
#
# Now we feed the whole set into the network and the network will perform the segregration of sets. The query image is then reconstructed in accordance to a given viewpoint and a representation vector that has been generated only by the context set.
# In[187]:
f, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(12, 7))
x_mu, r, kl = model(x_c[scene_id].unsqueeze(0),
v_c[scene_id].unsqueeze(0),
x_q[scene_id].unsqueeze(0),
v_q[scene_id].unsqueeze(0))
x_mu = x_mu.squeeze(0)
r = r.squeeze(0)
ax1.imshow(x_q[scene_id].data.permute(1, 2, 0))
ax1.set_title("Query image")
ax1.axis("off")
ax2.imshow(x_mu.data.permute(1, 2, 0))
ax2.set_title("Reconstruction")
ax2.axis("off")
ax3.imshow(r.data.view(16, 16))
ax3.set_title("Representation")
ax3.axis("off")
plt.show()
# ## Visualising representation
#
# We might be interested in visualising the representation as more context points are introduced. The representation network $\phi(x_i, v_i)$ generates a single representation for a context point $(x_i, v_i)$ which is then aggregated (summed) for each context point to generate the final representation.
#
# Below, we see how adding more context points creates a less sparse representation.
# In[188]:
f, axarr = plt.subplots(1, 7, figsize=(20, 7))
r = torch.zeros(128, 256, 1, 1)
for i, ax in enumerate(axarr.flat):
phi = model.representation(x_c[:, i], v_c[:, i])
r += phi
ax.imshow(r[scene_id].data.view(16, 16))
ax.axis("off")
ax.set_title("#Context points: {}".format(i+1))
# ## Sample from the prior.
#
# Because we use a conditional prior density $\pi(z|y)$ that is parametrised by a neural network, we should be able to continuously refine it during training such that if $y = (v, r)$ we can generate a sample from the data distrbution by sampling $z \sim \pi(z|v,r)$ and sending it through the generative model $g_{\theta}(x|z, y)$.
#
# This means that we can give a number of context points along with a query viewpoint and generate a new image.
# In[189]:
# Create progressively growing context set
batch_size, n_views, c, h, w = x_c.shape
f, axarr = plt.subplots(1, num_samples, figsize=(20, 7))
for i, ax in enumerate(axarr.flat):
x_ = x_c[scene_id][:i+1].view(-1, c, h, w)
v_ = v_c[scene_id][:i+1].view(-1, 7)
phi = model.representation(x_, v_)
r = torch.sum(phi, dim=0)
x_mu = model.generator.sample((h, w), v_q[scene_id].unsqueeze(0), r)
ax.imshow(x_mu.squeeze(0).data.permute(1, 2, 0))
ax.set_title("Context points: {}".format(i))
ax.axis("off")
# ## Mental rotation task
#
# As an extension to the above mentioned sampling procedure, we can perform the mental rotation task by continuously sampling from the prior given a static representation $r$ and then varying the query viewpoint vector $v^q$ between each sample to "rotate the object".
#
# In the example below we change the yaw slightly at each frame for 8 frames.
# In[198]:
# Change viewpoint yaw
batch_size, n_views, c, h, w = context_x.shape
pi = 3.1415629
x_ = x_c[scene_id].view(-1, c, h, w)
v_ = v_c[scene_id].view(-1, 7)
phi = model.representation(x_, v_)
r = torch.sum(phi, dim=0)
f, axarr = plt.subplots(2, num_samples, figsize=(20, 7))
for i, ax in enumerate(axarr[0].flat):
v = torch.zeros(7).copy_(v_q[scene_id])
yaw = (i+1) * (pi/8) - pi/2
v[3], v[4] = np.cos(yaw), np.sin(yaw)
x_mu = model.generator.sample((h, w), v.unsqueeze(0), r)
ax.imshow(x_mu.squeeze(0).data.permute(1, 2, 0))
ax.set_title(r"Yaw:" + str(i+1) + r"$\frac{\pi}{8} - \frac{\pi}{2}$")
ax.axis("off")
for i, ax in enumerate(axarr[1].flat):
v = torch.zeros(7).copy_(v_q[scene_id])
pitch = (i+1) * (pi/8) - pi/2
v[5], v[6] = np.cos(pitch), np.sin(pitch)
x_mu = model.generator.sample((h, w), v.unsqueeze(0), r)
ax.imshow(x_mu.squeeze(0).data.permute(1, 2, 0))
ax.set_title(r"Pitch:" + str(i+1) + r"$\frac{\pi}{8} - \frac{\pi}{2}$")
ax.axis("off")
# In[ ]: