Understanding Triton Tutorials Part 2

Isamu Isozaki
30 min readJun 6, 2024

--

Hi! I recently started to want to review Triton so I’m making a second part of understanding Triton tutorials. The first part can be viewed here. After this article, I plan to post another blog on understanding Torch Inductor and its limitations. Overall my goal of this blog is just to understand/wrap up the basic functions of Triton.

One disclaimer. I wasn’t able to fully understand the tutorials/cover all of them so this is not a comprehensive explanation of the tutorials. I’ll probably try coming back in the future to fix/explain some parts that I was confused by.

Low-Memory Dropout

In the previous blog we left off with matrix multiplication so let’s move on to Low-Memory Dropout! The link to the tutorial is here.

Dropout is usually applied to deep learning algorithms to randomly cut/zero out some features to reduce overfitting like below

Image taken from https://medium.com/@amarbudhiraja/https-medium-com-amarbudhiraja-learning-less-to-learn-better-dropout-in-deep-machine-learning-74334da4bfc5

As the tutorial states, “Each scalar in the output has a probability 𝑝 of being changed to zero and otherwise it is copied from the input. This forces the network to perform well even when only 1−𝑝 scalars from the input are available”

To keep the norm the same the entire thing is multiplied by 1/(1-p). The baseline implementation is below!

import tabulate
import torch

import triton
import triton.language as tl


@triton.jit
def _dropout(
x_ptr, # pointer to the input
x_keep_ptr, # pointer to a mask of 0s and 1s
output_ptr, # pointer to the output
n_elements, # number of elements in the `x` tensor
p, # probability that an element of `x` is changed to zero
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
# Load data
x = tl.load(x_ptr + offsets, mask=mask)
x_keep = tl.load(x_keep_ptr + offsets, mask=mask)
# The line below is the crucial part, described in the paragraph above!
output = tl.where(x_keep, x / (1 - p), 0.0)
# Write-back output
tl.store(output_ptr + offsets, output, mask=mask)


def dropout(x, x_keep, p):
output = torch.empty_like(x)
assert x.is_contiguous()
n_elements = x.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
_dropout[grid](x, x_keep, output, n_elements, p, BLOCK_SIZE=1024)
return output


# Input tensor
x = torch.randn(size=(10, )).cuda()
# Dropout mask
p = 0.5
x_keep = (torch.rand(size=(10, )) > p).to(torch.int32).cuda()
#
output = dropout(x, x_keep=x_keep, p=p)
print(tabulate.tabulate([
["input"] + x.tolist(),
["keep mask"] + x_keep.tolist(),
["output"] + output.tolist(),
]))

I think the idea here is not too new compared to what we had in part 1. Basically the dropout mask is computed

# Input tensor
x = torch.randn(size=(10, )).cuda()
# Dropout mask
p = 0.5
x_keep = (torch.rand(size=(10, )) > p).to(torch.int32).cuda()

then applied block by block

output = tl.where(x_keep, x / (1 - p), 0.0)

while being scaled by (1-p).

However, this implementation is suboptimal. The reason the tutorial raises is

  1. We need to store the dropout mask for backpropagation
  2. Dropout state management can get very tricky when using recompute/checkpointing. According to here, for gradient checkpointing, a technique to save vram, pytorch usually reruns each segment during backdrop and stashes and restores the rng state. So here if we do drop out, by default Pytorch can get the exact same drop out on backprop! One part I don’t get is they also say they “juggle” the rng state for deterministic output. I’ll try adding this in if I get it later.

Here is where triton.language.rand comes in! The tutorial argues that apart from simplifying code/rng state, this also reduces VRAM which I’m guessing is from how the dropout mask is stored. I’ll update this if I figure it out.

@triton.jit
def _seeded_dropout(
x_ptr,
output_ptr,
n_elements,
p,
seed,
BLOCK_SIZE: tl.constexpr,
):
# compute memory offsets of elements handled by this instance
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# load data from x
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
# randomly prune it
random = tl.rand(seed, offsets)
x_keep = random > p
# write-back
output = tl.where(x_keep, x / (1 - p), 0.0)
tl.store(output_ptr + offsets, output, mask=mask)


def seeded_dropout(x, p, seed):
output = torch.empty_like(x)
assert x.is_contiguous()
n_elements = x.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
_seeded_dropout[grid](x, output, n_elements, p, seed, BLOCK_SIZE=1024)
return output


x = torch.randn(size=(10, )).cuda()
# Compare this to the baseline - dropout mask is never instantiated!
output = seeded_dropout(x, p=0.5, seed=123)
output2 = seeded_dropout(x, p=0.5, seed=123)
output3 = seeded_dropout(x, p=0.5, seed=512)

print(
tabulate.tabulate([
["input"] + x.tolist(),
["output (seed = 123)"] + output.tolist(),
["output (seed = 123)"] + output2.tolist(),
["output (seed = 512)"] + output3.tolist(),
]))

Layer Normalization

The next tutorial is on Layer Normalization which you can follow along here.

LayerNorm is a pretty famous normalization method along with Batch Norm, Instance Norm, and Group Norm. For layer norm, in particular, normalization is done across the feature dimension(in the below image’s case H times W).

Picture taken from https://arxiv.org/pdf/1803.08494

I heard each normalization method offers a different benefit but that is a topic for another blog. Now, let’s look at how we should implement layer norm in triton! The formula for layer norm is

So overall, given x we subtract the mean and divide by std. We add a small epsilon to avoid a 0 division error. The w and b are learnable parameters so that we can have the output be the mean and std we want!

The code is

import torch

import triton
import triton.language as tl

try:
# This is https://github.com/NVIDIA/apex, NOT the apex on PyPi, so it
# should not be added to extras_require in setup.py.
import apex
HAS_APEX = True
except ModuleNotFoundError:
HAS_APEX = False


@triton.jit
def _layer_norm_fwd_fused(
X, # pointer to the input
Y, # pointer to the output
W, # pointer to the weights
B, # pointer to the biases
Mean, # pointer to the mean
Rstd, # pointer to the 1/std
stride, # how much to increase the pointer when moving by 1 row
N, # number of columns in X
eps, # epsilon to avoid division by zero
BLOCK_SIZE: tl.constexpr,
):
# Map the program id to the row of X and Y it should compute.
row = tl.program_id(0)
Y += row * stride
X += row * stride
# Compute mean
mean = 0
_mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
a = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
_mean += a
mean = tl.sum(_mean, axis=0) / N
# Compute variance
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
x = tl.where(cols < N, x - mean, 0.)
_var += x * x
var = tl.sum(_var, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
# Write mean / rstd
tl.store(Mean + row, mean)
tl.store(Rstd + row, rstd)
# Normalize and apply linear transformation
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
mask = cols < N
w = tl.load(W + cols, mask=mask)
b = tl.load(B + cols, mask=mask)
x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32)
x_hat = (x - mean) * rstd
y = x_hat * w + b
# Write output
tl.store(Y + cols, y, mask=mask)

Let’s go step by step.

The input shape is [rows, columns]. Computation is done row by row. Here, the columns are the feature dimension. So we want to normalize the mean and standard deviation of the columns. Thus we get the current row and we go to the start of the row for the input(X) and output(Y). For this, I think the stride should be N(number of columns) but correct me if I’m wrong!

row = tl.program_id(0)
Y += row * stride
X += row * stride

The mean is computed. The reason we need a loop is only in case N is larger than BLOCK_SIZE

mean = 0
_mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
a = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
_mean += a
mean = tl.sum(_mean, axis=0) / N

Variance is computed in a similar way. The reason we do tl.where here is because otherwise this location will be 0 and we will be invalidly adding -mean!

_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
x = tl.where(cols < N, x - mean, 0.)
_var += x * x
var = tl.sum(_var, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)

Finally, do normalization like so

tl.store(Mean + row, mean)
tl.store(Rstd + row, rstd)
# Normalize and apply linear transformation
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
mask = cols < N
w = tl.load(W + cols, mask=mask)
b = tl.load(B + cols, mask=mask)
x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32)
x_hat = (x - mean) * rstd
y = x_hat * w + b
# Write output
tl.store(Y + cols, y, mask=mask)

I’m pretty sure this isn’t the most optimal in terms of memory. For example, I don’t think we have to record the Mean and std which is an extra io write and we are loading in the columns of x 3 times, once for mean, once for std, and one time here. But I may be missing something. But either way, it should work!

Backward Pass

Now, let’s do a backward pass! For this let’s say the output of our layer norm is L! To get some background on the terminology of the tutorial, let’s do a brief aside into Autograd!

Fundamentals of Autograd

For this, I do recommend checking out this blog by Pytorch here. There is also a Huggingface blog that helped me personally here. The main idea of backprop as I understand is

  1. We want to make the partial derivative of the loss with respect to every input 0. The main assumption of deep learning is that this causes the loss to be minimal.
  2. To do this, we can say directly differentiate the loss with respect to all the inputs however this is not ideal. The reason is that networks are large so if we do this for every parameter we will be recomputing a lot of the gradients of functions over and over
  3. This is where autograd comes in. The idea is basically, let’s compute the intermediate value gradients step by step until we reach the loss gradient!
  4. For this, let’s say we have a simple function at the beginning of our network(ex. a single MLP) and the output, y, is m dimensions and the input, x, is n dimensions then we can make a matrix like below!

This is called the Jacobian. Now, let’s say we have the next MLP layer that outputs a single scaler, l, and we want to get the derivative of the output of our next MLP layer with respect to the input. Then, we only have to compute

which we can do independently from computing J! Then finally if we want to get the partial derivatives of l with respect to all the input xs we can just do matrix multiplication

My understanding is this is called the vector-Jacobian product(VJP). Now, let’s compute the VJP of the outputs of our Layernorm with respect to the inputs

Back to Layer Norm Back Prop

