In [1]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline


# Mental rotation¶

In this notebook we will show how we can enable a pretrained Generative Query Network (GQN) for the Shepard-Metzler mental rotation task. The problem is well studied in psychology to asses spatial intelligence. Mental rotation is a congitive hard problem as it typically requires the employment of both the ventral and dorsal visual streams for recognition and spatial reasoning respectively. Additionally, a certain degree of metacognition is required to reason about uncertainty.

It turns out that the GQN is capable of this, as we will see in this notebook.

Note: This model has only been trained on around 10% of the data for $2 \times 10^5$ iterations instead of the $2 \times 10^6$ described in the original paper. This means that the reconstructions are quite bad and the samples are even worse. Consequently, this notebook is just a proof of concept that the model approximately works. If you have the computational means to fully train the model, then please feel free to make a pull request with the trained model, this will help me a lot.

In [35]:
import torch
import torch.nn as nn

from shepardmetzler import ShepardMetzler

dataset = ShepardMetzler("/data/shepard_metzler_5_parts/") ## <= Choose your data location

In [36]:
from gqn import GenerativeQueryNetwork, partition

# Load model parameters onto CPU

# Initialise new model with the settings of the trained one
model_settings = dict(x_dim=3, v_dim=7, r_dim=256, h_dim=128, z_dim=64, L=8)
model = GenerativeQueryNetwork(**model_settings)

# Load trained parameters, un-dataparallel if needed
if True in ["module" in m for m in list(state_dict.keys())]:
model = nn.DataParallel(model)
model = model.module
else:

model

