In [1]:
import torch
import torch.nn as nn
import numpy as np


Let $\theta \in \mathbb R^M, \phi \in \mathbb R^N$ and $f: \mathbb R^M \times \mathbb R^N \to \mathbb R$. Given initial values $\theta = \theta_0$, $\phi = \phi_0$, we want to:

• Compute $g(\theta_0, \phi_0)$, where $g(\theta, \phi) = \frac{\partial f}{\partial \theta} \bigr\rvert_{\theta, \phi}$.
• Update $\theta = \theta - \alpha g(\theta_0, \phi_0)$.
• Compute $\tilde g(\theta_0, \phi_0)$ where
• $\tilde g(\theta, \phi) = J(\theta, \phi)^T g(\theta, \phi)$ and
• $J(\theta, \phi)_{m, n} = \frac{\partial g_m}{\partial \phi_n} \bigr\rvert_{\theta, \phi}$.
• Update $\phi = \phi - \beta \tilde g(\theta_0, \phi_0)$.

Let $$f(\theta, \phi) = (\theta^T \theta)(\phi^T \phi).$$ We have:

• $g(\theta, \phi) = 2 (\phi^T \phi) \theta$ and
• $J(\theta, \phi)_{m, n} = 4 \begin{bmatrix}  \theta_1\phi_1 & \dots & \theta_1 \phi_N \\ \vdots & \ddots & \vdots \\ \theta_M \phi_1 & \dots & \theta_M \phi_N  \end{bmatrix}$.
In [2]:
class Foo(nn.Module):
def __init__(self, theta_init, phi_init):
super(Foo, self).__init__()
self.theta = nn.Parameter(torch.Tensor(theta_init))
self.phi = nn.Parameter(torch.Tensor(phi_init))

def theta_params(self):
return [self.theta]

def phi_params(self):
return [self.phi]

def forward(self):

theta_init = np.random.rand(3)
phi_init = np.random.rand(4)

foo = Foo(theta_init, phi_init)
optimizer = torch.optim.SGD(foo.theta_params(), lr=1)
phi_optimizer = torch.optim.SGD(foo.phi_params(), lr=1)

loss = foo()

print('Backward theta')
loss.backward(create_graph=True)

print('Expected: {}'.format(2 * np.sum(phi_init * phi_init) * theta_init))

print('\nOptimize theta')
optimizer.step()

print('\nBackward phi')
print('Expected: {}'.format(4 * np.outer(theta_init, phi_init).T @ (2 * np.sum(phi_init * phi_init) * theta_init)))

print('\nOptimize phi')
phi_optimizer.step()

Backward theta

Actual: [ 1.47633135  0.45280895  0.09186564]
Expected: [ 1.47633131  0.45280897  0.09186564]

Optimize theta

Backward phi