Now let’s go step by step. Thankfully someone already did part of the computations for us: here! In particular, the partial derivative of the mean is, if we have n be N,

and for the standard deviation is

Then, in summary, we get

Here, I’m pretty sure this is w_k for the scale as w does an element-wise product with normalized x so the only index of w which has a contribution in y_k is w_k. a is 1 if i is the same as k, and otherwise 0. Here, the authors above define the normalization part of it(subtract mean and divide by std) as

So it is slightly different from what we have. However, I argue that it won’t make much difference since during differentiation the triton authors seem to ignore ϵ anyway. The above formula simplifies to

Now I think here, the authors of the triton tutorial ignore ϵ. Then we have

Now, can we put this in matrix form? For this part, I wasn’t able to figure out how to get it to the author’s expression for the moment but I think you can see that we are close! The below equation is the partial derivative of x with respect to the loss so we just multiply by the partial derivative of the loss with respect to y

I’ll try filling out a derivation if I can figure it out later(let me know if any math people happens to already know this!)

Now, for the gradients of the weights and biases the authors already computed them as

Now we see that the updates to these 2 functions are very simple and are the same across every batch! So if we want to update these it’ll be nice if we can do this without doing any global reads and update just the L2 cache! If you don’t remember, L2 cache was the faster part of the GPU which is not stored globally. The authors follow the following idea:

We want to accumulate the partial derivates of the loss with respect to w so

  1. We make a buffer called DW which accumulates these partial derivates across every batch so that we can sum later
  2. Now the above is not a very great idea for threads because when we write to the buffer to update its value, we have to read it and then add the value we computed, and then write it while all other threads/cores wait. This is typically done with a mutex lock etc.
  3. The idea the authors had was ok so if we write to a single buffer for everything we have this issue of stalling but what if we write to GROUP_SIZE_M number of buffers? Then we can make each of the threads in a group just focus on a single lock!
  4. Then, for accumulating the final partial grad, we can just sum these buffers!

Here’s an illustration from the tutorial:

The authors say that we can keep DW here in L2 cache too! Let’s see how they do it!

@triton.jit
def _layer_norm_bwd_dx_fused(DX, # pointer to the input gradient
DY, # pointer to the output gradient
DW, # pointer to the partial sum of weights gradient
DB, # pointer to the partial sum of biases gradient
X, # pointer to the input
W, # pointer to the weights
Mean, # pointer to the mean
Rstd, # pointer to the 1/std
Lock, # pointer to the lock
stride, # how much to increase the pointer when moving by 1 row
N, # number of columns in X
GROUP_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr):
# Map the program id to the elements of X, DX, and DY it should compute.
row = tl.program_id(0)
cols = tl.arange(0, BLOCK_SIZE_N)
mask = cols < N
X += row * stride
DY += row * stride
DX += row * stride
# Offset locks and weights/biases gradient pointer for parallel reduction
lock_id = row % GROUP_SIZE_M
Lock += lock_id
Count = Lock + GROUP_SIZE_M
DW = DW + lock_id * N + cols
DB = DB + lock_id * N + cols
# Load data to SRAM
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
w = tl.load(W + cols, mask=mask).to(tl.float32)
mean = tl.load(Mean + row)
rstd = tl.load(Rstd + row)
# Compute dx
xhat = (x - mean) * rstd
wdy = w * dy
xhat = tl.where(mask, xhat, 0.)
wdy = tl.where(mask, wdy, 0.)
c1 = tl.sum(xhat * wdy, axis=0) / N
c2 = tl.sum(wdy, axis=0) / N
dx = (wdy - (xhat * c1 + c2)) * rstd
# Write dx
tl.store(DX + cols, dx, mask=mask)
# Accumulate partial sums for dw/db
partial_dw = (dy * xhat).to(w.dtype)
partial_db = (dy).to(w.dtype)
while tl.atomic_cas(Lock, 0, 1) == 1:
pass
count = tl.load(Count)
# First store doesn't accumulate
if count == 0:
tl.atomic_xchg(Count, 1)
else:
partial_dw += tl.load(DW, mask=mask)
partial_db += tl.load(DB, mask=mask)
tl.store(DW, partial_dw, mask=mask)
tl.store(DB, partial_db, mask=mask)
# Release the lock
tl.atomic_xchg(Lock, 0)


@triton.jit
def _layer_norm_bwd_dwdb(DW, # pointer to the partial sum of weights gradient
DB, # pointer to the partial sum of biases gradient
FINAL_DW, # pointer to the weights gradient
FINAL_DB, # pointer to the biases gradient
M, # GROUP_SIZE_M
N, # number of columns
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr):
# Map the program id to the elements of DW and DB it should compute.
pid = tl.program_id(0)
cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
# Iterate through the rows of DW and DB to sum the partial sums.
for i in range(0, M, BLOCK_SIZE_M):
rows = i + tl.arange(0, BLOCK_SIZE_M)
mask = (rows[:, None] < M) & (cols[None, :] < N)
offs = rows[:, None] * N + cols[None, :]
dw += tl.load(DW + offs, mask=mask, other=0.)
db += tl.load(DB + offs, mask=mask, other=0.)
# Write the final sum to the output.
sum_dw = tl.sum(dw, axis=0)
sum_db = tl.sum(db, axis=0)
tl.store(FINAL_DW + cols, sum_dw, mask=cols < N)
tl.store(FINAL_DB + cols, sum_db, mask=cols < N)