Out[36]:
GenerativeQueryNetwork(
(generator): GeneratorNetwork(
(inference_core): Conv2dLSTMCell(
(forget): Conv2d(394, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(input): Conv2d(394, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(output): Conv2d(394, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(state): Conv2d(394, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(transform): Conv2d(128, 394, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
)
(generator_core): Conv2dLSTMCell(
(forget): Conv2d(327, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(input): Conv2d(327, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(output): Conv2d(327, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(state): Conv2d(327, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(transform): Conv2d(128, 327, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
)
(posterior_density): Conv2d(128, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(prior_density): Conv2d(128, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(observation_density): Conv2d(128, 3, kernel_size=(1, 1), stride=(1, 1))
(upsample): ConvTranspose2d(128, 128, kernel_size=(4, 4), stride=(4, 4), bias=False)
(downsample): Conv2d(3, 3, kernel_size=(4, 4), stride=(4, 4), bias=False)
)
(representation): TowerRepresentation(
(conv1): Conv2d(3, 256, kernel_size=(2, 2), stride=(2, 2))
(conv2): Conv2d(256, 256, kernel_size=(2, 2), stride=(2, 2))
(conv3): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(conv4): Conv2d(128, 256, kernel_size=(2, 2), stride=(2, 2))
(conv5): Conv2d(263, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(conv6): Conv2d(263, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(conv7): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(conv8): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
)
)

We load a batch of a single image containing a single object seen from 15 different viewpoints. We describe the whole set of image, viewpoint pairs by $\{x_i, v_i \}_{i=1}^{n}$. Whereafter we seperate this set into a context set $\{x_i, v_i \}_{i=1}^{m}$ of $m$ random elements and a query set $\{x^q, v^q \}$, which contains just a single element.

In [184]:
def deterministic_partition(images, viewpoints, indices):
"""
Partition batch into context and query sets.
:param images
:param viewpoints
:return: context images, context viewpoint, query image, query viewpoint
"""
# Maximum number of context points to use
_, b, m, *x_dims = images.shape
_, b, m, *v_dims = viewpoints.shape

# "Squeeze" the batch dimension
images = images.view((-1, m, *x_dims))
viewpoints = viewpoints.view((-1, m, *v_dims))

# Partition into context and query sets
context_idx, query_idx = indices[:-1], indices[-1]

x, v = images[:, context_idx], viewpoints[:, context_idx]
x_q, v_q = images[:, query_idx], viewpoints[:, query_idx]

return x, v, x_q, v_q

In [186]:
import random

# Pick a scene to visualise
scene_id = 0

x_, v_ = x.squeeze(0), v.squeeze(0)

# Sample a set of views
n_context = 7 + 1
indices = random.sample([i for i in range(v_.size(1))], n_context)

# Seperate into context and query sets
x_c, v_c, x_q, v_q = deterministic_partition(x, v, indices)

# Visualise context and query images
f, axarr = plt.subplots(1, 15, figsize=(20, 7))
for i, ax in enumerate(axarr.flat):
# Move channel dimension to end
ax.imshow(x_[scene_id][i].permute(1, 2, 0))

if i == indices[-1]:
ax.set_title("Query", color="magenta")
elif i in indices[:-1]:
ax.set_title("Context", color="green")
else:
ax.set_title("Unused", color="grey")

ax.axis("off")


## Reconstruction¶

Now we feed the whole set into the network and the network will perform the segregration of sets. The query image is then reconstructed in accordance to a given viewpoint and a representation vector that has been generated only by the context set.

In [187]:
f, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(12, 7))

x_mu, r, kl = model(x_c[scene_id].unsqueeze(0),
v_c[scene_id].unsqueeze(0),
x_q[scene_id].unsqueeze(0),
v_q[scene_id].unsqueeze(0))

x_mu = x_mu.squeeze(0)
r = r.squeeze(0)

ax1.imshow(x_q[scene_id].data.permute(1, 2, 0))
ax1.set_title("Query image")
ax1.axis("off")

ax2.imshow(x_mu.data.permute(1, 2, 0))
ax2.set_title("Reconstruction")
ax2.axis("off")

ax3.imshow(r.data.view(16, 16))
ax3.set_title("Representation")
ax3.axis("off")

plt.show()


## Visualising representation¶

We might be interested in visualising the representation as more context points are introduced. The representation network $\phi(x_i, v_i)$ generates a single representation for a context point $(x_i, v_i)$ which is then aggregated (summed) for each context point to generate the final representation.

Below, we see how adding more context points creates a less sparse representation.

In [188]:
f, axarr = plt.subplots(1, 7, figsize=(20, 7))

r = torch.zeros(128, 256, 1, 1)

for i, ax in enumerate(axarr.flat):
phi = model.representation(x_c[:, i], v_c[:, i])
r += phi
ax.imshow(r[scene_id].data.view(16, 16))
ax.axis("off")
ax.set_title("#Context points: {}".format(i+1))


## Sample from the prior.¶

Because we use a conditional prior density $\pi(z|y)$ that is parametrised by a neural network, we should be able to continuously refine it during training such that if $y = (v, r)$ we can generate a sample from the data distrbution by sampling $z \sim \pi(z|v,r)$ and sending it through the generative model $g_{\theta}(x|z, y)$.

This means that we can give a number of context points along with a query viewpoint and generate a new image.

In [189]:
# Create progressively growing context set
batch_size, n_views, c, h, w = x_c.shape

f, axarr = plt.subplots(1, num_samples, figsize=(20, 7))
for i, ax in enumerate(axarr.flat):
x_ = x_c[scene_id][:i+1].view(-1, c, h, w)
v_ = v_c[scene_id][:i+1].view(-1, 7)

phi = model.representation(x_, v_)

r = torch.sum(phi, dim=0)
x_mu = model.generator.sample((h, w), v_q[scene_id].unsqueeze(0), r)
ax.imshow(x_mu.squeeze(0).data.permute(1, 2, 0))
ax.set_title("Context points: {}".format(i))
ax.axis("off")


As an extension to the above mentioned sampling procedure, we can perform the mental rotation task by continuously sampling from the prior given a static representation $r$ and then varying the query viewpoint vector $v^q$ between each sample to "rotate the object".

In the example below we change the yaw slightly at each frame for 8 frames.

In [198]:
# Change viewpoint yaw
batch_size, n_views, c, h, w = context_x.shape
pi = 3.1415629

x_ = x_c[scene_id].view(-1, c, h, w)
v_ = v_c[scene_id].view(-1, 7)

phi = model.representation(x_, v_)

r = torch.sum(phi, dim=0)

f, axarr = plt.subplots(2, num_samples, figsize=(20, 7))
for i, ax in enumerate(axarr[0].flat):
v = torch.zeros(7).copy_(v_q[scene_id])

yaw = (i+1) * (pi/8) - pi/2
v[3], v[4] = np.cos(yaw), np.sin(yaw)

x_mu = model.generator.sample((h, w), v.unsqueeze(0), r)
ax.imshow(x_mu.squeeze(0).data.permute(1, 2, 0))
ax.set_title(r"Yaw:" + str(i+1) + r"$\frac{\pi}{8} - \frac{\pi}{2}$")
ax.axis("off")

for i, ax in enumerate(axarr[1].flat):
v = torch.zeros(7).copy_(v_q[scene_id])

pitch = (i+1) * (pi/8) - pi/2
v[5], v[6] = np.cos(pitch), np.sin(pitch)

x_mu = model.generator.sample((h, w), v.unsqueeze(0), r)
ax.imshow(x_mu.squeeze(0).data.permute(1, 2, 0))
ax.set_title(r"Pitch:" + str(i+1) + r"$\frac{\pi}{8} - \frac{\pi}{2}$")
ax.axis("off")

In [ ]: