ddn.pytorch.node
Module¶Unlike the previous tutorials, in this notebook we use the PyTorch framework to implement a declarative node. For information on how to use PyTorch, see the official documentation and tutorials. Here we will show how to implement a declarative node using the ddn.pytorch.node
module to explore the behavior of the node and solve simple bi-level optimization problems. Using the ddn.pytorch.node
module is very useful for prototyping ideas, at a small performance cost, without having to implement all the expressions in the backward pass yourself.
The typical pattern is to: First, develop a declarative node deriving from the appropriate DeclarativeNode
class. Next, wrap the node within a DeclarativeLayer
. Last, use that layer in the same way you would any other layer in a PyTorch model (e.g., torch.nn.Linear
) as the bi-level optmization examples below show.
We consider the problem of minimizing the KL-divergence between the input $x$ and output $y$ subject to the output forming a valid probability vector (i.e., the elements of $y$ be positive and sum to one). We will assume strictly positive $x$. The problem can be written formally as
$$ \begin{array}{rll} y =& \text{argmin}_u & - \sum_{i=1}^{n} x_i \log u_i \\ & \text{subject to} & \sum_{i=1}^{n} u_i = 1 \end{array} $$where the positivity constraint on $y$ is automatically satisfied by the domain of the log function.
A nice feature of this problem is that we can solve it in closed-form as $$ y = \frac{1}{\sum_{i=1}^{n} x_i} x. $$
However, we will only use this for verification and pretend for now that we do not have a closed-form solution. Instead we will make use of the scipy.optimize
module to solve the problem via an iterative method. Deriving our deep declarative node from the LinEqConstDeclarativeNode
class, we will need to implement three functions: the objective
function, the solve
function, and the linear_constraint_parameters
function (the gradient
function is already implemented for us).
import torch
import numpy as np
import scipy.optimize as opt
import sys
sys.path.append("../")
from ddn.pytorch.node import *
import warnings
warnings.filterwarnings('ignore')
# create the example node
class MinKL(LinEqConstDeclarativeNode):
def __init__(self):
super().__init__()
def objective(self, x, y):
"""f(x, y) = -sum x*log(y)"""
return -1.0 * torch.einsum('bn,bn->b', (x, y.log()))
def linear_constraint_parameters(self, y):
"""Ay=d ==> sum(y) = 1"""
A = y.new_ones(1, y.size(-1)) # 1xm
d = y.new_ones(1) # 1
return A, d
def solve(self, x):
"""Solve the constrained optimization problem using scipy's built-in minimize function.
Here we initialize the solver at the uniform distribution.
"""
m = n = x.size(-1)
u0 = np.ones((m,)) / m
y = torch.zeros_like(x)
# Loop over batch:
for i, xi in enumerate(x):
result = opt.minimize(lambda u: -1.0 * np.dot(xi.detach().numpy(), np.log(u)),
u0,
constraints={'type': 'eq', 'fun': lambda u: np.sum(u) - 1.0},
bounds=opt.Bounds(1e-12, np.inf, keep_feasible=True),
options={'maxiter': 100000, 'ftol': 1e-12})
y[i, :] = torch.tensor(result.x)
# The solve function must always return two arguments, the solution and context (i.e., cached values needed
# for computing the gradient). In the case of linearly constrained problems we do not need the dual solution
# in computing the gradient so we return None for context.
return y, None
And now we test the node.
node = MinKL()
x = torch.rand(1, 5)
print("Input:\n{}".format(x.squeeze().numpy()))
print("Expected output:\n{}".format((x / x.sum(dim=-1, keepdim=True)).squeeze().numpy()))
y, _ = node.solve(x)
print("Actual output:\n{}".format(y.squeeze().numpy()))
Input: [0.6568583 0.38127953 0.21917361 0.8904883 0.93559736] Expected output: [0.21303073 0.12365567 0.07108186 0.28880104 0.3034307 ] Actual output: [0.21303076 0.12365563 0.07108174 0.28880104 0.30343086]
We now plot the function and gradient sweeping the first component of the input $x_1$ from 0.1 to 10.0 while holding the other elements of $x$ constant.
%matplotlib notebook
import matplotlib.pyplot as plt
x1_data = torch.linspace(0.1, 10.0, 100)
x = x.detach() # Don't track computation graph
y_data = []
Dy_data = []
vjp_data = []
for x1 in x1_data:
x_new = x.clone()
x_new[0, 0] = x1
x_new.requires_grad = True
y, _ = torch.no_grad()(node.solve)(x_new) # Run node's forward pass
y.requires_grad = True
y_data.append(y.squeeze().detach().numpy())
# Note that the jacobian function call is inefficient
# and is used only for teaching and analysis purposes
Dy_data.append(node.jacobian(x_new, y=y)[0][0,:,0].detach().numpy())
vjp_data.append(node.gradient(x_new, y=y)[0][0,:].detach().numpy())
# Plot output y as x varies
fig = plt.figure()
plt.subplot(3, 1, 1)
plt.plot(x1_data, y_data)
plt.ylabel(r"$y$")
# Plot derivative dy/dx1 as x1 varies
# dy/dx = (I - y 1^T) / sum(xi)
# dy1/dx1 = (1 - y1) / sum(xi)
# dyi/dx1 = -yi / sum(xi), i > 1
plt.subplot(3, 1, 2)
plt.plot(x1_data, Dy_data)
#plt.ylabel(r"$Dy_{:,1}$")
plt.ylabel(r"$\frac{dy}{dx_1}$")
# Plot vector-Jacobian product as x1 varies
plt.subplot(3, 1, 3)
plt.plot(x1_data, vjp_data)
plt.xlabel(r"$x_1$");
plt.ylabel(r"$\mathbf{1}^\mathsf{T}Dy$")
fig.subplots_adjust(hspace=0.5)
plt.show()
Now let's see whether we can use the node within a bi-level optimization problem. We will attempt to learn an input $x$ that results in an output $y$ with smallest norm-squared. Moreover, we will regularize the norm of $x$ to be close to 10. Given our understanding of KL-divergence this should learn a vector $x$ that is a constant multiple of the ones vector (i.e., all elements of $x$ should be the same). Let's see what happens.
# define the upper-level objective
def J(x, y=None):
"""Computes our upper-level objective given both x and y."""
if y is None:
y, _ = torch.no_grad()(node.solve)(x)
return ((y.norm(dim=-1)) ** 2 + (x.norm(dim=-1) - 10.0) ** 2).mean()
kl_problem = MinKL()
kl_declarative_layer = DeclarativeLayer(kl_problem)
# Solve using gradient descent:
learning_rate = 0.1
x = torch.rand(1, 5, requires_grad=True)
history = [J(x).detach().numpy()]
for i in range(100):
y = kl_declarative_layer(x)
z = J(x, y)
z.backward()
x_new = x - learning_rate * x.grad
x = x_new.detach().requires_grad_(True)
history.append(J(x).detach().numpy())
y, _ = torch.no_grad()(node.solve)(x)
x_np = x.detach().squeeze().numpy()
y_np = y.detach().squeeze().numpy()
print("Found x = {} with norm {:0.2f}".format(x_np, np.sqrt(np.dot(x_np, x_np))))
print("Results in y = {}".format(y_np))
fig = plt.figure()
plt.semilogy(history)
plt.ylabel("upper-level objective (log-scale)"); plt.xlabel("iteration")
plt.show()
# Solve using LBFGS:
x = torch.rand(1, 5, requires_grad=True)
history = []
optimizer = torch.optim.LBFGS([x], lr=1, max_iter=100)
def reevaluate():
optimizer.zero_grad()
y = kl_declarative_layer(x)
z = J(x, y)
z.backward()
history.append(z.clone().detach().numpy())
return z
optimizer.step(reevaluate)
y, _ = torch.no_grad()(node.solve)(x)
x_np = x.detach().squeeze().numpy()
y_np = y.detach().squeeze().numpy()
print("Found x = {} with norm {:0.2f}".format(x_np, np.sqrt(np.dot(x_np, x_np))))
print("Results in y = {}".format(y_np))
fig = plt.figure()
plt.semilogy(history)
plt.ylabel("upper-level objective (log-scale)"); plt.xlabel("iteration")
plt.show()
Found x = [5.8497915 2.1849608 4.9227815 5.558141 2.3281703] with norm 9.98 Results in y = [0.28064838 0.10482523 0.23617435 0.26665622 0.11169581]
Found x = [2.6130035 6.8791103 4.9775457 3.803213 2.4792774] with norm 9.98 Results in y = [0.12591484 0.33148903 0.2398569 0.18326835 0.11947087]
We consider the problem of minimizing the distance between the input $x$ and output $y$ using the robust pseudo-Huber penalty function. The problem can be written formally as $$ \begin{equation} y = \text{argmin}_u \sum_{i=1}^{n} \phi^\text{pseudo}(u - x_i; \alpha) \end{equation} $$ where the pseudo-Huber penalty function is given by $$ \begin{equation} \phi^{\text{pseudo}}(z; \alpha) = \alpha^2 \left( \sqrt{1 + \left(\frac{z}{\alpha}\right)^2} - 1 \right). \end{equation} $$
Deriving our deep declarative node from the AbstractDeclarativeNode
class, we will need to implement two functions: the objective
function, and the solve
function. However, we will also provide a gradient
function to compare the generic gradient result with an efficient hand-coded gradient that makes use of the structure of the problem.
import torch
import numpy as np
import sys
sys.path.append("../")
from ddn.pytorch.node import *
import warnings
warnings.filterwarnings('ignore')
class GlobalPseudoHuberPool2d(AbstractDeclarativeNode):
""""""
def __init__(self):
super().__init__()
def objective(self, x, alpha, y):
alpha2 = alpha * alpha
z = y.unsqueeze(-1).unsqueeze(-1) - x
phi = alpha2 * (torch.sqrt(1.0 + torch.pow(z, 2) / alpha2) - 1.0)
return phi.sum(dim=(-2,-1)) # b
def solve(self, x, alpha):
x = x.detach()
y = x.mean([-2, -1]).clone().requires_grad_()
y = self._runOptimisation(x, alpha, y)
y = y.detach()
z = (y.unsqueeze(-1).unsqueeze(-1) - x).clone()
ctx = {'z': z}
return y, ctx
def _runOptimisation(self, x, alpha, y):
with torch.enable_grad():
opt = torch.optim.LBFGS([y],
lr=1, # Default: 1
max_iter=100, # Default: 20
max_eval=None, # Default: None
tolerance_grad=1e-05, # Default: 1e-05
tolerance_change=1e-09, # Default: 1e-09
history_size=100, # Default: 100
line_search_fn=None # Default: None, Alternative: "strong_wolfe"
)
def reevaluate():
opt.zero_grad()
f = self.objective(x, alpha, y).sum() # sum over batch elements
f.backward()
return f
opt.step(reevaluate)
return y
def gradient(self, x, alpha, y=None, v=None, ctx=None):
"""Override base class to compute the analytic gradient of the optimal solution."""
if y is None:
y, ctx = torch.no_grad()(self.solve)(x, alpha)
if v is None:
v = torch.ones_like(y)
z = ctx['z'] # b x n1 x n2
w = torch.pow(1.0 + torch.pow(z, 2) / (alpha * alpha), -1.5)
w_sum = w.sum(dim=-1, keepdim=True).sum(dim=-2, keepdim=True).expand_as(w)
Dy_at_x = torch.where(w_sum.abs() <= 1e-9, torch.zeros_like(w), w.div(w_sum)) # b x n1 x n2
return torch.einsum('b,bmn->bmn', (v, Dy_at_x)), None
And now we test the node.
node = GlobalPseudoHuberPool2d()
batch_size = 3
input_size = (6, 6)
x = torch.randn(batch_size, *input_size, dtype=torch.double, requires_grad=True)
alpha = torch.tensor([0.5], dtype=torch.double, requires_grad=False)
y, _ = torch.no_grad()(node.solve)(x, alpha)
print("Input:\n{}".format(x[0,...].squeeze().detach().numpy())) # First batch element only
print("Output:\n{}".format(y[0,...].squeeze().detach().numpy())) # First batch element only
Input: [[ 2.13697157e+00 3.05468524e-02 2.19965548e-01 1.14265167e+00 3.21381158e-01 4.99261317e-01] [ 1.19857750e-03 1.07470571e+00 -1.63209451e+00 1.35819507e+00 -7.17839812e-01 2.18305918e-03] [ 1.40257322e+00 1.94317989e+00 8.63105100e-01 -1.36807300e+00 2.24705225e-01 -2.67362331e+00] [-8.24784197e-02 -6.40890761e-01 -6.20824929e-03 -1.25405629e-01 -6.76392417e-01 4.02266401e-01] [ 4.75056082e-01 -9.71727130e-01 -1.56911624e+00 8.24252399e-01 1.98249452e+00 -8.53628414e-01] [-2.96082926e+00 -1.15065371e+00 6.47177871e-02 -1.12872788e+00 2.97239133e-02 -5.01244771e-01]] Output: 0.0009873927345873989
We now plot the function and gradient sweeping the first component of the input $x_1$ from -10.0 to 10.0 while holding the other elements of $x$ constant.
%matplotlib notebook
import matplotlib.pyplot as plt
x1_data = torch.linspace(-10.0, 10.0, 110)
x = x.detach() # Don't track computation graph
y_data = []
vjp_data = []
vjp2_data = []
for x1 in x1_data:
x_new = x.clone()
x_new[:, 0, 0] = x1
x_new.requires_grad = True
y, ctx = torch.no_grad()(node.solve)(x_new, alpha)
y.requires_grad = True
y_data.append(y[0,...].squeeze().detach().numpy()) # First batch element only
vjp_data.append(super(type(node), node).gradient(x_new, alpha, y=y, ctx=ctx)[0][0,0,:].detach().numpy()) # First 6 components
vjp2_data.append(node.gradient(x_new, alpha, y=y, ctx=ctx)[0][0,0,:].detach().numpy()) # First 6 components
fig = plt.figure()
plt.subplot(3, 1, 1)
plt.plot(x1_data, y_data)
plt.ylabel(r"$y$")
plt.subplot(3, 1, 2)
plt.plot(x1_data, vjp_data)
plt.xlabel(r"$x_1$"); plt.ylabel(r"$\mathbf{1}^\mathsf{T}Dy$ (generic)")
plt.subplot(3, 1, 3)
plt.plot(x1_data, vjp2_data)
plt.xlabel(r"$x_1$"); plt.ylabel(r"$\mathbf{1}^\mathsf{T}Dy$ (analytic)")
fig.subplots_adjust(hspace=0.5)
plt.show()
We consider the problem of minimizing the weighted reprojection error between a set of corresponding 3D and 2D points $\{p_i, q_i \}_{i=1}^n$ by varying the rigid transformation parameters $y$ applied to the 3D points. Here the transformation parameters consist of an angle-axis rotation vector concatenated with a translation vector. The problem can be written formally as $$ y = \text{argmin}_u \sum_{i=1}^{n} w_i \| \pi(p_i, u) - q_i \|_2^2 $$ where the projection $\pi(\cdot)$ is given by $$ \pi(p, u) = h(K (R(u) p + t(u))) $$ with intrinsic camera parameters $K$, rotation $R$, translation $t$, and map from homogeneous-to-Cartesian coordinates $h$, where $h(x) = [x_1 / x_3, x_2 / x_3]$.
Deriving our deep declarative node from the AbstractDeclarativeNode
class, we will need to implement two functions: the objective
function, and the solve
function. For this class, we use the solvePnPRansac
function from the Python OpenCV library.
import torch
import numpy as np
import sys
import cv2 as cv
from math import degrees
sys.path.append("../")
from ddn.pytorch.node import *
import ddn.pytorch.geometry_utilities as geo
import warnings
warnings.filterwarnings('ignore')
class PnP(AbstractDeclarativeNode):
"""Declarative PnP layer"""
def __init__(self,
ransac_threshold=0.1,
ransac_max_iterations=1000,
):
super().__init__()
self.ransac_threshold = ransac_threshold
self.ransac_max_iterations = ransac_max_iterations
def objective(self, p, q, w, K, y):
"""Weighted reprojection error"""
p_projected = geo.project_points_by_theta(p, y, K)
squared_error = torch.sum((p_projected - q) ** 2, dim=-1)
w = torch.nn.functional.relu(w) # Enforce non-negative weights
return torch.einsum('bn,bn->b', (w, squared_error))
def solve(self, p, q, w, K=None):
p = p.detach()
q = q.detach()
w = w.detach()
K = K.detach() if K is not None else None
y = self._initialise_transformation(p, q, w, K).requires_grad_()
y = self._run_optimisation(p, q, w, K, y=y)
return y.detach(), None
def _ransac_p3p(self, p, q, K, threshold, max_iterations):
p_np = p.cpu().numpy()
q_np = q.cpu().numpy()
y = q.new_zeros(q.size(0), 6)
if K is None:
K_np = np.float32(np.array([[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0]]))
for i in range(q_np.shape[0]): # loop over batch
if K is not None:
K_np = np.float32(np.array([[K[i, 0], 0.0, K[i, 2]],
[0.0, K[i, 1], K[i, 3]],
[0.0, 0.0, 1.0]]))
retval, rvec, tvec, inliers = cv.solvePnPRansac(
p_np[i, :, :], q_np[i, :, :], K_np, None,
iterationsCount=max_iterations,
reprojectionError=threshold,
flags=cv.SOLVEPNP_EPNP)
if rvec is not None and tvec is not None and retval:
rvec = torch.as_tensor(rvec, dtype=q.dtype, device=q.device).squeeze(-1)
tvec = torch.as_tensor(tvec, dtype=q.dtype, device=q.device).squeeze(-1)
if torch.isfinite(rvec).all() and torch.isfinite(tvec).all():
y[i, :3] = rvec
y[i, 3:] = tvec
return y
def _initialise_transformation(self, p, q, w, K):
return self._ransac_p3p(p, q, K, self.ransac_threshold, self.ransac_max_iterations)
def _run_optimisation(self, *xs, y):
with torch.enable_grad():
opt = torch.optim.LBFGS([y],
lr=1.0,
max_iter=1000,
max_eval=None,
tolerance_grad=1e-40,
tolerance_change=1e-40,
history_size=100,
line_search_fn="strong_wolfe"
)
def reevaluate():
opt.zero_grad()
f = self.objective(*xs, y=y).sum() # sum over batch elements
f.backward()
return f
opt.step(reevaluate)
return y
Now we test the node with 15 random 2D-3D point pairs, random camera parameters, significant additive Gaussian noise, and a single outlier correspondence. We should expect poor results for PnP algorithms when there are outliers, but perhaps we can learn to identify such outliers?
node = PnP()
b = 1
n = 15
# Generate camera parameters:
y_true = torch.randn(b, 6, dtype=torch.double)
R_true = geo.angle_axis_to_rotation_matrix(y_true[..., :3])
t_true = y_true[..., 3:]
# Generate image points, then assign depths:
xy = 2.0 * torch.rand(b, n, 2, dtype=torch.double) - 1.0 # [-1, 1]
z = 2.0 * torch.rand(b, n, 1, dtype=torch.double) + 1.0 # [1, 3]
p_transformed = torch.cat((z * xy, z), dim=-1)
p = torch.einsum('brs,bms->bmr', (R_true.transpose(-2,-1), p_transformed - t_true.unsqueeze(-2))) # Inverse transform
q = xy.clone()
q = q + 0.1 * torch.randn(b, n, 2, dtype=torch.double) # add noise
q[:, 0:1, :] = torch.randn(b, 1, 2, dtype=torch.double) # add outliers
# Generate weights (uniform):
w = torch.ones(b, n, dtype=torch.double) # bxn
w = w.div(w.sum(-1).unsqueeze(-1))
# Run solver:
y, _ = torch.no_grad()(node.solve)(p, q, w)
R = geo.angle_axis_to_rotation_matrix(y[..., :3])
t = y[..., 3:]
# Compute objective function value:
reproj_error = torch.no_grad()(node.objective)(p, q, w, K=None, y=y)
reproj_error_true = torch.no_grad()(node.objective)(p, q, w, K=None, y=y_true)
# Compute transformation errors:
error_rotation = (0.5 * ((R * R_true).sum(dim=(-2, -1)) - 1.0)).acos()
error_translation = (t - t_true).norm(dim=-1)
# Save original data:
p_orig = p.clone()
q_orig = q.clone()
w_orig = w.clone()
y_orig = y.clone()
print("True Output:\n{}".format(y_true[0,...].squeeze().detach().numpy())) # First batch element only
print("Est. Output:\n{}".format(y[0,...].squeeze().detach().numpy())) # First batch element only
print("True Reprojection Error: {:0.4f}".format(reproj_error_true[0,...].squeeze().detach().numpy())) # First batch element only
print("Est. Reprojection Error: {:0.4f}".format(reproj_error[0,...].squeeze().detach().numpy())) # First batch element only
print("Rotation Error: {:0.2f} degrees".format(degrees(error_rotation[0,...].squeeze().detach().numpy()))) # First batch element only
print("Translation Error: {:0.2f}".format(error_translation[0,...].squeeze().detach().numpy())) # First batch element only
True Output: [-1.18942164 0.47228852 0.87570294 -0.50857038 -0.27704414 0.67678228] Est. Output: [-1.09173455 0.11698683 0.14654788 1.48048001 -0.09961496 1.17458384] True Reprojection Error: 0.2374 Est. Reprojection Error: 0.1509 Rotation Error: 44.64 degrees Translation Error: 2.06
It is clear that even a single outlier can play havoc with PnP estimation. We can visualize this by plotting the 2D points and projected 3D points, using the true and estimated transformation parameters. We link the putative 2D and 3D correspondences with a line, to make the outlier correspondence clear.
%matplotlib notebook
import matplotlib.pyplot as plt
fig = plt.figure()
q_np = q.numpy()
p_proj_true_np = geo.project_points_by_theta(p, y_true).numpy()
p_proj_np = geo.project_points_by_theta(p, y).numpy()
for i in range(q_np[0, :, 0].shape[0]):
plt.plot([q_np[0, :, 0], p_proj_true_np[0, :, 0]], [q_np[0, :, 1], p_proj_true_np[0, :, 1]], color='gray', linewidth=0.5)
plt.scatter(q_np[0, :, 0], q_np[0, :, 1], s=16, c='k', alpha=1.0, marker='s', label='2D points')
plt.scatter(p_proj_true_np[0, :, 0], p_proj_true_np[0, :, 1], s=16, c='r', alpha=1.0, marker='o', label='3D points (true projection)')
plt.scatter(p_proj_np[0, :, 0], p_proj_np[0, :, 1], s=16, facecolors='none', edgecolors='k', alpha=1.0, marker='o', label='3D points (est. projection)')
plt.legend(fontsize='small')
plt.show()
Now let's try to learn weights $w$ that attenuate the effect of the outlier correspondences, including those that occur due to noise. Our upper-level objective function will be a weighted sum of rotation and translation errors, given that we know the true camera pose. We expect the outlier correspondence to be downweighted, as well as some of the noisier points.
# Define the upper-level objective:
def J(p, q, w, y=None):
"""Compute sum of angular and positional camera errors"""
if y is None:
y, _ = torch.no_grad()(node.solve)(p, q, w)
R = geo.angle_axis_to_rotation_matrix(y[..., :3])
t = y[..., 3:]
max_dot_product = 1.0 - 1e-7
error_rotation = (0.5 * ((R * R_true).sum(dim=(-2, -1)) - 1.0)
).clamp_(-max_dot_product, max_dot_product).acos()
error_translation = (t - t_true).norm(dim=-1)
#print("rot: {:0.2f}, trans: {:0.6f}".format(degrees(error_rotation[0,...]), error_translation[0,...]))
return (error_rotation + 0.25 * error_translation).mean(), error_rotation, error_translation
# Reset parameters:
w = w_orig.clone().detach().requires_grad_()
y = y_orig.clone()
# Form a declarative layer:
pnp_declarative_layer = DeclarativeLayer(node)
loss, error_rotation, error_translation = J(p, q, w, y)
history_loss = [loss]
history_rot = [degrees(error_rotation[0, ...])] # First batch element only
history_tran = [error_translation[0, ...]] # First batch element only
# Solve using LBFGS optimizer:
optimizer = torch.optim.LBFGS([w], lr=1, max_iter=50, line_search_fn="strong_wolfe")
def reevaluate():
optimizer.zero_grad()
y = pnp_declarative_layer(p, q, w, None)
z, error_rotation, error_translation = J(p, q, w, y)
z.backward()
history_loss.append(z.clone().detach().numpy())
history_rot.append(degrees(error_rotation[0, ...])) # First batch element only
history_tran.append(error_translation[0, ...].detach().numpy()) # First batch element only
return z
optimizer.step(reevaluate)
w = torch.nn.functional.relu(w) # Enforce non-negativity
y, _ = torch.no_grad()(node.solve)(p, q, w)
R = geo.angle_axis_to_rotation_matrix(y[..., :3])
t = y[..., 3:]
reproj_error = torch.no_grad()(node.objective)(p, q, w, K=None, y=y)
error_rotation = (0.5 * ((R * R_true).sum(dim=(-2, -1)) - 1.0)).acos()
error_translation = (t - t_true).norm(dim=-1)
p_np = p.detach().numpy()
q_np = q.detach().numpy()
w_np = w.detach().numpy()
y_np = y.detach().numpy()
print("Found w = {}".format(w_np[0, ...]))
print("Reprojection Error: {:0.4f}".format(reproj_error[0,...].squeeze().detach().numpy()))
print("Rotation Error: {:0.2f} degrees".format(degrees(error_rotation[0,...].squeeze().detach().numpy())))
print("Translation Error: {:0.6f}".format(error_translation[0,...].squeeze().detach().numpy()))
print("True Output: {}".format(y_true[0,...].squeeze().detach().numpy())) # First batch element only
print("Est. Output: {}".format(y[0,...].squeeze().detach().numpy())) # First batch element only
Found w = [0. 0.10564394 0.03942948 0.09214075 0.12407448 0.0411957 0.14691241 0.12342888 0.19623603 0.15112368 0.11926811 0.13557277 0. 0.21916919 0.05708209] Reprojection Error: 0.0219 Rotation Error: 0.03 degrees Translation Error: 0.092251 True Output: [-1.18942164 0.47228852 0.87570294 -0.50857038 -0.27704414 0.67678228] Est. Output: [-1.18940645 0.47258121 0.8760778 -0.46006755 -0.26463463 0.59929912]
And now we plot the learning curves.
%matplotlib notebook
import matplotlib.pyplot as plt
fig = plt.figure()
plt.plot(history_loss)
plt.ylabel("upper-level objective"); plt.xlabel("iteration")
plt.show()
fig = plt.figure()
plt.plot(history_rot)
plt.ylabel("rotation error (degrees)"); plt.xlabel("iteration")
plt.show()
fig = plt.figure()
plt.plot(history_tran)
plt.ylabel("translation error"); plt.xlabel("iteration")
plt.show()
We can visualize the results by plotting the 2D points and projected 3D points. We scale the points by the estimated weight, and replace points with weight $\approx 0$ with crosses to indicate outlier correspondences.
%matplotlib notebook
import matplotlib.pyplot as plt
fig = plt.figure()
p_proj_true_np = geo.project_points_by_theta(p.detach(), y_true).numpy()
p_proj_np = geo.project_points_by_theta(p.detach(), y).numpy()
for i in range(q_np[0, :, 0].shape[0]):
# plt.plot([q_np[0, :, 0], p_proj_true_np[0, :, 0]], [q_np[0, :, 1], p_proj_true_np[0, :, 1]], color='gray', linewidth=0.5)
plt.plot([q_np[0, :, 0], p_proj_np[0, :, 0]], [q_np[0, :, 1], p_proj_np[0, :, 1]], color='gray', linewidth=0.5)
plt.scatter(q_np[0, :, 0], q_np[0, :, 1], s=200.*w_np[0,...], c='k', alpha=1.0, marker='s', label='2D points')
plt.scatter(p_proj_true_np[0, :, 0], p_proj_true_np[0, :, 1], s=200.*w_np[0,...], c='r', alpha=1.0, marker='o', label='3D points (true projection)')
plt.scatter(p_proj_np[0, :, 0], p_proj_np[0, :, 1], s=200.*w_np[0,...], facecolors='none', edgecolors='k', alpha=1.0, marker='o', label='3D points (est. projection)')
# Plot identified outliers separately:
plt.scatter(q_np[0, w_np[0,...] < 1e-3, 0], q_np[0, w_np[0,...] < 1e-3, 1], s=16, c='k', alpha=1.0, marker='x', label='2D points (outliers)')
plt.scatter(p_proj_true_np[0, w_np[0,...] < 1e-3, 0], p_proj_true_np[0, w_np[0,...] < 1e-3, 1], s=16, c='k', alpha=1.0, marker='x', label='3D points (outliers)')
plt.scatter(p_proj_np[0, w_np[0,...] < 1e-3, 0], p_proj_np[0, w_np[0,...] < 1e-3, 1], s=16, c='k', alpha=1.0, marker='x', label='3D points (outliers)')
#plt.legend(fontsize='small')
plt.show()