The first part of _layer_norm_bwd_dx_fused seems standard where we get the lock id and the important positons in X, DY, and DX. In addition, we seem to get a variable called Count

row = tl.program_id(0)
cols = tl.arange(0, BLOCK_SIZE_N)
mask = cols < N
X += row * stride
DY += row * stride
DX += row * stride
# Offset locks and weights/biases gradient pointer for parallel reduction
lock_id = row % GROUP_SIZE_M
Lock += lock_id
Count = Lock + GROUP_SIZE_M

Then for DW and DB, which are the buffers we want to store the partial gradients too, we do

DW = DW + lock_id * N + cols
DB = DB + lock_id * N + cols

to get to the position in the cache where we want to store add the current partial derivatives!

x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
w = tl.load(W + cols, mask=mask).to(tl.float32)
mean = tl.load(Mean + row)
rstd = tl.load(Rstd + row)
xhat = (x - mean) * rstd
wdy = w * dy
xhat = tl.where(mask, xhat, 0.)
wdy = tl.where(mask, wdy, 0.)
c1 = tl.sum(xhat * wdy, axis=0) / N
c2 = tl.sum(wdy, axis=0) / N
dx = (wdy - (xhat * c1 + c2)) * rstd
tl.store(DX + cols, dx, mask=mask)

Data for everything except the buffers are loaded and dx is computed!

Then, we compute the partial of w and b

partial_dw = (dy * xhat).to(w.dtype)
partial_db = (dy).to(w.dtype)

and then we get to a lock loop!

while tl.atomic_cas(Lock, 0, 1) == 1:
pass

For what this function does, it’s called a compare and swap operation where it returns False if the value at the Lock is not 0 otherwise, it sets the lock value to 1 in this case!

So what this accomplishes is that it waits for the value at the Lock to stop being 1, then it sets the Lock to 1 and causes the loop to terminate and for us to be able to move forward while the other threads in the group remain stuck in the while loop. I assume this function on false returns 1.

It’s a pretty interesting low-level way of doing accumulation in my opinion. Finally, we do

if count == 0:
tl.atomic_xchg(Count, 1)
else:
partial_dw += tl.load(DW, mask=mask)
partial_db += tl.load(DB, mask=mask)
tl.store(DW, partial_dw, mask=mask)
tl.store(DB, partial_db, mask=mask)
# Release the lock
tl.atomic_xchg(Lock, 0)

The atomic exchange seems to be a way to read and write a single-byte value very fast in triton. The idea here is that if the number of additions is 0, we don’t have to read the cache because we know it is 0!

So we just load it in the buffer and add it to the partial derivatives if the count is non-zero. And we see the point of the Count now. It is stored in

Count = Lock + GROUP_SIZE_M

so the memory location won’t overlap with the locks!

And finally, after saving to buffer, we release the lock by setting it to 0!

Now, for adding up, it’s pretty simple compared to the above function as we do the sum in one go

@triton.jit
def _layer_norm_bwd_dwdb(DW, # pointer to the partial sum of weights gradient
DB, # pointer to the partial sum of biases gradient
FINAL_DW, # pointer to the weights gradient
FINAL_DB, # pointer to the biases gradient
M, # GROUP_SIZE_M
N, # number of columns
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr):
# Map the program id to the elements of DW and DB it should compute.
pid = tl.program_id(0)
cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
# Iterate through the rows of DW and DB to sum the partial sums.
for i in range(0, M, BLOCK_SIZE_M):
rows = i + tl.arange(0, BLOCK_SIZE_M)
mask = (rows[:, None] < M) & (cols[None, :] < N)
offs = rows[:, None] * N + cols[None, :]
dw += tl.load(DW + offs, mask=mask, other=0.)
db += tl.load(DB + offs, mask=mask, other=0.)
# Write the final sum to the output.
sum_dw = tl.sum(dw, axis=0)
sum_db = tl.sum(db, axis=0)
tl.store(FINAL_DW + cols, sum_dw, mask=cols < N)
tl.store(FINAL_DB + cols, sum_db, mask=cols < N)

The function goes over BLOCKSIZE_N columns of the partial derivatives at a time. The rows are BLOCK_SIZE_M number of rows at a time. The mask is just there so that the values don’t go out of bounds. Now since when we stored we did

DW = DW + lock_id * N + cols

we can get the offset to a specific group by doing

offs = rows[:, None] * N + cols[None, :]

What this offset does is it gets all the group buffers that is relevant to the current list of columns. The reason we do None here and not in the original script is this allows something like this

