As seen in its dedicated notebook, online softmax
limits the computation to 2 passes on global memory input.
In FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
, the goal of the author is to have as few passes on input data as possible and, most importantly, avoid saving intermediate results to GPU
global memory (the big slow DRAM
memory).
The reason is that the intermediate results are basically the attention matrix QK^t
of shape seq length x seq length
(one before softmax
, one after it) which scales by definition quadratically regarding the sequence length.
Fusing two matmul
(technique known as kernel fusion) is quite simple when the number of columns of input matrices is low (which is the case in transformer
for a single head
).
That way, we avoid saving to global memory the output of the first matmul
(which is an intermediate result).
One more complicated challenge is to perform the softmax
between the two fused matmul
.
Traditional softmax
formula requires several full passes on input data, and it's impossible in an operation between two fused matmul
because intermediate result is not available.
In the original online softmax
paper, the computation of input maximum value and the normalizer is progressive, using one part of input data at a time. When we load a new part of the input vector, we may discover a new input maximum value. If so, we need to adjust already computed part of the softmax
(the past) in a way which makes it as we had applied the new row maximum since the beginning of the computation.
As explained at the beginning of this notebook, we don't want to save those very large data. So, the trick is the following: the adjustment of the partial softmax
output is a multiplication of the past data by a scalar, multiplication is commutative, so we can apply the adjustment not on softmax
output itself, but to the output of the second matmul
(softmax
output times V
matrix). This output is the self-attention output, it's quite small (seq length x d
) compared to intermediate results (seq length x seq length
), plus it is saved to global memory.
As a reminder seq length x d
is the size of V
matrix (and K
, Q
). d
is the number of columns of V
(and K
, Q
), aka the number of dimensions per head for the model. This number is low: <= 128 even for a Davinci GPT-3
size model.
import torch
torch.manual_seed(456)
N, d = 16, 8
Q_mat = torch.rand((N, d))
K_mat = torch.rand((N, d))
V_mat = torch.rand((N, d))
# tile size for matmul, no op bigger than this size can be stored in SRAM
Br = 4
Bc = d
We start by implementing attention mechanism in PyTorch. The code is simple and many read/write in global memory are performed.
expected_softmax = torch.softmax(Q_mat @ K_mat.T, dim=1)
expected_attention = expected_softmax @ V_mat
# 1st read
S_mat = Q_mat @ K_mat.T
row_max = torch.max(S_mat, dim=1).values[:, None]
# 2nd read
input_safe = S_mat - row_max
softmax_numerator = torch.exp(input_safe)
# 3rd read
softmax_denominator = torch.sum(softmax_numerator, dim=1)[:, None]
# 4th read
naive_softmax = softmax_numerator / softmax_denominator
# final matmul (another read / write)
matmul_result = naive_softmax @ V_mat
assert torch.allclose(naive_softmax, expected_softmax)
assert torch.allclose(matmul_result, expected_attention)
Tiling is a technique based on matrix partition, each block is called a tile.
A dedicated notebook is in the tutorial
folder of this repository.
We start with the matmul
of QK^t
.
One particularity of transformer models is that the number of columns (d
) of those matrices is small enough that several complete rows can be stored in shared memory (a fast memory close to compute cores of the GPU), so we don't have to iterate across this axis. This is an important aspect that we will leverage later.
S_mat_for_check = torch.zeros((N, N))
for block_start_Bc in range(0, N, Bc):
block_end_Bc = block_start_Bc + Bc
Kj = K_mat[block_start_Bc:block_end_Bc, :] # shape Bc x d
for block_start_Br in range(0, N, Br):
block_end_Br = block_start_Br + Br
Qi = Q_mat[block_start_Br:block_end_Br, :] # shape Br x d
# QKt at the tile level
Sij = Qi @ Kj.T # shape Br x Bc
S_mat_for_check[block_start_Br:block_end_Br, block_start_Bc:block_end_Bc] += Sij
assert torch.allclose(S_mat_for_check, Q_mat @ K_mat.T)
matmul
¶We will perform the computation O=SV
where S=QK^t
. Therefore, we need to perform 2 matmul
s.
For now, we do not put any softmax
in between.
Our main challenge is to build on top of the previous notebook block and not save in GPU
global memory the intermediate matmul
output S
. One trick to reduce global memory accesses is to reuse data as much as possible. That is why in the outer loop we load 2 blocks (Kj
and Vj
) and reuse both of them during the iteration in the inner loop where only a single block is loaded from global memory (if executed in Cuda
).
The input matrices are supposed to not be transposed.
The transposition of K
is done implicitely through the way we iterate over it. Because d
is small, we have no non-coalesced memory access issue.
O = torch.zeros((N, d))
for block_start_Bc in range(0, N, Bc):
block_end_Bc = block_start_Bc + Bc
Kj = K_mat[block_start_Bc:block_end_Bc, :] # shape Bc x d
Vj = V_mat[block_start_Bc:block_end_Bc, :] # shape Bc x d
for block_start_Br in range(0, N, Br):
block_end_Br = block_start_Br + Br
Qi = Q_mat[block_start_Br:block_end_Br, :] # shape Br x d
# QKt at the tile level
Sij = Qi @ Kj.T # shape Br x Bc
Oi = Sij @ Vj # shape Br x d
O[block_start_Br:block_end_Br, :] += Oi
assert torch.allclose(O, (Q_mat @ K_mat.T) @ V_mat)
In the previous block, we didn't apply the softmax
on top of QK^t
matmul
.
The challenge in introducing it is that both softmax
input and output shall not be saved in global memory. Let remind us that both input and output of softmax
have the shape seq len X seq len
.
For that purpose, we will leverage the online softmax
technique presented in details in a dedicated notebook.
The idea is that to compute safe softmax
(a numerical stable version of the softmax
) we need to know the input max
value (and the softmax
denominator which itself depends on this input max
statistic), an information that we can only know by scanning whole input data.
online softmax
computes (safe) softmax
progressively in a single pass over input data (which are themselves computes progressively), block of row after block of row. During the process, each time we discover in a block a max
bigger than the currently known row max
, we correct the already computed values in a way which simulates that we have applied the new row max
for the softmax
numerator and denominator since the beginning.
The correction of the softmax
denominator is applied here:
# This line is exactly the same mechanism seen in `online softmax` notebook but applied to a vector instead of scalar (the math stays the same).
li_new = torch.exp(mi - mi_new) * li + torch.exp(mij_hat - mi_new) * lij_hat
The remaining question is where to correct the past partial softmax
output?
Because we want to perform all computations in a single pass, we do not save first matmul
output to GPU
global memory, meaning there is no past part to adjust.
It appears that we can apply the correction to the matrix O
(output of the second matmul
).
It works because multiplication by a scalar is commutative, aka we can change the order of operations and get the same result mathematically.
Nb: this is not 100% true in our code, order matters because float numbers has limited precision and introduces some roundings, still, the effect is small and Okish in deep learning.
This is done in the line:
Oi = (li * torch.exp(mi - mi_new) * Oi / li_new) + (torch.exp(mij_hat - mi_new) * pij_hat / li_new) @ Vj
In the left part of the addition ((li * torch.exp(mi - mi_new) * Oi / li_new)
), Oi
contains the sum of past output tiles, and that's where we can correct the past. In the right part of the addition ((torch.exp(mij_hat - mi_new) * pij_hat / li_new) @ Vj
), we first correct the current tile softmax
and then compute the current matmul
tile output.
Lines below refer to Algorithm 1
from flash attention paper:
# variables outside the for loop represent the global memory
# they are the only ones bigger than what the SRAM can store
O = torch.zeros((N, d))
# For the 2 variables below, they may be removed in a serially executed code (in particular the outter for loop)
# They are needed in parallelized execution where each thread block need to sync its findings with the others
# line 4, l will store the denominator of the softmax for each row
l = torch.zeros((N, 1))
# line 4, m will store the row max (computed progressively, block after block)
m = torch.full((N, 1), -torch.inf)
for block_start_Bc in range(0, N, Bc):
block_end_Bc = block_start_Bc + Bc
# line 6, load a block from matmul input tensor
Kj = K_mat[block_start_Bc:block_end_Bc, :] # shape Bc x d
Vj = V_mat[block_start_Bc:block_end_Bc, :] # shape Bc x d
for block_start_Br in range(0, N, Br):
block_end_Br = block_start_Br + Br
# line 8, load stuff from globabl memory, aka the work of the other thread blocks
mi = m[block_start_Br:block_end_Br, :] # shape Br x 1
li = l[block_start_Br:block_end_Br, :] # shape Br x 1
Oi = O[block_start_Br:block_end_Br, :] # shape Br x d
Qi = Q_mat[block_start_Br:block_end_Br, :] # shape Br x d
# line 9, QKt at the tile level
Sij = Qi @ Kj.T # shape Br x Bc
# line 10, find max of each row of the current loaded block (and only this block)
mij_hat = torch.max(Sij, dim=1).values[:, None]
# line 10, compute the softmax numerator like if we only had the data from this block (and nothing before or after)
pij_hat = torch.exp(Sij - mij_hat)
# line 10, compute the softmax denominator like if we only had the data from this block (and nothing before or after)
lij_hat = torch.sum(pij_hat, dim=1)[:, None]
# line 11, find max of each row regarding the current block and all the previous ones we have already visited
mi_new = torch.max(torch.column_stack([mi, mij_hat]), dim=1).values[:, None]
# line 11, adjusting factor (see online softmax computation above) leveraging the rule of exponentiation
li_new = torch.exp(mi - mi_new) * li + torch.exp(mij_hat - mi_new) * lij_hat
# line 12, first part before the "+" is the adjustment of the past blocks
# second part after the "+" is the incorporation of the information from the current block and the matmul for this block
Oi = (li * torch.exp(mi - mi_new) * Oi / li_new) + (torch.exp(mij_hat - mi_new) * pij_hat / li_new) @ Vj
# Note that we replace (=) and not update (+=) the global variables like we would do in tilted matmul
# line 13, save statistics
m[block_start_Br:block_end_Br, :] = mi_new # row max
l[block_start_Br:block_end_Br, :] = li_new # softmax denominator
# save attention block to global memory
O[block_start_Br:block_end_Br, :] = Oi
assert torch.allclose(O, expected_attention)
Triton implementation of Flash attention and original Flash attention Cuda implementations differ on an important point: the way they are parallelized.
In Cuda implementation, it's quite simple, algorithm above is executed in a serialized way. The parallelization only happens at the head x batch
level (so it needs on A100
at least head x batch >= 80 to keep the GPU busy).
In Triton implementation, the inner and outer loops in the algo above are switched and the parallelization happens at the level of the outer loop, it increases the level of parallelization and it makes the GPU busy even for small batches / low number of heads. See https://github.com/HazyResearch/flash-attention/issues/40 for detailed analysis.