Understanding the Triton Tutorials Part 1
Motivation
I’m not sure I’ll be part of the group but recently Eleuther AI’s discord server started a Triton reading group which made me motivated to study triton. Also, I’ve heard at least Open AI is pretty interested in triton developers so I think it might be a good skillset to have!
What issue does Triton solve?
When working with deep learning, the most common method is to just put everything in Pytorch or TensorFlow and just start experimenting. However, let’s say you want to do large-scale experiments, e.g. Open AI’s GPT-3/4. One thing you quickly realize is that this GPU training business or even just inference is extremely expensive. Open AI reportedly lost around half a billion dollars so far in GPU costs while Facebook’s llama costs are approximated to be a couple of million dollars for one training run. For both of these cases, having a way to reduce this cost by 1% can save a huge amount of money.
To do this, one solution is working with low-level Cuda code. This means that instead of having Pytorch handle the complicated operations of allocating tensors, we do all the tiny things ourselves. While this tends to be significantly faster, at a certain point, it becomes too tedious. That’s when Open AI released a new language called triton. The main goal of triton is to be optimized at and higher level than cuda but a lower level than Pytorch.
In this blog, I plan to go through the Triton tutorials. To follow around, check out this link!
Short cut
One shortcut if you already know Pytorch might be the torch inductor which I heard can compile Pytorch code to triton. I’m personally pretty sure writing your own triton code will be more efficient but I’ve heard this saves a lot of memory.
Vector Addition
The first example in the tutorial is pretty simple but I will go over it since it does establish some basic fundamentals. First, we import
import torch
import triton
import triton.language as tlp
Next, we make a compiled code kernel
@triton.jit
def add_kernel(
x_ptr, # *Pointer* to first input vector.
y_ptr, # *Pointer* to second input vector.
output_ptr, # *Pointer* to output vector.
n_elements, # Size of the vector.
BLOCK_SIZE: tl.constexpr, # Number of elements each program should process.
# NOTE: `constexpr` so it can be used as a shape value.
):
# There are multiple 'programs' processing different data. We identify which program
# we are here:
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.
# This program will process inputs that are offset from the initial data.
# For instance, if you had a vector of length 256 and block_size of 64, the programs
# would each access the elements [0:64, 64:128, 128:192, 192:256].
# Note that offsets is a list of pointers:
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# Create a mask to guard memory operations against out-of-bounds accesses.
mask = offsets < n_elements
# Load x and y from DRAM, masking out any extra elements in case the input is not a
# multiple of the block size.
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = x + y
# Write x + y back to DRAM.
tl.store(output_ptr + offsets, output, mask=mask)
The first part to point out is each triton kernel starts with a decorator
@triton.jit
which tells triton that this code will be compiled.
Then, we get the arguments,
def add_kernel(
x_ptr, # *Pointer* to first input vector.
y_ptr, # *Pointer* to second input vector.
output_ptr, # *Pointer* to output vector.
n_elements, # Size of the vector.
BLOCK_SIZE: tl.constexpr, # Number of elements each program should process.
# NOTE: `constexpr` so it can be used as a shape value.
):
I’ve seen the legendary lucidrian make code in the below fashion for the decorator
@triton.autotune(configs = [
triton.Config({'BLOCK_SIZE': 128}, num_warps = 4, pre_hook = clone_inplace_updated_params),
triton.Config({'BLOCK_SIZE': 1024}, num_warps = 8, pre_hook = clone_inplace_updated_params),
], key = ['n_elements'])
@triton.jit
My understanding of this decorator is that instead of providing a BLOCK_SIZE in the argument like the above, we have triton choose the best block size given the number of elements. According to here there seems to be some repetitions before the tuner finds the best config.
Now, what is BLOCK_SIZE? BLOCK_SIZE seems to be the amount of memory/elements loaded at once
For the number of iterations we need to go over n_elements, we just tell the kernel when we call it like so
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
# NOTE:
# - Each torch.tensor object is implicitly converted into a pointer to its first element.
# - `triton.jit`'ed functions can be indexed with a launch grid to obtain a callable GPU kernel.
# - Don't forget to pass meta-parameters as keywords arguments.
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
Now, one important part of a loop is we want to have an index. That’s what pid provides
pid = tl.program_id(axis=0)
this is the id of the block size. So to get the current list of indices we are processing, we do
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
Here, if say our current pid is 3 then we get all the indices of pid*BLOCK_SIZE to (pid+1)*BLOCK_SIZE until the end of our vector.
Then, we load the data
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
sum them
output = x + y
and then store to output
tl.store(output_ptr + offsets, output, mask=mask)
Then, for working with Python we do:
def add(x: torch.Tensor, y: torch.Tensor):
# We need to preallocate the output.
output = torch.empty_like(x)
assert x.is_cuda and y.is_cuda and output.is_cuda
n_elements = output.numel()
# The SPMD launch grid denotes the number of kernel instances that run in parallel.
# It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int].
# In this case, we use a 1D grid where the size is the number of blocks:
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
# NOTE:
# - Each torch.tensor object is implicitly converted into a pointer to its first element.
# - `triton.jit`'ed functions can be indexed with a launch grid to obtain a callable GPU kernel.
# - Don't forget to pass meta-parameters as keywords arguments.
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
# We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still
# running asynchronously at this point.
return output
So, essentially, we pass in information on the number of iterations we need to iterate the current tensor and we are done! This gives roughly the equivalent time to the torch implementation!
Now, let’s go onto a bit harder examples with actual computational gains.
Fused Softmax
The tutorial first begins with a runtime analysis for the default Pytorch version. The main interesting part for me is the main computational bottleneck doesn’t seem to be the computation itself but more the loading and saving of matrices which is pretty interesting
import torch
import triton
import triton.language as tl
@torch.jit.script
def naive_softmax(x):
"""Compute row-wise softmax of X using native pytorch
We subtract the maximum element in order to avoid overflows. Softmax is invariant to
this shift.
"""
# read MN elements ; write M elements
x_max = x.max(dim=1)[0]
# read MN + M elements ; write MN elements
z = x - x_max[:, None]
# read MN elements ; write MN elements
numerator = torch.exp(z)
# read MN elements ; write M elements
denominator = numerator.sum(dim=1)
# read MN + M elements ; write MN elements
ret = numerator / denominator[:, None]
# in total: read 5MN + 2M elements ; wrote 3MN + 2M elements
return ret
Overall, given an M by N-sized matrix x, each operation seems to read a certain data chunk and write to memory from the GPU DRAM.
Why is memory read/write so important?
Now, one note here about the GPU architecture. Below is an image of the GPU architecture I got from the CUDA C++ guide.
The main idea here is for GPUs and CPUs, we store global variables that we share across cores in the DRAM. Then, above that we have the cache for faster access to global memory. Then, for all the cores, my understanding is that they get a dedicated cache for the data they need to process which they don’t need to share during the processing.
From bottom to top, the less the memory share, in general, the computation is way faster! One of the most clever uses of this I know is a technique called flash attention where instead of doing all the heavy computation in GPU DRAM(HBM), we do it in SRAM(L2 Cache) with some clever blocking. This leads to around two times speed increase overall.
When working with triton and doing
x = tl.load(x_ptr + offsets, mask=mask)
we are loading to the l2 cache/SRAM.
Back to fused softmax
When we compute the amount of reads and writes, we find that we need to read 5MN+2M elements from DRAM and write 3MN+2M using the vanilla Pytorch method. We can theoretically just move x, which is MN, to the GPU cores and then write MN back so we can expect roughly a 4 times increase in speed! (8MN+4M)/2MN
Now, here, we get our first complication. How do we work with 2D matrices? In softmax, we need to find the max across each row, not per block. How do we work with that?
To answer the first question, it seems like we can make our triton program iterate over each row given the stride size
@triton.jit
def softmax_kernel(
output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols,
BLOCK_SIZE: tl.constexpr
):
# The rows of the softmax are independent, so we parallelize across those
row_idx = tl.program_id(0)
# The stride represents how much we need to increase the pointer to advance 1 row
row_start_ptr = input_ptr + row_idx * input_row_stride
# The block size is the next power of two greater than n_cols, so we can fit each
# row in a single block
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
# Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf'))
And that goes to the second question, it seems like for this to work, we need to load at least the entire row into SRAM!
BLOCK_SIZE = triton.next_power_of_2(n_cols)
The full code is below
@triton.jit
def softmax_kernel(
output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols,
BLOCK_SIZE: tl.constexpr
):
# The rows of the softmax are independent, so we parallelize across those
row_idx = tl.program_id(0)
# The stride represents how much we need to increase the pointer to advance 1 row
row_start_ptr = input_ptr + row_idx * input_row_stride
# The block size is the next power of two greater than n_cols, so we can fit each
# row in a single block
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
# Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf'))
# Subtract maximum for numerical stability
row_minus_max = row - tl.max(row, axis=0)
# Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
# Write back output to DRAM
output_row_start_ptr = output_ptr + row_idx * output_row_stride
output_ptrs = output_row_start_ptr + col_offsets
tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)
def softmax(x):
n_rows, n_cols = x.shape
# The block size is the smallest power of two greater than the number of columns in `x`
BLOCK_SIZE = triton.next_power_of_2(n_cols)
# Another trick we can use is to ask the compiler to use more threads per row by
# increasing the number of warps (`num_warps`) over which each row is distributed.
# You will see in the next tutorial how to auto-tune this value in a more natural
# way so you don't have to come up with manual heuristics yourself.
num_warps = 4
if BLOCK_SIZE >= 2048:
num_warps = 8
if BLOCK_SIZE >= 4096:
num_warps = 16
# Allocate output
y = torch.empty_like(x)
# Enqueue kernel. The 1D launch grid is simple: we have one kernel instance per row o
# f the input matrix
softmax_kernel[(n_rows,)](
y,
x,
x.stride(0),
y.stride(0),
n_cols,
num_warps=num_warps,
BLOCK_SIZE=BLOCK_SIZE,
)
return y
Some ideas are introduced here.
- num_warps. My understanding of these are threads but when we do global operations like max, they sync with other threads and keep going.
- since BLOCK_SIZE might be higher than than n_cols, we do’
row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf'))
since e to the power of negative infinity is 0! So these elements won’t affect the sum.
3. To get the stride of rows, I was pretty impressed that it was as simple as
x.stride(0)
using Pytorch. In practice, this seems the same as the number of columns from my brief experiment below
>>> import torch
>>> a = torch.zeros((5, 5))
>>> a
tensor([[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]])
>>> a.stride(0)
5
>>> a.stride(1)
1
but my guess is .stride handles some edge cases.
4. The effect of local variables. For this, I don’t fully get it but the theory I have is that the compiler does most of the heavy lifting in minimizing the amount of memory used by them. And if impossible, they are sent to DRAM/SRAM(if I get this correctly).
5. If the number of processes doesn’t get affected by BLOCK_SIZE and is constant, we can just do
softmax_kernel[(n_rows,)](
instead of
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
# NOTE:
# - Each torch.tensor object is implicitly converted into a pointer to its first element.
# - `triton.jit`'ed functions can be indexed with a launch grid to obtain a callable GPU kernel.
# - Don't forget to pass meta-parameters as keywords arguments.
add_kernel[grid](
Overall, with this implementation, we get 4x the speed!
Matrix Multiplication
Here, let’s try multiplying a matrix of size (M, K) and (K, N)
# Do in parallel
for m in range(0, M, BLOCK_SIZE_M):
# Do in parallel
for n in range(0, N, BLOCK_SIZE_N):
acc = zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=float32)
for k in range(0, K, BLOCK_SIZE_K):
a = A[m : m+BLOCK_SIZE_M, k : k+BLOCK_SIZE_K]
b = B[k : k+BLOCK_SIZE_K, n : n+BLOCK_SIZE_N]
acc += dot(a, b)
C[m : m+BLOCK_SIZE_M, n : n+BLOCK_SIZE_N] = acc
In case you are a bit rusty in matrix multiplication, you can get the current index i, j from doing a dot product of row i and column j.
So now, what the above code is doing is instead of just using numbers, we say we multiply a block of rows by a block of columns to compute an output block!
Now, in pointer arithmetic, we can get pseudo code for these positions can be thought of as below. One part I needed to remind myself is that stride is the amount of addresses you skip until you get to the next element. So going from a[0] to a[1] will take a.stride(0)!
&A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K] = a_ptr + (m : m+BLOCK_SIZE_M)[:, None]*A.stride(0) + (k : k+BLOCK_SIZE_K)[None, :]*A.stride(1);
&B[k : k+BLOCK_SIZE_K, n:n+BLOCK_SIZE_N] = b_ptr + (k : k+BLOCK_SIZE_K)[:, None]*B.stride(0) + (n : n+BLOCK_SIZE_N)[None, :]*B.stride(1);
which becomes in triton
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None]*stride_am + offs_k [None, :]*stride_ak)
b_ptrs = b_ptr + (offs_k [:, None]*stride_bk + offs_bn[None, :]*stride_bn)
Now, some observations
- The % M and % N at the end are for avoiding going over the max offset. However, one problem here is that this will cause the overflowed values to wrap around. Thus, it is very important here to do masking!
- This is only for the pointers of the inner level. So, after the dot product, to do one inner loop, we do
a_ptrs += BLOCK_SIZE_K * stride_ak;
b_ptrs += BLOCK_SIZE_K * stride_bk;
However, if we go back to the fused softmax and the idea of reading and writing from DRAM, you might notice that this seems pretty inefficient. The most inefficient part is for an M by K matrix, we load each row into memory N/BLOCK_SIZE_N times while ideally, although impossibly, we will only like to load once.
However, can we rearrange our loops so we can minimize the memory reads/maximize the L2 Cache hit rate?
L2 Cache Optimizations
One way the tutorial suggests we do this is very simple, just increase the number of rows loaded at once!
As you see, if we load one row and load all the columns, we can write 9 blocks in the output. But if we load 3 rows and 3 columns, we can write 9 blocks but without loading 81–54 = 27 rows! In practice, this can save around 10% of computing.
However, one caveat for this is we can’t load more rows by itself since that is expensive. We would much prefer to load one column and one row at a time. Then how do we do this?
The code for this part is
# Program ID
pid = tl.program_id(axis=0)
# Number of program ids along the M axis
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
# Number of programs ids along the N axis
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
# Number of programs in group
num_pid_in_group = GROUP_SIZE_M * num_pid_n
# Id of the group this program is in
group_id = pid // num_pid_in_group
# Row-id of the first program in the group
first_pid_m = group_id * GROUP_SIZE_M
# If `num_pid_m` isn't divisible by `GROUP_SIZE_M`, the last group is smaller
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
# *Within groups*, programs are ordered in a column-major order
# Row-id of the program in the *launch grid*
pid_m = first_pid_m + (pid % group_size_m)
# Col-id of the program in the *launch grid*
pid_n = (pid % num_pid_in_group) // group_size_m
It seems like the general idea is
- We get the thread id
- We get the number of row blocks of MxK matrix and column blocks of KxN matrix
- We multiply the number of column blocks by the variable GROUP_SIZE_M. Here, we will get the number of blocks across GROUP_SIZE_M amount of rows. If you get confused by this, check out the diagram above! You have GROUP_SIZE_M=3 rows and num_pid_n = num columns=9. So you will get num_pid_in_group=27
- We get the current group id we are computing by dividing the thread id by the above! What this will do is tell us in which row group we are in!
- We get the offset to the start of the current group row by multiplying the current group id with the GROUP_SIZE_M. This will translate our group_id into which row block we will start from.
- Next, we get the group size which is the number of row blocks by checking if we are at the end. In which case we get the modulus.
- For the row id being loaded, this will be the row block offset we calculated in step 5, and then we add a (pid % group_size_m) to get to the desired row within the group. This must mean that for every update in pid, we switch rows.
- Finally, we get the column id as (pid % num_pid_in_group) // group_size_m which in plain English, we divide the thread id by the number of blocks in GROUP_SIZE_M rows. And then we divide by the GROUP_SIZE_M. So essentially, we get the same row for GROUP_SIZE_M updates to this program while we compute the dot product with this column and GROUP_SIZE_M number of rows.
Now this is interesting but I am curious about 2 parts
- What is the optimal GROUP_SIZE_M value here? I’m pretty sure this is a math question so let me know if anyone has ideas!
- How do we handle caching the column for GROUP_SIZE_M steps? For this, I’m not entirely sure since it doesn’t seem like we program this explicitly. We just seem to let Triton/the L2 cache handle it. So it might be an automatic process which will be awesome! Let me know if anyone knows. In this case, we can just query the same matrix location and know that it’s faster!
Overall, the final code is
import torch
import triton
import triton.language as tl
# `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:
# - A list of `triton.Config` objects that define different configurations of
# meta-parameters (e.g., `BLOCK_SIZE_M`) and compilation options (e.g., `num_warps`) to try
# - An auto-tuning *key* whose change in values will trigger evaluation of all the
# provided configs
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
],
key=['M', 'N', 'K'],
)
@triton.jit
def matmul_kernel(
# Pointers to matrices
a_ptr, b_ptr, c_ptr,
# Matrix dimensions
M, N, K,
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
# by to get the element one row down (A has M rows).
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
ACTIVATION: tl.constexpr,
):
"""Kernel for computing the matmul C = A x B.
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
"""
# -----------------------------------------------------------
# Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse.
# See above `L2 Cache Optimizations` section for details.
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
# ----------------------------------------------------------
# Create pointers for the first blocks of A and B.
# We will advance this pointer as we move in the K direction
# and accumulate
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
# See above `Pointer Arithmetics` section for details
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
# -----------------------------------------------------------
# Iterate to compute a block of the C matrix.
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop.
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
# Load the next block of A and B, generate a mask by checking the K dimension.
# If it is out of bounds, set it to 0.
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
# We accumulate along the K dimension.
accumulator += tl.dot(a, b)
# Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
# You can fuse arbitrary activation functions here
# while the accumulator is still in FP32!
if ACTIVATION == "leaky_relu":
accumulator = leaky_relu(accumulator)
c = accumulator.to(tl.float16)
# -----------------------------------------------------------
# Write back the block of the output matrix C with masks.
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
# We can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`.
@triton.jit
def leaky_relu(x):
x = x + 1
return tl.where(x >= 0, x, 0.01 * x)
def matmul(a, b, activation=""):
# Check constraints.
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
assert a.is_contiguous(), "Matrix A must be contiguous"
assert b.is_contiguous(), "Matrix B must be contiguous"
M, K = a.shape
K, N = b.shape
# Allocates output.
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
# 1D launch kernel where each block gets its own program.
grid = lambda META: (
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
)
matmul_kernel[grid](
a, b, c,
M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
ACTIVATION=activation
)
return c
For most of the parts, we already did talk about but a few points
- For .contiguous, my understanding is that this checks if the memory addresses of the matrix are not all over the place because if it was we won’t be able to do the nice striding tricks!
- We seem to be able to call triton functions from within triton functions like with the leaky_relu code which is pretty awesome too.
- We can change dtype just like so: .to(tl.float16) which is very pytorch-like.
Overall this gives comparable performance to CUBLAS which is pretty nice especially since we were able to write relatively simple code to accomplish something from a well-respected standard library. I’ve also heard recently that Deepmind’s alphatensor sped up matrix multiplication but I might be very wrong!
Conclusion
Thanks for reading! These were my notes for the first 3 tutorials out of the 8 triton tutorials. Hope it was helpful! I will cut it here for now just because it was getting longer than I expected. Here is part 2.