>>> np.arange(0, 5)[:, None]*6+np.arange(0, 6)[None, :]
array([[ 0, 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10, 11],
[12, 13, 14, 15, 16, 17],
[18, 19, 20, 21, 22, 23],
[24, 25, 26, 27, 28, 29]])

which I think illustrates my point! Here, we are only selecting a subset of the rows so it won’t be the case that all the numbers are connected like that.

Now, my main question here is why can we be so sure that the DW and DB buffers will be exclusively in the L2 cache as it does seem like global read/write is taking place. I think Triton does handle it automatically but will update this if I can be sure.

Overall, the authors did a benchmark of the triton Layernorm vs torch and it was pretty significantly faster

Fused Attention 2

This seems to be just an implementation of Flash Attention 2 whose paper is here. The basic idea is that first there was the flash attention paper “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness”.

This paper observed an issue with the attention architecture. When computing the attention matrix we do

where the result of multiplying Q by K is a N by N matrix where N is the sequence length. The idea here is we can calculate how much each token should pay attention to each other token in the matrix that comes out of the softmax so

Image taken from https://www.researchgate.net/figure/Attention-matrix-visualization-a-weights-in-BERT-Encoding-Unit-Entity-BERT-b_fig5_359215965

In the case of causal models like LLMs where we are not supposed to know future tokens, the upper triangle part of the attention matrix is zeroed out like so

Image taken from https://magazine.sebastianraschka.com/p/understanding-and-coding-self-attention

In the case of LLMs, this is the amount of tokens. Then, we have to store an O(N²) attention matrix in VRAM which is extremely expensive even for 100k tokens and even when computing 1 token from there we need way more memory.

Now, to handle this, the authors of Flash Attention did 2 things

  1. They came up with a way to do computations block by block with a clever formulation to get around softmax. So the required memory size is just O(n)!

The clever method is called lazy softmax like below

Now normal softmax is

so this does make sense where the max value is subtracted. But if doing the computation block by block with Q, K, and V how do we get the correct max values without approximation? The main idea in the code seems like we save these max values and keep rescaling like so

One optimization for Flash Attention 2 was the authors observed we don’t have to actually rescale with l2, l3, etc on every output we get. We can just maintain the ls and rescale by the final l right at the end like so

2. To do that computation, they did it exclusively in the SRAM so that the speed is extremely fast!

For Flash Attention 2, as the abstract mentions, “However, FlashAttention is still not nearly as fast as optimized matrix-multiply (GEMM) operations, reaching only 25–40% of the theoretical maximum FLOPs/s. We observe that the inefficiency is due to suboptimal work partitioning between different thread blocks and warps on the GPU, causing either low-occupancy or unnecessary shared memory reads/writes.”

So essentially, it is a low-level fix to flash attention to reach around 72% of maximum FLOPs/s. Here FLOPs means floating point operations which is a measure of GPUs too!

Some(but not all) optimizations mentioned in the paper apart from the ls above are:

  1. In causal attention, ignore blocks that are already masked when doing computation.
  2. Instead of loading all the queries and loading blocks from K and V, the authors found it to be more efficient to load blocks of Q and load the entire K and V for computation! Like so

I think intuitively it makes sense as for each Q we need to load in a block of both K and V for flash attention but here for each load of Q we can use the already loaded K and V so global reads are reduced by half.

The code is a bit long so let’s go step by step.

Now, first of all, let’s look at the pytorch implementation!

q = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_())
k = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_())
v = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_())
sm_scale = 0.5
dout = torch.randn_like(q)
# reference implementation
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
if causal:
p[:, :, M == 0] = float("-inf")
p = torch.softmax(p.float(), dim=-1).half()
# p = torch.exp(p)
ref_out = torch.matmul(p, v)
ref_out.backward(dout)

First of all, we have q, k, and v from the formula

Then we define something called sm_cale. My understanding is this is just the square root dk term in the equation. Next, we have M defined with torch.tril which puts ones in the location of the lower triangle including the diagonal like so!

Taken from https://pytorch.org/docs/stable/generated/torch.tril.html

What this accomplishes is it’ll make a mask for causal attention. Next, we make an attention matrix, p like so,

p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
if causal:
p[:, :, M == 0] = float("-inf")
p = torch.softmax(p.float(), dim=-1).half()

and then we just multiply by b and do backward and we are basically done!

ref_out = torch.matmul(p, v)
ref_out.backward(dout)

Now, let’s move on to the triton implementation. For triton, we call

tri_out = attention(q, k, v, causal, sm_scale).half()
tri_out.backward(dout)

Now, what’s the implementation of this attention function? This is

attention = _attention.apply

This, according to here, this goes straight to forward which has the following implementation:

# shape constraints
HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1]
# when v is in float8_e5m2 it is transposed.
HEAD_DIM_V = v.shape[-2] if v.dtype == torch.float8_e5m2 else v.shape[-1]
assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V
assert HEAD_DIM_K in {16, 32, 64, 128, 256}
o = torch.empty_like(q)
stage = 3 if causal else 1
extra_kern_args = {}
# Tuning for AMD target
if is_hip():
waves_per_eu = 3 if HEAD_DIM_K <= 64 else 2
extra_kern_args = {"waves_per_eu": waves_per_eu, "allow_flush_denorm": True}

