In [1]:
import torch
import torch.nn as nn
import numpy as np

Let $\theta \in \mathbb R^M, \phi \in \mathbb R^N$ and $f: \mathbb R^M \times \mathbb R^N \to \mathbb R$. Given initial values $\theta = \theta_0$, $\phi = \phi_0$, we want to:

  • Compute $g(\theta_0, \phi_0)$, where $g(\theta, \phi) = \frac{\partial f}{\partial \theta} \bigr\rvert_{\theta, \phi}$.
  • Update $\theta = \theta - \alpha g(\theta_0, \phi_0)$.
  • Compute $\tilde g(\theta_0, \phi_0)$ where
    • $\tilde g(\theta, \phi) = J(\theta, \phi)^T g(\theta, \phi)$ and
    • $J(\theta, \phi)_{m, n} = \frac{\partial g_m}{\partial \phi_n} \bigr\rvert_{\theta, \phi}$.
  • Update $\phi = \phi - \beta \tilde g(\theta_0, \phi_0)$.

Let $$f(\theta, \phi) = (\theta^T \theta)(\phi^T \phi).$$ We have:

  • $g(\theta, \phi) = 2 (\phi^T \phi) \theta$ and
  • $J(\theta, \phi)_{m, n} = 4 \begin{bmatrix}
      \theta_1\phi_1 & \dots & \theta_1 \phi_N \\
      \vdots & \ddots & \vdots \\
      \theta_M \phi_1 & \dots & \theta_M \phi_N
    
    \end{bmatrix}$.
In [2]:
class Foo(nn.Module):
    def __init__(self, theta_init, phi_init):
        super(Foo, self).__init__()
        self.theta = nn.Parameter(torch.Tensor(theta_init))
        self.phi = nn.Parameter(torch.Tensor(phi_init))
        
    def theta_params(self):
        return [self.theta]
    
    def phi_params(self):
        return [self.phi]
        
    def forward(self):
        return torch.sum(self.theta**2) * torch.sum(self.phi**2)
    
    
theta_init = np.random.rand(3)
phi_init = np.random.rand(4)

foo = Foo(theta_init, phi_init)
optimizer = torch.optim.SGD(foo.theta_params(), lr=1)
phi_optimizer = torch.optim.SGD(foo.phi_params(), lr=1)

optimizer.zero_grad()
loss = foo()

print('Backward theta')
loss.backward(create_graph=True)

print('\nGradient with respect to theta:')
print('Actual: {}'.format(foo.theta.grad.data.numpy()))
print('Expected: {}'.format(2 * np.sum(phi_init * phi_init) * theta_init))

print('\nOptimize theta')
optimizer.step()

phi_optimizer.zero_grad()

print('\nBackward phi')
torch.autograd.backward(
    [theta_param.grad for theta_param in foo.theta_params()], 
    [theta_param.grad.detach() for theta_param in foo.theta_params()]
)

print('\nGradient with respect to phi:')
print('Actual: {}'.format(foo.phi.grad.data.numpy()))
print('Expected: {}'.format(4 * np.outer(theta_init, phi_init).T @ (2 * np.sum(phi_init * phi_init) * theta_init)))

print('\nOptimize phi')
phi_optimizer.step()
Backward theta

Gradient with respect to theta:
Actual: [ 1.47633135  0.45280895  0.09186564]
Expected: [ 1.47633131  0.45280897  0.09186564]

Optimize theta

Backward phi

Gradient with respect to phi:
Actual: [ 3.80052328  2.62138939  0.48601541  1.70397508]
Expected: [ 3.80052316  2.62138921  0.48601538  1.7039751 ]

Optimize phi
In [ ]: