import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
In this notebook we will show how we can enable a pretrained Generative Query Network (GQN) for the Shepard-Metzler mental rotation task. The problem is well studied in psychology to asses spatial intelligence. Mental rotation is a congitive hard problem as it typically requires the employment of both the ventral and dorsal visual streams for recognition and spatial reasoning respectively. Additionally, a certain degree of metacognition is required to reason about uncertainty.
It turns out that the GQN is capable of this, as we will see in this notebook.
You can download the pretrained model weights from here: https://github.com/wohlert/generative-query-network-pytorch/releases/tag/0.1.
import torch
import torch.nn as nn
# Load dataset
from shepardmetzler import ShepardMetzler
from torch.utils.data import DataLoader
dataset = ShepardMetzler("/data/shepard_metzler_5_parts/") ## <= Choose your data location
loader = DataLoader(dataset, batch_size=1, shuffle=True)
from gqn import GenerativeQueryNetwork, partition
# Load model parameters onto CPU
state_dict = torch.load("./model-checkpoint.pth", map_location="cpu") ## <= Choose your model location
# Initialise new model with the settings of the trained one
model_settings = dict(x_dim=3, v_dim=7, r_dim=256, h_dim=128, z_dim=64, L=8)
model = GenerativeQueryNetwork(**model_settings)
# Load trained parameters, un-dataparallel if needed
if True in ["module" in m for m in list(state_dict.keys())]:
model = nn.DataParallel(model)
model.load_state_dict(state_dict)
model = model.module
else:
model.load_state_dict(state_dict)
model
GenerativeQueryNetwork( (generator): GeneratorNetwork( (inference_core): Conv2dLSTMCell( (forget): Conv2d(394, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) (input): Conv2d(394, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) (output): Conv2d(394, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) (state): Conv2d(394, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) (transform): Conv2d(128, 394, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) ) (generator_core): Conv2dLSTMCell( (forget): Conv2d(327, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) (input): Conv2d(327, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) (output): Conv2d(327, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) (state): Conv2d(327, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) (transform): Conv2d(128, 327, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) ) (posterior_density): Conv2d(128, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) (prior_density): Conv2d(128, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) (observation_density): Conv2d(128, 3, kernel_size=(1, 1), stride=(1, 1)) (upsample): ConvTranspose2d(128, 128, kernel_size=(4, 4), stride=(4, 4), bias=False) (downsample): Conv2d(3, 3, kernel_size=(4, 4), stride=(4, 4), bias=False) ) (representation): TowerRepresentation( (conv1): Conv2d(3, 256, kernel_size=(2, 2), stride=(2, 2)) (conv2): Conv2d(256, 256, kernel_size=(2, 2), stride=(2, 2)) (conv3): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv4): Conv2d(128, 256, kernel_size=(2, 2), stride=(2, 2)) (conv5): Conv2d(263, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv6): Conv2d(263, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv7): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (conv8): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1)) (avgpool): AvgPool2d(kernel_size=16, stride=16, padding=0) ) )
We load a batch of a single image containing a single object seen from 15 different viewpoints. We describe the whole set of image, viewpoint pairs by $\{x_i, v_i \}_{i=1}^{n}$. Whereafter we seperate this set into a context set $\{x_i, v_i \}_{i=1}^{m}$ of $m$ random elements and a query set $\{x^q, v^q \}$, which contains just a single element.
def deterministic_partition(images, viewpoints, indices):
"""
Partition batch into context and query sets.
:param images
:param viewpoints
:return: context images, context viewpoint, query image, query viewpoint
"""
# Maximum number of context points to use
_, b, m, *x_dims = images.shape
_, b, m, *v_dims = viewpoints.shape
# "Squeeze" the batch dimension
images = images.view((-1, m, *x_dims))
viewpoints = viewpoints.view((-1, m, *v_dims))
# Partition into context and query sets
context_idx, query_idx = indices[:-1], indices[-1]
x, v = images[:, context_idx], viewpoints[:, context_idx]
x_q, v_q = images[:, query_idx], viewpoints[:, query_idx]
return x, v, x_q, v_q
import random
# Pick a scene to visualise
scene_id = 0
# Load data
x, v = next(iter(loader))
x_, v_ = x.squeeze(0), v.squeeze(0)
# Sample a set of views
n_context = 7 + 1
indices = random.sample([i for i in range(v_.size(1))], n_context)
# Seperate into context and query sets
x_c, v_c, x_q, v_q = deterministic_partition(x, v, indices)
# Visualise context and query images
f, axarr = plt.subplots(1, 15, figsize=(20, 7))
for i, ax in enumerate(axarr.flat):
# Move channel dimension to end
ax.imshow(x_[scene_id][i].permute(1, 2, 0))
if i == indices[-1]:
ax.set_title("Query", color="magenta")
elif i in indices[:-1]:
ax.set_title("Context", color="green")
else:
ax.set_title("Unused", color="grey")
ax.axis("off")
Now we feed the whole set into the network and the network will perform the segregration of sets. The query image is then reconstructed in accordance to a given viewpoint and a representation vector that has been generated only by the context set.
f, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(12, 7))
x_mu, r, kl = model(x_c[scene_id].unsqueeze(0),
v_c[scene_id].unsqueeze(0),
x_q[scene_id].unsqueeze(0),
v_q[scene_id].unsqueeze(0))
x_mu = x_mu.squeeze(0)
r = r.squeeze(0)
ax1.imshow(x_q[scene_id].data.permute(1, 2, 0))
ax1.set_title("Query image")
ax1.axis("off")
ax2.imshow(x_mu.data.permute(1, 2, 0))
ax2.set_title("Reconstruction")
ax2.axis("off")
ax3.imshow(r.data.view(16, 16))
ax3.set_title("Representation")
ax3.axis("off")
plt.show()
We might be interested in visualising the representation as more context points are introduced. The representation network $\phi(x_i, v_i)$ generates a single representation for a context point $(x_i, v_i)$ which is then aggregated (summed) for each context point to generate the final representation.
Below, we see how adding more context points creates a less sparse representation.
f, axarr = plt.subplots(1, 7, figsize=(20, 7))
r = torch.zeros(128, 256, 1, 1)
for i, ax in enumerate(axarr.flat):
phi = model.representation(x_c[:, i], v_c[:, i])
r += phi
ax.imshow(r[scene_id].data.view(16, 16))
ax.axis("off")
ax.set_title("#Context points: {}".format(i+1))
Because we use a conditional prior density $\pi(z|y)$ that is parametrised by a neural network, we should be able to continuously refine it during training such that if $y = (v, r)$ we can generate a sample from the data distrbution by sampling $z \sim \pi(z|v,r)$ and sending it through the generative model $g_{\theta}(x|z, y)$.
This means that we can give a number of context points along with a query viewpoint and generate a new image.
# Create progressively growing context set
batch_size, n_views, c, h, w = x_c.shape
f, axarr = plt.subplots(1, num_samples, figsize=(20, 7))
for i, ax in enumerate(axarr.flat):
x_ = x_c[scene_id][:i+1].view(-1, c, h, w)
v_ = v_c[scene_id][:i+1].view(-1, 7)
phi = model.representation(x_, v_)
r = torch.sum(phi, dim=0)
x_mu = model.generator.sample((h, w), v_q[scene_id].unsqueeze(0), r)
ax.imshow(x_mu.squeeze(0).data.permute(1, 2, 0))
ax.set_title("Context points: {}".format(i))
ax.axis("off")
As an extension to the above mentioned sampling procedure, we can perform the mental rotation task by continuously sampling from the prior given a static representation $r$ and then varying the query viewpoint vector $v^q$ between each sample to "rotate the object".
In the example below we change the yaw slightly at each frame for 8 frames.
# Change viewpoint yaw
batch_size, n_views, c, h, w = context_x.shape
pi = 3.1415629
x_ = x_c[scene_id].view(-1, c, h, w)
v_ = v_c[scene_id].view(-1, 7)
phi = model.representation(x_, v_)
r = torch.sum(phi, dim=0)
f, axarr = plt.subplots(2, num_samples, figsize=(20, 7))
for i, ax in enumerate(axarr[0].flat):
v = torch.zeros(7).copy_(v_q[scene_id])
yaw = (i+1) * (pi/8) - pi/2
v[3], v[4] = np.cos(yaw), np.sin(yaw)
x_mu = model.generator.sample((h, w), v.unsqueeze(0), r)
ax.imshow(x_mu.squeeze(0).data.permute(1, 2, 0))
ax.set_title(r"Yaw:" + str(i+1) + r"$\frac{\pi}{8} - \frac{\pi}{2}$")
ax.axis("off")
for i, ax in enumerate(axarr[1].flat):
v = torch.zeros(7).copy_(v_q[scene_id])
pitch = (i+1) * (pi/8) - pi/2
v[5], v[6] = np.cos(pitch), np.sin(pitch)
x_mu = model.generator.sample((h, w), v.unsqueeze(0), r)
ax.imshow(x_mu.squeeze(0).data.permute(1, 2, 0))
ax.set_title(r"Pitch:" + str(i+1) + r"$\frac{\pi}{8} - \frac{\pi}{2}$")
ax.axis("off")