grid = lambda args: (triton.cdiv(q.shape[2], args["BLOCK_M"]), q.shape[0] * q.shape[1], 1)
M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
_attn_fwd[grid](
q, k, v, sm_scale, M, o, #
q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
o.stride(0), o.stride(1), o.stride(2), o.stride(3), #
q.shape[0], q.shape[1], #
N_CTX=q.shape[2], #
HEAD_DIM=HEAD_DIM_K, #
STAGE=stage, #
**extra_kern_args)

ctx.save_for_backward(q, k, v, o, M)
ctx.grid = grid
ctx.sm_scale = sm_scale
ctx.HEAD_DIM = HEAD_DIM_K
ctx.causal = causal
return o

Here, is_hip seems to be about this which is a driver to make code run in both AMD and Nvidia GPUs so the extra_kern_kwargs are specific to that.

One part that is interesting is that M is initialized as a torch.empty instead of the lower triangular ones like in the Pytorch one.

In addition, I’m curious why v is transposed when in float8 here

# when v is in float8_e5m2 it is transposed.
HEAD_DIM_V = v.shape[-2] if v.dtype == torch.float8_e5m2 else v.shape[-1]

And why only v in q, k, and v is getting this treatment. I’m guessing it’s something to do with numerical stability but I’ll write here if I get it. It wasn’t mentioned in the paper.

Another interesting part is the STAGE variable. If causal it is 3 so let’s go ahead assuming that this is 3.

Now, let’s go to the _attn_fwd function. We have

@triton.autotune(list(filter(keep, configs)), key=["N_CTX"])
@triton.jit
def _attn_fwd(Q, K, V, sm_scale, M, Out, #
stride_qz, stride_qh, stride_qm, stride_qk, #
stride_kz, stride_kh, stride_kn, stride_kk, #
stride_vz, stride_vh, stride_vk, stride_vn, #
stride_oz, stride_oh, stride_om, stride_on, #
Z, H, N_CTX, #
BLOCK_M: tl.constexpr, #
BLOCK_N: tl.constexpr, #
HEAD_DIM: tl.constexpr, #
STAGE: tl.constexpr #
):
tl.static_assert(BLOCK_N <= HEAD_DIM)
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
off_z = off_hz // H
off_h = off_hz % H
qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh

# block pointers
Q_block_ptr = tl.make_block_ptr(
base=Q + qvk_offset,
shape=(N_CTX, HEAD_DIM),
strides=(stride_qm, stride_qk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, HEAD_DIM),
order=(1, 0),
)
v_order: tl.constexpr = (0, 1) if V.dtype.element_ty == tl.float8e5 else (1, 0)
V_block_ptr = tl.make_block_ptr(
base=V + qvk_offset,
shape=(N_CTX, HEAD_DIM),
strides=(stride_vk, stride_vn),
offsets=(0, 0),
block_shape=(BLOCK_N, HEAD_DIM),
order=v_order,
)
K_block_ptr = tl.make_block_ptr(
base=K + qvk_offset,
shape=(HEAD_DIM, N_CTX),
strides=(stride_kk, stride_kn),
offsets=(0, 0),
block_shape=(HEAD_DIM, BLOCK_N),
order=(0, 1),
)
O_block_ptr = tl.make_block_ptr(
base=Out + qvk_offset,
shape=(N_CTX, HEAD_DIM),
strides=(stride_om, stride_on),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, HEAD_DIM),
order=(1, 0),
)
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
# load scales
qk_scale = sm_scale
qk_scale *= 1.44269504 # 1/log(2)
# load q: it will stay in SRAM throughout
q = tl.load(Q_block_ptr)
# stage 1: off-band
# For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE
# For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE
if STAGE & 1:
acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, #
start_m, qk_scale, #
BLOCK_M, HEAD_DIM, BLOCK_N, #
4 - STAGE, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5 #
)
# stage 2: on-band
if STAGE & 2:
# barrier makes it easier for compielr to schedule the
# two loops independently
acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, #
start_m, qk_scale, #
BLOCK_M, HEAD_DIM, BLOCK_N, #
2, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5 #
)
# epilogue
m_i += tl.math.log2(l_i)
acc = acc / l_i[:, None]
m_ptrs = M + off_hz * N_CTX + offs_m
tl.store(m_ptrs, m_i)
tl.store(O_block_ptr, acc.to(Out.type.element_ty))

First of all, we have

@triton.autotune(list(filter(keep, configs)), key=["N_CTX"])

What this does is it gets the filtered configs from

configs = [
triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN}, num_stages=s, num_warps=w) \
for BM in [64, 128]\
for BN in [32, 64]\
for s in ([1] if is_hip() else [3, 4, 7])\
for w in [4, 8]\
]


