[DOCS] Minor modifications of the matmul tutorial (#199)
Making the code more compact and fixing inconsistencies between text variable names and final python program.
This commit is contained in:
@@ -649,6 +649,17 @@ def max_contiguous(input, value, builder=None):
|
|||||||
# Standard library
|
# Standard library
|
||||||
# -----------------------
|
# -----------------------
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def cdiv(x, div):
|
||||||
|
"""
|
||||||
|
Computes the ceiling division of :code:`x` by :code:`div`
|
||||||
|
|
||||||
|
:param x: the input number
|
||||||
|
:type input: Block
|
||||||
|
:param div: the divisor
|
||||||
|
:param div: Block
|
||||||
|
"""
|
||||||
|
return (x + div - 1) // div
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def minimum(x, y):
|
def minimum(x, y):
|
||||||
|
@@ -23,7 +23,7 @@ You will specifically learn about:
|
|||||||
# yourself with Triton, in a way that is easy to customize and extend.
|
# yourself with Triton, in a way that is easy to customize and extend.
|
||||||
#
|
#
|
||||||
# Roughly speaking, the kernel that we will write will implement the following blocked
|
# Roughly speaking, the kernel that we will write will implement the following blocked
|
||||||
# algorithm to multiply a (MxK) by a (KxN) matrix:
|
# algorithm to multiply a (M, K) by a (K, N) matrix:
|
||||||
#
|
#
|
||||||
# .. code-block:: python
|
# .. code-block:: python
|
||||||
#
|
#
|
||||||
@@ -38,7 +38,7 @@ You will specifically learn about:
|
|||||||
# acc += dot(a, b)
|
# acc += dot(a, b)
|
||||||
# C[m : m+BLOCK_SIZE_M, n : n+BLOCK_SIZE_N] = acc;
|
# C[m : m+BLOCK_SIZE_M, n : n+BLOCK_SIZE_N] = acc;
|
||||||
#
|
#
|
||||||
# where each iteration of the doubly-nested for-loop corresponds to a Triton program instance.
|
# where each iteration of the doubly-nested for-loop is performed by a dedicated Triton program instance.
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
# Compute Kernel
|
# Compute Kernel
|
||||||
@@ -53,35 +53,31 @@ You will specifically learn about:
|
|||||||
# ~~~~~~~~~~~~~~~~~~~~
|
# ~~~~~~~~~~~~~~~~~~~~
|
||||||
#
|
#
|
||||||
# For a row-major 2D tensor :code:`X`, the memory location of :code:`X[i, j]` is given b
|
# For a row-major 2D tensor :code:`X`, the memory location of :code:`X[i, j]` is given b
|
||||||
# y :code:`&X[i, j] = X + i*stride_x_0 + j*stride_x_1`.
|
# y :code:`&X[i, j] = X + i*stride_xi + j*stride_xj`.
|
||||||
# Therefore, blocks of pointers for :code:`A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K]` and
|
# Therefore, blocks of pointers for :code:`A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K]` and
|
||||||
# :code:`B[k : k+BLOCK_SIZE_K, n : n+BLOCK_SIZE_N]` can be defined in pseudo-code as:
|
# :code:`B[k : k+BLOCK_SIZE_K, n : n+BLOCK_SIZE_N]` can be defined in pseudo-code as:
|
||||||
#
|
#
|
||||||
# .. code-block:: python
|
# .. code-block:: python
|
||||||
#
|
#
|
||||||
# &A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K] = A + (m : m+BLOCK_SIZE_M)[:, None]*A.stride(0) + (k : k+BLOCK_SIZE_K)[None, :]*A.stride(1);
|
# &A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K] = a_ptr + (m : m+BLOCK_SIZE_M)[:, None]*A.stride(0) + (k : k+BLOCK_SIZE_K)[None, :]*A.stride(1);
|
||||||
# &B[k : k+BLOCK_SIZE_K, n:n+BLOCK_SIZE_N] = B + (k : k+BLOCK_SIZE_K)[:, None]*B.stride(0) + (n : n+BLOCK_SIZE_N)[None, :]*B.stride(1);
|
# &B[k : k+BLOCK_SIZE_K, n:n+BLOCK_SIZE_N] = b_ptr + (k : k+BLOCK_SIZE_K)[:, None]*B.stride(0) + (n : n+BLOCK_SIZE_N)[None, :]*B.stride(1);
|
||||||
#
|
#
|
||||||
# Which means that pointers for blocks of A and B can be initialized (i.e., :code:`k=0`) in Triton as:
|
# Which means that pointers for blocks of A and B can be initialized (i.e., :code:`k=0`) in Triton as:
|
||||||
#
|
#
|
||||||
# .. code-block:: python
|
# .. code-block:: python
|
||||||
#
|
#
|
||||||
# pid_m = triton.program_id(0)
|
# offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||||
# pid_n = triton.program_id(1)
|
# offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||||
# rm = pid_m * BLOCK_SIZE_M + triton.arange(0, BLOCK_SIZE_M)
|
# offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||||
# rn = pid_n * BLOCK_SIZE_N + triton.arange(0, BLOCK_SIZE_N)
|
# a_ptrs = a_ptr + (offs_am[:, None]*stride_am + offs_k [None, :]*stride_ak)
|
||||||
# rk = triton.arange(0, BLOCK_SIZE_K)
|
# b_ptrs = b_ptr + (offs_k [:, None]*stride_bk + offs_bn[None, :]*stride_bn)
|
||||||
# // pointer for A operand
|
|
||||||
# pa = A + (rm[:, None] * stride_a_0 + rk[None, :] * stride_a_1);
|
|
||||||
# // pointer for B operand
|
|
||||||
# pb = B + (rk[:, None] * stride_b_0 + rn[None, :] * stride_b_1);
|
|
||||||
#
|
#
|
||||||
# And then updated in the inner loop as follows:
|
# And then updated in the inner loop as follows:
|
||||||
#
|
#
|
||||||
# .. code-block:: python
|
# .. code-block:: python
|
||||||
#
|
#
|
||||||
# pa += BLOCK_SIZE_K * stride_a_1;
|
# pa += BLOCK_SIZE_K * stride_ak;
|
||||||
# pb += BLOCK_SIZE_K * stride_b_0;
|
# pb += BLOCK_SIZE_K * stride_bk;
|
||||||
#
|
#
|
||||||
#
|
#
|
||||||
# L2 Cache Optimizations
|
# L2 Cache Optimizations
|
||||||
@@ -109,13 +105,25 @@ You will specifically learn about:
|
|||||||
#
|
#
|
||||||
# .. code-block:: python
|
# .. code-block:: python
|
||||||
#
|
#
|
||||||
# pid = triton.program_id(0);
|
# # program ID
|
||||||
# width = GROUP_M * grid_n;
|
# pid = tl.program_id(axis=0)
|
||||||
# group_id = pid // width;
|
# # number of program ids along the M axis
|
||||||
# # we need to handle the case where M % (GROUP_M*BLOCK_SIZE_M) != 0
|
# num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
||||||
# group_size = min(grid_m - group_id * GROUP_M, GROUP_M);
|
# # number of programs ids along the N axis
|
||||||
# pid_m = group_id * GROUP_M + (pid % group_size);
|
# num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||||
# pid_n = (pid % width) // (group_size);
|
# # number of programs in group
|
||||||
|
# num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||||
|
# # id of the group this program is in
|
||||||
|
# group_id = pid // num_pid_in_group
|
||||||
|
# # row-id of the first program in the group
|
||||||
|
# first_pid_m = group_id * GROUP_SIZE_M
|
||||||
|
# # if `num_pid_m` isn't divisible by `GROUP_SIZE_M`, the last group is smaller
|
||||||
|
# group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||||
|
# # *within groups*, programs are ordered in a column-major order
|
||||||
|
# # row-id of the program in the *launch grid*
|
||||||
|
# pid_m = first_pid_m + (pid % group_size_m)
|
||||||
|
# # col-id of the program in the *launch grid*
|
||||||
|
# pid_n = (pid % num_pid_in_group) // group_size_m
|
||||||
#
|
#
|
||||||
# For example, in the following matmul where each matrix is 9 blocks by 9 blocks,
|
# For example, in the following matmul where each matrix is 9 blocks by 9 blocks,
|
||||||
# we can see that if we compute the output in row-major ordering, we need to load 90
|
# we can see that if we compute the output in row-major ordering, we need to load 90
|
||||||
@@ -164,26 +172,19 @@ import triton.language as tl
|
|||||||
@triton.jit
|
@triton.jit
|
||||||
def matmul_kernel(
|
def matmul_kernel(
|
||||||
# Pointers to matrices
|
# Pointers to matrices
|
||||||
a_ptr,
|
a_ptr, b_ptr, c_ptr,
|
||||||
b_ptr,
|
|
||||||
c_ptr,
|
|
||||||
# Matrix dimensions
|
# Matrix dimensions
|
||||||
M,
|
M, N, K,
|
||||||
N,
|
|
||||||
K,
|
|
||||||
# The stride variables represent how much to increase the ptr by when moving by 1
|
# The stride variables represent how much to increase the ptr by when moving by 1
|
||||||
# element in a particular dimension. E.g. stride_am is how much to increase a_ptr
|
# element in a particular dimension. E.g. stride_am is how much to increase a_ptr
|
||||||
# by to get the element one row down (A has M rows)
|
# by to get the element one row down (A has M rows)
|
||||||
stride_am,
|
stride_am, stride_ak,
|
||||||
stride_ak,
|
stride_bk, stride_bn,
|
||||||
stride_bk,
|
stride_cm, stride_cn,
|
||||||
stride_bn,
|
# Meta-parameters
|
||||||
stride_cm,
|
|
||||||
stride_cn,
|
|
||||||
**meta,
|
**meta,
|
||||||
):
|
):
|
||||||
"""Kernel for computing the matmul AB = C
|
"""Kernel for computing the matmul C = A x B.
|
||||||
|
|
||||||
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
|
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
|
||||||
"""
|
"""
|
||||||
# extract meta-parameters
|
# extract meta-parameters
|
||||||
@@ -191,67 +192,65 @@ def matmul_kernel(
|
|||||||
BLOCK_SIZE_N = meta['BLOCK_SIZE_N']
|
BLOCK_SIZE_N = meta['BLOCK_SIZE_N']
|
||||||
BLOCK_SIZE_K = meta['BLOCK_SIZE_K']
|
BLOCK_SIZE_K = meta['BLOCK_SIZE_K']
|
||||||
GROUP_SIZE_M = 8
|
GROUP_SIZE_M = 8
|
||||||
|
|
||||||
|
# -----------------------------------------------------------
|
||||||
|
# Map program ids `pid` to the block of C it should compute.
|
||||||
|
# This is done in a grouped ordering to promote L2 data reuse
|
||||||
|
# See above `L2 Cache Optimizations` section for details
|
||||||
pid = tl.program_id(axis=0)
|
pid = tl.program_id(axis=0)
|
||||||
|
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
||||||
|
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||||
|
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||||
|
group_id = pid // num_pid_in_group
|
||||||
|
first_pid_m = group_id * GROUP_SIZE_M
|
||||||
|
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||||
|
pid_m = first_pid_m + (pid % group_size_m)
|
||||||
|
pid_n = (pid % num_pid_in_group) // group_size_m
|
||||||
|
|
||||||
# the number of blocks is the ceil(M / BLOCK_SIZE_M) since we need an extra block
|
# ----------------------------------------------------------
|
||||||
# Note that this will lead to some quantization in performance where time-taken jumps
|
# Create pointers for the first blocks of A and B.
|
||||||
# when you need to add a new block
|
# We will advance this pointer as we move in the K direction
|
||||||
n_blocks_m = (M + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M
|
# and accumulate
|
||||||
n_blocks_n = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N
|
# a_ptrs is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
|
||||||
|
# b_ptrs is a block of [BLOCK_SIZE_K, BLOCK_SIZE_n] pointers
|
||||||
|
# see above `Pointer Arithmetics` section for details
|
||||||
|
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||||
|
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||||
|
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||||
|
a_ptrs = a_ptr + (offs_am[:, None]*stride_am + offs_k [None, :]*stride_ak)
|
||||||
|
b_ptrs = b_ptr + (offs_k [:, None]*stride_bk + offs_bn[None, :]*stride_bn)
|
||||||
|
|
||||||
# Map PIDs to the block they should compute. This is done in a grouped ordering
|
# -----------------------------------------------------------
|
||||||
# to promote L2 cache reuse.
|
# Iterate to compute a block of the C matrix
|
||||||
n_output_blocks_in_group = GROUP_SIZE_M * n_blocks_n
|
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
|
||||||
group_id = pid // n_output_blocks_in_group
|
# of fp32 values for higher accuracy.
|
||||||
first_m_block_in_group = group_id * GROUP_SIZE_M
|
# `accumulator` will be converted back to fp16 after the loop
|
||||||
|
|
||||||
# If the number of blocks is not divisible by the group size, the last group is smaller
|
|
||||||
group_size_m = min(n_blocks_m - first_m_block_in_group, GROUP_SIZE_M)
|
|
||||||
|
|
||||||
# Within a group, we compute in col-major ordering, block_m and block_n are the
|
|
||||||
# output row and col that this program is computing in terms of blocks
|
|
||||||
block_m = first_m_block_in_group + (pid % group_size_m)
|
|
||||||
block_n = (pid % n_output_blocks_in_group) // group_size_m
|
|
||||||
|
|
||||||
# Convert from block indices back to element indices
|
|
||||||
m_start = block_m * BLOCK_SIZE_M
|
|
||||||
n_start = block_n * BLOCK_SIZE_N
|
|
||||||
|
|
||||||
# Expand out to all the offsets for each of the elements in this block.
|
|
||||||
m_offsets_a = (m_start + tl.arange(0, BLOCK_SIZE_M))[:, None]
|
|
||||||
n_offsets_b = (n_start + tl.arange(0, BLOCK_SIZE_N))[None, :]
|
|
||||||
k_offsets = tl.arange(0, BLOCK_SIZE_K)
|
|
||||||
|
|
||||||
# Get the pointers for the first block of each. We will advance this pointer
|
|
||||||
# as we move in the K direction and accumulate.
|
|
||||||
# a_ptrs should contain BLOCK_SIZE_M * BLOCK_SIZE_K pointers
|
|
||||||
a_ptrs = a_ptr + (stride_am * m_offsets_a + stride_ak * k_offsets[None, :])
|
|
||||||
# b_ptrs should contain BLOCK_SIZE_K * BLOCK_SIZE_N pointers
|
|
||||||
b_ptrs = b_ptr + (stride_bk * k_offsets[:, None] + stride_bn * n_offsets_b)
|
|
||||||
# We accumulate internally in fp32, but the output is written out in the dtype
|
|
||||||
# of the tensor when it is stored
|
|
||||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||||
for k in range(0, K, BLOCK_SIZE_K):
|
for k in range(0, K, BLOCK_SIZE_K):
|
||||||
# Note that for simplicity, we don't apply a mask here. This means that if K is
|
# Note that for simplicity, we don't apply a mask here.
|
||||||
# not a multiple of BLOCK_SIZE_K, this will access out-of-bounds memory and
|
# This means that if K is not a multiple of BLOCK_SIZE_K,
|
||||||
# accumulate it incorrectly.
|
# this will access out-of-bounds memory and produce an
|
||||||
|
# error or (worse!) incorrect results.
|
||||||
a = tl.load(a_ptrs)
|
a = tl.load(a_ptrs)
|
||||||
b = tl.load(b_ptrs)
|
b = tl.load(b_ptrs)
|
||||||
# We accumulate along the K dimension
|
# We accumulate along the K dimension
|
||||||
accumulator += tl.dot(a, b)
|
accumulator += tl.dot(a, b)
|
||||||
|
|
||||||
# Advance the ptrs to the next K block
|
# Advance the ptrs to the next K block
|
||||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||||
# triton can accept arbitrary activation function via metaparameters!
|
# you can fuse arbitrary activation functions here
|
||||||
|
# while the accumulator is still in FP32 !
|
||||||
if meta['ACTIVATION']:
|
if meta['ACTIVATION']:
|
||||||
accumulator = meta['ACTIVATION'](accumulator)
|
accumulator = meta['ACTIVATION'](accumulator)
|
||||||
|
c = accumulator.to(tl.float16)
|
||||||
|
|
||||||
m_offsets_c = (m_start + tl.arange(0, BLOCK_SIZE_M))[:, None]
|
# -----------------------------------------------------------
|
||||||
n_offsets_c = (n_start + tl.arange(0, BLOCK_SIZE_N))[None, :]
|
# Write back the block of the output matrix C
|
||||||
c_ptrs = c_ptr + stride_cm * m_offsets_c + stride_cn * n_offsets_c
|
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||||
mask = (m_offsets_c < M) & (n_offsets_c < N)
|
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||||
tl.store(c_ptrs, accumulator, mask=mask)
|
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
||||||
|
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
||||||
|
tl.store(c_ptrs, c, mask=c_mask)
|
||||||
|
|
||||||
|
|
||||||
# we can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`
|
# we can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`
|
||||||
@@ -282,18 +281,11 @@ def matmul(a, b, activation=None):
|
|||||||
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
|
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
|
||||||
)
|
)
|
||||||
matmul_kernel[grid](
|
matmul_kernel[grid](
|
||||||
a,
|
a, b, c,
|
||||||
b,
|
M, N, K,
|
||||||
c,
|
a.stride(0), a.stride(1),
|
||||||
M,
|
b.stride(0), b.stride(1),
|
||||||
N,
|
c.stride(0), c.stride(1),
|
||||||
K,
|
|
||||||
a.stride(0),
|
|
||||||
a.stride(1),
|
|
||||||
b.stride(0),
|
|
||||||
b.stride(1),
|
|
||||||
c.stride(0),
|
|
||||||
c.stride(1),
|
|
||||||
ACTIVATION=activation,
|
ACTIVATION=activation,
|
||||||
)
|
)
|
||||||
return c
|
return c
|
||||||
|
Reference in New Issue
Block a user