Preamble: At the time of this writing, I'm using PyTorch v1.7.1
binded with cuda11.0
and cudnn8.0
.
import numpy as np
import torch
print("version: ", torch.__version__)
mydevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device : ", mydevice)
Einsum is a powerful concept for processing tensors while at the same time writing very succinct code. The reasons to adopt einsum are:
Let us now see some sample problems to fully grasp the power of einsum.
# some input tensors to work with
vec = torch.tensor([0, 1, 2, 3])
aten = torch.tensor([[11, 12, 13, 14],
[21, 22, 23, 24],
[31, 32, 33, 34],
[41, 42, 43, 44]])
bten = torch.tensor([[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3],
[4, 4, 4, 4]])
Cik = ∑j Aij * Bjk
For a matrix multiplication to work, the number of columns in the first matrix (e.g., A) should match the number of rows in the second matrix (e.g., B).
c = torch.einsum('ij, jk -> ik', aten, bten)
print("einsum matmul: \n", c)
# sanity check
c = torch.matmul(aten, bten) # or: aten.mm(bten)
print("torch matmul: \n", c)
hp = torch.einsum('ij, ij, i -> ij', aten, bten, vec) # note: `vec` is treated as a column vector
print("einsum hadamard product: \n", hp)
# sanity check
ep = aten * bten * vec[:, None]
print("element-wise product: \n", ep)
Note: we can raise the elements of a tensor to power n
by repeating the tensor n
times. For instance, a tensor can be cubed by repeating it 3 times.
hp = torch.einsum('ij, ij, ij -> ij', bten, bten, bten)
print("einsum hadamard product: \n", hp)
# sanity check
ep = bten * bten * bten
print("element-wise product: \n", ep)