def keep(conf):
BLOCK_M = conf.kwargs["BLOCK_M"]
BLOCK_N = conf.kwargs["BLOCK_N"]
if BLOCK_M * BLOCK_N < 128 * 128 and conf.num_warps == 8:
return False
return True

and finds the optimal BLOCK_M and BLOCK_N whenever the N_CTX changes in the arguments!

Next, interestingly we have 2 ids

start_m = tl.program_id(0)
off_hz = tl.program_id(1)

My guess here is that this tells us which “block” we are in from

Next, remember the size of all q, k, v are

q = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_())
k = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_())
v = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_())

So judging from the name off_hz, this should give us the offset to the current Z and current H that we are focused on. start_m should mean that it gives the location in the attention block that we are calculating.

off_z = off_hz // H
off_h = off_hz % H
qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh

And thus we have the offsets! Now we do

# block pointers
Q_block_ptr = tl.make_block_ptr(
base=Q + qvk_offset,
shape=(N_CTX, HEAD_DIM),
strides=(stride_qm, stride_qk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, HEAD_DIM),
order=(1, 0),
)
v_order: tl.constexpr = (0, 1) if V.dtype.element_ty == tl.float8e5 else (1, 0)
V_block_ptr = tl.make_block_ptr(
base=V + qvk_offset,
shape=(N_CTX, HEAD_DIM),
strides=(stride_vk, stride_vn),
offsets=(0, 0),
block_shape=(BLOCK_N, HEAD_DIM),
order=v_order,
)
K_block_ptr = tl.make_block_ptr(
base=K + qvk_offset,
shape=(HEAD_DIM, N_CTX),
strides=(stride_kk, stride_kn),
offsets=(0, 0),
block_shape=(HEAD_DIM, BLOCK_N),
order=(0, 1),
)
O_block_ptr = tl.make_block_ptr(
base=Out + qvk_offset,
shape=(N_CTX, HEAD_DIM),
strides=(stride_om, stride_on),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, HEAD_DIM),
order=(1, 0),
)

Now what are these tl.make_block_ptrs? The triton documentation call these “blocks of pointers”. My understanding from looking at the rest of the code is that it’s pretty similar functionally to normal pointers. The one difference is you can do something a bit cool like

K_block_ptr = tl.advance(K_block_ptr, (0, lo))

to move forward in the pointer globally so we don’t have to keep track of say complicated indices! I think for this we have to preset the “order” parameter to be the same movement as the tl.advance parameter.

Now, we do

offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
# load scales
qk_scale = sm_scale
qk_scale *= 1.44269504 # 1/log(2)
# load q: it will stay in SRAM throughout
q = tl.load(Q_block_ptr)

for offs_m, I think it’s the offset to the block as we mentioned. BLOCK_N seems to be the amount of data processed at once within BLOCK_M! The m_is are initialized to negative infinity so that when raised by e like in softmax, this becomes 0.

The l_is are the scales so we initialize them to 1.

The acc is the output here before getting scaled(in O)

and now we load q!

# load scales
qk_scale = sm_scale
qk_scale *= 1.44269504 # 1/log(2)
# load q: it will stay in SRAM throughout
q = tl.load(Q_block_ptr)

Now I initially found it confusing that q is loaded instead of k and v. Since I assumed we want to load K and V and then in the inner loop load q like below.

My current understanding/guess is that here only a block of q is loaded and is shared while in the inner function that we will see later all of k and v are loaded to operate on this block of q.

Next, we do what the code calls “stage 1: off-band”

# stage 1: off-band
# For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE
# For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE
if STAGE & 1:
acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, #
start_m, qk_scale, #
BLOCK_M, HEAD_DIM, BLOCK_N, #
4 - STAGE, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5 #
)

Here, the reason for the if statements seem to be that we want to have the compiler run each loop independently. Here if causal we get 1 as 4–3 is 1 and if not we get 3 as the stage of _attn_fwd_inner

@triton.jit
def _attn_fwd_inner(acc, l_i, m_i, q, #
K_block_ptr, V_block_ptr, #
start_m, qk_scale, #
BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, #
STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, #
N_CTX: tl.constexpr, fp8_v: tl.constexpr):
# range of values handled by this stage
if STAGE == 1:
lo, hi = 0, start_m * BLOCK_M
elif STAGE == 2:
lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
lo = tl.multiple_of(lo, BLOCK_M)
# causal = False
else:
lo, hi = 0, N_CTX

Stage 2 seems like just the next block after Stage 1. This is not reached for the first stage regardless of causal or non-causal. The tl.multiple_of is explained here to just be telling the compiler that lo is a multiple of BLOCK_M. I’m pretty curious why this line is necessary. If the stage is 3, as was intended by the code for non-causal inputs, the range will be the entire context.

Now, from here, the pointers for K and V are moved forward to the chunk of data of interest!

K_block_ptr = tl.advance(K_block_ptr, (0, lo))
V_block_ptr = tl.advance(V_block_ptr, (lo, 0))

Then, our loop begins!

for start_n in range(lo, hi, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(K_block_ptr)
qk = tl.dot(q, k)
if STAGE == 2:
mask = offs_m[:, None] >= (start_n + offs_n[None, :])
qk = qk * qk_scale + tl.where(mask, 0, -1.0e6)
m_ij = tl.maximum(m_i, tl.max(qk, 1))
qk -= m_ij[:, None]
else:
m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale)
qk = qk * qk_scale - m_ij[:, None]
p = tl.math.exp2(qk)
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
alpha = tl.math.exp2(m_i - m_ij)
l_i = l_i * alpha + l_ij
# -- update output accumulator --
acc = acc * alpha[:, None]
# update acc
v = tl.load(V_block_ptr)
if fp8_v:
p = p.to(tl.float8e5)
else:
p = p.to(tl.float16)
acc = tl.dot(p, v, acc)
# update m_i and l_i
m_i = m_ij
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))

First, the key is loaded and we do a dot product with the block of q. If we are in STAGE 2, we add a mask to 0 out the dot product value if it’s not covered by the mask and then subtract the max recorded qk value.

One part I felt a bit puzzled in was I remember the paper mentioned skipping the computation of indices that are not covered by the mask which doesn’t seem to happen here. My best guess is the compiler figures this out.

Now, one thing about tl.math.exp2. I found an issue here explaining it but basically it raises it to the power of 2 which is valid as we do

qk_scale *= 1.44269504  # 1/log(2)

to fix the scale.

Next, to update l we want to get the sum of all the ps which we accomplish here

l_ij = tl.sum(p, 1)

Here, for this parameter alpha, it confused me a bit but

alpha = tl.math.exp2(m_i - m_ij)

what alpha is here is it’s the division between the past maximum value and the current one for this block if both are raised to the exp!

Initially, the m_i is initialized like so

m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")

Now, in the beginning, alpha is always 0 since we have -float(“inf”) for m_i but this changes as we do

m_i = m_ij

and l_i is updated to l_ij like so

l_i = l_i * alpha + l_ij

Now, why can we scale like this? My intuition is that to get l_i we did

qk -= m_ij[:, None]

Then,

p = tl.math.exp2(qk)
l_ij = tl.sum(p, 1)

and so if we substitute values/do simple calculations into

l_i = l_i * alpha + l_ij

we get

l_i = l_i * tl.math.exp2(m_i)/tl.math.exp2(m_ij) + tl.sum(tl.math.exp2(qk), 1)/tl.math.exp2(m_ij)

so essentially what we are doing is we are scaling by the new found max value!

Next, we do

acc = acc * alpha[:, None]

which is the same as multiplying by

tl.math.exp2(m_i)/tl.math.exp2(m_ij)

so it’s “fixed” with the new max. Then, we load v and we do a dot product!

 v = tl.load(V_block_ptr)
if fp8_v:
p = p.to(tl.float8e5)
else:
p = p.to(tl.float16)
acc = tl.dot(p, v, acc)

The documentation says if we put an accumulator in the third parameter the results get added there. So this is the same as doing

acc = tl.dot(p, v)+acc

Then finally everything is updated and we move to the next block

m_i = m_ij
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))

at the end of this function

return acc, l_i, m_i

is returned.

Now, back to the original forward function, we have

# stage 2: on-band
if STAGE & 2:
# barrier makes it easier for compielr to schedule the
# two loops independently
acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, #
start_m, qk_scale, #
BLOCK_M, HEAD_DIM, BLOCK_N, #
2, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5 #
)

This is only true for non-causal inputs! And can run independently due to the very nice compiler Triton has. As we discussed before this runs the other blocks in start_m. One part that confused me for a bit was below

if STAGE == 1:
lo, hi = 0, start_m * BLOCK_M
elif STAGE == 2:
lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
lo = tl.multiple_of(lo, BLOCK_M)

why can’t we just have start_m include 0 so that we don’t need to separate into stages? The main reason I think we do this is for

if STAGE == 2:
mask = offs_m[:, None] >= (start_n + offs_n[None, :])
qk = qk * qk_scale + tl.where(mask, 0, -1.0e6)
m_ij = tl.maximum(m_i, tl.max(qk, 1))
qk -= m_ij[:, None]
else:
m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale)
qk = qk * qk_scale - m_ij[:, None]

we want to know if we need to use a mask or not like the above!

And finally, we clean up

m_i += tl.math.log2(l_i)
acc = acc / l_i[:, None]
m_ptrs = M + off_hz * N_CTX + offs_m
tl.store(m_ptrs, m_i)
tl.store(O_block_ptr, acc.to(Out.type.element_ty))

and we are done with computing.

I think I might go for trying to understand the backward pass but I feel like this blog went on long enough so I’ll skip that and the last 2 tutorials for now.

I think for this blog I had some parts I didn’t know/understand like when exactly the l2 cache is used or how to derive the matrix form of back prop fully for layer norm so I think I’ll come back to this blog to fix those when I get them. But overall I hope this blog helped on your triton journey!

--

--

Responses (1)